Skip to content

feat(v4): prefix chunk and chunk long prefill related#1236

Open
jiayyu wants to merge 9 commits into
mainfrom
fpz/v4_prefix_chunk
Open

feat(v4): prefix chunk and chunk long prefill related#1236
jiayyu wants to merge 9 commits into
mainfrom
fpz/v4_prefix_chunk

Conversation

@jiayyu

@jiayyu jiayyu commented Jun 16, 2026

Copy link
Copy Markdown
Contributor

Motivation

Technical Details

Test Plan

Test Result

Submission Checklist

jiayyu and others added 8 commits June 16, 2026 09:30
In _forward_prefill_cached_chunked, a seq with no cached tokens in the
first chunk gets lse=-inf from flash_attn. Seeding the running accumulator
with that -inf and later merging against another -inf suffix computes
-inf-(-inf)=NaN in merge_attn_states, permanently poisoning that seq's
output. Only triggers with multiple seqs chunked together at high
concurrency (total_kv > attn_prefill_chunk_size). Sanitize the seed lse
with a large finite sentinel so an absent seq carries ~zero weight without
producing NaN.

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
HCA prefill used the per-seq committed count (ctx_end//128) for every
token, missing the (pos+1)//128 per-token causal cap that CSA already has
(and that the reference get_compress_topk_idxs applies). Under chunked
prefill ctx_end is the chunk's end, so the same logical token saw a
different number of HCA compressed groups depending on which chunk
computed it -> chunked != single-shot -> ~0.02 GSM8K drop.

Cap HCA per-token visibility to min((pos+1)//128, n_committed_hca) in the
indptr build, the prefill-indices kernel (new HCA_RATIO constexpr), and
the reference impl. Decode is unaffected (decode token is at seq end, the
cap is a no-op).

Verified GSM8K (V4-Pro, num_concurrent=4, fp8): chunked 0.93 -> 0.9507,
single-shot 0.9515 (no regression).

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
…round

The earlier MLA NaN fix sanitized the seed lse at the call site
(nan_to_num on suf_lse). Now that merge_attn_states lives in ATOM
(triton_merge_attn_states.py, sole caller is the MLA chunked path —
plugin/vllm imports vllm's own copy), fix it at the root instead.

When a token's prefix AND suffix are both empty (max_lse == -inf), the
kernel computed -inf-(-inf)=NaN and a 0/0 scale that poisoned the output.
This is reachable in ATOM's global-axis chunked prefill: a short seq can
fall entirely outside a chunk. Guard both_empty: force a finite 0/0-split
so out=0 (correct for empty attention) and keep lse=-inf. The call-site
nan_to_num is now redundant and reverts to chunked_lse = suf_lse.

Verified GSM8K (R1-MXFP4, tp4, fp8, num_concurrent=64, long-prefill 512):
0.9431 — same as the call-site workaround, no regression.

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
Copilot AI review requested due to automatic review settings June 16, 2026 10:45

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces a per-request cap for chunked-prefill scheduling (long_prefill_token_threshold) and fixes DeepSeek-V4 chunked-prefill correctness by making HCA compressed-group visibility causal per token (so chunked vs single-shot prefill yields consistent results). It also hardens Triton attention-state merging against NaNs when both prefix and suffix attention are empty.

Changes:

  • Add long_prefill_token_threshold to config + CLI and enforce it in Scheduler.schedule(); add unit tests covering expected scheduling behavior.
  • Fix DeepSeek-V4 paged-prefill index building to apply a per-token causal cap for HCA groups in both the Triton kernel and the reference/CPU meta builder.
  • Prevent NaNs in merge_attn_states_kernel when both prefix and suffix are empty for a token.

Reviewed changes

Copilot reviewed 8 out of 8 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
atom/model_engine/scheduler.py Apply per-request prefill cap during scheduling; adjust partial-prefill accounting to use num_tokens.
atom/model_engine/arg_utils.py Add CLI flag --long-prefill-token-threshold and plumb into engine args.
atom/config.py Add config field + validation for long_prefill_token_threshold; remove prior DeepSeek-V4 prefix-caching auto-disable.
atom/model_ops/v4_kernels/paged_prefill_indices.py Make HCA indices causal per token in the Triton kernel and the Python reference; add HCA_RATIO constexpr.
atom/model_ops/attentions/deepseek_v4_attn.py Update CPU-side per-token HCA counts/indptr sizing to match causal behavior.
atom/model_ops/attentions/triton_merge_attn_states.py Guard merge math to avoid -inf - (-inf) NaNs and 0/0 scales.
tests/conftest.py Extend MockConfig with long_prefill_token_threshold.
tests/test_scheduler.py Add test suite for long_prefill_token_threshold behavior.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread atom/config.py
Comment on lines 1175 to 1179
arches = getattr(self.hf_config, "architectures", None) or []
if any("DeepseekV4" in str(a) for a in arches):
v4_block_size = 128
if self.kv_cache_block_size != v4_block_size:
self.kv_cache_block_size = v4_block_size
Comment thread tests/test_scheduler.py
Comment on lines +255 to +259
# Simulate postprocess marking it partial (would normally happen after
# forward returns and num_cached_tokens < num_prompt_tokens).
seq.num_cached_tokens = 8
seq.is_partial_prefill = True
sched._partial_prefill_count += 1
Comment on lines 1392 to 1399
def get_next_batch_info(self) -> tuple[bool, int, int]:
# Check for partial prefills in running (chunked prefill resume)
for seq in self.running:
if seq.num_cached_tokens < seq.num_prompt_tokens:
remaining = seq.num_prompt_tokens - seq.num_cached_tokens
if seq.num_cached_tokens < seq.num_tokens:
remaining = seq.num_tokens - seq.num_cached_tokens
chunk = min(remaining, self.max_num_batched_tokens)
return (True, chunk, 1)
# Only consider waiting seqs that are not blocked on a remote KV
Comment on lines +96 to +104
# Per-token CAUSAL HCA visibility: token at `pos` may see only the
# `(pos+1)//HCA_RATIO` compressed groups committed up to its own position
# (matches the reference `get_compress_topk_idxs` prefill mask, and mirrors
# the CSA `(pos+1)//4` cap). Without this cap every token saw the per-seq
# `n_committed_hca = ctx_end//128`, which over-reads FUTURE groups and makes
# a token's output depend on the forward's total length (chunked != single).
n_hca = tl.minimum(
(pos + 1) // HCA_RATIO, tl.load(n_committed_hca_per_seq_ptr + bid)
)
Comment thread atom/config.py
Comment on lines 1175 to +1180
arches = getattr(self.hf_config, "architectures", None) or []
if any("DeepseekV4" in str(a) for a in arches):
v4_block_size = 128
if self.kv_cache_block_size != v4_block_size:
self.kv_cache_block_size = v4_block_size
# TODO: V4's per-request SWA buffer cannot be restored from the classical
# KV pool on prefix cache hit, so disable prefix caching silently.
if self.enable_prefix_caching:
import logging

logging.getLogger(__name__).warning(
"DeepSeek-V4 does not support prefix caching "
"(SWA buffer is not cacheable); disabling automatically."
)
self.enable_prefix_caching = False

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants