From c45840cc1ed58c1f9854949fb1d2f2047767d879 Mon Sep 17 00:00:00 2001 From: Taksh Date: Mon, 6 Apr 2026 16:37:34 +0530 Subject: [PATCH] Fix Mamba2.step() handling of D when D_has_hdim=True When D_has_hdim=True, self.D has shape (nheads * headdim,) instead of (nheads,). The forward() method correctly reshapes D from (h*p,) to (h, p) in this case, but step() always treated D as if it had shape (nheads,), leading to incorrect results during autoregressive decoding. Fix both code paths in step() (with and without selective_state_update) to use rearrange("(h p) -> h p") when D_has_hdim=True, consistent with forward(). Fixes #887 Co-Authored-By: Claude Opus 4.6 (1M context) --- mamba_ssm/modules/mamba2.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/mamba_ssm/modules/mamba2.py b/mamba_ssm/modules/mamba2.py index 36b16d471..facae2d75 100644 --- a/mamba_ssm/modules/mamba2.py +++ b/mamba_ssm/modules/mamba2.py @@ -316,7 +316,10 @@ def step(self, hidden_states, conv_state, ssm_state): dBx = torch.einsum("bh,bn,bhp->bhpn", dt, B, x) ssm_state.copy_(ssm_state * rearrange(dA, "b h -> b h 1 1") + dBx) y = torch.einsum("bhpn,bn->bhp", ssm_state.to(dtype), C) - y = y + rearrange(self.D.to(dtype), "h -> h 1") * x + if self.D_has_hdim: + y = y + rearrange(self.D.to(dtype), "(h p) -> h p", p=self.headdim) * x + else: + y = y + rearrange(self.D.to(dtype), "h -> h 1") * x y = rearrange(y, "b h p -> b (h p)") if not self.rmsnorm: y = y * self.act(z) # (B D) @@ -324,7 +327,10 @@ def step(self, hidden_states, conv_state, ssm_state): A = repeat(A, "h -> h p n", p=self.headdim, n=self.d_state).to(dtype=torch.float32) dt = repeat(dt, "b h -> b h p", p=self.headdim) dt_bias = repeat(self.dt_bias, "h -> h p", p=self.headdim) - D = repeat(self.D, "h -> h p", p=self.headdim) + if self.D_has_hdim: + D = rearrange(self.D, "(h p) -> h p", p=self.headdim) + else: + D = repeat(self.D, "h -> h p", p=self.headdim) B = rearrange(B, "b (g n) -> b g n", g=self.ngroups) C = rearrange(C, "b (g n) -> b g n", g=self.ngroups) x_reshaped = rearrange(x, "b (h p) -> b h p", p=self.headdim)