Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions csrc/scheduler/reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,13 @@ bool mayUseTma(
return false;
}

// For small TMA sizes, the smem indirection is not worth it.
if (props.total_reduction_numel < 128) {
int64_t dtype_bytes = props.max_dtype_size_bit_for_vectorization / 8;
uint64_t total_reduction_bytes = props.total_reduction_numel * dtype_bytes;

// Minimum TMA transfer size, below which it seems much slower than non-TMA.
uint64_t min_tma_bytes = 16384;

if (total_reduction_bytes < min_tma_bytes) {
return false;
}

Expand Down
26 changes: 19 additions & 7 deletions csrc/scheduler/reduction_tma.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,15 @@ std::unique_ptr<TmaInnerReductionParams> getReductionHeuristics(
FusionGuard fg(fusion);

auto dev_prop = at::cuda::getCurrentDeviceProperties();

// Vectorization is not as useful for TMA since we're not doing global loads.
// For now we don't do it.
int64_t vectorization_factor = 1;

// Benchmarking shows some benefit to 512 block size for has_mufu_computation,
// but it's not across the board. Stick to 256 for now.
int64_t threads_per_block = 256;

const int64_t max_threads_per_sm = dev_prop->maxThreadsPerMultiProcessor;
const int64_t target_threads_per_sm = max_threads_per_sm / 2;

Expand All @@ -36,11 +45,9 @@ std::unique_ptr<TmaInnerReductionParams> getReductionHeuristics(
target_threads_per_sm,
props.has_mufu_computation);

// Initialize split factors
int64_t vectorization_factor =
std::min(target_vect_unroll, props.vectorize_factor);
int64_t threads_per_block = 128;
int64_t unroll_factor = target_vect_unroll / vectorization_factor;
// Since vectorization isn't currently performed, fold the entire vect_unroll
// factor into unroll.
int64_t unroll_factor = scheduler_utils::lastPow2(target_vect_unroll);

auto params = std::make_unique<TmaInnerReductionParams>();
params->vectorization_factor = vectorization_factor;
Expand All @@ -63,6 +70,8 @@ void scheduleReduction(Fusion* fusion, const TmaInnerReductionParams* rparams) {
// fusion segmentation
scheduler_utils::clearMemorySpace(fusion);

scheduler_utils::cacheAndForkOutputs(fusion, true);

scheduler_utils::prepareForMemoryTypePromotion(fusion);

std::vector<TensorView*> tma_tvs;
Expand Down Expand Up @@ -109,8 +118,11 @@ void scheduleReduction(Fusion* fusion, const TmaInnerReductionParams* rparams) {
// -> [I, R/vect/tidx/unroll, unroll, tidx, vect]

// Split 1: Vectorization factor (innermost serial split for TMA)
reduction_tv->split(inner_reduce_axis, rparams->vectorization_factor);
reduction_tv->axis(inner_reduce_axis + 1)->parallelize(ParallelType::Serial);
if (rparams->vectorization_factor > 1) {
reduction_tv->split(inner_reduce_axis, rparams->vectorization_factor);
reduction_tv->axis(inner_reduce_axis + 1)
->parallelize(ParallelType::Serial);
}

// Split 2: TIDx (always applied)
reduction_tv->split(inner_reduce_axis, rparams->threads_per_block);
Expand Down
14 changes: 9 additions & 5 deletions tests/cpp/test_reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2815,11 +2815,6 @@ class TmaInnerReductionTest

// Check if we expect TMA to be used based on mayUseTma() conditions
bool expectTmaUsed(DataType dtype, int64_t reduction_size) {
// Skip TMA for small reductions
if (reduction_size < 128) {
return false;
}

// TMA requires 16-byte alignment (vectorize_factor > 1)
int64_t dtype_size_bit = dataTypeSizeBit(dtype);
int64_t dtype_bytes = dtype_size_bit / 8;
Expand All @@ -2828,6 +2823,15 @@ class TmaInnerReductionTest
return false;
}

uint64_t total_reduction_bytes = reduction_size * dtype_bytes;

// Minimum TMA transfer size, below which it seems much slower than non-TMA.
uint64_t min_tma_bytes = 16384;

if (total_reduction_bytes < min_tma_bytes) {
return false;
}

// Reduction dim must fit into smem
auto dev_prop = at::cuda::getCurrentDeviceProperties();
int64_t smem_elems =
Expand Down