Skip to content

RFC: enzyme dialect bindings for melior#860

Draft
agarret7 wants to merge 1 commit into
mlir-rs:mainfrom
agarret7:enzyme-dialect-rfc
Draft

RFC: enzyme dialect bindings for melior#860
agarret7 wants to merge 1 commit into
mlir-rs:mainfrom
agarret7:enzyme-dialect-rfc

Conversation

@agarret7

@agarret7 agarret7 commented Jun 15, 2026

Copy link
Copy Markdown
Contributor

Status

This is a draft PR follow-up to #856. It is not yet tested on real Enzyme lowering passes.

It's blocked on Enzyme's MLIR C API for dialect registration (enzymeRegisterDialectExtensions), which isn't merged (see EnzymeAD/Enzyme#2859, still open) and MLIR 22 support in Enzyme. So these builders' enzyme.* ops currently require context.set_allow_unregistered_dialects(true). Until #2859 lands (e.g. via a future enzyme-sys crate), the ops are just parsed and printed.

  1. EnzymeAD/Enzyme#2859 -- the Enzyme MLIR C API (MLIRCAPIEnzyme): dialect registration, enzymeActivityAttrGet, enzymeAutoDiffOpCreate/enzymeForwardDiffOpCreate.
  2. EnzymeAD/Enzyme#2858 -- MLIR 22 support in EnzymeAD's CI -- bump enzyme-bazel.yml's pinned LLVM commit to MLIR 22, which requires fixing (getSuccessorInputs -> getSuccessorRegions) so MLIREnzymeAnalysis builds.

Motivation

Automatic differentiation (AD) is increasingly central to scientific computing, machine learning, and probabilistic programming. MLIR's multi-dialect, multi-stage compilation model is a natural fit for post-optimization, compile-time AD passes before lowering to LLVM. This is exactly what Enzyme's dialect provides.

Exposing Enzyme's MLIR dialect through Melior would give Rust users a safe, ergonomic path to compile-time gradients without leaving the MLIR ecosystem and without having to use macros to differentiate their functions. This is especially targeted at deep learning libraries (candle, burn, dfdx) that hand-roll tensor graphs and can't wrap their inner ops with std::autodiff or std::offload and thus would benefit from a compile-time AD pass directly on their IR.

Scope

This is purely targeted at exposing EnzymeAD MLIR ops through mlir-rs.
Closes #856.

New API Surface

  1. autodiff: enzyme.autodiff, reverse-mode derivative of a function with respect to its enzyme_active/enzyme_dup operands.

    pub fn autodiff<'c>(
        context: &'c Context,
        function: FlatSymbolRefAttribute<'c>,
        operands: &[Value<'c, '_>],
        result_types: &[Type<'c>],
        activity: ArrayAttribute<'c>,
        ret_activity: ArrayAttribute<'c>,
        width: i64,
        strong_zero: bool,
        location: Location<'c>,
    ) -> Operation<'c>
  2. fwddiff: enzyme.fwddiff, forward-mode (tangent) derivative. Same signature as autodiff.

  3. jacobian: enzyme.jacobian, full Jacobian with respect to active/dup operands. Same signature as autodiff.

  4. batch: enzyme.batch, vectorizes a call over leading batch dimensions of its operands and results.

    pub fn batch<'c>(
        context: &'c Context,
        function: FlatSymbolRefAttribute<'c>,
        operands: &[Value<'c, '_>],
        result_types: &[Type<'c>],
        batch_shape: &[i64],
        location: Location<'c>,
    ) -> Operation<'c>
  5. Activity enum and activity_attribute/activity_array_attribute helpers for building #enzyme<activity ...> attributes (enzyme_active, enzyme_dup, enzyme_const, enzyme_dupnoneed, enzyme_activenoneed, enzyme_constnoneed).

    pub enum Activity {
        Active,
        Dup,
        Const,
        DupNoNeed,
        ActiveNoNeed,
        ConstNoNeed,
    }
    
    pub fn activity_attribute(context: &Context, activity: Activity) -> Attribute<'_>
    
    pub fn activity_array_attribute<'c>(
        context: &'c Context,
        activities: &[Activity],
    ) -> ArrayAttribute<'c>

Tests

5 new tests pass (with context.set_allow_unregistered_dialects(true) set, allowing unregistered dialect ops) activity_mnemonics, compile_autodiff, compile_fwddiff, compile_jacobian, compile_batch, each verifying the built operation and asserting an insta snapshot of its printed form.

Example

autodiff builds the reverse-mode derivative of @square with respect to a
enzyme_dup operand:

let result = block
    .append_operation(autodiff(
        &context,
        FlatSymbolRefAttribute::new(&context, "square"),
        &[lhs, rhs],
        &[f64_type],
        activity_array_attribute(&context, &[Activity::Dup]),
        activity_array_attribute(&context, &[Activity::ActiveNoNeed]),
        1,
        false,
        location,
    ))
    .result(0)
    .unwrap()
    .into();

The compile_autodiff test asserts this produces the module with the currently-unregistered (waiting on the PR on Enzyme's repo) enzyme.autodiff op.

module {
  func.func @foo(%arg0: f64, %arg1: f64) -> f64 {
    %0 = "enzyme.autodiff"(%arg0, %arg1) {
      // differentiate `@square`
      fn = @square,
      // `@square` takes one argument, marked `enzyme_dup`: both its primal
      // value (%arg0) and a shadow/derivative slot (%arg1) are passed in
      activity = [#enzyme<activity enzyme_dup>],
      // the return value is active, but the primal result itself isn't
      // returned -- only its gradient (see `enzyme_activenoneed` discussion
      // above)
      ret_activity = [#enzyme<activity enzyme_activenoneed>],
      // compute a single (non-batched) derivative
      width = 1 : i64,
      // shadow/derivative buffers are not assumed to be pre-zeroed
      strong_zero = false
    } : (f64, f64) -> f64
    return %0 : f64
  }
}

Open Questions

  1. Where should this live? Here, in Melior, or in a separate crate like mlir-rs/enzyme-sys?

    (A) On by Default: the bindings require extra registration of the Enzyme MLIR dialect. Tracking changes hopefully would be relatively contained to AD.
    (B) Feature-Gated: Here, but behind a Cargo feature flag (e.g. enzyme) so the Enzyme C API dependency (and its MLIR-22-only requirement, see Upstream Dependencies) is opt-in rather than forced on all melior users. Adds a feature-flag and CI-matrix maintenance burden while #2859/#2858 are still in flux.
    (C) Separate Crate: Enzyme isn't part of upstream MLIR core. Melior's existing (arith, cf, func, index, llvm, memref, scf) are all upstream MLIR-core dialects, usable via register_all_dialects. Enzyme's dialect (and the enzymeRegisterDialectExtensions(MlirDialectRegistry) function proposed in EnzymeAD/Enzyme#2859) would be in Enzyme's repo and bound on the Rust side as enzyme_sys::register_dialect_extensions(registry: &DialectRegistry), mirroring register_all_dialects(registry: &DialectRegistry).

  2. Naming: fwddiff/jacobian match Enzyme's op mnemonics. Should these instead follow melior's usual multi-word snake_case convention (e.g. fwd_diff, auto_diff)?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Feature request: Enzyme-backed Automatic Differentiation (AD)

1 participant