Skip to content

Fix Mamba-3 decode path: forward() to step() shape handling#978

Open
sanowl wants to merge 1 commit into
state-spaces:mainfrom
sanowl:fix/mamba3-decode-shape
Open

Fix Mamba-3 decode path: forward() to step() shape handling#978
sanowl wants to merge 1 commit into
state-spaces:mainfrom
sanowl:fix/mamba3-decode-shape

Conversation

@sanowl

@sanowl sanowl commented Jun 20, 2026

Copy link
Copy Markdown

Mamba3.forward() routes to step() during autoregressive decode (seqlen_offset > 0),
passing the (batch, 1, dim) tensor it receives. step() expects (batch, dim):
_preprocess uses 2-axis einops patterns such as rearrange(x, "b (h p) -> b h p").
Decoding through the module crashes at the first generated token. Even without the
crash, the previous return out returns 2D while the prefill path returns 3D.
Mamba2.step() handles this by squeezing the input and unsqueezing the output;
Mamba3 did neither.

the Fix

  • Squeeze the singleton seqlen before step(), restore it after, with an
    assert seqlen == 1 guard (single-token decode, matching Mamba2).
  • Add test_forward_decode_dispatch_matches_step: prefill via forward, decode
    through the seqlen_offset > 0 branch, and check parity against the direct-step
    reference (SISO and MIMO, out-norm on and off).

I have not run this. No NVIDIA GPU available locally (the CUTE step kernel requires
H100). The change is verified by code review and matches the existing Mamba-2
convention; the added test should validate it on CUDA / CI.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant