diff --git a/atom/plugin/config.py b/atom/plugin/config.py index ae72b4fad2..17787ccd13 100644 --- a/atom/plugin/config.py +++ b/atom/plugin/config.py @@ -409,7 +409,7 @@ def _generate_atom_config_from_rtpllm_config(config: Any): return Config( model=rtpllm_model_config.ckpt_path, - max_num_batched_tokens=max(16384, max_generate_batch_size), + max_num_batched_tokens=max(max_model_len, max_generate_batch_size), max_num_seqs=max_generate_batch_size, max_model_len=max_model_len, gpu_memory_utilization=0.9, diff --git a/atom/plugin/prepare.py b/atom/plugin/prepare.py index 191e9a8bad..41b892c474 100644 --- a/atom/plugin/prepare.py +++ b/atom/plugin/prepare.py @@ -98,6 +98,7 @@ def _prepare_model_atom_sglang( def _prepare_model_atom_rtpllm( config: Any, atom_config: Any, + model_arch: str, model_cls: Any, set_attn_cls: Any, init_aiter_dist: Any, @@ -120,6 +121,12 @@ def _prepare_model_atom_rtpllm( ) set_attn_cls() + if model_arch == "GlmMoeDsaForCausalLM": + from atom.plugin.rtpllm.attention_backend import ( + apply_attention_mla_rtpllm_patch, + ) + + apply_attention_mla_rtpllm_patch() # init aiter dist for using aiter custom collective ops init_aiter_dist(config=atom_config) @@ -172,6 +179,7 @@ def prepare_model(config: Any, engine: str): return _prepare_model_atom_rtpllm( config, atom_config, + model_arch, model_cls, set_attn_cls, init_aiter_dist, diff --git a/atom/plugin/rtpllm/__init__.py b/atom/plugin/rtpllm/__init__.py index e69de29bb2..eee9517201 100644 --- a/atom/plugin/rtpllm/__init__.py +++ b/atom/plugin/rtpllm/__init__.py @@ -0,0 +1,7 @@ +"""RTP-LLM plugin helpers. + +Keep the package root import side-effect free. RTP external model registration +is triggered by importing ``atom.plugin.rtpllm.models``. +""" + +__all__: list[str] = [] diff --git a/atom/plugin/rtpllm/attention_backend/__init__.py b/atom/plugin/rtpllm/attention_backend/__init__.py index 1afa3cdb59..0e7f68318a 100644 --- a/atom/plugin/rtpllm/attention_backend/__init__.py +++ b/atom/plugin/rtpllm/attention_backend/__init__.py @@ -1,10 +1,37 @@ -from .attention_gdn import apply_attention_gdn_rtpllm_patch -from .attention_switch import apply_attention_mha_rtpllm_patch -from .rtp_full_attention import AttentionForRTPLLM, RTPFullAttention +from .rtp_mla_attention import RTPMLAAttention, apply_attention_mla_rtpllm_patch +from .rtp_sparse_mla_backend import RTPSparseMlaBackend + + +def __getattr__(name): + if name == "AttentionForRTPLLM": + from .rtp_full_attention import AttentionForRTPLLM + + return AttentionForRTPLLM + if name == "RTPFullAttention": + from .rtp_full_attention import RTPFullAttention + + return RTPFullAttention + if name == "RTPAttention": + from .rtp_full_attention import RTPFullAttention + + return RTPFullAttention + if name == "apply_attention_gdn_rtpllm_patch": + from .attention_gdn import apply_attention_gdn_rtpllm_patch + + return apply_attention_gdn_rtpllm_patch + if name == "apply_attention_mha_rtpllm_patch": + from .attention_switch import apply_attention_mha_rtpllm_patch + + return apply_attention_mha_rtpllm_patch + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + __all__ = [ "AttentionForRTPLLM", "RTPFullAttention", + "RTPMLAAttention", + "RTPSparseMlaBackend", "apply_attention_gdn_rtpllm_patch", "apply_attention_mha_rtpllm_patch", + "apply_attention_mla_rtpllm_patch", ] diff --git a/atom/plugin/rtpllm/attention_backend/rtp_mla_attention.py b/atom/plugin/rtpllm/attention_backend/rtp_mla_attention.py new file mode 100644 index 0000000000..c6c3857f68 --- /dev/null +++ b/atom/plugin/rtpllm/attention_backend/rtp_mla_attention.py @@ -0,0 +1,253 @@ +"""RTP-style MLA adapter for GLM5 rtp-llm plugin mode.""" + +from __future__ import annotations + +import inspect +from types import MethodType +from typing import Optional + +import torch + + +def _resolve_index_topk(attn) -> int: + for obj, attr in ( + (getattr(attn, "indexer", None), "index_topk"), + (getattr(attn, "indexer", None), "topk_tokens"), + (attn, "index_topk"), + (getattr(attn, "config", None), "index_topk"), + ): + value = getattr(obj, attr, None) if obj is not None else None + if value is not None: + return int(value) + raise AttributeError("GLM5 RTP MLA indexer requires index_topk/topk_tokens") + + +def _get_topk_indices_buffer(attn) -> torch.Tensor: + indexer = getattr(attn, "indexer", None) + buffer = ( + getattr(indexer, "topk_indices_buffer", None) if indexer is not None else None + ) + if buffer is None: + buffer = getattr(attn, "topk_indices_buffer", None) + if buffer is None: + buffer = getattr(attn, "_topk_indices_buffer", None) + if buffer is None: + raise AttributeError("GLM5 RTP MLA indexer requires topk_indices_buffer") + return buffer + + +def _should_emit_topk_indices(attn) -> bool: + try: + from atom.utils.forward_context import get_forward_context + + forward_context = get_forward_context() + except Exception: + return True + + context = getattr(forward_context, "context", None) + if getattr(context, "is_dummy_run", False): + return False + return True + + +def _use_rtp_sparse_attn_indexer(indexer: object | None) -> None: + if indexer is None or not hasattr(indexer, "sparse_attn_indexer_impl"): + return + __import__("atom.plugin.rtpllm.attention_backend.rtp_sparse_mla_backend") + indexer.sparse_attn_indexer_impl = torch.ops.aiter.rtp_sparse_attn_indexer + if getattr(indexer, "_atom_rtp_topk_buffer_patched", False) or not hasattr( + indexer, "forward" + ): + return + original_forward = indexer.forward + + def _forward_with_topk_buffer(self, hidden_states, *args, **kwargs): + num_tokens = int(hidden_states.shape[0]) + topk_tokens = getattr(self, "topk_tokens", None) + if topk_tokens is None: + topk_tokens = getattr(self, "index_topk") + topk_tokens = int(topk_tokens) + buffer = getattr(self, "topk_indices_buffer", None) + needs_new_buffer = ( + buffer is None + or buffer.dim() != 2 + or buffer.device != hidden_states.device + or int(buffer.shape[0]) < num_tokens + or int(buffer.shape[1]) < topk_tokens + ) + if needs_new_buffer: + buffer = torch.empty( + num_tokens, + topk_tokens, + dtype=torch.int32, + device=hidden_states.device, + ) + self.topk_indices_buffer = buffer + self.sparse_kv_indices_buffer = self.topk_indices_buffer + return original_forward(hidden_states, *args, **kwargs) + + indexer.forward = MethodType(_forward_with_topk_buffer, indexer) + indexer._atom_rtp_topk_buffer_patched = True + + +class RTPMLAAttention: + """RTP MLA adapter for the native GLM5 MLA call contract.""" + + use_mla = True + + def __init__(self, *args, **kwargs) -> None: + self.args = args + self.kwargs = kwargs + mla_modules = kwargs.get("mla_modules") + self.mla_modules = mla_modules + self.q_proj = getattr(mla_modules, "q_proj", None) + self.o_proj = getattr(mla_modules, "o_proj", None) + self.kv_b_proj = getattr(mla_modules, "kv_b_proj", None) + self.indexer = getattr(mla_modules, "indexer", None) + _use_rtp_sparse_attn_indexer(self.indexer) + self.qk_head_dim = getattr(mla_modules, "qk_head_dim", None) + self.v_head_dim = getattr(mla_modules, "v_head_dim", None) + self.q_lora_rank = getattr(mla_modules, "q_lora_rank", None) + self.kv_lora_rank = getattr(mla_modules, "kv_lora_rank", None) + self.num_heads = getattr(mla_modules, "num_heads", None) + self.num_local_heads = getattr(mla_modules, "num_local_heads", self.num_heads) + self.index_topk = getattr(mla_modules, "index_topk", None) + self.topk_indices_buffer = ( + getattr(self.indexer, "topk_indices_buffer", None) + if self.indexer is not None + else None + ) + injected_backend = kwargs.get("sparse_backend") + if injected_backend is not None: + self.sparse_backend = injected_backend + elif mla_modules is not None: + from atom.plugin.rtpllm.attention_backend.rtp_sparse_mla_backend import ( + RTPSparseMlaBackend, + ) + + self.sparse_backend = RTPSparseMlaBackend( + v_head_dim=mla_modules.v_head_dim, + mla_modules=mla_modules, + scale=kwargs.get("scale"), + ) + else: + self.sparse_backend = None + self.kv_cache = kwargs.get("kv_cache") + self.layer_id = int(kwargs.get("layer_id", kwargs.get("layer_num", 0))) + self._sparse_backend_accepts_positions = ( + self._backend_accepts_positions(self.sparse_backend) + if self.sparse_backend is not None + else False + ) + + @staticmethod + def _backend_accepts_positions(backend: object) -> bool: + try: + signature = inspect.signature(backend.forward) + except (AttributeError, TypeError, ValueError): + return False + return "positions" in signature.parameters or any( + parameter.kind == inspect.Parameter.VAR_KEYWORD + for parameter in signature.parameters.values() + ) + + def _project_query( + self, query: torch.Tensor, q_scale: Optional[torch.Tensor] + ) -> tuple[torch.Tensor, bool]: + if query.ndim == 3: + return query, False + if self.q_proj is None: + return query, False + + q = self.q_proj(query, q_scale) + if q.ndim == 3: + return q, True + + num_heads = ( + self.num_local_heads if self.num_local_heads is not None else self.num_heads + ) + if num_heads is None: + if self.qk_head_dim is None: + raise AttributeError("GLM5 RTP MLA native contract requires num_heads") + num_heads = q.shape[-1] // int(self.qk_head_dim) + if self.qk_head_dim is None: + self.qk_head_dim = q.shape[-1] // int(num_heads) + return q.reshape(-1, int(num_heads), int(self.qk_head_dim)), True + + def _resolve_topk_indices( + self, + query: torch.Tensor, + q_scale: Optional[torch.Tensor], + positions: Optional[torch.Tensor], + explicit_topk_indices: Optional[torch.Tensor], + ) -> Optional[torch.Tensor]: + if explicit_topk_indices is not None: + return explicit_topk_indices + if self.indexer is None: + return None + + if not _should_emit_topk_indices(self): + return None + index_topk = _resolve_index_topk(self) + return _get_topk_indices_buffer(self)[: query.shape[0], :index_topk] + + def forward( + self, + query: torch.Tensor, + compressed_kv: torch.Tensor, + k_pe: torch.Tensor, + positions: Optional[torch.Tensor] = None, + q_scale: Optional[torch.Tensor] = None, + topk_indices: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + if self.sparse_backend is None: + raise NotImplementedError( + "RTPMLAAttention requires an attention backend for contract execution" + ) + q, native_projected = self._project_query(query, q_scale) + topk_indices = self._resolve_topk_indices( + query, + q_scale, + positions, + kwargs.get("topk_indices", topk_indices), + ) + forward_kwargs = {"topk_indices": topk_indices} + if self._sparse_backend_accepts_positions: + forward_kwargs["positions"] = positions + attn_output = self.sparse_backend.forward( + q, + compressed_kv, + k_pe, + self.kv_cache, + self.layer_id, + **forward_kwargs, + ) + if native_projected and self.o_proj is not None: + attn_output = attn_output.reshape(attn_output.shape[0], -1).contiguous() + return self.o_proj(attn_output) + return attn_output + + __call__ = forward + + +def apply_attention_mla_rtpllm_patch() -> None: + """Switch ATOM's generic Attention symbol to the RTP MLA adapter.""" + + import importlib + import sys + + ops = importlib.import_module("atom.model_ops") + base_attention = importlib.import_module("atom.model_ops.base_attention") + + ops.RTPMLAAttention = RTPMLAAttention + ops.Attention = RTPMLAAttention + base_attention.Attention = RTPMLAAttention + + deepseek_v2 = sys.modules.get("atom.models.deepseek_v2") + if deepseek_v2 is None: + try: + import atom.models.deepseek_v2 as deepseek_v2 + except (ImportError, ModuleNotFoundError): + return + deepseek_v2.Attention = RTPMLAAttention diff --git a/atom/plugin/rtpllm/attention_backend/rtp_sparse_mla_backend.py b/atom/plugin/rtpllm/attention_backend/rtp_sparse_mla_backend.py new file mode 100644 index 0000000000..263863031e --- /dev/null +++ b/atom/plugin/rtpllm/attention_backend/rtp_sparse_mla_backend.py @@ -0,0 +1,1908 @@ +"""Sparse MLA backend for GLM5 rtp-llm plugin mode.""" + +from __future__ import annotations + +import importlib +import inspect +import os +from dataclasses import dataclass +from typing import Any, Optional + +import torch + +from atom.utils.custom_register import direct_register_custom_op + + +class _SparseUnavailable(RuntimeError): + pass + + +def _resolve_plugin_sparse_index_converter(): + """Resolve the plugin-style request-local topk to global KV index converter.""" + errors: list[str] = [] + for module_name in ( + # Compatibility import path used by earlier plugin layouts. + "atom.plugin.attention_mla_sparse", + # Current plugin helper location with the same call signature. + "atom.plugin.vllm.attention.layer_sparse_mla", + ): + try: + module = importlib.import_module(module_name) + return getattr(module, "triton_convert_req_index_to_global_index") + except Exception as exc: + errors.append(f"{module_name}: {exc}") + raise _SparseUnavailable( + "plugin sparse MLA index converter unavailable; " + "; ".join(errors) + ) + + +@dataclass +class _AbsorbedWeights: + w_kc: torch.Tensor + w_vc: torch.Tensor + + +@dataclass +class _AtomSparseMetadata: + qo_indptr: torch.Tensor + paged_kv_indptr: torch.Tensor + paged_kv_indices: torch.Tensor + paged_kv_last_page_len: torch.Tensor + work_meta_data: torch.Tensor + work_indptr: torch.Tensor + work_info_set: torch.Tensor + reduce_indptr: torch.Tensor + reduce_final_map: torch.Tensor + reduce_partial_map: torch.Tensor + padded_num_heads: int + head_repeat_factor: int + page_size: int + + +class _LightweightSparseMlaImpl: + """Lightweight implementation for unit tests and explicit dependency injection.""" + + def __init__(self, v_head_dim: int) -> None: + self.v_head_dim = int(v_head_dim) + self.calls = [] + + def forward( + self, + q: torch.Tensor, + compressed_kv: torch.Tensor, + k_pe: torch.Tensor, + kv_cache: object, + layer_id: int, + *, + topk_indices: torch.Tensor, + attn_metadata: object, + positions: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + self.calls.append( + { + "q": q, + "compressed_kv": compressed_kv, + "k_pe": k_pe, + "kv_cache": kv_cache, + "layer_id": layer_id, + "topk_indices": topk_indices, + "attn_metadata": attn_metadata, + "positions": positions, + } + ) + return q.new_zeros((q.shape[0], q.shape[1], self.v_head_dim)) + + +class _RealSparseMlaImpl: + """Runtime sparse MLA adapter for ATOM-owned GLM5 weights and RTP KV cache.""" + + def __init__( + self, + *, + mla_modules: Any, + v_head_dim: int, + scale: Optional[float] = None, + ) -> None: + self.mla_modules = mla_modules + self.v_head_dim = int(v_head_dim) + self.kv_lora_rank = int(getattr(mla_modules, "kv_lora_rank")) + self.qk_nope_head_dim = int(getattr(mla_modules, "qk_nope_head_dim")) + self.qk_rope_head_dim = int(getattr(mla_modules, "qk_rope_head_dim")) + self.num_heads = int(getattr(mla_modules, "num_heads", 0) or 0) + self.rotary_emb = getattr(mla_modules, "rotary_emb", None) + self.kv_b_proj = getattr(mla_modules, "kv_b_proj", None) + self.scale = ( + float(scale) + if scale is not None + else float((self.qk_nope_head_dim + self.qk_rope_head_dim) ** -0.5) + ) + self._absorbed_weights: _AbsorbedWeights | None = None + self._cache_write_scale: dict[torch.device, torch.Tensor] = {} + self._cg_sparse_bufs: dict[str, torch.Tensor] | None = None + self._cg_workspace_signature: tuple[Any, ...] | None = None + self._enable_sparse_validate = ( + os.getenv("ATOM_RTP_GLM5_SPARSE_VALIDATE", "0") == "1" + ) + + @staticmethod + def _validate_sparse_index_contract( + *, + paged_kv_indptr: torch.Tensor, + paged_kv_indices: torch.Tensor, + num_tokens: int, + page_size: int, + max_slots: int, + ) -> None: + if int(paged_kv_indptr.numel()) != num_tokens + 1: + raise _SparseUnavailable( + "GLM5 RTP sparse MLA invalid paged_kv_indptr length " + f"(got={int(paged_kv_indptr.numel())}, expected={num_tokens + 1})." + ) + if int(paged_kv_indptr[0].item()) != 0: + raise _SparseUnavailable( + "GLM5 RTP sparse MLA paged_kv_indptr[0] must be 0, " + f"got {int(paged_kv_indptr[0].item())}." + ) + if num_tokens > 0: + deltas = paged_kv_indptr[1:] - paged_kv_indptr[:-1] + if bool((deltas < 0).any().item()): + raise _SparseUnavailable( + "GLM5 RTP sparse MLA paged_kv_indptr must be non-decreasing." + ) + used = int(paged_kv_indptr[-1].item()) + if used < 0 or used > int(paged_kv_indices.numel()): + raise _SparseUnavailable( + "GLM5 RTP sparse MLA paged_kv_indptr[-1] out of range " + f"(used={used}, capacity={int(paged_kv_indices.numel())})." + ) + if used == 0: + return + used_indices = paged_kv_indices[:used] + min_index = int(used_indices.min().item()) + max_index = int(used_indices.max().item()) + if min_index < 0 or max_index >= max_slots: + raise _SparseUnavailable( + "GLM5 RTP sparse MLA produced out-of-range paged_kv_indices " + f"(min={min_index}, max={max_index}, slots={max_slots}, " + f"page_size={page_size})." + ) + + @staticmethod + def _validate_sparse_last_page_contract( + *, + paged_kv_indptr: torch.Tensor, + paged_kv_last_page_len: torch.Tensor, + num_tokens: int, + page_size: int, + ) -> None: + if int(paged_kv_last_page_len.numel()) != int(num_tokens): + raise _SparseUnavailable( + "GLM5 RTP sparse MLA invalid paged_kv_last_page_len length " + f"(got={int(paged_kv_last_page_len.numel())}, expected={int(num_tokens)})." + ) + if num_tokens <= 0: + return + deltas = paged_kv_indptr[1:] - paged_kv_indptr[:-1] + active_mask = deltas > 0 + if not bool(active_mask.any().item()): + return + active_last_page_len = paged_kv_last_page_len[active_mask] + min_last_page_len = int(active_last_page_len.min().item()) + max_last_page_len = int(active_last_page_len.max().item()) + if min_last_page_len < 1 or max_last_page_len > int(page_size): + raise _SparseUnavailable( + "GLM5 RTP sparse MLA invalid paged_kv_last_page_len range " + f"(min={min_last_page_len}, max={max_last_page_len}, " + f"page_size={int(page_size)})." + ) + if int(page_size) == 1 and bool((active_last_page_len != 1).any().item()): + raise _SparseUnavailable( + "GLM5 RTP sparse MLA expects paged_kv_last_page_len==1 when page_size=1." + ) + + @staticmethod + def _kv_token_slot_capacity(kv_cache_base: torch.Tensor) -> int: + if kv_cache_base.ndim <= 0: + return 0 + latent_dim = int(kv_cache_base.shape[-1]) if kv_cache_base.ndim >= 1 else 0 + if latent_dim <= 0: + return 0 + return int(kv_cache_base.numel() // latent_dim) + + def _infer_num_heads(self, q: torch.Tensor) -> int: + num_heads = int(q.shape[1]) + if self.num_heads != num_heads: + self.num_heads = num_heads + return num_heads + + def _infer_num_heads_from_weight(self, fallback: int) -> int: + try: + weight = self._read_kv_b_proj_weight() + except Exception: + return int(fallback) + per_head_dim = int(self.qk_nope_head_dim + self.v_head_dim) + if per_head_dim <= 0 or weight.ndim != 2: + return int(fallback) + for dim in weight.shape: + dim_i = int(dim) + if dim_i > 0 and dim_i % per_head_dim == 0: + candidate = dim_i // per_head_dim + if candidate > 0: + return max(int(fallback), int(candidate)) + return int(fallback) + + def _read_kv_b_proj_weight(self) -> torch.Tensor: + if self.kv_b_proj is None: + raise _SparseUnavailable("GLM5 RTP sparse MLA requires kv_b_proj.") + try: + from atom.model_ops.utils import get_and_maybe_dequant_weights + + weight = get_and_maybe_dequant_weights(self.kv_b_proj) + except Exception: + weight = getattr(self.kv_b_proj, "weight", None) + if not isinstance(weight, torch.Tensor): + raise _SparseUnavailable( + "GLM5 RTP sparse MLA cannot read kv_b_proj.weight." + ) + if weight.dtype in ( + getattr(torch, "float8_e4m3fn", None), + getattr(torch, "float8_e4m3fnuz", None), + getattr(torch, "float8_e5m2", None), + getattr(torch, "float8_e5m2fnuz", None), + ): + raise _SparseUnavailable( + "GLM5 RTP sparse MLA needs dequantized kv_b_proj weights for " + "the current adapter." + ) + return weight + + def _get_absorbed_weights(self, q: torch.Tensor) -> _AbsorbedWeights: + cached = self._absorbed_weights + if cached is not None and cached.w_kc.device == q.device: + return cached + + weight = self._read_kv_b_proj_weight().to(device=q.device) + num_heads = self._infer_num_heads(q) + expected_out = num_heads * (self.qk_nope_head_dim + self.v_head_dim) + if weight.ndim != 2: + raise _SparseUnavailable( + f"GLM5 RTP sparse MLA got invalid kv_b_proj weight shape {tuple(weight.shape)}." + ) + if ( + int(weight.shape[0]) == expected_out + and int(weight.shape[1]) == self.kv_lora_rank + ): + kv_b_weight = weight.T.contiguous() + elif ( + int(weight.shape[1]) == expected_out + and int(weight.shape[0]) == self.kv_lora_rank + ): + kv_b_weight = weight.contiguous() + else: + raise _SparseUnavailable( + "GLM5 RTP sparse MLA kv_b_proj weight shape mismatch " + f"(got={tuple(weight.shape)}, expected_out={expected_out}, " + f"kv_lora_rank={self.kv_lora_rank})." + ) + + kv_b_weight = kv_b_weight.view( + self.kv_lora_rank, + num_heads, + self.qk_nope_head_dim + self.v_head_dim, + ) + w_uk, w_uv = kv_b_weight.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + absorbed = _AbsorbedWeights( + w_kc=w_uk.permute(1, 2, 0).contiguous(), + w_vc=w_uv.permute(1, 0, 2).contiguous(), + ) + self._absorbed_weights = absorbed + return absorbed + + def _apply_rope( + self, + q: torch.Tensor, + k_pe: torch.Tensor, + positions: Optional[torch.Tensor], + ) -> tuple[torch.Tensor, torch.Tensor]: + rope_dim = int(self.qk_rope_head_dim) + if rope_dim == 0: + return q, k_pe + if self.rotary_emb is None: + raise _SparseUnavailable("GLM5 RTP sparse MLA requires rotary_emb.") + if positions is None or int(positions.numel()) != int(q.shape[0]): + raise _SparseUnavailable( + "GLM5 RTP sparse MLA requires per-token positions for RoPE " + f"(positions={None if positions is None else int(positions.numel())}, " + f"tokens={int(q.shape[0])})." + ) + in_capture = torch.cuda.is_current_stream_capturing() + if in_capture: + if self._cg_sparse_bufs is None: + raise _SparseUnavailable( + "GLM5 RTP sparse MLA capture requires RoPE buffers." + ) + if positions.device != q.device or positions.dtype != torch.long: + raise _SparseUnavailable( + "GLM5 RTP sparse MLA capture requires int64 positions on device." + ) + if not positions.is_contiguous(): + raise _SparseUnavailable( + "GLM5 RTP sparse MLA capture requires contiguous positions." + ) + q_rope = self._cg_sparse_bufs["q_rope"][ + : q.shape[0], : q.shape[1], : q.shape[2] + ] + q_rope.copy_(q) + if k_pe.dim() == 2: + k_pe_rope = self._cg_sparse_bufs["k_pe_rope_2d"][ + : k_pe.shape[0], : k_pe.shape[1] + ] + elif k_pe.dim() == 3 and int(k_pe.shape[1]) == 1: + k_pe_rope = self._cg_sparse_bufs["k_pe_rope_3d"][ + : k_pe.shape[0], : k_pe.shape[1], : k_pe.shape[2] + ] + elif k_pe.dim() == 3: + k_pe_rope = self._cg_sparse_bufs["k_pe_rope_heads"][ + : k_pe.shape[0], : k_pe.shape[1], : k_pe.shape[2] + ] + else: + raise _SparseUnavailable( + f"GLM5 RTP sparse MLA capture got invalid k_pe ndim={k_pe.dim()}." + ) + k_pe_rope.copy_(k_pe) + rope_positions = positions.view(-1) + else: + q_rope = q.clone() + k_pe_rope = k_pe.clone() + rope_positions = positions.reshape(-1).to(device=q.device, dtype=torch.long) + rotated_q_pe, rotated_k_pe = self.rotary_emb( + rope_positions, + q_rope[..., -rope_dim:], + k_pe_rope, + ) + q_rope[..., -rope_dim:] = rotated_q_pe + return q_rope, rotated_k_pe + + def _cache_dtype_name(self, kv_cache_base: torch.Tensor) -> str: + fp8_dtypes = { + dtype + for dtype in ( + getattr(torch, "float8_e4m3fn", None), + getattr(torch, "float8_e4m3fnuz", None), + getattr(torch, "float8_e5m2", None), + getattr(torch, "float8_e5m2fnuz", None), + torch.uint8, + ) + if dtype is not None + } + if kv_cache_base.dtype not in fp8_dtypes: + return "auto" + # RTP allocates GLM5 FP8 MLA KV cache in the aiter 576-byte/token layout. + return "fp8" + + def _write_current_to_cache( + self, + *, + compressed_kv: torch.Tensor, + k_pe: torch.Tensor, + kv_cache: Any, + attn_metadata: Any, + ) -> torch.Tensor: + kv_cache_base = getattr(kv_cache, "kv_cache_base", None) + if not isinstance(kv_cache_base, torch.Tensor) or kv_cache_base.numel() == 0: + raise _SparseUnavailable("GLM5 RTP sparse MLA requires kv_cache_base.") + slot_mapping = getattr(attn_metadata, "slot_mapping", None) + if slot_mapping is None: + plugin_metadata = getattr(attn_metadata, "plugin_metadata", None) + slot_mapping = getattr(plugin_metadata, "slot_mapping", None) + if not isinstance(slot_mapping, torch.Tensor): + raise _SparseUnavailable("GLM5 RTP sparse MLA requires slot_mapping.") + try: + from aiter import concat_and_cache_mla + except Exception as exc: + raise _SparseUnavailable( + f"aiter.concat_and_cache_mla unavailable: {exc}" + ) from exc + + scale = self._cache_write_scale.get(compressed_kv.device) + if scale is None: + scale = torch.tensor(1.0, dtype=torch.float32, device=compressed_kv.device) + self._cache_write_scale[compressed_kv.device] = scale + in_capture = torch.cuda.is_current_stream_capturing() + if in_capture: + if ( + slot_mapping.device != compressed_kv.device + or slot_mapping.dtype != torch.int64 + ): + raise _SparseUnavailable( + "GLM5 RTP sparse MLA capture requires int64 slot_mapping on device." + ) + slot_mapping_for_cache = slot_mapping + else: + slot_mapping_for_cache = slot_mapping.to( + device=compressed_kv.device, dtype=torch.int64 + ) + try: + concat_and_cache_mla( + compressed_kv, + k_pe, + kv_cache_base, + slot_mapping_for_cache, + kv_cache_dtype=self._cache_dtype_name(kv_cache_base), + scale=scale, + ) + except Exception as exc: + raise _SparseUnavailable(f"concat_and_cache_mla failed: {exc}") from exc + return kv_cache_base + + @staticmethod + def _build_req_id_per_token( + attn_metadata: Any, + num_tokens: int, + device: torch.device, + ) -> torch.Tensor: + plugin_metadata = getattr(attn_metadata, "plugin_metadata", None) + req_id = getattr(plugin_metadata, "req_id_per_token", None) + if isinstance(req_id, torch.Tensor) and int(req_id.numel()) >= num_tokens: + return req_id[:num_tokens].to(device=device, dtype=torch.int32) + query_start_loc = getattr(plugin_metadata, "query_start_loc", None) + if query_start_loc is None: + query_start_loc = getattr(plugin_metadata, "rtp_cu_seqlens_q", None) + if query_start_loc is None: + query_start_loc = getattr(attn_metadata, "cu_seqlens_q", None) + if ( + isinstance(query_start_loc, torch.Tensor) + and int(query_start_loc.numel()) >= 2 + ): + qsl = query_start_loc.to(device=device, dtype=torch.int64) + lengths = qsl[1:] - qsl[:-1] + return torch.repeat_interleave( + torch.arange(int(lengths.numel()), device=device, dtype=torch.int32), + lengths, + )[:num_tokens].contiguous() + return torch.arange(num_tokens, device=device, dtype=torch.int32) + + @staticmethod + def _block_table(attn_metadata: Any, device: torch.device) -> torch.Tensor: + plugin_metadata = getattr(attn_metadata, "plugin_metadata", None) + block_table = getattr(plugin_metadata, "block_table", None) + if block_table is None: + block_table = getattr(attn_metadata, "block_tables", None) + if not isinstance(block_table, torch.Tensor): + raise _SparseUnavailable("GLM5 RTP sparse MLA requires block_table.") + if block_table.ndim == 1: + block_table = block_table.unsqueeze(0) + return block_table.to(device=device, dtype=torch.int32) + + @staticmethod + def _convert_topk_to_global( + *, + topk_indices: torch.Tensor, + attn_metadata: Any, + block_size: int, + ) -> torch.Tensor: + if int(block_size) <= 0: + raise _SparseUnavailable( + f"GLM5 RTP sparse MLA requires positive block_size, got {block_size}." + ) + num_tokens, topk = topk_indices.shape + device = topk_indices.device + block_table = _RealSparseMlaImpl._block_table(attn_metadata, device) + req_id = _RealSparseMlaImpl._build_req_id_per_token( + attn_metadata, num_tokens, device + ).to(dtype=torch.long) + token_indices = topk_indices.to(device=device, dtype=torch.long) + valid = token_indices >= 0 + block_cols = torch.div( + torch.clamp(token_indices, min=0), + int(block_size), + rounding_mode="floor", + ) + offsets = torch.remainder(torch.clamp(token_indices, min=0), int(block_size)) + valid = ( + valid & (req_id[:, None] >= 0) & (req_id[:, None] < block_table.shape[0]) + ) + valid = valid & (block_cols >= 0) & (block_cols < block_table.shape[1]) + safe_req = torch.clamp(req_id, min=0, max=max(int(block_table.shape[0]) - 1, 0)) + safe_cols = torch.clamp( + block_cols, min=0, max=max(int(block_table.shape[1]) - 1, 0) + ) + block_ids = block_table.to(dtype=torch.long)[safe_req[:, None], safe_cols] + valid = valid & (block_ids >= 0) + global_indices = block_ids * int(block_size) + offsets + return torch.where(valid, global_indices, torch.zeros_like(global_indices)).to( + dtype=torch.int32 + ) + + @staticmethod + def _aiter_dtype_for_tensor(tensor: torch.Tensor) -> Any: + try: + from aiter import dtypes + except Exception as exc: + raise _SparseUnavailable(f"aiter dtypes unavailable: {exc}") from exc + + fp8_dtypes = { + dtype + for dtype in ( + getattr(torch, "float8_e4m3fn", None), + getattr(torch, "float8_e4m3fnuz", None), + getattr(torch, "float8_e5m2", None), + getattr(torch, "float8_e5m2fnuz", None), + torch.uint8, + getattr(dtypes, "fp8", None), + ) + if dtype is not None + } + if tensor.dtype in fp8_dtypes: + return dtypes.fp8 + if tensor.dtype == torch.float16: + return dtypes.d_dtypes["fp16"] + return dtypes.d_dtypes["bf16"] + + @staticmethod + def _aiter_dtype_for_torch_dtype( + dtype: torch.dtype, *, assume_fp8: bool = False + ) -> Any: + try: + from aiter import dtypes + except Exception as exc: + raise _SparseUnavailable(f"aiter dtypes unavailable: {exc}") from exc + if assume_fp8: + return dtypes.fp8 + if dtype == torch.float16: + return dtypes.d_dtypes["fp16"] + return dtypes.d_dtypes["bf16"] + + def _resolve_topk_for_prewarm(self) -> int: + for obj, attr in ( + (getattr(self.mla_modules, "indexer", None), "index_topk"), + (getattr(self.mla_modules, "indexer", None), "topk_tokens"), + (self.mla_modules, "index_topk"), + (getattr(self.mla_modules, "config", None), "index_topk"), + ): + value = getattr(obj, attr, None) if obj is not None else None + if value is not None: + return int(value) + return 2048 + + @staticmethod + def _metadata_token_budget(*, num_tokens: int, topk: int) -> int: + # Sparse decode can materialize up to num_tokens * topk ragged entries. + # Use this upper bound to avoid undersized work/reduce metadata buffers. + return max(int(num_tokens) * max(int(topk), 1), 1) + + @staticmethod + def _validate_capture_sparse_buffer_capacity( + *, + sparse_bufs: dict[str, torch.Tensor], + num_tokens: int, + topk: int, + ) -> None: + needed_indices = int(num_tokens) * int(topk) + if int(sparse_bufs["paged_kv_indices"].numel()) < needed_indices: + raise _SparseUnavailable( + "GLM5 RTP sparse MLA capture paged_kv_indices buffer is too small " + f"(buffer={int(sparse_bufs['paged_kv_indices'].numel())}, " + f"required={needed_indices})." + ) + if int(sparse_bufs["qo_indptr"].numel()) < int(num_tokens) + 1: + raise _SparseUnavailable( + "GLM5 RTP sparse MLA capture qo_indptr buffer is too small." + ) + if int(sparse_bufs["paged_kv_indptr"].numel()) < int(num_tokens) + 1: + raise _SparseUnavailable( + "GLM5 RTP sparse MLA capture paged_kv_indptr buffer is too small." + ) + if int(sparse_bufs["paged_kv_last_page_len"].numel()) < int(num_tokens): + raise _SparseUnavailable( + "GLM5 RTP sparse MLA capture paged_kv_last_page_len buffer is too small." + ) + + def prewarm_for_cuda_graph( + self, + *, + max_num_tokens: int, + max_seq_len: int, + query_dtype: torch.dtype, + device: torch.device, + ) -> None: + del max_seq_len + try: + from aiter import dtypes, get_mla_metadata_info_v1 + except Exception as exc: + raise _SparseUnavailable( + f"aiter metadata prewarm unavailable: {exc}" + ) from exc + + max_tokens = int(max_num_tokens) + if max_tokens <= 0: + return + num_heads = int( + self.num_heads or getattr(self.mla_modules, "num_local_heads", 0) or 0 + ) + if num_heads <= 0: + # Lazily inferred in eager path; graph capture needs a stable budget. + num_heads = int(getattr(self.mla_modules, "num_heads", 0) or 1) + num_heads = self._infer_num_heads_from_weight(num_heads) + self.num_heads = num_heads + padded_num_heads = max(num_heads, 16) + if padded_num_heads % num_heads != 0: + padded_num_heads = ( + (padded_num_heads + num_heads - 1) // num_heads + ) * num_heads + topk = self._resolve_topk_for_prewarm() + latent_dim = self.kv_lora_rank + self.qk_rope_head_dim + q_dtype = self._aiter_dtype_for_torch_dtype(query_dtype) + kv_dtype = self._aiter_dtype_for_torch_dtype(query_dtype, assume_fp8=True) + metadata_budget_tokens = self._metadata_token_budget( + num_tokens=max_tokens, topk=topk + ) + ( + (work_meta_data_size, work_meta_data_type), + (work_indptr_size, work_indptr_type), + (work_info_set_size, work_info_set_type), + (reduce_indptr_size, reduce_indptr_type), + (reduce_final_map_size, reduce_final_map_type), + (reduce_partial_map_size, reduce_partial_map_type), + ) = get_mla_metadata_info_v1( + metadata_budget_tokens, + 1, + padded_num_heads, + q_dtype, + kv_dtype, + is_sparse=True, + fast_mode=True, + ) + self._cg_sparse_bufs = { + "qo_indptr": torch.arange(max_tokens + 1, device=device, dtype=torch.int32), + "sparse_seqlen": torch.empty(max_tokens, device=device, dtype=torch.int32), + "paged_kv_indptr": torch.empty( + max_tokens + 1, device=device, dtype=torch.int32 + ), + "paged_kv_last_page_len": torch.ones( + max_tokens, device=device, dtype=torch.int32 + ), + "paged_kv_indices": torch.empty( + max_tokens * topk, device=device, dtype=torch.int32 + ), + "q_rope": torch.empty( + max_tokens, + num_heads, + self.qk_nope_head_dim + self.qk_rope_head_dim, + device=device, + dtype=query_dtype, + ), + "k_pe_rope_2d": torch.empty( + max_tokens, self.qk_rope_head_dim, device=device, dtype=query_dtype + ), + "k_pe_rope_3d": torch.empty( + max_tokens, 1, self.qk_rope_head_dim, device=device, dtype=query_dtype + ), + "k_pe_rope_heads": torch.empty( + max_tokens, + num_heads, + self.qk_rope_head_dim, + device=device, + dtype=query_dtype, + ), + "q_latent_nope_t": torch.empty( + num_heads, + max_tokens, + self.kv_lora_rank, + device=device, + dtype=query_dtype, + ), + "q_latent": torch.empty( + max_tokens, num_heads, latent_dim, device=device, dtype=query_dtype + ), + "q_for_kernel": torch.empty( + max_tokens, + padded_num_heads, + latent_dim, + device=device, + dtype=query_dtype, + ), + "q_for_kernel_fp8": torch.empty( + max_tokens, + padded_num_heads, + latent_dim, + device=device, + dtype=dtypes.fp8, + ), + "latent_output": torch.empty( + max_tokens, + padded_num_heads, + self.kv_lora_rank, + device=device, + dtype=query_dtype, + ), + "final_output_t": torch.empty( + num_heads, max_tokens, self.v_head_dim, device=device, dtype=query_dtype + ), + "work_meta_data": torch.empty( + work_meta_data_size, dtype=work_meta_data_type, device=device + ), + "work_indptr": torch.empty( + work_indptr_size, dtype=work_indptr_type, device=device + ), + "work_info_set": torch.empty( + work_info_set_size, dtype=work_info_set_type, device=device + ), + "reduce_indptr": torch.empty( + reduce_indptr_size, dtype=reduce_indptr_type, device=device + ), + "reduce_final_map": torch.empty( + reduce_final_map_size, dtype=reduce_final_map_type, device=device + ), + "reduce_partial_map": torch.empty( + reduce_partial_map_size, dtype=reduce_partial_map_type, device=device + ), + } + self._cg_sparse_bufs["paged_kv_indptr"].zero_() + self._cache_write_scale[device] = torch.tensor( + 1.0, dtype=torch.float32, device=device + ) + self._cg_workspace_signature = ( + max_tokens, + padded_num_heads, + topk, + query_dtype, + device, + ) + + def _build_atom_sparse_metadata( + self, + *, + q_latent: torch.Tensor, + kv_cache_base: torch.Tensor, + topk_indices: torch.Tensor, + attn_metadata: Any, + block_size: int, + ) -> _AtomSparseMetadata: + try: + from aiter import get_mla_metadata_info_v1, get_mla_metadata_v1 + + triton_convert_req_index_to_global_index = ( + _resolve_plugin_sparse_index_converter() + ) + except Exception as exc: + raise _SparseUnavailable( + f"ATOM sparse MLA metadata helpers unavailable: {exc}" + ) from exc + + plugin_metadata = getattr(attn_metadata, "plugin_metadata", None) + if plugin_metadata is None: + raise _SparseUnavailable("GLM5 RTP sparse MLA requires plugin metadata.") + + num_tokens = int(q_latent.shape[0]) + num_heads = int(q_latent.shape[1]) + topk = int(topk_indices.shape[1]) + device = q_latent.device + in_capture = torch.cuda.is_current_stream_capturing() + cg_bufs = getattr(plugin_metadata, "cg_bufs", None) + sparse_bufs = self._cg_sparse_bufs + + query_start_loc = getattr(plugin_metadata, "query_start_loc", None) + if query_start_loc is None: + query_start_loc = getattr(plugin_metadata, "rtp_cu_seqlens_q", None) + if ( + not isinstance(query_start_loc, torch.Tensor) + or int(query_start_loc.numel()) < 2 + ): + raise _SparseUnavailable("GLM5 RTP sparse MLA requires query_start_loc.") + if in_capture: + if query_start_loc.device != device or query_start_loc.dtype != torch.int32: + raise _SparseUnavailable( + "GLM5 RTP sparse MLA capture requires int32 query_start_loc on device." + ) + else: + query_start_loc = query_start_loc.to( + device=device, dtype=torch.int32 + ).contiguous() + + seq_lens = getattr(plugin_metadata, "seq_lens", None) + if seq_lens is None: + seq_lens = getattr(attn_metadata, "context_lens", None) + if not isinstance(seq_lens, torch.Tensor) or int(seq_lens.numel()) + 1 != int( + query_start_loc.numel() + ): + raise _SparseUnavailable( + "GLM5 RTP sparse MLA requires seq_lens per request." + ) + if in_capture: + if seq_lens.device != device or seq_lens.dtype != torch.int32: + raise _SparseUnavailable( + "GLM5 RTP sparse MLA capture requires int32 seq_lens on device." + ) + else: + seq_lens = seq_lens.to(device=device, dtype=torch.int32).contiguous() + + if in_capture: + if not isinstance(cg_bufs, dict) or sparse_bufs is None: + raise _SparseUnavailable( + "GLM5 RTP sparse MLA capture requires prewarmed buffers." + ) + req_id = cg_bufs.get("seq_id_i32", None) + if not isinstance(req_id, torch.Tensor): + raise _SparseUnavailable( + "GLM5 RTP sparse MLA capture requires prewarmed seq_id_i32." + ) + req_id = req_id[:num_tokens] + block_table = getattr(plugin_metadata, "block_table", None) + if not isinstance(block_table, torch.Tensor): + raise _SparseUnavailable( + "GLM5 RTP sparse MLA capture requires block_table." + ) + if block_table.device != device or block_table.dtype != torch.int32: + raise _SparseUnavailable( + "GLM5 RTP sparse MLA capture requires int32 block_table on device." + ) + topk_indices_i32 = topk_indices + if ( + topk_indices_i32.device != device + or topk_indices_i32.dtype != torch.int32 + ): + raise _SparseUnavailable( + "GLM5 RTP sparse MLA capture requires int32 topk_indices on device." + ) + if not topk_indices_i32.is_contiguous(): + raise _SparseUnavailable( + "GLM5 RTP sparse MLA capture requires contiguous topk_indices." + ) + self._validate_capture_sparse_buffer_capacity( + sparse_bufs=sparse_bufs, + num_tokens=num_tokens, + topk=topk, + ) + sparse_seqlen = sparse_bufs["sparse_seqlen"][:num_tokens] + torch.clamp(seq_lens[:num_tokens], min=0, max=topk, out=sparse_seqlen) + max_query_len_for_sparse = 1 + else: + req_id = self._build_req_id_per_token(attn_metadata, num_tokens, device).to( + dtype=torch.int32 + ) + block_table = self._block_table(attn_metadata, device).to(dtype=torch.int32) + topk_indices_i32 = topk_indices.to( + device=device, dtype=torch.int32 + ).contiguous() + # Keep prefill aligned with ATOM sparse metadata contract: token-ragged + # representation always uses max_q_len=1. + max_query_len_for_sparse = 1 + # Derive sparse lengths directly from indexer output validity. This is + # robust for chunked prefill where seq_lens may be chunk-local. + sparse_seqlen = torch.sum(topk_indices_i32 >= 0, dim=1, dtype=torch.int32) + + if in_capture: + qo_indptr = sparse_bufs["qo_indptr"][: num_tokens + 1] + paged_kv_indptr = sparse_bufs["paged_kv_indptr"][: num_tokens + 1] + paged_kv_indptr[0].zero_() + paged_kv_last_page_len = sparse_bufs["paged_kv_last_page_len"][:num_tokens] + paged_kv_indices = sparse_bufs["paged_kv_indices"][: num_tokens * topk] + else: + eager_sig = ( + int(num_tokens), + int(topk), + str(device), + ) + cached_eager = getattr(plugin_metadata, "_rtp_sparse_eager_workspace", None) + if ( + isinstance(cached_eager, dict) + and cached_eager.get("signature") == eager_sig + ): + qo_indptr = cached_eager["qo_indptr"] + paged_kv_indptr = cached_eager["paged_kv_indptr"] + paged_kv_last_page_len = cached_eager["paged_kv_last_page_len"] + paged_kv_indices = cached_eager["paged_kv_indices"] + else: + qo_indptr = torch.empty( + num_tokens + 1, device=device, dtype=torch.int32 + ) + paged_kv_indptr = torch.empty( + num_tokens + 1, device=device, dtype=torch.int32 + ) + paged_kv_last_page_len = torch.empty( + num_tokens, device=device, dtype=torch.int32 + ) + paged_kv_indices = torch.empty( + num_tokens * topk, device=device, dtype=torch.int32 + ) + try: + plugin_metadata._rtp_sparse_eager_workspace = { + "signature": eager_sig, + "qo_indptr": qo_indptr, + "paged_kv_indptr": paged_kv_indptr, + "paged_kv_last_page_len": paged_kv_last_page_len, + "paged_kv_indices": paged_kv_indices, + } + except Exception: + pass + qo_indptr.copy_( + torch.arange(num_tokens + 1, device=device, dtype=torch.int32) + ) + paged_kv_indptr.zero_() + paged_kv_last_page_len.fill_(1) + torch.cumsum(sparse_seqlen, dim=0, out=paged_kv_indptr[1:]) + + if not in_capture and int(block_size) <= 0: + raise _SparseUnavailable( + f"GLM5 RTP sparse MLA requires positive block_size, got {block_size}." + ) + + triton_convert_req_index_to_global_index( + req_id, + block_table, + topk_indices_i32, + paged_kv_indptr, + paged_kv_indices, + BLOCK_SIZE=int(block_size), + NUM_TOPK_TOKENS=topk, + ) + + padded_num_heads = max(num_heads, 16) + if padded_num_heads % num_heads != 0: + padded_num_heads = ( + (padded_num_heads + num_heads - 1) // num_heads + ) * num_heads + head_repeat_factor = padded_num_heads // num_heads + q_dtype = self._aiter_dtype_for_tensor(q_latent) + kv_dtype = self._aiter_dtype_for_tensor(kv_cache_base) + reuse_eager_metadata = False + if in_capture: + work_meta_data = sparse_bufs["work_meta_data"] + work_indptr = sparse_bufs["work_indptr"] + work_info_set = sparse_bufs["work_info_set"] + reduce_indptr = sparse_bufs["reduce_indptr"] + reduce_final_map = sparse_bufs["reduce_final_map"] + reduce_partial_map = sparse_bufs["reduce_partial_map"] + else: + eager_meta_sig = ( + int(num_tokens), + int(topk), + int(padded_num_heads), + str(q_dtype), + str(kv_dtype), + str(device), + ) + cached_eager_meta = getattr( + plugin_metadata, "_rtp_sparse_eager_meta_workspace", None + ) + if ( + isinstance(cached_eager_meta, dict) + and cached_eager_meta.get("signature") == eager_meta_sig + ): + work_meta_data = cached_eager_meta["work_meta_data"] + work_indptr = cached_eager_meta["work_indptr"] + work_info_set = cached_eager_meta["work_info_set"] + reduce_indptr = cached_eager_meta["reduce_indptr"] + reduce_final_map = cached_eager_meta["reduce_final_map"] + reduce_partial_map = cached_eager_meta["reduce_partial_map"] + reuse_eager_metadata = bool( + cached_eager_meta.get("metadata_ready", False) + ) + else: + metadata_budget_tokens = self._metadata_token_budget( + num_tokens=num_tokens, topk=topk + ) + ( + (work_meta_data_size, work_meta_data_type), + (work_indptr_size, work_indptr_type), + (work_info_set_size, work_info_set_type), + (reduce_indptr_size, reduce_indptr_type), + (reduce_final_map_size, reduce_final_map_type), + (reduce_partial_map_size, reduce_partial_map_type), + ) = get_mla_metadata_info_v1( + metadata_budget_tokens, + 1, + padded_num_heads, + q_dtype, + kv_dtype, + is_sparse=True, + fast_mode=True, + ) + work_meta_data = torch.empty( + work_meta_data_size, dtype=work_meta_data_type, device=device + ) + work_indptr = torch.empty( + work_indptr_size, dtype=work_indptr_type, device=device + ) + work_info_set = torch.empty( + work_info_set_size, dtype=work_info_set_type, device=device + ) + reduce_indptr = torch.empty( + reduce_indptr_size, dtype=reduce_indptr_type, device=device + ) + reduce_final_map = torch.empty( + reduce_final_map_size, dtype=reduce_final_map_type, device=device + ) + reduce_partial_map = torch.empty( + reduce_partial_map_size, + dtype=reduce_partial_map_type, + device=device, + ) + try: + plugin_metadata._rtp_sparse_eager_meta_workspace = { + "signature": eager_meta_sig, + "work_meta_data": work_meta_data, + "work_indptr": work_indptr, + "work_info_set": work_info_set, + "reduce_indptr": reduce_indptr, + "reduce_final_map": reduce_final_map, + "reduce_partial_map": reduce_partial_map, + "metadata_ready": False, + } + except Exception: + pass + capture_meta_sig = ( + int(num_tokens), + int(topk), + int(padded_num_heads), + str(q_dtype), + str(kv_dtype), + str(device), + ) + reuse_capture_metadata = False + if in_capture: + cached_capture_meta = getattr( + plugin_metadata, "_rtp_sparse_capture_meta_workspace", None + ) + if ( + isinstance(cached_capture_meta, dict) + and cached_capture_meta.get("signature") == capture_meta_sig + ): + work_meta_data = cached_capture_meta["work_meta_data"] + work_indptr = cached_capture_meta["work_indptr"] + work_info_set = cached_capture_meta["work_info_set"] + reduce_indptr = cached_capture_meta["reduce_indptr"] + reduce_final_map = cached_capture_meta["reduce_final_map"] + reduce_partial_map = cached_capture_meta["reduce_partial_map"] + reuse_capture_metadata = True + kv_token_slots = self._kv_token_slot_capacity(kv_cache_base) + page_size = 1 + max_page_slots = int(kv_token_slots) + + if in_capture and int(paged_kv_indices.numel()) > 0: + # Capture path cannot run host-synced range checks; clamp indices into + # the current kv slot range to avoid kernel-side OOB accesses. + paged_kv_indices.clamp_(min=0, max=max(int(max_page_slots) - 1, 0)) + + if not in_capture and self._enable_sparse_validate: + self._validate_sparse_index_contract( + paged_kv_indptr=paged_kv_indptr, + paged_kv_indices=paged_kv_indices, + num_tokens=num_tokens, + page_size=page_size, + max_slots=max_page_slots, + ) + + if not reuse_capture_metadata and not reuse_eager_metadata: + get_mla_metadata_v1( + qo_indptr, + paged_kv_indptr, + paged_kv_last_page_len, + padded_num_heads, + 1, + True, + work_meta_data, + work_info_set, + work_indptr, + reduce_indptr, + reduce_final_map, + reduce_partial_map, + page_size=page_size, + kv_granularity=16, + max_seqlen_qo=max_query_len_for_sparse, + uni_seqlen_qo=max_query_len_for_sparse, + fast_mode=True, + dtype_q=q_dtype, + dtype_kv=kv_dtype, + ) + if not in_capture: + cached_eager_meta = getattr( + plugin_metadata, "_rtp_sparse_eager_meta_workspace", None + ) + if isinstance(cached_eager_meta, dict): + cached_eager_meta["metadata_ready"] = True + if in_capture: + plugin_metadata._rtp_sparse_capture_meta_workspace = { + "signature": capture_meta_sig, + "work_meta_data": work_meta_data, + "work_indptr": work_indptr, + "work_info_set": work_info_set, + "reduce_indptr": reduce_indptr, + "reduce_final_map": reduce_final_map, + "reduce_partial_map": reduce_partial_map, + } + return _AtomSparseMetadata( + qo_indptr=qo_indptr, + paged_kv_indptr=paged_kv_indptr, + paged_kv_indices=paged_kv_indices, + paged_kv_last_page_len=paged_kv_last_page_len, + work_meta_data=work_meta_data, + work_indptr=work_indptr, + work_info_set=work_info_set, + reduce_indptr=reduce_indptr, + reduce_final_map=reduce_final_map, + reduce_partial_map=reduce_partial_map, + padded_num_heads=padded_num_heads, + head_repeat_factor=head_repeat_factor, + page_size=page_size, + ) + + def _run_aiter_sparse_decode( + self, + *, + q_latent: torch.Tensor, + kv_cache_base: torch.Tensor, + topk_indices: torch.Tensor, + attn_metadata: Any, + block_size: int, + ) -> torch.Tensor: + try: + from aiter.mla import mla_decode_fwd + except Exception as exc: + raise _SparseUnavailable( + f"aiter.mla_decode_fwd unavailable: {exc}" + ) from exc + + num_tokens, num_heads, latent_dim = q_latent.shape + sparse_meta = self._build_atom_sparse_metadata( + q_latent=q_latent, + kv_cache_base=kv_cache_base, + topk_indices=topk_indices, + attn_metadata=attn_metadata, + block_size=block_size, + ) + in_capture = torch.cuda.is_current_stream_capturing() + page_size = 1 + if sparse_meta.head_repeat_factor > 1: + if in_capture and self._cg_sparse_bufs is not None: + q_for_kernel = self._cg_sparse_bufs["q_for_kernel"][ + :num_tokens, : sparse_meta.padded_num_heads, : + ] + # Capture path: use one broadcasted copy to fill repeated heads, + # avoiding per-repeat slice copies in the decode hot path. + q_for_kernel.view( + num_tokens, + num_heads, + sparse_meta.head_repeat_factor, + latent_dim, + ).copy_(q_latent.unsqueeze(2)) + else: + q_for_kernel = ( + q_latent.unsqueeze(2) + .expand(-1, -1, sparse_meta.head_repeat_factor, -1) + .reshape(num_tokens, sparse_meta.padded_num_heads, latent_dim) + ) + else: + q_for_kernel = q_latent + output_dtype = q_for_kernel.dtype + if in_capture and self._cg_sparse_bufs is not None: + output = self._cg_sparse_bufs["latent_output"][ + :num_tokens, : sparse_meta.padded_num_heads, : + ] + else: + output = torch.empty( + (num_tokens, sparse_meta.padded_num_heads, self.kv_lora_rank), + dtype=output_dtype, + device=q_latent.device, + ) + fp8_scale_kwargs = {} + if self._cache_dtype_name(kv_cache_base) == "fp8": + kv_scale = self._cache_write_scale.get(kv_cache_base.device) + if kv_scale is None: + kv_scale = torch.tensor( + 1.0, dtype=torch.float32, device=kv_cache_base.device + ) + self._cache_write_scale[kv_cache_base.device] = kv_scale + fp8_scale_kwargs = {"q_scale": kv_scale, "kv_scale": kv_scale} + try: + from aiter import dtypes + except Exception as exc: + raise _SparseUnavailable(f"aiter dtypes unavailable: {exc}") from exc + if in_capture and self._cg_sparse_bufs is not None: + q_for_kernel_fp8 = self._cg_sparse_bufs["q_for_kernel_fp8"][ + :num_tokens, : sparse_meta.padded_num_heads, : + ] + q_for_kernel_fp8.copy_(q_for_kernel) + q_for_kernel = q_for_kernel_fp8 + else: + q_for_kernel = q_for_kernel.to(dtype=dtypes.fp8) + try: + kv_buffer = kv_cache_base.reshape(-1, 1, 1, latent_dim) + if ( + not in_capture + and self._enable_sparse_validate + and int(sparse_meta.paged_kv_indices.numel()) > 0 + ): + self._validate_sparse_index_contract( + paged_kv_indptr=sparse_meta.paged_kv_indptr, + paged_kv_indices=sparse_meta.paged_kv_indices, + num_tokens=num_tokens, + page_size=page_size, + max_slots=int(kv_buffer.shape[0]), + ) + self._validate_sparse_last_page_contract( + paged_kv_indptr=sparse_meta.paged_kv_indptr, + paged_kv_last_page_len=sparse_meta.paged_kv_last_page_len, + num_tokens=num_tokens, + page_size=page_size, + ) + mla_decode_fwd( + q_for_kernel, + kv_buffer, + output, + sparse_meta.qo_indptr, + sparse_meta.paged_kv_indptr, + sparse_meta.paged_kv_indices, + sparse_meta.paged_kv_last_page_len, + 1, + sm_scale=self.scale, + page_size=page_size, + work_meta_data=sparse_meta.work_meta_data, + work_indptr=sparse_meta.work_indptr, + work_info_set=sparse_meta.work_info_set, + reduce_indptr=sparse_meta.reduce_indptr, + reduce_final_map=sparse_meta.reduce_final_map, + reduce_partial_map=sparse_meta.reduce_partial_map, + **fp8_scale_kwargs, + ) + except Exception as exc: + raise _SparseUnavailable(f"mla_decode_fwd failed: {exc}") from exc + if sparse_meta.head_repeat_factor > 1: + output = output[:, :: sparse_meta.head_repeat_factor, :] + if not in_capture: + output = output.contiguous() + return output + + def forward( + self, + q: torch.Tensor, + compressed_kv: torch.Tensor, + k_pe: torch.Tensor, + kv_cache: object, + layer_id: int, + *, + topk_indices: torch.Tensor, + attn_metadata: object, + positions: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + del layer_id + if attn_metadata is None: + raise _SparseUnavailable("GLM5 RTP sparse MLA requires attn_metadata.") + if getattr( + getattr(attn_metadata, "plugin_metadata", None), "is_dummy_warmup", False + ): + return q.new_zeros((q.shape[0], q.shape[1], self.v_head_dim)) + q_rope, k_pe_rope = self._apply_rope(q, k_pe, positions) + kv_cache_base = self._write_current_to_cache( + compressed_kv=compressed_kv, + k_pe=k_pe_rope, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + absorbed = self._get_absorbed_weights(q_rope) + q_nope = q_rope[..., : self.qk_nope_head_dim] + in_capture = torch.cuda.is_current_stream_capturing() + if in_capture: + if self._cg_sparse_bufs is None: + raise _SparseUnavailable( + "GLM5 RTP sparse MLA capture requires q buffers." + ) + if q_nope.dtype != absorbed.w_kc.dtype: + raise _SparseUnavailable( + "GLM5 RTP sparse MLA capture requires q_nope dtype to match absorbed weights." + ) + q_latent_nope_t = self._cg_sparse_bufs["q_latent_nope_t"][ + : q.shape[1], : q.shape[0], : + ] + torch.bmm(q_nope.transpose(0, 1), absorbed.w_kc, out=q_latent_nope_t) + q_latent_nope = q_latent_nope_t.transpose(0, 1) + q_latent = self._cg_sparse_bufs["q_latent"][ + : q.shape[0], + : q.shape[1], + : self.kv_lora_rank + self.qk_rope_head_dim, + ] + else: + q_latent_nope = torch.bmm( + q_nope.transpose(0, 1).to(dtype=absorbed.w_kc.dtype), + absorbed.w_kc, + ).transpose(0, 1) + q_latent = torch.empty( + q.shape[0], + q.shape[1], + self.kv_lora_rank + self.qk_rope_head_dim, + dtype=q_latent_nope.dtype, + device=q.device, + ) + q_latent[..., : self.kv_lora_rank] = q_latent_nope + if self.qk_rope_head_dim > 0: + q_latent[..., self.kv_lora_rank :] = q_rope[ + ..., -self.qk_rope_head_dim : + ].to(dtype=q_latent.dtype) + + block_size = int(getattr(attn_metadata, "rtp_seq_size_per_block", 0) or 0) + if block_size <= 0: + plugin_metadata = getattr(attn_metadata, "plugin_metadata", None) + block_size = int(getattr(plugin_metadata, "sparse_block_size", 0) or 0) + if block_size <= 0: + raise _SparseUnavailable( + "GLM5 RTP sparse MLA requires physical block size." + ) + latent_output = self._run_aiter_sparse_decode( + q_latent=q_latent, + kv_cache_base=kv_cache_base, + topk_indices=topk_indices, + attn_metadata=attn_metadata, + block_size=block_size, + ) + if in_capture: + if latent_output.dtype != absorbed.w_vc.dtype: + raise _SparseUnavailable( + "GLM5 RTP sparse MLA capture requires latent output dtype to match absorbed weights." + ) + output_t = self._cg_sparse_bufs["final_output_t"][ + : q.shape[1], : q.shape[0], : + ] + torch.bmm(latent_output.transpose(0, 1), absorbed.w_vc, out=output_t) + output = output_t.transpose(0, 1) + if output.dtype != q.dtype: + raise _SparseUnavailable( + "GLM5 RTP sparse MLA capture requires final output dtype to match q." + ) + return output + output = torch.bmm( + latent_output.transpose(0, 1).to(dtype=absorbed.w_vc.dtype), + absorbed.w_vc, + ).transpose(0, 1) + return output.to(dtype=q.dtype) + + +class RTPSparseMlaBackend: + """Sparse MLA backend used by GLM5 RTP plugin mode. + + Real GLM5 layers use ATOM-owned MLA modules and the AITER sparse decode + kernel. The lightweight implementation is kept for unit tests and explicit + injection only; production paths refuse dense fallback when sparse execution + is unavailable. + """ + + def __init__( + self, + *, + sparse_impl: Optional[object] = None, + v_head_dim: Optional[int] = None, + mla_modules: Optional[object] = None, + scale: Optional[float] = None, + ) -> None: + if v_head_dim is None: + if mla_modules is None or not hasattr(mla_modules, "v_head_dim"): + raise ValueError( + "RTPSparseMlaBackend requires v_head_dim or mla_modules.v_head_dim." + ) + v_head_dim = getattr(mla_modules, "v_head_dim") + self.v_head_dim = int(v_head_dim) + if sparse_impl is not None: + self.sparse_impl = sparse_impl + self._uses_lightweight_impl = False + elif mla_modules is not None and all( + hasattr(mla_modules, attr) + for attr in ( + "kv_lora_rank", + "qk_nope_head_dim", + "qk_rope_head_dim", + "kv_b_proj", + "rotary_emb", + ) + ): + self.sparse_impl = _RealSparseMlaImpl( + mla_modules=mla_modules, + v_head_dim=self.v_head_dim, + scale=scale, + ) + self._uses_lightweight_impl = False + else: + self.sparse_impl = _LightweightSparseMlaImpl(self.v_head_dim) + self._uses_lightweight_impl = True + self._sparse_impl_accepts_positions = self._impl_accepts_positions( + self.sparse_impl + ) + + def prepare_cuda_graph(self, attn_inputs) -> None: # noqa: ANN001 + del attn_inputs + + def prewarm_for_cuda_graph( + self, + *, + max_num_tokens: int, + max_seq_len: int, + query_dtype: torch.dtype, + device: torch.device, + ) -> None: + sparse_prewarm = getattr(self.sparse_impl, "prewarm_for_cuda_graph", None) + if callable(sparse_prewarm): + sparse_prewarm( + max_num_tokens=max_num_tokens, + max_seq_len=max_seq_len, + query_dtype=query_dtype, + device=device, + ) + + @staticmethod + def _get_attn_metadata() -> object: + try: + from atom.utils.forward_context import get_forward_context + + return getattr(get_forward_context(), "attn_metadata", None) + except Exception: + return None + + @staticmethod + def _validate_topk_indices(q: torch.Tensor, topk_indices: torch.Tensor) -> None: + if topk_indices.ndim != 2: + raise ValueError( + "Expected topk_indices to be rank-2 [T,K], " + f"got shape {tuple(topk_indices.shape)}" + ) + if topk_indices.dtype != torch.int32: + raise ValueError( + f"Expected topk_indices dtype torch.int32, got {topk_indices.dtype}" + ) + if topk_indices.shape[0] != q.shape[0]: + raise ValueError( + "Expected topk_indices first dimension to match q tokens, " + f"got {topk_indices.shape[0]} and {q.shape[0]}" + ) + + @staticmethod + def _impl_accepts_positions(impl: object) -> bool: + try: + signature = inspect.signature(impl.forward) + except (AttributeError, TypeError, ValueError): + return False + return "positions" in signature.parameters or any( + parameter.kind == inspect.Parameter.VAR_KEYWORD + for parameter in signature.parameters.values() + ) + + def forward( + self, + q: torch.Tensor, + compressed_kv: torch.Tensor, + k_pe: torch.Tensor, + kv_cache: object, + layer_id: int, + topk_indices: Optional[torch.Tensor] = None, + positions: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + attn_metadata = self._get_attn_metadata() + if getattr( + getattr(attn_metadata, "plugin_metadata", None), "is_dummy_warmup", False + ): + return q.new_zeros((q.shape[0], q.shape[1], self.v_head_dim)) + + if topk_indices is None: + if self._uses_lightweight_impl: + return q.new_zeros((q.shape[0], q.shape[1], self.v_head_dim)) + raise _SparseUnavailable( + "GLM5 RTP sparse MLA requires topk_indices; refusing dense fallback." + ) + self._validate_topk_indices(q, topk_indices) + if self._uses_lightweight_impl or not callable( + getattr(self.sparse_impl, "forward", None) + ): + raise _SparseUnavailable( + "GLM5 RTP sparse MLA is unavailable; refusing dense fallback." + ) + + kwargs = { + "topk_indices": topk_indices, + "attn_metadata": attn_metadata, + } + if self._sparse_impl_accepts_positions: + kwargs["positions"] = positions + try: + return self.sparse_impl.forward( + q, + compressed_kv, + k_pe, + kv_cache, + layer_id, + **kwargs, + ) + except _SparseUnavailable as exc: + raise _SparseUnavailable( + "GLM5 RTP sparse MLA unavailable; dense fallback is disabled. " + f"root_cause={exc}" + ) from exc + + +def _run_rtp_sparse_attn_indexer_topk_only( + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + q_input: torch.Tensor, + k: torch.Tensor, + weights: torch.Tensor, + quant_block_size: int, + scale_fmt: Optional[str], + topk_tokens: int, + head_dim: int, + max_model_len: int, + total_seq_lens: int, + topk_indices_buffer: torch.Tensor, + k_norm_weight: torch.Tensor, + k_norm_bias: torch.Tensor, + k_norm_eps: float, + positions: torch.Tensor, + cos_cache: torch.Tensor, + sin_cache: torch.Tensor, + weights_scale: float, + is_neox_style: bool, + use_qk_rope_cache_fusion: bool, + context: Any, + attn_metadata: Any, +) -> torch.Tensor: + from aiter import ( + cp_gather_indexer_k_quant_cache, + dtypes, + indexer_k_quant_and_cache, + indexer_qk_rope_quant_and_cache, + top_k_per_row_decode, + top_k_per_row_prefill, + ) + from aiter.ops.triton.fp8_mqa_logits import fp8_mqa_logits + from aiter.ops.triton.pa_mqa_logits import deepgemm_fp8_paged_mqa_logits + from atom.config import get_current_atom_config + + slot_mapping = getattr(attn_metadata, "slot_mapping", None) + if slot_mapping is None: + raise _SparseUnavailable("RTP sparse indexer requires slot_mapping metadata.") + if topk_indices_buffer is None: + raise _SparseUnavailable("RTP sparse indexer requires topk_indices_buffer.") + if topk_indices_buffer.dim() != 2: + raise _SparseUnavailable( + "RTP sparse indexer requires a 2D topk_indices_buffer; " + f"got shape={tuple(topk_indices_buffer.shape)}." + ) + + if bool(getattr(context, "is_dummy_run", False)): + return torch.zeros_like(weights, dtype=torch.float32) + + num_tokens = int(hidden_states.shape[0]) + if num_tokens <= 0: + return weights + topk_indices = topk_indices_buffer[:num_tokens, :topk_tokens] + if topk_indices.dtype != torch.int32: + raise _SparseUnavailable( + f"RTP sparse indexer topk buffer must be int32, got {topk_indices.dtype}." + ) + + runner_block_size = int(get_current_atom_config().kv_cache_block_size) + kv_cache = kv_cache.view(-1, runner_block_size, kv_cache.shape[-1]) + + if use_qk_rope_cache_fusion: + q_bf16 = q_input + q_fp8 = torch.empty_like(q_bf16, dtype=dtypes.fp8) + weights_out = torch.empty( + weights.shape, device=weights.device, dtype=torch.float32 + ) + indexer_qk_rope_quant_and_cache( + q_bf16, + q_fp8, + weights, + weights_out, + k, + kv_cache, + slot_mapping, + k_norm_weight, + k_norm_bias, + positions, + cos_cache, + sin_cache, + k_norm_eps, + quant_block_size, + scale_fmt, + weights_scale, + preshuffle=True, + is_neox=is_neox_style, + ) + weights = weights_out + else: + q_fp8 = q_input + indexer_k_quant_and_cache( + k, + kv_cache, + slot_mapping, + quant_block_size, + scale_fmt, + preshuffle=True, + ) + + is_prefill = bool(getattr(context, "is_prefill", False)) + max_seqlen_k = int(getattr(attn_metadata, "max_seqlen_k", 0) or 0) + if is_prefill and max_seqlen_k <= int(topk_tokens): + return weights + + if is_prefill: + total_seq_lens = int(hidden_states.shape[0]) + has_cached = bool(getattr(attn_metadata, "has_cached", False)) + total_kv = ( + int(getattr(attn_metadata, "total_kv", total_seq_lens)) + if has_cached + else total_seq_lens + ) + k_fp8 = torch.empty([total_kv, head_dim], device=k.device, dtype=dtypes.fp8) + k_scale = torch.empty([total_kv, 1], device=k.device, dtype=torch.float32) + block_tables = getattr(attn_metadata, "block_tables", None) + cu_seqlens_q = getattr(attn_metadata, "cu_seqlens_q", None) + if block_tables is None or cu_seqlens_q is None: + raise _SparseUnavailable( + "RTP sparse prefill indexer requires block_tables and cu_seqlens_q." + ) + cu_seqlens_k = ( + getattr(attn_metadata, "cu_seqlens_k", None) if has_cached else cu_seqlens_q + ) + if cu_seqlens_k is None: + raise _SparseUnavailable( + "RTP sparse prefill indexer requires cu_seqlens_k." + ) + cp_gather_indexer_k_quant_cache( + kv_cache, + k_fp8, + k_scale.view(dtypes.fp8), + block_tables, + cu_seqlens_k, + preshuffle=True, + ) + cu_seqlen_ks = getattr(attn_metadata, "cu_seqlen_ks", None) + cu_seqlen_ke = getattr(attn_metadata, "cu_seqlen_ke", None) + if cu_seqlen_ks is None or cu_seqlen_ke is None: + raise _SparseUnavailable( + "RTP sparse prefill indexer requires cu_seqlen_ks/cu_seqlen_ke." + ) + num_decode_tokens = 0 + logits = fp8_mqa_logits( + Q=q_fp8[num_decode_tokens:num_tokens], + KV=k_fp8, + kv_scales=k_scale, + weights=weights[num_decode_tokens:num_tokens], + cu_starts=cu_seqlen_ks, + cu_ends=cu_seqlen_ke, + ) + top_k_per_row_prefill( + logits=logits, + rowStarts=cu_seqlen_ks, + rowEnds=cu_seqlen_ke, + indices=topk_indices[num_decode_tokens:num_tokens, :topk_tokens], + values=None, + numRows=logits.shape[0], + stride0=logits.stride(0), + stride1=logits.stride(1), + ) + return weights + + max_seqlen_q = int(getattr(attn_metadata, "max_seqlen_q", 1) or 1) + num_decode_tokens = int(context.batch_size) * max_seqlen_q + kv_cache_for_logits = kv_cache.unsqueeze(-2) + padded_q_fp8_decode_tokens = q_fp8[:num_decode_tokens].reshape( + int(context.batch_size), -1, *q_fp8.shape[1:] + ) + batch_size, next_n, _heads, _dim = padded_q_fp8_decode_tokens.shape + logits = torch.empty( + [batch_size * next_n, int(max_model_len)], + dtype=torch.float32, + device=hidden_states.device, + ) + context_lens = getattr(attn_metadata, "context_lens", None) + block_tables = getattr(attn_metadata, "block_tables", None) + if context_lens is None or block_tables is None: + raise _SparseUnavailable( + "RTP sparse decode indexer requires context_lens and block_tables." + ) + deepgemm_fp8_paged_mqa_logits( + padded_q_fp8_decode_tokens, + kv_cache_for_logits, + weights[:num_decode_tokens], + logits, + context_lens, + block_tables, + int(max_model_len), + KVBlockSize=runner_block_size, + Preshuffle=True, + ) + top_k_per_row_decode( + logits, + next_n, + context_lens, + topk_indices[:num_decode_tokens, :topk_tokens], + logits.shape[0], + logits.stride(0), + logits.stride(1), + ) + return weights + + +def rtp_sparse_attn_indexer( + hidden_states: torch.Tensor, + k_cache_prefix: str, + kv_cache: torch.Tensor, + q_input: torch.Tensor, + k: torch.Tensor, + weights: torch.Tensor, + quant_block_size: int, + scale_fmt: Optional[str], + topk_tokens: int, + head_dim: int, + max_model_len: int, + total_seq_lens: int, + topk_indices_buffer: torch.Tensor, + k_norm_weight: torch.Tensor, + k_norm_bias: torch.Tensor, + k_norm_eps: float, + positions: torch.Tensor, + cos_cache: torch.Tensor, + sin_cache: torch.Tensor, + weights_scale: float, + is_neox_style: bool, + use_qk_rope_cache_fusion: bool, +) -> torch.Tensor: + try: + from atom.utils.forward_context import get_forward_context + + forward_context = get_forward_context() + except Exception: + forward_context = None + context = getattr(forward_context, "context", None) + attn_metadata = getattr(forward_context, "attn_metadata", None) + # For short prefill (ctx <= topk buffer width), DeepSeek indexer returns early and + # doesn't write topk buffer. Emit causal full-history indices to keep sparse path valid. + if ( + context is not None + and bool(getattr(context, "is_prefill", False)) + and attn_metadata is not None + and topk_indices_buffer is not None + and topk_indices_buffer.dim() == 2 + and positions is not None + ): + max_seqlen_k = int(getattr(attn_metadata, "max_seqlen_k", 0) or 0) + topk_capacity = int(topk_indices_buffer.shape[1]) + if max_seqlen_k > 0 and max_seqlen_k <= topk_capacity: + num_tokens = int(hidden_states.shape[0]) + if num_tokens > 0: + positions_i32 = positions.to( + device=topk_indices_buffer.device, dtype=torch.int32 + ).view(-1) + row_limits = ( + (positions_i32 + 1).clamp(min=0, max=topk_tokens).view(-1, 1) + ) + col_ids = torch.arange( + topk_tokens, + device=topk_indices_buffer.device, + dtype=torch.int32, + ).view(1, -1) + causal_topk = torch.where( + col_ids < row_limits, + col_ids.expand(num_tokens, topk_tokens), + torch.full( + (num_tokens, topk_tokens), + -1, + device=topk_indices_buffer.device, + dtype=torch.int32, + ), + ) + topk_indices_buffer[:num_tokens, :topk_tokens].copy_(causal_topk) + return weights + + if context is not None and attn_metadata is not None: + return _run_rtp_sparse_attn_indexer_topk_only( + hidden_states, + kv_cache, + q_input, + k, + weights, + quant_block_size, + scale_fmt, + topk_tokens, + head_dim, + max_model_len, + total_seq_lens, + topk_indices_buffer, + k_norm_weight, + k_norm_bias, + k_norm_eps, + positions, + cos_cache, + sin_cache, + weights_scale, + is_neox_style, + use_qk_rope_cache_fusion, + context, + attn_metadata, + ) + + from atom.models.deepseek_v2 import sparse_attn_indexer + + return sparse_attn_indexer( + hidden_states, + k_cache_prefix, + kv_cache, + q_input, + k, + weights, + quant_block_size, + scale_fmt, + topk_tokens, + head_dim, + max_model_len, + total_seq_lens, + topk_indices_buffer, + k_norm_weight, + k_norm_bias, + k_norm_eps, + positions, + cos_cache, + sin_cache, + weights_scale, + is_neox_style, + use_qk_rope_cache_fusion, + ) + + +def rtp_sparse_attn_indexer_fake( + hidden_states: torch.Tensor, + k_cache_prefix: str, + kv_cache: torch.Tensor, + q_input: torch.Tensor, + k: torch.Tensor, + weights: torch.Tensor, + quant_block_size: int, + scale_fmt: Optional[str], + topk_tokens: int, + head_dim: int, + max_model_len: int, + total_seq_lens: int, + topk_indices_buffer: torch.Tensor, + k_norm_weight: torch.Tensor, + k_norm_bias: torch.Tensor, + k_norm_eps: float, + positions: torch.Tensor, + cos_cache: torch.Tensor, + sin_cache: torch.Tensor, + weights_scale: float, + is_neox_style: bool, + use_qk_rope_cache_fusion: bool, +) -> torch.Tensor: + from atom.models.deepseek_v2 import sparse_attn_indexer_fake + + return sparse_attn_indexer_fake( + hidden_states, + k_cache_prefix, + kv_cache, + q_input, + k, + weights, + quant_block_size, + scale_fmt, + topk_tokens, + head_dim, + max_model_len, + total_seq_lens, + topk_indices_buffer, + k_norm_weight, + k_norm_bias, + k_norm_eps, + positions, + cos_cache, + sin_cache, + weights_scale, + is_neox_style, + use_qk_rope_cache_fusion, + ) + + +direct_register_custom_op( + op_name="rtp_sparse_attn_indexer", + op_func=rtp_sparse_attn_indexer, + mutates_args=["topk_indices_buffer"], + fake_impl=rtp_sparse_attn_indexer_fake, +) diff --git a/atom/plugin/rtpllm/models/__init__.py b/atom/plugin/rtpllm/models/__init__.py index ecaaacb47b..c99f5363fd 100644 --- a/atom/plugin/rtpllm/models/__init__.py +++ b/atom/plugin/rtpllm/models/__init__.py @@ -1,3 +1,19 @@ -from .base_model_wrapper import ATOMQwen35Moe +try: + from .base_model_wrapper import ATOMGlm5Moe, ATOMQwen35Moe +except ModuleNotFoundError as exc: + if not (exc.name or "").startswith("rtp_llm"): + raise + ATOMGlm5Moe = None + ATOMQwen35Moe = None +else: + try: + from atom.models.deepseek_v2 import GlmMoeDsaForCausalLM + from atom.plugin.register import _ATOM_SUPPORTED_MODELS + except ImportError: + # Unit tests may stub partial module trees and intentionally skip + # full model imports. Keep wrapper symbols importable in that case. + pass + else: + _ATOM_SUPPORTED_MODELS.setdefault("GlmMoeDsaForCausalLM", GlmMoeDsaForCausalLM) -__all__ = ["ATOMQwen35Moe"] +__all__ = ["ATOMGlm5Moe", "ATOMQwen35Moe"] diff --git a/atom/plugin/rtpllm/models/base_model_wrapper.py b/atom/plugin/rtpllm/models/base_model_wrapper.py index d952107b95..b0aed863a6 100644 --- a/atom/plugin/rtpllm/models/base_model_wrapper.py +++ b/atom/plugin/rtpllm/models/base_model_wrapper.py @@ -14,6 +14,7 @@ register_model, ) +from atom.plugin.rtpllm.models.glm5 import ATOMGlm5Moe from atom.plugin.rtpllm.models.qwen3_5 import ATOMQwen35Moe @@ -28,4 +29,12 @@ def _register_atom_qwen35_moe() -> None: _hf_architecture_2_ft["Qwen3_5MoeForConditionalGeneration"] = "qwen35_moe" +def _register_atom_glm5_moe() -> None: + """Register ATOM's rtp-llm model hook for GLM5.""" + register_model("atom_glm5_moe", ATOMGlm5Moe, []) + _model_factory["glm_5"] = ATOMGlm5Moe + _hf_architecture_2_ft["GlmMoeDsaForCausalLM"] = "glm_5" + + _register_atom_qwen35_moe() +_register_atom_glm5_moe() diff --git a/atom/plugin/rtpllm/models/glm5.py b/atom/plugin/rtpllm/models/glm5.py new file mode 100644 index 0000000000..41c1b86131 --- /dev/null +++ b/atom/plugin/rtpllm/models/glm5.py @@ -0,0 +1,789 @@ +"""GLM5 wrapper for rtp-llm external model loading.""" + +from __future__ import annotations + +import logging +import os +from typing import Any + +import torch +from rtp_llm.config.model_config import ModelConfig +from rtp_llm.model_loader.model_weight_info import ModelWeights +from rtp_llm.models.deepseek_v2 import DeepSeekV2 +from rtp_llm.models_py.model_desc.module_base import GptModelBase +from rtp_llm.ops import ParallelismConfig +from rtp_llm.ops.compute_ops import PyModelInputs, PyModelOutputs +from rtp_llm.utils.model_weight import W + +logger = logging.getLogger("atom.plugin.rtpllm.models") + +# Patched in tests; lazily imported in runtime to keep module import lightweight. +RTPForwardContext = None + + +class _NoopWeightManager: + def update(self, req): # noqa: ANN001 + return None + + +class _NoopModelWeightsLoader: + _py_eplb = None + + def load_lora_weights(self, adapter_name, lora_path, device): # noqa: ANN001 + logger.warning( + "No-op model_weights_loader received load_lora_weights(%s, %s, %s); " + "external plugin mode uses ATOM model weights path only.", + adapter_name, + lora_path, + device, + ) + return None + + +class _ATOMGlm5AttnPyObj: + """Container returned to RTP CudaGraphRunner for replay-time hooks.""" + + def __init__(self, runtime: "_ATOMGlm5MoeRuntime") -> None: + self._runtime = runtime + self.is_cuda_graph = False + self._rtp_mla_layers: list[Any] = [] + self._rtp_sparse_mla_backends: list[Any] = [] + self._collect_mla_layers() + + @staticmethod + def _append_unique(items: list[Any], value: Any) -> None: + if value is not None and all(value is not item for item in items): + items.append(value) + + def _collect_mla_layers(self) -> None: + try: + from atom.plugin.rtpllm.attention_backend import ( + RTPMLAAttention, + RTPSparseMlaBackend, + ) + except (ImportError, ModuleNotFoundError): + RTPMLAAttention = None + RTPSparseMlaBackend = None + + candidates: list[Any] = [] + _, _, mla_layer_map = self._runtime._rtp_layer_maps + candidates.extend(mla_layer_map.values()) + for module in self._runtime.model.modules(): + candidates.append(module) + mla_attn = getattr(module, "mla_attn", None) + if mla_attn is not None: + candidates.append(mla_attn) + + for candidate in candidates: + if RTPMLAAttention is not None and isinstance(candidate, RTPMLAAttention): + self._append_unique(self._rtp_mla_layers, candidate) + backend = getattr(candidate, "sparse_backend", None) + else: + backend = getattr(candidate, "sparse_backend", None) + if ( + backend is None + and RTPSparseMlaBackend is not None + and isinstance(candidate, RTPSparseMlaBackend) + ): + backend = candidate + + if RTPSparseMlaBackend is not None and isinstance( + backend, RTPSparseMlaBackend + ): + self._append_unique(self._rtp_sparse_mla_backends, backend) + + @property + def fmha_params(self): + return None + + def prepare_cuda_graph(self, attn_inputs) -> None: # noqa: ANN001 + for layer in self._rtp_mla_layers: + prepare = getattr(layer, "prepare_cuda_graph", None) + if callable(prepare): + prepare(attn_inputs) + for backend in self._rtp_sparse_mla_backends: + prepare = getattr(backend, "prepare_cuda_graph", None) + if callable(prepare): + prepare(attn_inputs) + + +class _ATOMGlm5MoeRuntime(GptModelBase): + """rtp-llm runtime adapter backed by an ATOM GLM5 model.""" + + def __init__( + self, + model_config: ModelConfig, + parallelism_config: ParallelismConfig, + weights: ModelWeights, + max_generate_batch_size: int, + atom_model: Any, + fmha_config=None, + py_hw_kernel_config=None, + device_resource_config=None, + ) -> None: + super().__init__( + model_config, + parallelism_config, + weights, + max_generate_batch_size=max_generate_batch_size, + fmha_config=fmha_config, + py_hw_kernel_config=py_hw_kernel_config, + device_resource_config=device_resource_config, + ) + self.model = atom_model + first_param = next(iter(self.model.parameters()), None) + if first_param is not None: + self._model_device = first_param.device + self._model_dtype = first_param.dtype + else: + self._model_device = torch.device("cpu") + self._model_dtype = torch.get_default_dtype() + forward_context_cls = self._get_forward_context_cls() + self._rtp_layer_maps = forward_context_cls.collect_layer_maps(model=self.model) + self._rtp_kv_cache_data: dict | None = None + self._rtp_kv_cache_signature: tuple | None = None + self._rtp_layer_group_map: dict[int, int] | None = None + self._rtp_layer_group_map_signature: tuple | None = None + decode_caps = getattr(py_hw_kernel_config, "decode_capture_batch_sizes", None) + if decode_caps: + self._cg_max_num_tokens: int = min( + int(max(decode_caps)), int(max_generate_batch_size) + ) + else: + self._cg_max_num_tokens: int = int(max_generate_batch_size) + self._cg_max_seq_len: int = int( + getattr(model_config, "max_seq_len", 0) + or getattr(model_config, "max_position_embeddings", 0) + or 32768 + ) + self._atom_attn_pyobj: _ATOMGlm5AttnPyObj | None = None + self._cg_layers_prewarmed: bool = False + + def load_weights(self): + return None + + def prepare_fmha_impl( + self, inputs: PyModelInputs, is_cuda_graph: bool = False + ) -> _ATOMGlm5AttnPyObj: + if self._atom_attn_pyobj is None: + self._atom_attn_pyobj = _ATOMGlm5AttnPyObj(self) + self._atom_attn_pyobj.is_cuda_graph = bool(is_cuda_graph) + if bool(is_cuda_graph): + inputs.attention_inputs.is_cuda_graph = True + self._ensure_cuda_graph_prewarmed() + return self._atom_attn_pyobj + + def _ensure_cuda_graph_prewarmed(self) -> None: + if self._cg_layers_prewarmed: + return + if self._atom_attn_pyobj is None: + return + + max_num_tokens = int(self._cg_max_num_tokens) + max_seq_len = int(self._cg_max_seq_len) + if max_num_tokens <= 0 or max_seq_len <= 0: + logger.warning( + "ATOM GLM5 cuda-graph prewarm skipped: invalid budget " + "(max_num_tokens=%d, max_seq_len=%d)", + max_num_tokens, + max_seq_len, + ) + return + + device = self._get_model_device() + dtype = self._get_model_dtype() + kv_cache = getattr(self, "kv_cache", None) + seq_size_per_block = ( + int(getattr(kv_cache, "seq_size_per_block", 0)) + or int(os.getenv("SEQ_SIZE_PER_BLOCK", "0") or 0) + or 1 + ) + kernel_seq_size_per_block = ( + int(getattr(kv_cache, "kernel_seq_size_per_block", 0)) + or int(os.getenv("KERNEL_SEQ_SIZE_PER_BLOCK", "0") or 0) + or seq_size_per_block + ) + physical_max_blocks = ( + int(max_seq_len) + seq_size_per_block - 1 + ) // seq_size_per_block + 1 + recovered_physical_max_blocks = ( + int(max_seq_len) + seq_size_per_block - 1 + ) // seq_size_per_block + indexer_max_blocks = ( + int(max_seq_len) + kernel_seq_size_per_block - 1 + ) // kernel_seq_size_per_block + 1 + block_table_max_blocks = max(physical_max_blocks, indexer_max_blocks) + + for backend in self._atom_attn_pyobj._rtp_sparse_mla_backends: + prewarm = getattr(backend, "prewarm_for_cuda_graph", None) + if callable(prewarm): + prewarm( + max_num_tokens=max_num_tokens, + max_seq_len=max_seq_len, + query_dtype=dtype, + device=device, + ) + self._cg_meta_bufs: dict[str, torch.Tensor] = { + "query_start_loc": torch.arange( + 0, max_num_tokens + 1, device=device, dtype=torch.int32 + ), + "seq_id": torch.arange(0, max_num_tokens, device=device, dtype=torch.int64), + "seq_id_i32": torch.arange( + 0, max_num_tokens, device=device, dtype=torch.int32 + ), + "positions_i32": torch.empty( + max_num_tokens, device=device, dtype=torch.int32 + ), + "positions_i64": torch.empty( + max_num_tokens, device=device, dtype=torch.int64 + ), + "block_col": torch.empty(max_num_tokens, device=device, dtype=torch.int32), + "block_col_i64": torch.empty( + max_num_tokens, device=device, dtype=torch.int64 + ), + "slot_base": torch.empty(max_num_tokens, device=device, dtype=torch.int32), + "token_offset": torch.empty( + max_num_tokens, device=device, dtype=torch.int32 + ), + "slot_mapping": torch.empty( + max_num_tokens, device=device, dtype=torch.int64 + ), + "seq_lens_i32": torch.empty( + max_num_tokens, device=device, dtype=torch.int32 + ), + "physical_block_table_i32": torch.empty( + max_num_tokens, + recovered_physical_max_blocks, + device=device, + dtype=torch.int32, + ), + "block_table_i32": torch.empty( + max_num_tokens, block_table_max_blocks, device=device, dtype=torch.int32 + ), + "indexer_block_table_i32": torch.empty( + max_num_tokens, indexer_max_blocks, device=device, dtype=torch.int32 + ), + } + self._cg_layers_prewarmed = True + logger.info( + "ATOM GLM5 cuda-graph prewarmed " + "(max_num_tokens=%d, max_seq_len=%d, sparse_layers=%d, " + "physical_block_table_i32[%dx%d], block_table_i32[%dx%d], " + "indexer_block_table_i32[%dx%d])", + max_num_tokens, + max_seq_len, + len(self._atom_attn_pyobj._rtp_sparse_mla_backends), + max_num_tokens, + recovered_physical_max_blocks, + max_num_tokens, + block_table_max_blocks, + max_num_tokens, + indexer_max_blocks, + ) + + @staticmethod + def _get_forward_context_cls(): + global RTPForwardContext + if RTPForwardContext is None: + from atom.plugin.rtpllm.utils import ( + RTPForwardMLAContext as _RTPForwardContext, + ) + + RTPForwardContext = _RTPForwardContext + return RTPForwardContext + + def _get_model_device(self) -> torch.device: + return self._model_device + + def _get_model_dtype(self) -> torch.dtype: + return self._model_dtype + + def _get_token_num( + self, inputs: PyModelInputs, input_ids: torch.Tensor | None + ) -> int: + if input_ids is not None and input_ids.numel() > 0: + return int(input_ids.numel()) + input_hiddens = getattr(inputs, "input_hiddens", None) + if input_hiddens is not None and input_hiddens.numel() > 0: + return int(input_hiddens.shape[0]) + return 0 + + @staticmethod + def _build_token_positions( + input_lengths: torch.Tensor, + starts: torch.Tensor, + ) -> torch.Tensor | None: + token_starts = torch.repeat_interleave(starts, input_lengths) + if token_starts.numel() == 0: + return None + per_seq_base = input_lengths.cumsum(dim=0) - input_lengths + token_ordinal = ( + torch.cumsum( + torch.repeat_interleave(torch.ones_like(input_lengths), input_lengths), + dim=0, + ) + - 1 + ) + token_ordinal = token_ordinal - torch.repeat_interleave( + per_seq_base, input_lengths + ) + return (token_starts + token_ordinal).to(dtype=torch.int32).contiguous() + + def _build_positions_from_attention_inputs( + self, attn_inputs: Any, model_device: torch.device + ) -> torch.Tensor | None: + if attn_inputs is None: + return None + + input_lengths = getattr(attn_inputs, "input_lengths", None) + if input_lengths is None or input_lengths.numel() == 0: + return None + input_lengths_i32 = input_lengths.to( + device=model_device, dtype=torch.int32, non_blocking=True + ).contiguous() + + is_prefill = bool(getattr(attn_inputs, "is_prefill", False)) + if is_prefill: + prefix_lengths = getattr(attn_inputs, "prefix_lengths", None) + if prefix_lengths is None or prefix_lengths.numel() == 0: + return None + prefix_lengths_i32 = prefix_lengths.to( + device=model_device, dtype=torch.int32, non_blocking=True + ).contiguous() + if int(prefix_lengths_i32.numel()) < int(input_lengths_i32.numel()): + return None + starts = prefix_lengths_i32[: int(input_lengths_i32.numel())] + return self._build_token_positions(input_lengths_i32, starts) + + sequence_lengths_plus_1 = getattr( + attn_inputs, "sequence_lengths_plus_1_d", None + ) + if sequence_lengths_plus_1 is not None and sequence_lengths_plus_1.numel() > 0: + seq_plus_one_i32 = sequence_lengths_plus_1.to( + device=model_device, dtype=torch.int32, non_blocking=True + ).contiguous() + if int(seq_plus_one_i32.numel()) < int(input_lengths_i32.numel()): + return None + starts = ( + seq_plus_one_i32[: int(input_lengths_i32.numel())] - input_lengths_i32 + ) + return self._build_token_positions(input_lengths_i32, starts) + + sequence_lengths = getattr(attn_inputs, "sequence_lengths", None) + if sequence_lengths is None or sequence_lengths.numel() == 0: + return None + sequence_lengths_i32 = sequence_lengths.to( + device=model_device, dtype=torch.int32, non_blocking=True + ).contiguous() + if int(sequence_lengths_i32.numel()) < int(input_lengths_i32.numel()): + return None + starts = ( + sequence_lengths_i32[: int(input_lengths_i32.numel())] + - input_lengths_i32 + + 1 + ) + return self._build_token_positions(input_lengths_i32, starts) + + def _build_graph_decode_positions( + self, attn_inputs: Any, model_device: torch.device + ) -> torch.Tensor | None: + sequence_lengths_plus_1 = getattr( + attn_inputs, "sequence_lengths_plus_1_d", None + ) + if sequence_lengths_plus_1 is None or sequence_lengths_plus_1.numel() == 0: + return None + input_lengths = getattr(attn_inputs, "input_lengths", None) + if input_lengths is None or input_lengths.numel() == 0: + return None + num_tokens = int(input_lengths.numel()) + seq_plus_one_i32 = sequence_lengths_plus_1.to( + device=model_device, dtype=torch.int32, non_blocking=True + ) + if int(seq_plus_one_i32.numel()) < num_tokens: + return None + cg_bufs = getattr(self, "_cg_meta_bufs", None) + if isinstance(cg_bufs, dict): + positions_buf = cg_bufs.get("positions_i32") + if ( + isinstance(positions_buf, torch.Tensor) + and int(positions_buf.numel()) >= num_tokens + ): + positions_i32 = positions_buf[:num_tokens] + torch.sub(seq_plus_one_i32[:num_tokens], 1, out=positions_i32) + positions_i64_buf = cg_bufs.get("positions_i64") + if ( + isinstance(positions_i64_buf, torch.Tensor) + and int(positions_i64_buf.numel()) >= num_tokens + ): + positions_i64 = positions_i64_buf[:num_tokens] + positions_i64.copy_(positions_i32) + return positions_i64 + return positions_i32 + return (seq_plus_one_i32[:num_tokens] - 1).to(dtype=torch.long).contiguous() + + def _extract_combo_positions( + self, inputs: PyModelInputs, model_device: torch.device + ) -> torch.Tensor | None: + bert_inputs = getattr(inputs, "bert_embedding_inputs", None) + if bert_inputs is None: + return None + combo_position_ids = getattr(bert_inputs, "combo_position_ids", None) + if combo_position_ids is None or combo_position_ids.numel() == 0: + return None + return combo_position_ids.to( + device=model_device, dtype=torch.long, non_blocking=True + ).contiguous() + + def _extract_positions( + self, inputs: PyModelInputs, model_device: torch.device, token_num: int + ) -> torch.Tensor: + attn_inputs = getattr(inputs, "attention_inputs", None) + if attn_inputs is None: + raise ValueError( + "GLM5 RTP plugin requires inputs.attention_inputs to provide position metadata." + ) + positions = None + graph_decode = bool(getattr(attn_inputs, "is_cuda_graph", False)) and not bool( + getattr(attn_inputs, "is_prefill", False) + ) + if graph_decode: + # RTP CudaGraphRunner refreshes sequence_lengths_plus_1_d before + # replay, but not position_ids. Build decode positions from the + # refreshed RTP length tensors so RoPE advances on every replay. + positions = self._build_graph_decode_positions( + attn_inputs=attn_inputs, + model_device=model_device, + ) + if positions is None or positions.numel() == 0: + positions = getattr(attn_inputs, "position_ids", None) + if positions is None or positions.numel() == 0: + positions = self._extract_combo_positions( + inputs=inputs, model_device=model_device + ) + if positions is None or positions.numel() == 0: + positions = self._build_positions_from_attention_inputs( + attn_inputs=attn_inputs, + model_device=model_device, + ) + if positions is None or positions.numel() == 0: + raise ValueError( + "GLM5 RTP plugin requires real position metadata from attention_inputs." + ) + if torch.cuda.is_current_stream_capturing(): + if positions.device != model_device: + raise RuntimeError( + "GLM5 RTP cuda-graph capture requires positions on model device." + ) + positions = positions.contiguous() + else: + positions = positions.to( + device=model_device, dtype=torch.long, non_blocking=True + ).contiguous() + if not torch.cuda.is_current_stream_capturing(): + pos_tokens = ( + int(positions.shape[-1]) + if positions.dim() > 0 + else int(positions.numel()) + ) + if token_num > 0 and pos_tokens != token_num: + rebuilt_positions = self._build_positions_from_attention_inputs( + attn_inputs=attn_inputs, + model_device=model_device, + ) + rebuilt_tokens = ( + int(rebuilt_positions.shape[-1]) + if rebuilt_positions is not None and rebuilt_positions.dim() > 0 + else ( + int(rebuilt_positions.numel()) + if rebuilt_positions is not None + else -1 + ) + ) + if rebuilt_positions is not None and rebuilt_tokens == token_num: + positions = rebuilt_positions.to( + device=model_device, dtype=torch.long, non_blocking=True + ).contiguous() + elif pos_tokens > token_num: + positions = positions[..., -token_num:].contiguous() + else: + raise ValueError( + "GLM5 RTP plugin position_ids/token_num mismatch " + f"(position_ids_tokens={pos_tokens}, token_num={token_num})." + ) + return positions + + def forward( + self, inputs: PyModelInputs, fmha_impl=None + ) -> PyModelOutputs: # noqa: ANN001 + is_cuda_graph = bool(getattr(fmha_impl, "is_cuda_graph", False)) + if is_cuda_graph: + inputs.attention_inputs.is_cuda_graph = True + model_device = self._get_model_device() + model_dtype = self._get_model_dtype() + input_ids = inputs.input_ids + inputs_embeds = None + + if ( + input_ids is not None + and input_ids.numel() > 0 + and input_ids.device != model_device + ): + input_ids = input_ids.to(device=model_device, non_blocking=True) + token_num = self._get_token_num(inputs=inputs, input_ids=input_ids) + positions = self._extract_positions( + inputs=inputs, model_device=model_device, token_num=token_num + ) + if is_cuda_graph and token_num > 0: + positions = positions[:token_num] + if input_ids is None or input_ids.numel() == 0: + inputs_embeds = inputs.input_hiddens + if ( + inputs_embeds is not None + and inputs_embeds.numel() > 0 + and inputs_embeds.device != model_device + ): + inputs_embeds = inputs_embeds.to(device=model_device, non_blocking=True) + if ( + inputs_embeds is not None + and inputs_embeds.numel() > 0 + and inputs_embeds.dtype != model_dtype + ): + inputs_embeds = inputs_embeds.to(dtype=model_dtype) + + forward_context_cls = self._get_forward_context_cls() + with forward_context_cls.bind( + model=self.model, + runtime=self, + inputs=inputs, + positions=positions, + layer_maps=self._rtp_layer_maps, + cg_max_seq_len=int(self._cg_max_seq_len), + cg_bufs=getattr(self, "_cg_meta_bufs", None), + ): + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=None, + inputs_embeds=inputs_embeds, + ) + return PyModelOutputs(hidden_states) + + +class ATOMGlm5Moe(DeepSeekV2): + """GLM5 model class that starts ATOM runtime in rtp-llm plugin mode.""" + + @staticmethod + def _is_external_plugin_mode() -> bool: + modules = os.getenv("RTP_LLM_EXTERNAL_MODEL_PACKAGES", "") + return "atom.plugin.rtpllm.models" in modules + + @classmethod + def _create_config(cls, ckpt_path: str): + config = super()._create_config(ckpt_path) + # ATOM sparse MLA reads the FP8 KV cache through aiter's 576-token layout. + config.attn_config.mla_use_aiter_fp8_layout = True + return config + + def support_cuda_graph(self) -> bool: + if os.getenv("ENABLE_CUDA_GRAPH", "1") == "0": + logger.info("ENABLE_CUDA_GRAPH=0 - ATOMGlm5Moe forces eager forward.") + return False + return True + + @staticmethod + def _make_glm5_hf_mapper(): + from atom.model_loader.loader import WeightsMapper + + return WeightsMapper( + orig_to_new_prefix={}, + orig_to_new_substr={ + "indexers_proj.": "indexer.weights_proj.", + }, + ) + + @staticmethod + def _get_named_parameters(atom_model: Any) -> dict[str, torch.Tensor]: + if atom_model is None or not hasattr(atom_model, "named_parameters"): + return {} + return { + name: param + for name, param in atom_model.named_parameters(recurse=True) + if param is not None + } + + @staticmethod + def _first_param( + params: dict[str, torch.Tensor], candidates: tuple[str, ...] + ) -> torch.Tensor | None: + for name in candidates: + param = params.get(name) + if param is not None: + return param + return None + + def _inject_rtp_projection_weights(self, atom_model: Any) -> None: + params = self._get_named_parameters(atom_model) + if not params: + logger.warning( + "Skip GLM5 RTP projection weight injection because atom_model has no named parameters." + ) + return + + required = { + W.lm_head: ( + "language_model.lm_head.weight", + "lm_head.weight", + ), + W.embedding: ( + "language_model.model.embed_tokens.weight", + "model.embed_tokens.weight", + ), + W.final_ln_gamma: ( + "language_model.model.norm.weight", + "model.norm.weight", + ), + } + missing = [] + for weight_name, candidates in required.items(): + param = self._first_param(params, candidates) + if param is None: + missing.append((weight_name, candidates)) + continue + self.weight.set_global_weight(weight_name, param.detach()) + logger.info( + "Injected GLM5 runtime %s for RTP: %s", + weight_name, + tuple(param.shape), + ) + if missing: + details = ", ".join( + f"{weight_name} candidates={candidates}" + for weight_name, candidates in missing + ) + raise ValueError( + f"Cannot locate GLM5 RTP runtime projection weights: {details}" + ) + + def _assert_norm_weights_loaded(self, atom_model: Any) -> None: + params = self._get_named_parameters(atom_model) + if not params: + logger.warning( + "Skip GLM5 norm weight validation because atom_model has no named parameters." + ) + return + norm_w = self._first_param( + params, + ( + "language_model.model.layers.0.input_layernorm.weight", + "model.layers.0.input_layernorm.weight", + ), + ) + if norm_w is None: + raise ValueError( + "Cannot locate GLM5 layer-0 input_layernorm.weight after ATOM load in RTP plugin mode." + ) + norm_w_cpu = norm_w.detach().float().reshape(-1).cpu() + if norm_w_cpu.numel() == 0 or bool(torch.all(norm_w_cpu == 0)): + raise ValueError( + "Loaded GLM5 layer-0 input_layernorm.weight is all zeros; " + "refusing to run with default values." + ) + + def load(self, skip_python_model: bool = False): + if self._is_external_plugin_mode(): + self.device = self._get_device_str() + self.weight = ModelWeights( + num_layers=self.model_config.num_layers, + device=self.device, + dtype=self.model_config.compute_dtype, + ) + self.model_weights_loader = _NoopModelWeightsLoader() + self.py_eplb = self.model_weights_loader._py_eplb + self.weight_manager = _NoopWeightManager() + if skip_python_model: + logger.info( + "External plugin mode: skip ATOM GLM5 python model creation as requested" + ) + return + self._create_python_model() + logger.info( + "External plugin mode: use ATOM GLM5 loading path and skip native load" + ) + return + + super().load(skip_python_model=skip_python_model) + + def _create_python_model(self): + if not self._is_external_plugin_mode(): + return super()._create_python_model() + + import atom + from atom.model_loader.loader import load_model_in_plugin_mode + + prepare_model = getattr(atom, "prepare_model", None) + if prepare_model is None: + from atom.plugin.prepare import prepare_model + + target_device = torch.device( + self.device if getattr(self, "device", None) else "cuda" + ) + target_dtype = self.model_config.compute_dtype + old_default_dtype = torch.get_default_dtype() + try: + old_default_device = torch.get_default_device() + except Exception: + old_default_device = None + + torch.set_default_device(target_device) + if target_dtype in { + torch.float16, + torch.bfloat16, + torch.float32, + torch.float64, + }: + torch.set_default_dtype(target_dtype) + + try: + atom_model = prepare_model(config=self, engine="rtpllm") + if atom_model is None: + raise ValueError("ATOM failed to create GLM5 model for rtp-llm plugin") + + if hasattr(atom_model, "to"): + atom_model = atom_model.to(target_device) + + atom_config = getattr(atom_model, "atom_config", None) + if atom_config is None: + atom_config = getattr( + getattr(atom_model, "model", None), "atom_config", None + ) + if atom_config is None: + # Unit tests may use mocked ATOM models; real loading must expose atom_config. + atom_config = getattr(self, "atom_config", None) + + load_model_in_plugin_mode( + model=atom_model, + config=atom_config, + prefix="model.", + weights_mapper=self._make_glm5_hf_mapper(), + ) + self._assert_norm_weights_loaded(atom_model) + self._inject_rtp_projection_weights(atom_model) + finally: + torch.set_default_dtype(old_default_dtype) + if old_default_device is not None: + torch.set_default_device(old_default_device) + else: + torch.set_default_device("cpu") + + self.py_model = _ATOMGlm5MoeRuntime( + model_config=self.model_config, + parallelism_config=self.parallelism_config, + weights=self.weight, + max_generate_batch_size=self.max_generate_batch_size, + fmha_config=self.fmha_config, + py_hw_kernel_config=self.hw_kernel_config, + device_resource_config=self.device_resource_config, + atom_model=atom_model, + ) + logger.info("Created ATOM GLM5 runtime for rtp-llm plugin mode") + return self.py_model diff --git a/atom/plugin/rtpllm/models/qwen3_5.py b/atom/plugin/rtpllm/models/qwen3_5.py index eb41294831..dbe21cc903 100644 --- a/atom/plugin/rtpllm/models/qwen3_5.py +++ b/atom/plugin/rtpllm/models/qwen3_5.py @@ -13,18 +13,7 @@ from rtp_llm.ops.compute_ops import PyModelInputs, PyModelOutputs from rtp_llm.utils.model_weight import W -from atom.model_loader.loader import WeightsMapper -from atom.models.qwen3_5 import ( - detect_fused_expert_format, - get_fused_expert_mapping, - load_fused_expert_weights, -) -from atom.plugin.rtpllm.attention_backend import ( - apply_attention_gdn_rtpllm_patch, - apply_attention_mha_rtpllm_patch, -) from atom.plugin.rtpllm.models.qwen3_next import apply_qwen3_next_rtpllm_patch -from atom.plugin.rtpllm.utils import RTPForwardContext logger = logging.getLogger("atom.plugin.rtpllm.models") @@ -126,8 +115,13 @@ def __init__( ) self._model_device = first_param.device self._model_dtype = first_param.dtype + from atom.plugin.rtpllm.utils import RTPForwardQwen35HybridContext + + self._rtp_forward_context_cls = RTPForwardQwen35HybridContext # Cache module layer maps once to avoid per-forward model.modules() traversal. - self._rtp_layer_maps = RTPForwardContext.collect_layer_maps(model=self.model) + self._rtp_layer_maps = self._rtp_forward_context_cls.collect_layer_maps( + model=self.model + ) # Lazy-built in forward_context; invalidated by kv buffer signature change. self._rtp_kv_cache_data: dict | None = None self._rtp_kv_cache_signature: tuple | None = None @@ -384,9 +378,17 @@ def _ensure_cuda_graph_prewarmed(self) -> None: or int(getattr(kv_cache, "seq_size_per_block", 0)) or 1 ) + seq_size_per_block = ( + int(getattr(kv_cache, "seq_size_per_block", 0)) + or kernel_seq_size_per_block + or 1 + ) max_blocks = ( int(max_seq_len) + kernel_seq_size_per_block - 1 ) // kernel_seq_size_per_block + 1 + physical_max_blocks = ( + int(max_seq_len) + seq_size_per_block - 1 + ) // seq_size_per_block # query_start_loc for decode: always [0, 1, 2, ..., bs], i.e. arange(bs+1). # seq_id for decode slot_mapping: seq_id[i] == i, i.e. arange(bs). self._cg_meta_bufs: dict = { @@ -394,6 +396,7 @@ def _ensure_cuda_graph_prewarmed(self) -> None: 0, max_bs + 1, device=device, dtype=torch.int32 ), "seq_id": torch.arange(0, max_bs, device=device, dtype=torch.int64), + "seq_id_i32": torch.arange(0, max_bs, device=device, dtype=torch.int32), "block_col": torch.empty(max_bs, device=device, dtype=torch.int32), "block_col_i64": torch.empty(max_bs, device=device, dtype=torch.int64), "slot_base": torch.empty(max_bs, device=device, dtype=torch.int32), @@ -403,12 +406,16 @@ def _ensure_cuda_graph_prewarmed(self) -> None: "block_table_i32": torch.empty( max_bs, max_blocks, device=device, dtype=torch.int32 ), + "physical_block_table_i32": torch.empty( + max_bs, max(physical_max_blocks, 1), device=device, dtype=torch.int32 + ), } self._cg_layers_prewarmed = True logger.info( "ATOM RTPFullAttention cuda-graph prewarmed for %d layers " "(max_num_tokens=%d, max_seq_len=%d, rtp_kv_heads=%s, " - "meta_bufs: query_start_loc[%d], slot_mapping[%d], block_table_i32[%dx%d])", + "meta_bufs: query_start_loc[%d], slot_mapping[%d], block_table_i32[%dx%d], " + "physical_block_table_i32[%dx%d])", len(self._atom_attn_pyobj._rtp_full_attn_layers), max_num_tokens, max_seq_len, @@ -417,6 +424,8 @@ def _ensure_cuda_graph_prewarmed(self) -> None: max_bs, max_bs, max_blocks, + max_bs, + max(physical_max_blocks, 1), ) def forward(self, inputs: PyModelInputs, fmha_impl: Any = None) -> PyModelOutputs: @@ -452,7 +461,7 @@ def forward(self, inputs: PyModelInputs, fmha_impl: Any = None) -> PyModelOutput ): inputs_embeds = inputs_embeds.to(dtype=model_dtype) - with RTPForwardContext.bind( + with self._rtp_forward_context_cls.bind( model=self.model, runtime=self, inputs=inputs, @@ -570,7 +579,9 @@ def support_cuda_graph(self) -> bool: return True @staticmethod - def _make_qwen35_hf_mapper() -> WeightsMapper: + def _make_qwen35_hf_mapper(): + from atom.model_loader.loader import WeightsMapper + # Keep loading on outer text-only wrapper so packed_modules_mapping works. # Normalize checkpoint prefixes to match wrapper's weights_mapping rules. return WeightsMapper( @@ -738,6 +749,12 @@ def _load_fused_expert_weights_for_qwen35( shard_id: str, num_experts: int, ) -> bool: + from atom.models.qwen3_5 import ( + detect_fused_expert_format, + get_fused_expert_mapping, + load_fused_expert_weights, + ) + if not detect_fused_expert_format(original_name): return False mapping = get_fused_expert_mapping() @@ -755,6 +772,11 @@ def _load_fused_expert_weights_for_qwen35( try: # Keep RTP-specific behavior in plugin path only. _set_framework_backbone("rtpllm") + from atom.plugin.rtpllm.attention_backend import ( + apply_attention_gdn_rtpllm_patch, + apply_attention_mha_rtpllm_patch, + ) + apply_attention_gdn_rtpllm_patch() apply_attention_mha_rtpllm_patch() apply_qwen3_next_rtpllm_patch() diff --git a/atom/plugin/rtpllm/utils/__init__.py b/atom/plugin/rtpllm/utils/__init__.py index 7a85bec249..d82cf33c0e 100644 --- a/atom/plugin/rtpllm/utils/__init__.py +++ b/atom/plugin/rtpllm/utils/__init__.py @@ -1,3 +1,11 @@ -from .forward_context import RTPForwardContext +from .forward_context import ( + RTPForwardContext, + RTPForwardMLAContext, + RTPForwardQwen35HybridContext, +) -__all__ = ["RTPForwardContext"] +__all__ = [ + "RTPForwardContext", + "RTPForwardMLAContext", + "RTPForwardQwen35HybridContext", +] diff --git a/atom/plugin/rtpllm/utils/forward_context.py b/atom/plugin/rtpllm/utils/forward_context.py index 2566966d97..0e536ace82 100644 --- a/atom/plugin/rtpllm/utils/forward_context.py +++ b/atom/plugin/rtpllm/utils/forward_context.py @@ -1,16 +1,33 @@ from __future__ import annotations -import os from contextlib import contextmanager from dataclasses import dataclass from typing import Any, Dict, Iterator, Tuple import torch +from aiter import dtypes + +try: + import triton + import triton.language as tl +except (ImportError, ModuleNotFoundError): + triton = None + tl = None from atom.config import KVCacheTensor, get_current_atom_config from atom.model_ops.attention_gdn import GatedDeltaNet -from atom.model_ops.attention_mha import PagedAttentionImpl -from atom.model_ops.paged_attention import Attention as PagedAttention + +try: + from atom.model_ops.attention_mha import PagedAttentionImpl +except (ImportError, ModuleNotFoundError): + PagedAttentionImpl = type("PagedAttentionImpl", (), {}) +try: + from atom.model_ops.paged_attention import Attention as PagedAttention +except (ImportError, ModuleNotFoundError): + try: + from atom.model_ops.paged_attention import PagedAttention + except (ImportError, ModuleNotFoundError): + PagedAttention = type("PagedAttention", (), {}) from atom.model_ops.attentions.gdn_attn import ( GDNAttentionMetadata, compute_causal_conv1d_metadata, @@ -61,9 +78,49 @@ class AiterFlashAttentionMetadataForPluginMode: context: Any = None +if triton is not None: + + @triton.jit + def _expand_block_table_for_atom_indexer_kernel( + block_table, + output, + num_cols: tl.constexpr, + output_cols: tl.constexpr, + block_ratio: tl.constexpr, + BLOCK_RATIO: tl.constexpr, + ): + row = tl.program_id(0) + col = tl.program_id(1) + offsets = tl.arange(0, BLOCK_RATIO) + value = tl.load(block_table + row * num_cols + col) + expanded = value * block_ratio + offsets + expanded = tl.where(value >= 0, expanded, -1) + tl.store(output + row * output_cols + col * block_ratio + offsets, expanded) + + @triton.jit + def _recover_physical_block_table_from_kernel_kernel( + kernel_block_table, + output, + kernel_cols: tl.constexpr, + physical_cols: tl.constexpr, + block_ratio: tl.constexpr, + ): + row = tl.program_id(0) + col = tl.program_id(1) + kernel_col = col * block_ratio + value = tl.load( + kernel_block_table + row * kernel_cols + kernel_col, + mask=kernel_col < kernel_cols, + other=-1, + ) + physical = value // block_ratio + physical = tl.where(value >= 0, physical, -1) + tl.store(output + row * physical_cols + col, physical) + + @dataclass(frozen=True) class RTPForwardContext: - gdn_metadata: GDNAttentionMetadata + gdn_metadata: GDNAttentionMetadata | None attn_metadata: AttentionMetaData rtp_attn_inputs: Any rtp_seq_size_per_block: int @@ -73,7 +130,8 @@ class RTPForwardContext: layer_group_map: Dict[int, int] context: Context num_tokens: int - LayerMaps = tuple[Dict[int, GatedDeltaNet], Dict[int, Any]] + mla_layer_map: Dict[int, Any] + LayerMaps = tuple[Dict[int, GatedDeltaNet], Dict[int, Any], Dict[int, Any]] @staticmethod def _non_empty_int32( @@ -305,6 +363,76 @@ def _select_block_table_for_layer( return by_group[gid] return getattr(attn_inputs, "kv_cache_kernel_block_id_device", None) + @staticmethod + def _recover_physical_block_table_from_kernel( + kernel_block_table: torch.Tensor, + *, + seq_size_per_block: int, + kernel_seq_size_per_block: int, + cg_bufs: dict | None = None, + ) -> torch.Tensor: + if ( + kernel_seq_size_per_block <= 0 + or seq_size_per_block <= 0 + or seq_size_per_block == kernel_seq_size_per_block + ): + return kernel_block_table + if seq_size_per_block % kernel_seq_size_per_block != 0: + raise ValueError( + "RTP plugin cannot recover physical block_table from kernel block_table: " + f"seq_size_per_block={seq_size_per_block}, " + f"kernel_seq_size_per_block={kernel_seq_size_per_block}." + ) + if kernel_block_table.dim() == 1: + kernel_block_table = kernel_block_table.unsqueeze(0) + if kernel_block_table.dim() != 2: + raise ValueError( + "RTP plugin invalid kernel block_table shape for physical recovery: " + f"{tuple(kernel_block_table.shape)}" + ) + block_ratio = int(seq_size_per_block // kernel_seq_size_per_block) + bs_now = int(kernel_block_table.shape[0]) + kernel_cols = int(kernel_block_table.shape[1]) + if kernel_cols < block_ratio or kernel_cols % block_ratio != 0: + return kernel_block_table.to( + device=kernel_block_table.device, dtype=torch.int32, non_blocking=True + ).contiguous() + physical_cols = (kernel_cols + block_ratio - 1) // block_ratio + in_capture = torch.cuda.is_current_stream_capturing() + if in_capture and cg_bufs is not None: + if triton is None: + raise RuntimeError( + "RTP plugin cuda-graph capture requires Triton for capture-safe " + "physical block_table recovery." + ) + out_buf = cg_bufs.get("physical_block_table_i32") + if not isinstance(out_buf, torch.Tensor): + raise RuntimeError( + "RTP plugin capture requires prewarmed physical_block_table_i32." + ) + if int(out_buf.shape[0]) < bs_now or int(out_buf.shape[1]) < physical_cols: + raise RuntimeError( + "RTP plugin prewarmed block_table_i32 buffer is too small for " + "physical recovery " + f"(buffer={tuple(out_buf.shape)}, required=({bs_now}, {physical_cols}))." + ) + out_view = out_buf[:bs_now, :physical_cols] + _recover_physical_block_table_from_kernel_kernel[(bs_now, physical_cols)]( + kernel_block_table, + out_view, + kernel_cols, + physical_cols, + block_ratio, + ) + return out_view + + sampled = kernel_block_table[:, : physical_cols * block_ratio : block_ratio] + recovered = torch.div(sampled, block_ratio, rounding_mode="floor") + recovered = torch.where(sampled >= 0, recovered, sampled) + return recovered.to( + device=kernel_block_table.device, dtype=torch.int32, non_blocking=True + ).contiguous() + @staticmethod def _build_layer_group_map(attn_inputs: Any) -> Dict[int, int]: layer_to_group = getattr(attn_inputs, "kv_cache_layer_to_group", None) @@ -476,10 +604,9 @@ def _build_gdn_metadata( def _build_seq_lens(attn_inputs: Any, *, device: torch.device) -> torch.Tensor: """Build kernel seq_lens using RTP-native field priority. - Non-cuda-graph decode keeps the pre-cuda-graph field priority: - sequence_lengths_plus_1_d first, then sequence_lengths + input_lengths. - Cuda-graph warmup/replay keeps the graph-safe priority introduced for - dummy inputs. + Decode uses RTP's canonical sequence_lengths_plus_1_d first in both + eager and CUDA-graph paths. This keeps context_lens aligned with the + block-table slot/state-index calculation during graph replay. """ input_lengths = RTPForwardContext._non_empty_int32( getattr(attn_inputs, "input_lengths", None), @@ -491,6 +618,20 @@ def _build_seq_lens(attn_inputs: Any, *, device: torch.device) -> torch.Tensor: ) is_prefill = bool(getattr(attn_inputs, "is_prefill", False)) if is_prefill: + # For chunked prefill, prefix_lengths can remain per-chunk while + # sequence_lengths_plus_1_d tracks the true cumulative context length. + sequence_lengths_plus_1 = RTPForwardContext._non_empty_int32( + getattr(attn_inputs, "sequence_lengths_plus_1_d", None), + device=device, + ) + if sequence_lengths_plus_1 is not None: + if int(sequence_lengths_plus_1.numel()) != int(input_lengths.numel()): + raise ValueError( + "RTP plugin sequence_lengths_plus_1_d/input_lengths batch mismatch " + f"(sequence_lengths_plus_1_d={int(sequence_lengths_plus_1.numel())}, " + f"input_lengths={int(input_lengths.numel())})." + ) + return sequence_lengths_plus_1.contiguous() prefix_lengths = RTPForwardContext._non_empty_int32( getattr(attn_inputs, "prefix_lengths_d", None), device=device, @@ -512,22 +653,18 @@ def _build_seq_lens(attn_inputs: Any, *, device: torch.device) -> torch.Tensor: ) return (prefix_lengths + input_lengths).contiguous() - non_cuda_graph_mode = not torch.cuda.is_current_stream_capturing() and not bool( - getattr(attn_inputs, "is_cuda_graph", False) + sequence_lengths_plus_1 = RTPForwardContext._non_empty_int32( + getattr(attn_inputs, "sequence_lengths_plus_1_d", None), + device=device, ) - if non_cuda_graph_mode: - sequence_lengths_plus_1 = RTPForwardContext._non_empty_int32( - getattr(attn_inputs, "sequence_lengths_plus_1_d", None), - device=device, - ) - if sequence_lengths_plus_1 is not None: - if int(sequence_lengths_plus_1.numel()) != int(input_lengths.numel()): - raise ValueError( - "RTP plugin sequence_lengths_plus_1_d/input_lengths batch mismatch " - f"(sequence_lengths_plus_1_d={int(sequence_lengths_plus_1.numel())}, " - f"input_lengths={int(input_lengths.numel())})." - ) - return sequence_lengths_plus_1.contiguous() + if sequence_lengths_plus_1 is not None: + if int(sequence_lengths_plus_1.numel()) != int(input_lengths.numel()): + raise ValueError( + "RTP plugin sequence_lengths_plus_1_d/input_lengths batch mismatch " + f"(sequence_lengths_plus_1_d={int(sequence_lengths_plus_1.numel())}, " + f"input_lengths={int(input_lengths.numel())})." + ) + return sequence_lengths_plus_1.contiguous() sequence_lengths = RTPForwardContext._non_empty_int32( getattr(attn_inputs, "sequence_lengths", None), @@ -544,20 +681,6 @@ def _build_seq_lens(attn_inputs: Any, *, device: torch.device) -> torch.Tensor: # real context length is sequence_lengths + input_lengths. return (sequence_lengths + input_lengths).contiguous() - if not non_cuda_graph_mode: - sequence_lengths_plus_1 = RTPForwardContext._non_empty_int32( - getattr(attn_inputs, "sequence_lengths_plus_1_d", None), - device=device, - ) - if sequence_lengths_plus_1 is not None: - if int(sequence_lengths_plus_1.numel()) != int(input_lengths.numel()): - raise ValueError( - "RTP plugin sequence_lengths_plus_1_d/input_lengths batch mismatch " - f"(sequence_lengths_plus_1_d={int(sequence_lengths_plus_1.numel())}, " - f"input_lengths={int(input_lengths.numel())})." - ) - return sequence_lengths_plus_1.contiguous() - raise ValueError( "RTP decode requires attention_inputs.sequence_lengths_plus_1_d or " "sequence_lengths for seq_lens." @@ -646,18 +769,7 @@ def _build_slot_mapping( "RTP plugin block_table/query_start_loc batch mismatch " f"(block_table={int(bt.shape[0])}, batch={batch_size})." ) - validate_slot_mapping = os.getenv("ATOM_VALIDATE_SLOT_MAPPING", "0") == "1" - if validate_slot_mapping and int(qsl[-1].item()) != num_tokens: - raise ValueError( - "RTP plugin query_start_loc/positions token mismatch " - f"(query_start_loc[-1]={int(qsl[-1].item())}, positions={num_tokens})." - ) - lengths = qsl[1:] - qsl[:-1] - if validate_slot_mapping and torch.any(lengths <= 0): - raise ValueError( - "RTP plugin query_start_loc contains non-positive sequence length." - ) if in_capture and cg_bufs is not None: # Zero-alloc path: use pre-allocated buffers so captured GPU ops # reference stable addresses that stay alive through replay. @@ -694,29 +806,14 @@ def _build_slot_mapping( torch.arange(batch_size, device=device, dtype=torch.int64), lengths.to(dtype=torch.int64), ) - if validate_slot_mapping and int(seq_id.numel()) != num_tokens: - raise ValueError( - "RTP plugin internal seq_id construction mismatch for slot_mapping." - ) block_col = torch.div( pos_i32, int(seq_size_per_block), rounding_mode="floor", ) - if validate_slot_mapping and ( - torch.any(block_col < 0) or torch.any(block_col >= bt.shape[1]) - ): - raise ValueError( - "RTP plugin block-table index out of range for full-attn slot_mapping " - f"(max_col={int(bt.shape[1]) - 1})." - ) slot_base = bt[seq_id, block_col.to(dtype=torch.int64)] - if validate_slot_mapping and torch.any(slot_base < 0): - raise ValueError( - "RTP plugin resolved padded/invalid (-1) block slot for full-attn slot_mapping." - ) token_offset = torch.remainder(pos_i32, int(seq_size_per_block)) slot_mapping = slot_base * int(seq_size_per_block) + token_offset return slot_mapping.to(dtype=torch.int64).contiguous() @@ -791,29 +888,213 @@ def _build_query_start_loc_for_plugin( ) @staticmethod + def _build_req_id_per_token( + *, + query_start_loc: torch.Tensor, + num_tokens: int, + device: torch.device, + cg_bufs: dict | None = None, + ) -> torch.Tensor: + batch_size = int(query_start_loc.numel()) - 1 + if batch_size <= 0: + raise ValueError( + "RTP plugin cannot build req_id_per_token for empty batch." + ) + in_capture = torch.cuda.is_current_stream_capturing() + if cg_bufs is not None and "seq_id_i32" in cg_bufs: + seq_id_i32 = cg_bufs["seq_id_i32"] + if not isinstance(seq_id_i32, torch.Tensor): + raise RuntimeError( + "RTP plugin capture requires prewarmed seq_id_i32 tensor." + ) + if int(seq_id_i32.shape[0]) < int(num_tokens): + raise RuntimeError( + "RTP plugin prewarmed seq_id_i32 buffer is too small " + f"(buffer={int(seq_id_i32.shape[0])}, required={int(num_tokens)})." + ) + if seq_id_i32.device != device or seq_id_i32.dtype != torch.int32: + raise RuntimeError( + "RTP plugin capture requires seq_id_i32 to be int32 on model device." + ) + if not seq_id_i32.is_contiguous(): + raise RuntimeError( + "RTP plugin capture requires seq_id_i32 to be contiguous." + ) + return seq_id_i32[:num_tokens] + if in_capture: + raise RuntimeError( + "RTP plugin capture requires prewarmed seq_id_i32 for req_id_per_token." + ) + if int(num_tokens) == 0: + return torch.empty((0,), dtype=torch.int32, device=device) + lengths = (query_start_loc[1:] - query_start_loc[:-1]).to(dtype=torch.int64) + if not torch.cuda.is_current_stream_capturing() and int( + lengths.sum().item() + ) != int(num_tokens): + raise ValueError( + "RTP plugin query_start_loc/num_tokens mismatch for req_id_per_token " + f"(query_start_loc[-1]={int(query_start_loc[-1].item())}, " + f"num_tokens={int(num_tokens)})." + ) + return torch.repeat_interleave( + torch.arange(batch_size, device=device, dtype=torch.int32), + lengths, + ).contiguous() + + @staticmethod + def _expand_block_table_for_atom_indexer( + block_table: torch.Tensor, + *, + seq_size_per_block: int, + kernel_seq_size_per_block: int, + ) -> torch.Tensor: + if ( + kernel_seq_size_per_block <= 0 + or seq_size_per_block <= 0 + or seq_size_per_block == kernel_seq_size_per_block + ): + return block_table + if seq_size_per_block % kernel_seq_size_per_block != 0: + raise ValueError( + "RTP plugin cannot expand block_table for ATOM indexer: " + f"seq_size_per_block={seq_size_per_block}, " + f"kernel_seq_size_per_block={kernel_seq_size_per_block}." + ) + block_ratio = int(seq_size_per_block // kernel_seq_size_per_block) + offsets = torch.arange( + block_ratio, device=block_table.device, dtype=torch.int32 + ) + base = block_table.to(dtype=torch.int32) + expanded = base.unsqueeze(-1) * block_ratio + offsets + expanded = torch.where(base.unsqueeze(-1) >= 0, expanded, -1) + return expanded.reshape(base.shape[0], base.shape[1] * block_ratio).contiguous() + + @staticmethod + def _expand_block_table_for_atom_indexer_capture( + block_table: torch.Tensor, + *, + seq_size_per_block: int, + kernel_seq_size_per_block: int, + cg_bufs: dict, + ) -> torch.Tensor: + if ( + kernel_seq_size_per_block <= 0 + or seq_size_per_block <= 0 + or seq_size_per_block == kernel_seq_size_per_block + ): + return block_table + if seq_size_per_block % kernel_seq_size_per_block != 0: + raise ValueError( + "RTP plugin cannot expand block_table for ATOM indexer: " + f"seq_size_per_block={seq_size_per_block}, " + f"kernel_seq_size_per_block={kernel_seq_size_per_block}." + ) + if triton is None: + raise RuntimeError( + "RTP plugin cuda-graph capture requires Triton for capture-safe " + "ATOM indexer block_table expansion." + ) + out_buf = cg_bufs.get("indexer_block_table_i32") + if not isinstance(out_buf, torch.Tensor): + raise RuntimeError( + "RTP plugin capture requires prewarmed indexer_block_table_i32." + ) + block_ratio = int(seq_size_per_block // kernel_seq_size_per_block) + bs_now = int(block_table.shape[0]) + cols_now = int(block_table.shape[1]) + expanded_cols = cols_now * block_ratio + if int(out_buf.shape[0]) < bs_now or int(out_buf.shape[1]) < expanded_cols: + raise RuntimeError( + "RTP plugin prewarmed indexer_block_table_i32 buffer is too small " + f"(buffer={tuple(out_buf.shape)}, required=({bs_now}, {expanded_cols}))." + ) + out_view = out_buf[:bs_now, :expanded_cols] + _expand_block_table_for_atom_indexer_kernel[(bs_now, cols_now)]( + block_table, + out_view, + cols_now, + expanded_cols, + block_ratio, + BLOCK_RATIO=block_ratio, + ) + return out_view + + @classmethod + def _build_indexer_block_tables( + cls, + *, + block_table_i32: torch.Tensor, + seq_size_per_block: int, + kernel_seq_size_per_block: int, + cg_max_seq_len: int, + in_capture: bool, + cg_bufs: dict | None, + ) -> torch.Tensor: + del ( + cls, + seq_size_per_block, + kernel_seq_size_per_block, + cg_max_seq_len, + in_capture, + cg_bufs, + ) + # Base path (e.g. Qwen3.5): keep compact physical table layout and do not + # expand to indexer granularity. + return block_table_i32 + + @classmethod + def _resolve_plugin_block_table( + cls, + *, + attn_inputs: Any, + seq_size_per_block: int, + kernel_seq_size_per_block: int, + cg_bufs: dict | None, + in_capture: bool, + ) -> torch.Tensor | None: + physical_block_table = getattr(attn_inputs, "kv_cache_block_id_device", None) + if physical_block_table is not None and physical_block_table.numel() > 0: + return physical_block_table + kernel_block_table = cls._select_block_table_for_layer(attn_inputs=attn_inputs) + if kernel_block_table is None or kernel_block_table.numel() == 0: + return None + return cls._recover_physical_block_table_from_kernel( + kernel_block_table, + seq_size_per_block=int(seq_size_per_block), + kernel_seq_size_per_block=int(kernel_seq_size_per_block), + cg_bufs=cg_bufs, + ) + + @classmethod def _build_plugin_attention_metadata( + cls, *, attn_inputs: Any, positions: torch.Tensor, seq_size_per_block: int, + kernel_seq_size_per_block: int = 0, cg_max_seq_len: int = 0, cg_bufs: dict | None = None, ) -> AttentionMetaData: - block_table = RTPForwardContext._select_block_table_for_layer( + in_capture = torch.cuda.is_current_stream_capturing() + block_table = cls._resolve_plugin_block_table( attn_inputs=attn_inputs, + seq_size_per_block=int(seq_size_per_block), + kernel_seq_size_per_block=int(kernel_seq_size_per_block), + cg_bufs=cg_bufs, + in_capture=in_capture, ) if block_table is None or block_table.numel() == 0: raise ValueError( - "RTP plugin requires kv_cache_kernel_block_id_device for plugin attention metadata." + "RTP plugin requires kv_cache_block_id_device for plugin attention metadata." ) device = positions.device is_prefill = bool(getattr(attn_inputs, "is_prefill", False)) - in_capture = torch.cuda.is_current_stream_capturing() if in_capture and cg_bufs is None: raise RuntimeError( "RTP plugin capture requires prewarmed cg_bufs; metadata fallback path is disabled." ) - seq_lens = RTPForwardContext._build_seq_lens(attn_inputs, device=device) + seq_lens = cls._build_seq_lens(attn_inputs, device=device) if in_capture and cg_bufs is not None: bs_now = int(seq_lens.shape[0]) seq_lens_buf = cg_bufs["seq_lens_i32"] @@ -837,22 +1118,42 @@ def _build_plugin_attention_metadata( # slice here so slot_mapping and num_actual_tokens are correctly sized. if in_capture and not is_prefill: positions = positions[:batch_size] + if positions.dtype != torch.int32: + positions_i32_buf = cg_bufs.get("positions_i32") + if not isinstance(positions_i32_buf, torch.Tensor): + raise RuntimeError( + "RTP plugin capture requires prewarmed positions_i32 buffer." + ) + if int(positions_i32_buf.shape[0]) < batch_size: + raise RuntimeError( + "RTP plugin prewarmed positions_i32 buffer is too small " + f"(buffer={int(positions_i32_buf.shape[0])}, required={batch_size})." + ) + positions_i32 = positions_i32_buf[:batch_size] + positions_i32.copy_(positions, non_blocking=True) + positions = positions_i32 num_actual_tokens = int(positions.numel()) - query_start_loc = RTPForwardContext._build_query_start_loc_for_plugin( + query_start_loc = cls._build_query_start_loc_for_plugin( attn_inputs=attn_inputs, seq_lens=seq_lens, num_tokens=num_actual_tokens, device=device, cg_bufs=cg_bufs, ) - slot_mapping = RTPForwardContext._build_slot_mapping( + slot_mapping = cls._build_slot_mapping( positions=positions, query_start_loc=query_start_loc, block_table=block_table, seq_size_per_block=seq_size_per_block, cg_bufs=cg_bufs, ) + req_id_per_token = cls._build_req_id_per_token( + query_start_loc=query_start_loc, + num_tokens=num_actual_tokens, + device=device, + cg_bufs=cg_bufs if in_capture else None, + ) is_dummy_warmup = False if in_capture: @@ -916,22 +1217,30 @@ def _build_plugin_attention_metadata( in_capture = torch.cuda.is_current_stream_capturing() if in_capture and cg_bufs is not None: - # Zero-alloc capture path: always route through prewarmed block_table_i32. - bt_buf = cg_bufs["block_table_i32"] - bs_now = int(block_table.shape[0]) - cols_now = int(block_table.shape[1]) - if int(bt_buf.shape[0]) < bs_now or int(bt_buf.shape[1]) < cols_now: + # Capture must keep the compact physical table layout. Copying into a + # wider prewarmed table and slicing columns would create a strided view + # that the downstream Triton expand kernel does not understand. + if block_table.dtype != torch.int32: raise RuntimeError( - "RTP plugin prewarmed block_table_i32 buffer is too small " - f"(buffer={tuple(bt_buf.shape)}, required=({bs_now}, {cols_now}))." + "RTP plugin capture requires block_table to be int32 to avoid allocation." ) - bt_view = bt_buf[:bs_now, :cols_now] - bt_view.copy_(block_table, non_blocking=True) - block_table_i32 = bt_view + if not block_table.is_contiguous(): + raise RuntimeError( + "RTP plugin capture requires block_table to be contiguous to avoid allocation." + ) + block_table_i32 = block_table else: block_table_i32 = block_table.to( device=device, dtype=torch.int32, non_blocking=True ).contiguous() + indexer_block_table_i32 = cls._build_indexer_block_tables( + block_table_i32=block_table_i32, + seq_size_per_block=int(seq_size_per_block), + kernel_seq_size_per_block=int(kernel_seq_size_per_block), + cg_max_seq_len=int(cg_max_seq_len), + in_capture=in_capture, + cg_bufs=cg_bufs, + ) plugin_md = AiterFlashAttentionMetadataForPluginMode( num_actual_tokens=num_actual_tokens, num_actual_kv_tokens=num_actual_kv_tokens, @@ -957,6 +1266,29 @@ def _build_plugin_attention_metadata( ) # Prefill-only fields shared across all full-attn layers in the step. plugin_md.rtp_cu_seqlens_q = query_start_loc + plugin_md.req_id_per_token = req_id_per_token + plugin_md.topk_tokens = 0 + plugin_md.sparse_block_size = int(seq_size_per_block) + plugin_md.cg_bufs = cg_bufs + cu_seqlen_ks = None + cu_seqlen_ke = None + if is_prefill: + prefill_lengths = (query_start_loc[1:] - query_start_loc[:-1]).to( + dtype=torch.int64 + ) + if in_capture and cg_bufs is not None and "seq_id" in cg_bufs: + seq_id_for_span = cg_bufs["seq_id"][:num_actual_tokens] + else: + seq_id_for_span = torch.repeat_interleave( + torch.arange(batch_size, device=device, dtype=torch.int64), + prefill_lengths, + ) + cu_seqlen_ks = ( + query_start_loc[:-1][seq_id_for_span].to(dtype=torch.int32).contiguous() + ) + cu_seqlen_ke = ( + torch.arange(num_actual_tokens, device=device, dtype=torch.int32) + 1 + ).contiguous() # Mark dummy probe (RTP initCapture's "forward for output datatype" feeds # all-zero seq_lens/block_tables); RTPFullAttention short-circuits to zeros. plugin_md.is_dummy_warmup = bool(is_dummy_warmup) @@ -973,11 +1305,17 @@ def _build_plugin_attention_metadata( else: plugin_md.rtp_has_prefix = False attn_metadata = AttentionMetaData( + cu_seqlens_q=query_start_loc, + cu_seqlens_k=query_start_loc, max_seqlen_q=max_query_len, max_seqlen_k=max_seq_len, - block_tables=plugin_md.block_table, + block_tables=indexer_block_table_i32, slot_mapping=slot_mapping, context_lens=seq_lens, + cu_seqlen_ks=cu_seqlen_ks, + cu_seqlen_ke=cu_seqlen_ke, + has_cached=False, + total_kv=int(num_actual_kv_tokens), ) attn_metadata.plugin_metadata = plugin_md return attn_metadata @@ -986,7 +1324,17 @@ def _build_plugin_attention_metadata( def collect_layer_maps(model: Any) -> LayerMaps: gdn_layer_map: Dict[int, GatedDeltaNet] = {} full_attn_layer_map: Dict[int, Any] = {} + mla_layer_map: Dict[int, Any] = {} rtp_attention_cls: type[Any] | None = None + rtp_mla_attention_cls: type[Any] | None = None + try: + from atom.plugin.rtpllm.attention_backend.rtp_mla_attention import ( + RTPMLAAttention, + ) + + rtp_mla_attention_cls = RTPMLAAttention + except (ImportError, ModuleNotFoundError): + rtp_mla_attention_cls = None try: from atom.plugin.rtpllm.attention_backend import AttentionForRTPLLM @@ -997,6 +1345,20 @@ def collect_layer_maps(model: Any) -> LayerMaps: for module in model.modules(): if isinstance(module, GatedDeltaNet): gdn_layer_map[int(module.layer_num)] = module + elif ( + getattr(module, "indexer", None) is not None + and getattr(module, "mla_attn", None) is not None + and getattr(module, "layer_num", None) is not None + ): + mla_layer_map[int(module.layer_num)] = module + elif rtp_mla_attention_cls is not None and isinstance( + module, rtp_mla_attention_cls + ): + layer_num = getattr(module, "layer_id", None) + if layer_num is None: + layer_num = getattr(module, "layer_num", None) + if layer_num is not None and int(layer_num) not in mla_layer_map: + mla_layer_map[int(layer_num)] = module elif isinstance(module, (PagedAttention, PagedAttentionImpl)) or ( rtp_attention_cls is not None and isinstance(module, rtp_attention_cls) ): @@ -1006,7 +1368,7 @@ def collect_layer_maps(model: Any) -> LayerMaps: layer_num = getattr(module, "layer_num", None) if layer_num is not None: full_attn_layer_map[int(layer_num)] = module - return gdn_layer_map, full_attn_layer_map + return gdn_layer_map, full_attn_layer_map, mla_layer_map @staticmethod def _build_kv_cache_tensors( @@ -1016,9 +1378,9 @@ def _build_kv_cache_tensors( if runtime.kv_cache is None: raise ValueError("RTP plugin requires initialized kv_cache for ATOM model.") - gdn_layer_map, full_attn_layer_map = layer_maps + gdn_layer_map, full_attn_layer_map, mla_layer_map = layer_maps - if not gdn_layer_map and not full_attn_layer_map: + if not gdn_layer_map and not full_attn_layer_map and not mla_layer_map: return {} cache_tensors: Dict[str, KVCacheTensor] = {} @@ -1109,6 +1471,31 @@ def _build_kv_cache_tensors( k_scale=None, v_scale=None, ) + # Build MLA cache references separately from full attention. MLA adapters + # own their kv_cache pointer and refresh it in bind() for every forward. + for layer_num in mla_layer_map.keys(): + layer_key = f"layer_{layer_num}" + if layer_key in cache_tensors: + continue + + layer_cache = runtime.kv_cache.get_layer_cache(layer_num) + kv_cache_base = getattr(layer_cache, "kv_cache_base", None) + if kv_cache_base is None: + raise ValueError( + f"Layer {layer_num} kv_cache_base is missing for MLA cache." + ) + if kv_cache_base.dim() < 1: + raise ValueError( + f"Layer {layer_num} MLA kv_cache_base has invalid shape " + f"{tuple(kv_cache_base.shape)}." + ) + cache_tensors[layer_key] = KVCacheTensor( + layer_num=layer_num, + k_cache=layer_cache, + v_cache=None, + k_scale=None, + v_scale=None, + ) return cache_tensors @staticmethod @@ -1118,10 +1505,12 @@ def _kv_cache_signature( ) -> Tuple[Any, ...]: if runtime.kv_cache is None: return ("no_kv_cache",) - gdn_layer_map, full_attn_layer_map = layer_maps + gdn_layer_map, full_attn_layer_map, mla_layer_map = layer_maps signature: list[Any] = [id(runtime.kv_cache)] all_layer_nums = sorted( - set(gdn_layer_map.keys()) | set(full_attn_layer_map.keys()) + set(gdn_layer_map.keys()) + | set(full_attn_layer_map.keys()) + | set(mla_layer_map.keys()) ) for layer_num in all_layer_nums: layer_cache = runtime.kv_cache.get_layer_cache(layer_num) @@ -1136,6 +1525,16 @@ def _kv_cache_signature( int(kv_cache_base.numel()), ) ) + kv_scale_base = getattr(layer_cache, "kv_scale_base", None) + if kv_scale_base is not None and kv_scale_base.numel() > 0: + signature.append( + ( + int(layer_num), + "scale", + int(kv_scale_base.data_ptr()), + int(kv_scale_base.numel()), + ) + ) return tuple(signature) @classmethod @@ -1166,6 +1565,8 @@ def build( if kernel_seq_size_per_block <= 0: kernel_seq_size_per_block = int(seq_size_per_block) state_indices_cache: Dict[tuple[int, bool], torch.Tensor] = {} + resolved_layer_maps = layer_maps or cls.collect_layer_maps(model) + gdn_layer_map, _, _ = resolved_layer_maps layer_group_map_signature = cls._layer_group_map_signature(attn_inputs) layer_group_map = getattr(runtime, "_rtp_layer_group_map", None) cached_layer_group_map_signature = getattr( @@ -1178,27 +1579,29 @@ def build( layer_group_map = cls._build_layer_group_map(attn_inputs) runtime._rtp_layer_group_map = layer_group_map runtime._rtp_layer_group_map_signature = layer_group_map_signature - gdn_metadata = cls._build_gdn_metadata( - attn_inputs, - seq_size_per_block=seq_size_per_block, - num_tokens=int(positions.numel()), - state_indices_cache=state_indices_cache, - layer_group_map=layer_group_map, - ) - # Keep raw RTP attention inputs in metadata so GDN can resolve per-layer - # block-map/state-index semantics (same idea as RTP's select_block_map_for_layer). - gdn_metadata.rtp_attn_inputs = attn_inputs - gdn_metadata.rtp_seq_size_per_block = int(seq_size_per_block) - gdn_metadata.rtp_state_indices_cache = state_indices_cache - gdn_metadata.rtp_layer_group_map = layer_group_map + gdn_metadata = None + if gdn_layer_map: + gdn_metadata = cls._build_gdn_metadata( + attn_inputs, + seq_size_per_block=seq_size_per_block, + num_tokens=int(positions.numel()), + state_indices_cache=state_indices_cache, + layer_group_map=layer_group_map, + ) + # Keep raw RTP attention inputs in metadata so GDN can resolve per-layer + # block-map/state-index semantics (same idea as RTP's select_block_map_for_layer). + gdn_metadata.rtp_attn_inputs = attn_inputs + gdn_metadata.rtp_seq_size_per_block = int(seq_size_per_block) + gdn_metadata.rtp_state_indices_cache = state_indices_cache + gdn_metadata.rtp_layer_group_map = layer_group_map attn_metadata = cls._build_plugin_attention_metadata( attn_inputs=attn_inputs, positions=positions, - seq_size_per_block=kernel_seq_size_per_block, + seq_size_per_block=seq_size_per_block, + kernel_seq_size_per_block=kernel_seq_size_per_block, cg_max_seq_len=int(cg_max_seq_len), cg_bufs=cg_bufs, ) - resolved_layer_maps = layer_maps or cls.collect_layer_maps(model) kv_cache_signature = cls._kv_cache_signature( runtime=runtime, layer_maps=resolved_layer_maps, @@ -1234,30 +1637,119 @@ def build( layer_group_map=layer_group_map, context=context, num_tokens=int(positions.numel()), + mla_layer_map=cls._resolve_mla_layer_map(resolved_layer_maps), ) @classmethod - @contextmanager - def bind( - cls, + def _resolve_mla_layer_map(cls, layer_maps: LayerMaps) -> Dict[int, Any]: + del cls, layer_maps + return {} + + @staticmethod + def _build_fallback_indexer_cache( *, - model: Any, - runtime: Any, - inputs: Any, - positions: torch.Tensor, - layer_maps: LayerMaps | None = None, - cg_max_seq_len: int = 0, - cg_bufs: dict | None = None, - ) -> Iterator[None]: - forward_context = cls.build( - model=model, - runtime=runtime, - inputs=inputs, - positions=positions, - layer_maps=layer_maps, - cg_max_seq_len=cg_max_seq_len, - cg_bufs=cg_bufs, - ) + cache_owner: Any, + layer_cache: Any, + indexer: Any, + block_size: int, + ) -> torch.Tensor | None: + kv_cache_base = getattr(layer_cache, "kv_cache_base", None) + if kv_cache_base is None or kv_cache_base.dim() == 0: + return None + index_dim = int(getattr(indexer, "head_dim", 0) or 0) + 4 + if index_dim <= 4: + return None + aligned_dim = ((index_dim + 15) // 16) * 16 + num_tokens = int(kv_cache_base.shape[0]) * block_size + cached = getattr(cache_owner, "_rtp_indexer_kv_cache", None) + expected_shape = (num_tokens, 1, aligned_dim) + if ( + cached is None + or tuple(cached.shape) != expected_shape + or cached.device != kv_cache_base.device + or cached.dtype != dtypes.fp8 + ): + cached = torch.empty( + expected_shape, + device=kv_cache_base.device, + dtype=dtypes.fp8, + ) + setattr(cache_owner, "_rtp_indexer_kv_cache", cached) + return cached + + @staticmethod + def _attach_mla_layer_caches( + forward_context: "RTPForwardContext", + ) -> tuple[list[tuple[Any, str, Any]], list[tuple[list[Any], int, Any]]]: + restore_attrs: list[tuple[Any, str, Any]] = [] + restore_indices: list[tuple[list[Any], int, Any]] = [] + for layer_num, layer in forward_context.mla_layer_map.items(): + cache_tensor = forward_context.kv_cache_data.get(f"layer_{layer_num}") + if cache_tensor is None: + continue + cache_owner = getattr(layer, "mla_attn", layer) + restore_attrs.append( + (cache_owner, "kv_cache", getattr(cache_owner, "kv_cache", None)) + ) + cache_owner.kv_cache = cache_tensor.k_cache + indexer = getattr(layer, "indexer", None) + if indexer is None: + indexer = getattr(cache_owner, "indexer", None) + indexer_cache = getattr(indexer, "k_cache", None) + indexer_kv_cache = getattr(indexer_cache, "kv_cache", None) + if not isinstance(indexer_kv_cache, list) or not indexer_kv_cache: + continue + layer_cache = cache_tensor.k_cache + kv_cache_base = getattr(layer_cache, "kv_cache_base", None) + if kv_cache_base is None or kv_cache_base.dim() == 0: + continue + block_size = int( + getattr(forward_context, "rtp_seq_size_per_block", 0) + or getattr(forward_context, "rtp_kernel_seq_size_per_block", 0) + or getattr(get_current_atom_config(), "kv_cache_block_size", 0) + ) + if block_size <= 0: + raise ValueError( + "RTP plugin requires positive block_size for MLA indexer cache " + f"(layer={layer_num}, rtp_seq_size_per_block=" + f"{getattr(forward_context, 'rtp_seq_size_per_block', 0)}, " + "rtp_kernel_seq_size_per_block=" + f"{getattr(forward_context, 'rtp_kernel_seq_size_per_block', 0)})." + ) + indexer_cache_tensor = RTPForwardContext._build_fallback_indexer_cache( + cache_owner=cache_owner, + layer_cache=layer_cache, + indexer=indexer, + block_size=block_size, + ) + if indexer_cache_tensor is None: + continue + restore_indices.append((indexer_kv_cache, 0, indexer_kv_cache[0])) + indexer_kv_cache[0] = indexer_cache_tensor + return restore_attrs, restore_indices + + @classmethod + @contextmanager + def bind( + cls, + *, + model: Any, + runtime: Any, + inputs: Any, + positions: torch.Tensor, + layer_maps: LayerMaps | None = None, + cg_max_seq_len: int = 0, + cg_bufs: dict | None = None, + ) -> Iterator[None]: + forward_context = cls.build( + model=model, + runtime=runtime, + inputs=inputs, + positions=positions, + layer_maps=layer_maps, + cg_max_seq_len=cg_max_seq_len, + cg_bufs=cg_bufs, + ) prev_kv = _forward_kv_cache_context.kv_cache_data attn_md = forward_context.attn_metadata attn_md.gdn_metadata = forward_context.gdn_metadata @@ -1265,8 +1757,16 @@ def bind( attn_md.rtp_kernel_seq_size_per_block = ( forward_context.rtp_kernel_seq_size_per_block ) + attn_md.rtp_seq_size_per_block = getattr( + forward_context, "rtp_seq_size_per_block", 0 + ) attn_md.rtp_layer_group_map = forward_context.layer_group_map + restore_mla_attrs: list[tuple[Any, str, Any]] = [] + restore_mla_indices: list[tuple[list[Any], int, Any]] = [] try: + restore_mla_attrs, restore_mla_indices = cls._attach_mla_layer_caches( + forward_context + ) set_kv_cache_data(forward_context.kv_cache_data) set_forward_context( attn_metadata=attn_md, @@ -1276,5 +1776,414 @@ def bind( ) yield finally: + for target, index, old_cache in reversed(restore_mla_indices): + target[index] = old_cache + for target, attr, old_cache in reversed(restore_mla_attrs): + setattr(target, attr, old_cache) reset_forward_context() set_kv_cache_data(prev_kv if prev_kv is not None else {}) + + +@dataclass(frozen=True) +class RTPForwardMLAContext(RTPForwardContext): + @classmethod + def _resolve_plugin_block_table( + cls, + *, + attn_inputs: Any, + seq_size_per_block: int, + kernel_seq_size_per_block: int, + cg_bufs: dict | None, + in_capture: bool, + ) -> torch.Tensor | None: + physical_block_table = getattr(attn_inputs, "kv_cache_block_id_device", None) + if physical_block_table is not None and physical_block_table.numel() > 0: + return physical_block_table + kernel_block_table = cls._select_block_table_for_layer(attn_inputs=attn_inputs) + if kernel_block_table is None or kernel_block_table.numel() == 0: + return None + return cls._recover_physical_block_table_from_kernel( + kernel_block_table, + seq_size_per_block=int(seq_size_per_block), + kernel_seq_size_per_block=int(kernel_seq_size_per_block), + cg_bufs=cg_bufs if in_capture else None, + ) + + @classmethod + def _build_indexer_block_tables( + cls, + *, + block_table_i32: torch.Tensor, + seq_size_per_block: int, + kernel_seq_size_per_block: int, + cg_max_seq_len: int, + in_capture: bool, + cg_bufs: dict | None, + ) -> torch.Tensor: + if in_capture: + expected_kernel_cols = 0 + if cg_max_seq_len > 0 and int(kernel_seq_size_per_block) > 0: + expected_kernel_cols = ( + int(cg_max_seq_len) + int(kernel_seq_size_per_block) - 1 + ) // int(kernel_seq_size_per_block) + if ( + expected_kernel_cols > 0 + and int(block_table_i32.shape[1]) >= expected_kernel_cols + ): + return block_table_i32 + return cls._expand_block_table_for_atom_indexer_capture( + block_table_i32, + seq_size_per_block=int(seq_size_per_block), + kernel_seq_size_per_block=int(kernel_seq_size_per_block), + cg_bufs=cg_bufs, + ) + return cls._expand_block_table_for_atom_indexer( + block_table_i32, + seq_size_per_block=int(seq_size_per_block), + kernel_seq_size_per_block=int(kernel_seq_size_per_block), + ) + + @classmethod + def _resolve_mla_layer_map( + cls, layer_maps: RTPForwardContext.LayerMaps + ) -> Dict[int, Any]: + del cls + return layer_maps[2] + + +@dataclass(frozen=True) +class RTPForwardQwen35HybridContext(RTPForwardContext): + @staticmethod + def _build_seq_lens(attn_inputs: Any, *, device: torch.device) -> torch.Tensor: + """Qwen3.5 decode-cudagraph compatible seq_lens priority. + + Keep the validated sequence_lengths_plus_1_d ordering from + `develop/rtp_atom_0526_qwen35_cuda_graph_ok`. + """ + input_lengths = RTPForwardContext._non_empty_int32( + getattr(attn_inputs, "input_lengths", None), + device=device, + ) + if input_lengths is None: + raise ValueError( + "RTP plugin requires attention_inputs.input_lengths for seq_lens." + ) + is_prefill = bool(getattr(attn_inputs, "is_prefill", False)) + if is_prefill: + prefix_lengths = RTPForwardContext._non_empty_int32( + getattr(attn_inputs, "prefix_lengths_d", None), + device=device, + ) + if prefix_lengths is None: + prefix_lengths = RTPForwardContext._non_empty_int32( + getattr(attn_inputs, "prefix_lengths", None), + device=device, + ) + if prefix_lengths is None: + raise ValueError( + "RTP prefill requires attention_inputs.prefix_lengths for seq_lens." + ) + if int(prefix_lengths.numel()) != int(input_lengths.numel()): + raise ValueError( + "RTP plugin prefix_lengths/input_lengths batch mismatch " + f"(prefix_lengths={int(prefix_lengths.numel())}, " + f"input_lengths={int(input_lengths.numel())})." + ) + return (prefix_lengths + input_lengths).contiguous() + + non_cuda_graph_mode = not torch.cuda.is_current_stream_capturing() and not bool( + getattr(attn_inputs, "is_cuda_graph", False) + ) + if non_cuda_graph_mode: + sequence_lengths_plus_1 = RTPForwardContext._non_empty_int32( + getattr(attn_inputs, "sequence_lengths_plus_1_d", None), + device=device, + ) + if sequence_lengths_plus_1 is not None: + if int(sequence_lengths_plus_1.numel()) != int(input_lengths.numel()): + raise ValueError( + "RTP plugin sequence_lengths_plus_1_d/input_lengths batch mismatch " + f"(sequence_lengths_plus_1_d={int(sequence_lengths_plus_1.numel())}, " + f"input_lengths={int(input_lengths.numel())})." + ) + return sequence_lengths_plus_1.contiguous() + + sequence_lengths = RTPForwardContext._non_empty_int32( + getattr(attn_inputs, "sequence_lengths", None), + device=device, + ) + if sequence_lengths is not None: + if int(sequence_lengths.numel()) != int(input_lengths.numel()): + raise ValueError( + "RTP plugin sequence_lengths/input_lengths batch mismatch " + f"(sequence_lengths={int(sequence_lengths.numel())}, " + f"input_lengths={int(input_lengths.numel())})." + ) + return (sequence_lengths + input_lengths).contiguous() + + if not non_cuda_graph_mode: + sequence_lengths_plus_1 = RTPForwardContext._non_empty_int32( + getattr(attn_inputs, "sequence_lengths_plus_1_d", None), + device=device, + ) + if sequence_lengths_plus_1 is not None: + if int(sequence_lengths_plus_1.numel()) != int(input_lengths.numel()): + raise ValueError( + "RTP plugin sequence_lengths_plus_1_d/input_lengths batch mismatch " + f"(sequence_lengths_plus_1_d={int(sequence_lengths_plus_1.numel())}, " + f"input_lengths={int(input_lengths.numel())})." + ) + return sequence_lengths_plus_1.contiguous() + + raise ValueError( + "RTP decode requires attention_inputs.sequence_lengths_plus_1_d or " + "sequence_lengths for seq_lens." + ) + + @classmethod + def _resolve_plugin_block_table( + cls, + *, + attn_inputs: Any, + seq_size_per_block: int, + kernel_seq_size_per_block: int, + cg_bufs: dict | None, + in_capture: bool, + ) -> torch.Tensor | None: + del cls, seq_size_per_block, kernel_seq_size_per_block, cg_bufs, in_capture + return RTPForwardContext._select_block_table_for_layer(attn_inputs=attn_inputs) + + @staticmethod + def _build_query_start_loc_for_plugin( + *, + attn_inputs: Any, + seq_lens: torch.Tensor, + num_tokens: int, + device: torch.device, + cg_bufs: dict | None = None, + ) -> torch.Tensor: + batch_size = int(seq_lens.numel()) + if batch_size <= 0: + raise ValueError( + "RTP plugin cannot build query_start_loc with empty seq_lens." + ) + + in_capture = torch.cuda.is_current_stream_capturing() + if in_capture and cg_bufs is not None: + return cg_bufs["query_start_loc"][: batch_size + 1] + + if in_capture: + raise ValueError( + "RTP plugin capture requires prewarmed cg_bufs for query_start_loc " + f"(batch={batch_size}, num_tokens={int(num_tokens)})." + ) + + qsl = RTPForwardContext._query_start_loc(attn_inputs, device=device) + if qsl is not None and qsl.numel() == batch_size + 1: + lengths = qsl[1:] - qsl[:-1] + qsl_stats = torch.stack([qsl[-1], torch.min(lengths)], dim=0).to( + device="cpu" + ) + qsl_total_tokens, qsl_min_len = [int(v) for v in qsl_stats.tolist()] + if qsl_total_tokens == int(num_tokens) and qsl_min_len > 0: + return qsl.contiguous() + + input_lengths = RTPForwardContext._non_empty_int32( + getattr(attn_inputs, "input_lengths", None), + device=device, + ) + if input_lengths is not None and int(input_lengths.numel()) == batch_size: + input_stats = torch.stack( + [torch.min(input_lengths), torch.sum(input_lengths)], + dim=0, + ).to(device="cpu") + min_input_len, total_input_len = [int(v) for v in input_stats.tolist()] + if min_input_len > 0 and total_input_len == int(num_tokens): + prefix = torch.zeros((1,), dtype=torch.int32, device=device) + return torch.cat( + [prefix, input_lengths.cumsum(dim=0)], dim=0 + ).contiguous() + + if int(num_tokens) == batch_size: + prefix = torch.arange(0, batch_size + 1, dtype=torch.int32, device=device) + return prefix.contiguous() + if batch_size == 1: + return torch.tensor([0, int(num_tokens)], dtype=torch.int32, device=device) + + raise ValueError( + "RTP plugin failed to build valid query_start_loc for plugin attention " + f"(batch={batch_size}, num_tokens={int(num_tokens)})." + ) + + @classmethod + def _build_plugin_attention_metadata( + cls, + *, + attn_inputs: Any, + positions: torch.Tensor, + seq_size_per_block: int, + kernel_seq_size_per_block: int = 0, + cg_max_seq_len: int = 0, + cg_bufs: dict | None = None, + ) -> AttentionMetaData: + del kernel_seq_size_per_block + block_table = cls._resolve_plugin_block_table( + attn_inputs=attn_inputs, + seq_size_per_block=int(seq_size_per_block), + kernel_seq_size_per_block=0, + cg_bufs=cg_bufs, + in_capture=torch.cuda.is_current_stream_capturing(), + ) + if block_table is None or block_table.numel() == 0: + raise ValueError( + "RTP plugin requires kv_cache_kernel_block_id_device for plugin attention metadata." + ) + device = positions.device + is_prefill = bool(getattr(attn_inputs, "is_prefill", False)) + in_capture = torch.cuda.is_current_stream_capturing() + if in_capture and cg_bufs is None: + raise RuntimeError( + "RTP plugin capture requires prewarmed cg_bufs; metadata fallback path is disabled." + ) + seq_lens = cls._build_seq_lens(attn_inputs, device=device) + if in_capture and cg_bufs is not None: + bs_now = int(seq_lens.shape[0]) + seq_lens_buf = cg_bufs["seq_lens_i32"] + if int(seq_lens_buf.shape[0]) < bs_now: + raise RuntimeError( + "RTP plugin prewarmed seq_lens_i32 buffer is too small " + f"(buffer={int(seq_lens_buf.shape[0])}, required={bs_now})." + ) + seq_lens_view = seq_lens_buf[:bs_now] + seq_lens_view.copy_(seq_lens, non_blocking=True) + seq_lens = seq_lens_view + else: + seq_lens = seq_lens.to( + device=device, dtype=torch.int32, non_blocking=True + ).contiguous() + batch_size = int(seq_lens.numel()) + + if in_capture and not is_prefill: + positions = positions[:batch_size] + num_actual_tokens = int(positions.numel()) + + query_start_loc = cls._build_query_start_loc_for_plugin( + attn_inputs=attn_inputs, + seq_lens=seq_lens, + num_tokens=num_actual_tokens, + device=device, + cg_bufs=cg_bufs, + ) + slot_mapping = cls._build_slot_mapping( + positions=positions, + query_start_loc=query_start_loc, + block_table=block_table, + seq_size_per_block=seq_size_per_block, + cg_bufs=cg_bufs, + ) + + is_dummy_warmup = False + if in_capture: + max_query_len = 1 + if cg_max_seq_len <= 0: + raise RuntimeError( + "RTP plugin cuda-graph capture requires cg_max_seq_len; " + "did you forget to thread it through RTPForwardContext.bind?" + ) + max_seq_len = int(cg_max_seq_len) + num_actual_kv_tokens = max_seq_len * batch_size + else: + query_lens = query_start_loc[1:] - query_start_loc[:-1] + stats = torch.stack( + [ + torch.max(query_lens), + torch.max(seq_lens), + torch.sum(seq_lens), + ], + dim=0, + ).to(device="cpu") + max_query_len, max_seq_len, num_actual_kv_tokens = [ + int(v) for v in stats.tolist() + ] + if max_seq_len <= 0: + is_dummy_warmup = True + max_seq_len = int(cg_max_seq_len) if cg_max_seq_len > 0 else 1 + if max_query_len <= 0: + max_query_len = 1 + + decode_md = None + prefill_md = None + if is_prefill: + prefill_md = AiterFlashAttentionPrefillMetadata( + max_query_len=max_query_len, + max_seq_len=max_seq_len, + query_start_loc=query_start_loc, + ) + else: + decode_md = AiterFlashAttentionDecodeMetadata( + max_query_len=max_query_len, + max_seq_len=max_seq_len, + query_start_loc=query_start_loc, + ) + + if in_capture and cg_bufs is not None: + bt_buf = cg_bufs["block_table_i32"] + bs_now = int(block_table.shape[0]) + cols_now = int(block_table.shape[1]) + if int(bt_buf.shape[0]) < bs_now or int(bt_buf.shape[1]) < cols_now: + raise RuntimeError( + "RTP plugin prewarmed block_table_i32 buffer is too small " + f"(buffer={tuple(bt_buf.shape)}, required=({bs_now}, {cols_now}))." + ) + bt_view = bt_buf[:bs_now, :cols_now] + bt_view.copy_(block_table, non_blocking=True) + block_table_i32 = bt_view + else: + block_table_i32 = block_table.to( + device=device, dtype=torch.int32, non_blocking=True + ).contiguous() + + plugin_md = AiterFlashAttentionMetadataForPluginMode( + num_actual_tokens=num_actual_tokens, + num_actual_kv_tokens=num_actual_kv_tokens, + max_query_len=max_query_len, + query_start_loc=query_start_loc, + max_seq_len=max_seq_len, + seq_lens=seq_lens, + slot_mapping=slot_mapping, + block_table=block_table_i32, + num_decodes=0 if is_prefill else batch_size, + num_decode_tokens=0 if is_prefill else num_actual_tokens, + num_prefills=batch_size if is_prefill else 0, + num_prefill_tokens=num_actual_tokens if is_prefill else 0, + num_extends=0, + num_extend_tokens=0, + decode_metadata=decode_md, + prefill_metadata=prefill_md, + extend_metadata=None, + use_cascade=False, + common_prefix_len=0, + total_tokens=0, + context=None, + ) + plugin_md.rtp_cu_seqlens_q = query_start_loc + plugin_md.is_dummy_warmup = bool(is_dummy_warmup) + prefix_lengths = getattr(attn_inputs, "prefix_lengths", None) + if ( + prefix_lengths is not None + and int(prefix_lengths.numel()) > 0 + and not in_capture + ): + plugin_md.rtp_has_prefix = bool((prefix_lengths > 0).any().item()) + else: + plugin_md.rtp_has_prefix = False + + attn_metadata = AttentionMetaData( + max_seqlen_q=max_query_len, + max_seqlen_k=max_seq_len, + block_tables=plugin_md.block_table, + slot_mapping=slot_mapping, + context_lens=seq_lens, + ) + attn_metadata.plugin_metadata = plugin_md + return attn_metadata diff --git a/tests/plugin/test_rtpllm_forward_context_semantics.py b/tests/plugin/test_rtpllm_forward_context_semantics.py index bd88b5b468..e316879ee1 100644 --- a/tests/plugin/test_rtpllm_forward_context_semantics.py +++ b/tests/plugin/test_rtpllm_forward_context_semantics.py @@ -40,16 +40,27 @@ def _install_forward_context_stubs(): utils_forward_context = types.ModuleType("atom.utils.forward_context") utils_forward_context.AttentionMetaData = _KwargsObject utils_forward_context.Context = _KwargsObject - utils_forward_context._forward_kv_cache_context = {} + utils_forward_context._forward_kv_cache_context = SimpleNamespace(kv_cache_data={}) utils_forward_context.reset_forward_context = lambda *args, **kwargs: None utils_forward_context.set_forward_context = lambda *args, **kwargs: None - utils_forward_context.set_kv_cache_data = lambda *args, **kwargs: None + utils_forward_context.get_forward_context = ( + lambda *args, **kwargs: SimpleNamespace() + ) + + def _set_kv_cache_data(value): + utils_forward_context._forward_kv_cache_context.kv_cache_data = value + + utils_forward_context.set_kv_cache_data = _set_kv_cache_data sys.modules["atom.utils.forward_context"] = utils_forward_context _install_forward_context_stubs() -from atom.plugin.rtpllm.utils.forward_context import RTPForwardContext # noqa: E402 +from atom.plugin.rtpllm.utils.forward_context import ( # noqa: E402 + RTPForwardContext, + RTPForwardMLAContext, + RTPForwardQwen35HybridContext, +) def _make_attn_inputs( @@ -59,6 +70,7 @@ def _make_attn_inputs( sequence_lengths=None, sequence_lengths_plus_1_d=None, cu_seqlens=None, + kv_cache_block_id_device=None, kv_cache_kernel_block_id_device=None, is_prefill=False, is_cuda_graph=False, @@ -69,6 +81,7 @@ def _make_attn_inputs( sequence_lengths=sequence_lengths, sequence_lengths_plus_1_d=sequence_lengths_plus_1_d, cu_seqlens=cu_seqlens, + kv_cache_block_id_device=kv_cache_block_id_device, kv_cache_kernel_block_id_device=kv_cache_kernel_block_id_device, is_prefill=is_prefill, is_cuda_graph=is_cuda_graph, @@ -132,7 +145,208 @@ def test_rtpllm_forward_context_decode_metadata_state_indices_shape(): assert md.non_spec_state_indices_tensor.cpu().tolist() == [125] -def test_rtpllm_decode_seq_lens_priority_splits_graph_and_eager_modes(): +def test_plugin_attention_metadata_slot_mapping_uses_physical_block_table(): + attn_inputs = _make_attn_inputs( + input_lengths=torch.tensor([1], dtype=torch.int32), + sequence_lengths=torch.tensor([1030], dtype=torch.int32), + kv_cache_block_id_device=torch.tensor([[7, 8]], dtype=torch.int32), + kv_cache_kernel_block_id_device=torch.tensor( + [[700, 701, 702]], dtype=torch.int32 + ), + is_prefill=False, + ) + + md = RTPForwardContext._build_plugin_attention_metadata( + attn_inputs=attn_inputs, + positions=torch.tensor([1029], dtype=torch.int32), + seq_size_per_block=1024, + ) + + assert md.plugin_metadata.block_table.cpu().tolist() == [[7, 8]] + assert md.plugin_metadata.slot_mapping.cpu().tolist() == [8 * 1024 + 5] + + +def test_recover_physical_block_table_accepts_expanded_kernel_layout(): + expanded = torch.tensor( + [[448, 449, 450, 451, 452, 453, 454, 455]], dtype=torch.int32 + ) + + recovered = RTPForwardContext._recover_physical_block_table_from_kernel( + expanded, + seq_size_per_block=1024, + kernel_seq_size_per_block=128, + ) + + assert recovered.cpu().tolist() == [[56]] + + +def test_recover_physical_block_table_keeps_compact_physical_layout(): + compact = torch.tensor([[7, 8, 9]], dtype=torch.int32) + + recovered = RTPForwardContext._recover_physical_block_table_from_kernel( + compact, + seq_size_per_block=1024, + kernel_seq_size_per_block=16, + ) + + assert recovered.cpu().tolist() == [[7, 8, 9]] + + +def test_plugin_attention_metadata_keeps_indexer_block_table_expanded(): + attn_inputs = _make_attn_inputs( + input_lengths=torch.tensor([1030], dtype=torch.int32), + prefix_lengths=torch.tensor([0], dtype=torch.int32), + kv_cache_block_id_device=torch.tensor([[7, 8]], dtype=torch.int32), + is_prefill=True, + ) + + md = RTPForwardMLAContext._build_plugin_attention_metadata( + attn_inputs=attn_inputs, + positions=torch.arange(1030, dtype=torch.int32), + seq_size_per_block=1024, + kernel_seq_size_per_block=16, + ) + + assert md.plugin_metadata.block_table.cpu().tolist() == [[7, 8]] + assert md.block_tables.shape == (1, 128) + assert md.block_tables[0, :4].cpu().tolist() == [448, 449, 450, 451] + assert md.block_tables[0, 64:68].cpu().tolist() == [512, 513, 514, 515] + + +def test_qwen35_context_does_not_use_glm5_indexer_block_expansion(): + block_table = torch.tensor([[7, 8]], dtype=torch.int32) + + qwen_block_tables = RTPForwardQwen35HybridContext._build_indexer_block_tables( + block_table_i32=block_table, + seq_size_per_block=1024, + kernel_seq_size_per_block=16, + cg_max_seq_len=0, + in_capture=False, + cg_bufs=None, + ) + glm5_block_tables = RTPForwardMLAContext._build_indexer_block_tables( + block_table_i32=block_table, + seq_size_per_block=1024, + kernel_seq_size_per_block=16, + cg_max_seq_len=0, + in_capture=False, + cg_bufs=None, + ) + + assert qwen_block_tables.shape == (1, 2) + assert qwen_block_tables.cpu().tolist() == [[7, 8]] + assert glm5_block_tables.shape[1] > qwen_block_tables.shape[1] + + +def test_plugin_attention_metadata_keeps_physical_block_table_for_base_context(): + attn_inputs = _make_attn_inputs( + input_lengths=torch.tensor([1030], dtype=torch.int32), + prefix_lengths=torch.tensor([0], dtype=torch.int32), + kv_cache_block_id_device=torch.tensor([[7, 8]], dtype=torch.int32), + is_prefill=True, + ) + + md = RTPForwardContext._build_plugin_attention_metadata( + attn_inputs=attn_inputs, + positions=torch.arange(1030, dtype=torch.int32), + seq_size_per_block=1024, + kernel_seq_size_per_block=16, + ) + + assert md.plugin_metadata.block_table.cpu().tolist() == [[7, 8]] + assert md.block_tables.shape == (1, 2) + assert md.block_tables.cpu().tolist() == [[7, 8]] + + +def test_base_context_capture_recovers_physical_table_with_prewarmed_buffer(): + attn_inputs = _make_attn_inputs( + input_lengths=torch.tensor([1], dtype=torch.int32), + sequence_lengths=torch.tensor([35], dtype=torch.int32), + kv_cache_kernel_block_id_device=torch.tensor( + [[448, 449, 450, 451, 452, 453, 454, 455]], dtype=torch.int32 + ), + is_prefill=False, + is_cuda_graph=True, + ) + + cg_bufs = {"physical_block_table_i32": torch.empty((1, 1), dtype=torch.int32)} + block_table = RTPForwardContext._resolve_plugin_block_table( + attn_inputs=attn_inputs, + seq_size_per_block=1024, + kernel_seq_size_per_block=128, + cg_bufs=cg_bufs, + in_capture=True, + ) + + assert block_table is not None + assert block_table.cpu().tolist() == [[56]] + + +def test_plugin_attention_metadata_builds_req_id_per_token(): + attn_inputs = _make_attn_inputs( + input_lengths=torch.tensor([2, 1], dtype=torch.int32), + prefix_lengths=torch.tensor([0, 0], dtype=torch.int32), + cu_seqlens=torch.tensor([0, 2, 3], dtype=torch.int32), + kv_cache_block_id_device=torch.tensor([[3], [4]], dtype=torch.int32), + kv_cache_kernel_block_id_device=torch.tensor([[30], [40]], dtype=torch.int32), + is_prefill=True, + ) + + md = RTPForwardContext._build_plugin_attention_metadata( + attn_inputs=attn_inputs, + positions=torch.tensor([0, 1, 0], dtype=torch.int32), + seq_size_per_block=1024, + ) + + assert md.plugin_metadata.req_id_per_token.cpu().tolist() == [0, 0, 1] + assert md.plugin_metadata.sparse_block_size == 1024 + assert md.cu_seqlens_q.cpu().tolist() == [0, 2, 3] + assert md.cu_seqlens_k.cpu().tolist() == [0, 2, 3] + assert md.cu_seqlen_ks.cpu().tolist() == [0, 0, 2] + assert md.cu_seqlen_ke.cpu().tolist() == [1, 2, 3] + assert md.total_kv == 3 + + +def test_build_req_id_per_token_prefers_prewarmed_i32_buffer(monkeypatch): + query_start_loc = torch.tensor([0, 1, 2, 3], dtype=torch.int32) + seq_id_i32 = torch.arange(8, dtype=torch.int32) + + monkeypatch.setattr(torch.cuda, "is_current_stream_capturing", lambda: True) + + req_id = RTPForwardContext._build_req_id_per_token( + query_start_loc=query_start_loc, + num_tokens=3, + device=query_start_loc.device, + cg_bufs={ + "seq_id": torch.arange(8, dtype=torch.int64), + "seq_id_i32": seq_id_i32, + }, + ) + + assert req_id.dtype == torch.int32 + assert req_id.data_ptr() == seq_id_i32.data_ptr() + assert req_id.cpu().tolist() == [0, 1, 2] + + +def test_build_req_id_per_token_requires_prewarmed_i32_buffer_in_capture(monkeypatch): + query_start_loc = torch.tensor([0, 1], dtype=torch.int32) + + monkeypatch.setattr(torch.cuda, "is_current_stream_capturing", lambda: True) + + try: + RTPForwardContext._build_req_id_per_token( + query_start_loc=query_start_loc, + num_tokens=1, + device=query_start_loc.device, + cg_bufs={"seq_id": torch.arange(1, dtype=torch.int64)}, + ) + except RuntimeError as exc: + assert "prewarmed seq_id_i32" in str(exc) + else: + raise AssertionError("expected missing seq_id_i32 to fail during capture") + + +def test_rtpllm_decode_seq_lens_uses_rtp_plus_one_in_graph_and_eager_modes(): input_lengths = torch.tensor([1], dtype=torch.int32) sequence_lengths = torch.tensor([35], dtype=torch.int32) sequence_lengths_plus_1 = torch.tensor([35], dtype=torch.int32) @@ -158,4 +372,214 @@ def test_rtpllm_decode_seq_lens_priority_splits_graph_and_eager_modes(): graph_seq_lens = RTPForwardContext._build_seq_lens( graph_inputs, device=input_lengths.device ) - assert graph_seq_lens.cpu().tolist() == [36] + assert graph_seq_lens.cpu().tolist() == [35] + + +def test_collect_layer_maps_keeps_mla_layers_separate(): + from atom.plugin.rtpllm.attention_backend.rtp_mla_attention import RTPMLAAttention + + mla_layer = RTPMLAAttention(sparse_backend=object(), layer_num=7) + model = SimpleNamespace(modules=lambda: [mla_layer]) + + gdn_map, full_attn_map, mla_map = RTPForwardContext.collect_layer_maps(model) + + assert gdn_map == {} + assert full_attn_map == {} + assert mla_map == {7: mla_layer} + + +def test_collect_layer_maps_keeps_sparse_mla_owner_for_indexer_cache(): + from atom.plugin.rtpllm.attention_backend.rtp_mla_attention import RTPMLAAttention + + mla_layer = RTPMLAAttention(sparse_backend=object(), layer_num=7) + sparse_owner = SimpleNamespace( + layer_num=7, + indexer=SimpleNamespace(), + mla_attn=mla_layer, + ) + model = SimpleNamespace(modules=lambda: [sparse_owner, mla_layer]) + + gdn_map, full_attn_map, mla_map = RTPForwardContext.collect_layer_maps(model) + + assert gdn_map == {} + assert full_attn_map == {} + assert mla_map == {7: sparse_owner} + + +def test_collect_layer_maps_recognizes_atom_mla_wrapper_by_indexer_and_mla_attn(): + from atom.plugin.rtpllm.attention_backend.rtp_mla_attention import RTPMLAAttention + + inner_mla = RTPMLAAttention(sparse_backend=object(), layer_num=9) + atom_wrapper = SimpleNamespace( + layer_num=9, + indexer=SimpleNamespace(), + mla_attn=inner_mla, + ) + model = SimpleNamespace(modules=lambda: [atom_wrapper]) + + _, _, mla_map = RTPForwardContext.collect_layer_maps(model) + + assert mla_map == {9: atom_wrapper} + + +def test_build_kv_cache_tensors_threads_raw_layer_cache_for_mla(): + from atom.plugin.rtpllm.attention_backend.rtp_mla_attention import RTPMLAAttention + + layer_cache = SimpleNamespace(kv_cache_base=torch.empty(2, 3)) + runtime = SimpleNamespace( + kv_cache=SimpleNamespace(get_layer_cache=lambda layer_num: layer_cache) + ) + mla_layer = RTPMLAAttention(sparse_backend=object(), layer_num=7) + + cache_tensors = RTPForwardContext._build_kv_cache_tensors( + runtime=runtime, + layer_maps=({}, {}, {7: mla_layer}), + ) + + assert cache_tensors["layer_7"].layer_num == 7 + assert cache_tensors["layer_7"].k_cache is layer_cache + + +def test_bind_temporarily_attaches_mla_layer_cache(monkeypatch): + from atom.plugin.rtpllm.attention_backend.rtp_mla_attention import RTPMLAAttention + + old_cache = SimpleNamespace(name="old-cache") + new_cache = SimpleNamespace(name="new-cache") + mla_layer = RTPMLAAttention( + sparse_backend=object(), layer_num=7, kv_cache=old_cache + ) + forward_context = SimpleNamespace( + attn_metadata=SimpleNamespace(), + gdn_metadata=SimpleNamespace(), + rtp_attn_inputs=SimpleNamespace(), + rtp_kernel_seq_size_per_block=16, + layer_group_map={}, + kv_cache_data={"layer_7": SimpleNamespace(k_cache=new_cache)}, + context=SimpleNamespace(), + num_tokens=1, + mla_layer_map={7: mla_layer}, + ) + + monkeypatch.setattr( + RTPForwardContext, + "build", + classmethod(lambda cls, **kwargs: forward_context), + ) + monkeypatch.setattr( + "atom.plugin.rtpllm.utils.forward_context.get_current_atom_config", + lambda: SimpleNamespace(kv_cache_block_size=99), + ) + + with RTPForwardContext.bind( + model=SimpleNamespace(), + runtime=SimpleNamespace(), + inputs=SimpleNamespace(), + positions=torch.tensor([0], dtype=torch.int32), + ): + assert mla_layer.kv_cache is new_cache + + assert mla_layer.kv_cache is old_cache + + +def test_bind_writes_kv_cache_to_mla_attn_owner_not_outer_wrapper(monkeypatch): + from atom.plugin.rtpllm.attention_backend.rtp_mla_attention import RTPMLAAttention + + outer_cache = SimpleNamespace(name="outer-cache") + old_inner_cache = SimpleNamespace(name="old-inner-cache") + new_cache = SimpleNamespace(kv_cache_base=torch.empty(2, 3)) + indexer = SimpleNamespace( + head_dim=128, + k_cache=SimpleNamespace(kv_cache=[torch.empty(0)]), + ) + mla_layer = RTPMLAAttention( + sparse_backend=object(), + layer_num=7, + kv_cache=old_inner_cache, + ) + outer = SimpleNamespace( + layer_num=7, + indexer=indexer, + mla_attn=mla_layer, + kv_cache=outer_cache, + ) + forward_context = SimpleNamespace( + attn_metadata=SimpleNamespace(), + gdn_metadata=SimpleNamespace(), + rtp_attn_inputs=SimpleNamespace(), + rtp_kernel_seq_size_per_block=16, + layer_group_map={}, + kv_cache_data={"layer_7": SimpleNamespace(k_cache=new_cache)}, + context=SimpleNamespace(), + num_tokens=1, + mla_layer_map={7: outer}, + ) + + monkeypatch.setattr( + RTPForwardContext, + "build", + classmethod(lambda cls, **kwargs: forward_context), + ) + + with RTPForwardContext.bind( + model=SimpleNamespace(), + runtime=SimpleNamespace(), + inputs=SimpleNamespace(), + positions=torch.tensor([0], dtype=torch.int32), + ): + assert outer.kv_cache is outer_cache + assert mla_layer.kv_cache is new_cache + + assert outer.kv_cache is outer_cache + assert mla_layer.kv_cache is old_inner_cache + + +def test_bind_temporarily_attaches_sparse_mla_indexer_cache(monkeypatch): + from atom.plugin.rtpllm.attention_backend.rtp_mla_attention import RTPMLAAttention + + old_cache = SimpleNamespace(name="old-cache") + layer_cache = SimpleNamespace(kv_cache_base=torch.empty(2, 3)) + old_index_cache = torch.empty(0) + indexer = SimpleNamespace( + head_dim=128, + k_cache=SimpleNamespace(kv_cache=[old_index_cache]), + ) + mla_layer = RTPMLAAttention( + sparse_backend=object(), + layer_num=7, + kv_cache=old_cache, + mla_modules=SimpleNamespace(indexer=indexer), + ) + forward_context = SimpleNamespace( + attn_metadata=SimpleNamespace(), + gdn_metadata=SimpleNamespace(), + rtp_attn_inputs=SimpleNamespace(), + rtp_kernel_seq_size_per_block=16, + layer_group_map={}, + kv_cache_data={"layer_7": SimpleNamespace(k_cache=layer_cache)}, + context=SimpleNamespace(), + num_tokens=1, + mla_layer_map={7: mla_layer}, + ) + + monkeypatch.setattr( + RTPForwardContext, + "build", + classmethod(lambda cls, **kwargs: forward_context), + ) + monkeypatch.setattr( + "atom.plugin.rtpllm.utils.forward_context.get_current_atom_config", + lambda: SimpleNamespace(kv_cache_block_size=16), + ) + + with RTPForwardContext.bind( + model=SimpleNamespace(), + runtime=SimpleNamespace(), + inputs=SimpleNamespace(), + positions=torch.tensor([0], dtype=torch.int32), + ): + assert mla_layer.kv_cache is layer_cache + assert indexer.k_cache.kv_cache[0] is not old_index_cache + assert indexer.k_cache.kv_cache[0].shape == (32, 1, 144) + + assert mla_layer.kv_cache is old_cache + assert indexer.k_cache.kv_cache[0] is old_index_cache diff --git a/tests/plugin/test_rtpllm_glm5_sparse_backend_contract.py b/tests/plugin/test_rtpllm_glm5_sparse_backend_contract.py new file mode 100644 index 0000000000..4a74ada455 --- /dev/null +++ b/tests/plugin/test_rtpllm_glm5_sparse_backend_contract.py @@ -0,0 +1,1159 @@ +"""Tests for GLM5 RTP MLA sparse topk consumption.""" + +import builtins +import importlib +import inspect +import sys +from types import SimpleNamespace + +import torch + +_SPARSE_BACKEND_MODULE = "atom.plugin.rtpllm.attention_backend.rtp_sparse_mla_backend" +_FORBIDDEN_CUDA_SPARSE_MODULES = ( + "flashmla_sparse", + "flash_mla", + "sparse_mla", + "attention_mla_sparse", +) + + +def _guard_sparse_kernel_imports(monkeypatch): + original_import = builtins.__import__ + + def _guarded_import(name, *args, **kwargs): + if any(part in _FORBIDDEN_CUDA_SPARSE_MODULES for part in name.split(".")): + raise AssertionError( + f"GLM5 RTP sparse tests must not import CUDA sparse kernel: {name}" + ) + return original_import(name, *args, **kwargs) + + monkeypatch.setattr(builtins, "__import__", _guarded_import) + + +def _load_sparse_backend(monkeypatch): + _guard_sparse_kernel_imports(monkeypatch) + importlib.invalidate_caches() + module = importlib.import_module(_SPARSE_BACKEND_MODULE) + return module.RTPSparseMlaBackend + + +def _forward_context_module(): + module = sys.modules.get("atom.utils.forward_context") + if module is None: + module = type(sys)("atom.utils.forward_context") + module.get_forward_context = lambda: None + sys.modules["atom.utils.forward_context"] = module + return module + + +def test_rtp_sparse_attn_indexer_bridge_forwards_to_main_indexer(monkeypatch): + module = importlib.import_module(_SPARSE_BACKEND_MODULE) + calls = [] + expected = torch.empty(1) + + def fake_sparse_attn_indexer(*args): + calls.append(args) + return expected + + fake_deepseek = type(sys)("atom.models.deepseek_v2") + fake_deepseek.sparse_attn_indexer = fake_sparse_attn_indexer + monkeypatch.setitem(sys.modules, "atom.models.deepseek_v2", fake_deepseek) + + tensor = torch.empty(1) + output = module.rtp_sparse_attn_indexer( + tensor, + "indexer.prefix", + tensor, + tensor, + tensor, + tensor, + 128, + None, + 2048, + 64, + 4096, + 1, + tensor, + tensor, + tensor, + 1e-6, + tensor, + tensor, + tensor, + 1.0, + True, + False, + ) + + assert output is expected + assert len(calls) == 1 + assert calls[0][0] is tensor + assert calls[0][1] == "indexer.prefix" + assert calls[0][6:12] == (128, None, 2048, 64, 4096, 1) + + +def test_rtp_sparse_attn_indexer_uses_rtp_topk_path_when_context_exists(monkeypatch): + module = importlib.import_module(_SPARSE_BACKEND_MODULE) + forward_context_mod = _forward_context_module() + fake_forward_context = SimpleNamespace( + context=SimpleNamespace(is_prefill=False, is_dummy_run=False, batch_size=1), + attn_metadata=SimpleNamespace(max_seqlen_q=1), + ) + monkeypatch.setattr( + forward_context_mod, + "get_forward_context", + lambda: fake_forward_context, + raising=False, + ) + + def _unexpected_call(*args, **kwargs): + raise AssertionError("RTP context path must not call deepseek sparse indexer") + + fake_deepseek = type(sys)("atom.models.deepseek_v2") + fake_deepseek.sparse_attn_indexer = _unexpected_call + monkeypatch.setitem(sys.modules, "atom.models.deepseek_v2", fake_deepseek) + + expected = torch.empty(1) + calls = [] + + def _fake_topk_only(*args): + calls.append(args) + return expected + + monkeypatch.setattr( + module, "_run_rtp_sparse_attn_indexer_topk_only", _fake_topk_only + ) + tensor = torch.empty(1) + + output = module.rtp_sparse_attn_indexer( + tensor, + "indexer.prefix", + tensor, + tensor, + tensor, + tensor, + 128, + None, + 2048, + 64, + 4096, + 1, + torch.empty(1, 2048, dtype=torch.int32), + tensor, + tensor, + 1e-6, + tensor, + tensor, + tensor, + 1.0, + True, + False, + ) + + assert output is expected + assert len(calls) == 1 + assert calls[0][-2:] == ( + fake_forward_context.context, + fake_forward_context.attn_metadata, + ) + + +def test_rtp_sparse_attn_indexer_fake_bridge_forwards_to_main_fake(monkeypatch): + module = importlib.import_module(_SPARSE_BACKEND_MODULE) + calls = [] + expected = torch.empty(1) + + def fake_sparse_attn_indexer_fake(*args): + calls.append(args) + return expected + + fake_deepseek = type(sys)("atom.models.deepseek_v2") + fake_deepseek.sparse_attn_indexer_fake = fake_sparse_attn_indexer_fake + monkeypatch.setitem(sys.modules, "atom.models.deepseek_v2", fake_deepseek) + + tensor = torch.empty(1) + output = module.rtp_sparse_attn_indexer_fake( + tensor, + "indexer.prefix", + tensor, + tensor, + tensor, + tensor, + 128, + None, + 2048, + 64, + 4096, + 1, + tensor, + tensor, + tensor, + 1e-6, + tensor, + tensor, + tensor, + 1.0, + True, + False, + ) + + assert output is expected + assert len(calls) == 1 + assert calls[0][0] is tensor + assert calls[0][1] == "indexer.prefix" + assert calls[0][6:12] == (128, None, 2048, 64, 4096, 1) + + +def test_rtp_sparse_attn_indexer_short_prefill_fills_causal_topk(monkeypatch): + module = importlib.import_module(_SPARSE_BACKEND_MODULE) + forward_context_mod = _forward_context_module() + fake_forward_context = SimpleNamespace( + context=SimpleNamespace(is_prefill=True, is_dummy_run=False), + attn_metadata=SimpleNamespace(max_seqlen_k=4), + ) + monkeypatch.setattr( + forward_context_mod, + "get_forward_context", + lambda: fake_forward_context, + raising=False, + ) + + def _unexpected_call(*args, **kwargs): + raise AssertionError( + "short prefill path should not call deepseek sparse_attn_indexer" + ) + + fake_deepseek = type(sys)("atom.models.deepseek_v2") + fake_deepseek.sparse_attn_indexer = _unexpected_call + monkeypatch.setitem(sys.modules, "atom.models.deepseek_v2", fake_deepseek) + + topk_buffer = torch.full((3, 8), -99, dtype=torch.int32) + positions = torch.tensor([0, 1, 3], dtype=torch.int32) + tensor = torch.empty(3, 2) + weights = torch.randn(3, 4) + out = module.rtp_sparse_attn_indexer( + tensor, + "indexer.prefix", + tensor, + tensor, + tensor, + weights, + 128, + None, + 6, + 64, + 4096, + 3, + topk_buffer, + tensor, + tensor, + 1e-6, + positions, + tensor, + tensor, + 1.0, + True, + False, + ) + + assert out is weights + assert topk_buffer[:3, :6].tolist() == [ + [0, -1, -1, -1, -1, -1], + [0, 1, -1, -1, -1, -1], + [0, 1, 2, 3, -1, -1], + ] + + +class _FakeSparseImpl: + def __init__(self, v_head_dim: int = 5): + self.v_head_dim = v_head_dim + self.calls = [] + + def forward( + self, + q, + compressed_kv, + k_pe, + kv_cache, + layer_id, + *, + topk_indices, + attn_metadata, + ): + self.calls.append( + { + "q": q, + "compressed_kv": compressed_kv, + "k_pe": k_pe, + "kv_cache": kv_cache, + "layer_id": layer_id, + "topk_indices": topk_indices, + "attn_metadata": attn_metadata, + } + ) + return q.new_full((q.shape[0], q.shape[1], self.v_head_dim), 7) + + +def _build_backend(backend_cls, sparse_impl): + params = inspect.signature(backend_cls).parameters + kwargs = {} + + if "sparse_impl" in params: + kwargs["sparse_impl"] = sparse_impl + else: + raise AssertionError( + "RTPSparseMlaBackend must accept an injected sparse implementation" + ) + + if "v_head_dim" in params: + kwargs["v_head_dim"] = int(getattr(sparse_impl, "v_head_dim", 5)) + return backend_cls(**kwargs) + + +def _make_inputs(): + return ( + torch.randn(3, 2, 4), + torch.randn(3, 8), + torch.randn(3, 3), + SimpleNamespace(name="kv-cache"), + 11, + ) + + +def test_sparse_backend_passes_topk_through_unchanged(monkeypatch): + backend_cls = _load_sparse_backend(monkeypatch) + sparse_impl = _FakeSparseImpl() + backend = _build_backend(backend_cls, sparse_impl) + q, compressed_kv, k_pe, kv_cache, layer_id = _make_inputs() + topk = torch.tensor([[4, 1], [3, 0], [2, 1]], dtype=torch.int32) + + output = backend.forward( + q, compressed_kv, k_pe, kv_cache, layer_id, topk_indices=topk + ) + + assert output.shape == (3, 2, sparse_impl.v_head_dim) + assert len(sparse_impl.calls) == 1 + assert sparse_impl.calls[0]["topk_indices"] is topk + assert sparse_impl.calls[0]["topk_indices"].dtype == torch.int32 + assert sparse_impl.calls[0]["topk_indices"].shape == (3, 2) + + +def test_sparse_backend_prefill_without_topk_raises(monkeypatch): + backend_cls = _load_sparse_backend(monkeypatch) + forward_context_mod = _forward_context_module() + fake_forward_context = SimpleNamespace( + context=SimpleNamespace(is_dummy_run=False), + attn_metadata=SimpleNamespace( + plugin_metadata=SimpleNamespace(num_prefills=1, is_dummy_warmup=False) + ), + ) + monkeypatch.setattr( + forward_context_mod, + "get_forward_context", + lambda: fake_forward_context, + raising=False, + ) + sparse_impl = _FakeSparseImpl() + backend = _build_backend(backend_cls, sparse_impl) + q, compressed_kv, k_pe, kv_cache, layer_id = _make_inputs() + + module = importlib.import_module(_SPARSE_BACKEND_MODULE) + try: + backend.forward(q, compressed_kv, k_pe, kv_cache, layer_id, topk_indices=None) + except module._SparseUnavailable as exc: + assert "requires topk_indices" in str(exc) + else: + raise AssertionError("Expected missing prefill topk_indices to raise") + assert sparse_impl.calls == [] + + +def test_sparse_backend_decode_without_topk_raises(monkeypatch): + backend_cls = _load_sparse_backend(monkeypatch) + module = importlib.import_module(_SPARSE_BACKEND_MODULE) + forward_context_mod = _forward_context_module() + fake_forward_context = SimpleNamespace( + context=SimpleNamespace(is_dummy_run=False), + attn_metadata=SimpleNamespace( + plugin_metadata=SimpleNamespace(num_prefills=0, is_dummy_warmup=False) + ), + ) + monkeypatch.setattr( + forward_context_mod, + "get_forward_context", + lambda: fake_forward_context, + raising=False, + ) + sparse_impl = _FakeSparseImpl() + backend = _build_backend(backend_cls, sparse_impl) + q, compressed_kv, k_pe, kv_cache, layer_id = _make_inputs() + + try: + backend.forward(q, compressed_kv, k_pe, kv_cache, layer_id, topk_indices=None) + except module._SparseUnavailable as exc: + assert "requires topk_indices" in str(exc) + else: + raise AssertionError("Expected missing decode topk_indices to raise") + assert sparse_impl.calls == [] + + +def test_sparse_backend_threads_kv_cache_and_layer_id_to_sparse_impl(monkeypatch): + backend_cls = _load_sparse_backend(monkeypatch) + sparse_impl = _FakeSparseImpl() + backend = _build_backend(backend_cls, sparse_impl) + q, compressed_kv, k_pe, kv_cache, layer_id = _make_inputs() + topk = torch.tensor([[1, 0], [0, 1], [1, 1]], dtype=torch.int32) + + backend.forward(q, compressed_kv, k_pe, kv_cache, layer_id, topk_indices=topk) + + call = sparse_impl.calls[0] + assert call["q"] is q + assert call["compressed_kv"] is compressed_kv + assert call["k_pe"] is k_pe + assert call["kv_cache"] is kv_cache + assert call["layer_id"] == layer_id + + +def test_sparse_backend_pulls_attn_metadata_from_forward_context(monkeypatch): + backend_cls = _load_sparse_backend(monkeypatch) + forward_context_mod = _forward_context_module() + + attn_metadata = SimpleNamespace(block_table="block-table", seq_lens="seq-lens") + fake_forward_context = SimpleNamespace( + context=SimpleNamespace(is_dummy_run=False), + attn_metadata=attn_metadata, + ) + monkeypatch.setattr( + forward_context_mod, + "get_forward_context", + lambda: fake_forward_context, + raising=False, + ) + sparse_impl = _FakeSparseImpl() + backend = _build_backend(backend_cls, sparse_impl) + q, compressed_kv, k_pe, kv_cache, layer_id = _make_inputs() + topk = torch.tensor([[1, 0], [0, 1], [1, 1]], dtype=torch.int32) + + backend.forward(q, compressed_kv, k_pe, kv_cache, layer_id, topk_indices=topk) + + assert sparse_impl.calls[0]["attn_metadata"] is attn_metadata + + +def test_sparse_backend_prefill_missing_sparse_kernel_raises(monkeypatch): + backend_cls = _load_sparse_backend(monkeypatch) + module = importlib.import_module(_SPARSE_BACKEND_MODULE) + forward_context_mod = _forward_context_module() + + attn_metadata = SimpleNamespace( + plugin_metadata=SimpleNamespace(num_prefills=1, is_dummy_warmup=False) + ) + fake_forward_context = SimpleNamespace( + context=SimpleNamespace(is_dummy_run=False), + attn_metadata=attn_metadata, + ) + monkeypatch.setattr( + forward_context_mod, + "get_forward_context", + lambda: fake_forward_context, + raising=False, + ) + + class _MissingPrefillSparse: + def forward(self, *args, **kwargs): + raise module._SparseUnavailable("flash_mla_sparse_fwd unavailable") + + sparse_impl = _MissingPrefillSparse() + backend = _build_backend(backend_cls, sparse_impl) + q, compressed_kv, k_pe, kv_cache, layer_id = _make_inputs() + topk = torch.tensor([[1, 0], [0, 1], [1, 1]], dtype=torch.int32) + + try: + backend.forward(q, compressed_kv, k_pe, kv_cache, layer_id, topk_indices=topk) + except module._SparseUnavailable: + pass + else: + raise AssertionError( + "prefill sparse unavailability must not fall back to dense" + ) + + +def test_sparse_backend_decode_missing_sparse_kernel_still_raises(monkeypatch): + backend_cls = _load_sparse_backend(monkeypatch) + module = importlib.import_module(_SPARSE_BACKEND_MODULE) + forward_context_mod = _forward_context_module() + + attn_metadata = SimpleNamespace( + plugin_metadata=SimpleNamespace(num_prefills=0, is_dummy_warmup=False) + ) + fake_forward_context = SimpleNamespace( + context=SimpleNamespace(is_dummy_run=False), + attn_metadata=attn_metadata, + ) + monkeypatch.setattr( + forward_context_mod, + "get_forward_context", + lambda: fake_forward_context, + raising=False, + ) + + class _MissingDecodeSparse: + def forward(self, *args, **kwargs): + raise module._SparseUnavailable("flash_mla_sparse_fwd unavailable") + + backend = _build_backend(backend_cls, _MissingDecodeSparse()) + q, compressed_kv, k_pe, kv_cache, layer_id = _make_inputs() + topk = torch.tensor([[1, 0], [0, 1], [1, 1]], dtype=torch.int32) + + try: + backend.forward(q, compressed_kv, k_pe, kv_cache, layer_id, topk_indices=topk) + except module._SparseUnavailable: + pass + else: + raise AssertionError("decode sparse unavailability must not fall back to dense") + + +def test_sparse_backend_forward_signature_matches_dense_boundary(monkeypatch): + backend_cls = _load_sparse_backend(monkeypatch) + + signature = inspect.signature(backend_cls.forward) + params = signature.parameters + + assert list(params) == [ + "self", + "q", + "compressed_kv", + "k_pe", + "kv_cache", + "layer_id", + "topk_indices", + "positions", + ] + assert params["topk_indices"].default is None + + +def test_sparse_backend_converts_request_local_topk_to_global_slots(monkeypatch): + backend_cls = _load_sparse_backend(monkeypatch) + module = importlib.import_module(_SPARSE_BACKEND_MODULE) + convert = module._RealSparseMlaImpl._convert_topk_to_global + plugin_metadata = SimpleNamespace( + block_table=torch.tensor([[7, 8], [20, 21]], dtype=torch.int32), + req_id_per_token=torch.tensor([0, 1], dtype=torch.int32), + ) + attn_metadata = SimpleNamespace(plugin_metadata=plugin_metadata) + topk = torch.tensor( + [ + [0, 1029, -1], + [1024, 2048, 5], + ], + dtype=torch.int32, + ) + + del backend_cls + global_topk = convert( + topk_indices=topk, + attn_metadata=attn_metadata, + block_size=1024, + ) + + assert global_topk.cpu().tolist() == [ + [7 * 1024, 8 * 1024 + 5, 0], + [21 * 1024, 0, 20 * 1024 + 5], + ] + + +def test_real_sparse_decode_uses_atom_aiter_metadata(monkeypatch): + module = importlib.import_module(_SPARSE_BACKEND_MODULE) + calls = {} + + aiter = type(sys)("aiter") + aiter.dtypes = SimpleNamespace( + fp8=torch.float8_e4m3fnuz, + d_dtypes={"fp16": torch.float16, "bf16": torch.bfloat16}, + ) + monkeypatch.setitem(sys.modules, "aiter", aiter) + + def fake_metadata_info(*args, **kwargs): + calls["metadata_info"] = (args, kwargs) + return ( + (4, torch.int32), + (4, torch.int32), + (4, torch.int32), + (4, torch.int32), + (4, torch.int32), + (4, torch.int32), + ) + + def fake_metadata_v1(*args, **kwargs): + calls["metadata_v1"] = (args, kwargs) + + monkeypatch.setattr( + aiter, "get_mla_metadata_info_v1", fake_metadata_info, raising=False + ) + monkeypatch.setattr(aiter, "get_mla_metadata_v1", fake_metadata_v1, raising=False) + + fake_mla = type(sys)("aiter.mla") + + def fake_mla_decode_fwd( + q, + kv, + output, + qo_indptr, + paged_kv_indptr, + paged_kv_indices, + paged_kv_last_page_len, + *args, + **kwargs, + ): + calls["mla_decode_fwd"] = { + "q": q, + "kv": kv, + "output": output, + "qo_indptr": qo_indptr, + "paged_kv_indptr": paged_kv_indptr, + "paged_kv_indices": paged_kv_indices, + "paged_kv_last_page_len": paged_kv_last_page_len, + "args": args, + "kwargs": kwargs, + } + output.fill_(3) + + fake_mla.mla_decode_fwd = fake_mla_decode_fwd + monkeypatch.setitem(sys.modules, "aiter.mla", fake_mla) + + fake_sparse_helpers = type(sys)("atom.plugin.attention_mla_sparse") + + def fake_generate_sparse_seqlen( + query_lens, seq_lens, query_start_loc, topk, num_tokens, max_query_len + ): + return torch.tensor([3, 2], dtype=torch.int32, device=query_lens.device) + + def fake_convert( + req_id, + block_table, + token_indices, + cu_seqlens, + out, + BLOCK_SIZE=1, + NUM_TOPK_TOKENS=0, + BLOCK_N=128, + ): + out[:5] = torch.tensor([0, 1, 2, 4, 5], dtype=torch.int32, device=out.device) + + fake_sparse_helpers.generate_sparse_seqlen_triton = fake_generate_sparse_seqlen + fake_sparse_helpers.triton_convert_req_index_to_global_index = fake_convert + monkeypatch.setitem( + sys.modules, + "atom.plugin.attention_mla_sparse", + fake_sparse_helpers, + ) + + impl = module._RealSparseMlaImpl( + mla_modules=SimpleNamespace( + kv_lora_rank=4, + qk_nope_head_dim=2, + qk_rope_head_dim=1, + num_heads=2, + rotary_emb=None, + kv_b_proj=SimpleNamespace(weight=torch.empty(0)), + ), + v_head_dim=3, + ) + q_latent = torch.randn(2, 2, 5, dtype=torch.bfloat16) + kv_cache = torch.empty(8, 1, 5, dtype=torch.uint8) + topk = torch.tensor([[0, 1, 2], [0, 1, -1]], dtype=torch.int32) + attn_metadata = SimpleNamespace( + plugin_metadata=SimpleNamespace( + query_start_loc=torch.tensor([0, 1, 2], dtype=torch.int32), + seq_lens=torch.tensor([3, 2], dtype=torch.int32), + req_id_per_token=torch.tensor([0, 1], dtype=torch.int32), + block_table=torch.tensor([[0], [1]], dtype=torch.int32), + ) + ) + + output = impl._run_aiter_sparse_decode( + q_latent=q_latent, + kv_cache_base=kv_cache, + topk_indices=topk, + attn_metadata=attn_metadata, + block_size=4, + ) + + assert output.shape == (2, 2, 4) + assert output.dtype == torch.bfloat16 + assert torch.all(output == 3) + decode_call = calls["mla_decode_fwd"] + assert decode_call["q"].shape == (2, 16, 5) + assert decode_call["q"].dtype == aiter.dtypes.fp8 + assert decode_call["output"].shape == (2, 16, 4) + assert decode_call["output"].dtype == torch.bfloat16 + assert decode_call["paged_kv_indptr"].tolist() == [0, 3, 5] + assert decode_call["paged_kv_indices"][:5].tolist() == [0, 1, 2, 4, 5] + assert decode_call["kwargs"]["page_size"] == 1 + assert decode_call["kwargs"]["q_scale"] is not None + assert decode_call["kwargs"]["kv_scale"] is not None + assert decode_call["kwargs"]["work_meta_data"] is not None + assert decode_call["kwargs"]["reduce_final_map"] is not None + + +def test_real_sparse_cache_dtype_uses_aiter_fp8_layout(): + module = importlib.import_module(_SPARSE_BACKEND_MODULE) + impl = module._RealSparseMlaImpl( + mla_modules=SimpleNamespace( + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + num_heads=2, + rotary_emb=None, + kv_b_proj=SimpleNamespace(weight=torch.empty(0)), + ), + v_head_dim=128, + ) + + assert impl._cache_dtype_name(torch.empty(1, 576, dtype=torch.uint8)) == "fp8" + assert impl._cache_dtype_name(torch.empty(1, 576, dtype=torch.bfloat16)) == "auto" + + +def test_sparse_index_converter_resolves_current_refactored_path(monkeypatch): + module = importlib.import_module(_SPARSE_BACKEND_MODULE) + old_module_name = "atom.plugin.attention_mla_sparse" + new_module_name = "atom.plugin.vllm.attention.layer_sparse_mla" + monkeypatch.delitem(sys.modules, old_module_name, raising=False) + + fake_new_helpers = type(sys)(new_module_name) + + def fake_convert(): + return None + + fake_new_helpers.triton_convert_req_index_to_global_index = fake_convert + monkeypatch.setitem(sys.modules, new_module_name, fake_new_helpers) + + assert module._resolve_plugin_sparse_index_converter() is fake_convert + + +def test_real_sparse_eager_metadata_workspace_skips_refill(monkeypatch): + module = importlib.import_module(_SPARSE_BACKEND_MODULE) + metadata_calls = [] + + fake_aiter = type(sys)("aiter") + fake_aiter.dtypes = SimpleNamespace(d_dtypes={"bf16": "bf16", "fp16": "fp16"}) + + def fake_metadata_info(*args, **kwargs): + return ( + (4, torch.int32), + (4, torch.int32), + (4, torch.int32), + (4, torch.int32), + (4, torch.int32), + (4, torch.int32), + ) + + def fake_metadata_v1(*args, **kwargs): + metadata_calls.append((args, kwargs)) + + fake_aiter.get_mla_metadata_info_v1 = fake_metadata_info + fake_aiter.get_mla_metadata_v1 = fake_metadata_v1 + monkeypatch.setitem(sys.modules, "aiter", fake_aiter) + monkeypatch.setattr( + torch.cuda, "is_current_stream_capturing", lambda: False, raising=False + ) + + fake_sparse_helpers = type(sys)("atom.plugin.attention_mla_sparse") + + def fake_convert( + req_id, + block_table, + token_indices, + cu_seqlens, + out, + BLOCK_SIZE=1, + NUM_TOPK_TOKENS=0, + BLOCK_N=128, + ): + del req_id, block_table, token_indices, BLOCK_SIZE, NUM_TOPK_TOKENS, BLOCK_N + out[: int(cu_seqlens[-1].item())].zero_() + + fake_sparse_helpers.triton_convert_req_index_to_global_index = fake_convert + monkeypatch.setitem( + sys.modules, + "atom.plugin.attention_mla_sparse", + fake_sparse_helpers, + ) + + impl = module._RealSparseMlaImpl( + mla_modules=SimpleNamespace( + kv_lora_rank=4, + qk_nope_head_dim=2, + qk_rope_head_dim=1, + num_heads=2, + rotary_emb=None, + kv_b_proj=SimpleNamespace(weight=torch.empty(0)), + ), + v_head_dim=3, + ) + q_latent = torch.randn(2, 2, 5) + kv_cache = torch.randn(8, 1, 5) + topk = torch.tensor([[0, 1, 2], [0, 1, -1]], dtype=torch.int32) + plugin_metadata = SimpleNamespace( + query_start_loc=torch.tensor([0, 1, 2], dtype=torch.int32), + seq_lens=torch.tensor([3, 2], dtype=torch.int32), + req_id_per_token=torch.tensor([0, 1], dtype=torch.int32), + block_table=torch.tensor([[0], [1]], dtype=torch.int32), + ) + attn_metadata = SimpleNamespace(plugin_metadata=plugin_metadata) + + first = impl._build_atom_sparse_metadata( + q_latent=q_latent, + kv_cache_base=kv_cache, + topk_indices=topk, + attn_metadata=attn_metadata, + block_size=4, + ) + second = impl._build_atom_sparse_metadata( + q_latent=q_latent, + kv_cache_base=kv_cache, + topk_indices=topk, + attn_metadata=attn_metadata, + block_size=4, + ) + + assert len(metadata_calls) == 1 + assert second.work_meta_data is first.work_meta_data + assert plugin_metadata._rtp_sparse_eager_meta_workspace["metadata_ready"] is True + + +def test_real_sparse_decode_rejects_oob_paged_kv_indices(monkeypatch): + module = importlib.import_module(_SPARSE_BACKEND_MODULE) + decode_called = {"value": False} + monkeypatch.setenv("ATOM_RTP_GLM5_SPARSE_VALIDATE", "1") + + fake_mla = type(sys)("aiter.mla") + + def fake_mla_decode_fwd(*args, **kwargs): + decode_called["value"] = True + + fake_mla.mla_decode_fwd = fake_mla_decode_fwd + monkeypatch.setitem(sys.modules, "aiter.mla", fake_mla) + + impl = module._RealSparseMlaImpl( + mla_modules=SimpleNamespace( + kv_lora_rank=4, + qk_nope_head_dim=2, + qk_rope_head_dim=1, + num_heads=2, + rotary_emb=None, + kv_b_proj=SimpleNamespace(weight=torch.empty(0)), + ), + v_head_dim=3, + ) + q_latent = torch.randn(2, 2, 5) + kv_cache = torch.randn(8, 1, 5) + topk = torch.tensor([[0, 1, 2], [0, 1, -1]], dtype=torch.int32) + attn_metadata = SimpleNamespace(plugin_metadata=SimpleNamespace()) + + oob_meta = module._AtomSparseMetadata( + qo_indptr=torch.tensor([0, 1, 2], dtype=torch.int32), + paged_kv_indptr=torch.tensor([0, 3, 6], dtype=torch.int32), + # kv_buffer has 8 slots, index=8 is out of range. + paged_kv_indices=torch.tensor([0, 1, 2, 3, 4, 8], dtype=torch.int32), + paged_kv_last_page_len=torch.ones(2, dtype=torch.int32), + work_meta_data=torch.zeros(1, dtype=torch.int32), + work_indptr=torch.zeros(1, dtype=torch.int32), + work_info_set=torch.zeros(1, dtype=torch.int32), + reduce_indptr=torch.zeros(1, dtype=torch.int32), + reduce_final_map=torch.zeros(1, dtype=torch.int32), + reduce_partial_map=torch.zeros(1, dtype=torch.int32), + padded_num_heads=2, + head_repeat_factor=1, + page_size=1, + ) + monkeypatch.setattr(impl, "_build_atom_sparse_metadata", lambda **kwargs: oob_meta) + + try: + impl._run_aiter_sparse_decode( + q_latent=q_latent, + kv_cache_base=kv_cache, + topk_indices=topk, + attn_metadata=attn_metadata, + block_size=4, + ) + except module._SparseUnavailable as exc: + assert "out-of-range paged_kv_indices" in str(exc) + else: + raise AssertionError( + "Expected OOB paged_kv_indices to raise _SparseUnavailable" + ) + assert decode_called["value"] is False + + +def _load_rtp_mla_attention(): + module = importlib.import_module( + "atom.plugin.rtpllm.attention_backend.rtp_mla_attention" + ) + return module.RTPMLAAttention + + +class _FakeSparseBackend: + def __init__(self, v_head_dim: int): + self.v_head_dim = v_head_dim + self.calls = [] + + def forward(self, q, compressed_kv, k_pe, kv_cache, layer_id, topk_indices=None): + self.calls.append( + { + "q": q, + "compressed_kv": compressed_kv, + "k_pe": k_pe, + "kv_cache": kv_cache, + "layer_id": layer_id, + "topk_indices": topk_indices, + } + ) + return q.new_empty((q.shape[0], q.shape[1], self.v_head_dim)) + + +class _FakeIndexer: + def __init__(self, topk_values): + self.calls = [] + self.index_topk = topk_values.shape[1] + self.topk_indices_buffer = torch.full( + (topk_values.shape[0], topk_values.shape[1] + 2), + -1, + dtype=torch.int32, + ) + self.topk_indices_buffer[: topk_values.shape[0], : topk_values.shape[1]].copy_( + topk_values + ) + self.weights = torch.full(topk_values.shape, 99.0, dtype=torch.float32) + + def __call__(self, *args, **kwargs): + self.calls.append((args, kwargs)) + return self.weights + + +class _FakeQProj: + def __init__(self, output): + self.output = output + self.calls = [] + + def __call__(self, query, q_scale=None): + self.calls.append((query, q_scale)) + return self.output + + +class _FakeOProj: + def __init__(self): + self.calls = [] + + def __call__(self, tensor): + self.calls.append(tensor) + return tensor + + +def _make_attention(topk_values): + token_count = topk_values.shape[0] + num_heads = 2 + qk_head_dim = 4 + v_head_dim = 3 + projected_q = torch.arange( + token_count * num_heads * qk_head_dim, dtype=torch.float32 + ).reshape(token_count, num_heads * qk_head_dim) + backend = _FakeSparseBackend(v_head_dim=v_head_dim) + indexer = _FakeIndexer(topk_values) + modules = SimpleNamespace( + q_proj=_FakeQProj(projected_q), + o_proj=_FakeOProj(), + kv_b_proj=object(), + indexer=indexer, + v_head_dim=v_head_dim, + qk_head_dim=qk_head_dim, + num_heads=num_heads, + num_local_heads=num_heads, + index_topk=topk_values.shape[1], + ) + attention = _load_rtp_mla_attention()( + mla_modules=modules, + sparse_backend=backend, + layer_num=7, + kv_cache="kv-cache", + ) + return attention, modules, backend + + +def _run_attention(attention, token_count: int): + query = torch.empty(token_count, 6) + compressed_kv = torch.empty(token_count, 8) + k_rope = torch.empty(token_count, 3) + positions = torch.arange(token_count, dtype=torch.int32) + return attention.forward( + query, + compressed_kv, + k_rope, + positions=positions, + ) + + +def _patch_forward_context(monkeypatch, *, is_dummy_run, is_prefill, max_seqlen_k): + forward_context_mod = sys.modules["atom.utils.forward_context"] + + fake_forward_context = SimpleNamespace( + context=SimpleNamespace(is_dummy_run=is_dummy_run, is_prefill=is_prefill), + attn_metadata=SimpleNamespace(max_seqlen_k=max_seqlen_k), + ) + monkeypatch.setattr( + forward_context_mod, + "get_forward_context", + lambda: fake_forward_context, + raising=False, + ) + + +def test_constructor_injects_indexer_and_topk_indices_buffer_owner_path(): + topk_buffer = torch.tensor([[4, 1, 3, 0]], dtype=torch.int32) + indexer = SimpleNamespace(topk_indices_buffer=topk_buffer, index_topk=4) + modules = SimpleNamespace( + q_proj=object(), + o_proj=object(), + kv_b_proj=object(), + indexer=indexer, + v_head_dim=3, + ) + attention = _load_rtp_mla_attention()(mla_modules=modules) + + assert attention.indexer is indexer + assert attention.topk_indices_buffer is topk_buffer + + +def test_constructor_swaps_indexer_to_rtp_sparse_indexer_op(monkeypatch): + default_op = object() + rtp_op = object() + monkeypatch.setattr( + torch.ops.aiter, "rtp_sparse_attn_indexer", rtp_op, raising=False + ) + topk_buffer = torch.tensor([[4, 1, 3, 0]], dtype=torch.int32) + indexer = SimpleNamespace( + topk_indices_buffer=topk_buffer, + index_topk=4, + sparse_attn_indexer_impl=default_op, + ) + modules = SimpleNamespace( + q_proj=object(), + o_proj=object(), + kv_b_proj=object(), + indexer=indexer, + v_head_dim=3, + ) + + attention = _load_rtp_mla_attention()(mla_modules=modules, sparse_backend=object()) + + assert attention.indexer is indexer + assert indexer.sparse_attn_indexer_impl is rtp_op + + +def test_constructor_patches_indexer_forward_to_own_topk_buffer(monkeypatch): + default_op = object() + rtp_op = object() + monkeypatch.setattr( + torch.ops.aiter, "rtp_sparse_attn_indexer", rtp_op, raising=False + ) + + class _ForwardIndexer: + def __init__(self): + self.topk_tokens = 4 + self.sparse_attn_indexer_impl = default_op + self.sparse_kv_indices_buffer = torch.empty(0, dtype=torch.int32) + self.seen_sparse_buffer = None + + def forward(self, hidden_states): + self.seen_sparse_buffer = self.sparse_kv_indices_buffer + return hidden_states + + indexer = _ForwardIndexer() + modules = SimpleNamespace( + q_proj=object(), + o_proj=object(), + kv_b_proj=object(), + indexer=indexer, + v_head_dim=3, + ) + + _load_rtp_mla_attention()(mla_modules=modules, sparse_backend=object()) + hidden_states = torch.empty(2, 8) + indexer.forward(hidden_states) + + assert indexer.sparse_attn_indexer_impl is rtp_op + assert indexer.topk_indices_buffer.shape == (2, 4) + assert indexer.topk_indices_buffer.dtype == torch.int32 + assert indexer.seen_sparse_buffer is indexer.topk_indices_buffer + + +def test_indexer_buffer_topk_is_passed_to_sparse_backend_when_emit_allowed(monkeypatch): + _guard_sparse_kernel_imports(monkeypatch) + topk_values = torch.tensor([[4, 1, 3, 0], [2, 0, 1, 3]], dtype=torch.int32) + attention, modules, backend = _make_attention(topk_values) + + _run_attention(attention, token_count=topk_values.shape[0]) + + assert modules.indexer.calls == [] + topk_indices = backend.calls[0]["topk_indices"] + assert topk_indices is not None + assert topk_indices.dtype == torch.int32 + assert topk_indices.shape == topk_values.shape + assert torch.equal(topk_indices, topk_values) + assert topk_indices is not modules.indexer.weights + assert not torch.equal(topk_indices.to(torch.float32), modules.indexer.weights) + + +def test_dummy_run_does_not_emit_topk_to_sparse_backend(monkeypatch): + _guard_sparse_kernel_imports(monkeypatch) + _patch_forward_context( + monkeypatch, + is_dummy_run=True, + is_prefill=False, + max_seqlen_k=4096, + ) + topk_values = torch.tensor([[4, 1, 3, 0], [2, 0, 1, 3]], dtype=torch.int32) + attention, modules, backend = _make_attention(topk_values) + + _run_attention(attention, token_count=topk_values.shape[0]) + + assert modules.indexer.calls == [] + assert backend.calls[0]["topk_indices"] is None + + +def test_short_prefill_emits_topk_to_sparse_backend(monkeypatch): + _guard_sparse_kernel_imports(monkeypatch) + _patch_forward_context( + monkeypatch, + is_dummy_run=False, + is_prefill=True, + max_seqlen_k=4, + ) + topk_values = torch.tensor([[4, 1, 3, 0], [2, 0, 1, 3]], dtype=torch.int32) + attention, modules, backend = _make_attention(topk_values) + + _run_attention(attention, token_count=topk_values.shape[0]) + + assert modules.indexer.calls == [] + topk_indices = backend.calls[0]["topk_indices"] + assert topk_indices is not None + assert torch.equal(topk_indices, topk_values) + + +def test_prefill_within_topk_buffer_padding_still_emits_topk(monkeypatch): + _guard_sparse_kernel_imports(monkeypatch) + _patch_forward_context( + monkeypatch, + is_dummy_run=False, + is_prefill=True, + max_seqlen_k=5, + ) + topk_values = torch.tensor([[4, 1, 3, 0], [2, 0, 1, 3]], dtype=torch.int32) + attention, modules, backend = _make_attention(topk_values) + + _run_attention(attention, token_count=topk_values.shape[0]) + + assert modules.indexer.index_topk == 4 + assert modules.indexer.topk_indices_buffer.shape[1] == 6 + assert modules.indexer.calls == [] + topk_indices = backend.calls[0]["topk_indices"] + assert topk_indices is not None + assert torch.equal(topk_indices, topk_values) diff --git a/tests/plugin/test_rtpllm_glm5_wrapper_lifecycle.py b/tests/plugin/test_rtpllm_glm5_wrapper_lifecycle.py new file mode 100644 index 0000000000..d8c37cefad --- /dev/null +++ b/tests/plugin/test_rtpllm_glm5_wrapper_lifecycle.py @@ -0,0 +1,619 @@ +"""Lifecycle tests for the GLM5 rtp-llm wrapper.""" + +import ast +from contextlib import nullcontext +import importlib +import os +from pathlib import Path +import sys +from types import ModuleType, SimpleNamespace +from unittest.mock import MagicMock, call, patch + +import torch + +_ATOM_ROOT = Path(__file__).resolve().parents[2] +_FORBIDDEN_IMPORT_TIME_SPARSE_KERNELS = { + "flashmla_sparse", + "flash_mla", + "sparse_mla", + "attention_mla_sparse", +} + + +def _package(name: str) -> ModuleType: + module = ModuleType(name) + module.__path__ = [] + return module + + +def _install_fake_rtp_modules() -> dict[str, ModuleType]: + fake_config_mod = ModuleType("rtp_llm.config.model_config") + + class _FakeModelConfig: + pass + + fake_config_mod.ModelConfig = _FakeModelConfig + + fake_factory_register_mod = ModuleType("rtp_llm.model_factory_register") + fake_factory_register_mod.register_model = MagicMock() + fake_factory_register_mod._model_factory = {} + fake_factory_register_mod._hf_architecture_2_ft = {} + + fake_deepseek_mod = ModuleType("rtp_llm.models.deepseek_v2") + + class _FakeDeepSeekV2: + def _get_device_str(self): + return "cpu" + + def _create_python_model(self): + self.native_create_python_model_called = True + + def load(self, skip_python_model=False): + self.native_load_called = skip_python_model + + fake_deepseek_mod.DeepSeekV2 = _FakeDeepSeekV2 + + fake_weight_info_mod = ModuleType("rtp_llm.model_loader.model_weight_info") + + class _FakeModelWeights: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.global_weights = {} + + def set_global_weight(self, name, tensor): + self.global_weights[name] = tensor + + class _FakeModelDeployWeightInfo: + pass + + fake_weight_info_mod.ModelDeployWeightInfo = _FakeModelDeployWeightInfo + fake_weight_info_mod.ModelWeights = _FakeModelWeights + + fake_module_base_mod = ModuleType("rtp_llm.models_py.model_desc.module_base") + + class _FakeGptModelBase: + def __init__(self, *args, **kwargs): + self.init_args = args + self.init_kwargs = kwargs + + fake_module_base_mod.GptModelBase = _FakeGptModelBase + + fake_ops_mod = ModuleType("rtp_llm.ops") + + class _FakeParallelismConfig: + pass + + fake_ops_mod.ParallelismConfig = _FakeParallelismConfig + + fake_compute_ops_mod = ModuleType("rtp_llm.ops.compute_ops") + + class _FakePyModelInputs: + pass + + class _FakePyModelOutputs: + def __init__(self, hidden_states): + self.hidden_states = hidden_states + + fake_compute_ops_mod.PyModelInputs = _FakePyModelInputs + fake_compute_ops_mod.PyModelOutputs = _FakePyModelOutputs + + fake_weight_mod = ModuleType("rtp_llm.utils.model_weight") + fake_weight_mod.W = SimpleNamespace( + lm_head="lm_head", + embedding="embedding", + final_ln_gamma="final_ln_gamma", + ) + + fake_loader_mod = ModuleType("atom.model_loader.loader") + + class _FakeWeightsMapper: + def __init__(self, **kwargs): + self.kwargs = kwargs + + fake_loader_mod.WeightsMapper = _FakeWeightsMapper + fake_loader_mod.load_model_in_plugin_mode = MagicMock() + + return { + "atom.model_loader": _package("atom.model_loader"), + "atom.model_loader.loader": fake_loader_mod, + "rtp_llm": _package("rtp_llm"), + "rtp_llm.config": _package("rtp_llm.config"), + "rtp_llm.config.model_config": fake_config_mod, + "rtp_llm.model_factory_register": fake_factory_register_mod, + "rtp_llm.models": _package("rtp_llm.models"), + "rtp_llm.models.deepseek_v2": fake_deepseek_mod, + "rtp_llm.model_loader": _package("rtp_llm.model_loader"), + "rtp_llm.model_loader.model_weight_info": fake_weight_info_mod, + "rtp_llm.models_py": _package("rtp_llm.models_py"), + "rtp_llm.models_py.model_desc": _package("rtp_llm.models_py.model_desc"), + "rtp_llm.models_py.model_desc.module_base": fake_module_base_mod, + "rtp_llm.ops": fake_ops_mod, + "rtp_llm.ops.compute_ops": fake_compute_ops_mod, + "rtp_llm.utils": _package("rtp_llm.utils"), + "rtp_llm.utils.model_weight": fake_weight_mod, + } + + +def _make_wrapper_instance(cls): + instance = cls.__new__(cls) + instance.model_config = SimpleNamespace( + num_layers=1, + compute_dtype=torch.bfloat16, + ) + instance.parallelism_config = SimpleNamespace() + instance.max_generate_batch_size = 1 + instance.fmha_config = None + instance.hw_kernel_config = None + instance.device_resource_config = None + return instance + + +def test_glm5_load_skip_python_model_does_not_create_atom_model(): + fake_modules = _install_fake_rtp_modules() + + with ( + patch.dict(sys.modules, fake_modules), + patch.dict( + os.environ, + {"RTP_LLM_EXTERNAL_MODEL_PACKAGES": "atom.plugin.rtpllm.models"}, + ), + ): + sys.modules.pop("atom.plugin.rtpllm.models.glm5", None) + module = importlib.import_module("atom.plugin.rtpllm.models.glm5") + module = importlib.reload(module) + instance = _make_wrapper_instance(module.ATOMGlm5Moe) + instance._create_python_model = MagicMock() + + instance.load(skip_python_model=True) + + instance._create_python_model.assert_not_called() + assert instance.device == "cpu" + assert isinstance(instance.model_weights_loader, module._NoopModelWeightsLoader) + assert isinstance(instance.weight_manager, module._NoopWeightManager) + + +def _patch_optional_attr(module, attr): + if hasattr(module, attr): + return patch.object(module, attr) + return nullcontext(MagicMock(name=attr)) + + +def _read_plugin_file(relative_path: str) -> str: + return (_ATOM_ROOT / relative_path).read_text() + + +def test_glm5_create_python_model_lets_prepare_model_own_mla_patching(): + fake_modules = _install_fake_rtp_modules() + fake_atom_model = MagicMock(name="atom_model") + fake_atom_model.to.return_value = fake_atom_model + fake_utils_mod = ModuleType("atom.plugin.rtpllm.utils") + + class _FakeRTPForwardMLAContext: + @staticmethod + def collect_layer_maps(model): + return ({}, {}, {}) + + fake_utils_mod.RTPForwardMLAContext = _FakeRTPForwardMLAContext + + with ( + patch.dict( + sys.modules, + fake_modules, + ), + patch.dict( + os.environ, + {"RTP_LLM_EXTERNAL_MODEL_PACKAGES": "atom.plugin.rtpllm.models"}, + ), + patch.dict(sys.modules, {"atom.plugin.rtpllm.utils": fake_utils_mod}), + patch( + "atom.prepare_model", return_value=fake_atom_model, create=True + ) as prepare_model, + ): + sys.modules.pop("atom.plugin.rtpllm.models.glm5", None) + module = importlib.import_module("atom.plugin.rtpllm.models.glm5") + module = importlib.reload(module) + instance = _make_wrapper_instance(module.ATOMGlm5Moe) + instance.device = "cpu" + instance.weight = MagicMock() + + with ( + _patch_optional_attr( + module, "apply_attention_mla_rtpllm_patch" + ) as mla_patch, + _patch_optional_attr( + module, "apply_deepseek_mla_rtpllm_patch" + ) as deepseek_patch, + ): + result = instance._create_python_model() + + prepare_model.assert_called_once_with(config=instance, engine="rtpllm") + mla_patch.assert_not_called() + deepseek_patch.assert_not_called() + load_model_in_plugin_mode = fake_modules[ + "atom.model_loader.loader" + ].load_model_in_plugin_mode + load_model_in_plugin_mode.assert_called_once() + assert result is instance.py_model + + +def test_glm5_support_cuda_graph_honors_eager_env(): + fake_modules = _install_fake_rtp_modules() + + with ( + patch.dict(sys.modules, fake_modules), + patch.dict( + os.environ, + { + "RTP_LLM_EXTERNAL_MODEL_PACKAGES": "atom.plugin.rtpllm.models", + "ENABLE_CUDA_GRAPH": "0", + }, + ), + ): + sys.modules.pop("atom.plugin.rtpllm.models.glm5", None) + module = importlib.import_module("atom.plugin.rtpllm.models.glm5") + module = importlib.reload(module) + instance = _make_wrapper_instance(module.ATOMGlm5Moe) + + assert instance.support_cuda_graph() is False + + +def test_glm5_runtime_uses_mla_forward_context_class(): + fake_modules = _install_fake_rtp_modules() + fake_utils_mod = ModuleType("atom.plugin.rtpllm.utils") + marker_context_cls = object() + fake_utils_mod.RTPForwardMLAContext = marker_context_cls + + with ( + patch.dict(sys.modules, fake_modules), + patch.dict(sys.modules, {"atom.plugin.rtpllm.utils": fake_utils_mod}), + patch.dict( + os.environ, + {"RTP_LLM_EXTERNAL_MODEL_PACKAGES": "atom.plugin.rtpllm.models"}, + ), + ): + sys.modules.pop("atom.plugin.rtpllm.models.glm5", None) + module = importlib.import_module("atom.plugin.rtpllm.models.glm5") + module = importlib.reload(module) + module.RTPForwardContext = None + + context_cls = module._ATOMGlm5MoeRuntime._get_forward_context_cls() + + assert context_cls is marker_context_cls + + +def test_glm5_runtime_forward_wraps_model_call_in_rtp_context(monkeypatch): + fake_modules = _install_fake_rtp_modules() + expected_input_ids = torch.tensor([10, 11], dtype=torch.int64) + position_ids = torch.tensor([5, 6], dtype=torch.int32) + hidden_states = torch.randn(2, 4) + events = [] + + class _FakeAtomModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(1)) + + def forward(self, *, input_ids, positions, intermediate_tensors, inputs_embeds): + events.append(("model", bool(_FakeRTPForwardContext.in_context))) + assert torch.equal(input_ids, expected_input_ids) + assert torch.equal(positions, position_ids.to(torch.long)) + assert positions.dtype == torch.long + assert intermediate_tensors is None + assert inputs_embeds is None + return hidden_states + + class _FakeBind: + def __enter__(self): + _FakeRTPForwardContext.in_context = True + events.append(("enter", None)) + + def __exit__(self, exc_type, exc, tb): + events.append(("exit", None)) + _FakeRTPForwardContext.in_context = False + + class _FakeRTPForwardContext: + in_context = False + + @staticmethod + def collect_layer_maps(model): + return ({}, {}, {}) + + @staticmethod + def bind(**kwargs): + assert torch.equal(kwargs["positions"], position_ids.to(torch.long)) + assert kwargs["positions"].dtype == torch.long + return _FakeBind() + + with ( + patch.dict(sys.modules, fake_modules), + patch.dict( + os.environ, + {"RTP_LLM_EXTERNAL_MODEL_PACKAGES": "atom.plugin.rtpllm.models"}, + ), + ): + sys.modules.pop("atom.plugin.rtpllm.models.glm5", None) + module = importlib.import_module("atom.plugin.rtpllm.models.glm5") + module = importlib.reload(module) + monkeypatch.setattr(module, "RTPForwardContext", _FakeRTPForwardContext) + runtime = module._ATOMGlm5MoeRuntime( + model_config=SimpleNamespace(max_seq_len=16), + parallelism_config=SimpleNamespace(), + weights=MagicMock(), + max_generate_batch_size=2, + atom_model=_FakeAtomModel(), + ) + runtime.kv_cache = SimpleNamespace() + inputs = SimpleNamespace( + input_ids=expected_input_ids, + input_hiddens=None, + attention_inputs=SimpleNamespace(position_ids=position_ids), + ) + + output = runtime.forward(inputs) + + assert output.hidden_states is hidden_states + assert events == [("enter", None), ("model", True), ("exit", None)] + + +def test_glm5_runtime_prepare_fmha_impl_bypasses_native_mla_factory(monkeypatch): + fake_modules = _install_fake_rtp_modules() + + class _FakeRTPForwardContext: + @staticmethod + def collect_layer_maps(model): + return ({}, {}, {}) + + with ( + patch.dict(sys.modules, fake_modules), + patch.dict( + os.environ, + {"RTP_LLM_EXTERNAL_MODEL_PACKAGES": "atom.plugin.rtpllm.models"}, + ), + ): + sys.modules.pop("atom.plugin.rtpllm.models.glm5", None) + module = importlib.import_module("atom.plugin.rtpllm.models.glm5") + module = importlib.reload(module) + monkeypatch.setattr(module, "RTPForwardContext", _FakeRTPForwardContext) + atom_model = torch.nn.Linear(1, 1) + runtime = module._ATOMGlm5MoeRuntime( + model_config=SimpleNamespace(max_seq_len=16), + parallelism_config=SimpleNamespace(), + weights=MagicMock(), + max_generate_batch_size=2, + atom_model=atom_model, + ) + inputs = SimpleNamespace(attention_inputs=SimpleNamespace()) + + attn_pyobj = runtime.prepare_fmha_impl(inputs, is_cuda_graph=False) + + assert attn_pyobj.fmha_params is None + assert attn_pyobj.is_cuda_graph is False + assert hasattr(attn_pyobj, "prepare_cuda_graph") + + +def test_glm5_runtime_decode_positions_prefer_sequence_lengths_plus_one(): + fake_modules = _install_fake_rtp_modules() + + with ( + patch.dict(sys.modules, fake_modules), + patch.dict( + os.environ, + {"RTP_LLM_EXTERNAL_MODEL_PACKAGES": "atom.plugin.rtpllm.models"}, + ), + ): + sys.modules.pop("atom.plugin.rtpllm.models.glm5", None) + module = importlib.import_module("atom.plugin.rtpllm.models.glm5") + module = importlib.reload(module) + runtime = object.__new__(module._ATOMGlm5MoeRuntime) + attn_inputs = SimpleNamespace( + input_lengths=torch.tensor([1, 2], dtype=torch.int32), + is_prefill=False, + sequence_lengths=torch.tensor([999, 999], dtype=torch.int32), + sequence_lengths_plus_1_d=torch.tensor([35, 50], dtype=torch.int32), + ) + + positions = runtime._build_positions_from_attention_inputs( + attn_inputs=attn_inputs, + model_device=torch.device("cpu"), + ) + + assert positions.cpu().tolist() == [34, 48, 49] + + +def test_glm5_runtime_graph_decode_ignores_stale_position_ids(): + fake_modules = _install_fake_rtp_modules() + + with ( + patch.dict(sys.modules, fake_modules), + patch.dict( + os.environ, + {"RTP_LLM_EXTERNAL_MODEL_PACKAGES": "atom.plugin.rtpllm.models"}, + ), + ): + sys.modules.pop("atom.plugin.rtpllm.models.glm5", None) + module = importlib.import_module("atom.plugin.rtpllm.models.glm5") + module = importlib.reload(module) + runtime = object.__new__(module._ATOMGlm5MoeRuntime) + inputs = SimpleNamespace( + bert_embedding_inputs=None, + attention_inputs=SimpleNamespace( + input_lengths=torch.tensor([1, 2], dtype=torch.int32), + is_prefill=False, + is_cuda_graph=True, + position_ids=torch.tensor([0, 0, 0], dtype=torch.int32), + sequence_lengths_plus_1_d=torch.tensor([35, 50], dtype=torch.int32), + ), + ) + + positions = runtime._extract_positions( + inputs=inputs, + model_device=torch.device("cpu"), + token_num=3, + ) + + assert positions.cpu().tolist() == [34, 48, 49] + + +def test_rtpllm_wrapper_registers_glm5_override_and_alias(): + register_model_mock = MagicMock() + + fake_rtp_register_mod = ModuleType("rtp_llm.model_factory_register") + fake_rtp_register_mod.register_model = register_model_mock + fake_rtp_register_mod._model_factory = {} + fake_rtp_register_mod._hf_architecture_2_ft = {} + + fake_atom_register_mod = ModuleType("atom.plugin.register") + fake_atom_register_mod._ATOM_SUPPORTED_MODELS = {} + + fake_atom_deepseek_mod = ModuleType("atom.models.deepseek_v2") + + class _FakeGlmMoeDsaForCausalLM: + pass + + fake_atom_deepseek_mod.GlmMoeDsaForCausalLM = _FakeGlmMoeDsaForCausalLM + + fake_atom_qwen_mod = ModuleType("atom.plugin.rtpllm.models.qwen3_5") + + class _FakeATOMQwen35Moe: + pass + + fake_atom_qwen_mod.ATOMQwen35Moe = _FakeATOMQwen35Moe + + fake_atom_glm_mod = ModuleType("atom.plugin.rtpllm.models.glm5") + + class _FakeATOMGlm5Moe: + pass + + fake_atom_glm_mod.ATOMGlm5Moe = _FakeATOMGlm5Moe + + fake_modules = { + "rtp_llm": _package("rtp_llm"), + "rtp_llm.models": _package("rtp_llm.models"), + "rtp_llm.model_factory_register": fake_rtp_register_mod, + "atom.models.deepseek_v2": fake_atom_deepseek_mod, + "atom.plugin.register": fake_atom_register_mod, + "atom.plugin.rtpllm.models.qwen3_5": fake_atom_qwen_mod, + "atom.plugin.rtpllm.models.glm5": fake_atom_glm_mod, + } + + with patch.dict(sys.modules, fake_modules): + sys.modules.pop("atom.plugin.rtpllm.models", None) + sys.modules.pop("atom.plugin.rtpllm.models.base_model_wrapper", None) + importlib.import_module("atom.plugin.rtpllm.models") + + assert fake_rtp_register_mod._model_factory["glm_5"] is _FakeATOMGlm5Moe + assert ( + fake_rtp_register_mod._hf_architecture_2_ft["GlmMoeDsaForCausalLM"] + == "glm_5" + ) + assert ( + fake_atom_register_mod._ATOM_SUPPORTED_MODELS["GlmMoeDsaForCausalLM"] + is _FakeGlmMoeDsaForCausalLM + ) + register_model_mock.assert_has_calls( + [ + call("atom_qwen35_moe", _FakeATOMQwen35Moe, []), + call("atom_glm5_moe", _FakeATOMGlm5Moe, []), + ], + any_order=False, + ) + + +def test_mla_attention_legacy_boundary_shape_stays_executable_during_migration(): + from atom.plugin.rtpllm.attention_backend.rtp_mla_attention import RTPMLAAttention + + q = torch.empty(2, 4, 256) + compressed_kv = torch.empty(2, 512) + k_pe = torch.empty(2, 64) + positions = torch.arange(2, dtype=torch.int32) + attention = RTPMLAAttention(mla_modules=SimpleNamespace(v_head_dim=128)) + + output = attention(q, compressed_kv, k_pe, positions=positions) + + assert output.shape == (2, 4, 128) + + +def test_mla_attention_is_marked_as_mla_adapter(): + from atom.plugin.rtpllm.attention_backend.rtp_mla_attention import RTPMLAAttention + + assert RTPMLAAttention.use_mla is True + + +def test_glm5_wrapper_does_not_use_mha_or_qwen_patches(): + source = _read_plugin_file("atom/plugin/rtpllm/models/glm5.py") + + assert "RTPFullAttention" not in source + assert "apply_attention_mha_rtpllm_patch" not in source + assert "apply_attention_gdn_rtpllm_patch" not in source + assert "apply_qwen3_next_rtpllm_patch" not in source + + +def test_glm5_wrapper_does_not_import_or_call_deepseek_mla_patch(): + source = _read_plugin_file("atom/plugin/rtpllm/models/glm5.py") + + assert "apply_deepseek_mla_rtpllm_patch" not in source + + +def test_rtp_mla_prepare_no_longer_contains_deepseek_forward_monkey_patch(): + assert not ( + _ATOM_ROOT / "atom/plugin/rtpllm/attention_backend/rtp_mla_prepare.py" + ).exists() + + +def test_glm5_mla_backend_is_not_full_attention_adapter(): + source = _read_plugin_file( + "atom/plugin/rtpllm/attention_backend/rtp_mla_attention.py" + ) + + assert "class RTPMLAAttention" in source + assert "use_mla" in source + assert "RTPFullAttention" not in source + + +def test_sparse_mla_backend_has_no_import_time_cuda_sparse_kernel_dependencies(): + backend_path = ( + _ATOM_ROOT / "atom/plugin/rtpllm/attention_backend/rtp_sparse_mla_backend.py" + ) + assert backend_path.exists() + + tree = ast.parse(backend_path.read_text()) + imported_modules = set() + for node in ast.walk(tree): + if isinstance(node, ast.Import): + imported_modules.update(alias.name for alias in node.names) + elif isinstance(node, ast.ImportFrom) and node.module is not None: + imported_modules.add(node.module) + + assert not any( + forbidden in module_name.split(".") + for module_name in imported_modules + for forbidden in _FORBIDDEN_IMPORT_TIME_SPARSE_KERNELS + ) + + +def test_rtp_mla_patch_updates_deepseek_attention_symbol(monkeypatch): + import types + + from atom.plugin.rtpllm.attention_backend.rtp_mla_attention import ( + RTPMLAAttention, + apply_attention_mla_rtpllm_patch, + ) + + sentinel = object() + fake_ops = types.ModuleType("atom.model_ops") + fake_ops.Attention = sentinel + fake_base_attention = types.ModuleType("atom.model_ops.base_attention") + fake_base_attention.Attention = sentinel + fake_deepseek = types.ModuleType("atom.models.deepseek_v2") + fake_deepseek.Attention = sentinel + monkeypatch.setitem(sys.modules, "atom.model_ops", fake_ops) + monkeypatch.setitem( + sys.modules, "atom.model_ops.base_attention", fake_base_attention + ) + monkeypatch.setitem(sys.modules, "atom.models.deepseek_v2", fake_deepseek) + + apply_attention_mla_rtpllm_patch() + + assert fake_ops.Attention is RTPMLAAttention + assert fake_base_attention.Attention is RTPMLAAttention + assert fake_deepseek.Attention is RTPMLAAttention diff --git a/tests/plugin/test_rtpllm_model_wrapper.py b/tests/plugin/test_rtpllm_model_wrapper.py index 9ff4838d2a..cafbbbbd4c 100644 --- a/tests/plugin/test_rtpllm_model_wrapper.py +++ b/tests/plugin/test_rtpllm_model_wrapper.py @@ -3,7 +3,7 @@ import importlib import sys from types import ModuleType -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, call, patch def _package(name: str) -> ModuleType: @@ -26,12 +26,19 @@ class _FakeATOMQwen35Moe: pass fake_atom_qwen_mod.ATOMQwen35Moe = _FakeATOMQwen35Moe + fake_atom_glm_mod = ModuleType("atom.plugin.rtpllm.models.glm5") + + class _FakeATOMGlm5Moe: + pass + + fake_atom_glm_mod.ATOMGlm5Moe = _FakeATOMGlm5Moe fake_modules = { "rtp_llm": _package("rtp_llm"), "rtp_llm.models": _package("rtp_llm.models"), "rtp_llm.model_factory_register": fake_register_mod, "atom.plugin.rtpllm.models.qwen3_5": fake_atom_qwen_mod, + "atom.plugin.rtpllm.models.glm5": fake_atom_glm_mod, } with patch.dict(sys.modules, fake_modules): @@ -46,6 +53,10 @@ class _FakeATOMQwen35Moe: ] == "qwen35_moe" ) - register_model_mock.assert_called_with( - "atom_qwen35_moe", _FakeATOMQwen35Moe, [] + register_model_mock.assert_has_calls( + [ + call("atom_qwen35_moe", _FakeATOMQwen35Moe, []), + call("atom_glm5_moe", _FakeATOMGlm5Moe, []), + ], + any_order=False, ) diff --git a/tests/plugin/test_rtpllm_prepare_model.py b/tests/plugin/test_rtpllm_prepare_model.py index 0ff0114fa0..6dcc7c3460 100644 --- a/tests/plugin/test_rtpllm_prepare_model.py +++ b/tests/plugin/test_rtpllm_prepare_model.py @@ -65,3 +65,47 @@ def test_prepare_model_rtpllm_happy_path(): fake_quant_config.remap_layer_name.assert_called_once() fake_model_cls.assert_called_once_with(atom_config=fake_atom_config) assert result is fake_model + + +def test_prepare_model_rtpllm_glm5_reapplies_mla_attention_patch(): + fake_atom_config = _Obj( + hf_config=_Obj(architectures=["GlmMoeDsaForCausalLM"]), + plugin_config=_Obj(is_plugin_mode=True), + quant_config=_Obj( + exclude_layers=[], + remap_layer_name=MagicMock(), + ), + ) + fake_model = MagicMock(name="FakeGlm5") + fake_model_cls = MagicMock(return_value=fake_model) + + fake_register = MagicMock() + fake_register._ATOM_SUPPORTED_MODELS = {"GlmMoeDsaForCausalLM": fake_model_cls} + fake_register.register_ops_to_sglang = MagicMock() + fake_register.init_aiter_dist = MagicMock() + fake_register.set_attn_cls = MagicMock() + + fake_config_mod = MagicMock() + fake_config_mod.generate_atom_config_for_plugin_mode = MagicMock( + return_value=fake_atom_config + ) + + fake_rtpllm_attention_backend = MagicMock() + + with patch.dict( + sys.modules, + { + "atom.plugin.register": fake_register, + "atom.plugin.config": fake_config_mod, + "atom.plugin.rtpllm.attention_backend": fake_rtpllm_attention_backend, + }, + ): + result = plugin_prepare.prepare_model( + config=_Obj(model_config=_Obj()), engine="rtpllm" + ) + + fake_register.set_attn_cls.assert_called_once() + fake_rtpllm_attention_backend.apply_attention_mla_rtpllm_patch.assert_called_once() + fake_atom_config.quant_config.remap_layer_name.assert_called_once() + fake_model_cls.assert_called_once_with(atom_config=fake_atom_config) + assert result is fake_model