diff --git a/examples/33_message_passing_single_kernel/example.py b/examples/33_message_passing_single_kernel/example.py new file mode 100644 index 000000000..dc77da30c --- /dev/null +++ b/examples/33_message_passing_single_kernel/example.py @@ -0,0 +1,166 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2026 Advanced Micro Devices, Inc. All rights reserved. + +""" +Example: producer-consumer message passing with a single unified kernel. + +Same semantics as example 31 (message_passing) but uses one kernel that +branches internally on cur_rank. This produces a single kernel dispatch +per rank, which is required for downstream capture tools (e.g. GUMMi). + +Producer rank writes source data to consumer's buffer and signals via flag. +Consumer rank spin-waits on flag then reads and doubles the data. +Requires exactly 2 ranks. + +Run with: + torchrun --nproc_per_node=2 --standalone example.py [--validate] +""" + +import argparse +import os +import random + +import torch +import torch.distributed as dist +import triton +import triton.language as tl + +import iris + + +@triton.jit +def message_passing_kernel( + source_buffer, + target_buffer, + flag, + buffer_size, + cur_rank, + producer_rank: tl.constexpr, + consumer_rank: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + heap_bases_ptr: tl.tensor, +): + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < buffer_size + + is_producer = cur_rank == producer_rank + + if is_producer: + values = iris.load(source_buffer + offsets, producer_rank, producer_rank, heap_bases_ptr, mask=mask) + iris.store(target_buffer + offsets, values, producer_rank, consumer_rank, heap_bases_ptr, mask=mask) + iris.atomic_cas(flag + pid, 0, 1, producer_rank, consumer_rank, heap_bases_ptr, sem="release", scope="sys") + else: + done = 0 + while done == 0: + done = iris.atomic_cas( + flag + pid, 1, 0, consumer_rank, consumer_rank, heap_bases_ptr, sem="acquire", scope="sys" + ) + values = iris.load(target_buffer + offsets, consumer_rank, consumer_rank, heap_bases_ptr, mask=mask) + values = values * 2 + iris.store(target_buffer + offsets, values, consumer_rank, consumer_rank, heap_bases_ptr, mask=mask) + + +torch.manual_seed(123) +random.seed(123) + + +def torch_dtype_from_str(datatype: str) -> torch.dtype: + dtype_map = { + "fp16": torch.float16, + "fp32": torch.float32, + "int8": torch.int8, + "bf16": torch.bfloat16, + } + try: + return dtype_map[datatype] + except KeyError: + raise ValueError(f"Unknown datatype: {datatype}") + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Message passing producer-consumer example with single kernel (2 ranks).", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "-t", + "--datatype", + type=str, + default="fp32", + choices=["fp16", "fp32", "int8", "bf16"], + help="Datatype of computation", + ) + parser.add_argument("-s", "--buffer_size", type=int, default=4096, help="Buffer size") + parser.add_argument("-b", "--block_size", type=int, default=512, help="Block size") + parser.add_argument("--heap_size", type=int, default=1 << 16, help="Iris heap size") + parser.add_argument("-v", "--validate", action="store_true", help="Validate output against reference") + return vars(parser.parse_args()) + + +def main(): + args = parse_args() + + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + torch.cuda.set_device(local_rank) + dist.init_process_group(backend="gloo") + + ctx = iris.iris(heap_size=args["heap_size"]) + cur_rank = ctx.get_rank() + world_size = ctx.get_num_ranks() + + if world_size != 2: + raise ValueError("This example requires exactly two processes. Use: torchrun --nproc_per_node=2 ...") + + dtype = torch_dtype_from_str(args["datatype"]) + producer_rank = 0 + consumer_rank = 1 + + source_buffer = ctx.zeros(args["buffer_size"], device="cuda", dtype=dtype) + if dtype.is_floating_point: + destination_buffer = ctx.randn(args["buffer_size"], device="cuda", dtype=dtype) + else: + ii = torch.iinfo(dtype) + destination_buffer = ctx.randint(ii.min, ii.max, (args["buffer_size"],), device="cuda", dtype=dtype) + + n_elements = source_buffer.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) + num_blocks = triton.cdiv(n_elements, args["block_size"]) + flags = ctx.zeros((num_blocks,), device="cuda", dtype=torch.int32) + + heap_bases = ctx.get_heap_bases() + + ctx.info(f"Rank {cur_rank} launching message_passing_kernel.") + message_passing_kernel[grid]( + source_buffer, + destination_buffer, + flags, + n_elements, + cur_rank, + producer_rank, + consumer_rank, + args["block_size"], + heap_bases, + ) + + ctx.barrier() + ctx.info(f"Rank {cur_rank} has finished.") + + if args["validate"]: + ctx.info("Validating output...") + if cur_rank == consumer_rank: + expected = source_buffer * 2 + if not torch.allclose(destination_buffer, expected, atol=1): + max_diff = (destination_buffer - expected).abs().max().item() + ctx.error(f"Validation failed. Max absolute difference: {max_diff}") + else: + ctx.info("Validation successful.") + + ctx.barrier() + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/iris/host/memory/allocators/torch_allocator.py b/iris/host/memory/allocators/torch_allocator.py index bbffb6f00..0036c7ab9 100644 --- a/iris/host/memory/allocators/torch_allocator.py +++ b/iris/host/memory/allocators/torch_allocator.py @@ -8,9 +8,13 @@ sub-allocations within it using bump allocation. """ +import ctypes +import ctypes.util import logging import math +import mmap import numpy as np +import os import torch from typing import Optional, Dict import struct @@ -51,27 +55,68 @@ def __init__(self, heap_size: int, device_id: int, cur_rank: int, num_ranks: int rank=cur_rank, num_ranks=num_ranks, ) + self._shm_mmap = None + self._shm_fd = None + self._shm_name = None + if is_simulation_env(): - import json - - # In simulation, each rank allocates n distinct buffers; memory_pool is a shallow view of the ith. - self.rank_bools = [torch.empty(heap_size, device=self.device, dtype=torch.int8) for _ in range(num_ranks)] - self.memory_pool = self.rank_bools[cur_rank] - - heap_views = [self.rank_bools[r].data_ptr() for r in range(num_ranks)] - out_path = f"iris_rank_{cur_rank}_allocator_views.json" - with open(out_path, "w") as f: - json.dump( - { - "rank": cur_rank, - "num_ranks": num_ranks, - "heap_views": [hex(b) for b in heap_views], - }, - f, - indent=2, - ) + self._shm_name = f"/iris-sim-heap-{os.environ.get('SLURM_JOB_ID', os.getppid())}" + total_size = heap_size * num_ranks + librt = ctypes.CDLL(ctypes.util.find_library("rt") or "librt.so.1", use_errno=True) + librt.shm_open.argtypes = [ctypes.c_char_p, ctypes.c_int, ctypes.c_uint] + librt.shm_open.restype = ctypes.c_int + self._librt = librt + + if cur_rank == 0: + librt.shm_unlink(self._shm_name.encode()) + fd = librt.shm_open(self._shm_name.encode(), os.O_CREAT | os.O_RDWR, 0o600) + if fd < 0: + raise OSError(ctypes.get_errno(), f"shm_open create failed: {os.strerror(ctypes.get_errno())}") + os.ftruncate(fd, total_size) + else: + for _ in range(100): + fd = librt.shm_open(self._shm_name.encode(), os.O_RDWR, 0) + if fd >= 0: + break + import time + + time.sleep(0.05) + else: + raise OSError(f"shm_open failed: rank 0 never created {self._shm_name}") + + self._shm_fd = fd + self._shm_mmap = mmap.mmap(fd, total_size, mmap.MAP_SHARED, mmap.PROT_READ | mmap.PROT_WRITE) + + # Register the entire shm region with HIP so Triton's pointer check passes + mmap_base = ctypes.addressof(ctypes.c_char.from_buffer(self._shm_mmap)) + try: + libhip = ctypes.CDLL("libamdhip64.so", use_errno=True) + libhip.hipHostRegister.argtypes = [ctypes.c_void_p, ctypes.c_size_t, ctypes.c_uint] + libhip.hipHostRegister.restype = ctypes.c_int + err = libhip.hipHostRegister(ctypes.c_void_p(mmap_base), ctypes.c_size_t(total_size), ctypes.c_uint(0)) + if err != 0: + _log_rank(logging.WARNING, "hipHostRegister returned %d", err, rank=cur_rank, num_ranks=num_ranks) + self._hip_registered_ptr = mmap_base + self._libhip = libhip + except Exception as e: + _log_rank(logging.WARNING, "hipHostRegister failed: %s", str(e), rank=cur_rank, num_ranks=num_ranks) + self._hip_registered_ptr = None + + my_offset = cur_rank * heap_size + buf = (ctypes.c_int8 * heap_size).from_buffer(self._shm_mmap, my_offset) + np_arr = np.ctypeslib.as_array(buf) + self.memory_pool = torch.from_numpy(np_arr) + + _log_rank( + logging.INFO, + "TorchAllocator: sim shm_open %s, rank %d slice at offset %d", + self._shm_name, + cur_rank, + my_offset, + rank=cur_rank, + num_ranks=num_ranks, + ) else: - self.rank_bools = None self.memory_pool = torch.empty(heap_size, device=self.device, dtype=torch.int8) self._peer_ext_mem_handles: Dict[int, object] = {} @@ -151,6 +196,25 @@ def establish_peer_access(self, all_bases: Dict[int, int], connections: Optional """ heap_bases_array = np.zeros(self.num_ranks, dtype=np.uint64) + if is_simulation_env() and self._shm_mmap is not None: + self._shm_peer_views = [] + for rank in range(self.num_ranks): + peer_offset = rank * self.heap_size + peer_buf = (ctypes.c_int8 * self.heap_size).from_buffer(self._shm_mmap, peer_offset) + peer_np = np.ctypeslib.as_array(peer_buf) + peer_tensor = torch.from_numpy(peer_np) + self._shm_peer_views.append(peer_tensor) + heap_bases_array[rank] = peer_tensor.data_ptr() + self.heap_bases_array = heap_bases_array + _log_rank( + logging.INFO, + "TorchAllocator: sim peer access via shm_open, %d ranks", + self.num_ranks, + rank=self.cur_rank, + num_ranks=self.num_ranks, + ) + return + if connections is not None: for handle in self._peer_ext_mem_handles.values(): try: @@ -190,7 +254,7 @@ def establish_peer_access(self, all_bases: Dict[int, int], connections: Optional self.heap_bases_array = heap_bases_array def close(self): - """Release peer external memory handles.""" + """Release peer external memory handles and shm resources.""" for handle in self._peer_ext_mem_handles.values(): try: destroy_external_memory(handle) @@ -198,8 +262,34 @@ def close(self): pass self._peer_ext_mem_handles.clear() + if getattr(self, "_hip_registered_ptr", None) is not None: + try: + self._libhip.hipHostUnregister(ctypes.c_void_p(self._hip_registered_ptr)) + except Exception: + pass + self._hip_registered_ptr = None + if self._shm_mmap is not None: + try: + self._shm_mmap.close() + except Exception: + pass + self._shm_mmap = None + if self._shm_fd is not None: + try: + os.close(self._shm_fd) + except Exception: + pass + self._shm_fd = None + if self._shm_name is not None and self.cur_rank == 0: + try: + self._librt.shm_unlink(self._shm_name.encode()) + except Exception: + pass + def get_device(self) -> torch.device: """Get the torch device.""" + if is_simulation_env(): + return torch.device(self.device) return self.memory_pool.device def import_external_tensor(self, external_tensor: torch.Tensor) -> torch.Tensor: diff --git a/iris/host/memory/symmetric_heap.py b/iris/host/memory/symmetric_heap.py index 28165288b..8f87db44f 100644 --- a/iris/host/memory/symmetric_heap.py +++ b/iris/host/memory/symmetric_heap.py @@ -400,8 +400,10 @@ def _refresh_peer_access_torch(self, dist, all_bases_arr): from iris.host.platform.utils import is_simulation_env if is_simulation_env(): + all_bases = {r: int(all_bases_arr[r]) for r in range(self.num_ranks)} + self.allocator.establish_peer_access(all_bases) for r in range(self.num_ranks): - self.heap_bases[r] = int(all_bases_arr[r]) + self.heap_bases[r] = int(self.allocator.heap_bases_array[r]) else: all_bases = {r: int(all_bases_arr[r]) for r in range(self.num_ranks)} self.allocator.establish_peer_access(all_bases, self.fd_conns)