diff --git a/CMakeLists.txt b/CMakeLists.txt index 8069ba0d7d4..d6289a29bb7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -215,7 +215,7 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/ir/base_nodes.cpp ${NVFUSER_SRCS_DIR}/ir/builder.cpp ${NVFUSER_SRCS_DIR}/ir/cloner.cpp - ${NVFUSER_SRCS_DIR}/ir/container.cpp + #${NVFUSER_SRCS_DIR}/ir/container.cpp ${NVFUSER_SRCS_DIR}/ir/storage.cpp ${NVFUSER_SRCS_DIR}/ir/graphviz.cpp ${NVFUSER_SRCS_DIR}/ir/iostream.cpp diff --git a/csrc/dispatch.h b/csrc/dispatch.h index 0ef085592f8..7c53c86d903 100644 --- a/csrc/dispatch.h +++ b/csrc/dispatch.h @@ -343,7 +343,7 @@ class NVF_API OptOutMutator : public PolymorphicBase { } protected: - virtual void removeExpr(IrContainer*, Expr*) const; + virtual void removeExpr(Fusion*, Expr*) const; virtual void registerNewExpr(Expr*) {} private: diff --git a/csrc/fusion.cpp b/csrc/fusion.cpp index 542f715dba5..ce829c3afc6 100644 --- a/csrc/fusion.cpp +++ b/csrc/fusion.cpp @@ -104,8 +104,37 @@ bool Fusion::sameDefinition(const Fusion& other) const { void Fusion::swap(Fusion& a, Fusion& b) noexcept { FUSER_PERF_SCOPE("Fusion swap"); - // Swap IrContainer base class (contains IrStorage) - IrContainer::swap(static_cast(a), static_cast(b)); + // We need to be careful to call IrContainer swap not unique_ptr swap, which + // will only swap the ptrs NOT the contents. + IrContainer::swap(*(a.ir_container()), *(b.ir_container())); + + // Fix parent pointers after swapping containers + // After swap, each IrContainer owns a different IrContainer, so we must + // update the parent backpointers in those containers to point to their new + // owners + if (a.ir_container_) { + // Also update all Statement ir_container_ pointers to point to new owner + // Note: IrContainer is now in impl namespace, but Statement::ir_container_ + // is Fusion*. Since only Fusion (and its derived classes) inherit from + // impl::IrContainer, this cast is safe. + a.ir_container()->parent_ = &a; + for (auto val : a.vals()) { + val->ir_container_ = &a; + } + for (auto expr : a.deterministic_exprs()) { + expr->ir_container_ = &a; + } + } + if (b.ir_container_) { + // Also update all Statement ir_container_ pointers to point to new owner + b.ir_container()->parent_ = &b; + for (auto val : b.vals()) { + val->ir_container_ = &b; + } + for (auto expr : b.deterministic_exprs()) { + expr->ir_container_ = &b; + } + } std::swap(a.inputs_, b.inputs_); std::swap(a.outputs_, b.outputs_); @@ -122,7 +151,7 @@ std::unique_ptr Fusion::segment( IrCloner Fusion::copy(const Fusion* from, Fusion* to) { to->clear(); - auto ir_cloner = IrContainer::copy(from, to); + auto ir_cloner = IrContainer::copy(from->ir_container(), to->ir_container()); for (auto val : from->vals()) { ir_cloner.clone(val)->setDefinition(ir_cloner.clone(val->definition_)); @@ -183,14 +212,19 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) { return ir_cloner; } +// Default constructor +Fusion::Fusion() : ir_container_(std::make_unique()) { + ir_container_->parent_ = this; +} + // Copy constructor -Fusion::Fusion(const Fusion& other) { +Fusion::Fusion(const Fusion& other) : Fusion() { FUSER_PERF_SCOPE("Fusion copy"); Fusion::copy(&other, this); } // Move constructor -Fusion::Fusion(Fusion&& other) noexcept { +Fusion::Fusion(Fusion&& other) noexcept : Fusion() { FUSER_PERF_SCOPE("Fusion move"); swap(*this, other); } @@ -223,7 +257,7 @@ void Fusion::clear() noexcept { // Clear container contents instead of destroying it // This preserves the container object so Statement pointers don't become // dangling - ir_storage()->clear(); + ir_container()->clear(); inputs_.clear(); outputs_.clear(); @@ -260,7 +294,8 @@ void Fusion::removeExpr(Expr* expr) { } } - IrContainer::removeExpr(expr); + // TODO : CHECK THIS vvv + ir_container()->removeExpr(expr); } void Fusion::removeVal(Val* val) { @@ -304,7 +339,7 @@ void Fusion::removeVal(Val* val) { for (auto e : exprs_to_remove) { removeExpr(e); } - IrContainer::removeVal(val); + ir_container()->removeVal(val); invalidateTvsAndUses(); } @@ -668,7 +703,7 @@ void Fusion::registerVal(Val* val) { val->fusion() == this, val, " was not found in the active fusion."); } - IrContainer::registerVal(val); + ir_container()->registerVal(val); } void Fusion::registerExpr(Expr* expr) { @@ -681,7 +716,7 @@ void Fusion::registerExpr(Expr* expr) { expr->fusion() == this, expr, " was not found in the active fusion."); } - IrContainer::registerExpr(expr); + ir_container()->registerExpr(expr); for (Val* input : expr->inputs()) { assertInContainer(input, "Input to expr is invalid, "); diff --git a/csrc/fusion.h b/csrc/fusion.h index f57446e1767..4e7ce658574 100644 --- a/csrc/fusion.h +++ b/csrc/fusion.h @@ -12,6 +12,7 @@ #include #include #include +#include "base.h" #include @@ -20,7 +21,7 @@ #include #include #include -#include +#include #include #include #include @@ -142,11 +143,36 @@ class AliasInfoMap { //! The Fusion owns the whole IR graph (Vals and Exprs) //! // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) -class NVF_API Fusion : public IrContainer { +class NVF_API Fusion : public PolymorphicBase { typedef std::unordered_map> PermutationMap; + protected: + // Direct access to underlying container + IrContainer* ir_container() { + NVF_ERROR( + ir_container_.get() != nullptr, + "Accessing a uninitialized IrContainer!.") + return ir_container_.get(); + } + + const IrContainer* ir_container() const { + NVF_ERROR( + ir_container_.get() != nullptr, + "Accessing a uninitialized IrContainer!.") + return ir_container_.get(); + } + public: - Fusion() = default; + // Registration (public API with passkey) + virtual void registerStmt(IrBuilderPasskey passkey, Statement* stmt) { + if (stmt->isVal()) { + registerVal(stmt->asVal()); + } else { + registerExpr(stmt->asExpr()); + } + } + + Fusion(); Fusion(const Fusion& other); Fusion(Fusion&& other) noexcept; @@ -168,11 +194,11 @@ class NVF_API Fusion : public IrContainer { //! Break dependency chains associated with Expr, remove references to expr //! delete expr - void removeExpr(Expr* expr) override; + virtual void removeExpr(Expr* expr); //! Completely remove val from the fusion, break all dependencies associated //! with it - void removeVal(Val* val) override; + virtual void removeVal(Val* val); //! Register input as an input of the fusion void addInput(Val* input); @@ -477,17 +503,118 @@ class NVF_API Fusion : public IrContainer { void resetExactMappings(); + //=================================================================== + // IrContainer API Forwarding (Public Methods) + //=================================================================== + + // Container queries + bool inContainer(const Statement* stmt) const { + return ir_container()->inContainer(stmt); + } + + void assertInContainer(const Statement* stmt, const std::string& msg) const { + ir_container()->assertInContainer(stmt, msg); + } + + // Collections access (return values in insertion order) + const std::deque deterministic_vals() const noexcept { + return ir_container()->deterministic_vals(); + } + + const std::deque deterministic_exprs() const noexcept { + return ir_container()->deterministic_exprs(); + } + + const std::unordered_map deterministic_vals_map() + const noexcept { + return ir_container()->deterministic_vals_map(); + } + + const std::unordered_map deterministic_exprs_map() + const noexcept { + return ir_container()->deterministic_exprs_map(); + } + + // Collections access (unordered sets) + const std::unordered_set& unordered_exprs() const noexcept { + return ir_container()->unordered_exprs(); + } + + const std::unordered_set& vals() const noexcept { + return ir_container()->vals(); + } + + // Count queries + int64_t numExprs() const noexcept { + return ir_container()->numExprs(); + } + + int64_t numVals(bool include_shortcuts) const noexcept { + return ir_container()->numVals(include_shortcuts); + } + + // Shortcut values (frequently used constants) + Val* zeroVal() { + return ir_container()->zeroVal(); + } + + Val* oneVal() { + return ir_container()->oneVal(); + } + + Val* falseVal() { + return ir_container()->falseVal(); + } + + Val* trueVal() { + return ir_container()->trueVal(); + } + + NamedScalar* magicZeroVal() { + return ir_container()->magicZeroVal(); + } + + Val* zeroVal(DataType dtype) { + return ir_container()->zeroVal(dtype); + } + + Val* oneVal(DataType dtype) { + return ir_container()->oneVal(dtype); + } + + Val* metadataOf(Val* val) { + return ir_container()->metadataOf(val); + } + + // Axioms (CUDA programming assumptions) + const std::vector& axioms() { + return ir_container()->axioms(); + } + + void assumePositive(Val* val) { + ir_container()->assumePositive(val); + } + + void assumeNonNegative(Val* val) { + ir_container()->assumeNonNegative(val); + } + + // Statement removal + void removeStatementsCreatedAfter( + int64_t num_exprs_before, + int64_t num_vals_before) { + ir_container()->removeStatementsCreatedAfter( + num_exprs_before, num_vals_before); + } + protected: friend SegmentCandidateFinder; friend SegmentedFusion; friend class TranslateApplicableWelford; friend Val; - using IrContainer::registerExpr; - using IrContainer::registerVal; - //! Register the Val with this fusion - void registerVal(Val* val) override; + virtual void registerVal(Val* val); //! Register expr with this fusion. //! When we register an expression, we want to update the dependency tracking @@ -495,7 +622,7 @@ class NVF_API Fusion : public IrContainer { //! definitions of outputs and register this Expr as the definition. Otherwise //! will update definition if not previously set, but will not remove old //! definitions. - void registerExpr(Expr* expr) override; + virtual void registerExpr(Expr* expr); //! Clear Expr's from TV uses that are not required to produce outputs from //! inputs. Only other place this is used (other than Fusion) is in @@ -539,22 +666,18 @@ class NVF_API Fusion : public IrContainer { std::unique_ptr> all_tvs_ptr_ = nullptr; inline static const std::string exact_mappings_key = "exact_mappings"; + std::unique_ptr ir_container_; }; +// Template implementations for Fusion::manage() that use IrCloner template std::any defaultCloneFunction(IrCloner& cloner, std::any data) { auto cloned_data = cloner.clone(std::any_cast(data)); - // Adding a static_assert to improve error message. Without this - // static_assert, the following cast will still fail, but the error message - // will be unreadable. static_assert( std::is_convertible_v, "IrCloner::clone returns a data type that is not compatible with the " "original managed data type. " "Likely you will need to check IrCloner::clone for your data type."); - // Convert the result of the clone back to T before assigning to std::any. - // This ensures the type of the std::any does not change over the clone of - // fusion. return std::any((T)cloned_data); } @@ -568,4 +691,41 @@ void Fusion::manage(std::string key, T data) { return manage(key, std::any(data), defaultCloneFunction); } +// Template implementations for IrBuilder that require Fusion to be fully +// defined +template +T* IrBuilder::createInContainer(Fusion* container, Args&&... args) { + NVF_ERROR(container != nullptr, "Need an active container to build IR."); + T* node = new T(IrBuilderPasskey(container), std::forward(args)...); + container->registerStmt(IrBuilderPasskey(container), node); + return node; +} + +template +T* IrBuilder::clone(const T* src, IrCloner* ir_cloner) { + NVF_ERROR( + ir_cloner != nullptr, + "Cannot use create when a cloner object is set. Use clone."); + NVF_ERROR( + ir_cloner->container() != nullptr, + "Cloner doesn't have a valid container to store cloned object."); + + T* dest = new T(src, ir_cloner); + const auto* src_stmt = dynamic_cast(src); + auto* dest_stmt = dynamic_cast(dest); + + auto dest_container = ir_cloner->container(); + auto src_container = src_stmt->container(); + + dest_container->registerStmt(IrBuilderPasskey(dest_container), dest_stmt); + + if (src_container != dest_container) { + dest_stmt->setName(IrBuilderPasskey(dest_container), src_stmt->name()); + } + + ir_cloner->registerClone(src_stmt, dest_stmt); + + return dest; +} + } // namespace nvfuser diff --git a/csrc/ir/base_nodes.h b/csrc/ir/base_nodes.h index d30f7abd145..dc04955a185 100644 --- a/csrc/ir/base_nodes.h +++ b/csrc/ir/base_nodes.h @@ -9,8 +9,6 @@ #include #include -#include -#include #include #include @@ -58,7 +56,6 @@ class Expr; class Val; class IrCloner; class IrContainer; -class IrStorage; class IrBuilderPasskey; class IrContainerPasskey; class ExpressionEvaluator; @@ -94,9 +91,7 @@ class ExprPasskey { //! Basically beinng able to succienctly traverse down the inhereitance stack of //! a Statment at runtime. This is currently implemented in dispatch.h class NVF_API Statement : public NonCopyable, public PolymorphicBase { - friend void swap(Fusion&, Fusion&) noexcept; - friend void swap(IrContainer& a, IrContainer& b) noexcept; - friend class IrContainer; + friend class Fusion; public: Statement() = delete; @@ -143,7 +138,7 @@ class NVF_API Statement : public NonCopyable, public PolymorphicBase { kir::Kernel* kernel() const; // Return the container this statement belongs to - IrContainer* container() const { + Fusion* container() const { return ir_container_; } @@ -186,7 +181,7 @@ class NVF_API Statement : public NonCopyable, public PolymorphicBase { StmtNameType name_ = kInvalidStmName; // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) - IrContainer* ir_container_ = nullptr; + Fusion* ir_container_ = nullptr; }; inline std::string toString(Statement* stmt) { @@ -423,7 +418,6 @@ class NVF_API Val : public Statement { protected: friend class Fusion; friend class IrContainer; - friend class IrStorage; // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) const ValType vtype_; @@ -480,7 +474,7 @@ class NVF_API Val : public Statement { }; using newObjectFuncType = Expr*( - IrContainer*, + Fusion*, std::vector, std::vector, std::vector); @@ -704,7 +698,7 @@ bool Val::isDefinitionType() const { #define NVFUSER_DECLARE_CLONE_AND_CREATE \ virtual Statement* clone(IrCloner* ir_cloner) const override; \ static Expr* newObject( \ - IrContainer* container, \ + Fusion* container, \ std::vector inputs, \ std::vector outputs, \ std::vector attributes); \ @@ -717,7 +711,7 @@ bool Val::isDefinitionType() const { return IrBuilder::clone(this, ir_cloner); \ } \ Expr* ClassName::newObject( \ - IrContainer* container, \ + Fusion* container, \ std::vector inputs, \ std::vector outputs, \ std::vector attributes) { \ diff --git a/csrc/ir/builder.h b/csrc/ir/builder.h index 8df6c2d64e4..af0fea66d32 100644 --- a/csrc/ir/builder.h +++ b/csrc/ir/builder.h @@ -39,16 +39,11 @@ class IrBuilder { } //! Allocate a new IR node, forwarding the arguments to the appropriate - //! constructor and registering with the container + //! constructor and registering with the container. + //! Implementation provided at the end of fusion.h after Fusion is fully + //! defined. template - static T* createInContainer(IrContainer* container, Args&&... args) { - NVF_ERROR(container != nullptr, "Need an active container to build IR."); - T* node = new T(IrBuilderPasskey(container), std::forward(args)...); - - container->registerStmt(IrBuilderPasskey(container), node); - - return node; - } + static T* createInContainer(Fusion* container, Args&&... args); //! Clone an IR node, forwarding the arguments to the IrCloner constructor. //! Register clones with IrCloner's target container. diff --git a/csrc/ir/builder_passkey.h b/csrc/ir/builder_passkey.h index b15f2a9a76d..79aa6404bd6 100644 --- a/csrc/ir/builder_passkey.h +++ b/csrc/ir/builder_passkey.h @@ -9,18 +9,18 @@ namespace nvfuser { -class IrContainer; +class Fusion; // Passkey for builder to register properties with statements, and to call -// functions in IrContainer +// functions in IrContainer (now via Fusion) class IrBuilderPasskey { friend class IrBuilder; public: - IrContainer* const ir_container_ = nullptr; + Fusion* const ir_container_ = nullptr; private: - explicit IrBuilderPasskey(IrContainer* ir_container) + explicit IrBuilderPasskey(Fusion* ir_container) : ir_container_(ir_container) {} }; diff --git a/csrc/ir/cloner.cpp b/csrc/ir/cloner.cpp index 6a38e7113b2..c71c04f082c 100644 --- a/csrc/ir/cloner.cpp +++ b/csrc/ir/cloner.cpp @@ -13,7 +13,7 @@ namespace nvfuser { -IrCloner::IrCloner(IrContainer* container) : ir_container_(container) { +IrCloner::IrCloner(Fusion* container) : ir_container_(container) { NVF_ERROR( container != nullptr, "IrCloner constructor received NULL container pointer"); diff --git a/csrc/ir/cloner.h b/csrc/ir/cloner.h index 83e0e444430..9a964f4d66a 100644 --- a/csrc/ir/cloner.h +++ b/csrc/ir/cloner.h @@ -21,7 +21,7 @@ namespace nvfuser { -class IrContainer; +class Fusion; //! Clones nodes from an exiting Fusion //! @@ -35,7 +35,7 @@ class IrCloner { public: // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) - explicit IrCloner(IrContainer* container); + explicit IrCloner(Fusion* container); virtual ~IrCloner() = default; NVF_API Statement* clone(const Statement* statement); @@ -140,7 +140,7 @@ class IrCloner { return cloned_disjoint_sets; } - IrContainer* container() const { + Fusion* container() const { return ir_container_; } @@ -155,7 +155,7 @@ class IrCloner { private: // The destination Fusion container - IrContainer* ir_container_ = nullptr; + Fusion* ir_container_ = nullptr; // Builder to make all the new nodes IrBuilder builder_; @@ -177,33 +177,7 @@ class RecomputeTv : private IrCloner { Statement* handle(const TensorDomain*); }; -//! Clone an IR node, forwarding the arguments to the IrCloner constructor. -template -T* IrBuilder::clone(const T* src, IrCloner* ir_cloner) { - NVF_ERROR( - ir_cloner != nullptr, - "Cannot use create when a cloner object is set. Use clone."); - - NVF_ERROR( - ir_cloner->container() != nullptr, - "Cloner doesn't have a valid container to store cloned object."); - - T* dest = new T(src, ir_cloner); - const auto* src_stmt = dynamic_cast(src); - auto* dest_stmt = dynamic_cast(dest); - - auto dest_container = ir_cloner->container(); - auto src_container = src_stmt->container(); - - dest_container->registerStmt(IrBuilderPasskey(dest_container), dest_stmt); - - if (src_container != dest_container) { - dest_stmt->setName(IrBuilderPasskey(dest_container), src_stmt->name()); - } - - ir_cloner->registerClone(src_stmt, dest_stmt); - - return dest; -} +// Note: IrBuilder::clone() template implementation is in fusion.h +// after Fusion is fully defined } // namespace nvfuser diff --git a/csrc/ir/container.cpp b/csrc/ir/container.cpp deleted file mode 100644 index e59905f1da3..00000000000 --- a/csrc/ir/container.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// clang-format off -/* - * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. - * All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - */ -// clang-format on - -#include - -#include -#include - -namespace nvfuser { - -IrContainer::IrContainer() : ir_storage_(std::make_unique()) { - ir_storage()->parent_ = this; -} - -IrContainer::~IrContainer() {} - -void IrContainer::swap(IrContainer& a, IrContainer& b) noexcept { - // We need to be careful to call IrStorage swap not unique_ptr swap, which - // will only swap the ptrs NOT the contents. - IrStorage::swap(*(a.ir_storage()), *(b.ir_storage())); - - // Fix parent pointers after swapping containers - // After swap, each IrContainer owns a different IrStorage, so we must update - // the parent backpointers in those containers to point to their new owners - if (a.ir_storage_) { - a.ir_storage()->parent_ = &a; - // Also update all Statement ir_container_ pointers to point to new owner - for (auto val : a.vals()) { - val->ir_container_ = &a; - } - for (auto expr : a.deterministic_exprs()) { - expr->ir_container_ = &a; - } - } - if (b.ir_storage_) { - b.ir_storage()->parent_ = &b; - // Also update all Statement ir_container_ pointers to point to new owner - for (auto val : b.vals()) { - val->ir_container_ = &b; - } - for (auto expr : b.deterministic_exprs()) { - expr->ir_container_ = &b; - } - } -} - -IrCloner IrContainer::copy(const IrContainer* from, IrContainer* to) { - auto ir_cloner = IrStorage::copy(from->ir_storage(), to->ir_storage()); - - return ir_cloner; -} -} // namespace nvfuser diff --git a/csrc/ir/container.h b/csrc/ir/container.h index 1b73d22b4a1..d6a53fb1e3a 100644 --- a/csrc/ir/container.h +++ b/csrc/ir/container.h @@ -7,233 +7,8 @@ // clang-format on #pragma once -#include - -#include -#include -#include - namespace nvfuser { -class IrBuilderPasskey; -class ExprPasskey; -class OptOutMutator; - -// Passkey for container to register names with statements -class IrContainerPasskey { - friend class IrContainer; - friend class IrStorage; - - private: - explicit IrContainerPasskey() = default; -}; - -// IrContainer: Base class for types that provide IrContainer API via -// composition -// -// This class handles the composition infrastructure and forwarding boilerplate -// for accessing IrStorage functionality. Derived classes (like Fusion) can -// focus on their specific logic while inheriting the full IrContainer API. -// -// Key Features: -// - Owns IrStorage via unique_ptr (can be shared_ptr in Phase 2) -// - Forwards all IrStorage public methods -// - Allows derived classes to override protected IrContainer methods -class NVF_API IrContainer : public PolymorphicBase { - protected: - // Constructors - explicit IrContainer(); - - // TODO: The semantics of IrContainers are largely driven through copy/swap - // function behavior. It might be better if this behaviour was properly - // defined through class semantics directly. - // - // Copy/Move are deleted. IrContainer is a forwarding interface class. We - // rely on copy/swap function behavior to handle the semantics of IrStorage. - IrContainer(const IrContainer& other) = delete; - IrContainer(IrContainer&& other) noexcept = delete; - IrContainer& operator=(const IrContainer& other) = delete; - IrContainer& operator=(IrContainer&& other) noexcept = delete; - - ~IrContainer() override; - - // Let mutator remove Exprs. - friend OptOutMutator; - - public: - //=================================================================== - // IrStorage API Forwarding (Public Methods) - //=================================================================== - - // Container queries - bool inContainer(const Statement* stmt) const { - return ir_storage()->inContainer(stmt); - } - - void assertInContainer(const Statement* stmt, const std::string& msg) const { - ir_storage()->assertInContainer(stmt, msg); - } - - // Collections access (return values in insertion order) - const std::deque deterministic_vals() const noexcept { - return ir_storage()->deterministic_vals(); - } - - const std::deque deterministic_exprs() const noexcept { - return ir_storage()->deterministic_exprs(); - } - - const std::unordered_map deterministic_vals_map() - const noexcept { - return ir_storage()->deterministic_vals_map(); - } - - const std::unordered_map deterministic_exprs_map() - const noexcept { - return ir_storage()->deterministic_exprs_map(); - } - - // Collections access (unordered sets) - const std::unordered_set& unordered_exprs() const noexcept { - return ir_storage()->unordered_exprs(); - } - - const std::unordered_set& vals() const noexcept { - return ir_storage()->vals(); - } - - // Count queries - int64_t numExprs() const noexcept { - return ir_storage()->numExprs(); - } - - int64_t numVals(bool include_shortcuts) const noexcept { - return ir_storage()->numVals(include_shortcuts); - } - - // Shortcut values (frequently used constants) - Val* zeroVal() { - return ir_storage()->zeroVal(); - } - - Val* oneVal() { - return ir_storage()->oneVal(); - } - - Val* falseVal() { - return ir_storage()->falseVal(); - } - - Val* trueVal() { - return ir_storage()->trueVal(); - } - - NamedScalar* magicZeroVal() { - return ir_storage()->magicZeroVal(); - } - - Val* zeroVal(DataType dtype) { - return ir_storage()->zeroVal(dtype); - } - - Val* oneVal(DataType dtype) { - return ir_storage()->oneVal(dtype); - } - - Val* metadataOf(Val* val) { - return ir_storage()->metadataOf(val); - } - - // Axioms (CUDA programming assumptions) - const std::vector& axioms() { - return ir_storage()->axioms(); - } - - void assumePositive(Val* val) { - ir_storage()->assumePositive(val); - } - - void assumeNonNegative(Val* val) { - ir_storage()->assumeNonNegative(val); - } - - // Statement removal - void removeStatementsCreatedAfter( - int64_t num_exprs_before, - int64_t num_vals_before) { - ir_storage()->removeStatementsCreatedAfter( - num_exprs_before, num_vals_before); - } - - // Registration (public API with passkey) - virtual void registerStmt(IrBuilderPasskey passkey, Statement* stmt) { - // Dispatch to Val or Expr registration, which calls the virtual protected - // methods that subclasses (like Fusion) override - if (stmt->isVal()) { - registerVal(passkey, stmt->asVal()); - } else { - registerExpr(passkey, stmt->asExpr()); - } - } - - virtual void registerVal(IrBuilderPasskey passkey, Val* val) { - // Call the protected virtual method that subclasses override - registerVal(val); - } - - virtual void registerExpr(IrBuilderPasskey passkey, Expr* expr) { - // Call the protected virtual method that subclasses override - registerExpr(expr); - } - - //=================================================================== - // Container Access - //=================================================================== - - // Direct access to underlying container - IrStorage* ir_storage() { - NVF_ERROR( - ir_storage_.get() != nullptr, "Accessing a uninitialized IrContainer!.") - return ir_storage_.get(); - } - - const IrStorage* ir_storage() const { - NVF_ERROR( - ir_storage_.get() != nullptr, "Accessing a uninitialized IrContainer!.") - return ir_storage_.get(); - } - - protected: - //=================================================================== - // Protected Registration API (for derived class overrides) - //=================================================================== - - static IrCloner copy(const IrContainer* from, IrContainer* to); - static void swap(IrContainer& a, IrContainer& b) noexcept; - - // Derived classes (like Fusion) override these to add custom logic - virtual void registerVal(Val* val) { - ir_storage()->registerVal(val); - } - - virtual void registerExpr(Expr* expr) { - ir_storage()->registerExpr(expr); - } - - virtual void removeExpr(Expr* expr) { - ir_storage()->removeExpr(expr); - } - - virtual void removeVal(Val* val) { - ir_storage()->removeVal(val); - } - - private: - //=================================================================== - // Data Members - //=================================================================== - - std::unique_ptr ir_storage_; -}; +// Empty for now... } // namespace nvfuser diff --git a/csrc/ir/storage.cpp b/csrc/ir/storage.cpp index 6dff6c52980..3c54966c87d 100644 --- a/csrc/ir/storage.cpp +++ b/csrc/ir/storage.cpp @@ -16,7 +16,7 @@ namespace nvfuser { //! Return values in insertion order -const std::deque IrStorage::deterministic_vals() const noexcept { +const std::deque IrContainer::deterministic_vals() const noexcept { std::deque vals_deque; std::transform( vals_up_.begin(), @@ -27,7 +27,7 @@ const std::deque IrStorage::deterministic_vals() const noexcept { } //! Return expression in insertion order -const std::deque IrStorage::deterministic_exprs() const noexcept { +const std::deque IrContainer::deterministic_exprs() const noexcept { std::deque exprs_deque; std::transform( exprs_up_.begin(), @@ -38,7 +38,7 @@ const std::deque IrStorage::deterministic_exprs() const noexcept { } //! Return mapping from value to integer id -const std::unordered_map IrStorage::deterministic_vals_map() +const std::unordered_map IrContainer::deterministic_vals_map() const noexcept { std::unordered_map vals_map; int64_t count = 0; @@ -53,7 +53,7 @@ const std::unordered_map IrStorage::deterministic_vals_map() } //! Return mapping from expression to integer id -const std::unordered_map IrStorage::deterministic_exprs_map() +const std::unordered_map IrContainer::deterministic_exprs_map() const noexcept { std::unordered_map exprs_map; int64_t count = 0; @@ -67,7 +67,7 @@ const std::unordered_map IrStorage::deterministic_exprs_map() return exprs_map; } -void IrStorage::swap(IrStorage& a, IrStorage& b) noexcept { +void IrContainer::swap(IrContainer& a, IrContainer& b) noexcept { FUSER_PERF_SCOPE("Fusion swap"); // Swap the content @@ -92,7 +92,7 @@ void IrStorage::swap(IrStorage& a, IrStorage& b) noexcept { std::swap(a.axioms_, b.axioms_); } -IrCloner IrStorage::copy(const IrStorage* from, IrStorage* to) { +IrCloner IrContainer::copy(const IrContainer* from, IrContainer* to) { to->clear(); IrCloner ir_cloner(to->parent()); @@ -127,13 +127,13 @@ IrCloner IrStorage::copy(const IrStorage* from, IrStorage* to) { return ir_cloner; } -IrStorage::IrStorage() = default; +IrContainer::IrContainer() = default; -IrStorage::~IrStorage() { +IrContainer::~IrContainer() { clear(); } -void IrStorage::removeExpr(Expr* expr) { +void IrContainer::removeExpr(Expr* expr) { NVF_ERROR( exprs_.find(expr) != exprs_.end(), "Wanted to remove an expression but it doesn't exist in this container."); @@ -152,7 +152,7 @@ void IrStorage::removeExpr(Expr* expr) { //! Completely remove val from the fusion, break all dependencies associated //! with it -void IrStorage::removeVal(Val* val) { +void IrContainer::removeVal(Val* val) { // Don't remove shortcuts if (val == true_val_.get() || val == false_val_.get() || val == one_val_.get() || val == zero_val_.get() || @@ -177,7 +177,7 @@ void IrStorage::removeVal(Val* val) { } //! Register the Val with this container -void IrStorage::registerVal(Val* val) { +void IrContainer::registerVal(Val* val) { if (inContainer(val)) { return; } @@ -189,7 +189,7 @@ void IrStorage::registerVal(Val* val) { } //! Register expr with this container. -void IrStorage::registerExpr(Expr* expr) { +void IrContainer::registerExpr(Expr* expr) { if (inContainer(expr)) { return; } @@ -200,8 +200,8 @@ void IrStorage::registerExpr(Expr* expr) { expr->setName(IrContainerPasskey(), getExprName()); } -void IrStorage::clear() noexcept { - FUSER_PERF_SCOPE("IrStorage clear"); +void IrContainer::clear() noexcept { + FUSER_PERF_SCOPE("IrContainer clear"); vals_.clear(); vals_up_.clear(); exprs_.clear(); @@ -212,7 +212,7 @@ void IrStorage::clear() noexcept { expr_name_counter_ = 0; } -bool IrStorage::inContainer(const Statement* const_stmt) const { +bool IrContainer::inContainer(const Statement* const_stmt) const { // We don't use dynamic_cast here because `const_stmt` may be an invalid // pointer. Specifically a pointer to a Statement owned by another container // that has been freed. @@ -245,7 +245,7 @@ bool IrStorage::inContainer(const Statement* const_stmt) const { } // Shortcuts for frequently used vals -Val* IrStorage::zeroVal() { +Val* IrContainer::zeroVal() { if (!zero_val_) { auto zero_val = IrBuilder::createInContainer(this->parent(), 0L, DataType::Index); @@ -256,7 +256,7 @@ Val* IrStorage::zeroVal() { return zero_val_.get(); } -Val* IrStorage::zeroVal(DataType dtype) { +Val* IrContainer::zeroVal(DataType dtype) { if (dtype == DataType::Index) { return zeroVal(); } else if (isBooleanType(dtype)) { @@ -267,7 +267,7 @@ Val* IrStorage::zeroVal(DataType dtype) { } } -Val* IrStorage::oneVal() { +Val* IrContainer::oneVal() { if (!one_val_) { auto one_val = IrBuilder::createInContainer(this->parent(), 1L, DataType::Index); @@ -278,7 +278,7 @@ Val* IrStorage::oneVal() { return one_val_.get(); } -Val* IrStorage::oneVal(DataType dtype) { +Val* IrContainer::oneVal(DataType dtype) { if (dtype == DataType::Index) { return oneVal(); } else if (isBooleanType(dtype)) { @@ -289,7 +289,7 @@ Val* IrStorage::oneVal(DataType dtype) { } } -Val* IrStorage::falseVal() { +Val* IrContainer::falseVal() { if (!false_val_) { auto false_val = IrBuilder::createInContainer( this->parent(), false, DataType::Bool); @@ -300,7 +300,7 @@ Val* IrStorage::falseVal() { return false_val_.get(); } -Val* IrStorage::trueVal() { +Val* IrContainer::trueVal() { if (!true_val_) { auto true_val = IrBuilder::createInContainer(this->parent(), true, DataType::Bool); @@ -311,7 +311,7 @@ Val* IrStorage::trueVal() { return true_val_.get(); } -NamedScalar* IrStorage::magicZeroVal() { +NamedScalar* IrContainer::magicZeroVal() { if (!magic_zero_val_) { auto magic_zero = IrBuilder::create(kMagicZeroName, DataType::Index); @@ -323,7 +323,7 @@ NamedScalar* IrStorage::magicZeroVal() { return magic_zero_val_.get(); } -Val* IrStorage::metadataOf(Val* v) { +Val* IrContainer::metadataOf(Val* v) { if (metadata_.count(v) == 0) { auto metadata_val = IrBuilder::createInContainer(this->parent(), metaDataTypeOf(v)); @@ -334,7 +334,7 @@ Val* IrStorage::metadataOf(Val* v) { return metadata_.at(v).first; } -void IrStorage::lazyInitAxioms() { +void IrContainer::lazyInitAxioms() { if (!axioms_) { axioms_ = std::make_unique>(); axioms_->reserve(kParallelTypeThreads.size() * 3); @@ -349,19 +349,19 @@ void IrStorage::lazyInitAxioms() { } } -void IrStorage::assumePositive(Val* val) { +void IrContainer::assumePositive(Val* val) { NVF_ERROR(val->container() == this->parent()); lazyInitAxioms(); axioms_->emplace_back(IrBuilder::gtExpr(val, zeroVal())); } -void IrStorage::assumeNonNegative(Val* val) { +void IrContainer::assumeNonNegative(Val* val) { NVF_ERROR(val->container() == this->parent()); lazyInitAxioms(); axioms_->emplace_back(IrBuilder::geExpr(val, zeroVal())); } -void IrStorage::removeStatementsCreatedAfter( +void IrContainer::removeStatementsCreatedAfter( int64_t prev_num_exprs, int64_t prev_num_vals) { NVF_ERROR( diff --git a/csrc/ir/storage.h b/csrc/ir/storage.h index 174d64091fc..863c2795628 100644 --- a/csrc/ir/storage.h +++ b/csrc/ir/storage.h @@ -18,22 +18,30 @@ namespace nvfuser { +// Passkey for container to register names with statements +class IrContainerPasskey { + friend class IrContainer; + + private: + explicit IrContainerPasskey() = default; +}; + class NamedScalar; -class IrStorage { +class IrContainer { public: - NVF_API IrStorage(); + NVF_API IrContainer(); - // Copy/Move Constructors and Operators are deleted. IrStorage is managed + // Copy/Move Constructors and Operators are deleted. IrContainer is managed // through a smart pointer in IrContainer. Semantic operations for Fusion // types are handled directly through copy and swap functions. - IrStorage(const IrStorage& other) = delete; - IrStorage(IrStorage&& other) noexcept = delete; + IrContainer(const IrContainer& other) = delete; + IrContainer(IrContainer&& other) noexcept = delete; - IrStorage& operator=(const IrStorage& other) = delete; - IrStorage& operator=(IrStorage&& other) noexcept = delete; + IrContainer& operator=(const IrContainer& other) = delete; + IrContainer& operator=(IrContainer&& other) noexcept = delete; - ~IrStorage(); + ~IrContainer(); bool inContainer(const Statement* stmt) const; @@ -97,14 +105,11 @@ class IrStorage { void assumeNonNegative(Val* val); protected: - static IrCloner copy(const IrStorage* from, IrStorage* to); - - static void swap(IrStorage& a, IrStorage& b) noexcept; + static IrCloner copy(const IrContainer* from, IrContainer* to); - // Let IrInterface access protected methods for forwarding - friend class IrContainer; + static void swap(IrContainer& a, IrContainer& b) noexcept; - // Let Fusion access IrStorage::clear() + // Let Fusion access IrContainer::clear() friend class Fusion; void removeExpr(Expr* expr); @@ -183,7 +188,7 @@ class IrStorage { std::unordered_map> metadata_; public: - IrContainer* parent() const { + Fusion* parent() const { NVF_ERROR( parent_ != nullptr, "Call to IrContainer::parent() holds nullptr.") return parent_; @@ -192,7 +197,7 @@ class IrStorage { private: // Parent IrInterface that owns this container (for pure composition pattern) // Used by Statement::fusion() to navigate back to owning Fusion - IrContainer* parent_ = nullptr; + Fusion* parent_ = nullptr; }; } // namespace nvfuser diff --git a/csrc/ir/utils.cpp b/csrc/ir/utils.cpp index bf2dc04f281..34792ccb5cc 100644 --- a/csrc/ir/utils.cpp +++ b/csrc/ir/utils.cpp @@ -200,7 +200,7 @@ struct SubstituteInExpr : public OptOutMutator { } protected: - void removeExpr(IrContainer*, Expr*) const override {} + void removeExpr(Fusion*, Expr*) const override {} void registerNewExpr(Expr* expr) override { expr_ = expr; diff --git a/csrc/kernel.h b/csrc/kernel.h index 0c97f24dcde..d6fdd2f2fa8 100644 --- a/csrc/kernel.h +++ b/csrc/kernel.h @@ -280,9 +280,6 @@ class NVF_API Kernel final : public Fusion { } protected: - using IrContainer::registerExpr; - using IrContainer::registerVal; - //! Register the Val with this fusion void registerVal(Val* val) override; diff --git a/csrc/mutator.cpp b/csrc/mutator.cpp index d5395c4f0f5..cfc529b6337 100644 --- a/csrc/mutator.cpp +++ b/csrc/mutator.cpp @@ -291,7 +291,7 @@ Expr* OptOutMutator::mutateExpr( return new_expr; } -void OptOutMutator::removeExpr(IrContainer* container, Expr* expr) const { +void OptOutMutator::removeExpr(Fusion* container, Expr* expr) const { container->removeExpr(expr); }