From 5c4916699852e15065269dd90c0f7d563538923f Mon Sep 17 00:00:00 2001 From: Sukjun Hwang Date: Wed, 15 Apr 2026 22:02:26 -0400 Subject: [PATCH] Update angle normalization using remainder function Angle calculation to use remainder for angle normalization. --- mamba_ssm/ops/triton/mamba3/mamba3_mimo_rotary_step.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mamba_ssm/ops/triton/mamba3/mamba3_mimo_rotary_step.py b/mamba_ssm/ops/triton/mamba3/mamba3_mimo_rotary_step.py index c6b8596d1..7222cc772 100644 --- a/mamba_ssm/ops/triton/mamba3/mamba3_mimo_rotary_step.py +++ b/mamba_ssm/ops/triton/mamba3/mamba3_mimo_rotary_step.py @@ -75,6 +75,8 @@ def rotary_qk_inference_kernel( # Match angle_dt: tanh(angle_proj) * dt * pi angle_proj = tl.sigmoid(2.0 * angle_proj) * 2.0 - 1.0 # tanh angle = angle_state + angle_proj * dt * 3.141592653589793 # (rotary_dim // 2) + TWO_PI: tl.constexpr = 6.283185307179586 + angle = angle - TWO_PI * tl.floor(angle / TWO_PI) OUT_ANGLE_STATE = OUT_ANGLE_STATE + rd_half * stride_out_angle_state[2] tl.store(OUT_ANGLE_STATE, angle, mask=mask_angle) @@ -254,6 +256,7 @@ def apply_rotary_qk_inference_reference( # Match angle_dt: tanh(angle_proj) * dt * pi angle_proj = torch.tanh(angle_proj) angle = angle_state + angle_proj * dt[:, :, None] * math.pi # (B, N, S) + angle = torch.remainder(angle, 2 * math.pi) angle_state_new = angle angle = angle.unsqueeze(1).expand(-1, mimo_dim, -1, -1) # (B, R, N, S)