Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Symbolic AD of ScalarNonlinearFunction #2533

Open
odow opened this issue Aug 12, 2024 · 2 comments
Open

Symbolic AD of ScalarNonlinearFunction #2533

odow opened this issue Aug 12, 2024 · 2 comments
Labels
Project: next-gen nonlinear support Issues relating to nonlinear support Submodule: Nonlinear About the Nonlinear submodule Type: Enhancement

Comments

@odow
Copy link
Member

odow commented Aug 12, 2024

This has come up quite a few times, so I think we need this.

I don't know what the right API is. Perhaps:

MOI.Nonlinear.gradient(f) :: Dict{MOI.VariableIndex,MOI.ScalarNonlinearFunction}

The use case for this would be:

It's okay for this to have all the usual issues with symbolic AD.

@odow odow added Type: Enhancement Project: next-gen nonlinear support Issues relating to nonlinear support Submodule: Nonlinear About the Nonlinear submodule labels Aug 12, 2024
@odow
Copy link
Member Author

odow commented Sep 3, 2024

I started hacking something for this:

module SymbolicAD

import MacroTools
import MathOptInterface as MOI

derivative(::Real, ::MOI.VariableIndex) = false

function derivative(f::MOI.VariableIndex, x::MOI.VariableIndex)
    return ifelse(f == x, true, false)
end

function derivative(
    f::MOI.ScalarAffineFunction{T},
    x::MOI.VariableIndex,
) where {T}
    ret = zero(T)
    for term in f.terms
        if term.variable == x
            ret += term.coefficient
        end
    end
    return ret
end

function derivative(
    f::MOI.ScalarQuadraticFunction{T},
    x::MOI.VariableIndex,
) where {T}
    constant = zero(T)
    for term in f.affine_terms
        if term.variable == x
            constant += term.coefficient
        end
    end
    aff_terms = MOI.ScalarAffineTerm{T}[]
    for q_term in f.quadratic_terms
        if q_term.variable_1 == q_term.variable_2 == x
            push!(aff_terms, MOI.ScalarAffineTerm(q_term.coefficient, x))
        elseif q_term.variable_1 == x
            push!(
                aff_terms,
                MOI.ScalarAffineTerm(q_term.coefficient, q_term.variable_2),
            )
        elseif q_term.variable_2 == x
            push!(
                aff_terms,
                MOI.ScalarAffineTerm(q_term.coefficient, q_term.variable_1),
            )
        end
    end
    return MOI.ScalarAffineFunction(aff_terms, constant)
end

