-
Notifications
You must be signed in to change notification settings - Fork 609
[Stacked][Feature] Support NVFP4 Gemm on Blackwell arch (SM100,110,120) #2324
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 20 commits
a356049
44336d7
7cf06fa
f068f2f
d116372
f13a6b7
8c915cb
10a2710
eeddaac
3fb9c9d
63c9451
8732c4c
70afbc3
f816a4e
57449a8
0d2a9cb
d153d7a
edcd752
6ed1e83
48db769
8aa292a
4dd1a5e
1290a96
1111bfd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,171 @@ | ||
| """Simplified A8W4 fused MoE shared expert on SM100/SM110. | ||
|
|
||
| This mirrors the SM120 demo shape, but uses TCGEN05/TMEM: | ||
| gate = input(FP8) x W_gate(FP4) | ||
| up = input(FP8) x W_up(FP4) | ||
| out = silu(gate) * up | ||
|
|
||
| The down projection is intentionally omitted here, matching the existing SM120 | ||
| diagnostic example and keeping the kernel focused on mixed FP8xFP4 TCGEN05 GEMM. | ||
| """ | ||
|
|
||
| import os | ||
| import time | ||
|
|
||
| import torch | ||
| import tilelang | ||
| import tilelang.language as T | ||
|
|
||
|
|
||
| FP4_E2M1_TO_FLOAT = [ | ||
| 0.0, | ||
| 0.5, | ||
| 1.0, | ||
| 1.5, | ||
| 2.0, | ||
| 3.0, | ||
| 4.0, | ||
| 6.0, | ||
| -0.0, | ||
| -0.5, | ||
| -1.0, | ||
| -1.5, | ||
| -2.0, | ||
| -3.0, | ||
| -4.0, | ||
| -6.0, | ||
| ] | ||
|
|
||
|
|
||
| def unpack_fp4_to_float(packed_int8, rows, cols): | ||
| lut = torch.tensor(FP4_E2M1_TO_FLOAT, dtype=torch.float32, device=packed_int8.device) | ||
| flat = packed_int8.to(torch.uint8).reshape(rows, cols // 2) | ||
| lo = flat & 0x0F | ||
| hi = (flat >> 4) & 0x0F | ||
| unpacked = torch.stack([lo, hi], dim=-1).reshape(rows, cols).to(torch.int64) | ||
| return lut[unpacked] | ||
|
|
||
|
|
||
| def fusedmoe_a8w4_sm100(num_tokens, d_hidden, d_expert, block_token=128, block_hidden=128, block_expert=64, threads=128, num_stages=1): | ||
| scale = 1.44269504 # log2(e) | ||
|
|
||
| @T.prim_func | ||
| def main( | ||
| Input: T.Tensor((num_tokens, d_hidden), "float8_e4m3fn"), | ||
| W_gate: T.Tensor((d_expert, d_hidden), T.float4_e2m1fn), | ||
| W_up: T.Tensor((d_expert, d_hidden), T.float4_e2m1fn), | ||
| Output: T.Tensor((num_tokens, d_expert), "float32"), | ||
| ): | ||
| with T.Kernel( | ||
| T.ceildiv(d_expert, block_expert), | ||
| T.ceildiv(num_tokens, block_token), | ||
| threads=threads, | ||
| ) as (bx, by): | ||
| input_shared = T.alloc_shared((block_token, block_hidden), "float8_e4m3fn") | ||
| gate_shared = T.alloc_shared((block_expert, block_hidden), T.float4_e2m1fn) | ||
| up_shared = T.alloc_shared((block_expert, block_hidden), T.float4_e2m1fn) | ||
|
|
||
| gate_tmem = T.alloc_tmem([block_token, block_expert], "float32") | ||
| up_tmem = T.alloc_tmem([block_token, block_expert], "float32") | ||
| gate_mbar = T.alloc_barrier(1) | ||
| up_mbar = T.alloc_barrier(1) | ||
|
|
||
| gate_local = T.alloc_fragment((block_token, block_expert), "float32") | ||
| up_local = T.alloc_fragment((block_token, block_expert), "float32") | ||
|
|
||
| for k in T.Pipelined(T.ceildiv(d_hidden, block_hidden), num_stages=num_stages): | ||
| T.copy(Input[by * block_token, k * block_hidden], input_shared) | ||
| T.copy(W_gate[bx * block_expert, k * block_hidden], gate_shared) | ||
| T.copy(W_up[bx * block_expert, k * block_hidden], up_shared) | ||
|
|
||
| T.tcgen05_gemm( | ||
| input_shared, | ||
| gate_shared, | ||
| gate_tmem, | ||
| transpose_A=False, | ||
| transpose_B=True, | ||
| mbar=gate_mbar, | ||
| clear_accum=(k == 0), | ||
| ) | ||
| T.mbarrier_wait_parity(gate_mbar, k % 2) | ||
|
|
||
| T.tcgen05_gemm( | ||
| input_shared, | ||
| up_shared, | ||
| up_tmem, | ||
| transpose_A=False, | ||
| transpose_B=True, | ||
| mbar=up_mbar, | ||
| clear_accum=(k == 0), | ||
| ) | ||
| T.mbarrier_wait_parity(up_mbar, k % 2) | ||
|
|
||
| T.copy(gate_tmem, gate_local) | ||
| T.copy(up_tmem, up_local) | ||
|
|
||
| for i, j in T.Parallel(block_token, block_expert): | ||
| gate = gate_local[i, j] * (1.0 / (1.0 + T.exp2(-gate_local[i, j] * scale))) | ||
| up_local[i, j] = up_local[i, j] * gate | ||
|
|
||
| T.copy(up_local, Output[by * block_token, bx * block_expert]) | ||
|
|
||
| return main | ||
|
|
||
|
|
||
| num_tokens = int(os.environ.get("TL_MOE_TOKENS", "128")) | ||
| d_hidden = int(os.environ.get("TL_MOE_HIDDEN", "256")) | ||
| d_expert = int(os.environ.get("TL_MOE_EXPERT", "256")) | ||
| block_token = int(os.environ.get("TL_MOE_BLOCK_TOKEN", "128")) | ||
| block_hidden = int(os.environ.get("TL_MOE_BLOCK_HIDDEN", "128")) | ||
| block_expert = int(os.environ.get("TL_MOE_BLOCK_EXPERT", "64")) | ||
|
|
||
| print( | ||
| f"Running SM100 A8W4 fused MoE: tokens={num_tokens}, hidden={d_hidden}, " | ||
| f"expert={d_expert}, block=({block_token},{block_hidden},{block_expert})" | ||
| ) | ||
|
|
||
| func = fusedmoe_a8w4_sm100(num_tokens, d_hidden, d_expert, block_token, block_hidden, block_expert) | ||
| jit_kernel = tilelang.compile( | ||
| func, | ||
| out_idx=[3], | ||
| target="cuda", | ||
| pass_configs={ | ||
| tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: False, | ||
| tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, | ||
| }, | ||
| ) | ||
| print("Compilation succeeded!") | ||
|
|
||
| torch.manual_seed(42) | ||
| input_fp8 = torch.randn(num_tokens, d_hidden, device="cuda", dtype=torch.float16).to(torch.float8_e4m3fn) | ||
| w_gate = torch.randint(0, 256, (d_expert, d_hidden // 2), device="cuda", dtype=torch.uint8).to(torch.int8) | ||
| w_up = torch.randint(0, 256, (d_expert, d_hidden // 2), device="cuda", dtype=torch.uint8).to(torch.int8) | ||
|
|
||
| z_input = torch.zeros(num_tokens, d_hidden, device="cuda", dtype=torch.float8_e4m3fn) | ||
| z_gate = torch.zeros(d_expert, d_hidden // 2, device="cuda", dtype=torch.int8) | ||
| z_up = torch.zeros(d_expert, d_hidden // 2, device="cuda", dtype=torch.int8) | ||
| c_zero = jit_kernel(z_input, z_gate, z_up) | ||
| assert c_zero.abs().max().item() == 0.0, f"Zero test failed: max={c_zero.abs().max().item()}" | ||
| print("[PASS] zeros in -> zeros out") | ||
|
|
||
| out = jit_kernel(input_fp8, w_gate, w_up) | ||
| input_f32 = input_fp8.to(torch.float32) | ||
| gate_logits = input_f32 @ unpack_fp4_to_float(w_gate, d_expert, d_hidden).T | ||
| up_logits = input_f32 @ unpack_fp4_to_float(w_up, d_expert, d_hidden).T | ||
| ref = up_logits * (gate_logits * torch.sigmoid(gate_logits)) | ||
|
|
||
| diff = (out.float() - ref).abs() | ||
| max_diff = diff.max().item() | ||
| rel_err = diff.sum().item() / (ref.abs().sum().item() + 1e-10) | ||
| print(f"[NUMERICAL] max_abs_diff={max_diff:.4f}, rel_err={rel_err:.6f}") | ||
| print("[PASS] MoE numerical verification" if rel_err < 0.05 else "[WARN] large diff") | ||
|
|
||
| torch.cuda.synchronize() | ||
| start = time.perf_counter() | ||
| for _ in range(100): | ||
| jit_kernel(input_fp8, w_gate, w_up) | ||
| torch.cuda.synchronize() | ||
| elapsed = (time.perf_counter() - start) / 100 * 1000 | ||
| total_flops = 2 * num_tokens * d_hidden * d_expert * 2 | ||
| print(f"Latency: {elapsed:.4f} ms") | ||
| print(f"TFLOPS: {total_flops / (elapsed / 1e3) / 1e12:.2f}") | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,174 @@ | ||||||||||||||||||||||||||||
| """FP4 Fused MoE shared expert kernel on SM120 (A8W4 mode). | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| Demonstrates the core MoE gate/up compute pattern with FP4 weights: | ||||||||||||||||||||||||||||
| 1. Gate GEMM: input(FP8) x W_gate(FP4) -> gate logits | ||||||||||||||||||||||||||||
| 2. Up GEMM: input(FP8) x W_up(FP4) -> up logits | ||||||||||||||||||||||||||||
| 3. SiLU(gate) * up | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| Uses SM120 native kind::f8f6f4 MMA (FP8 x FP4 -> FP32). | ||||||||||||||||||||||||||||
| Expert weights are stored as unpacked uint8 (1 FP4 per byte, low nibble). | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| This is a simplified single-expert example. For full routing + grouped GEMM, | ||||||||||||||||||||||||||||
| see examples/fusedmoe/example_fusedmoe_tilelang.py. | ||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| import time | ||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||
| import tilelang | ||||||||||||||||||||||||||||
| import tilelang.language as T | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| FP4_E2M1_TO_FLOAT = [ | ||||||||||||||||||||||||||||
| 0.0, | ||||||||||||||||||||||||||||
| 0.5, | ||||||||||||||||||||||||||||
| 1.0, | ||||||||||||||||||||||||||||
| 1.5, | ||||||||||||||||||||||||||||
| 2.0, | ||||||||||||||||||||||||||||
| 3.0, | ||||||||||||||||||||||||||||
| 4.0, | ||||||||||||||||||||||||||||
| 6.0, | ||||||||||||||||||||||||||||
| -0.0, | ||||||||||||||||||||||||||||
| -0.5, | ||||||||||||||||||||||||||||
| -1.0, | ||||||||||||||||||||||||||||
| -1.5, | ||||||||||||||||||||||||||||
| -2.0, | ||||||||||||||||||||||||||||
| -3.0, | ||||||||||||||||||||||||||||
| -4.0, | ||||||||||||||||||||||||||||
| -6.0, | ||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| def fp4_uint8_to_float(t): | ||||||||||||||||||||||||||||
| lut = torch.tensor(FP4_E2M1_TO_FLOAT, dtype=torch.float32, device=t.device) | ||||||||||||||||||||||||||||
| return lut[t.to(torch.int64)] | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| def moe_shared_expert_a8w4( | ||||||||||||||||||||||||||||
| num_tokens, | ||||||||||||||||||||||||||||
| d_hidden, | ||||||||||||||||||||||||||||
| d_expert, | ||||||||||||||||||||||||||||
| block_token=128, | ||||||||||||||||||||||||||||
| block_hidden=128, | ||||||||||||||||||||||||||||
| block_expert=128, | ||||||||||||||||||||||||||||
| threads=128, | ||||||||||||||||||||||||||||
| num_stages=1, | ||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||
| """Single shared expert: gate_up GEMM -> SiLU*up -> down GEMM.""" | ||||||||||||||||||||||||||||
| scale = 1.44269504 # log2(e) for fast SiLU | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| @T.prim_func | ||||||||||||||||||||||||||||
| def main( | ||||||||||||||||||||||||||||
| Input: T.Tensor((num_tokens, d_hidden), "float8_e4m3fn"), | ||||||||||||||||||||||||||||
| W_gate: T.Tensor((d_expert, d_hidden), "uint8"), | ||||||||||||||||||||||||||||
| W_up: T.Tensor((d_expert, d_hidden), "uint8"), | ||||||||||||||||||||||||||||
| output: T.Tensor((num_tokens, d_expert), "float32"), | ||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||
| # Step 1: Gate + Up GEMMs (fused in one kernel launch) | ||||||||||||||||||||||||||||
| with T.Kernel( | ||||||||||||||||||||||||||||
| T.ceildiv(num_tokens, block_token), | ||||||||||||||||||||||||||||
| T.ceildiv(d_expert, block_expert), | ||||||||||||||||||||||||||||
| threads=threads, | ||||||||||||||||||||||||||||
| ) as (bx, by): | ||||||||||||||||||||||||||||
| input_shared = T.alloc_shared((block_token, block_hidden), "float8_e4m3fn") | ||||||||||||||||||||||||||||
| W_gate_shared = T.alloc_shared((block_expert, block_hidden), "uint8") | ||||||||||||||||||||||||||||
| W_up_shared = T.alloc_shared((block_expert, block_hidden), "uint8") | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| gate_local = T.alloc_fragment((block_token, block_expert), "float32") | ||||||||||||||||||||||||||||
| up_local = T.alloc_fragment((block_token, block_expert), "float32") | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| T.clear(gate_local) | ||||||||||||||||||||||||||||
| T.clear(up_local) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| for k in T.Pipelined(T.ceildiv(d_hidden, block_hidden), num_stages=num_stages): | ||||||||||||||||||||||||||||
| T.copy(Input[bx * block_token, k * block_hidden], input_shared) | ||||||||||||||||||||||||||||
| T.copy(W_gate[by * block_expert, k * block_hidden], W_gate_shared) | ||||||||||||||||||||||||||||
| T.copy(W_up[by * block_expert, k * block_hidden], W_up_shared) | ||||||||||||||||||||||||||||
| T.gemm(input_shared, W_gate_shared, gate_local, transpose_B=True) | ||||||||||||||||||||||||||||
| T.gemm(input_shared, W_up_shared, up_local, transpose_B=True) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| # Fused SiLU activation: gate = gate * sigmoid(gate), then up = up * gate | ||||||||||||||||||||||||||||
| for i, j in T.Parallel(block_token, block_expert): | ||||||||||||||||||||||||||||
| gate_local[i, j] = gate_local[i, j] * (1.0 / (1.0 + T.exp2(-gate_local[i, j] * scale))) | ||||||||||||||||||||||||||||
| up_local[i, j] = up_local[i, j] * gate_local[i, j] | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| T.copy(up_local, output[bx * block_token, by * block_expert]) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| return main | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| # Problem sizes (small for testing) | ||||||||||||||||||||||||||||
| num_tokens = 128 | ||||||||||||||||||||||||||||
| d_hidden = 256 | ||||||||||||||||||||||||||||
| d_expert = 256 | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| print(f"Running FP4 MoE (A8W4): tokens={num_tokens}, hidden={d_hidden}, expert={d_expert}") | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| func = moe_shared_expert_a8w4( | ||||||||||||||||||||||||||||
| num_tokens, | ||||||||||||||||||||||||||||
| d_hidden, | ||||||||||||||||||||||||||||
| d_expert, | ||||||||||||||||||||||||||||
| block_token=128, | ||||||||||||||||||||||||||||
| block_hidden=128, | ||||||||||||||||||||||||||||
| block_expert=128, | ||||||||||||||||||||||||||||
| threads=128, | ||||||||||||||||||||||||||||
| num_stages=1, | ||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| jit_kernel = tilelang.compile( | ||||||||||||||||||||||||||||
| func, | ||||||||||||||||||||||||||||
| out_idx=[3], | ||||||||||||||||||||||||||||
| target="cuda", | ||||||||||||||||||||||||||||
| pass_configs={ | ||||||||||||||||||||||||||||
| tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, | ||||||||||||||||||||||||||||
| tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, | ||||||||||||||||||||||||||||
| }, | ||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| print("Compilation succeeded!") | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| torch.manual_seed(42) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| # Create test data | ||||||||||||||||||||||||||||
| input_fp8 = torch.randn(num_tokens, d_hidden, device="cuda", dtype=torch.float16).to(torch.float8_e4m3fn) | ||||||||||||||||||||||||||||
| W_gate_uint8 = torch.randint(0, 16, (d_expert, d_hidden), device="cuda", dtype=torch.uint8) | ||||||||||||||||||||||||||||
| W_up_uint8 = torch.randint(0, 16, (d_expert, d_hidden), device="cuda", dtype=torch.uint8) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| # --- Test 1: zeros --- | ||||||||||||||||||||||||||||
| z_input = torch.zeros(num_tokens, d_hidden, device="cuda", dtype=torch.float8_e4m3fn) | ||||||||||||||||||||||||||||
| z_gate = torch.zeros(d_expert, d_hidden, device="cuda", dtype=torch.uint8) | ||||||||||||||||||||||||||||
| z_up = torch.zeros(d_expert, d_hidden, device="cuda", dtype=torch.uint8) | ||||||||||||||||||||||||||||
| c_zero = jit_kernel(z_input, z_gate, z_up) | ||||||||||||||||||||||||||||
| print(f"[{'PASS' if c_zero.abs().max().item() == 0.0 else 'FAIL'}] zeros in -> zeros out") | ||||||||||||||||||||||||||||
|
Comment on lines
+136
to
+141
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Make the zero-input smoke test fail closed. Right now this only prints Suggested fix c_zero = jit_kernel(z_input, z_gate, z_up)
-print(f"[{'PASS' if c_zero.abs().max().item() == 0.0 else 'FAIL'}] zeros in -> zeros out")
+assert c_zero.abs().max().item() == 0.0, f"Zero test failed: max={c_zero.abs().max().item()}"
+print("[PASS] zeros in -> zeros out")📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| # --- Test 2: numerical verification (gate+up only, no down GEMM in this kernel) --- | ||||||||||||||||||||||||||||
| out = jit_kernel(input_fp8, W_gate_uint8, W_up_uint8) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| # Reference | ||||||||||||||||||||||||||||
| input_f32 = input_fp8.to(torch.float32) | ||||||||||||||||||||||||||||
| gate_f32 = fp4_uint8_to_float(W_gate_uint8) | ||||||||||||||||||||||||||||
| up_f32 = fp4_uint8_to_float(W_up_uint8) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| gate_logits = input_f32 @ gate_f32.T | ||||||||||||||||||||||||||||
| up_logits = input_f32 @ up_f32.T | ||||||||||||||||||||||||||||
| gate_activated = gate_logits * torch.sigmoid(gate_logits) | ||||||||||||||||||||||||||||
| ref_out = up_logits * gate_activated | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| diff = (out.float() - ref_out).abs() | ||||||||||||||||||||||||||||
| max_diff = diff.max().item() | ||||||||||||||||||||||||||||
| rel_err = diff.sum().item() / (ref_out.abs().sum().item() + 1e-10) | ||||||||||||||||||||||||||||
| print(f"[NUMERICAL] max_abs_diff={max_diff:.4f}, rel_err={rel_err:.6f}") | ||||||||||||||||||||||||||||
| if rel_err < 0.05: | ||||||||||||||||||||||||||||
| print("[PASS] MoE gate+up fusion numerical verification") | ||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||
| print("[WARN] large diff") | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| # --- Benchmark --- | ||||||||||||||||||||||||||||
| torch.cuda.synchronize() | ||||||||||||||||||||||||||||
| start = time.perf_counter() | ||||||||||||||||||||||||||||
| for _ in range(100): | ||||||||||||||||||||||||||||
| jit_kernel(input_fp8, W_gate_uint8, W_up_uint8) | ||||||||||||||||||||||||||||
| torch.cuda.synchronize() | ||||||||||||||||||||||||||||
| elapsed = (time.perf_counter() - start) / 100 * 1000 | ||||||||||||||||||||||||||||
| total_flops = 2 * num_tokens * d_hidden * d_expert * 2 # 2 GEMMs | ||||||||||||||||||||||||||||
| print(f"Latency: {elapsed:.4f} ms") | ||||||||||||||||||||||||||||
| print(f"TFLOPS: {total_flops / (elapsed / 1e3) / 1e12:.2f}") | ||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guard
TL_MOE_HIDDENfor packed FP4 storage.Both expert weight tensors are packed as
(d_expert, d_hidden // 2), so this example only works whend_hiddenis even. Right now an oddTL_MOE_HIDDENis accepted and the failure shows up later during unpack/reference instead of at config parsing.Suggested guard
num_tokens = int(os.environ.get("TL_MOE_TOKENS", "128")) d_hidden = int(os.environ.get("TL_MOE_HIDDEN", "256")) d_expert = int(os.environ.get("TL_MOE_EXPERT", "256")) @@ block_token = int(os.environ.get("TL_MOE_BLOCK_TOKEN", "128")) block_hidden = int(os.environ.get("TL_MOE_BLOCK_HIDDEN", "128")) block_expert = int(os.environ.get("TL_MOE_BLOCK_EXPERT", "64")) + +if d_hidden % 2 != 0: + raise ValueError("TL_MOE_HIDDEN must be even for packed FP4 expert weights")Also applies to: 115-120, 141-155
🤖 Prompt for AI Agents