From cb4d494c0f56d45d63d22bf594ed8df6e82598fd Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Fri, 16 Jan 2026 19:59:25 -0800 Subject: [PATCH 1/2] Reference model uses num_chunks=4 and num_streams=2 --- tests/python/multidevice/test_overlap.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/multidevice/test_overlap.py b/tests/python/multidevice/test_overlap.py index 02fdb1bfa53..5ec29fcb02a 100644 --- a/tests/python/multidevice/test_overlap.py +++ b/tests/python/multidevice/test_overlap.py @@ -183,7 +183,7 @@ def row_parallel_linear_forward_reference( main_stream = torch.cuda.current_stream() worker_streams = [] for i, (inp_chunk, out_chunk) in enumerate(zip(inp_chunks, out_chunks)): - worker_stream = stream_pool.get(i) + worker_stream = stream_pool.get(i % 2) worker_streams.append(worker_stream) worker_stream.wait_stream(main_stream) with torch.cuda.stream(worker_stream): @@ -232,7 +232,7 @@ def test_row_parallel_linear_forward_reference(setup_default_process_group): def test_row_parallel_linear_forward_reference_benchmark( setup_default_process_group, benchmark ): - h, s, t = 8192, 2, 8192 + h, s, t = 8192, 4, 8192 d = dist.get_world_size() if (h * 4) % d != 0: pytest.skip( From b286468b78401c954981a5a6ae75e4c77b1e0cdc Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Fri, 16 Jan 2026 17:48:30 -0800 Subject: [PATCH 2/2] Use symmetric memory --- tests/python/multidevice/conftest.py | 13 ++++++++++-- tests/python/multidevice/test_overlap.py | 27 +++++++++++++++++------- 2 files changed, 30 insertions(+), 10 deletions(-) diff --git a/tests/python/multidevice/conftest.py b/tests/python/multidevice/conftest.py index 4c67519a622..6dffae8ca16 100644 --- a/tests/python/multidevice/conftest.py +++ b/tests/python/multidevice/conftest.py @@ -11,6 +11,7 @@ import torch import torch.distributed as dist +import torch.distributed._symmetric_memory as symm_mem import nvfuser_direct as nvfuser @@ -133,13 +134,21 @@ def setup_default_process_group(): torch.cuda.set_device(local_rank) - # The default port as used by https://github.com/pytorch/pytorch/blob/45a8b5682eb69d865cbf68c7f2f689b56b4efd53/torch/csrc/distributed/c10d/TCPStore.hpp#L51. + opts = dist.ProcessGroupNCCL.Options() + opts.config.cta_policy = dist.ProcessGroupNCCL.NCCL_CTA_POLICY_ZERO dist.init_process_group( backend="nccl", + pg_options=opts, + # The default port as used by https://github.com/pytorch/pytorch/blob/45a8b5682eb69d865cbf68c7f2f689b56b4efd53/torch/csrc/distributed/c10d/TCPStore.hpp#L51. init_method="tcp://localhost:29500", world_size=world_size, rank=rank, - device_id=torch.device(f"cuda:{local_rank}"), + device_id=local_rank, ) + + symm_mem.set_backend("NCCL") + symm_mem.enable_symm_mem_for_group(dist.group.WORLD.group_name) + yield + dist.destroy_process_group() diff --git a/tests/python/multidevice/test_overlap.py b/tests/python/multidevice/test_overlap.py index 5ec29fcb02a..a634da471c0 100644 --- a/tests/python/multidevice/test_overlap.py +++ b/tests/python/multidevice/test_overlap.py @@ -7,6 +7,7 @@ import torch import torch.distributed as dist +import torch.distributed._symmetric_memory as symm_mem from torch.distributed.tensor import distribute_tensor, Shard import nvfuser_direct as nvfuser @@ -168,15 +169,10 @@ def get(self, sid: int) -> torch.cuda.Stream: def row_parallel_linear_forward_reference( inp_shard: torch.Tensor, weight_shard: torch.Tensor, + out: torch.Tensor, num_chunks: int, stream_pool: StreamPool, ) -> torch.Tensor: - out = torch.empty( - inp_shard.size(0), - weight_shard.size(0), - device="cuda", - dtype=inp_shard.dtype, - ) inp_chunks = inp_shard.chunk(num_chunks) out_chunks = out.chunk(num_chunks) @@ -222,7 +218,15 @@ def test_row_parallel_linear_forward_reference(setup_default_process_group): weight_ref, mesh, placements=[Shard(-1)] ).to_local() stream_pool = StreamPool() - out = row_parallel_linear_forward_reference(inp_shard, weight_shard, s, stream_pool) + + out = symm_mem.empty( + inp_shard.size(0), + weight_shard.size(0), + device="cuda", + dtype=inp_shard.dtype, + ) + symm_mem.rendezvous(out, group=dist.group.WORLD) + row_parallel_linear_forward_reference(inp_shard, weight_shard, out, s, stream_pool) torch.testing.assert_close(out.cpu(), out_ref) @@ -251,9 +255,16 @@ def test_row_parallel_linear_forward_reference_benchmark( ).to_local() stream_pool = StreamPool() + out = symm_mem.empty( + inp_shard.size(0), + weight_shard.size(0), + device="cuda", + dtype=inp_shard.dtype, + ) + symm_mem.rendezvous(out, group=dist.group.WORLD) warmup_fn, benchmark_fn = get_benchmark_fns( lambda: row_parallel_linear_forward_reference( - inp_shard, weight_shard, s, stream_pool + inp_shard, weight_shard, out, s, stream_pool ) ) warmup_fn()