Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DNMY: Enzyme extension #3712

Closed
wants to merge 1 commit into from
Closed

Conversation

michel2323
Copy link
Contributor

I made this PR initially to Enzyme EnzymeAD/Enzyme.jl#1337 , but @wsmoses recommended to make it an extension of JuMP. Let me know if this works and I can add this as a test.

This extends JuMP and allows a user in JuMP to differentiate an external function using Enzyme.

Use case:

using Ipopt
using JuMP
using Enzyme

# Rosenbrock
rosenbrock(x...) = (1 - x[1])^2 + 100 * (x[2] - x[1]^2)^2

model = Model(Ipopt.Optimizer)
op_rosenbrock =  model[:op_rosenbrock] = add_nonlinear_operator(model, 2, rosenbrock; name=:op_rosenbrock) 
@variable(model, x[1:2])

@objective(model, Min, op_rosenbrock(x[1],x[2]))

optimize!(model)

Copy link

codecov bot commented Mar 13, 2024

Codecov Report

Attention: Patch coverage is 0% with 44 lines in your changes are missing coverage. Please review.

Project coverage is 97.62%. Comparing base (a15daaa) to head (782b3b9).

Files Patch % Lines
ext/JuMPEnzymeExt.jl 0.00% 44 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #3712      +/-   ##
==========================================
- Coverage   98.37%   97.62%   -0.75%     
==========================================
  Files          43       44       +1     
  Lines        5736     5780      +44     
==========================================
  Hits         5643     5643              
- Misses         93      137      +44     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@odow
Copy link
Member

odow commented Mar 13, 2024

Okay, this needs some discussion, likely at a monthly developer call.

Let's also put aside the exact syntax. Instead of pirating a method like this, we'd need to add some sort of type or flag for people to opt-in, but that is a small issue that can be resolved.

I am somewhat in favor of this, but given the experience of #3707, I think we should be very careful about adding this.

Particularly relevant is this discussion: #3413 (comment)

I would be strongly in favor of making a requirement that new extensions must have a 1.0 release, and have no plans for a 2.0 release. This would rule out anything that has moved from v1.0.0 to v5.67.2 in a short time period, and it would rule out Enzyme, which is on v0.11

Another option is that we add a page to the documentation which shows how to construct the appropriate gradient and hessian oracles, but we don't add this to JuMP, either directly or as an extension.

It's also worth evaluating the cost on compilation times for the tests and documentation if we add this. Enzyme is pretty heavy.

@odow odow added Category: Nonlinear Related to nonlinear programming Status: Needs developer call This should be discussed on a monthly developer call labels Mar 13, 2024
@odow
Copy link
Member

odow commented Mar 13, 2024

Also, using your code I get:

julia> using Enzyme

julia> function jump_operator(f::Function)
           @inline function f!(y, x...)
               y[1] = f(x...)
           end
           function gradient!(g::AbstractVector{T}, x::Vararg{T,N}) where {T,N}
               y = zeros(T,1)
               ry = ones(T,1)
               rx = ntuple(N) do i
                   Active(x[i])
               end
               g .= autodiff(Reverse, f!, Const, Duplicated(y,ry), rx...)[1][2:end]
               return nothing
           end

           function gradient_deferred!(g, y, ry, rx...)
               g .= autodiff_deferred(Reverse, f!, Const, Duplicated(y,ry), rx...)[1][2:end]
               return nothing
           end

           function hessian!(H::AbstractMatrix{T}, x::Vararg{T,N}) where {T,N}
               y = zeros(T,1)
               dy = ntuple(N) do i
                   ones(1)
               end
               g = zeros(T,N)
               dg = ntuple(N) do i
                   zeros(T,N)
               end
               ry = ones(1)
               dry = ntuple(N) do i
                   zeros(T,1)
               end
               rx = ntuple(N) do i
                   Active(x[i])
               end

               args = ntuple(N) do i
                   drx = ntuple(N) do j
                       if i == j
                           Active(one(T))
                       else
                           Active(zero(T))
                       end
                   end
                   BatchDuplicated(rx[i], drx)
               end
               autodiff(Forward, gradient_deferred!, Const, BatchDuplicated(g,dg), BatchDuplicated(y,dy), BatchDuplicated(ry, dry), args...)
               for i in 1:N
                   for j in 1:N
                       if i <= j
                           H[j,i] = dg[j][i]
                       end
                   end
               end
               return nothing
           end

           return gradient!, hessian!
       end
