Skip to content

Fix reverse mode for memref.alloca#2826

Open
xys-syx wants to merge 1 commit into
mainfrom
memref-alloca
Open

Fix reverse mode for memref.alloca#2826
xys-syx wants to merge 1 commit into
mainfrom
memref-alloca

Conversation

@xys-syx

@xys-syx xys-syx commented May 15, 2026

Copy link
Copy Markdown
Collaborator
  1. use scf.parallel + memref.store/memref.dim to replace linalg.fill
  2. substitue the enzyme.placeholder with real shadow

The key issue for memref.alloca is: the mutable type's enzyme.placeholder does not be subtitute by the real shadow in reverse mode.

The output before fix:

module {
  func.func @foo_flat(%arg0: f64) -> f64 {
    %alloca = memref.alloca() : memref<f64>
    memref.store %arg0, %alloca[] : memref<f64>
    %0 = memref.load %alloca[] : memref<f64>
    return %0 : f6 }
  func.func @dfoo_flat(%arg0: f64, %arg1: f64) -> f64 {
    %0 = call @diffefoo_flat(%arg0, %arg1) : (f64, f64) -> f64
    return %0 : f64
  }
  func.func private @diffefoo_flat(%arg0: f64, %arg1: f64) -> f64 {
    %0 = "enzyme.init"() : () -> !enzyme.Gradient<f64>
    %cst = arith.constant 0.000000e+00 : f64
    "enzyme.set"(%0, %cst) : (!enzyme.Gradient<f64>, f64) -> ()
    %1 = "enzyme.init"() : () -> !enzyme.Cache<memref<f64>>
    %2 = "enzyme.init"() : () -> !enzyme.Cache<memref<f64>>
    %3 = "enzyme.init"() : () -> !enzyme.Gradient<f64>
    %cst_0 = arith.constant 0.000000e+00 : f64
    "enzyme.set"(%3, %cst_0) : (!enzyme.Gradient<f64>, f64) -> ()
    %4 = "enzyme.placeholder"() : () -> memref<f64>
    %alloca = memref.alloca() : memref<f64>
    %cst_1 = arith.constant 0.000000e+00 : f64
    memref.store %cst_1, %alloca[] : memref<f64>
    %alloca_2 = memref.alloca() : memref<f64>
    "enzyme.push"(%1, %4) : (!enzyme.Cache<memref<f64>>, memref<f64>) -> ()
    memref.store %arg0, %alloca_2[] : memref<f64>
    "enzyme.push"(%2, %4) : (!enzyme.Cache<memref<f64>>, memref<f64>) -> ()
    %5 = memref.load %alloca_2[] : memref<f64>
    cf.br ^bb1
  ^bb1:  // pred: ^bb0
    %6 = "enzyme.get"(%3) : (!enzyme.Gradient<f64>) -> f64
    %7 = arith.addf %6, %arg1 : f64
    "enzyme.set"(%3, %7) : (!enzyme.Gradient<f64>, f64) -> ()
    %8 = "enzyme.get"(%3) : (!enzyme.Gradient<f64>) -> f64
    %9 = "enzyme.pop"(%2) : (!enzyme.Cache<memref<f64>>) -> memref<f64>
    %10 = memref.load %9[] : memref<f64>
    %11 = arith.addf %10, %8 : f64
    memref.store %11, %9[] : memref<f64>
    %12 = "enzyme.pop"(%1) : (!enzyme.Cache<memref<f64>>) -> memref<f64>
    %13 = memref.load %12[] : memref<f64>
    %14 = "enzyme.get"(%0) : (!enzyme.Gradient<f64>) -> f64
    %15 = arith.addf %14, %13 : f64
    "enzyme.set"(%0, %15) : (!enzyme.Gradient<f64>, f64) -> ()
    %cst_3 = arith.constant 0.000000e+00 : f64
    memref.store %cst_3, %12[] : memref<f64>
    %16 = "enzyme.get"(%0) : (!enzyme.Gradient<f64>) -> f64
    return %16 : f64
  }
}

