Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 181 additions & 0 deletions src/transform/inject_pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,89 @@ using namespace ffi;
using tirx::GetSBlockReadWriteRegion;
namespace software_pipeline {

/*!
* \brief A StmtExprMutator that replaces Buffer objects in BufferLoad and
* BufferStore nodes with new Buffer objects from a replacement map.
* Used to reconcile old and new expanded buffer objects when sibling
* pipelines share buffers.
*/
class BufferReplacer : public StmtExprMutator {
public:
explicit BufferReplacer(const Map<Buffer, Buffer> &replacements)
: replacements_(replacements) {}

private:
PrimExpr VisitExpr_(const BufferLoadNode *op) final {
auto node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
if (auto new_buf = replacements_.Get(node->buffer)) {
node.CopyOnWrite()->buffer = new_buf.value();
}
return node;
}
Stmt VisitStmt_(const BufferStoreNode *op) final {
auto node = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
if (auto new_buf = replacements_.Get(node->buffer)) {
node.CopyOnWrite()->buffer = new_buf.value();
}
return node;
}
Stmt VisitStmt_(const DeclBufferNode *op) final {
if (auto new_buf = replacements_.Get(op->buffer)) {
ObjectPtr<DeclBufferNode> n = make_object<DeclBufferNode>(*op);
n->buffer = new_buf.value();
return DeclBuffer(n);
}
return StmtExprMutator::VisitStmt_(op);
}
Stmt VisitStmt_(const AllocBufferNode *op) final {
if (auto new_buf = replacements_.Get(op->buffer)) {
ObjectPtr<AllocBufferNode> n = make_object<AllocBufferNode>(*op);
n->buffer = new_buf.value();
return AllocBuffer(n);
}
return StmtExprMutator::VisitStmt_(op);
}
Stmt VisitStmt_(const SBlockNode *op) final {
SBlock block = Downcast<SBlock>(StmtExprMutator::VisitStmt_(op));
bool changed = false;
Array<Buffer> alloc_buffers;
for (const auto &buf : block->alloc_buffers) {
if (auto new_buf = replacements_.Get(buf)) {
alloc_buffers.push_back(new_buf.value());
changed = true;
} else {
alloc_buffers.push_back(buf);
}
}
Array<BufferRegion> reads;
for (const auto &region : block->reads) {
if (auto new_buf = replacements_.Get(region->buffer)) {
reads.push_back(BufferRegion(new_buf.value(), region->region));
changed = true;
} else {
reads.push_back(region);
}
}
Array<BufferRegion> writes;
for (const auto &region : block->writes) {
if (auto new_buf = replacements_.Get(region->buffer)) {
writes.push_back(BufferRegion(new_buf.value(), region->region));
changed = true;
} else {
writes.push_back(region);
}
}
if (changed) {
SBlockNode *bn = block.CopyOnWrite();
bn->alloc_buffers = std::move(alloc_buffers);
bn->reads = std::move(reads);
bn->writes = std::move(writes);
}
return block;
}
Map<Buffer, Buffer> replacements_;
};
Comment thread
coderabbitai[bot] marked this conversation as resolved.

class PipelineInjector : private StmtExprMutator {
public:
static Stmt Inject(const PrimFunc &func) {
Expand Down Expand Up @@ -360,6 +443,52 @@ class PipelineInjector : private StmtExprMutator {
CollectUsedPipelineBuffers(MakePipelineBody(pipeline_body_stmts),
buffer_data_to_buffer_, allocated_buffers_);

// Handle already-expanded buffers from previous sibling pipelines.
// When two T.Pipelined loops share the same alloc_shared buffer with
// different num_stages, the first pipeline expands the buffer (e.g.,
// (64,32) → (2,64,32)). The second pipeline must also expand it (to 5
// versions for num_stages=4), but it must expand the *original* 2D buffer,
// not the already-3D buffer from the first pipeline (which would produce
// a 4D buffer and crash LayoutInference).
//
// Fix: replace each already-expanded 3D buffer in pipeline_allocs with its
// original 2D form. After RewritePipeline creates a new expansion, we
// reconcile the two Buffer objects by replacing all references to the old
// expanded buffer with the new one.
{
Array<Buffer> fixed_allocs;
for (const auto &buf : pipeline_allocs) {
Buffer orig_2d;
Buffer old_expanded;
// Check if a pending remap's value (expanded buffer) matches this
// buffer by comparing their data Vars.
for (const auto &[remap_key, remap_val] : pending_buffer_remap_) {
if (remap_val->data.same_as(buf->data)) {
orig_2d = remap_key;
old_expanded = remap_val;
break;
}
}
if (orig_2d.defined()) {
// Track the old expanded buffer (3D from first pipeline) for
// later replacement with the new expanded buffer (3D from
// the second pipeline).
old_expanded_buffers_.Set(orig_2d, old_expanded);
// Use the original 2D buffer so RewritePipeline creates a
// fresh expansion with this pipeline's num_stages.
fixed_allocs.push_back(orig_2d);
LOG(WARNING)
<< "Buffer " << buf->name
<< " is shared across sibling Pipelined loops with "
"different num_stages. The second loop will re-expand "
"the original buffer instead of double-expanding.";
} else {
fixed_allocs.push_back(buf);
}
}
pipeline_allocs = std::move(fixed_allocs);
}

Optional<Array<Integer>> replayable_bind_mask;
if (auto replayable_bind_anno =
op->annotations.Get(kPipelineReplayableScalarBinds)) {
Expand Down Expand Up @@ -654,6 +783,27 @@ class PipelineInjector : private StmtExprMutator {
buffer_data_to_buffer_, pipeline_allocs, local_allocs, for_node,
pipeline_info, scalar_binding_blocks, target_);
Stmt pipeline = rewrite_result.pipeline;

// Reconcile old and new expanded buffers for siblings sharing buffers.
// The first pipeline's expansion created a buffer with shape (2,64,32),
// and the second pipeline creates one with 5 versions. We must unify them
// so that later passes (LayoutInference) see a single buffer object for
// each Var.
if (!old_expanded_buffers_.empty()) {
for (const auto &[orig_2d, old_expanded] : old_expanded_buffers_) {
for (const auto &[pipe_alloc, new_expanded] :
rewrite_result.buffer_remap) {
if (pipe_alloc.same_as(orig_2d)) {
old_expanded_to_new_.Set(old_expanded, new_expanded);
break;
}
}
}
if (!old_expanded_to_new_.empty()) {
pipeline = BufferReplacer(old_expanded_to_new_)(std::move(pipeline));
}
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.

subtree_modified_ = true;

auto unwrap_outer_attrs = [](Stmt stmt) {
Expand Down Expand Up @@ -758,6 +908,32 @@ class PipelineInjector : private StmtExprMutator {
// Propagate to parent: if this subtree was modified, parent should know.
subtree_modified_ = outer_flag || children_modified;

// Reconcile old and new expanded buffer objects from sibling pipelines.
// When two sibling T.Pipelined loops share buffers, the first pipeline
// creates an old expanded buffer (e.g. (2,64,32)) and the second creates
// a new one (e.g. (5,64,32)). We must replace all references to the old
// buffer in the entire block body with the new one so LayoutInference
// sees a single consistent buffer object per Var.
//
// Compose replacement chains to handle 3+ sibling pipelines sharing the
// same buffer. Without composition, A->B followed by B->C would leave
// references to B in the IR, creating mixed buffer objects for the same
// data Var.
if (!old_expanded_to_new_.empty()) {
Map<Buffer, Buffer> composed;
for (const auto &[key, val] : old_expanded_to_new_) {
Buffer resolved = val;
while (auto next = old_expanded_to_new_.Get(resolved)) {
resolved = next.value();
}
composed.Set(key, resolved);
}
old_expanded_to_new_ = std::move(composed);
block = Downcast<SBlock>(
BufferReplacer(old_expanded_to_new_)(std::move(block)));
children_modified = true;
}

// Update alloc_buffers with any pending buffer remaps from pipeline
// rewriting. This handles buffers allocated in this block but
// multi-versioned during pipeline rewriting of inner loops.
Expand Down Expand Up @@ -957,6 +1133,11 @@ class PipelineInjector : private StmtExprMutator {
Map<Var, Buffer> buffer_data_to_buffer_;
std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> allocated_buffers_;
Map<Buffer, Buffer> pending_buffer_remap_;
// Tracks old expanded buffers (from previous sibling pipelines) that need
// to be replaced. Key: original 2D buffer, Value: old 3D expanded buffer.
Map<Buffer, Buffer> old_expanded_buffers_;
// Maps old expanded buffers to new expanded buffers for reconciliation.
Map<Buffer, Buffer> old_expanded_to_new_;
std::vector<std::pair<Buffer, Buffer>> pending_layout_remapped_allocs_;
Optional<Target> target_;
Optional<String> global_symbol_;
Expand Down
72 changes: 72 additions & 0 deletions testing/python/issue/test_tilelang_issue_2309.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Copyright (c) 2025 Tile-AI Corporation.
# Licensed under the MIT License.
"""Tests for sibling pipeline shared buffer handling (issue #2309).

When two sibling T.Pipelined loops share the same alloc_shared buffer but use
different num_stages, the first pipeline expands the buffer from 2D to 3D, and
the second pipeline must not re-expand the already-3D buffer (which would
produce a 4D buffer and crash LayoutInference).
"""

import tilelang
import tilelang.language as T
import tilelang.testing


def run_asymmetric_pipeline():
"""Sibling pipelines sharing alloc_shared buffers with different num_stages."""
M, N, K = 512, 512, 512
BM, BN, BK = 64, 64, 32

@T.prim_func
def asymmetric_stages(
A: T.Tensor((M, K), T.float16),
B: T.Tensor((K, N), T.float16),
C: T.Tensor((M, N), T.float16),
):
K_half = K // 2
with T.Kernel(T.ceildiv(N, BN), T.ceildiv(M, BM), threads=128) as (bx, by):
A_shared = T.alloc_shared((BM, BK), T.float16)
B_shared = T.alloc_shared((BK, BN), T.float16)
C_local = T.alloc_fragment((BM, BN), T.float32)

T.clear(C_local)

for ko in T.Pipelined(T.ceildiv(K_half, BK), num_stages=2):
T.copy(A[by * BM, ko * BK], A_shared)
T.copy(B[ko * BK, bx * BN], B_shared)
T.gemm(A_shared, B_shared, C_local)

for ko in T.Pipelined(T.ceildiv(K_half, BK), num_stages=4):
T.copy(A[by * BM, K_half + ko * BK], A_shared)
T.copy(B[K_half + ko * BK, bx * BN], B_shared)
T.gemm(A_shared, B_shared, C_local)

T.copy(C_local, C[by * BM, bx * BN])

kernel = tilelang.compile(
asymmetric_stages,
out_idx=[2],
pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True},
)

profiler = kernel.get_profiler()

def ref_program(A, B):
import torch

C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.float16)
return C

profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)


def test_sibling_pipeline_different_num_stages():
"""Reproducer for issue #2309: sibling pipelines with different num_stages
sharing alloc_shared buffers should compile and run correctly."""
run_asymmetric_pipeline()


if __name__ == "__main__":
tilelang.testing.main()
Loading