jump_operator (generic function with 1 method)

julia> foo(x...) = log(sum(exp(x[i]) for i in 1:length(x)))
foo (generic function with 1 method)

julia> ∇foo, ∇²foo = jump_operator(foo)
(var"#gradient!#9"{var"#f!#8"{typeof(foo)}}(var"#f!#8"{typeof(foo)}(foo)), var"#hessian!#12"{var"#gradient_deferred!#11"{var"#f!#8"{typeof(foo)}}}(var"#gradient_deferred!#11"{var"#f!#8"{typeof(foo)}}(var"#f!#8"{typeof(foo)}(foo))))

julia> N = 3
3

julia> x = rand(N)
3-element Vector{Float64}:
 0.23712902725864782
 0.6699243680780806
 0.530669076854107

julia> g = zeros(N)
3-element Vector{Float64}:
 0.0
 0.0
 0.0

julia> H = zeros(N, N)
3×3 Matrix{Float64}:
 0.0  0.0  0.0
 0.0  0.0  0.0
 0.0  0.0  0.0

julia> foo(x...)
1.593666919840647

julia> ∇foo(g, x...)

julia> ∇²foo(H, x...)
ERROR: Attempting to call an indirect active function whose runtime value is inactive:
Backtrace

Stacktrace:
 [1] macro expansion
   @ ~/.julia/packages/Enzyme/wR2t7/src/compiler.jl:5378
 [2] enzyme_call
   @ ~/.julia/packages/Enzyme/wR2t7/src/compiler.jl:5056
 [3] AugmentedForwardThunk
   @ ~/.julia/packages/Enzyme/wR2t7/src/compiler.jl:5009
 [4] runtime_generic_augfwd
   @ ~/.julia/packages/Enzyme/wR2t7/src/rules/jitrules.jl:179
 [5] runtime_generic_augfwd
   @ ~/.julia/packages/Enzyme/wR2t7/src/rules/jitrules.jl:0
Stacktrace:
  [1] macro expansion
    @ ~/.julia/packages/Enzyme/wR2t7/src/compiler.jl:5378 [inlined]
  [2] enzyme_call
    @ ~/.julia/packages/Enzyme/wR2t7/src/compiler.jl:5056 [inlined]
  [3] AugmentedForwardThunk
    @ ~/.julia/packages/Enzyme/wR2t7/src/compiler.jl:5009 [inlined]
  [4] runtime_generic_augfwd
    @ ~/.julia/packages/Enzyme/wR2t7/src/rules/jitrules.jl:179 [inlined]
  [5] runtime_generic_augfwd
    @ ~/.julia/packages/Enzyme/wR2t7/src/rules/jitrules.jl:0 [inlined]
  [6] fwddiffe3julia_runtime_generic_augfwd_3727_inner_1wrap
    @ ~/.julia/packages/Enzyme/wR2t7/src/rules/jitrules.jl:0
  [7] macro expansion
    @ ~/.julia/packages/Enzyme/wR2t7/src/compiler.jl:5378 [inlined]
  [8] enzyme_call(::Val{…}, ::Ptr{…}, ::Type{…}, ::Type{…}, ::Val{…}, ::Type{…}, ::Type{…}, ::Const{…}, ::Type{…}, ::Const{…}, ::Const{…}, ::Const{…}, ::Const{…}, ::Const{…}, ::Const{…}, ::BatchDuplicated{…}, ::BatchDuplicated{…}, ::BatchDuplicated{…}, ::BatchDuplicated{…}, ::BatchDuplicated{…}, ::BatchDuplicated{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/wR2t7/src/compiler.jl:5056
  [9] (::Enzyme.Compiler.ForwardModeThunk{…})(::Const{…}, ::Const{…}, ::Vararg{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/wR2t7/src/compiler.jl:5001
 [10] runtime_generic_fwd(activity::Type{…}, width::Val{…}, RT::Val{…}, f::typeof(Enzyme.Compiler.runtime_generic_augfwd), df::Nothing, df_2::Nothing, df_3::Nothing, primal_1::Type{…}, shadow_1_1::Nothing, shadow_1_2::Nothing, shadow_1_3::Nothing, primal_2::Val{…}, shadow_2_1::Nothing, shadow_2_2::Nothing, shadow_2_3::Nothing, primal_3::Val{…}, shadow_3_1::Nothing, shadow_3_2::Nothing, shadow_3_3::Nothing, primal_4::Val{…}, shadow_4_1::Nothing, shadow_4_2::Nothing, shadow_4_3::Nothing, primal_5::typeof(foo), shadow_5_1::Nothing, shadow_5_2::Nothing, shadow_5_3::Nothing, primal_6::Nothing, shadow_6_1::Nothing, shadow_6_2::Nothing, shadow_6_3::Nothing, primal_7::Float64, shadow_7_1::Float64, shadow_7_2::Float64, shadow_7_3::Float64, primal_8::Base.RefValue{…}, shadow_8_1::Base.RefValue{…}, shadow_8_2::Base.RefValue{…}, shadow_8_3::Base.RefValue{…}, primal_9::Float64, shadow_9_1::Float64, shadow_9_2::Float64, shadow_9_3::Float64, primal_10::Base.RefValue{…}, shadow_10_1::Base.RefValue{…}, shadow_10_2::Base.RefValue{…}, shadow_10_3::Base.RefValue{…}, primal_11::Float64, shadow_11_1::Float64, shadow_11_2::Float64, shadow_11_3::Float64, primal_12::Base.RefValue{…}, shadow_12_1::Base.RefValue{…}, shadow_12_2::Base.RefValue{…}, shadow_12_3::Base.RefValue{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/wR2t7/src/rules/jitrules.jl:116
 [11] f!

@odow
Copy link
Member

odow commented Mar 13, 2024

Here's some code I had when experimenting with this:

abstract type AbstractADOperator end

#=
    Enzyme
=#

import Enzyme

struct ADOperatorEnzyme <: AbstractADOperator end

function create_operator(f::Function, ::ADOperatorEnzyme)
    @inline f!(y, x::Vararg{T,N}) where {T,N} = (y[1] = f(x...))
    function ∇f!(g::AbstractVector{T}, x::Vararg{T,N}) where {T,N}
        g .= Enzyme.autodiff(
            Enzyme.Reverse,
            f!,
            Enzyme.Const,
            Enzyme.Duplicated(zeros(T, 1), ones(T, 1)),
            Enzyme.Active.(x)...,
        )[1][2:end]
        return
    end
    function ∇²f!(H::AbstractMatrix{T}, x::Vararg{T,N}) where {T,N}
        dg = ntuple(_ -> zeros(T, N), N)
        args = ntuple(N) do i
            return Enzyme.BatchDuplicated(
                Enzyme.Active(x[i]),
                ntuple(j -> Enzyme.Active(T(i == j)), N),
            )
        end
        function gradient_deferred!(g, y, ry, rx...)
            g .= Enzyme.autodiff_deferred(
                Enzyme.Reverse,
                f!,
                Enzyme.Const,
                Enzyme.Duplicated(y, ry),
                rx...,
            )[1][2:end]
            return
        end
        Enzyme.autodiff(
            Enzyme.Forward,
            gradient_deferred!,
            Enzyme.Const,
            Enzyme.BatchDuplicated(zeros(T, N), dg),
            Enzyme.BatchDuplicated(zeros(T, 1), ntuple(_ -> ones(T, 1), N)),
            Enzyme.BatchDuplicated(ones(T, 1), ntuple(_ -> zeros(T, 1), N)),
            args...,
        )
        for j in 1:N, i in 1:j
            H[j, i] = dg[j][i]
        end
        return
    end
    return f, ∇f!, ∇²f!
end

#=
    ForwardDiff
=#

import ForwardDiff

struct ADOperatorForwardDiff <: AbstractADOperator end

function create_operator(f::Function, ::ADOperatorForwardDiff)
    function ∇f!(g::AbstractVector{T}, x::Vararg{T,N}) where {T,N}
        ForwardDiff.gradient!(g, y -> f(y...), collect(x))
        return
    end
    function ∇²f!(H::AbstractMatrix{T}, x::Vararg{T,N}) where {T,N}
        h = ForwardDiff.hessian(y -> f(y...), collect(x))
        for i in 1:N, j in 1:i
            H[i, j] = h[i, j]
        end
        return
    end
    return f, ∇f!, ∇²f!
end

#=
    Examples
=#

import LinearAlgebra
using Test

function example_logsumexp()
    f(x...) = log(sum(exp(x[i]) for i in 1:length(x)))
    ∇f(g, x...) = g .= exp.(x) ./ sum(exp.(x))
    function ∇²f(H, x...)
        y = collect(x)
        g = exp.(y) / sum(exp.(y))
        h = LinearAlgebra.diagm(g) - g * g'
        for i in 1:length(y), j in 1:i
            H[i, j] = h[i, j]
        end
        return
    end
    return f, ∇f, ∇²f
end

function example_rosenbrock()
    f(x, y) = (1 - x)^2 + 100 * (y - x^2)^2
    function ∇f(g, x, y)
        g[1] = 2 * (-1 + x - 200 * (y * x + -x^3))
        g[2] = 200 * (y - x^2)
        return
    end
    function ∇²f(H, x, y)
        H[1, 1] = 2 + 1200 * x^2 - 400 * y
        H[2, 1] = -400 * x
        H[2, 2] = 200
        return
    end
    return f, ∇f, ∇²f
end

function test_example(example, N, config::AbstractADOperator)
    true_f, true_∇f, true_∇²f = example()
    f, ∇f, ∇²f = create_operator(true_f, config)
    x = rand(N)
    y = f(x...)
    true_y = true_f(x...)
    @test isapprox(y, true_y)
    g, true_g = zeros(N), zeros(N)
    ∇f(g, x...)
    true_∇f(true_g, x...)
    @test isapprox(g, true_g)
    H, true_H = zeros(N, N), zeros(N, N)
    ∇²f(H, x...)
    true_∇²f(true_H, x...)
    @test isapprox(H, true_H)
    return
end

@testset "Examples" begin
    for config in (ADOperatorForwardDiff(), ADOperatorEnzyme())
        for (example, N) in (
            example_rosenbrock => 2,
            example_logsumexp => 3,
            example_logsumexp => 20,
        )
            @testset "$example - $N - $config" begin
                test_example(example, N, config)
            end
        end
    end
end

Running yields

Examples                                           |   16      2     18  11.3s
  example_rosenbrock - 2 - ADOperatorForwardDiff() |    3             3   1.0s
  example_logsumexp - 3 - ADOperatorForwardDiff()  |    3             3   1.1s
  example_logsumexp - 20 - ADOperatorForwardDiff() |    3             3   1.2s
  example_rosenbrock - 2 - ADOperatorEnzyme()      |    3             3   0.4s
  example_logsumexp - 3 - ADOperatorEnzyme()       |    2      1      3   1.0s
  example_logsumexp - 20 - ADOperatorEnzyme()      |    2      1      3   6.6s
ERROR: LoadError: Some tests did not pass: 16 passed, 0 failed, 2 errored, 0 broken.

@odow
Copy link
Member

odow commented Mar 14, 2024

Okay, I've tightened things up considerably, and got rid of the Hessian error:

abstract type AbstractADOperator end

#=
    Enzyme
=#

import Enzyme

struct ADOperatorEnzyme <: AbstractADOperator end

function create_operator(f::Function, ::ADOperatorEnzyme)
    function ∇f!(g::AbstractVector{T}, x::Vararg{T,N}) where {T,N}
        g .= Enzyme.autodiff(Enzyme.Reverse, f, Enzyme.Active.(x)...)[1]
        return
    end
    function ∇²f!(H::AbstractMatrix{T}, x::Vararg{T,N}) where {T,N}
        direction(i) = ntuple(j -> Enzyme.Active(T(i == j)), N)
        hess = Enzyme.autodiff(
            Enzyme.Forward,
            (x...) -> Enzyme.autodiff_deferred(Enzyme.Reverse, f, x...)[1],
            Enzyme.BatchDuplicated.(Enzyme.Active.(x), ntuple(direction, N))...,
        )[1]
        for j in 1:N, i in 1:j
            H[j, i] = hess[j][i]
        end
        return
    end
    return f, ∇f!, ∇²f!
end

#=
    ForwardDiff
=#

import ForwardDiff

struct ADOperatorForwardDiff <: AbstractADOperator end

function create_operator(f::Function, ::ADOperatorForwardDiff)
    function ∇f!(g::AbstractVector{T}, x::Vararg{T,N}) where {T,N}
        ForwardDiff.gradient!(g, y -> f(y...), collect(x))
        return
    end
    function ∇²f!(H::AbstractMatrix{T}, x::Vararg{T,N}) where {T,N}
        h = ForwardDiff.hessian(y -> f(y...), collect(x))
        for i in 1:N, j in 1:i
            H[i, j] = h[i, j]
        end
        return
    end
    return f, ∇f!, ∇²f!
end

#=
    Examples
=#

import LinearAlgebra
using Test

function example_logsumexp()
    f(x...) = log(sum(exp(x[i]) for i in 1:length(x)))
    ∇f(g, x...) = g .= exp.(x) ./ sum(exp.(x))
    function ∇²f(H, x...)
        y = collect(x)
        g = exp.(y) / sum(exp.(y))
        h = LinearAlgebra.diagm(g) - g * g'
        for i in 1:length(y), j in 1:i
            H[i, j] = h[i, j]
        end
        return
    end
    return f, ∇f, ∇²f
end

function example_rosenbrock()
    f(x, y) = (1 - x)^2 + 100 * (y - x^2)^2
    function ∇f(g, x, y)
        g[1] = 2 * (-1 + x - 200 * (y * x + -x^3))
        g[2] = 200 * (y - x^2)
        return
    end
    function ∇²f(H, x, y)
        H[1, 1] = 2 + 1200 * x^2 - 400 * y
        H[2, 1] = -400 * x
        H[2, 2] = 200
        return
    end
    return f, ∇f, ∇²f
end

function test_example(example, N, config::AbstractADOperator)
    true_f, true_∇f, true_∇²f = example()
    f, ∇f, ∇²f = create_operator(true_f, config)
    x = rand(N)
    y = f(x...)
    true_y = true_f(x...)
    @test isapprox(y, true_y)
    g, true_g = zeros(N), zeros(N)
    ∇f(g, x...)
    true_∇f(true_g, x...)
    @test isapprox(g, true_g)
    H, true_H = zeros(N, N), zeros(N, N)
    ∇²f(H, x...)
    true_∇²f(true_H, x...)
    @test isapprox(H, true_H)
    return
end

@testset "Examples" begin
    for config in (ADOperatorForwardDiff(), ADOperatorEnzyme())
        for (example, N) in (
            example_rosenbrock => 2,
            example_logsumexp => 3,
            example_logsumexp => 20,
        )
            @testset "$example - $N - $config" begin
                test_example(example, N, config)
            end
        end
    end
end

@odow
Copy link
Member

odow commented Mar 14, 2024

Okay, so since this is 20 lines of code, I think this might better as a tutorial in the documentation.

@blegat has asked for this before: #2348 (comment)

It'll also let us show off Enzyme and ForwardDiff.

I'll take a stab, and then we can discuss the relative merits of having the code as a JuMP extension vs asking people to copy-paste a snippet.

@odow
Copy link
Member

odow commented Mar 28, 2024

Developer call says that the documentation https://jump.dev/JuMP.jl/dev/tutorials/nonlinear/operator_ad/ is sufficient.

@odow odow closed this Mar 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Category: Nonlinear Related to nonlinear programming Status: Needs developer call This should be discussed on a monthly developer call
Development

Successfully merging this pull request may close these issues.

2 participants