@xys-syx xys-syx requested a review from vimarsh6739 May 15, 2026 04:06
Comment thread enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp
@xys-syx xys-syx requested a review from vimarsh6739 May 20, 2026 21:47
Comment thread enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp
@xys-syx xys-syx requested review from vimarsh6739 and wsmoses May 23, 2026 04:56
@vimarsh6739 vimarsh6739 changed the title Fix revert mode for memref.alloca Fix reverse mode for memref.alloca Jun 4, 2026
Comment thread enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp
@xys-syx xys-syx requested a review from wsmoses June 21, 2026 07:21
@vimarsh6739 vimarsh6739 force-pushed the memref-alloca branch 2 times, most recently from c2e2d9a to e362e6f Compare July 2, 2026 21:39
Allow setDiffe to replace placeholder shadows for mutable values in reverse
mode, while preserving existing user-provided shadows for mutable arguments.
Add memref.alloca coverage for both internal allocation shadows and memref
argument shadows.
@wsmoses

wsmoses commented Jul 2, 2026

Copy link
Copy Markdown
Member

what is the backtrace where this fails currently

@xys-syx

xys-syx commented Jul 2, 2026

Copy link
Copy Markdown
Collaborator Author

@wsmoses The issue is the placeholder was still not replaced by shadow value, while not a crash, the reason is setInvertedPointer never runs for the reverse+mutable value

Could you please tell would you hope to see to have some dumped MLIR in the downstream, or to see if it crash in the raising without this fix?

// %eopt --enzyme test/MLIR/ReverseMode/alloca.mlir 
module {
  func.func @foo_flat(%arg0: f64) -> f64 {
    %alloca = memref.alloca() : memref<f64>
    memref.store %arg0, %alloca[] : memref<f64>
    %0 = memref.load %alloca[] : memref<f64>
    return %0 : f64
  }
  func.func @dfoo_flat(%arg0: f64, %arg1: f64) -> f64 {
    %0 = call @diffefoo_flat(%arg0, %arg1) : (f64, f64) -> f64
    return %0 : f64
  }
  func.func @load_arg(%arg0: memref<f64>) -> f64 {
    %alloca = memref.alloca() : memref<f64>
    %0 = memref.load %arg0[] : memref<f64>
    memref.store %0, %alloca[] : memref<f64>
    %1 = memref.load %alloca[] : memref<f64>
    return %1 : f64
  }
  func.func @dload_arg(%arg0: memref<f64>, %arg1: memref<f64>, %arg2: f64) {
    call @diffeload_arg(%arg0, %arg1, %arg2) : (memref<f64>, memref<f64>, f64) -> ()
    return
  }
  func.func private @diffefoo_flat(%arg0: f64, %arg1: f64) -> f64 {
    %0 = "enzyme.init"() : () -> !enzyme.Gradient<f64>
    %cst = arith.constant 0.000000e+00 : f64
    "enzyme.set"(%0, %cst) : (!enzyme.Gradient<f64>, f64) -> ()
    %1 = "enzyme.init"() : () -> !enzyme.Cache<memref<f64>>
    %2 = "enzyme.init"() : () -> !enzyme.Cache<memref<f64>>
    %3 = "enzyme.init"() : () -> !enzyme.Gradient<f64>
    %cst_0 = arith.constant 0.000000e+00 : f64
    "enzyme.set"(%3, %cst_0) : (!enzyme.Gradient<f64>, f64) -> ()
    %4 = "enzyme.placeholder"() : () -> memref<f64>
    %alloca = memref.alloca() : memref<f64>
    %cst_1 = arith.constant 0.000000e+00 : f64
    linalg.fill ins(%cst_1 : f64) outs(%alloca : memref<f64>)
    %alloca_2 = memref.alloca() : memref<f64>
    "enzyme.push"(%1, %4) : (!enzyme.Cache<memref<f64>>, memref<f64>) -> ()
    memref.store %arg0, %alloca_2[] : memref<f64>
    "enzyme.push"(%2, %4) : (!enzyme.Cache<memref<f64>>, memref<f64>) -> ()
    %5 = memref.load %alloca_2[] : memref<f64>
    cf.br ^bb1
  ^bb1:  // pred: ^bb0
    %6 = "enzyme.get"(%3) : (!enzyme.Gradient<f64>) -> f64
    %7 = arith.addf %6, %arg1 : f64
    "enzyme.set"(%3, %7) : (!enzyme.Gradient<f64>, f64) -> ()
    %8 = "enzyme.get"(%3) : (!enzyme.Gradient<f64>) -> f64
    %9 = "enzyme.pop"(%2) : (!enzyme.Cache<memref<f64>>) -> memref<f64>
    %10 = memref.load %9[] : memref<f64>
    %11 = arith.addf %10, %8 : f64
    memref.store %11, %9[] : memref<f64>
    %12 = "enzyme.pop"(%1) : (!enzyme.Cache<memref<f64>>) -> memref<f64>
    %13 = memref.load %12[] : memref<f64>
    %14 = "enzyme.get"(%0) : (!enzyme.Gradient<f64>) -> f64
    %15 = arith.addf %14, %13 : f64
    "enzyme.set"(%0, %15) : (!enzyme.Gradient<f64>, f64) -> ()
    %cst_3 = arith.constant 0.000000e+00 : f64
    memref.store %cst_3, %12[] : memref<f64>
    %16 = "enzyme.get"(%0) : (!enzyme.Gradient<f64>) -> f64
    return %16 : f64
  }
  func.func private @diffeload_arg(%arg0: memref<f64>, %arg1: memref<f64>, %arg2: f64) {
    %0 = "enzyme.init"() : () -> !enzyme.Cache<memref<f64>>
    %1 = "enzyme.init"() : () -> !enzyme.Gradient<f64>
    %cst = arith.constant 0.000000e+00 : f64
    "enzyme.set"(%1, %cst) : (!enzyme.Gradient<f64>, f64) -> ()
    %2 = "enzyme.init"() : () -> !enzyme.Cache<memref<f64>>
    %3 = "enzyme.init"() : () -> !enzyme.Cache<memref<f64>>
    %4 = "enzyme.init"() : () -> !enzyme.Gradient<f64>
    %cst_0 = arith.constant 0.000000e+00 : f64
    "enzyme.set"(%4, %cst_0) : (!enzyme.Gradient<f64>, f64) -> ()
    %5 = "enzyme.placeholder"() : () -> memref<f64>
    %alloca = memref.alloca() : memref<f64>
    %cst_1 = arith.constant 0.000000e+00 : f64
    linalg.fill ins(%cst_1 : f64) outs(%alloca : memref<f64>)
    %alloca_2 = memref.alloca() : memref<f64>
    "enzyme.push"(%0, %arg1) : (!enzyme.Cache<memref<f64>>, memref<f64>) -> ()
    %6 = memref.load %arg0[] : memref<f64>
    "enzyme.push"(%2, %5) : (!enzyme.Cache<memref<f64>>, memref<f64>) -> ()
    memref.store %6, %alloca_2[] : memref<f64>
    "enzyme.push"(%3, %5) : (!enzyme.Cache<memref<f64>>, memref<f64>) -> ()
    %7 = memref.load %alloca_2[] : memref<f64>
    cf.br ^bb1
  ^bb1:  // pred: ^bb0
    %8 = "enzyme.get"(%4) : (!enzyme.Gradient<f64>) -> f64
    %9 = arith.addf %8, %arg2 : f64
    "enzyme.set"(%4, %9) : (!enzyme.Gradient<f64>, f64) -> ()
    %10 = "enzyme.get"(%4) : (!enzyme.Gradient<f64>) -> f64
    %11 = "enzyme.pop"(%3) : (!enzyme.Cache<memref<f64>>) -> memref<f64>
    %12 = memref.load %11[] : memref<f64>
    %13 = arith.addf %12, %10 : f64
    memref.store %13, %11[] : memref<f64>
    %14 = "enzyme.pop"(%2) : (!enzyme.Cache<memref<f64>>) -> memref<f64>
    %15 = memref.load %14[] : memref<f64>
    %16 = "enzyme.get"(%1) : (!enzyme.Gradient<f64>) -> f64
    %17 = arith.addf %16, %15 : f64
    "enzyme.set"(%1, %17) : (!enzyme.Gradient<f64>, f64) -> ()
    %cst_3 = arith.constant 0.000000e+00 : f64
    memref.store %cst_3, %14[] : memref<f64>
    %18 = "enzyme.get"(%1) : (!enzyme.Gradient<f64>) -> f64
    %19 = "enzyme.pop"(%0) : (!enzyme.Cache<memref<f64>>) -> memref<f64>
    %20 = memref.load %19[] : memref<f64>
    %21 = arith.addf %20, %18 : f64
    memref.store %21, %19[] : memref<f64>
    return
  }
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants