[Feature][Blackwell] Add SM120 T.float4_e2m1fn FP4 GEMM support.#2171
[Feature][Blackwell] Add SM120 T.float4_e2m1fn FP4 GEMM support.#2171TerminusAkivili wants to merge 1 commit into
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughThis PR implements SM120 (CUDA 12.0+) FP4 (float4_e2m1fn) GEMM support across TileLang: examples and host unpacking, CUDA/TI codegen FP4 storage/indexing/vector/scalar handling, FP4-aware cp.async injection, b4x16 ldmatrix helpers, CuTe SM120 MMA dispatch for FP4/mixed operands, layout/macro generation changes, and GemmMMA integration. ChangesSM120 FP4 GEMM Support
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (1)
src/tl_templates/cuda/cuda_fp4.h (1)
166-187: ⚡ Quick winVerify register allocation for
fp4_e2_t values[64]in device code.The 64-element local array is constant-indexed throughout (
values[0]–values[63]), so nvcc at-O2+should scalar-replace it into registers. However, unlike the explicitly-parameterizedmake_fp4_e2_32_twhich guarantees register-only arguments, register spilling to local memory is possible at lower optimisation levels or with larger surrounding register pressure. Consider adding a__forceinline__annotation to maximise inlining and scalar replacement at call sites.Proposed annotation
-template <typename... Args> -TL_DEVICE fp4_e2_64_t make_fp4_e2_64_t(Args... args) { +template <typename... Args> +TL_DEVICE __forceinline__ fp4_e2_64_t make_fp4_e2_64_t(Args... args) {🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@src/tl_templates/cuda/cuda_fp4.h` around lines 166 - 187, The local array fp4_e2_t values[64] in make_fp4_e2_64_t may be spilled under some compile conditions; annotate the function to force inlining (e.g., add a __forceinline__/always-inline device inline attribute to make_fp4_e2_64_t) so nvcc can scalar-replace values[0]..values[63] into registers and inline the make_fp4_e2_32_t calls; update the function declaration for make_fp4_e2_64_t accordingly (keeping fp4_e2_t values[64] and the existing make_fp4_e2_32_t usages unchanged).
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@src/backend/cuda/codegen/codegen_cuda.cc`:
- Around line 1973-2003: The FP4 padded shared-memory vector path
(IsFp4PaddedSharedStorage + code using GetFp4PaddedSharedIndex and the
byte_offset lambda when constructing the reinterpret cast for t.lanes()) can
incorrectly span the padded 16-element row boundary; add a guard or split logic:
either assert the logical base alignment (e.g., Ensure base % 16 == 0 for the
requested load/store) or detect when the access crosses a 16-element row by
computing the start and end logical indices (base + offset and base + offset +
t.lanes()-1) and comparing their 16-element row indices (truncdiv(..., 16)); if
it crosses, split the operation into two row-aligned fragments (like the
existing t.lanes()==32 two-fragment approach) and merge them, otherwise keep the
current single contiguous byte reinterpretation; apply the same fix to the other
similar blocks identified (around the other ranges mentioned).
- Around line 4428-4444: The allocator treats only scope == "local" as the path
that emits local backing arrays but FP4 fragments use the semantic storage name
"local.fragment", so allocations for these still hit the unsupported-scope
branch; update the scope checks used around is_int4_scalar_local, the FP4
alignas(16) branch, and the place that prints/omits the storage scope to treat
"local.fragment" as equivalent to "local" (either normalize scope to "local"
earlier or change conditions from scope == "local" to (scope == "local" || scope
== "local.fragment")), ensuring PrintStorageScope/PrintType and the
backing-array emission path handle FP4 fragments the same as regular local
allocations (references: is_int4_scalar_local, op->dtype.is_float4_e2m1fn(),
PrintStorageScope, PrintType, and the "local.fragment" semantic storage).
In `@tilelang/cuda/intrinsics/macro/mma_macro_generator.py`:
- Around line 121-124: The FP4 fast-path in mma_macro_generator.py sets
self.k_dim = 32 without respecting self.chunk, causing micro_size_k to exceed
chunk when chunk < 32; update the FP4 branch in the initializer (the block
setting self.k_dim) to clamp k_dim by self.chunk (e.g., self.k_dim = min(32,
self.chunk)) and add the same clamp/guard in the subclass override (the code
around lines 873–877) so both places respect chunk; optionally emit a clear
ValueError or assertion if chunk < required minimum to fail early with a helpful
message referencing the dtype and chunk size.
---
Nitpick comments:
In `@src/tl_templates/cuda/cuda_fp4.h`:
- Around line 166-187: The local array fp4_e2_t values[64] in make_fp4_e2_64_t
may be spilled under some compile conditions; annotate the function to force
inlining (e.g., add a __forceinline__/always-inline device inline attribute to
make_fp4_e2_64_t) so nvcc can scalar-replace values[0]..values[63] into
registers and inline the make_fp4_e2_32_t calls; update the function declaration
for make_fp4_e2_64_t accordingly (keeping fp4_e2_t values[64] and the existing
make_fp4_e2_32_t usages unchanged).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: a09f3145-ce2d-4b0d-bb75-d916a099b2be
📒 Files selected for processing (16)
examples/gemm_fp4/example_gemm_a8w4_sm120.pyexamples/gemm_fp4/example_gemm_fp4_sm120.pysrc/backend/cuda/codegen/codegen_cuda.ccsrc/backend/cuda/codegen/codegen_cuda.hsrc/backend/cuda/op/copy.ccsrc/backend/cuda/op/copy_analysis.ccsrc/tl_templates/cuda/cuda_fp4.hsrc/tl_templates/cuda/gemm_mma.hsrc/tl_templates/cuda/instruction/mma.hsrc/tl_templates/cuda/ldsm.hsrc/transform/lower_ptx_async_copy.ccsrc/transform/ptx_async_copy_injector.htilelang/cuda/intrinsics/layout/mma_layout.pytilelang/cuda/intrinsics/layout/utils.pytilelang/cuda/intrinsics/macro/mma_macro_generator.pytilelang/cuda/op/gemm/gemm_mma.py
3e5823d to
7f254a9
Compare
|
Hi @LeiWang1999, no rush at all. Feel free to check it whenever it's convenient for you. I'd love your feedback. Thank you! |
cb5bf3d to
795cb39
Compare
|
Oh, seems like this PR overlaps with part of my PR #2182. Just wanted to clarify that there was no intent to duplicate the work as it is carried over from the earlier FP4 branch (didn't notice this pr the moment I create my new one). I'm totally okey to coordinate scope if any feedback received from maintainers, thanks for your work. |
|
Thanks @Hale423 for the clarification! I took a closer look at the SM120 overlap, and I think the two PRs are taking slightly different directions. |
|
The downside is that it introduces FP4xFP4-specific logic across layout, copy analysis, pipeline planning, codegen, and CUDA helpers, which increases maintenance cost and expands the scope of this PR. |
|
Confirmed, feel free to implement any idea on SM120, I'm willing to coordinate scope, thanks for sharing! |
|
@Hale423 I also tried a separate RFC version that keeps the public API as |
Got it, thanks for your clarification |
|
Looks interesting, but for f8f6f4, I think we need to introduce a hidden type, float4_e2m1_unpacked – that's exactly what we're working on. Thanks. |
|
@LeiWang1999 Thanks for the update and for working on this. I’ll keep an eye on the progress. |
aa1a1c5 to
be0a064
Compare
3ca55aa to
7fa0f92
Compare
|
I’d appreciate it if you could review this PR when you have time. Thanks! @LeiWang1999 |
586a8e3 to
5d60b15
Compare
5d60b15 to
5e4ed5d
Compare
f0be868 to
241139a
Compare
Summary
This PR adds SM120 fragment-MMA GEMM support for semantic
T.float4_e2m1fnoperands.
TileLang programs continue to declare FP4 operands as
T.float4_e2m1fn.For the SM120 performance path, lowering maps those semantic operands onto the
hidden
T.float4_e2m1_unpackedbyte carrier internally. Users do not need towrite FP4 GEMM operands as
uint8tensors in TileLang programs.Supported SM120 GEMM combinations:
T.float4_e2m1fnT.float4_e2m1fnT.float32T.float8_e4m3fnT.float4_e2m1fnT.float32T.float4_e2m1fnT.float8_e4m3fnT.float32Design Goals
T.float4_e2m1fnatthe language level.
T.float4_e2m1_unpackedas the hidden physical carrier for the SM120 FP4GEMM path.
ldmatrixmodel used by [Feature] Support Blackwell FP4(float4_e2m1fn) GEMM for SM100 & SM120 #2182.ldmatrixbehavior.that silently skip a K tail.
Design
The key separation is between the public dtype and the physical carrier:
T.float4_e2m1fn.custom[float4_e2m1_unpacked]8shared/localcarrier buffers where the hardware path needs byte slots.
logical
(M, K)/(N, K)shapes.operands while preserving the public
T.float4_e2m1fnhandle dtype.calling the CuTe SM120 F8/F6/F4 atom.
Main Changes
CUDA Templates
cute::SM120_16x8x32_TNdispatch for FP4xFP4, FP8xFP4, andFP4xFP8 into FP32.
packed helper types.
CUDA Lowering
float4_e2m1_unpackedcarrier buffers where required by the performance path.tl::ptx_ldmatrix_x*for the hidden unpacked-carrier path.logical offsets by lowering those cases through per-lane nibble helpers.
dtype.
Python Lowering
T.float4_e2m1_unpackedas the hidden local/shared carrier for semanticSM120 FP4 operands.
public buffer dtypes semantic.
T.gemmblock K values that are not divisible by the selectedinstruction K tile.
Examples
examples/gemm_fp4/example_gemm_fp4_sm120.pyexamples/gemm_fp4/example_gemm_a8w4_sm120.pyT.float4_e2m1fnkernel signatures and usebyte-compatible host tensors only as an interoperability detail.
Tests
testing/python/language/test_tilelang_language_float4_e2m1_unpacked_gemm.py.carriers, ordinary
tl::ptx_ldmatrix_x*, and notl::ptx_ldmatrix_b4x16onthe main performance path.
Why The Review Fixes Matter
FP4 storage has two concerns that should not be conflated: the user-facing dtype
and the physical carrier used by the hardware path. The hidden
float4_e2m1_unpackedcarrier lets SM120 GEMM keep the semanticT.float4_e2m1fnAPI while matching the ordinary-ldmatrix byte-carrierperformance model used by #2182.
For packed fallback cases, FP4 byte storage is byte-addressed while logical FP4
elements are nibble-addressed. A vector reinterpret load/store is safe only when
the logical base offset is known to be even. If the offset is odd, or if codegen
cannot prove it is even, vectorized byte reinterpretation can read or write the
wrong nibble without producing a compilation error. This PR routes those cases
through per-lane nibble helpers.
SM120 FP4/A8W4/W4A8 MMA consumes K in fixed
m16n8k32instruction chunks.Allowing a
block_Ksuch as 48 would execute only the representable K=32portion and miss the K tail. This PR turns that silent numerical error into an
explicit unsupported-shape error.
Validation
Local SM120 validation used an RTX PRO 6000 / compute capability 12.0
environment.
Build and focused examples:
Generated CUDA and TIR were inspected for the expected SM120 FP4 markers:
Focused tests:
Observed focused-test results:
Performance comparison against #2182 used the same SM120 GEMM shapes for
FP4xFP4, A8W4, and W4A8, with 6 shapes, 2
block_Ksettings, and 9 repeats perpoint. No block-scale cases were included.
Latency delta is this PR divided by #2182 minus 1:
Equivalent TOPS delta is #2182 latency divided by this PR latency minus 1:
The overall result is effectively performance-neutral versus #2182: mean TOPS
is about 0.57% lower, while median TOPS is about 0.11% lower.
Notes And Non-Goals
float4_e2m1_unpackedis a hidden physical carrier for the SM120 path, notthe public GEMM dtype users are expected to write.
uint8remains a runtime storage/interoperability detail for byte-compatiblehost tensors, not the public TileLang FP4 GEMM dtype.
ldmatrixoffset behavior stays on the existing path.