From 9b4a52be04ffff26c64546aec422bbc6f3f4f8fe Mon Sep 17 00:00:00 2001 From: CPA <136311479+darxradi3nt@users.noreply.github.com> Date: Wed, 25 Mar 2026 23:25:07 +0200 Subject: [PATCH 1/4] fixes issue #880 --- csrc/selective_scan/selective_scan_bwd_kernel.cuh | 6 +++--- csrc/selective_scan/selective_scan_fwd_kernel.cuh | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/csrc/selective_scan/selective_scan_bwd_kernel.cuh b/csrc/selective_scan/selective_scan_bwd_kernel.cuh index c720ba28c..3e3de10d3 100755 --- a/csrc/selective_scan/selective_scan_bwd_kernel.cuh +++ b/csrc/selective_scan/selective_scan_bwd_kernel.cuh @@ -114,9 +114,9 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) { weight_t *smem_da = reinterpret_cast(smem_running_postfix + MAX_DSTATE); weight_t *smem_dbc = reinterpret_cast(smem_da + MAX_DSTATE); - const int batch_id = blockIdx.x; - const int dim_id = blockIdx.y; - const int group_id = dim_id / (params.dim_ngroups_ratio); + const int64_t batch_id = blockIdx.x; + const int64_t dim_id = blockIdx.y; + const int64_t group_id = dim_id / (params.dim_ngroups_ratio); input_t *u = reinterpret_cast(params.u_ptr) + batch_id * params.u_batch_stride + dim_id * params.u_d_stride; input_t *delta = reinterpret_cast(params.delta_ptr) + batch_id * params.delta_batch_stride diff --git a/csrc/selective_scan/selective_scan_fwd_kernel.cuh b/csrc/selective_scan/selective_scan_fwd_kernel.cuh index 0d65bb13c..5a785b5ba 100755 --- a/csrc/selective_scan/selective_scan_fwd_kernel.cuh +++ b/csrc/selective_scan/selective_scan_fwd_kernel.cuh @@ -99,9 +99,9 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { // weight_t *smem_bc = reinterpret_cast(smem_a + MAX_DSTATE); scan_t *smem_running_prefix = reinterpret_cast(smem_ + Ktraits::kSmemSize); - const int batch_id = blockIdx.x; - const int dim_id = blockIdx.y; - const int group_id = dim_id / (params.dim_ngroups_ratio); + const int64_t batch_id = blockIdx.x; + const int64_t dim_id = blockIdx.y; + const int64_t group_id = dim_id / (params.dim_ngroups_ratio); input_t *u = reinterpret_cast(params.u_ptr) + (int64_t)batch_id * params.u_batch_stride + (int64_t)dim_id * kNRows * params.u_d_stride; input_t *delta = reinterpret_cast(params.delta_ptr) + batch_id * params.delta_batch_stride From 1871459f1cce6e44c87020f222df67964e68a1a7 Mon Sep 17 00:00:00 2001 From: CPA <136311479+darxradi3nt@users.noreply.github.com> Date: Wed, 25 Mar 2026 23:27:02 +0200 Subject: [PATCH 2/4] fixes issue #880 --- csrc/selective_scan/selective_scan_fwd_kernel.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/selective_scan/selective_scan_fwd_kernel.cuh b/csrc/selective_scan/selective_scan_fwd_kernel.cuh index 5a785b5ba..0242730b7 100755 --- a/csrc/selective_scan/selective_scan_fwd_kernel.cuh +++ b/csrc/selective_scan/selective_scan_fwd_kernel.cuh @@ -102,8 +102,8 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { const int64_t batch_id = blockIdx.x; const int64_t dim_id = blockIdx.y; const int64_t group_id = dim_id / (params.dim_ngroups_ratio); - input_t *u = reinterpret_cast(params.u_ptr) + (int64_t)batch_id * params.u_batch_stride - + (int64_t)dim_id * kNRows * params.u_d_stride; + input_t *u = reinterpret_cast(params.u_ptr) + batch_id * params.u_batch_stride + + dim_id * kNRows * params.u_d_stride; input_t *delta = reinterpret_cast(params.delta_ptr) + batch_id * params.delta_batch_stride + dim_id * kNRows * params.delta_d_stride; weight_t *A = reinterpret_cast(params.A_ptr) + dim_id * kNRows * params.A_d_stride; From 1392be9ecb227fe83ab74f88db112160ad62a371 Mon Sep 17 00:00:00 2001 From: CPA <136311479+darxradi3nt@users.noreply.github.com> Date: Thu, 26 Mar 2026 18:48:07 +0200 Subject: [PATCH 3/4] fixes issue #880 --- csrc/selective_scan/selective_scan.h | 21 +------------------ .../selective_scan_bwd_kernel.cuh | 6 +++--- .../selective_scan_fwd_kernel.cuh | 6 +++--- 3 files changed, 7 insertions(+), 26 deletions(-) diff --git a/csrc/selective_scan/selective_scan.h b/csrc/selective_scan/selective_scan.h index b51a731e2..630806a6f 100644 --- a/csrc/selective_scan/selective_scan.h +++ b/csrc/selective_scan/selective_scan.h @@ -4,27 +4,8 @@ #pragma once -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct SSMScanParamsBase { - using index_t = uint64_t; - - int batch, seqlen, n_chunks; - index_t a_batch_stride; - index_t b_batch_stride; - index_t out_batch_stride; - - // Common data pointers. - void *__restrict__ a_ptr; - void *__restrict__ b_ptr; - void *__restrict__ out_ptr; - void *__restrict__ x_ptr; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - struct SSMParamsBase { - using index_t = uint32_t; + using index_t = uint64_t; int batch, dim, seqlen, dstate, n_groups, n_chunks; int dim_ngroups_ratio; diff --git a/csrc/selective_scan/selective_scan_bwd_kernel.cuh b/csrc/selective_scan/selective_scan_bwd_kernel.cuh index 3e3de10d3..c720ba28c 100755 --- a/csrc/selective_scan/selective_scan_bwd_kernel.cuh +++ b/csrc/selective_scan/selective_scan_bwd_kernel.cuh @@ -114,9 +114,9 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) { weight_t *smem_da = reinterpret_cast(smem_running_postfix + MAX_DSTATE); weight_t *smem_dbc = reinterpret_cast(smem_da + MAX_DSTATE); - const int64_t batch_id = blockIdx.x; - const int64_t dim_id = blockIdx.y; - const int64_t group_id = dim_id / (params.dim_ngroups_ratio); + const int batch_id = blockIdx.x; + const int dim_id = blockIdx.y; + const int group_id = dim_id / (params.dim_ngroups_ratio); input_t *u = reinterpret_cast(params.u_ptr) + batch_id * params.u_batch_stride + dim_id * params.u_d_stride; input_t *delta = reinterpret_cast(params.delta_ptr) + batch_id * params.delta_batch_stride diff --git a/csrc/selective_scan/selective_scan_fwd_kernel.cuh b/csrc/selective_scan/selective_scan_fwd_kernel.cuh index 0242730b7..c64098a3a 100755 --- a/csrc/selective_scan/selective_scan_fwd_kernel.cuh +++ b/csrc/selective_scan/selective_scan_fwd_kernel.cuh @@ -99,9 +99,9 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { // weight_t *smem_bc = reinterpret_cast(smem_a + MAX_DSTATE); scan_t *smem_running_prefix = reinterpret_cast(smem_ + Ktraits::kSmemSize); - const int64_t batch_id = blockIdx.x; - const int64_t dim_id = blockIdx.y; - const int64_t group_id = dim_id / (params.dim_ngroups_ratio); + const int batch_id = blockIdx.x; + const int dim_id = blockIdx.y; + const int group_id = dim_id / (params.dim_ngroups_ratio); input_t *u = reinterpret_cast(params.u_ptr) + batch_id * params.u_batch_stride + dim_id * kNRows * params.u_d_stride; input_t *delta = reinterpret_cast(params.delta_ptr) + batch_id * params.delta_batch_stride From d9a121da793795973600e6bbb4e27b1a9f19f60c Mon Sep 17 00:00:00 2001 From: CPA <136311479+darxradi3nt@users.noreply.github.com> Date: Thu, 26 Mar 2026 19:18:43 +0200 Subject: [PATCH 4/4] fix: x buffer offset corrupts gradients silently for large batch/dstate configs --- csrc/selective_scan/selective_scan_bwd_kernel.cuh | 2 +- csrc/selective_scan/selective_scan_fwd_kernel.cuh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/selective_scan/selective_scan_bwd_kernel.cuh b/csrc/selective_scan/selective_scan_bwd_kernel.cuh index c720ba28c..26f2be3ee 100755 --- a/csrc/selective_scan/selective_scan_bwd_kernel.cuh +++ b/csrc/selective_scan/selective_scan_bwd_kernel.cuh @@ -139,7 +139,7 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) { float delta_bias = params.delta_bias_ptr == nullptr ? 0 : reinterpret_cast(params.delta_bias_ptr)[dim_id]; scan_t *x = params.x_ptr == nullptr ? nullptr - : reinterpret_cast(params.x_ptr) + (batch_id * params.dim + dim_id) * (params.n_chunks) * params.dstate; + : reinterpret_cast(params.x_ptr) + (int64_t)(batch_id * params.dim + dim_id) * params.n_chunks * params.dstate; float dD_val = 0; float ddelta_bias_val = 0; diff --git a/csrc/selective_scan/selective_scan_fwd_kernel.cuh b/csrc/selective_scan/selective_scan_fwd_kernel.cuh index c64098a3a..4cb671cef 100755 --- a/csrc/selective_scan/selective_scan_fwd_kernel.cuh +++ b/csrc/selective_scan/selective_scan_fwd_kernel.cuh @@ -111,7 +111,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { input_t *Bvar = reinterpret_cast(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride; weight_t *C = reinterpret_cast(params.C_ptr) + dim_id * kNRows * params.C_d_stride; input_t *Cvar = reinterpret_cast(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride; - scan_t *x = reinterpret_cast(params.x_ptr) + (batch_id * params.dim + dim_id * kNRows) * params.n_chunks * params.dstate; + scan_t *x = reinterpret_cast(params.x_ptr) + (int64_t)(batch_id * params.dim + dim_id * kNRows) * params.n_chunks * params.dstate; float D_val[kNRows] = {0}; if (params.D_ptr != nullptr) {