Skip to content
41 changes: 40 additions & 1 deletion csrc/scheduler/tools/domain_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#include <scheduler/tools/domain_map.h>
#include <scheduler/utils.h>

#include <ranges>

namespace nvfuser {
namespace scheduler_tools {

Expand Down Expand Up @@ -515,10 +517,47 @@ bool TransposeDomainMap::hasAtLeastTwoValidGroups(Fusion* fusion) {
if (ref1 == nullptr || ref2 == nullptr) {
return false;
}

// reference 1 is the global reference, so it must have dim mapped the
// innermost dim of both groups
auto innermost2 = scheduler_utils::innerMostAllocDim(ref2);
return domain_map.getMappedAllocDimIn(ref1, innermost2) != nullptr;
auto mapped_id = domain_map.getMappedAllocDimIn(ref1, innermost2);
if (mapped_id == nullptr) {
return false;
}

// For grouping caused by permutation, the corresponding allocation domains
// should not be all mapped to each other. If they are, it means the two
// groups are due to broadcast. In this case, they are not considered as valid
// groups since the broadcast tensor has a smaller size and pointwise
// scheduler handles broadcast well through unrolling and caching at all
// levels. For example, in TransposeTest.NoTransposeMaverick17B, two inputs
// are tv0[i0, i1] and tv1[i2, b3] where i0/i2 and i1/b3 are mapped to each
// other. However, tv0 and tv1 are in two different groups because of the
// broadcast. In this case, we should use the pointwise scheduler instead of
// the transpose scheduler.
const auto& ref1_loop = ref1->getMaybeAllocationDomain();
const auto& ref2_loop = ref2->getMaybeAllocationDomain();
const auto& ca_map = domain_map.getComputeAtMap();
const bool all_mapped = std::ranges::equal(
ref1_loop, ref2_loop, [&](IterDomain* id1, IterDomain* id2) {
return ca_map.areMapped(id1, id2, IdMappingMode::PERMISSIVE);
});
if (all_mapped) {
// Not required, just to validate the assumption that all_mapped implies
// any_bcast
const bool any_bcast =
std::ranges::any_of(
ref1_loop, [](IterDomain* id) { return id->isBroadcast(); }) ||
std::ranges::any_of(
ref2_loop, [](IterDomain* id) { return id->isBroadcast(); });
NVF_ERROR(
any_bcast,
"all_mapped implies any_bcast, ca_map:\n",
ca_map.toString());
return false;
}
return true;
}

int64_t TransposeDomainMap::getInnerLeafDim(
Expand Down
94 changes: 81 additions & 13 deletions tests/cpp/test_transpose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "scheduler/transpose.h"
#include "scheduler/utils.h"
#include "tests/cpp/utils.h"
#include "type.h"
#include "validator_utils.h"

namespace nvfuser {
Expand Down Expand Up @@ -285,11 +286,14 @@ TEST_F(TransposeTest, FusionScheduleTransposeNoReference) {

// x->broadcast--add->z
// y->broadcast-/
// pointwise: 61%
// transpose: 39%
TEST_F(TransposeTest, FusionScheduleBroadcastOnly) {
for (bool contig0 : {true, false}) {
for (bool contig1 : {true, false}) {
Fusion fusion;
FusionGuard fg(&fusion);
auto fusion_ptr = std::make_unique<Fusion>();
FusionGuard fg(fusion_ptr.get());
Fusion& fusion = *fusion_ptr;
auto tv0 = contig0 ? makeContigConcreteTensor({-1, 1, -1})
: makeConcreteTensor({-1, 1, -1});
auto tv1 = contig1 ? makeContigConcreteTensor({-1, -1, 1})
Expand All @@ -303,10 +307,24 @@ TEST_F(TransposeTest, FusionScheduleBroadcastOnly) {
at::Tensor input0 = at::randn({1024, 1, 256}, options);
at::Tensor input1 = at::randn({1024, 1024, 1}, options);

auto cg_outputs =
scheduleAndRun(&fusion, SchedulerType::Transpose, {input0, input1})
.outputs;
testValidate(&fusion, cg_outputs, {input0, input1}, __LINE__, __FILE__);
FusionExecutorCache executor_cache(std::move(fusion_ptr));
auto outputs = executor_cache.runFusionWithInputs({input0, input1});
auto runtime = executor_cache.getMostRecentKernelRuntime();
auto heuristic = runtime->schedulerHeuristics()
->heuristicsList()
.at(0)
.get()
->scheduler_type;
NVF_CHECK(
heuristic == SchedulerType::PointWise,
"Unexpected heuristic: ",
heuristic);
testValidate(
executor_cache.fusion(),
outputs,
{input0, input1},
__LINE__,
__FILE__);
}
}
}
Expand Down Expand Up @@ -628,8 +646,9 @@ TEST_F(TransposeTest, FusionTransposeViewSelfMapping) {
// t2->broadcast->sub->mul->relu->t6
// t1------------------'
TEST_F(TransposeTest, FusionScheduleTransposeMissingDim) {
Fusion fusion;
FusionGuard fg(&fusion);
auto fusion_ptr = std::make_unique<Fusion>();
FusionGuard fg(fusion_ptr.get());
Fusion& fusion = *fusion_ptr;

auto tv0 = makeContigTensor(3);
auto tv1 = makeContigConcreteTensor({1, -1, 1});
Expand All @@ -648,12 +667,24 @@ TEST_F(TransposeTest, FusionScheduleTransposeMissingDim) {
at::Tensor input1 = at::randn({1, 512, 1}, options);
at::Tensor input2 = at::randn({512}, options);

auto cg_outputs =
scheduleAndRun(
&fusion, SchedulerType::Transpose, {input0, input1, input2})
.outputs;
FusionExecutorCache executor_cache(std::move(fusion_ptr));
auto outputs = executor_cache.runFusionWithInputs({input0, input1, input2});
auto runtime = executor_cache.getMostRecentKernelRuntime();
auto heuristic = runtime->schedulerHeuristics()
->heuristicsList()
.at(0)
.get()
->scheduler_type;
NVF_CHECK(
heuristic == SchedulerType::PointWise,
"Unexpected heuristic: ",
heuristic);
testValidate(
&fusion, cg_outputs, {input0, input1, input2}, __LINE__, __FILE__);
executor_cache.fusion(),
outputs,
{input0, input1, input2},
__LINE__,
__FILE__);
}

// x->sin->transpose->cos->y
Expand Down Expand Up @@ -1409,4 +1440,41 @@ TEST_F(TransposeTest, DanglingBroadcastIssue4957) {
testValidate(executor_cache.fusion(), outputs, {t0}, __LINE__, __FILE__);
}

TEST_F(TransposeTest, NoTransposeMaverick17B) {
auto fusion_ptr = std::make_unique<Fusion>();
FusionGuard fg(fusion_ptr.get());
Fusion& fusion = *fusion_ptr;

auto dtype = DataType::BFloat16;
auto tv0 = makeContigConcreteTensor({262144, 5120}, dtype);
auto tv1 = makeContigConcreteTensor({262144, 1}, dtype);
fusion.addInput(tv0);
fusion.addInput(tv1);
auto tv2 = castOp(DataType::Float, tv0);
auto tv3 = castOp(DataType::Float, tv1);
auto tv4 = mul(tv2, tv3);
auto tv5 = castOp(dtype, tv4);
fusion.addOutput(tv5);

auto options =
at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0);
at::Tensor input0 = at::randn({262144, 5120}, options);
at::Tensor input1 = at::randn({262144, 1}, options);

FusionExecutorCache executor_cache(std::move(fusion_ptr));
auto outputs = executor_cache.runFusionWithInputs({input0, input1});
auto runtime = executor_cache.getMostRecentKernelRuntime();
auto heuristic = runtime->schedulerHeuristics()
->heuristicsList()
.at(0)
.get()
->scheduler_type;
NVF_CHECK(
heuristic == SchedulerType::PointWise,
"Unexpected heuristic: ",
heuristic);
testValidate(
executor_cache.fusion(), outputs, {input0, input1}, __LINE__, __FILE__);
}

} // namespace nvfuser