Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions enzyme/Enzyme/MLIR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ add_subdirectory(Dialect)
add_subdirectory(Interfaces)
add_subdirectory(Implementations)
add_subdirectory(Passes)
add_subdirectory(Integrations)
add_subdirectory(enzymemlir-translate)
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
Expand Down
1 change: 1 addition & 0 deletions enzyme/Enzyme/MLIR/Integrations/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
add_subdirectory(c)
19 changes: 19 additions & 0 deletions enzyme/Enzyme/MLIR/Integrations/c/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
add_mlir_public_c_api_library(MLIRCAPIEnzyme
EnzymeMLIR.cpp

DEPENDS
MLIRImpulseEnumsIncGen
MLIRImpulseAttributesIncGen

LINK_LIBS PUBLIC
MLIRIR
MLIRCAPIIR

LINK_LIBS PRIVATE
MLIREnzyme
MLIRImpulse
MLIREnzymeTransforms
MLIREnzymeAnalysis
MLIREnzymeAutoDiffInterface
MLIREnzymeImplementations
)
181 changes: 169 additions & 12 deletions enzyme/Enzyme/MLIR/Integrations/c/EnzymeMLIR.cpp
Original file line number Diff line number Diff line change
@@ -1,13 +1,168 @@
//===- EnzymeMLIR.cpp - C API for Enzyme MLIR dialect ---------------------===//
//
// Part of the Enzyme Project, under the Apache License v2.0 with LLVM
// Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "EnzymeMLIR.h"

#include "mlir/CAPI/IR.h"
#include "mlir/CAPI/Pass.h"
#include "mlir/CAPI/Registration.h"

#include "Dialect/Dialect.h"
#include "Dialect/Impulse/Impulse.h"
#include "Dialect/Ops.h"
#include "Implementations/CoreDialectsAutoDiffImplementations.h"
#include "Passes/Passes.h"

MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Enzyme, enzyme,
mlir::enzyme::EnzymeDialect)

MLIR_CAPI_EXPORTED void enzymeRegisterPasses(void) {
mlir::enzyme::registerDifferentiatePass();
mlir::enzyme::registerExpandImpulsePass();
mlir::enzyme::registerBatchPass();
mlir::enzyme::registerBatchDiffPass();
mlir::enzyme::registerDifferentiateWrapperPass();
mlir::enzyme::registerInlineEnzymeIntoRegionPass();
mlir::enzyme::registerOutlineEnzymeFromRegionPass();
mlir::enzyme::registerPrintActivityAnalysisPass();
mlir::enzyme::registerPrintAliasAnalysisPass();
mlir::enzyme::registerRemoveUnusedEnzymeOpsPass();
}

MLIR_CAPI_EXPORTED void
enzymeRegisterDialectExtensions(MlirDialectRegistry registry) {
mlir::enzyme::registerCoreDialectAutodiffInterfaces(*unwrap(registry));
}

MLIR_CAPI_EXPORTED MlirPass enzymeCreateDifferentiatePass(void) {
return wrap(mlir::enzyme::createDifferentiatePass().release());
}

MLIR_CAPI_EXPORTED MlirPass enzymeCreateDifferentiatePassWithOptions(
MlirStringRef postpasses, bool verifyPostPasses) {
mlir::enzyme::DifferentiatePassOptions opts;
opts.postpasses = std::string(postpasses.data, postpasses.length);
opts.verifyPostPasses = verifyPostPasses;
return wrap(mlir::enzyme::createDifferentiatePass(opts).release());
}

MLIR_CAPI_EXPORTED MlirPass enzymeCreateConvertEnzymeToMemRefPass(void) {
return wrap(mlir::enzyme::createEnzymeOpsToMemRefPass().release());
}

MLIR_CAPI_EXPORTED MlirPass enzymeCreateBatchPass(void) {
return wrap(mlir::enzyme::createBatchPass().release());
}

MLIR_CAPI_EXPORTED MlirPass enzymeCreateBatchDiffPass(void) {
return wrap(mlir::enzyme::createBatchDiffPass().release());
}

