Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Non-specialized methods with custom rules result in unnecessary use of runtime handlers and "Non-constant keyword argument" error #1873

Open
danielwe opened this issue Sep 21, 2024 · 0 comments

Comments

@danielwe
Copy link
Contributor

Successor to #1845

Julia avoids specializing methods on arguments in certain cases, most notably when the argument type is <: Function and the function is not called in the function body, but only passed through to an inner function. This does not block type inference, only code generation, and runtime dispatch is often avoided by inlining since "pass-through" methods are usually small.

However, if a custom rule is written for such a method, Enzyme sees it as type unstable and invokes the runtime handler with its limited activity analysis. In addition to the performance penalty, this throws an error if activity analysis fails to prove that a keyword argument is Const. An important example is the custom rules for QuadGK.jl, since quadgk takes both a function argument and non-active float keyword arguments to set tolerances.

The solution could be for Enzyme to force recompilation with full specialization before choosing runtime vs. compile-time handling. This seems possible for a package like Enzyme, and would be fair game: I'm certain no one would object to this little bit of extra compilation in exchange for a faster and non-erroring gradient.

Reproducer below. Adding a type variable f::F to force specialization works around the issue.

using Enzyme

constcall(a, info) = call(() -> a; info)

function call(f; info=nothing)               # errors
# function call(f::F; info=nothing) where {F}  # works
    @info "$info"  # must use `info` somehow for the error to appear
    return f()
end

function EnzymeRules.augmented_primal(
    config, ::Const{typeof(call)}, ::Type{<:Active}, f::Active; kws...,
)
    primal = EnzymeRules.needs_primal(config) ? call(f.val; kws...) : nothing
    return EnzymeRules.AugmentedReturn(primal, nothing, nothing)
end

function EnzymeRules.reverse(  # this rule is totally wrong, but that's beside the point
    config, ::Const{typeof(call)}, ::Active, tape, f::Active; kws...,
)
    return (f.val,)
end

@show constcall(1.0, 1e-10)
@show autodiff(Reverse, constcall, Active, Active(1.0), Const(1e-10))

Output:

[ Info: 1.0e-10
constcall(1.0, 1.0e-10) = 1.0
ERROR: LoadError: Enzyme execution failed.
Enzyme: Non-constant keyword argument found for Tuple{UInt64, typeof(Core.kwcall), Duplicated{@NamedTuple{info::Float64}}, typeof(EnzymeCore.EnzymeRules.augmented_primal), EnzymeCore.EnzymeRules.RevConfigWidth{1, true, false, (false, false), false}, Const{typeof(call)},
 Type{Active{Float64}}, Active{var"#61#62"{Float64}}}

Stacktrace:
  [1] macro expansion
    @ ~/.julia/packages/Enzyme/uXW2v/src/compiler.jl:7061 [inlined]
  [2] enzyme_call
    @ ~/.julia/packages/Enzyme/uXW2v/src/compiler.jl:6664 [inlined]
  [3] AugmentedForwardThunk
    @ ~/.julia/packages/Enzyme/uXW2v/src/compiler.jl:6552 [inlined]
  [4] runtime_generic_augfwd(activity::Type{…}, runtimeActivity::Val{…}, width::Val{…}, ModifiedBetween::Val{…}, RT::Val{…}, f::typeof(Core.kwcall), df::Nothing, primal_1::@NamedTuple{…}, shadow_1_1::Base.RefValue{…}, primal_2::typeof(call), shadow_2_1::Nothing, primal_3::var"#61#62"{…}, shadow_3_1::Base.RefValue{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/uXW2v/src/rules/jitrules.jl:368
  [5] constcall
    @ ~/issues/quadgkkwargs.jl:41 [inlined]
  [6] diffejulia_constcall_11499wrap
    @ ~/issues/quadgkkwargs.jl:0
  [7] macro expansion
    @ ~/.julia/packages/Enzyme/uXW2v/src/compiler.jl:7061 [inlined]
  [8] enzyme_call
    @ ~/.julia/packages/Enzyme/uXW2v/src/compiler.jl:6664 [inlined]
  [9] CombinedAdjointThunk
    @ ~/.julia/packages/Enzyme/uXW2v/src/compiler.jl:6541 [inlined]
 [10] autodiff
    @ ~/.julia/packages/Enzyme/uXW2v/src/Enzyme.jl:316 [inlined]
 [11] autodiff(::ReverseMode{…}, ::typeof(constcall), ::Type{…}, ::Active{…}, ::Const{…})
    @ Enzyme ~/.julia/packages/Enzyme/uXW2v/src/Enzyme.jl:328
 [12] macro expansion
    @ show.jl:1181 [inlined]
 [13] top-level scope
    @ ~/issues/quadgkkwargs.jl:63
 [14] include(fname::String)
    @ Base.MainInclude ./client.jl:489
 [15] top-level scope
    @ REPL[3]:1
in expression starting at /home/daniel/issues/quadgkkwargs.jl:63
Some type information was truncated. Use `show(err)` to see complete types.

(PS: This reproducer is somewhat deceptive in that call calls f in the body, so why is it still not specialized? My understanding is that the inner method is specialized, but not the keyword handling wrapper that is actually invoked by call(f; info).)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant