[ROCM] Fix buffer_load_lds support for gfx950#2248
Conversation
…arks Adds a triton-style view of per-kernel AMD GPU resource usage on the tilelang JITKernel, queryable as kernel.n_regs / n_spills / n_max_threads (with a richer resource_usage dict mapping kernel name to a KernelResourceUsage dataclass). Implementation: * tilelang/jit/adapter/hip_resource_info.py — passes -Rpass-analysis=kernel-resource-usage to hipcc, parses the per-kernel remarks (Function Name / VGPRs / VGPRs Spill / TotalSGPRs / etc.) out of the captured stdio, and *strips* those lines before the output is printed or included in error messages, so autotune logs don't drown in remark blocks while real warnings/errors still surface. Includes JSON (de)serialization helpers. * tilelang/contrib/hipcc.py — adds the remark flag, parses + filters the output. Same on the LibraryGenerator HIP path (tilelang/jit/adapter/libgen.py); HIP compiles always pipe stdio there so the filter has something to act on (verbose=True still prints the filtered output). * tilelang/jit/kernel.py — opens a thread-local recorder window around lower() on HIP and exposes the parsed dict as lazy resource_usage / n_regs / n_spills / n_max_threads properties. * tilelang/cache/kernel_cache.py — persists the parsed dict as resource_usage.json next to kernel_lib.so on cache miss; reloads it on cache hit. This way subsequent runs don't lose the resource view to the cache, without paying the runtime API / ctypes cost. Older cache entries (no JSON file) silently degrade to None. Verified on MI355X (gfx950) with a small elementwise add: cache miss and cache hit both report n_regs=5, n_spills=0; zero remark lines leak to stdout/stderr. Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
Extends src/tl_templates/hip/copy.h with two new HIP device templates that emit buffer_load_dwordx4 ... lds (the gfx950 direct global-to-LDS DMA that bypasses VGPRs): - cp_async_gs_lds<N>: self-contained variant; computes the buffer resource descriptor and base address per call. - cp_async_gs_lds_with_rsrc<N>: variant taking a pre-hoisted resource descriptor + base address. This is what the HoistBufferResource Python pass (Round 5 / M6) will rewrite calls to use, amortising the readfirstlane overhead across unrolled loops. Both templates only emit the direct-DMA path for N == 16; smaller copies fall back to the existing cp_async_gs<N>. The 16-byte path requires that the LDS destination be lane-contiguous (base + lane_id * 16); the swizzle-swap optimisation in lower_tile_op.cc (Round 6 / M7) guarantees this by moving the XOR swizzle from the LDS store side to the global load side. Reuses the existing make_wave_buffer_resource helper at copy.h:22 rather than redeclaring it. The inline-asm body is lifted from the reference branch zty_opt_can_run_1120flops because the asm is hardware-pinned to gfx950 and has no branch-specific dependencies. Function-to-site mapping audit (Round 0 / M1) located the insertion point for HIP codegen handlers at src/backend/rocm/codegen/codegen_hip.cc (refactored from the reference's src/target/codegen_hip.cc); those handlers land in Round 3 / M4. Verification: USE_ROCM=ON pip install -e . succeeds; `import tilelang` loads. Inline-asm validation deferred to JIT-compile time (header-only template; uninstantiated until codegen emits the call). Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
…nc_lds_rsrc TIR ops Declares three new TIR builtin ops on the tl namespace and registers them in builtin.cc. Inserted right after the existing ptx_cp_async registration so they appear in a coherent block: - ptx_cp_async_lds(dst, src, bytes): same shape as ptx_cp_async but signals to codegen that the call lowers to cp_async_gs_lds<N> (the hardware buffer_load ... lds path added in M2). - ptx_make_buffer_resource(global_ptr): single-arg op that lowers to make_wave_buffer_resource((const void*)(global_ptr)). - ptx_cp_async_lds_rsrc(dst, src, bytes, rsrc, base): extended form carrying the pre-hoisted resource descriptor and base address. The HoistBufferResource Python pass (M6) rewrites ptx_cp_async_lds calls to this form once per kernel. All three use the same call-effect kind (kOpaque) as ptx_cp_async since they have global memory side effects. set_num_inputs(-1) on the *_lds variants matches ptx_cp_async, which already uses -1 to support both predicated and non-predicated forms. Verification: pip install -e . succeeds; `import tilelang` and `from tvm import tir` both succeed. Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
…ttrStmt
Adds codegen support in src/backend/rocm/codegen/codegen_hip.cc for the
three TIR ops registered in M3, plus the AttrStmt/LetStmt machinery that
the HoistBufferResource Python pass (M6) will emit.
CallNode handlers (VisitExpr_):
- tl::ptx_make_buffer_resource(ptr) -> expression
make_wave_buffer_resource((const void*)(ptr))
- tl::ptx_cp_async_lds_rsrc(dst, src, bytes, rsrc, base) -> statement
tl::cp_async_gs_lds_with_rsrc<bytes>(dst, src, rsrc, base);
- tl::ptx_cp_async_lds(dst, src, bytes [, pred]) -> statement
tl::cp_async_gs_lds<bytes>(dst, src); (predicated form falls back to
cp_async_gs_conditional<bytes> like the regular ptx_cp_async)
LetStmt visitor (new): when the bound value is a ptx_make_buffer_resource
Call, emit `auto x = ...;` instead of letting the base CodeGenC try to
print a C-typed declaration for the int32x4_t result.
AttrStmt branches (extended):
- "buffer_resource_var": emit `auto {rsrc_vid} = make_wave_buffer_resource(
(const void*)({buf_ptr}));`
- "buffer_base_var": emit `uint32_t {base_vid} =
__builtin_amdgcn_readfirstlane((uint32_t)(uintptr_t)({buf_ptr}));`
These match the prologue shape demonstrated in
/root/tile-kernel-bench-cdna4/_fast.cpp lines 10-13. The hoisting pass
in M6 will wrap the kernel body with these AttrStmts so the descriptors
materialise at kernel entry rather than per call.
Verification: pip install -e . succeeds; `import tilelang` succeeds.
Inner-loop emission of cp_async_gs_lds_with_rsrc will be exercised once
M5 (injection decision) routes T.copy through the new ops.
Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
Adds an opt-in `enable_buffer_load_lds` knob to InjectPTXAsyncCopy and the underlying class, wired from src/backend/rocm/op/copy.cc only when TargetIsGfx950(T.target) holds. Default false everywhere else, so CUDA and non-gfx950 ROCm paths are unchanged. Routing in MakeCPAsyncStmtFromLoads emits tl::ptx_cp_async_lds(dst, src, total_bytes) instead of tl::ptx_cp_async(dst, src, num_elems) when ALL of the following hold: - the flag is on - the copy is non-predicated - total_bytes == 16 (matches the device template's only specialised N) - the destination buffer scope is "shared" or "shared.dyn" - the destination LDS index contains no bitwise_xor term (a conservative proxy for lane contiguity; before the M7 swizzle-swap optimisation lands, swizzled LDS layouts will contain XOR and so the routing safely no-ops back to the existing ptx_cp_async path) Note arg 2 is byte width, not logical element count. The codegen handler for ptx_cp_async_lds (M4) prints arg 2 verbatim as the template width because the device template is currently only specialised for N == 16. Keeping the logical-vs-byte distinction at the boundary avoids the hazard Codex flagged where a blind op swap could otherwise emit cp_async_gs_lds<1> or <8>. Side artefacts: CopyIndexInfo gains a total_bytes field already computed in PrepareCopyIndexInfo; MakeCPAsyncStmtFromLoads loses its static qualifier so it can read enable_buffer_load_lds_ from `this`. Downstream cp.async recognisers updated so they treat the new ops the same as the existing ones: - src/transform/lower_ptx_async_copy.cc AnalyzeCopyRegion - src/transform/legalize_safe_memory_access.cc IsCPAsyncOp - src/transform/vectorize_loop.cc Call dispatch - src/transform/thread_storage_sync.cc is_cp_async lambda Verification: pip install -e . succeeds; `import tilelang` succeeds. Full bench correctness/perf evaluation is the M8 job after M6 (hoisting) and M7 (swizzle-swap) land. Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
New file tilelang/transform/hoist_buffer_resource.py with the descriptor-hoist half of the reference implementation. Scans the post-LowerAccessPtr PrimFunc body for tl.ptx_cp_async_lds calls, extracts the source buffer Var from each call's tvm_access_ptr arg, creates one __rsrc_<buf> and one __base_<buf> Var per unique source buffer, wraps the body with buffer_resource_var / buffer_base_var AttrStmts (consumed by the HIP codegen handlers from M4), and rewrites the calls to ptx_cp_async_lds_rsrc carrying the hoisted vars. Registered in tilelang/transform/__init__.py alongside the existing HoistBroadcastValues / DecoupleTypeCast imports. Inserted into tilelang/engine/phase.py OptimizeForTarget after MergeIfStmt, before MakePackedAPI, matching Codex's directive site so LowerAccessPtr has already lowered tl.access_ptr to tvm_access_ptr by the time the pass runs. The pass guards on target_is_gfx950(target) and is a no-op on every other target. AMD vmcnt wait-count scaling (the second half of the reference) is intentionally NOT included in this commit; it will land as a separate milestone (M6.5) only if the M6 / M7 bench shows a correctness failure attributable to async wait counts. Verification: pip install -e . succeeds; `from tilelang.transform import HoistBufferResource` imports the callable. End-to-end emission will be exercised once M7 lands the swizzle-swap optimisation that removes the XOR from LDS-side indices so M5's IsLdsLaneContiguous gate begins routing copies to ptx_cp_async_lds. Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
Two changes that together let the M5 IsLdsLaneContiguous gate route the
target shape's g2s copies into the gfx950 buffer_load...lds path.
1. Layout swizzle-delta API (src/layout/layout.h, layout.cc):
- New Optional<PrimExpr> swizzle_delta_ member on LayoutNode.
- virtual PrimExpr SwizzleDelta(input_indices): substitutes the last
InputDim() entries of input_indices into the stored delta (same
convention as Forward), returns 0 when no delta is set.
- bool HasSwizzle() / void SetSwizzleDelta(): trivial accessors.
- LayoutNode::Expand propagates swizzle_delta_ through the variable
remap that shifts InputPlaceholders by the leading-shape offset.
2. Swizzle factories now record their column-XOR delta
(src/layout/gemm_layouts.cc) so HasSwizzle() returns true for the
layouts the GEMM tile-op actually produces:
- MakeQuarterBankSwizzleLayout2D: (xor2x2(c, s>>2) - c) * vec
- MakeHalfBankSwizzleLayout2D: (xor4x4(c, s>>1) - c) * vec
- MakeFullBankSwizzleLayout2D: (xor8x8(c, s) - c) * vec
3. Swizzle-swap in src/transform/lower_tile_op.cc BufferStoreNode
visitor: when TargetIsRocm && !is_ptx_ && IsSharedBuffer(buffer) &&
layout has a swizzle AND the store value is a direct global
BufferLoad of matching arity AND the layout output has unit leading
dim, rewrite
shared[Forward(local)] = global[base + local]
to
shared[Forward(local) - delta(local)] = global[base + local + delta(local)]
XOR is self-inverse so net data movement is unchanged, but the LDS
destination becomes lane-contiguous and the cp.async injector's
IsLdsLaneContiguous check stops rejecting the store. Other layouts
(no HasSwizzle, non-ROCm, ptx path, non-shared dst, non-direct-load
stores) fall through to the existing Forward-only path.
Diagnostic LOG(INFO) noise from the reference branch is intentionally
omitted. Verification: pip install -e . succeeds; `import tilelang`
succeeds. End-to-end emission + perf is the M8 job.
Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
Without this pass, kernels with multiple `shared.dyn` buffers (which includes the bench's matmul with A_shared + B_shared) fail LowerDeviceKernelLaunch with "Only one dynamic shared memory allocation is allowed". The pass was commented out on this branch with no explanation; re-enabling it lets the bench compile through to codegen + kernel launch. This is a prerequisite for M8 (perf gate). Pre-existing constraint, not related to the buffer_load_lds work; landed as a separate small milestone for traceability per AC-3. Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
…odegen num_elems convention The Round 1 emission had an outer `for vec=0..8` loop wrapping each tl::cp_async_gs_lds_with_rsrc<16> call, causing 8x overlapping LDS writes per thread and a runtime GPU coredump. Root cause: tl::ptx_cp_async stores its transfer width as a *logical element count* and several downstream passes (vectorize_loop.cc::MutatePTXCPAsyncExpr_, loop_vectorize.cc dispatch, merge_shared_memory_allocations.cc rewrite) key off the op identity to widen/remap. Earlier I emitted ptx_cp_async_lds with arg 2 in *bytes*, so those passes either skipped the call (no widening = vec loop survived) or could not remap the merged shared buffer. This commit makes the ptx_cp_async_lds family follow the same logical- element-count convention as tl::ptx_cp_async: - src/transform/lower_ptx_async_copy.cc: MakeCPAsyncStmtFromLoads's LDS branch now passes `num_elems` (not `total_bytes`) as arg 2. - src/backend/rocm/codegen/codegen_hip.cc: the tl::ptx_cp_async and tl::ptx_cp_async_lds CallNode handlers are merged; both go through GetTileLangCPAsyncTransferBytes to derive the template byte width. The tl::ptx_cp_async_lds_rsrc handler does the same conversion inline (it has 5 args so cannot reuse GetTileLang...). - src/transform/vectorize_loop.cc: GetCPAsyncBitsPerCall and MutatePTXCPAsyncExpr_ ICHECKs widen to accept tl::ptx_cp_async_lds. This lets the vec(k) widening multiply num_elems by k and produces one call covering the full vec range, so the surrounding vec loop's body becomes uniform and the loop is consumed by the vectorizer. - src/transform/loop_vectorize.cc: same widening for the ScalarToVector pipeline. - src/transform/merge_shared_memory_allocations.cc: extend the cp_async dst-ptr remap to recognise tl::ptx_cp_async_lds so the merged dyn-shared buffer pointer is substituted correctly. Result on the bench (8192x8192x8192 NT, tile 256x256x64): emission now matches `_fast.cpp`'s outer-loop shape verbatim (tl::cp_async_gs_lds_with_rsrc<16> per i_1 iteration, no inner vec loop). Bench compile passes, kernel launches without coredump, ~867 TFLOPS. Correctness still fails because the swizzle-swap from M7 is not engaging on the CDNA layout, leaving the XOR on the LDS-store side; that is the next debug step. Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
Codex Round 1 review flagged the previous IsLdsLaneContiguous as unsafe: it only rejected expressions containing a tir.bitwise_xor call, but the swizzle layouts on this branch expand the XOR into shift/and/add bit arithmetic (see e.g. xor2x2 in src/layout/gemm_layouts.cc:370). The expanded form passed the gate even when the LDS destination was not lane-contiguous. Replace IsLdsLaneContiguous with a multi-sample affine proof modelled on /root/backuptilelang/src/transform/inject_ptx_async_copy.cc::IsLdsContiguous: - Walk the dst index, collect free Vars. - Pick a thread-like Var (name contains "thread" or equals "tx"/"tid"). - Compute f(0), f(1), expected_stride = f(1) - f(0); require IntImm. - Sample at points 2..1023; require analyzer.CanProveEqual(f(k) - f(0), k * stride) at every sample. Covers wave-32/64 boundaries, warp tiles, bank-swizzle phases (powers of two up to 64), and the 256-thread block boundaries the bench uses. Returns false if no thread-like var is found or any sample disagrees, so the LDS route safely no-ops back to the existing ptx_cp_async path until M9 actually moves the swizzle off the LDS side. Removes ContainsBitwiseXor helper (no longer needed). Build + import OK. Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
…ix for AC-2/AC-1)
The M7 swap in VisitStmt_(BufferStoreNode) never fired on this branch
because rocm::Copy::Lower calls InjectPTXAsyncCopy inline and the
BufferStores into A_shared / B_shared are converted to
tl::ptx_cp_async_lds Calls before LowerTileOpPass re-visits the
lowered Stmt (`return IRMutatorWithAnalyzer::VisitStmt(lowered)` at
~line 1130). Diagnosed by adding LOG(WARNING) at the top of the
BufferStore visitor: it only ever saw C_local, A_local, B_local, C
- never the shared buffers.
Per Codex Round 1 review's directive, add the swap on the resulting
Call node instead. New VisitExpr_(CallNode) branch that, for
tl::ptx_cp_async_lds on a ROCm target with a swizzled remapped shared
destination and a direct global source BufferLoad:
1. resolve_load() pulls BufferLoad from the tl::access_ptr arg
(handles direct BufferLoad and let-bound BufferLoad via
let_bindings_)
2. compute swizzled = layout->Forward(dst_logical_indices) and
delta = layout->SwizzleDelta(dst_logical_indices)
3. build new dst access_ptr against buffer_remap_[dst_buf] with
swizzled[last] - delta on the last dim (lane-contiguous LDS)
4. build new src access_ptr with src_logical_indices[last] + delta
on the last dim (XOR moved to the global side; self-inverse)
5. return rewritten Call directly so the default arg visitor does
not re-apply the swizzled layout to the destination
Rank-difference between dst and src is allowed: the LDS dst typically
has a leading Expand dim that the global src does not, but the
"last dim" of each maps to the same physical column so the swap is
well-defined on the last dim alone.
Bench result with M5-harden + M9 (M6.5 still pending): correctness
PASS, TFLOPS 866 (still below 1000 floor — M6.5 wait-count scaling
should close the gap). Emission shape now matches _fast.cpp exactly:
LDS dst is `i_1 * 4096 + threadIdx.x * 8` (linear), global src carries
the bank-swizzle pattern.
Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
Codex Round 1 review flagged the generated cp_async_wait<1> vs the target's cp_async_wait<8> as a likely 1000 TFLOPS blocker, and after M9 fixed the swizzle direction the bench correctness passed at only 866 TFLOPS - confirming Codex's analysis. Add _get_loads_per_group + _fix_amd_wait_counts to tilelang/transform/hoist_buffer_resource.py, matching the reference algorithm but rewriting on this branch's IR shape: - _is_async_load_call recognises Evaluate(Call(ptx_cp_async_lds OR ptx_cp_async_lds_rsrc, ...)). - _is_commit_call recognises Evaluate(Call(ptx_commit_group, ...)) (the existing pipeline emits these directly; the reference branch's async_commit_queue_scope AttrStmt is already lowered by the time we run). - _find_for_with_commit walks the body to find the innermost For whose subtree contains a commit; _count_async_loads counts async loads inside one iteration of that For, multiplying by IntImm loop extents for nested unrolls. - _fix_amd_wait_counts rewrites Evaluate(Call(ptx_wait_group, [n])) to Evaluate(Call(ptx_wait_group, [n * loads_per_group])) when n > 0. wait_group(0) (wait-all) stays unchanged. The pass runs the scaling AFTER the existing descriptor hoist so the post-hoist body (now using ptx_cp_async_lds_rsrc) is the input to the load counter; the recognised loads include the rsrc form. Bench result with M5-harden + M9 + M6.5: [iter] correctness: PASS [iter] latency: 0.9803 ms [iter] TFLOPS: 1121.62 [iter] TB/s: 0.411 Above the AC-1 hard floor of 1000. Matches the reference branch's "1120 TFLOPS now." commit message. Emitted .hipi has the target tl::cp_async_wait<8> at the K-loop wait and tl::cp_async_wait<0> at the final wait; .s contains 16 buffer_load_dwordx4 ... lds occurrences. Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
…roduce linear LDS The M9 swizzle-swap subtracts SwizzleDelta from the LAST output dim of Forward(dst_indices). That cancels the XOR cleanly for layouts whose swizzle is confined to the last column (the bench's A_shared layout, and B_shared in the bench's symmetric NT shape), but layouts where the swizzle spreads across multiple output dims leave residual swizzle in the higher output dims AND over-correct src. The dst LDS index is no longer lane-contiguous, so the emitted buffer_load_dwordx4 ... lds writes to wrong addresses. Concrete reproducer: python gemm/example_gemm.py at the default 1024x1024x16384 shape. A's LDS dst comes out as `i_1 * 1024 + threadIdx.x * 8` (clean linear) but B's comes out as `((tx & 15) >> 3) * 2048 + i_2 * 512 + (tx >> 4) * 64 + (tx & 7) * 8` (NOT lane-contiguous), and B's global src ends up with a spurious `- ((tx & 7) * 8)` term. Bench data 99.7% mismatched. Fix has two parts: 1. Restructure the M9 handler so EITHER outcome (swap-or-skip) is a well-defined action: if the candidate gate (`dst_ap` / `src_ap` / shared dst / global src / remap / HasSwizzle / non-empty indices) does not hold, immediately downgrade the op from tl::ptx_cp_async_lds back to tl::ptx_cp_async (same arg shape) and recurse with the default visitor. Without this, a call that the guard rejected would keep its ptx_cp_async_lds op and codegen would emit the LDS template against an unmodified (swizzled) dst. 2. After the swap math runs, sample each post-swap new_dst_indices[d] against a thread-like Var: compute f(0), f(1), expected_stride, then require analyzer.CanProveEqual(f(k)-f(0), k*stride) at sample points 2..63. If any dim is non-affine in the thread var (bit- extract `(tx & m) >> s` terms etc.), downgrade as above. Verified on two shapes: - python gemm/example_gemm.py (1024x1024x16384): All check passed. Latency 0.0443 ms. - bash scripts/iter_buffer_load.sh (8192x8192x8192): correctness PASS, TFLOPS 1116.84 (A still uses LDS fast path; B symmetric here also passes the linearity check). Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
Existing CBA helper in ComputePlanCandidate only runs when source_buffer is undefined. For T.copy(B, B_shared) at kCommon/kStrict, source_buffer chain takes B_shared and routes through ComputeLoopLayoutFromBuffer, so CBA never fires and the default PlanLoopPartition flatten puts 16 lanes on N (crossing the FullBank tc=2 boundary at lane 8 -> +1920B jump). Add an early hook at the top of ParallelOp::InferLayout that runs CBA at all 3 levels and sets loop_layout_ unconditionally when an eligible target buffer exists. CBA's gate (last_dim_bytes > 128B and divisible) keeps it inert on A_shared / NT-B_shared / C_local cases. Effect: - NN 1024^3 K=16384 128x128x32: correctness PASS (was wrong values before) - NT 8192^2 K=8192 256x256x64 stages=2: 1111 TFLOPS, correctness PASS (NT B last-dim = K = 64 = 1 bank cycle, gate skips; perf unchanged) Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
Previously the IsLdsLaneContiguous gate at lower_ptx_async_copy.cc:574 rejected any swizzled (XOR-laden) LDS index, falling through to ptx_cp_async. That cut M9 out of the loop for FullBank tc>1 cases (NN B under the new CBA binding from M10), even though M9's SwizzleDelta swap would produce a lane-contiguous index. Drop the IsLdsLaneContiguous gate. Always emit ptx_cp_async_lds when the destination is shared + 16B + non-predicated. M9 then either rewrites the call (LDS path stays) or downgrades to ptx_cp_async if post-swap is non-affine. Effect: - NN 8192^3 128x128x32 stages=3 threads=128: 443 TFLOPS (was 370) - NN 8192^3 256x256x32 stages=3 threads=256: 602 TFLOPS, PASS - NN 8192^3 256x256x64 stages=2 threads=256: 760 TFLOPS, PASS - NT 8192^2 K=8192 256x256x64 stages=2 threads=512: 1116 TFLOPS, PASS A_shared still downgrades (M9 affine check fails because HalfBank's ts boundary splits the warp), so A continues on cp_async path. Fixing A needs CBA-style binding on the stride dim — open work. Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
Adds in-source Chinese comments tracing the FullBank swizzle math (ts/s/tc/c/vec/c_swizzle/index, with concrete dimensions for the NN B_shared=32x128 bf16 case) so the layout's role in the M10/M11 chunk-block-aware binding work is documented at the source. Pure comment additions, no behavior change. Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
…ons" This reverts commit df611d3.
After always-emit-LDS landed (let the LowerTileOp swizzle-swap visitor own the lane-contiguity decision and downgrade when it can't), the IsLdsLaneContiguous sampler and the dst_check_index parameter to MakeCPAsyncStmtFromLoads are no longer reachable. Remove them so the upstream-facing diff has no dead code; behavior is unchanged. Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
The chunk-block-aware binding now lands via the early hook at the top of ParallelOp::InferLayout (runs at all 3 levels, before source_buffer dispatch). By the time PlanLoopPartition's ComputePlanCandidate runs, loop_layout_ is already set when CBA applies, so the second ComputeChunkBlockAwarePlanCandidate call site here is dead in practice. Removed it; behavior is unchanged. Verified: - NT 8192^2 K=8192 256x256x64 stages=2 threads=512: 1114 TFLOPS, PASS - NN 1024x1024x16384 128x128x32 stages=3 threads=128: 64 TFLOPS, PASS Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
Migrate buffer_load_lds + hoist + AMD vmcnt scaling work to upstream's
tir->tirx namespace rename and LetStmtNode->BindNode change. Pick up
all main improvements (Metal backend, SM75 MMA, cluster TMA, etc.)
and the tilelang_repo-side n_regs/n_spills exposure.
Conflicts resolved:
- src/backend/rocm/codegen/codegen_hip.{cc,h}: keep new BindNode
override for ptx_make_buffer_resource on top of main's
AllocateNode->AllocBufferNode rename.
- src/transform/lower_ptx_async_copy.cc: keep gfx950 ptx_cp_async_lds
routing block; use main's bare Array type alias.
- src/transform/ptx_async_copy_injector.h: combine main's tirx::Stmt
type with the extra enable_buffer_load_lds parameter.
- src/transform/lower_tile_op.cc: rename tir:: -> tirx:: in the
swizzle-swap affine prover.
- tilelang/transform/hoist_buffer_resource.py: port to tvm.tirx;
drop dead LetStmt branch; bump op names tir.* -> tirx.*.
- tilelang/contrib/hipcc.py: take main's _resolve_artifact_paths.
- tilelang/contrib/hip_resource_info.py: drop docstring whitespace nit.
- tilelang/jit/kernel.py: drop dead Py3.9 ParamSpec compat shim.
- 3rdparty/tvm: take main's pointer (8435b89).
Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
|
Important Review skippedDraft detected. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
📝 WalkthroughWalkthroughAdds gfx950-optimized global→LDS async copy support: new TL intrinsics and 16-byte LDS templates, swizzle-delta layout metadata and propagation, HIP emission and hoisted resource locals, injector routing for eligible transfers, LowerTile index rewrites, ParallelOp planning, and supporting pass updates. Changesgfx950 Async LDS Copy and Swizzle Optimization
🎯 4 (Complex) | ⏱️ ~60 minutes Possibly Related PRs
Suggested Reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
cc @zhangnju |
There was a problem hiding this comment.
Actionable comments posted: 5
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/op/parallel.cc (1)
404-436:⚠️ Potential issue | 🟠 Major | ⚡ Quick winPreserve annotated layout/predicate precedence over CBA auto-planning.
At Line 411, the CBA override runs before annotation adoption, so explicit
kParallelLoopLayout/kParallelLoopPredicatecan be silently bypassed when CBA matches. That changes declared layout semantics and can drop intended predicate guards.Proposed fix
- if (!loop_layout_.defined()) { + // Keep explicit annotations as highest-priority contract. + if (!loop_layout_.defined() && !annotated_layout_unbound_.defined()) { // Reuse the same vec_size calculation as ComputePlanCandidate. auto maybe_remapped_root = IfBufferRemapLoopGenerator::run(root_, T.buffer_remap, T.layout_map); @@ } if (!loop_layout_.defined() && annotated_layout_unbound_.defined()) { loop_layout_ = annotated_layout_unbound_.value()->BindThreadRange(T.thread_bounds); if (annotated_predicate_.defined()) { predicate_ = annotated_predicate_.value(); } } else if (!loop_layout_.defined() && source_buffer.defined() && allow_layout_propgate) {🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@src/op/parallel.cc` around lines 404 - 436, The CBA auto-plan (ComputeChunkBlockAwarePlanCandidate + loop_layout_ assignment) is run before honoring explicit annotations, which lets CBA silently override annotated kParallelLoopLayout/kParallelLoopPredicate; reorder logic so annotated_layout_unbound_ is applied first: if annotated_layout_unbound_.defined(), call annotated_layout_unbound_.value()->BindThreadRange(T.thread_bounds) into loop_layout_ and set predicate_ from annotated_predicate_ (if defined) before running ComputeChunkBlockAwarePlanCandidate/GetVectorizeSize; only run the CBA fallback when loop_layout_ remains undefined after adopting annotations. Use the existing symbols loop_layout_, annotated_layout_unbound_, annotated_predicate_, ComputeChunkBlockAwarePlanCandidate and the surrounding block that computes vector_size to implement this reorder.
🧹 Nitpick comments (2)
tilelang/transform/hoist_buffer_resource.py (2)
108-125: 💤 Low valueSymbolic loop extents are not multiplied into the load count.
When a
Forloop has a non-IntImmextent (e.g., a dynamic/symbolic extent), the multiplier stays unchanged. This could undercount async loads for dynamically-sized inner loops, leading to wait counts that are too small.Given the PR targets unrolled tile-copy loops on gfx950 (which typically have constant extents), this is likely acceptable, but consider logging a warning or asserting
IntImmwhen the pass is active.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tilelang/transform/hoist_buffer_resource.py` around lines 108 - 125, _count_async_loads currently ignores non-IntImm loop extents in tir.For, so symbolic/dynamic extents aren't multiplied into the async load count and can undercount waits; update _count_async_loads to detect when stmt is a tir.For with a non-IntImm extent and either 1) conservatively treat the extent as unknown by logging a warning or raising/asserting (e.g., assert isinstance(stmt.extent, tir.IntImm)) to fail fast, or 2) conservatively multiply by a safe upper bound if available; locate the tir.For handling in _count_async_loads and add the chosen behavior (warning/assert) to ensure dynamic extents don’t silently produce incorrect counts.
90-105: 💤 Low valueThe
hasattr(stmt, "body")fallthrough is overly generic.This branch will match any node type with a
bodyattribute, which could include unexpected statement types. Consider explicitly handling the known cases (e.g.,tir.LetStmt,tir.AssertStmt) or at minimum documenting what this branch is intended to cover.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tilelang/transform/hoist_buffer_resource.py` around lines 90 - 105, The fallback in _find_for_with_commit uses a generic hasattr(stmt, "body") which is too broad; replace it with explicit type checks for the statement kinds that actually wrap a body (e.g., check isinstance(stmt, tir.LetStmt) and isinstance(stmt, tir.AssertStmt) and any other known wrappers in this project such as tir.IfThenElse or tir.AttrStmt) and recursively call _find_for_with_commit only for those cases, or add a short comment listing the intended covered types if keeping a generic branch; update the branch for hasattr(stmt, "body") to handle only the explicit classes (and return None otherwise) so unrelated nodes with a body attribute do not get traversed unexpectedly.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@src/backend/rocm/codegen/codegen_hip.cc`:
- Around line 926-955: The ptx_cp_async_lds_rsrc branch computes total_bytes
from the destination element type only and doesn't validate source/destination
element-width equality or enforce allowed transfer widths; update the branch
handling in codegen_hip.cc (the ptx_cp_async_lds_rsrc branch that uses
GetAccessPtrElementType, num_elems_imm, total_bits/total_bytes and emits
tl::cp_async_gs_lds_with_rsrc) to: retrieve/ICHECK a src element type via
GetAccessPtrElementType(op->args[1]) like you do for dst, ICHECK both src and
dst element types exist and have identical bits() and lanes(), compute
per-element bytes = dst_elem_type.bits()*dst_elem_type.lanes()/8 and ICHECK it
is one of {4,8,16}, then compute total_bytes as before and emit the call using
that validated size so mismatched or unsupported widths fail early.
In `@src/layout/layout.cc`:
- Around line 518-531: In LayoutNode::SwizzleDelta ensure you check that
input_indices.size() is at least InputDim() before using input_indices[offset +
i]; add a precondition at the start (e.g., an explicit check/ICHECK/CHECK and a
clear error message) that fails fast when input_indices.size() < InputDim(), or
alternatively clamp/guard the loop so Substitute never indexes past
input_indices; reference the function LayoutNode::SwizzleDelta, the variables
input_indices, InputDim(), offset, and InputPlaceholder(i) when adding the check
and message.
In `@src/op/builtin.h`:
- Around line 526-527: The documentation for the PTX intrinsic ptx_cp_async_lds
(and the other affected intrinsic docstrings) currently states the third
argument is "bytes" but the lowering code treats arg3 as a logical element count
and computes byte size from the element dtype; update the docstrings to say the
third parameter is "num_elems" (logical element count) and explicitly note that
byte width is derived from the element dtype at lowering so callers should pass
element count not raw bytes for functions like ptx_cp_async_lds and the
similarly-documented intrinsics around the same block.
In `@src/transform/vectorize_loop.cc`:
- Around line 683-684: The ICHECK that currently allows only tl::ptx_cp_async()
and tl::ptx_cp_async_lds() should also accept tl::ptx_cp_async_lds_rsrc() so the
hoisted-resource path isn't hard-failed; update the guards in
MutatePTXCPAsyncExpr_ and GetCPAsyncBitsPerCall (the ICHECKs that reference
tl::ptx_cp_async() and tl::ptx_cp_async_lds()) to include
tl::ptx_cp_async_lds_rsrc() as an allowed op, ensuring all places that validate
cp.async variants (including the occurrences around the other noted blocks)
treat the new op the same as the existing ones.
---
Outside diff comments:
In `@src/op/parallel.cc`:
- Around line 404-436: The CBA auto-plan (ComputeChunkBlockAwarePlanCandidate +
loop_layout_ assignment) is run before honoring explicit annotations, which lets
CBA silently override annotated kParallelLoopLayout/kParallelLoopPredicate;
reorder logic so annotated_layout_unbound_ is applied first: if
annotated_layout_unbound_.defined(), call
annotated_layout_unbound_.value()->BindThreadRange(T.thread_bounds) into
loop_layout_ and set predicate_ from annotated_predicate_ (if defined) before
running ComputeChunkBlockAwarePlanCandidate/GetVectorizeSize; only run the CBA
fallback when loop_layout_ remains undefined after adopting annotations. Use the
existing symbols loop_layout_, annotated_layout_unbound_, annotated_predicate_,
ComputeChunkBlockAwarePlanCandidate and the surrounding block that computes
vector_size to implement this reorder.
---
Nitpick comments:
In `@tilelang/transform/hoist_buffer_resource.py`:
- Around line 108-125: _count_async_loads currently ignores non-IntImm loop
extents in tir.For, so symbolic/dynamic extents aren't multiplied into the async
load count and can undercount waits; update _count_async_loads to detect when
stmt is a tir.For with a non-IntImm extent and either 1) conservatively treat
the extent as unknown by logging a warning or raising/asserting (e.g., assert
isinstance(stmt.extent, tir.IntImm)) to fail fast, or 2) conservatively multiply
by a safe upper bound if available; locate the tir.For handling in
_count_async_loads and add the chosen behavior (warning/assert) to ensure
dynamic extents don’t silently produce incorrect counts.
- Around line 90-105: The fallback in _find_for_with_commit uses a generic
hasattr(stmt, "body") which is too broad; replace it with explicit type checks
for the statement kinds that actually wrap a body (e.g., check isinstance(stmt,
tir.LetStmt) and isinstance(stmt, tir.AssertStmt) and any other known wrappers
in this project such as tir.IfThenElse or tir.AttrStmt) and recursively call
_find_for_with_commit only for those cases, or add a short comment listing the
intended covered types if keeping a generic branch; update the branch for
hasattr(stmt, "body") to handle only the explicit classes (and return None
otherwise) so unrelated nodes with a body attribute do not get traversed
unexpectedly.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: af15cc00-88dc-44d3-bf99-8bd52162e26b
📒 Files selected for processing (22)
src/backend/rocm/codegen/codegen_hip.ccsrc/backend/rocm/codegen/codegen_hip.hsrc/backend/rocm/op/copy.ccsrc/layout/gemm_layouts.ccsrc/layout/layout.ccsrc/layout/layout.hsrc/op/builtin.ccsrc/op/builtin.hsrc/op/parallel.ccsrc/op/parallel.hsrc/tl_templates/hip/copy.hsrc/transform/legalize_safe_memory_access.ccsrc/transform/loop_vectorize.ccsrc/transform/lower_ptx_async_copy.ccsrc/transform/lower_tile_op.ccsrc/transform/merge_shared_memory_allocations.ccsrc/transform/ptx_async_copy_injector.hsrc/transform/thread_storage_sync.ccsrc/transform/vectorize_loop.cctilelang/engine/phase.pytilelang/transform/__init__.pytilelang/transform/hoist_buffer_resource.py
- codegen_hip: route ptx_cp_async_lds_rsrc through
GetTileLangCPAsyncTransferBytes so src/dst widths must match and
the final byte width is validated against {4,8,16}.
- layout: ICHECK_GE the index count in SwizzleDelta so a too-short
call fails with a clear contract error instead of OOB-reading.
- builtin.h: fix ptx_cp_async_lds / _rsrc docstrings -- arg 2 is
num_elems, not bytes (lowering derives byte width from the
access-ptr dtype).
- lower_tile_op: drop the "tx"/"tid"/"thread*" name-matching in the
affine lane-contiguity proof and key off the real threadIdx.x
binding tracked in thread_var_ so a future rename of the lane var
can't silently misclassify a non-affine LDS index as OK.
- vectorize_loop: accept ptx_cp_async_lds_rsrc in
MutatePTXCPAsyncExpr_ / GetCPAsyncBitsPerCall (it's already routed
here from VisitExpr_) and preserve the trailing (rsrc, base) args
through the rewrite so codegen still sees them.
Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
The unconditional override in the early CBA branch of InferLayout was forcing the loop layout onto MMA accumulators like acc_o_l / C_local / dq, whose fragment is already in T.layout_map with its own MMA-derived binding. The follow-up ValidateCandidateAgainstFragments call would then throw a "Layout infer conflict between <buf> and <buf>" error. Only adopt the CBA candidate when it actually validates against the existing fragments; otherwise fall through to the normal source-buffer / free-inference paths so the MMA fragment's binding still wins. Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
The chunk-block-aware override exists solely to make buffer_load_dwordx4...lds usable on gfx950; firing it on CUDA / older AMD targets can only force the loop binding into a shape that conflicts with MMA fragment layouts (the CUDA CI hit this on acc_o_l / C_local / dq / C_local_accum). Gate on TargetIsGfx950 so every other target keeps its existing layout-inference behaviour untouched. The validation fallback added in the previous commit stays as defense-in-depth. Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
src/op/parallel.cc (2)
417-430:⚠️ Potential issue | 🟠 Major | ⚡ Quick winApply the same
kCoalescedWidthclamp in the early CBA path.The comment says this reuses
ComputePlanCandidate(...)'s vector-size calculation, but it stops before theattr::kCoalescedWidthoverride on Lines 731-744. If that annotation is present, the gfx950 fast path can build a different partition than the normal plan path, and the later code never corrects it becauseloop_layout_is already set.Also applies to: 731-744
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@src/op/parallel.cc` around lines 417 - 430, The early CBA path computes vector_size using GetVectorizeSize on maybe_remapped_root but never applies the attr::kCoalescedWidth clamp like ComputePlanCandidate does, which can make the gfx950 fast path produce a different partition; update the CBA early path (the block that computes vector_size before calling ComputeChunkBlockAwarePlanCandidate and setting loop_layout_) to read attr::kCoalescedWidth (or call the same clamp routine used by ComputePlanCandidate) and clamp/reduce vector_size accordingly before the floormod loop and before passing vector_size into ComputeChunkBlockAwarePlanCandidate so loop_layout_ is computed consistently with the normal plan path.
416-449:⚠️ Potential issue | 🟠 Major | ⚡ Quick winPreserve explicit parallel-loop annotations ahead of the gfx950 override.
This branch runs before the cached
kParallelLoopLayout/kParallelLoopPredicateadoption, so oncecbavalidates it silently bypasses the explicit annotation path on gfx950. That changes precedence from “annotation wins” to “CBA wins” and also skips installing the annotated predicate.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@src/op/parallel.cc` around lines 416 - 449, The gfx950 override (TargetIsGfx950 branch that computes cba via ComputeChunkBlockAwarePlanCandidate and may set loop_layout_) runs before handling explicit annotations (annotated_layout_unbound_ / annotated_predicate_) and thus can override and skip the annotated predicate; change control flow so explicit annotations take precedence: either move the annotated_layout_unbound_ / annotated_predicate_ block to execute before the TargetIsGfx950 block, or add a guard at the start of the gfx950 branch that returns/continues if annotated_layout_unbound_.defined() (and/or annotated_predicate_.defined()), ensuring ValidateCandidateAgainstFragments and loop_layout_ assignment only run when no explicit annotation exists.
🧹 Nitpick comments (1)
src/op/parallel.cc (1)
787-811: ⚡ Quick winGate CBA on actual swizzle metadata, not just buffer width.
ComputeChunkBlockAwarePlanCandidate(...)is documented as a swizzle-specific escape hatch, but the matcher here only checks for a non-Fragmentlayout plus last-dimension size heuristics. That means any wide unswizzled shared store on gfx950 can force the alternate binding even when there is no downstream swizzle-swap to unlock.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@src/op/parallel.cc` around lines 787 - 811, The matcher is enabling the ComputeChunkBlockAwarePlanCandidate escape hatch based only on layout not being Fragment and last-dimension width heuristics, which allows wide unswizzled buffers to trigger the alternate binding; change the gate to require explicit swizzle metadata instead of just buffer width: locate the block that reads T.layout_map[buffer] and the early returns (the code that checks Layout layout = T.layout_map[buffer]; if (layout.as<Fragment>()) return; ...), and add a check for the actual swizzle flag/metadata (the same metadata used elsewhere to detect swizzle-swap downstream — e.g., the buffer/store swizzle field or the function that detects swizzled buffers) and only proceed to consider ComputeChunkBlockAwarePlanCandidate when that swizzle metadata indicates a swizzle will be used; leave the existing size/element checks intact but make them subordinate to the new explicit swizzle presence check so unswizzled wide stores no longer trigger the alternate binding.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Outside diff comments:
In `@src/op/parallel.cc`:
- Around line 417-430: The early CBA path computes vector_size using
GetVectorizeSize on maybe_remapped_root but never applies the
attr::kCoalescedWidth clamp like ComputePlanCandidate does, which can make the
gfx950 fast path produce a different partition; update the CBA early path (the
block that computes vector_size before calling
ComputeChunkBlockAwarePlanCandidate and setting loop_layout_) to read
attr::kCoalescedWidth (or call the same clamp routine used by
ComputePlanCandidate) and clamp/reduce vector_size accordingly before the
floormod loop and before passing vector_size into
ComputeChunkBlockAwarePlanCandidate so loop_layout_ is computed consistently
with the normal plan path.
- Around line 416-449: The gfx950 override (TargetIsGfx950 branch that computes
cba via ComputeChunkBlockAwarePlanCandidate and may set loop_layout_) runs
before handling explicit annotations (annotated_layout_unbound_ /
annotated_predicate_) and thus can override and skip the annotated predicate;
change control flow so explicit annotations take precedence: either move the
annotated_layout_unbound_ / annotated_predicate_ block to execute before the
TargetIsGfx950 block, or add a guard at the start of the gfx950 branch that
returns/continues if annotated_layout_unbound_.defined() (and/or
annotated_predicate_.defined()), ensuring ValidateCandidateAgainstFragments and
loop_layout_ assignment only run when no explicit annotation exists.
---
Nitpick comments:
In `@src/op/parallel.cc`:
- Around line 787-811: The matcher is enabling the
ComputeChunkBlockAwarePlanCandidate escape hatch based only on layout not being
Fragment and last-dimension width heuristics, which allows wide unswizzled
buffers to trigger the alternate binding; change the gate to require explicit
swizzle metadata instead of just buffer width: locate the block that reads
T.layout_map[buffer] and the early returns (the code that checks Layout layout =
T.layout_map[buffer]; if (layout.as<Fragment>()) return; ...), and add a check
for the actual swizzle flag/metadata (the same metadata used elsewhere to detect
swizzle-swap downstream — e.g., the buffer/store swizzle field or the function
that detects swizzled buffers) and only proceed to consider
ComputeChunkBlockAwarePlanCandidate when that swizzle metadata indicates a
swizzle will be used; leave the existing size/element checks intact but make
them subordinate to the new explicit swizzle presence check so unswizzled wide
stores no longer trigger the alternate binding.
HoistBufferResource rewrites every ptx_cp_async_lds it can match to the _rsrc form, which is the form codegen actually turns into the buffer_load_dwordx4 ... lds fast path. The two corner cases where the rewrite silently skips a call (empty buffer_vars early-return, or an access_ptr shape _extract_buffer_var can't pattern-match) used to land on a dedicated cp_async_gs_lds<N> codegen branch with its own buffer_load asm path -- duplicating the fast path and creating a second way to emit lds that could silently diverge from the _rsrc one. Collapse to a single safety net: treat ptx_cp_async_lds identically to ptx_cp_async in codegen so an unhoisted call just lowers to the synchronous tl::cp_async_gs<N>. Correct, no buffer_load_lds win for that particular call -- which is fine because in practice all calls get hoisted. Removes the now-dead cp_async_gs_lds<N> template too and updates the ptx_cp_async_lds docstring to reflect the new contract. Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
Why the existing AMD double buffer is slow
On the ROCm path, every
T.copy(gmem → smem)currently lowers to aglobal_load → VGPR → ds_writepair, and the lowering inserts ans_waitcnt vmcnt(0)between the load and the store so the data is in-register before it gets written to LDS. Thatvmcnt(0)is a full wait-for-all-outstanding-loads — it doesn't just block the lane that needs the result, it stalls the whole wave on every issued global load. The practical consequence is that pipelining is dead: anyT.Pipelined/double-buffer schedule the user wrote turns into back-to-back synchronous copies, because each stage'svmcnt(0)drains the previous stage's loads before the next compute can start.Background
gfx950 exposes a real async global→shared path:
buffer_load_dword{,x2,x4} ... ldsreads from global and writes directly into LDS without staging through a VGPR, ands_waitcnt vmcnt(N)plays the same role ascp.async.wait_group. The catch is that the hardware writes each lane's element tolds_base + 4 * lane_id— the destination is implicitly lane-contiguous. Tilelang normally hands shared memory to GEMM through an XOR bank-swizzle so MMA reads are bank-conflict-free, and that swizzle makes the per-lane LDS index non-contiguous, so a naive lowering would scatter the loads to the wrong banks. We also can't just drop the swizzle — it's the primary perf gate for MMA. That's why AMD has been stuck on the synchronous "load into VGPR, then store to LDS" path until now.What this PR does
The rewrite lives in
lower_tile_op.ccplus a small helper inlayout.{h,cc}. For everytl::ptx_cp_async_ldscall we try to move the swizzle from the LDS side to the global side: if the destination layout contains aBitwise(XOR, j, h(i))term, we fold the inverse delta into the sourcetvm_access_ptrso the LDS index becomes linear again while the global load picks up the permutation (which is free — it's just an address rewrite on the gmem side). We then prove lane-contiguity on the resulting LDS index by simplifying it against a synthetic thread var and checking the dependence is a single constant stride. If the proof fails the call is downgraded back totl::ptx_cp_async, so this is purely an optimization — no kernel becomes incorrect.The PR also lifts the per-call cost.
cp_async_gs_lds_with_rsrc<N>takes a pre-computedBufferResourcedescriptor plus a wave-uniform base address, and computing those per call would emit fourreadfirstlanes and the resource bit-cast on every load.HoistBufferResource(new Python pass, gfx950-only) collects every source buffer touched byptx_cp_async_lds, hoists one resource/base pair per buffer to the kernel prologue asAttrStmts, and rewrites the call sites to the_rsrcvariant. The same pass scales theptx_wait_group(n)counts:s_waitcnt vmcnt(N)counts individualbuffer_loads rather than commit-groups, so a wait-for-N-groups must become wait-for-(N × loads-per-group).wait_group(0)(wait-all) is the natural sentinel and stays unchanged.Codegen (
codegen_hip.{cc,h}) gets the matching templated intrinsics, andcp_async_gs_lds<N>/cp_async_gs_lds_with_rsrc<N>are added tosrc/tl_templates/hip/copy.h. The codegen translates the logical element count carried by the call back to bytes via the existingGetTileLangCPAsyncTransferBytesconvention, so the vec-loop folding invectorize_loop.cckeeps working.Why it's gfx950-only
buffer_load ... ldsexists on earlier CDNA, but the wait-count scaling and the bank-conflict-free swizzle delta we rely on were only validated against gfx950. On every other target the routing inlower_ptx_async_copy.ccand theHoistBufferResourcepass return unchanged, so existing kernels keep their current lowering.CI and performance
Here I have a CI for gfx950 common kernels. It can pass the correct check, and have some performance boost on many kenrels.
benenzhu/tile-kernel-bench-cdna4#23 (comment)
Bench results vs
mainStatus: OK (no regression beyond 5.00%)
Wins (14)
flash_attn_fwdb1_h32_s8192_d128flash_attn_fwdb1_h32_s8192_d128_causalflash_attn_fwdb8_h32_s2048_d128flash_attn_fwdb8_h32_s2048_d128_causalflash_attn_fwdb8_h32_s4096_d128flash_attn_fwdb8_h32_s4096_d128_causalgemm1024x1024x16384_NNgemm1024x8192x8192_NNgemm2048x2048x2048_NNgemm4096x4096x4096_NNgemm4096x4096x8192_NNgemm8192x8192x16384_NNgemm8192x8192x16384_NTgemm8192x8192x8192_NNVGPR down (4) — compiler win
flash_attn_fwdb1_h32_s8192_d128_causalflash_attn_fwdb8_h32_s2048_d128_causalflash_attn_fwdb8_h32_s4096_d128_causalgemm8192x8192x16384_NNVGPR up (11) — compiler regression
flash_attn_fwdb1_h32_s8192_d128flash_attn_fwdb8_h32_s2048_d128flash_attn_fwdb8_h32_s4096_d128gemm1024x1024x16384_NNgemm1024x8192x8192_NNgemm2048x2048x2048_NNgemm4096x4096x4096_NNgemm4096x4096x8192_NNgemm8192x8192x1024_NNgemm8192x8192x16384_NTgemm8192x8192x8192_NNSpills down (1) — compiler win
gemm8192x8192x16384_NNFull compare table (45 rows)
Self-hosted runner:
06-05· workflow runSummary by CodeRabbit
New Features
Improvements