Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 5 additions & 9 deletions csrc/scheduler/reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,17 +206,13 @@ bool mayUseTma(
return false;
}

// For small TMA sizes, the smem indirection is not worth it.
if (props.total_reduction_numel < 128) {
return false;
}
int64_t dtype_bytes = props.max_dtype_size_bit_for_vectorization / 8;
uint64_t total_reduction_bytes = props.total_reduction_numel * dtype_bytes;

// Require reduction dim fits into smem, until we add iteration over large
// reduction dim.
const int64_t smem_elems = (dev_prop->sharedMemPerBlockOptin * 8) /
props.max_dtype_size_bit_for_vectorization;
// Minimum TMA transfer size, below which it seems much slower than non-TMA.
uint64_t min_tma_bytes = 16384;

if (props.inner_most_dimension_numel > smem_elems) {
if (total_reduction_bytes < min_tma_bytes) {
return false;
}

Expand Down
115 changes: 98 additions & 17 deletions csrc/scheduler/reduction_tma.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,54 @@
namespace nvfuser {
namespace reduction {
namespace tma {
namespace {

// Find the smallest split factor that:
// 1. Evenly divides reduction_numel
// 2. Results in an element count that is divisible by alignment
// 3. Results in an element count that is inside [lower_elem_bound,
// upper_elem_bound]
int64_t getTmaSplit(
int64_t numel,
int64_t alignment,
int64_t lower_elem_bound,
int64_t upper_elem_bound) {
// Lower & upper bounds for the split factor
const int64_t split_lower = ceilDiv(numel, upper_elem_bound);
const int64_t split_upper = std::max(int64_t(1), numel / lower_elem_bound);

// Rather than linearly searching the whole range, use the fact that any
// divisor <= sqrt(numel) will be paired with another divisor >= sqrt(numel).
// Therefore we can stop at sqrt(numel) since we want to minimize the split
// divisor.
int64_t sqrt_n = int64_t(std::ceil(std::sqrt(double(numel))));
for (int64_t d = split_lower; d <= std::min(split_upper, sqrt_n); d++) {
if (numel % d == 0) {
int64_t tma_elems = numel / d;
if (tma_elems % alignment == 0) {
return d;
}
}
}

// The previous loop searched where the small divisor is within the range
// [split_lower, split_upper]. Now we check for cases where the large divisor
// is within that range.
for (int64_t d = sqrt_n; d >= 1; d--) {
if (numel % d == 0) {
int64_t paired = numel / d;
if (split_lower <= paired && paired <= split_upper) {
int64_t tma_elems = numel / paired;
if (tma_elems % alignment == 0) {
return paired;
}
}
}
}

return 0;
}
} // namespace

std::unique_ptr<TmaInnerReductionParams> getReductionHeuristics(
Fusion* fusion,
Expand All @@ -26,6 +74,35 @@ std::unique_ptr<TmaInnerReductionParams> getReductionHeuristics(
FusionGuard fg(fusion);

auto dev_prop = at::cuda::getCurrentDeviceProperties();

uint64_t dtype_bytes = props.max_dtype_size_bit_for_vectorization / 8;
uint64_t smem_elems = dev_prop->sharedMemPerBlockOptin / dtype_bytes;

// Search for a suitable split factor. Lower and upper bounds are based on
// heuristics. We require TMA loads are at least 16KB to overcome TMA
// overhead. And we require those loads consume less than half of available
// shared memory, to maintain reasonable occupancy.
constexpr int64_t min_tma_bytes = 16384;
const int64_t lower_elem_bound = min_tma_bytes / dtype_bytes;
const int64_t upper_elem_bound = smem_elems / 2;

// TMA requires 16-byte alignment after any splits
const int64_t aligned_elems = 16 / dtype_bytes;

const int64_t tma_split_factor = getTmaSplit(
props.inner_most_dimension_numel,
aligned_elems,
lower_elem_bound,
upper_elem_bound);

// If no valid split factor was found, fallback to non-TMA
if (tma_split_factor == 0) {
return nullptr;
}

// These are derived from benchmarking.
int64_t threads_per_block = props.has_mufu_computation ? 512 : 256;

const int64_t max_threads_per_sm = dev_prop->maxThreadsPerMultiProcessor;
const int64_t target_threads_per_sm = max_threads_per_sm / 2;

Expand All @@ -36,14 +113,12 @@ std::unique_ptr<TmaInnerReductionParams> getReductionHeuristics(
target_threads_per_sm,
props.has_mufu_computation);

// Initialize split factors
int64_t vectorization_factor =
std::min(target_vect_unroll, props.vectorize_factor);
int64_t threads_per_block = 128;
int64_t unroll_factor = target_vect_unroll / vectorization_factor;
// TMA kernel doesn't do vectorization due to working out of shared memory.
// Instead, fold the entire vect_unroll into unroll_factor.
int64_t unroll_factor = scheduler_utils::lastPow2(target_vect_unroll);

auto params = std::make_unique<TmaInnerReductionParams>();
params->vectorization_factor = vectorization_factor;
params->tma_split_factor = tma_split_factor;
params->threads_per_block = threads_per_block;
params->unroll_factor = unroll_factor;

Expand All @@ -65,6 +140,8 @@ void scheduleReduction(Fusion* fusion, const TmaInnerReductionParams* rparams) {

scheduler_utils::prepareForMemoryTypePromotion(fusion);

scheduler_utils::cacheAndForkOutputs(fusion, true);

std::vector<TensorView*> tma_tvs;
for (auto [tv, input_idx] : cached_inputs) {
if (auto load_op = dynamic_cast<LoadStoreOp*>(tv->definition())) {
Expand All @@ -87,14 +164,23 @@ void scheduleReduction(Fusion* fusion, const TmaInnerReductionParams* rparams) {
bool has_red_axis = dim_analysis.second;
NVF_ERROR(has_iter_axis && has_red_axis);

if (rparams->tma_split_factor > 1) {
reduction_tv->split(1, rparams->tma_split_factor, false);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to consider divisible split for better performance when setting this tma_split_factor?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is actually needed for correctness, not just performance. Lowering will fail if the reduction size is not divisible by the split. I believe this is a restriction on 1D TMA.

I added a new function getTmaSplit() to search for a valid split size. This fixes failures with large non-divisible sizes.

Note that now the TMA scheduler will return nullptr during heuristic checking if the splitting fails. This the splitting logic is complicated and I don't think we should duplicate it in mayUseTma().

reduction_tv->axis(1)->parallelize(ParallelType::Serial);

for (auto tma_tv : tma_tvs) {
tma_tv->split(1, rparams->tma_split_factor, false);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't need manually split tma_tv, TransformPropagator should be able to propagate the transforms.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It fails to propagate for bfloat16. Claude thinks this is due to to casts blocking something in the maximum spanning tree.

}

// Propagate the merges to all TMA TVs
TransformPropagator tma_propagator(reduction_tv);
SetSelector tma_selector({tma_tvs.begin(), tma_tvs.end()});
MaxLogicalDomainInfoSpanningTree(reduction_tv, &tma_selector)
.traverse(&tma_propagator);

int64_t iter_axis = 0;
int64_t inner_reduce_axis = 1;
int64_t inner_reduce_axis = rparams->tma_split_factor > 1 ? 2 : 1;

// Schedule TMA tvs as [BIDx, Bulk]
tma_tvs[0]->axis(iter_axis)->parallelize(ParallelType::BIDx);
Expand All @@ -104,26 +190,21 @@ void scheduleReduction(Fusion* fusion, const TmaInnerReductionParams* rparams) {
// Non-TMA scheduling
//
// Apply splits following the pattern:
// [I, R] -> [I, R/vect, vect]
// -> [I, R/vect/tidx, tidx, vect]
// -> [I, R/vect/tidx/unroll, unroll, tidx, vect]

// Split 1: Vectorization factor (innermost serial split for TMA)
reduction_tv->split(inner_reduce_axis, rparams->vectorization_factor);
reduction_tv->axis(inner_reduce_axis + 1)->parallelize(ParallelType::Serial);
// [I, R] -> [I, R/tidx, tidx]
// -> [I, R/tidx/unroll, unroll, tidx]

// Split 2: TIDx (always applied)
// Split 1: TIDx (always applied)
reduction_tv->split(inner_reduce_axis, rparams->threads_per_block);
reduction_tv->axis(inner_reduce_axis + 1)->parallelize(ParallelType::TIDx);

// Split 3: Inner unroll (outside of TIDx)
// Split 2: Inner unroll (outside of TIDx)
if (rparams->unroll_factor > 1) {
reduction_tv->split(inner_reduce_axis, rparams->unroll_factor);
reduction_tv->axis(inner_reduce_axis + 1)
->parallelize(ParallelType::Unroll);
}

// Split 4: Unswitch (always applied)
// Split 3: Unswitch (always applied)
reduction_tv->split(inner_reduce_axis, 1);
reduction_tv->axis(inner_reduce_axis + 1)
->parallelize(ParallelType::Unswitch);
Expand Down
4 changes: 2 additions & 2 deletions csrc/scheduler/reduction_tma.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ class TmaInnerReductionParams : public HeuristicParams {
SchedulerType scheduler_type = SchedulerType::Reduction)
: HeuristicParams(scheduler_type) {};

// Inner serial split factor (similar to vectorization for non-TMA)
int64_t vectorization_factor = 1;
// Outer serial split factor for reduce dimension, to better fit into smem.
int64_t tma_split_factor = 1;

// Number of threads per block for TIDx parallelization
int64_t threads_per_block = 1;
Expand Down
22 changes: 12 additions & 10 deletions tests/cpp/test_reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2815,11 +2815,6 @@ class TmaInnerReductionTest

// Check if we expect TMA to be used based on mayUseTma() conditions
bool expectTmaUsed(DataType dtype, int64_t reduction_size) {
// Skip TMA for small reductions
if (reduction_size < 128) {
return false;
}

// TMA requires 16-byte alignment (vectorize_factor > 1)
int64_t dtype_size_bit = dataTypeSizeBit(dtype);
int64_t dtype_bytes = dtype_size_bit / 8;
Expand All @@ -2828,11 +2823,18 @@ class TmaInnerReductionTest
return false;
}

// Reduction dim must fit into smem
auto dev_prop = at::cuda::getCurrentDeviceProperties();
int64_t smem_elems =
(dev_prop->sharedMemPerBlockOptin * 8) / dtype_size_bit;
if (reduction_size > smem_elems) {
uint64_t total_reduction_bytes = reduction_size * dtype_bytes;

// Minimum TMA transfer size, below which it seems much slower than non-TMA.
uint64_t min_tma_bytes = 16384;

if (total_reduction_bytes < min_tma_bytes) {
return false;
}

// This test has a few large, non-power-2 shapes, which TMA will reject
// when searching for a suitable split.
if (reduction_size > 1024 * 1024 && (reduction_size % 32) != 0) {
return false;
}

Expand Down
Loading