From d563b3f91e97e6061fe6bada295ccae5d8b37113 Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Tue, 27 Jan 2026 07:47:25 -0800 Subject: [PATCH 1/6] transpose --- csrc/scheduler/tools/domain_map.cpp | 9 +++++++ tests/cpp/test_transpose.cpp | 38 +++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+) diff --git a/csrc/scheduler/tools/domain_map.cpp b/csrc/scheduler/tools/domain_map.cpp index b25ed288fde..347928e9027 100644 --- a/csrc/scheduler/tools/domain_map.cpp +++ b/csrc/scheduler/tools/domain_map.cpp @@ -515,6 +515,15 @@ bool TransposeDomainMap::hasAtLeastTwoValidGroups(Fusion* fusion) { if (ref1 == nullptr || ref2 == nullptr) { return false; } + // ref1 and ref2 has the same number of non-broadcast/reduction/device + // dimensions. This is to reject cases like tv0[i0, i1] and tv1[i0, 1] for + // which the non-vectorized load of tv1 is not the bottleneck and pointwise + // scheudler is prefered. + if (scheduler_utils::nLogicalDims(ref1) != + scheduler_utils::nLogicalDims(ref2)) { + return false; + } + // reference 1 is the global reference, so it must have dim mapped the // innermost dim of both groups auto innermost2 = scheduler_utils::innerMostAllocDim(ref2); diff --git a/tests/cpp/test_transpose.cpp b/tests/cpp/test_transpose.cpp index 3d0bb16b87a..019d8b0e66a 100644 --- a/tests/cpp/test_transpose.cpp +++ b/tests/cpp/test_transpose.cpp @@ -20,6 +20,7 @@ #include "scheduler/transpose.h" #include "scheduler/utils.h" #include "tests/cpp/utils.h" +#include "type.h" #include "validator_utils.h" namespace nvfuser { @@ -1409,4 +1410,41 @@ TEST_F(TransposeTest, DanglingBroadcastIssue4957) { testValidate(executor_cache.fusion(), outputs, {t0}, __LINE__, __FILE__); } +TEST_F(NVFuserTest, NoTransposeMaverick17B) { + auto fusion_ptr = std::make_unique(); + FusionGuard fg(fusion_ptr.get()); + Fusion& fusion = *fusion_ptr; + + auto dtype = DataType::BFloat16; + auto tv0 = makeContigConcreteTensor({262144, 5120}, dtype); + auto tv1 = makeContigConcreteTensor({262144, 1}, dtype); + fusion.addInput(tv0); + fusion.addInput(tv1); + auto tv2 = castOp(DataType::Float, tv0); + auto tv3 = castOp(DataType::Float, tv1); + auto tv4 = mul(tv2, tv3); + auto tv5 = castOp(dtype, tv4); + fusion.addOutput(tv5); + + auto options = + at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0); + at::Tensor input0 = at::randn({262144, 5120}, options); + at::Tensor input1 = at::randn({262144, 1}, options); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto outputs = executor_cache.runFusionWithInputs({input0, input1}); + auto runtime = executor_cache.getMostRecentKernelRuntime(); + auto heuristic = runtime->schedulerHeuristics() + ->heuristicsList() + .at(0) + .get() + ->scheduler_type; + NVF_CHECK( + heuristic == SchedulerType::PointWise, + "Unexpected heuristic: ", + heuristic); + testValidate( + executor_cache.fusion(), outputs, {input0, input1}, __LINE__, __FILE__); +} + } // namespace nvfuser From 567dcb11d91fa45388263ad9303fc45e6fccb269 Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Tue, 27 Jan 2026 11:51:16 -0800 Subject: [PATCH 2/6] use number of elements --- csrc/scheduler/tools/domain_map.cpp | 9 ----- csrc/scheduler/transpose.cpp | 9 ++++- tests/cpp/test_transpose.cpp | 58 ++++++++++++++++++++++------- 3 files changed, 52 insertions(+), 24 deletions(-) diff --git a/csrc/scheduler/tools/domain_map.cpp b/csrc/scheduler/tools/domain_map.cpp index 347928e9027..b25ed288fde 100644 --- a/csrc/scheduler/tools/domain_map.cpp +++ b/csrc/scheduler/tools/domain_map.cpp @@ -515,15 +515,6 @@ bool TransposeDomainMap::hasAtLeastTwoValidGroups(Fusion* fusion) { if (ref1 == nullptr || ref2 == nullptr) { return false; } - // ref1 and ref2 has the same number of non-broadcast/reduction/device - // dimensions. This is to reject cases like tv0[i0, i1] and tv1[i0, 1] for - // which the non-vectorized load of tv1 is not the bottleneck and pointwise - // scheudler is prefered. - if (scheduler_utils::nLogicalDims(ref1) != - scheduler_utils::nLogicalDims(ref2)) { - return false; - } - // reference 1 is the global reference, so it must have dim mapped the // innermost dim of both groups auto innermost2 = scheduler_utils::innerMostAllocDim(ref2); diff --git a/csrc/scheduler/transpose.cpp b/csrc/scheduler/transpose.cpp index 766213ca484..2ddf9f8319b 100644 --- a/csrc/scheduler/transpose.cpp +++ b/csrc/scheduler/transpose.cpp @@ -491,9 +491,16 @@ std::string getTransposeRuntimeRejectReason( getReferenceTensors(data_cache, domain_map, grouped_inputs_outputs); auto reference_tensors = reference_tensors_entry.get(); TensorView* reference1 = reference_tensors[0]; - + TensorView* reference2 = reference_tensors[1]; auto [shape_in_ref1, n_elems] = getLoopDomainSizes(data_cache, runtime_info, reference1, domain_map); + auto [_, n_elems2] = + getLoopDomainSizes(data_cache, runtime_info, reference2, domain_map); + if (n_elems != n_elems2) { + return "Transpose scheduler does not perform well on problem sizes with " + "different number of elements. n_elems1: " + + std::to_string(n_elems) + " n_elems2: " + std::to_string(n_elems2); + } auto innermost_info_entry = getInnerMostDimInfoInReference( data_cache, reference_tensors, reference1, domain_map); diff --git a/tests/cpp/test_transpose.cpp b/tests/cpp/test_transpose.cpp index 019d8b0e66a..6ca6cf19ca3 100644 --- a/tests/cpp/test_transpose.cpp +++ b/tests/cpp/test_transpose.cpp @@ -286,11 +286,14 @@ TEST_F(TransposeTest, FusionScheduleTransposeNoReference) { // x->broadcast--add->z // y->broadcast-/ +// pointwise: 61% +// transpose: 39% TEST_F(TransposeTest, FusionScheduleBroadcastOnly) { for (bool contig0 : {true, false}) { for (bool contig1 : {true, false}) { - Fusion fusion; - FusionGuard fg(&fusion); + auto fusion_ptr = std::make_unique(); + FusionGuard fg(fusion_ptr.get()); + Fusion& fusion = *fusion_ptr; auto tv0 = contig0 ? makeContigConcreteTensor({-1, 1, -1}) : makeConcreteTensor({-1, 1, -1}); auto tv1 = contig1 ? makeContigConcreteTensor({-1, -1, 1}) @@ -304,10 +307,24 @@ TEST_F(TransposeTest, FusionScheduleBroadcastOnly) { at::Tensor input0 = at::randn({1024, 1, 256}, options); at::Tensor input1 = at::randn({1024, 1024, 1}, options); - auto cg_outputs = - scheduleAndRun(&fusion, SchedulerType::Transpose, {input0, input1}) - .outputs; - testValidate(&fusion, cg_outputs, {input0, input1}, __LINE__, __FILE__); + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto outputs = executor_cache.runFusionWithInputs({input0, input1}); + auto runtime = executor_cache.getMostRecentKernelRuntime(); + auto heuristic = runtime->schedulerHeuristics() + ->heuristicsList() + .at(0) + .get() + ->scheduler_type; + NVF_CHECK( + heuristic == SchedulerType::PointWise, + "Unexpected heuristic: ", + heuristic); + testValidate( + executor_cache.fusion(), + outputs, + {input0, input1}, + __LINE__, + __FILE__); } } } @@ -629,8 +646,9 @@ TEST_F(TransposeTest, FusionTransposeViewSelfMapping) { // t2->broadcast->sub->mul->relu->t6 // t1------------------' TEST_F(TransposeTest, FusionScheduleTransposeMissingDim) { - Fusion fusion; - FusionGuard fg(&fusion); + auto fusion_ptr = std::make_unique(); + FusionGuard fg(fusion_ptr.get()); + Fusion& fusion = *fusion_ptr; auto tv0 = makeContigTensor(3); auto tv1 = makeContigConcreteTensor({1, -1, 1}); @@ -649,12 +667,24 @@ TEST_F(TransposeTest, FusionScheduleTransposeMissingDim) { at::Tensor input1 = at::randn({1, 512, 1}, options); at::Tensor input2 = at::randn({512}, options); - auto cg_outputs = - scheduleAndRun( - &fusion, SchedulerType::Transpose, {input0, input1, input2}) - .outputs; + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto outputs = executor_cache.runFusionWithInputs({input0, input1, input2}); + auto runtime = executor_cache.getMostRecentKernelRuntime(); + auto heuristic = runtime->schedulerHeuristics() + ->heuristicsList() + .at(0) + .get() + ->scheduler_type; + NVF_CHECK( + heuristic == SchedulerType::PointWise, + "Unexpected heuristic: ", + heuristic); testValidate( - &fusion, cg_outputs, {input0, input1, input2}, __LINE__, __FILE__); + executor_cache.fusion(), + outputs, + {input0, input1, input2}, + __LINE__, + __FILE__); } // x->sin->transpose->cos->y @@ -1410,7 +1440,7 @@ TEST_F(TransposeTest, DanglingBroadcastIssue4957) { testValidate(executor_cache.fusion(), outputs, {t0}, __LINE__, __FILE__); } -TEST_F(NVFuserTest, NoTransposeMaverick17B) { +TEST_F(TransposeTest, NoTransposeMaverick17B) { auto fusion_ptr = std::make_unique(); FusionGuard fg(fusion_ptr.get()); Fusion& fusion = *fusion_ptr; From d807b3cc7db5b529d2538e9162a7911a76916b9d Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Tue, 27 Jan 2026 18:26:29 -0800 Subject: [PATCH 3/6] bcast map --- csrc/scheduler/tools/domain_map.cpp | 20 ++++++++++++++++++++ csrc/scheduler/transpose.cpp | 9 +-------- 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/csrc/scheduler/tools/domain_map.cpp b/csrc/scheduler/tools/domain_map.cpp index b25ed288fde..a03deec3442 100644 --- a/csrc/scheduler/tools/domain_map.cpp +++ b/csrc/scheduler/tools/domain_map.cpp @@ -515,6 +515,26 @@ bool TransposeDomainMap::hasAtLeastTwoValidGroups(Fusion* fusion) { if (ref1 == nullptr || ref2 == nullptr) { return false; } + // when all loop domains are mapped to each other, it means the two groups + // are due to broadcast. They are not considered as valid groups since + // pointwise scheduler handles broadcast well. + int64_t num_mapped = 0, bcast_domains = 0; + const auto& ref1_loop = ref1->getLoopDomain(); + const auto& ref2_loop = ref2->getLoopDomain(); + if (ref1_loop.size() == ref2_loop.size()) { + for (auto i : arange(ref1_loop.size())) { + if (domain_map.getComputeAtMap().areMapped( + ref1_loop[i], ref2_loop[i], IdMappingMode::PERMISSIVE)) { + num_mapped++; + } + if (ref1_loop[i]->isBroadcast() || ref2_loop[i]->isBroadcast()) { + bcast_domains++; + } + } + if (num_mapped == (int64_t)ref1_loop.size() && bcast_domains > 0) { + return false; + } + } // reference 1 is the global reference, so it must have dim mapped the // innermost dim of both groups auto innermost2 = scheduler_utils::innerMostAllocDim(ref2); diff --git a/csrc/scheduler/transpose.cpp b/csrc/scheduler/transpose.cpp index 2ddf9f8319b..766213ca484 100644 --- a/csrc/scheduler/transpose.cpp +++ b/csrc/scheduler/transpose.cpp @@ -491,16 +491,9 @@ std::string getTransposeRuntimeRejectReason( getReferenceTensors(data_cache, domain_map, grouped_inputs_outputs); auto reference_tensors = reference_tensors_entry.get(); TensorView* reference1 = reference_tensors[0]; - TensorView* reference2 = reference_tensors[1]; + auto [shape_in_ref1, n_elems] = getLoopDomainSizes(data_cache, runtime_info, reference1, domain_map); - auto [_, n_elems2] = - getLoopDomainSizes(data_cache, runtime_info, reference2, domain_map); - if (n_elems != n_elems2) { - return "Transpose scheduler does not perform well on problem sizes with " - "different number of elements. n_elems1: " + - std::to_string(n_elems) + " n_elems2: " + std::to_string(n_elems2); - } auto innermost_info_entry = getInnerMostDimInfoInReference( data_cache, reference_tensors, reference1, domain_map); From 86a826c972dcd625480ae1bf9852bc2e82aa1de6 Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Tue, 27 Jan 2026 18:38:16 -0800 Subject: [PATCH 4/6] bcast map --- csrc/scheduler/tools/domain_map.cpp | 40 +++++++++++++++++------------ 1 file changed, 23 insertions(+), 17 deletions(-) diff --git a/csrc/scheduler/tools/domain_map.cpp b/csrc/scheduler/tools/domain_map.cpp index a03deec3442..cca7ba960eb 100644 --- a/csrc/scheduler/tools/domain_map.cpp +++ b/csrc/scheduler/tools/domain_map.cpp @@ -8,6 +8,8 @@ #include #include +#include + namespace nvfuser { namespace scheduler_tools { @@ -515,25 +517,29 @@ bool TransposeDomainMap::hasAtLeastTwoValidGroups(Fusion* fusion) { if (ref1 == nullptr || ref2 == nullptr) { return false; } - // when all loop domains are mapped to each other, it means the two groups - // are due to broadcast. They are not considered as valid groups since - // pointwise scheduler handles broadcast well. - int64_t num_mapped = 0, bcast_domains = 0; + + // For grouping caused by permutation, the corresponding loop domains should + // not be all mapped to each other. If they are, it means the two groups are + // due to broadcast. In this case, they are not considered as valid groups + // since the broadcast tensor has a smaller size and pointwise scheduler + // handles broadcast well through unrolling and caching at all levels. const auto& ref1_loop = ref1->getLoopDomain(); const auto& ref2_loop = ref2->getLoopDomain(); - if (ref1_loop.size() == ref2_loop.size()) { - for (auto i : arange(ref1_loop.size())) { - if (domain_map.getComputeAtMap().areMapped( - ref1_loop[i], ref2_loop[i], IdMappingMode::PERMISSIVE)) { - num_mapped++; - } - if (ref1_loop[i]->isBroadcast() || ref2_loop[i]->isBroadcast()) { - bcast_domains++; - } - } - if (num_mapped == (int64_t)ref1_loop.size() && bcast_domains > 0) { - return false; - } + const auto& ca_map = domain_map.getComputeAtMap(); + const bool all_mapped = std::ranges::equal( + ref1_loop, ref2_loop, [&](IterDomain* id1, IterDomain* id2) { + return ca_map.areMapped(id1, id2, IdMappingMode::PERMISSIVE); + }); + if (all_mapped) { + // Not required, just to validate the assumption that all_mapped implies + // any_bcast + const bool any_bcast = + std::ranges::any_of( + ref1_loop, [](IterDomain* id) { return id->isBroadcast(); }) || + std::ranges::any_of( + ref2_loop, [](IterDomain* id) { return id->isBroadcast(); }); + NVF_ERROR(any_bcast, "all_mapped implies any_bcast"); + return false; } // reference 1 is the global reference, so it must have dim mapped the // innermost dim of both groups From 4ff918db325ad5378ae6971452e812c348bbcd86 Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Wed, 28 Jan 2026 07:58:33 -0800 Subject: [PATCH 5/6] use allocation domain --- csrc/scheduler/tools/domain_map.cpp | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/csrc/scheduler/tools/domain_map.cpp b/csrc/scheduler/tools/domain_map.cpp index cca7ba960eb..d9b06531da6 100644 --- a/csrc/scheduler/tools/domain_map.cpp +++ b/csrc/scheduler/tools/domain_map.cpp @@ -518,13 +518,21 @@ bool TransposeDomainMap::hasAtLeastTwoValidGroups(Fusion* fusion) { return false; } + // reference 1 is the global reference, so it must have dim mapped the + // innermost dim of both groups + auto innermost2 = scheduler_utils::innerMostAllocDim(ref2); + auto mapped_id = domain_map.getMappedAllocDimIn(ref1, innermost2); + if (mapped_id == nullptr) { + return false; + } + // For grouping caused by permutation, the corresponding loop domains should // not be all mapped to each other. If they are, it means the two groups are // due to broadcast. In this case, they are not considered as valid groups // since the broadcast tensor has a smaller size and pointwise scheduler // handles broadcast well through unrolling and caching at all levels. - const auto& ref1_loop = ref1->getLoopDomain(); - const auto& ref2_loop = ref2->getLoopDomain(); + const auto& ref1_loop = ref1->getMaybeAllocationDomain(); + const auto& ref2_loop = ref2->getMaybeAllocationDomain(); const auto& ca_map = domain_map.getComputeAtMap(); const bool all_mapped = std::ranges::equal( ref1_loop, ref2_loop, [&](IterDomain* id1, IterDomain* id2) { @@ -538,13 +546,13 @@ bool TransposeDomainMap::hasAtLeastTwoValidGroups(Fusion* fusion) { ref1_loop, [](IterDomain* id) { return id->isBroadcast(); }) || std::ranges::any_of( ref2_loop, [](IterDomain* id) { return id->isBroadcast(); }); - NVF_ERROR(any_bcast, "all_mapped implies any_bcast"); + NVF_ERROR( + any_bcast, + "all_mapped implies any_bcast, ca_map:\n", + ca_map.toString()); return false; } - // reference 1 is the global reference, so it must have dim mapped the - // innermost dim of both groups - auto innermost2 = scheduler_utils::innerMostAllocDim(ref2); - return domain_map.getMappedAllocDimIn(ref1, innermost2) != nullptr; + return true; } int64_t TransposeDomainMap::getInnerLeafDim( From cbeb32d7a3bb87ae5ccc86a7f567756db0f3ddf6 Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Fri, 30 Jan 2026 12:42:53 -0800 Subject: [PATCH 6/6] comment --- csrc/scheduler/tools/domain_map.cpp | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/csrc/scheduler/tools/domain_map.cpp b/csrc/scheduler/tools/domain_map.cpp index d9b06531da6..482d4f3aa88 100644 --- a/csrc/scheduler/tools/domain_map.cpp +++ b/csrc/scheduler/tools/domain_map.cpp @@ -526,11 +526,16 @@ bool TransposeDomainMap::hasAtLeastTwoValidGroups(Fusion* fusion) { return false; } - // For grouping caused by permutation, the corresponding loop domains should - // not be all mapped to each other. If they are, it means the two groups are - // due to broadcast. In this case, they are not considered as valid groups - // since the broadcast tensor has a smaller size and pointwise scheduler - // handles broadcast well through unrolling and caching at all levels. + // For grouping caused by permutation, the corresponding allocation domains + // should not be all mapped to each other. If they are, it means the two + // groups are due to broadcast. In this case, they are not considered as valid + // groups since the broadcast tensor has a smaller size and pointwise + // scheduler handles broadcast well through unrolling and caching at all + // levels. For example, in TransposeTest.NoTransposeMaverick17B, two inputs + // are tv0[i0, i1] and tv1[i2, b3] where i0/i2 and i1/b3 are mapped to each + // other. However, tv0 and tv1 are in two different groups because of the + // broadcast. In this case, we should use the pointwise scheduler instead of + // the transpose scheduler. const auto& ref1_loop = ref1->getMaybeAllocationDomain(); const auto& ref2_loop = ref2->getMaybeAllocationDomain(); const auto& ca_map = domain_map.getComputeAtMap();