-
Notifications
You must be signed in to change notification settings - Fork 76
Add kernel based alltoallv and cuda backend for MoE dispatch and combine #5863
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dispatch_combine/stub
Are you sure you want to change the base?
Add kernel based alltoallv and cuda backend for MoE dispatch and combine #5863
Conversation
|
!test |
Greptile SummaryThis PR adds a kernel-based alltoallv implementation and CUDA backend support for MoE dispatch/combine operations. The implementation introduces a custom CUDA kernel ( Key changes:
Issues found:
Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant App as Application
participant Dispatch as doMoEDispatch
participant Combine as doMoECombine
participant Metadata as prepareAlltoallvMetadata
participant Kernel as launchAlltoallvKernel
participant SymTensor as SymmetricTensor
participant Barrier as alltoallvBarrier
App->>Dispatch: x, topk_idx, topk_weights, is_token_in_rank
Note over Dispatch: Sort tokens by destination rank
Dispatch->>Dispatch: Reorder x, topk_idx, topk_weights by rank
alt Backend == NCCL
Dispatch->>Dispatch: Use ProcessGroup alltoall_base
else Backend == CUDA
Dispatch->>Metadata: prepareAlltoallvMetadata(n_tokens_to_rank)
Note over Metadata: Exchange counts via TCPStore<br/>Compute offsets and metadata
Metadata-->>Dispatch: AlltoallvMetadata
Dispatch->>SymTensor: Allocate symmetric send/recv buffers
SymTensor-->>Dispatch: Send/recv tensor handles
Dispatch->>SymTensor: setupRemoteHandles() for each buffer
Note over SymTensor: Exchange IPC handles across ranks
loop For each payload (x, topk_idx, topk_weights, src_idx, src_rank)
Dispatch->>Kernel: launchAlltoallvKernel(send, recv_ptrs, metadata)
Note over Kernel: CUDA kernel copies data<br/>from send to recv buffers<br/>across all ranks
end
Dispatch->>Barrier: alltoallvBarrier()
Note over Barrier: Synchronize all ranks
end
Note over Dispatch: Reorder by expert ID locally
Dispatch-->>App: recv_x, recv_topk_idx, recv_topk_weights, recv_src_idx, recv_src_rank
Note over App: ... Expert computation happens here ...
App->>Combine: x, topk_weights, src_idx, src_rank, n_tokens_to/from_rank
Note over Combine: Sort by source rank
alt Backend == NCCL
Combine->>Combine: Use ProcessGroup alltoall_base
else Backend == CUDA
Combine->>Metadata: prepareAlltoallvMetadata(n_tokens_from_rank)
Metadata-->>Combine: AlltoallvMetadata
Combine->>SymTensor: Allocate symmetric buffers
Combine->>SymTensor: setupRemoteHandles()
loop For each payload (x, topk_weights, src_idx)
Combine->>Kernel: launchAlltoallvKernel(send, recv_ptrs, metadata)
end
Combine->>Barrier: alltoallvBarrier()
end
Note over Combine: Scatter by original token index
Combine-->>App: combined_x, combined_topk_weights
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
8 files reviewed, 1 comment
| if (!communicator_->isBackendAvailable(CommunicatorBackend::kNccl)) { | ||
| GTEST_SKIP() << "Backend " << backend << " not available."; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: checking wrong backend constant - should check backend parameter, not hardcoded kNccl
| if (!communicator_->isBackendAvailable(CommunicatorBackend::kNccl)) { | |
| GTEST_SKIP() << "Backend " << backend << " not available."; | |
| } | |
| if (!communicator_->isBackendAvailable(backend)) { | |
| GTEST_SKIP() << "Backend " << backend << " not available."; | |
| } |
No description provided.