From 4e5539413bfa2d01b5ad51e2b2fc7667f06ff8ca Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Wed, 17 Dec 2025 06:09:15 -0800 Subject: [PATCH] rebase main --- .../normalization_inner_outer_tma_ws.cpp | 40 +---- csrc/scheduler/normalization_inner_tma.cpp | 170 ++++++++++++++++-- csrc/scheduler/normalization_inner_tma.h | 8 +- csrc/scheduler/utils.cpp | 20 +++ csrc/scheduler/utils.h | 20 ++- tests/cpp/test_persistent_buffer.cpp | 6 +- 6 files changed, 207 insertions(+), 57 deletions(-) diff --git a/csrc/scheduler/normalization_inner_outer_tma_ws.cpp b/csrc/scheduler/normalization_inner_outer_tma_ws.cpp index 72dcf2a1d51..7062f4ef881 100644 --- a/csrc/scheduler/normalization_inner_outer_tma_ws.cpp +++ b/csrc/scheduler/normalization_inner_outer_tma_ws.cpp @@ -34,8 +34,6 @@ void getHeuristics( rparams->cparams.index_type = index_type; const auto dev_prop = at::cuda::getCurrentDeviceProperties(); const int64_t sm_count = (int64_t)dev_prop->multiProcessorCount; - constexpr int64_t reg_per_async_thread = 32L; - constexpr int64_t regs_granularity = 8L; // Params for 1st stage, inner reduction and partial outer reduction. // Inner dim: inner_vect, inner_batch, and bdimx @@ -110,33 +108,7 @@ void getHeuristics( int buffer_per_thread = buffer_per_element * elements_per_thread; return buffer_per_thread / scheduler_utils::bits_per_register; }; - // Assume each padded threads keep [tma_branch_registers] registers and all - // others are moved to computation threads. The granularity is 8. - // [tma_branch_registers] is a tunable parameter. When estimated - // compute_branch_regs is not divisible by granularity, it is rounded down - // and needs to recompute tma_branch_registers. - // For example, assuming 256 computation threads, initial register = 168, - // tma_branch_regs = 32. then (168 - 32) * 128 / 256 = 68 which is not - // divisible by 8, compute_branch_registers = 168 + 68 = 236 --> rounded - // down to 232. re-calculate [tma_branch_registers] using: borrowed - // registers = (232 - 168) * 256 / 128 = 128. tma_branch_registers = 168 - - // 128 = 40 - auto get_register_sharing = [&](int64_t reg_per_thread, - int64_t computation_threads) { - int64_t tma_branch_regs = reg_per_async_thread; - int64_t compute_branch_regs = reg_per_thread + - (reg_per_thread - tma_branch_regs) * kWarpSpecializationPaddedThreads / - computation_threads; - if (compute_branch_regs % regs_granularity != 0) { - compute_branch_regs -= compute_branch_regs % regs_granularity; - tma_branch_regs = reg_per_thread - - (compute_branch_regs - reg_per_thread) * computation_threads / - kWarpSpecializationPaddedThreads; - } - compute_branch_regs = std::min( - compute_branch_regs, scheduler_utils::max_registers_per_thread); - return std::make_pair(tma_branch_regs, compute_branch_regs); - }; + auto is_enough_regs = [&](int64_t iter_unroll, int64_t bdimx, int64_t bdimy) { int64_t reg_count = 0; // cache circular buffered tv @@ -160,8 +132,8 @@ void getHeuristics( reg_count += register_overhead_ws_tma; int64_t available_regs = getRegPerThreadGivenThreadsPerSM( bdimx * bdimy + kWarpSpecializationPaddedThreads); - auto [_, compute_branch_regs] = - get_register_sharing(available_regs, bdimx * bdimy); + auto [_, compute_branch_regs] = scheduler_utils::getRegisterSharing( + available_regs, bdimx * bdimy, kWarpSpecializationPaddedThreads); return reg_count <= compute_branch_regs; }; @@ -316,7 +288,8 @@ void getHeuristics( kWarpSpecializationPaddedThreads + computation_threads; if (total_threads > 256) { int64_t reg_per_thread = getRegPerThreadGivenThreadsPerSM(total_threads); - ws.num_registers = get_register_sharing(reg_per_thread, bdimx * bdimy); + ws.num_registers = scheduler_utils::getRegisterSharing( + reg_per_thread, bdimx * bdimy, kWarpSpecializationPaddedThreads); } CircularBufferOptions circular_buffer_options{ .type = ws, .stage = n_stages, .prefetch = n_stages - 1}; @@ -395,8 +368,7 @@ void getHeuristics( << is_circular_buffer_regs_cached << "\n" << "is_non_circular_buffer_gmem_to_regs: " << is_non_circular_buffer_gmem_to_regs << "\n"; - debug() << "smem_persistent_buffers: " - << "\n"; + debug() << "smem_persistent_buffers: " << "\n"; for (auto buffer : rparams->smem_persistent_buffers) { debug() << buffer->toString() << "\n"; } diff --git a/csrc/scheduler/normalization_inner_tma.cpp b/csrc/scheduler/normalization_inner_tma.cpp index fd3d1880dbd..df945a77839 100644 --- a/csrc/scheduler/normalization_inner_tma.cpp +++ b/csrc/scheduler/normalization_inner_tma.cpp @@ -38,12 +38,14 @@ std::unique_ptr getInnerPersistentHeuristics( FusionGuard fg(fusion); auto dev_prop = at::cuda::getCurrentDeviceProperties(); const int64_t warp_size = dev_prop->warpSize; + const int64_t sm_count = dev_prop->multiProcessorCount; const int64_t max_threads_per_cta = dev_prop->maxThreadsPerBlock; auto params = std::make_unique( InnerPersistentKernelScheduler::schedulerType()); params->tag = "Inner Persistent TMA heuristics"; const int64_t total_redu_count = prop.inner_most_dimension_numel; + const int64_t total_iter_count = prop.total_iteration_numel; // Always project persistent buffers to inputs since inputs are cached in // shared memory, reducing the size of persistent buffers @@ -76,18 +78,66 @@ std::unique_ptr getInnerPersistentHeuristics( } params->persistent_batch_size = pbs; + // set warp specialized circular buffer options + // don't use warp specialized if the total iteration count is too small + // TODO: heuristic tuning determine when to use warp specialized version + int64_t gdimx = LaunchParams::UNINITIALIZED_VAL; + int64_t bdimy = LaunchParams::UNINITIALIZED_VAL; + int64_t bdimz = LaunchParams::UNINITIALIZED_VAL; + const int64_t n_compute_warp_groups = 2; + const int64_t n_rows_per_compute_warp_group = 2; + const int64_t iter_limited_stages = total_iter_count / + (n_compute_warp_groups * n_rows_per_compute_warp_group * sm_count); + const int64_t smem_size_bit = prop.max_persistent_buffer_size_bit * + n_compute_warp_groups * n_rows_per_compute_warp_group; + const int64_t smem_limited_stages = + (int64_t)dev_prop->sharedMemPerBlockOptin * 8 / smem_size_bit; + const int64_t n_stages = std::min(smem_limited_stages, iter_limited_stages); + if (n_stages >= 2 && bdimx == 128) { + gdimx = sm_count; + bdimx = 128; // 4 warps per warp group + bdimy = n_compute_warp_groups; + bdimz = 1; // warp specialized kernel requires static CTA shape + params->n_grouped_rows = n_rows_per_compute_warp_group; + ParallelType ws_pt = bdimy > 1 ? ParallelType::TIDy : ParallelType::TIDx; + WarpSpecialized ws(ws_pt); + if (ws_pt == ParallelType::TIDy) { + bdimy += 1; + ws.stage_slice_position = 3; + // Limitation in grouped reduction runtime function + NVF_ERROR(bdimx == 128, "bdimx must be 128 for TIDy warp specialization"); + NVF_ERROR( + params->n_grouped_rows > 1, + "n_grouped_rows must be greater than 1 for TIDy warp specialization"); + } else { + bdimx += kWarpSpecializationPaddedThreads; + } + int64_t total_threads = bdimx * bdimy * bdimz; + if (total_threads > 256) { + int64_t reg_per_thread = getRegPerThreadGivenThreadsPerSM(total_threads); + int64_t computation_threads = + total_threads - kWarpSpecializationPaddedThreads; + ws.num_registers = scheduler_utils::getRegisterSharing( + reg_per_thread, + computation_threads, + kWarpSpecializationPaddedThreads); + } + CircularBufferOptions circular_buffer_options{ + .type = ws, .stage = n_stages, .prefetch = n_stages - 1}; + params->circular_buffer_options = circular_buffer_options; + // Set launch parameters + params->lparams = LaunchParams( + gdimx, + LaunchParams::UNINITIALIZED_VAL, + LaunchParams::UNINITIALIZED_VAL, + bdimx, + bdimy, + bdimz); + } + // Set index type params->cparams.index_type = prop.index_type; - // Set launch parameters - params->lparams = LaunchParams( - LaunchParams::UNINITIALIZED_VAL, - LaunchParams::UNINITIALIZED_VAL, - LaunchParams::UNINITIALIZED_VAL, - LaunchParams::UNINITIALIZED_VAL, - LaunchParams::UNINITIALIZED_VAL, - LaunchParams::UNINITIALIZED_VAL); - if (isDebugDumpEnabled(DebugDumpOption::SchedulerDebug)) { debug() << prop.toString() << std::endl; debug() << params->toString() << std::endl; @@ -177,24 +227,55 @@ void scheduleInnerPersistent(Fusion* fusion, const InnerNormTmaParams* params) { // Parallelization strategy: // - axis(0): BIDx - each block handles one or more batch elements // - axis(1): Bulk - TMA asynchronously copies entire reduction dimension - int64_t ipos = 0, rpos = 1; - if (params->rows_per_block > 1) { + int64_t ipos = 0, rpos = 1, tidy_pos = -1, group_pos = -1; + if (params->circular_buffer_options.isEnable()) { + if (params->n_grouped_rows > 1) { + // [I, R] -> [I/Group, Group, R] + reduction_tv->split(ipos, params->n_grouped_rows); + group_pos = ipos + 1; + rpos++; + } + if (params->lparams.bdimy() > 2) { + NVF_ERROR_EQ( + std::get(params->circular_buffer_options.type).on, + ParallelType::TIDy); + // [I/Group, Group, R] -> [I/Group/TIDy, TIDy, Group, R] + reduction_tv->split(ipos, params->lparams.bdimy() - 1); + tidy_pos = ipos + 1; + rpos++; + group_pos++; + } + if (params->lparams.gdimx() > 1) { + // [I/Group/TIDy, TIDy, Group, R] -> [I/Group/TIDy/BIDx, BIDx, TIDy, + // Group, R] + reduction_tv->split(ipos, params->lparams.gdimx()); + reduction_tv->axis(ipos + 1)->parallelize(ParallelType::BIDx); + rpos++; + tidy_pos++; + group_pos++; + } + + } else if (params->n_grouped_rows > 1) { // [I, R] -> [I/TIDy, TIDy, R] - reduction_tv->split(ipos, params->rows_per_block); - TransformPropagator propagator(reduction_tv); - MaxLogicalDomainInfoSpanningTree(reduction_tv).traverse(&propagator); + reduction_tv->split(ipos, params->n_grouped_rows); reduction_tv->axis(ipos)->parallelize(ParallelType::BIDx); reduction_tv->axis(ipos + 1)->parallelize(ParallelType::TIDy); rpos = ipos + 2; } else { reduction_tv->axis(ipos)->parallelize(ParallelType::BIDx); } + TransformPropagator propagator(reduction_tv); + MaxLogicalDomainInfoSpanningTree(reduction_tv).traverse(&propagator); reduction_tv->axis(rpos)->parallelize(ParallelType::Bulk); scheduler_utils::parallelizeAllLike(reduction_tv, tma_tvs); // Reset reduction_tv's reduction axis back to Serial (only TMA loads use // Bulk) reduction_tv->axis(rpos)->parallelize(ParallelType::Serial); - + // For TMA tvs, we use serial to use 1 producer to serve all consumers + // parallelized by TIDy + if (tidy_pos > 0) { + reduction_tv->axis(tidy_pos)->parallelize(ParallelType::TIDy); + } // Transform reduction domain for efficient computation: // [I, R] -> [I, b, us, x, v] // Where: @@ -236,6 +317,18 @@ void scheduleInnerPersistent(Fusion* fusion, const InnerNormTmaParams* params) { ir_utils::allTvsExcept(fusion, {tma_tvs.begin(), tma_tvs.end()}); } scheduler_utils::parallelizeAllLike(reference_tv, non_tma_tvs); + if (params->circular_buffer_options.isEnable()) { + if (group_pos > 0) { + for (auto reduction_tv : reduction_tvs) { + reduction_tv->axis(group_pos)->parallelize(ParallelType::Group); + } + } + } else { + NVF_CHECK_EQ( + group_pos, + -1, + "Grouped reduction is only supported in warp specialized mode"); + } // Helper lambda to find the vectorization position (one past TIDx axis) auto get_vect_pos = [](TensorView* tv) -> std::optional { @@ -290,10 +383,57 @@ void scheduleInnerPersistent(Fusion* fusion, const InnerNormTmaParams* params) { if (params->pre_load_ldg_tvs) { exclude_tvs.insert(ldg_tvs.begin(), ldg_tvs.end()); } + if (params->circular_buffer_options.isEnable()) { + // when warp specialized, the iteration domain of tma tv is scheduled as: + // 1. GridStrideLoop + // 2. BIDx + // 3. Serial (Compute Warp Groups, TIDy in compute warp groups) + // 4. Serial (Multiple TMAs share one mbarrier, serial or grouped reduction + // in compuate warp groups) + constexpr int64_t pos_after_bidx = 2; + for (auto tv : tma_tvs) { + inlineSelectedAt({tv}, tv, pos_after_bidx); + exclude_tvs.insert(tv); + } + + // Happens in layer norm where the result of the 1st reduction is used by + // the 2nd reduction. Since each reduction is grouped in its iteration + // dimension we can't inline deeper than the group position. + if (group_pos > 0 && reduction_tvs.size() > 1) { + for (auto tv1 : reduction_tvs) { + for (auto tv2 : reduction_tvs) { + if (tv1 == tv2) { + continue; + } + auto all_vals = DependencyCheck::getAllValsBetween({tv1}, {tv2}); + auto gp_tvs = ir_utils::filterByType(all_vals); + for (auto gp_tv : gp_tvs) { + if (gp_tv->hasBroadcast() && !exclude_tvs.contains(gp_tv)) { + inlineSelectedAt({gp_tv}, gp_tv, group_pos); + exclude_tvs.insert(gp_tv); + } + } + } + } + } + } std::vector inline_most_tvs = ir_utils::allTvsExcept(fusion, exclude_tvs); inlineMost(inline_most_tvs); + if (params->circular_buffer_options.isEnable()) { + int64_t number_of_stages = params->circular_buffer_options.stage; + int64_t prefetch_distance = params->circular_buffer_options.prefetch; + CircularBufferType circular_buffer_type = + params->circular_buffer_options.type; + for (auto tv : tma_tvs) { + if (tv->getComputeAtPosition() > 0) { + tv->circularBuffer( + number_of_stages, prefetch_distance, circular_buffer_type); + } + } + } + // Refine cache policies for optimal memory hierarchy usage refineCachePolicy(fusion); } diff --git a/csrc/scheduler/normalization_inner_tma.h b/csrc/scheduler/normalization_inner_tma.h index 15236f1825c..3fb076e8730 100644 --- a/csrc/scheduler/normalization_inner_tma.h +++ b/csrc/scheduler/normalization_inner_tma.h @@ -44,7 +44,7 @@ class InnerNormTmaParams : public HeuristicParams { bool tma_load_non_persistent_buffers = false; // Number of rows per block - int64_t rows_per_block = 1; + int64_t n_grouped_rows = 1; // Circular buffer options CircularBufferOptions circular_buffer_options; @@ -62,7 +62,7 @@ class InnerNormTmaParams : public HeuristicParams { other->pre_load_ldg_tvs == pre_load_ldg_tvs && other->tma_load_non_persistent_buffers == tma_load_non_persistent_buffers && - other->rows_per_block == rows_per_block && + other->n_grouped_rows == n_grouped_rows && other->circular_buffer_options == circular_buffer_options; } @@ -77,7 +77,7 @@ class InnerNormTmaParams : public HeuristicParams { << (pre_load_ldg_tvs ? "Pre-load ldg tvs\n" : "") << (tma_load_non_persistent_buffers ? "TMA load non-persistent buffers\n" : "") - << "Rows per block: " << rows_per_block << "\n"; + << "Rows per block: " << n_grouped_rows << "\n"; if (circular_buffer_options.isEnable()) { ss << circular_buffer_options << "\n"; } else { @@ -96,7 +96,7 @@ class InnerNormTmaParams : public HeuristicParams { static_cast(persistent_batch_size) << (bits - 4) ^ static_cast(pre_load_ldg_tvs) << (bits - 5) ^ static_cast(tma_load_non_persistent_buffers) << (bits - 6) ^ - static_cast(rows_per_block) << (bits - 7) ^ + static_cast(n_grouped_rows) << (bits - 7) ^ static_cast(circular_buffer_options.stage) << (bits - 8) ^ static_cast(circular_buffer_options.prefetch) << (bits - 9); return attr_hash; diff --git a/csrc/scheduler/utils.cpp b/csrc/scheduler/utils.cpp index 134d514b9c1..17ac3c44328 100644 --- a/csrc/scheduler/utils.cpp +++ b/csrc/scheduler/utils.cpp @@ -3608,5 +3608,25 @@ int64_t getTmaDomainInner( return best_divisible_size; } +std::pair getRegisterSharing( + int64_t reg_per_thread, + int64_t computation_threads, + int64_t padded_threads) { + constexpr int64_t reg_per_async_thread = 32L; + constexpr int64_t regs_granularity = 8L; + int64_t tma_branch_regs = reg_per_async_thread; + int64_t compute_branch_regs = reg_per_thread + + (reg_per_thread - tma_branch_regs) * padded_threads / computation_threads; + if (compute_branch_regs % regs_granularity != 0) { + compute_branch_regs -= compute_branch_regs % regs_granularity; + tma_branch_regs = reg_per_thread - + (compute_branch_regs - reg_per_thread) * computation_threads / + padded_threads; + } + compute_branch_regs = + std::min(compute_branch_regs, scheduler_utils::max_registers_per_thread); + return std::make_pair(tma_branch_regs, compute_branch_regs); +} + } // namespace scheduler_utils } // namespace nvfuser diff --git a/csrc/scheduler/utils.h b/csrc/scheduler/utils.h index fdee872e161..c8c74f23593 100644 --- a/csrc/scheduler/utils.h +++ b/csrc/scheduler/utils.h @@ -1011,6 +1011,24 @@ int64_t getTmaDomainInner( int64_t total_element, int64_t tma_domain_inner_target = 512, int64_t min_dtype_bits = 8); -} // namespace scheduler_utils +// Calculate register sharing between TMA async threads and computation threads +// for warp specialization. Returns a pair of (tma_branch_registers, +// compute_branch_registers). +// +// Assumes padded threads keep [tma_branch_registers] registers and all others +// are moved to computation threads. The granularity is 8. When estimated +// compute_branch_regs is not divisible by granularity, it is rounded down and +// tma_branch_registers is recomputed. +// +// For example, assuming 256 computation threads, initial register = 168, +// tma_branch_regs = 32. then (168 - 32) * 128 / 256 = 68 which is not +// divisible by 8, compute_branch_registers = 168 + 68 = 236 --> rounded down to +// 232. re-calculate [tma_branch_registers] using: borrowed registers = (232 - +// 168) * 256 / 128 = 128. tma_branch_registers = 168 - 128 = 40 +std::pair getRegisterSharing( + int64_t reg_per_thread, + int64_t computation_threads, + int64_t padded_threads); +} // namespace scheduler_utils } // namespace nvfuser diff --git a/tests/cpp/test_persistent_buffer.cpp b/tests/cpp/test_persistent_buffer.cpp index 14e7565c076..5d463551e9b 100644 --- a/tests/cpp/test_persistent_buffer.cpp +++ b/tests/cpp/test_persistent_buffer.cpp @@ -2271,7 +2271,7 @@ TEST_P(TmaPersistentTestP, TmaInnerPersistentSoftmax) { auto& fusion = *fusion_ptr; FusionGuard fg(fusion_ptr.get()); - auto tv0 = makeContigTensor(2, dtype); + auto tv0 = makeContigConcreteTensor({x, y}, dtype); fusion.addInput(tv0); tv0 = maybeCastOp(DataType::Float, tv0); auto res = softmax(tv0, 1); @@ -2297,8 +2297,8 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Combine( testing::Values(DataType::BFloat16), testing::Values( - deviceSMCount() / 2, - 1024), // batch size, less or larger than sm count + deviceSMCount() / 2, // small batch, can't do grid stride loop + 2048), // batch size, less or larger than sm count testing::ValuesIn(Pow2Vals1to1Million)), // hidden size [](const testing::TestParamInfo& info) { auto dtype = std::get<0>(info.param);