Skip to content

[Stacked][Feature] Support NVFP4 Gemm on Blackwell arch (SM100,110,120) #2324

Open
Hale423 wants to merge 26 commits into
tile-ai:mainfrom
Hale423:feat/gemm-nvfp4
Open

[Stacked][Feature] Support NVFP4 Gemm on Blackwell arch (SM100,110,120) #2324
Hale423 wants to merge 26 commits into
tile-ai:mainfrom
Hale423:feat/gemm-nvfp4

Conversation

@Hale423

@Hale423 Hale423 commented Jun 3, 2026

Copy link
Copy Markdown
Contributor

Summary

To address the feature requested in bug #1592, and stacked on PR #2182,

this PR starts Blackwell NVFP4 support by adding a minimal block-scaled FP4 GEMM path and examples based on mxf4nvf4.block_scale.

The implementation introduces a thin T.nvfp4_gemm(...) frontend for SM120, lowers it through the existing CUDA MMA tile-op path, and emits mma.sync.aligned.kind::mxf4nvf4.block_scale with packed FP4 E2M1 operands and FP8 E4M3 scale factors.

It also adds NVFP4 GEMM and fused MoE examples that follow the FlashInfer/TRT-LLM style data contract: packed FP4 activations/weights plus per-block E4M3 scale factors.

What changed

  • SM120 (mma.sync, scale factors in registers)

    • New block-scaled MMA primitive path: T.ptx_mma_blockscaled(...), CUDA codegen for tl::mma_sync_blockscaled(...), and the SM120_16x8x64_TN_VS<e2m1, e2m1, f32, ue4m3> CUTLASS dispatch.
    • Thin frontend T.nvfp4_gemm(A, B, SFA, SFB, C, ...), lowered via TensorCoreIntrinEmitter.mma_blockscaled(...).
  • SM100/SM110 (TCGEN05, scale factors in TMEM)

    • NVFP4 path through tcgen05.mma.kind::mxf4nvf4.block_scale, reached via T.tcgen05_gemm_blockscaled(..., is_nvfp4=True).
    • FP4 operands staged dense (2 e2m1/byte) with the SW64 swizzle, matching CUTLASS's dense e2m1 K-major operand layout for mxf4nvf4 (a gap-expanded f8f6f4 layout is incompatible).
    • E4M3 scale factors moved into TMEM via warp transpose + UTCCP (tcgen05.cp); mxf4nvf4 instruction descriptor + dense SW64 SMEM descriptor.
    • FP4 packed load/store fixes across memory scopes and sub-byte TMA sizing for FP4 payloads.
  • Examples: unified --nvfp4 flag (default = plain FP4, --nvfp4 = NVFP4 block-scaled):

    • examples/gemm_fp4/example_gemm_fp4_sm120.py, example_gemm_fp4_sm100.py
    • NVFP4 data contract (FlashInfer/TRT-LLM style): packed FP4 E2M1, per-16-element E4M3 scales; one m16n8k64 (SM120) / K64 (SM100) atom consumes 4 packed E4M3 scales per operand.

Validation (nvfp4 only)

SM120

python examples/gemm_fp4/example_gemm_fp4_sm120.py --nvfp4    # NVFP4
python examples/gemm_fp4/examples/gemm_fp4/example_fusedmoe_nvfp4_sm120.py  

SM100/SM110

python examples/gemm_fp4/example_gemm_fp4_sm100.py --nvfp4    # NVFP4
python examples/gemm_fp4/example_fusedmoe_nvfp4_sm100.py

Summary by CodeRabbit

  • New Features
    • Added FP4 GEMM kernel examples for NVIDIA SM100 and SM120 architectures
    • Added NVFP4 block-scaled matrix multiplication support
    • Added fused Mixture-of-Experts (MoE) examples with FP4 quantized weights
    • Added A8W4 (8-bit activation, 4-bit weight) kernel examples
    • Extended support for block-scaled operations with per-block scale factors

Hale423 and others added 20 commits March 12, 2026 10:55
Add the plumbing required to route float4_e2m1fn through the TCGEN5 MMA
code-generation path so that FP4 GEMM kernels can be emitted on SM100.

Changes:
- ptx.h / ptx.cc: add kFloat4_e2m1fn enum, string tables, DTypeFromString
- common.h: add kFloat4_e2m1fn to device-side DataType enum
- tcgen5_meta.h: add FP4 branch in encode_dtype (format code 2)
- tcgen05mma.h: add kFloat4_e2m1fn specializations for SS/TS/WS_SS
  (delegates to the existing f8f6f4 PTX kind)
- mma_macro_generator.py: add dtype_abbrv mapping for float4_e2m1fn
- docs/GEMM_NV_FP4_FEATURE_STEPS.md: design doc and progress tracker

Addresses tile-ai#1592

Made-with: Cursor
… raw bytes, get_ldmatrix_offset add fp4 special layout, add SM120_FP4_FP4_F32_TN MmaDispatcher specialization + fp4 << 2 bit-shift.
Co-authored-by: Cursor <cursoragent@cursor.com>
… as payload, functionality verified valid for naive gemm-fp4
…acksmem path, functionalities from all examples were verified valid
Co-authored-by: Cursor <cursoragent@cursor.com>
…fix ptx_mma_blockscaled meta param must be of StringImm/IntImm, add and verified gemm_nvfp4 & fused moe examples
@github-actions

github-actions Bot commented Jun 3, 2026

Copy link
Copy Markdown

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai

coderabbitai Bot commented Jun 3, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds FP4 (float4_e2m1fn) and NVFP4 block-scaled GEMM support across SM100/SM120 in TileLang. Changes span CUDA type definitions, MMA trait specializations, TCGEN05/MXF4NVF4 descriptor plumbing, scope-aware codegen for FP4 loads/stores, ALIGN16B swizzle layout for sub-byte TMA, Python GEMM frontend/lowering dispatch, and eight new runnable example scripts.

Changes

FP4 + block-scaled GEMM (SM100/SM120)

Layer / File(s) Summary
Public FP4 contracts and API exports
src/op/builtin.h, src/op/builtin.cc, src/layout/layout.h, src/layout/layout.cc, src/tl_templates/cuda/common.h, tilelang/layout/swizzle.py, tilelang/layout/__init__.py, tilelang/language/tir/ir.py, tilelang/language/tir/op.py, tilelang/language/ast/ir.py, tilelang/language/__init__.py
Declares ptx_mma_blockscaled builtin with 17 inputs, increments ptx_tcgen05_mma_blockscaled_ss arity to 17, adds makeAlign16BSwizzleLayout FFI export, introduces tl::float_e2m1_t type wrapper, and exposes nvfp4_gemm and ptx_mma_blockscaled via all Python re-export layers.
CUDA FP4 type, ldmatrix, and MMA dispatch
src/tl_templates/cuda/cuda_fp4.h, src/tl_templates/cuda/gemm_mma.h, src/tl_templates/cuda/gemm_sm100.h, src/tl_templates/cuda/instruction/mma.h, src/tl_templates/cuda/ldsm.h, src/tl_templates/cuda/instruction/tcgen05mma.h
Aliases fp4_e2_t to tl::float_e2m1_t, adds SM100 FP4/A8W4 MMA_Traits specializations (SS/WS_SS/TS), extends SM120 MMA dispatch for FP4/FP8 mixes, adds BlockScaledMmaDispatcher and mma_sync_blockscaled, introduces ptx_ldmatrix_b4x16_x{1,2,4} PTX wrappers, and adds the tcgen05mma_mxf4nvf4_blockscaled_ss template for NVFP4 inline PTX.
ALIGN16B sub-byte swizzle layout and TMA sizing
src/layout/gemm_layouts.cc, src/cuda/op/copy.cc
Implements MakeAlign16BSwizzleLayout2D/makeAlign16BSwizzleLayout using XOR-based 8×8 swizzle, integrates it into makeGemmABLayoutSm100 and DetectSwizzleMode, and updates Copy::LowerBulk to use the correct instruction_dim and total_bytes for FP4 TMA payloads.
TCGEN05 descriptor split and NVFP4 macro emission
src/op/tcgen5_meta.h, src/cuda/op/gemm.cc, tilelang/cuda/intrinsics/macro/tcgen05_macro_generator.py
Splits GetTCGEN5InstrDesc to accept separate a_dtype/b_dtype, adds GetTCGEN5MXF4NVF4BlockScaledInstrDesc, updates FFI wiring in gemm.cc, moves descriptor computation to pure-Python in the macro generator, adds tcgen05mma_mxf4nvf4_blockscaled entry point, and extends tcgen05_blockscaled_atom with use_mxf4nvf4 flag.
TileLang GEMM frontends, dtype inference, and lowering dispatch
tilelang/tileop/gemm/gemm_base.py, tilelang/language/gemm_op.py, tilelang/cuda/intrinsics/layout/mma_layout.py, tilelang/cuda/intrinsics/layout/utils.py, tilelang/cuda/intrinsics/macro/mma_macro_generator.py, tilelang/cuda/op/gemm/gemm_mma.py, tilelang/cuda/op/gemm/gemm_tcgen05.py, tilelang/layout/swizzle.py, tilelang/cuda/transform/__init__.py
Adds nvfp4_gemm and is_nvfp4 parameter on tcgen05_gemm_blockscaled, extends GemmBase.in_dtype/in_dtype_b for mixed FP4/FP8, introduces FP4 ldmatrix layout helpers, extends TensorCoreIntrinEmitter with block-scaled mma_blockscaled/mma_blockscaled_atom, updates GemmTCGEN5.infer_shared_layout for sub-byte dtypes, and dispatches is_nvfp4 through both MMA and TCGEN05 lowering paths.
CUDA codegen scope-aware FP4 and ptx_mma_blockscaled lowering
src/cuda/codegen/codegen_cuda.cc, src/op/utils.cc
Scopes packed FP4 GetBufferRef indexing to empty/global only, adds ptx_mma_blockscaled lowering that emits mma_sync_blockscaled, gates FP4 packed loads/stores by scope, fixes local scalar FP4 allocation, adds use_mxf4nvf4 path in tcgen05 blockscaled lowering, and updates float4_e2m1fn TMA data-type enum constant from 13 to 14.
SM100/SM110 example scripts
examples/gemm_fp4/example_gemm_a8w4_sm100.py, examples/gemm_fp4/example_gemm_fp4_sm100.py, examples/gemm_fp4/example_fusedmoe_a8w4_sm100.py, examples/gemm_fp4/example_fusedmoe_nvfp4_sm100.py
Adds four SM100/SM110 example scripts covering A8W4 GEMM, plain FP4 GEMM, NVFP4 GEMM (with block-scaled scale tensors), A8W4 fused MoE, and NVFP4 fused MoE, each with FP4 unpack helpers, kernel compilation, zero-input correctness checks, float reference validation, and latency/TFLOPS benchmarks.
SM120 example scripts
examples/gemm_fp4/example_gemm_a8w4_sm120.py, examples/gemm_fp4/example_gemm_fp4_sm120.py, examples/gemm_fp4/example_fusedmoe_a8w4_sm120.py, examples/gemm_fp4/example_fusedmoe_nvfp4_sm120.py
Adds four SM120 example scripts covering A8W4 GEMM, plain FP4 GEMM (with uint8 unpack path), NVFP4 GEMM (using T.nvfp4_gemm with swizzled layouts and scale registers), A8W4 fused MoE, and NVFP4 fused MoE, all with the same compile/validate/benchmark pattern.

