diff --git a/src/nlp_expr.jl b/src/nlp_expr.jl index fbe7e631753..fa075adccf3 100644 --- a/src/nlp_expr.jl +++ b/src/nlp_expr.jl @@ -144,79 +144,48 @@ function _needs_parentheses(x::GenericNonlinearExpr) return x.head in _PREFIX_OPERATORS && length(x.args) > 1 end -function function_string(::MIME"text/plain", x::GenericNonlinearExpr) - io, stack = IOBuffer(), Any[x] - while !isempty(stack) - arg = pop!(stack) - if arg isa GenericNonlinearExpr - if arg.head in _PREFIX_OPERATORS && length(arg.args) > 1 - if _needs_parentheses(arg.args[1]) - print(io, "(") - end - if _needs_parentheses(arg.args[end]) - push!(stack, ")") - end - for i in length(arg.args):-1:2 - push!(stack, arg.args[i]) - if _needs_parentheses(arg.args[i]) - push!(stack, "(") - end - push!(stack, " $(arg.head) ") - if _needs_parentheses(arg.args[i-1]) - push!(stack, ")") - end - end - push!(stack, arg.args[1]) - else - print(io, arg.head, "(") - push!(stack, ")") - for i in length(arg.args):-1:2 - push!(stack, arg.args[i]) - push!(stack, ", ") - end - if length(arg.args) >= 1 - push!(stack, arg.args[1]) - end - end - else - print(io, arg) - end - end - seekstart(io) - return read(io, String) -end +_parens(::MIME) = "(", ")", "", "", "" +_parens(::MIME"text/latex") = "\\left(", "\\right)", "{", "}", "\\textsf" -function function_string(::MIME"text/latex", x::GenericNonlinearExpr) +function function_string(mime::MIME, x::GenericNonlinearExpr) + p_left, p_right, p_open, p_close, p_textsf = _parens(mime) io, stack = IOBuffer(), Any[x] while !isempty(stack) arg = pop!(stack) if arg isa GenericNonlinearExpr if arg.head in _PREFIX_OPERATORS && length(arg.args) > 1 - print(io, "{") - push!(stack, "}") - if _needs_parentheses(arg.args[1]) - print(io, "\\left(") - end - if _needs_parentheses(arg.args[end]) - push!(stack, "\\right)") - end - for i in length(arg.args):-1:2 - push!(stack, arg.args[i]) - if _needs_parentheses(arg.args[i]) - push!(stack, "\\left(") + print(io, p_open) + push!(stack, p_close) + l = ceil(_TERM_LIMIT_FOR_PRINTING[] / 2) + r = floor(_TERM_LIMIT_FOR_PRINTING[] / 2) + skip_indices = (1+l):(length(arg.args)-r) + for i in length(arg.args):-1:1 + if i in skip_indices + if i == skip_indices[end] + push!( + stack, + _terms_omitted(mime, length(skip_indices)), + ) + push!(stack, " $(arg.head) $p_open") + end + continue + elseif _needs_parentheses(arg.args[i]) + push!(stack, p_right) + push!(stack, arg.args[i]) + push!(stack, p_left) + else + push!(stack, arg.args[i]) end - push!(stack, "} $(arg.head) {") - if _needs_parentheses(arg.args[i-1]) - push!(stack, "\\right)") + if i > 1 + push!(stack, "$p_close $(arg.head) $p_open") end end - push!(stack, arg.args[1]) else - print(io, "\\textsf{", arg.head, "}\\left({") - push!(stack, "}\\right)") + print(io, p_textsf, p_open, arg.head, p_close, p_left, p_open) + push!(stack, p_close * p_right) for i in length(arg.args):-1:2 push!(stack, arg.args[i]) - push!(stack, "}, {") + push!(stack, "$p_close, $p_open") end if length(arg.args) >= 1 push!(stack, arg.args[1]) diff --git a/src/print.jl b/src/print.jl index c48269fafb1..5847ab3922b 100644 --- a/src/print.jl +++ b/src/print.jl @@ -605,6 +605,41 @@ function _term_string(coef, factor) end end +""" + const _TERM_LIMIT_FOR_PRINTING = Ref{Int}(60) + +A global constant used to control when terms are omitted when printing +expressions. + +Get and set this value using `_TERM_LIMIT_FOR_PRINTING[]`. + +```julia +julia> _TERM_LIMIT_FOR_PRINTING[] +60 + +julia> _TERM_LIMIT_FOR_PRINTING[] = 10 +10 +``` +""" +const _TERM_LIMIT_FOR_PRINTING = Ref{Int}(60) + +_terms_omitted(::MIME, n::Int) = "[[...$n terms omitted...]]" + +function _terms_omitted(::MIME"text/latex", n::Int) + return "[[\\ldots\\text{$n terms omitted}\\ldots]]" +end + +function _terms_to_truncated_string(mode, terms) + m = _TERM_LIMIT_FOR_PRINTING[] + if length(terms) <= 2 * m + return join(terms) + end + k_l = iseven(m) ? m + 1 : m + 2 + k_r = iseven(m) ? m - 1 : m - 2 + block = _terms_omitted(mode, div(length(terms), 2) - m) + return string(join(terms[1:k_l]), block, join(terms[(end-k_r):end])) +end + # TODO(odow): remove show_constant in JuMP 1.0 function function_string(mode, a::GenericAffExpr, show_constant = true) if length(linear_terms(a)) == 0 @@ -616,7 +651,7 @@ function function_string(mode, a::GenericAffExpr, show_constant = true) terms[2*elm] = _term_string(coef, function_string(mode, var)) end terms[1] = terms[1] == " - " ? "-" : "" - ret = join(terms) + ret = _terms_to_truncated_string(mode, terms) if show_constant && !_is_zero_for_printing(a.constant) ret = string( ret, @@ -645,7 +680,7 @@ function function_string(mode, q::GenericQuadExpr) terms[2*elm] = _term_string(coef, factor) end terms[1] = terms[1] == " - " ? "-" : "" - ret = join(terms) + ret = _terms_to_truncated_string(mode, terms) aff_str = function_string(mode, q.aff) if aff_str == "0" return ret diff --git a/test/test_nlp_expr.jl b/test/test_nlp_expr.jl index def51c63ff1..da60d3a13bd 100644 --- a/test/test_nlp_expr.jl +++ b/test/test_nlp_expr.jl @@ -978,4 +978,19 @@ function test_variable_ref_type() return end +function test_printing_truncation() + model = Model() + @variable(model, x[1:100]) + y = @expression(model, sum(sin.(x) .* 2)) + @test occursin( + "(sin(x[72]) * 2.0) + [[...41 terms omitted...]] + (sin(x[30]) * 2.0)", + function_string(MIME("text/plain"), y), + ) + @test occursin( + "{\\left({\\textsf{sin}\\left({x[72]}\\right)} * {2.0}\\right) + {[[\\ldots\\text{41 terms omitted}\\ldots]]} + {\\left({\\textsf{sin}\\left({x[30]}\\right)} * {2.0}\\right)} + {\\left({\\textsf{sin}\\left({x[29]}\\right)} * {2.0}\\right)} + {\\left({\\textsf{sin}\\left({x[28]}\\right)} * {2.0}\\right)} + {\\left({\\textsf{sin}\\left({x[27]}\\right)} * {2.0}\\right)} + {\\left({\\textsf{sin}\\left({x[26]}\\right)} * {2.0}\\right)}", + function_string(MIME("text/latex"), y), + ) + return +end + end # module diff --git a/test/test_print.jl b/test/test_print.jl index 5ed3eda1286..2eb4945e756 100644 --- a/test/test_print.jl +++ b/test/test_print.jl @@ -967,4 +967,24 @@ function test_print_text_latex_interval_set() return end +function test_truncated_printing() + model = Model() + @variable(model, x[1:1000]) + y = sum(x) + s = function_string(MIME("text/plain"), y) + @test occursin("x[30] + [[...940 terms omitted...]] + x[971]", s) + @test occursin( + "x_{30} + [[\\ldots\\text{940 terms omitted}\\ldots]] + x_{971}", + function_string(MIME("text/latex"), y), + ) + ret = JuMP._TERM_LIMIT_FOR_PRINTING[] + JuMP._TERM_LIMIT_FOR_PRINTING[] = 3 + @test function_string(MIME("text/plain"), y) == + "x[1] + x[2] + [[...997 terms omitted...]] + x[1000]" + @test function_string(MIME("text/latex"), y) == + "x_{1} + x_{2} + [[\\ldots\\text{997 terms omitted}\\ldots]] + x_{1000}" + JuMP._TERM_LIMIT_FOR_PRINTING[] = ret + return +end + end