Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions tests/python/multidevice/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import torch
import torch.distributed as dist
import torch.distributed._symmetric_memory as symm_mem

import nvfuser_direct as nvfuser

Expand Down Expand Up @@ -133,13 +134,21 @@ def setup_default_process_group():

torch.cuda.set_device(local_rank)

# The default port as used by https://github.com/pytorch/pytorch/blob/45a8b5682eb69d865cbf68c7f2f689b56b4efd53/torch/csrc/distributed/c10d/TCPStore.hpp#L51.
opts = dist.ProcessGroupNCCL.Options()
opts.config.cta_policy = dist.ProcessGroupNCCL.NCCL_CTA_POLICY_ZERO
dist.init_process_group(
backend="nccl",
pg_options=opts,
# The default port as used by https://github.com/pytorch/pytorch/blob/45a8b5682eb69d865cbf68c7f2f689b56b4efd53/torch/csrc/distributed/c10d/TCPStore.hpp#L51.
init_method="tcp://localhost:29500",
world_size=world_size,
rank=rank,
device_id=torch.device(f"cuda:{local_rank}"),
device_id=local_rank,
)

symm_mem.set_backend("NCCL")
symm_mem.enable_symm_mem_for_group(dist.group.WORLD.group_name)

yield

dist.destroy_process_group()
31 changes: 21 additions & 10 deletions tests/python/multidevice/test_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import torch
import torch.distributed as dist
import torch.distributed._symmetric_memory as symm_mem
from torch.distributed.tensor import distribute_tensor, Shard

import nvfuser_direct as nvfuser
Expand Down Expand Up @@ -168,22 +169,17 @@ def get(self, sid: int) -> torch.cuda.Stream:
def row_parallel_linear_forward_reference(
inp_shard: torch.Tensor,
weight_shard: torch.Tensor,
out: torch.Tensor,
num_chunks: int,
stream_pool: StreamPool,
) -> torch.Tensor:
out = torch.empty(
inp_shard.size(0),
weight_shard.size(0),
device="cuda",
dtype=inp_shard.dtype,
)
inp_chunks = inp_shard.chunk(num_chunks)
out_chunks = out.chunk(num_chunks)

main_stream = torch.cuda.current_stream()
worker_streams = []
for i, (inp_chunk, out_chunk) in enumerate(zip(inp_chunks, out_chunks)):
worker_stream = stream_pool.get(i)
worker_stream = stream_pool.get(i % 2)
worker_streams.append(worker_stream)
worker_stream.wait_stream(main_stream)
with torch.cuda.stream(worker_stream):
Expand Down Expand Up @@ -222,7 +218,15 @@ def test_row_parallel_linear_forward_reference(setup_default_process_group):
weight_ref, mesh, placements=[Shard(-1)]
).to_local()
stream_pool = StreamPool()
out = row_parallel_linear_forward_reference(inp_shard, weight_shard, s, stream_pool)

out = symm_mem.empty(
inp_shard.size(0),
weight_shard.size(0),
device="cuda",
dtype=inp_shard.dtype,
)
symm_mem.rendezvous(out, group=dist.group.WORLD)
row_parallel_linear_forward_reference(inp_shard, weight_shard, out, s, stream_pool)

torch.testing.assert_close(out.cpu(), out_ref)

Expand All @@ -232,7 +236,7 @@ def test_row_parallel_linear_forward_reference(setup_default_process_group):
def test_row_parallel_linear_forward_reference_benchmark(
setup_default_process_group, benchmark
):
h, s, t = 8192, 2, 8192
h, s, t = 8192, 4, 8192
d = dist.get_world_size()
if (h * 4) % d != 0:
pytest.skip(
Expand All @@ -251,9 +255,16 @@ def test_row_parallel_linear_forward_reference_benchmark(
).to_local()

stream_pool = StreamPool()
out = symm_mem.empty(
inp_shard.size(0),
weight_shard.size(0),
device="cuda",
dtype=inp_shard.dtype,
)
symm_mem.rendezvous(out, group=dist.group.WORLD)
warmup_fn, benchmark_fn = get_benchmark_fns(
lambda: row_parallel_linear_forward_reference(
inp_shard, weight_shard, s, stream_pool
inp_shard, weight_shard, out, s, stream_pool
)
)
warmup_fn()
Expand Down