Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Truncate printing of large expressions #3575

Merged
merged 10 commits into from
Nov 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 30 additions & 61 deletions src/nlp_expr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -144,79 +144,48 @@ function _needs_parentheses(x::GenericNonlinearExpr)
return x.head in _PREFIX_OPERATORS && length(x.args) > 1
end

function function_string(::MIME"text/plain", x::GenericNonlinearExpr)
io, stack = IOBuffer(), Any[x]
while !isempty(stack)
arg = pop!(stack)
if arg isa GenericNonlinearExpr
if arg.head in _PREFIX_OPERATORS && length(arg.args) > 1
if _needs_parentheses(arg.args[1])
print(io, "(")
end
if _needs_parentheses(arg.args[end])
push!(stack, ")")
end
for i in length(arg.args):-1:2
push!(stack, arg.args[i])
if _needs_parentheses(arg.args[i])
push!(stack, "(")
end
push!(stack, " $(arg.head) ")
if _needs_parentheses(arg.args[i-1])
push!(stack, ")")
end
end
push!(stack, arg.args[1])
else
print(io, arg.head, "(")
push!(stack, ")")
for i in length(arg.args):-1:2
push!(stack, arg.args[i])
push!(stack, ", ")
end
if length(arg.args) >= 1
push!(stack, arg.args[1])
end
end
else
print(io, arg)
end
end
seekstart(io)
return read(io, String)
end
_parens(::MIME) = "(", ")", "", "", ""
_parens(::MIME"text/latex") = "\\left(", "\\right)", "{", "}", "\\textsf"

function function_string(::MIME"text/latex", x::GenericNonlinearExpr)
function function_string(mime::MIME, x::GenericNonlinearExpr)
p_left, p_right, p_open, p_close, p_textsf = _parens(mime)
io, stack = IOBuffer(), Any[x]
while !isempty(stack)
arg = pop!(stack)
if arg isa GenericNonlinearExpr
if arg.head in _PREFIX_OPERATORS && length(arg.args) > 1
print(io, "{")
push!(stack, "}")
if _needs_parentheses(arg.args[1])
print(io, "\\left(")
end
if _needs_parentheses(arg.args[end])
push!(stack, "\\right)")
end
for i in length(arg.args):-1:2
push!(stack, arg.args[i])
if _needs_parentheses(arg.args[i])
push!(stack, "\\left(")
print(io, p_open)
push!(stack, p_close)
l = ceil(_TERM_LIMIT_FOR_PRINTING[] / 2)
r = floor(_TERM_LIMIT_FOR_PRINTING[] / 2)
skip_indices = (1+l):(length(arg.args)-r)
for i in length(arg.args):-1:1
if i in skip_indices
if i == skip_indices[end]
push!(
stack,
_terms_omitted(mime, length(skip_indices)),
)
push!(stack, " $(arg.head) $p_open")
end
continue
elseif _needs_parentheses(arg.args[i])
push!(stack, p_right)
push!(stack, arg.args[i])
push!(stack, p_left)
else
push!(stack, arg.args[i])
end
push!(stack, "} $(arg.head) {")
if _needs_parentheses(arg.args[i-1])
push!(stack, "\\right)")
if i > 1
push!(stack, "$p_close $(arg.head) $p_open")
end
end
push!(stack, arg.args[1])
else
print(io, "\\textsf{", arg.head, "}\\left({")
push!(stack, "}\\right)")
print(io, p_textsf, p_open, arg.head, p_close, p_left, p_open)
push!(stack, p_close * p_right)
for i in length(arg.args):-1:2
push!(stack, arg.args[i])
push!(stack, "}, {")
push!(stack, "$p_close, $p_open")
end
if length(arg.args) >= 1
push!(stack, arg.args[1])
Expand Down
39 changes: 37 additions & 2 deletions src/print.jl
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,41 @@ function _term_string(coef, factor)
end
end

"""
const _TERM_LIMIT_FOR_PRINTING = Ref{Int}(60)
A global constant used to control when terms are omitted when printing
expressions.
Get and set this value using `_TERM_LIMIT_FOR_PRINTING[]`.
```julia
julia> _TERM_LIMIT_FOR_PRINTING[]
60
julia> _TERM_LIMIT_FOR_PRINTING[] = 10
10
```
"""
const _TERM_LIMIT_FOR_PRINTING = Ref{Int}(60)

_terms_omitted(::MIME, n::Int) = "[[...$n terms omitted...]]"

function _terms_omitted(::MIME"text/latex", n::Int)
return "[[\\ldots\\text{$n terms omitted}\\ldots]]"
end

function _terms_to_truncated_string(mode, terms)
m = _TERM_LIMIT_FOR_PRINTING[]
if length(terms) <= 2 * m
return join(terms)
end
k_l = iseven(m) ? m + 1 : m + 2
k_r = iseven(m) ? m - 1 : m - 2
block = _terms_omitted(mode, div(length(terms), 2) - m)
return string(join(terms[1:k_l]), block, join(terms[(end-k_r):end]))
end

# TODO(odow): remove show_constant in JuMP 1.0
function function_string(mode, a::GenericAffExpr, show_constant = true)
if length(linear_terms(a)) == 0
Expand All @@ -616,7 +651,7 @@ function function_string(mode, a::GenericAffExpr, show_constant = true)
terms[2*elm] = _term_string(coef, function_string(mode, var))
end
terms[1] = terms[1] == " - " ? "-" : ""
ret = join(terms)
ret = _terms_to_truncated_string(mode, terms)
if show_constant && !_is_zero_for_printing(a.constant)
ret = string(
ret,
Expand Down Expand Up @@ -645,7 +680,7 @@ function function_string(mode, q::GenericQuadExpr)
terms[2*elm] = _term_string(coef, factor)
end
terms[1] = terms[1] == " - " ? "-" : ""
ret = join(terms)
ret = _terms_to_truncated_string(mode, terms)
aff_str = function_string(mode, q.aff)
if aff_str == "0"
return ret
Expand Down
15 changes: 15 additions & 0 deletions test/test_nlp_expr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -978,4 +978,19 @@ function test_variable_ref_type()
return
end

function test_printing_truncation()
model = Model()
@variable(model, x[1:100])
y = @expression(model, sum(sin.(x) .* 2))
@test occursin(
"(sin(x[72]) * 2.0) + [[...41 terms omitted...]] + (sin(x[30]) * 2.0)",
function_string(MIME("text/plain"), y),
)
@test occursin(
"{\\left({\\textsf{sin}\\left({x[72]}\\right)} * {2.0}\\right) + {[[\\ldots\\text{41 terms omitted}\\ldots]]} + {\\left({\\textsf{sin}\\left({x[30]}\\right)} * {2.0}\\right)} + {\\left({\\textsf{sin}\\left({x[29]}\\right)} * {2.0}\\right)} + {\\left({\\textsf{sin}\\left({x[28]}\\right)} * {2.0}\\right)} + {\\left({\\textsf{sin}\\left({x[27]}\\right)} * {2.0}\\right)} + {\\left({\\textsf{sin}\\left({x[26]}\\right)} * {2.0}\\right)}",
function_string(MIME("text/latex"), y),
)
return
end

end # module
20 changes: 20 additions & 0 deletions test/test_print.jl
Original file line number Diff line number Diff line change
Expand Up @@ -967,4 +967,24 @@ function test_print_text_latex_interval_set()
return
end

function test_truncated_printing()
model = Model()
@variable(model, x[1:1000])
y = sum(x)
s = function_string(MIME("text/plain"), y)
@test occursin("x[30] + [[...940 terms omitted...]] + x[971]", s)
@test occursin(
"x_{30} + [[\\ldots\\text{940 terms omitted}\\ldots]] + x_{971}",
function_string(MIME("text/latex"), y),
)
ret = JuMP._TERM_LIMIT_FOR_PRINTING[]
JuMP._TERM_LIMIT_FOR_PRINTING[] = 3
@test function_string(MIME("text/plain"), y) ==
"x[1] + x[2] + [[...997 terms omitted...]] + x[1000]"
@test function_string(MIME("text/latex"), y) ==
"x_{1} + x_{2} + [[\\ldots\\text{997 terms omitted}\\ldots]] + x_{1000}"
JuMP._TERM_LIMIT_FOR_PRINTING[] = ret
return
end

end
Loading