From 2239cfb4d58637034df1f45249ee3e30079ceb9e Mon Sep 17 00:00:00 2001 From: turusore08 Date: Mon, 15 Jun 2026 02:19:58 +0700 Subject: [PATCH] Installation and Build Failure on Modern Cloud Environments (Python 3.12+ / PyTorch 2.X / Kaggle Environments) --- mamba_ssm/modules/block.py | 43 +++- mamba_ssm/modules/mamba2.py | 162 ++++++++++++++- mamba_ssm/modules/mamba2_simple.py | 120 ++++++++++- mamba_ssm/modules/mamba3.py | 60 +++++- mamba_ssm/modules/ssd_minimal.py | 5 +- mamba_ssm/ops/selective_scan_interface.py | 243 +++++++++++++++++++++- 6 files changed, 620 insertions(+), 13 deletions(-) diff --git a/mamba_ssm/modules/block.py b/mamba_ssm/modules/block.py index 1bd968a0b..895431f06 100644 --- a/mamba_ssm/modules/block.py +++ b/mamba_ssm/modules/block.py @@ -4,7 +4,48 @@ import torch from torch import nn, Tensor -from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn +try: + from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn +except ImportError: + RMSNorm = None + layer_norm_fn = None + +if RMSNorm is None: + class RMSNorm(nn.Module): + def __init__(self, dim, eps=1e-5): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + self.register_parameter("bias", None) + + def forward(self, x): + variance = x.pow(2).mean(-1, keepdim=True) + return x * torch.rsqrt(variance + self.eps) * self.weight + +if layer_norm_fn is None: + def layer_norm_fn(x, weight, bias, residual=None, eps=1e-6, prenorm=False, residual_in_fp32=False, is_rms_norm=False): + res = (x + residual) if residual is not None else x + if residual_in_fp32: + res_out = res.to(torch.float32) + else: + res_out = res + + if is_rms_norm: + variance = res.pow(2).mean(-1, keepdim=True) + normed = res * torch.rsqrt(variance + eps) * weight + if bias is not None: + normed = normed + bias + else: + mean = res.mean(-1, keepdim=True) + var = res.var(-1, keepdim=True, unbiased=False) + normed = (res - mean) * torch.rsqrt(var + eps) * weight + if bias is not None: + normed = normed + bias + + if prenorm: + return normed, res_out + else: + return normed class Block(nn.Module): diff --git a/mamba_ssm/modules/mamba2.py b/mamba_ssm/modules/mamba2.py index 36b16d471..69e8a305d 100644 --- a/mamba_ssm/modules/mamba2.py +++ b/mamba_ssm/modules/mamba2.py @@ -23,14 +23,168 @@ except ImportError: selective_state_update = None -from mamba_ssm.ops.triton.layernorm_gated import RMSNorm as RMSNormGated +try: + from mamba_ssm.ops.triton.layernorm_gated import RMSNorm as RMSNormGated +except ImportError: + RMSNormGated = None + +try: + from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined + from mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combined +except ImportError: + mamba_chunk_scan_combined = None + mamba_split_conv1d_scan_combined = None + +if RMSNormGated is None: + class RMSNormGated(nn.Module): + def __init__(self, dim, eps=1e-5, norm_before_gate=False, group_size=None, **kwargs): + super().__init__() + self.eps = eps + self.norm_before_gate = norm_before_gate + self.group_size = group_size + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, x, z=None): + if z is not None: + gated = F.silu(z) + if self.norm_before_gate: + if self.group_size is None: + variance = x.pow(2).mean(-1, keepdim=True) + else: + orig_shape = x.shape + x_reshaped = x.view(*orig_shape[:-1], -1, self.group_size) + variance = x_reshaped.pow(2).mean(-1, keepdim=True) + x_normed = x_reshaped * torch.rsqrt(variance + self.eps) + x = x_normed.view(orig_shape) + normed = x * self.weight + return normed * gated + else: + out = x * gated + if self.group_size is None: + variance = out.pow(2).mean(-1, keepdim=True) + else: + orig_shape = out.shape + out_reshaped = out.view(*orig_shape[:-1], -1, self.group_size) + variance = out_reshaped.pow(2).mean(-1, keepdim=True) + out_normed = out_reshaped * torch.rsqrt(variance + self.eps) + out = out_normed.view(orig_shape) + return out * self.weight + else: + if self.group_size is None: + variance = x.pow(2).mean(-1, keepdim=True) + else: + orig_shape = x.shape + x_reshaped = x.view(*orig_shape[:-1], -1, self.group_size) + variance = x_reshaped.pow(2).mean(-1, keepdim=True) + x_normed = x_reshaped * torch.rsqrt(variance + self.eps) + x = x_normed.view(orig_shape) + return x * self.weight + +if mamba_chunk_scan_combined is None: + def mamba_chunk_scan_combined(x, dt, A, B, C, chunk_size, D=None, z=None, seq_idx=None, initial_states=None, return_final_states=False, **kwargs): + batch, seqlen, nheads, headdim = x.shape + ngroups = B.shape[2] + + # Pad sequence length to a multiple of chunk_size + pad_len = (chunk_size - (seqlen % chunk_size)) % chunk_size + if pad_len > 0: + x = rearrange(F.pad(rearrange(x, 'b l h d -> b h d l'), (0, pad_len), value=0.0), 'b h d l -> b l h d') + dt = rearrange(F.pad(rearrange(dt, 'b l h -> b h l'), (0, pad_len), value=0.0), 'b h l -> b l h') + B = rearrange(F.pad(rearrange(B, 'b l g d -> b g d l'), (0, pad_len), value=0.0), 'b g d l -> b l g d') + C = rearrange(F.pad(rearrange(C, 'b l g d -> b g d l'), (0, pad_len), value=0.0), 'b g d l -> b l g d') + if z is not None: + z = rearrange(F.pad(rearrange(z, 'b l h d -> b h d l'), (0, pad_len), value=0.0), 'b h d l -> b l h d') + + if ngroups < nheads: + B = repeat(B, "b l g d -> b l (g h) d", h=nheads // ngroups) + C = repeat(C, "b l g d -> b l (g h) d", h=nheads // ngroups) + A_discrete = A.view(1, 1, -1) * dt + X_discrete = x * dt.unsqueeze(-1) + from mamba_ssm.modules.ssd_minimal import ssd_minimal_discrete + y, final_state = ssd_minimal_discrete(X_discrete, A_discrete, B, C, chunk_size, initial_states=initial_states) + + if pad_len > 0: + y = y[:, :-pad_len] + x = x[:, :-pad_len] + if z is not None: + z = z[:, :-pad_len] + + if D is not None: + if D.dim() == 2: + y = y + x * D.unsqueeze(0).unsqueeze(0) + else: + y = y + x * D.view(1, 1, -1, 1) + if z is not None: + y = y * F.silu(z) + if return_final_states: + return y, final_state + return y + +if mamba_split_conv1d_scan_combined is None: + def mamba_split_conv1d_scan_combined( + zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D=None, chunk_size=256, + seq_idx=None, activation="swish", rmsnorm_weight=None, rmsnorm_eps=1e-5, + outproj_weight=None, outproj_bias=None, headdim=128, ngroups=1, + norm_before_gate=False, initial_states=None, **kwargs + ): + batch, seqlen, _ = zxbcdt.shape + d_inner = rmsnorm_weight.shape[0] if rmsnorm_weight is not None else (zxbcdt.shape[-1] - A.shape[0]) // 2 + # Let's infer parameters: + # In Mamba-2, headdim is self.headdim + # We can try to get it from context or kwargs + hdim = headdim if headdim is not None else 64 + nheads = A.shape[0] + d_state = (zxbcdt.shape[-1] - 2 * d_inner - nheads) // (2 * ngroups) if ngroups > 0 else 64 + z, xBC, dt = torch.split(zxbcdt, [d_inner, d_inner + 2 * ngroups * d_state, nheads], dim=-1) + dt = F.softplus(dt + dt_bias) + if causal_conv1d_fn is not None: + xBC = causal_conv1d_fn( + x=xBC.transpose(1, 2), + weight=conv1d_weight, + bias=conv1d_bias, + activation="silu", + ).transpose(1, 2) + else: + d_conv = conv1d_weight.shape[-1] + xBC_padded = F.pad(xBC.transpose(1, 2), (d_conv - 1, 0)) + xBC_conv = F.conv1d(xBC_padded, conv1d_weight.unsqueeze(1), bias=conv1d_bias, groups=xBC.shape[-1]) + xBC = F.silu(xBC_conv).transpose(1, 2) + x, B, C = torch.split(xBC, [d_inner, ngroups * d_state, ngroups * d_state], dim=-1) + y = mamba_chunk_scan_combined( + rearrange(x, "b l (h p) -> b l h p", p=hdim), + dt, + A, + rearrange(B, "b l (g n) -> b l g n", g=ngroups), + rearrange(C, "b l (g n) -> b l g n", g=ngroups), + chunk_size=chunk_size, + D=D, + z=None, + seq_idx=seq_idx, + initial_states=initial_states, + **kwargs, + ) + y = rearrange(y, "b l h p -> b l (h p)") + gated = F.silu(z) + if rmsnorm_weight is not None: + if norm_before_gate: + # Group-size logic: self.norm has group_size + # Since self.norm is RMSNormGated, we can call it directly + # but since we are in the standalone function, we can do group norm or call a custom norm + # We can just construct a temporary RMSNormGated to do it: + norm_layer = RMSNormGated(d_inner, eps=rmsnorm_eps, norm_before_gate=True, group_size=d_inner // ngroups) + norm_layer.weight = nn.Parameter(rmsnorm_weight) + y = norm_layer(y, z) + else: + norm_layer = RMSNormGated(d_inner, eps=rmsnorm_eps, norm_before_gate=False, group_size=d_inner // ngroups) + norm_layer.weight = nn.Parameter(rmsnorm_weight) + y = norm_layer(y, z) + else: + y = y * gated + return F.linear(y, outproj_weight, outproj_bias) from mamba_ssm.distributed.tensor_parallel import ColumnParallelLinear, RowParallelLinear from mamba_ssm.distributed.distributed_utils import all_reduce, reduce_scatter -from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined -from mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combined - from huggingface_hub import PyTorchModelHubMixin diff --git a/mamba_ssm/modules/mamba2_simple.py b/mamba_ssm/modules/mamba2_simple.py index 77a6af28e..ea2569c72 100644 --- a/mamba_ssm/modules/mamba2_simple.py +++ b/mamba_ssm/modules/mamba2_simple.py @@ -17,8 +17,124 @@ except ImportError: RMSNormGated, LayerNorm = None, None -from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined -from mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combined +try: + from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined + from mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combined +except ImportError: + mamba_chunk_scan_combined = None + mamba_split_conv1d_scan_combined = None + +if RMSNormGated is None: + class RMSNormGated(nn.Module): + def __init__(self, dim, eps=1e-5, norm_before_gate=False, **kwargs): + super().__init__() + self.eps = eps + self.norm_before_gate = norm_before_gate + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, x, z=None): + if z is not None: + gated = F.silu(z) + if self.norm_before_gate: + variance = x.pow(2).mean(-1, keepdim=True) + normed = x * torch.rsqrt(variance + self.eps) * self.weight + return normed * gated + else: + out = x * gated + variance = out.pow(2).mean(-1, keepdim=True) + return out * torch.rsqrt(variance + self.eps) * self.weight + else: + variance = x.pow(2).mean(-1, keepdim=True) + return x * torch.rsqrt(variance + self.eps) * self.weight + +if mamba_chunk_scan_combined is None: + def mamba_chunk_scan_combined(x, dt, A, B, C, chunk_size, D=None, z=None, seq_idx=None, initial_states=None, **kwargs): + batch, seqlen, nheads, headdim = x.shape + ngroups = B.shape[2] + + # Pad sequence length to a multiple of chunk_size + pad_len = (chunk_size - (seqlen % chunk_size)) % chunk_size + if pad_len > 0: + x = rearrange(F.pad(rearrange(x, 'b l h d -> b h d l'), (0, pad_len), value=0.0), 'b h d l -> b l h d') + dt = rearrange(F.pad(rearrange(dt, 'b l h -> b h l'), (0, pad_len), value=0.0), 'b h l -> b l h') + B = rearrange(F.pad(rearrange(B, 'b l g d -> b g d l'), (0, pad_len), value=0.0), 'b g d l -> b l g d') + C = rearrange(F.pad(rearrange(C, 'b l g d -> b g d l'), (0, pad_len), value=0.0), 'b g d l -> b l g d') + if z is not None: + z = rearrange(F.pad(rearrange(z, 'b l h d -> b h d l'), (0, pad_len), value=0.0), 'b h d l -> b l h d') + + if ngroups < nheads: + B = repeat(B, "b l g d -> b l (g h) d", h=nheads // ngroups) + C = repeat(C, "b l g d -> b l (g h) d", h=nheads // ngroups) + A_discrete = A.view(1, 1, -1) * dt + X_discrete = x * dt.unsqueeze(-1) + from mamba_ssm.modules.ssd_minimal import ssd_minimal_discrete + y, final_state = ssd_minimal_discrete(X_discrete, A_discrete, B, C, chunk_size, initial_states=initial_states) + + if pad_len > 0: + y = y[:, :-pad_len] + x = x[:, :-pad_len] + if z is not None: + z = z[:, :-pad_len] + + if D is not None: + if D.dim() == 2: + y = y + x * D.unsqueeze(0).unsqueeze(0) + else: + y = y + x * D.view(1, 1, -1, 1) + if z is not None: + y = y * F.silu(z) + return y + +if mamba_split_conv1d_scan_combined is None: + def mamba_split_conv1d_scan_combined( + zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D=None, chunk_size=256, + seq_idx=None, activation="swish", rmsnorm_weight=None, rmsnorm_eps=1e-5, + outproj_weight=None, outproj_bias=None, headdim=128, ngroups=1, + norm_before_gate=False, initial_states=None, **kwargs + ): + batch, seqlen, _ = zxbcdt.shape + d_inner = rmsnorm_weight.shape[0] + nheads = A.shape[0] + d_state = (zxbcdt.shape[-1] - 2 * d_inner - nheads) // (2 * ngroups) + z, xBC, dt = torch.split(zxbcdt, [d_inner, d_inner + 2 * ngroups * d_state, nheads], dim=-1) + dt = F.softplus(dt + dt_bias) + if causal_conv1d_fn is not None: + xBC = causal_conv1d_fn( + x=xBC.transpose(1, 2), + weight=conv1d_weight, + bias=conv1d_bias, + activation=activation, + ).transpose(1, 2) + else: + d_conv = conv1d_weight.shape[-1] + xBC_padded = F.pad(xBC.transpose(1, 2), (d_conv - 1, 0)) + xBC_conv = F.conv1d(xBC_padded, conv1d_weight.unsqueeze(1), bias=conv1d_bias, groups=xBC.shape[-1]) + xBC = F.silu(xBC_conv).transpose(1, 2) + x, B, C = torch.split(xBC, [d_inner, ngroups * d_state, ngroups * d_state], dim=-1) + y = mamba_chunk_scan_combined( + rearrange(x, "b l (h p) -> b l h p", p=headdim), + dt, + A, + rearrange(B, "b l (g n) -> b l g n", g=ngroups), + rearrange(C, "b l (g n) -> b l g n", g=ngroups), + chunk_size=chunk_size, + D=D, + z=None, + seq_idx=seq_idx, + initial_states=initial_states, + **kwargs, + ) + y = rearrange(y, "b l h p -> b l (h p)") + gated = F.silu(z) + if norm_before_gate: + variance = y.pow(2).mean(-1, keepdim=True) + normed = y * torch.rsqrt(variance + rmsnorm_eps) * rmsnorm_weight + y = normed * gated + else: + out = y * gated + variance = out.pow(2).mean(-1, keepdim=True) + y = out * torch.rsqrt(variance + rmsnorm_eps) * rmsnorm_weight + return F.linear(y, outproj_weight, outproj_bias) class Mamba2Simple(nn.Module): diff --git a/mamba_ssm/modules/mamba3.py b/mamba_ssm/modules/mamba3.py index 34c119c2c..523c22fdd 100644 --- a/mamba_ssm/modules/mamba3.py +++ b/mamba_ssm/modules/mamba3.py @@ -7,16 +7,70 @@ import torch.nn as nn import torch.nn.functional as F -from mamba_ssm.ops.triton.layernorm_gated import RMSNorm as RMSNormGated +try: + from mamba_ssm.ops.triton.layernorm_gated import RMSNorm as RMSNormGated +except ImportError: + RMSNormGated = None try: from mamba_ssm.ops.tilelang.mamba3.mamba3_mimo import mamba3_mimo as mamba3_mimo_combined except ImportError: mamba3_mimo_combined = None -from mamba_ssm.ops.triton.mamba3.mamba3_siso_combined import mamba3_siso_combined +try: + from mamba_ssm.ops.triton.mamba3.mamba3_siso_combined import mamba3_siso_combined +except ImportError: + mamba3_siso_combined = None -from mamba_ssm.ops.triton.mamba3.mamba3_mimo_rotary_step import apply_rotary_qk_inference_fwd +try: + from mamba_ssm.ops.triton.mamba3.mamba3_mimo_rotary_step import apply_rotary_qk_inference_fwd +except ImportError: + apply_rotary_qk_inference_fwd = None + +if RMSNormGated is None: + class RMSNormGated(nn.Module): + def __init__(self, dim, eps=1e-5, norm_before_gate=False, group_size=None, **kwargs): + super().__init__() + self.eps = eps + self.norm_before_gate = norm_before_gate + self.group_size = group_size + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, x, z=None): + if z is not None: + gated = F.silu(z) + if self.norm_before_gate: + if self.group_size is None: + variance = x.pow(2).mean(-1, keepdim=True) + else: + orig_shape = x.shape + x_reshaped = x.view(*orig_shape[:-1], -1, self.group_size) + variance = x_reshaped.pow(2).mean(-1, keepdim=True) + x_normed = x_reshaped * torch.rsqrt(variance + self.eps) + x = x_normed.view(orig_shape) + normed = x * self.weight + return normed * gated + else: + out = x * gated + if self.group_size is None: + variance = out.pow(2).mean(-1, keepdim=True) + else: + orig_shape = out.shape + out_reshaped = out.view(*orig_shape[:-1], -1, self.group_size) + variance = out_reshaped.pow(2).mean(-1, keepdim=True) + out_normed = out_reshaped * torch.rsqrt(variance + self.eps) + out = out_normed.view(orig_shape) + return out * self.weight + else: + if self.group_size is None: + variance = x.pow(2).mean(-1, keepdim=True) + else: + orig_shape = x.shape + x_reshaped = x.view(*orig_shape[:-1], -1, self.group_size) + variance = x_reshaped.pow(2).mean(-1, keepdim=True) + x_normed = x_reshaped * torch.rsqrt(variance + self.eps) + x = x_normed.view(orig_shape) + return x * self.weight try: from mamba_ssm.ops.cute.mamba3.mamba3_step_fn import mamba3_step_fn diff --git a/mamba_ssm/modules/ssd_minimal.py b/mamba_ssm/modules/ssd_minimal.py index 9632ebd43..83e743ff5 100644 --- a/mamba_ssm/modules/ssd_minimal.py +++ b/mamba_ssm/modules/ssd_minimal.py @@ -8,7 +8,10 @@ import torch.nn.functional as F from einops import rearrange, repeat -from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined +try: + from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined +except ImportError: + mamba_chunk_scan_combined = None def segsum_unstable(x): diff --git a/mamba_ssm/ops/selective_scan_interface.py b/mamba_ssm/ops/selective_scan_interface.py index a41f1359c..63b64196c 100644 --- a/mamba_ssm/ops/selective_scan_interface.py +++ b/mamba_ssm/ops/selective_scan_interface.py @@ -15,9 +15,15 @@ causal_conv1d_bwd_function = None causal_conv1d_update_function = None -from mamba_ssm.ops.triton.layer_norm import _layer_norm_fwd +try: + from mamba_ssm.ops.triton.layer_norm import _layer_norm_fwd +except ImportError: + _layer_norm_fwd = None -import selective_scan_cuda +try: + import selective_scan_cuda +except ImportError: + selective_scan_cuda = None class SelectiveScanFn(torch.autograd.Function): @@ -96,6 +102,20 @@ def rms_norm_forward( weight = weight.contiguous() if bias is not None: bias = bias.contiguous() + if _layer_norm_fwd is None: + if is_rms_norm: + variance = x.pow(2).mean(-1, keepdim=True) + output = x * torch.rsqrt(variance + eps) * weight + if bias is not None: + output = output + bias + return output + else: + mean = x.mean(-1, keepdim=True) + var = x.var(-1, keepdim=True, unbiased=False) + output = (x - mean) * torch.rsqrt(var + eps) * weight + if bias is not None: + output = output + bias + return output y = _layer_norm_fwd( x, weight, bias, eps, None, residual_dtype=None, is_rms_norm=is_rms_norm )[0] @@ -103,12 +123,167 @@ def rms_norm_forward( return y +import math + +def hybrid_chunk_scan_pytorch(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, + return_last_state=False, chunk_size=64): + """ + Pure PyTorch implementation of selective scan using a Hybrid Chunked Associative Scan. + Fully parallelized and vectorized. + """ + dtype_in = u.dtype + u = u.float() + delta = delta.float() + if delta_bias is not None: + delta = delta + delta_bias[..., None].float() + if delta_softplus: + delta = F.softplus(delta) + + batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1] + is_variable_B = B.dim() >= 3 + is_variable_C = C.dim() >= 3 + + if A.is_complex(): + if is_variable_B: + B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2)) + if is_variable_C: + C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2)) + else: + B = B.float() + C = C.float() + + deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) + if not is_variable_B: + deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u) + else: + if B.dim() == 3: + deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u) + else: + B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1]) + deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u) + + if is_variable_C and C.dim() == 4: + C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1]) + + # Rearrange deltaA and deltaB_u to (B, D, N, L) + a = rearrange(deltaA, 'b d l n -> b d n l') + b = rearrange(deltaB_u, 'b d l n -> b d n l') + + L = u.shape[2] + + # Pad L to a multiple of chunk_size if necessary + pad_len = (chunk_size - (L % chunk_size)) % chunk_size + if pad_len > 0: + a = F.pad(a, (0, pad_len), value=1.0) + b = F.pad(b, (0, pad_len), value=0.0) + + L_padded = a.shape[-1] + num_chunks = L_padded // chunk_size + + # Reshape to (B, D, N, C, W) + a = rearrange(a, 'b d n (c w) -> b d n c w', w=chunk_size) + b = rearrange(b, 'b d n (c w) -> b d n c w', w=chunk_size) + + # 1. Intra-chunk sequential scan (vectorized over chunks, batch, and dims) + A_loc = [] + X_loc = [] + + curr_A = torch.ones((batch, dim, dstate, num_chunks), device=a.device, dtype=a.dtype) + curr_X = torch.zeros((batch, dim, dstate, num_chunks), device=b.device, dtype=b.dtype) + + for w in range(chunk_size): + curr_A = a[..., w] * curr_A + curr_X = a[..., w] * curr_X + b[..., w] + A_loc.append(curr_A) + X_loc.append(curr_X) + + # Shape of A_loc and X_loc is (chunk_size, B, D, N, C) + # Stack them to shape (B, D, N, C, W) + A_loc = torch.stack(A_loc, dim=-1) + X_loc = torch.stack(X_loc, dim=-1) + + # Chunk summary: + # chunk_a: decay factor for each chunk (B, D, N, C) + # chunk_b: output term for each chunk (B, D, N, C) + chunk_a = A_loc[..., -1] + chunk_b = X_loc[..., -1] + + # 2. Inter-chunk parallel scan (Kogge-Stone style scan over C chunks) + # s_c: boundary state at start of chunk c (B, D, N, C) + # s_0 = 0. + s = torch.zeros((batch, dim, dstate, num_chunks), device=a.device, dtype=a.dtype) + + # We want to solve s_{c+1} = chunk_a_c * s_c + chunk_b_c + # This means the state after chunk c is chunk_x_c = prefix_scan(chunk_a, chunk_b) + # Let's perform parallel prefix scan on chunk_a and chunk_b + scan_a = chunk_a.clone() + scan_b = chunk_b.clone() + + num_steps = int(math.ceil(math.log2(num_chunks))) + for i in range(num_steps): + stride = 2 ** i + if stride >= num_chunks: + break + # Shift and combine + a_left = scan_a[..., :num_chunks - stride] + b_left = scan_b[..., :num_chunks - stride] + + a_right = scan_a[..., stride:] + b_right = scan_b[..., stride:] + + scan_a[..., stride:] = a_right * a_left + scan_b[..., stride:] = a_right * b_left + b_right + + # The state at the end of chunk c is scan_b[..., c] + # So the state at the start of chunk c+1 is scan_b[..., c] + # The state at the start of chunk 0 is s_0 = 0 + if num_chunks > 1: + s[..., 1:] = scan_b[..., :-1] + + # 3. Final state reconstruction for all steps in all chunks + # x_{c, w} = A_loc_{c, w} * s_c + X_loc_{c, w} + # A_loc: (B, D, N, C, W) + # s: (B, D, N, C) -> unsqueeze to (B, D, N, C, 1) + # X_loc: (B, D, N, C, W) + x = A_loc * s.unsqueeze(-1) + X_loc + + # Reshape back to (B, D, N, L_padded) and truncate to (B, D, N, L) + x = rearrange(x, 'b d n c w -> b d n (c w)') + if pad_len > 0: + x = x[..., :-pad_len] + + # Get last state if needed + last_state = x[..., -1] + + # Compute output y: + # x: (B, D, N, L) + if not is_variable_C: + y = torch.einsum('bdnl,dn->bdl', x, C) + else: + if C.dim() == 3: + y = torch.einsum('bdnl,bnl->bdl', x, C) + else: + y = torch.einsum('bdnl,bdnl->bdl', x, C) + + if y.is_complex(): + y = y.real * 2 + + out = y if D is None else y + u * rearrange(D, "d -> d 1") + if z is not None: + out = out * F.silu(z) + + out = out.to(dtype=dtype_in) + return out if not return_last_state else (out, last_state) + + def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False): """if return_last_state is True, returns (out, last_state) last_state has shape (batch, dim, dstate). Note that the gradient of the last state is not considered in the backward pass. """ + if selective_scan_cuda is None: + return hybrid_chunk_scan_pytorch(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state) return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state) @@ -370,12 +545,76 @@ def backward(ctx, dout): dB_proj_bias, dC_proj_bias, None, None, None, None, None, None) +def mamba_inner_ref_pure( + xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, + out_proj_weight, out_proj_bias, + A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, + C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1, + b_rms_weight=None, c_rms_weight=None, dt_rms_weight=None, b_c_dt_rms_eps=1e-6 +): + L = xz.shape[-1] + delta_rank = delta_proj_weight.shape[1] + d_state = A.shape[-1] * (1 if not A.is_complex() else 2) + x, z = xz.chunk(2, dim=1) + + if causal_conv1d_fn is not None: + x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, activation="silu") + else: + d_conv = conv1d_weight.shape[-1] + x_padded = F.pad(x, (d_conv - 1, 0)) + x = F.conv1d(x_padded, conv1d_weight, bias=conv1d_bias, groups=x.shape[1]) + x = F.silu(x) + + x_dbl = F.linear(rearrange(x, 'b d l -> (b l) d'), x_proj_weight) # (bl d) + delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l=L) + + if B is None: # variable B + B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl dstate) + if B_proj_bias is not None: + B = B + B_proj_bias.to(dtype=B.dtype) + if not A.is_complex(): + B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous() + else: + B = rearrange(B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous() + if C is None: # variable C + C = x_dbl[:, -d_state:] # (bl dstate) + if C_proj_bias is not None: + C = C + C_proj_bias.to(dtype=C.dtype) + if not A.is_complex(): + C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous() + else: + C = rearrange(C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous() + + if b_rms_weight is not None: + B = rearrange(B, "b dstate l -> (b l) dstate").contiguous() + B = rms_norm_forward(B, b_rms_weight, bias=None, eps=b_c_dt_rms_eps) + B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous() + if c_rms_weight is not None: + C = rearrange(C, "b dstate l -> (b l) dstate").contiguous() + C = rms_norm_forward(C, c_rms_weight, bias=None, eps=b_c_dt_rms_eps) + C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous() + if dt_rms_weight is not None: + delta = rearrange(delta, "b d l -> (b l) d", l=L).contiguous() + delta = rms_norm_forward(delta, dt_rms_weight, bias=None, eps=b_c_dt_rms_eps) + delta = rearrange(delta, "(b l) d -> b d l", l=L).contiguous() + + y = selective_scan_fn(x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=delta_softplus) + return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias) + + def mamba_inner_fn( xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1, b_rms_weight= None, c_rms_weight= None, dt_rms_weight= None, b_c_dt_rms_eps=1e-6 ): + if selective_scan_cuda is None or causal_conv1d_fwd_function is None: + return mamba_inner_ref_pure( + xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, + out_proj_weight, out_proj_bias, + A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus, checkpoint_lvl, + b_rms_weight, c_rms_weight, dt_rms_weight, b_c_dt_rms_eps + ) return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus, checkpoint_lvl, b_rms_weight, c_rms_weight, dt_rms_weight, b_c_dt_rms_eps)