Sequence Diagram(s)

sequenceDiagram
    rect rgba(173, 216, 230, 0.5)
        Note over User,CUDA: NVFP4 GEMM compilation and runtime
    end
    participant User as User script
    participant gemm_op as nvfp4_gemm / tcgen05_gemm_blockscaled
    participant gemm_tcgen05 as GemmTCGEN5 lowering
    participant macro as tcgen05_macro_generator
    participant codegen as codegen_cuda.cc
    participant Device as CUDA device

    User->>gemm_op: nvfp4_gemm(A_uint8, B_uint8, SFA, SFB, C)
    gemm_op->>gemm_op: compute logical_K=K*2, annotate is_nvfp4=1
    gemm_op->>gemm_tcgen05: infer_shared_layout(dtype=fp4, k_major)
    gemm_tcgen05->>macro: tcgen05mma_mxf4nvf4_blockscaled(SFA_tmem, SFB_tmem)
    macro->>macro: get_tcgen5_mxf4nvf4_blockscaled_instr_desc(is_mxfp4)
    macro->>macro: tcgen05_blockscaled_atom(use_mxf4nvf4=True)
    macro-->>codegen: ptx_tcgen05_mma_blockscaled_ss(..., use_mxf4nvf4=1)
    codegen->>codegen: emit tcgen05mma_mxf4nvf4_blockscaled_ss<use_2cta>
    codegen-->>Device: inline PTX tcgen05.mma.kind::mxf4nvf4.block_scale
    Device-->>User: FP32 output tensor
Loading

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

Possibly related PRs

  • tile-ai/tilelang#2274: Refactors and adds the blockscaled TCGEN05 mxf4nvf4 interfaces that this PR's fusedmoe_nvfp4_sm100.py and tcgen05_gemm_blockscaled(is_nvfp4=True) path depend on directly.
  • tile-ai/tilelang#1949: Splits tcgen05_gemm as a distinct frontend in gemm_op.py, which this PR extends by adding the is_nvfp4 parameter and logical_K handling.
  • tile-ai/tilelang#1327: Modifies the same TCGEN05 MMA metadata/instruction selection plumbing (tcgen5_meta.h, tcgen05 lowering) that this PR further extends for the A/B split dtype and MXF4NVF4 descriptor.

Suggested reviewers

  • LeiWang1999

🐇 Hop, hop through nibbles and bits,
FP4 weights packed in tight-little splits!
SM100 and SM120 both bloom,
Gate and up in TMEM's room.
Block-scaled MMA finds its place —
A bunny's math, at blazing pace! ✨

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@Hale423 Hale423 changed the title [Feature] Support NVFP4 Gemm on Blackwell arch (SM100,110,120) [Stacked][Feature] Support NVFP4 Gemm on Blackwell arch (SM100,110,120) Jun 3, 2026

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 14

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tilelang/cuda/intrinsics/macro/tcgen05_macro_generator.py (1)

399-418: ⚠️ Potential issue | 🟠 Major | 🏗️ Heavy lift

Scale IDs are constant across the entire K loop.

runtime_instr_desc is computed once from fixed sf_a_id/sf_b_id before the for ki loop, and tcgen05_blockscaled_atom() never adjusts them. When num_k_atoms > 1, later atoms still read the first scale vector even though the A/B descriptor offsets advance, so multi-block or tile-varying blockscaled GEMMs will miscompute.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tilelang/cuda/intrinsics/macro/tcgen05_macro_generator.py` around lines 399 -
418, runtime_instr_desc is computed once from sf_a_id/sf_b_id before the ki
loop, so all atoms use the same scale IDs even though descriptors/offsets
advance; move the computation of runtime_instr_desc (or otherwise recompute
sf_a_id/sf_b_id) inside the ki loop so each atom uses the current scale IDs for
its descriptor, i.e. update how runtime_instr_desc is built before calling
tcgen05_blockscaled_atom (referencing runtime_instr_desc, sf_a_id, sf_b_id,
tcgen05_blockscaled_atom, sfa_data, sfb_data, a_params, b_params) so multi-atom
iterations use the correct per-atom scale vectors.
🧹 Nitpick comments (2)
tilelang/cuda/intrinsics/layout/utils.py (1)

22-29: ⚡ Quick win

Expose FP4 in the public dtype contract.

get_ldmatrix_offset() now has a runtime FP4 path, but the signature still only advertises "int4". That makes type-checkers flag the new path and nudges callers toward the wrong dtype string for this feature.

Suggested fix
 def get_ldmatrix_offset(
     matrix: Literal["A", "B"],
     row_idx,
     col_idx,
     stride,
-    dtype: Literal["float16", "int8", "int4"] = "float16",
+    dtype: Literal["float16", "int8", "int4", "float4_e2m1fn"] = "float16",
     transposed: bool = False,
 ):

Also update the surrounding doc/error text so FP4 and INT4 stay clearly separated.

Also applies to: 52-63

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tilelang/cuda/intrinsics/layout/utils.py` around lines 22 - 29, The function
get_ldmatrix_offset currently advertises dtype: Literal["float16", "int8",
"int4"] but the implementation contains a runtime path for FP4; update the
public type contract to include "fp4" (or "FP4" matching project convention) so
type-checkers accept the code, and adjust any surrounding docstrings/error
messages to explicitly list and distinguish "fp4" and "int4" (and/or their
canonical casing) where dtype choices are described or validated (also update
the equivalent signature/annotations and messages in the similar helper at lines
~52-63 that handle small-int/fp4 cases). Ensure you reference and change the
dtype literal in get_ldmatrix_offset and the corresponding validation/error text
so FP4 and INT4 remain clearly separated.
examples/gemm_fp4/example_fusedmoe_nvfp4_sm120.py (1)

118-132: ⚡ Quick win

Reuse the staged X tile for both GEMMs.

The gate and up loops both execute T.copy(X_bytes[...], X_shared), so every ko rereads the same activation block twice. For the benchmarked FC1 path that doubles global-memory traffic on X and makes the “fused” example materially slower than it needs to be.

One-loop version
             T.clear(gate_local)
-            for ko in T.serial(hidden_blocks):
-                T.copy(X_bytes[by * block_tokens, ko * packed_block_hidden], X_shared)
-                T.copy(W_gate_bytes[bx * block_intermediate, ko * packed_block_hidden], gate_shared)
-                SFA_local[0] = X_scale[by, ko]
-                SFB_local[0] = W_gate_scale[bx, ko]
-                T.nvfp4_gemm(X_shared, gate_shared, SFA_local, SFB_local, gate_local, transpose_B=True, clear_accum=(ko == 0))
-
             T.clear(up_local)
             for ko in T.serial(hidden_blocks):
                 T.copy(X_bytes[by * block_tokens, ko * packed_block_hidden], X_shared)
