Skip to content

Remove host syncs from compute_dacs_segsum_triton_varlen#969

Open
jgcb00 wants to merge 2 commits into
state-spaces:mainfrom
jgcb00:main
Open

Remove host syncs from compute_dacs_segsum_triton_varlen#969
jgcb00 wants to merge 2 commits into
state-spaces:mainfrom
jgcb00:main

Conversation

@jgcb00

@jgcb00 jgcb00 commented Jun 11, 2026

Copy link
Copy Markdown

Problem

The chunk-mapping construction loops over packed sequences reading scalars
from GPU tensors element by element:

for i in range(num_sequences):
    start = int(cu_seqlens[i].item())          # GPU->CPU sync
    n = int(chunks_per_seq[i].item())          # GPU->CPU sync
    state_seq_mapping[chunk_start:chunk_end] = i                  # kernel
    state_chunk_in_seq[chunk_start:chunk_end] = torch.arange(...) # kernel

That is 2*num_sequences + 1 GPU→CPU synchronizations plus
2*num_sequences tiny slice-fill kernels per call. Each .item() stalls
the CPU on the GPU stream. Under packed-varlen inference serving this runs
once per layer per prefill batch: with 29 mamba3 layers and 64 packed prompts
that is ~3,700 pipeline stalls (~100–200 ms) per prefill wave, dominating
time-to-first-token. Varlen training forwards pay the same pattern once per
layer per micro-batch.

Fix — fully vectorized on device, zero syncs

Sequence i owns the contiguous global chunk slots
[cu[i]//C + i, cu[i]//C + i + len_i//C + 1). The starts are strictly
increasing (floor(a+b) >= floor(a) + floor(b)), so the owner of slot g is
searchsorted(range_starts, g, right=True) - 1, slots at/after the owner's
range end are inactive padding, and the local chunk index is
g - range_starts[owner]:

range_starts = cu_seqlens[:-1] // C + arange(NS)
range_ends   = range_starts + (cu_seqlens[1:] - cu_seqlens[:-1]) // C + 1
owner  = searchsorted(range_starts, arange(nchunks), right=True) - 1
active = arange(nchunks) < range_ends[owner]
state_seq_mapping  = where(active, owner, 0)
state_chunk_in_seq = where(active, g - range_starts[owner], sentinel)

The inactive-slot sentinel (ceil(len_0 / C), same as before) stays a
device-side scalar. ~7 small kernels replace the loop; no .item()/.tolist()
anywhere, so the call no longer serializes the CPU against the GPU stream —
which also matters under async scheduling, where any sync blocks pipelining
of the next step's host work.

The old per-sequence overflow assert is dropped: by the same floor
inequality the last range end is always <= nchunks, so it could never fire.

Validation / measurements

  • Mapping tensors bit-identical to the original loop for: NS=1 short/long,
    lengths exactly multiple of C, single-token sequences, mixed lengths, and
    32 packed sequences.
  • End-to-end greedy outputs of our mamba3-MIMO hybrid unchanged.
  • In serving (vLLM, GH200): together with hoisting two smaller sync sites in
    our integration layer, prefill syncs dropped ~520 → ~7 per 8-prompt batch
    and TTFT at 256 concurrent requests dropped 1.31 s → 0.85 s.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant