Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions src/cuda/op/gemm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ namespace cuda {
namespace {

constexpr const char *kCudaMMA = "cuda.mma";
constexpr const char *kCudaFMA = "cuda.fma";
constexpr const char *kCudaWGMMA = "cuda.wgmma";
constexpr const char *kCudaTCGEN05 = "cuda.tcgen05";

Expand Down Expand Up @@ -94,6 +95,26 @@ bool AllowWgmma(const GemmNode &op, int block_size, Target target) {
CheckWgmma(op);
}

bool AllowVoltaMma(const GemmNode &op) {
bool scope_ok = (IsSharedBuffer(op.a_) || IsFragmentBuffer(op.a_)) &&
IsSharedBuffer(op.b_);
if (!scope_ok) {
return false;
}
if (op.transA_) {
return false;
}
if (op.a_->dtype != DataType::Float(16) ||
op.b_->dtype != DataType::Float(16)) {
return false;
}
if (op.c_->dtype != DataType::Float(16) &&
op.c_->dtype != DataType::Float(32)) {
return false;
}
return op.m_ % 16 == 0 && op.n_ % 16 == 0 && op.k_ % 4 == 0;
}

void FatalWgmmaUnavailable(const GemmNode &op, Target target) {
LOG(FATAL) << "T.wgmma_gemm() requires Hopper WGMMA lowering, but "
"constraints were not satisfied. Got target="
Expand Down Expand Up @@ -283,6 +304,9 @@ struct Gemm {
if (AllowWgmma(op, block_size, target)) {
return kCudaWGMMA;
}
if (TargetIsVolta(target) && !AllowVoltaMma(op)) {
return kCudaFMA;
}
return kCudaMMA;
}

Expand All @@ -298,6 +322,11 @@ struct Gemm {
if (gemm_inst == kCudaWGMMA) {
return ComputeWgmmaWarpPartition(policy, M, N, num_warps);
}
if (gemm_inst == kCudaFMA) {
policy.m_warp = 1;
policy.n_warp = num_warps;
return {1, num_warps};
}
int k_n_per_warp = TargetIsVolta(target) ? 16 : 8;
return ComputeDefaultWarpPartition(policy, M, N, num_warps, k_n_per_warp);
}
Expand Down
2 changes: 2 additions & 0 deletions testing/python/cuda/test_cuda_mma_sm75_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import tilelang.testing
from tilelang.cuda.intrinsics.macro.mma_macro_generator import TensorCoreIntrinEmitter, TensorCoreIntrinEmitterWithLadderTransform
from tilelang.cuda.intrinsics.macro.mma_sm75_macro_generator import TensorCoreIntrinEmitterSM75
from tilelang.cuda.op.gemm.gemm_fma import GemmFMA
from tilelang.cuda.op.gemm.gemm_mma import GemmMMA
from tilelang.cuda.op.gemm.gemm_mma_sm70 import GemmMMASm70
from tilelang.cuda.op.gemm.gemm_mma_sm75 import GemmMMASm75
Expand All @@ -17,6 +18,7 @@ def test_sm75_uses_sm75_mma_gemm_impl():
assert resolve_gemm_impl("cuda.mma", Target({"kind": "cuda", "arch": "sm_70"})) is GemmMMASm70
assert resolve_gemm_impl("cuda.mma", Target({"kind": "cuda", "arch": "sm_75"})) is GemmMMASm75
assert resolve_gemm_impl("cuda.mma", Target({"kind": "cuda", "arch": "sm_80"})) is GemmMMA
assert resolve_gemm_impl("cuda.fma", Target({"kind": "cuda", "arch": "sm_70"})) is GemmFMA


def test_sm75_fp16_emitter_uses_m16n8k8_shape():
Expand Down
48 changes: 48 additions & 0 deletions testing/python/kernel/test_tilelang_kernel_sm70_fragment_copy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import torch

import tilelang
import tilelang.language as T
import tilelang.testing


def _make_fragment_cast_into_rs_gemm():
@T.prim_func
def main(
A: T.Tensor((64, 16), T.float16),
B: T.Tensor((16, 16), T.float16),
C: T.Tensor((64, 16), T.float32),
):
with T.Kernel(1, threads=128):
A_shared = T.alloc_shared((64, 16), T.float16)
B_shared = T.alloc_shared((16, 16), T.float16)
P = T.alloc_fragment((64, 16), T.float32)
P_half = T.alloc_fragment((64, 16), T.float16)
C_local = T.alloc_fragment((64, 16), T.float32)

T.copy(A, A_shared)
T.copy(B, B_shared)

for i, j in T.Parallel(64, 16):
P[i, j] = T.cast(A_shared[i, j], T.float32)

T.copy(P, P_half)
T.clear(C_local)
T.gemm(P_half, B_shared, C_local, policy=T.GemmWarpPolicy.FullRow)
T.copy(C_local, C)

return main


@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_eq(7, 0)
def test_sm70_fragment_cast_copy_feeds_rs_gemm():
kernel = tilelang.compile(_make_fragment_cast_into_rs_gemm(), target="cuda", out_idx=[2])
source = kernel.get_kernel_source()
assert "tl_frag_copy" not in source

a = torch.randn((64, 16), device="cuda", dtype=torch.float16)
b = torch.randn((16, 16), device="cuda", dtype=torch.float16)
c = kernel(a, b)
ref = a.float() @ b.float()

tilelang.testing.torch_assert_close(c, ref, rtol=1e-2, atol=1e-2, max_mismatched_ratio=0.01)
125 changes: 125 additions & 0 deletions testing/python/kernel/test_tilelang_kernel_sm70_gemm_fma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import torch

import tilelang
import tilelang.language as T
import tilelang.testing


def _make_bf16_ss_gemm():
@T.prim_func
def main(
A: T.Tensor((64, 16), T.bfloat16),
B: T.Tensor((16, 16), T.bfloat16),
C: T.Tensor((64, 16), T.float32),
):
with T.Kernel(1, threads=128):
A_shared = T.alloc_shared((64, 16), T.bfloat16)
B_shared = T.alloc_shared((16, 16), T.bfloat16)
C_local = T.alloc_fragment((64, 16), T.float32)

T.copy(A, A_shared)
T.copy(B, B_shared)
T.clear(C_local)
T.gemm(A_shared, B_shared, C_local, policy=T.GemmWarpPolicy.FullRow)
T.copy(C_local, C)

return main


def _make_bf16_ss_and_rs_gemm():
@T.prim_func
def main(
A: T.Tensor((64, 16), T.bfloat16),
B: T.Tensor((16, 16), T.bfloat16),
C: T.Tensor((64, 16), T.float32),
):
with T.Kernel(1, threads=128):
A_shared = T.alloc_shared((64, 16), T.bfloat16)
B_shared = T.alloc_shared((16, 16), T.bfloat16)
P = T.alloc_fragment((64, 16), T.float32)
P_bf16 = T.alloc_fragment((64, 16), T.bfloat16)
C_local = T.alloc_fragment((64, 16), T.float32)

T.copy(A, A_shared)
T.copy(B, B_shared)

T.clear(P)
T.gemm(A_shared, B_shared, P, policy=T.GemmWarpPolicy.FullRow)
T.copy(P, P_bf16)

T.clear(C_local)
T.gemm(P_bf16, B_shared, C_local, policy=T.GemmWarpPolicy.FullRow)
T.copy(C_local, C)

return main


def _make_fp16_transposed_a_ss_gemm():
@T.prim_func
def main(
A: T.Tensor((16, 64), T.float16),
B: T.Tensor((16, 16), T.float16),
C: T.Tensor((64, 16), T.float32),
):
with T.Kernel(1, threads=128):
A_shared = T.alloc_shared((16, 64), T.float16)
B_shared = T.alloc_shared((16, 16), T.float16)
C_local = T.alloc_fragment((64, 16), T.float32)

T.copy(A, A_shared)
T.copy(B, B_shared)
T.clear(C_local)
T.gemm(A_shared, B_shared, C_local, transpose_A=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(C_local, C)

return main


@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_eq(7, 0)
def test_sm70_bf16_ss_gemm_uses_fma_fallback():
kernel = tilelang.compile(_make_bf16_ss_gemm(), target="cuda", out_idx=[2])
source = kernel.get_kernel_source()
assert "mma_sync_sm70<tl::DataType::kBFloat16" not in source

a = torch.randn((64, 16), device="cuda", dtype=torch.bfloat16)
b = torch.randn((16, 16), device="cuda", dtype=torch.bfloat16)
c = kernel(a, b)
ref = a.float() @ b.float()

tilelang.testing.torch_assert_close(c, ref, rtol=1e-2, atol=1e-2, max_mismatched_ratio=0.01)


@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_eq(7, 0)
def test_sm70_bf16_rs_gemm_uses_fma_fallback():
kernel = tilelang.compile(_make_bf16_ss_and_rs_gemm(), target="cuda", out_idx=[2])
source = kernel.get_kernel_source()
assert "mma_sync_sm70<tl::DataType::kBFloat16" not in source

a = torch.randn((64, 16), device="cuda", dtype=torch.bfloat16) * 0.25
b = torch.randn((16, 16), device="cuda", dtype=torch.bfloat16) * 0.25
c = kernel(a, b)
p = (a.float() @ b.float()).to(torch.bfloat16)
ref = p.float() @ b.float()

tilelang.testing.torch_assert_close(c, ref, rtol=2e-2, atol=2e-2, max_mismatched_ratio=0.01)


@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_eq(7, 0)
def test_sm70_fp16_transposed_a_uses_fma_fallback():
kernel = tilelang.compile(_make_fp16_transposed_a_ss_gemm(), target="cuda", out_idx=[2])
source = kernel.get_kernel_source()
assert "mma_sync_sm70" not in source

a = torch.randn((16, 64), device="cuda", dtype=torch.float16) * 0.25
b = torch.randn((16, 16), device="cuda", dtype=torch.float16) * 0.25
c = kernel(a, b)
ref = a.t().float() @ b.float()

tilelang.testing.torch_assert_close(c, ref, rtol=1e-2, atol=1e-2, max_mismatched_ratio=0.01)


if __name__ == "__main__":
tilelang.testing.main()
7 changes: 7 additions & 0 deletions tilelang/cuda/op/gemm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

from tilelang.tileop.gemm.registry import register_gemm_impl
from .gemm_fma import GEMM_INST_FMA, GemmFMA
from .gemm_mma import GEMM_INST_MMA, GemmMMA
from .gemm_mma_sm70 import GemmMMASm70
from .gemm_mma_sm75 import GemmMMASm75
Expand All @@ -23,6 +24,11 @@ def _match_mma_sm75(target) -> bool:
return target_is_turing(target)


def _match_fma(target) -> bool:
# _match_fma checks CUDA FMA capability; SelectInst decides when to use it.
return target_is_cuda(target)


def _match_wgmma(target) -> bool:
return target_is_cuda(target)

Expand All @@ -34,5 +40,6 @@ def _match_tcgen05(target) -> bool:
register_gemm_impl("cuda.mma", GEMM_INST_MMA, _match_mma, GemmMMA)
register_gemm_impl("cuda.mma_sm70", GEMM_INST_MMA, _match_mma_sm70, GemmMMASm70)
register_gemm_impl("cuda.mma_sm75", GEMM_INST_MMA, _match_mma_sm75, GemmMMASm75)
register_gemm_impl("cuda.fma", GEMM_INST_FMA, _match_fma, GemmFMA)
register_gemm_impl("cuda.wgmma", GEMM_INST_WGMMA, _match_wgmma, GemmWGMMA)
register_gemm_impl("cuda.tcgen05", GEMM_INST_TCGEN05, _match_tcgen05, GemmTCGEN5)
Loading
Loading