Skip to content

[TileOP] Add SM70 GEMM FMA fallback#2339

Open
cklxx wants to merge 6 commits into
tile-ai:mainfrom
cklxx:fix/sm70-gemm-fma-fallback-v2
Open

[TileOP] Add SM70 GEMM FMA fallback#2339
cklxx wants to merge 6 commits into
tile-ai:mainfrom
cklxx:fix/sm70-gemm-fma-fallback-v2

Conversation

@cklxx

@cklxx cklxx commented Jun 5, 2026

Copy link
Copy Markdown
Contributor

Summary

This reopens the SM70 GEMM fallback work from closed PR #2279 on top of current main.

The main change is a scalar cuda.fma GEMM implementation for Volta cases that SM70 tensor-core MMA cannot legally cover, such as BF16 operands or shapes outside the SM70 MMA tile constraints. Supported FP16 Volta MMA cases still dispatch to the existing SM70 MMA path.

Compared with #2279, this version removes the controversial copy.cc Volta fragment staging path entirely. There is no fragment -> shared -> fragment fallback and no tl_frag_copy staging buffer. Fragment cast copy remains on the existing normal copy lowering path.

Changes

  • Add GemmFMA and register it as cuda.fma.
  • Route unsupported Volta GEMM combinations to cuda.fma at dispatch time.
  • Keep aligned FP16 SM70 SS/RS cases on cuda.mma.
  • Add full-region assertions for fragment A/B/C in GemmFMA.
  • Use explicit local accumulation in the FMA lowering.
  • Use separate A/B staging dtypes (a_dtype, b_dtype) instead of assuming one input dtype.
  • Add dispatch and SM70 kernel coverage tests.

Validation

On V100 (Tesla V100-SXM2-32GB, compute 7.0):

  • cmake --build build -j$(nproc) in ~/tilelang-sm70-copy: passed.
  • ~/tilelang/.venv/bin/python -m pytest testing/python/cuda/test_cuda_mma_sm75_dispatch.py -q: 8 passed.
  • Explicit sm_70 source+cubin checks for BF16 SS, BF16 RS, and FP16 fragment-copy RS:
    • BF16 SS/RS source has no mma_sync_sm70<...> and no tl_frag_copy.
    • FP16 fragment-copy RS source has no tl_frag_copy and still uses legal FP16 SM70 MMA.
  • BF16 SS SASS check: FFMA=128, HMMA=0.
  • ARLE AOT smoke using gen_tilelang_aot.py for gated_delta_rule_chunk_a_sm70: cubin generated, source has no BF16 SM70 MMA, SASS HMMA=0, FFMA=256.

The runtime CUDA pytest file was also invoked on the V100, but the local PyTorch install reports the NVIDIA driver as too old and those three GPU runtime tests were skipped before execution.

Summary by CodeRabbit

  • New Features

    • Added CUDA GEMM FMA (“cuda.fma”) fallback support with automatic selection on Volta/SM70 targets and updated tiling/warp partitioning for this execution path.
  • Tests

    • Extended CUDA GEMM dispatch tests to confirm the correct implementation is chosen.
    • Added new TileLang CUDA kernel tests for fragment copy/cast behavior and GEMM FMA fallback coverage (fp16/bf16, transposed-A), including checks that certain SM70 MMA instructions are not emitted and results match PyTorch references.

cklxx added 4 commits June 5, 2026 12:14
Address PR tile-ai#2279 review:
- coderabbitai nit: the T.sync_threads() before the cooperative A/B stage
  loads is unnecessary because A_stage / B_stage are freshly allocated
  shared buffers — no other thread holds references to them, and the
  later sync before the GEMM compute is sufficient.
- Docstring Coverage pre-merge check: add module-, helper-, method-, and
  prim_func-level docstrings so coverage clears the 80% threshold.

No behavior change; the prim_func still stages through shared memory and
accumulates via scalar FMA. SM70 pytest trio still expected to pass.
@github-actions

github-actions Bot commented Jun 5, 2026

Copy link
Copy Markdown

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai

coderabbitai Bot commented Jun 5, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 6f55fdf3-72a9-49f2-8148-f67d7811e83d

📥 Commits

Reviewing files that changed from the base of the PR and between dfca10b and 2c42042.

📒 Files selected for processing (4)
  • src/cuda/op/gemm.cc
  • testing/python/cuda/test_cuda_mma_sm75_dispatch.py
  • tilelang/cuda/op/gemm/__init__.py
  • tilelang/cuda/op/gemm/gemm_fma.py
🚧 Files skipped from review as they are similar to previous changes (3)
  • testing/python/cuda/test_cuda_mma_sm75_dispatch.py
  • src/cuda/op/gemm.cc
  • tilelang/cuda/op/gemm/init.py

📝 Walkthrough

Walkthrough

This PR introduces a scalar fused-multiply-add fallback GEMM for CUDA: C++ selection adds kCudaFMA and Volta eligibility checks; a new Python GemmFMA lowers GEMM via shared staging and per-thread scalar accumulation; registration and multiple SM70 tests validate dispatch and numerical correctness.

Changes

CUDA FMA Fallback GEMM

Layer / File(s) Summary
C++ FMA instruction selection and warp layout
src/cuda/op/gemm.cc
Adds kCudaFMA identifier, implements AllowVoltaMma predicate checking buffer scope (shared/fragment for A, shared for B), rejecting transA_, restricting A/B dtypes to FP16, allowing C as FP16 or FP32, and enforcing (m % 16 == 0), (n % 16 == 0), (k % 4 == 0); Gemm::SelectInst selects kCudaFMA on Volta targets when ineligible for MMA; ComputeWarpPartition forces {m_warp = 1, n_warp = num_warps} when FMA is active.
Python FMA implementation and lowering
tilelang/cuda/op/gemm/gemm_fma.py
New GEMM_INST_FMA constant and GemmFMA class; _linear_fragment distributes fragment elements row-major across threads; infer_layout assigns linear layout to fragment operands; lower generates scalar-FMA prim_func that conditionally stages A and/or B into shared memory, then performs per-thread scalar accumulation over K with optional dtype casts and accumulator clearing, finally writing results to C.
FMA registration and dispatch
tilelang/cuda/op/gemm/__init__.py, testing/python/cuda/test_cuda_mma_sm75_dispatch.py
Imports and registers GemmFMA under "cuda.fma" for CUDA targets; dispatch test asserts resolve_gemm_impl("cuda.fma", sm_70) yields GemmFMA.
FMA fallback tests and validation
testing/python/kernel/test_tilelang_kernel_sm70_fragment_copy.py, testing/python/kernel/test_tilelang_kernel_sm70_gemm_fma.py
Adds SM70 CUDA tests: fragment-cast GEMM (float16→float32 casts feeding GEMM), bf16 SS GEMM with float32 accumulation, bf16 RS GEMM with bfloat16 recast, and fp16 transposed-A GEMM. All assert SM70 MMA sync intrinsics are absent, compile for compute capability 7.0, run on random inputs, and validate numeric correctness against float32 matmul references.

Sequence Diagram

sequenceDiagram
  participant Compiler
  participant SelectInst
  participant AllowVoltaMma
  participant Registry
  Compiler->>SelectInst: request instruction for GemmNode
  SelectInst->>AllowVoltaMma: evaluate(op)
  AllowVoltaMma-->>SelectInst: eligible? (true/false)
  SelectInst->>Registry: choose "cuda.mma" or "cuda.fma"
  Registry-->>Compiler: emit selected implementation
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Poem

🐰 I nibble threads and hop through code,
Fragments line up in tidy mode,
When Volta says "no MMA today",
FMA hops in to save the play,
Scalars add and kernels sing hooray!

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 25.93% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title '[TileOP] Add SM70 GEMM FMA fallback' clearly and specifically describes the main change: introducing an FMA fallback mechanism for SM70 GEMM operations.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (2)
tilelang/cuda/op/gemm/gemm_fma.py (1)

39-39: 💤 Low value

Consider adding strict=True to zip() for explicit length checking.

While indices and strides are guaranteed to have the same length (both derived from shape), adding strict=True makes this invariant explicit and helps catch potential future bugs if the code is refactored.

🔍 Proposed fix
-        for index, stride in zip(indices, strides):
+        for index, stride in zip(indices, strides, strict=True):
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tilelang/cuda/op/gemm/gemm_fma.py` at line 39, The loop using zip(indices,
strides) in gemm_fma.py should use zip(indices, strides, strict=True) to enforce
that indices and strides have equal lengths at runtime; update the iteration
(the for loop that binds index and stride) to pass strict=True to zip to make
the invariant explicit and catch future refactor-induced mismatches.
tilelang/cuda/op/gemm/__init__.py (1)

27-28: 💤 Low value

Consider adding a clarifying comment.

The _match_fma predicate accepts all CUDA targets, which differs from architecture-specific predicates like _match_mma_sm70 (Volta only). This is correct—the predicate checks capability (can the target execute FMA instructions?), while dispatch policy (should the target use FMA?) is determined by C++ SelectInst based on instruction constraints.

A brief comment would help future maintainers understand this design:

📝 Optional clarifying comment
 def _match_fma(target) -> bool:
+    # FMA capability check; C++ SelectInst determines when to dispatch to cuda.fma
     return target_is_cuda(target)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tilelang/cuda/op/gemm/__init__.py` around lines 27 - 28, Add a brief
clarifying comment above the _match_fma function explaining that it
intentionally returns true for all CUDA targets because it checks capability
(whether the target can execute FMA instructions) rather than selecting which
architecture should use FMA; note that actual dispatch/policy decisions are made
elsewhere (e.g., C++ SelectInst based on instruction constraints), so keep the
predicate broad. Reference the _match_fma function name and SelectInst in the
comment to guide future maintainers.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Nitpick comments:
In `@tilelang/cuda/op/gemm/__init__.py`:
- Around line 27-28: Add a brief clarifying comment above the _match_fma
function explaining that it intentionally returns true for all CUDA targets
because it checks capability (whether the target can execute FMA instructions)
rather than selecting which architecture should use FMA; note that actual
dispatch/policy decisions are made elsewhere (e.g., C++ SelectInst based on
instruction constraints), so keep the predicate broad. Reference the _match_fma
function name and SelectInst in the comment to guide future maintainers.

