diff --git a/csrc/host_ir/evaluator.cpp b/csrc/host_ir/evaluator.cpp index 2ceedfddc40..2396767d5b0 100644 --- a/csrc/host_ir/evaluator.cpp +++ b/csrc/host_ir/evaluator.cpp @@ -806,14 +806,37 @@ void HostIrEvaluator::handle(ShardByStream* shard) { IterDomain* stream_id = *i; auto in_tensor = getKnownConcreteValue(shard->in()).as(); - auto stream_index = - expr_evaluator_.evaluate(shard->stream_index()).as(); + auto index = expr_evaluator_.evaluate(shard->stream_index()).as(); + + if (stream_id->definition()->isA()) { + // If the stream axis is defined by a swizzle, the input to + // the swizzle is the index into the `in_tensor`. + // Currently, we use cyclic shift swizzle to compute the index: + // in_index = (out_index (stream index) + device_id) % num_devices + // TODO(prmishra): In the future, the swizzle compute should be done outside + // of `shardByStream` such that `add` and `mod` are in the HostIrContainer + // similar to + // https://github.com/NVIDIA/Fuser/blob/0a6adb140d440cc1b6d5f21dfd05874f9699b2c6/csrc/swizzle.h#L26-L31. + auto* swizzle = stream_id->definition()->as(); + ParallelType pt = swizzle->parallelType(); + + auto mesh = out_tv->getDeviceMesh(); + // Find the index of the current device in the slice of mesh corresponding + // to the parallel type + auto team_size = mesh.size(pt); + at::Tensor md_index = + mesh.multiDimensionalIndexOf(communicator_->deviceId()); + auto pt_axis = mesh.parallelTypeToAxis(pt); + int64_t team_index = md_index[pt_axis].item(); + index = (index + team_index) % team_size; + } + at::Tensor out_tensor = in_tensor .chunk( stream_id->extent()->evaluate().as(), getShardedLogicalAxis(out_tv, ParallelType::Stream)) - .at(stream_index); + .at(index); expr_evaluator_.bind(out_tv, out_tensor); } diff --git a/csrc/ir/interface_nodes.h b/csrc/ir/interface_nodes.h index 289e0fc9de3..017c0a7db8e 100644 --- a/csrc/ir/interface_nodes.h +++ b/csrc/ir/interface_nodes.h @@ -646,6 +646,12 @@ class NVF_API TensorView : public Val { //! to the 2 given indices. TensorView* swizzle(SwizzleType swizzle_type, int64_t x, int64_t y); + //! Swizzle1D is currently only used and handled in HostIr + //! It computes the `in` id to the swizzle as a function of the device id + //! (corresponding to the parallel type) and `out` id. See + //! `HostIrEvaluator::handle(ShardByStream)` for usage. + TensorView* swizzle1d(int64_t x, ParallelType pt); + //! Resize an IterDomain by expanding both the left and right sides //! by given widths. The resulting IterDomain has an extent of //! (left_expansion + axis->extent() + right_expansion). diff --git a/csrc/ir/internal_base_nodes.cpp b/csrc/ir/internal_base_nodes.cpp index c1b16d3bfab..e122641852c 100644 --- a/csrc/ir/internal_base_nodes.cpp +++ b/csrc/ir/internal_base_nodes.cpp @@ -547,6 +547,12 @@ std::pair IterDomain::swizzle( return std::make_pair(out_x, out_y); } +IterDomain* IterDomain::swizzle1d(IterDomain* in, ParallelType pt) { + IterDomain* out = IterDomainBuilder(in).build(); + IrBuilder::createInContainer(in->container(), out, in, pt); + return out; +} + IterDomain* IterDomain::resize( IterDomain* in, Val* left_expansion, @@ -1806,6 +1812,16 @@ void TensorDomain::swizzle( loop_domain_.insert(loop_domain_.begin() + y, axis_out_y); } +void TensorDomain::swizzle1d(int64_t x, ParallelType pt) { + x = wrapDim(x); + + IterDomain* swizzle_in = axis(x); + IterDomain* swizzle_out = IterDomain::swizzle1d(swizzle_in, pt); + + loop_domain_.erase(loop_domain_.begin() + x); + loop_domain_.insert(loop_domain_.begin() + x, swizzle_out); +} + void TensorDomain::resize( int64_t axis, Val* left_expansion, diff --git a/csrc/ir/internal_base_nodes.h b/csrc/ir/internal_base_nodes.h index 4d685b4d8a5..7a176268d67 100644 --- a/csrc/ir/internal_base_nodes.h +++ b/csrc/ir/internal_base_nodes.h @@ -387,6 +387,8 @@ class NVF_API IterDomain : public Val { IterDomain* in_y, SwizzleMode swizzle_mode = SwizzleMode::Data); + static IterDomain* swizzle1d(IterDomain* in, ParallelType pt); + protected: friend TensorDomain; friend ReplayTransformations; @@ -830,6 +832,8 @@ class NVF_API TensorDomain : public Val { int64_t y, SwizzleMode swizzle_mode = SwizzleMode::Data); + void swizzle1d(int64_t x, ParallelType pt); + // Resize an axis by left_expansion and right_expansion void resize( int64_t axis, diff --git a/csrc/ir/internal_nodes.cpp b/csrc/ir/internal_nodes.cpp index 6954b0250bb..0a3666bd6aa 100644 --- a/csrc/ir/internal_nodes.cpp +++ b/csrc/ir/internal_nodes.cpp @@ -2771,6 +2771,32 @@ std::string Swizzle::toInlineString(int indent_size) const { NVFUSER_DEFINE_CLONE_AND_CREATE(Swizzle) +Swizzle1D::Swizzle1D( + IrBuilderPasskey passkey, + IterDomain* out, + IterDomain* in, + ParallelType pt) + : Expr(passkey) { + addOutput(out); + addInput(in); + addDataAttribute(pt); +} + +NVFUSER_DEFINE_CLONE_AND_CREATE(Swizzle1D) + +std::string Swizzle1D::toString(int indent_size) const { + std::stringstream ss; + indent(ss, indent_size) << out()->toString() << " = Swizzle1D(" + << in()->toString() + << ", parallelType=" << parallelType() << ")" + << std::endl; + return ss.str(); +} + +std::string Swizzle1D::toInlineString(int indent_size) const { + NVF_THROW("Swizzle1D can not be printed inline"); +} + Resize::Resize( IrBuilderPasskey passkey, IterDomain* out, diff --git a/csrc/ir/internal_nodes.h b/csrc/ir/internal_nodes.h index a1cd7539982..a3ef4782df8 100644 --- a/csrc/ir/internal_nodes.h +++ b/csrc/ir/internal_nodes.h @@ -1995,7 +1995,41 @@ class Swizzle : public Expr { } }; -//! Applies 2D swizzles on a rectangular tile defined by 2 iterdomains. +// Swizzle1D is currently only used and handled in HostIr. +// The main use case is to compute the indexing for ring-based overlap, where +// `out` is stream-parallel and `in` is a function of the device id and stream +// index. See `HostIrEvaluator::handle(ShardByStream)` for usage. +class Swizzle1D : public Expr { + public: + using Expr::Expr; + + Swizzle1D( + IrBuilderPasskey passkey, + IterDomain* out, + IterDomain* in, + ParallelType pt); + + NVFUSER_DECLARE_CLONE_AND_CREATE + + const char* getOpString() const override { + return "Swizzle1D"; + } + + std::string toString(int indent_size = 0) const override; + std::string toInlineString(int indent_size = 0) const override; + + IterDomain* in() const { + return inputs().at(0)->as(); + } + + IterDomain* out() const { + return outputs().at(0)->as(); + } + + ParallelType parallelType() const { + return attribute(0); + } +}; //! IterDomain expression to resize class Resize : public Expr { diff --git a/csrc/multidevice/utils.cpp b/csrc/multidevice/utils.cpp index 21d459889b8..beb7283c5a1 100644 --- a/csrc/multidevice/utils.cpp +++ b/csrc/multidevice/utils.cpp @@ -178,6 +178,8 @@ int64_t getProducingLogicalAxis(const TensorView* tv, IterDomain* id) { // When `unshardedSizes` is given a local tensor of shape [1, 1], it's // unclear the global shape is [1, D] or [D, 1] or even [2, D/2], etc. id = merge->outer(); + } else if (auto* swizzle = dynamic_cast(def)) { + id = swizzle->in(); } else { NVF_THROW( "Unexpected transforms from logical to a DID-parallel allocation " diff --git a/csrc/tensor_view.cpp b/csrc/tensor_view.cpp index 82dbc7230a4..a25be606b98 100644 --- a/csrc/tensor_view.cpp +++ b/csrc/tensor_view.cpp @@ -768,6 +768,15 @@ TensorView* TensorView::swizzle( return this; } +TensorView* TensorView::swizzle1d(int64_t x, ParallelType pt) { + NVF_CHECK( + deviceParallelTypes().contains(pt), + "Swizzle1D only supports device parallel types, given: ", + pt); + domain()->swizzle1d(x, pt); + return this; +} + TensorView* TensorView::rFactor(const std::vector& axes) { NVF_ERROR( !container()->isA(), diff --git a/tests/cpp/test_multidevice_host_ir.cpp b/tests/cpp/test_multidevice_host_ir.cpp index 579f3d8f661..a14b225c2f2 100644 --- a/tests/cpp/test_multidevice_host_ir.cpp +++ b/tests/cpp/test_multidevice_host_ir.cpp @@ -507,6 +507,83 @@ TEST_F(MultiDeviceHostIrTest, SymmetricContiguousView) { << "Output tensor does not match expected values"; } +TEST_F(MultiDeviceTest, SwizzleWithParallelType) { + const int64_t d = communicator_->size(); + const int64_t my_rank = communicator_->deviceId(); + auto mesh = DeviceMesh::createForNumDevices(d); + + auto hic = std::make_unique(); + FusionGuard fg(hic.get()); + { + TensorView* in_tv = makeContigTensor(2); + TensorView* out_tv = set(in_tv); + hic->addInput(in_tv); + hic->addOutput(out_tv); + + for (auto* tv : {in_tv, out_tv}) { + tv->setMemoryType(MemoryType::Global); + tv->setDeviceMesh(mesh); + tv->outer_split(1, d); + tv->axis(1)->parallelize(ParallelType::DIDx); + tv->setAllocationDomain(tv->getLoopDomain(), true); + } + auto* allocate_out = IrBuilder::create( + out_tv, MemoryType::Global, std::vector({}), /*zero_init=*/true); + + for (auto* tv : {in_tv, out_tv}) { + tv->outer_split(0, d); + tv->swizzle1d(0, ParallelType::DIDx); + tv->axis(0)->parallelize(ParallelType::Stream); + } + + auto* stream_index = IrBuilder::create(DataType::Index); + auto* for_loop = IrBuilder::create( + stream_index, + /*start=*/hic->zeroVal(DataType::Index), + /*stop=*/IrBuilder::create(d - 1, DataType::Index)); + + TensorView* in_shard = + ops::newValLike(in_tv, *in_tv->getDataType())->as(); + TensorView* out_shard = + ops::newValLike(out_tv, *out_tv->getDataType())->as(); + + for (auto* tv : {in_shard, out_shard}) { + tv->setDeviceMesh(mesh); + tv->outer_split(1, d); + tv->axis(1)->parallelize(ParallelType::DIDx); + tv->outer_split(0, d); + tv->swizzle1d(0, ParallelType::DIDx); + tv->axis(0)->parallelize(ParallelType::Stream); + tv->setAllocationDomain(tv->getLoopDomain(), true); + } + + IrBuilder::create(in_shard, in_tv, stream_index); + IrBuilder::create(out_shard, out_tv, stream_index); + auto* copy = IrBuilder::create( + LoadStoreOpType::Set, out_shard, in_shard); + + for_loop->body().pushBack(in_shard->definition()); + for_loop->body().pushBack(out_shard->definition()); + for_loop->body().pushBack(copy); + + hic->pushBackTopLevelExprs(allocate_out); + hic->pushBackTopLevelExprs(for_loop); + } + + HostIrEvaluator hie(std::move(hic)); + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA); + at::Tensor unsharded_in = at::randn({d * 3, d * 5}, options); + at::Tensor sharded_in = shardTensor1D(unsharded_in, 1, mesh); + + KernelArgumentHolder ins(sharded_in); + ins.setCacheId(0); + KernelArgumentHolder outs = hie.runWithInputs(ins); + at::Tensor out = outs[0].as(); + at::Tensor expected_out = sharded_in; + expected_out.chunk(d, 0)[(my_rank + d - 1) % d].zero_(); + EXPECT_TRUE(at::allclose(out, expected_out)) << out << " vs " << expected_out; +} + } // namespace hir } // namespace nvfuser