From 97f1a7b86b8d078eef8dcbd3b26dd2c14a527e95 Mon Sep 17 00:00:00 2001 From: ShaoChunLee Date: Thu, 18 Jun 2026 23:34:56 +0000 Subject: [PATCH] add integration --- atom/model_ops/v4_kernels/paged_decode.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/atom/model_ops/v4_kernels/paged_decode.py b/atom/model_ops/v4_kernels/paged_decode.py index e91acd947e..19e024b62d 100644 --- a/atom/model_ops/v4_kernels/paged_decode.py +++ b/atom/model_ops/v4_kernels/paged_decode.py @@ -56,6 +56,7 @@ import triton.language as tl from aiter.ops.triton.utils.device_info import get_num_sms from atom.model_ops.sparse_attn_v4 import _sparse_attn_ragged_torch +from aiter.ops.triton.attention.pa_decode_sparse import pa_decode_sparse LOG2E = 1.4426950408889634 # log2(e); folded into qk_scale so softmax can use exp2. _MAX_KV_SPLITS = 64 # Hard cap on kv_splits (see _kv_splits_heuristic). @@ -916,7 +917,18 @@ def sparse_attn_v4_paged_decode( When ``kv_scales`` is provided, ``unified_kv`` must be fp8 (e4m3fnuz) and will be dequantized in-kernel using 1xGROUP_SIZE (default 64) block scales. """ - if os.environ.get("ATOM_USE_TRITON_ATTN", "1") == "1": + if os.environ.get("ATOM_USE_AITER_TRITON_ATTN", "0") == "1": + return pa_decode_sparse( + q, + unified_kv, + kv_indices, + kv_indptr, + attn_sink, + softmax_scale, + has_invalid=False, + kv_scales=kv_scales, + ) + elif os.environ.get("ATOM_USE_TRITON_ATTN", "1") == "1": return _sparse_attn_v4_paged_decode_triton( q, unified_kv,