Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 25 additions & 3 deletions csrc/host_ir/evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -806,14 +806,36 @@ void HostIrEvaluator::handle(ShardByStream* shard) {
IterDomain* stream_id = *i;

auto in_tensor = getKnownConcreteValue(shard->in()).as<at::Tensor>();
auto stream_index =
expr_evaluator_.evaluate(shard->stream_index()).as<int64_t>();
auto index = expr_evaluator_.evaluate(shard->stream_index()).as<int64_t>();

if (stream_id->definition() != nullptr) {
// If the stream axis is defined by a swizzle, the input to
// the swizzle is the index into the `in_tensor`.
// Currently, we use cyclic shift swizzle to compute the index:
// in_index = (out_index (stream index) + device_id) % num_devices

NVF_CHECK(stream_id->definition()->isA<hir::Swizzle>());
auto* swizzle = stream_id->definition()->as<hir::Swizzle>();
ParallelType pt = swizzle->pt();

auto mesh = out_tv->getDeviceMesh();
// Find the index of the current device in the slice of mesh corresponding
// to the parallel type
auto team_size = mesh.size(pt);
at::Tensor md_index =
mesh.multiDimensionalIndexOf(communicator_->deviceId());
auto pt_axis = mesh.parallelTypeToAxis(pt);
int64_t team_index = md_index[pt_axis].item<int64_t>();
index = (index + team_index) % team_size;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: consider putting + and % to the HostIrContainer in the future. See also

Fuser/csrc/swizzle.h

Lines 26 to 52 in 77abd29

NVF_API std::pair<Val*, Val*> dispatchSwizzle(
Swizzle2DType type,
Val* x,
Val* y,
Val* maybe_size_x,
Val* maybe_size_y);
NVF_API std::pair<Val*, Val*> dispatchSwizzle(
SwizzleType type,
Val* x,
Val* y,
Val* maybe_size_x,
Val* maybe_size_y);
NVF_API std::pair<Val*, Val*> dispatchUnSwizzle(
Swizzle2DType type,
Val* x,
Val* y,
Val* maybe_size_x,
Val* maybe_size_y);
NVF_API std::pair<Val*, Val*> dispatchUnSwizzle(
SwizzleType type,
Val* x,
Val* y,
Val* maybe_size_x,
Val* maybe_size_y);

}

at::Tensor out_tensor =
in_tensor
.chunk(
stream_id->extent()->evaluate().as<int64_t>(),
index,
getShardedLogicalAxis(out_tv, ParallelType::Stream))
.at(stream_index);
.at(index);
Comment on lines 832 to +838
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

incorrect number of arguments passed to .chunk() - the method signature is chunk(int chunks, int dim) (2 args), but 3 arguments are being passed here

Suggested change
at::Tensor out_tensor =
in_tensor
.chunk(
stream_id->extent()->evaluate().as<int64_t>(),
index,
getShardedLogicalAxis(out_tv, ParallelType::Stream))
.at(stream_index);
.at(index);
at::Tensor out_tensor =
in_tensor
.chunk(
stream_id->extent()->evaluate().as<int64_t>(),
getShardedLogicalAxis(out_tv, ParallelType::Stream))
.at(index);


expr_evaluator_.bind(out_tv, out_tensor);
}
Expand Down
32 changes: 32 additions & 0 deletions csrc/host_ir/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -503,4 +503,36 @@ std::string ForLoop::toInlineString(int indent_size) const {
index, iter_domain->start(), iter_domain->stop());
}

Swizzle::Swizzle(
IrBuilderPasskey passkey,
IterDomain* in,
IterDomain* out,
ParallelType pt)
: Expr(passkey, {in}, {out}, {}) {
NVF_ERROR(passkey.ir_container_ != nullptr);
NVF_ERROR(
passkey.ir_container_->isA<HostIrContainer>(),
this,
"must be registered in a HostIrContainer");
NVF_ERROR(in != nullptr);
NVF_ERROR(out != nullptr);
addDataAttribute(pt);
}

