Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 103 additions & 28 deletions atom/model_ops/attention_mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from aiter import fused_qk_norm_rope_cache_quant_shuffle
from aiter.ops.triton.fused_kv_cache import fused_qk_rope_reshape_and_cache
from aiter.jit.utils.chip_info import get_gfx
from aiter import dtypes
from aiter.ops.triton.gluon.pa_decode_gluon import get_recommended_splits
from aiter.ops.triton.unified_attention import unified_attention
from atom.config import get_current_atom_config

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ [ruff] <F401> reported by reviewdog 🐶
atom.config.get_current_atom_config imported but unused

reviewdog suggestion errorGitHub comment range and suggestion line range must be same. L14-L14 v.s. L14-L15

Expand Down Expand Up @@ -588,31 +589,100 @@ def paged_attention_asm(

@mark_trace(prefix="paged_attention_persistent_asm", torch_compile=False)
def paged_attention_persistent_asm(
self, q, k, v, k_cache, v_cache, k_scale, v_scale, fwd_ctx: ForwardContext
self, q, k, v, k_cache, v_cache, k_scale, v_scale, fwd_ctx: ForwardContext, sink=None
):
attn_metadata = fwd_ctx.attn_metadata
output = torch.empty_like(q)
output = torch.empty(*q.shape, dtype=torch.bfloat16, device=q.device)

aiter.pa_persistent_fwd(
Q=q,
K=k_cache,
V=v_cache,
output=output,
max_qlen=attn_metadata.max_seqlen_q,
qo_indptr=attn_metadata.cu_seqlens_q,
kv_indptr=attn_metadata.kv_indptr,
kv_indices=attn_metadata.kv_indices,
context_lens=attn_metadata.context_lens,
K_QScale=k_scale,
V_QScale=v_scale,
work_indptr=attn_metadata.work_indptr,
work_info=attn_metadata.work_info_set,
reduce_indptr=attn_metadata.reduce_indptr,
reduce_final_map=attn_metadata.reduce_final_map,
reduce_partial_map=attn_metadata.reduce_partial_map,
softmax_scale=self.scale,
mask=1,
)
if self.sinks is None:
aiter.pa_persistent_fwd(
Q=q,
K=k_cache,
V=v_cache,
output=output,
max_qlen=attn_metadata.max_seqlen_q,
qo_indptr=attn_metadata.cu_seqlens_q,
kv_indptr=attn_metadata.kv_indptr,
kv_indices=attn_metadata.kv_indices,
context_lens=attn_metadata.context_lens,
K_QScale=k_scale,
V_QScale=v_scale,
work_indptr=attn_metadata.work_indptr,
work_info=attn_metadata.work_info_set,
reduce_indptr=attn_metadata.reduce_indptr,
reduce_final_map=attn_metadata.reduce_final_map,
reduce_partial_map=attn_metadata.reduce_partial_map,
softmax_scale=self.scale,
mask=1,
)
else:
device = q.device
total_s, nhead, v_head_dim = output.shape
softmax_scale = self.scale if self.scale is not None else 1.0 / (v_head_dim**0.5)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ [ruff] <F841> reported by reviewdog 🐶
Local variable softmax_scale is assigned to but never used

Suggested change
softmax_scale = self.scale if self.scale is not None else 1.0 / (v_head_dim**0.5)
self.scale if self.scale is not None else 1.0 / (v_head_dim**0.5)

split_o = torch.empty(
(
attn_metadata.reduce_partial_map.size(0)
* attn_metadata.max_seqlen_q,
1,
nhead,
v_head_dim,
),
dtype=dtypes.fp32,
device=device,
)
split_lse = torch.empty(
(
attn_metadata.reduce_partial_map.size(0)
* attn_metadata.max_seqlen_q,
1,
nhead,
1,
),
dtype=dtypes.fp32,
device=device,
)
final_lse = torch.empty((total_s, nhead), dtype=dtypes.fp32, device=device)
assert self.num_heads % self.num_kv_heads == 0
aiter.pa_decode_bf16_asm(
q.to(dtypes.fp8),
k_cache,
v_cache,
attn_metadata.kv_indices,
attn_metadata.context_lens,
self.scale,
attn_metadata.kv_indptr,
gqa=self.num_heads // self.num_kv_heads,
mtp=attn_metadata.max_seqlen_q,
query_scale=self.kv_scale,
key_scale=self.kv_scale,
value_scale=self.kv_scale,
qo_indptr=attn_metadata.cu_seqlens_q,
work_indptr=attn_metadata.work_indptr,
work_info=attn_metadata.work_info_set,
split_o=split_o,
split_lse=split_lse,
sink=sink,
)
bs = attn_metadata.cu_seqlens_q.shape[0] - 1
if int(attn_metadata.max_seqlen_k) > 256:
final_lse = torch.empty(
(bs * attn_metadata.max_seqlen_q, self.num_heads),
dtype=torch.float32,
device=q.device,
)
aiter.pa_reduce_v1(
split_o,
split_lse,
attn_metadata.reduce_indptr,
attn_metadata.reduce_final_map,
attn_metadata.reduce_partial_map,
attn_metadata.max_seqlen_q,
16,
output.view(
bs * attn_metadata.max_seqlen_q, self.num_heads, self.head_dim
),
final_lse,
)

return output

Expand Down Expand Up @@ -739,14 +809,19 @@ def dispatch_backend(
return self.prefill_attention_triton
return self.prefill_attention
else:
if use_unified_attn or self.use_triton_attn or self.use_flash_layout:
# if use_unified_attn or self.use_triton_attn or self.use_flash_layout:
# return self.paged_attention_triton
# else:
# atom_config = get_current_atom_config()
# if atom_config.kv_cache_block_size == 1024 or (
# atom_config.kv_cache_block_size == 256 and self.sinks is not None
# ):
# return self.paged_attention_persistent_asm
# return self.paged_attention_asm
if self.sliding_window != -1:
return self.paged_attention_triton
else:
# Only use pa persistent when block_size == 1024
atom_config = get_current_atom_config()
if atom_config.kv_cache_block_size == 1024:
return self.paged_attention_persistent_asm
return self.paged_attention_asm
return self.paged_attention_persistent_asm

def forward(
self,
Expand Down
30 changes: 9 additions & 21 deletions atom/model_ops/attentions/aiter_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,21 +47,10 @@ def __init__(
device=None,
model_runner=None,
):
self.block_size = 1024 if model_runner.block_size == 1024 else 16
if envs.ATOM_USE_UNIFIED_ATTN:
# SHUFFLE (pre-shuffled) KV cache: use the logical block size directly
# as the physical block size so block_ratio == 1 and
# unified_attention's block_table needs no logical->physical
# conversion. Pass --block-size equal to the performant physical
# page: fp8 packs x=16 - 128; bf16 packs x=8 - 64 (both keep a
# 128-byte physical page, i.e. block_size // x == 8).
expected = 128 if model_runner.kv_cache_dtype in ("fp8",) else 64
assert model_runner.block_size == expected, (
f"ATOM_USE_UNIFIED_ATTN=1 expects --block-size {expected} "
f"for {model_runner.kv_cache_dtype} KV cache (so block_ratio == 1), "
f"got --block-size {model_runner.block_size}"
)
self.block_size = model_runner.block_size
# self.block_size = 1024 if model_runner.block_size == 1024 else 16
self.block_size = model_runner.block_size if model_runner.block_size in [1024, 256] else 16
# if envs.ATOM_USE_UNIFIED_ATTN:
# self.block_size = 128
assert (
model_runner.block_size % self.block_size == 0
), f"model_runner.block_size must be divisible by block_size but got {model_runner.block_size=}, block_size={self.block_size}, please set --block-size (model_runner.block_size) to be divisible by {self.block_size}"
Expand Down Expand Up @@ -240,7 +229,6 @@ def set_aiter_persistent_worker_buffers(self, bs: int):
1, hf_config.num_key_value_heads // get_tp_group().world_size
)
block_size = self.block_size

var = self.model_runner.forward_vars
max_qlen = var["max_qlen"]

Expand Down Expand Up @@ -753,7 +741,7 @@ def prepare_decode(self, batch: ScheduledBatch, bs: int):
]

ctx = {el: var[el].copy_to_gpu(num) for el, num in vars_used}
if self.block_size == 1024:
if self.block_size in [256, 1024]:
ctx_pa_ps = self.set_aiter_persistent_worker_buffers(bs)
ctx.update(ctx_pa_ps)

Expand Down Expand Up @@ -876,7 +864,7 @@ def _prepare_ubatch_decode(
)

# Set PA persistent worker buffers for this ubatch
if self.block_size == 1024:
if self.block_size in [256, 1024]:
self._set_ubatch_pa_buffers(padded_bs, max_seqlen_q, ub_idx)

def _set_ubatch_pa_buffers(self, padded_bs, max_q_len, ubatch_idx):
Expand Down Expand Up @@ -922,7 +910,7 @@ def build_ubatch_metadata(
max_q_len = var["max_qlen"]

# Compute PA work buffers for this ubatch
if self.block_size == 1024:
if self.block_size in [256, 1024]:
self._set_ubatch_pa_buffers(padded_bs, max_q_len, ubatch_idx)

attn = AttentionMetaData(
Expand All @@ -943,8 +931,8 @@ def build_ubatch_metadata(
return attn

def build_for_cudagraph_capture(self, bs: int) -> AttentionMetaData:
var = self.model_runner.forward_vars
if self.block_size == 1024:
var = self.model_runner.forward_vars
if self.block_size in [256, 1024]:
ctx_pa_ps = self.set_aiter_persistent_worker_buffers(bs)
else:
ctx_pa_ps = {}
Expand Down
Loading