Skip to content
130 changes: 110 additions & 20 deletions iris/host/memory/allocators/torch_allocator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,13 @@
sub-allocations within it using bump allocation.
"""

import ctypes
import ctypes.util
import logging
import math
import mmap
import numpy as np
import os
import torch
from typing import Optional, Dict
import struct
Expand Down Expand Up @@ -51,27 +55,68 @@ def __init__(self, heap_size: int, device_id: int, cur_rank: int, num_ranks: int
rank=cur_rank,
num_ranks=num_ranks,
)
self._shm_mmap = None
self._shm_fd = None
self._shm_name = None

if is_simulation_env():
import json

# In simulation, each rank allocates n distinct buffers; memory_pool is a shallow view of the ith.
self.rank_bools = [torch.empty(heap_size, device=self.device, dtype=torch.int8) for _ in range(num_ranks)]
self.memory_pool = self.rank_bools[cur_rank]

heap_views = [self.rank_bools[r].data_ptr() for r in range(num_ranks)]
out_path = f"iris_rank_{cur_rank}_allocator_views.json"
with open(out_path, "w") as f:
json.dump(
{
"rank": cur_rank,
"num_ranks": num_ranks,
"heap_views": [hex(b) for b in heap_views],
},
f,
indent=2,
)
self._shm_name = f"/iris-sim-heap-{os.environ.get('SLURM_JOB_ID', os.getppid())}"
Comment thread
mawad-amd marked this conversation as resolved.
total_size = heap_size * num_ranks
librt = ctypes.CDLL(ctypes.util.find_library("rt") or "librt.so.1", use_errno=True)
librt.shm_open.argtypes = [ctypes.c_char_p, ctypes.c_int, ctypes.c_uint]
librt.shm_open.restype = ctypes.c_int
self._librt = librt

if cur_rank == 0:
librt.shm_unlink(self._shm_name.encode())
fd = librt.shm_open(self._shm_name.encode(), os.O_CREAT | os.O_RDWR, 0o600)
if fd < 0:
raise OSError(ctypes.get_errno(), f"shm_open create failed: {os.strerror(ctypes.get_errno())}")
os.ftruncate(fd, total_size)
else:
for _ in range(100):
fd = librt.shm_open(self._shm_name.encode(), os.O_RDWR, 0)
if fd >= 0:
break
import time

time.sleep(0.05)
else:
raise OSError(f"shm_open failed: rank 0 never created {self._shm_name}")

self._shm_fd = fd
self._shm_mmap = mmap.mmap(fd, total_size, mmap.MAP_SHARED, mmap.PROT_READ | mmap.PROT_WRITE)

# Register the entire shm region with HIP so Triton's pointer check passes
mmap_base = ctypes.addressof(ctypes.c_char.from_buffer(self._shm_mmap))
try:
libhip = ctypes.CDLL("libamdhip64.so", use_errno=True)
libhip.hipHostRegister.argtypes = [ctypes.c_void_p, ctypes.c_size_t, ctypes.c_uint]
libhip.hipHostRegister.restype = ctypes.c_int
err = libhip.hipHostRegister(ctypes.c_void_p(mmap_base), ctypes.c_size_t(total_size), ctypes.c_uint(0))
if err != 0:
_log_rank(logging.WARNING, "hipHostRegister returned %d", err, rank=cur_rank, num_ranks=num_ranks)
self._hip_registered_ptr = mmap_base
self._libhip = libhip
except Exception as e:
_log_rank(logging.WARNING, "hipHostRegister failed: %s", str(e), rank=cur_rank, num_ranks=num_ranks)
self._hip_registered_ptr = None

my_offset = cur_rank * heap_size
buf = (ctypes.c_int8 * heap_size).from_buffer(self._shm_mmap, my_offset)
np_arr = np.ctypeslib.as_array(buf)
self.memory_pool = torch.from_numpy(np_arr)

_log_rank(
logging.INFO,
"TorchAllocator: sim shm_open %s, rank %d slice at offset %d",
self._shm_name,
cur_rank,
my_offset,
rank=cur_rank,
num_ranks=num_ranks,
)
else:
self.rank_bools = None
self.memory_pool = torch.empty(heap_size, device=self.device, dtype=torch.int8)

self._peer_ext_mem_handles: Dict[int, object] = {}
Expand Down Expand Up @@ -151,6 +196,25 @@ def establish_peer_access(self, all_bases: Dict[int, int], connections: Optional
"""
heap_bases_array = np.zeros(self.num_ranks, dtype=np.uint64)

if is_simulation_env() and self._shm_mmap is not None:
self._shm_peer_views = []
for rank in range(self.num_ranks):
peer_offset = rank * self.heap_size
peer_buf = (ctypes.c_int8 * self.heap_size).from_buffer(self._shm_mmap, peer_offset)
peer_np = np.ctypeslib.as_array(peer_buf)
peer_tensor = torch.from_numpy(peer_np)
self._shm_peer_views.append(peer_tensor)
heap_bases_array[rank] = peer_tensor.data_ptr()
self.heap_bases_array = heap_bases_array
_log_rank(
logging.INFO,
"TorchAllocator: sim peer access via shm_open, %d ranks",
self.num_ranks,
rank=self.cur_rank,
num_ranks=self.num_ranks,
)
return
Comment thread
mawad-amd marked this conversation as resolved.

if connections is not None:
for handle in self._peer_ext_mem_handles.values():
try:
Expand Down Expand Up @@ -190,16 +254,42 @@ def establish_peer_access(self, all_bases: Dict[int, int], connections: Optional
self.heap_bases_array = heap_bases_array

def close(self):
"""Release peer external memory handles."""
"""Release peer external memory handles and shm resources."""
for handle in self._peer_ext_mem_handles.values():
try:
destroy_external_memory(handle)
except Exception:
pass
self._peer_ext_mem_handles.clear()

if getattr(self, "_hip_registered_ptr", None) is not None:
try:
self._libhip.hipHostUnregister(ctypes.c_void_p(self._hip_registered_ptr))
except Exception:
pass
self._hip_registered_ptr = None
if self._shm_mmap is not None:
try:
self._shm_mmap.close()
except Exception:
pass
self._shm_mmap = None
if self._shm_fd is not None:
try:
os.close(self._shm_fd)
except Exception:
pass
self._shm_fd = None
if self._shm_name is not None and self.cur_rank == 0:
try:
self._librt.shm_unlink(self._shm_name.encode())
except Exception:
pass

def get_device(self) -> torch.device:
"""Get the torch device."""
if is_simulation_env():
return torch.device(self.device)
return self.memory_pool.device
Comment thread
mawad-amd marked this conversation as resolved.

def import_external_tensor(self, external_tensor: torch.Tensor) -> torch.Tensor:
Expand Down
4 changes: 3 additions & 1 deletion iris/host/memory/symmetric_heap.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,8 +400,10 @@ def _refresh_peer_access_torch(self, dist, all_bases_arr):
from iris.host.platform.utils import is_simulation_env

if is_simulation_env():
all_bases = {r: int(all_bases_arr[r]) for r in range(self.num_ranks)}
self.allocator.establish_peer_access(all_bases)
for r in range(self.num_ranks):
self.heap_bases[r] = int(all_bases_arr[r])
self.heap_bases[r] = int(self.allocator.heap_bases_array[r])
else:
all_bases = {r: int(all_bases_arr[r]) for r in range(self.num_ranks)}
self.allocator.establish_peer_access(all_bases, self.fd_conns)
Expand Down
Loading