Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,7 @@ list(APPEND NVFUSER_SRCS
${NVFUSER_SRCS_DIR}/ir/base_nodes.cpp
${NVFUSER_SRCS_DIR}/ir/builder.cpp
${NVFUSER_SRCS_DIR}/ir/cloner.cpp
#${NVFUSER_SRCS_DIR}/ir/container.cpp
${NVFUSER_SRCS_DIR}/ir/storage.cpp
${NVFUSER_SRCS_DIR}/ir/container.cpp
${NVFUSER_SRCS_DIR}/ir/graphviz.cpp
${NVFUSER_SRCS_DIR}/ir/iostream.cpp
${NVFUSER_SRCS_DIR}/ir/internal_base_nodes.cpp
Expand Down
2 changes: 1 addition & 1 deletion csrc/fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
#include <fusion_guard.h>
#include <ir/base_nodes.h>
#include <ir/cloner.h>
#include <ir/storage.h>
#include <ir/container.h>
#include <iter_visitor.h>
#include <runtime/executor_params.h>
#include <visibility.h>
Expand Down
2 changes: 1 addition & 1 deletion csrc/ir/builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
#include "exceptions.h"
#include "fusion_guard.h"
#include "ir/builder_passkey.h"
#include "ir/storage.h"
#include "ir/container.h"
#include "visibility.h"

namespace nvfuser {
Expand Down
File renamed without changes.
191 changes: 190 additions & 1 deletion csrc/ir/container.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,197 @@
// clang-format on
#pragma once

#include <deque>
#include <unordered_map>
#include <unordered_set>

#include "base.h"
#include "exceptions.h"
#include "ir/base_nodes.h"
#include "visibility.h"

namespace nvfuser {

// Empty for now...
// Passkey for container to register names with statements
class IrContainerPasskey {
friend class IrContainer;

private:
explicit IrContainerPasskey() = default;
};

class NamedScalar;

class IrContainer {
public:
NVF_API IrContainer();

// Copy/Move Constructors and Operators are deleted. IrContainer is managed
// through a smart pointer in IrContainer. Semantic operations for Fusion
// types are handled directly through copy and swap functions.
IrContainer(const IrContainer& other) = delete;
IrContainer(IrContainer&& other) noexcept = delete;

IrContainer& operator=(const IrContainer& other) = delete;
IrContainer& operator=(IrContainer&& other) noexcept = delete;

~IrContainer();

bool inContainer(const Statement* stmt) const;

void assertInContainer(const Statement* stmt, const std::string& msg) const {
NVF_CHECK(
inContainer(stmt), msg, " it was not found in the active container.");
}

//! Return values in insertion order
const std::deque<Val*> deterministic_vals() const noexcept;

//! Return expression in insertion order
const std::deque<Expr*> deterministic_exprs() const noexcept;

//! Return mapping from value to integer id
const std::unordered_map<Val*, int64_t> deterministic_vals_map()
const noexcept;

//! Return mapping from expression to integer id
const std::unordered_map<Expr*, int64_t> deterministic_exprs_map()
const noexcept;

//! Return the set of Exprs registered with this fusion. Warning: This will
//! return exprs outside inputs/outputs, so can be unsafe for use with
//! segmented fusions.
const std::unordered_set<Expr*>& unordered_exprs() const noexcept {
return exprs_;
}

//! Return the set of Vals registered with this fusion
const std::unordered_set<Val*>& vals() const noexcept {
return vals_;
}

int64_t numExprs() const noexcept {
return std::ssize(exprs_);
}

// When include_shortcuts is true, it will count the shortcuts like true_val_.
int64_t numVals(bool include_shortcuts) const noexcept {
return include_shortcuts ? std::ssize(vals_) : std::ssize(vals_up_);
}

// Shortcuts for frequently used vals
NVF_API Val* zeroVal();
NVF_API Val* oneVal();
Val* falseVal();
Val* trueVal();
NamedScalar* magicZeroVal();
NVF_API Val* zeroVal(DataType dtype);
NVF_API Val* oneVal(DataType dtype);
Val* metadataOf(Val*);

// Axioms about CUDA programming, for example: threadIdx.x < blockDim.x
const std::vector<Val*>& axioms() {
lazyInitAxioms();
return *axioms_;
}

void assumePositive(Val* val);
void assumeNonNegative(Val* val);

protected:
static IrCloner copy(const IrContainer* from, IrContainer* to);

static void swap(IrContainer& a, IrContainer& b) noexcept;

// Let Fusion access IrContainer::clear()
friend class Fusion;

void removeExpr(Expr* expr);

//! Completely remove val from the fusion, break all dependencies associated
//! with it
void removeVal(Val* val);

//! Register the Val with this container
NVF_API void registerVal(Val* val);

//! Register expr with this container.
NVF_API void registerExpr(Expr* expr);

StmtNameType getValName(ValType vtype) {
if (val_type_name_map_.find(vtype) == val_type_name_map_.end()) {
val_type_name_map_[vtype] = 0;
}
return val_type_name_map_[vtype]++;
}

StmtNameType getExprName() {
return expr_name_counter_++;
}

void clear() noexcept;

void lazyInitAxioms();

friend class StatementGuard;

// A simple garbage collection mechanism to remove all Exprs and Vals that
// were created after a certain point. This is useful for analysis that
// creates new Exprs and Vals in the container and wants to clean up after
// itself.
//
// Used by StatementGuard only.
void removeStatementsCreatedAfter(
int64_t prev_num_exprs,
int64_t prev_num_vals);

// Deque of unique pointer is the memory owning data structure
std::deque<std::unique_ptr<Val>> vals_up_;

// A convenient set to return when we just need an unordered set to do
// something like check if a Val is in this container
std::unordered_set<Val*> vals_;

// Deque of unique pointer is the memory owning data structure
std::deque<std::unique_ptr<Expr>> exprs_up_;

// A convenient set to return when we just need an unordered set to do
// something like check if an Expr is in this container
std::unordered_set<Expr*> exprs_;

// Values names counters
std::unordered_map<ValType, StmtNameType> val_type_name_map_;

// Expression names counter
StmtNameType expr_name_counter_ = 0;

// Manually store some persistent, frequently used nodes. It's very
// challenging to do this anything but manually as detecting when a container
// may or may not have one of these vals is tricky. Specifically because if
// the container doesn't own it, it's hard to understand from the outside if
// the node may have been removed then re-registered. It could also be tricky
// to know when we're using a different container as in FusionCopy_test
// demonstrates deleting then creating containers can result in the same
// pointer for the container.
std::unique_ptr<Val> true_val_;
std::unique_ptr<Val> false_val_;
std::unique_ptr<Val> one_val_;
std::unique_ptr<Val> zero_val_;
std::unique_ptr<NamedScalar> magic_zero_val_;
std::unique_ptr<std::vector<Val*>> axioms_;
std::unordered_map<Val*, std::pair<Val*, Expr*>> metadata_;

public:
Fusion* parent() const {
NVF_ERROR(
parent_ != nullptr, "Call to IrContainer::parent() holds nullptr.")
return parent_;
}

private:
// Parent Fusion that owns this container (for pure composition pattern)
// Used by Statement::fusion() to navigate back to owning Fusion
Fusion* parent_ = nullptr;
};

} // namespace nvfuser
Loading