Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 42 additions & 1 deletion mamba_ssm/modules/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,48 @@
import torch
from torch import nn, Tensor

from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn
try:
from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn
except ImportError:
RMSNorm = None
layer_norm_fn = None

if RMSNorm is None:
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
self.register_parameter("bias", None)

def forward(self, x):
variance = x.pow(2).mean(-1, keepdim=True)
return x * torch.rsqrt(variance + self.eps) * self.weight

if layer_norm_fn is None:
def layer_norm_fn(x, weight, bias, residual=None, eps=1e-6, prenorm=False, residual_in_fp32=False, is_rms_norm=False):
res = (x + residual) if residual is not None else x
if residual_in_fp32:
res_out = res.to(torch.float32)
else:
res_out = res

if is_rms_norm:
variance = res.pow(2).mean(-1, keepdim=True)
normed = res * torch.rsqrt(variance + eps) * weight
if bias is not None:
normed = normed + bias
else:
mean = res.mean(-1, keepdim=True)
var = res.var(-1, keepdim=True, unbiased=False)
normed = (res - mean) * torch.rsqrt(var + eps) * weight
if bias is not None:
normed = normed + bias

if prenorm:
return normed, res_out
else:
return normed


class Block(nn.Module):
Expand Down
162 changes: 158 additions & 4 deletions mamba_ssm/modules/mamba2.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,168 @@
except ImportError:
selective_state_update = None

from mamba_ssm.ops.triton.layernorm_gated import RMSNorm as RMSNormGated
try:
from mamba_ssm.ops.triton.layernorm_gated import RMSNorm as RMSNormGated
except ImportError:
RMSNormGated = None

try:
from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined
from mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combined
except ImportError:
mamba_chunk_scan_combined = None
mamba_split_conv1d_scan_combined = None

if RMSNormGated is None:
class RMSNormGated(nn.Module):
def __init__(self, dim, eps=1e-5, norm_before_gate=False, group_size=None, **kwargs):
super().__init__()
self.eps = eps
self.norm_before_gate = norm_before_gate
self.group_size = group_size
self.weight = nn.Parameter(torch.ones(dim))

def forward(self, x, z=None):
if z is not None:
gated = F.silu(z)
if self.norm_before_gate:
if self.group_size is None:
variance = x.pow(2).mean(-1, keepdim=True)
else:
orig_shape = x.shape
x_reshaped = x.view(*orig_shape[:-1], -1, self.group_size)
variance = x_reshaped.pow(2).mean(-1, keepdim=True)
x_normed = x_reshaped * torch.rsqrt(variance + self.eps)
x = x_normed.view(orig_shape)
normed = x * self.weight
return normed * gated
else:
out = x * gated
if self.group_size is None:
variance = out.pow(2).mean(-1, keepdim=True)
else:
orig_shape = out.shape
out_reshaped = out.view(*orig_shape[:-1], -1, self.group_size)
variance = out_reshaped.pow(2).mean(-1, keepdim=True)
out_normed = out_reshaped * torch.rsqrt(variance + self.eps)
out = out_normed.view(orig_shape)
return out * self.weight
else:
if self.group_size is None:
variance = x.pow(2).mean(-1, keepdim=True)
else:
orig_shape = x.shape
x_reshaped = x.view(*orig_shape[:-1], -1, self.group_size)
variance = x_reshaped.pow(2).mean(-1, keepdim=True)
x_normed = x_reshaped * torch.rsqrt(variance + self.eps)
x = x_normed.view(orig_shape)
return x * self.weight

if mamba_chunk_scan_combined is None:
def mamba_chunk_scan_combined(x, dt, A, B, C, chunk_size, D=None, z=None, seq_idx=None, initial_states=None, return_final_states=False, **kwargs):
batch, seqlen, nheads, headdim = x.shape
ngroups = B.shape[2]

# Pad sequence length to a multiple of chunk_size
pad_len = (chunk_size - (seqlen % chunk_size)) % chunk_size
if pad_len > 0:
x = rearrange(F.pad(rearrange(x, 'b l h d -> b h d l'), (0, pad_len), value=0.0), 'b h d l -> b l h d')
dt = rearrange(F.pad(rearrange(dt, 'b l h -> b h l'), (0, pad_len), value=0.0), 'b h l -> b l h')
B = rearrange(F.pad(rearrange(B, 'b l g d -> b g d l'), (0, pad_len), value=0.0), 'b g d l -> b l g d')
C = rearrange(F.pad(rearrange(C, 'b l g d -> b g d l'), (0, pad_len), value=0.0), 'b g d l -> b l g d')
if z is not None:
z = rearrange(F.pad(rearrange(z, 'b l h d -> b h d l'), (0, pad_len), value=0.0), 'b h d l -> b l h d')

if ngroups < nheads:
B = repeat(B, "b l g d -> b l (g h) d", h=nheads // ngroups)
C = repeat(C, "b l g d -> b l (g h) d", h=nheads // ngroups)
A_discrete = A.view(1, 1, -1) * dt
X_discrete = x * dt.unsqueeze(-1)
from mamba_ssm.modules.ssd_minimal import ssd_minimal_discrete
y, final_state = ssd_minimal_discrete(X_discrete, A_discrete, B, C, chunk_size, initial_states=initial_states)

if pad_len > 0:
y = y[:, :-pad_len]
x = x[:, :-pad_len]
if z is not None:
z = z[:, :-pad_len]

if D is not None:
if D.dim() == 2:
y = y + x * D.unsqueeze(0).unsqueeze(0)
else:
y = y + x * D.view(1, 1, -1, 1)
if z is not None:
y = y * F.silu(z)
if return_final_states:
return y, final_state
return y

if mamba_split_conv1d_scan_combined is None:
def mamba_split_conv1d_scan_combined(
zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D=None, chunk_size=256,
seq_idx=None, activation="swish", rmsnorm_weight=None, rmsnorm_eps=1e-5,
outproj_weight=None, outproj_bias=None, headdim=128, ngroups=1,
norm_before_gate=False, initial_states=None, **kwargs
):
batch, seqlen, _ = zxbcdt.shape
d_inner = rmsnorm_weight.shape[0] if rmsnorm_weight is not None else (zxbcdt.shape[-1] - A.shape[0]) // 2
# Let's infer parameters:
# In Mamba-2, headdim is self.headdim
# We can try to get it from context or kwargs
hdim = headdim if headdim is not None else 64
nheads = A.shape[0]
d_state = (zxbcdt.shape[-1] - 2 * d_inner - nheads) // (2 * ngroups) if ngroups > 0 else 64
z, xBC, dt = torch.split(zxbcdt, [d_inner, d_inner + 2 * ngroups * d_state, nheads], dim=-1)
dt = F.softplus(dt + dt_bias)
if causal_conv1d_fn is not None:
xBC = causal_conv1d_fn(
x=xBC.transpose(1, 2),
weight=conv1d_weight,
bias=conv1d_bias,
activation="silu",
).transpose(1, 2)
else:
d_conv = conv1d_weight.shape[-1]
xBC_padded = F.pad(xBC.transpose(1, 2), (d_conv - 1, 0))
xBC_conv = F.conv1d(xBC_padded, conv1d_weight.unsqueeze(1), bias=conv1d_bias, groups=xBC.shape[-1])
xBC = F.silu(xBC_conv).transpose(1, 2)
x, B, C = torch.split(xBC, [d_inner, ngroups * d_state, ngroups * d_state], dim=-1)
y = mamba_chunk_scan_combined(
rearrange(x, "b l (h p) -> b l h p", p=hdim),
dt,
A,
rearrange(B, "b l (g n) -> b l g n", g=ngroups),
rearrange(C, "b l (g n) -> b l g n", g=ngroups),
chunk_size=chunk_size,
D=D,
z=None,
seq_idx=seq_idx,
initial_states=initial_states,
**kwargs,
)
y = rearrange(y, "b l h p -> b l (h p)")
gated = F.silu(z)
if rmsnorm_weight is not None:
if norm_before_gate:
# Group-size logic: self.norm has group_size
# Since self.norm is RMSNormGated, we can call it directly
# but since we are in the standalone function, we can do group norm or call a custom norm
# We can just construct a temporary RMSNormGated to do it:
norm_layer = RMSNormGated(d_inner, eps=rmsnorm_eps, norm_before_gate=True, group_size=d_inner // ngroups)
norm_layer.weight = nn.Parameter(rmsnorm_weight)
y = norm_layer(y, z)
else:
norm_layer = RMSNormGated(d_inner, eps=rmsnorm_eps, norm_before_gate=False, group_size=d_inner // ngroups)
norm_layer.weight = nn.Parameter(rmsnorm_weight)
y = norm_layer(y, z)
else:
y = y * gated
return F.linear(y, outproj_weight, outproj_bias)

from mamba_ssm.distributed.tensor_parallel import ColumnParallelLinear, RowParallelLinear
from mamba_ssm.distributed.distributed_utils import all_reduce, reduce_scatter

from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined
from mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combined

from huggingface_hub import PyTorchModelHubMixin


Expand Down
120 changes: 118 additions & 2 deletions mamba_ssm/modules/mamba2_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,124 @@
except ImportError:
RMSNormGated, LayerNorm = None, None

from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined
from mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combined
try:
from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined
from mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combined
except ImportError:
mamba_chunk_scan_combined = None
mamba_split_conv1d_scan_combined = None

if RMSNormGated is None:
class RMSNormGated(nn.Module):
def __init__(self, dim, eps=1e-5, norm_before_gate=False, **kwargs):
super().__init__()
self.eps = eps
self.norm_before_gate = norm_before_gate
self.weight = nn.Parameter(torch.ones(dim))

def forward(self, x, z=None):
if z is not None:
gated = F.silu(z)
if self.norm_before_gate:
variance = x.pow(2).mean(-1, keepdim=True)
normed = x * torch.rsqrt(variance + self.eps) * self.weight
return normed * gated
else:
out = x * gated
variance = out.pow(2).mean(-1, keepdim=True)
return out * torch.rsqrt(variance + self.eps) * self.weight
else:
variance = x.pow(2).mean(-1, keepdim=True)
return x * torch.rsqrt(variance + self.eps) * self.weight

if mamba_chunk_scan_combined is None:
def mamba_chunk_scan_combined(x, dt, A, B, C, chunk_size, D=None, z=None, seq_idx=None, initial_states=None, **kwargs):
batch, seqlen, nheads, headdim = x.shape
ngroups = B.shape[2]

# Pad sequence length to a multiple of chunk_size
pad_len = (chunk_size - (seqlen % chunk_size)) % chunk_size
if pad_len > 0:
x = rearrange(F.pad(rearrange(x, 'b l h d -> b h d l'), (0, pad_len), value=0.0), 'b h d l -> b l h d')
dt = rearrange(F.pad(rearrange(dt, 'b l h -> b h l'), (0, pad_len), value=0.0), 'b h l -> b l h')
B = rearrange(F.pad(rearrange(B, 'b l g d -> b g d l'), (0, pad_len), value=0.0), 'b g d l -> b l g d')
C = rearrange(F.pad(rearrange(C, 'b l g d -> b g d l'), (0, pad_len), value=0.0), 'b g d l -> b l g d')
if z is not None:
z = rearrange(F.pad(rearrange(z, 'b l h d -> b h d l'), (0, pad_len), value=0.0), 'b h d l -> b l h d')

if ngroups < nheads:
B = repeat(B, "b l g d -> b l (g h) d", h=nheads // ngroups)
C = repeat(C, "b l g d -> b l (g h) d", h=nheads // ngroups)
A_discrete = A.view(1, 1, -1) * dt
X_discrete = x * dt.unsqueeze(-1)
from mamba_ssm.modules.ssd_minimal import ssd_minimal_discrete
y, final_state = ssd_minimal_discrete(X_discrete, A_discrete, B, C, chunk_size, initial_states=initial_states)

if pad_len > 0:
y = y[:, :-pad_len]
x = x[:, :-pad_len]
if z is not None:
z = z[:, :-pad_len]

if D is not None:
if D.dim() == 2:
y = y + x * D.unsqueeze(0).unsqueeze(0)
else:
y = y + x * D.view(1, 1, -1, 1)
if z is not None:
y = y * F.silu(z)
return y

if mamba_split_conv1d_scan_combined is None:
def mamba_split_conv1d_scan_combined(
zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D=None, chunk_size=256,
seq_idx=None, activation="swish", rmsnorm_weight=None, rmsnorm_eps=1e-5,
outproj_weight=None, outproj_bias=None, headdim=128, ngroups=1,
norm_before_gate=False, initial_states=None, **kwargs
):
batch, seqlen, _ = zxbcdt.shape
d_inner = rmsnorm_weight.shape[0]
nheads = A.shape[0]
d_state = (zxbcdt.shape[-1] - 2 * d_inner - nheads) // (2 * ngroups)
z, xBC, dt = torch.split(zxbcdt, [d_inner, d_inner + 2 * ngroups * d_state, nheads], dim=-1)
dt = F.softplus(dt + dt_bias)
if causal_conv1d_fn is not None:
xBC = causal_conv1d_fn(
x=xBC.transpose(1, 2),
weight=conv1d_weight,
bias=conv1d_bias,
activation=activation,
).transpose(1, 2)
else:
d_conv = conv1d_weight.shape[-1]
xBC_padded = F.pad(xBC.transpose(1, 2), (d_conv - 1, 0))
xBC_conv = F.conv1d(xBC_padded, conv1d_weight.unsqueeze(1), bias=conv1d_bias, groups=xBC.shape[-1])
xBC = F.silu(xBC_conv).transpose(1, 2)
x, B, C = torch.split(xBC, [d_inner, ngroups * d_state, ngroups * d_state], dim=-1)
y = mamba_chunk_scan_combined(
rearrange(x, "b l (h p) -> b l h p", p=headdim),
dt,
A,
rearrange(B, "b l (g n) -> b l g n", g=ngroups),
rearrange(C, "b l (g n) -> b l g n", g=ngroups),
chunk_size=chunk_size,
D=D,
z=None,
seq_idx=seq_idx,
initial_states=initial_states,
**kwargs,
)
y = rearrange(y, "b l h p -> b l (h p)")
gated = F.silu(z)
if norm_before_gate:
variance = y.pow(2).mean(-1, keepdim=True)
normed = y * torch.rsqrt(variance + rmsnorm_eps) * rmsnorm_weight
y = normed * gated
else:
out = y * gated
variance = out.pow(2).mean(-1, keepdim=True)
y = out * torch.rsqrt(variance + rmsnorm_eps) * rmsnorm_weight
return F.linear(y, outproj_weight, outproj_bias)


class Mamba2Simple(nn.Module):
Expand Down
Loading