Skip to content

Conversation

@samnordmann
Copy link
Collaborator

@samnordmann samnordmann commented Jan 21, 2026

  • Add a first working dispatch+combine primitive for k=1 in multidevice execution, including utilities.
  • Extend Host IR evaluator plumbing to drive the new dispatch+combine path.
  • Add a C++ test

@samnordmann samnordmann changed the title MoE Dispatch Combine first implementation for k=1 and Nccl backend MoE Dispatch/Combine first implementation for k=1 and Nccl backend Jan 21, 2026
@samnordmann
Copy link
Collaborator Author

!test

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 21, 2026

Greptile Overview

Greptile Summary

Adds MoE (Mixture of Experts) dispatch and combine primitives for topk=1 with NCCL backend, enabling token routing across ranks during multidevice execution.

Key changes:

  • New doMoEDispatch and doMoECombine functions in csrc/multidevice/dispatch_combine.cpp implementing alltoall-based token shuffling
  • Host IR nodes MoEDispatch and MoECombine added to communication primitives
  • Integration into Host IR evaluator execution pipeline
  • C++ test validates round-trip dispatch+combine correctness

Issues found:

  • Test has hardcoded token routing for exactly 2 ranks, will fail for larger world sizes
  • Missing validation: is_token_in_rank should be verified as valid one-hot (exactly one True per row)
  • Missing validation: topk_idx values should be checked to be within [0, num_experts) range
  • Previous review threads noted GPU-CPU sync overhead in hot path (acknowledged as known limitation for CPU-initiated NCCL)

Confidence Score: 3/5

  • Safe for initial merge with noted test limitations
  • Core logic is sound but test only validates 2-rank scenario and missing input validations could cause silent errors with invalid inputs
  • tests/cpp/test_multidevice_dispatch_combine.cpp (hardcoded for 2 ranks), csrc/multidevice/dispatch_combine.cpp (missing validations)

Important Files Changed

Filename Overview
csrc/multidevice/dispatch_combine.cpp Core MoE dispatch/combine implementation with alltoall communication and token routing logic
csrc/multidevice/dispatch_combine.h API declarations and comprehensive documentation for dispatch/combine functions
tests/cpp/test_multidevice_dispatch_combine.cpp Single test case with hardcoded values that won't scale properly for world_size > 2

Sequence Diagram

sequenceDiagram
    participant User
    participant HostIrEvaluator
    participant doMoEDispatch
    participant NCCL as NCCL ProcessGroup
    participant doMoECombine

    User->>HostIrEvaluator: runWithInput(x, topk_idx, is_token_in_rank)
    HostIrEvaluator->>HostIrEvaluator: handle(MoEDispatch*)
    HostIrEvaluator->>doMoEDispatch: doMoEDispatch(x, topk_idx, is_token_in_rank, ...)
    
    Note over doMoEDispatch: Compute rank_for_token from is_token_in_rank
    Note over doMoEDispatch: Sort tokens by destination rank
    Note over doMoEDispatch: Reorder x, topk_idx, create src_idx, src_rank
    
    doMoEDispatch->>doMoEDispatch: bincount to get n_tokens_to_rank (GPU→CPU sync)
    doMoEDispatch->>NCCL: alltoall_base(n_tokens_from_rank, n_tokens_to_rank)
    NCCL-->>doMoEDispatch: exchange token counts
    
    doMoEDispatch->>NCCL: alltoall_base(recv_x, send_x)
    NCCL-->>doMoEDispatch: exchange token data
    doMoEDispatch->>NCCL: alltoall_base(recv_topk_idx, send_topk_idx)
    doMoEDispatch->>NCCL: alltoall_base(recv_src_idx, send_src_idx)
    doMoEDispatch->>NCCL: alltoall_base(recv_src_rank, send_src_rank)
    NCCL-->>doMoEDispatch: all metadata exchanged
    
    Note over doMoEDispatch: Reorder received tokens by local expert ID
    doMoEDispatch-->>HostIrEvaluator: DispatchResult{recv_x, recv_topk_idx, recv_src_idx, recv_src_rank, ...}
    
    Note over User: Expert computation happens here (not shown)
    
    HostIrEvaluator->>HostIrEvaluator: handle(MoECombine*)
    HostIrEvaluator->>doMoECombine: doMoECombine(x, src_idx, src_rank, n_tokens_to_rank, ...)
    
    Note over doMoECombine: Sort tokens by source rank
    Note over doMoECombine: Convert split counts to CPU vectors (GPU→CPU sync)
    
    doMoECombine->>NCCL: alltoall_base(recv_x, send_x)
    NCCL-->>doMoECombine: exchange results back to source ranks
    doMoECombine->>NCCL: alltoall_base(recv_src_idx, send_src_idx)
    NCCL-->>doMoECombine: exchange source indices
    
    Note over doMoECombine: Scatter tokens to original positions using src_idx
    doMoECombine-->>HostIrEvaluator: CombineResult{combined_x}
    HostIrEvaluator-->>User: outputs with combined_x
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

