[AMD][ROCm] Fix CI failures on gfx950, gfx1100, gfx1151, and gfx1201#2326
[AMD][ROCm] Fix CI failures on gfx950, gfx1100, gfx1151, and gfx1201#2326zhangnju wants to merge 6 commits into
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
📝 WalkthroughWalkthroughThis PR converts HIP target strings to dicts, adds run_auto scan entrypoints and lowers, updates MFMA/tir handling, adds RDNA detection and test gating, guards CUDA transforms, changes autotuner validation/fast-path timing, and extends CuTeDSL with TMEM stores, pow_of_int, and reduce.run_auto. ChangesPlatform, Scan, MFMA, Tests, and Autotuner Updates
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@testing/python/issue/test_tilelang_issue_2123.py`:
- Around line 10-16: The broad except Exception around importing
CUDAPassPipelineBodyPrologue and tilelang.cuda._ffi_api should be narrowed to
only catch import-related errors; replace the generic except with an except that
catches ImportError and ModuleNotFoundError so that failures inside the imported
modules (e.g., syntax/attribute errors) surface instead of being masked. Update
the try/except around the CUDAPassPipelineBodyPrologue and _cuda_ffi_api imports
that set _has_cuda_transforms to False on failure to only catch
ImportError/ModuleNotFoundError while leaving other exceptions to propagate.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 32b3e6ce-dead-47e1-ac9e-3fc515fcdb98
📒 Files selected for processing (12)
examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_cdna4.pysrc/backend/common/op/scan.hsrc/tl_templates/hip/scan.htesting/python/amd/test_tilelang_mxfp4_gfx950.pytesting/python/issue/test_tilelang_issue_2123.pytesting/python/kernel/test_tilelang_kernel_gemm.pytesting/python/language/test_tilelang_language_eager_jit.pytesting/python/target/test_tilelang_rocm_target.pytilelang/autotuner/tuner.pytilelang/language/tir/op.pytilelang/rocm/intrinsics/mfma_macro_generator.pytilelang/testing/__init__.py
| try: | ||
| from tilelang.cuda.pipeline import CUDAPassPipelineBodyPrologue | ||
| import tilelang.cuda._ffi_api as _cuda_ffi_api | ||
|
|
||
| _has_cuda_transforms = hasattr(_cuda_ffi_api, "LowerBlackwell2SM") | ||
| except Exception: | ||
| _has_cuda_transforms = False |
There was a problem hiding this comment.
Narrow the exception handling to catch only import-related exceptions.
The bare except Exception: clause catches all exceptions, which could mask unexpected errors during import such as syntax errors or attribute errors in the imported modules. This makes debugging harder if the CUDA modules have actual defects.
🛡️ Proposed fix to narrow exception handling
try:
from tilelang.cuda.pipeline import CUDAPassPipelineBodyPrologue
import tilelang.cuda._ffi_api as _cuda_ffi_api
_has_cuda_transforms = hasattr(_cuda_ffi_api, "LowerBlackwell2SM")
-except Exception:
+except (ImportError, AttributeError):
_has_cuda_transforms = False🧰 Tools
🪛 Ruff (0.15.15)
[warning] 15-15: Do not catch blind exception: Exception
(BLE001)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@testing/python/issue/test_tilelang_issue_2123.py` around lines 10 - 16, The
broad except Exception around importing CUDAPassPipelineBodyPrologue and
tilelang.cuda._ffi_api should be narrowed to only catch import-related errors;
replace the generic except with an except that catches ImportError and
ModuleNotFoundError so that failures inside the imported modules (e.g.,
syntax/attribute errors) surface instead of being masked. Update the try/except
around the CUDAPassPipelineBodyPrologue and _cuda_ffi_api imports that set
_has_cuda_transforms to False on failure to only catch
ImportError/ModuleNotFoundError while leaving other exceptions to propagate.
|
thanks @zhangnju it's likely this commit breaks ci. |
let me check it. Thanks for your info. |
Two CUDA CI failures: 1. cuda/scan.h: Add run_auto to all scan structs (InclusiveScan1D/2D, CumSum1D/2D, CumMax1D/2D). PR tile-ai#2262 made codegen always emit ::run_auto but forgot to add the method to the CUDA templates (HIP had it). On CUDA warp size is fixed at 32, so run_auto delegates to run<T, 32>. 2. autotuner/tuner.py: Skip autotuning and scalar-input validation when the caller already supplies all config keys explicitly. PR tile-ai#2084 added validation in AutoTuner.run() that fires even when the user calls with fixed params (e.g. block_M=64, ...) and no set_autotune_inputs context, causing test_example_mha_fwd_varlen to fail with ValueError on max_seqlen_q.
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
testing/python/language/test_tilelang_language_eager_jit.py (1)
79-87:⚠️ Potential issue | 🟠 Major | ⚡ Quick winThe
product(...)iterator is consumed before the runtime assertions loop.
prodis exhausted by the compile list comprehension, so the later verification loop does not run any cases.💡 Proposed fix
- prod = product(in_dtypes, [T.float32]) + prod = list(product(in_dtypes, [T.float32]))🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@testing/python/language/test_tilelang_language_eager_jit.py` around lines 79 - 87, The product(...) iterator is consumed by the list comprehension in the gemm_ptr.par_compile call; convert the iterator to a reusable list before using it in both places (e.g., change prod = product(...) to prod = list(product(in_dtypes, [T.float32]))), then use that prod in gemm_ptr.par_compile and in the subsequent for in_dtype, out_dtype in prod loop so the verification loop actually iterates; reference symbols: prod, product, gemm_ptr.par_compile, and the for in_dtype, out_dtype in prod loop.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@tilelang/autotuner/tuner.py`:
- Around line 1286-1289: The fast-path skip reads tunable keys from only
configs[0], which can wrongly skip tuning when later configs add keys; update
the logic in the tuner (where self.configs, config_keys, mode and kwargs are
used) to compute config_keys as the union of keys from all configs (e.g.,
iterate over self.configs if it's a list) instead of using only configs[0], then
perform the issubset check against kwargs as before so the skip condition is
correct for all provided configs.
---
Outside diff comments:
In `@testing/python/language/test_tilelang_language_eager_jit.py`:
- Around line 79-87: The product(...) iterator is consumed by the list
comprehension in the gemm_ptr.par_compile call; convert the iterator to a
reusable list before using it in both places (e.g., change prod = product(...)
to prod = list(product(in_dtypes, [T.float32]))), then use that prod in
gemm_ptr.par_compile and in the subsequent for in_dtype, out_dtype in prod loop
so the verification loop actually iterates; reference symbols: prod, product,
gemm_ptr.par_compile, and the for in_dtype, out_dtype in prod loop.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 44057fbc-277c-413c-8ef8-392d8223ae69
📒 Files selected for processing (13)
examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_cdna4.pysrc/backend/common/op/scan.hsrc/tl_templates/cuda/scan.hsrc/tl_templates/hip/scan.htesting/python/amd/test_tilelang_mxfp4_gfx950.pytesting/python/issue/test_tilelang_issue_2123.pytesting/python/kernel/test_tilelang_kernel_gemm.pytesting/python/language/test_tilelang_language_eager_jit.pytesting/python/target/test_tilelang_rocm_target.pytilelang/autotuner/tuner.pytilelang/language/tir/op.pytilelang/rocm/intrinsics/mfma_macro_generator.pytilelang/testing/__init__.py
✅ Files skipped from review due to trivial changes (1)
- src/backend/common/op/scan.h
🚧 Files skipped from review as they are similar to previous changes (8)
- testing/python/kernel/test_tilelang_kernel_gemm.py
- tilelang/rocm/intrinsics/mfma_macro_generator.py
- testing/python/amd/test_tilelang_mxfp4_gfx950.py
- src/tl_templates/hip/scan.h
- tilelang/testing/init.py
- examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_cdna4.py
- tilelang/language/tir/op.py
- testing/python/target/test_tilelang_rocm_target.py
| configs = self.configs | ||
| config_keys = set(configs[0].keys()) if isinstance(configs, list) and configs else set() | ||
| if config_keys and config_keys.issubset(kwargs.keys()): | ||
| if mode == "lazy": |
There was a problem hiding this comment.
Fast-path skip condition only inspects the first config and can bypass tuning incorrectly.
The skip decision derives tunable keys from only configs[0]. If later configs contain additional keys, autotuning is skipped even when not all tunables are supplied.
💡 Proposed fix
- configs = self.configs
- config_keys = set(configs[0].keys()) if isinstance(configs, list) and configs else set()
+ configs = self.configs
+ if isinstance(configs, list) and configs:
+ config_keys = set().union(
+ *(cfg.keys() for cfg in configs if isinstance(cfg, dict))
+ )
+ else:
+ config_keys = set()
if config_keys and config_keys.issubset(kwargs.keys()):🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tilelang/autotuner/tuner.py` around lines 1286 - 1289, The fast-path skip
reads tunable keys from only configs[0], which can wrongly skip tuning when
later configs add keys; update the logic in the tuner (where self.configs,
config_keys, mode and kwargs are used) to compute config_keys as the union of
keys from all configs (e.g., iterate over self.configs if it's a list) instead
of using only configs[0], then perform the issubset check against kwargs as
before so the skip condition is correct for all provided configs.
- tcgen05_st_32dp{32,64,128,256}bNx: mirror TMEM load functions for store direction
- tcgen05_before/after_thread_sync: SM100 TMEM ordering fences around __syncthreads
- pow_of_int: Python backend for tl::pow_of_int<N> call_extern translation
- CumSum1D/2D, CumMax1D/2D: add run_auto wrappers with @cute.jit
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@tilelang/contrib/cutedsl/math.py`:
- Around line 243-260: The pow_of_int helper must reject negative exponents; add
validation at the start of pow_of_int (or inside the returned _pow) to raise a
clear error (e.g., ValueError) when exp < 0 so callers don't get silent
incorrect results—refer to the pow_of_int function and the inner _pow closure
and ensure the check runs before attempting the exp==0 or loop logic.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 948b40b5-e211-4a73-85e8-ca35473a0be1
📒 Files selected for processing (3)
tilelang/contrib/cutedsl/gemm_tcgen05.pytilelang/contrib/cutedsl/math.pytilelang/contrib/cutedsl/reduce.py
| def pow_of_int(exp: int): | ||
| """Return a function that raises its argument to the integer power `exp`. | ||
|
|
||
| Mirrors tl::pow_of_int<exp> from math.cc — the C++ codegen emits | ||
| tl::pow_of_int<N> as a call_extern, which the CuTeDSL codegen translates | ||
| to tl.pow_of_int(N)(base). On CUDA/HIP the op is lowered by FLowerIntrinsic | ||
| before reaching call_extern; for the CuTeDSL Python backend it reaches here. | ||
| """ | ||
|
|
||
| def _pow(base): | ||
| if exp == 0: | ||
| return type(base)(1) | ||
| result = base | ||
| for _ in range(exp - 1): | ||
| result = result * base | ||
| return result | ||
|
|
||
| return _pow |
There was a problem hiding this comment.
Add validation for negative exponents to avoid silent incorrect results.
If exp < 0, range(exp - 1) returns an empty iterator, causing the function to return base unchanged instead of raising an error or computing 1/base^|exp|. Since this helper mirrors tl::pow_of_int<N> which expects a compile-time constant, consider validating non-negative input.
🛡️ Proposed fix to validate exponent
def pow_of_int(exp: int):
"""Return a function that raises its argument to the integer power `exp`.
Mirrors tl::pow_of_int<exp> from math.cc — the C++ codegen emits
tl::pow_of_int<N> as a call_extern, which the CuTeDSL codegen translates
to tl.pow_of_int(N)(base). On CUDA/HIP the op is lowered by FLowerIntrinsic
before reaching call_extern; for the CuTeDSL Python backend it reaches here.
"""
+ if exp < 0:
+ raise ValueError(f"pow_of_int requires non-negative exponent, got {exp}")
def _pow(base):
if exp == 0:
return type(base)(1)
result = base
for _ in range(exp - 1):
result = result * base
return result
return _pow📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def pow_of_int(exp: int): | |
| """Return a function that raises its argument to the integer power `exp`. | |
| Mirrors tl::pow_of_int<exp> from math.cc — the C++ codegen emits | |
| tl::pow_of_int<N> as a call_extern, which the CuTeDSL codegen translates | |
| to tl.pow_of_int(N)(base). On CUDA/HIP the op is lowered by FLowerIntrinsic | |
| before reaching call_extern; for the CuTeDSL Python backend it reaches here. | |
| """ | |
| def _pow(base): | |
| if exp == 0: | |
| return type(base)(1) | |
| result = base | |
| for _ in range(exp - 1): | |
| result = result * base | |
| return result | |
| return _pow | |
| def pow_of_int(exp: int): | |
| """Return a function that raises its argument to the integer power `exp`. | |
| Mirrors tl::pow_of_int<exp> from math.cc — the C++ codegen emits | |
| tl::pow_of_int<N> as a call_extern, which the CuTeDSL codegen translates | |
| to tl.pow_of_int(N)(base). On CUDA/HIP the op is lowered by FLowerIntrinsic | |
| before reaching call_extern; for the CuTeDSL Python backend it reaches here. | |
| """ | |
| if exp < 0: | |
| raise ValueError(f"pow_of_int requires non-negative exponent, got {exp}") | |
| def _pow(base): | |
| if exp == 0: | |
| return type(base)(1) | |
| result = base | |
| for _ in range(exp - 1): | |
| result = result * base | |
| return result | |
| return _pow |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tilelang/contrib/cutedsl/math.py` around lines 243 - 260, The pow_of_int
helper must reject negative exponents; add validation at the start of pow_of_int
(or inside the returned _pow) to raise a clear error (e.g., ValueError) when exp
< 0 so callers don't get silent incorrect results—refer to the pow_of_int
function and the inner _pow closure and ensure the check runs before attempting
the exp==0 or loop logic.
@LeiWang1999 CUDA CI issues have been fixed now. |
HI @LeiWang1999
I tried to run CI on AMD machine based on the latest codes, and found some cases failed.
This PR includes the below bug fixes:
tvm_mfmaintrinsic broken after tirx migration (tilelang/language/tir/op.py)After the TIR→tirx migration,
tirx.Callenforces strict type checking on all arguments. The six string arguments intvm_mfma(shape, layouts, dtypes) were no longer implicitly coerced, producingTypeError: Mismatched type on argument #2. Fix by wrapping each string argument intvm.tirx.StringImm(str(...)).src/backend/common/op/scan.h,src/tl_templates/hip/scan.h)The scan lowering emitted
CumSum/CumMax{1D,2D}<threads>::run, which hard-codesSEG=64. On Wave32 architectures__shfl_up/downis bounded by the 32-thread warp width, soSEG=64silently crosses warp boundaries and corrupts ~50% of results. Fix by:::runto::run_auto.run_automethods toCumSum1D/2DandCumMax1D/2Dwrapper structs.run_autoqueries__builtin_amdgcn_wavefrontsize()at kernel launch time and dispatches toSEG=32(Wave32) orSEG=64(Wave64) accordingly.Autotune scalar-input validation bypassed by disk cache (
tilelang/autotuner/tuner.py)_validate_input_supply_requirementswas called after the disk-cache lookup, so a cache hit would return a kernel without ever checking that scalar inputs were supplied viaset_autotune_inputs. Move the validation before the cache lookup so it is unconditional.tfloat32tests incorrectly running on ROCm (test_tilelang_kernel_gemm.py,test_tilelang_language_eager_jit.py)tfloat32is unsupported on some ROCm target, AMD doesn't advise customers to use it, so we don't need to run it on AMD CI testtest_gemm_f32f32f32_nn/nt: change decorator from@requires_cuda_or_cdnato@requires_cuda.test_jit2_gemm_ptr: excludeT.tfloat32from the dtype list when not on CUDA to prevent apar_compilefailure.tilelang/testing/__init__.py,test_tilelang_rocm_target.py)Three tests validating RDNA device-model behavior were marked
@requires_rocm, so they ran on CDNA (gfx950) and failed trivially.requires_rdnadecorator that skips on non-RDNA targets.RDNA(gfx1200)rejection assertions with monkeypatchedgeneration=10scenarios, and fix match strings from"gfx11 targets only"to"gfx11/gfx12 targets only".Test result :
Ran the full CI test suite on all four machines:
Remaining skips are CUDA-only tests (TMA, tfloat32) and RDNA-only tests on non-RDNA machines, all expected.
Summary by CodeRabbit
New Features
Bug Fixes
Tests