Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions csrc/parallel_dimension_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,43 @@ void ParallelDimensionMap::adjustMappingsForWarpSpecialization() {
" and remaining active cta threads ",
other_active_pts_threads);

// For warp specialization on TIDx, both the original and padded bdimx must
// be multiples of 32 to prevent warps from being split across producer and
// consumer roles. With CUDA's thread linearization (tidx + tidy * bdimx +
// tidz * bdimx * bdimy), if bdimx is not a multiple of 32, consecutive linear
// thread IDs wrap to the next tidy value mid-warp, splitting a warp across
// different roles. Example: CTA (32, 4, 2) with warp specialization on TIDx
// would pad to (48, 4, 2). Linear thread IDs 32-47 (padded producer threads,
// tidy=0) and 48-63 (original compute threads, tidy=1) occupy the same warp
// but have different roles, defeating warp specialization's purpose.
if (ws_pt == ParallelType::TIDx) {
int64_t original_tidx = getStaticComputeThreadsInDim(ws_pt);
NVF_ERROR(
original_tidx % 32 == 0,
"Warp specialization on TIDx requires bdimx to be a multiple of 32 ",
"to avoid splitting warps across producer/consumer boundaries. ",
"Got bdimx = ",
original_tidx,
" with CTA shape (",
original_tidx,
", ",
getStaticComputeThreadsInDim(ParallelType::TIDy),
", ",
getStaticComputeThreadsInDim(ParallelType::TIDz),
")");
NVF_ERROR(
after_pad % 32 == 0,
"Warp specialization on TIDx requires padded bdimx to be a multiple of "
"32 to avoid warp diverge. "
"Got padded bdimx = ",
after_pad,
" (original: ",
original_tidx,
", padding: ",
ws_num_threads_pad,
")");
}

// Apply the pad
warp_specialized_padding_value_ = ws_num_threads_pad;
auto offset = IrBuilder::create<Val>(ws_num_threads_pad, DataType::Index);
Expand Down
32 changes: 16 additions & 16 deletions tests/cpp/test_circular_buffering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2515,22 +2515,11 @@ TEST_P(TmaRegisterSharing, CtaShapeShmoo) {

constexpr int64_t n_stages = 2;

// If ws_pt == ParallelType::TIDx and bdim.x == 32, CUDA kernel cannot use
// register sharing. ncu reports it uses 26 register per thread.
// getNumRegisters expects 168 registers by default, so the register settings
// causes nvrtc to hang during compilation.
if (ws_pt == ParallelType::TIDx && getTmaPadThreads(ws_pt, bdim) < 32) {
CircularBufferType circular_buffer_type = WarpSpecialized(ws_pt);
tv1->circularBuffer(
n_stages, /*prefetch_distance=*/1, circular_buffer_type);
} else {
CircularBufferType circular_buffer_type = WarpSpecialized(
ws_pt,
getNumRegisters(
n_computation_threads, n_tma_branch_threads, n_total_threads));
tv1->circularBuffer(
n_stages, /*prefetch_distance=*/1, circular_buffer_type);
}
CircularBufferType circular_buffer_type = WarpSpecialized(
ws_pt,
getNumRegisters(
n_computation_threads, n_tma_branch_threads, n_total_threads));
tv1->circularBuffer(n_stages, /*prefetch_distance=*/1, circular_buffer_type);

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::Tensor t0 = at::randn({n_stages * gdimx, n_computation_threads}, options);
Expand All @@ -2547,6 +2536,17 @@ TEST_P(TmaRegisterSharing, CtaShapeShmoo) {
ASSERT_TRUE(str_match_pointer != nullptr);
return;
}
// If ws_pt == ParallelType::TIDx and CTA shape is (32, 4, 2), padded
// threads in x dim is 16, will cause warp divergence due to thread
// linearization.
if (ws_pt == ParallelType::TIDx &&
getTmaPadThreads(ws_pt, bdim) % 32 != 0) {
Comment on lines +2542 to +2543
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: condition checks padding amount but validation checks total (original + padding). works for current test cases where original_tidx is always a multiple of 32, but would fail if test added case like dim3(96, 8, 1) where original=96 (divisible by 32), pad=16 (not divisible), but after_pad=112 (not divisible by 32)

Suggested change
if (ws_pt == ParallelType::TIDx &&
getTmaPadThreads(ws_pt, bdim) % 32 != 0) {
if (ws_pt == ParallelType::TIDx &&
(bdim.x + getTmaPadThreads(ws_pt, bdim)) % 32 != 0) {

is the test suite intended to only cover cases where original bdimx is a multiple of 32?

const char* err_msg =
R"(Warp specialization on TIDx requires padded bdimx to be a multiple of 32)";
const char* str_match_pointer = strstr(e.what(), err_msg);
ASSERT_TRUE(str_match_pointer != nullptr);
return;
}
throw;
}

Expand Down