torchzero 0.3.9__py3-none-any.whl → 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +6 -4
- docs/source/docstring template.py +46 -0
- tests/test_identical.py +2 -3
- tests/test_opts.py +115 -68
- tests/test_tensorlist.py +2 -2
- tests/test_vars.py +62 -61
- torchzero/core/__init__.py +2 -3
- torchzero/core/module.py +185 -53
- torchzero/core/transform.py +327 -159
- torchzero/modules/__init__.py +3 -1
- torchzero/modules/clipping/clipping.py +120 -23
- torchzero/modules/clipping/ema_clipping.py +37 -22
- torchzero/modules/clipping/growth_clipping.py +20 -21
- torchzero/modules/experimental/__init__.py +30 -4
- torchzero/modules/experimental/absoap.py +53 -156
- torchzero/modules/experimental/adadam.py +22 -15
- torchzero/modules/experimental/adamY.py +21 -25
- torchzero/modules/experimental/adam_lambertw.py +149 -0
- torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +37 -8
- torchzero/modules/experimental/adasoap.py +24 -129
- torchzero/modules/experimental/cosine.py +214 -0
- torchzero/modules/experimental/cubic_adam.py +97 -0
- torchzero/modules/experimental/curveball.py +12 -12
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/experimental/eigendescent.py +120 -0
- torchzero/modules/experimental/etf.py +195 -0
- torchzero/modules/experimental/exp_adam.py +113 -0
- torchzero/modules/experimental/expanded_lbfgs.py +141 -0
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/hnewton.py +85 -0
- torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +49 -50
- torchzero/modules/experimental/newton_solver.py +11 -11
- torchzero/modules/experimental/newtonnewton.py +92 -0
- torchzero/modules/experimental/parabolic_search.py +220 -0
- torchzero/modules/experimental/reduce_outward_lr.py +10 -7
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
- torchzero/modules/experimental/subspace_preconditioners.py +20 -10
- torchzero/modules/experimental/tensor_adagrad.py +42 -0
- torchzero/modules/functional.py +12 -2
- torchzero/modules/grad_approximation/fdm.py +31 -4
- torchzero/modules/grad_approximation/forward_gradient.py +17 -7
- torchzero/modules/grad_approximation/grad_approximator.py +69 -24
- torchzero/modules/grad_approximation/rfdm.py +310 -50
- torchzero/modules/higher_order/__init__.py +1 -0
- torchzero/modules/higher_order/higher_order_newton.py +319 -0
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/adaptive.py +99 -0
- torchzero/modules/line_search/backtracking.py +75 -31
- torchzero/modules/line_search/line_search.py +107 -49
- torchzero/modules/line_search/polynomial.py +233 -0
- torchzero/modules/line_search/scipy.py +20 -5
- torchzero/modules/line_search/strong_wolfe.py +52 -36
- torchzero/modules/misc/__init__.py +27 -0
- torchzero/modules/misc/debug.py +48 -0
- torchzero/modules/misc/escape.py +60 -0
- torchzero/modules/misc/gradient_accumulation.py +70 -0
- torchzero/modules/misc/misc.py +316 -0
- torchzero/modules/misc/multistep.py +158 -0
- torchzero/modules/misc/regularization.py +171 -0
- torchzero/modules/misc/split.py +103 -0
- torchzero/modules/{ops → misc}/switch.py +48 -7
- torchzero/modules/momentum/__init__.py +1 -1
- torchzero/modules/momentum/averaging.py +25 -10
- torchzero/modules/momentum/cautious.py +115 -40
- torchzero/modules/momentum/ema.py +92 -41
- torchzero/modules/momentum/experimental.py +21 -13
- torchzero/modules/momentum/matrix_momentum.py +145 -76
- torchzero/modules/momentum/momentum.py +25 -4
- torchzero/modules/ops/__init__.py +3 -31
- torchzero/modules/ops/accumulate.py +51 -25
- torchzero/modules/ops/binary.py +108 -62
- torchzero/modules/ops/multi.py +95 -34
- torchzero/modules/ops/reduce.py +31 -23
- torchzero/modules/ops/unary.py +37 -21
- torchzero/modules/ops/utility.py +53 -45
- torchzero/modules/optimizers/__init__.py +12 -3
- torchzero/modules/optimizers/adagrad.py +48 -29
- torchzero/modules/optimizers/adahessian.py +223 -0
- torchzero/modules/optimizers/adam.py +35 -37
- torchzero/modules/optimizers/adan.py +110 -0
- torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
- torchzero/modules/optimizers/esgd.py +171 -0
- torchzero/modules/optimizers/ladagrad.py +183 -0
- torchzero/modules/optimizers/lion.py +4 -4
- torchzero/modules/optimizers/mars.py +91 -0
- torchzero/modules/optimizers/msam.py +186 -0
- torchzero/modules/optimizers/muon.py +32 -7
- torchzero/modules/optimizers/orthograd.py +4 -5
- torchzero/modules/optimizers/rmsprop.py +19 -19
- torchzero/modules/optimizers/rprop.py +89 -52
- torchzero/modules/optimizers/sam.py +163 -0
- torchzero/modules/optimizers/shampoo.py +55 -27
- torchzero/modules/optimizers/soap.py +40 -37
- torchzero/modules/optimizers/sophia_h.py +82 -25
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +4 -2
- torchzero/modules/projections/projection.py +212 -118
- torchzero/modules/quasi_newton/__init__.py +44 -5
- torchzero/modules/quasi_newton/cg.py +190 -39
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
- torchzero/modules/quasi_newton/lbfgs.py +154 -97
- torchzero/modules/quasi_newton/lsr1.py +102 -58
- torchzero/modules/quasi_newton/quasi_newton.py +1032 -177
- torchzero/modules/quasi_newton/trust_region.py +397 -0
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/newton.py +245 -54
- torchzero/modules/second_order/newton_cg.py +311 -21
- torchzero/modules/second_order/nystrom.py +124 -21
- torchzero/modules/smoothing/gaussian.py +55 -21
- torchzero/modules/smoothing/laplacian.py +20 -12
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +122 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +126 -10
- torchzero/modules/wrappers/optim_wrapper.py +40 -12
- torchzero/optim/wrappers/directsearch.py +281 -0
- torchzero/optim/wrappers/fcmaes.py +105 -0
- torchzero/optim/wrappers/mads.py +89 -0
- torchzero/optim/wrappers/nevergrad.py +20 -5
- torchzero/optim/wrappers/nlopt.py +28 -14
- torchzero/optim/wrappers/optuna.py +70 -0
- torchzero/optim/wrappers/scipy.py +167 -16
- torchzero/utils/__init__.py +3 -7
- torchzero/utils/derivatives.py +5 -4
- torchzero/utils/linalg/__init__.py +1 -1
- torchzero/utils/linalg/solve.py +251 -12
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/optimizer.py +55 -74
- torchzero/utils/python_tools.py +27 -4
- torchzero/utils/tensorlist.py +40 -28
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/METADATA +76 -51
- torchzero-0.3.11.dist-info/RECORD +159 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/WHEEL +1 -1
- torchzero/core/preconditioner.py +0 -138
- torchzero/modules/experimental/algebraic_newton.py +0 -145
- torchzero/modules/experimental/soapy.py +0 -290
- torchzero/modules/experimental/spectral.py +0 -288
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/experimental/tropical_newton.py +0 -136
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/lr.py +0 -59
- torchzero/modules/lr/step_size.py +0 -97
- torchzero/modules/ops/debug.py +0 -25
- torchzero/modules/ops/misc.py +0 -419
- torchzero/modules/ops/split.py +0 -75
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero-0.3.9.dist-info/RECORD +0 -131
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,319 @@
|
|
|
1
|
+
import itertools
|
|
2
|
+
import math
|
|
3
|
+
import warnings
|
|
4
|
+
from collections.abc import Callable
|
|
5
|
+
from contextlib import nullcontext
|
|
6
|
+
from functools import partial
|
|
7
|
+
from typing import Any, Literal
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
import scipy.optimize
|
|
11
|
+
import torch
|
|
12
|
+
|
|
13
|
+
from ...core import Chainable, Module, apply_transform
|
|
14
|
+
from ...utils import TensorList, vec_to_tensors, vec_to_tensors_
|
|
15
|
+
from ...utils.derivatives import (
|
|
16
|
+
hessian_list_to_mat,
|
|
17
|
+
jacobian_wrt,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
_LETTERS = 'abcdefghijklmnopqrstuvwxyz'
|
|
21
|
+
def _poly_eval(s: np.ndarray, c, derivatives):
|
|
22
|
+
val = float(c)
|
|
23
|
+
for i,T in enumerate(derivatives, 1):
|
|
24
|
+
s1 = ''.join(_LETTERS[:i]) # abcd
|
|
25
|
+
s2 = ',...'.join(_LETTERS[:i]) # a,b,c,d
|
|
26
|
+
# this would make einsum('abcd,a,b,c,d', T, x, x, x, x)
|
|
27
|
+
val += np.einsum(f"...{s1},...{s2}", T, *(s for _ in range(i))) / math.factorial(i)
|
|
28
|
+
return val
|
|
29
|
+
|
|
30
|
+
def _proximal_poly_v(x: np.ndarray, c, prox, x0: np.ndarray, derivatives):
|
|
31
|
+
if x.ndim == 2: x = x.T # DE passes (ndim, batch_size)
|
|
32
|
+
s = x - x0
|
|
33
|
+
val = _poly_eval(s, c, derivatives)
|
|
34
|
+
penalty = 0
|
|
35
|
+
if prox != 0: penalty = (prox / 2) * (s**2).sum(-1) # proximal penalty
|
|
36
|
+
return val + penalty
|
|
37
|
+
|
|
38
|
+
def _proximal_poly_g(x: np.ndarray, c, prox, x0: np.ndarray, derivatives):
|
|
39
|
+
s = x - x0
|
|
40
|
+
g = derivatives[0].copy()
|
|
41
|
+
if len(derivatives) > 1:
|
|
42
|
+
for i, T in enumerate(derivatives[1:], 2):
|
|
43
|
+
s1 = ''.join(_LETTERS[:i]) # abcd
|
|
44
|
+
s2 = ','.join(_LETTERS[1:i]) # b,c,d
|
|
45
|
+
# this would make einsum('abcd,b,c,d->a', T, x, x, x)
|
|
46
|
+
g += np.einsum(f"{s1},{s2}->a", T, *(s for _ in range(i-1))) / math.factorial(i - 1)
|
|
47
|
+
|
|
48
|
+
g_prox = 0
|
|
49
|
+
if prox != 0: g_prox = prox * s
|
|
50
|
+
return g + g_prox
|
|
51
|
+
|
|
52
|
+
def _proximal_poly_H(x: np.ndarray, c, prox, x0: np.ndarray, derivatives):
|
|
53
|
+
s = x - x0
|
|
54
|
+
n = x.shape[0]
|
|
55
|
+
if len(derivatives) == 1:
|
|
56
|
+
H = np.zeros(n, n)
|
|
57
|
+
else:
|
|
58
|
+
H = derivatives[1].copy()
|
|
59
|
+
if len(derivatives) > 2:
|
|
60
|
+
for i, T in enumerate(derivatives[2:], 3):
|
|
61
|
+
s1 = ''.join(_LETTERS[:i]) # abcd
|
|
62
|
+
s2 = ','.join(_LETTERS[2:i]) # c,d
|
|
63
|
+
# this would make einsum('abcd,c,d->ab', T, x, x, x)
|
|
64
|
+
H += np.einsum(f"{s1},{s2}->ab", T, *(s for _ in range(i-2))) / math.factorial(i - 2)
|
|
65
|
+
|
|
66
|
+
H_prox = 0
|
|
67
|
+
if prox != 0: H_prox = np.eye(n) * prox
|
|
68
|
+
return H + H_prox
|
|
69
|
+
|
|
70
|
+
def _poly_minimize(trust_region, prox, de_iters: Any, c, x: torch.Tensor, derivatives):
|
|
71
|
+
derivatives = [T.detach().cpu().numpy().astype(np.float64) for T in derivatives]
|
|
72
|
+
x0 = x.detach().cpu().numpy().astype(np.float64) # taylor series center
|
|
73
|
+
|
|
74
|
+
# notes
|
|
75
|
+
# 1. since we have exact hessian we use trust methods
|
|
76
|
+
|
|
77
|
+
# 2. if len(derivatives) is 1, only gradient is available,
|
|
78
|
+
# thus use slsqp depending on whether trust region is enabled
|
|
79
|
+
# this is just so that I can test that trust region works
|
|
80
|
+
if trust_region is None:
|
|
81
|
+
if len(derivatives) == 1: raise RuntimeError("trust region must be enabled because 1st order has no minima")
|
|
82
|
+
method = 'trust-exact'
|
|
83
|
+
de_bounds = list(zip(x0 - 10, x0 + 10))
|
|
84
|
+
constraints = None
|
|
85
|
+
|
|
86
|
+
else:
|
|
87
|
+
if len(derivatives) == 1: method = 'slsqp'
|
|
88
|
+
else: method = 'trust-constr'
|
|
89
|
+
de_bounds = list(zip(x0 - trust_region, x0 + trust_region))
|
|
90
|
+
|
|
91
|
+
def l2_bound_f(x):
|
|
92
|
+
if x.ndim == 2: return np.sum((x - x0[:,None])**2, axis=0)[None,:] # DE passes (ndim, batch_size) and expects (M, S)
|
|
93
|
+
return np.sum((x - x0)**2, axis=0)
|
|
94
|
+
|
|
95
|
+
def l2_bound_g(x):
|
|
96
|
+
return 2 * (x - x0)
|
|
97
|
+
|
|
98
|
+
def l2_bound_h(x, v):
|
|
99
|
+
return v[0] * 2 * np.eye(x0.shape[0])
|
|
100
|
+
|
|
101
|
+
constraint = scipy.optimize.NonlinearConstraint(
|
|
102
|
+
fun=l2_bound_f,
|
|
103
|
+
lb=0, # 0 <= ||x-x0||^2
|
|
104
|
+
ub=trust_region**2, # ||x-x0||^2 <= R^2
|
|
105
|
+
jac=l2_bound_g, # pyright:ignore[reportArgumentType]
|
|
106
|
+
hess=l2_bound_h,
|
|
107
|
+
keep_feasible=False
|
|
108
|
+
)
|
|
109
|
+
constraints = [constraint]
|
|
110
|
+
|
|
111
|
+
x_init = x0.copy()
|
|
112
|
+
v0 = _proximal_poly_v(x0, c, prox, x0, derivatives)
|
|
113
|
+
|
|
114
|
+
# ---------------------------------- run DE ---------------------------------- #
|
|
115
|
+
if de_iters is not None and de_iters != 0:
|
|
116
|
+
if de_iters == -1: de_iters = None # let scipy decide
|
|
117
|
+
|
|
118
|
+
# DE needs bounds so use linf ig
|
|
119
|
+
res = scipy.optimize.differential_evolution(
|
|
120
|
+
_proximal_poly_v,
|
|
121
|
+
de_bounds,
|
|
122
|
+
args=(c, prox, x0.copy(), derivatives),
|
|
123
|
+
maxiter=de_iters,
|
|
124
|
+
vectorized=True,
|
|
125
|
+
constraints = constraints,
|
|
126
|
+
updating='deferred',
|
|
127
|
+
)
|
|
128
|
+
if res.fun < v0 and np.all(np.isfinite(res.x)): x_init = res.x
|
|
129
|
+
|
|
130
|
+
# ------------------------------- run minimize ------------------------------- #
|
|
131
|
+
try:
|
|
132
|
+
res = scipy.optimize.minimize(
|
|
133
|
+
_proximal_poly_v,
|
|
134
|
+
x_init,
|
|
135
|
+
method=method,
|
|
136
|
+
args=(c, prox, x0.copy(), derivatives),
|
|
137
|
+
jac=_proximal_poly_g,
|
|
138
|
+
hess=_proximal_poly_H,
|
|
139
|
+
constraints = constraints,
|
|
140
|
+
)
|
|
141
|
+
except ValueError:
|
|
142
|
+
return x, -float('inf')
|
|
143
|
+
return torch.from_numpy(res.x).to(x), res.fun
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
class HigherOrderNewton(Module):
|
|
148
|
+
"""A basic arbitrary order newton's method with optional trust region and proximal penalty.
|
|
149
|
+
|
|
150
|
+
This constructs an nth order taylor approximation via autograd and minimizes it with
|
|
151
|
+
scipy.optimize.minimize trust region newton solvers with optional proximal penalty.
|
|
152
|
+
|
|
153
|
+
.. note::
|
|
154
|
+
In most cases HigherOrderNewton should be the first module in the chain because it relies on extra autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
|
|
155
|
+
|
|
156
|
+
.. note::
|
|
157
|
+
This module requires the a closure passed to the optimizer step,
|
|
158
|
+
as it needs to re-evaluate the loss and gradients for calculating higher order derivatives.
|
|
159
|
+
The closure must accept a ``backward`` argument (refer to documentation).
|
|
160
|
+
|
|
161
|
+
.. warning::
|
|
162
|
+
this uses roughly O(N^order) memory and solving the subproblem can be very expensive.
|
|
163
|
+
|
|
164
|
+
.. warning::
|
|
165
|
+
"none" and "proximal" trust methods may generate subproblems that have no minima, causing divergence.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
|
|
169
|
+
order (int, optional):
|
|
170
|
+
Order of the method, number of taylor series terms (orders of derivatives) used to approximate the function. Defaults to 4.
|
|
171
|
+
trust_method (str | None, optional):
|
|
172
|
+
Method used for trust region.
|
|
173
|
+
- "bounds" - the model is minimized within bounds defined by trust region.
|
|
174
|
+
- "proximal" - the model is minimized with penalty for going too far from current point.
|
|
175
|
+
- "none" - disables trust region.
|
|
176
|
+
|
|
177
|
+
Defaults to 'bounds'.
|
|
178
|
+
increase (float, optional): trust region multiplier on good steps. Defaults to 1.5.
|
|
179
|
+
decrease (float, optional): trust region multiplier on bad steps. Defaults to 0.75.
|
|
180
|
+
trust_init (float | None, optional):
|
|
181
|
+
initial trust region size. If none, defaults to 1 on :code:`trust_method="bounds"` and 0.1 on :code:`"proximal"`. Defaults to None.
|
|
182
|
+
trust_tol (float, optional):
|
|
183
|
+
Maximum ratio of expected loss reduction to actual reduction for trust region increase.
|
|
184
|
+
Should 1 or higer. Defaults to 2.
|
|
185
|
+
de_iters (int | None, optional):
|
|
186
|
+
If this is specified, the model is minimized via differential evolution first to possibly escape local minima,
|
|
187
|
+
then it is passed to scipy.optimize.minimize. Defaults to None.
|
|
188
|
+
vectorize (bool, optional): whether to enable vectorized jacobians (usually faster). Defaults to True.
|
|
189
|
+
"""
|
|
190
|
+
def __init__(
|
|
191
|
+
self,
|
|
192
|
+
order: int = 4,
|
|
193
|
+
trust_method: Literal['bounds', 'proximal', 'none'] | None = 'bounds',
|
|
194
|
+
nplus: float = 2,
|
|
195
|
+
nminus: float = 0.25,
|
|
196
|
+
init: float | None = None,
|
|
197
|
+
eta: float = 1e-6,
|
|
198
|
+
max_attempts = 10,
|
|
199
|
+
de_iters: int | None = None,
|
|
200
|
+
vectorize: bool = True,
|
|
201
|
+
):
|
|
202
|
+
if init is None:
|
|
203
|
+
if trust_method == 'bounds': init = 1
|
|
204
|
+
else: init = 0.1
|
|
205
|
+
|
|
206
|
+
defaults = dict(order=order, trust_method=trust_method, nplus=nplus, nminus=nminus, eta=eta, init=init, vectorize=vectorize, de_iters=de_iters, max_attempts=max_attempts)
|
|
207
|
+
super().__init__(defaults)
|
|
208
|
+
|
|
209
|
+
@torch.no_grad
|
|
210
|
+
def step(self, var):
|
|
211
|
+
params = TensorList(var.params)
|
|
212
|
+
closure = var.closure
|
|
213
|
+
if closure is None: raise RuntimeError('HigherOrderNewton requires closure')
|
|
214
|
+
|
|
215
|
+
settings = self.settings[params[0]]
|
|
216
|
+
order = settings['order']
|
|
217
|
+
nplus = settings['nplus']
|
|
218
|
+
nminus = settings['nminus']
|
|
219
|
+
eta = settings['eta']
|
|
220
|
+
init = settings['init']
|
|
221
|
+
trust_method = settings['trust_method']
|
|
222
|
+
de_iters = settings['de_iters']
|
|
223
|
+
max_attempts = settings['max_attempts']
|
|
224
|
+
vectorize = settings['vectorize']
|
|
225
|
+
|
|
226
|
+
# ------------------------ calculate grad and hessian ------------------------ #
|
|
227
|
+
with torch.enable_grad():
|
|
228
|
+
loss = var.loss = var.loss_approx = closure(False)
|
|
229
|
+
|
|
230
|
+
g_list = torch.autograd.grad(loss, params, create_graph=True)
|
|
231
|
+
var.grad = list(g_list)
|
|
232
|
+
|
|
233
|
+
g = torch.cat([t.ravel() for t in g_list])
|
|
234
|
+
n = g.numel()
|
|
235
|
+
derivatives = [g]
|
|
236
|
+
T = g # current derivatives tensor
|
|
237
|
+
|
|
238
|
+
# get all derivative up to order
|
|
239
|
+
for o in range(2, order + 1):
|
|
240
|
+
is_last = o == order
|
|
241
|
+
T_list = jacobian_wrt([T], params, create_graph=not is_last, batched=vectorize)
|
|
242
|
+
with torch.no_grad() if is_last else nullcontext():
|
|
243
|
+
# the shape is (ndim, ) * order
|
|
244
|
+
T = hessian_list_to_mat(T_list).view(n, n, *T.shape[1:])
|
|
245
|
+
derivatives.append(T)
|
|
246
|
+
|
|
247
|
+
x0 = torch.cat([p.ravel() for p in params])
|
|
248
|
+
|
|
249
|
+
success = False
|
|
250
|
+
x_star = None
|
|
251
|
+
while not success:
|
|
252
|
+
max_attempts -= 1
|
|
253
|
+
if max_attempts < 0: break
|
|
254
|
+
|
|
255
|
+
# load trust region value
|
|
256
|
+
trust_value = self.global_state.get('trust_region', init)
|
|
257
|
+
if trust_value < 1e-8 or trust_value > 1e16: trust_value = self.global_state['trust_region'] = settings['init']
|
|
258
|
+
|
|
259
|
+
if trust_method is None: trust_method = 'none'
|
|
260
|
+
else: trust_method = trust_method.lower()
|
|
261
|
+
|
|
262
|
+
if trust_method == 'none':
|
|
263
|
+
trust_region = None
|
|
264
|
+
prox = 0
|
|
265
|
+
|
|
266
|
+
elif trust_method == 'bounds':
|
|
267
|
+
trust_region = trust_value
|
|
268
|
+
prox = 0
|
|
269
|
+
|
|
270
|
+
elif trust_method == 'proximal':
|
|
271
|
+
trust_region = None
|
|
272
|
+
prox = 1 / trust_value
|
|
273
|
+
|
|
274
|
+
else:
|
|
275
|
+
raise ValueError(trust_method)
|
|
276
|
+
|
|
277
|
+
# minimize the model
|
|
278
|
+
x_star, expected_loss = _poly_minimize(
|
|
279
|
+
trust_region=trust_region,
|
|
280
|
+
prox=prox,
|
|
281
|
+
de_iters=de_iters,
|
|
282
|
+
c=loss.item(),
|
|
283
|
+
x=x0,
|
|
284
|
+
derivatives=derivatives,
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
# update trust region
|
|
288
|
+
if trust_method == 'none':
|
|
289
|
+
success = True
|
|
290
|
+
else:
|
|
291
|
+
pred_reduction = loss - expected_loss
|
|
292
|
+
|
|
293
|
+
vec_to_tensors_(x_star, params)
|
|
294
|
+
loss_star = closure(False)
|
|
295
|
+
vec_to_tensors_(x0, params)
|
|
296
|
+
reduction = loss - loss_star
|
|
297
|
+
|
|
298
|
+
rho = reduction / (max(pred_reduction, 1e-8))
|
|
299
|
+
# failed step
|
|
300
|
+
if rho < 0.25:
|
|
301
|
+
self.global_state['trust_region'] = trust_value * nminus
|
|
302
|
+
|
|
303
|
+
# very good step
|
|
304
|
+
elif rho > 0.75:
|
|
305
|
+
diff = trust_value - (x0 - x_star).abs_()
|
|
306
|
+
if (diff.amin() / trust_value) > 1e-4: # hits boundary
|
|
307
|
+
self.global_state['trust_region'] = trust_value * nplus
|
|
308
|
+
|
|
309
|
+
# if the ratio is high enough then accept the proposed step
|
|
310
|
+
success = rho > eta
|
|
311
|
+
|
|
312
|
+
assert x_star is not None
|
|
313
|
+
if success:
|
|
314
|
+
difference = vec_to_tensors(x0 - x_star, params)
|
|
315
|
+
var.update = list(difference)
|
|
316
|
+
else:
|
|
317
|
+
var.update = params.zeros_like()
|
|
318
|
+
return var
|
|
319
|
+
|
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
from .
|
|
2
|
-
from .backtracking import
|
|
3
|
-
from .
|
|
1
|
+
from .adaptive import AdaptiveLineSearch
|
|
2
|
+
from .backtracking import AdaptiveBacktracking, Backtracking
|
|
3
|
+
from .line_search import LineSearchBase
|
|
4
4
|
from .scipy import ScipyMinimizeScalar
|
|
5
|
-
from .
|
|
5
|
+
from .strong_wolfe import StrongWolfe
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from collections.abc import Callable
|
|
3
|
+
from operator import itemgetter
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from .line_search import LineSearchBase
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def adaptive_tracking(
|
|
12
|
+
f,
|
|
13
|
+
x_0,
|
|
14
|
+
maxiter: int,
|
|
15
|
+
nplus: float = 2,
|
|
16
|
+
nminus: float = 0.5,
|
|
17
|
+
):
|
|
18
|
+
f_0 = f(0)
|
|
19
|
+
|
|
20
|
+
t = x_0
|
|
21
|
+
f_t = f(t)
|
|
22
|
+
|
|
23
|
+
# backtrack
|
|
24
|
+
if f_t > f_0:
|
|
25
|
+
while f_t > f_0:
|
|
26
|
+
maxiter -= 1
|
|
27
|
+
if maxiter < 0: return 0, f_0
|
|
28
|
+
t = t*nminus
|
|
29
|
+
f_t = f(t)
|
|
30
|
+
return t, f_t
|
|
31
|
+
|
|
32
|
+
# forwardtrack
|
|
33
|
+
f_prev = f_t
|
|
34
|
+
t *= nplus
|
|
35
|
+
f_t = f(t)
|
|
36
|
+
if f_prev < f_t: return t / nplus, f_prev
|
|
37
|
+
while f_prev >= f_t:
|
|
38
|
+
maxiter -= 1
|
|
39
|
+
if maxiter < 0: return t, f_t
|
|
40
|
+
f_prev = f_t
|
|
41
|
+
t *= nplus
|
|
42
|
+
f_t = f(t)
|
|
43
|
+
return t / nplus, f_prev
|
|
44
|
+
|
|
45
|
+
class AdaptiveLineSearch(LineSearchBase):
|
|
46
|
+
"""Adaptive line search, similar to backtracking but also has forward tracking mode.
|
|
47
|
+
Currently doesn't check for weak curvature condition.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
init (float, optional): initial step size. Defaults to 1.0.
|
|
51
|
+
beta (float, optional): multiplies each consecutive step size by this value. Defaults to 0.5.
|
|
52
|
+
maxiter (int, optional): Maximum line search function evaluations. Defaults to 10.
|
|
53
|
+
adaptive (bool, optional):
|
|
54
|
+
when enabled, if line search failed, beta size is reduced.
|
|
55
|
+
Otherwise it is reset to initial value. Defaults to True.
|
|
56
|
+
"""
|
|
57
|
+
def __init__(
|
|
58
|
+
self,
|
|
59
|
+
init: float = 1.0,
|
|
60
|
+
nplus: float = 2,
|
|
61
|
+
nminus: float = 0.5,
|
|
62
|
+
maxiter: int = 10,
|
|
63
|
+
adaptive=True,
|
|
64
|
+
):
|
|
65
|
+
defaults=dict(init=init,nplus=nplus,nminus=nminus,maxiter=maxiter,adaptive=adaptive,)
|
|
66
|
+
super().__init__(defaults=defaults)
|
|
67
|
+
self.global_state['beta_scale'] = 1.0
|
|
68
|
+
|
|
69
|
+
def reset(self):
|
|
70
|
+
super().reset()
|
|
71
|
+
self.global_state['beta_scale'] = 1.0
|
|
72
|
+
|
|
73
|
+
@torch.no_grad
|
|
74
|
+
def search(self, update, var):
|
|
75
|
+
init, nplus, nminus, maxiter, adaptive = itemgetter(
|
|
76
|
+
'init', 'nplus', 'nminus', 'maxiter', 'adaptive')(self.settings[var.params[0]])
|
|
77
|
+
|
|
78
|
+
objective = self.make_objective(var=var)
|
|
79
|
+
|
|
80
|
+
# # directional derivative
|
|
81
|
+
# d = -sum(t.sum() for t in torch._foreach_mul(var.get_grad(), var.get_update()))
|
|
82
|
+
|
|
83
|
+
# scale beta (beta is multiplicative and i think may be better than scaling initial step size)
|
|
84
|
+
beta_scale = self.global_state.get('beta_scale', 1)
|
|
85
|
+
x_prev = self.global_state.get('prev_x', 1)
|
|
86
|
+
|
|
87
|
+
if adaptive: nminus = nminus * beta_scale
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
step_size, f = adaptive_tracking(objective, x_prev, maxiter, nplus=nplus, nminus=nminus)
|
|
91
|
+
|
|
92
|
+
# found an alpha that reduces loss
|
|
93
|
+
if step_size != 0:
|
|
94
|
+
self.global_state['beta_scale'] = min(1.0, self.global_state['beta_scale'] * math.sqrt(1.5))
|
|
95
|
+
return step_size
|
|
96
|
+
|
|
97
|
+
# on fail reduce beta scale value
|
|
98
|
+
self.global_state['beta_scale'] /= 1.5
|
|
99
|
+
return 0
|
|
@@ -4,7 +4,7 @@ from operator import itemgetter
|
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
|
-
from .line_search import
|
|
7
|
+
from .line_search import LineSearchBase
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
def backtracking_line_search(
|
|
@@ -14,19 +14,17 @@ def backtracking_line_search(
|
|
|
14
14
|
beta: float = 0.5,
|
|
15
15
|
c: float = 1e-4,
|
|
16
16
|
maxiter: int = 10,
|
|
17
|
-
a_min: float | None = None,
|
|
18
17
|
try_negative: bool = False,
|
|
19
18
|
) -> float | None:
|
|
20
19
|
"""
|
|
21
20
|
|
|
22
21
|
Args:
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
22
|
+
f: evaluates step size along some descent direction.
|
|
23
|
+
g_0: directional derivative along the descent direction.
|
|
24
|
+
init: initial step size.
|
|
26
25
|
beta: The factor by which to decrease alpha in each iteration
|
|
27
26
|
c: The constant for the Armijo sufficient decrease condition
|
|
28
|
-
|
|
29
|
-
min_alpha: Minimum allowable step size to prevent near-zero values (default: 1e-16).
|
|
27
|
+
maxiter: Maximum number of backtracking iterations (default: 10).
|
|
30
28
|
|
|
31
29
|
Returns:
|
|
32
30
|
step size
|
|
@@ -34,21 +32,21 @@ def backtracking_line_search(
|
|
|
34
32
|
|
|
35
33
|
a = init
|
|
36
34
|
f_x = f(0)
|
|
35
|
+
f_prev = None
|
|
37
36
|
|
|
38
37
|
for iteration in range(maxiter):
|
|
39
38
|
f_a = f(a)
|
|
40
39
|
|
|
41
|
-
if
|
|
40
|
+
if (f_prev is not None) and (f_a > f_prev) and (f_prev < f_x): return a / beta
|
|
41
|
+
f_prev = f_a
|
|
42
|
+
|
|
43
|
+
if f_a < f_x + c * a * min(g_0, 0): # pyright: ignore[reportArgumentType]
|
|
42
44
|
# found an acceptable alpha
|
|
43
45
|
return a
|
|
44
46
|
|
|
45
47
|
# decrease alpha
|
|
46
48
|
a *= beta
|
|
47
49
|
|
|
48
|
-
# alpha too small
|
|
49
|
-
if a_min is not None and a < a_min:
|
|
50
|
-
return a_min
|
|
51
|
-
|
|
52
50
|
# fail
|
|
53
51
|
if try_negative:
|
|
54
52
|
def inv_objective(alpha): return f(-alpha)
|
|
@@ -59,25 +57,56 @@ def backtracking_line_search(
|
|
|
59
57
|
beta=beta,
|
|
60
58
|
c=c,
|
|
61
59
|
maxiter=maxiter,
|
|
62
|
-
a_min=a_min,
|
|
63
60
|
try_negative=False,
|
|
64
61
|
)
|
|
65
62
|
if v is not None: return -v
|
|
66
63
|
|
|
67
64
|
return None
|
|
68
65
|
|
|
69
|
-
class Backtracking(
|
|
66
|
+
class Backtracking(LineSearchBase):
|
|
67
|
+
"""Backtracking line search satisfying the Armijo condition.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
init (float, optional): initial step size. Defaults to 1.0.
|
|
71
|
+
beta (float, optional): multiplies each consecutive step size by this value. Defaults to 0.5.
|
|
72
|
+
c (float, optional): acceptance value for Armijo condition. Defaults to 1e-4.
|
|
73
|
+
maxiter (int, optional): Maximum line search function evaluations. Defaults to 10.
|
|
74
|
+
adaptive (bool, optional):
|
|
75
|
+
when enabled, if line search failed, beta is reduced.
|
|
76
|
+
Otherwise it is reset to initial value. Defaults to True.
|
|
77
|
+
try_negative (bool, optional): Whether to perform line search in opposite direction on fail. Defaults to False.
|
|
78
|
+
|
|
79
|
+
Examples:
|
|
80
|
+
Gradient descent with backtracking line search:
|
|
81
|
+
|
|
82
|
+
.. code-block:: python
|
|
83
|
+
|
|
84
|
+
opt = tz.Modular(
|
|
85
|
+
model.parameters(),
|
|
86
|
+
tz.m.Backtracking()
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
LBFGS with backtracking line search:
|
|
90
|
+
|
|
91
|
+
.. code-block:: python
|
|
92
|
+
|
|
93
|
+
opt = tz.Modular(
|
|
94
|
+
model.parameters(),
|
|
95
|
+
tz.m.LBFGS(),
|
|
96
|
+
tz.m.Backtracking()
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
"""
|
|
70
100
|
def __init__(
|
|
71
101
|
self,
|
|
72
102
|
init: float = 1.0,
|
|
73
103
|
beta: float = 0.5,
|
|
74
104
|
c: float = 1e-4,
|
|
75
105
|
maxiter: int = 10,
|
|
76
|
-
min_alpha: float | None = None,
|
|
77
106
|
adaptive=True,
|
|
78
107
|
try_negative: bool = False,
|
|
79
108
|
):
|
|
80
|
-
defaults=dict(init=init,beta=beta,c=c,maxiter=maxiter,
|
|
109
|
+
defaults=dict(init=init,beta=beta,c=c,maxiter=maxiter,adaptive=adaptive, try_negative=try_negative)
|
|
81
110
|
super().__init__(defaults=defaults)
|
|
82
111
|
self.global_state['beta_scale'] = 1.0
|
|
83
112
|
|
|
@@ -86,20 +115,20 @@ class Backtracking(LineSearch):
|
|
|
86
115
|
self.global_state['beta_scale'] = 1.0
|
|
87
116
|
|
|
88
117
|
@torch.no_grad
|
|
89
|
-
def search(self, update,
|
|
90
|
-
init, beta, c, maxiter,
|
|
91
|
-
'init', 'beta', 'c', 'maxiter', '
|
|
118
|
+
def search(self, update, var):
|
|
119
|
+
init, beta, c, maxiter, adaptive, try_negative = itemgetter(
|
|
120
|
+
'init', 'beta', 'c', 'maxiter', 'adaptive', 'try_negative')(self.settings[var.params[0]])
|
|
92
121
|
|
|
93
|
-
objective = self.make_objective(
|
|
122
|
+
objective = self.make_objective(var=var)
|
|
94
123
|
|
|
95
124
|
# # directional derivative
|
|
96
|
-
d = -sum(t.sum() for t in torch._foreach_mul(
|
|
125
|
+
d = -sum(t.sum() for t in torch._foreach_mul(var.get_grad(), var.get_update()))
|
|
97
126
|
|
|
98
127
|
# scale beta (beta is multiplicative and i think may be better than scaling initial step size)
|
|
99
128
|
if adaptive: beta = beta * self.global_state['beta_scale']
|
|
100
129
|
|
|
101
130
|
step_size = backtracking_line_search(objective, d, init=init,beta=beta,
|
|
102
|
-
c=c,maxiter=maxiter,
|
|
131
|
+
c=c,maxiter=maxiter, try_negative=try_negative)
|
|
103
132
|
|
|
104
133
|
# found an alpha that reduces loss
|
|
105
134
|
if step_size is not None:
|
|
@@ -113,20 +142,35 @@ class Backtracking(LineSearch):
|
|
|
113
142
|
def _lerp(start,end,weight):
|
|
114
143
|
return start + weight * (end - start)
|
|
115
144
|
|
|
116
|
-
class AdaptiveBacktracking(
|
|
145
|
+
class AdaptiveBacktracking(LineSearchBase):
|
|
146
|
+
"""Adaptive backtracking line search. After each line search procedure, a new initial step size is set
|
|
147
|
+
such that optimal step size in the procedure would be found on the second line search iteration.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
init (float, optional): step size for the first step. Defaults to 1.0.
|
|
151
|
+
beta (float, optional): multiplies each consecutive step size by this value. Defaults to 0.5.
|
|
152
|
+
c (float, optional): acceptance value for Armijo condition. Defaults to 1e-4.
|
|
153
|
+
maxiter (int, optional): Maximum line search function evaluations. Defaults to 10.
|
|
154
|
+
target_iters (int, optional):
|
|
155
|
+
target number of iterations that would be performed until optimal step size is found. Defaults to 1.
|
|
156
|
+
nplus (float, optional):
|
|
157
|
+
Multiplier to initial step size if it was found to be the optimal step size. Defaults to 2.0.
|
|
158
|
+
scale_beta (float, optional):
|
|
159
|
+
Momentum for initial step size, at 0 disables momentum. Defaults to 0.0.
|
|
160
|
+
try_negative (bool, optional): Whether to perform line search in opposite direction on fail. Defaults to False.
|
|
161
|
+
"""
|
|
117
162
|
def __init__(
|
|
118
163
|
self,
|
|
119
164
|
init: float = 1.0,
|
|
120
165
|
beta: float = 0.5,
|
|
121
166
|
c: float = 1e-4,
|
|
122
167
|
maxiter: int = 20,
|
|
123
|
-
min_alpha: float | None = None,
|
|
124
168
|
target_iters = 1,
|
|
125
169
|
nplus = 2.0,
|
|
126
170
|
scale_beta = 0.0,
|
|
127
171
|
try_negative: bool = False,
|
|
128
172
|
):
|
|
129
|
-
defaults=dict(init=init,beta=beta,c=c,maxiter=maxiter,
|
|
173
|
+
defaults=dict(init=init,beta=beta,c=c,maxiter=maxiter,target_iters=target_iters,nplus=nplus,scale_beta=scale_beta, try_negative=try_negative)
|
|
130
174
|
super().__init__(defaults=defaults)
|
|
131
175
|
|
|
132
176
|
self.global_state['beta_scale'] = 1.0
|
|
@@ -138,15 +182,15 @@ class AdaptiveBacktracking(LineSearch):
|
|
|
138
182
|
self.global_state['initial_scale'] = 1.0
|
|
139
183
|
|
|
140
184
|
@torch.no_grad
|
|
141
|
-
def search(self, update,
|
|
142
|
-
init, beta, c, maxiter,
|
|
143
|
-
'init','beta','c','maxiter','
|
|
185
|
+
def search(self, update, var):
|
|
186
|
+
init, beta, c, maxiter, target_iters, nplus, scale_beta, try_negative=itemgetter(
|
|
187
|
+
'init','beta','c','maxiter','target_iters','nplus','scale_beta', 'try_negative')(self.settings[var.params[0]])
|
|
144
188
|
|
|
145
|
-
objective = self.make_objective(
|
|
189
|
+
objective = self.make_objective(var=var)
|
|
146
190
|
|
|
147
191
|
# directional derivative (0 if c = 0 because it is not needed)
|
|
148
192
|
if c == 0: d = 0
|
|
149
|
-
else: d = -sum(t.sum() for t in torch._foreach_mul(
|
|
193
|
+
else: d = -sum(t.sum() for t in torch._foreach_mul(var.get_grad(), update))
|
|
150
194
|
|
|
151
195
|
# scale beta
|
|
152
196
|
beta = beta * self.global_state['beta_scale']
|
|
@@ -155,7 +199,7 @@ class AdaptiveBacktracking(LineSearch):
|
|
|
155
199
|
init = init * self.global_state['initial_scale']
|
|
156
200
|
|
|
157
201
|
step_size = backtracking_line_search(objective, d, init=init, beta=beta,
|
|
158
|
-
c=c,maxiter=maxiter,
|
|
202
|
+
c=c,maxiter=maxiter, try_negative=try_negative)
|
|
159
203
|
|
|
160
204
|
# found an alpha that reduces loss
|
|
161
205
|
if step_size is not None:
|