Skip to content

fix: skip duplicate B/C RMSNorm in checkpoint_lvl=1 backward#964

Open
Chessing234 wants to merge 1 commit into
state-spaces:mainfrom
Chessing234:fix/double-rmsnorm-checkpoint-lvl1
Open

fix: skip duplicate B/C RMSNorm in checkpoint_lvl=1 backward#964
Chessing234 wants to merge 1 commit into
state-spaces:mainfrom
Chessing234:fix/double-rmsnorm-checkpoint-lvl1

Conversation

@Chessing234

Copy link
Copy Markdown
Contributor

Fixes #885

Bug

With checkpoint_lvl=1 and b_rms_weight/c_rms_weight, MambaInnerFn.backward applies RMSNorm to B and C a second time even though forward already normalized them and saved the post-norm tensors.

Root cause

delta is cleared before save_for_backward so it is recomputed once in backward; B and C are saved post-norm but backward still re-ran rms_norm_forward on those saved values.

Why this fix is correct

Drop the redundant B/C RMSNorm block in the checkpoint_lvl == 1 path so backward uses the same post-norm B/C that forward passed to selective_scan_cuda.fwd, matching the intended checkpointing behavior.

Made with Cursor

When checkpoint_lvl==1, B and C are already RMS-normalized in forward
and saved via save_for_backward. Re-applying rms_norm_forward in backward
double-normalizes them before selective_scan_cuda.bwd.

Fixes state-spaces#885

Co-authored-by: Cursor <cursoragent@cursor.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[bug] double rmsnorm on B/C in MambaInnerFn.backward at checkpoint_lvl=1

1 participant