diff --git a/csrc/selective_scan/selective_scan.h b/csrc/selective_scan/selective_scan.h index b51a731e2..630806a6f 100644 --- a/csrc/selective_scan/selective_scan.h +++ b/csrc/selective_scan/selective_scan.h @@ -4,27 +4,8 @@ #pragma once -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct SSMScanParamsBase { - using index_t = uint64_t; - - int batch, seqlen, n_chunks; - index_t a_batch_stride; - index_t b_batch_stride; - index_t out_batch_stride; - - // Common data pointers. - void *__restrict__ a_ptr; - void *__restrict__ b_ptr; - void *__restrict__ out_ptr; - void *__restrict__ x_ptr; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - struct SSMParamsBase { - using index_t = uint32_t; + using index_t = uint64_t; int batch, dim, seqlen, dstate, n_groups, n_chunks; int dim_ngroups_ratio; diff --git a/csrc/selective_scan/selective_scan_bwd_kernel.cuh b/csrc/selective_scan/selective_scan_bwd_kernel.cuh index c720ba28c..26f2be3ee 100755 --- a/csrc/selective_scan/selective_scan_bwd_kernel.cuh +++ b/csrc/selective_scan/selective_scan_bwd_kernel.cuh @@ -139,7 +139,7 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) { float delta_bias = params.delta_bias_ptr == nullptr ? 0 : reinterpret_cast(params.delta_bias_ptr)[dim_id]; scan_t *x = params.x_ptr == nullptr ? nullptr - : reinterpret_cast(params.x_ptr) + (batch_id * params.dim + dim_id) * (params.n_chunks) * params.dstate; + : reinterpret_cast(params.x_ptr) + (int64_t)(batch_id * params.dim + dim_id) * params.n_chunks * params.dstate; float dD_val = 0; float ddelta_bias_val = 0; diff --git a/csrc/selective_scan/selective_scan_fwd_kernel.cuh b/csrc/selective_scan/selective_scan_fwd_kernel.cuh index 0d65bb13c..4cb671cef 100755 --- a/csrc/selective_scan/selective_scan_fwd_kernel.cuh +++ b/csrc/selective_scan/selective_scan_fwd_kernel.cuh @@ -102,8 +102,8 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { const int batch_id = blockIdx.x; const int dim_id = blockIdx.y; const int group_id = dim_id / (params.dim_ngroups_ratio); - input_t *u = reinterpret_cast(params.u_ptr) + (int64_t)batch_id * params.u_batch_stride - + (int64_t)dim_id * kNRows * params.u_d_stride; + input_t *u = reinterpret_cast(params.u_ptr) + batch_id * params.u_batch_stride + + dim_id * kNRows * params.u_d_stride; input_t *delta = reinterpret_cast(params.delta_ptr) + batch_id * params.delta_batch_stride + dim_id * kNRows * params.delta_d_stride; weight_t *A = reinterpret_cast(params.A_ptr) + dim_id * kNRows * params.A_d_stride; @@ -111,7 +111,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { input_t *Bvar = reinterpret_cast(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride; weight_t *C = reinterpret_cast(params.C_ptr) + dim_id * kNRows * params.C_d_stride; input_t *Cvar = reinterpret_cast(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride; - scan_t *x = reinterpret_cast(params.x_ptr) + (batch_id * params.dim + dim_id * kNRows) * params.n_chunks * params.dstate; + scan_t *x = reinterpret_cast(params.x_ptr) + (int64_t)(batch_id * params.dim + dim_id * kNRows) * params.n_chunks * params.dstate; float D_val[kNRows] = {0}; if (params.D_ptr != nullptr) {