Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 15 additions & 18 deletions mamba_ssm/ops/triton/mamba3/mamba3_mimo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,28 +648,25 @@ def compute_dacs_segsum_triton_varlen(
da_cs_rev = torch.empty_like(da)
segsum = torch.zeros(B, H, nchunks, chunk_size, chunk_size, device=da.device, dtype=da.dtype)

seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]
chunks_per_seq = (seq_lens // chunk_size) + 1

# Build mapping tensors: both have length nchunks_global.
# Inactive (padding) slots are given a sentinel local-chunk index that
# places chunk_start >= seq_end, making the kernel mask all-False.
state_seq_mapping = torch.zeros(nchunks, dtype=torch.int32, device=da.device)
state_chunk_in_seq = torch.zeros(nchunks, dtype=torch.int32, device=da.device)
default_seq_len = int(seq_lens[0].item()) if num_sequences > 0 else 0
default_inactive_local_chunk = (default_seq_len + chunk_size - 1) // chunk_size
state_chunk_in_seq.fill_(default_inactive_local_chunk)

for i in range(num_sequences):
start = int(cu_seqlens[i].item())
n = int(chunks_per_seq[i].item())
chunk_start = (start // chunk_size) + i
chunk_end = chunk_start + n
assert chunk_end <= nchunks, (
f"Chunk mapping overflow for seq {i}: [{chunk_start}, {chunk_end}) vs nchunks={nchunks}"
)
state_seq_mapping[chunk_start:chunk_end] = i
state_chunk_in_seq[chunk_start:chunk_end] = torch.arange(n, dtype=torch.int32, device=da.device)
range_starts = cu_seqlens[:-1] // chunk_size + seq_idx
range_ends = range_starts + seq_lens // chunk_size + 1
g = torch.arange(nchunks, dtype=cu_seqlens.dtype, device=da.device)
owner = torch.searchsorted(range_starts, g, right=True) - 1
active = g < range_ends[owner]
# Sentinel for inactive slots: ceil(len_0 / C), matching the original.
default_inactive_local_chunk = (
(seq_lens[0] + chunk_size - 1) // chunk_size
)
state_seq_mapping = torch.where(
active, owner, torch.zeros_like(owner)
).to(torch.int32)
state_chunk_in_seq = torch.where(
active, g - range_starts[owner], default_inactive_local_chunk
).to(torch.int32)

grid = (B, H, nchunks)
dacs_segsum_kernel_varlen[grid](
Expand Down