Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion atom/plugin/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def _generate_atom_config_from_rtpllm_config(config: Any):

return Config(
model=rtpllm_model_config.ckpt_path,
max_num_batched_tokens=max(16384, max_generate_batch_size),
max_num_batched_tokens=max(max_model_len, max_generate_batch_size),
max_num_seqs=max_generate_batch_size,
max_model_len=max_model_len,
gpu_memory_utilization=0.9,
Expand Down
8 changes: 8 additions & 0 deletions atom/plugin/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def _prepare_model_atom_sglang(
def _prepare_model_atom_rtpllm(
config: Any,
atom_config: Any,
model_arch: str,
model_cls: Any,
set_attn_cls: Any,
init_aiter_dist: Any,
Expand All @@ -120,6 +121,12 @@ def _prepare_model_atom_rtpllm(
)

set_attn_cls()
if model_arch == "GlmMoeDsaForCausalLM":
from atom.plugin.rtpllm.attention_backend import (
apply_attention_mla_rtpllm_patch,
)

apply_attention_mla_rtpllm_patch()

# init aiter dist for using aiter custom collective ops
init_aiter_dist(config=atom_config)
Expand Down Expand Up @@ -172,6 +179,7 @@ def prepare_model(config: Any, engine: str):
return _prepare_model_atom_rtpllm(
config,
atom_config,
model_arch,
model_cls,
set_attn_cls,
init_aiter_dist,
Expand Down
7 changes: 7 additions & 0 deletions atom/plugin/rtpllm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""RTP-LLM plugin helpers.

Keep the package root import side-effect free. RTP external model registration
is triggered by importing ``atom.plugin.rtpllm.models``.
"""

__all__: list[str] = []
33 changes: 30 additions & 3 deletions atom/plugin/rtpllm/attention_backend/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,37 @@
from .attention_gdn import apply_attention_gdn_rtpllm_patch
from .attention_switch import apply_attention_mha_rtpllm_patch
from .rtp_full_attention import AttentionForRTPLLM, RTPFullAttention
from .rtp_mla_attention import RTPMLAAttention, apply_attention_mla_rtpllm_patch
from .rtp_sparse_mla_backend import RTPSparseMlaBackend


def __getattr__(name):
if name == "AttentionForRTPLLM":
from .rtp_full_attention import AttentionForRTPLLM

return AttentionForRTPLLM
if name == "RTPFullAttention":
from .rtp_full_attention import RTPFullAttention

return RTPFullAttention
if name == "RTPAttention":
from .rtp_full_attention import RTPFullAttention

return RTPFullAttention
if name == "apply_attention_gdn_rtpllm_patch":
from .attention_gdn import apply_attention_gdn_rtpllm_patch

return apply_attention_gdn_rtpllm_patch
if name == "apply_attention_mha_rtpllm_patch":
from .attention_switch import apply_attention_mha_rtpllm_patch

return apply_attention_mha_rtpllm_patch
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")


__all__ = [
"AttentionForRTPLLM",
"RTPFullAttention",
"RTPMLAAttention",
"RTPSparseMlaBackend",
"apply_attention_gdn_rtpllm_patch",
"apply_attention_mha_rtpllm_patch",
"apply_attention_mla_rtpllm_patch",
]
253 changes: 253 additions & 0 deletions atom/plugin/rtpllm/attention_backend/rtp_mla_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
"""RTP-style MLA adapter for GLM5 rtp-llm plugin mode."""

from __future__ import annotations

import inspect
from types import MethodType
from typing import Optional

import torch


def _resolve_index_topk(attn) -> int:
for obj, attr in (
(getattr(attn, "indexer", None), "index_topk"),
(getattr(attn, "indexer", None), "topk_tokens"),
(attn, "index_topk"),
(getattr(attn, "config", None), "index_topk"),
):
value = getattr(obj, attr, None) if obj is not None else None
if value is not None:
return int(value)
raise AttributeError("GLM5 RTP MLA indexer requires index_topk/topk_tokens")


def _get_topk_indices_buffer(attn) -> torch.Tensor:
indexer = getattr(attn, "indexer", None)
buffer = (
getattr(indexer, "topk_indices_buffer", None) if indexer is not None else None
)
if buffer is None:
buffer = getattr(attn, "topk_indices_buffer", None)
if buffer is None:
buffer = getattr(attn, "_topk_indices_buffer", None)
if buffer is None:
raise AttributeError("GLM5 RTP MLA indexer requires topk_indices_buffer")
return buffer


def _should_emit_topk_indices(attn) -> bool:
try:
from atom.utils.forward_context import get_forward_context

forward_context = get_forward_context()
except Exception:
return True

context = getattr(forward_context, "context", None)
if getattr(context, "is_dummy_run", False):
return False
return True


def _use_rtp_sparse_attn_indexer(indexer: object | None) -> None:
if indexer is None or not hasattr(indexer, "sparse_attn_indexer_impl"):
return
__import__("atom.plugin.rtpllm.attention_backend.rtp_sparse_mla_backend")
indexer.sparse_attn_indexer_impl = torch.ops.aiter.rtp_sparse_attn_indexer
if getattr(indexer, "_atom_rtp_topk_buffer_patched", False) or not hasattr(
indexer, "forward"
):
return
original_forward = indexer.forward

def _forward_with_topk_buffer(self, hidden_states, *args, **kwargs):
num_tokens = int(hidden_states.shape[0])
topk_tokens = getattr(self, "topk_tokens", None)
if topk_tokens is None:
topk_tokens = getattr(self, "index_topk")
topk_tokens = int(topk_tokens)
buffer = getattr(self, "topk_indices_buffer", None)
needs_new_buffer = (
buffer is None
or buffer.dim() != 2
or buffer.device != hidden_states.device
or int(buffer.shape[0]) < num_tokens
or int(buffer.shape[1]) < topk_tokens
)
if needs_new_buffer:
buffer = torch.empty(
num_tokens,
topk_tokens,
dtype=torch.int32,
device=hidden_states.device,
)
self.topk_indices_buffer = buffer
self.sparse_kv_indices_buffer = self.topk_indices_buffer
return original_forward(hidden_states, *args, **kwargs)

indexer.forward = MethodType(_forward_with_topk_buffer, indexer)
indexer._atom_rtp_topk_buffer_patched = True


class RTPMLAAttention:
"""RTP MLA adapter for the native GLM5 MLA call contract."""

use_mla = True

def __init__(self, *args, **kwargs) -> None:
self.args = args
self.kwargs = kwargs
mla_modules = kwargs.get("mla_modules")
self.mla_modules = mla_modules
self.q_proj = getattr(mla_modules, "q_proj", None)
self.o_proj = getattr(mla_modules, "o_proj", None)
self.kv_b_proj = getattr(mla_modules, "kv_b_proj", None)
self.indexer = getattr(mla_modules, "indexer", None)
_use_rtp_sparse_attn_indexer(self.indexer)
self.qk_head_dim = getattr(mla_modules, "qk_head_dim", None)
self.v_head_dim = getattr(mla_modules, "v_head_dim", None)
self.q_lora_rank = getattr(mla_modules, "q_lora_rank", None)
self.kv_lora_rank = getattr(mla_modules, "kv_lora_rank", None)
self.num_heads = getattr(mla_modules, "num_heads", None)
self.num_local_heads = getattr(mla_modules, "num_local_heads", self.num_heads)
self.index_topk = getattr(mla_modules, "index_topk", None)
self.topk_indices_buffer = (
getattr(self.indexer, "topk_indices_buffer", None)
if self.indexer is not None
else None
)
injected_backend = kwargs.get("sparse_backend")
if injected_backend is not None:
self.sparse_backend = injected_backend
elif mla_modules is not None:
from atom.plugin.rtpllm.attention_backend.rtp_sparse_mla_backend import (
RTPSparseMlaBackend,
)

self.sparse_backend = RTPSparseMlaBackend(
v_head_dim=mla_modules.v_head_dim,
mla_modules=mla_modules,
scale=kwargs.get("scale"),
)
else:
self.sparse_backend = None
self.kv_cache = kwargs.get("kv_cache")
self.layer_id = int(kwargs.get("layer_id", kwargs.get("layer_num", 0)))
self._sparse_backend_accepts_positions = (
self._backend_accepts_positions(self.sparse_backend)
if self.sparse_backend is not None
else False
)

@staticmethod
def _backend_accepts_positions(backend: object) -> bool:
try:
signature = inspect.signature(backend.forward)
except (AttributeError, TypeError, ValueError):
return False
return "positions" in signature.parameters or any(
parameter.kind == inspect.Parameter.VAR_KEYWORD
for parameter in signature.parameters.values()
)

def _project_query(
self, query: torch.Tensor, q_scale: Optional[torch.Tensor]
) -> tuple[torch.Tensor, bool]:
if query.ndim == 3:
return query, False
if self.q_proj is None:
return query, False

q = self.q_proj(query, q_scale)
if q.ndim == 3:
return q, True

num_heads = (
self.num_local_heads if self.num_local_heads is not None else self.num_heads
)
if num_heads is None:
if self.qk_head_dim is None:
raise AttributeError("GLM5 RTP MLA native contract requires num_heads")
num_heads = q.shape[-1] // int(self.qk_head_dim)
if self.qk_head_dim is None:
self.qk_head_dim = q.shape[-1] // int(num_heads)
return q.reshape(-1, int(num_heads), int(self.qk_head_dim)), True

def _resolve_topk_indices(
self,
query: torch.Tensor,
q_scale: Optional[torch.Tensor],
positions: Optional[torch.Tensor],
explicit_topk_indices: Optional[torch.Tensor],
) -> Optional[torch.Tensor]:
if explicit_topk_indices is not None:
return explicit_topk_indices
if self.indexer is None:
return None

if not _should_emit_topk_indices(self):
return None
index_topk = _resolve_index_topk(self)
return _get_topk_indices_buffer(self)[: query.shape[0], :index_topk]

def forward(
self,
query: torch.Tensor,
compressed_kv: torch.Tensor,
k_pe: torch.Tensor,
positions: Optional[torch.Tensor] = None,
q_scale: Optional[torch.Tensor] = None,
topk_indices: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
if self.sparse_backend is None:
raise NotImplementedError(
"RTPMLAAttention requires an attention backend for contract execution"
)
q, native_projected = self._project_query(query, q_scale)
topk_indices = self._resolve_topk_indices(
query,
q_scale,
positions,
kwargs.get("topk_indices", topk_indices),
)
forward_kwargs = {"topk_indices": topk_indices}
if self._sparse_backend_accepts_positions:
forward_kwargs["positions"] = positions
attn_output = self.sparse_backend.forward(
q,
compressed_kv,
k_pe,
self.kv_cache,
self.layer_id,
**forward_kwargs,
)
if native_projected and self.o_proj is not None:
attn_output = attn_output.reshape(attn_output.shape[0], -1).contiguous()
return self.o_proj(attn_output)
return attn_output

__call__ = forward


def apply_attention_mla_rtpllm_patch() -> None:
"""Switch ATOM's generic Attention symbol to the RTP MLA adapter."""

import importlib
import sys

ops = importlib.import_module("atom.model_ops")
base_attention = importlib.import_module("atom.model_ops.base_attention")

ops.RTPMLAAttention = RTPMLAAttention
ops.Attention = RTPMLAAttention
base_attention.Attention = RTPMLAAttention

deepseek_v2 = sys.modules.get("atom.models.deepseek_v2")
if deepseek_v2 is None:
try:
import atom.models.deepseek_v2 as deepseek_v2
except (ImportError, ModuleNotFoundError):
return
deepseek_v2.Attention = RTPMLAAttention
Loading
Loading