function derivative(f::MOI.ScalarNonlinearFunction, x::MOI.VariableIndex)
    if length(f.args) == 1
        u = only(f.args)
        if f.head == :+
            return derivative(u, x)
        elseif f.head == :-
            return MOI.ScalarNonlinearFunction(:-, Any[derivative(u, x)])
        elseif f.head == :abs
            scale = MOI.ScalarNonlinearFunction(
                :ifelse,
                Any[MOI.ScalarNonlinearFunction(:>=, Any[u, 0]), 1, -1],
            )
            return MOI.ScalarNonlinearFunction(:*, Any[scale, derivative(u, x)])
        elseif f.head == :sign
            return false
        end
        for (key, df, _) in MOI.Nonlinear.SYMBOLIC_UNIVARIATE_EXPRESSIONS
            if key == f.head
                # The chain rule: d(f(g(x))) / dx = f'(g(x)) * g'(x)
                u = only(f.args)
                df_du = MacroTools.postwalk(df) do node
                    if node === :x
                        return u
                    elseif Meta.isexpr(node, :call)
                        op, args = node.args[1], node.args[2:end]
                        return MOI.ScalarNonlinearFunction(op, args)
                    end
                    return node
                end
                du_dx = derivative(u, x)
                return MOI.ScalarNonlinearFunction(:*, Any[df_du, du_dx])
            end
        end
    end
    if f.head == :+
        # d/dx(+(args...)) = +(d/dx args)
        args = Any[derivative(arg, x) for arg in f.args]
        return MOI.ScalarNonlinearFunction(:+, args)
    elseif f.head == :-
        # d/dx(-(args...)) = -(d/dx args)
        # Note that - is not unary here because that wouuld be caught above.
        args = Any[derivative(arg, x) for arg in f.args]
        return MOI.ScalarNonlinearFunction(:-, args)
    elseif f.head == :*
        # Product rule: d/dx(*(args...)) = sum(d{i}/dx * args\{i})
        sum_terms = Any[]
        for i in 1:length(f.args)
            g = MOI.ScalarNonlinearFunction(:*, copy(f.args))
            g.args[i] = derivative(f.args[i], x)
            push!(sum_terms, g)
        end
        return MOI.ScalarNonlinearFunction(:+, sum_terms)
    elseif f.head == :^
        @assert length(f.args) == 2
        u, p = f.args
        du_dx = derivative(u, x)
        dp_dx = derivative(p, x)
        if _iszero(dp_dx)
            # p is constant and does not depend on x
            df_du = MOI.ScalarNonlinearFunction(
                :*,
                Any[p, MOI.ScalarNonlinearFunction(:^, Any[u, p-1])],
            )
            du_dx = derivative(u, x)
            return MOI.ScalarNonlinearFunction(:*, Any[df_du, du_dx])
        else
            # u(x)^p(x)
        end
    elseif f.head == :/
        # Quotient rule: d/dx(u / v) = (du/dx)*v - u*(dv/dx)) / v^2
        @assert length(f.args) == 2
        u, v = f.args
        du_dx, dv_dx = derivative(u, x), derivative(v, x)
        return MOI.ScalarNonlinearFunction(
            :/,
            Any[
                MOI.ScalarNonlinearFunction(
                    :-,
                    Any[
                        MOI.ScalarNonlinearFunction(:*, Any[du_dx, v]),
                        MOI.ScalarNonlinearFunction(:*, Any[u, dv_dx]),
                    ],
                ),
                MOI.ScalarNonlinearFunction(:^, Any[v, 2]),
            ],
        )
    elseif f.head == :ifelse
        @assert length(f.args) == 3
        # Pick the derivative of the active branch
        return MOI.ScalarNonlinearFunction(
            :ifelse,
            Any[f.args[1], derivative(f.args[2], x), derivative(f.args[3], x)],
        )
    elseif f.head == :atan
        # TODO
    elseif f.head == :min
        g = derivative(f.args[end], x)
        for i in length(f.args)-1:-1:1
            g = MOI.ScalarNonlinearFunction(
                :ifelse,
                Any[
                    MOI.ScalarNonlinearFunction(:(<=), Any[f.args[i], f]),
                    derivative(f.args[i], x),
                    g,
                ],
            )
        end
        return g
    elseif f.head == :max
        g = derivative(f.args[end], x)
        for i in length(f.args)-1:-1:1
            g = MOI.ScalarNonlinearFunction(
                :ifelse,
                Any[
                    MOI.ScalarNonlinearFunction(:(>=), Any[f.args[i], f]),
                    derivative(f.args[i], x),
                    g,
                ],
            )
        end
        return g
    elseif f.head in (:(>=), :(<=), :(<), :(>), :(==))
        return false
    end
    err = MOI.UnsupportedNonlinearOperator(
        f.head,
        "the operator does not support symbolic differentiation",
    )
    return throw(err)
end

simplify(f) = f

function simplify(f::MOI.ScalarAffineFunction{T}) where {T}
    f = MOI.Utilities.canonical(f)
    if isempty(f.terms)
        return f.constant
    end
    return f
end

function simplify(f::MOI.ScalarQuadraticFunction{T}) where {T}
    f = MOI.Utilities.canonical(f)
    if isempty(f.quadratic_terms)
        return simplify(MOI.ScalarAffineFunction(f.affine_terms, f.constant))
    end
    return f
end

function _eval_if_constant(f::MOI.ScalarNonlinearFunction)
    if all(_isnum, f.args) && hasproperty(Base, f.head)
        return getproperty(Base, f.head)(f.args...)
    end
    return f
end

_eval_if_constant(f) = f

function simplify(f::MOI.ScalarNonlinearFunction)
    for i in 1:length(f.args)
        f.args[i] = simplify(f.args[i])
    end
    return _eval_if_constant(simplify(Val(f.head), f))
end

simplify(::Val, f::MOI.ScalarNonlinearFunction) = f

_iszero(x::Union{Bool,Integer,Float64}) = iszero(x)
_iszero(::Any) = false

_isone(x::Union{Bool,Integer,Float64}) = isone(x)
_isone(::Any) = false

_isnum(::Union{Bool,Integer,Float64}) = true
_isnum(::Any) = false

_isexpr(::Any, ::Symbol, n::Int = 0) = false
_isexpr(f::MOI.ScalarNonlinearFunction, head::Symbol) = f.head == head
function _isexpr(f::MOI.ScalarNonlinearFunction, head::Symbol, n::Int)
    return _isexpr(f, head) && length(f.args) == n
end

function simplify(::Val{:*}, f::MOI.ScalarNonlinearFunction)
    new_args = Any[]
    first_constant = 0
    for arg in f.args
        if _isexpr(arg, :*)
            append!(new_args, arg.args)
        elseif _iszero(arg)
            return false
        elseif _isone(arg)
            # nothing
        elseif arg isa Real
            if first_constant == 0
                push!(new_args, arg)
                first_constant = length(new_args)
            else
                new_args[first_constant] *= arg
            end
        else
            push!(new_args, arg)
        end
    end
    if isempty(new_args)
        return true
    elseif length(new_args) == 1
        return only(new_args)
    end
    return MOI.ScalarNonlinearFunction(:*, new_args)
end

function simplify(::Val{:+}, f::MOI.ScalarNonlinearFunction)
    if length(f.args) == 1
        return only(f.args)
    elseif length(f.args) == 2 && _isexpr(f.args[2], :-, 1)
        return MOI.ScalarNonlinearFunction(
            :-,
            Any[f.args[1], f.args[2].args[1]],
        )
    end
    new_args = Any[]
    first_constant = 0
    for arg in f.args
        if _isexpr(arg, :+)
            append!(new_args, arg.args)
        elseif _iszero(arg)
            # nothing
        elseif arg isa Real
            if first_constant == 0
                push!(new_args, arg)
                first_constant = length(new_args)
            else
                new_args[first_constant] += arg
            end
        else
            push!(new_args, arg)
        end
    end
    if isempty(new_args)
        return false
    elseif length(new_args) == 1
        return only(new_args)
    end
    return MOI.ScalarNonlinearFunction(:+, new_args)
end

function simplify(::Val{:-}, f::MOI.ScalarNonlinearFunction)
    if length(f.args) == 1
        if _isexpr(f.args[1], :-, 1)
            # -(-(x)) => x
            return f.args[1].args[1]
        end
    elseif length(f.args) == 2
        if _iszero(f.args[1])
            # 0 - x => -x
            return MOI.ScalarNonlinearFunction(:-, Any[f.args[2]])
        elseif _iszero(f.args[2])
            # x - 0 => x
            return f.args[1]
        elseif f.args[1] == f.args[2]
            # x - x => 0
            return false
        elseif _isexpr(f.args[2], :-, 1)
            # x - -(y) => x + y
            return MOI.ScalarNonlinearFunction(
                :+,
                Any[f.args[1], f.args[2].args[1]],
            )
        end
    end
    return f
end

function simplify(::Val{:^}, f::MOI.ScalarNonlinearFunction)
    if _iszero(f.args[2])
        # x^0 => 1
        return true
    elseif _isone(f.args[2])
        # x^1 => x
        return f.args[1]
    elseif _iszero(f.args[1])
        # 0^x => 0
        return false
    elseif _isone(f.args[1])
        # 1^x => 1
        return true
    end
    return f
end

function variables(f::MOI.AbstractScalarFunction)
    ret = MOI.VariableIndex[]
    variables!(ret, f)
    return ret
end

variables(::Real) = MOI.VariableIndex[]
variables!(ret, ::Real) = nothing

function variables!(ret, f::MOI.VariableIndex)
    if !(f in ret)
        push!(ret, f)
    end
    return
end

function variables!(ret, f::MOI.ScalarAffineTerm)
    if !(f.variable in ret)
        push!(ret, f.variable)
    end
    return
end

function variables!(ret, f::MOI.ScalarAffineFunction)
    for term in f.terms
        variables!(ret, term)
    end
    return
end

function variables!(ret, f::MOI.ScalarQuadraticTerm)
    if !(f.variable_1 in ret)
        push!(ret, f.variable_1)
    end
    if !(f.variable_2 in ret)
        push!(ret, f.variable_2)
    end
    return
end

function variables!(ret, f::MOI.ScalarQuadraticFunction)
    for term in f.affine_terms
        variables!(ret, term)
    end
    for q_term in f.quadratic_terms
        variables!(ret, q_term)
    end
    return
end

function variables!(ret, f::MOI.ScalarNonlinearFunction)
    for arg in f.args
        variables!(ret, arg)
    end
    return
end

gradient(::Real) = Dict{MOI.VariableIndex,Any}()
function gradient(f::MOI.AbstractScalarFunction)
    return Dict{MOI.VariableIndex,Any}(
        x => simplify(derivative(f, x)) for x in variables(f)
    )
end

end

using JuMP, Test