+                SFA_local[0] = X_scale[by, ko]
+
+                T.copy(W_gate_bytes[bx * block_intermediate, ko * packed_block_hidden], gate_shared)
+                SFB_local[0] = W_gate_scale[bx, ko]
+                T.nvfp4_gemm(X_shared, gate_shared, SFA_local, SFB_local, gate_local, transpose_B=True, clear_accum=(ko == 0))
+
                 T.copy(W_up_bytes[bx * block_intermediate, ko * packed_block_hidden], up_shared)
-                SFA_local[0] = X_scale[by, ko]
                 SFB_local[0] = W_up_scale[bx, ko]
                 T.nvfp4_gemm(X_shared, up_shared, SFA_local, SFB_local, up_local, transpose_B=True, clear_accum=(ko == 0))
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/gemm_fp4/example_fusedmoe_nvfp4_sm120.py` around lines 118 - 132,
The X activation tile is copied twice per ko; change the loops so X_bytes is
loaded into X_shared once and reused for both GEMMs: for each ko, do a single
T.copy(X_bytes[by * block_tokens, ko * packed_block_hidden], X_shared) then call
T.nvfp4_gemm for gate (using gate_shared, SFA_local/SFB_local, gate_local,
clear_accum=(ko == 0)) and then call T.nvfp4_gemm for up (using up_shared,
SFA_local/SFB_local, up_local, clear_accum=(ko == 0)), removing the duplicate
T.copy from the second loop (references: X_bytes, X_shared, W_gate_bytes,
W_up_bytes, T.nvfp4_gemm, gate_local, up_local).
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@examples/gemm_fp4/example_fusedmoe_a8w4_sm100.py`:
- Around line 40-46: Add an explicit guard that TL_MOE_HIDDEN is even before
treating expert weight tensors as packed FP4 (two 4-bit values per byte);
validate this in the config/parsing path and before any use of
unpack_fp4_to_float and places where expert tensors are created/reshaped
(references: function unpack_fp4_to_float and the code regions handling expert
weight shapes around the other occurrences). If TL_MOE_HIDDEN is odd, raise a
clear ValueError (or argparse/config error) explaining that packed FP4 storage
requires an even hidden dimension, so the failure happens early instead of
during unpacking.

