diff --git a/setup.py b/setup.py index 6aba2de79..2a8e8eefc 100755 --- a/setup.py +++ b/setup.py @@ -244,25 +244,29 @@ def append_nvcc_threads(nvcc_extra_args): ), } - ext_modules.append( - CUDAExtension( - name="selective_scan_cuda", - sources=[ - "csrc/selective_scan/selective_scan.cpp", - "csrc/selective_scan/selective_scan_fwd_fp32.cu", - "csrc/selective_scan/selective_scan_fwd_fp16.cu", - "csrc/selective_scan/selective_scan_fwd_bf16.cu", - "csrc/selective_scan/selective_scan_bwd_fp32_real.cu", - "csrc/selective_scan/selective_scan_bwd_fp32_complex.cu", - "csrc/selective_scan/selective_scan_bwd_fp16_real.cu", - "csrc/selective_scan/selective_scan_bwd_fp16_complex.cu", - "csrc/selective_scan/selective_scan_bwd_bf16_real.cu", - "csrc/selective_scan/selective_scan_bwd_bf16_complex.cu", - ], - extra_compile_args=extra_compile_args, - include_dirs=[Path(this_dir) / "csrc" / "selective_scan"], + # selective_scan_cuda is a Mamba-1 CUDA extension that does not compile + # on ROCm/HIP. Mamba-2 and Mamba-3 use Triton kernels instead, so this + # extension is not needed on HIP builds. + if not HIP_BUILD: + ext_modules.append( + CUDAExtension( + name="selective_scan_cuda", + sources=[ + "csrc/selective_scan/selective_scan.cpp", + "csrc/selective_scan/selective_scan_fwd_fp32.cu", + "csrc/selective_scan/selective_scan_fwd_fp16.cu", + "csrc/selective_scan/selective_scan_fwd_bf16.cu", + "csrc/selective_scan/selective_scan_bwd_fp32_real.cu", + "csrc/selective_scan/selective_scan_bwd_fp32_complex.cu", + "csrc/selective_scan/selective_scan_bwd_fp16_real.cu", + "csrc/selective_scan/selective_scan_bwd_fp16_complex.cu", + "csrc/selective_scan/selective_scan_bwd_bf16_real.cu", + "csrc/selective_scan/selective_scan_bwd_bf16_complex.cu", + ], + extra_compile_args=extra_compile_args, + include_dirs=[Path(this_dir) / "csrc" / "selective_scan"], + ) ) - ) def get_package_version():