[TileOP] Add SM70 GEMM FMA fallback#2339
Conversation
Address PR tile-ai#2279 review: - coderabbitai nit: the T.sync_threads() before the cooperative A/B stage loads is unnecessary because A_stage / B_stage are freshly allocated shared buffers — no other thread holds references to them, and the later sync before the GEMM compute is sufficient. - Docstring Coverage pre-merge check: add module-, helper-, method-, and prim_func-level docstrings so coverage clears the 80% threshold. No behavior change; the prim_func still stages through shared memory and accumulates via scalar FMA. SM70 pytest trio still expected to pass.
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (4)
🚧 Files skipped from review as they are similar to previous changes (3)
📝 WalkthroughWalkthroughThis PR introduces a scalar fused-multiply-add fallback GEMM for CUDA: C++ selection adds ChangesCUDA FMA Fallback GEMM
Sequence DiagramsequenceDiagram
participant Compiler
participant SelectInst
participant AllowVoltaMma
participant Registry
Compiler->>SelectInst: request instruction for GemmNode
SelectInst->>AllowVoltaMma: evaluate(op)
AllowVoltaMma-->>SelectInst: eligible? (true/false)
SelectInst->>Registry: choose "cuda.mma" or "cuda.fma"
Registry-->>Compiler: emit selected implementation
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
🧹 Nitpick comments (2)
tilelang/cuda/op/gemm/gemm_fma.py (1)
39-39: 💤 Low valueConsider adding
strict=Truetozip()for explicit length checking.While
indicesandstridesare guaranteed to have the same length (both derived fromshape), addingstrict=Truemakes this invariant explicit and helps catch potential future bugs if the code is refactored.🔍 Proposed fix
- for index, stride in zip(indices, strides): + for index, stride in zip(indices, strides, strict=True):🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tilelang/cuda/op/gemm/gemm_fma.py` at line 39, The loop using zip(indices, strides) in gemm_fma.py should use zip(indices, strides, strict=True) to enforce that indices and strides have equal lengths at runtime; update the iteration (the for loop that binds index and stride) to pass strict=True to zip to make the invariant explicit and catch future refactor-induced mismatches.tilelang/cuda/op/gemm/__init__.py (1)
27-28: 💤 Low valueConsider adding a clarifying comment.
The
_match_fmapredicate accepts all CUDA targets, which differs from architecture-specific predicates like_match_mma_sm70(Volta only). This is correct—the predicate checks capability (can the target execute FMA instructions?), while dispatch policy (should the target use FMA?) is determined by C++SelectInstbased on instruction constraints.A brief comment would help future maintainers understand this design:
📝 Optional clarifying comment
def _match_fma(target) -> bool: + # FMA capability check; C++ SelectInst determines when to dispatch to cuda.fma return target_is_cuda(target)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tilelang/cuda/op/gemm/__init__.py` around lines 27 - 28, Add a brief clarifying comment above the _match_fma function explaining that it intentionally returns true for all CUDA targets because it checks capability (whether the target can execute FMA instructions) rather than selecting which architecture should use FMA; note that actual dispatch/policy decisions are made elsewhere (e.g., C++ SelectInst based on instruction constraints), so keep the predicate broad. Reference the _match_fma function name and SelectInst in the comment to guide future maintainers.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@tilelang/cuda/op/gemm/__init__.py`:
- Around line 27-28: Add a brief clarifying comment above the _match_fma
function explaining that it intentionally returns true for all CUDA targets
because it checks capability (whether the target can execute FMA instructions)
rather than selecting which architecture should use FMA; note that actual
dispatch/policy decisions are made elsewhere (e.g., C++ SelectInst based on
instruction constraints), so keep the predicate broad. Reference the _match_fma
function name and SelectInst in the comment to guide future maintainers.
In `@tilelang/cuda/op/gemm/gemm_fma.py`:
- Line 39: The loop using zip(indices, strides) in gemm_fma.py should use
zip(indices, strides, strict=True) to enforce that indices and strides have
equal lengths at runtime; update the iteration (the for loop that binds index
and stride) to pass strict=True to zip to make the invariant explicit and catch
future refactor-induced mismatches.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: e7d2c63a-7fb5-457a-be5f-848974516dc6
📒 Files selected for processing (6)
src/cuda/op/gemm.cctesting/python/cuda/test_cuda_mma_sm75_dispatch.pytesting/python/kernel/test_tilelang_kernel_sm70_fragment_copy.pytesting/python/kernel/test_tilelang_kernel_sm70_gemm_fma.pytilelang/cuda/op/gemm/__init__.pytilelang/cuda/op/gemm/gemm_fma.py
The build-linux-v100 job cleared the nvcc errors in v0.2.0 but then hit TileLang AOT LayoutInference 'Layout infer conflict between p and p_bf16': Volta sm_70 has no native bf16 tensor cores. scripts/sm70_tilelang.patch (upstream PR tile-ai/tilelang#2339, unmerged) adds the cuda.fma scalar-FMA GEMM fallback. The patch is C++ (compiled into libtilelang.so) so the pinned wheel can't carry it — build TileLang from source for this SM-pinned job only: clone v0.1.11 --recursive, apply the patch, non-editable --no-build-isolation pip install so the result reproduces the official wheel data layout (src/tl_templates + 3rdparty/cutlass/include) that build.rs's probe requires. Zero blast radius: SM-pinned (TORCH_CUDA_ARCH_LIST=7.0) own-runner own-env job; T1 + Blackwell keep the pinned wheel untouched. Job stays continue-on-error / non-gating. Compile-green is the only CI-verifiable bar (no V100 hardware); runtime correctness remains the V100-host-owner's best-effort responsibility. Bench-exempt: CI/build config only, no runtime/perf path change. Wins entry updated (V100 section: deferred -> wired in v0.2.1). Refs: docs/experience/wins/2026-06-15-v0.2.0-cuda-release-unblock.md
…ted on V100 The fp16-MMA fix (96bc0fc) feeds the sm_70 attention AOT kernels' GEMM operands as fp16 in-kernel, so dispatch hits TileLang's stock GemmMMASm70 mma.sync tensor-core path. That makes scripts/sm70_tilelang.patch (cuda.fma scalar fallback, PR tile-ai/tilelang#2339) dead code for these kernels. License-or-kill (per §0 SOLID): built STOCK unpatched TileLang 0.1.11 (gitcd37ed5f, no kCudaFMA / no gemm_fma.py) on a real Tesla V100 and re-validated all 4 kernels — prefill/decode x hd128(q32/kv8) + hd256(q16/kv4), cosine 0.999999, max_rel <= 2.7e-3 vs torch f32, identical to the patched validation. No-patch codegen emits tl::mma_sync_sm70<...kFloat16,kFloat16, kFloat32,16,16,4...> (stock GemmMMASm70), confirming the fma fallback is unused. Changes: - rm scripts/sm70_tilelang.patch + scripts/patch_tilelang_sm70.sh - release.yml build-linux-v100: drop the patch-apply; build stock 0.1.11 from source (still source, not the pinned wheel, for the src/tl_templates layout build.rs probes). CUDA 12.8 (already installed) satisfies stock 0.1.11's c++20 codegen. - scripts/_v100_build_tilelang.sh: build stock + assert the patch is ABSENT; note the CUDA>=12 (c++20) requirement. - de-dangle the one stale patch reference in the archived flashinfer plan. Caveat: stock 0.1.11 JIT emits -std=c++20 -> needs CUDA >=12 nvcc. CI uses 12.8; build.rs AOT uses CUDA_HOME (12.4 on the box) -- both fine. The box default /usr/bin/nvcc is 11.8 (lacks c++20). Bench-exempt: CI / build-config / docs only, no runtime path changed (the 4 shipped kernels are byte-identical). Verification recorded in memory reference_sm70_tilelang_multiconflict.
Summary
This reopens the SM70 GEMM fallback work from closed PR #2279 on top of current
main.The main change is a scalar
cuda.fmaGEMM implementation for Volta cases that SM70 tensor-core MMA cannot legally cover, such as BF16 operands or shapes outside the SM70 MMA tile constraints. Supported FP16 Volta MMA cases still dispatch to the existing SM70 MMA path.Compared with #2279, this version removes the controversial
copy.ccVolta fragment staging path entirely. There is no fragment -> shared -> fragment fallback and notl_frag_copystaging buffer. Fragment cast copy remains on the existing normal copy lowering path.Changes
GemmFMAand register it ascuda.fma.cuda.fmaat dispatch time.cuda.mma.GemmFMA.a_dtype,b_dtype) instead of assuming one input dtype.Validation
On V100 (
Tesla V100-SXM2-32GB, compute 7.0):cmake --build build -j$(nproc)in~/tilelang-sm70-copy: passed.~/tilelang/.venv/bin/python -m pytest testing/python/cuda/test_cuda_mma_sm75_dispatch.py -q: 8 passed.sm_70source+cubin checks for BF16 SS, BF16 RS, and FP16 fragment-copy RS:mma_sync_sm70<...>and notl_frag_copy.tl_frag_copyand still uses legal FP16 SM70 MMA.FFMA=128,HMMA=0.gen_tilelang_aot.pyforgated_delta_rule_chunk_a_sm70: cubin generated, source has no BF16 SM70 MMA, SASSHMMA=0,FFMA=256.The runtime CUDA pytest file was also invoked on the V100, but the local PyTorch install reports the NVIDIA driver as too old and those three GPU runtime tests were skipped before execution.
Summary by CodeRabbit
New Features
Tests