From 025c6e0f2384c7620986a6d4d79731c32ffe431c Mon Sep 17 00:00:00 2001 From: Ipnon Date: Thu, 2 Apr 2026 20:35:21 +0800 Subject: [PATCH] Fix Mamba2 step() D handling when D_has_hdim=True When D_has_hdim=True, self.D has shape (nheads * headdim,) but step() treated it as (nheads,) in both code paths, producing silent wrong outputs during inference. forward() already handled this correctly via rearrange("(h p) -> h p"). Add conditional reshape in both step() paths to match forward(). Add regression test comparing forward/step consistency. Fixes #887, fixes #888. --- mamba_ssm/modules/mamba2.py | 10 ++++- tests/test_mamba2_d_has_hdim.py | 65 +++++++++++++++++++++++++++++++++ 2 files changed, 73 insertions(+), 2 deletions(-) create mode 100644 tests/test_mamba2_d_has_hdim.py diff --git a/mamba_ssm/modules/mamba2.py b/mamba_ssm/modules/mamba2.py index 36b16d471..facae2d75 100644 --- a/mamba_ssm/modules/mamba2.py +++ b/mamba_ssm/modules/mamba2.py @@ -316,7 +316,10 @@ def step(self, hidden_states, conv_state, ssm_state): dBx = torch.einsum("bh,bn,bhp->bhpn", dt, B, x) ssm_state.copy_(ssm_state * rearrange(dA, "b h -> b h 1 1") + dBx) y = torch.einsum("bhpn,bn->bhp", ssm_state.to(dtype), C) - y = y + rearrange(self.D.to(dtype), "h -> h 1") * x + if self.D_has_hdim: + y = y + rearrange(self.D.to(dtype), "(h p) -> h p", p=self.headdim) * x + else: + y = y + rearrange(self.D.to(dtype), "h -> h 1") * x y = rearrange(y, "b h p -> b (h p)") if not self.rmsnorm: y = y * self.act(z) # (B D) @@ -324,7 +327,10 @@ def step(self, hidden_states, conv_state, ssm_state): A = repeat(A, "h -> h p n", p=self.headdim, n=self.d_state).to(dtype=torch.float32) dt = repeat(dt, "b h -> b h p", p=self.headdim) dt_bias = repeat(self.dt_bias, "h -> h p", p=self.headdim) - D = repeat(self.D, "h -> h p", p=self.headdim) + if self.D_has_hdim: + D = rearrange(self.D, "(h p) -> h p", p=self.headdim) + else: + D = repeat(self.D, "h -> h p", p=self.headdim) B = rearrange(B, "b (g n) -> b g n", g=self.ngroups) C = rearrange(C, "b (g n) -> b g n", g=self.ngroups) x_reshaped = rearrange(x, "b (h p) -> b h p", p=self.headdim) diff --git a/tests/test_mamba2_d_has_hdim.py b/tests/test_mamba2_d_has_hdim.py new file mode 100644 index 000000000..a573bfbc6 --- /dev/null +++ b/tests/test_mamba2_d_has_hdim.py @@ -0,0 +1,65 @@ +"""Test that Mamba2 forward() and step() produce consistent outputs when D_has_hdim=True. + +Regression test for https://github.com/state-spaces/mamba/issues/887 +""" +import torch +import pytest +from einops import rearrange + +from mamba_ssm.modules.mamba2 import Mamba2 + + +@pytest.mark.parametrize("D_has_hdim", [True, False]) +def test_mamba2_step_forward_consistency(D_has_hdim): + """step() must match forward() for D_has_hdim=True and D_has_hdim=False.""" + torch.manual_seed(42) + batch, seqlen = 2, 16 + d_model, headdim, d_state = 256, 64, 64 + device = "cuda" + dtype = torch.float32 + + model = Mamba2( + d_model=d_model, + headdim=headdim, + d_state=d_state, + d_conv=4, + D_has_hdim=D_has_hdim, + ngroups=1, + rmsnorm=False, + use_mem_eff_path=False, + device=device, + dtype=dtype, + ) + model.eval() + + # Randomize D so non-uniform values expose the bug + with torch.no_grad(): + model.D.copy_(torch.randn_like(model.D)) + + x = torch.randn(batch, seqlen, d_model, device=device, dtype=dtype) + + # Forward pass — reference output + with torch.no_grad(): + out_forward = model(x) + + # Step pass — one token at a time + conv_state, ssm_state = model.allocate_inference_cache(batch, seqlen, dtype=dtype) + step_outputs = [] + with torch.no_grad(): + for t in range(seqlen): + out_t, conv_state, ssm_state = model.step( + x[:, t : t + 1, :], conv_state, ssm_state + ) + step_outputs.append(out_t) + out_step = torch.cat(step_outputs, dim=1) + + # After conv warmup (d_conv - 1 = 3 tokens), outputs should match + d_conv = model.d_conv + out_fwd_tail = out_forward[:, d_conv - 1 :, :] + out_step_tail = out_step[:, d_conv - 1 :, :] + + max_diff = (out_fwd_tail - out_step_tail).abs().max().item() + print(f"D_has_hdim={D_has_hdim}, max diff after conv warmup: {max_diff:.2e}") + assert torch.allclose(out_fwd_tail, out_step_tail, rtol=1e-3, atol=1e-3), ( + f"forward/step mismatch with D_has_hdim={D_has_hdim}: max diff {max_diff:.2e}" + )