Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Forward over reverse for variadic function #1336

Closed
michel2323 opened this issue Mar 7, 2024 · 7 comments
Closed

Forward over reverse for variadic function #1336

michel2323 opened this issue Mar 7, 2024 · 7 comments

Comments

@michel2323
Copy link
Collaborator

michel2323 commented Mar 7, 2024

I'm trying to do forward over reverse over this function

# Rosenbrock
f(x...) = (1 - x[1])^2 + 100 * (x[2] - x[1]^2)^2

To achieve this I try to write a wrapper since I can't use Active for the Forward pass. An example is this:

# Wrapper to call with Vector, 1st element is output
function f!(inout::Vector{T}) where {T}
    inp = ntuple(length(inout) - 1) do i
        inout[i+1]
    end
    inout[1] = f(inp...)
    nothing
end

inout = [0.0, 1.0, 2.0]
f!(inout)
println("in = $(inout[2]) $(inout[3])")
println("out = $(inout[1])")
dinout = [0.0, 0.0, 0.0]
autodiff_deferred(ReverseWithPrimal, f!, Const, Duplicated(inout, dinout))

This crashes with

┌ Warning: Type does not have a definite number of fields
│   T = Tuple{Vararg{Float64}}
└ @ Enzyme ~/.julia/packages/GPUCompiler/U36Ed/src/utils.jl:59
┌ Warning: Type does not have a definite number of fields
│   T = Tuple{Vararg{Float64}}
└ @ Enzyme ~/.julia/packages/GPUCompiler/U36Ed/src/utils.jl:59
ERROR: LoadError: Enzyme execution failed.
Enzyme: Not yet implemented augmented forward for jl_f__apply_iterate (true, true, iterate, f, 6, 6)

The other version is the one I sent on Slack. That one doesn't crash, but the adjoints are not updated, although the adjoints of the output is zeroed.

using Enzyme

# Rosenbrock
f(x...) = (1 - x[1])^2 + 100 * (x[2] - x[1]^2)^2

function f!(y::Ref{T}, x::Vararg{T,N}) where {T,N}
    t = f(x...)
    y[] = t
    nothing
end

function gradient(x::Vararg{T,N}) where {T,N}
    # tx = map(Active, x)
    # tx = ntuple(i -> Duplicated(x[i], 0.0), N)
    fx = x -> Duplicated(x, 0.0)
    tx = map(fx, x)
    y = Duplicated(Ref{T}(zero(T)), Ref{T}(one(T)))
    autodiff_deferred(
    ReverseWithPrimal, f!, Const, y, tx...)
    @show tx
    return getfield.(tx, :dval), y.val
end
@show g = gradient(1.0, 2.0)
@michel2323 michel2323 changed the title Duplicated wrapper for variadic function Forward over reverse for variadic function Mar 7, 2024
@wsmoses
Copy link
Member

wsmoses commented Mar 7, 2024

f(inp...) this is type unstable since you take a vector and splat it into an unknown number of elements. Enzyme claims that we don't yet support that type unstable tuple operator.

make inout an ntuple? Alternatively mark f as inline perhaps.

@wsmoses
Copy link
Member

wsmoses commented Mar 7, 2024

For the latter you cant use a duplicated of a float, it won't do what you expect (which is what's happening in you case. You have to make those active

@michel2323
Copy link
Collaborator Author

I can't because I need to apply forward over the reverse pass. Or at least I got an error indicating that.

Do you know whether there is a way to compute second-order derivatives of variadic functions with active number arguments using Enzyme? Or unclear?

I read the type unstable part in the documentation, but it did not error this time with the proper warning. As far as I understand, not every type unstable code fails. And I don't pass in temporary storage.

@michel2323
Copy link
Collaborator Author

Ok. I think I see. I have to move the tuples out.

Thanks.

@wsmoses
Copy link
Member

wsmoses commented Mar 8, 2024

Applying forward over reverse shouldn't be the issue. For any reverse call (independent of how called, either directly, or in forward over reverse, etc) active is required for non-mutable state, duplicated for mutable.

So if you got an error, it may help to see what it said?

@michel2323
Copy link
Collaborator Author

michel2323 commented Mar 8, 2024

Okay. Finally got that second-order. Sorry for removing previous posts.

using Enzyme

# Rosenbrock
@inline function f(x...)
    (1 - x[1])^2 + 100 * (x[2] - x[1]^2)^2
end

@inline function f!(y, x...)
    y[1] = f(x...)
end

x = (1.0, 2.0)
y = zeros(1)
f!(y,x...)
y[1] = 0.0
ry = ones(1)
g = zeros(2)
rx = ntuple(2) do i
    Active(x[i])
end
function gradient!(g, y, ry, rx...)
    g .= autodiff_deferred(ReverseWithPrimal, f!, Const, Duplicated(y,ry), rx...)[1][2:end]
    return nothing
end

gradient!(g, y,ry,rx...)


# FoR
y[1] = 0.0
dy = ones(1)
ry[1] = 1.0
dry = zeros(1)
drx = ntuple(2) do i
    Active(one(Float64))
end

tdrx= ntuple(2) do i
    Duplicated(rx[i], drx[i])
end
rx

fill!(g, 0.0)
dg = zeros(2)
autodiff(Forward, gradient!, Const, Duplicated(g,dg), Duplicated(y,dy), Duplicated(ry, dry), tdrx...)
# H * drx
h=dg

@odow
Copy link

odow commented Mar 14, 2024

Here's what I came up with:

julia> import Enzyme

julia> f(x...) = log(sum(exp.(x)))
f (generic function with 1 method)

julia> function ∇f!(g::AbstractVector{T}, x::Vararg{T,N}) where {T,N}
           g .= Enzyme.autodiff(Enzyme.Reverse, f, Enzyme.Active.(x)...)[1]
           return
       end
∇f! (generic function with 1 method)

julia> function ∇²f!(H::AbstractMatrix{T}, x::Vararg{T,N}) where {T,N}
           direction(i) = ntuple(j -> Enzyme.Active(T(i == j)), N)
           hess = Enzyme.autodiff(
               Enzyme.Forward,
               (x...) -> Enzyme.autodiff_deferred(Enzyme.Reverse, f, x...)[1],
               Enzyme.BatchDuplicated.(Enzyme.Active.(x), ntuple(direction, N))...,
           )[1]
           for j in 1:N, i in 1:j
               H[j, i] = hess[j][i]
           end
           return
       end
∇²f! (generic function with 1 method)

julia> N = 3
3

julia> x, g, H = rand(N), fill(NaN, N), fill(NaN, N, N);

julia> f(x...)
1.7419820145927152

julia> ∇f!(g, x...)

julia> ∇²f!(H, x...)

julia> g
3-element Vector{Float64}:
 0.24224320758303503
 0.30782611265915005
 0.4499306797578149

julia> H
3×3 Matrix{Float64}:
  0.183561   NaN         NaN
 -0.0745688    0.213069  NaN
 -0.108993    -0.1385      0.247493

See jump-dev/JuMP.jl#3712 (comment)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants