Skip to content

Commit

Permalink
stage direct method
Browse files Browse the repository at this point in the history
  • Loading branch information
bzhangcw committed Aug 28, 2023
1 parent 2b30c3f commit 8a422f5
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 52 deletions.
27 changes: 15 additions & 12 deletions src/algorithms/utr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,8 @@ function Base.iterate(
)
end
k₂ = 0
γ = 1.5
γ₁ = 8.0
γ₂ = 5.0
η = 1.0
ρ = 1.0

Expand All @@ -339,13 +340,15 @@ function Base.iterate(
n = state.∇f |> length
# dual estimate
λ₁ = 0.0
θ = σ
while true

# if not accepted
# λ (dual) must increase
v, θ, λ₁, kᵢ = TrustRegionCholesky(
H,
H + σ * I,
state.∇f,
Δ;
λ₁=λ₁
λ₁=θ - σ
)
state.α = 1.0
fx = iter.f(state.z + v * state.α)
Expand All @@ -357,21 +360,21 @@ function Base.iterate(
k₂ += 1
@debug """inner""" v |> norm, Δ, θ, λ₁, kᵢ, df, ρₐ
Δ = min(Δ, v |> norm)
if> 1e-8) && ((df < 0) || ((df < Df) && (ρₐ < 0.6))) # not satisfactory
if abs(λ₁) >= 1e-3 # too cvx or ncvx
if> 1e-8) && ((df < 0) || ((df < Df) && (ρₐ < 0.2))) # not satisfactory
if abs(λ₁) >= 1e-8 # too cvx or ncvx
σ = 0.0
else
σ *= γ
σ *= γ
end
# dec radius
Δ /= γ
Df /= γ
# in this case, λ (dual) must increase
Δ /= γ
Df /= γ

continue
end
if ρₐ > 0.9
σ /= γ
Δ *= γ
σ /= γ
Δ *= γ
end
# do this when accept
state.σ = max(1e-12, σ / grad_regularizer)
Expand Down
8 changes: 4 additions & 4 deletions test/convex/test_logistic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,18 @@ using .LP
using LoopVectorization
using LIBSVMFileIO

bool_opt = true
bool_opt = false
bool_plot = false
bool_q_preprocessed = true
f1(A, d=2) = sqrt.(sum(abs2.(A), dims=d))

ε = 1e-6 # * max(g(x0) |> norm, 1)
λ = 1e-7
if bool_q_preprocessed
name = "a4a"
# name = "a4a"
# name = "a9a"
# name = "w4a"
# name = "covtype"
name = "covtype"
# name = "news20"
# name = "rcv1"

Expand Down Expand Up @@ -226,7 +226,7 @@ if bool_opt
ru = UTR(name=Symbol("Universal-TRS"))(;
x0=copy(x0), f=loss, g=g, H=H,
maxiter=10000, tol=1e-6, freq=1,
direction=:warm, bool_subp_exact=0
direction=:warm, bool_subp_exact=1
)
end

Expand Down
81 changes: 45 additions & 36 deletions test/convex/test_soft_maximum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,47 +35,50 @@ using LinearAlgebra
using Statistics
using LinearOperators
using Optim
using SparseArrays
using .LP


using LIBSVMFileIO

bool_plot = true
bool_plot = false
bool_opt = false

Random.seed!(2)
n = 500
m = 1000
μ = 5e-2
X = [rand(Float64, n) * 2 .- 1 for _ in 1:m]
Xm = hcat(X)'
y = rand(Float64, m) * 2 .- 1
# y = max.(y, 0)
# loss
x0 = ones(n) / 10
function loss_orig(w)
loss_single(x, y0) = exp((w' * x - y0) / μ)
_pure = loss_single.(X, y) |> sum
return μ * log(_pure)
end
function grad_orig(w)
a = (Xm * w - y) / μ
ax = exp.(a)
π0 = ax / (ax |> sum)
= Xm' * π0
return
end
function hess(w)
a = (Xm * w - y) / μ
ax = exp.(a)
π0 = ax / (ax |> sum)
return 1 / μ * (Xm' * Diagonal(π0) * Xm - Xm' * π0 * π0' * Xm)
bool_setup = true

if bool_setup
Random.seed!(2)
n = 500
m = 1000
μ = 5e-2
X = [rand(Float64, n) * 2 .- 1 for _ in 1:m]
Xm = hcat(X)'
y = rand(Float64, m) * 2 .- 1
# y = max.(y, 0)
# loss
x0 = ones(n) / 10
function loss_orig(w)
loss_single(x, y0) = exp((w' * x - y0) / μ)
_pure = loss_single.(X, y) |> sum
return μ * log(_pure)
end
function grad_orig(w)
a = (Xm * w - y) / μ
ax = exp.(a)
π0 = ax / (ax |> sum)
= Xm' * π0
return
end
function hess(w)
a = (Xm * w - y) / μ
ax = exp.(a)
π0 = ax / (ax |> sum)
return Symmetric(sparse(1 / μ * (Xm' * Diagonal(π0) * Xm - Xm' * π0 * π0' * Xm)))
end
∇₀ = grad_orig(zeros(x0 |> size))
grad(w) = grad_orig(w) - ∇₀
loss(w) = loss_orig(w) - ∇₀'w
ε = 1e-5
end
∇₀ = grad_orig(zeros(x0 |> size))
grad(w) = grad_orig(w) - ∇₀
loss(w) = loss_orig(w) - ∇₀'w
ε = 1e-5

if bool_opt
# compare with GD and LBFGS, Trust region newton,
options = Optim.Options(
Expand All @@ -94,8 +97,8 @@ if bool_opt
)
r_newton = Optim.optimize(
loss, grad, hess, x0,
Newton(; alphaguess=LineSearches.InitialStatic(),
linesearch=LineSearches.BackTracking()), options;
Newton(; alphaguess=LineSearches.InitialStatic()
), options;
inplace=false
)
r = HSODM()(;
Expand All @@ -112,6 +115,12 @@ if bool_opt
maxtime=10000,
direction=:warm
)

ru = UTR(;)(;
x0=copy(x0), f=loss, g=grad, H=hess,
maxiter=10000, tol=1e-6, freq=1,
direction=:warm, bool_subp_exact=1
)
end

if bool_plot
Expand Down

0 comments on commit 8a422f5

Please sign in to comment.