From 095ac5840fcbc610c13b1c80d324be7feacac496 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Tue, 27 Jan 2026 16:17:53 -0800 Subject: [PATCH 01/10] Multi-GPU version --- tests/python/multidevice/test_alphafold3.py | 174 ++++++++++++++++++++ 1 file changed, 174 insertions(+) create mode 100644 tests/python/multidevice/test_alphafold3.py diff --git a/tests/python/multidevice/test_alphafold3.py b/tests/python/multidevice/test_alphafold3.py new file mode 100644 index 00000000000..ee3ae1f9dea --- /dev/null +++ b/tests/python/multidevice/test_alphafold3.py @@ -0,0 +1,174 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + + +# This file contains certain building blocks of the AlphaFold3 model. + +import pytest +import torch +from dataclasses import dataclass +from enum import Enum, auto + +from nvfuser_direct import FusionDefinition, DataType, TensorView + + +@dataclass +class ModelConfig: + c_z: int = 128 + c_hidden: int = 32 + n_heads: int = 4 + + +_DEFAULT_CONFIG = ModelConfig() + + +class Direction(Enum): + INCOMING = auto() # aka ending node + OUTGOING = auto() # aka starting node + + +def layer_norm( + fd: FusionDefinition, x: TensorView, w: TensorView, b: TensorView +) -> TensorView: + io_dtype = x.dtype() + x = fd.ops.cast(x, dtype=DataType.Float) + var, mean = fd.ops.var_mean(x, dims=[-1], correction=0, keepdim=True) + y = fd.ops.sub(x, mean) + var = fd.ops.add(var, fd.define_scalar(1e-5)) + y = fd.ops.mul(y, fd.ops.rsqrt(var)) + shape = fd.ops.shape(x) + w = fd.ops.broadcast_in_dim(w, shape=shape, broadcast_dims=[-1]) + y = fd.ops.mul(y, w) + b = fd.ops.broadcast_in_dim(b, shape=shape, broadcast_dims=[-1]) + y = fd.ops.add(y, b) + y = fd.ops.cast(y, dtype=io_dtype) + return y + + +def gating( + fd: FusionDefinition, + z: TensorView, + w_p: TensorView, + z_in: TensorView, + w_g: TensorView, +) -> TensorView: + io_dtype = z.dtype() + p = fd.ops.linear(z, w_p) + g = fd.ops.linear(z_in, w_g) + g = fd.ops.sigmoid(g) + z = fd.ops.mul(p, g) + return fd.ops.cast(z, dtype=io_dtype) + + +# https://elanapearl.github.io/blog/2024/the-illustrated-alphafold/#triangle-updates +# +# Jumper, J., Evans, R., Pritzel, A. et al. Highly accurate protein structure +# prediction with AlphaFold. Nature 596, 583–589 (2021). +# https://doi.org/10.1038/s41586-021-03819-2 +# (see Supplementary Methods 1.6.5 for details) +@pytest.mark.mpi +@pytest.mark.parametrize( + "direction", [Direction.OUTGOING, Direction.INCOMING], ids=lambda d: d.name.lower() +) +def test_triangle_updates(direction): + c_z = _DEFAULT_CONFIG.c_z + + with FusionDefinition() as fd: + z_in = fd.define_tensor( + shape=[-1, -1, -1, c_z], + dtype=DataType.BFloat16, + contiguity=True, + ) # [b, i, j, c_z] + w_norm_in = fd.define_tensor( + shape=[c_z], dtype=DataType.BFloat16, contiguity=True + ) + b_norm_in = fd.define_tensor( + shape=[c_z], dtype=DataType.BFloat16, contiguity=True + ) + w_p_in = fd.define_tensor( + shape=[c_z * 2, c_z], dtype=DataType.BFloat16, contiguity=True + ) + w_g_in = fd.define_tensor( + shape=[c_z * 2, c_z], dtype=DataType.BFloat16, contiguity=True + ) + w_norm_out = fd.define_tensor( + shape=[c_z], dtype=DataType.BFloat16, contiguity=True + ) + b_norm_out = fd.define_tensor( + shape=[c_z], dtype=DataType.BFloat16, contiguity=True + ) + w_p_out = fd.define_tensor( + shape=[c_z, c_z], dtype=DataType.BFloat16, contiguity=True + ) + w_g_out = fd.define_tensor( + shape=[c_z, c_z], dtype=DataType.BFloat16, contiguity=True + ) + # Masking is used in an internal implementation: http://nv/e-4 + mask = fd.define_tensor( + shape=[-1, -1, -1], dtype=DataType.Bool, contiguity=True + ) # [b, i, j] + + batch_size = fd.ops.size(z_in, 0) + n_tokens = fd.ops.size(z_in, 1) + + z_in = layer_norm(fd, z_in, w_norm_in, b_norm_in) + z = gating(fd, z_in, w_p_in, z_in, w_g_in) + mask = fd.ops.broadcast_in_dim( + mask, shape=[batch_size, n_tokens, n_tokens, c_z], broadcast_dims=[0, 1, 2] + ) + z = fd.ops.where(mask, z, 0.0) + a = fd.ops.slice(z, [0, 0, 0, 0], [batch_size, n_tokens, n_tokens, c_z]) + b = fd.ops.slice(z, [0, 0, 0, c_z], [batch_size, n_tokens, n_tokens, c_z * 2]) + + match direction: + case Direction.OUTGOING: + # z_out = einsum("bikc,bjkc->bijc", a, b) + a = fd.ops.permute(a, [0, 3, 1, 2]) # [b, c, i, k] + b = fd.ops.permute(b, [0, 3, 2, 1]) # [b, c, k, j] + case Direction.INCOMING: + # z_out = einsum("bkic,bkjc->bijc", a, b) + a = fd.ops.permute(a, [0, 3, 2, 1]) # [b, c, i, k] + b = fd.ops.permute(b, [0, 3, 1, 2]) # [b, c, k, j] + z = fd.ops.matmul(a, b) # [b, c, i, j] + z = fd.ops.permute(z, [0, 2, 3, 1]) # [b, i, j, c] + + z = layer_norm(fd, z, w_norm_out, b_norm_out) + z = gating(fd, z, w_p_out, z_in, w_g_out) + fd.add_output(z) + + batch_size = 3 + n_tokens = 5 + z_in = torch.testing.make_tensor( + batch_size, n_tokens, n_tokens, c_z, dtype=torch.bfloat16, device="cuda" + ) + w_norm_in = torch.testing.make_tensor(c_z, dtype=torch.bfloat16, device="cuda") + b_norm_in = torch.testing.make_tensor(c_z, dtype=torch.bfloat16, device="cuda") + w_p_in = torch.testing.make_tensor( + c_z * 2, c_z, dtype=torch.bfloat16, device="cuda" + ) + w_g_in = torch.testing.make_tensor( + c_z * 2, c_z, dtype=torch.bfloat16, device="cuda" + ) + w_norm_out = torch.testing.make_tensor(c_z, dtype=torch.bfloat16, device="cuda") + b_norm_out = torch.testing.make_tensor(c_z, dtype=torch.bfloat16, device="cuda") + w_p_out = torch.testing.make_tensor(c_z, c_z, dtype=torch.bfloat16, device="cuda") + w_g_out = torch.testing.make_tensor(c_z, c_z, dtype=torch.bfloat16, device="cuda") + mask = torch.testing.make_tensor( + batch_size, n_tokens, n_tokens, dtype=torch.bool, device="cuda" + ) + (z_out,) = fd.execute( + [ + z_in, + w_norm_in, + b_norm_in, + w_p_in, + w_g_in, + w_norm_out, + b_norm_out, + w_p_out, + w_g_out, + mask, + ] + ) + assert z_out.shape == (batch_size, n_tokens, n_tokens, c_z) From ec010f322c9c2c4a6464b9fac7d641c58f88a1cd Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Wed, 28 Jan 2026 14:15:36 -0800 Subject: [PATCH 02/10] WIP --- tests/python/multidevice/test_alphafold3.py | 80 ++++++++++++++++----- 1 file changed, 64 insertions(+), 16 deletions(-) diff --git a/tests/python/multidevice/test_alphafold3.py b/tests/python/multidevice/test_alphafold3.py index ee3ae1f9dea..1ac79fe95fe 100644 --- a/tests/python/multidevice/test_alphafold3.py +++ b/tests/python/multidevice/test_alphafold3.py @@ -10,6 +10,7 @@ from dataclasses import dataclass from enum import Enum, auto +import nvfuser_direct as nvfuser from nvfuser_direct import FusionDefinition, DataType, TensorView @@ -71,11 +72,19 @@ def gating( @pytest.mark.parametrize( "direction", [Direction.OUTGOING, Direction.INCOMING], ids=lambda d: d.name.lower() ) -def test_triangle_updates(direction): +def test_triangle_updates(direction, multidevice_test): + d = multidevice_test.size + cp_size = 1 + if d % (cp_size * cp_size) != 0: + pytest.skip( + f"We only support even split, so {d} has to be divisible by {cp_size * cp_size} for {cp_size=}." + ) + dp_size = d // (cp_size * cp_size) + c_z = _DEFAULT_CONFIG.c_z with FusionDefinition() as fd: - z_in = fd.define_tensor( + z_in_tv = fd.define_tensor( shape=[-1, -1, -1, c_z], dtype=DataType.BFloat16, contiguity=True, @@ -105,17 +114,19 @@ def test_triangle_updates(direction): shape=[c_z, c_z], dtype=DataType.BFloat16, contiguity=True ) # Masking is used in an internal implementation: http://nv/e-4 - mask = fd.define_tensor( + mask_tv = fd.define_tensor( shape=[-1, -1, -1], dtype=DataType.Bool, contiguity=True ) # [b, i, j] - batch_size = fd.ops.size(z_in, 0) - n_tokens = fd.ops.size(z_in, 1) + batch_size = fd.ops.size(z_in_tv, 0) + n_tokens = fd.ops.size(z_in_tv, 1) - z_in = layer_norm(fd, z_in, w_norm_in, b_norm_in) - z = gating(fd, z_in, w_p_in, z_in, w_g_in) + z_in = layer_norm(fd, z_in_tv, w_norm_in, b_norm_in) + z = gating(fd, z_in_tv, w_p_in, z_in, w_g_in) mask = fd.ops.broadcast_in_dim( - mask, shape=[batch_size, n_tokens, n_tokens, c_z], broadcast_dims=[0, 1, 2] + mask_tv, + shape=[batch_size, n_tokens, n_tokens, c_z], + broadcast_dims=[0, 1, 2], ) z = fd.ops.where(mask, z, 0.0) a = fd.ops.slice(z, [0, 0, 0, 0], [batch_size, n_tokens, n_tokens, c_z]) @@ -137,11 +148,50 @@ def test_triangle_updates(direction): z = gating(fd, z, w_p_out, z_in, w_g_out) fd.add_output(z) - batch_size = 3 - n_tokens = 5 - z_in = torch.testing.make_tensor( - batch_size, n_tokens, n_tokens, c_z, dtype=torch.bfloat16, device="cuda" + mesh = nvfuser.multidevice.DeviceMesh( + torch.arange(d).reshape(dp_size, cp_size, cp_size) + ) + for tv in [ + z_in_tv, + w_norm_in, + b_norm_in, + w_p_in, + w_g_in, + w_norm_out, + b_norm_out, + w_p_out, + w_g_out, + mask_tv, + ]: + tv.set_device_mesh(mesh) + + for tv in [z_in, mask]: + tv.outer_split(2, cp_size) + tv.axis(2).parallelize(nvfuser.ParallelType.mesh_x) + tv.outer_split(1, cp_size) + tv.axis(1).parallelize(nvfuser.ParallelType.mesh_y) + tv.outer_split(0, dp_size) + tv.axis(0).parallelize(nvfuser.ParallelType.mesh_z) + + batch_per_rank = 3 + n_tokens_per_rank = 5 + z_in_ref = torch.testing.make_tensor( + batch_per_rank * dp_size, + n_tokens_per_rank * cp_size, + n_tokens_per_rank * cp_size, + c_z, + dtype=torch.bfloat16, + device="cpu", ) + mask_ref = torch.testing.make_tensor( + batch_per_rank * dp_size, + n_tokens_per_rank * cp_size, + n_tokens_per_rank * cp_size, + dtype=torch.bool, + device="cpu", + ) + + z_in = multidevice_test.shard_tensor(z_in_ref, z_in_tv) w_norm_in = torch.testing.make_tensor(c_z, dtype=torch.bfloat16, device="cuda") b_norm_in = torch.testing.make_tensor(c_z, dtype=torch.bfloat16, device="cuda") w_p_in = torch.testing.make_tensor( @@ -154,9 +204,7 @@ def test_triangle_updates(direction): b_norm_out = torch.testing.make_tensor(c_z, dtype=torch.bfloat16, device="cuda") w_p_out = torch.testing.make_tensor(c_z, c_z, dtype=torch.bfloat16, device="cuda") w_g_out = torch.testing.make_tensor(c_z, c_z, dtype=torch.bfloat16, device="cuda") - mask = torch.testing.make_tensor( - batch_size, n_tokens, n_tokens, dtype=torch.bool, device="cuda" - ) + mask = multidevice_test.shard_tensor(mask_ref, mask_tv) (z_out,) = fd.execute( [ z_in, @@ -171,4 +219,4 @@ def test_triangle_updates(direction): mask, ] ) - assert z_out.shape == (batch_size, n_tokens, n_tokens, c_z) + assert z_out.shape == (batch_per_rank, n_tokens_per_rank, n_tokens_per_rank, c_z) From d6b6cfaaf62094e94070bb29470686f2f5d62b92 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Wed, 28 Jan 2026 14:19:00 -0800 Subject: [PATCH 03/10] WIP --- tests/python/multidevice/test_alphafold3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/multidevice/test_alphafold3.py b/tests/python/multidevice/test_alphafold3.py index 1ac79fe95fe..a75a1567a22 100644 --- a/tests/python/multidevice/test_alphafold3.py +++ b/tests/python/multidevice/test_alphafold3.py @@ -165,7 +165,7 @@ def test_triangle_updates(direction, multidevice_test): ]: tv.set_device_mesh(mesh) - for tv in [z_in, mask]: + for tv in [z_in_tv, mask_tv]: tv.outer_split(2, cp_size) tv.axis(2).parallelize(nvfuser.ParallelType.mesh_x) tv.outer_split(1, cp_size) From 27c031784845496fedb6f0551629655f4d38668d Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Wed, 28 Jan 2026 17:00:54 -0800 Subject: [PATCH 04/10] More debugging --- csrc/preseg_passes/propagate_shardings.cpp | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/csrc/preseg_passes/propagate_shardings.cpp b/csrc/preseg_passes/propagate_shardings.cpp index b6867d238c9..459b2910cbe 100644 --- a/csrc/preseg_passes/propagate_shardings.cpp +++ b/csrc/preseg_passes/propagate_shardings.cpp @@ -186,6 +186,13 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { PropagateDirection::kBackward); } } + + if (isDebugDumpEnabled(DebugDumpOption::PreSegmenterLogging)) { + debug() << std::endl + << "Fusion Transforms after " << name() << ":" << std::endl; + fusion->printTransforms(); + debug() << std::endl; + } } } // namespace nvfuser::preseg_passes From 9fc55417676f043ce8d3c38d6a4dc5976f0d6b7a Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Thu, 29 Jan 2026 13:49:30 -0800 Subject: [PATCH 05/10] Clean up headers --- csrc/base.h | 1 - 1 file changed, 1 deletion(-) diff --git a/csrc/base.h b/csrc/base.h index 9255dc42c0d..9bceeac845c 100644 --- a/csrc/base.h +++ b/csrc/base.h @@ -7,7 +7,6 @@ // clang-format on #pragma once -#include #include #include #include From 4a35209faa94e07413dd4730808099545f6e4e4e Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Mon, 2 Feb 2026 11:51:44 -0800 Subject: [PATCH 06/10] Manual annotation for einsum out --- tests/python/multidevice/test_alphafold3.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/python/multidevice/test_alphafold3.py b/tests/python/multidevice/test_alphafold3.py index a75a1567a22..24044c54d2f 100644 --- a/tests/python/multidevice/test_alphafold3.py +++ b/tests/python/multidevice/test_alphafold3.py @@ -144,6 +144,8 @@ def test_triangle_updates(direction, multidevice_test): z = fd.ops.matmul(a, b) # [b, c, i, j] z = fd.ops.permute(z, [0, 2, 3, 1]) # [b, i, j, c] + einsum_out = z + z = layer_norm(fd, z, w_norm_out, b_norm_out) z = gating(fd, z, w_p_out, z_in, w_g_out) fd.add_output(z) @@ -162,10 +164,11 @@ def test_triangle_updates(direction, multidevice_test): w_p_out, w_g_out, mask_tv, + einsum_out, ]: tv.set_device_mesh(mesh) - for tv in [z_in_tv, mask_tv]: + for tv in [z_in_tv, mask_tv, einsum_out]: tv.outer_split(2, cp_size) tv.axis(2).parallelize(nvfuser.ParallelType.mesh_x) tv.outer_split(1, cp_size) From 2cfcd99d6ec56534574b6ca59b9afa6d358fd0fb Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Mon, 2 Feb 2026 20:32:35 -0800 Subject: [PATCH 07/10] Fix getCommunicationInfo --- csrc/host_ir/lower_to_communication.cpp | 42 ++++++++++++++----------- 1 file changed, 24 insertions(+), 18 deletions(-) diff --git a/csrc/host_ir/lower_to_communication.cpp b/csrc/host_ir/lower_to_communication.cpp index d91fd4eda60..44f5a93510d 100644 --- a/csrc/host_ir/lower_to_communication.cpp +++ b/csrc/host_ir/lower_to_communication.cpp @@ -8,17 +8,14 @@ #include "host_ir/lower_to_communication.h" -#include "host_ir/container.h" -#include "ir/all_nodes.h" -#include "ir/allocation_utils.h" #include "ir/builder.h" +#include "ir/interface_nodes.h" #include "ir/internal_base_nodes.h" #include "ir/iostream.h" -#include "kernel_ir.h" +#include "logical_domain_map.h" #include "multidevice/communication.h" #include "multidevice/resharding.h" #include "multidevice/utils.h" -#include "ops/all_ops.h" namespace nvfuser { @@ -56,10 +53,11 @@ void lowerToScatter( const CommunicatorBackend backend, std::vector& comms) { const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh(); - NVF_ERROR( - receiver_mesh.rank() == 1, + NVF_ERROR_EQ( + receiver_mesh.rank(), + 1, "Gather only supported on a 1D mesh. Given ", - receiver_mesh); + output_tv->toString()); // Find a common device between input and receiver meshes to be the root std::vector input_devices = input_tv->getDeviceMesh().vector(); @@ -348,13 +346,16 @@ CommunicationInfo getCommunicationInfo(Expr* e) { "getCommunicationInfo should only be called when `e` is known to be a " "communication. Given: ", e); - + NVF_ERROR_EQ( + e->inputs().size(), 1, "Expected 1 input, but got ", e->toString()); auto* producer = e->inputs().at(0)->as(); + NVF_ERROR_EQ( + e->outputs().size(), 1, "Expected 1 output, but got ", e->toString()); auto* consumer = e->outputs().at(0)->as(); - std::optional communication_info = std::nullopt; - // Fill `communication_info` instead of returning the result, so we can catch - // errors when more than one DIDs have sharding changes. + std::optional communication_info = std::nullopt; + // Fill `communication_info` instead of returning the result, so we can + // catch errors when more than one DIDs have sharding changes. auto fill_communication_info = [&](CommunicationType type, IterDomain* p_sharded_id, IterDomain* c_sharded_id) { @@ -375,19 +376,23 @@ CommunicationInfo getCommunicationInfo(Expr* e) { auto consumer_pt_to_did = mapDeviceAndStreamParallelTypeToId(consumer->getLoopDomain()); + const DeviceMesh& producer_mesh = producer->getDeviceMesh(); + const DeviceMesh& consumer_mesh = consumer->getDeviceMesh(); + const bool same_mesh = producer_mesh == consumer_mesh; + for (ParallelType pt : kParallelTypeDIDs) { + if (!haveDifferentShardings(producer, consumer, {pt})) { + continue; + } + IterDomain* p_loop_did = getOrDefault(producer_pt_to_did, pt); IterDomain* c_loop_did = getOrDefault(consumer_pt_to_did, pt); if (p_loop_did == nullptr && c_loop_did == nullptr) { // Not sharded on this parallel type - continue; + NVF_THROW("Not sharded on this parallel type: ", pt); } - const DeviceMesh& producer_mesh = producer->getDeviceMesh(); - const DeviceMesh& consumer_mesh = consumer->getDeviceMesh(); - const bool same_mesh = producer_mesh == consumer_mesh; - if (e->isA()) { if (p_loop_did && !c_loop_did) { IterDomain* p_logical_id = getLogicalFromLoopId(producer, p_loop_did); @@ -435,7 +440,8 @@ CommunicationInfo getCommunicationInfo(Expr* e) { auto c_it = p2c_map.find(p_logical_id); NVF_ERROR( c_it != p2c_map.end(), - "Cannot find the mapped consumer logical ID for the producer logical " + "Cannot find the mapped consumer logical ID for the producer " + "logical " "ID ", p_logical_id->toString()); if (!c_it->second->isReduction()) { From 407e5c5b9b062d05e806c5899d40fcdab868a458 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Mon, 2 Feb 2026 22:24:15 -0800 Subject: [PATCH 08/10] NVFUSER_DUMP=segmented_fusion prints transforms for multi-GPU debugging. Multi-GPU scheduling happens before segmentation and the shardings are encoded as loop transforms. --- csrc/fusion_segmenter.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/csrc/fusion_segmenter.cpp b/csrc/fusion_segmenter.cpp index f472d41f061..c397bd6af40 100644 --- a/csrc/fusion_segmenter.cpp +++ b/csrc/fusion_segmenter.cpp @@ -1732,9 +1732,10 @@ std::ostream& operator<<( } void SegmentedFusion::print() const { - debug() << "Segmented_Fusion Dump: -- Re-written complete fusion:{\n"; - completeFusion()->printMath(); - debug() << "} // {Re-written complete fusion}\n"; + debug() << "Segmented_Fusion Dump: -- Re-written complete fusion:{" + << std::endl; + completeFusion()->print(); + debug() << "} // {Re-written complete fusion}" << std::endl << std::endl; debug() << this << "\n"; } From 6547aeb0ab2ddab6c4cc1554948bae1689f6c4c4 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Mon, 2 Feb 2026 23:07:09 -0800 Subject: [PATCH 09/10] Add a small repro --- tests/python/multidevice/test_multidevice.py | 59 +++++++++++++++++++- 1 file changed, 58 insertions(+), 1 deletion(-) diff --git a/tests/python/multidevice/test_multidevice.py b/tests/python/multidevice/test_multidevice.py index 99858ca2e41..965b3a7799f 100644 --- a/tests/python/multidevice/test_multidevice.py +++ b/tests/python/multidevice/test_multidevice.py @@ -29,7 +29,6 @@ def test_sizes_and_ranks(multidevice_test): @pytest.mark.mpi def test_pointwise(multidevice_test): num_devices = multidevice_test.size - mesh = nvfuser.multidevice.DeviceMesh(torch.arange(num_devices)) with FusionDefinition() as fd: inp_tv = fd.define_tensor((-1, -1), contiguity=False, dtype=DataType.Float) @@ -37,6 +36,7 @@ def test_pointwise(multidevice_test): tv2 = fd.ops.add(tv1, tv1) fd.add_output(tv2) + mesh = nvfuser.multidevice.DeviceMesh(torch.arange(num_devices)) for tv in [inp_tv, tv1, tv2]: tv.set_device_mesh(mesh) @@ -50,6 +50,63 @@ def test_pointwise(multidevice_test): torch.testing.assert_close(out.cpu(), out_ref) +@pytest.mark.mpi +def test_transpose(multidevice_test): + d = multidevice_test.size + cp_size = 2 + if d % (cp_size * cp_size) != 0: + pytest.skip( + f"We only support even split, so {d} has to be divisible by {cp_size * cp_size} for {cp_size=}." + ) + dp_size = d // (cp_size * cp_size) + + c = 128 + with FusionDefinition() as fd: + inp_tv = fd.define_tensor( + (-1, c, -1, -1, cp_size), contiguity=True, dtype=DataType.BFloat16 + ) + out_tv = fd.ops.set(inp_tv) + fd.add_output(out_tv) + + mesh = nvfuser.multidevice.DeviceMesh( + torch.arange(d).reshape(dp_size, cp_size, cp_size) + ) + for tv in [inp_tv, out_tv]: + tv.set_device_mesh(mesh) + + inp_tv.axis(4).parallelize(nvfuser.ParallelType.mesh_y) + inp_tv.outer_split(3, cp_size) + inp_tv.axis(3).parallelize(nvfuser.ParallelType.mesh_x) + inp_tv.outer_split(0, dp_size) + inp_tv.axis(0).parallelize(nvfuser.ParallelType.mesh_z) + + out_tv.axis(4).parallelize(nvfuser.ParallelType.mesh_y) + out_tv.outer_split(3, cp_size) + out_tv.axis(3).parallelize(nvfuser.ParallelType.mesh_x) + out_tv.outer_split(0, dp_size) + out_tv.axis(0).parallelize(nvfuser.ParallelType.mesh_z) + out_tv.set_allocation_domain( + ( + out_tv.axis(3), + out_tv.axis(0), + out_tv.axis(1), + out_tv.axis(2), + out_tv.axis(4), + out_tv.axis(5), + out_tv.axis(6), + ), + True, + ) + + b = dp_size * 3 + s = cp_size * 5 + inp_ref = torch.randn(b, c, s, s, cp_size, dtype=torch.bfloat16) + out_ref = inp_ref + + inp = multidevice_test.shard_tensor(inp_ref, inp_tv) + fd.execute([inp]) + + class QkvFormat(Enum): BHSE = auto() BSHE = auto() From 54bbceb932ee0d11a922538c4cbd986823ae53fc Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Mon, 2 Feb 2026 23:09:43 -0800 Subject: [PATCH 10/10] Increase cp_size to 2 --- tests/python/multidevice/test_alphafold3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/multidevice/test_alphafold3.py b/tests/python/multidevice/test_alphafold3.py index 24044c54d2f..a18b3682808 100644 --- a/tests/python/multidevice/test_alphafold3.py +++ b/tests/python/multidevice/test_alphafold3.py @@ -74,7 +74,7 @@ def gating( ) def test_triangle_updates(direction, multidevice_test): d = multidevice_test.size - cp_size = 1 + cp_size = 2 if d % (cp_size * cp_size) != 0: pytest.skip( f"We only support even split, so {d} has to be divisible by {cp_size * cp_size} for {cp_size=}."