Skip to content

Cache subexpressions when building MOI.ScalarNonlinearExpression#4032

Open
blegat wants to merge 15 commits into
masterfrom
bl/cache_sub
Open

Cache subexpressions when building MOI.ScalarNonlinearExpression#4032
blegat wants to merge 15 commits into
masterfrom
bl/cache_sub

Conversation

@blegat

@blegat blegat commented Jul 22, 2025

Copy link
Copy Markdown
Member

Here is a simplified reproducer of the performance issue identified in #4024:

using JuMP

f(x, u) = [sin(x[1]) - x[1] * u, cos(x[2]) + x[1] * u]
function RK4(f, X, u)
    k1 = f(X     , u)
    k2 = f(X+k1/2, u)
    k3 = f(X+k2/2, u)
    k4 = f(X+k3  , u)
    X + (k1 + 2k2 + 2k3 + k4) / 6
end

model = Model()

@variable(model, q[1:2])
@variable(model, u)

x = q
for m = 1:4
    x = RK4(f, x, u)
end
@time moi_function.(x)

Before this PR, this gives

14.610110 seconds (207.45 M allocations: 6.927 GiB, 68.63% gc time)

After this PR, this gives

0.000103 seconds (2.16 k allocations: 74.250 KiB)

Note that the MOI.ScalarNonlinearFunction we generate now share common sub-expression by pointers! If I'm not mistaken, we don't support modifying these in-place so that shouldn't be an issue.

This fixes the slow model-building issue but ReverseAD will still be terribly slow. One thing we could do is to detect, in MOI.Nonlinear, the MOI.ScalarNonlinearFunction that share sub-functions correponding to the same object (with a dictionary). When it detects two functions with the same pointer, it can then create subexpressions and use its existing support for subexpressions.
The nice thing about it is that we're not doing any change to the MOI interface. Creating MOI.ScalarNonlinearFunction was already possible from the beginning, we just didn't treat it any differently. So this change will just be a performance optimization of the AD but the interface is as simple and non-breaking as it gets.

Results on other benchmarks

For this benchmark, cherry-pick both jump-dev/MathOptInterface.jl#2803 and jump-dev/MathOptInterface.jl#3008. We don't know yet what's the right fix at the MOI level, these PRs just make sure we only see the JuMP-side in the timing.
Then, execute test/perf/benchmark_cache.jl.
The conclusion of the benchmark seems to be that a global dict is better than a new dict per constraint and also that we have a small win to hash the ObjectId.
The benchmark was generated by Claude as a neutral judge.

A: aliased tree (one big DAG, K=16)

Current

BenchmarkTools.Trial: 149 samples with 1 evaluation per sample.
 Range (min … max):  10.813 ms … 176.359 ms  ┊ GC (min … max):  0.00% … 86.85%
 Time  (median):     26.873 ms               ┊ GC (median):     0.00%
 Time  (mean ± σ):   33.761 ms ±  30.991 ms  ┊ GC (mean ± σ):  22.81% ± 20.14%

  ▁▂  ▆▆▂█▃                                                     
  ██▇▅█████▆▆▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▅▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▆▅▁▁▁▆▁▅▅ ▅
  10.8 ms       Histogram: log(frequency) by time       168 ms <

 Memory estimate: 24.01 MiB, allocs estimate: 721053.

One dict per constraint

BenchmarkTools.Trial: 1850 samples with 1 evaluation per sample.
 Range (min … max):  2.579 ms …  4.681 ms  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     2.689 ms              ┊ GC (median):    0.00%
 Time  (mean ± σ):   2.704 ms ± 86.148 μs  ┊ GC (mean ± σ):  0.00% ± 0.00%

          ▂▂▄▄▄▆█▄▅▂▆▂▃▂▁▃▂▁                                  
  ▂▂▂▃▃▅▆███████████████████▇█▆▆▅▅▇▄▅▅▄▄▄▄▃▃▃▃▃▃▃▂▃▃▃▃▂▂▁▂▂▃ ▄
  2.58 ms        Histogram: frequency by time        2.92 ms <

 Memory estimate: 17.45 KiB, allocs estimate: 303.

Global dict