9 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Comment on lines +117 to +122
// For CPU-initiated comms (e.g. NCCL), split metadata must live on CPU, so we
// sync/copy here. GPU-initiated comms can avoid this extra sync.
auto rank_for_token_cpu = rank_for_token.to(at::kCPU);
auto n_tokens_to_rank_cpu =
at::bincount(rank_for_token_cpu, {}, world_size).to(at::kLong);
auto n_tokens_to_rank = n_tokens_to_rank_cpu.to(x.device());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: CPU synchronization here adds overhead by transferring rank_for_token to CPU, computing bincount, then moving back to GPU. For large token counts, this synchronization could become a bottleneck.

Suggested change
// For CPU-initiated comms (e.g. NCCL), split metadata must live on CPU, so we
// sync/copy here. GPU-initiated comms can avoid this extra sync.
auto rank_for_token_cpu = rank_for_token.to(at::kCPU);
auto n_tokens_to_rank_cpu =
at::bincount(rank_for_token_cpu, {}, world_size).to(at::kLong);
auto n_tokens_to_rank = n_tokens_to_rank_cpu.to(x.device());
// For CPU-initiated comms (e.g. NCCL), split metadata must live on CPU, so we
// sync/copy here. GPU-initiated comms can avoid this extra sync.
// TODO: Consider using GPU-based bincount to avoid CPU synchronization overhead.
auto rank_for_token_cpu = rank_for_token.to(at::kCPU);
auto n_tokens_to_rank_cpu =
at::bincount(rank_for_token_cpu, {}, world_size).to(at::kLong);

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

@github-actions
Copy link

github-actions bot commented Jan 21, 2026

Review updated until commit 8041c46

Description

  • Add MoEDispatch and MoECombine IR nodes for k=1 MoE routing

  • Implement NCCL alltoall-based dispatch/combine functions

  • Extend HostIrEvaluator with handlers for new MoE operations

  • Add comprehensive test validating multi-rank dispatch/combine

Changes walkthrough

Relevant files
Enhancement
7 files
dispatch_combine.h
New header defining MoE dispatch/combine APIs and result structs
+113/-0 
dispatch_combine.cpp
Implementation of doMoEDispatch and doMoECombine with NCCL alltoall
+240/-0 
communication.h
Add MoEDispatch and MoECombine IR node class definitions 
+157/-0 
communication.cpp
Implement MoEDispatch and MoECombine IR node methods and validation
+137/-0 
evaluator.h
Add MoEDispatch and MoECombine handler declarations           
+2/-0     
evaluator.cpp
Implement HostIrEvaluator handlers for MoEDispatch and MoECombine
+53/-0   
dispatch.h
Register MoEDispatch and MoECombine in dispatch macro       
+2/-0     
Tests
1 files
test_multidevice_dispatch_combine.cpp
Add test for multi-rank MoE dispatch/combine with k=1       
+117/-0 
Configuration changes
1 files
CMakeLists.txt
Add new dispatch_combine.cpp to build and test file to test suite
+2/-0     

