diff --git a/tests/python/multidevice/conftest.py b/tests/python/multidevice/conftest.py index 4c67519a622..6dffae8ca16 100644 --- a/tests/python/multidevice/conftest.py +++ b/tests/python/multidevice/conftest.py @@ -11,6 +11,7 @@ import torch import torch.distributed as dist +import torch.distributed._symmetric_memory as symm_mem import nvfuser_direct as nvfuser @@ -133,13 +134,21 @@ def setup_default_process_group(): torch.cuda.set_device(local_rank) - # The default port as used by https://github.com/pytorch/pytorch/blob/45a8b5682eb69d865cbf68c7f2f689b56b4efd53/torch/csrc/distributed/c10d/TCPStore.hpp#L51. + opts = dist.ProcessGroupNCCL.Options() + opts.config.cta_policy = dist.ProcessGroupNCCL.NCCL_CTA_POLICY_ZERO dist.init_process_group( backend="nccl", + pg_options=opts, + # The default port as used by https://github.com/pytorch/pytorch/blob/45a8b5682eb69d865cbf68c7f2f689b56b4efd53/torch/csrc/distributed/c10d/TCPStore.hpp#L51. init_method="tcp://localhost:29500", world_size=world_size, rank=rank, - device_id=torch.device(f"cuda:{local_rank}"), + device_id=local_rank, ) + + symm_mem.set_backend("NCCL") + symm_mem.enable_symm_mem_for_group(dist.group.WORLD.group_name) + yield + dist.destroy_process_group() diff --git a/tests/python/multidevice/test_overlap.py b/tests/python/multidevice/test_overlap.py index 02fdb1bfa53..a634da471c0 100644 --- a/tests/python/multidevice/test_overlap.py +++ b/tests/python/multidevice/test_overlap.py @@ -7,6 +7,7 @@ import torch import torch.distributed as dist +import torch.distributed._symmetric_memory as symm_mem from torch.distributed.tensor import distribute_tensor, Shard import nvfuser_direct as nvfuser @@ -168,22 +169,17 @@ def get(self, sid: int) -> torch.cuda.Stream: def row_parallel_linear_forward_reference( inp_shard: torch.Tensor, weight_shard: torch.Tensor, + out: torch.Tensor, num_chunks: int, stream_pool: StreamPool, ) -> torch.Tensor: - out = torch.empty( - inp_shard.size(0), - weight_shard.size(0), - device="cuda", - dtype=inp_shard.dtype, - ) inp_chunks = inp_shard.chunk(num_chunks) out_chunks = out.chunk(num_chunks) main_stream = torch.cuda.current_stream() worker_streams = [] for i, (inp_chunk, out_chunk) in enumerate(zip(inp_chunks, out_chunks)): - worker_stream = stream_pool.get(i) + worker_stream = stream_pool.get(i % 2) worker_streams.append(worker_stream) worker_stream.wait_stream(main_stream) with torch.cuda.stream(worker_stream): @@ -222,7 +218,15 @@ def test_row_parallel_linear_forward_reference(setup_default_process_group): weight_ref, mesh, placements=[Shard(-1)] ).to_local() stream_pool = StreamPool() - out = row_parallel_linear_forward_reference(inp_shard, weight_shard, s, stream_pool) + + out = symm_mem.empty( + inp_shard.size(0), + weight_shard.size(0), + device="cuda", + dtype=inp_shard.dtype, + ) + symm_mem.rendezvous(out, group=dist.group.WORLD) + row_parallel_linear_forward_reference(inp_shard, weight_shard, out, s, stream_pool) torch.testing.assert_close(out.cpu(), out_ref) @@ -232,7 +236,7 @@ def test_row_parallel_linear_forward_reference(setup_default_process_group): def test_row_parallel_linear_forward_reference_benchmark( setup_default_process_group, benchmark ): - h, s, t = 8192, 2, 8192 + h, s, t = 8192, 4, 8192 d = dist.get_world_size() if (h * 4) % d != 0: pytest.skip( @@ -251,9 +255,16 @@ def test_row_parallel_linear_forward_reference_benchmark( ).to_local() stream_pool = StreamPool() + out = symm_mem.empty( + inp_shard.size(0), + weight_shard.size(0), + device="cuda", + dtype=inp_shard.dtype, + ) + symm_mem.rendezvous(out, group=dist.group.WORLD) warmup_fn, benchmark_fn = get_benchmark_fns( lambda: row_parallel_linear_forward_reference( - inp_shard, weight_shard, s, stream_pool + inp_shard, weight_shard, out, s, stream_pool ) ) warmup_fn()