In `@tilelang/cuda/op/gemm/gemm_fma.py`:
- Line 39: The loop using zip(indices, strides) in gemm_fma.py should use
zip(indices, strides, strict=True) to enforce that indices and strides have
equal lengths at runtime; update the iteration (the for loop that binds index
and stride) to pass strict=True to zip to make the invariant explicit and catch
future refactor-induced mismatches.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: e7d2c63a-7fb5-457a-be5f-848974516dc6

📥 Commits

Reviewing files that changed from the base of the PR and between 550e25d and 0a21e37.

📒 Files selected for processing (6)
  • src/cuda/op/gemm.cc
  • testing/python/cuda/test_cuda_mma_sm75_dispatch.py
  • testing/python/kernel/test_tilelang_kernel_sm70_fragment_copy.py
  • testing/python/kernel/test_tilelang_kernel_sm70_gemm_fma.py
  • tilelang/cuda/op/gemm/__init__.py
  • tilelang/cuda/op/gemm/gemm_fma.py

@cklxx cklxx marked this pull request as draft June 5, 2026 06:02
@cklxx cklxx marked this pull request as ready for review June 5, 2026 12:17
cklxx added a commit to cklxx/arle that referenced this pull request Jun 15, 2026
The build-linux-v100 job cleared the nvcc errors in v0.2.0 but then hit
TileLang AOT LayoutInference 'Layout infer conflict between p and p_bf16':
Volta sm_70 has no native bf16 tensor cores. scripts/sm70_tilelang.patch
(upstream PR tile-ai/tilelang#2339, unmerged) adds the cuda.fma scalar-FMA
GEMM fallback. The patch is C++ (compiled into libtilelang.so) so the pinned
wheel can't carry it — build TileLang from source for this SM-pinned job only:
clone v0.1.11 --recursive, apply the patch, non-editable --no-build-isolation
pip install so the result reproduces the official wheel data layout
(src/tl_templates + 3rdparty/cutlass/include) that build.rs's probe requires.

Zero blast radius: SM-pinned (TORCH_CUDA_ARCH_LIST=7.0) own-runner own-env job;
T1 + Blackwell keep the pinned wheel untouched. Job stays continue-on-error /
non-gating. Compile-green is the only CI-verifiable bar (no V100 hardware);
runtime correctness remains the V100-host-owner's best-effort responsibility.

Bench-exempt: CI/build config only, no runtime/perf path change. Wins entry
updated (V100 section: deferred -> wired in v0.2.1).

Refs: docs/experience/wins/2026-06-15-v0.2.0-cuda-release-unblock.md
Comment thread tilelang/cuda/op/gemm/gemm_fma.py Outdated
@cklxx cklxx requested a review from LeiWang1999 June 15, 2026 10:02
cklxx added a commit to cklxx/arle that referenced this pull request Jun 15, 2026
…ted on V100

The fp16-MMA fix (96bc0fc) feeds the sm_70 attention AOT kernels' GEMM
operands as fp16 in-kernel, so dispatch hits TileLang's stock GemmMMASm70
mma.sync tensor-core path. That makes scripts/sm70_tilelang.patch (cuda.fma
scalar fallback, PR tile-ai/tilelang#2339) dead code for these kernels.

License-or-kill (per §0 SOLID): built STOCK unpatched TileLang 0.1.11
(gitcd37ed5f, no kCudaFMA / no gemm_fma.py) on a real Tesla V100 and
re-validated all 4 kernels — prefill/decode x hd128(q32/kv8) + hd256(q16/kv4),
cosine 0.999999, max_rel <= 2.7e-3 vs torch f32, identical to the patched
validation. No-patch codegen emits tl::mma_sync_sm70<...kFloat16,kFloat16,
kFloat32,16,16,4...> (stock GemmMMASm70), confirming the fma fallback is unused.

Changes:
- rm scripts/sm70_tilelang.patch + scripts/patch_tilelang_sm70.sh
- release.yml build-linux-v100: drop the patch-apply; build stock 0.1.11 from
  source (still source, not the pinned wheel, for the src/tl_templates layout
  build.rs probes). CUDA 12.8 (already installed) satisfies stock 0.1.11's
  c++20 codegen.
- scripts/_v100_build_tilelang.sh: build stock + assert the patch is ABSENT;
  note the CUDA>=12 (c++20) requirement.
- de-dangle the one stale patch reference in the archived flashinfer plan.

Caveat: stock 0.1.11 JIT emits -std=c++20 -> needs CUDA >=12 nvcc. CI uses
12.8; build.rs AOT uses CUDA_HOME (12.4 on the box) -- both fine. The box
default /usr/bin/nvcc is 11.8 (lacks c++20).

Bench-exempt: CI / build-config / docs only, no runtime path changed (the 4
shipped kernels are byte-identical). Verification recorded in memory
reference_sm70_tilelang_multiconflict.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants