Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom rule errors with active by ref type Active{...} is overwritten when data in Active arg comes from location determined at runtime #1786

Open
danielwe opened this issue Sep 3, 2024 · 1 comment

Comments

@danielwe
Copy link
Contributor

danielwe commented Sep 3, 2024

I'm writing a custom rule for a function that takes, among other args, an Active tuple/namedtuple. This throws an error when the values in the tuple are taken from arrays at indices determined at runtime. See MWE below, and note the two variants of the body of g: the error is produced when using the uncommented version that loops over array indices such that the tuple arg is constructed as (qs[i],). If I instead use the commented version that's manually unrolled and constructs the tuples using hardcoded indices, there's no error.

The error does not require the tuple value to come from a mutable/Duplicated location. If I rewrite the example such that qs is an Active tuple instead of a Duplicated array, the custom rule still errors as long as the tuples are constructed using runtime-valued rather than hardcoded indices.

using Enzyme

f(qt) = first(qt)^2

f_custom(qt) = f(qt)  # Wrapper for custom rule intercept

function make_g(func)
    return function g(qs)
        ret = zero(eltype(qs))
        for i in eachindex(qs)
            ret += func((qs[i],))
        end
        # ret = func((qs[1],)) + func((qs[2],))  # With this version the custom rule works
        return ret / 2
    end
end

const g = make_g(f)
const g_custom = make_g(f_custom)

function EnzymeRules.augmented_primal(
    config::EnzymeRules.Config, f::Const{typeof(f_custom)}, ::Type{<:Active}, qt::Active
)
    println("This is the augmented primal rule for f_custom")
    primal = EnzymeRules.needs_primal(config) ? f.val(qt.val) : nothing
    return EnzymeRules.AugmentedReturn(primal, nothing#=shadow=#, nothing#=tape=#)
end

function EnzymeRules.reverse(
    ::EnzymeRules.Config, ::Const{typeof(f_custom)}, dret::Active, _#=tape=#, qt::Active
)
    println("This is the reverse rule for f_custom")
    ((dqt_,),) = autodiff(Reverse, f, Active, qt)
    return (dret.val .* dqt_,)
end

qs = [0.5, 2.0]

# primal: sanity check
@show g(qs)

# gradient of g: no custom rule, works
dqs = make_zero(qs)
autodiff(Reverse, g, Active, Duplicated(qs, dqs))
@show dqs

# gradient of g_custom: uses custom rule, errors
dqs_custom = make_zero(qs)
autodiff(Reverse, g_custom, Active, Duplicated(qs, dqs_custom))
@show dqs_custom

# correctness check: not reached
@show isequal(dqs_custom, dqs)

Output:

g(qs) = 2.125
dqs = [0.5, 2.0]
ERROR: LoadError: Enzyme execution failed.
Enzyme: active by ref type Active{Tuple{Float64}} is overwritten in application of custom rul
e for MethodInstance for Main.B.f_custom(::Tuple{Float64}) val=  %8 = addrspacecast [1 x doub
le]* %newstruct to [1 x double] addrspace(11)* ptr=  %21 = getelementptr inbounds [1 x [1 x d
ouble]], [1 x [1 x double]] addrspace(11)* %20, i64 0, i32 0, !dbg !66
Stacktrace:
 [1] g
   @ ~/issues/enzymecustomrules.jl:11

Stacktrace:
  [1] g
    @ ~/issues/enzymecustomrules.jl:11 [inlined]
  [2] diffejulia_g_3435wrap
    @ ~/issues/enzymecustomrules.jl:0
  [3] macro expansion
    @ ~/.julia/packages/Enzyme/Tb3Iu/src/compiler.jl:7151 [inlined]
  [4] enzyme_call
    @ ~/.julia/packages/Enzyme/Tb3Iu/src/compiler.jl:6760 [inlined]
  [5] CombinedAdjointThunk
    @ ~/.julia/packages/Enzyme/Tb3Iu/src/compiler.jl:6637 [inlined]
  [6] autodiff
    @ ~/.julia/packages/Enzyme/Tb3Iu/src/Enzyme.jl:320 [inlined]
  [7] autodiff(mode::ReverseMode{…}, f::Main.B.var"#g#1"{…}, ::Type{…}, args::Duplicated{…})
    @ Enzyme ~/.julia/packages/Enzyme/Tb3Iu/src/Enzyme.jl:332
  [8] top-level scope
    @ ~/issues/enzymecustomrules.jl:49
  [9] include(mod::Module, _path::String)
    @ Base ./Base.jl:495
 [10] include(x::String)
    @ Main.B ./REPL[5]:1
 [11] top-level scope
    @ REPL[5]:2
in expression starting at /home/daniel/issues/enzymecustomrules.jl:49
Some type information was truncated. Use `show(err)` to see complete types.
@danielwe
Copy link
Contributor Author

danielwe commented Sep 4, 2024

The obvious workaround here is to pass the values as separate arguments rather than packing them into a tuple/namedtuple. That solves the problem in my actual use case too, so this issue might not be particularly urgent, though I guess it wouldn't hurt if the error message explicitly suggested this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant