From 5f9c929707e40b92870c2fcabaaf0a988c3aa4e9 Mon Sep 17 00:00:00 2001 From: Michael Davis Date: Fri, 30 Jan 2026 14:28:27 -0800 Subject: [PATCH 1/7] Use Fusion* as base pointer for all IrContainers; Move interface into impl namespace before rm. --- csrc/dispatch.h | 2 +- csrc/fusion.h | 50 ++++++++++++++++++++++++++++++++------- csrc/ir/base_nodes.h | 22 ++++++++++------- csrc/ir/builder.h | 13 ++++------ csrc/ir/builder_passkey.h | 8 +++---- csrc/ir/cloner.cpp | 2 +- csrc/ir/cloner.h | 38 +++++------------------------ csrc/ir/container.cpp | 20 ++++++++++++---- csrc/ir/container.h | 17 ++++++++++--- csrc/ir/storage.cpp | 44 +++++++++++++++++++++------------- csrc/ir/storage.h | 11 ++++++--- csrc/ir/utils.cpp | 2 +- csrc/kernel.h | 4 ++-- csrc/mutator.cpp | 2 +- 14 files changed, 140 insertions(+), 95 deletions(-) diff --git a/csrc/dispatch.h b/csrc/dispatch.h index 0ef085592f8..7c53c86d903 100644 --- a/csrc/dispatch.h +++ b/csrc/dispatch.h @@ -343,7 +343,7 @@ class NVF_API OptOutMutator : public PolymorphicBase { } protected: - virtual void removeExpr(IrContainer*, Expr*) const; + virtual void removeExpr(Fusion*, Expr*) const; virtual void registerNewExpr(Expr*) {} private: diff --git a/csrc/fusion.h b/csrc/fusion.h index f57446e1767..aa4d3abe341 100644 --- a/csrc/fusion.h +++ b/csrc/fusion.h @@ -142,7 +142,7 @@ class AliasInfoMap { //! The Fusion owns the whole IR graph (Vals and Exprs) //! // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) -class NVF_API Fusion : public IrContainer { +class NVF_API Fusion : public impl::IrContainer { typedef std::unordered_map> PermutationMap; public: @@ -483,8 +483,8 @@ class NVF_API Fusion : public IrContainer { friend class TranslateApplicableWelford; friend Val; - using IrContainer::registerExpr; - using IrContainer::registerVal; + using impl::IrContainer::registerExpr; + using impl::IrContainer::registerVal; //! Register the Val with this fusion void registerVal(Val* val) override; @@ -541,20 +541,15 @@ class NVF_API Fusion : public IrContainer { inline static const std::string exact_mappings_key = "exact_mappings"; }; +// Template implementations for Fusion::manage() that use IrCloner template std::any defaultCloneFunction(IrCloner& cloner, std::any data) { auto cloned_data = cloner.clone(std::any_cast(data)); - // Adding a static_assert to improve error message. Without this - // static_assert, the following cast will still fail, but the error message - // will be unreadable. static_assert( std::is_convertible_v, "IrCloner::clone returns a data type that is not compatible with the " "original managed data type. " "Likely you will need to check IrCloner::clone for your data type."); - // Convert the result of the clone back to T before assigning to std::any. - // This ensures the type of the std::any does not change over the clone of - // fusion. return std::any((T)cloned_data); } @@ -568,4 +563,41 @@ void Fusion::manage(std::string key, T data) { return manage(key, std::any(data), defaultCloneFunction); } +// Template implementations for IrBuilder that require Fusion to be fully +// defined +template +T* IrBuilder::createInContainer(Fusion* container, Args&&... args) { + NVF_ERROR(container != nullptr, "Need an active container to build IR."); + T* node = new T(IrBuilderPasskey(container), std::forward(args)...); + container->registerStmt(IrBuilderPasskey(container), node); + return node; +} + +template +T* IrBuilder::clone(const T* src, IrCloner* ir_cloner) { + NVF_ERROR( + ir_cloner != nullptr, + "Cannot use create when a cloner object is set. Use clone."); + NVF_ERROR( + ir_cloner->container() != nullptr, + "Cloner doesn't have a valid container to store cloned object."); + + T* dest = new T(src, ir_cloner); + const auto* src_stmt = dynamic_cast(src); + auto* dest_stmt = dynamic_cast(dest); + + auto dest_container = ir_cloner->container(); + auto src_container = src_stmt->container(); + + dest_container->registerStmt(IrBuilderPasskey(dest_container), dest_stmt); + + if (src_container != dest_container) { + dest_stmt->setName(IrBuilderPasskey(dest_container), src_stmt->name()); + } + + ir_cloner->registerClone(src_stmt, dest_stmt); + + return dest; +} + } // namespace nvfuser diff --git a/csrc/ir/base_nodes.h b/csrc/ir/base_nodes.h index d30f7abd145..661ddc6cbcb 100644 --- a/csrc/ir/base_nodes.h +++ b/csrc/ir/base_nodes.h @@ -57,12 +57,16 @@ class Fusion; class Expr; class Val; class IrCloner; -class IrContainer; class IrStorage; class IrBuilderPasskey; class IrContainerPasskey; class ExpressionEvaluator; +// Forward declaration of impl namespace +namespace impl { +class IrContainer; +} + namespace kir { class Kernel; class Predicate; @@ -95,8 +99,8 @@ class ExprPasskey { //! a Statment at runtime. This is currently implemented in dispatch.h class NVF_API Statement : public NonCopyable, public PolymorphicBase { friend void swap(Fusion&, Fusion&) noexcept; - friend void swap(IrContainer& a, IrContainer& b) noexcept; - friend class IrContainer; + friend void swap(impl::IrContainer& a, impl::IrContainer& b) noexcept; + friend class impl::IrContainer; public: Statement() = delete; @@ -143,7 +147,7 @@ class NVF_API Statement : public NonCopyable, public PolymorphicBase { kir::Kernel* kernel() const; // Return the container this statement belongs to - IrContainer* container() const { + Fusion* container() const { return ir_container_; } @@ -186,7 +190,7 @@ class NVF_API Statement : public NonCopyable, public PolymorphicBase { StmtNameType name_ = kInvalidStmName; // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) - IrContainer* ir_container_ = nullptr; + Fusion* ir_container_ = nullptr; }; inline std::string toString(Statement* stmt) { @@ -422,7 +426,7 @@ class NVF_API Val : public Statement { protected: friend class Fusion; - friend class IrContainer; + friend class impl::IrContainer; friend class IrStorage; // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) @@ -480,7 +484,7 @@ class NVF_API Val : public Statement { }; using newObjectFuncType = Expr*( - IrContainer*, + Fusion*, std::vector, std::vector, std::vector); @@ -704,7 +708,7 @@ bool Val::isDefinitionType() const { #define NVFUSER_DECLARE_CLONE_AND_CREATE \ virtual Statement* clone(IrCloner* ir_cloner) const override; \ static Expr* newObject( \ - IrContainer* container, \ + Fusion* container, \ std::vector inputs, \ std::vector outputs, \ std::vector attributes); \ @@ -717,7 +721,7 @@ bool Val::isDefinitionType() const { return IrBuilder::clone(this, ir_cloner); \ } \ Expr* ClassName::newObject( \ - IrContainer* container, \ + Fusion* container, \ std::vector inputs, \ std::vector outputs, \ std::vector attributes) { \ diff --git a/csrc/ir/builder.h b/csrc/ir/builder.h index 8df6c2d64e4..af0fea66d32 100644 --- a/csrc/ir/builder.h +++ b/csrc/ir/builder.h @@ -39,16 +39,11 @@ class IrBuilder { } //! Allocate a new IR node, forwarding the arguments to the appropriate - //! constructor and registering with the container + //! constructor and registering with the container. + //! Implementation provided at the end of fusion.h after Fusion is fully + //! defined. template - static T* createInContainer(IrContainer* container, Args&&... args) { - NVF_ERROR(container != nullptr, "Need an active container to build IR."); - T* node = new T(IrBuilderPasskey(container), std::forward(args)...); - - container->registerStmt(IrBuilderPasskey(container), node); - - return node; - } + static T* createInContainer(Fusion* container, Args&&... args); //! Clone an IR node, forwarding the arguments to the IrCloner constructor. //! Register clones with IrCloner's target container. diff --git a/csrc/ir/builder_passkey.h b/csrc/ir/builder_passkey.h index b15f2a9a76d..79aa6404bd6 100644 --- a/csrc/ir/builder_passkey.h +++ b/csrc/ir/builder_passkey.h @@ -9,18 +9,18 @@ namespace nvfuser { -class IrContainer; +class Fusion; // Passkey for builder to register properties with statements, and to call -// functions in IrContainer +// functions in IrContainer (now via Fusion) class IrBuilderPasskey { friend class IrBuilder; public: - IrContainer* const ir_container_ = nullptr; + Fusion* const ir_container_ = nullptr; private: - explicit IrBuilderPasskey(IrContainer* ir_container) + explicit IrBuilderPasskey(Fusion* ir_container) : ir_container_(ir_container) {} }; diff --git a/csrc/ir/cloner.cpp b/csrc/ir/cloner.cpp index 6a38e7113b2..c71c04f082c 100644 --- a/csrc/ir/cloner.cpp +++ b/csrc/ir/cloner.cpp @@ -13,7 +13,7 @@ namespace nvfuser { -IrCloner::IrCloner(IrContainer* container) : ir_container_(container) { +IrCloner::IrCloner(Fusion* container) : ir_container_(container) { NVF_ERROR( container != nullptr, "IrCloner constructor received NULL container pointer"); diff --git a/csrc/ir/cloner.h b/csrc/ir/cloner.h index 83e0e444430..9a964f4d66a 100644 --- a/csrc/ir/cloner.h +++ b/csrc/ir/cloner.h @@ -21,7 +21,7 @@ namespace nvfuser { -class IrContainer; +class Fusion; //! Clones nodes from an exiting Fusion //! @@ -35,7 +35,7 @@ class IrCloner { public: // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) - explicit IrCloner(IrContainer* container); + explicit IrCloner(Fusion* container); virtual ~IrCloner() = default; NVF_API Statement* clone(const Statement* statement); @@ -140,7 +140,7 @@ class IrCloner { return cloned_disjoint_sets; } - IrContainer* container() const { + Fusion* container() const { return ir_container_; } @@ -155,7 +155,7 @@ class IrCloner { private: // The destination Fusion container - IrContainer* ir_container_ = nullptr; + Fusion* ir_container_ = nullptr; // Builder to make all the new nodes IrBuilder builder_; @@ -177,33 +177,7 @@ class RecomputeTv : private IrCloner { Statement* handle(const TensorDomain*); }; -//! Clone an IR node, forwarding the arguments to the IrCloner constructor. -template -T* IrBuilder::clone(const T* src, IrCloner* ir_cloner) { - NVF_ERROR( - ir_cloner != nullptr, - "Cannot use create when a cloner object is set. Use clone."); - - NVF_ERROR( - ir_cloner->container() != nullptr, - "Cloner doesn't have a valid container to store cloned object."); - - T* dest = new T(src, ir_cloner); - const auto* src_stmt = dynamic_cast(src); - auto* dest_stmt = dynamic_cast(dest); - - auto dest_container = ir_cloner->container(); - auto src_container = src_stmt->container(); - - dest_container->registerStmt(IrBuilderPasskey(dest_container), dest_stmt); - - if (src_container != dest_container) { - dest_stmt->setName(IrBuilderPasskey(dest_container), src_stmt->name()); - } - - ir_cloner->registerClone(src_stmt, dest_stmt); - - return dest; -} +// Note: IrBuilder::clone() template implementation is in fusion.h +// after Fusion is fully defined } // namespace nvfuser diff --git a/csrc/ir/container.cpp b/csrc/ir/container.cpp index e59905f1da3..1e460e4e063 100644 --- a/csrc/ir/container.cpp +++ b/csrc/ir/container.cpp @@ -13,6 +13,11 @@ namespace nvfuser { +// Forward declaration - Fusion inherits from impl::IrContainer +class Fusion; + +namespace impl { + IrContainer::IrContainer() : ir_storage_(std::make_unique()) { ir_storage()->parent_ = this; } @@ -30,21 +35,26 @@ void IrContainer::swap(IrContainer& a, IrContainer& b) noexcept { if (a.ir_storage_) { a.ir_storage()->parent_ = &a; // Also update all Statement ir_container_ pointers to point to new owner + // Note: IrContainer is now in impl namespace, but Statement::ir_container_ + // is Fusion*. Since only Fusion (and its derived classes) inherit from + // impl::IrContainer, this cast is safe. + auto* fusion_a = static_cast(&a); for (auto val : a.vals()) { - val->ir_container_ = &a; + val->ir_container_ = fusion_a; } for (auto expr : a.deterministic_exprs()) { - expr->ir_container_ = &a; + expr->ir_container_ = fusion_a; } } if (b.ir_storage_) { b.ir_storage()->parent_ = &b; // Also update all Statement ir_container_ pointers to point to new owner + auto* fusion_b = static_cast(&b); for (auto val : b.vals()) { - val->ir_container_ = &b; + val->ir_container_ = fusion_b; } for (auto expr : b.deterministic_exprs()) { - expr->ir_container_ = &b; + expr->ir_container_ = fusion_b; } } } @@ -54,4 +64,6 @@ IrCloner IrContainer::copy(const IrContainer* from, IrContainer* to) { return ir_cloner; } + +} // namespace impl } // namespace nvfuser diff --git a/csrc/ir/container.h b/csrc/ir/container.h index 1b73d22b4a1..9e1fa49dd9a 100644 --- a/csrc/ir/container.h +++ b/csrc/ir/container.h @@ -15,26 +15,36 @@ namespace nvfuser { +// Forward declaration of impl namespace +namespace impl { +class IrContainer; +} + class IrBuilderPasskey; class ExprPasskey; class OptOutMutator; // Passkey for container to register names with statements class IrContainerPasskey { - friend class IrContainer; + friend class impl::IrContainer; friend class IrStorage; private: explicit IrContainerPasskey() = default; }; -// IrContainer: Base class for types that provide IrContainer API via -// composition +namespace impl { + +// IrContainer: Implementation detail base class for types that provide +// IrContainer API via composition // // This class handles the composition infrastructure and forwarding boilerplate // for accessing IrStorage functionality. Derived classes (like Fusion) can // focus on their specific logic while inheriting the full IrContainer API. // +// Note: IrContainer is now in the impl namespace. External code should use +// Fusion as the public base class interface. +// // Key Features: // - Owns IrStorage via unique_ptr (can be shared_ptr in Phase 2) // - Forwards all IrStorage public methods @@ -236,4 +246,5 @@ class NVF_API IrContainer : public PolymorphicBase { std::unique_ptr ir_storage_; }; +} // namespace impl } // namespace nvfuser diff --git a/csrc/ir/storage.cpp b/csrc/ir/storage.cpp index 6dff6c52980..94008e0fe2c 100644 --- a/csrc/ir/storage.cpp +++ b/csrc/ir/storage.cpp @@ -15,6 +15,15 @@ namespace nvfuser { +// Forward declaration - Fusion inherits from impl::IrContainer +class Fusion; + +// Helper to cast parent() to Fusion* for IrBuilder calls +static inline Fusion* parentAsFusion(impl::IrContainer* parent) { + // Safe cast since only Fusion inherits from impl::IrContainer + return static_cast(parent); +} + //! Return values in insertion order const std::deque IrStorage::deterministic_vals() const noexcept { std::deque vals_deque; @@ -94,7 +103,8 @@ void IrStorage::swap(IrStorage& a, IrStorage& b) noexcept { IrCloner IrStorage::copy(const IrStorage* from, IrStorage* to) { to->clear(); - IrCloner ir_cloner(to->parent()); + // parent() returns impl::IrContainer*, but IrCloner needs Fusion* + IrCloner ir_cloner(parentAsFusion(to->parent())); // Copy values in deterministic order // deterministic_vals can contain special values like one_val_, zero_val_, etc @@ -225,7 +235,7 @@ bool IrStorage::inContainer(const Statement* const_stmt) const { } NVF_ERROR( - const_stmt->container() == this->parent(), + const_stmt->container() == parentAsFusion(this->parent()), "Container claims to own stmt, but stmt disagrees."); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) @@ -247,8 +257,8 @@ bool IrStorage::inContainer(const Statement* const_stmt) const { // Shortcuts for frequently used vals Val* IrStorage::zeroVal() { if (!zero_val_) { - auto zero_val = - IrBuilder::createInContainer(this->parent(), 0L, DataType::Index); + auto zero_val = IrBuilder::createInContainer( + parentAsFusion(this->parent()), 0L, DataType::Index); NVF_ERROR(vals_up_.back().get() == zero_val); zero_val_ = std::unique_ptr(vals_up_.back().release()); vals_up_.pop_back(); @@ -263,14 +273,15 @@ Val* IrStorage::zeroVal(DataType dtype) { return falseVal(); } else { // NOTE: this does not cache values - return IrBuilder::createInContainer(this->parent(), 0L, dtype); + return IrBuilder::createInContainer( + parentAsFusion(this->parent()), 0L, dtype); } } Val* IrStorage::oneVal() { if (!one_val_) { - auto one_val = - IrBuilder::createInContainer(this->parent(), 1L, DataType::Index); + auto one_val = IrBuilder::createInContainer( + parentAsFusion(this->parent()), 1L, DataType::Index); NVF_ERROR(vals_up_.back().get() == one_val); one_val_ = std::unique_ptr(vals_up_.back().release()); vals_up_.pop_back(); @@ -285,14 +296,15 @@ Val* IrStorage::oneVal(DataType dtype) { return trueVal(); } else { // NOTE: this does not cache values - return IrBuilder::createInContainer(this->parent(), 1L, dtype); + return IrBuilder::createInContainer( + parentAsFusion(this->parent()), 1L, dtype); } } Val* IrStorage::falseVal() { if (!false_val_) { auto false_val = IrBuilder::createInContainer( - this->parent(), false, DataType::Bool); + parentAsFusion(this->parent()), false, DataType::Bool); NVF_ERROR(vals_up_.back().get() == false_val); false_val_ = std::unique_ptr(vals_up_.back().release()); vals_up_.pop_back(); @@ -302,8 +314,8 @@ Val* IrStorage::falseVal() { Val* IrStorage::trueVal() { if (!true_val_) { - auto true_val = - IrBuilder::createInContainer(this->parent(), true, DataType::Bool); + auto true_val = IrBuilder::createInContainer( + parentAsFusion(this->parent()), true, DataType::Bool); NVF_ERROR(vals_up_.back().get() == true_val); true_val_ = std::unique_ptr(vals_up_.back().release()); vals_up_.pop_back(); @@ -325,10 +337,10 @@ NamedScalar* IrStorage::magicZeroVal() { Val* IrStorage::metadataOf(Val* v) { if (metadata_.count(v) == 0) { - auto metadata_val = - IrBuilder::createInContainer(this->parent(), metaDataTypeOf(v)); + auto metadata_val = IrBuilder::createInContainer( + parentAsFusion(this->parent()), metaDataTypeOf(v)); auto metadata_expr = IrBuilder::createInContainer( - this->parent(), metadata_val, v); + parentAsFusion(this->parent()), metadata_val, v); metadata_[v] = std::make_pair(metadata_val, metadata_expr); } return metadata_.at(v).first; @@ -350,13 +362,13 @@ void IrStorage::lazyInitAxioms() { } void IrStorage::assumePositive(Val* val) { - NVF_ERROR(val->container() == this->parent()); + NVF_ERROR(val->container() == parentAsFusion(this->parent())); lazyInitAxioms(); axioms_->emplace_back(IrBuilder::gtExpr(val, zeroVal())); } void IrStorage::assumeNonNegative(Val* val) { - NVF_ERROR(val->container() == this->parent()); + NVF_ERROR(val->container() == parentAsFusion(this->parent())); lazyInitAxioms(); axioms_->emplace_back(IrBuilder::geExpr(val, zeroVal())); } diff --git a/csrc/ir/storage.h b/csrc/ir/storage.h index 174d64091fc..b4c4528ddd8 100644 --- a/csrc/ir/storage.h +++ b/csrc/ir/storage.h @@ -18,6 +18,11 @@ namespace nvfuser { +// Forward declaration of impl namespace +namespace impl { +class IrContainer; +} + class NamedScalar; class IrStorage { @@ -102,7 +107,7 @@ class IrStorage { static void swap(IrStorage& a, IrStorage& b) noexcept; // Let IrInterface access protected methods for forwarding - friend class IrContainer; + friend class impl::IrContainer; // Let Fusion access IrStorage::clear() friend class Fusion; @@ -183,7 +188,7 @@ class IrStorage { std::unordered_map> metadata_; public: - IrContainer* parent() const { + impl::IrContainer* parent() const { NVF_ERROR( parent_ != nullptr, "Call to IrContainer::parent() holds nullptr.") return parent_; @@ -192,7 +197,7 @@ class IrStorage { private: // Parent IrInterface that owns this container (for pure composition pattern) // Used by Statement::fusion() to navigate back to owning Fusion - IrContainer* parent_ = nullptr; + impl::IrContainer* parent_ = nullptr; }; } // namespace nvfuser diff --git a/csrc/ir/utils.cpp b/csrc/ir/utils.cpp index bf2dc04f281..34792ccb5cc 100644 --- a/csrc/ir/utils.cpp +++ b/csrc/ir/utils.cpp @@ -200,7 +200,7 @@ struct SubstituteInExpr : public OptOutMutator { } protected: - void removeExpr(IrContainer*, Expr*) const override {} + void removeExpr(Fusion*, Expr*) const override {} void registerNewExpr(Expr* expr) override { expr_ = expr; diff --git a/csrc/kernel.h b/csrc/kernel.h index 0c97f24dcde..b10a339887a 100644 --- a/csrc/kernel.h +++ b/csrc/kernel.h @@ -280,8 +280,8 @@ class NVF_API Kernel final : public Fusion { } protected: - using IrContainer::registerExpr; - using IrContainer::registerVal; + using impl::IrContainer::registerExpr; + using impl::IrContainer::registerVal; //! Register the Val with this fusion void registerVal(Val* val) override; diff --git a/csrc/mutator.cpp b/csrc/mutator.cpp index d5395c4f0f5..cfc529b6337 100644 --- a/csrc/mutator.cpp +++ b/csrc/mutator.cpp @@ -291,7 +291,7 @@ Expr* OptOutMutator::mutateExpr( return new_expr; } -void OptOutMutator::removeExpr(IrContainer* container, Expr* expr) const { +void OptOutMutator::removeExpr(Fusion* container, Expr* expr) const { container->removeExpr(expr); } From 02f22a8c566df439b302db817b5bb8b2fa247bcc Mon Sep 17 00:00:00 2001 From: Michael Davis Date: Fri, 30 Jan 2026 14:57:30 -0800 Subject: [PATCH 2/7] Directly point to Fusion* types from IrStorage. --- csrc/ir/container.cpp | 6 +++--- csrc/ir/storage.cpp | 44 ++++++++++++++++--------------------------- csrc/ir/storage.h | 4 ++-- 3 files changed, 21 insertions(+), 33 deletions(-) diff --git a/csrc/ir/container.cpp b/csrc/ir/container.cpp index 1e460e4e063..168fbb66452 100644 --- a/csrc/ir/container.cpp +++ b/csrc/ir/container.cpp @@ -19,7 +19,7 @@ class Fusion; namespace impl { IrContainer::IrContainer() : ir_storage_(std::make_unique()) { - ir_storage()->parent_ = this; + ir_storage()->parent_ = static_cast(this); } IrContainer::~IrContainer() {} @@ -33,12 +33,12 @@ void IrContainer::swap(IrContainer& a, IrContainer& b) noexcept { // After swap, each IrContainer owns a different IrStorage, so we must update // the parent backpointers in those containers to point to their new owners if (a.ir_storage_) { - a.ir_storage()->parent_ = &a; // Also update all Statement ir_container_ pointers to point to new owner // Note: IrContainer is now in impl namespace, but Statement::ir_container_ // is Fusion*. Since only Fusion (and its derived classes) inherit from // impl::IrContainer, this cast is safe. auto* fusion_a = static_cast(&a); + a.ir_storage()->parent_ = fusion_a; for (auto val : a.vals()) { val->ir_container_ = fusion_a; } @@ -47,9 +47,9 @@ void IrContainer::swap(IrContainer& a, IrContainer& b) noexcept { } } if (b.ir_storage_) { - b.ir_storage()->parent_ = &b; // Also update all Statement ir_container_ pointers to point to new owner auto* fusion_b = static_cast(&b); + b.ir_storage()->parent_ = fusion_b; for (auto val : b.vals()) { val->ir_container_ = fusion_b; } diff --git a/csrc/ir/storage.cpp b/csrc/ir/storage.cpp index 94008e0fe2c..6dff6c52980 100644 --- a/csrc/ir/storage.cpp +++ b/csrc/ir/storage.cpp @@ -15,15 +15,6 @@ namespace nvfuser { -// Forward declaration - Fusion inherits from impl::IrContainer -class Fusion; - -// Helper to cast parent() to Fusion* for IrBuilder calls -static inline Fusion* parentAsFusion(impl::IrContainer* parent) { - // Safe cast since only Fusion inherits from impl::IrContainer - return static_cast(parent); -} - //! Return values in insertion order const std::deque IrStorage::deterministic_vals() const noexcept { std::deque vals_deque; @@ -103,8 +94,7 @@ void IrStorage::swap(IrStorage& a, IrStorage& b) noexcept { IrCloner IrStorage::copy(const IrStorage* from, IrStorage* to) { to->clear(); - // parent() returns impl::IrContainer*, but IrCloner needs Fusion* - IrCloner ir_cloner(parentAsFusion(to->parent())); + IrCloner ir_cloner(to->parent()); // Copy values in deterministic order // deterministic_vals can contain special values like one_val_, zero_val_, etc @@ -235,7 +225,7 @@ bool IrStorage::inContainer(const Statement* const_stmt) const { } NVF_ERROR( - const_stmt->container() == parentAsFusion(this->parent()), + const_stmt->container() == this->parent(), "Container claims to own stmt, but stmt disagrees."); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) @@ -257,8 +247,8 @@ bool IrStorage::inContainer(const Statement* const_stmt) const { // Shortcuts for frequently used vals Val* IrStorage::zeroVal() { if (!zero_val_) { - auto zero_val = IrBuilder::createInContainer( - parentAsFusion(this->parent()), 0L, DataType::Index); + auto zero_val = + IrBuilder::createInContainer(this->parent(), 0L, DataType::Index); NVF_ERROR(vals_up_.back().get() == zero_val); zero_val_ = std::unique_ptr(vals_up_.back().release()); vals_up_.pop_back(); @@ -273,15 +263,14 @@ Val* IrStorage::zeroVal(DataType dtype) { return falseVal(); } else { // NOTE: this does not cache values - return IrBuilder::createInContainer( - parentAsFusion(this->parent()), 0L, dtype); + return IrBuilder::createInContainer(this->parent(), 0L, dtype); } } Val* IrStorage::oneVal() { if (!one_val_) { - auto one_val = IrBuilder::createInContainer( - parentAsFusion(this->parent()), 1L, DataType::Index); + auto one_val = + IrBuilder::createInContainer(this->parent(), 1L, DataType::Index); NVF_ERROR(vals_up_.back().get() == one_val); one_val_ = std::unique_ptr(vals_up_.back().release()); vals_up_.pop_back(); @@ -296,15 +285,14 @@ Val* IrStorage::oneVal(DataType dtype) { return trueVal(); } else { // NOTE: this does not cache values - return IrBuilder::createInContainer( - parentAsFusion(this->parent()), 1L, dtype); + return IrBuilder::createInContainer(this->parent(), 1L, dtype); } } Val* IrStorage::falseVal() { if (!false_val_) { auto false_val = IrBuilder::createInContainer( - parentAsFusion(this->parent()), false, DataType::Bool); + this->parent(), false, DataType::Bool); NVF_ERROR(vals_up_.back().get() == false_val); false_val_ = std::unique_ptr(vals_up_.back().release()); vals_up_.pop_back(); @@ -314,8 +302,8 @@ Val* IrStorage::falseVal() { Val* IrStorage::trueVal() { if (!true_val_) { - auto true_val = IrBuilder::createInContainer( - parentAsFusion(this->parent()), true, DataType::Bool); + auto true_val = + IrBuilder::createInContainer(this->parent(), true, DataType::Bool); NVF_ERROR(vals_up_.back().get() == true_val); true_val_ = std::unique_ptr(vals_up_.back().release()); vals_up_.pop_back(); @@ -337,10 +325,10 @@ NamedScalar* IrStorage::magicZeroVal() { Val* IrStorage::metadataOf(Val* v) { if (metadata_.count(v) == 0) { - auto metadata_val = IrBuilder::createInContainer( - parentAsFusion(this->parent()), metaDataTypeOf(v)); + auto metadata_val = + IrBuilder::createInContainer(this->parent(), metaDataTypeOf(v)); auto metadata_expr = IrBuilder::createInContainer( - parentAsFusion(this->parent()), metadata_val, v); + this->parent(), metadata_val, v); metadata_[v] = std::make_pair(metadata_val, metadata_expr); } return metadata_.at(v).first; @@ -362,13 +350,13 @@ void IrStorage::lazyInitAxioms() { } void IrStorage::assumePositive(Val* val) { - NVF_ERROR(val->container() == parentAsFusion(this->parent())); + NVF_ERROR(val->container() == this->parent()); lazyInitAxioms(); axioms_->emplace_back(IrBuilder::gtExpr(val, zeroVal())); } void IrStorage::assumeNonNegative(Val* val) { - NVF_ERROR(val->container() == parentAsFusion(this->parent())); + NVF_ERROR(val->container() == this->parent()); lazyInitAxioms(); axioms_->emplace_back(IrBuilder::geExpr(val, zeroVal())); } diff --git a/csrc/ir/storage.h b/csrc/ir/storage.h index b4c4528ddd8..1e3cbb94924 100644 --- a/csrc/ir/storage.h +++ b/csrc/ir/storage.h @@ -188,7 +188,7 @@ class IrStorage { std::unordered_map> metadata_; public: - impl::IrContainer* parent() const { + Fusion* parent() const { NVF_ERROR( parent_ != nullptr, "Call to IrContainer::parent() holds nullptr.") return parent_; @@ -197,7 +197,7 @@ class IrStorage { private: // Parent IrInterface that owns this container (for pure composition pattern) // Used by Statement::fusion() to navigate back to owning Fusion - impl::IrContainer* parent_ = nullptr; + Fusion* parent_ = nullptr; }; } // namespace nvfuser From dfd0c174965c801eed298c6b084d88cd03743684 Mon Sep 17 00:00:00 2001 From: Michael Davis Date: Fri, 30 Jan 2026 16:40:58 -0800 Subject: [PATCH 3/7] Moving forwarding interface into Fusion. --- csrc/fusion.cpp | 43 +++++- csrc/fusion.h | 138 ++++++++++++++++++- csrc/ir/base_nodes.h | 6 +- csrc/ir/container.cpp | 82 ++++++------ csrc/ir/container.h | 303 +++++++++++++++++++++--------------------- csrc/kernel.h | 3 - 6 files changed, 364 insertions(+), 211 deletions(-) diff --git a/csrc/fusion.cpp b/csrc/fusion.cpp index 542f715dba5..b795e8b420a 100644 --- a/csrc/fusion.cpp +++ b/csrc/fusion.cpp @@ -104,8 +104,36 @@ bool Fusion::sameDefinition(const Fusion& other) const { void Fusion::swap(Fusion& a, Fusion& b) noexcept { FUSER_PERF_SCOPE("Fusion swap"); - // Swap IrContainer base class (contains IrStorage) - IrContainer::swap(static_cast(a), static_cast(b)); + // We need to be careful to call IrStorage swap not unique_ptr swap, which + // will only swap the ptrs NOT the contents. + IrStorage::swap(*(a.ir_storage()), *(b.ir_storage())); + + // Fix parent pointers after swapping containers + // After swap, each IrContainer owns a different IrStorage, so we must update + // the parent backpointers in those containers to point to their new owners + if (a.ir_storage_) { + // Also update all Statement ir_container_ pointers to point to new owner + // Note: IrContainer is now in impl namespace, but Statement::ir_container_ + // is Fusion*. Since only Fusion (and its derived classes) inherit from + // impl::IrContainer, this cast is safe. + a.ir_storage()->parent_ = &a; + for (auto val : a.vals()) { + val->ir_container_ = &a; + } + for (auto expr : a.deterministic_exprs()) { + expr->ir_container_ = &a; + } + } + if (b.ir_storage_) { + // Also update all Statement ir_container_ pointers to point to new owner + b.ir_storage()->parent_ = &b; + for (auto val : b.vals()) { + val->ir_container_ = &b; + } + for (auto expr : b.deterministic_exprs()) { + expr->ir_container_ = &b; + } + } std::swap(a.inputs_, b.inputs_); std::swap(a.outputs_, b.outputs_); @@ -122,7 +150,7 @@ std::unique_ptr Fusion::segment( IrCloner Fusion::copy(const Fusion* from, Fusion* to) { to->clear(); - auto ir_cloner = IrContainer::copy(from, to); + auto ir_cloner = IrStorage::copy(from->ir_storage(), to->ir_storage()); for (auto val : from->vals()) { ir_cloner.clone(val)->setDefinition(ir_cloner.clone(val->definition_)); @@ -260,7 +288,8 @@ void Fusion::removeExpr(Expr* expr) { } } - IrContainer::removeExpr(expr); + // TODO : CHECK THIS vvv + ir_storage()->removeExpr(expr); } void Fusion::removeVal(Val* val) { @@ -304,7 +333,7 @@ void Fusion::removeVal(Val* val) { for (auto e : exprs_to_remove) { removeExpr(e); } - IrContainer::removeVal(val); + ir_storage()->removeVal(val); invalidateTvsAndUses(); } @@ -668,7 +697,7 @@ void Fusion::registerVal(Val* val) { val->fusion() == this, val, " was not found in the active fusion."); } - IrContainer::registerVal(val); + ir_storage()->registerVal(val); } void Fusion::registerExpr(Expr* expr) { @@ -681,7 +710,7 @@ void Fusion::registerExpr(Expr* expr) { expr->fusion() == this, expr, " was not found in the active fusion."); } - IrContainer::registerExpr(expr); + ir_storage()->registerExpr(expr); for (Val* input : expr->inputs()) { assertInContainer(input, "Input to expr is invalid, "); diff --git a/csrc/fusion.h b/csrc/fusion.h index aa4d3abe341..6d8ae4eb840 100644 --- a/csrc/fusion.h +++ b/csrc/fusion.h @@ -145,7 +145,30 @@ class AliasInfoMap { class NVF_API Fusion : public impl::IrContainer { typedef std::unordered_map> PermutationMap; + protected: + // Direct access to underlying container + IrStorage* ir_storage() { + NVF_ERROR( + ir_storage_.get() != nullptr, "Accessing a uninitialized IrContainer!.") + return ir_storage_.get(); + } + + const IrStorage* ir_storage() const { + NVF_ERROR( + ir_storage_.get() != nullptr, "Accessing a uninitialized IrContainer!.") + return ir_storage_.get(); + } + public: + // Registration (public API with passkey) + virtual void registerStmt(IrBuilderPasskey passkey, Statement* stmt) { + if (stmt->isVal()) { + registerVal(stmt->asVal()); + } else { + registerExpr(stmt->asExpr()); + } + } + Fusion() = default; Fusion(const Fusion& other); @@ -168,11 +191,11 @@ class NVF_API Fusion : public impl::IrContainer { //! Break dependency chains associated with Expr, remove references to expr //! delete expr - void removeExpr(Expr* expr) override; + virtual void removeExpr(Expr* expr); //! Completely remove val from the fusion, break all dependencies associated //! with it - void removeVal(Val* val) override; + virtual void removeVal(Val* val); //! Register input as an input of the fusion void addInput(Val* input); @@ -477,17 +500,118 @@ class NVF_API Fusion : public impl::IrContainer { void resetExactMappings(); + //=================================================================== + // IrStorage API Forwarding (Public Methods) + //=================================================================== + + // Container queries + bool inContainer(const Statement* stmt) const { + return ir_storage()->inContainer(stmt); + } + + void assertInContainer(const Statement* stmt, const std::string& msg) const { + ir_storage()->assertInContainer(stmt, msg); + } + + // Collections access (return values in insertion order) + const std::deque deterministic_vals() const noexcept { + return ir_storage()->deterministic_vals(); + } + + const std::deque deterministic_exprs() const noexcept { + return ir_storage()->deterministic_exprs(); + } + + const std::unordered_map deterministic_vals_map() + const noexcept { + return ir_storage()->deterministic_vals_map(); + } + + const std::unordered_map deterministic_exprs_map() + const noexcept { + return ir_storage()->deterministic_exprs_map(); + } + + // Collections access (unordered sets) + const std::unordered_set& unordered_exprs() const noexcept { + return ir_storage()->unordered_exprs(); + } + + const std::unordered_set& vals() const noexcept { + return ir_storage()->vals(); + } + + // Count queries + int64_t numExprs() const noexcept { + return ir_storage()->numExprs(); + } + + int64_t numVals(bool include_shortcuts) const noexcept { + return ir_storage()->numVals(include_shortcuts); + } + + // Shortcut values (frequently used constants) + Val* zeroVal() { + return ir_storage()->zeroVal(); + } + + Val* oneVal() { + return ir_storage()->oneVal(); + } + + Val* falseVal() { + return ir_storage()->falseVal(); + } + + Val* trueVal() { + return ir_storage()->trueVal(); + } + + NamedScalar* magicZeroVal() { + return ir_storage()->magicZeroVal(); + } + + Val* zeroVal(DataType dtype) { + return ir_storage()->zeroVal(dtype); + } + + Val* oneVal(DataType dtype) { + return ir_storage()->oneVal(dtype); + } + + Val* metadataOf(Val* val) { + return ir_storage()->metadataOf(val); + } + + // Axioms (CUDA programming assumptions) + const std::vector& axioms() { + return ir_storage()->axioms(); + } + + void assumePositive(Val* val) { + ir_storage()->assumePositive(val); + } + + void assumeNonNegative(Val* val) { + ir_storage()->assumeNonNegative(val); + } + + // Statement removal + void removeStatementsCreatedAfter( + int64_t num_exprs_before, + int64_t num_vals_before) { + ir_storage()->removeStatementsCreatedAfter( + num_exprs_before, num_vals_before); + } + protected: friend SegmentCandidateFinder; friend SegmentedFusion; friend class TranslateApplicableWelford; friend Val; - using impl::IrContainer::registerExpr; - using impl::IrContainer::registerVal; - //! Register the Val with this fusion - void registerVal(Val* val) override; + virtual void registerVal(Val* val); //! Register expr with this fusion. //! When we register an expression, we want to update the dependency tracking @@ -495,7 +619,7 @@ class NVF_API Fusion : public impl::IrContainer { //! definitions of outputs and register this Expr as the definition. Otherwise //! will update definition if not previously set, but will not remove old //! definitions. - void registerExpr(Expr* expr) override; + virtual void registerExpr(Expr* expr); //! Clear Expr's from TV uses that are not required to produce outputs from //! inputs. Only other place this is used (other than Fusion) is in diff --git a/csrc/ir/base_nodes.h b/csrc/ir/base_nodes.h index 661ddc6cbcb..8142a956655 100644 --- a/csrc/ir/base_nodes.h +++ b/csrc/ir/base_nodes.h @@ -9,8 +9,6 @@ #include #include -#include -#include #include #include @@ -98,9 +96,7 @@ class ExprPasskey { //! Basically beinng able to succienctly traverse down the inhereitance stack of //! a Statment at runtime. This is currently implemented in dispatch.h class NVF_API Statement : public NonCopyable, public PolymorphicBase { - friend void swap(Fusion&, Fusion&) noexcept; - friend void swap(impl::IrContainer& a, impl::IrContainer& b) noexcept; - friend class impl::IrContainer; + friend class Fusion; public: Statement() = delete; diff --git a/csrc/ir/container.cpp b/csrc/ir/container.cpp index 168fbb66452..baa1e44ade9 100644 --- a/csrc/ir/container.cpp +++ b/csrc/ir/container.cpp @@ -19,51 +19,53 @@ class Fusion; namespace impl { IrContainer::IrContainer() : ir_storage_(std::make_unique()) { - ir_storage()->parent_ = static_cast(this); + ir_storage_->parent_ = static_cast(this); } IrContainer::~IrContainer() {} -void IrContainer::swap(IrContainer& a, IrContainer& b) noexcept { - // We need to be careful to call IrStorage swap not unique_ptr swap, which - // will only swap the ptrs NOT the contents. - IrStorage::swap(*(a.ir_storage()), *(b.ir_storage())); +// void IrContainer::swap(IrContainer& a, IrContainer& b) noexcept { +// // We need to be careful to call IrStorage swap not unique_ptr swap, which +// // will only swap the ptrs NOT the contents. +// IrStorage::swap(*(a.ir_storage()), *(b.ir_storage())); +// +// // Fix parent pointers after swapping containers +// // After swap, each IrContainer owns a different IrStorage, so we must +// update +// // the parent backpointers in those containers to point to their new owners +// if (a.ir_storage_) { +// // Also update all Statement ir_container_ pointers to point to new owner +// // Note: IrContainer is now in impl namespace, but +// Statement::ir_container_ +// // is Fusion*. Since only Fusion (and its derived classes) inherit from +// // impl::IrContainer, this cast is safe. +// auto* fusion_a = static_cast(&a); +// a.ir_storage()->parent_ = fusion_a; +// for (auto val : a.vals()) { +// val->ir_container_ = fusion_a; +// } +// for (auto expr : a.deterministic_exprs()) { +// expr->ir_container_ = fusion_a; +// } +// } +// if (b.ir_storage_) { +// // Also update all Statement ir_container_ pointers to point to new owner +// auto* fusion_b = static_cast(&b); +// b.ir_storage()->parent_ = fusion_b; +// for (auto val : b.vals()) { +// val->ir_container_ = fusion_b; +// } +// for (auto expr : b.deterministic_exprs()) { +// expr->ir_container_ = fusion_b; +// } +// } +// } - // Fix parent pointers after swapping containers - // After swap, each IrContainer owns a different IrStorage, so we must update - // the parent backpointers in those containers to point to their new owners - if (a.ir_storage_) { - // Also update all Statement ir_container_ pointers to point to new owner - // Note: IrContainer is now in impl namespace, but Statement::ir_container_ - // is Fusion*. Since only Fusion (and its derived classes) inherit from - // impl::IrContainer, this cast is safe. - auto* fusion_a = static_cast(&a); - a.ir_storage()->parent_ = fusion_a; - for (auto val : a.vals()) { - val->ir_container_ = fusion_a; - } - for (auto expr : a.deterministic_exprs()) { - expr->ir_container_ = fusion_a; - } - } - if (b.ir_storage_) { - // Also update all Statement ir_container_ pointers to point to new owner - auto* fusion_b = static_cast(&b); - b.ir_storage()->parent_ = fusion_b; - for (auto val : b.vals()) { - val->ir_container_ = fusion_b; - } - for (auto expr : b.deterministic_exprs()) { - expr->ir_container_ = fusion_b; - } - } -} - -IrCloner IrContainer::copy(const IrContainer* from, IrContainer* to) { - auto ir_cloner = IrStorage::copy(from->ir_storage(), to->ir_storage()); - - return ir_cloner; -} +// IrCloner IrContainer::copy(const IrContainer* from, IrContainer* to) { +// auto ir_cloner = IrStorage::copy(from->ir_storage(), to->ir_storage()); +// +// return ir_cloner; +// } } // namespace impl } // namespace nvfuser diff --git a/csrc/ir/container.h b/csrc/ir/container.h index 9e1fa49dd9a..4693809fc7b 100644 --- a/csrc/ir/container.h +++ b/csrc/ir/container.h @@ -20,6 +20,7 @@ namespace impl { class IrContainer; } +class Fusion; class IrBuilderPasskey; class ExprPasskey; class OptOutMutator; @@ -75,175 +76,179 @@ class NVF_API IrContainer : public PolymorphicBase { // IrStorage API Forwarding (Public Methods) //=================================================================== - // Container queries - bool inContainer(const Statement* stmt) const { - return ir_storage()->inContainer(stmt); - } - - void assertInContainer(const Statement* stmt, const std::string& msg) const { - ir_storage()->assertInContainer(stmt, msg); - } - - // Collections access (return values in insertion order) - const std::deque deterministic_vals() const noexcept { - return ir_storage()->deterministic_vals(); - } - - const std::deque deterministic_exprs() const noexcept { - return ir_storage()->deterministic_exprs(); - } - - const std::unordered_map deterministic_vals_map() - const noexcept { - return ir_storage()->deterministic_vals_map(); - } - - const std::unordered_map deterministic_exprs_map() - const noexcept { - return ir_storage()->deterministic_exprs_map(); - } - - // Collections access (unordered sets) - const std::unordered_set& unordered_exprs() const noexcept { - return ir_storage()->unordered_exprs(); - } - - const std::unordered_set& vals() const noexcept { - return ir_storage()->vals(); - } - - // Count queries - int64_t numExprs() const noexcept { - return ir_storage()->numExprs(); - } - - int64_t numVals(bool include_shortcuts) const noexcept { - return ir_storage()->numVals(include_shortcuts); - } - - // Shortcut values (frequently used constants) - Val* zeroVal() { - return ir_storage()->zeroVal(); - } - - Val* oneVal() { - return ir_storage()->oneVal(); - } - - Val* falseVal() { - return ir_storage()->falseVal(); - } - - Val* trueVal() { - return ir_storage()->trueVal(); - } - - NamedScalar* magicZeroVal() { - return ir_storage()->magicZeroVal(); - } - - Val* zeroVal(DataType dtype) { - return ir_storage()->zeroVal(dtype); - } - - Val* oneVal(DataType dtype) { - return ir_storage()->oneVal(dtype); - } - - Val* metadataOf(Val* val) { - return ir_storage()->metadataOf(val); - } - - // Axioms (CUDA programming assumptions) - const std::vector& axioms() { - return ir_storage()->axioms(); - } - - void assumePositive(Val* val) { - ir_storage()->assumePositive(val); - } - - void assumeNonNegative(Val* val) { - ir_storage()->assumeNonNegative(val); - } - - // Statement removal - void removeStatementsCreatedAfter( - int64_t num_exprs_before, - int64_t num_vals_before) { - ir_storage()->removeStatementsCreatedAfter( - num_exprs_before, num_vals_before); - } - - // Registration (public API with passkey) - virtual void registerStmt(IrBuilderPasskey passkey, Statement* stmt) { - // Dispatch to Val or Expr registration, which calls the virtual protected - // methods that subclasses (like Fusion) override - if (stmt->isVal()) { - registerVal(passkey, stmt->asVal()); - } else { - registerExpr(passkey, stmt->asExpr()); - } - } - - virtual void registerVal(IrBuilderPasskey passkey, Val* val) { - // Call the protected virtual method that subclasses override - registerVal(val); - } - - virtual void registerExpr(IrBuilderPasskey passkey, Expr* expr) { - // Call the protected virtual method that subclasses override - registerExpr(expr); - } + //// Container queries + // bool inContainer(const Statement* stmt) const { + // return ir_storage()->inContainer(stmt); + // } + + // void assertInContainer(const Statement* stmt, const std::string& msg) const + // { + // ir_storage()->assertInContainer(stmt, msg); + // } + + //// Collections access (return values in insertion order) + // const std::deque deterministic_vals() const noexcept { + // return ir_storage()->deterministic_vals(); + // } + + // const std::deque deterministic_exprs() const noexcept { + // return ir_storage()->deterministic_exprs(); + // } + + // const std::unordered_map deterministic_vals_map() + // const noexcept { + // return ir_storage()->deterministic_vals_map(); + // } + + // const std::unordered_map deterministic_exprs_map() + // const noexcept { + // return ir_storage()->deterministic_exprs_map(); + // } + + //// Collections access (unordered sets) + // const std::unordered_set& unordered_exprs() const noexcept { + // return ir_storage()->unordered_exprs(); + // } + + // const std::unordered_set& vals() const noexcept { + // return ir_storage()->vals(); + // } + + //// Count queries + // int64_t numExprs() const noexcept { + // return ir_storage()->numExprs(); + // } + + // int64_t numVals(bool include_shortcuts) const noexcept { + // return ir_storage()->numVals(include_shortcuts); + // } + + //// Shortcut values (frequently used constants) + // Val* zeroVal() { + // return ir_storage()->zeroVal(); + // } + + // Val* oneVal() { + // return ir_storage()->oneVal(); + // } + + // Val* falseVal() { + // return ir_storage()->falseVal(); + // } + + // Val* trueVal() { + // return ir_storage()->trueVal(); + // } + + // NamedScalar* magicZeroVal() { + // return ir_storage()->magicZeroVal(); + // } + + // Val* zeroVal(DataType dtype) { + // return ir_storage()->zeroVal(dtype); + // } + + // Val* oneVal(DataType dtype) { + // return ir_storage()->oneVal(dtype); + // } + + // Val* metadataOf(Val* val) { + // return ir_storage()->metadataOf(val); + // } + + //// Axioms (CUDA programming assumptions) + // const std::vector& axioms() { + // return ir_storage()->axioms(); + // } + + // void assumePositive(Val* val) { + // ir_storage()->assumePositive(val); + // } + + // void assumeNonNegative(Val* val) { + // ir_storage()->assumeNonNegative(val); + // } + + //// Statement removal + // void removeStatementsCreatedAfter( + // int64_t num_exprs_before, + // int64_t num_vals_before) { + // ir_storage()->removeStatementsCreatedAfter( + // num_exprs_before, num_vals_before); + // } + + //// Registration (public API with passkey) + // virtual void registerStmt(IrBuilderPasskey passkey, Statement* stmt) { + // // Dispatch to Val or Expr registration, which calls the virtual + // protected + // // methods that subclasses (like Fusion) override + // if (stmt->isVal()) { + // registerVal(passkey, stmt->asVal()); + // } else { + // registerExpr(passkey, stmt->asExpr()); + // } + // } + + // virtual void registerVal(IrBuilderPasskey passkey, Val* val) { + // // Call the protected virtual method that subclasses override + // registerVal(val); + // } + + // virtual void registerExpr(IrBuilderPasskey passkey, Expr* expr) { + // // Call the protected virtual method that subclasses override + // registerExpr(expr); + // } //=================================================================== // Container Access //=================================================================== - // Direct access to underlying container - IrStorage* ir_storage() { - NVF_ERROR( - ir_storage_.get() != nullptr, "Accessing a uninitialized IrContainer!.") - return ir_storage_.get(); - } - - const IrStorage* ir_storage() const { - NVF_ERROR( - ir_storage_.get() != nullptr, "Accessing a uninitialized IrContainer!.") - return ir_storage_.get(); - } + //// Direct access to underlying container + // IrStorage* ir_storage() { + // NVF_ERROR( + // ir_storage_.get() != nullptr, "Accessing a uninitialized + // IrContainer!.") + // return ir_storage_.get(); + // } + + // const IrStorage* ir_storage() const { + // NVF_ERROR( + // ir_storage_.get() != nullptr, "Accessing a uninitialized + // IrContainer!.") + // return ir_storage_.get(); + // } protected: //=================================================================== // Protected Registration API (for derived class overrides) //=================================================================== - static IrCloner copy(const IrContainer* from, IrContainer* to); - static void swap(IrContainer& a, IrContainer& b) noexcept; + // static IrCloner copy(const IrContainer* from, IrContainer* to); + // static void swap(IrContainer& a, IrContainer& b) noexcept; + + //// Derived classes (like Fusion) override these to add custom logic + // virtual void registerVal(Val* val) { + // ir_storage()->registerVal(val); + // } - // Derived classes (like Fusion) override these to add custom logic - virtual void registerVal(Val* val) { - ir_storage()->registerVal(val); - } + // virtual void registerExpr(Expr* expr) { + // ir_storage()->registerExpr(expr); + // } - virtual void registerExpr(Expr* expr) { - ir_storage()->registerExpr(expr); - } + // virtual void removeExpr(Expr* expr) { + // ir_storage()->removeExpr(expr); + // } - virtual void removeExpr(Expr* expr) { - ir_storage()->removeExpr(expr); - } + // virtual void removeVal(Val* val) { + // ir_storage()->removeVal(val); + // } - virtual void removeVal(Val* val) { - ir_storage()->removeVal(val); - } + std::unique_ptr ir_storage_; private: //=================================================================== // Data Members //=================================================================== - - std::unique_ptr ir_storage_; }; } // namespace impl diff --git a/csrc/kernel.h b/csrc/kernel.h index b10a339887a..d6fdd2f2fa8 100644 --- a/csrc/kernel.h +++ b/csrc/kernel.h @@ -280,9 +280,6 @@ class NVF_API Kernel final : public Fusion { } protected: - using impl::IrContainer::registerExpr; - using impl::IrContainer::registerVal; - //! Register the Val with this fusion void registerVal(Val* val) override; From c1a3da731a59e6ae567b09a52084664ddec2b74f Mon Sep 17 00:00:00 2001 From: Michael Davis Date: Fri, 30 Jan 2026 16:42:49 -0800 Subject: [PATCH 4/7] Clean up commented interface definitions. --- csrc/ir/container.cpp | 43 --------- csrc/ir/container.h | 199 ------------------------------------------ 2 files changed, 242 deletions(-) diff --git a/csrc/ir/container.cpp b/csrc/ir/container.cpp index baa1e44ade9..77a93aa5ed5 100644 --- a/csrc/ir/container.cpp +++ b/csrc/ir/container.cpp @@ -24,48 +24,5 @@ IrContainer::IrContainer() : ir_storage_(std::make_unique()) { IrContainer::~IrContainer() {} -// void IrContainer::swap(IrContainer& a, IrContainer& b) noexcept { -// // We need to be careful to call IrStorage swap not unique_ptr swap, which -// // will only swap the ptrs NOT the contents. -// IrStorage::swap(*(a.ir_storage()), *(b.ir_storage())); -// -// // Fix parent pointers after swapping containers -// // After swap, each IrContainer owns a different IrStorage, so we must -// update -// // the parent backpointers in those containers to point to their new owners -// if (a.ir_storage_) { -// // Also update all Statement ir_container_ pointers to point to new owner -// // Note: IrContainer is now in impl namespace, but -// Statement::ir_container_ -// // is Fusion*. Since only Fusion (and its derived classes) inherit from -// // impl::IrContainer, this cast is safe. -// auto* fusion_a = static_cast(&a); -// a.ir_storage()->parent_ = fusion_a; -// for (auto val : a.vals()) { -// val->ir_container_ = fusion_a; -// } -// for (auto expr : a.deterministic_exprs()) { -// expr->ir_container_ = fusion_a; -// } -// } -// if (b.ir_storage_) { -// // Also update all Statement ir_container_ pointers to point to new owner -// auto* fusion_b = static_cast(&b); -// b.ir_storage()->parent_ = fusion_b; -// for (auto val : b.vals()) { -// val->ir_container_ = fusion_b; -// } -// for (auto expr : b.deterministic_exprs()) { -// expr->ir_container_ = fusion_b; -// } -// } -// } - -// IrCloner IrContainer::copy(const IrContainer* from, IrContainer* to) { -// auto ir_cloner = IrStorage::copy(from->ir_storage(), to->ir_storage()); -// -// return ir_cloner; -// } - } // namespace impl } // namespace nvfuser diff --git a/csrc/ir/container.h b/csrc/ir/container.h index 4693809fc7b..507b5b617bb 100644 --- a/csrc/ir/container.h +++ b/csrc/ir/container.h @@ -36,31 +36,11 @@ class IrContainerPasskey { namespace impl { -// IrContainer: Implementation detail base class for types that provide -// IrContainer API via composition -// -// This class handles the composition infrastructure and forwarding boilerplate -// for accessing IrStorage functionality. Derived classes (like Fusion) can -// focus on their specific logic while inheriting the full IrContainer API. -// -// Note: IrContainer is now in the impl namespace. External code should use -// Fusion as the public base class interface. -// -// Key Features: -// - Owns IrStorage via unique_ptr (can be shared_ptr in Phase 2) -// - Forwards all IrStorage public methods -// - Allows derived classes to override protected IrContainer methods class NVF_API IrContainer : public PolymorphicBase { protected: // Constructors explicit IrContainer(); - // TODO: The semantics of IrContainers are largely driven through copy/swap - // function behavior. It might be better if this behaviour was properly - // defined through class semantics directly. - // - // Copy/Move are deleted. IrContainer is a forwarding interface class. We - // rely on copy/swap function behavior to handle the semantics of IrStorage. IrContainer(const IrContainer& other) = delete; IrContainer(IrContainer&& other) noexcept = delete; IrContainer& operator=(const IrContainer& other) = delete; @@ -68,187 +48,8 @@ class NVF_API IrContainer : public PolymorphicBase { ~IrContainer() override; - // Let mutator remove Exprs. - friend OptOutMutator; - - public: - //=================================================================== - // IrStorage API Forwarding (Public Methods) - //=================================================================== - - //// Container queries - // bool inContainer(const Statement* stmt) const { - // return ir_storage()->inContainer(stmt); - // } - - // void assertInContainer(const Statement* stmt, const std::string& msg) const - // { - // ir_storage()->assertInContainer(stmt, msg); - // } - - //// Collections access (return values in insertion order) - // const std::deque deterministic_vals() const noexcept { - // return ir_storage()->deterministic_vals(); - // } - - // const std::deque deterministic_exprs() const noexcept { - // return ir_storage()->deterministic_exprs(); - // } - - // const std::unordered_map deterministic_vals_map() - // const noexcept { - // return ir_storage()->deterministic_vals_map(); - // } - - // const std::unordered_map deterministic_exprs_map() - // const noexcept { - // return ir_storage()->deterministic_exprs_map(); - // } - - //// Collections access (unordered sets) - // const std::unordered_set& unordered_exprs() const noexcept { - // return ir_storage()->unordered_exprs(); - // } - - // const std::unordered_set& vals() const noexcept { - // return ir_storage()->vals(); - // } - - //// Count queries - // int64_t numExprs() const noexcept { - // return ir_storage()->numExprs(); - // } - - // int64_t numVals(bool include_shortcuts) const noexcept { - // return ir_storage()->numVals(include_shortcuts); - // } - - //// Shortcut values (frequently used constants) - // Val* zeroVal() { - // return ir_storage()->zeroVal(); - // } - - // Val* oneVal() { - // return ir_storage()->oneVal(); - // } - - // Val* falseVal() { - // return ir_storage()->falseVal(); - // } - - // Val* trueVal() { - // return ir_storage()->trueVal(); - // } - - // NamedScalar* magicZeroVal() { - // return ir_storage()->magicZeroVal(); - // } - - // Val* zeroVal(DataType dtype) { - // return ir_storage()->zeroVal(dtype); - // } - - // Val* oneVal(DataType dtype) { - // return ir_storage()->oneVal(dtype); - // } - - // Val* metadataOf(Val* val) { - // return ir_storage()->metadataOf(val); - // } - - //// Axioms (CUDA programming assumptions) - // const std::vector& axioms() { - // return ir_storage()->axioms(); - // } - - // void assumePositive(Val* val) { - // ir_storage()->assumePositive(val); - // } - - // void assumeNonNegative(Val* val) { - // ir_storage()->assumeNonNegative(val); - // } - - //// Statement removal - // void removeStatementsCreatedAfter( - // int64_t num_exprs_before, - // int64_t num_vals_before) { - // ir_storage()->removeStatementsCreatedAfter( - // num_exprs_before, num_vals_before); - // } - - //// Registration (public API with passkey) - // virtual void registerStmt(IrBuilderPasskey passkey, Statement* stmt) { - // // Dispatch to Val or Expr registration, which calls the virtual - // protected - // // methods that subclasses (like Fusion) override - // if (stmt->isVal()) { - // registerVal(passkey, stmt->asVal()); - // } else { - // registerExpr(passkey, stmt->asExpr()); - // } - // } - - // virtual void registerVal(IrBuilderPasskey passkey, Val* val) { - // // Call the protected virtual method that subclasses override - // registerVal(val); - // } - - // virtual void registerExpr(IrBuilderPasskey passkey, Expr* expr) { - // // Call the protected virtual method that subclasses override - // registerExpr(expr); - // } - - //=================================================================== - // Container Access - //=================================================================== - - //// Direct access to underlying container - // IrStorage* ir_storage() { - // NVF_ERROR( - // ir_storage_.get() != nullptr, "Accessing a uninitialized - // IrContainer!.") - // return ir_storage_.get(); - // } - - // const IrStorage* ir_storage() const { - // NVF_ERROR( - // ir_storage_.get() != nullptr, "Accessing a uninitialized - // IrContainer!.") - // return ir_storage_.get(); - // } - protected: - //=================================================================== - // Protected Registration API (for derived class overrides) - //=================================================================== - - // static IrCloner copy(const IrContainer* from, IrContainer* to); - // static void swap(IrContainer& a, IrContainer& b) noexcept; - - //// Derived classes (like Fusion) override these to add custom logic - // virtual void registerVal(Val* val) { - // ir_storage()->registerVal(val); - // } - - // virtual void registerExpr(Expr* expr) { - // ir_storage()->registerExpr(expr); - // } - - // virtual void removeExpr(Expr* expr) { - // ir_storage()->removeExpr(expr); - // } - - // virtual void removeVal(Val* val) { - // ir_storage()->removeVal(val); - // } - std::unique_ptr ir_storage_; - - private: - //=================================================================== - // Data Members - //=================================================================== }; } // namespace impl From 2bd85f3297d48a4aad423dfa1b5390b734c16882 Mon Sep 17 00:00:00 2001 From: Michael Davis Date: Fri, 30 Jan 2026 17:10:39 -0800 Subject: [PATCH 5/7] Fusion owns u_ptr, Interface IrContainer class rm. --- csrc/fusion.cpp | 9 +++++++-- csrc/fusion.h | 7 ++++--- csrc/ir/container.cpp | 12 ++++++------ csrc/ir/container.h | 30 +++++++++++++++--------------- 4 files changed, 32 insertions(+), 26 deletions(-) diff --git a/csrc/fusion.cpp b/csrc/fusion.cpp index b795e8b420a..4ba44f9258f 100644 --- a/csrc/fusion.cpp +++ b/csrc/fusion.cpp @@ -211,14 +211,19 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) { return ir_cloner; } +// Default constructor +Fusion::Fusion() : ir_storage_(std::make_unique()) { + ir_storage_->parent_ = this; +} + // Copy constructor -Fusion::Fusion(const Fusion& other) { +Fusion::Fusion(const Fusion& other) : Fusion() { FUSER_PERF_SCOPE("Fusion copy"); Fusion::copy(&other, this); } // Move constructor -Fusion::Fusion(Fusion&& other) noexcept { +Fusion::Fusion(Fusion&& other) noexcept : Fusion() { FUSER_PERF_SCOPE("Fusion move"); swap(*this, other); } diff --git a/csrc/fusion.h b/csrc/fusion.h index 6d8ae4eb840..2aa1a40af8b 100644 --- a/csrc/fusion.h +++ b/csrc/fusion.h @@ -12,6 +12,7 @@ #include #include #include +#include "base.h" #include @@ -20,7 +21,6 @@ #include #include #include -#include #include #include #include @@ -142,7 +142,7 @@ class AliasInfoMap { //! The Fusion owns the whole IR graph (Vals and Exprs) //! // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) -class NVF_API Fusion : public impl::IrContainer { +class NVF_API Fusion : public PolymorphicBase { typedef std::unordered_map> PermutationMap; protected: @@ -169,7 +169,7 @@ class NVF_API Fusion : public impl::IrContainer { } } - Fusion() = default; + Fusion(); Fusion(const Fusion& other); Fusion(Fusion&& other) noexcept; @@ -663,6 +663,7 @@ class NVF_API Fusion : public impl::IrContainer { std::unique_ptr> all_tvs_ptr_ = nullptr; inline static const std::string exact_mappings_key = "exact_mappings"; + std::unique_ptr ir_storage_; }; // Template implementations for Fusion::manage() that use IrCloner diff --git a/csrc/ir/container.cpp b/csrc/ir/container.cpp index 77a93aa5ed5..1db665bcb0d 100644 --- a/csrc/ir/container.cpp +++ b/csrc/ir/container.cpp @@ -17,12 +17,12 @@ namespace nvfuser { class Fusion; namespace impl { - -IrContainer::IrContainer() : ir_storage_(std::make_unique()) { - ir_storage_->parent_ = static_cast(this); -} - -IrContainer::~IrContainer() {} +// +// IrContainer::IrContainer() : ir_storage_(std::make_unique()) { +// ir_storage_->parent_ = static_cast(this); +//} +// +// IrContainer::~IrContainer() {} } // namespace impl } // namespace nvfuser diff --git a/csrc/ir/container.h b/csrc/ir/container.h index 507b5b617bb..e169903795e 100644 --- a/csrc/ir/container.h +++ b/csrc/ir/container.h @@ -36,21 +36,21 @@ class IrContainerPasskey { namespace impl { -class NVF_API IrContainer : public PolymorphicBase { - protected: - // Constructors - explicit IrContainer(); - - IrContainer(const IrContainer& other) = delete; - IrContainer(IrContainer&& other) noexcept = delete; - IrContainer& operator=(const IrContainer& other) = delete; - IrContainer& operator=(IrContainer&& other) noexcept = delete; - - ~IrContainer() override; - - protected: - std::unique_ptr ir_storage_; -}; +// class NVF_API IrContainer : public PolymorphicBase { +// protected: +// // Constructors +// explicit IrContainer(); +// +// IrContainer(const IrContainer& other) = delete; +// IrContainer(IrContainer&& other) noexcept = delete; +// IrContainer& operator=(const IrContainer& other) = delete; +// IrContainer& operator=(IrContainer&& other) noexcept = delete; +// +// ~IrContainer() override; +// +// protected: +// std::unique_ptr ir_storage_; +// }; } // namespace impl } // namespace nvfuser From f162b0c46f7c8489d4f96eb1302e5ca6645e1369 Mon Sep 17 00:00:00 2001 From: Michael Davis Date: Fri, 30 Jan 2026 18:01:21 -0800 Subject: [PATCH 6/7] Completely removing the forwarding interface of IrContainer. --- CMakeLists.txt | 2 +- csrc/fusion.h | 1 + csrc/ir/base_nodes.h | 6 ------ csrc/ir/builder.h | 2 +- csrc/ir/container.cpp | 28 --------------------------- csrc/ir/container.h | 44 +------------------------------------------ csrc/ir/storage.h | 16 ++++++++-------- 7 files changed, 12 insertions(+), 87 deletions(-) delete mode 100644 csrc/ir/container.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 8069ba0d7d4..d6289a29bb7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -215,7 +215,7 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/ir/base_nodes.cpp ${NVFUSER_SRCS_DIR}/ir/builder.cpp ${NVFUSER_SRCS_DIR}/ir/cloner.cpp - ${NVFUSER_SRCS_DIR}/ir/container.cpp + #${NVFUSER_SRCS_DIR}/ir/container.cpp ${NVFUSER_SRCS_DIR}/ir/storage.cpp ${NVFUSER_SRCS_DIR}/ir/graphviz.cpp ${NVFUSER_SRCS_DIR}/ir/iostream.cpp diff --git a/csrc/fusion.h b/csrc/fusion.h index 2aa1a40af8b..219e71558d9 100644 --- a/csrc/fusion.h +++ b/csrc/fusion.h @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include diff --git a/csrc/ir/base_nodes.h b/csrc/ir/base_nodes.h index 8142a956655..e78db26e81b 100644 --- a/csrc/ir/base_nodes.h +++ b/csrc/ir/base_nodes.h @@ -60,11 +60,6 @@ class IrBuilderPasskey; class IrContainerPasskey; class ExpressionEvaluator; -// Forward declaration of impl namespace -namespace impl { -class IrContainer; -} - namespace kir { class Kernel; class Predicate; @@ -422,7 +417,6 @@ class NVF_API Val : public Statement { protected: friend class Fusion; - friend class impl::IrContainer; friend class IrStorage; // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) diff --git a/csrc/ir/builder.h b/csrc/ir/builder.h index af0fea66d32..80e7828067c 100644 --- a/csrc/ir/builder.h +++ b/csrc/ir/builder.h @@ -11,7 +11,7 @@ #include "exceptions.h" #include "fusion_guard.h" #include "ir/builder_passkey.h" -#include "ir/container.h" +#include "ir/storage.h" #include "visibility.h" namespace nvfuser { diff --git a/csrc/ir/container.cpp b/csrc/ir/container.cpp deleted file mode 100644 index 1db665bcb0d..00000000000 --- a/csrc/ir/container.cpp +++ /dev/null @@ -1,28 +0,0 @@ -// clang-format off -/* - * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. - * All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - */ -// clang-format on - -#include - -#include -#include - -namespace nvfuser { - -// Forward declaration - Fusion inherits from impl::IrContainer -class Fusion; - -namespace impl { -// -// IrContainer::IrContainer() : ir_storage_(std::make_unique()) { -// ir_storage_->parent_ = static_cast(this); -//} -// -// IrContainer::~IrContainer() {} - -} // namespace impl -} // namespace nvfuser diff --git a/csrc/ir/container.h b/csrc/ir/container.h index e169903795e..d6a53fb1e3a 100644 --- a/csrc/ir/container.h +++ b/csrc/ir/container.h @@ -7,50 +7,8 @@ // clang-format on #pragma once -#include - -#include -#include -#include - namespace nvfuser { -// Forward declaration of impl namespace -namespace impl { -class IrContainer; -} - -class Fusion; -class IrBuilderPasskey; -class ExprPasskey; -class OptOutMutator; - -// Passkey for container to register names with statements -class IrContainerPasskey { - friend class impl::IrContainer; - friend class IrStorage; - - private: - explicit IrContainerPasskey() = default; -}; - -namespace impl { - -// class NVF_API IrContainer : public PolymorphicBase { -// protected: -// // Constructors -// explicit IrContainer(); -// -// IrContainer(const IrContainer& other) = delete; -// IrContainer(IrContainer&& other) noexcept = delete; -// IrContainer& operator=(const IrContainer& other) = delete; -// IrContainer& operator=(IrContainer&& other) noexcept = delete; -// -// ~IrContainer() override; -// -// protected: -// std::unique_ptr ir_storage_; -// }; +// Empty for now... -} // namespace impl } // namespace nvfuser diff --git a/csrc/ir/storage.h b/csrc/ir/storage.h index 1e3cbb94924..10f5ec1f91c 100644 --- a/csrc/ir/storage.h +++ b/csrc/ir/storage.h @@ -18,10 +18,13 @@ namespace nvfuser { -// Forward declaration of impl namespace -namespace impl { -class IrContainer; -} +// Passkey for container to register names with statements +class IrContainerPasskey { + friend class IrStorage; + + private: + explicit IrContainerPasskey() = default; +}; class NamedScalar; @@ -106,9 +109,6 @@ class IrStorage { static void swap(IrStorage& a, IrStorage& b) noexcept; - // Let IrInterface access protected methods for forwarding - friend class impl::IrContainer; - // Let Fusion access IrStorage::clear() friend class Fusion; @@ -195,7 +195,7 @@ class IrStorage { } private: - // Parent IrInterface that owns this container (for pure composition pattern) + // Parent Fusion that owns this container (for pure composition pattern) // Used by Statement::fusion() to navigate back to owning Fusion Fusion* parent_ = nullptr; }; From 3aedecfc0c47d7b0616def3ba0aaef4b2d013068 Mon Sep 17 00:00:00 2001 From: Michael Davis Date: Fri, 30 Jan 2026 18:16:55 -0800 Subject: [PATCH 7/7] IrStorage -> IrContainer --- csrc/fusion.cpp | 37 ++++++++++++-------------- csrc/fusion.h | 62 +++++++++++++++++++++++--------------------- csrc/ir/base_nodes.h | 4 +-- csrc/ir/storage.cpp | 54 +++++++++++++++++++------------------- csrc/ir/storage.h | 24 ++++++++--------- 5 files changed, 90 insertions(+), 91 deletions(-) diff --git a/csrc/fusion.cpp b/csrc/fusion.cpp index 4ba44f9258f..baf1de84614 100644 --- a/csrc/fusion.cpp +++ b/csrc/fusion.cpp @@ -104,19 +104,17 @@ bool Fusion::sameDefinition(const Fusion& other) const { void Fusion::swap(Fusion& a, Fusion& b) noexcept { FUSER_PERF_SCOPE("Fusion swap"); - // We need to be careful to call IrStorage swap not unique_ptr swap, which + // We need to be careful to call IrContainer swap not unique_ptr swap, which // will only swap the ptrs NOT the contents. - IrStorage::swap(*(a.ir_storage()), *(b.ir_storage())); + IrContainer::swap(*(a.ir_container()), *(b.ir_container())); // Fix parent pointers after swapping containers - // After swap, each IrContainer owns a different IrStorage, so we must update - // the parent backpointers in those containers to point to their new owners - if (a.ir_storage_) { + // After swap, each Fusion owns a different IrContainer, so we must + // update the parent backpointers in those containers to point to their new + // owners + if (a.ir_container_) { // Also update all Statement ir_container_ pointers to point to new owner - // Note: IrContainer is now in impl namespace, but Statement::ir_container_ - // is Fusion*. Since only Fusion (and its derived classes) inherit from - // impl::IrContainer, this cast is safe. - a.ir_storage()->parent_ = &a; + a.ir_container()->parent_ = &a; for (auto val : a.vals()) { val->ir_container_ = &a; } @@ -124,9 +122,9 @@ void Fusion::swap(Fusion& a, Fusion& b) noexcept { expr->ir_container_ = &a; } } - if (b.ir_storage_) { + if (b.ir_container_) { // Also update all Statement ir_container_ pointers to point to new owner - b.ir_storage()->parent_ = &b; + b.ir_container()->parent_ = &b; for (auto val : b.vals()) { val->ir_container_ = &b; } @@ -150,7 +148,7 @@ std::unique_ptr Fusion::segment( IrCloner Fusion::copy(const Fusion* from, Fusion* to) { to->clear(); - auto ir_cloner = IrStorage::copy(from->ir_storage(), to->ir_storage()); + auto ir_cloner = IrContainer::copy(from->ir_container(), to->ir_container()); for (auto val : from->vals()) { ir_cloner.clone(val)->setDefinition(ir_cloner.clone(val->definition_)); @@ -212,8 +210,8 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) { } // Default constructor -Fusion::Fusion() : ir_storage_(std::make_unique()) { - ir_storage_->parent_ = this; +Fusion::Fusion() : ir_container_(std::make_unique()) { + ir_container_->parent_ = this; } // Copy constructor @@ -256,7 +254,7 @@ void Fusion::clear() noexcept { // Clear container contents instead of destroying it // This preserves the container object so Statement pointers don't become // dangling - ir_storage()->clear(); + ir_container()->clear(); inputs_.clear(); outputs_.clear(); @@ -293,8 +291,7 @@ void Fusion::removeExpr(Expr* expr) { } } - // TODO : CHECK THIS vvv - ir_storage()->removeExpr(expr); + ir_container()->removeExpr(expr); } void Fusion::removeVal(Val* val) { @@ -338,7 +335,7 @@ void Fusion::removeVal(Val* val) { for (auto e : exprs_to_remove) { removeExpr(e); } - ir_storage()->removeVal(val); + ir_container()->removeVal(val); invalidateTvsAndUses(); } @@ -702,7 +699,7 @@ void Fusion::registerVal(Val* val) { val->fusion() == this, val, " was not found in the active fusion."); } - ir_storage()->registerVal(val); + ir_container()->registerVal(val); } void Fusion::registerExpr(Expr* expr) { @@ -715,7 +712,7 @@ void Fusion::registerExpr(Expr* expr) { expr->fusion() == this, expr, " was not found in the active fusion."); } - ir_storage()->registerExpr(expr); + ir_container()->registerExpr(expr); for (Val* input : expr->inputs()) { assertInContainer(input, "Input to expr is invalid, "); diff --git a/csrc/fusion.h b/csrc/fusion.h index 219e71558d9..4e7ce658574 100644 --- a/csrc/fusion.h +++ b/csrc/fusion.h @@ -148,16 +148,18 @@ class NVF_API Fusion : public PolymorphicBase { protected: // Direct access to underlying container - IrStorage* ir_storage() { + IrContainer* ir_container() { NVF_ERROR( - ir_storage_.get() != nullptr, "Accessing a uninitialized IrContainer!.") - return ir_storage_.get(); + ir_container_.get() != nullptr, + "Accessing a uninitialized IrContainer!.") + return ir_container_.get(); } - const IrStorage* ir_storage() const { + const IrContainer* ir_container() const { NVF_ERROR( - ir_storage_.get() != nullptr, "Accessing a uninitialized IrContainer!.") - return ir_storage_.get(); + ir_container_.get() != nullptr, + "Accessing a uninitialized IrContainer!.") + return ir_container_.get(); } public: @@ -502,106 +504,106 @@ class NVF_API Fusion : public PolymorphicBase { void resetExactMappings(); //=================================================================== - // IrStorage API Forwarding (Public Methods) + // IrContainer API Forwarding (Public Methods) //=================================================================== // Container queries bool inContainer(const Statement* stmt) const { - return ir_storage()->inContainer(stmt); + return ir_container()->inContainer(stmt); } void assertInContainer(const Statement* stmt, const std::string& msg) const { - ir_storage()->assertInContainer(stmt, msg); + ir_container()->assertInContainer(stmt, msg); } // Collections access (return values in insertion order) const std::deque deterministic_vals() const noexcept { - return ir_storage()->deterministic_vals(); + return ir_container()->deterministic_vals(); } const std::deque deterministic_exprs() const noexcept { - return ir_storage()->deterministic_exprs(); + return ir_container()->deterministic_exprs(); } const std::unordered_map deterministic_vals_map() const noexcept { - return ir_storage()->deterministic_vals_map(); + return ir_container()->deterministic_vals_map(); } const std::unordered_map deterministic_exprs_map() const noexcept { - return ir_storage()->deterministic_exprs_map(); + return ir_container()->deterministic_exprs_map(); } // Collections access (unordered sets) const std::unordered_set& unordered_exprs() const noexcept { - return ir_storage()->unordered_exprs(); + return ir_container()->unordered_exprs(); } const std::unordered_set& vals() const noexcept { - return ir_storage()->vals(); + return ir_container()->vals(); } // Count queries int64_t numExprs() const noexcept { - return ir_storage()->numExprs(); + return ir_container()->numExprs(); } int64_t numVals(bool include_shortcuts) const noexcept { - return ir_storage()->numVals(include_shortcuts); + return ir_container()->numVals(include_shortcuts); } // Shortcut values (frequently used constants) Val* zeroVal() { - return ir_storage()->zeroVal(); + return ir_container()->zeroVal(); } Val* oneVal() { - return ir_storage()->oneVal(); + return ir_container()->oneVal(); } Val* falseVal() { - return ir_storage()->falseVal(); + return ir_container()->falseVal(); } Val* trueVal() { - return ir_storage()->trueVal(); + return ir_container()->trueVal(); } NamedScalar* magicZeroVal() { - return ir_storage()->magicZeroVal(); + return ir_container()->magicZeroVal(); } Val* zeroVal(DataType dtype) { - return ir_storage()->zeroVal(dtype); + return ir_container()->zeroVal(dtype); } Val* oneVal(DataType dtype) { - return ir_storage()->oneVal(dtype); + return ir_container()->oneVal(dtype); } Val* metadataOf(Val* val) { - return ir_storage()->metadataOf(val); + return ir_container()->metadataOf(val); } // Axioms (CUDA programming assumptions) const std::vector& axioms() { - return ir_storage()->axioms(); + return ir_container()->axioms(); } void assumePositive(Val* val) { - ir_storage()->assumePositive(val); + ir_container()->assumePositive(val); } void assumeNonNegative(Val* val) { - ir_storage()->assumeNonNegative(val); + ir_container()->assumeNonNegative(val); } // Statement removal void removeStatementsCreatedAfter( int64_t num_exprs_before, int64_t num_vals_before) { - ir_storage()->removeStatementsCreatedAfter( + ir_container()->removeStatementsCreatedAfter( num_exprs_before, num_vals_before); } @@ -664,7 +666,7 @@ class NVF_API Fusion : public PolymorphicBase { std::unique_ptr> all_tvs_ptr_ = nullptr; inline static const std::string exact_mappings_key = "exact_mappings"; - std::unique_ptr ir_storage_; + std::unique_ptr ir_container_; }; // Template implementations for Fusion::manage() that use IrCloner diff --git a/csrc/ir/base_nodes.h b/csrc/ir/base_nodes.h index e78db26e81b..dc04955a185 100644 --- a/csrc/ir/base_nodes.h +++ b/csrc/ir/base_nodes.h @@ -55,7 +55,7 @@ class Fusion; class Expr; class Val; class IrCloner; -class IrStorage; +class IrContainer; class IrBuilderPasskey; class IrContainerPasskey; class ExpressionEvaluator; @@ -417,7 +417,7 @@ class NVF_API Val : public Statement { protected: friend class Fusion; - friend class IrStorage; + friend class IrContainer; // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) const ValType vtype_; diff --git a/csrc/ir/storage.cpp b/csrc/ir/storage.cpp index 6dff6c52980..3c54966c87d 100644 --- a/csrc/ir/storage.cpp +++ b/csrc/ir/storage.cpp @@ -16,7 +16,7 @@ namespace nvfuser { //! Return values in insertion order -const std::deque IrStorage::deterministic_vals() const noexcept { +const std::deque IrContainer::deterministic_vals() const noexcept { std::deque vals_deque; std::transform( vals_up_.begin(), @@ -27,7 +27,7 @@ const std::deque IrStorage::deterministic_vals() const noexcept { } //! Return expression in insertion order -const std::deque IrStorage::deterministic_exprs() const noexcept { +const std::deque IrContainer::deterministic_exprs() const noexcept { std::deque exprs_deque; std::transform( exprs_up_.begin(), @@ -38,7 +38,7 @@ const std::deque IrStorage::deterministic_exprs() const noexcept { } //! Return mapping from value to integer id -const std::unordered_map IrStorage::deterministic_vals_map() +const std::unordered_map IrContainer::deterministic_vals_map() const noexcept { std::unordered_map vals_map; int64_t count = 0; @@ -53,7 +53,7 @@ const std::unordered_map IrStorage::deterministic_vals_map() } //! Return mapping from expression to integer id -const std::unordered_map IrStorage::deterministic_exprs_map() +const std::unordered_map IrContainer::deterministic_exprs_map() const noexcept { std::unordered_map exprs_map; int64_t count = 0; @@ -67,7 +67,7 @@ const std::unordered_map IrStorage::deterministic_exprs_map() return exprs_map; } -void IrStorage::swap(IrStorage& a, IrStorage& b) noexcept { +void IrContainer::swap(IrContainer& a, IrContainer& b) noexcept { FUSER_PERF_SCOPE("Fusion swap"); // Swap the content @@ -92,7 +92,7 @@ void IrStorage::swap(IrStorage& a, IrStorage& b) noexcept { std::swap(a.axioms_, b.axioms_); } -IrCloner IrStorage::copy(const IrStorage* from, IrStorage* to) { +IrCloner IrContainer::copy(const IrContainer* from, IrContainer* to) { to->clear(); IrCloner ir_cloner(to->parent()); @@ -127,13 +127,13 @@ IrCloner IrStorage::copy(const IrStorage* from, IrStorage* to) { return ir_cloner; } -IrStorage::IrStorage() = default; +IrContainer::IrContainer() = default; -IrStorage::~IrStorage() { +IrContainer::~IrContainer() { clear(); } -void IrStorage::removeExpr(Expr* expr) { +void IrContainer::removeExpr(Expr* expr) { NVF_ERROR( exprs_.find(expr) != exprs_.end(), "Wanted to remove an expression but it doesn't exist in this container."); @@ -152,7 +152,7 @@ void IrStorage::removeExpr(Expr* expr) { //! Completely remove val from the fusion, break all dependencies associated //! with it -void IrStorage::removeVal(Val* val) { +void IrContainer::removeVal(Val* val) { // Don't remove shortcuts if (val == true_val_.get() || val == false_val_.get() || val == one_val_.get() || val == zero_val_.get() || @@ -177,7 +177,7 @@ void IrStorage::removeVal(Val* val) { } //! Register the Val with this container -void IrStorage::registerVal(Val* val) { +void IrContainer::registerVal(Val* val) { if (inContainer(val)) { return; } @@ -189,7 +189,7 @@ void IrStorage::registerVal(Val* val) { } //! Register expr with this container. -void IrStorage::registerExpr(Expr* expr) { +void IrContainer::registerExpr(Expr* expr) { if (inContainer(expr)) { return; } @@ -200,8 +200,8 @@ void IrStorage::registerExpr(Expr* expr) { expr->setName(IrContainerPasskey(), getExprName()); } -void IrStorage::clear() noexcept { - FUSER_PERF_SCOPE("IrStorage clear"); +void IrContainer::clear() noexcept { + FUSER_PERF_SCOPE("IrContainer clear"); vals_.clear(); vals_up_.clear(); exprs_.clear(); @@ -212,7 +212,7 @@ void IrStorage::clear() noexcept { expr_name_counter_ = 0; } -bool IrStorage::inContainer(const Statement* const_stmt) const { +bool IrContainer::inContainer(const Statement* const_stmt) const { // We don't use dynamic_cast here because `const_stmt` may be an invalid // pointer. Specifically a pointer to a Statement owned by another container // that has been freed. @@ -245,7 +245,7 @@ bool IrStorage::inContainer(const Statement* const_stmt) const { } // Shortcuts for frequently used vals -Val* IrStorage::zeroVal() { +Val* IrContainer::zeroVal() { if (!zero_val_) { auto zero_val = IrBuilder::createInContainer(this->parent(), 0L, DataType::Index); @@ -256,7 +256,7 @@ Val* IrStorage::zeroVal() { return zero_val_.get(); } -Val* IrStorage::zeroVal(DataType dtype) { +Val* IrContainer::zeroVal(DataType dtype) { if (dtype == DataType::Index) { return zeroVal(); } else if (isBooleanType(dtype)) { @@ -267,7 +267,7 @@ Val* IrStorage::zeroVal(DataType dtype) { } } -Val* IrStorage::oneVal() { +Val* IrContainer::oneVal() { if (!one_val_) { auto one_val = IrBuilder::createInContainer(this->parent(), 1L, DataType::Index); @@ -278,7 +278,7 @@ Val* IrStorage::oneVal() { return one_val_.get(); } -Val* IrStorage::oneVal(DataType dtype) { +Val* IrContainer::oneVal(DataType dtype) { if (dtype == DataType::Index) { return oneVal(); } else if (isBooleanType(dtype)) { @@ -289,7 +289,7 @@ Val* IrStorage::oneVal(DataType dtype) { } } -Val* IrStorage::falseVal() { +Val* IrContainer::falseVal() { if (!false_val_) { auto false_val = IrBuilder::createInContainer( this->parent(), false, DataType::Bool); @@ -300,7 +300,7 @@ Val* IrStorage::falseVal() { return false_val_.get(); } -Val* IrStorage::trueVal() { +Val* IrContainer::trueVal() { if (!true_val_) { auto true_val = IrBuilder::createInContainer(this->parent(), true, DataType::Bool); @@ -311,7 +311,7 @@ Val* IrStorage::trueVal() { return true_val_.get(); } -NamedScalar* IrStorage::magicZeroVal() { +NamedScalar* IrContainer::magicZeroVal() { if (!magic_zero_val_) { auto magic_zero = IrBuilder::create(kMagicZeroName, DataType::Index); @@ -323,7 +323,7 @@ NamedScalar* IrStorage::magicZeroVal() { return magic_zero_val_.get(); } -Val* IrStorage::metadataOf(Val* v) { +Val* IrContainer::metadataOf(Val* v) { if (metadata_.count(v) == 0) { auto metadata_val = IrBuilder::createInContainer(this->parent(), metaDataTypeOf(v)); @@ -334,7 +334,7 @@ Val* IrStorage::metadataOf(Val* v) { return metadata_.at(v).first; } -void IrStorage::lazyInitAxioms() { +void IrContainer::lazyInitAxioms() { if (!axioms_) { axioms_ = std::make_unique>(); axioms_->reserve(kParallelTypeThreads.size() * 3); @@ -349,19 +349,19 @@ void IrStorage::lazyInitAxioms() { } } -void IrStorage::assumePositive(Val* val) { +void IrContainer::assumePositive(Val* val) { NVF_ERROR(val->container() == this->parent()); lazyInitAxioms(); axioms_->emplace_back(IrBuilder::gtExpr(val, zeroVal())); } -void IrStorage::assumeNonNegative(Val* val) { +void IrContainer::assumeNonNegative(Val* val) { NVF_ERROR(val->container() == this->parent()); lazyInitAxioms(); axioms_->emplace_back(IrBuilder::geExpr(val, zeroVal())); } -void IrStorage::removeStatementsCreatedAfter( +void IrContainer::removeStatementsCreatedAfter( int64_t prev_num_exprs, int64_t prev_num_vals) { NVF_ERROR( diff --git a/csrc/ir/storage.h b/csrc/ir/storage.h index 10f5ec1f91c..e361b8743ee 100644 --- a/csrc/ir/storage.h +++ b/csrc/ir/storage.h @@ -20,7 +20,7 @@ namespace nvfuser { // Passkey for container to register names with statements class IrContainerPasskey { - friend class IrStorage; + friend class IrContainer; private: explicit IrContainerPasskey() = default; @@ -28,20 +28,20 @@ class IrContainerPasskey { class NamedScalar; -class IrStorage { +class IrContainer { public: - NVF_API IrStorage(); + NVF_API IrContainer(); - // Copy/Move Constructors and Operators are deleted. IrStorage is managed + // Copy/Move Constructors and Operators are deleted. IrContainer is managed // through a smart pointer in IrContainer. Semantic operations for Fusion // types are handled directly through copy and swap functions. - IrStorage(const IrStorage& other) = delete; - IrStorage(IrStorage&& other) noexcept = delete; + IrContainer(const IrContainer& other) = delete; + IrContainer(IrContainer&& other) noexcept = delete; - IrStorage& operator=(const IrStorage& other) = delete; - IrStorage& operator=(IrStorage&& other) noexcept = delete; + IrContainer& operator=(const IrContainer& other) = delete; + IrContainer& operator=(IrContainer&& other) noexcept = delete; - ~IrStorage(); + ~IrContainer(); bool inContainer(const Statement* stmt) const; @@ -105,11 +105,11 @@ class IrStorage { void assumeNonNegative(Val* val); protected: - static IrCloner copy(const IrStorage* from, IrStorage* to); + static IrCloner copy(const IrContainer* from, IrContainer* to); - static void swap(IrStorage& a, IrStorage& b) noexcept; + static void swap(IrContainer& a, IrContainer& b) noexcept; - // Let Fusion access IrStorage::clear() + // Let Fusion access IrContainer::clear() friend class Fusion; void removeExpr(Expr* expr);