Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hybrid overloading + SCT; not pure SCT? #24

Open
oxinabox opened this issue Jan 18, 2020 · 28 comments
Open

Hybrid overloading + SCT; not pure SCT? #24

oxinabox opened this issue Jan 18, 2020 · 28 comments

Comments

@oxinabox
Copy link
Collaborator

oxinabox commented Jan 18, 2020

@ChrisRackauckas assured me that ForwardDiff2
was pure source code transformation based approach.
Rather than an overloading based approach.

My thoughts were that is was primarily an overloading based approach that uses limited source code transformation to make it possible to extend the overloads with frules.

Now its a blurry line between the two.
But the place I would cut it is that a pure source code transform approach never calls the original function passing in a special overloaded type. Not even as a fallback.
It always calls a function that it created.
This means it never has method errors,
since it created the function its going to call.

julia> using ForwardDiff2: D
[ Info: Precompiling ForwardDiff2 [994df76e-a4c1-5e1f-bd5c-23b9b5303d4f]

julia> sq(x) = x^2
sq (generic function with 1 method)

julia> D(sq)(10)*1.0
20.0

julia> sq2(x::Float64) = x^2
sq2 (generic function with 1 method)

julia> D(sq2)(10)*1.0
ERROR: MethodError: no method matching sq2(::ForwardDiff2.Dual{ForwardDiff2.Tag{Nothing},Int64,Float64})
Closest candidates are:
  sq(::Float64) at REPL[2]:1
Stacktrace:
 [1] call at /Users/oxinabox/.julia/packages/Cassette/kbN4l/src/context.jl:447 [inlined]
 [2] fallback at /Users/oxinabox/.julia/packages/Cassette/kbN4l/src/context.jl:445 [inlined]
 [3] _overdub_fallback at /Users/oxinabox/.julia/packages/Cassette/kbN4l/src/overdub.jl:486 [inlined]
 [4] _frule_overdub2 at /Users/oxinabox/JuliaEnvs/ForwardDiff2.jl/src/dual_context.jl:145 [inlined]
 [5] alternative at /Users/oxinabox/JuliaEnvs/ForwardDiff2.jl/src/dual_context.jl:186 [inlined]
 [6] #47 at /Users/oxinabox/JuliaEnvs/ForwardDiff2.jl/src/api.jl:61 [inlined]
 [7] overdub(::Cassette.Context{nametype(DualContext),Nothing,Nothing,getfield(ForwardDiff2, Symbol("##PassType#371")),Nothing,Cassette.D
isableHooks}, ::getfield(ForwardDiff2, Symbol("##47#49")){D{Int64,typeof(sq)},Float64}) at /Users/oxinabox/.julia/packages/Cassette/kbN4l
/src/overdub.jl:0
 [8] dualrun(::Function) at /Users/oxinabox/JuliaEnvs/ForwardDiff2.jl/src/dual_context.jl:192
 [9] *(::D{Int64,typeof(sq)}, ::Float64) at /Users/oxinabox/JuliaEnvs/ForwardDiff2.jl/src/api.jl:59
 [10] top-level scope at REPL[3]:1

To describe the rewritten function:
I am going to write it without Dual types but you can do it with Dual types, it just means allocating an extra slot in the transformed code and putting stuff into that slot.

We started with:
sq2(x::Float64) = x^2

The new code is:

function do_forward_mode(sq2, x, dsq2, dx)
    sub1_args = (Base.literal_pow, x, 2, Zero(), dx, Zero())
    sub1_res = frule(sub1_args...)
    if sub1_res === nothing
        # call the generated function to make itself
        return do_forward_mode(sub1_args...)
    else
        return sub1_res
    end
end

Doing do a more complicated one:

f(x) = g(x)*h(x)
function do_forward_mode(f, x, df, dx)
    subg_args = (g, x, Zero(), dx)
    subg_res = frule(subg_args...)
    if subg_res === nothing
        # call the generated function to make itself
        subg_res = do_forward_mode(subg_args...)
    end

    subh_args = (h, x, Zero(), dx)
    subh_res = frule(subh_args...)
    if subh_res === nothing
        # call the generated function to make itself
        subh_res = do_forward_mode(subh_args...)
    end

    # Result of `g` is being used, need to split the results
    subg_values, subg_partials = subg_res

    # Result of `h`  is being used, need to split the results
    subh_values, subh_partials = subh_res

    submul_args = (*, subg_values..., subh_values..., Zero(), subg_partials..., subh_partials...)
    submul_res = frule(sub1_args...)
    if submul_res === nothing
        # call the generated function to make itself
        return do_forward_mode(sub1_args...)
    end
    return submul_res
end

So one can see that this kind of transform is complete.
We never call the original function.
We never perform any overloaded operations.
We only call frule and out transform generating function: do_forward_mode.


I believe ForwardDiff2 is meant to be using this approach,
and thus bugs have slipped in and caused it to not.

I suggest to make such bugs harder to slip through tests,
that Dual stop subtyping Number and stops defining operations like + and *.
That way we know that it is correctly doing the rewrites.

Alternatively, this may be intentional.
The hybrid overloading + source code transform ForwardDiff2 is using right now
works pretty well if everything is ether a Number or an Array.

but we might need another forward mode AD package that is pure source to source to easily handle concretely typed structs, and code that is otherwise hostile to the overloading approach.

@oxinabox
Copy link
Collaborator Author

Feel free to close this if the hybrid approach is indeed intentional, and the long term plan.

@YingboMa
Copy link
Owner

Ref: JuliaLabs/Cassette.jl#157

@willtebbutt
Copy link

@YingboMa the Cassette PR doesn't address Lyndon's question. It would be helpful to understand whether the long-term goal is to take the kind of SCT approach that Lyndon discusses above, or to continue using various dual types.

@YingboMa
Copy link
Owner

YingboMa commented Jan 19, 2020

That PR makes ForwardDiff2 on sq2(x::Float64) = x^2 work and Dual not being a subtype of Number.

@willtebbutt
Copy link

Sure. So is the plan to have a single Dual type that can contain arbitrary data types?

@YingboMa
Copy link
Owner

Dual can already contain arbitrary data types. The plan is to make Dual not visible.

@YingboMa
Copy link
Owner

Pure source to source is also prone to missing rule errors, so I don't think that is a good idea for the forward mode.

@willtebbutt
Copy link

willtebbutt commented Jan 19, 2020

Dual can already contain arbitrary data types. The plan is to make Dual not visible.

Well, yes, kind of, but it only makes sense with subtypes of Real at the minute. As I understand it, you currently have Dual for Reals and DualArray for AbstractArrays.

Pure source to source is also prone to missing rule errors, so I don't think that is a good idea for the forward mode.

Is it? Could you provide an example please?

@YingboMa
Copy link
Owner

With Dual, we know what functions to differentiate, so it is pretty simple to ignore functions like

methods, code_typed

One can solve this problem with a dependency analysis in source to source AD, but that is just doing more work. Won't the Cassette PR make FD2 as powerful as what pure source to source could be?

@shashi
Copy link
Collaborator

shashi commented Jan 19, 2020

Well, yes, kind of, but it only makes sense with subtypes of Real at the minute. As I understand it, you currently have Dual for Reals and DualArray for AbstractArrays.

They could potentially become the same type, yes.

What we do is a hybrid, but we'd love to move to use Tagging in Cassette which would make it source-to-source. But tagging is pretty unusable at the moment, so it will need more work in the Cassette and compiler side.

Cassette's tagging does the same thing as DualArrays for arrays.

Won't the Cassette PR make FD2 as powerful as what pure source to source could be?

Almost. We still won't support structs that have fields restricted to subtypes of Real for example. And

it only makes sense with subtypes of Real at the minute

this problem will show up from time to time when Dual tries to fit into the type hierarchy.

@willtebbutt
Copy link

With Dual, we know what functions to differentiate, so it is pretty simple to ignore functions like

Is see. There's only a finite number of such functions though, and you can just define frules on them that are no-ops. This is essentially what Zygote does with @nograd. It seems to work fine and is really easy to fix if you find a new function that breaks stuff.

Moreover, you should generally be able to use Zero to resolve these types of problems for the cases you're considering -- if differential of all of the arguments to a function are Zero then you can just invoke the function as usual and entirely avoid doing forwards mode. The functions you're considering are on Functions, and functions generally have Zero differential.

Also, assuming that we do wind up a world where Duals can contain structs your argument about Duals avoiding these problems no longer holds. e.g.

struct Foo
    a::Float64
end
(foo::Foo)(x) = 5x * foo.a
foo = Dual(Foo(5.0), Composite{Foo}(;a=1.0.))
methods(foo)

An instance of Foo has methods, so calling methods on it is a reasonable thing to do (Zero won't save you here either, of course). In this kind of situation you really do have to think about these types of functions.

What we do is a hybrid, but we'd love to move to use Tagging in Cassette which would make it source-to-source. But tagging is pretty unusable at the moment, so it will need more work in the Cassette and compiler side.

Understood. Yeah, the current state of the Cassette's argument-local metadata propagation makes me sad.

this problem will show up from time to time when Dual tries to fit into the type hierarchy.

Agreed. But I assume that this problem will go away once you've got a generic Dual type that doesn't live in any hierarchy.

@shashi
Copy link
Collaborator

shashi commented Jan 19, 2020

Cassette's argument-local metadata propagation makes me sad.

It's super nice when you don't need performance though :)

@oxinabox
Copy link
Collaborator Author

Note the source transform I propose in the to post doesn't need Tagging.
It just creates new code with 2x as many slots as what is is based on, and 2x as many arguments.
It's probably a lot easier to write with IRTools than Cassette.
And even then it's nice to fun to write.
But it is strictly easier than the code in Zygote as there is no need to mess with control flow.

I think hybrid SCT + overloading has some advantages in particular it's less likely to hit compiler edge cases, and one has to write less really annoying IR transforms.
But I think it's got disadvantages in that is harder to handle things to do with user types like arbitrary type constraint's on fields, and on functions, and indeed user provided structs in the first place.
You almost definitely can also handle these things I'm a hybrid overloading approach.
But it's easier in the pure SCT case as they basically fall out of everything you are already doing.

@oxinabox oxinabox reopened this Jan 19, 2020
@oxinabox
Copy link
Collaborator Author

oxinabox commented Jan 19, 2020

One thing about source to source is it supports mutation basically as a side effect of everything you have to do anyway *.
Where as in reverse mode mutation support is hard because you need to remember the old values,
in forwards mode you don't.

But you do run into issues relating to overloading (in either direction)
This is Tracker's ERROR: LoadError: MethodError: no method matching Float64(::TrackedReal) error when calling setindex{Array{Float64}, ::Int, ::TrackedReal}.
Or (as of right now) ForwardDiff2's
ERROR: LoadError: MethodError: no method matching Float64(::ForwardDiff.Dual{ForwardDiff.Tag{typeof(f),Float64},Float64,1}) in the same.

Pure source code tranformation does not use overloading so don't run into these.

@AntoineLevitt posted this example which currently fails

function f(x)
    a = [1.0, 2.0]
    a[1] = x
    sum(a)
end

In this case making it work for overloading (and arguably making it more "correct" julia code) is just a matter of changing line 2 to be a = typeof(x)[1.0, 2.0].

But lets think what its source transform looks like:

We probably do want a frule for constructors
that can be written generically
(JuliaDiff/ChainRules.jl#154)
But for sake of this example lets assume a hand-coded frule for vect(X::T...) where T (which is what vector literals like that one lower to), just so it is clear.

function frule(::typeof(Base.vect), all_args::T...) where T
    args = all_args[1:end÷2]
    # skip `dself` as `vect` isn't a functor.
    dargs = all_args[end-end÷2:end]

    y = Base.vect(args...)
    D = Union{AbstractZero, T}  # this might want to be a little different for performance.
    dy = D[dargs...]
    return y, dy
end

and we want a frule for setindex! since that compiles down to a ccall so we were always going to need one.
(If not supporting mutation in a source 2 source then I guess could just make it error.)

function frule(::typeof(setindex!), all_args::...)
    args = all_args[1:end÷2]
    x, val = args[1,2]
    inds = args[3:end] 
    # skip `dself` as `setindex` isn't a functor.
    dargs = all_args[end-end÷2:end]
    dx, dval = dargs[1,2]
    # don't need `dinds`

    y = setindex!(x, val, inds...)
    dy = setindex!(dx, dval, inds...)
    return y, dy
end

so now lets look at the whole program: repeating it again:

function f(x)
    a = [1.0, 2.0]
    a[1] = x
    sum(a)
end
function do_forward_mode(f, x, df, dx)
    # a = [1.0, 2.0]
    sub_vect_args = (Base.vect, 1.0, 2.0, Zero(), Zero(), Zero())
    sub_vect_res = frule(sub_vect_args...)
    if sub_vect_res === nothing
        # call the generated function to make itself
        sub_vect_res = do_forward_mode(sub_vect_args...)
    end
    a, da = sub_vect_res

    # a[1] = x
    sub_setindex_args = (setindex!, a, x, 1, Zero(), da, dx, Zero())
    sub_setindex_res = frule(sub_setindex_args...)
    if sub_setindex_res === nothing
        # call the generated function to make itself
        sub_setindex_res = do_forward_mode(sub_setindex_args...)
    end

    # sum(a)
    sub_sum_args = (sum, a, Zero(), da)
    sub_sum_res = frule(sub_sum_args...)
    if sub_sum_res === nothing
        # call the generated function to make itself
        sub_sum_res = do_forward_mode(sub_sum_args...)
    end
    return sub_sum_res
end

The compiler will optimize out the checks to ===nothing if the frules are define (which they all are in this case).
So it becomes:

function do_forward_mode(f, x, df, dx)
    # a = [1.0, 2.0]
    a, da = frule(Base.vect, 1.0, 2.0, Zero(), Zero(), Zero())
    
    # a[1] = x
    frule(setindex!, a, x, 1, Zero(), da, dx, Zero())
    
    # sum(a)
    frule(sum, a, Zero(), da)
end

Which is pretty slick.
And does what we want.
and never runs into issues with a or da being the wrong type to hold results.
(as long as frule for constructors are written correctly)

Things needed from the description language (ChainRulesCore)

  • there are a few things needed that technically don't need if not supporting mutability, but they are things that one should have anyway; but that if one isn't supporting mutability can probably live without.
  1. We need to have mutable composite types. MutableTangent JuliaDiff/ChainRulesCore.jl#105
  2. We need a way to work out the differential type for a given primal type, so we can declare the differential arrays etc (this can default for Any) Function to find the Differential type for a given primal type JuliaDiff/ChainRulesCore.jl#106
  3. Supporting chunking on functions that don’t have arguments, e.g. zero arg constructors. (I think they can just encode their chunkyness by passing a fill(dself, chunksize) Support chunked frule JuliaDiff/ChainRulesCore.jl#92 (comment)

@ChrisRackauckas
Copy link
Collaborator

That's just the easy case though, and FD2 can fix that. But if the buffers come from the user and the user cannot define the internal weird type that FD2 wants, then it won't be fixed. The real fix is #8 . It's the same reason why constant global mutable buffers work in ForwardDiff, and don't work in Tracker/Zygote. Your solution is cute for the easy case that is still actually allocating, but doesn't work for non-allocating optimized code.

@oxinabox
Copy link
Collaborator Author

oxinabox commented Jan 19, 2020

It's the same reason why constant global mutable buffers work in ForwardDiff, and don't work in Tracker/Zygote.

I can imagine that it doesn't work on that, sure.
Mutating global state is super gross.
Not threadsafe, and hard to reason-about.
If you need a reusable buffer, pass it in as an argument to your function.
Which of course means also passing in a reusable buffer for its derivative.

Your solution is cute for the easy case that is still actually allocating, but doesn't work for non-allocating optimized code.

Yes it does. AFAICT.
And it works for mixtures of the two.

Also that case is great becauuse if not allocating don't need to worry about constructors, so many of the issues relating to them are solved for me.

Lets look at:

function f(out::Vector, x::Vector)
    a = ones(2, 10)
    a[:, 1] = x
    sum!(out, a)
end

f([0.0, 0.0], [10, 5])  # =[19.0, 14.0] 

I am going to write the transformed code in the abridged version assuming that frules are always found.
Since we know that when they are not we can generate them via source transform.

function do_forward_mode(f, out, x, df, dout, dx)
    # a = ones(2, 10)
    a, da = frule(ones, 2, 10, Zero(), Zero(), Zero())
    
    # a[:, 1] = x
    frule(setindex!, a, x, :, 1, Zero(), da, dx, Zero(), Zero())
    
    # sum!(out, a)
    frule(sum!, out, a, Zero(), dout, da)
end

So just like before.
What does frule(sum!, out, a, Zero(), dout, da) do?
It mutates dout. IIRC, it should basically do sum!(dout, da)


So it all just works the same.
Onely now even at the outer most level the user has provided us with a array of differentials to write into.
(just like duing frule(setindex!...) in the first example, only in that case we constructed it.)

We get only as much mutation again as we had for the primal problem

But if the buffers come from the user and the user cannot define the internal weird type that FD2 wants, then it won't be fixed.

Not sure what that means, if the user wants to provide a preallocated array to contain the derivatives they better knoow what type of derivatives they are getting back. Just like for nornal preallocation, but I guess a little harder as they have to know a bit about differentials.
Still we can give them a helper for it, we have to write one anyway to make generating the types for constructors easier.
Since the differential type can basically always be determined down to a Union of a few options based on the Primal type -- and if the user knows more, about what operations will be done then they can make it better.
e.g. if your primal is an Array{Float64} then you know the partial you are preallocating is going to be one too.


There is a weirder case where you have preallocated partials but not preallocated primal values.
That case would be trickier since one needs machinery to handle it.
ChainRules has limitted machinery that handles that for accumulation of deriviatives. (accumulate! and store! and InplaceThunk)
but the code tranform (or otherwise) to generate code that has that stuff seems tricky;
especially if one wants to try and fine when something is not uses any more and repurposes it.

@ChrisRackauckas
Copy link
Collaborator

Not sure what that means, if the user wants to provide a preallocated array to contain the derivatives they better knoow what type of derivatives they are getting back. Just like for nornal preallocation, but I guess a little harder as they have to know a bit about differentials.
Still we can give them a helper for it, we have to write one anyway to make generating the types for constructors easier.
Since the differential type can basically always be determined down to a Union of a few options based on the Primal type -- and if the user knows more, about what operations will be done then they can make it better.
e.g. if your primal is an Array{Float64} then you know the partial you are preallocating is going to be one too.

But that's incorrect because then you'd have to use a Union for ForwardDiff, but you don't. You can just make a DualCache that just needs to know the chunk size and can reinterpret. With a type based version this is fairly trivial to expose: you just say preallocations should be done with this type. If things like chunksize and such are never exposed, how do you make this helper function. What type does it output?

@ChrisRackauckas
Copy link
Collaborator

ChrisRackauckas commented Jan 19, 2020

@oxinabox let's make it concrete.

https://tutorials.juliadiffeq.org/html/introduction/03-optimizing_diffeq_code.html

Case 1:

Ayu = zeros(N,N)
uAx = zeros(N,N)
Du = zeros(N,N)
Ayv = zeros(N,N)
vAx = zeros(N,N)
Dv = zeros(N,N)
function gm3!(dr,r,p,t)
  a,α,ubar,β,D1,D2 = p
  u = @view r[:,:,1]
  v = @view r[:,:,2]
  du = @view dr[:,:,1]
  dv = @view dr[:,:,2]
  mul!(Ayu,Ay,u)
  mul!(uAx,u,Ax)
  mul!(Ayv,Ay,v)
  mul!(vAx,v,Ax)
  @. Du = D1*(Ayu + uAx)
  @. Dv = D2*(Ayv + vAx)
  @. du = Du + a*u*u./v + ubar - α*u
  @. dv = Dv + a*u*u - β*v
end
prob = ODEProblem(gm3!,r0,(0.0,0.1),p)
@benchmark solve(prob,Tsit5())

Case 2:

p = (1.0,1.0,1.0,10.0,0.001,100.0,Ayu,uAx,Du,Ayv,vAx,Dv) # a,α,ubar,β,D1,D2
function gm4!(dr,r,p,t)
  a,α,ubar,β,D1,D2,Ayu,uAx,Du,Ayv,vAx,Dv = p
  u = @view r[:,:,1]
  v = @view r[:,:,2]
  du = @view dr[:,:,1]
  dv = @view dr[:,:,2]
  mul!(Ayu,Ay,u)
  mul!(uAx,u,Ax)
  mul!(Ayv,Ay,v)
  mul!(vAx,v,Ax)
  @. Du = D1*(Ayu + uAx)
  @. Dv = D2*(Ayv + vAx)
  @. du = Du + a*u*u./v + ubar - α*u
  @. dv = Dv + a*u*u - β*v
end
prob = ODEProblem(gm4!,r0,(0.0,0.1),p)
@benchmark solve(prob,Tsit5())

The standard way to handle this with ForwardDiff would be to just do

Ayu = dualcache(u, N=ForwardDiff.pickchunksize(length(u0))

for all of the buffers, and then _Ayu = get_tmp(u) inside of the function. Done: 0 allocation forward mode that supports mutation and global mutable buffers. #8 is then a ForwardDiff2 solution, which is likely also 10 lines of code away

https://github.com/JuliaDiffEq/DiffEqBase.jl/blob/master/src/init.jl#L49-L61

How does your setup handle that?

@oxinabox
Copy link
Collaborator Author

oxinabox commented Jan 19, 2020

How does your setup handle that?

Oh Duel Cache is cute, took me a while to get it.
It acts differently depending on if currently doing AD or not.

I think we can get this behavour by defining a frule for get_temp.
And we get to save a tiny bit of memory by getting to use both fields during AD.
If I understand the code right.
I am not sure I put the chunk dimension in the right place,
but I think this gets the idea acoss:

struct DiffCache{T<:AbstractArray, S<:AbstractArray}
    value::T
    partials::S
end

function DiffCache(u::AbstractArray{T}, siz, ::Type{Val{chunk_size}}) where {T, chunk_size}
    DiffCache(u, zeros(T, chunk_size, size...))
end

dualcache(u::AbstractArray, N=Val{ForwardDiff.pickchunksize(length(u))}) = DiffCache(u, size(u), N)

function get_tmp(dc::DiffCache, u::AbstractArray) = dc.value
function frule(::typeof(get_tmp), dc::DiffCache, u::AbstractArray{T})
    return reinterpret(T, dc.value), reinterpret(T, dc.partials)
end

Could even get rid of the u argument to get_temp since it just marks the type,
and just pass in a T argument that could be defaulted.
and it would work even without that argument just fine since it doesn't use the argument anymore to decide if doing AD or not.

From there everything just works I believe.
Since once one has hold of the buffers one is going to fill
the process of doing mutating operations is exactly the same.
Either way we want too hit the frules for them*.

(* Technically speaking one with DualArray I think one can maybe skip some frules like for setindex! by via effectively inheritting them.)


But that's incorrect because then you'd have to use a Union for ForwardDiff, but you don't.

Those Unions were just an example, I think you only really has to start hitting unions of differentials when one wants to use more advanved features like Composite. And I think maybe if we are real smart we can get rid of them all together.
Mostly just me being very cautious about types,

@ChrisRackauckas
Copy link
Collaborator

Oh Duel Cache is cute, took me a while to get it. It acts differently depending on if currently doing AD or not.

Yup exactly.

I think we can get this behavour by defining a frule for get_temp.

Yeah, that's cool, I guess you can do it through frules. The only thing is we'd need a user to define the buffer array as a big enough DualCache. Essentially, the sizing and all of that of the array doesn't really matter, as long as it's a big enough memory buffer for the AD run. The same Buffers could work with reverse mode AD as well with the locking technique, so IMO we should just standardize a very good Buffer and then power-user codes like DiffEq can just use it properly for non-allocating AD, and we can worry about how to do it more automatically as a next step.

@oxinabox
Copy link
Collaborator Author

oxinabox commented Jan 19, 2020

Yeah, that's cool, I guess you can do it through frules.

Concrete examples FTW.

I am sure there is a lot of fiddlyness in writing the pure SCT forward mode.
But it is almost easier than reverse mode in all ways.

  • Control flow stays the same
  • Mutation dones't need to be undoable
  • No need to track the past

The harder bit is it is important to handle chunking.
Reverse mode should that that too but its less important.
Also in reverse mode allocating a bit extra is like w/e 🤷‍♂ since its has so much overhead alrady.
But it would wreck one in forward mode.
So need to be careful things are tight

I think it has all the power we need and is more flexible to user code.

I don't think its super high priority, ForwardDiff2 is pretty great as is for many use cases.
But if we start to get a bunch of issue about: mutating things that were not passed in to the code and thus have the wrong type.
(which i think isn't that bad to do)
then pure source code tranform is a thing to think about.

And we would need a tone of frules like said about for various things that don't differentiate (probably steal direct from Zygote), and also all the primative mutating operations.

@MikeInnes
Copy link

In general, it seems extremely unlikely to me that there's something hybrid/OO can do that pure SCT can't. Mutations, including global ones, caches etc. are all fine; you can also do things like adding rules to zero-argument functions, functions that take structs, etc. If there's scepticism on that still, we can always work through more specific cases.

Until then you're going to have problems with duals leaking out (#31), things that won't compile (or worse, run incorrectly) due to perturbation confusion, errors on structs with type requirements or code that doesn't use eltype exactly right, supporting duals of other types, and so on. There might be a short-term advantage to compromising on those things but it seems hard to defend long-term. Done right, ForwardDiff2 really could differentiate everything other than ccall, which seems like something we'd all like to have.

@ChrisRackauckas
Copy link
Collaborator

I think it's best to have something that works first, then think about pure SCT. Right now we haven't had a pure SCT AD perform well or have a good feature set, which is a good reason to get something good in the next month and then consider something possibly better in the time a multi-year time frame.

@MikeInnes
Copy link

If you can get the current design working fast you can certainly get an SCT version to work fast too. I suspect it's significantly easier than putting effort to working around all those issues with Cassette tricks, if effort is going into that.

@oxinabox
Copy link
Collaborator Author

oxinabox commented Jan 20, 2020

I suspect it's significantly easier than putting effort to working around all those issues with Cassette tricks, if effort is going into that.

As i understand it Cassette Tagging, with global handling turned off,
is basically a more convient interface around what I proposed.
I would do it with IRTools myself, because Cassette tagging spooks me, and has no docs.
But I know FD2 team are good at Cassette tagging so w/e


Also its worth remembering:
ForwardDiff2 works right now.

@MikeInnes
Copy link

MikeInnes commented Jan 21, 2020

I think you're technically right that you could twist Cassette's metadata system to behaving like (forward mode) SCT AD with some effort, but this would look very different from using it as a tagging system (and would make a lot of the features that make it slow unnecessary).

One way to look at the difference, which might be helpful, is that SCT AD works only with variables whereas OO AD (dual types or compiler-level tagging) works with values (much like static/dynamic types). A big difference is whether you can dynamically decide whether something is tagged or untagged. For example:

y = x > 0 ? x : 0

In OO AD y can (at runtime) be either a value tagged with derivative information (x) or it can be an untagged integer (0). In SCT AD, this distinction makes no sense; we either generate code to calculate ys derivative or we don't (and this is where things like activity analysis come in to determine what should be statically differentiated; something that likewise doesn't make sense in an OO/tagging context).

At first, having a variable for ys derivative seems pretty much equivalent to (dynamically) making it a Dual struct or tagging it. But once you get into the weeds of how, for example, tagged and untagged values deep inside of structs should be represented, it starts making a big difference and you end up with the issues I described above. This should also give some intuition for the compiler issues: adding a slot with a consistent type is trivial for the compiler, whereas dealing with dynamic data layouts can quickly throw off the whole thing. (A similar issue affects Zygote, since we don't know ahead of time what gradients will be nothing, but it wouldn't be a problem for forward mode since you just reuse the primal type for the dual as you do now; the dual would be exactly as easy to infer as the primal.) I suspect that compiling and executing the code to track what is tagged and where makes the tagging approach pretty fundamentally difficult in a way that SCT isn't.

SCT AD is not strictly better than OO AD, especially in forward mode; it's a tradeoff (I leave the benefits of OO as an exercise to the reader). From my perspective it's mostly just important to be clear that these are not actually the same thing, which is a confusion that's come up on this thread and elsewhere a couple of times.

@oxinabox
Copy link
Collaborator Author

adding a slot with a consistent type is trivial for the compiler, whereas dealing with dynamic data layouts can quickly throw off the whole thing. (A similar issue affects Zygote, since we don't know ahead of time what gradients will be nothing, but it wouldn't be a problem for forward mode since you just reuse the primal type for the dual as you do now; the dual would be exactly as easy to infer as the primal.)

Its actually difficult for both.
Its not a problem with you stick to things with obvious natural differentials that are the same time like ::Real and ::Array{<:Real} (i.e. types that basically are clear vector spaces).
but it becomes a problem for both reverse and forward AD once you start looking at structs, espectially structs that might not have natural differentials, and so you have to use structual differentials.
And when you want some of those fields to be Zero (or nothing) as Zygote calls it.

Then you need to know what operations are going to be done in the future.
Which for forward mode is strictly impossible -- it hasn't happened yet.

But for reverse mode its theoretically possible, just very hard.
Since the future pullbacks are determined by operations you saw on the forward pass.

@MikeInnes
Copy link

That's true, though only if you choose to represent Zero that way. You could instead make use of the field type and use a natural dual for that (if it's concrete; if it isn't inferability is moot anyway). If you want to save some memory (Zero instead of an array of zeros) you need the union in general, but that's more a property of that optimisation than something inherent to AD, I think.

Zygote might end up getting a similar stricter mapping from primal to adjoint types for somewhat similar reasons. But it's a bigger sacrifice for Zygote because we might ideally want to promote based on future operations, whereas in forward mode only the primal type itself matters.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants