From 1b54705412a071000daf64392bf1f3c393686df4 Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Thu, 15 Jan 2026 08:17:02 -0800 Subject: [PATCH 1/2] avoid warp diverge in warp specialized kernel --- csrc/parallel_dimension_map.cpp | 37 +++++++++++++++++++++++++++ tests/cpp/test_circular_buffering.cpp | 32 +++++++++++------------ 2 files changed, 53 insertions(+), 16 deletions(-) diff --git a/csrc/parallel_dimension_map.cpp b/csrc/parallel_dimension_map.cpp index 3ee6ca64b39..7d4cad91fae 100644 --- a/csrc/parallel_dimension_map.cpp +++ b/csrc/parallel_dimension_map.cpp @@ -208,6 +208,43 @@ void ParallelDimensionMap::adjustMappingsForWarpSpecialization() { " and remaining active cta threads ", other_active_pts_threads); + // For warp specialization on TIDx, both the original and padded bdimx must + // be multiples of 32 to prevent warps from being split across producer and + // consumer roles. With CUDA's thread linearization (tidx + tidy * bdimx + + // tidz * bdimx * bdimy), if bdimx is not a multiple of 32, consecutive linear + // thread IDs wrap to the next tidy value mid-warp, splitting a warp across + // different roles. Example: CTA (32, 4, 2) with warp specialization on TIDx + // would pad to (48, 4, 2). Linear thread IDs 32-47 (padded producer threads, + // tidy=0) and 48-63 (original compute threads, tidy=1) occupy the same warp + // but have different roles, defeating warp specialization's purpose. + if (ws_pt == ParallelType::TIDx) { + int64_t original_tidx = getThreadCountInDim(ws_pt); + NVF_ERROR( + original_tidx % 32 == 0, + "Warp specialization on TIDx requires bdimx to be a multiple of 32 ", + "to avoid splitting warps across producer/consumer boundaries. ", + "Got bdimx = ", + original_tidx, + " with CTA shape (", + original_tidx, + ", ", + getThreadCountInDim(ParallelType::TIDy), + ", ", + getThreadCountInDim(ParallelType::TIDz), + ")"); + NVF_ERROR( + after_pad % 32 == 0, + "Warp specialization on TIDx requires padded bdimx to be a multiple of " + "32 to avoid warp diverge. " + "Got padded bdimx = ", + after_pad, + " (original: ", + original_tidx, + ", padding: ", + ws_num_threads_pad, + ")"); + } + // Apply the pad warp_specialized_padding_value_ = ws_num_threads_pad; auto offset = IrBuilder::create(ws_num_threads_pad, DataType::Index); diff --git a/tests/cpp/test_circular_buffering.cpp b/tests/cpp/test_circular_buffering.cpp index dddc6553f4c..17daebdc248 100644 --- a/tests/cpp/test_circular_buffering.cpp +++ b/tests/cpp/test_circular_buffering.cpp @@ -2515,22 +2515,11 @@ TEST_P(TmaRegisterSharing, CtaShapeShmoo) { constexpr int64_t n_stages = 2; - // If ws_pt == ParallelType::TIDx and bdim.x == 32, CUDA kernel cannot use - // register sharing. ncu reports it uses 26 register per thread. - // getNumRegisters expects 168 registers by default, so the register settings - // causes nvrtc to hang during compilation. - if (ws_pt == ParallelType::TIDx && getTmaPadThreads(ws_pt, bdim) < 32) { - CircularBufferType circular_buffer_type = WarpSpecialized(ws_pt); - tv1->circularBuffer( - n_stages, /*prefetch_distance=*/1, circular_buffer_type); - } else { - CircularBufferType circular_buffer_type = WarpSpecialized( - ws_pt, - getNumRegisters( - n_computation_threads, n_tma_branch_threads, n_total_threads)); - tv1->circularBuffer( - n_stages, /*prefetch_distance=*/1, circular_buffer_type); - } + CircularBufferType circular_buffer_type = WarpSpecialized( + ws_pt, + getNumRegisters( + n_computation_threads, n_tma_branch_threads, n_total_threads)); + tv1->circularBuffer(n_stages, /*prefetch_distance=*/1, circular_buffer_type); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({n_stages * gdimx, n_computation_threads}, options); @@ -2547,6 +2536,17 @@ TEST_P(TmaRegisterSharing, CtaShapeShmoo) { ASSERT_TRUE(str_match_pointer != nullptr); return; } + // If ws_pt == ParallelType::TIDx and CTA shape is (32, 4, 2), padded + // threads in x dim is 16, will cause warp divergence due to thread + // linearization. + if (ws_pt == ParallelType::TIDx && + getTmaPadThreads(ws_pt, bdim) % 32 != 0) { + const char* err_msg = + R"(Warp specialization on TIDx requires padded bdimx to be a multiple of 32)"; + const char* str_match_pointer = strstr(e.what(), err_msg); + ASSERT_TRUE(str_match_pointer != nullptr); + return; + } throw; } From 63ce41a2f9907a98f38f4b010631e12e7219625d Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Wed, 21 Jan 2026 18:27:01 -0800 Subject: [PATCH 2/2] fix conflict --- csrc/parallel_dimension_map.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/csrc/parallel_dimension_map.cpp b/csrc/parallel_dimension_map.cpp index b269586aa88..a531585cc18 100644 --- a/csrc/parallel_dimension_map.cpp +++ b/csrc/parallel_dimension_map.cpp @@ -241,7 +241,7 @@ void ParallelDimensionMap::adjustMappingsForWarpSpecialization() { // tidy=0) and 48-63 (original compute threads, tidy=1) occupy the same warp // but have different roles, defeating warp specialization's purpose. if (ws_pt == ParallelType::TIDx) { - int64_t original_tidx = getThreadCountInDim(ws_pt); + int64_t original_tidx = getStaticComputeThreadsInDim(ws_pt); NVF_ERROR( original_tidx % 32 == 0, "Warp specialization on TIDx requires bdimx to be a multiple of 32 ", @@ -251,9 +251,9 @@ void ParallelDimensionMap::adjustMappingsForWarpSpecialization() { " with CTA shape (", original_tidx, ", ", - getThreadCountInDim(ParallelType::TIDy), + getStaticComputeThreadsInDim(ParallelType::TIDy), ", ", - getThreadCountInDim(ParallelType::TIDz), + getStaticComputeThreadsInDim(ParallelType::TIDz), ")"); NVF_ERROR( after_pad % 32 == 0,