diff --git a/atom/model_ops/attention_mla.py b/atom/model_ops/attention_mla.py index 8349801135..305ef97271 100644 --- a/atom/model_ops/attention_mla.py +++ b/atom/model_ops/attention_mla.py @@ -12,9 +12,11 @@ from aiter import ( QuantType, concat_and_cache_mla, + concat_and_cache_mla_seg, dtypes, flash_attn_varlen_func, fused_qk_rope_concat_and_cache_mla, + fused_qk_rope_concat_and_cache_mla_seg, get_hip_quant, ) from aiter.dist.parallel_state import get_dp_group @@ -46,9 +48,17 @@ concat_and_cache_mla = mark_trace( concat_and_cache_mla, prefix="kv_cache", torch_compile=False ) +concat_and_cache_mla_seg = mark_trace( + concat_and_cache_mla_seg, prefix="kv_cache_seg", torch_compile=False +) fused_qk_rope_concat_and_cache_mla = mark_trace( fused_qk_rope_concat_and_cache_mla, prefix="rope_and_kv_cache", torch_compile=False ) +fused_qk_rope_concat_and_cache_mla_seg = mark_trace( + fused_qk_rope_concat_and_cache_mla_seg, + prefix="rope_and_kv_cache", + torch_compile=False, +) mla_prefill_fwd = mark_trace(mla_prefill_fwd, prefix="mla_prefill", torch_compile=False) mla_decode_fwd = mark_trace(mla_decode_fwd, prefix="mla_decode", torch_compile=False) @@ -74,6 +84,20 @@ _MLA_MIN_HEADS = 16 # AITER MLA kernels require at least 16 attention heads +# The fused seg MLA kernels (fused_qk_rope_concat_and_cache_mla_seg + +# concat_and_cache_mla_seg + the gfx1250 mla_decode_fwd asm) share a single +# segmented KV cache layout (all tokens' nope packed first, then all tokens' +# pe) and a fixed page size hard-coded in the kernels. +_MLA_SEG_PAGE_SIZE = 64 +# The gfx1250 decode asm consumes an fp8 Q whose per-head row stride is padded +# to 768 bytes (poc_kl pack_q_page1_padded layout). q_out is allocated with this +# padded last dim and sliced to the logical kv_lora_rank + qk_rope_head_dim +# columns; the padding tail is never read by the decode kernel. +_MLA_Q_OUT_PADDED_DIM = 768 +# Dims the fused seg kernels are compiled against (KV_LORA / PE_DIM constexprs). +_MLA_SEG_KV_LORA_RANK = 512 +_MLA_SEG_PE_DIM = 64 + if False: try: from aiter.ops.triton.fused_gemm_a8w8_blockscale_split_cat import ( @@ -205,6 +229,34 @@ def __init__( else None ) self.layer_num = layer_num + # When the triton MLA backend is selected we keep the original + # interleaved KV cache layout (concat_and_cache_mla / + # fused_qk_rope_concat_and_cache_mla) and an unpadded 576-wide q_out; + # only the gfx1250 asm decode path needs the segmented layout + 768 pad. + self.use_triton_mla = bool(envs.ATOM_USE_TRITON_MLA) + # On the non-triton (aiter) path, ATOM_MLA_PAGE_SIZE selects the KV cache + # layout: >1 uses the segmented (paged) seg kernels + padded q_out, while + # ==1 falls back to the original interleaved per-token (page_size=1) + # kernels with an unpadded 576-wide q_out. The triton path never uses seg. + self.use_seg_mla = (not self.use_triton_mla) and envs.ATOM_MLA_PAGE_SIZE > 1 + + def _seg_kv_cache_view(self, kv_cache: torch.Tensor) -> torch.Tensor: + """Reshape the KV cache buffer into the page-level flat seg layout + ``[num_blocks, page_size*(kv_lora_rank + qk_rope_head_dim)]`` that the + seg write kernels expect (they derive page_size from ``stride(0)``). + + The cache is allocated token-major as ``[num_blocks*page_size, ..., entry]`` + (so ``kv_cache.shape[0]`` is the total slot count, not the block count). + A plain view groups every ``page_size`` consecutive token slots into one + block, i.e. slot = block*page_size + offset, which matches slot_mapping + and the page-level view used on the decode side + (``kv_buffer.view(-1, page_size, 1, entry)``). Using + ``kv_cache.view(kv_cache.shape[0], -1)`` here is WRONG: it keeps the + token-level stride (entry), so the kernel derives page_size=1 and writes + an interleaved layout that the page_size=64 decode then misreads.""" + page_size = get_current_atom_config().kv_cache_block_size + entry = self.kv_lora_rank + self.qk_rope_head_dim + return kv_cache.view(-1, page_size * entry) def process_weights_after_loading(self): if is_rocm_aiter_fp4bmm_enabled(): @@ -743,6 +795,14 @@ def _forward_prefill_mla( if self.head_repeat_factor > 1: q = q.repeat_interleave(self.head_repeat_factor, dim=1) + # In the seg path q arrives with a padded per-head row stride + # (_MLA_Q_OUT_PADDED_DIM); slice back to the logical + # kv_lora_rank + qk_rope_head_dim columns. The slice keeps the padded row + # stride, which the asm kernel expects. The triton and non-seg + # (page_size=1) paths use an unpadded 576-wide q_out, so no slicing. + if self.use_seg_mla: + q = q[..., : self.kv_lora_rank + self.qk_rope_head_dim] + o = torch.empty( B, self.padded_num_heads, @@ -763,16 +823,18 @@ def _forward_prefill_mla( max_q_len = 1 if kv_c_and_k_pe_cache.numel() > 0: + page_size = attn_metadata.block_size if self.kv_cache_dtype.startswith("fp8"): mla_decode_fwd( q, - kv_c_and_k_pe_cache.view(-1, 1, 1, q.shape[-1]), + kv_c_and_k_pe_cache.view(-1, page_size, 1, q.shape[-1]), o, paged_cu_seqlens_q, paged_kv_indptr, paged_kv_indices, kv_last_page_lens, max_q_len, + page_size=page_size, sm_scale=self.scale, q_scale=self._q_scale, kv_scale=self._k_scale, @@ -780,7 +842,7 @@ def _forward_prefill_mla( else: mla_prefill_fwd( q, - kv_c_and_k_pe_cache.view(-1, 1, 1, q.shape[-1]), + kv_c_and_k_pe_cache.view(-1, page_size, 1, q.shape[-1]), o, paged_cu_seqlens_q, paged_kv_indptr, @@ -831,6 +893,14 @@ def _forward_decode( if self.head_repeat_factor > 1: q = q.repeat_interleave(self.head_repeat_factor, dim=1) + # In the seg path q arrives with a padded per-head row stride + # (_MLA_Q_OUT_PADDED_DIM); slice back to the logical + # kv_lora_rank + qk_rope_head_dim columns. The slice keeps the padded row + # stride, which the asm kernel expects. The triton and non-seg + # (page_size=1) paths use an unpadded 576-wide q_out, so no slicing. + if self.use_seg_mla: + q = q[..., : self.kv_lora_rank + self.qk_rope_head_dim] + o = torch.empty( B, self.padded_num_heads, @@ -911,6 +981,8 @@ def _forward_decode( dp_size = get_dp_group().world_size use_persistent_mode = not (dp_size > 1) + if envs.ATOM_MLA_PAGE_SIZE > 1: + use_persistent_mode = False # Sparse layers in MTP verify use separate persistent metadata # (per-token, max_seqlen_qo=1) while dense layers use normal metadata @@ -939,16 +1011,24 @@ def _forward_decode( reduce_final_map = attn_metadata.reduce_final_map reduce_partial_map = attn_metadata.reduce_partial_map + # TODO refactor this + if envs.ATOM_MLA_PAGE_SIZE is not None: + page_size = envs.ATOM_MLA_PAGE_SIZE + else: + page_size = 1 + + seg_kv_buffer_4d = kv_buffer.view(-1, page_size, 1, q.shape[-1]) mla_decode_fwd( q, - kv_buffer.view(-1, 1, 1, q.shape[-1]), + seg_kv_buffer_4d, o, paged_cu_seqlens_q, paged_kv_indptr, paged_kv_indices, paged_kv_last_page_lens, max_q_len, - num_kv_splits=16, + page_size=page_size, + num_kv_splits=1, # asm passes sm_scale=self.scale, work_meta_data=work_meta_data, work_indptr=work_indptr, @@ -1009,6 +1089,22 @@ def forward_impl( apply_scale=True, shuffled_kv_cache=True, ) + elif self.use_seg_mla: + # Write the KV cache in the segmented layout so the + # decode-phase mla_decode_fwd (which reads seg layout) sees a + # consistent cache for tokens written during prefill. + # kv_cache is flattened to + # [num_blocks, page_size*(kv_lora_rank + qk_rope_head_dim)] so + # the kernel derives page_size from stride(0). + kv_cache_seg = self._seg_kv_cache_view(kv_cache) + concat_and_cache_mla_seg( + k_nope, + k_rope.squeeze(1), + kv_cache_seg, + attn_metadata.slot_mapping.flatten(), + kv_cache_dtype=self.kv_cache_dtype, + scale=self._k_scale, + ) else: concat_and_cache_mla( k_nope, @@ -1039,15 +1135,30 @@ def forward_impl( else: q_nope, q_rope = self._q_proj_and_k_up_proj(q, x_scale=q_scale) - q_out = torch.empty( - ( - q_nope.shape[0], - self.num_heads, - self.kv_lora_rank + self.qk_rope_head_dim, - ), - dtype=attn_metadata.dtype_q, - device=q_nope.device, - ) + if self.use_seg_mla: + # Seg path: allocate q_out with a padded last dim so each head row + # has a 768-byte stride (required by the gfx1250 decode asm). The + # kernel only writes the first kv_lora_rank + qk_rope_head_dim + # columns; the padding tail is left untouched and never read. + q_out = torch.empty( + ( + q_nope.shape[0], + self.num_heads, + _MLA_Q_OUT_PADDED_DIM, + ), + dtype=attn_metadata.dtype_q, + device=q_nope.device, + ) + else: + q_out = torch.empty( + ( + q_nope.shape[0], + self.num_heads, + self.kv_lora_rank + self.qk_rope_head_dim, + ), + dtype=attn_metadata.dtype_q, + device=q_nope.device, + ) if kv_cache.numel() > 0: if envs.ATOM_USE_TRITON_MLA and envs.ATOM_USE_TRITON_MLA_SHUFFLE_KV: shuffled_cache = self._shuffled_kv_view(kv_cache) @@ -1068,6 +1179,24 @@ def forward_impl( q_out=q_out, shuffled_kv_cache=True, ) + elif self.use_seg_mla: + kv_cache_seg = self._seg_kv_cache_view(kv_cache) + fused_qk_rope_concat_and_cache_mla_seg( + q_nope, + q_rope, + k_nope, + k_rope, + # Flat seg layout: [num_blocks, page_size*(kv_lora + pe)]. + kv_cache_seg, + q_out, + attn_metadata.slot_mapping, + self._k_scale, + self._q_scale, + positions, + self.rotary_emb.cos_cache, + self.rotary_emb.sin_cache, + is_neox=self.rotary_emb.is_neox_style, + ) else: fused_qk_rope_concat_and_cache_mla( q_nope, diff --git a/atom/model_ops/attentions/aiter_mla.py b/atom/model_ops/attentions/aiter_mla.py index ee522ff6dc..3a1e8ce03c 100644 --- a/atom/model_ops/attentions/aiter_mla.py +++ b/atom/model_ops/attentions/aiter_mla.py @@ -81,7 +81,10 @@ def get_impl_cls() -> Type["MLAAttention"]: class AiterMLAMetadataBuilder(CommonAttentionBuilder): def __init__(self, model_runner): - self.block_size = 1 + if envs.ATOM_MLA_PAGE_SIZE > 1: + self.block_size = envs.ATOM_MLA_PAGE_SIZE + else: + self.block_size = 1 if envs.ATOM_USE_TRITON_MLA and envs.ATOM_USE_TRITON_MLA_SHUFFLE_KV: assert model_runner.block_size == 64, ( f"ATOM_USE_TRITON_MLA=1 and ATOM_USE_TRITON_MLA_SHUFFLE_KV=1 expects --block-size 64 " @@ -114,6 +117,7 @@ def __init__(self, model_runner): self.dtype_kv, is_sparse=self.is_sparse, fast_mode=True, + max_split_per_batch=16, ) i32_kwargs = {"dtype": torch.int32, "device": self.device} @@ -195,6 +199,7 @@ def __init__(self, model_runner): self.dtype_kv, is_sparse=True, fast_mode=True, + max_split_per_batch=16, ) mla_metadata["sparse_mtp_work_meta_data"] = torch.empty( smt_wmd_size, dtype=smt_wmd_type, device=self.device diff --git a/atom/utils/envs.py b/atom/utils/envs.py index 87b00cba74..a445d6114a 100644 --- a/atom/utils/envs.py +++ b/atom/utils/envs.py @@ -42,6 +42,7 @@ os.getenv("ATOM_USE_TRITON_MLA_SHUFFLE_KV", "0") == "1" ), "ATOM_USE_TRITON_MOE": lambda: os.getenv("ATOM_USE_TRITON_MOE", "0") == "1", + "ATOM_MLA_PAGE_SIZE": lambda: int(os.getenv("ATOM_MLA_PAGE_SIZE", "1")), # --- Kernel Fusion Toggles --- # fused_compress_attn: switch between Triton (default historical) and a # flydsl drop-in for V4-Pro Compressor (Main BF16 + Indexer FP8) paths.