Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ list(APPEND NVFUSER_SRCS
${NVFUSER_SRCS_DIR}/ir/base_nodes.cpp
${NVFUSER_SRCS_DIR}/ir/builder.cpp
${NVFUSER_SRCS_DIR}/ir/cloner.cpp
${NVFUSER_SRCS_DIR}/ir/container.cpp
#${NVFUSER_SRCS_DIR}/ir/container.cpp
${NVFUSER_SRCS_DIR}/ir/storage.cpp
${NVFUSER_SRCS_DIR}/ir/graphviz.cpp
${NVFUSER_SRCS_DIR}/ir/iostream.cpp
Expand Down
2 changes: 1 addition & 1 deletion csrc/dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ class NVF_API OptOutMutator : public PolymorphicBase {
}

protected:
virtual void removeExpr(IrContainer*, Expr*) const;
virtual void removeExpr(Fusion*, Expr*) const;
virtual void registerNewExpr(Expr*) {}

private:
Expand Down
51 changes: 41 additions & 10 deletions csrc/fusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,34 @@ bool Fusion::sameDefinition(const Fusion& other) const {
void Fusion::swap(Fusion& a, Fusion& b) noexcept {
FUSER_PERF_SCOPE("Fusion swap");

// Swap IrContainer base class (contains IrStorage)
IrContainer::swap(static_cast<IrContainer&>(a), static_cast<IrContainer&>(b));
// We need to be careful to call IrContainer swap not unique_ptr swap, which
// will only swap the ptrs NOT the contents.
IrContainer::swap(*(a.ir_container()), *(b.ir_container()));

// Fix parent pointers after swapping containers
// After swap, each Fusion owns a different IrContainer, so we must
// update the parent backpointers in those containers to point to their new
// owners
if (a.ir_container_) {
// Also update all Statement ir_container_ pointers to point to new owner
a.ir_container()->parent_ = &a;
for (auto val : a.vals()) {
val->ir_container_ = &a;
}
for (auto expr : a.deterministic_exprs()) {
expr->ir_container_ = &a;
}
}
if (b.ir_container_) {
// Also update all Statement ir_container_ pointers to point to new owner
b.ir_container()->parent_ = &b;
for (auto val : b.vals()) {
val->ir_container_ = &b;
}
for (auto expr : b.deterministic_exprs()) {
expr->ir_container_ = &b;
}
}

std::swap(a.inputs_, b.inputs_);
std::swap(a.outputs_, b.outputs_);
Expand All @@ -122,7 +148,7 @@ std::unique_ptr<SegmentedFusion> Fusion::segment(
IrCloner Fusion::copy(const Fusion* from, Fusion* to) {
to->clear();

auto ir_cloner = IrContainer::copy(from, to);
auto ir_cloner = IrContainer::copy(from->ir_container(), to->ir_container());

for (auto val : from->vals()) {
ir_cloner.clone(val)->setDefinition(ir_cloner.clone(val->definition_));
Expand Down Expand Up @@ -183,14 +209,19 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) {
return ir_cloner;
}

// Default constructor
Fusion::Fusion() : ir_container_(std::make_unique<IrContainer>()) {
ir_container_->parent_ = this;
}

// Copy constructor
Fusion::Fusion(const Fusion& other) {
Fusion::Fusion(const Fusion& other) : Fusion() {
FUSER_PERF_SCOPE("Fusion copy");
Fusion::copy(&other, this);
}

// Move constructor
Fusion::Fusion(Fusion&& other) noexcept {
Fusion::Fusion(Fusion&& other) noexcept : Fusion() {
FUSER_PERF_SCOPE("Fusion move");
swap(*this, other);
}
Expand Down Expand Up @@ -223,7 +254,7 @@ void Fusion::clear() noexcept {
// Clear container contents instead of destroying it
// This preserves the container object so Statement pointers don't become
// dangling
ir_storage()->clear();
ir_container()->clear();

inputs_.clear();
outputs_.clear();
Expand Down Expand Up @@ -260,7 +291,7 @@ void Fusion::removeExpr(Expr* expr) {
}
}

IrContainer::removeExpr(expr);
ir_container()->removeExpr(expr);
}

void Fusion::removeVal(Val* val) {
Expand Down Expand Up @@ -304,7 +335,7 @@ void Fusion::removeVal(Val* val) {
for (auto e : exprs_to_remove) {
removeExpr(e);
}
IrContainer::removeVal(val);
ir_container()->removeVal(val);

invalidateTvsAndUses();
}
Expand Down Expand Up @@ -668,7 +699,7 @@ void Fusion::registerVal(Val* val) {
val->fusion() == this, val, " was not found in the active fusion.");
}

IrContainer::registerVal(val);
ir_container()->registerVal(val);
}

void Fusion::registerExpr(Expr* expr) {
Expand All @@ -681,7 +712,7 @@ void Fusion::registerExpr(Expr* expr) {
expr->fusion() == this, expr, " was not found in the active fusion.");
}

IrContainer::registerExpr(expr);
ir_container()->registerExpr(expr);

for (Val* input : expr->inputs()) {
assertInContainer(input, "Input to expr is invalid, ");
Expand Down
192 changes: 176 additions & 16 deletions csrc/fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "base.h"

#include <ATen/core/ivalue.h>

Expand All @@ -20,7 +21,7 @@
#include <fusion_guard.h>
#include <ir/base_nodes.h>
#include <ir/cloner.h>
#include <ir/container.h>
#include <ir/storage.h>
#include <iter_visitor.h>
#include <runtime/executor_params.h>
#include <visibility.h>
Expand Down Expand Up @@ -142,11 +143,36 @@ class AliasInfoMap {
//! The Fusion owns the whole IR graph (Vals and Exprs)
//!
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
class NVF_API Fusion : public IrContainer {
class NVF_API Fusion : public PolymorphicBase {
typedef std::unordered_map<int, std::vector<int64_t>> PermutationMap;

protected:
// Direct access to underlying container
IrContainer* ir_container() {
NVF_ERROR(
ir_container_.get() != nullptr,
"Accessing a uninitialized IrContainer!.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

grammatical error: "a uninitialized" should be "an uninitialized"

Suggested change
"Accessing a uninitialized IrContainer!.")
"Accessing an uninitialized IrContainer!.")

return ir_container_.get();
}

const IrContainer* ir_container() const {
NVF_ERROR(
ir_container_.get() != nullptr,
"Accessing a uninitialized IrContainer!.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same grammatical error: "a uninitialized" should be "an uninitialized"

Suggested change
"Accessing a uninitialized IrContainer!.")
"Accessing an uninitialized IrContainer!.")

return ir_container_.get();
}

public:
Fusion() = default;
// Registration (public API with passkey)
virtual void registerStmt(IrBuilderPasskey passkey, Statement* stmt) {
if (stmt->isVal()) {
registerVal(stmt->asVal());
} else {
registerExpr(stmt->asExpr());
}
}

Fusion();

Fusion(const Fusion& other);
Fusion(Fusion&& other) noexcept;
Expand All @@ -168,11 +194,11 @@ class NVF_API Fusion : public IrContainer {

//! Break dependency chains associated with Expr, remove references to expr
//! delete expr
void removeExpr(Expr* expr) override;
virtual void removeExpr(Expr* expr);

//! Completely remove val from the fusion, break all dependencies associated
//! with it
void removeVal(Val* val) override;
virtual void removeVal(Val* val);

//! Register input as an input of the fusion
void addInput(Val* input);
Expand Down Expand Up @@ -477,25 +503,126 @@ class NVF_API Fusion : public IrContainer {

void resetExactMappings();

//===================================================================
// IrContainer API Forwarding (Public Methods)
//===================================================================

// Container queries
bool inContainer(const Statement* stmt) const {
return ir_container()->inContainer(stmt);
}

void assertInContainer(const Statement* stmt, const std::string& msg) const {
ir_container()->assertInContainer(stmt, msg);
}

// Collections access (return values in insertion order)
const std::deque<Val*> deterministic_vals() const noexcept {
return ir_container()->deterministic_vals();
}

const std::deque<Expr*> deterministic_exprs() const noexcept {
return ir_container()->deterministic_exprs();
}

const std::unordered_map<Val*, int64_t> deterministic_vals_map()
const noexcept {
return ir_container()->deterministic_vals_map();
}

const std::unordered_map<Expr*, int64_t> deterministic_exprs_map()
const noexcept {
return ir_container()->deterministic_exprs_map();
}

// Collections access (unordered sets)
const std::unordered_set<Expr*>& unordered_exprs() const noexcept {
return ir_container()->unordered_exprs();
}

const std::unordered_set<Val*>& vals() const noexcept {
return ir_container()->vals();
}

// Count queries
int64_t numExprs() const noexcept {
return ir_container()->numExprs();
}

int64_t numVals(bool include_shortcuts) const noexcept {
return ir_container()->numVals(include_shortcuts);
}

// Shortcut values (frequently used constants)
Val* zeroVal() {
return ir_container()->zeroVal();
}

Val* oneVal() {
return ir_container()->oneVal();
}

Val* falseVal() {
return ir_container()->falseVal();
}

Val* trueVal() {
return ir_container()->trueVal();
}

NamedScalar* magicZeroVal() {
return ir_container()->magicZeroVal();
}

Val* zeroVal(DataType dtype) {
return ir_container()->zeroVal(dtype);
}

Val* oneVal(DataType dtype) {
return ir_container()->oneVal(dtype);
}

Val* metadataOf(Val* val) {
return ir_container()->metadataOf(val);
}

// Axioms (CUDA programming assumptions)
const std::vector<Val*>& axioms() {
return ir_container()->axioms();
}

void assumePositive(Val* val) {
ir_container()->assumePositive(val);
}

void assumeNonNegative(Val* val) {
ir_container()->assumeNonNegative(val);
}

// Statement removal
void removeStatementsCreatedAfter(
int64_t num_exprs_before,
int64_t num_vals_before) {
ir_container()->removeStatementsCreatedAfter(
num_exprs_before, num_vals_before);
}

protected:
friend SegmentCandidateFinder;
friend SegmentedFusion;
friend class TranslateApplicableWelford;
friend Val;

using IrContainer::registerExpr;
using IrContainer::registerVal;

//! Register the Val with this fusion
void registerVal(Val* val) override;
virtual void registerVal(Val* val);

//! Register expr with this fusion.
//! When we register an expression, we want to update the dependency tracking
//! of Vals. If this container is a not a Kernel, it will remove previous
//! definitions of outputs and register this Expr as the definition. Otherwise
//! will update definition if not previously set, but will not remove old
//! definitions.
void registerExpr(Expr* expr) override;
virtual void registerExpr(Expr* expr);

//! Clear Expr's from TV uses that are not required to produce outputs from
//! inputs. Only other place this is used (other than Fusion) is in
Expand Down Expand Up @@ -539,22 +666,18 @@ class NVF_API Fusion : public IrContainer {
std::unique_ptr<std::vector<TensorView*>> all_tvs_ptr_ = nullptr;

inline static const std::string exact_mappings_key = "exact_mappings";
std::unique_ptr<IrContainer> ir_container_;
};

// Template implementations for Fusion::manage<T>() that use IrCloner
template <typename T>
std::any defaultCloneFunction(IrCloner& cloner, std::any data) {
auto cloned_data = cloner.clone(std::any_cast<T>(data));
// Adding a static_assert to improve error message. Without this
// static_assert, the following cast will still fail, but the error message
// will be unreadable.
static_assert(
std::is_convertible_v<decltype(cloned_data), T>,
"IrCloner::clone returns a data type that is not compatible with the "
"original managed data type. "
"Likely you will need to check IrCloner::clone for your data type.");
// Convert the result of the clone back to T before assigning to std::any.
// This ensures the type of the std::any does not change over the clone of
// fusion.
return std::any((T)cloned_data);
}

Expand All @@ -568,4 +691,41 @@ void Fusion::manage(std::string key, T data) {
return manage(key, std::any(data), defaultCloneFunction<T>);
}

// Template implementations for IrBuilder that require Fusion to be fully
// defined
template <class T, class... Args>
T* IrBuilder::createInContainer(Fusion* container, Args&&... args) {
NVF_ERROR(container != nullptr, "Need an active container to build IR.");
T* node = new T(IrBuilderPasskey(container), std::forward<Args>(args)...);
container->registerStmt(IrBuilderPasskey(container), node);
return node;
}

template <class T>
T* IrBuilder::clone(const T* src, IrCloner* ir_cloner) {
NVF_ERROR(
ir_cloner != nullptr,
"Cannot use create when a cloner object is set. Use clone.");
NVF_ERROR(
ir_cloner->container() != nullptr,
"Cloner doesn't have a valid container to store cloned object.");

T* dest = new T(src, ir_cloner);
const auto* src_stmt = dynamic_cast<const Statement*>(src);
auto* dest_stmt = dynamic_cast<Statement*>(dest);

auto dest_container = ir_cloner->container();
auto src_container = src_stmt->container();

dest_container->registerStmt(IrBuilderPasskey(dest_container), dest_stmt);

if (src_container != dest_container) {
dest_stmt->setName(IrBuilderPasskey(dest_container), src_stmt->name());
}

ir_cloner->registerClone(src_stmt, dest_stmt);

return dest;
}

} // namespace nvfuser
Loading