From a356049564e2ddb4a62498f8feff2639586f1040 Mon Sep 17 00:00:00 2001 From: wahao Date: Mon, 9 Mar 2026 15:14:25 +0800 Subject: [PATCH 01/19] feat: add float4_e2m1fn (NV FP4) support for SM100 TCGEN5 MMA Add the plumbing required to route float4_e2m1fn through the TCGEN5 MMA code-generation path so that FP4 GEMM kernels can be emitted on SM100. Changes: - ptx.h / ptx.cc: add kFloat4_e2m1fn enum, string tables, DTypeFromString - common.h: add kFloat4_e2m1fn to device-side DataType enum - tcgen5_meta.h: add FP4 branch in encode_dtype (format code 2) - tcgen05mma.h: add kFloat4_e2m1fn specializations for SS/TS/WS_SS (delegates to the existing f8f6f4 PTX kind) - mma_macro_generator.py: add dtype_abbrv mapping for float4_e2m1fn - docs/GEMM_NV_FP4_FEATURE_STEPS.md: design doc and progress tracker Addresses #1592 Made-with: Cursor --- docs/GEMM_NV_FP4_FEATURE_STEPS.md | 107 ++++++++++++++++++ src/op/tcgen5_meta.h | 2 + src/target/ptx.cc | 8 +- src/target/ptx.h | 3 +- src/tl_templates/cuda/common.h | 3 +- .../cuda/instruction/tcgen05mma.h | 30 +++++ tilelang/intrinsics/mma_macro_generator.py | 1 + 7 files changed, 149 insertions(+), 5 deletions(-) create mode 100644 docs/GEMM_NV_FP4_FEATURE_STEPS.md diff --git a/docs/GEMM_NV_FP4_FEATURE_STEPS.md b/docs/GEMM_NV_FP4_FEATURE_STEPS.md new file mode 100644 index 0000000000..5fb808496a --- /dev/null +++ b/docs/GEMM_NV_FP4_FEATURE_STEPS.md @@ -0,0 +1,107 @@ +# GEMM NV FP4 Feature – 现状 Review 与实现步骤评估 + +## 一、当前代码库现状 + +### 1. 已有基础(可直接复用) + +| 模块 | 现状 | +|------|------| +| **FP4 类型与转换** | `src/tl_templates/cuda/cuda_fp4.h` 已有 `__nv_fp4_e2m1`、`fp4_e2_t`、fp4↔half/float/bf16 转换,以及 packed load/store 辅助函数。 | +| **Load/Store** | `lower_ldg_stg.cc` 已对 `float4_e2m1fn` 做 4-bit 处理(bits()、total_bits 等)。 | +| **CuTeDSL codegen** | `codegen_cutedsl.cc` 已支持 `Float4E2M1FN`。 | +| **Python dtype** | `tilelang/language/dtypes.py` 已有 `float4_e2m1fn` 及 x2/x4/…/x64 向量类型;与 torch 的映射在 `dtypes.py` / `tensor.py`。 | +| **TCGEN5 窄精度 meta** | `src/op/tcgen5_meta.h` 中 `GetTCGEN5MMAMeta` 已把 `float4_e2m1fn` 与 float8/float6 一起纳入窄精度分支:K%32==0,M/N 与现有 F8/F6 同规则,C 为 float32 或 float16。 | +| **SM100 MMA 模板** | `gemm_sm100.h` 已有 `SM100_MMA_F8F6F4_WS_SS`(PTX `kind::f8f6f4`),支持 ≤8bit 类型;当前仅对 `cute::float_e4m3_t` / `float_e5m2_t` 做了 `DispatchInstruction`。 | +| **Dequant 示例** | `examples/dequantize_gemm/` 下有 FP4 相关示例(如 `example_dequant_gemm_fp4_hopper.py`),用于参考 FP4 数据流与 scaling。 | + +### 2. 缺口(需要补齐) + +| 模块 | 问题 | +|------|------| +| **指令描述符** | `GetTCGEN5InstrDesc()` 中 `encode_dtype()` 未处理 `float4_e2m1fn`,会走到 `LOG(FATAL)`。需为 FP4 分配正确的 format 编码(需查 CUTLASS/PTX 描述符定义)。 | +| **PTX 类型枚举与字符串** | `src/target/ptx.cc` 的 `DTypeFromString()` 无 `"float4_e2m1fn"`;`common.h` 的 `DataType` 枚举无 `kFloat4_e2m1fn`;`enum_to_str` / `num_bits` 等表也需扩展。 | +| **GEMM 到 codegen 的 dtype 传递** | 当前 GEMM Lower 时把 `ab_dtype` 转成字符串传给后端;需保证 TVM `float4_e2m1fn` → 字符串 `"float4_e2m1fn"` → PTX/codegen 能识别并在 C++ 端选用 FP4 路径。 | +| **C++ 侧 FP4 类型与 Dispatch** | `gemm_sm100.h` 中 `DispatchInstruction` 仅有 `float_e4m3_t` / `float_e5m2_t`;需增加 FP4 的 C++ 类型(如对接 `cuda_fp4.h` 的 `fp4_e2_t` 或 CUTLASS 的 FP4 类型)以及对应的 `DispatchInstruction` 特化。 | +| **CuTe/CUTLASS 对 FP4 的支持** | `UMMA::make_instr_desc`、`sm100_smem_selector` 等是否支持 4-bit 类型需确认;若 CUTLASS 已有 FP4 描述符/布局,需在 TileLang 侧对齐。 | +| **Block-scaled FP4(可选)** | NVIDIA 文档中的 `mxf4.block_scale` / `mxf4nvf4.block_scale` 为单独路径,若要做 block-scaled GEMM,需额外设计 scale 的存储与传入方式,可作二期。 | + +--- + +## 二、推荐实现步骤(按依赖顺序) + +### Phase 1:类型与描述符贯通 + +1. **PTX / common 数据类型** + - 在 `common.h` 的 `DataType` 中增加 `kFloat4_e2m1fn`(或与现有命名规范一致)。 + - 在 `ptx.cc` 的 `DTypeFromString` 中增加 `"float4_e2m1fn"`(及可选别名)映射到该枚举。 + - 在 `enum_to_str`、`num_bits`、`dtype_str` 等表中为 FP4 填好条目,保证 codegen 能正确生成类型名与位宽。 + +2. **TCGEN5 指令描述符** + - 查 CUTLASS/PTX 文档中 tcgen05.mma `kind::f8f6f4` 下 FP4 的 descriptor 编码(与 e4m3/e5m2 的差异)。 + - 在 `tcgen5_meta.h` 的 `GetTCGEN5InstrDesc()` 里为 `float4_e2m1fn` 增加 `encode_dtype` 分支,返回正确的 format 码,保证现有 GEMM Lower 路径在 A/B 为 FP4 时能生成合法描述符。 + +3. **验证路径** + - 写一个最小用例:只做「分配 FP4 buffer + 从 Python 传 dtype 到 codegen」,确认: + - TVM dtype → 字符串 → `DTypeFromString` → `DTypeEnumToString` 一致; + - 若已有 FP4 的 LDG/STG,可顺带跑一条简单 load/store 或 dequant 示例,确认无回归。 + +### Phase 2:C++ GEMM 模板与 Dispatch + +4. **C++ FP4 类型与 CuTe 对接** + - 确定 GEMM 使用的 C++ 类型:`cuda_fp4.h` 的 `fp4_e2_t` 或 CUTLASS 的等价类型。 + - 若 CUTLASS 有 `make_instr_desc` 对 FP4 的重载,在 TileLang 的 `to_cute_type`(或当前用于 F8/F6 的映射处)增加 TVM float4_e2m1fn → 该 C++ 类型的映射。 + - 若 CUTLASS 的 `sm100_smem_selector` 对 4-bit 有特殊布局(例如 packed 2×fp4/byte),在 `GemmTensorOp` 的 A/B layout 路径中接入,保证与 descriptor 一致。 + +5. **DispatchInstruction 与 SS/WS** + - 在 `gemm_sm100.h` 中为 FP4 增加 `DispatchInstruction`(以及 C=half 若需要),指向已有的 `SM100_MMA_F8F6F4_SS` / `SM100_MMA_F8F6F4_WS_SS`。 + - 确认 PTX 的 `kind::f8f6f4` 对 FP4 的 M/N/K 与 tile 约束与当前 meta 一致(K=32 等);必要时对照 CUTLASS 72a_blackwell_nvfp4_bf16_gemm 示例。 + +6. **端到端 GEMM 调用** + - 在 `gemm.cc` / `gemm_py.cc` 的 Lower 中,当 `a_`/`b_` 为 float4_e2m1fn 时,传入的 `kind_dtype` 字符串需与 `DTypeFromString` 一致(如 `"float4_e2m1fn"`)。 + - 跑一次 SM100 上 A/B 为 FP4、C 为 float32 的 GEMM kernel,对比参考实现(如 CUTLASS 72a)做数值与正确性检查。 + +### Phase 3:示例与文档 + +7. **Example** + - 在 `examples/gemm_sm100/` 或新建 `examples/gemm_fp4_sm100/` 增加一个 NV FP4 GEMM 示例:T.float4_e2m1fn 作为 A/B dtype,accum float32,与 ref(先 dequant 再 bf16/f32 GEMM)对比。 + - 若后续做 block-scaled,可在同一目录下加另一个示例,区分「无 scale」与「block scale」两种用法。 + +8. **文档与 CI** + - 在 `docs/programming_guides/type_system.md` 或 GEMM 相关文档中注明:SM100 TCGEN5 MMA 支持 float4_e2m1fn,以及当前限制(K%32、M/N 与 meta 一致、是否支持 block scale)。 + - CI 中为 SM100 增加一条 FP4 GEMM 的编译/运行(若环境有 SM100 或通过 skip 条件处理)。 + +--- + +## 三、风险与依赖 + +- **CUTLASS 版本**:FP4 描述符与 `make_instr_desc` 的细节可能随 CUTLASS 版本变化,需以当前 submodule 版本为准核对。 +- **PTXAS / CUDA**:`cuda_fp4.h` 中已有对 CUDA < 13 的 workaround(如 `__tl_cvt_fp4_to_halfraw_naive`),若运行环境 CUDA 版本较新,可再确认是否仍需要这些分支。 +- **Block-scaled FP4**:若产品需要 4× 吞吐的 block-scaled 路径,需单独设计 scale 张量、描述符与 API,工作量大于「仅非 block-scaled FP4 GEMM」。 + +--- + +## 四、与 Flash Attention SM100 的优先级 + +- Flash Attention SM100 的 example(含 wasp)可后续再迭代;当前不作为重点。 +- GEMM NV FP4 作为独立 feature,与 FA 的改动解耦;完成上述 Phase 1–2 即可得到可用的 FP4 GEMM 路径,再视需求加 block scale 或集成到更高层算子。 + +--- + +## 五、本次已完成的修改(CuTeDSL/TCGEN5 路径) + +以下修改已落地,使 **CuTeDSL 生成的 SM100 GEMM kernel** 在 A/B 为 `float4_e2m1fn` 时能正确走 `kind::f8f6f4` 的 FP4 路径。当前未改 `gemm_sm100.h` 的 CUTLASS `DispatchInstruction`(该路径由 CUTLASS 模板实例化,与 CuTeDSL 生成 TIR 再 codegen 的路径分离)。 + +| 文件 | 修改内容 | +|------|----------| +| **src/target/ptx.h** | `DataType` 枚举增加 `kFloat4_e2m1fn = 23`。 | +| **src/target/ptx.cc** | `enum_to_str` / `dtype_str` / `num_bits` 增加第 24 项;`DTypeFromString` 增加 `"float4_e2m1fn"`、`".e2m1"` → `kFloat4_e2m1fn`。 | +| **src/tl_templates/cuda/common.h** | `DataType` 枚举增加 `kFloat4_e2m1fn = 23`。 | +| **src/op/tcgen5_meta.h** | `GetTCGEN5InstrDesc()` 中 `encode_dtype` 增加 `dtype.is_float4_e2m1fn()` 分支,返回 `2`(FP4 format 编码)。 | +| **src/tl_templates/cuda/instruction/tcgen05mma.h** | 为 `DataType::kFloat4_e2m1fn` 增加 `tcgen05mma_ss`、`tcgen05mma_ts`、`tcgen05mma_ws_ss` 特化,均转发到 `kFloat8_e4m3` 的 f8f6f4 实现。 | +| **tilelang/intrinsics/mma_macro_generator.py** | `dtype_abbrv` 增加 `"float4_e2m1fn": "float4_e2m1fn"`,保证 lowering 时 `a_dtype_abbrv` 传入 `ptx_tcgen05_mma_ss("float4_e2m1fn", ...)`。 | + +**后续建议**(在有 SM100 / nvcc 环境时): + +1. 本地或 CI 执行完整构建并跑 `examples/gemm_sm100/` 中 BF16/F8 用例,确认无回归。 +2. 新增 FP4 GEMM 示例:`in_dtype=accum_dtype=T.float4_e2m1fn, T.float`,`block_K=128`(K%32 已由 meta 保证),与参考实现做数值对比。 +3. 若需走 CUTLASS 模板路径的 FP4 GEMM,再补 `gemm_sm100.h` 的 `to_cute_type` 与 `DispatchInstruction`。 diff --git a/src/op/tcgen5_meta.h b/src/op/tcgen5_meta.h index 2b17b7b542..f45cfa63e6 100644 --- a/src/op/tcgen5_meta.h +++ b/src/op/tcgen5_meta.h @@ -147,6 +147,8 @@ inline uint32_t GetTCGEN5InstrDesc(int atom_m, int atom_n, int atom_k, return static_cast(0); } else if (dtype.is_float8_e5m2fnuz() || dtype.is_float8_e5m2()) { return static_cast(1); + } else if (dtype.is_float4_e2m1fn()) { + return static_cast(2); } else if (dtype.is_int() && dtype.bits() == 8) { return static_cast(1); } else if (dtype.is_uint() && dtype.bits() == 8) { diff --git a/src/target/ptx.cc b/src/target/ptx.cc index 9bf7c5a2e7..9c242952e3 100644 --- a/src/target/ptx.cc +++ b/src/target/ptx.cc @@ -40,15 +40,15 @@ static const char *enum_to_str[] = { "kUInt16", "kInt32", "kUInt32", "kInt64", "kUInt64", "kFloat8_e4m3", "kFloat8_e5m2", "kFloat16", "kBFloat16", "kFloat16x2", "kFloat32", "kTensorFloat32", "kFloat64", "kBit1", "kBit8", - "kBit16", "kBit32", "kBit64"}; + "kBit16", "kBit32", "kBit64", "kFloat4_e2m1fn"}; static const char *dtype_str[] = { ".s4", ".u4", ".s8", ".u8", ".s16", ".u16", ".s32", ".u32", ".s64", ".u64", ".e4m3", ".e5m2", ".f16", ".bf16", ".f16x2", ".f32", - ".tf32", ".f64", ".b1", ".b8", ".b16", ".b32", ".b64"}; + ".tf32", ".f64", ".b1", ".b8", ".b16", ".b32", ".b64", ".e2m1"}; static const uint32_t num_bits[] = {4, 4, 8, 8, 16, 16, 32, 32, 64, 64, 8, 8, 16, 16, 32, 32, - 32, 64, 1, 8, 16, 32, 64}; + 32, 64, 1, 8, 16, 32, 64, 4}; /*! * \brief Create PTX data type from string. @@ -80,6 +80,8 @@ DataType DTypeFromString(const std::string str) { } else if (str == "float8_e5m2" || str == "float8_e5m2fn" || str == "e5m2" || str == ".e5m2") { return DataType::kFloat8_e5m2; + } else if (str == "float4_e2m1fn" || str == ".e2m1") { + return DataType::kFloat4_e2m1fn; } else if (str == "float16" || str == "fp16" || str == ".f16") { return DataType::kFloat16; } else if (str == "bfloat16" || str == "bf16") { diff --git a/src/target/ptx.h b/src/target/ptx.h index 85d9b947bc..04812c0a26 100644 --- a/src/target/ptx.h +++ b/src/target/ptx.h @@ -65,7 +65,8 @@ enum class DataType : int { kBit8 = 19, kBit16 = 20, kBit32 = 21, - kBit64 = 22 + kBit64 = 22, + kFloat4_e2m1fn = 23 }; /*! diff --git a/src/tl_templates/cuda/common.h b/src/tl_templates/cuda/common.h index e8aa592b18..c6fdd55c99 100644 --- a/src/tl_templates/cuda/common.h +++ b/src/tl_templates/cuda/common.h @@ -281,7 +281,8 @@ enum class DataType : int { kBit8 = 19, kBit16 = 20, kBit32 = 21, - kBit64 = 22 + kBit64 = 22, + kFloat4_e2m1fn = 23 }; union GmmaDescriptor { diff --git a/src/tl_templates/cuda/instruction/tcgen05mma.h b/src/tl_templates/cuda/instruction/tcgen05mma.h index d69e8326a9..8422e67126 100644 --- a/src/tl_templates/cuda/instruction/tcgen05mma.h +++ b/src/tl_templates/cuda/instruction/tcgen05mma.h @@ -142,6 +142,16 @@ TL_DEVICE void tcgen05mma_ts( desc_val, mask0, mask1, mask2, mask3); } +// FP4 (e2m1) uses same f8f6f4 instruction kind as FP8 +template <> +TL_DEVICE void tcgen05mma_ts( + uint32_t const &tmem_a, uint64_t const &desc_b, uint32_t const &tmem_c, + uint32_t const &scalec, uint32_t const &desc_val, int const &mask0, + int const &mask1, int const &mask2, int const &mask3) { + tcgen05mma_ts(tmem_a, desc_b, tmem_c, scalec, + desc_val, mask0, mask1, mask2, mask3); +} + // F16/BF16 instruction kind (maps to kind::f16) template <> TL_DEVICE void tcgen05mma_ss( @@ -250,6 +260,16 @@ TL_DEVICE void tcgen05mma_ss( desc_val, mask0, mask1, mask2, mask3); } +// FP4 (e2m1) uses same f8f6f4 instruction kind as FP8 +template <> +TL_DEVICE void tcgen05mma_ss( + uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c, + uint32_t const &scalec, uint32_t const &desc_val, int const &mask0, + int const &mask1, int const &mask2, int const &mask3) { + tcgen05mma_ss(desc_a, desc_b, tmem_c, scalec, + desc_val, mask0, mask1, mask2, mask3); +} + // WS variants: tcgen05.mma.ws.cta_group::1.kind::xxx // Generic declaration falls back to static assert template @@ -364,4 +384,14 @@ TL_DEVICE void tcgen05mma_ws_ss( desc_a, desc_b, tmem_c, scalec, desc_val, mask0, mask1, mask2, mask3); } +// FP4 (e2m1) ws, same f8f6f4 instruction kind +template <> +TL_DEVICE void tcgen05mma_ws_ss( + uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c, + uint32_t const &scalec, uint32_t const &desc_val, int const &mask0, + int const &mask1, int const &mask2, int const &mask3) { + tcgen05mma_ws_ss( + desc_a, desc_b, tmem_c, scalec, desc_val, mask0, mask1, mask2, mask3); +} + } // namespace tl diff --git a/tilelang/intrinsics/mma_macro_generator.py b/tilelang/intrinsics/mma_macro_generator.py index f1932245d1..e8423240dd 100644 --- a/tilelang/intrinsics/mma_macro_generator.py +++ b/tilelang/intrinsics/mma_macro_generator.py @@ -54,6 +54,7 @@ class TensorCoreIntrinEmitter: "float8_e4m3fnuz": "e4m3", "float8_e5m2": "e5m2", "float8_e5m2fnuz": "e5m2", + "float4_e2m1fn": "float4_e2m1fn", } # Represent the thread binding in the form of (tx, warp_n, warp_m) From 44336d79211ed36d45bc8e8673354afab28fae45 Mon Sep 17 00:00:00 2001 From: wahao Date: Mon, 9 Mar 2026 23:04:44 +0800 Subject: [PATCH 02/19] fix: python pipeline (LayoutInference + LowerTileOp) passes --- examples/gemm_fp4/example_gemm_fp4_sm120.py | 96 +++++++++++++++++++++ src/tl_templates/cuda/common.h | 17 ++++ src/tl_templates/cuda/cuda_fp4.h | 40 +-------- src/tl_templates/cuda/gemm_mma.h | 2 + tilelang/intrinsics/mma_macro_generator.py | 3 +- tilelang/intrinsics/utils.py | 4 +- 6 files changed, 121 insertions(+), 41 deletions(-) create mode 100644 examples/gemm_fp4/example_gemm_fp4_sm120.py diff --git a/examples/gemm_fp4/example_gemm_fp4_sm120.py b/examples/gemm_fp4/example_gemm_fp4_sm120.py new file mode 100644 index 0000000000..86c338c39d --- /dev/null +++ b/examples/gemm_fp4/example_gemm_fp4_sm120.py @@ -0,0 +1,96 @@ +"""FP4 (float4_e2m1fn) GEMM on SM120 (RTX 5080/5090) using fragment-based MMA. + +Uses mma.sync.aligned.kind::f8f6f4 instructions (not TCGEN05/TMEM). +Addresses https://github.com/tile-ai/tilelang/issues/1592 +""" + +import torch +import tilelang +import tilelang.language as T + + +def matmul_fp4( + M, + N, + K, + block_M, + block_N, + block_K, + in_dtype, + out_dtype, + accum_dtype, + num_stages=2, + threads=128, +): + A_shape = (M, K) + B_shape = (N, K) + A_shared_shape = (block_M, block_K) + B_shared_shape = (block_N, block_K) + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + T.clear(C_local) + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(A[by * block_M, ko * block_K], A_shared) + T.copy(B[bx * block_N, ko * block_K], B_shared) + T.gemm(A_shared, B_shared, C_local, transpose_B=True) + + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def calc_diff(x, y): + x, y = x.double(), y.double() + denominator = (x * x + y * y).sum() + if denominator == 0: + return 0.0 + sim = 2 * (x * y).sum() / denominator + return 1 - sim + + +M, N, K = 256, 256, 256 +block_M, block_N, block_K = 128, 128, 128 +in_dtype = T.float4_e2m1fn +out_dtype = T.float32 +accum_dtype = T.float32 +num_stages = 2 +threads = 128 + +print(f"Running FP4 GEMM: M={M}, N={N}, K={K}") +print(f" block_M={block_M}, block_N={block_N}, block_K={block_K}") +print(f" in_dtype={in_dtype}, out_dtype={out_dtype}, accum_dtype={accum_dtype}") + +func = matmul_fp4( + M, N, K, + block_M, block_N, block_K, + in_dtype, out_dtype, accum_dtype, + num_stages, threads, +) + +jit_kernel = tilelang.compile( + func, + out_idx=[2], + target="cuda", + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, +) + +print("Compilation succeeded!") +print(jit_kernel.get_kernel_source()) + +profiler = jit_kernel.get_profiler() +latency = profiler.do_bench() +print(f"Latency: {latency} ms") +print(f"TFLOPS: {2 * M * N * K / (latency / 1e3) / 1e12:.2f}") diff --git a/src/tl_templates/cuda/common.h b/src/tl_templates/cuda/common.h index c6fdd55c99..6f3802f26a 100644 --- a/src/tl_templates/cuda/common.h +++ b/src/tl_templates/cuda/common.h @@ -578,6 +578,20 @@ struct float_e5m2_t : public cute::float_e5m2_t { : cute::float_e5m2_t(*reinterpret_cast(&x)) {} }; +struct float_e2m1_t : public cute::float_e2m1_t { + using cute::float_e2m1_t::float_e2m1_t; + CUTLASS_HOST_DEVICE + float_e2m1_t() = default; + + CUTLASS_HOST_DEVICE + explicit float_e2m1_t(__nv_bfloat16 x) + : float_e2m1_t(static_cast(x)) {} + + CUTLASS_HOST_DEVICE + float_e2m1_t(cutlass::float_e2m1_t x) + : cute::float_e2m1_t(*reinterpret_cast(&x)) {} +}; + template struct to_cute_type { using type = T; }; @@ -587,6 +601,9 @@ template <> struct to_cute_type { template <> struct to_cute_type { using type = cute::float_e5m2_t; }; +template <> struct to_cute_type { + using type = cute::float_e2m1_t; +}; // Packed FP32x2 math helpers // diff --git a/src/tl_templates/cuda/cuda_fp4.h b/src/tl_templates/cuda/cuda_fp4.h index d2efef24ae..4a74a6485a 100644 --- a/src/tl_templates/cuda/cuda_fp4.h +++ b/src/tl_templates/cuda/cuda_fp4.h @@ -1,48 +1,12 @@ #pragma once #include "common.h" +#include #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) #include -// Wrapper for __nv_fp4_e2m1 with implicit conversions -struct fp4_e2_t { - __nv_fp4_storage_t __x; - - TL_DEVICE fp4_e2_t() = default; - - // Constructor from __nv_fp4_e2m1 - TL_DEVICE fp4_e2_t(__nv_fp4_e2m1 x) : __x(x.__x) {} - - // Constructor from storage type - TL_DEVICE fp4_e2_t(__nv_fp4_storage_t x) : __x(x) {} - - // Constructor from float - TL_DEVICE explicit fp4_e2_t(float x) { - __nv_fp4_e2m1 tmp(x); - __x = tmp.__x; - } - - // Conversion to __nv_fp4_e2m1 - TL_DEVICE operator __nv_fp4_e2m1() const { - __nv_fp4_e2m1 tmp; - tmp.__x = __x; - return tmp; - } - - // Conversion to float - TL_DEVICE operator float() const { - __nv_fp4_e2m1 tmp; - tmp.__x = __x; - return float(tmp); - } - - // Implicit conversion to half_t (cutlass::half_t) - TL_DEVICE operator half_t() const { return half_t(float(*this)); } - - // Implicit conversion to __half - TL_DEVICE operator __half() const { return __half(float(*this)); } -}; +using fp4_e2_t = tl::float_e2m1_t; class fp4_e2_2_t { public: diff --git a/src/tl_templates/cuda/gemm_mma.h b/src/tl_templates/cuda/gemm_mma.h index ea01fa9aab..c4ba99af58 100644 --- a/src/tl_templates/cuda/gemm_mma.h +++ b/src/tl_templates/cuda/gemm_mma.h @@ -42,8 +42,10 @@ using _X = Underscore; #ifdef __CUDA_ARCH_LIST__ #if __CUDA_ARCH_LIST__ >= 1200 #include "cuda_fp8.h" +#include "cuda_fp4.h" #include #include +TL_DISPATCH_MMA_TEMPLATE(fp4_e2_t, fp4_e2_t, float, SM120_16x8x32_TN) TL_DISPATCH_MMA_TEMPLATE(fp8_e4_t, fp8_e4_t, float, SM120_16x8x32_TN) TL_DISPATCH_MMA_TEMPLATE(fp8_e5_t, fp8_e5_t, float, SM120_16x8x32_TN) TL_DISPATCH_MMA(half_t, half_t, half_t, SM80_16x8x16_F16F16F16F16_TN) diff --git a/tilelang/intrinsics/mma_macro_generator.py b/tilelang/intrinsics/mma_macro_generator.py index e8423240dd..95e4c46f9a 100644 --- a/tilelang/intrinsics/mma_macro_generator.py +++ b/tilelang/intrinsics/mma_macro_generator.py @@ -115,7 +115,8 @@ def __init__( def _initialize_k_dim(self, a_dtype=T.float16): if isinstance(a_dtype, str): a_dtype = DataType(a_dtype) - self.k_dim = 256 // a_dtype.bits + # MMA k_dim caps at 32 (m16n8k32 is the widest K for FP8/FP4) + self.k_dim = min(256 // a_dtype.bits, 32) def _initialize_local_size(self, m_dim=16, n_dim=16, k_dim=16, warp_size=32): self.local_size_a = (m_dim * k_dim) // warp_size diff --git a/tilelang/intrinsics/utils.py b/tilelang/intrinsics/utils.py index fb24a4add2..79afebae64 100644 --- a/tilelang/intrinsics/utils.py +++ b/tilelang/intrinsics/utils.py @@ -49,7 +49,7 @@ def get_ldmatrix_offset( else: new_row_idx, new_col_idx = transform_func(row_idx, col_idx) return new_row_idx * stride + new_col_idx - elif dtype_bits == 8: + elif dtype_bits == 8 or dtype_bits == 4: if matrix == "B" and transposed: transform_func = ldmatrix_32x16_to_shared_16x32_layout_b new_row_idx, new_col_idx = transform_func(row_idx, col_idx) @@ -59,7 +59,7 @@ def get_ldmatrix_offset( new_row_idx, new_col_idx = transform_func(row_idx, col_idx) return new_row_idx * stride + new_col_idx else: - raise ValueError("ldmatrix only supports B transposed and A non-transposed for int8") + raise ValueError("ldmatrix only supports B transposed and A non-transposed for int8/fp4") else: raise ValueError(f"Unsupported dtype {dtype}") From 7cf06fa04d5241abc52a6974759fc4185b83e839 Mon Sep 17 00:00:00 2001 From: wahao Date: Thu, 12 Mar 2026 17:59:23 +0800 Subject: [PATCH 03/19] =?UTF-8?q?fix:=20SM120=20FP4=20CUDA=20compilation?= =?UTF-8?q?=20=E2=80=93=20type=20bridge,=20codegen=20naming,=20MMA=20dispa?= =?UTF-8?q?tch?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../gemm_fp4/GEMM_NV_FP4_FEATURE_STEPS.md | 107 ++++++++++++++++++ examples/gemm_fp4/example_gemm_fp4_sm120.py | 21 ++-- examples/gemm_fp4/gemm_fp4_sm120.cu | 74 ++++++++++++ src/target/codegen_cuda.cc | 2 +- src/tl_templates/cuda/cuda_fp4.h | 6 +- src/tl_templates/cuda/instruction/mma.h | 6 + 6 files changed, 205 insertions(+), 11 deletions(-) create mode 100644 examples/gemm_fp4/GEMM_NV_FP4_FEATURE_STEPS.md create mode 100644 examples/gemm_fp4/gemm_fp4_sm120.cu diff --git a/examples/gemm_fp4/GEMM_NV_FP4_FEATURE_STEPS.md b/examples/gemm_fp4/GEMM_NV_FP4_FEATURE_STEPS.md new file mode 100644 index 0000000000..5fb808496a --- /dev/null +++ b/examples/gemm_fp4/GEMM_NV_FP4_FEATURE_STEPS.md @@ -0,0 +1,107 @@ +# GEMM NV FP4 Feature – 现状 Review 与实现步骤评估 + +## 一、当前代码库现状 + +### 1. 已有基础(可直接复用) + +| 模块 | 现状 | +|------|------| +| **FP4 类型与转换** | `src/tl_templates/cuda/cuda_fp4.h` 已有 `__nv_fp4_e2m1`、`fp4_e2_t`、fp4↔half/float/bf16 转换,以及 packed load/store 辅助函数。 | +| **Load/Store** | `lower_ldg_stg.cc` 已对 `float4_e2m1fn` 做 4-bit 处理(bits()、total_bits 等)。 | +| **CuTeDSL codegen** | `codegen_cutedsl.cc` 已支持 `Float4E2M1FN`。 | +| **Python dtype** | `tilelang/language/dtypes.py` 已有 `float4_e2m1fn` 及 x2/x4/…/x64 向量类型;与 torch 的映射在 `dtypes.py` / `tensor.py`。 | +| **TCGEN5 窄精度 meta** | `src/op/tcgen5_meta.h` 中 `GetTCGEN5MMAMeta` 已把 `float4_e2m1fn` 与 float8/float6 一起纳入窄精度分支:K%32==0,M/N 与现有 F8/F6 同规则,C 为 float32 或 float16。 | +| **SM100 MMA 模板** | `gemm_sm100.h` 已有 `SM100_MMA_F8F6F4_WS_SS`(PTX `kind::f8f6f4`),支持 ≤8bit 类型;当前仅对 `cute::float_e4m3_t` / `float_e5m2_t` 做了 `DispatchInstruction`。 | +| **Dequant 示例** | `examples/dequantize_gemm/` 下有 FP4 相关示例(如 `example_dequant_gemm_fp4_hopper.py`),用于参考 FP4 数据流与 scaling。 | + +### 2. 缺口(需要补齐) + +| 模块 | 问题 | +|------|------| +| **指令描述符** | `GetTCGEN5InstrDesc()` 中 `encode_dtype()` 未处理 `float4_e2m1fn`,会走到 `LOG(FATAL)`。需为 FP4 分配正确的 format 编码(需查 CUTLASS/PTX 描述符定义)。 | +| **PTX 类型枚举与字符串** | `src/target/ptx.cc` 的 `DTypeFromString()` 无 `"float4_e2m1fn"`;`common.h` 的 `DataType` 枚举无 `kFloat4_e2m1fn`;`enum_to_str` / `num_bits` 等表也需扩展。 | +| **GEMM 到 codegen 的 dtype 传递** | 当前 GEMM Lower 时把 `ab_dtype` 转成字符串传给后端;需保证 TVM `float4_e2m1fn` → 字符串 `"float4_e2m1fn"` → PTX/codegen 能识别并在 C++ 端选用 FP4 路径。 | +| **C++ 侧 FP4 类型与 Dispatch** | `gemm_sm100.h` 中 `DispatchInstruction` 仅有 `float_e4m3_t` / `float_e5m2_t`;需增加 FP4 的 C++ 类型(如对接 `cuda_fp4.h` 的 `fp4_e2_t` 或 CUTLASS 的 FP4 类型)以及对应的 `DispatchInstruction` 特化。 | +| **CuTe/CUTLASS 对 FP4 的支持** | `UMMA::make_instr_desc`、`sm100_smem_selector` 等是否支持 4-bit 类型需确认;若 CUTLASS 已有 FP4 描述符/布局,需在 TileLang 侧对齐。 | +| **Block-scaled FP4(可选)** | NVIDIA 文档中的 `mxf4.block_scale` / `mxf4nvf4.block_scale` 为单独路径,若要做 block-scaled GEMM,需额外设计 scale 的存储与传入方式,可作二期。 | + +--- + +## 二、推荐实现步骤(按依赖顺序) + +### Phase 1:类型与描述符贯通 + +1. **PTX / common 数据类型** + - 在 `common.h` 的 `DataType` 中增加 `kFloat4_e2m1fn`(或与现有命名规范一致)。 + - 在 `ptx.cc` 的 `DTypeFromString` 中增加 `"float4_e2m1fn"`(及可选别名)映射到该枚举。 + - 在 `enum_to_str`、`num_bits`、`dtype_str` 等表中为 FP4 填好条目,保证 codegen 能正确生成类型名与位宽。 + +2. **TCGEN5 指令描述符** + - 查 CUTLASS/PTX 文档中 tcgen05.mma `kind::f8f6f4` 下 FP4 的 descriptor 编码(与 e4m3/e5m2 的差异)。 + - 在 `tcgen5_meta.h` 的 `GetTCGEN5InstrDesc()` 里为 `float4_e2m1fn` 增加 `encode_dtype` 分支,返回正确的 format 码,保证现有 GEMM Lower 路径在 A/B 为 FP4 时能生成合法描述符。 + +3. **验证路径** + - 写一个最小用例:只做「分配 FP4 buffer + 从 Python 传 dtype 到 codegen」,确认: + - TVM dtype → 字符串 → `DTypeFromString` → `DTypeEnumToString` 一致; + - 若已有 FP4 的 LDG/STG,可顺带跑一条简单 load/store 或 dequant 示例,确认无回归。 + +### Phase 2:C++ GEMM 模板与 Dispatch + +4. **C++ FP4 类型与 CuTe 对接** + - 确定 GEMM 使用的 C++ 类型:`cuda_fp4.h` 的 `fp4_e2_t` 或 CUTLASS 的等价类型。 + - 若 CUTLASS 有 `make_instr_desc` 对 FP4 的重载,在 TileLang 的 `to_cute_type`(或当前用于 F8/F6 的映射处)增加 TVM float4_e2m1fn → 该 C++ 类型的映射。 + - 若 CUTLASS 的 `sm100_smem_selector` 对 4-bit 有特殊布局(例如 packed 2×fp4/byte),在 `GemmTensorOp` 的 A/B layout 路径中接入,保证与 descriptor 一致。 + +5. **DispatchInstruction 与 SS/WS** + - 在 `gemm_sm100.h` 中为 FP4 增加 `DispatchInstruction`(以及 C=half 若需要),指向已有的 `SM100_MMA_F8F6F4_SS` / `SM100_MMA_F8F6F4_WS_SS`。 + - 确认 PTX 的 `kind::f8f6f4` 对 FP4 的 M/N/K 与 tile 约束与当前 meta 一致(K=32 等);必要时对照 CUTLASS 72a_blackwell_nvfp4_bf16_gemm 示例。 + +6. **端到端 GEMM 调用** + - 在 `gemm.cc` / `gemm_py.cc` 的 Lower 中,当 `a_`/`b_` 为 float4_e2m1fn 时,传入的 `kind_dtype` 字符串需与 `DTypeFromString` 一致(如 `"float4_e2m1fn"`)。 + - 跑一次 SM100 上 A/B 为 FP4、C 为 float32 的 GEMM kernel,对比参考实现(如 CUTLASS 72a)做数值与正确性检查。 + +### Phase 3:示例与文档 + +7. **Example** + - 在 `examples/gemm_sm100/` 或新建 `examples/gemm_fp4_sm100/` 增加一个 NV FP4 GEMM 示例:T.float4_e2m1fn 作为 A/B dtype,accum float32,与 ref(先 dequant 再 bf16/f32 GEMM)对比。 + - 若后续做 block-scaled,可在同一目录下加另一个示例,区分「无 scale」与「block scale」两种用法。 + +8. **文档与 CI** + - 在 `docs/programming_guides/type_system.md` 或 GEMM 相关文档中注明:SM100 TCGEN5 MMA 支持 float4_e2m1fn,以及当前限制(K%32、M/N 与 meta 一致、是否支持 block scale)。 + - CI 中为 SM100 增加一条 FP4 GEMM 的编译/运行(若环境有 SM100 或通过 skip 条件处理)。 + +--- + +## 三、风险与依赖 + +- **CUTLASS 版本**:FP4 描述符与 `make_instr_desc` 的细节可能随 CUTLASS 版本变化,需以当前 submodule 版本为准核对。 +- **PTXAS / CUDA**:`cuda_fp4.h` 中已有对 CUDA < 13 的 workaround(如 `__tl_cvt_fp4_to_halfraw_naive`),若运行环境 CUDA 版本较新,可再确认是否仍需要这些分支。 +- **Block-scaled FP4**:若产品需要 4× 吞吐的 block-scaled 路径,需单独设计 scale 张量、描述符与 API,工作量大于「仅非 block-scaled FP4 GEMM」。 + +--- + +## 四、与 Flash Attention SM100 的优先级 + +- Flash Attention SM100 的 example(含 wasp)可后续再迭代;当前不作为重点。 +- GEMM NV FP4 作为独立 feature,与 FA 的改动解耦;完成上述 Phase 1–2 即可得到可用的 FP4 GEMM 路径,再视需求加 block scale 或集成到更高层算子。 + +--- + +## 五、本次已完成的修改(CuTeDSL/TCGEN5 路径) + +以下修改已落地,使 **CuTeDSL 生成的 SM100 GEMM kernel** 在 A/B 为 `float4_e2m1fn` 时能正确走 `kind::f8f6f4` 的 FP4 路径。当前未改 `gemm_sm100.h` 的 CUTLASS `DispatchInstruction`(该路径由 CUTLASS 模板实例化,与 CuTeDSL 生成 TIR 再 codegen 的路径分离)。 + +| 文件 | 修改内容 | +|------|----------| +| **src/target/ptx.h** | `DataType` 枚举增加 `kFloat4_e2m1fn = 23`。 | +| **src/target/ptx.cc** | `enum_to_str` / `dtype_str` / `num_bits` 增加第 24 项;`DTypeFromString` 增加 `"float4_e2m1fn"`、`".e2m1"` → `kFloat4_e2m1fn`。 | +| **src/tl_templates/cuda/common.h** | `DataType` 枚举增加 `kFloat4_e2m1fn = 23`。 | +| **src/op/tcgen5_meta.h** | `GetTCGEN5InstrDesc()` 中 `encode_dtype` 增加 `dtype.is_float4_e2m1fn()` 分支,返回 `2`(FP4 format 编码)。 | +| **src/tl_templates/cuda/instruction/tcgen05mma.h** | 为 `DataType::kFloat4_e2m1fn` 增加 `tcgen05mma_ss`、`tcgen05mma_ts`、`tcgen05mma_ws_ss` 特化,均转发到 `kFloat8_e4m3` 的 f8f6f4 实现。 | +| **tilelang/intrinsics/mma_macro_generator.py** | `dtype_abbrv` 增加 `"float4_e2m1fn": "float4_e2m1fn"`,保证 lowering 时 `a_dtype_abbrv` 传入 `ptx_tcgen05_mma_ss("float4_e2m1fn", ...)`。 | + +**后续建议**(在有 SM100 / nvcc 环境时): + +1. 本地或 CI 执行完整构建并跑 `examples/gemm_sm100/` 中 BF16/F8 用例,确认无回归。 +2. 新增 FP4 GEMM 示例:`in_dtype=accum_dtype=T.float4_e2m1fn, T.float`,`block_K=128`(K%32 已由 meta 保证),与参考实现做数值对比。 +3. 若需走 CUTLASS 模板路径的 FP4 GEMM,再补 `gemm_sm100.h` 的 `to_cute_type` 与 `DispatchInstruction`。 diff --git a/examples/gemm_fp4/example_gemm_fp4_sm120.py b/examples/gemm_fp4/example_gemm_fp4_sm120.py index 86c338c39d..fe92957397 100644 --- a/examples/gemm_fp4/example_gemm_fp4_sm120.py +++ b/examples/gemm_fp4/example_gemm_fp4_sm120.py @@ -7,7 +7,7 @@ import torch import tilelang import tilelang.language as T - +import os def matmul_fp4( M, @@ -88,9 +88,16 @@ def calc_diff(x, y): ) print("Compilation succeeded!") -print(jit_kernel.get_kernel_source()) - -profiler = jit_kernel.get_profiler() -latency = profiler.do_bench() -print(f"Latency: {latency} ms") -print(f"TFLOPS: {2 * M * N * K / (latency / 1e3) / 1e12:.2f}") +with open(os.path.join(os.path.dirname(__file__), "gemm_fp4_sm120.cu"), "w") as f: + f.write(jit_kernel.get_kernel_source()) + +print("Kernel source written to gemm_fp4_sm120.cu") + +# NOTE: Runtime execution of FP4 kernels requires proper sub-byte pointer +# arithmetic in TVM's codegen. Currently, codegen treats fp4_e2_t as 1-byte +# (sizeof=1) but actual data is packed 2-per-byte, causing 2x address +# overshoot. This is a known TVM sub-byte codegen limitation to be fixed +# upstream. The compilation pipeline (LayoutInference → LowerTileOp → +# CUDA codegen → nvcc) works end-to-end. +print("\n[OK] FP4 GEMM kernel compiled and CUDA source generated successfully.") +print("[TODO] Runtime execution pending TVM sub-byte pointer arithmetic fix.") diff --git a/examples/gemm_fp4/gemm_fp4_sm120.cu b/examples/gemm_fp4/gemm_fp4_sm120.cu new file mode 100644 index 0000000000..85f0abc5c9 --- /dev/null +++ b/examples/gemm_fp4/gemm_fp4_sm120.cu @@ -0,0 +1,74 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#ifdef ENABLE_BF16 +#include +#endif + +extern "C" __global__ void main_kernel(const fp4_e2_t* __restrict__ A, const fp4_e2_t* __restrict__ B, float* __restrict__ C); +extern "C" __global__ void __launch_bounds__(128, 1) main_kernel(const fp4_e2_t* __restrict__ A, const fp4_e2_t* __restrict__ B, float* __restrict__ C) { + extern __shared__ __align__(1024) uchar buf_dyn_shmem[]; + float C_local[128]; + fp4_e2_2_t A_local_packed[32]; + fp4_e2_2_t B_local_packed[32]; + #pragma unroll + for (int i = 0; i < 32; ++i) { + float broadcast_var = 0x0p+0f/*0.000000e+00*/; + *(float4*)(C_local + (i * 4)) = make_float4(broadcast_var, broadcast_var, broadcast_var, broadcast_var); + } + #pragma unroll + for (int i_1 = 0; i_1 < 4; ++i_1) { + *(fp4_e2_32_t*)(((fp4_e2_t*)buf_dyn_shmem) + ((((i_1 * 2048) + ((((int)threadIdx.x) >> 2) * 64)) + (((((((int)threadIdx.x) & 31) >> 4) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 15) >> 3) + (((int)threadIdx.x) & 1)) & 1) * 16))) = *(fp4_e2_32_t*)(A + ((((((int)blockIdx.y) * 16384) + (i_1 * 4096)) + ((((int)threadIdx.x) >> 2) * 128)) + ((((int)threadIdx.x) & 3) * 16))); + } + #pragma unroll + for (int i_2 = 0; i_2 < 4; ++i_2) { + *(fp4_e2_32_t*)(((fp4_e2_t*)buf_dyn_shmem) + (((((i_2 * 2048) + ((((int)threadIdx.x) >> 2) * 64)) + (((((((int)threadIdx.x) & 31) >> 4) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 15) >> 3) + (((int)threadIdx.x) & 1)) & 1) * 16)) + 16384)) = *(fp4_e2_32_t*)(B + ((((((int)blockIdx.x) * 16384) + (i_2 * 4096)) + ((((int)threadIdx.x) >> 2) * 128)) + ((((int)threadIdx.x) & 3) * 16))); + } + #pragma unroll + for (int i_3 = 0; i_3 < 4; ++i_3) { + *(fp4_e2_32_t*)(((fp4_e2_t*)buf_dyn_shmem) + (((((i_3 * 2048) + ((((int)threadIdx.x) >> 2) * 64)) + (((((((int)threadIdx.x) & 31) >> 4) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 15) >> 3) + (((int)threadIdx.x) & 1)) & 1) * 16)) + 8192)) = *(fp4_e2_32_t*)(A + (((((((int)blockIdx.y) * 16384) + (i_3 * 4096)) + ((((int)threadIdx.x) >> 2) * 128)) + ((((int)threadIdx.x) & 3) * 16)) + 64)); + } + #pragma unroll + for (int i_4 = 0; i_4 < 4; ++i_4) { + *(fp4_e2_32_t*)(((fp4_e2_t*)buf_dyn_shmem) + (((((i_4 * 2048) + ((((int)threadIdx.x) >> 2) * 64)) + (((((((int)threadIdx.x) & 31) >> 4) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 15) >> 3) + (((int)threadIdx.x) & 1)) & 1) * 16)) + 24576)) = *(fp4_e2_32_t*)(B + (((((((int)blockIdx.x) * 16384) + (i_4 * 4096)) + ((((int)threadIdx.x) >> 2) * 128)) + ((((int)threadIdx.x) & 3) * 16)) + 64)); + } + __syncthreads(); + for (int ki = 0; ki < 4; ++ki) { + for (int i_5 = 0; i_5 < 4; ++i_5) { + tl::ptx_ldmatrix_x4((&(((fp4_e2_t*)buf_dyn_shmem)[(((((((((int)threadIdx.x) & 63) >> 5) * 8192) + (i_5 * 2048)) + (((((int)threadIdx.x) & 15) >> 3) * 1024)) + ((((((((int)threadIdx.x) & 15) * 128) + (((((((int)threadIdx.x) & 7) >> 2) + (ki >> 1)) & 1) * 64)) + (((((((int)threadIdx.x) & 3) >> 1) + (ki & 1)) & 1) * 32)) + (((((int)threadIdx.x) & 31) >> 4) * 16)) & 1023)) / 2)])) + 0, A_local_packed + (i_5 * 16)); + } + for (int i_6 = 0; i_6 < 4; ++i_6) { + tl::ptx_ldmatrix_x4((&(((fp4_e2_t*)buf_dyn_shmem)[(((((((((((int)threadIdx.x) >> 6) * 4096) + (i_6 * 1024)) + (((((int)threadIdx.x) & 31) >> 4) * 512)) + ((((int)threadIdx.x) & 7) * 64)) + (((((((int)threadIdx.x) & 7) >> 2) + (ki >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 3) >> 1) + (ki & 1)) & 1) * 16)) + (((((int)threadIdx.x) & 15) >> 3) * 8)) + 16384)])) + 0, B_local_packed + (i_6 * 16)); + } + for (int i_7 = 0; i_7 < 4; ++i_7) { + for (int j = 0; j < 4; ++j) { + tl::mma_sync(reinterpret_cast(C_local + ((i_7 * 32) + (j * 8))), reinterpret_cast(A_local_packed + (i_7 * 16)), reinterpret_cast(B_local_packed + (j * 16))); + tl::mma_sync(reinterpret_cast(C_local + (((i_7 * 32) + (j * 8)) + 4)), reinterpret_cast(A_local_packed + (i_7 * 16)), reinterpret_cast(B_local_packed + ((j * 16) + 8))); + } + } + } + for (int ki_1 = 0; ki_1 < 4; ++ki_1) { + for (int i_8 = 0; i_8 < 4; ++i_8) { + tl::ptx_ldmatrix_x4((&(((fp4_e2_t*)buf_dyn_shmem)[((((((((((int)threadIdx.x) & 63) >> 5) * 8192) + (i_8 * 2048)) + (((((int)threadIdx.x) & 15) >> 3) * 1024)) + ((((((((int)threadIdx.x) & 15) * 128) + (((((((int)threadIdx.x) & 7) >> 2) + (ki_1 >> 1)) & 1) * 64)) + (((((((int)threadIdx.x) & 3) >> 1) + (ki_1 & 1)) & 1) * 32)) + (((((int)threadIdx.x) & 31) >> 4) * 16)) & 1023)) + 16384) / 2)])) + 0, A_local_packed + (i_8 * 16)); + } + for (int i_9 = 0; i_9 < 4; ++i_9) { + tl::ptx_ldmatrix_x4((&(((fp4_e2_t*)buf_dyn_shmem)[(((((((((((int)threadIdx.x) >> 6) * 4096) + (i_9 * 1024)) + (((((int)threadIdx.x) & 31) >> 4) * 512)) + ((((int)threadIdx.x) & 7) * 64)) + (((((((int)threadIdx.x) & 7) >> 2) + (ki_1 >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 3) >> 1) + (ki_1 & 1)) & 1) * 16)) + (((((int)threadIdx.x) & 15) >> 3) * 8)) + 24576)])) + 0, B_local_packed + (i_9 * 16)); + } + for (int i_10 = 0; i_10 < 4; ++i_10) { + for (int j_1 = 0; j_1 < 4; ++j_1) { + tl::mma_sync(reinterpret_cast(C_local + ((i_10 * 32) + (j_1 * 8))), reinterpret_cast(A_local_packed + (i_10 * 16)), reinterpret_cast(B_local_packed + (j_1 * 16))); + tl::mma_sync(reinterpret_cast(C_local + (((i_10 * 32) + (j_1 * 8)) + 4)), reinterpret_cast(A_local_packed + (i_10 * 16)), reinterpret_cast(B_local_packed + ((j_1 * 16) + 8))); + } + } + } + #pragma unroll + for (int i_11 = 0; i_11 < 64; ++i_11) { + *(float2*)(C + (((((((((((int)blockIdx.y) * 32768) + (((((int)threadIdx.x) & 63) >> 5) * 16384)) + ((i_11 >> 4) * 4096)) + ((i_11 & 1) * 2048)) + (((((int)threadIdx.x) & 31) >> 2) * 256)) + (((int)blockIdx.x) * 128)) + ((((int)threadIdx.x) >> 6) * 64)) + (((i_11 & 15) >> 1) * 8)) + ((((int)threadIdx.x) & 3) * 2))) = *(float2*)(C_local + (i_11 * 2)); + } +} + diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index 0aada15044..31b7371c83 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -3395,8 +3395,8 @@ void CodeGenTileLangCUDA::VisitStmt_(const AllocateNode *op) { auto vid_packed = vid + "_packed"; stream << "fp4_e2_2_t " << vid_packed << '[' << (constant_size + 1) / 2 << "];\n"; - // Record mapping from original buffer to packed buffer name fp4_packed_buffers_[op->buffer_var.get()] = vid_packed; + var_idmap_[op->buffer_var.get()] = vid_packed; } else { stream << ' ' << vid << '[' << constant_size << "];\n"; } diff --git a/src/tl_templates/cuda/cuda_fp4.h b/src/tl_templates/cuda/cuda_fp4.h index 4a74a6485a..d2a24bd98a 100644 --- a/src/tl_templates/cuda/cuda_fp4.h +++ b/src/tl_templates/cuda/cuda_fp4.h @@ -27,11 +27,11 @@ class fp4_e2_2_t { } // Set low 4 bits (first fp4) - TL_DEVICE void set_x(fp4_e2_t val) { __x = (__x & 0xF0) | (val.__x & 0x0F); } + TL_DEVICE void set_x(fp4_e2_t val) { __x = (__x & 0xF0) | (val.raw() & 0x0F); } // Set high 4 bits (second fp4) TL_DEVICE void set_y(fp4_e2_t val) { - __x = (__x & 0x0F) | ((val.__x & 0x0F) << 4); + __x = (__x & 0x0F) | ((val.raw() & 0x0F) << 4); } }; @@ -70,7 +70,7 @@ struct __CUDA_ALIGN__(32) fp4_e2_64_t { // Pack two fp4_e2_t values. TL_DEVICE fp4_e2_2_t make_fp4_e2_2_t(fp4_e2_t x, fp4_e2_t y) { - __nv_fp4x2_storage_t packed = (x.__x & 0x0F) | ((y.__x & 0x0F) << 4); + __nv_fp4x2_storage_t packed = (x.raw() & 0x0F) | ((y.raw() & 0x0F) << 4); fp4_e2_2_t result; result.__x = packed; return result; diff --git a/src/tl_templates/cuda/instruction/mma.h b/src/tl_templates/cuda/instruction/mma.h index 869fa777bc..92ac6d0ebd 100644 --- a/src/tl_templates/cuda/instruction/mma.h +++ b/src/tl_templates/cuda/instruction/mma.h @@ -3,6 +3,7 @@ #include "../common.h" #include #include +#include #ifndef __CUDACC_RTC__ #include @@ -142,6 +143,11 @@ TL_DEFINE_MMA_DISPATCHER(kTensorFloat32, kTensorFloat32, kFloat32, 16, 8, 8, TL_DEFINE_MMA_DISPATCHER(kFloat64, kFloat64, kFloat64, 8, 8, 4, false, true, false, cute::SM80_8x8x4_F64F64F64F64_TN) +// FP4 inputs (k32, SM120 kind::f8f6f4) +using SM120_FP4_FP4_F32_TN = cute::SM120_16x8x32_TN; +TL_DEFINE_MMA_DISPATCHER(kFloat4_e2m1fn, kFloat4_e2m1fn, kFloat32, 16, 8, 32, + false, true, false, SM120_FP4_FP4_F32_TN) + #undef TL_DEFINE_MMA_DISPATCHER } // namespace detail From f068f2f763932a3eec03dd9e9527e3da634a5259 Mon Sep 17 00:00:00 2001 From: wahao Date: Fri, 13 Mar 2026 17:27:42 +0800 Subject: [PATCH 04/19] fix(buffer overflow, fp4 address misalignment): ldmatrix writes in by raw bytes, get_ldmatrix_offset add fp4 special layout, add SM120_FP4_FP4_F32_TN MmaDispatcher specialization + fp4 << 2 bit-shift. --- docs/GEMM_NV_FP4_FEATURE_STEPS.md | 107 -------------------- examples/gemm_fp4/example_gemm_fp4_sm120.py | 98 +++++++++++------- src/target/codegen_cuda.cc | 29 ++---- src/tl_templates/cuda/instruction/mma.h | 16 ++- tilelang/intrinsics/mma_layout.py | 15 +++ tilelang/intrinsics/utils.py | 17 +++- 6 files changed, 115 insertions(+), 167 deletions(-) delete mode 100644 docs/GEMM_NV_FP4_FEATURE_STEPS.md diff --git a/docs/GEMM_NV_FP4_FEATURE_STEPS.md b/docs/GEMM_NV_FP4_FEATURE_STEPS.md deleted file mode 100644 index 5fb808496a..0000000000 --- a/docs/GEMM_NV_FP4_FEATURE_STEPS.md +++ /dev/null @@ -1,107 +0,0 @@ -# GEMM NV FP4 Feature – 现状 Review 与实现步骤评估 - -## 一、当前代码库现状 - -### 1. 已有基础(可直接复用) - -| 模块 | 现状 | -|------|------| -| **FP4 类型与转换** | `src/tl_templates/cuda/cuda_fp4.h` 已有 `__nv_fp4_e2m1`、`fp4_e2_t`、fp4↔half/float/bf16 转换,以及 packed load/store 辅助函数。 | -| **Load/Store** | `lower_ldg_stg.cc` 已对 `float4_e2m1fn` 做 4-bit 处理(bits()、total_bits 等)。 | -| **CuTeDSL codegen** | `codegen_cutedsl.cc` 已支持 `Float4E2M1FN`。 | -| **Python dtype** | `tilelang/language/dtypes.py` 已有 `float4_e2m1fn` 及 x2/x4/…/x64 向量类型;与 torch 的映射在 `dtypes.py` / `tensor.py`。 | -| **TCGEN5 窄精度 meta** | `src/op/tcgen5_meta.h` 中 `GetTCGEN5MMAMeta` 已把 `float4_e2m1fn` 与 float8/float6 一起纳入窄精度分支:K%32==0,M/N 与现有 F8/F6 同规则,C 为 float32 或 float16。 | -| **SM100 MMA 模板** | `gemm_sm100.h` 已有 `SM100_MMA_F8F6F4_WS_SS`(PTX `kind::f8f6f4`),支持 ≤8bit 类型;当前仅对 `cute::float_e4m3_t` / `float_e5m2_t` 做了 `DispatchInstruction`。 | -| **Dequant 示例** | `examples/dequantize_gemm/` 下有 FP4 相关示例(如 `example_dequant_gemm_fp4_hopper.py`),用于参考 FP4 数据流与 scaling。 | - -### 2. 缺口(需要补齐) - -| 模块 | 问题 | -|------|------| -| **指令描述符** | `GetTCGEN5InstrDesc()` 中 `encode_dtype()` 未处理 `float4_e2m1fn`,会走到 `LOG(FATAL)`。需为 FP4 分配正确的 format 编码(需查 CUTLASS/PTX 描述符定义)。 | -| **PTX 类型枚举与字符串** | `src/target/ptx.cc` 的 `DTypeFromString()` 无 `"float4_e2m1fn"`;`common.h` 的 `DataType` 枚举无 `kFloat4_e2m1fn`;`enum_to_str` / `num_bits` 等表也需扩展。 | -| **GEMM 到 codegen 的 dtype 传递** | 当前 GEMM Lower 时把 `ab_dtype` 转成字符串传给后端;需保证 TVM `float4_e2m1fn` → 字符串 `"float4_e2m1fn"` → PTX/codegen 能识别并在 C++ 端选用 FP4 路径。 | -| **C++ 侧 FP4 类型与 Dispatch** | `gemm_sm100.h` 中 `DispatchInstruction` 仅有 `float_e4m3_t` / `float_e5m2_t`;需增加 FP4 的 C++ 类型(如对接 `cuda_fp4.h` 的 `fp4_e2_t` 或 CUTLASS 的 FP4 类型)以及对应的 `DispatchInstruction` 特化。 | -| **CuTe/CUTLASS 对 FP4 的支持** | `UMMA::make_instr_desc`、`sm100_smem_selector` 等是否支持 4-bit 类型需确认;若 CUTLASS 已有 FP4 描述符/布局,需在 TileLang 侧对齐。 | -| **Block-scaled FP4(可选)** | NVIDIA 文档中的 `mxf4.block_scale` / `mxf4nvf4.block_scale` 为单独路径,若要做 block-scaled GEMM,需额外设计 scale 的存储与传入方式,可作二期。 | - ---- - -## 二、推荐实现步骤(按依赖顺序) - -### Phase 1:类型与描述符贯通 - -1. **PTX / common 数据类型** - - 在 `common.h` 的 `DataType` 中增加 `kFloat4_e2m1fn`(或与现有命名规范一致)。 - - 在 `ptx.cc` 的 `DTypeFromString` 中增加 `"float4_e2m1fn"`(及可选别名)映射到该枚举。 - - 在 `enum_to_str`、`num_bits`、`dtype_str` 等表中为 FP4 填好条目,保证 codegen 能正确生成类型名与位宽。 - -2. **TCGEN5 指令描述符** - - 查 CUTLASS/PTX 文档中 tcgen05.mma `kind::f8f6f4` 下 FP4 的 descriptor 编码(与 e4m3/e5m2 的差异)。 - - 在 `tcgen5_meta.h` 的 `GetTCGEN5InstrDesc()` 里为 `float4_e2m1fn` 增加 `encode_dtype` 分支,返回正确的 format 码,保证现有 GEMM Lower 路径在 A/B 为 FP4 时能生成合法描述符。 - -3. **验证路径** - - 写一个最小用例:只做「分配 FP4 buffer + 从 Python 传 dtype 到 codegen」,确认: - - TVM dtype → 字符串 → `DTypeFromString` → `DTypeEnumToString` 一致; - - 若已有 FP4 的 LDG/STG,可顺带跑一条简单 load/store 或 dequant 示例,确认无回归。 - -### Phase 2:C++ GEMM 模板与 Dispatch - -4. **C++ FP4 类型与 CuTe 对接** - - 确定 GEMM 使用的 C++ 类型:`cuda_fp4.h` 的 `fp4_e2_t` 或 CUTLASS 的等价类型。 - - 若 CUTLASS 有 `make_instr_desc` 对 FP4 的重载,在 TileLang 的 `to_cute_type`(或当前用于 F8/F6 的映射处)增加 TVM float4_e2m1fn → 该 C++ 类型的映射。 - - 若 CUTLASS 的 `sm100_smem_selector` 对 4-bit 有特殊布局(例如 packed 2×fp4/byte),在 `GemmTensorOp` 的 A/B layout 路径中接入,保证与 descriptor 一致。 - -5. **DispatchInstruction 与 SS/WS** - - 在 `gemm_sm100.h` 中为 FP4 增加 `DispatchInstruction`(以及 C=half 若需要),指向已有的 `SM100_MMA_F8F6F4_SS` / `SM100_MMA_F8F6F4_WS_SS`。 - - 确认 PTX 的 `kind::f8f6f4` 对 FP4 的 M/N/K 与 tile 约束与当前 meta 一致(K=32 等);必要时对照 CUTLASS 72a_blackwell_nvfp4_bf16_gemm 示例。 - -6. **端到端 GEMM 调用** - - 在 `gemm.cc` / `gemm_py.cc` 的 Lower 中,当 `a_`/`b_` 为 float4_e2m1fn 时,传入的 `kind_dtype` 字符串需与 `DTypeFromString` 一致(如 `"float4_e2m1fn"`)。 - - 跑一次 SM100 上 A/B 为 FP4、C 为 float32 的 GEMM kernel,对比参考实现(如 CUTLASS 72a)做数值与正确性检查。 - -### Phase 3:示例与文档 - -7. **Example** - - 在 `examples/gemm_sm100/` 或新建 `examples/gemm_fp4_sm100/` 增加一个 NV FP4 GEMM 示例:T.float4_e2m1fn 作为 A/B dtype,accum float32,与 ref(先 dequant 再 bf16/f32 GEMM)对比。 - - 若后续做 block-scaled,可在同一目录下加另一个示例,区分「无 scale」与「block scale」两种用法。 - -8. **文档与 CI** - - 在 `docs/programming_guides/type_system.md` 或 GEMM 相关文档中注明:SM100 TCGEN5 MMA 支持 float4_e2m1fn,以及当前限制(K%32、M/N 与 meta 一致、是否支持 block scale)。 - - CI 中为 SM100 增加一条 FP4 GEMM 的编译/运行(若环境有 SM100 或通过 skip 条件处理)。 - ---- - -## 三、风险与依赖 - -- **CUTLASS 版本**:FP4 描述符与 `make_instr_desc` 的细节可能随 CUTLASS 版本变化,需以当前 submodule 版本为准核对。 -- **PTXAS / CUDA**:`cuda_fp4.h` 中已有对 CUDA < 13 的 workaround(如 `__tl_cvt_fp4_to_halfraw_naive`),若运行环境 CUDA 版本较新,可再确认是否仍需要这些分支。 -- **Block-scaled FP4**:若产品需要 4× 吞吐的 block-scaled 路径,需单独设计 scale 张量、描述符与 API,工作量大于「仅非 block-scaled FP4 GEMM」。 - ---- - -## 四、与 Flash Attention SM100 的优先级 - -- Flash Attention SM100 的 example(含 wasp)可后续再迭代;当前不作为重点。 -- GEMM NV FP4 作为独立 feature,与 FA 的改动解耦;完成上述 Phase 1–2 即可得到可用的 FP4 GEMM 路径,再视需求加 block scale 或集成到更高层算子。 - ---- - -## 五、本次已完成的修改(CuTeDSL/TCGEN5 路径) - -以下修改已落地,使 **CuTeDSL 生成的 SM100 GEMM kernel** 在 A/B 为 `float4_e2m1fn` 时能正确走 `kind::f8f6f4` 的 FP4 路径。当前未改 `gemm_sm100.h` 的 CUTLASS `DispatchInstruction`(该路径由 CUTLASS 模板实例化,与 CuTeDSL 生成 TIR 再 codegen 的路径分离)。 - -| 文件 | 修改内容 | -|------|----------| -| **src/target/ptx.h** | `DataType` 枚举增加 `kFloat4_e2m1fn = 23`。 | -| **src/target/ptx.cc** | `enum_to_str` / `dtype_str` / `num_bits` 增加第 24 项;`DTypeFromString` 增加 `"float4_e2m1fn"`、`".e2m1"` → `kFloat4_e2m1fn`。 | -| **src/tl_templates/cuda/common.h** | `DataType` 枚举增加 `kFloat4_e2m1fn = 23`。 | -| **src/op/tcgen5_meta.h** | `GetTCGEN5InstrDesc()` 中 `encode_dtype` 增加 `dtype.is_float4_e2m1fn()` 分支,返回 `2`(FP4 format 编码)。 | -| **src/tl_templates/cuda/instruction/tcgen05mma.h** | 为 `DataType::kFloat4_e2m1fn` 增加 `tcgen05mma_ss`、`tcgen05mma_ts`、`tcgen05mma_ws_ss` 特化,均转发到 `kFloat8_e4m3` 的 f8f6f4 实现。 | -| **tilelang/intrinsics/mma_macro_generator.py** | `dtype_abbrv` 增加 `"float4_e2m1fn": "float4_e2m1fn"`,保证 lowering 时 `a_dtype_abbrv` 传入 `ptx_tcgen05_mma_ss("float4_e2m1fn", ...)`。 | - -**后续建议**(在有 SM100 / nvcc 环境时): - -1. 本地或 CI 执行完整构建并跑 `examples/gemm_sm100/` 中 BF16/F8 用例,确认无回归。 -2. 新增 FP4 GEMM 示例:`in_dtype=accum_dtype=T.float4_e2m1fn, T.float`,`block_K=128`(K%32 已由 meta 保证),与参考实现做数值对比。 -3. 若需走 CUTLASS 模板路径的 FP4 GEMM,再补 `gemm_sm100.h` 的 `to_cute_type` 与 `DispatchInstruction`。 diff --git a/examples/gemm_fp4/example_gemm_fp4_sm120.py b/examples/gemm_fp4/example_gemm_fp4_sm120.py index fe92957397..104b727c02 100644 --- a/examples/gemm_fp4/example_gemm_fp4_sm120.py +++ b/examples/gemm_fp4/example_gemm_fp4_sm120.py @@ -4,23 +4,17 @@ Addresses https://github.com/tile-ai/tilelang/issues/1592 """ +import os +import time import torch import tilelang import tilelang.language as T -import os + def matmul_fp4( - M, - N, - K, - block_M, - block_N, - block_K, - in_dtype, - out_dtype, - accum_dtype, - num_stages=2, - threads=128, + M, N, K, block_M, block_N, block_K, + in_dtype, out_dtype, accum_dtype, + num_stages=2, threads=128, ): A_shape = (M, K) B_shape = (N, K) @@ -49,13 +43,25 @@ def main( return main -def calc_diff(x, y): - x, y = x.double(), y.double() - denominator = (x * x + y * y).sum() - if denominator == 0: - return 0.0 - sim = 2 * (x * y).sum() / denominator - return 1 - sim +FP4_E2M1_TO_FLOAT = [ + 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, + -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0, +] + + +def unpack_fp4_to_float(packed_int8: torch.Tensor, M: int, K: int) -> torch.Tensor: + """Unpack (M, K//2) int8 tensor → (M, K) float32 tensor.""" + lut = torch.tensor(FP4_E2M1_TO_FLOAT, dtype=torch.float32, device=packed_int8.device) + flat = packed_int8.to(torch.uint8).flatten() + lo = (flat & 0x0F).to(torch.int64) + hi = ((flat >> 4) & 0x0F).to(torch.int64) + pairs = torch.stack([lut[lo], lut[hi]], dim=-1) + return pairs.reshape(M, K) + + +def make_fp4_tensor(M: int, K: int, device="cuda") -> torch.Tensor: + """Create random packed FP4 tensor as (M, K//2) int8.""" + return torch.randint(0, 256, (M, K // 2), dtype=torch.uint8, device=device).to(torch.int8) M, N, K = 256, 256, 256 @@ -63,18 +69,14 @@ def calc_diff(x, y): in_dtype = T.float4_e2m1fn out_dtype = T.float32 accum_dtype = T.float32 -num_stages = 2 -threads = 128 print(f"Running FP4 GEMM: M={M}, N={N}, K={K}") print(f" block_M={block_M}, block_N={block_N}, block_K={block_K}") -print(f" in_dtype={in_dtype}, out_dtype={out_dtype}, accum_dtype={accum_dtype}") func = matmul_fp4( - M, N, K, - block_M, block_N, block_K, + M, N, K, block_M, block_N, block_K, in_dtype, out_dtype, accum_dtype, - num_stages, threads, + num_stages=2, threads=128, ) jit_kernel = tilelang.compile( @@ -91,13 +93,39 @@ def calc_diff(x, y): with open(os.path.join(os.path.dirname(__file__), "gemm_fp4_sm120.cu"), "w") as f: f.write(jit_kernel.get_kernel_source()) -print("Kernel source written to gemm_fp4_sm120.cu") - -# NOTE: Runtime execution of FP4 kernels requires proper sub-byte pointer -# arithmetic in TVM's codegen. Currently, codegen treats fp4_e2_t as 1-byte -# (sizeof=1) but actual data is packed 2-per-byte, causing 2x address -# overshoot. This is a known TVM sub-byte codegen limitation to be fixed -# upstream. The compilation pipeline (LayoutInference → LowerTileOp → -# CUDA codegen → nvcc) works end-to-end. -print("\n[OK] FP4 GEMM kernel compiled and CUDA source generated successfully.") -print("[TODO] Runtime execution pending TVM sub-byte pointer arithmetic fix.") +# --- Test 1: zeros in → zeros out --- +a_zero = torch.zeros(M, K // 2, device="cuda", dtype=torch.int8) +b_zero = torch.zeros(N, K // 2, device="cuda", dtype=torch.int8) +c_zero = jit_kernel(a_zero, b_zero) +assert c_zero.abs().max().item() == 0.0, f"Zero test failed: max={c_zero.abs().max().item()}" +print("[PASS] zeros in → zeros out") + +# --- Test 2: numerical verification with random FP4 data --- +torch.manual_seed(42) +a_packed = make_fp4_tensor(M, K) +b_packed = make_fp4_tensor(N, K) + +c = jit_kernel(a_packed, b_packed) + +a_float = unpack_fp4_to_float(a_packed, M, K) +b_float = unpack_fp4_to_float(b_packed, N, K) +ref_c = a_float @ b_float.T + +diff = (c.float() - ref_c).abs() +max_diff = diff.max().item() +rel_err = diff.sum().item() / (ref_c.abs().sum().item() + 1e-10) +print(f"[NUMERICAL] max_abs_diff={max_diff:.4f}, rel_err={rel_err:.6f}") +if max_diff < 1.0: + print("[PASS] numerical verification (max_abs_diff < 1.0)") +else: + print(f"[WARN] large diff — may indicate layout or data flow issue") + +# --- Benchmark --- +torch.cuda.synchronize() +start = time.perf_counter() +for _ in range(100): + jit_kernel(a_packed, b_packed) +torch.cuda.synchronize() +elapsed = (time.perf_counter() - start) / 100 * 1000 +print(f"Latency: {elapsed:.4f} ms") +print(f"TFLOPS: {2 * M * N * K / (elapsed / 1e3) / 1e12:.2f}") diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index 31b7371c83..ca388269d9 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -3352,15 +3352,8 @@ void CodeGenTileLangCUDA::VisitStmt_(const AllocateNode *op) { } else if (scope == "local.descriptor.tcgen05_instr") { stream << "tl::Tcgen05InstrDescriptor " << vid << ";\n"; } else { - // For FP4 scalar local buffers, we use packed storage type, - // so skip type declaration here (will be handled in the local scope section - // below) - bool is_fp4_scalar_local = op->dtype.is_float4() && op->dtype.is_scalar() && - (scope == "local" || scope.empty()); - if (!is_fp4_scalar_local) { - PrintStorageScope(scope, stream); - PrintType(op->dtype, stream); - } + PrintStorageScope(scope, stream); + PrintType(op->dtype, stream); } if (scope == "shared.dyn") { @@ -3387,19 +3380,11 @@ void CodeGenTileLangCUDA::VisitStmt_(const AllocateNode *op) { stream << "auto " << vid << " = reinterpret_cast<" << mbarrier_dtype_ << "*>(" << v_id_mem << ");\n"; } else if (scope == "local") { - // For FP4 types, use packed storage type to avoid wasting registers. - // fp4_e2_t uses int8 as storage but only needs 4 bits per element. - // By using fp4_e2_2_t (which stores 2 fp4 values in 1 byte), we halve the - // storage. - if (op->dtype.is_float4() && op->dtype.is_scalar()) { - auto vid_packed = vid + "_packed"; - stream << "fp4_e2_2_t " << vid_packed << '[' << (constant_size + 1) / 2 - << "];\n"; - fp4_packed_buffers_[op->buffer_var.get()] = vid_packed; - var_idmap_[op->buffer_var.get()] = vid_packed; - } else { - stream << ' ' << vid << '[' << constant_size << "];\n"; - } + // For FP4 local fragment buffers (MMA operands), do NOT pack into + // fp4_e2_2_t. ldmatrix loads raw bytes regardless of element type, so + // the buffer needs the full byte count (1 byte per fp4_e2_t element). + // Packing would halve the buffer and cause stack overflow. + stream << ' ' << vid << '[' << constant_size << "];\n"; } else if (scope == "local.var") { PrimExpr init = tir::make_const(op->dtype, 0); auto init_it = op->annotations.find(tl::attr::kLocalVarInit); diff --git a/src/tl_templates/cuda/instruction/mma.h b/src/tl_templates/cuda/instruction/mma.h index 92ac6d0ebd..0007ccb3af 100644 --- a/src/tl_templates/cuda/instruction/mma.h +++ b/src/tl_templates/cuda/instruction/mma.h @@ -165,7 +165,21 @@ TL_DEVICE void mma_sync( TransB, Saturate>; static_assert(!std::is_void_v, "tl::mma_sync: unsupported configuration"); - Dispatcher::exec(c, a, b, c); + if constexpr (AType == DataType::kFloat4_e2m1fn || + BType == DataType::kFloat4_e2m1fn) { + using AReg = typename Dispatcher::ARegType; + using BReg = typename Dispatcher::BRegType; + constexpr int nA = detail::MmaImplTraits::kARegs; + constexpr int nB = detail::MmaImplTraits::kBRegs; + AReg as[nA]; BReg bs[nB]; + #pragma unroll + for (int i = 0; i < nA; ++i) as[i] = a[i] << 2; + #pragma unroll + for (int i = 0; i < nB; ++i) bs[i] = b[i] << 2; + Dispatcher::exec(c, as, bs, c); + } else { + Dispatcher::exec(c, a, b, c); + } } } // namespace tl diff --git a/tilelang/intrinsics/mma_layout.py b/tilelang/intrinsics/mma_layout.py index 2eb575f0ca..5ec887a1f0 100644 --- a/tilelang/intrinsics/mma_layout.py +++ b/tilelang/intrinsics/mma_layout.py @@ -39,6 +39,21 @@ def ldmatrix_32x16_to_shared_16x32_layout_b(thread_id, local_id): return row, col +def ldmatrix_32x16_to_shared_16x32_fp4_layout_a(thread_id, local_id): + """FP4 variant: each row is 16 bytes (32 FP4 elements), ldmatrix covers + the full row in one 128-bit load. No half-row selection needed.""" + row = thread_id % 16 + col = local_id + return row, col + + +def ldmatrix_32x16_to_shared_16x32_fp4_layout_b(thread_id, local_id): + """FP4 variant: same reasoning — full row covered by one ldmatrix load.""" + row = (thread_id // 16) * 8 + (thread_id % 8) + col = local_id + return row, col + + def mma_store_32x8_to_shared_16x16_layout(thread_id, local_id): row = 8 * (local_id % 4 // 2) + (thread_id // 4) col = 8 * (local_id // 4) + (thread_id % 4) * 2 + (local_id % 2) diff --git a/tilelang/intrinsics/utils.py b/tilelang/intrinsics/utils.py index 79afebae64..9c5751ff3a 100644 --- a/tilelang/intrinsics/utils.py +++ b/tilelang/intrinsics/utils.py @@ -7,6 +7,8 @@ ldmatrix_trans_32x8_to_shared_16x16_layout, ldmatrix_32x16_to_shared_16x32_layout_a, ldmatrix_32x16_to_shared_16x32_layout_b, + ldmatrix_32x16_to_shared_16x32_fp4_layout_a, + ldmatrix_32x16_to_shared_16x32_fp4_layout_b, mma_store_32x8_to_shared_16x16_layout, mma_store_32x2_to_shared_8x8_layout_fp64, ) @@ -49,7 +51,18 @@ def get_ldmatrix_offset( else: new_row_idx, new_col_idx = transform_func(row_idx, col_idx) return new_row_idx * stride + new_col_idx - elif dtype_bits == 8 or dtype_bits == 4: + elif dtype_bits == 4: + if matrix == "B" and transposed: + transform_func = ldmatrix_32x16_to_shared_16x32_fp4_layout_b + new_row_idx, new_col_idx = transform_func(row_idx, col_idx) + return new_row_idx * stride + new_col_idx + elif matrix == "A" and not transposed: + transform_func = ldmatrix_32x16_to_shared_16x32_fp4_layout_a + new_row_idx, new_col_idx = transform_func(row_idx, col_idx) + return new_row_idx * stride + new_col_idx + else: + raise ValueError("ldmatrix only supports B transposed and A non-transposed for fp4") + elif dtype_bits == 8: if matrix == "B" and transposed: transform_func = ldmatrix_32x16_to_shared_16x32_layout_b new_row_idx, new_col_idx = transform_func(row_idx, col_idx) @@ -59,7 +72,7 @@ def get_ldmatrix_offset( new_row_idx, new_col_idx = transform_func(row_idx, col_idx) return new_row_idx * stride + new_col_idx else: - raise ValueError("ldmatrix only supports B transposed and A non-transposed for int8/fp4") + raise ValueError("ldmatrix only supports B transposed and A non-transposed for int8") else: raise ValueError(f"Unsupported dtype {dtype}") From d116372253e51f5e9f58390725525dc3987d4985 Mon Sep 17 00:00:00 2001 From: wahao Date: Sat, 14 Mar 2026 22:30:53 +0800 Subject: [PATCH 05/19] feat: SM120 FP4 GEMM unpacked shared memory + numerical verification --- examples/gemm_fp4/example_gemm_fp4_sm120.py | 68 ++++++++++++--------- src/target/codegen_cuda.cc | 12 ++-- src/tl_templates/cuda/instruction/mma.h | 14 ++++- src/tl_templates/cuda/ldsm.h | 33 ++++++++++ tilelang/intrinsics/mma_layout.py | 10 +-- tilelang/tileop/gemm/gemm_base.py | 18 +++++- 6 files changed, 113 insertions(+), 42 deletions(-) diff --git a/examples/gemm_fp4/example_gemm_fp4_sm120.py b/examples/gemm_fp4/example_gemm_fp4_sm120.py index 104b727c02..7d85319d49 100644 --- a/examples/gemm_fp4/example_gemm_fp4_sm120.py +++ b/examples/gemm_fp4/example_gemm_fp4_sm120.py @@ -1,6 +1,11 @@ """FP4 (float4_e2m1fn) GEMM on SM120 (RTX 5080/5090) using fragment-based MMA. Uses mma.sync.aligned.kind::f8f6f4 instructions (not TCGEN05/TMEM). +FP4 data is pre-unpacked to uint8 on the host (1 byte per element, low +nibble holds the 4-bit value). Shared memory stores uint8, making the +layout identical to INT8 and satisfying ldmatrix 16B alignment. +The << 2 bit-shift before MMA places FP4 data at bits 2-5 as required. + Addresses https://github.com/tile-ai/tilelang/issues/1592 """ @@ -13,7 +18,7 @@ def matmul_fp4( M, N, K, block_M, block_N, block_K, - in_dtype, out_dtype, accum_dtype, + out_dtype, accum_dtype, num_stages=2, threads=128, ): A_shape = (M, K) @@ -23,13 +28,13 @@ def matmul_fp4( @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), + A: T.Tensor(A_shape, "uint8"), + B: T.Tensor(B_shape, "uint8"), C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype) - B_shared = T.alloc_shared(B_shared_shape, in_dtype) + A_shared = T.alloc_shared(A_shared_shape, "uint8") + B_shared = T.alloc_shared(B_shared_shape, "uint8") C_local = T.alloc_fragment((block_M, block_N), accum_dtype) T.clear(C_local) @@ -49,24 +54,23 @@ def main( ] -def unpack_fp4_to_float(packed_int8: torch.Tensor, M: int, K: int) -> torch.Tensor: - """Unpack (M, K//2) int8 tensor → (M, K) float32 tensor.""" - lut = torch.tensor(FP4_E2M1_TO_FLOAT, dtype=torch.float32, device=packed_int8.device) - flat = packed_int8.to(torch.uint8).flatten() - lo = (flat & 0x0F).to(torch.int64) - hi = ((flat >> 4) & 0x0F).to(torch.int64) - pairs = torch.stack([lut[lo], lut[hi]], dim=-1) - return pairs.reshape(M, K) +def unpack_fp4_to_uint8(packed_int8: torch.Tensor, M: int, K: int) -> torch.Tensor: + """Unpack (M, K//2) int8 -> (M, K) uint8, 1 FP4 per byte in low nibble.""" + flat = packed_int8.to(torch.uint8).reshape(M, K // 2) + lo = flat & 0x0F + hi = (flat >> 4) & 0x0F + return torch.stack([lo, hi], dim=-1).reshape(M, K).contiguous() -def make_fp4_tensor(M: int, K: int, device="cuda") -> torch.Tensor: - """Create random packed FP4 tensor as (M, K//2) int8.""" - return torch.randint(0, 256, (M, K // 2), dtype=torch.uint8, device=device).to(torch.int8) +def unpack_fp4_to_float(packed_int8: torch.Tensor, M: int, K: int) -> torch.Tensor: + """Unpack (M, K//2) int8 -> (M, K) float32 via e2m1 lookup table.""" + lut = torch.tensor(FP4_E2M1_TO_FLOAT, dtype=torch.float32, device=packed_int8.device) + unpacked = unpack_fp4_to_uint8(packed_int8, M, K).to(torch.int64) + return lut[unpacked] M, N, K = 256, 256, 256 block_M, block_N, block_K = 128, 128, 128 -in_dtype = T.float4_e2m1fn out_dtype = T.float32 accum_dtype = T.float32 @@ -75,7 +79,7 @@ def make_fp4_tensor(M: int, K: int, device="cuda") -> torch.Tensor: func = matmul_fp4( M, N, K, block_M, block_N, block_K, - in_dtype, out_dtype, accum_dtype, + out_dtype, accum_dtype, num_stages=2, threads=128, ) @@ -93,19 +97,23 @@ def make_fp4_tensor(M: int, K: int, device="cuda") -> torch.Tensor: with open(os.path.join(os.path.dirname(__file__), "gemm_fp4_sm120.cu"), "w") as f: f.write(jit_kernel.get_kernel_source()) -# --- Test 1: zeros in → zeros out --- -a_zero = torch.zeros(M, K // 2, device="cuda", dtype=torch.int8) -b_zero = torch.zeros(N, K // 2, device="cuda", dtype=torch.int8) +torch.manual_seed(42) + +# Create packed FP4 data (2 per byte), then pre-unpack to uint8 +a_packed = torch.randint(0, 256, (M, K // 2), device="cuda", dtype=torch.uint8).to(torch.int8) +b_packed = torch.randint(0, 256, (N, K // 2), device="cuda", dtype=torch.uint8).to(torch.int8) +a_unpacked = unpack_fp4_to_uint8(a_packed, M, K) +b_unpacked = unpack_fp4_to_uint8(b_packed, N, K) + +# --- Test 1: zeros --- +a_zero = torch.zeros(M, K, device="cuda", dtype=torch.uint8) +b_zero = torch.zeros(N, K, device="cuda", dtype=torch.uint8) c_zero = jit_kernel(a_zero, b_zero) assert c_zero.abs().max().item() == 0.0, f"Zero test failed: max={c_zero.abs().max().item()}" -print("[PASS] zeros in → zeros out") - -# --- Test 2: numerical verification with random FP4 data --- -torch.manual_seed(42) -a_packed = make_fp4_tensor(M, K) -b_packed = make_fp4_tensor(N, K) +print("[PASS] zeros in -> zeros out") -c = jit_kernel(a_packed, b_packed) +# --- Test 2: numerical verification --- +c = jit_kernel(a_unpacked, b_unpacked) a_float = unpack_fp4_to_float(a_packed, M, K) b_float = unpack_fp4_to_float(b_packed, N, K) @@ -118,13 +126,13 @@ def make_fp4_tensor(M: int, K: int, device="cuda") -> torch.Tensor: if max_diff < 1.0: print("[PASS] numerical verification (max_abs_diff < 1.0)") else: - print(f"[WARN] large diff — may indicate layout or data flow issue") + print(f"[WARN] large diff -- may indicate layout or data flow issue") # --- Benchmark --- torch.cuda.synchronize() start = time.perf_counter() for _ in range(100): - jit_kernel(a_packed, b_packed) + jit_kernel(a_unpacked, b_unpacked) torch.cuda.synchronize() elapsed = (time.perf_counter() - start) / 100 * 1000 print(f"Latency: {elapsed:.4f} ms") diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index ca388269d9..25bdfd7096 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -1621,17 +1621,21 @@ std::string CodeGenTileLangCUDA::GetBufferRef(DataType t, os << "*((" << ptr_cast(t) << vid << ")" << " + " << index_str << ")"; } else if (t == buffer_element_dtype) { int div_factor = 1; - if (buffer_element_dtype.is_float4() && buffer_element_dtype.lanes() == 1) { + // FP4 div_factor=2 only for global memory (packed 2/byte). + // Shared/local stores unpacked FP4 (1 byte per element, sizeof=1). + bool is_packed_scope = scope.empty() || scope == "global"; + if (buffer_element_dtype.is_float4() && buffer_element_dtype.lanes() == 1 + && is_packed_scope) { div_factor = 2; } index_str = PrintExpr(arith::Analyzer().Simplify(truncdiv(index, div_factor))); os << buffer_str << "[" << index_str << "]"; } else { - // Fix fp4 pointer arithmetic: fp4 elements are 4-bit packed 2 per byte. - // fp4* + n incorrectly advances n bytes (skipping 2n elements). int div_factor = 1; - if (buffer_element_dtype.is_float4() && buffer_element_dtype.lanes() == 1) { + bool is_packed_scope = scope.empty() || scope == "global"; + if (buffer_element_dtype.is_float4() && buffer_element_dtype.lanes() == 1 + && is_packed_scope) { div_factor = 2; } index_str = diff --git a/src/tl_templates/cuda/instruction/mma.h b/src/tl_templates/cuda/instruction/mma.h index 0007ccb3af..b5cf1696de 100644 --- a/src/tl_templates/cuda/instruction/mma.h +++ b/src/tl_templates/cuda/instruction/mma.h @@ -148,6 +148,14 @@ using SM120_FP4_FP4_F32_TN = cute::SM120_16x8x32_TN; +using SM120_FP4_FP8_F32_TN = cute::SM120_16x8x32_TN; +TL_DEFINE_MMA_DISPATCHER(kFloat8_e4m3, kFloat4_e2m1fn, kFloat32, 16, 8, 32, + false, true, false, SM120_FP8_FP4_F32_TN) +TL_DEFINE_MMA_DISPATCHER(kFloat4_e2m1fn, kFloat8_e4m3, kFloat32, 16, 8, 32, + false, true, false, SM120_FP4_FP8_F32_TN) + #undef TL_DEFINE_MMA_DISPATCHER } // namespace detail @@ -173,9 +181,11 @@ TL_DEVICE void mma_sync( constexpr int nB = detail::MmaImplTraits::kBRegs; AReg as[nA]; BReg bs[nB]; #pragma unroll - for (int i = 0; i < nA; ++i) as[i] = a[i] << 2; + for (int i = 0; i < nA; ++i) + as[i] = (AType == DataType::kFloat4_e2m1fn) ? (a[i] << 2) : a[i]; #pragma unroll - for (int i = 0; i < nB; ++i) bs[i] = b[i] << 2; + for (int i = 0; i < nB; ++i) + bs[i] = (BType == DataType::kFloat4_e2m1fn) ? (b[i] << 2) : b[i]; Dispatcher::exec(c, as, bs, c); } else { Dispatcher::exec(c, a, b, c); diff --git a/src/tl_templates/cuda/ldsm.h b/src/tl_templates/cuda/ldsm.h index a20746dff9..abd1471e5c 100644 --- a/src/tl_templates/cuda/ldsm.h +++ b/src/tl_templates/cuda/ldsm.h @@ -51,6 +51,39 @@ TL_DEVICE void ptx_ldmatrix_x2_trans(void const *const smem_ptr, : "r"(smem_int_ptr)); } +// ldmatrix for 4-bit sub-byte types (FP4/INT4 on SM120+). +// Reads packed 4-bit data from shared memory and unpacks each 4-bit value +// into the low 4 bits of an 8-bit container (upper 4 bits zeroed). +TL_DEVICE void ptx_ldmatrix_b4x16_x1(void const *const smem_ptr, + void *const local_ptr) { + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + int32_t *value = reinterpret_cast(local_ptr); + asm volatile( + "ldmatrix.sync.aligned.m8n16.x1.shared.b8x16.b4x16_p64 {%0}, [%1];\n" + : "=r"(value[0]) + : "r"(smem_int_ptr)); +} + +TL_DEVICE void ptx_ldmatrix_b4x16_x2(void const *const smem_ptr, + void *const local_ptr) { + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + int32_t *value = reinterpret_cast(local_ptr); + asm volatile( + "ldmatrix.sync.aligned.m8n16.x2.shared.b8x16.b4x16_p64 {%0, %1}, [%2];\n" + : "=r"(value[0]), "=r"(value[1]) + : "r"(smem_int_ptr)); +} + +TL_DEVICE void ptx_ldmatrix_b4x16_x4(void const *const smem_ptr, + void *const local_ptr) { + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + int32_t *value = reinterpret_cast(local_ptr); + asm volatile( + "ldmatrix.sync.aligned.m8n16.x4.shared.b8x16.b4x16_p64 {%0, %1, %2, %3}, [%4];\n" + : "=r"(value[0]), "=r"(value[1]), "=r"(value[2]), "=r"(value[3]) + : "r"(smem_int_ptr)); +} + TL_DEVICE void ptx_ldmatrix_x4_trans(void const *const smem_ptr, void *const local_ptr) { uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); diff --git a/tilelang/intrinsics/mma_layout.py b/tilelang/intrinsics/mma_layout.py index 5ec887a1f0..89eb717ad7 100644 --- a/tilelang/intrinsics/mma_layout.py +++ b/tilelang/intrinsics/mma_layout.py @@ -40,17 +40,17 @@ def ldmatrix_32x16_to_shared_16x32_layout_b(thread_id, local_id): def ldmatrix_32x16_to_shared_16x32_fp4_layout_a(thread_id, local_id): - """FP4 variant: each row is 16 bytes (32 FP4 elements), ldmatrix covers - the full row in one 128-bit load. No half-row selection needed.""" + """FP4 with unpacked shared memory (1 byte/element) uses the same + layout as INT8 — shared memory rows are 32 bytes for K=32.""" row = thread_id % 16 - col = local_id + col = local_id + (thread_id // 16) * 16 return row, col def ldmatrix_32x16_to_shared_16x32_fp4_layout_b(thread_id, local_id): - """FP4 variant: same reasoning — full row covered by one ldmatrix load.""" + """FP4 with unpacked shared memory — same as INT8.""" row = (thread_id // 16) * 8 + (thread_id % 8) - col = local_id + col = local_id + 16 * ((thread_id % 16) // 8) return row, col diff --git a/tilelang/tileop/gemm/gemm_base.py b/tilelang/tileop/gemm/gemm_base.py index 488f742cb4..92db197db4 100644 --- a/tilelang/tileop/gemm/gemm_base.py +++ b/tilelang/tileop/gemm/gemm_base.py @@ -75,12 +75,28 @@ def in_dtype(self) -> str: For the TS variant, A resides in TMEM with the accumulator dtype, so the actual input dtype is derived from B. + + For FP4/FP8 on SM120, input buffers use uint8 storage (unpacked, + 1 element per byte) but the MMA instruction uses float4_e2m1fn or + float8_e4m3fn. Mixed types (e.g. FP8 x FP4) are supported. """ if is_tensor_memory(self.A): return self.B.dtype - assert self.A.dtype == self.B.dtype, "A and B must have the same dtype" return self.A.dtype + @property + def in_dtype_b(self) -> str: + """Input data type for B operand (may differ from A for mixed-type MMA). + + uint8 B with float32 accumulator maps to float4_e2m1fn (unpacked FP4). + """ + if is_tensor_memory(self.A): + return self.B.dtype + dtype = self.B.dtype + if dtype == "uint8" and self.C.dtype in ("float32", "float16"): + return "float4_e2m1fn" + return dtype + @property def accum_dtype(self) -> str: return self.C.dtype From f13a6b714f91d0310d3da54fb8ad1701f0eba61a Mon Sep 17 00:00:00 2001 From: wahao Date: Sat, 14 Mar 2026 22:31:09 +0800 Subject: [PATCH 06/19] feat: A8W4 mixed-type MMA (FP8xFP4) + FP4 MoE example on SM120 --- .../gemm_fp4/example_fusedmoe_a8w4_sm120.py | 157 ++++++++++++++++++ examples/gemm_fp4/example_gemm_a8w4_sm120.py | 126 ++++++++++++++ src/tl_templates/cuda/gemm_mma.h | 2 + tilelang/tileop/gemm/gemm_mma.py | 9 +- 4 files changed, 290 insertions(+), 4 deletions(-) create mode 100644 examples/gemm_fp4/example_fusedmoe_a8w4_sm120.py create mode 100644 examples/gemm_fp4/example_gemm_a8w4_sm120.py diff --git a/examples/gemm_fp4/example_fusedmoe_a8w4_sm120.py b/examples/gemm_fp4/example_fusedmoe_a8w4_sm120.py new file mode 100644 index 0000000000..0035786347 --- /dev/null +++ b/examples/gemm_fp4/example_fusedmoe_a8w4_sm120.py @@ -0,0 +1,157 @@ +"""FP4 Fused MoE shared expert kernel on SM120 (A8W4 mode). + +Demonstrates the core MoE compute pattern with FP4 weights: + 1. Gate GEMM: input(FP8) x W_gate(FP4) -> gate logits + 2. Up GEMM: input(FP8) x W_up(FP4) -> up logits + 3. SiLU(gate) * up + 4. Down GEMM: activated(FP8) x W_down(FP4) -> output + +Uses SM120 native kind::f8f6f4 MMA (FP8 x FP4 -> FP32). +Expert weights are stored as unpacked uint8 (1 FP4 per byte, low nibble). + +This is a simplified single-expert example. For full routing + grouped GEMM, +see examples/fusedmoe/example_fusedmoe_tilelang.py. +""" + +import os +import time +import torch +import tilelang +import tilelang.language as T + + +FP4_E2M1_TO_FLOAT = [ + 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, + -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0, +] + + +def fp4_uint8_to_float(t): + lut = torch.tensor(FP4_E2M1_TO_FLOAT, dtype=torch.float32, device=t.device) + return lut[t.to(torch.int64)] + + +def moe_shared_expert_a8w4( + num_tokens, d_hidden, d_expert, + block_token=128, block_hidden=128, block_expert=128, + threads=128, num_stages=1, +): + """Single shared expert: gate_up GEMM -> SiLU*up -> down GEMM.""" + scale = 1.44269504 # log2(e) for fast SiLU + + @T.prim_func + def main( + input: T.Tensor((num_tokens, d_hidden), "float8_e4m3fn"), + W_gate: T.Tensor((d_expert, d_hidden), "uint8"), + W_up: T.Tensor((d_expert, d_hidden), "uint8"), + W_down: T.Tensor((d_hidden, d_expert), "uint8"), + output: T.Tensor((num_tokens, d_hidden), "float32"), + ): + # Step 1: Gate + Up GEMMs (fused in one kernel launch) + with T.Kernel( + T.ceildiv(num_tokens, block_token), + T.ceildiv(d_expert, block_expert), + threads=threads, + ) as (bx, by): + input_shared = T.alloc_shared((block_token, block_hidden), "float8_e4m3fn") + W_gate_shared = T.alloc_shared((block_expert, block_hidden), "uint8") + W_up_shared = T.alloc_shared((block_expert, block_hidden), "uint8") + + gate_local = T.alloc_fragment((block_token, block_expert), "float32") + up_local = T.alloc_fragment((block_token, block_expert), "float32") + + T.clear(gate_local) + T.clear(up_local) + + for k in T.Pipelined(T.ceildiv(d_hidden, block_hidden), num_stages=num_stages): + T.copy(input[bx * block_token, k * block_hidden], input_shared) + T.copy(W_gate[by * block_expert, k * block_hidden], W_gate_shared) + T.copy(W_up[by * block_expert, k * block_hidden], W_up_shared) + T.gemm(input_shared, W_gate_shared, gate_local, transpose_B=True) + T.gemm(input_shared, W_up_shared, up_local, transpose_B=True) + + # Fused SiLU activation: gate = gate * sigmoid(gate), then up = up * gate + for i, j in T.Parallel(block_token, block_expert): + gate_local[i, j] = gate_local[i, j] * ( + 1.0 / (1.0 + T.exp2(-gate_local[i, j] * scale)) + ) + up_local[i, j] = up_local[i, j] * gate_local[i, j] + + T.copy(up_local, output[bx * block_token, by * block_expert]) + + return main + + +# Problem sizes (small for testing) +num_tokens = 128 +d_hidden = 256 +d_expert = 256 + +print(f"Running FP4 MoE (A8W4): tokens={num_tokens}, hidden={d_hidden}, expert={d_expert}") + +func = moe_shared_expert_a8w4( + num_tokens, d_hidden, d_expert, + block_token=128, block_hidden=128, block_expert=128, + threads=128, num_stages=1, +) + +jit_kernel = tilelang.compile( + func, + out_idx=[4], + target="cuda", + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, +) + +print("Compilation succeeded!") + +torch.manual_seed(42) + +# Create test data +input_fp8 = torch.randn(num_tokens, d_hidden, device="cuda", dtype=torch.float16).to(torch.float8_e4m3fn) +W_gate_uint8 = torch.randint(0, 16, (d_expert, d_hidden), device="cuda", dtype=torch.uint8) +W_up_uint8 = torch.randint(0, 16, (d_expert, d_hidden), device="cuda", dtype=torch.uint8) +W_down_uint8 = torch.randint(0, 16, (d_hidden, d_expert), device="cuda", dtype=torch.uint8) + +# --- Test 1: zeros --- +z_input = torch.zeros(num_tokens, d_hidden, device="cuda", dtype=torch.float8_e4m3fn) +z_gate = torch.zeros(d_expert, d_hidden, device="cuda", dtype=torch.uint8) +z_up = torch.zeros(d_expert, d_hidden, device="cuda", dtype=torch.uint8) +z_down = torch.zeros(d_hidden, d_expert, device="cuda", dtype=torch.uint8) +c_zero = jit_kernel(z_input, z_gate, z_up, z_down) +print(f"[{'PASS' if c_zero.abs().max().item() == 0.0 else 'FAIL'}] zeros in -> zeros out") + +# --- Test 2: numerical verification (gate+up only, no down GEMM in this kernel) --- +out = jit_kernel(input_fp8, W_gate_uint8, W_up_uint8, W_down_uint8) + +# Reference +input_f32 = input_fp8.to(torch.float32) +gate_f32 = fp4_uint8_to_float(W_gate_uint8) +up_f32 = fp4_uint8_to_float(W_up_uint8) + +gate_logits = input_f32 @ gate_f32.T +up_logits = input_f32 @ up_f32.T +gate_activated = gate_logits * torch.sigmoid(gate_logits) +ref_out = up_logits * gate_activated + +diff = (out.float() - ref_out).abs() +max_diff = diff.max().item() +rel_err = diff.sum().item() / (ref_out.abs().sum().item() + 1e-10) +print(f"[NUMERICAL] max_abs_diff={max_diff:.4f}, rel_err={rel_err:.6f}") +if rel_err < 0.05: + print("[PASS] MoE gate+up fusion numerical verification") +else: + print("[WARN] large diff") + +# --- Benchmark --- +torch.cuda.synchronize() +start = time.perf_counter() +for _ in range(100): + jit_kernel(input_fp8, W_gate_uint8, W_up_uint8, W_down_uint8) +torch.cuda.synchronize() +elapsed = (time.perf_counter() - start) / 100 * 1000 +total_flops = 2 * num_tokens * d_hidden * d_expert * 2 # 2 GEMMs +print(f"Latency: {elapsed:.4f} ms") +print(f"TFLOPS: {total_flops / (elapsed / 1e3) / 1e12:.2f}") diff --git a/examples/gemm_fp4/example_gemm_a8w4_sm120.py b/examples/gemm_fp4/example_gemm_a8w4_sm120.py new file mode 100644 index 0000000000..a5174e7493 --- /dev/null +++ b/examples/gemm_fp4/example_gemm_a8w4_sm120.py @@ -0,0 +1,126 @@ +"""A8W4 GEMM on SM120: FP8 (e4m3) activation x FP4 (e2m1) weight. + +Uses SM120 native mma.sync.kind::f8f6f4 with mixed-type operands: + A: float8_e4m3fn (activation, 1 byte/element) + B: float4_e2m1fn stored as uint8 (weight, unpacked, 1 byte/element) + C: float32 (accumulator/output) + +No block scaling. Direct FP8 x FP4 tensor core multiply-accumulate. +""" + +import os +import time +import torch +import tilelang +import tilelang.language as T + + +def gemm_a8w4( + M, N, K, block_M, block_N, block_K, + out_dtype, accum_dtype, + num_stages=2, threads=128, +): + A_shape = (M, K) + B_shape = (N, K) + + @T.prim_func + def main( + A: T.Tensor(A_shape, "float8_e4m3fn"), + B: T.Tensor(B_shape, "uint8"), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), "float8_e4m3fn") + B_shared = T.alloc_shared((block_N, block_K), "uint8") + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + T.clear(C_local) + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(A[by * block_M, ko * block_K], A_shared) + T.copy(B[bx * block_N, ko * block_K], B_shared) + T.gemm(A_shared, B_shared, C_local, transpose_B=True) + + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +FP4_E2M1_TO_FLOAT = [ + 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, + -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0, +] + + +def fp4_uint8_to_float(tensor_uint8): + """Convert uint8 tensor (low nibble = FP4 e2m1) to float32.""" + lut = torch.tensor(FP4_E2M1_TO_FLOAT, dtype=torch.float32, device=tensor_uint8.device) + return lut[tensor_uint8.to(torch.int64)] + + +M, N, K = 256, 256, 256 +block_M, block_N, block_K = 128, 128, 128 +out_dtype = T.float32 +accum_dtype = T.float32 + +print(f"Running A8W4 GEMM: M={M}, N={N}, K={K}") +print(f" A: float8_e4m3fn, B: FP4 (unpacked uint8)") + +func = gemm_a8w4( + M, N, K, block_M, block_N, block_K, + out_dtype, accum_dtype, + num_stages=2, threads=128, +) + +jit_kernel = tilelang.compile( + func, + out_idx=[2], + target="cuda", + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, +) + +print("Compilation succeeded!") + +torch.manual_seed(42) + +# A: FP8 e4m3fn values (create as float then convert) +a_float = torch.randn(M, K, device="cuda", dtype=torch.float16) +a_fp8 = a_float.to(torch.float8_e4m3fn) + +# B: FP4 e2m1 values (random nibbles 0-15, stored as uint8) +b_uint8 = torch.randint(0, 16, (N, K), device="cuda", dtype=torch.uint8) + +# --- Test 1: zeros --- +a_zero = torch.zeros(M, K, device="cuda", dtype=torch.float8_e4m3fn) +b_zero = torch.zeros(N, K, device="cuda", dtype=torch.uint8) +c_zero = jit_kernel(a_zero, b_zero) +assert c_zero.abs().max().item() == 0.0, f"Zero test failed: max={c_zero.abs().max().item()}" +print("[PASS] zeros in -> zeros out") + +# --- Test 2: numerical verification --- +c = jit_kernel(a_fp8, b_uint8) + +a_ref = a_fp8.to(torch.float32) +b_ref = fp4_uint8_to_float(b_uint8) +ref_c = a_ref @ b_ref.T + +diff = (c.float() - ref_c).abs() +max_diff = diff.max().item() +rel_err = diff.sum().item() / (ref_c.abs().sum().item() + 1e-10) +print(f"[NUMERICAL] max_abs_diff={max_diff:.4f}, rel_err={rel_err:.6f}") +if rel_err < 0.01: + print("[PASS] numerical verification (rel_err < 0.01)") +else: + print(f"[WARN] large diff -- may indicate layout or data flow issue") + +# --- Benchmark --- +torch.cuda.synchronize() +start = time.perf_counter() +for _ in range(100): + jit_kernel(a_fp8, b_uint8) +torch.cuda.synchronize() +elapsed = (time.perf_counter() - start) / 100 * 1000 +print(f"Latency: {elapsed:.4f} ms") +print(f"TFLOPS: {2 * M * N * K / (elapsed / 1e3) / 1e12:.2f}") diff --git a/src/tl_templates/cuda/gemm_mma.h b/src/tl_templates/cuda/gemm_mma.h index c4ba99af58..586b459f4c 100644 --- a/src/tl_templates/cuda/gemm_mma.h +++ b/src/tl_templates/cuda/gemm_mma.h @@ -46,6 +46,8 @@ using _X = Underscore; #include #include TL_DISPATCH_MMA_TEMPLATE(fp4_e2_t, fp4_e2_t, float, SM120_16x8x32_TN) +TL_DISPATCH_MMA_TEMPLATE(fp8_e4_t, fp4_e2_t, float, SM120_16x8x32_TN) +TL_DISPATCH_MMA_TEMPLATE(fp4_e2_t, fp8_e4_t, float, SM120_16x8x32_TN) TL_DISPATCH_MMA_TEMPLATE(fp8_e4_t, fp8_e4_t, float, SM120_16x8x32_TN) TL_DISPATCH_MMA_TEMPLATE(fp8_e5_t, fp8_e5_t, float, SM120_16x8x32_TN) TL_DISPATCH_MMA(half_t, half_t, half_t, SM80_16x8x16_F16F16F16F16_TN) diff --git a/tilelang/tileop/gemm/gemm_mma.py b/tilelang/tileop/gemm/gemm_mma.py index 52b08c146e..f91da59b31 100644 --- a/tilelang/tileop/gemm/gemm_mma.py +++ b/tilelang/tileop/gemm/gemm_mma.py @@ -20,7 +20,7 @@ def infer_layout(self, target: Target, thread_nums: int): warp_col_tiles = int(self.N // n_warp) mma_emitter = TensorCoreIntrinEmitter( a_dtype=self.in_dtype, - b_dtype=self.in_dtype, + b_dtype=self.in_dtype_b, accum_dtype=self.accum_dtype, a_transposed=self.trans_A, b_transposed=self.trans_B, @@ -64,7 +64,7 @@ def lower(self, layout_map: dict, target: Target, thread_bounds: Range, thread_v warp_col_tiles = int(self.N // n_warp) mma_emitter = TensorCoreIntrinEmitter( a_dtype=self.in_dtype, - b_dtype=self.in_dtype, + b_dtype=self.in_dtype_b, accum_dtype=self.accum_dtype, a_transposed=self.trans_A, b_transposed=self.trans_B, @@ -77,6 +77,7 @@ def lower(self, layout_map: dict, target: Target, thread_bounds: Range, thread_v ) in_dtype = self.in_dtype + in_dtype_b = self.in_dtype_b warp_rows = mma_emitter.warp_rows warp_cols = mma_emitter.warp_cols local_size_a = mma_emitter.local_size_a @@ -109,7 +110,7 @@ def _gemm_ssr() -> None: accumulating into C_local. """ A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) - B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), in_dtype_b) if clear_accum: T.clear(C_buf) for ki in T.serial(0, (block_K // micro_size_k)): @@ -173,7 +174,7 @@ def _gemm_rsr() -> None: B_shared into local fragments, then issues Tensor Core mma ops, accumulating into C_local. """ - B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), in_dtype_b) if clear_accum: T.clear(C_buf) for ki in T.serial(0, (block_K // micro_size_k)): From 10a271029409fd85c4a73cf362ab9c5ed71876d2 Mon Sep 17 00:00:00 2001 From: wahao Date: Mon, 11 May 2026 14:14:46 +0800 Subject: [PATCH 07/19] fix: fix precision using smem unpacked layout, setting bit 2-5/8 bits as payload, functionality verified valid for naive gemm-fp4 --- examples/amd_mxfp4_mm/asm-dump-v22/src_v22.s | 2193 +++++++++++++++++ examples/amd_mxfp4_mm/asm-dump/src.s | 1283 ++++++++++ examples/gemm_fp4/gemm_fp4_sm100.cu | 100 + examples/gemm_sm100/gemm_tcgen5mma_fp4.py | 214 ++ examples/gemm_sm100/test_fp4_diagonal.py | 216 ++ .../gemm_sm100/test_fp4_single_tile_probe.py | 276 +++ ...est_fp4_single_tile_probe.tma.generated.cu | 102 + ...ile_probe.tma.linear_epilogue.generated.cu | 90 + ...probe.tma.tma_only.descriptor.generated.cu | 65 + ...ingle_tile_probe.tma.tma_only.generated.cu | 61 + examples/gemm_sm100/test_tma_fp4_roundtrip.py | 207 ++ src/backend/cuda/codegen/codegen_cuda.cc | 28 +- src/backend/cuda/op/copy.cc | 24 +- src/backend/cuda/op/gemm.cc | 12 +- src/layout/gemm_layouts.cc | 56 + src/layout/layout.cc | 4 + src/layout/layout.h | 1 + src/op/tcgen5_meta.h | 7 +- src/op/utils.cc | 8 +- src/tl_templates/cuda/gemm_mma.h | 1 + src/tl_templates/cuda/gemm_sm100.h | 461 ++++ .../cuda/instruction/tcgen05mma.h | 20 - .../intrinsics/tcgen05_macro_generator.py | 7 + tilelang/layout/__init__.py | 1 + tilelang/layout/swizzle.py | 54 +- tilelang/tileop/gemm/gemm_tcgen05.py | 32 +- 26 files changed, 5465 insertions(+), 58 deletions(-) create mode 100644 examples/amd_mxfp4_mm/asm-dump-v22/src_v22.s create mode 100644 examples/amd_mxfp4_mm/asm-dump/src.s create mode 100644 examples/gemm_fp4/gemm_fp4_sm100.cu create mode 100644 examples/gemm_sm100/gemm_tcgen5mma_fp4.py create mode 100644 examples/gemm_sm100/test_fp4_diagonal.py create mode 100644 examples/gemm_sm100/test_fp4_single_tile_probe.py create mode 100644 examples/gemm_sm100/test_fp4_single_tile_probe.tma.generated.cu create mode 100644 examples/gemm_sm100/test_fp4_single_tile_probe.tma.linear_epilogue.generated.cu create mode 100644 examples/gemm_sm100/test_fp4_single_tile_probe.tma.tma_only.descriptor.generated.cu create mode 100644 examples/gemm_sm100/test_fp4_single_tile_probe.tma.tma_only.generated.cu create mode 100644 examples/gemm_sm100/test_tma_fp4_roundtrip.py diff --git a/examples/amd_mxfp4_mm/asm-dump-v22/src_v22.s b/examples/amd_mxfp4_mm/asm-dump-v22/src_v22.s new file mode 100644 index 0000000000..a724dd05e7 --- /dev/null +++ b/examples/amd_mxfp4_mm/asm-dump-v22/src_v22.s @@ -0,0 +1,2193 @@ + +# __CLANG_OFFLOAD_BUNDLE____START__ hip-amdgcn-amd-amdhsa--gfx950 + .amdgcn_target "amdgcn-amd-amdhsa--gfx950" + .amdhsa_code_object_version 6 + .text + .protected _Z14gemm_core_gmemPKhS0_PvS0_S0_iiiiiii ; -- Begin function _Z14gemm_core_gmemPKhS0_PvS0_S0_iiiiiii + .globl _Z14gemm_core_gmemPKhS0_PvS0_S0_iiiiiii + .p2align 8 + .type _Z14gemm_core_gmemPKhS0_PvS0_S0_iiiiiii,@function +_Z14gemm_core_gmemPKhS0_PvS0_S0_iiiiiii: ; @_Z14gemm_core_gmemPKhS0_PvS0_S0_iiiiiii +; %bb.0: + s_load_dwordx2 s[6:7], s[0:1], 0x10 + s_load_dwordx8 s[8:15], s[0:1], 0x28 + v_and_b32_e32 v1, 31, v0 + s_waitcnt lgkmcnt(0) + s_lshl_b32 s15, s2, 5 + s_lshl_b32 s17, s3, 5 + v_or_b32_e32 v2, s17, v1 + s_mul_i32 s5, s11, s4 + s_add_i32 s2, s5, s11 + s_min_i32 s11, s2, s10 + v_lshrrev_b32_e32 v14, 5, v0 + v_cmp_gt_i32_e64 s[2:3], s9, v2 + v_ashrrev_i32_e32 v3, 31, v2 + v_mov_b32_e32 v5, 0 + v_accvgpr_write_b32 a0, 0 + v_accvgpr_write_b32 a1, 0 + v_accvgpr_write_b32 a2, 0 + v_accvgpr_write_b32 a3, 0 + v_accvgpr_write_b32 a4, 0 + v_accvgpr_write_b32 a5, 0 + v_accvgpr_write_b32 a6, 0 + v_accvgpr_write_b32 a7, 0 + v_accvgpr_write_b32 a8, 0 + v_accvgpr_write_b32 a9, 0 + v_accvgpr_write_b32 a10, 0 + v_accvgpr_write_b32 a11, 0 + v_accvgpr_write_b32 a12, 0 + v_accvgpr_write_b32 a13, 0 + v_accvgpr_write_b32 a14, 0 + s_cmp_ge_i32 s5, s11 + v_accvgpr_write_b32 a15, 0 + s_cbranch_scc1 .LBB0_11 +; %bb.1: ; %.preheader123 + s_load_dwordx4 s[20:23], s[0:1], 0x0 + s_load_dwordx4 s[24:27], s[0:1], 0x18 + v_lshlrev_b32_e32 v0, 2, v0 + s_ashr_i32 s16, s10, 5 + v_or_b32_e32 v12, s15, v1 + v_lshrrev_b32_e32 v4, 4, v1 + v_and_b32_e32 v10, 60, v0 + s_lshl_b32 s0, s14, 5 + s_ashr_i32 s1, s17, 5 + s_ashr_i32 s10, s10, 1 + s_waitcnt lgkmcnt(0) + v_mov_b64_e32 v[0:1], s[20:21] + v_mov_b64_e32 v[6:7], s[22:23] + v_mov_b64_e32 v[8:9], s[24:25] + s_mul_hi_i32 s14, s0, s1 + s_mul_i32 s17, s0, s1 + v_mad_i64_i32 v[0:1], s[0:1], s10, v12, v[0:1] + v_mad_i64_i32 v[6:7], s[0:1], s10, v2, v[6:7] + v_mad_i64_i32 v[8:9], s[0:1], s13, v12, v[8:9] + s_add_u32 s0, s26, s17 + v_mov_b32_e32 v11, v5 + s_addc_u32 s1, s27, s14 + v_lshl_add_u64 v[10:11], s[0:1], 0, v[10:11] + v_cmp_gt_i32_e32 vcc, s8, v12 + v_accvgpr_write_b32 a15, 0 + v_accvgpr_write_b32 a14, 0 + v_accvgpr_write_b32 a13, 0 + v_accvgpr_write_b32 a12, 0 + v_accvgpr_write_b32 a11, 0 + v_accvgpr_write_b32 a10, 0 + v_accvgpr_write_b32 a9, 0 + v_accvgpr_write_b32 a8, 0 + v_accvgpr_write_b32 a7, 0 + v_accvgpr_write_b32 a6, 0 + v_accvgpr_write_b32 a5, 0 + v_accvgpr_write_b32 a4, 0 + v_accvgpr_write_b32 a3, 0 + v_accvgpr_write_b32 a2, 0 + v_accvgpr_write_b32 a1, 0 + v_accvgpr_write_b32 a0, 0 + v_lshlrev_b32_e32 v15, 4, v14 + v_lshl_add_u64 v[10:11], v[10:11], 0, v[4:5] + s_branch .LBB0_3 +.LBB0_2: ; %_Z14load_frag_gmemPKhS0_S0_S0_iiiillllibbRDv32_hS2_RhS3_.exit + ; in Loop: Header=BB0_3 Depth=1 + s_or_b64 exec, exec, s[0:1] + s_waitcnt vmcnt(0) + v_mfma_scale_f32_32x32x64_f8f6f4 a[0:15], v[16:19], v[20:23], a[0:15], v13, v4 op_sel_hi:[0,0,0] cbsz:4 blgp:4 + s_add_i32 s5, s5, 64 + s_cmp_lt_i32 s5, s11 + s_cbranch_scc0 .LBB0_11 +.LBB0_3: ; =>This Inner Loop Header: Depth=1 + s_ashr_i32 s0, s5, 1 + v_add_u32_e32 v12, s0, v15 + v_mov_b32_e32 v16, 0 + v_mov_b32_e32 v17, 0 + v_mov_b32_e32 v18, 0 + v_mov_b32_e32 v20, 0 + v_ashrrev_i32_e32 v13, 31, v12 + v_mov_b32_e32 v19, 0 + s_and_saveexec_b64 s[0:1], vcc + s_cbranch_execz .LBB0_5 +; %bb.4: ; %.loopexit45.i + ; in Loop: Header=BB0_3 Depth=1 + v_lshl_add_u64 v[16:17], v[0:1], 0, v[12:13] + global_load_dwordx4 v[16:19], v[16:17], off +.LBB0_5: ; in Loop: Header=BB0_3 Depth=1 + s_or_b64 exec, exec, s[0:1] + v_mov_b32_e32 v21, 0 + v_mov_b32_e32 v22, 0 + v_mov_b32_e32 v23, 0 + s_and_saveexec_b64 s[0:1], s[2:3] + s_cbranch_execz .LBB0_7 +; %bb.6: ; %.loopexit.i + ; in Loop: Header=BB0_3 Depth=1 + v_lshl_add_u64 v[12:13], v[6:7], 0, v[12:13] + global_load_dwordx4 v[20:23], v[12:13], off +.LBB0_7: ; in Loop: Header=BB0_3 Depth=1 + s_or_b64 exec, exec, s[0:1] + s_ashr_i32 s0, s5, 5 + v_add_u32_e32 v12, s0, v14 + v_cmp_gt_i32_e64 s[0:1], s13, v12 + s_and_b64 s[18:19], vcc, s[0:1] + v_mov_b32_e32 v4, 0x7f + v_mov_b32_e32 v13, 0x7f + s_and_saveexec_b64 s[0:1], s[18:19] + s_cbranch_execz .LBB0_9 +; %bb.8: ; in Loop: Header=BB0_3 Depth=1 + v_ashrrev_i32_e32 v13, 31, v12 + v_lshl_add_u64 v[24:25], v[8:9], 0, v[12:13] + global_load_ubyte v13, v[24:25], off +.LBB0_9: ; in Loop: Header=BB0_3 Depth=1 + s_or_b64 exec, exec, s[0:1] + v_cmp_gt_i32_e64 s[0:1], s16, v12 + s_and_b64 s[18:19], s[2:3], s[0:1] + s_and_saveexec_b64 s[0:1], s[18:19] + s_cbranch_execz .LBB0_2 +; %bb.10: ; in Loop: Header=BB0_3 Depth=1 + v_lshlrev_b32_e32 v4, 5, v12 + v_and_b32_e32 v24, 0xffffff00, v4 + v_ashrrev_i32_e32 v25, 31, v24 + v_lshlrev_b32_e32 v4, 6, v12 + v_and_b32_e32 v4, 0xc0, v4 + v_lshrrev_b32_e32 v12, 1, v12 + v_lshl_add_u64 v[24:25], v[10:11], 0, v[24:25] + v_and_b32_e32 v26, 2, v12 + v_mov_b32_e32 v27, v5 + v_lshl_add_u64 v[24:25], v[24:25], 0, v[4:5] + v_lshl_add_u64 v[24:25], v[24:25], 0, v[26:27] + global_load_ubyte v4, v[24:25], off + s_branch .LBB0_2 +.LBB0_11: ; %.loopexit124 + s_cmp_lg_u32 s12, 1 + s_mov_b64 s[0:1], -1 + s_cbranch_scc0 .LBB0_46 +; %bb.12: + s_and_saveexec_b64 s[0:1], s[2:3] + s_cbranch_execz .LBB0_45 +; %bb.13: ; %.preheader121 + v_lshl_add_u32 v4, v14, 2, s15 + s_mul_hi_i32 s5, s8, s4 + s_mul_i32 s4, s8, s4 + s_ashr_i32 s12, s9, 31 + v_lshl_add_u64 v[0:1], v[2:3], 2, s[6:7] + v_cmp_gt_i32_e32 vcc, s8, v4 + s_and_saveexec_b64 s[10:11], vcc + s_cbranch_execz .LBB0_15 +; %bb.14: + v_ashrrev_i32_e32 v5, 31, v4 + v_lshl_add_u64 v[6:7], s[4:5], 0, v[4:5] + v_mul_lo_u32 v5, v7, s9 + v_mul_lo_u32 v8, v6, s12 + v_mad_u64_u32 v[6:7], s[16:17], v6, s9, 0 + v_add3_u32 v7, v7, v8, v5 + v_lshl_add_u64 v[6:7], v[6:7], 2, v[0:1] + global_store_dword v[6:7], a0, off +.LBB0_15: + s_or_b64 exec, exec, s[10:11] + v_or_b32_e32 v6, 1, v4 + v_cmp_gt_i32_e32 vcc, s8, v6 + s_and_saveexec_b64 s[10:11], vcc + s_cbranch_execz .LBB0_17 +; %bb.16: + v_ashrrev_i32_e32 v7, 31, v6 + v_lshl_add_u64 v[6:7], s[4:5], 0, v[6:7] + v_mul_lo_u32 v5, v7, s9 + v_mul_lo_u32 v8, v6, s12 + v_mad_u64_u32 v[6:7], s[16:17], v6, s9, 0 + v_add3_u32 v7, v7, v8, v5 + v_lshl_add_u64 v[6:7], v[6:7], 2, v[0:1] + global_store_dword v[6:7], a1, off +.LBB0_17: + s_or_b64 exec, exec, s[10:11] + v_or_b32_e32 v6, 2, v4 + v_cmp_gt_i32_e32 vcc, s8, v6 + s_and_saveexec_b64 s[10:11], vcc + s_cbranch_execz .LBB0_19 +; %bb.18: + v_ashrrev_i32_e32 v7, 31, v6 + v_lshl_add_u64 v[6:7], s[4:5], 0, v[6:7] + v_mul_lo_u32 v5, v7, s9 + v_mul_lo_u32 v8, v6, s12 + v_mad_u64_u32 v[6:7], s[16:17], v6, s9, 0 + v_add3_u32 v7, v7, v8, v5 + v_lshl_add_u64 v[6:7], v[6:7], 2, v[0:1] + global_store_dword v[6:7], a2, off +.LBB0_19: + s_or_b64 exec, exec, s[10:11] + v_or_b32_e32 v6, 3, v4 + v_cmp_gt_i32_e32 vcc, s8, v6 + s_and_saveexec_b64 s[10:11], vcc + s_cbranch_execz .LBB0_21 +; %bb.20: + v_ashrrev_i32_e32 v7, 31, v6 + v_lshl_add_u64 v[6:7], s[4:5], 0, v[6:7] + v_mul_lo_u32 v5, v7, s9 + v_mul_lo_u32 v8, v6, s12 + v_mad_u64_u32 v[6:7], s[16:17], v6, s9, 0 + v_add3_u32 v7, v7, v8, v5 + v_lshl_add_u64 v[6:7], v[6:7], 2, v[0:1] + global_store_dword v[6:7], a3, off +.LBB0_21: + s_or_b64 exec, exec, s[10:11] + v_add_u32_e32 v6, 8, v4 + v_cmp_gt_i32_e32 vcc, s8, v6 + s_and_saveexec_b64 s[10:11], vcc + s_cbranch_execz .LBB0_23 +; %bb.22: + v_ashrrev_i32_e32 v7, 31, v6 + v_lshl_add_u64 v[6:7], s[4:5], 0, v[6:7] + v_mul_lo_u32 v5, v7, s9 + v_mul_lo_u32 v8, v6, s12 + v_mad_u64_u32 v[6:7], s[16:17], v6, s9, 0 + v_add3_u32 v7, v7, v8, v5 + v_lshl_add_u64 v[6:7], v[6:7], 2, v[0:1] + global_store_dword v[6:7], a4, off +.LBB0_23: + s_or_b64 exec, exec, s[10:11] + v_add_u32_e32 v6, 9, v4 + v_cmp_gt_i32_e32 vcc, s8, v6 + s_and_saveexec_b64 s[10:11], vcc + s_cbranch_execz .LBB0_25 +; %bb.24: + v_ashrrev_i32_e32 v7, 31, v6 + v_lshl_add_u64 v[6:7], s[4:5], 0, v[6:7] + v_mul_lo_u32 v5, v7, s9 + v_mul_lo_u32 v8, v6, s12 + v_mad_u64_u32 v[6:7], s[16:17], v6, s9, 0 + v_add3_u32 v7, v7, v8, v5 + v_lshl_add_u64 v[6:7], v[6:7], 2, v[0:1] + global_store_dword v[6:7], a5, off +.LBB0_25: + s_or_b64 exec, exec, s[10:11] + v_add_u32_e32 v6, 10, v4 + v_cmp_gt_i32_e32 vcc, s8, v6 + s_and_saveexec_b64 s[10:11], vcc + s_cbranch_execz .LBB0_27 +; %bb.26: + v_ashrrev_i32_e32 v7, 31, v6 + v_lshl_add_u64 v[6:7], s[4:5], 0, v[6:7] + v_mul_lo_u32 v5, v7, s9 + v_mul_lo_u32 v8, v6, s12 + v_mad_u64_u32 v[6:7], s[16:17], v6, s9, 0 + v_add3_u32 v7, v7, v8, v5 + v_lshl_add_u64 v[6:7], v[6:7], 2, v[0:1] + global_store_dword v[6:7], a6, off +.LBB0_27: + s_or_b64 exec, exec, s[10:11] + v_add_u32_e32 v6, 11, v4 + v_cmp_gt_i32_e32 vcc, s8, v6 + s_and_saveexec_b64 s[10:11], vcc + s_cbranch_execz .LBB0_29 +; %bb.28: + v_ashrrev_i32_e32 v7, 31, v6 + v_lshl_add_u64 v[6:7], s[4:5], 0, v[6:7] + v_mul_lo_u32 v5, v7, s9 + v_mul_lo_u32 v8, v6, s12 + v_mad_u64_u32 v[6:7], s[16:17], v6, s9, 0 + v_add3_u32 v7, v7, v8, v5 + v_lshl_add_u64 v[6:7], v[6:7], 2, v[0:1] + global_store_dword v[6:7], a7, off +.LBB0_29: + s_or_b64 exec, exec, s[10:11] + v_add_u32_e32 v6, 16, v4 + v_cmp_gt_i32_e32 vcc, s8, v6 + s_and_saveexec_b64 s[10:11], vcc + s_cbranch_execz .LBB0_31 +; %bb.30: + v_ashrrev_i32_e32 v7, 31, v6 + v_lshl_add_u64 v[6:7], s[4:5], 0, v[6:7] + v_mul_lo_u32 v5, v7, s9 + v_mul_lo_u32 v8, v6, s12 + v_mad_u64_u32 v[6:7], s[16:17], v6, s9, 0 + v_add3_u32 v7, v7, v8, v5 + v_lshl_add_u64 v[6:7], v[6:7], 2, v[0:1] + global_store_dword v[6:7], a8, off +.LBB0_31: + s_or_b64 exec, exec, s[10:11] + v_add_u32_e32 v6, 17, v4 + v_cmp_gt_i32_e32 vcc, s8, v6 + s_and_saveexec_b64 s[10:11], vcc + s_cbranch_execz .LBB0_33 +; %bb.32: + v_ashrrev_i32_e32 v7, 31, v6 + v_lshl_add_u64 v[6:7], s[4:5], 0, v[6:7] + v_mul_lo_u32 v5, v7, s9 + v_mul_lo_u32 v8, v6, s12 + v_mad_u64_u32 v[6:7], s[16:17], v6, s9, 0 + v_add3_u32 v7, v7, v8, v5 + v_lshl_add_u64 v[6:7], v[6:7], 2, v[0:1] + global_store_dword v[6:7], a9, off +.LBB0_33: + s_or_b64 exec, exec, s[10:11] + v_add_u32_e32 v6, 18, v4 + v_cmp_gt_i32_e32 vcc, s8, v6 + s_and_saveexec_b64 s[10:11], vcc + s_cbranch_execz .LBB0_35 +; %bb.34: + v_ashrrev_i32_e32 v7, 31, v6 + v_lshl_add_u64 v[6:7], s[4:5], 0, v[6:7] + v_mul_lo_u32 v5, v7, s9 + v_mul_lo_u32 v8, v6, s12 + v_mad_u64_u32 v[6:7], s[16:17], v6, s9, 0 + v_add3_u32 v7, v7, v8, v5 + v_lshl_add_u64 v[6:7], v[6:7], 2, v[0:1] + global_store_dword v[6:7], a10, off +.LBB0_35: + s_or_b64 exec, exec, s[10:11] + v_add_u32_e32 v6, 19, v4 + v_cmp_gt_i32_e32 vcc, s8, v6 + s_and_saveexec_b64 s[10:11], vcc + s_cbranch_execz .LBB0_37 +; %bb.36: + v_ashrrev_i32_e32 v7, 31, v6 + v_lshl_add_u64 v[6:7], s[4:5], 0, v[6:7] + v_mul_lo_u32 v5, v7, s9 + v_mul_lo_u32 v8, v6, s12 + v_mad_u64_u32 v[6:7], s[16:17], v6, s9, 0 + v_add3_u32 v7, v7, v8, v5 + v_lshl_add_u64 v[6:7], v[6:7], 2, v[0:1] + global_store_dword v[6:7], a11, off +.LBB0_37: + s_or_b64 exec, exec, s[10:11] + v_add_u32_e32 v6, 24, v4 + v_cmp_gt_i32_e32 vcc, s8, v6 + s_and_saveexec_b64 s[10:11], vcc + s_cbranch_execz .LBB0_39 +; %bb.38: + v_ashrrev_i32_e32 v7, 31, v6 + v_lshl_add_u64 v[6:7], s[4:5], 0, v[6:7] + v_mul_lo_u32 v5, v7, s9 + v_mul_lo_u32 v8, v6, s12 + v_mad_u64_u32 v[6:7], s[16:17], v6, s9, 0 + v_add3_u32 v7, v7, v8, v5 + v_lshl_add_u64 v[6:7], v[6:7], 2, v[0:1] + global_store_dword v[6:7], a12, off +.LBB0_39: + s_or_b64 exec, exec, s[10:11] + v_add_u32_e32 v6, 25, v4 + v_cmp_gt_i32_e32 vcc, s8, v6 + s_and_saveexec_b64 s[10:11], vcc + s_cbranch_execz .LBB0_41 +; %bb.40: + v_ashrrev_i32_e32 v7, 31, v6 + v_lshl_add_u64 v[6:7], s[4:5], 0, v[6:7] + v_mul_lo_u32 v5, v7, s9 + v_mul_lo_u32 v8, v6, s12 + v_mad_u64_u32 v[6:7], s[16:17], v6, s9, 0 + v_add3_u32 v7, v7, v8, v5 + v_lshl_add_u64 v[6:7], v[6:7], 2, v[0:1] + global_store_dword v[6:7], a13, off +.LBB0_41: + s_or_b64 exec, exec, s[10:11] + v_add_u32_e32 v6, 26, v4 + v_cmp_gt_i32_e32 vcc, s8, v6 + s_and_saveexec_b64 s[10:11], vcc + s_cbranch_execz .LBB0_43 +; %bb.42: + v_ashrrev_i32_e32 v7, 31, v6 + v_lshl_add_u64 v[6:7], s[4:5], 0, v[6:7] + v_mul_lo_u32 v5, v7, s9 + v_mul_lo_u32 v8, v6, s12 + v_mad_u64_u32 v[6:7], s[16:17], v6, s9, 0 + v_add3_u32 v7, v7, v8, v5 + v_lshl_add_u64 v[6:7], v[6:7], 2, v[0:1] + global_store_dword v[6:7], a14, off +.LBB0_43: + s_or_b64 exec, exec, s[10:11] + v_add_u32_e32 v4, 27, v4 + v_cmp_gt_i32_e32 vcc, s8, v4 + s_and_b64 exec, exec, vcc + s_cbranch_execz .LBB0_45 +; %bb.44: + v_ashrrev_i32_e32 v5, 31, v4 + v_lshl_add_u64 v[4:5], s[4:5], 0, v[4:5] + v_mul_lo_u32 v6, v5, s9 + v_mul_lo_u32 v7, v4, s12 + v_mad_u64_u32 v[4:5], s[4:5], v4, s9, 0 + v_add3_u32 v5, v5, v7, v6 + v_lshl_add_u64 v[0:1], v[4:5], 2, v[0:1] + global_store_dword v[0:1], a15, off +.LBB0_45: ; %Flow403 + s_or_b64 exec, exec, s[0:1] + s_mov_b64 s[0:1], 0 +.LBB0_46: ; %Flow406 + s_andn2_b64 vcc, exec, s[0:1] + s_cbranch_vccnz .LBB0_80 +; %bb.47: + s_and_saveexec_b64 s[0:1], s[2:3] + s_cbranch_execz .LBB0_80 +; %bb.48: ; %.preheader + v_lshl_add_u32 v4, v14, 2, s15 + v_lshl_add_u64 v[0:1], v[2:3], 1, s[6:7] + v_cmp_gt_i32_e32 vcc, s8, v4 + s_and_saveexec_b64 s[0:1], vcc + s_cbranch_execz .LBB0_50 +; %bb.49: + v_mad_i64_i32 v[2:3], s[2:3], v4, s9, 0 + v_lshl_add_u64 v[2:3], v[2:3], 1, v[0:1] + global_store_short_d16_hi v[2:3], a0, off +.LBB0_50: + s_or_b64 exec, exec, s[0:1] + v_or_b32_e32 v2, 1, v4 + v_cmp_gt_i32_e32 vcc, s8, v2 + s_and_saveexec_b64 s[0:1], vcc + s_cbranch_execz .LBB0_52 +; %bb.51: + v_mad_i64_i32 v[2:3], s[2:3], v2, s9, 0 + v_lshl_add_u64 v[2:3], v[2:3], 1, v[0:1] + global_store_short_d16_hi v[2:3], a1, off +.LBB0_52: + s_or_b64 exec, exec, s[0:1] + v_or_b32_e32 v2, 2, v4 + v_cmp_gt_i32_e32 vcc, s8, v2 + s_and_saveexec_b64 s[0:1], vcc + s_cbranch_execz .LBB0_54 +; %bb.53: + v_mad_i64_i32 v[2:3], s[2:3], v2, s9, 0 + v_lshl_add_u64 v[2:3], v[2:3], 1, v[0:1] + global_store_short_d16_hi v[2:3], a2, off +.LBB0_54: + s_or_b64 exec, exec, s[0:1] + v_or_b32_e32 v2, 3, v4 + v_cmp_gt_i32_e32 vcc, s8, v2 + s_and_saveexec_b64 s[0:1], vcc + s_cbranch_execz .LBB0_56 +; %bb.55: + v_mad_i64_i32 v[2:3], s[2:3], v2, s9, 0 + v_lshl_add_u64 v[2:3], v[2:3], 1, v[0:1] + global_store_short_d16_hi v[2:3], a3, off +.LBB0_56: + s_or_b64 exec, exec, s[0:1] + v_add_u32_e32 v2, 8, v4 + v_cmp_gt_i32_e32 vcc, s8, v2 + s_and_saveexec_b64 s[0:1], vcc + s_cbranch_execz .LBB0_58 +; %bb.57: + v_mad_i64_i32 v[2:3], s[2:3], v2, s9, 0 + v_lshl_add_u64 v[2:3], v[2:3], 1, v[0:1] + global_store_short_d16_hi v[2:3], a4, off +.LBB0_58: + s_or_b64 exec, exec, s[0:1] + v_add_u32_e32 v2, 9, v4 + v_cmp_gt_i32_e32 vcc, s8, v2 + s_and_saveexec_b64 s[0:1], vcc + s_cbranch_execz .LBB0_60 +; %bb.59: + v_mad_i64_i32 v[2:3], s[2:3], v2, s9, 0 + v_lshl_add_u64 v[2:3], v[2:3], 1, v[0:1] + global_store_short_d16_hi v[2:3], a5, off +.LBB0_60: + s_or_b64 exec, exec, s[0:1] + v_add_u32_e32 v2, 10, v4 + v_cmp_gt_i32_e32 vcc, s8, v2 + s_and_saveexec_b64 s[0:1], vcc + s_cbranch_execz .LBB0_62 +; %bb.61: + v_mad_i64_i32 v[2:3], s[2:3], v2, s9, 0 + v_lshl_add_u64 v[2:3], v[2:3], 1, v[0:1] + global_store_short_d16_hi v[2:3], a6, off +.LBB0_62: + s_or_b64 exec, exec, s[0:1] + v_add_u32_e32 v2, 11, v4 + v_cmp_gt_i32_e32 vcc, s8, v2 + s_and_saveexec_b64 s[0:1], vcc + s_cbranch_execz .LBB0_64 +; %bb.63: + v_mad_i64_i32 v[2:3], s[2:3], v2, s9, 0 + v_lshl_add_u64 v[2:3], v[2:3], 1, v[0:1] + global_store_short_d16_hi v[2:3], a7, off +.LBB0_64: + s_or_b64 exec, exec, s[0:1] + v_add_u32_e32 v2, 16, v4 + v_cmp_gt_i32_e32 vcc, s8, v2 + s_and_saveexec_b64 s[0:1], vcc + s_cbranch_execz .LBB0_66 +; %bb.65: + v_mad_i64_i32 v[2:3], s[2:3], v2, s9, 0 + v_lshl_add_u64 v[2:3], v[2:3], 1, v[0:1] + global_store_short_d16_hi v[2:3], a8, off +.LBB0_66: + s_or_b64 exec, exec, s[0:1] + v_add_u32_e32 v2, 17, v4 + v_cmp_gt_i32_e32 vcc, s8, v2 + s_and_saveexec_b64 s[0:1], vcc + s_cbranch_execz .LBB0_68 +; %bb.67: + v_mad_i64_i32 v[2:3], s[2:3], v2, s9, 0 + v_lshl_add_u64 v[2:3], v[2:3], 1, v[0:1] + global_store_short_d16_hi v[2:3], a9, off +.LBB0_68: + s_or_b64 exec, exec, s[0:1] + v_add_u32_e32 v2, 18, v4 + v_cmp_gt_i32_e32 vcc, s8, v2 + s_and_saveexec_b64 s[0:1], vcc + s_cbranch_execz .LBB0_70 +; %bb.69: + v_mad_i64_i32 v[2:3], s[2:3], v2, s9, 0 + v_lshl_add_u64 v[2:3], v[2:3], 1, v[0:1] + global_store_short_d16_hi v[2:3], a10, off +.LBB0_70: + s_or_b64 exec, exec, s[0:1] + v_add_u32_e32 v2, 19, v4 + v_cmp_gt_i32_e32 vcc, s8, v2 + s_and_saveexec_b64 s[0:1], vcc + s_cbranch_execz .LBB0_72 +; %bb.71: + v_mad_i64_i32 v[2:3], s[2:3], v2, s9, 0 + v_lshl_add_u64 v[2:3], v[2:3], 1, v[0:1] + global_store_short_d16_hi v[2:3], a11, off +.LBB0_72: + s_or_b64 exec, exec, s[0:1] + v_add_u32_e32 v2, 24, v4 + v_cmp_gt_i32_e32 vcc, s8, v2 + s_and_saveexec_b64 s[0:1], vcc + s_cbranch_execz .LBB0_74 +; %bb.73: + v_mad_i64_i32 v[2:3], s[2:3], v2, s9, 0 + v_lshl_add_u64 v[2:3], v[2:3], 1, v[0:1] + global_store_short_d16_hi v[2:3], a12, off +.LBB0_74: + s_or_b64 exec, exec, s[0:1] + v_add_u32_e32 v2, 25, v4 + v_cmp_gt_i32_e32 vcc, s8, v2 + s_and_saveexec_b64 s[0:1], vcc + s_cbranch_execz .LBB0_76 +; %bb.75: + v_mad_i64_i32 v[2:3], s[2:3], v2, s9, 0 + v_lshl_add_u64 v[2:3], v[2:3], 1, v[0:1] + global_store_short_d16_hi v[2:3], a13, off +.LBB0_76: + s_or_b64 exec, exec, s[0:1] + v_add_u32_e32 v2, 26, v4 + v_cmp_gt_i32_e32 vcc, s8, v2 + s_and_saveexec_b64 s[0:1], vcc + s_cbranch_execz .LBB0_78 +; %bb.77: + v_mad_i64_i32 v[2:3], s[2:3], v2, s9, 0 + v_lshl_add_u64 v[2:3], v[2:3], 1, v[0:1] + global_store_short_d16_hi v[2:3], a14, off +.LBB0_78: + s_or_b64 exec, exec, s[0:1] + v_add_u32_e32 v2, 27, v4 + v_cmp_gt_i32_e32 vcc, s8, v2 + s_and_b64 exec, exec, vcc + s_cbranch_execz .LBB0_80 +; %bb.79: + v_mad_i64_i32 v[2:3], s[0:1], v2, s9, 0 + v_lshl_add_u64 v[0:1], v[2:3], 1, v[0:1] + global_store_short_d16_hi v[0:1], a15, off +.LBB0_80: ; %.loopexit + s_endpgm + .section .rodata,"a",@progbits + .p2align 6, 0x0 + .amdhsa_kernel _Z14gemm_core_gmemPKhS0_PvS0_S0_iiiiiii + .amdhsa_group_segment_fixed_size 0 + .amdhsa_private_segment_fixed_size 0 + .amdhsa_kernarg_size 68 + .amdhsa_user_sgpr_count 2 + .amdhsa_user_sgpr_dispatch_ptr 0 + .amdhsa_user_sgpr_queue_ptr 0 + .amdhsa_user_sgpr_kernarg_segment_ptr 1 + .amdhsa_user_sgpr_dispatch_id 0 + .amdhsa_user_sgpr_kernarg_preload_length 0 + .amdhsa_user_sgpr_kernarg_preload_offset 0 + .amdhsa_user_sgpr_private_segment_size 0 + .amdhsa_uses_dynamic_stack 0 + .amdhsa_enable_private_segment 0 + .amdhsa_system_sgpr_workgroup_id_x 1 + .amdhsa_system_sgpr_workgroup_id_y 1 + .amdhsa_system_sgpr_workgroup_id_z 1 + .amdhsa_system_sgpr_workgroup_info 0 + .amdhsa_system_vgpr_workitem_id 0 + .amdhsa_next_free_vgpr 44 + .amdhsa_next_free_sgpr 28 + .amdhsa_accum_offset 28 + .amdhsa_reserve_vcc 1 + .amdhsa_float_round_mode_32 0 + .amdhsa_float_round_mode_16_64 0 + .amdhsa_float_denorm_mode_32 3 + .amdhsa_float_denorm_mode_16_64 3 + .amdhsa_dx10_clamp 1 + .amdhsa_ieee_mode 1 + .amdhsa_fp16_overflow 0 + .amdhsa_tg_split 0 + .amdhsa_exception_fp_ieee_invalid_op 0 + .amdhsa_exception_fp_denorm_src 0 + .amdhsa_exception_fp_ieee_div_zero 0 + .amdhsa_exception_fp_ieee_overflow 0 + .amdhsa_exception_fp_ieee_underflow 0 + .amdhsa_exception_fp_ieee_inexact 0 + .amdhsa_exception_int_div_zero 0 + .end_amdhsa_kernel + .text +.Lfunc_end0: + .size _Z14gemm_core_gmemPKhS0_PvS0_S0_iiiiiii, .Lfunc_end0-_Z14gemm_core_gmemPKhS0_PvS0_S0_iiiiiii + ; -- End function + .set _Z14gemm_core_gmemPKhS0_PvS0_S0_iiiiiii.num_vgpr, 28 + .set _Z14gemm_core_gmemPKhS0_PvS0_S0_iiiiiii.num_agpr, 16 + .set _Z14gemm_core_gmemPKhS0_PvS0_S0_iiiiiii.numbered_sgpr, 28 + .set _Z14gemm_core_gmemPKhS0_PvS0_S0_iiiiiii.num_named_barrier, 0 + .set _Z14gemm_core_gmemPKhS0_PvS0_S0_iiiiiii.private_seg_size, 0 + .set _Z14gemm_core_gmemPKhS0_PvS0_S0_iiiiiii.uses_vcc, 1 + .set _Z14gemm_core_gmemPKhS0_PvS0_S0_iiiiiii.uses_flat_scratch, 0 + .set _Z14gemm_core_gmemPKhS0_PvS0_S0_iiiiiii.has_dyn_sized_stack, 0 + .set _Z14gemm_core_gmemPKhS0_PvS0_S0_iiiiiii.has_recursion, 0 + .set _Z14gemm_core_gmemPKhS0_PvS0_S0_iiiiiii.has_indirect_call, 0 + .section .AMDGPU.csdata,"",@progbits +; Kernel info: +; codeLenInByte = 2828 +; TotalNumSgprs: 34 +; NumVgprs: 28 +; NumAgprs: 16 +; TotalNumVgprs: 44 +; ScratchSize: 0 +; MemoryBound: 0 +; FloatMode: 240 +; IeeeMode: 1 +; LDSByteSize: 0 bytes/workgroup (compile time only) +; SGPRBlocks: 4 +; VGPRBlocks: 5 +; NumSGPRsForWavesPerEU: 34 +; NumVGPRsForWavesPerEU: 44 +; AccumOffset: 28 +; Occupancy: 8 +; WaveLimiterHint : 0 +; COMPUTE_PGM_RSRC2:SCRATCH_EN: 0 +; COMPUTE_PGM_RSRC2:USER_SGPR: 2 +; COMPUTE_PGM_RSRC2:TRAP_HANDLER: 0 +; COMPUTE_PGM_RSRC2:TGID_X_EN: 1 +; COMPUTE_PGM_RSRC2:TGID_Y_EN: 1 +; COMPUTE_PGM_RSRC2:TGID_Z_EN: 1 +; COMPUTE_PGM_RSRC2:TIDIG_COMP_CNT: 0 +; COMPUTE_PGM_RSRC3_GFX90A:ACCUM_OFFSET: 6 +; COMPUTE_PGM_RSRC3_GFX90A:TG_SPLIT: 0 + .text + .protected _Z13gemm_core_ldsPKhS0_PvS0_S0_iiiiiii ; -- Begin function _Z13gemm_core_ldsPKhS0_PvS0_S0_iiiiiii + .globl _Z13gemm_core_ldsPKhS0_PvS0_S0_iiiiiii + .p2align 8 + .type _Z13gemm_core_ldsPKhS0_PvS0_S0_iiiiiii,@function +_Z13gemm_core_ldsPKhS0_PvS0_S0_iiiiiii: ; @_Z13gemm_core_ldsPKhS0_PvS0_S0_iiiiiii +; %bb.0: + s_load_dwordx2 s[10:11], s[0:1], 0x10 + s_load_dwordx8 s[12:19], s[0:1], 0x28 + s_lshl_b32 s3, s3, 6 + v_lshrrev_b32_e32 v2, 1, v0 + v_and_b32_e32 v1, 31, v0 + v_and_or_b32 v5, v2, 32, s3 + s_waitcnt lgkmcnt(0) + s_mul_i32 s5, s15, s4 + s_add_i32 s6, s5, s15 + s_lshl_b32 s2, s2, 6 + v_lshrrev_b32_e32 v4, 2, v0 + v_or_b32_e32 v18, v5, v1 + s_min_i32 s20, s6, s14 + v_bfe_u32 v28, v0, 5, 1 + v_and_or_b32 v29, v4, 32, s2 + s_cmp_ge_i32 s5, s20 + v_cmp_gt_i32_e64 s[8:9], s13, v18 + s_cbranch_scc1 .LBB1_26 +; %bb.1: ; %.lr.ph211 + s_load_dwordx4 s[24:27], s[0:1], 0x0 + s_load_dwordx4 s[28:31], s[0:1], 0x18 + s_ashr_i32 s21, s14, 5 + v_lshlrev_b32_e32 v11, 5, v0 + v_mov_b32_e32 v3, 0 + v_ashrrev_i32_e32 v12, 5, v5 + s_ashr_i32 s14, s14, 1 + v_add_u32_e32 v5, s2, v4 + s_waitcnt lgkmcnt(0) + v_mov_b64_e32 v[8:9], s[24:25] + v_and_b32_e32 v32, 0x60, v11 + v_mad_i64_i32 v[8:9], s[6:7], s14, v5, v[8:9] + v_mov_b32_e32 v33, v3 + v_add_u32_e32 v13, s3, v4 + v_lshl_add_u64 v[20:21], v[8:9], 0, v[32:33] + v_mov_b64_e32 v[8:9], s[26:27] + v_mad_i64_i32 v[8:9], s[6:7], s14, v13, v[8:9] + v_or_b32_e32 v10, v29, v1 + v_cmp_gt_i32_e64 s[2:3], s12, v5 + v_lshl_add_u64 v[22:23], v[8:9], 0, v[32:33] + v_lshlrev_b32_e32 v33, 7, v4 + v_mov_b64_e32 v[4:5], s[28:29] + v_lshlrev_b32_e32 v6, 2, v0 + s_lshl_b32 s15, s18, 5 + v_mad_i64_i32 v[24:25], s[6:7], s17, v10, v[4:5] + v_mov_b64_e32 v[4:5], s[30:31] + v_and_b32_e32 v6, 60, v6 + v_mov_b32_e32 v7, v3 + v_mad_i64_i32 v[4:5], s[6:7], s15, v12, v[4:5] + v_lshrrev_b32_e32 v2, 4, v1 + v_lshl_add_u64 v[4:5], v[4:5], 0, v[6:7] + v_lshlrev_b32_e32 v0, 6, v0 + v_lshlrev_b32_e32 v1, 7, v1 + s_movk_i32 s6, 0x1000 + s_movk_i32 s14, 0x2000 + v_lshl_add_u64 v[26:27], v[4:5], 0, v[2:3] + v_lshlrev_b32_e32 v2, 4, v28 + v_and_or_b32 v0, v0, s6, v1 + v_or3_b32 v30, v0, v2, s14 + v_and_b32_e32 v0, 0x1000, v11 + v_or3_b32 v31, v0, v1, v2 + v_mov_b32_e32 v2, v3 + v_cmp_gt_i32_e32 vcc, s12, v10 + v_cmp_gt_i32_e64 s[0:1], s13, v13 + v_or_b32_e32 v34, 0x2000, v33 + v_mov_b32_e32 v4, v3 + v_mov_b32_e32 v5, v3 + v_mov_b32_e32 v6, v3 + v_mov_b32_e32 v8, v3 + v_mov_b32_e32 v9, v3 + v_mov_b32_e32 v10, v3 + v_mov_b32_e32 v11, v3 + v_mov_b32_e32 v12, v3 + v_mov_b32_e32 v13, v3 + v_mov_b32_e32 v14, v3 + v_mov_b32_e32 v15, v3 + v_mov_b32_e32 v16, v3 + v_mov_b32_e32 v17, v3 + v_accvgpr_write_b32 a0, v2 + v_add_u32_e32 v19, 32, v32 + v_accvgpr_write_b32 a1, v3 + v_accvgpr_write_b32 a2, v4 + v_accvgpr_write_b32 a3, v5 + v_accvgpr_write_b32 a4, v6 + v_accvgpr_write_b32 a5, v7 + v_accvgpr_write_b32 a6, v8 + v_accvgpr_write_b32 a7, v9 + v_accvgpr_write_b32 a8, v10 + v_accvgpr_write_b32 a9, v11 + v_accvgpr_write_b32 a10, v12 + v_accvgpr_write_b32 a11, v13 + v_accvgpr_write_b32 a12, v14 + v_accvgpr_write_b32 a13, v15 + v_accvgpr_write_b32 a14, v16 + v_accvgpr_write_b32 a15, v17 + v_add_u32_e32 v4, v33, v32 + v_add_u32_e32 v5, v34, v32 + s_branch .LBB1_3 +.LBB1_2: ; %._crit_edge + ; in Loop: Header=BB1_3 Depth=1 + s_addk_i32 s5, 0x100 + s_cmp_ge_i32 s5, s20 + s_barrier + s_cbranch_scc1 .LBB1_27 +.LBB1_3: ; =>This Loop Header: Depth=1 + ; Child Loop BB1_25 Depth 2 + s_sub_i32 s23, s20, s5 + s_min_i32 s22, s23, 0x100 + s_lshr_b32 s6, s22, 1 + v_cmp_ge_u32_e64 s[6:7], s6, v19 + s_ashr_i32 s14, s5, 1 + s_and_b64 s[24:25], s[2:3], s[6:7] + s_ashr_i32 s15, s14, 31 + s_waitcnt vmcnt(0) + v_mov_b32_e32 v6, 0 + v_mov_b32_e32 v7, 0 + v_mov_b32_e32 v8, 0 + v_mov_b32_e32 v9, 0 + v_mov_b32_e32 v10, 0 + v_mov_b32_e32 v11, 0 + v_mov_b32_e32 v12, 0 + v_mov_b32_e32 v13, 0 + s_and_saveexec_b64 s[18:19], s[24:25] + s_cbranch_execz .LBB1_5 +; %bb.4: ; in Loop: Header=BB1_3 Depth=1 + v_lshl_add_u64 v[0:1], v[20:21], 0, s[14:15] + global_load_dwordx4 v[6:9], v[0:1], off + global_load_dwordx4 v[10:13], v[0:1], off offset:16 +.LBB1_5: ; %_Z8copy_32bPhPKhb.exit + ; in Loop: Header=BB1_3 Depth=1 + s_or_b64 exec, exec, s[18:19] + s_and_b64 s[18:19], s[0:1], s[6:7] + s_waitcnt vmcnt(1) + ds_write_b128 v4, v[6:9] + s_waitcnt vmcnt(0) + ds_write_b128 v4, v[10:13] offset:16 + v_mov_b32_e32 v6, 0 + v_mov_b32_e32 v7, 0 + v_mov_b32_e32 v8, 0 + v_mov_b32_e32 v9, 0 + v_mov_b32_e32 v10, 0 + v_mov_b32_e32 v11, 0 + v_mov_b32_e32 v12, 0 + v_mov_b32_e32 v13, 0 + s_and_saveexec_b64 s[6:7], s[18:19] + s_cbranch_execz .LBB1_7 +; %bb.6: ; in Loop: Header=BB1_3 Depth=1 + v_lshl_add_u64 v[0:1], v[22:23], 0, s[14:15] + global_load_dwordx4 v[6:9], v[0:1], off + global_load_dwordx4 v[10:13], v[0:1], off offset:16 +.LBB1_7: ; %_Z8copy_32bPhPKhb.exit197 + ; in Loop: Header=BB1_3 Depth=1 + s_or_b64 exec, exec, s[6:7] + s_ashr_i32 s6, s5, 5 + v_add_u32_e32 v0, s6, v28 + v_cmp_gt_i32_e64 s[6:7], s17, v0 + s_waitcnt vmcnt(1) + ds_write_b128 v5, v[6:9] + s_waitcnt vmcnt(0) + ds_write_b128 v5, v[10:13] offset:16 + s_and_b64 s[14:15], vcc, s[6:7] + v_mov_b32_e32 v7, 0x7f + v_mov_b32_e32 v6, 0x7f + s_waitcnt lgkmcnt(0) + s_barrier + s_and_saveexec_b64 s[6:7], s[14:15] + s_cbranch_execz .LBB1_9 +; %bb.8: ; in Loop: Header=BB1_3 Depth=1 + v_ashrrev_i32_e32 v1, 31, v0 + v_lshl_add_u64 v[8:9], v[24:25], 0, v[0:1] + global_load_ubyte v6, v[8:9], off +.LBB1_9: ; in Loop: Header=BB1_3 Depth=1 + s_or_b64 exec, exec, s[6:7] + v_cmp_gt_i32_e64 s[6:7], s21, v0 + s_and_b64 s[14:15], s[8:9], s[6:7] + s_and_saveexec_b64 s[6:7], s[14:15] + s_cbranch_execz .LBB1_11 +; %bb.10: ; in Loop: Header=BB1_3 Depth=1 + v_lshlrev_b32_e32 v1, 5, v0 + v_and_b32_e32 v8, 0xffffff00, v1 + v_ashrrev_i32_e32 v9, 31, v8 + v_lshlrev_b32_e32 v1, 6, v0 + v_and_b32_e32 v2, 0xc0, v1 + v_lshrrev_b32_e32 v1, 1, v0 + v_lshl_add_u64 v[8:9], v[26:27], 0, v[8:9] + v_and_b32_e32 v10, 2, v1 + v_mov_b32_e32 v11, v3 + v_lshl_add_u64 v[8:9], v[8:9], 0, v[2:3] + v_lshl_add_u64 v[8:9], v[8:9], 0, v[10:11] + global_load_ubyte v7, v[8:9], off +.LBB1_11: ; in Loop: Header=BB1_3 Depth=1 + s_or_b64 exec, exec, s[6:7] + v_add_u32_e32 v10, 2, v0 + v_cmp_gt_i32_e64 s[6:7], s17, v10 + s_and_b64 s[14:15], vcc, s[6:7] + v_mov_b32_e32 v9, 0x7f00 + v_mov_b32_e32 v8, 0x7f00 + s_and_saveexec_b64 s[6:7], s[14:15] + s_cbranch_execz .LBB1_13 +; %bb.12: ; in Loop: Header=BB1_3 Depth=1 + v_ashrrev_i32_e32 v1, 31, v0 + v_lshl_add_u64 v[12:13], v[24:25], 0, v[0:1] + global_load_ubyte v1, v[12:13], off offset:2 + s_waitcnt vmcnt(0) + v_lshlrev_b32_e32 v8, 8, v1 +.LBB1_13: ; in Loop: Header=BB1_3 Depth=1 + s_or_b64 exec, exec, s[6:7] + v_cmp_gt_i32_e64 s[6:7], s21, v10 + s_and_b64 s[14:15], s[8:9], s[6:7] + s_and_saveexec_b64 s[6:7], s[14:15] + s_cbranch_execz .LBB1_15 +; %bb.14: ; in Loop: Header=BB1_3 Depth=1 + v_lshlrev_b32_e32 v1, 5, v10 + v_and_b32_e32 v12, 0xffffff00, v1 + v_ashrrev_i32_e32 v13, 31, v12 + v_lshlrev_b32_e32 v1, 6, v10 + v_and_b32_e32 v2, 0xc0, v1 + v_lshrrev_b32_e32 v1, 1, v10 + v_lshl_add_u64 v[12:13], v[26:27], 0, v[12:13] + v_and_b32_e32 v10, 2, v1 + v_mov_b32_e32 v11, v3 + v_lshl_add_u64 v[12:13], v[12:13], 0, v[2:3] + v_lshl_add_u64 v[10:11], v[12:13], 0, v[10:11] + global_load_ubyte v1, v[10:11], off + s_waitcnt vmcnt(0) + v_lshlrev_b32_e32 v9, 8, v1 +.LBB1_15: ; in Loop: Header=BB1_3 Depth=1 + s_or_b64 exec, exec, s[6:7] + v_add_u32_e32 v12, 4, v0 + v_cmp_gt_i32_e64 s[6:7], s17, v12 + s_and_b64 s[14:15], vcc, s[6:7] + v_mov_b32_e32 v11, 0x7f0000 + v_mov_b32_e32 v10, 0x7f0000 + s_and_saveexec_b64 s[6:7], s[14:15] + s_cbranch_execz .LBB1_17 +; %bb.16: ; in Loop: Header=BB1_3 Depth=1 + v_ashrrev_i32_e32 v1, 31, v0 + v_lshl_add_u64 v[14:15], v[24:25], 0, v[0:1] + global_load_ubyte v1, v[14:15], off offset:4 + s_waitcnt vmcnt(0) + v_lshlrev_b32_e32 v10, 16, v1 +.LBB1_17: ; in Loop: Header=BB1_3 Depth=1 + s_or_b64 exec, exec, s[6:7] + v_cmp_gt_i32_e64 s[6:7], s21, v12 + s_and_b64 s[14:15], s[8:9], s[6:7] + s_and_saveexec_b64 s[6:7], s[14:15] + s_cbranch_execz .LBB1_19 +; %bb.18: ; in Loop: Header=BB1_3 Depth=1 + v_lshlrev_b32_e32 v1, 5, v12 + v_and_b32_e32 v14, 0xffffff00, v1 + v_ashrrev_i32_e32 v15, 31, v14 + v_lshlrev_b32_e32 v1, 6, v0 + v_and_b32_e32 v2, 0xc0, v1 + v_lshrrev_b32_e32 v1, 1, v12 + v_lshl_add_u64 v[14:15], v[26:27], 0, v[14:15] + v_and_b32_e32 v12, 2, v1 + v_mov_b32_e32 v13, v3 + v_lshl_add_u64 v[14:15], v[14:15], 0, v[2:3] + v_lshl_add_u64 v[12:13], v[14:15], 0, v[12:13] + global_load_ubyte v1, v[12:13], off + s_waitcnt vmcnt(0) + v_lshlrev_b32_e32 v11, 16, v1 +.LBB1_19: ; in Loop: Header=BB1_3 Depth=1 + s_or_b64 exec, exec, s[6:7] + v_add_u32_e32 v12, 6, v0 + v_cmp_gt_i32_e64 s[6:7], s17, v12 + s_and_b64 s[14:15], vcc, s[6:7] + v_mov_b32_e32 v2, 0x7f000000 + v_mov_b32_e32 v1, 0x7f000000 + s_and_saveexec_b64 s[6:7], s[14:15] + s_cbranch_execz .LBB1_21 +; %bb.20: ; in Loop: Header=BB1_3 Depth=1 + v_ashrrev_i32_e32 v1, 31, v0 + v_lshl_add_u64 v[0:1], v[24:25], 0, v[0:1] + global_load_ubyte v0, v[0:1], off offset:6 + s_waitcnt vmcnt(0) + v_lshlrev_b32_e32 v1, 24, v0 +.LBB1_21: ; in Loop: Header=BB1_3 Depth=1 + s_or_b64 exec, exec, s[6:7] + v_cmp_gt_i32_e64 s[6:7], s21, v12 + s_and_b64 s[14:15], s[8:9], s[6:7] + s_and_saveexec_b64 s[6:7], s[14:15] + s_cbranch_execz .LBB1_23 +; %bb.22: ; in Loop: Header=BB1_3 Depth=1 + v_lshlrev_b32_e32 v0, 5, v12 + v_and_b32_e32 v14, 0xffffff00, v0 + v_ashrrev_i32_e32 v15, 31, v14 + v_lshlrev_b32_e32 v0, 6, v12 + v_and_b32_e32 v2, 0xc0, v0 + v_lshrrev_b32_e32 v0, 1, v12 + v_lshl_add_u64 v[14:15], v[26:27], 0, v[14:15] + v_and_b32_e32 v12, 2, v0 + v_mov_b32_e32 v13, v3 + v_lshl_add_u64 v[14:15], v[14:15], 0, v[2:3] + v_lshl_add_u64 v[12:13], v[14:15], 0, v[12:13] + global_load_ubyte v0, v[12:13], off + s_waitcnt vmcnt(0) + v_lshlrev_b32_e32 v2, 24, v0 +.LBB1_23: ; %.preheader202 + ; in Loop: Header=BB1_3 Depth=1 + s_or_b64 exec, exec, s[6:7] + s_cmp_lt_i32 s23, 1 + s_cbranch_scc1 .LBB1_2 +; %bb.24: ; %.lr.ph.preheader + ; in Loop: Header=BB1_3 Depth=1 + s_waitcnt vmcnt(0) + v_or_b32_e32 v0, v8, v6 + v_or_b32_e32 v6, v9, v7 + v_or3_b32 v0, v10, v0, v1 + v_or3_b32 v1, v11, v6, v2 + s_mov_b32 s6, 0 + v_mov_b32_e32 v2, v31 + v_mov_b32_e32 v6, v30 + s_mov_b32 s7, 0 +.LBB1_25: ; %.lr.ph + ; Parent Loop BB1_3 Depth=1 + ; => This Inner Loop Header: Depth=2 + ds_read_b128 v[8:11], v2 + ds_read_b128 v[12:15], v6 + v_bfe_u32 v7, v0, s6, 8 + v_bfe_u32 v16, v1, s6, 8 + s_add_i32 s7, s7, 64 + s_add_i32 s6, s6, 8 + s_waitcnt lgkmcnt(0) + v_mfma_scale_f32_32x32x64_f8f6f4 a[0:15], v[8:11], v[12:15], a[0:15], v7, v16 op_sel_hi:[0,0,0] cbsz:4 blgp:4 + v_add_u32_e32 v6, 32, v6 + s_cmp_ge_i32 s7, s22 + v_add_u32_e32 v2, 32, v2 + s_cbranch_scc0 .LBB1_25 + s_branch .LBB1_2 +.LBB1_26: + v_accvgpr_write_b32 a0, 0 + v_accvgpr_mov_b32 a1, a0 + v_accvgpr_mov_b32 a2, a0 + v_accvgpr_mov_b32 a3, a0 + v_accvgpr_mov_b32 a4, a0 + v_accvgpr_mov_b32 a5, a0 + v_accvgpr_mov_b32 a6, a0 + v_accvgpr_mov_b32 a7, a0 + v_accvgpr_mov_b32 a8, a0 + v_accvgpr_mov_b32 a9, a0 + v_accvgpr_mov_b32 a10, a0 + v_accvgpr_mov_b32 a11, a0 + v_accvgpr_mov_b32 a12, a0 + v_accvgpr_mov_b32 a13, a0 + v_accvgpr_mov_b32 a14, a0 + v_accvgpr_mov_b32 a15, a0 +.LBB1_27: ; %Flow413 + s_waitcnt vmcnt(0) + s_nop 1 + v_accvgpr_read_b32 v0, a0 + v_accvgpr_read_b32 v1, a1 + v_accvgpr_read_b32 v2, a2 + v_accvgpr_read_b32 v3, a3 + v_accvgpr_read_b32 v4, a4 + v_accvgpr_read_b32 v5, a5 + v_accvgpr_read_b32 v6, a6 + v_accvgpr_read_b32 v7, a7 + v_accvgpr_read_b32 v8, a8 + v_accvgpr_read_b32 v9, a9 + v_accvgpr_read_b32 v10, a10 + v_accvgpr_read_b32 v11, a11 + v_accvgpr_read_b32 v12, a12 + v_accvgpr_read_b32 v13, a13 + v_accvgpr_read_b32 v14, a14 + v_accvgpr_read_b32 v15, a15 + s_cmp_lg_u32 s16, 1 + s_mov_b64 s[0:1], -1 + s_cbranch_scc0 .LBB1_62 +; %bb.28: + s_and_saveexec_b64 s[0:1], s[8:9] + s_cbranch_execz .LBB1_61 +; %bb.29: ; %.preheader200 + v_lshl_or_b32 v20, v28, 2, v29 + v_ashrrev_i32_e32 v19, 31, v18 + s_mul_hi_i32 s3, s12, s4 + s_mul_i32 s2, s12, s4 + s_ashr_i32 s6, s13, 31 + v_lshl_add_u64 v[16:17], v[18:19], 2, s[10:11] + v_cmp_gt_i32_e32 vcc, s12, v20 + s_and_saveexec_b64 s[4:5], vcc + s_cbranch_execz .LBB1_31 +; %bb.30: + v_ashrrev_i32_e32 v21, 31, v20 + v_lshl_add_u64 v[22:23], s[2:3], 0, v[20:21] + v_mul_lo_u32 v19, v23, s13 + v_mul_lo_u32 v21, v22, s6 + v_mad_u64_u32 v[22:23], s[14:15], v22, s13, 0 + v_add3_u32 v23, v23, v21, v19 + v_lshl_add_u64 v[22:23], v[22:23], 2, v[16:17] + global_store_dword v[22:23], v0, off +.LBB1_31: + s_or_b64 exec, exec, s[4:5] + v_or_b32_e32 v22, 1, v20 + v_cmp_gt_i32_e32 vcc, s12, v22 + s_and_saveexec_b64 s[4:5], vcc + s_cbranch_execz .LBB1_33 +; %bb.32: + v_ashrrev_i32_e32 v23, 31, v22 + v_lshl_add_u64 v[22:23], s[2:3], 0, v[22:23] + v_mul_lo_u32 v19, v23, s13 + v_mul_lo_u32 v21, v22, s6 + v_mad_u64_u32 v[22:23], s[14:15], v22, s13, 0 + v_add3_u32 v23, v23, v21, v19 + v_lshl_add_u64 v[22:23], v[22:23], 2, v[16:17] + global_store_dword v[22:23], v1, off +.LBB1_33: + s_or_b64 exec, exec, s[4:5] + v_or_b32_e32 v22, 2, v20 + v_cmp_gt_i32_e32 vcc, s12, v22 + s_and_saveexec_b64 s[4:5], vcc + s_cbranch_execz .LBB1_35 +; %bb.34: + v_ashrrev_i32_e32 v23, 31, v22 + v_lshl_add_u64 v[22:23], s[2:3], 0, v[22:23] + v_mul_lo_u32 v19, v23, s13 + v_mul_lo_u32 v21, v22, s6 + v_mad_u64_u32 v[22:23], s[14:15], v22, s13, 0 + v_add3_u32 v23, v23, v21, v19 + v_lshl_add_u64 v[22:23], v[22:23], 2, v[16:17] + global_store_dword v[22:23], v2, off +.LBB1_35: + s_or_b64 exec, exec, s[4:5] + v_or_b32_e32 v22, 3, v20 + v_cmp_gt_i32_e32 vcc, s12, v22 + s_and_saveexec_b64 s[4:5], vcc + s_cbranch_execz .LBB1_37 +; %bb.36: + v_ashrrev_i32_e32 v23, 31, v22 + v_lshl_add_u64 v[22:23], s[2:3], 0, v[22:23] + v_mul_lo_u32 v19, v23, s13 + v_mul_lo_u32 v21, v22, s6 + v_mad_u64_u32 v[22:23], s[14:15], v22, s13, 0 + v_add3_u32 v23, v23, v21, v19 + v_lshl_add_u64 v[22:23], v[22:23], 2, v[16:17] + global_store_dword v[22:23], v3, off +.LBB1_37: + s_or_b64 exec, exec, s[4:5] + v_or_b32_e32 v22, 8, v20 + v_cmp_gt_i32_e32 vcc, s12, v22 + s_and_saveexec_b64 s[4:5], vcc + s_cbranch_execz .LBB1_39 +; %bb.38: + v_ashrrev_i32_e32 v23, 31, v22 + v_lshl_add_u64 v[22:23], s[2:3], 0, v[22:23] + v_mul_lo_u32 v19, v23, s13 + v_mul_lo_u32 v21, v22, s6 + v_mad_u64_u32 v[22:23], s[14:15], v22, s13, 0 + v_add3_u32 v23, v23, v21, v19 + v_lshl_add_u64 v[22:23], v[22:23], 2, v[16:17] + global_store_dword v[22:23], v4, off +.LBB1_39: + s_or_b64 exec, exec, s[4:5] + v_or_b32_e32 v22, 9, v20 + v_cmp_gt_i32_e32 vcc, s12, v22 + s_and_saveexec_b64 s[4:5], vcc + s_cbranch_execz .LBB1_41 +; %bb.40: + v_ashrrev_i32_e32 v23, 31, v22 + v_lshl_add_u64 v[22:23], s[2:3], 0, v[22:23] + v_mul_lo_u32 v19, v23, s13 + v_mul_lo_u32 v21, v22, s6 + v_mad_u64_u32 v[22:23], s[14:15], v22, s13, 0 + v_add3_u32 v23, v23, v21, v19 + v_lshl_add_u64 v[22:23], v[22:23], 2, v[16:17] + global_store_dword v[22:23], v5, off +.LBB1_41: + s_or_b64 exec, exec, s[4:5] + v_or_b32_e32 v22, 10, v20 + v_cmp_gt_i32_e32 vcc, s12, v22 + s_and_saveexec_b64 s[4:5], vcc + s_cbranch_execz .LBB1_43 +; %bb.42: + v_ashrrev_i32_e32 v23, 31, v22 + v_lshl_add_u64 v[22:23], s[2:3], 0, v[22:23] + v_mul_lo_u32 v19, v23, s13 + v_mul_lo_u32 v21, v22, s6 + v_mad_u64_u32 v[22:23], s[14:15], v22, s13, 0 + v_add3_u32 v23, v23, v21, v19 + v_lshl_add_u64 v[22:23], v[22:23], 2, v[16:17] + global_store_dword v[22:23], v6, off +.LBB1_43: + s_or_b64 exec, exec, s[4:5] + v_or_b32_e32 v22, 11, v20 + v_cmp_gt_i32_e32 vcc, s12, v22 + s_and_saveexec_b64 s[4:5], vcc + s_cbranch_execz .LBB1_45 +; %bb.44: + v_ashrrev_i32_e32 v23, 31, v22 + v_lshl_add_u64 v[22:23], s[2:3], 0, v[22:23] + v_mul_lo_u32 v19, v23, s13 + v_mul_lo_u32 v21, v22, s6 + v_mad_u64_u32 v[22:23], s[14:15], v22, s13, 0 + v_add3_u32 v23, v23, v21, v19 + v_lshl_add_u64 v[22:23], v[22:23], 2, v[16:17] + global_store_dword v[22:23], v7, off +.LBB1_45: + s_or_b64 exec, exec, s[4:5] + v_or_b32_e32 v22, 16, v20 + v_cmp_gt_i32_e32 vcc, s12, v22 + s_and_saveexec_b64 s[4:5], vcc + s_cbranch_execz .LBB1_47 +; %bb.46: + v_ashrrev_i32_e32 v23, 31, v22 + v_lshl_add_u64 v[22:23], s[2:3], 0, v[22:23] + v_mul_lo_u32 v19, v23, s13 + v_mul_lo_u32 v21, v22, s6 + v_mad_u64_u32 v[22:23], s[14:15], v22, s13, 0 + v_add3_u32 v23, v23, v21, v19 + v_lshl_add_u64 v[22:23], v[22:23], 2, v[16:17] + global_store_dword v[22:23], v8, off +.LBB1_47: + s_or_b64 exec, exec, s[4:5] + v_or_b32_e32 v22, 17, v20 + v_cmp_gt_i32_e32 vcc, s12, v22 + s_and_saveexec_b64 s[4:5], vcc + s_cbranch_execz .LBB1_49 +; %bb.48: + v_ashrrev_i32_e32 v23, 31, v22 + v_lshl_add_u64 v[22:23], s[2:3], 0, v[22:23] + v_mul_lo_u32 v19, v23, s13 + v_mul_lo_u32 v21, v22, s6 + v_mad_u64_u32 v[22:23], s[14:15], v22, s13, 0 + v_add3_u32 v23, v23, v21, v19 + v_lshl_add_u64 v[22:23], v[22:23], 2, v[16:17] + global_store_dword v[22:23], v9, off +.LBB1_49: + s_or_b64 exec, exec, s[4:5] + v_or_b32_e32 v22, 18, v20 + v_cmp_gt_i32_e32 vcc, s12, v22 + s_and_saveexec_b64 s[4:5], vcc + s_cbranch_execz .LBB1_51 +; %bb.50: + v_ashrrev_i32_e32 v23, 31, v22 + v_lshl_add_u64 v[22:23], s[2:3], 0, v[22:23] + v_mul_lo_u32 v19, v23, s13 + v_mul_lo_u32 v21, v22, s6 + v_mad_u64_u32 v[22:23], s[14:15], v22, s13, 0 + v_add3_u32 v23, v23, v21, v19 + v_lshl_add_u64 v[22:23], v[22:23], 2, v[16:17] + global_store_dword v[22:23], v10, off +.LBB1_51: + s_or_b64 exec, exec, s[4:5] + v_or_b32_e32 v22, 19, v20 + v_cmp_gt_i32_e32 vcc, s12, v22 + s_and_saveexec_b64 s[4:5], vcc + s_cbranch_execz .LBB1_53 +; %bb.52: + v_ashrrev_i32_e32 v23, 31, v22 + v_lshl_add_u64 v[22:23], s[2:3], 0, v[22:23] + v_mul_lo_u32 v19, v23, s13 + v_mul_lo_u32 v21, v22, s6 + v_mad_u64_u32 v[22:23], s[14:15], v22, s13, 0 + v_add3_u32 v23, v23, v21, v19 + v_lshl_add_u64 v[22:23], v[22:23], 2, v[16:17] + global_store_dword v[22:23], v11, off +.LBB1_53: + s_or_b64 exec, exec, s[4:5] + v_or_b32_e32 v22, 24, v20 + v_cmp_gt_i32_e32 vcc, s12, v22 + s_and_saveexec_b64 s[4:5], vcc + s_cbranch_execz .LBB1_55 +; %bb.54: + v_ashrrev_i32_e32 v23, 31, v22 + v_lshl_add_u64 v[22:23], s[2:3], 0, v[22:23] + v_mul_lo_u32 v19, v23, s13 + v_mul_lo_u32 v21, v22, s6 + v_mad_u64_u32 v[22:23], s[14:15], v22, s13, 0 + v_add3_u32 v23, v23, v21, v19 + v_lshl_add_u64 v[22:23], v[22:23], 2, v[16:17] + global_store_dword v[22:23], v12, off +.LBB1_55: + s_or_b64 exec, exec, s[4:5] + v_or_b32_e32 v22, 25, v20 + v_cmp_gt_i32_e32 vcc, s12, v22 + s_and_saveexec_b64 s[4:5], vcc + s_cbranch_execz .LBB1_57 +; %bb.56: + v_ashrrev_i32_e32 v23, 31, v22 + v_lshl_add_u64 v[22:23], s[2:3], 0, v[22:23] + v_mul_lo_u32 v19, v23, s13 + v_mul_lo_u32 v21, v22, s6 + v_mad_u64_u32 v[22:23], s[14:15], v22, s13, 0 + v_add3_u32 v23, v23, v21, v19 + v_lshl_add_u64 v[22:23], v[22:23], 2, v[16:17] + global_store_dword v[22:23], v13, off +.LBB1_57: + s_or_b64 exec, exec, s[4:5] + v_or_b32_e32 v22, 26, v20 + v_cmp_gt_i32_e32 vcc, s12, v22 + s_and_saveexec_b64 s[4:5], vcc + s_cbranch_execz .LBB1_59 +; %bb.58: + v_ashrrev_i32_e32 v23, 31, v22 + v_lshl_add_u64 v[22:23], s[2:3], 0, v[22:23] + v_mul_lo_u32 v19, v23, s13 + v_mul_lo_u32 v21, v22, s6 + v_mad_u64_u32 v[22:23], s[14:15], v22, s13, 0 + v_add3_u32 v23, v23, v21, v19 + v_lshl_add_u64 v[22:23], v[22:23], 2, v[16:17] + global_store_dword v[22:23], v14, off +.LBB1_59: + s_or_b64 exec, exec, s[4:5] + v_or_b32_e32 v20, 27, v20 + v_cmp_gt_i32_e32 vcc, s12, v20 + s_and_b64 exec, exec, vcc + s_cbranch_execz .LBB1_61 +; %bb.60: + v_ashrrev_i32_e32 v21, 31, v20 + v_lshl_add_u64 v[20:21], s[2:3], 0, v[20:21] + v_mul_lo_u32 v19, v21, s13 + v_mul_lo_u32 v22, v20, s6 + v_mad_u64_u32 v[20:21], s[2:3], v20, s13, 0 + v_add3_u32 v21, v21, v22, v19 + v_lshl_add_u64 v[16:17], v[20:21], 2, v[16:17] + global_store_dword v[16:17], v15, off +.LBB1_61: ; %Flow406 + s_or_b64 exec, exec, s[0:1] + s_mov_b64 s[0:1], 0 +.LBB1_62: ; %Flow409 + s_andn2_b64 vcc, exec, s[0:1] + s_cbranch_vccnz .LBB1_96 +; %bb.63: + s_and_saveexec_b64 s[0:1], s[8:9] + s_cbranch_execz .LBB1_96 +; %bb.64: ; %.preheader + v_lshl_or_b32 v20, v28, 2, v29 + v_ashrrev_i32_e32 v19, 31, v18 + v_lshl_add_u64 v[16:17], v[18:19], 1, s[10:11] + v_cmp_gt_i32_e32 vcc, s12, v20 + s_and_saveexec_b64 s[0:1], vcc + s_cbranch_execz .LBB1_66 +; %bb.65: + v_mad_i64_i32 v[18:19], s[2:3], v20, s13, 0 + v_lshl_add_u64 v[18:19], v[18:19], 1, v[16:17] + global_store_short_d16_hi v[18:19], v0, off +.LBB1_66: + s_or_b64 exec, exec, s[0:1] + v_or_b32_e32 v0, 1, v20 + v_cmp_gt_i32_e32 vcc, s12, v0 + s_and_saveexec_b64 s[0:1], vcc + s_cbranch_execz .LBB1_68 +; %bb.67: + v_mad_i64_i32 v[18:19], s[2:3], v0, s13, 0 + v_lshl_add_u64 v[18:19], v[18:19], 1, v[16:17] + global_store_short_d16_hi v[18:19], v1, off +.LBB1_68: + s_or_b64 exec, exec, s[0:1] + v_or_b32_e32 v0, 2, v20 + v_cmp_gt_i32_e32 vcc, s12, v0 + s_and_saveexec_b64 s[0:1], vcc + s_cbranch_execz .LBB1_70 +; %bb.69: + v_mad_i64_i32 v[0:1], s[2:3], v0, s13, 0 + v_lshl_add_u64 v[0:1], v[0:1], 1, v[16:17] + global_store_short_d16_hi v[0:1], v2, off +.LBB1_70: + s_or_b64 exec, exec, s[0:1] + v_or_b32_e32 v0, 3, v20 + v_cmp_gt_i32_e32 vcc, s12, v0 + s_and_saveexec_b64 s[0:1], vcc + s_cbranch_execz .LBB1_72 +; %bb.71: + v_mad_i64_i32 v[0:1], s[2:3], v0, s13, 0 + v_lshl_add_u64 v[0:1], v[0:1], 1, v[16:17] + global_store_short_d16_hi v[0:1], v3, off +.LBB1_72: + s_or_b64 exec, exec, s[0:1] + v_or_b32_e32 v0, 8, v20 + v_cmp_gt_i32_e32 vcc, s12, v0 + s_and_saveexec_b64 s[0:1], vcc + s_cbranch_execz .LBB1_74 +; %bb.73: + v_mad_i64_i32 v[0:1], s[2:3], v0, s13, 0 + v_lshl_add_u64 v[0:1], v[0:1], 1, v[16:17] + global_store_short_d16_hi v[0:1], v4, off +.LBB1_74: + s_or_b64 exec, exec, s[0:1] + v_or_b32_e32 v0, 9, v20 + v_cmp_gt_i32_e32 vcc, s12, v0 + s_and_saveexec_b64 s[0:1], vcc + s_cbranch_execz .LBB1_76 +; %bb.75: + v_mad_i64_i32 v[0:1], s[2:3], v0, s13, 0 + v_lshl_add_u64 v[0:1], v[0:1], 1, v[16:17] + global_store_short_d16_hi v[0:1], v5, off +.LBB1_76: + s_or_b64 exec, exec, s[0:1] + v_or_b32_e32 v0, 10, v20 + v_cmp_gt_i32_e32 vcc, s12, v0 + s_and_saveexec_b64 s[0:1], vcc + s_cbranch_execz .LBB1_78 +; %bb.77: + v_mad_i64_i32 v[0:1], s[2:3], v0, s13, 0 + v_lshl_add_u64 v[0:1], v[0:1], 1, v[16:17] + global_store_short_d16_hi v[0:1], v6, off +.LBB1_78: + s_or_b64 exec, exec, s[0:1] + v_or_b32_e32 v0, 11, v20 + v_cmp_gt_i32_e32 vcc, s12, v0 + s_and_saveexec_b64 s[0:1], vcc + s_cbranch_execz .LBB1_80 +; %bb.79: + v_mad_i64_i32 v[0:1], s[2:3], v0, s13, 0 + v_lshl_add_u64 v[0:1], v[0:1], 1, v[16:17] + global_store_short_d16_hi v[0:1], v7, off +.LBB1_80: + s_or_b64 exec, exec, s[0:1] + v_or_b32_e32 v0, 16, v20 + v_cmp_gt_i32_e32 vcc, s12, v0 + s_and_saveexec_b64 s[0:1], vcc + s_cbranch_execz .LBB1_82 +; %bb.81: + v_mad_i64_i32 v[0:1], s[2:3], v0, s13, 0 + v_lshl_add_u64 v[0:1], v[0:1], 1, v[16:17] + global_store_short_d16_hi v[0:1], v8, off +.LBB1_82: + s_or_b64 exec, exec, s[0:1] + v_or_b32_e32 v0, 17, v20 + v_cmp_gt_i32_e32 vcc, s12, v0 + s_and_saveexec_b64 s[0:1], vcc + s_cbranch_execz .LBB1_84 +; %bb.83: + v_mad_i64_i32 v[0:1], s[2:3], v0, s13, 0 + v_lshl_add_u64 v[0:1], v[0:1], 1, v[16:17] + global_store_short_d16_hi v[0:1], v9, off +.LBB1_84: + s_or_b64 exec, exec, s[0:1] + v_or_b32_e32 v0, 18, v20 + v_cmp_gt_i32_e32 vcc, s12, v0 + s_and_saveexec_b64 s[0:1], vcc + s_cbranch_execz .LBB1_86 +; %bb.85: + v_mad_i64_i32 v[0:1], s[2:3], v0, s13, 0 + v_lshl_add_u64 v[0:1], v[0:1], 1, v[16:17] + global_store_short_d16_hi v[0:1], v10, off +.LBB1_86: + s_or_b64 exec, exec, s[0:1] + v_or_b32_e32 v0, 19, v20 + v_cmp_gt_i32_e32 vcc, s12, v0 + s_and_saveexec_b64 s[0:1], vcc + s_cbranch_execz .LBB1_88 +; %bb.87: + v_mad_i64_i32 v[0:1], s[2:3], v0, s13, 0 + v_lshl_add_u64 v[0:1], v[0:1], 1, v[16:17] + global_store_short_d16_hi v[0:1], v11, off +.LBB1_88: + s_or_b64 exec, exec, s[0:1] + v_or_b32_e32 v0, 24, v20 + v_cmp_gt_i32_e32 vcc, s12, v0 + s_and_saveexec_b64 s[0:1], vcc + s_cbranch_execz .LBB1_90 +; %bb.89: + v_mad_i64_i32 v[0:1], s[2:3], v0, s13, 0 + v_lshl_add_u64 v[0:1], v[0:1], 1, v[16:17] + global_store_short_d16_hi v[0:1], v12, off +.LBB1_90: + s_or_b64 exec, exec, s[0:1] + v_or_b32_e32 v0, 25, v20 + v_cmp_gt_i32_e32 vcc, s12, v0 + s_and_saveexec_b64 s[0:1], vcc + s_cbranch_execz .LBB1_92 +; %bb.91: + v_mad_i64_i32 v[0:1], s[2:3], v0, s13, 0 + v_lshl_add_u64 v[0:1], v[0:1], 1, v[16:17] + global_store_short_d16_hi v[0:1], v13, off +.LBB1_92: + s_or_b64 exec, exec, s[0:1] + v_or_b32_e32 v0, 26, v20 + v_cmp_gt_i32_e32 vcc, s12, v0 + s_and_saveexec_b64 s[0:1], vcc + s_cbranch_execz .LBB1_94 +; %bb.93: + v_mad_i64_i32 v[0:1], s[2:3], v0, s13, 0 + v_lshl_add_u64 v[0:1], v[0:1], 1, v[16:17] + global_store_short_d16_hi v[0:1], v14, off +.LBB1_94: + s_or_b64 exec, exec, s[0:1] + v_or_b32_e32 v0, 27, v20 + v_cmp_gt_i32_e32 vcc, s12, v0 + s_and_b64 exec, exec, vcc + s_cbranch_execz .LBB1_96 +; %bb.95: + v_mad_i64_i32 v[0:1], s[0:1], v0, s13, 0 + v_lshl_add_u64 v[0:1], v[0:1], 1, v[16:17] + global_store_short_d16_hi v[0:1], v15, off +.LBB1_96: ; %.loopexit + s_endpgm + .section .rodata,"a",@progbits + .p2align 6, 0x0 + .amdhsa_kernel _Z13gemm_core_ldsPKhS0_PvS0_S0_iiiiiii + .amdhsa_group_segment_fixed_size 16384 + .amdhsa_private_segment_fixed_size 0 + .amdhsa_kernarg_size 68 + .amdhsa_user_sgpr_count 2 + .amdhsa_user_sgpr_dispatch_ptr 0 + .amdhsa_user_sgpr_queue_ptr 0 + .amdhsa_user_sgpr_kernarg_segment_ptr 1 + .amdhsa_user_sgpr_dispatch_id 0 + .amdhsa_user_sgpr_kernarg_preload_length 0 + .amdhsa_user_sgpr_kernarg_preload_offset 0 + .amdhsa_user_sgpr_private_segment_size 0 + .amdhsa_uses_dynamic_stack 0 + .amdhsa_enable_private_segment 0 + .amdhsa_system_sgpr_workgroup_id_x 1 + .amdhsa_system_sgpr_workgroup_id_y 1 + .amdhsa_system_sgpr_workgroup_id_z 1 + .amdhsa_system_sgpr_workgroup_info 0 + .amdhsa_system_vgpr_workitem_id 0 + .amdhsa_next_free_vgpr 52 + .amdhsa_next_free_sgpr 32 + .amdhsa_accum_offset 36 + .amdhsa_reserve_vcc 1 + .amdhsa_float_round_mode_32 0 + .amdhsa_float_round_mode_16_64 0 + .amdhsa_float_denorm_mode_32 3 + .amdhsa_float_denorm_mode_16_64 3 + .amdhsa_dx10_clamp 1 + .amdhsa_ieee_mode 1 + .amdhsa_fp16_overflow 0 + .amdhsa_tg_split 0 + .amdhsa_exception_fp_ieee_invalid_op 0 + .amdhsa_exception_fp_denorm_src 0 + .amdhsa_exception_fp_ieee_div_zero 0 + .amdhsa_exception_fp_ieee_overflow 0 + .amdhsa_exception_fp_ieee_underflow 0 + .amdhsa_exception_fp_ieee_inexact 0 + .amdhsa_exception_int_div_zero 0 + .end_amdhsa_kernel + .text +.Lfunc_end1: + .size _Z13gemm_core_ldsPKhS0_PvS0_S0_iiiiiii, .Lfunc_end1-_Z13gemm_core_ldsPKhS0_PvS0_S0_iiiiiii + ; -- End function + .set _Z13gemm_core_ldsPKhS0_PvS0_S0_iiiiiii.num_vgpr, 35 + .set _Z13gemm_core_ldsPKhS0_PvS0_S0_iiiiiii.num_agpr, 16 + .set _Z13gemm_core_ldsPKhS0_PvS0_S0_iiiiiii.numbered_sgpr, 32 + .set _Z13gemm_core_ldsPKhS0_PvS0_S0_iiiiiii.num_named_barrier, 0 + .set _Z13gemm_core_ldsPKhS0_PvS0_S0_iiiiiii.private_seg_size, 0 + .set _Z13gemm_core_ldsPKhS0_PvS0_S0_iiiiiii.uses_vcc, 1 + .set _Z13gemm_core_ldsPKhS0_PvS0_S0_iiiiiii.uses_flat_scratch, 0 + .set _Z13gemm_core_ldsPKhS0_PvS0_S0_iiiiiii.has_dyn_sized_stack, 0 + .set _Z13gemm_core_ldsPKhS0_PvS0_S0_iiiiiii.has_recursion, 0 + .set _Z13gemm_core_ldsPKhS0_PvS0_S0_iiiiiii.has_indirect_call, 0 + .section .AMDGPU.csdata,"",@progbits +; Kernel info: +; codeLenInByte = 3892 +; TotalNumSgprs: 38 +; NumVgprs: 35 +; NumAgprs: 16 +; TotalNumVgprs: 52 +; ScratchSize: 0 +; MemoryBound: 1 +; FloatMode: 240 +; IeeeMode: 1 +; LDSByteSize: 16384 bytes/workgroup (compile time only) +; SGPRBlocks: 4 +; VGPRBlocks: 6 +; NumSGPRsForWavesPerEU: 38 +; NumVGPRsForWavesPerEU: 52 +; AccumOffset: 36 +; Occupancy: 8 +; WaveLimiterHint : 0 +; COMPUTE_PGM_RSRC2:SCRATCH_EN: 0 +; COMPUTE_PGM_RSRC2:USER_SGPR: 2 +; COMPUTE_PGM_RSRC2:TRAP_HANDLER: 0 +; COMPUTE_PGM_RSRC2:TGID_X_EN: 1 +; COMPUTE_PGM_RSRC2:TGID_Y_EN: 1 +; COMPUTE_PGM_RSRC2:TGID_Z_EN: 1 +; COMPUTE_PGM_RSRC2:TIDIG_COMP_CNT: 0 +; COMPUTE_PGM_RSRC3_GFX90A:ACCUM_OFFSET: 8 +; COMPUTE_PGM_RSRC3_GFX90A:TG_SPLIT: 0 + .text + .protected _Z9reduce_skPKfP12hip_bfloat16ii ; -- Begin function _Z9reduce_skPKfP12hip_bfloat16ii + .globl _Z9reduce_skPKfP12hip_bfloat16ii + .p2align 8 + .type _Z9reduce_skPKfP12hip_bfloat16ii,@function +_Z9reduce_skPKfP12hip_bfloat16ii: ; @_Z9reduce_skPKfP12hip_bfloat16ii +; %bb.0: + s_load_dwordx2 s[4:5], s[0:1], 0x10 + v_lshl_add_u32 v0, s2, 8, v0 + s_waitcnt lgkmcnt(0) + v_cmp_gt_i32_e32 vcc, s4, v0 + s_and_saveexec_b64 s[2:3], vcc + s_cbranch_execz .LBB2_8 +; %bb.1: ; %.preheader + s_cmp_gt_i32 s5, 0 + v_ashrrev_i32_e32 v1, 31, v0 + s_cbranch_scc1 .LBB2_3 +; %bb.2: ; %.preheader.._crit_edge_crit_edge + s_load_dwordx2 s[2:3], s[0:1], 0x8 + v_mov_b32_e32 v2, 0 + s_cbranch_execz .LBB2_4 + s_branch .LBB2_7 +.LBB2_3: + s_load_dwordx2 s[2:3], s[0:1], 0x8 + v_mov_b32_e32 v2, 0 +.LBB2_4: ; %.lr.ph + s_load_dwordx2 s[6:7], s[0:1], 0x0 + s_ashr_i32 s1, s4, 31 + s_mov_b32 s0, s4 + s_lshl_b64 s[0:1], s[0:1], 2 + v_mov_b32_e32 v4, 0 + s_waitcnt lgkmcnt(0) + v_lshl_add_u64 v[2:3], v[0:1], 2, s[6:7] +.LBB2_5: ; =>This Inner Loop Header: Depth=1 + global_load_dword v5, v[2:3], off + s_add_i32 s5, s5, -1 + v_lshl_add_u64 v[2:3], v[2:3], 0, s[0:1] + s_cmp_eq_u32 s5, 0 + s_waitcnt vmcnt(0) + v_add_f32_e32 v4, v4, v5 + s_cbranch_scc0 .LBB2_5 +; %bb.6: ; %._crit_edge.loopexit + v_lshrrev_b32_e32 v2, 16, v4 +.LBB2_7: ; %Flow26 + s_waitcnt lgkmcnt(0) + v_lshl_add_u64 v[0:1], v[0:1], 1, s[2:3] + global_store_short v[0:1], v2, off +.LBB2_8: + s_endpgm + .section .rodata,"a",@progbits + .p2align 6, 0x0 + .amdhsa_kernel _Z9reduce_skPKfP12hip_bfloat16ii + .amdhsa_group_segment_fixed_size 0 + .amdhsa_private_segment_fixed_size 0 + .amdhsa_kernarg_size 24 + .amdhsa_user_sgpr_count 2 + .amdhsa_user_sgpr_dispatch_ptr 0 + .amdhsa_user_sgpr_queue_ptr 0 + .amdhsa_user_sgpr_kernarg_segment_ptr 1 + .amdhsa_user_sgpr_dispatch_id 0 + .amdhsa_user_sgpr_kernarg_preload_length 0 + .amdhsa_user_sgpr_kernarg_preload_offset 0 + .amdhsa_user_sgpr_private_segment_size 0 + .amdhsa_uses_dynamic_stack 0 + .amdhsa_enable_private_segment 0 + .amdhsa_system_sgpr_workgroup_id_x 1 + .amdhsa_system_sgpr_workgroup_id_y 0 + .amdhsa_system_sgpr_workgroup_id_z 0 + .amdhsa_system_sgpr_workgroup_info 0 + .amdhsa_system_vgpr_workitem_id 0 + .amdhsa_next_free_vgpr 6 + .amdhsa_next_free_sgpr 8 + .amdhsa_accum_offset 8 + .amdhsa_reserve_vcc 1 + .amdhsa_float_round_mode_32 0 + .amdhsa_float_round_mode_16_64 0 + .amdhsa_float_denorm_mode_32 3 + .amdhsa_float_denorm_mode_16_64 3 + .amdhsa_dx10_clamp 1 + .amdhsa_ieee_mode 1 + .amdhsa_fp16_overflow 0 + .amdhsa_tg_split 0 + .amdhsa_exception_fp_ieee_invalid_op 0 + .amdhsa_exception_fp_denorm_src 0 + .amdhsa_exception_fp_ieee_div_zero 0 + .amdhsa_exception_fp_ieee_overflow 0 + .amdhsa_exception_fp_ieee_underflow 0 + .amdhsa_exception_fp_ieee_inexact 0 + .amdhsa_exception_int_div_zero 0 + .end_amdhsa_kernel + .text +.Lfunc_end2: + .size _Z9reduce_skPKfP12hip_bfloat16ii, .Lfunc_end2-_Z9reduce_skPKfP12hip_bfloat16ii + ; -- End function + .set _Z9reduce_skPKfP12hip_bfloat16ii.num_vgpr, 6 + .set _Z9reduce_skPKfP12hip_bfloat16ii.num_agpr, 0 + .set _Z9reduce_skPKfP12hip_bfloat16ii.numbered_sgpr, 8 + .set _Z9reduce_skPKfP12hip_bfloat16ii.num_named_barrier, 0 + .set _Z9reduce_skPKfP12hip_bfloat16ii.private_seg_size, 0 + .set _Z9reduce_skPKfP12hip_bfloat16ii.uses_vcc, 1 + .set _Z9reduce_skPKfP12hip_bfloat16ii.uses_flat_scratch, 0 + .set _Z9reduce_skPKfP12hip_bfloat16ii.has_dyn_sized_stack, 0 + .set _Z9reduce_skPKfP12hip_bfloat16ii.has_recursion, 0 + .set _Z9reduce_skPKfP12hip_bfloat16ii.has_indirect_call, 0 + .section .AMDGPU.csdata,"",@progbits +; Kernel info: +; codeLenInByte = 176 +; TotalNumSgprs: 14 +; NumVgprs: 6 +; NumAgprs: 0 +; TotalNumVgprs: 6 +; ScratchSize: 0 +; MemoryBound: 0 +; FloatMode: 240 +; IeeeMode: 1 +; LDSByteSize: 0 bytes/workgroup (compile time only) +; SGPRBlocks: 1 +; VGPRBlocks: 0 +; NumSGPRsForWavesPerEU: 14 +; NumVGPRsForWavesPerEU: 6 +; AccumOffset: 8 +; Occupancy: 8 +; WaveLimiterHint : 0 +; COMPUTE_PGM_RSRC2:SCRATCH_EN: 0 +; COMPUTE_PGM_RSRC2:USER_SGPR: 2 +; COMPUTE_PGM_RSRC2:TRAP_HANDLER: 0 +; COMPUTE_PGM_RSRC2:TGID_X_EN: 1 +; COMPUTE_PGM_RSRC2:TGID_Y_EN: 0 +; COMPUTE_PGM_RSRC2:TGID_Z_EN: 0 +; COMPUTE_PGM_RSRC2:TIDIG_COMP_CNT: 0 +; COMPUTE_PGM_RSRC3_GFX90A:ACCUM_OFFSET: 1 +; COMPUTE_PGM_RSRC3_GFX90A:TG_SPLIT: 0 + .text + .p2alignl 6, 3212836864 + .fill 256, 4, 3212836864 + .section .AMDGPU.gpr_maximums,"",@progbits + .set amdgpu.max_num_vgpr, 0 + .set amdgpu.max_num_agpr, 0 + .set amdgpu.max_num_sgpr, 0 + .text + .type __hip_cuid_85127a709332f6f1,@object ; @__hip_cuid_85127a709332f6f1 + .section .bss,"aw",@nobits + .globl __hip_cuid_85127a709332f6f1 +__hip_cuid_85127a709332f6f1: + .byte 0 ; 0x0 + .size __hip_cuid_85127a709332f6f1, 1 + + .ident "AMD clang version 22.0.0git (https://github.com/RadeonOpenCompute/llvm-project roc-7.2.0 26014 7b800a19466229b8479a78de19143dc33c3ab9b5)" + .section ".note.GNU-stack","",@progbits + .addrsig + .addrsig_sym __hip_cuid_85127a709332f6f1 + .amdgpu_metadata +--- +amdhsa.kernels: + - .agpr_count: 16 + .args: + - .actual_access: read_only + .address_space: global + .offset: 0 + .size: 8 + .value_kind: global_buffer + - .actual_access: read_only + .address_space: global + .offset: 8 + .size: 8 + .value_kind: global_buffer + - .actual_access: write_only + .address_space: global + .offset: 16 + .size: 8 + .value_kind: global_buffer + - .actual_access: read_only + .address_space: global + .offset: 24 + .size: 8 + .value_kind: global_buffer + - .actual_access: read_only + .address_space: global + .offset: 32 + .size: 8 + .value_kind: global_buffer + - .offset: 40 + .size: 4 + .value_kind: by_value + - .offset: 44 + .size: 4 + .value_kind: by_value + - .offset: 48 + .size: 4 + .value_kind: by_value + - .offset: 52 + .size: 4 + .value_kind: by_value + - .offset: 56 + .size: 4 + .value_kind: by_value + - .offset: 60 + .size: 4 + .value_kind: by_value + - .offset: 64 + .size: 4 + .value_kind: by_value + .group_segment_fixed_size: 0 + .kernarg_segment_align: 8 + .kernarg_segment_size: 68 + .language: OpenCL C + .language_version: + - 2 + - 0 + .max_flat_workgroup_size: 64 + .name: _Z14gemm_core_gmemPKhS0_PvS0_S0_iiiiiii + .private_segment_fixed_size: 0 + .sgpr_count: 34 + .sgpr_spill_count: 0 + .symbol: _Z14gemm_core_gmemPKhS0_PvS0_S0_iiiiiii.kd + .uniform_work_group_size: 1 + .uses_dynamic_stack: false + .vgpr_count: 44 + .vgpr_spill_count: 0 + .wavefront_size: 64 + - .agpr_count: 16 + .args: + - .actual_access: read_only + .address_space: global + .offset: 0 + .size: 8 + .value_kind: global_buffer + - .actual_access: read_only + .address_space: global + .offset: 8 + .size: 8 + .value_kind: global_buffer + - .actual_access: write_only + .address_space: global + .offset: 16 + .size: 8 + .value_kind: global_buffer + - .actual_access: read_only + .address_space: global + .offset: 24 + .size: 8 + .value_kind: global_buffer + - .actual_access: read_only + .address_space: global + .offset: 32 + .size: 8 + .value_kind: global_buffer + - .offset: 40 + .size: 4 + .value_kind: by_value + - .offset: 44 + .size: 4 + .value_kind: by_value + - .offset: 48 + .size: 4 + .value_kind: by_value + - .offset: 52 + .size: 4 + .value_kind: by_value + - .offset: 56 + .size: 4 + .value_kind: by_value + - .offset: 60 + .size: 4 + .value_kind: by_value + - .offset: 64 + .size: 4 + .value_kind: by_value + .group_segment_fixed_size: 16384 + .kernarg_segment_align: 8 + .kernarg_segment_size: 68 + .language: OpenCL C + .language_version: + - 2 + - 0 + .max_flat_workgroup_size: 256 + .name: _Z13gemm_core_ldsPKhS0_PvS0_S0_iiiiiii + .private_segment_fixed_size: 0 + .sgpr_count: 38 + .sgpr_spill_count: 0 + .symbol: _Z13gemm_core_ldsPKhS0_PvS0_S0_iiiiiii.kd + .uniform_work_group_size: 1 + .uses_dynamic_stack: false + .vgpr_count: 52 + .vgpr_spill_count: 0 + .wavefront_size: 64 + - .agpr_count: 0 + .args: + - .actual_access: read_only + .address_space: global + .offset: 0 + .size: 8 + .value_kind: global_buffer + - .actual_access: write_only + .address_space: global + .offset: 8 + .size: 8 + .value_kind: global_buffer + - .offset: 16 + .size: 4 + .value_kind: by_value + - .offset: 20 + .size: 4 + .value_kind: by_value + .group_segment_fixed_size: 0 + .kernarg_segment_align: 8 + .kernarg_segment_size: 24 + .language: OpenCL C + .language_version: + - 2 + - 0 + .max_flat_workgroup_size: 1024 + .name: _Z9reduce_skPKfP12hip_bfloat16ii + .private_segment_fixed_size: 0 + .sgpr_count: 14 + .sgpr_spill_count: 0 + .symbol: _Z9reduce_skPKfP12hip_bfloat16ii.kd + .uniform_work_group_size: 1 + .uses_dynamic_stack: false + .vgpr_count: 6 + .vgpr_spill_count: 0 + .wavefront_size: 64 +amdhsa.target: amdgcn-amd-amdhsa--gfx950 +amdhsa.version: + - 1 + - 2 +... + + .end_amdgpu_metadata + +# __CLANG_OFFLOAD_BUNDLE____END__ hip-amdgcn-amd-amdhsa--gfx950 + +# __CLANG_OFFLOAD_BUNDLE____START__ host-x86_64-unknown-linux-gnu- + .file "src_v22.hip" + .text + .globl _Z29__device_stub__gemm_core_gmemPKhS0_PvS0_S0_iiiiiii # -- Begin function _Z29__device_stub__gemm_core_gmemPKhS0_PvS0_S0_iiiiiii + .p2align 4 + .type _Z29__device_stub__gemm_core_gmemPKhS0_PvS0_S0_iiiiiii,@function +_Z29__device_stub__gemm_core_gmemPKhS0_PvS0_S0_iiiiiii: # @_Z29__device_stub__gemm_core_gmemPKhS0_PvS0_S0_iiiiiii + .cfi_startproc +# %bb.0: + subq $200, %rsp + .cfi_def_cfa_offset 208 + movq %rdi, 88(%rsp) + movq %rsi, 80(%rsp) + movq %rdx, 72(%rsp) + movq %rcx, 64(%rsp) + movq %r8, 56(%rsp) + movl %r9d, 4(%rsp) + leaq 88(%rsp), %rax + movq %rax, 96(%rsp) + leaq 80(%rsp), %rax + movq %rax, 104(%rsp) + leaq 72(%rsp), %rax + movq %rax, 112(%rsp) + leaq 64(%rsp), %rax + movq %rax, 120(%rsp) + leaq 56(%rsp), %rax + movq %rax, 128(%rsp) + leaq 4(%rsp), %rax + movq %rax, 136(%rsp) + leaq 208(%rsp), %rax + movq %rax, 144(%rsp) + leaq 216(%rsp), %rax + movq %rax, 152(%rsp) + leaq 224(%rsp), %rax + movq %rax, 160(%rsp) + leaq 232(%rsp), %rax + movq %rax, 168(%rsp) + leaq 240(%rsp), %rax + movq %rax, 176(%rsp) + leaq 248(%rsp), %rax + movq %rax, 184(%rsp) + leaq 40(%rsp), %rdi + leaq 24(%rsp), %rsi + leaq 16(%rsp), %rdx + leaq 8(%rsp), %rcx + callq __hipPopCallConfiguration + movq 40(%rsp), %rsi + movl 48(%rsp), %edx + movq 24(%rsp), %rcx + movl 32(%rsp), %r8d + leaq 96(%rsp), %r9 + movl $_Z14gemm_core_gmemPKhS0_PvS0_S0_iiiiiii, %edi + pushq 8(%rsp) + .cfi_adjust_cfa_offset 8 + pushq 24(%rsp) + .cfi_adjust_cfa_offset 8 + callq hipLaunchKernel + addq $216, %rsp + .cfi_adjust_cfa_offset -216 + retq +.Lfunc_end0: + .size _Z29__device_stub__gemm_core_gmemPKhS0_PvS0_S0_iiiiiii, .Lfunc_end0-_Z29__device_stub__gemm_core_gmemPKhS0_PvS0_S0_iiiiiii + .cfi_endproc + # -- End function + .globl _Z28__device_stub__gemm_core_ldsPKhS0_PvS0_S0_iiiiiii # -- Begin function _Z28__device_stub__gemm_core_ldsPKhS0_PvS0_S0_iiiiiii + .p2align 4 + .type _Z28__device_stub__gemm_core_ldsPKhS0_PvS0_S0_iiiiiii,@function +_Z28__device_stub__gemm_core_ldsPKhS0_PvS0_S0_iiiiiii: # @_Z28__device_stub__gemm_core_ldsPKhS0_PvS0_S0_iiiiiii + .cfi_startproc +# %bb.0: + subq $200, %rsp + .cfi_def_cfa_offset 208 + movq %rdi, 88(%rsp) + movq %rsi, 80(%rsp) + movq %rdx, 72(%rsp) + movq %rcx, 64(%rsp) + movq %r8, 56(%rsp) + movl %r9d, 4(%rsp) + leaq 88(%rsp), %rax + movq %rax, 96(%rsp) + leaq 80(%rsp), %rax + movq %rax, 104(%rsp) + leaq 72(%rsp), %rax + movq %rax, 112(%rsp) + leaq 64(%rsp), %rax + movq %rax, 120(%rsp) + leaq 56(%rsp), %rax + movq %rax, 128(%rsp) + leaq 4(%rsp), %rax + movq %rax, 136(%rsp) + leaq 208(%rsp), %rax + movq %rax, 144(%rsp) + leaq 216(%rsp), %rax + movq %rax, 152(%rsp) + leaq 224(%rsp), %rax + movq %rax, 160(%rsp) + leaq 232(%rsp), %rax + movq %rax, 168(%rsp) + leaq 240(%rsp), %rax + movq %rax, 176(%rsp) + leaq 248(%rsp), %rax + movq %rax, 184(%rsp) + leaq 40(%rsp), %rdi + leaq 24(%rsp), %rsi + leaq 16(%rsp), %rdx + leaq 8(%rsp), %rcx + callq __hipPopCallConfiguration + movq 40(%rsp), %rsi + movl 48(%rsp), %edx + movq 24(%rsp), %rcx + movl 32(%rsp), %r8d + leaq 96(%rsp), %r9 + movl $_Z13gemm_core_ldsPKhS0_PvS0_S0_iiiiiii, %edi + pushq 8(%rsp) + .cfi_adjust_cfa_offset 8 + pushq 24(%rsp) + .cfi_adjust_cfa_offset 8 + callq hipLaunchKernel + addq $216, %rsp + .cfi_adjust_cfa_offset -216 + retq +.Lfunc_end1: + .size _Z28__device_stub__gemm_core_ldsPKhS0_PvS0_S0_iiiiiii, .Lfunc_end1-_Z28__device_stub__gemm_core_ldsPKhS0_PvS0_S0_iiiiiii + .cfi_endproc + # -- End function + .globl _Z24__device_stub__reduce_skPKfP12hip_bfloat16ii # -- Begin function _Z24__device_stub__reduce_skPKfP12hip_bfloat16ii + .p2align 4 + .type _Z24__device_stub__reduce_skPKfP12hip_bfloat16ii,@function +_Z24__device_stub__reduce_skPKfP12hip_bfloat16ii: # @_Z24__device_stub__reduce_skPKfP12hip_bfloat16ii + .cfi_startproc +# %bb.0: + subq $120, %rsp + .cfi_def_cfa_offset 128 + movq %rdi, 72(%rsp) + movq %rsi, 64(%rsp) + movl %edx, 12(%rsp) + movl %ecx, 8(%rsp) + leaq 72(%rsp), %rax + movq %rax, 80(%rsp) + leaq 64(%rsp), %rax + movq %rax, 88(%rsp) + leaq 12(%rsp), %rax + movq %rax, 96(%rsp) + leaq 8(%rsp), %rax + movq %rax, 104(%rsp) + leaq 48(%rsp), %rdi + leaq 32(%rsp), %rsi + leaq 24(%rsp), %rdx + leaq 16(%rsp), %rcx + callq __hipPopCallConfiguration + movq 48(%rsp), %rsi + movl 56(%rsp), %edx + movq 32(%rsp), %rcx + movl 40(%rsp), %r8d + leaq 80(%rsp), %r9 + movl $_Z9reduce_skPKfP12hip_bfloat16ii, %edi + pushq 16(%rsp) + .cfi_adjust_cfa_offset 8 + pushq 32(%rsp) + .cfi_adjust_cfa_offset 8 + callq hipLaunchKernel + addq $136, %rsp + .cfi_adjust_cfa_offset -136 + retq +.Lfunc_end2: + .size _Z24__device_stub__reduce_skPKfP12hip_bfloat16ii, .Lfunc_end2-_Z24__device_stub__reduce_skPKfP12hip_bfloat16ii + .cfi_endproc + # -- End function + .p2align 4 # -- Begin function __hip_module_ctor + .type __hip_module_ctor,@function +__hip_module_ctor: # @__hip_module_ctor + .cfi_startproc +# %bb.0: + pushq %rbx + .cfi_def_cfa_offset 16 + subq $32, %rsp + .cfi_def_cfa_offset 48 + .cfi_offset %rbx, -16 + movq __hip_gpubin_handle_85127a709332f6f1(%rip), %rbx + testq %rbx, %rbx + jne .LBB3_2 +# %bb.1: + movl $__hip_fatbin_wrapper, %edi + callq __hipRegisterFatBinary + movq %rax, %rbx + movq %rax, __hip_gpubin_handle_85127a709332f6f1(%rip) +.LBB3_2: + xorps %xmm0, %xmm0 + movups %xmm0, 16(%rsp) + movups %xmm0, (%rsp) + movl $_Z14gemm_core_gmemPKhS0_PvS0_S0_iiiiiii, %esi + movl $.L__unnamed_1, %edx + movl $.L__unnamed_1, %ecx + movq %rbx, %rdi + movl $-1, %r8d + xorl %r9d, %r9d + callq __hipRegisterFunction + xorps %xmm0, %xmm0 + movups %xmm0, 16(%rsp) + movups %xmm0, (%rsp) + movl $_Z13gemm_core_ldsPKhS0_PvS0_S0_iiiiiii, %esi + movl $.L__unnamed_2, %edx + movl $.L__unnamed_2, %ecx + movq %rbx, %rdi + movl $-1, %r8d + xorl %r9d, %r9d + callq __hipRegisterFunction + xorps %xmm0, %xmm0 + movups %xmm0, 16(%rsp) + movups %xmm0, (%rsp) + movl $_Z9reduce_skPKfP12hip_bfloat16ii, %esi + movl $.L__unnamed_3, %edx + movl $.L__unnamed_3, %ecx + movq %rbx, %rdi + movl $-1, %r8d + xorl %r9d, %r9d + callq __hipRegisterFunction + movl $__hip_module_dtor, %edi + addq $32, %rsp + .cfi_def_cfa_offset 16 + popq %rbx + .cfi_def_cfa_offset 8 + jmp atexit # TAILCALL +.Lfunc_end3: + .size __hip_module_ctor, .Lfunc_end3-__hip_module_ctor + .cfi_endproc + # -- End function + .p2align 4 # -- Begin function __hip_module_dtor + .type __hip_module_dtor,@function +__hip_module_dtor: # @__hip_module_dtor + .cfi_startproc +# %bb.0: + movq __hip_gpubin_handle_85127a709332f6f1(%rip), %rdi + testq %rdi, %rdi + je .LBB4_2 +# %bb.1: + pushq %rax + .cfi_def_cfa_offset 16 + callq __hipUnregisterFatBinary + movq $0, __hip_gpubin_handle_85127a709332f6f1(%rip) + addq $8, %rsp + .cfi_def_cfa_offset 8 +.LBB4_2: + retq +.Lfunc_end4: + .size __hip_module_dtor, .Lfunc_end4-__hip_module_dtor + .cfi_endproc + # -- End function + .type _Z14gemm_core_gmemPKhS0_PvS0_S0_iiiiiii,@object # @_Z14gemm_core_gmemPKhS0_PvS0_S0_iiiiiii + .section .rodata,"a",@progbits + .globl _Z14gemm_core_gmemPKhS0_PvS0_S0_iiiiiii + .p2align 3, 0x0 +_Z14gemm_core_gmemPKhS0_PvS0_S0_iiiiiii: + .quad _Z29__device_stub__gemm_core_gmemPKhS0_PvS0_S0_iiiiiii + .size _Z14gemm_core_gmemPKhS0_PvS0_S0_iiiiiii, 8 + + .type _Z13gemm_core_ldsPKhS0_PvS0_S0_iiiiiii,@object # @_Z13gemm_core_ldsPKhS0_PvS0_S0_iiiiiii + .globl _Z13gemm_core_ldsPKhS0_PvS0_S0_iiiiiii + .p2align 3, 0x0 +_Z13gemm_core_ldsPKhS0_PvS0_S0_iiiiiii: + .quad _Z28__device_stub__gemm_core_ldsPKhS0_PvS0_S0_iiiiiii + .size _Z13gemm_core_ldsPKhS0_PvS0_S0_iiiiiii, 8 + + .type _Z9reduce_skPKfP12hip_bfloat16ii,@object # @_Z9reduce_skPKfP12hip_bfloat16ii + .globl _Z9reduce_skPKfP12hip_bfloat16ii + .p2align 3, 0x0 +_Z9reduce_skPKfP12hip_bfloat16ii: + .quad _Z24__device_stub__reduce_skPKfP12hip_bfloat16ii + .size _Z9reduce_skPKfP12hip_bfloat16ii, 8 + + .type .L__unnamed_1,@object # @0 + .section .rodata.str1.1,"aMS",@progbits,1 +.L__unnamed_1: + .asciz "_Z14gemm_core_gmemPKhS0_PvS0_S0_iiiiiii" + .size .L__unnamed_1, 40 + + .type .L__unnamed_2,@object # @1 +.L__unnamed_2: + .asciz "_Z13gemm_core_ldsPKhS0_PvS0_S0_iiiiiii" + .size .L__unnamed_2, 39 + + .type .L__unnamed_3,@object # @2 +.L__unnamed_3: + .asciz "_Z9reduce_skPKfP12hip_bfloat16ii" + .size .L__unnamed_3, 33 + + .type __hip_fatbin_wrapper,@object # @__hip_fatbin_wrapper + .section .hipFatBinSegment,"a",@progbits + .p2align 3, 0x0 +__hip_fatbin_wrapper: + .long 1212764230 # 0x48495046 + .long 1 # 0x1 + .quad __hip_fatbin_85127a709332f6f1 + .quad 0 + .size __hip_fatbin_wrapper, 24 + + .type __hip_gpubin_handle_85127a709332f6f1,@object # @__hip_gpubin_handle_85127a709332f6f1 + .local __hip_gpubin_handle_85127a709332f6f1 + .comm __hip_gpubin_handle_85127a709332f6f1,8,8 + .section .init_array,"aw",@init_array + .p2align 3, 0x0 + .quad __hip_module_ctor + .type __hip_cuid_85127a709332f6f1,@object # @__hip_cuid_85127a709332f6f1 + .bss + .globl __hip_cuid_85127a709332f6f1 +__hip_cuid_85127a709332f6f1: + .byte 0 # 0x0 + .size __hip_cuid_85127a709332f6f1, 1 + + .ident "AMD clang version 22.0.0git (https://github.com/RadeonOpenCompute/llvm-project roc-7.2.0 26014 7b800a19466229b8479a78de19143dc33c3ab9b5)" + .section ".note.GNU-stack","",@progbits + .addrsig + .addrsig_sym _Z29__device_stub__gemm_core_gmemPKhS0_PvS0_S0_iiiiiii + .addrsig_sym _Z28__device_stub__gemm_core_ldsPKhS0_PvS0_S0_iiiiiii + .addrsig_sym _Z24__device_stub__reduce_skPKfP12hip_bfloat16ii + .addrsig_sym __hip_module_ctor + .addrsig_sym __hip_module_dtor + .addrsig_sym _Z14gemm_core_gmemPKhS0_PvS0_S0_iiiiiii + .addrsig_sym _Z13gemm_core_ldsPKhS0_PvS0_S0_iiiiiii + .addrsig_sym _Z9reduce_skPKfP12hip_bfloat16ii + .addrsig_sym __hip_fatbin_85127a709332f6f1 + .addrsig_sym __hip_fatbin_wrapper + .addrsig_sym __hip_cuid_85127a709332f6f1 + +# __CLANG_OFFLOAD_BUNDLE____END__ host-x86_64-unknown-linux-gnu- diff --git a/examples/amd_mxfp4_mm/asm-dump/src.s b/examples/amd_mxfp4_mm/asm-dump/src.s new file mode 100644 index 0000000000..ef74cd27ca --- /dev/null +++ b/examples/amd_mxfp4_mm/asm-dump/src.s @@ -0,0 +1,1283 @@ + +# __CLANG_OFFLOAD_BUNDLE____START__ hip-amdgcn-amd-amdhsa--gfx950 + .amdgcn_target "amdgcn-amd-amdhsa--gfx950" + .amdhsa_code_object_version 6 + .text + .protected _Z9gemm_corePKhS0_PvS0_S0_iiiiiii ; -- Begin function _Z9gemm_corePKhS0_PvS0_S0_iiiiiii + .globl _Z9gemm_corePKhS0_PvS0_S0_iiiiiii + .p2align 8 + .type _Z9gemm_corePKhS0_PvS0_S0_iiiiiii,@function +_Z9gemm_corePKhS0_PvS0_S0_iiiiiii: ; @_Z9gemm_corePKhS0_PvS0_S0_iiiiiii +; %bb.0: + s_load_dwordx2 s[6:7], s[0:1], 0x10 + s_load_dwordx8 s[8:15], s[0:1], 0x28 + v_and_b32_e32 v1, 31, v0 + s_waitcnt lgkmcnt(0) + s_lshl_b32 s15, s2, 5 + s_lshl_b32 s23, s3, 5 + v_or_b32_e32 v2, s23, v1 + s_mul_i32 s22, s11, s4 + s_add_i32 s2, s22, s11 + s_min_i32 s5, s2, s10 + v_lshrrev_b32_e32 v14, 5, v0 + v_cmp_gt_i32_e64 s[2:3], s9, v2 + v_ashrrev_i32_e32 v3, 31, v2 + v_mov_b32_e32 v18, 0 + v_accvgpr_write_b32 a0, 0 + v_accvgpr_write_b32 a1, 0 + v_accvgpr_write_b32 a2, 0 + v_accvgpr_write_b32 a3, 0 + v_accvgpr_write_b32 a4, 0 + v_accvgpr_write_b32 a5, 0 + v_accvgpr_write_b32 a6, 0 + v_accvgpr_write_b32 a7, 0 + v_accvgpr_write_b32 a8, 0 + v_accvgpr_write_b32 a9, 0 + v_accvgpr_write_b32 a10, 0 + v_accvgpr_write_b32 a11, 0 + v_accvgpr_write_b32 a12, 0 + v_accvgpr_write_b32 a13, 0 + v_accvgpr_write_b32 a14, 0 + s_cmp_ge_i32 s22, s5 + v_accvgpr_write_b32 a15, 0 + s_cbranch_scc0 .LBB0_4 +; %bb.1: + s_cmp_lg_u32 s12, 1 + s_mov_b64 s[0:1], -1 + s_cbranch_scc1 .LBB0_26 +.LBB0_2: ; %Flow442 + s_andn2_b64 vcc, exec, s[0:1] + s_cbranch_vccz .LBB0_60 +.LBB0_3: ; %.loopexit + s_endpgm +.LBB0_4: + s_load_dwordx4 s[16:19], s[0:1], 0x0 + s_ashr_i32 s11, s10, 1 + v_or_b32_e32 v10, s15, v1 + v_mad_i64_i32 v[4:5], s[20:21], s11, v10, 0 + s_ashr_i32 s20, s22, 1 + v_lshlrev_b32_e32 v15, 4, v14 + v_add_u32_e32 v8, s20, v15 + v_cmp_gt_i32_e32 vcc, s8, v10 + v_mov_b32_e32 v19, v18 + v_mov_b32_e32 v20, v18 + v_mov_b32_e32 v21, v18 + s_waitcnt lgkmcnt(0) + v_lshl_add_u64 v[4:5], s[16:17], 0, v[4:5] + v_ashrrev_i32_e32 v9, 31, v8 + s_and_saveexec_b64 s[16:17], vcc + s_cbranch_execz .LBB0_6 +; %bb.5: ; %.loopexit44.i + v_lshl_add_u64 v[6:7], v[4:5], 0, v[8:9] + global_load_dwordx4 v[18:21], v[6:7], off +.LBB0_6: + s_or_b64 exec, exec, s[16:17] + s_load_dwordx2 s[20:21], s[0:1], 0x18 + v_mad_i64_i32 v[6:7], s[16:17], s11, v2, 0 + v_mov_b32_e32 v26, 0 + v_mov_b32_e32 v27, v26 + v_mov_b32_e32 v28, v26 + v_mov_b32_e32 v29, v26 + v_lshl_add_u64 v[6:7], s[18:19], 0, v[6:7] + s_and_saveexec_b64 s[16:17], s[2:3] + s_cbranch_execz .LBB0_8 +; %bb.7: ; %.loopexit.i + v_lshl_add_u64 v[8:9], v[6:7], 0, v[8:9] + global_load_dwordx4 v[26:29], v[8:9], off +.LBB0_8: + s_or_b64 exec, exec, s[16:17] + s_load_dwordx2 s[16:17], s[0:1], 0x20 + v_mad_i64_i32 v[8:9], s[0:1], s13, v10, 0 + s_ashr_i32 s0, s22, 5 + s_nop 0 + v_add_u32_e32 v10, s0, v14 + v_cmp_gt_i32_e64 s[0:1], s13, v10 + s_and_b64 s[18:19], vcc, s[0:1] + v_mov_b32_e32 v17, 0x7f + s_waitcnt lgkmcnt(0) + v_lshl_add_u64 v[8:9], s[20:21], 0, v[8:9] + v_mov_b32_e32 v16, 0x7f + s_and_saveexec_b64 s[0:1], s[18:19] + s_cbranch_execz .LBB0_10 +; %bb.9: + v_ashrrev_i32_e32 v11, 31, v10 + v_lshl_add_u64 v[12:13], v[8:9], 0, v[10:11] + global_load_ubyte v16, v[12:13], off +.LBB0_10: + s_or_b64 exec, exec, s[0:1] + s_ashr_i32 s0, s23, 5 + v_lshlrev_b32_e32 v0, 2, v0 + s_lshl_b32 s1, s14, 5 + v_and_b32_e32 v12, 60, v0 + v_mov_b32_e32 v13, 0 + v_mov_b32_e32 v0, s0 + s_ashr_i32 s10, s10, 5 + v_mad_i64_i32 v[34:35], s[0:1], s1, v0, v[12:13] + v_lshrrev_b32_e32 v0, 4, v1 + v_or_b32_e32 v34, v34, v0 + v_cmp_gt_i32_e64 s[0:1], s10, v10 + v_accvgpr_write_b32 a0, 0 + v_accvgpr_write_b32 a1, 0 + v_accvgpr_write_b32 a2, 0 + v_accvgpr_write_b32 a3, 0 + v_accvgpr_write_b32 a4, 0 + v_accvgpr_write_b32 a5, 0 + v_accvgpr_write_b32 a6, 0 + v_accvgpr_write_b32 a7, 0 + v_accvgpr_write_b32 a8, 0 + v_accvgpr_write_b32 a9, 0 + v_accvgpr_write_b32 a10, 0 + v_accvgpr_write_b32 a11, 0 + v_accvgpr_write_b32 a12, 0 + v_accvgpr_write_b32 a13, 0 + v_accvgpr_write_b32 a14, 0 + v_accvgpr_write_b32 a15, 0 + s_and_b64 s[18:19], s[2:3], s[0:1] + v_lshl_add_u64 v[0:1], s[16:17], 0, v[34:35] + s_and_saveexec_b64 s[0:1], s[18:19] + s_cbranch_execz .LBB0_12 +; %bb.11: + v_lshlrev_b32_e32 v11, 5, v10 + v_and_b32_e32 v34, 0xffffff00, v11 + v_ashrrev_i32_e32 v35, 31, v34 + v_lshlrev_b32_e32 v11, 6, v10 + v_and_b32_e32 v12, 0xc0, v11 + v_lshrrev_b32_e32 v10, 1, v10 + v_lshl_add_u64 v[34:35], v[0:1], 0, v[34:35] + v_and_b32_e32 v10, 2, v10 + v_mov_b32_e32 v11, v13 + v_lshl_add_u64 v[12:13], v[34:35], 0, v[12:13] + v_lshl_add_u64 v[10:11], v[12:13], 0, v[10:11] + global_load_ubyte v17, v[10:11], off +.LBB0_12: ; %_Z9load_fragPKhS0_S0_S0_iiiillllibbRDv32_hS2_RhS3_.exit + s_or_b64 exec, exec, s[0:1] + s_add_i32 s11, s22, 64 + s_cmp_ge_i32 s11, s5 + s_cbranch_scc1 .LBB0_24 +; %bb.13: ; %.lr.ph + v_mov_b32_e32 v34, 0 + v_accvgpr_write_b32 a15, 0 + v_accvgpr_write_b32 a14, 0 + v_accvgpr_write_b32 a13, 0 + v_accvgpr_write_b32 a12, 0 + v_accvgpr_write_b32 a11, 0 + v_accvgpr_write_b32 a10, 0 + v_accvgpr_write_b32 a9, 0 + v_accvgpr_write_b32 a8, 0 + v_accvgpr_write_b32 a7, 0 + v_accvgpr_write_b32 a6, 0 + v_accvgpr_write_b32 a5, 0 + v_accvgpr_write_b32 a4, 0 + v_accvgpr_write_b32 a3, 0 + v_accvgpr_write_b32 a2, 0 + v_accvgpr_write_b32 a1, 0 + v_accvgpr_write_b32 a0, 0 +.LBB0_14: ; =>This Inner Loop Header: Depth=1 + s_ashr_i32 s0, s11, 1 + v_add_u32_e32 v10, s0, v15 + v_mov_b32_e32 v35, v34 + v_mov_b32_e32 v36, v34 + v_mov_b32_e32 v37, v34 + v_mov_b64_e32 v[44:45], v[40:41] + v_ashrrev_i32_e32 v11, 31, v10 + v_mov_b64_e32 v[42:43], v[38:39] + v_mov_b64_e32 v[40:41], v[36:37] + v_mov_b64_e32 v[38:39], v[34:35] + s_and_saveexec_b64 s[0:1], vcc + s_cbranch_execz .LBB0_16 +; %bb.15: ; %.loopexit44.i136 + ; in Loop: Header=BB0_14 Depth=1 + v_lshl_add_u64 v[12:13], v[4:5], 0, v[10:11] + global_load_dwordx4 v[38:41], v[12:13], off +.LBB0_16: ; in Loop: Header=BB0_14 Depth=1 + s_or_b64 exec, exec, s[0:1] + s_waitcnt vmcnt(0) + v_mov_b64_e32 v[52:53], v[40:41] + v_mov_b64_e32 v[50:51], v[38:39] + v_mov_b64_e32 v[48:49], v[36:37] + v_mov_b64_e32 v[46:47], v[34:35] + s_and_saveexec_b64 s[0:1], s[2:3] + s_cbranch_execz .LBB0_18 +; %bb.17: ; %.loopexit.i134 + ; in Loop: Header=BB0_14 Depth=1 + v_lshl_add_u64 v[10:11], v[6:7], 0, v[10:11] + global_load_dwordx4 v[46:49], v[10:11], off +.LBB0_18: ; in Loop: Header=BB0_14 Depth=1 + s_or_b64 exec, exec, s[0:1] + s_ashr_i32 s0, s11, 5 + v_add_u32_e32 v10, s0, v14 + v_cmp_gt_i32_e64 s[0:1], s13, v10 + s_and_b64 s[16:17], vcc, s[0:1] + v_mov_b32_e32 v12, 0x7f + v_mov_b32_e32 v11, 0x7f + s_and_saveexec_b64 s[0:1], s[16:17] + s_cbranch_execz .LBB0_20 +; %bb.19: ; in Loop: Header=BB0_14 Depth=1 + v_ashrrev_i32_e32 v11, 31, v10 + v_lshl_add_u64 v[22:23], v[8:9], 0, v[10:11] + global_load_ubyte v11, v[22:23], off +.LBB0_20: ; in Loop: Header=BB0_14 Depth=1 + s_or_b64 exec, exec, s[0:1] + v_cmp_gt_i32_e64 s[0:1], s10, v10 + s_and_b64 s[16:17], s[2:3], s[0:1] + s_and_saveexec_b64 s[0:1], s[16:17] + s_cbranch_execz .LBB0_22 +; %bb.21: ; in Loop: Header=BB0_14 Depth=1 + v_lshlrev_b32_e32 v12, 5, v10 + v_and_b32_e32 v12, 0xffffff00, v12 + v_ashrrev_i32_e32 v13, 31, v12 + v_lshlrev_b32_e32 v22, 6, v10 + v_and_b32_e32 v22, 0xc0, v22 + v_mov_b32_e32 v23, v34 + v_lshrrev_b32_e32 v10, 1, v10 + v_lshl_add_u64 v[12:13], v[0:1], 0, v[12:13] + v_and_b32_e32 v24, 2, v10 + v_mov_b32_e32 v25, v34 + v_lshl_add_u64 v[12:13], v[12:13], 0, v[22:23] + v_lshl_add_u64 v[12:13], v[12:13], 0, v[24:25] + global_load_ubyte v12, v[12:13], off +.LBB0_22: ; %_Z9load_fragPKhS0_S0_S0_iiiillllibbRDv32_hS2_RhS3_.exit138 + ; in Loop: Header=BB0_14 Depth=1 + s_or_b64 exec, exec, s[0:1] + v_and_b32_e32 v10, 0xff, v16 + v_and_b32_e32 v13, 0xff, v17 + s_add_i32 s11, s11, 64 + s_cmp_ge_i32 s11, s5 + v_mfma_scale_f32_32x32x64_f8f6f4 a[0:15], v[18:21], v[26:29], a[0:15], v10, v13 op_sel_hi:[0,0,0] cbsz:4 blgp:4 + s_cbranch_scc1 .LBB0_25 +; %bb.23: ; in Loop: Header=BB0_14 Depth=1 + v_mov_b64_e32 v[18:19], v[38:39] + s_waitcnt vmcnt(0) + v_mov_b64_e32 v[26:27], v[46:47] + v_mov_b64_e32 v[20:21], v[40:41] + v_mov_b64_e32 v[22:23], v[42:43] + v_mov_b64_e32 v[24:25], v[44:45] + v_mov_b64_e32 v[28:29], v[48:49] + v_mov_b64_e32 v[30:31], v[50:51] + v_mov_b64_e32 v[32:33], v[52:53] + v_mov_b32_e32 v17, v12 + v_mov_b32_e32 v16, v11 + s_branch .LBB0_14 +.LBB0_24: + s_waitcnt vmcnt(0) + v_mov_b64_e32 v[44:45], v[24:25] + v_mov_b64_e32 v[52:53], v[32:33] + v_mov_b64_e32 v[40:41], v[20:21] + v_mov_b64_e32 v[38:39], v[18:19] + v_mov_b64_e32 v[48:49], v[28:29] + v_mov_b64_e32 v[46:47], v[26:27] + v_mov_b32_e32 v12, v17 + v_mov_b32_e32 v11, v16 + v_mov_b64_e32 v[42:43], v[22:23] + v_mov_b64_e32 v[50:51], v[30:31] +.LBB0_25: ; %Flow444 + s_waitcnt vmcnt(0) + v_and_b32_e32 v0, 0xff, v11 + v_and_b32_e32 v1, 0xff, v12 + s_nop 1 + v_mfma_scale_f32_32x32x64_f8f6f4 a[0:15], v[38:41], v[46:49], a[0:15], v0, v1 op_sel_hi:[0,0,0] cbsz:4 blgp:4 + s_cmp_lg_u32 s12, 1 + s_mov_b64 s[0:1], -1 + s_cbranch_scc0 .LBB0_2 +.LBB0_26: + s_and_saveexec_b64 s[0:1], s[2:3] + s_cbranch_execz .LBB0_59 +; %bb.27: ; %.preheader154 + v_lshl_add_u32 v4, v14, 2, s15 + s_mul_hi_i32 s5, s8, s4 + s_mul_i32 s4, s8, s4 + s_ashr_i32 s12, s9, 31 + v_lshl_add_u64 v[0:1], v[2:3], 2, s[6:7] + v_cmp_gt_i32_e32 vcc, s8, v4 + s_and_saveexec_b64 s[10:11], vcc + s_cbranch_execz .LBB0_29 +; %bb.28: + v_ashrrev_i32_e32 v5, 31, v4 + v_lshl_add_u64 v[6:7], s[4:5], 0, v[4:5] + v_mul_lo_u32 v5, v7, s9 + v_mul_lo_u32 v8, v6, s12 + v_mad_u64_u32 v[6:7], s[16:17], v6, s9, 0 + v_add3_u32 v7, v7, v8, v5 + v_lshl_add_u64 v[6:7], v[6:7], 2, v[0:1] + global_store_dword v[6:7], a0, off +.LBB0_29: + s_or_b64 exec, exec, s[10:11] + v_or_b32_e32 v6, 1, v4 + v_cmp_gt_i32_e32 vcc, s8, v6 + s_and_saveexec_b64 s[10:11], vcc + s_cbranch_execz .LBB0_31 +; %bb.30: + v_ashrrev_i32_e32 v7, 31, v6 + v_lshl_add_u64 v[6:7], s[4:5], 0, v[6:7] + v_mul_lo_u32 v5, v7, s9 + v_mul_lo_u32 v8, v6, s12 + v_mad_u64_u32 v[6:7], s[16:17], v6, s9, 0 + v_add3_u32 v7, v7, v8, v5 + v_lshl_add_u64 v[6:7], v[6:7], 2, v[0:1] + global_store_dword v[6:7], a1, off +.LBB0_31: + s_or_b64 exec, exec, s[10:11] + v_or_b32_e32 v6, 2, v4 + v_cmp_gt_i32_e32 vcc, s8, v6 + s_and_saveexec_b64 s[10:11], vcc + s_cbranch_execz .LBB0_33 +; %bb.32: + v_ashrrev_i32_e32 v7, 31, v6 + v_lshl_add_u64 v[6:7], s[4:5], 0, v[6:7] + v_mul_lo_u32 v5, v7, s9 + v_mul_lo_u32 v8, v6, s12 + v_mad_u64_u32 v[6:7], s[16:17], v6, s9, 0 + v_add3_u32 v7, v7, v8, v5 + v_lshl_add_u64 v[6:7], v[6:7], 2, v[0:1] + global_store_dword v[6:7], a2, off +.LBB0_33: + s_or_b64 exec, exec, s[10:11] + v_or_b32_e32 v6, 3, v4 + v_cmp_gt_i32_e32 vcc, s8, v6 + s_and_saveexec_b64 s[10:11], vcc + s_cbranch_execz .LBB0_35 +; %bb.34: + v_ashrrev_i32_e32 v7, 31, v6 + v_lshl_add_u64 v[6:7], s[4:5], 0, v[6:7] + v_mul_lo_u32 v5, v7, s9 + v_mul_lo_u32 v8, v6, s12 + v_mad_u64_u32 v[6:7], s[16:17], v6, s9, 0 + v_add3_u32 v7, v7, v8, v5 + v_lshl_add_u64 v[6:7], v[6:7], 2, v[0:1] + global_store_dword v[6:7], a3, off +.LBB0_35: + s_or_b64 exec, exec, s[10:11] + v_add_u32_e32 v6, 8, v4 + v_cmp_gt_i32_e32 vcc, s8, v6 + s_and_saveexec_b64 s[10:11], vcc + s_cbranch_execz .LBB0_37 +; %bb.36: + v_ashrrev_i32_e32 v7, 31, v6 + v_lshl_add_u64 v[6:7], s[4:5], 0, v[6:7] + v_mul_lo_u32 v5, v7, s9 + v_mul_lo_u32 v8, v6, s12 + v_mad_u64_u32 v[6:7], s[16:17], v6, s9, 0 + v_add3_u32 v7, v7, v8, v5 + v_lshl_add_u64 v[6:7], v[6:7], 2, v[0:1] + global_store_dword v[6:7], a4, off +.LBB0_37: + s_or_b64 exec, exec, s[10:11] + v_add_u32_e32 v6, 9, v4 + v_cmp_gt_i32_e32 vcc, s8, v6 + s_and_saveexec_b64 s[10:11], vcc + s_cbranch_execz .LBB0_39 +; %bb.38: + v_ashrrev_i32_e32 v7, 31, v6 + v_lshl_add_u64 v[6:7], s[4:5], 0, v[6:7] + v_mul_lo_u32 v5, v7, s9 + v_mul_lo_u32 v8, v6, s12 + v_mad_u64_u32 v[6:7], s[16:17], v6, s9, 0 + v_add3_u32 v7, v7, v8, v5 + v_lshl_add_u64 v[6:7], v[6:7], 2, v[0:1] + global_store_dword v[6:7], a5, off +.LBB0_39: + s_or_b64 exec, exec, s[10:11] + v_add_u32_e32 v6, 10, v4 + v_cmp_gt_i32_e32 vcc, s8, v6 + s_and_saveexec_b64 s[10:11], vcc + s_cbranch_execz .LBB0_41 +; %bb.40: + v_ashrrev_i32_e32 v7, 31, v6 + v_lshl_add_u64 v[6:7], s[4:5], 0, v[6:7] + v_mul_lo_u32 v5, v7, s9 + v_mul_lo_u32 v8, v6, s12 + v_mad_u64_u32 v[6:7], s[16:17], v6, s9, 0 + v_add3_u32 v7, v7, v8, v5 + v_lshl_add_u64 v[6:7], v[6:7], 2, v[0:1] + global_store_dword v[6:7], a6, off +.LBB0_41: + s_or_b64 exec, exec, s[10:11] + v_add_u32_e32 v6, 11, v4 + v_cmp_gt_i32_e32 vcc, s8, v6 + s_and_saveexec_b64 s[10:11], vcc + s_cbranch_execz .LBB0_43 +; %bb.42: + v_ashrrev_i32_e32 v7, 31, v6 + v_lshl_add_u64 v[6:7], s[4:5], 0, v[6:7] + v_mul_lo_u32 v5, v7, s9 + v_mul_lo_u32 v8, v6, s12 + v_mad_u64_u32 v[6:7], s[16:17], v6, s9, 0 + v_add3_u32 v7, v7, v8, v5 + v_lshl_add_u64 v[6:7], v[6:7], 2, v[0:1] + global_store_dword v[6:7], a7, off +.LBB0_43: + s_or_b64 exec, exec, s[10:11] + v_add_u32_e32 v6, 16, v4 + v_cmp_gt_i32_e32 vcc, s8, v6 + s_and_saveexec_b64 s[10:11], vcc + s_cbranch_execz .LBB0_45 +; %bb.44: + v_ashrrev_i32_e32 v7, 31, v6 + v_lshl_add_u64 v[6:7], s[4:5], 0, v[6:7] + v_mul_lo_u32 v5, v7, s9 + v_mul_lo_u32 v8, v6, s12 + v_mad_u64_u32 v[6:7], s[16:17], v6, s9, 0 + v_add3_u32 v7, v7, v8, v5 + v_lshl_add_u64 v[6:7], v[6:7], 2, v[0:1] + global_store_dword v[6:7], a8, off +.LBB0_45: + s_or_b64 exec, exec, s[10:11] + v_add_u32_e32 v6, 17, v4 + v_cmp_gt_i32_e32 vcc, s8, v6 + s_and_saveexec_b64 s[10:11], vcc + s_cbranch_execz .LBB0_47 +; %bb.46: + v_ashrrev_i32_e32 v7, 31, v6 + v_lshl_add_u64 v[6:7], s[4:5], 0, v[6:7] + v_mul_lo_u32 v5, v7, s9 + v_mul_lo_u32 v8, v6, s12 + v_mad_u64_u32 v[6:7], s[16:17], v6, s9, 0 + v_add3_u32 v7, v7, v8, v5 + v_lshl_add_u64 v[6:7], v[6:7], 2, v[0:1] + global_store_dword v[6:7], a9, off +.LBB0_47: + s_or_b64 exec, exec, s[10:11] + v_add_u32_e32 v6, 18, v4 + v_cmp_gt_i32_e32 vcc, s8, v6 + s_and_saveexec_b64 s[10:11], vcc + s_cbranch_execz .LBB0_49 +; %bb.48: + v_ashrrev_i32_e32 v7, 31, v6 + v_lshl_add_u64 v[6:7], s[4:5], 0, v[6:7] + v_mul_lo_u32 v5, v7, s9 + v_mul_lo_u32 v8, v6, s12 + v_mad_u64_u32 v[6:7], s[16:17], v6, s9, 0 + v_add3_u32 v7, v7, v8, v5 + v_lshl_add_u64 v[6:7], v[6:7], 2, v[0:1] + global_store_dword v[6:7], a10, off +.LBB0_49: + s_or_b64 exec, exec, s[10:11] + v_add_u32_e32 v6, 19, v4 + v_cmp_gt_i32_e32 vcc, s8, v6 + s_and_saveexec_b64 s[10:11], vcc + s_cbranch_execz .LBB0_51 +; %bb.50: + v_ashrrev_i32_e32 v7, 31, v6 + v_lshl_add_u64 v[6:7], s[4:5], 0, v[6:7] + v_mul_lo_u32 v5, v7, s9 + v_mul_lo_u32 v8, v6, s12 + v_mad_u64_u32 v[6:7], s[16:17], v6, s9, 0 + v_add3_u32 v7, v7, v8, v5 + v_lshl_add_u64 v[6:7], v[6:7], 2, v[0:1] + global_store_dword v[6:7], a11, off +.LBB0_51: + s_or_b64 exec, exec, s[10:11] + v_add_u32_e32 v6, 24, v4 + v_cmp_gt_i32_e32 vcc, s8, v6 + s_and_saveexec_b64 s[10:11], vcc + s_cbranch_execz .LBB0_53 +; %bb.52: + v_ashrrev_i32_e32 v7, 31, v6 + v_lshl_add_u64 v[6:7], s[4:5], 0, v[6:7] + v_mul_lo_u32 v5, v7, s9 + v_mul_lo_u32 v8, v6, s12 + v_mad_u64_u32 v[6:7], s[16:17], v6, s9, 0 + v_add3_u32 v7, v7, v8, v5 + v_lshl_add_u64 v[6:7], v[6:7], 2, v[0:1] + global_store_dword v[6:7], a12, off +.LBB0_53: + s_or_b64 exec, exec, s[10:11] + v_add_u32_e32 v6, 25, v4 + v_cmp_gt_i32_e32 vcc, s8, v6 + s_and_saveexec_b64 s[10:11], vcc + s_cbranch_execz .LBB0_55 +; %bb.54: + v_ashrrev_i32_e32 v7, 31, v6 + v_lshl_add_u64 v[6:7], s[4:5], 0, v[6:7] + v_mul_lo_u32 v5, v7, s9 + v_mul_lo_u32 v8, v6, s12 + v_mad_u64_u32 v[6:7], s[16:17], v6, s9, 0 + v_add3_u32 v7, v7, v8, v5 + v_lshl_add_u64 v[6:7], v[6:7], 2, v[0:1] + global_store_dword v[6:7], a13, off +.LBB0_55: + s_or_b64 exec, exec, s[10:11] + v_add_u32_e32 v6, 26, v4 + v_cmp_gt_i32_e32 vcc, s8, v6 + s_and_saveexec_b64 s[10:11], vcc + s_cbranch_execz .LBB0_57 +; %bb.56: + v_ashrrev_i32_e32 v7, 31, v6 + v_lshl_add_u64 v[6:7], s[4:5], 0, v[6:7] + v_mul_lo_u32 v5, v7, s9 + v_mul_lo_u32 v8, v6, s12 + v_mad_u64_u32 v[6:7], s[16:17], v6, s9, 0 + v_add3_u32 v7, v7, v8, v5 + v_lshl_add_u64 v[6:7], v[6:7], 2, v[0:1] + global_store_dword v[6:7], a14, off +.LBB0_57: + s_or_b64 exec, exec, s[10:11] + v_add_u32_e32 v4, 27, v4 + v_cmp_gt_i32_e32 vcc, s8, v4 + s_and_b64 exec, exec, vcc + s_cbranch_execz .LBB0_59 +; %bb.58: + v_ashrrev_i32_e32 v5, 31, v4 + v_lshl_add_u64 v[4:5], s[4:5], 0, v[4:5] + v_mul_lo_u32 v6, v5, s9 + v_mul_lo_u32 v7, v4, s12 + v_mad_u64_u32 v[4:5], s[4:5], v4, s9, 0 + v_add3_u32 v5, v5, v7, v6 + v_lshl_add_u64 v[0:1], v[4:5], 2, v[0:1] + global_store_dword v[0:1], a15, off +.LBB0_59: ; %Flow439 + s_or_b64 exec, exec, s[0:1] + s_cbranch_execnz .LBB0_3 +.LBB0_60: + s_and_saveexec_b64 s[0:1], s[2:3] + s_cbranch_execz .LBB0_3 +; %bb.61: ; %.preheader + v_lshl_add_u32 v4, v14, 2, s15 + v_lshl_add_u64 v[0:1], v[2:3], 1, s[6:7] + v_cmp_gt_i32_e32 vcc, s8, v4 + s_and_saveexec_b64 s[0:1], vcc + s_cbranch_execz .LBB0_63 +; %bb.62: + v_mad_i64_i32 v[2:3], s[2:3], v4, s9, 0 + v_lshl_add_u64 v[2:3], v[2:3], 1, v[0:1] + global_store_short_d16_hi v[2:3], a0, off +.LBB0_63: + s_or_b64 exec, exec, s[0:1] + v_or_b32_e32 v2, 1, v4 + v_cmp_gt_i32_e32 vcc, s8, v2 + s_and_saveexec_b64 s[0:1], vcc + s_cbranch_execz .LBB0_65 +; %bb.64: + v_mad_i64_i32 v[2:3], s[2:3], v2, s9, 0 + v_lshl_add_u64 v[2:3], v[2:3], 1, v[0:1] + global_store_short_d16_hi v[2:3], a1, off +.LBB0_65: + s_or_b64 exec, exec, s[0:1] + v_or_b32_e32 v2, 2, v4 + v_cmp_gt_i32_e32 vcc, s8, v2 + s_and_saveexec_b64 s[0:1], vcc + s_cbranch_execz .LBB0_67 +; %bb.66: + v_mad_i64_i32 v[2:3], s[2:3], v2, s9, 0 + v_lshl_add_u64 v[2:3], v[2:3], 1, v[0:1] + global_store_short_d16_hi v[2:3], a2, off +.LBB0_67: + s_or_b64 exec, exec, s[0:1] + v_or_b32_e32 v2, 3, v4 + v_cmp_gt_i32_e32 vcc, s8, v2 + s_and_saveexec_b64 s[0:1], vcc + s_cbranch_execz .LBB0_69 +; %bb.68: + v_mad_i64_i32 v[2:3], s[2:3], v2, s9, 0 + v_lshl_add_u64 v[2:3], v[2:3], 1, v[0:1] + global_store_short_d16_hi v[2:3], a3, off +.LBB0_69: + s_or_b64 exec, exec, s[0:1] + v_add_u32_e32 v2, 8, v4 + v_cmp_gt_i32_e32 vcc, s8, v2 + s_and_saveexec_b64 s[0:1], vcc + s_cbranch_execz .LBB0_71 +; %bb.70: + v_mad_i64_i32 v[2:3], s[2:3], v2, s9, 0 + v_lshl_add_u64 v[2:3], v[2:3], 1, v[0:1] + global_store_short_d16_hi v[2:3], a4, off +.LBB0_71: + s_or_b64 exec, exec, s[0:1] + v_add_u32_e32 v2, 9, v4 + v_cmp_gt_i32_e32 vcc, s8, v2 + s_and_saveexec_b64 s[0:1], vcc + s_cbranch_execz .LBB0_73 +; %bb.72: + v_mad_i64_i32 v[2:3], s[2:3], v2, s9, 0 + v_lshl_add_u64 v[2:3], v[2:3], 1, v[0:1] + global_store_short_d16_hi v[2:3], a5, off +.LBB0_73: + s_or_b64 exec, exec, s[0:1] + v_add_u32_e32 v2, 10, v4 + v_cmp_gt_i32_e32 vcc, s8, v2 + s_and_saveexec_b64 s[0:1], vcc + s_cbranch_execz .LBB0_75 +; %bb.74: + v_mad_i64_i32 v[2:3], s[2:3], v2, s9, 0 + v_lshl_add_u64 v[2:3], v[2:3], 1, v[0:1] + global_store_short_d16_hi v[2:3], a6, off +.LBB0_75: + s_or_b64 exec, exec, s[0:1] + v_add_u32_e32 v2, 11, v4 + v_cmp_gt_i32_e32 vcc, s8, v2 + s_and_saveexec_b64 s[0:1], vcc + s_cbranch_execz .LBB0_77 +; %bb.76: + v_mad_i64_i32 v[2:3], s[2:3], v2, s9, 0 + v_lshl_add_u64 v[2:3], v[2:3], 1, v[0:1] + global_store_short_d16_hi v[2:3], a7, off +.LBB0_77: + s_or_b64 exec, exec, s[0:1] + v_add_u32_e32 v2, 16, v4 + v_cmp_gt_i32_e32 vcc, s8, v2 + s_and_saveexec_b64 s[0:1], vcc + s_cbranch_execz .LBB0_79 +; %bb.78: + v_mad_i64_i32 v[2:3], s[2:3], v2, s9, 0 + v_lshl_add_u64 v[2:3], v[2:3], 1, v[0:1] + global_store_short_d16_hi v[2:3], a8, off +.LBB0_79: + s_or_b64 exec, exec, s[0:1] + v_add_u32_e32 v2, 17, v4 + v_cmp_gt_i32_e32 vcc, s8, v2 + s_and_saveexec_b64 s[0:1], vcc + s_cbranch_execz .LBB0_81 +; %bb.80: + v_mad_i64_i32 v[2:3], s[2:3], v2, s9, 0 + v_lshl_add_u64 v[2:3], v[2:3], 1, v[0:1] + global_store_short_d16_hi v[2:3], a9, off +.LBB0_81: + s_or_b64 exec, exec, s[0:1] + v_add_u32_e32 v2, 18, v4 + v_cmp_gt_i32_e32 vcc, s8, v2 + s_and_saveexec_b64 s[0:1], vcc + s_cbranch_execz .LBB0_83 +; %bb.82: + v_mad_i64_i32 v[2:3], s[2:3], v2, s9, 0 + v_lshl_add_u64 v[2:3], v[2:3], 1, v[0:1] + global_store_short_d16_hi v[2:3], a10, off +.LBB0_83: + s_or_b64 exec, exec, s[0:1] + v_add_u32_e32 v2, 19, v4 + v_cmp_gt_i32_e32 vcc, s8, v2 + s_and_saveexec_b64 s[0:1], vcc + s_cbranch_execz .LBB0_85 +; %bb.84: + v_mad_i64_i32 v[2:3], s[2:3], v2, s9, 0 + v_lshl_add_u64 v[2:3], v[2:3], 1, v[0:1] + global_store_short_d16_hi v[2:3], a11, off +.LBB0_85: + s_or_b64 exec, exec, s[0:1] + v_add_u32_e32 v2, 24, v4 + v_cmp_gt_i32_e32 vcc, s8, v2 + s_and_saveexec_b64 s[0:1], vcc + s_cbranch_execz .LBB0_87 +; %bb.86: + v_mad_i64_i32 v[2:3], s[2:3], v2, s9, 0 + v_lshl_add_u64 v[2:3], v[2:3], 1, v[0:1] + global_store_short_d16_hi v[2:3], a12, off +.LBB0_87: + s_or_b64 exec, exec, s[0:1] + v_add_u32_e32 v2, 25, v4 + v_cmp_gt_i32_e32 vcc, s8, v2 + s_and_saveexec_b64 s[0:1], vcc + s_cbranch_execz .LBB0_89 +; %bb.88: + v_mad_i64_i32 v[2:3], s[2:3], v2, s9, 0 + v_lshl_add_u64 v[2:3], v[2:3], 1, v[0:1] + global_store_short_d16_hi v[2:3], a13, off +.LBB0_89: + s_or_b64 exec, exec, s[0:1] + v_add_u32_e32 v2, 26, v4 + v_cmp_gt_i32_e32 vcc, s8, v2 + s_and_saveexec_b64 s[0:1], vcc + s_cbranch_execz .LBB0_91 +; %bb.90: + v_mad_i64_i32 v[2:3], s[2:3], v2, s9, 0 + v_lshl_add_u64 v[2:3], v[2:3], 1, v[0:1] + global_store_short_d16_hi v[2:3], a14, off +.LBB0_91: + s_or_b64 exec, exec, s[0:1] + v_add_u32_e32 v2, 27, v4 + v_cmp_gt_i32_e32 vcc, s8, v2 + s_and_b64 exec, exec, vcc + s_cbranch_execz .LBB0_3 +; %bb.92: + v_mad_i64_i32 v[2:3], s[0:1], v2, s9, 0 + v_lshl_add_u64 v[0:1], v[2:3], 1, v[0:1] + global_store_short_d16_hi v[0:1], a15, off + s_endpgm + .section .rodata,"a",@progbits + .p2align 6, 0x0 + .amdhsa_kernel _Z9gemm_corePKhS0_PvS0_S0_iiiiiii + .amdhsa_group_segment_fixed_size 0 + .amdhsa_private_segment_fixed_size 0 + .amdhsa_kernarg_size 68 + .amdhsa_user_sgpr_count 2 + .amdhsa_user_sgpr_dispatch_ptr 0 + .amdhsa_user_sgpr_queue_ptr 0 + .amdhsa_user_sgpr_kernarg_segment_ptr 1 + .amdhsa_user_sgpr_dispatch_id 0 + .amdhsa_user_sgpr_kernarg_preload_length 0 + .amdhsa_user_sgpr_kernarg_preload_offset 0 + .amdhsa_user_sgpr_private_segment_size 0 + .amdhsa_uses_dynamic_stack 0 + .amdhsa_enable_private_segment 0 + .amdhsa_system_sgpr_workgroup_id_x 1 + .amdhsa_system_sgpr_workgroup_id_y 1 + .amdhsa_system_sgpr_workgroup_id_z 1 + .amdhsa_system_sgpr_workgroup_info 0 + .amdhsa_system_vgpr_workitem_id 0 + .amdhsa_next_free_vgpr 72 + .amdhsa_next_free_sgpr 24 + .amdhsa_accum_offset 56 + .amdhsa_reserve_vcc 1 + .amdhsa_float_round_mode_32 0 + .amdhsa_float_round_mode_16_64 0 + .amdhsa_float_denorm_mode_32 3 + .amdhsa_float_denorm_mode_16_64 3 + .amdhsa_dx10_clamp 1 + .amdhsa_ieee_mode 1 + .amdhsa_fp16_overflow 0 + .amdhsa_tg_split 0 + .amdhsa_exception_fp_ieee_invalid_op 0 + .amdhsa_exception_fp_denorm_src 0 + .amdhsa_exception_fp_ieee_div_zero 0 + .amdhsa_exception_fp_ieee_overflow 0 + .amdhsa_exception_fp_ieee_underflow 0 + .amdhsa_exception_fp_ieee_inexact 0 + .amdhsa_exception_int_div_zero 0 + .end_amdhsa_kernel + .text +.Lfunc_end0: + .size _Z9gemm_corePKhS0_PvS0_S0_iiiiiii, .Lfunc_end0-_Z9gemm_corePKhS0_PvS0_S0_iiiiiii + ; -- End function + .set _Z9gemm_corePKhS0_PvS0_S0_iiiiiii.num_vgpr, 54 + .set _Z9gemm_corePKhS0_PvS0_S0_iiiiiii.num_agpr, 16 + .set _Z9gemm_corePKhS0_PvS0_S0_iiiiiii.numbered_sgpr, 24 + .set _Z9gemm_corePKhS0_PvS0_S0_iiiiiii.num_named_barrier, 0 + .set _Z9gemm_corePKhS0_PvS0_S0_iiiiiii.private_seg_size, 0 + .set _Z9gemm_corePKhS0_PvS0_S0_iiiiiii.uses_vcc, 1 + .set _Z9gemm_corePKhS0_PvS0_S0_iiiiiii.uses_flat_scratch, 0 + .set _Z9gemm_corePKhS0_PvS0_S0_iiiiiii.has_dyn_sized_stack, 0 + .set _Z9gemm_corePKhS0_PvS0_S0_iiiiiii.has_recursion, 0 + .set _Z9gemm_corePKhS0_PvS0_S0_iiiiiii.has_indirect_call, 0 + .section .AMDGPU.csdata,"",@progbits +; Kernel info: +; codeLenInByte = 3424 +; TotalNumSgprs: 30 +; NumVgprs: 54 +; NumAgprs: 16 +; TotalNumVgprs: 72 +; ScratchSize: 0 +; MemoryBound: 0 +; FloatMode: 240 +; IeeeMode: 1 +; LDSByteSize: 0 bytes/workgroup (compile time only) +; SGPRBlocks: 3 +; VGPRBlocks: 8 +; NumSGPRsForWavesPerEU: 30 +; NumVGPRsForWavesPerEU: 72 +; AccumOffset: 56 +; Occupancy: 7 +; WaveLimiterHint : 0 +; COMPUTE_PGM_RSRC2:SCRATCH_EN: 0 +; COMPUTE_PGM_RSRC2:USER_SGPR: 2 +; COMPUTE_PGM_RSRC2:TRAP_HANDLER: 0 +; COMPUTE_PGM_RSRC2:TGID_X_EN: 1 +; COMPUTE_PGM_RSRC2:TGID_Y_EN: 1 +; COMPUTE_PGM_RSRC2:TGID_Z_EN: 1 +; COMPUTE_PGM_RSRC2:TIDIG_COMP_CNT: 0 +; COMPUTE_PGM_RSRC3_GFX90A:ACCUM_OFFSET: 13 +; COMPUTE_PGM_RSRC3_GFX90A:TG_SPLIT: 0 + .text + .protected _Z9reduce_skPKfP12hip_bfloat16ii ; -- Begin function _Z9reduce_skPKfP12hip_bfloat16ii + .globl _Z9reduce_skPKfP12hip_bfloat16ii + .p2align 8 + .type _Z9reduce_skPKfP12hip_bfloat16ii,@function +_Z9reduce_skPKfP12hip_bfloat16ii: ; @_Z9reduce_skPKfP12hip_bfloat16ii +; %bb.0: + s_load_dwordx2 s[4:5], s[0:1], 0x10 + v_lshl_add_u32 v0, s2, 8, v0 + s_waitcnt lgkmcnt(0) + v_cmp_gt_i32_e32 vcc, s4, v0 + s_and_saveexec_b64 s[2:3], vcc + s_cbranch_execz .LBB1_8 +; %bb.1: ; %.preheader + s_cmp_gt_i32 s5, 0 + v_ashrrev_i32_e32 v1, 31, v0 + s_cbranch_scc1 .LBB1_3 +; %bb.2: ; %.preheader.._crit_edge_crit_edge + s_load_dwordx2 s[2:3], s[0:1], 0x8 + v_mov_b32_e32 v2, 0 + s_cbranch_execz .LBB1_4 + s_branch .LBB1_7 +.LBB1_3: + s_load_dwordx2 s[2:3], s[0:1], 0x8 + v_mov_b32_e32 v2, 0 +.LBB1_4: ; %.lr.ph + s_load_dwordx2 s[6:7], s[0:1], 0x0 + s_ashr_i32 s1, s4, 31 + s_mov_b32 s0, s4 + s_lshl_b64 s[0:1], s[0:1], 2 + v_mov_b32_e32 v4, 0 + s_waitcnt lgkmcnt(0) + v_lshl_add_u64 v[2:3], v[0:1], 2, s[6:7] +.LBB1_5: ; =>This Inner Loop Header: Depth=1 + global_load_dword v5, v[2:3], off + s_add_i32 s5, s5, -1 + v_lshl_add_u64 v[2:3], v[2:3], 0, s[0:1] + s_cmp_eq_u32 s5, 0 + s_waitcnt vmcnt(0) + v_add_f32_e32 v4, v4, v5 + s_cbranch_scc0 .LBB1_5 +; %bb.6: ; %._crit_edge.loopexit + v_lshrrev_b32_e32 v2, 16, v4 +.LBB1_7: ; %Flow26 + s_waitcnt lgkmcnt(0) + v_lshl_add_u64 v[0:1], v[0:1], 1, s[2:3] + global_store_short v[0:1], v2, off +.LBB1_8: + s_endpgm + .section .rodata,"a",@progbits + .p2align 6, 0x0 + .amdhsa_kernel _Z9reduce_skPKfP12hip_bfloat16ii + .amdhsa_group_segment_fixed_size 0 + .amdhsa_private_segment_fixed_size 0 + .amdhsa_kernarg_size 24 + .amdhsa_user_sgpr_count 2 + .amdhsa_user_sgpr_dispatch_ptr 0 + .amdhsa_user_sgpr_queue_ptr 0 + .amdhsa_user_sgpr_kernarg_segment_ptr 1 + .amdhsa_user_sgpr_dispatch_id 0 + .amdhsa_user_sgpr_kernarg_preload_length 0 + .amdhsa_user_sgpr_kernarg_preload_offset 0 + .amdhsa_user_sgpr_private_segment_size 0 + .amdhsa_uses_dynamic_stack 0 + .amdhsa_enable_private_segment 0 + .amdhsa_system_sgpr_workgroup_id_x 1 + .amdhsa_system_sgpr_workgroup_id_y 0 + .amdhsa_system_sgpr_workgroup_id_z 0 + .amdhsa_system_sgpr_workgroup_info 0 + .amdhsa_system_vgpr_workitem_id 0 + .amdhsa_next_free_vgpr 6 + .amdhsa_next_free_sgpr 8 + .amdhsa_accum_offset 8 + .amdhsa_reserve_vcc 1 + .amdhsa_float_round_mode_32 0 + .amdhsa_float_round_mode_16_64 0 + .amdhsa_float_denorm_mode_32 3 + .amdhsa_float_denorm_mode_16_64 3 + .amdhsa_dx10_clamp 1 + .amdhsa_ieee_mode 1 + .amdhsa_fp16_overflow 0 + .amdhsa_tg_split 0 + .amdhsa_exception_fp_ieee_invalid_op 0 + .amdhsa_exception_fp_denorm_src 0 + .amdhsa_exception_fp_ieee_div_zero 0 + .amdhsa_exception_fp_ieee_overflow 0 + .amdhsa_exception_fp_ieee_underflow 0 + .amdhsa_exception_fp_ieee_inexact 0 + .amdhsa_exception_int_div_zero 0 + .end_amdhsa_kernel + .text +.Lfunc_end1: + .size _Z9reduce_skPKfP12hip_bfloat16ii, .Lfunc_end1-_Z9reduce_skPKfP12hip_bfloat16ii + ; -- End function + .set _Z9reduce_skPKfP12hip_bfloat16ii.num_vgpr, 6 + .set _Z9reduce_skPKfP12hip_bfloat16ii.num_agpr, 0 + .set _Z9reduce_skPKfP12hip_bfloat16ii.numbered_sgpr, 8 + .set _Z9reduce_skPKfP12hip_bfloat16ii.num_named_barrier, 0 + .set _Z9reduce_skPKfP12hip_bfloat16ii.private_seg_size, 0 + .set _Z9reduce_skPKfP12hip_bfloat16ii.uses_vcc, 1 + .set _Z9reduce_skPKfP12hip_bfloat16ii.uses_flat_scratch, 0 + .set _Z9reduce_skPKfP12hip_bfloat16ii.has_dyn_sized_stack, 0 + .set _Z9reduce_skPKfP12hip_bfloat16ii.has_recursion, 0 + .set _Z9reduce_skPKfP12hip_bfloat16ii.has_indirect_call, 0 + .section .AMDGPU.csdata,"",@progbits +; Kernel info: +; codeLenInByte = 176 +; TotalNumSgprs: 14 +; NumVgprs: 6 +; NumAgprs: 0 +; TotalNumVgprs: 6 +; ScratchSize: 0 +; MemoryBound: 0 +; FloatMode: 240 +; IeeeMode: 1 +; LDSByteSize: 0 bytes/workgroup (compile time only) +; SGPRBlocks: 1 +; VGPRBlocks: 0 +; NumSGPRsForWavesPerEU: 14 +; NumVGPRsForWavesPerEU: 6 +; AccumOffset: 8 +; Occupancy: 8 +; WaveLimiterHint : 0 +; COMPUTE_PGM_RSRC2:SCRATCH_EN: 0 +; COMPUTE_PGM_RSRC2:USER_SGPR: 2 +; COMPUTE_PGM_RSRC2:TRAP_HANDLER: 0 +; COMPUTE_PGM_RSRC2:TGID_X_EN: 1 +; COMPUTE_PGM_RSRC2:TGID_Y_EN: 0 +; COMPUTE_PGM_RSRC2:TGID_Z_EN: 0 +; COMPUTE_PGM_RSRC2:TIDIG_COMP_CNT: 0 +; COMPUTE_PGM_RSRC3_GFX90A:ACCUM_OFFSET: 1 +; COMPUTE_PGM_RSRC3_GFX90A:TG_SPLIT: 0 + .text + .p2alignl 6, 3212836864 + .fill 256, 4, 3212836864 + .section .AMDGPU.gpr_maximums,"",@progbits + .set amdgpu.max_num_vgpr, 0 + .set amdgpu.max_num_agpr, 0 + .set amdgpu.max_num_sgpr, 0 + .text + .type __hip_cuid_8b457700d3d88645,@object ; @__hip_cuid_8b457700d3d88645 + .section .bss,"aw",@nobits + .globl __hip_cuid_8b457700d3d88645 +__hip_cuid_8b457700d3d88645: + .byte 0 ; 0x0 + .size __hip_cuid_8b457700d3d88645, 1 + + .ident "AMD clang version 22.0.0git (https://github.com/RadeonOpenCompute/llvm-project roc-7.2.0 26014 7b800a19466229b8479a78de19143dc33c3ab9b5)" + .section ".note.GNU-stack","",@progbits + .addrsig + .addrsig_sym __hip_cuid_8b457700d3d88645 + .amdgpu_metadata +--- +amdhsa.kernels: + - .agpr_count: 16 + .args: + - .actual_access: read_only + .address_space: global + .offset: 0 + .size: 8 + .value_kind: global_buffer + - .actual_access: read_only + .address_space: global + .offset: 8 + .size: 8 + .value_kind: global_buffer + - .actual_access: write_only + .address_space: global + .offset: 16 + .size: 8 + .value_kind: global_buffer + - .actual_access: read_only + .address_space: global + .offset: 24 + .size: 8 + .value_kind: global_buffer + - .actual_access: read_only + .address_space: global + .offset: 32 + .size: 8 + .value_kind: global_buffer + - .offset: 40 + .size: 4 + .value_kind: by_value + - .offset: 44 + .size: 4 + .value_kind: by_value + - .offset: 48 + .size: 4 + .value_kind: by_value + - .offset: 52 + .size: 4 + .value_kind: by_value + - .offset: 56 + .size: 4 + .value_kind: by_value + - .offset: 60 + .size: 4 + .value_kind: by_value + - .offset: 64 + .size: 4 + .value_kind: by_value + .group_segment_fixed_size: 0 + .kernarg_segment_align: 8 + .kernarg_segment_size: 68 + .language: OpenCL C + .language_version: + - 2 + - 0 + .max_flat_workgroup_size: 64 + .name: _Z9gemm_corePKhS0_PvS0_S0_iiiiiii + .private_segment_fixed_size: 0 + .sgpr_count: 30 + .sgpr_spill_count: 0 + .symbol: _Z9gemm_corePKhS0_PvS0_S0_iiiiiii.kd + .uniform_work_group_size: 1 + .uses_dynamic_stack: false + .vgpr_count: 72 + .vgpr_spill_count: 0 + .wavefront_size: 64 + - .agpr_count: 0 + .args: + - .actual_access: read_only + .address_space: global + .offset: 0 + .size: 8 + .value_kind: global_buffer + - .actual_access: write_only + .address_space: global + .offset: 8 + .size: 8 + .value_kind: global_buffer + - .offset: 16 + .size: 4 + .value_kind: by_value + - .offset: 20 + .size: 4 + .value_kind: by_value + .group_segment_fixed_size: 0 + .kernarg_segment_align: 8 + .kernarg_segment_size: 24 + .language: OpenCL C + .language_version: + - 2 + - 0 + .max_flat_workgroup_size: 1024 + .name: _Z9reduce_skPKfP12hip_bfloat16ii + .private_segment_fixed_size: 0 + .sgpr_count: 14 + .sgpr_spill_count: 0 + .symbol: _Z9reduce_skPKfP12hip_bfloat16ii.kd + .uniform_work_group_size: 1 + .uses_dynamic_stack: false + .vgpr_count: 6 + .vgpr_spill_count: 0 + .wavefront_size: 64 +amdhsa.target: amdgcn-amd-amdhsa--gfx950 +amdhsa.version: + - 1 + - 2 +... + + .end_amdgpu_metadata + +# __CLANG_OFFLOAD_BUNDLE____END__ hip-amdgcn-amd-amdhsa--gfx950 + +# __CLANG_OFFLOAD_BUNDLE____START__ host-x86_64-unknown-linux-gnu- + .file "src.hip" + .text + .globl _Z24__device_stub__gemm_corePKhS0_PvS0_S0_iiiiiii # -- Begin function _Z24__device_stub__gemm_corePKhS0_PvS0_S0_iiiiiii + .p2align 4 + .type _Z24__device_stub__gemm_corePKhS0_PvS0_S0_iiiiiii,@function +_Z24__device_stub__gemm_corePKhS0_PvS0_S0_iiiiiii: # @_Z24__device_stub__gemm_corePKhS0_PvS0_S0_iiiiiii + .cfi_startproc +# %bb.0: + subq $200, %rsp + .cfi_def_cfa_offset 208 + movq %rdi, 88(%rsp) + movq %rsi, 80(%rsp) + movq %rdx, 72(%rsp) + movq %rcx, 64(%rsp) + movq %r8, 56(%rsp) + movl %r9d, 4(%rsp) + leaq 88(%rsp), %rax + movq %rax, 96(%rsp) + leaq 80(%rsp), %rax + movq %rax, 104(%rsp) + leaq 72(%rsp), %rax + movq %rax, 112(%rsp) + leaq 64(%rsp), %rax + movq %rax, 120(%rsp) + leaq 56(%rsp), %rax + movq %rax, 128(%rsp) + leaq 4(%rsp), %rax + movq %rax, 136(%rsp) + leaq 208(%rsp), %rax + movq %rax, 144(%rsp) + leaq 216(%rsp), %rax + movq %rax, 152(%rsp) + leaq 224(%rsp), %rax + movq %rax, 160(%rsp) + leaq 232(%rsp), %rax + movq %rax, 168(%rsp) + leaq 240(%rsp), %rax + movq %rax, 176(%rsp) + leaq 248(%rsp), %rax + movq %rax, 184(%rsp) + leaq 40(%rsp), %rdi + leaq 24(%rsp), %rsi + leaq 16(%rsp), %rdx + leaq 8(%rsp), %rcx + callq __hipPopCallConfiguration + movq 40(%rsp), %rsi + movl 48(%rsp), %edx + movq 24(%rsp), %rcx + movl 32(%rsp), %r8d + leaq 96(%rsp), %r9 + movl $_Z9gemm_corePKhS0_PvS0_S0_iiiiiii, %edi + pushq 8(%rsp) + .cfi_adjust_cfa_offset 8 + pushq 24(%rsp) + .cfi_adjust_cfa_offset 8 + callq hipLaunchKernel + addq $216, %rsp + .cfi_adjust_cfa_offset -216 + retq +.Lfunc_end0: + .size _Z24__device_stub__gemm_corePKhS0_PvS0_S0_iiiiiii, .Lfunc_end0-_Z24__device_stub__gemm_corePKhS0_PvS0_S0_iiiiiii + .cfi_endproc + # -- End function + .globl _Z24__device_stub__reduce_skPKfP12hip_bfloat16ii # -- Begin function _Z24__device_stub__reduce_skPKfP12hip_bfloat16ii + .p2align 4 + .type _Z24__device_stub__reduce_skPKfP12hip_bfloat16ii,@function +_Z24__device_stub__reduce_skPKfP12hip_bfloat16ii: # @_Z24__device_stub__reduce_skPKfP12hip_bfloat16ii + .cfi_startproc +# %bb.0: + subq $120, %rsp + .cfi_def_cfa_offset 128 + movq %rdi, 72(%rsp) + movq %rsi, 64(%rsp) + movl %edx, 12(%rsp) + movl %ecx, 8(%rsp) + leaq 72(%rsp), %rax + movq %rax, 80(%rsp) + leaq 64(%rsp), %rax + movq %rax, 88(%rsp) + leaq 12(%rsp), %rax + movq %rax, 96(%rsp) + leaq 8(%rsp), %rax + movq %rax, 104(%rsp) + leaq 48(%rsp), %rdi + leaq 32(%rsp), %rsi + leaq 24(%rsp), %rdx + leaq 16(%rsp), %rcx + callq __hipPopCallConfiguration + movq 48(%rsp), %rsi + movl 56(%rsp), %edx + movq 32(%rsp), %rcx + movl 40(%rsp), %r8d + leaq 80(%rsp), %r9 + movl $_Z9reduce_skPKfP12hip_bfloat16ii, %edi + pushq 16(%rsp) + .cfi_adjust_cfa_offset 8 + pushq 32(%rsp) + .cfi_adjust_cfa_offset 8 + callq hipLaunchKernel + addq $136, %rsp + .cfi_adjust_cfa_offset -136 + retq +.Lfunc_end1: + .size _Z24__device_stub__reduce_skPKfP12hip_bfloat16ii, .Lfunc_end1-_Z24__device_stub__reduce_skPKfP12hip_bfloat16ii + .cfi_endproc + # -- End function + .p2align 4 # -- Begin function __hip_module_ctor + .type __hip_module_ctor,@function +__hip_module_ctor: # @__hip_module_ctor + .cfi_startproc +# %bb.0: + pushq %rbx + .cfi_def_cfa_offset 16 + subq $32, %rsp + .cfi_def_cfa_offset 48 + .cfi_offset %rbx, -16 + movq __hip_gpubin_handle_8b457700d3d88645(%rip), %rbx + testq %rbx, %rbx + jne .LBB2_2 +# %bb.1: + movl $__hip_fatbin_wrapper, %edi + callq __hipRegisterFatBinary + movq %rax, %rbx + movq %rax, __hip_gpubin_handle_8b457700d3d88645(%rip) +.LBB2_2: + xorps %xmm0, %xmm0 + movups %xmm0, 16(%rsp) + movups %xmm0, (%rsp) + movl $_Z9gemm_corePKhS0_PvS0_S0_iiiiiii, %esi + movl $.L__unnamed_1, %edx + movl $.L__unnamed_1, %ecx + movq %rbx, %rdi + movl $-1, %r8d + xorl %r9d, %r9d + callq __hipRegisterFunction + xorps %xmm0, %xmm0 + movups %xmm0, 16(%rsp) + movups %xmm0, (%rsp) + movl $_Z9reduce_skPKfP12hip_bfloat16ii, %esi + movl $.L__unnamed_2, %edx + movl $.L__unnamed_2, %ecx + movq %rbx, %rdi + movl $-1, %r8d + xorl %r9d, %r9d + callq __hipRegisterFunction + movl $__hip_module_dtor, %edi + addq $32, %rsp + .cfi_def_cfa_offset 16 + popq %rbx + .cfi_def_cfa_offset 8 + jmp atexit # TAILCALL +.Lfunc_end2: + .size __hip_module_ctor, .Lfunc_end2-__hip_module_ctor + .cfi_endproc + # -- End function + .p2align 4 # -- Begin function __hip_module_dtor + .type __hip_module_dtor,@function +__hip_module_dtor: # @__hip_module_dtor + .cfi_startproc +# %bb.0: + movq __hip_gpubin_handle_8b457700d3d88645(%rip), %rdi + testq %rdi, %rdi + je .LBB3_2 +# %bb.1: + pushq %rax + .cfi_def_cfa_offset 16 + callq __hipUnregisterFatBinary + movq $0, __hip_gpubin_handle_8b457700d3d88645(%rip) + addq $8, %rsp + .cfi_def_cfa_offset 8 +.LBB3_2: + retq +.Lfunc_end3: + .size __hip_module_dtor, .Lfunc_end3-__hip_module_dtor + .cfi_endproc + # -- End function + .type _Z9gemm_corePKhS0_PvS0_S0_iiiiiii,@object # @_Z9gemm_corePKhS0_PvS0_S0_iiiiiii + .section .rodata,"a",@progbits + .globl _Z9gemm_corePKhS0_PvS0_S0_iiiiiii + .p2align 3, 0x0 +_Z9gemm_corePKhS0_PvS0_S0_iiiiiii: + .quad _Z24__device_stub__gemm_corePKhS0_PvS0_S0_iiiiiii + .size _Z9gemm_corePKhS0_PvS0_S0_iiiiiii, 8 + + .type _Z9reduce_skPKfP12hip_bfloat16ii,@object # @_Z9reduce_skPKfP12hip_bfloat16ii + .globl _Z9reduce_skPKfP12hip_bfloat16ii + .p2align 3, 0x0 +_Z9reduce_skPKfP12hip_bfloat16ii: + .quad _Z24__device_stub__reduce_skPKfP12hip_bfloat16ii + .size _Z9reduce_skPKfP12hip_bfloat16ii, 8 + + .type .L__unnamed_1,@object # @0 + .section .rodata.str1.1,"aMS",@progbits,1 +.L__unnamed_1: + .asciz "_Z9gemm_corePKhS0_PvS0_S0_iiiiiii" + .size .L__unnamed_1, 34 + + .type .L__unnamed_2,@object # @1 +.L__unnamed_2: + .asciz "_Z9reduce_skPKfP12hip_bfloat16ii" + .size .L__unnamed_2, 33 + + .type __hip_fatbin_wrapper,@object # @__hip_fatbin_wrapper + .section .hipFatBinSegment,"a",@progbits + .p2align 3, 0x0 +__hip_fatbin_wrapper: + .long 1212764230 # 0x48495046 + .long 1 # 0x1 + .quad __hip_fatbin_8b457700d3d88645 + .quad 0 + .size __hip_fatbin_wrapper, 24 + + .type __hip_gpubin_handle_8b457700d3d88645,@object # @__hip_gpubin_handle_8b457700d3d88645 + .local __hip_gpubin_handle_8b457700d3d88645 + .comm __hip_gpubin_handle_8b457700d3d88645,8,8 + .section .init_array,"aw",@init_array + .p2align 3, 0x0 + .quad __hip_module_ctor + .type __hip_cuid_8b457700d3d88645,@object # @__hip_cuid_8b457700d3d88645 + .bss + .globl __hip_cuid_8b457700d3d88645 +__hip_cuid_8b457700d3d88645: + .byte 0 # 0x0 + .size __hip_cuid_8b457700d3d88645, 1 + + .ident "AMD clang version 22.0.0git (https://github.com/RadeonOpenCompute/llvm-project roc-7.2.0 26014 7b800a19466229b8479a78de19143dc33c3ab9b5)" + .section ".note.GNU-stack","",@progbits + .addrsig + .addrsig_sym _Z24__device_stub__gemm_corePKhS0_PvS0_S0_iiiiiii + .addrsig_sym _Z24__device_stub__reduce_skPKfP12hip_bfloat16ii + .addrsig_sym __hip_module_ctor + .addrsig_sym __hip_module_dtor + .addrsig_sym _Z9gemm_corePKhS0_PvS0_S0_iiiiiii + .addrsig_sym _Z9reduce_skPKfP12hip_bfloat16ii + .addrsig_sym __hip_fatbin_8b457700d3d88645 + .addrsig_sym __hip_fatbin_wrapper + .addrsig_sym __hip_cuid_8b457700d3d88645 + +# __CLANG_OFFLOAD_BUNDLE____END__ host-x86_64-unknown-linux-gnu- diff --git a/examples/gemm_fp4/gemm_fp4_sm100.cu b/examples/gemm_fp4/gemm_fp4_sm100.cu new file mode 100644 index 0000000000..741a79427c --- /dev/null +++ b/examples/gemm_fp4/gemm_fp4_sm100.cu @@ -0,0 +1,100 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#ifdef ENABLE_BF16 +#include +#endif + +extern "C" __global__ void main_kernel(__grid_constant__ const CUtensorMap A_desc, __grid_constant__ const CUtensorMap B_desc, __grid_constant__ const CUtensorMap C_desc); +extern "C" __global__ void __launch_bounds__(256, 1) main_kernel(__grid_constant__ const CUtensorMap A_desc, __grid_constant__ const CUtensorMap B_desc, __grid_constant__ const CUtensorMap C_desc) { + extern __shared__ __align__(1024) uchar buf_dyn_shmem[]; + __shared__ __align__(16) uint64_t mbar_mem[2]; + auto mbar = reinterpret_cast(mbar_mem); + __shared__ __align__(16) uint64_t mbarrier_mem[4]; + auto mbarrier = reinterpret_cast(mbarrier_mem); + __shared__ __align__(16) uint C_tmem[1]; + tl::Tcgen05SMemDescriptor desc_a; + tl::Tcgen05SMemDescriptor desc_b; + float C_local[64]; + if (tl::tl_shuffle_elect<0>()) { + tl::prefetch_tma_descriptor(A_desc); + tl::prefetch_tma_descriptor(B_desc); + tl::prefetch_tma_descriptor(C_desc); + } + if (tl::tl_shuffle_elect<0>()) { + mbar[0].init(1); + mbar[1].init(1); + mbarrier[0].init(1); + mbarrier[1].init(1); + mbarrier[2].init(1); + mbarrier[3].init(1); + } + tl::fence_barrier_init(); + __syncthreads(); + if ((((int)threadIdx.x) >> 5) == 0) { + tl::tmem_allocate((&(C_tmem[0])), 128); + } + __syncthreads(); + if (((int)threadIdx.x) == 0) { + mbarrier[0].arrive_and_expect_tx(16256); + tl::tma_load(A_desc, mbarrier[0], (&(((fp4_e2_t*)buf_dyn_shmem)[0])), 0, (((int)blockIdx.y) * 128)); + mbarrier[1].arrive_and_expect_tx(16256); + tl::tma_load(B_desc, mbarrier[1], (&(((fp4_e2_t*)buf_dyn_shmem)[130048])), 0, (((int)blockIdx.x) * 128)); + mbarrier[2].arrive_and_expect_tx(16256); + tl::tma_load(A_desc, mbarrier[2], (&(((fp4_e2_t*)buf_dyn_shmem)[97536])), 128, (((int)blockIdx.y) * 128)); + mbarrier[3].arrive_and_expect_tx(16256); + tl::tma_load(B_desc, mbarrier[3], (&(((fp4_e2_t*)buf_dyn_shmem)[227584])), 128, (((int)blockIdx.x) * 128)); + } + mbarrier[0].wait(0); + mbarrier[1].wait(0); + __syncthreads(); + if ((((int)threadIdx.x) >> 5) == 0) { + tl::initialize_tcgen05_descriptor(desc_a, (&(((fp4_e2_t*)buf_dyn_shmem)[0])), 1, 64, 0, 0, 2); + tl::initialize_tcgen05_descriptor(desc_b, (&(((fp4_e2_t*)buf_dyn_shmem)[130048])), 1, 64, 0, 0, 2); + #pragma unroll + for (int ki = 0; ki < 4; ++ki) { + tl::tcgen05mma_ss(uint64_t(desc_a + (ki * 32)), uint64_t(desc_b + (ki * 32)), (*reinterpret_cast(C_tmem)) + 0, ((0 < ki) ? 1 : 0), static_cast(136320656), 0, 0, 0, 0); + } + tl::tcgen05_mma_arrive((&(mbar[0]))); + } + mbar[0].wait(0); + mbarrier[2].wait(0); + mbarrier[3].wait(0); + if ((((int)threadIdx.x) >> 5) == 0) { + tl::initialize_tcgen05_descriptor(desc_a, (&(((fp4_e2_t*)buf_dyn_shmem)[97536])), 1, 64, 0, 0, 2); + tl::initialize_tcgen05_descriptor(desc_b, (&(((fp4_e2_t*)buf_dyn_shmem)[227584])), 1, 64, 0, 0, 2); + tl::fence_proxy_async(); + #pragma unroll + for (int ki_1 = 0; ki_1 < 4; ++ki_1) { + tl::tcgen05mma_ss(uint64_t(desc_a + (ki_1 * 32)), uint64_t(desc_b + (ki_1 * 32)), (*reinterpret_cast(C_tmem)) + 0, 1, static_cast(136320656), 0, 0, 0, 0); + } + tl::tcgen05_mma_arrive((&(mbar[1]))); + } + mbar[1].wait(0); + tl::tcgen05_ld_32dp32bNx<64, false>(C_tmem[0], ((((int)threadIdx.x) >> 7) * 64), (&(C_local[0]))); + __syncthreads(); + #pragma unroll + for (int i = 0; i < 16; ++i) { + *(float4*)(((float*)buf_dyn_shmem) + (((((((((int)threadIdx.x) >> 7) * 8192) + ((i >> 3) * 4096)) + ((((int)threadIdx.x) & 127) * 32)) + (((((i & 7) >> 2) + ((((int)threadIdx.x) & 7) >> 2)) & 1) * 16)) + (((((i & 3) >> 1) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 8)) + ((((i & 1) + (((int)threadIdx.x) & 1)) & 1) * 4))) = *(float4*)(C_local + (i * 4)); + } + __syncthreads(); + if (((int)threadIdx.x) == 0) { + tl::fence_proxy_async(); + tl::tma_store(C_desc, (&(((float*)buf_dyn_shmem)[0])), (((int)blockIdx.x) * 128), (((int)blockIdx.y) * 128)); + tl::tma_store(C_desc, (&(((float*)buf_dyn_shmem)[4096])), ((((int)blockIdx.x) * 128) + 32), (((int)blockIdx.y) * 128)); + tl::tma_store(C_desc, (&(((float*)buf_dyn_shmem)[8192])), ((((int)blockIdx.x) * 128) + 64), (((int)blockIdx.y) * 128)); + tl::tma_store(C_desc, (&(((float*)buf_dyn_shmem)[12288])), ((((int)blockIdx.x) * 128) + 96), (((int)blockIdx.y) * 128)); + tl::tma_store_arrive(); + tl::tma_store_wait<0>(); + } + if ((((int)threadIdx.x) >> 5) == 0) { + tl::tmem_deallocate((&(C_tmem[0])), 128); + } +} + diff --git a/examples/gemm_sm100/gemm_tcgen5mma_fp4.py b/examples/gemm_sm100/gemm_tcgen5mma_fp4.py new file mode 100644 index 0000000000..28195c4942 --- /dev/null +++ b/examples/gemm_sm100/gemm_tcgen5mma_fp4.py @@ -0,0 +1,214 @@ +"""FP4 (float4_e2m1fn) GEMM on SM100/SM110 (B200/Thor) using TCGEN05 async MMA. + +Uses tcgen05.mma.kind::f8f6f4 with TMEM accumulator. +The current TMA path keeps global FP4 packed, then writes it into SMEM using the +ALIGN16B gap-aware layout expected by tcgen05.mma. + +Supported: SM100 (B100/B200), SM101/SM110 (DRIVE Thor), SM103 (B300). +""" + +import time +import os +import torch +import tilelang +import tilelang.language as T + + +FP4_E2M1_TO_FLOAT = [ + 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, + -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0, +] + + +def matmul_fp4_sm100( + M, N, K, block_M, block_N, block_K, + in_dtype, out_dtype, accum_dtype, + num_stages=2, threads=256, + transpose_b=True, +): + """FP4 GEMM using TCGEN05 async MMA + TMEM. + + TCGEN05 handles packed FP4 natively — no unpack, no ldmatrix, no bit-shift. + Data flow: global(packed) -> TMA -> SMEM(ALIGN16B gap-aware) -> TCGEN05 MMA -> TMEM. + """ + A_shape = (M, K) + A_shared_shape = (block_M, block_K) + if transpose_b: + B_shape = (N, K) + B_shared_shape = (block_N, block_K) + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype) + mbar = T.alloc_barrier(1) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_shared = T.alloc_shared((block_M, block_N), out_dtype) + + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[bx * block_N, k * block_K], B_shared) + T.tcgen05_gemm( + A_shared, B_shared, C_tmem, + transpose_A=False, transpose_B=True, + mbar=mbar, clear_accum=(k == 0), + ) + T.mbarrier_wait_parity(mbar, k % 2) + + T.copy(C_tmem, C_local) + T.copy(C_local, C_shared) + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return main + + B_shape = (K, N) + B_shared_shape = (block_K, block_N) + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype) + mbar = T.alloc_barrier(1) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_shared = T.alloc_shared((block_M, block_N), out_dtype) + + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[k * block_K, bx * block_N], B_shared) + T.tcgen05_gemm( + A_shared, B_shared, C_tmem, + transpose_A=False, transpose_B=False, + mbar=mbar, clear_accum=(k == 0), + ) + T.mbarrier_wait_parity(mbar, k % 2) + + T.copy(C_tmem, C_local) + T.copy(C_local, C_shared) + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return main + + +def unpack_fp4_to_float(packed_int8, M, K): + lut = torch.tensor(FP4_E2M1_TO_FLOAT, dtype=torch.float32, device=packed_int8.device) + flat = packed_int8.to(torch.uint8).reshape(M, K // 2) + lo = flat & 0x0F + hi = (flat >> 4) & 0x0F + unpacked = torch.stack([lo, hi], dim=-1).reshape(M, K).to(torch.int64) + return lut[unpacked] + + +M = int(os.environ.get("TL_FP4_M", "256")) +N = int(os.environ.get("TL_FP4_N", "256")) +K = int(os.environ.get("TL_FP4_K", "256")) +# +# Thor has a tighter per-CTA dynamic shared-memory budget than server Blackwell. +# Keep K=128 for the current ALIGN16B gap-aware layout, but reduce N so the +# example launches reliably before we debug numerical accuracy. +# +# We intentionally keep block_M=128 so TCGEN05 stores C through the standard +# atom_m=128 TMEM layout (Layout D) instead of the atom_m=64 WS-specific layout. +# +block_M = int(os.environ.get("TL_FP4_BLOCK_M", "128")) +block_N = int(os.environ.get("TL_FP4_BLOCK_N", "64")) +block_K = int(os.environ.get("TL_FP4_BLOCK_K", "128")) +in_dtype = T.float4_e2m1fn +out_dtype = T.float32 +accum_dtype = T.float32 +num_stages = 1 +threads = 128 + +input_mode = os.environ.get("TL_FP4_INPUT_MODE", "random") +transpose_b = os.environ.get("TL_FP4_TRANSPOSE_B", "1") != "0" +print( + f"Running FP4 GEMM (SM100/SM110 TCGEN05): M={M}, N={N}, K={K}, " + f"input_mode={input_mode}, transpose_b={transpose_b}" +) + +func = matmul_fp4_sm100( + M, N, K, block_M, block_N, block_K, + in_dtype, out_dtype, accum_dtype, + num_stages, threads, transpose_b, +) + +jit_kernel = tilelang.compile( + func, + out_idx=[2], + target="cuda", + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: False, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, +) + +print("Compilation succeeded!") + +torch.manual_seed(42) + +def make_random_fp4(rows, cols, mode): + if mode == "positive": + lo = torch.randint(0, 8, (rows, cols // 2), device="cuda", dtype=torch.uint8) + hi = torch.randint(0, 8, (rows, cols // 2), device="cuda", dtype=torch.uint8) + return (lo | (hi << 4)).to(torch.int8) + if mode == "low_nibble": + lo = torch.randint(0, 16, (rows, cols // 2), device="cuda", dtype=torch.uint8) + return lo.to(torch.int8) + if mode == "high_nibble": + hi = torch.randint(0, 16, (rows, cols // 2), device="cuda", dtype=torch.uint8) + return (hi << 4).to(torch.int8) + if mode == "random": + return torch.randint(0, 256, (rows, cols // 2), device="cuda", dtype=torch.uint8).to(torch.int8) + raise ValueError(f"Unsupported TL_FP4_INPUT_MODE={mode}") + + +# Packed FP4 tensors (2 per byte) — TCGEN05 reads packed data directly. +a_packed = make_random_fp4(M, K, input_mode) +b_packed = make_random_fp4(N, K, input_mode) if transpose_b else make_random_fp4(K, N, input_mode) + +# --- Test 1: zeros --- +a_zero = torch.zeros(M, K // 2, device="cuda", dtype=torch.int8) +b_zero = ( + torch.zeros(N, K // 2, device="cuda", dtype=torch.int8) + if transpose_b + else torch.zeros(K, N // 2, device="cuda", dtype=torch.int8) +) +c_zero = jit_kernel(a_zero, b_zero) +assert c_zero.abs().max().item() == 0.0, f"Zero test failed: max={c_zero.abs().max().item()}" +print("[PASS] zeros in -> zeros out") + +# --- Test 2: numerical verification --- +c = jit_kernel(a_packed, b_packed) +a_float = unpack_fp4_to_float(a_packed, M, K) +b_float = unpack_fp4_to_float(b_packed, N, K) if transpose_b else unpack_fp4_to_float(b_packed, K, N) +ref_c = a_float @ (b_float.T if transpose_b else b_float) + +diff = (c.float() - ref_c).abs() +max_diff = diff.max().item() +rel_err = diff.sum().item() / (ref_c.abs().sum().item() + 1e-10) +print(f"[NUMERICAL] max_abs_diff={max_diff:.4f}, rel_err={rel_err:.6f}") +if max_diff < 1.0: + print("[PASS] numerical verification") +else: + print("[WARN] large diff") + +# --- Benchmark --- +torch.cuda.synchronize() +start = time.perf_counter() +for _ in range(100): + jit_kernel(a_packed, b_packed) +torch.cuda.synchronize() +elapsed = (time.perf_counter() - start) / 100 * 1000 +print(f"Latency: {elapsed:.4f} ms") +print(f"TFLOPS: {2 * M * N * K / (elapsed / 1e3) / 1e12:.2f}") diff --git a/examples/gemm_sm100/test_fp4_diagonal.py b/examples/gemm_sm100/test_fp4_diagonal.py new file mode 100644 index 0000000000..f16147b2e8 --- /dev/null +++ b/examples/gemm_sm100/test_fp4_diagonal.py @@ -0,0 +1,216 @@ +"""Diagnostic: FP4 GEMM with diagonal B matrix. + +C = A @ B^T where B = identity → C should equal A (in float). +Mismatches reveal exactly which (row, col) the MMA reads incorrectly. +""" + +import os +import torch +import tilelang +import tilelang.language as T + +M, N, K = 256, 256, 256 +# Keep K=128 for the current ALIGN16B gap-aware layout, but reduce N so the +# diagnostic runs within Thor's dynamic shared-memory budget. +# +# We intentionally keep block_M=128 so TCGEN05 stores C through the standard +# atom_m=128 TMEM layout (Layout D) instead of the atom_m=64 WS-specific layout. +block_M, block_N, block_K = 128, 64, 128 +in_dtype = T.float4_e2m1fn +out_dtype = T.float32 +accum_dtype = T.float32 + +FP4_E2M1_TO_FLOAT = [ + 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, + -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0, +] + + +def unpack_fp4_to_float(packed_int8, rows, cols): + lut = torch.tensor(FP4_E2M1_TO_FLOAT, dtype=torch.float32, device=packed_int8.device) + flat = packed_int8.to(torch.uint8).reshape(rows, cols // 2) + lo = flat & 0x0F + hi = (flat >> 4) & 0x0F + unpacked = torch.stack([lo, hi], dim=-1).reshape(rows, cols).to(torch.int64) + return lut[unpacked] + + +def make_fp4_diagonal(N, K): + """Create packed FP4 identity matrix B[N,K] where B[n,k]=1.0 if n==k.""" + packed = torch.zeros(N, K // 2, dtype=torch.uint8, device="cuda") + fp4_one = 2 # FP4 encoding for 1.0 + diag = torch.arange(min(N, K), device="cuda", dtype=torch.int64) + byte_idx = diag // 2 + nibble_shift = (diag % 2).to(torch.uint8) * 4 + packed[diag, byte_idx] = torch.bitwise_left_shift( + torch.full_like(nibble_shift, fp4_one, dtype=torch.uint8), nibble_shift + ) + return packed.to(torch.int8) + + +def make_fp4_row_pattern(M, K): + """A[m,k] = FP4 value based on (m % 8), so each row has a recognizable pattern.""" + vals = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7], dtype=torch.uint8, device="cuda") + row_vals = vals[torch.arange(M, device="cuda", dtype=torch.int64) % 8] + packed_row_vals = row_vals | torch.bitwise_left_shift(row_vals, 4) + packed = packed_row_vals[:, None].expand(M, K // 2).contiguous() + return packed.to(torch.int8) + + +def print_generated_kernel_debug_summary(kernel_src_path: str): + keywords = ( + "arrive_and_expect_tx", + "initialize_tcgen05_descriptor", + "fence_proxy_async", + "tcgen05mma_ss", + ".wait(", + ) + with open(kernel_src_path, "r", encoding="utf-8") as f: + lines = f.readlines() + + first_mma_line = None + first_fence_line = None + + print("Generated kernel debug summary:") + for line_no, line in enumerate(lines, start=1): + if "tcgen05mma_ss" in line and first_mma_line is None: + first_mma_line = line_no + if "fence_proxy_async" in line and first_fence_line is None: + first_fence_line = line_no + if any(keyword in line for keyword in keywords): + print(f"{line_no:4d}: {line.rstrip()}") + + if first_mma_line is None: + print("No tcgen05mma_ss call found in generated source.") + return + + if first_fence_line is None: + print("No fence_proxy_async() found in generated source.") + elif first_fence_line < first_mma_line: + print( + f"First fence_proxy_async() appears before first tcgen05mma_ss " + f"(fence line {first_fence_line}, mma line {first_mma_line})." + ) + else: + print( + f"First fence_proxy_async() appears after first tcgen05mma_ss " + f"(fence line {first_fence_line}, mma line {first_mma_line})." + ) + + +def matmul_fp4_sm100(M, N, K, block_M, block_N, block_K, in_dtype, out_dtype, accum_dtype): + @T.prim_func + def main( + A: T.Tensor((M, K), in_dtype), + B: T.Tensor((N, K), in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), in_dtype) + B_shared = T.alloc_shared((block_N, block_K), in_dtype) + C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype) + mbar = T.alloc_barrier(1) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_shared = T.alloc_shared((block_M, block_N), out_dtype) + + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=1): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[bx * block_N, k * block_K], B_shared) + T.tcgen05_gemm(A_shared, B_shared, C_tmem, + transpose_A=False, transpose_B=True, + mbar=mbar, clear_accum=(k == 0)) + T.mbarrier_wait_parity(mbar, k % 2) + + T.copy(C_tmem, C_local) + T.copy(C_local, C_shared) + T.copy(C_shared, C[by * block_M, bx * block_N]) + return main + + +print(f"FP4 Diagonal Test: M={M}, N={N}, K={K}") + +func = matmul_fp4_sm100(M, N, K, block_M, block_N, block_K, in_dtype, out_dtype, accum_dtype) + +kernel = tilelang.compile(func, out_idx=[2], target="cuda", pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: False, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, +}) +print("Compiled OK") + +kernel_src_path = os.path.join(os.path.dirname(__file__), "test_fp4_diagonal.generated.cu") +with open(kernel_src_path, "w", encoding="utf-8") as f: + f.write(kernel.get_kernel_source()) +print(f"Kernel source written to: {kernel_src_path}") +print_generated_kernel_debug_summary(kernel_src_path) + +if os.environ.get("TL_FP4_COMPILE_ONLY", "0") == "1": + print("TL_FP4_COMPILE_ONLY=1, skip kernel launch.") + raise SystemExit(0) + +# --- Test: A = row pattern, B = identity --- +print("Building FP4 test inputs...") +a_packed = make_fp4_row_pattern(M, K) +b_packed = make_fp4_diagonal(N, K) +torch.cuda.synchronize() +print("Launching kernel...") + +c = kernel(a_packed, b_packed) +torch.cuda.synchronize() +print("Kernel finished.") + +a_float = unpack_fp4_to_float(a_packed, M, K) +b_float = unpack_fp4_to_float(b_packed, N, K) +ref_c = a_float @ b_float.T # should ≈ A (since B is identity) + +diff = (c.float() - ref_c).abs() +max_diff = diff.max().item() +rel_err = diff.sum().item() / (ref_c.abs().sum().item() + 1e-10) +print(f"max_abs_diff={max_diff:.4f}, rel_err={rel_err:.6f}") + +out = c +print(f"\n{'='*80}") +print(f"Element-by-element comparison (first 8 rows x first 48 cols):") +print(f"{'='*80}") + +# Print header +print(f"{'':>6}", end="") +for col in range(48): + print(f" {col:>5}", end="") +print() + +for r in range(8): + # Expected row + print(f"E[{r:>2}]", end="") + for col in range(48): + print(f" {ref_c[r,col].item():>5.1f}", end="") + print() + # Actual row + print(f"A[{r:>2}]", end="") + for col in range(48): + v = out[r, col].item() + e = ref_c[r, col].item() + mark = " " if abs(v - e) < 0.01 else "*" + print(f"{v:>5.1f}{mark}", end="") + print() + print() + +print(f"\n{'='*80}") +print(f"Columns 120-135 (crossing 128 boundary):") +print(f"{'='*80}") +print(f"{'':>6}", end="") +for col in range(120, 136): + print(f" {col:>5}", end="") +print() +for r in range(4): + print(f"E[{r:>2}]", end="") + for col in range(120, 136): + print(f" {ref_c[r,col].item():>5.1f}", end="") + print() + print(f"A[{r:>2}]", end="") + for col in range(120, 136): + v = out[r, col].item() + e = ref_c[r, col].item() + mark = " " if abs(v - e) < 0.01 else "*" + print(f"{v:>5.1f}{mark}", end="") + print() + print() diff --git a/examples/gemm_sm100/test_fp4_single_tile_probe.py b/examples/gemm_sm100/test_fp4_single_tile_probe.py new file mode 100644 index 0000000000..5a70205f30 --- /dev/null +++ b/examples/gemm_sm100/test_fp4_single_tile_probe.py @@ -0,0 +1,276 @@ +"""Single-tile FP4 GEMM probe for SM100/SM110. + +Purpose: + 1. Remove the second K tile so we only exercise the first + TMA -> tcgen05.mma -> mbarrier wait sequence. + 2. Allow quick TMA vs non-TMA comparison via TL_FP4_DISABLE_TMA=1. + 3. Allow isolating the epilogue path via TL_FP4_LINEAR_EPILOGUE=1, + so we can compare "TMA input + linear global store" against the + default "TMA input + TMA store" path. + 4. Allow forcing an explicit fence between copy/wait and tcgen05.mma via + TL_FP4_FORCE_FENCE=1 to test whether the TMA->UMMA handoff is missing a + proxy fence. + 5. Allow stopping immediately after input load via TL_FP4_TMA_ONLY=1, so + TMA wait and tcgen05 MMA wait can be isolated. + 6. Allow forcing the same descriptor/gap-aware input TMA path without MMA via + TL_FP4_TMA_ONLY_DESCRIPTOR=1. + +Run: + python examples/gemm_sm100/test_fp4_single_tile_probe.py + TL_FP4_DISABLE_TMA=1 python examples/gemm_sm100/test_fp4_single_tile_probe.py + TL_FP4_COMPILE_ONLY=1 python examples/gemm_sm100/test_fp4_single_tile_probe.py + TL_FP4_LINEAR_EPILOGUE=1 python examples/gemm_sm100/test_fp4_single_tile_probe.py + TL_FP4_FORCE_FENCE=1 python examples/gemm_sm100/test_fp4_single_tile_probe.py + TL_FP4_TMA_ONLY=1 python examples/gemm_sm100/test_fp4_single_tile_probe.py + TL_FP4_TMA_ONLY=1 TL_FP4_TMA_ONLY_DESCRIPTOR=1 python examples/gemm_sm100/test_fp4_single_tile_probe.py +""" + +import os + +import torch +import tilelang +import tilelang.language as T + +M, N, K = 128, 64, 128 +block_M, block_N, block_K = 128, 64, 128 +in_dtype = T.float4_e2m1fn +out_dtype = T.float32 +accum_dtype = T.float32 + +FP4_E2M1_TO_FLOAT = [ + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, + -0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, +] + + +def unpack_fp4_to_float(packed_int8, rows, cols): + lut = torch.tensor( + FP4_E2M1_TO_FLOAT, dtype=torch.float32, device=packed_int8.device + ) + flat = packed_int8.to(torch.uint8).reshape(rows, cols // 2) + lo = flat & 0x0F + hi = (flat >> 4) & 0x0F + unpacked = torch.stack([lo, hi], dim=-1).reshape(rows, cols).to(torch.int64) + return lut[unpacked] + + +def make_fp4_diagonal(rows, cols): + packed = torch.zeros(rows, cols // 2, dtype=torch.uint8, device="cuda") + fp4_one = 2 # FP4 encoding for 1.0 + diag = torch.arange(min(rows, cols), device="cuda", dtype=torch.int64) + byte_idx = diag // 2 + nibble_shift = (diag % 2).to(torch.uint8) * 4 + packed[diag, byte_idx] = torch.bitwise_left_shift( + torch.full_like(nibble_shift, fp4_one, dtype=torch.uint8), nibble_shift + ) + return packed.to(torch.int8) + + +def make_fp4_row_pattern(rows, cols): + vals = torch.tensor( + [0, 1, 2, 3, 4, 5, 6, 7], dtype=torch.uint8, device="cuda" + ) + row_vals = vals[torch.arange(rows, device="cuda", dtype=torch.int64) % 8] + packed_row_vals = row_vals | torch.bitwise_left_shift(row_vals, 4) + packed = packed_row_vals[:, None].expand(rows, cols // 2).contiguous() + return packed.to(torch.int8) + + +def print_generated_kernel_debug_summary(kernel_src_path: str): + keywords = ( + "arrive_and_expect_tx", + "expect_transaction", + "tma_load(", + "tma_store(", + "tma_store_arrive", + "tma_store_wait", + "initialize_tcgen05_descriptor", + "fence_proxy_async", + "tcgen05mma_ss", + ".wait(", + ) + + with open(kernel_src_path, "r", encoding="utf-8") as f: + lines = f.readlines() + + print("Generated kernel debug summary:") + for line_no, line in enumerate(lines, start=1): + if any(keyword in line for keyword in keywords): + print(f"{line_no:4d}: {line.rstrip()}") + + +def fp4_single_tile_kernel( + M, + N, + K, + block_M, + block_N, + block_K, + in_dtype, + out_dtype, + accum_dtype, + linear_epilogue, + force_fence, + tma_only, + tma_only_descriptor, +): + @T.prim_func + def main( + A: T.Tensor((M, K), in_dtype), + B: T.Tensor((N, K), in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(1, 1, threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), in_dtype) + B_shared = T.alloc_shared((block_N, block_K), in_dtype) + C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype) + if not disable_tma: + load_mbar_a = T.alloc_barrier(128) + load_mbar_b = T.alloc_barrier(128) + mbar = T.alloc_barrier(1) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + if not linear_epilogue: + C_shared = T.alloc_shared((block_M, block_N), out_dtype) + if tma_only and tma_only_descriptor: + T.annotate_layout( + { + A_shared: tilelang.layout.make_tcgen05mma_swizzled_layout( + A_shared, continuity=block_K, k_major=True + ), + B_shared: tilelang.layout.make_tcgen05mma_swizzled_layout( + B_shared, continuity=block_K, k_major=True + ), + } + ) + + if disable_tma: + T.copy(A[0, 0], A_shared) + T.copy(B[0, 0], B_shared) + else: + T.tma_copy(A[0, 0], A_shared, barrier=load_mbar_a) + T.tma_copy(B[0, 0], B_shared, barrier=load_mbar_b) + T.barrier_arrive(load_mbar_a) + T.barrier_arrive(load_mbar_b) + T.mbarrier_wait_parity(load_mbar_a, 0) + T.mbarrier_wait_parity(load_mbar_b, 0) + T.fence_proxy_async() + if tma_only: + return + if force_fence: + T.fence_proxy_async() + T.tcgen05_gemm( + A_shared, + B_shared, + C_tmem, + transpose_A=False, + transpose_B=True, + mbar=mbar, + clear_accum=True, + ) + T.mbarrier_wait_parity(mbar, 0) + + T.copy(C_tmem, C_local) + if linear_epilogue: + T.copy(C_local, C[0, 0]) + else: + T.copy(C_local, C_shared) + T.copy(C_shared, C[0, 0]) + + return main + + +disable_tma = os.environ.get("TL_FP4_DISABLE_TMA", "0") == "1" +linear_epilogue = os.environ.get("TL_FP4_LINEAR_EPILOGUE", "0") == "1" +force_fence = os.environ.get("TL_FP4_FORCE_FENCE", "0") == "1" +tma_only = os.environ.get("TL_FP4_TMA_ONLY", "0") == "1" +tma_only_descriptor = os.environ.get("TL_FP4_TMA_ONLY_DESCRIPTOR", "0") == "1" +mode_parts = ["non_tma" if disable_tma else "tma"] +if linear_epilogue: + mode_parts.append("linear_epilogue") +if force_fence: + mode_parts.append("force_fence") +if tma_only: + mode_parts.append("tma_only") +if tma_only_descriptor: + mode_parts.append("descriptor") +mode = ".".join(mode_parts) +print( + f"FP4 single-tile probe: mode={mode}, M={M}, N={N}, K={K}, " + f"block=({block_M},{block_N},{block_K})" +) + +func = fp4_single_tile_kernel( + M, + N, + K, + block_M, + block_N, + block_K, + in_dtype, + out_dtype, + accum_dtype, + linear_epilogue, + force_fence, + tma_only, + tma_only_descriptor, +) +kernel = tilelang.compile( + func, + out_idx=[2], + target="cuda", + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: disable_tma, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, +) +print("Compiled OK") + +kernel_src_path = os.path.join( + os.path.dirname(__file__), f"test_fp4_single_tile_probe.{mode}.generated.cu" +) +with open(kernel_src_path, "w", encoding="utf-8") as f: + f.write(kernel.get_kernel_source()) +print(f"Kernel source written to: {kernel_src_path}") +print_generated_kernel_debug_summary(kernel_src_path) + +if os.environ.get("TL_FP4_COMPILE_ONLY", "0") == "1": + print("TL_FP4_COMPILE_ONLY=1, skip kernel launch.") + raise SystemExit(0) + +print("Building FP4 test inputs...") +a_packed = make_fp4_row_pattern(M, K) +b_packed = make_fp4_diagonal(N, K) +torch.cuda.synchronize() + +print("Launching kernel...") +c = kernel(a_packed, b_packed) +print("Kernel launch returned, synchronizing...") +torch.cuda.synchronize() +print("Kernel finished.") + +if tma_only: + print("TL_FP4_TMA_ONLY=1, skipped GEMM and output comparison.") + raise SystemExit(0) + +a_float = unpack_fp4_to_float(a_packed, M, K) +b_float = unpack_fp4_to_float(b_packed, N, K) +ref_c = a_float @ b_float.T + +diff = (c.float() - ref_c).abs() +max_diff = diff.max().item() +rel_err = diff.sum().item() / (ref_c.abs().sum().item() + 1e-10) +print(f"max_abs_diff={max_diff:.4f}, rel_err={rel_err:.6f}") diff --git a/examples/gemm_sm100/test_fp4_single_tile_probe.tma.generated.cu b/examples/gemm_sm100/test_fp4_single_tile_probe.tma.generated.cu new file mode 100644 index 0000000000..722dd72830 --- /dev/null +++ b/examples/gemm_sm100/test_fp4_single_tile_probe.tma.generated.cu @@ -0,0 +1,102 @@ +#if defined(_MSC_VER) && !defined(__clang__) && _MSC_VER < 1940 +#define _tl_orig_alignas alignas +#define alignas(N) _tl_orig_alignas((N) <= 64 ? (N) : 64) +#include +#undef alignas +#define alignas _tl_orig_alignas +#endif +#include +#include +#include +#include +#include +#include +#include +#include +#include +#ifdef ENABLE_BF16 +#include +#endif + +extern "C" __global__ void main_kernel(__grid_constant__ const CUtensorMap A_desc, __grid_constant__ const CUtensorMap B_desc, float* __restrict__ C); +extern "C" __global__ void __launch_bounds__(128, 1) main_kernel(__grid_constant__ const CUtensorMap A_desc, __grid_constant__ const CUtensorMap B_desc, float* __restrict__ C) { + extern __shared__ __align__(1024) uchar buf_dyn_shmem[]; + __shared__ __align__(16) uint64_t load_mbar_a_mem[1]; + auto load_mbar_a = reinterpret_cast(load_mbar_a_mem); + __shared__ __align__(16) uint64_t load_mbar_b_mem[1]; + auto load_mbar_b = reinterpret_cast(load_mbar_b_mem); + __shared__ __align__(16) uint64_t mbar_mem[1]; + auto mbar = reinterpret_cast(mbar_mem); + __shared__ __align__(16) uint C_tmem[1]; + float C_local[64]; + if (tl::tl_shuffle_elect<0>()) { + tl::prefetch_tma_descriptor(A_desc); + tl::prefetch_tma_descriptor(B_desc); + } + if (tl::tl_shuffle_elect<0>()) { + load_mbar_a[0].init(128); + load_mbar_b[0].init(128); + mbar[0].init(1); + } + tl::fence_barrier_init(); + tl::tcgen05_before_thread_sync(); + __syncthreads(); + tl::tcgen05_after_thread_sync(); + if ((((int)threadIdx.x) >> 5) == 0) { + tl::tmem_allocate((&(C_tmem[0])), 64); + } + tl::tcgen05_before_thread_sync(); + __syncthreads(); + tl::tcgen05_after_thread_sync(); + if (tl::tl_shuffle_elect<128>()) { + load_mbar_a[0].expect_transaction(16256); + tl::tma_load(A_desc, load_mbar_a[0], (&(((fp4_e2_t*)buf_dyn_shmem)[0])), 0, 0); + load_mbar_b[0].expect_transaction(8128); + tl::tma_load(B_desc, load_mbar_b[0], (&(((fp4_e2_t*)buf_dyn_shmem)[32512])), 0, 0); + } + load_mbar_a[0].arrive(); + load_mbar_b[0].arrive(); + load_mbar_a[0].wait(0); + load_mbar_b[0].wait(0); + tl::tcgen05_after_thread_sync(); + tl::fence_proxy_async(); + { + tl::Tcgen05SMemDescriptor desc_a; + tl::Tcgen05SMemDescriptor desc_b; + tl::tcgen05_before_thread_sync(); + __syncthreads(); + tl::tcgen05_after_thread_sync(); + if ((((int)threadIdx.x) >> 5) == 0) { + tl::initialize_tcgen05_descriptor(desc_a, (&(((fp4_e2_t*)buf_dyn_shmem)[0])), 1, 64, 0, 0, 2); + tl::initialize_tcgen05_descriptor(desc_b, (&(((fp4_e2_t*)buf_dyn_shmem)[32512])), 1, 64, 0, 0, 2); + #pragma unroll + for (int ki = 0; ki < 4; ++ki) { + tl::tcgen05mma_ws_ss(uint64_t(desc_a + (ki * 32)), uint64_t(desc_b + (ki * 32)), (*reinterpret_cast(C_tmem)) + 0, ((0 < ki) ? 1 : 0), static_cast(135272080), 0, 0, 0, 0); + } + tl::tcgen05_mma_arrive((&(mbar[0]))); + } + } + mbar[0].wait(0); + tl::tcgen05_after_thread_sync(); + tl::tcgen05_ld_32dp32bNx<64, false>(C_tmem[0], 0, (&(C_local[0]))); + tl::tcgen05_before_thread_sync(); + __syncthreads(); + tl::tcgen05_after_thread_sync(); + #pragma unroll + for (int i = 0; i < 16; ++i) { + *(float4*)(((float*)buf_dyn_shmem) + ((((int)threadIdx.x) * 64) + (i * 4))) = *(float4*)(C_local + (i * 4)); + } + tl::tcgen05_before_thread_sync(); + __syncthreads(); + tl::tcgen05_after_thread_sync(); + if (tl::tl_shuffle_elect<128>()) { + tl::fence_proxy_async(); + tl::tma_store((&(C[0])), (&(((float*)buf_dyn_shmem)[0])), 32768); + tl::tma_store_arrive(); + tl::tma_store_wait<0>(); + } + if ((((int)threadIdx.x) >> 5) == 0) { + tl::tmem_deallocate((&(C_tmem[0])), 64); + } +} + diff --git a/examples/gemm_sm100/test_fp4_single_tile_probe.tma.linear_epilogue.generated.cu b/examples/gemm_sm100/test_fp4_single_tile_probe.tma.linear_epilogue.generated.cu new file mode 100644 index 0000000000..f17c4de7fc --- /dev/null +++ b/examples/gemm_sm100/test_fp4_single_tile_probe.tma.linear_epilogue.generated.cu @@ -0,0 +1,90 @@ +#if defined(_MSC_VER) && !defined(__clang__) && _MSC_VER < 1940 +#define _tl_orig_alignas alignas +#define alignas(N) _tl_orig_alignas((N) <= 64 ? (N) : 64) +#include +#undef alignas +#define alignas _tl_orig_alignas +#endif +#include +#include +#include +#include +#include +#include +#include +#include +#include +#ifdef ENABLE_BF16 +#include +#endif + +extern "C" __global__ void main_kernel(__grid_constant__ const CUtensorMap A_desc, __grid_constant__ const CUtensorMap B_desc, float* __restrict__ C); +extern "C" __global__ void __launch_bounds__(128, 1) main_kernel(__grid_constant__ const CUtensorMap A_desc, __grid_constant__ const CUtensorMap B_desc, float* __restrict__ C) { + extern __shared__ __align__(1024) uchar buf_dyn_shmem[]; + __shared__ __align__(16) uint64_t load_mbar_a_mem[1]; + auto load_mbar_a = reinterpret_cast(load_mbar_a_mem); + __shared__ __align__(16) uint64_t load_mbar_b_mem[1]; + auto load_mbar_b = reinterpret_cast(load_mbar_b_mem); + __shared__ __align__(16) uint64_t mbar_mem[1]; + auto mbar = reinterpret_cast(mbar_mem); + __shared__ __align__(16) uint C_tmem[1]; + float C_local[64]; + if (tl::tl_shuffle_elect<0>()) { + tl::prefetch_tma_descriptor(A_desc); + tl::prefetch_tma_descriptor(B_desc); + } + if (tl::tl_shuffle_elect<0>()) { + load_mbar_a[0].init(128); + load_mbar_b[0].init(128); + mbar[0].init(1); + } + tl::fence_barrier_init(); + tl::tcgen05_before_thread_sync(); + __syncthreads(); + tl::tcgen05_after_thread_sync(); + if ((((int)threadIdx.x) >> 5) == 0) { + tl::tmem_allocate((&(C_tmem[0])), 64); + } + tl::tcgen05_before_thread_sync(); + __syncthreads(); + tl::tcgen05_after_thread_sync(); + if (tl::tl_shuffle_elect<128>()) { + load_mbar_a[0].expect_transaction(8192); + tl::tma_load(A_desc, load_mbar_a[0], (&(((fp4_e2_t*)buf_dyn_shmem)[0])), 0, 0); + load_mbar_b[0].expect_transaction(4096); + tl::tma_load(B_desc, load_mbar_b[0], (&(((fp4_e2_t*)buf_dyn_shmem)[16384])), 0, 0); + } + load_mbar_a[0].arrive(); + load_mbar_b[0].arrive(); + load_mbar_a[0].wait(0); + load_mbar_b[0].wait(0); + tl::tcgen05_after_thread_sync(); + tl::fence_proxy_async(); + { + tl::Tcgen05SMemDescriptor desc_a; + tl::Tcgen05SMemDescriptor desc_b; + tl::tcgen05_before_thread_sync(); + __syncthreads(); + tl::tcgen05_after_thread_sync(); + if ((((int)threadIdx.x) >> 5) == 0) { + tl::initialize_tcgen05_descriptor(desc_a, (&(((fp4_e2_t*)buf_dyn_shmem)[0])), 1, 64, 0, 0, 2); + tl::initialize_tcgen05_descriptor(desc_b, (&(((fp4_e2_t*)buf_dyn_shmem)[16384])), 1, 64, 0, 0, 2); + #pragma unroll + for (int ki = 0; ki < 4; ++ki) { + tl::tcgen05mma_ws_ss(uint64_t(desc_a + (ki * 32)), uint64_t(desc_b + (ki * 32)), (*reinterpret_cast(C_tmem)) + 0, ((0 < ki) ? 1 : 0), static_cast(135272080), 0, 0, 0, 0); + } + tl::tcgen05_mma_arrive((&(mbar[0]))); + } + } + mbar[0].wait(0); + tl::tcgen05_after_thread_sync(); + tl::tcgen05_ld_32dp32bNx<64, false>(C_tmem[0], 0, (&(C_local[0]))); + #pragma unroll + for (int i = 0; i < 8; ++i) { + tl::store_global_256(&(*(ulonglong4*)(C + ((((int)threadIdx.x) * 64) + (i * 8)))), *(ulonglong4*)(C_local + (i * 8))); + } + if ((((int)threadIdx.x) >> 5) == 0) { + tl::tmem_deallocate((&(C_tmem[0])), 64); + } +} + diff --git a/examples/gemm_sm100/test_fp4_single_tile_probe.tma.tma_only.descriptor.generated.cu b/examples/gemm_sm100/test_fp4_single_tile_probe.tma.tma_only.descriptor.generated.cu new file mode 100644 index 0000000000..a210b5b2bb --- /dev/null +++ b/examples/gemm_sm100/test_fp4_single_tile_probe.tma.tma_only.descriptor.generated.cu @@ -0,0 +1,65 @@ +#if defined(_MSC_VER) && !defined(__clang__) && _MSC_VER < 1940 +#define _tl_orig_alignas alignas +#define alignas(N) _tl_orig_alignas((N) <= 64 ? (N) : 64) +#include +#undef alignas +#define alignas _tl_orig_alignas +#endif +#include +#include +#include +#include +#include +#include +#include +#include +#ifdef ENABLE_BF16 +#include +#endif + +extern "C" __global__ void main_kernel(__grid_constant__ const CUtensorMap A_desc, __grid_constant__ const CUtensorMap B_desc); +extern "C" __global__ void __launch_bounds__(128, 1) main_kernel(__grid_constant__ const CUtensorMap A_desc, __grid_constant__ const CUtensorMap B_desc) { + extern __shared__ __align__(1024) uchar buf_dyn_shmem[]; + __shared__ __align__(16) uint64_t load_mbar_a_mem[1]; + auto load_mbar_a = reinterpret_cast(load_mbar_a_mem); + __shared__ __align__(16) uint64_t load_mbar_b_mem[1]; + auto load_mbar_b = reinterpret_cast(load_mbar_b_mem); + __shared__ __align__(16) uint64_t mbar_mem[1]; + auto mbar = reinterpret_cast(mbar_mem); + __shared__ __align__(16) uint C_tmem[1]; + if (tl::tl_shuffle_elect<0>()) { + tl::prefetch_tma_descriptor(A_desc); + tl::prefetch_tma_descriptor(B_desc); + } + if (tl::tl_shuffle_elect<0>()) { + load_mbar_a[0].init(128); + load_mbar_b[0].init(128); + mbar[0].init(1); + } + tl::fence_barrier_init(); + tl::tcgen05_before_thread_sync(); + __syncthreads(); + tl::tcgen05_after_thread_sync(); + if ((((int)threadIdx.x) >> 5) == 0) { + tl::tmem_allocate((&(C_tmem[0])), 64); + } + tl::tcgen05_before_thread_sync(); + __syncthreads(); + tl::tcgen05_after_thread_sync(); + if (tl::tl_shuffle_elect<128>()) { + load_mbar_a[0].expect_transaction(8192); + tl::tma_load(A_desc, load_mbar_a[0], (&(((fp4_e2_t*)buf_dyn_shmem)[0])), 0, 0); + load_mbar_b[0].expect_transaction(4096); + tl::tma_load(B_desc, load_mbar_b[0], (&(((fp4_e2_t*)buf_dyn_shmem)[16384])), 0, 0); + } + load_mbar_a[0].arrive(); + load_mbar_b[0].arrive(); + load_mbar_a[0].wait(0); + load_mbar_b[0].wait(0); + tl::tcgen05_after_thread_sync(); + tl::fence_proxy_async(); + if ((((int)threadIdx.x) >> 5) == 0) { + tl::tmem_deallocate((&(C_tmem[0])), 64); + } +} + diff --git a/examples/gemm_sm100/test_fp4_single_tile_probe.tma.tma_only.generated.cu b/examples/gemm_sm100/test_fp4_single_tile_probe.tma.tma_only.generated.cu new file mode 100644 index 0000000000..dc1fa8a3ab --- /dev/null +++ b/examples/gemm_sm100/test_fp4_single_tile_probe.tma.tma_only.generated.cu @@ -0,0 +1,61 @@ +#if defined(_MSC_VER) && !defined(__clang__) && _MSC_VER < 1940 +#define _tl_orig_alignas alignas +#define alignas(N) _tl_orig_alignas((N) <= 64 ? (N) : 64) +#include +#undef alignas +#define alignas _tl_orig_alignas +#endif +#include +#include +#include +#include +#include +#include +#include +#include +#ifdef ENABLE_BF16 +#include +#endif + +extern "C" __global__ void main_kernel(const fp4_e2_t* __restrict__ A, const fp4_e2_t* __restrict__ B); +extern "C" __global__ void __launch_bounds__(128, 1) main_kernel(const fp4_e2_t* __restrict__ A, const fp4_e2_t* __restrict__ B) { + extern __shared__ __align__(1024) uchar buf_dyn_shmem[]; + __shared__ __align__(16) uint64_t load_mbar_a_mem[1]; + auto load_mbar_a = reinterpret_cast(load_mbar_a_mem); + __shared__ __align__(16) uint64_t load_mbar_b_mem[1]; + auto load_mbar_b = reinterpret_cast(load_mbar_b_mem); + __shared__ __align__(16) uint64_t mbar_mem[1]; + auto mbar = reinterpret_cast(mbar_mem); + __shared__ __align__(16) uint C_tmem[1]; + if (tl::tl_shuffle_elect<0>()) { + load_mbar_a[0].init(128); + load_mbar_b[0].init(128); + mbar[0].init(1); + } + tl::fence_barrier_init(); + tl::tcgen05_before_thread_sync(); + __syncthreads(); + tl::tcgen05_after_thread_sync(); + if ((((int)threadIdx.x) >> 5) == 0) { + tl::tmem_allocate((&(C_tmem[0])), 64); + } + tl::tcgen05_before_thread_sync(); + __syncthreads(); + tl::tcgen05_after_thread_sync(); + if (tl::tl_shuffle_elect<128>()) { + load_mbar_a[0].expect_transaction(8192); + tl::tma_load((&(((fp4_e2_t*)buf_dyn_shmem)[0])), (&(A[0])), load_mbar_a[0], 8192); + load_mbar_b[0].expect_transaction(4096); + tl::tma_load((&(((fp4_e2_t*)buf_dyn_shmem)[16384])), (&(B[0])), load_mbar_b[0], 4096); + } + load_mbar_a[0].arrive(); + load_mbar_b[0].arrive(); + load_mbar_a[0].wait(0); + load_mbar_b[0].wait(0); + tl::tcgen05_after_thread_sync(); + tl::fence_proxy_async(); + if ((((int)threadIdx.x) >> 5) == 0) { + tl::tmem_deallocate((&(C_tmem[0])), 64); + } +} + diff --git a/examples/gemm_sm100/test_tma_fp4_roundtrip.py b/examples/gemm_sm100/test_tma_fp4_roundtrip.py new file mode 100644 index 0000000000..5fa454face --- /dev/null +++ b/examples/gemm_sm100/test_tma_fp4_roundtrip.py @@ -0,0 +1,207 @@ +"""Diagnostic: force TMA load for FP4 on SM100/SM110. + +This test isolates the GMEM -> SMEM TMA load path for FP4 ALIGN16B: + 1. input tile is loaded with explicit ``T.tma_copy(...)`` so it cannot + silently fall back to a normal copy; + 2. shared memory uses the ALIGN16B gap-aware layout that GEMM also relies on; + 3. the write-back to GMEM stays as a normal copy so the result reflects only + whether the TMA load populated SMEM correctly. + +Run: TILELANG_DISABLE_CACHE=1 python examples/gemm_sm100/test_tma_fp4_roundtrip.py +""" + +import os + +import torch +import tilelang +import tilelang.language as T + +M, K = 256, 256 +block_M, block_K = 128, 128 +in_dtype = T.float4_e2m1fn + + +def tma_roundtrip_kernel( + M, + K, + block_M, + block_K, + in_dtype, + force_tma_load: bool, + skip_store: bool, + use_explicit_tma_copy: bool, +): + """Copy A[block_M, block_K] to SMEM, then write back to Out.""" + A_shape = (M, K) + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + Out: T.Tensor(A_shape, in_dtype), + ): + with T.Kernel(1, T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), in_dtype) + if force_tma_load and use_explicit_tma_copy: + mbar = T.alloc_barrier(128) + + T.annotate_layout( + { + A_shared: tilelang.layout.make_align16b_swizzled_layout(A_shared), + } + ) + + if force_tma_load and use_explicit_tma_copy: + T.tma_copy(A[by * block_M, 0], A_shared, barrier=mbar) + T.barrier_arrive(mbar) + T.mbarrier_wait_parity(mbar, 0) + T.fence_proxy_async() + elif force_tma_load: + T.copy(A[by * block_M, 0], A_shared) + else: + T.copy(A[by * block_M, 0], A_shared) + if not skip_store: + T.copy(A_shared, Out[by * block_M, 0]) + + return main + + +def dump_and_check_kernel_source(kernel, mode: str, expect_tma_load: bool): + kernel_src_path = os.path.join( + os.path.dirname(__file__), + f"test_tma_fp4_roundtrip.{mode}.generated.cu", + ) + src = kernel.get_kernel_source() + with open(kernel_src_path, "w", encoding="utf-8") as f: + f.write(src) + print(f"Kernel source written to: {kernel_src_path}") + + keywords = ( + "arrive_and_expect_tx", + "tma_load(", + "tma_store(", + "fence_proxy_async", + ) + print(f"{mode} kernel debug summary:") + matched = [] + for line_no, line in enumerate(src.splitlines(), start=1): + if any(keyword in line for keyword in keywords): + matched.append((line_no, line.rstrip())) + print(f"{line_no:4d}: {line.rstrip()}") + + has_tma_load = any("tma_load(" in line for _, line in matched) + if expect_tma_load and not has_tma_load: + raise RuntimeError( + f"{mode} kernel was expected to use TMA load, but generated source " + "contains no tl::tma_load(...) call." + ) + if not expect_tma_load and has_tma_load: + raise RuntimeError( + f"{mode} kernel was expected to disable TMA load, but generated source " + "still contains tl::tma_load(...)." + ) + + +print(f"TMA FP4 load diagnostic: M={M}, K={K}, block=({block_M},{block_K})") +run_ref = os.environ.get("TL_FP4_SKIP_REF", "0") != "1" +skip_store = os.environ.get("TL_FP4_SKIP_STORE", "0") == "1" +use_explicit_tma_copy = os.environ.get("TL_FP4_USE_TMA_COPY", "1") == "1" +mode = "explicit_tma_copy" if use_explicit_tma_copy else "auto_tma_lower" +print(f"Load mode: {mode}") + +# --- Build with explicit TMA load --- +func_tma = tma_roundtrip_kernel( + M, + K, + block_M, + block_K, + in_dtype, + force_tma_load=True, + skip_store=skip_store, + use_explicit_tma_copy=use_explicit_tma_copy, +) +kernel_tma = tilelang.compile( + func_tma, + out_idx=[1], + target="cuda", + pass_configs={ + # Keep explicit T.tma_copy(), but disable implicit TMA selection on the + # store path so this test only validates the load side. + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, +) +print("TMA kernel compiled OK") +dump_and_check_kernel_source(kernel_tma, "tma", expect_tma_load=True) + +kernel_ref = None +if run_ref: + # --- Build with TMA disabled (reference) --- + func_ref = tma_roundtrip_kernel( + M, + K, + block_M, + block_K, + in_dtype, + force_tma_load=False, + skip_store=skip_store, + use_explicit_tma_copy=False, + ) + kernel_ref = tilelang.compile( + func_ref, + out_idx=[1], + target="cuda", + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, + ) + print("Reference kernel (no TMA) compiled OK") + dump_and_check_kernel_source(kernel_ref, "non_tma", expect_tma_load=False) + +torch.manual_seed(42) +a_packed = torch.randint(0, 256, (M, K // 2), device="cuda", dtype=torch.uint8).to(torch.int8) + +# --- Run TMA kernel first and synchronize immediately so failures are attributed correctly --- +print("Launching TMA kernel...") +out_tma = kernel_tma(a_packed) +torch.cuda.synchronize() +print("TMA kernel finished.") + +out_ref = None +if kernel_ref is not None: + print("Launching reference kernel...") + out_ref = kernel_ref(a_packed) + torch.cuda.synchronize() + print("Reference kernel finished.") + +if skip_store: + print("\nStore path skipped; this run only validates kernel completion after TMA load.") + raise SystemExit(0) + +# --- Compare byte-by-byte --- +tma_bytes = out_tma.view(torch.uint8).cpu() +inp_bytes = a_packed.view(torch.uint8).cpu() + +match_tma_inp = (tma_bytes == inp_bytes).sum().item() +total = tma_bytes.numel() + +print(f"\nTotal bytes: {total}") +print(f"TMA vs Input: {match_tma_inp}/{total} match ({100*match_tma_inp/total:.1f}%)") +if out_ref is not None: + ref_bytes = out_ref.view(torch.uint8).cpu() + match_ref_inp = (ref_bytes == inp_bytes).sum().item() + match_tma_ref = (tma_bytes == ref_bytes).sum().item() + print(f"Ref vs Input: {match_ref_inp}/{total} match ({100*match_ref_inp/total:.1f}%)") + print(f"TMA vs Ref: {match_tma_ref}/{total} match ({100*match_tma_ref/total:.1f}%)") + +if match_tma_inp == total: + print("\n[PASS] TMA round-trip is byte-exact!") +else: + mismatch_idx = (tma_bytes != inp_bytes).nonzero(as_tuple=True)[0] + print(f"\nFirst 20 mismatched byte indices: {mismatch_idx[:20].tolist()}") + print(f" TMA bytes: {tma_bytes[mismatch_idx[:20]].tolist()}") + print(f" Inp bytes: {inp_bytes[mismatch_idx[:20]].tolist()}") + + row_size = K // 2 + mismatch_rows = (mismatch_idx // row_size).unique() + print(f"Mismatched rows ({len(mismatch_rows)} total): {mismatch_rows[:20].tolist()}") diff --git a/src/backend/cuda/codegen/codegen_cuda.cc b/src/backend/cuda/codegen/codegen_cuda.cc index 5aedbb58e0..4b464f5f13 100644 --- a/src/backend/cuda/codegen/codegen_cuda.cc +++ b/src/backend/cuda/codegen/codegen_cuda.cc @@ -1933,8 +1933,6 @@ std::string CodeGenTileLangCUDA::GetBufferRef(DataType t, << " + " << index_str << ")"; } else if (t == buffer_element_dtype) { int div_factor = 1; - // FP4 div_factor=2 only for global memory (packed 2/byte). - // Shared/local stores unpacked FP4 (1 byte per element, sizeof=1). bool is_packed_scope = scope.empty() || scope == "global"; if (buffer_element_dtype.is_float4() && buffer_element_dtype.lanes() == 1 && is_packed_scope) { @@ -4113,6 +4111,9 @@ void CodeGenTileLangCUDA::VisitStmt_(const AllocateNode *op) { } else if (op->dtype == DataType::Int(1) && scope == "shared") { constant_size = constant_size / 32; } + // SM100 (TCGEN05MMA): FP4 shared memory keeps full allocation (N elements + // of fp4_e2_t = N byte containers). With 16U4_ALIGN16B TMA, each logical + // FP4 occupies one 8-bit SMEM container consumed by tcgen05.mma. if (scope == "shared") { stream << ' ' << vid << '[' << constant_size << "];\n"; } else if (scope == "shared.barrier" || scope == "shared.cluster_barrier") { @@ -4236,12 +4237,12 @@ void CodeGenTileLangCUDA::VisitExpr_(const BufferLoadNode *op, int lanes = op->dtype.lanes(); // declare type. if (value_dtype.lanes() == element_dtype.lanes()) { - // For scalar fp4 loads from non-packed buffers, use tl_fp4_packed_load - // to correctly extract the nibble at the given index (the /2 in - // GetBufferRef maps two consecutive fp4 elements to the same byte, but - // reading that byte only returns the low nibble — the odd-indexed element - // is lost). - if (element_dtype.is_float4() && element_dtype.lanes() == 1) { + // Scalar global FP4 is packed (two logical values per byte). Shared FP4 on + // SM100 unpacksmem is byte-container based and must use normal buffer refs. + std::string scope = GetPtrStorageScope(buffer_var); + bool is_packed_fp4_scope = scope.empty() || scope == "global"; + if (element_dtype.is_float4() && element_dtype.lanes() == 1 && + is_packed_fp4_scope) { std::string idx_str = PrintExpr(index); std::string vid = GetVarID(buffer_var.get()); os << "tl_fp4_packed_load((fp4_e2_2_t*)" << vid << ", " << idx_str << ")"; @@ -4336,11 +4337,12 @@ void CodeGenTileLangCUDA::VisitStmt_(const BufferStoreNode *op) { return; } if (value_dtype.lanes() == element_dtype.lanes()) { - // For scalar fp4 stores to non-packed buffers, use tl_fp4_packed_store - // to correctly handle nibble-level writes. The /2 in GetBufferRef maps two - // consecutive fp4 elements to the same byte, and a plain assignment - // overwrites the entire byte — destroying the neighboring nibble. - if (element_dtype.is_float4() && element_dtype.lanes() == 1) { + // Scalar global FP4 is packed (two logical values per byte). Shared FP4 on + // SM100 unpacksmem is byte-container based and must use normal buffer refs. + std::string scope = GetPtrStorageScope(buffer_var); + bool is_packed_fp4_scope = scope.empty() || scope == "global"; + if (element_dtype.is_float4() && element_dtype.lanes() == 1 && + is_packed_fp4_scope) { std::string idx_str = PrintExpr(index_expr); std::string value = this->PrintExpr(op->value); std::string vid = GetVarID(buffer_var.get()); diff --git a/src/backend/cuda/op/copy.cc b/src/backend/cuda/op/copy.cc index 14cbd6348e..498fefb252 100644 --- a/src/backend/cuda/op/copy.cc +++ b/src/backend/cuda/op/copy.cc @@ -1068,6 +1068,11 @@ Stmt Copy::LowerBulk(const CopyNode &op, const LowerArgs &T, return LowerNormal(op, T, analyzer); } } + bool is_align16b_subbyte_layout = + shared_layout.defined() && shared_tensor_unmapped->dtype.bits() < 8 && + shared_tensor_unmapped->dtype.lanes() == 1 && + StructuralEqual()(shared_layout, + makeAlign16BSwizzleLayout(shared_tensor_unmapped)); auto inner_box_dim = as_const_int(desc.smem_box[0]); if (inner_box_dim == nullptr) { @@ -1077,7 +1082,12 @@ Stmt Copy::LowerBulk(const CopyNode &op, const LowerArgs &T, return LowerNormal(op, T, analyzer); } int instruction_dim = *inner_box_dim; - if (desc.swizzle == static_cast(CU_TENSOR_MAP_SWIZZLE_64B)) { + if (is_align16b_subbyte_layout) { + // 16U4_ALIGN16B uses one 8-bit SMEM container per logical FP4. Keep the + // logical box dimension instead of expanding a 128B swizzle span as if the + // data were densely packed. + instruction_dim = *inner_box_dim; + } else if (desc.swizzle == static_cast(CU_TENSOR_MAP_SWIZZLE_64B)) { instruction_dim = TMAElementsForBytes(64, src->dtype); } else if (desc.swizzle == static_cast(CU_TENSOR_MAP_SWIZZLE_128B)) { instruction_dim = TMAElementsForBytes(128, src->dtype); @@ -1092,8 +1102,10 @@ Stmt Copy::LowerBulk(const CopyNode &op, const LowerArgs &T, << " is not divisible by instruction_dim: " << instruction_dim; desc.smem_box.Set(0, PrimExpr(instruction_dim)); - int inner_box_dim_ = - TMABytesFromElements(instruction_dim, shared_tensor->dtype); + int inner_box_dim_ = is_align16b_subbyte_layout + ? instruction_dim * shared_tensor->dtype.bytes() + : TMABytesFromElements(instruction_dim, + shared_tensor->dtype); struct SwizzleCheck { int swizzle; @@ -1202,7 +1214,11 @@ Stmt Copy::LowerBulk(const CopyNode &op, const LowerArgs &T, if (is_load && barrier_base_id >= 0) { PrimExpr total_bytes; - if ((*inner_box_dim) != instruction_dim) { + if (is_align16b_subbyte_layout) { + // mbarrier transaction bytes track the FP4 payload moved by TMA, not the + // byte-container shared-memory footprint introduced by 16U4_ALIGN16B. + total_bytes = TMABytesFromElements(total_elements, shared_tensor->dtype); + } else if ((*inner_box_dim) != instruction_dim) { int loop_extent = (*inner_box_dim) / instruction_dim; total_bytes = TMABytesFromElements(total_elements * loop_extent, shared_tensor->dtype); diff --git a/src/backend/cuda/op/gemm.cc b/src/backend/cuda/op/gemm.cc index abe33333f1..349ebf9a8a 100644 --- a/src/backend/cuda/op/gemm.cc +++ b/src/backend/cuda/op/gemm.cc @@ -362,12 +362,12 @@ TVM_FFI_STATIC_INIT_BLOCK() { }); refl::GlobalDef().def( "tl.get_tcgen5_instr_desc", - [](int atom_m, int atom_n, int atom_k, DataType ab_dtype, - DataType c_dtype, bool a_is_k_major, bool b_is_k_major, int scale_in_a, - int scale_in_b) { - uint32_t desc = GetTCGEN5InstrDesc(atom_m, atom_n, atom_k, ab_dtype, - c_dtype, a_is_k_major, b_is_k_major, - scale_in_a, scale_in_b); + [](int atom_m, int atom_n, int atom_k, DataType a_dtype, + DataType b_dtype, DataType c_dtype, bool a_is_k_major, + bool b_is_k_major, int scale_in_a, int scale_in_b) { + uint32_t desc = GetTCGEN5InstrDesc( + atom_m, atom_n, atom_k, a_dtype, b_dtype, c_dtype, a_is_k_major, + b_is_k_major, scale_in_a, scale_in_b); return Integer(static_cast(desc)); }); refl::GlobalDef().def("tl.get_tcgen5_blockscaled_instr_desc", diff --git a/src/layout/gemm_layouts.cc b/src/layout/gemm_layouts.cc index d06182efa8..635f1c9bc9 100644 --- a/src/layout/gemm_layouts.cc +++ b/src/layout/gemm_layouts.cc @@ -488,6 +488,33 @@ Layout makeHalfBankSwizzleLayout(const Buffer &buffer) { return ExpandLayout2D(base, buffer); } +// Layout swizzling for 128 bytes (ALIGN16B unpacksmem variant for sub-byte types) +// For CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B: each FP4 element occupies one +// 8-bit SMEM container. The FP4 payload lives inside that byte container +// (bits 2..5 for NVIDIA unpacksmem), so shared-memory indexing must not pack +// two logical FP4 values into one byte. +static Layout MakeAlign16BSwizzleLayout2D(int stride, int continuous, + int element_size) { + ICHECK(element_size < 8 && element_size > 0) + << "ALIGN16B layout is for sub-byte types, got element_size=" << element_size; + const int elements_per_chunk = 16; + ICHECK(stride % 8 == 0) << "stride=" << stride; + ICHECK(continuous % (elements_per_chunk * 8) == 0) + << "continuous=" << continuous + << " must be a multiple of " << (elements_per_chunk * 8); + + Var i = InputPlaceholder(0); + Var j = InputPlaceholder(1); + PrimExpr ts = FloorDiv(i, 8); + PrimExpr s = FloorMod(i, 8); + PrimExpr tc = FloorDiv(FloorDiv(j, elements_per_chunk), 8); + PrimExpr c = FloorMod(FloorDiv(j, elements_per_chunk), 8); + PrimExpr vec = FloorMod(j, elements_per_chunk); + PrimExpr c_swizzle = xor8x8(c, s); + PrimExpr index = vec + (c_swizzle + s * 8) * elements_per_chunk; + return Layout(Array{stride, continuous}, {tc, ts, index}); +} + // Layout swizzling for 128 bytes static Layout MakeFullBankSwizzleLayout2D(int stride, int continuous, int element_size) { @@ -516,6 +543,19 @@ Layout makeFullBankSwizzleLayout(const Buffer &buffer) { return ExpandLayout2D(base, buffer); } +Layout makeAlign16BSwizzleLayout(const Buffer &buffer) { + auto info = GetSwizzleShapeInfoChecked(buffer); + auto base = MakeAlign16BSwizzleLayout2D(static_cast(info.stride), + static_cast(info.continuous), + info.element_size); + Array leading_shape; + leading_shape.reserve(buffer->shape.size() - 2); + for (size_t i = 0; i + 2 < buffer->shape.size(); ++i) { + leading_shape.push_back(buffer->shape[i]); + } + return base->Expand(leading_shape); +} + // Detail implementation please ref to // bitblas::tl::mfma_layout::make_mfma_swizzle_layout Layout makeMatrixCoreSwizzleLayout(int stride, int continuous, int element_size, @@ -875,6 +915,11 @@ Layout makeGemmABLayoutSm100(int mat_stride, int mat_continuous, int continuity, if (element_size == 64) { ICHECK(0) << "float64 on sm100 is not supported now"; } + // Sub-byte types (FP4): use ALIGN16B unpacksmem layout. TMA writes each + // logical FP4 into an 8-bit SMEM container, matching tcgen05.mma consumption. + if (element_size < 8) { + return MakeAlign16BSwizzleLayout2D(mat_stride, mat_continuous, element_size); + } int vector_size = 128 / element_size; if (mat_stride % 8 == 0) { @@ -954,6 +999,17 @@ SwizzleMode DetectSwizzleMode(const Layout &layout, const Buffer &buffer) { if (!TryGetSwizzleShapeInfo(buffer, &info)) { return SwizzleMode::kNone; } + + // Sub-byte types: check ALIGN16B unpacksmem layout (maps to 128B swizzle) + if (info.element_size < 8 && info.element_size > 0) { + if (info.stride % 8 == 0 && info.continuous % 128 == 0) { + if (StructuralEqual()(layout, makeAlign16BSwizzleLayout(buffer))) { + return SwizzleMode::kFull; + } + } + return SwizzleMode::kNone; + } + int vector_size = 128 / info.element_size; // Check from smallest to largest granularity diff --git a/src/layout/layout.cc b/src/layout/layout.cc index c5661dc91d..782aedfa8f 100644 --- a/src/layout/layout.cc +++ b/src/layout/layout.cc @@ -1163,6 +1163,10 @@ TVM_FFI_STATIC_INIT_BLOCK() { [](const Buffer &buffer) { return makeQuarterBankSwizzleLayout(buffer); }) + .def("tl.make_align16b_swizzled_layout", + [](const Buffer &buffer) { + return makeAlign16BSwizzleLayout(buffer); + }) .def("tl.make_linear_layout", [](Array shape) { return makeLinearLayout(shape); }) .def("tl.make_gemm_fragment_8x8", []() { return makeGemmFragment8x8(); }) diff --git a/src/layout/layout.h b/src/layout/layout.h index 8043c5765d..7b08bccf02 100644 --- a/src/layout/layout.h +++ b/src/layout/layout.h @@ -284,6 +284,7 @@ Layout makeTcgen05mmaSwizzledLayout(const Buffer &buffer, int continuity = -1, Layout makeFullBankSwizzleLayout(const Buffer &buffer); Layout makeHalfBankSwizzleLayout(const Buffer &buffer); Layout makeQuarterBankSwizzleLayout(const Buffer &buffer); +Layout makeAlign16BSwizzleLayout(const Buffer &buffer); // Swizzle mode for shared memory layouts (nvidia only) // Smaller enum value = smaller swizzle granularity diff --git a/src/op/tcgen5_meta.h b/src/op/tcgen5_meta.h index 1ddfa9367e..ed6408ffef 100644 --- a/src/op/tcgen5_meta.h +++ b/src/op/tcgen5_meta.h @@ -156,7 +156,8 @@ GetTCGEN5MMAMeta(int M, int N, int K, DataType ab_dtype, DataType c_dtype, } inline uint32_t GetTCGEN5InstrDesc(int atom_m, int atom_n, int atom_k, - DataType ab_dtype, DataType c_dtype, + DataType a_dtype, DataType b_dtype, + DataType c_dtype, bool a_is_k_major, bool b_is_k_major, int scale_in_a, int scale_in_b) { ICHECK(atom_m % 16 == 0) << "atom_m must be divisible by 16"; @@ -193,8 +194,8 @@ inline uint32_t GetTCGEN5InstrDesc(int atom_m, int atom_n, int atom_k, return 0u; }; - uint32_t a_format = encode_dtype(ab_dtype); - uint32_t b_format = a_format; + uint32_t a_format = encode_dtype(a_dtype); + uint32_t b_format = encode_dtype(b_dtype); uint32_t c_format = 0; if (c_dtype.is_float16()) { diff --git a/src/op/utils.cc b/src/op/utils.cc index 2be9704617..972dfa78a6 100644 --- a/src/op/utils.cc +++ b/src/op/utils.cc @@ -147,11 +147,11 @@ PrimExpr MakeAccessPtrFromBufferLoad(const BufferLoad &load, int rw_mask) { // Maps TVM DataType to CUDA's CUtensorMapDataType enum value. int to_CUtensorMapDataType(DataType dtype) { - // CUDA 13 adds packed U4 TensorMap formats. The vendored CUDA stub may lag - // the installed toolkit, so keep the enum value by CUDA's documented order. - constexpr int kTensorMapDataType16U4Align8B = 13; + // CUDA headers can lag the driver API enum additions for packed narrow TMA + // formats. Keep the numeric value in sync with runtime.cc validation. + constexpr int kTensorMapDataType16U4Align16B = 14; if (dtype.is_float4_e2m1fn()) { - return kTensorMapDataType16U4Align8B; + return kTensorMapDataType16U4Align16B; } CUtensorMapDataType tp; diff --git a/src/tl_templates/cuda/gemm_mma.h b/src/tl_templates/cuda/gemm_mma.h index 6314dc8f89..cb3536af28 100644 --- a/src/tl_templates/cuda/gemm_mma.h +++ b/src/tl_templates/cuda/gemm_mma.h @@ -60,6 +60,7 @@ TL_DISPATCH_MMA(int8_t, int8_t, int, SM80_16x8x32_S32S8S8S32_TN) TL_DISPATCH_MMA(double, double, double, SM80_8x8x4_F64F64F64F64_TN) #elif __CUDA_ARCH_LIST__ >= 1000 #include "cuda_fp8.h" +#include "cuda_fp4.h" #include #include #include diff --git a/src/tl_templates/cuda/gemm_sm100.h b/src/tl_templates/cuda/gemm_sm100.h index 84e22f24e1..e0cbec36a2 100644 --- a/src/tl_templates/cuda/gemm_sm100.h +++ b/src/tl_templates/cuda/gemm_sm100.h @@ -204,6 +204,369 @@ struct MMA_Traits, } }; +// MMA_Traits specialization for FP4 (float_e2m1_t) with SM100_MMA_F8F6F4_SS. +// CUTLASS's to_UMMAFormat returns MXF4Format::E2M1(=1), but +// kind::f8f6f4 requires MXF8F6F4Format::E2M1(=5). This more-specialized +// partial specialization provides the correct descriptor encoding. +template +struct MMA_Traits, cute::C, + cute::integral_constant, + cute::integral_constant, + cute::integral_constant, + cute::integral_constant> +{ + using ValTypeD = c_type; + using ValTypeA = cute::float_e2m1_t; + using ValTypeB = cute::float_e2m1_t; + using ValTypeC = c_type; + + using FrgTypeA = UMMA::smem_desc; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_1sm; + + constexpr static int K = 32; + + using Shape_MNK = Shape, Int, Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + + static constexpr UMMA::InstrDescriptor make_fp4_desc() { + UMMA::InstrDescriptor d = {}; + d.c_format_ = uint8_t(UMMA::to_CFormat()); + d.a_format_ = 5; // MXF8F6F4Format::E2M1 + d.b_format_ = 5; + d.a_negate_ = uint8_t(a_neg); + d.b_negate_ = uint8_t(b_neg); + d.a_major_ = uint8_t(a_major); + d.b_major_ = uint8_t(b_major); + d.n_dim_ = N >> 3; + d.m_dim_ = M >> 4; + return d; + } + + UMMA::InstrDescriptor idesc_ = make_fp4_desc(); + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + template + CUTE_HOST_DEVICE constexpr friend void + mma_unpack(MMA_Traits const &traits, Tensor &D, + Tensor const &A, Tensor const &B, + Tensor const &C) { + static_assert(is_tmem::value, "Expected tmem"); + static_assert(is_rmem::value, "Expected desc registers"); + static_assert(is_rmem::value, "Expected desc registers"); + static_assert(is_tmem::value, "Expected tmem"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + SM100_MMA_F8F6F4_SS::fma(desc_a, desc_b, tmem_c, + uint32_t(traits.accumulate_), idesc); + } +}; + +// Same fix for SM100_MMA_F8F6F4_WS_SS (weight-stationary .ws variant). +template +struct MMA_Traits, cute::C, + cute::integral_constant, + cute::integral_constant, + cute::integral_constant, + cute::integral_constant> +{ + using ValTypeD = c_type; + using ValTypeA = cute::float_e2m1_t; + using ValTypeB = cute::float_e2m1_t; + using ValTypeC = c_type; + + using FrgTypeA = UMMA::smem_desc; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_ws_1sm; + + constexpr static int K = 32; + + using Shape_MNK = Shape, Int, Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + + static constexpr UMMA::InstrDescriptor make_fp4_desc() { + UMMA::InstrDescriptor d = {}; + d.c_format_ = uint8_t(UMMA::to_CFormat()); + d.a_format_ = 5; // MXF8F6F4Format::E2M1 + d.b_format_ = 5; + d.a_negate_ = uint8_t(a_neg); + d.b_negate_ = uint8_t(b_neg); + d.a_major_ = uint8_t(a_major); + d.b_major_ = uint8_t(b_major); + d.n_dim_ = N >> 3; + d.m_dim_ = M >> 4; + return d; + } + + UMMA::InstrDescriptor idesc_ = make_fp4_desc(); + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + template + CUTE_HOST_DEVICE constexpr friend void + mma_unpack(MMA_Traits const &traits, Tensor &D, + Tensor const &A, Tensor const &B, + Tensor const &C) { + static_assert(is_tmem::value, "Expected tmem"); + static_assert(is_rmem::value, "Expected desc registers"); + static_assert(is_rmem::value, "Expected desc registers"); + static_assert(is_tmem::value, "Expected tmem"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + SM100_MMA_F8F6F4_WS_SS::fma(desc_a, desc_b, tmem_c, + uint32_t(traits.accumulate_), idesc); + } +}; + +// MMA_Traits specialization for FP4 (float_e2m1_t) with SM100_MMA_F8F6F4_TS. +// A operand comes from TMEM, B from SMEM. Same descriptor encoding fix as SS. +// 你的理解接近正确,但还需要澄清一点: +// 使用时其实不用再对 template 里面的参数“再次特化”,而是—— +// 1. 外层 template <...> 这部分负责声明参数列表,即让这个偏特化对所有可能的 组合都有效。 +// 2. 内层 struct MMA_Traits> 说明:只要模板参数匹配到 SM100_MMA_F8F6F4_TS 那一整组, +// 就会自动选择这个偏特化的实现,不需要“再特化”。 +// 换句话说,你写一个 MMA_Traits,如果 X 恰好能匹配 SM100_MMA_F8F6F4_TS<...> 那一组参数, +// 这个特化版本就会被用到。比如: +// using Traits = MMA_Traits>; +// Traits的内容就是匹配到此特化,无需再继续特化template参数。 +// 总结:外层 template 支持泛型匹配,struct MMA_Traits<...> 实现具体特化,实际用时只需传入对应参数即可自动选中。 + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = cute::float_e2m1_t; + using ValTypeB = cute::float_e2m1_t; + using ValTypeC = c_type; + + using FrgTypeA = UMMA::tmem_frg_1sm; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_1sm; + + constexpr static int K = 32; + + using Shape_MNK = Shape, Int, Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + + static constexpr UMMA::InstrDescriptor make_fp4_desc() { + UMMA::InstrDescriptor d = {}; + d.c_format_ = uint8_t(UMMA::to_CFormat()); + d.a_format_ = 5; // MXF8F6F4Format::E2M1 + d.b_format_ = 5; + d.a_negate_ = uint8_t(a_neg); + d.b_negate_ = uint8_t(b_neg); + d.a_major_ = uint8_t(a_major); + d.b_major_ = uint8_t(b_major); + d.n_dim_ = N >> 3; + d.m_dim_ = M >> 4; + return d; + } + + UMMA::InstrDescriptor idesc_ = make_fp4_desc(); + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + template + CUTE_HOST_DEVICE constexpr friend void + mma_unpack(MMA_Traits const &traits, Tensor &D, + Tensor const &A, Tensor const &B, + Tensor const &C) { + static_assert(is_tmem::value, "Expected tmem"); + static_assert(is_tmem::value, "Expected tmem"); + static_assert(is_rmem::value, "Expected desc registers"); + static_assert(is_tmem::value, "Expected tmem"); + + uint64_t tmem_a = raw_pointer_cast(A.data()); + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + SM100_MMA_F8F6F4_TS::fma( + tmem_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } +}; + +// Mixed-precision MMA_Traits: A = float_e4m3_t (FP8), B = float_e2m1_t (FP4). +// Descriptor: a_format = 0 (MXF8F6F4Format::E4M3), b_format = 5 (MXF8F6F4Format::E2M1). +// SS variant (both operands from shared memory). +template +struct MMA_Traits, cute::C, + cute::integral_constant, + cute::integral_constant, + cute::integral_constant, + cute::integral_constant> +{ + using ValTypeD = c_type; + using ValTypeA = cute::float_e4m3_t; + using ValTypeB = cute::float_e2m1_t; + using ValTypeC = c_type; + + using FrgTypeA = UMMA::smem_desc; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_1sm; + + constexpr static int K = 32; + + using Shape_MNK = Shape, Int, Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + + static constexpr UMMA::InstrDescriptor make_a8w4_desc() { + UMMA::InstrDescriptor d = {}; + d.c_format_ = uint8_t(UMMA::to_CFormat()); + d.a_format_ = 0; // MXF8F6F4Format::E4M3 + d.b_format_ = 5; // MXF8F6F4Format::E2M1 + d.a_negate_ = uint8_t(a_neg); + d.b_negate_ = uint8_t(b_neg); + d.a_major_ = uint8_t(a_major); + d.b_major_ = uint8_t(b_major); + d.n_dim_ = N >> 3; + d.m_dim_ = M >> 4; + return d; + } + + UMMA::InstrDescriptor idesc_ = make_a8w4_desc(); + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + template + CUTE_HOST_DEVICE constexpr friend void + mma_unpack(MMA_Traits const &traits, Tensor &D, + Tensor const &A, Tensor const &B, + Tensor const &C) { + static_assert(is_tmem::value, "Expected tmem"); + static_assert(is_rmem::value, "Expected desc registers"); + static_assert(is_rmem::value, "Expected desc registers"); + static_assert(is_tmem::value, "Expected tmem"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + SM100_MMA_F8F6F4_SS::fma(desc_a, desc_b, tmem_c, + uint32_t(traits.accumulate_), idesc); + } +}; + +// Mixed A8W4 — WS_SS variant (weight-stationary, M=64 or M=32). +template +struct MMA_Traits, cute::C, + cute::integral_constant, + cute::integral_constant, + cute::integral_constant, + cute::integral_constant> +{ + using ValTypeD = c_type; + using ValTypeA = cute::float_e4m3_t; + using ValTypeB = cute::float_e2m1_t; + using ValTypeC = c_type; + + using FrgTypeA = UMMA::smem_desc; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_ws_1sm; + + constexpr static int K = 32; + + using Shape_MNK = Shape, Int, Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + + static constexpr UMMA::InstrDescriptor make_a8w4_desc() { + UMMA::InstrDescriptor d = {}; + d.c_format_ = uint8_t(UMMA::to_CFormat()); + d.a_format_ = 0; // MXF8F6F4Format::E4M3 + d.b_format_ = 5; // MXF8F6F4Format::E2M1 + d.a_negate_ = uint8_t(a_neg); + d.b_negate_ = uint8_t(b_neg); + d.a_major_ = uint8_t(a_major); + d.b_major_ = uint8_t(b_major); + d.n_dim_ = N >> 3; + d.m_dim_ = M >> 4; + return d; + } + + UMMA::InstrDescriptor idesc_ = make_a8w4_desc(); + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + template + CUTE_HOST_DEVICE constexpr friend void + mma_unpack(MMA_Traits const &traits, Tensor &D, + Tensor const &A, Tensor const &B, + Tensor const &C) { + static_assert(is_tmem::value, "Expected tmem"); + static_assert(is_rmem::value, "Expected desc registers"); + static_assert(is_rmem::value, "Expected desc registers"); + static_assert(is_tmem::value, "Expected tmem"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + SM100_MMA_F8F6F4_WS_SS::fma(desc_a, desc_b, tmem_c, + uint32_t(traits.accumulate_), idesc); + } +}; + namespace tl_tcgen5mma { using cutlass::gemm::collective::detail::sm100_smem_selector; @@ -336,6 +699,104 @@ struct DispatchInstruction>; }; +// FP4 (float_e2m1_t) — C = float +template +struct DispatchInstruction> { + using MMA = + MMA_Traits, Int, integral_constant, + integral_constant, + integral_constant, + integral_constant>; +}; + +template +struct DispatchInstruction> { + using MMA = + MMA_Traits, Int, integral_constant, + integral_constant, + integral_constant, + integral_constant>; +}; + +// FP4 (float_e2m1_t) — C = half_t +template +struct DispatchInstruction> { + using MMA = MMA_Traits, Int, + integral_constant, + integral_constant, + integral_constant, + integral_constant>; +}; +template +struct DispatchInstruction> { + using MMA = MMA_Traits, Int, + integral_constant, + integral_constant, + integral_constant, + integral_constant>; +}; + +// Mixed A8W4 (float_e4m3_t x float_e2m1_t) — C = float +template +struct DispatchInstruction> { + using MMA = + MMA_Traits, Int, integral_constant, + integral_constant, + integral_constant, + integral_constant>; +}; + +template +struct DispatchInstruction> { + using MMA = + MMA_Traits, Int, integral_constant, + integral_constant, + integral_constant, + integral_constant>; +}; + +// Mixed A8W4 (float_e4m3_t x float_e2m1_t) — C = half_t +template +struct DispatchInstruction> { + using MMA = MMA_Traits, Int, + integral_constant, + integral_constant, + integral_constant, + integral_constant>; +}; +template +struct DispatchInstruction> { + using MMA = MMA_Traits, Int, + integral_constant, + integral_constant, + integral_constant, + integral_constant>; +}; + template diff --git a/src/tl_templates/cuda/instruction/tcgen05mma.h b/src/tl_templates/cuda/instruction/tcgen05mma.h index 5e700a73b5..c719e05e10 100644 --- a/src/tl_templates/cuda/instruction/tcgen05mma.h +++ b/src/tl_templates/cuda/instruction/tcgen05mma.h @@ -279,16 +279,6 @@ TL_DEVICE void tcgen05mma_ts( tmem_a, desc_b, tmem_c, scalec, desc_val, mask0, mask1, mask2, mask3); } -// FP4 (e2m1) uses same f8f6f4 instruction kind as FP8 -template <> -TL_DEVICE void tcgen05mma_ts( - uint32_t const &tmem_a, uint64_t const &desc_b, uint32_t const &tmem_c, - uint32_t const &scalec, uint32_t const &desc_val, int const &mask0, - int const &mask1, int const &mask2, int const &mask3) { - tcgen05mma_ts(tmem_a, desc_b, tmem_c, scalec, - desc_val, mask0, mask1, mask2, mask3); -} - // F16/BF16 instruction kind (maps to kind::f16) template <> TL_DEVICE void tcgen05mma_ss( @@ -534,16 +524,6 @@ TL_DEVICE void tcgen05mma_ss( desc_a, desc_b, tmem_c, scalec, desc_val, mask0, mask1, mask2, mask3); } -// FP4 (e2m1) uses same f8f6f4 instruction kind as FP8 -template <> -TL_DEVICE void tcgen05mma_ss( - uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c, - uint32_t const &scalec, uint32_t const &desc_val, int const &mask0, - int const &mask1, int const &mask2, int const &mask3) { - tcgen05mma_ss(desc_a, desc_b, tmem_c, scalec, - desc_val, mask0, mask1, mask2, mask3); -} - // WS variants: tcgen05.mma.ws.cta_group::1.kind::xxx // Generic declaration falls back to static assert template diff --git a/tilelang/intrinsics/tcgen05_macro_generator.py b/tilelang/intrinsics/tcgen05_macro_generator.py index 7799b79150..c496a8c343 100644 --- a/tilelang/intrinsics/tcgen05_macro_generator.py +++ b/tilelang/intrinsics/tcgen05_macro_generator.py @@ -9,6 +9,7 @@ from tilelang.utils import is_tensor_memory from tilelang.layout import ( Layout, + make_align16b_swizzled_layout, make_full_bank_swizzled_layout, make_half_bank_swizzled_layout, make_quarter_bank_swizzled_layout, @@ -133,8 +134,13 @@ def _initialize_k_dim(self, a_dtype=T.float16): def _determinate_swizzle_mode(self, buffer: Buffer, layout: Layout) -> SwizzleMode: # same behavior to src/layout/gemm_layouts.cc::makeGemmABLayoutHopper + tir_buffer = buffer.buffer if isinstance(buffer, BufferRegion) else buffer if layout is None or layout.is_equal(make_linear_layout(buffer)): return SwizzleMode.NONE + if DataType(tir_buffer.dtype).bits < 8: + if layout.is_equal(make_align16b_swizzled_layout(buffer)): + return SwizzleMode.SWIZZLE_128B + raise ValueError(f"Unsupported sub-byte swizzle mode: {layout}") elif layout.is_equal(make_quarter_bank_swizzled_layout(buffer)): return SwizzleMode.SWIZZLE_32B elif layout.is_equal(make_half_bank_swizzled_layout(buffer)): @@ -918,6 +924,7 @@ def get_tcgen5_instr_desc( atom_n, atom_k, DataType(self.a_dtype), + DataType(self.b_dtype), DataType(self.accum_dtype), a_is_k_major, b_is_k_major, diff --git a/tilelang/layout/__init__.py b/tilelang/layout/__init__.py index ae50e86cb4..e108270546 100644 --- a/tilelang/layout/__init__.py +++ b/tilelang/layout/__init__.py @@ -11,6 +11,7 @@ make_full_bank_swizzled_layout, # noqa: F401 make_half_bank_swizzled_layout, # noqa: F401 make_quarter_bank_swizzled_layout, # noqa: F401 + make_align16b_swizzled_layout, # noqa: F401 make_linear_layout, # noqa: F401 make_gemm_fragment_8x8, # noqa: F401 make_gemm_fragment_8x8_transposed, # noqa: F401 diff --git a/tilelang/layout/swizzle.py b/tilelang/layout/swizzle.py index 01a2f741cd..41b8f2a5b1 100644 --- a/tilelang/layout/swizzle.py +++ b/tilelang/layout/swizzle.py @@ -3,12 +3,16 @@ # pylint: disable=invalid-name, unsupported-binary-operation from __future__ import annotations +import warnings import tvm from tvm import tir from tilelang import _ffi_api from tilelang._typing import BufferLikeType, BufferLikeTypeTuple +_WARNED_LEGACY_TCGEN05_LAYOUT_FFI = False + + def _get_buffer_info(buffer_or_load_or_region: BufferLikeType) -> tuple[tir.Buffer, list[int], str]: """ Extract buffer, shape, and dtype from BufferLikeType. @@ -84,10 +88,36 @@ def make_wgmma_swizzled_layout(buffer: BufferLikeType, continuity: int = None, k # for TCGEN05MMA Intrinsics def make_tcgen05mma_swizzled_layout(buffer: BufferLikeType, continuity: int = None, k_major: bool = True): - buf, _, _ = _get_buffer_info(buffer) + global _WARNED_LEGACY_TCGEN05_LAYOUT_FFI + buf, shape, _ = _get_buffer_info(buffer) + stride, continuous = _get_stride_continuous(buffer) + element_size = _get_element_size(buffer) if continuity is None: - continuity = -1 - return _ffi_api.make_tcgen05mma_swizzled_layout(buf, continuity, k_major) + continuity = continuous + try: + base = _ffi_api.make_tcgen05mma_swizzled_layout( + stride, + continuous, + continuity, + element_size, + k_major, + ) + return base.reshape(shape) + except TypeError as err: + # Keep Python sources compatible with older built libs that still expose + # the legacy FFI signature: (buffer, continuity, k_major). + if "Mismatched number of arguments" not in str(err): + raise + if not _WARNED_LEGACY_TCGEN05_LAYOUT_FFI: + warnings.warn( + "Detected legacy tcgen05 swizzle-layout FFI in the loaded native " + "TileLang library. Rebuild the native `build/` artifacts so FP4 " + "SM100/SM110 kernels use the current gap-aware ALIGN16B layout.", + RuntimeWarning, + stacklevel=2, + ) + _WARNED_LEGACY_TCGEN05_LAYOUT_FFI = True + return _ffi_api.make_tcgen05mma_swizzled_layout(buf, continuity, k_major) # swizzle 128B @@ -126,6 +156,24 @@ def make_quarter_bank_swizzled_layout(buffer: BufferLikeType): return _ffi_api.make_quarter_bank_swizzled_layout(buf) +def make_align16b_swizzled_layout(buffer: BufferLikeType): + """ALIGN16B gap-aware 128B swizzle for sub-byte types (e.g. FP4). + + Each group of 16 sub-byte elements occupies 16 SMEM bytes (8 data + 8 gap). + The layout's index formula, combined with div_factor=2, skips the gap bytes. + + Args: + buffer: BufferLikeType with a sub-byte dtype (e.g. float4_e2m1fn) + """ + buf, _, _ = _get_buffer_info(buffer) + if hasattr(_ffi_api, "make_align16b_swizzled_layout"): + return _ffi_api.make_align16b_swizzled_layout(buf) + # Compatibility with native builds that already route sub-byte TCGEN05 + # layouts through makeGemmABLayoutSm100, but do not expose the dedicated FFI + # wrapper yet. Rebuild native artifacts to use the direct symbol. + return make_tcgen05mma_swizzled_layout(buffer) + + def make_linear_layout(buffer_or_load_or_region: BufferLikeType): """ Create a row-major linear layout for any dimension. diff --git a/tilelang/tileop/gemm/gemm_tcgen05.py b/tilelang/tileop/gemm/gemm_tcgen05.py index 28d4c805be..9b3de02407 100644 --- a/tilelang/tileop/gemm/gemm_tcgen05.py +++ b/tilelang/tileop/gemm/gemm_tcgen05.py @@ -6,6 +6,7 @@ make_full_bank_swizzled_layout, make_half_bank_swizzled_layout, make_quarter_bank_swizzled_layout, + make_tcgen05mma_swizzled_layout, make_linear_layout, ) from tilelang.intrinsics.tcgen05_macro_generator import ( @@ -43,8 +44,29 @@ class GemmTCGEN5(GemmBase): of operands A and B. """ - def infer_shared_layout(self, continuity: int) -> Callable[[tir.Buffer], Layout]: - """Infer a standard shared-memory swizzle layout for TCGEN05 operands.""" + def infer_shared_layout(self, continuity: int, k_major: bool) -> Callable[[tir.Buffer], Layout]: + """Infer the shared-memory layout for TCGEN05 operands. + + Sub-byte inputs (e.g. FP4) must use the TCGEN05-specific layout helper so + they match the ALIGN16B + gap-aware SMEM organization expected by TMA and + descriptor generation. When TMA is disabled (non-TMA / cp.async path), + the SIMT copy writes packed-linear data, so we return a plain linear layout + to keep the SMEM descriptor consistent with the actual data layout. + """ + if self.in_dtype.bits < 8: + import tvm + try: + _pass_ctx = tvm.transform.PassContext.current() + disable_tma = _pass_ctx.config.get("tl.disable_tma_lower", False) + except Exception: + disable_tma = False + if disable_tma: + # Non-TMA path: SIMT copy writes packed-linear nibbles; use a + # linear layout so the SMEM descriptor matches the actual data. + return make_linear_layout + return lambda buffer: make_tcgen05mma_swizzled_layout( + buffer, continuity=continuity, k_major=k_major + ) vectorized_size = 128 // self.in_dtype.bits if continuity % (vectorized_size * 8) == 0: return make_full_bank_swizzled_layout @@ -93,15 +115,15 @@ def infer_layout(self, target: Target, thread_nums: int): b_continuity = self.K if b_is_k_major else int(self.B.shape[-1]) # don't use N, as it may be for 2cta return { - self.A: self.infer_shared_layout(a_continuity)(self.A), - self.B: self.infer_shared_layout(b_continuity)(self.B), + self.A: self.infer_shared_layout(a_continuity, a_is_k_major)(self.A), + self.B: self.infer_shared_layout(b_continuity, b_is_k_major)(self.B), self.C: mma_emitter.make_mma_store_layout(self.C), } if self.is_gemm_ts(): b_continuity = self.K if b_is_k_major else int(self.B.shape[-1]) layouts = { self.A: mma_emitter.make_mma_store_layout(self.A), - self.B: self.infer_shared_layout(b_continuity)(self.B), + self.B: self.infer_shared_layout(b_continuity, b_is_k_major)(self.B), self.C: mma_emitter.make_mma_store_layout(self.C), } return layouts From eeddaac0bea4c00afa2ed16a0d3663a2be4a2353 Mon Sep 17 00:00:00 2001 From: wahao Date: Mon, 11 May 2026 16:07:11 +0800 Subject: [PATCH 08/19] Implement SM100/SM110 FP4 and A8W4 examples using the TCGEN05/TMA unpacksmem path, functionalities from all examples were verified valid --- .../gemm_fp4/GEMM_NV_FP4_FEATURE_STEPS.md | 107 ------- .../gemm_fp4/example_fusedmoe_a8w4_sm100.py | 171 +++++++++++ .../gemm_fp4/example_fusedmoe_a8w4_sm120.py | 45 ++- examples/gemm_fp4/example_gemm_a8w4_sm100.py | 129 ++++++++ examples/gemm_fp4/example_gemm_a8w4_sm120.py | 49 +++- .../example_gemm_fp4_sm100.py} | 109 ++++--- examples/gemm_fp4/example_gemm_fp4_sm120.py | 46 ++- examples/gemm_fp4/gemm_fp4_sm100.cu | 100 ------- examples/gemm_fp4/gemm_fp4_sm120.cu | 74 ----- examples/gemm_sm100/test_fp4_diagonal.py | 216 -------------- .../gemm_sm100/test_fp4_single_tile_probe.py | 276 ------------------ ...est_fp4_single_tile_probe.tma.generated.cu | 102 ------- ...ile_probe.tma.linear_epilogue.generated.cu | 90 ------ ...probe.tma.tma_only.descriptor.generated.cu | 65 ----- ...ingle_tile_probe.tma.tma_only.generated.cu | 61 ---- examples/gemm_sm100/test_tma_fp4_roundtrip.py | 207 ------------- src/backend/cuda/codegen/codegen_cuda.cc | 8 +- src/backend/cuda/op/copy.cc | 8 +- src/backend/cuda/op/gemm.cc | 6 +- src/layout/gemm_layouts.cc | 16 +- src/op/tcgen5_meta.h | 6 +- src/tl_templates/cuda/cuda_fp4.h | 4 +- src/tl_templates/cuda/gemm_mma.h | 4 +- src/tl_templates/cuda/gemm_sm100.h | 217 +++++++------- src/tl_templates/cuda/instruction/mma.h | 18 +- src/tl_templates/cuda/ldsm.h | 14 +- tilelang/intrinsics/mma_macro_generator.py | 7 +- .../intrinsics/tcgen05_macro_generator.py | 63 +++- tilelang/intrinsics/utils.py | 21 +- tilelang/tileop/gemm/gemm_base.py | 5 +- tilelang/tileop/gemm/gemm_tcgen05.py | 24 +- 31 files changed, 708 insertions(+), 1560 deletions(-) delete mode 100644 examples/gemm_fp4/GEMM_NV_FP4_FEATURE_STEPS.md create mode 100644 examples/gemm_fp4/example_fusedmoe_a8w4_sm100.py create mode 100644 examples/gemm_fp4/example_gemm_a8w4_sm100.py rename examples/{gemm_sm100/gemm_tcgen5mma_fp4.py => gemm_fp4/example_gemm_fp4_sm100.py} (75%) delete mode 100644 examples/gemm_fp4/gemm_fp4_sm100.cu delete mode 100644 examples/gemm_fp4/gemm_fp4_sm120.cu delete mode 100644 examples/gemm_sm100/test_fp4_diagonal.py delete mode 100644 examples/gemm_sm100/test_fp4_single_tile_probe.py delete mode 100644 examples/gemm_sm100/test_fp4_single_tile_probe.tma.generated.cu delete mode 100644 examples/gemm_sm100/test_fp4_single_tile_probe.tma.linear_epilogue.generated.cu delete mode 100644 examples/gemm_sm100/test_fp4_single_tile_probe.tma.tma_only.descriptor.generated.cu delete mode 100644 examples/gemm_sm100/test_fp4_single_tile_probe.tma.tma_only.generated.cu delete mode 100644 examples/gemm_sm100/test_tma_fp4_roundtrip.py diff --git a/examples/gemm_fp4/GEMM_NV_FP4_FEATURE_STEPS.md b/examples/gemm_fp4/GEMM_NV_FP4_FEATURE_STEPS.md deleted file mode 100644 index 5fb808496a..0000000000 --- a/examples/gemm_fp4/GEMM_NV_FP4_FEATURE_STEPS.md +++ /dev/null @@ -1,107 +0,0 @@ -# GEMM NV FP4 Feature – 现状 Review 与实现步骤评估 - -## 一、当前代码库现状 - -### 1. 已有基础(可直接复用) - -| 模块 | 现状 | -|------|------| -| **FP4 类型与转换** | `src/tl_templates/cuda/cuda_fp4.h` 已有 `__nv_fp4_e2m1`、`fp4_e2_t`、fp4↔half/float/bf16 转换,以及 packed load/store 辅助函数。 | -| **Load/Store** | `lower_ldg_stg.cc` 已对 `float4_e2m1fn` 做 4-bit 处理(bits()、total_bits 等)。 | -| **CuTeDSL codegen** | `codegen_cutedsl.cc` 已支持 `Float4E2M1FN`。 | -| **Python dtype** | `tilelang/language/dtypes.py` 已有 `float4_e2m1fn` 及 x2/x4/…/x64 向量类型;与 torch 的映射在 `dtypes.py` / `tensor.py`。 | -| **TCGEN5 窄精度 meta** | `src/op/tcgen5_meta.h` 中 `GetTCGEN5MMAMeta` 已把 `float4_e2m1fn` 与 float8/float6 一起纳入窄精度分支:K%32==0,M/N 与现有 F8/F6 同规则,C 为 float32 或 float16。 | -| **SM100 MMA 模板** | `gemm_sm100.h` 已有 `SM100_MMA_F8F6F4_WS_SS`(PTX `kind::f8f6f4`),支持 ≤8bit 类型;当前仅对 `cute::float_e4m3_t` / `float_e5m2_t` 做了 `DispatchInstruction`。 | -| **Dequant 示例** | `examples/dequantize_gemm/` 下有 FP4 相关示例(如 `example_dequant_gemm_fp4_hopper.py`),用于参考 FP4 数据流与 scaling。 | - -### 2. 缺口(需要补齐) - -| 模块 | 问题 | -|------|------| -| **指令描述符** | `GetTCGEN5InstrDesc()` 中 `encode_dtype()` 未处理 `float4_e2m1fn`,会走到 `LOG(FATAL)`。需为 FP4 分配正确的 format 编码(需查 CUTLASS/PTX 描述符定义)。 | -| **PTX 类型枚举与字符串** | `src/target/ptx.cc` 的 `DTypeFromString()` 无 `"float4_e2m1fn"`;`common.h` 的 `DataType` 枚举无 `kFloat4_e2m1fn`;`enum_to_str` / `num_bits` 等表也需扩展。 | -| **GEMM 到 codegen 的 dtype 传递** | 当前 GEMM Lower 时把 `ab_dtype` 转成字符串传给后端;需保证 TVM `float4_e2m1fn` → 字符串 `"float4_e2m1fn"` → PTX/codegen 能识别并在 C++ 端选用 FP4 路径。 | -| **C++ 侧 FP4 类型与 Dispatch** | `gemm_sm100.h` 中 `DispatchInstruction` 仅有 `float_e4m3_t` / `float_e5m2_t`;需增加 FP4 的 C++ 类型(如对接 `cuda_fp4.h` 的 `fp4_e2_t` 或 CUTLASS 的 FP4 类型)以及对应的 `DispatchInstruction` 特化。 | -| **CuTe/CUTLASS 对 FP4 的支持** | `UMMA::make_instr_desc`、`sm100_smem_selector` 等是否支持 4-bit 类型需确认;若 CUTLASS 已有 FP4 描述符/布局,需在 TileLang 侧对齐。 | -| **Block-scaled FP4(可选)** | NVIDIA 文档中的 `mxf4.block_scale` / `mxf4nvf4.block_scale` 为单独路径,若要做 block-scaled GEMM,需额外设计 scale 的存储与传入方式,可作二期。 | - ---- - -## 二、推荐实现步骤(按依赖顺序) - -### Phase 1:类型与描述符贯通 - -1. **PTX / common 数据类型** - - 在 `common.h` 的 `DataType` 中增加 `kFloat4_e2m1fn`(或与现有命名规范一致)。 - - 在 `ptx.cc` 的 `DTypeFromString` 中增加 `"float4_e2m1fn"`(及可选别名)映射到该枚举。 - - 在 `enum_to_str`、`num_bits`、`dtype_str` 等表中为 FP4 填好条目,保证 codegen 能正确生成类型名与位宽。 - -2. **TCGEN5 指令描述符** - - 查 CUTLASS/PTX 文档中 tcgen05.mma `kind::f8f6f4` 下 FP4 的 descriptor 编码(与 e4m3/e5m2 的差异)。 - - 在 `tcgen5_meta.h` 的 `GetTCGEN5InstrDesc()` 里为 `float4_e2m1fn` 增加 `encode_dtype` 分支,返回正确的 format 码,保证现有 GEMM Lower 路径在 A/B 为 FP4 时能生成合法描述符。 - -3. **验证路径** - - 写一个最小用例:只做「分配 FP4 buffer + 从 Python 传 dtype 到 codegen」,确认: - - TVM dtype → 字符串 → `DTypeFromString` → `DTypeEnumToString` 一致; - - 若已有 FP4 的 LDG/STG,可顺带跑一条简单 load/store 或 dequant 示例,确认无回归。 - -### Phase 2:C++ GEMM 模板与 Dispatch - -4. **C++ FP4 类型与 CuTe 对接** - - 确定 GEMM 使用的 C++ 类型:`cuda_fp4.h` 的 `fp4_e2_t` 或 CUTLASS 的等价类型。 - - 若 CUTLASS 有 `make_instr_desc` 对 FP4 的重载,在 TileLang 的 `to_cute_type`(或当前用于 F8/F6 的映射处)增加 TVM float4_e2m1fn → 该 C++ 类型的映射。 - - 若 CUTLASS 的 `sm100_smem_selector` 对 4-bit 有特殊布局(例如 packed 2×fp4/byte),在 `GemmTensorOp` 的 A/B layout 路径中接入,保证与 descriptor 一致。 - -5. **DispatchInstruction 与 SS/WS** - - 在 `gemm_sm100.h` 中为 FP4 增加 `DispatchInstruction`(以及 C=half 若需要),指向已有的 `SM100_MMA_F8F6F4_SS` / `SM100_MMA_F8F6F4_WS_SS`。 - - 确认 PTX 的 `kind::f8f6f4` 对 FP4 的 M/N/K 与 tile 约束与当前 meta 一致(K=32 等);必要时对照 CUTLASS 72a_blackwell_nvfp4_bf16_gemm 示例。 - -6. **端到端 GEMM 调用** - - 在 `gemm.cc` / `gemm_py.cc` 的 Lower 中,当 `a_`/`b_` 为 float4_e2m1fn 时,传入的 `kind_dtype` 字符串需与 `DTypeFromString` 一致(如 `"float4_e2m1fn"`)。 - - 跑一次 SM100 上 A/B 为 FP4、C 为 float32 的 GEMM kernel,对比参考实现(如 CUTLASS 72a)做数值与正确性检查。 - -### Phase 3:示例与文档 - -7. **Example** - - 在 `examples/gemm_sm100/` 或新建 `examples/gemm_fp4_sm100/` 增加一个 NV FP4 GEMM 示例:T.float4_e2m1fn 作为 A/B dtype,accum float32,与 ref(先 dequant 再 bf16/f32 GEMM)对比。 - - 若后续做 block-scaled,可在同一目录下加另一个示例,区分「无 scale」与「block scale」两种用法。 - -8. **文档与 CI** - - 在 `docs/programming_guides/type_system.md` 或 GEMM 相关文档中注明:SM100 TCGEN5 MMA 支持 float4_e2m1fn,以及当前限制(K%32、M/N 与 meta 一致、是否支持 block scale)。 - - CI 中为 SM100 增加一条 FP4 GEMM 的编译/运行(若环境有 SM100 或通过 skip 条件处理)。 - ---- - -## 三、风险与依赖 - -- **CUTLASS 版本**:FP4 描述符与 `make_instr_desc` 的细节可能随 CUTLASS 版本变化,需以当前 submodule 版本为准核对。 -- **PTXAS / CUDA**:`cuda_fp4.h` 中已有对 CUDA < 13 的 workaround(如 `__tl_cvt_fp4_to_halfraw_naive`),若运行环境 CUDA 版本较新,可再确认是否仍需要这些分支。 -- **Block-scaled FP4**:若产品需要 4× 吞吐的 block-scaled 路径,需单独设计 scale 张量、描述符与 API,工作量大于「仅非 block-scaled FP4 GEMM」。 - ---- - -## 四、与 Flash Attention SM100 的优先级 - -- Flash Attention SM100 的 example(含 wasp)可后续再迭代;当前不作为重点。 -- GEMM NV FP4 作为独立 feature,与 FA 的改动解耦;完成上述 Phase 1–2 即可得到可用的 FP4 GEMM 路径,再视需求加 block scale 或集成到更高层算子。 - ---- - -## 五、本次已完成的修改(CuTeDSL/TCGEN5 路径) - -以下修改已落地,使 **CuTeDSL 生成的 SM100 GEMM kernel** 在 A/B 为 `float4_e2m1fn` 时能正确走 `kind::f8f6f4` 的 FP4 路径。当前未改 `gemm_sm100.h` 的 CUTLASS `DispatchInstruction`(该路径由 CUTLASS 模板实例化,与 CuTeDSL 生成 TIR 再 codegen 的路径分离)。 - -| 文件 | 修改内容 | -|------|----------| -| **src/target/ptx.h** | `DataType` 枚举增加 `kFloat4_e2m1fn = 23`。 | -| **src/target/ptx.cc** | `enum_to_str` / `dtype_str` / `num_bits` 增加第 24 项;`DTypeFromString` 增加 `"float4_e2m1fn"`、`".e2m1"` → `kFloat4_e2m1fn`。 | -| **src/tl_templates/cuda/common.h** | `DataType` 枚举增加 `kFloat4_e2m1fn = 23`。 | -| **src/op/tcgen5_meta.h** | `GetTCGEN5InstrDesc()` 中 `encode_dtype` 增加 `dtype.is_float4_e2m1fn()` 分支,返回 `2`(FP4 format 编码)。 | -| **src/tl_templates/cuda/instruction/tcgen05mma.h** | 为 `DataType::kFloat4_e2m1fn` 增加 `tcgen05mma_ss`、`tcgen05mma_ts`、`tcgen05mma_ws_ss` 特化,均转发到 `kFloat8_e4m3` 的 f8f6f4 实现。 | -| **tilelang/intrinsics/mma_macro_generator.py** | `dtype_abbrv` 增加 `"float4_e2m1fn": "float4_e2m1fn"`,保证 lowering 时 `a_dtype_abbrv` 传入 `ptx_tcgen05_mma_ss("float4_e2m1fn", ...)`。 | - -**后续建议**(在有 SM100 / nvcc 环境时): - -1. 本地或 CI 执行完整构建并跑 `examples/gemm_sm100/` 中 BF16/F8 用例,确认无回归。 -2. 新增 FP4 GEMM 示例:`in_dtype=accum_dtype=T.float4_e2m1fn, T.float`,`block_K=128`(K%32 已由 meta 保证),与参考实现做数值对比。 -3. 若需走 CUTLASS 模板路径的 FP4 GEMM,再补 `gemm_sm100.h` 的 `to_cute_type` 与 `DispatchInstruction`。 diff --git a/examples/gemm_fp4/example_fusedmoe_a8w4_sm100.py b/examples/gemm_fp4/example_fusedmoe_a8w4_sm100.py new file mode 100644 index 0000000000..9e45cdf2b1 --- /dev/null +++ b/examples/gemm_fp4/example_fusedmoe_a8w4_sm100.py @@ -0,0 +1,171 @@ +"""Simplified A8W4 fused MoE shared expert on SM100/SM110. + +This mirrors the SM120 demo shape, but uses TCGEN05/TMEM: + gate = input(FP8) x W_gate(FP4) + up = input(FP8) x W_up(FP4) + out = silu(gate) * up + +The down projection is intentionally omitted here, matching the existing SM120 +diagnostic example and keeping the kernel focused on mixed FP8xFP4 TCGEN05 GEMM. +""" + +import os +import time + +import torch +import tilelang +import tilelang.language as T + + +FP4_E2M1_TO_FLOAT = [ + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, + -0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, +] + + +def unpack_fp4_to_float(packed_int8, rows, cols): + lut = torch.tensor(FP4_E2M1_TO_FLOAT, dtype=torch.float32, device=packed_int8.device) + flat = packed_int8.to(torch.uint8).reshape(rows, cols // 2) + lo = flat & 0x0F + hi = (flat >> 4) & 0x0F + unpacked = torch.stack([lo, hi], dim=-1).reshape(rows, cols).to(torch.int64) + return lut[unpacked] + + +def fusedmoe_a8w4_sm100(num_tokens, d_hidden, d_expert, block_token=128, block_hidden=128, block_expert=64, threads=128, num_stages=1): + scale = 1.44269504 # log2(e) + + @T.prim_func + def main( + Input: T.Tensor((num_tokens, d_hidden), "float8_e4m3fn"), + W_gate: T.Tensor((d_expert, d_hidden), T.float4_e2m1fn), + W_up: T.Tensor((d_expert, d_hidden), T.float4_e2m1fn), + Output: T.Tensor((num_tokens, d_expert), "float32"), + ): + with T.Kernel( + T.ceildiv(d_expert, block_expert), + T.ceildiv(num_tokens, block_token), + threads=threads, + ) as (bx, by): + input_shared = T.alloc_shared((block_token, block_hidden), "float8_e4m3fn") + gate_shared = T.alloc_shared((block_expert, block_hidden), T.float4_e2m1fn) + up_shared = T.alloc_shared((block_expert, block_hidden), T.float4_e2m1fn) + + gate_tmem = T.alloc_tmem([block_token, block_expert], "float32") + up_tmem = T.alloc_tmem([block_token, block_expert], "float32") + gate_mbar = T.alloc_barrier(1) + up_mbar = T.alloc_barrier(1) + + gate_local = T.alloc_fragment((block_token, block_expert), "float32") + up_local = T.alloc_fragment((block_token, block_expert), "float32") + + for k in T.Pipelined(T.ceildiv(d_hidden, block_hidden), num_stages=num_stages): + T.copy(Input[by * block_token, k * block_hidden], input_shared) + T.copy(W_gate[bx * block_expert, k * block_hidden], gate_shared) + T.copy(W_up[bx * block_expert, k * block_hidden], up_shared) + + T.tcgen05_gemm( + input_shared, + gate_shared, + gate_tmem, + transpose_A=False, + transpose_B=True, + mbar=gate_mbar, + clear_accum=(k == 0), + ) + T.mbarrier_wait_parity(gate_mbar, k % 2) + + T.tcgen05_gemm( + input_shared, + up_shared, + up_tmem, + transpose_A=False, + transpose_B=True, + mbar=up_mbar, + clear_accum=(k == 0), + ) + T.mbarrier_wait_parity(up_mbar, k % 2) + + T.copy(gate_tmem, gate_local) + T.copy(up_tmem, up_local) + + for i, j in T.Parallel(block_token, block_expert): + gate = gate_local[i, j] * (1.0 / (1.0 + T.exp2(-gate_local[i, j] * scale))) + up_local[i, j] = up_local[i, j] * gate + + T.copy(up_local, Output[by * block_token, bx * block_expert]) + + return main + + +num_tokens = int(os.environ.get("TL_MOE_TOKENS", "128")) +d_hidden = int(os.environ.get("TL_MOE_HIDDEN", "256")) +d_expert = int(os.environ.get("TL_MOE_EXPERT", "256")) +block_token = int(os.environ.get("TL_MOE_BLOCK_TOKEN", "128")) +block_hidden = int(os.environ.get("TL_MOE_BLOCK_HIDDEN", "128")) +block_expert = int(os.environ.get("TL_MOE_BLOCK_EXPERT", "64")) + +print( + f"Running SM100 A8W4 fused MoE: tokens={num_tokens}, hidden={d_hidden}, " + f"expert={d_expert}, block=({block_token},{block_hidden},{block_expert})" +) + +func = fusedmoe_a8w4_sm100(num_tokens, d_hidden, d_expert, block_token, block_hidden, block_expert) +jit_kernel = tilelang.compile( + func, + out_idx=[3], + target="cuda", + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: False, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, +) +print("Compilation succeeded!") + +torch.manual_seed(42) +input_fp8 = torch.randn(num_tokens, d_hidden, device="cuda", dtype=torch.float16).to(torch.float8_e4m3fn) +w_gate = torch.randint(0, 256, (d_expert, d_hidden // 2), device="cuda", dtype=torch.uint8).to(torch.int8) +w_up = torch.randint(0, 256, (d_expert, d_hidden // 2), device="cuda", dtype=torch.uint8).to(torch.int8) + +z_input = torch.zeros(num_tokens, d_hidden, device="cuda", dtype=torch.float8_e4m3fn) +z_gate = torch.zeros(d_expert, d_hidden // 2, device="cuda", dtype=torch.int8) +z_up = torch.zeros(d_expert, d_hidden // 2, device="cuda", dtype=torch.int8) +c_zero = jit_kernel(z_input, z_gate, z_up) +assert c_zero.abs().max().item() == 0.0, f"Zero test failed: max={c_zero.abs().max().item()}" +print("[PASS] zeros in -> zeros out") + +out = jit_kernel(input_fp8, w_gate, w_up) +input_f32 = input_fp8.to(torch.float32) +gate_logits = input_f32 @ unpack_fp4_to_float(w_gate, d_expert, d_hidden).T +up_logits = input_f32 @ unpack_fp4_to_float(w_up, d_expert, d_hidden).T +ref = up_logits * (gate_logits * torch.sigmoid(gate_logits)) + +diff = (out.float() - ref).abs() +max_diff = diff.max().item() +rel_err = diff.sum().item() / (ref.abs().sum().item() + 1e-10) +print(f"[NUMERICAL] max_abs_diff={max_diff:.4f}, rel_err={rel_err:.6f}") +print("[PASS] MoE numerical verification" if rel_err < 0.05 else "[WARN] large diff") + +torch.cuda.synchronize() +start = time.perf_counter() +for _ in range(100): + jit_kernel(input_fp8, w_gate, w_up) +torch.cuda.synchronize() +elapsed = (time.perf_counter() - start) / 100 * 1000 +total_flops = 2 * num_tokens * d_hidden * d_expert * 2 +print(f"Latency: {elapsed:.4f} ms") +print(f"TFLOPS: {total_flops / (elapsed / 1e3) / 1e12:.2f}") diff --git a/examples/gemm_fp4/example_fusedmoe_a8w4_sm120.py b/examples/gemm_fp4/example_fusedmoe_a8w4_sm120.py index 0035786347..fbd06a7e9d 100644 --- a/examples/gemm_fp4/example_fusedmoe_a8w4_sm120.py +++ b/examples/gemm_fp4/example_fusedmoe_a8w4_sm120.py @@ -13,7 +13,6 @@ see examples/fusedmoe/example_fusedmoe_tilelang.py. """ -import os import time import torch import tilelang @@ -21,8 +20,22 @@ FP4_E2M1_TO_FLOAT = [ - 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, - -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0, + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, + -0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, ] @@ -32,9 +45,14 @@ def fp4_uint8_to_float(t): def moe_shared_expert_a8w4( - num_tokens, d_hidden, d_expert, - block_token=128, block_hidden=128, block_expert=128, - threads=128, num_stages=1, + num_tokens, + d_hidden, + d_expert, + block_token=128, + block_hidden=128, + block_expert=128, + threads=128, + num_stages=1, ): """Single shared expert: gate_up GEMM -> SiLU*up -> down GEMM.""" scale = 1.44269504 # log2(e) for fast SiLU @@ -72,9 +90,7 @@ def main( # Fused SiLU activation: gate = gate * sigmoid(gate), then up = up * gate for i, j in T.Parallel(block_token, block_expert): - gate_local[i, j] = gate_local[i, j] * ( - 1.0 / (1.0 + T.exp2(-gate_local[i, j] * scale)) - ) + gate_local[i, j] = gate_local[i, j] * (1.0 / (1.0 + T.exp2(-gate_local[i, j] * scale))) up_local[i, j] = up_local[i, j] * gate_local[i, j] T.copy(up_local, output[bx * block_token, by * block_expert]) @@ -90,9 +106,14 @@ def main( print(f"Running FP4 MoE (A8W4): tokens={num_tokens}, hidden={d_hidden}, expert={d_expert}") func = moe_shared_expert_a8w4( - num_tokens, d_hidden, d_expert, - block_token=128, block_hidden=128, block_expert=128, - threads=128, num_stages=1, + num_tokens, + d_hidden, + d_expert, + block_token=128, + block_hidden=128, + block_expert=128, + threads=128, + num_stages=1, ) jit_kernel = tilelang.compile( diff --git a/examples/gemm_fp4/example_gemm_a8w4_sm100.py b/examples/gemm_fp4/example_gemm_a8w4_sm100.py new file mode 100644 index 0000000000..6b87f7c18d --- /dev/null +++ b/examples/gemm_fp4/example_gemm_a8w4_sm100.py @@ -0,0 +1,129 @@ +"""A8W4 GEMM on SM100/SM110 using TCGEN05 f8f6f4 MMA. + +A is FP8 e4m3 activation. B is FP4 e2m1 weight stored densely packed on the +PyTorch side and loaded through the SM100 unpacksmem TMA/shared path. +""" + +import os +import time + +import torch +import tilelang +import tilelang.language as T + + +FP4_E2M1_TO_FLOAT = [ + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, + -0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, +] + + +def unpack_fp4_to_float(packed_int8, rows, cols): + lut = torch.tensor(FP4_E2M1_TO_FLOAT, dtype=torch.float32, device=packed_int8.device) + flat = packed_int8.to(torch.uint8).reshape(rows, cols // 2) + lo = flat & 0x0F + hi = (flat >> 4) & 0x0F + unpacked = torch.stack([lo, hi], dim=-1).reshape(rows, cols).to(torch.int64) + return lut[unpacked] + + +def matmul_a8w4_sm100(M, N, K, block_M, block_N, block_K, out_dtype, accum_dtype, num_stages=1, threads=128): + A_shape = (M, K) + B_shape = (N, K) + + @T.prim_func + def main( + A: T.Tensor(A_shape, "float8_e4m3fn"), + B: T.Tensor(B_shape, T.float4_e2m1fn), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), "float8_e4m3fn") + B_shared = T.alloc_shared((block_N, block_K), T.float4_e2m1fn) + C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype) + mbar = T.alloc_barrier(1) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_shared = T.alloc_shared((block_M, block_N), out_dtype) + + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[bx * block_N, k * block_K], B_shared) + T.tcgen05_gemm( + A_shared, + B_shared, + C_tmem, + transpose_A=False, + transpose_B=True, + mbar=mbar, + clear_accum=(k == 0), + ) + T.mbarrier_wait_parity(mbar, k % 2) + + T.copy(C_tmem, C_local) + T.copy(C_local, C_shared) + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return main + + +M = int(os.environ.get("TL_A8W4_M", "256")) +N = int(os.environ.get("TL_A8W4_N", "256")) +K = int(os.environ.get("TL_A8W4_K", "256")) +block_M = int(os.environ.get("TL_A8W4_BLOCK_M", "128")) +block_N = int(os.environ.get("TL_A8W4_BLOCK_N", "64")) +block_K = int(os.environ.get("TL_A8W4_BLOCK_K", "128")) + +print(f"Running SM100 A8W4 GEMM: M={M}, N={N}, K={K}, block=({block_M},{block_N},{block_K})") + +func = matmul_a8w4_sm100(M, N, K, block_M, block_N, block_K, T.float32, T.float32) +jit_kernel = tilelang.compile( + func, + out_idx=[2], + target="cuda", + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: False, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, +) +print("Compilation succeeded!") + +torch.manual_seed(42) +a_fp8 = torch.randn(M, K, device="cuda", dtype=torch.float16).to(torch.float8_e4m3fn) +b_packed = torch.randint(0, 256, (N, K // 2), device="cuda", dtype=torch.uint8).to(torch.int8) + +a_zero = torch.zeros(M, K, device="cuda", dtype=torch.float8_e4m3fn) +b_zero = torch.zeros(N, K // 2, device="cuda", dtype=torch.int8) +c_zero = jit_kernel(a_zero, b_zero) +assert c_zero.abs().max().item() == 0.0, f"Zero test failed: max={c_zero.abs().max().item()}" +print("[PASS] zeros in -> zeros out") + +c = jit_kernel(a_fp8, b_packed) +ref = a_fp8.to(torch.float32) @ unpack_fp4_to_float(b_packed, N, K).T +diff = (c.float() - ref).abs() +max_diff = diff.max().item() +rel_err = diff.sum().item() / (ref.abs().sum().item() + 1e-10) +print(f"[NUMERICAL] max_abs_diff={max_diff:.4f}, rel_err={rel_err:.6f}") +print("[PASS] numerical verification" if max_diff < 1.0 else "[WARN] large diff") + +torch.cuda.synchronize() +start = time.perf_counter() +for _ in range(100): + jit_kernel(a_fp8, b_packed) +torch.cuda.synchronize() +elapsed = (time.perf_counter() - start) / 100 * 1000 +print(f"Latency: {elapsed:.4f} ms") +print(f"TFLOPS: {2 * M * N * K / (elapsed / 1e3) / 1e12:.2f}") diff --git a/examples/gemm_fp4/example_gemm_a8w4_sm120.py b/examples/gemm_fp4/example_gemm_a8w4_sm120.py index a5174e7493..b48564ba56 100644 --- a/examples/gemm_fp4/example_gemm_a8w4_sm120.py +++ b/examples/gemm_fp4/example_gemm_a8w4_sm120.py @@ -8,7 +8,6 @@ No block scaling. Direct FP8 x FP4 tensor core multiply-accumulate. """ -import os import time import torch import tilelang @@ -16,9 +15,16 @@ def gemm_a8w4( - M, N, K, block_M, block_N, block_K, - out_dtype, accum_dtype, - num_stages=2, threads=128, + M, + N, + K, + block_M, + block_N, + block_K, + out_dtype, + accum_dtype, + num_stages=2, + threads=128, ): A_shape = (M, K) B_shape = (N, K) @@ -46,8 +52,22 @@ def main( FP4_E2M1_TO_FLOAT = [ - 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, - -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0, + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, + -0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, ] @@ -63,12 +83,19 @@ def fp4_uint8_to_float(tensor_uint8): accum_dtype = T.float32 print(f"Running A8W4 GEMM: M={M}, N={N}, K={K}") -print(f" A: float8_e4m3fn, B: FP4 (unpacked uint8)") +print(" A: float8_e4m3fn, B: FP4 (unpacked uint8)") func = gemm_a8w4( - M, N, K, block_M, block_N, block_K, - out_dtype, accum_dtype, - num_stages=2, threads=128, + M, + N, + K, + block_M, + block_N, + block_K, + out_dtype, + accum_dtype, + num_stages=2, + threads=128, ) jit_kernel = tilelang.compile( @@ -113,7 +140,7 @@ def fp4_uint8_to_float(tensor_uint8): if rel_err < 0.01: print("[PASS] numerical verification (rel_err < 0.01)") else: - print(f"[WARN] large diff -- may indicate layout or data flow issue") + print("[WARN] large diff -- may indicate layout or data flow issue") # --- Benchmark --- torch.cuda.synchronize() diff --git a/examples/gemm_sm100/gemm_tcgen5mma_fp4.py b/examples/gemm_fp4/example_gemm_fp4_sm100.py similarity index 75% rename from examples/gemm_sm100/gemm_tcgen5mma_fp4.py rename to examples/gemm_fp4/example_gemm_fp4_sm100.py index 28195c4942..9ff596f49f 100644 --- a/examples/gemm_sm100/gemm_tcgen5mma_fp4.py +++ b/examples/gemm_fp4/example_gemm_fp4_sm100.py @@ -1,36 +1,55 @@ """FP4 (float4_e2m1fn) GEMM on SM100/SM110 (B200/Thor) using TCGEN05 async MMA. Uses tcgen05.mma.kind::f8f6f4 with TMEM accumulator. -The current TMA path keeps global FP4 packed, then writes it into SMEM using the -ALIGN16B gap-aware layout expected by tcgen05.mma. +The TMA path keeps global FP4 packed, then writes it into SMEM using the +ALIGN16B unpacksmem layout expected by tcgen05.mma. Supported: SM100 (B100/B200), SM101/SM110 (DRIVE Thor), SM103 (B300). """ -import time import os +import time + import torch import tilelang import tilelang.language as T FP4_E2M1_TO_FLOAT = [ - 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, - -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0, + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, + -0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, ] def matmul_fp4_sm100( - M, N, K, block_M, block_N, block_K, - in_dtype, out_dtype, accum_dtype, - num_stages=2, threads=256, + M, + N, + K, + block_M, + block_N, + block_K, + in_dtype, + out_dtype, + accum_dtype, + num_stages=2, + threads=256, transpose_b=True, ): - """FP4 GEMM using TCGEN05 async MMA + TMEM. - - TCGEN05 handles packed FP4 natively — no unpack, no ldmatrix, no bit-shift. - Data flow: global(packed) -> TMA -> SMEM(ALIGN16B gap-aware) -> TCGEN05 MMA -> TMEM. - """ + """FP4 GEMM using TCGEN05 async MMA + TMEM.""" A_shape = (M, K) A_shared_shape = (block_M, block_K) if transpose_b: @@ -55,9 +74,13 @@ def main( T.copy(A[by * block_M, k * block_K], A_shared) T.copy(B[bx * block_N, k * block_K], B_shared) T.tcgen05_gemm( - A_shared, B_shared, C_tmem, - transpose_A=False, transpose_B=True, - mbar=mbar, clear_accum=(k == 0), + A_shared, + B_shared, + C_tmem, + transpose_A=False, + transpose_B=True, + mbar=mbar, + clear_accum=(k == 0), ) T.mbarrier_wait_parity(mbar, k % 2) @@ -88,9 +111,13 @@ def main( T.copy(A[by * block_M, k * block_K], A_shared) T.copy(B[k * block_K, bx * block_N], B_shared) T.tcgen05_gemm( - A_shared, B_shared, C_tmem, - transpose_A=False, transpose_B=False, - mbar=mbar, clear_accum=(k == 0), + A_shared, + B_shared, + C_tmem, + transpose_A=False, + transpose_B=False, + mbar=mbar, + clear_accum=(k == 0), ) T.mbarrier_wait_parity(mbar, k % 2) @@ -113,14 +140,6 @@ def unpack_fp4_to_float(packed_int8, M, K): M = int(os.environ.get("TL_FP4_M", "256")) N = int(os.environ.get("TL_FP4_N", "256")) K = int(os.environ.get("TL_FP4_K", "256")) -# -# Thor has a tighter per-CTA dynamic shared-memory budget than server Blackwell. -# Keep K=128 for the current ALIGN16B gap-aware layout, but reduce N so the -# example launches reliably before we debug numerical accuracy. -# -# We intentionally keep block_M=128 so TCGEN05 stores C through the standard -# atom_m=128 TMEM layout (Layout D) instead of the atom_m=64 WS-specific layout. -# block_M = int(os.environ.get("TL_FP4_BLOCK_M", "128")) block_N = int(os.environ.get("TL_FP4_BLOCK_N", "64")) block_K = int(os.environ.get("TL_FP4_BLOCK_K", "128")) @@ -132,15 +151,21 @@ def unpack_fp4_to_float(packed_int8, M, K): input_mode = os.environ.get("TL_FP4_INPUT_MODE", "random") transpose_b = os.environ.get("TL_FP4_TRANSPOSE_B", "1") != "0" -print( - f"Running FP4 GEMM (SM100/SM110 TCGEN05): M={M}, N={N}, K={K}, " - f"input_mode={input_mode}, transpose_b={transpose_b}" -) +print(f"Running FP4 GEMM (SM100/SM110 TCGEN05): M={M}, N={N}, K={K}, input_mode={input_mode}, transpose_b={transpose_b}") func = matmul_fp4_sm100( - M, N, K, block_M, block_N, block_K, - in_dtype, out_dtype, accum_dtype, - num_stages, threads, transpose_b, + M, + N, + K, + block_M, + block_N, + block_K, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, + transpose_b, ) jit_kernel = tilelang.compile( @@ -157,6 +182,7 @@ def unpack_fp4_to_float(packed_int8, M, K): torch.manual_seed(42) + def make_random_fp4(rows, cols, mode): if mode == "positive": lo = torch.randint(0, 8, (rows, cols // 2), device="cuda", dtype=torch.uint8) @@ -173,22 +199,15 @@ def make_random_fp4(rows, cols, mode): raise ValueError(f"Unsupported TL_FP4_INPUT_MODE={mode}") -# Packed FP4 tensors (2 per byte) — TCGEN05 reads packed data directly. a_packed = make_random_fp4(M, K, input_mode) b_packed = make_random_fp4(N, K, input_mode) if transpose_b else make_random_fp4(K, N, input_mode) -# --- Test 1: zeros --- a_zero = torch.zeros(M, K // 2, device="cuda", dtype=torch.int8) -b_zero = ( - torch.zeros(N, K // 2, device="cuda", dtype=torch.int8) - if transpose_b - else torch.zeros(K, N // 2, device="cuda", dtype=torch.int8) -) +b_zero = torch.zeros(N, K // 2, device="cuda", dtype=torch.int8) if transpose_b else torch.zeros(K, N // 2, device="cuda", dtype=torch.int8) c_zero = jit_kernel(a_zero, b_zero) assert c_zero.abs().max().item() == 0.0, f"Zero test failed: max={c_zero.abs().max().item()}" print("[PASS] zeros in -> zeros out") -# --- Test 2: numerical verification --- c = jit_kernel(a_packed, b_packed) a_float = unpack_fp4_to_float(a_packed, M, K) b_float = unpack_fp4_to_float(b_packed, N, K) if transpose_b else unpack_fp4_to_float(b_packed, K, N) @@ -198,12 +217,8 @@ def make_random_fp4(rows, cols, mode): max_diff = diff.max().item() rel_err = diff.sum().item() / (ref_c.abs().sum().item() + 1e-10) print(f"[NUMERICAL] max_abs_diff={max_diff:.4f}, rel_err={rel_err:.6f}") -if max_diff < 1.0: - print("[PASS] numerical verification") -else: - print("[WARN] large diff") +print("[PASS] numerical verification" if max_diff < 1.0 else "[WARN] large diff") -# --- Benchmark --- torch.cuda.synchronize() start = time.perf_counter() for _ in range(100): diff --git a/examples/gemm_fp4/example_gemm_fp4_sm120.py b/examples/gemm_fp4/example_gemm_fp4_sm120.py index 7d85319d49..c79fab8df3 100644 --- a/examples/gemm_fp4/example_gemm_fp4_sm120.py +++ b/examples/gemm_fp4/example_gemm_fp4_sm120.py @@ -17,9 +17,16 @@ def matmul_fp4( - M, N, K, block_M, block_N, block_K, - out_dtype, accum_dtype, - num_stages=2, threads=128, + M, + N, + K, + block_M, + block_N, + block_K, + out_dtype, + accum_dtype, + num_stages=2, + threads=128, ): A_shape = (M, K) B_shape = (N, K) @@ -49,8 +56,22 @@ def main( FP4_E2M1_TO_FLOAT = [ - 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, - -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0, + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, + -0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, ] @@ -78,9 +99,16 @@ def unpack_fp4_to_float(packed_int8: torch.Tensor, M: int, K: int) -> torch.Tens print(f" block_M={block_M}, block_N={block_N}, block_K={block_K}") func = matmul_fp4( - M, N, K, block_M, block_N, block_K, - out_dtype, accum_dtype, - num_stages=2, threads=128, + M, + N, + K, + block_M, + block_N, + block_K, + out_dtype, + accum_dtype, + num_stages=2, + threads=128, ) jit_kernel = tilelang.compile( @@ -126,7 +154,7 @@ def unpack_fp4_to_float(packed_int8: torch.Tensor, M: int, K: int) -> torch.Tens if max_diff < 1.0: print("[PASS] numerical verification (max_abs_diff < 1.0)") else: - print(f"[WARN] large diff -- may indicate layout or data flow issue") + print("[WARN] large diff -- may indicate layout or data flow issue") # --- Benchmark --- torch.cuda.synchronize() diff --git a/examples/gemm_fp4/gemm_fp4_sm100.cu b/examples/gemm_fp4/gemm_fp4_sm100.cu deleted file mode 100644 index 741a79427c..0000000000 --- a/examples/gemm_fp4/gemm_fp4_sm100.cu +++ /dev/null @@ -1,100 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#ifdef ENABLE_BF16 -#include -#endif - -extern "C" __global__ void main_kernel(__grid_constant__ const CUtensorMap A_desc, __grid_constant__ const CUtensorMap B_desc, __grid_constant__ const CUtensorMap C_desc); -extern "C" __global__ void __launch_bounds__(256, 1) main_kernel(__grid_constant__ const CUtensorMap A_desc, __grid_constant__ const CUtensorMap B_desc, __grid_constant__ const CUtensorMap C_desc) { - extern __shared__ __align__(1024) uchar buf_dyn_shmem[]; - __shared__ __align__(16) uint64_t mbar_mem[2]; - auto mbar = reinterpret_cast(mbar_mem); - __shared__ __align__(16) uint64_t mbarrier_mem[4]; - auto mbarrier = reinterpret_cast(mbarrier_mem); - __shared__ __align__(16) uint C_tmem[1]; - tl::Tcgen05SMemDescriptor desc_a; - tl::Tcgen05SMemDescriptor desc_b; - float C_local[64]; - if (tl::tl_shuffle_elect<0>()) { - tl::prefetch_tma_descriptor(A_desc); - tl::prefetch_tma_descriptor(B_desc); - tl::prefetch_tma_descriptor(C_desc); - } - if (tl::tl_shuffle_elect<0>()) { - mbar[0].init(1); - mbar[1].init(1); - mbarrier[0].init(1); - mbarrier[1].init(1); - mbarrier[2].init(1); - mbarrier[3].init(1); - } - tl::fence_barrier_init(); - __syncthreads(); - if ((((int)threadIdx.x) >> 5) == 0) { - tl::tmem_allocate((&(C_tmem[0])), 128); - } - __syncthreads(); - if (((int)threadIdx.x) == 0) { - mbarrier[0].arrive_and_expect_tx(16256); - tl::tma_load(A_desc, mbarrier[0], (&(((fp4_e2_t*)buf_dyn_shmem)[0])), 0, (((int)blockIdx.y) * 128)); - mbarrier[1].arrive_and_expect_tx(16256); - tl::tma_load(B_desc, mbarrier[1], (&(((fp4_e2_t*)buf_dyn_shmem)[130048])), 0, (((int)blockIdx.x) * 128)); - mbarrier[2].arrive_and_expect_tx(16256); - tl::tma_load(A_desc, mbarrier[2], (&(((fp4_e2_t*)buf_dyn_shmem)[97536])), 128, (((int)blockIdx.y) * 128)); - mbarrier[3].arrive_and_expect_tx(16256); - tl::tma_load(B_desc, mbarrier[3], (&(((fp4_e2_t*)buf_dyn_shmem)[227584])), 128, (((int)blockIdx.x) * 128)); - } - mbarrier[0].wait(0); - mbarrier[1].wait(0); - __syncthreads(); - if ((((int)threadIdx.x) >> 5) == 0) { - tl::initialize_tcgen05_descriptor(desc_a, (&(((fp4_e2_t*)buf_dyn_shmem)[0])), 1, 64, 0, 0, 2); - tl::initialize_tcgen05_descriptor(desc_b, (&(((fp4_e2_t*)buf_dyn_shmem)[130048])), 1, 64, 0, 0, 2); - #pragma unroll - for (int ki = 0; ki < 4; ++ki) { - tl::tcgen05mma_ss(uint64_t(desc_a + (ki * 32)), uint64_t(desc_b + (ki * 32)), (*reinterpret_cast(C_tmem)) + 0, ((0 < ki) ? 1 : 0), static_cast(136320656), 0, 0, 0, 0); - } - tl::tcgen05_mma_arrive((&(mbar[0]))); - } - mbar[0].wait(0); - mbarrier[2].wait(0); - mbarrier[3].wait(0); - if ((((int)threadIdx.x) >> 5) == 0) { - tl::initialize_tcgen05_descriptor(desc_a, (&(((fp4_e2_t*)buf_dyn_shmem)[97536])), 1, 64, 0, 0, 2); - tl::initialize_tcgen05_descriptor(desc_b, (&(((fp4_e2_t*)buf_dyn_shmem)[227584])), 1, 64, 0, 0, 2); - tl::fence_proxy_async(); - #pragma unroll - for (int ki_1 = 0; ki_1 < 4; ++ki_1) { - tl::tcgen05mma_ss(uint64_t(desc_a + (ki_1 * 32)), uint64_t(desc_b + (ki_1 * 32)), (*reinterpret_cast(C_tmem)) + 0, 1, static_cast(136320656), 0, 0, 0, 0); - } - tl::tcgen05_mma_arrive((&(mbar[1]))); - } - mbar[1].wait(0); - tl::tcgen05_ld_32dp32bNx<64, false>(C_tmem[0], ((((int)threadIdx.x) >> 7) * 64), (&(C_local[0]))); - __syncthreads(); - #pragma unroll - for (int i = 0; i < 16; ++i) { - *(float4*)(((float*)buf_dyn_shmem) + (((((((((int)threadIdx.x) >> 7) * 8192) + ((i >> 3) * 4096)) + ((((int)threadIdx.x) & 127) * 32)) + (((((i & 7) >> 2) + ((((int)threadIdx.x) & 7) >> 2)) & 1) * 16)) + (((((i & 3) >> 1) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 8)) + ((((i & 1) + (((int)threadIdx.x) & 1)) & 1) * 4))) = *(float4*)(C_local + (i * 4)); - } - __syncthreads(); - if (((int)threadIdx.x) == 0) { - tl::fence_proxy_async(); - tl::tma_store(C_desc, (&(((float*)buf_dyn_shmem)[0])), (((int)blockIdx.x) * 128), (((int)blockIdx.y) * 128)); - tl::tma_store(C_desc, (&(((float*)buf_dyn_shmem)[4096])), ((((int)blockIdx.x) * 128) + 32), (((int)blockIdx.y) * 128)); - tl::tma_store(C_desc, (&(((float*)buf_dyn_shmem)[8192])), ((((int)blockIdx.x) * 128) + 64), (((int)blockIdx.y) * 128)); - tl::tma_store(C_desc, (&(((float*)buf_dyn_shmem)[12288])), ((((int)blockIdx.x) * 128) + 96), (((int)blockIdx.y) * 128)); - tl::tma_store_arrive(); - tl::tma_store_wait<0>(); - } - if ((((int)threadIdx.x) >> 5) == 0) { - tl::tmem_deallocate((&(C_tmem[0])), 128); - } -} - diff --git a/examples/gemm_fp4/gemm_fp4_sm120.cu b/examples/gemm_fp4/gemm_fp4_sm120.cu deleted file mode 100644 index 85f0abc5c9..0000000000 --- a/examples/gemm_fp4/gemm_fp4_sm120.cu +++ /dev/null @@ -1,74 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include -#include -#ifdef ENABLE_BF16 -#include -#endif - -extern "C" __global__ void main_kernel(const fp4_e2_t* __restrict__ A, const fp4_e2_t* __restrict__ B, float* __restrict__ C); -extern "C" __global__ void __launch_bounds__(128, 1) main_kernel(const fp4_e2_t* __restrict__ A, const fp4_e2_t* __restrict__ B, float* __restrict__ C) { - extern __shared__ __align__(1024) uchar buf_dyn_shmem[]; - float C_local[128]; - fp4_e2_2_t A_local_packed[32]; - fp4_e2_2_t B_local_packed[32]; - #pragma unroll - for (int i = 0; i < 32; ++i) { - float broadcast_var = 0x0p+0f/*0.000000e+00*/; - *(float4*)(C_local + (i * 4)) = make_float4(broadcast_var, broadcast_var, broadcast_var, broadcast_var); - } - #pragma unroll - for (int i_1 = 0; i_1 < 4; ++i_1) { - *(fp4_e2_32_t*)(((fp4_e2_t*)buf_dyn_shmem) + ((((i_1 * 2048) + ((((int)threadIdx.x) >> 2) * 64)) + (((((((int)threadIdx.x) & 31) >> 4) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 15) >> 3) + (((int)threadIdx.x) & 1)) & 1) * 16))) = *(fp4_e2_32_t*)(A + ((((((int)blockIdx.y) * 16384) + (i_1 * 4096)) + ((((int)threadIdx.x) >> 2) * 128)) + ((((int)threadIdx.x) & 3) * 16))); - } - #pragma unroll - for (int i_2 = 0; i_2 < 4; ++i_2) { - *(fp4_e2_32_t*)(((fp4_e2_t*)buf_dyn_shmem) + (((((i_2 * 2048) + ((((int)threadIdx.x) >> 2) * 64)) + (((((((int)threadIdx.x) & 31) >> 4) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 15) >> 3) + (((int)threadIdx.x) & 1)) & 1) * 16)) + 16384)) = *(fp4_e2_32_t*)(B + ((((((int)blockIdx.x) * 16384) + (i_2 * 4096)) + ((((int)threadIdx.x) >> 2) * 128)) + ((((int)threadIdx.x) & 3) * 16))); - } - #pragma unroll - for (int i_3 = 0; i_3 < 4; ++i_3) { - *(fp4_e2_32_t*)(((fp4_e2_t*)buf_dyn_shmem) + (((((i_3 * 2048) + ((((int)threadIdx.x) >> 2) * 64)) + (((((((int)threadIdx.x) & 31) >> 4) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 15) >> 3) + (((int)threadIdx.x) & 1)) & 1) * 16)) + 8192)) = *(fp4_e2_32_t*)(A + (((((((int)blockIdx.y) * 16384) + (i_3 * 4096)) + ((((int)threadIdx.x) >> 2) * 128)) + ((((int)threadIdx.x) & 3) * 16)) + 64)); - } - #pragma unroll - for (int i_4 = 0; i_4 < 4; ++i_4) { - *(fp4_e2_32_t*)(((fp4_e2_t*)buf_dyn_shmem) + (((((i_4 * 2048) + ((((int)threadIdx.x) >> 2) * 64)) + (((((((int)threadIdx.x) & 31) >> 4) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 15) >> 3) + (((int)threadIdx.x) & 1)) & 1) * 16)) + 24576)) = *(fp4_e2_32_t*)(B + (((((((int)blockIdx.x) * 16384) + (i_4 * 4096)) + ((((int)threadIdx.x) >> 2) * 128)) + ((((int)threadIdx.x) & 3) * 16)) + 64)); - } - __syncthreads(); - for (int ki = 0; ki < 4; ++ki) { - for (int i_5 = 0; i_5 < 4; ++i_5) { - tl::ptx_ldmatrix_x4((&(((fp4_e2_t*)buf_dyn_shmem)[(((((((((int)threadIdx.x) & 63) >> 5) * 8192) + (i_5 * 2048)) + (((((int)threadIdx.x) & 15) >> 3) * 1024)) + ((((((((int)threadIdx.x) & 15) * 128) + (((((((int)threadIdx.x) & 7) >> 2) + (ki >> 1)) & 1) * 64)) + (((((((int)threadIdx.x) & 3) >> 1) + (ki & 1)) & 1) * 32)) + (((((int)threadIdx.x) & 31) >> 4) * 16)) & 1023)) / 2)])) + 0, A_local_packed + (i_5 * 16)); - } - for (int i_6 = 0; i_6 < 4; ++i_6) { - tl::ptx_ldmatrix_x4((&(((fp4_e2_t*)buf_dyn_shmem)[(((((((((((int)threadIdx.x) >> 6) * 4096) + (i_6 * 1024)) + (((((int)threadIdx.x) & 31) >> 4) * 512)) + ((((int)threadIdx.x) & 7) * 64)) + (((((((int)threadIdx.x) & 7) >> 2) + (ki >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 3) >> 1) + (ki & 1)) & 1) * 16)) + (((((int)threadIdx.x) & 15) >> 3) * 8)) + 16384)])) + 0, B_local_packed + (i_6 * 16)); - } - for (int i_7 = 0; i_7 < 4; ++i_7) { - for (int j = 0; j < 4; ++j) { - tl::mma_sync(reinterpret_cast(C_local + ((i_7 * 32) + (j * 8))), reinterpret_cast(A_local_packed + (i_7 * 16)), reinterpret_cast(B_local_packed + (j * 16))); - tl::mma_sync(reinterpret_cast(C_local + (((i_7 * 32) + (j * 8)) + 4)), reinterpret_cast(A_local_packed + (i_7 * 16)), reinterpret_cast(B_local_packed + ((j * 16) + 8))); - } - } - } - for (int ki_1 = 0; ki_1 < 4; ++ki_1) { - for (int i_8 = 0; i_8 < 4; ++i_8) { - tl::ptx_ldmatrix_x4((&(((fp4_e2_t*)buf_dyn_shmem)[((((((((((int)threadIdx.x) & 63) >> 5) * 8192) + (i_8 * 2048)) + (((((int)threadIdx.x) & 15) >> 3) * 1024)) + ((((((((int)threadIdx.x) & 15) * 128) + (((((((int)threadIdx.x) & 7) >> 2) + (ki_1 >> 1)) & 1) * 64)) + (((((((int)threadIdx.x) & 3) >> 1) + (ki_1 & 1)) & 1) * 32)) + (((((int)threadIdx.x) & 31) >> 4) * 16)) & 1023)) + 16384) / 2)])) + 0, A_local_packed + (i_8 * 16)); - } - for (int i_9 = 0; i_9 < 4; ++i_9) { - tl::ptx_ldmatrix_x4((&(((fp4_e2_t*)buf_dyn_shmem)[(((((((((((int)threadIdx.x) >> 6) * 4096) + (i_9 * 1024)) + (((((int)threadIdx.x) & 31) >> 4) * 512)) + ((((int)threadIdx.x) & 7) * 64)) + (((((((int)threadIdx.x) & 7) >> 2) + (ki_1 >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 3) >> 1) + (ki_1 & 1)) & 1) * 16)) + (((((int)threadIdx.x) & 15) >> 3) * 8)) + 24576)])) + 0, B_local_packed + (i_9 * 16)); - } - for (int i_10 = 0; i_10 < 4; ++i_10) { - for (int j_1 = 0; j_1 < 4; ++j_1) { - tl::mma_sync(reinterpret_cast(C_local + ((i_10 * 32) + (j_1 * 8))), reinterpret_cast(A_local_packed + (i_10 * 16)), reinterpret_cast(B_local_packed + (j_1 * 16))); - tl::mma_sync(reinterpret_cast(C_local + (((i_10 * 32) + (j_1 * 8)) + 4)), reinterpret_cast(A_local_packed + (i_10 * 16)), reinterpret_cast(B_local_packed + ((j_1 * 16) + 8))); - } - } - } - #pragma unroll - for (int i_11 = 0; i_11 < 64; ++i_11) { - *(float2*)(C + (((((((((((int)blockIdx.y) * 32768) + (((((int)threadIdx.x) & 63) >> 5) * 16384)) + ((i_11 >> 4) * 4096)) + ((i_11 & 1) * 2048)) + (((((int)threadIdx.x) & 31) >> 2) * 256)) + (((int)blockIdx.x) * 128)) + ((((int)threadIdx.x) >> 6) * 64)) + (((i_11 & 15) >> 1) * 8)) + ((((int)threadIdx.x) & 3) * 2))) = *(float2*)(C_local + (i_11 * 2)); - } -} - diff --git a/examples/gemm_sm100/test_fp4_diagonal.py b/examples/gemm_sm100/test_fp4_diagonal.py deleted file mode 100644 index f16147b2e8..0000000000 --- a/examples/gemm_sm100/test_fp4_diagonal.py +++ /dev/null @@ -1,216 +0,0 @@ -"""Diagnostic: FP4 GEMM with diagonal B matrix. - -C = A @ B^T where B = identity → C should equal A (in float). -Mismatches reveal exactly which (row, col) the MMA reads incorrectly. -""" - -import os -import torch -import tilelang -import tilelang.language as T - -M, N, K = 256, 256, 256 -# Keep K=128 for the current ALIGN16B gap-aware layout, but reduce N so the -# diagnostic runs within Thor's dynamic shared-memory budget. -# -# We intentionally keep block_M=128 so TCGEN05 stores C through the standard -# atom_m=128 TMEM layout (Layout D) instead of the atom_m=64 WS-specific layout. -block_M, block_N, block_K = 128, 64, 128 -in_dtype = T.float4_e2m1fn -out_dtype = T.float32 -accum_dtype = T.float32 - -FP4_E2M1_TO_FLOAT = [ - 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, - -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0, -] - - -def unpack_fp4_to_float(packed_int8, rows, cols): - lut = torch.tensor(FP4_E2M1_TO_FLOAT, dtype=torch.float32, device=packed_int8.device) - flat = packed_int8.to(torch.uint8).reshape(rows, cols // 2) - lo = flat & 0x0F - hi = (flat >> 4) & 0x0F - unpacked = torch.stack([lo, hi], dim=-1).reshape(rows, cols).to(torch.int64) - return lut[unpacked] - - -def make_fp4_diagonal(N, K): - """Create packed FP4 identity matrix B[N,K] where B[n,k]=1.0 if n==k.""" - packed = torch.zeros(N, K // 2, dtype=torch.uint8, device="cuda") - fp4_one = 2 # FP4 encoding for 1.0 - diag = torch.arange(min(N, K), device="cuda", dtype=torch.int64) - byte_idx = diag // 2 - nibble_shift = (diag % 2).to(torch.uint8) * 4 - packed[diag, byte_idx] = torch.bitwise_left_shift( - torch.full_like(nibble_shift, fp4_one, dtype=torch.uint8), nibble_shift - ) - return packed.to(torch.int8) - - -def make_fp4_row_pattern(M, K): - """A[m,k] = FP4 value based on (m % 8), so each row has a recognizable pattern.""" - vals = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7], dtype=torch.uint8, device="cuda") - row_vals = vals[torch.arange(M, device="cuda", dtype=torch.int64) % 8] - packed_row_vals = row_vals | torch.bitwise_left_shift(row_vals, 4) - packed = packed_row_vals[:, None].expand(M, K // 2).contiguous() - return packed.to(torch.int8) - - -def print_generated_kernel_debug_summary(kernel_src_path: str): - keywords = ( - "arrive_and_expect_tx", - "initialize_tcgen05_descriptor", - "fence_proxy_async", - "tcgen05mma_ss", - ".wait(", - ) - with open(kernel_src_path, "r", encoding="utf-8") as f: - lines = f.readlines() - - first_mma_line = None - first_fence_line = None - - print("Generated kernel debug summary:") - for line_no, line in enumerate(lines, start=1): - if "tcgen05mma_ss" in line and first_mma_line is None: - first_mma_line = line_no - if "fence_proxy_async" in line and first_fence_line is None: - first_fence_line = line_no - if any(keyword in line for keyword in keywords): - print(f"{line_no:4d}: {line.rstrip()}") - - if first_mma_line is None: - print("No tcgen05mma_ss call found in generated source.") - return - - if first_fence_line is None: - print("No fence_proxy_async() found in generated source.") - elif first_fence_line < first_mma_line: - print( - f"First fence_proxy_async() appears before first tcgen05mma_ss " - f"(fence line {first_fence_line}, mma line {first_mma_line})." - ) - else: - print( - f"First fence_proxy_async() appears after first tcgen05mma_ss " - f"(fence line {first_fence_line}, mma line {first_mma_line})." - ) - - -def matmul_fp4_sm100(M, N, K, block_M, block_N, block_K, in_dtype, out_dtype, accum_dtype): - @T.prim_func - def main( - A: T.Tensor((M, K), in_dtype), - B: T.Tensor((N, K), in_dtype), - C: T.Tensor((M, N), out_dtype), - ): - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): - A_shared = T.alloc_shared((block_M, block_K), in_dtype) - B_shared = T.alloc_shared((block_N, block_K), in_dtype) - C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype) - mbar = T.alloc_barrier(1) - C_local = T.alloc_fragment((block_M, block_N), accum_dtype) - C_shared = T.alloc_shared((block_M, block_N), out_dtype) - - for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=1): - T.copy(A[by * block_M, k * block_K], A_shared) - T.copy(B[bx * block_N, k * block_K], B_shared) - T.tcgen05_gemm(A_shared, B_shared, C_tmem, - transpose_A=False, transpose_B=True, - mbar=mbar, clear_accum=(k == 0)) - T.mbarrier_wait_parity(mbar, k % 2) - - T.copy(C_tmem, C_local) - T.copy(C_local, C_shared) - T.copy(C_shared, C[by * block_M, bx * block_N]) - return main - - -print(f"FP4 Diagonal Test: M={M}, N={N}, K={K}") - -func = matmul_fp4_sm100(M, N, K, block_M, block_N, block_K, in_dtype, out_dtype, accum_dtype) - -kernel = tilelang.compile(func, out_idx=[2], target="cuda", pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: False, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, -}) -print("Compiled OK") - -kernel_src_path = os.path.join(os.path.dirname(__file__), "test_fp4_diagonal.generated.cu") -with open(kernel_src_path, "w", encoding="utf-8") as f: - f.write(kernel.get_kernel_source()) -print(f"Kernel source written to: {kernel_src_path}") -print_generated_kernel_debug_summary(kernel_src_path) - -if os.environ.get("TL_FP4_COMPILE_ONLY", "0") == "1": - print("TL_FP4_COMPILE_ONLY=1, skip kernel launch.") - raise SystemExit(0) - -# --- Test: A = row pattern, B = identity --- -print("Building FP4 test inputs...") -a_packed = make_fp4_row_pattern(M, K) -b_packed = make_fp4_diagonal(N, K) -torch.cuda.synchronize() -print("Launching kernel...") - -c = kernel(a_packed, b_packed) -torch.cuda.synchronize() -print("Kernel finished.") - -a_float = unpack_fp4_to_float(a_packed, M, K) -b_float = unpack_fp4_to_float(b_packed, N, K) -ref_c = a_float @ b_float.T # should ≈ A (since B is identity) - -diff = (c.float() - ref_c).abs() -max_diff = diff.max().item() -rel_err = diff.sum().item() / (ref_c.abs().sum().item() + 1e-10) -print(f"max_abs_diff={max_diff:.4f}, rel_err={rel_err:.6f}") - -out = c -print(f"\n{'='*80}") -print(f"Element-by-element comparison (first 8 rows x first 48 cols):") -print(f"{'='*80}") - -# Print header -print(f"{'':>6}", end="") -for col in range(48): - print(f" {col:>5}", end="") -print() - -for r in range(8): - # Expected row - print(f"E[{r:>2}]", end="") - for col in range(48): - print(f" {ref_c[r,col].item():>5.1f}", end="") - print() - # Actual row - print(f"A[{r:>2}]", end="") - for col in range(48): - v = out[r, col].item() - e = ref_c[r, col].item() - mark = " " if abs(v - e) < 0.01 else "*" - print(f"{v:>5.1f}{mark}", end="") - print() - print() - -print(f"\n{'='*80}") -print(f"Columns 120-135 (crossing 128 boundary):") -print(f"{'='*80}") -print(f"{'':>6}", end="") -for col in range(120, 136): - print(f" {col:>5}", end="") -print() -for r in range(4): - print(f"E[{r:>2}]", end="") - for col in range(120, 136): - print(f" {ref_c[r,col].item():>5.1f}", end="") - print() - print(f"A[{r:>2}]", end="") - for col in range(120, 136): - v = out[r, col].item() - e = ref_c[r, col].item() - mark = " " if abs(v - e) < 0.01 else "*" - print(f"{v:>5.1f}{mark}", end="") - print() - print() diff --git a/examples/gemm_sm100/test_fp4_single_tile_probe.py b/examples/gemm_sm100/test_fp4_single_tile_probe.py deleted file mode 100644 index 5a70205f30..0000000000 --- a/examples/gemm_sm100/test_fp4_single_tile_probe.py +++ /dev/null @@ -1,276 +0,0 @@ -"""Single-tile FP4 GEMM probe for SM100/SM110. - -Purpose: - 1. Remove the second K tile so we only exercise the first - TMA -> tcgen05.mma -> mbarrier wait sequence. - 2. Allow quick TMA vs non-TMA comparison via TL_FP4_DISABLE_TMA=1. - 3. Allow isolating the epilogue path via TL_FP4_LINEAR_EPILOGUE=1, - so we can compare "TMA input + linear global store" against the - default "TMA input + TMA store" path. - 4. Allow forcing an explicit fence between copy/wait and tcgen05.mma via - TL_FP4_FORCE_FENCE=1 to test whether the TMA->UMMA handoff is missing a - proxy fence. - 5. Allow stopping immediately after input load via TL_FP4_TMA_ONLY=1, so - TMA wait and tcgen05 MMA wait can be isolated. - 6. Allow forcing the same descriptor/gap-aware input TMA path without MMA via - TL_FP4_TMA_ONLY_DESCRIPTOR=1. - -Run: - python examples/gemm_sm100/test_fp4_single_tile_probe.py - TL_FP4_DISABLE_TMA=1 python examples/gemm_sm100/test_fp4_single_tile_probe.py - TL_FP4_COMPILE_ONLY=1 python examples/gemm_sm100/test_fp4_single_tile_probe.py - TL_FP4_LINEAR_EPILOGUE=1 python examples/gemm_sm100/test_fp4_single_tile_probe.py - TL_FP4_FORCE_FENCE=1 python examples/gemm_sm100/test_fp4_single_tile_probe.py - TL_FP4_TMA_ONLY=1 python examples/gemm_sm100/test_fp4_single_tile_probe.py - TL_FP4_TMA_ONLY=1 TL_FP4_TMA_ONLY_DESCRIPTOR=1 python examples/gemm_sm100/test_fp4_single_tile_probe.py -""" - -import os - -import torch -import tilelang -import tilelang.language as T - -M, N, K = 128, 64, 128 -block_M, block_N, block_K = 128, 64, 128 -in_dtype = T.float4_e2m1fn -out_dtype = T.float32 -accum_dtype = T.float32 - -FP4_E2M1_TO_FLOAT = [ - 0.0, - 0.5, - 1.0, - 1.5, - 2.0, - 3.0, - 4.0, - 6.0, - -0.0, - -0.5, - -1.0, - -1.5, - -2.0, - -3.0, - -4.0, - -6.0, -] - - -def unpack_fp4_to_float(packed_int8, rows, cols): - lut = torch.tensor( - FP4_E2M1_TO_FLOAT, dtype=torch.float32, device=packed_int8.device - ) - flat = packed_int8.to(torch.uint8).reshape(rows, cols // 2) - lo = flat & 0x0F - hi = (flat >> 4) & 0x0F - unpacked = torch.stack([lo, hi], dim=-1).reshape(rows, cols).to(torch.int64) - return lut[unpacked] - - -def make_fp4_diagonal(rows, cols): - packed = torch.zeros(rows, cols // 2, dtype=torch.uint8, device="cuda") - fp4_one = 2 # FP4 encoding for 1.0 - diag = torch.arange(min(rows, cols), device="cuda", dtype=torch.int64) - byte_idx = diag // 2 - nibble_shift = (diag % 2).to(torch.uint8) * 4 - packed[diag, byte_idx] = torch.bitwise_left_shift( - torch.full_like(nibble_shift, fp4_one, dtype=torch.uint8), nibble_shift - ) - return packed.to(torch.int8) - - -def make_fp4_row_pattern(rows, cols): - vals = torch.tensor( - [0, 1, 2, 3, 4, 5, 6, 7], dtype=torch.uint8, device="cuda" - ) - row_vals = vals[torch.arange(rows, device="cuda", dtype=torch.int64) % 8] - packed_row_vals = row_vals | torch.bitwise_left_shift(row_vals, 4) - packed = packed_row_vals[:, None].expand(rows, cols // 2).contiguous() - return packed.to(torch.int8) - - -def print_generated_kernel_debug_summary(kernel_src_path: str): - keywords = ( - "arrive_and_expect_tx", - "expect_transaction", - "tma_load(", - "tma_store(", - "tma_store_arrive", - "tma_store_wait", - "initialize_tcgen05_descriptor", - "fence_proxy_async", - "tcgen05mma_ss", - ".wait(", - ) - - with open(kernel_src_path, "r", encoding="utf-8") as f: - lines = f.readlines() - - print("Generated kernel debug summary:") - for line_no, line in enumerate(lines, start=1): - if any(keyword in line for keyword in keywords): - print(f"{line_no:4d}: {line.rstrip()}") - - -def fp4_single_tile_kernel( - M, - N, - K, - block_M, - block_N, - block_K, - in_dtype, - out_dtype, - accum_dtype, - linear_epilogue, - force_fence, - tma_only, - tma_only_descriptor, -): - @T.prim_func - def main( - A: T.Tensor((M, K), in_dtype), - B: T.Tensor((N, K), in_dtype), - C: T.Tensor((M, N), out_dtype), - ): - with T.Kernel(1, 1, threads=128) as (bx, by): - A_shared = T.alloc_shared((block_M, block_K), in_dtype) - B_shared = T.alloc_shared((block_N, block_K), in_dtype) - C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype) - if not disable_tma: - load_mbar_a = T.alloc_barrier(128) - load_mbar_b = T.alloc_barrier(128) - mbar = T.alloc_barrier(1) - C_local = T.alloc_fragment((block_M, block_N), accum_dtype) - if not linear_epilogue: - C_shared = T.alloc_shared((block_M, block_N), out_dtype) - if tma_only and tma_only_descriptor: - T.annotate_layout( - { - A_shared: tilelang.layout.make_tcgen05mma_swizzled_layout( - A_shared, continuity=block_K, k_major=True - ), - B_shared: tilelang.layout.make_tcgen05mma_swizzled_layout( - B_shared, continuity=block_K, k_major=True - ), - } - ) - - if disable_tma: - T.copy(A[0, 0], A_shared) - T.copy(B[0, 0], B_shared) - else: - T.tma_copy(A[0, 0], A_shared, barrier=load_mbar_a) - T.tma_copy(B[0, 0], B_shared, barrier=load_mbar_b) - T.barrier_arrive(load_mbar_a) - T.barrier_arrive(load_mbar_b) - T.mbarrier_wait_parity(load_mbar_a, 0) - T.mbarrier_wait_parity(load_mbar_b, 0) - T.fence_proxy_async() - if tma_only: - return - if force_fence: - T.fence_proxy_async() - T.tcgen05_gemm( - A_shared, - B_shared, - C_tmem, - transpose_A=False, - transpose_B=True, - mbar=mbar, - clear_accum=True, - ) - T.mbarrier_wait_parity(mbar, 0) - - T.copy(C_tmem, C_local) - if linear_epilogue: - T.copy(C_local, C[0, 0]) - else: - T.copy(C_local, C_shared) - T.copy(C_shared, C[0, 0]) - - return main - - -disable_tma = os.environ.get("TL_FP4_DISABLE_TMA", "0") == "1" -linear_epilogue = os.environ.get("TL_FP4_LINEAR_EPILOGUE", "0") == "1" -force_fence = os.environ.get("TL_FP4_FORCE_FENCE", "0") == "1" -tma_only = os.environ.get("TL_FP4_TMA_ONLY", "0") == "1" -tma_only_descriptor = os.environ.get("TL_FP4_TMA_ONLY_DESCRIPTOR", "0") == "1" -mode_parts = ["non_tma" if disable_tma else "tma"] -if linear_epilogue: - mode_parts.append("linear_epilogue") -if force_fence: - mode_parts.append("force_fence") -if tma_only: - mode_parts.append("tma_only") -if tma_only_descriptor: - mode_parts.append("descriptor") -mode = ".".join(mode_parts) -print( - f"FP4 single-tile probe: mode={mode}, M={M}, N={N}, K={K}, " - f"block=({block_M},{block_N},{block_K})" -) - -func = fp4_single_tile_kernel( - M, - N, - K, - block_M, - block_N, - block_K, - in_dtype, - out_dtype, - accum_dtype, - linear_epilogue, - force_fence, - tma_only, - tma_only_descriptor, -) -kernel = tilelang.compile( - func, - out_idx=[2], - target="cuda", - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: disable_tma, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }, -) -print("Compiled OK") - -kernel_src_path = os.path.join( - os.path.dirname(__file__), f"test_fp4_single_tile_probe.{mode}.generated.cu" -) -with open(kernel_src_path, "w", encoding="utf-8") as f: - f.write(kernel.get_kernel_source()) -print(f"Kernel source written to: {kernel_src_path}") -print_generated_kernel_debug_summary(kernel_src_path) - -if os.environ.get("TL_FP4_COMPILE_ONLY", "0") == "1": - print("TL_FP4_COMPILE_ONLY=1, skip kernel launch.") - raise SystemExit(0) - -print("Building FP4 test inputs...") -a_packed = make_fp4_row_pattern(M, K) -b_packed = make_fp4_diagonal(N, K) -torch.cuda.synchronize() - -print("Launching kernel...") -c = kernel(a_packed, b_packed) -print("Kernel launch returned, synchronizing...") -torch.cuda.synchronize() -print("Kernel finished.") - -if tma_only: - print("TL_FP4_TMA_ONLY=1, skipped GEMM and output comparison.") - raise SystemExit(0) - -a_float = unpack_fp4_to_float(a_packed, M, K) -b_float = unpack_fp4_to_float(b_packed, N, K) -ref_c = a_float @ b_float.T - -diff = (c.float() - ref_c).abs() -max_diff = diff.max().item() -rel_err = diff.sum().item() / (ref_c.abs().sum().item() + 1e-10) -print(f"max_abs_diff={max_diff:.4f}, rel_err={rel_err:.6f}") diff --git a/examples/gemm_sm100/test_fp4_single_tile_probe.tma.generated.cu b/examples/gemm_sm100/test_fp4_single_tile_probe.tma.generated.cu deleted file mode 100644 index 722dd72830..0000000000 --- a/examples/gemm_sm100/test_fp4_single_tile_probe.tma.generated.cu +++ /dev/null @@ -1,102 +0,0 @@ -#if defined(_MSC_VER) && !defined(__clang__) && _MSC_VER < 1940 -#define _tl_orig_alignas alignas -#define alignas(N) _tl_orig_alignas((N) <= 64 ? (N) : 64) -#include -#undef alignas -#define alignas _tl_orig_alignas -#endif -#include -#include -#include -#include -#include -#include -#include -#include -#include -#ifdef ENABLE_BF16 -#include -#endif - -extern "C" __global__ void main_kernel(__grid_constant__ const CUtensorMap A_desc, __grid_constant__ const CUtensorMap B_desc, float* __restrict__ C); -extern "C" __global__ void __launch_bounds__(128, 1) main_kernel(__grid_constant__ const CUtensorMap A_desc, __grid_constant__ const CUtensorMap B_desc, float* __restrict__ C) { - extern __shared__ __align__(1024) uchar buf_dyn_shmem[]; - __shared__ __align__(16) uint64_t load_mbar_a_mem[1]; - auto load_mbar_a = reinterpret_cast(load_mbar_a_mem); - __shared__ __align__(16) uint64_t load_mbar_b_mem[1]; - auto load_mbar_b = reinterpret_cast(load_mbar_b_mem); - __shared__ __align__(16) uint64_t mbar_mem[1]; - auto mbar = reinterpret_cast(mbar_mem); - __shared__ __align__(16) uint C_tmem[1]; - float C_local[64]; - if (tl::tl_shuffle_elect<0>()) { - tl::prefetch_tma_descriptor(A_desc); - tl::prefetch_tma_descriptor(B_desc); - } - if (tl::tl_shuffle_elect<0>()) { - load_mbar_a[0].init(128); - load_mbar_b[0].init(128); - mbar[0].init(1); - } - tl::fence_barrier_init(); - tl::tcgen05_before_thread_sync(); - __syncthreads(); - tl::tcgen05_after_thread_sync(); - if ((((int)threadIdx.x) >> 5) == 0) { - tl::tmem_allocate((&(C_tmem[0])), 64); - } - tl::tcgen05_before_thread_sync(); - __syncthreads(); - tl::tcgen05_after_thread_sync(); - if (tl::tl_shuffle_elect<128>()) { - load_mbar_a[0].expect_transaction(16256); - tl::tma_load(A_desc, load_mbar_a[0], (&(((fp4_e2_t*)buf_dyn_shmem)[0])), 0, 0); - load_mbar_b[0].expect_transaction(8128); - tl::tma_load(B_desc, load_mbar_b[0], (&(((fp4_e2_t*)buf_dyn_shmem)[32512])), 0, 0); - } - load_mbar_a[0].arrive(); - load_mbar_b[0].arrive(); - load_mbar_a[0].wait(0); - load_mbar_b[0].wait(0); - tl::tcgen05_after_thread_sync(); - tl::fence_proxy_async(); - { - tl::Tcgen05SMemDescriptor desc_a; - tl::Tcgen05SMemDescriptor desc_b; - tl::tcgen05_before_thread_sync(); - __syncthreads(); - tl::tcgen05_after_thread_sync(); - if ((((int)threadIdx.x) >> 5) == 0) { - tl::initialize_tcgen05_descriptor(desc_a, (&(((fp4_e2_t*)buf_dyn_shmem)[0])), 1, 64, 0, 0, 2); - tl::initialize_tcgen05_descriptor(desc_b, (&(((fp4_e2_t*)buf_dyn_shmem)[32512])), 1, 64, 0, 0, 2); - #pragma unroll - for (int ki = 0; ki < 4; ++ki) { - tl::tcgen05mma_ws_ss(uint64_t(desc_a + (ki * 32)), uint64_t(desc_b + (ki * 32)), (*reinterpret_cast(C_tmem)) + 0, ((0 < ki) ? 1 : 0), static_cast(135272080), 0, 0, 0, 0); - } - tl::tcgen05_mma_arrive((&(mbar[0]))); - } - } - mbar[0].wait(0); - tl::tcgen05_after_thread_sync(); - tl::tcgen05_ld_32dp32bNx<64, false>(C_tmem[0], 0, (&(C_local[0]))); - tl::tcgen05_before_thread_sync(); - __syncthreads(); - tl::tcgen05_after_thread_sync(); - #pragma unroll - for (int i = 0; i < 16; ++i) { - *(float4*)(((float*)buf_dyn_shmem) + ((((int)threadIdx.x) * 64) + (i * 4))) = *(float4*)(C_local + (i * 4)); - } - tl::tcgen05_before_thread_sync(); - __syncthreads(); - tl::tcgen05_after_thread_sync(); - if (tl::tl_shuffle_elect<128>()) { - tl::fence_proxy_async(); - tl::tma_store((&(C[0])), (&(((float*)buf_dyn_shmem)[0])), 32768); - tl::tma_store_arrive(); - tl::tma_store_wait<0>(); - } - if ((((int)threadIdx.x) >> 5) == 0) { - tl::tmem_deallocate((&(C_tmem[0])), 64); - } -} - diff --git a/examples/gemm_sm100/test_fp4_single_tile_probe.tma.linear_epilogue.generated.cu b/examples/gemm_sm100/test_fp4_single_tile_probe.tma.linear_epilogue.generated.cu deleted file mode 100644 index f17c4de7fc..0000000000 --- a/examples/gemm_sm100/test_fp4_single_tile_probe.tma.linear_epilogue.generated.cu +++ /dev/null @@ -1,90 +0,0 @@ -#if defined(_MSC_VER) && !defined(__clang__) && _MSC_VER < 1940 -#define _tl_orig_alignas alignas -#define alignas(N) _tl_orig_alignas((N) <= 64 ? (N) : 64) -#include -#undef alignas -#define alignas _tl_orig_alignas -#endif -#include -#include -#include -#include -#include -#include -#include -#include -#include -#ifdef ENABLE_BF16 -#include -#endif - -extern "C" __global__ void main_kernel(__grid_constant__ const CUtensorMap A_desc, __grid_constant__ const CUtensorMap B_desc, float* __restrict__ C); -extern "C" __global__ void __launch_bounds__(128, 1) main_kernel(__grid_constant__ const CUtensorMap A_desc, __grid_constant__ const CUtensorMap B_desc, float* __restrict__ C) { - extern __shared__ __align__(1024) uchar buf_dyn_shmem[]; - __shared__ __align__(16) uint64_t load_mbar_a_mem[1]; - auto load_mbar_a = reinterpret_cast(load_mbar_a_mem); - __shared__ __align__(16) uint64_t load_mbar_b_mem[1]; - auto load_mbar_b = reinterpret_cast(load_mbar_b_mem); - __shared__ __align__(16) uint64_t mbar_mem[1]; - auto mbar = reinterpret_cast(mbar_mem); - __shared__ __align__(16) uint C_tmem[1]; - float C_local[64]; - if (tl::tl_shuffle_elect<0>()) { - tl::prefetch_tma_descriptor(A_desc); - tl::prefetch_tma_descriptor(B_desc); - } - if (tl::tl_shuffle_elect<0>()) { - load_mbar_a[0].init(128); - load_mbar_b[0].init(128); - mbar[0].init(1); - } - tl::fence_barrier_init(); - tl::tcgen05_before_thread_sync(); - __syncthreads(); - tl::tcgen05_after_thread_sync(); - if ((((int)threadIdx.x) >> 5) == 0) { - tl::tmem_allocate((&(C_tmem[0])), 64); - } - tl::tcgen05_before_thread_sync(); - __syncthreads(); - tl::tcgen05_after_thread_sync(); - if (tl::tl_shuffle_elect<128>()) { - load_mbar_a[0].expect_transaction(8192); - tl::tma_load(A_desc, load_mbar_a[0], (&(((fp4_e2_t*)buf_dyn_shmem)[0])), 0, 0); - load_mbar_b[0].expect_transaction(4096); - tl::tma_load(B_desc, load_mbar_b[0], (&(((fp4_e2_t*)buf_dyn_shmem)[16384])), 0, 0); - } - load_mbar_a[0].arrive(); - load_mbar_b[0].arrive(); - load_mbar_a[0].wait(0); - load_mbar_b[0].wait(0); - tl::tcgen05_after_thread_sync(); - tl::fence_proxy_async(); - { - tl::Tcgen05SMemDescriptor desc_a; - tl::Tcgen05SMemDescriptor desc_b; - tl::tcgen05_before_thread_sync(); - __syncthreads(); - tl::tcgen05_after_thread_sync(); - if ((((int)threadIdx.x) >> 5) == 0) { - tl::initialize_tcgen05_descriptor(desc_a, (&(((fp4_e2_t*)buf_dyn_shmem)[0])), 1, 64, 0, 0, 2); - tl::initialize_tcgen05_descriptor(desc_b, (&(((fp4_e2_t*)buf_dyn_shmem)[16384])), 1, 64, 0, 0, 2); - #pragma unroll - for (int ki = 0; ki < 4; ++ki) { - tl::tcgen05mma_ws_ss(uint64_t(desc_a + (ki * 32)), uint64_t(desc_b + (ki * 32)), (*reinterpret_cast(C_tmem)) + 0, ((0 < ki) ? 1 : 0), static_cast(135272080), 0, 0, 0, 0); - } - tl::tcgen05_mma_arrive((&(mbar[0]))); - } - } - mbar[0].wait(0); - tl::tcgen05_after_thread_sync(); - tl::tcgen05_ld_32dp32bNx<64, false>(C_tmem[0], 0, (&(C_local[0]))); - #pragma unroll - for (int i = 0; i < 8; ++i) { - tl::store_global_256(&(*(ulonglong4*)(C + ((((int)threadIdx.x) * 64) + (i * 8)))), *(ulonglong4*)(C_local + (i * 8))); - } - if ((((int)threadIdx.x) >> 5) == 0) { - tl::tmem_deallocate((&(C_tmem[0])), 64); - } -} - diff --git a/examples/gemm_sm100/test_fp4_single_tile_probe.tma.tma_only.descriptor.generated.cu b/examples/gemm_sm100/test_fp4_single_tile_probe.tma.tma_only.descriptor.generated.cu deleted file mode 100644 index a210b5b2bb..0000000000 --- a/examples/gemm_sm100/test_fp4_single_tile_probe.tma.tma_only.descriptor.generated.cu +++ /dev/null @@ -1,65 +0,0 @@ -#if defined(_MSC_VER) && !defined(__clang__) && _MSC_VER < 1940 -#define _tl_orig_alignas alignas -#define alignas(N) _tl_orig_alignas((N) <= 64 ? (N) : 64) -#include -#undef alignas -#define alignas _tl_orig_alignas -#endif -#include -#include -#include -#include -#include -#include -#include -#include -#ifdef ENABLE_BF16 -#include -#endif - -extern "C" __global__ void main_kernel(__grid_constant__ const CUtensorMap A_desc, __grid_constant__ const CUtensorMap B_desc); -extern "C" __global__ void __launch_bounds__(128, 1) main_kernel(__grid_constant__ const CUtensorMap A_desc, __grid_constant__ const CUtensorMap B_desc) { - extern __shared__ __align__(1024) uchar buf_dyn_shmem[]; - __shared__ __align__(16) uint64_t load_mbar_a_mem[1]; - auto load_mbar_a = reinterpret_cast(load_mbar_a_mem); - __shared__ __align__(16) uint64_t load_mbar_b_mem[1]; - auto load_mbar_b = reinterpret_cast(load_mbar_b_mem); - __shared__ __align__(16) uint64_t mbar_mem[1]; - auto mbar = reinterpret_cast(mbar_mem); - __shared__ __align__(16) uint C_tmem[1]; - if (tl::tl_shuffle_elect<0>()) { - tl::prefetch_tma_descriptor(A_desc); - tl::prefetch_tma_descriptor(B_desc); - } - if (tl::tl_shuffle_elect<0>()) { - load_mbar_a[0].init(128); - load_mbar_b[0].init(128); - mbar[0].init(1); - } - tl::fence_barrier_init(); - tl::tcgen05_before_thread_sync(); - __syncthreads(); - tl::tcgen05_after_thread_sync(); - if ((((int)threadIdx.x) >> 5) == 0) { - tl::tmem_allocate((&(C_tmem[0])), 64); - } - tl::tcgen05_before_thread_sync(); - __syncthreads(); - tl::tcgen05_after_thread_sync(); - if (tl::tl_shuffle_elect<128>()) { - load_mbar_a[0].expect_transaction(8192); - tl::tma_load(A_desc, load_mbar_a[0], (&(((fp4_e2_t*)buf_dyn_shmem)[0])), 0, 0); - load_mbar_b[0].expect_transaction(4096); - tl::tma_load(B_desc, load_mbar_b[0], (&(((fp4_e2_t*)buf_dyn_shmem)[16384])), 0, 0); - } - load_mbar_a[0].arrive(); - load_mbar_b[0].arrive(); - load_mbar_a[0].wait(0); - load_mbar_b[0].wait(0); - tl::tcgen05_after_thread_sync(); - tl::fence_proxy_async(); - if ((((int)threadIdx.x) >> 5) == 0) { - tl::tmem_deallocate((&(C_tmem[0])), 64); - } -} - diff --git a/examples/gemm_sm100/test_fp4_single_tile_probe.tma.tma_only.generated.cu b/examples/gemm_sm100/test_fp4_single_tile_probe.tma.tma_only.generated.cu deleted file mode 100644 index dc1fa8a3ab..0000000000 --- a/examples/gemm_sm100/test_fp4_single_tile_probe.tma.tma_only.generated.cu +++ /dev/null @@ -1,61 +0,0 @@ -#if defined(_MSC_VER) && !defined(__clang__) && _MSC_VER < 1940 -#define _tl_orig_alignas alignas -#define alignas(N) _tl_orig_alignas((N) <= 64 ? (N) : 64) -#include -#undef alignas -#define alignas _tl_orig_alignas -#endif -#include -#include -#include -#include -#include -#include -#include -#include -#ifdef ENABLE_BF16 -#include -#endif - -extern "C" __global__ void main_kernel(const fp4_e2_t* __restrict__ A, const fp4_e2_t* __restrict__ B); -extern "C" __global__ void __launch_bounds__(128, 1) main_kernel(const fp4_e2_t* __restrict__ A, const fp4_e2_t* __restrict__ B) { - extern __shared__ __align__(1024) uchar buf_dyn_shmem[]; - __shared__ __align__(16) uint64_t load_mbar_a_mem[1]; - auto load_mbar_a = reinterpret_cast(load_mbar_a_mem); - __shared__ __align__(16) uint64_t load_mbar_b_mem[1]; - auto load_mbar_b = reinterpret_cast(load_mbar_b_mem); - __shared__ __align__(16) uint64_t mbar_mem[1]; - auto mbar = reinterpret_cast(mbar_mem); - __shared__ __align__(16) uint C_tmem[1]; - if (tl::tl_shuffle_elect<0>()) { - load_mbar_a[0].init(128); - load_mbar_b[0].init(128); - mbar[0].init(1); - } - tl::fence_barrier_init(); - tl::tcgen05_before_thread_sync(); - __syncthreads(); - tl::tcgen05_after_thread_sync(); - if ((((int)threadIdx.x) >> 5) == 0) { - tl::tmem_allocate((&(C_tmem[0])), 64); - } - tl::tcgen05_before_thread_sync(); - __syncthreads(); - tl::tcgen05_after_thread_sync(); - if (tl::tl_shuffle_elect<128>()) { - load_mbar_a[0].expect_transaction(8192); - tl::tma_load((&(((fp4_e2_t*)buf_dyn_shmem)[0])), (&(A[0])), load_mbar_a[0], 8192); - load_mbar_b[0].expect_transaction(4096); - tl::tma_load((&(((fp4_e2_t*)buf_dyn_shmem)[16384])), (&(B[0])), load_mbar_b[0], 4096); - } - load_mbar_a[0].arrive(); - load_mbar_b[0].arrive(); - load_mbar_a[0].wait(0); - load_mbar_b[0].wait(0); - tl::tcgen05_after_thread_sync(); - tl::fence_proxy_async(); - if ((((int)threadIdx.x) >> 5) == 0) { - tl::tmem_deallocate((&(C_tmem[0])), 64); - } -} - diff --git a/examples/gemm_sm100/test_tma_fp4_roundtrip.py b/examples/gemm_sm100/test_tma_fp4_roundtrip.py deleted file mode 100644 index 5fa454face..0000000000 --- a/examples/gemm_sm100/test_tma_fp4_roundtrip.py +++ /dev/null @@ -1,207 +0,0 @@ -"""Diagnostic: force TMA load for FP4 on SM100/SM110. - -This test isolates the GMEM -> SMEM TMA load path for FP4 ALIGN16B: - 1. input tile is loaded with explicit ``T.tma_copy(...)`` so it cannot - silently fall back to a normal copy; - 2. shared memory uses the ALIGN16B gap-aware layout that GEMM also relies on; - 3. the write-back to GMEM stays as a normal copy so the result reflects only - whether the TMA load populated SMEM correctly. - -Run: TILELANG_DISABLE_CACHE=1 python examples/gemm_sm100/test_tma_fp4_roundtrip.py -""" - -import os - -import torch -import tilelang -import tilelang.language as T - -M, K = 256, 256 -block_M, block_K = 128, 128 -in_dtype = T.float4_e2m1fn - - -def tma_roundtrip_kernel( - M, - K, - block_M, - block_K, - in_dtype, - force_tma_load: bool, - skip_store: bool, - use_explicit_tma_copy: bool, -): - """Copy A[block_M, block_K] to SMEM, then write back to Out.""" - A_shape = (M, K) - - @T.prim_func - def main( - A: T.Tensor(A_shape, in_dtype), - Out: T.Tensor(A_shape, in_dtype), - ): - with T.Kernel(1, T.ceildiv(M, block_M), threads=128) as (bx, by): - A_shared = T.alloc_shared((block_M, block_K), in_dtype) - if force_tma_load and use_explicit_tma_copy: - mbar = T.alloc_barrier(128) - - T.annotate_layout( - { - A_shared: tilelang.layout.make_align16b_swizzled_layout(A_shared), - } - ) - - if force_tma_load and use_explicit_tma_copy: - T.tma_copy(A[by * block_M, 0], A_shared, barrier=mbar) - T.barrier_arrive(mbar) - T.mbarrier_wait_parity(mbar, 0) - T.fence_proxy_async() - elif force_tma_load: - T.copy(A[by * block_M, 0], A_shared) - else: - T.copy(A[by * block_M, 0], A_shared) - if not skip_store: - T.copy(A_shared, Out[by * block_M, 0]) - - return main - - -def dump_and_check_kernel_source(kernel, mode: str, expect_tma_load: bool): - kernel_src_path = os.path.join( - os.path.dirname(__file__), - f"test_tma_fp4_roundtrip.{mode}.generated.cu", - ) - src = kernel.get_kernel_source() - with open(kernel_src_path, "w", encoding="utf-8") as f: - f.write(src) - print(f"Kernel source written to: {kernel_src_path}") - - keywords = ( - "arrive_and_expect_tx", - "tma_load(", - "tma_store(", - "fence_proxy_async", - ) - print(f"{mode} kernel debug summary:") - matched = [] - for line_no, line in enumerate(src.splitlines(), start=1): - if any(keyword in line for keyword in keywords): - matched.append((line_no, line.rstrip())) - print(f"{line_no:4d}: {line.rstrip()}") - - has_tma_load = any("tma_load(" in line for _, line in matched) - if expect_tma_load and not has_tma_load: - raise RuntimeError( - f"{mode} kernel was expected to use TMA load, but generated source " - "contains no tl::tma_load(...) call." - ) - if not expect_tma_load and has_tma_load: - raise RuntimeError( - f"{mode} kernel was expected to disable TMA load, but generated source " - "still contains tl::tma_load(...)." - ) - - -print(f"TMA FP4 load diagnostic: M={M}, K={K}, block=({block_M},{block_K})") -run_ref = os.environ.get("TL_FP4_SKIP_REF", "0") != "1" -skip_store = os.environ.get("TL_FP4_SKIP_STORE", "0") == "1" -use_explicit_tma_copy = os.environ.get("TL_FP4_USE_TMA_COPY", "1") == "1" -mode = "explicit_tma_copy" if use_explicit_tma_copy else "auto_tma_lower" -print(f"Load mode: {mode}") - -# --- Build with explicit TMA load --- -func_tma = tma_roundtrip_kernel( - M, - K, - block_M, - block_K, - in_dtype, - force_tma_load=True, - skip_store=skip_store, - use_explicit_tma_copy=use_explicit_tma_copy, -) -kernel_tma = tilelang.compile( - func_tma, - out_idx=[1], - target="cuda", - pass_configs={ - # Keep explicit T.tma_copy(), but disable implicit TMA selection on the - # store path so this test only validates the load side. - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }, -) -print("TMA kernel compiled OK") -dump_and_check_kernel_source(kernel_tma, "tma", expect_tma_load=True) - -kernel_ref = None -if run_ref: - # --- Build with TMA disabled (reference) --- - func_ref = tma_roundtrip_kernel( - M, - K, - block_M, - block_K, - in_dtype, - force_tma_load=False, - skip_store=skip_store, - use_explicit_tma_copy=False, - ) - kernel_ref = tilelang.compile( - func_ref, - out_idx=[1], - target="cuda", - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }, - ) - print("Reference kernel (no TMA) compiled OK") - dump_and_check_kernel_source(kernel_ref, "non_tma", expect_tma_load=False) - -torch.manual_seed(42) -a_packed = torch.randint(0, 256, (M, K // 2), device="cuda", dtype=torch.uint8).to(torch.int8) - -# --- Run TMA kernel first and synchronize immediately so failures are attributed correctly --- -print("Launching TMA kernel...") -out_tma = kernel_tma(a_packed) -torch.cuda.synchronize() -print("TMA kernel finished.") - -out_ref = None -if kernel_ref is not None: - print("Launching reference kernel...") - out_ref = kernel_ref(a_packed) - torch.cuda.synchronize() - print("Reference kernel finished.") - -if skip_store: - print("\nStore path skipped; this run only validates kernel completion after TMA load.") - raise SystemExit(0) - -# --- Compare byte-by-byte --- -tma_bytes = out_tma.view(torch.uint8).cpu() -inp_bytes = a_packed.view(torch.uint8).cpu() - -match_tma_inp = (tma_bytes == inp_bytes).sum().item() -total = tma_bytes.numel() - -print(f"\nTotal bytes: {total}") -print(f"TMA vs Input: {match_tma_inp}/{total} match ({100*match_tma_inp/total:.1f}%)") -if out_ref is not None: - ref_bytes = out_ref.view(torch.uint8).cpu() - match_ref_inp = (ref_bytes == inp_bytes).sum().item() - match_tma_ref = (tma_bytes == ref_bytes).sum().item() - print(f"Ref vs Input: {match_ref_inp}/{total} match ({100*match_ref_inp/total:.1f}%)") - print(f"TMA vs Ref: {match_tma_ref}/{total} match ({100*match_tma_ref/total:.1f}%)") - -if match_tma_inp == total: - print("\n[PASS] TMA round-trip is byte-exact!") -else: - mismatch_idx = (tma_bytes != inp_bytes).nonzero(as_tuple=True)[0] - print(f"\nFirst 20 mismatched byte indices: {mismatch_idx[:20].tolist()}") - print(f" TMA bytes: {tma_bytes[mismatch_idx[:20]].tolist()}") - print(f" Inp bytes: {inp_bytes[mismatch_idx[:20]].tolist()}") - - row_size = K // 2 - mismatch_rows = (mismatch_idx // row_size).unique() - print(f"Mismatched rows ({len(mismatch_rows)} total): {mismatch_rows[:20].tolist()}") diff --git a/src/backend/cuda/codegen/codegen_cuda.cc b/src/backend/cuda/codegen/codegen_cuda.cc index 4b464f5f13..3a7f713e65 100644 --- a/src/backend/cuda/codegen/codegen_cuda.cc +++ b/src/backend/cuda/codegen/codegen_cuda.cc @@ -1934,8 +1934,8 @@ std::string CodeGenTileLangCUDA::GetBufferRef(DataType t, } else if (t == buffer_element_dtype) { int div_factor = 1; bool is_packed_scope = scope.empty() || scope == "global"; - if (buffer_element_dtype.is_float4() && buffer_element_dtype.lanes() == 1 - && is_packed_scope) { + if (buffer_element_dtype.is_float4() && buffer_element_dtype.lanes() == 1 && + is_packed_scope) { div_factor = 2; } index_str = @@ -1944,8 +1944,8 @@ std::string CodeGenTileLangCUDA::GetBufferRef(DataType t, } else { int div_factor = 1; bool is_packed_scope = scope.empty() || scope == "global"; - if (buffer_element_dtype.is_float4() && buffer_element_dtype.lanes() == 1 - && is_packed_scope) { + if (buffer_element_dtype.is_float4() && buffer_element_dtype.lanes() == 1 && + is_packed_scope) { div_factor = 2; } index_str = diff --git a/src/backend/cuda/op/copy.cc b/src/backend/cuda/op/copy.cc index 498fefb252..e7bd269daf 100644 --- a/src/backend/cuda/op/copy.cc +++ b/src/backend/cuda/op/copy.cc @@ -1102,10 +1102,10 @@ Stmt Copy::LowerBulk(const CopyNode &op, const LowerArgs &T, << " is not divisible by instruction_dim: " << instruction_dim; desc.smem_box.Set(0, PrimExpr(instruction_dim)); - int inner_box_dim_ = is_align16b_subbyte_layout - ? instruction_dim * shared_tensor->dtype.bytes() - : TMABytesFromElements(instruction_dim, - shared_tensor->dtype); + int inner_box_dim_ = + is_align16b_subbyte_layout + ? instruction_dim * shared_tensor->dtype.bytes() + : TMABytesFromElements(instruction_dim, shared_tensor->dtype); struct SwizzleCheck { int swizzle; diff --git a/src/backend/cuda/op/gemm.cc b/src/backend/cuda/op/gemm.cc index 349ebf9a8a..8f72f192e0 100644 --- a/src/backend/cuda/op/gemm.cc +++ b/src/backend/cuda/op/gemm.cc @@ -362,9 +362,9 @@ TVM_FFI_STATIC_INIT_BLOCK() { }); refl::GlobalDef().def( "tl.get_tcgen5_instr_desc", - [](int atom_m, int atom_n, int atom_k, DataType a_dtype, - DataType b_dtype, DataType c_dtype, bool a_is_k_major, - bool b_is_k_major, int scale_in_a, int scale_in_b) { + [](int atom_m, int atom_n, int atom_k, DataType a_dtype, DataType b_dtype, + DataType c_dtype, bool a_is_k_major, bool b_is_k_major, int scale_in_a, + int scale_in_b) { uint32_t desc = GetTCGEN5InstrDesc( atom_m, atom_n, atom_k, a_dtype, b_dtype, c_dtype, a_is_k_major, b_is_k_major, scale_in_a, scale_in_b); diff --git a/src/layout/gemm_layouts.cc b/src/layout/gemm_layouts.cc index 635f1c9bc9..e96e838e3f 100644 --- a/src/layout/gemm_layouts.cc +++ b/src/layout/gemm_layouts.cc @@ -488,20 +488,21 @@ Layout makeHalfBankSwizzleLayout(const Buffer &buffer) { return ExpandLayout2D(base, buffer); } -// Layout swizzling for 128 bytes (ALIGN16B unpacksmem variant for sub-byte types) -// For CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B: each FP4 element occupies one -// 8-bit SMEM container. The FP4 payload lives inside that byte container +// Layout swizzling for 128 bytes (ALIGN16B unpacksmem variant for sub-byte +// types) For CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B: each FP4 element occupies +// one 8-bit SMEM container. The FP4 payload lives inside that byte container // (bits 2..5 for NVIDIA unpacksmem), so shared-memory indexing must not pack // two logical FP4 values into one byte. static Layout MakeAlign16BSwizzleLayout2D(int stride, int continuous, int element_size) { ICHECK(element_size < 8 && element_size > 0) - << "ALIGN16B layout is for sub-byte types, got element_size=" << element_size; + << "ALIGN16B layout is for sub-byte types, got element_size=" + << element_size; const int elements_per_chunk = 16; ICHECK(stride % 8 == 0) << "stride=" << stride; ICHECK(continuous % (elements_per_chunk * 8) == 0) - << "continuous=" << continuous - << " must be a multiple of " << (elements_per_chunk * 8); + << "continuous=" << continuous << " must be a multiple of " + << (elements_per_chunk * 8); Var i = InputPlaceholder(0); Var j = InputPlaceholder(1); @@ -918,7 +919,8 @@ Layout makeGemmABLayoutSm100(int mat_stride, int mat_continuous, int continuity, // Sub-byte types (FP4): use ALIGN16B unpacksmem layout. TMA writes each // logical FP4 into an 8-bit SMEM container, matching tcgen05.mma consumption. if (element_size < 8) { - return MakeAlign16BSwizzleLayout2D(mat_stride, mat_continuous, element_size); + return MakeAlign16BSwizzleLayout2D(mat_stride, mat_continuous, + element_size); } int vector_size = 128 / element_size; diff --git a/src/op/tcgen5_meta.h b/src/op/tcgen5_meta.h index ed6408ffef..5b198a9500 100644 --- a/src/op/tcgen5_meta.h +++ b/src/op/tcgen5_meta.h @@ -157,9 +157,9 @@ GetTCGEN5MMAMeta(int M, int N, int K, DataType ab_dtype, DataType c_dtype, inline uint32_t GetTCGEN5InstrDesc(int atom_m, int atom_n, int atom_k, DataType a_dtype, DataType b_dtype, - DataType c_dtype, - bool a_is_k_major, bool b_is_k_major, - int scale_in_a, int scale_in_b) { + DataType c_dtype, bool a_is_k_major, + bool b_is_k_major, int scale_in_a, + int scale_in_b) { ICHECK(atom_m % 16 == 0) << "atom_m must be divisible by 16"; ICHECK(atom_n % 8 == 0) << "atom_n must be divisible by 8"; ICHECK(atom_k == 16 || atom_k == 32) diff --git a/src/tl_templates/cuda/cuda_fp4.h b/src/tl_templates/cuda/cuda_fp4.h index d2a24bd98a..803a0e2c74 100644 --- a/src/tl_templates/cuda/cuda_fp4.h +++ b/src/tl_templates/cuda/cuda_fp4.h @@ -27,7 +27,9 @@ class fp4_e2_2_t { } // Set low 4 bits (first fp4) - TL_DEVICE void set_x(fp4_e2_t val) { __x = (__x & 0xF0) | (val.raw() & 0x0F); } + TL_DEVICE void set_x(fp4_e2_t val) { + __x = (__x & 0xF0) | (val.raw() & 0x0F); + } // Set high 4 bits (second fp4) TL_DEVICE void set_y(fp4_e2_t val) { diff --git a/src/tl_templates/cuda/gemm_mma.h b/src/tl_templates/cuda/gemm_mma.h index cb3536af28..0218374e2a 100644 --- a/src/tl_templates/cuda/gemm_mma.h +++ b/src/tl_templates/cuda/gemm_mma.h @@ -42,8 +42,8 @@ using _X = Underscore; #ifdef __CUDA_ARCH_LIST__ #if __CUDA_ARCH_LIST__ >= 1200 -#include "cuda_fp8.h" #include "cuda_fp4.h" +#include "cuda_fp8.h" #include #include TL_DISPATCH_MMA_TEMPLATE(fp4_e2_t, fp4_e2_t, float, SM120_16x8x32_TN) @@ -59,8 +59,8 @@ TL_DISPATCH_MMA(float, float, float, SM80_16x8x8_F32TF32TF32F32_TN) TL_DISPATCH_MMA(int8_t, int8_t, int, SM80_16x8x32_S32S8S8S32_TN) TL_DISPATCH_MMA(double, double, double, SM80_8x8x4_F64F64F64F64_TN) #elif __CUDA_ARCH_LIST__ >= 1000 -#include "cuda_fp8.h" #include "cuda_fp4.h" +#include "cuda_fp8.h" #include #include #include diff --git a/src/tl_templates/cuda/gemm_sm100.h b/src/tl_templates/cuda/gemm_sm100.h index e0cbec36a2..3615501daf 100644 --- a/src/tl_templates/cuda/gemm_sm100.h +++ b/src/tl_templates/cuda/gemm_sm100.h @@ -208,17 +208,14 @@ struct MMA_Traits, // CUTLASS's to_UMMAFormat returns MXF4Format::E2M1(=1), but // kind::f8f6f4 requires MXF8F6F4Format::E2M1(=5). This more-specialized // partial specialization provides the correct descriptor encoding. -template -struct MMA_Traits, cute::C, +struct MMA_Traits, cute::C, cute::integral_constant, cute::integral_constant, cute::integral_constant, - cute::integral_constant> -{ + cute::integral_constant> { using ValTypeD = c_type; using ValTypeA = cute::float_e2m1_t; using ValTypeB = cute::float_e2m1_t; @@ -231,25 +228,25 @@ struct MMA_Traits, Int, Int>; - using ThrID = Layout<_1>; - using ALayout = Layout,Int>>, - Stride<_0,Stride< _1,Int>>>; - using BLayout = Layout,Int>>, - Stride<_0,Stride< _1,Int>>>; - using CLayout = Layout,Int>>, - Stride<_0,Stride< _1,Int>>>; + using ThrID = Layout<_1>; + using ALayout = + Layout, Int>>, Stride<_0, Stride<_1, Int>>>; + using BLayout = + Layout, Int>>, Stride<_0, Stride<_1, Int>>>; + using CLayout = + Layout, Int>>, Stride<_0, Stride<_1, Int>>>; static constexpr UMMA::InstrDescriptor make_fp4_desc() { UMMA::InstrDescriptor d = {}; d.c_format_ = uint8_t(UMMA::to_CFormat()); - d.a_format_ = 5; // MXF8F6F4Format::E2M1 + d.a_format_ = 5; // MXF8F6F4Format::E2M1 d.b_format_ = 5; d.a_negate_ = uint8_t(a_neg); d.b_negate_ = uint8_t(b_neg); - d.a_major_ = uint8_t(a_major); - d.b_major_ = uint8_t(b_major); - d.n_dim_ = N >> 3; - d.m_dim_ = M >> 4; + d.a_major_ = uint8_t(a_major); + d.b_major_ = uint8_t(b_major); + d.n_dim_ = N >> 3; + d.m_dim_ = M >> 4; return d; } @@ -277,17 +274,14 @@ struct MMA_Traits -struct MMA_Traits, cute::C, +struct MMA_Traits, cute::C, cute::integral_constant, cute::integral_constant, cute::integral_constant, - cute::integral_constant> -{ + cute::integral_constant> { using ValTypeD = c_type; using ValTypeA = cute::float_e2m1_t; using ValTypeB = cute::float_e2m1_t; @@ -300,25 +294,25 @@ struct MMA_Traits, Int, Int>; - using ThrID = Layout<_1>; - using ALayout = Layout,Int>>, - Stride<_0,Stride< _1,Int>>>; - using BLayout = Layout,Int>>, - Stride<_0,Stride< _1,Int>>>; - using CLayout = Layout,Int>>, - Stride<_0,Stride< _1,Int>>>; + using ThrID = Layout<_1>; + using ALayout = + Layout, Int>>, Stride<_0, Stride<_1, Int>>>; + using BLayout = + Layout, Int>>, Stride<_0, Stride<_1, Int>>>; + using CLayout = + Layout, Int>>, Stride<_0, Stride<_1, Int>>>; static constexpr UMMA::InstrDescriptor make_fp4_desc() { UMMA::InstrDescriptor d = {}; d.c_format_ = uint8_t(UMMA::to_CFormat()); - d.a_format_ = 5; // MXF8F6F4Format::E2M1 + d.a_format_ = 5; // MXF8F6F4Format::E2M1 d.b_format_ = 5; d.a_negate_ = uint8_t(a_neg); d.b_negate_ = uint8_t(b_neg); - d.a_major_ = uint8_t(a_major); - d.b_major_ = uint8_t(b_major); - d.n_dim_ = N >> 3; - d.m_dim_ = M >> 4; + d.a_major_ = uint8_t(a_major); + d.b_major_ = uint8_t(b_major); + d.n_dim_ = N >> 3; + d.m_dim_ = M >> 4; return d; } @@ -349,56 +343,57 @@ struct MMA_Traits 这部分负责声明参数列表,即让这个偏特化对所有可能的 组合都有效。 -// 2. 内层 struct MMA_Traits> 说明:只要模板参数匹配到 SM100_MMA_F8F6F4_TS 那一整组, +// 1. 外层 template <...> 这部分负责声明参数列表,即让这个偏特化对所有可能的 +// 组合都有效。 +// 2. 内层 struct MMA_Traits> 说明:只要模板参数匹配到 +// SM100_MMA_F8F6F4_TS 那一整组, // 就会自动选择这个偏特化的实现,不需要“再特化”。 -// 换句话说,你写一个 MMA_Traits,如果 X 恰好能匹配 SM100_MMA_F8F6F4_TS<...> 那一组参数, -// 这个特化版本就会被用到。比如: -// using Traits = MMA_Traits>; +// 换句话说,你写一个 MMA_Traits,如果 X 恰好能匹配 SM100_MMA_F8F6F4_TS<...> +// 那一组参数, 这个特化版本就会被用到。比如: +// using Traits = MMA_Traits>; // Traits的内容就是匹配到此特化,无需再继续特化template参数。 -// 总结:外层 template 支持泛型匹配,struct MMA_Traits<...> 实现具体特化,实际用时只需传入对应参数即可自动选中。 - -template -struct MMA_Traits> -{ +// 总结:外层 template 支持泛型匹配,struct MMA_Traits<...> +// 实现具体特化,实际用时只需传入对应参数即可自动选中。 + +template +struct MMA_Traits< + SM100_MMA_F8F6F4_TS> { using ValTypeD = c_type; using ValTypeA = cute::float_e2m1_t; using ValTypeB = cute::float_e2m1_t; using ValTypeC = c_type; using FrgTypeA = UMMA::tmem_frg_1sm; + UMMA::TmemAllocMode::NonInterleaved>; using FrgTypeB = UMMA::smem_desc; - using FrgTypeC = UMMA::tmem_frg_1sm; + using FrgTypeC = + UMMA::tmem_frg_1sm; constexpr static int K = 32; using Shape_MNK = Shape, Int, Int>; - using ThrID = Layout<_1>; - using ALayout = Layout,Int>>, - Stride<_0,Stride< _1,Int>>>; - using BLayout = Layout,Int>>, - Stride<_0,Stride< _1,Int>>>; - using CLayout = Layout,Int>>, - Stride<_0,Stride< _1,Int>>>; + using ThrID = Layout<_1>; + using ALayout = + Layout, Int>>, Stride<_0, Stride<_1, Int>>>; + using BLayout = + Layout, Int>>, Stride<_0, Stride<_1, Int>>>; + using CLayout = + Layout, Int>>, Stride<_0, Stride<_1, Int>>>; static constexpr UMMA::InstrDescriptor make_fp4_desc() { UMMA::InstrDescriptor d = {}; d.c_format_ = uint8_t(UMMA::to_CFormat()); - d.a_format_ = 5; // MXF8F6F4Format::E2M1 + d.a_format_ = 5; // MXF8F6F4Format::E2M1 d.b_format_ = 5; d.a_negate_ = uint8_t(a_neg); d.b_negate_ = uint8_t(b_neg); - d.a_major_ = uint8_t(a_major); - d.b_major_ = uint8_t(b_major); - d.n_dim_ = N >> 3; - d.m_dim_ = M >> 4; + d.a_major_ = uint8_t(a_major); + d.b_major_ = uint8_t(b_major); + d.n_dim_ = N >> 3; + d.m_dim_ = M >> 4; return d; } @@ -420,27 +415,24 @@ struct MMA_Traits(traits.idesc_); - SM100_MMA_F8F6F4_TS::fma( - tmem_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + SM100_MMA_F8F6F4_TS::fma(tmem_a, desc_b, tmem_c, + uint32_t(traits.accumulate_), idesc); } }; // Mixed-precision MMA_Traits: A = float_e4m3_t (FP8), B = float_e2m1_t (FP4). -// Descriptor: a_format = 0 (MXF8F6F4Format::E4M3), b_format = 5 (MXF8F6F4Format::E2M1). -// SS variant (both operands from shared memory). -template -struct MMA_Traits, cute::C, +struct MMA_Traits, cute::C, cute::integral_constant, cute::integral_constant, cute::integral_constant, - cute::integral_constant> -{ + cute::integral_constant> { using ValTypeD = c_type; using ValTypeA = cute::float_e4m3_t; using ValTypeB = cute::float_e2m1_t; @@ -453,25 +445,25 @@ struct MMA_Traits, Int, Int>; - using ThrID = Layout<_1>; - using ALayout = Layout,Int>>, - Stride<_0,Stride< _1,Int>>>; - using BLayout = Layout,Int>>, - Stride<_0,Stride< _1,Int>>>; - using CLayout = Layout,Int>>, - Stride<_0,Stride< _1,Int>>>; + using ThrID = Layout<_1>; + using ALayout = + Layout, Int>>, Stride<_0, Stride<_1, Int>>>; + using BLayout = + Layout, Int>>, Stride<_0, Stride<_1, Int>>>; + using CLayout = + Layout, Int>>, Stride<_0, Stride<_1, Int>>>; static constexpr UMMA::InstrDescriptor make_a8w4_desc() { UMMA::InstrDescriptor d = {}; d.c_format_ = uint8_t(UMMA::to_CFormat()); - d.a_format_ = 0; // MXF8F6F4Format::E4M3 - d.b_format_ = 5; // MXF8F6F4Format::E2M1 + d.a_format_ = 0; // MXF8F6F4Format::E4M3 + d.b_format_ = 5; // MXF8F6F4Format::E2M1 d.a_negate_ = uint8_t(a_neg); d.b_negate_ = uint8_t(b_neg); - d.a_major_ = uint8_t(a_major); - d.b_major_ = uint8_t(b_major); - d.n_dim_ = N >> 3; - d.m_dim_ = M >> 4; + d.a_major_ = uint8_t(a_major); + d.b_major_ = uint8_t(b_major); + d.n_dim_ = N >> 3; + d.m_dim_ = M >> 4; return d; } @@ -499,17 +491,14 @@ struct MMA_Traits -struct MMA_Traits, cute::C, +struct MMA_Traits, cute::C, cute::integral_constant, cute::integral_constant, cute::integral_constant, - cute::integral_constant> -{ + cute::integral_constant> { using ValTypeD = c_type; using ValTypeA = cute::float_e4m3_t; using ValTypeB = cute::float_e2m1_t; @@ -522,25 +511,25 @@ struct MMA_Traits, Int, Int>; - using ThrID = Layout<_1>; - using ALayout = Layout,Int>>, - Stride<_0,Stride< _1,Int>>>; - using BLayout = Layout,Int>>, - Stride<_0,Stride< _1,Int>>>; - using CLayout = Layout,Int>>, - Stride<_0,Stride< _1,Int>>>; + using ThrID = Layout<_1>; + using ALayout = + Layout, Int>>, Stride<_0, Stride<_1, Int>>>; + using BLayout = + Layout, Int>>, Stride<_0, Stride<_1, Int>>>; + using CLayout = + Layout, Int>>, Stride<_0, Stride<_1, Int>>>; static constexpr UMMA::InstrDescriptor make_a8w4_desc() { UMMA::InstrDescriptor d = {}; d.c_format_ = uint8_t(UMMA::to_CFormat()); - d.a_format_ = 0; // MXF8F6F4Format::E4M3 - d.b_format_ = 5; // MXF8F6F4Format::E2M1 + d.a_format_ = 0; // MXF8F6F4Format::E4M3 + d.b_format_ = 5; // MXF8F6F4Format::E2M1 d.a_negate_ = uint8_t(a_neg); d.b_negate_ = uint8_t(b_neg); - d.a_major_ = uint8_t(a_major); - d.b_major_ = uint8_t(b_major); - d.n_dim_ = N >> 3; - d.m_dim_ = M >> 4; + d.a_major_ = uint8_t(a_major); + d.b_major_ = uint8_t(b_major); + d.n_dim_ = N >> 3; + d.m_dim_ = M >> 4; return d; } diff --git a/src/tl_templates/cuda/instruction/mma.h b/src/tl_templates/cuda/instruction/mma.h index b8fad7a94b..ff120f41be 100644 --- a/src/tl_templates/cuda/instruction/mma.h +++ b/src/tl_templates/cuda/instruction/mma.h @@ -1,9 +1,9 @@ #pragma once #include "../common.h" +#include #include #include -#include #ifndef __CUDACC_RTC__ #include @@ -148,13 +148,16 @@ TL_DEFINE_MMA_DISPATCHER(kFloat64, kFloat64, kFloat64, 8, 8, 4, false, true, false, cute::SM80_8x8x4_F64F64F64F64_TN) // FP4 inputs (k32, SM120 kind::f8f6f4) -using SM120_FP4_FP4_F32_TN = cute::SM120_16x8x32_TN; +using SM120_FP4_FP4_F32_TN = + cute::SM120_16x8x32_TN; TL_DEFINE_MMA_DISPATCHER(kFloat4_e2m1fn, kFloat4_e2m1fn, kFloat32, 16, 8, 32, false, true, false, SM120_FP4_FP4_F32_TN) // Mixed FP8 x FP4 and FP4 x FP8 (k32, SM120 kind::f8f6f4) -using SM120_FP8_FP4_F32_TN = cute::SM120_16x8x32_TN; -using SM120_FP4_FP8_F32_TN = cute::SM120_16x8x32_TN; +using SM120_FP8_FP4_F32_TN = + cute::SM120_16x8x32_TN; +using SM120_FP4_FP8_F32_TN = + cute::SM120_16x8x32_TN; TL_DEFINE_MMA_DISPATCHER(kFloat8_e4m3, kFloat4_e2m1fn, kFloat32, 16, 8, 32, false, true, false, SM120_FP8_FP4_F32_TN) TL_DEFINE_MMA_DISPATCHER(kFloat4_e2m1fn, kFloat8_e4m3, kFloat32, 16, 8, 32, @@ -183,11 +186,12 @@ TL_DEVICE void mma_sync( using BReg = typename Dispatcher::BRegType; constexpr int nA = detail::MmaImplTraits::kARegs; constexpr int nB = detail::MmaImplTraits::kBRegs; - AReg as[nA]; BReg bs[nB]; - #pragma unroll + AReg as[nA]; + BReg bs[nB]; +#pragma unroll for (int i = 0; i < nA; ++i) as[i] = (AType == DataType::kFloat4_e2m1fn) ? (a[i] << 2) : a[i]; - #pragma unroll +#pragma unroll for (int i = 0; i < nB; ++i) bs[i] = (BType == DataType::kFloat4_e2m1fn) ? (b[i] << 2) : b[i]; Dispatcher::exec(c, as, bs, c); diff --git a/src/tl_templates/cuda/ldsm.h b/src/tl_templates/cuda/ldsm.h index abd1471e5c..49321e91d6 100644 --- a/src/tl_templates/cuda/ldsm.h +++ b/src/tl_templates/cuda/ldsm.h @@ -55,7 +55,7 @@ TL_DEVICE void ptx_ldmatrix_x2_trans(void const *const smem_ptr, // Reads packed 4-bit data from shared memory and unpacks each 4-bit value // into the low 4 bits of an 8-bit container (upper 4 bits zeroed). TL_DEVICE void ptx_ldmatrix_b4x16_x1(void const *const smem_ptr, - void *const local_ptr) { + void *const local_ptr) { uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); int32_t *value = reinterpret_cast(local_ptr); asm volatile( @@ -65,7 +65,7 @@ TL_DEVICE void ptx_ldmatrix_b4x16_x1(void const *const smem_ptr, } TL_DEVICE void ptx_ldmatrix_b4x16_x2(void const *const smem_ptr, - void *const local_ptr) { + void *const local_ptr) { uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); int32_t *value = reinterpret_cast(local_ptr); asm volatile( @@ -75,13 +75,13 @@ TL_DEVICE void ptx_ldmatrix_b4x16_x2(void const *const smem_ptr, } TL_DEVICE void ptx_ldmatrix_b4x16_x4(void const *const smem_ptr, - void *const local_ptr) { + void *const local_ptr) { uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); int32_t *value = reinterpret_cast(local_ptr); - asm volatile( - "ldmatrix.sync.aligned.m8n16.x4.shared.b8x16.b4x16_p64 {%0, %1, %2, %3}, [%4];\n" - : "=r"(value[0]), "=r"(value[1]), "=r"(value[2]), "=r"(value[3]) - : "r"(smem_int_ptr)); + asm volatile("ldmatrix.sync.aligned.m8n16.x4.shared.b8x16.b4x16_p64 {%0, %1, " + "%2, %3}, [%4];\n" + : "=r"(value[0]), "=r"(value[1]), "=r"(value[2]), "=r"(value[3]) + : "r"(smem_int_ptr)); } TL_DEVICE void ptx_ldmatrix_x4_trans(void const *const smem_ptr, diff --git a/tilelang/intrinsics/mma_macro_generator.py b/tilelang/intrinsics/mma_macro_generator.py index 26e34e6a51..a1167aa3b2 100644 --- a/tilelang/intrinsics/mma_macro_generator.py +++ b/tilelang/intrinsics/mma_macro_generator.py @@ -118,7 +118,12 @@ def __init__( def _initialize_k_dim(self, a_dtype=T.float16): if isinstance(a_dtype, str): a_dtype = DataType(a_dtype) - self.k_dim = min(256 // a_dtype.bits, self.chunk) + # SM120 f8f6f4 FP4 MMA is m16n8k32. Although 256 / 4 would allow a + # k64 fragment by bit count, there is no k64 dispatcher for FP4. + if str(a_dtype) == "float4_e2m1fn": + self.k_dim = min(32, self.chunk) + else: + self.k_dim = min(256 // a_dtype.bits, self.chunk) def _initialize_local_size(self, m_dim=16, n_dim=16, k_dim=16, warp_size=32): self.local_size_a = (m_dim * k_dim) // warp_size diff --git a/tilelang/intrinsics/tcgen05_macro_generator.py b/tilelang/intrinsics/tcgen05_macro_generator.py index c496a8c343..8868c1e8d1 100644 --- a/tilelang/intrinsics/tcgen05_macro_generator.py +++ b/tilelang/intrinsics/tcgen05_macro_generator.py @@ -919,16 +919,55 @@ def get_tcgen5_instr_desc( self, atom_m: int, atom_n: int, atom_k: int, a_is_k_major: bool, b_is_k_major: bool, scale_in_a: int, scale_in_b: int ) -> PrimExpr: """Build the 64-bit instruction descriptor for a ``tcgen05.mma`` PTX call.""" - desc = _ffi_api.get_tcgen5_instr_desc( - atom_m, - atom_n, - atom_k, - DataType(self.a_dtype), - DataType(self.b_dtype), - DataType(self.accum_dtype), - a_is_k_major, - b_is_k_major, - scale_in_a, - scale_in_b, - ) + + def encode_dtype(dtype: str) -> int: + dtype = str(DataType(dtype)) + if dtype == "float16": + return 0 + if dtype == "bfloat16": + return 1 + if dtype in ("float8_e4m3", "float8_e4m3fn", "float8_e4m3fnuz"): + return 0 + if dtype in ("float8_e5m2", "float8_e5m2fnuz"): + return 1 + if dtype == "float6_e2m3fn": + return 3 + if dtype == "float6_e3m2fn": + return 4 + if dtype == "float4_e2m1fn": + return 5 + if dtype == "int8": + return 1 + if dtype == "uint8": + return 0 + raise ValueError(f"Unsupported dtype for TCGEN5MMA descriptor: {dtype}") + + def set_bits(value: int, start: int, width: int) -> int: + return (value & ((1 << width) - 1)) << start + + accum_dtype = str(DataType(self.accum_dtype)) + if accum_dtype == "float16": + c_format = 0 + elif accum_dtype == "float32": + c_format = 1 + elif accum_dtype == "int32": + c_format = 2 + else: + raise ValueError(f"Unsupported accumulator dtype for TCGEN5MMA descriptor: {accum_dtype}") + + if atom_m % 16 != 0 or atom_n % 8 != 0 or atom_k not in (16, 32): + raise ValueError(f"Unsupported TCGEN5MMA atom shape: m={atom_m}, n={atom_n}, k={atom_k}") + if scale_in_a not in (1, -1) or scale_in_b not in (1, -1): + raise ValueError("scale_in_a and scale_in_b must be +/-1 for TCGEN5MMA") + + desc = 0 + desc |= set_bits(c_format, 4, 2) + desc |= set_bits(encode_dtype(self.a_dtype), 7, 3) + desc |= set_bits(encode_dtype(self.b_dtype), 10, 3) + desc |= set_bits(1 if scale_in_a == -1 else 0, 13, 1) + desc |= set_bits(1 if scale_in_b == -1 else 0, 14, 1) + desc |= set_bits(0 if a_is_k_major else 1, 15, 1) + desc |= set_bits(0 if b_is_k_major else 1, 16, 1) + desc |= set_bits(atom_n >> 3, 17, 6) + desc |= set_bits(atom_m >> 4, 24, 5) return lift(desc) diff --git a/tilelang/intrinsics/utils.py b/tilelang/intrinsics/utils.py index f65fff1a9b..57ab83a8cb 100644 --- a/tilelang/intrinsics/utils.py +++ b/tilelang/intrinsics/utils.py @@ -7,6 +7,8 @@ ldmatrix_trans_32x8_to_shared_16x16_layout, ldmatrix_32x16_to_shared_16x32_layout_a, ldmatrix_32x16_to_shared_16x32_layout_b, + ldmatrix_32x16_to_shared_16x32_fp4_layout_a, + ldmatrix_32x16_to_shared_16x32_fp4_layout_b, mma_store_32x8_to_shared_16x16_layout, mma_store_32x2_to_shared_8x8_layout_fp64, ) @@ -49,17 +51,26 @@ def get_ldmatrix_offset( else: new_row_idx, new_col_idx = transform_func(row_idx, col_idx) return new_row_idx, new_col_idx - elif dtype_bits <= 8: + elif dtype_bits == 4: + if matrix == "B" and transposed: + transform_func = ldmatrix_32x16_to_shared_16x32_fp4_layout_b + new_row_idx, new_col_idx = transform_func(row_idx, col_idx) + return new_row_idx, new_col_idx + elif matrix == "A" and not transposed: + transform_func = ldmatrix_32x16_to_shared_16x32_fp4_layout_a + new_row_idx, new_col_idx = transform_func(row_idx, col_idx) + return new_row_idx, new_col_idx + else: + raise ValueError("ldmatrix only supports B transposed and A non-transposed for fp4") + elif dtype_bits == 8: if matrix == "B" and transposed: transform_func = ldmatrix_32x16_to_shared_16x32_layout_b new_row_idx, new_col_idx = transform_func(row_idx, col_idx) - pack_factor = 8 // dtype_bits - return new_row_idx, new_col_idx * pack_factor + return new_row_idx, new_col_idx elif matrix == "A" and not transposed: transform_func = ldmatrix_32x16_to_shared_16x32_layout_a new_row_idx, new_col_idx = transform_func(row_idx, col_idx) - pack_factor = 8 // dtype_bits - return new_row_idx, new_col_idx * pack_factor + return new_row_idx, new_col_idx else: raise ValueError("ldmatrix only supports B transposed and A non-transposed for int8") else: diff --git a/tilelang/tileop/gemm/gemm_base.py b/tilelang/tileop/gemm/gemm_base.py index fdd00db2f3..825a01b390 100644 --- a/tilelang/tileop/gemm/gemm_base.py +++ b/tilelang/tileop/gemm/gemm_base.py @@ -89,7 +89,10 @@ def in_dtype(self) -> str: """ if is_tensor_memory(self.A): return self.B.dtype - return self.A.dtype + dtype = self.A.dtype + if dtype == "uint8" and self.C.dtype in ("float32", "float16"): + return "float4_e2m1fn" + return dtype @property def in_dtype_b(self) -> str: diff --git a/tilelang/tileop/gemm/gemm_tcgen05.py b/tilelang/tileop/gemm/gemm_tcgen05.py index 9b3de02407..ef6efb6e3e 100644 --- a/tilelang/tileop/gemm/gemm_tcgen05.py +++ b/tilelang/tileop/gemm/gemm_tcgen05.py @@ -15,7 +15,7 @@ from tilelang import language as T from tilelang.utils.language import retrieve_ptr from tilelang.transform.simplify import _Simplify -from tvm import tir +from tvm import DataType, tir from tvm.target import Target from tvm.ir import Range from tvm.arith import Analyzer @@ -44,7 +44,7 @@ class GemmTCGEN5(GemmBase): of operands A and B. """ - def infer_shared_layout(self, continuity: int, k_major: bool) -> Callable[[tir.Buffer], Layout]: + def infer_shared_layout(self, dtype, continuity: int, k_major: bool) -> Callable[[tir.Buffer], Layout]: """Infer the shared-memory layout for TCGEN05 operands. Sub-byte inputs (e.g. FP4) must use the TCGEN05-specific layout helper so @@ -53,8 +53,10 @@ def infer_shared_layout(self, continuity: int, k_major: bool) -> Callable[[tir.B the SIMT copy writes packed-linear data, so we return a plain linear layout to keep the SMEM descriptor consistent with the actual data layout. """ - if self.in_dtype.bits < 8: + dtype_bits = dtype.bits if hasattr(dtype, "bits") else DataType(dtype).bits + if dtype_bits < 8: import tvm + try: _pass_ctx = tvm.transform.PassContext.current() disable_tma = _pass_ctx.config.get("tl.disable_tma_lower", False) @@ -64,10 +66,8 @@ def infer_shared_layout(self, continuity: int, k_major: bool) -> Callable[[tir.B # Non-TMA path: SIMT copy writes packed-linear nibbles; use a # linear layout so the SMEM descriptor matches the actual data. return make_linear_layout - return lambda buffer: make_tcgen05mma_swizzled_layout( - buffer, continuity=continuity, k_major=k_major - ) - vectorized_size = 128 // self.in_dtype.bits + return lambda buffer: make_tcgen05mma_swizzled_layout(buffer, continuity=continuity, k_major=k_major) + vectorized_size = 128 // dtype_bits if continuity % (vectorized_size * 8) == 0: return make_full_bank_swizzled_layout elif continuity % (vectorized_size * 4) == 0: @@ -93,7 +93,7 @@ def infer_layout(self, target: Target, thread_nums: int): warp_col_tiles = int(self.N // n_warp) mma_emitter = TensorCoreIntrinEmitter( a_dtype=self.in_dtype, - b_dtype=self.in_dtype, + b_dtype=self.in_dtype_b, accum_dtype=self.accum_dtype, a_transposed=self.trans_A, b_transposed=self.trans_B, @@ -115,15 +115,15 @@ def infer_layout(self, target: Target, thread_nums: int): b_continuity = self.K if b_is_k_major else int(self.B.shape[-1]) # don't use N, as it may be for 2cta return { - self.A: self.infer_shared_layout(a_continuity, a_is_k_major)(self.A), - self.B: self.infer_shared_layout(b_continuity, b_is_k_major)(self.B), + self.A: self.infer_shared_layout(self.in_dtype, a_continuity, a_is_k_major)(self.A), + self.B: self.infer_shared_layout(self.in_dtype_b, b_continuity, b_is_k_major)(self.B), self.C: mma_emitter.make_mma_store_layout(self.C), } if self.is_gemm_ts(): b_continuity = self.K if b_is_k_major else int(self.B.shape[-1]) layouts = { self.A: mma_emitter.make_mma_store_layout(self.A), - self.B: self.infer_shared_layout(b_continuity, b_is_k_major)(self.B), + self.B: self.infer_shared_layout(self.in_dtype_b, b_continuity, b_is_k_major)(self.B), self.C: mma_emitter.make_mma_store_layout(self.C), } return layouts @@ -148,7 +148,7 @@ def lower( warp_col_tiles = int(self.N // n_warp) mma_emitter = TensorCoreIntrinEmitter( a_dtype=self.in_dtype, - b_dtype=self.in_dtype, + b_dtype=self.in_dtype_b, accum_dtype=self.accum_dtype, a_transposed=self.trans_A, b_transposed=self.trans_B, From 3fb9c9d1a25642a45ee1a68fd575c7fa78203f34 Mon Sep 17 00:00:00 2001 From: wahao Date: Sat, 14 Mar 2026 22:31:09 +0800 Subject: [PATCH 09/19] feat: A8W4 mixed-type MMA (FP8xFP4) + FP4 MoE example on SM120 --- .../gemm_fp4/example_fusedmoe_a8w4_sm120.py | 27 ++++++++++++++++++ examples/gemm_fp4/example_gemm_a8w4_sm120.py | 28 +++++++++++++++++++ tilelang/tileop/gemm/gemm_mma.py | 19 +++++++++++++ 3 files changed, 74 insertions(+) diff --git a/examples/gemm_fp4/example_fusedmoe_a8w4_sm120.py b/examples/gemm_fp4/example_fusedmoe_a8w4_sm120.py index fbd06a7e9d..104b139e50 100644 --- a/examples/gemm_fp4/example_fusedmoe_a8w4_sm120.py +++ b/examples/gemm_fp4/example_fusedmoe_a8w4_sm120.py @@ -13,6 +13,10 @@ see examples/fusedmoe/example_fusedmoe_tilelang.py. """ +<<<<<<< HEAD +======= +import os +>>>>>>> f13a6b71 (feat: A8W4 mixed-type MMA (FP8xFP4) + FP4 MoE example on SM120) import time import torch import tilelang @@ -20,6 +24,7 @@ FP4_E2M1_TO_FLOAT = [ +<<<<<<< HEAD 0.0, 0.5, 1.0, @@ -36,6 +41,10 @@ -3.0, -4.0, -6.0, +======= + 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, + -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0, +>>>>>>> f13a6b71 (feat: A8W4 mixed-type MMA (FP8xFP4) + FP4 MoE example on SM120) ] @@ -45,6 +54,7 @@ def fp4_uint8_to_float(t): def moe_shared_expert_a8w4( +<<<<<<< HEAD num_tokens, d_hidden, d_expert, @@ -53,6 +63,11 @@ def moe_shared_expert_a8w4( block_expert=128, threads=128, num_stages=1, +======= + num_tokens, d_hidden, d_expert, + block_token=128, block_hidden=128, block_expert=128, + threads=128, num_stages=1, +>>>>>>> f13a6b71 (feat: A8W4 mixed-type MMA (FP8xFP4) + FP4 MoE example on SM120) ): """Single shared expert: gate_up GEMM -> SiLU*up -> down GEMM.""" scale = 1.44269504 # log2(e) for fast SiLU @@ -90,7 +105,13 @@ def main( # Fused SiLU activation: gate = gate * sigmoid(gate), then up = up * gate for i, j in T.Parallel(block_token, block_expert): +<<<<<<< HEAD gate_local[i, j] = gate_local[i, j] * (1.0 / (1.0 + T.exp2(-gate_local[i, j] * scale))) +======= + gate_local[i, j] = gate_local[i, j] * ( + 1.0 / (1.0 + T.exp2(-gate_local[i, j] * scale)) + ) +>>>>>>> f13a6b71 (feat: A8W4 mixed-type MMA (FP8xFP4) + FP4 MoE example on SM120) up_local[i, j] = up_local[i, j] * gate_local[i, j] T.copy(up_local, output[bx * block_token, by * block_expert]) @@ -106,6 +127,7 @@ def main( print(f"Running FP4 MoE (A8W4): tokens={num_tokens}, hidden={d_hidden}, expert={d_expert}") func = moe_shared_expert_a8w4( +<<<<<<< HEAD num_tokens, d_hidden, d_expert, @@ -114,6 +136,11 @@ def main( block_expert=128, threads=128, num_stages=1, +======= + num_tokens, d_hidden, d_expert, + block_token=128, block_hidden=128, block_expert=128, + threads=128, num_stages=1, +>>>>>>> f13a6b71 (feat: A8W4 mixed-type MMA (FP8xFP4) + FP4 MoE example on SM120) ) jit_kernel = tilelang.compile( diff --git a/examples/gemm_fp4/example_gemm_a8w4_sm120.py b/examples/gemm_fp4/example_gemm_a8w4_sm120.py index b48564ba56..3083850076 100644 --- a/examples/gemm_fp4/example_gemm_a8w4_sm120.py +++ b/examples/gemm_fp4/example_gemm_a8w4_sm120.py @@ -8,6 +8,10 @@ No block scaling. Direct FP8 x FP4 tensor core multiply-accumulate. """ +<<<<<<< HEAD +======= +import os +>>>>>>> f13a6b71 (feat: A8W4 mixed-type MMA (FP8xFP4) + FP4 MoE example on SM120) import time import torch import tilelang @@ -15,6 +19,7 @@ def gemm_a8w4( +<<<<<<< HEAD M, N, K, @@ -25,6 +30,11 @@ def gemm_a8w4( accum_dtype, num_stages=2, threads=128, +======= + M, N, K, block_M, block_N, block_K, + out_dtype, accum_dtype, + num_stages=2, threads=128, +>>>>>>> f13a6b71 (feat: A8W4 mixed-type MMA (FP8xFP4) + FP4 MoE example on SM120) ): A_shape = (M, K) B_shape = (N, K) @@ -52,6 +62,7 @@ def main( FP4_E2M1_TO_FLOAT = [ +<<<<<<< HEAD 0.0, 0.5, 1.0, @@ -68,6 +79,10 @@ def main( -3.0, -4.0, -6.0, +======= + 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, + -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0, +>>>>>>> f13a6b71 (feat: A8W4 mixed-type MMA (FP8xFP4) + FP4 MoE example on SM120) ] @@ -83,6 +98,7 @@ def fp4_uint8_to_float(tensor_uint8): accum_dtype = T.float32 print(f"Running A8W4 GEMM: M={M}, N={N}, K={K}") +<<<<<<< HEAD print(" A: float8_e4m3fn, B: FP4 (unpacked uint8)") func = gemm_a8w4( @@ -96,6 +112,14 @@ def fp4_uint8_to_float(tensor_uint8): accum_dtype, num_stages=2, threads=128, +======= +print(f" A: float8_e4m3fn, B: FP4 (unpacked uint8)") + +func = gemm_a8w4( + M, N, K, block_M, block_N, block_K, + out_dtype, accum_dtype, + num_stages=2, threads=128, +>>>>>>> f13a6b71 (feat: A8W4 mixed-type MMA (FP8xFP4) + FP4 MoE example on SM120) ) jit_kernel = tilelang.compile( @@ -140,7 +164,11 @@ def fp4_uint8_to_float(tensor_uint8): if rel_err < 0.01: print("[PASS] numerical verification (rel_err < 0.01)") else: +<<<<<<< HEAD print("[WARN] large diff -- may indicate layout or data flow issue") +======= + print(f"[WARN] large diff -- may indicate layout or data flow issue") +>>>>>>> f13a6b71 (feat: A8W4 mixed-type MMA (FP8xFP4) + FP4 MoE example on SM120) # --- Benchmark --- torch.cuda.synchronize() diff --git a/tilelang/tileop/gemm/gemm_mma.py b/tilelang/tileop/gemm/gemm_mma.py index b5c2ba2c37..2282eccfa4 100644 --- a/tilelang/tileop/gemm/gemm_mma.py +++ b/tilelang/tileop/gemm/gemm_mma.py @@ -75,7 +75,26 @@ def lower( mbar_phase_expr: tir.PrimExpr | None = None, ): thread_nums = thread_bounds.extent +<<<<<<< HEAD mma_emitter = self._make_mma_emitter(target, thread_nums, thread_var=thread_var) +======= + m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, GemmInst.MMA) + warp_row_tiles = int(self.M // m_warp) + warp_col_tiles = int(self.N // n_warp) + mma_emitter = TensorCoreIntrinEmitter( + a_dtype=self.in_dtype, + b_dtype=self.in_dtype_b, + accum_dtype=self.accum_dtype, + a_transposed=self.trans_A, + b_transposed=self.trans_B, + block_row_warps=m_warp, + block_col_warps=n_warp, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=self.chunk, + thread_var=thread_var, + ) +>>>>>>> f13a6b71 (feat: A8W4 mixed-type MMA (FP8xFP4) + FP4 MoE example on SM120) in_dtype = self.in_dtype in_dtype_b = self.in_dtype_b From 63c945170b99abed9bb114d06a595113726bfb04 Mon Sep 17 00:00:00 2001 From: wahao Date: Mon, 11 May 2026 16:59:44 +0800 Subject: [PATCH 10/19] Resolve FP4 example conflicts and remove generated artifacts --- examples/amd_mxfp4_mm/asm-dump-v22/src_v22.s | 2193 ----------------- examples/amd_mxfp4_mm/asm-dump/src.s | 1283 ---------- .../gemm_fp4/example_fusedmoe_a8w4_sm120.py | 27 - examples/gemm_fp4/example_gemm_a8w4_sm120.py | 28 - tilelang/tileop/gemm/gemm_mma.py | 19 - 5 files changed, 3550 deletions(-) delete mode 100644 examples/amd_mxfp4_mm/asm-dump-v22/src_v22.s delete mode 100644 examples/amd_mxfp4_mm/asm-dump/src.s diff --git a/examples/amd_mxfp4_mm/asm-dump-v22/src_v22.s b/examples/amd_mxfp4_mm/asm-dump-v22/src_v22.s deleted file mode 100644 index a724dd05e7..0000000000 --- a/examples/amd_mxfp4_mm/asm-dump-v22/src_v22.s +++ /dev/null @@ -1,2193 +0,0 @@ - -# __CLANG_OFFLOAD_BUNDLE____START__ hip-amdgcn-amd-amdhsa--gfx950 - .amdgcn_target "amdgcn-amd-amdhsa--gfx950" - .amdhsa_code_object_version 6 - .text - .protected _Z14gemm_core_gmemPKhS0_PvS0_S0_iiiiiii ; -- Begin function _Z14gemm_core_gmemPKhS0_PvS0_S0_iiiiiii - .globl _Z14gemm_core_gmemPKhS0_PvS0_S0_iiiiiii - .p2align 8 - .type _Z14gemm_core_gmemPKhS0_PvS0_S0_iiiiiii,@function -_Z14gemm_core_gmemPKhS0_PvS0_S0_iiiiiii: ; @_Z14gemm_core_gmemPKhS0_PvS0_S0_iiiiiii -; %bb.0: - s_load_dwordx2 s[6:7], s[0:1], 0x10 - s_load_dwordx8 s[8:15], s[0:1], 0x28 - v_and_b32_e32 v1, 31, v0 - s_waitcnt lgkmcnt(0) - s_lshl_b32 s15, s2, 5 - s_lshl_b32 s17, s3, 5 - v_or_b32_e32 v2, s17, v1 - s_mul_i32 s5, s11, s4 - s_add_i32 s2, s5, s11 - s_min_i32 s11, s2, s10 - v_lshrrev_b32_e32 v14, 5, v0 - v_cmp_gt_i32_e64 s[2:3], s9, v2 - v_ashrrev_i32_e32 v3, 31, v2 - v_mov_b32_e32 v5, 0 - v_accvgpr_write_b32 a0, 0 - v_accvgpr_write_b32 a1, 0 - v_accvgpr_write_b32 a2, 0 - v_accvgpr_write_b32 a3, 0 - v_accvgpr_write_b32 a4, 0 - v_accvgpr_write_b32 a5, 0 - v_accvgpr_write_b32 a6, 0 - v_accvgpr_write_b32 a7, 0 - v_accvgpr_write_b32 a8, 0 - v_accvgpr_write_b32 a9, 0 - v_accvgpr_write_b32 a10, 0 - v_accvgpr_write_b32 a11, 0 - v_accvgpr_write_b32 a12, 0 - v_accvgpr_write_b32 a13, 0 - v_accvgpr_write_b32 a14, 0 - s_cmp_ge_i32 s5, s11 - v_accvgpr_write_b32 a15, 0 - s_cbranch_scc1 .LBB0_11 -; %bb.1: ; %.preheader123 - s_load_dwordx4 s[20:23], s[0:1], 0x0 - s_load_dwordx4 s[24:27], s[0:1], 0x18 - v_lshlrev_b32_e32 v0, 2, v0 - s_ashr_i32 s16, s10, 5 - v_or_b32_e32 v12, s15, v1 - v_lshrrev_b32_e32 v4, 4, v1 - v_and_b32_e32 v10, 60, v0 - s_lshl_b32 s0, s14, 5 - s_ashr_i32 s1, s17, 5 - s_ashr_i32 s10, s10, 1 - s_waitcnt lgkmcnt(0) - v_mov_b64_e32 v[0:1], s[20:21] - v_mov_b64_e32 v[6:7], s[22:23] - v_mov_b64_e32 v[8:9], s[24:25] - s_mul_hi_i32 s14, s0, s1 - s_mul_i32 s17, s0, s1 - v_mad_i64_i32 v[0:1], s[0:1], s10, v12, v[0:1] - v_mad_i64_i32 v[6:7], s[0:1], s10, v2, v[6:7] - v_mad_i64_i32 v[8:9], s[0:1], s13, v12, v[8:9] - s_add_u32 s0, s26, s17 - v_mov_b32_e32 v11, v5 - s_addc_u32 s1, s27, s14 - v_lshl_add_u64 v[10:11], s[0:1], 0, v[10:11] - v_cmp_gt_i32_e32 vcc, s8, v12 - v_accvgpr_write_b32 a15, 0 - v_accvgpr_write_b32 a14, 0 - v_accvgpr_write_b32 a13, 0 - v_accvgpr_write_b32 a12, 0 - v_accvgpr_write_b32 a11, 0 - v_accvgpr_write_b32 a10, 0 - v_accvgpr_write_b32 a9, 0 - v_accvgpr_write_b32 a8, 0 - v_accvgpr_write_b32 a7, 0 - v_accvgpr_write_b32 a6, 0 - v_accvgpr_write_b32 a5, 0 - v_accvgpr_write_b32 a4, 0 - v_accvgpr_write_b32 a3, 0 - v_accvgpr_write_b32 a2, 0 - v_accvgpr_write_b32 a1, 0 - v_accvgpr_write_b32 a0, 0 - v_lshlrev_b32_e32 v15, 4, v14 - v_lshl_add_u64 v[10:11], v[10:11], 0, v[4:5] - s_branch .LBB0_3 -.LBB0_2: ; %_Z14load_frag_gmemPKhS0_S0_S0_iiiillllibbRDv32_hS2_RhS3_.exit - ; in Loop: Header=BB0_3 Depth=1 - s_or_b64 exec, exec, s[0:1] - s_waitcnt vmcnt(0) - v_mfma_scale_f32_32x32x64_f8f6f4 a[0:15], v[16:19], v[20:23], a[0:15], v13, v4 op_sel_hi:[0,0,0] cbsz:4 blgp:4 - s_add_i32 s5, s5, 64 - s_cmp_lt_i32 s5, s11 - s_cbranch_scc0 .LBB0_11 -.LBB0_3: ; =>This Inner Loop Header: Depth=1 - s_ashr_i32 s0, s5, 1 - v_add_u32_e32 v12, s0, v15 - v_mov_b32_e32 v16, 0 - v_mov_b32_e32 v17, 0 - v_mov_b32_e32 v18, 0 - v_mov_b32_e32 v20, 0 - v_ashrrev_i32_e32 v13, 31, v12 - v_mov_b32_e32 v19, 0 - s_and_saveexec_b64 s[0:1], vcc - s_cbranch_execz .LBB0_5 -; %bb.4: ; %.loopexit45.i - ; in Loop: Header=BB0_3 Depth=1 - v_lshl_add_u64 v[16:17], v[0:1], 0, v[12:13] - global_load_dwordx4 v[16:19], v[16:17], off -.LBB0_5: ; in Loop: Header=BB0_3 Depth=1 - s_or_b64 exec, exec, s[0:1] - v_mov_b32_e32 v21, 0 - v_mov_b32_e32 v22, 0 - v_mov_b32_e32 v23, 0 - s_and_saveexec_b64 s[0:1], s[2:3] - s_cbranch_execz .LBB0_7 -; %bb.6: ; %.loopexit.i - ; in Loop: Header=BB0_3 Depth=1 - v_lshl_add_u64 v[12:13], v[6:7], 0, v[12:13] - global_load_dwordx4 v[20:23], v[12:13], off -.LBB0_7: ; in Loop: Header=BB0_3 Depth=1 - s_or_b64 exec, exec, s[0:1] - s_ashr_i32 s0, s5, 5 - v_add_u32_e32 v12, s0, v14 - v_cmp_gt_i32_e64 s[0:1], s13, v12 - s_and_b64 s[18:19], vcc, s[0:1] - v_mov_b32_e32 v4, 0x7f - v_mov_b32_e32 v13, 0x7f - s_and_saveexec_b64 s[0:1], s[18:19] - s_cbranch_execz .LBB0_9 -; %bb.8: ; in Loop: Header=BB0_3 Depth=1 - v_ashrrev_i32_e32 v13, 31, v12 - v_lshl_add_u64 v[24:25], v[8:9], 0, v[12:13] - global_load_ubyte v13, v[24:25], off -.LBB0_9: ; in Loop: Header=BB0_3 Depth=1 - s_or_b64 exec, exec, s[0:1] - v_cmp_gt_i32_e64 s[0:1], s16, v12 - s_and_b64 s[18:19], s[2:3], s[0:1] - s_and_saveexec_b64 s[0:1], s[18:19] - s_cbranch_execz .LBB0_2 -; %bb.10: ; in Loop: Header=BB0_3 Depth=1 - v_lshlrev_b32_e32 v4, 5, v12 - v_and_b32_e32 v24, 0xffffff00, v4 - v_ashrrev_i32_e32 v25, 31, v24 - v_lshlrev_b32_e32 v4, 6, v12 - v_and_b32_e32 v4, 0xc0, v4 - v_lshrrev_b32_e32 v12, 1, v12 - v_lshl_add_u64 v[24:25], v[10:11], 0, v[24:25] - v_and_b32_e32 v26, 2, v12 - v_mov_b32_e32 v27, v5 - v_lshl_add_u64 v[24:25], v[24:25], 0, v[4:5] - v_lshl_add_u64 v[24:25], v[24:25], 0, v[26:27] - global_load_ubyte v4, v[24:25], off - s_branch .LBB0_2 -.LBB0_11: ; %.loopexit124 - s_cmp_lg_u32 s12, 1 - s_mov_b64 s[0:1], -1 - s_cbranch_scc0 .LBB0_46 -; %bb.12: - s_and_saveexec_b64 s[0:1], s[2:3] - s_cbranch_execz .LBB0_45 -; %bb.13: ; %.preheader121 - v_lshl_add_u32 v4, v14, 2, s15 - s_mul_hi_i32 s5, s8, s4 - s_mul_i32 s4, s8, s4 - s_ashr_i32 s12, s9, 31 - v_lshl_add_u64 v[0:1], v[2:3], 2, s[6:7] - v_cmp_gt_i32_e32 vcc, s8, v4 - s_and_saveexec_b64 s[10:11], vcc - s_cbranch_execz .LBB0_15 -; %bb.14: - v_ashrrev_i32_e32 v5, 31, v4 - v_lshl_add_u64 v[6:7], s[4:5], 0, v[4:5] - v_mul_lo_u32 v5, v7, s9 - v_mul_lo_u32 v8, v6, s12 - v_mad_u64_u32 v[6:7], s[16:17], v6, s9, 0 - v_add3_u32 v7, v7, v8, v5 - v_lshl_add_u64 v[6:7], v[6:7], 2, v[0:1] - global_store_dword v[6:7], a0, off -.LBB0_15: - s_or_b64 exec, exec, s[10:11] - v_or_b32_e32 v6, 1, v4 - v_cmp_gt_i32_e32 vcc, s8, v6 - s_and_saveexec_b64 s[10:11], vcc - s_cbranch_execz .LBB0_17 -; %bb.16: - v_ashrrev_i32_e32 v7, 31, v6 - v_lshl_add_u64 v[6:7], s[4:5], 0, v[6:7] - v_mul_lo_u32 v5, v7, s9 - v_mul_lo_u32 v8, v6, s12 - v_mad_u64_u32 v[6:7], s[16:17], v6, s9, 0 - v_add3_u32 v7, v7, v8, v5 - v_lshl_add_u64 v[6:7], v[6:7], 2, v[0:1] - global_store_dword v[6:7], a1, off -.LBB0_17: - s_or_b64 exec, exec, s[10:11] - v_or_b32_e32 v6, 2, v4 - v_cmp_gt_i32_e32 vcc, s8, v6 - s_and_saveexec_b64 s[10:11], vcc - s_cbranch_execz .LBB0_19 -; %bb.18: - v_ashrrev_i32_e32 v7, 31, v6 - v_lshl_add_u64 v[6:7], s[4:5], 0, v[6:7] - v_mul_lo_u32 v5, v7, s9 - v_mul_lo_u32 v8, v6, s12 - v_mad_u64_u32 v[6:7], s[16:17], v6, s9, 0 - v_add3_u32 v7, v7, v8, v5 - v_lshl_add_u64 v[6:7], v[6:7], 2, v[0:1] - global_store_dword v[6:7], a2, off -.LBB0_19: - s_or_b64 exec, exec, s[10:11] - v_or_b32_e32 v6, 3, v4 - v_cmp_gt_i32_e32 vcc, s8, v6 - s_and_saveexec_b64 s[10:11], vcc - s_cbranch_execz .LBB0_21 -; %bb.20: - v_ashrrev_i32_e32 v7, 31, v6 - v_lshl_add_u64 v[6:7], s[4:5], 0, v[6:7] - v_mul_lo_u32 v5, v7, s9 - v_mul_lo_u32 v8, v6, s12 - v_mad_u64_u32 v[6:7], s[16:17], v6, s9, 0 - v_add3_u32 v7, v7, v8, v5 - v_lshl_add_u64 v[6:7], v[6:7], 2, v[0:1] - global_store_dword v[6:7], a3, off -.LBB0_21: - s_or_b64 exec, exec, s[10:11] - v_add_u32_e32 v6, 8, v4 - v_cmp_gt_i32_e32 vcc, s8, v6 - s_and_saveexec_b64 s[10:11], vcc - s_cbranch_execz .LBB0_23 -; %bb.22: - v_ashrrev_i32_e32 v7, 31, v6 - v_lshl_add_u64 v[6:7], s[4:5], 0, v[6:7] - v_mul_lo_u32 v5, v7, s9 - v_mul_lo_u32 v8, v6, s12 - v_mad_u64_u32 v[6:7], s[16:17], v6, s9, 0 - v_add3_u32 v7, v7, v8, v5 - v_lshl_add_u64 v[6:7], v[6:7], 2, v[0:1] - global_store_dword v[6:7], a4, off -.LBB0_23: - s_or_b64 exec, exec, s[10:11] - v_add_u32_e32 v6, 9, v4 - v_cmp_gt_i32_e32 vcc, s8, v6 - s_and_saveexec_b64 s[10:11], vcc - s_cbranch_execz .LBB0_25 -; %bb.24: - v_ashrrev_i32_e32 v7, 31, v6 - v_lshl_add_u64 v[6:7], s[4:5], 0, v[6:7] - v_mul_lo_u32 v5, v7, s9 - v_mul_lo_u32 v8, v6, s12 - v_mad_u64_u32 v[6:7], s[16:17], v6, s9, 0 - v_add3_u32 v7, v7, v8, v5 - v_lshl_add_u64 v[6:7], v[6:7], 2, v[0:1] - global_store_dword v[6:7], a5, off -.LBB0_25: - s_or_b64 exec, exec, s[10:11] - v_add_u32_e32 v6, 10, v4 - v_cmp_gt_i32_e32 vcc, s8, v6 - s_and_saveexec_b64 s[10:11], vcc - s_cbranch_execz .LBB0_27 -; %bb.26: - v_ashrrev_i32_e32 v7, 31, v6 - v_lshl_add_u64 v[6:7], s[4:5], 0, v[6:7] - v_mul_lo_u32 v5, v7, s9 - v_mul_lo_u32 v8, v6, s12 - v_mad_u64_u32 v[6:7], s[16:17], v6, s9, 0 - v_add3_u32 v7, v7, v8, v5 - v_lshl_add_u64 v[6:7], v[6:7], 2, v[0:1] - global_store_dword v[6:7], a6, off -.LBB0_27: - s_or_b64 exec, exec, s[10:11] - v_add_u32_e32 v6, 11, v4 - v_cmp_gt_i32_e32 vcc, s8, v6 - s_and_saveexec_b64 s[10:11], vcc - s_cbranch_execz .LBB0_29 -; %bb.28: - v_ashrrev_i32_e32 v7, 31, v6 - v_lshl_add_u64 v[6:7], s[4:5], 0, v[6:7] - v_mul_lo_u32 v5, v7, s9 - v_mul_lo_u32 v8, v6, s12 - v_mad_u64_u32 v[6:7], s[16:17], v6, s9, 0 - v_add3_u32 v7, v7, v8, v5 - v_lshl_add_u64 v[6:7], v[6:7], 2, v[0:1] - global_store_dword v[6:7], a7, off -.LBB0_29: - s_or_b64 exec, exec, s[10:11] - v_add_u32_e32 v6, 16, v4 - v_cmp_gt_i32_e32 vcc, s8, v6 - s_and_saveexec_b64 s[10:11], vcc - s_cbranch_execz .LBB0_31 -; %bb.30: - v_ashrrev_i32_e32 v7, 31, v6 - v_lshl_add_u64 v[6:7], s[4:5], 0, v[6:7] - v_mul_lo_u32 v5, v7, s9 - v_mul_lo_u32 v8, v6, s12 - v_mad_u64_u32 v[6:7], s[16:17], v6, s9, 0 - v_add3_u32 v7, v7, v8, v5 - v_lshl_add_u64 v[6:7], v[6:7], 2, v[0:1] - global_store_dword v[6:7], a8, off -.LBB0_31: - s_or_b64 exec, exec, s[10:11] - v_add_u32_e32 v6, 17, v4 - v_cmp_gt_i32_e32 vcc, s8, v6 - s_and_saveexec_b64 s[10:11], vcc - s_cbranch_execz .LBB0_33 -; %bb.32: - v_ashrrev_i32_e32 v7, 31, v6 - v_lshl_add_u64 v[6:7], s[4:5], 0, v[6:7] - v_mul_lo_u32 v5, v7, s9 - v_mul_lo_u32 v8, v6, s12 - v_mad_u64_u32 v[6:7], s[16:17], v6, s9, 0 - v_add3_u32 v7, v7, v8, v5 - v_lshl_add_u64 v[6:7], v[6:7], 2, v[0:1] - global_store_dword v[6:7], a9, off -.LBB0_33: - s_or_b64 exec, exec, s[10:11] - v_add_u32_e32 v6, 18, v4 - v_cmp_gt_i32_e32 vcc, s8, v6 - s_and_saveexec_b64 s[10:11], vcc - s_cbranch_execz .LBB0_35 -; %bb.34: - v_ashrrev_i32_e32 v7, 31, v6 - v_lshl_add_u64 v[6:7], s[4:5], 0, v[6:7] - v_mul_lo_u32 v5, v7, s9 - v_mul_lo_u32 v8, v6, s12 - v_mad_u64_u32 v[6:7], s[16:17], v6, s9, 0 - v_add3_u32 v7, v7, v8, v5 - v_lshl_add_u64 v[6:7], v[6:7], 2, v[0:1] - global_store_dword v[6:7], a10, off -.LBB0_35: - s_or_b64 exec, exec, s[10:11] - v_add_u32_e32 v6, 19, v4 - v_cmp_gt_i32_e32 vcc, s8, v6 - s_and_saveexec_b64 s[10:11], vcc - s_cbranch_execz .LBB0_37 -; %bb.36: - v_ashrrev_i32_e32 v7, 31, v6 - v_lshl_add_u64 v[6:7], s[4:5], 0, v[6:7] - v_mul_lo_u32 v5, v7, s9 - v_mul_lo_u32 v8, v6, s12 - v_mad_u64_u32 v[6:7], s[16:17], v6, s9, 0 - v_add3_u32 v7, v7, v8, v5 - v_lshl_add_u64 v[6:7], v[6:7], 2, v[0:1] - global_store_dword v[6:7], a11, off -.LBB0_37: - s_or_b64 exec, exec, s[10:11] - v_add_u32_e32 v6, 24, v4 - v_cmp_gt_i32_e32 vcc, s8, v6 - s_and_saveexec_b64 s[10:11], vcc - s_cbranch_execz .LBB0_39 -; %bb.38: - v_ashrrev_i32_e32 v7, 31, v6 - v_lshl_add_u64 v[6:7], s[4:5], 0, v[6:7] - v_mul_lo_u32 v5, v7, s9 - v_mul_lo_u32 v8, v6, s12 - v_mad_u64_u32 v[6:7], s[16:17], v6, s9, 0 - v_add3_u32 v7, v7, v8, v5 - v_lshl_add_u64 v[6:7], v[6:7], 2, v[0:1] - global_store_dword v[6:7], a12, off -.LBB0_39: - s_or_b64 exec, exec, s[10:11] - v_add_u32_e32 v6, 25, v4 - v_cmp_gt_i32_e32 vcc, s8, v6 - s_and_saveexec_b64 s[10:11], vcc - s_cbranch_execz .LBB0_41 -; %bb.40: - v_ashrrev_i32_e32 v7, 31, v6 - v_lshl_add_u64 v[6:7], s[4:5], 0, v[6:7] - v_mul_lo_u32 v5, v7, s9 - v_mul_lo_u32 v8, v6, s12 - v_mad_u64_u32 v[6:7], s[16:17], v6, s9, 0 - v_add3_u32 v7, v7, v8, v5 - v_lshl_add_u64 v[6:7], v[6:7], 2, v[0:1] - global_store_dword v[6:7], a13, off -.LBB0_41: - s_or_b64 exec, exec, s[10:11] - v_add_u32_e32 v6, 26, v4 - v_cmp_gt_i32_e32 vcc, s8, v6 - s_and_saveexec_b64 s[10:11], vcc - s_cbranch_execz .LBB0_43 -; %bb.42: - v_ashrrev_i32_e32 v7, 31, v6 - v_lshl_add_u64 v[6:7], s[4:5], 0, v[6:7] - v_mul_lo_u32 v5, v7, s9 - v_mul_lo_u32 v8, v6, s12 - v_mad_u64_u32 v[6:7], s[16:17], v6, s9, 0 - v_add3_u32 v7, v7, v8, v5 - v_lshl_add_u64 v[6:7], v[6:7], 2, v[0:1] - global_store_dword v[6:7], a14, off -.LBB0_43: - s_or_b64 exec, exec, s[10:11] - v_add_u32_e32 v4, 27, v4 - v_cmp_gt_i32_e32 vcc, s8, v4 - s_and_b64 exec, exec, vcc - s_cbranch_execz .LBB0_45 -; %bb.44: - v_ashrrev_i32_e32 v5, 31, v4 - v_lshl_add_u64 v[4:5], s[4:5], 0, v[4:5] - v_mul_lo_u32 v6, v5, s9 - v_mul_lo_u32 v7, v4, s12 - v_mad_u64_u32 v[4:5], s[4:5], v4, s9, 0 - v_add3_u32 v5, v5, v7, v6 - v_lshl_add_u64 v[0:1], v[4:5], 2, v[0:1] - global_store_dword v[0:1], a15, off -.LBB0_45: ; %Flow403 - s_or_b64 exec, exec, s[0:1] - s_mov_b64 s[0:1], 0 -.LBB0_46: ; %Flow406 - s_andn2_b64 vcc, exec, s[0:1] - s_cbranch_vccnz .LBB0_80 -; %bb.47: - s_and_saveexec_b64 s[0:1], s[2:3] - s_cbranch_execz .LBB0_80 -; %bb.48: ; %.preheader - v_lshl_add_u32 v4, v14, 2, s15 - v_lshl_add_u64 v[0:1], v[2:3], 1, s[6:7] - v_cmp_gt_i32_e32 vcc, s8, v4 - s_and_saveexec_b64 s[0:1], vcc - s_cbranch_execz .LBB0_50 -; %bb.49: - v_mad_i64_i32 v[2:3], s[2:3], v4, s9, 0 - v_lshl_add_u64 v[2:3], v[2:3], 1, v[0:1] - global_store_short_d16_hi v[2:3], a0, off -.LBB0_50: - s_or_b64 exec, exec, s[0:1] - v_or_b32_e32 v2, 1, v4 - v_cmp_gt_i32_e32 vcc, s8, v2 - s_and_saveexec_b64 s[0:1], vcc - s_cbranch_execz .LBB0_52 -; %bb.51: - v_mad_i64_i32 v[2:3], s[2:3], v2, s9, 0 - v_lshl_add_u64 v[2:3], v[2:3], 1, v[0:1] - global_store_short_d16_hi v[2:3], a1, off -.LBB0_52: - s_or_b64 exec, exec, s[0:1] - v_or_b32_e32 v2, 2, v4 - v_cmp_gt_i32_e32 vcc, s8, v2 - s_and_saveexec_b64 s[0:1], vcc - s_cbranch_execz .LBB0_54 -; %bb.53: - v_mad_i64_i32 v[2:3], s[2:3], v2, s9, 0 - v_lshl_add_u64 v[2:3], v[2:3], 1, v[0:1] - global_store_short_d16_hi v[2:3], a2, off -.LBB0_54: - s_or_b64 exec, exec, s[0:1] - v_or_b32_e32 v2, 3, v4 - v_cmp_gt_i32_e32 vcc, s8, v2 - s_and_saveexec_b64 s[0:1], vcc - s_cbranch_execz .LBB0_56 -; %bb.55: - v_mad_i64_i32 v[2:3], s[2:3], v2, s9, 0 - v_lshl_add_u64 v[2:3], v[2:3], 1, v[0:1] - global_store_short_d16_hi v[2:3], a3, off -.LBB0_56: - s_or_b64 exec, exec, s[0:1] - v_add_u32_e32 v2, 8, v4 - v_cmp_gt_i32_e32 vcc, s8, v2 - s_and_saveexec_b64 s[0:1], vcc - s_cbranch_execz .LBB0_58 -; %bb.57: - v_mad_i64_i32 v[2:3], s[2:3], v2, s9, 0 - v_lshl_add_u64 v[2:3], v[2:3], 1, v[0:1] - global_store_short_d16_hi v[2:3], a4, off -.LBB0_58: - s_or_b64 exec, exec, s[0:1] - v_add_u32_e32 v2, 9, v4 - v_cmp_gt_i32_e32 vcc, s8, v2 - s_and_saveexec_b64 s[0:1], vcc - s_cbranch_execz .LBB0_60 -; %bb.59: - v_mad_i64_i32 v[2:3], s[2:3], v2, s9, 0 - v_lshl_add_u64 v[2:3], v[2:3], 1, v[0:1] - global_store_short_d16_hi v[2:3], a5, off -.LBB0_60: - s_or_b64 exec, exec, s[0:1] - v_add_u32_e32 v2, 10, v4 - v_cmp_gt_i32_e32 vcc, s8, v2 - s_and_saveexec_b64 s[0:1], vcc - s_cbranch_execz .LBB0_62 -; %bb.61: - v_mad_i64_i32 v[2:3], s[2:3], v2, s9, 0 - v_lshl_add_u64 v[2:3], v[2:3], 1, v[0:1] - global_store_short_d16_hi v[2:3], a6, off -.LBB0_62: - s_or_b64 exec, exec, s[0:1] - v_add_u32_e32 v2, 11, v4 - v_cmp_gt_i32_e32 vcc, s8, v2 - s_and_saveexec_b64 s[0:1], vcc - s_cbranch_execz .LBB0_64 -; %bb.63: - v_mad_i64_i32 v[2:3], s[2:3], v2, s9, 0 - v_lshl_add_u64 v[2:3], v[2:3], 1, v[0:1] - global_store_short_d16_hi v[2:3], a7, off -.LBB0_64: - s_or_b64 exec, exec, s[0:1] - v_add_u32_e32 v2, 16, v4 - v_cmp_gt_i32_e32 vcc, s8, v2 - s_and_saveexec_b64 s[0:1], vcc - s_cbranch_execz .LBB0_66 -; %bb.65: - v_mad_i64_i32 v[2:3], s[2:3], v2, s9, 0 - v_lshl_add_u64 v[2:3], v[2:3], 1, v[0:1] - global_store_short_d16_hi v[2:3], a8, off -.LBB0_66: - s_or_b64 exec, exec, s[0:1] - v_add_u32_e32 v2, 17, v4 - v_cmp_gt_i32_e32 vcc, s8, v2 - s_and_saveexec_b64 s[0:1], vcc - s_cbranch_execz .LBB0_68 -; %bb.67: - v_mad_i64_i32 v[2:3], s[2:3], v2, s9, 0 - v_lshl_add_u64 v[2:3], v[2:3], 1, v[0:1] - global_store_short_d16_hi v[2:3], a9, off -.LBB0_68: - s_or_b64 exec, exec, s[0:1] - v_add_u32_e32 v2, 18, v4 - v_cmp_gt_i32_e32 vcc, s8, v2 - s_and_saveexec_b64 s[0:1], vcc - s_cbranch_execz .LBB0_70 -; %bb.69: - v_mad_i64_i32 v[2:3], s[2:3], v2, s9, 0 - v_lshl_add_u64 v[2:3], v[2:3], 1, v[0:1] - global_store_short_d16_hi v[2:3], a10, off -.LBB0_70: - s_or_b64 exec, exec, s[0:1] - v_add_u32_e32 v2, 19, v4 - v_cmp_gt_i32_e32 vcc, s8, v2 - s_and_saveexec_b64 s[0:1], vcc - s_cbranch_execz .LBB0_72 -; %bb.71: - v_mad_i64_i32 v[2:3], s[2:3], v2, s9, 0 - v_lshl_add_u64 v[2:3], v[2:3], 1, v[0:1] - global_store_short_d16_hi v[2:3], a11, off -.LBB0_72: - s_or_b64 exec, exec, s[0:1] - v_add_u32_e32 v2, 24, v4 - v_cmp_gt_i32_e32 vcc, s8, v2 - s_and_saveexec_b64 s[0:1], vcc - s_cbranch_execz .LBB0_74 -; %bb.73: - v_mad_i64_i32 v[2:3], s[2:3], v2, s9, 0 - v_lshl_add_u64 v[2:3], v[2:3], 1, v[0:1] - global_store_short_d16_hi v[2:3], a12, off -.LBB0_74: - s_or_b64 exec, exec, s[0:1] - v_add_u32_e32 v2, 25, v4 - v_cmp_gt_i32_e32 vcc, s8, v2 - s_and_saveexec_b64 s[0:1], vcc - s_cbranch_execz .LBB0_76 -; %bb.75: - v_mad_i64_i32 v[2:3], s[2:3], v2, s9, 0 - v_lshl_add_u64 v[2:3], v[2:3], 1, v[0:1] - global_store_short_d16_hi v[2:3], a13, off -.LBB0_76: - s_or_b64 exec, exec, s[0:1] - v_add_u32_e32 v2, 26, v4 - v_cmp_gt_i32_e32 vcc, s8, v2 - s_and_saveexec_b64 s[0:1], vcc - s_cbranch_execz .LBB0_78 -; %bb.77: - v_mad_i64_i32 v[2:3], s[2:3], v2, s9, 0 - v_lshl_add_u64 v[2:3], v[2:3], 1, v[0:1] - global_store_short_d16_hi v[2:3], a14, off -.LBB0_78: - s_or_b64 exec, exec, s[0:1] - v_add_u32_e32 v2, 27, v4 - v_cmp_gt_i32_e32 vcc, s8, v2 - s_and_b64 exec, exec, vcc - s_cbranch_execz .LBB0_80 -; %bb.79: - v_mad_i64_i32 v[2:3], s[0:1], v2, s9, 0 - v_lshl_add_u64 v[0:1], v[2:3], 1, v[0:1] - global_store_short_d16_hi v[0:1], a15, off -.LBB0_80: ; %.loopexit - s_endpgm - .section .rodata,"a",@progbits - .p2align 6, 0x0 - .amdhsa_kernel _Z14gemm_core_gmemPKhS0_PvS0_S0_iiiiiii - .amdhsa_group_segment_fixed_size 0 - .amdhsa_private_segment_fixed_size 0 - .amdhsa_kernarg_size 68 - .amdhsa_user_sgpr_count 2 - .amdhsa_user_sgpr_dispatch_ptr 0 - .amdhsa_user_sgpr_queue_ptr 0 - .amdhsa_user_sgpr_kernarg_segment_ptr 1 - .amdhsa_user_sgpr_dispatch_id 0 - .amdhsa_user_sgpr_kernarg_preload_length 0 - .amdhsa_user_sgpr_kernarg_preload_offset 0 - .amdhsa_user_sgpr_private_segment_size 0 - .amdhsa_uses_dynamic_stack 0 - .amdhsa_enable_private_segment 0 - .amdhsa_system_sgpr_workgroup_id_x 1 - .amdhsa_system_sgpr_workgroup_id_y 1 - .amdhsa_system_sgpr_workgroup_id_z 1 - .amdhsa_system_sgpr_workgroup_info 0 - .amdhsa_system_vgpr_workitem_id 0 - .amdhsa_next_free_vgpr 44 - .amdhsa_next_free_sgpr 28 - .amdhsa_accum_offset 28 - .amdhsa_reserve_vcc 1 - .amdhsa_float_round_mode_32 0 - .amdhsa_float_round_mode_16_64 0 - .amdhsa_float_denorm_mode_32 3 - .amdhsa_float_denorm_mode_16_64 3 - .amdhsa_dx10_clamp 1 - .amdhsa_ieee_mode 1 - .amdhsa_fp16_overflow 0 - .amdhsa_tg_split 0 - .amdhsa_exception_fp_ieee_invalid_op 0 - .amdhsa_exception_fp_denorm_src 0 - .amdhsa_exception_fp_ieee_div_zero 0 - .amdhsa_exception_fp_ieee_overflow 0 - .amdhsa_exception_fp_ieee_underflow 0 - .amdhsa_exception_fp_ieee_inexact 0 - .amdhsa_exception_int_div_zero 0 - .end_amdhsa_kernel - .text -.Lfunc_end0: - .size _Z14gemm_core_gmemPKhS0_PvS0_S0_iiiiiii, .Lfunc_end0-_Z14gemm_core_gmemPKhS0_PvS0_S0_iiiiiii - ; -- End function - .set _Z14gemm_core_gmemPKhS0_PvS0_S0_iiiiiii.num_vgpr, 28 - .set _Z14gemm_core_gmemPKhS0_PvS0_S0_iiiiiii.num_agpr, 16 - .set _Z14gemm_core_gmemPKhS0_PvS0_S0_iiiiiii.numbered_sgpr, 28 - .set _Z14gemm_core_gmemPKhS0_PvS0_S0_iiiiiii.num_named_barrier, 0 - .set _Z14gemm_core_gmemPKhS0_PvS0_S0_iiiiiii.private_seg_size, 0 - .set _Z14gemm_core_gmemPKhS0_PvS0_S0_iiiiiii.uses_vcc, 1 - .set _Z14gemm_core_gmemPKhS0_PvS0_S0_iiiiiii.uses_flat_scratch, 0 - .set _Z14gemm_core_gmemPKhS0_PvS0_S0_iiiiiii.has_dyn_sized_stack, 0 - .set _Z14gemm_core_gmemPKhS0_PvS0_S0_iiiiiii.has_recursion, 0 - .set _Z14gemm_core_gmemPKhS0_PvS0_S0_iiiiiii.has_indirect_call, 0 - .section .AMDGPU.csdata,"",@progbits -; Kernel info: -; codeLenInByte = 2828 -; TotalNumSgprs: 34 -; NumVgprs: 28 -; NumAgprs: 16 -; TotalNumVgprs: 44 -; ScratchSize: 0 -; MemoryBound: 0 -; FloatMode: 240 -; IeeeMode: 1 -; LDSByteSize: 0 bytes/workgroup (compile time only) -; SGPRBlocks: 4 -; VGPRBlocks: 5 -; NumSGPRsForWavesPerEU: 34 -; NumVGPRsForWavesPerEU: 44 -; AccumOffset: 28 -; Occupancy: 8 -; WaveLimiterHint : 0 -; COMPUTE_PGM_RSRC2:SCRATCH_EN: 0 -; COMPUTE_PGM_RSRC2:USER_SGPR: 2 -; COMPUTE_PGM_RSRC2:TRAP_HANDLER: 0 -; COMPUTE_PGM_RSRC2:TGID_X_EN: 1 -; COMPUTE_PGM_RSRC2:TGID_Y_EN: 1 -; COMPUTE_PGM_RSRC2:TGID_Z_EN: 1 -; COMPUTE_PGM_RSRC2:TIDIG_COMP_CNT: 0 -; COMPUTE_PGM_RSRC3_GFX90A:ACCUM_OFFSET: 6 -; COMPUTE_PGM_RSRC3_GFX90A:TG_SPLIT: 0 - .text - .protected _Z13gemm_core_ldsPKhS0_PvS0_S0_iiiiiii ; -- Begin function _Z13gemm_core_ldsPKhS0_PvS0_S0_iiiiiii - .globl _Z13gemm_core_ldsPKhS0_PvS0_S0_iiiiiii - .p2align 8 - .type _Z13gemm_core_ldsPKhS0_PvS0_S0_iiiiiii,@function -_Z13gemm_core_ldsPKhS0_PvS0_S0_iiiiiii: ; @_Z13gemm_core_ldsPKhS0_PvS0_S0_iiiiiii -; %bb.0: - s_load_dwordx2 s[10:11], s[0:1], 0x10 - s_load_dwordx8 s[12:19], s[0:1], 0x28 - s_lshl_b32 s3, s3, 6 - v_lshrrev_b32_e32 v2, 1, v0 - v_and_b32_e32 v1, 31, v0 - v_and_or_b32 v5, v2, 32, s3 - s_waitcnt lgkmcnt(0) - s_mul_i32 s5, s15, s4 - s_add_i32 s6, s5, s15 - s_lshl_b32 s2, s2, 6 - v_lshrrev_b32_e32 v4, 2, v0 - v_or_b32_e32 v18, v5, v1 - s_min_i32 s20, s6, s14 - v_bfe_u32 v28, v0, 5, 1 - v_and_or_b32 v29, v4, 32, s2 - s_cmp_ge_i32 s5, s20 - v_cmp_gt_i32_e64 s[8:9], s13, v18 - s_cbranch_scc1 .LBB1_26 -; %bb.1: ; %.lr.ph211 - s_load_dwordx4 s[24:27], s[0:1], 0x0 - s_load_dwordx4 s[28:31], s[0:1], 0x18 - s_ashr_i32 s21, s14, 5 - v_lshlrev_b32_e32 v11, 5, v0 - v_mov_b32_e32 v3, 0 - v_ashrrev_i32_e32 v12, 5, v5 - s_ashr_i32 s14, s14, 1 - v_add_u32_e32 v5, s2, v4 - s_waitcnt lgkmcnt(0) - v_mov_b64_e32 v[8:9], s[24:25] - v_and_b32_e32 v32, 0x60, v11 - v_mad_i64_i32 v[8:9], s[6:7], s14, v5, v[8:9] - v_mov_b32_e32 v33, v3 - v_add_u32_e32 v13, s3, v4 - v_lshl_add_u64 v[20:21], v[8:9], 0, v[32:33] - v_mov_b64_e32 v[8:9], s[26:27] - v_mad_i64_i32 v[8:9], s[6:7], s14, v13, v[8:9] - v_or_b32_e32 v10, v29, v1 - v_cmp_gt_i32_e64 s[2:3], s12, v5 - v_lshl_add_u64 v[22:23], v[8:9], 0, v[32:33] - v_lshlrev_b32_e32 v33, 7, v4 - v_mov_b64_e32 v[4:5], s[28:29] - v_lshlrev_b32_e32 v6, 2, v0 - s_lshl_b32 s15, s18, 5 - v_mad_i64_i32 v[24:25], s[6:7], s17, v10, v[4:5] - v_mov_b64_e32 v[4:5], s[30:31] - v_and_b32_e32 v6, 60, v6 - v_mov_b32_e32 v7, v3 - v_mad_i64_i32 v[4:5], s[6:7], s15, v12, v[4:5] - v_lshrrev_b32_e32 v2, 4, v1 - v_lshl_add_u64 v[4:5], v[4:5], 0, v[6:7] - v_lshlrev_b32_e32 v0, 6, v0 - v_lshlrev_b32_e32 v1, 7, v1 - s_movk_i32 s6, 0x1000 - s_movk_i32 s14, 0x2000 - v_lshl_add_u64 v[26:27], v[4:5], 0, v[2:3] - v_lshlrev_b32_e32 v2, 4, v28 - v_and_or_b32 v0, v0, s6, v1 - v_or3_b32 v30, v0, v2, s14 - v_and_b32_e32 v0, 0x1000, v11 - v_or3_b32 v31, v0, v1, v2 - v_mov_b32_e32 v2, v3 - v_cmp_gt_i32_e32 vcc, s12, v10 - v_cmp_gt_i32_e64 s[0:1], s13, v13 - v_or_b32_e32 v34, 0x2000, v33 - v_mov_b32_e32 v4, v3 - v_mov_b32_e32 v5, v3 - v_mov_b32_e32 v6, v3 - v_mov_b32_e32 v8, v3 - v_mov_b32_e32 v9, v3 - v_mov_b32_e32 v10, v3 - v_mov_b32_e32 v11, v3 - v_mov_b32_e32 v12, v3 - v_mov_b32_e32 v13, v3 - v_mov_b32_e32 v14, v3 - v_mov_b32_e32 v15, v3 - v_mov_b32_e32 v16, v3 - v_mov_b32_e32 v17, v3 - v_accvgpr_write_b32 a0, v2 - v_add_u32_e32 v19, 32, v32 - v_accvgpr_write_b32 a1, v3 - v_accvgpr_write_b32 a2, v4 - v_accvgpr_write_b32 a3, v5 - v_accvgpr_write_b32 a4, v6 - v_accvgpr_write_b32 a5, v7 - v_accvgpr_write_b32 a6, v8 - v_accvgpr_write_b32 a7, v9 - v_accvgpr_write_b32 a8, v10 - v_accvgpr_write_b32 a9, v11 - v_accvgpr_write_b32 a10, v12 - v_accvgpr_write_b32 a11, v13 - v_accvgpr_write_b32 a12, v14 - v_accvgpr_write_b32 a13, v15 - v_accvgpr_write_b32 a14, v16 - v_accvgpr_write_b32 a15, v17 - v_add_u32_e32 v4, v33, v32 - v_add_u32_e32 v5, v34, v32 - s_branch .LBB1_3 -.LBB1_2: ; %._crit_edge - ; in Loop: Header=BB1_3 Depth=1 - s_addk_i32 s5, 0x100 - s_cmp_ge_i32 s5, s20 - s_barrier - s_cbranch_scc1 .LBB1_27 -.LBB1_3: ; =>This Loop Header: Depth=1 - ; Child Loop BB1_25 Depth 2 - s_sub_i32 s23, s20, s5 - s_min_i32 s22, s23, 0x100 - s_lshr_b32 s6, s22, 1 - v_cmp_ge_u32_e64 s[6:7], s6, v19 - s_ashr_i32 s14, s5, 1 - s_and_b64 s[24:25], s[2:3], s[6:7] - s_ashr_i32 s15, s14, 31 - s_waitcnt vmcnt(0) - v_mov_b32_e32 v6, 0 - v_mov_b32_e32 v7, 0 - v_mov_b32_e32 v8, 0 - v_mov_b32_e32 v9, 0 - v_mov_b32_e32 v10, 0 - v_mov_b32_e32 v11, 0 - v_mov_b32_e32 v12, 0 - v_mov_b32_e32 v13, 0 - s_and_saveexec_b64 s[18:19], s[24:25] - s_cbranch_execz .LBB1_5 -; %bb.4: ; in Loop: Header=BB1_3 Depth=1 - v_lshl_add_u64 v[0:1], v[20:21], 0, s[14:15] - global_load_dwordx4 v[6:9], v[0:1], off - global_load_dwordx4 v[10:13], v[0:1], off offset:16 -.LBB1_5: ; %_Z8copy_32bPhPKhb.exit - ; in Loop: Header=BB1_3 Depth=1 - s_or_b64 exec, exec, s[18:19] - s_and_b64 s[18:19], s[0:1], s[6:7] - s_waitcnt vmcnt(1) - ds_write_b128 v4, v[6:9] - s_waitcnt vmcnt(0) - ds_write_b128 v4, v[10:13] offset:16 - v_mov_b32_e32 v6, 0 - v_mov_b32_e32 v7, 0 - v_mov_b32_e32 v8, 0 - v_mov_b32_e32 v9, 0 - v_mov_b32_e32 v10, 0 - v_mov_b32_e32 v11, 0 - v_mov_b32_e32 v12, 0 - v_mov_b32_e32 v13, 0 - s_and_saveexec_b64 s[6:7], s[18:19] - s_cbranch_execz .LBB1_7 -; %bb.6: ; in Loop: Header=BB1_3 Depth=1 - v_lshl_add_u64 v[0:1], v[22:23], 0, s[14:15] - global_load_dwordx4 v[6:9], v[0:1], off - global_load_dwordx4 v[10:13], v[0:1], off offset:16 -.LBB1_7: ; %_Z8copy_32bPhPKhb.exit197 - ; in Loop: Header=BB1_3 Depth=1 - s_or_b64 exec, exec, s[6:7] - s_ashr_i32 s6, s5, 5 - v_add_u32_e32 v0, s6, v28 - v_cmp_gt_i32_e64 s[6:7], s17, v0 - s_waitcnt vmcnt(1) - ds_write_b128 v5, v[6:9] - s_waitcnt vmcnt(0) - ds_write_b128 v5, v[10:13] offset:16 - s_and_b64 s[14:15], vcc, s[6:7] - v_mov_b32_e32 v7, 0x7f - v_mov_b32_e32 v6, 0x7f - s_waitcnt lgkmcnt(0) - s_barrier - s_and_saveexec_b64 s[6:7], s[14:15] - s_cbranch_execz .LBB1_9 -; %bb.8: ; in Loop: Header=BB1_3 Depth=1 - v_ashrrev_i32_e32 v1, 31, v0 - v_lshl_add_u64 v[8:9], v[24:25], 0, v[0:1] - global_load_ubyte v6, v[8:9], off -.LBB1_9: ; in Loop: Header=BB1_3 Depth=1 - s_or_b64 exec, exec, s[6:7] - v_cmp_gt_i32_e64 s[6:7], s21, v0 - s_and_b64 s[14:15], s[8:9], s[6:7] - s_and_saveexec_b64 s[6:7], s[14:15] - s_cbranch_execz .LBB1_11 -; %bb.10: ; in Loop: Header=BB1_3 Depth=1 - v_lshlrev_b32_e32 v1, 5, v0 - v_and_b32_e32 v8, 0xffffff00, v1 - v_ashrrev_i32_e32 v9, 31, v8 - v_lshlrev_b32_e32 v1, 6, v0 - v_and_b32_e32 v2, 0xc0, v1 - v_lshrrev_b32_e32 v1, 1, v0 - v_lshl_add_u64 v[8:9], v[26:27], 0, v[8:9] - v_and_b32_e32 v10, 2, v1 - v_mov_b32_e32 v11, v3 - v_lshl_add_u64 v[8:9], v[8:9], 0, v[2:3] - v_lshl_add_u64 v[8:9], v[8:9], 0, v[10:11] - global_load_ubyte v7, v[8:9], off -.LBB1_11: ; in Loop: Header=BB1_3 Depth=1 - s_or_b64 exec, exec, s[6:7] - v_add_u32_e32 v10, 2, v0 - v_cmp_gt_i32_e64 s[6:7], s17, v10 - s_and_b64 s[14:15], vcc, s[6:7] - v_mov_b32_e32 v9, 0x7f00 - v_mov_b32_e32 v8, 0x7f00 - s_and_saveexec_b64 s[6:7], s[14:15] - s_cbranch_execz .LBB1_13 -; %bb.12: ; in Loop: Header=BB1_3 Depth=1 - v_ashrrev_i32_e32 v1, 31, v0 - v_lshl_add_u64 v[12:13], v[24:25], 0, v[0:1] - global_load_ubyte v1, v[12:13], off offset:2 - s_waitcnt vmcnt(0) - v_lshlrev_b32_e32 v8, 8, v1 -.LBB1_13: ; in Loop: Header=BB1_3 Depth=1 - s_or_b64 exec, exec, s[6:7] - v_cmp_gt_i32_e64 s[6:7], s21, v10 - s_and_b64 s[14:15], s[8:9], s[6:7] - s_and_saveexec_b64 s[6:7], s[14:15] - s_cbranch_execz .LBB1_15 -; %bb.14: ; in Loop: Header=BB1_3 Depth=1 - v_lshlrev_b32_e32 v1, 5, v10 - v_and_b32_e32 v12, 0xffffff00, v1 - v_ashrrev_i32_e32 v13, 31, v12 - v_lshlrev_b32_e32 v1, 6, v10 - v_and_b32_e32 v2, 0xc0, v1 - v_lshrrev_b32_e32 v1, 1, v10 - v_lshl_add_u64 v[12:13], v[26:27], 0, v[12:13] - v_and_b32_e32 v10, 2, v1 - v_mov_b32_e32 v11, v3 - v_lshl_add_u64 v[12:13], v[12:13], 0, v[2:3] - v_lshl_add_u64 v[10:11], v[12:13], 0, v[10:11] - global_load_ubyte v1, v[10:11], off - s_waitcnt vmcnt(0) - v_lshlrev_b32_e32 v9, 8, v1 -.LBB1_15: ; in Loop: Header=BB1_3 Depth=1 - s_or_b64 exec, exec, s[6:7] - v_add_u32_e32 v12, 4, v0 - v_cmp_gt_i32_e64 s[6:7], s17, v12 - s_and_b64 s[14:15], vcc, s[6:7] - v_mov_b32_e32 v11, 0x7f0000 - v_mov_b32_e32 v10, 0x7f0000 - s_and_saveexec_b64 s[6:7], s[14:15] - s_cbranch_execz .LBB1_17 -; %bb.16: ; in Loop: Header=BB1_3 Depth=1 - v_ashrrev_i32_e32 v1, 31, v0 - v_lshl_add_u64 v[14:15], v[24:25], 0, v[0:1] - global_load_ubyte v1, v[14:15], off offset:4 - s_waitcnt vmcnt(0) - v_lshlrev_b32_e32 v10, 16, v1 -.LBB1_17: ; in Loop: Header=BB1_3 Depth=1 - s_or_b64 exec, exec, s[6:7] - v_cmp_gt_i32_e64 s[6:7], s21, v12 - s_and_b64 s[14:15], s[8:9], s[6:7] - s_and_saveexec_b64 s[6:7], s[14:15] - s_cbranch_execz .LBB1_19 -; %bb.18: ; in Loop: Header=BB1_3 Depth=1 - v_lshlrev_b32_e32 v1, 5, v12 - v_and_b32_e32 v14, 0xffffff00, v1 - v_ashrrev_i32_e32 v15, 31, v14 - v_lshlrev_b32_e32 v1, 6, v0 - v_and_b32_e32 v2, 0xc0, v1 - v_lshrrev_b32_e32 v1, 1, v12 - v_lshl_add_u64 v[14:15], v[26:27], 0, v[14:15] - v_and_b32_e32 v12, 2, v1 - v_mov_b32_e32 v13, v3 - v_lshl_add_u64 v[14:15], v[14:15], 0, v[2:3] - v_lshl_add_u64 v[12:13], v[14:15], 0, v[12:13] - global_load_ubyte v1, v[12:13], off - s_waitcnt vmcnt(0) - v_lshlrev_b32_e32 v11, 16, v1 -.LBB1_19: ; in Loop: Header=BB1_3 Depth=1 - s_or_b64 exec, exec, s[6:7] - v_add_u32_e32 v12, 6, v0 - v_cmp_gt_i32_e64 s[6:7], s17, v12 - s_and_b64 s[14:15], vcc, s[6:7] - v_mov_b32_e32 v2, 0x7f000000 - v_mov_b32_e32 v1, 0x7f000000 - s_and_saveexec_b64 s[6:7], s[14:15] - s_cbranch_execz .LBB1_21 -; %bb.20: ; in Loop: Header=BB1_3 Depth=1 - v_ashrrev_i32_e32 v1, 31, v0 - v_lshl_add_u64 v[0:1], v[24:25], 0, v[0:1] - global_load_ubyte v0, v[0:1], off offset:6 - s_waitcnt vmcnt(0) - v_lshlrev_b32_e32 v1, 24, v0 -.LBB1_21: ; in Loop: Header=BB1_3 Depth=1 - s_or_b64 exec, exec, s[6:7] - v_cmp_gt_i32_e64 s[6:7], s21, v12 - s_and_b64 s[14:15], s[8:9], s[6:7] - s_and_saveexec_b64 s[6:7], s[14:15] - s_cbranch_execz .LBB1_23 -; %bb.22: ; in Loop: Header=BB1_3 Depth=1 - v_lshlrev_b32_e32 v0, 5, v12 - v_and_b32_e32 v14, 0xffffff00, v0 - v_ashrrev_i32_e32 v15, 31, v14 - v_lshlrev_b32_e32 v0, 6, v12 - v_and_b32_e32 v2, 0xc0, v0 - v_lshrrev_b32_e32 v0, 1, v12 - v_lshl_add_u64 v[14:15], v[26:27], 0, v[14:15] - v_and_b32_e32 v12, 2, v0 - v_mov_b32_e32 v13, v3 - v_lshl_add_u64 v[14:15], v[14:15], 0, v[2:3] - v_lshl_add_u64 v[12:13], v[14:15], 0, v[12:13] - global_load_ubyte v0, v[12:13], off - s_waitcnt vmcnt(0) - v_lshlrev_b32_e32 v2, 24, v0 -.LBB1_23: ; %.preheader202 - ; in Loop: Header=BB1_3 Depth=1 - s_or_b64 exec, exec, s[6:7] - s_cmp_lt_i32 s23, 1 - s_cbranch_scc1 .LBB1_2 -; %bb.24: ; %.lr.ph.preheader - ; in Loop: Header=BB1_3 Depth=1 - s_waitcnt vmcnt(0) - v_or_b32_e32 v0, v8, v6 - v_or_b32_e32 v6, v9, v7 - v_or3_b32 v0, v10, v0, v1 - v_or3_b32 v1, v11, v6, v2 - s_mov_b32 s6, 0 - v_mov_b32_e32 v2, v31 - v_mov_b32_e32 v6, v30 - s_mov_b32 s7, 0 -.LBB1_25: ; %.lr.ph - ; Parent Loop BB1_3 Depth=1 - ; => This Inner Loop Header: Depth=2 - ds_read_b128 v[8:11], v2 - ds_read_b128 v[12:15], v6 - v_bfe_u32 v7, v0, s6, 8 - v_bfe_u32 v16, v1, s6, 8 - s_add_i32 s7, s7, 64 - s_add_i32 s6, s6, 8 - s_waitcnt lgkmcnt(0) - v_mfma_scale_f32_32x32x64_f8f6f4 a[0:15], v[8:11], v[12:15], a[0:15], v7, v16 op_sel_hi:[0,0,0] cbsz:4 blgp:4 - v_add_u32_e32 v6, 32, v6 - s_cmp_ge_i32 s7, s22 - v_add_u32_e32 v2, 32, v2 - s_cbranch_scc0 .LBB1_25 - s_branch .LBB1_2 -.LBB1_26: - v_accvgpr_write_b32 a0, 0 - v_accvgpr_mov_b32 a1, a0 - v_accvgpr_mov_b32 a2, a0 - v_accvgpr_mov_b32 a3, a0 - v_accvgpr_mov_b32 a4, a0 - v_accvgpr_mov_b32 a5, a0 - v_accvgpr_mov_b32 a6, a0 - v_accvgpr_mov_b32 a7, a0 - v_accvgpr_mov_b32 a8, a0 - v_accvgpr_mov_b32 a9, a0 - v_accvgpr_mov_b32 a10, a0 - v_accvgpr_mov_b32 a11, a0 - v_accvgpr_mov_b32 a12, a0 - v_accvgpr_mov_b32 a13, a0 - v_accvgpr_mov_b32 a14, a0 - v_accvgpr_mov_b32 a15, a0 -.LBB1_27: ; %Flow413 - s_waitcnt vmcnt(0) - s_nop 1 - v_accvgpr_read_b32 v0, a0 - v_accvgpr_read_b32 v1, a1 - v_accvgpr_read_b32 v2, a2 - v_accvgpr_read_b32 v3, a3 - v_accvgpr_read_b32 v4, a4 - v_accvgpr_read_b32 v5, a5 - v_accvgpr_read_b32 v6, a6 - v_accvgpr_read_b32 v7, a7 - v_accvgpr_read_b32 v8, a8 - v_accvgpr_read_b32 v9, a9 - v_accvgpr_read_b32 v10, a10 - v_accvgpr_read_b32 v11, a11 - v_accvgpr_read_b32 v12, a12 - v_accvgpr_read_b32 v13, a13 - v_accvgpr_read_b32 v14, a14 - v_accvgpr_read_b32 v15, a15 - s_cmp_lg_u32 s16, 1 - s_mov_b64 s[0:1], -1 - s_cbranch_scc0 .LBB1_62 -; %bb.28: - s_and_saveexec_b64 s[0:1], s[8:9] - s_cbranch_execz .LBB1_61 -; %bb.29: ; %.preheader200 - v_lshl_or_b32 v20, v28, 2, v29 - v_ashrrev_i32_e32 v19, 31, v18 - s_mul_hi_i32 s3, s12, s4 - s_mul_i32 s2, s12, s4 - s_ashr_i32 s6, s13, 31 - v_lshl_add_u64 v[16:17], v[18:19], 2, s[10:11] - v_cmp_gt_i32_e32 vcc, s12, v20 - s_and_saveexec_b64 s[4:5], vcc - s_cbranch_execz .LBB1_31 -; %bb.30: - v_ashrrev_i32_e32 v21, 31, v20 - v_lshl_add_u64 v[22:23], s[2:3], 0, v[20:21] - v_mul_lo_u32 v19, v23, s13 - v_mul_lo_u32 v21, v22, s6 - v_mad_u64_u32 v[22:23], s[14:15], v22, s13, 0 - v_add3_u32 v23, v23, v21, v19 - v_lshl_add_u64 v[22:23], v[22:23], 2, v[16:17] - global_store_dword v[22:23], v0, off -.LBB1_31: - s_or_b64 exec, exec, s[4:5] - v_or_b32_e32 v22, 1, v20 - v_cmp_gt_i32_e32 vcc, s12, v22 - s_and_saveexec_b64 s[4:5], vcc - s_cbranch_execz .LBB1_33 -; %bb.32: - v_ashrrev_i32_e32 v23, 31, v22 - v_lshl_add_u64 v[22:23], s[2:3], 0, v[22:23] - v_mul_lo_u32 v19, v23, s13 - v_mul_lo_u32 v21, v22, s6 - v_mad_u64_u32 v[22:23], s[14:15], v22, s13, 0 - v_add3_u32 v23, v23, v21, v19 - v_lshl_add_u64 v[22:23], v[22:23], 2, v[16:17] - global_store_dword v[22:23], v1, off -.LBB1_33: - s_or_b64 exec, exec, s[4:5] - v_or_b32_e32 v22, 2, v20 - v_cmp_gt_i32_e32 vcc, s12, v22 - s_and_saveexec_b64 s[4:5], vcc - s_cbranch_execz .LBB1_35 -; %bb.34: - v_ashrrev_i32_e32 v23, 31, v22 - v_lshl_add_u64 v[22:23], s[2:3], 0, v[22:23] - v_mul_lo_u32 v19, v23, s13 - v_mul_lo_u32 v21, v22, s6 - v_mad_u64_u32 v[22:23], s[14:15], v22, s13, 0 - v_add3_u32 v23, v23, v21, v19 - v_lshl_add_u64 v[22:23], v[22:23], 2, v[16:17] - global_store_dword v[22:23], v2, off -.LBB1_35: - s_or_b64 exec, exec, s[4:5] - v_or_b32_e32 v22, 3, v20 - v_cmp_gt_i32_e32 vcc, s12, v22 - s_and_saveexec_b64 s[4:5], vcc - s_cbranch_execz .LBB1_37 -; %bb.36: - v_ashrrev_i32_e32 v23, 31, v22 - v_lshl_add_u64 v[22:23], s[2:3], 0, v[22:23] - v_mul_lo_u32 v19, v23, s13 - v_mul_lo_u32 v21, v22, s6 - v_mad_u64_u32 v[22:23], s[14:15], v22, s13, 0 - v_add3_u32 v23, v23, v21, v19 - v_lshl_add_u64 v[22:23], v[22:23], 2, v[16:17] - global_store_dword v[22:23], v3, off -.LBB1_37: - s_or_b64 exec, exec, s[4:5] - v_or_b32_e32 v22, 8, v20 - v_cmp_gt_i32_e32 vcc, s12, v22 - s_and_saveexec_b64 s[4:5], vcc - s_cbranch_execz .LBB1_39 -; %bb.38: - v_ashrrev_i32_e32 v23, 31, v22 - v_lshl_add_u64 v[22:23], s[2:3], 0, v[22:23] - v_mul_lo_u32 v19, v23, s13 - v_mul_lo_u32 v21, v22, s6 - v_mad_u64_u32 v[22:23], s[14:15], v22, s13, 0 - v_add3_u32 v23, v23, v21, v19 - v_lshl_add_u64 v[22:23], v[22:23], 2, v[16:17] - global_store_dword v[22:23], v4, off -.LBB1_39: - s_or_b64 exec, exec, s[4:5] - v_or_b32_e32 v22, 9, v20 - v_cmp_gt_i32_e32 vcc, s12, v22 - s_and_saveexec_b64 s[4:5], vcc - s_cbranch_execz .LBB1_41 -; %bb.40: - v_ashrrev_i32_e32 v23, 31, v22 - v_lshl_add_u64 v[22:23], s[2:3], 0, v[22:23] - v_mul_lo_u32 v19, v23, s13 - v_mul_lo_u32 v21, v22, s6 - v_mad_u64_u32 v[22:23], s[14:15], v22, s13, 0 - v_add3_u32 v23, v23, v21, v19 - v_lshl_add_u64 v[22:23], v[22:23], 2, v[16:17] - global_store_dword v[22:23], v5, off -.LBB1_41: - s_or_b64 exec, exec, s[4:5] - v_or_b32_e32 v22, 10, v20 - v_cmp_gt_i32_e32 vcc, s12, v22 - s_and_saveexec_b64 s[4:5], vcc - s_cbranch_execz .LBB1_43 -; %bb.42: - v_ashrrev_i32_e32 v23, 31, v22 - v_lshl_add_u64 v[22:23], s[2:3], 0, v[22:23] - v_mul_lo_u32 v19, v23, s13 - v_mul_lo_u32 v21, v22, s6 - v_mad_u64_u32 v[22:23], s[14:15], v22, s13, 0 - v_add3_u32 v23, v23, v21, v19 - v_lshl_add_u64 v[22:23], v[22:23], 2, v[16:17] - global_store_dword v[22:23], v6, off -.LBB1_43: - s_or_b64 exec, exec, s[4:5] - v_or_b32_e32 v22, 11, v20 - v_cmp_gt_i32_e32 vcc, s12, v22 - s_and_saveexec_b64 s[4:5], vcc - s_cbranch_execz .LBB1_45 -; %bb.44: - v_ashrrev_i32_e32 v23, 31, v22 - v_lshl_add_u64 v[22:23], s[2:3], 0, v[22:23] - v_mul_lo_u32 v19, v23, s13 - v_mul_lo_u32 v21, v22, s6 - v_mad_u64_u32 v[22:23], s[14:15], v22, s13, 0 - v_add3_u32 v23, v23, v21, v19 - v_lshl_add_u64 v[22:23], v[22:23], 2, v[16:17] - global_store_dword v[22:23], v7, off -.LBB1_45: - s_or_b64 exec, exec, s[4:5] - v_or_b32_e32 v22, 16, v20 - v_cmp_gt_i32_e32 vcc, s12, v22 - s_and_saveexec_b64 s[4:5], vcc - s_cbranch_execz .LBB1_47 -; %bb.46: - v_ashrrev_i32_e32 v23, 31, v22 - v_lshl_add_u64 v[22:23], s[2:3], 0, v[22:23] - v_mul_lo_u32 v19, v23, s13 - v_mul_lo_u32 v21, v22, s6 - v_mad_u64_u32 v[22:23], s[14:15], v22, s13, 0 - v_add3_u32 v23, v23, v21, v19 - v_lshl_add_u64 v[22:23], v[22:23], 2, v[16:17] - global_store_dword v[22:23], v8, off -.LBB1_47: - s_or_b64 exec, exec, s[4:5] - v_or_b32_e32 v22, 17, v20 - v_cmp_gt_i32_e32 vcc, s12, v22 - s_and_saveexec_b64 s[4:5], vcc - s_cbranch_execz .LBB1_49 -; %bb.48: - v_ashrrev_i32_e32 v23, 31, v22 - v_lshl_add_u64 v[22:23], s[2:3], 0, v[22:23] - v_mul_lo_u32 v19, v23, s13 - v_mul_lo_u32 v21, v22, s6 - v_mad_u64_u32 v[22:23], s[14:15], v22, s13, 0 - v_add3_u32 v23, v23, v21, v19 - v_lshl_add_u64 v[22:23], v[22:23], 2, v[16:17] - global_store_dword v[22:23], v9, off -.LBB1_49: - s_or_b64 exec, exec, s[4:5] - v_or_b32_e32 v22, 18, v20 - v_cmp_gt_i32_e32 vcc, s12, v22 - s_and_saveexec_b64 s[4:5], vcc - s_cbranch_execz .LBB1_51 -; %bb.50: - v_ashrrev_i32_e32 v23, 31, v22 - v_lshl_add_u64 v[22:23], s[2:3], 0, v[22:23] - v_mul_lo_u32 v19, v23, s13 - v_mul_lo_u32 v21, v22, s6 - v_mad_u64_u32 v[22:23], s[14:15], v22, s13, 0 - v_add3_u32 v23, v23, v21, v19 - v_lshl_add_u64 v[22:23], v[22:23], 2, v[16:17] - global_store_dword v[22:23], v10, off -.LBB1_51: - s_or_b64 exec, exec, s[4:5] - v_or_b32_e32 v22, 19, v20 - v_cmp_gt_i32_e32 vcc, s12, v22 - s_and_saveexec_b64 s[4:5], vcc - s_cbranch_execz .LBB1_53 -; %bb.52: - v_ashrrev_i32_e32 v23, 31, v22 - v_lshl_add_u64 v[22:23], s[2:3], 0, v[22:23] - v_mul_lo_u32 v19, v23, s13 - v_mul_lo_u32 v21, v22, s6 - v_mad_u64_u32 v[22:23], s[14:15], v22, s13, 0 - v_add3_u32 v23, v23, v21, v19 - v_lshl_add_u64 v[22:23], v[22:23], 2, v[16:17] - global_store_dword v[22:23], v11, off -.LBB1_53: - s_or_b64 exec, exec, s[4:5] - v_or_b32_e32 v22, 24, v20 - v_cmp_gt_i32_e32 vcc, s12, v22 - s_and_saveexec_b64 s[4:5], vcc - s_cbranch_execz .LBB1_55 -; %bb.54: - v_ashrrev_i32_e32 v23, 31, v22 - v_lshl_add_u64 v[22:23], s[2:3], 0, v[22:23] - v_mul_lo_u32 v19, v23, s13 - v_mul_lo_u32 v21, v22, s6 - v_mad_u64_u32 v[22:23], s[14:15], v22, s13, 0 - v_add3_u32 v23, v23, v21, v19 - v_lshl_add_u64 v[22:23], v[22:23], 2, v[16:17] - global_store_dword v[22:23], v12, off -.LBB1_55: - s_or_b64 exec, exec, s[4:5] - v_or_b32_e32 v22, 25, v20 - v_cmp_gt_i32_e32 vcc, s12, v22 - s_and_saveexec_b64 s[4:5], vcc - s_cbranch_execz .LBB1_57 -; %bb.56: - v_ashrrev_i32_e32 v23, 31, v22 - v_lshl_add_u64 v[22:23], s[2:3], 0, v[22:23] - v_mul_lo_u32 v19, v23, s13 - v_mul_lo_u32 v21, v22, s6 - v_mad_u64_u32 v[22:23], s[14:15], v22, s13, 0 - v_add3_u32 v23, v23, v21, v19 - v_lshl_add_u64 v[22:23], v[22:23], 2, v[16:17] - global_store_dword v[22:23], v13, off -.LBB1_57: - s_or_b64 exec, exec, s[4:5] - v_or_b32_e32 v22, 26, v20 - v_cmp_gt_i32_e32 vcc, s12, v22 - s_and_saveexec_b64 s[4:5], vcc - s_cbranch_execz .LBB1_59 -; %bb.58: - v_ashrrev_i32_e32 v23, 31, v22 - v_lshl_add_u64 v[22:23], s[2:3], 0, v[22:23] - v_mul_lo_u32 v19, v23, s13 - v_mul_lo_u32 v21, v22, s6 - v_mad_u64_u32 v[22:23], s[14:15], v22, s13, 0 - v_add3_u32 v23, v23, v21, v19 - v_lshl_add_u64 v[22:23], v[22:23], 2, v[16:17] - global_store_dword v[22:23], v14, off -.LBB1_59: - s_or_b64 exec, exec, s[4:5] - v_or_b32_e32 v20, 27, v20 - v_cmp_gt_i32_e32 vcc, s12, v20 - s_and_b64 exec, exec, vcc - s_cbranch_execz .LBB1_61 -; %bb.60: - v_ashrrev_i32_e32 v21, 31, v20 - v_lshl_add_u64 v[20:21], s[2:3], 0, v[20:21] - v_mul_lo_u32 v19, v21, s13 - v_mul_lo_u32 v22, v20, s6 - v_mad_u64_u32 v[20:21], s[2:3], v20, s13, 0 - v_add3_u32 v21, v21, v22, v19 - v_lshl_add_u64 v[16:17], v[20:21], 2, v[16:17] - global_store_dword v[16:17], v15, off -.LBB1_61: ; %Flow406 - s_or_b64 exec, exec, s[0:1] - s_mov_b64 s[0:1], 0 -.LBB1_62: ; %Flow409 - s_andn2_b64 vcc, exec, s[0:1] - s_cbranch_vccnz .LBB1_96 -; %bb.63: - s_and_saveexec_b64 s[0:1], s[8:9] - s_cbranch_execz .LBB1_96 -; %bb.64: ; %.preheader - v_lshl_or_b32 v20, v28, 2, v29 - v_ashrrev_i32_e32 v19, 31, v18 - v_lshl_add_u64 v[16:17], v[18:19], 1, s[10:11] - v_cmp_gt_i32_e32 vcc, s12, v20 - s_and_saveexec_b64 s[0:1], vcc - s_cbranch_execz .LBB1_66 -; %bb.65: - v_mad_i64_i32 v[18:19], s[2:3], v20, s13, 0 - v_lshl_add_u64 v[18:19], v[18:19], 1, v[16:17] - global_store_short_d16_hi v[18:19], v0, off -.LBB1_66: - s_or_b64 exec, exec, s[0:1] - v_or_b32_e32 v0, 1, v20 - v_cmp_gt_i32_e32 vcc, s12, v0 - s_and_saveexec_b64 s[0:1], vcc - s_cbranch_execz .LBB1_68 -; %bb.67: - v_mad_i64_i32 v[18:19], s[2:3], v0, s13, 0 - v_lshl_add_u64 v[18:19], v[18:19], 1, v[16:17] - global_store_short_d16_hi v[18:19], v1, off -.LBB1_68: - s_or_b64 exec, exec, s[0:1] - v_or_b32_e32 v0, 2, v20 - v_cmp_gt_i32_e32 vcc, s12, v0 - s_and_saveexec_b64 s[0:1], vcc - s_cbranch_execz .LBB1_70 -; %bb.69: - v_mad_i64_i32 v[0:1], s[2:3], v0, s13, 0 - v_lshl_add_u64 v[0:1], v[0:1], 1, v[16:17] - global_store_short_d16_hi v[0:1], v2, off -.LBB1_70: - s_or_b64 exec, exec, s[0:1] - v_or_b32_e32 v0, 3, v20 - v_cmp_gt_i32_e32 vcc, s12, v0 - s_and_saveexec_b64 s[0:1], vcc - s_cbranch_execz .LBB1_72 -; %bb.71: - v_mad_i64_i32 v[0:1], s[2:3], v0, s13, 0 - v_lshl_add_u64 v[0:1], v[0:1], 1, v[16:17] - global_store_short_d16_hi v[0:1], v3, off -.LBB1_72: - s_or_b64 exec, exec, s[0:1] - v_or_b32_e32 v0, 8, v20 - v_cmp_gt_i32_e32 vcc, s12, v0 - s_and_saveexec_b64 s[0:1], vcc - s_cbranch_execz .LBB1_74 -; %bb.73: - v_mad_i64_i32 v[0:1], s[2:3], v0, s13, 0 - v_lshl_add_u64 v[0:1], v[0:1], 1, v[16:17] - global_store_short_d16_hi v[0:1], v4, off -.LBB1_74: - s_or_b64 exec, exec, s[0:1] - v_or_b32_e32 v0, 9, v20 - v_cmp_gt_i32_e32 vcc, s12, v0 - s_and_saveexec_b64 s[0:1], vcc - s_cbranch_execz .LBB1_76 -; %bb.75: - v_mad_i64_i32 v[0:1], s[2:3], v0, s13, 0 - v_lshl_add_u64 v[0:1], v[0:1], 1, v[16:17] - global_store_short_d16_hi v[0:1], v5, off -.LBB1_76: - s_or_b64 exec, exec, s[0:1] - v_or_b32_e32 v0, 10, v20 - v_cmp_gt_i32_e32 vcc, s12, v0 - s_and_saveexec_b64 s[0:1], vcc - s_cbranch_execz .LBB1_78 -; %bb.77: - v_mad_i64_i32 v[0:1], s[2:3], v0, s13, 0 - v_lshl_add_u64 v[0:1], v[0:1], 1, v[16:17] - global_store_short_d16_hi v[0:1], v6, off -.LBB1_78: - s_or_b64 exec, exec, s[0:1] - v_or_b32_e32 v0, 11, v20 - v_cmp_gt_i32_e32 vcc, s12, v0 - s_and_saveexec_b64 s[0:1], vcc - s_cbranch_execz .LBB1_80 -; %bb.79: - v_mad_i64_i32 v[0:1], s[2:3], v0, s13, 0 - v_lshl_add_u64 v[0:1], v[0:1], 1, v[16:17] - global_store_short_d16_hi v[0:1], v7, off -.LBB1_80: - s_or_b64 exec, exec, s[0:1] - v_or_b32_e32 v0, 16, v20 - v_cmp_gt_i32_e32 vcc, s12, v0 - s_and_saveexec_b64 s[0:1], vcc - s_cbranch_execz .LBB1_82 -; %bb.81: - v_mad_i64_i32 v[0:1], s[2:3], v0, s13, 0 - v_lshl_add_u64 v[0:1], v[0:1], 1, v[16:17] - global_store_short_d16_hi v[0:1], v8, off -.LBB1_82: - s_or_b64 exec, exec, s[0:1] - v_or_b32_e32 v0, 17, v20 - v_cmp_gt_i32_e32 vcc, s12, v0 - s_and_saveexec_b64 s[0:1], vcc - s_cbranch_execz .LBB1_84 -; %bb.83: - v_mad_i64_i32 v[0:1], s[2:3], v0, s13, 0 - v_lshl_add_u64 v[0:1], v[0:1], 1, v[16:17] - global_store_short_d16_hi v[0:1], v9, off -.LBB1_84: - s_or_b64 exec, exec, s[0:1] - v_or_b32_e32 v0, 18, v20 - v_cmp_gt_i32_e32 vcc, s12, v0 - s_and_saveexec_b64 s[0:1], vcc - s_cbranch_execz .LBB1_86 -; %bb.85: - v_mad_i64_i32 v[0:1], s[2:3], v0, s13, 0 - v_lshl_add_u64 v[0:1], v[0:1], 1, v[16:17] - global_store_short_d16_hi v[0:1], v10, off -.LBB1_86: - s_or_b64 exec, exec, s[0:1] - v_or_b32_e32 v0, 19, v20 - v_cmp_gt_i32_e32 vcc, s12, v0 - s_and_saveexec_b64 s[0:1], vcc - s_cbranch_execz .LBB1_88 -; %bb.87: - v_mad_i64_i32 v[0:1], s[2:3], v0, s13, 0 - v_lshl_add_u64 v[0:1], v[0:1], 1, v[16:17] - global_store_short_d16_hi v[0:1], v11, off -.LBB1_88: - s_or_b64 exec, exec, s[0:1] - v_or_b32_e32 v0, 24, v20 - v_cmp_gt_i32_e32 vcc, s12, v0 - s_and_saveexec_b64 s[0:1], vcc - s_cbranch_execz .LBB1_90 -; %bb.89: - v_mad_i64_i32 v[0:1], s[2:3], v0, s13, 0 - v_lshl_add_u64 v[0:1], v[0:1], 1, v[16:17] - global_store_short_d16_hi v[0:1], v12, off -.LBB1_90: - s_or_b64 exec, exec, s[0:1] - v_or_b32_e32 v0, 25, v20 - v_cmp_gt_i32_e32 vcc, s12, v0 - s_and_saveexec_b64 s[0:1], vcc - s_cbranch_execz .LBB1_92 -; %bb.91: - v_mad_i64_i32 v[0:1], s[2:3], v0, s13, 0 - v_lshl_add_u64 v[0:1], v[0:1], 1, v[16:17] - global_store_short_d16_hi v[0:1], v13, off -.LBB1_92: - s_or_b64 exec, exec, s[0:1] - v_or_b32_e32 v0, 26, v20 - v_cmp_gt_i32_e32 vcc, s12, v0 - s_and_saveexec_b64 s[0:1], vcc - s_cbranch_execz .LBB1_94 -; %bb.93: - v_mad_i64_i32 v[0:1], s[2:3], v0, s13, 0 - v_lshl_add_u64 v[0:1], v[0:1], 1, v[16:17] - global_store_short_d16_hi v[0:1], v14, off -.LBB1_94: - s_or_b64 exec, exec, s[0:1] - v_or_b32_e32 v0, 27, v20 - v_cmp_gt_i32_e32 vcc, s12, v0 - s_and_b64 exec, exec, vcc - s_cbranch_execz .LBB1_96 -; %bb.95: - v_mad_i64_i32 v[0:1], s[0:1], v0, s13, 0 - v_lshl_add_u64 v[0:1], v[0:1], 1, v[16:17] - global_store_short_d16_hi v[0:1], v15, off -.LBB1_96: ; %.loopexit - s_endpgm - .section .rodata,"a",@progbits - .p2align 6, 0x0 - .amdhsa_kernel _Z13gemm_core_ldsPKhS0_PvS0_S0_iiiiiii - .amdhsa_group_segment_fixed_size 16384 - .amdhsa_private_segment_fixed_size 0 - .amdhsa_kernarg_size 68 - .amdhsa_user_sgpr_count 2 - .amdhsa_user_sgpr_dispatch_ptr 0 - .amdhsa_user_sgpr_queue_ptr 0 - .amdhsa_user_sgpr_kernarg_segment_ptr 1 - .amdhsa_user_sgpr_dispatch_id 0 - .amdhsa_user_sgpr_kernarg_preload_length 0 - .amdhsa_user_sgpr_kernarg_preload_offset 0 - .amdhsa_user_sgpr_private_segment_size 0 - .amdhsa_uses_dynamic_stack 0 - .amdhsa_enable_private_segment 0 - .amdhsa_system_sgpr_workgroup_id_x 1 - .amdhsa_system_sgpr_workgroup_id_y 1 - .amdhsa_system_sgpr_workgroup_id_z 1 - .amdhsa_system_sgpr_workgroup_info 0 - .amdhsa_system_vgpr_workitem_id 0 - .amdhsa_next_free_vgpr 52 - .amdhsa_next_free_sgpr 32 - .amdhsa_accum_offset 36 - .amdhsa_reserve_vcc 1 - .amdhsa_float_round_mode_32 0 - .amdhsa_float_round_mode_16_64 0 - .amdhsa_float_denorm_mode_32 3 - .amdhsa_float_denorm_mode_16_64 3 - .amdhsa_dx10_clamp 1 - .amdhsa_ieee_mode 1 - .amdhsa_fp16_overflow 0 - .amdhsa_tg_split 0 - .amdhsa_exception_fp_ieee_invalid_op 0 - .amdhsa_exception_fp_denorm_src 0 - .amdhsa_exception_fp_ieee_div_zero 0 - .amdhsa_exception_fp_ieee_overflow 0 - .amdhsa_exception_fp_ieee_underflow 0 - .amdhsa_exception_fp_ieee_inexact 0 - .amdhsa_exception_int_div_zero 0 - .end_amdhsa_kernel - .text -.Lfunc_end1: - .size _Z13gemm_core_ldsPKhS0_PvS0_S0_iiiiiii, .Lfunc_end1-_Z13gemm_core_ldsPKhS0_PvS0_S0_iiiiiii - ; -- End function - .set _Z13gemm_core_ldsPKhS0_PvS0_S0_iiiiiii.num_vgpr, 35 - .set _Z13gemm_core_ldsPKhS0_PvS0_S0_iiiiiii.num_agpr, 16 - .set _Z13gemm_core_ldsPKhS0_PvS0_S0_iiiiiii.numbered_sgpr, 32 - .set _Z13gemm_core_ldsPKhS0_PvS0_S0_iiiiiii.num_named_barrier, 0 - .set _Z13gemm_core_ldsPKhS0_PvS0_S0_iiiiiii.private_seg_size, 0 - .set _Z13gemm_core_ldsPKhS0_PvS0_S0_iiiiiii.uses_vcc, 1 - .set _Z13gemm_core_ldsPKhS0_PvS0_S0_iiiiiii.uses_flat_scratch, 0 - .set _Z13gemm_core_ldsPKhS0_PvS0_S0_iiiiiii.has_dyn_sized_stack, 0 - .set _Z13gemm_core_ldsPKhS0_PvS0_S0_iiiiiii.has_recursion, 0 - .set _Z13gemm_core_ldsPKhS0_PvS0_S0_iiiiiii.has_indirect_call, 0 - .section .AMDGPU.csdata,"",@progbits -; Kernel info: -; codeLenInByte = 3892 -; TotalNumSgprs: 38 -; NumVgprs: 35 -; NumAgprs: 16 -; TotalNumVgprs: 52 -; ScratchSize: 0 -; MemoryBound: 1 -; FloatMode: 240 -; IeeeMode: 1 -; LDSByteSize: 16384 bytes/workgroup (compile time only) -; SGPRBlocks: 4 -; VGPRBlocks: 6 -; NumSGPRsForWavesPerEU: 38 -; NumVGPRsForWavesPerEU: 52 -; AccumOffset: 36 -; Occupancy: 8 -; WaveLimiterHint : 0 -; COMPUTE_PGM_RSRC2:SCRATCH_EN: 0 -; COMPUTE_PGM_RSRC2:USER_SGPR: 2 -; COMPUTE_PGM_RSRC2:TRAP_HANDLER: 0 -; COMPUTE_PGM_RSRC2:TGID_X_EN: 1 -; COMPUTE_PGM_RSRC2:TGID_Y_EN: 1 -; COMPUTE_PGM_RSRC2:TGID_Z_EN: 1 -; COMPUTE_PGM_RSRC2:TIDIG_COMP_CNT: 0 -; COMPUTE_PGM_RSRC3_GFX90A:ACCUM_OFFSET: 8 -; COMPUTE_PGM_RSRC3_GFX90A:TG_SPLIT: 0 - .text - .protected _Z9reduce_skPKfP12hip_bfloat16ii ; -- Begin function _Z9reduce_skPKfP12hip_bfloat16ii - .globl _Z9reduce_skPKfP12hip_bfloat16ii - .p2align 8 - .type _Z9reduce_skPKfP12hip_bfloat16ii,@function -_Z9reduce_skPKfP12hip_bfloat16ii: ; @_Z9reduce_skPKfP12hip_bfloat16ii -; %bb.0: - s_load_dwordx2 s[4:5], s[0:1], 0x10 - v_lshl_add_u32 v0, s2, 8, v0 - s_waitcnt lgkmcnt(0) - v_cmp_gt_i32_e32 vcc, s4, v0 - s_and_saveexec_b64 s[2:3], vcc - s_cbranch_execz .LBB2_8 -; %bb.1: ; %.preheader - s_cmp_gt_i32 s5, 0 - v_ashrrev_i32_e32 v1, 31, v0 - s_cbranch_scc1 .LBB2_3 -; %bb.2: ; %.preheader.._crit_edge_crit_edge - s_load_dwordx2 s[2:3], s[0:1], 0x8 - v_mov_b32_e32 v2, 0 - s_cbranch_execz .LBB2_4 - s_branch .LBB2_7 -.LBB2_3: - s_load_dwordx2 s[2:3], s[0:1], 0x8 - v_mov_b32_e32 v2, 0 -.LBB2_4: ; %.lr.ph - s_load_dwordx2 s[6:7], s[0:1], 0x0 - s_ashr_i32 s1, s4, 31 - s_mov_b32 s0, s4 - s_lshl_b64 s[0:1], s[0:1], 2 - v_mov_b32_e32 v4, 0 - s_waitcnt lgkmcnt(0) - v_lshl_add_u64 v[2:3], v[0:1], 2, s[6:7] -.LBB2_5: ; =>This Inner Loop Header: Depth=1 - global_load_dword v5, v[2:3], off - s_add_i32 s5, s5, -1 - v_lshl_add_u64 v[2:3], v[2:3], 0, s[0:1] - s_cmp_eq_u32 s5, 0 - s_waitcnt vmcnt(0) - v_add_f32_e32 v4, v4, v5 - s_cbranch_scc0 .LBB2_5 -; %bb.6: ; %._crit_edge.loopexit - v_lshrrev_b32_e32 v2, 16, v4 -.LBB2_7: ; %Flow26 - s_waitcnt lgkmcnt(0) - v_lshl_add_u64 v[0:1], v[0:1], 1, s[2:3] - global_store_short v[0:1], v2, off -.LBB2_8: - s_endpgm - .section .rodata,"a",@progbits - .p2align 6, 0x0 - .amdhsa_kernel _Z9reduce_skPKfP12hip_bfloat16ii - .amdhsa_group_segment_fixed_size 0 - .amdhsa_private_segment_fixed_size 0 - .amdhsa_kernarg_size 24 - .amdhsa_user_sgpr_count 2 - .amdhsa_user_sgpr_dispatch_ptr 0 - .amdhsa_user_sgpr_queue_ptr 0 - .amdhsa_user_sgpr_kernarg_segment_ptr 1 - .amdhsa_user_sgpr_dispatch_id 0 - .amdhsa_user_sgpr_kernarg_preload_length 0 - .amdhsa_user_sgpr_kernarg_preload_offset 0 - .amdhsa_user_sgpr_private_segment_size 0 - .amdhsa_uses_dynamic_stack 0 - .amdhsa_enable_private_segment 0 - .amdhsa_system_sgpr_workgroup_id_x 1 - .amdhsa_system_sgpr_workgroup_id_y 0 - .amdhsa_system_sgpr_workgroup_id_z 0 - .amdhsa_system_sgpr_workgroup_info 0 - .amdhsa_system_vgpr_workitem_id 0 - .amdhsa_next_free_vgpr 6 - .amdhsa_next_free_sgpr 8 - .amdhsa_accum_offset 8 - .amdhsa_reserve_vcc 1 - .amdhsa_float_round_mode_32 0 - .amdhsa_float_round_mode_16_64 0 - .amdhsa_float_denorm_mode_32 3 - .amdhsa_float_denorm_mode_16_64 3 - .amdhsa_dx10_clamp 1 - .amdhsa_ieee_mode 1 - .amdhsa_fp16_overflow 0 - .amdhsa_tg_split 0 - .amdhsa_exception_fp_ieee_invalid_op 0 - .amdhsa_exception_fp_denorm_src 0 - .amdhsa_exception_fp_ieee_div_zero 0 - .amdhsa_exception_fp_ieee_overflow 0 - .amdhsa_exception_fp_ieee_underflow 0 - .amdhsa_exception_fp_ieee_inexact 0 - .amdhsa_exception_int_div_zero 0 - .end_amdhsa_kernel - .text -.Lfunc_end2: - .size _Z9reduce_skPKfP12hip_bfloat16ii, .Lfunc_end2-_Z9reduce_skPKfP12hip_bfloat16ii - ; -- End function - .set _Z9reduce_skPKfP12hip_bfloat16ii.num_vgpr, 6 - .set _Z9reduce_skPKfP12hip_bfloat16ii.num_agpr, 0 - .set _Z9reduce_skPKfP12hip_bfloat16ii.numbered_sgpr, 8 - .set _Z9reduce_skPKfP12hip_bfloat16ii.num_named_barrier, 0 - .set _Z9reduce_skPKfP12hip_bfloat16ii.private_seg_size, 0 - .set _Z9reduce_skPKfP12hip_bfloat16ii.uses_vcc, 1 - .set _Z9reduce_skPKfP12hip_bfloat16ii.uses_flat_scratch, 0 - .set _Z9reduce_skPKfP12hip_bfloat16ii.has_dyn_sized_stack, 0 - .set _Z9reduce_skPKfP12hip_bfloat16ii.has_recursion, 0 - .set _Z9reduce_skPKfP12hip_bfloat16ii.has_indirect_call, 0 - .section .AMDGPU.csdata,"",@progbits -; Kernel info: -; codeLenInByte = 176 -; TotalNumSgprs: 14 -; NumVgprs: 6 -; NumAgprs: 0 -; TotalNumVgprs: 6 -; ScratchSize: 0 -; MemoryBound: 0 -; FloatMode: 240 -; IeeeMode: 1 -; LDSByteSize: 0 bytes/workgroup (compile time only) -; SGPRBlocks: 1 -; VGPRBlocks: 0 -; NumSGPRsForWavesPerEU: 14 -; NumVGPRsForWavesPerEU: 6 -; AccumOffset: 8 -; Occupancy: 8 -; WaveLimiterHint : 0 -; COMPUTE_PGM_RSRC2:SCRATCH_EN: 0 -; COMPUTE_PGM_RSRC2:USER_SGPR: 2 -; COMPUTE_PGM_RSRC2:TRAP_HANDLER: 0 -; COMPUTE_PGM_RSRC2:TGID_X_EN: 1 -; COMPUTE_PGM_RSRC2:TGID_Y_EN: 0 -; COMPUTE_PGM_RSRC2:TGID_Z_EN: 0 -; COMPUTE_PGM_RSRC2:TIDIG_COMP_CNT: 0 -; COMPUTE_PGM_RSRC3_GFX90A:ACCUM_OFFSET: 1 -; COMPUTE_PGM_RSRC3_GFX90A:TG_SPLIT: 0 - .text - .p2alignl 6, 3212836864 - .fill 256, 4, 3212836864 - .section .AMDGPU.gpr_maximums,"",@progbits - .set amdgpu.max_num_vgpr, 0 - .set amdgpu.max_num_agpr, 0 - .set amdgpu.max_num_sgpr, 0 - .text - .type __hip_cuid_85127a709332f6f1,@object ; @__hip_cuid_85127a709332f6f1 - .section .bss,"aw",@nobits - .globl __hip_cuid_85127a709332f6f1 -__hip_cuid_85127a709332f6f1: - .byte 0 ; 0x0 - .size __hip_cuid_85127a709332f6f1, 1 - - .ident "AMD clang version 22.0.0git (https://github.com/RadeonOpenCompute/llvm-project roc-7.2.0 26014 7b800a19466229b8479a78de19143dc33c3ab9b5)" - .section ".note.GNU-stack","",@progbits - .addrsig - .addrsig_sym __hip_cuid_85127a709332f6f1 - .amdgpu_metadata ---- -amdhsa.kernels: - - .agpr_count: 16 - .args: - - .actual_access: read_only - .address_space: global - .offset: 0 - .size: 8 - .value_kind: global_buffer - - .actual_access: read_only - .address_space: global - .offset: 8 - .size: 8 - .value_kind: global_buffer - - .actual_access: write_only - .address_space: global - .offset: 16 - .size: 8 - .value_kind: global_buffer - - .actual_access: read_only - .address_space: global - .offset: 24 - .size: 8 - .value_kind: global_buffer - - .actual_access: read_only - .address_space: global - .offset: 32 - .size: 8 - .value_kind: global_buffer - - .offset: 40 - .size: 4 - .value_kind: by_value - - .offset: 44 - .size: 4 - .value_kind: by_value - - .offset: 48 - .size: 4 - .value_kind: by_value - - .offset: 52 - .size: 4 - .value_kind: by_value - - .offset: 56 - .size: 4 - .value_kind: by_value - - .offset: 60 - .size: 4 - .value_kind: by_value - - .offset: 64 - .size: 4 - .value_kind: by_value - .group_segment_fixed_size: 0 - .kernarg_segment_align: 8 - .kernarg_segment_size: 68 - .language: OpenCL C - .language_version: - - 2 - - 0 - .max_flat_workgroup_size: 64 - .name: _Z14gemm_core_gmemPKhS0_PvS0_S0_iiiiiii - .private_segment_fixed_size: 0 - .sgpr_count: 34 - .sgpr_spill_count: 0 - .symbol: _Z14gemm_core_gmemPKhS0_PvS0_S0_iiiiiii.kd - .uniform_work_group_size: 1 - .uses_dynamic_stack: false - .vgpr_count: 44 - .vgpr_spill_count: 0 - .wavefront_size: 64 - - .agpr_count: 16 - .args: - - .actual_access: read_only - .address_space: global - .offset: 0 - .size: 8 - .value_kind: global_buffer - - .actual_access: read_only - .address_space: global - .offset: 8 - .size: 8 - .value_kind: global_buffer - - .actual_access: write_only - .address_space: global - .offset: 16 - .size: 8 - .value_kind: global_buffer - - .actual_access: read_only - .address_space: global - .offset: 24 - .size: 8 - .value_kind: global_buffer - - .actual_access: read_only - .address_space: global - .offset: 32 - .size: 8 - .value_kind: global_buffer - - .offset: 40 - .size: 4 - .value_kind: by_value - - .offset: 44 - .size: 4 - .value_kind: by_value - - .offset: 48 - .size: 4 - .value_kind: by_value - - .offset: 52 - .size: 4 - .value_kind: by_value - - .offset: 56 - .size: 4 - .value_kind: by_value - - .offset: 60 - .size: 4 - .value_kind: by_value - - .offset: 64 - .size: 4 - .value_kind: by_value - .group_segment_fixed_size: 16384 - .kernarg_segment_align: 8 - .kernarg_segment_size: 68 - .language: OpenCL C - .language_version: - - 2 - - 0 - .max_flat_workgroup_size: 256 - .name: _Z13gemm_core_ldsPKhS0_PvS0_S0_iiiiiii - .private_segment_fixed_size: 0 - .sgpr_count: 38 - .sgpr_spill_count: 0 - .symbol: _Z13gemm_core_ldsPKhS0_PvS0_S0_iiiiiii.kd - .uniform_work_group_size: 1 - .uses_dynamic_stack: false - .vgpr_count: 52 - .vgpr_spill_count: 0 - .wavefront_size: 64 - - .agpr_count: 0 - .args: - - .actual_access: read_only - .address_space: global - .offset: 0 - .size: 8 - .value_kind: global_buffer - - .actual_access: write_only - .address_space: global - .offset: 8 - .size: 8 - .value_kind: global_buffer - - .offset: 16 - .size: 4 - .value_kind: by_value - - .offset: 20 - .size: 4 - .value_kind: by_value - .group_segment_fixed_size: 0 - .kernarg_segment_align: 8 - .kernarg_segment_size: 24 - .language: OpenCL C - .language_version: - - 2 - - 0 - .max_flat_workgroup_size: 1024 - .name: _Z9reduce_skPKfP12hip_bfloat16ii - .private_segment_fixed_size: 0 - .sgpr_count: 14 - .sgpr_spill_count: 0 - .symbol: _Z9reduce_skPKfP12hip_bfloat16ii.kd - .uniform_work_group_size: 1 - .uses_dynamic_stack: false - .vgpr_count: 6 - .vgpr_spill_count: 0 - .wavefront_size: 64 -amdhsa.target: amdgcn-amd-amdhsa--gfx950 -amdhsa.version: - - 1 - - 2 -... - - .end_amdgpu_metadata - -# __CLANG_OFFLOAD_BUNDLE____END__ hip-amdgcn-amd-amdhsa--gfx950 - -# __CLANG_OFFLOAD_BUNDLE____START__ host-x86_64-unknown-linux-gnu- - .file "src_v22.hip" - .text - .globl _Z29__device_stub__gemm_core_gmemPKhS0_PvS0_S0_iiiiiii # -- Begin function _Z29__device_stub__gemm_core_gmemPKhS0_PvS0_S0_iiiiiii - .p2align 4 - .type _Z29__device_stub__gemm_core_gmemPKhS0_PvS0_S0_iiiiiii,@function -_Z29__device_stub__gemm_core_gmemPKhS0_PvS0_S0_iiiiiii: # @_Z29__device_stub__gemm_core_gmemPKhS0_PvS0_S0_iiiiiii - .cfi_startproc -# %bb.0: - subq $200, %rsp - .cfi_def_cfa_offset 208 - movq %rdi, 88(%rsp) - movq %rsi, 80(%rsp) - movq %rdx, 72(%rsp) - movq %rcx, 64(%rsp) - movq %r8, 56(%rsp) - movl %r9d, 4(%rsp) - leaq 88(%rsp), %rax - movq %rax, 96(%rsp) - leaq 80(%rsp), %rax - movq %rax, 104(%rsp) - leaq 72(%rsp), %rax - movq %rax, 112(%rsp) - leaq 64(%rsp), %rax - movq %rax, 120(%rsp) - leaq 56(%rsp), %rax - movq %rax, 128(%rsp) - leaq 4(%rsp), %rax - movq %rax, 136(%rsp) - leaq 208(%rsp), %rax - movq %rax, 144(%rsp) - leaq 216(%rsp), %rax - movq %rax, 152(%rsp) - leaq 224(%rsp), %rax - movq %rax, 160(%rsp) - leaq 232(%rsp), %rax - movq %rax, 168(%rsp) - leaq 240(%rsp), %rax - movq %rax, 176(%rsp) - leaq 248(%rsp), %rax - movq %rax, 184(%rsp) - leaq 40(%rsp), %rdi - leaq 24(%rsp), %rsi - leaq 16(%rsp), %rdx - leaq 8(%rsp), %rcx - callq __hipPopCallConfiguration - movq 40(%rsp), %rsi - movl 48(%rsp), %edx - movq 24(%rsp), %rcx - movl 32(%rsp), %r8d - leaq 96(%rsp), %r9 - movl $_Z14gemm_core_gmemPKhS0_PvS0_S0_iiiiiii, %edi - pushq 8(%rsp) - .cfi_adjust_cfa_offset 8 - pushq 24(%rsp) - .cfi_adjust_cfa_offset 8 - callq hipLaunchKernel - addq $216, %rsp - .cfi_adjust_cfa_offset -216 - retq -.Lfunc_end0: - .size _Z29__device_stub__gemm_core_gmemPKhS0_PvS0_S0_iiiiiii, .Lfunc_end0-_Z29__device_stub__gemm_core_gmemPKhS0_PvS0_S0_iiiiiii - .cfi_endproc - # -- End function - .globl _Z28__device_stub__gemm_core_ldsPKhS0_PvS0_S0_iiiiiii # -- Begin function _Z28__device_stub__gemm_core_ldsPKhS0_PvS0_S0_iiiiiii - .p2align 4 - .type _Z28__device_stub__gemm_core_ldsPKhS0_PvS0_S0_iiiiiii,@function -_Z28__device_stub__gemm_core_ldsPKhS0_PvS0_S0_iiiiiii: # @_Z28__device_stub__gemm_core_ldsPKhS0_PvS0_S0_iiiiiii - .cfi_startproc -# %bb.0: - subq $200, %rsp - .cfi_def_cfa_offset 208 - movq %rdi, 88(%rsp) - movq %rsi, 80(%rsp) - movq %rdx, 72(%rsp) - movq %rcx, 64(%rsp) - movq %r8, 56(%rsp) - movl %r9d, 4(%rsp) - leaq 88(%rsp), %rax - movq %rax, 96(%rsp) - leaq 80(%rsp), %rax - movq %rax, 104(%rsp) - leaq 72(%rsp), %rax - movq %rax, 112(%rsp) - leaq 64(%rsp), %rax - movq %rax, 120(%rsp) - leaq 56(%rsp), %rax - movq %rax, 128(%rsp) - leaq 4(%rsp), %rax - movq %rax, 136(%rsp) - leaq 208(%rsp), %rax - movq %rax, 144(%rsp) - leaq 216(%rsp), %rax - movq %rax, 152(%rsp) - leaq 224(%rsp), %rax - movq %rax, 160(%rsp) - leaq 232(%rsp), %rax - movq %rax, 168(%rsp) - leaq 240(%rsp), %rax - movq %rax, 176(%rsp) - leaq 248(%rsp), %rax - movq %rax, 184(%rsp) - leaq 40(%rsp), %rdi - leaq 24(%rsp), %rsi - leaq 16(%rsp), %rdx - leaq 8(%rsp), %rcx - callq __hipPopCallConfiguration - movq 40(%rsp), %rsi - movl 48(%rsp), %edx - movq 24(%rsp), %rcx - movl 32(%rsp), %r8d - leaq 96(%rsp), %r9 - movl $_Z13gemm_core_ldsPKhS0_PvS0_S0_iiiiiii, %edi - pushq 8(%rsp) - .cfi_adjust_cfa_offset 8 - pushq 24(%rsp) - .cfi_adjust_cfa_offset 8 - callq hipLaunchKernel - addq $216, %rsp - .cfi_adjust_cfa_offset -216 - retq -.Lfunc_end1: - .size _Z28__device_stub__gemm_core_ldsPKhS0_PvS0_S0_iiiiiii, .Lfunc_end1-_Z28__device_stub__gemm_core_ldsPKhS0_PvS0_S0_iiiiiii - .cfi_endproc - # -- End function - .globl _Z24__device_stub__reduce_skPKfP12hip_bfloat16ii # -- Begin function _Z24__device_stub__reduce_skPKfP12hip_bfloat16ii - .p2align 4 - .type _Z24__device_stub__reduce_skPKfP12hip_bfloat16ii,@function -_Z24__device_stub__reduce_skPKfP12hip_bfloat16ii: # @_Z24__device_stub__reduce_skPKfP12hip_bfloat16ii - .cfi_startproc -# %bb.0: - subq $120, %rsp - .cfi_def_cfa_offset 128 - movq %rdi, 72(%rsp) - movq %rsi, 64(%rsp) - movl %edx, 12(%rsp) - movl %ecx, 8(%rsp) - leaq 72(%rsp), %rax - movq %rax, 80(%rsp) - leaq 64(%rsp), %rax - movq %rax, 88(%rsp) - leaq 12(%rsp), %rax - movq %rax, 96(%rsp) - leaq 8(%rsp), %rax - movq %rax, 104(%rsp) - leaq 48(%rsp), %rdi - leaq 32(%rsp), %rsi - leaq 24(%rsp), %rdx - leaq 16(%rsp), %rcx - callq __hipPopCallConfiguration - movq 48(%rsp), %rsi - movl 56(%rsp), %edx - movq 32(%rsp), %rcx - movl 40(%rsp), %r8d - leaq 80(%rsp), %r9 - movl $_Z9reduce_skPKfP12hip_bfloat16ii, %edi - pushq 16(%rsp) - .cfi_adjust_cfa_offset 8 - pushq 32(%rsp) - .cfi_adjust_cfa_offset 8 - callq hipLaunchKernel - addq $136, %rsp - .cfi_adjust_cfa_offset -136 - retq -.Lfunc_end2: - .size _Z24__device_stub__reduce_skPKfP12hip_bfloat16ii, .Lfunc_end2-_Z24__device_stub__reduce_skPKfP12hip_bfloat16ii - .cfi_endproc - # -- End function - .p2align 4 # -- Begin function __hip_module_ctor - .type __hip_module_ctor,@function -__hip_module_ctor: # @__hip_module_ctor - .cfi_startproc -# %bb.0: - pushq %rbx - .cfi_def_cfa_offset 16 - subq $32, %rsp - .cfi_def_cfa_offset 48 - .cfi_offset %rbx, -16 - movq __hip_gpubin_handle_85127a709332f6f1(%rip), %rbx - testq %rbx, %rbx - jne .LBB3_2 -# %bb.1: - movl $__hip_fatbin_wrapper, %edi - callq __hipRegisterFatBinary - movq %rax, %rbx - movq %rax, __hip_gpubin_handle_85127a709332f6f1(%rip) -.LBB3_2: - xorps %xmm0, %xmm0 - movups %xmm0, 16(%rsp) - movups %xmm0, (%rsp) - movl $_Z14gemm_core_gmemPKhS0_PvS0_S0_iiiiiii, %esi - movl $.L__unnamed_1, %edx - movl $.L__unnamed_1, %ecx - movq %rbx, %rdi - movl $-1, %r8d - xorl %r9d, %r9d - callq __hipRegisterFunction - xorps %xmm0, %xmm0 - movups %xmm0, 16(%rsp) - movups %xmm0, (%rsp) - movl $_Z13gemm_core_ldsPKhS0_PvS0_S0_iiiiiii, %esi - movl $.L__unnamed_2, %edx - movl $.L__unnamed_2, %ecx - movq %rbx, %rdi - movl $-1, %r8d - xorl %r9d, %r9d - callq __hipRegisterFunction - xorps %xmm0, %xmm0 - movups %xmm0, 16(%rsp) - movups %xmm0, (%rsp) - movl $_Z9reduce_skPKfP12hip_bfloat16ii, %esi - movl $.L__unnamed_3, %edx - movl $.L__unnamed_3, %ecx - movq %rbx, %rdi - movl $-1, %r8d - xorl %r9d, %r9d - callq __hipRegisterFunction - movl $__hip_module_dtor, %edi - addq $32, %rsp - .cfi_def_cfa_offset 16 - popq %rbx - .cfi_def_cfa_offset 8 - jmp atexit # TAILCALL -.Lfunc_end3: - .size __hip_module_ctor, .Lfunc_end3-__hip_module_ctor - .cfi_endproc - # -- End function - .p2align 4 # -- Begin function __hip_module_dtor - .type __hip_module_dtor,@function -__hip_module_dtor: # @__hip_module_dtor - .cfi_startproc -# %bb.0: - movq __hip_gpubin_handle_85127a709332f6f1(%rip), %rdi - testq %rdi, %rdi - je .LBB4_2 -# %bb.1: - pushq %rax - .cfi_def_cfa_offset 16 - callq __hipUnregisterFatBinary - movq $0, __hip_gpubin_handle_85127a709332f6f1(%rip) - addq $8, %rsp - .cfi_def_cfa_offset 8 -.LBB4_2: - retq -.Lfunc_end4: - .size __hip_module_dtor, .Lfunc_end4-__hip_module_dtor - .cfi_endproc - # -- End function - .type _Z14gemm_core_gmemPKhS0_PvS0_S0_iiiiiii,@object # @_Z14gemm_core_gmemPKhS0_PvS0_S0_iiiiiii - .section .rodata,"a",@progbits - .globl _Z14gemm_core_gmemPKhS0_PvS0_S0_iiiiiii - .p2align 3, 0x0 -_Z14gemm_core_gmemPKhS0_PvS0_S0_iiiiiii: - .quad _Z29__device_stub__gemm_core_gmemPKhS0_PvS0_S0_iiiiiii - .size _Z14gemm_core_gmemPKhS0_PvS0_S0_iiiiiii, 8 - - .type _Z13gemm_core_ldsPKhS0_PvS0_S0_iiiiiii,@object # @_Z13gemm_core_ldsPKhS0_PvS0_S0_iiiiiii - .globl _Z13gemm_core_ldsPKhS0_PvS0_S0_iiiiiii - .p2align 3, 0x0 -_Z13gemm_core_ldsPKhS0_PvS0_S0_iiiiiii: - .quad _Z28__device_stub__gemm_core_ldsPKhS0_PvS0_S0_iiiiiii - .size _Z13gemm_core_ldsPKhS0_PvS0_S0_iiiiiii, 8 - - .type _Z9reduce_skPKfP12hip_bfloat16ii,@object # @_Z9reduce_skPKfP12hip_bfloat16ii - .globl _Z9reduce_skPKfP12hip_bfloat16ii - .p2align 3, 0x0 -_Z9reduce_skPKfP12hip_bfloat16ii: - .quad _Z24__device_stub__reduce_skPKfP12hip_bfloat16ii - .size _Z9reduce_skPKfP12hip_bfloat16ii, 8 - - .type .L__unnamed_1,@object # @0 - .section .rodata.str1.1,"aMS",@progbits,1 -.L__unnamed_1: - .asciz "_Z14gemm_core_gmemPKhS0_PvS0_S0_iiiiiii" - .size .L__unnamed_1, 40 - - .type .L__unnamed_2,@object # @1 -.L__unnamed_2: - .asciz "_Z13gemm_core_ldsPKhS0_PvS0_S0_iiiiiii" - .size .L__unnamed_2, 39 - - .type .L__unnamed_3,@object # @2 -.L__unnamed_3: - .asciz "_Z9reduce_skPKfP12hip_bfloat16ii" - .size .L__unnamed_3, 33 - - .type __hip_fatbin_wrapper,@object # @__hip_fatbin_wrapper - .section .hipFatBinSegment,"a",@progbits - .p2align 3, 0x0 -__hip_fatbin_wrapper: - .long 1212764230 # 0x48495046 - .long 1 # 0x1 - .quad __hip_fatbin_85127a709332f6f1 - .quad 0 - .size __hip_fatbin_wrapper, 24 - - .type __hip_gpubin_handle_85127a709332f6f1,@object # @__hip_gpubin_handle_85127a709332f6f1 - .local __hip_gpubin_handle_85127a709332f6f1 - .comm __hip_gpubin_handle_85127a709332f6f1,8,8 - .section .init_array,"aw",@init_array - .p2align 3, 0x0 - .quad __hip_module_ctor - .type __hip_cuid_85127a709332f6f1,@object # @__hip_cuid_85127a709332f6f1 - .bss - .globl __hip_cuid_85127a709332f6f1 -__hip_cuid_85127a709332f6f1: - .byte 0 # 0x0 - .size __hip_cuid_85127a709332f6f1, 1 - - .ident "AMD clang version 22.0.0git (https://github.com/RadeonOpenCompute/llvm-project roc-7.2.0 26014 7b800a19466229b8479a78de19143dc33c3ab9b5)" - .section ".note.GNU-stack","",@progbits - .addrsig - .addrsig_sym _Z29__device_stub__gemm_core_gmemPKhS0_PvS0_S0_iiiiiii - .addrsig_sym _Z28__device_stub__gemm_core_ldsPKhS0_PvS0_S0_iiiiiii - .addrsig_sym _Z24__device_stub__reduce_skPKfP12hip_bfloat16ii - .addrsig_sym __hip_module_ctor - .addrsig_sym __hip_module_dtor - .addrsig_sym _Z14gemm_core_gmemPKhS0_PvS0_S0_iiiiiii - .addrsig_sym _Z13gemm_core_ldsPKhS0_PvS0_S0_iiiiiii - .addrsig_sym _Z9reduce_skPKfP12hip_bfloat16ii - .addrsig_sym __hip_fatbin_85127a709332f6f1 - .addrsig_sym __hip_fatbin_wrapper - .addrsig_sym __hip_cuid_85127a709332f6f1 - -# __CLANG_OFFLOAD_BUNDLE____END__ host-x86_64-unknown-linux-gnu- diff --git a/examples/amd_mxfp4_mm/asm-dump/src.s b/examples/amd_mxfp4_mm/asm-dump/src.s deleted file mode 100644 index ef74cd27ca..0000000000 --- a/examples/amd_mxfp4_mm/asm-dump/src.s +++ /dev/null @@ -1,1283 +0,0 @@ - -# __CLANG_OFFLOAD_BUNDLE____START__ hip-amdgcn-amd-amdhsa--gfx950 - .amdgcn_target "amdgcn-amd-amdhsa--gfx950" - .amdhsa_code_object_version 6 - .text - .protected _Z9gemm_corePKhS0_PvS0_S0_iiiiiii ; -- Begin function _Z9gemm_corePKhS0_PvS0_S0_iiiiiii - .globl _Z9gemm_corePKhS0_PvS0_S0_iiiiiii - .p2align 8 - .type _Z9gemm_corePKhS0_PvS0_S0_iiiiiii,@function -_Z9gemm_corePKhS0_PvS0_S0_iiiiiii: ; @_Z9gemm_corePKhS0_PvS0_S0_iiiiiii -; %bb.0: - s_load_dwordx2 s[6:7], s[0:1], 0x10 - s_load_dwordx8 s[8:15], s[0:1], 0x28 - v_and_b32_e32 v1, 31, v0 - s_waitcnt lgkmcnt(0) - s_lshl_b32 s15, s2, 5 - s_lshl_b32 s23, s3, 5 - v_or_b32_e32 v2, s23, v1 - s_mul_i32 s22, s11, s4 - s_add_i32 s2, s22, s11 - s_min_i32 s5, s2, s10 - v_lshrrev_b32_e32 v14, 5, v0 - v_cmp_gt_i32_e64 s[2:3], s9, v2 - v_ashrrev_i32_e32 v3, 31, v2 - v_mov_b32_e32 v18, 0 - v_accvgpr_write_b32 a0, 0 - v_accvgpr_write_b32 a1, 0 - v_accvgpr_write_b32 a2, 0 - v_accvgpr_write_b32 a3, 0 - v_accvgpr_write_b32 a4, 0 - v_accvgpr_write_b32 a5, 0 - v_accvgpr_write_b32 a6, 0 - v_accvgpr_write_b32 a7, 0 - v_accvgpr_write_b32 a8, 0 - v_accvgpr_write_b32 a9, 0 - v_accvgpr_write_b32 a10, 0 - v_accvgpr_write_b32 a11, 0 - v_accvgpr_write_b32 a12, 0 - v_accvgpr_write_b32 a13, 0 - v_accvgpr_write_b32 a14, 0 - s_cmp_ge_i32 s22, s5 - v_accvgpr_write_b32 a15, 0 - s_cbranch_scc0 .LBB0_4 -; %bb.1: - s_cmp_lg_u32 s12, 1 - s_mov_b64 s[0:1], -1 - s_cbranch_scc1 .LBB0_26 -.LBB0_2: ; %Flow442 - s_andn2_b64 vcc, exec, s[0:1] - s_cbranch_vccz .LBB0_60 -.LBB0_3: ; %.loopexit - s_endpgm -.LBB0_4: - s_load_dwordx4 s[16:19], s[0:1], 0x0 - s_ashr_i32 s11, s10, 1 - v_or_b32_e32 v10, s15, v1 - v_mad_i64_i32 v[4:5], s[20:21], s11, v10, 0 - s_ashr_i32 s20, s22, 1 - v_lshlrev_b32_e32 v15, 4, v14 - v_add_u32_e32 v8, s20, v15 - v_cmp_gt_i32_e32 vcc, s8, v10 - v_mov_b32_e32 v19, v18 - v_mov_b32_e32 v20, v18 - v_mov_b32_e32 v21, v18 - s_waitcnt lgkmcnt(0) - v_lshl_add_u64 v[4:5], s[16:17], 0, v[4:5] - v_ashrrev_i32_e32 v9, 31, v8 - s_and_saveexec_b64 s[16:17], vcc - s_cbranch_execz .LBB0_6 -; %bb.5: ; %.loopexit44.i - v_lshl_add_u64 v[6:7], v[4:5], 0, v[8:9] - global_load_dwordx4 v[18:21], v[6:7], off -.LBB0_6: - s_or_b64 exec, exec, s[16:17] - s_load_dwordx2 s[20:21], s[0:1], 0x18 - v_mad_i64_i32 v[6:7], s[16:17], s11, v2, 0 - v_mov_b32_e32 v26, 0 - v_mov_b32_e32 v27, v26 - v_mov_b32_e32 v28, v26 - v_mov_b32_e32 v29, v26 - v_lshl_add_u64 v[6:7], s[18:19], 0, v[6:7] - s_and_saveexec_b64 s[16:17], s[2:3] - s_cbranch_execz .LBB0_8 -; %bb.7: ; %.loopexit.i - v_lshl_add_u64 v[8:9], v[6:7], 0, v[8:9] - global_load_dwordx4 v[26:29], v[8:9], off -.LBB0_8: - s_or_b64 exec, exec, s[16:17] - s_load_dwordx2 s[16:17], s[0:1], 0x20 - v_mad_i64_i32 v[8:9], s[0:1], s13, v10, 0 - s_ashr_i32 s0, s22, 5 - s_nop 0 - v_add_u32_e32 v10, s0, v14 - v_cmp_gt_i32_e64 s[0:1], s13, v10 - s_and_b64 s[18:19], vcc, s[0:1] - v_mov_b32_e32 v17, 0x7f - s_waitcnt lgkmcnt(0) - v_lshl_add_u64 v[8:9], s[20:21], 0, v[8:9] - v_mov_b32_e32 v16, 0x7f - s_and_saveexec_b64 s[0:1], s[18:19] - s_cbranch_execz .LBB0_10 -; %bb.9: - v_ashrrev_i32_e32 v11, 31, v10 - v_lshl_add_u64 v[12:13], v[8:9], 0, v[10:11] - global_load_ubyte v16, v[12:13], off -.LBB0_10: - s_or_b64 exec, exec, s[0:1] - s_ashr_i32 s0, s23, 5 - v_lshlrev_b32_e32 v0, 2, v0 - s_lshl_b32 s1, s14, 5 - v_and_b32_e32 v12, 60, v0 - v_mov_b32_e32 v13, 0 - v_mov_b32_e32 v0, s0 - s_ashr_i32 s10, s10, 5 - v_mad_i64_i32 v[34:35], s[0:1], s1, v0, v[12:13] - v_lshrrev_b32_e32 v0, 4, v1 - v_or_b32_e32 v34, v34, v0 - v_cmp_gt_i32_e64 s[0:1], s10, v10 - v_accvgpr_write_b32 a0, 0 - v_accvgpr_write_b32 a1, 0 - v_accvgpr_write_b32 a2, 0 - v_accvgpr_write_b32 a3, 0 - v_accvgpr_write_b32 a4, 0 - v_accvgpr_write_b32 a5, 0 - v_accvgpr_write_b32 a6, 0 - v_accvgpr_write_b32 a7, 0 - v_accvgpr_write_b32 a8, 0 - v_accvgpr_write_b32 a9, 0 - v_accvgpr_write_b32 a10, 0 - v_accvgpr_write_b32 a11, 0 - v_accvgpr_write_b32 a12, 0 - v_accvgpr_write_b32 a13, 0 - v_accvgpr_write_b32 a14, 0 - v_accvgpr_write_b32 a15, 0 - s_and_b64 s[18:19], s[2:3], s[0:1] - v_lshl_add_u64 v[0:1], s[16:17], 0, v[34:35] - s_and_saveexec_b64 s[0:1], s[18:19] - s_cbranch_execz .LBB0_12 -; %bb.11: - v_lshlrev_b32_e32 v11, 5, v10 - v_and_b32_e32 v34, 0xffffff00, v11 - v_ashrrev_i32_e32 v35, 31, v34 - v_lshlrev_b32_e32 v11, 6, v10 - v_and_b32_e32 v12, 0xc0, v11 - v_lshrrev_b32_e32 v10, 1, v10 - v_lshl_add_u64 v[34:35], v[0:1], 0, v[34:35] - v_and_b32_e32 v10, 2, v10 - v_mov_b32_e32 v11, v13 - v_lshl_add_u64 v[12:13], v[34:35], 0, v[12:13] - v_lshl_add_u64 v[10:11], v[12:13], 0, v[10:11] - global_load_ubyte v17, v[10:11], off -.LBB0_12: ; %_Z9load_fragPKhS0_S0_S0_iiiillllibbRDv32_hS2_RhS3_.exit - s_or_b64 exec, exec, s[0:1] - s_add_i32 s11, s22, 64 - s_cmp_ge_i32 s11, s5 - s_cbranch_scc1 .LBB0_24 -; %bb.13: ; %.lr.ph - v_mov_b32_e32 v34, 0 - v_accvgpr_write_b32 a15, 0 - v_accvgpr_write_b32 a14, 0 - v_accvgpr_write_b32 a13, 0 - v_accvgpr_write_b32 a12, 0 - v_accvgpr_write_b32 a11, 0 - v_accvgpr_write_b32 a10, 0 - v_accvgpr_write_b32 a9, 0 - v_accvgpr_write_b32 a8, 0 - v_accvgpr_write_b32 a7, 0 - v_accvgpr_write_b32 a6, 0 - v_accvgpr_write_b32 a5, 0 - v_accvgpr_write_b32 a4, 0 - v_accvgpr_write_b32 a3, 0 - v_accvgpr_write_b32 a2, 0 - v_accvgpr_write_b32 a1, 0 - v_accvgpr_write_b32 a0, 0 -.LBB0_14: ; =>This Inner Loop Header: Depth=1 - s_ashr_i32 s0, s11, 1 - v_add_u32_e32 v10, s0, v15 - v_mov_b32_e32 v35, v34 - v_mov_b32_e32 v36, v34 - v_mov_b32_e32 v37, v34 - v_mov_b64_e32 v[44:45], v[40:41] - v_ashrrev_i32_e32 v11, 31, v10 - v_mov_b64_e32 v[42:43], v[38:39] - v_mov_b64_e32 v[40:41], v[36:37] - v_mov_b64_e32 v[38:39], v[34:35] - s_and_saveexec_b64 s[0:1], vcc - s_cbranch_execz .LBB0_16 -; %bb.15: ; %.loopexit44.i136 - ; in Loop: Header=BB0_14 Depth=1 - v_lshl_add_u64 v[12:13], v[4:5], 0, v[10:11] - global_load_dwordx4 v[38:41], v[12:13], off -.LBB0_16: ; in Loop: Header=BB0_14 Depth=1 - s_or_b64 exec, exec, s[0:1] - s_waitcnt vmcnt(0) - v_mov_b64_e32 v[52:53], v[40:41] - v_mov_b64_e32 v[50:51], v[38:39] - v_mov_b64_e32 v[48:49], v[36:37] - v_mov_b64_e32 v[46:47], v[34:35] - s_and_saveexec_b64 s[0:1], s[2:3] - s_cbranch_execz .LBB0_18 -; %bb.17: ; %.loopexit.i134 - ; in Loop: Header=BB0_14 Depth=1 - v_lshl_add_u64 v[10:11], v[6:7], 0, v[10:11] - global_load_dwordx4 v[46:49], v[10:11], off -.LBB0_18: ; in Loop: Header=BB0_14 Depth=1 - s_or_b64 exec, exec, s[0:1] - s_ashr_i32 s0, s11, 5 - v_add_u32_e32 v10, s0, v14 - v_cmp_gt_i32_e64 s[0:1], s13, v10 - s_and_b64 s[16:17], vcc, s[0:1] - v_mov_b32_e32 v12, 0x7f - v_mov_b32_e32 v11, 0x7f - s_and_saveexec_b64 s[0:1], s[16:17] - s_cbranch_execz .LBB0_20 -; %bb.19: ; in Loop: Header=BB0_14 Depth=1 - v_ashrrev_i32_e32 v11, 31, v10 - v_lshl_add_u64 v[22:23], v[8:9], 0, v[10:11] - global_load_ubyte v11, v[22:23], off -.LBB0_20: ; in Loop: Header=BB0_14 Depth=1 - s_or_b64 exec, exec, s[0:1] - v_cmp_gt_i32_e64 s[0:1], s10, v10 - s_and_b64 s[16:17], s[2:3], s[0:1] - s_and_saveexec_b64 s[0:1], s[16:17] - s_cbranch_execz .LBB0_22 -; %bb.21: ; in Loop: Header=BB0_14 Depth=1 - v_lshlrev_b32_e32 v12, 5, v10 - v_and_b32_e32 v12, 0xffffff00, v12 - v_ashrrev_i32_e32 v13, 31, v12 - v_lshlrev_b32_e32 v22, 6, v10 - v_and_b32_e32 v22, 0xc0, v22 - v_mov_b32_e32 v23, v34 - v_lshrrev_b32_e32 v10, 1, v10 - v_lshl_add_u64 v[12:13], v[0:1], 0, v[12:13] - v_and_b32_e32 v24, 2, v10 - v_mov_b32_e32 v25, v34 - v_lshl_add_u64 v[12:13], v[12:13], 0, v[22:23] - v_lshl_add_u64 v[12:13], v[12:13], 0, v[24:25] - global_load_ubyte v12, v[12:13], off -.LBB0_22: ; %_Z9load_fragPKhS0_S0_S0_iiiillllibbRDv32_hS2_RhS3_.exit138 - ; in Loop: Header=BB0_14 Depth=1 - s_or_b64 exec, exec, s[0:1] - v_and_b32_e32 v10, 0xff, v16 - v_and_b32_e32 v13, 0xff, v17 - s_add_i32 s11, s11, 64 - s_cmp_ge_i32 s11, s5 - v_mfma_scale_f32_32x32x64_f8f6f4 a[0:15], v[18:21], v[26:29], a[0:15], v10, v13 op_sel_hi:[0,0,0] cbsz:4 blgp:4 - s_cbranch_scc1 .LBB0_25 -; %bb.23: ; in Loop: Header=BB0_14 Depth=1 - v_mov_b64_e32 v[18:19], v[38:39] - s_waitcnt vmcnt(0) - v_mov_b64_e32 v[26:27], v[46:47] - v_mov_b64_e32 v[20:21], v[40:41] - v_mov_b64_e32 v[22:23], v[42:43] - v_mov_b64_e32 v[24:25], v[44:45] - v_mov_b64_e32 v[28:29], v[48:49] - v_mov_b64_e32 v[30:31], v[50:51] - v_mov_b64_e32 v[32:33], v[52:53] - v_mov_b32_e32 v17, v12 - v_mov_b32_e32 v16, v11 - s_branch .LBB0_14 -.LBB0_24: - s_waitcnt vmcnt(0) - v_mov_b64_e32 v[44:45], v[24:25] - v_mov_b64_e32 v[52:53], v[32:33] - v_mov_b64_e32 v[40:41], v[20:21] - v_mov_b64_e32 v[38:39], v[18:19] - v_mov_b64_e32 v[48:49], v[28:29] - v_mov_b64_e32 v[46:47], v[26:27] - v_mov_b32_e32 v12, v17 - v_mov_b32_e32 v11, v16 - v_mov_b64_e32 v[42:43], v[22:23] - v_mov_b64_e32 v[50:51], v[30:31] -.LBB0_25: ; %Flow444 - s_waitcnt vmcnt(0) - v_and_b32_e32 v0, 0xff, v11 - v_and_b32_e32 v1, 0xff, v12 - s_nop 1 - v_mfma_scale_f32_32x32x64_f8f6f4 a[0:15], v[38:41], v[46:49], a[0:15], v0, v1 op_sel_hi:[0,0,0] cbsz:4 blgp:4 - s_cmp_lg_u32 s12, 1 - s_mov_b64 s[0:1], -1 - s_cbranch_scc0 .LBB0_2 -.LBB0_26: - s_and_saveexec_b64 s[0:1], s[2:3] - s_cbranch_execz .LBB0_59 -; %bb.27: ; %.preheader154 - v_lshl_add_u32 v4, v14, 2, s15 - s_mul_hi_i32 s5, s8, s4 - s_mul_i32 s4, s8, s4 - s_ashr_i32 s12, s9, 31 - v_lshl_add_u64 v[0:1], v[2:3], 2, s[6:7] - v_cmp_gt_i32_e32 vcc, s8, v4 - s_and_saveexec_b64 s[10:11], vcc - s_cbranch_execz .LBB0_29 -; %bb.28: - v_ashrrev_i32_e32 v5, 31, v4 - v_lshl_add_u64 v[6:7], s[4:5], 0, v[4:5] - v_mul_lo_u32 v5, v7, s9 - v_mul_lo_u32 v8, v6, s12 - v_mad_u64_u32 v[6:7], s[16:17], v6, s9, 0 - v_add3_u32 v7, v7, v8, v5 - v_lshl_add_u64 v[6:7], v[6:7], 2, v[0:1] - global_store_dword v[6:7], a0, off -.LBB0_29: - s_or_b64 exec, exec, s[10:11] - v_or_b32_e32 v6, 1, v4 - v_cmp_gt_i32_e32 vcc, s8, v6 - s_and_saveexec_b64 s[10:11], vcc - s_cbranch_execz .LBB0_31 -; %bb.30: - v_ashrrev_i32_e32 v7, 31, v6 - v_lshl_add_u64 v[6:7], s[4:5], 0, v[6:7] - v_mul_lo_u32 v5, v7, s9 - v_mul_lo_u32 v8, v6, s12 - v_mad_u64_u32 v[6:7], s[16:17], v6, s9, 0 - v_add3_u32 v7, v7, v8, v5 - v_lshl_add_u64 v[6:7], v[6:7], 2, v[0:1] - global_store_dword v[6:7], a1, off -.LBB0_31: - s_or_b64 exec, exec, s[10:11] - v_or_b32_e32 v6, 2, v4 - v_cmp_gt_i32_e32 vcc, s8, v6 - s_and_saveexec_b64 s[10:11], vcc - s_cbranch_execz .LBB0_33 -; %bb.32: - v_ashrrev_i32_e32 v7, 31, v6 - v_lshl_add_u64 v[6:7], s[4:5], 0, v[6:7] - v_mul_lo_u32 v5, v7, s9 - v_mul_lo_u32 v8, v6, s12 - v_mad_u64_u32 v[6:7], s[16:17], v6, s9, 0 - v_add3_u32 v7, v7, v8, v5 - v_lshl_add_u64 v[6:7], v[6:7], 2, v[0:1] - global_store_dword v[6:7], a2, off -.LBB0_33: - s_or_b64 exec, exec, s[10:11] - v_or_b32_e32 v6, 3, v4 - v_cmp_gt_i32_e32 vcc, s8, v6 - s_and_saveexec_b64 s[10:11], vcc - s_cbranch_execz .LBB0_35 -; %bb.34: - v_ashrrev_i32_e32 v7, 31, v6 - v_lshl_add_u64 v[6:7], s[4:5], 0, v[6:7] - v_mul_lo_u32 v5, v7, s9 - v_mul_lo_u32 v8, v6, s12 - v_mad_u64_u32 v[6:7], s[16:17], v6, s9, 0 - v_add3_u32 v7, v7, v8, v5 - v_lshl_add_u64 v[6:7], v[6:7], 2, v[0:1] - global_store_dword v[6:7], a3, off -.LBB0_35: - s_or_b64 exec, exec, s[10:11] - v_add_u32_e32 v6, 8, v4 - v_cmp_gt_i32_e32 vcc, s8, v6 - s_and_saveexec_b64 s[10:11], vcc - s_cbranch_execz .LBB0_37 -; %bb.36: - v_ashrrev_i32_e32 v7, 31, v6 - v_lshl_add_u64 v[6:7], s[4:5], 0, v[6:7] - v_mul_lo_u32 v5, v7, s9 - v_mul_lo_u32 v8, v6, s12 - v_mad_u64_u32 v[6:7], s[16:17], v6, s9, 0 - v_add3_u32 v7, v7, v8, v5 - v_lshl_add_u64 v[6:7], v[6:7], 2, v[0:1] - global_store_dword v[6:7], a4, off -.LBB0_37: - s_or_b64 exec, exec, s[10:11] - v_add_u32_e32 v6, 9, v4 - v_cmp_gt_i32_e32 vcc, s8, v6 - s_and_saveexec_b64 s[10:11], vcc - s_cbranch_execz .LBB0_39 -; %bb.38: - v_ashrrev_i32_e32 v7, 31, v6 - v_lshl_add_u64 v[6:7], s[4:5], 0, v[6:7] - v_mul_lo_u32 v5, v7, s9 - v_mul_lo_u32 v8, v6, s12 - v_mad_u64_u32 v[6:7], s[16:17], v6, s9, 0 - v_add3_u32 v7, v7, v8, v5 - v_lshl_add_u64 v[6:7], v[6:7], 2, v[0:1] - global_store_dword v[6:7], a5, off -.LBB0_39: - s_or_b64 exec, exec, s[10:11] - v_add_u32_e32 v6, 10, v4 - v_cmp_gt_i32_e32 vcc, s8, v6 - s_and_saveexec_b64 s[10:11], vcc - s_cbranch_execz .LBB0_41 -; %bb.40: - v_ashrrev_i32_e32 v7, 31, v6 - v_lshl_add_u64 v[6:7], s[4:5], 0, v[6:7] - v_mul_lo_u32 v5, v7, s9 - v_mul_lo_u32 v8, v6, s12 - v_mad_u64_u32 v[6:7], s[16:17], v6, s9, 0 - v_add3_u32 v7, v7, v8, v5 - v_lshl_add_u64 v[6:7], v[6:7], 2, v[0:1] - global_store_dword v[6:7], a6, off -.LBB0_41: - s_or_b64 exec, exec, s[10:11] - v_add_u32_e32 v6, 11, v4 - v_cmp_gt_i32_e32 vcc, s8, v6 - s_and_saveexec_b64 s[10:11], vcc - s_cbranch_execz .LBB0_43 -; %bb.42: - v_ashrrev_i32_e32 v7, 31, v6 - v_lshl_add_u64 v[6:7], s[4:5], 0, v[6:7] - v_mul_lo_u32 v5, v7, s9 - v_mul_lo_u32 v8, v6, s12 - v_mad_u64_u32 v[6:7], s[16:17], v6, s9, 0 - v_add3_u32 v7, v7, v8, v5 - v_lshl_add_u64 v[6:7], v[6:7], 2, v[0:1] - global_store_dword v[6:7], a7, off -.LBB0_43: - s_or_b64 exec, exec, s[10:11] - v_add_u32_e32 v6, 16, v4 - v_cmp_gt_i32_e32 vcc, s8, v6 - s_and_saveexec_b64 s[10:11], vcc - s_cbranch_execz .LBB0_45 -; %bb.44: - v_ashrrev_i32_e32 v7, 31, v6 - v_lshl_add_u64 v[6:7], s[4:5], 0, v[6:7] - v_mul_lo_u32 v5, v7, s9 - v_mul_lo_u32 v8, v6, s12 - v_mad_u64_u32 v[6:7], s[16:17], v6, s9, 0 - v_add3_u32 v7, v7, v8, v5 - v_lshl_add_u64 v[6:7], v[6:7], 2, v[0:1] - global_store_dword v[6:7], a8, off -.LBB0_45: - s_or_b64 exec, exec, s[10:11] - v_add_u32_e32 v6, 17, v4 - v_cmp_gt_i32_e32 vcc, s8, v6 - s_and_saveexec_b64 s[10:11], vcc - s_cbranch_execz .LBB0_47 -; %bb.46: - v_ashrrev_i32_e32 v7, 31, v6 - v_lshl_add_u64 v[6:7], s[4:5], 0, v[6:7] - v_mul_lo_u32 v5, v7, s9 - v_mul_lo_u32 v8, v6, s12 - v_mad_u64_u32 v[6:7], s[16:17], v6, s9, 0 - v_add3_u32 v7, v7, v8, v5 - v_lshl_add_u64 v[6:7], v[6:7], 2, v[0:1] - global_store_dword v[6:7], a9, off -.LBB0_47: - s_or_b64 exec, exec, s[10:11] - v_add_u32_e32 v6, 18, v4 - v_cmp_gt_i32_e32 vcc, s8, v6 - s_and_saveexec_b64 s[10:11], vcc - s_cbranch_execz .LBB0_49 -; %bb.48: - v_ashrrev_i32_e32 v7, 31, v6 - v_lshl_add_u64 v[6:7], s[4:5], 0, v[6:7] - v_mul_lo_u32 v5, v7, s9 - v_mul_lo_u32 v8, v6, s12 - v_mad_u64_u32 v[6:7], s[16:17], v6, s9, 0 - v_add3_u32 v7, v7, v8, v5 - v_lshl_add_u64 v[6:7], v[6:7], 2, v[0:1] - global_store_dword v[6:7], a10, off -.LBB0_49: - s_or_b64 exec, exec, s[10:11] - v_add_u32_e32 v6, 19, v4 - v_cmp_gt_i32_e32 vcc, s8, v6 - s_and_saveexec_b64 s[10:11], vcc - s_cbranch_execz .LBB0_51 -; %bb.50: - v_ashrrev_i32_e32 v7, 31, v6 - v_lshl_add_u64 v[6:7], s[4:5], 0, v[6:7] - v_mul_lo_u32 v5, v7, s9 - v_mul_lo_u32 v8, v6, s12 - v_mad_u64_u32 v[6:7], s[16:17], v6, s9, 0 - v_add3_u32 v7, v7, v8, v5 - v_lshl_add_u64 v[6:7], v[6:7], 2, v[0:1] - global_store_dword v[6:7], a11, off -.LBB0_51: - s_or_b64 exec, exec, s[10:11] - v_add_u32_e32 v6, 24, v4 - v_cmp_gt_i32_e32 vcc, s8, v6 - s_and_saveexec_b64 s[10:11], vcc - s_cbranch_execz .LBB0_53 -; %bb.52: - v_ashrrev_i32_e32 v7, 31, v6 - v_lshl_add_u64 v[6:7], s[4:5], 0, v[6:7] - v_mul_lo_u32 v5, v7, s9 - v_mul_lo_u32 v8, v6, s12 - v_mad_u64_u32 v[6:7], s[16:17], v6, s9, 0 - v_add3_u32 v7, v7, v8, v5 - v_lshl_add_u64 v[6:7], v[6:7], 2, v[0:1] - global_store_dword v[6:7], a12, off -.LBB0_53: - s_or_b64 exec, exec, s[10:11] - v_add_u32_e32 v6, 25, v4 - v_cmp_gt_i32_e32 vcc, s8, v6 - s_and_saveexec_b64 s[10:11], vcc - s_cbranch_execz .LBB0_55 -; %bb.54: - v_ashrrev_i32_e32 v7, 31, v6 - v_lshl_add_u64 v[6:7], s[4:5], 0, v[6:7] - v_mul_lo_u32 v5, v7, s9 - v_mul_lo_u32 v8, v6, s12 - v_mad_u64_u32 v[6:7], s[16:17], v6, s9, 0 - v_add3_u32 v7, v7, v8, v5 - v_lshl_add_u64 v[6:7], v[6:7], 2, v[0:1] - global_store_dword v[6:7], a13, off -.LBB0_55: - s_or_b64 exec, exec, s[10:11] - v_add_u32_e32 v6, 26, v4 - v_cmp_gt_i32_e32 vcc, s8, v6 - s_and_saveexec_b64 s[10:11], vcc - s_cbranch_execz .LBB0_57 -; %bb.56: - v_ashrrev_i32_e32 v7, 31, v6 - v_lshl_add_u64 v[6:7], s[4:5], 0, v[6:7] - v_mul_lo_u32 v5, v7, s9 - v_mul_lo_u32 v8, v6, s12 - v_mad_u64_u32 v[6:7], s[16:17], v6, s9, 0 - v_add3_u32 v7, v7, v8, v5 - v_lshl_add_u64 v[6:7], v[6:7], 2, v[0:1] - global_store_dword v[6:7], a14, off -.LBB0_57: - s_or_b64 exec, exec, s[10:11] - v_add_u32_e32 v4, 27, v4 - v_cmp_gt_i32_e32 vcc, s8, v4 - s_and_b64 exec, exec, vcc - s_cbranch_execz .LBB0_59 -; %bb.58: - v_ashrrev_i32_e32 v5, 31, v4 - v_lshl_add_u64 v[4:5], s[4:5], 0, v[4:5] - v_mul_lo_u32 v6, v5, s9 - v_mul_lo_u32 v7, v4, s12 - v_mad_u64_u32 v[4:5], s[4:5], v4, s9, 0 - v_add3_u32 v5, v5, v7, v6 - v_lshl_add_u64 v[0:1], v[4:5], 2, v[0:1] - global_store_dword v[0:1], a15, off -.LBB0_59: ; %Flow439 - s_or_b64 exec, exec, s[0:1] - s_cbranch_execnz .LBB0_3 -.LBB0_60: - s_and_saveexec_b64 s[0:1], s[2:3] - s_cbranch_execz .LBB0_3 -; %bb.61: ; %.preheader - v_lshl_add_u32 v4, v14, 2, s15 - v_lshl_add_u64 v[0:1], v[2:3], 1, s[6:7] - v_cmp_gt_i32_e32 vcc, s8, v4 - s_and_saveexec_b64 s[0:1], vcc - s_cbranch_execz .LBB0_63 -; %bb.62: - v_mad_i64_i32 v[2:3], s[2:3], v4, s9, 0 - v_lshl_add_u64 v[2:3], v[2:3], 1, v[0:1] - global_store_short_d16_hi v[2:3], a0, off -.LBB0_63: - s_or_b64 exec, exec, s[0:1] - v_or_b32_e32 v2, 1, v4 - v_cmp_gt_i32_e32 vcc, s8, v2 - s_and_saveexec_b64 s[0:1], vcc - s_cbranch_execz .LBB0_65 -; %bb.64: - v_mad_i64_i32 v[2:3], s[2:3], v2, s9, 0 - v_lshl_add_u64 v[2:3], v[2:3], 1, v[0:1] - global_store_short_d16_hi v[2:3], a1, off -.LBB0_65: - s_or_b64 exec, exec, s[0:1] - v_or_b32_e32 v2, 2, v4 - v_cmp_gt_i32_e32 vcc, s8, v2 - s_and_saveexec_b64 s[0:1], vcc - s_cbranch_execz .LBB0_67 -; %bb.66: - v_mad_i64_i32 v[2:3], s[2:3], v2, s9, 0 - v_lshl_add_u64 v[2:3], v[2:3], 1, v[0:1] - global_store_short_d16_hi v[2:3], a2, off -.LBB0_67: - s_or_b64 exec, exec, s[0:1] - v_or_b32_e32 v2, 3, v4 - v_cmp_gt_i32_e32 vcc, s8, v2 - s_and_saveexec_b64 s[0:1], vcc - s_cbranch_execz .LBB0_69 -; %bb.68: - v_mad_i64_i32 v[2:3], s[2:3], v2, s9, 0 - v_lshl_add_u64 v[2:3], v[2:3], 1, v[0:1] - global_store_short_d16_hi v[2:3], a3, off -.LBB0_69: - s_or_b64 exec, exec, s[0:1] - v_add_u32_e32 v2, 8, v4 - v_cmp_gt_i32_e32 vcc, s8, v2 - s_and_saveexec_b64 s[0:1], vcc - s_cbranch_execz .LBB0_71 -; %bb.70: - v_mad_i64_i32 v[2:3], s[2:3], v2, s9, 0 - v_lshl_add_u64 v[2:3], v[2:3], 1, v[0:1] - global_store_short_d16_hi v[2:3], a4, off -.LBB0_71: - s_or_b64 exec, exec, s[0:1] - v_add_u32_e32 v2, 9, v4 - v_cmp_gt_i32_e32 vcc, s8, v2 - s_and_saveexec_b64 s[0:1], vcc - s_cbranch_execz .LBB0_73 -; %bb.72: - v_mad_i64_i32 v[2:3], s[2:3], v2, s9, 0 - v_lshl_add_u64 v[2:3], v[2:3], 1, v[0:1] - global_store_short_d16_hi v[2:3], a5, off -.LBB0_73: - s_or_b64 exec, exec, s[0:1] - v_add_u32_e32 v2, 10, v4 - v_cmp_gt_i32_e32 vcc, s8, v2 - s_and_saveexec_b64 s[0:1], vcc - s_cbranch_execz .LBB0_75 -; %bb.74: - v_mad_i64_i32 v[2:3], s[2:3], v2, s9, 0 - v_lshl_add_u64 v[2:3], v[2:3], 1, v[0:1] - global_store_short_d16_hi v[2:3], a6, off -.LBB0_75: - s_or_b64 exec, exec, s[0:1] - v_add_u32_e32 v2, 11, v4 - v_cmp_gt_i32_e32 vcc, s8, v2 - s_and_saveexec_b64 s[0:1], vcc - s_cbranch_execz .LBB0_77 -; %bb.76: - v_mad_i64_i32 v[2:3], s[2:3], v2, s9, 0 - v_lshl_add_u64 v[2:3], v[2:3], 1, v[0:1] - global_store_short_d16_hi v[2:3], a7, off -.LBB0_77: - s_or_b64 exec, exec, s[0:1] - v_add_u32_e32 v2, 16, v4 - v_cmp_gt_i32_e32 vcc, s8, v2 - s_and_saveexec_b64 s[0:1], vcc - s_cbranch_execz .LBB0_79 -; %bb.78: - v_mad_i64_i32 v[2:3], s[2:3], v2, s9, 0 - v_lshl_add_u64 v[2:3], v[2:3], 1, v[0:1] - global_store_short_d16_hi v[2:3], a8, off -.LBB0_79: - s_or_b64 exec, exec, s[0:1] - v_add_u32_e32 v2, 17, v4 - v_cmp_gt_i32_e32 vcc, s8, v2 - s_and_saveexec_b64 s[0:1], vcc - s_cbranch_execz .LBB0_81 -; %bb.80: - v_mad_i64_i32 v[2:3], s[2:3], v2, s9, 0 - v_lshl_add_u64 v[2:3], v[2:3], 1, v[0:1] - global_store_short_d16_hi v[2:3], a9, off -.LBB0_81: - s_or_b64 exec, exec, s[0:1] - v_add_u32_e32 v2, 18, v4 - v_cmp_gt_i32_e32 vcc, s8, v2 - s_and_saveexec_b64 s[0:1], vcc - s_cbranch_execz .LBB0_83 -; %bb.82: - v_mad_i64_i32 v[2:3], s[2:3], v2, s9, 0 - v_lshl_add_u64 v[2:3], v[2:3], 1, v[0:1] - global_store_short_d16_hi v[2:3], a10, off -.LBB0_83: - s_or_b64 exec, exec, s[0:1] - v_add_u32_e32 v2, 19, v4 - v_cmp_gt_i32_e32 vcc, s8, v2 - s_and_saveexec_b64 s[0:1], vcc - s_cbranch_execz .LBB0_85 -; %bb.84: - v_mad_i64_i32 v[2:3], s[2:3], v2, s9, 0 - v_lshl_add_u64 v[2:3], v[2:3], 1, v[0:1] - global_store_short_d16_hi v[2:3], a11, off -.LBB0_85: - s_or_b64 exec, exec, s[0:1] - v_add_u32_e32 v2, 24, v4 - v_cmp_gt_i32_e32 vcc, s8, v2 - s_and_saveexec_b64 s[0:1], vcc - s_cbranch_execz .LBB0_87 -; %bb.86: - v_mad_i64_i32 v[2:3], s[2:3], v2, s9, 0 - v_lshl_add_u64 v[2:3], v[2:3], 1, v[0:1] - global_store_short_d16_hi v[2:3], a12, off -.LBB0_87: - s_or_b64 exec, exec, s[0:1] - v_add_u32_e32 v2, 25, v4 - v_cmp_gt_i32_e32 vcc, s8, v2 - s_and_saveexec_b64 s[0:1], vcc - s_cbranch_execz .LBB0_89 -; %bb.88: - v_mad_i64_i32 v[2:3], s[2:3], v2, s9, 0 - v_lshl_add_u64 v[2:3], v[2:3], 1, v[0:1] - global_store_short_d16_hi v[2:3], a13, off -.LBB0_89: - s_or_b64 exec, exec, s[0:1] - v_add_u32_e32 v2, 26, v4 - v_cmp_gt_i32_e32 vcc, s8, v2 - s_and_saveexec_b64 s[0:1], vcc - s_cbranch_execz .LBB0_91 -; %bb.90: - v_mad_i64_i32 v[2:3], s[2:3], v2, s9, 0 - v_lshl_add_u64 v[2:3], v[2:3], 1, v[0:1] - global_store_short_d16_hi v[2:3], a14, off -.LBB0_91: - s_or_b64 exec, exec, s[0:1] - v_add_u32_e32 v2, 27, v4 - v_cmp_gt_i32_e32 vcc, s8, v2 - s_and_b64 exec, exec, vcc - s_cbranch_execz .LBB0_3 -; %bb.92: - v_mad_i64_i32 v[2:3], s[0:1], v2, s9, 0 - v_lshl_add_u64 v[0:1], v[2:3], 1, v[0:1] - global_store_short_d16_hi v[0:1], a15, off - s_endpgm - .section .rodata,"a",@progbits - .p2align 6, 0x0 - .amdhsa_kernel _Z9gemm_corePKhS0_PvS0_S0_iiiiiii - .amdhsa_group_segment_fixed_size 0 - .amdhsa_private_segment_fixed_size 0 - .amdhsa_kernarg_size 68 - .amdhsa_user_sgpr_count 2 - .amdhsa_user_sgpr_dispatch_ptr 0 - .amdhsa_user_sgpr_queue_ptr 0 - .amdhsa_user_sgpr_kernarg_segment_ptr 1 - .amdhsa_user_sgpr_dispatch_id 0 - .amdhsa_user_sgpr_kernarg_preload_length 0 - .amdhsa_user_sgpr_kernarg_preload_offset 0 - .amdhsa_user_sgpr_private_segment_size 0 - .amdhsa_uses_dynamic_stack 0 - .amdhsa_enable_private_segment 0 - .amdhsa_system_sgpr_workgroup_id_x 1 - .amdhsa_system_sgpr_workgroup_id_y 1 - .amdhsa_system_sgpr_workgroup_id_z 1 - .amdhsa_system_sgpr_workgroup_info 0 - .amdhsa_system_vgpr_workitem_id 0 - .amdhsa_next_free_vgpr 72 - .amdhsa_next_free_sgpr 24 - .amdhsa_accum_offset 56 - .amdhsa_reserve_vcc 1 - .amdhsa_float_round_mode_32 0 - .amdhsa_float_round_mode_16_64 0 - .amdhsa_float_denorm_mode_32 3 - .amdhsa_float_denorm_mode_16_64 3 - .amdhsa_dx10_clamp 1 - .amdhsa_ieee_mode 1 - .amdhsa_fp16_overflow 0 - .amdhsa_tg_split 0 - .amdhsa_exception_fp_ieee_invalid_op 0 - .amdhsa_exception_fp_denorm_src 0 - .amdhsa_exception_fp_ieee_div_zero 0 - .amdhsa_exception_fp_ieee_overflow 0 - .amdhsa_exception_fp_ieee_underflow 0 - .amdhsa_exception_fp_ieee_inexact 0 - .amdhsa_exception_int_div_zero 0 - .end_amdhsa_kernel - .text -.Lfunc_end0: - .size _Z9gemm_corePKhS0_PvS0_S0_iiiiiii, .Lfunc_end0-_Z9gemm_corePKhS0_PvS0_S0_iiiiiii - ; -- End function - .set _Z9gemm_corePKhS0_PvS0_S0_iiiiiii.num_vgpr, 54 - .set _Z9gemm_corePKhS0_PvS0_S0_iiiiiii.num_agpr, 16 - .set _Z9gemm_corePKhS0_PvS0_S0_iiiiiii.numbered_sgpr, 24 - .set _Z9gemm_corePKhS0_PvS0_S0_iiiiiii.num_named_barrier, 0 - .set _Z9gemm_corePKhS0_PvS0_S0_iiiiiii.private_seg_size, 0 - .set _Z9gemm_corePKhS0_PvS0_S0_iiiiiii.uses_vcc, 1 - .set _Z9gemm_corePKhS0_PvS0_S0_iiiiiii.uses_flat_scratch, 0 - .set _Z9gemm_corePKhS0_PvS0_S0_iiiiiii.has_dyn_sized_stack, 0 - .set _Z9gemm_corePKhS0_PvS0_S0_iiiiiii.has_recursion, 0 - .set _Z9gemm_corePKhS0_PvS0_S0_iiiiiii.has_indirect_call, 0 - .section .AMDGPU.csdata,"",@progbits -; Kernel info: -; codeLenInByte = 3424 -; TotalNumSgprs: 30 -; NumVgprs: 54 -; NumAgprs: 16 -; TotalNumVgprs: 72 -; ScratchSize: 0 -; MemoryBound: 0 -; FloatMode: 240 -; IeeeMode: 1 -; LDSByteSize: 0 bytes/workgroup (compile time only) -; SGPRBlocks: 3 -; VGPRBlocks: 8 -; NumSGPRsForWavesPerEU: 30 -; NumVGPRsForWavesPerEU: 72 -; AccumOffset: 56 -; Occupancy: 7 -; WaveLimiterHint : 0 -; COMPUTE_PGM_RSRC2:SCRATCH_EN: 0 -; COMPUTE_PGM_RSRC2:USER_SGPR: 2 -; COMPUTE_PGM_RSRC2:TRAP_HANDLER: 0 -; COMPUTE_PGM_RSRC2:TGID_X_EN: 1 -; COMPUTE_PGM_RSRC2:TGID_Y_EN: 1 -; COMPUTE_PGM_RSRC2:TGID_Z_EN: 1 -; COMPUTE_PGM_RSRC2:TIDIG_COMP_CNT: 0 -; COMPUTE_PGM_RSRC3_GFX90A:ACCUM_OFFSET: 13 -; COMPUTE_PGM_RSRC3_GFX90A:TG_SPLIT: 0 - .text - .protected _Z9reduce_skPKfP12hip_bfloat16ii ; -- Begin function _Z9reduce_skPKfP12hip_bfloat16ii - .globl _Z9reduce_skPKfP12hip_bfloat16ii - .p2align 8 - .type _Z9reduce_skPKfP12hip_bfloat16ii,@function -_Z9reduce_skPKfP12hip_bfloat16ii: ; @_Z9reduce_skPKfP12hip_bfloat16ii -; %bb.0: - s_load_dwordx2 s[4:5], s[0:1], 0x10 - v_lshl_add_u32 v0, s2, 8, v0 - s_waitcnt lgkmcnt(0) - v_cmp_gt_i32_e32 vcc, s4, v0 - s_and_saveexec_b64 s[2:3], vcc - s_cbranch_execz .LBB1_8 -; %bb.1: ; %.preheader - s_cmp_gt_i32 s5, 0 - v_ashrrev_i32_e32 v1, 31, v0 - s_cbranch_scc1 .LBB1_3 -; %bb.2: ; %.preheader.._crit_edge_crit_edge - s_load_dwordx2 s[2:3], s[0:1], 0x8 - v_mov_b32_e32 v2, 0 - s_cbranch_execz .LBB1_4 - s_branch .LBB1_7 -.LBB1_3: - s_load_dwordx2 s[2:3], s[0:1], 0x8 - v_mov_b32_e32 v2, 0 -.LBB1_4: ; %.lr.ph - s_load_dwordx2 s[6:7], s[0:1], 0x0 - s_ashr_i32 s1, s4, 31 - s_mov_b32 s0, s4 - s_lshl_b64 s[0:1], s[0:1], 2 - v_mov_b32_e32 v4, 0 - s_waitcnt lgkmcnt(0) - v_lshl_add_u64 v[2:3], v[0:1], 2, s[6:7] -.LBB1_5: ; =>This Inner Loop Header: Depth=1 - global_load_dword v5, v[2:3], off - s_add_i32 s5, s5, -1 - v_lshl_add_u64 v[2:3], v[2:3], 0, s[0:1] - s_cmp_eq_u32 s5, 0 - s_waitcnt vmcnt(0) - v_add_f32_e32 v4, v4, v5 - s_cbranch_scc0 .LBB1_5 -; %bb.6: ; %._crit_edge.loopexit - v_lshrrev_b32_e32 v2, 16, v4 -.LBB1_7: ; %Flow26 - s_waitcnt lgkmcnt(0) - v_lshl_add_u64 v[0:1], v[0:1], 1, s[2:3] - global_store_short v[0:1], v2, off -.LBB1_8: - s_endpgm - .section .rodata,"a",@progbits - .p2align 6, 0x0 - .amdhsa_kernel _Z9reduce_skPKfP12hip_bfloat16ii - .amdhsa_group_segment_fixed_size 0 - .amdhsa_private_segment_fixed_size 0 - .amdhsa_kernarg_size 24 - .amdhsa_user_sgpr_count 2 - .amdhsa_user_sgpr_dispatch_ptr 0 - .amdhsa_user_sgpr_queue_ptr 0 - .amdhsa_user_sgpr_kernarg_segment_ptr 1 - .amdhsa_user_sgpr_dispatch_id 0 - .amdhsa_user_sgpr_kernarg_preload_length 0 - .amdhsa_user_sgpr_kernarg_preload_offset 0 - .amdhsa_user_sgpr_private_segment_size 0 - .amdhsa_uses_dynamic_stack 0 - .amdhsa_enable_private_segment 0 - .amdhsa_system_sgpr_workgroup_id_x 1 - .amdhsa_system_sgpr_workgroup_id_y 0 - .amdhsa_system_sgpr_workgroup_id_z 0 - .amdhsa_system_sgpr_workgroup_info 0 - .amdhsa_system_vgpr_workitem_id 0 - .amdhsa_next_free_vgpr 6 - .amdhsa_next_free_sgpr 8 - .amdhsa_accum_offset 8 - .amdhsa_reserve_vcc 1 - .amdhsa_float_round_mode_32 0 - .amdhsa_float_round_mode_16_64 0 - .amdhsa_float_denorm_mode_32 3 - .amdhsa_float_denorm_mode_16_64 3 - .amdhsa_dx10_clamp 1 - .amdhsa_ieee_mode 1 - .amdhsa_fp16_overflow 0 - .amdhsa_tg_split 0 - .amdhsa_exception_fp_ieee_invalid_op 0 - .amdhsa_exception_fp_denorm_src 0 - .amdhsa_exception_fp_ieee_div_zero 0 - .amdhsa_exception_fp_ieee_overflow 0 - .amdhsa_exception_fp_ieee_underflow 0 - .amdhsa_exception_fp_ieee_inexact 0 - .amdhsa_exception_int_div_zero 0 - .end_amdhsa_kernel - .text -.Lfunc_end1: - .size _Z9reduce_skPKfP12hip_bfloat16ii, .Lfunc_end1-_Z9reduce_skPKfP12hip_bfloat16ii - ; -- End function - .set _Z9reduce_skPKfP12hip_bfloat16ii.num_vgpr, 6 - .set _Z9reduce_skPKfP12hip_bfloat16ii.num_agpr, 0 - .set _Z9reduce_skPKfP12hip_bfloat16ii.numbered_sgpr, 8 - .set _Z9reduce_skPKfP12hip_bfloat16ii.num_named_barrier, 0 - .set _Z9reduce_skPKfP12hip_bfloat16ii.private_seg_size, 0 - .set _Z9reduce_skPKfP12hip_bfloat16ii.uses_vcc, 1 - .set _Z9reduce_skPKfP12hip_bfloat16ii.uses_flat_scratch, 0 - .set _Z9reduce_skPKfP12hip_bfloat16ii.has_dyn_sized_stack, 0 - .set _Z9reduce_skPKfP12hip_bfloat16ii.has_recursion, 0 - .set _Z9reduce_skPKfP12hip_bfloat16ii.has_indirect_call, 0 - .section .AMDGPU.csdata,"",@progbits -; Kernel info: -; codeLenInByte = 176 -; TotalNumSgprs: 14 -; NumVgprs: 6 -; NumAgprs: 0 -; TotalNumVgprs: 6 -; ScratchSize: 0 -; MemoryBound: 0 -; FloatMode: 240 -; IeeeMode: 1 -; LDSByteSize: 0 bytes/workgroup (compile time only) -; SGPRBlocks: 1 -; VGPRBlocks: 0 -; NumSGPRsForWavesPerEU: 14 -; NumVGPRsForWavesPerEU: 6 -; AccumOffset: 8 -; Occupancy: 8 -; WaveLimiterHint : 0 -; COMPUTE_PGM_RSRC2:SCRATCH_EN: 0 -; COMPUTE_PGM_RSRC2:USER_SGPR: 2 -; COMPUTE_PGM_RSRC2:TRAP_HANDLER: 0 -; COMPUTE_PGM_RSRC2:TGID_X_EN: 1 -; COMPUTE_PGM_RSRC2:TGID_Y_EN: 0 -; COMPUTE_PGM_RSRC2:TGID_Z_EN: 0 -; COMPUTE_PGM_RSRC2:TIDIG_COMP_CNT: 0 -; COMPUTE_PGM_RSRC3_GFX90A:ACCUM_OFFSET: 1 -; COMPUTE_PGM_RSRC3_GFX90A:TG_SPLIT: 0 - .text - .p2alignl 6, 3212836864 - .fill 256, 4, 3212836864 - .section .AMDGPU.gpr_maximums,"",@progbits - .set amdgpu.max_num_vgpr, 0 - .set amdgpu.max_num_agpr, 0 - .set amdgpu.max_num_sgpr, 0 - .text - .type __hip_cuid_8b457700d3d88645,@object ; @__hip_cuid_8b457700d3d88645 - .section .bss,"aw",@nobits - .globl __hip_cuid_8b457700d3d88645 -__hip_cuid_8b457700d3d88645: - .byte 0 ; 0x0 - .size __hip_cuid_8b457700d3d88645, 1 - - .ident "AMD clang version 22.0.0git (https://github.com/RadeonOpenCompute/llvm-project roc-7.2.0 26014 7b800a19466229b8479a78de19143dc33c3ab9b5)" - .section ".note.GNU-stack","",@progbits - .addrsig - .addrsig_sym __hip_cuid_8b457700d3d88645 - .amdgpu_metadata ---- -amdhsa.kernels: - - .agpr_count: 16 - .args: - - .actual_access: read_only - .address_space: global - .offset: 0 - .size: 8 - .value_kind: global_buffer - - .actual_access: read_only - .address_space: global - .offset: 8 - .size: 8 - .value_kind: global_buffer - - .actual_access: write_only - .address_space: global - .offset: 16 - .size: 8 - .value_kind: global_buffer - - .actual_access: read_only - .address_space: global - .offset: 24 - .size: 8 - .value_kind: global_buffer - - .actual_access: read_only - .address_space: global - .offset: 32 - .size: 8 - .value_kind: global_buffer - - .offset: 40 - .size: 4 - .value_kind: by_value - - .offset: 44 - .size: 4 - .value_kind: by_value - - .offset: 48 - .size: 4 - .value_kind: by_value - - .offset: 52 - .size: 4 - .value_kind: by_value - - .offset: 56 - .size: 4 - .value_kind: by_value - - .offset: 60 - .size: 4 - .value_kind: by_value - - .offset: 64 - .size: 4 - .value_kind: by_value - .group_segment_fixed_size: 0 - .kernarg_segment_align: 8 - .kernarg_segment_size: 68 - .language: OpenCL C - .language_version: - - 2 - - 0 - .max_flat_workgroup_size: 64 - .name: _Z9gemm_corePKhS0_PvS0_S0_iiiiiii - .private_segment_fixed_size: 0 - .sgpr_count: 30 - .sgpr_spill_count: 0 - .symbol: _Z9gemm_corePKhS0_PvS0_S0_iiiiiii.kd - .uniform_work_group_size: 1 - .uses_dynamic_stack: false - .vgpr_count: 72 - .vgpr_spill_count: 0 - .wavefront_size: 64 - - .agpr_count: 0 - .args: - - .actual_access: read_only - .address_space: global - .offset: 0 - .size: 8 - .value_kind: global_buffer - - .actual_access: write_only - .address_space: global - .offset: 8 - .size: 8 - .value_kind: global_buffer - - .offset: 16 - .size: 4 - .value_kind: by_value - - .offset: 20 - .size: 4 - .value_kind: by_value - .group_segment_fixed_size: 0 - .kernarg_segment_align: 8 - .kernarg_segment_size: 24 - .language: OpenCL C - .language_version: - - 2 - - 0 - .max_flat_workgroup_size: 1024 - .name: _Z9reduce_skPKfP12hip_bfloat16ii - .private_segment_fixed_size: 0 - .sgpr_count: 14 - .sgpr_spill_count: 0 - .symbol: _Z9reduce_skPKfP12hip_bfloat16ii.kd - .uniform_work_group_size: 1 - .uses_dynamic_stack: false - .vgpr_count: 6 - .vgpr_spill_count: 0 - .wavefront_size: 64 -amdhsa.target: amdgcn-amd-amdhsa--gfx950 -amdhsa.version: - - 1 - - 2 -... - - .end_amdgpu_metadata - -# __CLANG_OFFLOAD_BUNDLE____END__ hip-amdgcn-amd-amdhsa--gfx950 - -# __CLANG_OFFLOAD_BUNDLE____START__ host-x86_64-unknown-linux-gnu- - .file "src.hip" - .text - .globl _Z24__device_stub__gemm_corePKhS0_PvS0_S0_iiiiiii # -- Begin function _Z24__device_stub__gemm_corePKhS0_PvS0_S0_iiiiiii - .p2align 4 - .type _Z24__device_stub__gemm_corePKhS0_PvS0_S0_iiiiiii,@function -_Z24__device_stub__gemm_corePKhS0_PvS0_S0_iiiiiii: # @_Z24__device_stub__gemm_corePKhS0_PvS0_S0_iiiiiii - .cfi_startproc -# %bb.0: - subq $200, %rsp - .cfi_def_cfa_offset 208 - movq %rdi, 88(%rsp) - movq %rsi, 80(%rsp) - movq %rdx, 72(%rsp) - movq %rcx, 64(%rsp) - movq %r8, 56(%rsp) - movl %r9d, 4(%rsp) - leaq 88(%rsp), %rax - movq %rax, 96(%rsp) - leaq 80(%rsp), %rax - movq %rax, 104(%rsp) - leaq 72(%rsp), %rax - movq %rax, 112(%rsp) - leaq 64(%rsp), %rax - movq %rax, 120(%rsp) - leaq 56(%rsp), %rax - movq %rax, 128(%rsp) - leaq 4(%rsp), %rax - movq %rax, 136(%rsp) - leaq 208(%rsp), %rax - movq %rax, 144(%rsp) - leaq 216(%rsp), %rax - movq %rax, 152(%rsp) - leaq 224(%rsp), %rax - movq %rax, 160(%rsp) - leaq 232(%rsp), %rax - movq %rax, 168(%rsp) - leaq 240(%rsp), %rax - movq %rax, 176(%rsp) - leaq 248(%rsp), %rax - movq %rax, 184(%rsp) - leaq 40(%rsp), %rdi - leaq 24(%rsp), %rsi - leaq 16(%rsp), %rdx - leaq 8(%rsp), %rcx - callq __hipPopCallConfiguration - movq 40(%rsp), %rsi - movl 48(%rsp), %edx - movq 24(%rsp), %rcx - movl 32(%rsp), %r8d - leaq 96(%rsp), %r9 - movl $_Z9gemm_corePKhS0_PvS0_S0_iiiiiii, %edi - pushq 8(%rsp) - .cfi_adjust_cfa_offset 8 - pushq 24(%rsp) - .cfi_adjust_cfa_offset 8 - callq hipLaunchKernel - addq $216, %rsp - .cfi_adjust_cfa_offset -216 - retq -.Lfunc_end0: - .size _Z24__device_stub__gemm_corePKhS0_PvS0_S0_iiiiiii, .Lfunc_end0-_Z24__device_stub__gemm_corePKhS0_PvS0_S0_iiiiiii - .cfi_endproc - # -- End function - .globl _Z24__device_stub__reduce_skPKfP12hip_bfloat16ii # -- Begin function _Z24__device_stub__reduce_skPKfP12hip_bfloat16ii - .p2align 4 - .type _Z24__device_stub__reduce_skPKfP12hip_bfloat16ii,@function -_Z24__device_stub__reduce_skPKfP12hip_bfloat16ii: # @_Z24__device_stub__reduce_skPKfP12hip_bfloat16ii - .cfi_startproc -# %bb.0: - subq $120, %rsp - .cfi_def_cfa_offset 128 - movq %rdi, 72(%rsp) - movq %rsi, 64(%rsp) - movl %edx, 12(%rsp) - movl %ecx, 8(%rsp) - leaq 72(%rsp), %rax - movq %rax, 80(%rsp) - leaq 64(%rsp), %rax - movq %rax, 88(%rsp) - leaq 12(%rsp), %rax - movq %rax, 96(%rsp) - leaq 8(%rsp), %rax - movq %rax, 104(%rsp) - leaq 48(%rsp), %rdi - leaq 32(%rsp), %rsi - leaq 24(%rsp), %rdx - leaq 16(%rsp), %rcx - callq __hipPopCallConfiguration - movq 48(%rsp), %rsi - movl 56(%rsp), %edx - movq 32(%rsp), %rcx - movl 40(%rsp), %r8d - leaq 80(%rsp), %r9 - movl $_Z9reduce_skPKfP12hip_bfloat16ii, %edi - pushq 16(%rsp) - .cfi_adjust_cfa_offset 8 - pushq 32(%rsp) - .cfi_adjust_cfa_offset 8 - callq hipLaunchKernel - addq $136, %rsp - .cfi_adjust_cfa_offset -136 - retq -.Lfunc_end1: - .size _Z24__device_stub__reduce_skPKfP12hip_bfloat16ii, .Lfunc_end1-_Z24__device_stub__reduce_skPKfP12hip_bfloat16ii - .cfi_endproc - # -- End function - .p2align 4 # -- Begin function __hip_module_ctor - .type __hip_module_ctor,@function -__hip_module_ctor: # @__hip_module_ctor - .cfi_startproc -# %bb.0: - pushq %rbx - .cfi_def_cfa_offset 16 - subq $32, %rsp - .cfi_def_cfa_offset 48 - .cfi_offset %rbx, -16 - movq __hip_gpubin_handle_8b457700d3d88645(%rip), %rbx - testq %rbx, %rbx - jne .LBB2_2 -# %bb.1: - movl $__hip_fatbin_wrapper, %edi - callq __hipRegisterFatBinary - movq %rax, %rbx - movq %rax, __hip_gpubin_handle_8b457700d3d88645(%rip) -.LBB2_2: - xorps %xmm0, %xmm0 - movups %xmm0, 16(%rsp) - movups %xmm0, (%rsp) - movl $_Z9gemm_corePKhS0_PvS0_S0_iiiiiii, %esi - movl $.L__unnamed_1, %edx - movl $.L__unnamed_1, %ecx - movq %rbx, %rdi - movl $-1, %r8d - xorl %r9d, %r9d - callq __hipRegisterFunction - xorps %xmm0, %xmm0 - movups %xmm0, 16(%rsp) - movups %xmm0, (%rsp) - movl $_Z9reduce_skPKfP12hip_bfloat16ii, %esi - movl $.L__unnamed_2, %edx - movl $.L__unnamed_2, %ecx - movq %rbx, %rdi - movl $-1, %r8d - xorl %r9d, %r9d - callq __hipRegisterFunction - movl $__hip_module_dtor, %edi - addq $32, %rsp - .cfi_def_cfa_offset 16 - popq %rbx - .cfi_def_cfa_offset 8 - jmp atexit # TAILCALL -.Lfunc_end2: - .size __hip_module_ctor, .Lfunc_end2-__hip_module_ctor - .cfi_endproc - # -- End function - .p2align 4 # -- Begin function __hip_module_dtor - .type __hip_module_dtor,@function -__hip_module_dtor: # @__hip_module_dtor - .cfi_startproc -# %bb.0: - movq __hip_gpubin_handle_8b457700d3d88645(%rip), %rdi - testq %rdi, %rdi - je .LBB3_2 -# %bb.1: - pushq %rax - .cfi_def_cfa_offset 16 - callq __hipUnregisterFatBinary - movq $0, __hip_gpubin_handle_8b457700d3d88645(%rip) - addq $8, %rsp - .cfi_def_cfa_offset 8 -.LBB3_2: - retq -.Lfunc_end3: - .size __hip_module_dtor, .Lfunc_end3-__hip_module_dtor - .cfi_endproc - # -- End function - .type _Z9gemm_corePKhS0_PvS0_S0_iiiiiii,@object # @_Z9gemm_corePKhS0_PvS0_S0_iiiiiii - .section .rodata,"a",@progbits - .globl _Z9gemm_corePKhS0_PvS0_S0_iiiiiii - .p2align 3, 0x0 -_Z9gemm_corePKhS0_PvS0_S0_iiiiiii: - .quad _Z24__device_stub__gemm_corePKhS0_PvS0_S0_iiiiiii - .size _Z9gemm_corePKhS0_PvS0_S0_iiiiiii, 8 - - .type _Z9reduce_skPKfP12hip_bfloat16ii,@object # @_Z9reduce_skPKfP12hip_bfloat16ii - .globl _Z9reduce_skPKfP12hip_bfloat16ii - .p2align 3, 0x0 -_Z9reduce_skPKfP12hip_bfloat16ii: - .quad _Z24__device_stub__reduce_skPKfP12hip_bfloat16ii - .size _Z9reduce_skPKfP12hip_bfloat16ii, 8 - - .type .L__unnamed_1,@object # @0 - .section .rodata.str1.1,"aMS",@progbits,1 -.L__unnamed_1: - .asciz "_Z9gemm_corePKhS0_PvS0_S0_iiiiiii" - .size .L__unnamed_1, 34 - - .type .L__unnamed_2,@object # @1 -.L__unnamed_2: - .asciz "_Z9reduce_skPKfP12hip_bfloat16ii" - .size .L__unnamed_2, 33 - - .type __hip_fatbin_wrapper,@object # @__hip_fatbin_wrapper - .section .hipFatBinSegment,"a",@progbits - .p2align 3, 0x0 -__hip_fatbin_wrapper: - .long 1212764230 # 0x48495046 - .long 1 # 0x1 - .quad __hip_fatbin_8b457700d3d88645 - .quad 0 - .size __hip_fatbin_wrapper, 24 - - .type __hip_gpubin_handle_8b457700d3d88645,@object # @__hip_gpubin_handle_8b457700d3d88645 - .local __hip_gpubin_handle_8b457700d3d88645 - .comm __hip_gpubin_handle_8b457700d3d88645,8,8 - .section .init_array,"aw",@init_array - .p2align 3, 0x0 - .quad __hip_module_ctor - .type __hip_cuid_8b457700d3d88645,@object # @__hip_cuid_8b457700d3d88645 - .bss - .globl __hip_cuid_8b457700d3d88645 -__hip_cuid_8b457700d3d88645: - .byte 0 # 0x0 - .size __hip_cuid_8b457700d3d88645, 1 - - .ident "AMD clang version 22.0.0git (https://github.com/RadeonOpenCompute/llvm-project roc-7.2.0 26014 7b800a19466229b8479a78de19143dc33c3ab9b5)" - .section ".note.GNU-stack","",@progbits - .addrsig - .addrsig_sym _Z24__device_stub__gemm_corePKhS0_PvS0_S0_iiiiiii - .addrsig_sym _Z24__device_stub__reduce_skPKfP12hip_bfloat16ii - .addrsig_sym __hip_module_ctor - .addrsig_sym __hip_module_dtor - .addrsig_sym _Z9gemm_corePKhS0_PvS0_S0_iiiiiii - .addrsig_sym _Z9reduce_skPKfP12hip_bfloat16ii - .addrsig_sym __hip_fatbin_8b457700d3d88645 - .addrsig_sym __hip_fatbin_wrapper - .addrsig_sym __hip_cuid_8b457700d3d88645 - -# __CLANG_OFFLOAD_BUNDLE____END__ host-x86_64-unknown-linux-gnu- diff --git a/examples/gemm_fp4/example_fusedmoe_a8w4_sm120.py b/examples/gemm_fp4/example_fusedmoe_a8w4_sm120.py index 104b139e50..fbd06a7e9d 100644 --- a/examples/gemm_fp4/example_fusedmoe_a8w4_sm120.py +++ b/examples/gemm_fp4/example_fusedmoe_a8w4_sm120.py @@ -13,10 +13,6 @@ see examples/fusedmoe/example_fusedmoe_tilelang.py. """ -<<<<<<< HEAD -======= -import os ->>>>>>> f13a6b71 (feat: A8W4 mixed-type MMA (FP8xFP4) + FP4 MoE example on SM120) import time import torch import tilelang @@ -24,7 +20,6 @@ FP4_E2M1_TO_FLOAT = [ -<<<<<<< HEAD 0.0, 0.5, 1.0, @@ -41,10 +36,6 @@ -3.0, -4.0, -6.0, -======= - 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, - -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0, ->>>>>>> f13a6b71 (feat: A8W4 mixed-type MMA (FP8xFP4) + FP4 MoE example on SM120) ] @@ -54,7 +45,6 @@ def fp4_uint8_to_float(t): def moe_shared_expert_a8w4( -<<<<<<< HEAD num_tokens, d_hidden, d_expert, @@ -63,11 +53,6 @@ def moe_shared_expert_a8w4( block_expert=128, threads=128, num_stages=1, -======= - num_tokens, d_hidden, d_expert, - block_token=128, block_hidden=128, block_expert=128, - threads=128, num_stages=1, ->>>>>>> f13a6b71 (feat: A8W4 mixed-type MMA (FP8xFP4) + FP4 MoE example on SM120) ): """Single shared expert: gate_up GEMM -> SiLU*up -> down GEMM.""" scale = 1.44269504 # log2(e) for fast SiLU @@ -105,13 +90,7 @@ def main( # Fused SiLU activation: gate = gate * sigmoid(gate), then up = up * gate for i, j in T.Parallel(block_token, block_expert): -<<<<<<< HEAD gate_local[i, j] = gate_local[i, j] * (1.0 / (1.0 + T.exp2(-gate_local[i, j] * scale))) -======= - gate_local[i, j] = gate_local[i, j] * ( - 1.0 / (1.0 + T.exp2(-gate_local[i, j] * scale)) - ) ->>>>>>> f13a6b71 (feat: A8W4 mixed-type MMA (FP8xFP4) + FP4 MoE example on SM120) up_local[i, j] = up_local[i, j] * gate_local[i, j] T.copy(up_local, output[bx * block_token, by * block_expert]) @@ -127,7 +106,6 @@ def main( print(f"Running FP4 MoE (A8W4): tokens={num_tokens}, hidden={d_hidden}, expert={d_expert}") func = moe_shared_expert_a8w4( -<<<<<<< HEAD num_tokens, d_hidden, d_expert, @@ -136,11 +114,6 @@ def main( block_expert=128, threads=128, num_stages=1, -======= - num_tokens, d_hidden, d_expert, - block_token=128, block_hidden=128, block_expert=128, - threads=128, num_stages=1, ->>>>>>> f13a6b71 (feat: A8W4 mixed-type MMA (FP8xFP4) + FP4 MoE example on SM120) ) jit_kernel = tilelang.compile( diff --git a/examples/gemm_fp4/example_gemm_a8w4_sm120.py b/examples/gemm_fp4/example_gemm_a8w4_sm120.py index 3083850076..b48564ba56 100644 --- a/examples/gemm_fp4/example_gemm_a8w4_sm120.py +++ b/examples/gemm_fp4/example_gemm_a8w4_sm120.py @@ -8,10 +8,6 @@ No block scaling. Direct FP8 x FP4 tensor core multiply-accumulate. """ -<<<<<<< HEAD -======= -import os ->>>>>>> f13a6b71 (feat: A8W4 mixed-type MMA (FP8xFP4) + FP4 MoE example on SM120) import time import torch import tilelang @@ -19,7 +15,6 @@ def gemm_a8w4( -<<<<<<< HEAD M, N, K, @@ -30,11 +25,6 @@ def gemm_a8w4( accum_dtype, num_stages=2, threads=128, -======= - M, N, K, block_M, block_N, block_K, - out_dtype, accum_dtype, - num_stages=2, threads=128, ->>>>>>> f13a6b71 (feat: A8W4 mixed-type MMA (FP8xFP4) + FP4 MoE example on SM120) ): A_shape = (M, K) B_shape = (N, K) @@ -62,7 +52,6 @@ def main( FP4_E2M1_TO_FLOAT = [ -<<<<<<< HEAD 0.0, 0.5, 1.0, @@ -79,10 +68,6 @@ def main( -3.0, -4.0, -6.0, -======= - 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, - -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0, ->>>>>>> f13a6b71 (feat: A8W4 mixed-type MMA (FP8xFP4) + FP4 MoE example on SM120) ] @@ -98,7 +83,6 @@ def fp4_uint8_to_float(tensor_uint8): accum_dtype = T.float32 print(f"Running A8W4 GEMM: M={M}, N={N}, K={K}") -<<<<<<< HEAD print(" A: float8_e4m3fn, B: FP4 (unpacked uint8)") func = gemm_a8w4( @@ -112,14 +96,6 @@ def fp4_uint8_to_float(tensor_uint8): accum_dtype, num_stages=2, threads=128, -======= -print(f" A: float8_e4m3fn, B: FP4 (unpacked uint8)") - -func = gemm_a8w4( - M, N, K, block_M, block_N, block_K, - out_dtype, accum_dtype, - num_stages=2, threads=128, ->>>>>>> f13a6b71 (feat: A8W4 mixed-type MMA (FP8xFP4) + FP4 MoE example on SM120) ) jit_kernel = tilelang.compile( @@ -164,11 +140,7 @@ def fp4_uint8_to_float(tensor_uint8): if rel_err < 0.01: print("[PASS] numerical verification (rel_err < 0.01)") else: -<<<<<<< HEAD print("[WARN] large diff -- may indicate layout or data flow issue") -======= - print(f"[WARN] large diff -- may indicate layout or data flow issue") ->>>>>>> f13a6b71 (feat: A8W4 mixed-type MMA (FP8xFP4) + FP4 MoE example on SM120) # --- Benchmark --- torch.cuda.synchronize() diff --git a/tilelang/tileop/gemm/gemm_mma.py b/tilelang/tileop/gemm/gemm_mma.py index 2282eccfa4..b5c2ba2c37 100644 --- a/tilelang/tileop/gemm/gemm_mma.py +++ b/tilelang/tileop/gemm/gemm_mma.py @@ -75,26 +75,7 @@ def lower( mbar_phase_expr: tir.PrimExpr | None = None, ): thread_nums = thread_bounds.extent -<<<<<<< HEAD mma_emitter = self._make_mma_emitter(target, thread_nums, thread_var=thread_var) -======= - m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, GemmInst.MMA) - warp_row_tiles = int(self.M // m_warp) - warp_col_tiles = int(self.N // n_warp) - mma_emitter = TensorCoreIntrinEmitter( - a_dtype=self.in_dtype, - b_dtype=self.in_dtype_b, - accum_dtype=self.accum_dtype, - a_transposed=self.trans_A, - b_transposed=self.trans_B, - block_row_warps=m_warp, - block_col_warps=n_warp, - warp_row_tiles=warp_row_tiles, - warp_col_tiles=warp_col_tiles, - chunk=self.chunk, - thread_var=thread_var, - ) ->>>>>>> f13a6b71 (feat: A8W4 mixed-type MMA (FP8xFP4) + FP4 MoE example on SM120) in_dtype = self.in_dtype in_dtype_b = self.in_dtype_b From 70afbc31760ef272c968913840e949687022eb04 Mon Sep 17 00:00:00 2001 From: wahao Date: Mon, 11 May 2026 17:57:46 +0800 Subject: [PATCH 11/19] rewrite: rewrite codes according to code rabbit's suggestions --- src/backend/cuda/op/copy.cc | 5 ++++- tilelang/cuda/intrinsics/macro/mma_macro_generator.py | 2 ++ .../cuda/intrinsics/macro/tcgen05_macro_generator.py | 11 ++++++----- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/src/backend/cuda/op/copy.cc b/src/backend/cuda/op/copy.cc index e7bd269daf..12584eb5f3 100644 --- a/src/backend/cuda/op/copy.cc +++ b/src/backend/cuda/op/copy.cc @@ -1215,9 +1215,12 @@ Stmt Copy::LowerBulk(const CopyNode &op, const LowerArgs &T, if (is_load && barrier_base_id >= 0) { PrimExpr total_bytes; if (is_align16b_subbyte_layout) { + int loop_extent = + ((*inner_box_dim) != instruction_dim) ? ((*inner_box_dim) / instruction_dim) : 1; // mbarrier transaction bytes track the FP4 payload moved by TMA, not the // byte-container shared-memory footprint introduced by 16U4_ALIGN16B. - total_bytes = TMABytesFromElements(total_elements, shared_tensor->dtype); + total_bytes = + TMABytesFromElements(total_elements * loop_extent, shared_tensor->dtype); } else if ((*inner_box_dim) != instruction_dim) { int loop_extent = (*inner_box_dim) / instruction_dim; total_bytes = TMABytesFromElements(total_elements * loop_extent, diff --git a/tilelang/cuda/intrinsics/macro/mma_macro_generator.py b/tilelang/cuda/intrinsics/macro/mma_macro_generator.py index 6be1e6b2f4..27dfc4fa30 100644 --- a/tilelang/cuda/intrinsics/macro/mma_macro_generator.py +++ b/tilelang/cuda/intrinsics/macro/mma_macro_generator.py @@ -121,6 +121,8 @@ def _initialize_k_dim(self, a_dtype=T.float16): # SM120 f8f6f4 FP4 MMA is m16n8k32. Although 256 / 4 would allow a # k64 fragment by bit count, there is no k64 dispatcher for FP4. if str(a_dtype) == "float4_e2m1fn": + if self.chunk < 32: + raise ValueError("FP4 MMA requires chunk to be a multiple of 32 (m16n8k32)") self.k_dim = min(32, self.chunk) else: self.k_dim = min(256 // a_dtype.bits, self.chunk) diff --git a/tilelang/cuda/intrinsics/macro/tcgen05_macro_generator.py b/tilelang/cuda/intrinsics/macro/tcgen05_macro_generator.py index 67daaedfb1..2a17114fb6 100644 --- a/tilelang/cuda/intrinsics/macro/tcgen05_macro_generator.py +++ b/tilelang/cuda/intrinsics/macro/tcgen05_macro_generator.py @@ -160,17 +160,17 @@ def _initialize_k_dim(self, a_dtype=T.float16): def _determinate_swizzle_mode(self, buffer: Buffer, layout: Layout) -> SwizzleMode: # same behavior to src/layout/gemm_layouts.cc::makeGemmABLayoutHopper tir_buffer = buffer.buffer if isinstance(buffer, BufferRegion) else buffer - if layout is None or layout.is_equal(make_linear_layout(buffer)): + if layout is None or layout.is_equal(make_linear_layout(tir_buffer)): return SwizzleMode.NONE if DataType(tir_buffer.dtype).bits < 8: - if layout.is_equal(make_align16b_swizzled_layout(buffer)): + if layout.is_equal(make_align16b_swizzled_layout(tir_buffer)): return SwizzleMode.SWIZZLE_128B raise ValueError(f"Unsupported sub-byte swizzle mode: {layout}") - elif layout.is_equal(make_quarter_bank_swizzled_layout(buffer)): + elif layout.is_equal(make_quarter_bank_swizzled_layout(tir_buffer)): return SwizzleMode.SWIZZLE_32B - elif layout.is_equal(make_half_bank_swizzled_layout(buffer)): + elif layout.is_equal(make_half_bank_swizzled_layout(tir_buffer)): return SwizzleMode.SWIZZLE_64B - elif layout.is_equal(make_full_bank_swizzled_layout(buffer)): + elif layout.is_equal(make_full_bank_swizzled_layout(tir_buffer)): return SwizzleMode.SWIZZLE_128B else: raise ValueError(f"Unsupported swizzle mode: {layout}") @@ -543,6 +543,7 @@ def get_tcgen5_instr_desc( self, atom_m: int, atom_n: int, atom_k: int, a_is_k_major: bool, b_is_k_major: bool, scale_in_a: int, scale_in_b: int ) -> PrimExpr: """Build the 64-bit instruction descriptor for a ``tcgen05.mma`` PTX call.""" + def encode_dtype(dtype: str) -> int: dtype = str(DataType(dtype)) if dtype == "float16": From f816a4ec99844b7b183ce3a0e0a2600b179ef089 Mon Sep 17 00:00:00 2001 From: wahao Date: Fri, 15 May 2026 16:20:50 +0800 Subject: [PATCH 12/19] handle non-align16B fp4 shared layout gracefully --- .../gemm_fp4/example_fusedmoe_a8w4_sm120.py | 20 ++++++++----------- examples/gemm_fp4/example_gemm_fp4_sm120.py | 5 +++-- src/layout/gemm_layouts.cc | 9 +++++++++ src/tl_templates/cuda/gemm_sm100.h | 18 +++-------------- .../macro/tcgen05_macro_generator.py | 2 +- tilelang/cuda/op/gemm/gemm_tcgen05.py | 2 +- 6 files changed, 25 insertions(+), 31 deletions(-) diff --git a/examples/gemm_fp4/example_fusedmoe_a8w4_sm120.py b/examples/gemm_fp4/example_fusedmoe_a8w4_sm120.py index fbd06a7e9d..6d38785060 100644 --- a/examples/gemm_fp4/example_fusedmoe_a8w4_sm120.py +++ b/examples/gemm_fp4/example_fusedmoe_a8w4_sm120.py @@ -1,10 +1,9 @@ """FP4 Fused MoE shared expert kernel on SM120 (A8W4 mode). -Demonstrates the core MoE compute pattern with FP4 weights: +Demonstrates the core MoE gate/up compute pattern with FP4 weights: 1. Gate GEMM: input(FP8) x W_gate(FP4) -> gate logits 2. Up GEMM: input(FP8) x W_up(FP4) -> up logits 3. SiLU(gate) * up - 4. Down GEMM: activated(FP8) x W_down(FP4) -> output Uses SM120 native kind::f8f6f4 MMA (FP8 x FP4 -> FP32). Expert weights are stored as unpacked uint8 (1 FP4 per byte, low nibble). @@ -59,11 +58,10 @@ def moe_shared_expert_a8w4( @T.prim_func def main( - input: T.Tensor((num_tokens, d_hidden), "float8_e4m3fn"), + Input: T.Tensor((num_tokens, d_hidden), "float8_e4m3fn"), W_gate: T.Tensor((d_expert, d_hidden), "uint8"), W_up: T.Tensor((d_expert, d_hidden), "uint8"), - W_down: T.Tensor((d_hidden, d_expert), "uint8"), - output: T.Tensor((num_tokens, d_hidden), "float32"), + output: T.Tensor((num_tokens, d_expert), "float32"), ): # Step 1: Gate + Up GEMMs (fused in one kernel launch) with T.Kernel( @@ -82,7 +80,7 @@ def main( T.clear(up_local) for k in T.Pipelined(T.ceildiv(d_hidden, block_hidden), num_stages=num_stages): - T.copy(input[bx * block_token, k * block_hidden], input_shared) + T.copy(Input[bx * block_token, k * block_hidden], input_shared) T.copy(W_gate[by * block_expert, k * block_hidden], W_gate_shared) T.copy(W_up[by * block_expert, k * block_hidden], W_up_shared) T.gemm(input_shared, W_gate_shared, gate_local, transpose_B=True) @@ -118,7 +116,7 @@ def main( jit_kernel = tilelang.compile( func, - out_idx=[4], + out_idx=[3], target="cuda", pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, @@ -134,18 +132,16 @@ def main( input_fp8 = torch.randn(num_tokens, d_hidden, device="cuda", dtype=torch.float16).to(torch.float8_e4m3fn) W_gate_uint8 = torch.randint(0, 16, (d_expert, d_hidden), device="cuda", dtype=torch.uint8) W_up_uint8 = torch.randint(0, 16, (d_expert, d_hidden), device="cuda", dtype=torch.uint8) -W_down_uint8 = torch.randint(0, 16, (d_hidden, d_expert), device="cuda", dtype=torch.uint8) # --- Test 1: zeros --- z_input = torch.zeros(num_tokens, d_hidden, device="cuda", dtype=torch.float8_e4m3fn) z_gate = torch.zeros(d_expert, d_hidden, device="cuda", dtype=torch.uint8) z_up = torch.zeros(d_expert, d_hidden, device="cuda", dtype=torch.uint8) -z_down = torch.zeros(d_hidden, d_expert, device="cuda", dtype=torch.uint8) -c_zero = jit_kernel(z_input, z_gate, z_up, z_down) +c_zero = jit_kernel(z_input, z_gate, z_up) print(f"[{'PASS' if c_zero.abs().max().item() == 0.0 else 'FAIL'}] zeros in -> zeros out") # --- Test 2: numerical verification (gate+up only, no down GEMM in this kernel) --- -out = jit_kernel(input_fp8, W_gate_uint8, W_up_uint8, W_down_uint8) +out = jit_kernel(input_fp8, W_gate_uint8, W_up_uint8) # Reference input_f32 = input_fp8.to(torch.float32) @@ -170,7 +166,7 @@ def main( torch.cuda.synchronize() start = time.perf_counter() for _ in range(100): - jit_kernel(input_fp8, W_gate_uint8, W_up_uint8, W_down_uint8) + jit_kernel(input_fp8, W_gate_uint8, W_up_uint8) torch.cuda.synchronize() elapsed = (time.perf_counter() - start) / 100 * 1000 total_flops = 2 * num_tokens * d_hidden * d_expert * 2 # 2 GEMMs diff --git a/examples/gemm_fp4/example_gemm_fp4_sm120.py b/examples/gemm_fp4/example_gemm_fp4_sm120.py index c79fab8df3..98f32f4edc 100644 --- a/examples/gemm_fp4/example_gemm_fp4_sm120.py +++ b/examples/gemm_fp4/example_gemm_fp4_sm120.py @@ -122,8 +122,9 @@ def unpack_fp4_to_float(packed_int8: torch.Tensor, M: int, K: int) -> torch.Tens ) print("Compilation succeeded!") -with open(os.path.join(os.path.dirname(__file__), "gemm_fp4_sm120.cu"), "w") as f: - f.write(jit_kernel.get_kernel_source()) +if os.environ.get("TL_FP4_DUMP_CUDA", "0") != "0": + with open(os.path.join(os.path.dirname(__file__), "gemm_fp4_sm120.cu"), "w") as f: + f.write(jit_kernel.get_kernel_source()) torch.manual_seed(42) diff --git a/src/layout/gemm_layouts.cc b/src/layout/gemm_layouts.cc index e96e838e3f..7de43d3c8d 100644 --- a/src/layout/gemm_layouts.cc +++ b/src/layout/gemm_layouts.cc @@ -919,6 +919,15 @@ Layout makeGemmABLayoutSm100(int mat_stride, int mat_continuous, int continuity, // Sub-byte types (FP4): use ALIGN16B unpacksmem layout. TMA writes each // logical FP4 into an 8-bit SMEM container, matching tcgen05.mma consumption. if (element_size < 8) { + if (mat_stride % 8 == 0 && mat_continuous % 128 == 0) { + return MakeAlign16BSwizzleLayout2D(mat_stride, mat_continuous, + element_size); + } + int vector_size = 128 / element_size; + if (mat_continuous % vector_size == 0) { + return makeLinearLayout( + Array{Integer(mat_stride), Integer(mat_continuous)}); + } return MakeAlign16BSwizzleLayout2D(mat_stride, mat_continuous, element_size); } diff --git a/src/tl_templates/cuda/gemm_sm100.h b/src/tl_templates/cuda/gemm_sm100.h index 3615501daf..c6085eae66 100644 --- a/src/tl_templates/cuda/gemm_sm100.h +++ b/src/tl_templates/cuda/gemm_sm100.h @@ -341,21 +341,9 @@ struct MMA_Traits 这部分负责声明参数列表,即让这个偏特化对所有可能的 -// 组合都有效。 -// 2. 内层 struct MMA_Traits> 说明:只要模板参数匹配到 -// SM100_MMA_F8F6F4_TS 那一整组, -// 就会自动选择这个偏特化的实现,不需要“再特化”。 -// 换句话说,你写一个 MMA_Traits,如果 X 恰好能匹配 SM100_MMA_F8F6F4_TS<...> -// 那一组参数, 这个特化版本就会被用到。比如: -// using Traits = MMA_Traits>; -// Traits的内容就是匹配到此特化,无需再继续特化template参数。 -// 总结:外层 template 支持泛型匹配,struct MMA_Traits<...> -// 实现具体特化,实际用时只需传入对应参数即可自动选中。 - +// This partial specialization is selected automatically for TS atoms whose +// operands are FP4. It keeps the runtime descriptor encoding in the f8f6f4 +// domain, where E2M1 is encoded as MXF8F6F4Format::E2M1 (=5). template struct MMA_Traits< diff --git a/tilelang/cuda/intrinsics/macro/tcgen05_macro_generator.py b/tilelang/cuda/intrinsics/macro/tcgen05_macro_generator.py index 2a17114fb6..51242d174a 100644 --- a/tilelang/cuda/intrinsics/macro/tcgen05_macro_generator.py +++ b/tilelang/cuda/intrinsics/macro/tcgen05_macro_generator.py @@ -674,7 +674,7 @@ def compute_tcgen05_b_desc_params(self, B_buf) -> TCGEN05DescriptorParams: n_dim_per_cta = n_dim // 2 if enable_2cta else n_dim k_dim = self.chunk micro_size_k = self.micro_size_k - elems_in_bytes = (DataType(self.a_dtype).bits + 7) // 8 + elems_in_bytes = (DataType(self.b_dtype).bits + 7) // 8 b_is_k_major = self.b_transposed b_swizzle_mode = self._determinate_swizzle_mode(B_buf, self.b_shared_layout) diff --git a/tilelang/cuda/op/gemm/gemm_tcgen05.py b/tilelang/cuda/op/gemm/gemm_tcgen05.py index 3de048fa6e..7068fcac03 100644 --- a/tilelang/cuda/op/gemm/gemm_tcgen05.py +++ b/tilelang/cuda/op/gemm/gemm_tcgen05.py @@ -60,7 +60,7 @@ def infer_shared_layout(self, dtype, continuity: int, k_major: bool) -> Callable try: _pass_ctx = tvm.transform.PassContext.current() disable_tma = _pass_ctx.config.get("tl.disable_tma_lower", False) - except Exception: + except (AttributeError, KeyError, TypeError): disable_tma = False if disable_tma: # Non-TMA path: SIMT copy writes packed-linear nibbles; use a From 57449a88b6328544caabbc14d328143bf5cbecb4 Mon Sep 17 00:00:00 2001 From: wahao Date: Fri, 15 May 2026 17:01:33 +0800 Subject: [PATCH 13/19] refresh the code for pre-commit check --- src/backend/cuda/op/copy.cc | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/backend/cuda/op/copy.cc b/src/backend/cuda/op/copy.cc index 12584eb5f3..9e53e8f437 100644 --- a/src/backend/cuda/op/copy.cc +++ b/src/backend/cuda/op/copy.cc @@ -1215,12 +1215,13 @@ Stmt Copy::LowerBulk(const CopyNode &op, const LowerArgs &T, if (is_load && barrier_base_id >= 0) { PrimExpr total_bytes; if (is_align16b_subbyte_layout) { - int loop_extent = - ((*inner_box_dim) != instruction_dim) ? ((*inner_box_dim) / instruction_dim) : 1; + int loop_extent = ((*inner_box_dim) != instruction_dim) + ? ((*inner_box_dim) / instruction_dim) + : 1; // mbarrier transaction bytes track the FP4 payload moved by TMA, not the // byte-container shared-memory footprint introduced by 16U4_ALIGN16B. - total_bytes = - TMABytesFromElements(total_elements * loop_extent, shared_tensor->dtype); + total_bytes = TMABytesFromElements(total_elements * loop_extent, + shared_tensor->dtype); } else if ((*inner_box_dim) != instruction_dim) { int loop_extent = (*inner_box_dim) / instruction_dim; total_bytes = TMABytesFromElements(total_elements * loop_extent, From 0d2a9cba57b19bd447ee10e9e209e5a09748f467 Mon Sep 17 00:00:00 2001 From: wahao Date: Tue, 2 Jun 2026 09:35:03 +0800 Subject: [PATCH 14/19] tmp save --- examples/gemm_fp4/example_gemm_nvfp4_sm120.py | 161 ++++++++++++++++++ src/backend/cuda/codegen/codegen_cuda.cc | 86 ++++++++++ src/op/builtin.cc | 5 + src/op/builtin.h | 10 ++ src/tl_templates/cuda/instruction/mma.h | 119 +++++++++++++ .../intrinsics/macro/mma_macro_generator.py | 94 ++++++++++ tilelang/language/ast/ir.py | 2 + tilelang/language/tir/ir.py | 1 + tilelang/language/tir/op.py | 48 ++++++ 9 files changed, 526 insertions(+) create mode 100644 examples/gemm_fp4/example_gemm_nvfp4_sm120.py diff --git a/examples/gemm_fp4/example_gemm_nvfp4_sm120.py b/examples/gemm_fp4/example_gemm_nvfp4_sm120.py new file mode 100644 index 0000000000..61843e0d9c --- /dev/null +++ b/examples/gemm_fp4/example_gemm_nvfp4_sm120.py @@ -0,0 +1,161 @@ +"""Minimal NVFP4 GEMM on SM120 using block-scaled MMA. + +This example exercises ``mma.sync.aligned.kind::mxf4nvf4.block_scale`` for +FP4 E2M1 inputs with UE4M3 scale factors. It intentionally starts with a +single SM120 MMA shape (m16n8k64) before wiring NVFP4 into the generic GEMM API. +""" + +import os +import torch +import tilelang +import tilelang.language as T +from tilelang.cuda.intrinsics.macro.mma_macro_generator import TensorCoreIntrinEmitter +from tilelang.layout import make_swizzled_layout + + +FP4_E2M1_TO_FLOAT = [ + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, + -0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, +] + +# UE4M3 encoding for scale 1.0, packed four times into one uint32 register. +UE4M3_ONE_X4 = 0x38383838 + + +def unpack_fp4_to_uint8(packed_int8: torch.Tensor, M: int, K: int) -> torch.Tensor: + """Unpack (M, K//2) int8 -> (M, K) uint8, one FP4 value per byte.""" + flat = packed_int8.to(torch.uint8).reshape(M, K // 2) + lo = flat & 0x0F + hi = (flat >> 4) & 0x0F + return torch.stack([lo, hi], dim=-1).reshape(M, K).contiguous() + + +def unpack_fp4_to_float(packed_int8: torch.Tensor, M: int, K: int) -> torch.Tensor: + """Unpack FP4 E2M1 data to float32 through a lookup table.""" + lut = torch.tensor(FP4_E2M1_TO_FLOAT, dtype=torch.float32, device=packed_int8.device) + unpacked = unpack_fp4_to_uint8(packed_int8, M, K).to(torch.int64) + return lut[unpacked] + + +def matmul_nvfp4_sm120(M=16, N=8, K=64, out_dtype=T.float32, accum_dtype=T.float32): + block_M, block_N, block_K = 16, 8, 64 + threads = 32 + + emitter = TensorCoreIntrinEmitter( + a_dtype=T.float4_e2m1fn, + b_dtype=T.float4_e2m1fn, + accum_dtype=accum_dtype, + a_transposed=False, + b_transposed=True, + block_row_warps=1, + block_col_warps=1, + warp_row_tiles=block_M, + warp_col_tiles=block_N, + chunk=block_K, + is_blockscaled=True, + scale_dtype=T.float8_e4m3, + scale_vec_size=16, + ) + + local_size_a = emitter.local_size_a + local_size_b = emitter.local_size_b + local_size_c = emitter.local_size_out + + @T.prim_func + def main( + A: T.Tensor((M, K), "uint8"), + B: T.Tensor((N, K), "uint8"), + SFA: T.Tensor((T.ceildiv(M, block_M), T.ceildiv(K, block_K)), "uint32"), + SFB: T.Tensor((T.ceildiv(N, block_N), T.ceildiv(K, block_K)), "uint32"), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), "uint8") + B_shared = T.alloc_shared((block_N, block_K), "uint8") + C_shared = T.alloc_shared((1, 1, block_M, block_N), out_dtype) + A_local = T.alloc_local((local_size_a,), T.float4_e2m1fn) + B_local = T.alloc_local((local_size_b,), T.float4_e2m1fn) + C_local = T.alloc_local((local_size_c,), accum_dtype) + SFA_local = T.alloc_local((1,), "uint32") + SFB_local = T.alloc_local((1,), "uint32") + + T.annotate_layout( + { + A_shared: make_swizzled_layout(A_shared), + B_shared: make_swizzled_layout(B_shared), + } + ) + + T.clear(C_local) + for ko in T.serial(T.ceildiv(K, block_K)): + T.copy(A[by * block_M, ko * block_K], A_shared) + T.copy(B[bx * block_N, ko * block_K], B_shared) + SFA_local[0] = SFA[by, ko] + SFB_local[0] = SFB[bx, ko] + emitter.ldmatrix_a(A_local, A_shared, 0) + emitter.ldmatrix_b(B_local, B_shared, 0) + emitter.mma_blockscaled(A_local, B_local, C_local, SFA_local, SFB_local) + + emitter.stmatrix(C_local, C_shared) + for i, j in T.Parallel(block_M, block_N): + C[by * block_M + i, bx * block_N + j] = C_shared[0, 0, i, j] + + return main + + +if __name__ == "__main__": + M, N, K = 16, 8, 64 + print(f"Running NVFP4 SM120 block-scaled MMA: M={M}, N={N}, K={K}") + + kernel = tilelang.compile( + matmul_nvfp4_sm120(M, N, K), + out_idx=[4], + target="cuda", + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, + ) + print("Compilation succeeded!") + + if os.environ.get("TL_NVFP4_DUMP_CUDA", "0") != "0": + with open(os.path.join(os.path.dirname(__file__), "gemm_nvfp4_sm120.cu"), "w") as f: + f.write(kernel.get_kernel_source()) + + torch.manual_seed(0) + a_packed = torch.randint(0, 256, (M, K // 2), device="cuda", dtype=torch.uint8).to(torch.int8) + b_packed = torch.randint(0, 256, (N, K // 2), device="cuda", dtype=torch.uint8).to(torch.int8) + a_unpacked = unpack_fp4_to_uint8(a_packed, M, K) + b_unpacked = unpack_fp4_to_uint8(b_packed, N, K) + + sfa = torch.full((1, 1), UE4M3_ONE_X4, device="cuda", dtype=torch.uint32) + sfb = torch.full((1, 1), UE4M3_ONE_X4, device="cuda", dtype=torch.uint32) + + a_zero = torch.zeros((M, K), device="cuda", dtype=torch.uint8) + b_zero = torch.zeros((N, K), device="cuda", dtype=torch.uint8) + c_zero = kernel(a_zero, b_zero, sfa, sfb) + print(f"[ZERO] max_abs={c_zero.abs().max().item():.4f}") + + a_one = torch.full((M, K), 2, device="cuda", dtype=torch.uint8) + b_one = torch.full((N, K), 2, device="cuda", dtype=torch.uint8) + c_one = kernel(a_one, b_one, sfa, sfb) + print(f"[ONE] first={c_one[0, 0].item():.4f}, expected={float(K):.4f}") + + c = kernel(a_unpacked, b_unpacked, sfa, sfb) + ref = unpack_fp4_to_float(a_packed, M, K) @ unpack_fp4_to_float(b_packed, N, K).T + diff = (c.float() - ref).abs() + print(f"[NUMERICAL] max_abs_diff={diff.max().item():.4f}, rel_err={diff.sum().item() / (ref.abs().sum().item() + 1e-10):.6f}") diff --git a/src/backend/cuda/codegen/codegen_cuda.cc b/src/backend/cuda/codegen/codegen_cuda.cc index 3a7f713e65..479ce18802 100644 --- a/src/backend/cuda/codegen/codegen_cuda.cc +++ b/src/backend/cuda/codegen/codegen_cuda.cc @@ -2604,6 +2604,92 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { replacer.register_rule("(C_ptr)", c_ref); replacer.register_rule("(C_offset)", c_bias); this->stream << replacer.rewrite(mma_call); + } else if (op->op.same_as(tl::ptx_mma_blockscaled())) { + ICHECK_EQ(op->args.size(), 17U); + std::string dtype = Downcast(op->args[0])->value; + std::string shape = Downcast(op->args[1])->value; + std::string A_layout = Downcast(op->args[2])->value; + std::string B_layout = Downcast(op->args[3])->value; + std::string A_dtype = Downcast(op->args[4])->value; + std::string B_dtype = Downcast(op->args[5])->value; + std::string C_dtype = Downcast(op->args[6])->value; + std::string SF_dtype = Downcast(op->args[7])->value; + int scale_vec_size = Downcast(op->args[8])->value; + std::string a_ref = this->PrintExpr(op->args[9]); + std::string a_bias = this->PrintExpr(op->args[10]); + std::string b_ref = this->PrintExpr(op->args[11]); + std::string b_bias = this->PrintExpr(op->args[12]); + std::string c_ref = this->PrintExpr(op->args[13]); + std::string c_bias = this->PrintExpr(op->args[14]); + std::string sfa_ref = this->PrintExpr(op->args[15]); + std::string sfb_ref = this->PrintExpr(op->args[16]); + + auto dtype_a_enum = tl::codegen::ptx::DTypeFromString(A_dtype); + auto dtype_b_enum = tl::codegen::ptx::DTypeFromString(B_dtype); + auto dtype_enum = tl::codegen::ptx::DTypeFromString(dtype); + auto dtype_c_enum = tl::codegen::ptx::DTypeFromString(C_dtype); + auto dtype_sf_enum = tl::codegen::ptx::DTypeFromString(SF_dtype); + ICHECK(dtype_enum == dtype_c_enum) + << "ptx_mma_blockscaled dtype must match C_dtype"; + auto [m, n, k] = tl::codegen::ptx::ParseMMAShape(shape); + + need_mma_instruction_h_ = true; + this->PrintIndent(); + std::string mma_call = + "tl::mma_sync_blockscaled<(AType), (BType), (CType), (SFType), (M), " + "(N), (K), (TransA), (TransB), (VS)>(" + "reinterpret_cast<(CRegType)*>((C_ptr) + (C_offset)), " + "reinterpret_cast((A_ptr) + (A_offset)), " + "reinterpret_cast((B_ptr) + (B_offset)), " + "reinterpret_cast((SFA_ptr)), " + "reinterpret_cast((SFB_ptr)));\n"; + tl::codegen::Replacer replacer; + replacer.register_rule("(AType)", + tl::codegen::ptx::DTypeEnumToString(dtype_a_enum)); + replacer.register_rule("(BType)", + tl::codegen::ptx::DTypeEnumToString(dtype_b_enum)); + replacer.register_rule("(CType)", + tl::codegen::ptx::DTypeEnumToString(dtype_c_enum)); + replacer.register_rule("(SFType)", + tl::codegen::ptx::DTypeEnumToString(dtype_sf_enum)); + replacer.register_rule("(M)", std::to_string(m)); + replacer.register_rule("(N)", std::to_string(n)); + replacer.register_rule("(K)", std::to_string(k)); + replacer.register_rule("(TransA)", A_layout == "row" ? "false" : "true"); + replacer.register_rule("(TransB)", B_layout == "row" ? "false" : "true"); + replacer.register_rule("(VS)", std::to_string(scale_vec_size)); + replacer.register_rule("(CRegType)", + tl::codegen::GetMMARegisterType(dtype_c_enum)); + replacer.register_rule("(ARegType)", "uint32_t"); + replacer.register_rule("(BRegType)", "uint32_t"); + replacer.register_rule("(SFRegType)", + "typename tl::detail::BlockScaledMmaDispatcher<" + + tl::codegen::ptx::DTypeEnumToString( + dtype_a_enum) + + ", " + + tl::codegen::ptx::DTypeEnumToString( + dtype_b_enum) + + ", " + + tl::codegen::ptx::DTypeEnumToString( + dtype_c_enum) + + ", " + + tl::codegen::ptx::DTypeEnumToString( + dtype_sf_enum) + + ", " + std::to_string(m) + ", " + + std::to_string(n) + ", " + std::to_string(k) + + ", " + (A_layout == "row" ? "false" : "true") + + ", " + (B_layout == "row" ? "false" : "true") + + ", " + std::to_string(scale_vec_size) + + ">::SFRegType"); + replacer.register_rule("(A_ptr)", a_ref); + replacer.register_rule("(A_offset)", a_bias); + replacer.register_rule("(B_ptr)", b_ref); + replacer.register_rule("(B_offset)", b_bias); + replacer.register_rule("(C_ptr)", c_ref); + replacer.register_rule("(C_offset)", c_bias); + replacer.register_rule("(SFA_ptr)", sfa_ref); + replacer.register_rule("(SFB_ptr)", sfb_ref); + this->stream << replacer.rewrite(mma_call); } else if (op->op.same_as(builtin::ptx_mma_sp())) { // arg 0: shape: mXnXkX // arg 1: A layout: row/col diff --git a/src/op/builtin.cc b/src/op/builtin.cc index 508f90c654..f74a2457cb 100644 --- a/src/op/builtin.cc +++ b/src/op/builtin.cc @@ -252,6 +252,11 @@ TIR_DEFINE_TL_BUILTIN(ptx_mma_sm70) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_TL_BUILTIN(ptx_mma_blockscaled) + .set_num_inputs(17) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + TIR_DEFINE_TL_BUILTIN(ptx_ldmatrix) .set_num_inputs(4) .set_attr("TCallEffectKind", diff --git a/src/op/builtin.h b/src/op/builtin.h index 69a65ab10a..410f3e32e9 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -431,6 +431,16 @@ TVM_DLL const Op &ptx_deallocate_tensor_memory(); */ TVM_DLL const Op &ptx_mma_sm70(); +/*! + * \brief TileLang intrinsic for SM120 block-scaled MMA. + * + * ptx_mma_blockscaled(dtype, shape, A_layout, B_layout, A_dtype, B_dtype, + * C_dtype, SF_dtype, scale_vec_size, + * A_ptr, A_offset, B_ptr, B_offset, C_ptr, C_offset, + * SFA_ptr, SFB_ptr) + */ +TVM_DLL const Op &ptx_mma_blockscaled(); + /*! * \brief tvm intrinsics for ldmatrix * diff --git a/src/tl_templates/cuda/instruction/mma.h b/src/tl_templates/cuda/instruction/mma.h index ff120f41be..91074c7518 100644 --- a/src/tl_templates/cuda/instruction/mma.h +++ b/src/tl_templates/cuda/instruction/mma.h @@ -31,6 +31,13 @@ template struct MmaImplTraits { static constexpr int kCRegs = std::extent_v; }; +template struct BlockScaledMmaImplTraits : MmaImplTraits { + using SFReg = std::remove_extent_t; + + static constexpr int kSFARegs = std::extent_v; + static constexpr int kSFBRegs = std::extent_v; +}; + template TL_DEVICE void @@ -55,6 +62,40 @@ TL_DEVICE void call_fma(typename MmaImplTraits::DReg *d, std::make_index_sequence::kCRegs>{}); } +template +TL_DEVICE void call_blockscaled_fma_impl( + typename BlockScaledMmaImplTraits::DReg *d, + const typename BlockScaledMmaImplTraits::AReg *a, + const typename BlockScaledMmaImplTraits::BReg *b, + const typename BlockScaledMmaImplTraits::CReg *c, + const typename BlockScaledMmaImplTraits::SFReg *sfa, + const typename BlockScaledMmaImplTraits::SFReg *sfb, + std::index_sequence, std::index_sequence, + std::index_sequence, std::index_sequence, + std::index_sequence, std::index_sequence) { + Impl::fma(d[DIdx]..., a[AIdx]..., b[BIdx]..., c[CIdx]..., sfa[SFAIdx]..., + sfb[SFBIdx]...); +} + +template +TL_DEVICE void +call_blockscaled_fma(typename BlockScaledMmaImplTraits::DReg *d, + const typename BlockScaledMmaImplTraits::AReg *a, + const typename BlockScaledMmaImplTraits::BReg *b, + const typename BlockScaledMmaImplTraits::CReg *c, + const typename BlockScaledMmaImplTraits::SFReg *sfa, + const typename BlockScaledMmaImplTraits::SFReg *sfb) { + call_blockscaled_fma_impl( + d, a, b, c, sfa, sfb, + std::make_index_sequence::kDRegs>{}, + std::make_index_sequence::kARegs>{}, + std::make_index_sequence::kBRegs>{}, + std::make_index_sequence::kCRegs>{}, + std::make_index_sequence::kSFARegs>{}, + std::make_index_sequence::kSFBRegs>{}); +} + template struct MmaDispatcher { @@ -69,6 +110,22 @@ struct MmaDispatcher { } }; +template +struct BlockScaledMmaDispatcher { + using CRegType = void; + using ARegType = void; + using BRegType = void; + using SFRegType = void; + + static TL_DEVICE void exec(CRegType *, const ARegType *, const BRegType *, + const CRegType *, const SFRegType *, + const SFRegType *) { + static_assert(always_false_v>, + "tl::mma_sync_blockscaled: unsupported configuration"); + } +}; + #define TL_DEFINE_MMA_DISPATCHER(ATypeEnum, BTypeEnum, CTypeEnum, MValue, \ NValue, KValue, TransAValue, TransBValue, \ SaturateValue, ImplType) \ @@ -90,6 +147,30 @@ struct MmaDispatcher { } \ }; +#define TL_DEFINE_BLOCKSCALED_MMA_DISPATCHER( \ + ATypeEnum, BTypeEnum, CTypeEnum, SFTypeEnum, MValue, NValue, KValue, \ + TransAValue, TransBValue, VSValue, ImplType) \ + template <> \ + struct BlockScaledMmaDispatcher< \ + DataType::ATypeEnum, DataType::BTypeEnum, DataType::CTypeEnum, \ + DataType::SFTypeEnum, MValue, NValue, KValue, TransAValue, TransBValue, \ + VSValue> { \ + using Impl = ImplType; \ + using Traits = BlockScaledMmaImplTraits; \ + using CRegType = typename Traits::DReg; \ + using ARegType = typename Traits::AReg; \ + using BRegType = typename Traits::BReg; \ + using SFRegType = typename Traits::SFReg; \ + static_assert( \ + std::is_same_v, \ + "tl::mma_sync_blockscaled requires matching accumulator/output regs"); \ + static TL_DEVICE void exec(CRegType *d, const ARegType *a, \ + const BRegType *b, const CRegType *c, \ + const SFRegType *sfa, const SFRegType *sfb) { \ + call_blockscaled_fma(d, a, b, c, sfa, sfb); \ + } \ + }; + // FP16 inputs (TN layout: A row-major, B column-major) TL_DEFINE_MMA_DISPATCHER(kFloat16, kFloat16, kFloat16, 16, 8, 16, false, true, false, cute::SM80_16x8x16_F16F16F16F16_TN) @@ -163,7 +244,22 @@ TL_DEFINE_MMA_DISPATCHER(kFloat8_e4m3, kFloat4_e2m1fn, kFloat32, 16, 8, 32, TL_DEFINE_MMA_DISPATCHER(kFloat4_e2m1fn, kFloat8_e4m3, kFloat32, 16, 8, 32, false, true, false, SM120_FP4_FP8_F32_TN) +// NVFP4 block-scaled MMA (SM120 kind::mxf4nvf4.block_scale). +using SM120_NVFP4_NVFP4_F32_UE4M3_TN = + cute::SM120::BLOCKSCALED::SM120_16x8x64_TN_VS< + cute::float_e2m1_t, cute::float_e2m1_t, float, cute::float_ue4m3_t, 32>; +using SM120_NVFP4_NVFP4_F32_UE4M3_VS16_TN = + cute::SM120::BLOCKSCALED::SM120_16x8x64_TN_VS< + cute::float_e2m1_t, cute::float_e2m1_t, float, cute::float_ue4m3_t, 16>; +TL_DEFINE_BLOCKSCALED_MMA_DISPATCHER( + kFloat4_e2m1fn, kFloat4_e2m1fn, kFloat32, kFloat8_e4m3, 16, 8, 64, false, + true, 32, SM120_NVFP4_NVFP4_F32_UE4M3_TN) +TL_DEFINE_BLOCKSCALED_MMA_DISPATCHER( + kFloat4_e2m1fn, kFloat4_e2m1fn, kFloat32, kFloat8_e4m3, 16, 8, 64, false, + true, 16, SM120_NVFP4_NVFP4_F32_UE4M3_VS16_TN) + #undef TL_DEFINE_MMA_DISPATCHER +#undef TL_DEFINE_BLOCKSCALED_MMA_DISPATCHER } // namespace detail @@ -200,4 +296,27 @@ TL_DEVICE void mma_sync( } } +template +TL_DEVICE void mma_sync_blockscaled( + typename detail::BlockScaledMmaDispatcher< + AType, BType, CType, SFType, M, N, K, TransA, TransB, VS>::CRegType *c, + const typename detail::BlockScaledMmaDispatcher< + AType, BType, CType, SFType, M, N, K, TransA, TransB, VS>::ARegType *a, + const typename detail::BlockScaledMmaDispatcher< + AType, BType, CType, SFType, M, N, K, TransA, TransB, VS>::BRegType *b, + const typename detail::BlockScaledMmaDispatcher< + AType, BType, CType, SFType, M, N, K, TransA, TransB, VS>::SFRegType + *sfa, + const typename detail::BlockScaledMmaDispatcher< + AType, BType, CType, SFType, M, N, K, TransA, TransB, VS>::SFRegType + *sfb) { + using Dispatcher = detail::BlockScaledMmaDispatcher< + AType, BType, CType, SFType, M, N, K, TransA, TransB, VS>; + static_assert(!std::is_void_v, + "tl::mma_sync_blockscaled: unsupported configuration"); + + Dispatcher::exec(c, a, b, c, sfa, sfb); +} + } // namespace tl diff --git a/tilelang/cuda/intrinsics/macro/mma_macro_generator.py b/tilelang/cuda/intrinsics/macro/mma_macro_generator.py index 27dfc4fa30..dfc4e55bf3 100644 --- a/tilelang/cuda/intrinsics/macro/mma_macro_generator.py +++ b/tilelang/cuda/intrinsics/macro/mma_macro_generator.py @@ -81,10 +81,16 @@ def __init__( num_elems_per_byte: int = 1, is_m_first: bool | None = False, thread_var: Var | None = None, + is_blockscaled: bool = False, + scale_dtype: str = T.float8_e4m3, + scale_vec_size: int = 16, ): self.a_dtype = a_dtype self.b_dtype = b_dtype self.accum_dtype = accum_dtype + self.is_blockscaled = is_blockscaled + self.scale_dtype = scale_dtype + self.scale_vec_size = scale_vec_size self.a_transposed = a_transposed self.b_transposed = b_transposed # Hint Information @@ -118,6 +124,11 @@ def __init__( def _initialize_k_dim(self, a_dtype=T.float16): if isinstance(a_dtype, str): a_dtype = DataType(a_dtype) + if self.is_blockscaled: + if str(a_dtype) != "float4_e2m1fn" or self.chunk < 64: + raise ValueError("SM120 block-scaled NVFP4 MMA requires FP4 inputs and chunk >= 64") + self.k_dim = 64 + return # SM120 f8f6f4 FP4 MMA is m16n8k32. Although 256 / 4 would allow a # k64 fragment by bit count, there is no k64 dispatcher for FP4. if str(a_dtype) == "float4_e2m1fn": @@ -497,6 +508,36 @@ def _warp_mma(A_local_buf, B_local_buf, C_local_buf): return _warp_mma(A_local_buf, B_local_buf, C_local_buf) + def mma_blockscaled( + self, + A_local_buf: Buffer, + B_local_buf: Buffer, + C_local_buf: Buffer, + SFA_local_buf: Buffer, + SFB_local_buf: Buffer, + k_inner: PrimExpr | None = 0, + ): + if self.n_dim != 8: + raise ValueError("SM120 block-scaled MMA emitter currently supports one m16n8k64 atom at a time") + warp_rows = self.warp_rows + warp_cols = self.warp_cols + + @T.macro + def _warp_mma_blockscaled(A_local_buf, B_local_buf, C_local_buf, SFA_local_buf, SFB_local_buf): + for i, j in T.grid(warp_rows, warp_cols): + self.mma_blockscaled_atom( + A_local_buf, + B_local_buf, + C_local_buf, + SFA_local_buf, + SFB_local_buf, + i, + j, + k_inner, + ) + + return _warp_mma_blockscaled(A_local_buf, B_local_buf, C_local_buf, SFA_local_buf, SFB_local_buf) + # ---- Atom-level interface ---- @property @@ -598,6 +639,59 @@ def _atom_mma(A_local_buf, B_local_buf, C_local_buf): return _atom_mma(A_local_buf, B_local_buf, C_local_buf) + def mma_blockscaled_atom( + self, + A_local_buf: Buffer, + B_local_buf: Buffer, + C_local_buf: Buffer, + SFA_local_buf: Buffer, + SFB_local_buf: Buffer, + inst_m_idx: PrimExpr | int, + inst_n_idx: PrimExpr | int, + k_inner: PrimExpr | int = 0, + ): + local_size_a = self.local_size_a + local_size_b = self.local_size_b + local_size_out = self.local_size_out + a_dtype_abbrv = self.a_dtype_abbrv + b_dtype_abbrv = self.b_dtype_abbrv + accum_dtype = self.accum_dtype + accum_dtype_abbrv = self.accum_dtype_abbrv + mma_prefix = self.mma_prefix + + a_is_fragment = is_fragment(A_local_buf) + b_is_fragment = is_fragment(B_local_buf) + a_local_stride: PrimExpr = k_inner * self.warp_rows * local_size_a if a_is_fragment else 0 + b_local_stride: PrimExpr = k_inner * self.warp_cols * local_size_b if b_is_fragment else 0 + + A_offset = a_local_stride + inst_m_idx * local_size_a + B_offset = b_local_stride + inst_n_idx * local_size_b + C_offset = inst_m_idx * self.warp_cols * local_size_out + inst_n_idx * local_size_out + + @T.macro + def _atom_mma_blockscaled(A_local_buf, B_local_buf, C_local_buf, SFA_local_buf, SFB_local_buf): + T.ptx_mma_blockscaled( + accum_dtype, + mma_prefix, + "row", + "col", + a_dtype_abbrv, + b_dtype_abbrv, + accum_dtype_abbrv, + self.scale_dtype, + self.scale_vec_size, + A_local_buf.data, + A_offset, + B_local_buf.data, + B_offset, + C_local_buf.data, + C_offset, + SFA_local_buf.data, + SFB_local_buf.data, + ) + + return _atom_mma_blockscaled(A_local_buf, B_local_buf, C_local_buf, SFA_local_buf, SFB_local_buf) + def stmatrix(self, C_local_buf, C_buf, pid_m=None, pid_n=None): block_row_warps = self.block_row_warps block_col_warps = self.block_col_warps diff --git a/tilelang/language/ast/ir.py b/tilelang/language/ast/ir.py index 7865b74be7..bca77ec6c0 100644 --- a/tilelang/language/ast/ir.py +++ b/tilelang/language/ast/ir.py @@ -1881,6 +1881,7 @@ def wrapped(*args, **kwargs): call_llvm_pure_intrin = _dtype_forward(_tir_op.call_llvm_pure_intrin) call_pure_extern = _dtype_forward(_tir_op.call_pure_extern) ptx_mma = _dtype_forward(_tir_op.ptx_mma) +ptx_mma_blockscaled = _dtype_forward(_tir_op.ptx_mma_blockscaled) ptx_mma_sp = _dtype_forward(_tir_op.ptx_mma_sp) ptx_wgmma_ss = _dtype_forward(_tir_op.ptx_wgmma_ss) ptx_wgmma_rs = _dtype_forward(_tir_op.ptx_wgmma_rs) @@ -2135,6 +2136,7 @@ def wrapped(*args, **kwargs): "tvm_warp_shuffle_down", "tvm_warp_activemask", "ptx_mma", + "ptx_mma_blockscaled", "ptx_mma_sp", "ptx_wgmma_ss", "ptx_wgmma_rs", diff --git a/tilelang/language/tir/ir.py b/tilelang/language/tir/ir.py index 384be21ccc..62b4ba8ae3 100644 --- a/tilelang/language/tir/ir.py +++ b/tilelang/language/tir/ir.py @@ -287,6 +287,7 @@ def wrapped(*args, **kwargs): call_llvm_pure_intrin = _dtype_forward(_tir_op.call_llvm_pure_intrin) call_pure_extern = _dtype_forward(_tir_op.call_pure_extern) ptx_mma = _dtype_forward(_tir_op.ptx_mma) +ptx_mma_blockscaled = _dtype_forward(_tir_op.ptx_mma_blockscaled) ptx_mma_sp = _dtype_forward(_tir_op.ptx_mma_sp) ptx_wgmma_ss = _dtype_forward(_tir_op.ptx_wgmma_ss) ptx_wgmma_rs = _dtype_forward(_tir_op.ptx_wgmma_rs) diff --git a/tilelang/language/tir/op.py b/tilelang/language/tir/op.py index 0aee38da35..d72a8c32b8 100644 --- a/tilelang/language/tir/op.py +++ b/tilelang/language/tir/op.py @@ -960,6 +960,54 @@ def ptx_mma( ) +def ptx_mma_blockscaled( + dtype, + shape, + A_layout, + B_layout, + A_dtype, + B_dtype, + C_dtype, + SF_dtype, + scale_vec_size, + multiplicand_a, + a_index, + multiplicand_b, + b_index, + accumulator, + c_index, + scale_factor_a, + scale_factor_b, +): + """TileLang intrinsic for SM120 block-scaled MMA. + + This minimal frontend covers NVFP4 on SM120: + ``mma.sync.aligned.kind::mxf4nvf4.block_scale`` with FP4 A/B, + FP32 accumulators, and UE4M3 scale factors. + """ + return call_intrin( + "handle", + _tvm_op.Op.get("tl.ptx_mma_blockscaled"), + dtype, + shape, + A_layout, + B_layout, + A_dtype, + B_dtype, + C_dtype, + SF_dtype, + scale_vec_size, + multiplicand_a, + a_index, + multiplicand_b, + b_index, + accumulator, + c_index, + scale_factor_a, + scale_factor_b, + ) + + def ptx_mma_sp( dtype, shape, From 48db769af988bb416072018cf2b256f72c4cbfc2 Mon Sep 17 00:00:00 2001 From: wahao Date: Wed, 3 Jun 2026 21:58:44 +0800 Subject: [PATCH 15/19] feat(sm120): add T.nvfp4_gemm as frontend, lower to mma_blockscaled, fix ptx_mma_blockscaled meta param must be of StringImm/IntImm, add and verified gemm_nvfp4 & fused moe examples --- .../gemm_fp4/example_fusedmoe_nvfp4_sm120.py | 205 ++++++++++++++++++ examples/gemm_fp4/example_gemm_nvfp4_sm120.py | 130 ++++++----- tilelang/cuda/op/gemm/gemm_mma.py | 33 ++- tilelang/language/__init__.py | 1 + tilelang/language/gemm_op.py | 102 +++++++++ tilelang/language/tir/op.py | 21 +- tilelang/tileop/gemm/gemm_base.py | 2 + 7 files changed, 426 insertions(+), 68 deletions(-) create mode 100644 examples/gemm_fp4/example_fusedmoe_nvfp4_sm120.py diff --git a/examples/gemm_fp4/example_fusedmoe_nvfp4_sm120.py b/examples/gemm_fp4/example_fusedmoe_nvfp4_sm120.py new file mode 100644 index 0000000000..6c47a06b66 --- /dev/null +++ b/examples/gemm_fp4/example_fusedmoe_nvfp4_sm120.py @@ -0,0 +1,205 @@ +"""Native NVFP4 fused MoE gate/up example on SM120. + +This mirrors the FlashInfer/TRT-LLM NVFP4 data contract: + +* activations and expert weights are declared as ``T.float4_e2m1fn`` tensors; + the host still supplies torch uint8/int8 packed storage because PyTorch does + not expose a scalar FP4 tensor dtype. +* scale factors are FP8 E4M3/UE4M3 values, packed four-at-a-time into the + single scale register consumed by the current m16n8k64 MMA.SF wrapper. +* routing is represented by precomputed ``token_selected_experts`` and + ``token_final_scales``. This example implements the common fused FC1 pattern: + gate/up GEMMs followed by SiLU(gate) * up. The down projection is intentionally + left to a future generic block-scaled GEMM tiling path. + +The kernel is deliberately one CTA / one expert tile so the example validates the +native NVFP4 API and scale path without hiding behind uint8 tensor annotations. +""" + +import os +import time + +import torch +import tilelang +import tilelang.language as T +from tilelang.layout import make_swizzled_layout +import argparse + +FP4_E2M1_TO_FLOAT = [ + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, + -0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, +] + + +def pack_e4m3x4_to_u32(values: tuple[float, float, float, float]) -> int: + raw = torch.tensor(values, dtype=torch.float32).to(torch.float8_e4m3fn).view(torch.uint8) + return sum(int(byte) << (8 * i) for i, byte in enumerate(raw.tolist())) + + +def unpack_fp4_to_float(packed: torch.Tensor, rows: int, cols: int) -> torch.Tensor: + lut = torch.tensor(FP4_E2M1_TO_FLOAT, dtype=torch.float32, device=packed.device) + flat = packed.to(torch.uint8).reshape(rows, cols // 2) + lo = flat & 0x0F + hi = (flat >> 4) & 0x0F + codes = torch.stack([lo, hi], dim=-1).reshape(rows, cols).to(torch.int64) + return lut[codes] + + +def fusedmoe_nvfp4_sm120( + num_tokens=128, + hidden_size=256, + intermediate_size=128, + block_tokens=16, + block_hidden=64, + block_intermediate=8, + threads=32, +): + assert num_tokens % block_tokens == 0 + assert hidden_size % block_hidden == 0 + assert intermediate_size % block_intermediate == 0 + + packed_hidden = hidden_size // 2 + packed_block_hidden = block_hidden // 2 + token_blocks = num_tokens // block_tokens + hidden_blocks = hidden_size // block_hidden + intermediate_blocks = intermediate_size // block_intermediate + scale = 1.44269504 # log2(e) + + @T.prim_func + def main( + X: T.Tensor((num_tokens, hidden_size), T.float4_e2m1fn), + X_scale: T.Tensor((token_blocks, hidden_blocks), "uint32"), + W_gate: T.Tensor((intermediate_size, hidden_size), T.float4_e2m1fn), + W_gate_scale: T.Tensor((intermediate_blocks, hidden_blocks), "uint32"), + W_up: T.Tensor((intermediate_size, hidden_size), T.float4_e2m1fn), + W_up_scale: T.Tensor((intermediate_blocks, hidden_blocks), "uint32"), + token_selected_experts: T.Tensor((num_tokens,), "int32"), + token_final_scales: T.Tensor((num_tokens,), "float32"), + Output: T.Tensor((num_tokens, intermediate_size), "float32"), + ): + with T.Kernel(intermediate_blocks, token_blocks, threads=threads) as (bx, by): + # Public tensors use native packed FP4. The mxf4nvf4 ldmatrix path + # consumes the raw packed bytes, so create byte views inside the + # kernel instead of exposing uint8 tensors in the API. + X_bytes = T.view(X, (num_tokens, packed_hidden), "uint8") + W_gate_bytes = T.view(W_gate, (intermediate_size, packed_hidden), "uint8") + W_up_bytes = T.view(W_up, (intermediate_size, packed_hidden), "uint8") + X_shared = T.alloc_shared((block_tokens, packed_block_hidden), "uint8") + gate_shared = T.alloc_shared((block_intermediate, packed_block_hidden), "uint8") + up_shared = T.alloc_shared((block_intermediate, packed_block_hidden), "uint8") + + gate_local = T.alloc_fragment((block_tokens, block_intermediate), "float32") + up_local = T.alloc_fragment((block_tokens, block_intermediate), "float32") + SFA_local = T.alloc_local((1,), "uint32") + SFB_local = T.alloc_local((1,), "uint32") + + T.annotate_layout( + { + X_shared: make_swizzled_layout(X_shared), + gate_shared: make_swizzled_layout(gate_shared), + up_shared: make_swizzled_layout(up_shared), + } + ) + + T.clear(gate_local) + for ko in T.serial(hidden_blocks): + T.copy(X_bytes[by * block_tokens, ko * packed_block_hidden], X_shared) + T.copy(W_gate_bytes[bx * block_intermediate, ko * packed_block_hidden], gate_shared) + SFA_local[0] = X_scale[by, ko] + SFB_local[0] = W_gate_scale[bx, ko] + T.nvfp4_gemm(X_shared, gate_shared, SFA_local, SFB_local, gate_local, transpose_B=True, clear_accum=(ko == 0)) + + T.clear(up_local) + for ko in T.serial(hidden_blocks): + T.copy(X_bytes[by * block_tokens, ko * packed_block_hidden], X_shared) + T.copy(W_up_bytes[bx * block_intermediate, ko * packed_block_hidden], up_shared) + SFA_local[0] = X_scale[by, ko] + SFB_local[0] = W_up_scale[bx, ko] + T.nvfp4_gemm(X_shared, up_shared, SFA_local, SFB_local, up_local, transpose_B=True, clear_accum=(ko == 0)) + + for t, i in T.Parallel(block_tokens, block_intermediate): + gate = gate_local[t, i] + up = up_local[t, i] + token_idx = by * block_tokens + t + out_idx = bx * block_intermediate + i + routed = T.if_then_else(token_selected_experts[token_idx] == 0, token_final_scales[token_idx], 0.0) + silu = gate * (1.0 / (1.0 + T.exp2(-gate * scale))) + Output[token_idx, out_idx] = up * silu * routed + + return main + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("-n","--num-tokens", type=int, default=1024) + parser.add_argument("-c","--hidden-size", type=int, default=256) + parser.add_argument("-s","--intermediate-size", type=int, default=1024) + args = parser.parse_args() + num_tokens = args.num_tokens + hidden_size = args.hidden_size + intermediate_size = args.intermediate_size + one_scale = pack_e4m3x4_to_u32((1.0, 1.0, 1.0, 1.0)) + + kernel = tilelang.compile( + fusedmoe_nvfp4_sm120(num_tokens, hidden_size, intermediate_size), + out_idx=[8], + target="cuda", + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: False, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, + ) + print("Compilation succeeded!") + + torch.manual_seed(0) + x = torch.randint(0, 256, (num_tokens, hidden_size // 2), device="cuda", dtype=torch.uint8) + w_gate = torch.randint(0, 256, (intermediate_size, hidden_size // 2), device="cuda", dtype=torch.uint8) + w_up = torch.randint(0, 256, (intermediate_size, hidden_size // 2), device="cuda", dtype=torch.uint8) + token_blocks = num_tokens // 16 + hidden_blocks = hidden_size // 64 + intermediate_blocks = intermediate_size // 8 + x_scale = torch.full((token_blocks, hidden_blocks), one_scale, device="cuda", dtype=torch.uint32) + w_gate_scale = torch.full((intermediate_blocks, hidden_blocks), one_scale, device="cuda", dtype=torch.uint32) + w_up_scale = torch.full((intermediate_blocks, hidden_blocks), one_scale, device="cuda", dtype=torch.uint32) + selected = torch.zeros((num_tokens,), device="cuda", dtype=torch.int32) + routing = torch.ones((num_tokens,), device="cuda", dtype=torch.float32) + + out = kernel(x, x_scale, w_gate, w_gate_scale, w_up, w_up_scale, selected, routing) + + x_f32 = unpack_fp4_to_float(x, num_tokens, hidden_size) + gate_f32 = unpack_fp4_to_float(w_gate, intermediate_size, hidden_size) + up_f32 = unpack_fp4_to_float(w_up, intermediate_size, hidden_size) + gate = x_f32 @ gate_f32.T + up = x_f32 @ up_f32.T + ref = up * (gate * torch.sigmoid(gate)) + diff = (out.float() - ref).abs() + print(f"[NUMERICAL] max_abs_diff={diff.max().item():.4f}, rel_err={diff.sum().item() / (ref.abs().sum().item() + 1e-10):.6f}") + + warmup = int(os.environ.get("TL_NVFP4_MOE_WARMUP", "20")) + iters = int(os.environ.get("TL_NVFP4_MOE_ITERS", "100")) + for _ in range(warmup): + kernel(x, x_scale, w_gate, w_gate_scale, w_up, w_up_scale, selected, routing) + torch.cuda.synchronize() + start = time.perf_counter() + for _ in range(iters): + kernel(x, x_scale, w_gate, w_gate_scale, w_up, w_up_scale, selected, routing) + torch.cuda.synchronize() + elapsed_ms = (time.perf_counter() - start) * 1000 / iters + # Two FC1 GEMMs: gate and up. The SiLU/mul/routing epilogue is not included. + total_flops = 2 * num_tokens * hidden_size * intermediate_size * 2 + tflops = total_flops / (elapsed_ms / 1e3) / 1e12 + print(f"[BENCH] latency={elapsed_ms:.4f} ms, TFLOPS={tflops:.2f}, iters={iters}") diff --git a/examples/gemm_fp4/example_gemm_nvfp4_sm120.py b/examples/gemm_fp4/example_gemm_nvfp4_sm120.py index 61843e0d9c..f78db48f18 100644 --- a/examples/gemm_fp4/example_gemm_nvfp4_sm120.py +++ b/examples/gemm_fp4/example_gemm_nvfp4_sm120.py @@ -1,15 +1,24 @@ """Minimal NVFP4 GEMM on SM120 using block-scaled MMA. -This example exercises ``mma.sync.aligned.kind::mxf4nvf4.block_scale`` for -FP4 E2M1 inputs with UE4M3 scale factors. It intentionally starts with a -single SM120 MMA shape (m16n8k64) before wiring NVFP4 into the generic GEMM API. +FlashInfer/TRT-LLM use "NVFP4" as packed FP4 E2M1 data plus FP8 E4M3/UE4M3 +scale factors with ``sf_vec_size=16``. The corresponding SM120 Tensor Core +instruction is ``mma.sync.aligned.kind::mxf4nvf4.block_scale``: both MMA +operands are FP4, the scale factors are FP8 E4M3, and accumulation is FP32. + +That is distinct from the non-scaled A8W4 examples, which use FP8 x FP4 +``kind::f8f6f4`` MMA. FP16/BF16 usually appear around NVFP4 as source/output +types for quantization or epilogues, not as the direct mxf4nvf4 MMA operands. + +This example intentionally starts with one SM120 MMA atom (m16n8k64) and a +uniform scale register before wiring NVFP4 into the generic GEMM API. """ import os +import time + import torch import tilelang import tilelang.language as T -from tilelang.cuda.intrinsics.macro.mma_macro_generator import TensorCoreIntrinEmitter from tilelang.layout import make_swizzled_layout @@ -32,12 +41,9 @@ -6.0, ] -# UE4M3 encoding for scale 1.0, packed four times into one uint32 register. -UE4M3_ONE_X4 = 0x38383838 - def unpack_fp4_to_uint8(packed_int8: torch.Tensor, M: int, K: int) -> torch.Tensor: - """Unpack (M, K//2) int8 -> (M, K) uint8, one FP4 value per byte.""" + """Unpack (M, K//2) uint8 -> (M, K) uint8 FP4 codes for the reference.""" flat = packed_int8.to(torch.uint8).reshape(M, K // 2) lo = flat & 0x0F hi = (flat >> 4) & 0x0F @@ -51,45 +57,35 @@ def unpack_fp4_to_float(packed_int8: torch.Tensor, M: int, K: int) -> torch.Tens return lut[unpacked] +def pack_e4m3x4_to_u32(values: tuple[float, float, float, float]) -> int: + """Pack four FP8 E4M3 scale values into the uint32 register used by MMA.SF.""" + raw = torch.tensor(values, dtype=torch.float32).to(torch.float8_e4m3fn).view(torch.uint8) + return sum(int(byte) << (8 * i) for i, byte in enumerate(raw.tolist())) + + +UE4M3_ONE_X4 = pack_e4m3x4_to_u32((1.0, 1.0, 1.0, 1.0)) +UE4M3_HALF_X4 = pack_e4m3x4_to_u32((0.5, 0.5, 0.5, 0.5)) +UE4M3_TWO_X4 = pack_e4m3x4_to_u32((2.0, 2.0, 2.0, 2.0)) + + def matmul_nvfp4_sm120(M=16, N=8, K=64, out_dtype=T.float32, accum_dtype=T.float32): block_M, block_N, block_K = 16, 8, 64 threads = 32 - - emitter = TensorCoreIntrinEmitter( - a_dtype=T.float4_e2m1fn, - b_dtype=T.float4_e2m1fn, - accum_dtype=accum_dtype, - a_transposed=False, - b_transposed=True, - block_row_warps=1, - block_col_warps=1, - warp_row_tiles=block_M, - warp_col_tiles=block_N, - chunk=block_K, - is_blockscaled=True, - scale_dtype=T.float8_e4m3, - scale_vec_size=16, - ) - - local_size_a = emitter.local_size_a - local_size_b = emitter.local_size_b - local_size_c = emitter.local_size_out + packed_K = K // 2 + packed_block_K = block_K // 2 @T.prim_func def main( - A: T.Tensor((M, K), "uint8"), - B: T.Tensor((N, K), "uint8"), + A: T.Tensor((M, packed_K), "uint8"), + B: T.Tensor((N, packed_K), "uint8"), SFA: T.Tensor((T.ceildiv(M, block_M), T.ceildiv(K, block_K)), "uint32"), SFB: T.Tensor((T.ceildiv(N, block_N), T.ceildiv(K, block_K)), "uint32"), C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared((block_M, block_K), "uint8") - B_shared = T.alloc_shared((block_N, block_K), "uint8") - C_shared = T.alloc_shared((1, 1, block_M, block_N), out_dtype) - A_local = T.alloc_local((local_size_a,), T.float4_e2m1fn) - B_local = T.alloc_local((local_size_b,), T.float4_e2m1fn) - C_local = T.alloc_local((local_size_c,), accum_dtype) + A_shared = T.alloc_shared((block_M, packed_block_K), "uint8") + B_shared = T.alloc_shared((block_N, packed_block_K), "uint8") + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) SFA_local = T.alloc_local((1,), "uint32") SFB_local = T.alloc_local((1,), "uint32") @@ -102,23 +98,19 @@ def main( T.clear(C_local) for ko in T.serial(T.ceildiv(K, block_K)): - T.copy(A[by * block_M, ko * block_K], A_shared) - T.copy(B[bx * block_N, ko * block_K], B_shared) + T.copy(A[by * block_M, ko * packed_block_K], A_shared) + T.copy(B[bx * block_N, ko * packed_block_K], B_shared) SFA_local[0] = SFA[by, ko] SFB_local[0] = SFB[bx, ko] - emitter.ldmatrix_a(A_local, A_shared, 0) - emitter.ldmatrix_b(B_local, B_shared, 0) - emitter.mma_blockscaled(A_local, B_local, C_local, SFA_local, SFB_local) + T.nvfp4_gemm(A_shared, B_shared, SFA_local, SFB_local, C_local, transpose_B=True, clear_accum=(ko == 0)) - emitter.stmatrix(C_local, C_shared) - for i, j in T.Parallel(block_M, block_N): - C[by * block_M + i, bx * block_N + j] = C_shared[0, 0, i, j] + T.copy(C_local, C[by * block_M, bx * block_N]) return main if __name__ == "__main__": - M, N, K = 16, 8, 64 + M, N, K = 1024, 1024, 256 print(f"Running NVFP4 SM120 block-scaled MMA: M={M}, N={N}, K={K}") kernel = tilelang.compile( @@ -137,25 +129,47 @@ def main( f.write(kernel.get_kernel_source()) torch.manual_seed(0) - a_packed = torch.randint(0, 256, (M, K // 2), device="cuda", dtype=torch.uint8).to(torch.int8) - b_packed = torch.randint(0, 256, (N, K // 2), device="cuda", dtype=torch.uint8).to(torch.int8) - a_unpacked = unpack_fp4_to_uint8(a_packed, M, K) - b_unpacked = unpack_fp4_to_uint8(b_packed, N, K) - - sfa = torch.full((1, 1), UE4M3_ONE_X4, device="cuda", dtype=torch.uint32) - sfb = torch.full((1, 1), UE4M3_ONE_X4, device="cuda", dtype=torch.uint32) - - a_zero = torch.zeros((M, K), device="cuda", dtype=torch.uint8) - b_zero = torch.zeros((N, K), device="cuda", dtype=torch.uint8) + a_packed = torch.randint(0, 256, (M, K // 2), device="cuda", dtype=torch.uint8) + b_packed = torch.randint(0, 256, (N, K // 2), device="cuda", dtype=torch.uint8) + + # To fit for mma.m16n8k64, we need to pad the scale factors to the nearest multiples accordingly. + scale_shape_a = ((M + 15) // 16, (K + 63) // 64) + scale_shape_b = ((N + 7) // 8, (K + 63) // 64) + sfa = torch.full(scale_shape_a, UE4M3_ONE_X4, device="cuda", dtype=torch.uint32) + sfb = torch.full(scale_shape_b, UE4M3_ONE_X4, device="cuda", dtype=torch.uint32) + + a_zero = torch.zeros((M, K // 2), device="cuda", dtype=torch.uint8) + b_zero = torch.zeros((N, K // 2), device="cuda", dtype=torch.uint8) c_zero = kernel(a_zero, b_zero, sfa, sfb) print(f"[ZERO] max_abs={c_zero.abs().max().item():.4f}") - a_one = torch.full((M, K), 2, device="cuda", dtype=torch.uint8) - b_one = torch.full((N, K), 2, device="cuda", dtype=torch.uint8) + one_pair = (2 | (2 << 4)) # two FP4 E2M1 1.0 values packed in one byte + a_one = torch.full((M, K // 2), one_pair, device="cuda", dtype=torch.uint8) + b_one = torch.full((N, K // 2), one_pair, device="cuda", dtype=torch.uint8) c_one = kernel(a_one, b_one, sfa, sfb) print(f"[ONE] first={c_one[0, 0].item():.4f}, expected={float(K):.4f}") - c = kernel(a_unpacked, b_unpacked, sfa, sfb) + # Exercise the scale path explicitly. SFA=0.5 and SFB=2.0 should preserve + # the effective product for all-one FP4 inputs. + sfa_half = torch.full(scale_shape_a, UE4M3_HALF_X4, device="cuda", dtype=torch.uint32) + sfb_two = torch.full(scale_shape_b, UE4M3_TWO_X4, device="cuda", dtype=torch.uint32) + c_scaled = kernel(a_one, b_one, sfa_half, sfb_two) + print(f"[SCALE] first={c_scaled[0, 0].item():.4f}, expected={float(K):.4f}") + + c = kernel(a_packed, b_packed, sfa, sfb) ref = unpack_fp4_to_float(a_packed, M, K) @ unpack_fp4_to_float(b_packed, N, K).T diff = (c.float() - ref).abs() print(f"[NUMERICAL] max_abs_diff={diff.max().item():.4f}, rel_err={diff.sum().item() / (ref.abs().sum().item() + 1e-10):.6f}") + + warmup = int(os.environ.get("TL_NVFP4_WARMUP", "20")) + iters = int(os.environ.get("TL_NVFP4_ITERS", "100")) + for _ in range(warmup): + kernel(a_packed, b_packed, sfa, sfb) + torch.cuda.synchronize() + start = time.perf_counter() + for _ in range(iters): + kernel(a_packed, b_packed, sfa, sfb) + torch.cuda.synchronize() + elapsed_ms = (time.perf_counter() - start) * 1000 / iters + tflops = 2 * M * N * K / (elapsed_ms / 1e3) / 1e12 + print(f"[BENCH] latency={elapsed_ms:.4f} ms, TFLOPS={tflops:.2f}, iters={iters}") diff --git a/tilelang/cuda/op/gemm/gemm_mma.py b/tilelang/cuda/op/gemm/gemm_mma.py index 0a1157a871..598845b8f9 100644 --- a/tilelang/cuda/op/gemm/gemm_mma.py +++ b/tilelang/cuda/op/gemm/gemm_mma.py @@ -17,13 +17,22 @@ GEMM_INST_MMA = "cuda.mma" +def _annotation_value(annotations: dict, key: str, default): + value = annotations.get(key, default) + return getattr(value, "value", value) + + class GemmMMA(GemmBase): intrin_emitter_cls = TensorCoreIntrinEmitter def _make_mma_emitter(self, target: Target, thread_nums: int, thread_var: tirx.Var | None = None): - m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, GEMM_INST_MMA) + if self.is_blockscaled: + m_warp, n_warp = 1, 1 + else: + m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, GEMM_INST_MMA) warp_row_tiles = int(self.M // m_warp) warp_col_tiles = int(self.N // n_warp) + annotations = getattr(self.gemm_node, "annotations", {}) emitter = self.intrin_emitter_cls( a_dtype=self.in_dtype, b_dtype=self.in_dtype_b, @@ -36,6 +45,9 @@ def _make_mma_emitter(self, target: Target, thread_nums: int, thread_var: tirx.V warp_col_tiles=warp_col_tiles, chunk=self.chunk, thread_var=thread_var, + is_blockscaled=self.is_blockscaled, + scale_dtype=str(_annotation_value(annotations, "scale_dtype", "float8_e4m3")), + scale_vec_size=int(_annotation_value(annotations, "scale_vec_size", 16)), ) return emitter @@ -103,6 +115,25 @@ def lower( assert is_full_region(C_region), "Fragment output C must be a full region" + if self.is_blockscaled: + if not self.is_gemm_ss(): + raise ValueError("SM120 NVFP4 block-scaled MMA currently supports shared/shared tiles only") + + @T.prim_func + def _gemm_blockscaled_ssr() -> None: + A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), in_dtype_b) + SFA_buf = self.SFARegion.buffer + SFB_buf = self.SFBRegion.buffer + if clear_accum: + T.clear(C_buf) + for ki in T.serial(0, (block_K // micro_size_k)): + mma_emitter.ldmatrix_a(A_local, A_region, ki) + mma_emitter.ldmatrix_b(B_local, B_region, ki) + mma_emitter.mma_blockscaled(A_local, B_local, C_buf, SFA_buf, SFB_buf, ki) + + return _Simplify(_gemm_blockscaled_ssr, inline_let=True) + if self.is_gemm_ss(): @T.prim_func diff --git a/tilelang/language/__init__.py b/tilelang/language/__init__.py index 60aaef834b..7ab325f7c5 100644 --- a/tilelang/language/__init__.py +++ b/tilelang/language/__init__.py @@ -70,6 +70,7 @@ from tilelang.tileop.base import GemmWarpPolicy # noqa: F401 from .gemm_op import ( # noqa: F401 gemm, + nvfp4_gemm, wgmma_gemm, tcgen05_gemm, tcgen05_gemm_blockscaled, diff --git a/tilelang/language/gemm_op.py b/tilelang/language/gemm_op.py index 81ae93cf74..064005a44d 100644 --- a/tilelang/language/gemm_op.py +++ b/tilelang/language/gemm_op.py @@ -233,6 +233,108 @@ def wgmma_gemm( ) +def nvfp4_gemm( + A: BufferLikeType, + B: BufferLikeType, + SFA: BufferLikeType, + SFB: BufferLikeType, + C: BufferLikeType, + transpose_A: bool = False, + transpose_B: bool = True, + policy: GemmWarpPolicy = GemmWarpPolicy.Square, + clear_accum: bool = False, + scale_dtype: str = "float8_e4m3", + scale_vec_size: int = 16, +) -> tirx.PrimExpr: + """SM120 NVFP4 block-scaled GEMM tile op. + + This is a thin SM120-only frontend for the currently supported + ``mxf4nvf4.block_scale`` MMA path. A and B are packed FP4 tiles (normally + exposed as ``T.float4_e2m1fn`` tensors and viewed as packed bytes before + shared-memory staging), while SFA/SFB are scale-register buffers containing + packed FP8 E4M3 values. + """ + + def legalize(arg): + if isinstance(arg, tirx.Var) and T.has_let_value(arg): + return T.get_let_value(arg).buffer + return arg + + A = legalize(A) + B = legalize(B) + C = legalize(C) + SFA = legalize(SFA) + SFB = legalize(SFB) + + A_region = to_buffer_region(A) + B_region = to_buffer_region(B) + C_region = to_buffer_region(C) + SFA_region = to_buffer_region(SFA) + SFB_region = to_buffer_region(SFB) + + A_shape = retrieve_shape(A_region) + B_shape = retrieve_shape(B_region) + C_shape = retrieve_shape(C_region) + + M, N = C_shape + M_A = A_shape[-1] if transpose_A else A_shape[-2] + K = A_shape[-2] if transpose_A else A_shape[-1] + N_B = B_shape[-2] if transpose_B else B_shape[-1] + K_B = B_shape[-1] if transpose_B else B_shape[-2] + assert prim_expr_equal(M_A, M), f"T.nvfp4_gemm M shape check failed: M_A = {M_A}, M_C = {M}" + assert prim_expr_equal(N_B, N), f"T.nvfp4_gemm N shape check failed: N_B = {N_B}, N_C = {N}" + assert prim_expr_equal(K, K_B), f"T.nvfp4_gemm K shape check failed: K_A = {K}, K_B = {K_B}" + # The current SM120 mxf4nvf4 ldmatrix path consumes packed bytes in shared + # memory (two FP4 values per byte). The GEMM node still needs the logical + # K so the emitter selects m16n8k64. + logical_K = K * 2 if str(A_region.buffer.dtype) == "uint8" and str(B_region.buffer.dtype) == "uint8" else K + + A_stride = retrieve_stride(A_region) + B_stride = retrieve_stride(B_region) + A_offset = retrieve_offset(A_region) + B_offset = retrieve_offset(B_region) + C_coords = [r.min for r in C_region.region] + + A_arg = buffer_region_to_tile_region(A_region, "r", [r for r in A_shape]) + B_arg = buffer_region_to_tile_region(B_region, "r", [r for r in B_shape]) + C_arg = buffer_region_to_tile_region(C_region, "rw", [r for r in C_shape]) + SFA_arg = buffer_region_to_tile_region(SFA_region, "r", list(retrieve_shape(SFA_region))) + SFB_arg = buffer_region_to_tile_region(SFB_region, "r", list(retrieve_shape(SFB_region))) + + return tirx.call_intrin( + "handle", + tirx.op.Op.get("tl.tileop.gemm"), + A_arg, + B_arg, + C_arg, + transpose_A, + transpose_B, + M, + N, + logical_K, + policy, + clear_accum, + A_stride[-2], + B_stride[-2], + A_offset[-1], + B_offset[-1], + 1, + 0, + tirx.const(0, dtype="int32"), + C_coords[0], + C_coords[1], + SFA_arg, + SFB_arg, + tirx.const(0, dtype="int32"), + tirx.const(0, dtype="int32"), + annotations={ + "is_nvfp4": tirx.IntImm("int32", 1), + "scale_dtype": tirx.StringImm(scale_dtype), + "scale_vec_size": tirx.IntImm("int32", scale_vec_size), + }, + ) + + def tcgen05_gemm( A: BufferLikeType, B: BufferLikeType, diff --git a/tilelang/language/tir/op.py b/tilelang/language/tir/op.py index a46146fafd..ff4700d1c2 100644 --- a/tilelang/language/tir/op.py +++ b/tilelang/language/tir/op.py @@ -1000,18 +1000,21 @@ def ptx_mma_blockscaled( ``mma.sync.aligned.kind::mxf4nvf4.block_scale`` with FP4 A/B, FP32 accumulators, and UE4M3 scale factors. """ + def _string_imm(value): + return tvm.tirx.StringImm(str(value)) + return call_intrin( "handle", _tvm_op.Op.get("tl.ptx_mma_blockscaled"), - dtype, - shape, - A_layout, - B_layout, - A_dtype, - B_dtype, - C_dtype, - SF_dtype, - scale_vec_size, + _string_imm(dtype), + _string_imm(shape), + _string_imm(A_layout), + _string_imm(B_layout), + _string_imm(A_dtype), + _string_imm(B_dtype), + _string_imm(C_dtype), + _string_imm(SF_dtype), + IntImm("int32", scale_vec_size), multiplicand_a, a_index, multiplicand_b, diff --git a/tilelang/tileop/gemm/gemm_base.py b/tilelang/tileop/gemm/gemm_base.py index bfab31806a..a011251234 100644 --- a/tilelang/tileop/gemm/gemm_base.py +++ b/tilelang/tileop/gemm/gemm_base.py @@ -113,6 +113,8 @@ def accum_dtype(self) -> str: @property def chunk(self) -> int: + if self.is_blockscaled: + return self.K return self.A.shape[-2] if self.trans_A else self.A.shape[-1] @property From 8aa292aae73f739fc001a96eed57c482ed912813 Mon Sep 17 00:00:00 2001 From: wahao Date: Fri, 5 Jun 2026 15:10:32 +0800 Subject: [PATCH 16/19] feat(sm100): add mxf4nvf4 feature for ptx code, meta/descriptor, connected in codegen_cuda.cc, called at tcgen05_macro_generator.py --- src/cuda/codegen/codegen_cuda.cc | 34 ++-- src/cuda/op/gemm.cc | 7 + src/op/builtin.cc | 2 +- src/op/tcgen5_meta.h | 50 ++++++ .../cuda/instruction/tcgen05mma.h | 62 ++++++++ .../macro/tcgen05_macro_generator.py | 147 +++++++++++++++++- 6 files changed, 291 insertions(+), 11 deletions(-) diff --git a/src/cuda/codegen/codegen_cuda.cc b/src/cuda/codegen/codegen_cuda.cc index 59db2a85e8..0ede7da0bb 100644 --- a/src/cuda/codegen/codegen_cuda.cc +++ b/src/cuda/codegen/codegen_cuda.cc @@ -3408,24 +3408,39 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { std::string sfa_offset = this->PrintExpr(op->args[10]); std::string sfb_ref = this->PrintExpr(op->args[11]); std::string sfb_offset = this->PrintExpr(op->args[12]); - // args[13], [14] reserved for future mask/flags + bool use_mxf4nvf4 = Downcast(op->args[13])->value != 0; + // args[14] reserved for future mask/flags bool enable_ws = Downcast(op->args[15])->value; bool enable_2cta = Downcast(op->args[16])->value; ICHECK(!(enable_ws && enable_2cta)) << "Block-scaled TCGEN05 does not support combining .ws and 2CTA"; + ICHECK(!use_mxf4nvf4 || !enable_ws) + << "mxf4nvf4 block-scaled TCGEN05 currently supports SS only"; auto dtype_enum = tl::codegen::ptx::DTypeFromString(kind_dtype); need_tcgen05mma_instruction_h_ = true; this->PrintIndent(); - std::string tcgen05_call = - "tl::(tcgen05_name)<(ABType), (USE_2CTA)>(uint64_t((desc_a) + " - "(A_offset)), " - "uint64_t((desc_b) + (B_offset)), (*reinterpret_cast((C))) " - "+ (C_offset), " - "(scale_out), static_cast((desc_val)), " - "(*reinterpret_cast((SFA))) + (SFA_offset), " - "(*reinterpret_cast((SFB))) + (SFB_offset));\n"; + std::string tcgen05_call; + if (use_mxf4nvf4) { + tcgen05_call = + "tl::tcgen05mma_mxf4nvf4_blockscaled_ss<(USE_2CTA)>(" + "uint64_t((desc_a) + (A_offset)), " + "uint64_t((desc_b) + (B_offset)), " + "(*reinterpret_cast((C))) + (C_offset), " + "(scale_out), static_cast((desc_val)), " + "(*reinterpret_cast((SFA))) + (SFA_offset), " + "(*reinterpret_cast((SFB))) + (SFB_offset));\n"; + } else { + tcgen05_call = + "tl::(tcgen05_name)<(ABType), (USE_2CTA)>(uint64_t((desc_a) + " + "(A_offset)), " + "uint64_t((desc_b) + (B_offset)), (*reinterpret_cast((C))) " + "+ (C_offset), " + "(scale_out), static_cast((desc_val)), " + "(*reinterpret_cast((SFA))) + (SFA_offset), " + "(*reinterpret_cast((SFB))) + (SFB_offset));\n"; + } tl::codegen::Replacer replacer; replacer.register_rule("(ABType)", tl::codegen::ptx::DTypeEnumToString(dtype_enum)); @@ -3439,6 +3454,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { replacer.register_rule("(tcgen05_name)", enable_ws ? "tcgen05mma_blockscaled_ws_ss" : "tcgen05mma_blockscaled_ss"); + replacer.register_rule("(scale_out)", scale_out); replacer.register_rule("(desc_val)", this->PrintExpr(desc_expr)); replacer.register_rule("(SFA)", sfa_ref); diff --git a/src/cuda/op/gemm.cc b/src/cuda/op/gemm.cc index d86b7e6032..4837308fb8 100644 --- a/src/cuda/op/gemm.cc +++ b/src/cuda/op/gemm.cc @@ -381,6 +381,13 @@ TVM_FFI_STATIC_INIT_BLOCK() { b_sf_id); return Integer(static_cast(desc)); }); + refl::GlobalDef().def("tl.get_tcgen5_mxf4nvf4_blockscaled_instr_desc", + [](int atom_m, int atom_n, DataType ab_dtype, + int scale_in_a, int scale_in_b, bool is_mxfp4) { + uint32_t desc = GetTCGEN5MXF4NVF4BlockScaledInstrDesc( + atom_m, atom_n, ab_dtype, scale_in_a, scale_in_b, is_mxfp4); + return Integer(static_cast(desc)); + }); } } // namespace tl diff --git a/src/op/builtin.cc b/src/op/builtin.cc index 2f793fd35a..53ae00ca22 100644 --- a/src/op/builtin.cc +++ b/src/op/builtin.cc @@ -252,7 +252,7 @@ TIR_DEFINE_TL_BUILTIN(ptx_tcgen05_mma_ts) Integer(CallEffectKind::kOpaque)); TIR_DEFINE_TL_BUILTIN(ptx_tcgen05_mma_blockscaled_ss) - .set_num_inputs(16) + .set_num_inputs(17) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); diff --git a/src/op/tcgen5_meta.h b/src/op/tcgen5_meta.h index 60f6fcaa43..2b3a74c006 100644 --- a/src/op/tcgen5_meta.h +++ b/src/op/tcgen5_meta.h @@ -313,6 +313,56 @@ inline uint32_t GetTCGEN5BlockScaledInstrDesc(int atom_m, int atom_n, return desc; } +// Build block-scaled instruction descriptor for mxf4nvf4.block_scale +// Bit layout: InstrDescriptorBlockScaled (see CUTLASS mma_sm100_desc.hpp) +inline uint32_t GetTCGEN5MXF4NVF4BlockScaledInstrDesc(int atom_m, int atom_n, + DataType ab_dtype, + int scale_in_a, + int scale_in_b, + bool is_mxfp4) { + +ICHECK(atom_m % 16 == 0) << "atom_m must be divisible by 16"; +ICHECK(atom_n % 8 == 0) << "atom_n must be divisible by 8"; +ICHECK(scale_in_a == 1 || scale_in_a == -1); +ICHECK(scale_in_b == 1 || scale_in_b == -1); +ICHECK(ab_dtype.is_float4_e2m1fn()) << "ab_dtype must be float4_e2m1fn for mxf4nvf4"; + + +auto set_bits = [](uint32_t value, int start, int width) -> uint32_t { +uint32_t mask = (width == 32) ? 0xFFFFFFFFu : ((1u << width) - 1); +return (value & mask) << start; +}; + + +uint32_t a_neg = (scale_in_a == -1) ? 1u : 0u; +uint32_t b_neg = (scale_in_b == -1) ? 1u : 0u; +uint32_t n_dim = static_cast(atom_n >> 3); +uint32_t m_dim = static_cast(atom_m >> 7); + +uint32_t desc = 0; +// bit 0-1 reserved +desc |= set_bits(0, 2, 1); // sparse_flag +// bit 3 reserved +// bit 4-5: b_sf_id, enforce to be 0 for scale_vec::4X, see here https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x +// bit 6 reserved +desc |= set_bits(1, 7, 3); // a_format, enforce to be 1 for e2m1 format +desc |= set_bits(1, 10, 2); // b_format, enforce to be 1 for e2m1 format +// bit 12 reserved +desc |= set_bits(a_neg, 13, 1); // a_negate +desc |= set_bits(b_neg, 14, 1); // b_negate +// bit 15-16, enforce to be 0 as AB is fixed non-transposed +desc |= set_bits(n_dim, 17, 6); // n_dim +desc |= set_bits(is_mxfp4 ? 1 : 0, 23, 1); // scale_format = 1 (E8M0) for mxfp4, 0 (E4M3) for nvf4 +// bit 24-26 reserved +desc |= set_bits(m_dim, 27, 2); // m_dim +// bit 29-30: a_sf_id, enforce to be 0 for scale_vec::4X, +desc |= set_bits(static_cast(a_sf_id), 29, 2); // a_sf_id +// bit 31 reserved + +return desc; +} + + } // namespace tl } // namespace tvm diff --git a/src/tl_templates/cuda/instruction/tcgen05mma.h b/src/tl_templates/cuda/instruction/tcgen05mma.h index c719e05e10..ab13477a10 100644 --- a/src/tl_templates/cuda/instruction/tcgen05mma.h +++ b/src/tl_templates/cuda/instruction/tcgen05mma.h @@ -801,4 +801,66 @@ TL_DEVICE void tcgen05mma_blockscaled_ss( desc_a, desc_b, tmem_c, scalec, desc_val, tmem_sfa, tmem_sfb); } +// NVFP4 block-scaled uses mxf4nvf4 instead of the mxf8f6f4 instruction family +// above. Scale factors are addressed through TMEM operands. +template +TL_DEVICE void tcgen05mma_mxf4nvf4_blockscaled_ss( + uint64_t const & /*desc_a*/, uint64_t const & /*desc_b*/, + uint32_t const & /*tmem_c*/, uint32_t const & /*scalec*/, + uint32_t const & /*desc_val*/, uint32_t const & /*tmem_sfa*/, + uint32_t const & /*tmem_sfb*/) { + static_assert(always_false_v>, + "tl::tcgen05mma_mxf4nvf4_blockscaled_ss: unsupported CTA group"); +} + +template <> +TL_DEVICE void tcgen05mma_mxf4nvf4_blockscaled_ss( + uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c, + uint32_t const &scalec, uint32_t const &desc_val, uint32_t const &tmem_sfa, + uint32_t const &tmem_sfb) { + if (cute::elect_one_sync()) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" +#if (__CUDACC_VER_MAJOR__ > 12) || \ + (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9) + "tcgen05.mma.cta_group::1.kind::mxf4nvf4.block_scale.block16 [%0], %1, " + "%2, %3, [%5], [%6], p; \n\t" +#else + "tcgen05.mma.cta_group::1.kind::mxf4nvf4.block_scale.scale_vec::4X " + "[%0], %1, %2, %3, [%5], [%6], p; \n\t" +#endif + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(desc_val), "r"(scalec), + "r"(tmem_sfa), "r"(tmem_sfb)); + } +} + +template <> +TL_DEVICE void tcgen05mma_mxf4nvf4_blockscaled_ss( + uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c, + uint32_t const &scalec, uint32_t const &desc_val, uint32_t const &tmem_sfa, + uint32_t const &tmem_sfb) { + if (cute::elect_one_sync()) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" +#if (__CUDACC_VER_MAJOR__ > 12) || \ + (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9) + "tcgen05.mma.cta_group::2.kind::mxf4nvf4.block_scale.block16 [%0], %1, " + "%2, %3, [%5], [%6], p; \n\t" +#else + "tcgen05.mma.cta_group::2.kind::mxf4nvf4.block_scale.scale_vec::4X " + "[%0], %1, %2, %3, [%5], [%6], p; \n\t" +#endif + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(desc_val), "r"(scalec), + "r"(tmem_sfa), "r"(tmem_sfb)); + } +} + } // namespace tl diff --git a/tilelang/cuda/intrinsics/macro/tcgen05_macro_generator.py b/tilelang/cuda/intrinsics/macro/tcgen05_macro_generator.py index fa963fe386..f68af5b56f 100644 --- a/tilelang/cuda/intrinsics/macro/tcgen05_macro_generator.py +++ b/tilelang/cuda/intrinsics/macro/tcgen05_macro_generator.py @@ -445,6 +445,150 @@ def get_tcgen5_blockscaled_instr_desc( ) return lift(desc) + def tcgen05mma_mxf4nvf4_blockscaled( + self, + A_buf: Buffer, + B_buf: Buffer, + C_local_buf: Buffer, + SFA_tmem, + SFB_tmem, + mbar, + clear_accum: PrimExpr = False, + ): + """Emit a mxf4nvf4 block-scaled TCGEN5MMA (SS variant with TMEM scale factors). + + Uses ``tcgen05.mma.cta_group::1|2.kind::mxf4nvf4.block_scale`` PTX instruction. + Scale factors must already reside in tensor memory. + """ + m_dim = self.block_row_warps * self.warp_row_tiles + micro_size_k = self.micro_size_k + k_dim, n_dim = self.chunk, self.block_col_warps * self.warp_col_tiles + + assert k_dim >= micro_size_k + assert (not self.a_transposed) and (not self.b_transposed), "A and B must be non-transposed/K-major for mxf4nvf4 block-scaled" + + a_swizzle_mode = self._determinate_swizzle_mode(A_buf, self.a_shared_layout) + + a_elems_in_bytes = (DataType(self.a_dtype).bits + 7) // 8 + + if len(self.meta) != 5: + self.get_tcgen5_mma_meta(m_dim, n_dim, k_dim, disable_2cta=False) + if len(self.meta) != 5: + raise ValueError( + f"Unsupported TCGEN5MMA configuration for block-scaled: M={m_dim}, N={n_dim}, " + f"K={k_dim}, A dtype={self.a_dtype}, accum dtype={self.accum_dtype}" + ) + atom_m, atom_n, _, _, enable_2cta = self.tcgen05_meta_unpacked + atom_m_per_cta = atom_m // 2 if enable_2cta else atom_m + + a_swizzle_atom_elems = atom_m_per_cta if a_swizzle_mode.is_none() else (a_swizzle_mode.swizzle_byte_size() * 8) // DataType(self.a_dtype).bits + # Block-scaled A LBO/SBO differ from regular SS (uses atom_m_per_cta instead of m_dim) + + # Below computation assumes A is K-major + if a_swizzle_mode.is_none(): + a_leading_byte_offset = 16 // 2 # per 8x2 tile, each has 8 bytes + a_stride_byte_offset = 8 * k_dim // 2 # the offset from the first 8 rows to the next 8 rows. + else: + a_leading_byte_offset = 1 # not used, fixed to 1 + a_stride_byte_offset = 8 * k_dim // 2 # the offset from the first 8 rows to the next 8 rows. + + a_params = TCGEN05DescriptorParams( + swizzle_mode=int(a_swizzle_mode), + leading_byte_offset=int(a_leading_byte_offset >> 4), + stride_byte_offset=int(a_stride_byte_offset >> 4), + swizzle_atom_elems=a_swizzle_atom_elems, + k_atom_size=max(a_swizzle_atom_elems // micro_size_k, 1), + elems_in_bytes=a_elems_in_bytes, + is_k_major=True, + ) + + # Below computation assumes B is K-major + b_swizzle_mode = self._determinate_swizzle_mode(B_buf, self.b_shared_layout) + b_elems_in_bytes = (DataType(self.b_dtype).bits + 7) // 8 + n_dim_per_cta = n_dim // 2 if enable_2cta else n_dim + b_swizzle_atom_elems = ( + n_dim_per_cta + if b_swizzle_mode.is_none() + else (b_swizzle_mode.swizzle_byte_size() * 8) // DataType(self.b_dtype).bits + ) + + if b_swizzle_mode.is_none(): + b_leading_byte_offset = 16 // 2 # per 8x2 tile, each has 8 bytes + b_stride_byte_offset = 8 * k_dim // 2 # the offset from the first 8 rows to the next 8 rows. + else: + b_leading_byte_offset = 1 # not used, fixed to 1 + b_stride_byte_offset = 8 * k_dim // 2 # the offset from the first 8 rows to the next 8 rows. + b_params = TCGEN05DescriptorParams( + swizzle_mode=int(b_swizzle_mode), + leading_byte_offset=int(b_leading_byte_offset >> 4), + stride_byte_offset=int(b_stride_byte_offset >> 4), + swizzle_atom_elems=b_swizzle_atom_elems, + k_atom_size=max(b_swizzle_atom_elems // micro_size_k, 1), + elems_in_bytes=b_elems_in_bytes, + is_k_major=True, + ) + + sf_dtype = DataType(SFA_tmem.dtype) + is_mxfp4 = sf_dtype == DataType("float8_e8m0fnu") or sf_dtype == DataType("uint8") # UE8M0 storage + + inst_desc = _ffi_api.get_tcgen5_mxf4nvf4_blockscaled_instr_desc( + atom_m, + atom_n, + DataType(self.a_dtype), + 1, + 1, + is_mxfp4, + ) + instr_desc = lift(inst_desc) + + if isinstance(SFA_tmem, BufferRegion): + sfa_data = SFA_tmem.buffer.data + elif isinstance(SFA_tmem, Buffer): + sfa_data = SFA_tmem.data + else: + raise ValueError(f"Unsupported SFA_tmem type: {type(SFA_tmem)}") + + if isinstance(SFB_tmem, BufferRegion): + sfb_data = SFB_tmem.buffer.data + elif isinstance(SFB_tmem, Buffer): + sfb_data = SFB_tmem.data + else: + raise ValueError(f"Unsupported SFB_tmem type: {type(SFB_tmem)}") + + desc_a = T.alloc_tcgen05_smem_desc() + desc_b = T.alloc_tcgen05_smem_desc() + self.init_tcgen05_a_desc(desc_a, A_buf, a_params) + self.init_tcgen05_b_desc(desc_b, B_buf, b_params) + + num_inst_m = m_dim // atom_m_per_cta + num_inst_n = n_dim // atom_n + num_k_atoms = self.tcgen05_num_k_atoms + + @T.macro + def _warp_mma_mxf4nvf4_blockscaled(desc_a, desc_b, C_local_buf, sfa_data, sfb_data, mbar): + for j in T.unroll(num_inst_n): + for i in T.unroll(num_inst_m): + for ki in T.unroll(0, num_k_atoms): + self.tcgen05_blockscaled_atom( + desc_a, + desc_b, + C_local_buf, + sfa_data, + sfb_data, + i, + j, + ki, + a_params, + b_params, + instr_desc, + clear_accum, + use_mxf4nvf4=True, + ) + self.tcgen05_atom_arrive(mbar) + + return _warp_mma_mxf4nvf4_blockscaled(desc_a, desc_b, C_local_buf, sfa_data, sfb_data, mbar) + + def make_mma_load_layout(self, local_buf: Buffer, matrix: str = "A") -> T.Fragment: raise NotImplementedError @@ -1042,6 +1186,7 @@ def tcgen05_blockscaled_atom( b_params: TCGEN05DescriptorParams, instr_desc: PrimExpr, clear_accum: PrimExpr = False, + use_mxf4nvf4: bool = False, ): """Emit a single TCGEN05MMA block-scaled SS instruction. @@ -1120,7 +1265,7 @@ def _bs_atom(desc_a, desc_b, C_local_buf, sfa_data, sfb_data): 0, sfb_data, 0, - 0, + 1 if use_mxf4nvf4 else 0, 0, enable_ws, enable_2cta, From 4dd1a5e8f275c9ae447c7ca925953f1e5181a49f Mon Sep 17 00:00:00 2001 From: wahao Date: Fri, 5 Jun 2026 18:15:45 +0800 Subject: [PATCH 17/19] feat(sm100): based on blockscale fw,assign is_nvfp4 annotation to additional tcgen05mma_mxf4nvf4_blockscaled, add example to verify --- examples/gemm_fp4/example_gemm_nvfp4_sm100.py | 170 ++++++++++++++++++ tilelang/cuda/op/gemm/gemm_tcgen05.py | 78 +++++--- tilelang/language/gemm_op.py | 10 +- 3 files changed, 231 insertions(+), 27 deletions(-) create mode 100644 examples/gemm_fp4/example_gemm_nvfp4_sm100.py diff --git a/examples/gemm_fp4/example_gemm_nvfp4_sm100.py b/examples/gemm_fp4/example_gemm_nvfp4_sm100.py new file mode 100644 index 0000000000..c04c9d5e28 --- /dev/null +++ b/examples/gemm_fp4/example_gemm_nvfp4_sm100.py @@ -0,0 +1,170 @@ +"""Minimal NVFP4 GEMM on SM100/SM110 using TCGEN05 block-scaled MMA. + +This validates the frontend-to-PTX path for +``tcgen05.mma.kind::mxf4nvf4.block_scale``. A/B are declared as native +``T.float4_e2m1fn`` tensors, then viewed as packed bytes for the current SMEM +staging path. SFA/SFB are uint32 words, each packing four FP8 E4M3 scale +values for one K64 tile. +""" + +import os +import time + +import torch +import tilelang +import tilelang.language as T + + +FP4_E2M1_TO_FLOAT = [ + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, + -0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, +] + + +def pack_e4m3x4_to_u32(values: tuple[float, float, float, float]) -> int: + raw = torch.tensor(values, dtype=torch.float32).to(torch.float8_e4m3fn).view(torch.uint8) + return sum(int(byte) << (8 * i) for i, byte in enumerate(raw.tolist())) + + +def unpack_fp4_to_float(packed, rows, cols): + lut = torch.tensor(FP4_E2M1_TO_FLOAT, dtype=torch.float32, device=packed.device) + flat = packed.to(torch.uint8).reshape(rows, cols // 2) + lo = flat & 0x0F + hi = (flat >> 4) & 0x0F + unpacked = torch.stack([lo, hi], dim=-1).reshape(rows, cols).to(torch.int64) + return lut[unpacked] + + +def matmul_nvfp4_sm100( + M=128, + N=128, + K=64, + block_M=128, + block_N=128, + block_K=64, + out_dtype=T.float32, + accum_dtype=T.float32, + threads=128, +): + packed_K = K // 2 + packed_block_K = block_K // 2 + k_tiles = T.ceildiv(K, block_K) + + @T.prim_func + def main( + A: T.Tensor((M, K), T.float4_e2m1fn), + B: T.Tensor((N, K), T.float4_e2m1fn), + SFA: T.Tensor((T.ceildiv(K, block_K) * M,), "uint32"), + SFB: T.Tensor((T.ceildiv(K, block_K) * N,), "uint32"), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_bytes = T.view(A, (M, packed_K), "uint8") + B_bytes = T.view(B, (N, packed_K), "uint8") + + A_shared = T.alloc_shared((block_M, packed_block_K), "uint8") + B_shared = T.alloc_shared((block_N, packed_block_K), "uint8") + SFA_shared = T.alloc_shared((block_M,), "uint32") + SFB_shared = T.alloc_shared((block_N,), "uint32") + + C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype) + SFA_tmem = T.alloc_tmem([block_M, 4], "uint32") + SFB_tmem = T.alloc_tmem([block_M, 4], "uint32") + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_shared = T.alloc_shared((block_M, block_N), out_dtype) + mbar = T.alloc_barrier(1) + + tx = T.get_thread_binding() + T.use_swizzle(8) + + for ko in T.serial(k_tiles): + T.copy(A_bytes[by * block_M, ko * packed_block_K], A_shared) + T.copy(B_bytes[bx * block_N, ko * packed_block_K], B_shared) + T.copy(SFA[ko * M + by * block_M], SFA_shared) + T.copy(SFB[ko * N + bx * block_N], SFB_shared) + T.sync_threads() + + if tx < 32: + T.tcgen05_sf_warp_transpose(SFA_shared) + T.tcgen05_sf_warp_transpose(SFB_shared) + T.fence_proxy_async() + T.tcgen05_cp_warpx4(SFA_shared, SFA_tmem) + T.tcgen05_cp_warpx4(SFB_shared, SFB_tmem) + + if 32 <= tx and tx < 64: + T.tcgen05_gemm_blockscaled( + A_shared, + B_shared, + C_tmem, + SFA_tmem, + SFB_tmem, + transpose_B=True, + mbar=mbar, + clear_accum=(ko == 0), + is_nvfp4=True, + ) + T.mbarrier_wait_parity(mbar, ko % 2) + T.sync_threads() + + T.copy(C_tmem, C_local) + T.copy(C_local, C_shared) + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return main + + +if __name__ == "__main__": + M = int(os.environ.get("TL_NVFP4_SM100_M", "128")) + N = int(os.environ.get("TL_NVFP4_SM100_N", "128")) + K = int(os.environ.get("TL_NVFP4_SM100_K", "64")) + one_scale = pack_e4m3x4_to_u32((1.0, 1.0, 1.0, 1.0)) + + print(f"Running SM100 NVFP4 GEMM: M={M}, N={N}, K={K}") + kernel = tilelang.compile( + matmul_nvfp4_sm100(M, N, K), + out_idx=[4], + target="cuda", + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, + ) + print("Compilation succeeded!") + + torch.manual_seed(0) + a = torch.randint(0, 256, (M, K // 2), device="cuda", dtype=torch.uint8) + b = torch.randint(0, 256, (N, K // 2), device="cuda", dtype=torch.uint8) + sfa = torch.full(((K + 63) // 64 * M,), one_scale, device="cuda", dtype=torch.uint32) + sfb = torch.full(((K + 63) // 64 * N,), one_scale, device="cuda", dtype=torch.uint32) + + c = kernel(a, b, sfa, sfb) + ref = unpack_fp4_to_float(a, M, K) @ unpack_fp4_to_float(b, N, K).T + diff = (c.float() - ref).abs() + print(f"[NUMERICAL] max_abs_diff={diff.max().item():.4f}, rel_err={diff.sum().item() / (ref.abs().sum().item() + 1e-10):.6f}") + + warmup = int(os.environ.get("TL_NVFP4_SM100_WARMUP", "20")) + iters = int(os.environ.get("TL_NVFP4_SM100_ITERS", "100")) + for _ in range(warmup): + kernel(a, b, sfa, sfb) + torch.cuda.synchronize() + start = time.perf_counter() + for _ in range(iters): + kernel(a, b, sfa, sfb) + torch.cuda.synchronize() + elapsed_ms = (time.perf_counter() - start) * 1000 / iters + tflops = 2 * M * N * K / (elapsed_ms / 1e3) / 1e12 + print(f"[BENCH] latency={elapsed_ms:.4f} ms, TFLOPS={tflops:.2f}, iters={iters}") diff --git a/tilelang/cuda/op/gemm/gemm_tcgen05.py b/tilelang/cuda/op/gemm/gemm_tcgen05.py index 8c38b20501..e6e7805657 100644 --- a/tilelang/cuda/op/gemm/gemm_tcgen05.py +++ b/tilelang/cuda/op/gemm/gemm_tcgen05.py @@ -35,6 +35,11 @@ GEMM_INST_TCGEN05 = "cuda.tcgen05" +def _annotation_bool(annotations: dict, key: str) -> bool: + value = annotations.get(key, 0) + return bool(getattr(value, "value", value)) + + class GemmTCGEN5(GemmBase): """GEMM operator for Blackwell (SM100) TCGEN5MMA instructions. @@ -107,7 +112,7 @@ def infer_layout(self, target: Target, thread_nums: int): b_is_k_major = self.trans_B annotations = getattr(self.gemm_node, "annotations", {}) - use_2cta = bool(annotations.get("use_2cta", 0)) + use_2cta = _annotation_bool(annotations, "use_2cta") mma_emitter.get_tcgen5_mma_meta(self.M, self.N, self.K, disable_2cta=not use_2cta) if self.is_blockscaled or self.is_gemm_ss(): @@ -171,7 +176,7 @@ def lower( raise ValueError(f"TCGEN5MMA supports gemm_ss and gemm_ts, got A scope {self.A.scope()}, B scope {self.B.scope()}") annotations = getattr(self.gemm_node, "annotations", {}) - use_2cta = bool(annotations.get("use_2cta", 0)) + use_2cta = _annotation_bool(annotations, "use_2cta") mma_emitter.get_tcgen5_mma_meta(self.M, self.N, self.K, disable_2cta=not use_2cta) atom_m, atom_n, atom_k, enable_ws, enable_2cta = mma_emitter.meta @@ -266,7 +271,8 @@ def _lower_blockscaled(self, mma_emitter, thread_bounds, thread_var, mbar_phase_ del mbar_phase_expr annotations = getattr(self.gemm_node, "annotations", {}) - use_2cta = bool(annotations.get("use_2cta", 0)) + use_2cta = _annotation_bool(annotations, "use_2cta") + is_nvfp4 = _annotation_bool(annotations, "is_nvfp4") mma_emitter.get_tcgen5_mma_meta(self.M, self.N, self.K, disable_2cta=not use_2cta) _atom_m, _atom_n, _atom_k, _enable_ws, enable_2cta = (int(x) for x in mma_emitter.meta) @@ -280,32 +286,54 @@ def _lower_blockscaled(self, mma_emitter, thread_bounds, thread_var, mbar_phase_ @T.prim_func def _gemm_blockscaled_cond() -> None: if cluster_cond and thread_var // 32 == thread_bounds.min // warp_size: - mma_emitter.tcgen05mma_blockscaled( - A_shared, - B_shared, - C_local, - SFA_tmem, - SFB_tmem, - mbarptr, - clear_accum, - sf_a_id, - sf_b_id, - ) + if is_nvfp4: + mma_emitter.tcgen05mma_mxf4nvf4_blockscaled( + A_shared, + B_shared, + C_local, + SFA_tmem, + SFB_tmem, + mbarptr, + clear_accum, + ) + else: + mma_emitter.tcgen05mma_blockscaled( + A_shared, + B_shared, + C_local, + SFA_tmem, + SFB_tmem, + mbarptr, + clear_accum, + sf_a_id, + sf_b_id, + ) @T.prim_func def _gemm_blockscaled() -> None: if cluster_cond: - mma_emitter.tcgen05mma_blockscaled( - A_shared, - B_shared, - C_local, - SFA_tmem, - SFB_tmem, - mbarptr, - clear_accum, - sf_a_id, - sf_b_id, - ) + if is_nvfp4: + mma_emitter.tcgen05mma_mxf4nvf4_blockscaled( + A_shared, + B_shared, + C_local, + SFA_tmem, + SFB_tmem, + mbarptr, + clear_accum, + ) + else: + mma_emitter.tcgen05mma_blockscaled( + A_shared, + B_shared, + C_local, + SFA_tmem, + SFB_tmem, + mbarptr, + clear_accum, + sf_a_id, + sf_b_id, + ) return ( _Simplify(_gemm_blockscaled, inline_let=True) diff --git a/tilelang/language/gemm_op.py b/tilelang/language/gemm_op.py index 064005a44d..4df0aa954f 100644 --- a/tilelang/language/gemm_op.py +++ b/tilelang/language/gemm_op.py @@ -395,6 +395,7 @@ def tcgen05_gemm_blockscaled( sf_b_id: int = 0, *, use_2cta: bool = False, + is_nvfp4: bool = False, ) -> tirx.PrimExpr: """Explicit Blackwell TCGEN05 block-scaled GEMM without an implicit wait. @@ -428,7 +429,11 @@ def tcgen05_gemm_blockscaled( use_2cta: Whether to request true ``cta_group::2`` lowering. """ - ann = {"use_2cta": int(use_2cta)} if use_2cta else None + ann = {} + if use_2cta: + ann["use_2cta"] = int(use_2cta) + if is_nvfp4: + ann["is_nvfp4"] = 1 # Re-read normalized regions below after let legalization. @@ -464,6 +469,7 @@ def legalize(arg): K = A_shape[-2] if transpose_A else A_shape[-1] K_B = B_shape[-1] if transpose_B else B_shape[-2] assert prim_expr_equal(K, K_B), f"T.tcgen05_gemm_blockscaled K shape check failed: K_A = {K}, K_B = {K_B}" + logical_K = K * 2 if is_nvfp4 and str(A_region.buffer.dtype) == "uint8" and str(B_region.buffer.dtype) == "uint8" else K if use_2cta: assert prim_expr_equal(M_A, M) and prim_expr_equal(N_B * 2, N), ( f"T.tcgen05_gemm_blockscaled 2CTA shape check failed: M_A = {M_A}, expected M_C = {M}; N_B = {N_B}, expected N_C / 2 = {N} / 2" @@ -517,7 +523,7 @@ def legalize(arg): transpose_B, M, N, - K, + logical_K, policy, clear_accum, stride_a, From 1290a96dba075c41e6c25088ac67e77c09f12220 Mon Sep 17 00:00:00 2001 From: wahao Date: Fri, 12 Jun 2026 09:21:03 +0800 Subject: [PATCH 18/19] fix(sm100): correct NVFP4 mxf4nvf4 block-scaled GEMM by staging FP4 operands dense (SW64) instead of gap-expanded align16b (uniform-SF verified, rel_err=0 on Thor) --- examples/gemm_fp4/example_gemm_nvfp4_sm100.py | 113 ++++++++++++++--- examples/gemm_fp4/gemm_nvfp4_sm100.cu | 119 ++++++++++++++++++ src/cuda/op/gemm.cc | 11 ++ src/layout/gemm_layouts.cc | 66 ++++++++++ src/layout/layout.cc | 4 + src/layout/layout.h | 2 + src/op/tcgen5_meta.h | 71 ++++++----- .../macro/tcgen05_macro_generator.py | 76 +++++++---- tilelang/cuda/op/gemm/gemm_tcgen05.py | 3 + tilelang/cuda/transform/__init__.py | 4 +- tilelang/language/gemm_op.py | 26 +++- tilelang/layout/__init__.py | 1 + tilelang/layout/swizzle.py | 15 +++ 13 files changed, 425 insertions(+), 86 deletions(-) create mode 100644 examples/gemm_fp4/gemm_nvfp4_sm100.cu diff --git a/examples/gemm_fp4/example_gemm_nvfp4_sm100.py b/examples/gemm_fp4/example_gemm_nvfp4_sm100.py index c04c9d5e28..f450cef418 100644 --- a/examples/gemm_fp4/example_gemm_nvfp4_sm100.py +++ b/examples/gemm_fp4/example_gemm_nvfp4_sm100.py @@ -1,10 +1,10 @@ """Minimal NVFP4 GEMM on SM100/SM110 using TCGEN05 block-scaled MMA. This validates the frontend-to-PTX path for -``tcgen05.mma.kind::mxf4nvf4.block_scale``. A/B are declared as native -``T.float4_e2m1fn`` tensors, then viewed as packed bytes for the current SMEM -staging path. SFA/SFB are uint32 words, each packing four FP8 E4M3 scale -values for one K64 tile. +``tcgen05.mma.kind::mxf4nvf4.block_scale``. A/B are declared and staged as +native ``T.float4_e2m1fn`` tensors so the shared operand gets the swizzled +(align16b -> SWIZZLE_128B) layout the MMA requires. SFA/SFB are uint32 words, +each packing four FP8 E4M3 scale values for one K64 tile. """ import os @@ -40,22 +40,37 @@ def pack_e4m3x4_to_u32(values: tuple[float, float, float, float]) -> int: return sum(int(byte) << (8 * i) for i, byte in enumerate(raw.tolist())) -def unpack_fp4_to_float(packed, rows, cols): +def unpack_fp4_to_float(packed, rows, cols, high_first=False): lut = torch.tensor(FP4_E2M1_TO_FLOAT, dtype=torch.float32, device=packed.device) flat = packed.to(torch.uint8).reshape(rows, cols // 2) lo = flat & 0x0F hi = (flat >> 4) & 0x0F - unpacked = torch.stack([lo, hi], dim=-1).reshape(rows, cols).to(torch.int64) + first, second = (hi, lo) if high_first else (lo, hi) + unpacked = torch.stack([first, second], dim=-1).reshape(rows, cols).to(torch.int64) return lut[unpacked] +def pack_repeated_nibble(codes, cols): + packed = torch.empty((codes.numel(), cols // 2), device=codes.device, dtype=torch.uint8) + byte = (codes.to(torch.uint8) & 0x0F) | ((codes.to(torch.uint8) & 0x0F) << 4) + packed[:] = byte[:, None] + return packed + + +def pack_k_pattern(row_count, codes): + lo = codes[0::2].to(torch.uint8) & 0x0F + hi = codes[1::2].to(torch.uint8) & 0x0F + row = lo | (hi << 4) + return row.repeat(row_count, 1) + + def matmul_nvfp4_sm100( M=128, N=128, - K=64, + K=128, block_M=128, block_N=128, - block_K=64, + block_K=128, out_dtype=T.float32, accum_dtype=T.float32, threads=128, @@ -73,9 +88,13 @@ def main( C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + # DENSE-packed FP4 SMEM staging for tcgen05 mxf4nvf4: stage the operands + # as uint8 [M, K/2] (two e2m1 per byte). With TMA on, infer_shared_layout's + # >=8-bit branch gives this a half_bank (SW64) swizzle for K/2==64 bytes, + # matching CUTLASS's dense e2m1 SW64 K-major operand. (mxf4nvf4 consumes + # FP4 dense, unlike f8f6f4 which uses the 1-byte-per-FP4 align16b layout.) A_bytes = T.view(A, (M, packed_K), "uint8") B_bytes = T.view(B, (N, packed_K), "uint8") - A_shared = T.alloc_shared((block_M, packed_block_K), "uint8") B_shared = T.alloc_shared((block_N, packed_block_K), "uint8") SFA_shared = T.alloc_shared((block_M,), "uint32") @@ -104,6 +123,7 @@ def main( T.fence_proxy_async() T.tcgen05_cp_warpx4(SFA_shared, SFA_tmem) T.tcgen05_cp_warpx4(SFB_shared, SFB_tmem) + T.sync_threads() if 32 <= tx and tx < 64: T.tcgen05_gemm_blockscaled( @@ -130,31 +150,90 @@ def main( if __name__ == "__main__": M = int(os.environ.get("TL_NVFP4_SM100_M", "128")) N = int(os.environ.get("TL_NVFP4_SM100_N", "128")) - K = int(os.environ.get("TL_NVFP4_SM100_K", "64")) + K = int(os.environ.get("TL_NVFP4_SM100_K", "128")) + block_K = 128 # mxf4nvf4 needs a swizzled (SWIZZLE_128B) operand; 64B rows -> block_K=128 + device_major, device_minor = torch.cuda.get_device_capability() + arch = os.environ.get("TL_NVFP4_SM100_ARCH", f"sm_{device_major}{device_minor}") + cuda_root = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") or "/usr/local/cuda" + cuda_target = os.environ.get("TL_NVFP4_SM100_CUDA_TARGET_DIR", "aarch64-linux") + cuda_include = os.environ.get("TL_NVFP4_SM100_CUDA_INCLUDE", f"{cuda_root}/targets/{cuda_target}/include") one_scale = pack_e4m3x4_to_u32((1.0, 1.0, 1.0, 1.0)) - print(f"Running SM100 NVFP4 GEMM: M={M}, N={N}, K={K}") + print(f"Running SM100 NVFP4 GEMM: M={M}, N={N}, K={K}, arch={arch}") kernel = tilelang.compile( matmul_nvfp4_sm100(M, N, K), out_idx=[4], - target="cuda", + target={"kind": "cuda", "arch": arch}, pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + # mxf4nvf4 needs a swizzled K-major operand; sub-byte FP4 only gets the + # tcgen05mma swizzled SMEM layout on the TMA path (see infer_shared_layout). + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: False, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + tilelang.PassConfigKey.TL_DEVICE_COMPILE_FLAGS: [ + f"--target-directory={cuda_target}", + f"-I{cuda_include}", + ], }, ) print("Compilation succeeded!") + if os.environ.get("TL_NVFP4_DUMP_CUDA", "0") != "0": + with open(os.path.join(os.path.dirname(__file__), "gemm_nvfp4_sm100.cu"), "w") as f: + f.write(kernel.get_kernel_source()) torch.manual_seed(0) - a = torch.randint(0, 256, (M, K // 2), device="cuda", dtype=torch.uint8) - b = torch.randint(0, 256, (N, K // 2), device="cuda", dtype=torch.uint8) - sfa = torch.full(((K + 63) // 64 * M,), one_scale, device="cuda", dtype=torch.uint32) - sfb = torch.full(((K + 63) // 64 * N,), one_scale, device="cuda", dtype=torch.uint32) + pattern = os.environ.get("TL_NVFP4_SM100_PATTERN") + if pattern == "ones": + a = torch.full((M, K // 2), 0x22, device="cuda", dtype=torch.uint8) + b = torch.full((N, K // 2), 0x22, device="cuda", dtype=torch.uint8) + elif pattern == "a_ones_b_n": + a = torch.full((M, K // 2), 0x22, device="cuda", dtype=torch.uint8) + codes = (torch.arange(N, device="cuda", dtype=torch.uint8) % 7) + 1 + b = pack_repeated_nibble(codes, K) + elif pattern == "a_ones_b_k": + a = torch.full((M, K // 2), 0x22, device="cuda", dtype=torch.uint8) + codes = (torch.arange(K, device="cuda", dtype=torch.uint8) % 7) + 1 + b = pack_k_pattern(N, codes) + else: + a = torch.randint(0, 256, (M, K // 2), device="cuda", dtype=torch.uint8) + b = torch.randint(0, 256, (N, K // 2), device="cuda", dtype=torch.uint8) + # SF words must match the kernel signature: ceildiv(K, block_K) words per row. + # (Uniform scale=1.0 here, so the per-K64-atom SF advance is exercised in M2.) + sf_tiles = (K + block_K - 1) // block_K + sfa = torch.full((sf_tiles * M,), one_scale, device="cuda", dtype=torch.uint32) + sfb = torch.full((sf_tiles * N,), one_scale, device="cuda", dtype=torch.uint32) c = kernel(a, b, sfa, sfb) ref = unpack_fp4_to_float(a, M, K) @ unpack_fp4_to_float(b, N, K).T diff = (c.float() - ref).abs() print(f"[NUMERICAL] max_abs_diff={diff.max().item():.4f}, rel_err={diff.sum().item() / (ref.abs().sum().item() + 1e-10):.6f}") + if os.environ.get("TL_NVFP4_SM100_PRINT_SAMPLE"): + print(f"[SAMPLE] c00={c[0, 0].item():.4f}, c01={c[0, 1].item():.4f}, ref00={ref[0, 0].item():.4f}") + print( + f"[SAMPLE] c_min={c.float().min().item():.4f}, c_max={c.float().max().item():.4f}, " + f"ref_min={ref.min().item():.4f}, ref_max={ref.max().item():.4f}, " + f"num_equal={(diff == 0).sum().item()}/{diff.numel()}" + ) + nz = c.float() != 0 + row_nz = nz.sum(dim=1) + col_nz = nz.sum(dim=0) + print(f"[SAMPLE] row_nz_first16={row_nz[:16].tolist()}") + print(f"[SAMPLE] row_nz_last16={row_nz[-16:].tolist()}") + print(f"[SAMPLE] col_nz_first32={col_nz[:32].tolist()}") + print(f"[SAMPLE] col_nz_last32={col_nz[-32:].tolist()}") + print(f"[SAMPLE] c_row0_first16={c[0, :16].float().tolist()}") + print(f"[SAMPLE] c_row0_mid48_80={c[0, 48:80].float().tolist()}") + print(f"[SAMPLE] c_row0_last16={c[0, -16:].float().tolist()}") + print(f"[SAMPLE] c_col0_first16={c[:16, 0].float().tolist()}") + print(f"[SAMPLE] ref_row0_first32={ref[0, :32].float().tolist()}") + for a_high_first in (False, True): + for b_high_first in (False, True): + ref_probe = unpack_fp4_to_float(a, M, K, a_high_first) @ unpack_fp4_to_float(b, N, K, b_high_first).T + probe_diff = (c.float() - ref_probe).abs() + print( + f"[PROBE] a_high_first={int(a_high_first)}, b_high_first={int(b_high_first)}, " + f"max_abs_diff={probe_diff.max().item():.4f}, " + f"rel_err={probe_diff.sum().item() / (ref_probe.abs().sum().item() + 1e-10):.6f}" + ) warmup = int(os.environ.get("TL_NVFP4_SM100_WARMUP", "20")) iters = int(os.environ.get("TL_NVFP4_SM100_ITERS", "100")) diff --git a/examples/gemm_fp4/gemm_nvfp4_sm100.cu b/examples/gemm_fp4/gemm_nvfp4_sm100.cu new file mode 100644 index 0000000000..14b6fefca9 --- /dev/null +++ b/examples/gemm_fp4/gemm_nvfp4_sm100.cu @@ -0,0 +1,119 @@ +#if defined(_MSC_VER) && !defined(__clang__) && _MSC_VER < 1940 +#define _tl_orig_alignas alignas +#define alignas(N) _tl_orig_alignas((N) <= 64 ? (N) : 64) +#include +#undef alignas +#define alignas _tl_orig_alignas +#endif +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#ifdef ENABLE_BF16 +#include +#endif + +extern "C" __global__ void main_kernel(const fp4_e2_t* __restrict__ A_bytes, const fp4_e2_t* __restrict__ B_bytes, float* __restrict__ C, const uint* __restrict__ SFA, const uint* __restrict__ SFB); +extern "C" __global__ void __launch_bounds__(128, 1) main_kernel(const fp4_e2_t* __restrict__ A_bytes, const fp4_e2_t* __restrict__ B_bytes, float* __restrict__ C, const uint* __restrict__ SFA, const uint* __restrict__ SFB) { + extern __shared__ __align__(1024) uchar buf_dyn_shmem[]; + void* A_shared = ((void*)((char*)buf_dyn_shmem + 0)); + void* C_shared = ((void*)((char*)buf_dyn_shmem + 0)); + void* B_shared = ((void*)((char*)buf_dyn_shmem + 8192)); + void* SFA_shared = ((void*)((char*)buf_dyn_shmem + 16384)); + void* SFB_shared = ((void*)((char*)buf_dyn_shmem + 16896)); + __shared__ __align__(16) uint64_t mbar_mem[1]; + auto mbar = reinterpret_cast(mbar_mem); + __shared__ __align__(16) uint C_tmem[1]; + __shared__ __align__(16) uint sfb_data[1]; + __shared__ __align__(16) uint sfa_data[1]; + float C_local[128]; + if (tl::tl_shuffle_elect<0>()) { + mbar[0].init(1); + } + tl::fence_barrier_init(); + tl::tcgen05_before_thread_sync(); + __syncthreads(); + tl::tcgen05_after_thread_sync(); + if ((((int)threadIdx.x) >> 5) == 0) { + tl::tmem_allocate((&(C_tmem[0])), 128); + tl::tmem_allocate((&(sfb_data[0])), 32); + tl::tmem_allocate((&(sfa_data[0])), 32); + } + tl::tcgen05_before_thread_sync(); + __syncthreads(); + tl::tcgen05_after_thread_sync(); + const dim3 blockIdx = tl::rasterization2DRow<8>(); + #pragma unroll + for (int i = 0; i < 4; ++i) { + *(uint4*)(((uchar*)A_shared) + ((((i * 2048) + ((((int)threadIdx.x) >> 2) * 64)) + (((((((int)threadIdx.x) & 31) >> 4) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 15) >> 3) + (((int)threadIdx.x) & 1)) & 1) * 16))) = *(uint4*)(((uchar*)A_bytes) + ((i * 2048) + (((int)threadIdx.x) * 16))); + } + #pragma unroll + for (int i_1 = 0; i_1 < 4; ++i_1) { + *(uint4*)(((uchar*)B_shared) + ((((i_1 * 2048) + ((((int)threadIdx.x) >> 2) * 64)) + (((((((int)threadIdx.x) & 31) >> 4) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 15) >> 3) + (((int)threadIdx.x) & 1)) & 1) * 16))) = *(uint4*)(((uchar*)B_bytes) + ((i_1 * 2048) + (((int)threadIdx.x) * 16))); + } + ((uint*)SFA_shared)[((int)threadIdx.x)] = SFA[((int)threadIdx.x)]; + ((uint*)SFB_shared)[((int)threadIdx.x)] = SFB[((int)threadIdx.x)]; + tl::tcgen05_before_thread_sync(); + __syncthreads(); + tl::tcgen05_after_thread_sync(); + if (((int)threadIdx.x) < 32) { + void* chunk_ptr = (&(((uint*)SFA_shared)[0])); + tl::tcgen05_sf_warp_transpose(reinterpret_cast(chunk_ptr)); + void* chunk_ptr_1 = (&(((uint*)SFB_shared)[0])); + tl::tcgen05_sf_warp_transpose(reinterpret_cast(chunk_ptr_1)); + tl::fence_proxy_async(); + tl::tcgen05_before_thread_sync(); + tl::__sync_thread_partial(3, 32); + tl::tcgen05_after_thread_sync(); + void* chunk_ptr_2 = (&(((uint*)SFA_shared)[0])); + tl::tcgen05_cp(tl::make_sf_smem_desc(reinterpret_cast(chunk_ptr_2)), (*reinterpret_cast(sfa_data)) + 0); + void* chunk_ptr_3 = (&(((uint*)SFB_shared)[0])); + tl::tcgen05_cp(tl::make_sf_smem_desc(reinterpret_cast(chunk_ptr_3)), (*reinterpret_cast(sfb_data)) + 0); + } + tl::tcgen05_before_thread_sync(); + __syncthreads(); + tl::tcgen05_after_thread_sync(); + if ((32 <= ((int)threadIdx.x)) && (((int)threadIdx.x) < 64)) { + { + tl::Tcgen05SMemDescriptor v; + tl::Tcgen05SMemDescriptor v_1; + tl::initialize_tcgen05_descriptor(v, (&(((uchar*)A_shared)[0])), 1, 32, 0, 0, 4); + tl::initialize_tcgen05_descriptor(v_1, (&(((uchar*)B_shared)[0])), 1, 32, 0, 0, 4); + #pragma unroll + for (int ki = 0; ki < 2; ++ki) { + tl::tcgen05mma_mxf4nvf4_blockscaled_ss(uint64_t(v + (ki * 32)), uint64_t(v_1 + (ki * 32)), (*reinterpret_cast(C_tmem)) + 0, ((0 < ki) ? 1 : 0), static_cast(136316032), (*reinterpret_cast(sfa_data)) + 0, (*reinterpret_cast(sfb_data)) + 0); + } + tl::tcgen05_mma_arrive((&(mbar[0]))); + } + } + mbar[0].wait(0); + tl::tcgen05_before_thread_sync(); + __syncthreads(); + tl::tcgen05_after_thread_sync(); + tl::tcgen05_ld_32dp32bNx<128, false>(C_tmem[0], 0, (&(C_local[0]))); + #pragma unroll + for (int i_2 = 0; i_2 < 32; ++i_2) { + *(float4*)(((float*)C_shared) + ((((int)threadIdx.x) * 128) + (i_2 * 4))) = *(float4*)(C_local + (i_2 * 4)); + } + tl::tcgen05_before_thread_sync(); + __syncthreads(); + tl::tcgen05_after_thread_sync(); + if (tl::tl_shuffle_elect<128>()) { + tl::fence_proxy_async(); + tl::tma_store((&(C[0])), (&(((float*)C_shared)[0])), 65536); + tl::tma_store_arrive(); + tl::tma_store_wait<0, true>(); + } + if ((((int)threadIdx.x) >> 5) == 0) { + tl::tmem_deallocate((&(sfb_data[0])), 32); + tl::tmem_deallocate((&(sfa_data[0])), 32); + tl::tmem_deallocate((&(C_tmem[0])), 128); + } +} + diff --git a/src/cuda/op/gemm.cc b/src/cuda/op/gemm.cc index 4837308fb8..75a6082f37 100644 --- a/src/cuda/op/gemm.cc +++ b/src/cuda/op/gemm.cc @@ -34,6 +34,14 @@ constexpr const char *kCudaMMA = "cuda.mma"; constexpr const char *kCudaWGMMA = "cuda.wgmma"; constexpr const char *kCudaTCGEN05 = "cuda.tcgen05"; +bool HasIntAnnotation(const GemmNode &op, const char *key) { + if (auto val = op.annotations_.Get(key)) { + const auto *int_val = val->as(); + return int_val && int_val->value != 0; + } + return false; +} + bool CheckWgmma(const GemmNode &op) { if (op.b_.scope() != "shared.dyn" && op.b_.scope() != "shared") { return false; @@ -81,6 +89,9 @@ bool AllowTcgen5Mma(const GemmNode &op, Target target) { return false; DataType ab_dtype = (op.a_.scope() == "shared.tmem") ? op.b_->dtype : op.a_->dtype; + if (HasIntAnnotation(op, "is_nvfp4")) { + ab_dtype = DataType::Float4E2M1FN(); + } return GetTCGEN5MMAMeta(op.m_, op.n_, op.k_, ab_dtype, op.c_->dtype).first; } diff --git a/src/layout/gemm_layouts.cc b/src/layout/gemm_layouts.cc index 73127761c5..f491a2aa3b 100644 --- a/src/layout/gemm_layouts.cc +++ b/src/layout/gemm_layouts.cc @@ -525,6 +525,59 @@ static Layout MakeAlign16BSwizzleLayout2D(int stride, int continuous, return Layout(Array{stride, continuous}, {tc, ts, index}); } +// Dense-packed FP4 swizzle for tcgen05 mxf4nvf4. +// +// Unlike MakeAlign16BSwizzleLayout2D (f8f6f4 / unpacksmem: one FP4 per 8-bit +// container), the mxf4nvf4 MMA consumes FP4 DENSE-packed (two e2m1 per byte). +// We therefore swizzle at *byte* granularity: a logical FP4 element j maps to +// packed byte (j/2) plus nibble (j%2); the byte coordinate is XOR-swizzled with +// the standard >=8-bit bank pattern (SW128/SW64/SW32 chosen by byte width), and +// the result is re-expanded to an FP4-element index (byte*2 + nibble) so the +// codegen div_factor=2 packing lands two FP4 per byte. +static Layout MakeDenseFp4SwizzleLayout2D(int stride, int continuous, + int element_size) { + ICHECK(element_size == 4) + << "Dense FP4 swizzle is for e2m1 (4-bit), got element_size=" + << element_size; + ICHECK(stride % 8 == 0) << "stride=" << stride; + ICHECK(continuous % 2 == 0) << "continuous (FP4 count) must be even, got " + << continuous; + const int byte_continuous = continuous / 2; // packed bytes along K + const int vector_size = 16; // 128 / 8 (byte domain) + // Pick the largest bank swizzle whose period divides the packed byte width. + int bank; // number of 16B columns XOR-mixed: 8=SW128, 4=SW64, 2=SW32, 1=none + if (byte_continuous % (vector_size * 8) == 0) + bank = 8; + else if (byte_continuous % (vector_size * 4) == 0) + bank = 4; + else if (byte_continuous % (vector_size * 2) == 0) + bank = 2; + else + bank = 1; + + Var i = InputPlaceholder(0); + Var j = InputPlaceholder(1); + PrimExpr ts = FloorDiv(i, 8); + PrimExpr s = FloorMod(i, 8); + PrimExpr b = FloorDiv(j, 2); // packed byte coordinate + PrimExpr nib = FloorMod(j, 2); // nibble within the byte + PrimExpr tc = FloorDiv(FloorDiv(b, vector_size), bank); + PrimExpr c = FloorMod(FloorDiv(b, vector_size), bank); + PrimExpr vec = FloorMod(b, vector_size); + PrimExpr c_swizzle; + if (bank == 8) + c_swizzle = xor8x8(c, s); + else if (bank == 4) + c_swizzle = xor4x4(c, FloorDiv(s, 2)); + else if (bank == 2) + c_swizzle = xor2x2(c, FloorDiv(s, 4)); + else + c_swizzle = c; + PrimExpr byte_index = vec + (c_swizzle + s * bank) * vector_size; + PrimExpr index = byte_index * 2 + nib; + return Layout(Array{stride, continuous}, {tc, ts, index}); +} + // Layout swizzling for 128 bytes static Layout MakeFullBankSwizzleLayout2D(int stride, int continuous, int element_size) { @@ -567,6 +620,19 @@ Layout makeAlign16BSwizzleLayout(const Buffer &buffer) { return base->Expand(leading_shape); } +Layout makeDenseFp4SwizzleLayout(const Buffer &buffer) { + auto info = GetSwizzleShapeInfoChecked(buffer); + auto base = MakeDenseFp4SwizzleLayout2D(static_cast(info.stride), + static_cast(info.continuous), + info.element_size); + Array leading_shape; + leading_shape.reserve(buffer->shape.size() - 2); + for (size_t i = 0; i + 2 < buffer->shape.size(); ++i) { + leading_shape.push_back(buffer->shape[i]); + } + return base->Expand(leading_shape); +} + // Detail implementation please ref to // bitblas::tl::mfma_layout::make_mfma_swizzle_layout Layout makeMatrixCoreSwizzleLayout(int stride, int continuous, int element_size, diff --git a/src/layout/layout.cc b/src/layout/layout.cc index 42faac06c2..733c0ac2bb 100644 --- a/src/layout/layout.cc +++ b/src/layout/layout.cc @@ -1162,6 +1162,10 @@ TVM_FFI_STATIC_INIT_BLOCK() { [](const Buffer &buffer) { return makeAlign16BSwizzleLayout(buffer); }) + .def("tl.make_dense_fp4_swizzled_layout", + [](const Buffer &buffer) { + return makeDenseFp4SwizzleLayout(buffer); + }) .def("tl.make_linear_layout", [](Array shape) { return makeLinearLayout(shape); }) .def("tl.make_gemm_fragment_8x8", []() { return makeGemmFragment8x8(); }) diff --git a/src/layout/layout.h b/src/layout/layout.h index 61e787d3af..1cb66833da 100644 --- a/src/layout/layout.h +++ b/src/layout/layout.h @@ -287,6 +287,8 @@ Layout makeFullBankSwizzleLayout(const Buffer &buffer); Layout makeHalfBankSwizzleLayout(const Buffer &buffer); Layout makeQuarterBankSwizzleLayout(const Buffer &buffer); Layout makeAlign16BSwizzleLayout(const Buffer &buffer); +// Dense-packed FP4 (e2m1, two per byte) swizzle for tcgen05 mxf4nvf4 operands. +Layout makeDenseFp4SwizzleLayout(const Buffer &buffer); // Swizzle mode for shared memory layouts (nvidia only) // Smaller enum value = smaller swizzle granularity diff --git a/src/op/tcgen5_meta.h b/src/op/tcgen5_meta.h index 2b3a74c006..46d68aad39 100644 --- a/src/op/tcgen5_meta.h +++ b/src/op/tcgen5_meta.h @@ -320,46 +320,45 @@ inline uint32_t GetTCGEN5MXF4NVF4BlockScaledInstrDesc(int atom_m, int atom_n, int scale_in_a, int scale_in_b, bool is_mxfp4) { + ICHECK(atom_m % 16 == 0) << "atom_m must be divisible by 16"; + ICHECK(atom_n % 8 == 0) << "atom_n must be divisible by 8"; + ICHECK(scale_in_a == 1 || scale_in_a == -1); + ICHECK(scale_in_b == 1 || scale_in_b == -1); + ICHECK(ab_dtype.is_float4_e2m1fn()) + << "ab_dtype must be float4_e2m1fn for mxf4nvf4"; -ICHECK(atom_m % 16 == 0) << "atom_m must be divisible by 16"; -ICHECK(atom_n % 8 == 0) << "atom_n must be divisible by 8"; -ICHECK(scale_in_a == 1 || scale_in_a == -1); -ICHECK(scale_in_b == 1 || scale_in_b == -1); -ICHECK(ab_dtype.is_float4_e2m1fn()) << "ab_dtype must be float4_e2m1fn for mxf4nvf4"; - + auto set_bits = [](uint32_t value, int start, int width) -> uint32_t { + uint32_t mask = (width == 32) ? 0xFFFFFFFFu : ((1u << width) - 1); + return (value & mask) << start; + }; -auto set_bits = [](uint32_t value, int start, int width) -> uint32_t { -uint32_t mask = (width == 32) ? 0xFFFFFFFFu : ((1u << width) - 1); -return (value & mask) << start; -}; + uint32_t a_neg = (scale_in_a == -1) ? 1u : 0u; + uint32_t b_neg = (scale_in_b == -1) ? 1u : 0u; + uint32_t n_dim = static_cast(atom_n >> 3); + uint32_t m_dim = static_cast(atom_m >> 7); + uint32_t desc = 0; + // bit 0-1 reserved + desc |= set_bits(0, 2, 1); // sparse_flag + // bit 3 reserved + // bit 4-5: b_sf_id, enforce to be 0 for scale_vec::4X + // bit 6 reserved + desc |= set_bits(1, 7, 3); // a_format, enforce to be 1 for e2m1 format + desc |= set_bits(1, 10, 2); // b_format, enforce to be 1 for e2m1 format + // bit 12 reserved + desc |= set_bits(a_neg, 13, 1); // a_negate + desc |= set_bits(b_neg, 14, 1); // b_negate + // bit 15-16, enforce to be 0 as AB is fixed non-transposed + desc |= set_bits(n_dim, 17, 6); // n_dim + // scale_format = 1 (E8M0) for MXFP4, 0 (E4M3) for NVFP4 + desc |= set_bits(is_mxfp4 ? 1 : 0, 23, 1); + // bit 24-26 reserved + desc |= set_bits(m_dim, 27, 2); // m_dim + // bit 29-30: a_sf_id, enforce to be 0 for scale_vec::4X + desc |= set_bits(0, 29, 2); // a_sf_id + // bit 31 reserved -uint32_t a_neg = (scale_in_a == -1) ? 1u : 0u; -uint32_t b_neg = (scale_in_b == -1) ? 1u : 0u; -uint32_t n_dim = static_cast(atom_n >> 3); -uint32_t m_dim = static_cast(atom_m >> 7); - -uint32_t desc = 0; -// bit 0-1 reserved -desc |= set_bits(0, 2, 1); // sparse_flag -// bit 3 reserved -// bit 4-5: b_sf_id, enforce to be 0 for scale_vec::4X, see here https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x -// bit 6 reserved -desc |= set_bits(1, 7, 3); // a_format, enforce to be 1 for e2m1 format -desc |= set_bits(1, 10, 2); // b_format, enforce to be 1 for e2m1 format -// bit 12 reserved -desc |= set_bits(a_neg, 13, 1); // a_negate -desc |= set_bits(b_neg, 14, 1); // b_negate -// bit 15-16, enforce to be 0 as AB is fixed non-transposed -desc |= set_bits(n_dim, 17, 6); // n_dim -desc |= set_bits(is_mxfp4 ? 1 : 0, 23, 1); // scale_format = 1 (E8M0) for mxfp4, 0 (E4M3) for nvf4 -// bit 24-26 reserved -desc |= set_bits(m_dim, 27, 2); // m_dim -// bit 29-30: a_sf_id, enforce to be 0 for scale_vec::4X, -desc |= set_bits(static_cast(a_sf_id), 29, 2); // a_sf_id -// bit 31 reserved - -return desc; + return desc; } diff --git a/tilelang/cuda/intrinsics/macro/tcgen05_macro_generator.py b/tilelang/cuda/intrinsics/macro/tcgen05_macro_generator.py index f68af5b56f..9b4acfddf7 100644 --- a/tilelang/cuda/intrinsics/macro/tcgen05_macro_generator.py +++ b/tilelang/cuda/intrinsics/macro/tcgen05_macro_generator.py @@ -465,14 +465,16 @@ def tcgen05mma_mxf4nvf4_blockscaled( k_dim, n_dim = self.chunk, self.block_col_warps * self.warp_col_tiles assert k_dim >= micro_size_k - assert (not self.a_transposed) and (not self.b_transposed), "A and B must be non-transposed/K-major for mxf4nvf4 block-scaled" + assert not self.a_transposed, "A must be non-transposed for mxf4nvf4 block-scaled" a_swizzle_mode = self._determinate_swizzle_mode(A_buf, self.a_shared_layout) a_elems_in_bytes = (DataType(self.a_dtype).bits + 7) // 8 if len(self.meta) != 5: - self.get_tcgen5_mma_meta(m_dim, n_dim, k_dim, disable_2cta=False) + # mxf4nvf4 uses a single CTA in this lowering path. Picking the + # generic 2CTA-preferred meta leaves part of C owned by the peer CTA. + self.get_tcgen5_mma_meta(m_dim, n_dim, k_dim, disable_2cta=True) if len(self.meta) != 5: raise ValueError( f"Unsupported TCGEN5MMA configuration for block-scaled: M={m_dim}, N={n_dim}, " @@ -481,16 +483,23 @@ def tcgen05mma_mxf4nvf4_blockscaled( atom_m, atom_n, _, _, enable_2cta = self.tcgen05_meta_unpacked atom_m_per_cta = atom_m // 2 if enable_2cta else atom_m - a_swizzle_atom_elems = atom_m_per_cta if a_swizzle_mode.is_none() else (a_swizzle_mode.swizzle_byte_size() * 8) // DataType(self.a_dtype).bits - # Block-scaled A LBO/SBO differ from regular SS (uses atom_m_per_cta instead of m_dim) + # Row stride in the gap-expanded align16b SMEM is measured in element slots + # (1 byte per sub-byte element), matching the validated plain-FP4 path + # (compute_tcgen05_a_desc_params): swizzle_byte_size // elems_in_bytes. + a_swizzle_atom_elems = a_swizzle_mode.swizzle_byte_size() // a_elems_in_bytes - # Below computation assumes A is K-major + # Below computation assumes A is K-major. + # tcgen05 mxf4nvf4 requires a swizzled K-major SMEM operand (CUTLASS + # always selects SW128/SW64/SW32, never SWIZZLE_NONE). A linear/no-swizzle + # staging mis-addresses the core matrices, so reject it here. LBO/SBO mirror + # the validated FP4 path in compute_tcgen05_b_desc_params(). if a_swizzle_mode.is_none(): - a_leading_byte_offset = 16 // 2 # per 8x2 tile, each has 8 bytes - a_stride_byte_offset = 8 * k_dim // 2 # the offset from the first 8 rows to the next 8 rows. - else: - a_leading_byte_offset = 1 # not used, fixed to 1 - a_stride_byte_offset = 8 * k_dim // 2 # the offset from the first 8 rows to the next 8 rows. + raise ValueError( + "mxf4nvf4 block-scaled requires a swizzled K-major A operand " + "(SWIZZLE_128B); a linear/no-swizzle shared layout is unsupported." + ) + a_leading_byte_offset = 16 # >>4 == 1, canonical for swizzled K-major + a_stride_byte_offset = 8 * a_swizzle_mode.swizzle_byte_size() a_params = TCGEN05DescriptorParams( swizzle_mode=int(a_swizzle_mode), @@ -502,22 +511,23 @@ def tcgen05mma_mxf4nvf4_blockscaled( is_k_major=True, ) - # Below computation assumes B is K-major + # The PTX mxf4nvf4 opcode has fixed transpose bits, but the shared-memory + # descriptor still needs a 16-byte-aligned physical operand layout. b_swizzle_mode = self._determinate_swizzle_mode(B_buf, self.b_shared_layout) b_elems_in_bytes = (DataType(self.b_dtype).bits + 7) // 8 n_dim_per_cta = n_dim // 2 if enable_2cta else n_dim - b_swizzle_atom_elems = ( - n_dim_per_cta - if b_swizzle_mode.is_none() - else (b_swizzle_mode.swizzle_byte_size() * 8) // DataType(self.b_dtype).bits - ) + b_swizzle_atom_elems = b_swizzle_mode.swizzle_byte_size() // b_elems_in_bytes - if b_swizzle_mode.is_none(): - b_leading_byte_offset = 16 // 2 # per 8x2 tile, each has 8 bytes - b_stride_byte_offset = 8 * k_dim // 2 # the offset from the first 8 rows to the next 8 rows. + if self.b_transposed: + if b_swizzle_mode.is_none(): + raise ValueError( + "mxf4nvf4 block-scaled requires a swizzled K-major B operand " + "(SWIZZLE_128B); a linear/no-swizzle shared layout is unsupported." + ) + b_leading_byte_offset = 16 # >>4 == 1, canonical for swizzled K-major + b_stride_byte_offset = 8 * b_swizzle_mode.swizzle_byte_size() else: - b_leading_byte_offset = 1 # not used, fixed to 1 - b_stride_byte_offset = 8 * k_dim // 2 # the offset from the first 8 rows to the next 8 rows. + raise ValueError("mxf4nvf4 currently requires transpose_B=True (K-major B)") b_params = TCGEN05DescriptorParams( swizzle_mode=int(b_swizzle_mode), leading_byte_offset=int(b_leading_byte_offset >> 4), @@ -525,16 +535,18 @@ def tcgen05mma_mxf4nvf4_blockscaled( swizzle_atom_elems=b_swizzle_atom_elems, k_atom_size=max(b_swizzle_atom_elems // micro_size_k, 1), elems_in_bytes=b_elems_in_bytes, - is_k_major=True, + is_k_major=self.b_transposed, ) sf_dtype = DataType(SFA_tmem.dtype) is_mxfp4 = sf_dtype == DataType("float8_e8m0fnu") or sf_dtype == DataType("uint8") # UE8M0 storage + # The operands are staged DENSE as uint8 (two e2m1 per byte), so self.a_dtype + # is uint8; the mxf4nvf4 instruction descriptor must still describe e2m1. inst_desc = _ffi_api.get_tcgen5_mxf4nvf4_blockscaled_instr_desc( atom_m, atom_n, - DataType(self.a_dtype), + DataType("float4_e2m1fn"), 1, 1, is_mxfp4, @@ -562,7 +574,11 @@ def tcgen05mma_mxf4nvf4_blockscaled( num_inst_m = m_dim // atom_m_per_cta num_inst_n = n_dim // atom_n - num_k_atoms = self.tcgen05_num_k_atoms + # k_dim is the logical e2m1 K extent (e.g. 128). Each mxf4nvf4 MMA atom + # consumes 64 e2m1 along K (== 32 dense uint8 bytes; the byte stride is + # applied as ki*32 in tcgen05_blockscaled_atom). + assert k_dim % 64 == 0, "mxf4nvf4 block-scaled expects K to be a multiple of 64" + num_k_atoms = k_dim // 64 @T.macro def _warp_mma_mxf4nvf4_blockscaled(desc_a, desc_b, C_local_buf, sfa_data, sfb_data, mbar): @@ -1243,8 +1259,16 @@ def tcgen05_blockscaled_atom( k_dim if n_dim_per_cta // b_swizzle_atom_elems > 1 else 1 ) - A_byte_offset = A_elem_offset * a_elems_in_bytes - B_byte_offset = B_elem_offset * b_elems_in_bytes + if use_mxf4nvf4: + # Operands are dense uint8 (two e2m1 per byte). Offsets are in uint8 + # (=byte) units; each mxf4nvf4 K-atom spans 64 e2m1 == 32 dense bytes + # along K. Row stride is `*_swizzle_atom_elems` (uint8 elements). + MXF4NVF4_ATOM_BYTES = 32 + A_byte_offset = inst_m_idx * atom_m_per_cta * a_swizzle_atom_elems + ki * MXF4NVF4_ATOM_BYTES + B_byte_offset = inst_n_idx * atom_n * b_swizzle_atom_elems + ki * MXF4NVF4_ATOM_BYTES + else: + A_byte_offset = A_elem_offset * a_elems_in_bytes + B_byte_offset = B_elem_offset * b_elems_in_bytes tmem_col_step = atom_n // (128 // atom_m_per_cta) C_offset = (inst_m_idx * n_dim + inst_n_idx * tmem_col_step) * accum_dtype_in_bits // 32 diff --git a/tilelang/cuda/op/gemm/gemm_tcgen05.py b/tilelang/cuda/op/gemm/gemm_tcgen05.py index e6e7805657..8c03970fd2 100644 --- a/tilelang/cuda/op/gemm/gemm_tcgen05.py +++ b/tilelang/cuda/op/gemm/gemm_tcgen05.py @@ -274,6 +274,9 @@ def _lower_blockscaled(self, mma_emitter, thread_bounds, thread_var, mbar_phase_ use_2cta = _annotation_bool(annotations, "use_2cta") is_nvfp4 = _annotation_bool(annotations, "is_nvfp4") mma_emitter.get_tcgen5_mma_meta(self.M, self.N, self.K, disable_2cta=not use_2cta) + if is_nvfp4: + atom_m, atom_n, atom_k, _enable_ws, enable_2cta = (int(x) for x in mma_emitter.meta) + mma_emitter.meta = [atom_m, atom_n, atom_k, 0, enable_2cta] _atom_m, _atom_n, _atom_k, _enable_ws, enable_2cta = (int(x) for x in mma_emitter.meta) analyzer = Analyzer() diff --git a/tilelang/cuda/transform/__init__.py b/tilelang/cuda/transform/__init__.py index 68ed60a297..f8351fb3b9 100644 --- a/tilelang/cuda/transform/__init__.py +++ b/tilelang/cuda/transform/__init__.py @@ -39,7 +39,9 @@ def LowerBlackwell2SM(): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.LowerBlackwell2SM() # type: ignore + if hasattr(_ffi_api, "LowerBlackwell2SM"): + return _ffi_api.LowerBlackwell2SM() # type: ignore + return lambda f: f def LowerHopperIntrin(): diff --git a/tilelang/language/gemm_op.py b/tilelang/language/gemm_op.py index 4df0aa954f..d58969adf0 100644 --- a/tilelang/language/gemm_op.py +++ b/tilelang/language/gemm_op.py @@ -429,7 +429,7 @@ def tcgen05_gemm_blockscaled( use_2cta: Whether to request true ``cta_group::2`` lowering. """ - ann = {} + ann = {"is_tcgen05": 1} if use_2cta: ann["use_2cta"] = int(use_2cta) if is_nvfp4: @@ -464,12 +464,26 @@ def legalize(arg): assert len(B_shape) >= 2, "current only support B as a 2D or higher-order tensor" M, N = C_shape - M_A = A_shape[-1] if transpose_A else A_shape[-2] - N_B = B_shape[-2] if transpose_B else B_shape[-1] - K = A_shape[-2] if transpose_A else A_shape[-1] - K_B = B_shape[-1] if transpose_B else B_shape[-2] + is_packed_nvfp4 = ( + is_nvfp4 and str(A_region.buffer.dtype) == "uint8" and str(B_region.buffer.dtype) == "uint8" + ) + if is_packed_nvfp4: + assert not transpose_A, "Packed NVFP4 tcgen05_gemm_blockscaled expects non-transposed A" + M_A = A_shape[-2] + K = A_shape[-1] * 2 + if transpose_B: + N_B = B_shape[-2] + K_B = B_shape[-1] * 2 + else: + K_B = B_shape[-2] + N_B = B_shape[-1] * 2 + else: + M_A = A_shape[-1] if transpose_A else A_shape[-2] + N_B = B_shape[-2] if transpose_B else B_shape[-1] + K = A_shape[-2] if transpose_A else A_shape[-1] + K_B = B_shape[-1] if transpose_B else B_shape[-2] assert prim_expr_equal(K, K_B), f"T.tcgen05_gemm_blockscaled K shape check failed: K_A = {K}, K_B = {K_B}" - logical_K = K * 2 if is_nvfp4 and str(A_region.buffer.dtype) == "uint8" and str(B_region.buffer.dtype) == "uint8" else K + logical_K = K if use_2cta: assert prim_expr_equal(M_A, M) and prim_expr_equal(N_B * 2, N), ( f"T.tcgen05_gemm_blockscaled 2CTA shape check failed: M_A = {M_A}, expected M_C = {M}; N_B = {N_B}, expected N_C / 2 = {N} / 2" diff --git a/tilelang/layout/__init__.py b/tilelang/layout/__init__.py index e108270546..05ae2b705a 100644 --- a/tilelang/layout/__init__.py +++ b/tilelang/layout/__init__.py @@ -12,6 +12,7 @@ make_half_bank_swizzled_layout, # noqa: F401 make_quarter_bank_swizzled_layout, # noqa: F401 make_align16b_swizzled_layout, # noqa: F401 + make_dense_fp4_swizzled_layout, # noqa: F401 make_linear_layout, # noqa: F401 make_gemm_fragment_8x8, # noqa: F401 make_gemm_fragment_8x8_transposed, # noqa: F401 diff --git a/tilelang/layout/swizzle.py b/tilelang/layout/swizzle.py index ad3554d21f..690a22972a 100644 --- a/tilelang/layout/swizzle.py +++ b/tilelang/layout/swizzle.py @@ -174,6 +174,21 @@ def make_align16b_swizzled_layout(buffer: BufferLikeType): return make_tcgen05mma_swizzled_layout(buffer) +def make_dense_fp4_swizzled_layout(buffer: BufferLikeType): + """Dense-packed FP4 (e2m1, two elements per byte) swizzle for tcgen05 mxf4nvf4. + + Unlike the ALIGN16B layout (one FP4 per 8-bit container, used by f8f6f4), + mxf4nvf4 consumes FP4 dense-packed. The swizzle is applied at byte + granularity over the packed bytes (K/2), matching CUTLASS's dense e2m1 + SW64/SW128 K-major operand layout. + + Args: + buffer: BufferLikeType with a sub-byte dtype (e.g. float4_e2m1fn) + """ + buf, _, _ = _get_buffer_info(buffer) + return _ffi_api.make_dense_fp4_swizzled_layout(buf) + + def make_linear_layout(buffer_or_load_or_region: BufferLikeType): """ Create a row-major linear layout for any dimension. From 1111bfd202801dd5604c381532443e7b4b2f06f3 Mon Sep 17 00:00:00 2001 From: wahao Date: Mon, 15 Jun 2026 23:40:00 +0800 Subject: [PATCH 19/19] feat(examples): add SM100 NVFP4 fused MoE example (mxf4nvf4 gate/up + SiLU + routing); align SM100/SM120 NVFP4 MoE with FlashInfer compute paradigm --- .../gemm_fp4/example_fusedmoe_nvfp4_sm100.py | 219 +++++++++ examples/gemm_fp4/example_gemm_fp4_sm100.py | 420 +++++++++++------- examples/gemm_fp4/example_gemm_fp4_sm120.py | 338 ++++++++------ examples/gemm_fp4/example_gemm_nvfp4_sm100.py | 249 ----------- examples/gemm_fp4/example_gemm_nvfp4_sm120.py | 175 -------- examples/gemm_fp4/gemm_nvfp4_sm100.cu | 119 ----- src/layout/gemm_layouts.cc | 66 --- src/layout/layout.cc | 4 - src/layout/layout.h | 2 - tilelang/layout/__init__.py | 1 - tilelang/layout/swizzle.py | 15 - 11 files changed, 672 insertions(+), 936 deletions(-) create mode 100644 examples/gemm_fp4/example_fusedmoe_nvfp4_sm100.py delete mode 100644 examples/gemm_fp4/example_gemm_nvfp4_sm100.py delete mode 100644 examples/gemm_fp4/example_gemm_nvfp4_sm120.py delete mode 100644 examples/gemm_fp4/gemm_nvfp4_sm100.cu diff --git a/examples/gemm_fp4/example_fusedmoe_nvfp4_sm100.py b/examples/gemm_fp4/example_fusedmoe_nvfp4_sm100.py new file mode 100644 index 0000000000..6d8daef224 --- /dev/null +++ b/examples/gemm_fp4/example_fusedmoe_nvfp4_sm100.py @@ -0,0 +1,219 @@ +"""Native NVFP4 fused MoE gate/up (FC1) example on SM100/SM110 (TCGEN05). + +Matches the FlashInfer/TRT-LLM NVFP4 MoE compute paradigm (cf. flashinfer +`trtllm_fp4_block_scale_moe` / `cute_dsl_fused_moe_nvfp4`): activations AND expert +weights are NVFP4 (FP4 E2M1 data + FP8 E4M3 block scales, sf_vec_size=16); per +expert tile we run two block-scaled GEMMs (gate, up), then SiLU(gate) * up, then +apply routing. The down (FC2) projection and top-k routing are intentionally +omitted (single shared-expert demo), mirroring the SM120 NVFP4 MoE example. + +This is the SM100/SM110 (DRIVE Thor) counterpart of example_fusedmoe_nvfp4_sm120.py: +it uses tcgen05.mma.kind::mxf4nvf4.block_scale with FP4 operands staged DENSE +(2 e2m1/byte) and E4M3 scale factors moved into TMEM. NOT yet hardware-verified. +""" + +import argparse +import os +import time + +import torch +import tilelang +import tilelang.language as T + + +FP4_E2M1_TO_FLOAT = [ + 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, + -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0, +] + + +def unpack_fp4_to_float(packed, rows, cols): + lut = torch.tensor(FP4_E2M1_TO_FLOAT, dtype=torch.float32, device=packed.device) + flat = packed.to(torch.uint8).reshape(rows, cols // 2) + lo = flat & 0x0F + hi = (flat >> 4) & 0x0F + codes = torch.stack([lo, hi], dim=-1).reshape(rows, cols).to(torch.int64) + return lut[codes] + + +def fusedmoe_nvfp4_sm100(num_tokens, hidden, intermediate, + block_token=128, block_hidden=64, block_expert=128, threads=128): + packed_hidden = hidden // 2 + packed_block_hidden = block_hidden // 2 + k_tiles = T.ceildiv(hidden, block_hidden) + silu_scale = 1.44269504 # log2(e) + + @T.prim_func + def main( + X: T.Tensor((num_tokens, hidden), T.float4_e2m1fn), + X_scale: T.Tensor((T.ceildiv(hidden, block_hidden) * num_tokens,), "uint32"), + W_gate: T.Tensor((intermediate, hidden), T.float4_e2m1fn), + W_gate_scale: T.Tensor((T.ceildiv(hidden, block_hidden) * intermediate,), "uint32"), + W_up: T.Tensor((intermediate, hidden), T.float4_e2m1fn), + W_up_scale: T.Tensor((T.ceildiv(hidden, block_hidden) * intermediate,), "uint32"), + token_selected_experts: T.Tensor((num_tokens,), "int32"), + token_final_scales: T.Tensor((num_tokens,), "float32"), + Output: T.Tensor((num_tokens, intermediate), "float32"), + ): + with T.Kernel(T.ceildiv(intermediate, block_expert), T.ceildiv(num_tokens, block_token), + threads=threads) as (bx, by): + # NVFP4 operands staged DENSE (two e2m1 per byte) as uint8. + X_bytes = T.view(X, (num_tokens, packed_hidden), "uint8") + Wg_bytes = T.view(W_gate, (intermediate, packed_hidden), "uint8") + Wu_bytes = T.view(W_up, (intermediate, packed_hidden), "uint8") + X_shared = T.alloc_shared((block_token, packed_block_hidden), "uint8") + gate_shared = T.alloc_shared((block_expert, packed_block_hidden), "uint8") + up_shared = T.alloc_shared((block_expert, packed_block_hidden), "uint8") + sfx_shared = T.alloc_shared((block_token,), "uint32") + sfg_shared = T.alloc_shared((block_expert,), "uint32") + sfu_shared = T.alloc_shared((block_expert,), "uint32") + + gate_tmem = T.alloc_tmem([block_token, block_expert], "float32") + up_tmem = T.alloc_tmem([block_token, block_expert], "float32") + sfx_tmem = T.alloc_tmem([block_token, 4], "uint32") + sfg_tmem = T.alloc_tmem([block_token, 4], "uint32") + sfu_tmem = T.alloc_tmem([block_token, 4], "uint32") + gate_mbar = T.alloc_barrier(1) + up_mbar = T.alloc_barrier(1) + + gate_local = T.alloc_fragment((block_token, block_expert), "float32") + up_local = T.alloc_fragment((block_token, block_expert), "float32") + + tx = T.get_thread_binding() + T.use_swizzle(8) + + for ko in T.serial(k_tiles): + T.copy(X_bytes[by * block_token, ko * packed_block_hidden], X_shared) + T.copy(Wg_bytes[bx * block_expert, ko * packed_block_hidden], gate_shared) + T.copy(Wu_bytes[bx * block_expert, ko * packed_block_hidden], up_shared) + T.copy(X_scale[ko * num_tokens + by * block_token], sfx_shared) + T.copy(W_gate_scale[ko * intermediate + bx * block_expert], sfg_shared) + T.copy(W_up_scale[ko * intermediate + bx * block_expert], sfu_shared) + T.sync_threads() + + if tx < 32: + T.tcgen05_sf_warp_transpose(sfx_shared) + T.tcgen05_sf_warp_transpose(sfg_shared) + T.tcgen05_sf_warp_transpose(sfu_shared) + T.fence_proxy_async() + T.tcgen05_cp_warpx4(sfx_shared, sfx_tmem) + T.tcgen05_cp_warpx4(sfg_shared, sfg_tmem) + T.tcgen05_cp_warpx4(sfu_shared, sfu_tmem) + T.sync_threads() + + if 32 <= tx and tx < 64: + T.tcgen05_gemm_blockscaled(X_shared, gate_shared, gate_tmem, sfx_tmem, sfg_tmem, + transpose_B=True, mbar=gate_mbar, + clear_accum=(ko == 0), is_nvfp4=True) + T.mbarrier_wait_parity(gate_mbar, ko % 2) + if 32 <= tx and tx < 64: + T.tcgen05_gemm_blockscaled(X_shared, up_shared, up_tmem, sfx_tmem, sfu_tmem, + transpose_B=True, mbar=up_mbar, + clear_accum=(ko == 0), is_nvfp4=True) + T.mbarrier_wait_parity(up_mbar, ko % 2) + T.sync_threads() + + T.copy(gate_tmem, gate_local) + T.copy(up_tmem, up_local) + for t, i in T.Parallel(block_token, block_expert): + token_idx = by * block_token + t + gate = gate_local[t, i] + silu = gate * (1.0 / (1.0 + T.exp2(-gate * silu_scale))) + routed = T.if_then_else(token_selected_experts[token_idx] == 0, + token_final_scales[token_idx], 0.0) + Output[token_idx, bx * block_expert + i] = up_local[t, i] * silu * routed + + return main + + +def _decode_sf_full(sf, rows, hidden, block_hidden): + """Decode flat group-major E4M3 scale words -> [rows, hidden] float scale.""" + sf_tiles = (hidden + block_hidden - 1) // block_hidden + blocks_per_tile = block_hidden // 16 # 4 E4M3 per uint32 + raw = sf.view(torch.uint8).reshape(sf_tiles, rows, 4)[:, :, :blocks_per_tile].contiguous() + sc = raw.view(torch.float8_e4m3fn).float() # [tile, row, blocks_per_tile] + sc = sc.permute(1, 0, 2).reshape(rows, sf_tiles * blocks_per_tile) # [row, hidden//16] + return sc.repeat_interleave(16, dim=1) # [row, hidden] + + +def make_sf(rows, hidden, block_hidden): + """Random per-(row, 16-block) E4M3 scale words, group-major flat layout.""" + sf_tiles = (hidden + block_hidden - 1) // block_hidden + blocks_per_tile = block_hidden // 16 + n_blocks = sf_tiles * blocks_per_tile # == hidden // 16 + vals = torch.rand(rows, n_blocks, device="cuda") * 1.75 + 0.25 + e4 = vals.to(torch.float8_e4m3fn).view(torch.uint8) + out = torch.zeros(sf_tiles * rows, 4, device="cuda", dtype=torch.uint8) + for t in range(sf_tiles): + out[t * rows:(t + 1) * rows, :blocks_per_tile] = e4[:, t * blocks_per_tile:(t + 1) * blocks_per_tile] + return out.contiguous().view(torch.uint32).reshape(sf_tiles * rows) + + +def main(): + parser = argparse.ArgumentParser(description="NVFP4 fused MoE (FC1) on SM100/SM110") + parser.add_argument("-n", "--num-tokens", type=int, default=1024) + parser.add_argument("-c", "--hidden-size", type=int, default=256) + parser.add_argument("-s", "--intermediate-size", type=int, default=1024) + parser.add_argument("--warmup", type=int, default=int(os.environ.get("TL_MOE_WARMUP", "20"))) + parser.add_argument("--iters", type=int, default=int(os.environ.get("TL_MOE_ITERS", "100"))) + args = parser.parse_args() + num_tokens, hidden, intermediate = args.num_tokens, args.hidden_size, args.intermediate_size + block_hidden = 64 + + print(f"Running SM100 NVFP4 fused MoE (kind::mxf4nvf4.block_scale): " + f"tokens={num_tokens}, hidden={hidden}, intermediate={intermediate}") + + kernel = tilelang.compile( + fusedmoe_nvfp4_sm100(num_tokens, hidden, intermediate, block_hidden=block_hidden), + out_idx=[8], target="cuda", + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: False, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, + ) + print("Compilation succeeded!") + + torch.manual_seed(0) + x = torch.randint(0, 256, (num_tokens, hidden // 2), device="cuda", dtype=torch.uint8) + wg = torch.randint(0, 256, (intermediate, hidden // 2), device="cuda", dtype=torch.uint8) + wu = torch.randint(0, 256, (intermediate, hidden // 2), device="cuda", dtype=torch.uint8) + sfx = make_sf(num_tokens, hidden, block_hidden) + sfg = make_sf(intermediate, hidden, block_hidden) + sfu = make_sf(intermediate, hidden, block_hidden) + selected = torch.zeros((num_tokens,), device="cuda", dtype=torch.int32) + routing = torch.ones((num_tokens,), device="cuda", dtype=torch.float32) + + out = kernel(x, sfx, wg, sfg, wu, sfu, selected, routing) + + # Reference (flashinfer NVFP4 numerics): dequantize, gate/up GEMM, SiLU(gate)*up, route. + sx = _decode_sf_full(sfx, num_tokens, hidden, block_hidden) + sg = _decode_sf_full(sfg, intermediate, hidden, block_hidden) + su = _decode_sf_full(sfu, intermediate, hidden, block_hidden) + x_f = unpack_fp4_to_float(x, num_tokens, hidden) * sx + gate = x_f @ (unpack_fp4_to_float(wg, intermediate, hidden) * sg).T + up = x_f @ (unpack_fp4_to_float(wu, intermediate, hidden) * su).T + routed = torch.where(selected == 0, routing, torch.zeros_like(routing)).unsqueeze(1) + ref = up * (gate * torch.sigmoid(gate)) * routed + + diff = (out.float() - ref).abs() + max_diff = diff.max().item() + rel_err = diff.sum().item() / (ref.abs().sum().item() + 1e-10) + print(f"[NUMERICAL] max_abs_diff={max_diff:.4f}, rel_err={rel_err:.6f}") + print("[PASS] MoE numerical verification" if rel_err < 0.05 else "[WARN] large diff (unverified)") + + for _ in range(args.warmup): + kernel(x, sfx, wg, sfg, wu, sfu, selected, routing) + torch.cuda.synchronize() + start = time.perf_counter() + for _ in range(args.iters): + kernel(x, sfx, wg, sfg, wu, sfu, selected, routing) + torch.cuda.synchronize() + elapsed = (time.perf_counter() - start) / max(args.iters, 1) * 1000 + total_flops = 2 * num_tokens * hidden * intermediate * 2 # gate + up + print(f"Latency: {elapsed:.4f} ms") + if elapsed > 0: + print(f"TFLOPS: {total_flops / (elapsed / 1e3) / 1e12:.2f}") + + +if __name__ == "__main__": + main() diff --git a/examples/gemm_fp4/example_gemm_fp4_sm100.py b/examples/gemm_fp4/example_gemm_fp4_sm100.py index 9ff596f49f..987c670ab9 100644 --- a/examples/gemm_fp4/example_gemm_fp4_sm100.py +++ b/examples/gemm_fp4/example_gemm_fp4_sm100.py @@ -1,12 +1,17 @@ -"""FP4 (float4_e2m1fn) GEMM on SM100/SM110 (B200/Thor) using TCGEN05 async MMA. +"""FP4 / NVFP4 GEMM on SM100/SM110 (B200 / DRIVE Thor) using TCGEN05 MMA. -Uses tcgen05.mma.kind::f8f6f4 with TMEM accumulator. -The TMA path keeps global FP4 packed, then writes it into SMEM using the -ALIGN16B unpacksmem layout expected by tcgen05.mma. +Default: plain FP4 (float4_e2m1fn) GEMM via tcgen05.mma.kind::f8f6f4 (FP4 staged +gap-expanded in the ALIGN16B unpacksmem SMEM layout). + +With ``--nvfp4``: NVFP4 block-scaled GEMM via tcgen05.mma.kind::mxf4nvf4.block_scale +-- FP4 data plus one E4M3 scale factor per 16 elements along K. FP4 operands are +staged DENSE (two e2m1 per byte) as uint8, which the >=8-bit swizzle path lays out +to match the dense e2m1 K-major operand the mxf4nvf4 MMA expects. Supported: SM100 (B100/B200), SM101/SM110 (DRIVE Thor), SM103 (B300). """ +import argparse import os import time @@ -16,92 +21,41 @@ FP4_E2M1_TO_FLOAT = [ - 0.0, - 0.5, - 1.0, - 1.5, - 2.0, - 3.0, - 4.0, - 6.0, - -0.0, - -0.5, - -1.0, - -1.5, - -2.0, - -3.0, - -4.0, - -6.0, + 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, + -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0, ] -def matmul_fp4_sm100( - M, - N, - K, - block_M, - block_N, - block_K, - in_dtype, - out_dtype, - accum_dtype, - num_stages=2, - threads=256, - transpose_b=True, -): - """FP4 GEMM using TCGEN05 async MMA + TMEM.""" - A_shape = (M, K) - A_shared_shape = (block_M, block_K) - if transpose_b: - B_shape = (N, K) - B_shared_shape = (block_N, block_K) - - @T.prim_func - def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), - C: T.Tensor((M, N), out_dtype), - ): - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype) - B_shared = T.alloc_shared(B_shared_shape, in_dtype) - C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype) - mbar = T.alloc_barrier(1) - C_local = T.alloc_fragment((block_M, block_N), accum_dtype) - C_shared = T.alloc_shared((block_M, block_N), out_dtype) - - for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): - T.copy(A[by * block_M, k * block_K], A_shared) - T.copy(B[bx * block_N, k * block_K], B_shared) - T.tcgen05_gemm( - A_shared, - B_shared, - C_tmem, - transpose_A=False, - transpose_B=True, - mbar=mbar, - clear_accum=(k == 0), - ) - T.mbarrier_wait_parity(mbar, k % 2) +def unpack_fp4_to_float(packed, rows, cols, high_first=False): + lut = torch.tensor(FP4_E2M1_TO_FLOAT, dtype=torch.float32, device=packed.device) + flat = packed.to(torch.uint8).reshape(rows, cols // 2) + lo = flat & 0x0F + hi = (flat >> 4) & 0x0F + first, second = (hi, lo) if high_first else (lo, hi) + unpacked = torch.stack([first, second], dim=-1).reshape(rows, cols).to(torch.int64) + return lut[unpacked] + - T.copy(C_tmem, C_local) - T.copy(C_local, C_shared) - T.copy(C_shared, C[by * block_M, bx * block_N]) +def pack_e4m3x4_to_u32(values): + raw = torch.tensor(values, dtype=torch.float32).to(torch.float8_e4m3fn).view(torch.uint8) + return sum(int(byte) << (8 * i) for i, byte in enumerate(raw.tolist())) - return main - B_shape = (K, N) - B_shared_shape = (block_K, block_N) +# --------------------------------------------------------------------------- +# Plain FP4 (kind::f8f6f4): A/B are native float4_e2m1fn, transpose_B (TN). +# --------------------------------------------------------------------------- +def matmul_fp4_sm100(M, N, K, block_M, block_N, block_K, in_dtype, out_dtype, + accum_dtype, num_stages=1, threads=128): @T.prim_func def main( - A: T.Tensor(A_shape, in_dtype), - B: T.Tensor(B_shape, in_dtype), + A: T.Tensor((M, K), in_dtype), + B: T.Tensor((N, K), in_dtype), C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype) - B_shared = T.alloc_shared(B_shared_shape, in_dtype) + A_shared = T.alloc_shared((block_M, block_K), in_dtype) + B_shared = T.alloc_shared((block_N, block_K), in_dtype) C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype) mbar = T.alloc_barrier(1) C_local = T.alloc_fragment((block_M, block_N), accum_dtype) @@ -109,15 +63,11 @@ def main( for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): T.copy(A[by * block_M, k * block_K], A_shared) - T.copy(B[k * block_K, bx * block_N], B_shared) + T.copy(B[bx * block_N, k * block_K], B_shared) T.tcgen05_gemm( - A_shared, - B_shared, - C_tmem, - transpose_A=False, - transpose_B=False, - mbar=mbar, - clear_accum=(k == 0), + A_shared, B_shared, C_tmem, + transpose_A=False, transpose_B=True, + mbar=mbar, clear_accum=(k == 0), ) T.mbarrier_wait_parity(mbar, k % 2) @@ -128,59 +78,74 @@ def main( return main -def unpack_fp4_to_float(packed_int8, M, K): - lut = torch.tensor(FP4_E2M1_TO_FLOAT, dtype=torch.float32, device=packed_int8.device) - flat = packed_int8.to(torch.uint8).reshape(M, K // 2) - lo = flat & 0x0F - hi = (flat >> 4) & 0x0F - unpacked = torch.stack([lo, hi], dim=-1).reshape(M, K).to(torch.int64) - return lut[unpacked] +# --------------------------------------------------------------------------- +# NVFP4 (kind::mxf4nvf4.block_scale): FP4 + E4M3 block scale (one per 16 along K). +# FP4 staged DENSE as uint8[M, K/2]; one K64 MMA atom per tile (block_K=64) so the +# 4 scales of that tile pack into a single uint32 SF word. +# --------------------------------------------------------------------------- +def matmul_nvfp4_sm100(M, N, K, block_M, block_N, block_K, out_dtype, accum_dtype, threads=128): + packed_K = K // 2 + packed_block_K = block_K // 2 + k_tiles = T.ceildiv(K, block_K) + + @T.prim_func + def main( + A: T.Tensor((M, K), T.float4_e2m1fn), + B: T.Tensor((N, K), T.float4_e2m1fn), + SFA: T.Tensor((T.ceildiv(K, block_K) * M,), "uint32"), + SFB: T.Tensor((T.ceildiv(K, block_K) * N,), "uint32"), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + # DENSE-packed FP4 SMEM staging (two e2m1 per byte). The >=8-bit swizzle + # path gives this the bank swizzle matching CUTLASS's dense e2m1 K-major + # operand for mxf4nvf4. + A_bytes = T.view(A, (M, packed_K), "uint8") + B_bytes = T.view(B, (N, packed_K), "uint8") + A_shared = T.alloc_shared((block_M, packed_block_K), "uint8") + B_shared = T.alloc_shared((block_N, packed_block_K), "uint8") + SFA_shared = T.alloc_shared((block_M,), "uint32") + SFB_shared = T.alloc_shared((block_N,), "uint32") + C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype) + SFA_tmem = T.alloc_tmem([block_M, 4], "uint32") + SFB_tmem = T.alloc_tmem([block_M, 4], "uint32") + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_shared = T.alloc_shared((block_M, block_N), out_dtype) + mbar = T.alloc_barrier(1) -M = int(os.environ.get("TL_FP4_M", "256")) -N = int(os.environ.get("TL_FP4_N", "256")) -K = int(os.environ.get("TL_FP4_K", "256")) -block_M = int(os.environ.get("TL_FP4_BLOCK_M", "128")) -block_N = int(os.environ.get("TL_FP4_BLOCK_N", "64")) -block_K = int(os.environ.get("TL_FP4_BLOCK_K", "128")) -in_dtype = T.float4_e2m1fn -out_dtype = T.float32 -accum_dtype = T.float32 -num_stages = 1 -threads = 128 - -input_mode = os.environ.get("TL_FP4_INPUT_MODE", "random") -transpose_b = os.environ.get("TL_FP4_TRANSPOSE_B", "1") != "0" -print(f"Running FP4 GEMM (SM100/SM110 TCGEN05): M={M}, N={N}, K={K}, input_mode={input_mode}, transpose_b={transpose_b}") - -func = matmul_fp4_sm100( - M, - N, - K, - block_M, - block_N, - block_K, - in_dtype, - out_dtype, - accum_dtype, - num_stages, - threads, - transpose_b, -) - -jit_kernel = tilelang.compile( - func, - out_idx=[2], - target="cuda", - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: False, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }, -) - -print("Compilation succeeded!") - -torch.manual_seed(42) + tx = T.get_thread_binding() + T.use_swizzle(8) + + for ko in T.serial(k_tiles): + T.copy(A_bytes[by * block_M, ko * packed_block_K], A_shared) + T.copy(B_bytes[bx * block_N, ko * packed_block_K], B_shared) + T.copy(SFA[ko * M + by * block_M], SFA_shared) + T.copy(SFB[ko * N + bx * block_N], SFB_shared) + T.sync_threads() + + if tx < 32: + T.tcgen05_sf_warp_transpose(SFA_shared) + T.tcgen05_sf_warp_transpose(SFB_shared) + T.fence_proxy_async() + T.tcgen05_cp_warpx4(SFA_shared, SFA_tmem) + T.tcgen05_cp_warpx4(SFB_shared, SFB_tmem) + T.sync_threads() + + if 32 <= tx and tx < 64: + T.tcgen05_gemm_blockscaled( + A_shared, B_shared, C_tmem, SFA_tmem, SFB_tmem, + transpose_B=True, mbar=mbar, + clear_accum=(ko == 0), is_nvfp4=True, + ) + T.mbarrier_wait_parity(mbar, ko % 2) + T.sync_threads() + + T.copy(C_tmem, C_local) + T.copy(C_local, C_shared) + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return main def make_random_fp4(rows, cols, mode): @@ -189,41 +154,152 @@ def make_random_fp4(rows, cols, mode): hi = torch.randint(0, 8, (rows, cols // 2), device="cuda", dtype=torch.uint8) return (lo | (hi << 4)).to(torch.int8) if mode == "low_nibble": - lo = torch.randint(0, 16, (rows, cols // 2), device="cuda", dtype=torch.uint8) - return lo.to(torch.int8) + return torch.randint(0, 16, (rows, cols // 2), device="cuda", dtype=torch.uint8).to(torch.int8) if mode == "high_nibble": - hi = torch.randint(0, 16, (rows, cols // 2), device="cuda", dtype=torch.uint8) - return (hi << 4).to(torch.int8) + return (torch.randint(0, 16, (rows, cols // 2), device="cuda", dtype=torch.uint8) << 4).to(torch.int8) if mode == "random": return torch.randint(0, 256, (rows, cols // 2), device="cuda", dtype=torch.uint8).to(torch.int8) - raise ValueError(f"Unsupported TL_FP4_INPUT_MODE={mode}") - - -a_packed = make_random_fp4(M, K, input_mode) -b_packed = make_random_fp4(N, K, input_mode) if transpose_b else make_random_fp4(K, N, input_mode) - -a_zero = torch.zeros(M, K // 2, device="cuda", dtype=torch.int8) -b_zero = torch.zeros(N, K // 2, device="cuda", dtype=torch.int8) if transpose_b else torch.zeros(K, N // 2, device="cuda", dtype=torch.int8) -c_zero = jit_kernel(a_zero, b_zero) -assert c_zero.abs().max().item() == 0.0, f"Zero test failed: max={c_zero.abs().max().item()}" -print("[PASS] zeros in -> zeros out") - -c = jit_kernel(a_packed, b_packed) -a_float = unpack_fp4_to_float(a_packed, M, K) -b_float = unpack_fp4_to_float(b_packed, N, K) if transpose_b else unpack_fp4_to_float(b_packed, K, N) -ref_c = a_float @ (b_float.T if transpose_b else b_float) - -diff = (c.float() - ref_c).abs() -max_diff = diff.max().item() -rel_err = diff.sum().item() / (ref_c.abs().sum().item() + 1e-10) -print(f"[NUMERICAL] max_abs_diff={max_diff:.4f}, rel_err={rel_err:.6f}") -print("[PASS] numerical verification" if max_diff < 1.0 else "[WARN] large diff") - -torch.cuda.synchronize() -start = time.perf_counter() -for _ in range(100): - jit_kernel(a_packed, b_packed) -torch.cuda.synchronize() -elapsed = (time.perf_counter() - start) / 100 * 1000 -print(f"Latency: {elapsed:.4f} ms") -print(f"TFLOPS: {2 * M * N * K / (elapsed / 1e3) / 1e12:.2f}") + raise ValueError(f"Unsupported input mode={mode}") + + +def _arch_and_flags(args): + device_major, device_minor = torch.cuda.get_device_capability() + arch = args.arch or f"sm_{device_major}{device_minor}" + cuda_root = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") or "/usr/local/cuda" + cuda_include = args.cuda_include or f"{cuda_root}/targets/{args.cuda_target}/include" + flags = [f"--target-directory={args.cuda_target}", f"-I{cuda_include}"] + return arch, flags + + +def run_fp4(args): + M, N, K = args.m, args.n, args.k + block_M, block_N, block_K = 128, 64, 128 + arch, device_flags = _arch_and_flags(args) + print(f"Running SM100 FP4 GEMM (kind::f8f6f4): M={M}, N={N}, K={K}, arch={arch}, input_mode={args.input_mode}") + + func = matmul_fp4_sm100(M, N, K, block_M, block_N, block_K, T.float4_e2m1fn, T.float32, T.float32) + jit_kernel = tilelang.compile( + func, out_idx=[2], target={"kind": "cuda", "arch": arch}, + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: False, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + tilelang.PassConfigKey.TL_DEVICE_COMPILE_FLAGS: device_flags, + }, + ) + print("Compilation succeeded!") + + torch.manual_seed(42) + a_packed = make_random_fp4(M, K, args.input_mode) + b_packed = make_random_fp4(N, K, args.input_mode) + + a_zero = torch.zeros(M, K // 2, device="cuda", dtype=torch.int8) + b_zero = torch.zeros(N, K // 2, device="cuda", dtype=torch.int8) + assert jit_kernel(a_zero, b_zero).abs().max().item() == 0.0, "Zero test failed" + print("[PASS] zeros in -> zeros out") + + c = jit_kernel(a_packed, b_packed) + ref = unpack_fp4_to_float(a_packed, M, K) @ unpack_fp4_to_float(b_packed, N, K).T + diff = (c.float() - ref).abs() + max_diff = diff.max().item() + rel_err = diff.sum().item() / (ref.abs().sum().item() + 1e-10) + print(f"[NUMERICAL] max_abs_diff={max_diff:.4f}, rel_err={rel_err:.6f}") + print("[PASS] numerical verification" if max_diff < 1.0 else "[WARN] large diff") + + _bench(lambda: jit_kernel(a_packed, b_packed), M, N, K, args) + + +def run_nvfp4(args): + M, N, K = args.m, args.n, args.k + block_M, block_N, block_K = 128, 128, 64 # block_K=64: 4 scales == 1 SF word per tile + arch, device_flags = _arch_and_flags(args) + sf_tiles = (K + block_K - 1) // block_K + blocks_per_tile = block_K // 16 + + print(f"Running SM100 NVFP4 GEMM (kind::mxf4nvf4.block_scale): M={M}, N={N}, K={K}, arch={arch}, " + f"scale=random") + + kernel = tilelang.compile( + matmul_nvfp4_sm100(M, N, K, block_M, block_N, block_K, T.float32, T.float32), + out_idx=[4], target={"kind": "cuda", "arch": arch}, + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: False, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + tilelang.PassConfigKey.TL_DEVICE_COMPILE_FLAGS: device_flags, + }, + ) + print("Compilation succeeded!") + + torch.manual_seed(0) + a = torch.randint(0, 256, (M, K // 2), device="cuda", dtype=torch.uint8) + b = torch.randint(0, 256, (N, K // 2), device="cuda", dtype=torch.uint8) + + def make_sf(rows): + # Random per-(row, K16-block) E4M3 scales. The reference decodes the actual + # E4M3-quantized values from the same bytes, so verification stays exact. + n_blocks = sf_tiles * blocks_per_tile # == K // 16 + vals = torch.rand(rows, n_blocks, device="cuda") * 1.75 + 0.25 # ~[0.25, 2.0) + e4 = vals.to(torch.float8_e4m3fn).view(torch.uint8) + out = torch.zeros(sf_tiles * rows, 4, device="cuda", dtype=torch.uint8) + for t in range(sf_tiles): + out[t * rows:(t + 1) * rows, :blocks_per_tile] = e4[:, t * blocks_per_tile:(t + 1) * blocks_per_tile] + return out.contiguous().view(torch.uint32).reshape(sf_tiles * rows) + + def decode_sf_full(sf, rows): + raw = sf.view(torch.uint8).reshape(sf_tiles, rows, 4)[:, :, :blocks_per_tile].contiguous() + sc = raw.view(torch.float8_e4m3fn).float() + sc = sc.permute(1, 0, 2).reshape(rows, sf_tiles * blocks_per_tile) + return sc.repeat_interleave(16, dim=1) + + sfa, sfb = make_sf(M), make_sf(N) + c = kernel(a, b, sfa, sfb) + sa_full, sb_full = decode_sf_full(sfa, M), decode_sf_full(sfb, N) + ref = (unpack_fp4_to_float(a, M, K) * sa_full) @ (unpack_fp4_to_float(b, N, K) * sb_full).T + diff = (c.float() - ref).abs() + max_diff = diff.max().item() + rel_err = diff.sum().item() / (ref.abs().sum().item() + 1e-10) + print(f"[NUMERICAL] max_abs_diff={max_diff:.4f}, rel_err={rel_err:.6f}") + print("[PASS] numerical verification" if max_diff < 1.0 else "[WARN] large diff") + + _bench(lambda: kernel(a, b, sfa, sfb), M, N, K, args) + + +def _bench(fn, M, N, K, args): + for _ in range(args.warmup): + fn() + torch.cuda.synchronize() + start = time.perf_counter() + for _ in range(args.iters): + fn() + torch.cuda.synchronize() + elapsed = (time.perf_counter() - start) / max(args.iters, 1) * 1000 + print(f"Latency: {elapsed:.4f} ms") + if elapsed > 0: + print(f"TFLOPS: {2 * M * N * K / (elapsed / 1e3) / 1e12:.2f}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="FP4 / NVFP4 GEMM on SM100/SM110") + parser.add_argument("--nvfp4", action="store_true", + help="run NVFP4 block-scaled GEMM with random E4M3 scales (default: plain FP4)") + parser.add_argument("--m", type=int, default=None) + parser.add_argument("--n", type=int, default=None) + parser.add_argument("--k", type=int, default=None) + parser.add_argument("--input-mode", default=os.environ.get("TL_FP4_INPUT_MODE", "random"), + help="[fp4] random|positive|low_nibble|high_nibble") + parser.add_argument("--warmup", type=int, default=int(os.environ.get("TL_FP4_WARMUP", "20"))) + parser.add_argument("--iters", type=int, default=int(os.environ.get("TL_FP4_ITERS", "100"))) + parser.add_argument("--arch", default=os.environ.get("TL_FP4_ARCH")) + parser.add_argument("--cuda-target", default=os.environ.get("TL_FP4_CUDA_TARGET_DIR", "aarch64-linux")) + parser.add_argument("--cuda-include", default=os.environ.get("TL_FP4_CUDA_INCLUDE")) + args = parser.parse_args() + + # Size defaults differ by variant; env vars TL_FP4_{M,N,K} override, --m/--n/--k win. + default_mnk = 128 if args.nvfp4 else 256 + args.m = args.m if args.m is not None else int(os.environ.get("TL_FP4_M", default_mnk)) + args.n = args.n if args.n is not None else int(os.environ.get("TL_FP4_N", default_mnk)) + args.k = args.k if args.k is not None else int(os.environ.get("TL_FP4_K", default_mnk)) + + if args.nvfp4: + run_nvfp4(args) + else: + run_fp4(args) diff --git a/examples/gemm_fp4/example_gemm_fp4_sm120.py b/examples/gemm_fp4/example_gemm_fp4_sm120.py index 98f32f4edc..6148d39363 100644 --- a/examples/gemm_fp4/example_gemm_fp4_sm120.py +++ b/examples/gemm_fp4/example_gemm_fp4_sm120.py @@ -1,47 +1,65 @@ -"""FP4 (float4_e2m1fn) GEMM on SM120 (RTX 5080/5090) using fragment-based MMA. +"""FP4 / NVFP4 GEMM on SM120 (RTX 5080/5090) using fragment-based mma.sync. -Uses mma.sync.aligned.kind::f8f6f4 instructions (not TCGEN05/TMEM). -FP4 data is pre-unpacked to uint8 on the host (1 byte per element, low -nibble holds the 4-bit value). Shared memory stores uint8, making the -layout identical to INT8 and satisfying ldmatrix 16B alignment. -The << 2 bit-shift before MMA places FP4 data at bits 2-5 as required. +Default: plain FP4 (float4_e2m1fn) GEMM via mma.sync.aligned.kind::f8f6f4. +FP4 data is pre-unpacked to uint8 on the host (1 byte/element, low nibble holds +the 4-bit value); shared memory stores uint8 like INT8 (ldmatrix 16B aligned). + +With ``--nvfp4``: NVFP4 block-scaled GEMM via +mma.sync.aligned.kind::mxf4nvf4.block_scale -- both operands are dense-packed FP4 +(two e2m1 per byte) with FP8 E4M3 scale factors (sf_vec_size=16) held in registers. Addresses https://github.com/tile-ai/tilelang/issues/1592 """ +import argparse import os import time + import torch import tilelang import tilelang.language as T +from tilelang.layout import make_swizzled_layout + + +FP4_E2M1_TO_FLOAT = [ + 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, + -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0, +] + + +def unpack_fp4_to_uint8(packed_int8, M, K): + """Unpack (M, K//2) -> (M, K) uint8, 1 FP4 per byte in low nibble.""" + flat = packed_int8.to(torch.uint8).reshape(M, K // 2) + lo = flat & 0x0F + hi = (flat >> 4) & 0x0F + return torch.stack([lo, hi], dim=-1).reshape(M, K).contiguous() -def matmul_fp4( - M, - N, - K, - block_M, - block_N, - block_K, - out_dtype, - accum_dtype, - num_stages=2, - threads=128, -): - A_shape = (M, K) - B_shape = (N, K) - A_shared_shape = (block_M, block_K) - B_shared_shape = (block_N, block_K) +def unpack_fp4_to_float(packed_int8, M, K): + lut = torch.tensor(FP4_E2M1_TO_FLOAT, dtype=torch.float32, device=packed_int8.device) + return lut[unpack_fp4_to_uint8(packed_int8, M, K).to(torch.int64)] + + +def pack_e4m3x4_to_u32(values): + raw = torch.tensor(values, dtype=torch.float32).to(torch.float8_e4m3fn).view(torch.uint8) + return sum(int(byte) << (8 * i) for i, byte in enumerate(raw.tolist())) + + +# --------------------------------------------------------------------------- +# Plain FP4 (kind::f8f6f4): A/B unpacked to uint8 (1 byte/FP4), fragment T.gemm. +# --------------------------------------------------------------------------- +def matmul_fp4_sm120(M, N, K, block_M, block_N, block_K, out_dtype, accum_dtype, + num_stages=2, threads=128): @T.prim_func def main( - A: T.Tensor(A_shape, "uint8"), - B: T.Tensor(B_shape, "uint8"), + A: T.Tensor((M, K), "uint8"), + B: T.Tensor((N, K), "uint8"), C: T.Tensor((M, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, "uint8") - B_shared = T.alloc_shared(B_shared_shape, "uint8") + A_shared = T.alloc_shared((block_M, block_K), "uint8") + B_shared = T.alloc_shared((block_N, block_K), "uint8") C_local = T.alloc_fragment((block_M, block_N), accum_dtype) T.clear(C_local) @@ -55,114 +73,168 @@ def main( return main -FP4_E2M1_TO_FLOAT = [ - 0.0, - 0.5, - 1.0, - 1.5, - 2.0, - 3.0, - 4.0, - 6.0, - -0.0, - -0.5, - -1.0, - -1.5, - -2.0, - -3.0, - -4.0, - -6.0, -] +# --------------------------------------------------------------------------- +# NVFP4 (kind::mxf4nvf4.block_scale): A/B dense-packed FP4 (uint8 [*, K/2]), +# FP8 E4M3 scale per 16 elements held in registers, fragment T.nvfp4_gemm. +# --------------------------------------------------------------------------- +def matmul_nvfp4_sm120(M, N, K, out_dtype, accum_dtype): + block_M, block_N, block_K = 16, 8, 64 + threads = 32 + packed_block_K = block_K // 2 + @T.prim_func + def main( + A: T.Tensor((M, K // 2), "uint8"), + B: T.Tensor((N, K // 2), "uint8"), + SFA: T.Tensor((T.ceildiv(M, block_M), T.ceildiv(K, block_K)), "uint32"), + SFB: T.Tensor((T.ceildiv(N, block_N), T.ceildiv(K, block_K)), "uint32"), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared((block_M, packed_block_K), "uint8") + B_shared = T.alloc_shared((block_N, packed_block_K), "uint8") + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + SFA_local = T.alloc_local((1,), "uint32") + SFB_local = T.alloc_local((1,), "uint32") -def unpack_fp4_to_uint8(packed_int8: torch.Tensor, M: int, K: int) -> torch.Tensor: - """Unpack (M, K//2) int8 -> (M, K) uint8, 1 FP4 per byte in low nibble.""" - flat = packed_int8.to(torch.uint8).reshape(M, K // 2) - lo = flat & 0x0F - hi = (flat >> 4) & 0x0F - return torch.stack([lo, hi], dim=-1).reshape(M, K).contiguous() + T.annotate_layout({ + A_shared: make_swizzled_layout(A_shared), + B_shared: make_swizzled_layout(B_shared), + }) + T.clear(C_local) + for ko in T.serial(T.ceildiv(K, block_K)): + T.copy(A[by * block_M, ko * packed_block_K], A_shared) + T.copy(B[bx * block_N, ko * packed_block_K], B_shared) + SFA_local[0] = SFA[by, ko] + SFB_local[0] = SFB[bx, ko] + T.nvfp4_gemm(A_shared, B_shared, SFA_local, SFB_local, C_local, + transpose_B=True, clear_accum=(ko == 0)) -def unpack_fp4_to_float(packed_int8: torch.Tensor, M: int, K: int) -> torch.Tensor: - """Unpack (M, K//2) int8 -> (M, K) float32 via e2m1 lookup table.""" - lut = torch.tensor(FP4_E2M1_TO_FLOAT, dtype=torch.float32, device=packed_int8.device) - unpacked = unpack_fp4_to_uint8(packed_int8, M, K).to(torch.int64) - return lut[unpacked] - - -M, N, K = 256, 256, 256 -block_M, block_N, block_K = 128, 128, 128 -out_dtype = T.float32 -accum_dtype = T.float32 - -print(f"Running FP4 GEMM: M={M}, N={N}, K={K}") -print(f" block_M={block_M}, block_N={block_N}, block_K={block_K}") - -func = matmul_fp4( - M, - N, - K, - block_M, - block_N, - block_K, - out_dtype, - accum_dtype, - num_stages=2, - threads=128, -) - -jit_kernel = tilelang.compile( - func, - out_idx=[2], - target="cuda", - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }, -) - -print("Compilation succeeded!") -if os.environ.get("TL_FP4_DUMP_CUDA", "0") != "0": - with open(os.path.join(os.path.dirname(__file__), "gemm_fp4_sm120.cu"), "w") as f: - f.write(jit_kernel.get_kernel_source()) - -torch.manual_seed(42) - -# Create packed FP4 data (2 per byte), then pre-unpack to uint8 -a_packed = torch.randint(0, 256, (M, K // 2), device="cuda", dtype=torch.uint8).to(torch.int8) -b_packed = torch.randint(0, 256, (N, K // 2), device="cuda", dtype=torch.uint8).to(torch.int8) -a_unpacked = unpack_fp4_to_uint8(a_packed, M, K) -b_unpacked = unpack_fp4_to_uint8(b_packed, N, K) - -# --- Test 1: zeros --- -a_zero = torch.zeros(M, K, device="cuda", dtype=torch.uint8) -b_zero = torch.zeros(N, K, device="cuda", dtype=torch.uint8) -c_zero = jit_kernel(a_zero, b_zero) -assert c_zero.abs().max().item() == 0.0, f"Zero test failed: max={c_zero.abs().max().item()}" -print("[PASS] zeros in -> zeros out") - -# --- Test 2: numerical verification --- -c = jit_kernel(a_unpacked, b_unpacked) - -a_float = unpack_fp4_to_float(a_packed, M, K) -b_float = unpack_fp4_to_float(b_packed, N, K) -ref_c = a_float @ b_float.T - -diff = (c.float() - ref_c).abs() -max_diff = diff.max().item() -rel_err = diff.sum().item() / (ref_c.abs().sum().item() + 1e-10) -print(f"[NUMERICAL] max_abs_diff={max_diff:.4f}, rel_err={rel_err:.6f}") -if max_diff < 1.0: - print("[PASS] numerical verification (max_abs_diff < 1.0)") -else: - print("[WARN] large diff -- may indicate layout or data flow issue") - -# --- Benchmark --- -torch.cuda.synchronize() -start = time.perf_counter() -for _ in range(100): - jit_kernel(a_unpacked, b_unpacked) -torch.cuda.synchronize() -elapsed = (time.perf_counter() - start) / 100 * 1000 -print(f"Latency: {elapsed:.4f} ms") -print(f"TFLOPS: {2 * M * N * K / (elapsed / 1e3) / 1e12:.2f}") + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def run_fp4(args): + M, N, K = args.m, args.n, args.k + block_M, block_N, block_K = 128, 128, 128 + print(f"Running SM120 FP4 GEMM (kind::f8f6f4): M={M}, N={N}, K={K}") + + func = matmul_fp4_sm120(M, N, K, block_M, block_N, block_K, T.float32, T.float32, num_stages=2) + jit_kernel = tilelang.compile( + func, out_idx=[2], target="cuda", + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, + ) + print("Compilation succeeded!") + + torch.manual_seed(42) + a_packed = torch.randint(0, 256, (M, K // 2), device="cuda", dtype=torch.uint8).to(torch.int8) + b_packed = torch.randint(0, 256, (N, K // 2), device="cuda", dtype=torch.uint8).to(torch.int8) + a_unpacked = unpack_fp4_to_uint8(a_packed, M, K) + b_unpacked = unpack_fp4_to_uint8(b_packed, N, K) + + a_zero = torch.zeros(M, K, device="cuda", dtype=torch.uint8) + b_zero = torch.zeros(N, K, device="cuda", dtype=torch.uint8) + assert jit_kernel(a_zero, b_zero).abs().max().item() == 0.0, "Zero test failed" + print("[PASS] zeros in -> zeros out") + + c = jit_kernel(a_unpacked, b_unpacked) + ref = unpack_fp4_to_float(a_packed, M, K) @ unpack_fp4_to_float(b_packed, N, K).T + diff = (c.float() - ref).abs() + max_diff = diff.max().item() + rel_err = diff.sum().item() / (ref.abs().sum().item() + 1e-10) + print(f"[NUMERICAL] max_abs_diff={max_diff:.4f}, rel_err={rel_err:.6f}") + print("[PASS] numerical verification" if max_diff < 1.0 else "[WARN] large diff") + + _bench(lambda: jit_kernel(a_unpacked, b_unpacked), M, N, K, args) + + +def run_nvfp4(args): + M, N, K = args.m, args.n, args.k + block_M, block_N, block_K = 16, 8, 64 + print(f"Running SM120 NVFP4 GEMM (kind::mxf4nvf4.block_scale): M={M}, N={N}, K={K}, scale=uniform") + + kernel = tilelang.compile( + matmul_nvfp4_sm120(M, N, K, T.float32, T.float32), + out_idx=[4], target="cuda", + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, + ) + print("Compilation succeeded!") + + torch.manual_seed(0) + a = torch.randint(0, 256, (M, K // 2), device="cuda", dtype=torch.uint8) + b = torch.randint(0, 256, (N, K // 2), device="cuda", dtype=torch.uint8) + + # Scale factors: 2D [ceil(M/16), ceil(K/64)] uint32, 4 E4M3 packed per word. + # NOTE: the exact SM120 mma.sync SF-register -> (row, K16-block) mapping is not + # modeled in this reference, so verification uses UNIFORM scale (=1.0) for an + # exact data check, plus a scale spot-check (0.5 x 2.0 preserves the product). + # Random per-block-scale verification is a follow-up (needs the SF layout). + sf_m, sf_n, sf_k = (M + block_M - 1) // block_M, (N + block_N - 1) // block_N, (K + block_K - 1) // block_K + one = pack_e4m3x4_to_u32((1.0, 1.0, 1.0, 1.0)) + half = pack_e4m3x4_to_u32((0.5, 0.5, 0.5, 0.5)) + two = pack_e4m3x4_to_u32((2.0, 2.0, 2.0, 2.0)) + sfa = torch.full((sf_m, sf_k), one, device="cuda", dtype=torch.uint32) + sfb = torch.full((sf_n, sf_k), one, device="cuda", dtype=torch.uint32) + + c = kernel(a, b, sfa, sfb) + ref = unpack_fp4_to_float(a, M, K) @ unpack_fp4_to_float(b, N, K).T + diff = (c.float() - ref).abs() + max_diff = diff.max().item() + rel_err = diff.sum().item() / (ref.abs().sum().item() + 1e-10) + print(f"[NUMERICAL] max_abs_diff={max_diff:.4f}, rel_err={rel_err:.6f}") + print("[PASS] numerical verification" if max_diff < 1.0 else "[WARN] large diff") + + # Scale spot-check: SFA=0.5, SFB=2.0 must reproduce the unit-scale result. + c_scaled = kernel( + a, b, + torch.full((sf_m, sf_k), half, device="cuda", dtype=torch.uint32), + torch.full((sf_n, sf_k), two, device="cuda", dtype=torch.uint32), + ) + sdiff = (c_scaled.float() - ref).abs().max().item() + print(f"[SCALE] max_abs_diff(SFA=0.5 x SFB=2.0 vs unit)={sdiff:.4f} ({'PASS' if sdiff < 1.0 else 'WARN'})") + + _bench(lambda: kernel(a, b, sfa, sfb), M, N, K, args) + + +def _bench(fn, M, N, K, args): + for _ in range(args.warmup): + fn() + torch.cuda.synchronize() + start = time.perf_counter() + for _ in range(args.iters): + fn() + torch.cuda.synchronize() + elapsed = (time.perf_counter() - start) / max(args.iters, 1) * 1000 + print(f"Latency: {elapsed:.4f} ms") + if elapsed > 0: + print(f"TFLOPS: {2 * M * N * K / (elapsed / 1e3) / 1e12:.2f}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="FP4 / NVFP4 GEMM on SM120") + parser.add_argument("--nvfp4", action="store_true", + help="run NVFP4 block-scaled GEMM with random E4M3 scales (default: plain FP4)") + parser.add_argument("--m", type=int, default=None) + parser.add_argument("--n", type=int, default=None) + parser.add_argument("--k", type=int, default=None) + parser.add_argument("--warmup", type=int, default=int(os.environ.get("TL_FP4_WARMUP", "20"))) + parser.add_argument("--iters", type=int, default=int(os.environ.get("TL_FP4_ITERS", "100"))) + args = parser.parse_args() + + default_mnk = (1024, 1024, 256) if args.nvfp4 else (256, 256, 256) + args.m = args.m if args.m is not None else int(os.environ.get("TL_FP4_M", default_mnk[0])) + args.n = args.n if args.n is not None else int(os.environ.get("TL_FP4_N", default_mnk[1])) + args.k = args.k if args.k is not None else int(os.environ.get("TL_FP4_K", default_mnk[2])) + + if args.nvfp4: + run_nvfp4(args) + else: + run_fp4(args) diff --git a/examples/gemm_fp4/example_gemm_nvfp4_sm100.py b/examples/gemm_fp4/example_gemm_nvfp4_sm100.py deleted file mode 100644 index f450cef418..0000000000 --- a/examples/gemm_fp4/example_gemm_nvfp4_sm100.py +++ /dev/null @@ -1,249 +0,0 @@ -"""Minimal NVFP4 GEMM on SM100/SM110 using TCGEN05 block-scaled MMA. - -This validates the frontend-to-PTX path for -``tcgen05.mma.kind::mxf4nvf4.block_scale``. A/B are declared and staged as -native ``T.float4_e2m1fn`` tensors so the shared operand gets the swizzled -(align16b -> SWIZZLE_128B) layout the MMA requires. SFA/SFB are uint32 words, -each packing four FP8 E4M3 scale values for one K64 tile. -""" - -import os -import time - -import torch -import tilelang -import tilelang.language as T - - -FP4_E2M1_TO_FLOAT = [ - 0.0, - 0.5, - 1.0, - 1.5, - 2.0, - 3.0, - 4.0, - 6.0, - -0.0, - -0.5, - -1.0, - -1.5, - -2.0, - -3.0, - -4.0, - -6.0, -] - - -def pack_e4m3x4_to_u32(values: tuple[float, float, float, float]) -> int: - raw = torch.tensor(values, dtype=torch.float32).to(torch.float8_e4m3fn).view(torch.uint8) - return sum(int(byte) << (8 * i) for i, byte in enumerate(raw.tolist())) - - -def unpack_fp4_to_float(packed, rows, cols, high_first=False): - lut = torch.tensor(FP4_E2M1_TO_FLOAT, dtype=torch.float32, device=packed.device) - flat = packed.to(torch.uint8).reshape(rows, cols // 2) - lo = flat & 0x0F - hi = (flat >> 4) & 0x0F - first, second = (hi, lo) if high_first else (lo, hi) - unpacked = torch.stack([first, second], dim=-1).reshape(rows, cols).to(torch.int64) - return lut[unpacked] - - -def pack_repeated_nibble(codes, cols): - packed = torch.empty((codes.numel(), cols // 2), device=codes.device, dtype=torch.uint8) - byte = (codes.to(torch.uint8) & 0x0F) | ((codes.to(torch.uint8) & 0x0F) << 4) - packed[:] = byte[:, None] - return packed - - -def pack_k_pattern(row_count, codes): - lo = codes[0::2].to(torch.uint8) & 0x0F - hi = codes[1::2].to(torch.uint8) & 0x0F - row = lo | (hi << 4) - return row.repeat(row_count, 1) - - -def matmul_nvfp4_sm100( - M=128, - N=128, - K=128, - block_M=128, - block_N=128, - block_K=128, - out_dtype=T.float32, - accum_dtype=T.float32, - threads=128, -): - packed_K = K // 2 - packed_block_K = block_K // 2 - k_tiles = T.ceildiv(K, block_K) - - @T.prim_func - def main( - A: T.Tensor((M, K), T.float4_e2m1fn), - B: T.Tensor((N, K), T.float4_e2m1fn), - SFA: T.Tensor((T.ceildiv(K, block_K) * M,), "uint32"), - SFB: T.Tensor((T.ceildiv(K, block_K) * N,), "uint32"), - C: T.Tensor((M, N), out_dtype), - ): - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - # DENSE-packed FP4 SMEM staging for tcgen05 mxf4nvf4: stage the operands - # as uint8 [M, K/2] (two e2m1 per byte). With TMA on, infer_shared_layout's - # >=8-bit branch gives this a half_bank (SW64) swizzle for K/2==64 bytes, - # matching CUTLASS's dense e2m1 SW64 K-major operand. (mxf4nvf4 consumes - # FP4 dense, unlike f8f6f4 which uses the 1-byte-per-FP4 align16b layout.) - A_bytes = T.view(A, (M, packed_K), "uint8") - B_bytes = T.view(B, (N, packed_K), "uint8") - A_shared = T.alloc_shared((block_M, packed_block_K), "uint8") - B_shared = T.alloc_shared((block_N, packed_block_K), "uint8") - SFA_shared = T.alloc_shared((block_M,), "uint32") - SFB_shared = T.alloc_shared((block_N,), "uint32") - - C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype) - SFA_tmem = T.alloc_tmem([block_M, 4], "uint32") - SFB_tmem = T.alloc_tmem([block_M, 4], "uint32") - C_local = T.alloc_fragment((block_M, block_N), accum_dtype) - C_shared = T.alloc_shared((block_M, block_N), out_dtype) - mbar = T.alloc_barrier(1) - - tx = T.get_thread_binding() - T.use_swizzle(8) - - for ko in T.serial(k_tiles): - T.copy(A_bytes[by * block_M, ko * packed_block_K], A_shared) - T.copy(B_bytes[bx * block_N, ko * packed_block_K], B_shared) - T.copy(SFA[ko * M + by * block_M], SFA_shared) - T.copy(SFB[ko * N + bx * block_N], SFB_shared) - T.sync_threads() - - if tx < 32: - T.tcgen05_sf_warp_transpose(SFA_shared) - T.tcgen05_sf_warp_transpose(SFB_shared) - T.fence_proxy_async() - T.tcgen05_cp_warpx4(SFA_shared, SFA_tmem) - T.tcgen05_cp_warpx4(SFB_shared, SFB_tmem) - T.sync_threads() - - if 32 <= tx and tx < 64: - T.tcgen05_gemm_blockscaled( - A_shared, - B_shared, - C_tmem, - SFA_tmem, - SFB_tmem, - transpose_B=True, - mbar=mbar, - clear_accum=(ko == 0), - is_nvfp4=True, - ) - T.mbarrier_wait_parity(mbar, ko % 2) - T.sync_threads() - - T.copy(C_tmem, C_local) - T.copy(C_local, C_shared) - T.copy(C_shared, C[by * block_M, bx * block_N]) - - return main - - -if __name__ == "__main__": - M = int(os.environ.get("TL_NVFP4_SM100_M", "128")) - N = int(os.environ.get("TL_NVFP4_SM100_N", "128")) - K = int(os.environ.get("TL_NVFP4_SM100_K", "128")) - block_K = 128 # mxf4nvf4 needs a swizzled (SWIZZLE_128B) operand; 64B rows -> block_K=128 - device_major, device_minor = torch.cuda.get_device_capability() - arch = os.environ.get("TL_NVFP4_SM100_ARCH", f"sm_{device_major}{device_minor}") - cuda_root = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") or "/usr/local/cuda" - cuda_target = os.environ.get("TL_NVFP4_SM100_CUDA_TARGET_DIR", "aarch64-linux") - cuda_include = os.environ.get("TL_NVFP4_SM100_CUDA_INCLUDE", f"{cuda_root}/targets/{cuda_target}/include") - one_scale = pack_e4m3x4_to_u32((1.0, 1.0, 1.0, 1.0)) - - print(f"Running SM100 NVFP4 GEMM: M={M}, N={N}, K={K}, arch={arch}") - kernel = tilelang.compile( - matmul_nvfp4_sm100(M, N, K), - out_idx=[4], - target={"kind": "cuda", "arch": arch}, - pass_configs={ - # mxf4nvf4 needs a swizzled K-major operand; sub-byte FP4 only gets the - # tcgen05mma swizzled SMEM layout on the TMA path (see infer_shared_layout). - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: False, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - tilelang.PassConfigKey.TL_DEVICE_COMPILE_FLAGS: [ - f"--target-directory={cuda_target}", - f"-I{cuda_include}", - ], - }, - ) - print("Compilation succeeded!") - if os.environ.get("TL_NVFP4_DUMP_CUDA", "0") != "0": - with open(os.path.join(os.path.dirname(__file__), "gemm_nvfp4_sm100.cu"), "w") as f: - f.write(kernel.get_kernel_source()) - - torch.manual_seed(0) - pattern = os.environ.get("TL_NVFP4_SM100_PATTERN") - if pattern == "ones": - a = torch.full((M, K // 2), 0x22, device="cuda", dtype=torch.uint8) - b = torch.full((N, K // 2), 0x22, device="cuda", dtype=torch.uint8) - elif pattern == "a_ones_b_n": - a = torch.full((M, K // 2), 0x22, device="cuda", dtype=torch.uint8) - codes = (torch.arange(N, device="cuda", dtype=torch.uint8) % 7) + 1 - b = pack_repeated_nibble(codes, K) - elif pattern == "a_ones_b_k": - a = torch.full((M, K // 2), 0x22, device="cuda", dtype=torch.uint8) - codes = (torch.arange(K, device="cuda", dtype=torch.uint8) % 7) + 1 - b = pack_k_pattern(N, codes) - else: - a = torch.randint(0, 256, (M, K // 2), device="cuda", dtype=torch.uint8) - b = torch.randint(0, 256, (N, K // 2), device="cuda", dtype=torch.uint8) - # SF words must match the kernel signature: ceildiv(K, block_K) words per row. - # (Uniform scale=1.0 here, so the per-K64-atom SF advance is exercised in M2.) - sf_tiles = (K + block_K - 1) // block_K - sfa = torch.full((sf_tiles * M,), one_scale, device="cuda", dtype=torch.uint32) - sfb = torch.full((sf_tiles * N,), one_scale, device="cuda", dtype=torch.uint32) - - c = kernel(a, b, sfa, sfb) - ref = unpack_fp4_to_float(a, M, K) @ unpack_fp4_to_float(b, N, K).T - diff = (c.float() - ref).abs() - print(f"[NUMERICAL] max_abs_diff={diff.max().item():.4f}, rel_err={diff.sum().item() / (ref.abs().sum().item() + 1e-10):.6f}") - if os.environ.get("TL_NVFP4_SM100_PRINT_SAMPLE"): - print(f"[SAMPLE] c00={c[0, 0].item():.4f}, c01={c[0, 1].item():.4f}, ref00={ref[0, 0].item():.4f}") - print( - f"[SAMPLE] c_min={c.float().min().item():.4f}, c_max={c.float().max().item():.4f}, " - f"ref_min={ref.min().item():.4f}, ref_max={ref.max().item():.4f}, " - f"num_equal={(diff == 0).sum().item()}/{diff.numel()}" - ) - nz = c.float() != 0 - row_nz = nz.sum(dim=1) - col_nz = nz.sum(dim=0) - print(f"[SAMPLE] row_nz_first16={row_nz[:16].tolist()}") - print(f"[SAMPLE] row_nz_last16={row_nz[-16:].tolist()}") - print(f"[SAMPLE] col_nz_first32={col_nz[:32].tolist()}") - print(f"[SAMPLE] col_nz_last32={col_nz[-32:].tolist()}") - print(f"[SAMPLE] c_row0_first16={c[0, :16].float().tolist()}") - print(f"[SAMPLE] c_row0_mid48_80={c[0, 48:80].float().tolist()}") - print(f"[SAMPLE] c_row0_last16={c[0, -16:].float().tolist()}") - print(f"[SAMPLE] c_col0_first16={c[:16, 0].float().tolist()}") - print(f"[SAMPLE] ref_row0_first32={ref[0, :32].float().tolist()}") - for a_high_first in (False, True): - for b_high_first in (False, True): - ref_probe = unpack_fp4_to_float(a, M, K, a_high_first) @ unpack_fp4_to_float(b, N, K, b_high_first).T - probe_diff = (c.float() - ref_probe).abs() - print( - f"[PROBE] a_high_first={int(a_high_first)}, b_high_first={int(b_high_first)}, " - f"max_abs_diff={probe_diff.max().item():.4f}, " - f"rel_err={probe_diff.sum().item() / (ref_probe.abs().sum().item() + 1e-10):.6f}" - ) - - warmup = int(os.environ.get("TL_NVFP4_SM100_WARMUP", "20")) - iters = int(os.environ.get("TL_NVFP4_SM100_ITERS", "100")) - for _ in range(warmup): - kernel(a, b, sfa, sfb) - torch.cuda.synchronize() - start = time.perf_counter() - for _ in range(iters): - kernel(a, b, sfa, sfb) - torch.cuda.synchronize() - elapsed_ms = (time.perf_counter() - start) * 1000 / iters - tflops = 2 * M * N * K / (elapsed_ms / 1e3) / 1e12 - print(f"[BENCH] latency={elapsed_ms:.4f} ms, TFLOPS={tflops:.2f}, iters={iters}") diff --git a/examples/gemm_fp4/example_gemm_nvfp4_sm120.py b/examples/gemm_fp4/example_gemm_nvfp4_sm120.py deleted file mode 100644 index f78db48f18..0000000000 --- a/examples/gemm_fp4/example_gemm_nvfp4_sm120.py +++ /dev/null @@ -1,175 +0,0 @@ -"""Minimal NVFP4 GEMM on SM120 using block-scaled MMA. - -FlashInfer/TRT-LLM use "NVFP4" as packed FP4 E2M1 data plus FP8 E4M3/UE4M3 -scale factors with ``sf_vec_size=16``. The corresponding SM120 Tensor Core -instruction is ``mma.sync.aligned.kind::mxf4nvf4.block_scale``: both MMA -operands are FP4, the scale factors are FP8 E4M3, and accumulation is FP32. - -That is distinct from the non-scaled A8W4 examples, which use FP8 x FP4 -``kind::f8f6f4`` MMA. FP16/BF16 usually appear around NVFP4 as source/output -types for quantization or epilogues, not as the direct mxf4nvf4 MMA operands. - -This example intentionally starts with one SM120 MMA atom (m16n8k64) and a -uniform scale register before wiring NVFP4 into the generic GEMM API. -""" - -import os -import time - -import torch -import tilelang -import tilelang.language as T -from tilelang.layout import make_swizzled_layout - - -FP4_E2M1_TO_FLOAT = [ - 0.0, - 0.5, - 1.0, - 1.5, - 2.0, - 3.0, - 4.0, - 6.0, - -0.0, - -0.5, - -1.0, - -1.5, - -2.0, - -3.0, - -4.0, - -6.0, -] - - -def unpack_fp4_to_uint8(packed_int8: torch.Tensor, M: int, K: int) -> torch.Tensor: - """Unpack (M, K//2) uint8 -> (M, K) uint8 FP4 codes for the reference.""" - flat = packed_int8.to(torch.uint8).reshape(M, K // 2) - lo = flat & 0x0F - hi = (flat >> 4) & 0x0F - return torch.stack([lo, hi], dim=-1).reshape(M, K).contiguous() - - -def unpack_fp4_to_float(packed_int8: torch.Tensor, M: int, K: int) -> torch.Tensor: - """Unpack FP4 E2M1 data to float32 through a lookup table.""" - lut = torch.tensor(FP4_E2M1_TO_FLOAT, dtype=torch.float32, device=packed_int8.device) - unpacked = unpack_fp4_to_uint8(packed_int8, M, K).to(torch.int64) - return lut[unpacked] - - -def pack_e4m3x4_to_u32(values: tuple[float, float, float, float]) -> int: - """Pack four FP8 E4M3 scale values into the uint32 register used by MMA.SF.""" - raw = torch.tensor(values, dtype=torch.float32).to(torch.float8_e4m3fn).view(torch.uint8) - return sum(int(byte) << (8 * i) for i, byte in enumerate(raw.tolist())) - - -UE4M3_ONE_X4 = pack_e4m3x4_to_u32((1.0, 1.0, 1.0, 1.0)) -UE4M3_HALF_X4 = pack_e4m3x4_to_u32((0.5, 0.5, 0.5, 0.5)) -UE4M3_TWO_X4 = pack_e4m3x4_to_u32((2.0, 2.0, 2.0, 2.0)) - - -def matmul_nvfp4_sm120(M=16, N=8, K=64, out_dtype=T.float32, accum_dtype=T.float32): - block_M, block_N, block_K = 16, 8, 64 - threads = 32 - packed_K = K // 2 - packed_block_K = block_K // 2 - - @T.prim_func - def main( - A: T.Tensor((M, packed_K), "uint8"), - B: T.Tensor((N, packed_K), "uint8"), - SFA: T.Tensor((T.ceildiv(M, block_M), T.ceildiv(K, block_K)), "uint32"), - SFB: T.Tensor((T.ceildiv(N, block_N), T.ceildiv(K, block_K)), "uint32"), - C: T.Tensor((M, N), out_dtype), - ): - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared((block_M, packed_block_K), "uint8") - B_shared = T.alloc_shared((block_N, packed_block_K), "uint8") - C_local = T.alloc_fragment((block_M, block_N), accum_dtype) - SFA_local = T.alloc_local((1,), "uint32") - SFB_local = T.alloc_local((1,), "uint32") - - T.annotate_layout( - { - A_shared: make_swizzled_layout(A_shared), - B_shared: make_swizzled_layout(B_shared), - } - ) - - T.clear(C_local) - for ko in T.serial(T.ceildiv(K, block_K)): - T.copy(A[by * block_M, ko * packed_block_K], A_shared) - T.copy(B[bx * block_N, ko * packed_block_K], B_shared) - SFA_local[0] = SFA[by, ko] - SFB_local[0] = SFB[bx, ko] - T.nvfp4_gemm(A_shared, B_shared, SFA_local, SFB_local, C_local, transpose_B=True, clear_accum=(ko == 0)) - - T.copy(C_local, C[by * block_M, bx * block_N]) - - return main - - -if __name__ == "__main__": - M, N, K = 1024, 1024, 256 - print(f"Running NVFP4 SM120 block-scaled MMA: M={M}, N={N}, K={K}") - - kernel = tilelang.compile( - matmul_nvfp4_sm120(M, N, K), - out_idx=[4], - target="cuda", - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }, - ) - print("Compilation succeeded!") - - if os.environ.get("TL_NVFP4_DUMP_CUDA", "0") != "0": - with open(os.path.join(os.path.dirname(__file__), "gemm_nvfp4_sm120.cu"), "w") as f: - f.write(kernel.get_kernel_source()) - - torch.manual_seed(0) - a_packed = torch.randint(0, 256, (M, K // 2), device="cuda", dtype=torch.uint8) - b_packed = torch.randint(0, 256, (N, K // 2), device="cuda", dtype=torch.uint8) - - # To fit for mma.m16n8k64, we need to pad the scale factors to the nearest multiples accordingly. - scale_shape_a = ((M + 15) // 16, (K + 63) // 64) - scale_shape_b = ((N + 7) // 8, (K + 63) // 64) - sfa = torch.full(scale_shape_a, UE4M3_ONE_X4, device="cuda", dtype=torch.uint32) - sfb = torch.full(scale_shape_b, UE4M3_ONE_X4, device="cuda", dtype=torch.uint32) - - a_zero = torch.zeros((M, K // 2), device="cuda", dtype=torch.uint8) - b_zero = torch.zeros((N, K // 2), device="cuda", dtype=torch.uint8) - c_zero = kernel(a_zero, b_zero, sfa, sfb) - print(f"[ZERO] max_abs={c_zero.abs().max().item():.4f}") - - one_pair = (2 | (2 << 4)) # two FP4 E2M1 1.0 values packed in one byte - a_one = torch.full((M, K // 2), one_pair, device="cuda", dtype=torch.uint8) - b_one = torch.full((N, K // 2), one_pair, device="cuda", dtype=torch.uint8) - c_one = kernel(a_one, b_one, sfa, sfb) - print(f"[ONE] first={c_one[0, 0].item():.4f}, expected={float(K):.4f}") - - # Exercise the scale path explicitly. SFA=0.5 and SFB=2.0 should preserve - # the effective product for all-one FP4 inputs. - sfa_half = torch.full(scale_shape_a, UE4M3_HALF_X4, device="cuda", dtype=torch.uint32) - sfb_two = torch.full(scale_shape_b, UE4M3_TWO_X4, device="cuda", dtype=torch.uint32) - c_scaled = kernel(a_one, b_one, sfa_half, sfb_two) - print(f"[SCALE] first={c_scaled[0, 0].item():.4f}, expected={float(K):.4f}") - - c = kernel(a_packed, b_packed, sfa, sfb) - ref = unpack_fp4_to_float(a_packed, M, K) @ unpack_fp4_to_float(b_packed, N, K).T - diff = (c.float() - ref).abs() - print(f"[NUMERICAL] max_abs_diff={diff.max().item():.4f}, rel_err={diff.sum().item() / (ref.abs().sum().item() + 1e-10):.6f}") - - warmup = int(os.environ.get("TL_NVFP4_WARMUP", "20")) - iters = int(os.environ.get("TL_NVFP4_ITERS", "100")) - for _ in range(warmup): - kernel(a_packed, b_packed, sfa, sfb) - torch.cuda.synchronize() - start = time.perf_counter() - for _ in range(iters): - kernel(a_packed, b_packed, sfa, sfb) - torch.cuda.synchronize() - elapsed_ms = (time.perf_counter() - start) * 1000 / iters - tflops = 2 * M * N * K / (elapsed_ms / 1e3) / 1e12 - print(f"[BENCH] latency={elapsed_ms:.4f} ms, TFLOPS={tflops:.2f}, iters={iters}") diff --git a/examples/gemm_fp4/gemm_nvfp4_sm100.cu b/examples/gemm_fp4/gemm_nvfp4_sm100.cu deleted file mode 100644 index 14b6fefca9..0000000000 --- a/examples/gemm_fp4/gemm_nvfp4_sm100.cu +++ /dev/null @@ -1,119 +0,0 @@ -#if defined(_MSC_VER) && !defined(__clang__) && _MSC_VER < 1940 -#define _tl_orig_alignas alignas -#define alignas(N) _tl_orig_alignas((N) <= 64 ? (N) : 64) -#include -#undef alignas -#define alignas _tl_orig_alignas -#endif -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#ifdef ENABLE_BF16 -#include -#endif - -extern "C" __global__ void main_kernel(const fp4_e2_t* __restrict__ A_bytes, const fp4_e2_t* __restrict__ B_bytes, float* __restrict__ C, const uint* __restrict__ SFA, const uint* __restrict__ SFB); -extern "C" __global__ void __launch_bounds__(128, 1) main_kernel(const fp4_e2_t* __restrict__ A_bytes, const fp4_e2_t* __restrict__ B_bytes, float* __restrict__ C, const uint* __restrict__ SFA, const uint* __restrict__ SFB) { - extern __shared__ __align__(1024) uchar buf_dyn_shmem[]; - void* A_shared = ((void*)((char*)buf_dyn_shmem + 0)); - void* C_shared = ((void*)((char*)buf_dyn_shmem + 0)); - void* B_shared = ((void*)((char*)buf_dyn_shmem + 8192)); - void* SFA_shared = ((void*)((char*)buf_dyn_shmem + 16384)); - void* SFB_shared = ((void*)((char*)buf_dyn_shmem + 16896)); - __shared__ __align__(16) uint64_t mbar_mem[1]; - auto mbar = reinterpret_cast(mbar_mem); - __shared__ __align__(16) uint C_tmem[1]; - __shared__ __align__(16) uint sfb_data[1]; - __shared__ __align__(16) uint sfa_data[1]; - float C_local[128]; - if (tl::tl_shuffle_elect<0>()) { - mbar[0].init(1); - } - tl::fence_barrier_init(); - tl::tcgen05_before_thread_sync(); - __syncthreads(); - tl::tcgen05_after_thread_sync(); - if ((((int)threadIdx.x) >> 5) == 0) { - tl::tmem_allocate((&(C_tmem[0])), 128); - tl::tmem_allocate((&(sfb_data[0])), 32); - tl::tmem_allocate((&(sfa_data[0])), 32); - } - tl::tcgen05_before_thread_sync(); - __syncthreads(); - tl::tcgen05_after_thread_sync(); - const dim3 blockIdx = tl::rasterization2DRow<8>(); - #pragma unroll - for (int i = 0; i < 4; ++i) { - *(uint4*)(((uchar*)A_shared) + ((((i * 2048) + ((((int)threadIdx.x) >> 2) * 64)) + (((((((int)threadIdx.x) & 31) >> 4) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 15) >> 3) + (((int)threadIdx.x) & 1)) & 1) * 16))) = *(uint4*)(((uchar*)A_bytes) + ((i * 2048) + (((int)threadIdx.x) * 16))); - } - #pragma unroll - for (int i_1 = 0; i_1 < 4; ++i_1) { - *(uint4*)(((uchar*)B_shared) + ((((i_1 * 2048) + ((((int)threadIdx.x) >> 2) * 64)) + (((((((int)threadIdx.x) & 31) >> 4) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 15) >> 3) + (((int)threadIdx.x) & 1)) & 1) * 16))) = *(uint4*)(((uchar*)B_bytes) + ((i_1 * 2048) + (((int)threadIdx.x) * 16))); - } - ((uint*)SFA_shared)[((int)threadIdx.x)] = SFA[((int)threadIdx.x)]; - ((uint*)SFB_shared)[((int)threadIdx.x)] = SFB[((int)threadIdx.x)]; - tl::tcgen05_before_thread_sync(); - __syncthreads(); - tl::tcgen05_after_thread_sync(); - if (((int)threadIdx.x) < 32) { - void* chunk_ptr = (&(((uint*)SFA_shared)[0])); - tl::tcgen05_sf_warp_transpose(reinterpret_cast(chunk_ptr)); - void* chunk_ptr_1 = (&(((uint*)SFB_shared)[0])); - tl::tcgen05_sf_warp_transpose(reinterpret_cast(chunk_ptr_1)); - tl::fence_proxy_async(); - tl::tcgen05_before_thread_sync(); - tl::__sync_thread_partial(3, 32); - tl::tcgen05_after_thread_sync(); - void* chunk_ptr_2 = (&(((uint*)SFA_shared)[0])); - tl::tcgen05_cp(tl::make_sf_smem_desc(reinterpret_cast(chunk_ptr_2)), (*reinterpret_cast(sfa_data)) + 0); - void* chunk_ptr_3 = (&(((uint*)SFB_shared)[0])); - tl::tcgen05_cp(tl::make_sf_smem_desc(reinterpret_cast(chunk_ptr_3)), (*reinterpret_cast(sfb_data)) + 0); - } - tl::tcgen05_before_thread_sync(); - __syncthreads(); - tl::tcgen05_after_thread_sync(); - if ((32 <= ((int)threadIdx.x)) && (((int)threadIdx.x) < 64)) { - { - tl::Tcgen05SMemDescriptor v; - tl::Tcgen05SMemDescriptor v_1; - tl::initialize_tcgen05_descriptor(v, (&(((uchar*)A_shared)[0])), 1, 32, 0, 0, 4); - tl::initialize_tcgen05_descriptor(v_1, (&(((uchar*)B_shared)[0])), 1, 32, 0, 0, 4); - #pragma unroll - for (int ki = 0; ki < 2; ++ki) { - tl::tcgen05mma_mxf4nvf4_blockscaled_ss(uint64_t(v + (ki * 32)), uint64_t(v_1 + (ki * 32)), (*reinterpret_cast(C_tmem)) + 0, ((0 < ki) ? 1 : 0), static_cast(136316032), (*reinterpret_cast(sfa_data)) + 0, (*reinterpret_cast(sfb_data)) + 0); - } - tl::tcgen05_mma_arrive((&(mbar[0]))); - } - } - mbar[0].wait(0); - tl::tcgen05_before_thread_sync(); - __syncthreads(); - tl::tcgen05_after_thread_sync(); - tl::tcgen05_ld_32dp32bNx<128, false>(C_tmem[0], 0, (&(C_local[0]))); - #pragma unroll - for (int i_2 = 0; i_2 < 32; ++i_2) { - *(float4*)(((float*)C_shared) + ((((int)threadIdx.x) * 128) + (i_2 * 4))) = *(float4*)(C_local + (i_2 * 4)); - } - tl::tcgen05_before_thread_sync(); - __syncthreads(); - tl::tcgen05_after_thread_sync(); - if (tl::tl_shuffle_elect<128>()) { - tl::fence_proxy_async(); - tl::tma_store((&(C[0])), (&(((float*)C_shared)[0])), 65536); - tl::tma_store_arrive(); - tl::tma_store_wait<0, true>(); - } - if ((((int)threadIdx.x) >> 5) == 0) { - tl::tmem_deallocate((&(sfb_data[0])), 32); - tl::tmem_deallocate((&(sfa_data[0])), 32); - tl::tmem_deallocate((&(C_tmem[0])), 128); - } -} - diff --git a/src/layout/gemm_layouts.cc b/src/layout/gemm_layouts.cc index f491a2aa3b..73127761c5 100644 --- a/src/layout/gemm_layouts.cc +++ b/src/layout/gemm_layouts.cc @@ -525,59 +525,6 @@ static Layout MakeAlign16BSwizzleLayout2D(int stride, int continuous, return Layout(Array{stride, continuous}, {tc, ts, index}); } -// Dense-packed FP4 swizzle for tcgen05 mxf4nvf4. -// -// Unlike MakeAlign16BSwizzleLayout2D (f8f6f4 / unpacksmem: one FP4 per 8-bit -// container), the mxf4nvf4 MMA consumes FP4 DENSE-packed (two e2m1 per byte). -// We therefore swizzle at *byte* granularity: a logical FP4 element j maps to -// packed byte (j/2) plus nibble (j%2); the byte coordinate is XOR-swizzled with -// the standard >=8-bit bank pattern (SW128/SW64/SW32 chosen by byte width), and -// the result is re-expanded to an FP4-element index (byte*2 + nibble) so the -// codegen div_factor=2 packing lands two FP4 per byte. -static Layout MakeDenseFp4SwizzleLayout2D(int stride, int continuous, - int element_size) { - ICHECK(element_size == 4) - << "Dense FP4 swizzle is for e2m1 (4-bit), got element_size=" - << element_size; - ICHECK(stride % 8 == 0) << "stride=" << stride; - ICHECK(continuous % 2 == 0) << "continuous (FP4 count) must be even, got " - << continuous; - const int byte_continuous = continuous / 2; // packed bytes along K - const int vector_size = 16; // 128 / 8 (byte domain) - // Pick the largest bank swizzle whose period divides the packed byte width. - int bank; // number of 16B columns XOR-mixed: 8=SW128, 4=SW64, 2=SW32, 1=none - if (byte_continuous % (vector_size * 8) == 0) - bank = 8; - else if (byte_continuous % (vector_size * 4) == 0) - bank = 4; - else if (byte_continuous % (vector_size * 2) == 0) - bank = 2; - else - bank = 1; - - Var i = InputPlaceholder(0); - Var j = InputPlaceholder(1); - PrimExpr ts = FloorDiv(i, 8); - PrimExpr s = FloorMod(i, 8); - PrimExpr b = FloorDiv(j, 2); // packed byte coordinate - PrimExpr nib = FloorMod(j, 2); // nibble within the byte - PrimExpr tc = FloorDiv(FloorDiv(b, vector_size), bank); - PrimExpr c = FloorMod(FloorDiv(b, vector_size), bank); - PrimExpr vec = FloorMod(b, vector_size); - PrimExpr c_swizzle; - if (bank == 8) - c_swizzle = xor8x8(c, s); - else if (bank == 4) - c_swizzle = xor4x4(c, FloorDiv(s, 2)); - else if (bank == 2) - c_swizzle = xor2x2(c, FloorDiv(s, 4)); - else - c_swizzle = c; - PrimExpr byte_index = vec + (c_swizzle + s * bank) * vector_size; - PrimExpr index = byte_index * 2 + nib; - return Layout(Array{stride, continuous}, {tc, ts, index}); -} - // Layout swizzling for 128 bytes static Layout MakeFullBankSwizzleLayout2D(int stride, int continuous, int element_size) { @@ -620,19 +567,6 @@ Layout makeAlign16BSwizzleLayout(const Buffer &buffer) { return base->Expand(leading_shape); } -Layout makeDenseFp4SwizzleLayout(const Buffer &buffer) { - auto info = GetSwizzleShapeInfoChecked(buffer); - auto base = MakeDenseFp4SwizzleLayout2D(static_cast(info.stride), - static_cast(info.continuous), - info.element_size); - Array leading_shape; - leading_shape.reserve(buffer->shape.size() - 2); - for (size_t i = 0; i + 2 < buffer->shape.size(); ++i) { - leading_shape.push_back(buffer->shape[i]); - } - return base->Expand(leading_shape); -} - // Detail implementation please ref to // bitblas::tl::mfma_layout::make_mfma_swizzle_layout Layout makeMatrixCoreSwizzleLayout(int stride, int continuous, int element_size, diff --git a/src/layout/layout.cc b/src/layout/layout.cc index 733c0ac2bb..42faac06c2 100644 --- a/src/layout/layout.cc +++ b/src/layout/layout.cc @@ -1162,10 +1162,6 @@ TVM_FFI_STATIC_INIT_BLOCK() { [](const Buffer &buffer) { return makeAlign16BSwizzleLayout(buffer); }) - .def("tl.make_dense_fp4_swizzled_layout", - [](const Buffer &buffer) { - return makeDenseFp4SwizzleLayout(buffer); - }) .def("tl.make_linear_layout", [](Array shape) { return makeLinearLayout(shape); }) .def("tl.make_gemm_fragment_8x8", []() { return makeGemmFragment8x8(); }) diff --git a/src/layout/layout.h b/src/layout/layout.h index 1cb66833da..61e787d3af 100644 --- a/src/layout/layout.h +++ b/src/layout/layout.h @@ -287,8 +287,6 @@ Layout makeFullBankSwizzleLayout(const Buffer &buffer); Layout makeHalfBankSwizzleLayout(const Buffer &buffer); Layout makeQuarterBankSwizzleLayout(const Buffer &buffer); Layout makeAlign16BSwizzleLayout(const Buffer &buffer); -// Dense-packed FP4 (e2m1, two per byte) swizzle for tcgen05 mxf4nvf4 operands. -Layout makeDenseFp4SwizzleLayout(const Buffer &buffer); // Swizzle mode for shared memory layouts (nvidia only) // Smaller enum value = smaller swizzle granularity diff --git a/tilelang/layout/__init__.py b/tilelang/layout/__init__.py index 05ae2b705a..e108270546 100644 --- a/tilelang/layout/__init__.py +++ b/tilelang/layout/__init__.py @@ -12,7 +12,6 @@ make_half_bank_swizzled_layout, # noqa: F401 make_quarter_bank_swizzled_layout, # noqa: F401 make_align16b_swizzled_layout, # noqa: F401 - make_dense_fp4_swizzled_layout, # noqa: F401 make_linear_layout, # noqa: F401 make_gemm_fragment_8x8, # noqa: F401 make_gemm_fragment_8x8_transposed, # noqa: F401 diff --git a/tilelang/layout/swizzle.py b/tilelang/layout/swizzle.py index 690a22972a..ad3554d21f 100644 --- a/tilelang/layout/swizzle.py +++ b/tilelang/layout/swizzle.py @@ -174,21 +174,6 @@ def make_align16b_swizzled_layout(buffer: BufferLikeType): return make_tcgen05mma_swizzled_layout(buffer) -def make_dense_fp4_swizzled_layout(buffer: BufferLikeType): - """Dense-packed FP4 (e2m1, two elements per byte) swizzle for tcgen05 mxf4nvf4. - - Unlike the ALIGN16B layout (one FP4 per 8-bit container, used by f8f6f4), - mxf4nvf4 consumes FP4 dense-packed. The swizzle is applied at byte - granularity over the packed bytes (K/2), matching CUTLASS's dense e2m1 - SW64/SW128 K-major operand layout. - - Args: - buffer: BufferLikeType with a sub-byte dtype (e.g. float4_e2m1fn) - """ - buf, _, _ = _get_buffer_info(buffer) - return _ffi_api.make_dense_fp4_swizzled_layout(buf) - - def make_linear_layout(buffer_or_load_or_region: BufferLikeType): """ Create a row-major linear layout for any dimension.