Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Forward over reverse drops custom rule if corresponding function is inlined #1795

Open
danielwe opened this issue Sep 5, 2024 · 0 comments

Comments

@danielwe
Copy link
Contributor

danielwe commented Sep 5, 2024

My custom reverse rule works as expected in 1st order reverse mode, but I need to give the corresponding function @noinline tag for it to be picked up in 2nd order forward over reverse. In the MWE below I've introduced a bug in the rule such that both the gradient and the hv product should be different between g and g_custom; gradients are always different, but hv products are only different when I add @noinline.

using Enzyme

#=@noinline=# f(x) = sum(abs2, x)
#=@noinline=# f_custom(x) = sum(abs2, x)

g(x) = cos(f(x))
g_custom(x) = cos(f_custom(x))

function dg_deferred!(dx, x)
    make_zero!(dx)
    autodiff_deferred(Reverse, g, Active, Duplicated(x, dx))
    return nothing
end

function dg_custom_deferred!(dx, x)
    make_zero!(dx)
    autodiff_deferred(Reverse, g_custom, Active, Duplicated(x, dx))
    return nothing
end

function EnzymeRules.augmented_primal(
    config::EnzymeRules.Config, f::Const{typeof(f_custom)}, ::Type{<:Active}, x::Duplicated
)
    tape = EnzymeRules.overwritten(config)[2] ? copy(x.val) : nothing
    primal = EnzymeRules.needs_primal(config) ? f.val(x.val) : nothing
    return EnzymeRules.AugmentedReturn(primal, nothing#=shadow=#, tape)
end

function EnzymeRules.reverse(
    config::EnzymeRules.Config,
    ::Const{typeof(f_custom)},
    dret::Active,
    tape,
    x::Duplicated,
)
    xval = EnzymeRules.overwritten(config)[2] ? tape : x.val
    x.dval .= (2dret.val) .* xval
    x.dval .^= 2  # Deliberate bug as signature of custom rule 🐛
    return (nothing,)
end

x = [2.0]
dx, dx_custom = make_zero(x), make_zero(x)

v = first(onehot(x))
hv, hv_custom = make_zero(v), make_zero(v)

# gradients
dg_deferred!(dx, x)
@show dx

dg_custom_deferred!(dx_custom, x)
@show dx_custom

# hvps
autodiff(Forward, dg_deferred!, Const, Duplicated(dx, hv), Duplicated(x, v))
@show hv

autodiff(
    Forward,
    dg_custom_deferred!,
    Const,
    Duplicated(dx_custom, hv_custom),
    Duplicated(x, v),
)
@show hv_custom

Output as written: different gradients, equal hv products.

dx = [3.027209981231713]
dx_custom = [9.164000270468907]
hv = [11.971902924433648]
hv_custom = [11.971902924433648]

Output with @noinline: both gradients and hv products different.

dx = [3.027209981231713]
dx_custom = [9.164000270468907]
hv = [11.971902924433648]
hv_custom = [72.48292805436535]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant