Add state_batch_indices to mamba3_step_fn (paged-inference state pools, parity with selective_state_update)#971
Open
jgcb00 wants to merge 1 commit into
Open
Conversation
Mirror of selective_state_update's state_batch_indices for the mamba3 MIMO step kernel: when given, state/Bstate/Xstate are pools of shape (P, ...) and the kernel reads/updates row state_batch_indices[b] in place — removes the gather/scatter round-trip on the fp32 SSM state that inference servers otherwise pay (58% of decode GPU time at batch 256 in our vLLM deployment; ~2x end-to-end decode throughput). Padding semantics match selective_state_update: negative indices (CUDA-graph batch padding) skip the token — output zeroed, state untouched. Optional update_kv_state (explicit, asserts tile_D >= headdim): the kernel also stores this step's B/x into Bstate/Xstate after consuming the old values, replacing the caller's scatter kernels (single CTA per (b, h) row, no cross-CTA hazard). Byte-identical vs the gather/scatter reference incl. padded lanes; see tests/ops/test_mamba3_step_indexed.py. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
Inference servers (vLLM and friends) keep per-request recurrent states in a
pool of
Prows and address them by slot index, because requests come andgo at different times. Mamba2 supports this directly:
selective_state_update(..., state_batch_indices=...)reads and updates poolrows in place. The Mamba3 MIMO step kernel has no equivalent, so a server must
do, per layer, per decode step:
That is 3× the necessary HBM traffic on the dominant tensor of decode (the
fp32 SSM state — e.g. 48 heads × 64 × 128 = 1.5 MB per layer per request).
Profiling a mamba3-MIMO hybrid served with vLLM on GH200, this round-trip
accounted for 58% of all decode GPU time at batch 256; removing it gave
~2× end-to-end decode throughput (3.5k → 8k tok/s at 256 concurrent
requests). After the change the kernel runs at ~87% of HBM peak reading and
writing the pools directly.
What this PR does
Adds an optional
state_batch_indices: (batch,) int32argument tomamba3_step_fn, mirroring theselective_state_updatecontract:state/Bstate/Xstateare pools of shape(P, ...)andthe kernel reads/updates row
state_batch_indices[b]in place(
state_outmust beNone). The pool row count is independent of thebatch size (separate symbolic dim; the launch grid takes batch from
x).index (e.g. CUDA-graph batch padding) produces a zeroed output and leaves
all states untouched. Implementation note: instead of Mamba2's early
return, loads are clamped to row 0 and all writes are masked — CuteDSL'scp.asynccommit/wait pipeline spans the kernel body and does not admit amid-kernel return; the observable contract is identical and is asserted by
the tests.
update_kv_state: bool(assertstile_D >= headdim): thekernel also stores this step's
B/xintoBstate/Xstateafterconsuming the old values, replacing the caller's two scatter kernels. With
a single D-tile, exactly one CTA owns each
(batch, head)row, so there isno cross-CTA hazard (asserted in the wrapper). This is the MIMO analog of
causal_conv1d_update(..., conv_state_indices=...): Mamba2's auxiliarystate is updated by its own indexed kernel, while MIMO's auxiliary states
(previous step's B/x) are already inputs of this kernel.
state_batch_indicesnothing changes(same compile cache keys for existing callers).
Correctness
tests/ops/test_mamba3_step_indexed.py(included):gather → dense kernel → scatter reference, for batch sizes 1–32, with both
update_kv_statemodes, at H=48, D=64, N=128, R=4 (fp32 state, bf16 B/x).outputs, real lanes unchanged, and untouched pool rows verified
untouched.
update_kv_state+tile_D < headdimmisuse raises.Also validated end-to-end in serving (greedy outputs coherent at 256-way
concurrency; gsm8k / HumanEval / RULER scores unchanged vs the gather/scatter
path). bf16 state storage works through the same path (the kernel already
accumulates in fp32).
Notes for review
every other tensor keeps direct batch addressing.
state_batch_indicesis assertedint32(the cute compile declares thetensor dtype — stricter than Triton's any-int load in Mamba2).
main.