Skip to content

[ROCM] Fix buffer_load_lds support for gfx950#2248

Draft
benenzhu wants to merge 36 commits into
tile-ai:mainfrom
benenzhu:feat-c-remove_vmcnt0
Draft

[ROCM] Fix buffer_load_lds support for gfx950#2248
benenzhu wants to merge 36 commits into
tile-ai:mainfrom
benenzhu:feat-c-remove_vmcnt0

Conversation

@benenzhu

@benenzhu benenzhu commented May 22, 2026

Copy link
Copy Markdown
Contributor

Why the existing AMD double buffer is slow

On the ROCm path, every T.copy(gmem → smem) currently lowers to a global_load → VGPR → ds_write pair, and the lowering inserts an s_waitcnt vmcnt(0) between the load and the store so the data is in-register before it gets written to LDS. That vmcnt(0) is a full wait-for-all-outstanding-loads — it doesn't just block the lane that needs the result, it stalls the whole wave on every issued global load. The practical consequence is that pipelining is dead: any T.Pipelined/double-buffer schedule the user wrote turns into back-to-back synchronous copies, because each stage's vmcnt(0) drains the previous stage's loads before the next compute can start.

Background

gfx950 exposes a real async global→shared path: buffer_load_dword{,x2,x4} ... lds reads from global and writes directly into LDS without staging through a VGPR, and s_waitcnt vmcnt(N) plays the same role as cp.async.wait_group. The catch is that the hardware writes each lane's element to lds_base + 4 * lane_id — the destination is implicitly lane-contiguous. Tilelang normally hands shared memory to GEMM through an XOR bank-swizzle so MMA reads are bank-conflict-free, and that swizzle makes the per-lane LDS index non-contiguous, so a naive lowering would scatter the loads to the wrong banks. We also can't just drop the swizzle — it's the primary perf gate for MMA. That's why AMD has been stuck on the synchronous "load into VGPR, then store to LDS" path until now.

What this PR does

The rewrite lives in lower_tile_op.cc plus a small helper in layout.{h,cc}. For every tl::ptx_cp_async_lds call we try to move the swizzle from the LDS side to the global side: if the destination layout contains a Bitwise(XOR, j, h(i)) term, we fold the inverse delta into the source tvm_access_ptr so the LDS index becomes linear again while the global load picks up the permutation (which is free — it's just an address rewrite on the gmem side). We then prove lane-contiguity on the resulting LDS index by simplifying it against a synthetic thread var and checking the dependence is a single constant stride. If the proof fails the call is downgraded back to tl::ptx_cp_async, so this is purely an optimization — no kernel becomes incorrect.

The PR also lifts the per-call cost. cp_async_gs_lds_with_rsrc<N> takes a pre-computed BufferResource descriptor plus a wave-uniform base address, and computing those per call would emit four readfirstlanes and the resource bit-cast on every load. HoistBufferResource (new Python pass, gfx950-only) collects every source buffer touched by ptx_cp_async_lds, hoists one resource/base pair per buffer to the kernel prologue as AttrStmts, and rewrites the call sites to the _rsrc variant. The same pass scales the ptx_wait_group(n) counts: s_waitcnt vmcnt(N) counts individual buffer_loads rather than commit-groups, so a wait-for-N-groups must become wait-for-(N × loads-per-group). wait_group(0) (wait-all) is the natural sentinel and stays unchanged.

Codegen (codegen_hip.{cc,h}) gets the matching templated intrinsics, and cp_async_gs_lds<N> / cp_async_gs_lds_with_rsrc<N> are added to src/tl_templates/hip/copy.h. The codegen translates the logical element count carried by the call back to bytes via the existing GetTileLangCPAsyncTransferBytes convention, so the vec-loop folding in vectorize_loop.cc keeps working.

Why it's gfx950-only

buffer_load ... lds exists on earlier CDNA, but the wait-count scaling and the bank-conflict-free swizzle delta we rely on were only validated against gfx950. On every other target the routing in lower_ptx_async_copy.cc and the HoistBufferResource pass return unchanged, so existing kernels keep their current lowering.

CI and performance

Here I have a CI for gfx950 common kernels. It can pass the correct check, and have some performance boost on many kenrels.

Currently, flash_attn_fwd have 15 vgpr up. But should be fine with the performance boost. Will investigate later for the reasons of the boost.

benenzhu/tile-kernel-bench-cdna4#23 (comment)

Bench results vs main

Status: OK (no regression beyond 5.00%)

Wins (14)

op shape dtype base → current Δ
flash_attn_fwd b1_h32_s8192_d128 fp16 227.568 → 323.214 TFLOPS +42.03%
flash_attn_fwd b1_h32_s8192_d128_causal fp16 161.118 → 195.348 TFLOPS +21.24%
flash_attn_fwd b8_h32_s2048_d128 fp16 221.783 → 303.424 TFLOPS +36.81%
flash_attn_fwd b8_h32_s2048_d128_causal fp16 126.065 → 153.371 TFLOPS +21.66%
flash_attn_fwd b8_h32_s4096_d128 fp16 227.853 → 322.193 TFLOPS +41.40%
flash_attn_fwd b8_h32_s4096_d128_causal fp16 129.414 → 158.541 TFLOPS +22.51%
gemm 1024x1024x16384_NN bf16 50.846 → 63.890 TFLOPS +25.65%
gemm 1024x8192x8192_NN bf16 331.562 → 429.121 TFLOPS +29.42%
gemm 2048x2048x2048_NN bf16 180.084 → 223.118 TFLOPS +23.90%
gemm 4096x4096x4096_NN bf16 358.475 → 395.988 TFLOPS +10.46%
gemm 4096x4096x8192_NN bf16 359.186 → 406.000 TFLOPS +13.03%
gemm 8192x8192x16384_NN bf16 575.298 → 722.953 TFLOPS +25.67%
gemm 8192x8192x16384_NT bf16 788.202 → 1133.422 TFLOPS +43.80%
gemm 8192x8192x8192_NN bf16 403.935 → 433.821 TFLOPS +7.40%

VGPR down (4) — compiler win

op shape dtype base → current Δ
flash_attn_fwd b1_h32_s8192_d128_causal fp16 138 → 123 -15 VGPR
flash_attn_fwd b8_h32_s2048_d128_causal fp16 138 → 123 -15 VGPR
flash_attn_fwd b8_h32_s4096_d128_causal fp16 138 → 123 -15 VGPR
gemm 8192x8192x16384_NN bf16 256 → 207 -49 VGPR

VGPR up (11) — compiler regression

op shape dtype base → current Δ
flash_attn_fwd b1_h32_s8192_d128 fp16 140 → 154 +14 VGPR
flash_attn_fwd b8_h32_s2048_d128 fp16 140 → 155 +15 VGPR
flash_attn_fwd b8_h32_s4096_d128 fp16 140 → 155 +15 VGPR
gemm 1024x1024x16384_NN bf16 125 → 126 +1 VGPR
gemm 1024x8192x8192_NN bf16 125 → 126 +1 VGPR
gemm 2048x2048x2048_NN bf16 125 → 126 +1 VGPR
gemm 4096x4096x4096_NN bf16 125 → 126 +1 VGPR
gemm 4096x4096x8192_NN bf16 125 → 126 +1 VGPR
gemm 8192x8192x1024_NN bf16 125 → 126 +1 VGPR
gemm 8192x8192x16384_NT bf16 185 → 192 +7 VGPR
gemm 8192x8192x8192_NN bf16 125 → 126 +1 VGPR

Spills down (1) — compiler win

op shape dtype base → current Δ
gemm 8192x8192x16384_NN bf16 5 → 0 -5 spill+sc/4
Full compare table (45 rows)
op           shape                        dtype  tile base              tile cur                TFLOPS base   TFLOPS cur    ΔTFLOPS  TB/s base   TB/s cur      ΔTB/s
--------------------------------------------------------------------------------------------------------------------------------------------------------------------
add_3d_large 512x512x32768                fp16   D2blk1024              D2blk1024                     0.000        0.000        n/a      5.801      5.801     -0.00%
flash_attn_fwd b1_h32_s8192_d128            fp16   M128_N128_t128_s1      M128_N128_t128_s1           227.568      323.214    +42.03%      0.056      0.079    +42.03%
flash_attn_fwd b1_h32_s8192_d128_causal     fp16   M128_N128_t128_s1      M128_N128_t128_s1           161.118      195.348    +21.24%      0.079      0.095    +21.25%
flash_attn_fwd b8_h32_s2048_d128            fp16   M128_N128_t128_s1      M128_N128_t128_s1           221.783      303.424    +36.81%      0.217      0.296    +36.81%
flash_attn_fwd b8_h32_s2048_d128_causal     fp16   M128_N128_t128_s1      M128_N128_t128_s1           126.065      153.371    +21.66%      0.246      0.300    +21.66%
flash_attn_fwd b8_h32_s4096_d128            fp16   M128_N128_t128_s1      M128_N128_t128_s1           227.853      322.193    +41.40%      0.111      0.157    +41.40%
flash_attn_fwd b8_h32_s4096_d128_causal     fp16   M128_N128_t128_s1      M128_N128_t128_s1           129.414      158.541    +22.51%      0.126      0.155    +22.51%
gemm         1024x1024x16384_NN           bf16   128x128x32             128x128x32                   50.846       63.890    +25.65%      0.102      0.129    +25.65%
gemm         1024x1024x16384_NT           bf16   128x128x32             128x128x32                   71.003       70.997     -0.01%      0.143      0.143     -0.01%
gemm         1024x8192x8192_NN            bf16   128x128x32             128x128x32                  331.562      429.121    +29.42%      0.405      0.524    +29.42%
gemm         1024x8192x8192_NT            bf16   128x128x32             128x128x32                  509.715      510.167     +0.09%      0.622      0.623     +0.09%
gemm         2048x2048x2048_NN            bf16   128x128x32             128x128x32                  180.084      223.118    +23.90%      0.264      0.327    +23.90%
gemm         2048x2048x2048_NT            bf16   128x128x32             128x128x32                  253.394      257.960     +1.80%      0.371      0.378     +1.80%
gemm         4096x4096x4096_NN            bf16   128x128x32             128x128x32                  358.475      395.988    +10.46%      0.263      0.290    +10.46%
gemm         4096x4096x4096_NT            bf16   128x128x32             128x128x32                  504.254      513.754     +1.88%      0.369      0.376     +1.88%
gemm         4096x4096x8192_NN            bf16   128x128x32             128x128x32                  359.186      406.000    +13.03%      0.219      0.248    +13.03%
gemm         4096x4096x8192_NT            bf16   128x128x32             128x128x32                  527.678      524.376     -0.63%      0.322      0.320     -0.63%
gemm         8192x8192x1024_NN            bf16   128x128x32             128x128x32                  338.153      353.677     +4.59%      0.413      0.432     +4.59%
gemm         8192x8192x1024_NT            bf16   128x128x32             128x128x32                  547.132      548.002     +0.16%      0.668      0.669     +0.16%
gemm         8192x8192x16384_NN           bf16   256x256x64             256x256x64                  575.298      722.953    +25.67%      0.176      0.221    +25.67%
gemm         8192x8192x16384_NT           bf16   256x256x64             256x256x64                  788.202     1133.422    +43.80%      0.241      0.346    +43.80%
gemm         8192x8192x8192_NN            bf16   128x128x32             128x128x32                  403.935      433.821     +7.40%      0.148      0.159     +7.40%
gemm         8192x8192x8192_NT            bf16   128x128x32             128x128x32                  661.448      661.590     +0.02%      0.242      0.242     +0.02%
gemm_fp8     1024x1024x16384              fp8    128x128x128            128x128x128                 183.978      184.810     +0.45%      0.202      0.203     +0.45%
gemm_fp8     1024x1024x16384_pre          fp8    128x128x128            128x128x128                 157.700      158.897     +0.76%      0.173      0.175     +0.76%
gemm_fp8     2048x2048x16384              fp8    128x128x128            128x128x128                 688.713      683.779     -0.72%      0.420      0.417     -0.72%
gemm_fp8     2048x2048x16384_pre          fp8    128x128x128            128x128x128                 564.941      566.995     +0.36%      0.345      0.346     +0.36%
gemm_fp8     4096x4096x16384              fp8    128x128x128            128x128x128                1298.309     1296.839     -0.11%      0.475      0.475     -0.11%
gemm_fp8     4096x4096x16384_pre          fp8    128x128x128            128x128x128                1102.778     1104.726     +0.18%      0.404      0.405     +0.18%
gemm_fp8     8192x8192x8192               fp8    128x128x128            128x128x128                1267.361     1266.134     -0.10%      0.464      0.464     -0.10%
gemm_fp8     8192x8192x8192_pre           fp8    128x128x128            128x128x128                1089.186     1093.692     +0.41%      0.399      0.401     +0.41%
gemv         32768x8192                   fp16   N64_rt16               N64_rt16                      2.198        2.208     +0.46%      2.198      2.209     +0.46%
gemv         8192x32768                   fp16   N64_rt16               N64_rt16                      1.547        1.549     +0.15%      1.547      1.549     +0.15%
gemv         8192x8192                    fp16   N64_rt16               N64_rt16                      1.679        1.679     +0.05%      1.679      1.680     +0.05%
layernorm    16384x8192                   bf16   blk_m1                 blk_m1                        0.000        0.000        n/a      4.564      4.550     -0.30%
layernorm    32768x8192                   bf16   blk_m1                 blk_m1                        0.000        0.000        n/a      4.761      4.755     -0.12%
layernorm    4096x8192                    bf16   blk_m1                 blk_m1                        0.000        0.000        n/a      3.519      3.508     -0.31%
mla_decode   b128_h128_kv8192_d512_pe64   fp16   N32_H64_split4_t256    N32_H64_split4_t256         234.460      238.171     +1.58%      0.998      1.014     +1.58%
mla_decode   b64_h128_kv4096_d512_pe64    fp16   N32_H64_split4_t256    N32_H64_split4_t256         215.529      217.078     +0.72%      0.944      0.951     +0.72%
rmsnorm      16384x8192                   bf16   blk_m1                 blk_m1                        0.000        0.000        n/a      3.915      3.914     -0.03%
rmsnorm      32768x8192                   bf16   blk_m1                 blk_m1                        0.000        0.000        n/a      4.137      4.149     +0.29%
rmsnorm      4096x8192                    bf16   blk_m1                 blk_m1                        0.000        0.000        n/a      2.750      2.746     -0.17%
softmax      16384x8192                   bf16   -                      -                             0.000        0.000        n/a      4.213      4.206     -0.16%
softmax      32768x8192                   bf16   -                      -                             0.000        0.000        n/a      4.469      4.458     -0.27%
softmax      4096x8192                    bf16   -                      -                             0.000        0.000        n/a      3.107      3.107     +0.00%

Self-hosted runner: 06-05 · workflow run

Summary by CodeRabbit

  • New Features

    • GFX950 async global→shared copy path with buffer-load-based helpers and pre-hoisted buffer resources.
    • Hoist-buffer-resource pass to move resource descriptors to kernel prologue.
    • Layout swizzle-delta support for swizzle-aware remapping.
  • Improvements

    • Broader cp.async recognition across vectorization and transform passes.
    • Reduced async-copy overhead, stricter transfer/arity validation, and AMD wait-count clamping for safety.

Review Change Stack

benenzhu and others added 29 commits May 16, 2026 05:42
…arks

Adds a triton-style view of per-kernel AMD GPU resource usage on the
tilelang JITKernel, queryable as kernel.n_regs / n_spills /
n_max_threads (with a richer resource_usage dict mapping kernel name
to a KernelResourceUsage dataclass).

Implementation:

* tilelang/jit/adapter/hip_resource_info.py — passes
  -Rpass-analysis=kernel-resource-usage to hipcc, parses the per-kernel
  remarks (Function Name / VGPRs / VGPRs Spill / TotalSGPRs / etc.)
  out of the captured stdio, and *strips* those lines before the
  output is printed or included in error messages, so autotune logs
  don't drown in remark blocks while real warnings/errors still
  surface. Includes JSON (de)serialization helpers.

* tilelang/contrib/hipcc.py — adds the remark flag, parses + filters
  the output. Same on the LibraryGenerator HIP path
  (tilelang/jit/adapter/libgen.py); HIP compiles always pipe stdio
  there so the filter has something to act on (verbose=True still
  prints the filtered output).

* tilelang/jit/kernel.py — opens a thread-local recorder window
  around lower() on HIP and exposes the parsed dict as lazy
  resource_usage / n_regs / n_spills / n_max_threads properties.

* tilelang/cache/kernel_cache.py — persists the parsed dict as
  resource_usage.json next to kernel_lib.so on cache miss; reloads
  it on cache hit. This way subsequent runs don't lose the resource
  view to the cache, without paying the runtime API / ctypes cost.
  Older cache entries (no JSON file) silently degrade to None.

Verified on MI355X (gfx950) with a small elementwise add: cache miss
and cache hit both report n_regs=5, n_spills=0; zero remark lines
leak to stdout/stderr.

Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
Extends src/tl_templates/hip/copy.h with two new HIP device templates that
emit buffer_load_dwordx4 ... lds (the gfx950 direct global-to-LDS DMA that
bypasses VGPRs):

- cp_async_gs_lds<N>: self-contained variant; computes the buffer resource
  descriptor and base address per call.
- cp_async_gs_lds_with_rsrc<N>: variant taking a pre-hoisted resource
  descriptor + base address. This is what the HoistBufferResource Python
  pass (Round 5 / M6) will rewrite calls to use, amortising the
  readfirstlane overhead across unrolled loops.

Both templates only emit the direct-DMA path for N == 16; smaller copies
fall back to the existing cp_async_gs<N>. The 16-byte path requires that
the LDS destination be lane-contiguous (base + lane_id * 16); the
swizzle-swap optimisation in lower_tile_op.cc (Round 6 / M7) guarantees
this by moving the XOR swizzle from the LDS store side to the global load
side.

Reuses the existing make_wave_buffer_resource helper at copy.h:22 rather
than redeclaring it. The inline-asm body is lifted from the reference
branch zty_opt_can_run_1120flops because the asm is hardware-pinned to
gfx950 and has no branch-specific dependencies.

Function-to-site mapping audit (Round 0 / M1) located the insertion
point for HIP codegen handlers at src/backend/rocm/codegen/codegen_hip.cc
(refactored from the reference's src/target/codegen_hip.cc); those
handlers land in Round 3 / M4.

Verification: USE_ROCM=ON pip install -e . succeeds; `import tilelang`
loads. Inline-asm validation deferred to JIT-compile time (header-only
template; uninstantiated until codegen emits the call).

Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
…nc_lds_rsrc TIR ops

Declares three new TIR builtin ops on the tl namespace and registers them
in builtin.cc. Inserted right after the existing ptx_cp_async registration
so they appear in a coherent block:

- ptx_cp_async_lds(dst, src, bytes): same shape as ptx_cp_async but
  signals to codegen that the call lowers to cp_async_gs_lds<N> (the
  hardware buffer_load ... lds path added in M2).
- ptx_make_buffer_resource(global_ptr): single-arg op that lowers to
  make_wave_buffer_resource((const void*)(global_ptr)).
- ptx_cp_async_lds_rsrc(dst, src, bytes, rsrc, base): extended form
  carrying the pre-hoisted resource descriptor and base address. The
  HoistBufferResource Python pass (M6) rewrites ptx_cp_async_lds calls
  to this form once per kernel.

All three use the same call-effect kind (kOpaque) as ptx_cp_async since
they have global memory side effects. set_num_inputs(-1) on the *_lds
variants matches ptx_cp_async, which already uses -1 to support both
predicated and non-predicated forms.

Verification: pip install -e . succeeds; `import tilelang` and
`from tvm import tir` both succeed.

Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
…ttrStmt

Adds codegen support in src/backend/rocm/codegen/codegen_hip.cc for the
three TIR ops registered in M3, plus the AttrStmt/LetStmt machinery that
the HoistBufferResource Python pass (M6) will emit.

CallNode handlers (VisitExpr_):
- tl::ptx_make_buffer_resource(ptr) -> expression
  make_wave_buffer_resource((const void*)(ptr))
- tl::ptx_cp_async_lds_rsrc(dst, src, bytes, rsrc, base) -> statement
  tl::cp_async_gs_lds_with_rsrc<bytes>(dst, src, rsrc, base);
- tl::ptx_cp_async_lds(dst, src, bytes [, pred]) -> statement
  tl::cp_async_gs_lds<bytes>(dst, src);  (predicated form falls back to
  cp_async_gs_conditional<bytes> like the regular ptx_cp_async)

LetStmt visitor (new): when the bound value is a ptx_make_buffer_resource
Call, emit `auto x = ...;` instead of letting the base CodeGenC try to
print a C-typed declaration for the int32x4_t result.

AttrStmt branches (extended):
- "buffer_resource_var": emit `auto {rsrc_vid} = make_wave_buffer_resource(
  (const void*)({buf_ptr}));`
- "buffer_base_var":     emit `uint32_t {base_vid} =
  __builtin_amdgcn_readfirstlane((uint32_t)(uintptr_t)({buf_ptr}));`

These match the prologue shape demonstrated in
/root/tile-kernel-bench-cdna4/_fast.cpp lines 10-13. The hoisting pass
in M6 will wrap the kernel body with these AttrStmts so the descriptors
materialise at kernel entry rather than per call.

Verification: pip install -e . succeeds; `import tilelang` succeeds.
Inner-loop emission of cp_async_gs_lds_with_rsrc will be exercised once
M5 (injection decision) routes T.copy through the new ops.

Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
Adds an opt-in `enable_buffer_load_lds` knob to InjectPTXAsyncCopy and
the underlying class, wired from src/backend/rocm/op/copy.cc only when
TargetIsGfx950(T.target) holds. Default false everywhere else, so CUDA
and non-gfx950 ROCm paths are unchanged.

Routing in MakeCPAsyncStmtFromLoads emits tl::ptx_cp_async_lds(dst, src,
total_bytes) instead of tl::ptx_cp_async(dst, src, num_elems) when ALL
of the following hold:
- the flag is on
- the copy is non-predicated
- total_bytes == 16 (matches the device template's only specialised N)
- the destination buffer scope is "shared" or "shared.dyn"
- the destination LDS index contains no bitwise_xor term (a conservative
  proxy for lane contiguity; before the M7 swizzle-swap optimisation
  lands, swizzled LDS layouts will contain XOR and so the routing
  safely no-ops back to the existing ptx_cp_async path)

Note arg 2 is byte width, not logical element count. The codegen handler
for ptx_cp_async_lds (M4) prints arg 2 verbatim as the template width
because the device template is currently only specialised for N == 16.
Keeping the logical-vs-byte distinction at the boundary avoids the
hazard Codex flagged where a blind op swap could otherwise emit
cp_async_gs_lds<1> or <8>.

Side artefacts: CopyIndexInfo gains a total_bytes field already computed
in PrepareCopyIndexInfo; MakeCPAsyncStmtFromLoads loses its static
qualifier so it can read enable_buffer_load_lds_ from `this`.

Downstream cp.async recognisers updated so they treat the new ops the
same as the existing ones:
- src/transform/lower_ptx_async_copy.cc AnalyzeCopyRegion
- src/transform/legalize_safe_memory_access.cc IsCPAsyncOp
- src/transform/vectorize_loop.cc Call dispatch
- src/transform/thread_storage_sync.cc is_cp_async lambda

Verification: pip install -e . succeeds; `import tilelang` succeeds.
Full bench correctness/perf evaluation is the M8 job after M6 (hoisting)
and M7 (swizzle-swap) land.

Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
New file tilelang/transform/hoist_buffer_resource.py with the
descriptor-hoist half of the reference implementation. Scans the
post-LowerAccessPtr PrimFunc body for tl.ptx_cp_async_lds calls,
extracts the source buffer Var from each call's tvm_access_ptr arg,
creates one __rsrc_<buf> and one __base_<buf> Var per unique source
buffer, wraps the body with buffer_resource_var / buffer_base_var
AttrStmts (consumed by the HIP codegen handlers from M4), and rewrites
the calls to ptx_cp_async_lds_rsrc carrying the hoisted vars.

Registered in tilelang/transform/__init__.py alongside the existing
HoistBroadcastValues / DecoupleTypeCast imports.

Inserted into tilelang/engine/phase.py OptimizeForTarget after
MergeIfStmt, before MakePackedAPI, matching Codex's directive site so
LowerAccessPtr has already lowered tl.access_ptr to tvm_access_ptr by
the time the pass runs.

The pass guards on target_is_gfx950(target) and is a no-op on every
other target. AMD vmcnt wait-count scaling (the second half of the
reference) is intentionally NOT included in this commit; it will land
as a separate milestone (M6.5) only if the M6 / M7 bench shows a
correctness failure attributable to async wait counts.

Verification: pip install -e . succeeds; `from tilelang.transform
import HoistBufferResource` imports the callable. End-to-end emission
will be exercised once M7 lands the swizzle-swap optimisation that
removes the XOR from LDS-side indices so M5's IsLdsLaneContiguous
gate begins routing copies to ptx_cp_async_lds.

Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
Two changes that together let the M5 IsLdsLaneContiguous gate route the
target shape's g2s copies into the gfx950 buffer_load...lds path.

1. Layout swizzle-delta API (src/layout/layout.h, layout.cc):
   - New Optional<PrimExpr> swizzle_delta_ member on LayoutNode.
   - virtual PrimExpr SwizzleDelta(input_indices): substitutes the last
     InputDim() entries of input_indices into the stored delta (same
     convention as Forward), returns 0 when no delta is set.
   - bool HasSwizzle() / void SetSwizzleDelta(): trivial accessors.
   - LayoutNode::Expand propagates swizzle_delta_ through the variable
     remap that shifts InputPlaceholders by the leading-shape offset.

2. Swizzle factories now record their column-XOR delta
   (src/layout/gemm_layouts.cc) so HasSwizzle() returns true for the
   layouts the GEMM tile-op actually produces:
   - MakeQuarterBankSwizzleLayout2D: (xor2x2(c, s>>2) - c) * vec
   - MakeHalfBankSwizzleLayout2D:    (xor4x4(c, s>>1) - c) * vec
   - MakeFullBankSwizzleLayout2D:    (xor8x8(c, s) - c) * vec

3. Swizzle-swap in src/transform/lower_tile_op.cc BufferStoreNode
   visitor: when TargetIsRocm && !is_ptx_ && IsSharedBuffer(buffer) &&
   layout has a swizzle AND the store value is a direct global
   BufferLoad of matching arity AND the layout output has unit leading
   dim, rewrite

       shared[Forward(local)] = global[base + local]

   to

       shared[Forward(local) - delta(local)] = global[base + local + delta(local)]

   XOR is self-inverse so net data movement is unchanged, but the LDS
   destination becomes lane-contiguous and the cp.async injector's
   IsLdsLaneContiguous check stops rejecting the store. Other layouts
   (no HasSwizzle, non-ROCm, ptx path, non-shared dst, non-direct-load
   stores) fall through to the existing Forward-only path.

Diagnostic LOG(INFO) noise from the reference branch is intentionally
omitted. Verification: pip install -e . succeeds; `import tilelang`
succeeds. End-to-end emission + perf is the M8 job.

Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
Without this pass, kernels with multiple `shared.dyn` buffers (which
includes the bench's matmul with A_shared + B_shared) fail
LowerDeviceKernelLaunch with "Only one dynamic shared memory
allocation is allowed". The pass was commented out on this branch with
no explanation; re-enabling it lets the bench compile through to
codegen + kernel launch.

This is a prerequisite for M8 (perf gate). Pre-existing constraint, not
related to the buffer_load_lds work; landed as a separate small
milestone for traceability per AC-3.

Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
…odegen num_elems convention

The Round 1 emission had an outer `for vec=0..8` loop wrapping each
tl::cp_async_gs_lds_with_rsrc<16> call, causing 8x overlapping LDS
writes per thread and a runtime GPU coredump. Root cause: tl::ptx_cp_async
stores its transfer width as a *logical element count* and several
downstream passes (vectorize_loop.cc::MutatePTXCPAsyncExpr_,
loop_vectorize.cc dispatch, merge_shared_memory_allocations.cc rewrite)
key off the op identity to widen/remap. Earlier I emitted ptx_cp_async_lds
with arg 2 in *bytes*, so those passes either skipped the call (no
widening = vec loop survived) or could not remap the merged shared
buffer.

This commit makes the ptx_cp_async_lds family follow the same logical-
element-count convention as tl::ptx_cp_async:

- src/transform/lower_ptx_async_copy.cc: MakeCPAsyncStmtFromLoads's LDS
  branch now passes `num_elems` (not `total_bytes`) as arg 2.
- src/backend/rocm/codegen/codegen_hip.cc: the tl::ptx_cp_async and
  tl::ptx_cp_async_lds CallNode handlers are merged; both go through
  GetTileLangCPAsyncTransferBytes to derive the template byte width.
  The tl::ptx_cp_async_lds_rsrc handler does the same conversion
  inline (it has 5 args so cannot reuse GetTileLang...).
- src/transform/vectorize_loop.cc: GetCPAsyncBitsPerCall and
  MutatePTXCPAsyncExpr_ ICHECKs widen to accept tl::ptx_cp_async_lds.
  This lets the vec(k) widening multiply num_elems by k and produces
  one call covering the full vec range, so the surrounding vec loop's
  body becomes uniform and the loop is consumed by the vectorizer.
- src/transform/loop_vectorize.cc: same widening for the
  ScalarToVector pipeline.
- src/transform/merge_shared_memory_allocations.cc: extend the
  cp_async dst-ptr remap to recognise tl::ptx_cp_async_lds so the
  merged dyn-shared buffer pointer is substituted correctly.

Result on the bench (8192x8192x8192 NT, tile 256x256x64): emission now
matches `_fast.cpp`'s outer-loop shape verbatim
(tl::cp_async_gs_lds_with_rsrc<16> per i_1 iteration, no inner vec
loop). Bench compile passes, kernel launches without coredump, ~867
TFLOPS. Correctness still fails because the swizzle-swap from M7 is
not engaging on the CDNA layout, leaving the XOR on the LDS-store
side; that is the next debug step.

Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
Codex Round 1 review flagged the previous IsLdsLaneContiguous as unsafe:
it only rejected expressions containing a tir.bitwise_xor call, but the
swizzle layouts on this branch expand the XOR into shift/and/add bit
arithmetic (see e.g. xor2x2 in src/layout/gemm_layouts.cc:370). The
expanded form passed the gate even when the LDS destination was not
lane-contiguous.

Replace IsLdsLaneContiguous with a multi-sample affine proof modelled on
/root/backuptilelang/src/transform/inject_ptx_async_copy.cc::IsLdsContiguous:

- Walk the dst index, collect free Vars.
- Pick a thread-like Var (name contains "thread" or equals "tx"/"tid").
- Compute f(0), f(1), expected_stride = f(1) - f(0); require IntImm.
- Sample at points 2..1023; require analyzer.CanProveEqual(f(k) - f(0),
  k * stride) at every sample. Covers wave-32/64 boundaries, warp tiles,
  bank-swizzle phases (powers of two up to 64), and the 256-thread block
  boundaries the bench uses.

Returns false if no thread-like var is found or any sample disagrees, so
the LDS route safely no-ops back to the existing ptx_cp_async path until
M9 actually moves the swizzle off the LDS side.

Removes ContainsBitwiseXor helper (no longer needed). Build + import OK.

Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
…ix for AC-2/AC-1)

The M7 swap in VisitStmt_(BufferStoreNode) never fired on this branch
because rocm::Copy::Lower calls InjectPTXAsyncCopy inline and the
BufferStores into A_shared / B_shared are converted to
tl::ptx_cp_async_lds Calls before LowerTileOpPass re-visits the
lowered Stmt (`return IRMutatorWithAnalyzer::VisitStmt(lowered)` at
~line 1130). Diagnosed by adding LOG(WARNING) at the top of the
BufferStore visitor: it only ever saw C_local, A_local, B_local, C
- never the shared buffers.

Per Codex Round 1 review's directive, add the swap on the resulting
Call node instead. New VisitExpr_(CallNode) branch that, for
tl::ptx_cp_async_lds on a ROCm target with a swizzled remapped shared
destination and a direct global source BufferLoad:
  1. resolve_load() pulls BufferLoad from the tl::access_ptr arg
     (handles direct BufferLoad and let-bound BufferLoad via
     let_bindings_)
  2. compute swizzled = layout->Forward(dst_logical_indices) and
     delta = layout->SwizzleDelta(dst_logical_indices)
  3. build new dst access_ptr against buffer_remap_[dst_buf] with
     swizzled[last] - delta on the last dim (lane-contiguous LDS)
  4. build new src access_ptr with src_logical_indices[last] + delta
     on the last dim (XOR moved to the global side; self-inverse)
  5. return rewritten Call directly so the default arg visitor does
     not re-apply the swizzled layout to the destination

Rank-difference between dst and src is allowed: the LDS dst typically
has a leading Expand dim that the global src does not, but the
"last dim" of each maps to the same physical column so the swap is
well-defined on the last dim alone.

Bench result with M5-harden + M9 (M6.5 still pending): correctness
PASS, TFLOPS 866 (still below 1000 floor — M6.5 wait-count scaling
should close the gap). Emission shape now matches _fast.cpp exactly:
LDS dst is `i_1 * 4096 + threadIdx.x * 8` (linear), global src carries
the bank-swizzle pattern.

Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
Codex Round 1 review flagged the generated cp_async_wait<1> vs the
target's cp_async_wait<8> as a likely 1000 TFLOPS blocker, and after
M9 fixed the swizzle direction the bench correctness passed at only
866 TFLOPS - confirming Codex's analysis.

Add _get_loads_per_group + _fix_amd_wait_counts to
tilelang/transform/hoist_buffer_resource.py, matching the reference
algorithm but rewriting on this branch's IR shape:

- _is_async_load_call recognises Evaluate(Call(ptx_cp_async_lds OR
  ptx_cp_async_lds_rsrc, ...)).
- _is_commit_call recognises Evaluate(Call(ptx_commit_group, ...))
  (the existing pipeline emits these directly; the reference
  branch's async_commit_queue_scope AttrStmt is already lowered by
  the time we run).
- _find_for_with_commit walks the body to find the innermost For
  whose subtree contains a commit; _count_async_loads counts async
  loads inside one iteration of that For, multiplying by IntImm
  loop extents for nested unrolls.
- _fix_amd_wait_counts rewrites Evaluate(Call(ptx_wait_group, [n]))
  to Evaluate(Call(ptx_wait_group, [n * loads_per_group])) when
  n > 0. wait_group(0) (wait-all) stays unchanged.

The pass runs the scaling AFTER the existing descriptor hoist so
the post-hoist body (now using ptx_cp_async_lds_rsrc) is the input
to the load counter; the recognised loads include the rsrc form.

Bench result with M5-harden + M9 + M6.5:
  [iter] correctness: PASS
  [iter] latency: 0.9803 ms
  [iter] TFLOPS:   1121.62
  [iter] TB/s:     0.411

Above the AC-1 hard floor of 1000. Matches the reference branch's
"1120 TFLOPS now." commit message. Emitted .hipi has the target
tl::cp_async_wait<8> at the K-loop wait and tl::cp_async_wait<0>
at the final wait; .s contains 16 buffer_load_dwordx4 ... lds
occurrences.

Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
…roduce linear LDS

The M9 swizzle-swap subtracts SwizzleDelta from the LAST output dim of
Forward(dst_indices). That cancels the XOR cleanly for layouts whose
swizzle is confined to the last column (the bench's A_shared layout,
and B_shared in the bench's symmetric NT shape), but layouts where the
swizzle spreads across multiple output dims leave residual swizzle in
the higher output dims AND over-correct src. The dst LDS index is no
longer lane-contiguous, so the emitted buffer_load_dwordx4 ... lds
writes to wrong addresses.

Concrete reproducer: python gemm/example_gemm.py at the default
1024x1024x16384 shape. A's LDS dst comes out as
`i_1 * 1024 + threadIdx.x * 8` (clean linear) but B's comes out as
`((tx & 15) >> 3) * 2048 + i_2 * 512 + (tx >> 4) * 64 + (tx & 7) * 8`
(NOT lane-contiguous), and B's global src ends up with a spurious
`- ((tx & 7) * 8)` term. Bench data 99.7% mismatched.

Fix has two parts:

1. Restructure the M9 handler so EITHER outcome (swap-or-skip) is a
   well-defined action: if the candidate gate (`dst_ap` / `src_ap` /
   shared dst / global src / remap / HasSwizzle / non-empty indices)
   does not hold, immediately downgrade the op from
   tl::ptx_cp_async_lds back to tl::ptx_cp_async (same arg shape) and
   recurse with the default visitor. Without this, a call that the
   guard rejected would keep its ptx_cp_async_lds op and codegen would
   emit the LDS template against an unmodified (swizzled) dst.

2. After the swap math runs, sample each post-swap new_dst_indices[d]
   against a thread-like Var: compute f(0), f(1), expected_stride,
   then require analyzer.CanProveEqual(f(k)-f(0), k*stride) at sample
   points 2..63. If any dim is non-affine in the thread var (bit-
   extract `(tx & m) >> s` terms etc.), downgrade as above.

Verified on two shapes:
- python gemm/example_gemm.py (1024x1024x16384):
  All check passed. Latency 0.0443 ms.
- bash scripts/iter_buffer_load.sh (8192x8192x8192):
  correctness PASS, TFLOPS 1116.84 (A still uses LDS fast path; B
  symmetric here also passes the linearity check).

Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
Existing CBA helper in ComputePlanCandidate only runs when source_buffer is
undefined. For T.copy(B, B_shared) at kCommon/kStrict, source_buffer chain
takes B_shared and routes through ComputeLoopLayoutFromBuffer, so CBA never
fires and the default PlanLoopPartition flatten puts 16 lanes on N (crossing
the FullBank tc=2 boundary at lane 8 -> +1920B jump).

Add an early hook at the top of ParallelOp::InferLayout that runs CBA at all
3 levels and sets loop_layout_ unconditionally when an eligible target buffer
exists. CBA's gate (last_dim_bytes > 128B and divisible) keeps it inert on
A_shared / NT-B_shared / C_local cases.

Effect:
- NN 1024^3 K=16384 128x128x32: correctness PASS (was wrong values before)
- NT 8192^2 K=8192 256x256x64 stages=2: 1111 TFLOPS, correctness PASS
  (NT B last-dim = K = 64 = 1 bank cycle, gate skips; perf unchanged)

Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
Previously the IsLdsLaneContiguous gate at lower_ptx_async_copy.cc:574
rejected any swizzled (XOR-laden) LDS index, falling through to
ptx_cp_async. That cut M9 out of the loop for FullBank tc>1 cases (NN B
under the new CBA binding from M10), even though M9's SwizzleDelta swap
would produce a lane-contiguous index.

Drop the IsLdsLaneContiguous gate. Always emit ptx_cp_async_lds when the
destination is shared + 16B + non-predicated. M9 then either rewrites
the call (LDS path stays) or downgrades to ptx_cp_async if post-swap
is non-affine.

Effect:
- NN 8192^3 128x128x32 stages=3 threads=128: 443 TFLOPS (was 370)
- NN 8192^3 256x256x32 stages=3 threads=256: 602 TFLOPS, PASS
- NN 8192^3 256x256x64 stages=2 threads=256: 760 TFLOPS, PASS
- NT 8192^2 K=8192 256x256x64 stages=2 threads=512: 1116 TFLOPS, PASS

A_shared still downgrades (M9 affine check fails because HalfBank's ts
boundary splits the warp), so A continues on cp_async path. Fixing A
needs CBA-style binding on the stride dim — open work.

Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
Adds in-source Chinese comments tracing the FullBank swizzle math
(ts/s/tc/c/vec/c_swizzle/index, with concrete dimensions for the NN
B_shared=32x128 bf16 case) so the layout's role in the M10/M11
chunk-block-aware binding work is documented at the source. Pure
comment additions, no behavior change.

Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
After always-emit-LDS landed (let the LowerTileOp swizzle-swap visitor
own the lane-contiguity decision and downgrade when it can't), the
IsLdsLaneContiguous sampler and the dst_check_index parameter to
MakeCPAsyncStmtFromLoads are no longer reachable. Remove them so the
upstream-facing diff has no dead code; behavior is unchanged.

Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
The chunk-block-aware binding now lands via the early hook at the top
of ParallelOp::InferLayout (runs at all 3 levels, before source_buffer
dispatch). By the time PlanLoopPartition's ComputePlanCandidate runs,
loop_layout_ is already set when CBA applies, so the second
ComputeChunkBlockAwarePlanCandidate call site here is dead in practice.
Removed it; behavior is unchanged.

Verified:
- NT 8192^2 K=8192 256x256x64 stages=2 threads=512: 1114 TFLOPS, PASS
- NN 1024x1024x16384 128x128x32 stages=3 threads=128: 64 TFLOPS, PASS

Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
Migrate buffer_load_lds + hoist + AMD vmcnt scaling work to upstream's
tir->tirx namespace rename and LetStmtNode->BindNode change. Pick up
all main improvements (Metal backend, SM75 MMA, cluster TMA, etc.)
and the tilelang_repo-side n_regs/n_spills exposure.

Conflicts resolved:
- src/backend/rocm/codegen/codegen_hip.{cc,h}: keep new BindNode
  override for ptx_make_buffer_resource on top of main's
  AllocateNode->AllocBufferNode rename.
- src/transform/lower_ptx_async_copy.cc: keep gfx950 ptx_cp_async_lds
  routing block; use main's bare Array type alias.
- src/transform/ptx_async_copy_injector.h: combine main's tirx::Stmt
  type with the extra enable_buffer_load_lds parameter.
- src/transform/lower_tile_op.cc: rename tir:: -> tirx:: in the
  swizzle-swap affine prover.
- tilelang/transform/hoist_buffer_resource.py: port to tvm.tirx;
  drop dead LetStmt branch; bump op names tir.* -> tirx.*.
- tilelang/contrib/hipcc.py: take main's _resolve_artifact_paths.
- tilelang/contrib/hip_resource_info.py: drop docstring whitespace nit.
- tilelang/jit/kernel.py: drop dead Py3.9 ParamSpec compat shim.
- 3rdparty/tvm: take main's pointer (8435b89).

Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
@github-actions

Copy link
Copy Markdown

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai

coderabbitai Bot commented May 22, 2026

Copy link
Copy Markdown
Contributor

Important

Review skipped

Draft detected.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 5ccf9cf6-5d56-48e0-857c-99b2b9fad974

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds gfx950-optimized global→LDS async copy support: new TL intrinsics and 16-byte LDS templates, swizzle-delta layout metadata and propagation, HIP emission and hoisted resource locals, injector routing for eligible transfers, LowerTile index rewrites, ParallelOp planning, and supporting pass updates.

Changes

gfx950 Async LDS Copy and Swizzle Optimization

Layer / File(s) Summary
Layout swizzle delta infrastructure
src/layout/layout.h, src/layout/layout.cc, src/layout/gemm_layouts.cc
LayoutNode gains optional swizzle_delta_ metadata with SwizzleDelta(), HasSwizzle(), and SetSwizzleDelta() accessors; delta propagates through Expand(); quarter/half/full-bank swizzle layout factories compute and store the delta value.
New async copy intrinsics (contracts and registration)
src/op/builtin.h, src/op/builtin.cc
Declares and registers ptx_cp_async_lds, ptx_make_buffer_resource, ptx_cp_async_lds_rsrc as kOpaque TL ops with documented signatures and input-count specs.
LDS copy template implementations
src/tl_templates/hip/copy.h
Adds cp_async_gs_lds<N> and cp_async_gs_lds_with_rsrc<N> templates implementing 16-byte buffer_load_dwordx4 ... lds inline asm on gfx950; hoisted-rsrc variant reduces per-call SGPR overhead; other sizes fall back to standard cp_async_gs<N>.
HIP codegen for async copies and buffer resources
src/backend/rocm/codegen/codegen_hip.h, src/backend/rocm/codegen/codegen_hip.cc
Emits make_wave_buffer_resource(...), tl::cp_async_gs[_lds][_with_rsrc]<bytes> and conditional templates; converts logical element counts to bytes; clamps ptx_wait_group immediates to 63; emits AttrStmt-bound resource/base locals and special-cases Bind for hoisted resource exprs.
CP.async injector & LDS routing
src/transform/lower_ptx_async_copy.cc, src/transform/ptx_async_copy_injector.h
Extends PTXAsyncCopyInjector with enable_buffer_load_lds flag, threads total_bytes through copy prep and CopyIndexInfo, and conditionally emits ptx_cp_async_lds/ptx_cp_async_lds_rsrc for eligible non-predicated 16-byte global→shared transfers; updates purity analysis and API wiring.
HoistBufferResource pass and pipeline integration
tilelang/transform/hoist_buffer_resource.py, tilelang/transform/__init__.py, tilelang/engine/phase.py, src/backend/rocm/op/copy.cc
Python HoistBufferResource() pass hoists buffer resource descriptors used by ptx_cp_async_lds into prologue AttrStmt bindings, rewrites calls to ptx_cp_async_lds_rsrc, and scales AMD wait-counts by loads-per-group; re-exported and integrated into OptimizeForTarget.
LowerTileOp swizzle rewrite & store optimization
src/transform/lower_tile_op.cc
ROCm-specific rewrite of ptx_cp_async_lds that applies layout->Forward(...) plus SwizzleDelta(...) to indices with downgrade paths; also implements a non-PTX store-side swizzle-swap for direct global-load stores.
Chunk-block-aware planning for ParallelOp
src/op/parallel.h, src/op/parallel.cc
ParallelOp gains early chunk-block-aware plan candidate computation and ComputeChunkBlockAwarePlanCandidate to detect eligible shared-buffer patterns and return a Fragment used to short-circuit later layout inference.
Supporting transform updates
src/transform/legalize_safe_memory_access.cc, src/transform/loop_vectorize.cc, src/transform/merge_shared_memory_allocations.cc, src/transform/thread_storage_sync.cc, src/transform/vectorize_loop.cc
Multiple passes updated to recognize ptx_cp_async_lds/ptx_cp_async_lds_rsrc as CP.async ops for legality, vectorization, merging, sync, and widening; vectorize preserves hoisted resource args for the 5-arg form.

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly Related PRs

  • tile-ai/tilelang#2242: Related edits to bank-swizzle layout logic in src/layout/gemm_layouts.cc that interact with swizzle-delta propagation.
  • tile-ai/tilelang#2058: Changes to cp_async_gs<16> template paths that connect with the gfx950 buffer_load LDS route added here.
  • tile-ai/tilelang#2002: Pipeline reordering around tilelang transforms similar to the HoistBufferResource insertion in this PR.

Suggested Reviewers

  • LeiWang1999

Poem

🐇 I hoisted descriptors with a twitchy hop,

Swizzle deltas slid so copies don't stop,
Gfx950 hums with buffer_load delight,
Async loads land tidy in LDS night,
A rabbit applauds this optimized flight!

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 23.94% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The PR title '[ROCM] Fix buffer_load_lds support for gfx950' accurately and concisely summarizes the main change: enabling buffer_load_lds async copies for gfx950 AMD GPUs, which is the core objective described in the PR objectives.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@benenzhu

Copy link
Copy Markdown
Contributor Author

cc @zhangnju

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
src/op/parallel.cc (1)

404-436: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Preserve annotated layout/predicate precedence over CBA auto-planning.

At Line 411, the CBA override runs before annotation adoption, so explicit kParallelLoopLayout/kParallelLoopPredicate can be silently bypassed when CBA matches. That changes declared layout semantics and can drop intended predicate guards.

Proposed fix
-  if (!loop_layout_.defined()) {
+  // Keep explicit annotations as highest-priority contract.
+  if (!loop_layout_.defined() && !annotated_layout_unbound_.defined()) {
     // Reuse the same vec_size calculation as ComputePlanCandidate.
     auto maybe_remapped_root =
         IfBufferRemapLoopGenerator::run(root_, T.buffer_remap, T.layout_map);
@@
   }
   if (!loop_layout_.defined() && annotated_layout_unbound_.defined()) {
     loop_layout_ =
         annotated_layout_unbound_.value()->BindThreadRange(T.thread_bounds);
     if (annotated_predicate_.defined()) {
       predicate_ = annotated_predicate_.value();
     }
   } else if (!loop_layout_.defined() && source_buffer.defined() &&
              allow_layout_propgate) {
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/op/parallel.cc` around lines 404 - 436, The CBA auto-plan
(ComputeChunkBlockAwarePlanCandidate + loop_layout_ assignment) is run before
honoring explicit annotations, which lets CBA silently override annotated
kParallelLoopLayout/kParallelLoopPredicate; reorder logic so
annotated_layout_unbound_ is applied first: if
annotated_layout_unbound_.defined(), call
annotated_layout_unbound_.value()->BindThreadRange(T.thread_bounds) into
loop_layout_ and set predicate_ from annotated_predicate_ (if defined) before
running ComputeChunkBlockAwarePlanCandidate/GetVectorizeSize; only run the CBA
fallback when loop_layout_ remains undefined after adopting annotations. Use the
existing symbols loop_layout_, annotated_layout_unbound_, annotated_predicate_,
ComputeChunkBlockAwarePlanCandidate and the surrounding block that computes
vector_size to implement this reorder.
🧹 Nitpick comments (2)
tilelang/transform/hoist_buffer_resource.py (2)

108-125: 💤 Low value

Symbolic loop extents are not multiplied into the load count.

When a For loop has a non-IntImm extent (e.g., a dynamic/symbolic extent), the multiplier stays unchanged. This could undercount async loads for dynamically-sized inner loops, leading to wait counts that are too small.

Given the PR targets unrolled tile-copy loops on gfx950 (which typically have constant extents), this is likely acceptable, but consider logging a warning or asserting IntImm when the pass is active.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tilelang/transform/hoist_buffer_resource.py` around lines 108 - 125,
_count_async_loads currently ignores non-IntImm loop extents in tir.For, so
symbolic/dynamic extents aren't multiplied into the async load count and can
undercount waits; update _count_async_loads to detect when stmt is a tir.For
with a non-IntImm extent and either 1) conservatively treat the extent as
unknown by logging a warning or raising/asserting (e.g., assert
isinstance(stmt.extent, tir.IntImm)) to fail fast, or 2) conservatively multiply
by a safe upper bound if available; locate the tir.For handling in
_count_async_loads and add the chosen behavior (warning/assert) to ensure
dynamic extents don’t silently produce incorrect counts.

90-105: 💤 Low value

The hasattr(stmt, "body") fallthrough is overly generic.

This branch will match any node type with a body attribute, which could include unexpected statement types. Consider explicitly handling the known cases (e.g., tir.LetStmt, tir.AssertStmt) or at minimum documenting what this branch is intended to cover.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tilelang/transform/hoist_buffer_resource.py` around lines 90 - 105, The
fallback in _find_for_with_commit uses a generic hasattr(stmt, "body") which is
too broad; replace it with explicit type checks for the statement kinds that
actually wrap a body (e.g., check isinstance(stmt, tir.LetStmt) and
isinstance(stmt, tir.AssertStmt) and any other known wrappers in this project
such as tir.IfThenElse or tir.AttrStmt) and recursively call
_find_for_with_commit only for those cases, or add a short comment listing the
intended covered types if keeping a generic branch; update the branch for
hasattr(stmt, "body") to handle only the explicit classes (and return None
otherwise) so unrelated nodes with a body attribute do not get traversed
unexpectedly.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@src/backend/rocm/codegen/codegen_hip.cc`:
- Around line 926-955: The ptx_cp_async_lds_rsrc branch computes total_bytes
from the destination element type only and doesn't validate source/destination
element-width equality or enforce allowed transfer widths; update the branch
handling in codegen_hip.cc (the ptx_cp_async_lds_rsrc branch that uses
GetAccessPtrElementType, num_elems_imm, total_bits/total_bytes and emits
tl::cp_async_gs_lds_with_rsrc) to: retrieve/ICHECK a src element type via
GetAccessPtrElementType(op->args[1]) like you do for dst, ICHECK both src and
dst element types exist and have identical bits() and lanes(), compute
per-element bytes = dst_elem_type.bits()*dst_elem_type.lanes()/8 and ICHECK it
is one of {4,8,16}, then compute total_bytes as before and emit the call using
that validated size so mismatched or unsupported widths fail early.

In `@src/layout/layout.cc`:
- Around line 518-531: In LayoutNode::SwizzleDelta ensure you check that
input_indices.size() is at least InputDim() before using input_indices[offset +
i]; add a precondition at the start (e.g., an explicit check/ICHECK/CHECK and a
clear error message) that fails fast when input_indices.size() < InputDim(), or
alternatively clamp/guard the loop so Substitute never indexes past
input_indices; reference the function LayoutNode::SwizzleDelta, the variables
input_indices, InputDim(), offset, and InputPlaceholder(i) when adding the check
and message.

In `@src/op/builtin.h`:
- Around line 526-527: The documentation for the PTX intrinsic ptx_cp_async_lds
(and the other affected intrinsic docstrings) currently states the third
argument is "bytes" but the lowering code treats arg3 as a logical element count
and computes byte size from the element dtype; update the docstrings to say the
third parameter is "num_elems" (logical element count) and explicitly note that
byte width is derived from the element dtype at lowering so callers should pass
element count not raw bytes for functions like ptx_cp_async_lds and the
similarly-documented intrinsics around the same block.

In `@src/transform/vectorize_loop.cc`:
- Around line 683-684: The ICHECK that currently allows only tl::ptx_cp_async()
and tl::ptx_cp_async_lds() should also accept tl::ptx_cp_async_lds_rsrc() so the
hoisted-resource path isn't hard-failed; update the guards in
MutatePTXCPAsyncExpr_ and GetCPAsyncBitsPerCall (the ICHECKs that reference
tl::ptx_cp_async() and tl::ptx_cp_async_lds()) to include
tl::ptx_cp_async_lds_rsrc() as an allowed op, ensuring all places that validate
cp.async variants (including the occurrences around the other noted blocks)
treat the new op the same as the existing ones.

---

Outside diff comments:
In `@src/op/parallel.cc`:
- Around line 404-436: The CBA auto-plan (ComputeChunkBlockAwarePlanCandidate +
loop_layout_ assignment) is run before honoring explicit annotations, which lets
CBA silently override annotated kParallelLoopLayout/kParallelLoopPredicate;
reorder logic so annotated_layout_unbound_ is applied first: if
annotated_layout_unbound_.defined(), call
annotated_layout_unbound_.value()->BindThreadRange(T.thread_bounds) into
loop_layout_ and set predicate_ from annotated_predicate_ (if defined) before
running ComputeChunkBlockAwarePlanCandidate/GetVectorizeSize; only run the CBA
fallback when loop_layout_ remains undefined after adopting annotations. Use the
existing symbols loop_layout_, annotated_layout_unbound_, annotated_predicate_,
ComputeChunkBlockAwarePlanCandidate and the surrounding block that computes
vector_size to implement this reorder.

---

Nitpick comments:
In `@tilelang/transform/hoist_buffer_resource.py`:
- Around line 108-125: _count_async_loads currently ignores non-IntImm loop
extents in tir.For, so symbolic/dynamic extents aren't multiplied into the async
load count and can undercount waits; update _count_async_loads to detect when
stmt is a tir.For with a non-IntImm extent and either 1) conservatively treat
the extent as unknown by logging a warning or raising/asserting (e.g., assert
isinstance(stmt.extent, tir.IntImm)) to fail fast, or 2) conservatively multiply
by a safe upper bound if available; locate the tir.For handling in
_count_async_loads and add the chosen behavior (warning/assert) to ensure
dynamic extents don’t silently produce incorrect counts.
- Around line 90-105: The fallback in _find_for_with_commit uses a generic
hasattr(stmt, "body") which is too broad; replace it with explicit type checks
for the statement kinds that actually wrap a body (e.g., check isinstance(stmt,
tir.LetStmt) and isinstance(stmt, tir.AssertStmt) and any other known wrappers
in this project such as tir.IfThenElse or tir.AttrStmt) and recursively call
_find_for_with_commit only for those cases, or add a short comment listing the
intended covered types if keeping a generic branch; update the branch for
hasattr(stmt, "body") to handle only the explicit classes (and return None
otherwise) so unrelated nodes with a body attribute do not get traversed
unexpectedly.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: af15cc00-88dc-44d3-bf99-8bd52162e26b

📥 Commits

Reviewing files that changed from the base of the PR and between 47288a2 and 03343e2.

📒 Files selected for processing (22)
  • src/backend/rocm/codegen/codegen_hip.cc
  • src/backend/rocm/codegen/codegen_hip.h
  • src/backend/rocm/op/copy.cc
  • src/layout/gemm_layouts.cc
  • src/layout/layout.cc
  • src/layout/layout.h
  • src/op/builtin.cc
  • src/op/builtin.h
  • src/op/parallel.cc
  • src/op/parallel.h
  • src/tl_templates/hip/copy.h
  • src/transform/legalize_safe_memory_access.cc
  • src/transform/loop_vectorize.cc
  • src/transform/lower_ptx_async_copy.cc
  • src/transform/lower_tile_op.cc
  • src/transform/merge_shared_memory_allocations.cc
  • src/transform/ptx_async_copy_injector.h
  • src/transform/thread_storage_sync.cc
  • src/transform/vectorize_loop.cc
  • tilelang/engine/phase.py
  • tilelang/transform/__init__.py
  • tilelang/transform/hoist_buffer_resource.py

Comment thread src/backend/rocm/codegen/codegen_hip.cc
Comment thread src/layout/layout.cc
Comment thread src/op/builtin.h Outdated
Comment thread src/transform/lower_tile_op.cc Outdated
Comment thread src/transform/vectorize_loop.cc Outdated
- codegen_hip: route ptx_cp_async_lds_rsrc through
  GetTileLangCPAsyncTransferBytes so src/dst widths must match and
  the final byte width is validated against {4,8,16}.
- layout: ICHECK_GE the index count in SwizzleDelta so a too-short
  call fails with a clear contract error instead of OOB-reading.
- builtin.h: fix ptx_cp_async_lds / _rsrc docstrings -- arg 2 is
  num_elems, not bytes (lowering derives byte width from the
  access-ptr dtype).
- lower_tile_op: drop the "tx"/"tid"/"thread*" name-matching in the
  affine lane-contiguity proof and key off the real threadIdx.x
  binding tracked in thread_var_ so a future rename of the lane var
  can't silently misclassify a non-affine LDS index as OK.
- vectorize_loop: accept ptx_cp_async_lds_rsrc in
  MutatePTXCPAsyncExpr_ / GetCPAsyncBitsPerCall (it's already routed
  here from VisitExpr_) and preserve the trailing (rsrc, base) args
  through the rewrite so codegen still sees them.

Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
@benenzhu benenzhu closed this May 22, 2026
@benenzhu benenzhu reopened this May 23, 2026
benenzhu and others added 2 commits May 23, 2026 13:18
The unconditional override in the early CBA branch of InferLayout was
forcing the loop layout onto MMA accumulators like acc_o_l / C_local /
dq, whose fragment is already in T.layout_map with its own MMA-derived
binding. The follow-up ValidateCandidateAgainstFragments call would
then throw a "Layout infer conflict between <buf> and <buf>" error.

Only adopt the CBA candidate when it actually validates against the
existing fragments; otherwise fall through to the normal source-buffer
/ free-inference paths so the MMA fragment's binding still wins.

Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
The chunk-block-aware override exists solely to make
buffer_load_dwordx4...lds usable on gfx950; firing it on CUDA / older
AMD targets can only force the loop binding into a shape that conflicts
with MMA fragment layouts (the CUDA CI hit this on acc_o_l / C_local /
dq / C_local_accum). Gate on TargetIsGfx950 so every other target keeps
its existing layout-inference behaviour untouched. The validation
fallback added in the previous commit stays as defense-in-depth.

Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
src/op/parallel.cc (2)

417-430: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Apply the same kCoalescedWidth clamp in the early CBA path.

The comment says this reuses ComputePlanCandidate(...)'s vector-size calculation, but it stops before the attr::kCoalescedWidth override on Lines 731-744. If that annotation is present, the gfx950 fast path can build a different partition than the normal plan path, and the later code never corrects it because loop_layout_ is already set.

Also applies to: 731-744

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/op/parallel.cc` around lines 417 - 430, The early CBA path computes
vector_size using GetVectorizeSize on maybe_remapped_root but never applies the
attr::kCoalescedWidth clamp like ComputePlanCandidate does, which can make the
gfx950 fast path produce a different partition; update the CBA early path (the
block that computes vector_size before calling
ComputeChunkBlockAwarePlanCandidate and setting loop_layout_) to read
attr::kCoalescedWidth (or call the same clamp routine used by
ComputePlanCandidate) and clamp/reduce vector_size accordingly before the
floormod loop and before passing vector_size into
ComputeChunkBlockAwarePlanCandidate so loop_layout_ is computed consistently
with the normal plan path.

416-449: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Preserve explicit parallel-loop annotations ahead of the gfx950 override.

This branch runs before the cached kParallelLoopLayout/kParallelLoopPredicate adoption, so once cba validates it silently bypasses the explicit annotation path on gfx950. That changes precedence from “annotation wins” to “CBA wins” and also skips installing the annotated predicate.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/op/parallel.cc` around lines 416 - 449, The gfx950 override
(TargetIsGfx950 branch that computes cba via ComputeChunkBlockAwarePlanCandidate
and may set loop_layout_) runs before handling explicit annotations
(annotated_layout_unbound_ / annotated_predicate_) and thus can override and
skip the annotated predicate; change control flow so explicit annotations take
precedence: either move the annotated_layout_unbound_ / annotated_predicate_
block to execute before the TargetIsGfx950 block, or add a guard at the start of
the gfx950 branch that returns/continues if annotated_layout_unbound_.defined()
(and/or annotated_predicate_.defined()), ensuring
ValidateCandidateAgainstFragments and loop_layout_ assignment only run when no
explicit annotation exists.
🧹 Nitpick comments (1)
src/op/parallel.cc (1)

787-811: ⚡ Quick win

Gate CBA on actual swizzle metadata, not just buffer width.

ComputeChunkBlockAwarePlanCandidate(...) is documented as a swizzle-specific escape hatch, but the matcher here only checks for a non-Fragment layout plus last-dimension size heuristics. That means any wide unswizzled shared store on gfx950 can force the alternate binding even when there is no downstream swizzle-swap to unlock.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/op/parallel.cc` around lines 787 - 811, The matcher is enabling the
ComputeChunkBlockAwarePlanCandidate escape hatch based only on layout not being
Fragment and last-dimension width heuristics, which allows wide unswizzled
buffers to trigger the alternate binding; change the gate to require explicit
swizzle metadata instead of just buffer width: locate the block that reads
T.layout_map[buffer] and the early returns (the code that checks Layout layout =
T.layout_map[buffer]; if (layout.as<Fragment>()) return; ...), and add a check
for the actual swizzle flag/metadata (the same metadata used elsewhere to detect
swizzle-swap downstream — e.g., the buffer/store swizzle field or the function
that detects swizzled buffers) and only proceed to consider
ComputeChunkBlockAwarePlanCandidate when that swizzle metadata indicates a
swizzle will be used; leave the existing size/element checks intact but make
them subordinate to the new explicit swizzle presence check so unswizzled wide
stores no longer trigger the alternate binding.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Outside diff comments:
In `@src/op/parallel.cc`:
- Around line 417-430: The early CBA path computes vector_size using
GetVectorizeSize on maybe_remapped_root but never applies the
attr::kCoalescedWidth clamp like ComputePlanCandidate does, which can make the
gfx950 fast path produce a different partition; update the CBA early path (the
block that computes vector_size before calling
ComputeChunkBlockAwarePlanCandidate and setting loop_layout_) to read
attr::kCoalescedWidth (or call the same clamp routine used by
ComputePlanCandidate) and clamp/reduce vector_size accordingly before the
floormod loop and before passing vector_size into
ComputeChunkBlockAwarePlanCandidate so loop_layout_ is computed consistently
with the normal plan path.
- Around line 416-449: The gfx950 override (TargetIsGfx950 branch that computes
cba via ComputeChunkBlockAwarePlanCandidate and may set loop_layout_) runs
before handling explicit annotations (annotated_layout_unbound_ /
annotated_predicate_) and thus can override and skip the annotated predicate;
change control flow so explicit annotations take precedence: either move the
annotated_layout_unbound_ / annotated_predicate_ block to execute before the
TargetIsGfx950 block, or add a guard at the start of the gfx950 branch that
returns/continues if annotated_layout_unbound_.defined() (and/or
annotated_predicate_.defined()), ensuring ValidateCandidateAgainstFragments and
loop_layout_ assignment only run when no explicit annotation exists.

---

Nitpick comments:
In `@src/op/parallel.cc`:
- Around line 787-811: The matcher is enabling the
ComputeChunkBlockAwarePlanCandidate escape hatch based only on layout not being
Fragment and last-dimension width heuristics, which allows wide unswizzled
buffers to trigger the alternate binding; change the gate to require explicit
swizzle metadata instead of just buffer width: locate the block that reads
T.layout_map[buffer] and the early returns (the code that checks Layout layout =
T.layout_map[buffer]; if (layout.as<Fragment>()) return; ...), and add a check
for the actual swizzle flag/metadata (the same metadata used elsewhere to detect
swizzle-swap downstream — e.g., the buffer/store swizzle field or the function
that detects swizzled buffers) and only proceed to consider
ComputeChunkBlockAwarePlanCandidate when that swizzle metadata indicates a
swizzle will be used; leave the existing size/element checks intact but make
them subordinate to the new explicit swizzle presence check so unswizzled wide
stores no longer trigger the alternate binding.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: ae22d266-aaa9-487b-8132-08058c22919e

📥 Commits

Reviewing files that changed from the base of the PR and between 3d95235 and 7f8b103.

📒 Files selected for processing (1)
  • src/op/parallel.cc

@benenzhu benenzhu marked this pull request as draft May 23, 2026 13:31
benenzhu and others added 4 commits May 27, 2026 09:30
HoistBufferResource rewrites every ptx_cp_async_lds it can match to
the _rsrc form, which is the form codegen actually turns into the
buffer_load_dwordx4 ... lds fast path. The two corner cases where the
rewrite silently skips a call (empty buffer_vars early-return, or an
access_ptr shape _extract_buffer_var can't pattern-match) used to land
on a dedicated cp_async_gs_lds<N> codegen branch with its own
buffer_load asm path -- duplicating the fast path and creating a
second way to emit lds that could silently diverge from the _rsrc one.

Collapse to a single safety net: treat ptx_cp_async_lds identically to
ptx_cp_async in codegen so an unhoisted call just lowers to the
synchronous tl::cp_async_gs<N>. Correct, no buffer_load_lds win for
that particular call -- which is fine because in practice all calls
get hoisted. Removes the now-dead cp_async_gs_lds<N> template too and
updates the ptx_cp_async_lds docstring to reflect the new contract.

Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant