Skip to content

Commit

Permalink
Change a few fields to AbstractVector to support views (#54)
Browse files Browse the repository at this point in the history
  • Loading branch information
abelsiqueira authored and mlubin committed Oct 14, 2018
1 parent eb4526b commit bc96bfd
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 6 deletions.
6 changes: 3 additions & 3 deletions src/forward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
# a general DAG. If we have a DAG, then need to associate storage with each edge of the DAG.
# user_input_buffer and user_output_buffer are used as temporary storage
# when handling user-defined functions
function forward_eval(storage::Vector{T}, partials_storage::AbstractVector{T},
function forward_eval(storage::AbstractVector{T}, partials_storage::AbstractVector{T},
nd::AbstractVector{NodeData}, adj, const_values,
parameter_values, x_values::Vector{T},
parameter_values, x_values::AbstractVector{T},
subexpression_values, user_input_buffer=[],
user_output_buffer=[];
user_operators::UserOperatorRegistry=UserOperatorRegistry()) where T
Expand Down Expand Up @@ -215,7 +215,7 @@ export forward_eval
# need to recompute the real components.
# Computes partials_storage_ϵ as well
# We assume that forward_eval has already been called.
function forward_eval_ϵ(storage::Vector{T},
function forward_eval_ϵ(storage::AbstractVector{T},
storage_ϵ::AbstractVector{ForwardDiff.Partials{N, T}},
partials_storage::AbstractVector{T},
partials_storage_ϵ::AbstractVector{ForwardDiff.Partials{N, T}},
Expand Down
6 changes: 3 additions & 3 deletions src/reverse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# assumes partials_storage is already updated
# dense gradient output, assumes initialized to zero
# if subexpressions are present, must run reverse_eval on subexpression tapes afterwards
function reverse_eval(reverse_storage::Vector{T}, partials_storage::Vector{T},
function reverse_eval(reverse_storage::AbstractVector{T}, partials_storage::AbstractVector{T},
nd::Vector{NodeData},adj) where {T}

@assert length(reverse_storage) >= length(nd)
Expand Down Expand Up @@ -38,7 +38,7 @@ export reverse_eval

# assume we've already run the reverse pass, now just extract the answer
# given the scaling value
function reverse_extract(output::Vector{T}, reverse_storage::Vector{T}, nd::Vector{NodeData},
function reverse_extract(output::AbstractVector{T}, reverse_storage::AbstractVector{T}, nd::Vector{NodeData},
adj, subexpression_output, scale_value::T) where {T}

@assert length(reverse_storage) >= length(nd)
Expand All @@ -62,7 +62,7 @@ export reverse_extract
# Compute directional derivatives of the reverse pass, goes with forward_eval_ϵ
# to compute hessian-vector products.
function reverse_eval_ϵ(output_ϵ::AbstractVector{ForwardDiff.Partials{N,T}},
reverse_storage::Vector{T}, reverse_storage_ϵ, partials_storage::Vector{T},
reverse_storage::AbstractVector{T}, reverse_storage_ϵ, partials_storage::AbstractVector{T},
partials_storage_ϵ, nd::Vector{NodeData}, adj, subexpression_output,
subexpression_output_ϵ, scale_value::T, scale_value_ϵ::ForwardDiff.Partials{N,T}) where {N,T}

Expand Down
17 changes: 17 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,23 @@ reverse_extract(grad,reverse_storage,nd,adj,[],1.0)
true_grad = [2*x[1]*cos(x[1]^2), -4*sin(x[2]*4)/5]
@test isapprox(grad,true_grad)

# Testing view
xx = [1.0,2.0,3.0,4.0,5.0]
xv = @view xx[2:3]
fval_view = forward_eval(storage,partials_storage,nd,adj,const_values,[],xv,[])
@test fval_view == fval

grad_view = zeros(2)
reverse_eval(reverse_storage,partials_storage,nd,adj)
reverse_extract(grad_view,reverse_storage,nd,adj,[],1.0)
@test grad_view == grad

grad_view = zeros(5)
gv = @view grad_view[2:2:end]
reverse_eval(reverse_storage,partials_storage,nd,adj)
reverse_extract(gv,reverse_storage,nd,adj,[],1.0)
@test gv == grad

# subexpressions

nd_outer = [NodeData(SUBEXPRESSION,1,-1)]
Expand Down

0 comments on commit bc96bfd

Please sign in to comment.