PR Reviewer Guide

Here are some key observations to aid the review process:

🧪 PR contains tests
⚡ Recommended focus areas for review
Performance Critical Path

The implementation performs multiple synchronous alltoall operations (4 separate calls for recv_x, recv_topk_idx, recv_src_idx, recv_src_rank). This could be a significant performance bottleneck. Consider whether these tensors could be packed into a single alltoall operation or if the current approach is necessary due to data type differences.

// Alltoall exchange payloads with per-rank splits.
waitWork(pg->alltoall_base(recv_x, send_x, output_splits, input_splits));
waitWork(pg->alltoall_base(
    recv_topk_idx, send_topk_idx, output_splits, input_splits));
waitWork(pg->alltoall_base(
    recv_src_idx, send_src_idx, output_splits, input_splits));
waitWork(pg->alltoall_base(
    recv_src_rank, send_src_rank, output_splits, input_splits));
Memory Allocation

The code allocates new tensors for receive buffers without considering pre-existing allocations. The TODO comment on line 137 mentions this limitation. For production use, this could lead to memory fragmentation and reduced performance. Consider adding support for preallocated buffers.

// TODO: support preallocated buffers.
auto recv_x = at::empty({total_recv, hidden}, x.options());
auto recv_topk_idx = at::empty({total_recv}, topk_idx_flat.options());
auto recv_src_idx = at::empty({total_recv}, send_src_idx.options());
auto recv_src_rank = at::empty({total_recv}, send_src_rank.options());
Limited Test Coverage

The test only covers a basic case with 4 tokens and 2 ranks. Consider adding tests for edge cases like empty token batches, maximum token counts, and different expert distributions to ensure robustness.

TEST_F(DispatchCombineTest, DispatchCombineTop1) {
  if (!communicator_->is_available() || communicator_->size() < 2) {
    GTEST_SKIP() << "This test needs at least 2 ranks.";
  }

  const int64_t world_size = communicator_->size();
  const int64_t my_rank = communicator_->deviceId();
  constexpr int64_t kNumExpertsPerRank = 2;
  const int64_t num_experts = world_size * kNumExpertsPerRank;
  constexpr int64_t kNumTokens = 4;
  constexpr int64_t kHidden = 4;

  auto hic = std::make_unique<HostIrContainer>();
  FusionGuard fg(hic.get());

  auto* in_x = makeSymbolicTensor(2);
  auto* in_topk_idx = makeSymbolicTensor(1, DataType::Int);
  auto* in_is_token_in_rank = makeSymbolicTensor(2, DataType::Bool);

  auto* recv_x = makeSymbolicTensor(2);
  auto* recv_topk_idx = makeSymbolicTensor(1, DataType::Int);
  auto* recv_src_idx = makeSymbolicTensor(1, DataType::Int);
  auto* recv_src_rank = makeSymbolicTensor(1, DataType::Int);
  auto* n_tokens_to_rank = makeSymbolicTensor(1, DataType::Int);
  auto* n_tokens_from_rank = makeSymbolicTensor(1, DataType::Int);

  auto* dispatch = IrBuilder::create<MoEDispatch>(
      recv_x,
      recv_topk_idx,
      recv_src_idx,
      recv_src_rank,
      n_tokens_to_rank,
      n_tokens_from_rank,
      in_x,
      in_topk_idx,
      in_is_token_in_rank,
      num_experts,
      CommunicatorBackend::kNccl);

  auto* combined_x = makeSymbolicTensor(2);
  auto* combine = IrBuilder::create<MoECombine>(
      combined_x,
      recv_x,
      recv_src_idx,
      recv_src_rank,
      n_tokens_to_rank,
      n_tokens_from_rank,
      CommunicatorBackend::kNccl);

  hic->pushBackTopLevelExprs(dispatch);
  hic->pushBackTopLevelExprs(combine);

  hic->addInput(in_x);
  hic->addInput(in_topk_idx);
  hic->addInput(in_is_token_in_rank);
  hic->addOutput(combined_x);

  HostIrEvaluator hie(std::move(hic), communicator_);

  auto float_options =
      at::TensorOptions().device(communicator_->device()).dtype(at::kFloat);
  auto int_options =
      at::TensorOptions().device(communicator_->device()).dtype(at::kLong);

  auto x = at::arange(kNumTokens * kHidden, float_options)
               .reshape({kNumTokens, kHidden}) +
      static_cast<double>(my_rank) * 1000.0;
  auto topk_idx = at::zeros({kNumTokens}, int_options);

  // Asymmetric example:
  // token->rank: [0, 1, 1, 1] so rank0 gets 1 token, rank1 gets 3 tokens.
  auto rank_ids = at::arange(world_size, int_options);
  auto token_rank = at::tensor({0, 1, 1, 1}, int_options);
  auto is_token_in_rank = token_rank.unsqueeze(1).eq(rank_ids);

  // Experts are partitioned by rank. Use rank0 expert0, rank1 experts0/1.
  topk_idx.index_put_({0}, 0);
  topk_idx.index_put_({1}, kNumExpertsPerRank);
  topk_idx.index_put_({2}, kNumExpertsPerRank + 1);
  topk_idx.index_put_({3}, kNumExpertsPerRank);

  auto outputs = hie.runWithInput(
      {{in_x, x},
       {in_topk_idx, topk_idx},
       {in_is_token_in_rank, is_token_in_rank}});
  auto combined = outputs.back().as<at::Tensor>();

  EXPECT_TRUE(at::allclose(combined, x))
      << "Dispatch/Combine mismatch on rank " << my_rank;
}

@samnordmann samnordmann requested a review from wujingyue January 21, 2026 15:21
Copy link
Collaborator

@wujingyue wujingyue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

First round -- I'll review combine later.

auto topk_weights_flat = flattenTopk(topk_weights, num_tokens);

// Determine destination rank per token (topk=1).
auto rank_for_token = is_token_in_rank.to(at::kLong).argmax(1).to(at::kLong);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
auto rank_for_token = is_token_in_rank.to(at::kLong).argmax(1).to(at::kLong);
auto rank_for_token = topk_idx_flat / (num_experts / world_size);

is this equivalent and cheaper?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If so, is is_token_in_rank needed at all? The only other use is the validation several lines above but that's unnecessary given topk_idx is of shape [T, 1].

Copy link
Collaborator Author

@samnordmann samnordmann Jan 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed it’s redundant for contiguous world‑size EP, but keeping is_token_in_rank preserves flexibility for non‑trivial device meshes or uneven expert‑to‑rank mappings. For this reason, I'd rather keep it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed it’s redundant for contiguous world‑size EP, but keeping is_token_in_rank preserves flexibility for non‑trivial device meshes or uneven expert‑to‑rank mappings. For this reason, I'd rather keep it.

Also, I agree there could be a cheaper solution to that, by communicating the expert-->device mesh

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

preserves flexibility for non‑trivial device meshes or uneven expert‑to‑rank mappings

I'd go YAGNI and add is_token_in_rank if/when we need it. It's better to keep the IR simpler so it's easier to test, to lower, to interpret, and to compile.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’ll go ahead and remove is_token_in_rank as you suggest, but I want to make sure we’re aligned on the implication: without an explicit token‑to‑rank mapping we implicitly assume a full linear mesh (EP on all devices). The reason I originally kept it was the same as DeviceMesh—to support non‑trivial/partial meshes. If you’re Ok with that assumption for now, I’ll proceed.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you’re Ok with that assumption for now, I’ll proceed.

Yes, I'm OK with the above implications.

In practice, people seem to be converging on two EP recipes: wide EP and EP+TP.

