From 6ddddf9b28a976a379b3e782aefbb9c94daa8486 Mon Sep 17 00:00:00 2001 From: Austin Garrett Date: Sat, 6 Jun 2026 19:42:14 -0400 Subject: [PATCH 1/5] [MLIR] Add C API for Enzyme dialect (MLIRCAPIEnzyme) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds the missing C API layer for the Enzyme MLIR dialect, following the same pattern as upstream MLIR's mlir-c/ libraries: - `MLIR_DECLARE/DEFINE_CAPI_DIALECT_REGISTRATION` for the enzyme dialect - `enzymeActivityAttrGet` — constructs an `enzyme.activity` attr from the C-side `EnzymeActivity` enum (enzyme_active, enzyme_dup, enzyme_const, …) - `enzymeAutoDiffOpCreate` / `enzymeForwardDiffOpCreate` — construct unattached `enzyme.autodiff` / `enzyme.fwddiff` ops from C - `Integrations/c/CMakeLists.txt` wiring up `MLIRCAPIEnzyme` as an `add_mlir_public_c_api_library` target - `Integrations/CMakeLists.txt` + hook in `MLIR/CMakeLists.txt` This enables foreign-language bindings (Rust, Python, etc.) to emit Enzyme dialect ops without depending on C++ headers. --- enzyme/Enzyme/MLIR/CMakeLists.txt | 1 + .../Enzyme/MLIR/Integrations/CMakeLists.txt | 1 + .../Enzyme/MLIR/Integrations/c/CMakeLists.txt | 7 ++ .../Enzyme/MLIR/Integrations/c/EnzymeMLIR.cpp | 90 +++++++++++++++++++ .../Enzyme/MLIR/Integrations/c/EnzymeMLIR.h | 38 ++++++++ enzyme/test/MLIR/CAPI/roundtrip.mlir | 29 ++++++ 6 files changed, 166 insertions(+) create mode 100644 enzyme/Enzyme/MLIR/Integrations/CMakeLists.txt create mode 100644 enzyme/Enzyme/MLIR/Integrations/c/CMakeLists.txt create mode 100644 enzyme/test/MLIR/CAPI/roundtrip.mlir diff --git a/enzyme/Enzyme/MLIR/CMakeLists.txt b/enzyme/Enzyme/MLIR/CMakeLists.txt index 157ec56f2fd9..f055873bca0a 100644 --- a/enzyme/Enzyme/MLIR/CMakeLists.txt +++ b/enzyme/Enzyme/MLIR/CMakeLists.txt @@ -5,6 +5,7 @@ add_subdirectory(Dialect) add_subdirectory(Interfaces) add_subdirectory(Implementations) add_subdirectory(Passes) +add_subdirectory(Integrations) add_subdirectory(enzymemlir-translate) get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) diff --git a/enzyme/Enzyme/MLIR/Integrations/CMakeLists.txt b/enzyme/Enzyme/MLIR/Integrations/CMakeLists.txt new file mode 100644 index 000000000000..20bc4cc4b07a --- /dev/null +++ b/enzyme/Enzyme/MLIR/Integrations/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(c) diff --git a/enzyme/Enzyme/MLIR/Integrations/c/CMakeLists.txt b/enzyme/Enzyme/MLIR/Integrations/c/CMakeLists.txt new file mode 100644 index 000000000000..085d625fe7a9 --- /dev/null +++ b/enzyme/Enzyme/MLIR/Integrations/c/CMakeLists.txt @@ -0,0 +1,7 @@ +add_mlir_public_c_api_library(MLIRCAPIEnzyme + EnzymeMLIR.cpp + + LINK_LIBS PUBLIC + MLIRIR + MLIREnzyme +) diff --git a/enzyme/Enzyme/MLIR/Integrations/c/EnzymeMLIR.cpp b/enzyme/Enzyme/MLIR/Integrations/c/EnzymeMLIR.cpp index c65709b22de4..c0035163a424 100644 --- a/enzyme/Enzyme/MLIR/Integrations/c/EnzymeMLIR.cpp +++ b/enzyme/Enzyme/MLIR/Integrations/c/EnzymeMLIR.cpp @@ -1,10 +1,100 @@ #include "EnzymeMLIR.h" #include "mlir/CAPI/IR.h" +#include "mlir/CAPI/Registration.h" #include "Dialect/Dialect.h" #include "Dialect/Impulse/Impulse.h" #include "Dialect/Ops.h" +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Enzyme, enzyme, + mlir::enzyme::EnzymeDialect) + +MlirAttribute enzymeActivityAttrGet(MlirContext ctx, EnzymeActivity activity) { + mlir::enzyme::Activity act; + switch (activity) { + case EnzymeActivity_enzyme_active: + act = mlir::enzyme::Activity::enzyme_active; + break; + case EnzymeActivity_enzyme_dup: + act = mlir::enzyme::Activity::enzyme_dup; + break; + case EnzymeActivity_enzyme_const: + act = mlir::enzyme::Activity::enzyme_const; + break; + case EnzymeActivity_enzyme_dupnoneed: + act = mlir::enzyme::Activity::enzyme_dupnoneed; + break; + case EnzymeActivity_enzyme_activenoneed: + act = mlir::enzyme::Activity::enzyme_activenoneed; + break; + case EnzymeActivity_enzyme_constnoneed: + act = mlir::enzyme::Activity::enzyme_constnoneed; + break; + } + return wrap(mlir::enzyme::ActivityAttr::get(unwrap(ctx), act)); +} + +static mlir::ArrayAttr activityArrayAttr(MlirContext ctx, + MlirAttribute *activity, intptr_t n) { + llvm::SmallVector attrs; + attrs.reserve(n); + for (intptr_t i = 0; i < n; ++i) + attrs.push_back(unwrap(activity[i])); + return mlir::ArrayAttr::get(unwrap(ctx), attrs); +} + +static void collectTypes(MlirType *src, intptr_t n, + llvm::SmallVectorImpl &out) { + out.reserve(n); + for (intptr_t i = 0; i < n; ++i) + out.push_back(unwrap(src[i])); +} + +static void collectValues(MlirValue *src, intptr_t n, + llvm::SmallVectorImpl &out) { + out.reserve(n); + for (intptr_t i = 0; i < n; ++i) + out.push_back(unwrap(src[i])); +} + +MlirOperation +enzymeAutoDiffOpCreate(MlirContext ctx, MlirStringRef fn, MlirType *resultTypes, + intptr_t nResults, MlirValue *inputs, intptr_t nInputs, + MlirAttribute *activity, intptr_t nActivity, + MlirAttribute *retActivity, intptr_t nRetActivity, + int64_t width, bool strongZero, MlirLocation loc) { + auto *mlirCtx = unwrap(ctx); + llvm::SmallVector results; + collectTypes(resultTypes, nResults, results); + llvm::SmallVector operands; + collectValues(inputs, nInputs, operands); + auto op = mlir::OpBuilder(mlirCtx).create( + unwrap(loc), mlir::TypeRange(results), + llvm::StringRef(fn.data, fn.length), mlir::ValueRange(operands), + activityArrayAttr(ctx, activity, nActivity), + activityArrayAttr(ctx, retActivity, nRetActivity), (uint64_t)width, + strongZero); + return wrap(op.getOperation()); +} + +MlirOperation enzymeForwardDiffOpCreate( + MlirContext ctx, MlirStringRef fn, MlirType *resultTypes, intptr_t nResults, + MlirValue *inputs, intptr_t nInputs, MlirAttribute *activity, + intptr_t nActivity, MlirAttribute *retActivity, intptr_t nRetActivity, + int64_t width, bool strongZero, MlirLocation loc) { + auto *mlirCtx = unwrap(ctx); + llvm::SmallVector results; + collectTypes(resultTypes, nResults, results); + llvm::SmallVector operands; + collectValues(inputs, nInputs, operands); + auto op = mlir::OpBuilder(mlirCtx).create( + unwrap(loc), mlir::TypeRange(results), + llvm::StringRef(fn.data, fn.length), mlir::ValueRange(operands), + activityArrayAttr(ctx, activity, nActivity), + activityArrayAttr(ctx, retActivity, nRetActivity), (uint64_t)width, + strongZero); + return wrap(op.getOperation()); +} MlirAttribute enzymeRngDistributionAttrGet(MlirContext ctx, EnzymeRngDistribution dist) { diff --git a/enzyme/Enzyme/MLIR/Integrations/c/EnzymeMLIR.h b/enzyme/Enzyme/MLIR/Integrations/c/EnzymeMLIR.h index 62336c34b4b1..56d6df3749da 100644 --- a/enzyme/Enzyme/MLIR/Integrations/c/EnzymeMLIR.h +++ b/enzyme/Enzyme/MLIR/Integrations/c/EnzymeMLIR.h @@ -11,6 +11,44 @@ extern "C" { #endif +//===----------------------------------------------------------------------===// +// Dialect +//===----------------------------------------------------------------------===// + +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Enzyme, enzyme); + +//===----------------------------------------------------------------------===// +// Activity attribute +//===----------------------------------------------------------------------===// + +typedef enum { + EnzymeActivity_enzyme_active = 0, + EnzymeActivity_enzyme_dup = 1, + EnzymeActivity_enzyme_const = 2, + EnzymeActivity_enzyme_dupnoneed = 3, + EnzymeActivity_enzyme_activenoneed = 4, + EnzymeActivity_enzyme_constnoneed = 5, +} EnzymeActivity; + +MLIR_CAPI_EXPORTED MlirAttribute enzymeActivityAttrGet(MlirContext ctx, + EnzymeActivity activity); + +//===----------------------------------------------------------------------===// +// Core AD ops +//===----------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED MlirOperation enzymeAutoDiffOpCreate( + MlirContext ctx, MlirStringRef fn, MlirType *resultTypes, intptr_t nResults, + MlirValue *inputs, intptr_t nInputs, MlirAttribute *activity, + intptr_t nActivity, MlirAttribute *retActivity, intptr_t nRetActivity, + int64_t width, bool strongZero, MlirLocation loc); + +MLIR_CAPI_EXPORTED MlirOperation enzymeForwardDiffOpCreate( + MlirContext ctx, MlirStringRef fn, MlirType *resultTypes, intptr_t nResults, + MlirValue *inputs, intptr_t nInputs, MlirAttribute *activity, + intptr_t nActivity, MlirAttribute *retActivity, intptr_t nRetActivity, + int64_t width, bool strongZero, MlirLocation loc); + //===----------------------------------------------------------------------===// // Probabilistic Programming Ops //===----------------------------------------------------------------------===// diff --git a/enzyme/test/MLIR/CAPI/roundtrip.mlir b/enzyme/test/MLIR/CAPI/roundtrip.mlir new file mode 100644 index 000000000000..a01e574bcc91 --- /dev/null +++ b/enzyme/test/MLIR/CAPI/roundtrip.mlir @@ -0,0 +1,29 @@ +// Verify that enzyme.autodiff and enzyme.fwddiff ops — the ops constructed by +// the MLIRCAPIEnzyme C API — parse and round-trip correctly. +// +// RUN: %eopt %s | FileCheck %s + +module { + func.func @square(%x: f64) -> f64 { + %0 = arith.mulf %x, %x : f64 + return %0 : f64 + } + + // CHECK-LABEL: func.func @test_autodiff + func.func @test_autodiff(%x: f64, %dx: f64) -> f64 { + // CHECK: enzyme.autodiff @square(%{{.*}}, %{{.*}}) + // CHECK-SAME: activity = [#enzyme] + // CHECK-SAME: ret_activity = [#enzyme] + %0 = enzyme.autodiff @square(%x, %dx) {activity = [#enzyme], ret_activity = [#enzyme]} : (f64, f64) -> f64 + return %0 : f64 + } + + // CHECK-LABEL: func.func @test_fwddiff + func.func @test_fwddiff(%x: f64, %dx: f64) -> f64 { + // CHECK: enzyme.fwddiff @square(%{{.*}}, %{{.*}}) + // CHECK-SAME: activity = [#enzyme] + // CHECK-SAME: ret_activity = [#enzyme] + %0 = enzyme.fwddiff @square(%x, %dx) {activity = [#enzyme], ret_activity = [#enzyme]} : (f64, f64) -> f64 + return %0 : f64 + } +} From ece264d6f65d1da1f76d671ff4bd495e72b1a482 Mon Sep 17 00:00:00 2001 From: Austin Garrett Date: Sat, 6 Jun 2026 20:15:17 -0400 Subject: [PATCH 2/5] fix roundtrip test: avoid greedy %{{.*}} consuming attributes --- enzyme/test/MLIR/CAPI/roundtrip.mlir | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/enzyme/test/MLIR/CAPI/roundtrip.mlir b/enzyme/test/MLIR/CAPI/roundtrip.mlir index a01e574bcc91..24b2e93ae23f 100644 --- a/enzyme/test/MLIR/CAPI/roundtrip.mlir +++ b/enzyme/test/MLIR/CAPI/roundtrip.mlir @@ -11,7 +11,7 @@ module { // CHECK-LABEL: func.func @test_autodiff func.func @test_autodiff(%x: f64, %dx: f64) -> f64 { - // CHECK: enzyme.autodiff @square(%{{.*}}, %{{.*}}) + // CHECK: enzyme.autodiff @square // CHECK-SAME: activity = [#enzyme] // CHECK-SAME: ret_activity = [#enzyme] %0 = enzyme.autodiff @square(%x, %dx) {activity = [#enzyme], ret_activity = [#enzyme]} : (f64, f64) -> f64 @@ -20,7 +20,7 @@ module { // CHECK-LABEL: func.func @test_fwddiff func.func @test_fwddiff(%x: f64, %dx: f64) -> f64 { - // CHECK: enzyme.fwddiff @square(%{{.*}}, %{{.*}}) + // CHECK: enzyme.fwddiff @square // CHECK-SAME: activity = [#enzyme] // CHECK-SAME: ret_activity = [#enzyme] %0 = enzyme.fwddiff @square(%x, %dx) {activity = [#enzyme], ret_activity = [#enzyme]} : (f64, f64) -> f64 From 9dc699c8dca817a2aaf35d5f890dd6ac0f3fda74 Mon Sep 17 00:00:00 2001 From: Austin Garrett Date: Sat, 6 Jun 2026 20:44:56 -0400 Subject: [PATCH 3/5] remove redundant test --- enzyme/test/MLIR/CAPI/roundtrip.mlir | 29 ---------------------------- 1 file changed, 29 deletions(-) delete mode 100644 enzyme/test/MLIR/CAPI/roundtrip.mlir diff --git a/enzyme/test/MLIR/CAPI/roundtrip.mlir b/enzyme/test/MLIR/CAPI/roundtrip.mlir deleted file mode 100644 index 24b2e93ae23f..000000000000 --- a/enzyme/test/MLIR/CAPI/roundtrip.mlir +++ /dev/null @@ -1,29 +0,0 @@ -// Verify that enzyme.autodiff and enzyme.fwddiff ops — the ops constructed by -// the MLIRCAPIEnzyme C API — parse and round-trip correctly. -// -// RUN: %eopt %s | FileCheck %s - -module { - func.func @square(%x: f64) -> f64 { - %0 = arith.mulf %x, %x : f64 - return %0 : f64 - } - - // CHECK-LABEL: func.func @test_autodiff - func.func @test_autodiff(%x: f64, %dx: f64) -> f64 { - // CHECK: enzyme.autodiff @square - // CHECK-SAME: activity = [#enzyme] - // CHECK-SAME: ret_activity = [#enzyme] - %0 = enzyme.autodiff @square(%x, %dx) {activity = [#enzyme], ret_activity = [#enzyme]} : (f64, f64) -> f64 - return %0 : f64 - } - - // CHECK-LABEL: func.func @test_fwddiff - func.func @test_fwddiff(%x: f64, %dx: f64) -> f64 { - // CHECK: enzyme.fwddiff @square - // CHECK-SAME: activity = [#enzyme] - // CHECK-SAME: ret_activity = [#enzyme] - %0 = enzyme.fwddiff @square(%x, %dx) {activity = [#enzyme], ret_activity = [#enzyme]} : (f64, f64) -> f64 - return %0 : f64 - } -} From ff91852e8f50107c9e3f3d073f2874ec3ee5efe2 Mon Sep 17 00:00:00 2001 From: Austin Garrett Date: Mon, 8 Jun 2026 11:03:20 -0400 Subject: [PATCH 4/5] [MLIR] add pass creation fns, batch op constructor, MLIR_CAPI_EXPORTED on all definitions --- .../Enzyme/MLIR/Integrations/c/CMakeLists.txt | 12 + .../Enzyme/MLIR/Integrations/c/EnzymeMLIR.cpp | 227 ++++++++++++------ .../Enzyme/MLIR/Integrations/c/EnzymeMLIR.h | 109 +++++++-- 3 files changed, 260 insertions(+), 88 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Integrations/c/CMakeLists.txt b/enzyme/Enzyme/MLIR/Integrations/c/CMakeLists.txt index 085d625fe7a9..69981e15d347 100644 --- a/enzyme/Enzyme/MLIR/Integrations/c/CMakeLists.txt +++ b/enzyme/Enzyme/MLIR/Integrations/c/CMakeLists.txt @@ -1,7 +1,19 @@ add_mlir_public_c_api_library(MLIRCAPIEnzyme EnzymeMLIR.cpp + DEPENDS + MLIRImpulseEnumsIncGen + MLIRImpulseAttributesIncGen + LINK_LIBS PUBLIC MLIRIR + MLIRCAPIIR + + LINK_LIBS PRIVATE MLIREnzyme + MLIRImpulse + MLIREnzymeTransforms + MLIREnzymeAnalysis + MLIREnzymeAutoDiffInterface + MLIREnzymeImplementations ) diff --git a/enzyme/Enzyme/MLIR/Integrations/c/EnzymeMLIR.cpp b/enzyme/Enzyme/MLIR/Integrations/c/EnzymeMLIR.cpp index c0035163a424..dcb876499c13 100644 --- a/enzyme/Enzyme/MLIR/Integrations/c/EnzymeMLIR.cpp +++ b/enzyme/Enzyme/MLIR/Integrations/c/EnzymeMLIR.cpp @@ -1,103 +1,190 @@ +//===- EnzymeMLIR.cpp - C API for Enzyme MLIR dialect ---------------------===// +// +// Part of the Enzyme Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + #include "EnzymeMLIR.h" #include "mlir/CAPI/IR.h" +#include "mlir/CAPI/Pass.h" #include "mlir/CAPI/Registration.h" #include "Dialect/Dialect.h" #include "Dialect/Impulse/Impulse.h" #include "Dialect/Ops.h" +#include "Implementations/CoreDialectsAutoDiffImplementations.h" +#include "Passes/Passes.h" +#include "llvm/Support/ErrorHandling.h" MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Enzyme, enzyme, mlir::enzyme::EnzymeDialect) -MlirAttribute enzymeActivityAttrGet(MlirContext ctx, EnzymeActivity activity) { +MLIR_CAPI_EXPORTED void enzymeRegisterPasses(void) { + mlir::enzyme::registerDifferentiatePass(); + mlir::enzyme::registerExpandImpulsePass(); + mlir::enzyme::registerBatchPass(); + mlir::enzyme::registerBatchDiffPass(); + mlir::enzyme::registerDifferentiateWrapperPass(); + mlir::enzyme::registerInlineEnzymeIntoRegionPass(); + mlir::enzyme::registerOutlineEnzymeFromRegionPass(); + mlir::enzyme::registerPrintActivityAnalysisPass(); + mlir::enzyme::registerPrintAliasAnalysisPass(); + mlir::enzyme::registerRemoveUnusedEnzymeOpsPass(); +} + +MLIR_CAPI_EXPORTED void +enzymeRegisterDialectExtensions(MlirDialectRegistry registry) { + mlir::enzyme::registerCoreDialectAutodiffInterfaces(*unwrap(registry)); +} + +MLIR_CAPI_EXPORTED MlirPass enzymeCreateDifferentiatePass(void) { + return wrap(mlir::enzyme::createDifferentiatePass().release()); +} + +MLIR_CAPI_EXPORTED MlirPass enzymeCreateDifferentiatePassWithOptions( + MlirStringRef postpasses, bool verifyPostPasses) { + mlir::enzyme::DifferentiatePassOptions opts; + opts.postpasses = std::string(postpasses.data, postpasses.length); + opts.verifyPostPasses = verifyPostPasses; + return wrap(mlir::enzyme::createDifferentiatePass(opts).release()); +} + +MLIR_CAPI_EXPORTED MlirPass enzymeCreateConvertEnzymeToMemRefPass(void) { + return wrap(mlir::enzyme::createEnzymeOpsToMemRefPass().release()); +} + +MLIR_CAPI_EXPORTED MlirPass enzymeCreateBatchPass(void) { + return wrap(mlir::enzyme::createBatchPass().release()); +} + +MLIR_CAPI_EXPORTED MlirPass enzymeCreateBatchDiffPass(void) { + return wrap(mlir::enzyme::createBatchDiffPass().release()); +} + +MLIR_CAPI_EXPORTED MlirPass enzymeCreateRemoveUnusedEnzymeOpsPass(void) { + return wrap(mlir::enzyme::createRemoveUnusedEnzymeOpsPass().release()); +} + +MLIR_CAPI_EXPORTED MlirAttribute enzymeActivityAttrGet(MlirContext ctx, + uint32_t activity) { mlir::enzyme::Activity act; switch (activity) { - case EnzymeActivity_enzyme_active: + case 0: act = mlir::enzyme::Activity::enzyme_active; break; - case EnzymeActivity_enzyme_dup: + case 1: act = mlir::enzyme::Activity::enzyme_dup; break; - case EnzymeActivity_enzyme_const: + case 2: act = mlir::enzyme::Activity::enzyme_const; break; - case EnzymeActivity_enzyme_dupnoneed: + case 3: act = mlir::enzyme::Activity::enzyme_dupnoneed; break; - case EnzymeActivity_enzyme_activenoneed: + case 4: act = mlir::enzyme::Activity::enzyme_activenoneed; break; - case EnzymeActivity_enzyme_constnoneed: + case 5: act = mlir::enzyme::Activity::enzyme_constnoneed; break; + default: + llvm_unreachable("invalid Enzyme activity"); } return wrap(mlir::enzyme::ActivityAttr::get(unwrap(ctx), act)); } -static mlir::ArrayAttr activityArrayAttr(MlirContext ctx, - MlirAttribute *activity, intptr_t n) { - llvm::SmallVector attrs; - attrs.reserve(n); - for (intptr_t i = 0; i < n; ++i) - attrs.push_back(unwrap(activity[i])); - return mlir::ArrayAttr::get(unwrap(ctx), attrs); +static mlir::ArrayAttr +activityArrayAttr(MlirContext ctx, const MlirAttribute *activity, intptr_t n) { + return mlir::ArrayAttr::get( + unwrap(ctx), llvm::ArrayRef( + reinterpret_cast(activity), n)); } -static void collectTypes(MlirType *src, intptr_t n, - llvm::SmallVectorImpl &out) { - out.reserve(n); - for (intptr_t i = 0; i < n; ++i) - out.push_back(unwrap(src[i])); +MLIR_CAPI_EXPORTED MlirOperation enzymeAutoDiffOpCreate( + MlirContext ctx, MlirStringRef fn, const MlirType *resultTypes, + intptr_t nResults, const MlirValue *inputs, intptr_t nInputs, + const MlirAttribute *activity, intptr_t nActivity, + const MlirAttribute *retActivity, intptr_t nRetActivity, int64_t width, + bool strongZero, MlirLocation loc) { + auto op = + mlir::OpBuilder(unwrap(ctx)) + .create( + unwrap(loc), + mlir::TypeRange(llvm::ArrayRef( + reinterpret_cast(resultTypes), nResults)), + llvm::StringRef(fn.data, fn.length), + mlir::ValueRange(llvm::ArrayRef( + reinterpret_cast(inputs), nInputs)), + activityArrayAttr(ctx, activity, nActivity), + activityArrayAttr(ctx, retActivity, nRetActivity), + (uint64_t)width, strongZero); + return wrap(op.getOperation()); } -static void collectValues(MlirValue *src, intptr_t n, - llvm::SmallVectorImpl &out) { - out.reserve(n); - for (intptr_t i = 0; i < n; ++i) - out.push_back(unwrap(src[i])); +MLIR_CAPI_EXPORTED MlirOperation enzymeForwardDiffOpCreate( + MlirContext ctx, MlirStringRef fn, const MlirType *resultTypes, + intptr_t nResults, const MlirValue *inputs, intptr_t nInputs, + const MlirAttribute *activity, intptr_t nActivity, + const MlirAttribute *retActivity, intptr_t nRetActivity, int64_t width, + bool strongZero, MlirLocation loc) { + auto op = + mlir::OpBuilder(unwrap(ctx)) + .create( + unwrap(loc), + mlir::TypeRange(llvm::ArrayRef( + reinterpret_cast(resultTypes), nResults)), + llvm::StringRef(fn.data, fn.length), + mlir::ValueRange(llvm::ArrayRef( + reinterpret_cast(inputs), nInputs)), + activityArrayAttr(ctx, activity, nActivity), + activityArrayAttr(ctx, retActivity, nRetActivity), + (uint64_t)width, strongZero); + return wrap(op.getOperation()); } -MlirOperation -enzymeAutoDiffOpCreate(MlirContext ctx, MlirStringRef fn, MlirType *resultTypes, - intptr_t nResults, MlirValue *inputs, intptr_t nInputs, - MlirAttribute *activity, intptr_t nActivity, - MlirAttribute *retActivity, intptr_t nRetActivity, - int64_t width, bool strongZero, MlirLocation loc) { - auto *mlirCtx = unwrap(ctx); - llvm::SmallVector results; - collectTypes(resultTypes, nResults, results); - llvm::SmallVector operands; - collectValues(inputs, nInputs, operands); - auto op = mlir::OpBuilder(mlirCtx).create( - unwrap(loc), mlir::TypeRange(results), - llvm::StringRef(fn.data, fn.length), mlir::ValueRange(operands), - activityArrayAttr(ctx, activity, nActivity), - activityArrayAttr(ctx, retActivity, nRetActivity), (uint64_t)width, - strongZero); +MLIR_CAPI_EXPORTED MlirOperation enzymeJacobianOpCreate( + MlirContext ctx, MlirStringRef fn, const MlirType *resultTypes, + intptr_t nResults, const MlirValue *inputs, intptr_t nInputs, + const MlirAttribute *activity, intptr_t nActivity, + const MlirAttribute *retActivity, intptr_t nRetActivity, int64_t width, + bool strongZero, MlirLocation loc) { + auto op = + mlir::OpBuilder(unwrap(ctx)) + .create( + unwrap(loc), + mlir::TypeRange(llvm::ArrayRef( + reinterpret_cast(resultTypes), nResults)), + llvm::StringRef(fn.data, fn.length), + mlir::ValueRange(llvm::ArrayRef( + reinterpret_cast(inputs), nInputs)), + activityArrayAttr(ctx, activity, nActivity), + activityArrayAttr(ctx, retActivity, nRetActivity), + (uint64_t)width, strongZero); return wrap(op.getOperation()); } -MlirOperation enzymeForwardDiffOpCreate( - MlirContext ctx, MlirStringRef fn, MlirType *resultTypes, intptr_t nResults, - MlirValue *inputs, intptr_t nInputs, MlirAttribute *activity, - intptr_t nActivity, MlirAttribute *retActivity, intptr_t nRetActivity, - int64_t width, bool strongZero, MlirLocation loc) { - auto *mlirCtx = unwrap(ctx); - llvm::SmallVector results; - collectTypes(resultTypes, nResults, results); - llvm::SmallVector operands; - collectValues(inputs, nInputs, operands); - auto op = mlir::OpBuilder(mlirCtx).create( - unwrap(loc), mlir::TypeRange(results), - llvm::StringRef(fn.data, fn.length), mlir::ValueRange(operands), - activityArrayAttr(ctx, activity, nActivity), - activityArrayAttr(ctx, retActivity, nRetActivity), (uint64_t)width, - strongZero); +MLIR_CAPI_EXPORTED MlirOperation enzymeBatchOpCreate( + MlirContext ctx, MlirStringRef fn, const MlirType *resultTypes, + intptr_t nResults, const MlirValue *inputs, intptr_t nInputs, + const int64_t *batchShape, intptr_t nBatchShape, MlirLocation loc) { + auto op = + mlir::OpBuilder(unwrap(ctx)) + .create( + unwrap(loc), + mlir::TypeRange(llvm::ArrayRef( + reinterpret_cast(resultTypes), nResults)), + llvm::StringRef(fn.data, fn.length), + mlir::ValueRange(llvm::ArrayRef( + reinterpret_cast(inputs), nInputs)), + llvm::ArrayRef(batchShape, nBatchShape)); return wrap(op.getOperation()); } -MlirAttribute enzymeRngDistributionAttrGet(MlirContext ctx, - EnzymeRngDistribution dist) { +MLIR_CAPI_EXPORTED MlirAttribute +enzymeRngDistributionAttrGet(MlirContext ctx, EnzymeRngDistribution dist) { mlir::impulse::RngDistribution rngDist; switch (dist) { case EnzymeRngDistribution_Uniform: @@ -113,9 +200,9 @@ MlirAttribute enzymeRngDistributionAttrGet(MlirContext ctx, return wrap(mlir::impulse::RngDistributionAttr::get(unwrap(ctx), rngDist)); } -MlirAttribute enzymeSupportAttrGet(MlirContext ctx, EnzymeSupportKind kind, - bool hasLowerBound, double lowerBound, - bool hasUpperBound, double upperBound) { +MLIR_CAPI_EXPORTED MlirAttribute enzymeSupportAttrGet( + MlirContext ctx, EnzymeSupportKind kind, bool hasLowerBound, + double lowerBound, bool hasUpperBound, double upperBound) { auto *mlirCtx = unwrap(ctx); mlir::impulse::SupportKind supportKind; @@ -154,8 +241,10 @@ MlirAttribute enzymeSupportAttrGet(MlirContext ctx, EnzymeSupportKind kind, upperAttr)); } -MlirAttribute enzymeHMCConfigAttrGet(MlirContext ctx, double trajectoryLength, - bool adaptStepSize, bool adaptMassMatrix) { +MLIR_CAPI_EXPORTED MlirAttribute enzymeHMCConfigAttrGet(MlirContext ctx, + double trajectoryLength, + bool adaptStepSize, + bool adaptMassMatrix) { auto *mlirCtx = unwrap(ctx); auto trajectoryLengthAttr = mlir::FloatAttr::get(mlir::Float64Type::get(mlirCtx), trajectoryLength); @@ -164,10 +253,9 @@ MlirAttribute enzymeHMCConfigAttrGet(MlirContext ctx, double trajectoryLength, mlirCtx, trajectoryLengthAttr, adaptStepSize, adaptMassMatrix)); } -MlirAttribute enzymeNUTSConfigAttrGet(MlirContext ctx, int64_t maxTreeDepth, - bool hasMaxDeltaEnergy, - double maxDeltaEnergy, bool adaptStepSize, - bool adaptMassMatrix) { +MLIR_CAPI_EXPORTED MlirAttribute enzymeNUTSConfigAttrGet( + MlirContext ctx, int64_t maxTreeDepth, bool hasMaxDeltaEnergy, + double maxDeltaEnergy, bool adaptStepSize, bool adaptMassMatrix) { auto *mlirCtx = unwrap(ctx); mlir::FloatAttr maxDeltaEnergyAttr; @@ -180,6 +268,7 @@ MlirAttribute enzymeNUTSConfigAttrGet(MlirContext ctx, int64_t maxTreeDepth, adaptMassMatrix)); } -MlirAttribute enzymeSymbolAttrGet(MlirContext ctx, uint64_t ptr) { +MLIR_CAPI_EXPORTED MlirAttribute enzymeSymbolAttrGet(MlirContext ctx, + uint64_t ptr) { return wrap(mlir::impulse::SymbolAttr::get(unwrap(ctx), ptr)); } diff --git a/enzyme/Enzyme/MLIR/Integrations/c/EnzymeMLIR.h b/enzyme/Enzyme/MLIR/Integrations/c/EnzymeMLIR.h index 56d6df3749da..518906716c74 100644 --- a/enzyme/Enzyme/MLIR/Integrations/c/EnzymeMLIR.h +++ b/enzyme/Enzyme/MLIR/Integrations/c/EnzymeMLIR.h @@ -1,3 +1,16 @@ +//===- EnzymeMLIR.h - C API for Enzyme MLIR dialect ---------------*- C -*-===// +// +// Part of the Enzyme Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header declares the C interface to the Enzyme MLIR dialect. +// +//===----------------------------------------------------------------------===// + #ifndef ENZYME_MLIR_INTEGRATIONS_C_ENZYMEMLIR_H_ #define ENZYME_MLIR_INTEGRATIONS_C_ENZYMEMLIR_H_ @@ -5,6 +18,7 @@ #include #include "mlir-c/IR.h" +#include "mlir-c/Pass.h" #include "mlir-c/Support.h" #ifdef __cplusplus @@ -21,33 +35,90 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Enzyme, enzyme); // Activity attribute //===----------------------------------------------------------------------===// -typedef enum { - EnzymeActivity_enzyme_active = 0, - EnzymeActivity_enzyme_dup = 1, - EnzymeActivity_enzyme_const = 2, - EnzymeActivity_enzyme_dupnoneed = 3, - EnzymeActivity_enzyme_activenoneed = 4, - EnzymeActivity_enzyme_constnoneed = 5, -} EnzymeActivity; - +/// Construct an `enzyme.activity` attribute from an activity code matching the +/// `Activity` enum ordinals (0=active, 1=dup, 2=const, 3=dupnoneed, ...). MLIR_CAPI_EXPORTED MlirAttribute enzymeActivityAttrGet(MlirContext ctx, - EnzymeActivity activity); + uint32_t activity); + +//===----------------------------------------------------------------------===// +// Pass registration +//===----------------------------------------------------------------------===// + +/// Register all Enzyme MLIR passes with the global registry. +MLIR_CAPI_EXPORTED void enzymeRegisterPasses(void); + +/// Register Enzyme's external model AD implementations for the core MLIR +/// dialects (arith, func, scf, ...) into the given DialectRegistry. +/// Must be called before appending the registry to an MLIRContext so that +/// the Differentiate pass knows how to lower ops in those dialects. +MLIR_CAPI_EXPORTED void +enzymeRegisterDialectExtensions(MlirDialectRegistry registry); + +/// Create the core differentiation pass (`--enzyme`). +MLIR_CAPI_EXPORTED MlirPass enzymeCreateDifferentiatePass(void); + +/// Create the core differentiation pass with options. +/// postpasses: semicolon-separated list of passes to run after differentiation +/// (may be NULL or empty). verifyPostPasses: run the verifier after each +/// post-pass. +MLIR_CAPI_EXPORTED MlirPass enzymeCreateDifferentiatePassWithOptions( + MlirStringRef postpasses, bool verifyPostPasses); + +/// Create the `--enzyme-ops-to-memref` pass, which lowers `enzyme.init`, +/// `enzyme.get`, and `enzyme.set` to `memref` + `scf.if`. Required before +/// lowering to LLVM. +MLIR_CAPI_EXPORTED MlirPass enzymeCreateConvertEnzymeToMemRefPass(void); + +/// Create the `--enzyme-batch` pass, which lowers `enzyme.batch` ops to +/// calls to generated `@batched_*` functions. +MLIR_CAPI_EXPORTED MlirPass enzymeCreateBatchPass(void); + +/// Create the `--enzyme-batch-diff` pass, which differentiates batched +/// functions produced by `--enzyme-batch`. +MLIR_CAPI_EXPORTED MlirPass enzymeCreateBatchDiffPass(void); + +/// Create the `--remove-unnecessary-enzyme-ops` pass, which removes dead +/// `enzyme.*` ops after differentiation. +MLIR_CAPI_EXPORTED MlirPass enzymeCreateRemoveUnusedEnzymeOpsPass(void); //===----------------------------------------------------------------------===// // Core AD ops //===----------------------------------------------------------------------===// +/// Construct an `enzyme.autodiff` (reverse-mode) op referencing the function +/// named `fn`. `activity`/`retActivity` are arrays of `enzyme.activity` +/// attributes of length `nActivity`/`nRetActivity`. MLIR_CAPI_EXPORTED MlirOperation enzymeAutoDiffOpCreate( - MlirContext ctx, MlirStringRef fn, MlirType *resultTypes, intptr_t nResults, - MlirValue *inputs, intptr_t nInputs, MlirAttribute *activity, - intptr_t nActivity, MlirAttribute *retActivity, intptr_t nRetActivity, - int64_t width, bool strongZero, MlirLocation loc); - + MlirContext ctx, MlirStringRef fn, const MlirType *resultTypes, + intptr_t nResults, const MlirValue *inputs, intptr_t nInputs, + const MlirAttribute *activity, intptr_t nActivity, + const MlirAttribute *retActivity, intptr_t nRetActivity, int64_t width, + bool strongZero, MlirLocation loc); + +/// Construct an `enzyme.fwddiff` (forward-mode) op referencing the function +/// named `fn`. See `enzymeAutoDiffOpCreate` for parameter semantics. MLIR_CAPI_EXPORTED MlirOperation enzymeForwardDiffOpCreate( - MlirContext ctx, MlirStringRef fn, MlirType *resultTypes, intptr_t nResults, - MlirValue *inputs, intptr_t nInputs, MlirAttribute *activity, - intptr_t nActivity, MlirAttribute *retActivity, intptr_t nRetActivity, - int64_t width, bool strongZero, MlirLocation loc); + MlirContext ctx, MlirStringRef fn, const MlirType *resultTypes, + intptr_t nResults, const MlirValue *inputs, intptr_t nInputs, + const MlirAttribute *activity, intptr_t nActivity, + const MlirAttribute *retActivity, intptr_t nRetActivity, int64_t width, + bool strongZero, MlirLocation loc); + +/// Construct an `enzyme.jacobian` op referencing the function named `fn`. +/// See `enzymeAutoDiffOpCreate` for parameter semantics. +MLIR_CAPI_EXPORTED MlirOperation enzymeJacobianOpCreate( + MlirContext ctx, MlirStringRef fn, const MlirType *resultTypes, + intptr_t nResults, const MlirValue *inputs, intptr_t nInputs, + const MlirAttribute *activity, intptr_t nActivity, + const MlirAttribute *retActivity, intptr_t nRetActivity, int64_t width, + bool strongZero, MlirLocation loc); + +/// Construct an `enzyme.batch` op that maps the function named `fn` over a +/// batch of inputs. `batchShape` is an array of `nBatchShape` batch dimensions. +MLIR_CAPI_EXPORTED MlirOperation enzymeBatchOpCreate( + MlirContext ctx, MlirStringRef fn, const MlirType *resultTypes, + intptr_t nResults, const MlirValue *inputs, intptr_t nInputs, + const int64_t *batchShape, intptr_t nBatchShape, MlirLocation loc); //===----------------------------------------------------------------------===// // Probabilistic Programming Ops From aac41cf76b92ee922eeb0a99d0e4236753f2fa16 Mon Sep 17 00:00:00 2001 From: Austin Garrett Date: Tue, 9 Jun 2026 20:05:59 -0400 Subject: [PATCH 5/5] [MLIR] simplify enzymeActivityAttrGet --- .../Enzyme/MLIR/Integrations/c/EnzymeMLIR.cpp | 28 ++----------------- 1 file changed, 3 insertions(+), 25 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Integrations/c/EnzymeMLIR.cpp b/enzyme/Enzyme/MLIR/Integrations/c/EnzymeMLIR.cpp index dcb876499c13..a98f8b87f225 100644 --- a/enzyme/Enzyme/MLIR/Integrations/c/EnzymeMLIR.cpp +++ b/enzyme/Enzyme/MLIR/Integrations/c/EnzymeMLIR.cpp @@ -18,7 +18,7 @@ #include "Dialect/Ops.h" #include "Implementations/CoreDialectsAutoDiffImplementations.h" #include "Passes/Passes.h" -#include "llvm/Support/ErrorHandling.h" + MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Enzyme, enzyme, mlir::enzyme::EnzymeDialect) @@ -70,30 +70,8 @@ MLIR_CAPI_EXPORTED MlirPass enzymeCreateRemoveUnusedEnzymeOpsPass(void) { MLIR_CAPI_EXPORTED MlirAttribute enzymeActivityAttrGet(MlirContext ctx, uint32_t activity) { - mlir::enzyme::Activity act; - switch (activity) { - case 0: - act = mlir::enzyme::Activity::enzyme_active; - break; - case 1: - act = mlir::enzyme::Activity::enzyme_dup; - break; - case 2: - act = mlir::enzyme::Activity::enzyme_const; - break; - case 3: - act = mlir::enzyme::Activity::enzyme_dupnoneed; - break; - case 4: - act = mlir::enzyme::Activity::enzyme_activenoneed; - break; - case 5: - act = mlir::enzyme::Activity::enzyme_constnoneed; - break; - default: - llvm_unreachable("invalid Enzyme activity"); - } - return wrap(mlir::enzyme::ActivityAttr::get(unwrap(ctx), act)); + return wrap(mlir::enzyme::ActivityAttr::get( + unwrap(ctx), (mlir::enzyme::Activity)activity)); } static mlir::ArrayAttr