Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions csrc/scheduler/normalization_inner_tma.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,18 @@ void scheduleInnerPersistentMultiwave(
ir_utils::allTvsExcept(fusion, {tma_tvs.begin(), tma_tvs.end()});
}
scheduler_utils::parallelizeAllLike(reference_tv, non_tma_tvs);
if (params->circular_buffer_options.isEnable()) {
if (group_pos > 0) {
for (auto reduction_tv : reduction_tvs) {
reduction_tv->axis(group_pos)->parallelize(ParallelType::Group);
}
}
} else {
NVF_CHECK_EQ(
group_pos,
-1,
"Grouped reduction is only supported in warp specialized mode");
}

// Apply vectorization
applyVectorization(fusion, params, setup);
Expand All @@ -380,10 +392,57 @@ void scheduleInnerPersistentMultiwave(
if (params->pre_load_ldg_tvs) {
exclude_tvs.insert(setup.ldg_tvs.begin(), setup.ldg_tvs.end());
}
if (params->circular_buffer_options.isEnable()) {
// when warp specialized, the iteration domain of tma tv is scheduled as:
// 1. GridStrideLoop
// 2. BIDx
// 3. Serial (Compute Warp Groups, TIDy in compute warp groups)
// 4. Serial (Multiple TMAs share one mbarrier, serial or grouped reduction
// in compuate warp groups)
constexpr int64_t pos_after_bidx = 2;
for (auto tv : tma_tvs) {
inlineSelectedAt({tv}, tv, pos_after_bidx);
exclude_tvs.insert(tv);
}

// Happens in layer norm where the result of the 1st reduction is used by
// the 2nd reduction. Since each reduction is grouped in its iteration
// dimension we can't inline deeper than the group position.
if (group_pos > 0 && reduction_tvs.size() > 1) {
for (auto tv1 : reduction_tvs) {
for (auto tv2 : reduction_tvs) {
if (tv1 == tv2) {
continue;
}
auto all_vals = DependencyCheck::getAllValsBetween({tv1}, {tv2});
auto gp_tvs = ir_utils::filterByType<TensorView>(all_vals);
for (auto gp_tv : gp_tvs) {
if (gp_tv->hasBroadcast() && !exclude_tvs.contains(gp_tv)) {
inlineSelectedAt({gp_tv}, gp_tv, group_pos);
exclude_tvs.insert(gp_tv);
}
}
}
}
}
}
std::vector<TensorView*> inline_most_tvs =
ir_utils::allTvsExcept(fusion, exclude_tvs);
inlineMost(inline_most_tvs);

if (params->circular_buffer_options.isEnable()) {
int64_t number_of_stages = params->circular_buffer_options.stage;
int64_t prefetch_distance = params->circular_buffer_options.prefetch;
CircularBufferType circular_buffer_type =
params->circular_buffer_options.type;
for (auto tv : tma_tvs) {
if (tv->getComputeAtPosition() > 0) {
tv->circularBuffer(
number_of_stages, prefetch_distance, circular_buffer_type);
}
}
}

// Refine cache policies for optimal memory hierarchy usage
refineCachePolicy(fusion);
}
Expand Down
Loading