Skip to content
166 changes: 166 additions & 0 deletions examples/33_message_passing_single_kernel/example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
#!/usr/bin/env python3
# SPDX-License-Identifier: MIT
# Copyright (c) 2026 Advanced Micro Devices, Inc. All rights reserved.

"""
Example: producer-consumer message passing with a single unified kernel.
Same semantics as example 31 (message_passing) but uses one kernel that
branches internally on cur_rank. This produces a single kernel dispatch
per rank, which is required for downstream capture tools (e.g. GUMMi).
Producer rank writes source data to consumer's buffer and signals via flag.
Consumer rank spin-waits on flag then reads and doubles the data.
Requires exactly 2 ranks.
Run with:
torchrun --nproc_per_node=2 --standalone example.py [--validate]
"""

import argparse
import os
import random

import torch
import torch.distributed as dist
import triton
import triton.language as tl

import iris


@triton.jit
def message_passing_kernel(
source_buffer,
target_buffer,
flag,
buffer_size,
cur_rank,
producer_rank: tl.constexpr,
consumer_rank: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
heap_bases_ptr: tl.tensor,
):
pid = tl.program_id(0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < buffer_size

is_producer = cur_rank == producer_rank

if is_producer:
values = iris.load(source_buffer + offsets, producer_rank, producer_rank, heap_bases_ptr, mask=mask)
iris.store(target_buffer + offsets, values, producer_rank, consumer_rank, heap_bases_ptr, mask=mask)
iris.atomic_cas(flag + pid, 0, 1, producer_rank, consumer_rank, heap_bases_ptr, sem="release", scope="sys")
else:
done = 0
while done == 0:
done = iris.atomic_cas(
flag + pid, 1, 0, consumer_rank, consumer_rank, heap_bases_ptr, sem="acquire", scope="sys"
)
values = iris.load(target_buffer + offsets, consumer_rank, consumer_rank, heap_bases_ptr, mask=mask)
values = values * 2
iris.store(target_buffer + offsets, values, consumer_rank, consumer_rank, heap_bases_ptr, mask=mask)


torch.manual_seed(123)
random.seed(123)


def torch_dtype_from_str(datatype: str) -> torch.dtype:
dtype_map = {
"fp16": torch.float16,
"fp32": torch.float32,
"int8": torch.int8,
"bf16": torch.bfloat16,
}
try:
return dtype_map[datatype]
except KeyError:
raise ValueError(f"Unknown datatype: {datatype}")


def parse_args():
parser = argparse.ArgumentParser(
description="Message passing producer-consumer example with single kernel (2 ranks).",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"-t",
"--datatype",
type=str,
default="fp32",
choices=["fp16", "fp32", "int8", "bf16"],
help="Datatype of computation",
)
parser.add_argument("-s", "--buffer_size", type=int, default=4096, help="Buffer size")
parser.add_argument("-b", "--block_size", type=int, default=512, help="Block size")
parser.add_argument("--heap_size", type=int, default=1 << 16, help="Iris heap size")
parser.add_argument("-v", "--validate", action="store_true", help="Validate output against reference")
return vars(parser.parse_args())


def main():
args = parse_args()

local_rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(local_rank)
dist.init_process_group(backend="gloo")

ctx = iris.iris(heap_size=args["heap_size"])
cur_rank = ctx.get_rank()
world_size = ctx.get_num_ranks()

if world_size != 2:
raise ValueError("This example requires exactly two processes. Use: torchrun --nproc_per_node=2 ...")

dtype = torch_dtype_from_str(args["datatype"])
producer_rank = 0
consumer_rank = 1

source_buffer = ctx.zeros(args["buffer_size"], device="cuda", dtype=dtype)
if dtype.is_floating_point:
destination_buffer = ctx.randn(args["buffer_size"], device="cuda", dtype=dtype)
else:
ii = torch.iinfo(dtype)
destination_buffer = ctx.randint(ii.min, ii.max, (args["buffer_size"],), device="cuda", dtype=dtype)

n_elements = source_buffer.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
num_blocks = triton.cdiv(n_elements, args["block_size"])
flags = ctx.zeros((num_blocks,), device="cuda", dtype=torch.int32)

heap_bases = ctx.get_heap_bases()

ctx.info(f"Rank {cur_rank} launching message_passing_kernel.")
message_passing_kernel[grid](
source_buffer,
destination_buffer,
flags,
n_elements,
cur_rank,
producer_rank,
consumer_rank,
args["block_size"],
heap_bases,
)

ctx.barrier()
ctx.info(f"Rank {cur_rank} has finished.")

if args["validate"]:
ctx.info("Validating output...")
if cur_rank == consumer_rank:
expected = source_buffer * 2
if not torch.allclose(destination_buffer, expected, atol=1):
max_diff = (destination_buffer - expected).abs().max().item()
ctx.error(f"Validation failed. Max absolute difference: {max_diff}")
else:
ctx.info("Validation successful.")

ctx.barrier()
dist.destroy_process_group()


if __name__ == "__main__":
main()
130 changes: 110 additions & 20 deletions iris/host/memory/allocators/torch_allocator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,13 @@
sub-allocations within it using bump allocation.
"""

import ctypes
import ctypes.util
import logging
import math
import mmap
import numpy as np
import os
import torch
from typing import Optional, Dict
import struct
Expand Down Expand Up @@ -51,27 +55,68 @@ def __init__(self, heap_size: int, device_id: int, cur_rank: int, num_ranks: int
rank=cur_rank,
num_ranks=num_ranks,
)
self._shm_mmap = None
self._shm_fd = None
self._shm_name = None

if is_simulation_env():
import json

# In simulation, each rank allocates n distinct buffers; memory_pool is a shallow view of the ith.
self.rank_bools = [torch.empty(heap_size, device=self.device, dtype=torch.int8) for _ in range(num_ranks)]
self.memory_pool = self.rank_bools[cur_rank]

heap_views = [self.rank_bools[r].data_ptr() for r in range(num_ranks)]
out_path = f"iris_rank_{cur_rank}_allocator_views.json"
with open(out_path, "w") as f:
json.dump(
{
"rank": cur_rank,
"num_ranks": num_ranks,
"heap_views": [hex(b) for b in heap_views],
},
f,
indent=2,
)
self._shm_name = f"/iris-sim-heap-{os.environ.get('SLURM_JOB_ID', os.getppid())}"
Comment thread
mawad-amd marked this conversation as resolved.
total_size = heap_size * num_ranks
librt = ctypes.CDLL(ctypes.util.find_library("rt") or "librt.so.1", use_errno=True)
librt.shm_open.argtypes = [ctypes.c_char_p, ctypes.c_int, ctypes.c_uint]
librt.shm_open.restype = ctypes.c_int
self._librt = librt

if cur_rank == 0:
librt.shm_unlink(self._shm_name.encode())
fd = librt.shm_open(self._shm_name.encode(), os.O_CREAT | os.O_RDWR, 0o600)
if fd < 0:
raise OSError(ctypes.get_errno(), f"shm_open create failed: {os.strerror(ctypes.get_errno())}")
os.ftruncate(fd, total_size)
else:
for _ in range(100):
fd = librt.shm_open(self._shm_name.encode(), os.O_RDWR, 0)
if fd >= 0:
break
import time

time.sleep(0.05)
else:
raise OSError(f"shm_open failed: rank 0 never created {self._shm_name}")

self._shm_fd = fd
self._shm_mmap = mmap.mmap(fd, total_size, mmap.MAP_SHARED, mmap.PROT_READ | mmap.PROT_WRITE)

# Register the entire shm region with HIP so Triton's pointer check passes
mmap_base = ctypes.addressof(ctypes.c_char.from_buffer(self._shm_mmap))
try:
libhip = ctypes.CDLL("libamdhip64.so", use_errno=True)
libhip.hipHostRegister.argtypes = [ctypes.c_void_p, ctypes.c_size_t, ctypes.c_uint]
libhip.hipHostRegister.restype = ctypes.c_int
err = libhip.hipHostRegister(ctypes.c_void_p(mmap_base), ctypes.c_size_t(total_size), ctypes.c_uint(0))
if err != 0:
_log_rank(logging.WARNING, "hipHostRegister returned %d", err, rank=cur_rank, num_ranks=num_ranks)
self._hip_registered_ptr = mmap_base
self._libhip = libhip
except Exception as e:
_log_rank(logging.WARNING, "hipHostRegister failed: %s", str(e), rank=cur_rank, num_ranks=num_ranks)
self._hip_registered_ptr = None

my_offset = cur_rank * heap_size
buf = (ctypes.c_int8 * heap_size).from_buffer(self._shm_mmap, my_offset)
np_arr = np.ctypeslib.as_array(buf)
self.memory_pool = torch.from_numpy(np_arr)

_log_rank(
logging.INFO,
"TorchAllocator: sim shm_open %s, rank %d slice at offset %d",
self._shm_name,
cur_rank,
my_offset,
rank=cur_rank,
num_ranks=num_ranks,
)
else:
self.rank_bools = None
self.memory_pool = torch.empty(heap_size, device=self.device, dtype=torch.int8)

self._peer_ext_mem_handles: Dict[int, object] = {}
Expand Down Expand Up @@ -151,6 +196,25 @@ def establish_peer_access(self, all_bases: Dict[int, int], connections: Optional
"""
heap_bases_array = np.zeros(self.num_ranks, dtype=np.uint64)

if is_simulation_env() and self._shm_mmap is not None:
self._shm_peer_views = []
for rank in range(self.num_ranks):
peer_offset = rank * self.heap_size
peer_buf = (ctypes.c_int8 * self.heap_size).from_buffer(self._shm_mmap, peer_offset)
peer_np = np.ctypeslib.as_array(peer_buf)
peer_tensor = torch.from_numpy(peer_np)
self._shm_peer_views.append(peer_tensor)
heap_bases_array[rank] = peer_tensor.data_ptr()
self.heap_bases_array = heap_bases_array
_log_rank(
logging.INFO,
"TorchAllocator: sim peer access via shm_open, %d ranks",
self.num_ranks,
rank=self.cur_rank,
num_ranks=self.num_ranks,
)
return
Comment thread
mawad-amd marked this conversation as resolved.

if connections is not None:
for handle in self._peer_ext_mem_handles.values():
try:
Expand Down Expand Up @@ -190,16 +254,42 @@ def establish_peer_access(self, all_bases: Dict[int, int], connections: Optional
self.heap_bases_array = heap_bases_array

def close(self):
"""Release peer external memory handles."""
"""Release peer external memory handles and shm resources."""
for handle in self._peer_ext_mem_handles.values():
try:
destroy_external_memory(handle)
except Exception:
pass
self._peer_ext_mem_handles.clear()

if getattr(self, "_hip_registered_ptr", None) is not None:
try:
self._libhip.hipHostUnregister(ctypes.c_void_p(self._hip_registered_ptr))
except Exception:
pass
self._hip_registered_ptr = None
if self._shm_mmap is not None:
try:
self._shm_mmap.close()
except Exception:
pass
self._shm_mmap = None
if self._shm_fd is not None:
try:
os.close(self._shm_fd)
except Exception:
pass
self._shm_fd = None
if self._shm_name is not None and self.cur_rank == 0:
try:
self._librt.shm_unlink(self._shm_name.encode())
except Exception:
pass

def get_device(self) -> torch.device:
"""Get the torch device."""
if is_simulation_env():
return torch.device(self.device)
return self.memory_pool.device
Comment thread
mawad-amd marked this conversation as resolved.

def import_external_tensor(self, external_tensor: torch.Tensor) -> torch.Tensor:
Expand Down
4 changes: 3 additions & 1 deletion iris/host/memory/symmetric_heap.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,8 +400,10 @@ def _refresh_peer_access_torch(self, dist, all_bases_arr):
from iris.host.platform.utils import is_simulation_env

if is_simulation_env():
all_bases = {r: int(all_bases_arr[r]) for r in range(self.num_ranks)}
self.allocator.establish_peer_access(all_bases)
for r in range(self.num_ranks):
self.heap_bases[r] = int(all_bases_arr[r])
self.heap_bases[r] = int(self.allocator.heap_bases_array[r])
else:
all_bases = {r: int(all_bases_arr[r]) for r in range(self.num_ranks)}
self.allocator.establish_peer_access(all_bases, self.fd_conns)
Expand Down
Loading