Skip to content

Fix Mamba2.step() handling of D when D_has_hdim=True#901

Closed
Chessing234 wants to merge 1 commit into
state-spaces:mainfrom
Chessing234:fix/issue-887-D_has_hdim-step
Closed

Fix Mamba2.step() handling of D when D_has_hdim=True#901
Chessing234 wants to merge 1 commit into
state-spaces:mainfrom
Chessing234:fix/issue-887-D_has_hdim-step

Conversation

@Chessing234

Copy link
Copy Markdown
Contributor

Summary

  • Fixes Mamba2.step() to correctly handle self.D when D_has_hdim=True, where D has shape (nheads * headdim,) instead of (nheads,)
  • Both code paths in step() (with and without selective_state_update) now use rearrange("(h p) -> h p") when D_has_hdim=True, consistent with forward()
  • Without this fix, autoregressive decoding produces incorrect outputs when D_has_hdim=True because D is misinterpreted as per-head rather than per-head-dim

Fixes #887

Test plan

  • Verify step() output matches forward() output for a single token when D_has_hdim=True
  • Confirm no regression when D_has_hdim=False (default behavior unchanged)
  • Test both code paths: with selective_state_update available and with the fallback path

🤖 Generated with Claude Code

When D_has_hdim=True, self.D has shape (nheads * headdim,) instead of
(nheads,). The forward() method correctly reshapes D from (h*p,) to
(h, p) in this case, but step() always treated D as if it had shape
(nheads,), leading to incorrect results during autoregressive decoding.

Fix both code paths in step() (with and without selective_state_update)
to use rearrange("(h p) -> h p") when D_has_hdim=True, consistent with
forward().

Fixes #887

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@Chessing234

Copy link
Copy Markdown
Contributor Author

Closing in favor of #903 (duplicate)

@Chessing234 Chessing234 closed this Apr 7, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Mamba2.step() handles D incorrectly when D_has_dim=True

1 participant