diff --git a/mamba_ssm/modules/mamba3.py b/mamba_ssm/modules/mamba3.py index 34c119c2c..75680c4d1 100644 --- a/mamba_ssm/modules/mamba3.py +++ b/mamba_ssm/modules/mamba3.py @@ -169,8 +169,13 @@ def forward(self, u, seq_idx=None, cu_seqlens=None, inference_params=None): inference_batch = cu_seqlens.shape[0] - 1 if cu_seqlens is not None else batch angle_dt_state, ssm_state, k_state, v_state = self._get_states_from_cache(inference_params, inference_batch) if inference_params.seqlen_offset > 0: - out, _, _, _, _ = self.step(u, angle_dt_state, ssm_state, k_state, v_state) - return out + # Decode: forward receives (batch, 1, dim) but step operates on + # (batch, dim). Squeeze the singleton seqlen going in and restore it + # coming out so the caller sees the same (batch, seqlen, dim) layout + # as the prefill path. + assert seqlen == 1, "Mamba3 decode step expects a single token (seqlen=1)." + out, _, _, _, _ = self.step(u.squeeze(1), angle_dt_state, ssm_state, k_state, v_state) + return out.unsqueeze(1) # Apply in_proj zxBCdtAtrap = self.in_proj(u) diff --git a/tests/ops/cute/test_mamba3_mimo_step.py b/tests/ops/cute/test_mamba3_mimo_step.py index 0d40e9dc6..ffd9ebeab 100644 --- a/tests/ops/cute/test_mamba3_mimo_step.py +++ b/tests/ops/cute/test_mamba3_mimo_step.py @@ -260,6 +260,71 @@ def test_step_matches_forward_fp32(variant: VariantConfig, is_outproj_norm: bool ) +@pytest.mark.parametrize("variant", [pytest.param(SISO, id="siso"), pytest.param(MIMO, id="mimo")]) +@pytest.mark.parametrize( + "is_outproj_norm", + [ + pytest.param(False, id="outproj_norm_false"), + pytest.param(True, id="outproj_norm_true"), + ], +) +def test_forward_decode_dispatch_matches_step( + variant: VariantConfig, is_outproj_norm: bool +) -> None: + """Regression: ``Mamba3.forward(inference_params=..., seqlen_offset>0)`` must + route a ``(batch, 1, dim)`` token through ``step()`` and return ``(batch, 1, dim)``. + + This is the path autoregressive generation actually exercises, and it was + previously untested: ``test_step_matches_forward_fp32`` drives ``step()`` + directly with 2D ``(batch, dim)`` tensors, while ``forward`` handed ``step()`` + a 3D tensor — crashing ``_preprocess``'s rearrange. We compare the forward + decode path against the (already-validated) direct ``step()`` reference. + """ + Mamba3 = _mamba3_cls() + cfg = _case_config(variant, is_outproj_norm=is_outproj_norm) + config_label = f"forward_decode[variant={variant.mimo_dim}, is_outproj_norm={is_outproj_norm}]" + + torch.manual_seed(42) + torch.cuda.manual_seed_all(42) + model = Mamba3(**cfg) + model.eval() + + u = torch.randn(BATCH, SEQLEN, cfg["d_model"], device=DEVICE, dtype=DTYPE) + + with torch.no_grad(): + # Reference: drive step() directly, token by token (the tested 2D contract). + state = model.allocate_inference_cache(BATCH, 1, device=DEVICE, dtype=DTYPE) + ref_steps = [] + for t in range(SEQLEN): + out_step, *state = model.step(u[:, t], *state) + ref_steps.append(out_step) + ref = torch.stack(ref_steps, dim=1) # (batch, seqlen, dim) + + # Path under test: forward()'s decode dispatch. One-token prefill populates + # the cache (seqlen_offset == 0), then each subsequent token takes the + # seqlen_offset > 0 branch that forwards a (batch, 1, dim) tensor to step(). + inference_params = InferenceParams(max_seqlen=SEQLEN, max_batch_size=BATCH) + decode_outs = [model(u[:, :1], inference_params=inference_params)] + for t in range(1, SEQLEN): + inference_params.seqlen_offset = t + out = model(u[:, t : t + 1], inference_params=inference_params) + assert out.shape == (BATCH, 1, cfg["d_model"]), ( + f"forward decode must return (batch, 1, dim), got {tuple(out.shape)} " + f"for {config_label}" + ) + decode_outs.append(out) + decode = torch.cat(decode_outs, dim=1) # (batch, seqlen, dim) + + for t in range(SEQLEN): + _assert_close( + decode[:, t], + ref[:, t], + label="forward-decode-vs-step", + cfg=config_label, + step=t, + ) + + def run_step_benchmark(variant: VariantConfig, *, is_outproj_norm: bool) -> None: _require_cuda_and_kernel_deps() from triton.testing import do_bench_cudagraph