-
Notifications
You must be signed in to change notification settings - Fork 75
Optimize TMA inner-reduction #5887
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Review updated until commit 7d94987 Description
|
| Relevant files | |||||
|---|---|---|---|---|---|
| Enhancement |
| ||||
| Tests |
|
PR Reviewer Guide
Here are some key observations to aid the review process:
| 🧪 PR contains tests |
| ⚡ Recommended focus areas for review |
Performance regression risk for larger dtypes
|
Test failures (partial, pipeline still running)
-
(Medium, 1)
Thunder nanogpt autograd scalar mismatch in cuda testsTest Name H100 Source thunder.tests.test_networks.test_nanogpt_complete_autograd_nvfuser_cuda_thunder.dtypes.float32 ❌
Greptile OverviewGreptile SummaryThis PR optimizes the TMA (Tensor Memory Accelerator) inner-reduction scheduler by making three key changes:
The heuristics now fold the entire Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Scheduler as Reduction Scheduler
participant TMA as TMA Heuristics
participant Schedule as scheduleReduction
participant Fusion as Fusion Graph
Scheduler->>TMA: Check mayUseTma(props)
Note over TMA: Verify GPU arch >= 9
Note over TMA: Check 2D inner reduction [I, R]
TMA->>TMA: Calculate total_reduction_bytes
Note over TMA: Reject if < 16384 bytes
TMA->>TMA: Check smem capacity
TMA-->>Scheduler: TMA eligible
Scheduler->>TMA: getReductionHeuristics()
Note over TMA: Set vectorization_factor = 1<br/>(disabled for TMA)
Note over TMA: Set threads_per_block = 256
TMA->>TMA: Calculate target_vect_unroll
Note over TMA: Fold into unroll_factor<br/>(lastPow2)
TMA-->>Scheduler: TmaInnerReductionParams
Scheduler->>Schedule: scheduleReduction(params)
Schedule->>Fusion: cacheInputs(true)
Schedule->>Fusion: clearMemorySpace()
Schedule->>Fusion: cacheAndForkOutputs(true)
Schedule->>Fusion: prepareForMemoryTypePromotion()
Schedule->>Schedule: Convert cached inputs to TMA
Note over Schedule: SetOpType(CpAsyncBulk)<br/>SetMemoryType(Shared)
Schedule->>Schedule: Apply splits to reduction_tv
Note over Schedule: if vectorization_factor > 1:<br/> Split for vectorization
Note over Schedule: Split for TIDx (256 threads)
Note over Schedule: if unroll_factor > 1:<br/> Split for unroll
Note over Schedule: Split for unswitch
Schedule->>Fusion: Propagate transformations
Schedule-->>Scheduler: TMA-optimized schedule
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4 files reviewed, no comments
csrc/scheduler/reduction_tma.cpp
Outdated
| auto dev_prop = at::cuda::getCurrentDeviceProperties(); | ||
|
|
||
| // These are derived from benchmarking. | ||
| int64_t threads_per_block = props.has_mufu_computation ? 512 : 256; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
props.has_mufu_computation ? 512 Is this benchmarked?
We can keep a simple 256 with a clear code comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for extra tests. These results are useful for future heuristic fine tuning.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, 1 comment
csrc/scheduler/reduction_tma.cpp
Outdated
| scheduler_utils::prepareForMemoryTypePromotion(fusion); | ||
|
|
||
| scheduler_utils::cacheAndForkOutputs(fusion, true); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The order differs from reduction_non_tma.cpp:1357,1363 which calls cacheAndForkOutputs before prepareForMemoryTypePromotion. Verify this ordering doesn't affect output caching.
liqiangxl
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
csrc/scheduler/reduction_tma.cpp
Outdated
| // Vectorization is not as useful for TMA since we're not doing global loads. | ||
| // Vectorization can hurt the performance of shared memory loads, due to bank | ||
| // conflicts. For float32, ideal shared memory reads is achieved with no | ||
| // vectorization. For float16, it would be ideal for vect_factor=2, but we |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This statement doesn't seem right to me. Remove or add more context about why.
// Vectorization can hurt the performance of shared memory loads, due to bank
// conflicts.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added more detail:
// Vectorization can hurt the performance of shared memory loads, due to bank
// conflicts. E.g. for float32, ideal shared memory reads are achieved with no
// vectorization, but with a vectorization factor of 2, thread 0 and 16 will
// both hit bank0 during their first read.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not considered as bank conflict, you may use ncu to confirm. See https://forums.developer.nvidia.com/t/how-to-understand-the-bank-conflict-of-shared-mem/260900/2?u=liqiangl
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very interesting thread, I was not aware of the 128-byte limit and separate transactions.
However, it's unclear whether our case is the same. Our case is similar to test2 from that page, each thread loads or 4xfloat32. It seems they must be issuing 4-component float32 loads per every thread, e.g. each thread requests 16 bytes in a single instruction. That is the only way they could issue a 512-byte load per warp, which gets converted to 4 transactions.
This is different than the our TMA kernel, because our "vectorization" gets applied as ParallelType::Serialize. This means we issue 128-byte loads, which issue as a single transaction, and this should hit bank conflicts.
However, I tried enabling vectorization=4 for the kernel, and with ncu there were no bank conflicts. I don't understand this, perhaps the compiler further optimized the inner loop into 16-byte loads per thread.
I removed mention of bank conflict from the code comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
because our "vectorization" gets applied as ParallelType::Serialize
I believe this is only true for computations. For loading from smem to regs we use ParallelType::Vectorize
However, I tried enabling vectorization=4 for the kernel, and with ncu there were no bank conflicts. I don't understand this, perhaps the compiler further optimized the inner loop into 16-byte loads per thread.
When vectorization=4, each thread loads 16 bytes or 4 banks. 32 threads loads 512 bytes, split into 4 transactions, each with 128 bytes (32 banks).
556ad21 to
81e72b3
Compare
|
!test |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, no comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 files reviewed, no comments


Optimize the TMA inner-reduction scheduler.
Copy of #5867 without serial split. Moving over to this PR to keep serial split code easy to find for future reference.
Performance of TMA vs non-TMA for float32:
