Skip to content
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 15 additions & 10 deletions atom/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -981,6 +981,7 @@ class Config:
model: str
trust_remote_code: bool = False
max_num_batched_tokens: int = 16384
long_prefill_token_threshold: int = 0
attn_prefill_chunk_size: int = 16384
scheduler_delay_factor: float = 0.0
max_num_seqs: int = 512
Expand Down Expand Up @@ -1104,6 +1105,19 @@ def __post_init__(self):
self.max_model_len, hf_config_max_position_embeddings
)
# assert self.max_num_batched_tokens >= self.max_model_len
if self.long_prefill_token_threshold > 0:
if self.long_prefill_token_threshold > self.max_model_len:
raise ValueError(
f"long_prefill_token_threshold "
f"({self.long_prefill_token_threshold}) cannot be greater "
f"than max_model_len ({self.max_model_len})."
)
if self.long_prefill_token_threshold < self.kv_cache_block_size:
raise ValueError(
f"long_prefill_token_threshold "
f"({self.long_prefill_token_threshold}) must be >= "
f"kv_cache_block_size ({self.kv_cache_block_size})."
)
if not is_plugin_mode():
if self.torch_profiler_dir is not None:
os.makedirs(self.torch_profiler_dir, exist_ok=True)
Expand Down Expand Up @@ -1163,16 +1177,7 @@ def __post_init__(self):
v4_block_size = 128
if self.kv_cache_block_size != v4_block_size:
self.kv_cache_block_size = v4_block_size
Comment on lines 1175 to 1179
# TODO: V4's per-request SWA buffer cannot be restored from the classical
# KV pool on prefix cache hit, so disable prefix caching silently.
if self.enable_prefix_caching:
import logging

logging.getLogger(__name__).warning(
"DeepSeek-V4 does not support prefix caching "
"(SWA buffer is not cacheable); disabling automatically."
)
self.enable_prefix_caching = False

Comment on lines -1175 to +1180

