Skip to content
Open

mla #1280

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 142 additions & 13 deletions atom/model_ops/attention_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
from aiter import (
QuantType,
concat_and_cache_mla,
concat_and_cache_mla_seg,
dtypes,
flash_attn_varlen_func,
fused_qk_rope_concat_and_cache_mla,
fused_qk_rope_concat_and_cache_mla_seg,
get_hip_quant,
)
from aiter.dist.parallel_state import get_dp_group
Expand Down Expand Up @@ -46,9 +48,17 @@
concat_and_cache_mla = mark_trace(
concat_and_cache_mla, prefix="kv_cache", torch_compile=False
)
concat_and_cache_mla_seg = mark_trace(
concat_and_cache_mla_seg, prefix="kv_cache_seg", torch_compile=False
)
fused_qk_rope_concat_and_cache_mla = mark_trace(
fused_qk_rope_concat_and_cache_mla, prefix="rope_and_kv_cache", torch_compile=False
)
fused_qk_rope_concat_and_cache_mla_seg = mark_trace(
fused_qk_rope_concat_and_cache_mla_seg,
prefix="rope_and_kv_cache",
torch_compile=False,
)
mla_prefill_fwd = mark_trace(mla_prefill_fwd, prefix="mla_prefill", torch_compile=False)
mla_decode_fwd = mark_trace(mla_decode_fwd, prefix="mla_decode", torch_compile=False)

Expand All @@ -74,6 +84,20 @@

_MLA_MIN_HEADS = 16 # AITER MLA kernels require at least 16 attention heads

# The fused seg MLA kernels (fused_qk_rope_concat_and_cache_mla_seg +
# concat_and_cache_mla_seg + the gfx1250 mla_decode_fwd asm) share a single
# segmented KV cache layout (all tokens' nope packed first, then all tokens'
# pe) and a fixed page size hard-coded in the kernels.
_MLA_SEG_PAGE_SIZE = 64
# The gfx1250 decode asm consumes an fp8 Q whose per-head row stride is padded
# to 768 bytes (poc_kl pack_q_page1_padded layout). q_out is allocated with this
# padded last dim and sliced to the logical kv_lora_rank + qk_rope_head_dim
# columns; the padding tail is never read by the decode kernel.
_MLA_Q_OUT_PADDED_DIM = 768
# Dims the fused seg kernels are compiled against (KV_LORA / PE_DIM constexprs).
_MLA_SEG_KV_LORA_RANK = 512
_MLA_SEG_PE_DIM = 64

if False:
try:
from aiter.ops.triton.fused_gemm_a8w8_blockscale_split_cat import (
Expand Down Expand Up @@ -205,6 +229,34 @@ def __init__(
else None
)
self.layer_num = layer_num
# When the triton MLA backend is selected we keep the original
# interleaved KV cache layout (concat_and_cache_mla /
# fused_qk_rope_concat_and_cache_mla) and an unpadded 576-wide q_out;
# only the gfx1250 asm decode path needs the segmented layout + 768 pad.
self.use_triton_mla = bool(envs.ATOM_USE_TRITON_MLA)
# On the non-triton (aiter) path, ATOM_MLA_PAGE_SIZE selects the KV cache
# layout: >1 uses the segmented (paged) seg kernels + padded q_out, while
# ==1 falls back to the original interleaved per-token (page_size=1)
# kernels with an unpadded 576-wide q_out. The triton path never uses seg.
self.use_seg_mla = (not self.use_triton_mla) and envs.ATOM_MLA_PAGE_SIZE > 1

def _seg_kv_cache_view(self, kv_cache: torch.Tensor) -> torch.Tensor:
"""Reshape the KV cache buffer into the page-level flat seg layout
``[num_blocks, page_size*(kv_lora_rank + qk_rope_head_dim)]`` that the
seg write kernels expect (they derive page_size from ``stride(0)``).

The cache is allocated token-major as ``[num_blocks*page_size, ..., entry]``
(so ``kv_cache.shape[0]`` is the total slot count, not the block count).
A plain view groups every ``page_size`` consecutive token slots into one
block, i.e. slot = block*page_size + offset, which matches slot_mapping
and the page-level view used on the decode side
(``kv_buffer.view(-1, page_size, 1, entry)``). Using
``kv_cache.view(kv_cache.shape[0], -1)`` here is WRONG: it keeps the
token-level stride (entry), so the kernel derives page_size=1 and writes
an interleaved layout that the page_size=64 decode then misreads."""
page_size = get_current_atom_config().kv_cache_block_size
entry = self.kv_lora_rank + self.qk_rope_head_dim
return kv_cache.view(-1, page_size * entry)

def process_weights_after_loading(self):
if is_rocm_aiter_fp4bmm_enabled():
Expand Down Expand Up @@ -743,6 +795,14 @@ def _forward_prefill_mla(
if self.head_repeat_factor > 1:
q = q.repeat_interleave(self.head_repeat_factor, dim=1)

# In the seg path q arrives with a padded per-head row stride
# (_MLA_Q_OUT_PADDED_DIM); slice back to the logical
# kv_lora_rank + qk_rope_head_dim columns. The slice keeps the padded row
# stride, which the asm kernel expects. The triton and non-seg
# (page_size=1) paths use an unpadded 576-wide q_out, so no slicing.
if self.use_seg_mla:
q = q[..., : self.kv_lora_rank + self.qk_rope_head_dim]

o = torch.empty(
B,
self.padded_num_heads,
Expand All @@ -763,24 +823,26 @@ def _forward_prefill_mla(
max_q_len = 1

if kv_c_and_k_pe_cache.numel() > 0:
page_size = attn_metadata.block_size
if self.kv_cache_dtype.startswith("fp8"):
mla_decode_fwd(
q,
kv_c_and_k_pe_cache.view(-1, 1, 1, q.shape[-1]),
kv_c_and_k_pe_cache.view(-1, page_size, 1, q.shape[-1]),
o,
paged_cu_seqlens_q,
paged_kv_indptr,
paged_kv_indices,
kv_last_page_lens,
max_q_len,
page_size=page_size,
sm_scale=self.scale,
q_scale=self._q_scale,
kv_scale=self._k_scale,
)
else:
mla_prefill_fwd(
q,
kv_c_and_k_pe_cache.view(-1, 1, 1, q.shape[-1]),
kv_c_and_k_pe_cache.view(-1, page_size, 1, q.shape[-1]),
o,
paged_cu_seqlens_q,
paged_kv_indptr,
Expand Down Expand Up @@ -831,6 +893,14 @@ def _forward_decode(
if self.head_repeat_factor > 1:
q = q.repeat_interleave(self.head_repeat_factor, dim=1)

# In the seg path q arrives with a padded per-head row stride
# (_MLA_Q_OUT_PADDED_DIM); slice back to the logical
# kv_lora_rank + qk_rope_head_dim columns. The slice keeps the padded row
# stride, which the asm kernel expects. The triton and non-seg
# (page_size=1) paths use an unpadded 576-wide q_out, so no slicing.
if self.use_seg_mla:
q = q[..., : self.kv_lora_rank + self.qk_rope_head_dim]

o = torch.empty(
B,
self.padded_num_heads,
Expand Down Expand Up @@ -911,6 +981,8 @@ def _forward_decode(

dp_size = get_dp_group().world_size
use_persistent_mode = not (dp_size > 1)
if envs.ATOM_MLA_PAGE_SIZE > 1:
use_persistent_mode = False

# Sparse layers in MTP verify use separate persistent metadata
# (per-token, max_seqlen_qo=1) while dense layers use normal metadata
Expand Down Expand Up @@ -939,16 +1011,24 @@ def _forward_decode(
reduce_final_map = attn_metadata.reduce_final_map
reduce_partial_map = attn_metadata.reduce_partial_map

# TODO refactor this
if envs.ATOM_MLA_PAGE_SIZE is not None:
page_size = envs.ATOM_MLA_PAGE_SIZE
else:
page_size = 1

seg_kv_buffer_4d = kv_buffer.view(-1, page_size, 1, q.shape[-1])
mla_decode_fwd(
q,
kv_buffer.view(-1, 1, 1, q.shape[-1]),
seg_kv_buffer_4d,
o,
paged_cu_seqlens_q,
paged_kv_indptr,
paged_kv_indices,
paged_kv_last_page_lens,
max_q_len,
num_kv_splits=16,
page_size=page_size,
num_kv_splits=1, # asm passes
sm_scale=self.scale,
work_meta_data=work_meta_data,
work_indptr=work_indptr,
Expand Down Expand Up @@ -1009,6 +1089,22 @@ def forward_impl(
apply_scale=True,
shuffled_kv_cache=True,
)
elif self.use_seg_mla:
# Write the KV cache in the segmented layout so the
# decode-phase mla_decode_fwd (which reads seg layout) sees a
# consistent cache for tokens written during prefill.
# kv_cache is flattened to
# [num_blocks, page_size*(kv_lora_rank + qk_rope_head_dim)] so
# the kernel derives page_size from stride(0).
kv_cache_seg = self._seg_kv_cache_view(kv_cache)
concat_and_cache_mla_seg(
k_nope,
k_rope.squeeze(1),
kv_cache_seg,
attn_metadata.slot_mapping.flatten(),
kv_cache_dtype=self.kv_cache_dtype,
scale=self._k_scale,
)
else:
concat_and_cache_mla(
k_nope,
Expand Down Expand Up @@ -1039,15 +1135,30 @@ def forward_impl(
else:
q_nope, q_rope = self._q_proj_and_k_up_proj(q, x_scale=q_scale)

q_out = torch.empty(
(
q_nope.shape[0],
self.num_heads,
self.kv_lora_rank + self.qk_rope_head_dim,
),
dtype=attn_metadata.dtype_q,
device=q_nope.device,
)
if self.use_seg_mla:
# Seg path: allocate q_out with a padded last dim so each head row
# has a 768-byte stride (required by the gfx1250 decode asm). The
# kernel only writes the first kv_lora_rank + qk_rope_head_dim
# columns; the padding tail is left untouched and never read.
q_out = torch.empty(
(
q_nope.shape[0],
self.num_heads,
_MLA_Q_OUT_PADDED_DIM,
),
dtype=attn_metadata.dtype_q,
device=q_nope.device,
)
else:
q_out = torch.empty(
(
q_nope.shape[0],
self.num_heads,
self.kv_lora_rank + self.qk_rope_head_dim,
),
dtype=attn_metadata.dtype_q,
device=q_nope.device,
)
if kv_cache.numel() > 0:
if envs.ATOM_USE_TRITON_MLA and envs.ATOM_USE_TRITON_MLA_SHUFFLE_KV:
shuffled_cache = self._shuffled_kv_view(kv_cache)
Expand All @@ -1068,6 +1179,24 @@ def forward_impl(
q_out=q_out,
shuffled_kv_cache=True,
)
elif self.use_seg_mla:
kv_cache_seg = self._seg_kv_cache_view(kv_cache)
fused_qk_rope_concat_and_cache_mla_seg(
q_nope,
q_rope,
k_nope,
k_rope,
# Flat seg layout: [num_blocks, page_size*(kv_lora + pe)].
kv_cache_seg,
q_out,
attn_metadata.slot_mapping,
self._k_scale,
self._q_scale,
positions,
self.rotary_emb.cos_cache,
self.rotary_emb.sin_cache,
is_neox=self.rotary_emb.is_neox_style,
)
else:
fused_qk_rope_concat_and_cache_mla(
q_nope,
Expand Down
7 changes: 6 additions & 1 deletion atom/model_ops/attentions/aiter_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,10 @@ def get_impl_cls() -> Type["MLAAttention"]:

class AiterMLAMetadataBuilder(CommonAttentionBuilder):
def __init__(self, model_runner):
self.block_size = 1
if envs.ATOM_MLA_PAGE_SIZE > 1:
self.block_size = envs.ATOM_MLA_PAGE_SIZE
else:
self.block_size = 1
if envs.ATOM_USE_TRITON_MLA and envs.ATOM_USE_TRITON_MLA_SHUFFLE_KV:
assert model_runner.block_size == 64, (
f"ATOM_USE_TRITON_MLA=1 and ATOM_USE_TRITON_MLA_SHUFFLE_KV=1 expects --block-size 64 "
Expand Down Expand Up @@ -114,6 +117,7 @@ def __init__(self, model_runner):
self.dtype_kv,
is_sparse=self.is_sparse,
fast_mode=True,
max_split_per_batch=16,
)
i32_kwargs = {"dtype": torch.int32, "device": self.device}

Expand Down Expand Up @@ -195,6 +199,7 @@ def __init__(self, model_runner):
self.dtype_kv,
is_sparse=True,
fast_mode=True,
max_split_per_batch=16,
)
mla_metadata["sparse_mtp_work_meta_data"] = torch.empty(
smt_wmd_size, dtype=smt_wmd_type, device=self.device
Expand Down
1 change: 1 addition & 0 deletions atom/utils/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
os.getenv("ATOM_USE_TRITON_MLA_SHUFFLE_KV", "0") == "1"
),
"ATOM_USE_TRITON_MOE": lambda: os.getenv("ATOM_USE_TRITON_MOE", "0") == "1",
"ATOM_MLA_PAGE_SIZE": lambda: int(os.getenv("ATOM_MLA_PAGE_SIZE", "1")),
# --- Kernel Fusion Toggles ---
# fused_compress_attn: switch between Triton (default historical) and a
# flydsl drop-in for V4-Pro Compressor (Main BF16 + Indexer FP8) paths.
Expand Down
Loading