Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 104 additions & 59 deletions tests/ccl/test_all_gather_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,10 @@
Test suite for all-gather collective operation using Gluon.
"""

import os

import pytest
import torch
import torch.distributed as dist

# Try to import Gluon, skip tests if not available
try:
import iris
from iris.ccl import Config
Expand All @@ -22,77 +19,125 @@
GLUON_AVAILABLE = False


NUM_REPLAYS = 200
Comment thread
micmelesse marked this conversation as resolved.


def _all_gather(impl, src, stage_buf, result, shmem, config, async_op):
"""Stage src into the input buffer, then all-gather. Module-level (no closure
over shmem) so the test can ``del shmem`` for IPC cleanup."""
stage_buf.copy_(src)
if impl == "torch":
dist.all_gather_into_tensor(result, stage_buf)
else:
shmem.ccl.all_gather(result, stage_buf, config=config, async_op=async_op)


def _make_buffers(impl, shmem, rank, world_size, M, N, dtype, block_size_m, block_size_n):
"""Resolve impl -> (stage_buf, result, config) in one place: torch uses plain
device tensors and no config; the iris backends use symmetric-heap buffers and
a use_gluon config. Output is (world_size * M, N) — block r holds rank r's input."""
if impl == "torch":
stage = torch.empty((M, N), dtype=dtype, device=f"cuda:{rank}")
result = torch.empty((world_size * M, N), dtype=dtype, device=f"cuda:{rank}")
return stage, result, None
stage = shmem.zeros((M, N), dtype=dtype)
result = shmem.zeros((world_size * M, N), dtype=dtype)
config = Config(use_gluon=(impl == "gluon"), block_size_m=block_size_m, block_size_n=block_size_n)
return stage, result, config


@pytest.mark.skipif(not GLUON_AVAILABLE, reason="Gluon not available")
@pytest.mark.parametrize(
"dtype",
[
torch.float16,
torch.float32,
torch.bfloat16,
],
)
@pytest.mark.parametrize("impl", ["torch", "triton", "gluon"])
@pytest.mark.parametrize("mode", ["eager_barrier", "eager_nobarrier", "graph"])
@pytest.mark.parametrize("vary", [False, True])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16], ids=["bf16", "fp16"])
# block_size_n must be a multiple of threads_per_warp * num_warps (256 at defaults).
@pytest.mark.parametrize(
"M, N, block_size_m, block_size_n",
[
# block_size_n must be a multiple of (threads_per_warp * num_warps).
# With defaults (threads_per_warp=64, num_warps=4), minimum is 256.
# elems_per_thread = block_size_n / 256: higher = wider vector loads.
(256, 256, 32, 256), # Small: elems_per_thread=1 (scalar loads)
(1024, 512, 32, 512), # Medium: elems_per_thread=2 (dword loads)
(8192, 8192, 32, 1024), # Large: elems_per_thread=4 (dwordx4, optimal)
],
[(64, 8192, 32, 1024), (256, 8192, 32, 1024)],
)
def test_all_gather_gluon(dtype, M, N, block_size_m, block_size_n):
"""Test all-gather functionality using Gluon by comparing against PyTorch's implementation."""
# Ensure torch.distributed is initialized (should be done by test runner)
def test_all_gather_gluon(impl, mode, vary, dtype, M, N, block_size_m, block_size_n):
"""Drive all-gather across impl x mode x vary and check the gathered output.

mode: eager_barrier (async_op=False, trailing ctx.barrier()), eager_nobarrier
(async_op=True, no barrier), graph (HIP-graph capture+replay, async_op=True —
the host barrier can't be captured). vary=False replays identical input;
vary=True feeds a fresh input each step, surfacing stale cross-rank reads.

Rank r fills its whole input with 1 + r + replay%16 (exact integers), so output
block r must equal 1 + r + replay%16 — any >=1 mismatch is a real drop. torch
and eager_barrier are the references; per-peer-slice fail tallies show which
peers' slices dropped."""
if not dist.is_initialized():
pytest.skip("torch.distributed not initialized")
if impl == "torch" and mode == "eager_nobarrier":
pytest.skip("torch has no barrier knob; eager_barrier already covers eager torch")

# Size heap to fit input (M*N) + output (max_ranks*M*N) with headroom
max_ranks = int(os.environ.get("WORLD_SIZE", 8))
elem_size = torch.tensor([], dtype=dtype).element_size()
needed = (1 + max_ranks) * M * N * elem_size
heap_size = max(2**30, int(needed * 2)) # 2x headroom, minimum 1GB
shmem = iris.iris(heap_size)
rank = shmem.get_rank()
world_size = shmem.get_num_ranks()

