From e54428efb39beeeec41c80d26ce41074cb677b84 Mon Sep 17 00:00:00 2001 From: ZhangLirong-amd Date: Thu, 18 Jun 2026 11:40:47 +0800 Subject: [PATCH 1/3] support TBO decode in Deepseek v4 --- atom/model_engine/model_runner.py | 4 +- atom/model_ops/attentions/deepseek_v4_attn.py | 296 +++++++++++++++++- atom/model_ops/module_dispatch_ops.py | 5 +- atom/model_ops/moe.py | 8 + atom/models/deepseek_v4.py | 41 ++- atom/utils/tbo/ubatch_wrapper.py | 69 +++- 6 files changed, 391 insertions(+), 32 deletions(-) diff --git a/atom/model_engine/model_runner.py b/atom/model_engine/model_runner.py index db018399e2..1d958e9b45 100644 --- a/atom/model_engine/model_runner.py +++ b/atom/model_engine/model_runner.py @@ -2311,9 +2311,9 @@ def capture_cudagraph(self): context.positions = mrope_positions num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens) num_tokens += num_pad - # Create ubatch slices for TBO capture (need >= 2 requests) + # Create ubatch slices for TBO capture (need > 2 requests) ubatch_slices = None - if is_tbo and self.config.enable_tbo_decode and bs >= 2: + if is_tbo and self.config.enable_tbo_decode and bs > 2: ubatch_slices = maybe_create_ubatch_slices( num_reqs=bs, num_tokens=num_tokens, diff --git a/atom/model_ops/attentions/deepseek_v4_attn.py b/atom/model_ops/attentions/deepseek_v4_attn.py index 2ff9de3dd3..d9421727d4 100644 --- a/atom/model_ops/attentions/deepseek_v4_attn.py +++ b/atom/model_ops/attentions/deepseek_v4_attn.py @@ -263,6 +263,9 @@ class DeepseekV4AttentionMetadataBuilder(CommonAttentionBuilder): block_size = 128 + # Number of micro-batches for Two-Batch Overlap (TBO). + _NUM_TBO_UBATCHES = 2 + def __init__(self, model_runner): super().__init__(model_runner) hf = model_runner.config.hf_config @@ -384,6 +387,8 @@ def __init__(self, model_runner): # `torch.as_tensor(arr)` allocations. self._alloc_v4_metadata_buffers() + self._ubatch_decode_meta: Optional[list] = None + @property def prep_stream(self): return self.model_runner.async_execute_stream @@ -1176,8 +1181,187 @@ def prepare_decode(self, batch: ScheduledBatch, bs: int): sum_scheduled_tokens, positions_gpu=positions, ) + + self._ubatch_decode_meta = None + if ( + self.model_runner.config.enable_tbo_decode + and scheduled_bs > 2 + and not batch.is_dummy_run + ): + self._prepare_ubatch_decode( + scheduled_bs=scheduled_bs, + bs=bs, + max_seqlen_q=max_seqlen_q, + context_lens_np=context_lens_np, + state_slot_np=state_slot_np, + positions_np=positions_np, + ) + return attn_metadata, positions + def _prepare_ubatch_decode( + self, + *, + scheduled_bs: int, + bs: int, + max_seqlen_q: int, + context_lens_np: np.ndarray, + state_slot_np: np.ndarray, + positions_np: np.ndarray, + ) -> None: + """Split a decode batch into two micro-batches (by request) and build + each one's V4 decode metadata into ``ub{0,1}_`` prefixed buffers. + + Mirrors :meth:`prepare_decode` but operates on a per-ubatch request + slice. The two resulting :class:`AttentionMetaData_DSV4` objects are + cached on ``self._ubatch_decode_meta`` and returned by + :meth:`build_ubatch_metadata`. + + Token layout in a decode fwd is request-major with ``max_seqlen_q`` + tokens per request, so ubatch token ranges fall on request boundaries. + """ + var = self.model_runner.forward_vars + N = self._NUM_TBO_UBATCHES + enforce_eager = self.model_runner.enforce_eager + if enforce_eager: + split_total = scheduled_bs + half = scheduled_bs // N + padded_list = [half, scheduled_bs - half] + ub_ranges = [(0, half), (half, split_total)] + else: + from atom.utils.tbo.ubatch_wrapper import UBatchWrapper + + ctx = get_forward_context() + padded_list = [ + UBatchWrapper._decode_ub_padded_bs(ctx, i, N, bs) for i in range(N) + ] + # Real-request ranges partition scheduled_bs; each ubatch owns up to + # its padded capacity, the tail ubatch takes the remainder. Pad rows + # beyond the real reqs carry sentinels (filled below). + ub_ranges = [] + req_start = 0 + for i in range(N): + if i == N - 1: + req_end = scheduled_bs + else: + req_end = min(scheduled_bs, req_start + padded_list[i]) + ub_ranges.append((req_start, req_end)) + req_start = req_end + split_total = scheduled_bs + + metas: list = [] + for ub_idx, (req_start, req_end) in enumerate(ub_ranges): + p = f"ub{ub_idx}_" + padded_bs = padded_list[ub_idx] + # Real requests that fall into this ubatch's [req_start, req_end), + # clamped to scheduled_bs (cudagraph pad rows beyond scheduled_bs + # carry sentinels, exercised only during capture's synthetic batch). + ub_real_reqs = max(0, min(scheduled_bs, req_end) - req_start) + tok_start = req_start * max_seqlen_q + ub_real_tokens = ub_real_reqs * max_seqlen_q + + # ---- per-seq slices into ub buffers ---- + ub_ctx_np = context_lens_np[req_start : req_start + ub_real_reqs] + var[f"{p}context_lens"].np[:ub_real_reqs] = ub_ctx_np + var[f"{p}context_lens"].np[ub_real_reqs:padded_bs] = 0 + + ub_state_np = state_slot_np[req_start : req_start + ub_real_reqs] + if len(ub_state_np) < ub_real_reqs: + ub_state_np = np.zeros(ub_real_reqs, dtype=np.int32) + var[f"{p}v4_meta_state_slot_groups"].np[:ub_real_reqs] = ub_state_np + var[f"{p}v4_meta_state_slot_groups"].np[ub_real_reqs:padded_bs] = 0 + state_slot_np_ub = ( + var[f"{p}v4_meta_state_slot_groups"].np[:padded_bs].copy() + ) + + var[f"{p}block_tables"].np[:ub_real_reqs] = var["block_tables"].np[ + req_start : req_start + ub_real_reqs + ] + var[f"{p}block_tables"].np[ub_real_reqs:padded_bs] = 0 + + # positions: copy the ubatch's token slice (values match the global + # positions slice the UBatchWrapper Context will expose). + ub_positions_np = positions_np[tok_start : tok_start + ub_real_tokens] + var[f"{p}positions"].np[:ub_real_tokens] = ub_positions_np + var[f"{p}positions"].np[ub_real_tokens : padded_bs * max_seqlen_q] = 0 + + # cu_seqlens_q: uniform max_seqlen_q per real req, padded tail flat. + cu = np.arange( + 0, (ub_real_reqs + 1) * max_seqlen_q, max_seqlen_q, dtype=np.int32 + ) + var[f"{p}cu_seqlens_q"].np[: ub_real_reqs + 1] = cu + var[f"{p}cu_seqlens_q"].np[ub_real_reqs + 1 : padded_bs + 1] = ( + ub_real_reqs * max_seqlen_q + ) + + # ---- H2D ---- + ub_sum_tokens = max(ub_real_tokens, 1) + positions_gpu = var[f"{p}positions"].copy_to_gpu(padded_bs * max_seqlen_q) + cu_seqlens_q_gpu = var[f"{p}cu_seqlens_q"].copy_to_gpu(padded_bs + 1) + context_lens_gpu = var[f"{p}context_lens"].copy_to_gpu(padded_bs) + block_tables_gpu = var[f"{p}block_tables"].copy_to_gpu(padded_bs) + state_slot_gpu = var[f"{p}v4_meta_state_slot_groups"].copy_to_gpu(padded_bs) + + # ---- compress plans (per ubatch buffer set) ---- + extend_lens_np = np.full(ub_real_reqs, max_seqlen_q, dtype=np.int32) + ctx_for_plan = context_lens_np[req_start : req_start + ub_real_reqs] + compress_plans = self._build_compress_plans( + extend_lens_np, + ctx_for_plan, + for_decode_cg=True, + buf_prefix_ubatch=p, + ) + + attn_metadata = AttentionMetaData_DSV4( + cu_seqlens_q=cu_seqlens_q_gpu, + cu_seqlens_k=None, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=int(ub_ctx_np.max()) if ub_real_reqs > 0 else 1, + min_seqlen_q=0, + dropout_p=0.0, + has_cached=False, + total_kv=int(ub_ctx_np.sum()) if ub_real_reqs > 0 else 0, + num_cached_tokens=None, + block_tables=block_tables_gpu, + context_lens=context_lens_gpu, + state=AttnState.DECODE, + ) + attn_metadata.state_slot_mapping = state_slot_gpu + attn_metadata.state_slot_mapping_cpu = state_slot_np_ub + attn_metadata.compress_plans = compress_plans + + # token_num_per_seq over PADDED bs (pad reqs contribute max_seqlen_q + # each so batch_id_per_token covers padded_total_tokens). + token_num_per_seq = np.full(ub_real_reqs, max_seqlen_q, dtype=np.int32) + self._attach_v4_per_fwd_meta( + attn_metadata, + token_num_per_seq, + state_slot_np_ub, + ub_real_reqs, + ub_real_tokens, + padded_bs=padded_bs, + max_q_len=max_seqlen_q, + buf_prefix_ubatch=p, + ) + self._attach_v4_indexer_meta( + attn_metadata, + max(ub_real_reqs, 1), + ub_sum_tokens, + positions_gpu=positions_gpu, + ) + metas.append(attn_metadata) + + self._ubatch_decode_meta = metas + + def build_ubatch_metadata( + self, ubatch_idx: int, padded_bs: int + ) -> AttentionMetaData_DSV4: + assert self._ubatch_decode_meta is not None, ( + "build_ubatch_metadata called but no ubatch decode metadata was " + "prepared — ensure enable_tbo_decode is set and prepare_decode ran." + ) + return self._ubatch_decode_meta[ubatch_idx] + def prepare_prefill(self, batch: ScheduledBatch): """V4 prefill prep: extends parent to always populate block_tables and state_slot_mapping. @@ -1454,6 +1638,7 @@ def _attach_v4_per_fwd_meta( *, padded_bs: Optional[int] = None, max_q_len: Optional[int] = None, + buf_prefix_ubatch: str = "", ) -> None: """Hoist per-fwd, layer-invariant metadata used by every V4 layer. @@ -1517,7 +1702,7 @@ def _attach_v4_per_fwd_meta( # context_lens is int32 on the buffer; keep dtype through divide so # n_committed_{csa,hca} stay int32 (max value ~max_model_len // 4 ≪ 2^31). - ctx_per_seq_np = var["context_lens"].np[:scheduled_bs] + ctx_per_seq_np = var[f"{buf_prefix_ubatch}context_lens"].np[:scheduled_bs] # Single source of truth for n_committed_{csa,hca}_per_seq on CPU. # Stashed on attn_metadata so paged_decode_meta / paged_prefill_meta / # v4_indexer_meta can read instead of each re-running `ctx // k`. @@ -1533,7 +1718,7 @@ def _attach_v4_per_fwd_meta( # from `var["positions"].gpu` — saves O(T·win) numpy work + 4 MB # staging buffer. The `positions` H2D is already done by the caller. attn_metadata.batch_id_per_token = self._stage( - "v4_batch_id_per_token", batch_id_per_token_np + f"{buf_prefix_ubatch}v4_batch_id_per_token", batch_id_per_token_np ) # Stage n_committed to GPU. For CG-replay safety: aiter # `top_k_per_row_decode` iterates the CAPTURED grid (= padded_bs * @@ -1548,7 +1733,7 @@ def _attach_v4_per_fwd_meta( # `batch_id_per_token = -1` sentinel masks pad rows out of # `csa_translate_pack`, so the value just needs to be "big enough" # to keep row_len non-negative. Use `index_topk` (≥ 1024 ≫ next_n). - n_csa_buf = var["v4_n_committed_csa_per_seq"] + n_csa_buf = var[f"{buf_prefix_ubatch}v4_n_committed_csa_per_seq"] n_csa_buf.np[:scheduled_bs] = n_committed_csa_per_seq_np if is_pure_decode and padded_bs is not None and padded_bs > scheduled_bs: n_csa_buf.np[scheduled_bs:padded_bs] = self.index_topk @@ -1563,6 +1748,7 @@ def _attach_v4_per_fwd_meta( scheduled_bs=scheduled_bs, total_tokens=total_tokens, padded_total_tokens=padded_total_tokens, + buf_prefix_ubatch=buf_prefix_ubatch, ) def _attach_v4_paged_decode_meta( @@ -1573,6 +1759,7 @@ def _attach_v4_paged_decode_meta( scheduled_bs: int, total_tokens: int, padded_total_tokens: Optional[int] = None, + buf_prefix_ubatch: str = "", ) -> None: """Phase B: build per-fwd paged-decode index buffers (layer-invariant). @@ -1653,7 +1840,7 @@ def _attach_v4_paged_decode_meta( # n_csa (which happens for early tokens in chunked-prefill verify # batches and MTP draft mid-iters). index_topk = self.index_topk - positions_np_view = var["positions"].np[:T] + positions_np_view = var[f"{buf_prefix_ubatch}positions"].np[:T] n_committed_hca_per_token = n_committed_hca_per_seq[batch_id_per_token_np] # actual_swa_count[t] = min(positions[t]+1, win). Matches the kernel's @@ -1703,14 +1890,20 @@ def _attach_v4_paged_decode_meta( if T_pad > T: hca_indptr_np[T + 1 :].fill(int(hca_indptr_np[T])) - swa_indptr_gpu = self._stage("v4_kv_indptr_swa", swa_indptr_np) - csa_indptr_gpu = self._stage("v4_kv_indptr_csa", csa_indptr_np) - hca_indptr_gpu = self._stage("v4_kv_indptr_hca", hca_indptr_np) + swa_indptr_gpu = self._stage( + f"{buf_prefix_ubatch}v4_kv_indptr_swa", swa_indptr_np + ) + csa_indptr_gpu = self._stage( + f"{buf_prefix_ubatch}v4_kv_indptr_csa", csa_indptr_np + ) + hca_indptr_gpu = self._stage( + f"{buf_prefix_ubatch}v4_kv_indptr_hca", hca_indptr_np + ) # batch_id_per_token + n_committed_csa_per_seq already staged in # `_attach_v4_per_fwd_meta`. # ----- HCA compress paged offsets (CPU numpy, vectorized) ----- - block_tables_np_full = var["block_tables"].np[:scheduled_bs] + block_tables_np_full = var[f"{buf_prefix_ubatch}block_tables"].np[:scheduled_bs] hca_total_indices = int(hca_indptr_np[T]) hca_indices_np = np.full(hca_total_indices, -1, dtype=np.int32) # n_committed_hca_per_seq is int32; gather stays int32. @@ -1732,7 +1925,9 @@ def _attach_v4_paged_decode_meta( swa_pages + block_tables_np_full[bid_expanded, entry_offsets] ).astype(np.int32) # Stage to GPU (HCA compress section at head; SWA prefix scattered below). - hca_indices_gpu = self._stage("v4_kv_indices_hca", hca_indices_np) + hca_indices_gpu = self._stage( + f"{buf_prefix_ubatch}v4_kv_indices_hca", hca_indices_np + ) # ----- Write SWA / CSA / HCA window-prefix paged offsets (1 kernel) ----- # Kernel computes `n = min(positions[t]+1, win)` and ring-index @@ -1742,12 +1937,12 @@ def _attach_v4_paged_decode_meta( # persistent forward_vars buffers — no allocator churn (the prior # `index_copy_` chain raced under MTP-3 long-prefill; this kernel # also fixes that, see skill `debug-agent-locate-kernel`). - swa_indices_gpu = var["v4_kv_indices_swa"].gpu - csa_indices_gpu = var["v4_kv_indices_csa"].gpu + swa_indices_gpu = var[f"{buf_prefix_ubatch}v4_kv_indices_swa"].gpu + csa_indices_gpu = var[f"{buf_prefix_ubatch}v4_kv_indices_csa"].gpu write_v4_paged_decode_indices( state_slot_per_seq=attn_metadata.state_slot_mapping, batch_id_per_token=batch_id_per_token_gpu, - positions=var["positions"].gpu, + positions=var[f"{buf_prefix_ubatch}positions"].gpu, swa_indptr=swa_indptr_gpu, csa_indptr=csa_indptr_gpu, hca_indptr=hca_indptr_gpu, @@ -1992,7 +2187,12 @@ def _build_paged_prefill_meta( attn_metadata.swa_pages = swa_pages def _build_compress_plans( - self, extend_lens_np, context_lens_np, *, for_decode_cg: bool + self, + extend_lens_np, + context_lens_np, + *, + for_decode_cg: bool, + buf_prefix_ubatch: str = "", ): """Build per-ratio CompressPlan dict consumed by batched compressor. @@ -2028,8 +2228,8 @@ def _build_compress_plans( var = self.model_runner.forward_vars plan_buffers = { ratio: { - "compress": var[f"v4_compress_plan_{ratio}"], - "write": var[f"v4_write_plan_{ratio}"], + "compress": var[f"{buf_prefix_ubatch}v4_compress_plan_{ratio}"], + "write": var[f"{buf_prefix_ubatch}v4_write_plan_{ratio}"], } for ratio, _ in self._unique_compress_ratios_overlap } @@ -2195,6 +2395,16 @@ def build_for_cudagraph_capture( positions_gpu=positions, ) + if self.model_runner.config.enable_tbo_decode and bs > 2: + self._prepare_ubatch_decode( + scheduled_bs=bs, + bs=bs, + max_seqlen_q=max_q_len, + context_lens_np=context_lens_np, + state_slot_np=state_slot_np, + positions_np=positions_np.astype(np.int32), + ) + context = Context( positions=positions, is_prefill=False, @@ -2347,8 +2557,64 @@ def _alloc_v4_metadata_buffers(self) -> None: per_seq_max = (self.max_spec_steps + ratio) // ratio self._decode_compress_cap[ratio] = bs * per_seq_max + if getattr(self.model_runner.config, "enable_tbo_decode", False): + self._alloc_v4_ubatch_decode_buffers(bufs, i32, i64) + self.model_runner.forward_vars.update(bufs) + def _alloc_v4_ubatch_decode_buffers(self, bufs: dict, i32: dict, i64: dict) -> None: + """Clone decode-path metadata buffers into ``ub{0,1}_`` prefixed sets. + + Mirrors the sizes chosen in :meth:`_alloc_v4_metadata_buffers` for the + decode-relevant buffers plus the global per-fwd inputs the decode + helpers read (``positions`` / ``context_lens`` / ``block_tables`` / + ``cu_seqlens_q``). Only invoked when ``enable_tbo_decode`` is set. + """ + mnbt = self.max_num_batched_tokens + bs = self.max_bs + win = self.window_size + T_dec = self.max_decode_tokens + max_q_len = 1 + self.max_spec_steps + max_blocks = self.max_num_blocks_per_seq // self.block_ratio + + for ub_idx in range(self._NUM_TBO_UBATCHES): + p = f"ub{ub_idx}_" + # Global per-fwd decode inputs (live in model_runner.forward_vars + # for the non-TBO path; cloned here so each ubatch slices its own). + bufs[f"{p}positions"] = CpuGpuBuffer(T_dec, **i64) + bufs[f"{p}context_lens"] = CpuGpuBuffer(bs, **i32) + bufs[f"{p}block_tables"] = CpuGpuBuffer(bs, max_blocks, **i32) + bufs[f"{p}cu_seqlens_q"] = CpuGpuBuffer(bs + 1, **i32) + + # V4 decode metadata buffers. + bufs[f"{p}v4_meta_state_slot_groups"] = CpuGpuBuffer(bs, **i32) + bufs[f"{p}v4_kv_indices_swa"] = CpuGpuBuffer(T_dec * win, **i32) + bufs[f"{p}v4_kv_indices_csa"] = CpuGpuBuffer( + T_dec * (win + self.index_topk), **i32 + ) + bufs[f"{p}v4_kv_indices_hca"] = CpuGpuBuffer( + T_dec * (win + self.max_committed_hca), **i32 + ) + bufs[f"{p}v4_kv_indptr_swa"] = CpuGpuBuffer(T_dec + 1, **i32) + bufs[f"{p}v4_kv_indptr_csa"] = CpuGpuBuffer(T_dec + 1, **i32) + bufs[f"{p}v4_kv_indptr_hca"] = CpuGpuBuffer(T_dec + 1, **i32) + bufs[f"{p}v4_n_committed_csa_per_seq"] = CpuGpuBuffer(bs, **i32) + bufs[f"{p}v4_batch_id_per_token"] = CpuGpuBuffer(mnbt, **i64) + bufs[f"{p}v4_indexer_cu_committed"] = CpuGpuBuffer(bs + 1, **i32) + + for ratio, is_overlap in self._unique_compress_ratios_overlap: + K_pool = (2 if is_overlap else 1) * ratio + max_compress = mnbt // ratio + bs + max_write = min(mnbt, bs * K_pool) + cbuf = CpuGpuBuffer(max_compress, 4, **i32) + wbuf = CpuGpuBuffer(max_write, 4, **i32) + cbuf.cpu.fill_(-1) + cbuf.copy_to_gpu() + wbuf.cpu.fill_(-1) + wbuf.copy_to_gpu() + bufs[f"{p}v4_compress_plan_{ratio}"] = cbuf + bufs[f"{p}v4_write_plan_{ratio}"] = wbuf + def _stage(self, name: str, arr) -> torch.Tensor: """Write numpy `arr` into `forward_vars[name]` (CpuGpuBuffer) and return its GPU view sliced to len(arr). Auto-casts dtype to match diff --git a/atom/model_ops/module_dispatch_ops.py b/atom/model_ops/module_dispatch_ops.py index 096e2c013c..3a3e5d9c3b 100644 --- a/atom/model_ops/module_dispatch_ops.py +++ b/atom/model_ops/module_dispatch_ops.py @@ -51,7 +51,10 @@ def maybe_dual_stream_forward( ] threshold = envs.ATOM_DUAL_STREAM_MOE_TOKEN_THRESHOLD num_tokens = hidden_states.shape[0] - if self._use_dual_stream and 0 < num_tokens <= threshold: + # Under TBO the two micro-batches already overlap on separate threads + from atom.utils.tbo.ubatching import tbo_active + + if self._use_dual_stream and 0 < num_tokens <= threshold and not tbo_active(): return self.dual_stream_moe_forward(hidden_states) return self.single_stream_moe_forward(hidden_states) diff --git a/atom/model_ops/moe.py b/atom/model_ops/moe.py index ad45981e36..8a15613461 100644 --- a/atom/model_ops/moe.py +++ b/atom/model_ops/moe.py @@ -243,6 +243,14 @@ def all_gather_with_padding( # use_custom=False falls back to torch.distributed.all_gather_into_tensor # (NCCL), whose WorkNCCL end-event recorded inside CUDAGraph capture is # later queried by the watchdog thread -> hipErrorCapturedEvent crash. + # + # Under TBO the two micro-batch threads issue this collective concurrently. + # The custom CA/IPC all-gather uses a single process-wide signal/workspace, + # which two concurrent threads corrupt -> cross-rank deadlock. Fall back to + # the stock NCCL all_gather (thread-safe) while a TBO overlap is active. + from atom.utils.tbo.ubatching import tbo_active + + use_cag = use_cag and not tbo_active() gathered_hidden_states = get_dp_group().all_gather( padded_x, use_custom=use_cag, dim=0 ) diff --git a/atom/models/deepseek_v4.py b/atom/models/deepseek_v4.py index fc9e88431b..2fd6c8e434 100644 --- a/atom/models/deepseek_v4.py +++ b/atom/models/deepseek_v4.py @@ -874,6 +874,12 @@ def __init__( ) self.norm = RMSNorm(self.head_dim, args.norm_eps) + # Fixed CUDAGraph-stable scratch for `wkv_gate(x)` output on the captured + # decode path (lazily sized on first graph forward; see forward()). Keyed + # per TBO ubatch id so the two concurrent ubatch threads never share the + # same scratch (mirrors the per-ubatch topK / mori buffers). + self._combined_cg_buf: dict = {} + # External tensors — assigned by the owning Attention / Indexer at first forward. self.kv_cache: Optional[torch.Tensor] = None self.rotary_emb: Optional[_V4RoPE] = None @@ -980,6 +986,25 @@ def forward( # stride must be 1). coff_d = (1 + overlap) * d combined = self.wkv_gate(x) + # TBO decode: copy `combined` into a fixed-address buffer so CUDAGraph + # capture/replay see a stable pointer (allocator may re-place it). + from atom.utils.tbo.ubatching import tbo_active, tbo_current_ubatch_id + + _fc = get_forward_context() + if getattr(_fc, "in_hipgraph", False) and tbo_active(): + ub = tbo_current_ubatch_id() + n_tok = combined.shape[0] + buf = self._combined_cg_buf.get(ub) + if buf is None or buf.shape[0] < n_tok or buf.shape[1] != combined.shape[1]: + buf = torch.empty( + combined.shape[0], + combined.shape[1], + dtype=combined.dtype, + device=combined.device, + ) + self._combined_cg_buf[ub] = buf + buf[:n_tok].copy_(combined) + combined = buf[:n_tok] kv, score = torch.split(combined, [coff_d, coff_d], dim=-1) # ====== Unified fused kernel path (CSA + Indexer) ====== @@ -1322,10 +1347,12 @@ def _score_topk_decode( """ total_tokens = q_fp8.size(0) n_committed_per_seq_gpu = indexer_meta["n_committed_per_seq_gpu"] # int32 [bs] - bs = block_tables.size(0) - # V4-Pro has no MTP, so next_n = total_tokens // bs = 1. The reshape - # also handles future multi-token decode (MTP) without code change. - next_n = total_tokens // bs + # NOTE: derive the query batch size from the ACTUAL number of query + # tokens, NOT from block_tables.size(0). Under TBO the per-ubatch + # block_tables / n_committed are padded to a DP-unified bucket and will + # get errors if we try to use the padded rows. + next_n = max(1, int(get_forward_context().attn_metadata.max_seqlen_q)) + bs = total_tokens // next_n # deepgemm requires Q in [bs, next_n, heads, head_dim], KV in # [num_blocks, block_size, n_head=1, hidden_dim+scale_dim] (4D). q_4d = q_fp8.view( @@ -1632,7 +1659,11 @@ def maybe_compressors_async( Waits resolve instantly: side streams ~25us, main Q/KV chain ~87us.""" fc = get_forward_context() current_stream = fc.main_stream - use_async_compress = self._use_async_compress and fc.in_hipgraph + from atom.utils.tbo.ubatching import tbo_active + + use_async_compress = ( + self._use_async_compress and fc.in_hipgraph and not tbo_active() + ) has_compressor = self.compressor is not None has_indexer = self.indexer is not None and not self.skip_topk if use_async_compress: diff --git a/atom/utils/tbo/ubatch_wrapper.py b/atom/utils/tbo/ubatch_wrapper.py index d01f0c7daf..622e54f4d8 100644 --- a/atom/utils/tbo/ubatch_wrapper.py +++ b/atom/utils/tbo/ubatch_wrapper.py @@ -73,6 +73,8 @@ def _run_ubatches( N = len(ctx.ubatch_slices) compute_stream = torch.cuda.current_stream() + ub_dp_metadata = self._make_ubatch_dp_metadata(ctx, N) + full_graph_bs = ctx.context.graph_bs forward_contexts = [] ub_inputs = [] @@ -94,10 +96,7 @@ def _run_ubatches( if ctx.context.is_prefill: padded_bs = ub_num_reqs else: - if i < N - 1: - padded_bs = full_graph_bs // N - else: - padded_bs = full_graph_bs - (full_graph_bs // N) * (N - 1) + padded_bs = self._decode_ub_padded_bs(ctx, i, N, full_graph_bs) ub_ctx = self._make_ubatch_context( original_ctx, ub_slice, @@ -105,6 +104,7 @@ def _run_ubatches( i, ub_num_reqs, ub_graph_bs=ub_graph_bs_list[i], + dp_metadata=ub_dp_metadata[i] if ub_dp_metadata is not None else None, ) forward_contexts.append(ub_ctx) ub_token_slice = ( @@ -216,6 +216,7 @@ def capture_tbo_graph( full_graph_bs = ctx.context.graph_bs # only padding for all_gather/reduce_scatter pass all_gahter_dp_size = self._get_dp_size() if self.dp_gather_scatter else 1 + ub_dp_metadata = self._make_ubatch_dp_metadata(ctx, N) forward_contexts = [] ub_inputs = [] for i, ub_slice in enumerate(ctx.ubatch_slices): @@ -229,6 +230,7 @@ def capture_tbo_graph( padded_bs, i, ub_graph_bs=padded_bs * all_gahter_dp_size, + dp_metadata=ub_dp_metadata[i] if ub_dp_metadata is not None else None, ) forward_contexts.append(ub_ctx) ub_inputs.append( @@ -324,6 +326,55 @@ def _get_dp_size() -> int: except Exception: return 1 + def _make_ubatch_dp_metadata(self, ctx: ForwardContext, N: int): + """Build per-ubatch :class:`DPMetadata` so the MoE DP collective uses + each ubatch's own per-rank token counts. + + Returns ``None`` when DP is disabled / no dp_metadata on the parent + context (the shared metadata is then reused, which is correct for the + single-rank case). Otherwise returns a list of length ``N``. + + Each ubatch's per-rank token count is obtained with the same CPU + all_reduce that :meth:`DPMetadata.num_tokens_across_dp` uses, one per + ubatch. This is a CPU collective (cheap) and keeps every rank's + all_gatherv / reduce_scatterv consistently sized. + """ + if ctx.dp_metadata is None: + return None + from atom.config import get_current_atom_config + from atom.utils.forward_context import DPMetadata + + parallel_config = get_current_atom_config().parallel_config + metas = [] + for ub_slice in ctx.ubatch_slices: + ub_tokens = ub_slice.token_slice.stop - ub_slice.token_slice.start + metas.append(DPMetadata.make(parallel_config, int(ub_tokens), None)) + return metas + + @staticmethod + def _decode_ub_padded_bs( + ctx: ForwardContext, i: int, N: int, full_graph_bs: int + ) -> int: + """Per-ubatch padded request count for a decode micro-batch. + + Must be IDENTICAL across DP ranks: the MoE all_gather/reduce_scatter + pads each ubatch to this size, so a per-rank-local split (which differs + when ranks carry different decode batch sizes, e.g. during drain) + desyncs the collective and faults. Derive it from the DP-unified + ``ub_max_tokens_across_dp`` (MAX-reduced in ModelRunner._preprocess), + converting the per-ubatch token max back to a request count via + ``max_seqlen_q``. Falls back to the local split only when DP is off or + the precomputed value is unavailable. + """ + ub_max = ctx.ub_max_tokens_across_dp + if ub_max is not None and len(ub_max) == N: + max_q = getattr(ctx.attn_metadata, "max_seqlen_q", 1) or 1 + return max(1, ub_max[i] // max_q) + # Fallback: local split (single-rank / value not precomputed). + if i < N - 1: + return full_graph_bs // N + return full_graph_bs - (full_graph_bs // N) * (N - 1) + @staticmethod def _compute_ub_graph_bs( ctx: ForwardContext, @@ -353,12 +404,11 @@ def _compute_ub_graph_bs( ub_sizes.append(ub_num_tokens) return ub_sizes else: + # Decode: use the DP-unified per-ubatch padded_bs (same value on + # every rank) so MoE's pad_for_all_gather sizes match cross-rank. result = [] for i in range(N): - if i < N - 1: - padded_bs = full_graph_bs // N - else: - padded_bs = full_graph_bs - (full_graph_bs // N) * (N - 1) + padded_bs = UBatchWrapper._decode_ub_padded_bs(ctx, i, N, full_graph_bs) result.append(padded_bs * dp_size) return result @@ -370,6 +420,7 @@ def _make_ubatch_context( ubatch_idx: int = 0, actual_num_reqs: int | None = None, ub_graph_bs: int | None = None, + dp_metadata=None, ) -> ForwardContext: """Build a ForwardContext for a single micro-batch.""" ub_num_reqs = ub_slice.request_slice.stop - ub_slice.request_slice.start @@ -410,7 +461,7 @@ def _make_ubatch_context( no_compile_layers=ctx.no_compile_layers, kv_cache_data=ctx.kv_cache_data, context=ub_context, - dp_metadata=ctx.dp_metadata, # shared across ubatches + dp_metadata=dp_metadata if dp_metadata is not None else ctx.dp_metadata, spec_decode_metadata=None, # not supported with TBO ubatch_slices=None, # prevent recursion main_stream=ctx.main_stream, From b4b3b7e6c4a33882b6601b85812c874e31eeb178 Mon Sep 17 00:00:00 2001 From: ZhangLirong-amd Date: Thu, 18 Jun 2026 12:15:41 +0800 Subject: [PATCH 2/3] format --- atom/model_ops/attentions/deepseek_v4_attn.py | 33 ++++++------------- 1 file changed, 10 insertions(+), 23 deletions(-) diff --git a/atom/model_ops/attentions/deepseek_v4_attn.py b/atom/model_ops/attentions/deepseek_v4_attn.py index d9421727d4..d9935da0ff 100644 --- a/atom/model_ops/attentions/deepseek_v4_attn.py +++ b/atom/model_ops/attentions/deepseek_v4_attn.py @@ -1270,9 +1270,7 @@ def _prepare_ubatch_decode( ub_state_np = np.zeros(ub_real_reqs, dtype=np.int32) var[f"{p}v4_meta_state_slot_groups"].np[:ub_real_reqs] = ub_state_np var[f"{p}v4_meta_state_slot_groups"].np[ub_real_reqs:padded_bs] = 0 - state_slot_np_ub = ( - var[f"{p}v4_meta_state_slot_groups"].np[:padded_bs].copy() - ) + state_slot_np_ub = var[f"{p}v4_meta_state_slot_groups"].np[:padded_bs].copy() var[f"{p}block_tables"].np[:ub_real_reqs] = var["block_tables"].np[ req_start : req_start + ub_real_reqs @@ -1890,15 +1888,9 @@ def _attach_v4_paged_decode_meta( if T_pad > T: hca_indptr_np[T + 1 :].fill(int(hca_indptr_np[T])) - swa_indptr_gpu = self._stage( - f"{buf_prefix_ubatch}v4_kv_indptr_swa", swa_indptr_np - ) - csa_indptr_gpu = self._stage( - f"{buf_prefix_ubatch}v4_kv_indptr_csa", csa_indptr_np - ) - hca_indptr_gpu = self._stage( - f"{buf_prefix_ubatch}v4_kv_indptr_hca", hca_indptr_np - ) + swa_indptr_gpu = self._stage(f"{buf_prefix_ubatch}v4_kv_indptr_swa", swa_indptr_np) + csa_indptr_gpu = self._stage(f"{buf_prefix_ubatch}v4_kv_indptr_csa", csa_indptr_np) + hca_indptr_gpu = self._stage(f"{buf_prefix_ubatch}v4_kv_indptr_hca", hca_indptr_np) # batch_id_per_token + n_committed_csa_per_seq already staged in # `_attach_v4_per_fwd_meta`. @@ -1925,9 +1917,7 @@ def _attach_v4_paged_decode_meta( swa_pages + block_tables_np_full[bid_expanded, entry_offsets] ).astype(np.int32) # Stage to GPU (HCA compress section at head; SWA prefix scattered below). - hca_indices_gpu = self._stage( - f"{buf_prefix_ubatch}v4_kv_indices_hca", hca_indices_np - ) + hca_indices_gpu = self._stage(f"{buf_prefix_ubatch}v4_kv_indices_hca", hca_indices_np) # ----- Write SWA / CSA / HCA window-prefix paged offsets (1 kernel) ----- # Kernel computes `n = min(positions[t]+1, win)` and ring-index @@ -2187,12 +2177,7 @@ def _build_paged_prefill_meta( attn_metadata.swa_pages = swa_pages def _build_compress_plans( - self, - extend_lens_np, - context_lens_np, - *, - for_decode_cg: bool, - buf_prefix_ubatch: str = "", + self, extend_lens_np, context_lens_np, *, for_decode_cg: bool, buf_prefix_ubatch: str = "" ): """Build per-ratio CompressPlan dict consumed by batched compressor. @@ -2395,7 +2380,10 @@ def build_for_cudagraph_capture( positions_gpu=positions, ) - if self.model_runner.config.enable_tbo_decode and bs > 2: + if ( + self.model_runner.config.enable_tbo_decode + and bs > 2 + ): self._prepare_ubatch_decode( scheduled_bs=bs, bs=bs, @@ -2574,7 +2562,6 @@ def _alloc_v4_ubatch_decode_buffers(self, bufs: dict, i32: dict, i64: dict) -> N bs = self.max_bs win = self.window_size T_dec = self.max_decode_tokens - max_q_len = 1 + self.max_spec_steps max_blocks = self.max_num_blocks_per_seq // self.block_ratio for ub_idx in range(self._NUM_TBO_UBATCHES): From 02645321dd614ef843e17200c5073fce22e393d7 Mon Sep 17 00:00:00 2001 From: ZhangLirong-amd Date: Thu, 18 Jun 2026 12:19:03 +0800 Subject: [PATCH 3/3] format --- atom/model_ops/attentions/deepseek_v4_attn.py | 32 +++++++++++++------ 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/atom/model_ops/attentions/deepseek_v4_attn.py b/atom/model_ops/attentions/deepseek_v4_attn.py index d9935da0ff..fe2e2a3b65 100644 --- a/atom/model_ops/attentions/deepseek_v4_attn.py +++ b/atom/model_ops/attentions/deepseek_v4_attn.py @@ -1270,7 +1270,9 @@ def _prepare_ubatch_decode( ub_state_np = np.zeros(ub_real_reqs, dtype=np.int32) var[f"{p}v4_meta_state_slot_groups"].np[:ub_real_reqs] = ub_state_np var[f"{p}v4_meta_state_slot_groups"].np[ub_real_reqs:padded_bs] = 0 - state_slot_np_ub = var[f"{p}v4_meta_state_slot_groups"].np[:padded_bs].copy() + state_slot_np_ub = ( + var[f"{p}v4_meta_state_slot_groups"].np[:padded_bs].copy() + ) var[f"{p}block_tables"].np[:ub_real_reqs] = var["block_tables"].np[ req_start : req_start + ub_real_reqs @@ -1888,9 +1890,15 @@ def _attach_v4_paged_decode_meta( if T_pad > T: hca_indptr_np[T + 1 :].fill(int(hca_indptr_np[T])) - swa_indptr_gpu = self._stage(f"{buf_prefix_ubatch}v4_kv_indptr_swa", swa_indptr_np) - csa_indptr_gpu = self._stage(f"{buf_prefix_ubatch}v4_kv_indptr_csa", csa_indptr_np) - hca_indptr_gpu = self._stage(f"{buf_prefix_ubatch}v4_kv_indptr_hca", hca_indptr_np) + swa_indptr_gpu = self._stage( + f"{buf_prefix_ubatch}v4_kv_indptr_swa", swa_indptr_np + ) + csa_indptr_gpu = self._stage( + f"{buf_prefix_ubatch}v4_kv_indptr_csa", csa_indptr_np + ) + hca_indptr_gpu = self._stage( + f"{buf_prefix_ubatch}v4_kv_indptr_hca", hca_indptr_np + ) # batch_id_per_token + n_committed_csa_per_seq already staged in # `_attach_v4_per_fwd_meta`. @@ -1917,7 +1925,9 @@ def _attach_v4_paged_decode_meta( swa_pages + block_tables_np_full[bid_expanded, entry_offsets] ).astype(np.int32) # Stage to GPU (HCA compress section at head; SWA prefix scattered below). - hca_indices_gpu = self._stage(f"{buf_prefix_ubatch}v4_kv_indices_hca", hca_indices_np) + hca_indices_gpu = self._stage( + f"{buf_prefix_ubatch}v4_kv_indices_hca", hca_indices_np + ) # ----- Write SWA / CSA / HCA window-prefix paged offsets (1 kernel) ----- # Kernel computes `n = min(positions[t]+1, win)` and ring-index @@ -2177,7 +2187,12 @@ def _build_paged_prefill_meta( attn_metadata.swa_pages = swa_pages def _build_compress_plans( - self, extend_lens_np, context_lens_np, *, for_decode_cg: bool, buf_prefix_ubatch: str = "" + self, + extend_lens_np, + context_lens_np, + *, + for_decode_cg: bool, + buf_prefix_ubatch: str = "", ): """Build per-ratio CompressPlan dict consumed by batched compressor. @@ -2380,10 +2395,7 @@ def build_for_cudagraph_capture( positions_gpu=positions, ) - if ( - self.model_runner.config.enable_tbo_decode - and bs > 2 - ): + if self.model_runner.config.enable_tbo_decode and bs > 2: self._prepare_ubatch_decode( scheduled_bs=bs, bs=bs,