Remove host syncs from compute_dacs_segsum_triton_varlen#969
Open
jgcb00 wants to merge 2 commits into
Open
Conversation
Remove host syncs from `compute_dacs_segsum_triton_varlen`
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Problem
The chunk-mapping construction loops over packed sequences reading scalars
from GPU tensors element by element:
That is
2*num_sequences + 1GPU→CPU synchronizations plus2*num_sequencestiny slice-fill kernels per call. Each.item()stallsthe CPU on the GPU stream. Under packed-varlen inference serving this runs
once per layer per prefill batch: with 29 mamba3 layers and 64 packed prompts
that is ~3,700 pipeline stalls (~100–200 ms) per prefill wave, dominating
time-to-first-token. Varlen training forwards pay the same pattern once per
layer per micro-batch.
Fix — fully vectorized on device, zero syncs
Sequence
iowns the contiguous global chunk slots[cu[i]//C + i, cu[i]//C + i + len_i//C + 1). The starts are strictlyincreasing (
floor(a+b) >= floor(a) + floor(b)), so the owner of slotgissearchsorted(range_starts, g, right=True) - 1, slots at/after the owner'srange end are inactive padding, and the local chunk index is
g - range_starts[owner]:The inactive-slot sentinel (
ceil(len_0 / C), same as before) stays adevice-side scalar. ~7 small kernels replace the loop; no
.item()/.tolist()anywhere, so the call no longer serializes the CPU against the GPU stream —
which also matters under async scheduling, where any sync blocks pipelining
of the next step's host work.
The old per-sequence overflow assert is dropped: by the same floor
inequality the last range end is always
<= nchunks, so it could never fire.Validation / measurements
lengths exactly multiple of C, single-token sequences, mixed lengths, and
32 packed sequences.
our integration layer, prefill syncs dropped ~520 → ~7 per 8-prompt batch
and TTFT at 256 concurrent requests dropped 1.31 s → 0.85 s.