Wide EP applies EP across all GPUs, as you described.

EP+TP is less common, but is used in the MetaShuffling work on LLaMA-4 MoE inference. In this setup, TP is applied within a node and EP across nodes. Instead of passing an is_token_in_rank tensor of shape [seq_len / ep_size, n_experts], the destination rank can be computed on the fly as:

rank_to_send_to = expert_id / (n_experts / ep_size) * tp_size + tp_rank

An additional optimization in EP+TP is that each GPU only all-to-all sends seq_len / ep_size / tp_size tokens (rather than seq_len / ep_size), followed by an intra-node all-gather. This corresponds to the Communication Deduplication section described in the MetaShuffling paper.

uneven expert‑to‑rank mappings

I don't see a clear practical benefit to assigning more experts to a given GPU. Expert weights are evenly partitioned, and while activations are not guaranteed to be balanced on every iteration, they tend to average out over time.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I did as you suggest, removed the variable and rely on those implicit assumption. I just want to emphasize once again that in order to be fully consistent with this assumption we should also remove the very concept of DeviceMesh, even for other parallelism.

Btw, imo supporting EP+TP is very important.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also remove the very concept of DeviceMesh, even for other parallelism.

That sounds a bit too extreme, but I agree we rarely use a DeviceMesh for more than its rank and shape.

Btw, imo supporting EP+TP is very important.

Sure. Is that a problem? IIUC, ep_size and tp_size can be inferred from DeviceMesh's shape, and are therefore available to MoEDispatch and MoECombine.

TensorView* out_n_tokens_from_rank,
TensorView* in_x,
TensorView* in_topk_idx,
TensorView* in_topk_weights,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, topk_weights doesn't need to go through dispatch or combine. Llama4 applies the topk weights before dispatch

hidden_states = hidden_states * router_scores # [s, h]
and DeepSeek V3 does that after combine.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed — dropped topk_weights. Callers should apply weights before dispatch or after combine; added comments to document that

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