def compute_hash(self) -> str:
"""
Expand Down
12 changes: 12 additions & 0 deletions atom/model_engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class EngineArgs:
block_size: int = 16
max_model_len: Optional[int] = None
max_num_batched_tokens: int = 16384
long_prefill_token_threshold: int = 0
attn_prefill_chunk_size: int = 16384
enable_chunked_prefill: bool = True
scheduler_delay_factor: float = 0.0
Expand Down Expand Up @@ -192,6 +193,17 @@ def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
default=16384,
help="Maximum number of tokens to batch together in async engine",
)
parser.add_argument(
"--long-prefill-token-threshold",
type=int,
default=0,
help=(
"For chunked prefill, cap a single request's per-step prefill "
"size at this many tokens. 0 disables the cap (request is only "
"bounded by max_num_batched_tokens). Useful to interleave long "
"prefills with decode for lower ITL."
),
)
parser.add_argument(
"--attn-prefill-chunk-size",
type=int,
Expand Down
16 changes: 12 additions & 4 deletions atom/model_engine/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,7 @@ class Scheduler:
def __init__(self, config: Config):
self.max_num_seqs = config.max_num_seqs
self.max_num_batched_tokens = config.max_num_batched_tokens
self.long_prefill_token_threshold = config.long_prefill_token_threshold
self.max_model_len = config.max_model_len
self.bos_token_id = config.bos_token_id
self.eos_token_id = config.eos_token_id
Expand Down Expand Up @@ -721,7 +722,9 @@ def schedule(self) -> tuple[ScheduledBatch, dict[int, Sequence]]:
break
if not seq.is_partial_prefill:
continue
remaining = seq.num_prompt_tokens - seq.num_cached_tokens
remaining = seq.num_tokens - seq.num_cached_tokens
if 0 < self.long_prefill_token_threshold < remaining:
remaining = self.long_prefill_token_threshold
budget_remaining = self.max_num_batched_tokens - num_batched_tokens
chunk = min(remaining, budget_remaining)
if chunk <= 0:
Expand Down Expand Up @@ -821,6 +824,11 @@ def schedule(self) -> tuple[ScheduledBatch, dict[int, Sequence]]:
num_new_tokens = (
seq.num_tokens - num_cached_blocks * self.block_manager.block_size
)
if (
self.enable_chunked_prefill
and 0 < self.long_prefill_token_threshold < num_new_tokens
):
num_new_tokens = self.long_prefill_token_threshold
budget_remaining = self.max_num_batched_tokens - num_batched_tokens
if self.enable_chunked_prefill:
chunk = min(num_new_tokens, budget_remaining)
Expand Down Expand Up @@ -1036,7 +1044,7 @@ def postprocess(
# multiple steps (hash_blocks clips to fully-filled blocks).
self.block_manager.hash_blocks(seq, chunk)
seq.num_cached_tokens += chunk
now_partial = seq.num_cached_tokens < seq.num_prompt_tokens
now_partial = seq.num_cached_tokens < seq.num_tokens
if now_partial != seq.is_partial_prefill:
self._partial_prefill_count += 1 if now_partial else -1
seq.is_partial_prefill = now_partial
Expand Down Expand Up @@ -1384,8 +1392,8 @@ def has_requests(self) -> bool:
def get_next_batch_info(self) -> tuple[bool, int, int]:
# Check for partial prefills in running (chunked prefill resume)
for seq in self.running:
if seq.num_cached_tokens < seq.num_prompt_tokens:
remaining = seq.num_prompt_tokens - seq.num_cached_tokens
if seq.num_cached_tokens < seq.num_tokens:
remaining = seq.num_tokens - seq.num_cached_tokens
chunk = min(remaining, self.max_num_batched_tokens)
return (True, chunk, 1)
# Only consider waiting seqs that are not blocked on a remote KV
Comment on lines 1392 to 1399
Expand Down
15 changes: 12 additions & 3 deletions atom/model_ops/attentions/deepseek_v4_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1886,9 +1886,18 @@ def _build_paged_prefill_meta(
),
index_topk,
).astype(np.int32)
n_hca_per_token_np = n_committed_hca_per_seq_np[batch_id_per_token_np].astype(
np.int32
)
# Per-token CAUSAL HCA visibility (mirrors CSA above and the reference
# `get_compress_topk_idxs` prefill mask): token at `pos` sees only the
# `(pos+1)//128` HCA groups committed up to its own position, capped by
# the per-seq committed count. Without `(pos+1)//128`, every token used
# the per-seq `ctx_end//128`, over-reading FUTURE groups and making a
# token's output depend on the forward's total length (chunked breaks).
# MUST stay in sync with the kernel's inline cap in
# `_v4_paged_prefill_indices_kernel` (HCA_RATIO).
n_hca_per_token_np = np.minimum(
(positions_arr + 1) // 128,
n_committed_hca_per_seq_np[batch_id_per_token_np],
).astype(np.int32)

# 4 indptrs on CPU; last element = total (no D2H to size buffers).
ext_indptr_np = np.zeros(T + 1, dtype=np.int32)
Expand Down
23 changes: 18 additions & 5 deletions atom/model_ops/attentions/triton_merge_attn_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,15 +128,25 @@ def merge_attn_states_kernel(
s_lse = float("-inf") if s_lse == float("inf") else s_lse

max_lse = tl.maximum(p_lse, s_lse)
p_lse = p_lse - max_lse
s_lse = s_lse - max_lse
# Both prefix AND suffix are empty for this token (no KV on either side) ->
# max_lse == -inf. The naive `p_lse - max_lse` would compute -inf-(-inf)=NaN
# and `out_se` would be 0, making the scale 0/0=NaN that poisons the output.
# This happens in ATOM's global-axis chunked prefill: a short seq can fall
# entirely outside a chunk, so its tokens see an empty prefix AND suffix in
# that chunk. Force a safe 0/0-split: subtract a finite max so each side's
# exp is 0 (out = 0*p_out + 0*s_out = 0, correct for empty attention) and
# keep the merged lse at -inf so any downstream merge stays consistent.
both_empty = max_lse == float("-inf")
safe_max = tl.where(both_empty, 0.0, max_lse)
p_lse = p_lse - safe_max
s_lse = s_lse - safe_max
# Will reuse precomputed Exp values for scale factor computation.
p_se = tl.exp(p_lse)
s_se = tl.exp(s_lse)
out_se = p_se + s_se

if OUTPUT_LSE:
out_lse = tl.log(out_se) + max_lse
out_lse = tl.where(both_empty, float("-inf"), tl.log(out_se) + safe_max)
tl.store(output_lse + head_idx * num_tokens + token_idx, out_lse)

p_out = tl.load(
Expand All @@ -157,8 +167,11 @@ def merge_attn_states_kernel(
# NOTE(woosuk): Be careful with the numerical stability.
# We should compute the scale first, and then multiply it with the output.
# Do not multiply the output with tl.exp(p_lse) or tl.exp(s_lse) directly.
p_scale = p_se / out_se
s_scale = s_se / out_se
# both_empty -> out_se == 0; guard the denominator so the scale is 0/1=0
# (not 0/0=NaN). p_out/s_out are 0 for empty attention, so out stays 0.
safe_out_se = tl.where(both_empty, 1.0, out_se)
p_scale = p_se / safe_out_se
s_scale = s_se / safe_out_se
out = p_out * p_scale + s_out * s_scale

if USE_FP8:
Expand Down
17 changes: 15 additions & 2 deletions atom/model_ops/v4_kernels/paged_prefill_indices.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def _v4_paged_prefill_indices_kernel(
win: tl.constexpr,
cs, # win_with_spec — SWA ring stride (NOT constexpr because varies w/ mtp_k)
swa_pages, # state_slot count * cs — boundary into HCA compress section
HCA_RATIO: tl.constexpr, # HCA compress ratio (128) for per-token causal cap
BLOCK_N: tl.constexpr, # next_pow2(win) — covers SWA prefix and extend segments
):
"""One program per token. Writes four per-token segments:
Expand All @@ -92,7 +93,15 @@ def _v4_paged_prefill_indices_kernel(
chunk_start = tl.load(chunk_start_per_seq_ptr + bid)
cu_q = tl.load(cu_seqlens_q_per_seq_ptr + bid)
state_slot = tl.load(state_slot_per_seq_ptr + bid)
n_hca = tl.load(n_committed_hca_per_seq_ptr + bid)
# Per-token CAUSAL HCA visibility: token at `pos` may see only the
# `(pos+1)//HCA_RATIO` compressed groups committed up to its own position
# (matches the reference `get_compress_topk_idxs` prefill mask, and mirrors
# the CSA `(pos+1)//4` cap). Without this cap every token saw the per-seq
# `n_committed_hca = ctx_end//128`, which over-reads FUTURE groups and makes
# a token's output depend on the forward's total length (chunked != single).
n_hca = tl.minimum(
(pos + 1) // HCA_RATIO, tl.load(n_committed_hca_per_seq_ptr + bid)
)
Comment on lines +96 to +104

# Per-token derived quantities (single-pass arithmetic).
token_pos_in_chunk = pos - chunk_start
Expand Down Expand Up @@ -159,6 +168,7 @@ def write_v4_paged_prefill_indices(
win: int,
cs: int,
swa_pages: int,
hca_ratio: int = 128,
) -> None:
"""One-shot GPU build of the V4 paged-prefill index buffers.

Expand Down Expand Up @@ -247,6 +257,7 @@ def write_v4_paged_prefill_indices(
win=win,
cs=cs,
swa_pages=swa_pages,
HCA_RATIO=hca_ratio,
BLOCK_N=BLOCK_N,
)

Expand All @@ -272,6 +283,7 @@ def write_v4_paged_prefill_indices_reference(
win: int,
cs: int,
swa_pages: int,
hca_ratio: int = 128,
) -> None:
"""Pure-Python equivalent of ``write_v4_paged_prefill_indices``.
Per-token Python loop — slow but readable; used for unit-test bit-exact
Expand Down Expand Up @@ -301,7 +313,8 @@ def write_v4_paged_prefill_indices_reference(
chunk_start = cs_per_seq_cpu[bid]
cu_q = cu_q_cpu[bid]
state_slot = state_slot_cpu[bid]
n_hca = n_hca_cpu[bid]
# Per-token causal HCA cap (mirrors kernel + reference get_compress_topk_idxs).
n_hca = min((pos + 1) // hca_ratio, n_hca_cpu[bid])

token_pos_in_chunk = pos - chunk_start
swa_low = max(pos - win + 1, 0)
Expand Down
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def __init__(self, **overrides):
enable_prefix_caching=False,
max_num_seqs=4,
max_num_batched_tokens=64,
long_prefill_token_threshold=0,
max_model_len=64,
bos_token_id=1,
eos_token_id=2,
Expand Down
127 changes: 127 additions & 0 deletions tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,133 @@ def test_decode_preemption(self, seq_factory):
assert SequenceStatus.WAITING in statuses


# ── long_prefill_token_threshold ──────────────────────────────────────────


class TestLongPrefillTokenThreshold:
"""Per-request cap on prefill tokens per step (vLLM parity)."""

def test_disabled_by_default(self, seq_factory):
"""threshold=0 → no per-request cap, only max_num_batched_tokens applies."""
sched = Scheduler(
MockConfig(
num_kvcache_blocks=100,
kv_cache_block_size=4,
max_num_batched_tokens=1000,
enable_chunked_prefill=True,
)
)
sched.add(seq_factory(list(range(20))))
batch, _ = sched.schedule()
assert list(batch.num_scheduled_tokens) == [20]

def test_caps_single_long_request(self, seq_factory):
"""A 20-token prompt with threshold=8 → first step does 8 tokens."""
sched = Scheduler(
MockConfig(
num_kvcache_blocks=100,
kv_cache_block_size=4,
max_num_batched_tokens=1000,
long_prefill_token_threshold=8,
enable_chunked_prefill=True,
)
)
sched.add(seq_factory(list(range(20))))
batch, _ = sched.schedule()
assert list(batch.num_scheduled_tokens) == [8]

def test_short_request_unaffected(self, seq_factory):
"""Prompt shorter than threshold → full prefill in one step."""
sched = Scheduler(
MockConfig(
num_kvcache_blocks=100,
kv_cache_block_size=4,
max_num_batched_tokens=1000,
long_prefill_token_threshold=16,
enable_chunked_prefill=True,
)
)
sched.add(seq_factory([1, 2, 3, 4, 5]))
batch, _ = sched.schedule()
assert list(batch.num_scheduled_tokens) == [5]

def test_applied_per_request_not_batch(self, seq_factory):
"""Two long prompts each capped at 8 → batch carries 16 tokens."""
sched = Scheduler(
MockConfig(
num_kvcache_blocks=100,
kv_cache_block_size=4,
max_num_batched_tokens=1000,
long_prefill_token_threshold=8,
enable_chunked_prefill=True,
)
)
sched.add(seq_factory(list(range(20))))
sched.add(seq_factory(list(range(20, 40))))
batch, _ = sched.schedule()
assert list(batch.num_scheduled_tokens) == [8, 8]
assert batch.total_tokens_num_prefill == 16

def test_min_with_budget_remaining(self, seq_factory):
"""budget < threshold → chunk is bounded by budget, not threshold."""
sched = Scheduler(
MockConfig(
num_kvcache_blocks=100,
kv_cache_block_size=4,
max_num_batched_tokens=10,
long_prefill_token_threshold=8,
enable_chunked_prefill=True,
)
)
sched.add(seq_factory(list(range(20)))) # capped at 8
sched.add(seq_factory(list(range(20, 40)))) # budget left = 2
batch, _ = sched.schedule()
assert list(batch.num_scheduled_tokens) == [8, 2]

def test_ignored_when_chunked_prefill_disabled(self, seq_factory):
"""No chunked prefill → threshold is a no-op (full prompt or reject)."""
sched = Scheduler(
MockConfig(
num_kvcache_blocks=100,
kv_cache_block_size=4,
max_num_batched_tokens=1000,
long_prefill_token_threshold=8,
enable_chunked_prefill=False,
)
)
sched.add(seq_factory(list(range(20))))
batch, _ = sched.schedule()
# Full 20-token prompt scheduled in one shot, threshold ignored.
assert list(batch.num_scheduled_tokens) == [20]

def test_partial_prefill_resume_capped(self, seq_factory):
"""Phase-1 resume of a partial-prefill seq is also capped by threshold."""
sched = Scheduler(
MockConfig(
num_kvcache_blocks=100,
kv_cache_block_size=4,
max_num_batched_tokens=8, # forces chunking on the 20-tok prompt
long_prefill_token_threshold=8,
enable_chunked_prefill=True,
)
)
seq = seq_factory(list(range(20)))
sched.add(seq)

# Step 1: new request, capped at 8.
batch1, _ = sched.schedule()
assert list(batch1.num_scheduled_tokens) == [8]
# Simulate postprocess marking it partial (would normally happen after
# forward returns and num_cached_tokens < num_prompt_tokens).
seq.num_cached_tokens = 8
seq.is_partial_prefill = True
sched._partial_prefill_count += 1
Comment on lines +255 to +259

# Step 2: partial-prefill resume, also capped at 8 (not 12 remaining).
batch2, _ = sched.schedule()
assert list(batch2.num_scheduled_tokens) == [8]


# ── prefix caching ────────────────────────────────────────────────────────


Expand Down
Loading