-
Notifications
You must be signed in to change notification settings - Fork 76
MoE Dispatch/Combine first implementation for k=1 and Nccl backend
#5857
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
k=1 and Nccl backendk=1 and Nccl backend
|
!test |
Greptile OverviewGreptile SummaryAdds MoE (Mixture of Experts) dispatch and combine primitives for topk=1 with NCCL backend, enabling token routing across ranks during multidevice execution. Key changes:
Issues found:
Confidence Score: 3/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant HostIrEvaluator
participant doMoEDispatch
participant NCCL as NCCL ProcessGroup
participant doMoECombine
User->>HostIrEvaluator: runWithInput(x, topk_idx, is_token_in_rank)
HostIrEvaluator->>HostIrEvaluator: handle(MoEDispatch*)
HostIrEvaluator->>doMoEDispatch: doMoEDispatch(x, topk_idx, is_token_in_rank, ...)
Note over doMoEDispatch: Compute rank_for_token from is_token_in_rank
Note over doMoEDispatch: Sort tokens by destination rank
Note over doMoEDispatch: Reorder x, topk_idx, create src_idx, src_rank
doMoEDispatch->>doMoEDispatch: bincount to get n_tokens_to_rank (GPU→CPU sync)
doMoEDispatch->>NCCL: alltoall_base(n_tokens_from_rank, n_tokens_to_rank)
NCCL-->>doMoEDispatch: exchange token counts
doMoEDispatch->>NCCL: alltoall_base(recv_x, send_x)
NCCL-->>doMoEDispatch: exchange token data
doMoEDispatch->>NCCL: alltoall_base(recv_topk_idx, send_topk_idx)
doMoEDispatch->>NCCL: alltoall_base(recv_src_idx, send_src_idx)
doMoEDispatch->>NCCL: alltoall_base(recv_src_rank, send_src_rank)
NCCL-->>doMoEDispatch: all metadata exchanged
Note over doMoEDispatch: Reorder received tokens by local expert ID
doMoEDispatch-->>HostIrEvaluator: DispatchResult{recv_x, recv_topk_idx, recv_src_idx, recv_src_rank, ...}
Note over User: Expert computation happens here (not shown)
HostIrEvaluator->>HostIrEvaluator: handle(MoECombine*)
HostIrEvaluator->>doMoECombine: doMoECombine(x, src_idx, src_rank, n_tokens_to_rank, ...)
Note over doMoECombine: Sort tokens by source rank
Note over doMoECombine: Convert split counts to CPU vectors (GPU→CPU sync)
doMoECombine->>NCCL: alltoall_base(recv_x, send_x)
NCCL-->>doMoECombine: exchange results back to source ranks
doMoECombine->>NCCL: alltoall_base(recv_src_idx, send_src_idx)
NCCL-->>doMoECombine: exchange source indices
Note over doMoECombine: Scatter tokens to original positions using src_idx
doMoECombine-->>HostIrEvaluator: CombineResult{combined_x}
HostIrEvaluator-->>User: outputs with combined_x
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
9 files reviewed, 1 comment
| // For CPU-initiated comms (e.g. NCCL), split metadata must live on CPU, so we | ||
| // sync/copy here. GPU-initiated comms can avoid this extra sync. | ||
| auto rank_for_token_cpu = rank_for_token.to(at::kCPU); | ||
| auto n_tokens_to_rank_cpu = | ||
| at::bincount(rank_for_token_cpu, {}, world_size).to(at::kLong); | ||
| auto n_tokens_to_rank = n_tokens_to_rank_cpu.to(x.device()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: CPU synchronization here adds overhead by transferring rank_for_token to CPU, computing bincount, then moving back to GPU. For large token counts, this synchronization could become a bottleneck.
| // For CPU-initiated comms (e.g. NCCL), split metadata must live on CPU, so we | |
| // sync/copy here. GPU-initiated comms can avoid this extra sync. | |
| auto rank_for_token_cpu = rank_for_token.to(at::kCPU); | |
| auto n_tokens_to_rank_cpu = | |
| at::bincount(rank_for_token_cpu, {}, world_size).to(at::kLong); | |
| auto n_tokens_to_rank = n_tokens_to_rank_cpu.to(x.device()); | |
| // For CPU-initiated comms (e.g. NCCL), split metadata must live on CPU, so we | |
| // sync/copy here. GPU-initiated comms can avoid this extra sync. | |
| // TODO: Consider using GPU-based bincount to avoid CPU synchronization overhead. | |
| auto rank_for_token_cpu = rank_for_token.to(at::kCPU); | |
| auto n_tokens_to_rank_cpu = | |
| at::bincount(rank_for_token_cpu, {}, world_size).to(at::kLong); |
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
|
Review updated until commit 8041c46 Description
|
| Relevant files | |||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| Enhancement | 7 files
| ||||||||||||||
| Tests | 1 files
| ||||||||||||||
| Configuration changes | 1 files
|
PR Reviewer Guide
Here are some key observations to aid the review process:
| 🧪 PR contains tests |
| ⚡ Recommended focus areas for review |
Performance Critical Path
|
wujingyue
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
First round -- I'll review combine later.
| auto topk_weights_flat = flattenTopk(topk_weights, num_tokens); | ||
|
|
||
| // Determine destination rank per token (topk=1). | ||
| auto rank_for_token = is_token_in_rank.to(at::kLong).argmax(1).to(at::kLong); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| auto rank_for_token = is_token_in_rank.to(at::kLong).argmax(1).to(at::kLong); | |
| auto rank_for_token = topk_idx_flat / (num_experts / world_size); |
is this equivalent and cheaper?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If so, is is_token_in_rank needed at all? The only other use is the validation several lines above but that's unnecessary given topk_idx is of shape [T, 1].
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed it’s redundant for contiguous world‑size EP, but keeping is_token_in_rank preserves flexibility for non‑trivial device meshes or uneven expert‑to‑rank mappings. For this reason, I'd rather keep it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed it’s redundant for contiguous world‑size EP, but keeping
is_token_in_rankpreserves flexibility for non‑trivial device meshes or uneven expert‑to‑rank mappings. For this reason, I'd rather keep it.
Also, I agree there could be a cheaper solution to that, by communicating the expert-->device mesh
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
preserves flexibility for non‑trivial device meshes or uneven expert‑to‑rank mappings
I'd go YAGNI and add is_token_in_rank if/when we need it. It's better to keep the IR simpler so it's easier to test, to lower, to interpret, and to compile.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I’ll go ahead and remove is_token_in_rank as you suggest, but I want to make sure we’re aligned on the implication: without an explicit token‑to‑rank mapping we implicitly assume a full linear mesh (EP on all devices). The reason I originally kept it was the same as DeviceMesh—to support non‑trivial/partial meshes. If you’re Ok with that assumption for now, I’ll proceed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you’re Ok with that assumption for now, I’ll proceed.
Yes, I'm OK with the above implications.
In practice, people seem to be converging on two EP recipes: wide EP and EP+TP.
Wide EP applies EP across all GPUs, as you described.
EP+TP is less common, but is used in the MetaShuffling work on LLaMA-4 MoE inference. In this setup, TP is applied within a node and EP across nodes. Instead of passing an is_token_in_rank tensor of shape [seq_len / ep_size, n_experts], the destination rank can be computed on the fly as:
rank_to_send_to = expert_id / (n_experts / ep_size) * tp_size + tp_rank
An additional optimization in EP+TP is that each GPU only all-to-all sends seq_len / ep_size / tp_size tokens (rather than seq_len / ep_size), followed by an intra-node all-gather. This corresponds to the Communication Deduplication section described in the MetaShuffling paper.
uneven expert‑to‑rank mappings
I don't see a clear practical benefit to assigning more experts to a given GPU. Expert weights are evenly partitioned, and while activations are not guaranteed to be balanced on every iteration, they tend to average out over time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok I did as you suggest, removed the variable and rely on those implicit assumption. I just want to emphasize once again that in order to be fully consistent with this assumption we should also remove the very concept of DeviceMesh, even for other parallelism.
Btw, imo supporting EP+TP is very important.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should also remove the very concept of DeviceMesh, even for other parallelism.
That sounds a bit too extreme, but I agree we rarely use a DeviceMesh for more than its rank and shape.
Btw, imo supporting EP+TP is very important.
Sure. Is that a problem? IIUC, ep_size and tp_size can be inferred from DeviceMesh's shape, and are therefore available to MoEDispatch and MoECombine.
csrc/multidevice/communication.h
Outdated
| TensorView* out_n_tokens_from_rank, | ||
| TensorView* in_x, | ||
| TensorView* in_topk_idx, | ||
| TensorView* in_topk_weights, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC, topk_weights doesn't need to go through dispatch or combine. Llama4 applies the topk weights before dispatch
Fuser/tests/python/test_moe.py
Line 114 in 0c147ae
| hidden_states = hidden_states * router_scores # [s, h] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed — dropped topk_weights. Callers should apply weights before dispatch or after combine; added comments to document that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
9 files reviewed, 1 comment
| NVF_CHECK( | ||
| [&]() { | ||
| auto token_counts = is_token_in_rank.to(at::kLong).sum(1); | ||
| auto min_val = token_counts.min().item<int64_t>(); | ||
| auto max_val = token_counts.max().item<int64_t>(); | ||
| return min_val == 1 && max_val == 1; | ||
| }(), | ||
| "Only topk=1 is supported. Each token must be assigned to exactly one " | ||
| "rank."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: Validation check causes unnecessary GPU-CPU synchronization on every dispatch call. The lambda executes .min() and .max() which trigger implicit CPU synchronization via .item<int64_t>(). For topk=1 validation, consider validating this constraint earlier during graph construction or accepting it as a precondition, rather than checking it at runtime on the hot path.
0f48cd5 to
afd948d
Compare
|
!test |
| auto topk_weights_flat = flattenTopk(topk_weights, num_tokens); | ||
|
|
||
| // Determine destination rank per token (topk=1). | ||
| auto rank_for_token = is_token_in_rank.to(at::kLong).argmax(1).to(at::kLong); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
preserves flexibility for non‑trivial device meshes or uneven expert‑to‑rank mappings
I'd go YAGNI and add is_token_in_rank if/when we need it. It's better to keep the IR simpler so it's easier to test, to lower, to interpret, and to compile.
| TensorView* out_topk_idx, | ||
| TensorView* out_src_idx, | ||
| TensorView* out_src_rank, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand n_tokens_to_rank and n_tokens_from_rank, but I'm not sure why these three tensors need to be outputs. At least, I didn't have to use them in
Fuser/tests/python/multidevice/test_expert_parallel.py
Lines 165 to 185 in f8b6785
| # -------------------------------------------------------------------------- | |
| # Step 4: Each rank sorts the processed tokens by rank ID. | |
| # -------------------------------------------------------------------------- | |
| # GPU 0: tokens_for_expert_0_from_rank_0 || tokens_for_expert_1_from_rank_0 || tokens_for_expert_2_from_rank_0 || tokens_for_expert_0_from_rank_1 || tokens_for_expert_1_from_rank_1 || tokens_for_expert_2_from_rank_1 | |
| # GPU 1: tokens_for_expert_3_from_rank_0 || tokens_for_expert_4_from_rank_0 || tokens_for_expert_5_from_rank_0 || tokens_for_expert_3_from_rank_1 || tokens_for_expert_4_from_rank_1 || tokens_for_expert_5_from_rank_1 | |
| processed_tokens_by_rank = expert_first_to_rank_first( | |
| processed_tokens_by_expert, n_tokens_for_expert_from_rank | |
| ) | |
| # -------------------------------------------------------------------------- | |
| # Step 5: Processed tokens are sent back to the original ranks. | |
| # -------------------------------------------------------------------------- | |
| processed_tokens = torch.empty(n_tokens, dtype=torch.complex64, device="cuda") | |
| # GPU 0: tokens_for_expert_0_from_rank_0 || tokens_for_expert_1_from_rank_0 || tokens_for_expert_2_from_rank_0 || tokens_for_expert_3_from_rank_0 || tokens_for_expert_4_from_rank_0 || tokens_for_expert_5_from_rank_0 | |
| # GPU 1: tokens_for_expert_0_from_rank_1 || tokens_for_expert_1_from_rank_1 || tokens_for_expert_2_from_rank_1 || tokens_for_expert_3_from_rank_1 || tokens_for_expert_4_from_rank_1 || tokens_for_expert_5_from_rank_1 | |
| dist.all_to_all_single( | |
| processed_tokens, | |
| processed_tokens_by_rank, | |
| n_tokens_to_rank.tolist(), | |
| n_tokens_from_rank.tolist(), | |
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They’re required for the round‑trip semantics here. out_src_idx/out_src_rank are the metadata needed by combine to route tokens back to their original rank and restore local order (we use them to build the alltoall splits and the final index_copy_).
out_topk_idx is still needed after dispatch so the rank can route tokens to its local experts.
about out_src_idx/out_src_rank, I can go with your suggestion and drop them, but I want to be explicit about the implication: we’d be committing to the same constrained ordering as the reference test you linked, that is, routing must be fully determined by split sizes and a fixed rank<->expert mapping, with per‑rank chunk order preserved by all‑to‑all. In that model you can reconstruct order without per‑token metadata. The trade‑off is that we lose support for arbitrary routing/non‑trivial meshes or any custom per‑token permutation, since we’d have no way to restore original order. One potential issue I foresee is implementing a padded dispatch/combined on some over- and pre-allocated buffers -- but that could be solved later!
Assuming we are ok with those implications, I’ll proceed with the simplification.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Coming back on my last comment: note that DeepSeek API seems to use those tensors and their Combine specifically reorders back to original token positions (using src_idx/src_rank).
https://github.com/deepseek-ai/DeepEP/blob/29d31c095796f3c8ece47ee9cdcc167051bbeed9/csrc/kernels/intranode.cu#L1034
Is this an argument to keep those variables?
As a side note, related to #5857 (comment), note that DeepSeek interface takes topk_weights
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any thoughts on that? Imho reducing the number of arguments doesn't necessarily simplify the IR because it introduces implicit assumptions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note that DeepSeek interface takes topk_weights
Good point. I suspect this is because their combine kernel also covers the topk_weights multiplication. It's suboptimal to materialize s*k/ep in global memory. This PR probably doesn't need that at this very moment because it assumes k=1. I'll double check the code to make sure.
What does your kernel look like? I believe this is just a reference implementation and you are about to add a fused kernel.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does your kernel look like? I believe this is just a reference implementation and you are about to add a fused kernel.
What kernel ? I am not sure to understand. I am not writing any cuda kernel in this PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll double check the code to make sure.
I got a chance to read the code. Thanks for the pointer! There are three modes:
- Low-latency (see internode_ll.cu). Inter-node. One-hop.
- Intra-node (see intranode.cu). One-hop.
- Inter-node (see internode.cu). Two-hop, first RDMA and then nvlink.
Regardless of which mode, combine reads topk_weights. When k>1, it's better to fuse topk_weights multiplication into the combine kernel to avoid materializing s*k/ep tokens in global memory.
According to deepseek-ai/DeepEP#72, dispatch needs to produce topk_weights for local reduction. Therefore, dispatch does so for inter-node and doesn't for low-latency. It's yet unclear why dispatch produces topk_weights for intra-node in some cases.
What kernel? I am not sure to understand. I am not writing any cuda kernel in this PR
Never mind. I was referring to the kernel in the other PR you showed on Thursday. I wasn't sure you were building kernels from scratch or reusing DeepEP's. Since the former, we have more motivations to keep implementations simple until complexity is required.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
out_src_idx makes sense to me now. I previously thought a local sort was performed before dispatch, but I now understand that dispatch itself handles that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm still not sure about out_src_rank. Does it equal out_src_idx / ceil(seq_len / EP)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 files reviewed, 2 comments
| auto topk_weights_flat = flattenTopk(topk_weights, num_tokens); | ||
|
|
||
| // Determine destination rank per token (topk=1). | ||
| auto rank_for_token = is_token_in_rank.to(at::kLong).argmax(1).to(at::kLong); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you’re Ok with that assumption for now, I’ll proceed.
Yes, I'm OK with the above implications.
In practice, people seem to be converging on two EP recipes: wide EP and EP+TP.
Wide EP applies EP across all GPUs, as you described.
EP+TP is less common, but is used in the MetaShuffling work on LLaMA-4 MoE inference. In this setup, TP is applied within a node and EP across nodes. Instead of passing an is_token_in_rank tensor of shape [seq_len / ep_size, n_experts], the destination rank can be computed on the fly as:
rank_to_send_to = expert_id / (n_experts / ep_size) * tp_size + tp_rank
An additional optimization in EP+TP is that each GPU only all-to-all sends seq_len / ep_size / tp_size tokens (rather than seq_len / ep_size), followed by an intra-node all-gather. This corresponds to the Communication Deduplication section described in the MetaShuffling paper.
uneven expert‑to‑rank mappings
I don't see a clear practical benefit to assigning more experts to a given GPU. Expert weights are evenly partitioned, and while activations are not guaranteed to be balanced on every iteration, they tend to average out over time.
6da9793 to
4693c53
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, 3 comments
| // Asymmetric example: | ||
| // token->rank: [0, 1, 1, 1] so rank0 gets 1 token, rank1 gets 3 tokens. | ||
| auto rank_ids = at::arange(world_size, int_options); | ||
| auto token_rank = at::tensor({0, 1, 1, 1}, int_options); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hardcoded to 4 tokens with routing [0, 1, 1, 1] - won't work for world_size > 2
| auto topk_idx_flat = topk_idx.reshape({num_tokens}); | ||
|
|
||
| // Determine destination rank per token (topk=1). | ||
| auto rank_for_token = is_token_in_rank.to(at::kLong).argmax(1).to(at::kLong); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assumes is_token_in_rank is valid one-hot (exactly one 1 per row). If multiple ranks are marked or none are marked, argmax returns first/last index but routing will be incorrect. Add validation that each row has exactly one True value.
| topk_is_1d || topk_is_2d, | ||
| "Only topk=1 supported. topk_idx must be shape [T] or [T, 1], got: ", | ||
| topk_idx.sizes()); | ||
| auto topk_idx_flat = topk_idx.reshape({num_tokens}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no validation that topk_idx values are within [0, num_experts) range. Invalid expert IDs could cause incorrect rank assignment or out-of-bounds issues in local expert computation.
|
!test |
Uh oh!
There was an error while loading. Please reload this page.