diff --git a/csrc/scheduler/reduction.cpp b/csrc/scheduler/reduction.cpp index 5b11a13e929..44d00ce950d 100644 --- a/csrc/scheduler/reduction.cpp +++ b/csrc/scheduler/reduction.cpp @@ -206,8 +206,13 @@ bool mayUseTma( return false; } - // For small TMA sizes, the smem indirection is not worth it. - if (props.total_reduction_numel < 128) { + int64_t dtype_bytes = props.max_dtype_size_bit_for_vectorization / 8; + uint64_t total_reduction_bytes = props.total_reduction_numel * dtype_bytes; + + // Minimum TMA transfer size, below which it seems much slower than non-TMA. + uint64_t min_tma_bytes = 16384; + + if (total_reduction_bytes < min_tma_bytes) { return false; } diff --git a/csrc/scheduler/reduction_tma.cpp b/csrc/scheduler/reduction_tma.cpp index 3a3ac4fe319..d26f20087c2 100644 --- a/csrc/scheduler/reduction_tma.cpp +++ b/csrc/scheduler/reduction_tma.cpp @@ -26,6 +26,15 @@ std::unique_ptr getReductionHeuristics( FusionGuard fg(fusion); auto dev_prop = at::cuda::getCurrentDeviceProperties(); + + // Vectorization is not as useful for TMA since we're not doing global loads. + // For now we don't do it. + int64_t vectorization_factor = 1; + + // Benchmarking shows some benefit to 512 block size for has_mufu_computation, + // but it's not across the board. Stick to 256 for now. + int64_t threads_per_block = 256; + const int64_t max_threads_per_sm = dev_prop->maxThreadsPerMultiProcessor; const int64_t target_threads_per_sm = max_threads_per_sm / 2; @@ -36,11 +45,9 @@ std::unique_ptr getReductionHeuristics( target_threads_per_sm, props.has_mufu_computation); - // Initialize split factors - int64_t vectorization_factor = - std::min(target_vect_unroll, props.vectorize_factor); - int64_t threads_per_block = 128; - int64_t unroll_factor = target_vect_unroll / vectorization_factor; + // Since vectorization isn't currently performed, fold the entire vect_unroll + // factor into unroll. + int64_t unroll_factor = scheduler_utils::lastPow2(target_vect_unroll); auto params = std::make_unique(); params->vectorization_factor = vectorization_factor; @@ -63,6 +70,8 @@ void scheduleReduction(Fusion* fusion, const TmaInnerReductionParams* rparams) { // fusion segmentation scheduler_utils::clearMemorySpace(fusion); + scheduler_utils::cacheAndForkOutputs(fusion, true); + scheduler_utils::prepareForMemoryTypePromotion(fusion); std::vector tma_tvs; @@ -109,8 +118,11 @@ void scheduleReduction(Fusion* fusion, const TmaInnerReductionParams* rparams) { // -> [I, R/vect/tidx/unroll, unroll, tidx, vect] // Split 1: Vectorization factor (innermost serial split for TMA) - reduction_tv->split(inner_reduce_axis, rparams->vectorization_factor); - reduction_tv->axis(inner_reduce_axis + 1)->parallelize(ParallelType::Serial); + if (rparams->vectorization_factor > 1) { + reduction_tv->split(inner_reduce_axis, rparams->vectorization_factor); + reduction_tv->axis(inner_reduce_axis + 1) + ->parallelize(ParallelType::Serial); + } // Split 2: TIDx (always applied) reduction_tv->split(inner_reduce_axis, rparams->threads_per_block); diff --git a/tests/cpp/test_reduction.cpp b/tests/cpp/test_reduction.cpp index 827418ac232..15c027d0e5b 100644 --- a/tests/cpp/test_reduction.cpp +++ b/tests/cpp/test_reduction.cpp @@ -2815,11 +2815,6 @@ class TmaInnerReductionTest // Check if we expect TMA to be used based on mayUseTma() conditions bool expectTmaUsed(DataType dtype, int64_t reduction_size) { - // Skip TMA for small reductions - if (reduction_size < 128) { - return false; - } - // TMA requires 16-byte alignment (vectorize_factor > 1) int64_t dtype_size_bit = dataTypeSizeBit(dtype); int64_t dtype_bytes = dtype_size_bit / 8; @@ -2828,6 +2823,15 @@ class TmaInnerReductionTest return false; } + uint64_t total_reduction_bytes = reduction_size * dtype_bytes; + + // Minimum TMA transfer size, below which it seems much slower than non-TMA. + uint64_t min_tma_bytes = 16384; + + if (total_reduction_bytes < min_tma_bytes) { + return false; + } + // Reduction dim must fit into smem auto dev_prop = at::cuda::getCurrentDeviceProperties(); int64_t smem_elems =