feat(v4): prefix chunk and chunk long prefill related#1236
Open
jiayyu wants to merge 9 commits into
Open
Conversation
In _forward_prefill_cached_chunked, a seq with no cached tokens in the first chunk gets lse=-inf from flash_attn. Seeding the running accumulator with that -inf and later merging against another -inf suffix computes -inf-(-inf)=NaN in merge_attn_states, permanently poisoning that seq's output. Only triggers with multiple seqs chunked together at high concurrency (total_kv > attn_prefill_chunk_size). Sanitize the seed lse with a large finite sentinel so an absent seq carries ~zero weight without producing NaN. Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
HCA prefill used the per-seq committed count (ctx_end//128) for every token, missing the (pos+1)//128 per-token causal cap that CSA already has (and that the reference get_compress_topk_idxs applies). Under chunked prefill ctx_end is the chunk's end, so the same logical token saw a different number of HCA compressed groups depending on which chunk computed it -> chunked != single-shot -> ~0.02 GSM8K drop. Cap HCA per-token visibility to min((pos+1)//128, n_committed_hca) in the indptr build, the prefill-indices kernel (new HCA_RATIO constexpr), and the reference impl. Decode is unaffected (decode token is at seq end, the cap is a no-op). Verified GSM8K (V4-Pro, num_concurrent=4, fp8): chunked 0.93 -> 0.9507, single-shot 0.9515 (no regression). Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
…round The earlier MLA NaN fix sanitized the seed lse at the call site (nan_to_num on suf_lse). Now that merge_attn_states lives in ATOM (triton_merge_attn_states.py, sole caller is the MLA chunked path — plugin/vllm imports vllm's own copy), fix it at the root instead. When a token's prefix AND suffix are both empty (max_lse == -inf), the kernel computed -inf-(-inf)=NaN and a 0/0 scale that poisoned the output. This is reachable in ATOM's global-axis chunked prefill: a short seq can fall entirely outside a chunk. Guard both_empty: force a finite 0/0-split so out=0 (correct for empty attention) and keep lse=-inf. The call-site nan_to_num is now redundant and reverts to chunked_lse = suf_lse. Verified GSM8K (R1-MXFP4, tp4, fp8, num_concurrent=64, long-prefill 512): 0.9431 — same as the call-site workaround, no regression. Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
Contributor
There was a problem hiding this comment.
Pull request overview
This PR introduces a per-request cap for chunked-prefill scheduling (long_prefill_token_threshold) and fixes DeepSeek-V4 chunked-prefill correctness by making HCA compressed-group visibility causal per token (so chunked vs single-shot prefill yields consistent results). It also hardens Triton attention-state merging against NaNs when both prefix and suffix attention are empty.
Changes:
- Add
long_prefill_token_thresholdto config + CLI and enforce it inScheduler.schedule(); add unit tests covering expected scheduling behavior. - Fix DeepSeek-V4 paged-prefill index building to apply a per-token causal cap for HCA groups in both the Triton kernel and the reference/CPU meta builder.
- Prevent NaNs in
merge_attn_states_kernelwhen both prefix and suffix are empty for a token.
Reviewed changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
atom/model_engine/scheduler.py |
Apply per-request prefill cap during scheduling; adjust partial-prefill accounting to use num_tokens. |
atom/model_engine/arg_utils.py |
Add CLI flag --long-prefill-token-threshold and plumb into engine args. |
atom/config.py |
Add config field + validation for long_prefill_token_threshold; remove prior DeepSeek-V4 prefix-caching auto-disable. |
atom/model_ops/v4_kernels/paged_prefill_indices.py |
Make HCA indices causal per token in the Triton kernel and the Python reference; add HCA_RATIO constexpr. |
atom/model_ops/attentions/deepseek_v4_attn.py |
Update CPU-side per-token HCA counts/indptr sizing to match causal behavior. |
atom/model_ops/attentions/triton_merge_attn_states.py |
Guard merge math to avoid -inf - (-inf) NaNs and 0/0 scales. |
tests/conftest.py |
Extend MockConfig with long_prefill_token_threshold. |
tests/test_scheduler.py |
Add test suite for long_prefill_token_threshold behavior. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Comment on lines
1175
to
1179
| arches = getattr(self.hf_config, "architectures", None) or [] | ||
| if any("DeepseekV4" in str(a) for a in arches): | ||
| v4_block_size = 128 | ||
| if self.kv_cache_block_size != v4_block_size: | ||
| self.kv_cache_block_size = v4_block_size |
Comment on lines
+255
to
+259
| # Simulate postprocess marking it partial (would normally happen after | ||
| # forward returns and num_cached_tokens < num_prompt_tokens). | ||
| seq.num_cached_tokens = 8 | ||
| seq.is_partial_prefill = True | ||
| sched._partial_prefill_count += 1 |
Comment on lines
1392
to
1399
| def get_next_batch_info(self) -> tuple[bool, int, int]: | ||
| # Check for partial prefills in running (chunked prefill resume) | ||
| for seq in self.running: | ||
| if seq.num_cached_tokens < seq.num_prompt_tokens: | ||
| remaining = seq.num_prompt_tokens - seq.num_cached_tokens | ||
| if seq.num_cached_tokens < seq.num_tokens: | ||
| remaining = seq.num_tokens - seq.num_cached_tokens | ||
| chunk = min(remaining, self.max_num_batched_tokens) | ||
| return (True, chunk, 1) | ||
| # Only consider waiting seqs that are not blocked on a remote KV |
Comment on lines
+96
to
+104
| # Per-token CAUSAL HCA visibility: token at `pos` may see only the | ||
| # `(pos+1)//HCA_RATIO` compressed groups committed up to its own position | ||
| # (matches the reference `get_compress_topk_idxs` prefill mask, and mirrors | ||
| # the CSA `(pos+1)//4` cap). Without this cap every token saw the per-seq | ||
| # `n_committed_hca = ctx_end//128`, which over-reads FUTURE groups and makes | ||
| # a token's output depend on the forward's total length (chunked != single). | ||
| n_hca = tl.minimum( | ||
| (pos + 1) // HCA_RATIO, tl.load(n_committed_hca_per_seq_ptr + bid) | ||
| ) |
Comment on lines
1175
to
+1180
| arches = getattr(self.hf_config, "architectures", None) or [] | ||
| if any("DeepseekV4" in str(a) for a in arches): | ||
| v4_block_size = 128 | ||
| if self.kv_cache_block_size != v4_block_size: | ||
| self.kv_cache_block_size = v4_block_size | ||
| # TODO: V4's per-request SWA buffer cannot be restored from the classical | ||
| # KV pool on prefix cache hit, so disable prefix caching silently. | ||
| if self.enable_prefix_caching: | ||
| import logging | ||
|
|
||
| logging.getLogger(__name__).warning( | ||
| "DeepSeek-V4 does not support prefix caching " | ||
| "(SWA buffer is not cacheable); disabling automatically." | ||
| ) | ||
| self.enable_prefix_caching = False | ||
|
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
Technical Details
Test Plan
Test Result
Submission Checklist