Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 116 additions & 5 deletions enzyme/Enzyme/BlasDerivatives.td
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def Rows : MagicInst; // given a transpose, normal rows, normal cols get the tru
def Concat : MagicInst;

def ShadowNoInc : MagicInst;
def Dep : MagicInst;

class Binop<string _s, list<string> _tys> {
string s = _s;
Expand Down Expand Up @@ -89,8 +90,13 @@ class IntMatchers<string _inty, string _outty, list<string> _before, list<string

def side_to_trans : IntMatchers<
"charType", "charType",
["'l'", "'L'", "'R'", "'r'"],
["'n'", "'N'", "'T'", "'t'"]
["'l'", "'L'", "141", "0", "'R'", "'r'", "142", "1"],
["'n'", "'N'", "'N'", "'N'", "'T'", "'t'", "'T'", "'T'"]
>;
def side_to_rtrans : IntMatchers<
"charType", "charType",
["'l'", "'L'", "141", "0", "'R'", "'r'", "142", "1"],
["'t'", "'T'", "'T'", "'T'", "'N'", "'n'", "'N'", "'N'"]
>;

def is_diag_int : IntMatchers<
Expand All @@ -107,6 +113,7 @@ def is_nonunit : MagicInst;
def First : MagicInst;
def Lookup : MagicInst;
def LoadLookup : MagicInst;
def mat_ld : MagicInst;

class FAdd<string _tmp=""> {
string unused = _tmp;
Expand Down Expand Up @@ -163,6 +170,10 @@ class MemcpyMatAdd<bit _shouldZero> {
bit shouldZero = _shouldZero;
}

class MatAdd<bit _shouldZero> {
bit shouldZero = _shouldZero;
}

class For<string idx_, bit offset_> {
string idx = idx_;
bit offset = offset_;
Expand Down Expand Up @@ -672,6 +683,45 @@ def trtrs : CallBlasPattern<(Op $layout, $uplo, $trans, $diag, $n, $nrhs, $A, $l
]
>;

def trsv : CallBlasPattern<(Op $layout, $uplo, $trans, $diag, $n, $A, $lda, $x, $incx),
["x"],
[cblas_layout, uplo, trans, diag, len, mld<["uplo", "n", "n"]>, vinc<["n"]>],
[
/* A */
(Seq<["tri", "triangular", "n"], [], 1>
(BlasCall<"lacpy"> $layout, $uplo, $n, $n, (Shadow $A), use<"tri">, $n),
(BlasCall<"ger">
$layout,
$n,
$n,
Constant<"-1">,
(Rows $trans,
(Concat (Shadow $x)),
(Concat $x)),
(Rows $trans,
(Concat $x),
(Concat (Shadow $x))),
use<"tri">, $n),

(BlasCall<"copy">
(ISelect (is_nonunit $diag), ConstantInt<0>, $n),
(First (Shadow $A)), (Add $lda, ConstantInt<1>),
use<"tri">, (Add $n, ConstantInt<1>)),

(BlasCall<"lacpy"> $layout, $uplo, $n, $n, use<"tri">, $n, (Shadow $A))
),
/* x */ (BlasCall<"trsv"> $layout, $uplo, transpose<"trans">, $diag, $n, $A, (ld $A, Char<"N">, $lda, $n, $n), (Shadow $x)),
]
,
(Seq<["tmp", "vector", "n"], [], 0>
(BlasCall<"copy"> $n, (Dep (Shadow $A), (Dep (Shadow $x), $x)), use<"tmp">, ConstantInt<1>),
(BlasCall<"trmv"> $layout, $uplo, $trans, $diag, $n, (Shadow $A), (Dep (Shadow $x), use<"tmp">), ConstantInt<1>),
(BlasCall<"axpy"> $n, Constant<"-1">, (Dep (Shadow $A), use<"tmp">), ConstantInt<1>, (Shadow $x)),
(BlasCall<"axpy"> (ISelect (is_nonunit $diag), ConstantInt<0>, $n), Constant<"1">, (Dep (Shadow $A), $x), (Shadow $x)),
(BlasCall<"trsv"> $layout, $uplo, $trans, $diag, $n, $A, (ld $A, Char<"N">, $lda, $n, $n), (Shadow $x))
)
>;

def spr2 : CallBlasPattern<(Op $layout, $uplo, $n, $alpha, $x, $incx, $y, $incy, $ap),
["ap"],
[cblas_layout, uplo, len, fp, vinc<["n"]>, vinc<["n"]>, ap<["n"]>],
Expand Down Expand Up @@ -718,9 +768,55 @@ def trsm: CallBlasPattern<(Op $layout, $side, $uplo, $transa, $diag, $m, $n, $al
[cblas_layout, side, uplo, trans, diag, len, len, fp, mld<["side", "m", "n"]>, mld<["m","n"]>],
[
/* alpha */ (AssertingInactiveArg),
/* A */ (AssertingInactiveArg),
/* B */ (AssertingInactiveArg),
/* A */ (Seq<["tmp", "product", "m", "n"], [], 1>
(BlasCall<"lacpy"> $layout, Char<"G">, $m, $n, (Shadow $B), use<"tmp">, (mat_ld $layout, $m, $n)),
(BlasCall<"trsm"> $layout, $side, $uplo, transpose<"transa">, $diag, $m, $n, Constant<"1.0">, $A, (ld $A, (side_to_trans $side), $lda, $n, $m), use<"tmp">, (mat_ld $layout, $m, $n)),
(Seq<["tri", "side_square", "side", "m", "n"], [], 1>
(BlasCall<"lacpy"> $layout, $uplo, (ISelect (is_left $side), $m, $n), (ISelect (is_left $side), $m, $n), (Shadow $A), use<"tri">, (ISelect (is_left $side), $m, $n)),
(BlasCall<"gemm">
$layout,
(side_to_trans $side),
(side_to_rtrans $side),
(ISelect (is_left $side), $m, $n),
(ISelect (is_left $side), $m, $n),
(ISelect (is_left $side), $n, $m),
Constant<"-1">,
(Rows (side_to_trans $side),
(Rows $transa,
(Concat use<"tmp">, (mat_ld $layout, $m, $n)),
(Concat $B, $ldb)),
(Rows $transa,
(Concat $B, $ldb),
(Concat use<"tmp">, (mat_ld $layout, $m, $n)))),
(Rows (side_to_trans $side),
(Rows $transa,
(Concat $B, $ldb),
(Concat use<"tmp">, (mat_ld $layout, $m, $n))),
(Rows $transa,
(Concat use<"tmp">, (mat_ld $layout, $m, $n)),
(Concat $B, $ldb))),
Constant<"1">,
use<"tri">, (ISelect (is_left $side), $m, $n)),
(BlasCall<"copy">
(ISelect (is_nonunit $diag), ConstantInt<0>, (ISelect (is_left $side), $m, $n)),
(First (Shadow $A)), (Add $lda, ConstantInt<1>),
use<"tri">, (Add (ISelect (is_left $side), $m, $n), ConstantInt<1>)),
(BlasCall<"lacpy"> $layout, $uplo, (ISelect (is_left $side), $m, $n), (ISelect (is_left $side), $m, $n), use<"tri">, (ISelect (is_left $side), $m, $n), (Shadow $A))
)
),
/* B */ (BlasCall<"trsm"> $layout, $side, $uplo, transpose<"transa">, $diag, $m, $n, $alpha, $A, (ld $A, (side_to_trans $side), $lda, $n, $m), (Shadow $B)),
]
,
(Seq<["tmp", "product", "m", "n"], [], 0>
(Dep (Shadow $alpha), (AssertingInactiveArg)),
(BlasCall<"lascl"> $layout, Char<"G">, ConstantInt<0>, ConstantInt<0>, Constant<"1.0">, $alpha, $m, $n, (Shadow $B), Alloca<1>),
(BlasCall<"lacpy"> $layout, Char<"G">, $m, $n, (Dep (Shadow $A), (Dep (Shadow $B), $B)), $ldb, use<"tmp">, (mat_ld $layout, $m, $n)),
(BlasCall<"trmm"> $layout, $side, $uplo, $transa, $diag, $m, $n, Constant<"1.0">, (Shadow $A), (Dep (Shadow $B), use<"tmp">), (mat_ld $layout, $m, $n)),
(BlasCall<"lascl"> $layout, Char<"G">, ConstantInt<0>, ConstantInt<0>, Constant<"1.0">, Constant<"-1.0">, $m, $n, (Dep (Shadow $A), use<"tmp">), (mat_ld $layout, $m, $n), Alloca<1>),
(MatAdd<true> $layout, $m, $n, (Shadow $B), (Concat (Dep (Shadow $A), use<"tmp">), (mat_ld $layout, $m, $n))),
(MatAdd<false> $layout, (ISelect (is_nonunit $diag), ConstantInt<0>, $m), $n, (Shadow $B), (Concat (Dep (Shadow $A), $B), $ldb)),
(BlasCall<"trsm"> $layout, $side, $uplo, $transa, $diag, $m, $n, Constant<"1.0">, $A, (ld $A, (side_to_trans $side), $lda, $n, $m), (Shadow $B))
)
>;

def uplo_to_normal : IntMatchers<
Expand All @@ -744,7 +840,6 @@ def uplo_to_rside : IntMatchers<
["'R'", "'R'", "'L'", "'L'"]
>;


def potrf: CallBlasPattern<(Op $layout, $uplo, $n, $A, $lda, $info),
["A"],
[cblas_layout, uplo, len, mld<["uplo", "n", "n"]>, info],
Expand Down Expand Up @@ -888,4 +983,20 @@ def potrs: CallBlasPattern<(Op $layout, $uplo, $n, $nrhs, $A, $lda, $B, $ldb, $i
),
(BlasCall<"potrs"> $layout, $uplo, $n, $nrhs, $A, (ld $A, Char<"N">, $lda, $n, $n), (Shadow $B), Alloca<1>)
]
,
(Seq<["tmp", "product", "n", "nrhs"], [], 0>
(BlasCall<"lacpy"> $layout, Char<"G">, $n, $nrhs, (Dep (Shadow $A), (Dep (Shadow $B), $B)), $ldb, use<"tmp">, (mat_ld $layout, $n, $nrhs)),
(BlasCall<"trmm"> $layout, Char<"L">, $uplo, (uplo_to_trans $uplo), Char<"N">, $n, $nrhs, Constant<"1.0">, (Dep (Shadow $A), $A), (ld $A, Char<"N">, $lda, $n, $n), (Dep (Shadow $B), use<"tmp">), (mat_ld $layout, $n, $nrhs)),
(BlasCall<"trmm"> $layout, Char<"L">, $uplo, (uplo_to_normal $uplo), Char<"N">, $n, $nrhs, Constant<"1.0">, (Shadow $A), (Dep (Shadow $B), use<"tmp">), (mat_ld $layout, $n, $nrhs)),
(BlasCall<"lascl"> $layout, Char<"G">, ConstantInt<0>, ConstantInt<0>, Constant<"1.0">, Constant<"-1.0">, $n, $nrhs, (Dep (Shadow $A), (Dep (Shadow $B), use<"tmp">)), (mat_ld $layout, $n, $nrhs), Alloca<1>),
(MatAdd<true> $layout, $n, $nrhs, (Shadow $B), (Concat (Dep (Shadow $A), use<"tmp">), (mat_ld $layout, $n, $nrhs))),

(BlasCall<"lacpy"> $layout, Char<"G">, $n, $nrhs, (Dep (Shadow $A), (Dep (Shadow $B), $B)), $ldb, use<"tmp">, (mat_ld $layout, $n, $nrhs)),
(BlasCall<"trmm"> $layout, Char<"L">, $uplo, (uplo_to_trans $uplo), Char<"N">, $n, $nrhs, Constant<"1.0">, (Shadow $A), (Dep (Shadow $B), use<"tmp">), (mat_ld $layout, $n, $nrhs)),
(BlasCall<"trmm"> $layout, Char<"L">, $uplo, (uplo_to_normal $uplo), Char<"N">, $n, $nrhs, Constant<"1.0">, (Dep (Shadow $A), $A), (ld $A, Char<"N">, $lda, $n, $n), (Dep (Shadow $B), use<"tmp">), (mat_ld $layout, $n, $nrhs)),
(BlasCall<"lascl"> $layout, Char<"G">, ConstantInt<0>, ConstantInt<0>, Constant<"1.0">, Constant<"-1.0">, $n, $nrhs, (Dep (Shadow $A), (Dep (Shadow $B), use<"tmp">)), (mat_ld $layout, $n, $nrhs), Alloca<1>),
(MatAdd<true> $layout, $n, $nrhs, (Shadow $B), (Concat (Dep (Shadow $A), use<"tmp">), (mat_ld $layout, $n, $nrhs))),

(BlasCall<"potrs"> $layout, $uplo, $n, $nrhs, $A, (ld $A, Char<"N">, $lda, $n, $n), (Shadow $B), Alloca<1>)
)
>;
38 changes: 21 additions & 17 deletions enzyme/Enzyme/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1912,7 +1912,8 @@ Function *getOrInsertMemcpyMat(Module &Mod, Type *elementType, PointerType *PT,

Function *getOrInsertDifferentialFloatMemcpyMat(
Module &Mod, Type *elementType, PointerType *PT, IntegerType *IT,
IntegerType *CT, unsigned dstalign, unsigned srcalign, bool zeroSrc) {
IntegerType *UT, IntegerType *LT, unsigned dstalign, unsigned srcalign,
bool zeroSrc) {
assert(elementType->isFPOrFPVectorTy());
#if LLVM_VERSION_MAJOR < 17
#if LLVM_VERSION_MAJOR >= 15
Expand All @@ -1927,10 +1928,13 @@ Function *getOrInsertDifferentialFloatMemcpyMat(
#endif
#endif
std::string name = "__enzyme_dmemcpy_" + tofltstr(elementType) + "_mat_" +
std::to_string(cast<IntegerType>(UT)->getBitWidth()) +
"_" + std::to_string(cast<IntegerType>(LT)->getBitWidth()) +
"_" +
std::to_string(cast<IntegerType>(IT)->getBitWidth()) +
(zeroSrc ? "_zero" : "");
FunctionType *FT = FunctionType::get(Type::getVoidTy(Mod.getContext()),
{CT, IT, IT, PT, IT, PT, IT}, false);
{UT, LT, IT, IT, PT, IT, PT, IT}, false);

Function *F = cast<Function>(Mod.getOrInsertFunction(name, FT).getCallee());

Expand All @@ -1945,8 +1949,8 @@ Function *getOrInsertDifferentialFloatMemcpyMat(
#endif
F->addFnAttr(Attribute::NoUnwind);
F->addFnAttr(Attribute::AlwaysInline);
F->addParamAttr(3, Attribute::NoAlias);
F->addParamAttr(5, Attribute::NoAlias);
F->addParamAttr(4, Attribute::NoAlias);
F->addParamAttr(6, Attribute::NoAlias);

BasicBlock *entry = BasicBlock::Create(F->getContext(), "entry", F);
BasicBlock *swtch = BasicBlock::Create(F->getContext(), "swtch", F);
Expand All @@ -1957,7 +1961,9 @@ Function *getOrInsertDifferentialFloatMemcpyMat(

auto uplo = F->arg_begin();
uplo->setName("uplo");
auto M = uplo + 1;
auto layout = uplo + 1;
layout->setName("layout");
auto M = layout + 1;
M->setName("M");
auto N = M + 1;
N->setName("N");
Expand All @@ -1982,8 +1988,8 @@ Function *getOrInsertDifferentialFloatMemcpyMat(
{
IRBuilder<> B(swtch);
auto swtchT = B.CreateSwitch(uplo, Ginit);
swtchT->addCase(ConstantInt::get(CT, 'U'), Uinit);
swtchT->addCase(ConstantInt::get(CT, 'L'), Linit);
swtchT->addCase(ConstantInt::get(UT, 'U'), Uinit);
swtchT->addCase(ConstantInt::get(UT, 'L'), Linit);
}

std::pair<char, BasicBlock *> todo[] = {
Expand Down Expand Up @@ -2019,15 +2025,8 @@ Function *getOrInsertDifferentialFloatMemcpyMat(
PHINode *i = B.CreatePHI(IT, 2, dir + "i");
i->addIncoming(istart, init);

Value *srci = B.CreateInBoundsGEP(
elementType, src,
B.CreateAdd(i, B.CreateMul(j, lsrc, "", true, true), "", true, true),
dir + "src.i");

Value *dsti = B.CreateInBoundsGEP(
elementType, dst,
B.CreateAdd(i, B.CreateMul(j, ldst, "", true, true), "", true, true),
dir + "dst.i");
Value *srci = lookup_with_layout(B, elementType, layout, src, lsrc, i, j);
Value *dsti = lookup_with_layout(B, elementType, layout, dst, ldst, i, j);
LoadInst *srcl = B.CreateLoad(elementType, srci, dir + "src.i.l");
LoadInst *dstl = B.CreateLoad(elementType, dsti, dir + "dst.i.l");
auto res = B.CreateFAdd(srcl, dstl);
Expand Down Expand Up @@ -3567,6 +3566,7 @@ llvm::Optional<BlasInfo> extractBLAS(llvm::StringRef in)
"dot", "scal", "axpy", "gemv", "gemm", "spmv", "syrk", "nrm2",
"trmm", "trmv", "symm", "potrf", "potrs", "copy", "spmv", "syr2k",
"potrs", "getrf", "getrs", "trtrs", "getri", "symv", "lacpy", "trsv",
"trsm",
};
const char *floatType[] = {"s", "d", "c", "z"};
const char *prefixes[] = {"" /*Fortran*/, "cblas_"};
Expand Down Expand Up @@ -4100,7 +4100,11 @@ SmallVector<llvm::Value *, 1> get_blas_row(llvm::IRBuilder<> &B,
if (!cublas) {

if (!byRef) {
cond = B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 111));
auto isCblasN =
B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 111));
auto isN = B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 'N'));
auto isn = B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 'n'));
cond = B.CreateOr(isCblasN, B.CreateOr(isN, isn));
} else {
auto isn = B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 'n'));
auto isN = B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 'N'));
Expand Down
4 changes: 2 additions & 2 deletions enzyme/Enzyme/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -808,8 +808,8 @@ llvm::Function *getOrInsertMemcpyMat(llvm::Module &M, llvm::Type *elementType,

llvm::Function *getOrInsertDifferentialFloatMemcpyMat(
llvm::Module &M, llvm::Type *elementType, llvm::PointerType *PT,
llvm::IntegerType *IT, llvm::IntegerType *CT, unsigned dstalign,
unsigned srcalign, bool zeroSrc);
llvm::IntegerType *IT, llvm::IntegerType *UT, llvm::IntegerType *LT,
unsigned dstalign, unsigned srcalign, bool zeroSrc);

/// Create function for type that performs the derivative memmove on floating
/// point memory
Expand Down
28 changes: 28 additions & 0 deletions enzyme/test/Enzyme/ForwardMode/blas/cblas_trsm_f.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
; RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-preopt=false -enzyme-detect-readthrow=0 -S | FileCheck %s

target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128"
target triple = "x86_64-unknown-linux-gnu"

declare dso_local void @__enzyme_fwddiff(...)

declare void @cblas_dtrsm(i32, i32, i32, i32, i32, i32, i32, double, double*, i32, double*, i32)

define void @f(i32 %layout, double* %A, double* %B) {
entry:
call void @cblas_dtrsm(i32 %layout, i32 141, i32 121, i32 111, i32 131, i32 4, i32 3, double 2.000000e+00, double* %A, i32 4, double* %B, i32 4)
ret void
}

define void @active(i32 %layout, double* %A, double* %dA, double* %B, double* %dB) {
entry:
call void (...) @__enzyme_fwddiff(void (i32, double*, double*)* @f, metadata !"enzyme_const", i32 %layout, metadata !"enzyme_dup", double* %A, double* %dA, metadata !"enzyme_dup", double* %B, double* %dB)
ret void
}

; CHECK-LABEL: define internal void @fwddiffef(
; CHECK: call void @cblas_dtrsm
; CHECK: icmp eq i32 %layout, 101
; CHECK: __enzyme_dmemcpy_double_mat_8_32_32_zero
; CHECK: call void @cblas_dtrsm
; CHECK-NOT: call void bitcast
; CHECK: define internal void @__enzyme_dmemcpy_double_mat_8_32_32_zero(i8 %uplo, i32 %layout,
45 changes: 45 additions & 0 deletions enzyme/test/Enzyme/ForwardMode/blas/potrs_f.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -S -enzyme-detect-readthrow=0 | FileCheck %s; fi
; RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-preopt=false -enzyme-detect-readthrow=0 -S | FileCheck %s

target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128"
target triple = "x86_64-unknown-linux-gnu"

declare dso_local void @__enzyme_fwddiff(...)

declare void @dpotrs_64_(i8*, i64*, i64*, double*, i64*, double*, i64*, i64*, i64)

define void @f(double* %A, double* %B) {
entry:
%uplo = alloca i8, align 1
%n = alloca i64, align 8
%nrhs = alloca i64, align 8
%lda = alloca i64, align 8
%ldb = alloca i64, align 8
%info = alloca i64, align 8
store i8 85, i8* %uplo, align 1
store i64 4, i64* %n, align 8
store i64 3, i64* %nrhs, align 8
store i64 4, i64* %lda, align 8
store i64 4, i64* %ldb, align 8
call void @dpotrs_64_(i8* %uplo, i64* %n, i64* %nrhs, double* %A, i64* %lda, double* %B, i64* %ldb, i64* %info, i64 1)
ret void
}

define void @active(double* %A, double* %dA, double* %B, double* %dB) {
entry:
call void (...) @__enzyme_fwddiff(void (double*, double*)* @f, metadata !"enzyme_dup", double* %A, double* %dA, metadata !"enzyme_dup", double* %B, double* %dB)
ret void
}

; CHECK-LABEL: define internal void @fwddiffef(
; CHECK: call void @dpotrs_64_
; CHECK: call void @dlacpy_64_
; CHECK: call void @dtrmm_64_
; CHECK: call void @dtrmm_64_
; CHECK: call void @dlascl_64_
; CHECK: call void @dlacpy_64_
; CHECK: call void @dtrmm_64_
; CHECK: call void @dtrmm_64_
; CHECK: call void @dlascl_64_
; CHECK: call void @dpotrs_64_
; CHECK: ret void
Loading
Loading