-
Notifications
You must be signed in to change notification settings - Fork 76
Optimize TMA inner-reduction and add TMA serial-split #5867
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
f06fd56
6f28318
5dc00e4
bbf4351
35f6b0d
4775308
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,6 +17,54 @@ | |
| namespace nvfuser { | ||
| namespace reduction { | ||
| namespace tma { | ||
| namespace { | ||
|
|
||
| // Find the smallest split factor that: | ||
| // 1. Evenly divides reduction_numel | ||
| // 2. Results in an element count that is divisible by alignment | ||
| // 3. Results in an element count that is inside [lower_elem_bound, | ||
| // upper_elem_bound] | ||
| int64_t getTmaSplit( | ||
| int64_t numel, | ||
| int64_t alignment, | ||
| int64_t lower_elem_bound, | ||
| int64_t upper_elem_bound) { | ||
| // Lower & upper bounds for the split factor | ||
| const int64_t split_lower = ceilDiv(numel, upper_elem_bound); | ||
| const int64_t split_upper = std::max(int64_t(1), numel / lower_elem_bound); | ||
|
|
||
| // Rather than linearly searching the whole range, use the fact that any | ||
| // divisor <= sqrt(numel) will be paired with another divisor >= sqrt(numel). | ||
| // Therefore we can stop at sqrt(numel) since we want to minimize the split | ||
| // divisor. | ||
| int64_t sqrt_n = int64_t(std::ceil(std::sqrt(double(numel)))); | ||
| for (int64_t d = split_lower; d <= std::min(split_upper, sqrt_n); d++) { | ||
| if (numel % d == 0) { | ||
| int64_t tma_elems = numel / d; | ||
| if (tma_elems % alignment == 0) { | ||
| return d; | ||
| } | ||
| } | ||
| } | ||
|
|
||
| // The previous loop searched where the small divisor is within the range | ||
| // [split_lower, split_upper]. Now we check for cases where the large divisor | ||
| // is within that range. | ||
| for (int64_t d = sqrt_n; d >= 1; d--) { | ||
| if (numel % d == 0) { | ||
| int64_t paired = numel / d; | ||
| if (split_lower <= paired && paired <= split_upper) { | ||
| int64_t tma_elems = numel / paired; | ||
| if (tma_elems % alignment == 0) { | ||
| return paired; | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| return 0; | ||
| } | ||
| } // namespace | ||
|
|
||
| std::unique_ptr<TmaInnerReductionParams> getReductionHeuristics( | ||
| Fusion* fusion, | ||
|
|
@@ -26,6 +74,35 @@ std::unique_ptr<TmaInnerReductionParams> getReductionHeuristics( | |
| FusionGuard fg(fusion); | ||
|
|
||
| auto dev_prop = at::cuda::getCurrentDeviceProperties(); | ||
|
|
||
| uint64_t dtype_bytes = props.max_dtype_size_bit_for_vectorization / 8; | ||
| uint64_t smem_elems = dev_prop->sharedMemPerBlockOptin / dtype_bytes; | ||
|
|
||
| // Search for a suitable split factor. Lower and upper bounds are based on | ||
| // heuristics. We require TMA loads are at least 16KB to overcome TMA | ||
| // overhead. And we require those loads consume less than half of available | ||
| // shared memory, to maintain reasonable occupancy. | ||
| constexpr int64_t min_tma_bytes = 16384; | ||
| const int64_t lower_elem_bound = min_tma_bytes / dtype_bytes; | ||
| const int64_t upper_elem_bound = smem_elems / 2; | ||
|
|
||
| // TMA requires 16-byte alignment after any splits | ||
| const int64_t aligned_elems = 16 / dtype_bytes; | ||
|
|
||
| const int64_t tma_split_factor = getTmaSplit( | ||
| props.inner_most_dimension_numel, | ||
| aligned_elems, | ||
| lower_elem_bound, | ||
| upper_elem_bound); | ||
|
|
||
| // If no valid split factor was found, fallback to non-TMA | ||
| if (tma_split_factor == 0) { | ||
| return nullptr; | ||
| } | ||
|
|
||
| // These are derived from benchmarking. | ||
| int64_t threads_per_block = props.has_mufu_computation ? 512 : 256; | ||
|
|
||
| const int64_t max_threads_per_sm = dev_prop->maxThreadsPerMultiProcessor; | ||
| const int64_t target_threads_per_sm = max_threads_per_sm / 2; | ||
|
|
||
|
|
@@ -36,14 +113,12 @@ std::unique_ptr<TmaInnerReductionParams> getReductionHeuristics( | |
| target_threads_per_sm, | ||
| props.has_mufu_computation); | ||
|
|
||
| // Initialize split factors | ||
| int64_t vectorization_factor = | ||
| std::min(target_vect_unroll, props.vectorize_factor); | ||
| int64_t threads_per_block = 128; | ||
| int64_t unroll_factor = target_vect_unroll / vectorization_factor; | ||
| // TMA kernel doesn't do vectorization due to working out of shared memory. | ||
| // Instead, fold the entire vect_unroll into unroll_factor. | ||
| int64_t unroll_factor = scheduler_utils::lastPow2(target_vect_unroll); | ||
|
|
||
| auto params = std::make_unique<TmaInnerReductionParams>(); | ||
| params->vectorization_factor = vectorization_factor; | ||
| params->tma_split_factor = tma_split_factor; | ||
| params->threads_per_block = threads_per_block; | ||
| params->unroll_factor = unroll_factor; | ||
|
|
||
|
|
@@ -65,6 +140,8 @@ void scheduleReduction(Fusion* fusion, const TmaInnerReductionParams* rparams) { | |
|
|
||
| scheduler_utils::prepareForMemoryTypePromotion(fusion); | ||
|
|
||
| scheduler_utils::cacheAndForkOutputs(fusion, true); | ||
|
|
||
| std::vector<TensorView*> tma_tvs; | ||
| for (auto [tv, input_idx] : cached_inputs) { | ||
| if (auto load_op = dynamic_cast<LoadStoreOp*>(tv->definition())) { | ||
|
|
@@ -87,14 +164,23 @@ void scheduleReduction(Fusion* fusion, const TmaInnerReductionParams* rparams) { | |
| bool has_red_axis = dim_analysis.second; | ||
| NVF_ERROR(has_iter_axis && has_red_axis); | ||
|
|
||
| if (rparams->tma_split_factor > 1) { | ||
| reduction_tv->split(1, rparams->tma_split_factor, false); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need to consider divisible split for better performance when setting this
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, this is actually needed for correctness, not just performance. Lowering will fail if the reduction size is not divisible by the split. I believe this is a restriction on 1D TMA. I added a new function Note that now the TMA scheduler will |
||
| reduction_tv->axis(1)->parallelize(ParallelType::Serial); | ||
|
|
||
| for (auto tma_tv : tma_tvs) { | ||
| tma_tv->split(1, rparams->tma_split_factor, false); | ||
| } | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't need manually split
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It fails to propagate for bfloat16. Claude thinks this is due to to casts blocking something in the maximum spanning tree. |
||
| } | ||
|
|
||
| // Propagate the merges to all TMA TVs | ||
| TransformPropagator tma_propagator(reduction_tv); | ||
| SetSelector tma_selector({tma_tvs.begin(), tma_tvs.end()}); | ||
| MaxLogicalDomainInfoSpanningTree(reduction_tv, &tma_selector) | ||
| .traverse(&tma_propagator); | ||
|
|
||
| int64_t iter_axis = 0; | ||
| int64_t inner_reduce_axis = 1; | ||
| int64_t inner_reduce_axis = rparams->tma_split_factor > 1 ? 2 : 1; | ||
|
|
||
| // Schedule TMA tvs as [BIDx, Bulk] | ||
| tma_tvs[0]->axis(iter_axis)->parallelize(ParallelType::BIDx); | ||
|
|
@@ -104,26 +190,21 @@ void scheduleReduction(Fusion* fusion, const TmaInnerReductionParams* rparams) { | |
| // Non-TMA scheduling | ||
| // | ||
| // Apply splits following the pattern: | ||
| // [I, R] -> [I, R/vect, vect] | ||
| // -> [I, R/vect/tidx, tidx, vect] | ||
| // -> [I, R/vect/tidx/unroll, unroll, tidx, vect] | ||
|
|
||
| // Split 1: Vectorization factor (innermost serial split for TMA) | ||
| reduction_tv->split(inner_reduce_axis, rparams->vectorization_factor); | ||
| reduction_tv->axis(inner_reduce_axis + 1)->parallelize(ParallelType::Serial); | ||
| // [I, R] -> [I, R/tidx, tidx] | ||
| // -> [I, R/tidx/unroll, unroll, tidx] | ||
|
|
||
| // Split 2: TIDx (always applied) | ||
| // Split 1: TIDx (always applied) | ||
| reduction_tv->split(inner_reduce_axis, rparams->threads_per_block); | ||
| reduction_tv->axis(inner_reduce_axis + 1)->parallelize(ParallelType::TIDx); | ||
|
|
||
| // Split 3: Inner unroll (outside of TIDx) | ||
| // Split 2: Inner unroll (outside of TIDx) | ||
| if (rparams->unroll_factor > 1) { | ||
| reduction_tv->split(inner_reduce_axis, rparams->unroll_factor); | ||
| reduction_tv->axis(inner_reduce_axis + 1) | ||
| ->parallelize(ParallelType::Unroll); | ||
| } | ||
|
|
||
| // Split 4: Unswitch (always applied) | ||
| // Split 3: Unswitch (always applied) | ||
| reduction_tv->split(inner_reduce_axis, 1); | ||
| reduction_tv->axis(inner_reduce_axis + 1) | ||
| ->parallelize(ParallelType::Unswitch); | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.