Skip to content

Installation and Build Failure on Modern Cloud Environments (Python 3…#975

Open
Turusore08 wants to merge 1 commit into
state-spaces:mainfrom
Turusore08:installation-build-failure-cloud
Open

Installation and Build Failure on Modern Cloud Environments (Python 3…#975
Turusore08 wants to merge 1 commit into
state-spaces:mainfrom
Turusore08:installation-build-failure-cloud

Conversation

@Turusore08

Copy link
Copy Markdown

….12+ / PyTorch 2.X / Kaggle Environments)

Pull Request: Pure PyTorch Fallbacks for Mamba SSM (Resolving Compilation & Import Failures)

Description

This pull request introduces Pure PyTorch Fallbacks for Triton and CUDA operations across mamba-ssm. It completely resolves the persistent compilation and import failures encountered in standard managed cloud environments (such as Kaggle, Google Colab) and unsupported local OS setups (like Windows or CPU-only environments).

By dynamically falling back to highly optimized PyTorch-native equivalents when the compiled CUDA binaries or Triton libraries are unavailable, the package becomes instantly installable and executable out-of-the-box without any pre-compilation dependencies.


Key Problems Addressed

  1. Strict Dependency Bottleneck (causal-conv1d & triton): Direct imports of Triton and CUDA compilation dependencies previously caused immediate crashes (ModuleNotFoundError: No module named 'triton') on import, blocking usage in non-Linux or CPU/mismatched CUDA setups.
  2. Lack of Pre-built Wheels: Prevents environment failures in environments like Kaggle (running Python 3.12+) where pre-compiled wheels are not published and local compilation often runs out of memory or hits compiler mismatch errors.
  3. Fragile Namespace Resolution: Eliminates C++ ABI version mismatches or AttributeError caused by missing compiled symbols in .so files.

Changes Implemented

1. Robust Triton & CUDA Import Protection

Wrapped all CUDA and Triton imports in try-except blocks across critical files to ensure the namespace resolves successfully without crashing:

  • selective_scan_interface.py
  • block.py
  • mamba2.py
  • mamba2_simple.py
  • mamba3.py
  • ssd_minimal.py

2. Pure PyTorch Fallbacks

  • Vectorized Hybrid Chunked Associative Scan (hybrid_chunk_scan_pytorch): Added a custom parallelized associative scan in pure PyTorch. It groups step-by-step recurrence into sequence chunks and parallelizes boundary state propagation using a Kogge-Stone prefix scan logic, ensuring training-level backpropagation support and $O(\log L)$ parallel speed.
  • Sequence Padding Support: Automatically pads the sequence dimension to a multiple of chunk_size before executing the Mamba-2 fallback scan (ssd_minimal_discrete) and truncates the outputs, preventing size-mismatch AssertionErrors for arbitrary sequence lengths.
  • RMSNorm & Gated RMSNorm Fallbacks: Added pure PyTorch equivalents for standard, gated, and group-wise RMSNorm normalization modules.
  • Causal 1D Conv Fallback: Automatically redirects convolutions to F.conv1d if the compiled causal_conv1d library is missing.

Verification & Testing

  1. Mathematical Correctness: Ran correctness checks comparing outputs and backward gradients between hybrid_chunk_scan_pytorch and selective_scan_ref. They are mathematically equivalent down to float64 precision limit.
  2. End-to-End Execution: Tested Mamba-1 and Mamba-2 forward passes successfully on a clean Windows/CPU environment:
    $env:PYTHONPATH="."; python -c "import torch; from mamba_ssm import Mamba, Mamba2; model1 = Mamba(d_model=64); x = torch.randn(2, 16, 64); y1 = model1(x); print('Mamba-1 Output:', y1.shape); model2 = Mamba2(d_model=64); y2 = model2(x); print('Mamba-2 Output:', y2.shape)"
    Output:
    Mamba-1 Output: torch.Size([2, 16, 64])
    Mamba-2 Output: torch.Size([2, 16, 64])
    

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant