diff --git a/enzyme/Enzyme/MLIR/CMakeLists.txt b/enzyme/Enzyme/MLIR/CMakeLists.txt index 157ec56f2fd..f055873bca0 100644 --- a/enzyme/Enzyme/MLIR/CMakeLists.txt +++ b/enzyme/Enzyme/MLIR/CMakeLists.txt @@ -5,6 +5,7 @@ add_subdirectory(Dialect) add_subdirectory(Interfaces) add_subdirectory(Implementations) add_subdirectory(Passes) +add_subdirectory(Integrations) add_subdirectory(enzymemlir-translate) get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) diff --git a/enzyme/Enzyme/MLIR/Integrations/CMakeLists.txt b/enzyme/Enzyme/MLIR/Integrations/CMakeLists.txt new file mode 100644 index 00000000000..20bc4cc4b07 --- /dev/null +++ b/enzyme/Enzyme/MLIR/Integrations/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(c) diff --git a/enzyme/Enzyme/MLIR/Integrations/c/CMakeLists.txt b/enzyme/Enzyme/MLIR/Integrations/c/CMakeLists.txt new file mode 100644 index 00000000000..69981e15d34 --- /dev/null +++ b/enzyme/Enzyme/MLIR/Integrations/c/CMakeLists.txt @@ -0,0 +1,19 @@ +add_mlir_public_c_api_library(MLIRCAPIEnzyme + EnzymeMLIR.cpp + + DEPENDS + MLIRImpulseEnumsIncGen + MLIRImpulseAttributesIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRCAPIIR + + LINK_LIBS PRIVATE + MLIREnzyme + MLIRImpulse + MLIREnzymeTransforms + MLIREnzymeAnalysis + MLIREnzymeAutoDiffInterface + MLIREnzymeImplementations +) diff --git a/enzyme/Enzyme/MLIR/Integrations/c/EnzymeMLIR.cpp b/enzyme/Enzyme/MLIR/Integrations/c/EnzymeMLIR.cpp index c65709b22de..a98f8b87f22 100644 --- a/enzyme/Enzyme/MLIR/Integrations/c/EnzymeMLIR.cpp +++ b/enzyme/Enzyme/MLIR/Integrations/c/EnzymeMLIR.cpp @@ -1,13 +1,168 @@ +//===- EnzymeMLIR.cpp - C API for Enzyme MLIR dialect ---------------------===// +// +// Part of the Enzyme Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + #include "EnzymeMLIR.h" #include "mlir/CAPI/IR.h" +#include "mlir/CAPI/Pass.h" +#include "mlir/CAPI/Registration.h" #include "Dialect/Dialect.h" #include "Dialect/Impulse/Impulse.h" #include "Dialect/Ops.h" +#include "Implementations/CoreDialectsAutoDiffImplementations.h" +#include "Passes/Passes.h" + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Enzyme, enzyme, + mlir::enzyme::EnzymeDialect) + +MLIR_CAPI_EXPORTED void enzymeRegisterPasses(void) { + mlir::enzyme::registerDifferentiatePass(); + mlir::enzyme::registerExpandImpulsePass(); + mlir::enzyme::registerBatchPass(); + mlir::enzyme::registerBatchDiffPass(); + mlir::enzyme::registerDifferentiateWrapperPass(); + mlir::enzyme::registerInlineEnzymeIntoRegionPass(); + mlir::enzyme::registerOutlineEnzymeFromRegionPass(); + mlir::enzyme::registerPrintActivityAnalysisPass(); + mlir::enzyme::registerPrintAliasAnalysisPass(); + mlir::enzyme::registerRemoveUnusedEnzymeOpsPass(); +} + +MLIR_CAPI_EXPORTED void +enzymeRegisterDialectExtensions(MlirDialectRegistry registry) { + mlir::enzyme::registerCoreDialectAutodiffInterfaces(*unwrap(registry)); +} + +MLIR_CAPI_EXPORTED MlirPass enzymeCreateDifferentiatePass(void) { + return wrap(mlir::enzyme::createDifferentiatePass().release()); +} + +MLIR_CAPI_EXPORTED MlirPass enzymeCreateDifferentiatePassWithOptions( + MlirStringRef postpasses, bool verifyPostPasses) { + mlir::enzyme::DifferentiatePassOptions opts; + opts.postpasses = std::string(postpasses.data, postpasses.length); + opts.verifyPostPasses = verifyPostPasses; + return wrap(mlir::enzyme::createDifferentiatePass(opts).release()); +} + +MLIR_CAPI_EXPORTED MlirPass enzymeCreateConvertEnzymeToMemRefPass(void) { + return wrap(mlir::enzyme::createEnzymeOpsToMemRefPass().release()); +} + +MLIR_CAPI_EXPORTED MlirPass enzymeCreateBatchPass(void) { + return wrap(mlir::enzyme::createBatchPass().release()); +} + +MLIR_CAPI_EXPORTED MlirPass enzymeCreateBatchDiffPass(void) { + return wrap(mlir::enzyme::createBatchDiffPass().release()); +} + +MLIR_CAPI_EXPORTED MlirPass enzymeCreateRemoveUnusedEnzymeOpsPass(void) { + return wrap(mlir::enzyme::createRemoveUnusedEnzymeOpsPass().release()); +} + +MLIR_CAPI_EXPORTED MlirAttribute enzymeActivityAttrGet(MlirContext ctx, + uint32_t activity) { + return wrap(mlir::enzyme::ActivityAttr::get( + unwrap(ctx), (mlir::enzyme::Activity)activity)); +} + +static mlir::ArrayAttr +activityArrayAttr(MlirContext ctx, const MlirAttribute *activity, intptr_t n) { + return mlir::ArrayAttr::get( + unwrap(ctx), llvm::ArrayRef( + reinterpret_cast(activity), n)); +} + +MLIR_CAPI_EXPORTED MlirOperation enzymeAutoDiffOpCreate( + MlirContext ctx, MlirStringRef fn, const MlirType *resultTypes, + intptr_t nResults, const MlirValue *inputs, intptr_t nInputs, + const MlirAttribute *activity, intptr_t nActivity, + const MlirAttribute *retActivity, intptr_t nRetActivity, int64_t width, + bool strongZero, MlirLocation loc) { + auto op = + mlir::OpBuilder(unwrap(ctx)) + .create( + unwrap(loc), + mlir::TypeRange(llvm::ArrayRef( + reinterpret_cast(resultTypes), nResults)), + llvm::StringRef(fn.data, fn.length), + mlir::ValueRange(llvm::ArrayRef( + reinterpret_cast(inputs), nInputs)), + activityArrayAttr(ctx, activity, nActivity), + activityArrayAttr(ctx, retActivity, nRetActivity), + (uint64_t)width, strongZero); + return wrap(op.getOperation()); +} + +MLIR_CAPI_EXPORTED MlirOperation enzymeForwardDiffOpCreate( + MlirContext ctx, MlirStringRef fn, const MlirType *resultTypes, + intptr_t nResults, const MlirValue *inputs, intptr_t nInputs, + const MlirAttribute *activity, intptr_t nActivity, + const MlirAttribute *retActivity, intptr_t nRetActivity, int64_t width, + bool strongZero, MlirLocation loc) { + auto op = + mlir::OpBuilder(unwrap(ctx)) + .create( + unwrap(loc), + mlir::TypeRange(llvm::ArrayRef( + reinterpret_cast(resultTypes), nResults)), + llvm::StringRef(fn.data, fn.length), + mlir::ValueRange(llvm::ArrayRef( + reinterpret_cast(inputs), nInputs)), + activityArrayAttr(ctx, activity, nActivity), + activityArrayAttr(ctx, retActivity, nRetActivity), + (uint64_t)width, strongZero); + return wrap(op.getOperation()); +} + +MLIR_CAPI_EXPORTED MlirOperation enzymeJacobianOpCreate( + MlirContext ctx, MlirStringRef fn, const MlirType *resultTypes, + intptr_t nResults, const MlirValue *inputs, intptr_t nInputs, + const MlirAttribute *activity, intptr_t nActivity, + const MlirAttribute *retActivity, intptr_t nRetActivity, int64_t width, + bool strongZero, MlirLocation loc) { + auto op = + mlir::OpBuilder(unwrap(ctx)) + .create( + unwrap(loc), + mlir::TypeRange(llvm::ArrayRef( + reinterpret_cast(resultTypes), nResults)), + llvm::StringRef(fn.data, fn.length), + mlir::ValueRange(llvm::ArrayRef( + reinterpret_cast(inputs), nInputs)), + activityArrayAttr(ctx, activity, nActivity), + activityArrayAttr(ctx, retActivity, nRetActivity), + (uint64_t)width, strongZero); + return wrap(op.getOperation()); +} + +MLIR_CAPI_EXPORTED MlirOperation enzymeBatchOpCreate( + MlirContext ctx, MlirStringRef fn, const MlirType *resultTypes, + intptr_t nResults, const MlirValue *inputs, intptr_t nInputs, + const int64_t *batchShape, intptr_t nBatchShape, MlirLocation loc) { + auto op = + mlir::OpBuilder(unwrap(ctx)) + .create( + unwrap(loc), + mlir::TypeRange(llvm::ArrayRef( + reinterpret_cast(resultTypes), nResults)), + llvm::StringRef(fn.data, fn.length), + mlir::ValueRange(llvm::ArrayRef( + reinterpret_cast(inputs), nInputs)), + llvm::ArrayRef(batchShape, nBatchShape)); + return wrap(op.getOperation()); +} -MlirAttribute enzymeRngDistributionAttrGet(MlirContext ctx, - EnzymeRngDistribution dist) { +MLIR_CAPI_EXPORTED MlirAttribute +enzymeRngDistributionAttrGet(MlirContext ctx, EnzymeRngDistribution dist) { mlir::impulse::RngDistribution rngDist; switch (dist) { case EnzymeRngDistribution_Uniform: @@ -23,9 +178,9 @@ MlirAttribute enzymeRngDistributionAttrGet(MlirContext ctx, return wrap(mlir::impulse::RngDistributionAttr::get(unwrap(ctx), rngDist)); } -MlirAttribute enzymeSupportAttrGet(MlirContext ctx, EnzymeSupportKind kind, - bool hasLowerBound, double lowerBound, - bool hasUpperBound, double upperBound) { +MLIR_CAPI_EXPORTED MlirAttribute enzymeSupportAttrGet( + MlirContext ctx, EnzymeSupportKind kind, bool hasLowerBound, + double lowerBound, bool hasUpperBound, double upperBound) { auto *mlirCtx = unwrap(ctx); mlir::impulse::SupportKind supportKind; @@ -64,8 +219,10 @@ MlirAttribute enzymeSupportAttrGet(MlirContext ctx, EnzymeSupportKind kind, upperAttr)); } -MlirAttribute enzymeHMCConfigAttrGet(MlirContext ctx, double trajectoryLength, - bool adaptStepSize, bool adaptMassMatrix) { +MLIR_CAPI_EXPORTED MlirAttribute enzymeHMCConfigAttrGet(MlirContext ctx, + double trajectoryLength, + bool adaptStepSize, + bool adaptMassMatrix) { auto *mlirCtx = unwrap(ctx); auto trajectoryLengthAttr = mlir::FloatAttr::get(mlir::Float64Type::get(mlirCtx), trajectoryLength); @@ -74,10 +231,9 @@ MlirAttribute enzymeHMCConfigAttrGet(MlirContext ctx, double trajectoryLength, mlirCtx, trajectoryLengthAttr, adaptStepSize, adaptMassMatrix)); } -MlirAttribute enzymeNUTSConfigAttrGet(MlirContext ctx, int64_t maxTreeDepth, - bool hasMaxDeltaEnergy, - double maxDeltaEnergy, bool adaptStepSize, - bool adaptMassMatrix) { +MLIR_CAPI_EXPORTED MlirAttribute enzymeNUTSConfigAttrGet( + MlirContext ctx, int64_t maxTreeDepth, bool hasMaxDeltaEnergy, + double maxDeltaEnergy, bool adaptStepSize, bool adaptMassMatrix) { auto *mlirCtx = unwrap(ctx); mlir::FloatAttr maxDeltaEnergyAttr; @@ -90,6 +246,7 @@ MlirAttribute enzymeNUTSConfigAttrGet(MlirContext ctx, int64_t maxTreeDepth, adaptMassMatrix)); } -MlirAttribute enzymeSymbolAttrGet(MlirContext ctx, uint64_t ptr) { +MLIR_CAPI_EXPORTED MlirAttribute enzymeSymbolAttrGet(MlirContext ctx, + uint64_t ptr) { return wrap(mlir::impulse::SymbolAttr::get(unwrap(ctx), ptr)); } diff --git a/enzyme/Enzyme/MLIR/Integrations/c/EnzymeMLIR.h b/enzyme/Enzyme/MLIR/Integrations/c/EnzymeMLIR.h index 62336c34b4b..518906716c7 100644 --- a/enzyme/Enzyme/MLIR/Integrations/c/EnzymeMLIR.h +++ b/enzyme/Enzyme/MLIR/Integrations/c/EnzymeMLIR.h @@ -1,3 +1,16 @@ +//===- EnzymeMLIR.h - C API for Enzyme MLIR dialect ---------------*- C -*-===// +// +// Part of the Enzyme Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header declares the C interface to the Enzyme MLIR dialect. +// +//===----------------------------------------------------------------------===// + #ifndef ENZYME_MLIR_INTEGRATIONS_C_ENZYMEMLIR_H_ #define ENZYME_MLIR_INTEGRATIONS_C_ENZYMEMLIR_H_ @@ -5,12 +18,108 @@ #include #include "mlir-c/IR.h" +#include "mlir-c/Pass.h" #include "mlir-c/Support.h" #ifdef __cplusplus extern "C" { #endif +//===----------------------------------------------------------------------===// +// Dialect +//===----------------------------------------------------------------------===// + +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Enzyme, enzyme); + +//===----------------------------------------------------------------------===// +// Activity attribute +//===----------------------------------------------------------------------===// + +/// Construct an `enzyme.activity` attribute from an activity code matching the +/// `Activity` enum ordinals (0=active, 1=dup, 2=const, 3=dupnoneed, ...). +MLIR_CAPI_EXPORTED MlirAttribute enzymeActivityAttrGet(MlirContext ctx, + uint32_t activity); + +//===----------------------------------------------------------------------===// +// Pass registration +//===----------------------------------------------------------------------===// + +/// Register all Enzyme MLIR passes with the global registry. +MLIR_CAPI_EXPORTED void enzymeRegisterPasses(void); + +/// Register Enzyme's external model AD implementations for the core MLIR +/// dialects (arith, func, scf, ...) into the given DialectRegistry. +/// Must be called before appending the registry to an MLIRContext so that +/// the Differentiate pass knows how to lower ops in those dialects. +MLIR_CAPI_EXPORTED void +enzymeRegisterDialectExtensions(MlirDialectRegistry registry); + +/// Create the core differentiation pass (`--enzyme`). +MLIR_CAPI_EXPORTED MlirPass enzymeCreateDifferentiatePass(void); + +/// Create the core differentiation pass with options. +/// postpasses: semicolon-separated list of passes to run after differentiation +/// (may be NULL or empty). verifyPostPasses: run the verifier after each +/// post-pass. +MLIR_CAPI_EXPORTED MlirPass enzymeCreateDifferentiatePassWithOptions( + MlirStringRef postpasses, bool verifyPostPasses); + +/// Create the `--enzyme-ops-to-memref` pass, which lowers `enzyme.init`, +/// `enzyme.get`, and `enzyme.set` to `memref` + `scf.if`. Required before +/// lowering to LLVM. +MLIR_CAPI_EXPORTED MlirPass enzymeCreateConvertEnzymeToMemRefPass(void); + +/// Create the `--enzyme-batch` pass, which lowers `enzyme.batch` ops to +/// calls to generated `@batched_*` functions. +MLIR_CAPI_EXPORTED MlirPass enzymeCreateBatchPass(void); + +/// Create the `--enzyme-batch-diff` pass, which differentiates batched +/// functions produced by `--enzyme-batch`. +MLIR_CAPI_EXPORTED MlirPass enzymeCreateBatchDiffPass(void); + +/// Create the `--remove-unnecessary-enzyme-ops` pass, which removes dead +/// `enzyme.*` ops after differentiation. +MLIR_CAPI_EXPORTED MlirPass enzymeCreateRemoveUnusedEnzymeOpsPass(void); + +//===----------------------------------------------------------------------===// +// Core AD ops +//===----------------------------------------------------------------------===// + +/// Construct an `enzyme.autodiff` (reverse-mode) op referencing the function +/// named `fn`. `activity`/`retActivity` are arrays of `enzyme.activity` +/// attributes of length `nActivity`/`nRetActivity`. +MLIR_CAPI_EXPORTED MlirOperation enzymeAutoDiffOpCreate( + MlirContext ctx, MlirStringRef fn, const MlirType *resultTypes, + intptr_t nResults, const MlirValue *inputs, intptr_t nInputs, + const MlirAttribute *activity, intptr_t nActivity, + const MlirAttribute *retActivity, intptr_t nRetActivity, int64_t width, + bool strongZero, MlirLocation loc); + +/// Construct an `enzyme.fwddiff` (forward-mode) op referencing the function +/// named `fn`. See `enzymeAutoDiffOpCreate` for parameter semantics. +MLIR_CAPI_EXPORTED MlirOperation enzymeForwardDiffOpCreate( + MlirContext ctx, MlirStringRef fn, const MlirType *resultTypes, + intptr_t nResults, const MlirValue *inputs, intptr_t nInputs, + const MlirAttribute *activity, intptr_t nActivity, + const MlirAttribute *retActivity, intptr_t nRetActivity, int64_t width, + bool strongZero, MlirLocation loc); + +/// Construct an `enzyme.jacobian` op referencing the function named `fn`. +/// See `enzymeAutoDiffOpCreate` for parameter semantics. +MLIR_CAPI_EXPORTED MlirOperation enzymeJacobianOpCreate( + MlirContext ctx, MlirStringRef fn, const MlirType *resultTypes, + intptr_t nResults, const MlirValue *inputs, intptr_t nInputs, + const MlirAttribute *activity, intptr_t nActivity, + const MlirAttribute *retActivity, intptr_t nRetActivity, int64_t width, + bool strongZero, MlirLocation loc); + +/// Construct an `enzyme.batch` op that maps the function named `fn` over a +/// batch of inputs. `batchShape` is an array of `nBatchShape` batch dimensions. +MLIR_CAPI_EXPORTED MlirOperation enzymeBatchOpCreate( + MlirContext ctx, MlirStringRef fn, const MlirType *resultTypes, + intptr_t nResults, const MlirValue *inputs, intptr_t nInputs, + const int64_t *batchShape, intptr_t nBatchShape, MlirLocation loc); + //===----------------------------------------------------------------------===// // Probabilistic Programming Ops //===----------------------------------------------------------------------===//