BenchmarkTools.Trial: 10000 samples with 1 evaluation per sample.
 Range (min … max):   9.567 μs … 150.842 ms  ┊ GC (min … max):  0.00% … 99.95%
 Time  (median):     12.009 μs               ┊ GC (median):     0.00%
 Time  (mean ± σ):   28.061 μs ±   1.508 ms  ┊ GC (mean ± σ):  53.73% ±  1.00%

      ▃▆█▇▄▂                                                    
  ▁▂▃▇███████▆▅▆▆▆▇▇▆▆▅▅▄▃▂▂▂▃▃▃▂▃▂▂▂▂▂▂▂▂▂▂▂▁▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁ ▃
  9.57 μs         Histogram: frequency by time         22.6 μs <

 Memory estimate: 17.41 KiB, allocs estimate: 325.

Global dict + object id

BenchmarkTools.Trial: 10000 samples with 3 evaluations per sample.
 Range (min … max):   9.182 μs … 211.604 μs  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     13.418 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   14.339 μs ±   5.920 μs  ┊ GC (mean ± σ):  0.00% ± 0.00%

       ▂▅▆▇██▇▅▅▃▂▂▁                                            
  ▂▂▃▄▆███████████████▆▆▅▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▁▂▂▂ ▄
  9.18 μs         Histogram: frequency by time         30.1 μs <

 Memory estimate: 16.58 KiB, allocs estimate: 297.

B: many independent constraints (N=5000)

Current

BenchmarkTools.Trial: 271 samples with 1 evaluation per sample.
 Range (min … max):   9.258 ms … 313.421 ms  ┊ GC (min … max):  0.00% … 95.58%
 Time  (median):     14.209 ms               ┊ GC (median):     0.00%
 Time  (mean ± σ):   18.470 ms ±  34.692 ms  ┊ GC (mean ± σ):  25.40% ± 12.79%

  ▆█▁                                                           
  ███▇▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▄▁▁▁▁▁▁▁▁▁▄ ▅
  9.26 ms       Histogram: log(frequency) by time       251 ms <

 Memory estimate: 14.49 MiB, allocs estimate: 380142.

One dict per constraint

BenchmarkTools.Trial: 221 samples with 1 evaluation per sample.
 Range (min … max):  10.698 ms … 329.769 ms  ┊ GC (min … max):  0.00% … 94.39%
 Time  (median):     16.632 ms               ┊ GC (median):     0.00%
 Time  (mean ± σ):   22.627 ms ±  41.113 ms  ┊ GC (mean ± σ):  29.84% ± 15.29%

  ▆█                                                            
  ██▇█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▄▁▁▁▁▄▁▁▁▄▁▁▄ ▅
  10.7 ms       Histogram: log(frequency) by time       267 ms <

 Memory estimate: 17.23 MiB, allocs estimate: 400142.

Global dict

BenchmarkTools.Trial: 236 samples with 1 evaluation per sample.
 Range (min … max):   7.718 ms … 338.331 ms  ┊ GC (min … max):  0.00% … 93.87%
 Time  (median):     18.529 ms               ┊ GC (median):     0.00%
 Time  (mean ± σ):   21.190 ms ±  36.381 ms  ┊ GC (mean ± σ):  22.21% ± 12.21%

  ▅▅█                                                           
  ███▇▆▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▄▁▁▁▁▁▁▄ ▅
  7.72 ms       Histogram: log(frequency) by time       287 ms <

 Memory estimate: 17.37 MiB, allocs estimate: 418600.

Global dict + object id

BenchmarkTools.Trial: 156 samples with 1 evaluation per sample.
 Range (min … max):   6.855 ms …    2.565 s  ┊ GC (min … max):  0.00% … 99.34%
 Time  (median):     13.982 ms               ┊ GC (median):     0.00%
 Time  (mean ± σ):   32.069 ms ± 204.190 ms  ┊ GC (mean ± σ):  50.93% ±  7.95%

            ▁▄▁▃█                                               
  ▃▁▁▃▃▅▃▁▁▃█████▆▄▃▁▆▃▃▃▁▃▃▄▃▃▁▁▃▁▃▃▁▁▄▃▃▃▁▁▃▁▁▃▁▁▁▁▁▃▃▁▁▁▁▁▃ ▃
  6.85 ms         Histogram: frequency by time         37.4 ms <

 Memory estimate: 15.81 MiB, allocs estimate: 360177.

C: shared big subexpr (N=200, M=200)

Current

BenchmarkTools.Trial: 292 samples with 1 evaluation per sample.
 Range (min … max):   6.926 ms … 522.915 ms  ┊ GC (min … max):  0.00% … 97.27%
 Time  (median):     10.711 ms               ┊ GC (median):     0.00%
 Time  (mean ± σ):   17.169 ms ±  41.033 ms  ┊ GC (mean ± σ):  19.03% ±  8.04%

  ▃▆▄█▄ ▄▃ ▂                                                    
  ████████▆█▅█▆▄▄▁▄▁▁▄▁▁▁▄▄▁▁▁▁▁▁▁▁▄▁▁▁▁▁▁▁▁▁▄▁▁▁▁▁▄▁▄▁▁▁▁▄▄▄▅ ▅
  6.93 ms       Histogram: log(frequency) by time      74.8 ms <

 Memory estimate: 20.51 MiB, allocs estimate: 449903.

One dict per constraint

BenchmarkTools.Trial: 160 samples with 1 evaluation per sample.
 Range (min … max):  18.589 ms … 618.022 ms  ┊ GC (min … max):  0.00% … 94.65%
 Time  (median):     22.911 ms               ┊ GC (median):     0.00%
 Time  (mean ± σ):   31.275 ms ±  60.962 ms  ┊ GC (mean ± σ):  21.55% ± 10.59%

    █▃ ▄       ▅▂▅▇           ▄ ▃                               
  ▅▃████▆▇▁▆▅▅▇████▇▃▁▅▁▃▁▁▆▇▇███▇▅▁▁▁▁▁▁▁▁▁▁▁▃▁▃▁▁▆▆▆▃▃█▇▃▁▃▆ ▃
  18.6 ms         Histogram: frequency by time         34.7 ms <

 Memory estimate: 27.11 MiB, allocs estimate: 453303.

Global dict

BenchmarkTools.Trial: 6570 samples with 1 evaluation per sample.
 Range (min … max):  248.308 μs … 524.374 ms  ┊ GC (min … max):  0.00% … 99.79%
 Time  (median):     554.571 μs               ┊ GC (median):     0.00%
 Time  (mean ± σ):   760.800 μs ±   8.279 ms  ┊ GC (mean ± σ):  18.82% ±  1.74%

           █▂   ▂▂▁                                              
  ▄▇▅▅▃▄▄▆███▇▄▄███▅▃▃▄▄▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂ ▃
  248 μs           Histogram: frequency by time         1.98 ms <

 Memory estimate: 707.58 KiB, allocs estimate: 13364.

Global dict + object id

BenchmarkTools.Trial: 3525 samples with 1 evaluation per sample.
 Range (min … max):  223.772 μs …   2.514 s  ┊ GC (min … max):  0.00% … 99.83%
 Time  (median):     736.434 μs              ┊ GC (median):     0.00%
 Time  (mean ± σ):     1.879 ms ± 42.358 ms  ┊ GC (mean ± σ):  37.89% ±  1.68%

  ▃█▃▁▅▇▄▄▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▁▁▁▁ ▁                                ▁
  ███████████████████████████████▇▇▇▇▇▇▇▅▁▅▅▅▆▃▅▅▆▅▃▁▅▄▁▄▁▃▁▃▆ █
  224 μs        Histogram: log(frequency) by time      5.68 ms <

 Memory estimate: 662.53 KiB, allocs estimate: 11724.

D: many aliased trees (M=200, K=8)

Current

BenchmarkTools.Trial: 91 samples with 1 evaluation per sample.
 Range (min … max):  14.584 ms …    1.458 s  ┊ GC (min … max):  0.00% … 98.41%
 Time  (median):     53.483 ms               ┊ GC (median):     0.00%
 Time  (mean ± σ):   69.033 ms ± 149.943 ms  ┊ GC (mean ± σ):  22.83% ± 10.32%

       ▁    █                                                   
  ▃▄▄▆▆█▇▆███▇▄▃▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃ ▁
  14.6 ms         Histogram: frequency by time          285 ms <

 Memory estimate: 19.29 MiB, allocs estimate: 574306.

One dict per constraint

BenchmarkTools.Trial: 1481 samples with 1 evaluation per sample.
 Range (min … max):  3.041 ms …   4.058 ms  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     3.369 ms               ┊ GC (median):    0.00%
 Time  (mean ± σ):   3.377 ms ± 196.428 μs  ┊ GC (mean ± σ):  0.00% ± 0.00%

                   ▂▁▃▅▄▃█▄                                    
  ▁▁▂▄▆▅▆▅██▆▃▁▂▃▅██████████▃▁▁▁▁▁▁▂▁▂▂▂▂▂▁▂▁▂▁▁▁▁▁▂▁▁▁▂▃▂▃▂▂ ▃
  3.04 ms         Histogram: frequency by time           4 ms <

 Memory estimate: 1.19 MiB, allocs estimate: 26906.

Global dict

BenchmarkTools.Trial: 3090 samples with 1 evaluation per sample.
 Range (min … max):  954.119 μs … 986.423 ms  ┊ GC (min … max):  0.00% … 99.38%
 Time  (median):       1.246 ms               ┊ GC (median):     0.00%
 Time  (mean ± σ):     1.618 ms ±  17.724 ms  ┊ GC (mean ± σ):  19.61% ±  1.79%

    ▁█▅ ▁                                                        
  ▂▅██████▅▃▄▄▄▄▄▄▄▄▃▂▁▁▁▁▄██▅▆███▇▆▅▄▃▂▁▁▂▂▂▁▁▁▂▁▂▂▁▁▁▁▁▁▁▁▁▁▁ ▃
  954 μs           Histogram: frequency by time         2.08 ms <

 Memory estimate: 1.24 MiB, allocs estimate: 28180.

Global dict + object id

BenchmarkTools.Trial: 1992 samples with 1 evaluation per sample.
 Range (min … max):  845.061 μs …   2.848 s  ┊ GC (min … max):  0.00% … 99.98%
 Time  (median):     998.244 μs              ┊ GC (median):     0.00%
 Time  (mean ± σ):     2.510 ms ± 63.787 ms  ┊ GC (mean ± σ):  56.95% ±  2.24%

   ▃▁  █▅▃▂▁▁     ▁                                             
  ▅██▆▇██████▆▄▃▃▇█▅▃▃▃▂▂▃▃▄▄▄▃▃▄▅▄▃▂▂▂▂▂▂▃▃▃▂▂▂▂▂▂▂▁▁▂▂▂▂▂▂▂▂ ▃
  845 μs          Histogram: frequency by time         1.89 ms <

 Memory estimate: 1.15 MiB, allocs estimate: 25129.

@codecov

codecov Bot commented Jul 22, 2025

Copy link
Copy Markdown

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 99.79%. Comparing base (ff850e6) to head (85ec218).
⚠️ Report is 3 commits behind head on master.

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #4032      +/-   ##
==========================================
- Coverage   99.91%   99.79%   -0.13%     
==========================================
  Files          42       42              
  Lines        6229     6250      +21     
==========================================
+ Hits         6224     6237      +13     
- Misses          5       13       +8     

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@odow

odow commented Jul 22, 2025

Copy link
Copy Markdown
Member

If I'm not mistaken, we don't support modifying these in-place so that shouldn't be an issue.

In theory we don't, but the can be. This change needs very very careful consideration.

@blegat

blegat commented Jul 22, 2025

Copy link
Copy Markdown
Member Author

Yes, it has nontrivial consequences. It may very well jeopardize our chances of adding a modification API for nonlinear constraints in the future.

@odow

odow commented Aug 3, 2025

Copy link
Copy Markdown
Member

I think we can simplify this to:

function my_moi_function(f::GenericNonlinearExpr{V}) where {V}
    cache = Dict{UInt64,MOI.ScalarNonlinearFunction}()
    ret = MOI.ScalarNonlinearFunction(f.head, similar(f.args))
    stack = Tuple{MOI.ScalarNonlinearFunction,Int,GenericNonlinearExpr{V}}[]
    for i in length(f.args):-1:1
        if f.args[i] isa GenericNonlinearExpr{V}
            push!(stack, (ret, i, f.args[i]))
        else
            ret.args[i] = moi_function(f.args[i])
        end
    end
    while !isempty(stack)
        parent, i, arg = pop!(stack)
        parent.args[i] = get!(cache, objectid(arg)) do
            child = MOI.ScalarNonlinearFunction(arg.head, similar(arg.args))
            for j in length(arg.args):-1:1
                if arg.args[j] isa GenericNonlinearExpr{V}
                    push!(stack, (child, j, arg.args[j]))
                else
                    child.args[j] = moi_function(arg.args[j])
                end
            end
            return child
        end
    end
    return ret
end

It doesn't hold state across multiple calls to moi_function, but if there are large nested expressions within a single expression (the common case), then we exploit that.

@odow

odow commented Aug 3, 2025

Copy link
Copy Markdown
Member

The example gives:

julia> @time moi_function.(x);
  0.000142 seconds (1.53 k allocations: 106.688 KiB)

@blegat

blegat commented Aug 5, 2025

Copy link
Copy Markdown
Member Author

If this was only about moi_function then I would just get the model from the first JuMP.VariableRef I would see and then use the dictionary inside it. The reason I have to change the API is to merge moi_function and check_belongs_to_model in view of that benchmark: jump-dev/MathOptInterface.jl#2788 (comment)

@odow

odow commented Aug 5, 2025

Copy link
Copy Markdown
Member

We can fix check_belongs_to_model as well. I don't want to add the model-level cache. The cache should be within the function.

@blegat

blegat commented Aug 6, 2025

Copy link
Copy Markdown
Member Author

It would be weird to have the subexpressions only work when they are on the same function. Especially since the AD backend supports them being on different ones. I guess we can still find out they are the same later in post-processing, it could be enough just for the sake of speeding up passing the model to the AD backend without the exponential blowup.
One issue would be that we start creating many small dictionaries if there are a lot of constraints, but that's probably negligible.
What's the reason for not having a model-level cache ?

@odow

odow commented Aug 6, 2025

Copy link
Copy Markdown
Member

It would be weird to have the subexpressions only work when they are on the same function.

This is an implementation detail. Nothing should be observable at the JuMP level.

Especially since the AD backend supports them being on different ones.

This is also an implementation detail.

I guess we can still find out they are the same later in post-processing, it could be enough just for the sake of speeding up passing the model to the AD backend without the exponential blowup.

This is my strongly preferred option. I would like us to provide the full tape to an AD engine, and then it to turn everything into a single global DAG.

One issue would be that we start creating many small dictionaries if there are a lot of constraints, but that's probably negligible.

Yes, we definitely need benchmarks before merging anything like this

What's the reason for not having a model-level cache ?

It turns every nonlinear expression into a long-lived GC object that will never be freed until the model is.

Second, any expression occurring in multiple constraints seems like much less of an issue than the original example, where nested expressions are programmatically created.

Third, if we start with the internal cache, we can always change to a model-level cache later if needed.

@blegat

blegat commented Aug 7, 2025

Copy link
Copy Markdown
Member Author

That makes sense, if we only take care of not duplicating aliased subexpression at the level of each function, it's indeed easier.
We then detect subexpressions in the AD backend using hashes so the aliases used by the user have no impact on the subexpressions used by the AD.
It means that the user has no control over subexpressions but on the other hand, relying on the sub-expression used by the user was probably not the ideal way to let the user control it as well.

Comment thread src/nlp_expr.jl
@blegat

blegat commented Feb 26, 2026

Copy link
Copy Markdown
Member Author

To avoid having the JuMP objects continue living, we can us their hash as keys instead of the objects themselves

@blegat

blegat commented Feb 26, 2026

Copy link
Copy Markdown
Member Author

#3729 (comment) is a good benchmark but we need more.
What about a big sum of sin(x[i])

@blegat

blegat commented Jun 6, 2026

Copy link
Copy Markdown
Member Author

I added benchmarks on the top comment.

@blegat blegat marked this pull request as ready for review June 6, 2026 11:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Development

Successfully merging this pull request may close these issues.

2 participants