Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ list(APPEND NVFUSER_SRCS
${NVFUSER_SRCS_DIR}/ir/base_nodes.cpp
${NVFUSER_SRCS_DIR}/ir/builder.cpp
${NVFUSER_SRCS_DIR}/ir/cloner.cpp
${NVFUSER_SRCS_DIR}/ir/container.cpp
#${NVFUSER_SRCS_DIR}/ir/container.cpp
${NVFUSER_SRCS_DIR}/ir/storage.cpp
${NVFUSER_SRCS_DIR}/ir/graphviz.cpp
${NVFUSER_SRCS_DIR}/ir/iostream.cpp
Expand Down
2 changes: 1 addition & 1 deletion csrc/dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ class NVF_API OptOutMutator : public PolymorphicBase {
}

protected:
virtual void removeExpr(IrContainer*, Expr*) const;
virtual void removeExpr(Fusion*, Expr*) const;
virtual void registerNewExpr(Expr*) {}

private:
Expand Down
55 changes: 45 additions & 10 deletions csrc/fusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,37 @@ bool Fusion::sameDefinition(const Fusion& other) const {
void Fusion::swap(Fusion& a, Fusion& b) noexcept {
FUSER_PERF_SCOPE("Fusion swap");

// Swap IrContainer base class (contains IrStorage)
IrContainer::swap(static_cast<IrContainer&>(a), static_cast<IrContainer&>(b));
// We need to be careful to call IrContainer swap not unique_ptr swap, which
// will only swap the ptrs NOT the contents.
IrContainer::swap(*(a.ir_container()), *(b.ir_container()));

// Fix parent pointers after swapping containers
// After swap, each IrContainer owns a different IrContainer, so we must
// update the parent backpointers in those containers to point to their new
// owners
if (a.ir_container_) {
// Also update all Statement ir_container_ pointers to point to new owner
// Note: IrContainer is now in impl namespace, but Statement::ir_container_
// is Fusion*. Since only Fusion (and its derived classes) inherit from
// impl::IrContainer, this cast is safe.
a.ir_container()->parent_ = &a;
for (auto val : a.vals()) {
val->ir_container_ = &a;
}
for (auto expr : a.deterministic_exprs()) {
expr->ir_container_ = &a;
}
}
if (b.ir_container_) {
// Also update all Statement ir_container_ pointers to point to new owner
b.ir_container()->parent_ = &b;
for (auto val : b.vals()) {
val->ir_container_ = &b;
}
for (auto expr : b.deterministic_exprs()) {
expr->ir_container_ = &b;
}
}

std::swap(a.inputs_, b.inputs_);
std::swap(a.outputs_, b.outputs_);
Expand All @@ -122,7 +151,7 @@ std::unique_ptr<SegmentedFusion> Fusion::segment(
IrCloner Fusion::copy(const Fusion* from, Fusion* to) {
to->clear();

auto ir_cloner = IrContainer::copy(from, to);
auto ir_cloner = IrContainer::copy(from->ir_container(), to->ir_container());

for (auto val : from->vals()) {
ir_cloner.clone(val)->setDefinition(ir_cloner.clone(val->definition_));
Expand Down Expand Up @@ -183,14 +212,19 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) {
return ir_cloner;
}

// Default constructor
Fusion::Fusion() : ir_container_(std::make_unique<IrContainer>()) {
ir_container_->parent_ = this;
}

// Copy constructor
Fusion::Fusion(const Fusion& other) {
Fusion::Fusion(const Fusion& other) : Fusion() {
FUSER_PERF_SCOPE("Fusion copy");
Fusion::copy(&other, this);
}

// Move constructor
Fusion::Fusion(Fusion&& other) noexcept {
Fusion::Fusion(Fusion&& other) noexcept : Fusion() {
FUSER_PERF_SCOPE("Fusion move");
swap(*this, other);
}
Expand Down Expand Up @@ -223,7 +257,7 @@ void Fusion::clear() noexcept {
// Clear container contents instead of destroying it
// This preserves the container object so Statement pointers don't become
// dangling
ir_storage()->clear();
ir_container()->clear();

inputs_.clear();
outputs_.clear();
Expand Down Expand Up @@ -260,7 +294,8 @@ void Fusion::removeExpr(Expr* expr) {
}
}

IrContainer::removeExpr(expr);
// TODO : CHECK THIS vvv
ir_container()->removeExpr(expr);
}

void Fusion::removeVal(Val* val) {
Expand Down Expand Up @@ -304,7 +339,7 @@ void Fusion::removeVal(Val* val) {
for (auto e : exprs_to_remove) {
removeExpr(e);
}
IrContainer::removeVal(val);
ir_container()->removeVal(val);

invalidateTvsAndUses();
}
Expand Down Expand Up @@ -668,7 +703,7 @@ void Fusion::registerVal(Val* val) {
val->fusion() == this, val, " was not found in the active fusion.");
}

IrContainer::registerVal(val);
ir_container()->registerVal(val);
}

void Fusion::registerExpr(Expr* expr) {
Expand All @@ -681,7 +716,7 @@ void Fusion::registerExpr(Expr* expr) {
expr->fusion() == this, expr, " was not found in the active fusion.");
}

IrContainer::registerExpr(expr);
ir_container()->registerExpr(expr);

for (Val* input : expr->inputs()) {
assertInContainer(input, "Input to expr is invalid, ");
Expand Down
Loading
Loading