From 609cb499a26955a724b26bc4b4b7b49951f29fd4 Mon Sep 17 00:00:00 2001 From: Taksh Date: Tue, 2 Jun 2026 08:18:47 +0530 Subject: [PATCH] fix: skip duplicate B/C RMSNorm in checkpoint_lvl=1 backward When checkpoint_lvl==1, B and C are already RMS-normalized in forward and saved via save_for_backward. Re-applying rms_norm_forward in backward double-normalizes them before selective_scan_cuda.bwd. Fixes #885 Co-authored-by: Cursor --- mamba_ssm/ops/selective_scan_interface.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/mamba_ssm/ops/selective_scan_interface.py b/mamba_ssm/ops/selective_scan_interface.py index a41f1359c..cbe9635db 100644 --- a/mamba_ssm/ops/selective_scan_interface.py +++ b/mamba_ssm/ops/selective_scan_interface.py @@ -300,21 +300,6 @@ def backward(ctx, dout): delta = rearrange(delta, "b d l -> (b l) d", l=L).contiguous() delta = rms_norm_forward(delta, ctx.dt_rms_weight, None, ctx.b_c_dt_rms_eps) delta = rearrange(delta, "(b l) d -> b d l", l=L).contiguous() - if b_rms_weight is not None: - # Recompute & RMSNorm B - B = rearrange(B, "b 1 dstate l -> (b l) dstate", l=L).contiguous() - B = rms_norm_forward( - B, ctx.b_rms_weight, None, ctx.b_c_dt_rms_eps - ) - B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous() - if c_rms_weight is not None: - # Recompute & RMSNorm C - C = rearrange(C, "b 1 dstate l -> (b l) dstate", l=L).contiguous() - C = rms_norm_forward( - C, ctx.c_rms_weight, None, ctx.b_c_dt_rms_eps - ) - C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous() - # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the # backward of selective_scan_cuda with the backward of chunk). dxz = torch.empty_like(xz) # (batch, dim, seqlen)