Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,12 @@ It can also be built from source with `pip install .` from this repository.
Try passing `--no-build-isolation` to `pip` if installation encounters difficulties either when building from source or installing from PyPi. Common `pip` complaints that can be resolved in this way include PyTorch versions, but other cases exist as well.

Other requirements:
- Linux
- NVIDIA GPU
- PyTorch 1.12+
- CUDA 11.6+

Note: on Windows, it is recommended to use x64 Native Tools Command Prompt for VS with DISTUTILS_USE_SDK=1.

For AMD cards, see additional prerequisites below.

## Usage
Expand Down
17 changes: 9 additions & 8 deletions csrc/selective_scan/selective_scan_bwd_kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,18 @@
#include <cub/block/block_store.cuh>
#include <cub/block/block_scan.cuh>
#include <cub/block/block_reduce.cuh>
#define ROCM_ONLY(x)
#else
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#define ROCM_ONLY(x) x
#endif

#ifndef M_LOG2E
#define M_LOG2E 1.4426950408889634074
#endif


#include "selective_scan.h"
#include "selective_scan_common.h"
#include "reverse_scan.cuh"
Expand Down Expand Up @@ -511,15 +518,9 @@ void selective_scan_bwd_launch(SSMParamsBwd &params, cudaStream_t stream) {
auto kernel = &selective_scan_bwd_kernel<Ktraits>;

if (kSmemSize >= 48 * 1024) {

#ifndef USE_ROCM
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
#else
C10_CUDA_CHECK(cudaFuncSetAttribute(
(void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
std::cerr << "Warning (selective_scan_bwd_kernel): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl;
#endif
ROCM_ONLY((void *)) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
ROCM_ONLY(std::cerr << "Warning (selective_scan_bwd_kernel): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl);

}

Expand Down
18 changes: 10 additions & 8 deletions csrc/selective_scan/selective_scan_fwd_kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,18 @@
#include <cub/block/block_load.cuh>
#include <cub/block/block_store.cuh>
#include <cub/block/block_scan.cuh>
#define ROCM_ONLY(x)
#else
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#define ROCM_ONLY(x) x
#endif

#ifndef M_LOG2E
#define M_LOG2E 1.4426950408889634074
#endif


#include "selective_scan.h"
#include "selective_scan_common.h"
#include "static_switch.h"
Expand Down Expand Up @@ -311,7 +318,7 @@ template<int kNThreads, int kNItems, typename input_t, typename weight_t>
void selective_scan_fwd_launch(SSMParamsBase &params, cudaStream_t stream) {
// Only kNRows == 1 is tested for now, which ofc doesn't differ from previously when we had each block
// processing 1 row.
constexpr int kNRows = 1;
static constexpr int kNRows = 1;
BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
BOOL_SWITCH(params.is_variable_B, kIsVariableB, [&] {
BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] {
Expand All @@ -329,14 +336,9 @@ void selective_scan_fwd_launch(SSMParamsBase &params, cudaStream_t stream) {


if (kSmemSize >= 48 * 1024) {
#ifndef USE_ROCM
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
#else
C10_CUDA_CHECK(cudaFuncSetAttribute(
(void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
std::cerr << "Warning (selective_scan_fwd_kernel): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl;
#endif
ROCM_ONLY((void *)) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
ROCM_ONLY(std::cerr << "Warning (selective_scan_fwd_kernel): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl);
}

kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
Expand Down
4 changes: 2 additions & 2 deletions csrc/selective_scan/static_switch.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
if (COND) { \
constexpr bool CONST_NAME = true; \
static constexpr bool CONST_NAME = true; \
return __VA_ARGS__(); \
} else { \
constexpr bool CONST_NAME = false; \
static constexpr bool CONST_NAME = false; \
return __VA_ARGS__(); \
} \
}()
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ classifiers = [
]
dependencies = [
"torch",
"triton",
"triton ; sys_platform != 'win32'",
"triton-windows ; sys_platform == 'win32'",
"ninja",
"einops",
"transformers",
Expand Down