Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions enzyme/Enzyme/MLIR/Implementations/Common.td
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ class Operation<bit usesPrimal_, bit usesShadow_, bit usesCustom_=0> {
bit usesCustom = usesCustom_;
}

class ResultIndex<int index_> {
int index = index_;
}
def Result : ResultIndex</*index=*/0>;
class DiffeRetIndex<list<int> indices_> {
list<int> indices = indices_;
}
Expand Down
2 changes: 1 addition & 1 deletion enzyme/Enzyme/MLIR/Implementations/MathDerivatives.td
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ def : MLIRDerivative<"math", "CosOp", (Op $x),
>;
def : MLIRDerivative<"math", "ExpOp", (Op $x),
[
(CheckedMulF (DiffeRet), (ExpF $x))
(CheckedMulF (DiffeRet), (Result))
]
>;
def : MLIRDerivative<"math", "SinOp", (Op $x),
Expand Down
21 changes: 21 additions & 0 deletions enzyme/test/MLIR/ForwardMode/exp_wrap.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// RUN: %eopt --enzyme-wrap="infn=main outfn= argTys=enzyme_dup retTys=enzyme_dup mode=ForwardMode" --canonicalize --remove-unnecessary-enzyme-ops %s | FileCheck %s --check-prefix=FWD
// RUN: %eopt --enzyme-wrap="infn=main outfn= argTys=enzyme_active retTys=enzyme_active mode=ReverseModeCombined" --canonicalize --remove-unnecessary-enzyme-ops --enzyme-simplify-math %s | FileCheck %s --check-prefix=REV

module {
func.func @main(%arg0: f32) -> f32 {
%0 = math.exp %arg0 : f32
return %0 : f32
}
}

// FWD: func.func @main(%[[x:.+]]: f32, %[[dx:.+]]: f32) -> (f32, f32) {
// FWD-NEXT: %[[exp:.+]] = math.exp %[[x]] : f32
// FWD-NEXT: %[[shadow:.+]] = arith.mulf %[[dx]], %[[exp]] : f32
// FWD-NEXT: return %[[exp]], %[[shadow]] : f32, f32
// FWD-NEXT: }

// REV: func.func @main(%[[x:.+]]: f32, %[[dret:.+]]: f32) -> f32 {
// REV-NEXT: %[[exp:.+]] = math.exp %[[x]] : f32
// REV-NEXT: %[[darg:.+]] = arith.mulf %[[dret]], %[[exp]] : f32
// REV-NEXT: return %[[darg]] : f32
// REV-NEXT: }
59 changes: 59 additions & 0 deletions enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,47 @@ bool handle(const Twine &curIndent, const Twine &argPattern, raw_ostream &os,
os << curIndent << INDENT << "out; })\n";
}
return true;
} else if (opName == "ResultIndex" || Def->isSubClassOf("ResultIndex")) {
if (intrinsic != MLIRDerivatives) {
PrintFatalError(pattern->getLoc(),
"ResultIndex only supported in MLIR derivatives");
return false;
}

if (resultRoot->getNumArgs() != 0) {
PrintFatalError(pattern->getLoc(), "no arg supported for ResultIndex");
return false;
}

auto index = dyn_cast<IntInit>(Def->getValueInit("index"));

os << "({"
<< " // Computing ResultIndex (index=" << index->getValue() << ")"
<< argPattern << "\n";

llvm::SmallString<32> buf;
if (startsWith(argPattern.toStringRef(buf), "fwdarg")) {
os << curIndent << INDENT
<< "mlir::Value resultIndex = "
"gutils->getNewFromOriginal(op->getResult("
<< index->getValue() << "));\n";
os << curIndent << INDENT
<< "builder.setInsertionPointAfterValue(resultIndex)"
";\n";
os << curIndent << INDENT << "resultIndex;\n";
} else {
os << curIndent << INDENT
<< "auto neededArgs = cachedArguments(op, gutils);\n";
os << curIndent << INDENT << "int numNeededArgs = 0;\n";
os << curIndent << INDENT << "for (auto needed : neededArgs)\n";
os << curIndent << INDENT << INDENT
<< "if (needed) numNeededArgs += 1;\n";
os << curIndent << INDENT << "gutils->popCache(caches[numNeededArgs + "
<< index->getValue() << "], builder);\n";
}
os << curIndent << "})";

return false;
} else if (opName == "TypeOf" || Def->isSubClassOf("TypeOf")) {
if (resultRoot->getNumArgs() != 1)
PrintFatalError(pattern->getLoc(), "only single op TypeOf supported");
Expand Down Expand Up @@ -1408,6 +1449,8 @@ void handleUse(
if (opName == "InactiveArgSpec" || Def->isSubClassOf("InactiveArgSpec")) {
return;
}
if (opName == "ResultIndex" || Def->isSubClassOf("ResultIndex"))
return;
if (!Def->isSubClassOf("Operation")) {
errs() << *resultTree << "\n";
errs() << opName << " " << *Def << "\n";
Expand Down Expand Up @@ -1728,6 +1771,13 @@ static void emitMLIRReverse(raw_ostream &os, const Record *pattern,
os << " return toret;\n";
os << " }\n";

os << " SmallVector<bool> cachedResults(Operation *op,\n"
" MGradientUtilsReverse *gutils) "
"const {\n"
" SmallVector<bool> neededResults(op->getNumResults(), "
"false);\n"
" return neededResults;\n}\n";

os << " SmallVector<Value> cacheValues(Operation *op,\n";
os << " MGradientUtilsReverse *gutils) "
"const {\n";
Expand All @@ -1743,6 +1793,15 @@ static void emitMLIRReverse(raw_ostream &os, const Record *pattern,
"getOperand(en.index())), builder);\n";
os << " toret.push_back(cache);\n";
os << " }\n";
os << " "
"builder.setInsertionPointAfter(gutils->getNewFromOriginal(op));\n";
os << " auto neededResults = cachedResults(op, gutils);\n";
os << " for (auto en : llvm::enumerate(neededResults)) {\n";
os << " Value cache = "
"gutils->initAndPushCache(gutils->getNewFromOriginal(op->getResult(en."
"index())), builder);\n";
os << " toret.push_back(cache);\n";
os << " }";
os << " return toret;\n";
os << " }\n";
os << "\n";
Expand Down
Loading