diff --git a/mamba_ssm/ops/cute/mamba3/mamba3_step_fn.py b/mamba_ssm/ops/cute/mamba3/mamba3_step_fn.py index 4a2fb49f6..e19daca7f 100644 --- a/mamba_ssm/ops/cute/mamba3/mamba3_step_fn.py +++ b/mamba_ssm/ops/cute/mamba3/mamba3_step_fn.py @@ -46,7 +46,7 @@ def get_gmem_tiled_copy(dtype: Type[cutlass.Numeric], major_mode_size: int, num_ class Mamba3Step(): - def __init__(self, tile_D: int, dstate: int, mimo: int = 1, num_warps: int = 4, remove_gate: bool = False, remove_outproj: bool = False): + def __init__(self, tile_D: int, dstate: int, mimo: int = 1, num_warps: int = 4, remove_gate: bool = False, remove_outproj: bool = False, update_kv_state: bool = False): assert num_warps >= 2 assert dstate % 8 == 0, "dstate must be multiple of 8" # for vectorized load /store self.tile_D = tile_D @@ -55,6 +55,12 @@ def __init__(self, tile_D: int, dstate: int, mimo: int = 1, num_warps: int = 4, self.num_warps = num_warps self.remove_gate = remove_gate self.remove_outproj = remove_outproj + # When True (indexed mode only), the kernel stores the new key/value + # states (its B and x inputs) into mBstate/mXstate after consuming the + # old values — saves the caller's separate scatter kernels. Requires a + # single D-tile per (b, h) (tile_D >= D), otherwise the bidd=0 CTA's + # Bstate write would race other CTAs' reads. + self.update_kv_state = update_kv_state def _setup_smem_layouts(self): self.sState_layout = cute.make_ordered_layout((self.tile_D, self.dstate), order=(1, 0)) @@ -102,6 +108,10 @@ def __call__( mOut: cute.Tensor, # (B, H, D) or (B, R, H, D) if remove_outproj mZ: Optional[cute.Tensor], # (B, H, D), None if remove_gate mZproj: Optional[cute.Tensor], # (R, H, D), None if remove_gate + mStateBatchIdx: Optional[cute.Tensor], # (B,) int32 — row of the state pools + # for each batch element; when given, mState/mStateOut/mBstate/mXstate + # are pools of shape (P, ...) indexed indirectly (avoids the PyTorch + # gather/scatter round-trip on the SSM state). stream: cuda.CUstream, ): self.dtype = mState.element_type @@ -206,6 +216,7 @@ class SharedStorage: mOut, mZ, mZproj, + mStateBatchIdx, self.sState_layout, self.sBC_layout, self.sProj_layout, @@ -217,8 +228,9 @@ class SharedStorage: tiled_copy_B_s2r, vecsize_dstate, ).launch( - # grid: (d, h, b) - grid=[cute.ceil_div(mState.shape[2], self.tile_D), mState.shape[1], mState.shape[0]], + # grid: (d, h, b) — batch from mX (mState may be a larger pool + # when mStateBatchIdx is used) + grid=[cute.ceil_div(mState.shape[2], self.tile_D), mState.shape[1], mX.shape[0]], block=[num_threads, 1, 1], stream=stream, ) @@ -242,6 +254,7 @@ def kernel( mOut: cute.Tensor, # (B, H, D) or (B, R, H, D) if remove_outproj mZ: Optional[cute.Tensor], # (B, H, D), None if remove_gate mZproj: Optional[cute.Tensor], # (R, H, D), None if remove_gate + mStateBatchIdx: Optional[cute.Tensor], # (B,) int32 pool-row indices sState_layout: cute.Layout | cute.ComposedLayout, sBC_layout: cute.Layout | cute.ComposedLayout, sProj_layout: cute.Layout | cute.ComposedLayout, @@ -260,24 +273,39 @@ def kernel( limit_d = mState.shape[2] + # Pool-row index for the state tensors (indirect when mStateBatchIdx given). + # Batch padding (e.g. vLLM CUDA-graph capture sizes) uses negative + # indices (PAD_SLOT_ID = -1): clamp the row for the loads and suppress + # the state write so padded lanes never touch a real pool row. + valid_st = Boolean(True) + if const_expr(mStateBatchIdx is not None): + idx_val = Int32(mStateBatchIdx[bidb]) + valid_st = idx_val >= Int32(0) + bidb_st = idx_val + if not valid_st: + bidb_st = Int32(0) + else: + bidb_st = bidb + # /////////////////////////////////////////////////////////////////////////////// # Slice for CTA # /////////////////////////////////////////////////////////////////////////////// # (tile_D, N) gState, gStateOut = [ - cute.local_tile(t[bidb, bidh, None, None], (self.tile_D, self.dstate), (bidd, 0)) + cute.local_tile(t[bidb_st, bidh, None, None], (self.tile_D, self.dstate), (bidd, 0)) for t in (mState, mStateOut) ] # (R, N) - gBstate, gB, gC = [ + gBstate = cute.local_tile( + mBstate[bidb_st, None, bidh, None], (self.mimo, self.dstate), (0, 0) + ) + gB, gC = [ cute.local_tile(t[bidb, None, bidh, None], (self.mimo, self.dstate), (0, 0)) - for t in (mBstate, mB, mC) + for t in (mB, mC) ] # (tile_D,) - gXstate, gX = [ - cute.local_tile(t[bidb, bidh, None], (self.tile_D,), (bidd,)) - for t in (mXstate, mX) - ] + gXstate = cute.local_tile(mXstate[bidb_st, bidh, None], (self.tile_D,), (bidd,)) + gX = cute.local_tile(mX[bidb, bidh, None], (self.tile_D,), (bidd,)) if const_expr(mOutproj is not None): # Output is (B, H, D), outproj reduces MIMO rank gOut = cute.local_tile(mOut[bidb, bidh, None], (self.tile_D,), (bidd,)) @@ -482,7 +510,11 @@ def kernel( cute.arch.cp_async_commit_group() # Write state back to StateOut (may be same memory as State for in-place) - cute.copy(tiled_copy_state_s2r, tSrS, tSgSOut) + if const_expr(mStateBatchIdx is not None): + if valid_st: + cute.copy(tiled_copy_state_s2r, tSrS, tSgSOut) + else: + cute.copy(tiled_copy_state_s2r, tSrS, tSgSOut) # Do state @ C cute.arch.cp_async_wait_group(1) # C is done loading @@ -530,6 +562,22 @@ def kernel( cute.arch.cp_async_wait_group(0) # Zproj and Outproj are done loading cute.arch.sync_threads() + # Store the new key/value states (this step's B and x) into the pools, + # now that every thread has consumed the old Bstate/Xstate. Single + # D-tile per (b, h) is guaranteed by the wrapper (tile_D >= D), so no + # other CTA still reads these rows. tSrB / tXrX hold the original + # (pre-fp32) values, so the store is bit-exact with the caller-side + # `k_pool[slots] = B; v_pool[slots] = x` it replaces. + if const_expr(self.update_kv_state and mStateBatchIdx is not None): + if valid_st: + tpd_b = self.dstate // vecsize_dstate + if tidx < tpd_b: + tSgBstate_w = smem_thr_copy_B.partition_S(gBstate)[None, None, 0] + cute.autovec_copy(tSrB, tSgBstate_w) + if warp_idx == 0: + if not need_bound_check_X or lane_idx < num_loads_X: + cute.autovec_copy(tXrX, tXgXstate) + if const_expr(mOutproj is not None): # Gate: z_r * sigmoid(z_r) if const_expr(mZ is not None): @@ -552,6 +600,11 @@ def kernel( else: out_val += out[r] * out_proj_val + # Skip padding tokens: zero the output (selective_state_update + # does the same for state_batch_indices < 0). + if const_expr(mStateBatchIdx is not None): + if not valid_st: + out_val = Float32(0.0) # Write final output (B, H, D) if lane_idx < lanes_per_D: gOut[warp_idx * lanes_per_D + lane_idx] = out_val.to(mOut.element_type) @@ -559,8 +612,13 @@ def kernel( # No outproj: write per-rank output (B, R, H, D) for r in cutlass.range_constexpr(self.mimo): gOut_r = cute.local_tile(mOut[bidb, r, bidh, None], (self.tile_D,), (bidd,)) + out_r_val = out[r] + # Skip padding tokens (see above). + if const_expr(mStateBatchIdx is not None): + if not valid_st: + out_r_val = Float32(0.0) if lane_idx < lanes_per_D: - gOut_r[warp_idx * lanes_per_D + lane_idx] = out[r].to(mOut.element_type) + gOut_r[warp_idx * lanes_per_D + lane_idx] = out_r_val.to(mOut.element_type) def mamba3_step_fn( @@ -581,17 +639,42 @@ def mamba3_step_fn( out: Tensor = None, # (B, H, D) or (B, R, H, D) if remove_outproj z: Optional[Tensor] = None, # (B, H, D), None if remove_gate zproj: Optional[Tensor] = None, # (R, H, D), None if remove_gate + state_batch_indices: Optional[Tensor] = None, # (B,) int32 — when given, + # state/Bstate/Xstate are pools of shape (P, ...) and row + # state_batch_indices[b] holds batch element b's state. The state is updated + # in place in the pool (state_out must be None). Avoids gather/scatter. + update_kv_state: bool = False, # kernel also stores this step's B and x + # into Bstate/Xstate after consuming the old values, replacing the + # caller's scatter kernels. Requires state_batch_indices and tile_D >= headdim + # (a single D-tile per (b, h): the Bstate write would race other CTAs' + # reads otherwise). NOTE: mutates Bstate/Xstate. tile_D: int = 64, num_warps: int = 2, ) -> None: has_z = z is not None has_outproj = outproj is not None + has_state_batch_idx = state_batch_indices is not None inplace = state_out is None - batch, nheads, hdim, dstate = state.shape + pool, nheads, hdim, dstate = state.shape + if update_kv_state: + assert has_state_batch_idx, "update_kv_state requires state_batch_indices" + assert tile_D >= hdim, ( + f"update_kv_state requires a single D-tile per (b, h) " + f"(tile_D={tile_D} >= headdim={hdim}); with multiple D-tiles the " + f"bidd=0 CTA's Bstate write would race other CTAs' reads" + ) mimo = Bstate.shape[1] - assert state.shape == (batch, nheads, hdim, dstate) - assert Bstate.shape == (batch, mimo, nheads, dstate) - assert Xstate.shape == (batch, nheads, hdim) + batch = x.shape[0] + if has_state_batch_idx: + assert inplace, "state_batch_indices requires in-place update (state_out=None)" + assert state_batch_indices.shape == (batch,) + assert state_batch_indices.dtype == torch.int32 + assert state_batch_indices.is_cuda + else: + assert pool == batch + assert state.shape == (pool, nheads, hdim, dstate) + assert Bstate.shape == (pool, mimo, nheads, dstate) + assert Xstate.shape == (pool, nheads, hdim) assert A.shape == (batch, nheads) assert B.shape == (batch, mimo, nheads, dstate) assert C.shape == (batch, mimo, nheads, dstate) @@ -615,7 +698,7 @@ def mamba3_step_fn( if inplace: state_out = state else: - assert state_out.shape == (batch, nheads, hdim, dstate) + assert state_out.shape == (pool, nheads, hdim, dstate) required_tensors = [state, Bstate, Xstate, A, B, C, D, x, dt, trap, xproj, state_out, out] if has_outproj: @@ -650,13 +733,17 @@ def mamba3_step_fn( trap.dtype, has_z, has_outproj, + has_state_batch_idx, + update_kv_state, ) if compile_key not in mamba3_step_fn.compile_cache: - mamba3_step_op = Mamba3Step(tile_D, dstate, mimo, num_warps, remove_gate=not has_z, remove_outproj=not has_outproj) + mamba3_step_op = Mamba3Step(tile_D, dstate, mimo, num_warps, remove_gate=not has_z, remove_outproj=not has_outproj, update_kv_state=update_kv_state) # Create symbolic dimensions for batch and nheads batch_sym = cute.sym_int() nheads_sym = cute.sym_int() + # Pool row count is independent of batch when state_batch_indices is used + pool_sym = cute.sym_int() if has_state_batch_idx else batch_sym # Divisibility for strides (128-bit alignment) div_state = 128 // state_cute_dtype.width @@ -669,9 +756,9 @@ def mamba3_step_fn( div_trap = 128 // trap_cute_dtype.width # Create fake tensors with symbolic batch/nheads dimensions - state_fake = make_fake_tensor(state_cute_dtype, (batch_sym, nheads_sym, hdim, dstate), div_state) - Bstate_fake = make_fake_tensor(b_cute_dtype, (batch_sym, mimo, nheads_sym, dstate), div_b) - Xstate_fake = make_fake_tensor(x_cute_dtype, (batch_sym, nheads_sym, hdim), div_x) + state_fake = make_fake_tensor(state_cute_dtype, (pool_sym, nheads_sym, hdim, dstate), div_state) + Bstate_fake = make_fake_tensor(b_cute_dtype, (pool_sym, mimo, nheads_sym, dstate), div_b) + Xstate_fake = make_fake_tensor(x_cute_dtype, (pool_sym, nheads_sym, hdim), div_x) A_fake = make_fake_tensor(a_cute_dtype, (batch_sym, nheads_sym), div_a) B_fake = make_fake_tensor(b_cute_dtype, (batch_sym, mimo, nheads_sym, dstate), div_b) C_fake = make_fake_tensor(b_cute_dtype, (batch_sym, mimo, nheads_sym, dstate), div_b) @@ -681,13 +768,16 @@ def mamba3_step_fn( trap_fake = make_fake_tensor(trap_cute_dtype, (batch_sym, nheads_sym), div_trap) xproj_fake = make_fake_tensor(proj_cute_dtype, (mimo, nheads_sym, hdim), div_proj) outproj_fake = make_fake_tensor(proj_cute_dtype, (mimo, nheads_sym, hdim), div_proj) if has_outproj else None - state_out_fake = make_fake_tensor(state_cute_dtype, (batch_sym, nheads_sym, hdim, dstate), div_state) + state_out_fake = make_fake_tensor(state_cute_dtype, (pool_sym, nheads_sym, hdim, dstate), div_state) if has_outproj: out_fake = make_fake_tensor(x_cute_dtype, (batch_sym, nheads_sym, hdim), div_x) else: out_fake = make_fake_tensor(x_cute_dtype, (batch_sym, mimo, nheads_sym, hdim), div_x) z_fake = make_fake_tensor(x_cute_dtype, (batch_sym, nheads_sym, hdim), div_x) if has_z else None zproj_fake = make_fake_tensor(proj_cute_dtype, (mimo, nheads_sym, hdim), div_proj) if has_z else None + state_batch_idx_fake = ( + make_fake_tensor(Int32, (batch_sym,)) if has_state_batch_idx else None + ) fake_stream = cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True) @@ -709,6 +799,7 @@ def mamba3_step_fn( out_fake, z_fake, zproj_fake, + state_batch_idx_fake, fake_stream, options="--enable-tvm-ffi", ) @@ -732,6 +823,7 @@ def mamba3_step_fn( out, z, zproj, + state_batch_indices, ) diff --git a/tests/ops/test_mamba3_step_indexed.py b/tests/ops/test_mamba3_step_indexed.py new file mode 100644 index 000000000..bae32b8dd --- /dev/null +++ b/tests/ops/test_mamba3_step_indexed.py @@ -0,0 +1,107 @@ +# Copyright (c) 2025, Tri Dao. +"""Parity tests for mamba3_step_fn's indexed state-pool mode. + +The indexed path (``state_indices``) must be byte-identical to the +gather -> dense kernel -> scatter reference, including: +- negative (PAD_SLOT_ID) indices from CUDA-graph batch padding, +- untouched pool rows staying untouched, +- ``update_kv_state=True`` storing the new B/x states exactly as the + caller-side scatter would. +""" +import pytest +import torch + +from mamba_ssm.ops.cute.mamba3.mamba3_step_fn import mamba3_step_fn + +H, D, N, R = 48, 64, 128, 4 +P = 37 # pool rows + +requires_cuda = pytest.mark.skipif( + not torch.cuda.is_available(), reason="requires CUDA" +) + + +def _rand_inputs(B, state_dtype=torch.float32): + dev = "cuda" + ssm_pool = torch.randn(P, H, D, N, device=dev, dtype=state_dtype) + k_pool = torch.randn(P, R, H, N, device=dev, dtype=torch.bfloat16) + v_pool = torch.randn(P, H, D, device=dev, dtype=torch.bfloat16) + return dict( + ssm_pool=ssm_pool, k_pool=k_pool, v_pool=v_pool, + A=-torch.rand(B, H, device=dev, dtype=torch.float32), + B=torch.randn(B, R, H, N, device=dev, dtype=torch.bfloat16), + C=torch.randn(B, R, H, N, device=dev, dtype=torch.bfloat16), + Dp=torch.randn(H, device=dev, dtype=torch.float32), + x=torch.randn(B, H, D, device=dev, dtype=torch.bfloat16), + dt=torch.rand(B, H, device=dev, dtype=torch.float32), + trap=torch.rand(B, H, device=dev, dtype=torch.float32), + xpj=torch.randn(R, H, D, device=dev, dtype=torch.bfloat16), + opj=torch.randn(R, H, D, device=dev, dtype=torch.bfloat16), + z=torch.randn(B, H, D, device=dev, dtype=torch.bfloat16), + zpj=torch.randn(R, H, D, device=dev, dtype=torch.bfloat16), + ) + + +@requires_cuda +@pytest.mark.parametrize("B,npad", [(1, 0), (5, 0), (32, 0), (8, 3), (24, 8)]) +@pytest.mark.parametrize("update_kv", [False, True]) +def test_indexed_matches_gather_scatter(B: int, npad: int, update_kv: bool): + torch.manual_seed(0) + t = _rand_inputs(B) + real = B - npad + slots = torch.randperm(P, device="cuda")[:B].to(torch.int32) + slots[real:] = -1 # PAD_SLOT_ID lanes + rs = slots[:real].long() + + # Reference: gather -> dense kernel -> scatter (real lanes only). + pool_ref = t["ssm_pool"].clone() + k_ref, v_ref = t["k_pool"].clone(), t["v_pool"].clone() + st = pool_ref[rs] + st_out = torch.empty_like(st) + y_ref = torch.empty_like(t["x"][:real]) + mamba3_step_fn(st, k_ref[rs], v_ref[rs], t["A"][:real], t["B"][:real], + t["C"][:real], t["Dp"], t["x"][:real], t["dt"][:real], + t["trap"][:real], t["xpj"], t["opj"], st_out, y_ref, + z=t["z"][:real], zproj=t["zpj"], tile_D=64, num_warps=4) + pool_ref[rs] = st_out + k_ref[rs] = t["B"][:real] # caller-side kv update (old semantics) + v_ref[rs] = t["x"][:real] + + # Indexed: in-place pool update via slot indices. + pool_new = t["ssm_pool"].clone() + k_new, v_new = t["k_pool"].clone(), t["v_pool"].clone() + y_new = torch.empty_like(t["x"]) + mamba3_step_fn(pool_new, k_new, v_new, t["A"], t["B"], t["C"], t["Dp"], + t["x"], t["dt"], t["trap"], t["xpj"], t["opj"], None, + y_new, z=t["z"], zproj=t["zpj"], state_batch_indices=slots, + update_kv_state=update_kv, tile_D=64, num_warps=4) + if not update_kv: + k_new[torch.where(slots >= 0, slots, 0).long()[:real]] = t["B"][:real] + v_new[torch.where(slots >= 0, slots, 0).long()[:real]] = t["x"][:real] + torch.cuda.synchronize() + + assert torch.equal(y_new[:real], y_ref) + # Padding lanes produce zeroed outputs (selective_state_update semantics). + if npad: + assert torch.equal(y_new[real:], torch.zeros_like(y_new[real:])) + assert torch.equal(pool_new, pool_ref) + assert torch.equal(k_new, k_ref) + assert torch.equal(v_new, v_ref) + # untouched pool rows must be untouched + mask = torch.ones(P, dtype=torch.bool, device="cuda") + mask[rs] = False + assert torch.equal(pool_new[mask], t["ssm_pool"][mask]) + + +@requires_cuda +def test_update_kv_state_requires_single_d_tile(): + torch.manual_seed(0) + t = _rand_inputs(4) + slots = torch.arange(4, device="cuda", dtype=torch.int32) + y = torch.empty_like(t["x"]) + with pytest.raises(AssertionError, match="single D-tile"): + mamba3_step_fn(t["ssm_pool"], t["k_pool"], t["v_pool"], t["A"], + t["B"], t["C"], t["Dp"], t["x"], t["dt"], t["trap"], + t["xpj"], t["opj"], None, y, z=t["z"], zproj=t["zpj"], + state_batch_indices=slots, update_kv_state=True, + tile_D=32, num_warps=4)