diff --git a/mamba_ssm/ops/triton/mamba3/mamba3_mimo_utils.py b/mamba_ssm/ops/triton/mamba3/mamba3_mimo_utils.py index 5de93aea..5807db92 100644 --- a/mamba_ssm/ops/triton/mamba3/mamba3_mimo_utils.py +++ b/mamba_ssm/ops/triton/mamba3/mamba3_mimo_utils.py @@ -648,28 +648,25 @@ def compute_dacs_segsum_triton_varlen( da_cs_rev = torch.empty_like(da) segsum = torch.zeros(B, H, nchunks, chunk_size, chunk_size, device=da.device, dtype=da.dtype) - seq_lens = cu_seqlens[1:] - cu_seqlens[:-1] - chunks_per_seq = (seq_lens // chunk_size) + 1 - # Build mapping tensors: both have length nchunks_global. # Inactive (padding) slots are given a sentinel local-chunk index that # places chunk_start >= seq_end, making the kernel mask all-False. - state_seq_mapping = torch.zeros(nchunks, dtype=torch.int32, device=da.device) - state_chunk_in_seq = torch.zeros(nchunks, dtype=torch.int32, device=da.device) - default_seq_len = int(seq_lens[0].item()) if num_sequences > 0 else 0 - default_inactive_local_chunk = (default_seq_len + chunk_size - 1) // chunk_size - state_chunk_in_seq.fill_(default_inactive_local_chunk) - for i in range(num_sequences): - start = int(cu_seqlens[i].item()) - n = int(chunks_per_seq[i].item()) - chunk_start = (start // chunk_size) + i - chunk_end = chunk_start + n - assert chunk_end <= nchunks, ( - f"Chunk mapping overflow for seq {i}: [{chunk_start}, {chunk_end}) vs nchunks={nchunks}" - ) - state_seq_mapping[chunk_start:chunk_end] = i - state_chunk_in_seq[chunk_start:chunk_end] = torch.arange(n, dtype=torch.int32, device=da.device) + range_starts = cu_seqlens[:-1] // chunk_size + seq_idx + range_ends = range_starts + seq_lens // chunk_size + 1 + g = torch.arange(nchunks, dtype=cu_seqlens.dtype, device=da.device) + owner = torch.searchsorted(range_starts, g, right=True) - 1 + active = g < range_ends[owner] + # Sentinel for inactive slots: ceil(len_0 / C), matching the original. + default_inactive_local_chunk = ( + (seq_lens[0] + chunk_size - 1) // chunk_size + ) + state_seq_mapping = torch.where( + active, owner, torch.zeros_like(owner) + ).to(torch.int32) + state_chunk_in_seq = torch.where( + active, g - range_starts[owner], default_inactive_local_chunk + ).to(torch.int32) grid = (B, H, nchunks) dacs_segsum_kernel_varlen[grid](