Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 163 additions & 2 deletions atom/model_ops/minimax_m3/index_topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@

# One sparse block == one KV page.
SPARSE_BLOCK_SIZE = 128
# Physical 16-pages per logical 128-block for the page-16 SHUFFLE ASM/gluon cache
# (must match sparse_attn.PAGES_PER_SPARSE_BLOCK). Used by the fused block-table
# emission in the topk kernels.
PAGES_PER_SPARSE_BLOCK = 8


# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -199,10 +203,18 @@ def _topk_index_kernel(
stride_ti_h,
stride_ti_n,
stride_ti_t,
# --- fused sparse block-table emission (ASM/gluon prefill path) ---
block_table_ptr, # [batch, max_blocks] int32 logical 128-granularity (or dummy)
sparse_bt_ptr, # out: [total_q, topk*pages_per_block] int32 (or dummy)
sparse_ctx_ptr, # out: [total_q] int32 (or dummy)
stride_bt_b,
stride_sbt_n,
BLOCK_SIZE_K: tl.constexpr,
BLOCK_SIZE_T: tl.constexpr,
MASK_INIT: tl.constexpr,
MASK_LOCAL: tl.constexpr,
pages_per_block: tl.constexpr, # 16-pages per sparse block (8)
EMIT_SPARSE_BT: tl.constexpr, # fuse compaction iff True (num_idx_heads==1)
):
tl.static_assert(BLOCK_SIZE_K > BLOCK_SIZE_T)
pid_q = tl.program_id(0)
Expand Down Expand Up @@ -282,6 +294,48 @@ def _topk_index_kernel(
topk_idx = tl.where(store_mask & valid_mask, topk_idx, -1)
tl.store(ti_ptrs, topk_idx.to(ti_ptrs.dtype.element_ty), mask=store_mask)

# --- fused sparse block-table build (per-query-token causal compaction) ---
# Mirrors _build_sparse_block_table_prefill_kernel over the in-register
# selection. Only the first index head emits (ASM/gluon needs num_idx_heads
# == 1). Token absolute pos p = prefix_len + pid_q (sample_interval == 1);
# causal self-block = p // block_size, length p + 1.
if EMIT_SPARSE_BT and pid_h == 0:
p = prefix_len + pid_q * sample_interval
self_blk = p // block_size
causal_len = p + 1
bt_blk = tl.where(off_t < topk, topk_idx, -1)
# causal: drop any selected block above the self-block (defensive; the
# indexer already caps selection at valid_blocks == self_blk + 1).
bt_valid = (bt_blk >= 0) & (bt_blk <= self_blk)
bt_is_tail = bt_valid & (bt_blk == self_blk)
bt_is_full = bt_valid & (bt_blk < self_blk)
bt_n_full = tl.sum(bt_is_full.to(tl.int32), axis=0)
bt_n_valid = tl.sum(bt_valid.to(tl.int32), axis=0)
bt_earlier_full = tl.cumsum(bt_is_full.to(tl.int32), axis=0) - bt_is_full.to(
tl.int32
)
bt_slot = tl.where(bt_is_full, bt_earlier_full, bt_n_full) # tail -> n_full

bt_row = block_table_ptr + pid_b * stride_bt_b
bt_logical_page = tl.load(bt_row + bt_blk, mask=bt_valid, other=0).to(tl.int32)
bt_base_phys = bt_logical_page * pages_per_block
bt_dst_base = bt_slot * pages_per_block

sbt_row = sparse_bt_ptr + (block_start + pid_q) * stride_sbt_n
for pj in range(pages_per_block):
tl.store(sbt_row + bt_dst_base + pj, bt_base_phys + pj, mask=bt_valid)
bt_n_used = bt_n_valid * pages_per_block
off_w = tl.arange(0, BLOCK_SIZE_T * pages_per_block)
tl.store(sbt_row + off_w, tl.zeros_like(off_w), mask=off_w >= bt_n_used)

bt_tail_tokens = causal_len - self_blk * block_size
bt_has_tail = tl.sum(bt_is_tail.to(tl.int32), axis=0) > 0
bt_ctx = bt_n_full * block_size + tl.where(bt_has_tail, bt_tail_tokens, 0)
bt_ctx = tl.where(
bt_has_tail, bt_ctx, tl.minimum(bt_n_valid * block_size, causal_len)
)
tl.store(sparse_ctx_ptr + (block_start + pid_q), bt_ctx)


# ---------------------------------------------------------------------------
# Decode index-score kernel (split-K over seq blocks). Decode == one query
Expand Down Expand Up @@ -531,9 +585,17 @@ def _topk_index_merge_kernel(
stride_tif_h,
stride_tif_b,
stride_tif_t,
# --- fused sparse block-table emission (ASM/gluon decode path) ---
block_table_ptr, # [batch, max_blocks] int32 logical 128-granularity (or dummy)
sparse_bt_ptr, # out: [batch, topk*pages_per_block] int32 (or dummy)
sparse_ctx_ptr, # out: [batch] int32 (or dummy)
stride_bt_b,
stride_sbt_b,
NUM_TOPK_CHUNKS: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
BLOCK_SIZE_T: tl.constexpr,
pages_per_block: tl.constexpr, # 16-pages per sparse block (8)
EMIT_SPARSE_BT: tl.constexpr, # fuse compaction iff True (num_idx_heads==1)
):
pid_b = tl.program_id(0)
pid_h = tl.program_id(1)
Expand Down Expand Up @@ -595,6 +657,44 @@ def _topk_index_merge_kernel(
tif_ptrs, topk_idx_final.to(ti_final_ptr.dtype.element_ty), mask=store_mask
)

# --- fused sparse block-table build (per-request decode compaction) ---
# Mirrors _build_sparse_block_table_kernel over the in-register selection,
# avoiding a second kernel launch + topk_idx HBM round-trip. Only the first
# index head emits (ASM/gluon path requires num_idx_heads == 1).
if EMIT_SPARSE_BT and pid_h == 0:
last_blk = (seq_len - 1) // block_size
bt_blk = tl.where(off_t < topk, topk_idx_final, -1)
bt_valid = bt_blk >= 0
bt_is_tail = bt_valid & (bt_blk == last_blk)
bt_is_full = bt_valid & (bt_blk != last_blk)
bt_n_full = tl.sum(bt_is_full.to(tl.int32), axis=0)
bt_n_valid = tl.sum(bt_valid.to(tl.int32), axis=0)
bt_earlier_full = tl.cumsum(bt_is_full.to(tl.int32), axis=0) - bt_is_full.to(
tl.int32
)
bt_slot = tl.where(bt_is_full, bt_earlier_full, bt_n_full) # tail -> n_full

bt_row = block_table_ptr + pid_b * stride_bt_b
bt_logical_page = tl.load(bt_row + bt_blk, mask=bt_valid, other=0).to(tl.int32)
bt_base_phys = bt_logical_page * pages_per_block
bt_dst_base = bt_slot * pages_per_block

sbt_row = sparse_bt_ptr + pid_b * stride_sbt_b
# write valid slots -> their pages; unused tail -> 0 (in-bounds page id).
for pj in range(pages_per_block):
tl.store(sbt_row + bt_dst_base + pj, bt_base_phys + pj, mask=bt_valid)
bt_n_used = bt_n_valid * pages_per_block
off_w = tl.arange(0, BLOCK_SIZE_T * pages_per_block)
tl.store(sbt_row + off_w, tl.zeros_like(off_w), mask=off_w >= bt_n_used)

bt_tail_tokens = seq_len - last_blk * block_size
bt_has_tail = tl.sum(bt_is_tail.to(tl.int32), axis=0) > 0
bt_ctx = bt_n_full * block_size + tl.where(bt_has_tail, bt_tail_tokens, 0)
bt_ctx = tl.where(
bt_has_tail, bt_ctx, tl.minimum(bt_n_valid * block_size, seq_len)
)
tl.store(sparse_ctx_ptr + pid_b, bt_ctx)


# ---------------------------------------------------------------------------
# Python wrappers
Expand All @@ -614,12 +714,19 @@ def minimax_m3_index_topk(
local_blocks: int,
num_kv_heads: int,
sm_scale: float,
) -> torch.Tensor:
emit_sparse_block_table: bool = False,
):
"""Index block-score + top-k selection. block_size_q == 1 (per-token).

Returns topk_idx [num_kv_heads, total_q, topk] of 0-indexed block ids
(right-padded with -1). M3 has num_idx_heads == num_kv_heads, so the
per-index-head top-k maps 1:1 to kv heads (no index-head reduction needed).

When ``emit_sparse_block_table`` is True (requires num_idx_heads == 1), the
topk kernel ALSO fuses the per-query-token page-16 SHUFFLE block-table
compaction and returns ``(topk_idx, sparse_bt [total_q, topk*8], sparse_ctx
[total_q])`` ready for the ASM prefill kernel -- saving a separate build
launch + topk_idx HBM round-trip.
"""
total_q, num_idx_heads, head_dim = idx_q.shape
assert (
Expand Down Expand Up @@ -665,6 +772,20 @@ def minimax_m3_index_topk(
dtype=torch.int32,
device=idx_q.device,
)
emit = emit_sparse_block_table and num_idx_heads == 1
if emit:
sparse_bt = torch.empty(
(total_q, topk * PAGES_PER_SPARSE_BLOCK),
dtype=torch.int32,
device=idx_q.device,
)
sparse_ctx = torch.empty((total_q,), dtype=torch.int32, device=idx_q.device)
sbt_arg, sctx_arg = sparse_bt, sparse_ctx
bt_stride0, sbt_stride0 = block_table.stride(0), sparse_bt.stride(0)
else:
sbt_arg = torch.empty(1, dtype=torch.int32, device=idx_q.device)
sctx_arg = torch.empty(1, dtype=torch.int32, device=idx_q.device)
bt_stride0, sbt_stride0 = 0, 0
# block_size_q == 1 -> query blocks coincide with query tokens.
grid_topk = (max_query_len, batch, num_idx_heads)
_topk_index_kernel[grid_topk](
Expand All @@ -684,9 +805,18 @@ def minimax_m3_index_topk(
topk_idx.stride(0),
topk_idx.stride(1),
topk_idx.stride(2),
block_table,
sbt_arg,
sctx_arg,
bt_stride0,
sbt_stride0,
MASK_INIT=False,
MASK_LOCAL=False,
pages_per_block=PAGES_PER_SPARSE_BLOCK,
EMIT_SPARSE_BT=emit,
)
if emit:
return topk_idx, sparse_bt, sparse_ctx
return topk_idx


Expand All @@ -702,10 +832,17 @@ def minimax_m3_index_topk_decode(
local_blocks: int,
num_kv_heads: int,
sm_scale: float,
) -> torch.Tensor:
emit_sparse_block_table: bool = False,
):
"""Decode index block-score + top-k, both split-K (cudagraph-safe).

Returns topk_idx [num_kv_heads, batch, topk] (0-indexed block ids, -1 pad).

When ``emit_sparse_block_table`` is True (requires num_idx_heads == 1), the
merge kernel ALSO fuses the page-16 SHUFFLE block-table compaction and returns
``(topk_idx, sparse_bt [batch, topk*8], sparse_ctx [batch])`` ready for the
ASM/gluon decode kernel -- saving a separate build launch + topk_idx HBM
round-trip.
"""
total_q, num_idx_heads, head_dim = idx_q.shape
assert (
Expand Down Expand Up @@ -804,6 +941,21 @@ def minimax_m3_index_topk_decode(
topk_idx_partial.stride(2),
topk_idx_partial.stride(3),
)
emit = emit_sparse_block_table and num_idx_heads == 1
if emit:
sparse_bt = torch.empty(
(batch, topk * PAGES_PER_SPARSE_BLOCK),
dtype=torch.int32,
device=idx_q.device,
)
sparse_ctx = torch.empty((batch,), dtype=torch.int32, device=idx_q.device)
sbt_arg, sctx_arg = sparse_bt, sparse_ctx
bt_stride0, sbt_stride0 = block_table.stride(0), sparse_bt.stride(0)
else:
# dummy 1-elem tensors so the kernel always has valid pointers.
sbt_arg = torch.empty(1, dtype=torch.int32, device=idx_q.device)
sctx_arg = torch.empty(1, dtype=torch.int32, device=idx_q.device)
bt_stride0, sbt_stride0 = 0, 0
_topk_index_merge_kernel[(batch, num_idx_heads)](
topk_score_partial,
topk_idx_partial,
Expand All @@ -822,6 +974,15 @@ def minimax_m3_index_topk_decode(
topk_idx.stride(0),
topk_idx.stride(1),
topk_idx.stride(2),
block_table,
sbt_arg,
sctx_arg,
bt_stride0,
sbt_stride0,
NUM_TOPK_CHUNKS=num_topk_chunks,
pages_per_block=PAGES_PER_SPARSE_BLOCK,
EMIT_SPARSE_BT=emit,
)
if emit:
return topk_idx, sparse_bt, sparse_ctx
return topk_idx
Loading