Skip to content

Commit

Permalink
Merge pull request #39 from tkelman/nolazy
Browse files Browse the repository at this point in the history
Refactor eval_univariate to avoid depending on Lazy.jl
  • Loading branch information
mlubin committed Apr 1, 2017
2 parents 0f5bde7 + 95b349c commit 15e1d37
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 10 deletions.
1 change: 0 additions & 1 deletion REQUIRE
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
julia 0.5
Calculus
Lazy
DataStructures
MathProgBase
NaNMath 0.2.1
Expand Down
1 change: 0 additions & 1 deletion src/ReverseDiffSparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ module ReverseDiffSparse
using Base.Meta
using ForwardDiff
import Calculus
import Lazy
import MathProgBase
# Override basic math functions to return NaN instead of throwing errors.
# This is what NLP solvers expect, and
Expand Down
35 changes: 27 additions & 8 deletions src/forward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -336,21 +336,38 @@ end
export forward_eval_ϵ


switchblock = Expr(:block)
exprs = Expr[]
for i = 1:length(univariate_operators)
op = univariate_operators[i]
deriv_expr = univariate_operator_deriv[i]
ex = :(return $op(x), $deriv_expr::T)
push!(switchblock.args,i,ex)
ex = :(return $op(x), $deriv_expr::T)
push!(exprs, ex)
end
switchexpr = Expr(:macrocall, Expr(:.,:Lazy,quot(Symbol("@switch"))), :operator_id,switchblock)

function binaryswitch(ids, exprs)
if length(exprs) <= 3
out = Expr(:if, Expr(:call, :(==), :operator_id, ids[1]), exprs[1])
if length(exprs) > 1
push!(out.args, binaryswitch(ids[2:end], exprs[2:end]))
end
return out
else
mid = length(exprs) >>> 1
return Expr(:if, Expr(:call, :(<=), :operator_id, ids[mid]),
binaryswitch(ids[1:mid], exprs[1:mid]),
binaryswitch(ids[mid+1:end], exprs[mid+1:end]))
end
end
switchexpr = binaryswitch(1:length(exprs), exprs)

@eval @inline function eval_univariate{T}(operator_id,x::T)
$switchexpr
error("No match for operator_id")
end

# TODO: optimize sin/cos/exp
switchblock = Expr(:block)
ids = Int[]
exprs = Expr[]
for i = 1:length(univariate_operators)
op = univariate_operators[i]
if op == :asec || op == :acsc || op == :asecd || op == :acscd || op == :acsch || op == :trigamma
Expand All @@ -366,11 +383,13 @@ for i = 1:length(univariate_operators)
else
deriv_expr = Calculus.differentiate(univariate_operator_deriv[i],:x)
end
ex = :(return $deriv_expr::T)
push!(switchblock.args,i,ex)
ex = :(return $deriv_expr::T)
push!(ids, i)
push!(exprs, ex)
end
switchexpr = Expr(:macrocall, Expr(:.,:Lazy,quot(Symbol("@switch"))), :operator_id,switchblock)
switchexpr = binaryswitch(ids, exprs)

@eval @inline function eval_univariate_2nd_deriv{T}(operator_id,x::T,fval::T)
$switchexpr
error("No match for operator_id")
end

0 comments on commit 15e1d37

Please sign in to comment.