diff --git a/src/forward.jl b/src/forward.jl index 22e54e5..b3db6b9 100644 --- a/src/forward.jl +++ b/src/forward.jl @@ -10,9 +10,9 @@ # a general DAG. If we have a DAG, then need to associate storage with each edge of the DAG. # user_input_buffer and user_output_buffer are used as temporary storage # when handling user-defined functions -function forward_eval(storage::Vector{T}, partials_storage::AbstractVector{T}, +function forward_eval(storage::AbstractVector{T}, partials_storage::AbstractVector{T}, nd::AbstractVector{NodeData}, adj, const_values, - parameter_values, x_values::Vector{T}, + parameter_values, x_values::AbstractVector{T}, subexpression_values, user_input_buffer=[], user_output_buffer=[]; user_operators::UserOperatorRegistry=UserOperatorRegistry()) where T @@ -215,7 +215,7 @@ export forward_eval # need to recompute the real components. # Computes partials_storage_ϵ as well # We assume that forward_eval has already been called. -function forward_eval_ϵ(storage::Vector{T}, +function forward_eval_ϵ(storage::AbstractVector{T}, storage_ϵ::AbstractVector{ForwardDiff.Partials{N, T}}, partials_storage::AbstractVector{T}, partials_storage_ϵ::AbstractVector{ForwardDiff.Partials{N, T}}, diff --git a/src/reverse.jl b/src/reverse.jl index 2e13c88..b27cf6b 100644 --- a/src/reverse.jl +++ b/src/reverse.jl @@ -5,7 +5,7 @@ # assumes partials_storage is already updated # dense gradient output, assumes initialized to zero # if subexpressions are present, must run reverse_eval on subexpression tapes afterwards -function reverse_eval(reverse_storage::Vector{T}, partials_storage::Vector{T}, +function reverse_eval(reverse_storage::AbstractVector{T}, partials_storage::AbstractVector{T}, nd::Vector{NodeData},adj) where {T} @assert length(reverse_storage) >= length(nd) @@ -38,7 +38,7 @@ export reverse_eval # assume we've already run the reverse pass, now just extract the answer # given the scaling value -function reverse_extract(output::Vector{T}, reverse_storage::Vector{T}, nd::Vector{NodeData}, +function reverse_extract(output::AbstractVector{T}, reverse_storage::AbstractVector{T}, nd::Vector{NodeData}, adj, subexpression_output, scale_value::T) where {T} @assert length(reverse_storage) >= length(nd) @@ -62,7 +62,7 @@ export reverse_extract # Compute directional derivatives of the reverse pass, goes with forward_eval_ϵ # to compute hessian-vector products. function reverse_eval_ϵ(output_ϵ::AbstractVector{ForwardDiff.Partials{N,T}}, - reverse_storage::Vector{T}, reverse_storage_ϵ, partials_storage::Vector{T}, + reverse_storage::AbstractVector{T}, reverse_storage_ϵ, partials_storage::AbstractVector{T}, partials_storage_ϵ, nd::Vector{NodeData}, adj, subexpression_output, subexpression_output_ϵ, scale_value::T, scale_value_ϵ::ForwardDiff.Partials{N,T}) where {N,T} diff --git a/test/runtests.jl b/test/runtests.jl index d0d96d0..af7a80e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -27,6 +27,23 @@ reverse_extract(grad,reverse_storage,nd,adj,[],1.0) true_grad = [2*x[1]*cos(x[1]^2), -4*sin(x[2]*4)/5] @test isapprox(grad,true_grad) +# Testing view +xx = [1.0,2.0,3.0,4.0,5.0] +xv = @view xx[2:3] +fval_view = forward_eval(storage,partials_storage,nd,adj,const_values,[],xv,[]) +@test fval_view == fval + +grad_view = zeros(2) +reverse_eval(reverse_storage,partials_storage,nd,adj) +reverse_extract(grad_view,reverse_storage,nd,adj,[],1.0) +@test grad_view == grad + +grad_view = zeros(5) +gv = @view grad_view[2:2:end] +reverse_eval(reverse_storage,partials_storage,nd,adj) +reverse_extract(gv,reverse_storage,nd,adj,[],1.0) +@test gv == grad + # subexpressions nd_outer = [NodeData(SUBEXPRESSION,1,-1)]