diff --git a/csrc/parallel_dimension_map.cpp b/csrc/parallel_dimension_map.cpp index 16e276d2731..a531585cc18 100644 --- a/csrc/parallel_dimension_map.cpp +++ b/csrc/parallel_dimension_map.cpp @@ -231,6 +231,43 @@ void ParallelDimensionMap::adjustMappingsForWarpSpecialization() { " and remaining active cta threads ", other_active_pts_threads); + // For warp specialization on TIDx, both the original and padded bdimx must + // be multiples of 32 to prevent warps from being split across producer and + // consumer roles. With CUDA's thread linearization (tidx + tidy * bdimx + + // tidz * bdimx * bdimy), if bdimx is not a multiple of 32, consecutive linear + // thread IDs wrap to the next tidy value mid-warp, splitting a warp across + // different roles. Example: CTA (32, 4, 2) with warp specialization on TIDx + // would pad to (48, 4, 2). Linear thread IDs 32-47 (padded producer threads, + // tidy=0) and 48-63 (original compute threads, tidy=1) occupy the same warp + // but have different roles, defeating warp specialization's purpose. + if (ws_pt == ParallelType::TIDx) { + int64_t original_tidx = getStaticComputeThreadsInDim(ws_pt); + NVF_ERROR( + original_tidx % 32 == 0, + "Warp specialization on TIDx requires bdimx to be a multiple of 32 ", + "to avoid splitting warps across producer/consumer boundaries. ", + "Got bdimx = ", + original_tidx, + " with CTA shape (", + original_tidx, + ", ", + getStaticComputeThreadsInDim(ParallelType::TIDy), + ", ", + getStaticComputeThreadsInDim(ParallelType::TIDz), + ")"); + NVF_ERROR( + after_pad % 32 == 0, + "Warp specialization on TIDx requires padded bdimx to be a multiple of " + "32 to avoid warp diverge. " + "Got padded bdimx = ", + after_pad, + " (original: ", + original_tidx, + ", padding: ", + ws_num_threads_pad, + ")"); + } + // Apply the pad warp_specialized_padding_value_ = ws_num_threads_pad; auto offset = IrBuilder::create(ws_num_threads_pad, DataType::Index); diff --git a/tests/cpp/test_circular_buffering.cpp b/tests/cpp/test_circular_buffering.cpp index fb2d5299908..28a354be19c 100644 --- a/tests/cpp/test_circular_buffering.cpp +++ b/tests/cpp/test_circular_buffering.cpp @@ -2515,22 +2515,11 @@ TEST_P(TmaRegisterSharing, CtaShapeShmoo) { constexpr int64_t n_stages = 2; - // If ws_pt == ParallelType::TIDx and bdim.x == 32, CUDA kernel cannot use - // register sharing. ncu reports it uses 26 register per thread. - // getNumRegisters expects 168 registers by default, so the register settings - // causes nvrtc to hang during compilation. - if (ws_pt == ParallelType::TIDx && getTmaPadThreads(ws_pt, bdim) < 32) { - CircularBufferType circular_buffer_type = WarpSpecialized(ws_pt); - tv1->circularBuffer( - n_stages, /*prefetch_distance=*/1, circular_buffer_type); - } else { - CircularBufferType circular_buffer_type = WarpSpecialized( - ws_pt, - getNumRegisters( - n_computation_threads, n_tma_branch_threads, n_total_threads)); - tv1->circularBuffer( - n_stages, /*prefetch_distance=*/1, circular_buffer_type); - } + CircularBufferType circular_buffer_type = WarpSpecialized( + ws_pt, + getNumRegisters( + n_computation_threads, n_tma_branch_threads, n_total_threads)); + tv1->circularBuffer(n_stages, /*prefetch_distance=*/1, circular_buffer_type); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({n_stages * gdimx, n_computation_threads}, options); @@ -2547,6 +2536,17 @@ TEST_P(TmaRegisterSharing, CtaShapeShmoo) { ASSERT_TRUE(str_match_pointer != nullptr); return; } + // If ws_pt == ParallelType::TIDx and CTA shape is (32, 4, 2), padded + // threads in x dim is 16, will cause warp divergence due to thread + // linearization. + if (ws_pt == ParallelType::TIDx && + getTmaPadThreads(ws_pt, bdim) % 32 != 0) { + const char* err_msg = + R"(Warp specialization on TIDx requires padded bdimx to be a multiple of 32)"; + const char* str_match_pointer = strstr(e.what(), err_msg); + ASSERT_TRUE(str_match_pointer != nullptr); + return; + } throw; }