Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 114 additions & 22 deletions mamba_ssm/ops/cute/mamba3/mamba3_step_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def get_gmem_tiled_copy(dtype: Type[cutlass.Numeric], major_mode_size: int, num_


class Mamba3Step():
def __init__(self, tile_D: int, dstate: int, mimo: int = 1, num_warps: int = 4, remove_gate: bool = False, remove_outproj: bool = False):
def __init__(self, tile_D: int, dstate: int, mimo: int = 1, num_warps: int = 4, remove_gate: bool = False, remove_outproj: bool = False, update_kv_state: bool = False):
assert num_warps >= 2
assert dstate % 8 == 0, "dstate must be multiple of 8" # for vectorized load /store
self.tile_D = tile_D
Expand All @@ -55,6 +55,12 @@ def __init__(self, tile_D: int, dstate: int, mimo: int = 1, num_warps: int = 4,
self.num_warps = num_warps
self.remove_gate = remove_gate
self.remove_outproj = remove_outproj
# When True (indexed mode only), the kernel stores the new key/value
# states (its B and x inputs) into mBstate/mXstate after consuming the
# old values — saves the caller's separate scatter kernels. Requires a
# single D-tile per (b, h) (tile_D >= D), otherwise the bidd=0 CTA's
# Bstate write would race other CTAs' reads.
self.update_kv_state = update_kv_state

def _setup_smem_layouts(self):
self.sState_layout = cute.make_ordered_layout((self.tile_D, self.dstate), order=(1, 0))
Expand Down Expand Up @@ -102,6 +108,10 @@ def __call__(
mOut: cute.Tensor, # (B, H, D) or (B, R, H, D) if remove_outproj
mZ: Optional[cute.Tensor], # (B, H, D), None if remove_gate
mZproj: Optional[cute.Tensor], # (R, H, D), None if remove_gate
mStateBatchIdx: Optional[cute.Tensor], # (B,) int32 — row of the state pools
# for each batch element; when given, mState/mStateOut/mBstate/mXstate
# are pools of shape (P, ...) indexed indirectly (avoids the PyTorch
# gather/scatter round-trip on the SSM state).
stream: cuda.CUstream,
):
self.dtype = mState.element_type
Expand Down Expand Up @@ -206,6 +216,7 @@ class SharedStorage:
mOut,
mZ,
mZproj,
mStateBatchIdx,
self.sState_layout,
self.sBC_layout,
self.sProj_layout,
Expand All @@ -217,8 +228,9 @@ class SharedStorage:
tiled_copy_B_s2r,
vecsize_dstate,
).launch(
# grid: (d, h, b)
grid=[cute.ceil_div(mState.shape[2], self.tile_D), mState.shape[1], mState.shape[0]],
# grid: (d, h, b) — batch from mX (mState may be a larger pool
# when mStateBatchIdx is used)
grid=[cute.ceil_div(mState.shape[2], self.tile_D), mState.shape[1], mX.shape[0]],
block=[num_threads, 1, 1],
stream=stream,
)
Expand All @@ -242,6 +254,7 @@ def kernel(
mOut: cute.Tensor, # (B, H, D) or (B, R, H, D) if remove_outproj
mZ: Optional[cute.Tensor], # (B, H, D), None if remove_gate
mZproj: Optional[cute.Tensor], # (R, H, D), None if remove_gate
mStateBatchIdx: Optional[cute.Tensor], # (B,) int32 pool-row indices
sState_layout: cute.Layout | cute.ComposedLayout,
sBC_layout: cute.Layout | cute.ComposedLayout,
sProj_layout: cute.Layout | cute.ComposedLayout,
Expand All @@ -260,24 +273,39 @@ def kernel(

limit_d = mState.shape[2]

# Pool-row index for the state tensors (indirect when mStateBatchIdx given).
# Batch padding (e.g. vLLM CUDA-graph capture sizes) uses negative
# indices (PAD_SLOT_ID = -1): clamp the row for the loads and suppress
# the state write so padded lanes never touch a real pool row.
valid_st = Boolean(True)
if const_expr(mStateBatchIdx is not None):
idx_val = Int32(mStateBatchIdx[bidb])
valid_st = idx_val >= Int32(0)
bidb_st = idx_val
if not valid_st:
bidb_st = Int32(0)
else:
bidb_st = bidb

# ///////////////////////////////////////////////////////////////////////////////
# Slice for CTA
# ///////////////////////////////////////////////////////////////////////////////
# (tile_D, N)
gState, gStateOut = [
cute.local_tile(t[bidb, bidh, None, None], (self.tile_D, self.dstate), (bidd, 0))
cute.local_tile(t[bidb_st, bidh, None, None], (self.tile_D, self.dstate), (bidd, 0))
for t in (mState, mStateOut)
]
# (R, N)
gBstate, gB, gC = [
gBstate = cute.local_tile(
mBstate[bidb_st, None, bidh, None], (self.mimo, self.dstate), (0, 0)
)
gB, gC = [
cute.local_tile(t[bidb, None, bidh, None], (self.mimo, self.dstate), (0, 0))
for t in (mBstate, mB, mC)
for t in (mB, mC)
]
# (tile_D,)
gXstate, gX = [
cute.local_tile(t[bidb, bidh, None], (self.tile_D,), (bidd,))
for t in (mXstate, mX)
]
gXstate = cute.local_tile(mXstate[bidb_st, bidh, None], (self.tile_D,), (bidd,))
gX = cute.local_tile(mX[bidb, bidh, None], (self.tile_D,), (bidd,))
if const_expr(mOutproj is not None):
# Output is (B, H, D), outproj reduces MIMO rank
gOut = cute.local_tile(mOut[bidb, bidh, None], (self.tile_D,), (bidd,))
Expand Down Expand Up @@ -482,7 +510,11 @@ def kernel(
cute.arch.cp_async_commit_group()

# Write state back to StateOut (may be same memory as State for in-place)
cute.copy(tiled_copy_state_s2r, tSrS, tSgSOut)
if const_expr(mStateBatchIdx is not None):
if valid_st:
cute.copy(tiled_copy_state_s2r, tSrS, tSgSOut)
else:
cute.copy(tiled_copy_state_s2r, tSrS, tSgSOut)

# Do state @ C
cute.arch.cp_async_wait_group(1) # C is done loading
Expand Down Expand Up @@ -530,6 +562,22 @@ def kernel(
cute.arch.cp_async_wait_group(0) # Zproj and Outproj are done loading
cute.arch.sync_threads()

# Store the new key/value states (this step's B and x) into the pools,
# now that every thread has consumed the old Bstate/Xstate. Single
# D-tile per (b, h) is guaranteed by the wrapper (tile_D >= D), so no
# other CTA still reads these rows. tSrB / tXrX hold the original
# (pre-fp32) values, so the store is bit-exact with the caller-side
# `k_pool[slots] = B; v_pool[slots] = x` it replaces.
if const_expr(self.update_kv_state and mStateBatchIdx is not None):
if valid_st:
tpd_b = self.dstate // vecsize_dstate
if tidx < tpd_b:
tSgBstate_w = smem_thr_copy_B.partition_S(gBstate)[None, None, 0]
cute.autovec_copy(tSrB, tSgBstate_w)
if warp_idx == 0:
if not need_bound_check_X or lane_idx < num_loads_X:
cute.autovec_copy(tXrX, tXgXstate)

if const_expr(mOutproj is not None):
# Gate: z_r * sigmoid(z_r)
if const_expr(mZ is not None):
Expand All @@ -552,15 +600,25 @@ def kernel(
else:
out_val += out[r] * out_proj_val

# Skip padding tokens: zero the output (selective_state_update
# does the same for state_batch_indices < 0).
if const_expr(mStateBatchIdx is not None):
if not valid_st:
out_val = Float32(0.0)
# Write final output (B, H, D)
if lane_idx < lanes_per_D:
gOut[warp_idx * lanes_per_D + lane_idx] = out_val.to(mOut.element_type)
else:
# No outproj: write per-rank output (B, R, H, D)
for r in cutlass.range_constexpr(self.mimo):
gOut_r = cute.local_tile(mOut[bidb, r, bidh, None], (self.tile_D,), (bidd,))
out_r_val = out[r]
# Skip padding tokens (see above).
if const_expr(mStateBatchIdx is not None):
if not valid_st:
out_r_val = Float32(0.0)
if lane_idx < lanes_per_D:
gOut_r[warp_idx * lanes_per_D + lane_idx] = out[r].to(mOut.element_type)
gOut_r[warp_idx * lanes_per_D + lane_idx] = out_r_val.to(mOut.element_type)


def mamba3_step_fn(
Expand All @@ -581,17 +639,42 @@ def mamba3_step_fn(
out: Tensor = None, # (B, H, D) or (B, R, H, D) if remove_outproj
z: Optional[Tensor] = None, # (B, H, D), None if remove_gate
zproj: Optional[Tensor] = None, # (R, H, D), None if remove_gate
state_batch_indices: Optional[Tensor] = None, # (B,) int32 — when given,
# state/Bstate/Xstate are pools of shape (P, ...) and row
# state_batch_indices[b] holds batch element b's state. The state is updated
# in place in the pool (state_out must be None). Avoids gather/scatter.
update_kv_state: bool = False, # kernel also stores this step's B and x
# into Bstate/Xstate after consuming the old values, replacing the
# caller's scatter kernels. Requires state_batch_indices and tile_D >= headdim
# (a single D-tile per (b, h): the Bstate write would race other CTAs'
# reads otherwise). NOTE: mutates Bstate/Xstate.
tile_D: int = 64,
num_warps: int = 2,
) -> None:
has_z = z is not None
has_outproj = outproj is not None
has_state_batch_idx = state_batch_indices is not None
inplace = state_out is None
batch, nheads, hdim, dstate = state.shape
pool, nheads, hdim, dstate = state.shape
if update_kv_state:
assert has_state_batch_idx, "update_kv_state requires state_batch_indices"
assert tile_D >= hdim, (
f"update_kv_state requires a single D-tile per (b, h) "
f"(tile_D={tile_D} >= headdim={hdim}); with multiple D-tiles the "
f"bidd=0 CTA's Bstate write would race other CTAs' reads"
)
mimo = Bstate.shape[1]
assert state.shape == (batch, nheads, hdim, dstate)
assert Bstate.shape == (batch, mimo, nheads, dstate)
assert Xstate.shape == (batch, nheads, hdim)
batch = x.shape[0]
if has_state_batch_idx:
assert inplace, "state_batch_indices requires in-place update (state_out=None)"
assert state_batch_indices.shape == (batch,)
assert state_batch_indices.dtype == torch.int32
assert state_batch_indices.is_cuda
else:
assert pool == batch
assert state.shape == (pool, nheads, hdim, dstate)
assert Bstate.shape == (pool, mimo, nheads, dstate)
assert Xstate.shape == (pool, nheads, hdim)
assert A.shape == (batch, nheads)
assert B.shape == (batch, mimo, nheads, dstate)
assert C.shape == (batch, mimo, nheads, dstate)
Expand All @@ -615,7 +698,7 @@ def mamba3_step_fn(
if inplace:
state_out = state
else:
assert state_out.shape == (batch, nheads, hdim, dstate)
assert state_out.shape == (pool, nheads, hdim, dstate)

required_tensors = [state, Bstate, Xstate, A, B, C, D, x, dt, trap, xproj, state_out, out]
if has_outproj:
Expand Down Expand Up @@ -650,13 +733,17 @@ def mamba3_step_fn(
trap.dtype,
has_z,
has_outproj,
has_state_batch_idx,
update_kv_state,
)
if compile_key not in mamba3_step_fn.compile_cache:
mamba3_step_op = Mamba3Step(tile_D, dstate, mimo, num_warps, remove_gate=not has_z, remove_outproj=not has_outproj)
mamba3_step_op = Mamba3Step(tile_D, dstate, mimo, num_warps, remove_gate=not has_z, remove_outproj=not has_outproj, update_kv_state=update_kv_state)

# Create symbolic dimensions for batch and nheads
batch_sym = cute.sym_int()
nheads_sym = cute.sym_int()
# Pool row count is independent of batch when state_batch_indices is used
pool_sym = cute.sym_int() if has_state_batch_idx else batch_sym

# Divisibility for strides (128-bit alignment)
div_state = 128 // state_cute_dtype.width
Expand All @@ -669,9 +756,9 @@ def mamba3_step_fn(
div_trap = 128 // trap_cute_dtype.width

# Create fake tensors with symbolic batch/nheads dimensions
state_fake = make_fake_tensor(state_cute_dtype, (batch_sym, nheads_sym, hdim, dstate), div_state)
Bstate_fake = make_fake_tensor(b_cute_dtype, (batch_sym, mimo, nheads_sym, dstate), div_b)
Xstate_fake = make_fake_tensor(x_cute_dtype, (batch_sym, nheads_sym, hdim), div_x)
state_fake = make_fake_tensor(state_cute_dtype, (pool_sym, nheads_sym, hdim, dstate), div_state)
Bstate_fake = make_fake_tensor(b_cute_dtype, (pool_sym, mimo, nheads_sym, dstate), div_b)
Xstate_fake = make_fake_tensor(x_cute_dtype, (pool_sym, nheads_sym, hdim), div_x)
A_fake = make_fake_tensor(a_cute_dtype, (batch_sym, nheads_sym), div_a)
B_fake = make_fake_tensor(b_cute_dtype, (batch_sym, mimo, nheads_sym, dstate), div_b)
C_fake = make_fake_tensor(b_cute_dtype, (batch_sym, mimo, nheads_sym, dstate), div_b)
Expand All @@ -681,13 +768,16 @@ def mamba3_step_fn(
trap_fake = make_fake_tensor(trap_cute_dtype, (batch_sym, nheads_sym), div_trap)
xproj_fake = make_fake_tensor(proj_cute_dtype, (mimo, nheads_sym, hdim), div_proj)
outproj_fake = make_fake_tensor(proj_cute_dtype, (mimo, nheads_sym, hdim), div_proj) if has_outproj else None
state_out_fake = make_fake_tensor(state_cute_dtype, (batch_sym, nheads_sym, hdim, dstate), div_state)
state_out_fake = make_fake_tensor(state_cute_dtype, (pool_sym, nheads_sym, hdim, dstate), div_state)
if has_outproj:
out_fake = make_fake_tensor(x_cute_dtype, (batch_sym, nheads_sym, hdim), div_x)
else:
out_fake = make_fake_tensor(x_cute_dtype, (batch_sym, mimo, nheads_sym, hdim), div_x)
z_fake = make_fake_tensor(x_cute_dtype, (batch_sym, nheads_sym, hdim), div_x) if has_z else None
zproj_fake = make_fake_tensor(proj_cute_dtype, (mimo, nheads_sym, hdim), div_proj) if has_z else None
state_batch_idx_fake = (
make_fake_tensor(Int32, (batch_sym,)) if has_state_batch_idx else None
)

fake_stream = cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True)

Expand All @@ -709,6 +799,7 @@ def mamba3_step_fn(
out_fake,
z_fake,
zproj_fake,
state_batch_idx_fake,
fake_stream,
options="--enable-tvm-ffi",
)
Expand All @@ -732,6 +823,7 @@ def mamba3_step_fn(
out,
z,
zproj,
state_batch_indices,
)


Expand Down
Loading