Skip to content
Open
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
a356049
feat: add float4_e2m1fn (NV FP4) support for SM100 TCGEN5 MMA
Hale423 Mar 9, 2026
44336d7
fix: python pipeline (LayoutInference + LowerTileOp) passes
Hale423 Mar 9, 2026
7cf06fa
fix: SM120 FP4 CUDA compilation – type bridge, codegen naming, MMA di…
Hale423 Mar 12, 2026
f068f2f
fix(buffer overflow, fp4 address misalignment): ldmatrix writes in by…
Hale423 Mar 13, 2026
d116372
feat: SM120 FP4 GEMM unpacked shared memory + numerical verification
Hale423 Mar 14, 2026
f13a6b7
feat: A8W4 mixed-type MMA (FP8xFP4) + FP4 MoE example on SM120
Hale423 Mar 14, 2026
8c915cb
Merge origin/main into feat/gemm-fp4
Hale423 May 8, 2026
10a2710
fix: fix precision using smem unpacked layout, setting bit 2-5/8 bits…
Hale423 May 11, 2026
eeddaac
Implement SM100/SM110 FP4 and A8W4 examples using the TCGEN05/TMA unp…
Hale423 May 11, 2026
3fb9c9d
feat: A8W4 mixed-type MMA (FP8xFP4) + FP4 MoE example on SM120
Hale423 Mar 14, 2026
63c9451
Resolve FP4 example conflicts and remove generated artifacts
Hale423 May 11, 2026
8732c4c
Merge latest main into FP4 feature branch
Hale423 May 11, 2026
70afbc3
rewrite: rewrite codes according to code rabbit's suggestions
Hale423 May 11, 2026
f816a4e
handle non-align16B fp4 shared layout gracefully
Hale423 May 15, 2026
57449a8
refresh the code for pre-commit check
Hale423 May 15, 2026
0d2a9cb
tmp save
Hale423 Jun 2, 2026
d153d7a
Merge remote-tracking branch 'origin/main' into feat/gemm-fp4
Hale423 Jun 2, 2026
edcd752
Merge remote-tracking branch 'origin/main' into feat/gemm-fp4
Hale423 Jun 2, 2026
6ed1e83
Merge branch 'feat/gemm-fp4' into feat/gemm-nvfp4
Hale423 Jun 2, 2026
48db769
feat(sm120): add T.nvfp4_gemm as frontend, lower to mma_blockscaled, …
Hale423 Jun 3, 2026
8aa292a
feat(sm100): add mxf4nvf4 feature for ptx code, meta/descriptor, conn…
Hale423 Jun 5, 2026
4dd1a5e
feat(sm100): based on blockscale fw,assign is_nvfp4 annotation to add…
Hale423 Jun 5, 2026
1290a96
fix(sm100): correct NVFP4 mxf4nvf4 block-scaled GEMM by staging FP4 o…
Hale423 Jun 12, 2026
1111bfd
feat(examples): add SM100 NVFP4 fused MoE example (mxf4nvf4 gate/up +…
Hale423 Jun 15, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 171 additions & 0 deletions examples/gemm_fp4/example_fusedmoe_a8w4_sm100.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
"""Simplified A8W4 fused MoE shared expert on SM100/SM110.

This mirrors the SM120 demo shape, but uses TCGEN05/TMEM:
gate = input(FP8) x W_gate(FP4)
up = input(FP8) x W_up(FP4)
out = silu(gate) * up

The down projection is intentionally omitted here, matching the existing SM120
diagnostic example and keeping the kernel focused on mixed FP8xFP4 TCGEN05 GEMM.
"""

import os
import time

import torch
import tilelang
import tilelang.language as T


FP4_E2M1_TO_FLOAT = [
0.0,
0.5,
1.0,
1.5,
2.0,
3.0,
4.0,
6.0,
-0.0,
-0.5,
-1.0,
-1.5,
-2.0,
-3.0,
-4.0,
-6.0,
]


def unpack_fp4_to_float(packed_int8, rows, cols):
lut = torch.tensor(FP4_E2M1_TO_FLOAT, dtype=torch.float32, device=packed_int8.device)
flat = packed_int8.to(torch.uint8).reshape(rows, cols // 2)
lo = flat & 0x0F
hi = (flat >> 4) & 0x0F
unpacked = torch.stack([lo, hi], dim=-1).reshape(rows, cols).to(torch.int64)
return lut[unpacked]
Comment on lines +40 to +46

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Guard TL_MOE_HIDDEN for packed FP4 storage.

Both expert weight tensors are packed as (d_expert, d_hidden // 2), so this example only works when d_hidden is even. Right now an odd TL_MOE_HIDDEN is accepted and the failure shows up later during unpack/reference instead of at config parsing.

Suggested guard
 num_tokens = int(os.environ.get("TL_MOE_TOKENS", "128"))
 d_hidden = int(os.environ.get("TL_MOE_HIDDEN", "256"))
 d_expert = int(os.environ.get("TL_MOE_EXPERT", "256"))
@@
 block_token = int(os.environ.get("TL_MOE_BLOCK_TOKEN", "128"))
 block_hidden = int(os.environ.get("TL_MOE_BLOCK_HIDDEN", "128"))
 block_expert = int(os.environ.get("TL_MOE_BLOCK_EXPERT", "64"))
+
+if d_hidden % 2 != 0:
+    raise ValueError("TL_MOE_HIDDEN must be even for packed FP4 expert weights")

Also applies to: 115-120, 141-155

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/gemm_fp4/example_fusedmoe_a8w4_sm100.py` around lines 40 - 46, Add
an explicit guard that TL_MOE_HIDDEN is even before treating expert weight
tensors as packed FP4 (two 4-bit values per byte); validate this in the
config/parsing path and before any use of unpack_fp4_to_float and places where
expert tensors are created/reshaped (references: function unpack_fp4_to_float
and the code regions handling expert weight shapes around the other
occurrences). If TL_MOE_HIDDEN is odd, raise a clear ValueError (or
argparse/config error) explaining that packed FP4 storage requires an even
hidden dimension, so the failure happens early instead of during unpacking.



def fusedmoe_a8w4_sm100(num_tokens, d_hidden, d_expert, block_token=128, block_hidden=128, block_expert=64, threads=128, num_stages=1):
scale = 1.44269504 # log2(e)

@T.prim_func
def main(
Input: T.Tensor((num_tokens, d_hidden), "float8_e4m3fn"),
W_gate: T.Tensor((d_expert, d_hidden), T.float4_e2m1fn),
W_up: T.Tensor((d_expert, d_hidden), T.float4_e2m1fn),
Output: T.Tensor((num_tokens, d_expert), "float32"),
):
with T.Kernel(
T.ceildiv(d_expert, block_expert),
T.ceildiv(num_tokens, block_token),
threads=threads,
) as (bx, by):
input_shared = T.alloc_shared((block_token, block_hidden), "float8_e4m3fn")
gate_shared = T.alloc_shared((block_expert, block_hidden), T.float4_e2m1fn)
up_shared = T.alloc_shared((block_expert, block_hidden), T.float4_e2m1fn)

gate_tmem = T.alloc_tmem([block_token, block_expert], "float32")
up_tmem = T.alloc_tmem([block_token, block_expert], "float32")
gate_mbar = T.alloc_barrier(1)
up_mbar = T.alloc_barrier(1)

gate_local = T.alloc_fragment((block_token, block_expert), "float32")
up_local = T.alloc_fragment((block_token, block_expert), "float32")

for k in T.Pipelined(T.ceildiv(d_hidden, block_hidden), num_stages=num_stages):
T.copy(Input[by * block_token, k * block_hidden], input_shared)
T.copy(W_gate[bx * block_expert, k * block_hidden], gate_shared)
T.copy(W_up[bx * block_expert, k * block_hidden], up_shared)

T.tcgen05_gemm(
input_shared,
gate_shared,
gate_tmem,
transpose_A=False,
transpose_B=True,
mbar=gate_mbar,
clear_accum=(k == 0),
)
T.mbarrier_wait_parity(gate_mbar, k % 2)

T.tcgen05_gemm(
input_shared,
up_shared,
up_tmem,
transpose_A=False,
transpose_B=True,
mbar=up_mbar,
clear_accum=(k == 0),
)
T.mbarrier_wait_parity(up_mbar, k % 2)

T.copy(gate_tmem, gate_local)
T.copy(up_tmem, up_local)

for i, j in T.Parallel(block_token, block_expert):
gate = gate_local[i, j] * (1.0 / (1.0 + T.exp2(-gate_local[i, j] * scale)))
up_local[i, j] = up_local[i, j] * gate

T.copy(up_local, Output[by * block_token, bx * block_expert])

return main


num_tokens = int(os.environ.get("TL_MOE_TOKENS", "128"))
d_hidden = int(os.environ.get("TL_MOE_HIDDEN", "256"))
d_expert = int(os.environ.get("TL_MOE_EXPERT", "256"))
block_token = int(os.environ.get("TL_MOE_BLOCK_TOKEN", "128"))
block_hidden = int(os.environ.get("TL_MOE_BLOCK_HIDDEN", "128"))
block_expert = int(os.environ.get("TL_MOE_BLOCK_EXPERT", "64"))

print(
f"Running SM100 A8W4 fused MoE: tokens={num_tokens}, hidden={d_hidden}, "
f"expert={d_expert}, block=({block_token},{block_hidden},{block_expert})"
)

func = fusedmoe_a8w4_sm100(num_tokens, d_hidden, d_expert, block_token, block_hidden, block_expert)
jit_kernel = tilelang.compile(
func,
out_idx=[3],
target="cuda",
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: False,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
},
)
print("Compilation succeeded!")

torch.manual_seed(42)
input_fp8 = torch.randn(num_tokens, d_hidden, device="cuda", dtype=torch.float16).to(torch.float8_e4m3fn)
w_gate = torch.randint(0, 256, (d_expert, d_hidden // 2), device="cuda", dtype=torch.uint8).to(torch.int8)
w_up = torch.randint(0, 256, (d_expert, d_hidden // 2), device="cuda", dtype=torch.uint8).to(torch.int8)

z_input = torch.zeros(num_tokens, d_hidden, device="cuda", dtype=torch.float8_e4m3fn)
z_gate = torch.zeros(d_expert, d_hidden // 2, device="cuda", dtype=torch.int8)
z_up = torch.zeros(d_expert, d_hidden // 2, device="cuda", dtype=torch.int8)
c_zero = jit_kernel(z_input, z_gate, z_up)
assert c_zero.abs().max().item() == 0.0, f"Zero test failed: max={c_zero.abs().max().item()}"
print("[PASS] zeros in -> zeros out")

out = jit_kernel(input_fp8, w_gate, w_up)
input_f32 = input_fp8.to(torch.float32)
gate_logits = input_f32 @ unpack_fp4_to_float(w_gate, d_expert, d_hidden).T
up_logits = input_f32 @ unpack_fp4_to_float(w_up, d_expert, d_hidden).T
ref = up_logits * (gate_logits * torch.sigmoid(gate_logits))

diff = (out.float() - ref).abs()
max_diff = diff.max().item()
rel_err = diff.sum().item() / (ref.abs().sum().item() + 1e-10)
print(f"[NUMERICAL] max_abs_diff={max_diff:.4f}, rel_err={rel_err:.6f}")
print("[PASS] MoE numerical verification" if rel_err < 0.05 else "[WARN] large diff")

torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(100):
jit_kernel(input_fp8, w_gate, w_up)
torch.cuda.synchronize()
elapsed = (time.perf_counter() - start) / 100 * 1000
total_flops = 2 * num_tokens * d_hidden * d_expert * 2
print(f"Latency: {elapsed:.4f} ms")
print(f"TFLOPS: {total_flops / (elapsed / 1e3) / 1e12:.2f}")
174 changes: 174 additions & 0 deletions examples/gemm_fp4/example_fusedmoe_a8w4_sm120.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
"""FP4 Fused MoE shared expert kernel on SM120 (A8W4 mode).

Demonstrates the core MoE gate/up compute pattern with FP4 weights:
1. Gate GEMM: input(FP8) x W_gate(FP4) -> gate logits
2. Up GEMM: input(FP8) x W_up(FP4) -> up logits
3. SiLU(gate) * up

Uses SM120 native kind::f8f6f4 MMA (FP8 x FP4 -> FP32).
Expert weights are stored as unpacked uint8 (1 FP4 per byte, low nibble).

This is a simplified single-expert example. For full routing + grouped GEMM,
see examples/fusedmoe/example_fusedmoe_tilelang.py.
"""

import time
import torch
import tilelang
import tilelang.language as T


FP4_E2M1_TO_FLOAT = [
0.0,
0.5,
1.0,
1.5,
2.0,
3.0,
4.0,
6.0,
-0.0,
-0.5,
-1.0,
-1.5,
-2.0,
-3.0,
-4.0,
-6.0,
]


def fp4_uint8_to_float(t):
lut = torch.tensor(FP4_E2M1_TO_FLOAT, dtype=torch.float32, device=t.device)
return lut[t.to(torch.int64)]


def moe_shared_expert_a8w4(
num_tokens,
d_hidden,
d_expert,
block_token=128,
block_hidden=128,
block_expert=128,
threads=128,
num_stages=1,
):
"""Single shared expert: gate_up GEMM -> SiLU*up -> down GEMM."""
scale = 1.44269504 # log2(e) for fast SiLU

@T.prim_func
def main(
Input: T.Tensor((num_tokens, d_hidden), "float8_e4m3fn"),
W_gate: T.Tensor((d_expert, d_hidden), "uint8"),
W_up: T.Tensor((d_expert, d_hidden), "uint8"),
output: T.Tensor((num_tokens, d_expert), "float32"),
):
# Step 1: Gate + Up GEMMs (fused in one kernel launch)
with T.Kernel(
T.ceildiv(num_tokens, block_token),
T.ceildiv(d_expert, block_expert),
threads=threads,
) as (bx, by):
input_shared = T.alloc_shared((block_token, block_hidden), "float8_e4m3fn")
W_gate_shared = T.alloc_shared((block_expert, block_hidden), "uint8")
W_up_shared = T.alloc_shared((block_expert, block_hidden), "uint8")

gate_local = T.alloc_fragment((block_token, block_expert), "float32")
up_local = T.alloc_fragment((block_token, block_expert), "float32")

T.clear(gate_local)
T.clear(up_local)

for k in T.Pipelined(T.ceildiv(d_hidden, block_hidden), num_stages=num_stages):
T.copy(Input[bx * block_token, k * block_hidden], input_shared)
T.copy(W_gate[by * block_expert, k * block_hidden], W_gate_shared)
T.copy(W_up[by * block_expert, k * block_hidden], W_up_shared)
T.gemm(input_shared, W_gate_shared, gate_local, transpose_B=True)
T.gemm(input_shared, W_up_shared, up_local, transpose_B=True)

# Fused SiLU activation: gate = gate * sigmoid(gate), then up = up * gate
for i, j in T.Parallel(block_token, block_expert):
gate_local[i, j] = gate_local[i, j] * (1.0 / (1.0 + T.exp2(-gate_local[i, j] * scale)))
up_local[i, j] = up_local[i, j] * gate_local[i, j]

T.copy(up_local, output[bx * block_token, by * block_expert])

return main


# Problem sizes (small for testing)
num_tokens = 128
d_hidden = 256
d_expert = 256

print(f"Running FP4 MoE (A8W4): tokens={num_tokens}, hidden={d_hidden}, expert={d_expert}")

func = moe_shared_expert_a8w4(
num_tokens,
d_hidden,
d_expert,
block_token=128,
block_hidden=128,
block_expert=128,
threads=128,
num_stages=1,
)

jit_kernel = tilelang.compile(
func,
out_idx=[3],
target="cuda",
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
},
)

print("Compilation succeeded!")

torch.manual_seed(42)

# Create test data
input_fp8 = torch.randn(num_tokens, d_hidden, device="cuda", dtype=torch.float16).to(torch.float8_e4m3fn)
W_gate_uint8 = torch.randint(0, 16, (d_expert, d_hidden), device="cuda", dtype=torch.uint8)
W_up_uint8 = torch.randint(0, 16, (d_expert, d_hidden), device="cuda", dtype=torch.uint8)

# --- Test 1: zeros ---
z_input = torch.zeros(num_tokens, d_hidden, device="cuda", dtype=torch.float8_e4m3fn)
z_gate = torch.zeros(d_expert, d_hidden, device="cuda", dtype=torch.uint8)
z_up = torch.zeros(d_expert, d_hidden, device="cuda", dtype=torch.uint8)
c_zero = jit_kernel(z_input, z_gate, z_up)
print(f"[{'PASS' if c_zero.abs().max().item() == 0.0 else 'FAIL'}] zeros in -> zeros out")
Comment on lines +136 to +141

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Make the zero-input smoke test fail closed.

Right now this only prints [FAIL] and keeps going, so the example still exits successfully even when the fused kernel is fundamentally broken. The other new example scripts already assert here; this one should too.

Suggested fix
 c_zero = jit_kernel(z_input, z_gate, z_up)
-print(f"[{'PASS' if c_zero.abs().max().item() == 0.0 else 'FAIL'}] zeros in -> zeros out")
+assert c_zero.abs().max().item() == 0.0, f"Zero test failed: max={c_zero.abs().max().item()}"
+print("[PASS] zeros in -> zeros out")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# --- Test 1: zeros ---
z_input = torch.zeros(num_tokens, d_hidden, device="cuda", dtype=torch.float8_e4m3fn)
z_gate = torch.zeros(d_expert, d_hidden, device="cuda", dtype=torch.uint8)
z_up = torch.zeros(d_expert, d_hidden, device="cuda", dtype=torch.uint8)
c_zero = jit_kernel(z_input, z_gate, z_up)
print(f"[{'PASS' if c_zero.abs().max().item() == 0.0 else 'FAIL'}] zeros in -> zeros out")
# --- Test 1: zeros ---
z_input = torch.zeros(num_tokens, d_hidden, device="cuda", dtype=torch.float8_e4m3fn)
z_gate = torch.zeros(d_expert, d_hidden, device="cuda", dtype=torch.uint8)
z_up = torch.zeros(d_expert, d_hidden, device="cuda", dtype=torch.uint8)
c_zero = jit_kernel(z_input, z_gate, z_up)
assert c_zero.abs().max().item() == 0.0, f"Zero test failed: max={c_zero.abs().max().item()}"
print("[PASS] zeros in -> zeros out")
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/gemm_fp4/example_fusedmoe_a8w4_sm120.py` around lines 136 - 141, The
zero-input smoke test currently only prints PASS/FAIL and can allow the script
to exit successfully on failure; change it to a hard assertion so failures stop
execution: after calling jit_kernel(z_input, z_gate, z_up) and computing c_zero,
replace the print line with an assertion that c_zero.abs().max().item() == 0.0
(or torch.equal(c_zero, torch.zeros_like(c_zero))) and include a descriptive
message (e.g., "zeros in -> zeros out failed") so the test fails closed; refer
to the variables z_input, z_gate, z_up and the function jit_kernel to locate
where to apply this change.


# --- Test 2: numerical verification (gate+up only, no down GEMM in this kernel) ---
out = jit_kernel(input_fp8, W_gate_uint8, W_up_uint8)

# Reference
input_f32 = input_fp8.to(torch.float32)
gate_f32 = fp4_uint8_to_float(W_gate_uint8)
up_f32 = fp4_uint8_to_float(W_up_uint8)

gate_logits = input_f32 @ gate_f32.T
up_logits = input_f32 @ up_f32.T
gate_activated = gate_logits * torch.sigmoid(gate_logits)
ref_out = up_logits * gate_activated

diff = (out.float() - ref_out).abs()
max_diff = diff.max().item()
rel_err = diff.sum().item() / (ref_out.abs().sum().item() + 1e-10)
print(f"[NUMERICAL] max_abs_diff={max_diff:.4f}, rel_err={rel_err:.6f}")
if rel_err < 0.05:
print("[PASS] MoE gate+up fusion numerical verification")
else:
print("[WARN] large diff")

# --- Benchmark ---
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(100):
jit_kernel(input_fp8, W_gate_uint8, W_up_uint8)
torch.cuda.synchronize()
elapsed = (time.perf_counter() - start) / 100 * 1000
total_flops = 2 * num_tokens * d_hidden * d_expert * 2 # 2 GEMMs
print(f"Latency: {elapsed:.4f} ms")
print(f"TFLOPS: {total_flops / (elapsed / 1e3) / 1e12:.2f}")
Loading
Loading