Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 108 additions & 0 deletions enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
#include "Interfaces/AutoDiffTypeInterface.h"
#include "Interfaces/GradientUtils.h"
#include "Interfaces/GradientUtilsReverse.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/DialectRegistry.h"
#include "mlir/Support/LogicalResult.h"

Expand Down Expand Up @@ -297,6 +299,111 @@ struct InsertValueOpInterfaceReverse
MGradientUtilsReverse *gutils) const {}
};

struct MemcpyOpInterfaceReverse
: public ReverseAutoDiffOpInterface::ExternalModel<MemcpyOpInterfaceReverse,
LLVM::MemcpyOp> {

static Type inferElemType(LLVM::MemcpyOp cp) {
if (auto t = cp->getAttrOfType<TypeAttr>("enzyme.elem_type"))
return t.getValue();
auto walk = [](Value p) -> Type {
for (Operation *user : p.getUsers()) {
if (auto ld = dyn_cast<LLVM::LoadOp>(user))
if (isa<AutoDiffTypeInterface>(ld.getType()))
return ld.getType();
if (auto st = dyn_cast<LLVM::StoreOp>(user))
if (isa<AutoDiffTypeInterface>(st.getValue().getType()))
return st.getValue().getType();
}
return nullptr;
};
if (Type t = walk(cp.getDst()))
return t;
if (Type t = walk(cp.getSrc()))
return t;
return Float64Type::get(cp.getContext());

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is wrong, we should throw an error if things can't be deduced. This is also somewhat hacky and we should use type analysis

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, I will fix it

@xys-syx xys-syx Jun 8, 2026

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wsmoses I have fixed it, but only for the case where the inferred element type is a scalar int or float and using upstream MLIR for inferring types. The memcpy ops in RSBench that copy arrays of struct<(f64, f64)> and array<10 x f64> need to handle it recursively, I plan to open a new PR for it.

}

SmallVector<Value> cacheValues(Operation *op,
MGradientUtilsReverse *gutils) const {
auto cp = cast<LLVM::MemcpyOp>(op);
if (gutils->isConstantValue(cp.getDst()))
return {};
bool srcActive = !gutils->isConstantValue(cp.getSrc());
OpBuilder cb(gutils->getNewFromOriginal(op));
SmallVector<Value> caches;
caches.push_back(
gutils->initAndPushCache(gutils->invertPointerM(cp.getDst(), cb), cb));
caches.push_back(gutils->initAndPushCache(
srcActive ? gutils->invertPointerM(cp.getSrc(), cb)
: gutils->getNewFromOriginal(cp.getSrc()),
cb));
caches.push_back(
gutils->initAndPushCache(gutils->getNewFromOriginal(cp.getLen()), cb));
return caches;
}

void createShadowValues(Operation *op, OpBuilder &builder,
MGradientUtilsReverse *gutils) const {}

LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder,
MGradientUtilsReverse *gutils,
SmallVector<Value> caches) const {
auto cp = cast<LLVM::MemcpyOp>(op);
if (gutils->isConstantValue(cp.getDst()))
return success();
bool srcActive = !gutils->isConstantValue(cp.getSrc());

Value dDst = gutils->popCache(caches[0], builder);
Value dSrc = gutils->popCache(caches[1], builder);
Value len = gutils->popCache(caches[2], builder);

Type elemTy = inferElemType(cp);
auto adt = dyn_cast<AutoDiffTypeInterface>(elemTy);
if (!adt || !elemTy.isIntOrFloat())
return op->emitError()
<< "memcpy reverse: unsupported element type " << elemTy
<< " (annotate enzyme.elem_type or lower to scalar stores)";

Location loc = op->getLoc();
unsigned bytes = (elemTy.getIntOrFloatBitWidth() + 7) / 8;

// n_elements = len / sizeof(elemTy)
Value byteSz =
LLVM::ConstantOp::create(builder, loc, len.getType(),
builder.getIntegerAttr(len.getType(), bytes));
Value nInt = LLVM::SDivOp::create(builder, loc, len, byteSz);
Value n =
arith::IndexCastOp::create(builder, loc, builder.getIndexType(), nInt);

Value c0 = arith::ConstantIndexOp::create(builder, loc, 0);
Value c1 = arith::ConstantIndexOp::create(builder, loc, 1);
Value zeroElem = adt.createNullValue(builder, loc);
Type ptrTy = cp.getDst().getType();

auto forOp = scf::ForOp::create(builder, loc, c0, n, c1);

OpBuilder::InsertionGuard guard(builder);
builder.setInsertionPoint(forOp.getBody()->getTerminator());
Value ivIdx = forOp.getInductionVar();
Value iv = arith::IndexCastOp::create(builder, loc, len.getType(), ivIdx);

Value gDst = LLVM::GEPOp::create(builder, loc, ptrTy, elemTy, dDst,
ArrayRef<LLVM::GEPArg>{iv});
Value vDst = LLVM::LoadOp::create(builder, loc, elemTy, gDst);
if (srcActive) {
Value gSrc = LLVM::GEPOp::create(builder, loc, ptrTy, elemTy, dSrc,
ArrayRef<LLVM::GEPArg>{iv});
Value vSrc = LLVM::LoadOp::create(builder, loc, elemTy, gSrc);
Value sum = adt.createAddOp(builder, loc, vSrc, vDst);
LLVM::StoreOp::create(builder, loc, sum, gSrc);
}
LLVM::StoreOp::create(builder, loc, zeroElem, gDst);

return success();
}
};

