[Metal] M5 Cooperative Tensor T.gemm#2252
Conversation
Expose TileLang-owned cooperative tensor builtins so Metal MPP lowering does not depend on extra TVM fork APIs.
Add a shape-aware MPP instruction choice for shared-output Metal GEMM while preserving simdgroup fallback for fragments and unsupported tiles.
Generate Metal 4 MPP matmul2d code for cooperative tensor intrinsics and keep source-only codegen separate from runtime compilation.
Split Metal GEMM lowering into simdgroup and cooperative tensor emitters so M5 tiles use MPP while fragment accumulators keep the existing path.
Keep generic allocation and storage rewrites away from opaque Metal cooperative tensor scopes to avoid invalid scope analysis.
Add runtime and source-only coverage for non-square MPP GEMM so the new cooperative tensor path is reproducible in CI and on M5.
Point the submodule at the macOS SDK guarded Metal 4 runtime update used by cooperative tensor shaders.
Add a reference page covering the two Metal GEMM paths, selection rules, current limitations, and planned follow-up work.
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (13)
💤 Files with no reviewable changes (2)
✅ Files skipped from review due to trivial changes (3)
🚧 Files skipped from review as they are similar to previous changes (6)
📝 WalkthroughWalkthroughThis PR adds Metal 4 cooperative tensor GEMM support to TileLang. It introduces four new builtin intrinsics ( ChangesMetal 4 Cooperative Tensor GEMM
Sequence Diagram(s)sequenceDiagram
participant Frontend as TileLang Frontend (Python)
participant Emitter as MPSIntrinEmitter
participant GemmMetal as GemmMetal / GemmMetalSimdGroup
participant MetalOp as src/metal/op/gemm.cc
participant Transform as TVM Transforms
participant Codegen as CodeGenTileLangMetal
participant MPP as MetalPerformancePrimitives
Frontend->>GemmMetal: T.gemm(A, B, C, clear_accum=True)
GemmMetal->>MetalOp: SelectInst(target) → metal.cooperative_tensor or metal.simdgroup
MetalOp-->>GemmMetal: instruction + warp partition
GemmMetal->>Emitter: MPSIntrinEmitter(use_cooperative_tensor=True/False)
Emitter-->>GemmMetal: ldmatrix_a/b, mma, simdgroup_copy calls
GemmMetal-->>Transform: PrimFunc with cooperative_tensor scope buffers
Transform->>Transform: exempt metal.cooperative_tensor from storage_rewrite/allreduce/LCA
Transform-->>Codegen: lowered PrimFunc
Codegen->>Codegen: CooperativeTensorUseCollector scans body
Codegen->>MPP: emit matmul2d_descriptor + matmul2d objects
Codegen->>MPP: emit cooperative_tensor_load / matmul2d.run() / cooperative_tensor_store
Codegen-->>Frontend: Metal shader source (MSL)
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
# Conflicts: # 3rdparty/tvm # src/metal/codegen/codegen_metal.cc # src/metal/op/copy.cc # src/metal/op/fill.cc # src/metal/op/gemm.cc # tilelang/cuda/intrinsics/layout/mma_layout.py # tilelang/metal/intrinsics/metal_macro_generator.py # tilelang/metal/op/gemm/__init__.py # tilelang/metal/op/gemm/gemm_metal.py # tilelang/metal/transform/__init__.py # tilelang/metal/transform/metal_fragment_to_simdgroup.py
There was a problem hiding this comment.
Actionable comments posted: 5
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tilelang/metal/intrinsics/metal_macro_generator.py (1)
30-60:⚠️ Potential issue | 🟡 MinorExplicitly pass
use_cooperative_tensor=Truein GemmMetal instantiations for code clarity.While GemmMetal is intentionally designed for cooperative tensor mode (evidenced by
GEMM_INST_METAL_COOPERATIVE_TENSORpolicy selection in_make_mps_emitter), the instantiations at lines 179 and 236 rely on the default parameter value instead of explicitly passing it. This makes the intent less obvious and could be confusing for maintainers. GemmMetalSimdGroup correctly setsuse_cooperative_tensor=Falseexplicitly; GemmMetal should do the same withuse_cooperative_tensor=Trueat both instantiation sites.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tilelang/metal/intrinsics/metal_macro_generator.py` around lines 30 - 60, Locate the two instantiations of the GemmMetal class (in the _make_mps_emitter function) that currently do not explicitly pass the use_cooperative_tensor parameter. Add use_cooperative_tensor=True as an explicit argument to both GemmMetal instantiation calls to match the clarity and consistency pattern already established by GemmMetalSimdGroup, which explicitly passes use_cooperative_tensor=False. This makes the cooperative tensor design intent clear to maintainers reading the code.
🧹 Nitpick comments (4)
tilelang/metal/intrinsics/metal_macro_generator.py (1)
95-145: 💤 Low valueConsider using tuple unpacking for cleaner indexing.
The cooperative tensor load intrinsic call correctly matches the upstream contract in
builtin.py:1264-1293. The logic for transposed vs non-transposed row/col indexing is correct.Minor style suggestion from static analysis: at line 119, consider
buffer[(*extra, row_idx, col_idx)]instead ofbuffer[extra + (row_idx, col_idx)].🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tilelang/metal/intrinsics/metal_macro_generator.py` around lines 95 - 145, In the _warp_ldmatrix_a macro function, replace the tuple concatenation syntax for buffer indexing with tuple unpacking for improved readability. Change the buffer access from buffer[extra + (row_idx, col_idx)] to use the unpacking operator syntax buffer[(*extra, row_idx, col_idx)] where the buffer is being accessed with the extra, row_idx, and col_idx values.Source: Linters/SAST tools
src/op/builtin.h (1)
368-372: ⚡ Quick winAdd doxygen documentation for the new cooperative tensor intrinsics.
The four new cooperative tensor Op declarations lack documentation comments, unlike most other intrinsics in this file (see lines 287–366 for TMA intrinsics). Adding brief doxygen comments describing the signature and purpose of each intrinsic would improve maintainability.
For example, based on usage in
src/metal/op/fill.ccandsrc/metal/op/copy.cc,cooperative_tensor_fillappears to take(data, tile_idx, fill_value, tile_m, tile_n), whilecooperative_tensor_storetakes 11 parameters including destination pointer, stride, and tile dimensions.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@src/op/builtin.h` around lines 368 - 372, Add Doxygen documentation comments above each of the four cooperative tensor Op declarations (cooperative_tensor_fill, cooperative_tensor_load, cooperative_tensor_store, and cooperative_tensor_multiply_accumulate) following the same documentation style used for the TMA intrinsics in the file. Each comment should briefly describe the function's purpose and list its parameters and their types (for example, cooperative_tensor_fill takes data, tile_idx, fill_value, tile_m, tile_n, while cooperative_tensor_store takes destination pointer, stride, and tile dimensions along with others). Ensure the documentation format matches the existing doxygen comments in the file for consistency.3rdparty/tvm (1)
1-1: Metal 4 shader compilation support in TVM submodule is legitimate.The TVM commit
11c1968acf0e95f2ac1d76b0dd9ffd44c8072b30is valid and from the active TileLang fork. The update modifies onlysrc/runtime/metal/metal_module.mm(23 insertions, 4 deletions) to enable Metal 4 shader compilation, exactly as the PR objectives describe.Consider documenting this submodule update in your CHANGELOG or PR description to clarify the Metal 4 feature enablement for future maintainers.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@3rdparty/tvm` at line 1, The TVM submodule update that enables Metal 4 shader compilation support is not documented in the project's CHANGELOG or PR description, which could create confusion for future maintainers about the purpose of this change. Add an entry to your CHANGELOG documenting the TVM submodule update to commit 11c1968acf0e95f2ac1d76b0dd9ffd44c8072b30, clearly explaining that this change enables Metal 4 shader compilation support by modifying src/runtime/metal/metal_module.mm. Additionally, update your PR description to reference this feature enablement and link to the corresponding CHANGELOG entry for clarity.src/metal/op/copy.cc (1)
181-189: ⚡ Quick winDead code in tile size computation.
The conditional block (lines 183-186) sets
kTileNandkTileMto the exact same values they were just assigned on lines 181-182, making it a no-op. The subsequent check on line 187 (if (kTileN > warp_N)) can never be true sincekTileNwas just set towarp_Non line 181.♻️ Proposed cleanup
int kTileN = warp_N; int kTileM = kTileSize; -if (warp_tiles > 0 && warp_M > kTileSize) { - kTileN = warp_N; - kTileM = kTileSize; -} -if (kTileN > warp_N) { - kTileN = warp_N; -}🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@src/metal/op/copy.cc` around lines 181 - 189, The conditional block checking if warp_tiles > 0 and warp_M > kTileSize is assigning the same values to kTileN and kTileM that were just set unconditionally on the previous lines, making it redundant dead code. Additionally, the subsequent if condition checking if kTileN > warp_N can never be true since kTileN was just assigned to warp_N. Remove the redundant conditional block (the one checking warp_tiles > 0 && warp_M > kTileSize) and the unreachable if condition that follows it, keeping only the initial assignments of kTileN and kTileM unless there is additional logic that should be applied based on the warp_tiles and warp_M conditions.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@benchmark/matmul_metal/benchmark_matmul_metal.py`:
- Line 247: The bare `except Exception as e:` catch at line 247 triggers Ruff's
BLE001 rule which flags blind exception catching. Since this is intentional to
keep the benchmark sweep running after bad configurations, either narrow the
exception type to catch only specific exceptions that could be raised by a bad
config, or add a local waiver comment like `# noqa: BLE001` followed by a
comment explaining the intentional broad catch is needed to continue the
benchmark sweep despite configuration errors.
In `@src/metal/codegen/codegen_metal.cc`:
- Around line 692-709: The persistent C tensor allocation generates fixed symbol
names (__pct_desc, __pct_op, __pct_cN) unconditionally, causing duplicate
definitions when multiple C buffers are marked for inlining. Additionally, the
direct GEMM path at lines 1554-1585 uses local descriptors with arbitrary
dimensions but still references the persistent __pct_c tensors that were created
with 16×32×16 shape, creating a mismatch when actual dimensions differ. Fix this
by: (1) generating unique symbol names per buffer allocation using a Var-keyed
prefix instead of hardcoded __pct names in all allocation sites (lines 692-709,
1326-1329, 1361-1365, 1472-1481), and (2) in the direct GEMM path, validate that
the descriptor dimensions match the persistent tensor shapes (16, 32, 16); if
dimensions don't match, skip the persistent tensor optimization and use
non-elided storage instead.
In `@src/metal/op/copy.cc`:
- Around line 123-127: The divisibility check in copy.cc is inconsistent with
fill.cc and the actual cooperative tensor GEMM micro tile dimension of 16×32.
Change the kTileSize constant and kTileElems calculation in the copy operation
to use the correct tile dimensions (16×32 = 512 elements instead of 16×16 = 256
elements) to match the tile size checks in fill.cc and ensure buffers that pass
the copy divisibility check will also pass the fill lowering requirements.
In `@testing/python/metal/test_metal_gemm_v2_linux.py`:
- Line 201: The assertion in
assert_metal_gemm_v2_global_cooperative_tensor_codegen currently hard-codes the
value 128 in the assertion check for max_total_threads_per_threadgroup, but the
function accepts a threads parameter that may differ from the default. Replace
the hard-coded 128 value with the threads parameter so that the assertion
correctly validates the requested thread count instead of always expecting 128,
allowing non-default callers to pass the assertion correctly.
In `@tilelang/metal/op/gemm/gemm_metal.py`:
- Around line 205-276: The c_bytes_per_thread calculation in the lower method
uses a hardcoded tile size of 64 bytes, but this doesn't match the actual
cooperative tensor micro-tile size being used. Move the c_bytes_per_thread
calculation to after the MPSIntrinEmitter is created (after line 239 where
mps_emitter is instantiated) and replace the hardcoded 64 value with the actual
micro-tile dimensions from the emitter: use micro_size_x * micro_size_y (which
are extracted from mps_emitter on lines 249-250) multiplied by the appropriate
element size in bytes to calculate the correct bytes per thread, which will
ensure the inner_k_steps heuristic is based on the actual tile size being used.
---
Outside diff comments:
In `@tilelang/metal/intrinsics/metal_macro_generator.py`:
- Around line 30-60: Locate the two instantiations of the GemmMetal class (in
the _make_mps_emitter function) that currently do not explicitly pass the
use_cooperative_tensor parameter. Add use_cooperative_tensor=True as an explicit
argument to both GemmMetal instantiation calls to match the clarity and
consistency pattern already established by GemmMetalSimdGroup, which explicitly
passes use_cooperative_tensor=False. This makes the cooperative tensor design
intent clear to maintainers reading the code.
---
Nitpick comments:
In `@3rdparty/tvm`:
- Line 1: The TVM submodule update that enables Metal 4 shader compilation
support is not documented in the project's CHANGELOG or PR description, which
could create confusion for future maintainers about the purpose of this change.
Add an entry to your CHANGELOG documenting the TVM submodule update to commit
11c1968acf0e95f2ac1d76b0dd9ffd44c8072b30, clearly explaining that this change
enables Metal 4 shader compilation support by modifying
src/runtime/metal/metal_module.mm. Additionally, update your PR description to
reference this feature enablement and link to the corresponding CHANGELOG entry
for clarity.
In `@src/metal/op/copy.cc`:
- Around line 181-189: The conditional block checking if warp_tiles > 0 and
warp_M > kTileSize is assigning the same values to kTileN and kTileM that were
just set unconditionally on the previous lines, making it redundant dead code.
Additionally, the subsequent if condition checking if kTileN > warp_N can never
be true since kTileN was just assigned to warp_N. Remove the redundant
conditional block (the one checking warp_tiles > 0 && warp_M > kTileSize) and
the unreachable if condition that follows it, keeping only the initial
assignments of kTileN and kTileM unless there is additional logic that should be
applied based on the warp_tiles and warp_M conditions.
In `@src/op/builtin.h`:
- Around line 368-372: Add Doxygen documentation comments above each of the four
cooperative tensor Op declarations (cooperative_tensor_fill,
cooperative_tensor_load, cooperative_tensor_store, and
cooperative_tensor_multiply_accumulate) following the same documentation style
used for the TMA intrinsics in the file. Each comment should briefly describe
the function's purpose and list its parameters and their types (for example,
cooperative_tensor_fill takes data, tile_idx, fill_value, tile_m, tile_n, while
cooperative_tensor_store takes destination pointer, stride, and tile dimensions
along with others). Ensure the documentation format matches the existing doxygen
comments in the file for consistency.
In `@tilelang/metal/intrinsics/metal_macro_generator.py`:
- Around line 95-145: In the _warp_ldmatrix_a macro function, replace the tuple
concatenation syntax for buffer indexing with tuple unpacking for improved
readability. Change the buffer access from buffer[extra + (row_idx, col_idx)] to
use the unpacking operator syntax buffer[(*extra, row_idx, col_idx)] where the
buffer is being accessed with the extra, row_idx, and col_idx values.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 5c7a820b-faad-4a2a-8bb6-5fa7e9ed40ef
📒 Files selected for processing (36)
3rdparty/tvmbenchmark/matmul_metal/benchmark_matmul_metal.pydocs/compiler_internals/metal_tilelang_development.mddocs/index.mdsrc/metal/codegen/codegen_metal.ccsrc/metal/codegen/codegen_metal.hsrc/metal/op/copy.ccsrc/metal/op/fill.ccsrc/metal/op/gemm.ccsrc/metal/op/utils.hsrc/metal/target_utils.ccsrc/metal/target_utils.hsrc/op/builtin.ccsrc/op/builtin.hsrc/op/gemm.ccsrc/op/gemm.hsrc/transform/layout_inference.ccsrc/transform/lower_device_kernel_launch.ccsrc/transform/lower_thread_allreduce.ccsrc/transform/plan_update_buffer_allocation_location.ccsrc/transform/storage_rewrite.cctesting/python/metal/test_metal_gemm_v2.pytesting/python/metal/test_metal_gemm_v2_linux.pytesting/python/metal/test_metal_simdgroup_store.pytilelang/cuda/intrinsics/layout/mma_layout.pytilelang/engine/lower.pytilelang/language/annotations.pytilelang/language/builtin.pytilelang/language/gemm_op.pytilelang/metal/intrinsics/metal_macro_generator.pytilelang/metal/op/gemm/__init__.pytilelang/metal/op/gemm/gemm_metal.pytilelang/metal/target.pytilelang/metal/transform/__init__.pytilelang/metal/transform/metal_fragment_to_simdgroup.pytilelang/utils/language.py
💤 Files with no reviewable changes (1)
- tilelang/language/gemm_op.py
| f"{mode:>10s} | {block_text:>16s} | {threads:>4d} | {swizzle_text:>8s} | " | ||
| f"{tl:>10.1f} TFLOPS | {torch_ratio:>7.0f}% | {mlx_text}" | ||
| ) | ||
| except Exception as e: |
There was a problem hiding this comment.
Satisfy Ruff for the intentional sweep catch.
This catch keeps the benchmark sweep running after a bad config, but Ruff BLE001 flags blind Exception; either narrow it or add a local waiver with intent.
Proposed fix
- except Exception as e:
+ except Exception as e: # noqa: BLE001 - keep benchmark sweeps running after per-config failures📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| except Exception as e: | |
| except Exception as e: # noqa: BLE001 - keep benchmark sweeps running after per-config failures |
🧰 Tools
🪛 Ruff (0.15.17)
[warning] 247-247: Do not catch blind exception: Exception
(BLE001)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@benchmark/matmul_metal/benchmark_matmul_metal.py` at line 247, The bare
`except Exception as e:` catch at line 247 triggers Ruff's BLE001 rule which
flags blind exception catching. Since this is intentional to keep the benchmark
sweep running after bad configurations, either narrow the exception type to
catch only specific exceptions that could be raised by a bad config, or add a
local waiver comment like `# noqa: BLE001` followed by a comment explaining the
intentional broad catch is needed to continue the benchmark sweep despite
configuration errors.
Source: Linters/SAST tools
| this->PrintIndent(); | ||
| stream | ||
| << "constexpr auto __pct_desc = mpp::tensor_ops::matmul2d_descriptor(" | ||
| << "16, 32, 16, false, false, true, " | ||
| << "mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate);" | ||
| "\n"; | ||
| this->PrintIndent(); | ||
| stream << "mpp::tensor_ops::matmul2d<__pct_desc, " | ||
| "metal::execution_simdgroup> __pct_op;\n"; | ||
| for (int t = 0; t < num_c_tiles; t++) { | ||
| this->PrintIndent(); | ||
| stream << "auto __pct_c" << t | ||
| << " = __pct_op.get_destination_cooperative_tensor<" | ||
| << "decltype(__pct_op.get_left_input_cooperative_tensor<half, " | ||
| "half, float>()), " | ||
| << "decltype(__pct_op.get_right_input_cooperative_tensor<half, " | ||
| "half, float>()), float>();\n"; | ||
| } |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, let's understand the file structure and verify the mentioned lines exist
wc -l src/metal/codegen/codegen_metal.ccRepository: tile-ai/tilelang
Length of output: 100
🏁 Script executed:
# Check the specific lines mentioned in the review
sed -n '692,709p' src/metal/codegen/codegen_metal.ccRepository: tile-ai/tilelang
Length of output: 908
🏁 Script executed:
# Check the other mentioned line ranges
sed -n '1326,1329p' src/metal/codegen/codegen_metal.cc
sed -n '1361,1365p' src/metal/codegen/codegen_metal.cc
sed -n '1472,1481p' src/metal/codegen/codegen_metal.cc
sed -n '1554,1585p' src/metal/codegen/codegen_metal.ccRepository: tile-ai/tilelang
Length of output: 2865
🏁 Script executed:
# Search for all occurrences of __pct_ symbol generation patterns
rg '__pct_' src/metal/codegen/codegen_metal.cc -A 2 -B 2Repository: tile-ai/tilelang
Length of output: 3115
🏁 Script executed:
# Look for the descriptor generation pattern
rg 'matmul2d_descriptor' src/metal/codegen/codegen_metal.cc -A 3 -B 3Repository: tile-ai/tilelang
Length of output: 1544
🏁 Script executed:
# Search for symbol name generation with context to understand variable tracking
rg '__pct_desc|__pct_op|__pct_c' src/metal/codegen/codegen_metal.cc -n | head -50Repository: tile-ai/tilelang
Length of output: 1289
🏁 Script executed:
# Look at the context around persistent C allocation to understand when it happens
sed -n '680,720p' src/metal/codegen/codegen_metal.ccRepository: tile-ai/tilelang
Length of output: 1814
🏁 Script executed:
# Search for how persistent C buffers are allocated and in what context
rg 'persistent.*cooperation|cooperation.*persistent' src/metal/codegen/codegen_metal.cc -B 5 -A 5Repository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
# Look for the scope/function where these allocations happen to understand multi-buffer risk
sed -n '650,730p' src/metal/codegen/codegen_metal.cc | head -100Repository: tile-ai/tilelang
Length of output: 3482
🏁 Script executed:
# Check if there's any validation or gating that prevents non-16x32x16 from reaching line 1585
sed -n '1545,1590p' src/metal/codegen/codegen_metal.ccRepository: tile-ai/tilelang
Length of output: 2510
Two cooperative C tensors in the same kernel scope will produce duplicate Metal symbol definitions; non-16x32x16 GEMM shapes reuse mismatched persistent tensors.
The persistent C allocation (lines 694–709) emits fixed __pct_desc, __pct_op, and __pct_cN names unconditionally whenever can_inline_c is true. If a kernel allocates multiple C buffers marked for inlining, each generates the same symbol names, causing linker/compiler errors from duplicate definitions.
Additionally, the direct GEMM path (lines 1554–1585) creates local __desc and __op with correct dimensions when can_reuse_pct_op is false (M, N, K ≠ 16, 32, 16), but still references the persistent __pct_c tensors created with the 16×32×16 descriptor at allocation time. This creates a shape/dtype mismatch between the temporary cooperative tensors and the destination buffer.
Assign a Var-keyed symbol prefix per buffer, and validate that direct paths use descriptors compatible with the persistent tensor shapes or fall back to non-elided storage.
Also applies to: 1326–1329, 1361–1365, 1472–1481, 1554–1585
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@src/metal/codegen/codegen_metal.cc` around lines 692 - 709, The persistent C
tensor allocation generates fixed symbol names (__pct_desc, __pct_op, __pct_cN)
unconditionally, causing duplicate definitions when multiple C buffers are
marked for inlining. Additionally, the direct GEMM path at lines 1554-1585 uses
local descriptors with arbitrary dimensions but still references the persistent
__pct_c tensors that were created with 16×32×16 shape, creating a mismatch when
actual dimensions differ. Fix this by: (1) generating unique symbol names per
buffer allocation using a Var-keyed prefix instead of hardcoded __pct names in
all allocation sites (lines 692-709, 1326-1329, 1361-1365, 1472-1481), and (2)
in the direct GEMM path, validate that the descriptor dimensions match the
persistent tensor shapes (16, 32, 16); if dimensions don't match, skip the
persistent tensor optimization and use non-elided storage instead.
| constexpr int kTileSize = 16; | ||
| constexpr int kTileElems = kTileSize * kTileSize; | ||
| TVM_FFI_ICHECK(total_elements % kTileElems == 0) | ||
| << "cooperative_tensor buffer size must be multiple of " << kTileElems | ||
| << ", got " << total_elements; |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Description: Check cooperative tensor buffer allocation sizes
# Search for cooperative tensor allocations
rg -nP --type=py 'alloc.*scope="metal\.cooperative_tensor"' -A 2 -B 2
# Search for cooperative tensor tile size constants
rg -n 'kTileM.*=.*16|kTileN.*=.*32' src/metal/op/Repository: tile-ai/tilelang
Length of output: 1623
🏁 Script executed:
# Read fill.cc lines 30-45 to see the full divisibility requirement
sed -n '30,45p' src/metal/op/fill.cc
# Read gemm.cc lines 30-40 to check documentation
sed -n '30,40p' src/metal/op/gemm.cc
# Read copy.cc around the flagged lines for full context
sed -n '115,135p' src/metal/op/copy.ccRepository: tile-ai/tilelang
Length of output: 1856
Align tile size divisibility checks across copy and fill operations.
The divisibility check in copy.cc requires kTileElems = 256 (16×16), but fill.cc (lines 36-41) requires divisibility by 512 (16×32). A buffer with 256 elements would pass the copy check but fail the fill lowering. The cooperative tensor GEMM micro tile is 16×32 (documented in gemm.cc), so both copy and fill operations should align on that tile size.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@src/metal/op/copy.cc` around lines 123 - 127, The divisibility check in
copy.cc is inconsistent with fill.cc and the actual cooperative tensor GEMM
micro tile dimension of 16×32. Change the kTileSize constant and kTileElems
calculation in the copy operation to use the correct tile dimensions (16×32 =
512 elements instead of 16×16 = 256 elements) to match the tile size checks in
fill.cc and ensure buffers that pass the copy divisibility check will also pass
the fill lowering requirements.
| assert "const device half* __src" in src_code | ||
| assert "[[simdgroup_index_in_threadgroup]]" in src_code | ||
| assert "__metal_get_thread_index_in_simdgroup" in src_code | ||
| assert "max_total_threads_per_threadgroup(128)" in src_code |
There was a problem hiding this comment.
Assert the requested thread count instead of hard-coding 128.
assert_metal_gemm_v2_global_cooperative_tensor_codegen accepts threads, but Line 201 always expects max_total_threads_per_threadgroup(128), so non-default callers would fail for the wrong reason.
Proposed fix
- assert "max_total_threads_per_threadgroup(128)" in src_code
+ assert f"max_total_threads_per_threadgroup({threads})" in src_code📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| assert "max_total_threads_per_threadgroup(128)" in src_code | |
| assert f"max_total_threads_per_threadgroup({threads})" in src_code |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@testing/python/metal/test_metal_gemm_v2_linux.py` at line 201, The assertion
in assert_metal_gemm_v2_global_cooperative_tensor_codegen currently hard-codes
the value 128 in the assertion check for max_total_threads_per_threadgroup, but
the function accepts a threads parameter that may differ from the default.
Replace the hard-coded 128 value with the threads parameter so that the
assertion correctly validates the requested thread count instead of always
expecting 128, allowing non-default callers to pass the assertion correctly.
| @staticmethod | ||
| def _get_padded_stride(buffer): | ||
| continuous = int(buffer.shape[-1]) | ||
| element_bits = int(tvm.DataType(buffer.dtype).bits) | ||
| padded = continuous | ||
| if (element_bits * continuous) % 256 == 0: | ||
| padded += 128 // element_bits | ||
| return padded | ||
|
|
||
| def lower( | ||
| self, | ||
| layout_map: dict, | ||
| target: Target, | ||
| thread_bounds: Range, | ||
| thread_var: tir.Var, | ||
| mbar_phase_expr: tir.PrimExpr | None = None, | ||
| ): | ||
| thread_nums = thread_bounds.extent | ||
| _, m_warp, n_warp = self._make_mps_emitter(target, int(thread_nums)) | ||
| warp_row_tiles = int(self.M // m_warp) | ||
| warp_col_tiles = int(self.N // n_warp) | ||
|
|
||
| from tilelang.metal.intrinsics.metal_macro_generator import MPSIntrinEmitter | ||
|
|
||
| @T.prim_func | ||
| def _gemm_ss_shared() -> None: | ||
| A_local = T.alloc_local((warp_rows * 64), a_dtype, scope="metal.simdgroup") | ||
| B_local = T.alloc_local((warp_cols * 64), b_dtype, scope="metal.simdgroup") | ||
| C_simd = T.alloc_local((num_simd_c * 64), accum_dtype, scope="metal.simdgroup") | ||
| if clear_accum: | ||
| for _i in T.serial(num_simd_c): | ||
| T.make_filled_simdgroup_matrix(C_simd.data, _i, T.cast(0, accum_dtype)) | ||
| else: | ||
| mps_emitter.simd_load(C_simd, C_buf) | ||
| for ki in T.serial(0, (block_K // micro_size_k)): | ||
| mps_emitter.ldmatrix_a(A_local, A_region, ki) | ||
| mps_emitter.ldmatrix_b(B_local, B_region, ki) | ||
| mps_emitter.mma(A_local, B_local, C_simd) | ||
|
|
||
| mps_emitter.simd_store(C_simd, C_buf) | ||
|
|
||
| return _Simplify(_gemm_ss_shared, inline_let=True) | ||
| else: | ||
| a_stride = self._get_padded_stride(self.A) if self.is_gemm_ss() else None | ||
| b_stride = self._get_padded_stride(self.B) if self.is_gemm_ss() else None | ||
|
|
||
| c_bytes_per_thread = warp_row_tiles * warp_col_tiles * 64 | ||
| inner_k_steps = 2 if c_bytes_per_thread <= 128 else 1 | ||
| output_dtype = self.accum_dtype | ||
| accum_dtype = T.float32 if self.is_gemm_gg() and str(output_dtype) in ("float16", "bfloat16") else output_dtype | ||
| mps_emitter = MPSIntrinEmitter( | ||
| a_dtype=self.a_dtype, | ||
| b_dtype=self.b_dtype, | ||
| accum_dtype=accum_dtype, | ||
| a_transposed=self.trans_A, | ||
| b_transposed=self.trans_B, | ||
| block_row_warps=m_warp, | ||
| block_col_warps=n_warp, | ||
| warp_row_tiles=warp_row_tiles, | ||
| warp_col_tiles=warp_col_tiles, | ||
| chunk=self.chunk, | ||
| thread_var=thread_var, | ||
| a_stride_override=a_stride, | ||
| b_stride_override=b_stride, | ||
| inner_k_steps=inner_k_steps, | ||
| ) | ||
|
|
||
| a_dtype = self.a_dtype | ||
| b_dtype = self.b_dtype | ||
| warp_rows = mps_emitter.warp_rows | ||
| warp_cols = mps_emitter.warp_cols | ||
| num_simd_c = warp_rows * warp_cols | ||
| block_K = mps_emitter.chunk | ||
| micro_size_x = mps_emitter.micro_size_x | ||
| micro_size_y = mps_emitter.micro_size_y | ||
| micro_size_k = mps_emitter.micro_size_k | ||
| inner_k_steps = mps_emitter.inner_k_steps | ||
| a_tile_elems = micro_size_x * micro_size_k | ||
| b_tile_elems = micro_size_k * micro_size_y | ||
| c_tile_elems = micro_size_x * micro_size_y | ||
|
|
||
| A_region = self.ARegion | ||
| B_region = self.BRegion | ||
| C_region = self.CRegion | ||
| C_buf = C_region.buffer | ||
| clear_accum = self.clear_accum | ||
| c_in_cooperative_tensor = is_metal_cooperative_tensor(C_buf) or is_fragment(C_buf) | ||
| assert block_K >= micro_size_k, f"block_K ({block_K}) must be >= micro_size_k ({micro_size_k})" | ||
|
|
||
| if not (self.is_gemm_ss() or self.is_gemm_gg()): | ||
| raise ValueError(f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") |
There was a problem hiding this comment.
c_bytes_per_thread calculation uses simdgroup tile size (64) instead of cooperative tensor tile size.
At line 232, c_bytes_per_thread = warp_row_tiles * warp_col_tiles * 64 uses the fixed value 64, which corresponds to the simdgroup 8x8 micro-tile. However, GemmMetal uses cooperative tensor mode with 16x32=512 element micro-tiles.
This affects the inner_k_steps heuristic on line 233. If the intent is to measure register pressure per thread, the calculation should account for the actual tile size being used.
🔧 Suggested fix
- c_bytes_per_thread = warp_row_tiles * warp_col_tiles * 64
+ # Cooperative tensor micro-tile is 16x32 = 512 elements
+ ct_micro_elems = 16 * 32
+ c_bytes_per_thread = warp_row_tiles * warp_col_tiles * ct_micro_elems
inner_k_steps = 2 if c_bytes_per_thread <= 128 else 1Or alternatively, compute this after creating the emitter to use consistent values:
+ micro_size_x_ct = 16
+ micro_size_y_ct = 32
+ c_tile_elems_approx = micro_size_x_ct * micro_size_y_ct
+ c_bytes_per_thread = warp_row_tiles * warp_col_tiles * c_tile_elems_approx
inner_k_steps = 2 if c_bytes_per_thread <= 128 else 1🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tilelang/metal/op/gemm/gemm_metal.py` around lines 205 - 276, The
c_bytes_per_thread calculation in the lower method uses a hardcoded tile size of
64 bytes, but this doesn't match the actual cooperative tensor micro-tile size
being used. Move the c_bytes_per_thread calculation to after the
MPSIntrinEmitter is created (after line 239 where mps_emitter is instantiated)
and replace the hardcoded 64 value with the actual micro-tile dimensions from
the emitter: use micro_size_x * micro_size_y (which are extracted from
mps_emitter on lines 249-250) multiplied by the appropriate element size in
bytes to calculate the correct bytes per thread, which will ensure the
inner_k_steps heuristic is based on the actual tile size being used.
Needs tile-ai/tvm#44
Summary
This PR extends the existing Metal backend with a Metal 4 cooperative tensor path for
T.gemm.The Metal backend already supported simdgroup-based GEMM lowering. This PR adds cooperative tensor as a new fast path on supported Apple GPUs, while keeping simdgroup as the compatibility path for older devices and non-Metal-4 targets.
Motivation
Metal 4 cooperative tensor exposes Apple's tensor-core-like matrix compute path. On supported hardware, it provides a substantially faster GEMM implementation than the existing simdgroup path.
This PR adds that path to TileLang so Metal
T.gemmcan use the newer hardware capability while preserving the existing simdgroup implementation for compatibility.Design Notes
Although cooperative tensor is conceptually the Metal-side counterpart of CUDA tensor core programming, the programming model is not a direct CUDA clone.
As a practical approximation for CUDA reviewers, Apple GPUs expose a less CUDA-like split between register and threadgroup storage; both are backed by a more hardware-managed on-chip memory system. Because of that, explicit threadgroup staging is not automatically a faster path than feeding cooperative tensor operands directly.
This PR therefore keeps CUDA-shaped shared staging as a compatibility path, but optimizes the direct cooperative-tensor path as the Metal fast path.
T.gemmremains the frontend abstraction, and Metal-specific instruction choice stays inside the Metal backend.What Changed
At a high level, this PR adds:
Detailed lowering rules and implementation notes are documented separately in the Metal compiler internals doc.
Impact on TileLang
The main TileLang-level impact is that Metal now has a dedicated high-performance GEMM path that reflects Metal's own matrix programming model.
In particular:
T.gemmremains the user-facing abstraction.Compatibility
This PR is intended to be backward-compatible for existing Metal users.
The existing simdgroup path is still present and tested. The new cooperative tensor path is only used when the target/runtime capability allows it, so building TileLang with a newer SDK should not force all Metal kernels to require Metal 4.
One important compatibility check is that this PR has been validated on GitHub Actions with macOS 26 and M1 hardware. That environment exposes the newer SDK at build time but does not support cooperative tensor in hardware, so passing there verifies that the backend correctly falls back to the simdgroup path on unsupported devices.
Testing
Test coverage includes:
Validated locally with:
pip install .python -m pytest testing/python/metal/ -q -xpython -m pre_commit run --all-filesGitHub Actions additionally validated the macOS 26 + M1 fallback case described above.
Summary by CodeRabbit
Release Notes
New Features
cooperative_tensor_fill,cooperative_tensor_load,cooperative_tensor_store,cooperative_tensor_multiply_accumulate).Documentation
Chores