Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
a356049
feat: add float4_e2m1fn (NV FP4) support for SM100 TCGEN5 MMA
Hale423 Mar 9, 2026
44336d7
fix: python pipeline (LayoutInference + LowerTileOp) passes
Hale423 Mar 9, 2026
7cf06fa
fix: SM120 FP4 CUDA compilation – type bridge, codegen naming, MMA di…
Hale423 Mar 12, 2026
f068f2f
fix(buffer overflow, fp4 address misalignment): ldmatrix writes in by…
Hale423 Mar 13, 2026
d116372
feat: SM120 FP4 GEMM unpacked shared memory + numerical verification
Hale423 Mar 14, 2026
f13a6b7
feat: A8W4 mixed-type MMA (FP8xFP4) + FP4 MoE example on SM120
Hale423 Mar 14, 2026
8c915cb
Merge origin/main into feat/gemm-fp4
Hale423 May 8, 2026
10a2710
fix: fix precision using smem unpacked layout, setting bit 2-5/8 bits…
Hale423 May 11, 2026
eeddaac
Implement SM100/SM110 FP4 and A8W4 examples using the TCGEN05/TMA unp…
Hale423 May 11, 2026
3fb9c9d
feat: A8W4 mixed-type MMA (FP8xFP4) + FP4 MoE example on SM120
Hale423 Mar 14, 2026
63c9451
Resolve FP4 example conflicts and remove generated artifacts
Hale423 May 11, 2026
8732c4c
Merge latest main into FP4 feature branch
Hale423 May 11, 2026
70afbc3
rewrite: rewrite codes according to code rabbit's suggestions
Hale423 May 11, 2026
f816a4e
handle non-align16B fp4 shared layout gracefully
Hale423 May 15, 2026
57449a8
refresh the code for pre-commit check
Hale423 May 15, 2026
0d2a9cb
tmp save
Hale423 Jun 2, 2026
d153d7a
Merge remote-tracking branch 'origin/main' into feat/gemm-fp4
Hale423 Jun 2, 2026
edcd752
Merge remote-tracking branch 'origin/main' into feat/gemm-fp4
Hale423 Jun 2, 2026
6ed1e83
Merge branch 'feat/gemm-fp4' into feat/gemm-nvfp4
Hale423 Jun 2, 2026
48db769
feat(sm120): add T.nvfp4_gemm as frontend, lower to mma_blockscaled, …
Hale423 Jun 3, 2026
8aa292a
feat(sm100): add mxf4nvf4 feature for ptx code, meta/descriptor, conn…
Hale423 Jun 5, 2026
4dd1a5e
feat(sm100): based on blockscale fw,assign is_nvfp4 annotation to add…
Hale423 Jun 5, 2026
1290a96
fix(sm100): correct NVFP4 mxf4nvf4 block-scaled GEMM by staging FP4 o…
Hale423 Jun 12, 2026
1111bfd
feat(examples): add SM100 NVFP4 fused MoE example (mxf4nvf4 gate/up +…
Hale423 Jun 15, 2026
3a31b97
refresh: refresh code according coderabbat suggestions
Hale423 Jun 22, 2026
2abb53d
refresh: refresh code according coderabbat suggestions, and resolve c…
Hale423 Jun 22, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 174 additions & 0 deletions examples/gemm_fp4/example_fusedmoe_a8w4_sm100.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
"""Simplified A8W4 fused MoE shared expert on SM100/SM110.

This mirrors the SM120 demo shape, but uses TCGEN05/TMEM:
gate = input(FP8) x W_gate(FP4)
up = input(FP8) x W_up(FP4)
out = silu(gate) * up

The down projection is intentionally omitted here, matching the existing SM120
diagnostic example and keeping the kernel focused on mixed FP8xFP4 TCGEN05 GEMM.
"""

import os
import time

import torch
import tilelang
import tilelang.language as T


FP4_E2M1_TO_FLOAT = [
0.0,
0.5,
1.0,
1.5,
2.0,
3.0,
4.0,
6.0,
-0.0,
-0.5,
-1.0,
-1.5,
-2.0,
-3.0,
-4.0,
-6.0,
]


def unpack_fp4_to_float(packed_int8, rows, cols):
lut = torch.tensor(FP4_E2M1_TO_FLOAT, dtype=torch.float32, device=packed_int8.device)
flat = packed_int8.to(torch.uint8).reshape(rows, cols // 2)
lo = flat & 0x0F
hi = (flat >> 4) & 0x0F
unpacked = torch.stack([lo, hi], dim=-1).reshape(rows, cols).to(torch.int64)
return lut[unpacked]
Comment thread
coderabbitai[bot] marked this conversation as resolved.


def fusedmoe_a8w4_sm100(num_tokens, d_hidden, d_expert, block_token=128, block_hidden=128, block_expert=64, threads=128, num_stages=1):
scale = 1.44269504 # log2(e)

@T.prim_func
def main(
Input: T.Tensor((num_tokens, d_hidden), "float8_e4m3fn"),
W_gate: T.Tensor((d_expert, d_hidden), T.float4_e2m1fn),
W_up: T.Tensor((d_expert, d_hidden), T.float4_e2m1fn),
Output: T.Tensor((num_tokens, d_expert), "float32"),
):
with T.Kernel(
T.ceildiv(d_expert, block_expert),
T.ceildiv(num_tokens, block_token),
threads=threads,
) as (bx, by):
input_shared = T.alloc_shared((block_token, block_hidden), "float8_e4m3fn")
gate_shared = T.alloc_shared((block_expert, block_hidden), T.float4_e2m1fn)
up_shared = T.alloc_shared((block_expert, block_hidden), T.float4_e2m1fn)

gate_tmem = T.alloc_tmem([block_token, block_expert], "float32")
up_tmem = T.alloc_tmem([block_token, block_expert], "float32")
gate_mbar = T.alloc_barrier(1)
up_mbar = T.alloc_barrier(1)

gate_local = T.alloc_fragment((block_token, block_expert), "float32")
up_local = T.alloc_fragment((block_token, block_expert), "float32")

for k in T.Pipelined(T.ceildiv(d_hidden, block_hidden), num_stages=num_stages):
T.copy(Input[by * block_token, k * block_hidden], input_shared)
T.copy(W_gate[bx * block_expert, k * block_hidden], gate_shared)
T.copy(W_up[bx * block_expert, k * block_hidden], up_shared)

T.tcgen05_gemm(
input_shared,
gate_shared,
gate_tmem,
transpose_A=False,
transpose_B=True,
mbar=gate_mbar,
clear_accum=(k == 0),
)
T.mbarrier_wait_parity(gate_mbar, k % 2)

T.tcgen05_gemm(
input_shared,
up_shared,
up_tmem,
transpose_A=False,
transpose_B=True,
mbar=up_mbar,
clear_accum=(k == 0),
)
T.mbarrier_wait_parity(up_mbar, k % 2)

T.copy(gate_tmem, gate_local)
T.copy(up_tmem, up_local)

for i, j in T.Parallel(block_token, block_expert):
gate = gate_local[i, j] * (1.0 / (1.0 + T.exp2(-gate_local[i, j] * scale)))
up_local[i, j] = up_local[i, j] * gate

T.copy(up_local, Output[by * block_token, bx * block_expert])

return main


num_tokens = int(os.environ.get("TL_MOE_TOKENS", "128"))
d_hidden = int(os.environ.get("TL_MOE_HIDDEN", "256"))
d_expert = int(os.environ.get("TL_MOE_EXPERT", "256"))
block_token = int(os.environ.get("TL_MOE_BLOCK_TOKEN", "128"))
block_hidden = int(os.environ.get("TL_MOE_BLOCK_HIDDEN", "128"))
block_expert = int(os.environ.get("TL_MOE_BLOCK_EXPERT", "64"))

if d_hidden % 2 != 0:
raise ValueError("TL_MOE_HIDDEN must be even for packed FP4 expert weights")

print(
f"Running SM100 A8W4 fused MoE: tokens={num_tokens}, hidden={d_hidden}, "
f"expert={d_expert}, block=({block_token},{block_hidden},{block_expert})"
)

func = fusedmoe_a8w4_sm100(num_tokens, d_hidden, d_expert, block_token, block_hidden, block_expert)
jit_kernel = tilelang.compile(
func,
out_idx=[3],
target="cuda",
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: False,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
},
)
print("Compilation succeeded!")

torch.manual_seed(42)
input_fp8 = torch.randn(num_tokens, d_hidden, device="cuda", dtype=torch.float16).to(torch.float8_e4m3fn)
w_gate = torch.randint(0, 256, (d_expert, d_hidden // 2), device="cuda", dtype=torch.uint8).to(torch.int8)
w_up = torch.randint(0, 256, (d_expert, d_hidden // 2), device="cuda", dtype=torch.uint8).to(torch.int8)

z_input = torch.zeros(num_tokens, d_hidden, device="cuda", dtype=torch.float8_e4m3fn)
z_gate = torch.zeros(d_expert, d_hidden // 2, device="cuda", dtype=torch.int8)
z_up = torch.zeros(d_expert, d_hidden // 2, device="cuda", dtype=torch.int8)
c_zero = jit_kernel(z_input, z_gate, z_up)
assert c_zero.abs().max().item() == 0.0, f"Zero test failed: max={c_zero.abs().max().item()}"
print("[PASS] zeros in -> zeros out")

out = jit_kernel(input_fp8, w_gate, w_up)
input_f32 = input_fp8.to(torch.float32)
gate_logits = input_f32 @ unpack_fp4_to_float(w_gate, d_expert, d_hidden).T
up_logits = input_f32 @ unpack_fp4_to_float(w_up, d_expert, d_hidden).T
ref = up_logits * (gate_logits * torch.sigmoid(gate_logits))

diff = (out.float() - ref).abs()
max_diff = diff.max().item()
rel_err = diff.sum().item() / (ref.abs().sum().item() + 1e-10)
print(f"[NUMERICAL] max_abs_diff={max_diff:.4f}, rel_err={rel_err:.6f}")
print("[PASS] MoE numerical verification" if rel_err < 0.05 else "[WARN] large diff")

torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(100):
jit_kernel(input_fp8, w_gate, w_up)
torch.cuda.synchronize()
elapsed = (time.perf_counter() - start) / 100 * 1000
total_flops = 2 * num_tokens * d_hidden * d_expert * 2
print(f"Latency: {elapsed:.4f} ms")
print(f"TFLOPS: {total_flops / (elapsed / 1e3) / 1e12:.2f}")
175 changes: 175 additions & 0 deletions examples/gemm_fp4/example_fusedmoe_a8w4_sm120.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
"""FP4 Fused MoE shared expert kernel on SM120 (A8W4 mode).

Demonstrates the core MoE gate/up compute pattern with FP4 weights:
1. Gate GEMM: input(FP8) x W_gate(FP4) -> gate logits
2. Up GEMM: input(FP8) x W_up(FP4) -> up logits
3. SiLU(gate) * up

Uses SM120 native kind::f8f6f4 MMA (FP8 x FP4 -> FP32).
Expert weights are stored as unpacked uint8 (1 FP4 per byte, low nibble).

This is a simplified single-expert example. For full routing + grouped GEMM,
see examples/fusedmoe/example_fusedmoe_tilelang.py.
"""

import time
import torch
import tilelang
import tilelang.language as T


FP4_E2M1_TO_FLOAT = [
0.0,
0.5,
1.0,
1.5,
2.0,
3.0,
4.0,
6.0,
-0.0,
-0.5,
-1.0,
-1.5,
-2.0,
-3.0,
-4.0,
-6.0,
]


def fp4_uint8_to_float(t):
lut = torch.tensor(FP4_E2M1_TO_FLOAT, dtype=torch.float32, device=t.device)
return lut[t.to(torch.int64)]


def moe_shared_expert_a8w4(
num_tokens,
d_hidden,
d_expert,
block_token=128,
block_hidden=128,
block_expert=128,
threads=128,
num_stages=1,
):
"""Single shared expert: gate_up GEMM -> SiLU*up -> down GEMM."""
scale = 1.44269504 # log2(e) for fast SiLU

@T.prim_func
def main(
Input: T.Tensor((num_tokens, d_hidden), "float8_e4m3fn"),
W_gate: T.Tensor((d_expert, d_hidden), "uint8"),
W_up: T.Tensor((d_expert, d_hidden), "uint8"),
output: T.Tensor((num_tokens, d_expert), "float32"),
):
# Step 1: Gate + Up GEMMs (fused in one kernel launch)
with T.Kernel(
T.ceildiv(num_tokens, block_token),
T.ceildiv(d_expert, block_expert),
threads=threads,
) as (bx, by):
input_shared = T.alloc_shared((block_token, block_hidden), "float8_e4m3fn")
W_gate_shared = T.alloc_shared((block_expert, block_hidden), "uint8")
W_up_shared = T.alloc_shared((block_expert, block_hidden), "uint8")

gate_local = T.alloc_fragment((block_token, block_expert), "float32")
up_local = T.alloc_fragment((block_token, block_expert), "float32")

T.clear(gate_local)
T.clear(up_local)

for k in T.Pipelined(T.ceildiv(d_hidden, block_hidden), num_stages=num_stages):
T.copy(Input[bx * block_token, k * block_hidden], input_shared)
T.copy(W_gate[by * block_expert, k * block_hidden], W_gate_shared)
T.copy(W_up[by * block_expert, k * block_hidden], W_up_shared)
T.gemm(input_shared, W_gate_shared, gate_local, transpose_B=True)
T.gemm(input_shared, W_up_shared, up_local, transpose_B=True)

# Fused SiLU activation: gate = gate * sigmoid(gate), then up = up * gate
for i, j in T.Parallel(block_token, block_expert):
gate_local[i, j] = gate_local[i, j] * (1.0 / (1.0 + T.exp2(-gate_local[i, j] * scale)))
up_local[i, j] = up_local[i, j] * gate_local[i, j]

T.copy(up_local, output[bx * block_token, by * block_expert])

return main


# Problem sizes (small for testing)
num_tokens = 128
d_hidden = 256
d_expert = 256

print(f"Running FP4 MoE (A8W4): tokens={num_tokens}, hidden={d_hidden}, expert={d_expert}")

func = moe_shared_expert_a8w4(
num_tokens,
d_hidden,
d_expert,
block_token=128,
block_hidden=128,
block_expert=128,
threads=128,
num_stages=1,
)

jit_kernel = tilelang.compile(
func,
out_idx=[3],
target="cuda",
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
},
)

print("Compilation succeeded!")

torch.manual_seed(42)

# Create test data
input_fp8 = torch.randn(num_tokens, d_hidden, device="cuda", dtype=torch.float16).to(torch.float8_e4m3fn)
W_gate_uint8 = torch.randint(0, 16, (d_expert, d_hidden), device="cuda", dtype=torch.uint8)
W_up_uint8 = torch.randint(0, 16, (d_expert, d_hidden), device="cuda", dtype=torch.uint8)

# --- Test 1: zeros ---
z_input = torch.zeros(num_tokens, d_hidden, device="cuda", dtype=torch.float8_e4m3fn)
z_gate = torch.zeros(d_expert, d_hidden, device="cuda", dtype=torch.uint8)
z_up = torch.zeros(d_expert, d_hidden, device="cuda", dtype=torch.uint8)
c_zero = jit_kernel(z_input, z_gate, z_up)
assert c_zero.abs().max().item() == 0.0, f"Zero test failed: max={c_zero.abs().max().item()}"
print("[PASS] zeros in -> zeros out")

# --- Test 2: numerical verification (gate+up only, no down GEMM in this kernel) ---
out = jit_kernel(input_fp8, W_gate_uint8, W_up_uint8)

# Reference
input_f32 = input_fp8.to(torch.float32)
gate_f32 = fp4_uint8_to_float(W_gate_uint8)
up_f32 = fp4_uint8_to_float(W_up_uint8)

gate_logits = input_f32 @ gate_f32.T
up_logits = input_f32 @ up_f32.T
gate_activated = gate_logits * torch.sigmoid(gate_logits)
ref_out = up_logits * gate_activated

diff = (out.float() - ref_out).abs()
max_diff = diff.max().item()
rel_err = diff.sum().item() / (ref_out.abs().sum().item() + 1e-10)
print(f"[NUMERICAL] max_abs_diff={max_diff:.4f}, rel_err={rel_err:.6f}")
if rel_err < 0.05:
print("[PASS] MoE gate+up fusion numerical verification")
else:
print("[WARN] large diff")

# --- Benchmark ---
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(100):
jit_kernel(input_fp8, W_gate_uint8, W_up_uint8)
torch.cuda.synchronize()
elapsed = (time.perf_counter() - start) / 100 * 1000
total_flops = 2 * num_tokens * d_hidden * d_expert * 2 # 2 GEMMs
print(f"Latency: {elapsed:.4f} ms")
print(f"TFLOPS: {total_flops / (elapsed / 1e3) / 1e12:.2f}")
Loading
Loading