9 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Comment on lines 76 to 84
NVF_CHECK(
[&]() {
auto token_counts = is_token_in_rank.to(at::kLong).sum(1);
auto min_val = token_counts.min().item<int64_t>();
auto max_val = token_counts.max().item<int64_t>();
return min_val == 1 && max_val == 1;
}(),
"Only topk=1 is supported. Each token must be assigned to exactly one "
"rank.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Validation check causes unnecessary GPU-CPU synchronization on every dispatch call. The lambda executes .min() and .max() which trigger implicit CPU synchronization via .item<int64_t>(). For topk=1 validation, consider validating this constraint earlier during graph construction or accepting it as a precondition, rather than checking it at runtime on the hot path.

@samnordmann samnordmann force-pushed the dispatch_combine/stub branch from 0f48cd5 to afd948d Compare January 22, 2026 12:14
@samnordmann
Copy link
Collaborator Author

!test

@samnordmann samnordmann requested a review from wujingyue January 22, 2026 17:36
auto topk_weights_flat = flattenTopk(topk_weights, num_tokens);

// Determine destination rank per token (topk=1).
auto rank_for_token = is_token_in_rank.to(at::kLong).argmax(1).to(at::kLong);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

preserves flexibility for non‑trivial device meshes or uneven expert‑to‑rank mappings

I'd go YAGNI and add is_token_in_rank if/when we need it. It's better to keep the IR simpler so it's easier to test, to lower, to interpret, and to compile.

Comment on lines +195 to +197
TensorView* out_topk_idx,
TensorView* out_src_idx,
TensorView* out_src_rank,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand n_tokens_to_rank and n_tokens_from_rank, but I'm not sure why these three tensors need to be outputs. At least, I didn't have to use them in

# --------------------------------------------------------------------------
# Step 4: Each rank sorts the processed tokens by rank ID.
# --------------------------------------------------------------------------
# GPU 0: tokens_for_expert_0_from_rank_0 || tokens_for_expert_1_from_rank_0 || tokens_for_expert_2_from_rank_0 || tokens_for_expert_0_from_rank_1 || tokens_for_expert_1_from_rank_1 || tokens_for_expert_2_from_rank_1
# GPU 1: tokens_for_expert_3_from_rank_0 || tokens_for_expert_4_from_rank_0 || tokens_for_expert_5_from_rank_0 || tokens_for_expert_3_from_rank_1 || tokens_for_expert_4_from_rank_1 || tokens_for_expert_5_from_rank_1
processed_tokens_by_rank = expert_first_to_rank_first(
processed_tokens_by_expert, n_tokens_for_expert_from_rank
)
# --------------------------------------------------------------------------
# Step 5: Processed tokens are sent back to the original ranks.
# --------------------------------------------------------------------------
processed_tokens = torch.empty(n_tokens, dtype=torch.complex64, device="cuda")
# GPU 0: tokens_for_expert_0_from_rank_0 || tokens_for_expert_1_from_rank_0 || tokens_for_expert_2_from_rank_0 || tokens_for_expert_3_from_rank_0 || tokens_for_expert_4_from_rank_0 || tokens_for_expert_5_from_rank_0
# GPU 1: tokens_for_expert_0_from_rank_1 || tokens_for_expert_1_from_rank_1 || tokens_for_expert_2_from_rank_1 || tokens_for_expert_3_from_rank_1 || tokens_for_expert_4_from_rank_1 || tokens_for_expert_5_from_rank_1
dist.all_to_all_single(
processed_tokens,
processed_tokens_by_rank,
n_tokens_to_rank.tolist(),
n_tokens_from_rank.tolist(),
)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They’re required for the round‑trip semantics here. out_src_idx/out_src_rank are the metadata needed by combine to route tokens back to their original rank and restore local order (we use them to build the alltoall splits and the final index_copy_).

out_topk_idx is still needed after dispatch so the rank can route tokens to its local experts.

about out_src_idx/out_src_rank, I can go with your suggestion and drop them, but I want to be explicit about the implication: we’d be committing to the same constrained ordering as the reference test you linked, that is, routing must be fully determined by split sizes and a fixed rank<->expert mapping, with per‑rank chunk order preserved by all‑to‑all. In that model you can reconstruct order without per‑token metadata. The trade‑off is that we lose support for arbitrary routing/non‑trivial meshes or any custom per‑token permutation, since we’d have no way to restore original order. One potential issue I foresee is implementing a padded dispatch/combined on some over- and pre-allocated buffers -- but that could be solved later!

Assuming we are ok with those implications, I’ll proceed with the simplification.

Copy link
Collaborator Author

@samnordmann samnordmann Jan 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Coming back on my last comment: note that DeepSeek API seems to use those tensors and their Combine specifically reorders back to original token positions (using src_idx/src_rank).
https://github.com/deepseek-ai/DeepEP/blob/29d31c095796f3c8ece47ee9cdcc167051bbeed9/csrc/kernels/intranode.cu#L1034

Is this an argument to keep those variables?

As a side note, related to #5857 (comment), note that DeepSeek interface takes topk_weights

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any thoughts on that? Imho reducing the number of arguments doesn't necessarily simplify the IR because it introduces implicit assumptions

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note that DeepSeek interface takes topk_weights

Good point. I suspect this is because their combine kernel also covers the topk_weights multiplication. It's suboptimal to materialize s*k/ep in global memory. This PR probably doesn't need that at this very moment because it assumes k=1. I'll double check the code to make sure.

What does your kernel look like? I believe this is just a reference implementation and you are about to add a fused kernel.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does your kernel look like? I believe this is just a reference implementation and you are about to add a fused kernel.

What kernel ? I am not sure to understand. I am not writing any cuda kernel in this PR

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll double check the code to make sure.

I got a chance to read the code. Thanks for the pointer! There are three modes:

  • Low-latency (see internode_ll.cu). Inter-node. One-hop.
  • Intra-node (see intranode.cu). One-hop.
  • Inter-node (see internode.cu). Two-hop, first RDMA and then nvlink.

Regardless of which mode, combine reads topk_weights. When k>1, it's better to fuse topk_weights multiplication into the combine kernel to avoid materializing s*k/ep tokens in global memory.

According to deepseek-ai/DeepEP#72, dispatch needs to produce topk_weights for local reduction. Therefore, dispatch does so for inter-node and doesn't for low-latency. It's yet unclear why dispatch produces topk_weights for intra-node in some cases.

What kernel? I am not sure to understand. I am not writing any cuda kernel in this PR

Never mind. I was referring to the kernel in the other PR you showed on Thursday. I wasn't sure you were building kernels from scratch or reusing DeepEP's. Since the former, we have more motivations to keep implementations simple until complexity is required.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

out_src_idx makes sense to me now. I previously thought a local sort was performed before dispatch, but I now understand that dispatch itself handles that.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still not sure about out_src_rank. Does it equal out_src_idx / ceil(seq_len / EP)?

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

auto topk_weights_flat = flattenTopk(topk_weights, num_tokens);

// Determine destination rank per token (topk=1).
auto rank_for_token = is_token_in_rank.to(at::kLong).argmax(1).to(at::kLong);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you’re Ok with that assumption for now, I’ll proceed.

Yes, I'm OK with the above implications.

In practice, people seem to be converging on two EP recipes: wide EP and EP+TP.

Wide EP applies EP across all GPUs, as you described.

EP+TP is less common, but is used in the MetaShuffling work on LLaMA-4 MoE inference. In this setup, TP is applied within a node and EP across nodes. Instead of passing an is_token_in_rank tensor of shape [seq_len / ep_size, n_experts], the destination rank can be computed on the fly as:

rank_to_send_to = expert_id / (n_experts / ep_size) * tp_size + tp_rank

An additional optimization in EP+TP is that each GPU only all-to-all sends seq_len / ep_size / tp_size tokens (rather than seq_len / ep_size), followed by an intra-node all-gather. This corresponds to the Communication Deduplication section described in the MetaShuffling paper.

uneven expert‑to‑rank mappings

I don't see a clear practical benefit to assigning more experts to a given GPU. Expert weights are evenly partitioned, and while activations are not guaranteed to be balanced on every iteration, they tend to average out over time.

@samnordmann samnordmann requested a review from nsarka January 27, 2026 15:12
@samnordmann samnordmann force-pushed the dispatch_combine/stub branch from 6da9793 to 4693c53 Compare January 29, 2026 14:03
@samnordmann samnordmann requested a review from wujingyue January 29, 2026 14:04
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

// Asymmetric example:
// token->rank: [0, 1, 1, 1] so rank0 gets 1 token, rank1 gets 3 tokens.
auto rank_ids = at::arange(world_size, int_options);
auto token_rank = at::tensor({0, 1, 1, 1}, int_options);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hardcoded to 4 tokens with routing [0, 1, 1, 1] - won't work for world_size > 2

auto topk_idx_flat = topk_idx.reshape({num_tokens});

// Determine destination rank per token (topk=1).
auto rank_for_token = is_token_in_rank.to(at::kLong).argmax(1).to(at::kLong);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assumes is_token_in_rank is valid one-hot (exactly one 1 per row). If multiple ranks are marked or none are marked, argmax returns first/last index but routing will be incorrect. Add validation that each row has exactly one True value.

topk_is_1d || topk_is_2d,
"Only topk=1 supported. topk_idx must be shape [T] or [T, 1], got: ",
topk_idx.sizes());
auto topk_idx_flat = topk_idx.reshape({num_tokens});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no validation that topk_idx values are within [0, num_experts) range. Invalid expert IDs could cause incorrect rank assignment or out-of-bounds issues in local expert computation.

@samnordmann
Copy link
Collaborator Author

!test

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants