-
Notifications
You must be signed in to change notification settings - Fork 76
[IR Refactor] Fusion Base Type #5902
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
mdavis36
wants to merge
7
commits into
main
Choose a base branch
from
md/fusion-base
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+293
−419
Open
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
5f9c929
Use Fusion* as base pointer for all IrContainers; Move interface into…
mdavis36 02f22a8
Directly point to Fusion* types from IrStorage.
mdavis36 dfd0c17
Moving forwarding interface into Fusion.
mdavis36 c1a3da7
Clean up commented interface definitions.
mdavis36 2bd85f3
Fusion owns u_ptr, Interface IrContainer class rm.
mdavis36 f162b0c
Completely removing the forwarding interface of IrContainer.
mdavis36 3aedecf
IrStorage -> IrContainer
mdavis36 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -12,6 +12,7 @@ | |||||
| #include <unordered_map> | ||||||
| #include <unordered_set> | ||||||
| #include <vector> | ||||||
| #include "base.h" | ||||||
|
|
||||||
| #include <ATen/core/ivalue.h> | ||||||
|
|
||||||
|
|
@@ -20,7 +21,7 @@ | |||||
| #include <fusion_guard.h> | ||||||
| #include <ir/base_nodes.h> | ||||||
| #include <ir/cloner.h> | ||||||
| #include <ir/container.h> | ||||||
| #include <ir/storage.h> | ||||||
| #include <iter_visitor.h> | ||||||
| #include <runtime/executor_params.h> | ||||||
| #include <visibility.h> | ||||||
|
|
@@ -142,11 +143,36 @@ class AliasInfoMap { | |||||
| //! The Fusion owns the whole IR graph (Vals and Exprs) | ||||||
| //! | ||||||
| // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) | ||||||
| class NVF_API Fusion : public IrContainer { | ||||||
| class NVF_API Fusion : public PolymorphicBase { | ||||||
| typedef std::unordered_map<int, std::vector<int64_t>> PermutationMap; | ||||||
|
|
||||||
| protected: | ||||||
| // Direct access to underlying container | ||||||
| IrContainer* ir_container() { | ||||||
| NVF_ERROR( | ||||||
| ir_container_.get() != nullptr, | ||||||
| "Accessing a uninitialized IrContainer!.") | ||||||
| return ir_container_.get(); | ||||||
| } | ||||||
|
|
||||||
| const IrContainer* ir_container() const { | ||||||
| NVF_ERROR( | ||||||
| ir_container_.get() != nullptr, | ||||||
| "Accessing a uninitialized IrContainer!.") | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same grammatical error: "a uninitialized" should be "an uninitialized"
Suggested change
|
||||||
| return ir_container_.get(); | ||||||
| } | ||||||
|
|
||||||
| public: | ||||||
| Fusion() = default; | ||||||
| // Registration (public API with passkey) | ||||||
| virtual void registerStmt(IrBuilderPasskey passkey, Statement* stmt) { | ||||||
| if (stmt->isVal()) { | ||||||
| registerVal(stmt->asVal()); | ||||||
| } else { | ||||||
| registerExpr(stmt->asExpr()); | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| Fusion(); | ||||||
|
|
||||||
| Fusion(const Fusion& other); | ||||||
| Fusion(Fusion&& other) noexcept; | ||||||
|
|
@@ -168,11 +194,11 @@ class NVF_API Fusion : public IrContainer { | |||||
|
|
||||||
| //! Break dependency chains associated with Expr, remove references to expr | ||||||
| //! delete expr | ||||||
| void removeExpr(Expr* expr) override; | ||||||
| virtual void removeExpr(Expr* expr); | ||||||
|
|
||||||
| //! Completely remove val from the fusion, break all dependencies associated | ||||||
| //! with it | ||||||
| void removeVal(Val* val) override; | ||||||
| virtual void removeVal(Val* val); | ||||||
|
|
||||||
| //! Register input as an input of the fusion | ||||||
| void addInput(Val* input); | ||||||
|
|
@@ -477,25 +503,126 @@ class NVF_API Fusion : public IrContainer { | |||||
|
|
||||||
| void resetExactMappings(); | ||||||
|
|
||||||
| //=================================================================== | ||||||
| // IrContainer API Forwarding (Public Methods) | ||||||
| //=================================================================== | ||||||
|
|
||||||
| // Container queries | ||||||
| bool inContainer(const Statement* stmt) const { | ||||||
| return ir_container()->inContainer(stmt); | ||||||
| } | ||||||
|
|
||||||
| void assertInContainer(const Statement* stmt, const std::string& msg) const { | ||||||
| ir_container()->assertInContainer(stmt, msg); | ||||||
| } | ||||||
|
|
||||||
| // Collections access (return values in insertion order) | ||||||
| const std::deque<Val*> deterministic_vals() const noexcept { | ||||||
| return ir_container()->deterministic_vals(); | ||||||
| } | ||||||
|
|
||||||
| const std::deque<Expr*> deterministic_exprs() const noexcept { | ||||||
| return ir_container()->deterministic_exprs(); | ||||||
| } | ||||||
|
|
||||||
| const std::unordered_map<Val*, int64_t> deterministic_vals_map() | ||||||
| const noexcept { | ||||||
| return ir_container()->deterministic_vals_map(); | ||||||
| } | ||||||
|
|
||||||
| const std::unordered_map<Expr*, int64_t> deterministic_exprs_map() | ||||||
| const noexcept { | ||||||
| return ir_container()->deterministic_exprs_map(); | ||||||
| } | ||||||
|
|
||||||
| // Collections access (unordered sets) | ||||||
| const std::unordered_set<Expr*>& unordered_exprs() const noexcept { | ||||||
| return ir_container()->unordered_exprs(); | ||||||
| } | ||||||
|
|
||||||
| const std::unordered_set<Val*>& vals() const noexcept { | ||||||
| return ir_container()->vals(); | ||||||
| } | ||||||
|
|
||||||
| // Count queries | ||||||
| int64_t numExprs() const noexcept { | ||||||
| return ir_container()->numExprs(); | ||||||
| } | ||||||
|
|
||||||
| int64_t numVals(bool include_shortcuts) const noexcept { | ||||||
| return ir_container()->numVals(include_shortcuts); | ||||||
| } | ||||||
|
|
||||||
| // Shortcut values (frequently used constants) | ||||||
| Val* zeroVal() { | ||||||
| return ir_container()->zeroVal(); | ||||||
| } | ||||||
|
|
||||||
| Val* oneVal() { | ||||||
| return ir_container()->oneVal(); | ||||||
| } | ||||||
|
|
||||||
| Val* falseVal() { | ||||||
| return ir_container()->falseVal(); | ||||||
| } | ||||||
|
|
||||||
| Val* trueVal() { | ||||||
| return ir_container()->trueVal(); | ||||||
| } | ||||||
|
|
||||||
| NamedScalar* magicZeroVal() { | ||||||
| return ir_container()->magicZeroVal(); | ||||||
| } | ||||||
|
|
||||||
| Val* zeroVal(DataType dtype) { | ||||||
| return ir_container()->zeroVal(dtype); | ||||||
| } | ||||||
|
|
||||||
| Val* oneVal(DataType dtype) { | ||||||
| return ir_container()->oneVal(dtype); | ||||||
| } | ||||||
|
|
||||||
| Val* metadataOf(Val* val) { | ||||||
| return ir_container()->metadataOf(val); | ||||||
| } | ||||||
|
|
||||||
| // Axioms (CUDA programming assumptions) | ||||||
| const std::vector<Val*>& axioms() { | ||||||
| return ir_container()->axioms(); | ||||||
| } | ||||||
|
|
||||||
| void assumePositive(Val* val) { | ||||||
| ir_container()->assumePositive(val); | ||||||
| } | ||||||
|
|
||||||
| void assumeNonNegative(Val* val) { | ||||||
| ir_container()->assumeNonNegative(val); | ||||||
| } | ||||||
|
|
||||||
| // Statement removal | ||||||
| void removeStatementsCreatedAfter( | ||||||
| int64_t num_exprs_before, | ||||||
| int64_t num_vals_before) { | ||||||
| ir_container()->removeStatementsCreatedAfter( | ||||||
| num_exprs_before, num_vals_before); | ||||||
| } | ||||||
|
|
||||||
| protected: | ||||||
| friend SegmentCandidateFinder; | ||||||
| friend SegmentedFusion; | ||||||
| friend class TranslateApplicableWelford; | ||||||
| friend Val; | ||||||
|
|
||||||
| using IrContainer::registerExpr; | ||||||
| using IrContainer::registerVal; | ||||||
|
|
||||||
| //! Register the Val with this fusion | ||||||
| void registerVal(Val* val) override; | ||||||
| virtual void registerVal(Val* val); | ||||||
|
|
||||||
| //! Register expr with this fusion. | ||||||
| //! When we register an expression, we want to update the dependency tracking | ||||||
| //! of Vals. If this container is a not a Kernel, it will remove previous | ||||||
| //! definitions of outputs and register this Expr as the definition. Otherwise | ||||||
| //! will update definition if not previously set, but will not remove old | ||||||
| //! definitions. | ||||||
| void registerExpr(Expr* expr) override; | ||||||
| virtual void registerExpr(Expr* expr); | ||||||
|
|
||||||
| //! Clear Expr's from TV uses that are not required to produce outputs from | ||||||
| //! inputs. Only other place this is used (other than Fusion) is in | ||||||
|
|
@@ -539,22 +666,18 @@ class NVF_API Fusion : public IrContainer { | |||||
| std::unique_ptr<std::vector<TensorView*>> all_tvs_ptr_ = nullptr; | ||||||
|
|
||||||
| inline static const std::string exact_mappings_key = "exact_mappings"; | ||||||
| std::unique_ptr<IrContainer> ir_container_; | ||||||
| }; | ||||||
|
|
||||||
| // Template implementations for Fusion::manage<T>() that use IrCloner | ||||||
| template <typename T> | ||||||
| std::any defaultCloneFunction(IrCloner& cloner, std::any data) { | ||||||
| auto cloned_data = cloner.clone(std::any_cast<T>(data)); | ||||||
| // Adding a static_assert to improve error message. Without this | ||||||
| // static_assert, the following cast will still fail, but the error message | ||||||
| // will be unreadable. | ||||||
| static_assert( | ||||||
| std::is_convertible_v<decltype(cloned_data), T>, | ||||||
| "IrCloner::clone returns a data type that is not compatible with the " | ||||||
| "original managed data type. " | ||||||
| "Likely you will need to check IrCloner::clone for your data type."); | ||||||
| // Convert the result of the clone back to T before assigning to std::any. | ||||||
| // This ensures the type of the std::any does not change over the clone of | ||||||
| // fusion. | ||||||
| return std::any((T)cloned_data); | ||||||
| } | ||||||
|
|
||||||
|
|
@@ -568,4 +691,41 @@ void Fusion::manage(std::string key, T data) { | |||||
| return manage(key, std::any(data), defaultCloneFunction<T>); | ||||||
| } | ||||||
|
|
||||||
| // Template implementations for IrBuilder that require Fusion to be fully | ||||||
| // defined | ||||||
| template <class T, class... Args> | ||||||
| T* IrBuilder::createInContainer(Fusion* container, Args&&... args) { | ||||||
| NVF_ERROR(container != nullptr, "Need an active container to build IR."); | ||||||
| T* node = new T(IrBuilderPasskey(container), std::forward<Args>(args)...); | ||||||
| container->registerStmt(IrBuilderPasskey(container), node); | ||||||
| return node; | ||||||
| } | ||||||
|
|
||||||
| template <class T> | ||||||
| T* IrBuilder::clone(const T* src, IrCloner* ir_cloner) { | ||||||
| NVF_ERROR( | ||||||
| ir_cloner != nullptr, | ||||||
| "Cannot use create when a cloner object is set. Use clone."); | ||||||
| NVF_ERROR( | ||||||
| ir_cloner->container() != nullptr, | ||||||
| "Cloner doesn't have a valid container to store cloned object."); | ||||||
|
|
||||||
| T* dest = new T(src, ir_cloner); | ||||||
| const auto* src_stmt = dynamic_cast<const Statement*>(src); | ||||||
| auto* dest_stmt = dynamic_cast<Statement*>(dest); | ||||||
|
|
||||||
| auto dest_container = ir_cloner->container(); | ||||||
| auto src_container = src_stmt->container(); | ||||||
|
|
||||||
| dest_container->registerStmt(IrBuilderPasskey(dest_container), dest_stmt); | ||||||
|
|
||||||
| if (src_container != dest_container) { | ||||||
| dest_stmt->setName(IrBuilderPasskey(dest_container), src_stmt->name()); | ||||||
| } | ||||||
|
|
||||||
| ir_cloner->registerClone(src_stmt, dest_stmt); | ||||||
|
|
||||||
| return dest; | ||||||
| } | ||||||
|
|
||||||
| } // namespace nvfuser | ||||||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
grammatical error: "a uninitialized" should be "an uninitialized"