Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DNMY: Enzyme extension #3712

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[weakdeps]
DimensionalData = "0703355e-b756-11e9-17c0-8b28908087d0"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"

[extensions]
JuMPDimensionalDataExt = "DimensionalData"
JuMPEnzymeExt = "Enzyme"

[compat]
DimensionalData = "0.24, 0.25, 0.26.2"
Expand Down
76 changes: 76 additions & 0 deletions ext/JuMPEnzymeExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
module JuMPEnzymeExt

using Enzyme
using JuMP

function jump_operator(f::Function)
@inline function f!(y, x...)
y[1] = f(x...)

Check warning on line 8 in ext/JuMPEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/JuMPEnzymeExt.jl#L6-L8

Added lines #L6 - L8 were not covered by tests
end
function gradient!(g::AbstractVector{T}, x::Vararg{T,N}) where {T,N}
y = zeros(T,1)
ry = ones(T,1)
rx = ntuple(N) do i
Active(x[i])

Check warning on line 14 in ext/JuMPEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/JuMPEnzymeExt.jl#L10-L14

Added lines #L10 - L14 were not covered by tests
end
g .= autodiff(Reverse, f!, Const, Duplicated(y,ry), rx...)[1][2:end]
return nothing

Check warning on line 17 in ext/JuMPEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/JuMPEnzymeExt.jl#L16-L17

Added lines #L16 - L17 were not covered by tests
end

function gradient_deferred!(g, y, ry, rx...)
g .= autodiff_deferred(Reverse, f!, Const, Duplicated(y,ry), rx...)[1][2:end]
return nothing

Check warning on line 22 in ext/JuMPEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/JuMPEnzymeExt.jl#L20-L22

Added lines #L20 - L22 were not covered by tests
end

function hessian!(H::AbstractMatrix{T}, x::Vararg{T,N}) where {T,N}
y = zeros(T,1)
dy = ntuple(N) do i
ones(1)

Check warning on line 28 in ext/JuMPEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/JuMPEnzymeExt.jl#L25-L28

Added lines #L25 - L28 were not covered by tests
end
g = zeros(T,N)
dg = ntuple(N) do i
zeros(T,N)

Check warning on line 32 in ext/JuMPEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/JuMPEnzymeExt.jl#L30-L32

Added lines #L30 - L32 were not covered by tests
end
ry = ones(1)
dry = ntuple(N) do i
zeros(T,1)

Check warning on line 36 in ext/JuMPEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/JuMPEnzymeExt.jl#L34-L36

Added lines #L34 - L36 were not covered by tests
end
rx = ntuple(N) do i
Active(x[i])

Check warning on line 39 in ext/JuMPEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/JuMPEnzymeExt.jl#L38-L39

Added lines #L38 - L39 were not covered by tests
end

args = ntuple(N) do i
drx = ntuple(N) do j
if i == j
Active(one(T))

Check warning on line 45 in ext/JuMPEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/JuMPEnzymeExt.jl#L42-L45

Added lines #L42 - L45 were not covered by tests
else
Active(zero(T))

Check warning on line 47 in ext/JuMPEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/JuMPEnzymeExt.jl#L47

Added line #L47 was not covered by tests
end
end
BatchDuplicated(rx[i], drx)

Check warning on line 50 in ext/JuMPEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/JuMPEnzymeExt.jl#L50

Added line #L50 was not covered by tests
end
autodiff(Forward, gradient_deferred!, Const, BatchDuplicated(g,dg), BatchDuplicated(y,dy), BatchDuplicated(ry, dry), args...)
for i in 1:N
for j in 1:N
if i <= j
H[j,i] = dg[j][i]

Check warning on line 56 in ext/JuMPEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/JuMPEnzymeExt.jl#L52-L56

Added lines #L52 - L56 were not covered by tests
end
end
end
return nothing

Check warning on line 60 in ext/JuMPEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/JuMPEnzymeExt.jl#L58-L60

Added lines #L58 - L60 were not covered by tests
end

return gradient!, hessian!

Check warning on line 63 in ext/JuMPEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/JuMPEnzymeExt.jl#L63

Added line #L63 was not covered by tests
end

function JuMP.add_nonlinear_operator(

Check warning on line 66 in ext/JuMPEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/JuMPEnzymeExt.jl#L66

Added line #L66 was not covered by tests
model::GenericModel,
dim::Int,
f::Function;
name::Symbol = Symbol(f),
)
gradient, hessian = jump_operator(f)
MOI.set(model, MOI.UserDefinedFunction(name, dim), tuple(f, gradient, hessian))
return NonlinearOperator(f, name)

Check warning on line 74 in ext/JuMPEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/JuMPEnzymeExt.jl#L72-L74

Added lines #L72 - L74 were not covered by tests
end
end
Loading