From 13ac121f958883b0d61ef727e103e4edb85bc1d1 Mon Sep 17 00:00:00 2001 From: jiayyu Date: Thu, 4 Jun 2026 04:51:29 +0000 Subject: [PATCH 1/9] wip --- atom/config.py | 14 +++ atom/model_engine/arg_utils.py | 12 +++ atom/model_engine/model_runner.py | 6 ++ atom/model_engine/scheduler.py | 22 ++++ atom/utils/debug_helper/__init__.py | 2 + atom/utils/debug_helper/dump.py | 149 ++++++++++++++++++++++++++++ atom/utils/envs.py | 9 ++ tests/conftest.py | 1 + tests/test_scheduler.py | 127 ++++++++++++++++++++++++ 9 files changed, 342 insertions(+) diff --git a/atom/config.py b/atom/config.py index d9b582d601..541f23d2a1 100644 --- a/atom/config.py +++ b/atom/config.py @@ -981,6 +981,7 @@ class Config: model: str trust_remote_code: bool = False max_num_batched_tokens: int = 16384 + long_prefill_token_threshold: int = 0 attn_prefill_chunk_size: int = 16384 scheduler_delay_factor: float = 0.0 max_num_seqs: int = 512 @@ -1104,6 +1105,19 @@ def __post_init__(self): self.max_model_len, hf_config_max_position_embeddings ) # assert self.max_num_batched_tokens >= self.max_model_len + if self.long_prefill_token_threshold > 0: + if self.long_prefill_token_threshold > self.max_model_len: + raise ValueError( + f"long_prefill_token_threshold " + f"({self.long_prefill_token_threshold}) cannot be greater " + f"than max_model_len ({self.max_model_len})." + ) + if self.long_prefill_token_threshold < self.kv_cache_block_size: + raise ValueError( + f"long_prefill_token_threshold " + f"({self.long_prefill_token_threshold}) must be >= " + f"kv_cache_block_size ({self.kv_cache_block_size})." + ) if not is_plugin_mode(): if self.torch_profiler_dir is not None: os.makedirs(self.torch_profiler_dir, exist_ok=True) diff --git a/atom/model_engine/arg_utils.py b/atom/model_engine/arg_utils.py index 0cba4b2e33..526a94d4be 100644 --- a/atom/model_engine/arg_utils.py +++ b/atom/model_engine/arg_utils.py @@ -38,6 +38,7 @@ class EngineArgs: block_size: int = 16 max_model_len: Optional[int] = None max_num_batched_tokens: int = 16384 + long_prefill_token_threshold: int = 0 attn_prefill_chunk_size: int = 16384 enable_chunked_prefill: bool = True scheduler_delay_factor: float = 0.0 @@ -192,6 +193,17 @@ def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: default=16384, help="Maximum number of tokens to batch together in async engine", ) + parser.add_argument( + "--long-prefill-token-threshold", + type=int, + default=0, + help=( + "For chunked prefill, cap a single request's per-step prefill " + "size at this many tokens. 0 disables the cap (request is only " + "bounded by max_num_batched_tokens). Useful to interleave long " + "prefills with decode for lower ITL." + ), + ) parser.add_argument( "--attn-prefill-chunk-size", type=int, diff --git a/atom/model_engine/model_runner.py b/atom/model_engine/model_runner.py index db018399e2..0d7c2112ed 100644 --- a/atom/model_engine/model_runner.py +++ b/atom/model_engine/model_runner.py @@ -659,12 +659,18 @@ def __init__(self, rank: int, config: Config): # See atom/utils/debug_helper/. from atom.utils.debug_helper import ( install_block_forward_hooks, + install_partial_prefill_dump_hooks, maybe_dump_weights_and_exit, ) _n_fwd_hooks = install_block_forward_hooks(self.model) if _n_fwd_hooks > 0: logger.info(f"[ATOM_FWD_DUMP] {_n_fwd_hooks} Block forward hooks installed") + _n_partial_hooks = install_partial_prefill_dump_hooks(self.model) + if _n_partial_hooks > 0: + logger.info( + f"[ATOM_PARTIAL_DUMP] {_n_partial_hooks} partial-prefill hooks installed" + ) maybe_dump_weights_and_exit(self.model) if self.config.speculative_config and get_pp_group().is_last_rank: diff --git a/atom/model_engine/scheduler.py b/atom/model_engine/scheduler.py index dab098f1a9..16d5402687 100644 --- a/atom/model_engine/scheduler.py +++ b/atom/model_engine/scheduler.py @@ -415,6 +415,7 @@ class Scheduler: def __init__(self, config: Config): self.max_num_seqs = config.max_num_seqs self.max_num_batched_tokens = config.max_num_batched_tokens + self.long_prefill_token_threshold = config.long_prefill_token_threshold self.max_model_len = config.max_model_len self.bos_token_id = config.bos_token_id self.eos_token_id = config.eos_token_id @@ -722,6 +723,8 @@ def schedule(self) -> tuple[ScheduledBatch, dict[int, Sequence]]: if not seq.is_partial_prefill: continue remaining = seq.num_prompt_tokens - seq.num_cached_tokens + if 0 < self.long_prefill_token_threshold < remaining: + remaining = self.long_prefill_token_threshold budget_remaining = self.max_num_batched_tokens - num_batched_tokens chunk = min(remaining, budget_remaining) if chunk <= 0: @@ -821,6 +824,11 @@ def schedule(self) -> tuple[ScheduledBatch, dict[int, Sequence]]: num_new_tokens = ( seq.num_tokens - num_cached_blocks * self.block_manager.block_size ) + if ( + self.enable_chunked_prefill + and 0 < self.long_prefill_token_threshold < num_new_tokens + ): + num_new_tokens = self.long_prefill_token_threshold budget_remaining = self.max_num_batched_tokens - num_batched_tokens if self.enable_chunked_prefill: chunk = min(num_new_tokens, budget_remaining) @@ -883,6 +891,20 @@ def schedule(self) -> tuple[ScheduledBatch, dict[int, Sequence]]: f"(cached: {cached_per_req}, new: {num_scheduled_tokens}), " f"req_ids: {tuple(scheduled_seqs.keys())}" ) + n_partial = sum( + 1 + for s, c in zip(scheduled_seqs.values(), num_scheduled_tokens) + if s.num_cached_tokens + c < s.num_prompt_tokens + ) + if n_partial >= 2: + logger.warning( + f"[MULTI-CHUNK-DEBUG] {n_partial}/{num_seqs_prefill} partial-prefill in batch, " + f"req_ids={list(scheduled_seqs.keys())}, " + f"cached={num_cached_tokens_list}, " + f"chunks={num_scheduled_tokens}, " + f"prompt_lens={[s.num_prompt_tokens for s in scheduled_seqs.values()]}, " + f"block_table_lens={[len(s.block_table) for s in scheduled_seqs.values()]}" + ) self.prev_prompt = True # lip: TODO for prefill/decode mixed batch diff --git a/atom/utils/debug_helper/__init__.py b/atom/utils/debug_helper/__init__.py index d23de5e141..5c9fcc6d8e 100644 --- a/atom/utils/debug_helper/__init__.py +++ b/atom/utils/debug_helper/__init__.py @@ -31,6 +31,7 @@ ) from atom.utils.debug_helper.dump import ( install_block_forward_hooks, + install_partial_prefill_dump_hooks, maybe_dump_weights_and_exit, maybe_log_topk, ) @@ -43,6 +44,7 @@ __all__ = [ # dump "install_block_forward_hooks", + "install_partial_prefill_dump_hooks", "maybe_dump_weights_and_exit", "maybe_log_topk", # compare primitives diff --git a/atom/utils/debug_helper/dump.py b/atom/utils/debug_helper/dump.py index 0122efbb72..4c7443a1f8 100644 --- a/atom/utils/debug_helper/dump.py +++ b/atom/utils/debug_helper/dump.py @@ -158,6 +158,155 @@ def _find_layer_id(mod_name: str) -> Optional[int]: return n +# === Partial-prefill chunk dump ====================================== + + +def install_partial_prefill_dump_hooks(model: torch.nn.Module) -> int: + """Install per-Block hooks that dump seq-0's chunk-end hidden state. + + Built for chunked-prefill regression debugging: when running with + --long-prefill-token-threshold, multiple seqs can be partial-prefill in the + same batch. To compare "1 seq per batch" vs "N seqs per batch" without + shape mismatch, we extract the hidden state at seq 0's chunk-end position + from each prefill forward and tag the file with the chunk index. + + Output filename: + {DIR}/layer{LL}_{Cls}_prompt0_chunk{N}_rank{R}.pt + where N is the cumulative number of prefill forwards that have included + seq 0 (= chunk-N's hidden state). + + Per-rank file contains: + hidden: [hidden_dim] — seq-0 last-token hidden state + cu_seqlens_q: full tensor (for sanity check) + bs: batch size this forward + chunk_idx: same as N above + """ + dump_dir = envs.ATOM_PARTIAL_DUMP_DIR + if not dump_dir: + return 0 + + os.makedirs(dump_dir, exist_ok=True) + wanted = _parse_layer_set(envs.ATOM_PARTIAL_DUMP_LAYERS) + block_classes = { + c.strip() for c in envs.ATOM_PARTIAL_DUMP_BLOCK_CLASS.split(",") if c.strip() + } + layer_attr = envs.ATOM_FWD_DUMP_LAYER_ATTR + rank = _get_rank() + + # Per-(layer, cls) chunk counter — increments once per prefill forward + # that includes seq 0. Decode forwards are skipped. + _chunk_counters: dict[tuple[int, str], int] = {} + + def _make_hook(layer_id: int, cls_name: str): + base = os.path.join(dump_dir, f"layer{layer_id:02d}_{cls_name}_rank{rank}") + + def _hook(_mod, _args, output): + import sys as _sys + try: + from atom.utils.forward_context import get_forward_context + + fwd_ctx = get_forward_context() + ctx = fwd_ctx.context + if not ctx.is_prefill or ctx.is_draft or ctx.is_dummy_run: + return + attn = fwd_ctx.attn_metadata + cu_q = attn.cu_seqlens_q + ctx_lens = attn.context_lens + if cu_q is None or cu_q.numel() < 2: + return + if ctx_lens is None or ctx_lens.numel() < 1: + return + t = output[0] if isinstance(output, tuple) else output + if not isinstance(t, torch.Tensor): + return + # Pick the first 2-D tensor arg as the per-prompt fingerprint + # source (DeepseekV2DecoderLayer.forward signature is + # (positions, hidden_states, residual); positions is 1-D, + # hidden_states is 2-D [tokens, hidden]). + inp = None + for a in (_args or ()): + if isinstance(a, torch.Tensor) and a.ndim == 2: + inp = a + break + + bs = int(cu_q.numel() - 1) + cu_q_cpu = cu_q.detach().cpu().tolist() + + import hashlib as _hl + + for sidx in range(bs): + seq_start = cu_q_cpu[sidx] + seq_end = cu_q_cpu[sidx + 1] - 1 + if seq_end < 0 or t.shape[0] <= seq_end: + continue + ctx_end = int(ctx_lens[sidx].item()) + + # Fingerprint: hash the FIRST 4 input rows of this seq's + # chunk. 4 × hidden_dim = ~28KB; uniquely identifies + # (prompt, chunk) across runs. + if inp is not None and inp.shape[0] > seq_start: + end = min(seq_start + 4, inp.shape[0]) + seg = inp[seq_start:end].detach().cpu().contiguous() + seg_bytes = seg.view(torch.uint8).numpy().tobytes() + fp = _hl.sha1(seg_bytes).hexdigest()[:12] + else: + fp = "noinp" + + hidden = t[seq_end].detach().cpu() + fname = f"{base}_ctx{ctx_end:05d}_fp{fp}_bs{bs}.pt" + if os.path.exists(fname): + continue + torch.save( + { + "hidden": hidden, + "ctx_end": ctx_end, + "fp": fp, + "bs": bs, + "sidx": sidx, + }, + fname, + ) + except Exception as exc: + print( + f"[ATOM_PARTIAL_DUMP] layer={layer_id} cls={cls_name} " + f"hook error: {exc!r}", + file=_sys.stderr, + flush=True, + ) + + return _hook + + block_layer_ids: dict[str, int] = {} + for name, mod in model.named_modules(): + lid = getattr(mod, layer_attr, None) + if lid is not None: + block_layer_ids[name] = int(lid) + + def _find_layer_id(mod_name: str) -> Optional[int]: + if mod_name in block_layer_ids: + return block_layer_ids[mod_name] + parts = mod_name.split(".") + for i in range(len(parts) - 1, 0, -1): + parent = ".".join(parts[:i]) + if parent in block_layer_ids: + return block_layer_ids[parent] + return None + + n = 0 + for name, mod in model.named_modules(): + cls = mod.__class__.__name__ + if cls not in block_classes: + continue + lid = _find_layer_id(name) + if lid is None: + continue + if wanted is not None and lid not in wanted: + continue + mod.register_forward_hook(_make_hook(lid, cls)) + n += 1 + return n + + # === Weight dump ===================================================== diff --git a/atom/utils/envs.py b/atom/utils/envs.py index 46554050b5..bc91688aed 100644 --- a/atom/utils/envs.py +++ b/atom/utils/envs.py @@ -168,6 +168,15 @@ "ATOM_FWD_DUMP_LAYER_ATTR", "layer_id" ), "ATOM_FWD_DUMP_ONE_SHOT": lambda: os.getenv("ATOM_FWD_DUMP_ONE_SHOT", "1") == "1", + # Partial-prefill chunk dump: capture the last-token hidden state of seq 0 + # in each prefill batch, tagged by chunk index. Used to bisect "single + # partial seq" vs "multiple partial seqs in same batch" divergence. + # Set ATOM_PARTIAL_DUMP_DIR to enable. Independent of ATOM_FWD_DUMP_*. + "ATOM_PARTIAL_DUMP_DIR": lambda: os.getenv("ATOM_PARTIAL_DUMP_DIR", ""), + "ATOM_PARTIAL_DUMP_LAYERS": lambda: os.getenv("ATOM_PARTIAL_DUMP_LAYERS", "0,30,60"), + "ATOM_PARTIAL_DUMP_BLOCK_CLASS": lambda: os.getenv( + "ATOM_PARTIAL_DUMP_BLOCK_CLASS", "Block" + ), # Per-rank weight dump + sys.exit(0) — for byte-equal weight comparison. "ATOM_WEIGHT_DUMP_DIR": lambda: os.getenv("ATOM_WEIGHT_DUMP_DIR", ""), "ATOM_WEIGHT_DUMP_LAYERS": lambda: os.getenv("ATOM_WEIGHT_DUMP_LAYERS", "0"), diff --git a/tests/conftest.py b/tests/conftest.py index 326335cb9f..e8ded61b5a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -118,6 +118,7 @@ def __init__(self, **overrides): enable_prefix_caching=False, max_num_seqs=4, max_num_batched_tokens=64, + long_prefill_token_threshold=0, max_model_len=64, bos_token_id=1, eos_token_id=2, diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index da43358ecd..67734851e1 100644 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -136,6 +136,133 @@ def test_decode_preemption(self, seq_factory): assert SequenceStatus.WAITING in statuses +# ── long_prefill_token_threshold ────────────────────────────────────────── + + +class TestLongPrefillTokenThreshold: + """Per-request cap on prefill tokens per step (vLLM parity).""" + + def test_disabled_by_default(self, seq_factory): + """threshold=0 → no per-request cap, only max_num_batched_tokens applies.""" + sched = Scheduler( + MockConfig( + num_kvcache_blocks=100, + kv_cache_block_size=4, + max_num_batched_tokens=1000, + enable_chunked_prefill=True, + ) + ) + sched.add(seq_factory(list(range(20)))) + batch, _ = sched.schedule() + assert list(batch.num_scheduled_tokens) == [20] + + def test_caps_single_long_request(self, seq_factory): + """A 20-token prompt with threshold=8 → first step does 8 tokens.""" + sched = Scheduler( + MockConfig( + num_kvcache_blocks=100, + kv_cache_block_size=4, + max_num_batched_tokens=1000, + long_prefill_token_threshold=8, + enable_chunked_prefill=True, + ) + ) + sched.add(seq_factory(list(range(20)))) + batch, _ = sched.schedule() + assert list(batch.num_scheduled_tokens) == [8] + + def test_short_request_unaffected(self, seq_factory): + """Prompt shorter than threshold → full prefill in one step.""" + sched = Scheduler( + MockConfig( + num_kvcache_blocks=100, + kv_cache_block_size=4, + max_num_batched_tokens=1000, + long_prefill_token_threshold=16, + enable_chunked_prefill=True, + ) + ) + sched.add(seq_factory([1, 2, 3, 4, 5])) + batch, _ = sched.schedule() + assert list(batch.num_scheduled_tokens) == [5] + + def test_applied_per_request_not_batch(self, seq_factory): + """Two long prompts each capped at 8 → batch carries 16 tokens.""" + sched = Scheduler( + MockConfig( + num_kvcache_blocks=100, + kv_cache_block_size=4, + max_num_batched_tokens=1000, + long_prefill_token_threshold=8, + enable_chunked_prefill=True, + ) + ) + sched.add(seq_factory(list(range(20)))) + sched.add(seq_factory(list(range(20, 40)))) + batch, _ = sched.schedule() + assert list(batch.num_scheduled_tokens) == [8, 8] + assert batch.total_tokens_num_prefill == 16 + + def test_min_with_budget_remaining(self, seq_factory): + """budget < threshold → chunk is bounded by budget, not threshold.""" + sched = Scheduler( + MockConfig( + num_kvcache_blocks=100, + kv_cache_block_size=4, + max_num_batched_tokens=10, + long_prefill_token_threshold=8, + enable_chunked_prefill=True, + ) + ) + sched.add(seq_factory(list(range(20)))) # capped at 8 + sched.add(seq_factory(list(range(20, 40)))) # budget left = 2 + batch, _ = sched.schedule() + assert list(batch.num_scheduled_tokens) == [8, 2] + + def test_ignored_when_chunked_prefill_disabled(self, seq_factory): + """No chunked prefill → threshold is a no-op (full prompt or reject).""" + sched = Scheduler( + MockConfig( + num_kvcache_blocks=100, + kv_cache_block_size=4, + max_num_batched_tokens=1000, + long_prefill_token_threshold=8, + enable_chunked_prefill=False, + ) + ) + sched.add(seq_factory(list(range(20)))) + batch, _ = sched.schedule() + # Full 20-token prompt scheduled in one shot, threshold ignored. + assert list(batch.num_scheduled_tokens) == [20] + + def test_partial_prefill_resume_capped(self, seq_factory): + """Phase-1 resume of a partial-prefill seq is also capped by threshold.""" + sched = Scheduler( + MockConfig( + num_kvcache_blocks=100, + kv_cache_block_size=4, + max_num_batched_tokens=8, # forces chunking on the 20-tok prompt + long_prefill_token_threshold=8, + enable_chunked_prefill=True, + ) + ) + seq = seq_factory(list(range(20))) + sched.add(seq) + + # Step 1: new request, capped at 8. + batch1, _ = sched.schedule() + assert list(batch1.num_scheduled_tokens) == [8] + # Simulate postprocess marking it partial (would normally happen after + # forward returns and num_cached_tokens < num_prompt_tokens). + seq.num_cached_tokens = 8 + seq.is_partial_prefill = True + sched._partial_prefill_count += 1 + + # Step 2: partial-prefill resume, also capped at 8 (not 12 remaining). + batch2, _ = sched.schedule() + assert list(batch2.num_scheduled_tokens) == [8] + + # ── prefix caching ──────────────────────────────────────────────────────── From 4624d0d9e2fad01aa591bef2bfc8a844772f013d Mon Sep 17 00:00:00 2001 From: jiayyu Date: Tue, 9 Jun 2026 03:03:31 +0000 Subject: [PATCH 2/9] disable debug code for now --- atom/model_engine/scheduler.py | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/atom/model_engine/scheduler.py b/atom/model_engine/scheduler.py index 16d5402687..3c1c49923d 100644 --- a/atom/model_engine/scheduler.py +++ b/atom/model_engine/scheduler.py @@ -891,20 +891,22 @@ def schedule(self) -> tuple[ScheduledBatch, dict[int, Sequence]]: f"(cached: {cached_per_req}, new: {num_scheduled_tokens}), " f"req_ids: {tuple(scheduled_seqs.keys())}" ) - n_partial = sum( - 1 - for s, c in zip(scheduled_seqs.values(), num_scheduled_tokens) - if s.num_cached_tokens + c < s.num_prompt_tokens - ) - if n_partial >= 2: - logger.warning( - f"[MULTI-CHUNK-DEBUG] {n_partial}/{num_seqs_prefill} partial-prefill in batch, " - f"req_ids={list(scheduled_seqs.keys())}, " - f"cached={num_cached_tokens_list}, " - f"chunks={num_scheduled_tokens}, " - f"prompt_lens={[s.num_prompt_tokens for s in scheduled_seqs.values()]}, " - f"block_table_lens={[len(s.block_table) for s in scheduled_seqs.values()]}" - ) + # [MULTI-CHUNK-DEBUG] disabled — kept for re-enabling when investigating + # batch-invariance regressions under --long-prefill-token-threshold. + # n_partial = sum( + # 1 + # for s, c in zip(scheduled_seqs.values(), num_scheduled_tokens) + # if s.num_cached_tokens + c < s.num_prompt_tokens + # ) + # if n_partial >= 2: + # logger.warning( + # f"[MULTI-CHUNK-DEBUG] {n_partial}/{num_seqs_prefill} partial-prefill in batch, " + # f"req_ids={list(scheduled_seqs.keys())}, " + # f"cached={num_cached_tokens_list}, " + # f"chunks={num_scheduled_tokens}, " + # f"prompt_lens={[s.num_prompt_tokens for s in scheduled_seqs.values()]}, " + # f"block_table_lens={[len(s.block_table) for s in scheduled_seqs.values()]}" + # ) self.prev_prompt = True # lip: TODO for prefill/decode mixed batch From 0bebacd07d5f2f52d234d815b76a2838caf55a2c Mon Sep 17 00:00:00 2001 From: jiayyu Date: Wed, 10 Jun 2026 06:29:03 +0000 Subject: [PATCH 3/9] fix(mla): prevent NaN in chunked cached-prefix attention LSE merge In _forward_prefill_cached_chunked, a seq with no cached tokens in the first chunk gets lse=-inf from flash_attn. Seeding the running accumulator with that -inf and later merging against another -inf suffix computes -inf-(-inf)=NaN in merge_attn_states, permanently poisoning that seq's output. Only triggers with multiple seqs chunked together at high concurrency (total_kv > attn_prefill_chunk_size). Sanitize the seed lse with a large finite sentinel so an absent seq carries ~zero weight without producing NaN. Co-Authored-By: Claude Opus 4 --- atom/model_ops/attention_mla.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/atom/model_ops/attention_mla.py b/atom/model_ops/attention_mla.py index ee7941da9a..990a89360d 100644 --- a/atom/model_ops/attention_mla.py +++ b/atom/model_ops/attention_mla.py @@ -530,7 +530,17 @@ def _forward_prefill_cached_chunked( ) if chunked_out is None: chunked_out = suf_out - chunked_lse = suf_lse + # A seq with no cached tokens in this (first) chunk gets + # lse=-inf from flash_attn. If that -inf seeds the running + # accumulator, a later merge against another -inf suffix (the + # same seq absent again) computes -inf-(-inf)=NaN in + # merge_attn_states, permanently poisoning that seq's output. + # Replace -inf with a large finite sentinel so the seq carries + # ~zero weight (exp(sentinel-max)=0) without producing NaN; its + # real prefix contribution is merged in once a later chunk + # covers it. Only the seed needs sanitizing — once the prefix + # side is finite, max_lse stays finite for all later merges. + chunked_lse = torch.nan_to_num(suf_lse, neginf=-1e30) else: tmp_out = torch.empty_like(new_out) tmp_lse = torch.empty_like(new_lse) From f27db731d0697d86532f003ec0af03d1fa4d9446 Mon Sep 17 00:00:00 2001 From: jiayyu Date: Wed, 10 Jun 2026 06:29:15 +0000 Subject: [PATCH 4/9] fix(v4): add per-token causal cap to HCA prefill visibility HCA prefill used the per-seq committed count (ctx_end//128) for every token, missing the (pos+1)//128 per-token causal cap that CSA already has (and that the reference get_compress_topk_idxs applies). Under chunked prefill ctx_end is the chunk's end, so the same logical token saw a different number of HCA compressed groups depending on which chunk computed it -> chunked != single-shot -> ~0.02 GSM8K drop. Cap HCA per-token visibility to min((pos+1)//128, n_committed_hca) in the indptr build, the prefill-indices kernel (new HCA_RATIO constexpr), and the reference impl. Decode is unaffected (decode token is at seq end, the cap is a no-op). Verified GSM8K (V4-Pro, num_concurrent=4, fp8): chunked 0.93 -> 0.9507, single-shot 0.9515 (no regression). Co-Authored-By: Claude Opus 4 --- atom/model_ops/attentions/deepseek_v4_attn.py | 15 ++++++++++++--- .../v4_kernels/paged_prefill_indices.py | 17 +++++++++++++++-- 2 files changed, 27 insertions(+), 5 deletions(-) diff --git a/atom/model_ops/attentions/deepseek_v4_attn.py b/atom/model_ops/attentions/deepseek_v4_attn.py index 2ff9de3dd3..ee849b1f48 100644 --- a/atom/model_ops/attentions/deepseek_v4_attn.py +++ b/atom/model_ops/attentions/deepseek_v4_attn.py @@ -1886,9 +1886,18 @@ def _build_paged_prefill_meta( ), index_topk, ).astype(np.int32) - n_hca_per_token_np = n_committed_hca_per_seq_np[batch_id_per_token_np].astype( - np.int32 - ) + # Per-token CAUSAL HCA visibility (mirrors CSA above and the reference + # `get_compress_topk_idxs` prefill mask): token at `pos` sees only the + # `(pos+1)//128` HCA groups committed up to its own position, capped by + # the per-seq committed count. Without `(pos+1)//128`, every token used + # the per-seq `ctx_end//128`, over-reading FUTURE groups and making a + # token's output depend on the forward's total length (chunked breaks). + # MUST stay in sync with the kernel's inline cap in + # `_v4_paged_prefill_indices_kernel` (HCA_RATIO). + n_hca_per_token_np = np.minimum( + (positions_arr + 1) // 128, + n_committed_hca_per_seq_np[batch_id_per_token_np], + ).astype(np.int32) # 4 indptrs on CPU; last element = total (no D2H to size buffers). ext_indptr_np = np.zeros(T + 1, dtype=np.int32) diff --git a/atom/model_ops/v4_kernels/paged_prefill_indices.py b/atom/model_ops/v4_kernels/paged_prefill_indices.py index c1375da1db..d7edfaf8ce 100644 --- a/atom/model_ops/v4_kernels/paged_prefill_indices.py +++ b/atom/model_ops/v4_kernels/paged_prefill_indices.py @@ -71,6 +71,7 @@ def _v4_paged_prefill_indices_kernel( win: tl.constexpr, cs, # win_with_spec — SWA ring stride (NOT constexpr because varies w/ mtp_k) swa_pages, # state_slot count * cs — boundary into HCA compress section + HCA_RATIO: tl.constexpr, # HCA compress ratio (128) for per-token causal cap BLOCK_N: tl.constexpr, # next_pow2(win) — covers SWA prefix and extend segments ): """One program per token. Writes four per-token segments: @@ -92,7 +93,15 @@ def _v4_paged_prefill_indices_kernel( chunk_start = tl.load(chunk_start_per_seq_ptr + bid) cu_q = tl.load(cu_seqlens_q_per_seq_ptr + bid) state_slot = tl.load(state_slot_per_seq_ptr + bid) - n_hca = tl.load(n_committed_hca_per_seq_ptr + bid) + # Per-token CAUSAL HCA visibility: token at `pos` may see only the + # `(pos+1)//HCA_RATIO` compressed groups committed up to its own position + # (matches the reference `get_compress_topk_idxs` prefill mask, and mirrors + # the CSA `(pos+1)//4` cap). Without this cap every token saw the per-seq + # `n_committed_hca = ctx_end//128`, which over-reads FUTURE groups and makes + # a token's output depend on the forward's total length (chunked != single). + n_hca = tl.minimum( + (pos + 1) // HCA_RATIO, tl.load(n_committed_hca_per_seq_ptr + bid) + ) # Per-token derived quantities (single-pass arithmetic). token_pos_in_chunk = pos - chunk_start @@ -159,6 +168,7 @@ def write_v4_paged_prefill_indices( win: int, cs: int, swa_pages: int, + hca_ratio: int = 128, ) -> None: """One-shot GPU build of the V4 paged-prefill index buffers. @@ -247,6 +257,7 @@ def write_v4_paged_prefill_indices( win=win, cs=cs, swa_pages=swa_pages, + HCA_RATIO=hca_ratio, BLOCK_N=BLOCK_N, ) @@ -272,6 +283,7 @@ def write_v4_paged_prefill_indices_reference( win: int, cs: int, swa_pages: int, + hca_ratio: int = 128, ) -> None: """Pure-Python equivalent of ``write_v4_paged_prefill_indices``. Per-token Python loop — slow but readable; used for unit-test bit-exact @@ -301,7 +313,8 @@ def write_v4_paged_prefill_indices_reference( chunk_start = cs_per_seq_cpu[bid] cu_q = cu_q_cpu[bid] state_slot = state_slot_cpu[bid] - n_hca = n_hca_cpu[bid] + # Per-token causal HCA cap (mirrors kernel + reference get_compress_topk_idxs). + n_hca = min((pos + 1) // hca_ratio, n_hca_cpu[bid]) token_pos_in_chunk = pos - chunk_start swa_low = max(pos - win + 1, 0) From 64f95644409cf4323e7d0528d1f31b2152067e2f Mon Sep 17 00:00:00 2001 From: jiayyu Date: Wed, 10 Jun 2026 07:16:12 +0000 Subject: [PATCH 5/9] remove debug code --- atom/model_engine/model_runner.py | 6 -- atom/model_engine/scheduler.py | 24 +---- atom/utils/debug_helper/__init__.py | 2 - atom/utils/debug_helper/dump.py | 149 ---------------------------- atom/utils/envs.py | 9 -- 5 files changed, 4 insertions(+), 186 deletions(-) diff --git a/atom/model_engine/model_runner.py b/atom/model_engine/model_runner.py index 0d7c2112ed..db018399e2 100644 --- a/atom/model_engine/model_runner.py +++ b/atom/model_engine/model_runner.py @@ -659,18 +659,12 @@ def __init__(self, rank: int, config: Config): # See atom/utils/debug_helper/. from atom.utils.debug_helper import ( install_block_forward_hooks, - install_partial_prefill_dump_hooks, maybe_dump_weights_and_exit, ) _n_fwd_hooks = install_block_forward_hooks(self.model) if _n_fwd_hooks > 0: logger.info(f"[ATOM_FWD_DUMP] {_n_fwd_hooks} Block forward hooks installed") - _n_partial_hooks = install_partial_prefill_dump_hooks(self.model) - if _n_partial_hooks > 0: - logger.info( - f"[ATOM_PARTIAL_DUMP] {_n_partial_hooks} partial-prefill hooks installed" - ) maybe_dump_weights_and_exit(self.model) if self.config.speculative_config and get_pp_group().is_last_rank: diff --git a/atom/model_engine/scheduler.py b/atom/model_engine/scheduler.py index 3c1c49923d..8e90901abe 100644 --- a/atom/model_engine/scheduler.py +++ b/atom/model_engine/scheduler.py @@ -722,7 +722,7 @@ def schedule(self) -> tuple[ScheduledBatch, dict[int, Sequence]]: break if not seq.is_partial_prefill: continue - remaining = seq.num_prompt_tokens - seq.num_cached_tokens + remaining = seq.num_tokens - seq.num_cached_tokens if 0 < self.long_prefill_token_threshold < remaining: remaining = self.long_prefill_token_threshold budget_remaining = self.max_num_batched_tokens - num_batched_tokens @@ -891,22 +891,6 @@ def schedule(self) -> tuple[ScheduledBatch, dict[int, Sequence]]: f"(cached: {cached_per_req}, new: {num_scheduled_tokens}), " f"req_ids: {tuple(scheduled_seqs.keys())}" ) - # [MULTI-CHUNK-DEBUG] disabled — kept for re-enabling when investigating - # batch-invariance regressions under --long-prefill-token-threshold. - # n_partial = sum( - # 1 - # for s, c in zip(scheduled_seqs.values(), num_scheduled_tokens) - # if s.num_cached_tokens + c < s.num_prompt_tokens - # ) - # if n_partial >= 2: - # logger.warning( - # f"[MULTI-CHUNK-DEBUG] {n_partial}/{num_seqs_prefill} partial-prefill in batch, " - # f"req_ids={list(scheduled_seqs.keys())}, " - # f"cached={num_cached_tokens_list}, " - # f"chunks={num_scheduled_tokens}, " - # f"prompt_lens={[s.num_prompt_tokens for s in scheduled_seqs.values()]}, " - # f"block_table_lens={[len(s.block_table) for s in scheduled_seqs.values()]}" - # ) self.prev_prompt = True # lip: TODO for prefill/decode mixed batch @@ -1060,7 +1044,7 @@ def postprocess( # multiple steps (hash_blocks clips to fully-filled blocks). self.block_manager.hash_blocks(seq, chunk) seq.num_cached_tokens += chunk - now_partial = seq.num_cached_tokens < seq.num_prompt_tokens + now_partial = seq.num_cached_tokens < seq.num_tokens if now_partial != seq.is_partial_prefill: self._partial_prefill_count += 1 if now_partial else -1 seq.is_partial_prefill = now_partial @@ -1408,8 +1392,8 @@ def has_requests(self) -> bool: def get_next_batch_info(self) -> tuple[bool, int, int]: # Check for partial prefills in running (chunked prefill resume) for seq in self.running: - if seq.num_cached_tokens < seq.num_prompt_tokens: - remaining = seq.num_prompt_tokens - seq.num_cached_tokens + if seq.num_cached_tokens < seq.num_tokens: + remaining = seq.num_tokens - seq.num_cached_tokens chunk = min(remaining, self.max_num_batched_tokens) return (True, chunk, 1) # Only consider waiting seqs that are not blocked on a remote KV diff --git a/atom/utils/debug_helper/__init__.py b/atom/utils/debug_helper/__init__.py index 5c9fcc6d8e..d23de5e141 100644 --- a/atom/utils/debug_helper/__init__.py +++ b/atom/utils/debug_helper/__init__.py @@ -31,7 +31,6 @@ ) from atom.utils.debug_helper.dump import ( install_block_forward_hooks, - install_partial_prefill_dump_hooks, maybe_dump_weights_and_exit, maybe_log_topk, ) @@ -44,7 +43,6 @@ __all__ = [ # dump "install_block_forward_hooks", - "install_partial_prefill_dump_hooks", "maybe_dump_weights_and_exit", "maybe_log_topk", # compare primitives diff --git a/atom/utils/debug_helper/dump.py b/atom/utils/debug_helper/dump.py index 4c7443a1f8..0122efbb72 100644 --- a/atom/utils/debug_helper/dump.py +++ b/atom/utils/debug_helper/dump.py @@ -158,155 +158,6 @@ def _find_layer_id(mod_name: str) -> Optional[int]: return n -# === Partial-prefill chunk dump ====================================== - - -def install_partial_prefill_dump_hooks(model: torch.nn.Module) -> int: - """Install per-Block hooks that dump seq-0's chunk-end hidden state. - - Built for chunked-prefill regression debugging: when running with - --long-prefill-token-threshold, multiple seqs can be partial-prefill in the - same batch. To compare "1 seq per batch" vs "N seqs per batch" without - shape mismatch, we extract the hidden state at seq 0's chunk-end position - from each prefill forward and tag the file with the chunk index. - - Output filename: - {DIR}/layer{LL}_{Cls}_prompt0_chunk{N}_rank{R}.pt - where N is the cumulative number of prefill forwards that have included - seq 0 (= chunk-N's hidden state). - - Per-rank file contains: - hidden: [hidden_dim] — seq-0 last-token hidden state - cu_seqlens_q: full tensor (for sanity check) - bs: batch size this forward - chunk_idx: same as N above - """ - dump_dir = envs.ATOM_PARTIAL_DUMP_DIR - if not dump_dir: - return 0 - - os.makedirs(dump_dir, exist_ok=True) - wanted = _parse_layer_set(envs.ATOM_PARTIAL_DUMP_LAYERS) - block_classes = { - c.strip() for c in envs.ATOM_PARTIAL_DUMP_BLOCK_CLASS.split(",") if c.strip() - } - layer_attr = envs.ATOM_FWD_DUMP_LAYER_ATTR - rank = _get_rank() - - # Per-(layer, cls) chunk counter — increments once per prefill forward - # that includes seq 0. Decode forwards are skipped. - _chunk_counters: dict[tuple[int, str], int] = {} - - def _make_hook(layer_id: int, cls_name: str): - base = os.path.join(dump_dir, f"layer{layer_id:02d}_{cls_name}_rank{rank}") - - def _hook(_mod, _args, output): - import sys as _sys - try: - from atom.utils.forward_context import get_forward_context - - fwd_ctx = get_forward_context() - ctx = fwd_ctx.context - if not ctx.is_prefill or ctx.is_draft or ctx.is_dummy_run: - return - attn = fwd_ctx.attn_metadata - cu_q = attn.cu_seqlens_q - ctx_lens = attn.context_lens - if cu_q is None or cu_q.numel() < 2: - return - if ctx_lens is None or ctx_lens.numel() < 1: - return - t = output[0] if isinstance(output, tuple) else output - if not isinstance(t, torch.Tensor): - return - # Pick the first 2-D tensor arg as the per-prompt fingerprint - # source (DeepseekV2DecoderLayer.forward signature is - # (positions, hidden_states, residual); positions is 1-D, - # hidden_states is 2-D [tokens, hidden]). - inp = None - for a in (_args or ()): - if isinstance(a, torch.Tensor) and a.ndim == 2: - inp = a - break - - bs = int(cu_q.numel() - 1) - cu_q_cpu = cu_q.detach().cpu().tolist() - - import hashlib as _hl - - for sidx in range(bs): - seq_start = cu_q_cpu[sidx] - seq_end = cu_q_cpu[sidx + 1] - 1 - if seq_end < 0 or t.shape[0] <= seq_end: - continue - ctx_end = int(ctx_lens[sidx].item()) - - # Fingerprint: hash the FIRST 4 input rows of this seq's - # chunk. 4 × hidden_dim = ~28KB; uniquely identifies - # (prompt, chunk) across runs. - if inp is not None and inp.shape[0] > seq_start: - end = min(seq_start + 4, inp.shape[0]) - seg = inp[seq_start:end].detach().cpu().contiguous() - seg_bytes = seg.view(torch.uint8).numpy().tobytes() - fp = _hl.sha1(seg_bytes).hexdigest()[:12] - else: - fp = "noinp" - - hidden = t[seq_end].detach().cpu() - fname = f"{base}_ctx{ctx_end:05d}_fp{fp}_bs{bs}.pt" - if os.path.exists(fname): - continue - torch.save( - { - "hidden": hidden, - "ctx_end": ctx_end, - "fp": fp, - "bs": bs, - "sidx": sidx, - }, - fname, - ) - except Exception as exc: - print( - f"[ATOM_PARTIAL_DUMP] layer={layer_id} cls={cls_name} " - f"hook error: {exc!r}", - file=_sys.stderr, - flush=True, - ) - - return _hook - - block_layer_ids: dict[str, int] = {} - for name, mod in model.named_modules(): - lid = getattr(mod, layer_attr, None) - if lid is not None: - block_layer_ids[name] = int(lid) - - def _find_layer_id(mod_name: str) -> Optional[int]: - if mod_name in block_layer_ids: - return block_layer_ids[mod_name] - parts = mod_name.split(".") - for i in range(len(parts) - 1, 0, -1): - parent = ".".join(parts[:i]) - if parent in block_layer_ids: - return block_layer_ids[parent] - return None - - n = 0 - for name, mod in model.named_modules(): - cls = mod.__class__.__name__ - if cls not in block_classes: - continue - lid = _find_layer_id(name) - if lid is None: - continue - if wanted is not None and lid not in wanted: - continue - mod.register_forward_hook(_make_hook(lid, cls)) - n += 1 - return n - - # === Weight dump ===================================================== diff --git a/atom/utils/envs.py b/atom/utils/envs.py index bc91688aed..46554050b5 100644 --- a/atom/utils/envs.py +++ b/atom/utils/envs.py @@ -168,15 +168,6 @@ "ATOM_FWD_DUMP_LAYER_ATTR", "layer_id" ), "ATOM_FWD_DUMP_ONE_SHOT": lambda: os.getenv("ATOM_FWD_DUMP_ONE_SHOT", "1") == "1", - # Partial-prefill chunk dump: capture the last-token hidden state of seq 0 - # in each prefill batch, tagged by chunk index. Used to bisect "single - # partial seq" vs "multiple partial seqs in same batch" divergence. - # Set ATOM_PARTIAL_DUMP_DIR to enable. Independent of ATOM_FWD_DUMP_*. - "ATOM_PARTIAL_DUMP_DIR": lambda: os.getenv("ATOM_PARTIAL_DUMP_DIR", ""), - "ATOM_PARTIAL_DUMP_LAYERS": lambda: os.getenv("ATOM_PARTIAL_DUMP_LAYERS", "0,30,60"), - "ATOM_PARTIAL_DUMP_BLOCK_CLASS": lambda: os.getenv( - "ATOM_PARTIAL_DUMP_BLOCK_CLASS", "Block" - ), # Per-rank weight dump + sys.exit(0) — for byte-equal weight comparison. "ATOM_WEIGHT_DUMP_DIR": lambda: os.getenv("ATOM_WEIGHT_DUMP_DIR", ""), "ATOM_WEIGHT_DUMP_LAYERS": lambda: os.getenv("ATOM_WEIGHT_DUMP_LAYERS", "0"), From 287fbb78750f4443d369c225e0c84832d1c998d6 Mon Sep 17 00:00:00 2001 From: jiayyu Date: Fri, 12 Jun 2026 02:45:51 +0000 Subject: [PATCH 6/9] fix(mla): handle both-empty merge in the kernel, drop call-site workaround MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The earlier MLA NaN fix sanitized the seed lse at the call site (nan_to_num on suf_lse). Now that merge_attn_states lives in ATOM (triton_merge_attn_states.py, sole caller is the MLA chunked path — plugin/vllm imports vllm's own copy), fix it at the root instead. When a token's prefix AND suffix are both empty (max_lse == -inf), the kernel computed -inf-(-inf)=NaN and a 0/0 scale that poisoned the output. This is reachable in ATOM's global-axis chunked prefill: a short seq can fall entirely outside a chunk. Guard both_empty: force a finite 0/0-split so out=0 (correct for empty attention) and keep lse=-inf. The call-site nan_to_num is now redundant and reverts to chunked_lse = suf_lse. Verified GSM8K (R1-MXFP4, tp4, fp8, num_concurrent=64, long-prefill 512): 0.9431 — same as the call-site workaround, no regression. Co-Authored-By: Claude Opus 4 --- atom/model_ops/attention_mla.py | 15 ++++-------- .../attentions/triton_merge_attn_states.py | 23 +++++++++++++++---- 2 files changed, 22 insertions(+), 16 deletions(-) diff --git a/atom/model_ops/attention_mla.py b/atom/model_ops/attention_mla.py index 990a89360d..6354ba19cb 100644 --- a/atom/model_ops/attention_mla.py +++ b/atom/model_ops/attention_mla.py @@ -530,17 +530,10 @@ def _forward_prefill_cached_chunked( ) if chunked_out is None: chunked_out = suf_out - # A seq with no cached tokens in this (first) chunk gets - # lse=-inf from flash_attn. If that -inf seeds the running - # accumulator, a later merge against another -inf suffix (the - # same seq absent again) computes -inf-(-inf)=NaN in - # merge_attn_states, permanently poisoning that seq's output. - # Replace -inf with a large finite sentinel so the seq carries - # ~zero weight (exp(sentinel-max)=0) without producing NaN; its - # real prefix contribution is merged in once a later chunk - # covers it. Only the seed needs sanitizing — once the prefix - # side is finite, max_lse stays finite for all later merges. - chunked_lse = torch.nan_to_num(suf_lse, neginf=-1e30) + # A seq absent from this chunk has lse=-inf; both-(-inf) merges + # are handled inside merge_attn_states (see its both_empty + # guard), so the seed needs no sanitizing here. + chunked_lse = suf_lse else: tmp_out = torch.empty_like(new_out) tmp_lse = torch.empty_like(new_lse) diff --git a/atom/model_ops/attentions/triton_merge_attn_states.py b/atom/model_ops/attentions/triton_merge_attn_states.py index 4aefe3e85a..828496261a 100644 --- a/atom/model_ops/attentions/triton_merge_attn_states.py +++ b/atom/model_ops/attentions/triton_merge_attn_states.py @@ -128,15 +128,25 @@ def merge_attn_states_kernel( s_lse = float("-inf") if s_lse == float("inf") else s_lse max_lse = tl.maximum(p_lse, s_lse) - p_lse = p_lse - max_lse - s_lse = s_lse - max_lse + # Both prefix AND suffix are empty for this token (no KV on either side) -> + # max_lse == -inf. The naive `p_lse - max_lse` would compute -inf-(-inf)=NaN + # and `out_se` would be 0, making the scale 0/0=NaN that poisons the output. + # This happens in ATOM's global-axis chunked prefill: a short seq can fall + # entirely outside a chunk, so its tokens see an empty prefix AND suffix in + # that chunk. Force a safe 0/0-split: subtract a finite max so each side's + # exp is 0 (out = 0*p_out + 0*s_out = 0, correct for empty attention) and + # keep the merged lse at -inf so any downstream merge stays consistent. + both_empty = max_lse == float("-inf") + safe_max = tl.where(both_empty, 0.0, max_lse) + p_lse = p_lse - safe_max + s_lse = s_lse - safe_max # Will reuse precomputed Exp values for scale factor computation. p_se = tl.exp(p_lse) s_se = tl.exp(s_lse) out_se = p_se + s_se if OUTPUT_LSE: - out_lse = tl.log(out_se) + max_lse + out_lse = tl.where(both_empty, float("-inf"), tl.log(out_se) + safe_max) tl.store(output_lse + head_idx * num_tokens + token_idx, out_lse) p_out = tl.load( @@ -157,8 +167,11 @@ def merge_attn_states_kernel( # NOTE(woosuk): Be careful with the numerical stability. # We should compute the scale first, and then multiply it with the output. # Do not multiply the output with tl.exp(p_lse) or tl.exp(s_lse) directly. - p_scale = p_se / out_se - s_scale = s_se / out_se + # both_empty -> out_se == 0; guard the denominator so the scale is 0/1=0 + # (not 0/0=NaN). p_out/s_out are 0 for empty attention, so out stays 0. + safe_out_se = tl.where(both_empty, 1.0, out_se) + p_scale = p_se / safe_out_se + s_scale = s_se / safe_out_se out = p_out * p_scale + s_out * s_scale if USE_FP8: From ef553c5875282e5732192c166a6137609f2ddaae Mon Sep 17 00:00:00 2001 From: jiayyu Date: Fri, 12 Jun 2026 02:50:04 +0000 Subject: [PATCH 7/9] rm log --- atom/model_ops/attention_mla.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/atom/model_ops/attention_mla.py b/atom/model_ops/attention_mla.py index 6354ba19cb..ee7941da9a 100644 --- a/atom/model_ops/attention_mla.py +++ b/atom/model_ops/attention_mla.py @@ -530,9 +530,6 @@ def _forward_prefill_cached_chunked( ) if chunked_out is None: chunked_out = suf_out - # A seq absent from this chunk has lse=-inf; both-(-inf) merges - # are handled inside merge_attn_states (see its both_empty - # guard), so the seed needs no sanitizing here. chunked_lse = suf_lse else: tmp_out = torch.empty_like(new_out) From 8bb8ceb1688b4f16373f70a66cfc41dff5905e18 Mon Sep 17 00:00:00 2001 From: jiayyu Date: Tue, 16 Jun 2026 09:27:58 +0000 Subject: [PATCH 8/9] commit --- atom/config.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/atom/config.py b/atom/config.py index 541f23d2a1..f1d4e318d7 100644 --- a/atom/config.py +++ b/atom/config.py @@ -1177,16 +1177,7 @@ def __post_init__(self): v4_block_size = 128 if self.kv_cache_block_size != v4_block_size: self.kv_cache_block_size = v4_block_size - # TODO: V4's per-request SWA buffer cannot be restored from the classical - # KV pool on prefix cache hit, so disable prefix caching silently. - if self.enable_prefix_caching: - import logging - - logging.getLogger(__name__).warning( - "DeepSeek-V4 does not support prefix caching " - "(SWA buffer is not cacheable); disabling automatically." - ) - self.enable_prefix_caching = False + def compute_hash(self) -> str: """ From 9476eb854d3e6aa7b53fcddf656a3db2fc941666 Mon Sep 17 00:00:00 2001 From: jiayyu Date: Tue, 16 Jun 2026 10:50:30 +0000 Subject: [PATCH 9/9] fix format --- atom/config.py | 1 - 1 file changed, 1 deletion(-) diff --git a/atom/config.py b/atom/config.py index f1d4e318d7..1e70bb90d7 100644 --- a/atom/config.py +++ b/atom/config.py @@ -1178,7 +1178,6 @@ def __post_init__(self): if self.kv_cache_block_size != v4_block_size: self.kv_cache_block_size = v4_block_size - def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config,