Skip to content

Conversation

@liqiangxl
Copy link
Collaborator

No description provided.

@github-actions
Copy link

github-actions bot commented Jan 15, 2026

Review updated until commit 63ce41a

Description

  • Add validation to prevent warp divergence in warp specialized kernels

  • Ensure TIDx warp specialization requires bdimx to be multiple of 32

  • Validate both original and padded thread counts for warp specialization

  • Update tests to verify new validation error messages

Changes walkthrough

Relevant files
Enhancement
parallel_dimension_map.cpp
Add TIDx warp specialization validation checks                     

csrc/parallel_dimension_map.cpp

  • Added detailed comment explaining warp specialization requirements for
    TIDx
  • Added NVF_ERROR checks to validate original bdimx is multiple of 32
  • Added NVF_ERROR checks to validate padded bdimx is multiple of 32
  • Prevents warps from being split across producer/consumer boundaries
  • +37/-0   
    Tests
    test_circular_buffering.cpp
    Update tests for TIDx warp specialization validation         

    tests/cpp/test_circular_buffering.cpp

  • Simplified circular buffer setup by removing conditional register
    logic
  • Added error message validation for TIDx warp specialization padding
  • Added check for non-multiple-of-32 padded threads causing warp
    divergence
  • Updated test to verify new validation error messages
  • +16/-16 

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review
    Error Message Clarity

    The error messages for warp specialization on TIDx could be more concise. The current messages are very detailed but might be overwhelming. Consider if the essential information can be conveyed more succinctly while maintaining clarity.

    NVF_ERROR(
        original_tidx % 32 == 0,
        "Warp specialization on TIDx requires bdimx to be a multiple of 32 ",
        "to avoid splitting warps across producer/consumer boundaries. ",
        "Got bdimx = ",
        original_tidx,
        " with CTA shape (",
        original_tidx,
        ", ",
        getStaticComputeThreadsInDim(ParallelType::TIDy),
        ", ",
        getStaticComputeThreadsInDim(ParallelType::TIDz),
        ")");
    NVF_ERROR(
        after_pad % 32 == 0,
        "Warp specialization on TIDx requires padded bdimx to be a multiple of "
        "32 to avoid warp diverge. "
        "Got padded bdimx = ",
        after_pad,
        " (original: ",
        original_tidx,
        ", padding: ",
        ws_num_threads_pad,
        ")");
    Test Coverage Completeness

    The test modification removes the conditional logic that previously handled edge cases differently. Ensure that all previously covered scenarios are still adequately tested, particularly the case where register sharing couldn't be used with bdimx == 32.

    CircularBufferType circular_buffer_type = WarpSpecialized(
        ws_pt,
        getNumRegisters(
            n_computation_threads, n_tma_branch_threads, n_total_threads));
    tv1->circularBuffer(n_stages, /*prefetch_distance=*/1, circular_buffer_type);

    Test failures

    • (High, 186) CUDA driver/runtime mismatch breaking nvFuser tests on dlcluster_h100

      Test Name H100 Source
      ArgsortParameterizedWithBlockAndBatch.SharedMemoryRequirement/1024_3_1_0 Link
      ArgsortParameterizedWithBlockAndBatch.SharedMemoryRequirement/512_1_0_0 Link
      ArgsortParameterizedWithBlockAndBatch.SharedMemoryRequirement/512_1_1_1 Link
      ArgsortParameterizedWithBlockAndBatch.SharedMemoryRequirement/512_2_0_1 Link
      ArgsortParameterizedWithBlockAndBatch.SharedMemoryRequirement/512_3_0_0 Link
      BlockSizeAndItemsPerThread/ArgSortComprehensiveTest.ComprehensiveValidation/BlockSize128_ItemsPerThread4 Link
      BlockSizeAndItemsPerThread/ArgSortComprehensiveTest.ComprehensiveValidation/BlockSize32_ItemsPerThread5 Link
      ClusterReductionTest.SimpleFusionAllReduce/cluster_10_dtype_float Link
      ClusterReductionTest.SimpleFusionNotAllReduce/cluster_10_dtype___bfloat Link
      ClusterReductionTest.SimpleFusionNotAllReduce/cluster_15_dtype___bfloat Link
      ... with 176 more test failures omitted. Check internal logs.
    • (High, 16) CUDA driver too old on dlcluster_h100 causing device_count_ensure_non_zero failures in RNGTest

      Test Name H100 Source
      .thunder.tests.opinfos
      .thunder.tests.test_apex_cross_entropy_executor
      .thunder.tests.test_auto_register_torchops
      .thunder.tests.test_cudnn_executor
      .thunder.tests.test_einops
      .thunder.tests.test_grad
      .thunder.tests.test_nvfuser
      .thunder.tests.test_ops
      .thunder.tests.test_sdpaex_executor
      .thunder.tests.test_torch_compile_executor
      ... with 6 more test failures omitted. Check internal logs.
    • (Medium, 1) Thunder NVFuser nanoGPT autograd mismatch in test_networks

      Test Name GB200 Source
      thunder.tests.test_networks.test_nanogpt_complete_autograd_nvfuser_cuda_thunder.dtypes.float32

    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Jan 15, 2026

    Greptile Summary

    Added compile-time validation to prevent warp divergence in warp-specialized kernels when using TIDx specialization. The change enforces that both the original and padded bdimx must be multiples of 32 to prevent CUDA's thread linearization from splitting warps across producer and consumer roles.

    Key changes:

    • Added two validation checks in adjustMappingsForWarpSpecialization() for TIDx specialization
    • Removed workaround code in test that disabled register sharing for problematic configurations
    • Updated test to properly handle the new validation error for CTA shapes that would cause warp divergence

    Technical context:
    With CUDA's thread linearization formula (tidx + tidy * bdimx + tidz * bdimx * bdimy), if bdimx is not a multiple of 32, consecutive linear thread IDs can wrap to the next tidy value mid-warp. For example, CTA (32, 4, 2) with 16-thread padding becomes (48, 4, 2), causing threads 32-47 (padded producers at tidy=0) and threads 48-63 (compute threads at tidy=1) to occupy the same warp despite having different roles.

    Confidence Score: 5/5

    • This PR is safe to merge with no significant risks
    • The changes add necessary compile-time validation to prevent a subtle but serious warp divergence bug in GPU kernels. The validation logic is mathematically sound, the error messages are clear and informative, and the test coverage has been properly updated to verify the new checks work correctly. The implementation correctly identifies problematic CTA configurations before code generation.
    • No files require special attention

    Important Files Changed

    Filename Overview
    csrc/parallel_dimension_map.cpp Added validation to prevent warp divergence in TIDx specialization by ensuring both original and padded bdimx are multiples of 32
    tests/cpp/test_circular_buffering.cpp Updated test to expect the new validation error for TIDx warp specialization with non-32-aligned padding, removed workaround code

    Sequence Diagram

    sequenceDiagram
        participant Test as Test Suite
        participant PDM as ParallelDimensionMap
        participant Validator as Warp Divergence Validator
        
        Test->>PDM: adjustMappingsForWarpSpecialization(TIDx)
        PDM->>PDM: Calculate other_active_threads (bdimy * bdimz)
        PDM->>PDM: Calculate padding: 128 / other_active_threads
        PDM->>PDM: Calculate after_pad = original_tidx + padding
        
        alt ws_pt == TIDx
            PDM->>Validator: Check original_tidx % 32 == 0
            alt original_tidx not multiple of 32
                Validator-->>Test: ERROR: bdimx must be multiple of 32
            else original_tidx is valid
                PDM->>Validator: Check after_pad % 32 == 0
                alt after_pad not multiple of 32
                    Validator-->>Test: ERROR: padded bdimx must be multiple of 32
                else after_pad is valid
                    PDM->>PDM: Apply padding to dimension map
                    PDM-->>Test: Success
                end
            end
        else ws_pt != TIDx
            PDM->>PDM: Apply padding to dimension map
            PDM-->>Test: Success
        end
    
    Loading

    @liqiangxl
    Copy link
    Collaborator Author

    !test

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    2 files reviewed, 1 comment

    Edit Code Review Agent Settings | Greptile

    Comment on lines +2542 to +2543
    if (ws_pt == ParallelType::TIDx &&
    getTmaPadThreads(ws_pt, bdim) % 32 != 0) {
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    style: condition checks padding amount but validation checks total (original + padding). works for current test cases where original_tidx is always a multiple of 32, but would fail if test added case like dim3(96, 8, 1) where original=96 (divisible by 32), pad=16 (not divisible), but after_pad=112 (not divisible by 32)

    Suggested change
    if (ws_pt == ParallelType::TIDx &&
    getTmaPadThreads(ws_pt, bdim) % 32 != 0) {
    if (ws_pt == ParallelType::TIDx &&
    (bdim.x + getTmaPadThreads(ws_pt, bdim)) % 32 != 0) {

    is the test suite intended to only cover cases where original bdimx is a multiple of 32?

    @liqiangxl
    Copy link
    Collaborator Author

    !test

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    1 participant