NVFUSER_DEFINE_CLONE_AND_CREATE(Swizzle)

std::string Swizzle::toString(int indent_size) const {
std::stringstream ss;
indent(ss, indent_size) << out()->toString() << " = Swizzle("
<< in()->toString() << ", pt=" << pt() << std::endl;
return ss.str();
}

std::string Swizzle::toInlineString(int indent_size) const {
std::stringstream ss;
indent(ss, indent_size) << "Swizzle(" << in()->toInlineString()
<< ", pt=" << pt() << ")";
return ss.str();
}

} // namespace nvfuser::hir
36 changes: 36 additions & 0 deletions csrc/host_ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -569,4 +569,40 @@ class ForLoop : public Expr {
}
};

class Swizzle : public Expr {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class Swizzle : public Expr {
class Swizzle1D : public Expr {

This should be moved to csrc/ir/internal_nodes.h. It's fine to not use it for kernel generation for now, but the Expr itself is agnostic and in fact will be used by preseg.

public:
using Expr::Expr;

Swizzle(
IrBuilderPasskey passkey,
IterDomain* in,
IterDomain* out,
ParallelType pt);

Swizzle(const Swizzle& other) = delete;
Swizzle& operator=(const Swizzle& other) = delete;
Swizzle(Swizzle&& other) = delete;
Swizzle& operator=(Swizzle&& other) = delete;

NVFUSER_DECLARE_CLONE_AND_CREATE

std::string toString(int indent_size = 0) const override;
std::string toInlineString(int indent_size = 0) const override;
const char* getOpString() const override {
return "hir::Swizzle";
}

IterDomain* in() const {
return inputs().at(0)->as<IterDomain>();
}

IterDomain* out() const {
return outputs().at(0)->as<IterDomain>();
}

ParallelType pt() const {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
ParallelType pt() const {
ParallelType parallelType() const {

return attribute<ParallelType>(0);
}
};

} // namespace nvfuser::hir
15 changes: 15 additions & 0 deletions csrc/host_ir/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,21 @@

namespace nvfuser::hir {

TensorView* swizzle(TensorView* in, int64_t axis, ParallelType pt) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TensorView::swizzle1d

NVF_ERROR(in != nullptr);

IterDomain* swizzle_in = in->axis(axis);
IterDomain* swizzle_out = IterDomainBuilder(swizzle_in).build();
IrBuilder::create<Swizzle>(swizzle_in, swizzle_out, pt);

std::vector<IterDomain*> loop_domain = in->getLoopDomain();
loop_domain.erase(loop_domain.begin() + axis);
loop_domain.insert(loop_domain.begin() + axis, swizzle_out);
in->setLoopDomain(loop_domain);

return in;
}

TensorView* shardByStream(TensorView* source, Val* stream_index, Expr* e) {
NVF_ERROR(
getShardedIterDomain(
Expand Down
2 changes: 2 additions & 0 deletions csrc/host_ir/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

namespace nvfuser::hir {

TensorView* swizzle(TensorView* in, int64_t axis, ParallelType pt);

// Creates a ShardByStream without needing the destination TensorView. Returns
// the destination TensorView. `e` is the Expr from which we propagate the loop
// domain from. `source` must be either an input or an output of `e`. The
Expand Down
3 changes: 3 additions & 0 deletions csrc/multidevice/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <vector>

#include "compute_at_map.h"
#include "host_ir/ir.h"
#include "ir/internal_base_nodes.h"
#include "ir/internal_nodes.h"
#include "transform_replay.h"
Expand Down Expand Up @@ -178,6 +179,8 @@ int64_t getProducingLogicalAxis(const TensorView* tv, IterDomain* id) {
// When `unshardedSizes` is given a local tensor of shape [1, 1], it's
// unclear the global shape is [1, D] or [D, 1] or even [2, D/2], etc.
id = merge->outer();
} else if (auto* swizzle = dynamic_cast<hir::Swizzle*>(def)) {
id = swizzle->in();
} else {
NVF_THROW(
"Unexpected transforms from logical to a DID-parallel allocation "
Expand Down
78 changes: 78 additions & 0 deletions tests/cpp/test_multidevice_host_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "fusion.h"
#include "host_ir/container.h"
#include "host_ir/evaluator.h"
#include "host_ir/ops.h"
#include "host_ir/pass/stream_parallel_type.h"
#include "ir/all_nodes.h"
#include "multidevice/symmetric_tensor.h"
Expand Down Expand Up @@ -507,6 +508,83 @@ TEST_F(MultiDeviceHostIrTest, SymmetricContiguousView) {
<< "Output tensor does not match expected values";
}

TEST_F(MultiDeviceTest, SwizzleWithParallelType) {
const int64_t d = communicator_->size();
const int64_t my_rank = communicator_->deviceId();
auto mesh = DeviceMesh::createForNumDevices(d);

auto hic = std::make_unique<HostIrContainer>();
FusionGuard fg(hic.get());
{
TensorView* in_tv = makeContigTensor(2);
TensorView* out_tv = set(in_tv);
hic->addInput(in_tv);
hic->addOutput(out_tv);

for (auto* tv : {in_tv, out_tv}) {
tv->setMemoryType(MemoryType::Global);
tv->setDeviceMesh(mesh);
tv->outer_split(1, d);
tv->axis(1)->parallelize(ParallelType::DIDx);
tv->setAllocationDomain(tv->getLoopDomain(), true);
}
auto* allocate_out = IrBuilder::create<kir::Allocate>(
out_tv, MemoryType::Global, std::vector<Val*>({}), /*zero_init=*/true);

for (auto* tv : {in_tv, out_tv}) {
tv->outer_split(0, d);
tv = hir::swizzle(tv, 0, ParallelType::DIDx);
tv->axis(0)->parallelize(ParallelType::Stream);
}

auto* stream_index = IrBuilder::create<Val>(DataType::Index);
auto* for_loop = IrBuilder::create<ForLoop>(
stream_index,
/*start=*/hic->zeroVal(DataType::Index),
/*stop=*/IrBuilder::create<Val>(d - 1, DataType::Index));

TensorView* in_shard =
ops::newValLike(in_tv, *in_tv->getDataType())->as<TensorView>();
TensorView* out_shard =
ops::newValLike(out_tv, *out_tv->getDataType())->as<TensorView>();

for (auto* tv : {in_shard, out_shard}) {
tv->setDeviceMesh(mesh);
tv->outer_split(1, d);
tv->axis(1)->parallelize(ParallelType::DIDx);
tv->outer_split(0, d);
tv = hir::swizzle(tv, 0, ParallelType::DIDx);
tv->axis(0)->parallelize(ParallelType::Stream);
tv->setAllocationDomain(tv->getLoopDomain(), true);
}

IrBuilder::create<ShardByStream>(in_shard, in_tv, stream_index);
IrBuilder::create<ShardByStream>(out_shard, out_tv, stream_index);
auto* copy = IrBuilder::create<LoadStoreOp>(
LoadStoreOpType::Set, out_shard, in_shard);

for_loop->body().pushBack(in_shard->definition());
for_loop->body().pushBack(out_shard->definition());
for_loop->body().pushBack(copy);

hic->pushBackTopLevelExprs(allocate_out);
hic->pushBackTopLevelExprs(for_loop);
}

HostIrEvaluator hie(std::move(hic));
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA);
at::Tensor unsharded_in = at::randn({d * 3, d * 5}, options);
at::Tensor sharded_in = shardTensor1D(unsharded_in, 1, mesh);

KernelArgumentHolder ins(sharded_in);
ins.setCacheId(0);
KernelArgumentHolder outs = hie.runWithInputs(ins);
at::Tensor out = outs[0].as<at::Tensor>();
at::Tensor expected_out = sharded_in;
expected_out.chunk(d, 0)[(my_rank + d - 1) % d].zero_();
EXPECT_TRUE(at::allclose(out, expected_out)) << out << " vs " << expected_out;
}

} // namespace hir

} // namespace nvfuser
Loading