Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions mamba_ssm/modules/mamba3.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,13 @@ def forward(self, u, seq_idx=None, cu_seqlens=None, inference_params=None):
inference_batch = cu_seqlens.shape[0] - 1 if cu_seqlens is not None else batch
angle_dt_state, ssm_state, k_state, v_state = self._get_states_from_cache(inference_params, inference_batch)
if inference_params.seqlen_offset > 0:
out, _, _, _, _ = self.step(u, angle_dt_state, ssm_state, k_state, v_state)
return out
# Decode: forward receives (batch, 1, dim) but step operates on
# (batch, dim). Squeeze the singleton seqlen going in and restore it
# coming out so the caller sees the same (batch, seqlen, dim) layout
# as the prefill path.
assert seqlen == 1, "Mamba3 decode step expects a single token (seqlen=1)."
out, _, _, _, _ = self.step(u.squeeze(1), angle_dt_state, ssm_state, k_state, v_state)
return out.unsqueeze(1)

# Apply in_proj
zxBCdtAtrap = self.in_proj(u)
Expand Down
65 changes: 65 additions & 0 deletions tests/ops/cute/test_mamba3_mimo_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,71 @@ def test_step_matches_forward_fp32(variant: VariantConfig, is_outproj_norm: bool
)


@pytest.mark.parametrize("variant", [pytest.param(SISO, id="siso"), pytest.param(MIMO, id="mimo")])
@pytest.mark.parametrize(
"is_outproj_norm",
[
pytest.param(False, id="outproj_norm_false"),
pytest.param(True, id="outproj_norm_true"),
],
)
def test_forward_decode_dispatch_matches_step(
variant: VariantConfig, is_outproj_norm: bool
) -> None:
"""Regression: ``Mamba3.forward(inference_params=..., seqlen_offset>0)`` must
route a ``(batch, 1, dim)`` token through ``step()`` and return ``(batch, 1, dim)``.

This is the path autoregressive generation actually exercises, and it was
previously untested: ``test_step_matches_forward_fp32`` drives ``step()``
directly with 2D ``(batch, dim)`` tensors, while ``forward`` handed ``step()``
a 3D tensor — crashing ``_preprocess``'s rearrange. We compare the forward
decode path against the (already-validated) direct ``step()`` reference.
"""
Mamba3 = _mamba3_cls()
cfg = _case_config(variant, is_outproj_norm=is_outproj_norm)
config_label = f"forward_decode[variant={variant.mimo_dim}, is_outproj_norm={is_outproj_norm}]"

torch.manual_seed(42)
torch.cuda.manual_seed_all(42)
model = Mamba3(**cfg)
model.eval()

u = torch.randn(BATCH, SEQLEN, cfg["d_model"], device=DEVICE, dtype=DTYPE)

with torch.no_grad():
# Reference: drive step() directly, token by token (the tested 2D contract).
state = model.allocate_inference_cache(BATCH, 1, device=DEVICE, dtype=DTYPE)
ref_steps = []
for t in range(SEQLEN):
out_step, *state = model.step(u[:, t], *state)
ref_steps.append(out_step)
ref = torch.stack(ref_steps, dim=1) # (batch, seqlen, dim)

# Path under test: forward()'s decode dispatch. One-token prefill populates
# the cache (seqlen_offset == 0), then each subsequent token takes the
# seqlen_offset > 0 branch that forwards a (batch, 1, dim) tensor to step().
inference_params = InferenceParams(max_seqlen=SEQLEN, max_batch_size=BATCH)
decode_outs = [model(u[:, :1], inference_params=inference_params)]
for t in range(1, SEQLEN):
inference_params.seqlen_offset = t
out = model(u[:, t : t + 1], inference_params=inference_params)
assert out.shape == (BATCH, 1, cfg["d_model"]), (
f"forward decode must return (batch, 1, dim), got {tuple(out.shape)} "
f"for {config_label}"
)
decode_outs.append(out)
decode = torch.cat(decode_outs, dim=1) # (batch, seqlen, dim)

for t in range(SEQLEN):
_assert_close(
decode[:, t],
ref[:, t],
label="forward-decode-vs-step",
cfg=config_label,
step=t,
)


def run_step_benchmark(variant: VariantConfig, *, is_outproj_norm: bool) -> None:
_require_cuda_and_kernel_deps()
from triton.testing import do_bench_cudagraph
Expand Down