-
Notifications
You must be signed in to change notification settings - Fork 609
[Stacked][Feature] Support NVFP4 Gemm on Blackwell arch (SM100,110,120) #2324
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Hale423
wants to merge
26
commits into
tile-ai:main
Choose a base branch
from
Hale423:feat/gemm-nvfp4
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
26 commits
Select commit
Hold shift + click to select a range
a356049
feat: add float4_e2m1fn (NV FP4) support for SM100 TCGEN5 MMA
Hale423 44336d7
fix: python pipeline (LayoutInference + LowerTileOp) passes
Hale423 7cf06fa
fix: SM120 FP4 CUDA compilation – type bridge, codegen naming, MMA di…
Hale423 f068f2f
fix(buffer overflow, fp4 address misalignment): ldmatrix writes in by…
Hale423 d116372
feat: SM120 FP4 GEMM unpacked shared memory + numerical verification
Hale423 f13a6b7
feat: A8W4 mixed-type MMA (FP8xFP4) + FP4 MoE example on SM120
Hale423 8c915cb
Merge origin/main into feat/gemm-fp4
Hale423 10a2710
fix: fix precision using smem unpacked layout, setting bit 2-5/8 bits…
Hale423 eeddaac
Implement SM100/SM110 FP4 and A8W4 examples using the TCGEN05/TMA unp…
Hale423 3fb9c9d
feat: A8W4 mixed-type MMA (FP8xFP4) + FP4 MoE example on SM120
Hale423 63c9451
Resolve FP4 example conflicts and remove generated artifacts
Hale423 8732c4c
Merge latest main into FP4 feature branch
Hale423 70afbc3
rewrite: rewrite codes according to code rabbit's suggestions
Hale423 f816a4e
handle non-align16B fp4 shared layout gracefully
Hale423 57449a8
refresh the code for pre-commit check
Hale423 0d2a9cb
tmp save
Hale423 d153d7a
Merge remote-tracking branch 'origin/main' into feat/gemm-fp4
Hale423 edcd752
Merge remote-tracking branch 'origin/main' into feat/gemm-fp4
Hale423 6ed1e83
Merge branch 'feat/gemm-fp4' into feat/gemm-nvfp4
Hale423 48db769
feat(sm120): add T.nvfp4_gemm as frontend, lower to mma_blockscaled, …
Hale423 8aa292a
feat(sm100): add mxf4nvf4 feature for ptx code, meta/descriptor, conn…
Hale423 4dd1a5e
feat(sm100): based on blockscale fw,assign is_nvfp4 annotation to add…
Hale423 1290a96
fix(sm100): correct NVFP4 mxf4nvf4 block-scaled GEMM by staging FP4 o…
Hale423 1111bfd
feat(examples): add SM100 NVFP4 fused MoE example (mxf4nvf4 gate/up +…
Hale423 3a31b97
refresh: refresh code according coderabbat suggestions
Hale423 2abb53d
refresh: refresh code according coderabbat suggestions, and resolve c…
Hale423 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,174 @@ | ||
| """Simplified A8W4 fused MoE shared expert on SM100/SM110. | ||
|
|
||
| This mirrors the SM120 demo shape, but uses TCGEN05/TMEM: | ||
| gate = input(FP8) x W_gate(FP4) | ||
| up = input(FP8) x W_up(FP4) | ||
| out = silu(gate) * up | ||
|
|
||
| The down projection is intentionally omitted here, matching the existing SM120 | ||
| diagnostic example and keeping the kernel focused on mixed FP8xFP4 TCGEN05 GEMM. | ||
| """ | ||
|
|
||
| import os | ||
| import time | ||
|
|
||
| import torch | ||
| import tilelang | ||
| import tilelang.language as T | ||
|
|
||
|
|
||
| FP4_E2M1_TO_FLOAT = [ | ||
| 0.0, | ||
| 0.5, | ||
| 1.0, | ||
| 1.5, | ||
| 2.0, | ||
| 3.0, | ||
| 4.0, | ||
| 6.0, | ||
| -0.0, | ||
| -0.5, | ||
| -1.0, | ||
| -1.5, | ||
| -2.0, | ||
| -3.0, | ||
| -4.0, | ||
| -6.0, | ||
| ] | ||
|
|
||
|
|
||
| def unpack_fp4_to_float(packed_int8, rows, cols): | ||
| lut = torch.tensor(FP4_E2M1_TO_FLOAT, dtype=torch.float32, device=packed_int8.device) | ||
| flat = packed_int8.to(torch.uint8).reshape(rows, cols // 2) | ||
| lo = flat & 0x0F | ||
| hi = (flat >> 4) & 0x0F | ||
| unpacked = torch.stack([lo, hi], dim=-1).reshape(rows, cols).to(torch.int64) | ||
| return lut[unpacked] | ||
|
|
||
|
|
||
| def fusedmoe_a8w4_sm100(num_tokens, d_hidden, d_expert, block_token=128, block_hidden=128, block_expert=64, threads=128, num_stages=1): | ||
| scale = 1.44269504 # log2(e) | ||
|
|
||
| @T.prim_func | ||
| def main( | ||
| Input: T.Tensor((num_tokens, d_hidden), "float8_e4m3fn"), | ||
| W_gate: T.Tensor((d_expert, d_hidden), T.float4_e2m1fn), | ||
| W_up: T.Tensor((d_expert, d_hidden), T.float4_e2m1fn), | ||
| Output: T.Tensor((num_tokens, d_expert), "float32"), | ||
| ): | ||
| with T.Kernel( | ||
| T.ceildiv(d_expert, block_expert), | ||
| T.ceildiv(num_tokens, block_token), | ||
| threads=threads, | ||
| ) as (bx, by): | ||
| input_shared = T.alloc_shared((block_token, block_hidden), "float8_e4m3fn") | ||
| gate_shared = T.alloc_shared((block_expert, block_hidden), T.float4_e2m1fn) | ||
| up_shared = T.alloc_shared((block_expert, block_hidden), T.float4_e2m1fn) | ||
|
|
||
| gate_tmem = T.alloc_tmem([block_token, block_expert], "float32") | ||
| up_tmem = T.alloc_tmem([block_token, block_expert], "float32") | ||
| gate_mbar = T.alloc_barrier(1) | ||
| up_mbar = T.alloc_barrier(1) | ||
|
|
||
| gate_local = T.alloc_fragment((block_token, block_expert), "float32") | ||
| up_local = T.alloc_fragment((block_token, block_expert), "float32") | ||
|
|
||
| for k in T.Pipelined(T.ceildiv(d_hidden, block_hidden), num_stages=num_stages): | ||
| T.copy(Input[by * block_token, k * block_hidden], input_shared) | ||
| T.copy(W_gate[bx * block_expert, k * block_hidden], gate_shared) | ||
| T.copy(W_up[bx * block_expert, k * block_hidden], up_shared) | ||
|
|
||
| T.tcgen05_gemm( | ||
| input_shared, | ||
| gate_shared, | ||
| gate_tmem, | ||
| transpose_A=False, | ||
| transpose_B=True, | ||
| mbar=gate_mbar, | ||
| clear_accum=(k == 0), | ||
| ) | ||
| T.mbarrier_wait_parity(gate_mbar, k % 2) | ||
|
|
||
| T.tcgen05_gemm( | ||
| input_shared, | ||
| up_shared, | ||
| up_tmem, | ||
| transpose_A=False, | ||
| transpose_B=True, | ||
| mbar=up_mbar, | ||
| clear_accum=(k == 0), | ||
| ) | ||
| T.mbarrier_wait_parity(up_mbar, k % 2) | ||
|
|
||
| T.copy(gate_tmem, gate_local) | ||
| T.copy(up_tmem, up_local) | ||
|
|
||
| for i, j in T.Parallel(block_token, block_expert): | ||
| gate = gate_local[i, j] * (1.0 / (1.0 + T.exp2(-gate_local[i, j] * scale))) | ||
| up_local[i, j] = up_local[i, j] * gate | ||
|
|
||
| T.copy(up_local, Output[by * block_token, bx * block_expert]) | ||
|
|
||
| return main | ||
|
|
||
|
|
||
| num_tokens = int(os.environ.get("TL_MOE_TOKENS", "128")) | ||
| d_hidden = int(os.environ.get("TL_MOE_HIDDEN", "256")) | ||
| d_expert = int(os.environ.get("TL_MOE_EXPERT", "256")) | ||
| block_token = int(os.environ.get("TL_MOE_BLOCK_TOKEN", "128")) | ||
| block_hidden = int(os.environ.get("TL_MOE_BLOCK_HIDDEN", "128")) | ||
| block_expert = int(os.environ.get("TL_MOE_BLOCK_EXPERT", "64")) | ||
|
|
||
| if d_hidden % 2 != 0: | ||
| raise ValueError("TL_MOE_HIDDEN must be even for packed FP4 expert weights") | ||
|
|
||
| print( | ||
| f"Running SM100 A8W4 fused MoE: tokens={num_tokens}, hidden={d_hidden}, " | ||
| f"expert={d_expert}, block=({block_token},{block_hidden},{block_expert})" | ||
| ) | ||
|
|
||
| func = fusedmoe_a8w4_sm100(num_tokens, d_hidden, d_expert, block_token, block_hidden, block_expert) | ||
| jit_kernel = tilelang.compile( | ||
| func, | ||
| out_idx=[3], | ||
| target="cuda", | ||
| pass_configs={ | ||
| tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: False, | ||
| tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, | ||
| }, | ||
| ) | ||
| print("Compilation succeeded!") | ||
|
|
||
| torch.manual_seed(42) | ||
| input_fp8 = torch.randn(num_tokens, d_hidden, device="cuda", dtype=torch.float16).to(torch.float8_e4m3fn) | ||
| w_gate = torch.randint(0, 256, (d_expert, d_hidden // 2), device="cuda", dtype=torch.uint8).to(torch.int8) | ||
| w_up = torch.randint(0, 256, (d_expert, d_hidden // 2), device="cuda", dtype=torch.uint8).to(torch.int8) | ||
|
|
||
| z_input = torch.zeros(num_tokens, d_hidden, device="cuda", dtype=torch.float8_e4m3fn) | ||
| z_gate = torch.zeros(d_expert, d_hidden // 2, device="cuda", dtype=torch.int8) | ||
| z_up = torch.zeros(d_expert, d_hidden // 2, device="cuda", dtype=torch.int8) | ||
| c_zero = jit_kernel(z_input, z_gate, z_up) | ||
| assert c_zero.abs().max().item() == 0.0, f"Zero test failed: max={c_zero.abs().max().item()}" | ||
| print("[PASS] zeros in -> zeros out") | ||
|
|
||
| out = jit_kernel(input_fp8, w_gate, w_up) | ||
| input_f32 = input_fp8.to(torch.float32) | ||
| gate_logits = input_f32 @ unpack_fp4_to_float(w_gate, d_expert, d_hidden).T | ||
| up_logits = input_f32 @ unpack_fp4_to_float(w_up, d_expert, d_hidden).T | ||
| ref = up_logits * (gate_logits * torch.sigmoid(gate_logits)) | ||
|
|
||
| diff = (out.float() - ref).abs() | ||
| max_diff = diff.max().item() | ||
| rel_err = diff.sum().item() / (ref.abs().sum().item() + 1e-10) | ||
| print(f"[NUMERICAL] max_abs_diff={max_diff:.4f}, rel_err={rel_err:.6f}") | ||
| print("[PASS] MoE numerical verification" if rel_err < 0.05 else "[WARN] large diff") | ||
|
|
||
| torch.cuda.synchronize() | ||
| start = time.perf_counter() | ||
| for _ in range(100): | ||
| jit_kernel(input_fp8, w_gate, w_up) | ||
| torch.cuda.synchronize() | ||
| elapsed = (time.perf_counter() - start) / 100 * 1000 | ||
| total_flops = 2 * num_tokens * d_hidden * d_expert * 2 | ||
| print(f"Latency: {elapsed:.4f} ms") | ||
| print(f"TFLOPS: {total_flops / (elapsed / 1e3) / 1e12:.2f}") | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,175 @@ | ||
| """FP4 Fused MoE shared expert kernel on SM120 (A8W4 mode). | ||
|
|
||
| Demonstrates the core MoE gate/up compute pattern with FP4 weights: | ||
| 1. Gate GEMM: input(FP8) x W_gate(FP4) -> gate logits | ||
| 2. Up GEMM: input(FP8) x W_up(FP4) -> up logits | ||
| 3. SiLU(gate) * up | ||
|
|
||
| Uses SM120 native kind::f8f6f4 MMA (FP8 x FP4 -> FP32). | ||
| Expert weights are stored as unpacked uint8 (1 FP4 per byte, low nibble). | ||
|
|
||
| This is a simplified single-expert example. For full routing + grouped GEMM, | ||
| see examples/fusedmoe/example_fusedmoe_tilelang.py. | ||
| """ | ||
|
|
||
| import time | ||
| import torch | ||
| import tilelang | ||
| import tilelang.language as T | ||
|
|
||
|
|
||
| FP4_E2M1_TO_FLOAT = [ | ||
| 0.0, | ||
| 0.5, | ||
| 1.0, | ||
| 1.5, | ||
| 2.0, | ||
| 3.0, | ||
| 4.0, | ||
| 6.0, | ||
| -0.0, | ||
| -0.5, | ||
| -1.0, | ||
| -1.5, | ||
| -2.0, | ||
| -3.0, | ||
| -4.0, | ||
| -6.0, | ||
| ] | ||
|
|
||
|
|
||
| def fp4_uint8_to_float(t): | ||
| lut = torch.tensor(FP4_E2M1_TO_FLOAT, dtype=torch.float32, device=t.device) | ||
| return lut[t.to(torch.int64)] | ||
|
|
||
|
|
||
| def moe_shared_expert_a8w4( | ||
| num_tokens, | ||
| d_hidden, | ||
| d_expert, | ||
| block_token=128, | ||
| block_hidden=128, | ||
| block_expert=128, | ||
| threads=128, | ||
| num_stages=1, | ||
| ): | ||
| """Single shared expert: gate_up GEMM -> SiLU*up -> down GEMM.""" | ||
| scale = 1.44269504 # log2(e) for fast SiLU | ||
|
|
||
| @T.prim_func | ||
| def main( | ||
| Input: T.Tensor((num_tokens, d_hidden), "float8_e4m3fn"), | ||
| W_gate: T.Tensor((d_expert, d_hidden), "uint8"), | ||
| W_up: T.Tensor((d_expert, d_hidden), "uint8"), | ||
| output: T.Tensor((num_tokens, d_expert), "float32"), | ||
| ): | ||
| # Step 1: Gate + Up GEMMs (fused in one kernel launch) | ||
| with T.Kernel( | ||
| T.ceildiv(num_tokens, block_token), | ||
| T.ceildiv(d_expert, block_expert), | ||
| threads=threads, | ||
| ) as (bx, by): | ||
| input_shared = T.alloc_shared((block_token, block_hidden), "float8_e4m3fn") | ||
| W_gate_shared = T.alloc_shared((block_expert, block_hidden), "uint8") | ||
| W_up_shared = T.alloc_shared((block_expert, block_hidden), "uint8") | ||
|
|
||
| gate_local = T.alloc_fragment((block_token, block_expert), "float32") | ||
| up_local = T.alloc_fragment((block_token, block_expert), "float32") | ||
|
|
||
| T.clear(gate_local) | ||
| T.clear(up_local) | ||
|
|
||
| for k in T.Pipelined(T.ceildiv(d_hidden, block_hidden), num_stages=num_stages): | ||
| T.copy(Input[bx * block_token, k * block_hidden], input_shared) | ||
| T.copy(W_gate[by * block_expert, k * block_hidden], W_gate_shared) | ||
| T.copy(W_up[by * block_expert, k * block_hidden], W_up_shared) | ||
| T.gemm(input_shared, W_gate_shared, gate_local, transpose_B=True) | ||
| T.gemm(input_shared, W_up_shared, up_local, transpose_B=True) | ||
|
|
||
| # Fused SiLU activation: gate = gate * sigmoid(gate), then up = up * gate | ||
| for i, j in T.Parallel(block_token, block_expert): | ||
| gate_local[i, j] = gate_local[i, j] * (1.0 / (1.0 + T.exp2(-gate_local[i, j] * scale))) | ||
| up_local[i, j] = up_local[i, j] * gate_local[i, j] | ||
|
|
||
| T.copy(up_local, output[bx * block_token, by * block_expert]) | ||
|
|
||
| return main | ||
|
|
||
|
|
||
| # Problem sizes (small for testing) | ||
| num_tokens = 128 | ||
| d_hidden = 256 | ||
| d_expert = 256 | ||
|
|
||
| print(f"Running FP4 MoE (A8W4): tokens={num_tokens}, hidden={d_hidden}, expert={d_expert}") | ||
|
|
||
| func = moe_shared_expert_a8w4( | ||
| num_tokens, | ||
| d_hidden, | ||
| d_expert, | ||
| block_token=128, | ||
| block_hidden=128, | ||
| block_expert=128, | ||
| threads=128, | ||
| num_stages=1, | ||
| ) | ||
|
|
||
| jit_kernel = tilelang.compile( | ||
| func, | ||
| out_idx=[3], | ||
| target="cuda", | ||
| pass_configs={ | ||
| tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, | ||
| tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, | ||
| }, | ||
| ) | ||
|
|
||
| print("Compilation succeeded!") | ||
|
|
||
| torch.manual_seed(42) | ||
|
|
||
| # Create test data | ||
| input_fp8 = torch.randn(num_tokens, d_hidden, device="cuda", dtype=torch.float16).to(torch.float8_e4m3fn) | ||
| W_gate_uint8 = torch.randint(0, 16, (d_expert, d_hidden), device="cuda", dtype=torch.uint8) | ||
| W_up_uint8 = torch.randint(0, 16, (d_expert, d_hidden), device="cuda", dtype=torch.uint8) | ||
|
|
||
| # --- Test 1: zeros --- | ||
| z_input = torch.zeros(num_tokens, d_hidden, device="cuda", dtype=torch.float8_e4m3fn) | ||
| z_gate = torch.zeros(d_expert, d_hidden, device="cuda", dtype=torch.uint8) | ||
| z_up = torch.zeros(d_expert, d_hidden, device="cuda", dtype=torch.uint8) | ||
| c_zero = jit_kernel(z_input, z_gate, z_up) | ||
| assert c_zero.abs().max().item() == 0.0, f"Zero test failed: max={c_zero.abs().max().item()}" | ||
| print("[PASS] zeros in -> zeros out") | ||
|
|
||
| # --- Test 2: numerical verification (gate+up only, no down GEMM in this kernel) --- | ||
| out = jit_kernel(input_fp8, W_gate_uint8, W_up_uint8) | ||
|
|
||
| # Reference | ||
| input_f32 = input_fp8.to(torch.float32) | ||
| gate_f32 = fp4_uint8_to_float(W_gate_uint8) | ||
| up_f32 = fp4_uint8_to_float(W_up_uint8) | ||
|
|
||
| gate_logits = input_f32 @ gate_f32.T | ||
| up_logits = input_f32 @ up_f32.T | ||
| gate_activated = gate_logits * torch.sigmoid(gate_logits) | ||
| ref_out = up_logits * gate_activated | ||
|
|
||
| diff = (out.float() - ref_out).abs() | ||
| max_diff = diff.max().item() | ||
| rel_err = diff.sum().item() / (ref_out.abs().sum().item() + 1e-10) | ||
| print(f"[NUMERICAL] max_abs_diff={max_diff:.4f}, rel_err={rel_err:.6f}") | ||
| if rel_err < 0.05: | ||
| print("[PASS] MoE gate+up fusion numerical verification") | ||
| else: | ||
| print("[WARN] large diff") | ||
|
|
||
| # --- Benchmark --- | ||
| torch.cuda.synchronize() | ||
| start = time.perf_counter() | ||
| for _ in range(100): | ||
| jit_kernel(input_fp8, W_gate_uint8, W_up_uint8) | ||
| torch.cuda.synchronize() | ||
| elapsed = (time.perf_counter() - start) / 100 * 1000 | ||
| total_flops = 2 * num_tokens * d_hidden * d_expert * 2 # 2 GEMMs | ||
| print(f"Latency: {elapsed:.4f} ms") | ||
| print(f"TFLOPS: {total_flops / (elapsed / 1e3) / 1e12:.2f}") |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.