std::optional<Value> findPtrSize(Value ptr) {
if (auto allocOp = ptr.getDefiningOp<llvm_ext::AllocOp>())
return allocOp.getSize();
Expand Down Expand Up @@ -467,6 +574,7 @@ void mlir::enzyme::registerLLVMDialectAutoDiffInterface(
*context);
LLVM::InsertValueOp::attachInterface<InsertValueOpInterfaceReverse>(
*context);
LLVM::MemcpyOp::attachInterface<MemcpyOpInterfaceReverse>(*context);
LLVM::UnreachableOp::template attachInterface<
detail::NoopRevAutoDiffInterface<LLVM::UnreachableOp>>(*context);
LLVM::LLVMFuncOp::attachInterface<AutoDiffLLVMFuncOpFunctionInterface>(
Expand Down
4 changes: 3 additions & 1 deletion enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/Builders.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Pass/PassManager.h"
Expand Down Expand Up @@ -53,7 +54,8 @@ struct DifferentiatePass

registry.insert<mlir::arith::ArithDialect, mlir::complex::ComplexDialect,
mlir::cf::ControlFlowDialect, mlir::tensor::TensorDialect,
mlir::linalg::LinalgDialect, mlir::enzyme::EnzymeDialect>();
mlir::scf::SCFDialect, mlir::linalg::LinalgDialect,
mlir::enzyme::EnzymeDialect>();
}

static std::vector<DIFFE_TYPE> mode_from_fn(FunctionOpInterface fn,
Expand Down
1 change: 1 addition & 0 deletions enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "mlir/Interfaces/FunctionInterfaces.h"

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"

#define DEBUG_TYPE "enzyme"

Expand Down
2 changes: 2 additions & 0 deletions enzyme/Enzyme/MLIR/Passes/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def DifferentiatePass : Pass<"enzyme"> {
"complex::ComplexDialect",
"cf::ControlFlowDialect",
"tensor::TensorDialect",
"scf::SCFDialect",
"enzyme::EnzymeDialect",
];
let options = [
Expand Down Expand Up @@ -85,6 +86,7 @@ def DifferentiateWrapperPass : Pass<"enzyme-wrap"> {
"arith::ArithDialect",
"complex::ComplexDialect",
"cf::ControlFlowDialect",
"scf::SCFDialect",
"enzyme::EnzymeDialect"
];
let options = [
Expand Down
64 changes: 64 additions & 0 deletions enzyme/test/MLIR/ReverseMode/memcpy.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// RUN: %eopt --enzyme %s | FileCheck %s

func.func @copy1(%dst: !llvm.ptr, %src: !llvm.ptr, %n: i64) {
"llvm.intr.memcpy"(%dst, %src, %n)
<{arg_attrs = [{llvm.align = 8 : i64}], isVolatile = false}>
: (!llvm.ptr, !llvm.ptr, i64) -> ()
return
}

func.func @dcopy1(%dst: !llvm.ptr, %ddst: !llvm.ptr,
%src: !llvm.ptr, %dsrc: !llvm.ptr, %n: i64) {
enzyme.autodiff @copy1(%dst, %ddst, %src, %dsrc, %n) {
activity = [#enzyme<activity enzyme_dup>,
#enzyme<activity enzyme_dup>,
#enzyme<activity enzyme_const>],
ret_activity = []
} : (!llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, i64) -> ()
return
}

func.func @copy2(%dst: !llvm.ptr, %src: !llvm.ptr, %n: i64) {
"llvm.intr.memcpy"(%dst, %src, %n)
<{arg_attrs = [{llvm.align = 8 : i64}, {llvm.align = 8 : i64}, {}],
isVolatile = false}>
: (!llvm.ptr, !llvm.ptr, i64) -> ()
return
}

func.func @dcopy2(%dst: !llvm.ptr, %ddst: !llvm.ptr,
%src: !llvm.ptr, %dsrc: !llvm.ptr, %n: i64) {
enzyme.autodiff @copy2(%dst, %ddst, %src, %dsrc, %n) {
activity = [#enzyme<activity enzyme_dup>,
#enzyme<activity enzyme_dup>,
#enzyme<activity enzyme_const>],
ret_activity = []
} : (!llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, i64) -> ()
return
}

// CHECK-LABEL: func.func private @diffecopy1(
// Forward: the primal memcpy is preserved.
// CHECK: "llvm.intr.memcpy"
// Reverse: n / sizeof(f64) element-wise loop, d_src[i] += d_dst[i]; d_dst[i]=0.
// CHECK: %[[BYTES:.+]] = llvm.mlir.constant(8 : i64) : i64
// CHECK: llvm.sdiv %{{.+}}, %[[BYTES]] : i64
// CHECK: arith.index_cast
// CHECK: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f64
// CHECK: scf.for
// CHECK: llvm.getelementptr %{{.+}}[%{{.+}}] : (!llvm.ptr, i64) -> !llvm.ptr, f64
// CHECK: llvm.load %{{.+}} : !llvm.ptr -> f64
// CHECK: llvm.getelementptr %{{.+}}[%{{.+}}] : (!llvm.ptr, i64) -> !llvm.ptr, f64
// CHECK: llvm.load %{{.+}} : !llvm.ptr -> f64
// CHECK: %[[SUM:.+]] = arith.addf
// CHECK: llvm.store %[[SUM]], %{{.+}} : f64, !llvm.ptr
// CHECK: llvm.store %[[ZERO]], %{{.+}} : f64, !llvm.ptr

// CHECK-LABEL: func.func private @diffecopy2(
// CHECK: "llvm.intr.memcpy"
// CHECK: llvm.mlir.constant(8 : i64) : i64
// CHECK: %[[ZERO2:.+]] = arith.constant 0.000000e+00 : f64
// CHECK: scf.for
// CHECK: %[[SUM2:.+]] = arith.addf
// CHECK: llvm.store %[[SUM2]], %{{.+}} : f64, !llvm.ptr
// CHECK: llvm.store %[[ZERO2]], %{{.+}} : f64, !llvm.ptr
Loading