In `@examples/gemm_fp4/example_fusedmoe_a8w4_sm120.py`:
- Around line 136-141: The zero-input smoke test currently only prints PASS/FAIL
and can allow the script to exit successfully on failure; change it to a hard
assertion so failures stop execution: after calling jit_kernel(z_input, z_gate,
z_up) and computing c_zero, replace the print line with an assertion that
c_zero.abs().max().item() == 0.0 (or torch.equal(c_zero,
torch.zeros_like(c_zero))) and include a descriptive message (e.g., "zeros in ->
zeros out failed") so the test fails closed; refer to the variables z_input,
z_gate, z_up and the function jit_kernel to locate where to apply this change.

In `@examples/gemm_fp4/example_gemm_a8w4_sm100.py`:
- Around line 35-41: The code assumes K is even when packing/unpacking FP4 (see
unpack_fp4_to_float) which silently truncates when TL_A8W4_K is odd; add an
explicit check (raise ValueError or assert) that TL_A8W4_K % 2 == 0 before any
packing or buffer allocation/reshape, and apply the same guard near the FP4
packing logic and any other unpacking uses (the unpack_fp4_to_float function and
the FP4 weight-packing sites) so the script fails fast with a clear error
instead of allocating a truncated buffer.

In `@examples/gemm_fp4/example_gemm_fp4_sm100.py`:
- Around line 131-137: The unpack_fp4_to_float routine (and corresponding calls
like make_random_fp4) assume the packed axis length K//2 (or cols//2) is an
integer; add an upfront validation in unpack_fp4_to_float (and the code that
calls it, e.g., make_random_fp4 and any uses conditioned on
TL_FP4_TRANSPOSE_B/TL_FP4_K/TL_FP4_N) to check that K (and any cols passed as
packed length*2) is even and raise a clear ValueError if not (or alternatively
document/pad explicitly). Locate the unpack_fp4_to_float function and the
make_random_fp4 call sites referenced in the diff and enforce this guard so odd
TL_FP4_K or TL_FP4_N cannot silently truncate the buffer.

In `@examples/gemm_fp4/example_gemm_nvfp4_sm120.py`:
- Around line 71-109: The kernel assumes full 16x8x64 tiles but the launch dims
use ceildiv, so for any M, N, or K not divisible by block_M=16, block_N=8,
block_K=64 the kernel will read/write out-of-bounds; fix by rejecting
unsupported sizes early in matmul_nvfp4_sm120 (before building/returning main) —
check that M % block_M == 0, N % block_N == 0, and K % block_K == 0 and raise a
clear exception (e.g., ValueError) if not, so callers must provide padded inputs
or a different kernel; reference block_M/block_N/block_K and the prim func main
where copies (T.copy to/from A_shared/B_shared and final T.copy to C) currently
assume full tiles.
- Around line 112-124: The example lacks a runtime SM120 capability check before
calling tilelang.compile for the SM120-only frontend (T.nvfp4_gemm) in the
__main__ path; add a preflight that queries the current device capability (via
tilelang.utils.target.target_is_sm120 or equivalent helper) and fail fast with a
clear "SM120 required" message if the device is not SM120, returning/non-zero
exit instead of proceeding to tilelang.compile and kernel creation; update both
example_gemm_nvfp4_sm120 entrypoint logic and any similar SM120-only example
functions to guard before compilation and import/use of SM120-specific
primitives.

In `@src/cuda/codegen/codegen_cuda.cc`:
- Around line 4746-4754: The scope check using GetPtrStorageScope(buffer_var) is
incomplete—GetBufferRef() first consults alloc_storage_scope_ so allocated
shared/local FP4 buffers can be misclassified as global/packed; update the FP4
packed-load/store branches to resolve scope the same way GetBufferRef() does by
checking alloc_storage_scope_ for buffer_var.get() first and falling back to
GetPtrStorageScope(buffer_var), then use that resolved scope when computing
is_packed_fp4_scope; apply this change in both the load (tl_fp4_packed_load) and
store branches.

In `@src/cuda/op/copy.cc`:
- Around line 1721-1728: The code path handling is_align16b_subbyte_layout
currently calls TMABytesFromElements and so ignores transaction-byte accounting
(and the float4_e2m1_unpacked FP4 special case), causing mbarrier_expect_tx to
be overstated; modify the branch where total_bytes is set (the block that uses
inner_box_dim, instruction_dim, total_elements, and shared_tensor->dtype) to
call TMATransactionBytesFromElements(total_elements * loop_extent,
shared_tensor->dtype) instead of TMABytesFromElements so transaction-byte
accounting (and FP4 handling) is used for ALIGN16B loads.

In `@src/layout/gemm_layouts.cc`:
- Around line 1027-1035: DetectSwizzleMode() is collapsing ALIGN16B into
SwizzleMode::kFull causing MergeSwizzleLayouts() to reconstruct the wrong layout
with makeFullBankSwizzleLayout(buffer); add a distinct swizzle mode (e.g.,
SwizzleMode::kAlign16B) or, when merging, branch on info.element_size < 8 to
detect ALIGN16B and reconstruct it with makeAlign16BSwizzleLayout(buffer)
instead of makeFullBankSwizzleLayout(buffer); update the SwizzleMode enum and
the switch/merge logic in MergeSwizzleLayouts() to return/use the new kAlign16B
(or the element_size-based branch) so ALIGN16B mappings for sub-byte types
remain preserved.

In `@src/tl_templates/cuda/instruction/mma.h`:
- Around line 361-381: The FP4 values loaded by ptx_ldmatrix_b4x16_* end up in
the low nibble and must be repacked the same way as in mma_sync; update
mma_sync_blockscaled to apply the FP4 repack to the A and B operands before
calling Dispatcher::exec (i.e., perform the same left-shift/repack used in
mma_sync on the arrays of Dispatcher::ARegType and Dispatcher::BRegType prior to
dispatch) so BlockScaledMmaDispatcher::exec receives the correctly encoded FP4
data and produces correct accumulations.

In `@tilelang/cuda/intrinsics/macro/mma_macro_generator.py`:
- Around line 678-709: The scale operands are never advanced per k_inner, so
SFA_local_buf and SFB_local_buf always use base 0; compute SFA_offset and
SFB_offset analogous to A_offset/B_offset using k_inner (e.g. advancing by
k_inner * element_stride_or_vector_size used for scale storage) and pass those
offsets into T.ptx_mma_blockscaled instead of always 0; update the
_atom_mma_blockscaled macro to use the new SFA_offset and SFB_offset
(referencing k_inner, SFA_local_buf, SFB_local_buf, and the
_atom_mma_blockscaled/T.ptx_mma_blockscaled call) so each ki iteration uses the
correct per-k_inner scale vectors.
- Around line 126-136: The check for FP4 chunk sizes must reject non-multiples
rather than only enforcing a minimum: in the block-scaled path (when
self.is_blockscaled and str(a_dtype) == "float4_e2m1fn") validate that
self.chunk is a multiple of 64 (e.g. self.chunk % 64 == 0) and raise a
ValueError if not; in the plain FP4 path (when str(a_dtype) == "float4_e2m1fn")
validate that self.chunk is a multiple of 32 (self.chunk % 32 == 0) instead of
only checking self.chunk < 32, and raise a clear ValueError if the modulus check
fails; keep setting self.k_dim = 64 (block-scaled) or self.k_dim = min(32,
self.chunk) after the validation.

In `@tilelang/language/gemm_op.py`:
- Around line 269-320: The nvfp4_gemm path dropped the _gemm_impl base-offset
and rank guards so row-sliced packed tiles can lose their outer-dimension base
offset and malformed inputs raise IndexError; restore the same validation: after
computing A_offset = retrieve_offset(A_region) and B_offset =
retrieve_offset(B_region) (and their shapes via retrieve_shape if needed),
assert that A_offset and B_offset have the expected rank (e.g., len(A_offset) >=
2 and len(B_offset) >= 2) and assert A_offset[-2] == 0 and B_offset[-2] == 0 (or
alternatively call _gemm_impl to reuse its checks) before forwarding
A_offset[-1] and B_offset[-1] into tirx.call_intrin in nvfp4_gemm so
outer-dimension base offsets cannot be silently lost.

In `@tilelang/layout/swizzle.py`:
- Around line 91-120: This code currently flattens all leading dims into a
single stride and calls _ffi_api.make_tcgen05mma_swizzled_layout which lets the
swizzle cross outer-dimension boundaries; instead, compute the 2D layout from
the trailing two dims only (use shape[-2:] and the corresponding last-two
stride/continuity values from _get_stride_continuous/stride), call
_ffi_api.make_tcgen05mma_swizzled_layout to build the base 2D layout for those
last two dims, then expand that 2D base across the leading dimensions
(shape[:-2]) using the same ExpandLayout2D-style expansion used by the native
C++ factories so the Python path matches the buffer-based path for rank>2
buffers.

---

Outside diff comments:
In `@tilelang/cuda/intrinsics/macro/tcgen05_macro_generator.py`:
- Around line 399-418: runtime_instr_desc is computed once from sf_a_id/sf_b_id
before the ki loop, so all atoms use the same scale IDs even though
descriptors/offsets advance; move the computation of runtime_instr_desc (or
otherwise recompute sf_a_id/sf_b_id) inside the ki loop so each atom uses the
current scale IDs for its descriptor, i.e. update how runtime_instr_desc is
built before calling tcgen05_blockscaled_atom (referencing runtime_instr_desc,
sf_a_id, sf_b_id, tcgen05_blockscaled_atom, sfa_data, sfb_data, a_params,
b_params) so multi-atom iterations use the correct per-atom scale vectors.

---

Nitpick comments:
In `@examples/gemm_fp4/example_fusedmoe_nvfp4_sm120.py`:
- Around line 118-132: The X activation tile is copied twice per ko; change the
loops so X_bytes is loaded into X_shared once and reused for both GEMMs: for
each ko, do a single T.copy(X_bytes[by * block_tokens, ko *
packed_block_hidden], X_shared) then call T.nvfp4_gemm for gate (using
gate_shared, SFA_local/SFB_local, gate_local, clear_accum=(ko == 0)) and then
call T.nvfp4_gemm for up (using up_shared, SFA_local/SFB_local, up_local,
clear_accum=(ko == 0)), removing the duplicate T.copy from the second loop
(references: X_bytes, X_shared, W_gate_bytes, W_up_bytes, T.nvfp4_gemm,
gate_local, up_local).

In `@tilelang/cuda/intrinsics/layout/utils.py`:
- Around line 22-29: The function get_ldmatrix_offset currently advertises
dtype: Literal["float16", "int8", "int4"] but the implementation contains a
runtime path for FP4; update the public type contract to include "fp4" (or "FP4"
matching project convention) so type-checkers accept the code, and adjust any
surrounding docstrings/error messages to explicitly list and distinguish "fp4"
and "int4" (and/or their canonical casing) where dtype choices are described or
validated (also update the equivalent signature/annotations and messages in the
similar helper at lines ~52-63 that handle small-int/fp4 cases). Ensure you
reference and change the dtype literal in get_ldmatrix_offset and the
corresponding validation/error text so FP4 and INT4 remain clearly separated.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 233d92a6-159d-4126-8152-e37c62dc19ad

📥 Commits

Reviewing files that changed from the base of the PR and between 3d95d65 and 48db769.

📒 Files selected for processing (38)
  • examples/gemm_fp4/example_fusedmoe_a8w4_sm100.py
  • examples/gemm_fp4/example_fusedmoe_a8w4_sm120.py
  • examples/gemm_fp4/example_fusedmoe_nvfp4_sm120.py
  • examples/gemm_fp4/example_gemm_a8w4_sm100.py
  • examples/gemm_fp4/example_gemm_a8w4_sm120.py
  • examples/gemm_fp4/example_gemm_fp4_sm100.py
  • examples/gemm_fp4/example_gemm_fp4_sm120.py
  • examples/gemm_fp4/example_gemm_nvfp4_sm120.py
  • src/cuda/codegen/codegen_cuda.cc
  • src/cuda/op/copy.cc
  • src/cuda/op/gemm.cc
  • src/layout/gemm_layouts.cc
  • src/layout/layout.cc
  • src/layout/layout.h
  • src/op/builtin.cc
  • src/op/builtin.h
  • src/op/tcgen5_meta.h
  • src/op/utils.cc
  • src/tl_templates/cuda/common.h
  • src/tl_templates/cuda/cuda_fp4.h
  • src/tl_templates/cuda/gemm_mma.h
  • src/tl_templates/cuda/gemm_sm100.h
  • src/tl_templates/cuda/instruction/mma.h
  • src/tl_templates/cuda/ldsm.h
  • tilelang/cuda/intrinsics/layout/mma_layout.py
  • tilelang/cuda/intrinsics/layout/utils.py
  • tilelang/cuda/intrinsics/macro/mma_macro_generator.py
  • tilelang/cuda/intrinsics/macro/tcgen05_macro_generator.py
  • tilelang/cuda/op/gemm/gemm_mma.py
  • tilelang/cuda/op/gemm/gemm_tcgen05.py
  • tilelang/language/__init__.py
  • tilelang/language/ast/ir.py
  • tilelang/language/gemm_op.py
  • tilelang/language/tir/ir.py
  • tilelang/language/tir/op.py
  • tilelang/layout/__init__.py
  • tilelang/layout/swizzle.py
  • tilelang/tileop/gemm/gemm_base.py

Comment thread examples/gemm_fp4/example_fusedmoe_a8w4_sm100.py
Comment thread examples/gemm_fp4/example_fusedmoe_a8w4_sm120.py Outdated
Comment thread examples/gemm_fp4/example_gemm_a8w4_sm100.py
Comment on lines +131 to +137
def unpack_fp4_to_float(packed_int8, M, K):
lut = torch.tensor(FP4_E2M1_TO_FLOAT, dtype=torch.float32, device=packed_int8.device)
flat = packed_int8.to(torch.uint8).reshape(M, K // 2)
lo = flat & 0x0F
hi = (flat >> 4) & 0x0F
unpacked = torch.stack([lo, hi], dim=-1).reshape(M, K).to(torch.int64)
return lut[unpacked]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Validate the packed FP4 dimensions up front.

make_random_fp4(..., cols // 2) and unpack_fp4_to_float(...reshape(..., K // 2)) both assume the packed axis is even. With TL_FP4_K odd — and with TL_FP4_N odd when TL_FP4_TRANSPOSE_B=0 — this truncates the byte buffer shape and the example fails later with a much less obvious reshape/reference error.

Suggested guard
 M = int(os.environ.get("TL_FP4_M", "256"))
 N = int(os.environ.get("TL_FP4_N", "256"))
 K = int(os.environ.get("TL_FP4_K", "256"))
@@
 input_mode = os.environ.get("TL_FP4_INPUT_MODE", "random")
 transpose_b = os.environ.get("TL_FP4_TRANSPOSE_B", "1") != "0"
+
+if K % 2 != 0:
+    raise ValueError("TL_FP4_K must be even for packed FP4 inputs")
+if not transpose_b and N % 2 != 0:
+    raise ValueError("TL_FP4_N must be even when TL_FP4_TRANSPOSE_B=0")
+
 print(f"Running FP4 GEMM (SM100/SM110 TCGEN05): M={M}, N={N}, K={K}, input_mode={input_mode}, transpose_b={transpose_b}")

Also applies to: 140-153, 186-206

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/gemm_fp4/example_gemm_fp4_sm100.py` around lines 131 - 137, The
unpack_fp4_to_float routine (and corresponding calls like make_random_fp4)
assume the packed axis length K//2 (or cols//2) is an integer; add an upfront
validation in unpack_fp4_to_float (and the code that calls it, e.g.,
make_random_fp4 and any uses conditioned on
TL_FP4_TRANSPOSE_B/TL_FP4_K/TL_FP4_N) to check that K (and any cols passed as
packed length*2) is even and raise a clear ValueError if not (or alternatively
document/pad explicitly). Locate the unpack_fp4_to_float function and the
make_random_fp4 call sites referenced in the diff and enforce this guard so odd
TL_FP4_K or TL_FP4_N cannot silently truncate the buffer.

Comment thread examples/gemm_fp4/example_gemm_nvfp4_sm120.py Outdated
Comment on lines +361 to +381
template <DataType AType, DataType BType, DataType CType, DataType SFType,
int M, int N, int K, bool TransA, bool TransB, int VS>
TL_DEVICE void mma_sync_blockscaled(
typename detail::BlockScaledMmaDispatcher<
AType, BType, CType, SFType, M, N, K, TransA, TransB, VS>::CRegType *c,
const typename detail::BlockScaledMmaDispatcher<
AType, BType, CType, SFType, M, N, K, TransA, TransB, VS>::ARegType *a,
const typename detail::BlockScaledMmaDispatcher<
AType, BType, CType, SFType, M, N, K, TransA, TransB, VS>::BRegType *b,
const typename detail::BlockScaledMmaDispatcher<
AType, BType, CType, SFType, M, N, K, TransA, TransB, VS>::SFRegType
*sfa,
const typename detail::BlockScaledMmaDispatcher<
AType, BType, CType, SFType, M, N, K, TransA, TransB, VS>::SFRegType
*sfb) {
using Dispatcher = detail::BlockScaledMmaDispatcher<
AType, BType, CType, SFType, M, N, K, TransA, TransB, VS>;
static_assert(!std::is_void_v<typename Dispatcher::CRegType>,
"tl::mma_sync_blockscaled: unsupported configuration");

Dispatcher::exec(c, a, b, c, sfa, sfb);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Apply the FP4 repack in mma_sync_blockscaled too.

ptx_ldmatrix_b4x16_* leaves each FP4 value in the low nibble, and mma_sync() compensates with << 2 before dispatch. This path forwards A/B unchanged, so NVFP4 block-scaled GEMM consumes a different FP4 encoding than the plain SM120 FP4 path and will accumulate wrong values.

Suggested fix
 template <DataType AType, DataType BType, DataType CType, DataType SFType,
           int M, int N, int K, bool TransA, bool TransB, int VS>
 TL_DEVICE void mma_sync_blockscaled(
@@
   using Dispatcher = detail::BlockScaledMmaDispatcher<
       AType, BType, CType, SFType, M, N, K, TransA, TransB, VS>;
   static_assert(!std::is_void_v<typename Dispatcher::CRegType>,
                 "tl::mma_sync_blockscaled: unsupported configuration");
-
-  Dispatcher::exec(c, a, b, c, sfa, sfb);
+  if constexpr (AType == DataType::kFloat4_e2m1fn ||
+                BType == DataType::kFloat4_e2m1fn) {
+    using AReg = typename Dispatcher::ARegType;
+    using BReg = typename Dispatcher::BRegType;
+    constexpr int nA =
+        detail::BlockScaledMmaImplTraits<typename Dispatcher::Impl>::kARegs;
+    constexpr int nB =
+        detail::BlockScaledMmaImplTraits<typename Dispatcher::Impl>::kBRegs;
+    AReg as[nA];
+    BReg bs[nB];
+#pragma unroll
+    for (int i = 0; i < nA; ++i)
+      as[i] = (AType == DataType::kFloat4_e2m1fn) ? (a[i] << 2) : a[i];
+#pragma unroll
+    for (int i = 0; i < nB; ++i)
+      bs[i] = (BType == DataType::kFloat4_e2m1fn) ? (b[i] << 2) : b[i];
+    Dispatcher::exec(c, as, bs, c, sfa, sfb);
+  } else {
+    Dispatcher::exec(c, a, b, c, sfa, sfb);
+  }
 }
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/tl_templates/cuda/instruction/mma.h` around lines 361 - 381, The FP4
values loaded by ptx_ldmatrix_b4x16_* end up in the low nibble and must be
repacked the same way as in mma_sync; update mma_sync_blockscaled to apply the
FP4 repack to the A and B operands before calling Dispatcher::exec (i.e.,
perform the same left-shift/repack used in mma_sync on the arrays of
Dispatcher::ARegType and Dispatcher::BRegType prior to dispatch) so
BlockScaledMmaDispatcher::exec receives the correctly encoded FP4 data and
produces correct accumulations.

Comment thread tilelang/cuda/intrinsics/macro/mma_macro_generator.py
Comment on lines +678 to +709
a_is_fragment = is_fragment(A_local_buf)
b_is_fragment = is_fragment(B_local_buf)
a_local_stride: PrimExpr = k_inner * self.warp_rows * local_size_a if a_is_fragment else 0
b_local_stride: PrimExpr = k_inner * self.warp_cols * local_size_b if b_is_fragment else 0

A_offset = a_local_stride + inst_m_idx * local_size_a
B_offset = b_local_stride + inst_n_idx * local_size_b
C_offset = inst_m_idx * self.warp_cols * local_size_out + inst_n_idx * local_size_out

@T.macro
def _atom_mma_blockscaled(A_local_buf, B_local_buf, C_local_buf, SFA_local_buf, SFB_local_buf):
T.ptx_mma_blockscaled(
accum_dtype,
mma_prefix,
"row",
"col",
a_dtype_abbrv,
b_dtype_abbrv,
accum_dtype_abbrv,
self.scale_dtype,
self.scale_vec_size,
A_local_buf.data,
A_offset,
B_local_buf.data,
B_offset,
C_local_buf.data,
C_offset,
SFA_local_buf.data,
SFB_local_buf.data,
)

return _atom_mma_blockscaled(A_local_buf, B_local_buf, C_local_buf, SFA_local_buf, SFB_local_buf)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | 🏗️ Heavy lift

Advance the scale operands with k_inner.

A_offset and B_offset move per k_inner, but T.ptx_mma_blockscaled always receives SFA_local_buf.data and SFB_local_buf.data at base 0. The caller loops over ki, so any K > 64 or non-uniform per-block scales will keep reusing the first scale vector and accumulate with the wrong exponents.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tilelang/cuda/intrinsics/macro/mma_macro_generator.py` around lines 678 -
709, The scale operands are never advanced per k_inner, so SFA_local_buf and
SFB_local_buf always use base 0; compute SFA_offset and SFB_offset analogous to
A_offset/B_offset using k_inner (e.g. advancing by k_inner *
element_stride_or_vector_size used for scale storage) and pass those offsets
into T.ptx_mma_blockscaled instead of always 0; update the _atom_mma_blockscaled
macro to use the new SFA_offset and SFB_offset (referencing k_inner,
SFA_local_buf, SFB_local_buf, and the
_atom_mma_blockscaled/T.ptx_mma_blockscaled call) so each ki iteration uses the
correct per-k_inner scale vectors.

Comment thread tilelang/language/gemm_op.py
Comment on lines +91 to +120
global _WARNED_LEGACY_TCGEN05_LAYOUT_FFI
buf, shape, _ = _get_buffer_info(buffer)
stride, continuous = _get_stride_continuous(buffer)
element_size = _get_element_size(buffer)
if continuity is None:
continuity = -1
return _ffi_api.make_tcgen05mma_swizzled_layout(buf, continuity, k_major)
continuity = continuous
try:
base = _ffi_api.make_tcgen05mma_swizzled_layout(
stride,
continuous,
continuity,
element_size,
k_major,
)
return base.reshape(shape)
except TypeError as err:
# Keep Python sources compatible with older built libs that still expose
# the legacy FFI signature: (buffer, continuity, k_major).
if "Mismatched number of arguments" not in str(err):
raise
if not _WARNED_LEGACY_TCGEN05_LAYOUT_FFI:
warnings.warn(
"Detected legacy tcgen05 swizzle-layout FFI in the loaded native "
"TileLang library. Rebuild the native `build/` artifacts so FP4 "
"SM100/SM110 kernels use the current gap-aware ALIGN16B layout.",
RuntimeWarning,
stacklevel=2,
)
_WARNED_LEGACY_TCGEN05_LAYOUT_FFI = True
return _ffi_api.make_tcgen05mma_swizzled_layout(buf, continuity, k_major)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

The new TCGEN05 Python path changes the layout for rank>2 buffers.

The native C++ factories only swizzle the last two dimensions and keep leading dims unchanged via ExpandLayout2D(...), but this code flattens every leading dim into stride and then calls base.reshape(shape). For 3D+ buffers that lets the swizzle span outer-dimension boundaries and produces a different mapping than the buffer-based path. Build the 2D base from shape[-2:] and expand across shape[:-2] instead; otherwise make_align16b_swizzled_layout() inherits the same bug whenever it falls back here.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tilelang/layout/swizzle.py` around lines 91 - 120, This code currently
flattens all leading dims into a single stride and calls
_ffi_api.make_tcgen05mma_swizzled_layout which lets the swizzle cross
outer-dimension boundaries; instead, compute the 2D layout from the trailing two
dims only (use shape[-2:] and the corresponding last-two stride/continuity
values from _get_stride_continuous/stride), call
_ffi_api.make_tcgen05mma_swizzled_layout to build the base 2D layout for those
last two dims, then expand that 2D base across the leading dimensions
(shape[:-2]) using the same ExpandLayout2D-style expansion used by the native
C++ factories so the Python path matches the buffer-based path for rank>2
buffers.

…ected in codegen_cuda.cc, called at tcgen05_macro_generator.py

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@src/cuda/codegen/codegen_cuda.cc`:
- Around line 3411-3433: When use_mxf4nvf4 is true, validate that
kind_dtype/dtype_enum is the FP4 variant expected by the mxf4nvf4 intrinsic
before emitting tcgen05mma_mxf4nvf4_blockscaled_ss: call
tl::codegen::ptx::DTypeFromString(kind_dtype) (already present as dtype_enum)
and assert (ICHECK) that it matches the FP4 dtype enum the intrinsic requires,
otherwise fail with a clear error message (or avoid emitting the FP4 intrinsic);
update the branch that constructs tcgen05_call (and/or use_mxf4nvf4 decision) to
perform this dtype check so a mismatched caller errors at codegen rather than
silently lowering to the wrong intrinsic.

In `@src/op/tcgen5_meta.h`:
- Around line 358-359: The helper GetTCGEN5MXF4NVF4BlockScaledInstrDesc
currently writes a_sf_id into desc via set_bits(static_cast<uint32_t>(a_sf_id),
29, 2) even though a_sf_id is not a parameter and the field is fixed to zero for
scale_vec::4X; remove that stale write (or alternatively add an a_sf_id
parameter to the function signature and propagate it through if intended) so
that desc no longer attempts to set bits 29-30 from an undefined a_sf_id; update
any related comments to reflect the removal.

In `@tilelang/cuda/intrinsics/macro/tcgen05_macro_generator.py`:
- Around line 468-529: The guard and descriptor for B are incorrect: replace the
assert that assumes B is non-transposed and ensure b_transposed is used
consistently when building b_params; specifically, update the top assertion to
require the correct B layout (use self.b_transposed as done by
tcgen05mma_blockscaled() and compute_tcgen05_b_desc_params()), compute
b_is_k_major = self.b_transposed and pass that into the TCGEN05DescriptorParams
for b_params (instead of hardcoding is_k_major=True), and adjust any related
leading/stride byte offset logic to follow the K-major vs N-major branch
consistent with the _determinate_swizzle_mode(B_buf, ...) handling.
- Around line 531-556: The code reads SFA_tmem.dtype before normalizing
BufferRegion/Buffer, which fails for sliced regions; update the function to
first normalize SFA_tmem and SFB_tmem into their underlying Buffer or
Buffer.data (using the existing BufferRegion/Buffer checks for SFA_tmem and
SFB_tmem) and only then read dtype to compute sf_dtype and is_mxfp4 and call
_ffi_api.get_tcgen5_mxf4nvf4_blockscaled_instr_desc; reference symbols:
SFA_tmem, SFB_tmem, BufferRegion, Buffer, sf_dtype, is_mxfp4, and
_ffi_api.get_tcgen5_mxf4nvf4_blockscaled_instr_desc to locate where to move the
dtype access and normalization logic.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: c496c2e0-db4c-47d6-9afc-cf1243927384

📥 Commits

Reviewing files that changed from the base of the PR and between 48db769 and 8aa292a.

📒 Files selected for processing (6)
  • src/cuda/codegen/codegen_cuda.cc
  • src/cuda/op/gemm.cc
  • src/op/builtin.cc
  • src/op/tcgen5_meta.h
  • src/tl_templates/cuda/instruction/tcgen05mma.h
  • tilelang/cuda/intrinsics/macro/tcgen05_macro_generator.py

Comment on lines +3411 to +3433
bool use_mxf4nvf4 = Downcast<IntImm>(op->args[13])->value != 0;
// args[14] reserved for future mask/flags
bool enable_ws = Downcast<Bool>(op->args[15])->value;
bool enable_2cta = Downcast<Bool>(op->args[16])->value;
ICHECK(!(enable_ws && enable_2cta))
<< "Block-scaled TCGEN05 does not support combining .ws and 2CTA";
ICHECK(!use_mxf4nvf4 || !enable_ws)
<< "mxf4nvf4 block-scaled TCGEN05 currently supports SS only";

auto dtype_enum = tl::codegen::ptx::DTypeFromString(kind_dtype);

need_tcgen05mma_instruction_h_ = true;
this->PrintIndent();
std::string tcgen05_call =
"tl::(tcgen05_name)<(ABType), (USE_2CTA)>(uint64_t((desc_a) + "
"(A_offset)), "
"uint64_t((desc_b) + (B_offset)), (*reinterpret_cast<uint32_t*>((C))) "
"+ (C_offset), "
"(scale_out), static_cast<uint32_t>((desc_val)), "
"(*reinterpret_cast<uint32_t*>((SFA))) + (SFA_offset), "
"(*reinterpret_cast<uint32_t*>((SFB))) + (SFB_offset));\n";
std::string tcgen05_call;
if (use_mxf4nvf4) {
tcgen05_call =
"tl::tcgen05mma_mxf4nvf4_blockscaled_ss<(USE_2CTA)>("
"uint64_t((desc_a) + (A_offset)), "
"uint64_t((desc_b) + (B_offset)), "
"(*reinterpret_cast<uint32_t*>((C))) + (C_offset), "
"(scale_out), static_cast<uint32_t>((desc_val)), "
"(*reinterpret_cast<uint32_t*>((SFA))) + (SFA_offset), "
"(*reinterpret_cast<uint32_t*>((SFB))) + (SFB_offset));\n";

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Validate the operand dtype before forcing the mxf4nvf4 intrinsic.

When use_mxf4nvf4 is true, this branch ignores kind_dtype and always emits tcgen05mma_mxf4nvf4_blockscaled_ss. A mismatched caller will silently lower to the FP4-specific intrinsic instead of failing at the codegen boundary.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/cuda/codegen/codegen_cuda.cc` around lines 3411 - 3433, When use_mxf4nvf4
is true, validate that kind_dtype/dtype_enum is the FP4 variant expected by the
mxf4nvf4 intrinsic before emitting tcgen05mma_mxf4nvf4_blockscaled_ss: call
tl::codegen::ptx::DTypeFromString(kind_dtype) (already present as dtype_enum)
and assert (ICHECK) that it matches the FP4 dtype enum the intrinsic requires,
otherwise fail with a clear error message (or avoid emitting the FP4 intrinsic);
update the branch that constructs tcgen05_call (and/or use_mxf4nvf4 decision) to
perform this dtype check so a mismatched caller errors at codegen rather than
silently lowering to the wrong intrinsic.

Comment thread src/op/tcgen5_meta.h Outdated
Comment thread tilelang/cuda/intrinsics/macro/tcgen05_macro_generator.py Outdated
Comment thread tilelang/cuda/intrinsics/macro/tcgen05_macro_generator.py
…itional tcgen05mma_mxf4nvf4_blockscaled, add example to verify

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
tilelang/cuda/op/gemm/gemm_tcgen05.py (1)

286-341: 💤 Low value

Consider extracting the duplicated MMA dispatch logic.

The is_nvfp4 conditional block is duplicated between _gemm_blockscaled_cond and _gemm_blockscaled. Extracting this to a helper (or using a lambda) would reduce maintenance burden when adding future dispatch variants.

That said, this is a minor improvement and acceptable as-is given the nested T.prim_func structure.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tilelang/cuda/op/gemm/gemm_tcgen05.py` around lines 286 - 341, Extract the
duplicated is_nvfp4 dispatch into a small helper function or lambda and call it
from both _gemm_blockscaled_cond and _gemm_blockscaled to avoid repeating the
MMA call pairs; specifically create a helper (e.g., _call_mma_blockscaled or a
local lambda) that takes (A_shared, B_shared, C_local, SFA_tmem, SFB_tmem,
mbarptr, clear_accum, sf_a_id=None, sf_b_id=None) and inside performs the
is_nvfp4 check to call mma_emitter.tcgen05mma_mxf4nvf4_blockscaled or
mma_emitter.tcgen05mma_blockscaled accordingly, then replace the duplicated
blocks in _gemm_blockscaled_cond and _gemm_blockscaled with a single call to
that helper.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@examples/gemm_fp4/example_gemm_nvfp4_sm100.py`:
- Around line 84-86: SFB_tmem is allocated with the wrong tile dimension causing
incorrect sizing for non-square tiles; change the alloc_tmem call that creates
SFB_tmem (currently T.alloc_tmem([block_M, 4], "uint32")) to use block_N instead
of block_M so it becomes T.alloc_tmem([block_N, 4], "uint32"), ensuring SFB_tmem
matches per-row scales of B; update the allocation site where SFB_tmem is
defined and verify any uses of SFB_tmem assume the N-dimension.

---

Nitpick comments:
In `@tilelang/cuda/op/gemm/gemm_tcgen05.py`:
- Around line 286-341: Extract the duplicated is_nvfp4 dispatch into a small
helper function or lambda and call it from both _gemm_blockscaled_cond and
_gemm_blockscaled to avoid repeating the MMA call pairs; specifically create a
helper (e.g., _call_mma_blockscaled or a local lambda) that takes (A_shared,
B_shared, C_local, SFA_tmem, SFB_tmem, mbarptr, clear_accum, sf_a_id=None,
sf_b_id=None) and inside performs the is_nvfp4 check to call
mma_emitter.tcgen05mma_mxf4nvf4_blockscaled or
mma_emitter.tcgen05mma_blockscaled accordingly, then replace the duplicated
blocks in _gemm_blockscaled_cond and _gemm_blockscaled with a single call to
that helper.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 4a71a3c5-5dc4-4374-9855-5fa2586336b2

📥 Commits

Reviewing files that changed from the base of the PR and between 8aa292a and 4dd1a5e.

📒 Files selected for processing (3)
  • examples/gemm_fp4/example_gemm_nvfp4_sm100.py
  • tilelang/cuda/op/gemm/gemm_tcgen05.py
  • tilelang/language/gemm_op.py

Comment on lines +84 to +86
C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype)
SFA_tmem = T.alloc_tmem([block_M, 4], "uint32")
SFB_tmem = T.alloc_tmem([block_M, 4], "uint32")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

SFB_tmem shape uses block_M but should likely be block_N.

SFB contains per-row scales for B (shape N×K), so SFB_tmem should be sized to [block_N, 4] rather than [block_M, 4]. With the default square blocks (block_M == block_N == 128) this works by coincidence, but will break for non-square tile configurations.

Proposed fix
             C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype)
             SFA_tmem = T.alloc_tmem([block_M, 4], "uint32")
-            SFB_tmem = T.alloc_tmem([block_M, 4], "uint32")
+            SFB_tmem = T.alloc_tmem([block_N, 4], "uint32")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype)
SFA_tmem = T.alloc_tmem([block_M, 4], "uint32")
SFB_tmem = T.alloc_tmem([block_M, 4], "uint32")
C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype)
SFA_tmem = T.alloc_tmem([block_M, 4], "uint32")
SFB_tmem = T.alloc_tmem([block_N, 4], "uint32")
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/gemm_fp4/example_gemm_nvfp4_sm100.py` around lines 84 - 86, SFB_tmem
is allocated with the wrong tile dimension causing incorrect sizing for
non-square tiles; change the alloc_tmem call that creates SFB_tmem (currently
T.alloc_tmem([block_M, 4], "uint32")) to use block_N instead of block_M so it
becomes T.alloc_tmem([block_N, 4], "uint32"), ensuring SFB_tmem matches per-row
scales of B; update the allocation site where SFB_tmem is defined and verify any
uses of SFB_tmem assume the N-dimension.

sepcnt added a commit to sepcnt/tilelang that referenced this pull request Jun 11, 2026
…perands dense (SW64) instead of gap-expanded align16b (uniform-SF verified, rel_err=0 on Thor)

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
src/layout/gemm_layouts.cc (1)

998-1011: ⚠️ Potential issue | 🟠 Major | 🏗️ Heavy lift

Wire dense FP4 into SM100 selection and swizzle merging.

MakeDenseFp4SwizzleLayout2D is added, but the TCGEN05 path still reaches makeGemmABLayoutSm100, where Line 998 routes sub-byte operands through ALIGN16B or linear layout. Line 1093 also returns kNone for dense FP4 layouts, so MergeSwizzleLayouts has no way to preserve/reconstruct the new dense layout. Thread an explicit dense-vs-ALIGN16B layout kind into SM100 selection and add a dense FP4 SwizzleMode merge case. This is required for the dense-packed mxf4nvf4 SM100 contract described by the PR objectives.

Also applies to: 1076-1084, 1093-1100, 1149-1156

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/layout/gemm_layouts.cc` around lines 998 - 1011, The TCGEN05 path in
makeGemmABLayoutSm100 at lines 998-1011 needs to distinguish between dense FP4
and ALIGN16B layouts for sub-byte operands, but currently it treats them
uniformly. Additionally, the logic at lines 1076-1084 and 1093-1100 returns
kNone for dense FP4 layouts without proper differentiation, and
MergeSwizzleLayouts at lines 1149-1156 lacks a merge case for the dense FP4
SwizzleMode. To fix this, introduce an explicit layout kind parameter that
distinguishes dense FP4 from ALIGN16B layouts and thread it through the SM100
selection logic at the anchor location, update the layout kind determination at
lines 1093-1100 to return the appropriate dense FP4 kind instead of kNone,
update the TCGEN05 routing at lines 1076-1084 to handle the dense case, and add
a corresponding dense FP4 SwizzleMode merge case in MergeSwizzleLayouts at lines
1149-1156 to preserve and reconstruct the dense layout.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@src/layout/gemm_layouts.cc`:
- Around line 543-556: The current ICHECK at line 543 only validates that the
FP4 count is even, but this is insufficient because the downstream code uses
16-byte packed vectors even when bank equals 1. Add an additional validation
that byte_continuous is divisible by vector_size (16 bytes) to reject dense FP4
shapes that do not cover a full 16-byte vector. This should be enforced in the
same ICHECK or a new one immediately following the existing check, ensuring that
continuous is at least 32 FP4 elements (since byte_continuous = continuous / 2,
requiring byte_continuous >= 16 means continuous >= 32).

In `@tilelang/layout/swizzle.py`:
- Around line 177-189: The function `make_dense_fp4_swizzled_layout` calls
`_ffi_api.make_dense_fp4_swizzled_layout` directly without checking if the FFI
symbol exists, unlike the similar `make_align16b_swizzled_layout` function which
guards the call with hasattr. Add a hasattr check on _ffi_api for the
make_dense_fp4_swizzled_layout symbol before attempting to call it, and either
provide a graceful fallback or raise a descriptive error message if the symbol
is not available. This ensures users receive a clear, helpful error instead of
an AttributeError when the native library lacks this symbol.

---

Outside diff comments:
In `@src/layout/gemm_layouts.cc`:
- Around line 998-1011: The TCGEN05 path in makeGemmABLayoutSm100 at lines
998-1011 needs to distinguish between dense FP4 and ALIGN16B layouts for
sub-byte operands, but currently it treats them uniformly. Additionally, the
logic at lines 1076-1084 and 1093-1100 returns kNone for dense FP4 layouts
without proper differentiation, and MergeSwizzleLayouts at lines 1149-1156 lacks
a merge case for the dense FP4 SwizzleMode. To fix this, introduce an explicit
layout kind parameter that distinguishes dense FP4 from ALIGN16B layouts and
thread it through the SM100 selection logic at the anchor location, update the
layout kind determination at lines 1093-1100 to return the appropriate dense FP4
kind instead of kNone, update the TCGEN05 routing at lines 1076-1084 to handle
the dense case, and add a corresponding dense FP4 SwizzleMode merge case in
MergeSwizzleLayouts at lines 1149-1156 to preserve and reconstruct the dense
layout.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 9c04581d-1805-4f7b-b7ea-124fa84c11cb

📥 Commits

Reviewing files that changed from the base of the PR and between 4dd1a5e and 1290a96.

📒 Files selected for processing (13)
  • examples/gemm_fp4/example_gemm_nvfp4_sm100.py
  • examples/gemm_fp4/gemm_nvfp4_sm100.cu
  • src/cuda/op/gemm.cc
  • src/layout/gemm_layouts.cc
  • src/layout/layout.cc
  • src/layout/layout.h
  • src/op/tcgen5_meta.h
  • tilelang/cuda/intrinsics/macro/tcgen05_macro_generator.py
  • tilelang/cuda/op/gemm/gemm_tcgen05.py
  • tilelang/cuda/transform/__init__.py
  • tilelang/language/gemm_op.py
  • tilelang/layout/__init__.py
  • tilelang/layout/swizzle.py
🚧 Files skipped from review as they are similar to previous changes (7)
  • tilelang/layout/init.py
  • src/layout/layout.cc
  • src/layout/layout.h
  • src/op/tcgen5_meta.h
  • tilelang/language/gemm_op.py
  • tilelang/cuda/intrinsics/macro/tcgen05_macro_generator.py
  • tilelang/cuda/op/gemm/gemm_tcgen05.py

Comment thread src/layout/gemm_layouts.cc Outdated
Comment thread tilelang/layout/swizzle.py Outdated
… SiLU + routing); align SM100/SM120 NVFP4 MoE with FlashInfer compute paradigm

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
examples/gemm_fp4/example_gemm_fp4_sm100.py (1)

174-208: ⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Missing validation for even K dimension in FP4 mode.

Both make_random_fp4(M, K, ...) and unpack_fp4_to_float(..., K) use K // 2 for the packed axis. If an odd K is passed via --k or TL_FP4_K, the packed buffer will be truncated, leading to silent data loss or shape mismatches during reference verification.

🛡️ Suggested guard after args resolution
     args.m = args.m if args.m is not None else int(os.environ.get("TL_FP4_M", default_mnk))
     args.n = args.n if args.n is not None else int(os.environ.get("TL_FP4_N", default_mnk))
     args.k = args.k if args.k is not None else int(os.environ.get("TL_FP4_K", default_mnk))
+
+    if args.k % 2 != 0:
+        raise ValueError("K must be even for packed FP4 inputs")
 
     if args.nvfp4:
         run_nvfp4(args)

Also applies to: 280-305

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/gemm_fp4/example_gemm_fp4_sm100.py` around lines 174 - 208, The K
dimension must be validated to be even before use in FP4 mode, since
make_random_fp4 and unpack_fp4_to_float both use K // 2 for packed buffers. Add
a validation check after extracting K from args in the run_fp4 function to
ensure K is even; if odd, raise an error or print a clear message explaining the
requirement. Apply the same validation to any other entry points that accept K
as a parameter for FP4 operations.
🧹 Nitpick comments (2)
examples/gemm_fp4/example_gemm_fp4_sm120.py (1)

