From 4c4a3254587b1be8111129574ef345aabc5d3ce6 Mon Sep 17 00:00:00 2001 From: HaonanWang98 Date: Thu, 18 Jun 2026 14:24:01 +0000 Subject: [PATCH 1/6] mla --- atom/model_ops/attention_mla.py | 152 ++++++++++++++++-- atom/model_ops/attentions/aiter_mla.py | 7 +- .../sglang/models/deepseek_mla_attention.py | 4 +- atom/plugin/vllm/attention/layer_mla.py | 6 +- atom/utils/envs.py | 1 + 5 files changed, 148 insertions(+), 22 deletions(-) diff --git a/atom/model_ops/attention_mla.py b/atom/model_ops/attention_mla.py index 8349801135..bdd7f908cc 100644 --- a/atom/model_ops/attention_mla.py +++ b/atom/model_ops/attention_mla.py @@ -12,12 +12,13 @@ from aiter import ( QuantType, concat_and_cache_mla, + concat_and_cache_mla_seg, dtypes, flash_attn_varlen_func, fused_qk_rope_concat_and_cache_mla, + fused_qk_rope_concat_and_cache_mla_seg, get_hip_quant, ) -from aiter.dist.parallel_state import get_dp_group from aiter.mla import mla_decode_fwd, mla_prefill_fwd from aiter.ops.triton.attention.mla import ( mla_decode_fwd as triton_shuffle_mla_decode_fwd, @@ -46,9 +47,17 @@ concat_and_cache_mla = mark_trace( concat_and_cache_mla, prefix="kv_cache", torch_compile=False ) +concat_and_cache_mla_seg = mark_trace( + concat_and_cache_mla_seg, prefix="kv_cache_seg", torch_compile=False +) fused_qk_rope_concat_and_cache_mla = mark_trace( fused_qk_rope_concat_and_cache_mla, prefix="rope_and_kv_cache", torch_compile=False ) +fused_qk_rope_concat_and_cache_mla_seg = mark_trace( + fused_qk_rope_concat_and_cache_mla_seg, + prefix="rope_and_kv_cache", + torch_compile=False, +) mla_prefill_fwd = mark_trace(mla_prefill_fwd, prefix="mla_prefill", torch_compile=False) mla_decode_fwd = mark_trace(mla_decode_fwd, prefix="mla_decode", torch_compile=False) @@ -74,6 +83,20 @@ _MLA_MIN_HEADS = 16 # AITER MLA kernels require at least 16 attention heads +# The fused seg MLA kernels (fused_qk_rope_concat_and_cache_mla_seg + +# concat_and_cache_mla_seg + the gfx1250 mla_decode_fwd asm) share a single +# segmented KV cache layout (all tokens' nope packed first, then all tokens' +# pe) and a fixed page size hard-coded in the kernels. +_MLA_SEG_PAGE_SIZE = 64 +# The gfx1250 decode asm consumes an fp8 Q whose per-head row stride is padded +# to 768 bytes (poc_kl pack_q_page1_padded layout). q_out is allocated with this +# padded last dim and sliced to the logical kv_lora_rank + qk_rope_head_dim +# columns; the padding tail is never read by the decode kernel. +_MLA_Q_OUT_PADDED_DIM = 768 +# Dims the fused seg kernels are compiled against (KV_LORA / PE_DIM constexprs). +_MLA_SEG_KV_LORA_RANK = 512 +_MLA_SEG_PE_DIM = 64 + if False: try: from aiter.ops.triton.fused_gemm_a8w8_blockscale_split_cat import ( @@ -205,6 +228,34 @@ def __init__( else None ) self.layer_num = layer_num + # When the triton MLA backend is selected we keep the original + # interleaved KV cache layout (concat_and_cache_mla / + # fused_qk_rope_concat_and_cache_mla) and an unpadded 576-wide q_out; + # only the gfx1250 asm decode path needs the segmented layout + 768 pad. + self.use_triton_mla = bool(envs.ATOM_USE_TRITON_MLA) + # On the non-triton (aiter) path, ATOM_MLA_PAGE_SIZE selects the KV cache + # layout: >1 uses the segmented (paged) seg kernels + padded q_out, while + # ==1 falls back to the original interleaved per-token (page_size=1) + # kernels with an unpadded 576-wide q_out. The triton path never uses seg. + self.use_seg_mla = (not self.use_triton_mla) and envs.ATOM_MLA_PAGE_SIZE > 1 + + def _seg_kv_cache_view(self, kv_cache: torch.Tensor) -> torch.Tensor: + """Reshape the KV cache buffer into the page-level flat seg layout + ``[num_blocks, page_size*(kv_lora_rank + qk_rope_head_dim)]`` that the + seg write kernels expect (they derive page_size from ``stride(0)``). + + The cache is allocated token-major as ``[num_blocks*page_size, ..., entry]`` + (so ``kv_cache.shape[0]`` is the total slot count, not the block count). + A plain view groups every ``page_size`` consecutive token slots into one + block, i.e. slot = block*page_size + offset, which matches slot_mapping + and the page-level view used on the decode side + (``kv_buffer.view(-1, page_size, 1, entry)``). Using + ``kv_cache.view(kv_cache.shape[0], -1)`` here is WRONG: it keeps the + token-level stride (entry), so the kernel derives page_size=1 and writes + an interleaved layout that the page_size=64 decode then misreads.""" + page_size = get_current_atom_config().kv_cache_block_size + entry = self.kv_lora_rank + self.qk_rope_head_dim + return kv_cache.view(-1, page_size * entry) def process_weights_after_loading(self): if is_rocm_aiter_fp4bmm_enabled(): @@ -743,6 +794,14 @@ def _forward_prefill_mla( if self.head_repeat_factor > 1: q = q.repeat_interleave(self.head_repeat_factor, dim=1) + # In the seg path q arrives with a padded per-head row stride + # (_MLA_Q_OUT_PADDED_DIM); slice back to the logical + # kv_lora_rank + qk_rope_head_dim columns. The slice keeps the padded row + # stride, which the asm kernel expects. The triton and non-seg + # (page_size=1) paths use an unpadded 576-wide q_out, so no slicing. + if self.use_seg_mla: + q = q[..., : self.kv_lora_rank + self.qk_rope_head_dim] + o = torch.empty( B, self.padded_num_heads, @@ -763,16 +822,18 @@ def _forward_prefill_mla( max_q_len = 1 if kv_c_and_k_pe_cache.numel() > 0: + page_size = get_current_atom_config().kv_cache_block_size if self.kv_cache_dtype.startswith("fp8"): mla_decode_fwd( q, - kv_c_and_k_pe_cache.view(-1, 1, 1, q.shape[-1]), + kv_c_and_k_pe_cache.view(-1, page_size, 1, q.shape[-1]), o, paged_cu_seqlens_q, paged_kv_indptr, paged_kv_indices, kv_last_page_lens, max_q_len, + page_size=page_size, sm_scale=self.scale, q_scale=self._q_scale, kv_scale=self._k_scale, @@ -780,7 +841,7 @@ def _forward_prefill_mla( else: mla_prefill_fwd( q, - kv_c_and_k_pe_cache.view(-1, 1, 1, q.shape[-1]), + kv_c_and_k_pe_cache.view(-1, page_size, 1, q.shape[-1]), o, paged_cu_seqlens_q, paged_kv_indptr, @@ -831,6 +892,14 @@ def _forward_decode( if self.head_repeat_factor > 1: q = q.repeat_interleave(self.head_repeat_factor, dim=1) + # In the seg path q arrives with a padded per-head row stride + # (_MLA_Q_OUT_PADDED_DIM); slice back to the logical + # kv_lora_rank + qk_rope_head_dim columns. The slice keeps the padded row + # stride, which the asm kernel expects. The triton and non-seg + # (page_size=1) paths use an unpadded 576-wide q_out, so no slicing. + if self.use_seg_mla: + q = q[..., : self.kv_lora_rank + self.qk_rope_head_dim] + o = torch.empty( B, self.padded_num_heads, @@ -909,8 +978,7 @@ def _forward_decode( paged_kv_indptr = attn_metadata.sparse_kv_indptr paged_kv_indices = self.sparse_kv_indices_buffer - dp_size = get_dp_group().world_size - use_persistent_mode = not (dp_size > 1) + use_persistent_mode = False # Sparse layers in MTP verify use separate persistent metadata # (per-token, max_seqlen_qo=1) while dense layers use normal metadata @@ -939,16 +1007,19 @@ def _forward_decode( reduce_final_map = attn_metadata.reduce_final_map reduce_partial_map = attn_metadata.reduce_partial_map + page_size = get_current_atom_config().kv_cache_block_size + seg_kv_buffer_4d = kv_buffer.view(-1, page_size, 1, q.shape[-1]) mla_decode_fwd( q, - kv_buffer.view(-1, 1, 1, q.shape[-1]), + seg_kv_buffer_4d, o, paged_cu_seqlens_q, paged_kv_indptr, paged_kv_indices, paged_kv_last_page_lens, max_q_len, - num_kv_splits=16, + page_size=page_size, + num_kv_splits=1, # asm passes sm_scale=self.scale, work_meta_data=work_meta_data, work_indptr=work_indptr, @@ -1009,6 +1080,22 @@ def forward_impl( apply_scale=True, shuffled_kv_cache=True, ) + elif self.use_seg_mla: + # Write the KV cache in the segmented layout so the + # decode-phase mla_decode_fwd (which reads seg layout) sees a + # consistent cache for tokens written during prefill. + # kv_cache is flattened to + # [num_blocks, page_size*(kv_lora_rank + qk_rope_head_dim)] so + # the kernel derives page_size from stride(0). + kv_cache_seg = self._seg_kv_cache_view(kv_cache) + concat_and_cache_mla_seg( + k_nope, + k_rope.squeeze(1), + kv_cache_seg, + attn_metadata.slot_mapping.flatten(), + kv_cache_dtype=self.kv_cache_dtype, + scale=self._k_scale, + ) else: concat_and_cache_mla( k_nope, @@ -1039,15 +1126,30 @@ def forward_impl( else: q_nope, q_rope = self._q_proj_and_k_up_proj(q, x_scale=q_scale) - q_out = torch.empty( - ( - q_nope.shape[0], - self.num_heads, - self.kv_lora_rank + self.qk_rope_head_dim, - ), - dtype=attn_metadata.dtype_q, - device=q_nope.device, - ) + if self.use_seg_mla: + # Seg path: allocate q_out with a padded last dim so each head row + # has a 768-byte stride (required by the gfx1250 decode asm). The + # kernel only writes the first kv_lora_rank + qk_rope_head_dim + # columns; the padding tail is left untouched and never read. + q_out = torch.empty( + ( + q_nope.shape[0], + self.num_heads, + _MLA_Q_OUT_PADDED_DIM, + ), + dtype=attn_metadata.dtype_q, + device=q_nope.device, + ) + else: + q_out = torch.empty( + ( + q_nope.shape[0], + self.num_heads, + self.kv_lora_rank + self.qk_rope_head_dim, + ), + dtype=attn_metadata.dtype_q, + device=q_nope.device, + ) if kv_cache.numel() > 0: if envs.ATOM_USE_TRITON_MLA and envs.ATOM_USE_TRITON_MLA_SHUFFLE_KV: shuffled_cache = self._shuffled_kv_view(kv_cache) @@ -1068,6 +1170,24 @@ def forward_impl( q_out=q_out, shuffled_kv_cache=True, ) + elif self.use_seg_mla: + kv_cache_seg = self._seg_kv_cache_view(kv_cache) + fused_qk_rope_concat_and_cache_mla_seg( + q_nope, + q_rope, + k_nope, + k_rope, + # Flat seg layout: [num_blocks, page_size*(kv_lora + pe)]. + kv_cache_seg, + q_out, + attn_metadata.slot_mapping, + self._k_scale, + self._q_scale, + positions, + self.rotary_emb.cos_cache, + self.rotary_emb.sin_cache, + is_neox=self.rotary_emb.is_neox_style, + ) else: fused_qk_rope_concat_and_cache_mla( q_nope, diff --git a/atom/model_ops/attentions/aiter_mla.py b/atom/model_ops/attentions/aiter_mla.py index ee522ff6dc..84f3e61e2b 100644 --- a/atom/model_ops/attentions/aiter_mla.py +++ b/atom/model_ops/attentions/aiter_mla.py @@ -81,7 +81,10 @@ def get_impl_cls() -> Type["MLAAttention"]: class AiterMLAMetadataBuilder(CommonAttentionBuilder): def __init__(self, model_runner): - self.block_size = 1 + if envs.ATOM_MLA_PAGE_SIZE > 1: + self.block_size = envs.ATOM_MLA_PAGE_SIZE + else: + self.block_size = 1 if envs.ATOM_USE_TRITON_MLA and envs.ATOM_USE_TRITON_MLA_SHUFFLE_KV: assert model_runner.block_size == 64, ( f"ATOM_USE_TRITON_MLA=1 and ATOM_USE_TRITON_MLA_SHUFFLE_KV=1 expects --block-size 64 " @@ -114,6 +117,7 @@ def __init__(self, model_runner): self.dtype_kv, is_sparse=self.is_sparse, fast_mode=True, + max_split_per_batch=16, ) i32_kwargs = {"dtype": torch.int32, "device": self.device} @@ -195,6 +199,7 @@ def __init__(self, model_runner): self.dtype_kv, is_sparse=True, fast_mode=True, + max_split_per_batch=16 ) mla_metadata["sparse_mtp_work_meta_data"] = torch.empty( smt_wmd_size, dtype=smt_wmd_type, device=self.device diff --git a/atom/plugin/sglang/models/deepseek_mla_attention.py b/atom/plugin/sglang/models/deepseek_mla_attention.py index e9ffe79f2e..0cda5b6d07 100644 --- a/atom/plugin/sglang/models/deepseek_mla_attention.py +++ b/atom/plugin/sglang/models/deepseek_mla_attention.py @@ -157,7 +157,7 @@ def _forward_absorbed( ) -> torch.Tensor: attn = self.owner_attn from aiter import dtypes - from atom.model_ops.attention_mla import fused_qk_rope_concat_and_cache_mla + from atom.model_ops.attention_mla import fused_qk_rope_concat_and_cache_mla_seg from atom.plugin.sglang.models.deepseek_mla_forward import ( _get_sglang_radix_attn, mla_absorbed_bmm, @@ -212,7 +212,7 @@ def _forward_absorbed( dtype=q_out_dtype, device=q_nope_out.device, ) - fused_qk_rope_concat_and_cache_mla( + fused_qk_rope_concat_and_cache_mla_seg( q_nope_out, q_pe, k_nope, diff --git a/atom/plugin/vllm/attention/layer_mla.py b/atom/plugin/vllm/attention/layer_mla.py index 70d31592ba..72ef3d29ee 100644 --- a/atom/plugin/vllm/attention/layer_mla.py +++ b/atom/plugin/vllm/attention/layer_mla.py @@ -5,7 +5,7 @@ import aiter import torch -from aiter import dtypes, fused_qk_rope_concat_and_cache_mla +from aiter import dtypes, fused_qk_rope_concat_and_cache_mla_seg from aiter.mla import mla_decode_fwd from aiter.ops.triton import ( batched_gemm_a16wfp4 as _fp4_bmm_module, @@ -968,7 +968,7 @@ def forward_impl( ), device=decode_ql_nope.device, ) - aiter.fused_qk_rope_concat_and_cache_mla( + aiter.fused_qk_rope_concat_and_cache_mla_seg( decode_ql_nope, decode_q_pe, k_c_normed, @@ -1204,7 +1204,7 @@ def forward_impl_sparse( device=ql_nope.device, ) if kv_cache.numel() > 0: - fused_qk_rope_concat_and_cache_mla( + fused_qk_rope_concat_and_cache_mla_seg( ql_nope, q_pe, k_c_normed, diff --git a/atom/utils/envs.py b/atom/utils/envs.py index 87b00cba74..a445d6114a 100644 --- a/atom/utils/envs.py +++ b/atom/utils/envs.py @@ -42,6 +42,7 @@ os.getenv("ATOM_USE_TRITON_MLA_SHUFFLE_KV", "0") == "1" ), "ATOM_USE_TRITON_MOE": lambda: os.getenv("ATOM_USE_TRITON_MOE", "0") == "1", + "ATOM_MLA_PAGE_SIZE": lambda: int(os.getenv("ATOM_MLA_PAGE_SIZE", "1")), # --- Kernel Fusion Toggles --- # fused_compress_attn: switch between Triton (default historical) and a # flydsl drop-in for V4-Pro Compressor (Main BF16 + Indexer FP8) paths. From c7bf6afcfd0992dcd70bef75aa77781370b5a96f Mon Sep 17 00:00:00 2001 From: HaonanWang98 Date: Fri, 19 Jun 2026 06:15:02 +0000 Subject: [PATCH 2/6] revert some non-aiter related changes --- atom/plugin/sglang/models/deepseek_mla_attention.py | 4 ++-- atom/plugin/vllm/attention/layer_mla.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/atom/plugin/sglang/models/deepseek_mla_attention.py b/atom/plugin/sglang/models/deepseek_mla_attention.py index 0cda5b6d07..e9ffe79f2e 100644 --- a/atom/plugin/sglang/models/deepseek_mla_attention.py +++ b/atom/plugin/sglang/models/deepseek_mla_attention.py @@ -157,7 +157,7 @@ def _forward_absorbed( ) -> torch.Tensor: attn = self.owner_attn from aiter import dtypes - from atom.model_ops.attention_mla import fused_qk_rope_concat_and_cache_mla_seg + from atom.model_ops.attention_mla import fused_qk_rope_concat_and_cache_mla from atom.plugin.sglang.models.deepseek_mla_forward import ( _get_sglang_radix_attn, mla_absorbed_bmm, @@ -212,7 +212,7 @@ def _forward_absorbed( dtype=q_out_dtype, device=q_nope_out.device, ) - fused_qk_rope_concat_and_cache_mla_seg( + fused_qk_rope_concat_and_cache_mla( q_nope_out, q_pe, k_nope, diff --git a/atom/plugin/vllm/attention/layer_mla.py b/atom/plugin/vllm/attention/layer_mla.py index 72ef3d29ee..70d31592ba 100644 --- a/atom/plugin/vllm/attention/layer_mla.py +++ b/atom/plugin/vllm/attention/layer_mla.py @@ -5,7 +5,7 @@ import aiter import torch -from aiter import dtypes, fused_qk_rope_concat_and_cache_mla_seg +from aiter import dtypes, fused_qk_rope_concat_and_cache_mla from aiter.mla import mla_decode_fwd from aiter.ops.triton import ( batched_gemm_a16wfp4 as _fp4_bmm_module, @@ -968,7 +968,7 @@ def forward_impl( ), device=decode_ql_nope.device, ) - aiter.fused_qk_rope_concat_and_cache_mla_seg( + aiter.fused_qk_rope_concat_and_cache_mla( decode_ql_nope, decode_q_pe, k_c_normed, @@ -1204,7 +1204,7 @@ def forward_impl_sparse( device=ql_nope.device, ) if kv_cache.numel() > 0: - fused_qk_rope_concat_and_cache_mla_seg( + fused_qk_rope_concat_and_cache_mla( ql_nope, q_pe, k_c_normed, From 749aaf46cc29cfc35418640164c5930a1e33fc11 Mon Sep 17 00:00:00 2001 From: HaonanWang98 Date: Fri, 19 Jun 2026 09:42:10 +0000 Subject: [PATCH 3/6] fix format --- atom/model_ops/attentions/aiter_mla.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/atom/model_ops/attentions/aiter_mla.py b/atom/model_ops/attentions/aiter_mla.py index 84f3e61e2b..3a1e8ce03c 100644 --- a/atom/model_ops/attentions/aiter_mla.py +++ b/atom/model_ops/attentions/aiter_mla.py @@ -199,7 +199,7 @@ def __init__(self, model_runner): self.dtype_kv, is_sparse=True, fast_mode=True, - max_split_per_batch=16 + max_split_per_batch=16, ) mla_metadata["sparse_mtp_work_meta_data"] = torch.empty( smt_wmd_size, dtype=smt_wmd_type, device=self.device From 703d17db543a8bb56fe4ce54d078c51d6bbed34e Mon Sep 17 00:00:00 2001 From: HaonanWang98 Date: Fri, 19 Jun 2026 12:59:35 +0000 Subject: [PATCH 4/6] fix block-size usage --- atom/model_ops/attention_mla.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/atom/model_ops/attention_mla.py b/atom/model_ops/attention_mla.py index bdd7f908cc..d63cdc936a 100644 --- a/atom/model_ops/attention_mla.py +++ b/atom/model_ops/attention_mla.py @@ -822,7 +822,7 @@ def _forward_prefill_mla( max_q_len = 1 if kv_c_and_k_pe_cache.numel() > 0: - page_size = get_current_atom_config().kv_cache_block_size + page_size = attn_metadata.block_size if self.kv_cache_dtype.startswith("fp8"): mla_decode_fwd( q, From 05980c39911bc40bb3cbfdbf73b24775b0d6c936 Mon Sep 17 00:00:00 2001 From: HaonanWang98 Date: Fri, 19 Jun 2026 13:08:58 +0000 Subject: [PATCH 5/6] fix block size and persistent mode --- atom/model_ops/attention_mla.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/atom/model_ops/attention_mla.py b/atom/model_ops/attention_mla.py index d63cdc936a..93a3eb653f 100644 --- a/atom/model_ops/attention_mla.py +++ b/atom/model_ops/attention_mla.py @@ -19,6 +19,7 @@ fused_qk_rope_concat_and_cache_mla_seg, get_hip_quant, ) +from aiter.dist.parallel_state import get_dp_group from aiter.mla import mla_decode_fwd, mla_prefill_fwd from aiter.ops.triton.attention.mla import ( mla_decode_fwd as triton_shuffle_mla_decode_fwd, @@ -978,7 +979,10 @@ def _forward_decode( paged_kv_indptr = attn_metadata.sparse_kv_indptr paged_kv_indices = self.sparse_kv_indices_buffer - use_persistent_mode = False + dp_size = get_dp_group().world_size + use_persistent_mode = not (dp_size > 1) + if envs.ATOM_MLA_PAGE_SIZE > 1: + use_persistent_mode = False # Sparse layers in MTP verify use separate persistent metadata # (per-token, max_seqlen_qo=1) while dense layers use normal metadata @@ -1007,7 +1011,7 @@ def _forward_decode( reduce_final_map = attn_metadata.reduce_final_map reduce_partial_map = attn_metadata.reduce_partial_map - page_size = get_current_atom_config().kv_cache_block_size + page_size = attn_metadata.block_size seg_kv_buffer_4d = kv_buffer.view(-1, page_size, 1, q.shape[-1]) mla_decode_fwd( q, From 9d7c49bc9065c6a948eb39afba04c8757a5a6e7f Mon Sep 17 00:00:00 2001 From: HaonanWang98 Date: Fri, 19 Jun 2026 13:32:12 +0000 Subject: [PATCH 6/6] fix block size to env var --- atom/model_ops/attention_mla.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/atom/model_ops/attention_mla.py b/atom/model_ops/attention_mla.py index 93a3eb653f..305ef97271 100644 --- a/atom/model_ops/attention_mla.py +++ b/atom/model_ops/attention_mla.py @@ -1011,7 +1011,12 @@ def _forward_decode( reduce_final_map = attn_metadata.reduce_final_map reduce_partial_map = attn_metadata.reduce_partial_map - page_size = attn_metadata.block_size + # TODO refactor this + if envs.ATOM_MLA_PAGE_SIZE is not None: + page_size = envs.ATOM_MLA_PAGE_SIZE + else: + page_size = 1 + seg_kv_buffer_4d = kv_buffer.view(-1, page_size, 1, q.shape[-1]) mla_decode_fwd( q,