Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion atom/model_ops/v4_kernels/paged_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
import triton.language as tl
from aiter.ops.triton.utils.device_info import get_num_sms
from atom.model_ops.sparse_attn_v4 import _sparse_attn_ragged_torch
from aiter.ops.triton.attention.pa_decode_sparse import pa_decode_sparse

LOG2E = 1.4426950408889634 # log2(e); folded into qk_scale so softmax can use exp2.
_MAX_KV_SPLITS = 64 # Hard cap on kv_splits (see _kv_splits_heuristic).
Expand Down Expand Up @@ -916,7 +917,18 @@ def sparse_attn_v4_paged_decode(
When ``kv_scales`` is provided, ``unified_kv`` must be fp8 (e4m3fnuz) and
will be dequantized in-kernel using 1xGROUP_SIZE (default 64) block scales.
"""
if os.environ.get("ATOM_USE_TRITON_ATTN", "1") == "1":
if os.environ.get("ATOM_USE_AITER_TRITON_ATTN", "0") == "1":
return pa_decode_sparse(
q,
unified_kv,
kv_indices,
kv_indptr,
attn_sink,
softmax_scale,
has_invalid=False,
kv_scales=kv_scales,
)
elif os.environ.get("ATOM_USE_TRITON_ATTN", "1") == "1":
return _sparse_attn_v4_paged_decode_triton(
q,
unified_kv,
Expand Down
Loading