# Each rank has an M x N input tensor
# Output is (world_size * M, N) - concatenated along dimension 0
pytorch_input_tensor = torch.randn(M, N, dtype=dtype, device=f"cuda:{rank}")
# Fill with deterministic values for easier debugging
pytorch_input_tensor.fill_(float(rank + 1))

# Create output tensor for PyTorch: (world_size * M, N)
pytorch_output_tensor = torch.zeros(world_size * M, N, dtype=dtype, device=f"cuda:{rank}")

# Run PyTorch's all_gather_into_tensor to get reference output
shmem.barrier()
dist.all_gather_into_tensor(pytorch_output_tensor, pytorch_input_tensor)
torch.cuda.synchronize()
# Resolve (impl, mode) up front; the body runs straight-line off these.
async_op = mode != "eager_barrier"
capture = mode == "graph"

# Now set up Iris Gluon all_gather
iris_input_tensor = shmem.zeros((M, N), dtype=dtype)
iris_input_tensor.copy_(pytorch_input_tensor)
shmem = iris.iris(2**33) # 8 GB
Comment thread
micmelesse marked this conversation as resolved.
Outdated
rank, world_size = shmem.get_rank(), shmem.get_num_ranks()
src = torch.empty((M, N), dtype=dtype, device=f"cuda:{rank}")
stage_buf, result, config = _make_buffers(impl, shmem, rank, world_size, M, N, dtype, block_size_m, block_size_n)
shmem.barrier()
Comment thread
micmelesse marked this conversation as resolved.
Outdated

iris_output_tensor = shmem.zeros((world_size * M, N), dtype=dtype)
def fill_src(replay):
src.fill_(float(1 + rank + (replay % 16)))

# Run Iris Gluon all_gather
shmem.barrier()
config = Config(use_gluon=True, block_size_m=block_size_m, block_size_n=block_size_n)
shmem.ccl.all_gather(iris_output_tensor, iris_input_tensor, config=config)
# Warmup (runs lazy JIT/setup), then capture the step once if in graph mode.
fill_src(0)
_all_gather(impl, src, stage_buf, result, shmem, config, async_op)
torch.cuda.synchronize()
shmem.barrier()

# Compare results
atol = 1e-3 if dtype == torch.float16 else 1e-5
max_diff = torch.abs(iris_output_tensor - pytorch_output_tensor).max().item()

graph = None
if capture:
stream = torch.cuda.Stream()
stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(stream):
graph = torch.cuda.CUDAGraph()
graph.capture_begin()
_all_gather(impl, src, stage_buf, result, shmem, config, async_op)
graph.capture_end()
torch.cuda.current_stream().wait_stream(stream)
Comment thread
micmelesse marked this conversation as resolved.

