diff --git a/tests/ccl/test_all_gather_gluon.py b/tests/ccl/test_all_gather_gluon.py index 3c57baead..8d8d660db 100644 --- a/tests/ccl/test_all_gather_gluon.py +++ b/tests/ccl/test_all_gather_gluon.py @@ -11,7 +11,6 @@ import torch import torch.distributed as dist -# Try to import Gluon, skip tests if not available try: import iris from iris.ccl import Config @@ -22,31 +21,63 @@ GLUON_AVAILABLE = False +NUM_REPLAYS = 200 + + +def _all_gather(impl, src, stage_buf, result, shmem, config, async_op): + """Stage src into the input buffer, then all-gather. Module-level (no closure + over shmem) so the test can ``del shmem`` for IPC cleanup.""" + stage_buf.copy_(src) + if impl == "torch": + dist.all_gather_into_tensor(result, stage_buf) + else: + shmem.ccl.all_gather(result, stage_buf, config=config, async_op=async_op) + + +def _make_buffers(impl, shmem, rank, world_size, M, N, dtype, block_size_m, block_size_n): + """Resolve impl -> (stage_buf, result, config) in one place: torch uses plain + device tensors and no config; the iris backends use symmetric-heap buffers and + a use_gluon config. Output is (world_size * M, N) — block r holds rank r's input.""" + if impl == "torch": + stage = torch.empty((M, N), dtype=dtype, device=f"cuda:{rank}") + result = torch.empty((world_size * M, N), dtype=dtype, device=f"cuda:{rank}") + return stage, result, None + stage = shmem.zeros((M, N), dtype=dtype) + result = shmem.zeros((world_size * M, N), dtype=dtype) + config = Config(use_gluon=(impl == "gluon"), block_size_m=block_size_m, block_size_n=block_size_n) + return stage, result, config + + @pytest.mark.skipif(not GLUON_AVAILABLE, reason="Gluon not available") -@pytest.mark.parametrize( - "dtype", - [ - torch.float16, - torch.float32, - torch.bfloat16, - ], -) +@pytest.mark.parametrize("impl", ["torch", "triton", "gluon"]) +@pytest.mark.parametrize("mode", ["eager_barrier", "eager_nobarrier", "graph"]) +@pytest.mark.parametrize("vary", [False, True]) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16], ids=["bf16", "fp16"]) +# block_size_n must be a multiple of threads_per_warp * num_warps (256 at defaults). @pytest.mark.parametrize( "M, N, block_size_m, block_size_n", - [ - # block_size_n must be a multiple of (threads_per_warp * num_warps). - # With defaults (threads_per_warp=64, num_warps=4), minimum is 256. - # elems_per_thread = block_size_n / 256: higher = wider vector loads. - (256, 256, 32, 256), # Small: elems_per_thread=1 (scalar loads) - (1024, 512, 32, 512), # Medium: elems_per_thread=2 (dword loads) - (8192, 8192, 32, 1024), # Large: elems_per_thread=4 (dwordx4, optimal) - ], + [(64, 8192, 32, 1024), (256, 8192, 32, 1024)], ) -def test_all_gather_gluon(dtype, M, N, block_size_m, block_size_n): - """Test all-gather functionality using Gluon by comparing against PyTorch's implementation.""" - # Ensure torch.distributed is initialized (should be done by test runner) +def test_all_gather_gluon(impl, mode, vary, dtype, M, N, block_size_m, block_size_n): + """Drive all-gather across impl x mode x vary and check the gathered output. + + mode: eager_barrier (async_op=False, trailing ctx.barrier()), eager_nobarrier + (async_op=True, no barrier), graph (HIP-graph capture+replay, async_op=True — + the host barrier can't be captured). vary=False replays identical input; + vary=True feeds a fresh input each step, surfacing stale cross-rank reads. + + Rank r fills its whole input with 1 + r + replay%16 (exact integers), so output + block r must equal 1 + r + replay%16 — any >=1 mismatch is a real drop. torch + and eager_barrier are the references; per-peer-slice fail tallies show which + peers' slices dropped.""" if not dist.is_initialized(): pytest.skip("torch.distributed not initialized") + if impl == "torch" and mode == "eager_nobarrier": + pytest.skip("torch has no barrier knob; eager_barrier already covers eager torch") + + # Resolve (impl, mode) up front; the body runs straight-line off these. + async_op = mode != "eager_barrier" + capture = mode == "graph" # Size heap to fit input (M*N) + output (max_ranks*M*N) with headroom max_ranks = int(os.environ.get("WORLD_SIZE", 8)) @@ -54,45 +85,67 @@ def test_all_gather_gluon(dtype, M, N, block_size_m, block_size_n): needed = (1 + max_ranks) * M * N * elem_size heap_size = max(2**30, int(needed * 2)) # 2x headroom, minimum 1GB shmem = iris.iris(heap_size) - rank = shmem.get_rank() - world_size = shmem.get_num_ranks() - - # Each rank has an M x N input tensor - # Output is (world_size * M, N) - concatenated along dimension 0 - pytorch_input_tensor = torch.randn(M, N, dtype=dtype, device=f"cuda:{rank}") - # Fill with deterministic values for easier debugging - pytorch_input_tensor.fill_(float(rank + 1)) - - # Create output tensor for PyTorch: (world_size * M, N) - pytorch_output_tensor = torch.zeros(world_size * M, N, dtype=dtype, device=f"cuda:{rank}") - - # Run PyTorch's all_gather_into_tensor to get reference output + rank, world_size = shmem.get_rank(), shmem.get_num_ranks() + torch.cuda.set_device(rank) + src = torch.empty((M, N), dtype=dtype, device=f"cuda:{rank}") + stage_buf, result, config = _make_buffers(impl, shmem, rank, world_size, M, N, dtype, block_size_m, block_size_n) shmem.barrier() - dist.all_gather_into_tensor(pytorch_output_tensor, pytorch_input_tensor) - torch.cuda.synchronize() - - # Now set up Iris Gluon all_gather - iris_input_tensor = shmem.zeros((M, N), dtype=dtype) - iris_input_tensor.copy_(pytorch_input_tensor) - iris_output_tensor = shmem.zeros((world_size * M, N), dtype=dtype) + def fill_src(replay): + src.fill_(float(1 + rank + (replay % 16))) - # Run Iris Gluon all_gather - shmem.barrier() - config = Config(use_gluon=True, block_size_m=block_size_m, block_size_n=block_size_n) - shmem.ccl.all_gather(iris_output_tensor, iris_input_tensor, config=config) + # Warmup (runs lazy JIT/setup), then capture the step once if in graph mode. + fill_src(0) + _all_gather(impl, src, stage_buf, result, shmem, config, async_op) torch.cuda.synchronize() + shmem.barrier() - # Compare results - atol = 1e-3 if dtype == torch.float16 else 1e-5 - max_diff = torch.abs(iris_output_tensor - pytorch_output_tensor).max().item() - + graph = None + if capture: + stream = torch.cuda.Stream() + stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(stream): + graph = torch.cuda.CUDAGraph() + graph.capture_begin() + _all_gather(impl, src, stage_buf, result, shmem, config, async_op) + graph.capture_end() + torch.cuda.current_stream().wait_stream(stream) + + atol = 0.5 # exact integer inputs; >=1 mismatch is a real drop + failures = [] # (step, max|diff|, bad_slices) + block_fail = [0] * world_size # steps each peer slice dropped try: - assert torch.allclose(iris_output_tensor, pytorch_output_tensor, atol=atol), ( - f"Max difference: {max_diff}, expected < {atol}\n" - f"Rank {rank}: Iris Gluon output doesn't match PyTorch's all_gather_into_tensor" + for i in range(NUM_REPLAYS): + replay = i if vary else 0 + fill_src(replay) + if capture: + graph.replay() + else: + _all_gather(impl, src, stage_buf, result, shmem, config, async_op) + torch.cuda.synchronize() + diffs = [ + torch.abs(result[r * M : (r + 1) * M] - float(1 + r + (replay % 16))).max().item() + for r in range(world_size) + ] + bad = [r for r in range(world_size) if diffs[r] > atol] + for r in bad: + block_fail[r] += 1 + if bad: + failures.append((i, round(max(diffs[r] for r in bad), 4), bad)) + print( + f"[rank {rank}] all_gather impl={impl} mode={mode} vary={vary} dtype={dtype} " + f"{M}x{N}: {NUM_REPLAYS - len(failures)}/{NUM_REPLAYS} ok; " + f"per-peer-slice fail counts={block_fail}" + (f"; first FAIL={failures[0]}" if failures else ""), + flush=True, + ) + assert not failures, ( + f"impl={impl} mode={mode} vary={vary} dtype={dtype} {M}x{N}: " + f"{len(failures)}/{NUM_REPLAYS} steps wrong (first {failures[0]}; per-peer-slice " + f"fail counts={block_fail})." ) finally: + if graph is not None: + del graph # Final barrier to ensure all ranks complete before test cleanup # This helps with test isolation when running multiple tests # Note: shmem.barrier() already does cuda.synchronize() diff --git a/tests/ccl/test_all_to_all_gluon.py b/tests/ccl/test_all_to_all_gluon.py index 04902726f..e8e115698 100644 --- a/tests/ccl/test_all_to_all_gluon.py +++ b/tests/ccl/test_all_to_all_gluon.py @@ -9,7 +9,6 @@ import torch import torch.distributed as dist -# Try to import Gluon, skip tests if not available try: import iris from iris.ccl import Config @@ -20,83 +19,106 @@ GLUON_AVAILABLE = False -@pytest.mark.skipif(not GLUON_AVAILABLE, reason="Gluon not available") -@pytest.mark.parametrize( - "dtype", - [ - torch.float16, - torch.float32, - torch.bfloat16, - ], -) -@pytest.mark.parametrize( - "M, N", - [ - (128, 64), # Small - (1024, 256), # Medium - (8192, 8192), # Large - ], -) -def test_all_to_all_gluon(dtype, M, N): - """Test all-to-all functionality using Gluon with traffic shaping by comparing against PyTorch's implementation.""" - # Ensure torch.distributed is initialized (should be done by test runner) - if not dist.is_initialized(): - pytest.skip("torch.distributed not initialized") +NUM_REPLAYS = 200 - heap_size = 2**33 # 8GB - shmem = iris.iris(heap_size) - rank = shmem.get_rank() - world_size = shmem.get_num_ranks() - # PyTorch's all_to_all format: each rank has M x N data to send to all ranks - # Create input data: each rank has its own M x N chunk - # For rank r, the data it sends to all ranks is the same (M x N tensor) - pytorch_input_tensor = torch.randn(M, N, dtype=dtype, device=f"cuda:{rank}") - # Fill with deterministic values for easier debugging - pytorch_input_tensor.fill_(float(rank)) +def _all_to_all(src, stage_buf, result, shmem, config, async_op): + """Stage src into the input buffer, then all-to-all. Module-level (no closure + over shmem) so the test can ``del shmem`` for IPC cleanup. triton + (use_gluon=False) and gluon (use_gluon=True) both dispatch through + iris.ccl.all_to_all; async_op=True skips the capture-illegal trailing barrier.""" + stage_buf.copy_(src) + shmem.ccl.all_to_all(result, stage_buf, config=config, async_op=async_op) - # PyTorch all_to_all expects list of tensors: input_list[i] is sent to rank i - # Since we're sending the same data to all ranks, we replicate it - pytorch_input_list = [pytorch_input_tensor.clone() for _ in range(world_size)] - pytorch_output_list = [torch.zeros(M, N, dtype=dtype, device=f"cuda:{rank}") for _ in range(world_size)] - # Run PyTorch's all_to_all to get reference output - shmem.barrier() - dist.all_to_all(pytorch_output_list, pytorch_input_list) - torch.cuda.synchronize() - - # Convert PyTorch output to concatenated format for comparison - # pytorch_output_list[i] contains data received from rank i - pytorch_output_concat = torch.zeros(M, N * world_size, dtype=dtype, device=f"cuda:{rank}") - for target_rank in range(world_size): - pytorch_output_concat[:, target_rank * N : (target_rank + 1) * N] = pytorch_output_list[target_rank] +@pytest.mark.skipif(not GLUON_AVAILABLE, reason="Gluon not available") +@pytest.mark.parametrize("impl", ["triton", "gluon"]) +@pytest.mark.parametrize("mode", ["eager_barrier", "eager_nobarrier", "graph"]) +@pytest.mark.parametrize("vary", [False, True]) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16], ids=["bf16", "fp16"]) +@pytest.mark.parametrize("M, N", [(128, 1024), (256, 1024)]) +def test_all_to_all_gluon(impl, mode, vary, dtype, M, N): + """Drive all-to-all across impl x mode x vary and check the exchanged output. + mode/vary as in test_all_gather_gluon. No torch arm — torch.distributed + all_to_all uses a different (row-split) layout and can't share this harness; + eager_barrier is the reference (correct when properly synced). + + Layout is iris's (M, N*world_size) column-chunks: rank r fills its whole input + with 1 + r + replay%16, so output chunk c (columns [c*N:(c+1)*N]) must equal + 1 + c + replay%16 (chunk c is rank c's data) — any >=1 mismatch is a real drop. + Per-source-chunk fail tallies show which chunks dropped.""" + if not dist.is_initialized(): + pytest.skip("torch.distributed not initialized") - # Now set up Iris Gluon all_to_all format - # Iris format: concatenated tensor (M, N * world_size) - # input[:, i*N:(i+1)*N] contains data to send to rank i - # Since we're sending the same M x N data to all ranks, we replicate it - iris_input_concat = shmem.zeros((M, N * world_size), dtype=dtype) - for target_rank in range(world_size): - iris_input_concat[:, target_rank * N : (target_rank + 1) * N] = pytorch_input_tensor + # Resolve mode up front; the body runs straight-line off these. + async_op = mode != "eager_barrier" + capture = mode == "graph" + + shmem = iris.iris(2**33) # 8 GB + rank, world_size = shmem.get_rank(), shmem.get_num_ranks() + torch.cuda.set_device(rank) + width = N * world_size + src = torch.empty((M, width), dtype=dtype, device=f"cuda:{rank}") + stage_buf = shmem.zeros((M, width), dtype=dtype) + result = shmem.zeros((M, width), dtype=dtype) + config = Config(use_gluon=(impl == "gluon")) + shmem.barrier() - iris_output_concat = shmem.zeros((M, N * world_size), dtype=dtype) + def fill_src(replay): + src.fill_(float(1 + rank + (replay % 16))) - # Run Iris Gluon all_to_all with traffic shaping enabled - shmem.barrier() - config = Config(use_gluon=True) # Enable Gluon with traffic shaping - shmem.ccl.all_to_all(iris_output_concat, iris_input_concat, config=config) + # Warmup (runs lazy JIT/setup), then capture the step once if in graph mode. + fill_src(0) + _all_to_all(src, stage_buf, result, shmem, config, async_op) torch.cuda.synchronize() + shmem.barrier() - # Compare results - atol = 1e-3 if dtype == torch.float16 else 1e-5 - max_diff = torch.abs(iris_output_concat - pytorch_output_concat).max().item() - + graph = None + if capture: + stream = torch.cuda.Stream() + stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(stream): + graph = torch.cuda.CUDAGraph() + graph.capture_begin() + _all_to_all(src, stage_buf, result, shmem, config, async_op) + graph.capture_end() + torch.cuda.current_stream().wait_stream(stream) + + atol = 0.5 # exact integer inputs; >=1 mismatch is a real drop + failures = [] # (step, max|diff|, bad_chunks) + chunk_fail = [0] * world_size # steps each source chunk dropped try: - assert torch.allclose(iris_output_concat, pytorch_output_concat, atol=atol), ( - f"Max difference: {max_diff}, expected < {atol}\n" - f"Rank {rank}: Iris Gluon output doesn't match PyTorch's all_to_all" + for i in range(NUM_REPLAYS): + replay = i if vary else 0 + fill_src(replay) + if capture: + graph.replay() + else: + _all_to_all(src, stage_buf, result, shmem, config, async_op) + torch.cuda.synchronize() + diffs = [ + torch.abs(result[:, c * N : (c + 1) * N] - float(1 + c + (replay % 16))).max().item() + for c in range(world_size) + ] + bad = [c for c in range(world_size) if diffs[c] > atol] + for c in bad: + chunk_fail[c] += 1 + if bad: + failures.append((i, round(max(diffs[c] for c in bad), 4), bad)) + print( + f"[rank {rank}] all_to_all impl={impl} mode={mode} vary={vary} dtype={dtype} " + f"{M}x{width}: {NUM_REPLAYS - len(failures)}/{NUM_REPLAYS} ok; " + f"per-source-chunk fail counts={chunk_fail}" + (f"; first FAIL={failures[0]}" if failures else ""), + flush=True, + ) + assert not failures, ( + f"impl={impl} mode={mode} vary={vary} dtype={dtype} {M}x{width}: " + f"{len(failures)}/{NUM_REPLAYS} steps wrong (first {failures[0]}; per-source-chunk " + f"fail counts={chunk_fail})." ) finally: + if graph is not None: + del graph # Final barrier to ensure all ranks complete before test cleanup # This helps with test isolation when running multiple tests # Note: shmem.barrier() already does cuda.synchronize()