diff --git a/csrc/scheduler/normalization_inner_tma.cpp b/csrc/scheduler/normalization_inner_tma.cpp index acea291e777..0050add7457 100644 --- a/csrc/scheduler/normalization_inner_tma.cpp +++ b/csrc/scheduler/normalization_inner_tma.cpp @@ -365,6 +365,18 @@ void scheduleInnerPersistentMultiwave( ir_utils::allTvsExcept(fusion, {tma_tvs.begin(), tma_tvs.end()}); } scheduler_utils::parallelizeAllLike(reference_tv, non_tma_tvs); + if (params->circular_buffer_options.isEnable()) { + if (group_pos > 0) { + for (auto reduction_tv : reduction_tvs) { + reduction_tv->axis(group_pos)->parallelize(ParallelType::Group); + } + } + } else { + NVF_CHECK_EQ( + group_pos, + -1, + "Grouped reduction is only supported in warp specialized mode"); + } // Apply vectorization applyVectorization(fusion, params, setup); @@ -380,10 +392,57 @@ void scheduleInnerPersistentMultiwave( if (params->pre_load_ldg_tvs) { exclude_tvs.insert(setup.ldg_tvs.begin(), setup.ldg_tvs.end()); } + if (params->circular_buffer_options.isEnable()) { + // when warp specialized, the iteration domain of tma tv is scheduled as: + // 1. GridStrideLoop + // 2. BIDx + // 3. Serial (Compute Warp Groups, TIDy in compute warp groups) + // 4. Serial (Multiple TMAs share one mbarrier, serial or grouped reduction + // in compuate warp groups) + constexpr int64_t pos_after_bidx = 2; + for (auto tv : tma_tvs) { + inlineSelectedAt({tv}, tv, pos_after_bidx); + exclude_tvs.insert(tv); + } + + // Happens in layer norm where the result of the 1st reduction is used by + // the 2nd reduction. Since each reduction is grouped in its iteration + // dimension we can't inline deeper than the group position. + if (group_pos > 0 && reduction_tvs.size() > 1) { + for (auto tv1 : reduction_tvs) { + for (auto tv2 : reduction_tvs) { + if (tv1 == tv2) { + continue; + } + auto all_vals = DependencyCheck::getAllValsBetween({tv1}, {tv2}); + auto gp_tvs = ir_utils::filterByType(all_vals); + for (auto gp_tv : gp_tvs) { + if (gp_tv->hasBroadcast() && !exclude_tvs.contains(gp_tv)) { + inlineSelectedAt({gp_tv}, gp_tv, group_pos); + exclude_tvs.insert(gp_tv); + } + } + } + } + } + } std::vector inline_most_tvs = ir_utils::allTvsExcept(fusion, exclude_tvs); inlineMost(inline_most_tvs); + if (params->circular_buffer_options.isEnable()) { + int64_t number_of_stages = params->circular_buffer_options.stage; + int64_t prefetch_distance = params->circular_buffer_options.prefetch; + CircularBufferType circular_buffer_type = + params->circular_buffer_options.type; + for (auto tv : tma_tvs) { + if (tv->getComputeAtPosition() > 0) { + tv->circularBuffer( + number_of_stages, prefetch_distance, circular_buffer_type); + } + } + } + // Refine cache policies for optimal memory hierarchy usage refineCachePolicy(fusion); }