Skip to content

Conversation

@mdavis36
Copy link
Collaborator

@mdavis36 mdavis36 commented Jan 31, 2026

Intermediate PR to clean up diff in #5902 this just changes the file names and header include.

@github-actions
Copy link

Description

  • Rename storage.h/cpp files to container.h/cpp

  • Update include statements in fusion.h to reference new container.h

  • Update CMakeLists.txt to build container.cpp instead of storage.cpp

  • Move implementation code from commented storage.cpp to active container.cpp

Changes walkthrough

Relevant files
Other
container.cpp
Rename storage.cpp to container.cpp with full implementation

csrc/ir/container.cpp

  • File renamed from storage.cpp to container.cpp
  • Previously empty/commented implementation now contains full
    IrContainer class implementation
  • Added complete implementation of IrContainer methods including
    deterministic_vals(), deterministic_exprs(), registerVal(),
    registerExpr(), and other container management functions
  • [link]   
    container.h
    Rename storage.h to container.h with complete class definition

    csrc/ir/container.h

  • File renamed from storage.h to container.h
  • Contains complete IrContainer class definition with all methods and
    members
  • Includes passkey pattern, container management, and fusion integration
    functionality
  • +190/-1 
    fusion.h
    Update include path from storage.h to container.h               

    csrc/fusion.h

  • Updated include statement from #include to #include
  • Maintains same functionality with renamed header file
  • +1/-1     
    Configuration changes
    CMakeLists.txt
    Update build configuration for renamed container files     

    CMakeLists.txt

  • Updated build configuration to reference container.cpp instead of
    storage.cpp
  • Removed commented reference to container.cpp
  • Activated container.cpp in the build system
  • +1/-2     
    Additional files
    storage.h +0/-203 

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 No relevant tests
    ⚡ Recommended focus areas for review
    File Rename Completeness

    This appears to be a straightforward file rename from storage.h/cpp to container.h/cpp. The changes look comprehensive - includes are updated in fusion.h, the build system (CMakeLists.txt) is updated, and the old storage.h file is deleted. However, I should verify that all references to the old storage files have been updated throughout the codebase, including any tests, documentation, or other build files that might reference the old names.

    // clang-format off
    /*
     * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
     * All rights reserved.
     * SPDX-License-Identifier: BSD-3-Clause
     */
    // clang-format on
    #include "ir/container.h"
    
    #include "instrumentation.h"
    #include "ir/base_nodes.h"
    #include "ir/builder.h"
    #include "ir/cloner.h"
    #include "ir/internal_nodes.h"
    
    namespace nvfuser {
    
    //! Return values in insertion order
    const std::deque<Val*> IrContainer::deterministic_vals() const noexcept {
      std::deque<Val*> vals_deque;
      std::transform(
          vals_up_.begin(),
          vals_up_.end(),
          std::back_inserter(vals_deque),
          [](const std::unique_ptr<Val>& val_up) { return val_up.get(); });
      return vals_deque;
    }
    
    //! Return expression in insertion order
    const std::deque<Expr*> IrContainer::deterministic_exprs() const noexcept {
      std::deque<Expr*> exprs_deque;
      std::transform(
          exprs_up_.begin(),
          exprs_up_.end(),
          std::back_inserter(exprs_deque),
          [](const std::unique_ptr<Expr>& expr_up) { return expr_up.get(); });
      return exprs_deque;
    }
    
    //! Return mapping from value to integer id
    const std::unordered_map<Val*, int64_t> IrContainer::deterministic_vals_map()
        const noexcept {
      std::unordered_map<Val*, int64_t> vals_map;
      int64_t count = 0;
      std::transform(
          vals_up_.begin(),
          vals_up_.end(),
          std::inserter(vals_map, vals_map.end()),
          [&count](const std::unique_ptr<Val>& val_up) {
            return std::make_pair(val_up.get(), count++);
          });
      return vals_map;
    }
    
    //! Return mapping from expression to integer id
    const std::unordered_map<Expr*, int64_t> IrContainer::deterministic_exprs_map()
        const noexcept {
      std::unordered_map<Expr*, int64_t> exprs_map;
      int64_t count = 0;
      std::transform(
          exprs_up_.begin(),
          exprs_up_.end(),
          std::inserter(exprs_map, exprs_map.end()),
          [&count](const std::unique_ptr<Expr>& expr_up) {
            return std::make_pair(expr_up.get(), count++);
          });
      return exprs_map;
    }
    
    void IrContainer::swap(IrContainer& a, IrContainer& b) noexcept {
      FUSER_PERF_SCOPE("Fusion swap");
    
      // Swap the content
      std::swap(a.vals_up_, b.vals_up_);
      std::swap(a.vals_, b.vals_);
    
      std::swap(a.exprs_up_, b.exprs_up_);
      std::swap(a.exprs_, b.exprs_);
    
      std::swap(a.val_type_name_map_, b.val_type_name_map_);
      std::swap(a.expr_name_counter_, b.expr_name_counter_);
    
      std::swap(a.metadata_, b.metadata_);
    
      std::swap(a.parent_, b.parent_);
    
      std::swap(a.zero_val_, b.zero_val_);
      std::swap(a.one_val_, b.one_val_);
      std::swap(a.true_val_, b.true_val_);
      std::swap(a.false_val_, b.false_val_);
      std::swap(a.magic_zero_val_, b.magic_zero_val_);
      std::swap(a.axioms_, b.axioms_);
    }
    
    IrCloner IrContainer::copy(const IrContainer* from, IrContainer* to) {
      to->clear();
      IrCloner ir_cloner(to->parent());
    
      // Copy values in deterministic order
      // deterministic_vals can contain special values like one_val_, zero_val_, etc
      // that are not registered in the container.
      for (auto val : from->deterministic_vals()) {
        if (from->vals().count(val) > 0) {
          to->vals_.insert(ir_cloner.clone(val));
        }
      }
    
      // Copy expressions in deterministic order
      for (auto expr : from->deterministic_exprs()) {
        if (from->unordered_exprs().count(expr) > 0) {
          to->exprs_.insert(ir_cloner.clone(expr));
        }
      }
    
      to->val_type_name_map_ = from->val_type_name_map_;
      to->expr_name_counter_ = from->expr_name_counter_;
    
      if (from->axioms_ != nullptr) {
        to->axioms_ = std::make_unique<std::vector<Val*>>();
        for (auto pred : *from->axioms_) {
          to->axioms_->push_back(ir_cloner.clone(pred));
        }
      }
    
      to->metadata_ = ir_cloner.clone(from->metadata_);
    
      return ir_cloner;
    }
    
    IrContainer::IrContainer() = default;
    
    IrContainer::~IrContainer() {
      clear();
    }
    
    void IrContainer::removeExpr(Expr* expr) {
      NVF_ERROR(
          exprs_.find(expr) != exprs_.end(),
          "Wanted to remove an expression but it doesn't exist in this container.");
      auto expr_in_deque = std::find_if(
          exprs_up_.begin(),
          exprs_up_.end(),
          [expr](std::unique_ptr<Expr>& expr_up) { return expr_up.get() == expr; });
    
      NVF_ERROR(
          expr_in_deque != exprs_up_.end(),
          "Wanted to remove an expression but its unique ptr is missing.");
    
      exprs_.erase(expr);
      exprs_up_.erase(expr_in_deque);
    }
    
    //! Completely remove val from the fusion, break all dependencies associated
    //! with it
    void IrContainer::removeVal(Val* val) {
      // Don't remove shortcuts
      if (val == true_val_.get() || val == false_val_.get() ||
          val == one_val_.get() || val == zero_val_.get() ||
          val == magic_zero_val_.get()) {
        return;
      }
    
      NVF_ERROR(
          vals_.find(val) != vals_.end(),
          "Wanted to remove a value but it doesn't exist in this container.");
      auto val_in_deque = std::find_if(
          vals_up_.begin(), vals_up_.end(), [val](std::unique_ptr<Val>& val_up) {
            return val_up.get() == val;
          });
    
      NVF_ERROR(
          val_in_deque != vals_up_.end(),
          "Wanted to remove a value but its unique ptr is missing.");
    
      vals_.erase(val);
      vals_up_.erase(val_in_deque);
    }
    
    //! Register the Val with this container
    void IrContainer::registerVal(Val* val) {
      if (inContainer(val)) {
        return;
      }
    
      // Otherwise handle registration locally
      vals_up_.emplace_back(val);
      vals_.insert(val);
      val->setName(IrContainerPasskey(), getValName(val->vtype()));
    }
    
    //! Register expr with this container.
    void IrContainer::registerExpr(Expr* expr) {
      if (inContainer(expr)) {
        return;
      }
    
      // Otherwise handle registration locally
      exprs_up_.emplace_back(expr);
      exprs_.insert(expr);
      expr->setName(IrContainerPasskey(), getExprName());
    }
    
    void IrContainer::clear() noexcept {
      FUSER_PERF_SCOPE("IrContainer clear");
      vals_.clear();
      vals_up_.clear();
      exprs_.clear();
      exprs_up_.clear();
      axioms_.reset();
      val_type_name_map_.clear();
      metadata_.clear();
      expr_name_counter_ = 0;
    }
    
    bool IrContainer::inContainer(const Statement* const_stmt) const {
      // We don't use dynamic_cast here because `const_stmt` may be an invalid
      // pointer. Specifically a pointer to a Statement owned by another container
      // that has been freed.
    
      // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
      void* raw_ptr = const_cast<void*>(reinterpret_cast<const void*>(const_stmt));
      if (exprs_.count(reinterpret_cast<Expr*>(raw_ptr)) == 0 &&
          vals_.count(reinterpret_cast<Val*>(raw_ptr)) == 0) {
        return false;
      }
    
      NVF_ERROR(
          const_stmt->container() == this->parent(),
          "Container claims to own stmt, but stmt disagrees.");
    
      // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
      auto* stmt = const_cast<Statement*>(const_stmt);
      if (stmt->isExpr()) {
        NVF_ERROR(
            exprs_.find(stmt->as<Expr>()) != exprs_.end(),
            "Somehow container claims to and not to own an Expr.");
      }
      if (stmt->isVal()) {
        NVF_ERROR(
            vals_.find(stmt->as<Val>()) != vals_.end(),
            "Somehow container claims to and not to own an Val.");
      }
    
      return true;
    }
    
    // Shortcuts for frequently used vals
    Val* IrContainer::zeroVal() {
      if (!zero_val_) {
        auto zero_val =
            IrBuilder::createInContainer<Val>(this->parent(), 0L, DataType::Index);
        NVF_ERROR(vals_up_.back().get() == zero_val);
        zero_val_ = std::unique_ptr<Val>(vals_up_.back().release());
        vals_up_.pop_back();
      }
      return zero_val_.get();
    }
    
    Val* IrContainer::zeroVal(DataType dtype) {
      if (dtype == DataType::Index) {
        return zeroVal();
      } else if (isBooleanType(dtype)) {
        return falseVal();
      } else {
        // NOTE: this does not cache values
        return IrBuilder::createInContainer<Val>(this->parent(), 0L, dtype);
      }
    }
    
    Val* IrContainer::oneVal() {
      if (!one_val_) {
        auto one_val =
            IrBuilder::createInContainer<Val>(this->parent(), 1L, DataType::Index);
        NVF_ERROR(vals_up_.back().get() == one_val);
        one_val_ = std::unique_ptr<Val>(vals_up_.back().release());
        vals_up_.pop_back();
      }
      return one_val_.get();
    }
    
    Val* IrContainer::oneVal(DataType dtype) {
      if (dtype == DataType::Index) {
        return oneVal();
      } else if (isBooleanType(dtype)) {
        return trueVal();
      } else {
        // NOTE: this does not cache values
        return IrBuilder::createInContainer<Val>(this->parent(), 1L, dtype);
      }
    }
    
    Val* IrContainer::falseVal() {
      if (!false_val_) {
        auto false_val = IrBuilder::createInContainer<Val>(
            this->parent(), false, DataType::Bool);
        NVF_ERROR(vals_up_.back().get() == false_val);
        false_val_ = std::unique_ptr<Val>(vals_up_.back().release());
        vals_up_.pop_back();
      }
      return false_val_.get();
    }
    
    Val* IrContainer::trueVal() {
      if (!true_val_) {
        auto true_val =
            IrBuilder::createInContainer<Val>(this->parent(), true, DataType::Bool);
        NVF_ERROR(vals_up_.back().get() == true_val);
        true_val_ = std::unique_ptr<Val>(vals_up_.back().release());
        vals_up_.pop_back();
      }
      return true_val_.get();
    }
    
    NamedScalar* IrContainer::magicZeroVal() {
      if (!magic_zero_val_) {
        auto magic_zero =
            IrBuilder::create<NamedScalar>(kMagicZeroName, DataType::Index);
        NVF_ERROR(vals_up_.back().get() == magic_zero);
        magic_zero_val_ = std::unique_ptr<NamedScalar>(
            vals_up_.back().release()->as<NamedScalar>());
        vals_up_.pop_back();
      }
      return magic_zero_val_.get();
    }
    
    Val* IrContainer::metadataOf(Val* v) {
      if (metadata_.count(v) == 0) {
        auto metadata_val =
            IrBuilder::createInContainer<Val>(this->parent(), metaDataTypeOf(v));
        auto metadata_expr = IrBuilder::createInContainer<GetMetaData>(
            this->parent(), metadata_val, v);
        metadata_[v] = std::make_pair(metadata_val, metadata_expr);
      }
      return metadata_.at(v).first;
    }
    
    void IrContainer::lazyInitAxioms() {
      if (!axioms_) {
        axioms_ = std::make_unique<std::vector<Val*>>();
        axioms_->reserve(kParallelTypeThreads.size() * 3);
        auto zero = zeroVal();
        for (auto p : kParallelTypeThreads) {
          auto pidx = NamedScalar::getParallelIndex(p);
          auto pdim = NamedScalar::getParallelDim(p);
          axioms_->push_back(SimplifyingIrBuilder::geExpr(pidx, zero));
          axioms_->push_back(SimplifyingIrBuilder::gtExpr(pdim, zero));
          axioms_->push_back(SimplifyingIrBuilder::ltExpr(pidx, pdim));
        }
      }
    }
    
    void IrContainer::assumePositive(Val* val) {
      NVF_ERROR(val->container() == this->parent());
      lazyInitAxioms();
      axioms_->emplace_back(IrBuilder::gtExpr(val, zeroVal()));
    }
    
    void IrContainer::assumeNonNegative(Val* val) {
      NVF_ERROR(val->container() == this->parent());
      lazyInitAxioms();
      axioms_->emplace_back(IrBuilder::geExpr(val, zeroVal()));
    }
    
    void IrContainer::removeStatementsCreatedAfter(
        int64_t prev_num_exprs,
        int64_t prev_num_vals) {
      NVF_ERROR(
          exprs_up_.size() == exprs_.size(),
          "exprs_up_ (size ",
          exprs_up_.size(),
          ") and exprs_ (size ",
          exprs_.size(),
          ") are out of sync.");
      NVF_ERROR(
          std::ssize(exprs_up_) >= prev_num_exprs,
          "exprs_up_ size (",
          std::ssize(exprs_up_),
          ") is less than prev_num_exprs (",
          prev_num_exprs,
          ").");
    
      // Remove expressions before values because we need to change Val::uses_.
      while (std::ssize(exprs_up_) > prev_num_exprs) {
        Expr* e = exprs_up_.back().get();
        for (Val* in : e->inputs()) {
          in->removeUse(e);
        }
        exprs_.erase(e);
        exprs_up_.pop_back();
      }
    
      while (std::ssize(vals_up_) > prev_num_vals) {
        vals_.erase(vals_up_.back().get());
        vals_up_.pop_back();
      }
    }
    
    } // namespace nvfuser
    Header Content Verification

    The new container.h header appears to contain the full IrContainer class definition, which is good. However, I should verify that this is indeed the complete and correct content that should be in the renamed file, and that no functionality was lost in the rename process.

    // clang-format off
    /*
     * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
     * All rights reserved.
     * SPDX-License-Identifier: BSD-3-Clause
     */
    // clang-format on
    #pragma once
    
    #include <deque>
    #include <unordered_map>
    #include <unordered_set>
    
    #include "base.h"
    #include "exceptions.h"
    #include "ir/base_nodes.h"
    #include "visibility.h"
    
    namespace nvfuser {
    
    // Passkey for container to register names with statements
    class IrContainerPasskey {
      friend class IrContainer;
    
     private:
      explicit IrContainerPasskey() = default;
    };
    
    class NamedScalar;
    
    class IrContainer {
     public:
      NVF_API IrContainer();
    
      // Copy/Move Constructors and Operators are deleted. IrContainer is managed
      // through a smart pointer in IrContainer. Semantic operations for Fusion
      // types are handled directly through copy and swap functions.
      IrContainer(const IrContainer& other) = delete;
      IrContainer(IrContainer&& other) noexcept = delete;
    
      IrContainer& operator=(const IrContainer& other) = delete;
      IrContainer& operator=(IrContainer&& other) noexcept = delete;
    
      ~IrContainer();
    
      bool inContainer(const Statement* stmt) const;
    
      void assertInContainer(const Statement* stmt, const std::string& msg) const {
        NVF_CHECK(
            inContainer(stmt), msg, " it was not found in the active container.");
      }
    
      //! Return values in insertion order
      const std::deque<Val*> deterministic_vals() const noexcept;
    
      //! Return expression in insertion order
      const std::deque<Expr*> deterministic_exprs() const noexcept;
    
      //! Return mapping from value to integer id
      const std::unordered_map<Val*, int64_t> deterministic_vals_map()
          const noexcept;
    
      //! Return mapping from expression to integer id
      const std::unordered_map<Expr*, int64_t> deterministic_exprs_map()
          const noexcept;
    
      //! Return the set of Exprs registered with this fusion. Warning: This will
      //! return exprs outside inputs/outputs, so can be unsafe for use with
      //! segmented fusions.
      const std::unordered_set<Expr*>& unordered_exprs() const noexcept {
        return exprs_;
      }
    
      //! Return the set of Vals registered with this fusion
      const std::unordered_set<Val*>& vals() const noexcept {
        return vals_;
      }
    
      int64_t numExprs() const noexcept {
        return std::ssize(exprs_);
      }
    
      // When include_shortcuts is true, it will count the shortcuts like true_val_.
      int64_t numVals(bool include_shortcuts) const noexcept {
        return include_shortcuts ? std::ssize(vals_) : std::ssize(vals_up_);
      }
    
      // Shortcuts for frequently used vals
      NVF_API Val* zeroVal();
      NVF_API Val* oneVal();
      Val* falseVal();
      Val* trueVal();
      NamedScalar* magicZeroVal();
      NVF_API Val* zeroVal(DataType dtype);
      NVF_API Val* oneVal(DataType dtype);
      Val* metadataOf(Val*);
    
      // Axioms about CUDA programming, for example: threadIdx.x < blockDim.x
      const std::vector<Val*>& axioms() {
        lazyInitAxioms();
        return *axioms_;
      }
    
      void assumePositive(Val* val);
      void assumeNonNegative(Val* val);
    
     protected:
      static IrCloner copy(const IrContainer* from, IrContainer* to);
    
      static void swap(IrContainer& a, IrContainer& b) noexcept;
    
      // Let Fusion access IrContainer::clear()
      friend class Fusion;
    
      void removeExpr(Expr* expr);
    
      //! Completely remove val from the fusion, break all dependencies associated
      //! with it
      void removeVal(Val* val);
    
      //! Register the Val with this container
      NVF_API void registerVal(Val* val);
    
      //! Register expr with this container.
      NVF_API void registerExpr(Expr* expr);
    
      StmtNameType getValName(ValType vtype) {
        if (val_type_name_map_.find(vtype) == val_type_name_map_.end()) {
          val_type_name_map_[vtype] = 0;
        }
        return val_type_name_map_[vtype]++;
      }
    
      StmtNameType getExprName() {
        return expr_name_counter_++;
      }
    
      void clear() noexcept;
    
      void lazyInitAxioms();
    
      friend class StatementGuard;
    
      // A simple garbage collection mechanism to remove all Exprs and Vals that
      // were created after a certain point. This is useful for analysis that
      // creates new Exprs and Vals in the container and wants to clean up after
      // itself.
      //
      // Used by StatementGuard only.
      void removeStatementsCreatedAfter(
          int64_t prev_num_exprs,
          int64_t prev_num_vals);
    
      // Deque of unique pointer is the memory owning data structure
      std::deque<std::unique_ptr<Val>> vals_up_;
    
      // A convenient set to return when we just need an unordered set to do
      // something like check if a Val is in this container
      std::unordered_set<Val*> vals_;
    
      // Deque of unique pointer is the memory owning data structure
      std::deque<std::unique_ptr<Expr>> exprs_up_;
    
      // A convenient set to return when we just need an unordered set to do
      // something like check if an Expr is in this container
      std::unordered_set<Expr*> exprs_;
    
      // Values names counters
      std::unordered_map<ValType, StmtNameType> val_type_name_map_;
    
      // Expression names counter
      StmtNameType expr_name_counter_ = 0;
    
      // Manually store some persistent, frequently used nodes. It's very
      // challenging to do this anything but manually as detecting when a container
      // may or may not have one of these vals is tricky. Specifically because if
      // the container doesn't own it, it's hard to understand from the outside if
      // the node may have been removed then re-registered. It could also be tricky
      // to know when we're using a different container as in FusionCopy_test
      // demonstrates deleting then creating containers can result in the same
      // pointer for the container.
      std::unique_ptr<Val> true_val_;
      std::unique_ptr<Val> false_val_;
      std::unique_ptr<Val> one_val_;
      std::unique_ptr<Val> zero_val_;
      std::unique_ptr<NamedScalar> magic_zero_val_;
      std::unique_ptr<std::vector<Val*>> axioms_;
      std::unordered_map<Val*, std::pair<Val*, Expr*>> metadata_;
    
     public:
      Fusion* parent() const {
        NVF_ERROR(
            parent_ != nullptr, "Call to IrContainer::parent() holds nullptr.")
        return parent_;
      }
    
     private:
      // Parent IrInterface that owns this container (for pure composition pattern)
      // Used by Statement::fusion() to navigate back to owning Fusion
      Fusion* parent_ = nullptr;
    };
    
    } // namespace nvfuser

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    1 participant