atol = 0.5 # exact integer inputs; >=1 mismatch is a real drop
failures = [] # (step, max|diff|, bad_slices)
block_fail = [0] * world_size # steps each peer slice dropped
try:
assert torch.allclose(iris_output_tensor, pytorch_output_tensor, atol=atol), (
f"Max difference: {max_diff}, expected < {atol}\n"
f"Rank {rank}: Iris Gluon output doesn't match PyTorch's all_gather_into_tensor"
for i in range(NUM_REPLAYS):
replay = i if vary else 0
fill_src(replay)
if capture:
graph.replay()
else:
_all_gather(impl, src, stage_buf, result, shmem, config, async_op)
torch.cuda.synchronize()
Comment thread
micmelesse marked this conversation as resolved.
diffs = [
torch.abs(result[r * M : (r + 1) * M] - float(1 + r + (replay % 16))).max().item()
for r in range(world_size)
]
bad = [r for r in range(world_size) if diffs[r] > atol]
for r in bad:
block_fail[r] += 1
if bad:
failures.append((i, round(max(diffs[r] for r in bad), 4), bad))
print(
f"[rank {rank}] all_gather impl={impl} mode={mode} vary={vary} dtype={dtype} "
f"{M}x{N}: {NUM_REPLAYS - len(failures)}/{NUM_REPLAYS} ok; "
f"per-peer-slice fail counts={block_fail}" + (f"; first FAIL={failures[0]}" if failures else ""),
flush=True,
)
assert not failures, (
f"impl={impl} mode={mode} vary={vary} dtype={dtype} {M}x{N}: "
f"{len(failures)}/{NUM_REPLAYS} steps wrong (first {failures[0]}; per-peer-slice "
f"fail counts={block_fail})."
)
finally:
if graph is not None:
del graph
# Final barrier to ensure all ranks complete before test cleanup
# This helps with test isolation when running multiple tests
# Note: shmem.barrier() already does cuda.synchronize()
Expand Down
153 changes: 87 additions & 66 deletions tests/ccl/test_all_to_all_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import torch
import torch.distributed as dist

# Try to import Gluon, skip tests if not available
try:
import iris
from iris.ccl import Config
Expand All @@ -20,83 +19,105 @@
GLUON_AVAILABLE = False


@pytest.mark.skipif(not GLUON_AVAILABLE, reason="Gluon not available")
@pytest.mark.parametrize(
"dtype",
[
torch.float16,
torch.float32,
torch.bfloat16,
],
)
@pytest.mark.parametrize(
"M, N",
[
(128, 64), # Small
(1024, 256), # Medium
(8192, 8192), # Large
],
)
def test_all_to_all_gluon(dtype, M, N):
"""Test all-to-all functionality using Gluon with traffic shaping by comparing against PyTorch's implementation."""
# Ensure torch.distributed is initialized (should be done by test runner)
if not dist.is_initialized():
pytest.skip("torch.distributed not initialized")
NUM_REPLAYS = 200

heap_size = 2**33 # 8GB
shmem = iris.iris(heap_size)
rank = shmem.get_rank()
world_size = shmem.get_num_ranks()

# PyTorch's all_to_all format: each rank has M x N data to send to all ranks
# Create input data: each rank has its own M x N chunk
# For rank r, the data it sends to all ranks is the same (M x N tensor)
pytorch_input_tensor = torch.randn(M, N, dtype=dtype, device=f"cuda:{rank}")
# Fill with deterministic values for easier debugging
pytorch_input_tensor.fill_(float(rank))
def _all_to_all(src, stage_buf, result, shmem, config, async_op):
"""Stage src into the input buffer, then all-to-all. Module-level (no closure
over shmem) so the test can ``del shmem`` for IPC cleanup. triton
(use_gluon=False) and gluon (use_gluon=True) both dispatch through
iris.ccl.all_to_all; async_op=True skips the capture-illegal trailing barrier."""
stage_buf.copy_(src)
shmem.ccl.all_to_all(result, stage_buf, config=config, async_op=async_op)

# PyTorch all_to_all expects list of tensors: input_list[i] is sent to rank i
# Since we're sending the same data to all ranks, we replicate it
pytorch_input_list = [pytorch_input_tensor.clone() for _ in range(world_size)]
pytorch_output_list = [torch.zeros(M, N, dtype=dtype, device=f"cuda:{rank}") for _ in range(world_size)]

# Run PyTorch's all_to_all to get reference output
shmem.barrier()
dist.all_to_all(pytorch_output_list, pytorch_input_list)
torch.cuda.synchronize()

# Convert PyTorch output to concatenated format for comparison
# pytorch_output_list[i] contains data received from rank i
pytorch_output_concat = torch.zeros(M, N * world_size, dtype=dtype, device=f"cuda:{rank}")
for target_rank in range(world_size):
pytorch_output_concat[:, target_rank * N : (target_rank + 1) * N] = pytorch_output_list[target_rank]
@pytest.mark.skipif(not GLUON_AVAILABLE, reason="Gluon not available")
@pytest.mark.parametrize("impl", ["triton", "gluon"])
@pytest.mark.parametrize("mode", ["eager_barrier", "eager_nobarrier", "graph"])
@pytest.mark.parametrize("vary", [False, True])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16], ids=["bf16", "fp16"])
@pytest.mark.parametrize("M, N", [(128, 1024), (256, 1024)])
def test_all_to_all_gluon(impl, mode, vary, dtype, M, N):
"""Drive all-to-all across impl x mode x vary and check the exchanged output.
mode/vary as in test_all_gather_gluon. No torch arm — torch.distributed
all_to_all uses a different (row-split) layout and can't share this harness;
eager_barrier is the reference (correct when properly synced).

Layout is iris's (M, N*world_size) column-chunks: rank r fills its whole input
with 1 + r + replay%16, so output chunk c (columns [c*N:(c+1)*N]) must equal
1 + c + replay%16 (chunk c is rank c's data) — any >=1 mismatch is a real drop.
Per-source-chunk fail tallies show which chunks dropped."""
if not dist.is_initialized():
pytest.skip("torch.distributed not initialized")

# Now set up Iris Gluon all_to_all format
# Iris format: concatenated tensor (M, N * world_size)
# input[:, i*N:(i+1)*N] contains data to send to rank i
# Since we're sending the same M x N data to all ranks, we replicate it
iris_input_concat = shmem.zeros((M, N * world_size), dtype=dtype)
for target_rank in range(world_size):
iris_input_concat[:, target_rank * N : (target_rank + 1) * N] = pytorch_input_tensor
# Resolve mode up front; the body runs straight-line off these.
async_op = mode != "eager_barrier"
capture = mode == "graph"

shmem = iris.iris(2**33) # 8 GB
rank, world_size = shmem.get_rank(), shmem.get_num_ranks()
width = N * world_size
src = torch.empty((M, width), dtype=dtype, device=f"cuda:{rank}")
stage_buf = shmem.zeros((M, width), dtype=dtype)
result = shmem.zeros((M, width), dtype=dtype)
config = Config(use_gluon=(impl == "gluon"))
shmem.barrier()
Comment thread
micmelesse marked this conversation as resolved.

iris_output_concat = shmem.zeros((M, N * world_size), dtype=dtype)
def fill_src(replay):
src.fill_(float(1 + rank + (replay % 16)))

# Run Iris Gluon all_to_all with traffic shaping enabled
shmem.barrier()
config = Config(use_gluon=True) # Enable Gluon with traffic shaping
shmem.ccl.all_to_all(iris_output_concat, iris_input_concat, config=config)
# Warmup (runs lazy JIT/setup), then capture the step once if in graph mode.
fill_src(0)
_all_to_all(src, stage_buf, result, shmem, config, async_op)
torch.cuda.synchronize()
shmem.barrier()

# Compare results
atol = 1e-3 if dtype == torch.float16 else 1e-5
max_diff = torch.abs(iris_output_concat - pytorch_output_concat).max().item()

graph = None
if capture:
stream = torch.cuda.Stream()
stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(stream):
graph = torch.cuda.CUDAGraph()
graph.capture_begin()
_all_to_all(src, stage_buf, result, shmem, config, async_op)
graph.capture_end()
torch.cuda.current_stream().wait_stream(stream)
Comment thread
micmelesse marked this conversation as resolved.

atol = 0.5 # exact integer inputs; >=1 mismatch is a real drop
failures = [] # (step, max|diff|, bad_chunks)
chunk_fail = [0] * world_size # steps each source chunk dropped
try:
assert torch.allclose(iris_output_concat, pytorch_output_concat, atol=atol), (
f"Max difference: {max_diff}, expected < {atol}\n"
f"Rank {rank}: Iris Gluon output doesn't match PyTorch's all_to_all"
for i in range(NUM_REPLAYS):
replay = i if vary else 0
fill_src(replay)
if capture:
graph.replay()
else:
_all_to_all(src, stage_buf, result, shmem, config, async_op)
torch.cuda.synchronize()
diffs = [
torch.abs(result[:, c * N : (c + 1) * N] - float(1 + c + (replay % 16))).max().item()
for c in range(world_size)
]
bad = [c for c in range(world_size) if diffs[c] > atol]
for c in bad:
chunk_fail[c] += 1
if bad:
failures.append((i, round(max(diffs[c] for c in bad), 4), bad))
print(
f"[rank {rank}] all_to_all impl={impl} mode={mode} vary={vary} dtype={dtype} "
f"{M}x{width}: {NUM_REPLAYS - len(failures)}/{NUM_REPLAYS} ok; "
f"per-source-chunk fail counts={chunk_fail}" + (f"; first FAIL={failures[0]}" if failures else ""),
flush=True,
)
Comment thread
micmelesse marked this conversation as resolved.
assert not failures, (
f"impl={impl} mode={mode} vary={vary} dtype={dtype} {M}x{width}: "
f"{len(failures)}/{NUM_REPLAYS} steps wrong (first {failures[0]}; per-source-chunk "
f"fail counts={chunk_fail})."
)
finally:
if graph is not None:
del graph
# Final barrier to ensure all ranks complete before test cleanup
# This helps with test isolation when running multiple tests
# Note: shmem.barrier() already does cuda.synchronize()
Expand Down
Loading