MLIR_CAPI_EXPORTED MlirPass enzymeCreateRemoveUnusedEnzymeOpsPass(void) {
return wrap(mlir::enzyme::createRemoveUnusedEnzymeOpsPass().release());
}

MLIR_CAPI_EXPORTED MlirAttribute enzymeActivityAttrGet(MlirContext ctx,

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the feedback! Will simplify enzymeActivityAttrGet. Also wondering if we shouldn't remove enzymeAutoDiffOpCreate, enzymeForwardDiffOpCreate, enzymeJacobianOpCreate, enzymeBatchOpCreate, and activityArrayAttr from this PR (the rest could be built downstream). What do you think, worth keeping?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto for enzymeRegisterPasses

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for Op creation, at least in Reactant.jl, we always use mlirOperationCreate, so they are clearly not required.

enzymeRegisterPasses might be required though because AFAIK there's no other way to create a pass without calling the C++ API right?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

even if not required I'm okay with keeping it, if it makes your or other downstream folks life easier

@agarret7 agarret7 Jun 9, 2026

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

enzymeRegisterPasses also isn't strictly needed today. Pass constructors return the pass directly, so you can add to a local PassManager without the global registry (unsafety can be wrapped):

pm.add_pass(unsafe {
    melior::pass::Pass::from_raw_fn(enzymeCreateDifferentiatePass)
});

The registry does matter for MLIR string parsing (mlir-opt --pass-pipeline=...), which melior already exposes through parse_pass_pipeline.

Given that, I suggest:

  • Keep exposing the individual enzymeCreate*Pass constructors as it's needed to append to PassManager directly
  • Keep enzymeRegisterPasses for the mlir-opt pipeline.
  • Drop the AD *Op-construction helpers, since both should be expressible with the help of melior's existing builders (OperationBuilder + enzymeActivityAttrGet).

Curious what you think.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd go for keeping them, if it may make folks lives easier longer term [not just the work we're doing in julia or rust right now]

uint32_t activity) {
return wrap(mlir::enzyme::ActivityAttr::get(
unwrap(ctx), (mlir::enzyme::Activity)activity));
}

static mlir::ArrayAttr
activityArrayAttr(MlirContext ctx, const MlirAttribute *activity, intptr_t n) {
return mlir::ArrayAttr::get(
unwrap(ctx), llvm::ArrayRef<mlir::Attribute>(
reinterpret_cast<const mlir::Attribute *>(activity), n));
}

MLIR_CAPI_EXPORTED MlirOperation enzymeAutoDiffOpCreate(
MlirContext ctx, MlirStringRef fn, const MlirType *resultTypes,
intptr_t nResults, const MlirValue *inputs, intptr_t nInputs,
const MlirAttribute *activity, intptr_t nActivity,
const MlirAttribute *retActivity, intptr_t nRetActivity, int64_t width,
bool strongZero, MlirLocation loc) {
auto op =
mlir::OpBuilder(unwrap(ctx))
.create<mlir::enzyme::AutoDiffOp>(
unwrap(loc),
mlir::TypeRange(llvm::ArrayRef<mlir::Type>(
reinterpret_cast<const mlir::Type *>(resultTypes), nResults)),
llvm::StringRef(fn.data, fn.length),
mlir::ValueRange(llvm::ArrayRef<mlir::Value>(
reinterpret_cast<const mlir::Value *>(inputs), nInputs)),
activityArrayAttr(ctx, activity, nActivity),
activityArrayAttr(ctx, retActivity, nRetActivity),
(uint64_t)width, strongZero);
return wrap(op.getOperation());
}

MLIR_CAPI_EXPORTED MlirOperation enzymeForwardDiffOpCreate(
MlirContext ctx, MlirStringRef fn, const MlirType *resultTypes,
intptr_t nResults, const MlirValue *inputs, intptr_t nInputs,
const MlirAttribute *activity, intptr_t nActivity,
const MlirAttribute *retActivity, intptr_t nRetActivity, int64_t width,
bool strongZero, MlirLocation loc) {
auto op =
mlir::OpBuilder(unwrap(ctx))
.create<mlir::enzyme::ForwardDiffOp>(
unwrap(loc),
mlir::TypeRange(llvm::ArrayRef<mlir::Type>(
reinterpret_cast<const mlir::Type *>(resultTypes), nResults)),
llvm::StringRef(fn.data, fn.length),
mlir::ValueRange(llvm::ArrayRef<mlir::Value>(
reinterpret_cast<const mlir::Value *>(inputs), nInputs)),
activityArrayAttr(ctx, activity, nActivity),
activityArrayAttr(ctx, retActivity, nRetActivity),
(uint64_t)width, strongZero);
return wrap(op.getOperation());
}

MLIR_CAPI_EXPORTED MlirOperation enzymeJacobianOpCreate(
MlirContext ctx, MlirStringRef fn, const MlirType *resultTypes,
intptr_t nResults, const MlirValue *inputs, intptr_t nInputs,
const MlirAttribute *activity, intptr_t nActivity,
const MlirAttribute *retActivity, intptr_t nRetActivity, int64_t width,
bool strongZero, MlirLocation loc) {
auto op =
mlir::OpBuilder(unwrap(ctx))
.create<mlir::enzyme::JacobianOp>(
unwrap(loc),
mlir::TypeRange(llvm::ArrayRef<mlir::Type>(
reinterpret_cast<const mlir::Type *>(resultTypes), nResults)),
llvm::StringRef(fn.data, fn.length),
mlir::ValueRange(llvm::ArrayRef<mlir::Value>(
reinterpret_cast<const mlir::Value *>(inputs), nInputs)),
activityArrayAttr(ctx, activity, nActivity),
activityArrayAttr(ctx, retActivity, nRetActivity),
(uint64_t)width, strongZero);
return wrap(op.getOperation());
}

MLIR_CAPI_EXPORTED MlirOperation enzymeBatchOpCreate(
MlirContext ctx, MlirStringRef fn, const MlirType *resultTypes,
intptr_t nResults, const MlirValue *inputs, intptr_t nInputs,
const int64_t *batchShape, intptr_t nBatchShape, MlirLocation loc) {
auto op =
mlir::OpBuilder(unwrap(ctx))
.create<mlir::enzyme::BatchOp>(
unwrap(loc),
mlir::TypeRange(llvm::ArrayRef<mlir::Type>(
reinterpret_cast<const mlir::Type *>(resultTypes), nResults)),
llvm::StringRef(fn.data, fn.length),
mlir::ValueRange(llvm::ArrayRef<mlir::Value>(
reinterpret_cast<const mlir::Value *>(inputs), nInputs)),
llvm::ArrayRef<int64_t>(batchShape, nBatchShape));
return wrap(op.getOperation());
}

MlirAttribute enzymeRngDistributionAttrGet(MlirContext ctx,
EnzymeRngDistribution dist) {
MLIR_CAPI_EXPORTED MlirAttribute
enzymeRngDistributionAttrGet(MlirContext ctx, EnzymeRngDistribution dist) {
mlir::impulse::RngDistribution rngDist;
switch (dist) {
case EnzymeRngDistribution_Uniform:
Expand All @@ -23,9 +178,9 @@ MlirAttribute enzymeRngDistributionAttrGet(MlirContext ctx,
return wrap(mlir::impulse::RngDistributionAttr::get(unwrap(ctx), rngDist));
}

MlirAttribute enzymeSupportAttrGet(MlirContext ctx, EnzymeSupportKind kind,
bool hasLowerBound, double lowerBound,
bool hasUpperBound, double upperBound) {
MLIR_CAPI_EXPORTED MlirAttribute enzymeSupportAttrGet(
MlirContext ctx, EnzymeSupportKind kind, bool hasLowerBound,
double lowerBound, bool hasUpperBound, double upperBound) {
auto *mlirCtx = unwrap(ctx);

mlir::impulse::SupportKind supportKind;
Expand Down Expand Up @@ -64,8 +219,10 @@ MlirAttribute enzymeSupportAttrGet(MlirContext ctx, EnzymeSupportKind kind,
upperAttr));
}

MlirAttribute enzymeHMCConfigAttrGet(MlirContext ctx, double trajectoryLength,
bool adaptStepSize, bool adaptMassMatrix) {
MLIR_CAPI_EXPORTED MlirAttribute enzymeHMCConfigAttrGet(MlirContext ctx,
double trajectoryLength,
bool adaptStepSize,
bool adaptMassMatrix) {
auto *mlirCtx = unwrap(ctx);
auto trajectoryLengthAttr =
mlir::FloatAttr::get(mlir::Float64Type::get(mlirCtx), trajectoryLength);
Expand All @@ -74,10 +231,9 @@ MlirAttribute enzymeHMCConfigAttrGet(MlirContext ctx, double trajectoryLength,
mlirCtx, trajectoryLengthAttr, adaptStepSize, adaptMassMatrix));
}

MlirAttribute enzymeNUTSConfigAttrGet(MlirContext ctx, int64_t maxTreeDepth,
bool hasMaxDeltaEnergy,
double maxDeltaEnergy, bool adaptStepSize,
bool adaptMassMatrix) {
MLIR_CAPI_EXPORTED MlirAttribute enzymeNUTSConfigAttrGet(
MlirContext ctx, int64_t maxTreeDepth, bool hasMaxDeltaEnergy,
double maxDeltaEnergy, bool adaptStepSize, bool adaptMassMatrix) {
auto *mlirCtx = unwrap(ctx);

mlir::FloatAttr maxDeltaEnergyAttr;
Expand All @@ -90,6 +246,7 @@ MlirAttribute enzymeNUTSConfigAttrGet(MlirContext ctx, int64_t maxTreeDepth,
adaptMassMatrix));
}

MlirAttribute enzymeSymbolAttrGet(MlirContext ctx, uint64_t ptr) {
MLIR_CAPI_EXPORTED MlirAttribute enzymeSymbolAttrGet(MlirContext ctx,
uint64_t ptr) {
return wrap(mlir::impulse::SymbolAttr::get(unwrap(ctx), ptr));
}
109 changes: 109 additions & 0 deletions enzyme/Enzyme/MLIR/Integrations/c/EnzymeMLIR.h
Original file line number Diff line number Diff line change
@@ -1,16 +1,125 @@
//===- EnzymeMLIR.h - C API for Enzyme MLIR dialect ---------------*- C -*-===//
//
// Part of the Enzyme Project, under the Apache License v2.0 with LLVM
// Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This header declares the C interface to the Enzyme MLIR dialect.
//
//===----------------------------------------------------------------------===//

#ifndef ENZYME_MLIR_INTEGRATIONS_C_ENZYMEMLIR_H_
#define ENZYME_MLIR_INTEGRATIONS_C_ENZYMEMLIR_H_

#include <stdbool.h>
#include <stdint.h>

#include "mlir-c/IR.h"
#include "mlir-c/Pass.h"
#include "mlir-c/Support.h"

#ifdef __cplusplus
extern "C" {
#endif

//===----------------------------------------------------------------------===//
// Dialect
//===----------------------------------------------------------------------===//

MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Enzyme, enzyme);

//===----------------------------------------------------------------------===//
// Activity attribute
//===----------------------------------------------------------------------===//

/// Construct an `enzyme.activity` attribute from an activity code matching the
/// `Activity` enum ordinals (0=active, 1=dup, 2=const, 3=dupnoneed, ...).
MLIR_CAPI_EXPORTED MlirAttribute enzymeActivityAttrGet(MlirContext ctx,
uint32_t activity);

//===----------------------------------------------------------------------===//
// Pass registration
//===----------------------------------------------------------------------===//

/// Register all Enzyme MLIR passes with the global registry.
MLIR_CAPI_EXPORTED void enzymeRegisterPasses(void);

/// Register Enzyme's external model AD implementations for the core MLIR
/// dialects (arith, func, scf, ...) into the given DialectRegistry.
/// Must be called before appending the registry to an MLIRContext so that
/// the Differentiate pass knows how to lower ops in those dialects.
MLIR_CAPI_EXPORTED void
enzymeRegisterDialectExtensions(MlirDialectRegistry registry);

/// Create the core differentiation pass (`--enzyme`).
MLIR_CAPI_EXPORTED MlirPass enzymeCreateDifferentiatePass(void);

/// Create the core differentiation pass with options.
/// postpasses: semicolon-separated list of passes to run after differentiation
/// (may be NULL or empty). verifyPostPasses: run the verifier after each
/// post-pass.
MLIR_CAPI_EXPORTED MlirPass enzymeCreateDifferentiatePassWithOptions(
MlirStringRef postpasses, bool verifyPostPasses);

/// Create the `--enzyme-ops-to-memref` pass, which lowers `enzyme.init`,
/// `enzyme.get`, and `enzyme.set` to `memref` + `scf.if`. Required before
/// lowering to LLVM.
MLIR_CAPI_EXPORTED MlirPass enzymeCreateConvertEnzymeToMemRefPass(void);

/// Create the `--enzyme-batch` pass, which lowers `enzyme.batch` ops to
/// calls to generated `@batched_*` functions.
MLIR_CAPI_EXPORTED MlirPass enzymeCreateBatchPass(void);

/// Create the `--enzyme-batch-diff` pass, which differentiates batched
/// functions produced by `--enzyme-batch`.
MLIR_CAPI_EXPORTED MlirPass enzymeCreateBatchDiffPass(void);

/// Create the `--remove-unnecessary-enzyme-ops` pass, which removes dead
/// `enzyme.*` ops after differentiation.
MLIR_CAPI_EXPORTED MlirPass enzymeCreateRemoveUnusedEnzymeOpsPass(void);

//===----------------------------------------------------------------------===//
// Core AD ops
//===----------------------------------------------------------------------===//

/// Construct an `enzyme.autodiff` (reverse-mode) op referencing the function
/// named `fn`. `activity`/`retActivity` are arrays of `enzyme.activity`
/// attributes of length `nActivity`/`nRetActivity`.
MLIR_CAPI_EXPORTED MlirOperation enzymeAutoDiffOpCreate(
MlirContext ctx, MlirStringRef fn, const MlirType *resultTypes,
intptr_t nResults, const MlirValue *inputs, intptr_t nInputs,
const MlirAttribute *activity, intptr_t nActivity,
const MlirAttribute *retActivity, intptr_t nRetActivity, int64_t width,
bool strongZero, MlirLocation loc);

/// Construct an `enzyme.fwddiff` (forward-mode) op referencing the function
/// named `fn`. See `enzymeAutoDiffOpCreate` for parameter semantics.
MLIR_CAPI_EXPORTED MlirOperation enzymeForwardDiffOpCreate(
MlirContext ctx, MlirStringRef fn, const MlirType *resultTypes,
intptr_t nResults, const MlirValue *inputs, intptr_t nInputs,
const MlirAttribute *activity, intptr_t nActivity,
const MlirAttribute *retActivity, intptr_t nRetActivity, int64_t width,
bool strongZero, MlirLocation loc);

/// Construct an `enzyme.jacobian` op referencing the function named `fn`.
/// See `enzymeAutoDiffOpCreate` for parameter semantics.
MLIR_CAPI_EXPORTED MlirOperation enzymeJacobianOpCreate(
MlirContext ctx, MlirStringRef fn, const MlirType *resultTypes,
intptr_t nResults, const MlirValue *inputs, intptr_t nInputs,
const MlirAttribute *activity, intptr_t nActivity,
const MlirAttribute *retActivity, intptr_t nRetActivity, int64_t width,
bool strongZero, MlirLocation loc);

/// Construct an `enzyme.batch` op that maps the function named `fn` over a
/// batch of inputs. `batchShape` is an array of `nBatchShape` batch dimensions.
MLIR_CAPI_EXPORTED MlirOperation enzymeBatchOpCreate(
MlirContext ctx, MlirStringRef fn, const MlirType *resultTypes,
intptr_t nResults, const MlirValue *inputs, intptr_t nInputs,
const int64_t *batchShape, intptr_t nBatchShape, MlirLocation loc);

//===----------------------------------------------------------------------===//
// Probabilistic Programming Ops
//===----------------------------------------------------------------------===//
Expand Down