From 1c72baa4974215ecdd827a7522c33f4c14ba0d89 Mon Sep 17 00:00:00 2001 From: ganyi Date: Wed, 17 Jun 2026 03:38:44 +0000 Subject: [PATCH] aiter asm pa have right acc Signed-off-by: ganyi --- atom/model_ops/minimax_m3/sparse_attn.py | 595 +++++++++++++++++++++++ atom/models/minimax_m3.py | 278 ++++++++--- atom/utils/envs.py | 7 + 3 files changed, 818 insertions(+), 62 deletions(-) diff --git a/atom/model_ops/minimax_m3/sparse_attn.py b/atom/model_ops/minimax_m3/sparse_attn.py index d2d8b20985..f19bfb865c 100644 --- a/atom/model_ops/minimax_m3/sparse_attn.py +++ b/atom/model_ops/minimax_m3/sparse_attn.py @@ -28,6 +28,166 @@ SPARSE_BLOCK_SIZE = 128 +# --------------------------------------------------------------------------- +# Fused qknorm + RoPE + KV insert (SHUFFLE main cache writer). +# +# Reference / fallback for ``aiter.fused_minimax_m3_qknorm_rope_kv_insert`` that +# writes the MAIN K/V cache in page-16 SHUFFLE layout (via +# ``aiter.reshape_and_cache(asm_layout=True)``) instead of the plain page-128 +# layout the aiter kernel writes. This lets AITER ASM paged-attention +# (``pa_fwd_asm``) read the M3 main KV cache during decode. +# +# Math matches the aiter op test oracle exactly (gemma_rmsnorm + neox partial +# RoPE in fp32, cast to the qkv dtype). Correctness is prioritized over speed: +# this is a reference path, not the perf hot path. +# --------------------------------------------------------------------------- +def _gemma_rmsnorm(x: torch.Tensor, weight: torch.Tensor, eps: float) -> torch.Tensor: + xf = x.float() + variance = xf.pow(2).mean(dim=-1, keepdim=True) + return xf * torch.rsqrt(variance + eps) * (1.0 + weight.float()) + + +def _apply_rope_neox_partial( + x: torch.Tensor, + positions: torch.Tensor, + cos_sin_cache: torch.Tensor, + rotary_dim: int, +) -> torch.Tensor: + half = rotary_dim // 2 + cos_sin = cos_sin_cache[positions].float() + cos = cos_sin[..., :half].unsqueeze(1) + sin = cos_sin[..., half:].unsqueeze(1) + + rot = x[..., :rotary_dim] + x1 = rot[..., :half] + x2 = rot[..., half:] + out = x.clone() + out[..., :half] = x1 * cos - x2 * sin + out[..., half:rotary_dim] = x2 * cos + x1 * sin + return out + + +def _norm_rope( + x: torch.Tensor, + weight: torch.Tensor, + positions: torch.Tensor, + cos_sin_cache: torch.Tensor, + eps: float, + rotary_dim: int, + dtype: torch.dtype, +) -> torch.Tensor: + """gemma_rmsnorm followed by partial neox RoPE, in fp32, cast to ``dtype``.""" + normed = _gemma_rmsnorm(x.float(), weight, eps) + return _apply_rope_neox_partial(normed, positions, cos_sin_cache, rotary_dim).to( + dtype + ) + + +@torch.no_grad() +def minimax_m3_fused_qknorm_rope_kv_insert_shuffle( + qkv: torch.Tensor, # [num_tokens, q_size + 2*kv_size + iq_size + ik_size] + q_norm_weight: torch.Tensor, # [head_dim] + k_norm_weight: torch.Tensor, # [head_dim] + cos_sin_cache: torch.Tensor, # [max_pos, rotary_dim] + positions: torch.Tensor, # [num_tokens] int + num_heads: int, + num_kv_heads: int, + rotary_dim: int, + eps: float, + index_q_norm_weight: torch.Tensor, # [idx_head_dim] + index_k_norm_weight: torch.Tensor, # [idx_head_dim] + num_index_heads: int, + slot_mapping: torch.Tensor, # [num_tokens] int64 logical slots + kv_cache_k: torch.Tensor, # SHUFFLE K cache [phys, num_kv_heads, head_dim//x, 16, x] + kv_cache_v: torch.Tensor, # SHUFFLE V cache [phys, num_kv_heads, 16//x, head_dim, x] + index_cache: torch.Tensor, # index K cache, viewable as [-1, idx_head_dim] + q_out: torch.Tensor, # [num_tokens, q_size] normed+roped q + index_q_out: torch.Tensor, # [num_tokens, iq_size] normed+roped index_q + idx_head_dim: int, +) -> None: + """Reference for the fused M3 qknorm/rope/kv-insert, SHUFFLE main-cache variant. + + Reproduces the exact semantics of ``aiter.fused_minimax_m3_qknorm_rope_kv_insert`` + (gemma rmsnorm + partial neox RoPE on q/k/index_q/index_k, raw V), but writes + the main K/V into the page-16 SHUFFLE cache via + ``aiter.reshape_and_cache(asm_layout=True)`` (the proven SHUFFLE writer), and + scatters index_k into ``index_cache.view(-1, idx_head_dim)[slot_mapping]``. + """ + import aiter + + num_tokens = qkv.shape[0] + head_dim = q_norm_weight.shape[-1] + dtype = qkv.dtype + + q_size = num_heads * head_dim + kv_size = num_kv_heads * head_dim + iq_size = num_index_heads * head_dim + ik_size = idx_head_dim + + q_in, k_in, v_in, iq_in, ik_in = qkv.split( + [q_size, kv_size, kv_size, iq_size, ik_size], dim=-1 + ) + + # q / index_q -> normed + roped outputs. + q_ref = _norm_rope( + q_in.view(num_tokens, num_heads, head_dim), + q_norm_weight, + positions, + cos_sin_cache, + eps, + rotary_dim, + dtype, + ).view(num_tokens, q_size) + q_out.copy_(q_ref) + + iq_ref = _norm_rope( + iq_in.view(num_tokens, num_index_heads, head_dim), + index_q_norm_weight, + positions, + cos_sin_cache, + eps, + rotary_dim, + dtype, + ).view(num_tokens, iq_size) + index_q_out.copy_(iq_ref) + + # k -> normed + roped; v -> raw. Both written to the SHUFFLE main cache. + k_ref = _norm_rope( + k_in.view(num_tokens, num_kv_heads, head_dim), + k_norm_weight, + positions, + cos_sin_cache, + eps, + rotary_dim, + dtype, + ).view(num_tokens, num_kv_heads, head_dim) + v_raw = v_in.view(num_tokens, num_kv_heads, head_dim).contiguous() + + aiter.reshape_and_cache( + k_ref, + v_raw, + kv_cache_k, + kv_cache_v, + slot_mapping, + kv_cache_dtype="auto", + k_scale=None, + v_scale=None, + asm_layout=True, + ) + + # index_k -> single head, normed + roped; plain scatter into index_cache. + ik_ref = _norm_rope( + ik_in.view(num_tokens, 1, idx_head_dim), + index_k_norm_weight, + positions, + cos_sin_cache, + eps, + rotary_dim, + dtype, + ).view(num_tokens, idx_head_dim) + index_cache.view(-1, idx_head_dim)[slot_mapping] = ik_ref.to(index_cache.dtype) + + def _is_fp8_kv_cache_tensor(kv_cache: torch.Tensor) -> bool: fp8_dtypes = ( getattr(torch, "float8_e4m3fn", None), @@ -537,3 +697,438 @@ def minimax_m3_sparse_attn_decode( output.stride(2), NUM_TOPK_CHUNKS=num_topk_chunks, ) + + +# --------------------------------------------------------------------------- +# Page-16 SHUFFLE prefill kernel. Identical math to _gqa_sparse_fwd_kernel; +# only the K/V load addressing differs: the plain page-128 cache is replaced by +# a page-16 SHUFFLE cache split into separate K and V tensors. +# +# K SHUFFLE: [num_phys_blocks, num_kv_heads, head_dim//x, 16, x] +# V SHUFFLE: [num_phys_blocks, num_kv_heads, 16//x, head_dim, x] (x = 16//itemsize) +# +# A selected logical 128-block ``blk`` maps to 8 physical 16-pages: +# phys_page(j) = block_table[blk]*PAGES_PER_SPARSE_BLOCK + j, j in 0..7 +# Within a sparse-block-local position p in [0,128): j = p//16, intra = p%16. +# For (head h, intra s in [0,16), dim d in [0,128)): +# K element: k_cache[phys, h, d//x, s, d%x] +# V element: v_cache[phys, h, s//x, d, s%x] +# --------------------------------------------------------------------------- +@triton.heuristics( + { + "BLOCK_SIZE_D": lambda args: triton.next_power_of_2(args["head_dim"]), + "BLOCK_SIZE_H": lambda args: triton.next_power_of_2(args["gqa_group_size"]), + "BLOCK_SIZE_T": lambda args: triton.next_power_of_2(args["max_topk"]), + "BLOCK_SIZE_QH": lambda args: args["BLOCK_SIZE_Q"] + * triton.next_power_of_2(args["gqa_group_size"]), + } +) +@triton.jit +def _gqa_sparse_fwd_kernel_shuffle( + q_ptr, # [total_q, num_heads, head_dim] + k_cache_ptr, # SHUFFLE K [num_phys_blocks, num_kv_heads, head_dim//x, 16, x] + v_cache_ptr, # SHUFFLE V [num_phys_blocks, num_kv_heads, 16//x, head_dim, x] + t_ptr, # topk_idx: [num_kv_heads, total_q, topk] + o_ptr, # [total_q, num_heads, head_dim] + block_table_ptr, # [num_reqs, max_blocks] + cu_seqlens_q, + cu_seqblocks_q, + seq_lens, + prefix_lens, + num_kv_heads, + gqa_group_size, + head_dim, + max_topk, + num_q_loop, + sm_scale, + stride_qn, + stride_qh, + stride_qd, + # K SHUFFLE strides: [blk, h, d//x, s(16), x] + stride_k_blk, + stride_k_h, + stride_k_dx, + stride_k_s, + stride_k_x, + # V SHUFFLE strides: [blk, h, s//x, d, x] + stride_v_blk, + stride_v_h, + stride_v_sx, + stride_v_d, + stride_v_x, + stride_th, + stride_tn, + stride_tk, + stride_on, + stride_oh, + stride_od, + stride_bt_b, + X: tl.constexpr, # 16 // dtype.itemsize (bf16 -> 8) + PAGES_PER_BLOCK: tl.constexpr, # PAGES_PER_SPARSE_BLOCK (8) + ASM_PAGE: tl.constexpr, # ASM_PAGE_SIZE (16) + BLOCK_SIZE_Q: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, # == SPARSE_BLOCK_SIZE (128) + BLOCK_SIZE_D: tl.constexpr, + BLOCK_SIZE_H: tl.constexpr, + BLOCK_SIZE_T: tl.constexpr, + BLOCK_SIZE_QH: tl.constexpr, +): + sm_scale_log2e = sm_scale * 1.4426950409 + pid_q = tl.program_id(0) + pid_kh = tl.program_id(1) + pid_b = tl.program_id(2) + pid_h = pid_kh * gqa_group_size + q_start = tl.load(cu_seqlens_q + pid_b) + q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start + q_block_start = tl.load(cu_seqblocks_q + pid_b) + q_block_len = tl.load(cu_seqblocks_q + pid_b + 1) - q_block_start + seq_len = tl.load(seq_lens + pid_b) + prefix_len = tl.load(prefix_lens + pid_b) + if pid_q * num_q_loop >= q_block_len: + return + real_q_loop = min(num_q_loop, q_block_len - pid_q * num_q_loop) + bt_row = block_table_ptr + pid_b * stride_bt_b + off_n = tl.arange(0, BLOCK_SIZE_K) + off_d = tl.arange(0, BLOCK_SIZE_D) + d_mask = off_d < head_dim + # SHUFFLE decomposition of the 128 sparse-block-local positions: + # j = p // 16 (which physical 16-page), s = p % 16 (intra-page position). + j_of_n = off_n // ASM_PAGE + s_of_n = off_n % ASM_PAGE + # SHUFFLE decomposition of the head_dim: dx = d // x, dr = d % x. + dx_of_d = off_d // X + dr_of_d = off_d % X + for j in range(real_q_loop): + pid_q_j = pid_q * num_q_loop + j + t_ptr_j = t_ptr + (q_block_start + pid_q_j) * stride_tn + pid_kh * stride_th + off_t = tl.arange(0, BLOCK_SIZE_T) + topk_idx = tl.load(t_ptr_j + off_t * stride_tk, mask=off_t < max_topk, other=-1) + real_topk = tl.sum((topk_idx >= 0).to(tl.int32), axis=0) + q_ptrs = tl.make_block_ptr( + base=q_ptr + q_start * stride_qn + pid_h * stride_qh, + shape=(q_len, gqa_group_size, head_dim), + strides=(stride_qn, stride_qh, stride_qd), + offsets=(pid_q_j * BLOCK_SIZE_Q, 0, 0), + block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_H, BLOCK_SIZE_D), + order=(2, 1, 0), + ) + q = tl.load(q_ptrs, boundary_check=(0, 1, 2), padding_option="zero") + off_q = ( + tl.arange(0, BLOCK_SIZE_Q)[:, None] + + pid_q_j * BLOCK_SIZE_Q + + prefix_len + - tl.arange(0, BLOCK_SIZE_K)[None, :] + ) + m_i = tl.full((BLOCK_SIZE_QH,), float("-inf"), dtype=tl.float32) + lse_i = tl.full((BLOCK_SIZE_QH,), float("-inf"), dtype=tl.float32) + acc_o = tl.zeros((BLOCK_SIZE_QH, BLOCK_SIZE_D), dtype=tl.float32) + q = tl.reshape(q, BLOCK_SIZE_QH, BLOCK_SIZE_D) + for _ in range(real_topk): + blk = tl.load(t_ptr_j).to(tl.int32) + t_ptr_j = t_ptr_j + stride_tk + c = blk * BLOCK_SIZE_K + # logical 128-page id -> base physical 16-page (logical*8 + j). + base_phys = tl.load(bt_row + blk).to(tl.int64) * PAGES_PER_BLOCK + pos = c + off_n + pos_mask = pos < seq_len + # physical page + intra-page position for each of the 128 positions. + phys_n = base_phys + j_of_n # [BLOCK_SIZE_K] + # K SHUFFLE address for [d (rows), p (cols)]: + # k_cache[phys, h, d//x, s, d%x] + k = tl.load( + k_cache_ptr + + phys_n[None, :] * stride_k_blk + + pid_kh * stride_k_h + + dx_of_d[:, None] * stride_k_dx + + s_of_n[None, :] * stride_k_s + + dr_of_d[:, None] * stride_k_x, + mask=d_mask[:, None] & pos_mask[None, :], + other=0.0, + ) + qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_H, BLOCK_SIZE_K), dtype=tl.float32) + # causal: q_abs_pos - k_off >= block_start (c) + qk += tl.where(off_q[:, None, :] >= c, 0, float("-inf")) + qk = tl.reshape(qk, BLOCK_SIZE_QH, BLOCK_SIZE_K) + qk += tl.dot(q, k) * sm_scale_log2e + qk += tl.where(pos_mask[None, :], 0, float("-inf")) + m_ij = tl.maximum(m_i, tl.max(qk, axis=1)) + p = tl.exp2(qk - m_ij[:, None]) + l_ij = tl.sum(p, axis=1) + acc_o = acc_o * tl.exp2(m_i - m_ij)[:, None] + # V SHUFFLE address for [p (rows), d (cols)]: + # v_cache[phys, h, s//x, d, s%x] + v = tl.load( + v_cache_ptr + + phys_n[:, None] * stride_v_blk + + pid_kh * stride_v_h + + (s_of_n[:, None] // X) * stride_v_sx + + off_d[None, :] * stride_v_d + + (s_of_n[:, None] % X) * stride_v_x, + mask=pos_mask[:, None] & d_mask[None, :], + other=0.0, + ) + acc_o += tl.dot(p.to(v.dtype), v) + m_i = m_ij + lse_i = m_ij + tl.log2(tl.exp2(lse_i - m_ij) + l_ij) + acc_o = acc_o * tl.exp2(m_i - lse_i)[:, None] + acc_o = tl.reshape(acc_o, BLOCK_SIZE_Q, BLOCK_SIZE_H, BLOCK_SIZE_D) + o_ptrs = tl.make_block_ptr( + base=o_ptr + q_start * stride_on + pid_h * stride_oh, + shape=(q_len, gqa_group_size, head_dim), + strides=(stride_on, stride_oh, stride_od), + offsets=(pid_q_j * BLOCK_SIZE_Q, 0, 0), + block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_H, BLOCK_SIZE_D), + order=(2, 1, 0), + ) + tl.store(o_ptrs, acc_o.to(o_ptr.dtype.element_ty), boundary_check=(0, 1, 2)) + + +@torch.no_grad() +def minimax_m3_sparse_attn_shuffle( + q: torch.Tensor, # [total_q, num_heads, head_dim] + k_cache: torch.Tensor, # SHUFFLE K [num_phys_blocks, num_kv_heads, head_dim//x, 16, x] + v_cache: torch.Tensor, # SHUFFLE V [num_phys_blocks, num_kv_heads, 16//x, head_dim, x] + topk_idx: torch.Tensor, # [num_kv_heads, total_q, topk] + block_table: torch.Tensor, # [batch, max_blocks] logical 128-granularity + cu_seqlens_q: torch.Tensor, # [batch+1] int32 + seq_lens: torch.Tensor, # [batch] int32 + prefix_lens: torch.Tensor, # [batch] int32 + max_query_len: int, + num_kv_heads: int, + sm_scale: float, + output: torch.Tensor, # [total_q, num_heads, head_dim] +) -> None: + """GQA block-sparse prefill attention over a page-16 SHUFFLE KV cache. + + Math-identical to ``minimax_m3_sparse_attn`` (online base-2 softmax, causal + diagonal, GQA group reshaping); only the K/V load addressing differs. Each + selected logical 128-block expands to ``PAGES_PER_SPARSE_BLOCK`` physical + 16-pages (physical = logical*8 + j). bf16 only. + """ + assert q.dtype == torch.bfloat16, "shuffle prefill kernel is bf16-only" + assert not _is_fp8_kv_cache_tensor(k_cache), "shuffle prefill kernel is bf16-only" + total_q, num_heads, head_dim = q.shape + batch = cu_seqlens_q.shape[0] - 1 + topk = topk_idx.shape[-1] + gqa_group_size = num_heads // num_kv_heads + x = 16 // k_cache.element_size() # bf16 -> 8 + grid = (max_query_len, num_kv_heads, batch) + _gqa_sparse_fwd_kernel_shuffle[grid]( + q, + k_cache, + v_cache, + topk_idx, + output, + block_table, + cu_seqlens_q, + cu_seqlens_q, # cu_seqblocks_q == cu_seqlens_q when block_size_q == 1 + seq_lens, + prefix_lens, + num_kv_heads, + gqa_group_size, + head_dim, + topk, + 1, # num_q_loop + sm_scale, + q.stride(0), + q.stride(1), + q.stride(2), + k_cache.stride(0), + k_cache.stride(1), + k_cache.stride(2), + k_cache.stride(3), + k_cache.stride(4), + v_cache.stride(0), + v_cache.stride(1), + v_cache.stride(2), + v_cache.stride(3), + v_cache.stride(4), + topk_idx.stride(0), + topk_idx.stride(1), + topk_idx.stride(2), + output.stride(0), + output.stride(1), + output.stride(2), + block_table.stride(0), + X=x, + PAGES_PER_BLOCK=PAGES_PER_SPARSE_BLOCK, + ASM_PAGE=ASM_PAGE_SIZE, + BLOCK_SIZE_Q=1, + BLOCK_SIZE_K=SPARSE_BLOCK_SIZE, + num_stages=1, + ) + + +# --------------------------------------------------------------------------- +# ASM paged-attention decode path (alternative to the Triton split-K decode). +# +# The lightning-indexer selects, per request, up to `topk` 128-token blocks. +# AITER's bf16 `pa_fwd_asm` only ships a blkSz=16 kernel, so the main KV cache +# is paged at 16 and each selected logical 128-block expands into its 8 physical +# 16-pages (physical = block_table[blk]*8 + j). We compact those selected pages +# into a dense per-request block_table + a context_lens giving the true number +# of valid KV tokens, then hand them to pa_fwd_asm. +# +# Only valid when per-rank num_kv_heads == 1 (topk is per-kv-head; ASM PA shares +# one block_table across all kv heads, so it cannot express per-kv-head +# selection). The caller asserts this. +# +# The partial (seq_len-containing) 128-block must be packed LAST so that +# pa_fwd_asm's tail mask (context_lens % 16) applies to it; all other selected +# blocks contribute full 8x16-page groups packed contiguously at the front. +# --------------------------------------------------------------------------- +# Physical KV page size for the ASM decode path. The bf16 gfx950 pa_fwd_asm +# kernels only ship a blkSz=16 variant (Gqa in {8,16}); the model's logical +# sparse block is 128, so each selected 128-block expands into +# PAGES_PER_SPARSE_BLOCK == 8 contiguous physical 16-pages. +ASM_PAGE_SIZE = 16 +PAGES_PER_SPARSE_BLOCK = SPARSE_BLOCK_SIZE // ASM_PAGE_SIZE # 8 + + +@triton.jit +def _build_sparse_block_table_kernel( + t_ptr, # topk_idx: [1, batch, topk] int32, 0-indexed 128-blocks, -1 pad + block_table_ptr, # logical block_table [batch, max_blocks] int32 (128-granularity) + seq_lens_ptr, # [batch] int32 + sparse_bt_ptr, # out: compacted 16-page block_table [batch, topk*8] int32 + sparse_ctx_ptr, # out: compacted context_lens [batch] int32 + max_topk, + sm_block_size: tl.constexpr, # logical sparse block size (128) + pages_per_block: tl.constexpr, # 16-pages per sparse block (8) + asm_page_size: tl.constexpr, # physical page size (16) + stride_tn, + stride_tk, + stride_bt_b, + stride_sbt_b, + BLOCK_SIZE_T: tl.constexpr, +): + pid_b = tl.program_id(0) + seq_len = tl.load(seq_lens_ptr + pid_b) + # logical 128-block containing the last valid token (the partial tail block). + last_blk = (seq_len - 1) // sm_block_size + bt_row = block_table_ptr + pid_b * stride_bt_b + t_row = t_ptr + pid_b * stride_tn + sbt_row = sparse_bt_ptr + pid_b * stride_sbt_b + + off_t = tl.arange(0, BLOCK_SIZE_T) + blk = tl.load(t_row + off_t * stride_tk, mask=off_t < max_topk, other=-1) + valid = blk >= 0 + is_tail = valid & (blk == last_blk) + is_full = valid & (blk != last_blk) + + # Stable compaction in units of SPARSE BLOCKS: full blocks first (in + # selection order), tail block last. Each sparse block then expands to + # `pages_per_block` physical 16-pages. + n_full = tl.sum(is_full.to(tl.int32), axis=0) + n_valid = tl.sum(valid.to(tl.int32), axis=0) + earlier_full = tl.cumsum(is_full.to(tl.int32), axis=0) - is_full.to(tl.int32) + slot = tl.where(is_full, earlier_full, n_full) # tail -> slot n_full + + # logical 128-page id of each selected block -> 8 physical 16-pages: + # physical = logical_id * pages_per_block + j (matches block_convert) + logical_page = tl.load(bt_row + blk, mask=valid, other=0).to(tl.int32) + base_phys = logical_page * pages_per_block # [BLOCK_SIZE_T] + dst_base = slot * pages_per_block # [BLOCK_SIZE_T] + + for j in range(pages_per_block): + write_mask = valid & (dst_base + j < BLOCK_SIZE_T * pages_per_block) + tl.store(sbt_row + dst_base + j, base_phys + j, mask=write_mask) + + # true valid token count: full blocks contribute 128 each, tail the remainder. + tail_tokens = seq_len - last_blk * sm_block_size + has_tail = tl.sum(is_tail.to(tl.int32), axis=0) > 0 + ctx = n_full * sm_block_size + tl.where(has_tail, tail_tokens, 0) + ctx = tl.where(has_tail, ctx, tl.minimum(n_valid * sm_block_size, seq_len)) + tl.store(sparse_ctx_ptr + pid_b, ctx) + + +@torch.no_grad() +def minimax_m3_build_sparse_block_table( + topk_idx: torch.Tensor, # [1, batch, topk] int32 (num_kv_heads == 1) + block_table: torch.Tensor, # [batch, max_blocks] int32, logical 128-granularity + seq_lens: torch.Tensor, # [batch] int32 +) -> tuple[torch.Tensor, torch.Tensor]: + """Compact per-request selected 128-blocks into a dense 16-page block_table + + context_lens for `pa_fwd_asm`. + + Each selected logical 128-block expands to its 8 physical 16-pages + (``logical_id * 8 + j``, matching ``block_convert``). The partial tail block + is packed last so pa_fwd_asm's tail mask (context_lens % 16) lands on it. + + Returns (sparse_bt [batch, topk*8] int32, sparse_ctx_lens [batch] int32). + The compacted width is fixed (topk*8), so the grid is shape-constant + (cudagraph-safe). + """ + assert topk_idx.shape[0] == 1, "ASM PA decode requires num_kv_heads == 1" + batch = topk_idx.shape[1] + topk = topk_idx.shape[-1] + width = topk * PAGES_PER_SPARSE_BLOCK + sparse_bt = torch.zeros((batch, width), dtype=torch.int32, device=topk_idx.device) + sparse_ctx = torch.zeros((batch,), dtype=torch.int32, device=topk_idx.device) + _build_sparse_block_table_kernel[(batch,)]( + topk_idx, + block_table, + seq_lens, + sparse_bt, + sparse_ctx, + topk, + SPARSE_BLOCK_SIZE, + PAGES_PER_SPARSE_BLOCK, + ASM_PAGE_SIZE, + topk_idx.stride(1), + topk_idx.stride(2), + block_table.stride(0), + sparse_bt.stride(0), + BLOCK_SIZE_T=triton.next_power_of_2(topk), + ) + return sparse_bt, sparse_ctx + + +@torch.no_grad() +def minimax_m3_sparse_attn_decode_asm( + q: torch.Tensor, # [batch, num_heads, head_dim==128] + k_cache: torch.Tensor, # SHUFFLE K [num_blocks, num_kv_heads, head_dim//x, 16, x] + v_cache: torch.Tensor, # SHUFFLE V [num_blocks, num_kv_heads, 16//x, head_dim, x] + topk_idx: torch.Tensor, # [num_kv_heads, batch, topk] int32 + block_table: torch.Tensor, # [batch, max_blocks] int32, logical 128-granularity + seq_lens: torch.Tensor, # [batch] int32 + num_kv_heads: int, + sm_scale: float, + output: torch.Tensor, # [batch, num_heads, head_dim] + k_scale: torch.Tensor | None = None, + v_scale: torch.Tensor | None = None, +) -> None: + """Block-sparse decode attention via the AITER ASM paged-attention kernel. + + The lightning-indexer's selected logical 128-blocks are compacted into a + dense PHYSICAL 16-page block_table (each 128-block -> 8 pages, tail packed + last) + exact context_lens, then fed to ``pa_fwd_asm`` over the page-16 + SHUFFLE KV cache. ASM PA's bf16 gfx950 kernels cover Gqa in {8, 16} (M3 decode + is gqa 16 at tp4 / 8 at tp8) and blkSz=16, hence the 16-page expansion. + + Requires per-rank num_kv_heads == 1 (the indexer top-k is per-kv-head; one + shared block_table cannot express per-kv-head selection) and head_dim == 128. + """ + from atom.model_ops.base_attention import run_pa_fwd_asm + + assert num_kv_heads == 1, ( + "minimax_m3_sparse_attn_decode_asm requires per-rank num_kv_heads == 1;" + f" got {num_kv_heads}. Use the Triton split-K decode path for GQA." + ) + assert q.shape[-1] == 128, "ASM paged-attention requires head_dim == 128." + + sparse_bt, sparse_ctx = minimax_m3_build_sparse_block_table( + topk_idx, block_table, seq_lens + ) + run_pa_fwd_asm( + q=q, + k_cache=k_cache, + v_cache=v_cache, + block_tables=sparse_bt, + context_lens=sparse_ctx, + k_scale=k_scale, + v_scale=v_scale, + out=output, + max_qlen=1, + high_precision=0, + ) diff --git a/atom/models/minimax_m3.py b/atom/models/minimax_m3.py index cd66cc49cc..92b14a01fb 100644 --- a/atom/models/minimax_m3.py +++ b/atom/models/minimax_m3.py @@ -44,11 +44,15 @@ ) from atom.model_ops.minimax_m3.sparse_attn import ( SPARSE_BLOCK_SIZE, + minimax_m3_fused_qknorm_rope_kv_insert_shuffle, minimax_m3_sparse_attn, minimax_m3_sparse_attn_decode, + minimax_m3_sparse_attn_decode_asm, + minimax_m3_sparse_attn_shuffle, ) from atom.model_ops.swiglu_oai import swiglu_oai_split from atom.model_ops.utils import atom_parameter +from atom.utils import envs from atom.models.utils import ( IntermediateTensors, PPMissingLayer, @@ -283,6 +287,7 @@ def __init__( cache_config: str = "bf16", ) -> None: super().__init__() + self.layer_num = layer_id self.hidden_size = config.hidden_size self.tp_size = get_tensor_model_parallel_world_size() self.total_num_heads = config.num_attention_heads @@ -492,6 +497,16 @@ def __init__( self.index_rotary_emb = self.rotary_emb self.kv_cache = torch.tensor([]) self.index_cache = torch.tensor([]) + # ASM decode path (ATOM_M3_SPARSE_USE_ASM_PA): page-16 SHUFFLE K/V *views* + # of `self.kv_cache`, derived lazily by `_ensure_asm_shuffle_views()`. + # The allocation/binding is unchanged from the non-ASM path (the backend + # gives us the plain page-128 `self.kv_cache`); we only reinterpret its + # bytes as page-16 SHUFFLE here. index cache stays the page-128 + # `index_cache` above. + self.kv_cache_k = torch.tensor([]) + self.kv_cache_v = torch.tensor([]) + self.k_scale = self.v_scale = None + self._use_asm_pa = bool(envs.ATOM_M3_SPARSE_USE_ASM_PA) compilation_config = get_current_atom_config().compilation_config if self.layer_name in compilation_config.static_forward_context: raise ValueError(f"Duplicate layer: {self.layer_name}") @@ -511,6 +526,55 @@ def forward( ) return self.o_proj(attn_output) + def _ensure_asm_shuffle_views(self) -> None: + """Lazily derive the page-16 SHUFFLE K/V views from ``self.kv_cache``. + + The backend binds ``self.kv_cache`` as the plain page-128 combined cache + (shape ``[N, 2, 128, num_kv_heads, head_dim]``) -- identical to the + non-ASM path, no allocation change. AITER's ``pa_fwd_asm`` / the page-16 + SHUFFLE writer need a 5D page-16 view, so we reinterpret each layer's K/V + slice (no bytes moved): one logical 128-page == 8 contiguous physical + 16-pages, so ``N`` logical blocks become ``N*8`` physical 16-pages: + K: [N*8, num_kv_heads, head_dim//x, 16, x] + V: [N*8, num_kv_heads, 16//x, head_dim, x] + Both the write (reshape_and_cache asm_layout=True) and the read + (pa_fwd_asm) go through THESE views, so the page-16 interpretation is + self-consistent regardless of the plain layout. Idempotent: rebuilds only + when the underlying ``self.kv_cache`` storage changes. + """ + from atom.model_ops.minimax_m3.sparse_attn import ( + ASM_PAGE_SIZE, + PAGES_PER_SPARSE_BLOCK, + ) + + if self.kv_cache.numel() == 0: + return + key_cache, value_cache = self.kv_cache.unbind(1) # each [N, 128, h, hd] + if ( + self.kv_cache_k.numel() != 0 + and self.kv_cache_k.data_ptr() == key_cache.data_ptr() + ): + return # views already derived from this storage + x = 16 // self.kv_cache.element_size() + num_blocks = key_cache.shape[0] + num_phys16 = num_blocks * PAGES_PER_SPARSE_BLOCK + # .view (not .reshape): each unbound slice is contiguous, so this is + # guaranteed zero-copy -- writes must land in the real pool, never a copy. + self.kv_cache_k = key_cache.view( + num_phys16, + self.num_kv_heads, + self.head_dim // x, + ASM_PAGE_SIZE, + x, + ) + self.kv_cache_v = value_cache.view( + num_phys16, + self.num_kv_heads, + ASM_PAGE_SIZE // x, + self.head_dim, + x, + ) + def _insert_kv( self, k: torch.Tensor, @@ -518,21 +582,42 @@ def _insert_kv( index_k: torch.Tensor, slot_mapping: torch.Tensor, ) -> None: - if self.kv_cache.numel() == 0 or self.index_cache.numel() == 0: + if self.index_cache.numel() == 0: return - key_cache, value_cache = self.kv_cache.unbind(1) - kv_cache_dtype = "auto" if self.kv_cache_dtype == "bf16" else self.kv_cache_dtype - aiter.reshape_and_cache( - k.view(-1, self.num_kv_heads, self.head_dim), - v.view(-1, self.num_kv_heads, self.head_dim), - key_cache, - value_cache, - slot_mapping, - kv_cache_dtype=kv_cache_dtype, - k_scale=None, - v_scale=None, - asm_layout=False, - ) + if self._use_asm_pa: + self._ensure_asm_shuffle_views() + if self.kv_cache_k.numel() == 0: + return + # Page-16 SHUFFLE write for the ASM decode path. + aiter.reshape_and_cache( + k.view(-1, self.num_kv_heads, self.head_dim), + v.view(-1, self.num_kv_heads, self.head_dim), + self.kv_cache_k, + self.kv_cache_v, + slot_mapping, + kv_cache_dtype="auto", + k_scale=None, + v_scale=None, + asm_layout=True, + ) + else: + if self.kv_cache.numel() == 0: + return + key_cache, value_cache = self.kv_cache.unbind(1) + kv_cache_dtype = ( + "auto" if self.kv_cache_dtype == "bf16" else self.kv_cache_dtype + ) + aiter.reshape_and_cache( + k.view(-1, self.num_kv_heads, self.head_dim), + v.view(-1, self.num_kv_heads, self.head_dim), + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype=kv_cache_dtype, + k_scale=None, + v_scale=None, + asm_layout=False, + ) self.index_cache.view(-1, self.idx_head_dim)[slot_mapping] = index_k.to( self.index_cache.dtype ) @@ -569,19 +654,35 @@ def _run_prefill_sparse( self.scaling, ) output = torch.empty_like(q) - minimax_m3_sparse_attn( - q, - self.kv_cache, - topk_idx, - block_tables, - cu_seqlens_q, - seq_lens, - prefix_lens, - prefill_metadata.max_query_len, - self.num_kv_heads, - self.scaling, - output, - ) + if self._use_asm_pa: + minimax_m3_sparse_attn_shuffle( + q, + self.kv_cache_k, + self.kv_cache_v, + topk_idx, + block_tables, + cu_seqlens_q, + seq_lens, + prefix_lens, + prefill_metadata.max_query_len, + self.num_kv_heads, + self.scaling, + output, + ) + else: + minimax_m3_sparse_attn( + q, + self.kv_cache, + topk_idx, + block_tables, + cu_seqlens_q, + seq_lens, + prefix_lens, + prefill_metadata.max_query_len, + self.num_kv_heads, + self.scaling, + output, + ) return output def _run_decode_sparse( @@ -605,16 +706,37 @@ def _run_decode_sparse( self.scaling, ) output = torch.empty_like(q) - minimax_m3_sparse_attn_decode( - q, - self.kv_cache, - topk_idx, - decode_metadata.block_table, - decode_metadata.seq_lens, - self.num_kv_heads, - self.scaling, - output, - ) + if self._use_asm_pa: + if self.num_kv_heads != 1: + raise NotImplementedError( + "ATOM_M3_SPARSE_USE_ASM_PA requires per-rank num_kv_heads == 1 " + "(tensor-parallel size >= 4); ASM PA shares one block_table " + f"across kv heads. Got num_kv_heads={self.num_kv_heads}." + ) + minimax_m3_sparse_attn_decode_asm( + q, + self.kv_cache_k, + self.kv_cache_v, + topk_idx, + decode_metadata.block_table, + decode_metadata.seq_lens, + self.num_kv_heads, + self.scaling, + output, + k_scale=self.k_scale, + v_scale=self.v_scale, + ) + else: + minimax_m3_sparse_attn_decode( + q, + self.kv_cache, + topk_idx, + decode_metadata.block_table, + decode_metadata.seq_lens, + self.num_kv_heads, + self.scaling, + output, + ) return output def sparse_attention_forward_impl( @@ -623,10 +745,17 @@ def sparse_attention_forward_impl( positions: torch.Tensor, ) -> torch.Tensor: fwd_ctx = get_forward_context() + if self._use_asm_pa: + # Derive the page-16 SHUFFLE K/V views from the (plain) self.kv_cache + # the backend bound. No-op once derived / when cache is unbound. + self._ensure_asm_shuffle_views() + # self.kv_cache is the source of truth for "cache bound" in both paths; + # the ASM views are derived from it. + main_cache_unbound = self.kv_cache.numel() == 0 if ( fwd_ctx.context.is_dummy_run or fwd_ctx.attn_metadata is None - or self.kv_cache.numel() == 0 + or main_cache_unbound or self.index_cache.numel() == 0 ): return torch.empty( @@ -653,30 +782,55 @@ def sparse_attention_forward_impl( (qkv.shape[0], self.index_q_size), dtype=qkv.dtype, device=qkv.device ) cos_sin_cache = _minimax_m3_cos_sin_cache(self.rotary_emb, qkv.dtype) - aiter.fused_minimax_m3_qknorm_rope_kv_insert( - qkv, - self.q_norm.weight, - self.k_norm.weight, - cos_sin_cache, - positions, - self.num_heads, - self.num_kv_heads, - self.rotary_emb.rotary_dim, - self.q_norm.variance_epsilon, - self.index_q_norm.weight, - self.index_k_norm.weight, - self.num_idx_heads, - sparse_metadata.slot_mapping, - self.kv_cache, - self.index_cache, - self.kv_cache.shape[2], - q, - index_q, - sparse_metadata.slot_mapping, - self.kv_cache_dtype, - None, - None, - ) + if self._use_asm_pa: + # Triton fallback that writes the main KV cache in page-16 + # SHUFFLE layout (the aiter fused kernel writes plain page-128). + minimax_m3_fused_qknorm_rope_kv_insert_shuffle( + qkv, + self.q_norm.weight, + self.k_norm.weight, + cos_sin_cache, + positions, + self.num_heads, + self.num_kv_heads, + self.rotary_emb.rotary_dim, + self.q_norm.variance_epsilon, + self.index_q_norm.weight, + self.index_k_norm.weight, + self.num_idx_heads, + sparse_metadata.slot_mapping, + self.kv_cache_k, + self.kv_cache_v, + self.index_cache, + q, + index_q, + self.idx_head_dim, + ) + else: + aiter.fused_minimax_m3_qknorm_rope_kv_insert( + qkv, + self.q_norm.weight, + self.k_norm.weight, + cos_sin_cache, + positions, + self.num_heads, + self.num_kv_heads, + self.rotary_emb.rotary_dim, + self.q_norm.variance_epsilon, + self.index_q_norm.weight, + self.index_k_norm.weight, + self.num_idx_heads, + sparse_metadata.slot_mapping, + self.kv_cache, + self.index_cache, + self.kv_cache.shape[2], + q, + index_q, + sparse_metadata.slot_mapping, + self.kv_cache_dtype, + None, + None, + ) q = q.view(-1, self.num_heads, self.head_dim) index_q = index_q.view(-1, self.num_idx_heads, self.idx_head_dim) if getattr(sparse_metadata, "num_prefills", 0) > 0: diff --git a/atom/utils/envs.py b/atom/utils/envs.py index 46554050b5..37b2a038d8 100644 --- a/atom/utils/envs.py +++ b/atom/utils/envs.py @@ -102,6 +102,13 @@ "ATOM_USE_GLUON_PA_DECODE": lambda: ( os.getenv("ATOM_USE_GLUON_PA_DECODE", "0") == "1" ), + # MiniMax-M3 sparse decode: use the AITER ASM paged-attention kernel + # (pa_fwd_asm over the page-16 SHUFFLE KV cache) instead of the Triton + # split-K decode. Only valid when per-rank num_kv_heads == 1 (tp >= 4); + # falls back to the Triton path otherwise. Set to 1 to enable. + "ATOM_M3_SPARSE_USE_ASM_PA": lambda: ( + os.getenv("ATOM_M3_SPARSE_USE_ASM_PA", "0") == "1" + ), # --- Plugin Mode --- "ATOM_DISABLE_VLLM_PLUGIN": lambda: ( os.getenv("ATOM_DISABLE_VLLM_PLUGIN", "0").lower() == "1"