From f06fd561e241d7e2c14836af8f3f235ad66df672 Mon Sep 17 00:00:00 2001 From: tbqh Date: Wed, 14 Jan 2026 03:35:43 -0800 Subject: [PATCH 1/6] Optimize TMA inner reduction --- csrc/scheduler/reduction.cpp | 14 +++------ csrc/scheduler/reduction_tma.cpp | 49 ++++++++++++++++++++------------ csrc/scheduler/reduction_tma.h | 4 +-- tests/cpp/test_reduction.cpp | 14 +++------ 4 files changed, 41 insertions(+), 40 deletions(-) diff --git a/csrc/scheduler/reduction.cpp b/csrc/scheduler/reduction.cpp index 5b11a13e929..e40a7034816 100644 --- a/csrc/scheduler/reduction.cpp +++ b/csrc/scheduler/reduction.cpp @@ -206,17 +206,11 @@ bool mayUseTma( return false; } - // For small TMA sizes, the smem indirection is not worth it. - if (props.total_reduction_numel < 128) { - return false; - } + int64_t dtype_bytes = props.max_dtype_size_bit_for_vectorization / 8; + uint64_t total_reduction_bytes = props.total_reduction_numel * dtype_bytes; - // Require reduction dim fits into smem, until we add iteration over large - // reduction dim. - const int64_t smem_elems = (dev_prop->sharedMemPerBlockOptin * 8) / - props.max_dtype_size_bit_for_vectorization; - - if (props.inner_most_dimension_numel > smem_elems) { + // For small TMA sizes, the smem indirection is not worth it. + if (total_reduction_bytes < 16384) { return false; } diff --git a/csrc/scheduler/reduction_tma.cpp b/csrc/scheduler/reduction_tma.cpp index 3a3ac4fe319..7e988233f29 100644 --- a/csrc/scheduler/reduction_tma.cpp +++ b/csrc/scheduler/reduction_tma.cpp @@ -36,14 +36,20 @@ std::unique_ptr getReductionHeuristics( target_threads_per_sm, props.has_mufu_computation); - // Initialize split factors - int64_t vectorization_factor = - std::min(target_vect_unroll, props.vectorize_factor); - int64_t threads_per_block = 128; - int64_t unroll_factor = target_vect_unroll / vectorization_factor; + const int64_t smem_elems = (dev_prop->sharedMemPerBlockOptin * 8) / + props.max_dtype_size_bit_for_vectorization; + + const int64_t half_smem_elems = scheduler_utils::lastPow2(smem_elems) / 2; + + // Split the reduction dimension to reduce smem usage. + uint64_t tma_split_factor = + ceilDiv(props.inner_most_dimension_numel, half_smem_elems); + + int64_t threads_per_block = 256; + int64_t unroll_factor = scheduler_utils::lastPow2(target_vect_unroll); auto params = std::make_unique(); - params->vectorization_factor = vectorization_factor; + params->tma_split_factor = tma_split_factor; params->threads_per_block = threads_per_block; params->unroll_factor = unroll_factor; @@ -65,6 +71,8 @@ void scheduleReduction(Fusion* fusion, const TmaInnerReductionParams* rparams) { scheduler_utils::prepareForMemoryTypePromotion(fusion); + scheduler_utils::cacheAndForkOutputs(fusion, true); + std::vector tma_tvs; for (auto [tv, input_idx] : cached_inputs) { if (auto load_op = dynamic_cast(tv->definition())) { @@ -87,6 +95,15 @@ void scheduleReduction(Fusion* fusion, const TmaInnerReductionParams* rparams) { bool has_red_axis = dim_analysis.second; NVF_ERROR(has_iter_axis && has_red_axis); + if (rparams->tma_split_factor > 1) { + reduction_tv->split(1, rparams->tma_split_factor, false); + reduction_tv->axis(1)->parallelize(ParallelType::Serial); + + for (auto tma_tv : tma_tvs) { + tma_tv->split(1, rparams->tma_split_factor, false); + } + } + // Propagate the merges to all TMA TVs TransformPropagator tma_propagator(reduction_tv); SetSelector tma_selector({tma_tvs.begin(), tma_tvs.end()}); @@ -94,7 +111,7 @@ void scheduleReduction(Fusion* fusion, const TmaInnerReductionParams* rparams) { .traverse(&tma_propagator); int64_t iter_axis = 0; - int64_t inner_reduce_axis = 1; + int64_t inner_reduce_axis = rparams->tma_split_factor > 1 ? 2 : 1; // Schedule TMA tvs as [BIDx, Bulk] tma_tvs[0]->axis(iter_axis)->parallelize(ParallelType::BIDx); @@ -104,26 +121,22 @@ void scheduleReduction(Fusion* fusion, const TmaInnerReductionParams* rparams) { // Non-TMA scheduling // // Apply splits following the pattern: - // [I, R] -> [I, R/vect, vect] - // -> [I, R/vect/tidx, tidx, vect] - // -> [I, R/vect/tidx/unroll, unroll, tidx, vect] - - // Split 1: Vectorization factor (innermost serial split for TMA) - reduction_tv->split(inner_reduce_axis, rparams->vectorization_factor); - reduction_tv->axis(inner_reduce_axis + 1)->parallelize(ParallelType::Serial); + // [I, R] -> [I, R/tidx, tidx] + // -> [I, R/tidx/unroll, unroll, tidx] - // Split 2: TIDx (always applied) - reduction_tv->split(inner_reduce_axis, rparams->threads_per_block); + // Split 1: TIDx (always applied) + reduction_tv->split(inner_reduce_axis, 256); reduction_tv->axis(inner_reduce_axis + 1)->parallelize(ParallelType::TIDx); + reduction_tv->axis(inner_reduce_axis + 1)->padToMultipleOfWarp(); - // Split 3: Inner unroll (outside of TIDx) + // Split 2: Inner unroll (outside of TIDx) if (rparams->unroll_factor > 1) { reduction_tv->split(inner_reduce_axis, rparams->unroll_factor); reduction_tv->axis(inner_reduce_axis + 1) ->parallelize(ParallelType::Unroll); } - // Split 4: Unswitch (always applied) + // Split 3: Unswitch (always applied) reduction_tv->split(inner_reduce_axis, 1); reduction_tv->axis(inner_reduce_axis + 1) ->parallelize(ParallelType::Unswitch); diff --git a/csrc/scheduler/reduction_tma.h b/csrc/scheduler/reduction_tma.h index 29efaa23188..9909edae374 100644 --- a/csrc/scheduler/reduction_tma.h +++ b/csrc/scheduler/reduction_tma.h @@ -19,8 +19,8 @@ class TmaInnerReductionParams : public HeuristicParams { SchedulerType scheduler_type = SchedulerType::Reduction) : HeuristicParams(scheduler_type) {}; - // Inner serial split factor (similar to vectorization for non-TMA) - int64_t vectorization_factor = 1; + // Outer serial split factor for reduce dimension, to better fit into smem. + int64_t tma_split_factor = 1; // Number of threads per block for TIDx parallelization int64_t threads_per_block = 1; diff --git a/tests/cpp/test_reduction.cpp b/tests/cpp/test_reduction.cpp index afbd59984f6..57922acd499 100644 --- a/tests/cpp/test_reduction.cpp +++ b/tests/cpp/test_reduction.cpp @@ -2815,11 +2815,6 @@ class TmaInnerReductionTest // Check if we expect TMA to be used based on mayUseTma() conditions bool expectTmaUsed(DataType dtype, int64_t reduction_size) { - // Skip TMA for small reductions - if (reduction_size < 128) { - return false; - } - // TMA requires 16-byte alignment (vectorize_factor > 1) int64_t dtype_size_bit = dataTypeSizeBit(dtype); int64_t dtype_bytes = dtype_size_bit / 8; @@ -2828,11 +2823,10 @@ class TmaInnerReductionTest return false; } - // Reduction dim must fit into smem - auto dev_prop = at::cuda::getCurrentDeviceProperties(); - int64_t smem_elems = - (dev_prop->sharedMemPerBlockOptin * 8) / dtype_size_bit; - if (reduction_size > smem_elems) { + uint64_t total_reduction_bytes = reduction_size * dtype_bytes; + + // Skip TMA for small reductions + if (total_reduction_bytes < 16384) { return false; } From 6f283184f4f2a35a58f069250671c3c9c048fadd Mon Sep 17 00:00:00 2001 From: tbqh Date: Fri, 23 Jan 2026 03:34:40 -0800 Subject: [PATCH 2/6] Use "rparams->threads_per_block" for consistency --- csrc/scheduler/reduction_tma.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/scheduler/reduction_tma.cpp b/csrc/scheduler/reduction_tma.cpp index 7e988233f29..d07c1a86ee6 100644 --- a/csrc/scheduler/reduction_tma.cpp +++ b/csrc/scheduler/reduction_tma.cpp @@ -125,7 +125,7 @@ void scheduleReduction(Fusion* fusion, const TmaInnerReductionParams* rparams) { // -> [I, R/tidx/unroll, unroll, tidx] // Split 1: TIDx (always applied) - reduction_tv->split(inner_reduce_axis, 256); + reduction_tv->split(inner_reduce_axis, rparams->threads_per_block); reduction_tv->axis(inner_reduce_axis + 1)->parallelize(ParallelType::TIDx); reduction_tv->axis(inner_reduce_axis + 1)->padToMultipleOfWarp(); From 5dc00e40a19187248dcb32978aba564c9652fc81 Mon Sep 17 00:00:00 2001 From: tbqh Date: Mon, 26 Jan 2026 00:35:26 -0800 Subject: [PATCH 3/6] Add algorithm to search for suitable TMA split size --- csrc/scheduler/reduction_tma.cpp | 75 +++++++++++++++++++++++++++++--- tests/cpp/test_reduction.cpp | 6 +++ 2 files changed, 75 insertions(+), 6 deletions(-) diff --git a/csrc/scheduler/reduction_tma.cpp b/csrc/scheduler/reduction_tma.cpp index d07c1a86ee6..cb855996dfb 100644 --- a/csrc/scheduler/reduction_tma.cpp +++ b/csrc/scheduler/reduction_tma.cpp @@ -17,6 +17,54 @@ namespace nvfuser { namespace reduction { namespace tma { +namespace { + +// Find the smallest split factor that: +// 1. Evenly divides reduction_numel +// 2. Results in an element count that is divisible by alignment +// 3. Results in an element count that is inside [lower_elem_bound, +// upper_elem_bound] +int64_t getTmaSplit( + int64_t numel, + int64_t alignment, + int64_t lower_elem_bound, + int64_t upper_elem_bound) { + // Lower & upper bounds for the split factor + const int64_t split_lower = ceilDiv(numel, upper_elem_bound); + const int64_t split_upper = std::max(int64_t(1), numel / lower_elem_bound); + + // Rather than linearly searching the whole range, use the fact that any + // divisor <= sqrt(numel) will be paired with another divisor >= sqrt(numel). + // Therefore we can stop at sqrt(numel) since we want to minimize the split + // divisor. + int64_t sqrt_n = int64_t(std::ceil(std::sqrt(double(numel)))); + for (int64_t d = split_lower; d <= std::min(split_upper, sqrt_n); d++) { + if (numel % d == 0) { + int64_t tma_elems = numel / d; + if (tma_elems % alignment == 0) { + return d; + } + } + } + + // The previous loop searched where the small divisor is within the range + // [split_lower, split_upper]. Now we check for cases where the large divisor + // is within that range. + for (int64_t d = sqrt_n; d >= 1; d--) { + if (numel % d == 0) { + int64_t paired = numel / d; + if (split_lower <= paired && paired <= split_upper) { + int64_t tma_elems = numel / paired; + if (tma_elems % alignment == 0) { + return paired; + } + } + } + } + + return 0; +} +} // namespace std::unique_ptr getReductionHeuristics( Fusion* fusion, @@ -36,14 +84,29 @@ std::unique_ptr getReductionHeuristics( target_threads_per_sm, props.has_mufu_computation); - const int64_t smem_elems = (dev_prop->sharedMemPerBlockOptin * 8) / - props.max_dtype_size_bit_for_vectorization; + uint64_t dtype_bytes = props.max_dtype_size_bit_for_vectorization / 8; + uint64_t smem_elems = dev_prop->sharedMemPerBlockOptin / dtype_bytes; + + // Heuristics: Require TMA loads are at least 16KB, and consume up to half of + // shared memory. + constexpr int64_t min_tma_bytes = 16384; + const int64_t lower_elem_bound = min_tma_bytes / dtype_bytes; + const int64_t upper_elem_bound = smem_elems / 2; - const int64_t half_smem_elems = scheduler_utils::lastPow2(smem_elems) / 2; + // TMA requires 16-byte alignment after any splits + const int64_t aligned_elems = 16 / dtype_bytes; - // Split the reduction dimension to reduce smem usage. - uint64_t tma_split_factor = - ceilDiv(props.inner_most_dimension_numel, half_smem_elems); + // Search for a suitable split factor + const int64_t tma_split_factor = getTmaSplit( + props.inner_most_dimension_numel, + aligned_elems, + lower_elem_bound, + upper_elem_bound); + + // If no valid split factor was found, fallback to non-TMA + if (tma_split_factor == 0) { + return nullptr; + } int64_t threads_per_block = 256; int64_t unroll_factor = scheduler_utils::lastPow2(target_vect_unroll); diff --git a/tests/cpp/test_reduction.cpp b/tests/cpp/test_reduction.cpp index 57922acd499..0c1bcb7c7ed 100644 --- a/tests/cpp/test_reduction.cpp +++ b/tests/cpp/test_reduction.cpp @@ -2830,6 +2830,12 @@ class TmaInnerReductionTest return false; } + // This test has a few large, non-power-2 shapes, which TMA will reject + // when searching for a suitable split. + if (reduction_size > 1024 * 1024 && (reduction_size % 32) != 0) { + return false; + } + return true; } From bbf43512b7111cf98aee0f278f12c4189243a582 Mon Sep 17 00:00:00 2001 From: tbqh Date: Mon, 26 Jan 2026 00:36:07 -0800 Subject: [PATCH 4/6] Remove padToMultipleOfWarp() --- csrc/scheduler/reduction_tma.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/csrc/scheduler/reduction_tma.cpp b/csrc/scheduler/reduction_tma.cpp index cb855996dfb..2e20c411c00 100644 --- a/csrc/scheduler/reduction_tma.cpp +++ b/csrc/scheduler/reduction_tma.cpp @@ -190,7 +190,6 @@ void scheduleReduction(Fusion* fusion, const TmaInnerReductionParams* rparams) { // Split 1: TIDx (always applied) reduction_tv->split(inner_reduce_axis, rparams->threads_per_block); reduction_tv->axis(inner_reduce_axis + 1)->parallelize(ParallelType::TIDx); - reduction_tv->axis(inner_reduce_axis + 1)->padToMultipleOfWarp(); // Split 2: Inner unroll (outside of TIDx) if (rparams->unroll_factor > 1) { From 35f6b0d775e36232155a7ee0bf667d08ce4e7974 Mon Sep 17 00:00:00 2001 From: tbqh Date: Mon, 26 Jan 2026 02:06:53 -0800 Subject: [PATCH 5/6] Add explanation for heuristics and tweak threads_per_block for mufu kernels --- csrc/scheduler/reduction.cpp | 6 ++++-- csrc/scheduler/reduction_tma.cpp | 32 +++++++++++++++++++------------- tests/cpp/test_reduction.cpp | 6 ++++-- 3 files changed, 27 insertions(+), 17 deletions(-) diff --git a/csrc/scheduler/reduction.cpp b/csrc/scheduler/reduction.cpp index e40a7034816..1ef03929dd3 100644 --- a/csrc/scheduler/reduction.cpp +++ b/csrc/scheduler/reduction.cpp @@ -209,8 +209,10 @@ bool mayUseTma( int64_t dtype_bytes = props.max_dtype_size_bit_for_vectorization / 8; uint64_t total_reduction_bytes = props.total_reduction_numel * dtype_bytes; - // For small TMA sizes, the smem indirection is not worth it. - if (total_reduction_bytes < 16384) { + // Minimum TMA transfer size, below which it seems much slower than non-TMA. + uint64_t min_tma_bytes = 16384; + + if (total_reduction_bytes < min_tma_bytes) { return false; } diff --git a/csrc/scheduler/reduction_tma.cpp b/csrc/scheduler/reduction_tma.cpp index 2e20c411c00..37ae9cdba0f 100644 --- a/csrc/scheduler/reduction_tma.cpp +++ b/csrc/scheduler/reduction_tma.cpp @@ -74,21 +74,14 @@ std::unique_ptr getReductionHeuristics( FusionGuard fg(fusion); auto dev_prop = at::cuda::getCurrentDeviceProperties(); - const int64_t max_threads_per_sm = dev_prop->maxThreadsPerMultiProcessor; - const int64_t target_threads_per_sm = max_threads_per_sm / 2; - - auto target_vect_unroll = reduction_scheduler_utils::getVectUnroll( - props.max_dtype_size_bit_for_vectorization, - props.vectorize_factor, - props.n_tensor_inputs, - target_threads_per_sm, - props.has_mufu_computation); uint64_t dtype_bytes = props.max_dtype_size_bit_for_vectorization / 8; uint64_t smem_elems = dev_prop->sharedMemPerBlockOptin / dtype_bytes; - // Heuristics: Require TMA loads are at least 16KB, and consume up to half of - // shared memory. + // Search for a suitable split factor. Lower and upper bounds are based on + // heuristics. We require TMA loads are at least 16KB to overcome TMA + // overhead. And we require those loads consume less than half of available + // shared memory, to maintain reasonable occupancy. constexpr int64_t min_tma_bytes = 16384; const int64_t lower_elem_bound = min_tma_bytes / dtype_bytes; const int64_t upper_elem_bound = smem_elems / 2; @@ -96,7 +89,6 @@ std::unique_ptr getReductionHeuristics( // TMA requires 16-byte alignment after any splits const int64_t aligned_elems = 16 / dtype_bytes; - // Search for a suitable split factor const int64_t tma_split_factor = getTmaSplit( props.inner_most_dimension_numel, aligned_elems, @@ -108,7 +100,21 @@ std::unique_ptr getReductionHeuristics( return nullptr; } - int64_t threads_per_block = 256; + // These are derived from benchmarking. + int64_t threads_per_block = has_mufu_computation ? 512 : 256; + + const int64_t max_threads_per_sm = dev_prop->maxThreadsPerMultiProcessor; + const int64_t target_threads_per_sm = max_threads_per_sm / 2; + + auto target_vect_unroll = reduction_scheduler_utils::getVectUnroll( + props.max_dtype_size_bit_for_vectorization, + props.vectorize_factor, + props.n_tensor_inputs, + target_threads_per_sm, + props.has_mufu_computation); + + // TMA kernel doesn't do vectorization due to working out of shared memory. + // Instead, fold the entire vect_unroll into unroll_factor. int64_t unroll_factor = scheduler_utils::lastPow2(target_vect_unroll); auto params = std::make_unique(); diff --git a/tests/cpp/test_reduction.cpp b/tests/cpp/test_reduction.cpp index 0c1bcb7c7ed..ba36f8fd608 100644 --- a/tests/cpp/test_reduction.cpp +++ b/tests/cpp/test_reduction.cpp @@ -2825,8 +2825,10 @@ class TmaInnerReductionTest uint64_t total_reduction_bytes = reduction_size * dtype_bytes; - // Skip TMA for small reductions - if (total_reduction_bytes < 16384) { + // Minimum TMA transfer size, below which it seems much slower than non-TMA. + uint64_t min_tma_bytes = 16384; + + if (total_reduction_bytes < min_tma_bytes) { return false; } From 477530881483e64e39a1108482ce955c6fb939ce Mon Sep 17 00:00:00 2001 From: tbqh Date: Mon, 26 Jan 2026 02:12:38 -0800 Subject: [PATCH 6/6] Fix compile error --- csrc/scheduler/reduction_tma.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/scheduler/reduction_tma.cpp b/csrc/scheduler/reduction_tma.cpp index 37ae9cdba0f..ff322770347 100644 --- a/csrc/scheduler/reduction_tma.cpp +++ b/csrc/scheduler/reduction_tma.cpp @@ -101,7 +101,7 @@ std::unique_ptr getReductionHeuristics( } // These are derived from benchmarking. - int64_t threads_per_block = has_mufu_computation ? 512 : 256; + int64_t threads_per_block = props.has_mufu_computation ? 512 : 256; const int64_t max_threads_per_sm = dev_prop->maxThreadsPerMultiProcessor; const int64_t target_threads_per_sm = max_threads_per_sm / 2;