From dda9aa7c2be35ef1e604fb12b63d8a5278834657 Mon Sep 17 00:00:00 2001 From: snordmann Date: Thu, 22 Jan 2026 09:33:18 -0800 Subject: [PATCH] add kernel based a2av and cuda backend for d/c --- CMakeLists.txt | 2 + csrc/multidevice/alltoallv.cu | 37 ++ csrc/multidevice/cuda_p2p.cpp | 315 ++++++++++++++++++ csrc/multidevice/cuda_p2p.h | 29 ++ csrc/multidevice/dispatch_combine.cpp | 309 +++++++++++++---- csrc/multidevice/dispatch_combine.h | 4 +- tests/cpp/test_multidevice_alltoallv.cpp | 82 +++++ .../cpp/test_multidevice_dispatch_combine.cpp | 20 +- 8 files changed, 726 insertions(+), 72 deletions(-) create mode 100644 csrc/multidevice/alltoallv.cu create mode 100644 tests/cpp/test_multidevice_alltoallv.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index b325b325d9c..ff76e741b4c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1144,6 +1144,7 @@ if(BUILD_TEST) ${NVFUSER_ROOT}/tests/cpp/multidevice.cpp ${NVFUSER_ROOT}/tests/cpp/multidevice_transformer.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_communications.cpp + ${NVFUSER_ROOT}/tests/cpp/test_multidevice_alltoallv.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_dispatch_combine.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_communicator.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_host_ir.cpp @@ -1393,6 +1394,7 @@ list(APPEND NVFUSER_RUNTIME_FILES ${NVFUSER_ROOT}/runtime/mbarrier.cu ${NVFUSER_ROOT}/runtime/memory.cu ${NVFUSER_ROOT}/runtime/multicast.cu + ${NVFUSER_SRCS_DIR}/multidevice/alltoallv.cu ${NVFUSER_ROOT}/runtime/random_numbers.cu ${NVFUSER_ROOT}/runtime/tensor_memory.cu ${NVFUSER_ROOT}/runtime/tensor.cu diff --git a/csrc/multidevice/alltoallv.cu b/csrc/multidevice/alltoallv.cu new file mode 100644 index 00000000000..9725794f838 --- /dev/null +++ b/csrc/multidevice/alltoallv.cu @@ -0,0 +1,37 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2026-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on + +extern "C" __global__ void alltoallv_kernel( + const unsigned char* send, + const unsigned long long* recv_ptrs, + const long long* send_offsets, + const long long* send_sizes, + const long long* recv_offsets, + long long world_size, + long long elem_size, + long long max_send_bytes) { + const long long peer = static_cast(blockIdx.y); + if (peer >= world_size) { + return; + } + const long long bytes = send_sizes[peer] * elem_size; + if (bytes == 0) { + return; + } + const long long idx = + static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + if (idx >= bytes) { + return; + } + const long long send_byte_offset = send_offsets[peer] * elem_size + idx; + const long long recv_byte_offset = recv_offsets[peer] * elem_size + idx; + auto* dst = reinterpret_cast( + static_cast(recv_ptrs[peer])); + dst[recv_byte_offset] = send[send_byte_offset]; +} + diff --git a/csrc/multidevice/cuda_p2p.cpp b/csrc/multidevice/cuda_p2p.cpp index 6ad709fa062..8804c1a7a79 100644 --- a/csrc/multidevice/cuda_p2p.cpp +++ b/csrc/multidevice/cuda_p2p.cpp @@ -6,6 +6,7 @@ */ // clang-format on #include "multidevice/cuda_p2p.h" +#include "nvfuser_resources/alltoallv.h" #include "nvfuser_resources/multicast.h" #include "cuda_utils.h" @@ -34,6 +35,143 @@ P2pProtocol getP2pProtocol() { } namespace { +void launchAlltoallvKernel( + const void* send, + const uint64_t* recv_ptrs, + const int64_t* send_offsets, + const int64_t* send_sizes, + const int64_t* recv_offsets, + int64_t world_size, + int64_t elem_size, + int64_t max_send_bytes, + CUstream stream) { + static CUmodule module = nullptr; + static CUfunction kernel = nullptr; + + if (module == nullptr) { + nvrtcProgram prog; + NVFUSER_NVRTC_SAFE_CALL(nvrtcCreateProgram( + &prog, + nvfuser_resources::alltoallv_cu, + "alltoallv.cu", + 0, + nullptr, + nullptr)); + + int major = 0; + int minor = 0; + int device = 0; + NVFUSER_CUDA_RT_SAFE_CALL(cudaGetDevice(&device)); + cudaDeviceProp prop; + NVFUSER_CUDA_RT_SAFE_CALL(cudaGetDeviceProperties(&prop, device)); + major = prop.major; + minor = prop.minor; + + std::string arch_arg = "--gpu-architecture=compute_" + + std::to_string(major) + std::to_string(minor); + std::vector opts = {arch_arg.c_str(), "--std=c++17"}; + // NVRTC needs CUDA headers to compile alltoallv.cu. + opts.push_back("-I/usr/local/cuda/include"); + opts.push_back("-I/usr/local/cuda/include/cccl"); + + nvrtcResult res = nvrtcCompileProgram(prog, (int)opts.size(), opts.data()); + if (res != NVRTC_SUCCESS) { + size_t logSize; + NVFUSER_NVRTC_SAFE_CALL(nvrtcGetProgramLogSize(prog, &logSize)); + std::vector log(logSize); + NVFUSER_NVRTC_SAFE_CALL(nvrtcGetProgramLog(prog, log.data())); + NVF_ERROR(false, "Alltoallv kernel compilation failed:\n", log.data()); + } + + size_t ptxSize; + NVFUSER_NVRTC_SAFE_CALL(nvrtcGetPTXSize(prog, &ptxSize)); + std::vector ptx(ptxSize); + NVFUSER_NVRTC_SAFE_CALL(nvrtcGetPTX(prog, ptx.data())); + NVFUSER_NVRTC_SAFE_CALL(nvrtcDestroyProgram(&prog)); + + CUresult load_result = cuModuleLoadData(&module, ptx.data()); + if (load_result != CUDA_SUCCESS) { + constexpr size_t kLogSize = 8192; + char error_log[kLogSize]; + char info_log[kLogSize]; + CUjit_option options[] = { + CU_JIT_ERROR_LOG_BUFFER, + CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, + CU_JIT_INFO_LOG_BUFFER, + CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES, + CU_JIT_LOG_VERBOSE}; + void* option_values[] = { + (void*)error_log, + (void*)kLogSize, + (void*)info_log, + (void*)kLogSize, + (void*)1}; + cuModuleLoadDataEx(&module, ptx.data(), 5, options, option_values); + NVF_ERROR( + false, + "Alltoallv kernel module load failed with error: ", + load_result, + "\nInfo Log:\n", + info_log, + "\nError Log:\n", + error_log); + } + + NVFUSER_CUDA_SAFE_CALL( + cuModuleGetFunction(&kernel, module, "alltoallv_kernel")); + } + + if (max_send_bytes == 0) { + return; + } + + constexpr int kThreads = 256; + const int64_t blocks_x = (max_send_bytes + kThreads - 1) / kThreads; + void* args_kernel[] = { + const_cast(static_cast(&send)), + const_cast(static_cast(&recv_ptrs)), + const_cast(static_cast(&send_offsets)), + const_cast(static_cast(&send_sizes)), + const_cast(static_cast(&recv_offsets)), + &world_size, + &elem_size, + &max_send_bytes}; + NVFUSER_CUDA_SAFE_CALL(cuLaunchKernel( + kernel, + blocks_x, + static_cast(world_size), + 1, + kThreads, + 1, + 1, + 0, + stream, + args_kernel, + nullptr)); +} + +std::vector serializeInt64Vector(const std::vector& values) { + std::vector bytes(values.size() * sizeof(int64_t)); + std::memcpy(bytes.data(), values.data(), bytes.size()); + return bytes; +} + +std::vector deserializeInt64Vector(const std::vector& bytes) { + NVF_CHECK( + bytes.size() % sizeof(int64_t) == 0, "Invalid int64 byte buffer size."); + const size_t count = bytes.size() / sizeof(int64_t); + std::vector values(count); + std::memcpy(values.data(), bytes.data(), bytes.size()); + return values; +} + +std::string alltoallvCountsKey(const std::string& tag, int64_t rank) { + return "nvfuser_alltoallv_counts_" + tag + "_" + std::to_string(rank); +} + +std::string alltoallvBarrierKey(const std::string& tag, int64_t rank) { + return "nvfuser_alltoallv_barrier_" + tag + "_" + std::to_string(rank); +} void launchMulticastKernel( void* dst, @@ -710,4 +848,181 @@ void waitWithCudaBackend( } } +AlltoallvMetadata prepareAlltoallvMetadata( + const at::Tensor& send_counts, + const std::string& tag) { + Communicator& comm = Communicator::getInstance(); + const int64_t world_size = comm.size(); + const int64_t my_rank = comm.deviceId(); + NVF_CHECK( + send_counts.is_cuda(), "alltoallv send_counts must be CUDA tensor."); + NVF_CHECK( + send_counts.dim() == 1 && send_counts.numel() == world_size, + "alltoallv send_counts must be 1D [R]."); + + auto store = comm.getTcpStore(); + auto send_counts_cpu = send_counts.to(at::kCPU); + auto* send_ptr = send_counts_cpu.data_ptr(); + std::vector send_counts_vec(send_ptr, send_ptr + world_size); + + store->set( + alltoallvCountsKey(tag, my_rank), serializeInt64Vector(send_counts_vec)); + + std::vector> counts_matrix(world_size); + for (int64_t rank = 0; rank < world_size; ++rank) { + auto bytes = store->get(alltoallvCountsKey(tag, rank)); + counts_matrix[rank] = deserializeInt64Vector(bytes); + NVF_CHECK( + (int64_t)counts_matrix[rank].size() == world_size, + "Invalid alltoallv counts size."); + } + comm.barrier(); + for (int64_t rank = 0; rank < world_size; ++rank) { + store->deleteKey(alltoallvCountsKey(tag, rank)); + } + + std::vector recv_counts_vec(world_size, 0); + for (int64_t sender = 0; sender < world_size; ++sender) { + recv_counts_vec[sender] = counts_matrix[sender][my_rank]; + } + + std::vector send_offsets_vec(world_size, 0); + int64_t prefix = 0; + for (int64_t rank = 0; rank < world_size; ++rank) { + send_offsets_vec[rank] = prefix; + prefix += send_counts_vec[rank]; + } + + std::vector recv_offsets_vec(world_size, 0); + for (int64_t peer = 0; peer < world_size; ++peer) { + int64_t offset = 0; + for (int64_t sender = 0; sender < my_rank; ++sender) { + offset += counts_matrix[sender][peer]; + } + recv_offsets_vec[peer] = offset; + } + + int64_t total_recv = 0; + for (auto value : recv_counts_vec) { + total_recv += value; + } + + int64_t max_recv = 0; + int64_t max_send_total = 0; + for (int64_t rank = 0; rank < world_size; ++rank) { + int64_t total = 0; + for (int64_t sender = 0; sender < world_size; ++sender) { + total += counts_matrix[sender][rank]; + } + if (total > max_recv) { + max_recv = total; + } + } + + for (int64_t rank = 0; rank < world_size; ++rank) { + int64_t total = 0; + for (int64_t dest = 0; dest < world_size; ++dest) { + total += counts_matrix[rank][dest]; + } + if (total > max_send_total) { + max_send_total = total; + } + } + + int64_t max_send = 0; + for (auto value : send_counts_vec) { + if (value > max_send) { + max_send = value; + } + } + + auto cpu_options = at::TensorOptions().dtype(at::kLong).device(at::kCPU); + auto send_offsets_cpu = at::empty({world_size}, cpu_options); + std::memcpy( + send_offsets_cpu.data_ptr(), + send_offsets_vec.data(), + world_size * sizeof(int64_t)); + auto recv_offsets_cpu = at::empty({world_size}, cpu_options); + std::memcpy( + recv_offsets_cpu.data_ptr(), + recv_offsets_vec.data(), + world_size * sizeof(int64_t)); + auto recv_counts_cpu = at::empty({world_size}, cpu_options); + std::memcpy( + recv_counts_cpu.data_ptr(), + recv_counts_vec.data(), + world_size * sizeof(int64_t)); + + AlltoallvMetadata metadata; + metadata.send_counts = send_counts; + metadata.recv_counts = recv_counts_cpu.to(send_counts.device()); + metadata.send_offsets = send_offsets_cpu.to(send_counts.device()); + metadata.recv_offsets = recv_offsets_cpu.to(send_counts.device()); + metadata.total_recv = total_recv; + metadata.max_recv = max_recv; + metadata.max_send_total = max_send_total; + metadata.max_send_bytes = max_send; + metadata.world_size = world_size; + return metadata; +} + +void alltoallvWithCudaBackend( + const at::Tensor& send, + const at::Tensor& recv, + const AlltoallvMetadata& metadata, + const std::vector& recv_ptrs, + CUstream stream) { + NVF_CHECK(send.is_cuda(), "alltoallv send must be CUDA."); + NVF_CHECK(recv.is_cuda(), "alltoallv recv must be CUDA."); + NVF_CHECK( + (int64_t)recv_ptrs.size() == metadata.world_size, + "recv_ptrs size must match world size."); + + auto cpu_options = at::TensorOptions().dtype(at::kLong).device(at::kCPU); + auto recv_ptrs_cpu = at::empty({metadata.world_size}, cpu_options); + auto* ptrs = recv_ptrs_cpu.data_ptr(); + for (int64_t rank = 0; rank < metadata.world_size; ++rank) { + ptrs[rank] = + static_cast(reinterpret_cast(recv_ptrs[rank])); + } + auto recv_ptrs_cuda = recv_ptrs_cpu.to(send.device()); + + const int64_t elem_stride = + metadata.max_send_total > 0 ? send.numel() / metadata.max_send_total : 1; + NVF_CHECK( + metadata.max_send_total == 0 || + send.numel() % metadata.max_send_total == 0, + "alltoallv send numel must be divisible by max_send_total."); + NVF_CHECK( + metadata.max_recv == 0 || recv.numel() % metadata.max_recv == 0, + "alltoallv recv numel must be divisible by max_recv."); + + auto send_offsets = metadata.send_offsets; + auto send_counts = metadata.send_counts; + auto recv_offsets = metadata.recv_offsets; + int64_t max_send_bytes = metadata.max_send_bytes; + if (elem_stride > 1) { + send_offsets = metadata.send_offsets * elem_stride; + send_counts = metadata.send_counts * elem_stride; + recv_offsets = metadata.recv_offsets * elem_stride; + max_send_bytes = metadata.max_send_bytes * elem_stride; + } + + launchAlltoallvKernel( + send.data_ptr(), + reinterpret_cast(recv_ptrs_cuda.data_ptr()), + send_offsets.data_ptr(), + send_counts.data_ptr(), + recv_offsets.data_ptr(), + metadata.world_size, + send.element_size(), + max_send_bytes * send.element_size(), + stream); +} + +void alltoallvBarrier(const std::string& tag) { + Communicator& comm = Communicator::getInstance(); + comm.barrier(); +} + } // namespace nvfuser diff --git a/csrc/multidevice/cuda_p2p.h b/csrc/multidevice/cuda_p2p.h index 4947e4e6ee1..e9fd5828597 100644 --- a/csrc/multidevice/cuda_p2p.h +++ b/csrc/multidevice/cuda_p2p.h @@ -9,6 +9,10 @@ #include +#include +#include +#include + #include "multidevice/ipc_handle.h" namespace nvfuser { @@ -43,4 +47,29 @@ void waitWithCudaBackend( CUstream stream, int64_t root); +struct AlltoallvMetadata { + at::Tensor send_counts; // CUDA [R] + at::Tensor recv_counts; // CUDA [R] + at::Tensor send_offsets; // CUDA [R] + at::Tensor recv_offsets; // CUDA [R] + int64_t total_recv = 0; + int64_t max_recv = 0; + int64_t max_send_total = 0; + int64_t max_send_bytes = 0; + int64_t world_size = 0; +}; + +AlltoallvMetadata prepareAlltoallvMetadata( + const at::Tensor& send_counts, + const std::string& tag); + +void alltoallvWithCudaBackend( + const at::Tensor& send, + const at::Tensor& recv, + const AlltoallvMetadata& metadata, + const std::vector& recv_ptrs, + CUstream stream); + +void alltoallvBarrier(const std::string& tag); + } // namespace nvfuser diff --git a/csrc/multidevice/dispatch_combine.cpp b/csrc/multidevice/dispatch_combine.cpp index 738e27765d9..cbad812aa06 100644 --- a/csrc/multidevice/dispatch_combine.cpp +++ b/csrc/multidevice/dispatch_combine.cpp @@ -11,9 +11,12 @@ #include #include +#include #include #include "multidevice/communicator.h" +#include "multidevice/cuda_p2p.h" +#include "multidevice/symmetric_tensor.h" #include "utils.h" namespace nvfuser { @@ -114,53 +117,160 @@ DispatchResult doMoEDispatch( my_rank, at::TensorOptions().dtype(at::kLong).device(x.device())); - // For CPU-initiated comms (e.g. NCCL), split metadata must live on CPU, so we - // sync/copy here. GPU-initiated comms can avoid this extra sync. + // Split metadata is exchanged via CPU (TCPStore), so we sync/copy here. auto rank_for_token_cpu = rank_for_token.to(at::kCPU); auto n_tokens_to_rank_cpu = at::bincount(rank_for_token_cpu, {}, world_size).to(at::kLong); auto n_tokens_to_rank = n_tokens_to_rank_cpu.to(x.device()); - auto n_tokens_from_rank = at::empty_like(n_tokens_to_rank); + if (backend == CommunicatorBackend::kNccl) { + NVF_CHECK( + communicator->isBackendAvailable(backend), + "Backend not available for dispatch: ", + backend); + auto* pg = communicator->getWorld(backend); + NVF_CHECK(pg != nullptr, "Dispatch backend is null."); + + auto n_tokens_from_rank = at::empty_like(n_tokens_to_rank); + std::vector one_split(world_size, 1); + waitWork(pg->alltoall_base( + n_tokens_from_rank, n_tokens_to_rank, one_split, one_split)); + + auto input_splits = toSplitSizes(n_tokens_to_rank); + auto output_splits = toSplitSizes(n_tokens_from_rank); + auto total_recv = sumSplitSizes(output_splits); + + auto recv_x = at::empty({total_recv, hidden}, x.options()); + auto recv_topk_idx = at::empty({total_recv}, topk_idx_flat.options()); + auto recv_topk_weights = + at::empty({total_recv}, topk_weights_flat.options()); + auto recv_src_idx = at::empty({total_recv}, send_src_idx.options()); + auto recv_src_rank = at::empty({total_recv}, send_src_rank.options()); + + waitWork(pg->alltoall_base(recv_x, send_x, output_splits, input_splits)); + waitWork(pg->alltoall_base( + recv_topk_idx, send_topk_idx, output_splits, input_splits)); + waitWork(pg->alltoall_base( + recv_topk_weights, send_topk_weights, output_splits, input_splits)); + waitWork(pg->alltoall_base( + recv_src_idx, send_src_idx, output_splits, input_splits)); + waitWork(pg->alltoall_base( + recv_src_rank, send_src_rank, output_splits, input_splits)); + + const int64_t experts_per_rank = num_experts / world_size; + auto local_expert = recv_topk_idx - my_rank * experts_per_rank; + auto expert_sorted = local_expert.sort(); + auto expert_order = std::get<1>(expert_sorted); + recv_x = recv_x.index_select(0, expert_order); + recv_topk_idx = recv_topk_idx.index_select(0, expert_order); + recv_topk_weights = recv_topk_weights.index_select(0, expert_order); + recv_src_idx = recv_src_idx.index_select(0, expert_order); + recv_src_rank = recv_src_rank.index_select(0, expert_order); + + return DispatchResult{ + recv_x, + recv_topk_idx, + recv_topk_weights, + recv_src_idx, + recv_src_rank, + n_tokens_to_rank, + n_tokens_from_rank}; + } NVF_CHECK( - backend == CommunicatorBackend::kNccl, - "Only NCCL backend is supported for MoEDispatch."); - CommunicatorBackend actual_backend = backend; - NVF_CHECK( - communicator->isBackendAvailable(actual_backend), - "Backend not available for dispatch: ", - actual_backend); - auto* pg = communicator->getWorld(actual_backend); - NVF_CHECK(pg != nullptr, "Dispatch backend is null."); - - // Exchange per-rank token counts to build split sizes for alltoall. - std::vector one_split(world_size, 1); - waitWork(pg->alltoall_base( - n_tokens_from_rank, n_tokens_to_rank, one_split, one_split)); - - // Convert count tensors to CPU split vectors and size the receive buffers. - auto input_splits = toSplitSizes(n_tokens_to_rank); - auto output_splits = toSplitSizes(n_tokens_from_rank); - auto total_recv = sumSplitSizes(output_splits); - - // Allocate receive buffers for payloads and metadata. - // TODO: support preallocated buffers. - auto recv_x = at::empty({total_recv, hidden}, x.options()); - auto recv_topk_idx = at::empty({total_recv}, topk_idx_flat.options()); - auto recv_topk_weights = at::empty({total_recv}, topk_weights_flat.options()); - auto recv_src_idx = at::empty({total_recv}, send_src_idx.options()); - auto recv_src_rank = at::empty({total_recv}, send_src_rank.options()); - - // Alltoall exchange payloads with per-rank splits. - waitWork(pg->alltoall_base(recv_x, send_x, output_splits, input_splits)); - waitWork(pg->alltoall_base( - recv_topk_idx, send_topk_idx, output_splits, input_splits)); - waitWork(pg->alltoall_base( - recv_topk_weights, send_topk_weights, output_splits, input_splits)); - waitWork(pg->alltoall_base( - recv_src_idx, send_src_idx, output_splits, input_splits)); - waitWork(pg->alltoall_base( - recv_src_rank, send_src_rank, output_splits, input_splits)); + backend == CommunicatorBackend::kCuda, + "Only CUDA and NCCL backends are supported for MoEDispatch."); + + auto metadata = + prepareAlltoallvMetadata(n_tokens_to_rank, "moe_dispatch_counts"); + auto n_tokens_from_rank = metadata.recv_counts; + const int64_t total_recv = metadata.total_recv; + const int64_t max_recv = metadata.max_recv; + + // Allocate symmetric buffers for send/recv payloads. + auto send_x_sym = SymmetricTensor::allocate( + {metadata.max_send_total, hidden}, x.scalar_type(), x.device()); + send_x_sym.narrow(0, 0, num_tokens).copy_(send_x); + auto send_topk_idx_sym = SymmetricTensor::allocate( + {metadata.max_send_total}, topk_idx_flat.scalar_type(), x.device()); + send_topk_idx_sym.narrow(0, 0, num_tokens).copy_(send_topk_idx); + auto send_topk_weights_sym = SymmetricTensor::allocate( + {metadata.max_send_total}, topk_weights_flat.scalar_type(), x.device()); + send_topk_weights_sym.narrow(0, 0, num_tokens).copy_(send_topk_weights); + auto send_src_idx_sym = SymmetricTensor::allocate( + {metadata.max_send_total}, send_src_idx.scalar_type(), x.device()); + send_src_idx_sym.narrow(0, 0, num_tokens).copy_(send_src_idx); + auto send_src_rank_sym = SymmetricTensor::allocate( + {metadata.max_send_total}, send_src_rank.scalar_type(), x.device()); + send_src_rank_sym.narrow(0, 0, num_tokens).copy_(send_src_rank); + + auto recv_x_sym = SymmetricTensor::allocate( + {max_recv, hidden}, x.scalar_type(), x.device()); + auto recv_topk_idx_sym = SymmetricTensor::allocate( + {max_recv}, topk_idx_flat.scalar_type(), x.device()); + auto recv_topk_weights_sym = SymmetricTensor::allocate( + {max_recv}, topk_weights_flat.scalar_type(), x.device()); + auto recv_src_idx_sym = SymmetricTensor::allocate( + {max_recv}, send_src_idx.scalar_type(), x.device()); + auto recv_src_rank_sym = SymmetricTensor::allocate( + {max_recv}, send_src_rank.scalar_type(), x.device()); + + SymmetricTensor recv_x_handle(recv_x_sym); + SymmetricTensor recv_topk_idx_handle(recv_topk_idx_sym); + SymmetricTensor recv_topk_weights_handle(recv_topk_weights_sym); + SymmetricTensor recv_src_idx_handle(recv_src_idx_sym); + SymmetricTensor recv_src_rank_handle(recv_src_rank_sym); + recv_x_handle.setupRemoteHandles("moe_dispatch_recv_x"); + recv_topk_idx_handle.setupRemoteHandles("moe_dispatch_recv_topk_idx"); + recv_topk_weights_handle.setupRemoteHandles("moe_dispatch_recv_topk_weights"); + recv_src_idx_handle.setupRemoteHandles("moe_dispatch_recv_src_idx"); + recv_src_rank_handle.setupRemoteHandles("moe_dispatch_recv_src_rank"); + + std::vector recv_x_ptrs(world_size); + std::vector recv_topk_idx_ptrs(world_size); + std::vector recv_topk_weights_ptrs(world_size); + std::vector recv_src_idx_ptrs(world_size); + std::vector recv_src_rank_ptrs(world_size); + for (int64_t rank = 0; rank < world_size; ++rank) { + recv_x_ptrs[rank] = recv_x_handle.remoteTensor(rank).data_ptr(); + recv_topk_idx_ptrs[rank] = + recv_topk_idx_handle.remoteTensor(rank).data_ptr(); + recv_topk_weights_ptrs[rank] = + recv_topk_weights_handle.remoteTensor(rank).data_ptr(); + recv_src_idx_ptrs[rank] = recv_src_idx_handle.remoteTensor(rank).data_ptr(); + recv_src_rank_ptrs[rank] = + recv_src_rank_handle.remoteTensor(rank).data_ptr(); + } + + auto stream = + static_cast(at::cuda::getDefaultCUDAStream().stream()); + alltoallvWithCudaBackend( + send_x_sym, recv_x_sym, metadata, recv_x_ptrs, stream); + alltoallvWithCudaBackend( + send_topk_idx_sym, + recv_topk_idx_sym, + metadata, + recv_topk_idx_ptrs, + stream); + alltoallvWithCudaBackend( + send_topk_weights_sym, + recv_topk_weights_sym, + metadata, + recv_topk_weights_ptrs, + stream); + alltoallvWithCudaBackend( + send_src_idx_sym, recv_src_idx_sym, metadata, recv_src_idx_ptrs, stream); + alltoallvWithCudaBackend( + send_src_rank_sym, + recv_src_rank_sym, + metadata, + recv_src_rank_ptrs, + stream); + alltoallvBarrier("moe_dispatch_counts"); + auto recv_x = recv_x_sym.narrow(0, 0, total_recv); + auto recv_topk_idx = recv_topk_idx_sym.narrow(0, 0, total_recv); + auto recv_topk_weights = recv_topk_weights_sym.narrow(0, 0, total_recv); + auto recv_src_idx = recv_src_idx_sym.narrow(0, 0, total_recv); + auto recv_src_rank = recv_src_rank_sym.narrow(0, 0, total_recv); // Locally reorder by expert id so each rank processes contiguous experts. const int64_t experts_per_rank = num_experts / world_size; @@ -212,6 +322,7 @@ CombineResult doMoECombine( n_tokens_from_rank.numel() == communicator->size(), "n_tokens_from_rank must match world size."); + const int64_t world_size = communicator->size(); c10::cuda::CUDAGuard device_guard(x.device()); // Sort by source rank so alltoall can send contiguous chunks per rank. @@ -222,32 +333,100 @@ CombineResult doMoECombine( auto send_src_idx = src_idx.index_select(0, sorted_indices); // Split sizes come from dispatch counts. - auto input_splits = toSplitSizes(n_tokens_from_rank); - auto output_splits = toSplitSizes(n_tokens_to_rank); - auto total_recv = sumSplitSizes(output_splits); - auto hidden = x.size(1); + if (backend == CommunicatorBackend::kNccl) { + NVF_CHECK( + communicator->isBackendAvailable(backend), + "Backend not available for combine: ", + backend); + auto* pg = communicator->getWorld(backend); + NVF_CHECK(pg != nullptr, "Combine backend is null."); + + auto input_splits = toSplitSizes(n_tokens_from_rank); + auto output_splits = toSplitSizes(n_tokens_to_rank); + auto total_recv = sumSplitSizes(output_splits); + auto hidden = x.size(1); + + auto recv_x = at::empty({total_recv, hidden}, x.options()); + auto recv_topk_weights = at::empty({total_recv}, topk_weights.options()); + auto recv_src_idx = at::empty({total_recv}, src_idx.options()); + + waitWork(pg->alltoall_base(recv_x, send_x, output_splits, input_splits)); + waitWork(pg->alltoall_base( + recv_topk_weights, send_topk_weights, output_splits, input_splits)); + waitWork(pg->alltoall_base( + recv_src_idx, send_src_idx, output_splits, input_splits)); + + auto combined_x = at::empty({total_recv, hidden}, x.options()); + combined_x.index_copy_(0, recv_src_idx, recv_x); + auto combined_topk_weights = + at::empty({total_recv}, topk_weights.options()); + combined_topk_weights.index_copy_(0, recv_src_idx, recv_topk_weights); + + return CombineResult{combined_x, combined_topk_weights}; + } NVF_CHECK( - backend == CommunicatorBackend::kNccl, - "Only NCCL backend is supported for MoECombine."); - CommunicatorBackend actual_backend = backend; - NVF_CHECK( - communicator->isBackendAvailable(actual_backend), - "Backend not available for combine: ", - actual_backend); - auto* pg = communicator->getWorld(actual_backend); - NVF_CHECK(pg != nullptr, "Combine backend is null."); - - // Allocate receive buffers and exchange payloads back to source ranks. - auto recv_x = at::empty({total_recv, hidden}, x.options()); - auto recv_topk_weights = at::empty({total_recv}, topk_weights.options()); - auto recv_src_idx = at::empty({total_recv}, src_idx.options()); - - waitWork(pg->alltoall_base(recv_x, send_x, output_splits, input_splits)); - waitWork(pg->alltoall_base( - recv_topk_weights, send_topk_weights, output_splits, input_splits)); - waitWork(pg->alltoall_base( - recv_src_idx, send_src_idx, output_splits, input_splits)); + backend == CommunicatorBackend::kCuda, + "Only CUDA and NCCL backends are supported for MoECombine."); + + auto metadata = + prepareAlltoallvMetadata(n_tokens_from_rank, "moe_combine_counts"); + const int64_t total_recv = metadata.total_recv; + const int64_t max_recv = metadata.max_recv; + auto hidden = x.size(1); + + // Allocate symmetric buffers for send/recv payloads. + auto send_x_sym = SymmetricTensor::allocate( + {metadata.max_send_total, hidden}, x.scalar_type(), x.device()); + send_x_sym.narrow(0, 0, x.size(0)).copy_(send_x); + auto send_topk_weights_sym = SymmetricTensor::allocate( + {metadata.max_send_total}, topk_weights.scalar_type(), x.device()); + send_topk_weights_sym.narrow(0, 0, x.size(0)).copy_(send_topk_weights); + auto send_src_idx_sym = SymmetricTensor::allocate( + {metadata.max_send_total}, src_idx.scalar_type(), x.device()); + send_src_idx_sym.narrow(0, 0, x.size(0)).copy_(send_src_idx); + + auto recv_x_sym = SymmetricTensor::allocate( + {max_recv, hidden}, x.scalar_type(), x.device()); + auto recv_topk_weights_sym = SymmetricTensor::allocate( + {max_recv}, topk_weights.scalar_type(), x.device()); + auto recv_src_idx_sym = + SymmetricTensor::allocate({max_recv}, src_idx.scalar_type(), x.device()); + + SymmetricTensor recv_x_handle(recv_x_sym); + SymmetricTensor recv_topk_weights_handle(recv_topk_weights_sym); + SymmetricTensor recv_src_idx_handle(recv_src_idx_sym); + recv_x_handle.setupRemoteHandles("moe_combine_recv_x"); + recv_topk_weights_handle.setupRemoteHandles("moe_combine_recv_topk_weights"); + recv_src_idx_handle.setupRemoteHandles("moe_combine_recv_src_idx"); + + std::vector recv_x_ptrs(world_size); + std::vector recv_topk_weights_ptrs(world_size); + std::vector recv_src_idx_ptrs(world_size); + for (int64_t rank = 0; rank < world_size; ++rank) { + recv_x_ptrs[rank] = recv_x_handle.remoteTensor(rank).data_ptr(); + recv_topk_weights_ptrs[rank] = + recv_topk_weights_handle.remoteTensor(rank).data_ptr(); + recv_src_idx_ptrs[rank] = recv_src_idx_handle.remoteTensor(rank).data_ptr(); + } + + auto stream = + static_cast(at::cuda::getDefaultCUDAStream().stream()); + alltoallvWithCudaBackend( + send_x_sym, recv_x_sym, metadata, recv_x_ptrs, stream); + alltoallvWithCudaBackend( + send_topk_weights_sym, + recv_topk_weights_sym, + metadata, + recv_topk_weights_ptrs, + stream); + alltoallvWithCudaBackend( + send_src_idx_sym, recv_src_idx_sym, metadata, recv_src_idx_ptrs, stream); + alltoallvBarrier("moe_combine_counts"); + + auto recv_x = recv_x_sym.narrow(0, 0, total_recv); + auto recv_topk_weights = recv_topk_weights_sym.narrow(0, 0, total_recv); + auto recv_src_idx = recv_src_idx_sym.narrow(0, 0, total_recv); // Scatter by original token index to restore local order. auto combined_x = at::empty({total_recv, hidden}, x.options()); diff --git a/csrc/multidevice/dispatch_combine.h b/csrc/multidevice/dispatch_combine.h index 5714a45a818..ceb0a2652b4 100644 --- a/csrc/multidevice/dispatch_combine.h +++ b/csrc/multidevice/dispatch_combine.h @@ -38,7 +38,7 @@ struct CombineResult { // is_token_in_rank: One-hot token-to-rank assignment, shape [T, R]. // num_experts: Total experts across all ranks (must be divisible by R). // communicator: Communicator for alltoall exchange. -// backend: Communication backend (only NCCL is supported for now). +// backend: Communication backend (CUDA or NCCL). // // Returns: // DispatchResult with recv_* tensors on this rank. @@ -86,7 +86,7 @@ NVF_API DispatchResult doMoEDispatch( // n_tokens_to_rank: Tokens sent to each rank (from dispatch), shape [R]. // n_tokens_from_rank: Tokens received from each rank (from dispatch), shape // [R]. communicator: Communicator for alltoall exchange. backend: -// Communication backend (only NCCL is supported for now). +// Communication backend (CUDA or NCCL). // // Returns: // CombineResult with tokens restored to original order on this rank. diff --git a/tests/cpp/test_multidevice_alltoallv.cpp b/tests/cpp/test_multidevice_alltoallv.cpp new file mode 100644 index 00000000000..02cb21b7892 --- /dev/null +++ b/tests/cpp/test_multidevice_alltoallv.cpp @@ -0,0 +1,82 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2026-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include +#include + +#include + +#include "multidevice/cuda_p2p.h" +#include "multidevice/symmetric_tensor.h" +#include "tests/cpp/multidevice.h" + +namespace nvfuser { +namespace hir { + +class AlltoallvCudaTest : public MultiDeviceTest {}; + +TEST_F(AlltoallvCudaTest, AlltoallvAsymmetric) { + if (!communicator_->is_available() || communicator_->size() < 2) { + GTEST_SKIP() << "This test needs at least 2 ranks."; + } + + const int64_t world_size = communicator_->size(); + const int64_t my_rank = communicator_->deviceId(); + + auto int_options = + at::TensorOptions().device(communicator_->device()).dtype(at::kLong); + + auto count_for = [](int64_t sender, int64_t dest) { + return (sender + dest) % 3 + 1; + }; + auto send_counts = at::empty({world_size}, int_options); + for (int64_t dest = 0; dest < world_size; ++dest) { + send_counts.index_put_({dest}, count_for(my_rank, dest)); + } + + auto metadata = prepareAlltoallvMetadata(send_counts, "test_alltoallv_counts"); + const int64_t max_recv = metadata.max_recv; + const int64_t total_send = send_counts.sum().item(); + auto send_sym = SymmetricTensor::allocate( + {metadata.max_send_total}, at::kLong, communicator_->device()); + send_sym.narrow(0, 0, total_send) + .copy_(at::arange(total_send, int_options) + my_rank * 1000); + + auto recv_sym = SymmetricTensor::allocate( + {max_recv}, at::kLong, communicator_->device()); + SymmetricTensor recv_handle(recv_sym); + recv_handle.setupRemoteHandles("test_alltoallv_recv"); + + std::vector recv_ptrs(world_size); + for (int64_t rank = 0; rank < world_size; ++rank) { + recv_ptrs[rank] = recv_handle.remoteTensor(rank).data_ptr(); + } + + auto stream = at::cuda::getDefaultCUDAStream().stream(); + alltoallvWithCudaBackend(send_sym, recv_sym, metadata, recv_ptrs, stream); + alltoallvBarrier("test_alltoallv_counts"); + + auto recv_view = recv_sym.narrow(0, 0, metadata.total_recv); + std::vector expected_vec; + expected_vec.reserve(static_cast(metadata.total_recv)); + for (int64_t sender = 0; sender < world_size; ++sender) { + int64_t offset = 0; + for (int64_t dest = 0; dest < my_rank; ++dest) { + offset += count_for(sender, dest); + } + const int64_t count = count_for(sender, my_rank); + for (int64_t i = 0; i < count; ++i) { + expected_vec.push_back(offset + i + sender * 1000); + } + } + auto expected = at::tensor(expected_vec, int_options); + EXPECT_TRUE(at::equal(recv_view, expected)) + << "Alltoallv mismatch on rank " << my_rank; +} + +} // namespace hir +} // namespace nvfuser diff --git a/tests/cpp/test_multidevice_dispatch_combine.cpp b/tests/cpp/test_multidevice_dispatch_combine.cpp index 0d84dbc03e0..1a28c6e18d5 100644 --- a/tests/cpp/test_multidevice_dispatch_combine.cpp +++ b/tests/cpp/test_multidevice_dispatch_combine.cpp @@ -21,15 +21,21 @@ namespace nvfuser { namespace hir { -class DispatchCombineTest : public MultiDeviceTest {}; +class DispatchCombineTest + : public MultiDeviceTest, + public ::testing::WithParamInterface {}; -TEST_F(DispatchCombineTest, DispatchCombineTop1) { +TEST_P(DispatchCombineTest, DispatchCombineTop1) { if (!communicator_->is_available() || communicator_->size() < 2) { GTEST_SKIP() << "This test needs at least 2 ranks."; } const int64_t world_size = communicator_->size(); const int64_t my_rank = communicator_->deviceId(); + const auto backend = GetParam(); + if (!communicator_->isBackendAvailable(CommunicatorBackend::kNccl)) { + GTEST_SKIP() << "Backend " << backend << " not available."; + } constexpr int64_t kNumExpertsPerRank = 2; const int64_t num_experts = world_size * kNumExpertsPerRank; constexpr int64_t kNumTokens = 4; @@ -64,7 +70,7 @@ TEST_F(DispatchCombineTest, DispatchCombineTop1) { in_topk_weights, in_is_token_in_rank, num_experts, - CommunicatorBackend::kNccl); + backend); auto* combined_x = makeSymbolicTensor(2); auto* combined_topk_weights = makeSymbolicTensor(1); @@ -77,7 +83,7 @@ TEST_F(DispatchCombineTest, DispatchCombineTop1) { recv_src_rank, n_tokens_to_rank, n_tokens_from_rank, - CommunicatorBackend::kNccl); + backend); hic->pushBackTopLevelExprs(dispatch); hic->pushBackTopLevelExprs(combine); @@ -119,10 +125,14 @@ TEST_F(DispatchCombineTest, DispatchCombineTop1) { {in_topk_weights, topk_weights}, {in_is_token_in_rank, is_token_in_rank}}); auto combined = outputs.back().as(); - EXPECT_TRUE(at::allclose(combined, x)) << "Dispatch/Combine mismatch on rank " << my_rank; } +INSTANTIATE_TEST_SUITE_P( + DispatchCombineBackends, + DispatchCombineTest, + ::testing::Values(CommunicatorBackend::kNccl, CommunicatorBackend::kCuda)); + } // namespace hir } // namespace nvfuser