-
Notifications
You must be signed in to change notification settings - Fork 76
avoid warp diverge in warp specialized kernel #5830
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Review updated until commit 63ce41a Description
|
| Relevant files | |||
|---|---|---|---|
| Enhancement |
| ||
| Tests |
|
PR Reviewer Guide
Here are some key observations to aid the review process:
| 🧪 PR contains tests |
| ⚡ Recommended focus areas for review |
Error Message Clarity
|
Test failures
-
(High, 186)
CUDA driver/runtime mismatch breaking nvFuser tests on dlcluster_h100Test Name H100 Source ArgsortParameterizedWithBlockAndBatch.SharedMemoryRequirement/1024_3_1_0 ❌ Link ArgsortParameterizedWithBlockAndBatch.SharedMemoryRequirement/512_1_0_0 ❌ Link ArgsortParameterizedWithBlockAndBatch.SharedMemoryRequirement/512_1_1_1 ❌ Link ArgsortParameterizedWithBlockAndBatch.SharedMemoryRequirement/512_2_0_1 ❌ Link ArgsortParameterizedWithBlockAndBatch.SharedMemoryRequirement/512_3_0_0 ❌ Link BlockSizeAndItemsPerThread/ArgSortComprehensiveTest.ComprehensiveValidation/BlockSize128_ItemsPerThread4 ❌ Link BlockSizeAndItemsPerThread/ArgSortComprehensiveTest.ComprehensiveValidation/BlockSize32_ItemsPerThread5 ❌ Link ClusterReductionTest.SimpleFusionAllReduce/cluster_10_dtype_float ❌ Link ClusterReductionTest.SimpleFusionNotAllReduce/cluster_10_dtype___bfloat ❌ Link ClusterReductionTest.SimpleFusionNotAllReduce/cluster_15_dtype___bfloat ❌ Link ... with 176 more test failures omitted. Check internal logs. -
(High, 16)
CUDA driver too old on dlcluster_h100 causing device_count_ensure_non_zero failures in RNGTestTest Name H100 Source .thunder.tests.opinfos ❌ .thunder.tests.test_apex_cross_entropy_executor ❌ .thunder.tests.test_auto_register_torchops ❌ .thunder.tests.test_cudnn_executor ❌ .thunder.tests.test_einops ❌ .thunder.tests.test_grad ❌ .thunder.tests.test_nvfuser ❌ .thunder.tests.test_ops ❌ .thunder.tests.test_sdpaex_executor ❌ .thunder.tests.test_torch_compile_executor ❌ ... with 6 more test failures omitted. Check internal logs. -
(Medium, 1)
Thunder NVFuser nanoGPT autograd mismatch in test_networksTest Name GB200 Source thunder.tests.test_networks.test_nanogpt_complete_autograd_nvfuser_cuda_thunder.dtypes.float32 ❌
Greptile SummaryAdded compile-time validation to prevent warp divergence in warp-specialized kernels when using TIDx specialization. The change enforces that both the original and padded Key changes:
Technical context: Confidence Score: 5/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Test as Test Suite
participant PDM as ParallelDimensionMap
participant Validator as Warp Divergence Validator
Test->>PDM: adjustMappingsForWarpSpecialization(TIDx)
PDM->>PDM: Calculate other_active_threads (bdimy * bdimz)
PDM->>PDM: Calculate padding: 128 / other_active_threads
PDM->>PDM: Calculate after_pad = original_tidx + padding
alt ws_pt == TIDx
PDM->>Validator: Check original_tidx % 32 == 0
alt original_tidx not multiple of 32
Validator-->>Test: ERROR: bdimx must be multiple of 32
else original_tidx is valid
PDM->>Validator: Check after_pad % 32 == 0
alt after_pad not multiple of 32
Validator-->>Test: ERROR: padded bdimx must be multiple of 32
else after_pad is valid
PDM->>PDM: Apply padding to dimension map
PDM-->>Test: Success
end
end
else ws_pt != TIDx
PDM->>PDM: Apply padding to dimension map
PDM-->>Test: Success
end
|
|
!test |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 files reviewed, 1 comment
| if (ws_pt == ParallelType::TIDx && | ||
| getTmaPadThreads(ws_pt, bdim) % 32 != 0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: condition checks padding amount but validation checks total (original + padding). works for current test cases where original_tidx is always a multiple of 32, but would fail if test added case like dim3(96, 8, 1) where original=96 (divisible by 32), pad=16 (not divisible), but after_pad=112 (not divisible by 32)
| if (ws_pt == ParallelType::TIDx && | |
| getTmaPadThreads(ws_pt, bdim) % 32 != 0) { | |
| if (ws_pt == ParallelType::TIDx && | |
| (bdim.x + getTmaPadThreads(ws_pt, bdim)) % 32 != 0) { |
is the test suite intended to only cover cases where original bdimx is a multiple of 32?
|
!test |
No description provided.