From 5de28fe384004829133d15fcc8936952eba79b25 Mon Sep 17 00:00:00 2001 From: Ali Nasiri Sarvi Date: Thu, 9 Apr 2026 16:38:09 +0000 Subject: [PATCH] Cache ctx.saved_tensors to avoid double-access under activation checkpointing torch.utils.checkpoint unpack hooks only allow a single unpack per tensor. --- mamba_ssm/ops/triton/mamba3/mamba3_siso_combined.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mamba_ssm/ops/triton/mamba3/mamba3_siso_combined.py b/mamba_ssm/ops/triton/mamba3/mamba3_siso_combined.py index ca5152001..809b31ecd 100644 --- a/mamba_ssm/ops/triton/mamba3/mamba3_siso_combined.py +++ b/mamba_ssm/ops/triton/mamba3/mamba3_siso_combined.py @@ -165,7 +165,8 @@ def backward( except Exception: pass - if len(ctx.saved_tensors) == 0: + _saved = ctx.saved_tensors + if len(_saved) == 0: raise RuntimeError( "Backward called but forward ran without gradient tracking. " "Ensure inputs require grad or run under torch.enable_grad()." @@ -176,7 +177,7 @@ def backward( (Q, K, V, ADT, DT, Trap, Q_bias, K_bias, Angles, Angles_Cumsum, D_save, Z_save, Input_SSM_State_save, Input_K_State_save, Input_V_State_save, Out, Out_v, SSM_States, DA_CS, DA_CS_SUM, Q_rot, K_scaled, QK_dot, Scale, Gamma, - Final_SSM_State_save, cu_seqlens_save) = ctx.saved_tensors + Final_SSM_State_save, cu_seqlens_save) = _saved D = D_save if ctx.has_D else None Z = Z_save if ctx.has_Z else None