105-112: 💤 Low value

Redundant accumulator clear.

T.clear(C_local) at line 105 and clear_accum=(ko == 0) at line 112 both zero the accumulator on the first K-tile iteration. This is redundant — either remove the T.clear(C_local) call or change clear_accum to False.

♻️ Suggested fix
-            T.clear(C_local)
             for ko in T.serial(T.ceildiv(K, block_K)):
                 T.copy(A[by * block_M, ko * packed_block_K], A_shared)
                 T.copy(B[bx * block_N, ko * packed_block_K], B_shared)
                 SFA_local[0] = SFA[by, ko]
                 SFB_local[0] = SFB[bx, ko]
                 T.nvfp4_gemm(A_shared, B_shared, SFA_local, SFB_local, C_local,
                              transpose_B=True, clear_accum=(ko == 0))
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/gemm_fp4/example_gemm_fp4_sm120.py` around lines 105 - 112, The
accumulator C_local is being cleared twice redundantly: once explicitly with
T.clear(C_local) before the loop and once implicitly during the first iteration
of the loop via the clear_accum=(ko == 0) parameter in the T.nvfp4_gemm call.
Remove the explicit T.clear(C_local) call since the T.nvfp4_gemm operation
already handles accumulator initialization on the first K-tile iteration through
its clear_accum parameter.
examples/gemm_fp4/example_gemm_fp4_sm100.py (1)

165-171: 💤 Low value

Default --cuda-target assumes aarch64, may fail on x86 hosts.

The default cuda_target is "aarch64-linux", which constructs an include path like /usr/local/cuda/targets/aarch64-linux/include. On x86_64 systems, the correct target directory is typically x86_64-linux. Users on x86 will need to explicitly pass --cuda-target=x86_64-linux or the compilation may fail with missing headers.

Consider auto-detecting the host platform or documenting this in the help text.

Also applies to: 292-292

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/gemm_fp4/example_gemm_fp4_sm100.py` around lines 165 - 171, The
_arch_and_flags function uses a hardcoded default of aarch64-linux for
cuda_target, which fails on x86_64 systems that need x86_64-linux instead.
Auto-detect the host platform to determine the correct default cuda_target based
on the system's architecture (x86_64 vs aarch64). This detection should be
applied at the location where the cuda_target argument default is defined
(around line 292 where the argument parser is likely configured) and ensure the
detected value flows through to where it is used in the _arch_and_flags function
at line 165-171 when constructing the cuda_include path. The auto-detection
should happen early so that args.cuda_target receives the platform-appropriate
default unless explicitly overridden by the user.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@examples/gemm_fp4/example_gemm_fp4_sm100.py`:
- Around line 296-305: Add validation to ensure the K dimension is even in both
FP4 example files. In examples/gemm_fp4/example_gemm_fp4_sm100.py (lines
296-305), after the lines that resolve args.m, args.n, and args.k from
command-line arguments and environment variables, add a check that raises
ValueError with message "K must be even for packed FP4 inputs" if args.k is odd.
Apply the identical validation check in
examples/gemm_fp4/example_gemm_fp4_sm120.py (lines 232-240) after the args
resolution section to prevent silent truncation of packed FP4 buffers when K
values are not divisible by 2.

---

Duplicate comments:
In `@examples/gemm_fp4/example_gemm_fp4_sm100.py`:
- Around line 174-208: The K dimension must be validated to be even before use
in FP4 mode, since make_random_fp4 and unpack_fp4_to_float both use K // 2 for
packed buffers. Add a validation check after extracting K from args in the
run_fp4 function to ensure K is even; if odd, raise an error or print a clear
message explaining the requirement. Apply the same validation to any other entry
points that accept K as a parameter for FP4 operations.

---

Nitpick comments:
In `@examples/gemm_fp4/example_gemm_fp4_sm100.py`:
- Around line 165-171: The _arch_and_flags function uses a hardcoded default of
aarch64-linux for cuda_target, which fails on x86_64 systems that need
x86_64-linux instead. Auto-detect the host platform to determine the correct
default cuda_target based on the system's architecture (x86_64 vs aarch64). This
detection should be applied at the location where the cuda_target argument
default is defined (around line 292 where the argument parser is likely
configured) and ensure the detected value flows through to where it is used in
the _arch_and_flags function at line 165-171 when constructing the cuda_include
path. The auto-detection should happen early so that args.cuda_target receives
the platform-appropriate default unless explicitly overridden by the user.

In `@examples/gemm_fp4/example_gemm_fp4_sm120.py`:
- Around line 105-112: The accumulator C_local is being cleared twice
redundantly: once explicitly with T.clear(C_local) before the loop and once
implicitly during the first iteration of the loop via the clear_accum=(ko == 0)
parameter in the T.nvfp4_gemm call. Remove the explicit T.clear(C_local) call
since the T.nvfp4_gemm operation already handles accumulator initialization on
the first K-tile iteration through its clear_accum parameter.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 746c6f94-16c1-496b-9b4e-f7a397c35497

📥 Commits

Reviewing files that changed from the base of the PR and between 1290a96 and 1111bfd.

📒 Files selected for processing (8)
  • examples/gemm_fp4/example_fusedmoe_nvfp4_sm100.py
  • examples/gemm_fp4/example_gemm_fp4_sm100.py
  • examples/gemm_fp4/example_gemm_fp4_sm120.py
  • src/layout/gemm_layouts.cc
  • src/layout/layout.cc
  • src/layout/layout.h
  • tilelang/layout/__init__.py
  • tilelang/layout/swizzle.py
💤 Files with no reviewable changes (5)
  • tilelang/layout/init.py
  • src/layout/layout.h
  • src/layout/layout.cc
  • tilelang/layout/swizzle.py
  • src/layout/gemm_layouts.cc

Comment on lines +296 to +305
# Size defaults differ by variant; env vars TL_FP4_{M,N,K} override, --m/--n/--k win.
default_mnk = 128 if args.nvfp4 else 256
args.m = args.m if args.m is not None else int(os.environ.get("TL_FP4_M", default_mnk))
args.n = args.n if args.n is not None else int(os.environ.get("TL_FP4_N", default_mnk))
args.k = args.k if args.k is not None else int(os.environ.get("TL_FP4_K", default_mnk))

if args.nvfp4:
run_nvfp4(args)
else:
run_fp4(args)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Add validation for even K dimension in both SM100 and SM120 examples.

Both example scripts use K // 2 for packed FP4 buffer dimensions without validating that K is even. Odd values of K (via --k or TL_FP4_K env var) will silently truncate the packed buffer, causing shape mismatches or incorrect reference calculations.

  • examples/gemm_fp4/example_gemm_fp4_sm100.py#L296-L305: Add if args.k % 2 != 0: raise ValueError("K must be even for packed FP4 inputs") after resolving args.
  • examples/gemm_fp4/example_gemm_fp4_sm120.py#L232-L240: Add the same validation after resolving args.
📍 Affects 2 files
  • examples/gemm_fp4/example_gemm_fp4_sm100.py#L296-L305 (this comment)
  • examples/gemm_fp4/example_gemm_fp4_sm120.py#L232-L240
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/gemm_fp4/example_gemm_fp4_sm100.py` around lines 296 - 305, Add
validation to ensure the K dimension is even in both FP4 example files. In
examples/gemm_fp4/example_gemm_fp4_sm100.py (lines 296-305), after the lines
that resolve args.m, args.n, and args.k from command-line arguments and
environment variables, add a check that raises ValueError with message "K must
be even for packed FP4 inputs" if args.k is odd. Apply the identical validation
check in examples/gemm_fp4/example_gemm_fp4_sm120.py (lines 232-240) after the
args resolution section to prevent silent truncation of packed FP4 buffers when
K values are not divisible by 2.

sepcnt added a commit to sepcnt/tilelang that referenced this pull request Jun 17, 2026
sepcnt added a commit to sepcnt/tilelang that referenced this pull request Jun 17, 2026
sepcnt added a commit to sepcnt/tilelang that referenced this pull request Jun 17, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant