From b6dd8c37e84ef034c71f77f7e6f1ac7fca7f8cbc Mon Sep 17 00:00:00 2001 From: tbqh Date: Wed, 28 Jan 2026 00:44:07 -0800 Subject: [PATCH 1/5] Optimize reduction_tma --- csrc/scheduler/reduction.cpp | 9 +++++++-- csrc/scheduler/reduction_tma.cpp | 30 ++++++++++++++---------------- csrc/scheduler/reduction_tma.h | 3 --- tests/cpp/test_reduction.cpp | 14 +++++++++----- 4 files changed, 30 insertions(+), 26 deletions(-) diff --git a/csrc/scheduler/reduction.cpp b/csrc/scheduler/reduction.cpp index 5b11a13e929..44d00ce950d 100644 --- a/csrc/scheduler/reduction.cpp +++ b/csrc/scheduler/reduction.cpp @@ -206,8 +206,13 @@ bool mayUseTma( return false; } - // For small TMA sizes, the smem indirection is not worth it. - if (props.total_reduction_numel < 128) { + int64_t dtype_bytes = props.max_dtype_size_bit_for_vectorization / 8; + uint64_t total_reduction_bytes = props.total_reduction_numel * dtype_bytes; + + // Minimum TMA transfer size, below which it seems much slower than non-TMA. + uint64_t min_tma_bytes = 16384; + + if (total_reduction_bytes < min_tma_bytes) { return false; } diff --git a/csrc/scheduler/reduction_tma.cpp b/csrc/scheduler/reduction_tma.cpp index 3a3ac4fe319..329eaafbd99 100644 --- a/csrc/scheduler/reduction_tma.cpp +++ b/csrc/scheduler/reduction_tma.cpp @@ -26,6 +26,10 @@ std::unique_ptr getReductionHeuristics( FusionGuard fg(fusion); auto dev_prop = at::cuda::getCurrentDeviceProperties(); + + // These are derived from benchmarking. + int64_t threads_per_block = props.has_mufu_computation ? 512 : 256; + const int64_t max_threads_per_sm = dev_prop->maxThreadsPerMultiProcessor; const int64_t target_threads_per_sm = max_threads_per_sm / 2; @@ -36,14 +40,11 @@ std::unique_ptr getReductionHeuristics( target_threads_per_sm, props.has_mufu_computation); - // Initialize split factors - int64_t vectorization_factor = - std::min(target_vect_unroll, props.vectorize_factor); - int64_t threads_per_block = 128; - int64_t unroll_factor = target_vect_unroll / vectorization_factor; + // TMA kernel doesn't do vectorization due to working out of shared memory. + // Instead, fold the entire vect_unroll into unroll_factor. + int64_t unroll_factor = scheduler_utils::lastPow2(target_vect_unroll); auto params = std::make_unique(); - params->vectorization_factor = vectorization_factor; params->threads_per_block = threads_per_block; params->unroll_factor = unroll_factor; @@ -65,6 +66,8 @@ void scheduleReduction(Fusion* fusion, const TmaInnerReductionParams* rparams) { scheduler_utils::prepareForMemoryTypePromotion(fusion); + scheduler_utils::cacheAndForkOutputs(fusion, true); + std::vector tma_tvs; for (auto [tv, input_idx] : cached_inputs) { if (auto load_op = dynamic_cast(tv->definition())) { @@ -104,26 +107,21 @@ void scheduleReduction(Fusion* fusion, const TmaInnerReductionParams* rparams) { // Non-TMA scheduling // // Apply splits following the pattern: - // [I, R] -> [I, R/vect, vect] - // -> [I, R/vect/tidx, tidx, vect] - // -> [I, R/vect/tidx/unroll, unroll, tidx, vect] - - // Split 1: Vectorization factor (innermost serial split for TMA) - reduction_tv->split(inner_reduce_axis, rparams->vectorization_factor); - reduction_tv->axis(inner_reduce_axis + 1)->parallelize(ParallelType::Serial); + // [I, R] -> [I, R/tidx, tidx] + // -> [I, R/tidx/unroll, unroll, tidx] - // Split 2: TIDx (always applied) + // Split 1: TIDx (always applied) reduction_tv->split(inner_reduce_axis, rparams->threads_per_block); reduction_tv->axis(inner_reduce_axis + 1)->parallelize(ParallelType::TIDx); - // Split 3: Inner unroll (outside of TIDx) + // Split 2: Inner unroll (outside of TIDx) if (rparams->unroll_factor > 1) { reduction_tv->split(inner_reduce_axis, rparams->unroll_factor); reduction_tv->axis(inner_reduce_axis + 1) ->parallelize(ParallelType::Unroll); } - // Split 4: Unswitch (always applied) + // Split 3: Unswitch (always applied) reduction_tv->split(inner_reduce_axis, 1); reduction_tv->axis(inner_reduce_axis + 1) ->parallelize(ParallelType::Unswitch); diff --git a/csrc/scheduler/reduction_tma.h b/csrc/scheduler/reduction_tma.h index 01a89d304ad..c950478493f 100644 --- a/csrc/scheduler/reduction_tma.h +++ b/csrc/scheduler/reduction_tma.h @@ -19,9 +19,6 @@ class TmaInnerReductionParams : public HeuristicParams { SchedulerType scheduler_type = SchedulerType::Reduction) : HeuristicParams(scheduler_type) {}; - // Inner serial split factor (similar to vectorization for non-TMA) - int64_t vectorization_factor = 1; - // Number of threads per block for TIDx parallelization int64_t threads_per_block = 1; diff --git a/tests/cpp/test_reduction.cpp b/tests/cpp/test_reduction.cpp index 827418ac232..15c027d0e5b 100644 --- a/tests/cpp/test_reduction.cpp +++ b/tests/cpp/test_reduction.cpp @@ -2815,11 +2815,6 @@ class TmaInnerReductionTest // Check if we expect TMA to be used based on mayUseTma() conditions bool expectTmaUsed(DataType dtype, int64_t reduction_size) { - // Skip TMA for small reductions - if (reduction_size < 128) { - return false; - } - // TMA requires 16-byte alignment (vectorize_factor > 1) int64_t dtype_size_bit = dataTypeSizeBit(dtype); int64_t dtype_bytes = dtype_size_bit / 8; @@ -2828,6 +2823,15 @@ class TmaInnerReductionTest return false; } + uint64_t total_reduction_bytes = reduction_size * dtype_bytes; + + // Minimum TMA transfer size, below which it seems much slower than non-TMA. + uint64_t min_tma_bytes = 16384; + + if (total_reduction_bytes < min_tma_bytes) { + return false; + } + // Reduction dim must fit into smem auto dev_prop = at::cuda::getCurrentDeviceProperties(); int64_t smem_elems = From 871689d8a4f791c4d07beb62737f88c7a409cf31 Mon Sep 17 00:00:00 2001 From: tbqh Date: Thu, 29 Jan 2026 17:46:02 -0800 Subject: [PATCH 2/5] Bring back vectorization_factor, set threads_per_block=256, add comments --- csrc/scheduler/reduction_tma.cpp | 35 ++++++++++++++++++++++++-------- csrc/scheduler/reduction_tma.h | 3 +++ 2 files changed, 29 insertions(+), 9 deletions(-) diff --git a/csrc/scheduler/reduction_tma.cpp b/csrc/scheduler/reduction_tma.cpp index 329eaafbd99..7e5c1973f6e 100644 --- a/csrc/scheduler/reduction_tma.cpp +++ b/csrc/scheduler/reduction_tma.cpp @@ -27,8 +27,16 @@ std::unique_ptr getReductionHeuristics( auto dev_prop = at::cuda::getCurrentDeviceProperties(); - // These are derived from benchmarking. - int64_t threads_per_block = props.has_mufu_computation ? 512 : 256; + // Vectorization is not as useful for TMA since we're not doing global loads. + // Vectorization can hurt the performance of shared memory loads, due to bank + // conflicts. For float32, ideal shared memory reads is achieved with no + // vectorization. For float16, it would be ideal for vect_factor=2, but we + // don't currently handle this. + int64_t vectorization_factor = 1; + + // Benchmarking shows some benefit to 512 block size for has_mufu_computation, + // but it's not across the board. Stick to 256 for now. + int64_t threads_per_block = 256; const int64_t max_threads_per_sm = dev_prop->maxThreadsPerMultiProcessor; const int64_t target_threads_per_sm = max_threads_per_sm / 2; @@ -40,11 +48,12 @@ std::unique_ptr getReductionHeuristics( target_threads_per_sm, props.has_mufu_computation); - // TMA kernel doesn't do vectorization due to working out of shared memory. - // Instead, fold the entire vect_unroll into unroll_factor. + // Since vectorization isn't currently performed, fold the entire vect_unroll + // factor into unroll. int64_t unroll_factor = scheduler_utils::lastPow2(target_vect_unroll); auto params = std::make_unique(); + params->vectorization_factor = vectorization_factor; params->threads_per_block = threads_per_block; params->unroll_factor = unroll_factor; @@ -107,21 +116,29 @@ void scheduleReduction(Fusion* fusion, const TmaInnerReductionParams* rparams) { // Non-TMA scheduling // // Apply splits following the pattern: - // [I, R] -> [I, R/tidx, tidx] - // -> [I, R/tidx/unroll, unroll, tidx] + // [I, R] -> [I, R/vect, vect] + // -> [I, R/vect/tidx, tidx, vect] + // -> [I, R/vect/tidx/unroll, unroll, tidx, vect] + + // Split 1: Vectorization factor (innermost serial split for TMA) + if (rparams->vectorization_factor > 1) { + reduction_tv->split(inner_reduce_axis, rparams->vectorization_factor); + reduction_tv->axis(inner_reduce_axis + 1) + ->parallelize(ParallelType::Serial); + } - // Split 1: TIDx (always applied) + // Split 2: TIDx (always applied) reduction_tv->split(inner_reduce_axis, rparams->threads_per_block); reduction_tv->axis(inner_reduce_axis + 1)->parallelize(ParallelType::TIDx); - // Split 2: Inner unroll (outside of TIDx) + // Split 3: Inner unroll (outside of TIDx) if (rparams->unroll_factor > 1) { reduction_tv->split(inner_reduce_axis, rparams->unroll_factor); reduction_tv->axis(inner_reduce_axis + 1) ->parallelize(ParallelType::Unroll); } - // Split 3: Unswitch (always applied) + // Split 4: Unswitch (always applied) reduction_tv->split(inner_reduce_axis, 1); reduction_tv->axis(inner_reduce_axis + 1) ->parallelize(ParallelType::Unswitch); diff --git a/csrc/scheduler/reduction_tma.h b/csrc/scheduler/reduction_tma.h index c950478493f..01a89d304ad 100644 --- a/csrc/scheduler/reduction_tma.h +++ b/csrc/scheduler/reduction_tma.h @@ -19,6 +19,9 @@ class TmaInnerReductionParams : public HeuristicParams { SchedulerType scheduler_type = SchedulerType::Reduction) : HeuristicParams(scheduler_type) {}; + // Inner serial split factor (similar to vectorization for non-TMA) + int64_t vectorization_factor = 1; + // Number of threads per block for TIDx parallelization int64_t threads_per_block = 1; From c24d65ab2bd96f0b2d4a08aabd38b0a99b13086a Mon Sep 17 00:00:00 2001 From: tbqh Date: Fri, 30 Jan 2026 15:03:43 -0800 Subject: [PATCH 3/5] Add comment about bank conflicts --- csrc/scheduler/reduction_tma.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/csrc/scheduler/reduction_tma.cpp b/csrc/scheduler/reduction_tma.cpp index 7e5c1973f6e..07df887fdd5 100644 --- a/csrc/scheduler/reduction_tma.cpp +++ b/csrc/scheduler/reduction_tma.cpp @@ -29,9 +29,10 @@ std::unique_ptr getReductionHeuristics( // Vectorization is not as useful for TMA since we're not doing global loads. // Vectorization can hurt the performance of shared memory loads, due to bank - // conflicts. For float32, ideal shared memory reads is achieved with no - // vectorization. For float16, it would be ideal for vect_factor=2, but we - // don't currently handle this. + // conflicts. E.g. for float32, ideal shared memory reads are achieved with no + // vectorization, but with a vectorization factor of 2, thread 0 and 16 will + // both hit bank0 during their first read. For float16, it would be ideal for + // vect_factor=2, but we don't currently handle this. int64_t vectorization_factor = 1; // Benchmarking shows some benefit to 512 block size for has_mufu_computation, From a8f35b7cd47fe60e06a97fab68034653cd0044d7 Mon Sep 17 00:00:00 2001 From: tbqh Date: Fri, 30 Jan 2026 15:04:28 -0800 Subject: [PATCH 4/5] Consistent ordre of cacheAndForkOutputs() --- csrc/scheduler/reduction_tma.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/scheduler/reduction_tma.cpp b/csrc/scheduler/reduction_tma.cpp index 07df887fdd5..2f38e5d8ffd 100644 --- a/csrc/scheduler/reduction_tma.cpp +++ b/csrc/scheduler/reduction_tma.cpp @@ -74,10 +74,10 @@ void scheduleReduction(Fusion* fusion, const TmaInnerReductionParams* rparams) { // fusion segmentation scheduler_utils::clearMemorySpace(fusion); - scheduler_utils::prepareForMemoryTypePromotion(fusion); - scheduler_utils::cacheAndForkOutputs(fusion, true); + scheduler_utils::prepareForMemoryTypePromotion(fusion); + std::vector tma_tvs; for (auto [tv, input_idx] : cached_inputs) { if (auto load_op = dynamic_cast(tv->definition())) { From c5a49fb1b9804277308eecfea6480f3cc9460f31 Mon Sep 17 00:00:00 2001 From: tbqh Date: Fri, 30 Jan 2026 16:55:08 -0800 Subject: [PATCH 5/5] Update vectorization comment --- csrc/scheduler/reduction_tma.cpp | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/csrc/scheduler/reduction_tma.cpp b/csrc/scheduler/reduction_tma.cpp index 2f38e5d8ffd..d26f20087c2 100644 --- a/csrc/scheduler/reduction_tma.cpp +++ b/csrc/scheduler/reduction_tma.cpp @@ -28,11 +28,7 @@ std::unique_ptr getReductionHeuristics( auto dev_prop = at::cuda::getCurrentDeviceProperties(); // Vectorization is not as useful for TMA since we're not doing global loads. - // Vectorization can hurt the performance of shared memory loads, due to bank - // conflicts. E.g. for float32, ideal shared memory reads are achieved with no - // vectorization, but with a vectorization factor of 2, thread 0 and 16 will - // both hit bank0 during their first read. For float16, it would be ideal for - // vect_factor=2, but we don't currently handle this. + // For now we don't do it. int64_t vectorization_factor = 1; // Benchmarking shows some benefit to 512 block size for has_mufu_computation,