Skip to content

Add state_batch_indices to mamba3_step_fn (paged-inference state pools, parity with selective_state_update)#971

Open
jgcb00 wants to merge 1 commit into
state-spaces:mainfrom
jgcb00:mamba3-state-pools
Open

Add state_batch_indices to mamba3_step_fn (paged-inference state pools, parity with selective_state_update)#971
jgcb00 wants to merge 1 commit into
state-spaces:mainfrom
jgcb00:mamba3-state-pools

Conversation

@jgcb00

@jgcb00 jgcb00 commented Jun 11, 2026

Copy link
Copy Markdown

Motivation

Inference servers (vLLM and friends) keep per-request recurrent states in a
pool of P rows and address them by slot index, because requests come and
go at different times. Mamba2 supports this directly:
selective_state_update(..., state_batch_indices=...) reads and updates pool
rows in place. The Mamba3 MIMO step kernel has no equivalent, so a server must
do, per layer, per decode step:

state = state_pool[slots]            # gather   (read + write the full state)
mamba3_step_fn(state, ..., out)      # update   (read + write again)
state_pool[slots] = state_out        # scatter  (read + write a third time)

That is 3× the necessary HBM traffic on the dominant tensor of decode (the
fp32 SSM state — e.g. 48 heads × 64 × 128 = 1.5 MB per layer per request).
Profiling a mamba3-MIMO hybrid served with vLLM on GH200, this round-trip
accounted for 58% of all decode GPU time at batch 256; removing it gave
~2× end-to-end decode throughput (3.5k → 8k tok/s at 256 concurrent
requests). After the change the kernel runs at ~87% of HBM peak reading and
writing the pools directly.

What this PR does

Adds an optional state_batch_indices: (batch,) int32 argument to
mamba3_step_fn, mirroring the selective_state_update contract:

  • When given, state / Bstate / Xstate are pools of shape (P, ...) and
    the kernel reads/updates row state_batch_indices[b] in place
    (state_out must be None). The pool row count is independent of the
    batch size (separate symbolic dim; the launch grid takes batch from x).
  • Padding semantics match Mamba2 ("skip padding tokens"): a negative
    index (e.g. CUDA-graph batch padding) produces a zeroed output and leaves
    all states untouched
    . Implementation note: instead of Mamba2's early
    return, loads are clamped to row 0 and all writes are masked — CuteDSL's
    cp.async commit/wait pipeline spans the kernel body and does not admit a
    mid-kernel return; the observable contract is identical and is asserted by
    the tests.
  • Adds an opt-in update_kv_state: bool (asserts tile_D >= headdim): the
    kernel also stores this step's B/x into Bstate/Xstate after
    consuming the old values, replacing the caller's two scatter kernels. With
    a single D-tile, exactly one CTA owns each (batch, head) row, so there is
    no cross-CTA hazard (asserted in the wrapper). This is the MIMO analog of
    causal_conv1d_update(..., conv_state_indices=...): Mamba2's auxiliary
    state is updated by its own indexed kernel, while MIMO's auxiliary states
    (previous step's B/x) are already inputs of this kernel.
  • Fully backward-compatible: without state_batch_indices nothing changes
    (same compile cache keys for existing callers).

Correctness

tests/ops/test_mamba3_step_indexed.py (included):

  • Byte-identical outputs, updated SSM states, and B/x states vs the
    gather → dense kernel → scatter reference, for batch sizes 1–32, with both
    update_kv_state modes, at H=48, D=64, N=128, R=4 (fp32 state, bf16 B/x).
  • Batches with trailing negative (padded) indices: padded lanes return zeroed
    outputs, real lanes unchanged, and untouched pool rows verified
    untouched
    .
  • The update_kv_state + tile_D < headdim misuse raises.

Also validated end-to-end in serving (greedy outputs coherent at 256-way
concurrency; gsm8k / HumanEval / RULER scores unchanged vs the gather/scatter
path). bf16 state storage works through the same path (the kernel already
accumulates in fp32).

Notes for review

  • The diff only re-points the four state tensors through the indirection;
    every other tensor keeps direct batch addressing.
  • state_batch_indices is asserted int32 (the cute compile declares the
    tensor dtype — stricter than Triton's any-int load in Mamba2).
  • One commit, rebased on current main.

Mirror of selective_state_update's state_batch_indices for the mamba3 MIMO
step kernel: when given, state/Bstate/Xstate are pools of shape (P, ...)
and the kernel reads/updates row state_batch_indices[b] in place — removes
the gather/scatter round-trip on the fp32 SSM state that inference servers
otherwise pay (58% of decode GPU time at batch 256 in our vLLM deployment;
~2x end-to-end decode throughput).

Padding semantics match selective_state_update: negative indices (CUDA-graph
batch padding) skip the token — output zeroed, state untouched.

Optional update_kv_state (explicit, asserts tile_D >= headdim): the kernel
also stores this step's B/x into Bstate/Xstate after consuming the old
values, replacing the caller's scatter kernels (single CTA per (b, h) row,
no cross-CTA hazard).

Byte-identical vs the gather/scatter reference incl. padded lanes; see
tests/ops/test_mamba3_step_indexed.py.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants