RFC: enzyme dialect bindings for melior#860
Draft
agarret7 wants to merge 1 commit into
Draft
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Status
This is a draft PR follow-up to #856. It is not yet tested on real Enzyme lowering passes.
It's blocked on Enzyme's MLIR C API for dialect registration (
enzymeRegisterDialectExtensions), which isn't merged (see EnzymeAD/Enzyme#2859, still open) and MLIR 22 support in Enzyme. So these builders'enzyme.*ops currently requirecontext.set_allow_unregistered_dialects(true). Until #2859 lands (e.g. via a futureenzyme-syscrate), the ops are just parsed and printed.MLIRCAPIEnzyme): dialect registration,enzymeActivityAttrGet,enzymeAutoDiffOpCreate/enzymeForwardDiffOpCreate.enzyme-bazel.yml's pinned LLVM commit to MLIR 22, which requires fixing (getSuccessorInputs->getSuccessorRegions) soMLIREnzymeAnalysisbuilds.Motivation
Automatic differentiation (AD) is increasingly central to scientific computing, machine learning, and probabilistic programming. MLIR's multi-dialect, multi-stage compilation model is a natural fit for post-optimization, compile-time AD passes before lowering to LLVM. This is exactly what Enzyme's dialect provides.
Exposing Enzyme's MLIR dialect through Melior would give Rust users a safe, ergonomic path to compile-time gradients without leaving the MLIR ecosystem and without having to use macros to differentiate their functions. This is especially targeted at deep learning libraries (
candle,burn,dfdx) that hand-roll tensor graphs and can't wrap their inner ops withstd::autodifforstd::offloadand thus would benefit from a compile-time AD pass directly on their IR.Scope
This is purely targeted at exposing EnzymeAD MLIR ops through
mlir-rs.Closes #856.
New API Surface
autodiff:enzyme.autodiff, reverse-mode derivative of a function with respect to itsenzyme_active/enzyme_dupoperands.fwddiff:enzyme.fwddiff, forward-mode (tangent) derivative. Same signature asautodiff.jacobian:enzyme.jacobian, full Jacobian with respect to active/dup operands. Same signature asautodiff.batch:enzyme.batch, vectorizes a call over leading batch dimensions of its operands and results.Activityenum andactivity_attribute/activity_array_attributehelpers for building#enzyme<activity ...>attributes (enzyme_active,enzyme_dup,enzyme_const,enzyme_dupnoneed,enzyme_activenoneed,enzyme_constnoneed).Tests
5 new tests pass (with
context.set_allow_unregistered_dialects(true)set, allowing unregistered dialect ops)activity_mnemonics,compile_autodiff,compile_fwddiff,compile_jacobian,compile_batch, each verifying the built operation and asserting aninstasnapshot of its printed form.Example
autodiffbuilds the reverse-mode derivative of@squarewith respect to aenzyme_dupoperand:The
compile_autodifftest asserts this produces the module with the currently-unregistered (waiting on the PR on Enzyme's repo)enzyme.autodiffop.Open Questions
Where should this live? Here, in Melior, or in a separate crate like
mlir-rs/enzyme-sys?(A) On by Default: the bindings require extra registration of the Enzyme MLIR dialect. Tracking changes hopefully would be relatively contained to AD.
(B) Feature-Gated: Here, but behind a Cargo feature flag (e.g.
enzyme) so the Enzyme C API dependency (and its MLIR-22-only requirement, see Upstream Dependencies) is opt-in rather than forced on all melior users. Adds a feature-flag and CI-matrix maintenance burden while #2859/#2858 are still in flux.(C) Separate Crate: Enzyme isn't part of upstream MLIR core. Melior's existing (arith, cf, func, index, llvm, memref, scf) are all upstream MLIR-core dialects, usable via
register_all_dialects. Enzyme's dialect (and theenzymeRegisterDialectExtensions(MlirDialectRegistry)function proposed in EnzymeAD/Enzyme#2859) would be in Enzyme's repo and bound on the Rust side asenzyme_sys::register_dialect_extensions(registry: &DialectRegistry), mirroringregister_all_dialects(registry: &DialectRegistry).Naming:
fwddiff/jacobianmatch Enzyme's op mnemonics. Should these instead follow melior's usual multi-wordsnake_caseconvention (e.g.fwd_diff,auto_diff)?