function test_derivative()
    model = Model()
    @variable(model, x)
    @variable(model, y)
    @variable(model, z)
    @testset "$f" for (f, fp) in Any[
        # derivative(::Real, ::MOI.VariableIndex)
        1.0=>0.0,
        1.23=>0.0,
        # derivative(f::MOI.VariableIndex, x::MOI.VariableIndex)
        x=>1.0,
        y=>0.0,
        # derivative(f::MOI.ScalarAffineFunction{T}, x::MOI.VariableIndex)
        1.0*x=>1.0,
        1.0*x+2.0=>1.0,
        2.0*x+2.0=>2.0,
        2.0*x+y+2.0=>2.0,
        2.0*x+y+z+2.0=>2.0,
        # derivative(f::MOI.ScalarQuadraticFunction{T}, x::MOI.VariableIndex)
        QuadExpr(1.0 * x)=>1.0,
        QuadExpr(1.0 * x + 0.0 * y)=>1.0,
        x*y=>1.0*y,
        y*x=>1.0*y,
        x^2=>2.0*x,
        x^2+3x+4=>2.0*x+3.0,
        (x-1.0)^2=>2.0*(x-1),
        (3*x+1.0)^2=>6.0*(3x+1),
        # Univariate
        #   f.head == :+
        @force_nonlinear(+x)=>1,
        @force_nonlinear(+sin(x))=>cos(x),
        #   f.head == :-
        @force_nonlinear(-sin(x))=>-cos(x),
        #   f.head == :abs
        @force_nonlinear(
            abs(sin(x))
        )=>op_ifelse(op_greater_than_or_equal_to(sin(x), 0), 1, -1)*cos(x),
        #   f.head == :sign
        sign(x)=>false,
        # SYMBOLIC_UNIVARIATE_EXPRESSIONS
        sin(x)=>cos(x),
        cos(x)=>-sin(x),
        log(x)=>1/x,
        log(2x)=>1/(2x)*2.0,
        # f.head == :+
        sin(x)+cos(x)=>cos(x)-sin(x),
        # f.head == :-
        sin(x)-cos(x)=>cos(x)+sin(x),
        # f.head == :*
        @force_nonlinear(*(x, y, z))=>@force_nonlinear(*(y, z)),
        @force_nonlinear(*(y, x, z))=>@force_nonlinear(*(y, z)),
        @force_nonlinear(*(y, z, x))=>@force_nonlinear(*(y, z)),
        # :^
        sin(x)^2=>@force_nonlinear(*(2.0, sin(x), cos(x))),
        sin(x)^1=>cos(x),
        # :/
        @force_nonlinear(/(x, 2))=>0.5,
        @force_nonlinear(
            x^2 / (x + 1)
        )=>@force_nonlinear((*(2, x, x + 1) - x^2) / (x + 1)^2),
        # :ifelse
        op_ifelse(z, x^2, x)=>op_ifelse(z, 2x, 1),
        # :atan
        # :min
        min(x, x^2)=>op_ifelse(op_less_than_or_equal_to(x, min(x, x^2)), 1, 2x),
        # :max
        max(
            x,
            x^2,
        )=>op_ifelse(op_greater_than_or_equal_to(x, max(x, x^2)), 1, 2x),
        # comparisons
        op_greater_than_or_equal_to(x, y)=>false,
        op_equal_to(x, y)=>false,
    ]
        g = SymbolicAD.derivative(moi_function(f), index(x))
        h = SymbolicAD.simplify(g)
        if !(h  moi_function(fp))
            @show h
            @show f
        end
        @test h  moi_function(fp)
    end
    return
end

function test_gradient()
    model = Model()
    @variable(model, x)
    @variable(model, y)
    @variable(model, z)
    @testset "$f" for (f, fp) in Any[
        # ::Real
        1.0=>Dict(),
        # ::AffExpr
        x=>Dict(x => 1),
        x+y=>Dict(x => 1, y => 1),
        2x+y=>Dict(x => 2, y => 1),
        2x+3y+1=>Dict(x => 2, y => 3),
        # ::QuadExpr
        2x^2+3y+z=>Dict(x => 4x, y => 3, z => 1),
        # ::NonlinearExpr
        sin(x)=>Dict(x => cos(x)),
        sin(x + y)=>Dict(x => cos(x + y), y => cos(x + y)),
        sin(x + 2y)=>Dict(x => cos(x + 2y), y => cos(x + 2y) * 2),
    ]
        g = SymbolicAD.gradient(moi_function(f))
        h = Dict{MOI.VariableIndex,Any}(
            index(k) => moi_function(v) for (k, v) in fp
        )
        @test length(g) == length(h)
        for k in keys(g)
            @test g[k]  h[k]
        end
    end
    return
end

