[Stacked][Feature] Support NVFP4 Gemm on Blackwell arch (SM100,110,120) #2324
[Stacked][Feature] Support NVFP4 Gemm on Blackwell arch (SM100,110,120) #2324Hale423 wants to merge 24 commits into
Conversation
Add the plumbing required to route float4_e2m1fn through the TCGEN5 MMA code-generation path so that FP4 GEMM kernels can be emitted on SM100. Changes: - ptx.h / ptx.cc: add kFloat4_e2m1fn enum, string tables, DTypeFromString - common.h: add kFloat4_e2m1fn to device-side DataType enum - tcgen5_meta.h: add FP4 branch in encode_dtype (format code 2) - tcgen05mma.h: add kFloat4_e2m1fn specializations for SS/TS/WS_SS (delegates to the existing f8f6f4 PTX kind) - mma_macro_generator.py: add dtype_abbrv mapping for float4_e2m1fn - docs/GEMM_NV_FP4_FEATURE_STEPS.md: design doc and progress tracker Addresses tile-ai#1592 Made-with: Cursor
… raw bytes, get_ldmatrix_offset add fp4 special layout, add SM120_FP4_FP4_F32_TN MmaDispatcher specialization + fp4 << 2 bit-shift.
Co-authored-by: Cursor <cursoragent@cursor.com>
… as payload, functionality verified valid for naive gemm-fp4
…acksmem path, functionalities from all examples were verified valid
Co-authored-by: Cursor <cursoragent@cursor.com>
…fix ptx_mma_blockscaled meta param must be of StringImm/IntImm, add and verified gemm_nvfp4 & fused moe examples
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds FP4 ( ChangesFP4 + block-scaled GEMM (SM100/SM120)
Sequence Diagram(s)sequenceDiagram
rect rgba(173, 216, 230, 0.5)
Note over User,CUDA: NVFP4 GEMM compilation and runtime
end
participant User as User script
participant gemm_op as nvfp4_gemm / tcgen05_gemm_blockscaled
participant gemm_tcgen05 as GemmTCGEN5 lowering
participant macro as tcgen05_macro_generator
participant codegen as codegen_cuda.cc
participant Device as CUDA device
User->>gemm_op: nvfp4_gemm(A_uint8, B_uint8, SFA, SFB, C)
gemm_op->>gemm_op: compute logical_K=K*2, annotate is_nvfp4=1
gemm_op->>gemm_tcgen05: infer_shared_layout(dtype=fp4, k_major)
gemm_tcgen05->>macro: tcgen05mma_mxf4nvf4_blockscaled(SFA_tmem, SFB_tmem)
macro->>macro: get_tcgen5_mxf4nvf4_blockscaled_instr_desc(is_mxfp4)
macro->>macro: tcgen05_blockscaled_atom(use_mxf4nvf4=True)
macro-->>codegen: ptx_tcgen05_mma_blockscaled_ss(..., use_mxf4nvf4=1)
codegen->>codegen: emit tcgen05mma_mxf4nvf4_blockscaled_ss<use_2cta>
codegen-->>Device: inline PTX tcgen05.mma.kind::mxf4nvf4.block_scale
Device-->>User: FP32 output tensor
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes Possibly related PRs
Suggested reviewers
✨ Finishing Touches🧪 Generate unit tests (beta)
|
There was a problem hiding this comment.
Actionable comments posted: 14
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tilelang/cuda/intrinsics/macro/tcgen05_macro_generator.py (1)
399-418:⚠️ Potential issue | 🟠 Major | 🏗️ Heavy liftScale IDs are constant across the entire K loop.
runtime_instr_descis computed once from fixedsf_a_id/sf_b_idbefore thefor kiloop, andtcgen05_blockscaled_atom()never adjusts them. Whennum_k_atoms > 1, later atoms still read the first scale vector even though the A/B descriptor offsets advance, so multi-block or tile-varying blockscaled GEMMs will miscompute.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tilelang/cuda/intrinsics/macro/tcgen05_macro_generator.py` around lines 399 - 418, runtime_instr_desc is computed once from sf_a_id/sf_b_id before the ki loop, so all atoms use the same scale IDs even though descriptors/offsets advance; move the computation of runtime_instr_desc (or otherwise recompute sf_a_id/sf_b_id) inside the ki loop so each atom uses the current scale IDs for its descriptor, i.e. update how runtime_instr_desc is built before calling tcgen05_blockscaled_atom (referencing runtime_instr_desc, sf_a_id, sf_b_id, tcgen05_blockscaled_atom, sfa_data, sfb_data, a_params, b_params) so multi-atom iterations use the correct per-atom scale vectors.
🧹 Nitpick comments (2)
tilelang/cuda/intrinsics/layout/utils.py (1)
22-29: ⚡ Quick winExpose FP4 in the public dtype contract.
get_ldmatrix_offset()now has a runtime FP4 path, but the signature still only advertises"int4". That makes type-checkers flag the new path and nudges callers toward the wrong dtype string for this feature.Suggested fix
def get_ldmatrix_offset( matrix: Literal["A", "B"], row_idx, col_idx, stride, - dtype: Literal["float16", "int8", "int4"] = "float16", + dtype: Literal["float16", "int8", "int4", "float4_e2m1fn"] = "float16", transposed: bool = False, ):Also update the surrounding doc/error text so FP4 and INT4 stay clearly separated.
Also applies to: 52-63
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tilelang/cuda/intrinsics/layout/utils.py` around lines 22 - 29, The function get_ldmatrix_offset currently advertises dtype: Literal["float16", "int8", "int4"] but the implementation contains a runtime path for FP4; update the public type contract to include "fp4" (or "FP4" matching project convention) so type-checkers accept the code, and adjust any surrounding docstrings/error messages to explicitly list and distinguish "fp4" and "int4" (and/or their canonical casing) where dtype choices are described or validated (also update the equivalent signature/annotations and messages in the similar helper at lines ~52-63 that handle small-int/fp4 cases). Ensure you reference and change the dtype literal in get_ldmatrix_offset and the corresponding validation/error text so FP4 and INT4 remain clearly separated.examples/gemm_fp4/example_fusedmoe_nvfp4_sm120.py (1)
118-132: ⚡ Quick winReuse the staged
Xtile for both GEMMs.The gate and up loops both execute
T.copy(X_bytes[...], X_shared), so everykorereads the same activation block twice. For the benchmarked FC1 path that doubles global-memory traffic onXand makes the “fused” example materially slower than it needs to be.One-loop version
T.clear(gate_local) - for ko in T.serial(hidden_blocks): - T.copy(X_bytes[by * block_tokens, ko * packed_block_hidden], X_shared) - T.copy(W_gate_bytes[bx * block_intermediate, ko * packed_block_hidden], gate_shared) - SFA_local[0] = X_scale[by, ko] - SFB_local[0] = W_gate_scale[bx, ko] - T.nvfp4_gemm(X_shared, gate_shared, SFA_local, SFB_local, gate_local, transpose_B=True, clear_accum=(ko == 0)) - T.clear(up_local) for ko in T.serial(hidden_blocks): T.copy(X_bytes[by * block_tokens, ko * packed_block_hidden], X_shared) + SFA_local[0] = X_scale[by, ko] + + T.copy(W_gate_bytes[bx * block_intermediate, ko * packed_block_hidden], gate_shared) + SFB_local[0] = W_gate_scale[bx, ko] + T.nvfp4_gemm(X_shared, gate_shared, SFA_local, SFB_local, gate_local, transpose_B=True, clear_accum=(ko == 0)) + T.copy(W_up_bytes[bx * block_intermediate, ko * packed_block_hidden], up_shared) - SFA_local[0] = X_scale[by, ko] SFB_local[0] = W_up_scale[bx, ko] T.nvfp4_gemm(X_shared, up_shared, SFA_local, SFB_local, up_local, transpose_B=True, clear_accum=(ko == 0))🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@examples/gemm_fp4/example_fusedmoe_nvfp4_sm120.py` around lines 118 - 132, The X activation tile is copied twice per ko; change the loops so X_bytes is loaded into X_shared once and reused for both GEMMs: for each ko, do a single T.copy(X_bytes[by * block_tokens, ko * packed_block_hidden], X_shared) then call T.nvfp4_gemm for gate (using gate_shared, SFA_local/SFB_local, gate_local, clear_accum=(ko == 0)) and then call T.nvfp4_gemm for up (using up_shared, SFA_local/SFB_local, up_local, clear_accum=(ko == 0)), removing the duplicate T.copy from the second loop (references: X_bytes, X_shared, W_gate_bytes, W_up_bytes, T.nvfp4_gemm, gate_local, up_local).
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@examples/gemm_fp4/example_fusedmoe_a8w4_sm100.py`:
- Around line 40-46: Add an explicit guard that TL_MOE_HIDDEN is even before
treating expert weight tensors as packed FP4 (two 4-bit values per byte);
validate this in the config/parsing path and before any use of
unpack_fp4_to_float and places where expert tensors are created/reshaped
(references: function unpack_fp4_to_float and the code regions handling expert
weight shapes around the other occurrences). If TL_MOE_HIDDEN is odd, raise a
clear ValueError (or argparse/config error) explaining that packed FP4 storage
requires an even hidden dimension, so the failure happens early instead of
during unpacking.
In `@examples/gemm_fp4/example_fusedmoe_a8w4_sm120.py`:
- Around line 136-141: The zero-input smoke test currently only prints PASS/FAIL
and can allow the script to exit successfully on failure; change it to a hard
assertion so failures stop execution: after calling jit_kernel(z_input, z_gate,
z_up) and computing c_zero, replace the print line with an assertion that
c_zero.abs().max().item() == 0.0 (or torch.equal(c_zero,
torch.zeros_like(c_zero))) and include a descriptive message (e.g., "zeros in ->
zeros out failed") so the test fails closed; refer to the variables z_input,
z_gate, z_up and the function jit_kernel to locate where to apply this change.
In `@examples/gemm_fp4/example_gemm_a8w4_sm100.py`:
- Around line 35-41: The code assumes K is even when packing/unpacking FP4 (see
unpack_fp4_to_float) which silently truncates when TL_A8W4_K is odd; add an
explicit check (raise ValueError or assert) that TL_A8W4_K % 2 == 0 before any
packing or buffer allocation/reshape, and apply the same guard near the FP4
packing logic and any other unpacking uses (the unpack_fp4_to_float function and
the FP4 weight-packing sites) so the script fails fast with a clear error
instead of allocating a truncated buffer.
In `@examples/gemm_fp4/example_gemm_fp4_sm100.py`:
- Around line 131-137: The unpack_fp4_to_float routine (and corresponding calls
like make_random_fp4) assume the packed axis length K//2 (or cols//2) is an
integer; add an upfront validation in unpack_fp4_to_float (and the code that
calls it, e.g., make_random_fp4 and any uses conditioned on
TL_FP4_TRANSPOSE_B/TL_FP4_K/TL_FP4_N) to check that K (and any cols passed as
packed length*2) is even and raise a clear ValueError if not (or alternatively
document/pad explicitly). Locate the unpack_fp4_to_float function and the
make_random_fp4 call sites referenced in the diff and enforce this guard so odd
TL_FP4_K or TL_FP4_N cannot silently truncate the buffer.
In `@examples/gemm_fp4/example_gemm_nvfp4_sm120.py`:
- Around line 71-109: The kernel assumes full 16x8x64 tiles but the launch dims
use ceildiv, so for any M, N, or K not divisible by block_M=16, block_N=8,
block_K=64 the kernel will read/write out-of-bounds; fix by rejecting
unsupported sizes early in matmul_nvfp4_sm120 (before building/returning main) —
check that M % block_M == 0, N % block_N == 0, and K % block_K == 0 and raise a
clear exception (e.g., ValueError) if not, so callers must provide padded inputs
or a different kernel; reference block_M/block_N/block_K and the prim func main
where copies (T.copy to/from A_shared/B_shared and final T.copy to C) currently
assume full tiles.
- Around line 112-124: The example lacks a runtime SM120 capability check before
calling tilelang.compile for the SM120-only frontend (T.nvfp4_gemm) in the
__main__ path; add a preflight that queries the current device capability (via
tilelang.utils.target.target_is_sm120 or equivalent helper) and fail fast with a
clear "SM120 required" message if the device is not SM120, returning/non-zero
exit instead of proceeding to tilelang.compile and kernel creation; update both
example_gemm_nvfp4_sm120 entrypoint logic and any similar SM120-only example
functions to guard before compilation and import/use of SM120-specific
primitives.
In `@src/cuda/codegen/codegen_cuda.cc`:
- Around line 4746-4754: The scope check using GetPtrStorageScope(buffer_var) is
incomplete—GetBufferRef() first consults alloc_storage_scope_ so allocated
shared/local FP4 buffers can be misclassified as global/packed; update the FP4
packed-load/store branches to resolve scope the same way GetBufferRef() does by
checking alloc_storage_scope_ for buffer_var.get() first and falling back to
GetPtrStorageScope(buffer_var), then use that resolved scope when computing
is_packed_fp4_scope; apply this change in both the load (tl_fp4_packed_load) and
store branches.
In `@src/cuda/op/copy.cc`:
- Around line 1721-1728: The code path handling is_align16b_subbyte_layout
currently calls TMABytesFromElements and so ignores transaction-byte accounting
(and the float4_e2m1_unpacked FP4 special case), causing mbarrier_expect_tx to
be overstated; modify the branch where total_bytes is set (the block that uses
inner_box_dim, instruction_dim, total_elements, and shared_tensor->dtype) to
call TMATransactionBytesFromElements(total_elements * loop_extent,
shared_tensor->dtype) instead of TMABytesFromElements so transaction-byte
accounting (and FP4 handling) is used for ALIGN16B loads.
In `@src/layout/gemm_layouts.cc`:
- Around line 1027-1035: DetectSwizzleMode() is collapsing ALIGN16B into
SwizzleMode::kFull causing MergeSwizzleLayouts() to reconstruct the wrong layout
with makeFullBankSwizzleLayout(buffer); add a distinct swizzle mode (e.g.,
SwizzleMode::kAlign16B) or, when merging, branch on info.element_size < 8 to
detect ALIGN16B and reconstruct it with makeAlign16BSwizzleLayout(buffer)
instead of makeFullBankSwizzleLayout(buffer); update the SwizzleMode enum and
the switch/merge logic in MergeSwizzleLayouts() to return/use the new kAlign16B
(or the element_size-based branch) so ALIGN16B mappings for sub-byte types
remain preserved.
In `@src/tl_templates/cuda/instruction/mma.h`:
- Around line 361-381: The FP4 values loaded by ptx_ldmatrix_b4x16_* end up in
the low nibble and must be repacked the same way as in mma_sync; update
mma_sync_blockscaled to apply the FP4 repack to the A and B operands before
calling Dispatcher::exec (i.e., perform the same left-shift/repack used in
mma_sync on the arrays of Dispatcher::ARegType and Dispatcher::BRegType prior to
dispatch) so BlockScaledMmaDispatcher::exec receives the correctly encoded FP4
data and produces correct accumulations.
In `@tilelang/cuda/intrinsics/macro/mma_macro_generator.py`:
- Around line 678-709: The scale operands are never advanced per k_inner, so
SFA_local_buf and SFB_local_buf always use base 0; compute SFA_offset and
SFB_offset analogous to A_offset/B_offset using k_inner (e.g. advancing by
k_inner * element_stride_or_vector_size used for scale storage) and pass those
offsets into T.ptx_mma_blockscaled instead of always 0; update the
_atom_mma_blockscaled macro to use the new SFA_offset and SFB_offset
(referencing k_inner, SFA_local_buf, SFB_local_buf, and the
_atom_mma_blockscaled/T.ptx_mma_blockscaled call) so each ki iteration uses the
correct per-k_inner scale vectors.
- Around line 126-136: The check for FP4 chunk sizes must reject non-multiples
rather than only enforcing a minimum: in the block-scaled path (when
self.is_blockscaled and str(a_dtype) == "float4_e2m1fn") validate that
self.chunk is a multiple of 64 (e.g. self.chunk % 64 == 0) and raise a
ValueError if not; in the plain FP4 path (when str(a_dtype) == "float4_e2m1fn")
validate that self.chunk is a multiple of 32 (self.chunk % 32 == 0) instead of
only checking self.chunk < 32, and raise a clear ValueError if the modulus check
fails; keep setting self.k_dim = 64 (block-scaled) or self.k_dim = min(32,
self.chunk) after the validation.
In `@tilelang/language/gemm_op.py`:
- Around line 269-320: The nvfp4_gemm path dropped the _gemm_impl base-offset
and rank guards so row-sliced packed tiles can lose their outer-dimension base
offset and malformed inputs raise IndexError; restore the same validation: after
computing A_offset = retrieve_offset(A_region) and B_offset =
retrieve_offset(B_region) (and their shapes via retrieve_shape if needed),
assert that A_offset and B_offset have the expected rank (e.g., len(A_offset) >=
2 and len(B_offset) >= 2) and assert A_offset[-2] == 0 and B_offset[-2] == 0 (or
alternatively call _gemm_impl to reuse its checks) before forwarding
A_offset[-1] and B_offset[-1] into tirx.call_intrin in nvfp4_gemm so
outer-dimension base offsets cannot be silently lost.
In `@tilelang/layout/swizzle.py`:
- Around line 91-120: This code currently flattens all leading dims into a
single stride and calls _ffi_api.make_tcgen05mma_swizzled_layout which lets the
swizzle cross outer-dimension boundaries; instead, compute the 2D layout from
the trailing two dims only (use shape[-2:] and the corresponding last-two
stride/continuity values from _get_stride_continuous/stride), call
_ffi_api.make_tcgen05mma_swizzled_layout to build the base 2D layout for those
last two dims, then expand that 2D base across the leading dimensions
(shape[:-2]) using the same ExpandLayout2D-style expansion used by the native
C++ factories so the Python path matches the buffer-based path for rank>2
buffers.
---
Outside diff comments:
In `@tilelang/cuda/intrinsics/macro/tcgen05_macro_generator.py`:
- Around line 399-418: runtime_instr_desc is computed once from sf_a_id/sf_b_id
before the ki loop, so all atoms use the same scale IDs even though
descriptors/offsets advance; move the computation of runtime_instr_desc (or
otherwise recompute sf_a_id/sf_b_id) inside the ki loop so each atom uses the
current scale IDs for its descriptor, i.e. update how runtime_instr_desc is
built before calling tcgen05_blockscaled_atom (referencing runtime_instr_desc,
sf_a_id, sf_b_id, tcgen05_blockscaled_atom, sfa_data, sfb_data, a_params,
b_params) so multi-atom iterations use the correct per-atom scale vectors.
---
Nitpick comments:
In `@examples/gemm_fp4/example_fusedmoe_nvfp4_sm120.py`:
- Around line 118-132: The X activation tile is copied twice per ko; change the
loops so X_bytes is loaded into X_shared once and reused for both GEMMs: for
each ko, do a single T.copy(X_bytes[by * block_tokens, ko *
packed_block_hidden], X_shared) then call T.nvfp4_gemm for gate (using
gate_shared, SFA_local/SFB_local, gate_local, clear_accum=(ko == 0)) and then
call T.nvfp4_gemm for up (using up_shared, SFA_local/SFB_local, up_local,
clear_accum=(ko == 0)), removing the duplicate T.copy from the second loop
(references: X_bytes, X_shared, W_gate_bytes, W_up_bytes, T.nvfp4_gemm,
gate_local, up_local).
In `@tilelang/cuda/intrinsics/layout/utils.py`:
- Around line 22-29: The function get_ldmatrix_offset currently advertises
dtype: Literal["float16", "int8", "int4"] but the implementation contains a
runtime path for FP4; update the public type contract to include "fp4" (or "FP4"
matching project convention) so type-checkers accept the code, and adjust any
surrounding docstrings/error messages to explicitly list and distinguish "fp4"
and "int4" (and/or their canonical casing) where dtype choices are described or
validated (also update the equivalent signature/annotations and messages in the
similar helper at lines ~52-63 that handle small-int/fp4 cases). Ensure you
reference and change the dtype literal in get_ldmatrix_offset and the
corresponding validation/error text so FP4 and INT4 remain clearly separated.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 233d92a6-159d-4126-8152-e37c62dc19ad
📒 Files selected for processing (38)
examples/gemm_fp4/example_fusedmoe_a8w4_sm100.pyexamples/gemm_fp4/example_fusedmoe_a8w4_sm120.pyexamples/gemm_fp4/example_fusedmoe_nvfp4_sm120.pyexamples/gemm_fp4/example_gemm_a8w4_sm100.pyexamples/gemm_fp4/example_gemm_a8w4_sm120.pyexamples/gemm_fp4/example_gemm_fp4_sm100.pyexamples/gemm_fp4/example_gemm_fp4_sm120.pyexamples/gemm_fp4/example_gemm_nvfp4_sm120.pysrc/cuda/codegen/codegen_cuda.ccsrc/cuda/op/copy.ccsrc/cuda/op/gemm.ccsrc/layout/gemm_layouts.ccsrc/layout/layout.ccsrc/layout/layout.hsrc/op/builtin.ccsrc/op/builtin.hsrc/op/tcgen5_meta.hsrc/op/utils.ccsrc/tl_templates/cuda/common.hsrc/tl_templates/cuda/cuda_fp4.hsrc/tl_templates/cuda/gemm_mma.hsrc/tl_templates/cuda/gemm_sm100.hsrc/tl_templates/cuda/instruction/mma.hsrc/tl_templates/cuda/ldsm.htilelang/cuda/intrinsics/layout/mma_layout.pytilelang/cuda/intrinsics/layout/utils.pytilelang/cuda/intrinsics/macro/mma_macro_generator.pytilelang/cuda/intrinsics/macro/tcgen05_macro_generator.pytilelang/cuda/op/gemm/gemm_mma.pytilelang/cuda/op/gemm/gemm_tcgen05.pytilelang/language/__init__.pytilelang/language/ast/ir.pytilelang/language/gemm_op.pytilelang/language/tir/ir.pytilelang/language/tir/op.pytilelang/layout/__init__.pytilelang/layout/swizzle.pytilelang/tileop/gemm/gemm_base.py
| def unpack_fp4_to_float(packed_int8, rows, cols): | ||
| lut = torch.tensor(FP4_E2M1_TO_FLOAT, dtype=torch.float32, device=packed_int8.device) | ||
| flat = packed_int8.to(torch.uint8).reshape(rows, cols // 2) | ||
| lo = flat & 0x0F | ||
| hi = (flat >> 4) & 0x0F | ||
| unpacked = torch.stack([lo, hi], dim=-1).reshape(rows, cols).to(torch.int64) | ||
| return lut[unpacked] |
There was a problem hiding this comment.
Guard TL_MOE_HIDDEN for packed FP4 storage.
Both expert weight tensors are packed as (d_expert, d_hidden // 2), so this example only works when d_hidden is even. Right now an odd TL_MOE_HIDDEN is accepted and the failure shows up later during unpack/reference instead of at config parsing.
Suggested guard
num_tokens = int(os.environ.get("TL_MOE_TOKENS", "128"))
d_hidden = int(os.environ.get("TL_MOE_HIDDEN", "256"))
d_expert = int(os.environ.get("TL_MOE_EXPERT", "256"))
@@
block_token = int(os.environ.get("TL_MOE_BLOCK_TOKEN", "128"))
block_hidden = int(os.environ.get("TL_MOE_BLOCK_HIDDEN", "128"))
block_expert = int(os.environ.get("TL_MOE_BLOCK_EXPERT", "64"))
+
+if d_hidden % 2 != 0:
+ raise ValueError("TL_MOE_HIDDEN must be even for packed FP4 expert weights")Also applies to: 115-120, 141-155
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@examples/gemm_fp4/example_fusedmoe_a8w4_sm100.py` around lines 40 - 46, Add
an explicit guard that TL_MOE_HIDDEN is even before treating expert weight
tensors as packed FP4 (two 4-bit values per byte); validate this in the
config/parsing path and before any use of unpack_fp4_to_float and places where
expert tensors are created/reshaped (references: function unpack_fp4_to_float
and the code regions handling expert weight shapes around the other
occurrences). If TL_MOE_HIDDEN is odd, raise a clear ValueError (or
argparse/config error) explaining that packed FP4 storage requires an even
hidden dimension, so the failure happens early instead of during unpacking.
| # --- Test 1: zeros --- | ||
| z_input = torch.zeros(num_tokens, d_hidden, device="cuda", dtype=torch.float8_e4m3fn) | ||
| z_gate = torch.zeros(d_expert, d_hidden, device="cuda", dtype=torch.uint8) | ||
| z_up = torch.zeros(d_expert, d_hidden, device="cuda", dtype=torch.uint8) | ||
| c_zero = jit_kernel(z_input, z_gate, z_up) | ||
| print(f"[{'PASS' if c_zero.abs().max().item() == 0.0 else 'FAIL'}] zeros in -> zeros out") |
There was a problem hiding this comment.
Make the zero-input smoke test fail closed.
Right now this only prints [FAIL] and keeps going, so the example still exits successfully even when the fused kernel is fundamentally broken. The other new example scripts already assert here; this one should too.
Suggested fix
c_zero = jit_kernel(z_input, z_gate, z_up)
-print(f"[{'PASS' if c_zero.abs().max().item() == 0.0 else 'FAIL'}] zeros in -> zeros out")
+assert c_zero.abs().max().item() == 0.0, f"Zero test failed: max={c_zero.abs().max().item()}"
+print("[PASS] zeros in -> zeros out")📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # --- Test 1: zeros --- | |
| z_input = torch.zeros(num_tokens, d_hidden, device="cuda", dtype=torch.float8_e4m3fn) | |
| z_gate = torch.zeros(d_expert, d_hidden, device="cuda", dtype=torch.uint8) | |
| z_up = torch.zeros(d_expert, d_hidden, device="cuda", dtype=torch.uint8) | |
| c_zero = jit_kernel(z_input, z_gate, z_up) | |
| print(f"[{'PASS' if c_zero.abs().max().item() == 0.0 else 'FAIL'}] zeros in -> zeros out") | |
| # --- Test 1: zeros --- | |
| z_input = torch.zeros(num_tokens, d_hidden, device="cuda", dtype=torch.float8_e4m3fn) | |
| z_gate = torch.zeros(d_expert, d_hidden, device="cuda", dtype=torch.uint8) | |
| z_up = torch.zeros(d_expert, d_hidden, device="cuda", dtype=torch.uint8) | |
| c_zero = jit_kernel(z_input, z_gate, z_up) | |
| assert c_zero.abs().max().item() == 0.0, f"Zero test failed: max={c_zero.abs().max().item()}" | |
| print("[PASS] zeros in -> zeros out") |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@examples/gemm_fp4/example_fusedmoe_a8w4_sm120.py` around lines 136 - 141, The
zero-input smoke test currently only prints PASS/FAIL and can allow the script
to exit successfully on failure; change it to a hard assertion so failures stop
execution: after calling jit_kernel(z_input, z_gate, z_up) and computing c_zero,
replace the print line with an assertion that c_zero.abs().max().item() == 0.0
(or torch.equal(c_zero, torch.zeros_like(c_zero))) and include a descriptive
message (e.g., "zeros in -> zeros out failed") so the test fails closed; refer
to the variables z_input, z_gate, z_up and the function jit_kernel to locate
where to apply this change.
| def unpack_fp4_to_float(packed_int8, rows, cols): | ||
| lut = torch.tensor(FP4_E2M1_TO_FLOAT, dtype=torch.float32, device=packed_int8.device) | ||
| flat = packed_int8.to(torch.uint8).reshape(rows, cols // 2) | ||
| lo = flat & 0x0F | ||
| hi = (flat >> 4) & 0x0F | ||
| unpacked = torch.stack([lo, hi], dim=-1).reshape(rows, cols).to(torch.int64) | ||
| return lut[unpacked] |
There was a problem hiding this comment.
Reject odd K before building packed FP4 weights.
This example packs the FP4 weight matrix as (N, K // 2) and unpacks it with the same assumption. If TL_A8W4_K is odd, the script silently allocates a truncated buffer and then fails later in a less obvious place.
Suggested guard
M = int(os.environ.get("TL_A8W4_M", "256"))
N = int(os.environ.get("TL_A8W4_N", "256"))
K = int(os.environ.get("TL_A8W4_K", "256"))
@@
block_M = int(os.environ.get("TL_A8W4_BLOCK_M", "128"))
block_N = int(os.environ.get("TL_A8W4_BLOCK_N", "64"))
block_K = int(os.environ.get("TL_A8W4_BLOCK_K", "128"))
+
+if K % 2 != 0:
+ raise ValueError("TL_A8W4_K must be even for packed FP4 weights")
print(f"Running SM100 A8W4 GEMM: M={M}, N={N}, K={K}, block=({block_M},{block_N},{block_K})")Also applies to: 83-88, 106-115
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@examples/gemm_fp4/example_gemm_a8w4_sm100.py` around lines 35 - 41, The code
assumes K is even when packing/unpacking FP4 (see unpack_fp4_to_float) which
silently truncates when TL_A8W4_K is odd; add an explicit check (raise
ValueError or assert) that TL_A8W4_K % 2 == 0 before any packing or buffer
allocation/reshape, and apply the same guard near the FP4 packing logic and any
other unpacking uses (the unpack_fp4_to_float function and the FP4
weight-packing sites) so the script fails fast with a clear error instead of
allocating a truncated buffer.
| def unpack_fp4_to_float(packed_int8, M, K): | ||
| lut = torch.tensor(FP4_E2M1_TO_FLOAT, dtype=torch.float32, device=packed_int8.device) | ||
| flat = packed_int8.to(torch.uint8).reshape(M, K // 2) | ||
| lo = flat & 0x0F | ||
| hi = (flat >> 4) & 0x0F | ||
| unpacked = torch.stack([lo, hi], dim=-1).reshape(M, K).to(torch.int64) | ||
| return lut[unpacked] |
There was a problem hiding this comment.
Validate the packed FP4 dimensions up front.
make_random_fp4(..., cols // 2) and unpack_fp4_to_float(...reshape(..., K // 2)) both assume the packed axis is even. With TL_FP4_K odd — and with TL_FP4_N odd when TL_FP4_TRANSPOSE_B=0 — this truncates the byte buffer shape and the example fails later with a much less obvious reshape/reference error.
Suggested guard
M = int(os.environ.get("TL_FP4_M", "256"))
N = int(os.environ.get("TL_FP4_N", "256"))
K = int(os.environ.get("TL_FP4_K", "256"))
@@
input_mode = os.environ.get("TL_FP4_INPUT_MODE", "random")
transpose_b = os.environ.get("TL_FP4_TRANSPOSE_B", "1") != "0"
+
+if K % 2 != 0:
+ raise ValueError("TL_FP4_K must be even for packed FP4 inputs")
+if not transpose_b and N % 2 != 0:
+ raise ValueError("TL_FP4_N must be even when TL_FP4_TRANSPOSE_B=0")
+
print(f"Running FP4 GEMM (SM100/SM110 TCGEN05): M={M}, N={N}, K={K}, input_mode={input_mode}, transpose_b={transpose_b}")Also applies to: 140-153, 186-206
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@examples/gemm_fp4/example_gemm_fp4_sm100.py` around lines 131 - 137, The
unpack_fp4_to_float routine (and corresponding calls like make_random_fp4)
assume the packed axis length K//2 (or cols//2) is an integer; add an upfront
validation in unpack_fp4_to_float (and the code that calls it, e.g.,
make_random_fp4 and any uses conditioned on
TL_FP4_TRANSPOSE_B/TL_FP4_K/TL_FP4_N) to check that K (and any cols passed as
packed length*2) is even and raise a clear ValueError if not (or alternatively
document/pad explicitly). Locate the unpack_fp4_to_float function and the
make_random_fp4 call sites referenced in the diff and enforce this guard so odd
TL_FP4_K or TL_FP4_N cannot silently truncate the buffer.
| template <DataType AType, DataType BType, DataType CType, DataType SFType, | ||
| int M, int N, int K, bool TransA, bool TransB, int VS> | ||
| TL_DEVICE void mma_sync_blockscaled( | ||
| typename detail::BlockScaledMmaDispatcher< | ||
| AType, BType, CType, SFType, M, N, K, TransA, TransB, VS>::CRegType *c, | ||
| const typename detail::BlockScaledMmaDispatcher< | ||
| AType, BType, CType, SFType, M, N, K, TransA, TransB, VS>::ARegType *a, | ||
| const typename detail::BlockScaledMmaDispatcher< | ||
| AType, BType, CType, SFType, M, N, K, TransA, TransB, VS>::BRegType *b, | ||
| const typename detail::BlockScaledMmaDispatcher< | ||
| AType, BType, CType, SFType, M, N, K, TransA, TransB, VS>::SFRegType | ||
| *sfa, | ||
| const typename detail::BlockScaledMmaDispatcher< | ||
| AType, BType, CType, SFType, M, N, K, TransA, TransB, VS>::SFRegType | ||
| *sfb) { | ||
| using Dispatcher = detail::BlockScaledMmaDispatcher< | ||
| AType, BType, CType, SFType, M, N, K, TransA, TransB, VS>; | ||
| static_assert(!std::is_void_v<typename Dispatcher::CRegType>, | ||
| "tl::mma_sync_blockscaled: unsupported configuration"); | ||
|
|
||
| Dispatcher::exec(c, a, b, c, sfa, sfb); |
There was a problem hiding this comment.
Apply the FP4 repack in mma_sync_blockscaled too.
ptx_ldmatrix_b4x16_* leaves each FP4 value in the low nibble, and mma_sync() compensates with << 2 before dispatch. This path forwards A/B unchanged, so NVFP4 block-scaled GEMM consumes a different FP4 encoding than the plain SM120 FP4 path and will accumulate wrong values.
Suggested fix
template <DataType AType, DataType BType, DataType CType, DataType SFType,
int M, int N, int K, bool TransA, bool TransB, int VS>
TL_DEVICE void mma_sync_blockscaled(
@@
using Dispatcher = detail::BlockScaledMmaDispatcher<
AType, BType, CType, SFType, M, N, K, TransA, TransB, VS>;
static_assert(!std::is_void_v<typename Dispatcher::CRegType>,
"tl::mma_sync_blockscaled: unsupported configuration");
-
- Dispatcher::exec(c, a, b, c, sfa, sfb);
+ if constexpr (AType == DataType::kFloat4_e2m1fn ||
+ BType == DataType::kFloat4_e2m1fn) {
+ using AReg = typename Dispatcher::ARegType;
+ using BReg = typename Dispatcher::BRegType;
+ constexpr int nA =
+ detail::BlockScaledMmaImplTraits<typename Dispatcher::Impl>::kARegs;
+ constexpr int nB =
+ detail::BlockScaledMmaImplTraits<typename Dispatcher::Impl>::kBRegs;
+ AReg as[nA];
+ BReg bs[nB];
+#pragma unroll
+ for (int i = 0; i < nA; ++i)
+ as[i] = (AType == DataType::kFloat4_e2m1fn) ? (a[i] << 2) : a[i];
+#pragma unroll
+ for (int i = 0; i < nB; ++i)
+ bs[i] = (BType == DataType::kFloat4_e2m1fn) ? (b[i] << 2) : b[i];
+ Dispatcher::exec(c, as, bs, c, sfa, sfb);
+ } else {
+ Dispatcher::exec(c, a, b, c, sfa, sfb);
+ }
}🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@src/tl_templates/cuda/instruction/mma.h` around lines 361 - 381, The FP4
values loaded by ptx_ldmatrix_b4x16_* end up in the low nibble and must be
repacked the same way as in mma_sync; update mma_sync_blockscaled to apply the
FP4 repack to the A and B operands before calling Dispatcher::exec (i.e.,
perform the same left-shift/repack used in mma_sync on the arrays of
Dispatcher::ARegType and Dispatcher::BRegType prior to dispatch) so
BlockScaledMmaDispatcher::exec receives the correctly encoded FP4 data and
produces correct accumulations.
| if self.is_blockscaled: | ||
| if str(a_dtype) != "float4_e2m1fn" or self.chunk < 64: | ||
| raise ValueError("SM120 block-scaled NVFP4 MMA requires FP4 inputs and chunk >= 64") | ||
| self.k_dim = 64 | ||
| return | ||
| # SM120 f8f6f4 FP4 MMA is m16n8k32. Although 256 / 4 would allow a | ||
| # k64 fragment by bit count, there is no k64 dispatcher for FP4. | ||
| if str(a_dtype) == "float4_e2m1fn": | ||
| if self.chunk < 32: | ||
| raise ValueError("FP4 MMA requires chunk to be a multiple of 32 (m16n8k32)") | ||
| self.k_dim = min(32, self.chunk) |
There was a problem hiding this comment.
Reject non-multiple FP4 chunk sizes.
Line 127 and Line 134 only enforce a minimum size, but the lowering paths consume block_K // micro_size_k fixed-width atoms. A chunk of 96 in blockscaled mode or 48 in plain FP4 mode therefore emits one k64/k32 step and silently drops the tail instead of failing fast.
Proposed fix
if self.is_blockscaled:
- if str(a_dtype) != "float4_e2m1fn" or self.chunk < 64:
- raise ValueError("SM120 block-scaled NVFP4 MMA requires FP4 inputs and chunk >= 64")
+ if str(a_dtype) != "float4_e2m1fn" or self.chunk % 64 != 0:
+ raise ValueError("SM120 block-scaled NVFP4 MMA requires FP4 inputs and chunk to be a multiple of 64")
self.k_dim = 64
return
# SM120 f8f6f4 FP4 MMA is m16n8k32. Although 256 / 4 would allow a
# k64 fragment by bit count, there is no k64 dispatcher for FP4.
if str(a_dtype) == "float4_e2m1fn":
- if self.chunk < 32:
+ if self.chunk % 32 != 0:
raise ValueError("FP4 MMA requires chunk to be a multiple of 32 (m16n8k32)")
self.k_dim = min(32, self.chunk)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if self.is_blockscaled: | |
| if str(a_dtype) != "float4_e2m1fn" or self.chunk < 64: | |
| raise ValueError("SM120 block-scaled NVFP4 MMA requires FP4 inputs and chunk >= 64") | |
| self.k_dim = 64 | |
| return | |
| # SM120 f8f6f4 FP4 MMA is m16n8k32. Although 256 / 4 would allow a | |
| # k64 fragment by bit count, there is no k64 dispatcher for FP4. | |
| if str(a_dtype) == "float4_e2m1fn": | |
| if self.chunk < 32: | |
| raise ValueError("FP4 MMA requires chunk to be a multiple of 32 (m16n8k32)") | |
| self.k_dim = min(32, self.chunk) | |
| if self.is_blockscaled: | |
| if str(a_dtype) != "float4_e2m1fn" or self.chunk % 64 != 0: | |
| raise ValueError("SM120 block-scaled NVFP4 MMA requires FP4 inputs and chunk to be a multiple of 64") | |
| self.k_dim = 64 | |
| return | |
| # SM120 f8f6f4 FP4 MMA is m16n8k32. Although 256 / 4 would allow a | |
| # k64 fragment by bit count, there is no k64 dispatcher for FP4. | |
| if str(a_dtype) == "float4_e2m1fn": | |
| if self.chunk % 32 != 0: | |
| raise ValueError("FP4 MMA requires chunk to be a multiple of 32 (m16n8k32)") | |
| self.k_dim = min(32, self.chunk) |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tilelang/cuda/intrinsics/macro/mma_macro_generator.py` around lines 126 -
136, The check for FP4 chunk sizes must reject non-multiples rather than only
enforcing a minimum: in the block-scaled path (when self.is_blockscaled and
str(a_dtype) == "float4_e2m1fn") validate that self.chunk is a multiple of 64
(e.g. self.chunk % 64 == 0) and raise a ValueError if not; in the plain FP4 path
(when str(a_dtype) == "float4_e2m1fn") validate that self.chunk is a multiple of
32 (self.chunk % 32 == 0) instead of only checking self.chunk < 32, and raise a
clear ValueError if the modulus check fails; keep setting self.k_dim = 64
(block-scaled) or self.k_dim = min(32, self.chunk) after the validation.
| a_is_fragment = is_fragment(A_local_buf) | ||
| b_is_fragment = is_fragment(B_local_buf) | ||
| a_local_stride: PrimExpr = k_inner * self.warp_rows * local_size_a if a_is_fragment else 0 | ||
| b_local_stride: PrimExpr = k_inner * self.warp_cols * local_size_b if b_is_fragment else 0 | ||
|
|
||
| A_offset = a_local_stride + inst_m_idx * local_size_a | ||
| B_offset = b_local_stride + inst_n_idx * local_size_b | ||
| C_offset = inst_m_idx * self.warp_cols * local_size_out + inst_n_idx * local_size_out | ||
|
|
||
| @T.macro | ||
| def _atom_mma_blockscaled(A_local_buf, B_local_buf, C_local_buf, SFA_local_buf, SFB_local_buf): | ||
| T.ptx_mma_blockscaled( | ||
| accum_dtype, | ||
| mma_prefix, | ||
| "row", | ||
| "col", | ||
| a_dtype_abbrv, | ||
| b_dtype_abbrv, | ||
| accum_dtype_abbrv, | ||
| self.scale_dtype, | ||
| self.scale_vec_size, | ||
| A_local_buf.data, | ||
| A_offset, | ||
| B_local_buf.data, | ||
| B_offset, | ||
| C_local_buf.data, | ||
| C_offset, | ||
| SFA_local_buf.data, | ||
| SFB_local_buf.data, | ||
| ) | ||
|
|
||
| return _atom_mma_blockscaled(A_local_buf, B_local_buf, C_local_buf, SFA_local_buf, SFB_local_buf) |
There was a problem hiding this comment.
Advance the scale operands with k_inner.
A_offset and B_offset move per k_inner, but T.ptx_mma_blockscaled always receives SFA_local_buf.data and SFB_local_buf.data at base 0. The caller loops over ki, so any K > 64 or non-uniform per-block scales will keep reusing the first scale vector and accumulate with the wrong exponents.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tilelang/cuda/intrinsics/macro/mma_macro_generator.py` around lines 678 -
709, The scale operands are never advanced per k_inner, so SFA_local_buf and
SFB_local_buf always use base 0; compute SFA_offset and SFB_offset analogous to
A_offset/B_offset using k_inner (e.g. advancing by k_inner *
element_stride_or_vector_size used for scale storage) and pass those offsets
into T.ptx_mma_blockscaled instead of always 0; update the _atom_mma_blockscaled
macro to use the new SFA_offset and SFB_offset (referencing k_inner,
SFA_local_buf, SFB_local_buf, and the
_atom_mma_blockscaled/T.ptx_mma_blockscaled call) so each ki iteration uses the
correct per-k_inner scale vectors.
| A_region = to_buffer_region(A) | ||
| B_region = to_buffer_region(B) | ||
| C_region = to_buffer_region(C) | ||
| SFA_region = to_buffer_region(SFA) | ||
| SFB_region = to_buffer_region(SFB) | ||
|
|
||
| A_shape = retrieve_shape(A_region) | ||
| B_shape = retrieve_shape(B_region) | ||
| C_shape = retrieve_shape(C_region) | ||
|
|
||
| M, N = C_shape | ||
| M_A = A_shape[-1] if transpose_A else A_shape[-2] | ||
| K = A_shape[-2] if transpose_A else A_shape[-1] | ||
| N_B = B_shape[-2] if transpose_B else B_shape[-1] | ||
| K_B = B_shape[-1] if transpose_B else B_shape[-2] | ||
| assert prim_expr_equal(M_A, M), f"T.nvfp4_gemm M shape check failed: M_A = {M_A}, M_C = {M}" | ||
| assert prim_expr_equal(N_B, N), f"T.nvfp4_gemm N shape check failed: N_B = {N_B}, N_C = {N}" | ||
| assert prim_expr_equal(K, K_B), f"T.nvfp4_gemm K shape check failed: K_A = {K}, K_B = {K_B}" | ||
| # The current SM120 mxf4nvf4 ldmatrix path consumes packed bytes in shared | ||
| # memory (two FP4 values per byte). The GEMM node still needs the logical | ||
| # K so the emitter selects m16n8k64. | ||
| logical_K = K * 2 if str(A_region.buffer.dtype) == "uint8" and str(B_region.buffer.dtype) == "uint8" else K | ||
|
|
||
| A_stride = retrieve_stride(A_region) | ||
| B_stride = retrieve_stride(B_region) | ||
| A_offset = retrieve_offset(A_region) | ||
| B_offset = retrieve_offset(B_region) | ||
| C_coords = [r.min for r in C_region.region] | ||
|
|
||
| A_arg = buffer_region_to_tile_region(A_region, "r", [r for r in A_shape]) | ||
| B_arg = buffer_region_to_tile_region(B_region, "r", [r for r in B_shape]) | ||
| C_arg = buffer_region_to_tile_region(C_region, "rw", [r for r in C_shape]) | ||
| SFA_arg = buffer_region_to_tile_region(SFA_region, "r", list(retrieve_shape(SFA_region))) | ||
| SFB_arg = buffer_region_to_tile_region(SFB_region, "r", list(retrieve_shape(SFB_region))) | ||
|
|
||
| return tirx.call_intrin( | ||
| "handle", | ||
| tirx.op.Op.get("tl.tileop.gemm"), | ||
| A_arg, | ||
| B_arg, | ||
| C_arg, | ||
| transpose_A, | ||
| transpose_B, | ||
| M, | ||
| N, | ||
| logical_K, | ||
| policy, | ||
| clear_accum, | ||
| A_stride[-2], | ||
| B_stride[-2], | ||
| A_offset[-1], | ||
| B_offset[-1], |
There was a problem hiding this comment.
Restore _gemm_impl’s A/B base-offset validation here.
nvfp4_gemm only forwards A_offset[-1] and B_offset[-1], but unlike _gemm_impl it no longer asserts A_offset[-2] == 0 / B_offset[-2] == 0. That means a row-sliced packed tile can silently lose its outer-dimension base offset and read the wrong data. This duplicate path also dropped the rank checks, so malformed inputs now fail with IndexError later instead of a targeted assertion. Reusing _gemm_impl or reinstating the same guards here would close both gaps.
Suggested fix
A_shape = retrieve_shape(A_region)
B_shape = retrieve_shape(B_region)
C_shape = retrieve_shape(C_region)
+
+ assert len(C_shape) == 2, "current only support C as a 2D tensor"
+ assert len(A_shape) >= 2, "current only support A as a 2D or higher-order tensor"
+ assert len(B_shape) >= 2, "current only support B as a 2D or higher-order tensor"
M, N = C_shape
M_A = A_shape[-1] if transpose_A else A_shape[-2]
K = A_shape[-2] if transpose_A else A_shape[-1]
N_B = B_shape[-2] if transpose_B else B_shape[-1]
@@
A_stride = retrieve_stride(A_region)
B_stride = retrieve_stride(B_region)
A_offset = retrieve_offset(A_region)
B_offset = retrieve_offset(B_region)
+ assert A_offset[-2] == 0, "The offset of the first dimension of A must be 0"
+ assert B_offset[-2] == 0, "The offset of the first dimension of B must be 0"
C_coords = [r.min for r in C_region.region]🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tilelang/language/gemm_op.py` around lines 269 - 320, The nvfp4_gemm path
dropped the _gemm_impl base-offset and rank guards so row-sliced packed tiles
can lose their outer-dimension base offset and malformed inputs raise
IndexError; restore the same validation: after computing A_offset =
retrieve_offset(A_region) and B_offset = retrieve_offset(B_region) (and their
shapes via retrieve_shape if needed), assert that A_offset and B_offset have the
expected rank (e.g., len(A_offset) >= 2 and len(B_offset) >= 2) and assert
A_offset[-2] == 0 and B_offset[-2] == 0 (or alternatively call _gemm_impl to
reuse its checks) before forwarding A_offset[-1] and B_offset[-1] into
tirx.call_intrin in nvfp4_gemm so outer-dimension base offsets cannot be
silently lost.
| global _WARNED_LEGACY_TCGEN05_LAYOUT_FFI | ||
| buf, shape, _ = _get_buffer_info(buffer) | ||
| stride, continuous = _get_stride_continuous(buffer) | ||
| element_size = _get_element_size(buffer) | ||
| if continuity is None: | ||
| continuity = -1 | ||
| return _ffi_api.make_tcgen05mma_swizzled_layout(buf, continuity, k_major) | ||
| continuity = continuous | ||
| try: | ||
| base = _ffi_api.make_tcgen05mma_swizzled_layout( | ||
| stride, | ||
| continuous, | ||
| continuity, | ||
| element_size, | ||
| k_major, | ||
| ) | ||
| return base.reshape(shape) | ||
| except TypeError as err: | ||
| # Keep Python sources compatible with older built libs that still expose | ||
| # the legacy FFI signature: (buffer, continuity, k_major). | ||
| if "Mismatched number of arguments" not in str(err): | ||
| raise | ||
| if not _WARNED_LEGACY_TCGEN05_LAYOUT_FFI: | ||
| warnings.warn( | ||
| "Detected legacy tcgen05 swizzle-layout FFI in the loaded native " | ||
| "TileLang library. Rebuild the native `build/` artifacts so FP4 " | ||
| "SM100/SM110 kernels use the current gap-aware ALIGN16B layout.", | ||
| RuntimeWarning, | ||
| stacklevel=2, | ||
| ) | ||
| _WARNED_LEGACY_TCGEN05_LAYOUT_FFI = True | ||
| return _ffi_api.make_tcgen05mma_swizzled_layout(buf, continuity, k_major) |
There was a problem hiding this comment.
The new TCGEN05 Python path changes the layout for rank>2 buffers.
The native C++ factories only swizzle the last two dimensions and keep leading dims unchanged via ExpandLayout2D(...), but this code flattens every leading dim into stride and then calls base.reshape(shape). For 3D+ buffers that lets the swizzle span outer-dimension boundaries and produces a different mapping than the buffer-based path. Build the 2D base from shape[-2:] and expand across shape[:-2] instead; otherwise make_align16b_swizzled_layout() inherits the same bug whenever it falls back here.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tilelang/layout/swizzle.py` around lines 91 - 120, This code currently
flattens all leading dims into a single stride and calls
_ffi_api.make_tcgen05mma_swizzled_layout which lets the swizzle cross
outer-dimension boundaries; instead, compute the 2D layout from the trailing two
dims only (use shape[-2:] and the corresponding last-two stride/continuity
values from _get_stride_continuous/stride), call
_ffi_api.make_tcgen05mma_swizzled_layout to build the base 2D layout for those
last two dims, then expand that 2D base across the leading dimensions
(shape[:-2]) using the same ExpandLayout2D-style expansion used by the native
C++ factories so the Python path matches the buffer-based path for rank>2
buffers.
…ected in codegen_cuda.cc, called at tcgen05_macro_generator.py
There was a problem hiding this comment.
Actionable comments posted: 4
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@src/cuda/codegen/codegen_cuda.cc`:
- Around line 3411-3433: When use_mxf4nvf4 is true, validate that
kind_dtype/dtype_enum is the FP4 variant expected by the mxf4nvf4 intrinsic
before emitting tcgen05mma_mxf4nvf4_blockscaled_ss: call
tl::codegen::ptx::DTypeFromString(kind_dtype) (already present as dtype_enum)
and assert (ICHECK) that it matches the FP4 dtype enum the intrinsic requires,
otherwise fail with a clear error message (or avoid emitting the FP4 intrinsic);
update the branch that constructs tcgen05_call (and/or use_mxf4nvf4 decision) to
perform this dtype check so a mismatched caller errors at codegen rather than
silently lowering to the wrong intrinsic.
In `@src/op/tcgen5_meta.h`:
- Around line 358-359: The helper GetTCGEN5MXF4NVF4BlockScaledInstrDesc
currently writes a_sf_id into desc via set_bits(static_cast<uint32_t>(a_sf_id),
29, 2) even though a_sf_id is not a parameter and the field is fixed to zero for
scale_vec::4X; remove that stale write (or alternatively add an a_sf_id
parameter to the function signature and propagate it through if intended) so
that desc no longer attempts to set bits 29-30 from an undefined a_sf_id; update
any related comments to reflect the removal.
In `@tilelang/cuda/intrinsics/macro/tcgen05_macro_generator.py`:
- Around line 468-529: The guard and descriptor for B are incorrect: replace the
assert that assumes B is non-transposed and ensure b_transposed is used
consistently when building b_params; specifically, update the top assertion to
require the correct B layout (use self.b_transposed as done by
tcgen05mma_blockscaled() and compute_tcgen05_b_desc_params()), compute
b_is_k_major = self.b_transposed and pass that into the TCGEN05DescriptorParams
for b_params (instead of hardcoding is_k_major=True), and adjust any related
leading/stride byte offset logic to follow the K-major vs N-major branch
consistent with the _determinate_swizzle_mode(B_buf, ...) handling.
- Around line 531-556: The code reads SFA_tmem.dtype before normalizing
BufferRegion/Buffer, which fails for sliced regions; update the function to
first normalize SFA_tmem and SFB_tmem into their underlying Buffer or
Buffer.data (using the existing BufferRegion/Buffer checks for SFA_tmem and
SFB_tmem) and only then read dtype to compute sf_dtype and is_mxfp4 and call
_ffi_api.get_tcgen5_mxf4nvf4_blockscaled_instr_desc; reference symbols:
SFA_tmem, SFB_tmem, BufferRegion, Buffer, sf_dtype, is_mxfp4, and
_ffi_api.get_tcgen5_mxf4nvf4_blockscaled_instr_desc to locate where to move the
dtype access and normalization logic.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: c496c2e0-db4c-47d6-9afc-cf1243927384
📒 Files selected for processing (6)
src/cuda/codegen/codegen_cuda.ccsrc/cuda/op/gemm.ccsrc/op/builtin.ccsrc/op/tcgen5_meta.hsrc/tl_templates/cuda/instruction/tcgen05mma.htilelang/cuda/intrinsics/macro/tcgen05_macro_generator.py
| bool use_mxf4nvf4 = Downcast<IntImm>(op->args[13])->value != 0; | ||
| // args[14] reserved for future mask/flags | ||
| bool enable_ws = Downcast<Bool>(op->args[15])->value; | ||
| bool enable_2cta = Downcast<Bool>(op->args[16])->value; | ||
| ICHECK(!(enable_ws && enable_2cta)) | ||
| << "Block-scaled TCGEN05 does not support combining .ws and 2CTA"; | ||
| ICHECK(!use_mxf4nvf4 || !enable_ws) | ||
| << "mxf4nvf4 block-scaled TCGEN05 currently supports SS only"; | ||
|
|
||
| auto dtype_enum = tl::codegen::ptx::DTypeFromString(kind_dtype); | ||
|
|
||
| need_tcgen05mma_instruction_h_ = true; | ||
| this->PrintIndent(); | ||
| std::string tcgen05_call = | ||
| "tl::(tcgen05_name)<(ABType), (USE_2CTA)>(uint64_t((desc_a) + " | ||
| "(A_offset)), " | ||
| "uint64_t((desc_b) + (B_offset)), (*reinterpret_cast<uint32_t*>((C))) " | ||
| "+ (C_offset), " | ||
| "(scale_out), static_cast<uint32_t>((desc_val)), " | ||
| "(*reinterpret_cast<uint32_t*>((SFA))) + (SFA_offset), " | ||
| "(*reinterpret_cast<uint32_t*>((SFB))) + (SFB_offset));\n"; | ||
| std::string tcgen05_call; | ||
| if (use_mxf4nvf4) { | ||
| tcgen05_call = | ||
| "tl::tcgen05mma_mxf4nvf4_blockscaled_ss<(USE_2CTA)>(" | ||
| "uint64_t((desc_a) + (A_offset)), " | ||
| "uint64_t((desc_b) + (B_offset)), " | ||
| "(*reinterpret_cast<uint32_t*>((C))) + (C_offset), " | ||
| "(scale_out), static_cast<uint32_t>((desc_val)), " | ||
| "(*reinterpret_cast<uint32_t*>((SFA))) + (SFA_offset), " | ||
| "(*reinterpret_cast<uint32_t*>((SFB))) + (SFB_offset));\n"; |
There was a problem hiding this comment.
Validate the operand dtype before forcing the mxf4nvf4 intrinsic.
When use_mxf4nvf4 is true, this branch ignores kind_dtype and always emits tcgen05mma_mxf4nvf4_blockscaled_ss. A mismatched caller will silently lower to the FP4-specific intrinsic instead of failing at the codegen boundary.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@src/cuda/codegen/codegen_cuda.cc` around lines 3411 - 3433, When use_mxf4nvf4
is true, validate that kind_dtype/dtype_enum is the FP4 variant expected by the
mxf4nvf4 intrinsic before emitting tcgen05mma_mxf4nvf4_blockscaled_ss: call
tl::codegen::ptx::DTypeFromString(kind_dtype) (already present as dtype_enum)
and assert (ICHECK) that it matches the FP4 dtype enum the intrinsic requires,
otherwise fail with a clear error message (or avoid emitting the FP4 intrinsic);
update the branch that constructs tcgen05_call (and/or use_mxf4nvf4 decision) to
perform this dtype check so a mismatched caller errors at codegen rather than
silently lowering to the wrong intrinsic.
…itional tcgen05mma_mxf4nvf4_blockscaled, add example to verify
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
tilelang/cuda/op/gemm/gemm_tcgen05.py (1)
286-341: 💤 Low valueConsider extracting the duplicated MMA dispatch logic.
The
is_nvfp4conditional block is duplicated between_gemm_blockscaled_condand_gemm_blockscaled. Extracting this to a helper (or using a lambda) would reduce maintenance burden when adding future dispatch variants.That said, this is a minor improvement and acceptable as-is given the nested
T.prim_funcstructure.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tilelang/cuda/op/gemm/gemm_tcgen05.py` around lines 286 - 341, Extract the duplicated is_nvfp4 dispatch into a small helper function or lambda and call it from both _gemm_blockscaled_cond and _gemm_blockscaled to avoid repeating the MMA call pairs; specifically create a helper (e.g., _call_mma_blockscaled or a local lambda) that takes (A_shared, B_shared, C_local, SFA_tmem, SFB_tmem, mbarptr, clear_accum, sf_a_id=None, sf_b_id=None) and inside performs the is_nvfp4 check to call mma_emitter.tcgen05mma_mxf4nvf4_blockscaled or mma_emitter.tcgen05mma_blockscaled accordingly, then replace the duplicated blocks in _gemm_blockscaled_cond and _gemm_blockscaled with a single call to that helper.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@examples/gemm_fp4/example_gemm_nvfp4_sm100.py`:
- Around line 84-86: SFB_tmem is allocated with the wrong tile dimension causing
incorrect sizing for non-square tiles; change the alloc_tmem call that creates
SFB_tmem (currently T.alloc_tmem([block_M, 4], "uint32")) to use block_N instead
of block_M so it becomes T.alloc_tmem([block_N, 4], "uint32"), ensuring SFB_tmem
matches per-row scales of B; update the allocation site where SFB_tmem is
defined and verify any uses of SFB_tmem assume the N-dimension.
---
Nitpick comments:
In `@tilelang/cuda/op/gemm/gemm_tcgen05.py`:
- Around line 286-341: Extract the duplicated is_nvfp4 dispatch into a small
helper function or lambda and call it from both _gemm_blockscaled_cond and
_gemm_blockscaled to avoid repeating the MMA call pairs; specifically create a
helper (e.g., _call_mma_blockscaled or a local lambda) that takes (A_shared,
B_shared, C_local, SFA_tmem, SFB_tmem, mbarptr, clear_accum, sf_a_id=None,
sf_b_id=None) and inside performs the is_nvfp4 check to call
mma_emitter.tcgen05mma_mxf4nvf4_blockscaled or
mma_emitter.tcgen05mma_blockscaled accordingly, then replace the duplicated
blocks in _gemm_blockscaled_cond and _gemm_blockscaled with a single call to
that helper.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 4a71a3c5-5dc4-4374-9855-5fa2586336b2
📒 Files selected for processing (3)
examples/gemm_fp4/example_gemm_nvfp4_sm100.pytilelang/cuda/op/gemm/gemm_tcgen05.pytilelang/language/gemm_op.py
| C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype) | ||
| SFA_tmem = T.alloc_tmem([block_M, 4], "uint32") | ||
| SFB_tmem = T.alloc_tmem([block_M, 4], "uint32") |
There was a problem hiding this comment.
SFB_tmem shape uses block_M but should likely be block_N.
SFB contains per-row scales for B (shape N×K), so SFB_tmem should be sized to [block_N, 4] rather than [block_M, 4]. With the default square blocks (block_M == block_N == 128) this works by coincidence, but will break for non-square tile configurations.
Proposed fix
C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype)
SFA_tmem = T.alloc_tmem([block_M, 4], "uint32")
- SFB_tmem = T.alloc_tmem([block_M, 4], "uint32")
+ SFB_tmem = T.alloc_tmem([block_N, 4], "uint32")📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype) | |
| SFA_tmem = T.alloc_tmem([block_M, 4], "uint32") | |
| SFB_tmem = T.alloc_tmem([block_M, 4], "uint32") | |
| C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype) | |
| SFA_tmem = T.alloc_tmem([block_M, 4], "uint32") | |
| SFB_tmem = T.alloc_tmem([block_N, 4], "uint32") |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@examples/gemm_fp4/example_gemm_nvfp4_sm100.py` around lines 84 - 86, SFB_tmem
is allocated with the wrong tile dimension causing incorrect sizing for
non-square tiles; change the alloc_tmem call that creates SFB_tmem (currently
T.alloc_tmem([block_M, 4], "uint32")) to use block_N instead of block_M so it
becomes T.alloc_tmem([block_N, 4], "uint32"), ensuring SFB_tmem matches per-row
scales of B; update the allocation site where SFB_tmem is defined and verify any
uses of SFB_tmem assume the N-dimension.
…perands dense (SW64) instead of gap-expanded align16b (uniform-SF verified, rel_err=0 on Thor)
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/layout/gemm_layouts.cc (1)
998-1011:⚠️ Potential issue | 🟠 Major | 🏗️ Heavy liftWire dense FP4 into SM100 selection and swizzle merging.
MakeDenseFp4SwizzleLayout2Dis added, but the TCGEN05 path still reachesmakeGemmABLayoutSm100, where Line 998 routes sub-byte operands through ALIGN16B or linear layout. Line 1093 also returnskNonefor dense FP4 layouts, soMergeSwizzleLayoutshas no way to preserve/reconstruct the new dense layout. Thread an explicit dense-vs-ALIGN16B layout kind into SM100 selection and add a dense FP4SwizzleModemerge case. This is required for the dense-packedmxf4nvf4SM100 contract described by the PR objectives.Also applies to: 1076-1084, 1093-1100, 1149-1156
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@src/layout/gemm_layouts.cc` around lines 998 - 1011, The TCGEN05 path in makeGemmABLayoutSm100 at lines 998-1011 needs to distinguish between dense FP4 and ALIGN16B layouts for sub-byte operands, but currently it treats them uniformly. Additionally, the logic at lines 1076-1084 and 1093-1100 returns kNone for dense FP4 layouts without proper differentiation, and MergeSwizzleLayouts at lines 1149-1156 lacks a merge case for the dense FP4 SwizzleMode. To fix this, introduce an explicit layout kind parameter that distinguishes dense FP4 from ALIGN16B layouts and thread it through the SM100 selection logic at the anchor location, update the layout kind determination at lines 1093-1100 to return the appropriate dense FP4 kind instead of kNone, update the TCGEN05 routing at lines 1076-1084 to handle the dense case, and add a corresponding dense FP4 SwizzleMode merge case in MergeSwizzleLayouts at lines 1149-1156 to preserve and reconstruct the dense layout.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@src/layout/gemm_layouts.cc`:
- Around line 543-556: The current ICHECK at line 543 only validates that the
FP4 count is even, but this is insufficient because the downstream code uses
16-byte packed vectors even when bank equals 1. Add an additional validation
that byte_continuous is divisible by vector_size (16 bytes) to reject dense FP4
shapes that do not cover a full 16-byte vector. This should be enforced in the
same ICHECK or a new one immediately following the existing check, ensuring that
continuous is at least 32 FP4 elements (since byte_continuous = continuous / 2,
requiring byte_continuous >= 16 means continuous >= 32).
In `@tilelang/layout/swizzle.py`:
- Around line 177-189: The function `make_dense_fp4_swizzled_layout` calls
`_ffi_api.make_dense_fp4_swizzled_layout` directly without checking if the FFI
symbol exists, unlike the similar `make_align16b_swizzled_layout` function which
guards the call with hasattr. Add a hasattr check on _ffi_api for the
make_dense_fp4_swizzled_layout symbol before attempting to call it, and either
provide a graceful fallback or raise a descriptive error message if the symbol
is not available. This ensures users receive a clear, helpful error instead of
an AttributeError when the native library lacks this symbol.
---
Outside diff comments:
In `@src/layout/gemm_layouts.cc`:
- Around line 998-1011: The TCGEN05 path in makeGemmABLayoutSm100 at lines
998-1011 needs to distinguish between dense FP4 and ALIGN16B layouts for
sub-byte operands, but currently it treats them uniformly. Additionally, the
logic at lines 1076-1084 and 1093-1100 returns kNone for dense FP4 layouts
without proper differentiation, and MergeSwizzleLayouts at lines 1149-1156 lacks
a merge case for the dense FP4 SwizzleMode. To fix this, introduce an explicit
layout kind parameter that distinguishes dense FP4 from ALIGN16B layouts and
thread it through the SM100 selection logic at the anchor location, update the
layout kind determination at lines 1093-1100 to return the appropriate dense FP4
kind instead of kNone, update the TCGEN05 routing at lines 1076-1084 to handle
the dense case, and add a corresponding dense FP4 SwizzleMode merge case in
MergeSwizzleLayouts at lines 1149-1156 to preserve and reconstruct the dense
layout.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 9c04581d-1805-4f7b-b7ea-124fa84c11cb
📒 Files selected for processing (13)
examples/gemm_fp4/example_gemm_nvfp4_sm100.pyexamples/gemm_fp4/gemm_nvfp4_sm100.cusrc/cuda/op/gemm.ccsrc/layout/gemm_layouts.ccsrc/layout/layout.ccsrc/layout/layout.hsrc/op/tcgen5_meta.htilelang/cuda/intrinsics/macro/tcgen05_macro_generator.pytilelang/cuda/op/gemm/gemm_tcgen05.pytilelang/cuda/transform/__init__.pytilelang/language/gemm_op.pytilelang/layout/__init__.pytilelang/layout/swizzle.py
🚧 Files skipped from review as they are similar to previous changes (7)
- tilelang/layout/init.py
- src/layout/layout.cc
- src/layout/layout.h
- src/op/tcgen5_meta.h
- tilelang/language/gemm_op.py
- tilelang/cuda/intrinsics/macro/tcgen05_macro_generator.py
- tilelang/cuda/op/gemm/gemm_tcgen05.py
… SiLU + routing); align SM100/SM120 NVFP4 MoE with FlashInfer compute paradigm
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (1)
examples/gemm_fp4/example_gemm_fp4_sm100.py (1)
174-208:⚠️ Potential issue | 🟡 Minor | ⚡ Quick winMissing validation for even K dimension in FP4 mode.
Both
make_random_fp4(M, K, ...)andunpack_fp4_to_float(..., K)useK // 2for the packed axis. If an oddKis passed via--korTL_FP4_K, the packed buffer will be truncated, leading to silent data loss or shape mismatches during reference verification.🛡️ Suggested guard after args resolution
args.m = args.m if args.m is not None else int(os.environ.get("TL_FP4_M", default_mnk)) args.n = args.n if args.n is not None else int(os.environ.get("TL_FP4_N", default_mnk)) args.k = args.k if args.k is not None else int(os.environ.get("TL_FP4_K", default_mnk)) + + if args.k % 2 != 0: + raise ValueError("K must be even for packed FP4 inputs") if args.nvfp4: run_nvfp4(args)Also applies to: 280-305
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@examples/gemm_fp4/example_gemm_fp4_sm100.py` around lines 174 - 208, The K dimension must be validated to be even before use in FP4 mode, since make_random_fp4 and unpack_fp4_to_float both use K // 2 for packed buffers. Add a validation check after extracting K from args in the run_fp4 function to ensure K is even; if odd, raise an error or print a clear message explaining the requirement. Apply the same validation to any other entry points that accept K as a parameter for FP4 operations.
🧹 Nitpick comments (2)
examples/gemm_fp4/example_gemm_fp4_sm120.py (1)
105-112: 💤 Low valueRedundant accumulator clear.
T.clear(C_local)at line 105 andclear_accum=(ko == 0)at line 112 both zero the accumulator on the first K-tile iteration. This is redundant — either remove theT.clear(C_local)call or changeclear_accumtoFalse.♻️ Suggested fix
- T.clear(C_local) for ko in T.serial(T.ceildiv(K, block_K)): T.copy(A[by * block_M, ko * packed_block_K], A_shared) T.copy(B[bx * block_N, ko * packed_block_K], B_shared) SFA_local[0] = SFA[by, ko] SFB_local[0] = SFB[bx, ko] T.nvfp4_gemm(A_shared, B_shared, SFA_local, SFB_local, C_local, transpose_B=True, clear_accum=(ko == 0))🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@examples/gemm_fp4/example_gemm_fp4_sm120.py` around lines 105 - 112, The accumulator C_local is being cleared twice redundantly: once explicitly with T.clear(C_local) before the loop and once implicitly during the first iteration of the loop via the clear_accum=(ko == 0) parameter in the T.nvfp4_gemm call. Remove the explicit T.clear(C_local) call since the T.nvfp4_gemm operation already handles accumulator initialization on the first K-tile iteration through its clear_accum parameter.examples/gemm_fp4/example_gemm_fp4_sm100.py (1)
165-171: 💤 Low valueDefault
--cuda-targetassumes aarch64, may fail on x86 hosts.The default
cuda_targetis"aarch64-linux", which constructs an include path like/usr/local/cuda/targets/aarch64-linux/include. On x86_64 systems, the correct target directory is typicallyx86_64-linux. Users on x86 will need to explicitly pass--cuda-target=x86_64-linuxor the compilation may fail with missing headers.Consider auto-detecting the host platform or documenting this in the help text.
Also applies to: 292-292
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@examples/gemm_fp4/example_gemm_fp4_sm100.py` around lines 165 - 171, The _arch_and_flags function uses a hardcoded default of aarch64-linux for cuda_target, which fails on x86_64 systems that need x86_64-linux instead. Auto-detect the host platform to determine the correct default cuda_target based on the system's architecture (x86_64 vs aarch64). This detection should be applied at the location where the cuda_target argument default is defined (around line 292 where the argument parser is likely configured) and ensure the detected value flows through to where it is used in the _arch_and_flags function at line 165-171 when constructing the cuda_include path. The auto-detection should happen early so that args.cuda_target receives the platform-appropriate default unless explicitly overridden by the user.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@examples/gemm_fp4/example_gemm_fp4_sm100.py`:
- Around line 296-305: Add validation to ensure the K dimension is even in both
FP4 example files. In examples/gemm_fp4/example_gemm_fp4_sm100.py (lines
296-305), after the lines that resolve args.m, args.n, and args.k from
command-line arguments and environment variables, add a check that raises
ValueError with message "K must be even for packed FP4 inputs" if args.k is odd.
Apply the identical validation check in
examples/gemm_fp4/example_gemm_fp4_sm120.py (lines 232-240) after the args
resolution section to prevent silent truncation of packed FP4 buffers when K
values are not divisible by 2.
---
Duplicate comments:
In `@examples/gemm_fp4/example_gemm_fp4_sm100.py`:
- Around line 174-208: The K dimension must be validated to be even before use
in FP4 mode, since make_random_fp4 and unpack_fp4_to_float both use K // 2 for
packed buffers. Add a validation check after extracting K from args in the
run_fp4 function to ensure K is even; if odd, raise an error or print a clear
message explaining the requirement. Apply the same validation to any other entry
points that accept K as a parameter for FP4 operations.
---
Nitpick comments:
In `@examples/gemm_fp4/example_gemm_fp4_sm100.py`:
- Around line 165-171: The _arch_and_flags function uses a hardcoded default of
aarch64-linux for cuda_target, which fails on x86_64 systems that need
x86_64-linux instead. Auto-detect the host platform to determine the correct
default cuda_target based on the system's architecture (x86_64 vs aarch64). This
detection should be applied at the location where the cuda_target argument
default is defined (around line 292 where the argument parser is likely
configured) and ensure the detected value flows through to where it is used in
the _arch_and_flags function at line 165-171 when constructing the cuda_include
path. The auto-detection should happen early so that args.cuda_target receives
the platform-appropriate default unless explicitly overridden by the user.
In `@examples/gemm_fp4/example_gemm_fp4_sm120.py`:
- Around line 105-112: The accumulator C_local is being cleared twice
redundantly: once explicitly with T.clear(C_local) before the loop and once
implicitly during the first iteration of the loop via the clear_accum=(ko == 0)
parameter in the T.nvfp4_gemm call. Remove the explicit T.clear(C_local) call
since the T.nvfp4_gemm operation already handles accumulator initialization on
the first K-tile iteration through its clear_accum parameter.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 746c6f94-16c1-496b-9b4e-f7a397c35497
📒 Files selected for processing (8)
examples/gemm_fp4/example_fusedmoe_nvfp4_sm100.pyexamples/gemm_fp4/example_gemm_fp4_sm100.pyexamples/gemm_fp4/example_gemm_fp4_sm120.pysrc/layout/gemm_layouts.ccsrc/layout/layout.ccsrc/layout/layout.htilelang/layout/__init__.pytilelang/layout/swizzle.py
💤 Files with no reviewable changes (5)
- tilelang/layout/init.py
- src/layout/layout.h
- src/layout/layout.cc
- tilelang/layout/swizzle.py
- src/layout/gemm_layouts.cc
| # Size defaults differ by variant; env vars TL_FP4_{M,N,K} override, --m/--n/--k win. | ||
| default_mnk = 128 if args.nvfp4 else 256 | ||
| args.m = args.m if args.m is not None else int(os.environ.get("TL_FP4_M", default_mnk)) | ||
| args.n = args.n if args.n is not None else int(os.environ.get("TL_FP4_N", default_mnk)) | ||
| args.k = args.k if args.k is not None else int(os.environ.get("TL_FP4_K", default_mnk)) | ||
|
|
||
| if args.nvfp4: | ||
| run_nvfp4(args) | ||
| else: | ||
| run_fp4(args) |
There was a problem hiding this comment.
Add validation for even K dimension in both SM100 and SM120 examples.
Both example scripts use K // 2 for packed FP4 buffer dimensions without validating that K is even. Odd values of K (via --k or TL_FP4_K env var) will silently truncate the packed buffer, causing shape mismatches or incorrect reference calculations.
examples/gemm_fp4/example_gemm_fp4_sm100.py#L296-L305: Addif args.k % 2 != 0: raise ValueError("K must be even for packed FP4 inputs")after resolving args.examples/gemm_fp4/example_gemm_fp4_sm120.py#L232-L240: Add the same validation after resolving args.
📍 Affects 2 files
examples/gemm_fp4/example_gemm_fp4_sm100.py#L296-L305(this comment)examples/gemm_fp4/example_gemm_fp4_sm120.py#L232-L240
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@examples/gemm_fp4/example_gemm_fp4_sm100.py` around lines 296 - 305, Add
validation to ensure the K dimension is even in both FP4 example files. In
examples/gemm_fp4/example_gemm_fp4_sm100.py (lines 296-305), after the lines
that resolve args.m, args.n, and args.k from command-line arguments and
environment variables, add a check that raises ValueError with message "K must
be even for packed FP4 inputs" if args.k is odd. Apply the identical validation
check in examples/gemm_fp4/example_gemm_fp4_sm120.py (lines 232-240) after the
args resolution section to prevent silent truncation of packed FP4 buffers when
K values are not divisible by 2.
Summary
To address the feature requested in bug #1592, and stacked on PR #2182,
this PR starts Blackwell NVFP4 support by adding a minimal block-scaled FP4 GEMM path and examples based on
mxf4nvf4.block_scale.The implementation introduces a thin
T.nvfp4_gemm(...)frontend for SM120, lowers it through the existing CUDA MMA tile-op path, and emitsmma.sync.aligned.kind::mxf4nvf4.block_scalewith packed FP4 E2M1 operands and FP8 E4M3 scale factors.It also adds NVFP4 GEMM and fused MoE examples that follow the FlashInfer/TRT-LLM style data contract: packed FP4 activations/weights plus per-block E4M3 scale factors.
What changed
SM120 (mma.sync, scale factors in registers)
SM100/SM110 (TCGEN05, scale factors in TMEM)
Examples: unified --nvfp4 flag (default = plain FP4, --nvfp4 = NVFP4 block-scaled):
Validation (nvfp4 only)
SM120
SM100/SM110
Summary by CodeRabbit