From e7ac334788626daf369138363ee7ac50b0bf65c5 Mon Sep 17 00:00:00 2001 From: zhutaoyu Date: Fri, 15 May 2026 17:31:22 +0000 Subject: [PATCH 01/34] [ROCm] Expose HIP kernel n_regs / n_spills on JITKernel via clang remarks MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds a triton-style view of per-kernel AMD GPU resource usage on the tilelang JITKernel, queryable as kernel.n_regs / n_spills / n_max_threads (with a richer resource_usage dict mapping kernel name to a KernelResourceUsage dataclass). Implementation: * tilelang/jit/adapter/hip_resource_info.py — passes -Rpass-analysis=kernel-resource-usage to hipcc, parses the per-kernel remarks (Function Name / VGPRs / VGPRs Spill / TotalSGPRs / etc.) out of the captured stdio, and *strips* those lines before the output is printed or included in error messages, so autotune logs don't drown in remark blocks while real warnings/errors still surface. Includes JSON (de)serialization helpers. * tilelang/contrib/hipcc.py — adds the remark flag, parses + filters the output. Same on the LibraryGenerator HIP path (tilelang/jit/adapter/libgen.py); HIP compiles always pipe stdio there so the filter has something to act on (verbose=True still prints the filtered output). * tilelang/jit/kernel.py — opens a thread-local recorder window around lower() on HIP and exposes the parsed dict as lazy resource_usage / n_regs / n_spills / n_max_threads properties. * tilelang/cache/kernel_cache.py — persists the parsed dict as resource_usage.json next to kernel_lib.so on cache miss; reloads it on cache hit. This way subsequent runs don't lose the resource view to the cache, without paying the runtime API / ctypes cost. Older cache entries (no JSON file) silently degrade to None. Verified on MI355X (gfx950) with a small elementwise add: cache miss and cache hit both report n_regs=5, n_spills=0; zero remark lines leak to stdout/stderr. Co-Authored-By: Claude Opus 4 (1M context) --- tilelang/cache/kernel_cache.py | 22 +++- tilelang/contrib/hip_resource_info.py | 145 ++++++++++++++++++++++++++ tilelang/contrib/hipcc.py | 12 ++- tilelang/jit/adapter/libgen.py | 25 ++++- tilelang/jit/kernel.py | 49 ++++++++- 5 files changed, 248 insertions(+), 5 deletions(-) create mode 100644 tilelang/contrib/hip_resource_info.py diff --git a/tilelang/cache/kernel_cache.py b/tilelang/cache/kernel_cache.py index 1fd9e68b2..bd67027b2 100644 --- a/tilelang/cache/kernel_cache.py +++ b/tilelang/cache/kernel_cache.py @@ -22,6 +22,7 @@ from tilelang.utils.language import get_prim_func_name from tilelang import env from tilelang.jit import JITKernel +from tilelang.contrib.hip_resource_info import dump_to_file, load_from_file from tilelang import __version__ import platform @@ -46,6 +47,7 @@ class KernelCache: host_kernel_path = "host_kernel.cu" kernel_lib_path = "kernel_lib.so" params_path = "params.pkl" + resource_usage_path = "resource_usage.json" cache_root_dir = "kernels" staging_root_dir = ".staging" @@ -459,6 +461,13 @@ def _save_kernel_to_disk(self, key: str, kernel: JITKernel, func: Callable = Non self.logger.debug(f"Saving kernel parameters to disk: {params_path}") KernelCache._safe_write_file(params_path, "wb", lambda file: cloudpickle.dump(kernel.params, file)) + # Persist parsed kernel-resource-usage remarks (HIP only; + # empty dict on other targets) so cache hits don't lose + # kernel.n_regs / n_spills / resource_usage. + usage = getattr(kernel, "_resource_usage", None) or {} + if usage: + dump_to_file(usage, os.path.join(staging_path, self.resource_usage_path)) + missing_files = self._get_missing_complete_cache_files(staging_path) if missing_files: missing_names = ", ".join(os.path.basename(path) for path in missing_files) @@ -531,7 +540,7 @@ def _load_kernel_from_disk( except Exception: self.logger.exception("Error loading kernel parameters from disk") - return self._build_kernel( + kernel = self._build_kernel( func=func, host_kernel_source=host_kernel_source, device_kernel_source=device_kernel_source, @@ -545,6 +554,17 @@ def _load_kernel_from_disk( compile_flags=compile_flags, ) + # Restore parsed kernel-resource-usage if a previous compile + # persisted it; absent file is fine (older caches, non-HIP). + ru_path = os.path.join(cache_path, self.resource_usage_path) + if kernel is not None and os.path.exists(ru_path): + try: + kernel._resource_usage = load_from_file(ru_path) + except Exception: + self.logger.exception("Error loading kernel resource_usage from disk") + + return kernel + def _clear_disk_cache(self): """ Removes all cached kernels from disk. diff --git a/tilelang/contrib/hip_resource_info.py b/tilelang/contrib/hip_resource_info.py new file mode 100644 index 000000000..d053a03d7 --- /dev/null +++ b/tilelang/contrib/hip_resource_info.py @@ -0,0 +1,145 @@ +"""Parse AMD GPU per-kernel resource usage out of clang's +``-Rpass-analysis=kernel-resource-usage`` remarks and expose them on +JITKernel. + +clang emits a block like:: + + remark: src.cc:9:0: Function Name: main_kernel [-Rpass-analysis=kernel-resource-usage] + remark: src.cc:9:0: TotalSGPRs: 16 [-Rpass-analysis=kernel-resource-usage] + remark: src.cc:9:0: VGPRs: 5 [-Rpass-analysis=kernel-resource-usage] + remark: src.cc:9:0: ScratchSize [bytes/lane]: 0 [-Rpass-analysis=kernel-resource-usage] + remark: src.cc:9:0: SGPRs Spill: 0 [-Rpass-analysis=kernel-resource-usage] + remark: src.cc:9:0: VGPRs Spill: 0 [-Rpass-analysis=kernel-resource-usage] + ... + +right alongside any real warnings/errors. We *parse and strip* those +lines before printing or raising, so autotune logs don't drown in +hundreds of remark blocks while real diagnostics still surface. +""" + +from __future__ import annotations + +import contextlib +import json +import re +import threading +from dataclasses import asdict, dataclass, field + +# A line is a kernel-resource-usage remark iff it carries this exact tag. +# clang appends the option name as ``[-Rpass-analysis=kernel-resource-usage]``. +_REMARK_TAG = "[-Rpass-analysis=kernel-resource-usage]" +# clang prints `::: remark: : [-Rpass-...]` +_REMARK_LINE_RE = re.compile(r"\bremark:\s*(?P[^:]+?):\s*(?P.*?)\s*" + re.escape(_REMARK_TAG) + r"\s*$") + + +@dataclass +class KernelResourceUsage: + """Resource counts as reported by clang's kernel-resource-usage pass. + + Field names mirror the remark labels (lower-cased, normalized) so we + can extend without breaking callers. + """ + + n_regs: int = 0 # VGPRs + # Total VGPR-equivalent spill pressure: the explicit `VGPRs Spill` count + # plus scratch memory in dwords (`ScratchSize [bytes/lane]` / 4). Matches + # how triton accounts for spills (its n_spills is scratch_bytes / 4) but + # also folds in clang's explicit spill count when present. + n_spills: int = 0 + scratch_bytes: int = 0 # raw `ScratchSize [bytes/lane]` + n_max_threads: int | None = None # not in remarks; kept for API symmetry + extra: dict[str, str] = field(default_factory=dict) # raw remark key→value + + +_FLAG = "-Rpass-analysis=kernel-resource-usage" + +_RECORDER = threading.local() + + +def hipcc_remark_flag() -> str: + """The clang flag callers should pass to hipcc to enable the remark + output we parse here.""" + return _FLAG + + +def reset_recorder() -> None: + """Begin a fresh recording window on this thread.""" + _RECORDER.items = {} + + +def pop_recorded() -> dict[str, KernelResourceUsage]: + """Return everything recorded since the last ``reset_recorder`` and + clear the buffer.""" + items = getattr(_RECORDER, "items", {}) + _RECORDER.items = {} + return dict(items) + + +def filter_and_record(output: str) -> str: + """Strip kernel-resource-usage remarks from ``output``, parse them, + and append the parsed entries to the active recorder (if any). + Returns the filtered output with the remark lines removed.""" + if _REMARK_TAG not in output: + return output + + kept_lines: list[str] = [] + current_name: str | None = None + current: KernelResourceUsage | None = None + items = getattr(_RECORDER, "items", None) + + for line in output.splitlines(keepends=True): + m = _REMARK_LINE_RE.search(line.rstrip("\n").rstrip("\r")) + if m is None: + kept_lines.append(line) + continue + key = m.group("key").strip() + value = m.group("value").strip() + if key == "Function Name": + # finalize previous block + if items is not None and current_name and current is not None: + items[current_name] = current + current_name = value + current = KernelResourceUsage() + elif current is not None: + current.extra[key] = value + if key == "VGPRs": + with contextlib.suppress(ValueError): + current.n_regs = int(value) + elif key == "VGPRs Spill": + with contextlib.suppress(ValueError): + current.n_spills += int(value) + elif key.startswith("ScratchSize"): + # ScratchSize [bytes/lane] — fold scratch dwords into n_spills. + with contextlib.suppress(ValueError): + current.scratch_bytes = int(value) + current.n_spills += int(value) // 4 + # remark line is dropped (not added to kept_lines) + + if items is not None and current_name and current is not None: + items[current_name] = current + + return "".join(kept_lines) + + +def dump_to_file(usage: dict[str, KernelResourceUsage], path: str) -> None: + """Persist parsed resource usage so it survives kernel-cache hits.""" + data = {name: asdict(u) for name, u in usage.items()} + with open(path, "w") as f: + json.dump(data, f, indent=2, sort_keys=True) + + +def load_from_file(path: str) -> dict[str, KernelResourceUsage]: + """Inverse of ``dump_to_file``. Tolerant of missing / unknown fields + so older cache entries keep working when the dataclass evolves.""" + with open(path) as f: + data = json.load(f) + out: dict[str, KernelResourceUsage] = {} + for name, entry in data.items(): + out[name] = KernelResourceUsage( + n_regs=int(entry.get("n_regs", 0)), + n_spills=int(entry.get("n_spills", 0)), + scratch_bytes=int(entry.get("scratch_bytes", 0)), + n_max_threads=entry.get("n_max_threads"), + extra=dict(entry.get("extra", {})), + ) + return out diff --git a/tilelang/contrib/hipcc.py b/tilelang/contrib/hipcc.py index 85eb661a2..f733703e4 100644 --- a/tilelang/contrib/hipcc.py +++ b/tilelang/contrib/hipcc.py @@ -15,6 +15,8 @@ from tvm.base import py_str from tvm.contrib.rocm import get_rocm_arch, find_rocm_path +from .hip_resource_info import filter_and_record, hipcc_remark_flag + def compile_hip(code, target_format="hsaco", arch=None, options=None, path_target=None, verbose=False): """Compile HIP code with hipcc. @@ -69,19 +71,25 @@ def compile_hip(code, target_format="hsaco", arch=None, options=None, path_targe else: raise ValueError("options must be str or list of str") + # -Rpass-analysis=kernel-resource-usage prints per-kernel VGPR / + # scratch / spill counts as clang remarks; tilelang parses + strips + # them in `filter_and_record` so they don't drown autotune logs. + cmd += [hipcc_remark_flag()] + cmd += ["-o", file_target] cmd += [temp_code] proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) (out, _) = proc.communicate() + out_text = filter_and_record(py_str(out)) if verbose: - print(py_str(out)) + print(out_text) if proc.returncode != 0: msg = code msg += "\nCompilation error:\n" - msg += py_str(out) + msg += out_text raise RuntimeError(msg) with open(file_target, "rb") as f: diff --git a/tilelang/jit/adapter/libgen.py b/tilelang/jit/adapter/libgen.py index 020ce04f0..125f693f6 100644 --- a/tilelang/jit/adapter/libgen.py +++ b/tilelang/jit/adapter/libgen.py @@ -21,6 +21,8 @@ from tilelang.env import TILELANG_TEMPLATE_PATH from tilelang.utils.target import target_get_mcpu +from tilelang.contrib.hip_resource_info import filter_and_record, hipcc_remark_flag + from .utils import is_cpu_target, is_cuda_target, is_hip_target logger = logging.getLogger(__name__) @@ -126,6 +128,10 @@ def compile_lib(self, timeout: float = None): f"--offload-arch={arch}", "--shared", src.name, + # Emit per-kernel resource-usage remarks; tilelang parses + # and strips them below before printing/raising so they + # don't pollute autotune logs. + hipcc_remark_flag(), ] command += [ "-I" + COMPOSABLE_KERNEL_INCLUDE_DIR, @@ -172,6 +178,12 @@ def compile_lib(self, timeout: float = None): stdout=subprocess.PIPE, stderr=subprocess.STDOUT, ) + # On HIP we always capture stdio so the kernel-resource-usage + # remarks can be parsed + stripped before anything reaches the + # terminal; the filtered output still surfaces on error. + if is_hip_target(target): + run_kwargs.setdefault("stdout", subprocess.PIPE) + run_kwargs.setdefault("stderr", subprocess.STDOUT) try: if verbose: @@ -180,10 +192,21 @@ def compile_lib(self, timeout: float = None): except Exception as e: raise RuntimeError(f"Compile kernel failed because of {e}") from e + captured = "" + if ret.stdout is not None: + captured = ret.stdout.decode("utf-8", errors="replace") + if is_hip_target(target): + captured = filter_and_record(captured) + if ret.returncode != 0: - captured = ret.stdout.decode("utf-8", errors="replace") if ret.stdout else "" raise RuntimeError(f"Compilation Failed! {command}\n{captured}\n{self.lib_code}") + # On HIP success, surface any real (non-remark) warnings only when + # the caller asked for verbose output, matching `compile_hip`'s + # behavior. Quiet by default so autotune logs stay clean. + if is_hip_target(target) and verbose and captured.strip(): + print(captured) + self.srcpath = src.name self.libpath = libpath diff --git a/tilelang/jit/kernel.py b/tilelang/jit/kernel.py index b63884e30..19e61774a 100644 --- a/tilelang/jit/kernel.py +++ b/tilelang/jit/kernel.py @@ -7,7 +7,8 @@ except ImportError: # Python < 3.10 from typing_extensions import ParamSpec -from tilelang.jit.adapter.utils import is_cutedsl_target, is_metal_target, is_cuda_target +from tilelang.contrib.hip_resource_info import pop_recorded, reset_recorder +from tilelang.jit.adapter.utils import is_cutedsl_target, is_metal_target, is_cuda_target, is_hip_target from tvm.target import Target from tvm.tir import PrimFunc @@ -243,6 +244,15 @@ def _compile_and_create_adapter(self, tilelang_func: PrimFunc, out_idx: list[int dump_ir_path = pass_configs.get(PassConfigKey.TL_DUMP_IR_DIR, "./dump_ir") # Default dump path pass_instruments.append(tvm.ir.instrument.DumpIR(dump_dir=dump_ir_path)) + # On HIP, open a recorder window so the kernel-resource-usage + # remarks emitted by hipcc get parsed into self._resource_usage + # (consumed by kernel.n_regs etc.). The window has to stay open + # past adapter construction: tvm_ffi runs hipcc inside lower() + # (enable_device_compile=True), but cython invokes it later from + # the adapter constructor. + capture_resources = is_hip_target(target) + if capture_resources: + reset_recorder() with tvm.transform.PassContext(opt_level=3, config=pass_configs, instruments=pass_instruments), self.target: artifact = tilelang.lower( tilelang_func, @@ -332,6 +342,9 @@ def _compile_and_create_adapter(self, tilelang_func: PrimFunc, out_idx: list[int # Handle invalid backend. raise ValueError(f"Invalid execution backend: {execution_backend}") + if capture_resources: + self._resource_usage = pop_recorded() + return adapter def _create_adapter_from_database( @@ -626,6 +639,40 @@ def kernel_source(self) -> str: def host_source(self) -> str: return str(self.artifact.host_mod) if self.artifact else "" + @property + def resource_usage(self) -> dict[str, Any]: + """{kernel_name: KernelResourceUsage} parsed from clang's + kernel-resource-usage remarks. HIP only; empty for other + targets and for kernels loaded from cache (no compile happened). + """ + return getattr(self, "_resource_usage", {}) or {} + + def _primary_resource_usage(self): + usage = self.resource_usage + if not usage: + return None + gsym = None + if self.prim_func is not None and self.prim_func.attrs is not None: + gsym = self.prim_func.attrs.get("global_symbol") + if gsym is not None and str(gsym) in usage: + return usage[str(gsym)] + return next(iter(usage.values())) + + @property + def n_regs(self) -> int | None: + info = self._primary_resource_usage() + return info.n_regs if info is not None else None + + @property + def n_spills(self) -> int | None: + info = self._primary_resource_usage() + return info.n_spills if info is not None else None + + @property + def n_max_threads(self) -> int | None: + info = self._primary_resource_usage() + return info.n_max_threads if info is not None else None + def export_library(self, kernel_file: str) -> None: """ Exports the compiled kernel function to a shared library file. From 7f594d13698323caa5ea7b4752548c8df768a494 Mon Sep 17 00:00:00 2001 From: zhutaoyu Date: Sat, 16 May 2026 06:06:04 +0000 Subject: [PATCH 02/34] fix for final merge --- tilelang/cache/kernel_cache.py | 4 +--- tilelang/jit/adapter/libgen.py | 26 +++++++------------------- tilelang/jit/kernel.py | 12 ++---------- 3 files changed, 10 insertions(+), 32 deletions(-) diff --git a/tilelang/cache/kernel_cache.py b/tilelang/cache/kernel_cache.py index bd67027b2..06a235b79 100644 --- a/tilelang/cache/kernel_cache.py +++ b/tilelang/cache/kernel_cache.py @@ -461,9 +461,7 @@ def _save_kernel_to_disk(self, key: str, kernel: JITKernel, func: Callable = Non self.logger.debug(f"Saving kernel parameters to disk: {params_path}") KernelCache._safe_write_file(params_path, "wb", lambda file: cloudpickle.dump(kernel.params, file)) - # Persist parsed kernel-resource-usage remarks (HIP only; - # empty dict on other targets) so cache hits don't lose - # kernel.n_regs / n_spills / resource_usage. + # Persist HIP kernel-resource-usage remarks usage = getattr(kernel, "_resource_usage", None) or {} if usage: dump_to_file(usage, os.path.join(staging_path, self.resource_usage_path)) diff --git a/tilelang/jit/adapter/libgen.py b/tilelang/jit/adapter/libgen.py index 125f693f6..5b61067a2 100644 --- a/tilelang/jit/adapter/libgen.py +++ b/tilelang/jit/adapter/libgen.py @@ -21,7 +21,7 @@ from tilelang.env import TILELANG_TEMPLATE_PATH from tilelang.utils.target import target_get_mcpu -from tilelang.contrib.hip_resource_info import filter_and_record, hipcc_remark_flag +from tilelang.contrib.hip_resource_info import filter_and_record from .utils import is_cpu_target, is_cuda_target, is_hip_target @@ -128,10 +128,7 @@ def compile_lib(self, timeout: float = None): f"--offload-arch={arch}", "--shared", src.name, - # Emit per-kernel resource-usage remarks; tilelang parses - # and strips them below before printing/raising so they - # don't pollute autotune logs. - hipcc_remark_flag(), + "-Rpass-analysis=kernel-resource-usage", ] command += [ "-I" + COMPOSABLE_KERNEL_INCLUDE_DIR, @@ -178,9 +175,6 @@ def compile_lib(self, timeout: float = None): stdout=subprocess.PIPE, stderr=subprocess.STDOUT, ) - # On HIP we always capture stdio so the kernel-resource-usage - # remarks can be parsed + stripped before anything reaches the - # terminal; the filtered output still surfaces on error. if is_hip_target(target): run_kwargs.setdefault("stdout", subprocess.PIPE) run_kwargs.setdefault("stderr", subprocess.STDOUT) @@ -192,20 +186,14 @@ def compile_lib(self, timeout: float = None): except Exception as e: raise RuntimeError(f"Compile kernel failed because of {e}") from e - captured = "" - if ret.stdout is not None: - captured = ret.stdout.decode("utf-8", errors="replace") - if is_hip_target(target): - captured = filter_and_record(captured) - if ret.returncode != 0: + captured = ret.stdout.decode("utf-8", errors="replace") if ret.stdout else "" raise RuntimeError(f"Compilation Failed! {command}\n{captured}\n{self.lib_code}") - # On HIP success, surface any real (non-remark) warnings only when - # the caller asked for verbose output, matching `compile_hip`'s - # behavior. Quiet by default so autotune logs stay clean. - if is_hip_target(target) and verbose and captured.strip(): - print(captured) + if is_hip_target(target) and ret.stdout is not None: + captured = filter_and_record(ret.stdout.decode("utf-8", errors="replace")) + if verbose and captured.strip(): + print(captured) self.srcpath = src.name self.libpath = libpath diff --git a/tilelang/jit/kernel.py b/tilelang/jit/kernel.py index 19e61774a..ff31cfdc8 100644 --- a/tilelang/jit/kernel.py +++ b/tilelang/jit/kernel.py @@ -244,12 +244,7 @@ def _compile_and_create_adapter(self, tilelang_func: PrimFunc, out_idx: list[int dump_ir_path = pass_configs.get(PassConfigKey.TL_DUMP_IR_DIR, "./dump_ir") # Default dump path pass_instruments.append(tvm.ir.instrument.DumpIR(dump_dir=dump_ir_path)) - # On HIP, open a recorder window so the kernel-resource-usage - # remarks emitted by hipcc get parsed into self._resource_usage - # (consumed by kernel.n_regs etc.). The window has to stay open - # past adapter construction: tvm_ffi runs hipcc inside lower() - # (enable_device_compile=True), but cython invokes it later from - # the adapter constructor. + # open a recorder window for kernel-resource-usage remarks capture_resources = is_hip_target(target) if capture_resources: reset_recorder() @@ -641,10 +636,7 @@ def host_source(self) -> str: @property def resource_usage(self) -> dict[str, Any]: - """{kernel_name: KernelResourceUsage} parsed from clang's - kernel-resource-usage remarks. HIP only; empty for other - targets and for kernels loaded from cache (no compile happened). - """ + """HIP only now""" return getattr(self, "_resource_usage", {}) or {} def _primary_resource_usage(self): From ad72ada6bd9f2a5362c14cbe7a6f6129f5592953 Mon Sep 17 00:00:00 2001 From: zhutaoyu Date: Sat, 16 May 2026 06:07:26 +0000 Subject: [PATCH 03/34] fix for final merge --- tilelang/jit/adapter/libgen.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tilelang/jit/adapter/libgen.py b/tilelang/jit/adapter/libgen.py index 5b61067a2..2de3bbacc 100644 --- a/tilelang/jit/adapter/libgen.py +++ b/tilelang/jit/adapter/libgen.py @@ -20,7 +20,6 @@ from tilelang.contrib.rocm import find_rocm_path, get_rocm_arch from tilelang.env import TILELANG_TEMPLATE_PATH from tilelang.utils.target import target_get_mcpu - from tilelang.contrib.hip_resource_info import filter_and_record from .utils import is_cpu_target, is_cuda_target, is_hip_target From 23be555ca90e421832018788348509d43552085d Mon Sep 17 00:00:00 2001 From: zhutaoyu Date: Sat, 16 May 2026 06:08:12 +0000 Subject: [PATCH 04/34] fix for final merge --- tilelang/jit/kernel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tilelang/jit/kernel.py b/tilelang/jit/kernel.py index ff31cfdc8..b8ac91aac 100644 --- a/tilelang/jit/kernel.py +++ b/tilelang/jit/kernel.py @@ -7,7 +7,6 @@ except ImportError: # Python < 3.10 from typing_extensions import ParamSpec -from tilelang.contrib.hip_resource_info import pop_recorded, reset_recorder from tilelang.jit.adapter.utils import is_cutedsl_target, is_metal_target, is_cuda_target, is_hip_target from tvm.target import Target from tvm.tir import PrimFunc @@ -26,6 +25,7 @@ from tilelang.profiler import Profiler, TensorSupplyType from tilelang.utils.target import determine_target from tilelang.contrib import nvcc as tl_nvcc +from tilelang.contrib.hip_resource_info import pop_recorded, reset_recorder from tilelang.transform import PassConfigKey from tilelang.transform.pass_config import normalize_pass_configs import logging From f276441ae45854bfcd7aeb6d09c9fe55d22e5dbe Mon Sep 17 00:00:00 2001 From: zhutaoyu Date: Sat, 16 May 2026 17:32:38 +0000 Subject: [PATCH 05/34] add a flag to save temp files --- tilelang/env.py | 2 ++ tilelang/jit/adapter/libgen.py | 4 +++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/tilelang/env.py b/tilelang/env.py index f66b7fa35..68b795be4 100644 --- a/tilelang/env.py +++ b/tilelang/env.py @@ -312,6 +312,7 @@ class Environment: TILELANG_CLEANUP_TEMP_FILES = EnvVar( "TILELANG_CLEANUP_TEMP_FILES", "0" ) # cleanup temporary compiler files/dirs after compilation (default: keep for debugging) + TILELANG_HIP_SAVE_TEMP_FILES = EnvVar("TILELANG_HIP_SAVE_TEMP_FILES", "0") # save temporary files for HIP compilation # Auto-tuning settings TILELANG_AUTO_TUNING_DISABLE_CACHE = EnvVar("TILELANG_AUTO_TUNING_DISABLE_CACHE", "0") @@ -502,3 +503,4 @@ def prepend_pythonpath(path): CUTLASS_INCLUDE_DIR = env.CUTLASS_INCLUDE_DIR COMPOSABLE_KERNEL_INCLUDE_DIR = env.COMPOSABLE_KERNEL_INCLUDE_DIR TILELANG_TEMPLATE_PATH = env.TILELANG_TEMPLATE_PATH +TILELANG_HIP_SAVE_TEMP_FILES = env.TILELANG_HIP_SAVE_TEMP_FILES diff --git a/tilelang/jit/adapter/libgen.py b/tilelang/jit/adapter/libgen.py index 2de3bbacc..238775e35 100644 --- a/tilelang/jit/adapter/libgen.py +++ b/tilelang/jit/adapter/libgen.py @@ -114,7 +114,7 @@ def compile_lib(self, timeout: float = None): ] elif is_hip_target(target): - from tilelang.env import COMPOSABLE_KERNEL_INCLUDE_DIR + from tilelang.env import COMPOSABLE_KERNEL_INCLUDE_DIR, TILELANG_HIP_SAVE_TEMP_FILES src = tempfile.NamedTemporaryFile(mode="w", suffix=".cpp", delete=False) # noqa: SIM115 libpath = src.name.replace(".cpp", ".so") @@ -132,6 +132,8 @@ def compile_lib(self, timeout: float = None): command += [ "-I" + COMPOSABLE_KERNEL_INCLUDE_DIR, ] + if TILELANG_HIP_SAVE_TEMP_FILES != "0": + command += ["--save-temps", "-g"] elif is_cpu_target(target): from tilelang.contrib.cc import get_cplus_compiler From 600b516bea9270583a05744b5c82e5db054da217 Mon Sep 17 00:00:00 2001 From: zhutaoyu Date: Sat, 16 May 2026 17:47:08 +0000 Subject: [PATCH 06/34] add a flag to save temp files --- tilelang/engine/phase.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index 7c5b433dc..03ae0b4ce 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -282,10 +282,11 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: mod = tilelang.transform.MarkCudaSyncCalls(have_pdl(target))(mod) mod = tilelang.transform.AnnotateReadOnlyParams()(mod) + print(mod.script()) # MergeSharedMemoryAllocations must be applied after SplitHostDevice # because the merged allocation site is at the beginning of each device function - enable_aggressive_merge = should_enable_aggressive_merge(pass_ctx=pass_ctx, target=target) - mod = tilelang.transform.MergeSharedMemoryAllocations(enable_aggressive_merge=enable_aggressive_merge)(mod) + # enable_aggressive_merge = should_enable_aggressive_merge(pass_ctx=pass_ctx, target=target) + # mod = tilelang.transform.MergeSharedMemoryAllocations(enable_aggressive_merge=enable_aggressive_merge)(mod) # InjectFenceProxy is a no-op on targets that lack the TMA / async-proxy # programming model; the pass itself checks the PrimFunc's target. mod = tilelang.transform.InjectFenceProxy()(mod) From f0fd44f9912a30709d866b5e6eb8b5266ac68765 Mon Sep 17 00:00:00 2001 From: zhutaoyu Date: Sun, 17 May 2026 06:09:09 +0000 Subject: [PATCH 07/34] M2: add cp_async_gs_lds_with_rsrc device templates for gfx950 Extends src/tl_templates/hip/copy.h with two new HIP device templates that emit buffer_load_dwordx4 ... lds (the gfx950 direct global-to-LDS DMA that bypasses VGPRs): - cp_async_gs_lds: self-contained variant; computes the buffer resource descriptor and base address per call. - cp_async_gs_lds_with_rsrc: variant taking a pre-hoisted resource descriptor + base address. This is what the HoistBufferResource Python pass (Round 5 / M6) will rewrite calls to use, amortising the readfirstlane overhead across unrolled loops. Both templates only emit the direct-DMA path for N == 16; smaller copies fall back to the existing cp_async_gs. The 16-byte path requires that the LDS destination be lane-contiguous (base + lane_id * 16); the swizzle-swap optimisation in lower_tile_op.cc (Round 6 / M7) guarantees this by moving the XOR swizzle from the LDS store side to the global load side. Reuses the existing make_wave_buffer_resource helper at copy.h:22 rather than redeclaring it. The inline-asm body is lifted from the reference branch zty_opt_can_run_1120flops because the asm is hardware-pinned to gfx950 and has no branch-specific dependencies. Function-to-site mapping audit (Round 0 / M1) located the insertion point for HIP codegen handlers at src/backend/rocm/codegen/codegen_hip.cc (refactored from the reference's src/target/codegen_hip.cc); those handlers land in Round 3 / M4. Verification: USE_ROCM=ON pip install -e . succeeds; `import tilelang` loads. Inline-asm validation deferred to JIT-compile time (header-only template; uninstantiated until codegen emits the call). Co-Authored-By: Claude Opus 4 (1M context) --- src/tl_templates/hip/copy.h | 51 +++++++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/src/tl_templates/hip/copy.h b/src/tl_templates/hip/copy.h index 6142e3fdf..82155aecb 100644 --- a/src/tl_templates/hip/copy.h +++ b/src/tl_templates/hip/copy.h @@ -138,4 +138,55 @@ TL_DEVICE void cp_async_gs_conditional(void *lds_base_ptr, } } +// gfx950 (CDNA4 / MI350) direct global→LDS copy via buffer_load_dwordx4 ... lds. +// Bypasses VGPRs entirely. Only valid when LDS destination is lane-contiguous +// (base + lane_id * N bytes); the swizzle-swap pass in LowerTileOp guarantees +// this by moving the XOR swizzle from the LDS store side to the global load +// side. The compiler is expected to emit this only for 16-byte copies; other +// sizes fall back to cp_async_gs. +template +TL_DEVICE void cp_async_gs_lds(void *lds_base_ptr, void const *global_base_ptr) { + if constexpr (N == 16) { + auto rsrc = make_wave_buffer_resource(global_base_ptr); + uint32_t my_lo = + static_cast(reinterpret_cast(global_base_ptr)); + uint32_t base_lo = __builtin_amdgcn_readfirstlane(my_lo); + uint32_t voffset = my_lo - base_lo; + uint32_t lds_cur = __builtin_amdgcn_readfirstlane( + static_cast(reinterpret_cast(lds_base_ptr))); + asm volatile("s_mov_b32 m0, %0; \n\t" + "buffer_load_dwordx4 %1, %2, 0 offen lds;\n\t" + : : "s"(lds_cur), "v"(voffset), "s"(rsrc) + : "memory"); + } else { + cp_async_gs(lds_base_ptr, global_base_ptr); + } +} + +// Variant with pre-hoisted buffer resource descriptor and base address. +// rsrc and rsrc_base_lo are computed once at kernel entry (see the +// HoistBufferResource Python pass) so per-call readfirstlane overhead is +// amortised across the many cp_async_gs_lds_with_rsrc calls in an unrolled +// loop. rsrc_base_lo must equal readfirstlane((uint32_t)(uintptr_t)A) for +// the same A passed to make_wave_buffer_resource that produced rsrc. +template +TL_DEVICE void cp_async_gs_lds_with_rsrc(void *lds_base_ptr, + void const *global_base_ptr, + int32x4_t rsrc, + uint32_t rsrc_base_lo) { + if constexpr (N == 16) { + uint32_t my_lo = + static_cast(reinterpret_cast(global_base_ptr)); + uint32_t voffset = my_lo - rsrc_base_lo; + uint32_t lds_cur = __builtin_amdgcn_readfirstlane( + static_cast(reinterpret_cast(lds_base_ptr))); + asm volatile("s_mov_b32 m0, %0; \n\t" + "buffer_load_dwordx4 %1, %2, 0 offen lds;\n\t" + : : "s"(lds_cur), "v"(voffset), "s"(rsrc) + : "memory"); + } else { + cp_async_gs(lds_base_ptr, global_base_ptr); + } +} + } // namespace tl From adb4364884097c38f98ad9e3bd644ab5620754a7 Mon Sep 17 00:00:00 2001 From: zhutaoyu Date: Sun, 17 May 2026 06:10:52 +0000 Subject: [PATCH 08/34] M3: register ptx_cp_async_lds / ptx_make_buffer_resource / ptx_cp_async_lds_rsrc TIR ops Declares three new TIR builtin ops on the tl namespace and registers them in builtin.cc. Inserted right after the existing ptx_cp_async registration so they appear in a coherent block: - ptx_cp_async_lds(dst, src, bytes): same shape as ptx_cp_async but signals to codegen that the call lowers to cp_async_gs_lds (the hardware buffer_load ... lds path added in M2). - ptx_make_buffer_resource(global_ptr): single-arg op that lowers to make_wave_buffer_resource((const void*)(global_ptr)). - ptx_cp_async_lds_rsrc(dst, src, bytes, rsrc, base): extended form carrying the pre-hoisted resource descriptor and base address. The HoistBufferResource Python pass (M6) rewrites ptx_cp_async_lds calls to this form once per kernel. All three use the same call-effect kind (kOpaque) as ptx_cp_async since they have global memory side effects. set_num_inputs(-1) on the *_lds variants matches ptx_cp_async, which already uses -1 to support both predicated and non-predicated forms. Verification: pip install -e . succeeds; `import tilelang` and `from tvm import tir` both succeed. Co-Authored-By: Claude Opus 4 (1M context) --- src/op/builtin.cc | 15 +++++++++++++++ src/op/builtin.h | 37 +++++++++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+) diff --git a/src/op/builtin.cc b/src/op/builtin.cc index 508f90c65..2798520c7 100644 --- a/src/op/builtin.cc +++ b/src/op/builtin.cc @@ -272,6 +272,21 @@ TIR_DEFINE_TL_BUILTIN(ptx_cp_async) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_TL_BUILTIN(ptx_cp_async_lds) + .set_num_inputs(-1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(ptx_make_buffer_resource) + .set_num_inputs(1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(ptx_cp_async_lds_rsrc) + .set_num_inputs(-1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + TIR_DEFINE_TL_BUILTIN(fence_proxy_async) .set_num_inputs(0) .set_attr("TCallEffectKind", diff --git a/src/op/builtin.h b/src/op/builtin.h index 69a65ab10..9d9755806 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -464,6 +464,43 @@ TVM_DLL const Op &ptx_cp_async_barrier_noinc(); */ TVM_DLL const Op &ptx_cp_async(); +/*! + * \brief Truly async G2S copy via buffer_load_dwordx4 ... lds (gfx950+). + * + * Same signature as ptx_cp_async but lowers to cp_async_gs_lds which + * uses the hardware buffer_load ... lds instruction (bypasses VGPRs). + * Only valid when LDS addresses are lane-contiguous; the swizzle-swap + * pass in lower_tile_op.cc moves the XOR swizzle from the LDS store side + * to the global load side to make this safe. + * + * ptx_cp_async_lds(dst_access_ptr, src_access_ptr, bytes) + */ +TVM_DLL const Op &ptx_cp_async_lds(); + +/*! + * \brief Create a buffer resource descriptor for async G2S LDS copy (gfx950+). + * + * ptx_make_buffer_resource(global_ptr) + * + * Returns an int32x4_t buffer resource descriptor via + * make_wave_buffer_resource (defined in src/tl_templates/hip/copy.h). + */ +TVM_DLL const Op &ptx_make_buffer_resource(); + +/*! + * \brief Truly async G2S copy with pre-computed buffer resource (gfx950+). + * + * Same as ptx_cp_async_lds but takes a pre-hoisted buffer resource + * descriptor + base address to avoid redundant readfirstlane / + * make_wave_buffer_resource calls inside unrolled loops. The + * HoistBufferResource Python pass rewrites ptx_cp_async_lds calls to this + * form once per kernel. + * + * ptx_cp_async_lds_rsrc(dst_access_ptr, src_access_ptr, bytes, rsrc_var, + * base_var) + */ +TVM_DLL const Op &ptx_cp_async_lds_rsrc(); + /*! * \brief Pack two b16 value into a b32 value * From e38d5fd112a9c7b595f3c389e1b43c93b3c8d216 Mon Sep 17 00:00:00 2001 From: zhutaoyu Date: Sun, 17 May 2026 06:13:11 +0000 Subject: [PATCH 09/34] M4: HIP codegen handlers for buffer_load_lds ops + hoisted-resource AttrStmt Adds codegen support in src/backend/rocm/codegen/codegen_hip.cc for the three TIR ops registered in M3, plus the AttrStmt/LetStmt machinery that the HoistBufferResource Python pass (M6) will emit. CallNode handlers (VisitExpr_): - tl::ptx_make_buffer_resource(ptr) -> expression make_wave_buffer_resource((const void*)(ptr)) - tl::ptx_cp_async_lds_rsrc(dst, src, bytes, rsrc, base) -> statement tl::cp_async_gs_lds_with_rsrc(dst, src, rsrc, base); - tl::ptx_cp_async_lds(dst, src, bytes [, pred]) -> statement tl::cp_async_gs_lds(dst, src); (predicated form falls back to cp_async_gs_conditional like the regular ptx_cp_async) LetStmt visitor (new): when the bound value is a ptx_make_buffer_resource Call, emit `auto x = ...;` instead of letting the base CodeGenC try to print a C-typed declaration for the int32x4_t result. AttrStmt branches (extended): - "buffer_resource_var": emit `auto {rsrc_vid} = make_wave_buffer_resource( (const void*)({buf_ptr}));` - "buffer_base_var": emit `uint32_t {base_vid} = __builtin_amdgcn_readfirstlane((uint32_t)(uintptr_t)({buf_ptr}));` These match the prologue shape demonstrated in /root/tile-kernel-bench-cdna4/_fast.cpp lines 10-13. The hoisting pass in M6 will wrap the kernel body with these AttrStmts so the descriptors materialise at kernel entry rather than per call. Verification: pip install -e . succeeds; `import tilelang` succeeds. Inner-loop emission of cp_async_gs_lds_with_rsrc will be exercised once M5 (injection decision) routes T.copy through the new ops. Co-Authored-By: Claude Opus 4 (1M context) --- src/backend/rocm/codegen/codegen_hip.cc | 77 +++++++++++++++++++++++-- src/backend/rocm/codegen/codegen_hip.h | 1 + 2 files changed, 73 insertions(+), 5 deletions(-) diff --git a/src/backend/rocm/codegen/codegen_hip.cc b/src/backend/rocm/codegen/codegen_hip.cc index c6ed1323d..6c7984ed2 100644 --- a/src/backend/rocm/codegen/codegen_hip.cc +++ b/src/backend/rocm/codegen/codegen_hip.cc @@ -913,9 +913,31 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) { } this->stream << ");\n"; }; - if (op->op.same_as(builtin::ptx_cp_async())) { + if (op->op.same_as(tl::ptx_make_buffer_resource())) { + // Expression form: emits make_wave_buffer_resource((const void*)(ptr)). + // The enclosing LetStmt visitor recognises this Call and emits `auto x =`. + ICHECK(op->args.size() == 1) + << "ptx_make_buffer_resource expects 1 argument (global_ptr)"; + std::string ptr = this->PrintExpr(op->args[0]); + os << "make_wave_buffer_resource((const void*)(" << ptr << "))"; + } else if (op->op.same_as(tl::ptx_cp_async_lds_rsrc())) { + // args = [dst, src, bytes, rsrc_var, base_var] + ICHECK(op->args.size() == 5) + << "ptx_cp_async_lds_rsrc expects 5 arguments"; + std::string dst = this->PrintExpr(op->args[0]); + std::string src = this->PrintExpr(op->args[1]); + std::string size = this->PrintExpr(op->args[2]); + std::string rsrc = this->PrintExpr(op->args[3]); + std::string base = this->PrintExpr(op->args[4]); + this->PrintIndent(); + this->stream << "tl::cp_async_gs_lds_with_rsrc<" << size << ">(" + << dst << ", " << src << ", " << rsrc << ", " << base + << ");\n"; + } else if (op->op.same_as(builtin::ptx_cp_async()) || + op->op.same_as(tl::ptx_cp_async_lds())) { // args[0] = dst_access_ptr, args[1] = src_access_ptr, args[2] = bytes, - // args[3] = predicate (optional) + // args[3] = predicate (optional). ptx_cp_async_lds shares this shape but + // routes to the gs_lds (buffer_load ... lds) template instead. ICHECK(op->args.size() == 3 || op->args.size() == 4) << "ptx_cp_async expects 3 or 4 arguments (dst_access_ptr, " "src_access_ptr, bytes, [predicate])"; @@ -925,8 +947,9 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) { this->PrintIndent(); if (op->args.size() == 3) { // Non-predicated version - this->stream << "tl::cp_async_gs<" << size << ">(" << dst << ", " << src - << ");\n"; + bool use_lds = op->op.same_as(tl::ptx_cp_async_lds()); + this->stream << (use_lds ? "tl::cp_async_gs_lds<" : "tl::cp_async_gs<") + << size << ">(" << dst << ", " << src << ");\n"; } else { // Predicated version std::string condition = this->PrintExpr(op->args[3]); @@ -1437,7 +1460,32 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) { } void CodeGenTileLangHIP::VisitStmt_(const AttrStmtNode *op) { - if (op->attr_key == tl::attr::kLexicalAllocScope) { + if (op->attr_key == "buffer_resource_var") { + // Hoisted resource descriptor from the HoistBufferResource Python pass. + // Emits: auto {rsrc_var} = make_wave_buffer_resource((const void*)({buf_var})); + auto rsrc_var = Downcast(op->node); + std::string rsrc_vid = AllocVarID(rsrc_var.get()); + std::string buf_ptr = PrintExpr(op->value); + this->PrintIndent(); + this->stream << "auto " << rsrc_vid + << " = make_wave_buffer_resource((const void*)(" << buf_ptr + << "));\n"; + this->VisitStmt(op->body); + return; + } else if (op->attr_key == "buffer_base_var") { + // Hoisted readfirstlane base address from the HoistBufferResource pass. + // Emits: uint32_t {base_var} = __builtin_amdgcn_readfirstlane( + // (uint32_t)(uintptr_t)({buf_var})); + auto base_var = Downcast(op->node); + std::string base_vid = AllocVarID(base_var.get()); + std::string buf_ptr = PrintExpr(op->value); + this->PrintIndent(); + this->stream << "uint32_t " << base_vid + << " = __builtin_amdgcn_readfirstlane(" + << "(uint32_t)(uintptr_t)(" << buf_ptr << "));\n"; + this->VisitStmt(op->body); + return; + } else if (op->attr_key == tl::attr::kLexicalAllocScope) { PrintIndent(); stream << "{\n"; int scope = BeginScope(); @@ -1472,6 +1520,25 @@ void CodeGenTileLangHIP::VisitStmt_(const AttrStmtNode *op) { CodeGenC::VisitStmt_(op); } +void CodeGenTileLangHIP::VisitStmt_(const LetStmtNode *op) { + // For LetStmt(var = ptx_make_buffer_resource(buf)), emit `auto x = ...;` + // instead of the C-typed declaration the base class would produce. The + // return type is int32x4_t and naming it explicitly is brittle across + // backends, so `auto` keeps the template lookup in make_wave_buffer_resource + // responsible for the type. + if (auto *call = op->value.as()) { + if (call->op.same_as(tl::ptx_make_buffer_resource())) { + std::string value = PrintExpr(op->value); + PrintIndent(); + stream << "auto " << AllocVarID(op->var.get()) << " = " << value + << ";\n"; + PrintStmt(op->body); + return; + } + } + CodeGenC::VisitStmt_(op); +} + void CodeGenTileLangHIP::VisitStmt_(const AllocateNode *op) { ICHECK(!is_zero(op->condition)); std::string vid = AllocVarID(op->buffer_var.get()); diff --git a/src/backend/rocm/codegen/codegen_hip.h b/src/backend/rocm/codegen/codegen_hip.h index 1030352e9..fc995e008 100644 --- a/src/backend/rocm/codegen/codegen_hip.h +++ b/src/backend/rocm/codegen/codegen_hip.h @@ -52,6 +52,7 @@ class CodeGenTileLangHIP final : public CodeGenC { void VisitExpr_(const ShuffleNode *op, std::ostream &os) final; // NOLINT(*) void VisitStmt_(const AllocateNode *op) final; void VisitStmt_(const AttrStmtNode *op) final; + void VisitStmt_(const LetStmtNode *op) final; // Override this as a work around for __grid_constant__ parameter void AddFunction(const PrimFunc &f); From 1053bcb55e70d1b47f99fed4b80504ab56fd49ec Mon Sep 17 00:00:00 2001 From: zhutaoyu Date: Sun, 17 May 2026 06:33:37 +0000 Subject: [PATCH 10/34] M5: route eligible 16B copies to ptx_cp_async_lds on gfx950 Adds an opt-in `enable_buffer_load_lds` knob to InjectPTXAsyncCopy and the underlying class, wired from src/backend/rocm/op/copy.cc only when TargetIsGfx950(T.target) holds. Default false everywhere else, so CUDA and non-gfx950 ROCm paths are unchanged. Routing in MakeCPAsyncStmtFromLoads emits tl::ptx_cp_async_lds(dst, src, total_bytes) instead of tl::ptx_cp_async(dst, src, num_elems) when ALL of the following hold: - the flag is on - the copy is non-predicated - total_bytes == 16 (matches the device template's only specialised N) - the destination buffer scope is "shared" or "shared.dyn" - the destination LDS index contains no bitwise_xor term (a conservative proxy for lane contiguity; before the M7 swizzle-swap optimisation lands, swizzled LDS layouts will contain XOR and so the routing safely no-ops back to the existing ptx_cp_async path) Note arg 2 is byte width, not logical element count. The codegen handler for ptx_cp_async_lds (M4) prints arg 2 verbatim as the template width because the device template is currently only specialised for N == 16. Keeping the logical-vs-byte distinction at the boundary avoids the hazard Codex flagged where a blind op swap could otherwise emit cp_async_gs_lds<1> or <8>. Side artefacts: CopyIndexInfo gains a total_bytes field already computed in PrepareCopyIndexInfo; MakeCPAsyncStmtFromLoads loses its static qualifier so it can read enable_buffer_load_lds_ from `this`. Downstream cp.async recognisers updated so they treat the new ops the same as the existing ones: - src/transform/lower_ptx_async_copy.cc AnalyzeCopyRegion - src/transform/legalize_safe_memory_access.cc IsCPAsyncOp - src/transform/vectorize_loop.cc Call dispatch - src/transform/thread_storage_sync.cc is_cp_async lambda Verification: pip install -e . succeeds; `import tilelang` succeeds. Full bench correctness/perf evaluation is the M8 job after M6 (hoisting) and M7 (swizzle-swap) land. Co-Authored-By: Claude Opus 4 (1M context) --- src/backend/rocm/op/copy.cc | 4 +- src/transform/legalize_safe_memory_access.cc | 3 +- src/transform/lower_ptx_async_copy.cc | 74 +++++++++++++++++--- src/transform/ptx_async_copy_injector.h | 8 ++- src/transform/thread_storage_sync.cc | 4 +- src/transform/vectorize_loop.cc | 4 +- 6 files changed, 83 insertions(+), 14 deletions(-) diff --git a/src/backend/rocm/op/copy.cc b/src/backend/rocm/op/copy.cc index df1c8ce7e..ffb2ac2b4 100644 --- a/src/backend/rocm/op/copy.cc +++ b/src/backend/rocm/op/copy.cc @@ -130,7 +130,9 @@ struct Copy { auto inject_result = InjectPTXAsyncCopy(lowered_loop, /*enable_auto_async_copy=*/true, /*async_without_async_commit_wait=*/ - no_implicit_commit_wait || GetIsAsyncCopy(op)); + no_implicit_commit_wait || GetIsAsyncCopy(op), + /*enable_buffer_load_lds=*/ + TargetIsGfx950(T.target)); Stmt cp_async_loop = inject_result.stmt; if (!inject_result.injected_ptx_async_copy) { DLOG(WARNING) << "cp.async rewrite miss for copy src=" << op.src->name diff --git a/src/transform/legalize_safe_memory_access.cc b/src/transform/legalize_safe_memory_access.cc index de85e0ac7..f67f34ffb 100644 --- a/src/transform/legalize_safe_memory_access.cc +++ b/src/transform/legalize_safe_memory_access.cc @@ -315,7 +315,8 @@ class SafeMemorysRewriter : public IRMutatorWithAnalyzer { } bool IsCPAsyncOp(const Op &op) { - return op == builtin::ptx_cp_async() || op == tl::ptx_cp_async(); + return op == builtin::ptx_cp_async() || op == tl::ptx_cp_async() || + op == tl::ptx_cp_async_lds() || op == tl::ptx_cp_async_lds_rsrc(); } static constexpr int kCPAsyncDstPtrArg = 0; diff --git a/src/transform/lower_ptx_async_copy.cc b/src/transform/lower_ptx_async_copy.cc index fc90b9ad4..1c6e3a965 100644 --- a/src/transform/lower_ptx_async_copy.cc +++ b/src/transform/lower_ptx_async_copy.cc @@ -32,9 +32,11 @@ using namespace tir; class PTXAsyncCopyInjector : public StmtMutator { public: explicit PTXAsyncCopyInjector(bool enable_auto_async_copy, - bool async_without_async_commit_wait) + bool async_without_async_commit_wait, + bool enable_buffer_load_lds = false) : enable_auto_async_copy_(enable_auto_async_copy), - async_without_async_commit_wait_(async_without_async_commit_wait) {} + async_without_async_commit_wait_(async_without_async_commit_wait), + enable_buffer_load_lds_(enable_buffer_load_lds) {} bool InjectedPTXAsyncCopy() const { return injected_ptx_async_copy_; } @@ -121,7 +123,9 @@ class PTXAsyncCopyInjector : public StmtMutator { store, /*dst_base_load=*/BufferLoad(store->buffer, store->indices), /*src_base_load=*/BufferLoad(load->buffer, load->indices), - /*num_elems=*/index_info->per_access_num_elems, predicated, + /*num_elems=*/index_info->per_access_num_elems, + /*total_bytes=*/index_info->total_bytes, + /*dst_check_index=*/index_info->dst_index, predicated, predicate_value); } @@ -142,8 +146,9 @@ class PTXAsyncCopyInjector : public StmtMutator { store, /*dst_base_load=*/BufferLoad(store->buffer, dst_base_indices.value()), /*src_base_load=*/BufferLoad(load->buffer, src_base_indices.value()), - /*num_elems=*/index_info->per_access_num_elems, predicated, - predicate_value); + /*num_elems=*/index_info->per_access_num_elems, + /*total_bytes=*/index_info->total_bytes, + /*dst_check_index=*/index_info->dst_index, predicated, predicate_value); } Stmt VisitStmt_(const SeqStmtNode *op) final { @@ -298,6 +303,10 @@ class PTXAsyncCopyInjector : public StmtMutator { PrimExpr dst_index; int index_lanes{1}; int per_access_num_elems{0}; + // Byte width of one vectorized transfer at the final PTX emission point. + // Already factors in `current_vectorized_lanes_`. Used by the gfx950 + // buffer_load...lds routing to confirm the 16-byte gate. + int total_bytes{0}; }; // Synchronization state for injected cp.async runs carried across statements. @@ -418,9 +427,32 @@ class PTXAsyncCopyInjector : public StmtMutator { info.dst_index = dst_index; info.index_lanes = index_lanes; info.per_access_num_elems = effective_lanes; + info.total_bytes = total_bytes; return info; } + // Conservative structural check used to gate the gfx950 buffer_load...lds + // routing: the destination LDS index must be lane-contiguous, i.e. it has + // no XOR-style swizzle term. Until the swizzle-swap optimisation in + // lower_tile_op.cc moves the XOR off the LDS store side (M7 in the + // accompanying plan) this returns false for ordinary swizzled layouts, so + // the routing safely no-ops and the existing ptx_cp_async path is used. + static bool ContainsBitwiseXor(const PrimExpr &expr) { + bool found = false; + tir::PostOrderVisit(expr, [&](const ObjectRef &node) { + if (found) return; + if (const auto *call = node.as()) { + if (call->op.same_as(tvm::tir::builtin::bitwise_xor())) { + found = true; + } + } + }); + return found; + } + static bool IsLdsLaneContiguous(const PrimExpr &dst_index) { + return !ContainsBitwiseXor(dst_index); + } + static PrimExpr ExtractVectorBase(const PrimExpr &index) { if (index.dtype().lanes() == 1) { return index; @@ -486,16 +518,35 @@ class PTXAsyncCopyInjector : public StmtMutator { IntImm(DataType::Int(32), rw_mask)}); } - static Optional + Optional MakeCPAsyncStmtFromLoads(const BufferStoreNode *store, const BufferLoad &dst_base_load, const BufferLoad &src_base_load, int num_elems, + int total_bytes, const PrimExpr &dst_check_index, bool predicated, const PrimExpr &predicate_value) { PrimExpr dst_access_ptr = MakeAccessPtrFromLoad(dst_base_load, num_elems, /*rw_mask=*/2); PrimExpr src_access_ptr = MakeAccessPtrFromLoad(src_base_load, num_elems, /*rw_mask=*/1); + // gfx950 routing: emit tl::ptx_cp_async_lds when the destination is a + // 16-byte non-predicated shared-memory write whose LDS index is lane- + // contiguous (no XOR swizzle). Note: arg 2 here is *byte width*, not + // logical element count, because the codegen handler for ptx_cp_async_lds + // prints arg 2 directly as the template width and the device template + // currently only specialises N == 16. + if (enable_buffer_load_lds_ && !predicated && total_bytes == 16) { + const std::string dst_scope = store->buffer.scope(); + const bool is_shared = + dst_scope == "shared" || dst_scope == "shared.dyn"; + if (is_shared && IsLdsLaneContiguous(dst_check_index)) { + ffi::Array lds_args = {dst_access_ptr, src_access_ptr, + PrimExpr(total_bytes)}; + return Evaluate(Call(store->buffer->dtype, tvm::tl::ptx_cp_async_lds(), + lds_args)); + } + } + ffi::Array cp_async_args; if (predicated) { cp_async_args = {dst_access_ptr, src_access_ptr, PrimExpr(num_elems), @@ -614,7 +665,9 @@ class PTXAsyncCopyInjector : public StmtMutator { return out; } if (call->op.same_as(builtin::ptx_cp_async()) || - call->op.same_as(tl::ptx_cp_async())) { + call->op.same_as(tl::ptx_cp_async()) || + call->op.same_as(tl::ptx_cp_async_lds()) || + call->op.same_as(tl::ptx_cp_async_lds_rsrc())) { return out; } if (call->op.same_as(builtin::ptx_commit_group())) { @@ -687,6 +740,7 @@ class PTXAsyncCopyInjector : public StmtMutator { bool enable_auto_async_copy_{true}; bool async_without_async_commit_wait_{false}; + bool enable_buffer_load_lds_{false}; int explicit_async_scope_depth_{0}; int current_vectorized_lanes_{1}; std::vector active_vectorized_loops_; @@ -700,9 +754,11 @@ using namespace tir::transform; PTXAsyncCopyInjectResult InjectPTXAsyncCopy(const Stmt &body, bool enable_auto_async_copy, - bool async_without_async_commit_wait) { + bool async_without_async_commit_wait, + bool enable_buffer_load_lds) { PTXAsyncCopyInjector injector(enable_auto_async_copy, - async_without_async_commit_wait); + async_without_async_commit_wait, + enable_buffer_load_lds); Stmt injected = injector(body); return {injector.Finalize(injected), injector.InjectedPTXAsyncCopy()}; } diff --git a/src/transform/ptx_async_copy_injector.h b/src/transform/ptx_async_copy_injector.h index 80c642562..8fb549059 100644 --- a/src/transform/ptx_async_copy_injector.h +++ b/src/transform/ptx_async_copy_injector.h @@ -15,10 +15,16 @@ struct PTXAsyncCopyInjectResult { * This is the statement-level entrypoint used by other transforms to apply the * same rewrite as the `tl.LowerPTXAsyncCopy` pass, but scoped to a region * (e.g., a lowered parallel loop) rather than the whole PrimFunc. + * + * `enable_buffer_load_lds` enables the gfx950-specific routing that emits + * tl::ptx_cp_async_lds for eligible 16-byte non-predicated shared-memory- + * destined copies whose LDS index is lane-contiguous (no XOR swizzle). The + * ROCm copy lowering pass passes this flag only when the target is gfx950+. */ PTXAsyncCopyInjectResult InjectPTXAsyncCopy(const tvm::tir::Stmt &body, bool enable_auto_async_copy, - bool async_without_async_commit_wait = false); + bool async_without_async_commit_wait = false, + bool enable_buffer_load_lds = false); } // namespace tl } // namespace tvm diff --git a/src/transform/thread_storage_sync.cc b/src/transform/thread_storage_sync.cc index a9320b150..f3ed659e8 100644 --- a/src/transform/thread_storage_sync.cc +++ b/src/transform/thread_storage_sync.cc @@ -1108,7 +1108,9 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { if (auto opt = op->op.as()) { const Op &call_op = opt.value(); return call_op.same_as(builtin::ptx_cp_async()) || - call_op.same_as(tl::ptx_cp_async()); + call_op.same_as(tl::ptx_cp_async()) || + call_op.same_as(tl::ptx_cp_async_lds()) || + call_op.same_as(tl::ptx_cp_async_lds_rsrc()); } return false; }(); diff --git a/src/transform/vectorize_loop.cc b/src/transform/vectorize_loop.cc index 9735a9a5b..5fdbdeaf8 100644 --- a/src/transform/vectorize_loop.cc +++ b/src/transform/vectorize_loop.cc @@ -797,7 +797,9 @@ class TLVectorizer : public StmtMutator, } else if (op->op.same_as(builtin::tvm_access_ptr())) { return MutateAccessPtrCall_(op); } else if (op->op.same_as(builtin::ptx_cp_async()) || - op->op.same_as(tl::ptx_cp_async())) { + op->op.same_as(tl::ptx_cp_async()) || + op->op.same_as(tl::ptx_cp_async_lds()) || + op->op.same_as(tl::ptx_cp_async_lds_rsrc())) { return MutatePTXCPAsyncExpr_(op); } auto optional_op = op->op.as(); From 7ac1d964e454d911844c07a34c7a8f71922abcb7 Mon Sep 17 00:00:00 2001 From: zhutaoyu Date: Sun, 17 May 2026 06:35:45 +0000 Subject: [PATCH 11/34] M6: HoistBufferResource Python pass + pipeline wiring (gfx950 only) New file tilelang/transform/hoist_buffer_resource.py with the descriptor-hoist half of the reference implementation. Scans the post-LowerAccessPtr PrimFunc body for tl.ptx_cp_async_lds calls, extracts the source buffer Var from each call's tvm_access_ptr arg, creates one __rsrc_ and one __base_ Var per unique source buffer, wraps the body with buffer_resource_var / buffer_base_var AttrStmts (consumed by the HIP codegen handlers from M4), and rewrites the calls to ptx_cp_async_lds_rsrc carrying the hoisted vars. Registered in tilelang/transform/__init__.py alongside the existing HoistBroadcastValues / DecoupleTypeCast imports. Inserted into tilelang/engine/phase.py OptimizeForTarget after MergeIfStmt, before MakePackedAPI, matching Codex's directive site so LowerAccessPtr has already lowered tl.access_ptr to tvm_access_ptr by the time the pass runs. The pass guards on target_is_gfx950(target) and is a no-op on every other target. AMD vmcnt wait-count scaling (the second half of the reference) is intentionally NOT included in this commit; it will land as a separate milestone (M6.5) only if the M6 / M7 bench shows a correctness failure attributable to async wait counts. Verification: pip install -e . succeeds; `from tilelang.transform import HoistBufferResource` imports the callable. End-to-end emission will be exercised once M7 lands the swizzle-swap optimisation that removes the XOR from LDS-side indices so M5's IsLdsLaneContiguous gate begins routing copies to ptx_cp_async_lds. Co-Authored-By: Claude Opus 4 (1M context) --- tilelang/engine/phase.py | 3 + tilelang/transform/__init__.py | 1 + tilelang/transform/hoist_buffer_resource.py | 126 ++++++++++++++++++++ 3 files changed, 130 insertions(+) create mode 100644 tilelang/transform/hoist_buffer_resource.py diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index 03ae0b4ce..45e62adcd 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -299,6 +299,9 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: mod = tilelang.transform.InjectTcgen05Fence()(mod) mod = tilelang.transform.MergeIfStmt()(mod) # NOTE: LowerPTXAsyncCopy is applied earlier (before PipelinePlanning). + # Hoist buffer resource descriptors for the gfx950 buffer_load...lds path. + # No-op on non-gfx950 targets (pass guards on target_is_gfx950). + mod = tilelang.transform.HoistBufferResource()(mod) if allow_warp_specialized(pass_ctx=pass_ctx, target=target): mod = tilelang.transform.AnnotateWarpGroupRegAlloc()(mod) mod = tilelang.transform.MakePackedAPI()(mod) diff --git a/tilelang/transform/__init__.py b/tilelang/transform/__init__.py index 599ccef1a..c3a2a3679 100644 --- a/tilelang/transform/__init__.py +++ b/tilelang/transform/__init__.py @@ -8,6 +8,7 @@ from tvm.ir.transform import PassContext # noqa: F401 from .add_bufstore_wrapper import AddWrapperForSingleBufStore # noqa: F401 from .hoist_broadcast_values import HoistBroadcastValues # noqa: F401 +from .hoist_buffer_resource import HoistBufferResource # noqa: F401 from .decouple_type_cast import DecoupleTypeCast # noqa: F401 diff --git a/tilelang/transform/hoist_buffer_resource.py b/tilelang/transform/hoist_buffer_resource.py new file mode 100644 index 000000000..b3d13c10c --- /dev/null +++ b/tilelang/transform/hoist_buffer_resource.py @@ -0,0 +1,126 @@ +"""Hoist make_wave_buffer_resource descriptors for gfx950 buffer_load...lds. + +On gfx950, the `cp_async_gs_lds_with_rsrc<16>` device template takes a +pre-computed buffer resource descriptor and a pre-computed wave-uniform +base address. Computing those per call would emit 4x readfirstlane plus +the resource bit-cast on every call site. In an unrolled tile-copy loop +the same global buffer is touched many times, so we lift the descriptor +to the kernel prologue once per source buffer and rewrite the calls to +the variant that takes the pre-hoisted pair. + +Pipeline order: this pass runs in the OptimizeForTarget phase after +ThreadSync/MergeIfStmt and before MakePackedAPI, which means the +tl::access_ptr calls have already been lowered by `LowerAccessPtr` to +`tir.tvm_access_ptr(ptype, data, offset, extent, rw_mask)`, so the +buffer Var is at args[1] of each access_ptr term. + +This pass is gfx950-only: on every other target it returns the PrimFunc +unchanged. + +NOTE: AMD vmcnt wait-count scaling (the second half of the reference +implementation on the zty_opt_can_run_1120flops branch) is deliberately +omitted in this commit. It will land as a separate milestone (M6.5) +only if the M6 bench shows a correctness failure attributable to async +wait counts. +""" + +from tvm import tir +from tvm.tir import AttrStmt, Call, Evaluate, Var, PrimFunc, stmt_functor +from tvm.tir.transform import prim_func_pass + +from tilelang.utils.target import target_is_gfx950 + +_op_ptx_cp_async_lds = tir.op.Op.get("tl.ptx_cp_async_lds") +_op_ptx_cp_async_lds_rsrc = tir.op.Op.get("tl.ptx_cp_async_lds_rsrc") +_op_tvm_access_ptr = tir.op.Op.get("tir.tvm_access_ptr") + + +def _extract_buffer_var(access_ptr_expr): + """Pull the buffer-data Var out of a lowered tvm_access_ptr call. + + After tl.LowerAccessPtr the access pointer is encoded as + ``tvm_access_ptr(ptype, data, offset, extent, rw_mask)`` so args[1] + is the Var of interest. Anything else (e.g. an unlowered tl.access_ptr + or a plain pointer expression) returns None and the call is skipped. + """ + if not isinstance(access_ptr_expr, Call): + return None + if access_ptr_expr.op != _op_tvm_access_ptr: + return None + if len(access_ptr_expr.args) < 2: + return None + data_arg = access_ptr_expr.args[1] + if isinstance(data_arg, Var): + return data_arg + return None + + +def _collect_buffer_vars(body): + """Discover unique source buffer Vars referenced by ptx_cp_async_lds calls. + + Returns an ordered dict {buf_var: (rsrc_var, base_var)} so the prologue + AttrStmts emit in a stable order. + """ + buffer_vars = {} + + def _visit(stmt): + if isinstance(stmt, Evaluate) and isinstance(stmt.value, Call): + if stmt.value.op == _op_ptx_cp_async_lds: + # ptx_cp_async_lds args: (dst_access_ptr, src_access_ptr, bytes) + buf_var = _extract_buffer_var(stmt.value.args[1]) + if buf_var is not None and buf_var not in buffer_vars: + rsrc_var = Var("__rsrc_" + buf_var.name, dtype="handle") + base_var = Var("__base_" + buf_var.name, dtype="uint32") + buffer_vars[buf_var] = (rsrc_var, base_var) + + stmt_functor.post_order_visit(body, _visit) + return buffer_vars + + +def _rewrite_calls(body, buffer_vars): + """Rewrite ptx_cp_async_lds -> ptx_cp_async_lds_rsrc with hoisted vars.""" + + def _postorder(op): + if isinstance(op, Evaluate) and isinstance(op.value, Call): + if op.value.op == _op_ptx_cp_async_lds: + buf_var = _extract_buffer_var(op.value.args[1]) + if buf_var is not None and buf_var in buffer_vars: + rsrc_var, base_var = buffer_vars[buf_var] + new_call = Call( + op.value.dtype, + _op_ptx_cp_async_lds_rsrc, + [ + op.value.args[0], + op.value.args[1], + op.value.args[2], + rsrc_var, + base_var, + ], + ) + return Evaluate(new_call) + return None + + return stmt_functor.ir_transform(body, None, _postorder, ["tir.Evaluate"]) + + +def HoistBufferResource(): + """gfx950: hoist buffer resource descriptors out of the inner copy loop.""" + + def pass_fn(func: PrimFunc, _mod, _ctx): + target = func.attrs.get("target", None) + if target is None or not target_is_gfx950(target): + return func + + buffer_vars = _collect_buffer_vars(func.body) + if not buffer_vars: + return func + + new_body = _rewrite_calls(func.body, buffer_vars) + + for buf_var, (rsrc_var, base_var) in reversed(list(buffer_vars.items())): + new_body = AttrStmt(base_var, "buffer_base_var", buf_var, new_body) + new_body = AttrStmt(rsrc_var, "buffer_resource_var", buf_var, new_body) + + return func.with_body(new_body) + + return prim_func_pass(pass_fn, opt_level=0) From 3c82e3b9d0f92dab56e0e954052538c2b9449e88 Mon Sep 17 00:00:00 2001 From: zhutaoyu Date: Sun, 17 May 2026 06:40:20 +0000 Subject: [PATCH 12/34] M7: layout SwizzleDelta API + ROCm swizzle-swap in lower_tile_op Two changes that together let the M5 IsLdsLaneContiguous gate route the target shape's g2s copies into the gfx950 buffer_load...lds path. 1. Layout swizzle-delta API (src/layout/layout.h, layout.cc): - New Optional swizzle_delta_ member on LayoutNode. - virtual PrimExpr SwizzleDelta(input_indices): substitutes the last InputDim() entries of input_indices into the stored delta (same convention as Forward), returns 0 when no delta is set. - bool HasSwizzle() / void SetSwizzleDelta(): trivial accessors. - LayoutNode::Expand propagates swizzle_delta_ through the variable remap that shifts InputPlaceholders by the leading-shape offset. 2. Swizzle factories now record their column-XOR delta (src/layout/gemm_layouts.cc) so HasSwizzle() returns true for the layouts the GEMM tile-op actually produces: - MakeQuarterBankSwizzleLayout2D: (xor2x2(c, s>>2) - c) * vec - MakeHalfBankSwizzleLayout2D: (xor4x4(c, s>>1) - c) * vec - MakeFullBankSwizzleLayout2D: (xor8x8(c, s) - c) * vec 3. Swizzle-swap in src/transform/lower_tile_op.cc BufferStoreNode visitor: when TargetIsRocm && !is_ptx_ && IsSharedBuffer(buffer) && layout has a swizzle AND the store value is a direct global BufferLoad of matching arity AND the layout output has unit leading dim, rewrite shared[Forward(local)] = global[base + local] to shared[Forward(local) - delta(local)] = global[base + local + delta(local)] XOR is self-inverse so net data movement is unchanged, but the LDS destination becomes lane-contiguous and the cp.async injector's IsLdsLaneContiguous check stops rejecting the store. Other layouts (no HasSwizzle, non-ROCm, ptx path, non-shared dst, non-direct-load stores) fall through to the existing Forward-only path. Diagnostic LOG(INFO) noise from the reference branch is intentionally omitted. Verification: pip install -e . succeeds; `import tilelang` succeeds. End-to-end emission + perf is the M8 job. Co-Authored-By: Claude Opus 4 (1M context) --- src/layout/gemm_layouts.cc | 15 ++++++++-- src/layout/layout.cc | 27 +++++++++++++++++- src/layout/layout.h | 23 +++++++++++++++ src/transform/lower_tile_op.cc | 51 +++++++++++++++++++++++++++++++++- 4 files changed, 111 insertions(+), 5 deletions(-) diff --git a/src/layout/gemm_layouts.cc b/src/layout/gemm_layouts.cc index d06182efa..758bb4788 100644 --- a/src/layout/gemm_layouts.cc +++ b/src/layout/gemm_layouts.cc @@ -449,7 +449,10 @@ static Layout MakeQuarterBankSwizzleLayout2D(int stride, int continuous, PrimExpr vec = FloorMod(j, vector_size); PrimExpr c_swizzle = xor2x2(c, FloorDiv(s, 4)); PrimExpr index = vec + (c_swizzle + s * 2) * vector_size; - return Layout(Array{stride, continuous}, {tc, ts, index}); + PrimExpr swizzle_delta = (c_swizzle - c) * vector_size; + Layout result(Array{stride, continuous}, {tc, ts, index}); + const_cast(result.get())->SetSwizzleDelta(swizzle_delta); + return result; } Layout makeQuarterBankSwizzleLayout(const Buffer &buffer) { @@ -477,7 +480,10 @@ static Layout MakeHalfBankSwizzleLayout2D(int stride, int continuous, PrimExpr vec = FloorMod(j, vector_size); PrimExpr c_swizzle = xor4x4(c, FloorDiv(s, 2)); PrimExpr index = vec + (c_swizzle + s * 4) * vector_size; - return Layout(Array{stride, continuous}, {tc, ts, index}); + PrimExpr swizzle_delta = (c_swizzle - c) * vector_size; + Layout result(Array{stride, continuous}, {tc, ts, index}); + const_cast(result.get())->SetSwizzleDelta(swizzle_delta); + return result; } Layout makeHalfBankSwizzleLayout(const Buffer &buffer) { @@ -505,7 +511,10 @@ static Layout MakeFullBankSwizzleLayout2D(int stride, int continuous, PrimExpr vec = FloorMod(j, vector_size); PrimExpr c_swizzle = xor8x8(c, s); PrimExpr index = vec + (c_swizzle + s * 8) * vector_size; - return Layout(Array{stride, continuous}, {tc, ts, index}); + PrimExpr swizzle_delta = (c_swizzle - c) * vector_size; + Layout result(Array{stride, continuous}, {tc, ts, index}); + const_cast(result.get())->SetSwizzleDelta(swizzle_delta); + return result; } Layout makeFullBankSwizzleLayout(const Buffer &buffer) { diff --git a/src/layout/layout.cc b/src/layout/layout.cc index c5661dc91..045234ede 100644 --- a/src/layout/layout.cc +++ b/src/layout/layout.cc @@ -505,7 +505,32 @@ Layout LayoutNode::Expand(const Array &leading_shape) const { new_forward_index.push_back(Substitute(e, vmap)); } - return Layout(new_input_size, new_forward_index); + Layout result(new_input_size, new_forward_index); + // Propagate swizzle_delta_ through Expand: substitute placeholder + // indices so the delta keeps referring to the same physical input + // dimension after the leading-shape prefix is added. + if (swizzle_delta_.defined()) { + const_cast(result.get()) + ->SetSwizzleDelta(Substitute(swizzle_delta_.value(), vmap)); + } + return result; +} + +PrimExpr LayoutNode::SwizzleDelta(const Array &input_indices) const { + if (!swizzle_delta_.defined()) { + return IntImm(DataType::Int(32), 0); + } + // Substitute the last InputDim() entries of input_indices into + // swizzle_delta_, matching the convention Forward() uses. + PrimExpr delta = swizzle_delta_.value(); + size_t offset = + input_indices.size() >= InputDim() ? input_indices.size() - InputDim() + : 0; + for (size_t i = 0; i < InputDim(); ++i) { + delta = Substitute(delta, + {{InputPlaceholder(i), input_indices[offset + i]}}); + } + return delta; } Fragment FragmentNode::Repeat(const Array &repeats, diff --git a/src/layout/layout.h b/src/layout/layout.h index 8043c5765..d888d8c26 100644 --- a/src/layout/layout.h +++ b/src/layout/layout.h @@ -103,6 +103,24 @@ class LayoutNode : public Object { virtual bool IsEqual(const LayoutNode *other, bool skip_index = false) const; + /*! + * \brief Get the XOR swizzle column delta on the last input dimension. + * + * For swizzled layouts (quarter/half/full bank) returns the column + * delta caused by the XOR: delta = (c_swizzled - c) * vector_size, + * substituted against the supplied indices. For non-swizzle layouts + * returns 0. Used by the swizzle-swap optimisation in lower_tile_op.cc + * to move the XOR off the LDS-store side and onto the global-load + * side when the target supports buffer_load ... lds direct DMA. + */ + virtual PrimExpr SwizzleDelta(const Array &input_indices) const; + + /*! \brief Whether this layout carries a non-trivial swizzle delta. */ + bool HasSwizzle() const { return swizzle_delta_.defined(); } + + /*! \brief Set the swizzle delta expression (called by layout factories). */ + void SetSwizzleDelta(PrimExpr delta) { swizzle_delta_ = delta; } + static void RegisterReflection(); TVM_FFI_DECLARE_OBJECT_INFO("tl.Layout", LayoutNode, Object); static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = @@ -113,6 +131,11 @@ class LayoutNode : public Object { void UpdateAnalyzer(arith::Analyzer *analyzer) const; Array forward_index_; Array input_size_; + /*! + * \brief Optional XOR swizzle delta in terms of InputPlaceholders, set + * by swizzle layout factories and propagated through Expand/Reshape. + */ + Optional swizzle_delta_; }; /*! diff --git a/src/transform/lower_tile_op.cc b/src/transform/lower_tile_op.cc index ea92ac66a..74a3a0ca4 100644 --- a/src/transform/lower_tile_op.cc +++ b/src/transform/lower_tile_op.cc @@ -953,9 +953,58 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { auto store = Downcast(IRMutatorWithAnalyzer::VisitStmt_(op)); auto buffer = store->buffer; if (buffer_remap_.count(buffer)) { - auto new_indices = layout_map_[buffer]->Forward(store->indices); auto new_buffer = buffer_remap_[store->buffer]; layout_remap_.Set(new_buffer, layout_map_[store->buffer]); + + // gfx950 buffer_load_dwordx4 ... lds requires LDS destinations be + // lane-contiguous (no XOR). When the shared-buffer layout carries + // an XOR swizzle, move the XOR off the LDS store side and onto + // the global load side (XOR is self-inverse, so the net data + // movement is unchanged). Gated on: ROCm target, non-PTX path, + // shared destination, layout has a swizzle delta, and the store + // value is a direct global BufferLoad (the g2s shape the cp.async + // injector recognises). + if (TargetIsRocm(target_) && !is_ptx_ && IsSharedBuffer(buffer) && + layout_map_[buffer]->HasSwizzle()) { + const BufferLoadNode *load_node = nullptr; + if (auto *load = store->value.as()) { + if (IsGlobalBuffer(load->buffer)) { + load_node = load; + } + } + if (load_node && is_one(layout_map_[buffer]->OutputShape()[0]) && + load_node->indices.size() == store->indices.size()) { + auto swizzled_store = layout_map_[buffer]->Forward(store->indices); + PrimExpr delta = analyzer_->Simplify( + layout_map_[buffer]->SwizzleDelta(store->indices)); + + Array sequential_store(swizzled_store.begin(), + swizzled_store.end()); + int last_out = static_cast(sequential_store.size()) - 1; + sequential_store.Set( + last_out, + analyzer_->Simplify(sequential_store[last_out] - delta)); + + Array reflected(store->indices.begin(), + store->indices.end()); + int last_in = static_cast(reflected.size()) - 1; + reflected.Set( + last_in, analyzer_->Simplify(reflected[last_in] + delta)); + + Array new_load_indices; + for (size_t k = 0; k < load_node->indices.size(); ++k) { + PrimExpr base = analyzer_->Simplify(load_node->indices[k] - + store->indices[k]); + new_load_indices.push_back( + analyzer_->Simplify(base + reflected[k])); + } + + BufferLoad rewritten_load(load_node->buffer, new_load_indices); + return BufferStore(new_buffer, rewritten_load, sequential_store); + } + } + + auto new_indices = layout_map_[buffer]->Forward(store->indices); return BufferStore(new_buffer, store->value, new_indices); } else if (var_remap_.count(buffer->data)) { auto new_buffer = Buffer( From ae29f93939bd336db161dc12f46c84e40dfef761 Mon Sep 17 00:00:00 2001 From: zhutaoyu Date: Sun, 17 May 2026 06:45:28 +0000 Subject: [PATCH 13/34] M8a: re-enable MergeSharedMemoryAllocations in OptimizeForTarget Without this pass, kernels with multiple `shared.dyn` buffers (which includes the bench's matmul with A_shared + B_shared) fail LowerDeviceKernelLaunch with "Only one dynamic shared memory allocation is allowed". The pass was commented out on this branch with no explanation; re-enabling it lets the bench compile through to codegen + kernel launch. This is a prerequisite for M8 (perf gate). Pre-existing constraint, not related to the buffer_load_lds work; landed as a separate small milestone for traceability per AC-3. Co-Authored-By: Claude Opus 4 (1M context) --- tilelang/engine/phase.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index 45e62adcd..3021d5255 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -282,11 +282,13 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: mod = tilelang.transform.MarkCudaSyncCalls(have_pdl(target))(mod) mod = tilelang.transform.AnnotateReadOnlyParams()(mod) - print(mod.script()) # MergeSharedMemoryAllocations must be applied after SplitHostDevice - # because the merged allocation site is at the beginning of each device function - # enable_aggressive_merge = should_enable_aggressive_merge(pass_ctx=pass_ctx, target=target) - # mod = tilelang.transform.MergeSharedMemoryAllocations(enable_aggressive_merge=enable_aggressive_merge)(mod) + # because the merged allocation site is at the beginning of each device + # function. LowerDeviceKernelLaunch enforces "Only one dynamic shared + # memory allocation"; keeping this disabled breaks any kernel with + # multiple .dyn buffers (the bench matmul has two: A_shared + B_shared). + enable_aggressive_merge = should_enable_aggressive_merge(pass_ctx=pass_ctx, target=target) + mod = tilelang.transform.MergeSharedMemoryAllocations(enable_aggressive_merge=enable_aggressive_merge)(mod) # InjectFenceProxy is a no-op on targets that lack the TMA / async-proxy # programming model; the pass itself checks the PrimFunc's target. mod = tilelang.transform.InjectFenceProxy()(mod) From d101b3823b278ff68f06052c01b0c4cb0f09c3e8 Mon Sep 17 00:00:00 2001 From: zhutaoyu Date: Sun, 17 May 2026 06:57:13 +0000 Subject: [PATCH 14/34] M8b: integrate ptx_cp_async_lds with vec-loop folding + buf-merge + codegen num_elems convention The Round 1 emission had an outer `for vec=0..8` loop wrapping each tl::cp_async_gs_lds_with_rsrc<16> call, causing 8x overlapping LDS writes per thread and a runtime GPU coredump. Root cause: tl::ptx_cp_async stores its transfer width as a *logical element count* and several downstream passes (vectorize_loop.cc::MutatePTXCPAsyncExpr_, loop_vectorize.cc dispatch, merge_shared_memory_allocations.cc rewrite) key off the op identity to widen/remap. Earlier I emitted ptx_cp_async_lds with arg 2 in *bytes*, so those passes either skipped the call (no widening = vec loop survived) or could not remap the merged shared buffer. This commit makes the ptx_cp_async_lds family follow the same logical- element-count convention as tl::ptx_cp_async: - src/transform/lower_ptx_async_copy.cc: MakeCPAsyncStmtFromLoads's LDS branch now passes `num_elems` (not `total_bytes`) as arg 2. - src/backend/rocm/codegen/codegen_hip.cc: the tl::ptx_cp_async and tl::ptx_cp_async_lds CallNode handlers are merged; both go through GetTileLangCPAsyncTransferBytes to derive the template byte width. The tl::ptx_cp_async_lds_rsrc handler does the same conversion inline (it has 5 args so cannot reuse GetTileLang...). - src/transform/vectorize_loop.cc: GetCPAsyncBitsPerCall and MutatePTXCPAsyncExpr_ ICHECKs widen to accept tl::ptx_cp_async_lds. This lets the vec(k) widening multiply num_elems by k and produces one call covering the full vec range, so the surrounding vec loop's body becomes uniform and the loop is consumed by the vectorizer. - src/transform/loop_vectorize.cc: same widening for the ScalarToVector pipeline. - src/transform/merge_shared_memory_allocations.cc: extend the cp_async dst-ptr remap to recognise tl::ptx_cp_async_lds so the merged dyn-shared buffer pointer is substituted correctly. Result on the bench (8192x8192x8192 NT, tile 256x256x64): emission now matches `_fast.cpp`'s outer-loop shape verbatim (tl::cp_async_gs_lds_with_rsrc<16> per i_1 iteration, no inner vec loop). Bench compile passes, kernel launches without coredump, ~867 TFLOPS. Correctness still fails because the swizzle-swap from M7 is not engaging on the CDNA layout, leaving the XOR on the LDS-store side; that is the next debug step. Co-Authored-By: Claude Opus 4 (1M context) --- src/backend/rocm/codegen/codegen_hip.cc | 48 +++++++++++++------ src/transform/loop_vectorize.cc | 6 ++- src/transform/lower_ptx_async_copy.cc | 11 +++-- .../merge_shared_memory_allocations.cc | 3 +- src/transform/vectorize_loop.cc | 6 ++- 5 files changed, 50 insertions(+), 24 deletions(-) diff --git a/src/backend/rocm/codegen/codegen_hip.cc b/src/backend/rocm/codegen/codegen_hip.cc index 6c7984ed2..46bdf770f 100644 --- a/src/backend/rocm/codegen/codegen_hip.cc +++ b/src/backend/rocm/codegen/codegen_hip.cc @@ -926,18 +926,34 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) { << "ptx_cp_async_lds_rsrc expects 5 arguments"; std::string dst = this->PrintExpr(op->args[0]); std::string src = this->PrintExpr(op->args[1]); - std::string size = this->PrintExpr(op->args[2]); + // arg 2 carries logical element count (inherited from the + // ptx_cp_async_lds call that HoistBufferResource rewrote into this + // rsrc form). Convert to bytes the same way the ptx_cp_async / lds + // handler does, by inspecting the access-ptr element type. + const auto *num_elems_imm = op->args[2].as(); + ICHECK(num_elems_imm) + << "ptx_cp_async_lds_rsrc num_elems must be IntImm, but got " + << op->args[2]; + auto dst_elem_type = GetAccessPtrElementType(op->args[0]); + ICHECK(dst_elem_type.has_value()) + << "ptx_cp_async_lds_rsrc dst must be tvm_access_ptr / tl.access_ptr / " + "address_of(BufferLoad)"; + int64_t total_bits = num_elems_imm->value * + dst_elem_type.value().bits() * + dst_elem_type.value().lanes(); + ICHECK_EQ(total_bits % 8, 0) + << "ptx_cp_async_lds_rsrc requires byte-aligned transfer, got " + << total_bits << " bits"; + int total_bytes = static_cast(total_bits / 8); + std::string size = std::to_string(total_bytes); std::string rsrc = this->PrintExpr(op->args[3]); std::string base = this->PrintExpr(op->args[4]); this->PrintIndent(); this->stream << "tl::cp_async_gs_lds_with_rsrc<" << size << ">(" << dst << ", " << src << ", " << rsrc << ", " << base << ");\n"; - } else if (op->op.same_as(builtin::ptx_cp_async()) || - op->op.same_as(tl::ptx_cp_async_lds())) { - // args[0] = dst_access_ptr, args[1] = src_access_ptr, args[2] = bytes, - // args[3] = predicate (optional). ptx_cp_async_lds shares this shape but - // routes to the gs_lds (buffer_load ... lds) template instead. + } else if (op->op.same_as(builtin::ptx_cp_async())) { + // builtin::ptx_cp_async stores byte width directly in arg 2. ICHECK(op->args.size() == 3 || op->args.size() == 4) << "ptx_cp_async expects 3 or 4 arguments (dst_access_ptr, " "src_access_ptr, bytes, [predicate])"; @@ -946,25 +962,29 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) { std::string size = this->PrintExpr(op->args[2]); this->PrintIndent(); if (op->args.size() == 3) { - // Non-predicated version - bool use_lds = op->op.same_as(tl::ptx_cp_async_lds()); - this->stream << (use_lds ? "tl::cp_async_gs_lds<" : "tl::cp_async_gs<") - << size << ">(" << dst << ", " << src << ");\n"; + this->stream << "tl::cp_async_gs<" << size << ">(" << dst << ", " << src + << ");\n"; } else { - // Predicated version std::string condition = this->PrintExpr(op->args[3]); this->stream << "tl::cp_async_gs_conditional<" << size << ">(" << dst << ", " << src << ", " << condition << ");\n"; } - } else if (op->op.same_as(tl::ptx_cp_async())) { + } else if (op->op.same_as(tl::ptx_cp_async()) || + op->op.same_as(tl::ptx_cp_async_lds())) { + // Both store logical element count in arg 2; convert to bytes via + // GetTileLangCPAsyncTransferBytes. ptx_cp_async_lds routes the + // non-predicated path to tl::cp_async_gs_lds instead of + // tl::cp_async_gs; predicated copies always fall back to the + // generic cp_async_gs_conditional template. int total_bytes = GetTileLangCPAsyncTransferBytes(op); std::string dst = this->PrintExpr(op->args[0]); std::string src = this->PrintExpr(op->args[1]); std::string size = std::to_string(total_bytes); this->PrintIndent(); if (op->args.size() == 3) { - this->stream << "tl::cp_async_gs<" << size << ">(" << dst << ", " << src - << ");\n"; + bool use_lds = op->op.same_as(tl::ptx_cp_async_lds()); + this->stream << (use_lds ? "tl::cp_async_gs_lds<" : "tl::cp_async_gs<") + << size << ">(" << dst << ", " << src << ");\n"; } else { std::string condition = this->PrintExpr(op->args[3]); this->stream << "tl::cp_async_gs_conditional<" << size << ">(" << dst diff --git a/src/transform/loop_vectorize.cc b/src/transform/loop_vectorize.cc index d1c8e2fea..0be22e5aa 100644 --- a/src/transform/loop_vectorize.cc +++ b/src/transform/loop_vectorize.cc @@ -445,7 +445,8 @@ class VectorizePlanner : public arith::IRMutatorWithAnalyzer { if (node->op.same_as(builtin::ptx_cp_async())) { return count * 8; } - ICHECK(node->op.same_as(tl::ptx_cp_async())); + ICHECK(node->op.same_as(tl::ptx_cp_async()) || + node->op.same_as(tl::ptx_cp_async_lds())); auto dst_elem_bits = GetAccessPtrElementBits(node->args[0]); auto src_elem_bits = GetAccessPtrElementBits(node->args[1]); if (!dst_elem_bits.has_value() || !src_elem_bits.has_value()) { @@ -520,7 +521,8 @@ class VectorizePlanner : public arith::IRMutatorWithAnalyzer { buffer_vector_infos_.push_back({Buffer(), vectorize_length, false, {}}); return arith::IRMutatorWithAnalyzer::VisitExpr_(node); } else if (node->op.same_as(builtin::ptx_cp_async()) || - node->op.same_as(tl::ptx_cp_async())) { + node->op.same_as(tl::ptx_cp_async()) || + node->op.same_as(tl::ptx_cp_async_lds())) { // builtin::ptx_cp_async stores bytes, while tl::ptx_cp_async stores // logical element counts. In both cases we pick the largest vector width // whose eventual PTX payload is one of {4, 8, 16} bytes. diff --git a/src/transform/lower_ptx_async_copy.cc b/src/transform/lower_ptx_async_copy.cc index 1c6e3a965..77c3a915e 100644 --- a/src/transform/lower_ptx_async_copy.cc +++ b/src/transform/lower_ptx_async_copy.cc @@ -531,17 +531,18 @@ class PTXAsyncCopyInjector : public StmtMutator { // gfx950 routing: emit tl::ptx_cp_async_lds when the destination is a // 16-byte non-predicated shared-memory write whose LDS index is lane- - // contiguous (no XOR swizzle). Note: arg 2 here is *byte width*, not - // logical element count, because the codegen handler for ptx_cp_async_lds - // prints arg 2 directly as the template width and the device template - // currently only specialises N == 16. + // contiguous (no XOR swizzle). Arg 2 carries the logical element count + // (same convention tl::ptx_cp_async uses) so the existing vec-loop + // folding in vectorize_loop.cc widens it correctly when the call sits + // inside a T.vectorized(k) loop. The codegen handler converts the + // logical count back to bytes via GetTileLangCPAsyncTransferBytes. if (enable_buffer_load_lds_ && !predicated && total_bytes == 16) { const std::string dst_scope = store->buffer.scope(); const bool is_shared = dst_scope == "shared" || dst_scope == "shared.dyn"; if (is_shared && IsLdsLaneContiguous(dst_check_index)) { ffi::Array lds_args = {dst_access_ptr, src_access_ptr, - PrimExpr(total_bytes)}; + PrimExpr(num_elems)}; return Evaluate(Call(store->buffer->dtype, tvm::tl::ptx_cp_async_lds(), lds_args)); } diff --git a/src/transform/merge_shared_memory_allocations.cc b/src/transform/merge_shared_memory_allocations.cc index 8ce33dc5d..0b264524c 100644 --- a/src/transform/merge_shared_memory_allocations.cc +++ b/src/transform/merge_shared_memory_allocations.cc @@ -600,7 +600,8 @@ class SharedMemoryRewriter : public StmtExprMutator { {op->args[0], merged_buf_var_, extra_offset + offset, extent, op->args[4]}); } else if (op->op.same_as(builtin::ptx_cp_async()) || - op->op.same_as(tl::ptx_cp_async())) { + op->op.same_as(tl::ptx_cp_async()) || + op->op.same_as(tl::ptx_cp_async_lds())) { ICHECK(op->args.size() == 3U || op->args.size() == 4U) << "ptx_cp_async expects 3 or 4 arguments (dst_access_ptr, " "src_access_ptr, count[, predicate])"; diff --git a/src/transform/vectorize_loop.cc b/src/transform/vectorize_loop.cc index 5fdbdeaf8..9be5e9a02 100644 --- a/src/transform/vectorize_loop.cc +++ b/src/transform/vectorize_loop.cc @@ -685,7 +685,8 @@ class TLVectorizer : public StmtMutator, if (op->op.same_as(builtin::ptx_cp_async())) { return scalar_count * 8; } - ICHECK(op->op.same_as(tl::ptx_cp_async())); + ICHECK(op->op.same_as(tl::ptx_cp_async()) || + op->op.same_as(tl::ptx_cp_async_lds())); auto dst_elem_bits = GetAccessPtrElementBits(op->args[0]); auto src_elem_bits = GetAccessPtrElementBits(op->args[1]); if (!dst_elem_bits.has_value() || !src_elem_bits.has_value()) { @@ -706,7 +707,8 @@ class TLVectorizer : public StmtMutator, // the final codegen validate the derived PTX byte width. PrimExpr MutatePTXCPAsyncExpr_(const CallNode *op) { ICHECK(op->op.same_as(builtin::ptx_cp_async()) || - op->op.same_as(tl::ptx_cp_async())); + op->op.same_as(tl::ptx_cp_async()) || + op->op.same_as(tl::ptx_cp_async_lds())); if (op->args.size() != 3 && op->args.size() != 4) { return tvm::ffi::GetRef(op); } From 22fd40293108b76e668427b6abc894ce1befe5a1 Mon Sep 17 00:00:00 2001 From: zhutaoyu Date: Sun, 17 May 2026 07:15:03 +0000 Subject: [PATCH 15/34] M5-harden: replace XOR-call gate with real affine lane-contiguity proof Codex Round 1 review flagged the previous IsLdsLaneContiguous as unsafe: it only rejected expressions containing a tir.bitwise_xor call, but the swizzle layouts on this branch expand the XOR into shift/and/add bit arithmetic (see e.g. xor2x2 in src/layout/gemm_layouts.cc:370). The expanded form passed the gate even when the LDS destination was not lane-contiguous. Replace IsLdsLaneContiguous with a multi-sample affine proof modelled on /root/backuptilelang/src/transform/inject_ptx_async_copy.cc::IsLdsContiguous: - Walk the dst index, collect free Vars. - Pick a thread-like Var (name contains "thread" or equals "tx"/"tid"). - Compute f(0), f(1), expected_stride = f(1) - f(0); require IntImm. - Sample at points 2..1023; require analyzer.CanProveEqual(f(k) - f(0), k * stride) at every sample. Covers wave-32/64 boundaries, warp tiles, bank-swizzle phases (powers of two up to 64), and the 256-thread block boundaries the bench uses. Returns false if no thread-like var is found or any sample disagrees, so the LDS route safely no-ops back to the existing ptx_cp_async path until M9 actually moves the swizzle off the LDS side. Removes ContainsBitwiseXor helper (no longer needed). Build + import OK. Co-Authored-By: Claude Opus 4 (1M context) --- src/transform/lower_ptx_async_copy.cc | 65 ++++++++++++++++++++------- 1 file changed, 48 insertions(+), 17 deletions(-) diff --git a/src/transform/lower_ptx_async_copy.cc b/src/transform/lower_ptx_async_copy.cc index 77c3a915e..9d29da41f 100644 --- a/src/transform/lower_ptx_async_copy.cc +++ b/src/transform/lower_ptx_async_copy.cc @@ -431,26 +431,57 @@ class PTXAsyncCopyInjector : public StmtMutator { return info; } - // Conservative structural check used to gate the gfx950 buffer_load...lds - // routing: the destination LDS index must be lane-contiguous, i.e. it has - // no XOR-style swizzle term. Until the swizzle-swap optimisation in - // lower_tile_op.cc moves the XOR off the LDS store side (M7 in the - // accompanying plan) this returns false for ordinary swizzled layouts, so - // the routing safely no-ops and the existing ptx_cp_async path is used. - static bool ContainsBitwiseXor(const PrimExpr &expr) { - bool found = false; - tir::PostOrderVisit(expr, [&](const ObjectRef &node) { - if (found) return; - if (const auto *call = node.as()) { - if (call->op.same_as(tvm::tir::builtin::bitwise_xor())) { - found = true; + // Real lane-contiguity proof for the gfx950 buffer_load...lds routing. + // The destination LDS index must be an affine function of the thread-like + // variable with a constant per-lane stride. Bit-arithmetic swizzles + // (xor2x2 etc. expanded as `((a + b) & 1) * k`) have a constant + // f(1)-f(0) but are NOT linear globally; we catch those by sampling at + // many points and requiring f(k) - f(0) == k * stride at every point. + // Returns false unless we find a recognisable thread-like free var whose + // contribution is provably linear. + static bool IsLdsLaneContiguous(const PrimExpr &dst_index) { + arith::Analyzer analyzer; + std::unordered_set seen; + Array free_vars; + tir::PostOrderVisit(dst_index, [&](const ObjectRef &node) { + if (auto *v = node.as()) { + if (seen.insert(v).second) { + free_vars.push_back(Downcast(node)); } } }); - return found; - } - static bool IsLdsLaneContiguous(const PrimExpr &dst_index) { - return !ContainsBitwiseXor(dst_index); + // Sample 0, 1, 2, ..., 1023 — covers wave-32/64 boundaries, warp tiles, + // bank-swizzle phases (typically powers of two up to 64), and the + // wider 256-thread block boundaries the bench uses. + constexpr int kNumSamples = 1024; + for (const auto &var : free_vars) { + const std::string name(var->name_hint); + if (name.find("thread") == std::string::npos && + name != "tx" && name != "tid") { + continue; + } + PrimExpr f0 = analyzer.Simplify(Substitute( + dst_index, Map{{var, IntImm(var->dtype, 0)}})); + PrimExpr f1 = analyzer.Simplify(Substitute( + dst_index, Map{{var, IntImm(var->dtype, 1)}})); + PrimExpr expected_stride = analyzer.Simplify(f1 - f0); + auto *stride_imm = expected_stride.as(); + if (!stride_imm) { + return false; + } + int64_t stride = stride_imm->value; + for (int pt = 2; pt < kNumSamples; ++pt) { + PrimExpr fk = analyzer.Simplify(Substitute( + dst_index, Map{{var, IntImm(var->dtype, pt)}})); + PrimExpr actual = analyzer.Simplify(fk - f0); + PrimExpr expected = IntImm(DataType::Int(64), stride * pt); + if (!analyzer.CanProveEqual(actual, expected)) { + return false; + } + } + return true; + } + return false; } static PrimExpr ExtractVectorBase(const PrimExpr &index) { From 91878e4a31b07789e3fcc441f987a1e98fb3d0ef Mon Sep 17 00:00:00 2001 From: zhutaoyu Date: Sun, 17 May 2026 07:21:19 +0000 Subject: [PATCH 16/34] M9: VisitExpr_(CallNode) swizzle-swap on tl::ptx_cp_async_lds (real fix for AC-2/AC-1) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The M7 swap in VisitStmt_(BufferStoreNode) never fired on this branch because rocm::Copy::Lower calls InjectPTXAsyncCopy inline and the BufferStores into A_shared / B_shared are converted to tl::ptx_cp_async_lds Calls before LowerTileOpPass re-visits the lowered Stmt (`return IRMutatorWithAnalyzer::VisitStmt(lowered)` at ~line 1130). Diagnosed by adding LOG(WARNING) at the top of the BufferStore visitor: it only ever saw C_local, A_local, B_local, C - never the shared buffers. Per Codex Round 1 review's directive, add the swap on the resulting Call node instead. New VisitExpr_(CallNode) branch that, for tl::ptx_cp_async_lds on a ROCm target with a swizzled remapped shared destination and a direct global source BufferLoad: 1. resolve_load() pulls BufferLoad from the tl::access_ptr arg (handles direct BufferLoad and let-bound BufferLoad via let_bindings_) 2. compute swizzled = layout->Forward(dst_logical_indices) and delta = layout->SwizzleDelta(dst_logical_indices) 3. build new dst access_ptr against buffer_remap_[dst_buf] with swizzled[last] - delta on the last dim (lane-contiguous LDS) 4. build new src access_ptr with src_logical_indices[last] + delta on the last dim (XOR moved to the global side; self-inverse) 5. return rewritten Call directly so the default arg visitor does not re-apply the swizzled layout to the destination Rank-difference between dst and src is allowed: the LDS dst typically has a leading Expand dim that the global src does not, but the "last dim" of each maps to the same physical column so the swap is well-defined on the last dim alone. Bench result with M5-harden + M9 (M6.5 still pending): correctness PASS, TFLOPS 866 (still below 1000 floor — M6.5 wait-count scaling should close the gap). Emission shape now matches _fast.cpp exactly: LDS dst is `i_1 * 4096 + threadIdx.x * 8` (linear), global src carries the bank-swizzle pattern. Co-Authored-By: Claude Opus 4 (1M context) --- src/transform/lower_tile_op.cc | 70 ++++++++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) diff --git a/src/transform/lower_tile_op.cc b/src/transform/lower_tile_op.cc index 74a3a0ca4..81ffbc1bc 100644 --- a/src/transform/lower_tile_op.cc +++ b/src/transform/lower_tile_op.cc @@ -923,6 +923,76 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { return call; } + // gfx950 swizzle-swap for buffer_load_dwordx4 ... lds. On this branch + // the BufferStores that feed A_shared/B_shared are consumed by + // InjectPTXAsyncCopy inside rocm::Copy::Lower before they reach the + // BufferStore visitor, so the swap has to happen on the resulting + // tl::ptx_cp_async_lds Call. We rewrite the dst access_ptr to use + // (Forward(idx) - delta) on the last dim against the remapped shared + // buffer (lane-contiguous LDS writes) and shift the global src + // access_ptr by +delta on the last dim (XOR is self-inverse, so net + // data movement is unchanged). Returning the rewritten Call directly + // prevents the default visitor from re-applying the swizzled layout. + if (op->op.same_as(tl::ptx_cp_async_lds()) && + TargetIsRocm(target_) && op->args.size() == 3) { + auto resolve_load = [&](const PrimExpr &arg) -> const BufferLoadNode * { + const auto *call = arg.as(); + if (!call || !call->op.same_as(tl::access_ptr())) return nullptr; + const auto *direct = call->args[0].as(); + if (direct) return direct; + if (const auto *var = call->args[0].as()) { + auto it = let_bindings_.find(Downcast(call->args[0])); + if (it != let_bindings_.end()) { + return it->second.as(); + } + (void)var; + } + return nullptr; + }; + const auto *dst_ap = op->args[0].as(); + const auto *src_ap = op->args[1].as(); + const BufferLoadNode *dst_load = resolve_load(op->args[0]); + const BufferLoadNode *src_load = resolve_load(op->args[1]); + if (dst_ap && src_ap && dst_load && src_load && + IsSharedBuffer(dst_load->buffer) && + IsGlobalBuffer(src_load->buffer) && + buffer_remap_.count(dst_load->buffer) && + layout_map_.count(dst_load->buffer) && + layout_map_[dst_load->buffer]->HasSwizzle() && + dst_load->indices.size() > 0 && src_load->indices.size() > 0) { + Buffer new_dst_buf = buffer_remap_[dst_load->buffer]; + layout_remap_.Set(new_dst_buf, layout_map_[dst_load->buffer]); + auto layout = layout_map_[dst_load->buffer]; + auto swizzled = layout->Forward(dst_load->indices); + PrimExpr delta = + analyzer_->Simplify(layout->SwizzleDelta(dst_load->indices)); + + Array new_dst_indices(swizzled.begin(), swizzled.end()); + int last_dst = static_cast(new_dst_indices.size()) - 1; + new_dst_indices.Set( + last_dst, + analyzer_->Simplify(new_dst_indices[last_dst] - delta)); + + Array new_src_indices(src_load->indices.begin(), + src_load->indices.end()); + int last_src = static_cast(new_src_indices.size()) - 1; + new_src_indices.Set( + last_src, + analyzer_->Simplify(new_src_indices[last_src] + delta)); + + BufferLoad new_dst_load(new_dst_buf, new_dst_indices); + BufferLoad new_src_load(src_load->buffer, new_src_indices); + PrimExpr new_dst_ap = + Call(dst_ap->dtype, tl::access_ptr(), + {new_dst_load, dst_ap->args[1], dst_ap->args[2]}); + PrimExpr new_src_ap = + Call(src_ap->dtype, tl::access_ptr(), + {new_src_load, src_ap->args[1], src_ap->args[2]}); + return Call(op->dtype, op->op, + {new_dst_ap, new_src_ap, op->args[2]}); + } + } + // Default: visit normally auto call = Downcast(IRMutatorWithAnalyzer::VisitExpr_(op)); return call; From 4736c94ac34f2e88bc2ad0ecc139cee3e2cf05b6 Mon Sep 17 00:00:00 2001 From: zhutaoyu Date: Sun, 17 May 2026 07:23:20 +0000 Subject: [PATCH 17/34] M6.5: AMD vmcnt wait-count scaling in HoistBufferResource Codex Round 1 review flagged the generated cp_async_wait<1> vs the target's cp_async_wait<8> as a likely 1000 TFLOPS blocker, and after M9 fixed the swizzle direction the bench correctness passed at only 866 TFLOPS - confirming Codex's analysis. Add _get_loads_per_group + _fix_amd_wait_counts to tilelang/transform/hoist_buffer_resource.py, matching the reference algorithm but rewriting on this branch's IR shape: - _is_async_load_call recognises Evaluate(Call(ptx_cp_async_lds OR ptx_cp_async_lds_rsrc, ...)). - _is_commit_call recognises Evaluate(Call(ptx_commit_group, ...)) (the existing pipeline emits these directly; the reference branch's async_commit_queue_scope AttrStmt is already lowered by the time we run). - _find_for_with_commit walks the body to find the innermost For whose subtree contains a commit; _count_async_loads counts async loads inside one iteration of that For, multiplying by IntImm loop extents for nested unrolls. - _fix_amd_wait_counts rewrites Evaluate(Call(ptx_wait_group, [n])) to Evaluate(Call(ptx_wait_group, [n * loads_per_group])) when n > 0. wait_group(0) (wait-all) stays unchanged. The pass runs the scaling AFTER the existing descriptor hoist so the post-hoist body (now using ptx_cp_async_lds_rsrc) is the input to the load counter; the recognised loads include the rsrc form. Bench result with M5-harden + M9 + M6.5: [iter] correctness: PASS [iter] latency: 0.9803 ms [iter] TFLOPS: 1121.62 [iter] TB/s: 0.411 Above the AC-1 hard floor of 1000. Matches the reference branch's "1120 TFLOPS now." commit message. Emitted .hipi has the target tl::cp_async_wait<8> at the K-loop wait and tl::cp_async_wait<0> at the final wait; .s contains 16 buffer_load_dwordx4 ... lds occurrences. Co-Authored-By: Claude Opus 4 (1M context) --- tilelang/transform/hoist_buffer_resource.py | 134 ++++++++++++++++++-- 1 file changed, 126 insertions(+), 8 deletions(-) diff --git a/tilelang/transform/hoist_buffer_resource.py b/tilelang/transform/hoist_buffer_resource.py index b3d13c10c..6aba490c2 100644 --- a/tilelang/transform/hoist_buffer_resource.py +++ b/tilelang/transform/hoist_buffer_resource.py @@ -1,4 +1,4 @@ -"""Hoist make_wave_buffer_resource descriptors for gfx950 buffer_load...lds. +"""Hoist make_wave_buffer_resource descriptors + scale AMD async wait counts. On gfx950, the `cp_async_gs_lds_with_rsrc<16>` device template takes a pre-computed buffer resource descriptor and a pre-computed wave-uniform @@ -8,6 +8,18 @@ to the kernel prologue once per source buffer and rewrite the calls to the variant that takes the pre-hoisted pair. +Second half: AMD vmcnt tracks individual `buffer_load` issues, not the +NVIDIA-style commit groups. `tl::cp_async_wait` lowers to +`s_waitcnt vmcnt(N)`, so a wait-for-N-groups must become +wait-for-(N * loads_per_group) on AMD. NVIDIA's cp.async commits group +every async load issued since the last commit; on AMD we have to scale +the wait count manually. We do that by finding the for-loop that +contains the `ptx_commit_group` call, counting async loads in one +iteration of that loop (multiplied by loop extents for nested unrolls), +and rewriting every positive `ptx_wait_group(n)` to `ptx_wait_group(n * +loads_per_group)`. `ptx_wait_group(0)` (wait-all) stays as `vmcnt(0)`, +which is already correct. + Pipeline order: this pass runs in the OptimizeForTarget phase after ThreadSync/MergeIfStmt and before MakePackedAPI, which means the tl::access_ptr calls have already been lowered by `LowerAccessPtr` to @@ -16,12 +28,6 @@ This pass is gfx950-only: on every other target it returns the PrimFunc unchanged. - -NOTE: AMD vmcnt wait-count scaling (the second half of the reference -implementation on the zty_opt_can_run_1120flops branch) is deliberately -omitted in this commit. It will land as a separate milestone (M6.5) -only if the M6 bench shows a correctness failure attributable to async -wait counts. """ from tvm import tir @@ -33,6 +39,8 @@ _op_ptx_cp_async_lds = tir.op.Op.get("tl.ptx_cp_async_lds") _op_ptx_cp_async_lds_rsrc = tir.op.Op.get("tl.ptx_cp_async_lds_rsrc") _op_tvm_access_ptr = tir.op.Op.get("tir.tvm_access_ptr") +_op_ptx_commit_group = tir.op.Op.get("tir.ptx_commit_group") +_op_ptx_wait_group = tir.op.Op.get("tir.ptx_wait_group") def _extract_buffer_var(access_ptr_expr): @@ -55,6 +63,110 @@ def _extract_buffer_var(access_ptr_expr): return None +def _is_async_load_call(stmt): + if not isinstance(stmt, Evaluate) or not isinstance(stmt.value, Call): + return False + op = stmt.value.op + return op == _op_ptx_cp_async_lds or op == _op_ptx_cp_async_lds_rsrc + + +def _is_commit_call(stmt): + if not isinstance(stmt, Evaluate) or not isinstance(stmt.value, Call): + return False + return stmt.value.op == _op_ptx_commit_group + + +def _contains_commit_call(stmt): + found = [False] + + def _v(s): + if _is_commit_call(s): + found[0] = True + + stmt_functor.post_order_visit(stmt, _v) + return found[0] + + +def _find_for_with_commit(stmt): + """Find the innermost For loop whose body contains a commit call.""" + if isinstance(stmt, tir.For): + inner = _find_for_with_commit(stmt.body) + if inner is not None: + return inner + if _contains_commit_call(stmt.body): + return stmt + elif isinstance(stmt, tir.SeqStmt): + for s in stmt.seq: + r = _find_for_with_commit(s) + if r is not None: + return r + elif hasattr(stmt, "body"): + return _find_for_with_commit(stmt.body) + return None + + +def _count_async_loads(stmt, multiplier=1): + if _is_async_load_call(stmt): + return multiplier + if isinstance(stmt, tir.For): + ext = multiplier + if isinstance(stmt.extent, tir.IntImm): + ext = multiplier * stmt.extent.value + return _count_async_loads(stmt.body, ext) + if isinstance(stmt, tir.SeqStmt): + return sum(_count_async_loads(s, multiplier) for s in stmt.seq) + if isinstance(stmt, tir.AttrStmt): + return _count_async_loads(stmt.body, multiplier) + if isinstance(stmt, tir.IfThenElse): + c = _count_async_loads(stmt.then_case, multiplier) + if stmt.else_case is not None: + c = max(c, _count_async_loads(stmt.else_case, multiplier)) + return c + if isinstance(stmt, tir.LetStmt): + return _count_async_loads(stmt.body, multiplier) + return 0 + + +def _get_loads_per_group(body): + for_node = _find_for_with_commit(body) + if for_node is not None: + return _count_async_loads(for_node.body) + return 0 + + +def _fix_amd_wait_counts(body, loads_per_group): + """Multiply positive ptx_wait_group(n) arguments by loads_per_group. + + Each `tl::cp_async_wait` on AMD lowers to `s_waitcnt vmcnt(N)`, + which counts individual buffer_loads rather than NVIDIA-style commit + groups. wait_group(0) (wait-all) stays unchanged because vmcnt(0) + is already the correct "wait for everything" sentinel. + """ + + def _postorder(op): + if not isinstance(op, Evaluate): + return None + if not isinstance(op.value, Call): + return None + if op.value.op != _op_ptx_wait_group: + return None + if len(op.value.args) != 1: + return None + n_arg = op.value.args[0] + if not isinstance(n_arg, tir.IntImm): + return None + if n_arg.value <= 0: + return None + new_call = Call( + op.value.dtype, + _op_ptx_wait_group, + [tir.IntImm(n_arg.dtype, n_arg.value * loads_per_group)], + ) + return Evaluate(new_call) + + return stmt_functor.ir_transform(body, None, _postorder, ["tir.Evaluate"]) + + def _collect_buffer_vars(body): """Discover unique source buffer Vars referenced by ptx_cp_async_lds calls. @@ -104,7 +216,7 @@ def _postorder(op): def HoistBufferResource(): - """gfx950: hoist buffer resource descriptors out of the inner copy loop.""" + """gfx950: hoist buffer resource descriptors + scale AMD vmcnt waits.""" def pass_fn(func: PrimFunc, _mod, _ctx): target = func.attrs.get("target", None) @@ -121,6 +233,12 @@ def pass_fn(func: PrimFunc, _mod, _ctx): new_body = AttrStmt(base_var, "buffer_base_var", buf_var, new_body) new_body = AttrStmt(rsrc_var, "buffer_resource_var", buf_var, new_body) + # AMD wait-count scaling. Only meaningful when there's at least one + # commit group; otherwise loads_per_group is 0 and we skip. + loads_per_group = _get_loads_per_group(new_body) + if loads_per_group > 1: + new_body = _fix_amd_wait_counts(new_body, loads_per_group) + return func.with_body(new_body) return prim_func_pass(pass_fn, opt_level=0) From c3c0f9a6de6f21b1a6e639591d43f3156776bc1a Mon Sep 17 00:00:00 2001 From: zhutaoyu Date: Sun, 17 May 2026 14:54:49 +0000 Subject: [PATCH 18/34] add a flag to save temp files --- src/backend/rocm/codegen/codegen_hip.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/backend/rocm/codegen/codegen_hip.cc b/src/backend/rocm/codegen/codegen_hip.cc index 46bdf770f..25d7873cf 100644 --- a/src/backend/rocm/codegen/codegen_hip.cc +++ b/src/backend/rocm/codegen/codegen_hip.cc @@ -994,6 +994,9 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) { print_extern_call_stmt("tl::cp_async_commit"); } else if (op->op.same_as(builtin::ptx_wait_group())) { int n = Downcast(op->args[0])->value; + // AMDGPU s_waitcnt vmcnt field is 6-bit (max 63); clamp to keep the + // "n"(cnt) immediate constraint in tl::cp_async_wait valid. + if (n > 63) n = 63; std::string func_name = "tl::cp_async_wait<" + std::to_string(n) + ">"; print_extern_call_stmt(func_name, 1); } else if (op->op.same_as(builtin::create_barriers())) { From e6ff3f05f8fc9a606af9978851d7bb26495e1edd Mon Sep 17 00:00:00 2001 From: zhutaoyu Date: Sun, 17 May 2026 15:36:06 +0000 Subject: [PATCH 19/34] M9-safe: downgrade ptx_cp_async_lds to ptx_cp_async when swap can't produce linear LDS The M9 swizzle-swap subtracts SwizzleDelta from the LAST output dim of Forward(dst_indices). That cancels the XOR cleanly for layouts whose swizzle is confined to the last column (the bench's A_shared layout, and B_shared in the bench's symmetric NT shape), but layouts where the swizzle spreads across multiple output dims leave residual swizzle in the higher output dims AND over-correct src. The dst LDS index is no longer lane-contiguous, so the emitted buffer_load_dwordx4 ... lds writes to wrong addresses. Concrete reproducer: python gemm/example_gemm.py at the default 1024x1024x16384 shape. A's LDS dst comes out as `i_1 * 1024 + threadIdx.x * 8` (clean linear) but B's comes out as `((tx & 15) >> 3) * 2048 + i_2 * 512 + (tx >> 4) * 64 + (tx & 7) * 8` (NOT lane-contiguous), and B's global src ends up with a spurious `- ((tx & 7) * 8)` term. Bench data 99.7% mismatched. Fix has two parts: 1. Restructure the M9 handler so EITHER outcome (swap-or-skip) is a well-defined action: if the candidate gate (`dst_ap` / `src_ap` / shared dst / global src / remap / HasSwizzle / non-empty indices) does not hold, immediately downgrade the op from tl::ptx_cp_async_lds back to tl::ptx_cp_async (same arg shape) and recurse with the default visitor. Without this, a call that the guard rejected would keep its ptx_cp_async_lds op and codegen would emit the LDS template against an unmodified (swizzled) dst. 2. After the swap math runs, sample each post-swap new_dst_indices[d] against a thread-like Var: compute f(0), f(1), expected_stride, then require analyzer.CanProveEqual(f(k)-f(0), k*stride) at sample points 2..63. If any dim is non-affine in the thread var (bit- extract `(tx & m) >> s` terms etc.), downgrade as above. Verified on two shapes: - python gemm/example_gemm.py (1024x1024x16384): All check passed. Latency 0.0443 ms. - bash scripts/iter_buffer_load.sh (8192x8192x8192): correctness PASS, TFLOPS 1116.84 (A still uses LDS fast path; B symmetric here also passes the linearity check). Co-Authored-By: Claude Opus 4 (1M context) --- src/transform/lower_tile_op.cc | 116 ++++++++++++++++++++++++++++----- 1 file changed, 99 insertions(+), 17 deletions(-) diff --git a/src/transform/lower_tile_op.cc b/src/transform/lower_tile_op.cc index 81ffbc1bc..60b65b1bf 100644 --- a/src/transform/lower_tile_op.cc +++ b/src/transform/lower_tile_op.cc @@ -953,13 +953,26 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { const auto *src_ap = op->args[1].as(); const BufferLoadNode *dst_load = resolve_load(op->args[0]); const BufferLoadNode *src_load = resolve_load(op->args[1]); - if (dst_ap && src_ap && dst_load && src_load && - IsSharedBuffer(dst_load->buffer) && - IsGlobalBuffer(src_load->buffer) && - buffer_remap_.count(dst_load->buffer) && - layout_map_.count(dst_load->buffer) && - layout_map_[dst_load->buffer]->HasSwizzle() && - dst_load->indices.size() > 0 && src_load->indices.size() > 0) { + // Candidate gate: do we have everything we need to even attempt + // the swap? If not, fall through to "downgrade" below — we must + // NOT leave the call as ptx_cp_async_lds, because codegen would + // then emit the LDS template against the unmodified (swizzled) + // dst index and produce wrong addresses. + bool m9_candidate = dst_ap && src_ap && dst_load && src_load && + IsSharedBuffer(dst_load->buffer) && + IsGlobalBuffer(src_load->buffer) && + buffer_remap_.count(dst_load->buffer) && + layout_map_.count(dst_load->buffer) && + layout_map_[dst_load->buffer]->HasSwizzle() && + dst_load->indices.size() > 0 && + src_load->indices.size() > 0; + if (!m9_candidate) { + // Can't do the swap. Downgrade so codegen uses the safe + // synchronous cp_async_gs path instead of buffer_load_lds. + Call downgraded(op->dtype, tl::ptx_cp_async(), op->args); + return IRMutatorWithAnalyzer::VisitExpr(downgraded); + } + { Buffer new_dst_buf = buffer_remap_[dst_load->buffer]; layout_remap_.Set(new_dst_buf, layout_map_[dst_load->buffer]); auto layout = layout_map_[dst_load->buffer]; @@ -980,16 +993,85 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { last_src, analyzer_->Simplify(new_src_indices[last_src] + delta)); - BufferLoad new_dst_load(new_dst_buf, new_dst_indices); - BufferLoad new_src_load(src_load->buffer, new_src_indices); - PrimExpr new_dst_ap = - Call(dst_ap->dtype, tl::access_ptr(), - {new_dst_load, dst_ap->args[1], dst_ap->args[2]}); - PrimExpr new_src_ap = - Call(src_ap->dtype, tl::access_ptr(), - {new_src_load, src_ap->args[1], src_ap->args[2]}); - return Call(op->dtype, op->op, - {new_dst_ap, new_src_ap, op->args[2]}); + // Post-swap linearity guard: the single-dim subtract-delta swap + // only cancels the XOR when the layout's swizzle is confined to + // the last output dim. For layouts (e.g. B's matmul layout in + // some shapes) where the swizzle spreads across multiple output + // dims, the post-swap LDS index is NOT lane-contiguous and + // buffer_load_dwordx4...lds would write to wrong addresses. + // Detect by sampling each new_dst_indices[d] against a + // thread-like var and requiring an affine (constant or + // single-stride) dependence. If any dim is non-affine in tx + // (typical for bit-extract `(tx & m) >> s` terms), bail out so + // the call falls through to the safe ptx_cp_async path. + auto is_affine_in_thread_var = [&](const PrimExpr &e) -> bool { + arith::Analyzer post_analyzer; + std::unordered_set seen; + Array free_vars; + tir::PostOrderVisit(e, [&](const ObjectRef &node) { + if (auto *v = node.as()) { + if (seen.insert(v).second) { + free_vars.push_back(Downcast(node)); + } + } + }); + for (const auto &var : free_vars) { + const std::string name(var->name_hint); + if (name.find("thread") == std::string::npos && name != "tx" && + name != "tid") { + continue; + } + PrimExpr f0 = post_analyzer.Simplify(tir::Substitute( + e, Map{{var, IntImm(var->dtype, 0)}})); + PrimExpr f1 = post_analyzer.Simplify(tir::Substitute( + e, Map{{var, IntImm(var->dtype, 1)}})); + PrimExpr stride = post_analyzer.Simplify(f1 - f0); + const auto *stride_imm = stride.as(); + if (!stride_imm) return false; + for (int pt = 2; pt < 64; ++pt) { + PrimExpr fk = post_analyzer.Simplify(tir::Substitute( + e, Map{{var, IntImm(var->dtype, pt)}})); + PrimExpr actual = post_analyzer.Simplify(fk - f0); + PrimExpr expected = + IntImm(DataType::Int(64), stride_imm->value * pt); + if (!post_analyzer.CanProveEqual(actual, expected)) { + return false; + } + } + return true; + } + // No thread-like var found: constant w.r.t. tx -> affine OK. + return true; + }; + bool all_dims_affine = true; + for (const auto &idx : new_dst_indices) { + if (!is_affine_in_thread_var(idx)) { + all_dims_affine = false; + break; + } + } + if (!all_dims_affine) { + // The swap can't produce a lane-contiguous LDS dst for this + // layout. Downgrade the op from tl::ptx_cp_async_lds to + // tl::ptx_cp_async (same arg shape) so codegen emits the safe + // synchronous cp_async_gs path rather than buffer_load_lds + // with a non-contiguous LDS index. Let the default visitor + // recurse from there so the access_ptr children still get the + // ordinary swizzled-layout treatment. + Call downgraded(op->dtype, tl::ptx_cp_async(), op->args); + return IRMutatorWithAnalyzer::VisitExpr(downgraded); + } else { + BufferLoad new_dst_load(new_dst_buf, new_dst_indices); + BufferLoad new_src_load(src_load->buffer, new_src_indices); + PrimExpr new_dst_ap = + Call(dst_ap->dtype, tl::access_ptr(), + {new_dst_load, dst_ap->args[1], dst_ap->args[2]}); + PrimExpr new_src_ap = + Call(src_ap->dtype, tl::access_ptr(), + {new_src_load, src_ap->args[1], src_ap->args[2]}); + return Call(op->dtype, op->op, + {new_dst_ap, new_src_ap, op->args[2]}); + } } } From 79c658011fbdc7a54798e65e1731e947366b441a Mon Sep 17 00:00:00 2001 From: zhutaoyu Date: Mon, 18 May 2026 04:02:42 +0000 Subject: [PATCH 20/34] add a flag to save temp files --- src/backend/rocm/codegen/codegen_hip.cc | 20 ++++++++--------- src/layout/layout.cc | 10 ++++----- src/tl_templates/hip/copy.h | 28 ++++++++++++----------- src/transform/lower_ptx_async_copy.cc | 24 ++++++++++---------- src/transform/lower_tile_op.cc | 30 ++++++++++++------------- 5 files changed, 56 insertions(+), 56 deletions(-) diff --git a/src/backend/rocm/codegen/codegen_hip.cc b/src/backend/rocm/codegen/codegen_hip.cc index 25d7873cf..905faa1e0 100644 --- a/src/backend/rocm/codegen/codegen_hip.cc +++ b/src/backend/rocm/codegen/codegen_hip.cc @@ -922,8 +922,7 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) { os << "make_wave_buffer_resource((const void*)(" << ptr << "))"; } else if (op->op.same_as(tl::ptx_cp_async_lds_rsrc())) { // args = [dst, src, bytes, rsrc_var, base_var] - ICHECK(op->args.size() == 5) - << "ptx_cp_async_lds_rsrc expects 5 arguments"; + ICHECK(op->args.size() == 5) << "ptx_cp_async_lds_rsrc expects 5 arguments"; std::string dst = this->PrintExpr(op->args[0]); std::string src = this->PrintExpr(op->args[1]); // arg 2 carries logical element count (inherited from the @@ -938,8 +937,7 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) { ICHECK(dst_elem_type.has_value()) << "ptx_cp_async_lds_rsrc dst must be tvm_access_ptr / tl.access_ptr / " "address_of(BufferLoad)"; - int64_t total_bits = num_elems_imm->value * - dst_elem_type.value().bits() * + int64_t total_bits = num_elems_imm->value * dst_elem_type.value().bits() * dst_elem_type.value().lanes(); ICHECK_EQ(total_bits % 8, 0) << "ptx_cp_async_lds_rsrc requires byte-aligned transfer, got " @@ -949,9 +947,8 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) { std::string rsrc = this->PrintExpr(op->args[3]); std::string base = this->PrintExpr(op->args[4]); this->PrintIndent(); - this->stream << "tl::cp_async_gs_lds_with_rsrc<" << size << ">(" - << dst << ", " << src << ", " << rsrc << ", " << base - << ");\n"; + this->stream << "tl::cp_async_gs_lds_with_rsrc<" << size << ">(" << dst + << ", " << src << ", " << rsrc << ", " << base << ");\n"; } else if (op->op.same_as(builtin::ptx_cp_async())) { // builtin::ptx_cp_async stores byte width directly in arg 2. ICHECK(op->args.size() == 3 || op->args.size() == 4) @@ -996,7 +993,8 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) { int n = Downcast(op->args[0])->value; // AMDGPU s_waitcnt vmcnt field is 6-bit (max 63); clamp to keep the // "n"(cnt) immediate constraint in tl::cp_async_wait valid. - if (n > 63) n = 63; + if (n > 63) + n = 63; std::string func_name = "tl::cp_async_wait<" + std::to_string(n) + ">"; print_extern_call_stmt(func_name, 1); } else if (op->op.same_as(builtin::create_barriers())) { @@ -1485,7 +1483,8 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) { void CodeGenTileLangHIP::VisitStmt_(const AttrStmtNode *op) { if (op->attr_key == "buffer_resource_var") { // Hoisted resource descriptor from the HoistBufferResource Python pass. - // Emits: auto {rsrc_var} = make_wave_buffer_resource((const void*)({buf_var})); + // Emits: auto {rsrc_var} = make_wave_buffer_resource((const + // void*)({buf_var})); auto rsrc_var = Downcast(op->node); std::string rsrc_vid = AllocVarID(rsrc_var.get()); std::string buf_ptr = PrintExpr(op->value); @@ -1553,8 +1552,7 @@ void CodeGenTileLangHIP::VisitStmt_(const LetStmtNode *op) { if (call->op.same_as(tl::ptx_make_buffer_resource())) { std::string value = PrintExpr(op->value); PrintIndent(); - stream << "auto " << AllocVarID(op->var.get()) << " = " << value - << ";\n"; + stream << "auto " << AllocVarID(op->var.get()) << " = " << value << ";\n"; PrintStmt(op->body); return; } diff --git a/src/layout/layout.cc b/src/layout/layout.cc index 045234ede..452e38349 100644 --- a/src/layout/layout.cc +++ b/src/layout/layout.cc @@ -523,12 +523,12 @@ PrimExpr LayoutNode::SwizzleDelta(const Array &input_indices) const { // Substitute the last InputDim() entries of input_indices into // swizzle_delta_, matching the convention Forward() uses. PrimExpr delta = swizzle_delta_.value(); - size_t offset = - input_indices.size() >= InputDim() ? input_indices.size() - InputDim() - : 0; + size_t offset = input_indices.size() >= InputDim() + ? input_indices.size() - InputDim() + : 0; for (size_t i = 0; i < InputDim(); ++i) { - delta = Substitute(delta, - {{InputPlaceholder(i), input_indices[offset + i]}}); + delta = + Substitute(delta, {{InputPlaceholder(i), input_indices[offset + i]}}); } return delta; } diff --git a/src/tl_templates/hip/copy.h b/src/tl_templates/hip/copy.h index 82155aecb..461cafe85 100644 --- a/src/tl_templates/hip/copy.h +++ b/src/tl_templates/hip/copy.h @@ -138,14 +138,15 @@ TL_DEVICE void cp_async_gs_conditional(void *lds_base_ptr, } } -// gfx950 (CDNA4 / MI350) direct global→LDS copy via buffer_load_dwordx4 ... lds. -// Bypasses VGPRs entirely. Only valid when LDS destination is lane-contiguous -// (base + lane_id * N bytes); the swizzle-swap pass in LowerTileOp guarantees -// this by moving the XOR swizzle from the LDS store side to the global load -// side. The compiler is expected to emit this only for 16-byte copies; other -// sizes fall back to cp_async_gs. +// gfx950 (CDNA4 / MI350) direct global→LDS copy via buffer_load_dwordx4 ... +// lds. Bypasses VGPRs entirely. Only valid when LDS destination is +// lane-contiguous (base + lane_id * N bytes); the swizzle-swap pass in +// LowerTileOp guarantees this by moving the XOR swizzle from the LDS store side +// to the global load side. The compiler is expected to emit this only for +// 16-byte copies; other sizes fall back to cp_async_gs. template -TL_DEVICE void cp_async_gs_lds(void *lds_base_ptr, void const *global_base_ptr) { +TL_DEVICE void cp_async_gs_lds(void *lds_base_ptr, + void const *global_base_ptr) { if constexpr (N == 16) { auto rsrc = make_wave_buffer_resource(global_base_ptr); uint32_t my_lo = @@ -156,7 +157,8 @@ TL_DEVICE void cp_async_gs_lds(void *lds_base_ptr, void const *global_base_ptr) static_cast(reinterpret_cast(lds_base_ptr))); asm volatile("s_mov_b32 m0, %0; \n\t" "buffer_load_dwordx4 %1, %2, 0 offen lds;\n\t" - : : "s"(lds_cur), "v"(voffset), "s"(rsrc) + : + : "s"(lds_cur), "v"(voffset), "s"(rsrc) : "memory"); } else { cp_async_gs(lds_base_ptr, global_base_ptr); @@ -170,10 +172,9 @@ TL_DEVICE void cp_async_gs_lds(void *lds_base_ptr, void const *global_base_ptr) // loop. rsrc_base_lo must equal readfirstlane((uint32_t)(uintptr_t)A) for // the same A passed to make_wave_buffer_resource that produced rsrc. template -TL_DEVICE void cp_async_gs_lds_with_rsrc(void *lds_base_ptr, - void const *global_base_ptr, - int32x4_t rsrc, - uint32_t rsrc_base_lo) { +TL_DEVICE void +cp_async_gs_lds_with_rsrc(void *lds_base_ptr, void const *global_base_ptr, + int32x4_t rsrc, uint32_t rsrc_base_lo) { if constexpr (N == 16) { uint32_t my_lo = static_cast(reinterpret_cast(global_base_ptr)); @@ -182,7 +183,8 @@ TL_DEVICE void cp_async_gs_lds_with_rsrc(void *lds_base_ptr, static_cast(reinterpret_cast(lds_base_ptr))); asm volatile("s_mov_b32 m0, %0; \n\t" "buffer_load_dwordx4 %1, %2, 0 offen lds;\n\t" - : : "s"(lds_cur), "v"(voffset), "s"(rsrc) + : + : "s"(lds_cur), "v"(voffset), "s"(rsrc) : "memory"); } else { cp_async_gs(lds_base_ptr, global_base_ptr); diff --git a/src/transform/lower_ptx_async_copy.cc b/src/transform/lower_ptx_async_copy.cc index 9d29da41f..b7d261d5a 100644 --- a/src/transform/lower_ptx_async_copy.cc +++ b/src/transform/lower_ptx_async_copy.cc @@ -456,8 +456,8 @@ class PTXAsyncCopyInjector : public StmtMutator { constexpr int kNumSamples = 1024; for (const auto &var : free_vars) { const std::string name(var->name_hint); - if (name.find("thread") == std::string::npos && - name != "tx" && name != "tid") { + if (name.find("thread") == std::string::npos && name != "tx" && + name != "tid") { continue; } PrimExpr f0 = analyzer.Simplify(Substitute( @@ -549,12 +549,13 @@ class PTXAsyncCopyInjector : public StmtMutator { IntImm(DataType::Int(32), rw_mask)}); } - Optional - MakeCPAsyncStmtFromLoads(const BufferStoreNode *store, - const BufferLoad &dst_base_load, - const BufferLoad &src_base_load, int num_elems, - int total_bytes, const PrimExpr &dst_check_index, - bool predicated, const PrimExpr &predicate_value) { + Optional MakeCPAsyncStmtFromLoads(const BufferStoreNode *store, + const BufferLoad &dst_base_load, + const BufferLoad &src_base_load, + int num_elems, int total_bytes, + const PrimExpr &dst_check_index, + bool predicated, + const PrimExpr &predicate_value) { PrimExpr dst_access_ptr = MakeAccessPtrFromLoad(dst_base_load, num_elems, /*rw_mask=*/2); PrimExpr src_access_ptr = @@ -569,13 +570,12 @@ class PTXAsyncCopyInjector : public StmtMutator { // logical count back to bytes via GetTileLangCPAsyncTransferBytes. if (enable_buffer_load_lds_ && !predicated && total_bytes == 16) { const std::string dst_scope = store->buffer.scope(); - const bool is_shared = - dst_scope == "shared" || dst_scope == "shared.dyn"; + const bool is_shared = dst_scope == "shared" || dst_scope == "shared.dyn"; if (is_shared && IsLdsLaneContiguous(dst_check_index)) { ffi::Array lds_args = {dst_access_ptr, src_access_ptr, PrimExpr(num_elems)}; - return Evaluate(Call(store->buffer->dtype, tvm::tl::ptx_cp_async_lds(), - lds_args)); + return Evaluate( + Call(store->buffer->dtype, tvm::tl::ptx_cp_async_lds(), lds_args)); } } diff --git a/src/transform/lower_tile_op.cc b/src/transform/lower_tile_op.cc index 60b65b1bf..240218a31 100644 --- a/src/transform/lower_tile_op.cc +++ b/src/transform/lower_tile_op.cc @@ -933,13 +933,15 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { // access_ptr by +delta on the last dim (XOR is self-inverse, so net // data movement is unchanged). Returning the rewritten Call directly // prevents the default visitor from re-applying the swizzled layout. - if (op->op.same_as(tl::ptx_cp_async_lds()) && - TargetIsRocm(target_) && op->args.size() == 3) { + if (op->op.same_as(tl::ptx_cp_async_lds()) && TargetIsRocm(target_) && + op->args.size() == 3) { auto resolve_load = [&](const PrimExpr &arg) -> const BufferLoadNode * { const auto *call = arg.as(); - if (!call || !call->op.same_as(tl::access_ptr())) return nullptr; + if (!call || !call->op.same_as(tl::access_ptr())) + return nullptr; const auto *direct = call->args[0].as(); - if (direct) return direct; + if (direct) + return direct; if (const auto *var = call->args[0].as()) { auto it = let_bindings_.find(Downcast(call->args[0])); if (it != let_bindings_.end()) { @@ -983,15 +985,13 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { Array new_dst_indices(swizzled.begin(), swizzled.end()); int last_dst = static_cast(new_dst_indices.size()) - 1; new_dst_indices.Set( - last_dst, - analyzer_->Simplify(new_dst_indices[last_dst] - delta)); + last_dst, analyzer_->Simplify(new_dst_indices[last_dst] - delta)); Array new_src_indices(src_load->indices.begin(), src_load->indices.end()); int last_src = static_cast(new_src_indices.size()) - 1; new_src_indices.Set( - last_src, - analyzer_->Simplify(new_src_indices[last_src] + delta)); + last_src, analyzer_->Simplify(new_src_indices[last_src] + delta)); // Post-swap linearity guard: the single-dim subtract-delta swap // only cancels the XOR when the layout's swizzle is confined to @@ -1027,7 +1027,8 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { e, Map{{var, IntImm(var->dtype, 1)}})); PrimExpr stride = post_analyzer.Simplify(f1 - f0); const auto *stride_imm = stride.as(); - if (!stride_imm) return false; + if (!stride_imm) + return false; for (int pt = 2; pt < 64; ++pt) { PrimExpr fk = post_analyzer.Simplify(tir::Substitute( e, Map{{var, IntImm(var->dtype, pt)}})); @@ -1069,8 +1070,7 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { PrimExpr new_src_ap = Call(src_ap->dtype, tl::access_ptr(), {new_src_load, src_ap->args[1], src_ap->args[2]}); - return Call(op->dtype, op->op, - {new_dst_ap, new_src_ap, op->args[2]}); + return Call(op->dtype, op->op, {new_dst_ap, new_src_ap, op->args[2]}); } } } @@ -1140,13 +1140,13 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { Array reflected(store->indices.begin(), store->indices.end()); int last_in = static_cast(reflected.size()) - 1; - reflected.Set( - last_in, analyzer_->Simplify(reflected[last_in] + delta)); + reflected.Set(last_in, + analyzer_->Simplify(reflected[last_in] + delta)); Array new_load_indices; for (size_t k = 0; k < load_node->indices.size(); ++k) { - PrimExpr base = analyzer_->Simplify(load_node->indices[k] - - store->indices[k]); + PrimExpr base = + analyzer_->Simplify(load_node->indices[k] - store->indices[k]); new_load_indices.push_back( analyzer_->Simplify(base + reflected[k])); } From ff6cb4ebf74ef462bc6aa848ff0accd6bece210d Mon Sep 17 00:00:00 2001 From: zhutaoyu Date: Mon, 18 May 2026 04:04:33 +0000 Subject: [PATCH 21/34] format file --- tilelang/transform/hoist_buffer_resource.py | 50 ++++++++++----------- 1 file changed, 24 insertions(+), 26 deletions(-) diff --git a/tilelang/transform/hoist_buffer_resource.py b/tilelang/transform/hoist_buffer_resource.py index 6aba490c2..a939ded47 100644 --- a/tilelang/transform/hoist_buffer_resource.py +++ b/tilelang/transform/hoist_buffer_resource.py @@ -67,7 +67,7 @@ def _is_async_load_call(stmt): if not isinstance(stmt, Evaluate) or not isinstance(stmt.value, Call): return False op = stmt.value.op - return op == _op_ptx_cp_async_lds or op == _op_ptx_cp_async_lds_rsrc + return op in (_op_ptx_cp_async_lds, _op_ptx_cp_async_lds_rsrc) def _is_commit_call(stmt): @@ -176,14 +176,13 @@ def _collect_buffer_vars(body): buffer_vars = {} def _visit(stmt): - if isinstance(stmt, Evaluate) and isinstance(stmt.value, Call): - if stmt.value.op == _op_ptx_cp_async_lds: - # ptx_cp_async_lds args: (dst_access_ptr, src_access_ptr, bytes) - buf_var = _extract_buffer_var(stmt.value.args[1]) - if buf_var is not None and buf_var not in buffer_vars: - rsrc_var = Var("__rsrc_" + buf_var.name, dtype="handle") - base_var = Var("__base_" + buf_var.name, dtype="uint32") - buffer_vars[buf_var] = (rsrc_var, base_var) + if isinstance(stmt, Evaluate) and isinstance(stmt.value, Call) and stmt.value.op == _op_ptx_cp_async_lds: + # ptx_cp_async_lds args: (dst_access_ptr, src_access_ptr, bytes) + buf_var = _extract_buffer_var(stmt.value.args[1]) + if buf_var is not None and buf_var not in buffer_vars: + rsrc_var = Var("__rsrc_" + buf_var.name, dtype="handle") + base_var = Var("__base_" + buf_var.name, dtype="uint32") + buffer_vars[buf_var] = (rsrc_var, base_var) stmt_functor.post_order_visit(body, _visit) return buffer_vars @@ -193,23 +192,22 @@ def _rewrite_calls(body, buffer_vars): """Rewrite ptx_cp_async_lds -> ptx_cp_async_lds_rsrc with hoisted vars.""" def _postorder(op): - if isinstance(op, Evaluate) and isinstance(op.value, Call): - if op.value.op == _op_ptx_cp_async_lds: - buf_var = _extract_buffer_var(op.value.args[1]) - if buf_var is not None and buf_var in buffer_vars: - rsrc_var, base_var = buffer_vars[buf_var] - new_call = Call( - op.value.dtype, - _op_ptx_cp_async_lds_rsrc, - [ - op.value.args[0], - op.value.args[1], - op.value.args[2], - rsrc_var, - base_var, - ], - ) - return Evaluate(new_call) + if isinstance(op, Evaluate) and isinstance(op.value, Call) and op.value.op == _op_ptx_cp_async_lds: + buf_var = _extract_buffer_var(op.value.args[1]) + if buf_var is not None and buf_var in buffer_vars: + rsrc_var, base_var = buffer_vars[buf_var] + new_call = Call( + op.value.dtype, + _op_ptx_cp_async_lds_rsrc, + [ + op.value.args[0], + op.value.args[1], + op.value.args[2], + rsrc_var, + base_var, + ], + ) + return Evaluate(new_call) return None return stmt_functor.ir_transform(body, None, _postorder, ["tir.Evaluate"]) From 1007a033829c108f62fd066d8eb01a6993547a0c Mon Sep 17 00:00:00 2001 From: zhutaoyu Date: Thu, 21 May 2026 05:48:17 +0000 Subject: [PATCH 22/34] M10: chunk-block-aware binding via early CBA hook in InferLayout Existing CBA helper in ComputePlanCandidate only runs when source_buffer is undefined. For T.copy(B, B_shared) at kCommon/kStrict, source_buffer chain takes B_shared and routes through ComputeLoopLayoutFromBuffer, so CBA never fires and the default PlanLoopPartition flatten puts 16 lanes on N (crossing the FullBank tc=2 boundary at lane 8 -> +1920B jump). Add an early hook at the top of ParallelOp::InferLayout that runs CBA at all 3 levels and sets loop_layout_ unconditionally when an eligible target buffer exists. CBA's gate (last_dim_bytes > 128B and divisible) keeps it inert on A_shared / NT-B_shared / C_local cases. Effect: - NN 1024^3 K=16384 128x128x32: correctness PASS (was wrong values before) - NT 8192^2 K=8192 256x256x64 stages=2: 1111 TFLOPS, correctness PASS (NT B last-dim = K = 64 = 1 bank cycle, gate skips; perf unchanged) Co-Authored-By: Claude Opus 4 (1M context) --- src/op/parallel.cc | 200 +++++++++++++++++++++++++++++++++++++++++++++ src/op/parallel.h | 11 +++ 2 files changed, 211 insertions(+) diff --git a/src/op/parallel.cc b/src/op/parallel.cc index dd55a9a0f..df1ec0e90 100644 --- a/src/op/parallel.cc +++ b/src/op/parallel.cc @@ -395,6 +395,33 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, // 3. Non-replicated read buffer // 4. Fully replicated write buffer (backup, may cause issues) // 5. Free inference mode (no source buffer) + // Early chunk-block-aware override: if a written shared buffer has a + // FullBank-style swizzled layout with tc > 1, the default flatten policy + // produces a binding whose wavefront lanes straddle the chunk-block + // boundary, which breaks lane-contiguous LDS WRITEs (buffer_load ... lds). + // Override here BEFORE source_buffer dispatch so the CBA fragment wins. + // Fire at ALL levels so the first level that sees the layout map populated + // wins; subsequent levels short-circuit via loop_layout_inferred_. + if (!loop_layout_.defined()) { + // Reuse the same vec_size calculation as ComputePlanCandidate. + auto maybe_remapped_root = + IfBufferRemapLoopGenerator::run(root_, T.buffer_remap, T.layout_map); + int vector_size = + GetVectorizeSize(maybe_remapped_root, T.analyzer, T.layout_map); + PrimExpr loop_total_size = 1; + for (Stmt l = root_; l.as().has_value(); l = l.as().value()->body) + loop_total_size = loop_total_size * l.as().value()->extent; + while (!analyzer_.CanProve(floormod(loop_total_size, + T.thread_bounds->extent * + vector_size) == 0) && + vector_size > 1) + vector_size /= 2; + if (auto cba = + ComputeChunkBlockAwarePlanCandidate(T, vector_size); + cba.defined()) { + loop_layout_ = cba; + } + } if (!loop_layout_.defined() && annotated_layout_unbound_.defined()) { loop_layout_ = annotated_layout_unbound_.value()->BindThreadRange(T.thread_bounds); @@ -699,12 +726,185 @@ Fragment ParallelOpNode::ComputePlanCandidate(const LayoutInferArgs &T) const { DLOG(INFO) << "[PlanLoopPartition] root_ = " << root_ << " ############# vector_size = " << vector_size << ", thread_bounds = " << T.thread_bounds << '\n'; + // Prefer chunk-block aware binding when applicable. The default flatten + // policy below would put 16 (or more) lanes on the continuous dim; when a + // shared destination has a FullBank swizzle layout with tc > 1, those lanes + // straddle the chunk-block boundary, breaking lane-contiguous LDS WRITEs and + // any downstream swizzle-swap. + if (auto cba_plan = + ComputeChunkBlockAwarePlanCandidate(T, vector_size); + cba_plan.defined()) { + DLOG(INFO) << "[PlanLoopPartition] chunk-block-aware candidate = " + << cba_plan->DebugOutput() << '\n'; + return cba_plan; + } auto plan = PlanLoopPartition(root_, vector_size, T.thread_bounds); DLOG(INFO) << "[PlanLoopPartition] candidate = " << plan->DebugOutput() << '\n'; return plan; } +Fragment ParallelOpNode::ComputeChunkBlockAwarePlanCandidate( + const LayoutInferArgs &T, int vector_size) const { + // 1. Find a written shared buffer with a swizzle layout whose continuous + // (innermost) dim exceeds one CDNA LDS bank cycle (128 bytes). When + // that happens FullBank splits the dim into `tc` planes; the default + // flatten policy puts too many lanes on the continuous dim and one + // wavefront ends up straddling the tc boundary. We compute tc from the + // buffer's last-dim extent + element size to avoid having to interpret + // the layout's output-dim structure (which can vary with pipelining). + Buffer target; + // Continuous (innermost) buffer dim extent. When the parallel loop is + // fused, this comes from buffer->shape.back(); when unfused, it equals + // the extent of the matching loop var. + int64_t cont_ext = 0; + int64_t inner_extent = 1; + // Index into loop_vars_ for the loop var that drives the continuous dim, + // or -1 if the access is `fused_var % cont_ext` (1D fused case). + int split_axis = -1; + constexpr int kBankCycleBytes = 128; + PostOrderVisit(root_, [&](const ObjectRef &obj) { + if (target.defined()) + return; + const auto *store = obj.as(); + if (!store) + return; + const Buffer &buffer = store->buffer; + if (!IsSharedBuffer(buffer)) + return; + if (!T.layout_map.count(buffer)) + return; + Layout layout = T.layout_map[buffer]; + if (layout.as()) + return; + if (store->indices.empty()) + return; + if (buffer->shape.empty()) + return; + auto *last_dim_imm = as_const_int(buffer->shape.back()); + if (!last_dim_imm) + return; + int64_t cont = *last_dim_imm; + int element_bytes = buffer->dtype.bytes(); + if (element_bytes <= 0) + return; + int64_t bank_cycle_elems = kBankCycleBytes / element_bytes; + if (bank_cycle_elems <= 0) + return; + if (cont * element_bytes <= kBankCycleBytes) + return; + if (cont % bank_cycle_elems != 0) + return; + if ((cont / bank_cycle_elems) <= 1) + return; + + // Identify the loop var(s) driving the continuous dim. + PrimExpr last_idx = analyzer_.Simplify(store->indices.back()); + int chosen_axis = -1; + if (auto var_opt = last_idx.as()) { + // Unfused N-D case: bare loop var on the cont dim. + for (int i = 0; i < static_cast(loop_vars_.size()); i++) { + if (loop_vars_[i]->var.same_as(var_opt.value())) { + chosen_axis = i; + break; + } + } + if (chosen_axis < 0) + return; + auto *axis_ext_imm = + as_const_int(loop_vars_[chosen_axis]->dom->extent); + if (!axis_ext_imm || *axis_ext_imm != cont) + return; + } else { + // Fused 1D case: index is `fused_var % cont_ext` (after pipelining + // multi-dim accesses get flattened). Require exactly one loop var of + // extent that is a multiple of cont. + if (loop_vars_.size() != 1) + return; + auto *total_ext_imm = as_const_int(loop_vars_[0]->dom->extent); + if (!total_ext_imm || *total_ext_imm % cont != 0) + return; + // Match `fused_var % cont` (with cont equal to the buffer's last dim). + const auto *mod = last_idx.as(); + if (!mod) + return; + auto *mod_imm = as_const_int(mod->b); + if (!mod_imm || *mod_imm != cont) + return; + if (!mod->a.same_as(loop_vars_[0]->var)) + return; + } + + target = buffer; + split_axis = chosen_axis; + cont_ext = cont; + inner_extent = bank_cycle_elems; + }); + if (!target.defined()) + return Fragment(); + + // 2. Build flatten expressed purely in the existing loop vars so the + // resulting Fragment matches root_'s loop_vars and downstream + // PartitionLoop / LowerParallelLoop are unaffected. + ICHECK(!loop_vars_.empty()); + DataType dtype = loop_vars_[0]->var.dtype(); + PrimExpr inner_pe = IntImm(dtype, inner_extent); + PrimExpr flat; + if (split_axis >= 0) { + // Unfused N-D: split the chosen loop var into outer/inner and reorder + // to [outer, ..., inner] before row-major flatten. + PrimExpr split_var = loop_vars_[split_axis]->var; + PrimExpr outer_part = FloorDiv(split_var, inner_pe); + PrimExpr inner_part = FloorMod(split_var, inner_pe); + PrimExpr modified_total = IntImm(dtype, 1); + PrimExpr modified_flat = make_zero(dtype); + for (int i = 0; i < static_cast(loop_vars_.size()); i++) { + PrimExpr ext = + (i == split_axis) ? inner_pe : loop_vars_[i]->dom->extent; + PrimExpr v = (i == split_axis) + ? inner_part + : static_cast(loop_vars_[i]->var); + modified_total = modified_total * ext; + modified_flat = modified_flat * ext + v; + } + flat = outer_part * modified_total + modified_flat; + } else { + // Fused 1D: decompose fused_var into (rest, cont_inner_part, c_inner) + // where cont_inner_part = (fused_var % cont)/inner. New flat puts + // n_outer (= cont_inner_part) outermost. + PrimExpr fused = loop_vars_[0]->var; + auto *total_ext_imm = as_const_int(loop_vars_[0]->dom->extent); + ICHECK(total_ext_imm); + int64_t total = *total_ext_imm; + int64_t rest = total / cont_ext; + PrimExpr cont_pe = IntImm(dtype, cont_ext); + PrimExpr c = FloorMod(fused, cont_pe); + PrimExpr rest_part = FloorDiv(fused, cont_pe); + PrimExpr n_outer = FloorDiv(c, inner_pe); + PrimExpr c_inner = FloorMod(c, inner_pe); + PrimExpr rest_pe = IntImm(dtype, rest); + flat = n_outer * (rest_pe * inner_pe) + rest_part * inner_pe + c_inner; + } + + // 3. Apply the same coalesce policy as LoopPartitioner::Partition: + // access_idx = flat / vec_size, thd = access_idx % num_thread, + // idx = (access_idx / num_thread) * vec_size + flat % vec_size. + auto *num_thread_imm = as_const_int(T.thread_bounds->extent); + if (!num_thread_imm) + return Fragment(); // Symbolic thread bounds: fall back to default plan. + PrimExpr vec_pe = IntImm(dtype, vector_size); + PrimExpr num_thread_pe = IntImm(dtype, *num_thread_imm); + PrimExpr access_idx = FloorDiv(flat, vec_pe); + PrimExpr thd = FloorMod(access_idx, num_thread_pe); + PrimExpr idx = FloorDiv(access_idx, num_thread_pe) * vec_pe + + FloorMod(flat, vec_pe); + + Fragment fragment = Fragment(loop_vars_, /*forward_index=*/{idx}, + /*forward_thread=*/thd, + /*thread_replicate=*/IterVar()); + return fragment->BindThreadRange(T.thread_bounds); +} + void ParallelOpNode::BuildReplicationGuardsIfNeeded( const LayoutInferArgs &T, const std::vector &store_shared_global_buffers, diff --git a/src/op/parallel.h b/src/op/parallel.h index 751e14a22..14c42e33c 100644 --- a/src/op/parallel.h +++ b/src/op/parallel.h @@ -149,6 +149,17 @@ class ParallelOpNode : public TileOperatorNode { // Compute plan-based loop layout candidate using vectorization and thread // bounds. Fragment ComputePlanCandidate(const LayoutInferArgs &T) const; + // Compute a "chunk-block aware" plan candidate. When a written shared buffer + // has a swizzled layout whose outer (tc) dim has extent > 1 (FullBank with + // continuous-bytes > one LDS bank cycle), the default flatten-and-partition + // policy makes one wavefront's lanes straddle the chunk-block boundary, + // which breaks the lane-contiguous LDS WRITE constraint and prevents any + // downstream swizzle-swap. This candidate splits the continuous loop var + // (n = n_outer * inner + n_inner) and reorders to [n_outer, ..., n_inner] + // before flattening, so consecutive lanes stay inside one tc plane. + // Returns an undefined Fragment if no eligible buffer is found. + Fragment ComputeChunkBlockAwarePlanCandidate(const LayoutInferArgs &T, + int vector_size) const; // Add replication guard predicates when needed for cross-thread stores. void BuildReplicationGuardsIfNeeded( const LayoutInferArgs &T, From 0a864a880c0cb075d423472073e00de862989e9f Mon Sep 17 00:00:00 2001 From: zhutaoyu Date: Thu, 21 May 2026 06:03:05 +0000 Subject: [PATCH 23/34] M11: always emit ptx_cp_async_lds; let M9 swap or downgrade MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Previously the IsLdsLaneContiguous gate at lower_ptx_async_copy.cc:574 rejected any swizzled (XOR-laden) LDS index, falling through to ptx_cp_async. That cut M9 out of the loop for FullBank tc>1 cases (NN B under the new CBA binding from M10), even though M9's SwizzleDelta swap would produce a lane-contiguous index. Drop the IsLdsLaneContiguous gate. Always emit ptx_cp_async_lds when the destination is shared + 16B + non-predicated. M9 then either rewrites the call (LDS path stays) or downgrades to ptx_cp_async if post-swap is non-affine. Effect: - NN 8192^3 128x128x32 stages=3 threads=128: 443 TFLOPS (was 370) - NN 8192^3 256x256x32 stages=3 threads=256: 602 TFLOPS, PASS - NN 8192^3 256x256x64 stages=2 threads=256: 760 TFLOPS, PASS - NT 8192^2 K=8192 256x256x64 stages=2 threads=512: 1116 TFLOPS, PASS A_shared still downgrades (M9 affine check fails because HalfBank's ts boundary splits the warp), so A continues on cp_async path. Fixing A needs CBA-style binding on the stride dim — open work. Co-Authored-By: Claude Opus 4 (1M context) --- src/transform/lower_ptx_async_copy.cc | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/transform/lower_ptx_async_copy.cc b/src/transform/lower_ptx_async_copy.cc index b7d261d5a..0db946d17 100644 --- a/src/transform/lower_ptx_async_copy.cc +++ b/src/transform/lower_ptx_async_copy.cc @@ -571,7 +571,14 @@ class PTXAsyncCopyInjector : public StmtMutator { if (enable_buffer_load_lds_ && !predicated && total_bytes == 16) { const std::string dst_scope = store->buffer.scope(); const bool is_shared = dst_scope == "shared" || dst_scope == "shared.dyn"; - if (is_shared && IsLdsLaneContiguous(dst_check_index)) { + // Emit ptx_cp_async_lds whenever shared + 16B + non-predicated. If + // dst_check_index is already lane-contiguous, codegen emits the LDS + // template directly. Otherwise the LowerTileOp M9 visitor will see + // tl::ptx_cp_async_lds and either (a) rewrite via the swizzle-swap + // (subtract SwizzleDelta on LDS, add to global) when the layout is + // amenable, or (b) downgrade to tl::ptx_cp_async if the post-swap + // index is not lane-contiguous. Both paths produce correct code. + if (is_shared) { ffi::Array lds_args = {dst_access_ptr, src_access_ptr, PrimExpr(num_elems)}; return Evaluate( From df611d30d003db0c3bb857b49d02a4c079360a1c Mon Sep 17 00:00:00 2001 From: zhutaoyu Date: Thu, 21 May 2026 09:03:06 +0000 Subject: [PATCH 24/34] docs: annotate FullBank/makeGemmABLayout with inline derivations Adds in-source Chinese comments tracing the FullBank swizzle math (ts/s/tc/c/vec/c_swizzle/index, with concrete dimensions for the NN B_shared=32x128 bf16 case) so the layout's role in the M10/M11 chunk-block-aware binding work is documented at the source. Pure comment additions, no behavior change. Co-Authored-By: Claude Opus 4 (1M context) --- src/layout/gemm_layouts.cc | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/layout/gemm_layouts.cc b/src/layout/gemm_layouts.cc index 758bb4788..f9f5bf969 100644 --- a/src/layout/gemm_layouts.cc +++ b/src/layout/gemm_layouts.cc @@ -495,19 +495,19 @@ Layout makeHalfBankSwizzleLayout(const Buffer &buffer) { } // Layout swizzling for 128 bytes -static Layout MakeFullBankSwizzleLayout2D(int stride, int continuous, - int element_size) { +static Layout MakeFullBankSwizzleLayout2D(int stride, int continuous, // stride = 32, continuous = 128 # BLOCK_K, BLOCK_N + int element_size) { // element_size = 16 # bf16 // Swizzle 3 bit Var i = InputPlaceholder(0); Var j = InputPlaceholder(1); - int vector_size = 128 / element_size; + int vector_size = 128 / element_size; // vector_size = 128 / 16 = 8 # 一次写入 float4. ICHECK(stride % 8 == 0) << "stride=" << stride; ICHECK(continuous % (vector_size * 8) == 0) << "continuous=" << continuous << ", vector_size=" << vector_size; - PrimExpr ts = FloorDiv(i, 8); - PrimExpr s = FloorMod(i, 8); - PrimExpr tc = FloorDiv(FloorDiv(j, vector_size), 8); - PrimExpr c = FloorMod(FloorDiv(j, vector_size), 8); + PrimExpr ts = FloorDiv(i, 8); // tile stride row + PrimExpr s = FloorMod(i, 8); // stride + PrimExpr tc = FloorDiv(FloorDiv(j, vector_size), 8); // tc = "tile-continuous", 第几个16字节的chunk. + PrimExpr c = FloorMod(FloorDiv(j, vector_size), 8); // 块内的编号. PrimExpr vec = FloorMod(j, vector_size); PrimExpr c_swizzle = xor8x8(c, s); PrimExpr index = vec + (c_swizzle + s * 8) * vector_size; @@ -819,7 +819,7 @@ Layout makeGemmSparseAmpereABLayout(int mat_stride, int mat_continuous, * \return A Layout object representing the chosen memory layout. */ Layout makeGemmABLayout(int mat_stride, int mat_continuous, int continuity, - int element_size, bool k_inner) { + int element_size, bool k_inner) { // makeGemmABLayout(stride=32, continuous=128, element_size=16, k_inner=False) if (element_size == 64) { if (!k_inner && continuity % 16 == 0) // float64 KxN return makeGemmABLayoutF64_Kouter(mat_stride, mat_continuous); @@ -827,10 +827,10 @@ Layout makeGemmABLayout(int mat_stride, int mat_continuous, int continuity, return makeGemmABLayoutF64_Kinner(mat_stride, mat_continuous); return makeGemmABLayoutPadded(mat_stride, mat_continuous, element_size); } - int vector_size = 128 / element_size; + int vector_size = 128 / element_size; // vector_size = 128 / 16 = 8 if (!k_inner && element_size == 8) // int8 KxN return makeGemmABLayoutPadded(mat_stride, mat_continuous, element_size); - else if (mat_continuous % (vector_size * 8) == 0) + else if (mat_continuous % (vector_size * 8) == 0) // mat_continuous = 128 % (8 * 8) = 0 return MakeFullBankSwizzleLayout2D(mat_stride, mat_continuous, element_size); else if (mat_continuous % (vector_size * 4) == 0) From c5b1973b332c603adc29087031c11cf96f9f05a6 Mon Sep 17 00:00:00 2001 From: zhutaoyu Date: Fri, 22 May 2026 15:12:14 +0000 Subject: [PATCH 25/34] Revert "docs: annotate FullBank/makeGemmABLayout with inline derivations" This reverts commit df611d30d003db0c3bb857b49d02a4c079360a1c. --- src/layout/gemm_layouts.cc | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/layout/gemm_layouts.cc b/src/layout/gemm_layouts.cc index f9f5bf969..758bb4788 100644 --- a/src/layout/gemm_layouts.cc +++ b/src/layout/gemm_layouts.cc @@ -495,19 +495,19 @@ Layout makeHalfBankSwizzleLayout(const Buffer &buffer) { } // Layout swizzling for 128 bytes -static Layout MakeFullBankSwizzleLayout2D(int stride, int continuous, // stride = 32, continuous = 128 # BLOCK_K, BLOCK_N - int element_size) { // element_size = 16 # bf16 +static Layout MakeFullBankSwizzleLayout2D(int stride, int continuous, + int element_size) { // Swizzle 3 bit Var i = InputPlaceholder(0); Var j = InputPlaceholder(1); - int vector_size = 128 / element_size; // vector_size = 128 / 16 = 8 # 一次写入 float4. + int vector_size = 128 / element_size; ICHECK(stride % 8 == 0) << "stride=" << stride; ICHECK(continuous % (vector_size * 8) == 0) << "continuous=" << continuous << ", vector_size=" << vector_size; - PrimExpr ts = FloorDiv(i, 8); // tile stride row - PrimExpr s = FloorMod(i, 8); // stride - PrimExpr tc = FloorDiv(FloorDiv(j, vector_size), 8); // tc = "tile-continuous", 第几个16字节的chunk. - PrimExpr c = FloorMod(FloorDiv(j, vector_size), 8); // 块内的编号. + PrimExpr ts = FloorDiv(i, 8); + PrimExpr s = FloorMod(i, 8); + PrimExpr tc = FloorDiv(FloorDiv(j, vector_size), 8); + PrimExpr c = FloorMod(FloorDiv(j, vector_size), 8); PrimExpr vec = FloorMod(j, vector_size); PrimExpr c_swizzle = xor8x8(c, s); PrimExpr index = vec + (c_swizzle + s * 8) * vector_size; @@ -819,7 +819,7 @@ Layout makeGemmSparseAmpereABLayout(int mat_stride, int mat_continuous, * \return A Layout object representing the chosen memory layout. */ Layout makeGemmABLayout(int mat_stride, int mat_continuous, int continuity, - int element_size, bool k_inner) { // makeGemmABLayout(stride=32, continuous=128, element_size=16, k_inner=False) + int element_size, bool k_inner) { if (element_size == 64) { if (!k_inner && continuity % 16 == 0) // float64 KxN return makeGemmABLayoutF64_Kouter(mat_stride, mat_continuous); @@ -827,10 +827,10 @@ Layout makeGemmABLayout(int mat_stride, int mat_continuous, int continuity, return makeGemmABLayoutF64_Kinner(mat_stride, mat_continuous); return makeGemmABLayoutPadded(mat_stride, mat_continuous, element_size); } - int vector_size = 128 / element_size; // vector_size = 128 / 16 = 8 + int vector_size = 128 / element_size; if (!k_inner && element_size == 8) // int8 KxN return makeGemmABLayoutPadded(mat_stride, mat_continuous, element_size); - else if (mat_continuous % (vector_size * 8) == 0) // mat_continuous = 128 % (8 * 8) = 0 + else if (mat_continuous % (vector_size * 8) == 0) return MakeFullBankSwizzleLayout2D(mat_stride, mat_continuous, element_size); else if (mat_continuous % (vector_size * 4) == 0) From 781eb74c6dd2f858972781767ed1954739769812 Mon Sep 17 00:00:00 2001 From: zhutaoyu Date: Fri, 22 May 2026 15:16:25 +0000 Subject: [PATCH 26/34] lower_ptx_async_copy: drop dead IsLdsLaneContiguous + dst_check_index After always-emit-LDS landed (let the LowerTileOp swizzle-swap visitor own the lane-contiguity decision and downgrade when it can't), the IsLdsLaneContiguous sampler and the dst_check_index parameter to MakeCPAsyncStmtFromLoads are no longer reachable. Remove them so the upstream-facing diff has no dead code; behavior is unchanged. Co-Authored-By: Claude Opus 4 (1M context) --- src/transform/lower_ptx_async_copy.cc | 84 ++++----------------------- 1 file changed, 12 insertions(+), 72 deletions(-) diff --git a/src/transform/lower_ptx_async_copy.cc b/src/transform/lower_ptx_async_copy.cc index 0db946d17..04f208bd0 100644 --- a/src/transform/lower_ptx_async_copy.cc +++ b/src/transform/lower_ptx_async_copy.cc @@ -124,9 +124,7 @@ class PTXAsyncCopyInjector : public StmtMutator { /*dst_base_load=*/BufferLoad(store->buffer, store->indices), /*src_base_load=*/BufferLoad(load->buffer, load->indices), /*num_elems=*/index_info->per_access_num_elems, - /*total_bytes=*/index_info->total_bytes, - /*dst_check_index=*/index_info->dst_index, predicated, - predicate_value); + /*total_bytes=*/index_info->total_bytes, predicated, predicate_value); } Optional> src_base_indices = @@ -147,8 +145,7 @@ class PTXAsyncCopyInjector : public StmtMutator { /*dst_base_load=*/BufferLoad(store->buffer, dst_base_indices.value()), /*src_base_load=*/BufferLoad(load->buffer, src_base_indices.value()), /*num_elems=*/index_info->per_access_num_elems, - /*total_bytes=*/index_info->total_bytes, - /*dst_check_index=*/index_info->dst_index, predicated, predicate_value); + /*total_bytes=*/index_info->total_bytes, predicated, predicate_value); } Stmt VisitStmt_(const SeqStmtNode *op) final { @@ -431,59 +428,6 @@ class PTXAsyncCopyInjector : public StmtMutator { return info; } - // Real lane-contiguity proof for the gfx950 buffer_load...lds routing. - // The destination LDS index must be an affine function of the thread-like - // variable with a constant per-lane stride. Bit-arithmetic swizzles - // (xor2x2 etc. expanded as `((a + b) & 1) * k`) have a constant - // f(1)-f(0) but are NOT linear globally; we catch those by sampling at - // many points and requiring f(k) - f(0) == k * stride at every point. - // Returns false unless we find a recognisable thread-like free var whose - // contribution is provably linear. - static bool IsLdsLaneContiguous(const PrimExpr &dst_index) { - arith::Analyzer analyzer; - std::unordered_set seen; - Array free_vars; - tir::PostOrderVisit(dst_index, [&](const ObjectRef &node) { - if (auto *v = node.as()) { - if (seen.insert(v).second) { - free_vars.push_back(Downcast(node)); - } - } - }); - // Sample 0, 1, 2, ..., 1023 — covers wave-32/64 boundaries, warp tiles, - // bank-swizzle phases (typically powers of two up to 64), and the - // wider 256-thread block boundaries the bench uses. - constexpr int kNumSamples = 1024; - for (const auto &var : free_vars) { - const std::string name(var->name_hint); - if (name.find("thread") == std::string::npos && name != "tx" && - name != "tid") { - continue; - } - PrimExpr f0 = analyzer.Simplify(Substitute( - dst_index, Map{{var, IntImm(var->dtype, 0)}})); - PrimExpr f1 = analyzer.Simplify(Substitute( - dst_index, Map{{var, IntImm(var->dtype, 1)}})); - PrimExpr expected_stride = analyzer.Simplify(f1 - f0); - auto *stride_imm = expected_stride.as(); - if (!stride_imm) { - return false; - } - int64_t stride = stride_imm->value; - for (int pt = 2; pt < kNumSamples; ++pt) { - PrimExpr fk = analyzer.Simplify(Substitute( - dst_index, Map{{var, IntImm(var->dtype, pt)}})); - PrimExpr actual = analyzer.Simplify(fk - f0); - PrimExpr expected = IntImm(DataType::Int(64), stride * pt); - if (!analyzer.CanProveEqual(actual, expected)) { - return false; - } - } - return true; - } - return false; - } - static PrimExpr ExtractVectorBase(const PrimExpr &index) { if (index.dtype().lanes() == 1) { return index; @@ -553,7 +497,6 @@ class PTXAsyncCopyInjector : public StmtMutator { const BufferLoad &dst_base_load, const BufferLoad &src_base_load, int num_elems, int total_bytes, - const PrimExpr &dst_check_index, bool predicated, const PrimExpr &predicate_value) { PrimExpr dst_access_ptr = @@ -562,22 +505,19 @@ class PTXAsyncCopyInjector : public StmtMutator { MakeAccessPtrFromLoad(src_base_load, num_elems, /*rw_mask=*/1); // gfx950 routing: emit tl::ptx_cp_async_lds when the destination is a - // 16-byte non-predicated shared-memory write whose LDS index is lane- - // contiguous (no XOR swizzle). Arg 2 carries the logical element count - // (same convention tl::ptx_cp_async uses) so the existing vec-loop - // folding in vectorize_loop.cc widens it correctly when the call sits - // inside a T.vectorized(k) loop. The codegen handler converts the - // logical count back to bytes via GetTileLangCPAsyncTransferBytes. + // 16-byte non-predicated shared-memory write. Arg 2 carries the logical + // element count (same convention tl::ptx_cp_async uses) so the existing + // vec-loop folding in vectorize_loop.cc widens it correctly when the + // call sits inside a T.vectorized(k) loop. The codegen handler converts + // the logical count back to bytes via GetTileLangCPAsyncTransferBytes. + // If the LDS index carries an XOR swizzle, the swizzle-swap visitor in + // LowerTileOp rewrites the call (subtract SwizzleDelta on LDS, add on + // global) so the destination becomes lane-contiguous; if the swap + // can't produce an affine destination it downgrades back to + // tl::ptx_cp_async, so both paths produce correct code. if (enable_buffer_load_lds_ && !predicated && total_bytes == 16) { const std::string dst_scope = store->buffer.scope(); const bool is_shared = dst_scope == "shared" || dst_scope == "shared.dyn"; - // Emit ptx_cp_async_lds whenever shared + 16B + non-predicated. If - // dst_check_index is already lane-contiguous, codegen emits the LDS - // template directly. Otherwise the LowerTileOp M9 visitor will see - // tl::ptx_cp_async_lds and either (a) rewrite via the swizzle-swap - // (subtract SwizzleDelta on LDS, add to global) when the layout is - // amenable, or (b) downgrade to tl::ptx_cp_async if the post-swap - // index is not lane-contiguous. Both paths produce correct code. if (is_shared) { ffi::Array lds_args = {dst_access_ptr, src_access_ptr, PrimExpr(num_elems)}; From 7163b44d4e2a6b85923cdf8ec1ac7e86de7df582 Mon Sep 17 00:00:00 2001 From: zhutaoyu Date: Fri, 22 May 2026 15:20:09 +0000 Subject: [PATCH 27/34] parallel: drop redundant CBA call in ComputePlanCandidate The chunk-block-aware binding now lands via the early hook at the top of ParallelOp::InferLayout (runs at all 3 levels, before source_buffer dispatch). By the time PlanLoopPartition's ComputePlanCandidate runs, loop_layout_ is already set when CBA applies, so the second ComputeChunkBlockAwarePlanCandidate call site here is dead in practice. Removed it; behavior is unchanged. Verified: - NT 8192^2 K=8192 256x256x64 stages=2 threads=512: 1114 TFLOPS, PASS - NN 1024x1024x16384 128x128x32 stages=3 threads=128: 64 TFLOPS, PASS Co-Authored-By: Claude Opus 4 (1M context) --- src/op/parallel.cc | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/src/op/parallel.cc b/src/op/parallel.cc index df1ec0e90..0bc09b998 100644 --- a/src/op/parallel.cc +++ b/src/op/parallel.cc @@ -726,18 +726,10 @@ Fragment ParallelOpNode::ComputePlanCandidate(const LayoutInferArgs &T) const { DLOG(INFO) << "[PlanLoopPartition] root_ = " << root_ << " ############# vector_size = " << vector_size << ", thread_bounds = " << T.thread_bounds << '\n'; - // Prefer chunk-block aware binding when applicable. The default flatten - // policy below would put 16 (or more) lanes on the continuous dim; when a - // shared destination has a FullBank swizzle layout with tc > 1, those lanes - // straddle the chunk-block boundary, breaking lane-contiguous LDS WRITEs and - // any downstream swizzle-swap. - if (auto cba_plan = - ComputeChunkBlockAwarePlanCandidate(T, vector_size); - cba_plan.defined()) { - DLOG(INFO) << "[PlanLoopPartition] chunk-block-aware candidate = " - << cba_plan->DebugOutput() << '\n'; - return cba_plan; - } + // Chunk-block-aware binding is taken by the early hook in + // ParallelOp::InferLayout (before source_buffer dispatch). By the time + // we reach here loop_layout_ is already set when CBA applies, so no + // need to re-try it. auto plan = PlanLoopPartition(root_, vector_size, T.thread_bounds); DLOG(INFO) << "[PlanLoopPartition] candidate = " << plan->DebugOutput() << '\n'; From 03343e21ce682dac4f61430ab2ebecff28cf028f Mon Sep 17 00:00:00 2001 From: zhutaoyu Date: Fri, 22 May 2026 16:16:27 +0000 Subject: [PATCH 28/34] parallel: apply clang-format Co-Authored-By: Claude Opus 4 (1M context) --- src/op/parallel.cc | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/src/op/parallel.cc b/src/op/parallel.cc index 3475bb51b..66867ab8d 100644 --- a/src/op/parallel.cc +++ b/src/op/parallel.cc @@ -417,13 +417,12 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, PrimExpr loop_total_size = 1; for (Stmt l = root_; l.as().has_value(); l = l.as().value()->body) loop_total_size = loop_total_size * l.as().value()->extent; - while (!analyzer_.CanProve(floormod(loop_total_size, - T.thread_bounds->extent * - vector_size) == 0) && - vector_size > 1) + while ( + !analyzer_.CanProve(floormod(loop_total_size, T.thread_bounds->extent * + vector_size) == 0) && + vector_size > 1) vector_size /= 2; - if (auto cba = - ComputeChunkBlockAwarePlanCandidate(T, vector_size); + if (auto cba = ComputeChunkBlockAwarePlanCandidate(T, vector_size); cba.defined()) { loop_layout_ = cba; } @@ -742,8 +741,9 @@ Fragment ParallelOpNode::ComputePlanCandidate(const LayoutInferArgs &T) const { return plan; } -Fragment ParallelOpNode::ComputeChunkBlockAwarePlanCandidate( - const LayoutInferArgs &T, int vector_size) const { +Fragment +ParallelOpNode::ComputeChunkBlockAwarePlanCandidate(const LayoutInferArgs &T, + int vector_size) const { // 1. Find a written shared buffer with a swizzle layout whose continuous // (innermost) dim exceeds one CDNA LDS bank cycle (128 bytes). When // that happens FullBank splits the dim into `tc` planes; the default @@ -809,8 +809,7 @@ Fragment ParallelOpNode::ComputeChunkBlockAwarePlanCandidate( } if (chosen_axis < 0) return; - auto *axis_ext_imm = - as_const_int(loop_vars_[chosen_axis]->dom->extent); + auto *axis_ext_imm = as_const_int(loop_vars_[chosen_axis]->dom->extent); if (!axis_ext_imm || *axis_ext_imm != cont) return; } else { @@ -857,8 +856,7 @@ Fragment ParallelOpNode::ComputeChunkBlockAwarePlanCandidate( PrimExpr modified_total = IntImm(dtype, 1); PrimExpr modified_flat = make_zero(dtype); for (int i = 0; i < static_cast(loop_vars_.size()); i++) { - PrimExpr ext = - (i == split_axis) ? inner_pe : loop_vars_[i]->dom->extent; + PrimExpr ext = (i == split_axis) ? inner_pe : loop_vars_[i]->dom->extent; PrimExpr v = (i == split_axis) ? inner_part : static_cast(loop_vars_[i]->var); @@ -894,8 +892,8 @@ Fragment ParallelOpNode::ComputeChunkBlockAwarePlanCandidate( PrimExpr num_thread_pe = IntImm(dtype, *num_thread_imm); PrimExpr access_idx = FloorDiv(flat, vec_pe); PrimExpr thd = FloorMod(access_idx, num_thread_pe); - PrimExpr idx = FloorDiv(access_idx, num_thread_pe) * vec_pe + - FloorMod(flat, vec_pe); + PrimExpr idx = + FloorDiv(access_idx, num_thread_pe) * vec_pe + FloorMod(flat, vec_pe); Fragment fragment = Fragment(loop_vars_, /*forward_index=*/{idx}, /*forward_thread=*/thd, From e9d7bdbbec72081f0c78b693253f331d096943a3 Mon Sep 17 00:00:00 2001 From: zhutaoyu Date: Fri, 22 May 2026 16:41:57 +0000 Subject: [PATCH 29/34] review fixes for ptx_cp_async_lds family - codegen_hip: route ptx_cp_async_lds_rsrc through GetTileLangCPAsyncTransferBytes so src/dst widths must match and the final byte width is validated against {4,8,16}. - layout: ICHECK_GE the index count in SwizzleDelta so a too-short call fails with a clear contract error instead of OOB-reading. - builtin.h: fix ptx_cp_async_lds / _rsrc docstrings -- arg 2 is num_elems, not bytes (lowering derives byte width from the access-ptr dtype). - lower_tile_op: drop the "tx"/"tid"/"thread*" name-matching in the affine lane-contiguity proof and key off the real threadIdx.x binding tracked in thread_var_ so a future rename of the lane var can't silently misclassify a non-affine LDS index as OK. - vectorize_loop: accept ptx_cp_async_lds_rsrc in MutatePTXCPAsyncExpr_ / GetCPAsyncBitsPerCall (it's already routed here from VisitExpr_) and preserve the trailing (rsrc, base) args through the rewrite so codegen still sees them. Co-Authored-By: Claude Opus 4 (1M context) --- src/backend/rocm/codegen/codegen_hip.cc | 34 ++++------- src/layout/layout.cc | 7 ++- src/op/builtin.h | 13 ++++- src/transform/lower_tile_op.cc | 78 +++++++++++++------------ src/transform/vectorize_loop.cc | 34 ++++++++--- 5 files changed, 92 insertions(+), 74 deletions(-) diff --git a/src/backend/rocm/codegen/codegen_hip.cc b/src/backend/rocm/codegen/codegen_hip.cc index 41ab5a4ea..3ec7cac98 100644 --- a/src/backend/rocm/codegen/codegen_hip.cc +++ b/src/backend/rocm/codegen/codegen_hip.cc @@ -56,9 +56,12 @@ std::optional GetAccessPtrElementType(const PrimExpr &expr) { } int GetTileLangCPAsyncTransferBytes(const CallNode *op) { - ICHECK(op->args.size() == 3 || op->args.size() == 4) - << "tl::ptx_cp_async expects 3 or 4 arguments (dst_access_ptr, " - "src_access_ptr, num_elems, [predicate])"; + // Accepts ptx_cp_async / ptx_cp_async_lds (3 or 4 args: dst, src, num_elems, + // [predicate]) and ptx_cp_async_lds_rsrc (5 args: dst, src, num_elems, + // rsrc_var, base_var) -- only args[0..2] are read here. + ICHECK(op->args.size() == 3 || op->args.size() == 4 || op->args.size() == 5) + << "tl::ptx_cp_async family expects 3-5 arguments (dst_access_ptr, " + "src_access_ptr, num_elems, ...)"; const auto *num_elems_imm = op->args[2].as(); ICHECK(num_elems_imm) << "tl::ptx_cp_async num_elems must be IntImm, but got " << op->args[2]; @@ -924,28 +927,15 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) { std::string ptr = this->PrintExpr(op->args[0]); os << "make_wave_buffer_resource((const void*)(" << ptr << "))"; } else if (op->op.same_as(tl::ptx_cp_async_lds_rsrc())) { - // args = [dst, src, bytes, rsrc_var, base_var] + // args = [dst, src, num_elems, rsrc_var, base_var]. arg 2 is the logical + // element count inherited from the ptx_cp_async_lds call that + // HoistBufferResource rewrote into this rsrc form -- the helper does the + // src/dst width-equality and {4,8,16} validation that the plain + // ptx_cp_async path also relies on. ICHECK(op->args.size() == 5) << "ptx_cp_async_lds_rsrc expects 5 arguments"; std::string dst = this->PrintExpr(op->args[0]); std::string src = this->PrintExpr(op->args[1]); - // arg 2 carries logical element count (inherited from the - // ptx_cp_async_lds call that HoistBufferResource rewrote into this - // rsrc form). Convert to bytes the same way the ptx_cp_async / lds - // handler does, by inspecting the access-ptr element type. - const auto *num_elems_imm = op->args[2].as(); - ICHECK(num_elems_imm) - << "ptx_cp_async_lds_rsrc num_elems must be IntImm, but got " - << op->args[2]; - auto dst_elem_type = GetAccessPtrElementType(op->args[0]); - ICHECK(dst_elem_type.has_value()) - << "ptx_cp_async_lds_rsrc dst must be tvm_access_ptr / tl.access_ptr / " - "address_of(BufferLoad)"; - int64_t total_bits = num_elems_imm->value * dst_elem_type.value().bits() * - dst_elem_type.value().lanes(); - ICHECK_EQ(total_bits % 8, 0) - << "ptx_cp_async_lds_rsrc requires byte-aligned transfer, got " - << total_bits << " bits"; - int total_bytes = static_cast(total_bits / 8); + int total_bytes = GetTileLangCPAsyncTransferBytes(op); std::string size = std::to_string(total_bytes); std::string rsrc = this->PrintExpr(op->args[3]); std::string base = this->PrintExpr(op->args[4]); diff --git a/src/layout/layout.cc b/src/layout/layout.cc index 62f81542b..5329d00c0 100644 --- a/src/layout/layout.cc +++ b/src/layout/layout.cc @@ -521,10 +521,11 @@ PrimExpr LayoutNode::SwizzleDelta(const Array &input_indices) const { } // Substitute the last InputDim() entries of input_indices into // swizzle_delta_, matching the convention Forward() uses. + ICHECK_GE(input_indices.size(), InputDim()) + << "SwizzleDelta requires at least " << InputDim() << " indices, but got " + << input_indices.size(); PrimExpr delta = swizzle_delta_.value(); - size_t offset = input_indices.size() >= InputDim() - ? input_indices.size() - InputDim() - : 0; + size_t offset = input_indices.size() - InputDim(); for (size_t i = 0; i < InputDim(); ++i) { delta = Substitute(delta, {{InputPlaceholder(i), input_indices[offset + i]}}); diff --git a/src/op/builtin.h b/src/op/builtin.h index 32000618e..482a074cf 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -523,7 +523,13 @@ TVM_DLL const Op &ptx_cp_async(); * pass in lower_tile_op.cc moves the XOR swizzle from the LDS store side * to the global load side to make this safe. * - * ptx_cp_async_lds(dst_access_ptr, src_access_ptr, bytes) + * ptx_cp_async_lds(dst_access_ptr, src_access_ptr, num_elems) + * + * num_elems is the logical element count (NOT byte width). Lowering + * derives the {4, 8, 16} byte transfer width from the access-ptr dtype. + * Passing this as elements keeps vec-loop folding in vectorize_loop.cc + * (which multiplies the count when it widens a loop) consistent with + * the plain ptx_cp_async path. */ TVM_DLL const Op &ptx_cp_async_lds(); @@ -546,8 +552,11 @@ TVM_DLL const Op &ptx_make_buffer_resource(); * HoistBufferResource Python pass rewrites ptx_cp_async_lds calls to this * form once per kernel. * - * ptx_cp_async_lds_rsrc(dst_access_ptr, src_access_ptr, bytes, rsrc_var, + * ptx_cp_async_lds_rsrc(dst_access_ptr, src_access_ptr, num_elems, rsrc_var, * base_var) + * + * num_elems uses the same convention as ptx_cp_async_lds -- logical + * element count, not bytes; lowering converts via the access-ptr dtype. */ TVM_DLL const Op &ptx_cp_async_lds_rsrc(); diff --git a/src/transform/lower_tile_op.cc b/src/transform/lower_tile_op.cc index d95fba908..2fe2dad06 100644 --- a/src/transform/lower_tile_op.cc +++ b/src/transform/lower_tile_op.cc @@ -1077,49 +1077,51 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { // some shapes) where the swizzle spreads across multiple output // dims, the post-swap LDS index is NOT lane-contiguous and // buffer_load_dwordx4...lds would write to wrong addresses. - // Detect by sampling each new_dst_indices[d] against a - // thread-like var and requiring an affine (constant or - // single-stride) dependence. If any dim is non-affine in tx - // (typical for bit-extract `(tx & m) >> s` terms), bail out so - // the call falls through to the safe ptx_cp_async path. + // Detect by sampling each new_dst_indices[d] against the actual + // threadIdx.x binding tracked in thread_var_ (set from the + // tirx::attr::thread_extent AttrStmt). Substituting concrete + // lane values and requiring the dependence to be a single + // constant stride catches the case where the LDS index has + // bit-extract terms like `(tx & m) >> s` that buffer_load_lds + // would scatter. If any dim is non-affine in the lane var, bail + // out so the call falls through to the safe ptx_cp_async path. + // + // Use the real binding rather than a name-based heuristic so a + // future rename of the lane var (or any kernel whose lane var + // doesn't happen to be called "tx"/"tid"/"thread*") doesn't + // silently misclassify the call as affine-OK and emit a + // wrong-banks tile. + Var lane_var; + if (thread_var_.defined() && thread_var_->var.defined() && + thread_block_size_ > 1) { + lane_var = thread_var_->var; + } auto is_affine_in_thread_var = [&](const PrimExpr &e) -> bool { + if (!lane_var.defined()) { + // Serial / no-thread kernel: expression is trivially + // constant w.r.t. the lane and any LDS layout works. + return true; + } arith::Analyzer post_analyzer; - std::unordered_set seen; - Array free_vars; - tirx::PostOrderVisit(e, [&](const ObjectRef &node) { - if (auto *v = node.as()) { - if (seen.insert(v).second) { - free_vars.push_back(Downcast(node)); - } - } - }); - for (const auto &var : free_vars) { - const std::string name(var->name_hint); - if (name.find("thread") == std::string::npos && name != "tx" && - name != "tid") { - continue; - } - PrimExpr f0 = post_analyzer.Simplify(tirx::Substitute( - e, Map{{var, IntImm(var->dtype, 0)}})); - PrimExpr f1 = post_analyzer.Simplify(tirx::Substitute( - e, Map{{var, IntImm(var->dtype, 1)}})); - PrimExpr stride = post_analyzer.Simplify(f1 - f0); - const auto *stride_imm = stride.as(); - if (!stride_imm) + PrimExpr f0 = post_analyzer.Simplify(tirx::Substitute( + e, Map{{lane_var, IntImm(lane_var->dtype, 0)}})); + PrimExpr f1 = post_analyzer.Simplify(tirx::Substitute( + e, Map{{lane_var, IntImm(lane_var->dtype, 1)}})); + PrimExpr stride = post_analyzer.Simplify(f1 - f0); + const auto *stride_imm = stride.as(); + if (!stride_imm) + return false; + for (int pt = 2; pt < 64; ++pt) { + PrimExpr fk = post_analyzer.Simplify(tirx::Substitute( + e, + Map{{lane_var, IntImm(lane_var->dtype, pt)}})); + PrimExpr actual = post_analyzer.Simplify(fk - f0); + PrimExpr expected = + IntImm(DataType::Int(64), stride_imm->value * pt); + if (!post_analyzer.CanProveEqual(actual, expected)) { return false; - for (int pt = 2; pt < 64; ++pt) { - PrimExpr fk = post_analyzer.Simplify(tirx::Substitute( - e, Map{{var, IntImm(var->dtype, pt)}})); - PrimExpr actual = post_analyzer.Simplify(fk - f0); - PrimExpr expected = - IntImm(DataType::Int(64), stride_imm->value * pt); - if (!post_analyzer.CanProveEqual(actual, expected)) { - return false; - } } - return true; } - // No thread-like var found: constant w.r.t. tx -> affine OK. return true; }; bool all_dims_affine = true; diff --git a/src/transform/vectorize_loop.cc b/src/transform/vectorize_loop.cc index e01c814d9..6c556b6de 100644 --- a/src/transform/vectorize_loop.cc +++ b/src/transform/vectorize_loop.cc @@ -681,7 +681,8 @@ class TLVectorizer : public StmtMutator, return scalar_count * 8; } ICHECK(op->op.same_as(tl::ptx_cp_async()) || - op->op.same_as(tl::ptx_cp_async_lds())); + op->op.same_as(tl::ptx_cp_async_lds()) || + op->op.same_as(tl::ptx_cp_async_lds_rsrc())); auto dst_elem_bits = GetAccessPtrElementBits(op->args[0]); auto src_elem_bits = GetAccessPtrElementBits(op->args[1]); if (!dst_elem_bits.has_value() || !src_elem_bits.has_value()) { @@ -703,8 +704,11 @@ class TLVectorizer : public StmtMutator, PrimExpr MutatePTXCPAsyncExpr_(const CallNode *op) { ICHECK(op->op.same_as(builtin::ptx_cp_async()) || op->op.same_as(tl::ptx_cp_async()) || - op->op.same_as(tl::ptx_cp_async_lds())); - if (op->args.size() != 3 && op->args.size() != 4) { + op->op.same_as(tl::ptx_cp_async_lds()) || + op->op.same_as(tl::ptx_cp_async_lds_rsrc())); + // 3 or 4 args: dst, src, count, [predicate] (plain cp_async family). + // 5 args: dst, src, count, rsrc_var, base_var (hoisted-resource form). + if (op->args.size() != 3 && op->args.size() != 4 && op->args.size() != 5) { return GetRef(op); } @@ -720,13 +724,27 @@ class TLVectorizer : public StmtMutator, } predicate = pred; } + // For the rsrc form, args[3..4] are the hoisted (rsrc_var, base_var) and + // must be preserved through the rewrite so codegen still sees them. + Array trailing_rsrc_args; + if (op->args.size() == 5) { + trailing_rsrc_args.push_back(VisitExpr(op->args[3])); + trailing_rsrc_args.push_back(VisitExpr(op->args[4])); + } + + auto append_trailing = [&](Array &args) { + if (predicate.defined()) { + args.push_back(predicate.value()); + } + for (const auto &a : trailing_rsrc_args) { + args.push_back(a); + } + }; auto lanes_ptr = as_const_int(var_lanes_); if (!lanes_ptr || *lanes_ptr <= 1) { Array new_args{dst, src, count}; - if (predicate.defined()) { - new_args.push_back(predicate.value()); - } + append_trailing(new_args); if (new_args.same_as(op->args)) { return GetRef(op); } @@ -754,9 +772,7 @@ class TLVectorizer : public StmtMutator, int total_count = static_cast(Downcast(count)->value) * vector_size; Array new_args{dst, src, IntImm(count.dtype(), total_count)}; - if (predicate.defined()) { - new_args.push_back(predicate.value()); - } + append_trailing(new_args); if (new_args.same_as(op->args)) { return GetRef(op); } From 3d95235595998d40e761599998a804453b227b05 Mon Sep 17 00:00:00 2001 From: zhutaoyu Date: Sat, 23 May 2026 13:18:09 +0000 Subject: [PATCH 30/34] parallel: gate early CBA hook on fragment validation The unconditional override in the early CBA branch of InferLayout was forcing the loop layout onto MMA accumulators like acc_o_l / C_local / dq, whose fragment is already in T.layout_map with its own MMA-derived binding. The follow-up ValidateCandidateAgainstFragments call would then throw a "Layout infer conflict between and " error. Only adopt the CBA candidate when it actually validates against the existing fragments; otherwise fall through to the normal source-buffer / free-inference paths so the MMA fragment's binding still wins. Co-Authored-By: Claude Opus 4 (1M context) --- src/op/parallel.cc | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/op/parallel.cc b/src/op/parallel.cc index 66867ab8d..34eb3c3fd 100644 --- a/src/op/parallel.cc +++ b/src/op/parallel.cc @@ -424,7 +424,16 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, vector_size /= 2; if (auto cba = ComputeChunkBlockAwarePlanCandidate(T, vector_size); cba.defined()) { - loop_layout_ = cba; + // Only adopt the CBA layout if it doesn't conflict with a fragment + // (e.g. an MMA accumulator like acc_o_l / C_local) that already has + // a layout in T.layout_map. Otherwise the unconditional override + // would force the loop onto a binding incompatible with the + // fragment and ValidateCandidateAgainstFragments would fail later. + if (ValidateCandidateAgainstFragments(cba, T, /*throw_on_error=*/false, + /*check_forward_index=*/false, + /*source_buffer=*/Buffer{})) { + loop_layout_ = cba; + } } } if (!loop_layout_.defined() && annotated_layout_unbound_.defined()) { From 7f8b103218174c17a989a66fe81202be8bad2a00 Mon Sep 17 00:00:00 2001 From: zhutaoyu Date: Sat, 23 May 2026 13:22:23 +0000 Subject: [PATCH 31/34] parallel: scope early CBA hook to gfx950 only The chunk-block-aware override exists solely to make buffer_load_dwordx4...lds usable on gfx950; firing it on CUDA / older AMD targets can only force the loop binding into a shape that conflicts with MMA fragment layouts (the CUDA CI hit this on acc_o_l / C_local / dq / C_local_accum). Gate on TargetIsGfx950 so every other target keeps its existing layout-inference behaviour untouched. The validation fallback added in the previous commit stays as defense-in-depth. Co-Authored-By: Claude Opus 4 (1M context) --- src/op/parallel.cc | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/op/parallel.cc b/src/op/parallel.cc index 34eb3c3fd..2c1cf7cdd 100644 --- a/src/op/parallel.cc +++ b/src/op/parallel.cc @@ -408,7 +408,12 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, // Override here BEFORE source_buffer dispatch so the CBA fragment wins. // Fire at ALL levels so the first level that sees the layout map populated // wins; subsequent levels short-circuit via loop_layout_inferred_. - if (!loop_layout_.defined()) { + // + // gfx950-only: this hook exists to make buffer_load_dwordx4...lds usable, + // and the alternate binding it picks can conflict with MMA fragment + // bindings on NVIDIA / older AMD. Skip on every other target so the + // existing CUDA / pre-CDNA4 layout inference is untouched. + if (!loop_layout_.defined() && TargetIsGfx950(T.target)) { // Reuse the same vec_size calculation as ComputePlanCandidate. auto maybe_remapped_root = IfBufferRemapLoopGenerator::run(root_, T.buffer_remap, T.layout_map); From 2718f499be73b01dd84ca5067488e21e22bf0761 Mon Sep 17 00:00:00 2001 From: zhutaoyu Date: Wed, 27 May 2026 09:46:48 +0000 Subject: [PATCH 32/34] drop bare ptx_cp_async_lds codegen, fall back to sync cp_async_gs HoistBufferResource rewrites every ptx_cp_async_lds it can match to the _rsrc form, which is the form codegen actually turns into the buffer_load_dwordx4 ... lds fast path. The two corner cases where the rewrite silently skips a call (empty buffer_vars early-return, or an access_ptr shape _extract_buffer_var can't pattern-match) used to land on a dedicated cp_async_gs_lds codegen branch with its own buffer_load asm path -- duplicating the fast path and creating a second way to emit lds that could silently diverge from the _rsrc one. Collapse to a single safety net: treat ptx_cp_async_lds identically to ptx_cp_async in codegen so an unhoisted call just lowers to the synchronous tl::cp_async_gs. Correct, no buffer_load_lds win for that particular call -- which is fine because in practice all calls get hoisted. Removes the now-dead cp_async_gs_lds template too and updates the ptx_cp_async_lds docstring to reflect the new contract. Co-Authored-By: Claude Opus 4 (1M context) --- src/backend/rocm/codegen/codegen_hip.cc | 20 +++++++++++------- src/op/builtin.h | 20 +++++++++++------- src/tl_templates/hip/copy.h | 27 ------------------------- 3 files changed, 26 insertions(+), 41 deletions(-) diff --git a/src/backend/rocm/codegen/codegen_hip.cc b/src/backend/rocm/codegen/codegen_hip.cc index 80cad0782..ecb12248e 100644 --- a/src/backend/rocm/codegen/codegen_hip.cc +++ b/src/backend/rocm/codegen/codegen_hip.cc @@ -1205,9 +1205,11 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) { std::string size = this->PrintExpr(op->args[2]); this->PrintIndent(); if (op->args.size() == 3) { + // Non-predicated version this->stream << "tl::cp_async_gs<" << size << ">(" << dst << ", " << src << ");\n"; } else { + // Predicated version std::string condition = this->PrintExpr(op->args[3]); this->stream << "tl::cp_async_gs_conditional<" << size << ">(" << dst << ", " << src << ", " << condition << ");\n"; @@ -1215,19 +1217,23 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) { } else if (op->op.same_as(tl::ptx_cp_async()) || op->op.same_as(tl::ptx_cp_async_lds())) { // Both store logical element count in arg 2; convert to bytes via - // GetTileLangCPAsyncTransferBytes. ptx_cp_async_lds routes the - // non-predicated path to tl::cp_async_gs_lds instead of - // tl::cp_async_gs; predicated copies always fall back to the - // generic cp_async_gs_conditional template. + // GetTileLangCPAsyncTransferBytes. + // + // tl::ptx_cp_async_lds is normally rewritten to ptx_cp_async_lds_rsrc + // by the HoistBufferResource pass. If a call survives the rewrite + // (e.g. an access_ptr shape _extract_buffer_var can't pattern-match, + // or the pass found nothing to hoist), fall back to the synchronous + // tl::cp_async_gs path here -- correctness is preserved at + // the cost of giving up the buffer_load_dwordx4...lds fast path for + // that particular call. Treat both ops identically in codegen. int total_bytes = GetTileLangCPAsyncTransferBytes(op); std::string dst = this->PrintExpr(op->args[0]); std::string src = this->PrintExpr(op->args[1]); std::string size = std::to_string(total_bytes); this->PrintIndent(); if (op->args.size() == 3) { - bool use_lds = op->op.same_as(tl::ptx_cp_async_lds()); - this->stream << (use_lds ? "tl::cp_async_gs_lds<" : "tl::cp_async_gs<") - << size << ">(" << dst << ", " << src << ");\n"; + this->stream << "tl::cp_async_gs<" << size << ">(" << dst << ", " << src + << ");\n"; } else { std::string condition = this->PrintExpr(op->args[3]); this->stream << "tl::cp_async_gs_conditional<" << size << ">(" << dst diff --git a/src/op/builtin.h b/src/op/builtin.h index 1bf41a7e5..30e7ee481 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -517,13 +517,19 @@ TVM_DLL const Op &ptx_cp_async_barrier_noinc(); TVM_DLL const Op &ptx_cp_async(); /*! - * \brief Truly async G2S copy via buffer_load_dwordx4 ... lds (gfx950+). - * - * Same signature as ptx_cp_async but lowers to cp_async_gs_lds which - * uses the hardware buffer_load ... lds instruction (bypasses VGPRs). - * Only valid when LDS addresses are lane-contiguous; the swizzle-swap - * pass in lower_tile_op.cc moves the XOR swizzle from the LDS store side - * to the global load side to make this safe. + * \brief Marker for an eligible async G2S copy on gfx950+. + * + * Emitted by LowerPTXAsyncCopy in place of ptx_cp_async for 16-byte + * non-predicated shared-memory writes whose LDS index is (post + * swizzle-swap, see lower_tile_op.cc) lane-contiguous. The + * HoistBufferResource pass then rewrites each call to + * ptx_cp_async_lds_rsrc with a pre-computed buffer resource + * descriptor + base address; that rsrc form is what codegen emits as + * the buffer_load_dwordx4 ... lds fast path. + * + * If a call survives the rewrite (e.g. an access_ptr the hoister + * can't pattern-match), codegen falls back to the synchronous + * tl::cp_async_gs path -- correct, but no buffer_load_lds win. * * ptx_cp_async_lds(dst_access_ptr, src_access_ptr, num_elems) * diff --git a/src/tl_templates/hip/copy.h b/src/tl_templates/hip/copy.h index 461cafe85..2f36f88a6 100644 --- a/src/tl_templates/hip/copy.h +++ b/src/tl_templates/hip/copy.h @@ -138,33 +138,6 @@ TL_DEVICE void cp_async_gs_conditional(void *lds_base_ptr, } } -// gfx950 (CDNA4 / MI350) direct global→LDS copy via buffer_load_dwordx4 ... -// lds. Bypasses VGPRs entirely. Only valid when LDS destination is -// lane-contiguous (base + lane_id * N bytes); the swizzle-swap pass in -// LowerTileOp guarantees this by moving the XOR swizzle from the LDS store side -// to the global load side. The compiler is expected to emit this only for -// 16-byte copies; other sizes fall back to cp_async_gs. -template -TL_DEVICE void cp_async_gs_lds(void *lds_base_ptr, - void const *global_base_ptr) { - if constexpr (N == 16) { - auto rsrc = make_wave_buffer_resource(global_base_ptr); - uint32_t my_lo = - static_cast(reinterpret_cast(global_base_ptr)); - uint32_t base_lo = __builtin_amdgcn_readfirstlane(my_lo); - uint32_t voffset = my_lo - base_lo; - uint32_t lds_cur = __builtin_amdgcn_readfirstlane( - static_cast(reinterpret_cast(lds_base_ptr))); - asm volatile("s_mov_b32 m0, %0; \n\t" - "buffer_load_dwordx4 %1, %2, 0 offen lds;\n\t" - : - : "s"(lds_cur), "v"(voffset), "s"(rsrc) - : "memory"); - } else { - cp_async_gs(lds_base_ptr, global_base_ptr); - } -} - // Variant with pre-hoisted buffer resource descriptor and base address. // rsrc and rsrc_base_lo are computed once at kernel entry (see the // HoistBufferResource Python pass) so per-call readfirstlane overhead is From 25f61a2e5420792db7f65a482f9a879ef6de719b Mon Sep 17 00:00:00 2001 From: zhutaoyu Date: Wed, 27 May 2026 09:48:20 +0000 Subject: [PATCH 33/34] clean code --- tilelang/engine/phase.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index ee2f1fc99..89a1f47a4 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -282,10 +282,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: mod = tilelang.transform.MarkCudaSyncCalls(have_pdl(target))(mod) mod = tilelang.transform.AnnotateReadOnlyParams()(mod) # MergeSharedMemoryAllocations must be applied after SplitHostDevice - # because the merged allocation site is at the beginning of each device - # function. LowerDeviceKernelLaunch enforces "Only one dynamic shared - # memory allocation"; keeping this disabled breaks any kernel with - # multiple .dyn buffers (the bench matmul has two: A_shared + B_shared). + # because the merged allocation site is at the beginning of each device function enable_aggressive_merge = should_enable_aggressive_merge(pass_ctx=pass_ctx, target=target) disable_reuse = should_disable_shared_memory_reuse(pass_ctx=pass_ctx) mod = tilelang.transform.MergeSharedMemoryAllocations(enable_aggressive_merge=enable_aggressive_merge, disable_reuse=disable_reuse)(mod) From a0312d5d0fa7dc9d29207333fc5a24680137e212 Mon Sep 17 00:00:00 2001 From: zhutaoyu Date: Wed, 27 May 2026 09:50:02 +0000 Subject: [PATCH 34/34] clean code --- src/tl_templates/hip/copy.h | 1 + 1 file changed, 1 insertion(+) diff --git a/src/tl_templates/hip/copy.h b/src/tl_templates/hip/copy.h index 2f36f88a6..1c018272b 100644 --- a/src/tl_templates/hip/copy.h +++ b/src/tl_templates/hip/copy.h @@ -154,6 +154,7 @@ cp_async_gs_lds_with_rsrc(void *lds_base_ptr, void const *global_base_ptr, uint32_t voffset = my_lo - rsrc_base_lo; uint32_t lds_cur = __builtin_amdgcn_readfirstlane( static_cast(reinterpret_cast(lds_base_ptr))); + // TODO(benenzhu): here use inline asm is a little bit tricky. asm volatile("s_mov_b32 m0, %0; \n\t" "buffer_load_dwordx4 %1, %2, 0 offen lds;\n\t" :