From c687a5533846036b60ebeadd52a3ef0dcc842b0f Mon Sep 17 00:00:00 2001 From: Muhammad Awad Date: Thu, 18 Jun 2026 14:07:09 -0700 Subject: [PATCH 1/9] Replace simulation rank_bools with cross-process shm_open In simulation mode (FFM), replace the N-buffers-on-one-GPU hack with POSIX shared memory (shm_open + mmap). Each rank gets a slice of a shared region, enabling real cross-process memory sharing. FFM SVM mode means GPU VA = host VA, so mmap'd addresses are dereferenceable by GPU kernels. Validated with standalone prototype on gfx1260-ffm container. Co-Authored-By: Claude Opus 4 --- .../host/memory/allocators/torch_allocator.py | 104 ++++++++++++++---- 1 file changed, 84 insertions(+), 20 deletions(-) diff --git a/iris/host/memory/allocators/torch_allocator.py b/iris/host/memory/allocators/torch_allocator.py index bbffb6f00..715f6f0d0 100644 --- a/iris/host/memory/allocators/torch_allocator.py +++ b/iris/host/memory/allocators/torch_allocator.py @@ -8,9 +8,13 @@ sub-allocations within it using bump allocation. """ +import ctypes +import ctypes.util import logging import math +import mmap import numpy as np +import os import torch from typing import Optional, Dict import struct @@ -51,27 +55,52 @@ def __init__(self, heap_size: int, device_id: int, cur_rank: int, num_ranks: int rank=cur_rank, num_ranks=num_ranks, ) + self._shm_mmap = None + self._shm_fd = None + self._shm_name = None + if is_simulation_env(): - import json - - # In simulation, each rank allocates n distinct buffers; memory_pool is a shallow view of the ith. - self.rank_bools = [torch.empty(heap_size, device=self.device, dtype=torch.int8) for _ in range(num_ranks)] - self.memory_pool = self.rank_bools[cur_rank] - - heap_views = [self.rank_bools[r].data_ptr() for r in range(num_ranks)] - out_path = f"iris_rank_{cur_rank}_allocator_views.json" - with open(out_path, "w") as f: - json.dump( - { - "rank": cur_rank, - "num_ranks": num_ranks, - "heap_views": [hex(b) for b in heap_views], - }, - f, - indent=2, - ) + self._shm_name = f"/iris-sim-heap-{os.environ.get('SLURM_JOB_ID', os.getppid())}" + total_size = heap_size * num_ranks + librt = ctypes.CDLL(ctypes.util.find_library("rt") or "librt.so.1", use_errno=True) + librt.shm_open.argtypes = [ctypes.c_char_p, ctypes.c_int, ctypes.c_uint] + librt.shm_open.restype = ctypes.c_int + self._librt = librt + + if cur_rank == 0: + librt.shm_unlink(self._shm_name.encode()) + fd = librt.shm_open(self._shm_name.encode(), os.O_CREAT | os.O_RDWR, 0o600) + if fd < 0: + raise OSError(ctypes.get_errno(), f"shm_open create failed: {os.strerror(ctypes.get_errno())}") + os.ftruncate(fd, total_size) + else: + for _ in range(100): + fd = librt.shm_open(self._shm_name.encode(), os.O_RDWR, 0) + if fd >= 0: + break + import time + time.sleep(0.05) + else: + raise OSError(f"shm_open failed: rank 0 never created {self._shm_name}") + + self._shm_fd = fd + self._shm_mmap = mmap.mmap(fd, total_size, mmap.MAP_SHARED, mmap.PROT_READ | mmap.PROT_WRITE) + + my_offset = cur_rank * heap_size + buf = (ctypes.c_int8 * heap_size).from_buffer(self._shm_mmap, my_offset) + np_arr = np.ctypeslib.as_array(buf) + self.memory_pool = torch.from_numpy(np_arr).to(self.device) + + _log_rank( + logging.INFO, + "TorchAllocator: sim shm_open %s, rank %d slice at offset %d", + self._shm_name, + cur_rank, + my_offset, + rank=cur_rank, + num_ranks=num_ranks, + ) else: - self.rank_bools = None self.memory_pool = torch.empty(heap_size, device=self.device, dtype=torch.int8) self._peer_ext_mem_handles: Dict[int, object] = {} @@ -151,6 +180,23 @@ def establish_peer_access(self, all_bases: Dict[int, int], connections: Optional """ heap_bases_array = np.zeros(self.num_ranks, dtype=np.uint64) + if is_simulation_env() and self._shm_mmap is not None: + for rank in range(self.num_ranks): + peer_offset = rank * self.heap_size + peer_buf = (ctypes.c_int8 * self.heap_size).from_buffer(self._shm_mmap, peer_offset) + peer_np = np.ctypeslib.as_array(peer_buf) + peer_tensor = torch.from_numpy(peer_np).to(self.device) + heap_bases_array[rank] = peer_tensor.data_ptr() + self.heap_bases_array = heap_bases_array + _log_rank( + logging.INFO, + "TorchAllocator: sim peer access via shm_open, %d ranks", + self.num_ranks, + rank=self.cur_rank, + num_ranks=self.num_ranks, + ) + return + if connections is not None: for handle in self._peer_ext_mem_handles.values(): try: @@ -190,7 +236,7 @@ def establish_peer_access(self, all_bases: Dict[int, int], connections: Optional self.heap_bases_array = heap_bases_array def close(self): - """Release peer external memory handles.""" + """Release peer external memory handles and shm resources.""" for handle in self._peer_ext_mem_handles.values(): try: destroy_external_memory(handle) @@ -198,6 +244,24 @@ def close(self): pass self._peer_ext_mem_handles.clear() + if self._shm_mmap is not None: + try: + self._shm_mmap.close() + except Exception: + pass + self._shm_mmap = None + if self._shm_fd is not None: + try: + os.close(self._shm_fd) + except Exception: + pass + self._shm_fd = None + if self._shm_name is not None and self.cur_rank == 0: + try: + self._librt.shm_unlink(self._shm_name.encode()) + except Exception: + pass + def get_device(self) -> torch.device: """Get the torch device.""" return self.memory_pool.device From 3d29562ec451c8a64661ff23e5ed092892f6bd4e Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Thu, 18 Jun 2026 21:07:45 +0000 Subject: [PATCH 2/9] Apply Ruff auto-fixes --- iris/host/memory/allocators/torch_allocator.py | 1 + 1 file changed, 1 insertion(+) diff --git a/iris/host/memory/allocators/torch_allocator.py b/iris/host/memory/allocators/torch_allocator.py index 715f6f0d0..d5e69d4ed 100644 --- a/iris/host/memory/allocators/torch_allocator.py +++ b/iris/host/memory/allocators/torch_allocator.py @@ -79,6 +79,7 @@ def __init__(self, heap_size: int, device_id: int, cur_rank: int, num_ranks: int if fd >= 0: break import time + time.sleep(0.05) else: raise OSError(f"shm_open failed: rank 0 never created {self._shm_name}") From f50ea19de170c0e9b0bf446493eed353aa89c2c2 Mon Sep 17 00:00:00 2001 From: Muhammad Awad Date: Thu, 18 Jun 2026 14:14:47 -0700 Subject: [PATCH 3/9] Fix shm_open sim mode: keep CPU tensors, wire up peer access MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - memory_pool stays as CPU tensor (no .to(device)) so data_ptr() returns the mmap host VA — valid for FFM SVM dereference - establish_peer_access creates local mmap views for each peer's slice, stores references to prevent GC - symmetric_heap._refresh_peer_access_torch now calls establish_peer_access in sim mode instead of using raw allgather bases (which are remote VAs, invalid in this process) Co-Authored-By: Claude Opus 4 --- iris/host/memory/allocators/torch_allocator.py | 6 ++++-- iris/host/memory/symmetric_heap.py | 4 +++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/iris/host/memory/allocators/torch_allocator.py b/iris/host/memory/allocators/torch_allocator.py index d5e69d4ed..dc934bd91 100644 --- a/iris/host/memory/allocators/torch_allocator.py +++ b/iris/host/memory/allocators/torch_allocator.py @@ -90,7 +90,7 @@ def __init__(self, heap_size: int, device_id: int, cur_rank: int, num_ranks: int my_offset = cur_rank * heap_size buf = (ctypes.c_int8 * heap_size).from_buffer(self._shm_mmap, my_offset) np_arr = np.ctypeslib.as_array(buf) - self.memory_pool = torch.from_numpy(np_arr).to(self.device) + self.memory_pool = torch.from_numpy(np_arr) _log_rank( logging.INFO, @@ -182,11 +182,13 @@ def establish_peer_access(self, all_bases: Dict[int, int], connections: Optional heap_bases_array = np.zeros(self.num_ranks, dtype=np.uint64) if is_simulation_env() and self._shm_mmap is not None: + self._shm_peer_views = [] for rank in range(self.num_ranks): peer_offset = rank * self.heap_size peer_buf = (ctypes.c_int8 * self.heap_size).from_buffer(self._shm_mmap, peer_offset) peer_np = np.ctypeslib.as_array(peer_buf) - peer_tensor = torch.from_numpy(peer_np).to(self.device) + peer_tensor = torch.from_numpy(peer_np) + self._shm_peer_views.append(peer_tensor) heap_bases_array[rank] = peer_tensor.data_ptr() self.heap_bases_array = heap_bases_array _log_rank( diff --git a/iris/host/memory/symmetric_heap.py b/iris/host/memory/symmetric_heap.py index 28165288b..8f87db44f 100644 --- a/iris/host/memory/symmetric_heap.py +++ b/iris/host/memory/symmetric_heap.py @@ -400,8 +400,10 @@ def _refresh_peer_access_torch(self, dist, all_bases_arr): from iris.host.platform.utils import is_simulation_env if is_simulation_env(): + all_bases = {r: int(all_bases_arr[r]) for r in range(self.num_ranks)} + self.allocator.establish_peer_access(all_bases) for r in range(self.num_ranks): - self.heap_bases[r] = int(all_bases_arr[r]) + self.heap_bases[r] = int(self.allocator.heap_bases_array[r]) else: all_bases = {r: int(all_bases_arr[r]) for r in range(self.num_ranks)} self.allocator.establish_peer_access(all_bases, self.fd_conns) From 122b80bf31a8d242a165ffd3afa6a48b9a0b99f3 Mon Sep 17 00:00:00 2001 From: Muhammad Awad Date: Thu, 18 Jun 2026 14:17:45 -0700 Subject: [PATCH 4/9] Return CUDA device in sim mode for device check compatibility In FFM SVM mode, CPU VA = GPU VA. memory_pool is backed by shm mmap (CPU tensor) but get_device() returns cuda:N so iris device checks pass and examples work unchanged. Co-Authored-By: Claude Opus 4 --- iris/host/memory/allocators/torch_allocator.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/iris/host/memory/allocators/torch_allocator.py b/iris/host/memory/allocators/torch_allocator.py index dc934bd91..7b90ae400 100644 --- a/iris/host/memory/allocators/torch_allocator.py +++ b/iris/host/memory/allocators/torch_allocator.py @@ -267,6 +267,8 @@ def close(self): def get_device(self) -> torch.device: """Get the torch device.""" + if is_simulation_env(): + return torch.device(self.device) return self.memory_pool.device def import_external_tensor(self, external_tensor: torch.Tensor) -> torch.Tensor: From e84559aec0b98b3274df780359f454a9c59e700f Mon Sep 17 00:00:00 2001 From: Muhammad Awad Date: Thu, 18 Jun 2026 14:20:31 -0700 Subject: [PATCH 5/9] Register shm mmap with hipHostRegister for Triton pointer check Triton AMD driver validates pointers via hipPointerGetAttribute before kernel launch. CPU tensor pointers from shm mmap fail this check. hipHostRegister marks the shm region as device-accessible, making the check pass without patching Triton. Co-Authored-By: Claude Opus 4 --- .../host/memory/allocators/torch_allocator.py | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/iris/host/memory/allocators/torch_allocator.py b/iris/host/memory/allocators/torch_allocator.py index 7b90ae400..657ac6427 100644 --- a/iris/host/memory/allocators/torch_allocator.py +++ b/iris/host/memory/allocators/torch_allocator.py @@ -87,6 +87,21 @@ def __init__(self, heap_size: int, device_id: int, cur_rank: int, num_ranks: int self._shm_fd = fd self._shm_mmap = mmap.mmap(fd, total_size, mmap.MAP_SHARED, mmap.PROT_READ | mmap.PROT_WRITE) + # Register the entire shm region with HIP so Triton's pointer check passes + mmap_base = ctypes.addressof(ctypes.c_char.from_buffer(self._shm_mmap)) + try: + libhip = ctypes.CDLL("libamdhip64.so", use_errno=True) + libhip.hipHostRegister.argtypes = [ctypes.c_void_p, ctypes.c_size_t, ctypes.c_uint] + libhip.hipHostRegister.restype = ctypes.c_int + err = libhip.hipHostRegister(ctypes.c_void_p(mmap_base), ctypes.c_size_t(total_size), ctypes.c_uint(0)) + if err != 0: + _log_rank(logging.WARNING, "hipHostRegister returned %d", err, rank=cur_rank, num_ranks=num_ranks) + self._hip_registered_ptr = mmap_base + self._libhip = libhip + except Exception as e: + _log_rank(logging.WARNING, "hipHostRegister failed: %s", str(e), rank=cur_rank, num_ranks=num_ranks) + self._hip_registered_ptr = None + my_offset = cur_rank * heap_size buf = (ctypes.c_int8 * heap_size).from_buffer(self._shm_mmap, my_offset) np_arr = np.ctypeslib.as_array(buf) @@ -247,6 +262,12 @@ def close(self): pass self._peer_ext_mem_handles.clear() + if getattr(self, '_hip_registered_ptr', None) is not None: + try: + self._libhip.hipHostUnregister(ctypes.c_void_p(self._hip_registered_ptr)) + except Exception: + pass + self._hip_registered_ptr = None if self._shm_mmap is not None: try: self._shm_mmap.close() From 62518907edd7bb16e4cefb3c02e77481874124cd Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Thu, 18 Jun 2026 21:21:15 +0000 Subject: [PATCH 6/9] Apply Ruff auto-fixes --- iris/host/memory/allocators/torch_allocator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/iris/host/memory/allocators/torch_allocator.py b/iris/host/memory/allocators/torch_allocator.py index 657ac6427..0036c7ab9 100644 --- a/iris/host/memory/allocators/torch_allocator.py +++ b/iris/host/memory/allocators/torch_allocator.py @@ -262,7 +262,7 @@ def close(self): pass self._peer_ext_mem_handles.clear() - if getattr(self, '_hip_registered_ptr', None) is not None: + if getattr(self, "_hip_registered_ptr", None) is not None: try: self._libhip.hipHostUnregister(ctypes.c_void_p(self._hip_registered_ptr)) except Exception: From 1d9692409fd769d87c90d0d62e61d020b6cfded7 Mon Sep 17 00:00:00 2001 From: Muhammad Awad Date: Thu, 18 Jun 2026 14:55:03 -0700 Subject: [PATCH 7/9] Add single-kernel message passing example Same semantics as example 31 but uses one unified kernel that branches on cur_rank instead of separate producer/consumer kernels. Produces a single kernel dispatch per rank for downstream capture tools. Co-Authored-By: Claude Opus 4 --- .../example.py | 166 ++++++++++++++++++ 1 file changed, 166 insertions(+) create mode 100644 examples/33_message_passing_single_kernel/example.py diff --git a/examples/33_message_passing_single_kernel/example.py b/examples/33_message_passing_single_kernel/example.py new file mode 100644 index 000000000..dcbfad2fd --- /dev/null +++ b/examples/33_message_passing_single_kernel/example.py @@ -0,0 +1,166 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2026 Advanced Micro Devices, Inc. All rights reserved. + +""" +Example: producer-consumer message passing with a single unified kernel. + +Same semantics as example 31 (message_passing) but uses one kernel that +branches internally on cur_rank. This produces a single kernel dispatch +per rank, which is required for downstream capture tools (e.g. GUMMi). + +Producer rank writes source data to consumer's buffer and signals via flag. +Consumer rank spin-waits on flag then reads and doubles the data. +Requires exactly 2 ranks. + +Run with: + torchrun --nproc_per_node=2 --standalone example.py [--validate] +""" + +import argparse +import os +import random + +import torch +import torch.distributed as dist +import triton +import triton.language as tl + +import iris + + +@triton.jit +def message_passing_kernel( + source_buffer, + target_buffer, + flag, + buffer_size, + cur_rank, + producer_rank: tl.constexpr, + consumer_rank: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + heap_bases_ptr: tl.tensor, +): + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < buffer_size + + is_producer = cur_rank == producer_rank + + if is_producer: + values = iris.load(source_buffer + offsets, producer_rank, producer_rank, heap_bases_ptr, mask=mask) + iris.store(target_buffer + offsets, values, producer_rank, consumer_rank, heap_bases_ptr, mask=mask) + iris.atomic_cas(flag + pid, 0, 1, producer_rank, consumer_rank, heap_bases_ptr, sem="release", scope="sys") + else: + done = 0 + while done == 0: + done = iris.atomic_cas( + flag + pid, 1, 0, consumer_rank, consumer_rank, heap_bases_ptr, sem="acquire", scope="sys" + ) + values = iris.load(target_buffer + offsets, consumer_rank, consumer_rank, heap_bases_ptr, mask=mask) + values = values * 2 + iris.store(target_buffer + offsets, values, consumer_rank, consumer_rank, heap_bases_ptr, mask=mask) + + +torch.manual_seed(123) +random.seed(123) + + +def torch_dtype_from_str(datatype: str) -> torch.dtype: + dtype_map = { + "fp16": torch.float16, + "fp32": torch.float32, + "int8": torch.int8, + "bf16": torch.bfloat16, + } + try: + return dtype_map[datatype] + except KeyError: + raise ValueError(f"Unknown datatype: {datatype}") + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Message passing producer-consumer example with single kernel (2 ranks).", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "-t", + "--datatype", + type=str, + default="fp32", + choices=["fp16", "fp32", "int8", "bf16"], + help="Datatype of computation", + ) + parser.add_argument("-s", "--buffer_size", type=int, default=4096, help="Buffer size") + parser.add_argument("-b", "--block_size", type=int, default=512, help="Block size") + parser.add_argument("--heap_size", type=int, default=1 << 16, help="Iris heap size") + parser.add_argument("-v", "--validate", action="store_true", help="Validate output against reference") + return vars(parser.parse_args()) + + +def main(): + args = parse_args() + + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + torch.cuda.set_device(local_rank) + dist.init_process_group(backend="nccl") + + ctx = iris.iris(heap_size=args["heap_size"]) + cur_rank = ctx.get_rank() + world_size = ctx.get_num_ranks() + + if world_size != 2: + raise ValueError("This example requires exactly two processes. Use: torchrun --nproc_per_node=2 ...") + + dtype = torch_dtype_from_str(args["datatype"]) + producer_rank = 0 + consumer_rank = 1 + + source_buffer = ctx.zeros(args["buffer_size"], device="cuda", dtype=dtype) + if dtype.is_floating_point: + destination_buffer = ctx.randn(args["buffer_size"], device="cuda", dtype=dtype) + else: + ii = torch.iinfo(dtype) + destination_buffer = ctx.randint(ii.min, ii.max, (args["buffer_size"],), device="cuda", dtype=dtype) + + n_elements = source_buffer.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) + num_blocks = triton.cdiv(n_elements, args["block_size"]) + flags = ctx.zeros((num_blocks,), device="cuda", dtype=torch.int32) + + heap_bases = ctx.get_heap_bases() + + ctx.info(f"Rank {cur_rank} launching message_passing_kernel.") + message_passing_kernel[grid]( + source_buffer, + destination_buffer, + flags, + n_elements, + cur_rank, + producer_rank, + consumer_rank, + args["block_size"], + heap_bases, + ) + + ctx.barrier() + ctx.info(f"Rank {cur_rank} has finished.") + + if args["validate"]: + ctx.info("Validating output...") + if cur_rank == consumer_rank: + expected = source_buffer * 2 + if not torch.allclose(destination_buffer, expected, atol=1): + max_diff = (destination_buffer - expected).abs().max().item() + ctx.error(f"Validation failed. Max absolute difference: {max_diff}") + else: + ctx.info("Validation successful.") + + ctx.barrier() + dist.destroy_process_group() + + +if __name__ == "__main__": + main() From 19752331d0377dd2d22e6343af92d965c8696789 Mon Sep 17 00:00:00 2001 From: Muhammad Awad Date: Thu, 18 Jun 2026 14:56:04 -0700 Subject: [PATCH 8/9] Use gloo backend in sim mode for FFM compatibility Co-Authored-By: Claude Opus 4 --- examples/33_message_passing_single_kernel/example.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/33_message_passing_single_kernel/example.py b/examples/33_message_passing_single_kernel/example.py index dcbfad2fd..549ecfe40 100644 --- a/examples/33_message_passing_single_kernel/example.py +++ b/examples/33_message_passing_single_kernel/example.py @@ -27,6 +27,7 @@ import triton.language as tl import iris +from iris.host.platform.utils import is_simulation_env @triton.jit @@ -105,7 +106,8 @@ def main(): local_rank = int(os.environ.get("LOCAL_RANK", 0)) torch.cuda.set_device(local_rank) - dist.init_process_group(backend="nccl") + backend = "gloo" if is_simulation_env() else "nccl" + dist.init_process_group(backend=backend) ctx = iris.iris(heap_size=args["heap_size"]) cur_rank = ctx.get_rank() From dc9aae7600fc350def16f5b21d3b89bc1374d90d Mon Sep 17 00:00:00 2001 From: Muhammad Awad Date: Thu, 18 Jun 2026 14:57:05 -0700 Subject: [PATCH 9/9] Use gloo backend always in example 33 Co-Authored-By: Claude Opus 4 --- examples/33_message_passing_single_kernel/example.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/examples/33_message_passing_single_kernel/example.py b/examples/33_message_passing_single_kernel/example.py index 549ecfe40..dc77da30c 100644 --- a/examples/33_message_passing_single_kernel/example.py +++ b/examples/33_message_passing_single_kernel/example.py @@ -27,7 +27,6 @@ import triton.language as tl import iris -from iris.host.platform.utils import is_simulation_env @triton.jit @@ -106,8 +105,7 @@ def main(): local_rank = int(os.environ.get("LOCAL_RANK", 0)) torch.cuda.set_device(local_rank) - backend = "gloo" if is_simulation_env() else "nccl" - dist.init_process_group(backend=backend) + dist.init_process_group(backend="gloo") ctx = iris.iris(heap_size=args["heap_size"]) cur_rank = ctx.get_rank()