Skip to content

Conversation

@tbqh
Copy link
Collaborator

@tbqh tbqh commented Jan 28, 2026

Optimize the TMA inner-reduction scheduler.

Copy of #5867 without serial split. Moving over to this PR to keep serial split code easy to find for future reference.

Performance of TMA vs non-TMA for float32:
2026-01-28_14-02

@github-actions
Copy link

github-actions bot commented Jan 28, 2026

Review updated until commit 7d94987

Description

  • Change TMA usage condition from element count to bytes with 16KB minimum

  • Set threads_per_block=256 and vectorization_factor=1 for TMA scheduler

  • Add cacheAndForkOutputs() call and improve vectorization split logic

  • Update test expectations to match new TMA byte-based criteria

Changes walkthrough

Relevant files
Enhancement
reduction.cpp
Update TMA usage condition to byte-based threshold             

csrc/scheduler/reduction.cpp

  • Change TMA eligibility check from element count to total bytes
  • Set minimum TMA transfer size to 16384 bytes (16KB)
  • Add dtype_bytes calculation for accurate size checking
  • +7/-2     
    reduction_tma.cpp
    Optimize TMA scheduler parameters and caching                       

    csrc/scheduler/reduction_tma.cpp

  • Set vectorization_factor=1 and threads_per_block=256
  • Add comments explaining vectorization limitations for TMA
  • Modify unroll_factor calculation using lastPow2
  • Add cacheAndForkOutputs() call for output caching
  • Make vectorization split conditional on vectorization_factor > 1
  • +19/-7   
    Tests
    test_reduction.cpp
    Update TMA test expectations for byte-based criteria         

    tests/cpp/test_reduction.cpp

  • Update test expectation logic to match new byte-based TMA criteria
  • Replace element count check with total bytes comparison
  • Use 16384 bytes minimum threshold for TMA usage expectations
  • +9/-5     

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review
    Performance regression risk for larger dtypes

    The change from element-based threshold (128 elements) to byte-based threshold (16384 bytes) could cause performance regressions for larger data types. For float64, the old threshold was 128 elements (1024 bytes), but the new threshold is 16384 bytes. This means TMA won't be used for reduction sizes between 1024-16384 bytes for float64, potentially making these cases slower than before.

    int64_t dtype_bytes = props.max_dtype_size_bit_for_vectorization / 8;
    uint64_t total_reduction_bytes = props.total_reduction_numel * dtype_bytes;
    
    // Minimum TMA transfer size, below which it seems much slower than non-TMA.
    uint64_t min_tma_bytes = 16384;
    
    if (total_reduction_bytes < min_tma_bytes) {
      return false;
    }
    Hardcoded performance parameters

    The threads_per_block is hardcoded to 256 and vectorization_factor is hardcoded to 1. These values may not be optimal across all GPU architectures and reduction patterns. Consider making these configurable or adding architecture-specific heuristics.

    int64_t threads_per_block = 256;
    Missing validation for vectorization factor

    The code now only applies the vectorization split when rparams->vectorization_factor > 1, but the parameter is always set to 1. This suggests the vectorization logic may be dead code or the parameter setting needs to be revisited to ensure it's actually used when beneficial.

    if (rparams->vectorization_factor > 1) {
      reduction_tv->split(inner_reduce_axis, rparams->vectorization_factor);
      reduction_tv->axis(inner_reduce_axis + 1)
          ->parallelize(ParallelType::Serial);
    }

    Test failures (partial, pipeline still running)

    • (Medium, 1) Thunder nanogpt autograd scalar mismatch in cuda tests

      Test Name H100 Source
      thunder.tests.test_networks.test_nanogpt_complete_autograd_nvfuser_cuda_thunder.dtypes.float32

    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Jan 28, 2026

    Greptile Overview

    Greptile Summary

    This PR optimizes the TMA (Tensor Memory Accelerator) inner-reduction scheduler by making three key changes:

    • Disabled vectorization for TMA (set to 1) since TMA handles bulk transfers and vectorization is less beneficial
    • Increased threads_per_block from 128 to 256 based on benchmarking results
    • Changed TMA eligibility threshold from element count (128 elements) to byte-based (16384 bytes), which better handles small data types

    The heuristics now fold the entire target_vect_unroll into the unroll factor since vectorization is disabled. The vectorization split is now conditional (only applied if vectorization_factor > 1). Tests were updated to match the new byte-based threshold logic. The ordering of cacheAndForkOutputs was also made consistent with the non-TMA scheduler.

    Confidence Score: 4/5

    • This PR is safe to merge with low risk - changes are well-documented optimizations based on benchmarking
    • The changes are straightforward performance optimizations with clear rationale in comments. The byte-based threshold is more accurate than element count for determining TMA viability. Tests are updated consistently. The ordering issue from the previous thread has been addressed. Minor confidence reduction due to hardcoded constants (256, 16384) that may need tuning for different hardware.
    • No files require special attention - all changes are consistent and well-justified

    Important Files Changed

    Filename Overview
    csrc/scheduler/reduction_tma.cpp Optimized TMA inner-reduction heuristics by disabling vectorization, setting threads_per_block to 256, and folding vectorization into unroll factor; added conditional vectorization split
    csrc/scheduler/reduction.cpp Changed TMA eligibility check from element count threshold (128) to byte-based threshold (16384 bytes) for better performance on small dtypes

    Sequence Diagram

    sequenceDiagram
        participant Scheduler as Reduction Scheduler
        participant TMA as TMA Heuristics
        participant Schedule as scheduleReduction
        participant Fusion as Fusion Graph
        
        Scheduler->>TMA: Check mayUseTma(props)
        Note over TMA: Verify GPU arch >= 9
        Note over TMA: Check 2D inner reduction [I, R]
        TMA->>TMA: Calculate total_reduction_bytes
        Note over TMA: Reject if < 16384 bytes
        TMA->>TMA: Check smem capacity
        TMA-->>Scheduler: TMA eligible
        
        Scheduler->>TMA: getReductionHeuristics()
        Note over TMA: Set vectorization_factor = 1<br/>(disabled for TMA)
        Note over TMA: Set threads_per_block = 256
        TMA->>TMA: Calculate target_vect_unroll
        Note over TMA: Fold into unroll_factor<br/>(lastPow2)
        TMA-->>Scheduler: TmaInnerReductionParams
        
        Scheduler->>Schedule: scheduleReduction(params)
        Schedule->>Fusion: cacheInputs(true)
        Schedule->>Fusion: clearMemorySpace()
        Schedule->>Fusion: cacheAndForkOutputs(true)
        Schedule->>Fusion: prepareForMemoryTypePromotion()
        
        Schedule->>Schedule: Convert cached inputs to TMA
        Note over Schedule: SetOpType(CpAsyncBulk)<br/>SetMemoryType(Shared)
        
        Schedule->>Schedule: Apply splits to reduction_tv
        Note over Schedule: if vectorization_factor > 1:<br/>  Split for vectorization
        Note over Schedule: Split for TIDx (256 threads)
        Note over Schedule: if unroll_factor > 1:<br/>  Split for unroll
        Note over Schedule: Split for unswitch
        
        Schedule->>Fusion: Propagate transformations
        Schedule-->>Scheduler: TMA-optimized schedule
    
    Loading

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    4 files reviewed, no comments

    Edit Code Review Agent Settings | Greptile

    auto dev_prop = at::cuda::getCurrentDeviceProperties();

    // These are derived from benchmarking.
    int64_t threads_per_block = props.has_mufu_computation ? 512 : 256;
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    props.has_mufu_computation ? 512 Is this benchmarked?
    We can keep a simple 256 with a clear code comment.

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Alright, I set it to 256 and added a comment.

    Here is a benchmark of reduction(sin(T0)), comparing TMA with different sizes, and non-TMA.
    2026-01-29_16-21

    And similar comparison for plain reduction (no mufu):
    2026-01-29_19-31

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Thanks for extra tests. These results are useful for future heuristic fine tuning.

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    3 files reviewed, 1 comment

    Edit Code Review Agent Settings | Greptile

    Comment on lines 76 to 78
    scheduler_utils::prepareForMemoryTypePromotion(fusion);

    scheduler_utils::cacheAndForkOutputs(fusion, true);
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    The order differs from reduction_non_tma.cpp:1357,1363 which calls cacheAndForkOutputs before prepareForMemoryTypePromotion. Verify this ordering doesn't affect output caching.

    Copy link
    Collaborator

    @liqiangxl liqiangxl left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    LGTM.

    // Vectorization is not as useful for TMA since we're not doing global loads.
    // Vectorization can hurt the performance of shared memory loads, due to bank
    // conflicts. For float32, ideal shared memory reads is achieved with no
    // vectorization. For float16, it would be ideal for vect_factor=2, but we
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    This statement doesn't seem right to me. Remove or add more context about why.

      // Vectorization can hurt the performance of shared memory loads, due to bank
      // conflicts.
    

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Added more detail:

      // Vectorization can hurt the performance of shared memory loads, due to bank
      // conflicts. E.g. for float32, ideal shared memory reads are achieved with no
      // vectorization, but with a vectorization factor of 2, thread 0 and 16 will
      // both hit bank0 during their first read.
    

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    This is not considered as bank conflict, you may use ncu to confirm. See https://forums.developer.nvidia.com/t/how-to-understand-the-bank-conflict-of-shared-mem/260900/2?u=liqiangl

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Very interesting thread, I was not aware of the 128-byte limit and separate transactions.

    However, it's unclear whether our case is the same. Our case is similar to test2 from that page, each thread loads or 4xfloat32. It seems they must be issuing 4-component float32 loads per every thread, e.g. each thread requests 16 bytes in a single instruction. That is the only way they could issue a 512-byte load per warp, which gets converted to 4 transactions.

    This is different than the our TMA kernel, because our "vectorization" gets applied as ParallelType::Serialize. This means we issue 128-byte loads, which issue as a single transaction, and this should hit bank conflicts.

    However, I tried enabling vectorization=4 for the kernel, and with ncu there were no bank conflicts. I don't understand this, perhaps the compiler further optimized the inner loop into 16-byte loads per thread.

    I removed mention of bank conflict from the code comment.

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    because our "vectorization" gets applied as ParallelType::Serialize

    I believe this is only true for computations. For loading from smem to regs we use ParallelType::Vectorize

    However, I tried enabling vectorization=4 for the kernel, and with ncu there were no bank conflicts. I don't understand this, perhaps the compiler further optimized the inner loop into 16-byte loads per thread.

    When vectorization=4, each thread loads 16 bytes or 4 banks. 32 threads loads 512 bytes, split into 4 transactions, each with 128 bytes (32 banks).

    @tbqh tbqh force-pushed the tbqh/inner_reduce_performance2 branch from 556ad21 to 81e72b3 Compare January 30, 2026 23:07
    @tbqh
    Copy link
    Collaborator Author

    tbqh commented Jan 30, 2026

    !test

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    3 files reviewed, no comments

    Edit Code Review Agent Settings | Greptile

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    2 files reviewed, no comments

    Edit Code Review Agent Settings | Greptile

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants