You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm writing a custom rule for a function that takes, among other args, an Active tuple/namedtuple. This throws an error when the values in the tuple are taken from arrays at indices determined at runtime. See MWE below, and note the two variants of the body of g: the error is produced when using the uncommented version that loops over array indices such that the tuple arg is constructed as (qs[i],). If I instead use the commented version that's manually unrolled and constructs the tuples using hardcoded indices, there's no error.
The error does not require the tuple value to come from a mutable/Duplicated location. If I rewrite the example such that qs is an Active tuple instead of a Duplicated array, the custom rule still errors as long as the tuples are constructed using runtime-valued rather than hardcoded indices.
using Enzyme
f(qt) =first(qt)^2f_custom(qt) =f(qt) # Wrapper for custom rule interceptfunctionmake_g(func)
returnfunctiong(qs)
ret =zero(eltype(qs))
for i ineachindex(qs)
ret +=func((qs[i],))
end# ret = func((qs[1],)) + func((qs[2],)) # With this version the custom rule worksreturn ret /2endendconst g =make_g(f)
const g_custom =make_g(f_custom)
function EnzymeRules.augmented_primal(
config::EnzymeRules.Config, f::Const{typeof(f_custom)}, ::Type{<:Active}, qt::Active
)
println("This is the augmented primal rule for f_custom")
primal = EnzymeRules.needs_primal(config) ? f.val(qt.val) :nothingreturn EnzymeRules.AugmentedReturn(primal, nothing#=shadow=#, nothing#=tape=#)
endfunction EnzymeRules.reverse(
::EnzymeRules.Config, ::Const{typeof(f_custom)}, dret::Active, _#=tape=#, qt::Active
)
println("This is the reverse rule for f_custom")
((dqt_,),) =autodiff(Reverse, f, Active, qt)
return (dret.val .* dqt_,)
end
qs = [0.5, 2.0]
# primal: sanity check@showg(qs)
# gradient of g: no custom rule, works
dqs =make_zero(qs)
autodiff(Reverse, g, Active, Duplicated(qs, dqs))
@show dqs
# gradient of g_custom: uses custom rule, errors
dqs_custom =make_zero(qs)
autodiff(Reverse, g_custom, Active, Duplicated(qs, dqs_custom))
@show dqs_custom
# correctness check: not reached@showisequal(dqs_custom, dqs)
Output:
g(qs) = 2.125
dqs = [0.5, 2.0]
ERROR: LoadError: Enzyme execution failed.
Enzyme: active by ref type Active{Tuple{Float64}} is overwritten in application of custom rul
e for MethodInstance for Main.B.f_custom(::Tuple{Float64}) val= %8 = addrspacecast [1 x doub
le]* %newstruct to [1 x double] addrspace(11)* ptr= %21 = getelementptr inbounds [1 x [1 x d
ouble]], [1 x [1 x double]] addrspace(11)* %20, i64 0, i32 0, !dbg !66
Stacktrace:
[1] g
@ ~/issues/enzymecustomrules.jl:11
Stacktrace:
[1] g
@ ~/issues/enzymecustomrules.jl:11 [inlined]
[2] diffejulia_g_3435wrap
@ ~/issues/enzymecustomrules.jl:0
[3] macro expansion
@ ~/.julia/packages/Enzyme/Tb3Iu/src/compiler.jl:7151 [inlined]
[4] enzyme_call
@ ~/.julia/packages/Enzyme/Tb3Iu/src/compiler.jl:6760 [inlined]
[5] CombinedAdjointThunk
@ ~/.julia/packages/Enzyme/Tb3Iu/src/compiler.jl:6637 [inlined]
[6] autodiff
@ ~/.julia/packages/Enzyme/Tb3Iu/src/Enzyme.jl:320 [inlined]
[7] autodiff(mode::ReverseMode{…}, f::Main.B.var"#g#1"{…}, ::Type{…}, args::Duplicated{…})
@ Enzyme ~/.julia/packages/Enzyme/Tb3Iu/src/Enzyme.jl:332
[8] top-level scope
@ ~/issues/enzymecustomrules.jl:49
[9] include(mod::Module, _path::String)
@ Base ./Base.jl:495
[10] include(x::String)
@ Main.B ./REPL[5]:1
[11] top-level scope
@ REPL[5]:2
in expression starting at /home/daniel/issues/enzymecustomrules.jl:49
Some type information was truncated. Use `show(err)` to see complete types.
The text was updated successfully, but these errors were encountered:
The obvious workaround here is to pass the values as separate arguments rather than packing them into a tuple/namedtuple. That solves the problem in my actual use case too, so this issue might not be particularly urgent, though I guess it wouldn't hurt if the error message explicitly suggested this.
I'm writing a custom rule for a function that takes, among other args, an Active tuple/namedtuple. This throws an error when the values in the tuple are taken from arrays at indices determined at runtime. See MWE below, and note the two variants of the body of
g
: the error is produced when using the uncommented version that loops over array indices such that the tuple arg is constructed as(qs[i],)
. If I instead use the commented version that's manually unrolled and constructs the tuples using hardcoded indices, there's no error.The error does not require the tuple value to come from a mutable/Duplicated location. If I rewrite the example such that
qs
is an Active tuple instead of a Duplicated array, the custom rule still errors as long as the tuples are constructed using runtime-valued rather than hardcoded indices.Output:
The text was updated successfully, but these errors were encountered: