diff --git a/atom/model_ops/minimax_m3/index_topk.py b/atom/model_ops/minimax_m3/index_topk.py index f2d1dc9a21..cc4a2c7a18 100644 --- a/atom/model_ops/minimax_m3/index_topk.py +++ b/atom/model_ops/minimax_m3/index_topk.py @@ -25,6 +25,10 @@ # One sparse block == one KV page. SPARSE_BLOCK_SIZE = 128 +# Physical 16-pages per logical 128-block for the page-16 SHUFFLE ASM/gluon cache +# (must match sparse_attn.PAGES_PER_SPARSE_BLOCK). Used by the fused block-table +# emission in the topk kernels. +PAGES_PER_SPARSE_BLOCK = 8 # --------------------------------------------------------------------------- @@ -199,10 +203,18 @@ def _topk_index_kernel( stride_ti_h, stride_ti_n, stride_ti_t, + # --- fused sparse block-table emission (ASM/gluon prefill path) --- + block_table_ptr, # [batch, max_blocks] int32 logical 128-granularity (or dummy) + sparse_bt_ptr, # out: [total_q, topk*pages_per_block] int32 (or dummy) + sparse_ctx_ptr, # out: [total_q] int32 (or dummy) + stride_bt_b, + stride_sbt_n, BLOCK_SIZE_K: tl.constexpr, BLOCK_SIZE_T: tl.constexpr, MASK_INIT: tl.constexpr, MASK_LOCAL: tl.constexpr, + pages_per_block: tl.constexpr, # 16-pages per sparse block (8) + EMIT_SPARSE_BT: tl.constexpr, # fuse compaction iff True (num_idx_heads==1) ): tl.static_assert(BLOCK_SIZE_K > BLOCK_SIZE_T) pid_q = tl.program_id(0) @@ -282,6 +294,48 @@ def _topk_index_kernel( topk_idx = tl.where(store_mask & valid_mask, topk_idx, -1) tl.store(ti_ptrs, topk_idx.to(ti_ptrs.dtype.element_ty), mask=store_mask) + # --- fused sparse block-table build (per-query-token causal compaction) --- + # Mirrors _build_sparse_block_table_prefill_kernel over the in-register + # selection. Only the first index head emits (ASM/gluon needs num_idx_heads + # == 1). Token absolute pos p = prefix_len + pid_q (sample_interval == 1); + # causal self-block = p // block_size, length p + 1. + if EMIT_SPARSE_BT and pid_h == 0: + p = prefix_len + pid_q * sample_interval + self_blk = p // block_size + causal_len = p + 1 + bt_blk = tl.where(off_t < topk, topk_idx, -1) + # causal: drop any selected block above the self-block (defensive; the + # indexer already caps selection at valid_blocks == self_blk + 1). + bt_valid = (bt_blk >= 0) & (bt_blk <= self_blk) + bt_is_tail = bt_valid & (bt_blk == self_blk) + bt_is_full = bt_valid & (bt_blk < self_blk) + bt_n_full = tl.sum(bt_is_full.to(tl.int32), axis=0) + bt_n_valid = tl.sum(bt_valid.to(tl.int32), axis=0) + bt_earlier_full = tl.cumsum(bt_is_full.to(tl.int32), axis=0) - bt_is_full.to( + tl.int32 + ) + bt_slot = tl.where(bt_is_full, bt_earlier_full, bt_n_full) # tail -> n_full + + bt_row = block_table_ptr + pid_b * stride_bt_b + bt_logical_page = tl.load(bt_row + bt_blk, mask=bt_valid, other=0).to(tl.int32) + bt_base_phys = bt_logical_page * pages_per_block + bt_dst_base = bt_slot * pages_per_block + + sbt_row = sparse_bt_ptr + (block_start + pid_q) * stride_sbt_n + for pj in range(pages_per_block): + tl.store(sbt_row + bt_dst_base + pj, bt_base_phys + pj, mask=bt_valid) + bt_n_used = bt_n_valid * pages_per_block + off_w = tl.arange(0, BLOCK_SIZE_T * pages_per_block) + tl.store(sbt_row + off_w, tl.zeros_like(off_w), mask=off_w >= bt_n_used) + + bt_tail_tokens = causal_len - self_blk * block_size + bt_has_tail = tl.sum(bt_is_tail.to(tl.int32), axis=0) > 0 + bt_ctx = bt_n_full * block_size + tl.where(bt_has_tail, bt_tail_tokens, 0) + bt_ctx = tl.where( + bt_has_tail, bt_ctx, tl.minimum(bt_n_valid * block_size, causal_len) + ) + tl.store(sparse_ctx_ptr + (block_start + pid_q), bt_ctx) + # --------------------------------------------------------------------------- # Decode index-score kernel (split-K over seq blocks). Decode == one query @@ -531,9 +585,17 @@ def _topk_index_merge_kernel( stride_tif_h, stride_tif_b, stride_tif_t, + # --- fused sparse block-table emission (ASM/gluon decode path) --- + block_table_ptr, # [batch, max_blocks] int32 logical 128-granularity (or dummy) + sparse_bt_ptr, # out: [batch, topk*pages_per_block] int32 (or dummy) + sparse_ctx_ptr, # out: [batch] int32 (or dummy) + stride_bt_b, + stride_sbt_b, NUM_TOPK_CHUNKS: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, BLOCK_SIZE_T: tl.constexpr, + pages_per_block: tl.constexpr, # 16-pages per sparse block (8) + EMIT_SPARSE_BT: tl.constexpr, # fuse compaction iff True (num_idx_heads==1) ): pid_b = tl.program_id(0) pid_h = tl.program_id(1) @@ -595,6 +657,44 @@ def _topk_index_merge_kernel( tif_ptrs, topk_idx_final.to(ti_final_ptr.dtype.element_ty), mask=store_mask ) + # --- fused sparse block-table build (per-request decode compaction) --- + # Mirrors _build_sparse_block_table_kernel over the in-register selection, + # avoiding a second kernel launch + topk_idx HBM round-trip. Only the first + # index head emits (ASM/gluon path requires num_idx_heads == 1). + if EMIT_SPARSE_BT and pid_h == 0: + last_blk = (seq_len - 1) // block_size + bt_blk = tl.where(off_t < topk, topk_idx_final, -1) + bt_valid = bt_blk >= 0 + bt_is_tail = bt_valid & (bt_blk == last_blk) + bt_is_full = bt_valid & (bt_blk != last_blk) + bt_n_full = tl.sum(bt_is_full.to(tl.int32), axis=0) + bt_n_valid = tl.sum(bt_valid.to(tl.int32), axis=0) + bt_earlier_full = tl.cumsum(bt_is_full.to(tl.int32), axis=0) - bt_is_full.to( + tl.int32 + ) + bt_slot = tl.where(bt_is_full, bt_earlier_full, bt_n_full) # tail -> n_full + + bt_row = block_table_ptr + pid_b * stride_bt_b + bt_logical_page = tl.load(bt_row + bt_blk, mask=bt_valid, other=0).to(tl.int32) + bt_base_phys = bt_logical_page * pages_per_block + bt_dst_base = bt_slot * pages_per_block + + sbt_row = sparse_bt_ptr + pid_b * stride_sbt_b + # write valid slots -> their pages; unused tail -> 0 (in-bounds page id). + for pj in range(pages_per_block): + tl.store(sbt_row + bt_dst_base + pj, bt_base_phys + pj, mask=bt_valid) + bt_n_used = bt_n_valid * pages_per_block + off_w = tl.arange(0, BLOCK_SIZE_T * pages_per_block) + tl.store(sbt_row + off_w, tl.zeros_like(off_w), mask=off_w >= bt_n_used) + + bt_tail_tokens = seq_len - last_blk * block_size + bt_has_tail = tl.sum(bt_is_tail.to(tl.int32), axis=0) > 0 + bt_ctx = bt_n_full * block_size + tl.where(bt_has_tail, bt_tail_tokens, 0) + bt_ctx = tl.where( + bt_has_tail, bt_ctx, tl.minimum(bt_n_valid * block_size, seq_len) + ) + tl.store(sparse_ctx_ptr + pid_b, bt_ctx) + # --------------------------------------------------------------------------- # Python wrappers @@ -614,12 +714,19 @@ def minimax_m3_index_topk( local_blocks: int, num_kv_heads: int, sm_scale: float, -) -> torch.Tensor: + emit_sparse_block_table: bool = False, +): """Index block-score + top-k selection. block_size_q == 1 (per-token). Returns topk_idx [num_kv_heads, total_q, topk] of 0-indexed block ids (right-padded with -1). M3 has num_idx_heads == num_kv_heads, so the per-index-head top-k maps 1:1 to kv heads (no index-head reduction needed). + + When ``emit_sparse_block_table`` is True (requires num_idx_heads == 1), the + topk kernel ALSO fuses the per-query-token page-16 SHUFFLE block-table + compaction and returns ``(topk_idx, sparse_bt [total_q, topk*8], sparse_ctx + [total_q])`` ready for the ASM prefill kernel -- saving a separate build + launch + topk_idx HBM round-trip. """ total_q, num_idx_heads, head_dim = idx_q.shape assert ( @@ -665,6 +772,20 @@ def minimax_m3_index_topk( dtype=torch.int32, device=idx_q.device, ) + emit = emit_sparse_block_table and num_idx_heads == 1 + if emit: + sparse_bt = torch.empty( + (total_q, topk * PAGES_PER_SPARSE_BLOCK), + dtype=torch.int32, + device=idx_q.device, + ) + sparse_ctx = torch.empty((total_q,), dtype=torch.int32, device=idx_q.device) + sbt_arg, sctx_arg = sparse_bt, sparse_ctx + bt_stride0, sbt_stride0 = block_table.stride(0), sparse_bt.stride(0) + else: + sbt_arg = torch.empty(1, dtype=torch.int32, device=idx_q.device) + sctx_arg = torch.empty(1, dtype=torch.int32, device=idx_q.device) + bt_stride0, sbt_stride0 = 0, 0 # block_size_q == 1 -> query blocks coincide with query tokens. grid_topk = (max_query_len, batch, num_idx_heads) _topk_index_kernel[grid_topk]( @@ -684,9 +805,18 @@ def minimax_m3_index_topk( topk_idx.stride(0), topk_idx.stride(1), topk_idx.stride(2), + block_table, + sbt_arg, + sctx_arg, + bt_stride0, + sbt_stride0, MASK_INIT=False, MASK_LOCAL=False, + pages_per_block=PAGES_PER_SPARSE_BLOCK, + EMIT_SPARSE_BT=emit, ) + if emit: + return topk_idx, sparse_bt, sparse_ctx return topk_idx @@ -702,10 +832,17 @@ def minimax_m3_index_topk_decode( local_blocks: int, num_kv_heads: int, sm_scale: float, -) -> torch.Tensor: + emit_sparse_block_table: bool = False, +): """Decode index block-score + top-k, both split-K (cudagraph-safe). Returns topk_idx [num_kv_heads, batch, topk] (0-indexed block ids, -1 pad). + + When ``emit_sparse_block_table`` is True (requires num_idx_heads == 1), the + merge kernel ALSO fuses the page-16 SHUFFLE block-table compaction and returns + ``(topk_idx, sparse_bt [batch, topk*8], sparse_ctx [batch])`` ready for the + ASM/gluon decode kernel -- saving a separate build launch + topk_idx HBM + round-trip. """ total_q, num_idx_heads, head_dim = idx_q.shape assert ( @@ -804,6 +941,21 @@ def minimax_m3_index_topk_decode( topk_idx_partial.stride(2), topk_idx_partial.stride(3), ) + emit = emit_sparse_block_table and num_idx_heads == 1 + if emit: + sparse_bt = torch.empty( + (batch, topk * PAGES_PER_SPARSE_BLOCK), + dtype=torch.int32, + device=idx_q.device, + ) + sparse_ctx = torch.empty((batch,), dtype=torch.int32, device=idx_q.device) + sbt_arg, sctx_arg = sparse_bt, sparse_ctx + bt_stride0, sbt_stride0 = block_table.stride(0), sparse_bt.stride(0) + else: + # dummy 1-elem tensors so the kernel always has valid pointers. + sbt_arg = torch.empty(1, dtype=torch.int32, device=idx_q.device) + sctx_arg = torch.empty(1, dtype=torch.int32, device=idx_q.device) + bt_stride0, sbt_stride0 = 0, 0 _topk_index_merge_kernel[(batch, num_idx_heads)]( topk_score_partial, topk_idx_partial, @@ -822,6 +974,15 @@ def minimax_m3_index_topk_decode( topk_idx.stride(0), topk_idx.stride(1), topk_idx.stride(2), + block_table, + sbt_arg, + sctx_arg, + bt_stride0, + sbt_stride0, NUM_TOPK_CHUNKS=num_topk_chunks, + pages_per_block=PAGES_PER_SPARSE_BLOCK, + EMIT_SPARSE_BT=emit, ) + if emit: + return topk_idx, sparse_bt, sparse_ctx return topk_idx diff --git a/atom/model_ops/minimax_m3/sparse_attn.py b/atom/model_ops/minimax_m3/sparse_attn.py index 159a3f20e5..e06bade517 100644 --- a/atom/model_ops/minimax_m3/sparse_attn.py +++ b/atom/model_ops/minimax_m3/sparse_attn.py @@ -1266,9 +1266,17 @@ def _build_sparse_block_table_kernel( base_phys = logical_page * pages_per_block # [BLOCK_SIZE_T] dst_base = slot * pages_per_block # [BLOCK_SIZE_T] + # Write EVERY destination slot so the output buffer can be torch.empty (no + # memset): valid selected blocks -> their physical pages; all remaining slots + # (padding beyond n_valid, or BLOCK_SIZE_T > max_topk) -> 0 (an in-bounds page + # id; masked out by context_lens at attention time). Avoids the per-call + # torch.zeros memset that dominates at low concurrency. for j in range(pages_per_block): - write_mask = valid & (dst_base + j < BLOCK_SIZE_T * pages_per_block) - tl.store(sbt_row + dst_base + j, base_phys + j, mask=write_mask) + tl.store(sbt_row + dst_base + j, base_phys + j, mask=valid) + # zero the unused tail [n_valid*pages_per_block : width). + n_used = n_valid * pages_per_block + off_w = tl.arange(0, BLOCK_SIZE_T * pages_per_block) + tl.store(sbt_row + off_w, tl.zeros_like(off_w), mask=off_w >= n_used) # true valid token count: full blocks contribute 128 each, tail the remainder. tail_tokens = seq_len - last_blk * sm_block_size @@ -1299,8 +1307,11 @@ def minimax_m3_build_sparse_block_table( batch = topk_idx.shape[1] topk = topk_idx.shape[-1] width = topk * PAGES_PER_SPARSE_BLOCK - sparse_bt = torch.zeros((batch, width), dtype=torch.int32, device=topk_idx.device) - sparse_ctx = torch.zeros((batch,), dtype=torch.int32, device=topk_idx.device) + # Both buffers are FULLY written by the kernel (sparse_bt: every slot incl. + # padding -> 0; sparse_ctx: one entry per program), so torch.empty is safe and + # skips the per-call memset that hurts low-concurrency decode. + sparse_bt = torch.empty((batch, width), dtype=torch.int32, device=topk_idx.device) + sparse_ctx = torch.empty((batch,), dtype=torch.int32, device=topk_idx.device) _build_sparse_block_table_kernel[(batch,)]( topk_idx, block_table, @@ -1333,6 +1344,8 @@ def minimax_m3_sparse_attn_decode_asm( output: torch.Tensor, # [batch, num_heads, head_dim] k_scale: torch.Tensor | None = None, v_scale: torch.Tensor | None = None, + sparse_bt: torch.Tensor | None = None, # prebuilt (fused topk) -> skip build + sparse_ctx: torch.Tensor | None = None, ) -> None: """Block-sparse decode attention via the AITER Gluon paged-attention kernel. @@ -1344,6 +1357,9 @@ def minimax_m3_sparse_attn_decode_asm( ASM kernel at low concurrency (few decode sequences), where it parallelizes over KV partitions to keep the GPU busy. + If ``sparse_bt`` / ``sparse_ctx`` are provided (built fused inside the topk + merge kernel), the standalone compaction launch is skipped. + Requires per-rank num_kv_heads == 1 (the indexer top-k is per-kv-head; one shared block_table cannot express per-kv-head selection) and head_dim == 128. """ @@ -1356,9 +1372,10 @@ def minimax_m3_sparse_attn_decode_asm( ) assert q.shape[-1] == 128, "Gluon paged-attention requires head_dim == 128." - sparse_bt, sparse_ctx = minimax_m3_build_sparse_block_table( - topk_idx, block_table, seq_lens - ) + if sparse_bt is None or sparse_ctx is None: + sparse_bt, sparse_ctx = minimax_m3_build_sparse_block_table( + topk_idx, block_table, seq_lens + ) # Gluon split-KV decode setup (mirrors the standard MHA path in # attention_mha.py::paged_attention_triton). q is [batch, num_heads, 128]; @@ -1466,9 +1483,14 @@ def _build_sparse_block_table_prefill_kernel( base_phys = logical_page * pages_per_block dst_base = slot * pages_per_block + # Write EVERY destination slot so the output buffer can be torch.empty (no + # memset): valid selected blocks -> their physical pages; the unused tail -> + # 0 (in-bounds page id, masked out by context_lens at attention time). for j in range(pages_per_block): - write_mask = valid & (dst_base + j < BLOCK_SIZE_T * pages_per_block) - tl.store(sbt_row + dst_base + j, base_phys + j, mask=write_mask) + tl.store(sbt_row + dst_base + j, base_phys + j, mask=valid) + n_used = n_valid * pages_per_block + off_w = tl.arange(0, BLOCK_SIZE_T * pages_per_block) + tl.store(sbt_row + off_w, tl.zeros_like(off_w), mask=off_w >= n_used) # full blocks contribute 128 each; tail (self-block) contributes p%128 + 1. tail_tokens = causal_len - self_blk * sm_block_size @@ -1501,8 +1523,10 @@ def minimax_m3_build_sparse_block_table_prefill( device = topk_idx.device width = topk * PAGES_PER_SPARSE_BLOCK - sparse_bt = torch.zeros((total_q, width), dtype=torch.int32, device=device) - sparse_ctx = torch.zeros((total_q,), dtype=torch.int32, device=device) + # Fully written by the kernel (every slot incl. padding -> 0; one ctx per + # program), so torch.empty is safe and skips the per-call memset. + sparse_bt = torch.empty((total_q, width), dtype=torch.int32, device=device) + sparse_ctx = torch.empty((total_q,), dtype=torch.int32, device=device) _build_sparse_block_table_prefill_kernel[(total_q,)]( topk_idx, block_table, @@ -1543,6 +1567,8 @@ def minimax_m3_sparse_attn_prefill_asm( v_scale: torch.Tensor | None = None, cu_seqlens_q: torch.Tensor | None = None, # [batch+1] int32, for the fallback prefix_lens: torch.Tensor | None = None, # [batch] int32, for the fallback + sparse_bt: torch.Tensor | None = None, # prebuilt (fused topk) -> skip build + sparse_ctx: torch.Tensor | None = None, ) -> None: """Block-sparse PREFILL via AITER ASM pa_fwd_asm, per-token-as-decode. @@ -1566,23 +1592,24 @@ def minimax_m3_sparse_attn_prefill_asm( total_q = q.shape[0] device = q.device - if query_req_id is None or query_abs_pos is None: - # Sync-free on-device derivation: req_id[n] = #(cu_seqlens_q[1:] <= n), - # abs_pos[n] = prefix_lens[req] + (n - cu_seqlens_q[req]). - assert cu_seqlens_q is not None and prefix_lens is not None - pos = torch.arange(total_q, dtype=torch.int32, device=device) - query_req_id = torch.searchsorted( - cu_seqlens_q[1:].contiguous(), pos, right=True - ).to(torch.int32) - query_abs_pos = ( - prefix_lens[query_req_id] + (pos - cu_seqlens_q[query_req_id]) - ).to(torch.int32) if qo_indptr is None: qo_indptr = torch.arange(total_q + 1, dtype=torch.int32, device=device) - sparse_bt, sparse_ctx = minimax_m3_build_sparse_block_table_prefill( - topk_idx, block_table, query_req_id, query_abs_pos - ) + if sparse_bt is None or sparse_ctx is None: + if query_req_id is None or query_abs_pos is None: + # Sync-free on-device derivation: req_id[n] = #(cu_seqlens_q[1:] <= n), + # abs_pos[n] = prefix_lens[req] + (n - cu_seqlens_q[req]). + assert cu_seqlens_q is not None and prefix_lens is not None + pos = torch.arange(total_q, dtype=torch.int32, device=device) + query_req_id = torch.searchsorted( + cu_seqlens_q[1:].contiguous(), pos, right=True + ).to(torch.int32) + query_abs_pos = ( + prefix_lens[query_req_id] + (pos - cu_seqlens_q[query_req_id]) + ).to(torch.int32) + sparse_bt, sparse_ctx = minimax_m3_build_sparse_block_table_prefill( + topk_idx, block_table, query_req_id, query_abs_pos + ) run_pa_fwd_asm( q=q, k_cache=k_cache, diff --git a/atom/models/minimax_m3.py b/atom/models/minimax_m3.py index 8278dd32b8..bbc9686a37 100644 --- a/atom/models/minimax_m3.py +++ b/atom/models/minimax_m3.py @@ -45,7 +45,6 @@ ) from atom.model_ops.minimax_m3.sparse_attn import ( SPARSE_BLOCK_SIZE, - minimax_m3_fused_qknorm_rope_kv_insert_shuffle, minimax_m3_sparse_attn, minimax_m3_sparse_attn_decode, minimax_m3_sparse_attn_decode_asm, @@ -94,11 +93,13 @@ def _rope_theta(config: PretrainedConfig) -> float: def _can_use_fused_minimax_m3_attention_preproc( qkv: torch.Tensor, + use_asm_pa: bool, rotary_emb: nn.Module, *weights: torch.Tensor, ) -> bool: return ( hasattr(aiter, "fused_qknorm_idxrqknorm") + or use_asm_pa and qkv.dim() == 2 and qkv.is_cuda and qkv.dtype in (torch.float16, torch.bfloat16) @@ -449,16 +450,17 @@ def forward( self.num_kv_heads, self.rotary_emb.rotary_dim, self.q_norm.variance_epsilon, - None, - None, - 0, - None, - None, - None, - 0, - q, - None, - None, + None, # index_q_norm_weight + None, # index_k_norm_weight + 0, # num_index_heads + None, # slot_mapping + None, # kv_cache_k + None, # kv_cache_v + None, # index_cache + 0, # block_size + q, # q_out + None, # index_q_out + None, # index_slot_mapping ) _, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) attn_output = self.attn(q, k, v) @@ -702,7 +704,10 @@ def _run_prefill_sparse( seq_lens = prefill_metadata.seq_lens prefix_lens = prefill_metadata.context_lens block_tables = prefill_metadata.block_table - topk_idx = minimax_m3_index_topk( + # Fuse the sparse block-table build into the topk kernel for the ASM + # prefill path (num_kv_heads == 1), saving a build launch + round-trip. + fuse_bt = self._use_asm_pa and self.num_kv_heads == 1 + topk_out = minimax_m3_index_topk( index_q, self.index_cache, block_tables, @@ -716,7 +721,12 @@ def _run_prefill_sparse( self.local_blocks, self.num_kv_heads, self.scaling, + emit_sparse_block_table=fuse_bt, ) + if fuse_bt: + topk_idx, sparse_bt, sparse_ctx = topk_out + else: + topk_idx, sparse_bt, sparse_ctx = topk_out, None, None output = torch.empty_like(q) if self._use_asm_pa: minimax_m3_sparse_attn_prefill_asm( @@ -733,6 +743,8 @@ def _run_prefill_sparse( output, k_scale=self.k_scale, v_scale=self.v_scale, + sparse_bt=sparse_bt, + sparse_ctx=sparse_ctx, ) else: minimax_m3_sparse_attn( @@ -758,7 +770,11 @@ def _run_decode_sparse( ) -> torch.Tensor: decode_metadata = sparse_metadata.decode assert decode_metadata is not None - topk_idx = minimax_m3_index_topk_decode( + # When using ASM/gluon decode (num_kv_heads == 1), fuse the sparse + # block-table build into the topk merge kernel (returns sparse_bt/ctx), + # saving a separate build launch + topk_idx round-trip. + fuse_bt = self._use_asm_pa and self.num_kv_heads == 1 + topk_out = minimax_m3_index_topk_decode( index_q, self.index_cache, decode_metadata.block_table, @@ -769,7 +785,12 @@ def _run_decode_sparse( self.local_blocks, self.num_kv_heads, self.scaling, + emit_sparse_block_table=fuse_bt, ) + if fuse_bt: + topk_idx, sparse_bt, sparse_ctx = topk_out + else: + topk_idx, sparse_bt, sparse_ctx = topk_out, None, None output = torch.empty_like(q) if self._use_asm_pa: if self.num_kv_heads != 1: @@ -790,6 +811,8 @@ def _run_decode_sparse( output, k_scale=self.k_scale, v_scale=self.v_scale, + sparse_bt=sparse_bt, + sparse_ctx=sparse_ctx, ) else: minimax_m3_sparse_attn_decode( @@ -833,6 +856,7 @@ def sparse_attention_forward_impl( sparse_metadata = attn_metadata if _can_use_fused_minimax_m3_attention_preproc( qkv, + self._use_asm_pa, self.rotary_emb, self.q_norm.weight, self.k_norm.weight, @@ -848,9 +872,14 @@ def sparse_attention_forward_impl( ) cos_sin_cache = _minimax_m3_cos_sin_cache(self.rotary_emb, qkv) if self._use_asm_pa: - # Triton fallback that writes the main KV cache in page-16 - # SHUFFLE layout (the aiter fused kernel writes plain page-128). - minimax_m3_fused_qknorm_rope_kv_insert_shuffle( + # Fused aiter CUDA kernel writing the main KV cache in page-16 + # SHUFFLE layout (asm_layout=True), matching the page-16 SHUFFLE + # K/V views. ~2-4x faster than the Triton shuffle writer; block_size + # is the SHUFFLE page (16 == kv_cache_k.shape[3]), not 128. + kv_cache_dtype = ( + "auto" if self.kv_cache_dtype == "bf16" else self.kv_cache_dtype + ) + aiter.fused_qknorm_idxrqknorm( qkv, self.q_norm.weight, self.k_norm.weight, @@ -867,11 +896,23 @@ def sparse_attention_forward_impl( self.kv_cache_k, self.kv_cache_v, self.index_cache, + self.kv_cache_k.shape[3], # SHUFFLE page size (== ASM_PAGE_SIZE) q, index_q, - self.idx_head_dim, + sparse_metadata.slot_mapping, + kv_cache_dtype=kv_cache_dtype, + k_scale=None, + v_scale=None, + asm_layout=True, ) else: + # page-128 path: the op takes separate K/V caches. Pass the + # key/value slices of the fused [N, 2, block_size, nkv, hd] cache + # (zero-copy views); asm_layout=False selects page-128 addressing. + key_cache, value_cache = self.kv_cache.unbind(1) + kv_cache_dtype = ( + "auto" if self.kv_cache_dtype == "bf16" else self.kv_cache_dtype + ) aiter.fused_qknorm_idxrqknorm( qkv, self.q_norm.weight, @@ -886,15 +927,17 @@ def sparse_attention_forward_impl( self.index_k_norm.weight, self.num_idx_heads, sparse_metadata.slot_mapping, - self.kv_cache, + key_cache, + value_cache, self.index_cache, self.kv_cache.shape[2], q, index_q, sparse_metadata.slot_mapping, - self.kv_cache_dtype, - None, - None, + kv_cache_dtype=kv_cache_dtype, + k_scale=None, + v_scale=None, + asm_layout=False, ) q = q.view(-1, self.num_heads, self.head_dim) index_q = index_q.view(-1, self.num_idx_heads, self.idx_head_dim) @@ -904,47 +947,47 @@ def sparse_attention_forward_impl( output = self._run_decode_sparse(q, index_q, sparse_metadata) return output.view(-1, self.q_size) - q, k, v, index_q, index_k = qkv.split( - [ - self.q_size, - self.kv_size, - self.kv_size, - self.index_q_size, - self.idx_head_dim, - ], - dim=-1, - ) - q, k = _minimax_m3_gemma_qk_norm( - q, - k, - self.q_norm, - self.k_norm, - self.num_heads, - self.num_kv_heads, - self.head_dim, - ) - q, k = self.rotary_emb(positions, q, k) - - index_q, index_k = _minimax_m3_gemma_qk_norm( - index_q, - index_k, - self.index_q_norm, - self.index_k_norm, - self.num_idx_heads, - 1, - self.idx_head_dim, - ) - index_q, index_k = self.index_rotary_emb(positions, index_q, index_k) - - self._insert_kv(k, v, index_k, sparse_metadata.slot_mapping) - - q = q.view(-1, self.num_heads, self.head_dim) - index_q = index_q.view(-1, self.num_idx_heads, self.idx_head_dim) - if getattr(sparse_metadata, "num_prefills", 0) > 0: - output = self._run_prefill_sparse(q, index_q, sparse_metadata) - else: - output = self._run_decode_sparse(q, index_q, sparse_metadata) - return output.view(-1, self.q_size) + # q, k, v, index_q, index_k = qkv.split( + # [ + # self.q_size, + # self.kv_size, + # self.kv_size, + # self.index_q_size, + # self.idx_head_dim, + # ], + # dim=-1, + # ) + # q, k = _minimax_m3_gemma_qk_norm( + # q, + # k, + # self.q_norm, + # self.k_norm, + # self.num_heads, + # self.num_kv_heads, + # self.head_dim, + # ) + # q, k = self.rotary_emb(positions, q, k) + + # index_q, index_k = _minimax_m3_gemma_qk_norm( + # index_q, + # index_k, + # self.index_q_norm, + # self.index_k_norm, + # self.num_idx_heads, + # 1, + # self.idx_head_dim, + # ) + # index_q, index_k = self.index_rotary_emb(positions, index_q, index_k) + + # self._insert_kv(k, v, index_k, sparse_metadata.slot_mapping) + + # q = q.view(-1, self.num_heads, self.head_dim) + # index_q = index_q.view(-1, self.num_idx_heads, self.idx_head_dim) + # if getattr(sparse_metadata, "num_prefills", 0) > 0: + # output = self._run_prefill_sparse(q, index_q, sparse_metadata) + # else: + # output = self._run_decode_sparse(q, index_q, sparse_metadata) + # return output.view(-1, self.q_size) class MiniMaxM3DecoderLayer(nn.Module):