Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
e7ac334
[ROCm] Expose HIP kernel n_regs / n_spills on JITKernel via clang rem…
benenzhu May 15, 2026
7f594d1
fix for final merge
benenzhu May 16, 2026
ad72ada
fix for final merge
benenzhu May 16, 2026
23be555
fix for final merge
benenzhu May 16, 2026
f276441
add a flag to save temp files
benenzhu May 16, 2026
600b516
add a flag to save temp files
benenzhu May 16, 2026
f0fd44f
M2: add cp_async_gs_lds_with_rsrc<N> device templates for gfx950
benenzhu May 17, 2026
adb4364
M3: register ptx_cp_async_lds / ptx_make_buffer_resource / ptx_cp_asy…
benenzhu May 17, 2026
e38d5fd
M4: HIP codegen handlers for buffer_load_lds ops + hoisted-resource A…
benenzhu May 17, 2026
1053bcb
M5: route eligible 16B copies to ptx_cp_async_lds on gfx950
benenzhu May 17, 2026
7ac1d96
M6: HoistBufferResource Python pass + pipeline wiring (gfx950 only)
benenzhu May 17, 2026
3c82e3b
M7: layout SwizzleDelta API + ROCm swizzle-swap in lower_tile_op
benenzhu May 17, 2026
ae29f93
M8a: re-enable MergeSharedMemoryAllocations in OptimizeForTarget
benenzhu May 17, 2026
d101b38
M8b: integrate ptx_cp_async_lds with vec-loop folding + buf-merge + c…
benenzhu May 17, 2026
22fd402
M5-harden: replace XOR-call gate with real affine lane-contiguity proof
benenzhu May 17, 2026
91878e4
M9: VisitExpr_(CallNode) swizzle-swap on tl::ptx_cp_async_lds (real f…
benenzhu May 17, 2026
4736c94
M6.5: AMD vmcnt wait-count scaling in HoistBufferResource
benenzhu May 17, 2026
c3c0f9a
add a flag to save temp files
benenzhu May 17, 2026
e6ff3f0
M9-safe: downgrade ptx_cp_async_lds to ptx_cp_async when swap can't p…
benenzhu May 17, 2026
79c6580
add a flag to save temp files
benenzhu May 18, 2026
ff6cb4e
format file
benenzhu May 18, 2026
1007a03
M10: chunk-block-aware binding via early CBA hook in InferLayout
benenzhu May 21, 2026
0a864a8
M11: always emit ptx_cp_async_lds; let M9 swap or downgrade
benenzhu May 21, 2026
df611d3
docs: annotate FullBank/makeGemmABLayout with inline derivations
benenzhu May 21, 2026
c5b1973
Revert "docs: annotate FullBank/makeGemmABLayout with inline derivati…
benenzhu May 22, 2026
781eb74
lower_ptx_async_copy: drop dead IsLdsLaneContiguous + dst_check_index
benenzhu May 22, 2026
7163b44
parallel: drop redundant CBA call in ComputePlanCandidate
benenzhu May 22, 2026
3c4d1d2
Merge branch 'main' into feat-c-remove_vmcnt0
benenzhu May 22, 2026
03343e2
parallel: apply clang-format
benenzhu May 22, 2026
e9d7bdb
review fixes for ptx_cp_async_lds family
benenzhu May 22, 2026
3d95235
parallel: gate early CBA hook on fragment validation
benenzhu May 23, 2026
7f8b103
parallel: scope early CBA hook to gfx950 only
benenzhu May 23, 2026
c78c587
zz
benenzhu May 27, 2026
2718f49
drop bare ptx_cp_async_lds codegen, fall back to sync cp_async_gs
benenzhu May 27, 2026
25f61a2
clean code
benenzhu May 27, 2026
a0312d5
clean code
benenzhu May 27, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/tvm
Submodule tvm updated from 0be336 to 8435b8
100 changes: 92 additions & 8 deletions src/backend/rocm/codegen/codegen_hip.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,12 @@ std::optional<DataType> GetAccessPtrElementType(const PrimExpr &expr) {
}

int GetTileLangCPAsyncTransferBytes(const CallNode *op) {
ICHECK(op->args.size() == 3 || op->args.size() == 4)
<< "tl::ptx_cp_async expects 3 or 4 arguments (dst_access_ptr, "
"src_access_ptr, num_elems, [predicate])";
// Accepts ptx_cp_async / ptx_cp_async_lds (3 or 4 args: dst, src, num_elems,
// [predicate]) and ptx_cp_async_lds_rsrc (5 args: dst, src, num_elems,
// rsrc_var, base_var) -- only args[0..2] are read here.
ICHECK(op->args.size() == 3 || op->args.size() == 4 || op->args.size() == 5)
<< "tl::ptx_cp_async family expects 3-5 arguments (dst_access_ptr, "
"src_access_ptr, num_elems, ...)";
const auto *num_elems_imm = op->args[2].as<IntImmNode>();
ICHECK(num_elems_imm) << "tl::ptx_cp_async num_elems must be IntImm, but got "
<< op->args[2];
Expand Down Expand Up @@ -1169,9 +1172,31 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
}
this->stream << ");\n";
};
if (op->op.same_as(builtin::ptx_cp_async())) {
// args[0] = dst_access_ptr, args[1] = src_access_ptr, args[2] = bytes,
// args[3] = predicate (optional)
if (op->op.same_as(tl::ptx_make_buffer_resource())) {
// Expression form: emits make_wave_buffer_resource((const void*)(ptr)).
// The enclosing LetStmt visitor recognises this Call and emits `auto x =`.
ICHECK(op->args.size() == 1)
<< "ptx_make_buffer_resource expects 1 argument (global_ptr)";
std::string ptr = this->PrintExpr(op->args[0]);
os << "make_wave_buffer_resource((const void*)(" << ptr << "))";
} else if (op->op.same_as(tl::ptx_cp_async_lds_rsrc())) {
// args = [dst, src, num_elems, rsrc_var, base_var]. arg 2 is the logical
// element count inherited from the ptx_cp_async_lds call that
// HoistBufferResource rewrote into this rsrc form -- the helper does the
// src/dst width-equality and {4,8,16} validation that the plain
// ptx_cp_async path also relies on.
ICHECK(op->args.size() == 5) << "ptx_cp_async_lds_rsrc expects 5 arguments";
std::string dst = this->PrintExpr(op->args[0]);
std::string src = this->PrintExpr(op->args[1]);
int total_bytes = GetTileLangCPAsyncTransferBytes(op);
std::string size = std::to_string(total_bytes);
std::string rsrc = this->PrintExpr(op->args[3]);
std::string base = this->PrintExpr(op->args[4]);
this->PrintIndent();
this->stream << "tl::cp_async_gs_lds_with_rsrc<" << size << ">(" << dst
<< ", " << src << ", " << rsrc << ", " << base << ");\n";
} else if (op->op.same_as(builtin::ptx_cp_async())) {
Comment thread
coderabbitai[bot] marked this conversation as resolved.
// builtin::ptx_cp_async stores byte width directly in arg 2.
ICHECK(op->args.size() == 3 || op->args.size() == 4)
<< "ptx_cp_async expects 3 or 4 arguments (dst_access_ptr, "
"src_access_ptr, bytes, [predicate])";
Expand All @@ -1189,7 +1214,18 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
this->stream << "tl::cp_async_gs_conditional<" << size << ">(" << dst
<< ", " << src << ", " << condition << ");\n";
}
} else if (op->op.same_as(tl::ptx_cp_async())) {
} else if (op->op.same_as(tl::ptx_cp_async()) ||
op->op.same_as(tl::ptx_cp_async_lds())) {
// Both store logical element count in arg 2; convert to bytes via
// GetTileLangCPAsyncTransferBytes.
//
// tl::ptx_cp_async_lds is normally rewritten to ptx_cp_async_lds_rsrc
// by the HoistBufferResource pass. If a call survives the rewrite
// (e.g. an access_ptr shape _extract_buffer_var can't pattern-match,
// or the pass found nothing to hoist), fall back to the synchronous
// tl::cp_async_gs<bytes> path here -- correctness is preserved at
// the cost of giving up the buffer_load_dwordx4...lds fast path for
// that particular call. Treat both ops identically in codegen.
int total_bytes = GetTileLangCPAsyncTransferBytes(op);
std::string dst = this->PrintExpr(op->args[0]);
std::string src = this->PrintExpr(op->args[1]);
Expand All @@ -1207,6 +1243,10 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
print_extern_call_stmt("tl::cp_async_commit");
} else if (op->op.same_as(builtin::ptx_wait_group())) {
int n = Downcast<IntImm>(op->args[0])->value;
// AMDGPU s_waitcnt vmcnt field is 6-bit (max 63); clamp to keep the
// "n"(cnt) immediate constraint in tl::cp_async_wait valid.
if (n > 63)
n = 63;
std::string func_name = "tl::cp_async_wait<" + std::to_string(n) + ">";
print_extern_call_stmt(func_name, 1);
} else if (op->op.same_as(builtin::create_barriers())) {
Expand Down Expand Up @@ -1693,7 +1733,33 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
}

void CodeGenTileLangHIP::VisitStmt_(const AttrStmtNode *op) {
if (op->attr_key == tl::attr::kLexicalAllocScope) {
if (op->attr_key == "buffer_resource_var") {
// Hoisted resource descriptor from the HoistBufferResource Python pass.
// Emits: auto {rsrc_var} = make_wave_buffer_resource((const
// void*)({buf_var}));
auto rsrc_var = Downcast<Var>(op->node);
std::string rsrc_vid = AllocVarID(rsrc_var.get());
std::string buf_ptr = PrintExpr(op->value);
this->PrintIndent();
this->stream << "auto " << rsrc_vid
<< " = make_wave_buffer_resource((const void*)(" << buf_ptr
<< "));\n";
this->VisitStmt(op->body);
return;
} else if (op->attr_key == "buffer_base_var") {
// Hoisted readfirstlane base address from the HoistBufferResource pass.
// Emits: uint32_t {base_var} = __builtin_amdgcn_readfirstlane(
// (uint32_t)(uintptr_t)({buf_var}));
auto base_var = Downcast<Var>(op->node);
std::string base_vid = AllocVarID(base_var.get());
std::string buf_ptr = PrintExpr(op->value);
this->PrintIndent();
this->stream << "uint32_t " << base_vid
<< " = __builtin_amdgcn_readfirstlane("
<< "(uint32_t)(uintptr_t)(" << buf_ptr << "));\n";
this->VisitStmt(op->body);
return;
} else if (op->attr_key == tl::attr::kLexicalAllocScope) {
PrintIndent();
stream << "{\n";
int scope = BeginScope();
Expand Down Expand Up @@ -1728,6 +1794,24 @@ void CodeGenTileLangHIP::VisitStmt_(const AttrStmtNode *op) {
CodeGenC::VisitStmt_(op);
}

void CodeGenTileLangHIP::VisitStmt_(const BindNode *op) {
// For Bind(var = ptx_make_buffer_resource(buf)), emit `auto x = ...;`
// instead of the C-typed declaration the base class would produce. The
// return type is int32x4_t and naming it explicitly is brittle across
// backends, so `auto` keeps the template lookup in make_wave_buffer_resource
// responsible for the type. The body that follows the bind is handled by
// the enclosing SeqStmt visitor.
if (auto *call = op->value.as<CallNode>()) {
if (call->op.same_as(tl::ptx_make_buffer_resource())) {
std::string value = PrintExpr(op->value);
PrintIndent();
stream << "auto " << AllocVarID(op->var.get()) << " = " << value << ";\n";
return;
}
}
CodeGenC::VisitStmt_(op);
}

void CodeGenTileLangHIP::VisitStmt_(const AllocBufferNode *op) {
std::string vid = AllocVarID(op->buffer->data.get());

Expand Down
1 change: 1 addition & 0 deletions src/backend/rocm/codegen/codegen_hip.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class CodeGenTileLangHIP final : public CodeGenC {
void VisitExpr_(const ShuffleNode *op, std::ostream &os) final; // NOLINT(*)
void VisitStmt_(const AllocBufferNode *op) final;
void VisitStmt_(const AttrStmtNode *op) final;
void VisitStmt_(const BindNode *op) final;
void VisitStmt_(const BufferStoreNode *op) final;

// Override this as a work around for __grid_constant__ parameter
Expand Down
4 changes: 3 additions & 1 deletion src/backend/rocm/op/copy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,9 @@ struct Copy {
auto inject_result =
InjectPTXAsyncCopy(lowered_loop, /*enable_auto_async_copy=*/true,
/*async_without_async_commit_wait=*/
no_implicit_commit_wait || GetIsAsyncCopy(op));
no_implicit_commit_wait || GetIsAsyncCopy(op),
/*enable_buffer_load_lds=*/
TargetIsGfx950(T.target));
Stmt cp_async_loop = inject_result.stmt;
if (!inject_result.injected_ptx_async_copy) {
DLOG(WARNING) << "cp.async rewrite miss for copy src=" << op.src->name
Expand Down
15 changes: 12 additions & 3 deletions src/layout/gemm_layouts.cc
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,10 @@ static Layout MakeQuarterBankSwizzleLayout2D(int stride, int continuous,
PrimExpr vec = FloorMod(j, vector_size);
PrimExpr c_swizzle = xor2x2(c, FloorDiv(s, 4));
PrimExpr index = vec + (c_swizzle + s * 2) * vector_size;
return Layout(Array<PrimExpr>{stride, continuous}, {tc, ts, index});
PrimExpr swizzle_delta = (c_swizzle - c) * vector_size;
Layout result(Array<PrimExpr>{stride, continuous}, {tc, ts, index});
const_cast<LayoutNode *>(result.get())->SetSwizzleDelta(swizzle_delta);
return result;
}

Layout makeQuarterBankSwizzleLayout(const Buffer &buffer) {
Expand Down Expand Up @@ -486,7 +489,10 @@ static Layout MakeHalfBankSwizzleLayout2D(int stride, int continuous,
PrimExpr vec = FloorMod(j, vector_size);
PrimExpr c_swizzle = xor4x4(c, FloorDiv(s, 2));
PrimExpr index = vec + (c_swizzle + s * 4) * vector_size;
return Layout(Array<PrimExpr>{stride, continuous}, {tc, ts, index});
PrimExpr swizzle_delta = (c_swizzle - c) * vector_size;
Layout result(Array<PrimExpr>{stride, continuous}, {tc, ts, index});
const_cast<LayoutNode *>(result.get())->SetSwizzleDelta(swizzle_delta);
return result;
}

Layout makeHalfBankSwizzleLayout(const Buffer &buffer) {
Expand Down Expand Up @@ -515,7 +521,10 @@ static Layout MakeFullBankSwizzleLayout2D(int stride, int continuous,
PrimExpr vec = FloorMod(j, vector_size);
PrimExpr c_swizzle = xor8x8(c, s);
PrimExpr index = vec + (c_swizzle + s * 8) * vector_size;
return Layout(Array<PrimExpr>{stride, continuous}, {tc, ts, index});
PrimExpr swizzle_delta = (c_swizzle - c) * vector_size;
Layout result(Array<PrimExpr>{stride, continuous}, {tc, ts, index});
const_cast<LayoutNode *>(result.get())->SetSwizzleDelta(swizzle_delta);
return result;
}

Layout makeFullBankSwizzleLayout(const Buffer &buffer) {
Expand Down
28 changes: 27 additions & 1 deletion src/layout/layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,33 @@ Layout LayoutNode::Expand(const Array<PrimExpr> &leading_shape) const {
new_forward_index.push_back(Substitute(e, vmap));
}

return Layout(new_input_size, new_forward_index);
Layout result(new_input_size, new_forward_index);
// Propagate swizzle_delta_ through Expand: substitute placeholder
// indices so the delta keeps referring to the same physical input
// dimension after the leading-shape prefix is added.
if (swizzle_delta_.defined()) {
const_cast<LayoutNode *>(result.get())
->SetSwizzleDelta(Substitute(swizzle_delta_.value(), vmap));
}
return result;
}

PrimExpr LayoutNode::SwizzleDelta(const Array<PrimExpr> &input_indices) const {
if (!swizzle_delta_.defined()) {
return IntImm(DataType::Int(32), 0);
}
// Substitute the last InputDim() entries of input_indices into
// swizzle_delta_, matching the convention Forward() uses.
ICHECK_GE(input_indices.size(), InputDim())
<< "SwizzleDelta requires at least " << InputDim() << " indices, but got "
<< input_indices.size();
PrimExpr delta = swizzle_delta_.value();
size_t offset = input_indices.size() - InputDim();
for (size_t i = 0; i < InputDim(); ++i) {
delta =
Substitute(delta, {{InputPlaceholder(i), input_indices[offset + i]}});
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.
return delta;
}

Fragment FragmentNode::Repeat(const Array<PrimExpr> &repeats,
Expand Down
23 changes: 23 additions & 0 deletions src/layout/layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,24 @@ class LayoutNode : public Object {

virtual bool IsEqual(const LayoutNode *other, bool skip_index = false) const;

/*!
* \brief Get the XOR swizzle column delta on the last input dimension.
*
* For swizzled layouts (quarter/half/full bank) returns the column
* delta caused by the XOR: delta = (c_swizzled - c) * vector_size,
* substituted against the supplied indices. For non-swizzle layouts
* returns 0. Used by the swizzle-swap optimisation in lower_tile_op.cc
* to move the XOR off the LDS-store side and onto the global-load
* side when the target supports buffer_load ... lds direct DMA.
*/
virtual PrimExpr SwizzleDelta(const Array<PrimExpr> &input_indices) const;

/*! \brief Whether this layout carries a non-trivial swizzle delta. */
bool HasSwizzle() const { return swizzle_delta_.defined(); }

/*! \brief Set the swizzle delta expression (called by layout factories). */
void SetSwizzleDelta(PrimExpr delta) { swizzle_delta_ = delta; }

static void RegisterReflection();
TVM_FFI_DECLARE_OBJECT_INFO("tl.Layout", LayoutNode, Object);
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
Expand All @@ -112,6 +130,11 @@ class LayoutNode : public Object {
void UpdateAnalyzer(arith::Analyzer *analyzer) const;
Array<PrimExpr> forward_index_;
Array<PrimExpr> input_size_;
/*!
* \brief Optional XOR swizzle delta in terms of InputPlaceholders, set
* by swizzle layout factories and propagated through Expand/Reshape.
*/
Optional<PrimExpr> swizzle_delta_;
};

/*!
Expand Down
15 changes: 15 additions & 0 deletions src/op/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,21 @@ TIR_DEFINE_TL_BUILTIN(ptx_cp_async)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(ptx_cp_async_lds)
.set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(ptx_make_buffer_resource)
.set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(ptx_cp_async_lds_rsrc)
.set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(fence_proxy_async)
.set_num_inputs(0)
.set_attr<TCallEffectKind>("TCallEffectKind",
Expand Down
52 changes: 52 additions & 0 deletions src/op/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,58 @@ TVM_DLL const Op &ptx_cp_async_barrier_noinc();
*/
TVM_DLL const Op &ptx_cp_async();

/*!
* \brief Marker for an eligible async G2S copy on gfx950+.
*
* Emitted by LowerPTXAsyncCopy in place of ptx_cp_async for 16-byte
* non-predicated shared-memory writes whose LDS index is (post
* swizzle-swap, see lower_tile_op.cc) lane-contiguous. The
* HoistBufferResource pass then rewrites each call to
* ptx_cp_async_lds_rsrc with a pre-computed buffer resource
* descriptor + base address; that rsrc form is what codegen emits as
* the buffer_load_dwordx4 ... lds fast path.
*
* If a call survives the rewrite (e.g. an access_ptr the hoister
* can't pattern-match), codegen falls back to the synchronous
* tl::cp_async_gs<N> path -- correct, but no buffer_load_lds win.
*
* ptx_cp_async_lds(dst_access_ptr, src_access_ptr, num_elems)
*
* num_elems is the logical element count (NOT byte width). Lowering
* derives the {4, 8, 16} byte transfer width from the access-ptr dtype.
* Passing this as elements keeps vec-loop folding in vectorize_loop.cc
* (which multiplies the count when it widens a loop) consistent with
* the plain ptx_cp_async path.
*/
TVM_DLL const Op &ptx_cp_async_lds();

/*!
* \brief Create a buffer resource descriptor for async G2S LDS copy (gfx950+).
*
* ptx_make_buffer_resource(global_ptr)
*
* Returns an int32x4_t buffer resource descriptor via
* make_wave_buffer_resource (defined in src/tl_templates/hip/copy.h).
*/
TVM_DLL const Op &ptx_make_buffer_resource();

/*!
* \brief Truly async G2S copy with pre-computed buffer resource (gfx950+).
*
* Same as ptx_cp_async_lds but takes a pre-hoisted buffer resource
* descriptor + base address to avoid redundant readfirstlane /
* make_wave_buffer_resource calls inside unrolled loops. The
* HoistBufferResource Python pass rewrites ptx_cp_async_lds calls to this
* form once per kernel.
*
* ptx_cp_async_lds_rsrc(dst_access_ptr, src_access_ptr, num_elems, rsrc_var,
* base_var)
*
* num_elems uses the same convention as ptx_cp_async_lds -- logical
* element count, not bytes; lowering converts via the access-ptr dtype.
*/
TVM_DLL const Op &ptx_cp_async_lds_rsrc();

/*!
* \brief Pack two b16 value into a b32 value
*
Expand Down
Loading