function test_simplify()
    model = Model()
    @variable(model, x)
    @variable(model, y)
    @variable(model, z)
    @testset "$f" for (f, fp) in Any[
        # simplify(f)
        x=>x,
        # simplify(f::MOI.ScalarAffineFunction{T})
        AffExpr(2.0)=>2.0,
        # simplify(f::MOI.ScalarQuadraticFunction{T})
        QuadExpr(x + 1)=>x+1,
        # simplify(f::MOI.ScalarNonlinearFunction)
        @force_nonlinear(sin(*(3, x^0)))=>sin(3),
        sin(log(x))=>sin(log(x)),
        op_ifelse(z, x, 0)=>op_ifelse(z, x, 0),
        # simplify(::Val{:*}, f::MOI.ScalarNonlinearFunction)
        @force_nonlinear(*(x, *(y, z)))=>@force_nonlinear(*(x, y, z)),
        @force_nonlinear(
            *(x, *(y, z, *(x, 2)))
        )=>@force_nonlinear(*(x, y, z, x, 2)),
        @force_nonlinear(*(x, 3, 2))=>@force_nonlinear(*(x, 6)),
        @force_nonlinear(*(3, x, 2))=>@force_nonlinear(*(6, x)),
        @force_nonlinear(*(x, 1))=>x,
        @force_nonlinear(*(-(x, x), 1))=>0,
        # simplify(::Val{:+}, f::MOI.ScalarNonlinearFunction)
        @force_nonlinear(+(x, +(y, z)))=>@force_nonlinear(+(x, y, z)),
        +(sin(x), -cos(x))=>sin(x)-cos(x),
        @force_nonlinear(+(x, 1, 2))=>@force_nonlinear(+(x, 3)),
        @force_nonlinear(+(1, x, 2))=>@force_nonlinear(+(3, x)),
        @force_nonlinear(+(x, 0))=>x,
        @force_nonlinear(+(-(x, x), 0))=>0,
        # simplify(::Val{:-}, f::MOI.ScalarNonlinearFunction)
        @force_nonlinear(-(-(x)))=>x,
        @force_nonlinear(-(x, 0))=>x,
        @force_nonlinear(-(0, x))=>@force_nonlinear(-x),
        @force_nonlinear(-(x, x))=>0,
        @force_nonlinear(-(x, -y))=>@force_nonlinear(x + y),
        @force_nonlinear(-(x, y))=>@force_nonlinear(x - y),
        # simplify(::Val{:^}, f::MOI.ScalarNonlinearFunction)
        @force_nonlinear(^(x, 0))=>1,
        @force_nonlinear(^(x, 1))=>x,
        @force_nonlinear(^(0, x))=>0,
        @force_nonlinear(^(1, x))=>1,
        x^y=>x^y,
    ]
        g = SymbolicAD.simplify(moi_function(f))
        if !(g  moi_function(fp))
            @show f
            @show g
        end
        @test g  moi_function(fp)
    end
    return
end

function test_variable()
    model = Model()
    @variable(model, x)
    @variable(model, y)
    @variable(model, z)
    @testset "$f" for (f, fp) in Any[
        # ::Real
        1.0=>[],
        # ::VariableRef,
        x=>[x],
        # ::AffExpr
        AffExpr(2.0)=>[],
        x+1=>[x],
        2x+1=>[x],
        2x+y+1=>[x, y],
        y+1+z=>[y, z],
        # ::QuadExpr
        zero(QuadExpr)=>[],
        QuadExpr(x + 1)=>[x],
        QuadExpr(x + 1 + y)=>[x, y],
        x^2=>[x],
        x^2+x=>[x],
        x^2+y=>[y, x],
        x*y=>[x, y],
        y*x=>[y, x],
        # ::NonlinearExpr
        sin(x)=>[x],
        sin(x + y)=>[x, y],
        sin(x)*cos(y)=>[x, y],
    ]
        @test SymbolicAD.variables(moi_function(f)) == index.(fp)
    end
    return
end

@testset "SymbolicAD" begin
    @testset "derivative" begin
        test_derivative()
    end
    @testset "simplify" begin
        test_simplify()
    end
    @testset "variable" begin
        test_variable()
    end
    @testset "gradient" begin
        test_gradient()
    end
end

nothing

I think the trick for integrating this into MathOptSymbolicAD is to have an efficient interpreter that re-uses expression values across the primal and derivatives evaluation. The symbolic expression trees are always going to be fundamentally limited.

@odow
Copy link
Member Author

odow commented Sep 11, 2024

Thinking on this, I should probably merge this first into MathOptSymbolicAD.jl, get it working, and then we can add MathOptSymbolicAD as MOI.Nonlinear.SymbolicAD.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Project: next-gen nonlinear support Issues relating to nonlinear support Submodule: Nonlinear About the Nonlinear submodule Type: Enhancement
Development

No branches or pull requests

1 participant