diff --git a/atom/model_ops/attention_mha.py b/atom/model_ops/attention_mha.py index 9ad2916282..dc2cecec32 100644 --- a/atom/model_ops/attention_mha.py +++ b/atom/model_ops/attention_mha.py @@ -8,6 +8,7 @@ from aiter import fused_qk_norm_rope_cache_quant_shuffle from aiter.ops.triton.fused_kv_cache import fused_qk_rope_reshape_and_cache from aiter.jit.utils.chip_info import get_gfx +from aiter import dtypes from aiter.ops.triton.gluon.pa_decode_gluon import get_recommended_splits from aiter.ops.triton.unified_attention import unified_attention from atom.config import get_current_atom_config @@ -588,31 +589,100 @@ def paged_attention_asm( @mark_trace(prefix="paged_attention_persistent_asm", torch_compile=False) def paged_attention_persistent_asm( - self, q, k, v, k_cache, v_cache, k_scale, v_scale, fwd_ctx: ForwardContext + self, q, k, v, k_cache, v_cache, k_scale, v_scale, fwd_ctx: ForwardContext, sink=None ): attn_metadata = fwd_ctx.attn_metadata - output = torch.empty_like(q) + output = torch.empty(*q.shape, dtype=torch.bfloat16, device=q.device) - aiter.pa_persistent_fwd( - Q=q, - K=k_cache, - V=v_cache, - output=output, - max_qlen=attn_metadata.max_seqlen_q, - qo_indptr=attn_metadata.cu_seqlens_q, - kv_indptr=attn_metadata.kv_indptr, - kv_indices=attn_metadata.kv_indices, - context_lens=attn_metadata.context_lens, - K_QScale=k_scale, - V_QScale=v_scale, - work_indptr=attn_metadata.work_indptr, - work_info=attn_metadata.work_info_set, - reduce_indptr=attn_metadata.reduce_indptr, - reduce_final_map=attn_metadata.reduce_final_map, - reduce_partial_map=attn_metadata.reduce_partial_map, - softmax_scale=self.scale, - mask=1, - ) + if self.sinks is None: + aiter.pa_persistent_fwd( + Q=q, + K=k_cache, + V=v_cache, + output=output, + max_qlen=attn_metadata.max_seqlen_q, + qo_indptr=attn_metadata.cu_seqlens_q, + kv_indptr=attn_metadata.kv_indptr, + kv_indices=attn_metadata.kv_indices, + context_lens=attn_metadata.context_lens, + K_QScale=k_scale, + V_QScale=v_scale, + work_indptr=attn_metadata.work_indptr, + work_info=attn_metadata.work_info_set, + reduce_indptr=attn_metadata.reduce_indptr, + reduce_final_map=attn_metadata.reduce_final_map, + reduce_partial_map=attn_metadata.reduce_partial_map, + softmax_scale=self.scale, + mask=1, + ) + else: + device = q.device + total_s, nhead, v_head_dim = output.shape + softmax_scale = self.scale if self.scale is not None else 1.0 / (v_head_dim**0.5) + split_o = torch.empty( + ( + attn_metadata.reduce_partial_map.size(0) + * attn_metadata.max_seqlen_q, + 1, + nhead, + v_head_dim, + ), + dtype=dtypes.fp32, + device=device, + ) + split_lse = torch.empty( + ( + attn_metadata.reduce_partial_map.size(0) + * attn_metadata.max_seqlen_q, + 1, + nhead, + 1, + ), + dtype=dtypes.fp32, + device=device, + ) + final_lse = torch.empty((total_s, nhead), dtype=dtypes.fp32, device=device) + assert self.num_heads % self.num_kv_heads == 0 + aiter.pa_decode_bf16_asm( + q.to(dtypes.fp8), + k_cache, + v_cache, + attn_metadata.kv_indices, + attn_metadata.context_lens, + self.scale, + attn_metadata.kv_indptr, + gqa=self.num_heads // self.num_kv_heads, + mtp=attn_metadata.max_seqlen_q, + query_scale=self.kv_scale, + key_scale=self.kv_scale, + value_scale=self.kv_scale, + qo_indptr=attn_metadata.cu_seqlens_q, + work_indptr=attn_metadata.work_indptr, + work_info=attn_metadata.work_info_set, + split_o=split_o, + split_lse=split_lse, + sink=sink, + ) + bs = attn_metadata.cu_seqlens_q.shape[0] - 1 + if int(attn_metadata.max_seqlen_k) > 256: + final_lse = torch.empty( + (bs * attn_metadata.max_seqlen_q, self.num_heads), + dtype=torch.float32, + device=q.device, + ) + aiter.pa_reduce_v1( + split_o, + split_lse, + attn_metadata.reduce_indptr, + attn_metadata.reduce_final_map, + attn_metadata.reduce_partial_map, + attn_metadata.max_seqlen_q, + 16, + output.view( + bs * attn_metadata.max_seqlen_q, self.num_heads, self.head_dim + ), + final_lse, + ) return output @@ -739,14 +809,19 @@ def dispatch_backend( return self.prefill_attention_triton return self.prefill_attention else: - if use_unified_attn or self.use_triton_attn or self.use_flash_layout: + # if use_unified_attn or self.use_triton_attn or self.use_flash_layout: + # return self.paged_attention_triton + # else: + # atom_config = get_current_atom_config() + # if atom_config.kv_cache_block_size == 1024 or ( + # atom_config.kv_cache_block_size == 256 and self.sinks is not None + # ): + # return self.paged_attention_persistent_asm + # return self.paged_attention_asm + if self.sliding_window != -1: return self.paged_attention_triton else: - # Only use pa persistent when block_size == 1024 - atom_config = get_current_atom_config() - if atom_config.kv_cache_block_size == 1024: - return self.paged_attention_persistent_asm - return self.paged_attention_asm + return self.paged_attention_persistent_asm def forward( self, diff --git a/atom/model_ops/attentions/aiter_attention.py b/atom/model_ops/attentions/aiter_attention.py index 078fed004f..26bb47c799 100644 --- a/atom/model_ops/attentions/aiter_attention.py +++ b/atom/model_ops/attentions/aiter_attention.py @@ -47,21 +47,10 @@ def __init__( device=None, model_runner=None, ): - self.block_size = 1024 if model_runner.block_size == 1024 else 16 - if envs.ATOM_USE_UNIFIED_ATTN: - # SHUFFLE (pre-shuffled) KV cache: use the logical block size directly - # as the physical block size so block_ratio == 1 and - # unified_attention's block_table needs no logical->physical - # conversion. Pass --block-size equal to the performant physical - # page: fp8 packs x=16 - 128; bf16 packs x=8 - 64 (both keep a - # 128-byte physical page, i.e. block_size // x == 8). - expected = 128 if model_runner.kv_cache_dtype in ("fp8",) else 64 - assert model_runner.block_size == expected, ( - f"ATOM_USE_UNIFIED_ATTN=1 expects --block-size {expected} " - f"for {model_runner.kv_cache_dtype} KV cache (so block_ratio == 1), " - f"got --block-size {model_runner.block_size}" - ) - self.block_size = model_runner.block_size + # self.block_size = 1024 if model_runner.block_size == 1024 else 16 + self.block_size = model_runner.block_size if model_runner.block_size in [1024, 256] else 16 + # if envs.ATOM_USE_UNIFIED_ATTN: + # self.block_size = 128 assert ( model_runner.block_size % self.block_size == 0 ), f"model_runner.block_size must be divisible by block_size but got {model_runner.block_size=}, block_size={self.block_size}, please set --block-size (model_runner.block_size) to be divisible by {self.block_size}" @@ -240,7 +229,6 @@ def set_aiter_persistent_worker_buffers(self, bs: int): 1, hf_config.num_key_value_heads // get_tp_group().world_size ) block_size = self.block_size - var = self.model_runner.forward_vars max_qlen = var["max_qlen"] @@ -753,7 +741,7 @@ def prepare_decode(self, batch: ScheduledBatch, bs: int): ] ctx = {el: var[el].copy_to_gpu(num) for el, num in vars_used} - if self.block_size == 1024: + if self.block_size in [256, 1024]: ctx_pa_ps = self.set_aiter_persistent_worker_buffers(bs) ctx.update(ctx_pa_ps) @@ -876,7 +864,7 @@ def _prepare_ubatch_decode( ) # Set PA persistent worker buffers for this ubatch - if self.block_size == 1024: + if self.block_size in [256, 1024]: self._set_ubatch_pa_buffers(padded_bs, max_seqlen_q, ub_idx) def _set_ubatch_pa_buffers(self, padded_bs, max_q_len, ubatch_idx): @@ -922,7 +910,7 @@ def build_ubatch_metadata( max_q_len = var["max_qlen"] # Compute PA work buffers for this ubatch - if self.block_size == 1024: + if self.block_size in [256, 1024]: self._set_ubatch_pa_buffers(padded_bs, max_q_len, ubatch_idx) attn = AttentionMetaData( @@ -943,8 +931,8 @@ def build_ubatch_metadata( return attn def build_for_cudagraph_capture(self, bs: int) -> AttentionMetaData: - var = self.model_runner.forward_vars - if self.block_size == 1024: + var = self.model_runner.forward_vars + if self.block_size in [256, 1024]: ctx_pa_ps = self.set_aiter_persistent_worker_buffers(bs) else: ctx_pa_ps = {}