torchzero 0.3.11__py3-none-any.whl → 0.3.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_opts.py +95 -69
- tests/test_tensorlist.py +8 -7
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +225 -72
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +44 -24
- torchzero/modules/__init__.py +13 -5
- torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
- torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
- torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
- torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
- torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
- torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
- torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
- torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
- torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
- torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
- torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
- torchzero/modules/clipping/clipping.py +85 -92
- torchzero/modules/clipping/ema_clipping.py +5 -5
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
- torchzero/modules/experimental/__init__.py +9 -32
- torchzero/modules/experimental/dct.py +2 -2
- torchzero/modules/experimental/fft.py +2 -2
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +27 -14
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/experimental/structural_projections.py +1 -1
- torchzero/modules/functional.py +50 -14
- torchzero/modules/grad_approximation/fdm.py +19 -20
- torchzero/modules/grad_approximation/forward_gradient.py +4 -2
- torchzero/modules/grad_approximation/grad_approximator.py +43 -47
- torchzero/modules/grad_approximation/rfdm.py +144 -122
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +31 -23
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +2 -2
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +69 -44
- torchzero/modules/line_search/backtracking.py +83 -70
- torchzero/modules/line_search/line_search.py +159 -68
- torchzero/modules/line_search/scipy.py +1 -1
- torchzero/modules/line_search/strong_wolfe.py +319 -218
- torchzero/modules/misc/__init__.py +8 -0
- torchzero/modules/misc/debug.py +4 -4
- torchzero/modules/misc/escape.py +9 -7
- torchzero/modules/misc/gradient_accumulation.py +88 -22
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +82 -15
- torchzero/modules/misc/multistep.py +47 -11
- torchzero/modules/misc/regularization.py +5 -9
- torchzero/modules/misc/split.py +55 -35
- torchzero/modules/misc/switch.py +1 -1
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +3 -3
- torchzero/modules/momentum/cautious.py +42 -47
- torchzero/modules/momentum/momentum.py +35 -1
- torchzero/modules/ops/__init__.py +9 -1
- torchzero/modules/ops/binary.py +9 -8
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
- torchzero/modules/ops/multi.py +15 -15
- torchzero/modules/ops/reduce.py +1 -1
- torchzero/modules/ops/utility.py +12 -8
- torchzero/modules/projections/projection.py +4 -4
- torchzero/modules/quasi_newton/__init__.py +1 -16
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
- torchzero/modules/quasi_newton/lbfgs.py +256 -200
- torchzero/modules/quasi_newton/lsr1.py +167 -132
- torchzero/modules/quasi_newton/quasi_newton.py +346 -446
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +252 -0
- torchzero/modules/second_order/__init__.py +2 -1
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +133 -88
- torchzero/modules/second_order/newton_cg.py +141 -80
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +1 -1
- torchzero/modules/step_size/adaptive.py +312 -47
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +97 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/weight_decay.py +65 -64
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +359 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +0 -1
- torchzero/optim/wrappers/fcmaes.py +3 -2
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +2 -2
- torchzero/optim/wrappers/scipy.py +81 -22
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +123 -111
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +226 -154
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/python_tools.py +6 -0
- torchzero/utils/tensorlist.py +105 -34
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.13.dist-info/METADATA +14 -0
- torchzero-0.3.13.dist-info/RECORD +166 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -59
- docs/source/docstring template.py +0 -46
- torchzero/modules/experimental/absoap.py +0 -253
- torchzero/modules/experimental/adadam.py +0 -118
- torchzero/modules/experimental/adamY.py +0 -131
- torchzero/modules/experimental/adam_lambertw.py +0 -149
- torchzero/modules/experimental/adaptive_step_size.py +0 -90
- torchzero/modules/experimental/adasoap.py +0 -177
- torchzero/modules/experimental/cosine.py +0 -214
- torchzero/modules/experimental/cubic_adam.py +0 -97
- torchzero/modules/experimental/eigendescent.py +0 -120
- torchzero/modules/experimental/etf.py +0 -195
- torchzero/modules/experimental/exp_adam.py +0 -113
- torchzero/modules/experimental/expanded_lbfgs.py +0 -141
- torchzero/modules/experimental/hnewton.py +0 -85
- torchzero/modules/experimental/modular_lbfgs.py +0 -265
- torchzero/modules/experimental/parabolic_search.py +0 -220
- torchzero/modules/experimental/subspace_preconditioners.py +0 -145
- torchzero/modules/experimental/tensor_adagrad.py +0 -42
- torchzero/modules/line_search/polynomial.py +0 -233
- torchzero/modules/momentum/matrix_momentum.py +0 -193
- torchzero/modules/optimizers/adagrad.py +0 -165
- torchzero/modules/quasi_newton/trust_region.py +0 -397
- torchzero/modules/smoothing/gaussian.py +0 -198
- torchzero-0.3.11.dist-info/METADATA +0 -404
- torchzero-0.3.11.dist-info/RECORD +0 -159
- torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
- /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
|
@@ -13,7 +13,7 @@ import torch
|
|
|
13
13
|
from ...core import Chainable, Module, apply_transform
|
|
14
14
|
from ...utils import TensorList, vec_to_tensors, vec_to_tensors_
|
|
15
15
|
from ...utils.derivatives import (
|
|
16
|
-
|
|
16
|
+
flatten_jacobian,
|
|
17
17
|
jacobian_wrt,
|
|
18
18
|
)
|
|
19
19
|
|
|
@@ -148,21 +148,16 @@ class HigherOrderNewton(Module):
|
|
|
148
148
|
"""A basic arbitrary order newton's method with optional trust region and proximal penalty.
|
|
149
149
|
|
|
150
150
|
This constructs an nth order taylor approximation via autograd and minimizes it with
|
|
151
|
-
scipy.optimize.minimize trust region newton solvers with optional proximal penalty.
|
|
151
|
+
``scipy.optimize.minimize`` trust region newton solvers with optional proximal penalty.
|
|
152
152
|
|
|
153
|
-
|
|
154
|
-
|
|
153
|
+
The hessian of taylor approximation is easier to evaluate, plus it can be evaluated in a batched mode,
|
|
154
|
+
so it can be more efficient in very specific instances.
|
|
155
155
|
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
as it needs to re-evaluate the loss and gradients for calculating higher order derivatives.
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
.. warning::
|
|
162
|
-
this uses roughly O(N^order) memory and solving the subproblem can be very expensive.
|
|
163
|
-
|
|
164
|
-
.. warning::
|
|
165
|
-
"none" and "proximal" trust methods may generate subproblems that have no minima, causing divergence.
|
|
156
|
+
Notes:
|
|
157
|
+
- In most cases HigherOrderNewton should be the first module in the chain because it relies on extra autograd. Use the ``inner`` argument if you wish to apply Newton preconditioning to another module's output.
|
|
158
|
+
- This module requires the a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating higher order derivatives. The closure must accept a ``backward`` argument (refer to documentation).
|
|
159
|
+
- this uses roughly O(N^order) memory and solving the subproblem is very expensive.
|
|
160
|
+
- "none" and "proximal" trust methods may generate subproblems that have no minima, causing divergence.
|
|
166
161
|
|
|
167
162
|
Args:
|
|
168
163
|
|
|
@@ -178,7 +173,7 @@ class HigherOrderNewton(Module):
|
|
|
178
173
|
increase (float, optional): trust region multiplier on good steps. Defaults to 1.5.
|
|
179
174
|
decrease (float, optional): trust region multiplier on bad steps. Defaults to 0.75.
|
|
180
175
|
trust_init (float | None, optional):
|
|
181
|
-
initial trust region size. If none, defaults to 1 on :code:`trust_method="bounds"` and 0.1 on
|
|
176
|
+
initial trust region size. If none, defaults to 1 on :code:`trust_method="bounds"` and 0.1 on ``"proximal"``. Defaults to None.
|
|
182
177
|
trust_tol (float, optional):
|
|
183
178
|
Maximum ratio of expected loss reduction to actual reduction for trust region increase.
|
|
184
179
|
Should 1 or higer. Defaults to 2.
|
|
@@ -191,11 +186,14 @@ class HigherOrderNewton(Module):
|
|
|
191
186
|
self,
|
|
192
187
|
order: int = 4,
|
|
193
188
|
trust_method: Literal['bounds', 'proximal', 'none'] | None = 'bounds',
|
|
194
|
-
nplus: float =
|
|
189
|
+
nplus: float = 3.5,
|
|
195
190
|
nminus: float = 0.25,
|
|
191
|
+
rho_good: float = 0.99,
|
|
192
|
+
rho_bad: float = 1e-4,
|
|
196
193
|
init: float | None = None,
|
|
197
194
|
eta: float = 1e-6,
|
|
198
195
|
max_attempts = 10,
|
|
196
|
+
boundary_tol: float = 1e-2,
|
|
199
197
|
de_iters: int | None = None,
|
|
200
198
|
vectorize: bool = True,
|
|
201
199
|
):
|
|
@@ -203,7 +201,7 @@ class HigherOrderNewton(Module):
|
|
|
203
201
|
if trust_method == 'bounds': init = 1
|
|
204
202
|
else: init = 0.1
|
|
205
203
|
|
|
206
|
-
defaults = dict(order=order, trust_method=trust_method, nplus=nplus, nminus=nminus, eta=eta, init=init, vectorize=vectorize, de_iters=de_iters, max_attempts=max_attempts)
|
|
204
|
+
defaults = dict(order=order, trust_method=trust_method, nplus=nplus, nminus=nminus, eta=eta, init=init, vectorize=vectorize, de_iters=de_iters, max_attempts=max_attempts, boundary_tol=boundary_tol, rho_good=rho_good, rho_bad=rho_bad)
|
|
207
205
|
super().__init__(defaults)
|
|
208
206
|
|
|
209
207
|
@torch.no_grad
|
|
@@ -222,6 +220,9 @@ class HigherOrderNewton(Module):
|
|
|
222
220
|
de_iters = settings['de_iters']
|
|
223
221
|
max_attempts = settings['max_attempts']
|
|
224
222
|
vectorize = settings['vectorize']
|
|
223
|
+
boundary_tol = settings['boundary_tol']
|
|
224
|
+
rho_good = settings['rho_good']
|
|
225
|
+
rho_bad = settings['rho_bad']
|
|
225
226
|
|
|
226
227
|
# ------------------------ calculate grad and hessian ------------------------ #
|
|
227
228
|
with torch.enable_grad():
|
|
@@ -241,7 +242,7 @@ class HigherOrderNewton(Module):
|
|
|
241
242
|
T_list = jacobian_wrt([T], params, create_graph=not is_last, batched=vectorize)
|
|
242
243
|
with torch.no_grad() if is_last else nullcontext():
|
|
243
244
|
# the shape is (ndim, ) * order
|
|
244
|
-
T =
|
|
245
|
+
T = flatten_jacobian(T_list).view(n, n, *T.shape[1:])
|
|
245
246
|
derivatives.append(T)
|
|
246
247
|
|
|
247
248
|
x0 = torch.cat([p.ravel() for p in params])
|
|
@@ -254,8 +255,13 @@ class HigherOrderNewton(Module):
|
|
|
254
255
|
|
|
255
256
|
# load trust region value
|
|
256
257
|
trust_value = self.global_state.get('trust_region', init)
|
|
257
|
-
if trust_value < 1e-8 or trust_value > 1e16: trust_value = self.global_state['trust_region'] = settings['init']
|
|
258
258
|
|
|
259
|
+
# make sure its not too small or too large
|
|
260
|
+
finfo = torch.finfo(x0.dtype)
|
|
261
|
+
if trust_value < finfo.tiny*2 or trust_value > finfo.max / (2*nplus):
|
|
262
|
+
trust_value = self.global_state['trust_region'] = settings['init']
|
|
263
|
+
|
|
264
|
+
# determine tr and prox values
|
|
259
265
|
if trust_method is None: trust_method = 'none'
|
|
260
266
|
else: trust_method = trust_method.lower()
|
|
261
267
|
|
|
@@ -297,13 +303,15 @@ class HigherOrderNewton(Module):
|
|
|
297
303
|
|
|
298
304
|
rho = reduction / (max(pred_reduction, 1e-8))
|
|
299
305
|
# failed step
|
|
300
|
-
if rho <
|
|
306
|
+
if rho < rho_bad:
|
|
301
307
|
self.global_state['trust_region'] = trust_value * nminus
|
|
302
308
|
|
|
303
309
|
# very good step
|
|
304
|
-
elif rho >
|
|
305
|
-
|
|
306
|
-
|
|
310
|
+
elif rho > rho_good:
|
|
311
|
+
step = (x_star - x0)
|
|
312
|
+
magn = torch.linalg.vector_norm(step) # pylint:disable=not-callable
|
|
313
|
+
if trust_method == 'proximal' or (trust_value - magn) / trust_value <= boundary_tol:
|
|
314
|
+
# close to boundary
|
|
307
315
|
self.global_state['trust_region'] = trust_value * nplus
|
|
308
316
|
|
|
309
317
|
# if the ratio is high enough then accept the proposed step
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .gn import SumOfSquares, GaussNewton
|
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from ...core import Module
|
|
3
|
+
|
|
4
|
+
from ...utils.derivatives import jacobian_wrt, flatten_jacobian
|
|
5
|
+
from ...utils import vec_to_tensors
|
|
6
|
+
from ...utils.linalg import linear_operator
|
|
7
|
+
class SumOfSquares(Module):
|
|
8
|
+
"""Sets loss to be the sum of squares of values returned by the closure.
|
|
9
|
+
|
|
10
|
+
This is meant to be used to test least squares methods against ordinary minimization methods.
|
|
11
|
+
|
|
12
|
+
To use this, the closure should return a vector of values to minimize sum of squares of.
|
|
13
|
+
Please add the `backward` argument, it will always be False but it is required.
|
|
14
|
+
"""
|
|
15
|
+
def __init__(self):
|
|
16
|
+
super().__init__()
|
|
17
|
+
|
|
18
|
+
@torch.no_grad
|
|
19
|
+
def step(self, var):
|
|
20
|
+
closure = var.closure
|
|
21
|
+
|
|
22
|
+
if closure is not None:
|
|
23
|
+
def sos_closure(backward=True):
|
|
24
|
+
if backward:
|
|
25
|
+
var.zero_grad()
|
|
26
|
+
with torch.enable_grad():
|
|
27
|
+
loss = closure(False)
|
|
28
|
+
loss = loss.pow(2).sum()
|
|
29
|
+
loss.backward()
|
|
30
|
+
return loss
|
|
31
|
+
|
|
32
|
+
loss = closure(False)
|
|
33
|
+
return loss.pow(2).sum()
|
|
34
|
+
|
|
35
|
+
var.closure = sos_closure
|
|
36
|
+
|
|
37
|
+
if var.loss is not None:
|
|
38
|
+
var.loss = var.loss.pow(2).sum()
|
|
39
|
+
|
|
40
|
+
if var.loss_approx is not None:
|
|
41
|
+
var.loss_approx = var.loss_approx.pow(2).sum()
|
|
42
|
+
|
|
43
|
+
return var
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class GaussNewton(Module):
|
|
47
|
+
"""Gauss-newton method.
|
|
48
|
+
|
|
49
|
+
To use this, the closure should return a vector of values to minimize sum of squares of.
|
|
50
|
+
Please add the ``backward`` argument, it will always be False but it is required.
|
|
51
|
+
Gradients will be calculated via batched autograd within this module, you don't need to
|
|
52
|
+
implement the backward pass. Please see below for an example.
|
|
53
|
+
|
|
54
|
+
Note:
|
|
55
|
+
This method requires ``ndim^2`` memory, however, if it is used within ``tz.m.TrustCG`` trust region,
|
|
56
|
+
the memory requirement is ``ndim*m``, where ``m`` is number of values in the output.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
reg (float, optional): regularization parameter. Defaults to 1e-8.
|
|
60
|
+
batched (bool, optional): whether to use vmapping. Defaults to True.
|
|
61
|
+
|
|
62
|
+
Examples:
|
|
63
|
+
|
|
64
|
+
minimizing the rosenbrock function:
|
|
65
|
+
```python
|
|
66
|
+
def rosenbrock(X):
|
|
67
|
+
x1, x2 = X
|
|
68
|
+
return torch.stack([(1 - x1), 100 * (x2 - x1**2)])
|
|
69
|
+
|
|
70
|
+
X = torch.tensor([-1.1, 2.5], requires_grad=True)
|
|
71
|
+
opt = tz.Modular([X], tz.m.GaussNewton(), tz.m.Backtracking())
|
|
72
|
+
|
|
73
|
+
# define the closure for line search
|
|
74
|
+
def closure(backward=True):
|
|
75
|
+
return rosenbrock(X)
|
|
76
|
+
|
|
77
|
+
# minimize
|
|
78
|
+
for iter in range(10):
|
|
79
|
+
loss = opt.step(closure)
|
|
80
|
+
print(f'{loss = }')
|
|
81
|
+
```
|
|
82
|
+
|
|
83
|
+
training a neural network with a matrix-free GN trust region:
|
|
84
|
+
```python
|
|
85
|
+
X = torch.randn(64, 20)
|
|
86
|
+
y = torch.randn(64, 10)
|
|
87
|
+
|
|
88
|
+
model = nn.Sequential(nn.Linear(20, 64), nn.ELU(), nn.Linear(64, 10))
|
|
89
|
+
opt = tz.Modular(
|
|
90
|
+
model.parameters(),
|
|
91
|
+
tz.m.TrustCG(tz.m.GaussNewton()),
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
def closure(backward=True):
|
|
95
|
+
y_hat = model(X) # (64, 10)
|
|
96
|
+
return (y_hat - y).pow(2).mean(0) # (10, )
|
|
97
|
+
|
|
98
|
+
for i in range(100):
|
|
99
|
+
losses = opt.step(closure)
|
|
100
|
+
if i % 10 == 0:
|
|
101
|
+
print(f'{losses.mean() = }')
|
|
102
|
+
```
|
|
103
|
+
"""
|
|
104
|
+
def __init__(self, reg:float = 1e-8, batched:bool=True, ):
|
|
105
|
+
super().__init__(defaults=dict(batched=batched, reg=reg))
|
|
106
|
+
|
|
107
|
+
@torch.no_grad
|
|
108
|
+
def update(self, var):
|
|
109
|
+
params = var.params
|
|
110
|
+
batched = self.defaults['batched']
|
|
111
|
+
|
|
112
|
+
closure = var.closure
|
|
113
|
+
assert closure is not None
|
|
114
|
+
|
|
115
|
+
# gauss newton direction
|
|
116
|
+
with torch.enable_grad():
|
|
117
|
+
f = var.get_loss(backward=False) # n_out
|
|
118
|
+
assert isinstance(f, torch.Tensor)
|
|
119
|
+
G_list = jacobian_wrt([f.ravel()], params, batched=batched)
|
|
120
|
+
|
|
121
|
+
var.loss = f.pow(2).sum()
|
|
122
|
+
|
|
123
|
+
G = self.global_state["G"] = flatten_jacobian(G_list) # (n_out, ndim)
|
|
124
|
+
Gtf = G.T @ f.detach() # (ndim)
|
|
125
|
+
self.global_state["Gtf"] = Gtf
|
|
126
|
+
var.grad = vec_to_tensors(Gtf, var.params)
|
|
127
|
+
|
|
128
|
+
# set closure to calculate sum of squares for line searches etc
|
|
129
|
+
if var.closure is not None:
|
|
130
|
+
def sos_closure(backward=True):
|
|
131
|
+
if backward:
|
|
132
|
+
var.zero_grad()
|
|
133
|
+
with torch.enable_grad():
|
|
134
|
+
loss = closure(False).pow(2).sum()
|
|
135
|
+
loss.backward()
|
|
136
|
+
return loss
|
|
137
|
+
|
|
138
|
+
loss = closure(False).pow(2).sum()
|
|
139
|
+
return loss
|
|
140
|
+
|
|
141
|
+
var.closure = sos_closure
|
|
142
|
+
|
|
143
|
+
@torch.no_grad
|
|
144
|
+
def apply(self, var):
|
|
145
|
+
reg = self.defaults['reg']
|
|
146
|
+
|
|
147
|
+
G = self.global_state['G']
|
|
148
|
+
Gtf = self.global_state['Gtf']
|
|
149
|
+
|
|
150
|
+
GtG = G.T @ G # (ndim, ndim)
|
|
151
|
+
if reg != 0:
|
|
152
|
+
GtG.add_(torch.eye(GtG.size(0), device=GtG.device, dtype=GtG.dtype).mul_(reg))
|
|
153
|
+
|
|
154
|
+
v = torch.linalg.lstsq(GtG, Gtf).solution # pylint:disable=not-callable
|
|
155
|
+
|
|
156
|
+
var.update = vec_to_tensors(v, var.params)
|
|
157
|
+
return var
|
|
158
|
+
|
|
159
|
+
def get_H(self, var):
|
|
160
|
+
G = self.global_state['G']
|
|
161
|
+
return linear_operator.AtA(G)
|
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
from .adaptive import
|
|
1
|
+
from .adaptive import AdaptiveTracking
|
|
2
2
|
from .backtracking import AdaptiveBacktracking, Backtracking
|
|
3
3
|
from .line_search import LineSearchBase
|
|
4
4
|
from .scipy import ScipyMinimizeScalar
|
|
5
|
-
from .strong_wolfe import StrongWolfe
|
|
5
|
+
from .strong_wolfe import StrongWolfe
|
|
@@ -0,0 +1,289 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
from .line_search import LineSearchBase
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
# polynomial interpolation
|
|
8
|
+
# this code is from https://github.com/hjmshi/PyTorch-LBFGS/blob/master/functions/LBFGS.py
|
|
9
|
+
# PyTorch-LBFGS: A PyTorch Implementation of L-BFGS
|
|
10
|
+
def polyinterp(points, x_min_bound=None, x_max_bound=None, plot=False):
|
|
11
|
+
"""
|
|
12
|
+
Gives the minimizer and minimum of the interpolating polynomial over given points
|
|
13
|
+
based on function and derivative information. Defaults to bisection if no critical
|
|
14
|
+
points are valid.
|
|
15
|
+
|
|
16
|
+
Based on polyinterp.m Matlab function in minFunc by Mark Schmidt with some slight
|
|
17
|
+
modifications.
|
|
18
|
+
|
|
19
|
+
Implemented by: Hao-Jun Michael Shi and Dheevatsa Mudigere
|
|
20
|
+
Last edited 12/6/18.
|
|
21
|
+
|
|
22
|
+
Inputs:
|
|
23
|
+
points (nparray): two-dimensional array with each point of form [x f g]
|
|
24
|
+
x_min_bound (float): minimum value that brackets minimum (default: minimum of points)
|
|
25
|
+
x_max_bound (float): maximum value that brackets minimum (default: maximum of points)
|
|
26
|
+
plot (bool): plot interpolating polynomial
|
|
27
|
+
|
|
28
|
+
Outputs:
|
|
29
|
+
x_sol (float): minimizer of interpolating polynomial
|
|
30
|
+
F_min (float): minimum of interpolating polynomial
|
|
31
|
+
|
|
32
|
+
Note:
|
|
33
|
+
. Set f or g to np.nan if they are unknown
|
|
34
|
+
|
|
35
|
+
"""
|
|
36
|
+
no_points = points.shape[0]
|
|
37
|
+
order = np.sum(1 - np.isnan(points[:, 1:3]).astype('int')) - 1
|
|
38
|
+
|
|
39
|
+
x_min = np.min(points[:, 0])
|
|
40
|
+
x_max = np.max(points[:, 0])
|
|
41
|
+
|
|
42
|
+
# compute bounds of interpolation area
|
|
43
|
+
if x_min_bound is None:
|
|
44
|
+
x_min_bound = x_min
|
|
45
|
+
if x_max_bound is None:
|
|
46
|
+
x_max_bound = x_max
|
|
47
|
+
|
|
48
|
+
# explicit formula for quadratic interpolation
|
|
49
|
+
if no_points == 2 and order == 2 and plot is False:
|
|
50
|
+
# Solution to quadratic interpolation is given by:
|
|
51
|
+
# a = -(f1 - f2 - g1(x1 - x2))/(x1 - x2)^2
|
|
52
|
+
# x_min = x1 - g1/(2a)
|
|
53
|
+
# if x1 = 0, then is given by:
|
|
54
|
+
# x_min = - (g1*x2^2)/(2(f2 - f1 - g1*x2))
|
|
55
|
+
|
|
56
|
+
if points[0, 0] == 0:
|
|
57
|
+
x_sol = -points[0, 2] * points[1, 0] ** 2 / (2 * (points[1, 1] - points[0, 1] - points[0, 2] * points[1, 0]))
|
|
58
|
+
else:
|
|
59
|
+
a = -(points[0, 1] - points[1, 1] - points[0, 2] * (points[0, 0] - points[1, 0])) / (points[0, 0] - points[1, 0]) ** 2
|
|
60
|
+
x_sol = points[0, 0] - points[0, 2]/(2*a)
|
|
61
|
+
|
|
62
|
+
x_sol = np.minimum(np.maximum(x_min_bound, x_sol), x_max_bound)
|
|
63
|
+
|
|
64
|
+
# explicit formula for cubic interpolation
|
|
65
|
+
elif no_points == 2 and order == 3 and plot is False:
|
|
66
|
+
# Solution to cubic interpolation is given by:
|
|
67
|
+
# d1 = g1 + g2 - 3((f1 - f2)/(x1 - x2))
|
|
68
|
+
# d2 = sqrt(d1^2 - g1*g2)
|
|
69
|
+
# x_min = x2 - (x2 - x1)*((g2 + d2 - d1)/(g2 - g1 + 2*d2))
|
|
70
|
+
d1 = points[0, 2] + points[1, 2] - 3 * ((points[0, 1] - points[1, 1]) / (points[0, 0] - points[1, 0]))
|
|
71
|
+
value = d1 ** 2 - points[0, 2] * points[1, 2]
|
|
72
|
+
if value > 0:
|
|
73
|
+
d2 = np.sqrt(value)
|
|
74
|
+
x_sol = points[1, 0] - (points[1, 0] - points[0, 0]) * ((points[1, 2] + d2 - d1) / (points[1, 2] - points[0, 2] + 2 * d2))
|
|
75
|
+
x_sol = np.minimum(np.maximum(x_min_bound, x_sol), x_max_bound)
|
|
76
|
+
else:
|
|
77
|
+
x_sol = (x_max_bound + x_min_bound)/2
|
|
78
|
+
|
|
79
|
+
# solve linear system
|
|
80
|
+
else:
|
|
81
|
+
# define linear constraints
|
|
82
|
+
A = np.zeros((0, order + 1))
|
|
83
|
+
b = np.zeros((0, 1))
|
|
84
|
+
|
|
85
|
+
# add linear constraints on function values
|
|
86
|
+
for i in range(no_points):
|
|
87
|
+
if not np.isnan(points[i, 1]):
|
|
88
|
+
constraint = np.zeros((1, order + 1))
|
|
89
|
+
for j in range(order, -1, -1):
|
|
90
|
+
constraint[0, order - j] = points[i, 0] ** j
|
|
91
|
+
A = np.append(A, constraint, 0)
|
|
92
|
+
b = np.append(b, points[i, 1])
|
|
93
|
+
|
|
94
|
+
# add linear constraints on gradient values
|
|
95
|
+
for i in range(no_points):
|
|
96
|
+
if not np.isnan(points[i, 2]):
|
|
97
|
+
constraint = np.zeros((1, order + 1))
|
|
98
|
+
for j in range(order):
|
|
99
|
+
constraint[0, j] = (order - j) * points[i, 0] ** (order - j - 1)
|
|
100
|
+
A = np.append(A, constraint, 0)
|
|
101
|
+
b = np.append(b, points[i, 2])
|
|
102
|
+
|
|
103
|
+
# check if system is solvable
|
|
104
|
+
if A.shape[0] != A.shape[1] or np.linalg.matrix_rank(A) != A.shape[0]:
|
|
105
|
+
x_sol = (x_min_bound + x_max_bound)/2
|
|
106
|
+
f_min = np.inf
|
|
107
|
+
else:
|
|
108
|
+
# solve linear system for interpolating polynomial
|
|
109
|
+
coeff = np.linalg.solve(A, b)
|
|
110
|
+
|
|
111
|
+
# compute critical points
|
|
112
|
+
dcoeff = np.zeros(order)
|
|
113
|
+
for i in range(len(coeff) - 1):
|
|
114
|
+
dcoeff[i] = coeff[i] * (order - i)
|
|
115
|
+
|
|
116
|
+
crit_pts = np.array([x_min_bound, x_max_bound])
|
|
117
|
+
crit_pts = np.append(crit_pts, points[:, 0])
|
|
118
|
+
|
|
119
|
+
if not np.isinf(dcoeff).any():
|
|
120
|
+
roots = np.roots(dcoeff)
|
|
121
|
+
crit_pts = np.append(crit_pts, roots)
|
|
122
|
+
|
|
123
|
+
# test critical points
|
|
124
|
+
f_min = np.inf
|
|
125
|
+
x_sol = (x_min_bound + x_max_bound) / 2 # defaults to bisection
|
|
126
|
+
for crit_pt in crit_pts:
|
|
127
|
+
if np.isreal(crit_pt):
|
|
128
|
+
if not np.isrealobj(crit_pt): crit_pt = crit_pt.real
|
|
129
|
+
if crit_pt >= x_min_bound and crit_pt <= x_max_bound:
|
|
130
|
+
F_cp = np.polyval(coeff, crit_pt)
|
|
131
|
+
if np.isreal(F_cp) and F_cp < f_min:
|
|
132
|
+
x_sol = np.real(crit_pt)
|
|
133
|
+
f_min = np.real(F_cp)
|
|
134
|
+
|
|
135
|
+
if(plot):
|
|
136
|
+
import matplotlib.pyplot as plt
|
|
137
|
+
plt.figure()
|
|
138
|
+
x = np.arange(x_min_bound, x_max_bound, (x_max_bound - x_min_bound)/10000)
|
|
139
|
+
f = np.polyval(coeff, x)
|
|
140
|
+
plt.plot(x, f)
|
|
141
|
+
plt.plot(x_sol, f_min, 'x')
|
|
142
|
+
|
|
143
|
+
return x_sol
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
# polynomial interpolation
|
|
147
|
+
# this code is based on https://github.com/hjmshi/PyTorch-LBFGS/blob/master/functions/LBFGS.py
|
|
148
|
+
# PyTorch-LBFGS: A PyTorch Implementation of L-BFGS
|
|
149
|
+
# this one is modified where instead of clipping the solution by bounds, it tries a lower degree polynomial
|
|
150
|
+
# all the way to bisection
|
|
151
|
+
def _within_bounds(x, lb, ub):
|
|
152
|
+
if lb is not None and x < lb: return False
|
|
153
|
+
if ub is not None and x > ub: return False
|
|
154
|
+
return True
|
|
155
|
+
|
|
156
|
+
def _quad_interp(points):
|
|
157
|
+
assert points.shape[0] == 2, points.shape
|
|
158
|
+
if points[0, 0] == 0:
|
|
159
|
+
denom = 2 * (points[1, 1] - points[0, 1] - points[0, 2] * points[1, 0])
|
|
160
|
+
if abs(denom) > 1e-32:
|
|
161
|
+
return -points[0, 2] * points[1, 0] ** 2 / denom
|
|
162
|
+
else:
|
|
163
|
+
denom = (points[0, 0] - points[1, 0]) ** 2
|
|
164
|
+
if denom > 1e-32:
|
|
165
|
+
a = -(points[0, 1] - points[1, 1] - points[0, 2] * (points[0, 0] - points[1, 0])) / denom
|
|
166
|
+
if a > 1e-32:
|
|
167
|
+
return points[0, 0] - points[0, 2]/(2*a)
|
|
168
|
+
return None
|
|
169
|
+
|
|
170
|
+
def _cubic_interp(points, lb, ub):
|
|
171
|
+
assert points.shape[0] == 2, points.shape
|
|
172
|
+
denom = points[0, 0] - points[1, 0]
|
|
173
|
+
if abs(denom) > 1e-32:
|
|
174
|
+
d1 = points[0, 2] + points[1, 2] - 3 * ((points[0, 1] - points[1, 1]) / denom)
|
|
175
|
+
value = d1 ** 2 - points[0, 2] * points[1, 2]
|
|
176
|
+
if value > 0:
|
|
177
|
+
d2 = np.sqrt(value)
|
|
178
|
+
denom = points[1, 2] - points[0, 2] + 2 * d2
|
|
179
|
+
if abs(denom) > 1e-32:
|
|
180
|
+
x_sol = points[1, 0] - (points[1, 0] - points[0, 0]) * ((points[1, 2] + d2 - d1) / denom)
|
|
181
|
+
if _within_bounds(x_sol, lb, ub): return x_sol
|
|
182
|
+
|
|
183
|
+
# try quadratic interpolations
|
|
184
|
+
x_sol = _quad_interp(points)
|
|
185
|
+
if x_sol is not None and _within_bounds(x_sol, lb, ub): return x_sol
|
|
186
|
+
|
|
187
|
+
return None
|
|
188
|
+
|
|
189
|
+
def _poly_interp(points, lb, ub):
|
|
190
|
+
no_points = points.shape[0]
|
|
191
|
+
assert no_points > 2, points.shape
|
|
192
|
+
order = np.sum(1 - np.isnan(points[:, 1:3]).astype('int')) - 1
|
|
193
|
+
|
|
194
|
+
# define linear constraints
|
|
195
|
+
A = np.zeros((0, order + 1))
|
|
196
|
+
b = np.zeros((0, 1))
|
|
197
|
+
|
|
198
|
+
# add linear constraints on function values
|
|
199
|
+
for i in range(no_points):
|
|
200
|
+
if not np.isnan(points[i, 1]):
|
|
201
|
+
constraint = np.zeros((1, order + 1))
|
|
202
|
+
for j in range(order, -1, -1):
|
|
203
|
+
constraint[0, order - j] = points[i, 0] ** j
|
|
204
|
+
A = np.append(A, constraint, 0)
|
|
205
|
+
b = np.append(b, points[i, 1])
|
|
206
|
+
|
|
207
|
+
# add linear constraints on gradient values
|
|
208
|
+
for i in range(no_points):
|
|
209
|
+
if not np.isnan(points[i, 2]):
|
|
210
|
+
constraint = np.zeros((1, order + 1))
|
|
211
|
+
for j in range(order):
|
|
212
|
+
constraint[0, j] = (order - j) * points[i, 0] ** (order - j - 1)
|
|
213
|
+
A = np.append(A, constraint, 0)
|
|
214
|
+
b = np.append(b, points[i, 2])
|
|
215
|
+
|
|
216
|
+
# check if system is solvable
|
|
217
|
+
if A.shape[0] != A.shape[1] or np.linalg.matrix_rank(A) != A.shape[0]:
|
|
218
|
+
return None
|
|
219
|
+
|
|
220
|
+
# solve linear system for interpolating polynomial
|
|
221
|
+
coeff = np.linalg.solve(A, b)
|
|
222
|
+
|
|
223
|
+
# compute critical points
|
|
224
|
+
dcoeff = np.zeros(order)
|
|
225
|
+
for i in range(len(coeff) - 1):
|
|
226
|
+
dcoeff[i] = coeff[i] * (order - i)
|
|
227
|
+
|
|
228
|
+
lower = np.min(points[:, 0]) if lb is None else lb
|
|
229
|
+
upper = np.max(points[:, 0]) if ub is None else ub
|
|
230
|
+
|
|
231
|
+
crit_pts = np.array([lower, upper])
|
|
232
|
+
crit_pts = np.append(crit_pts, points[:, 0])
|
|
233
|
+
|
|
234
|
+
if not np.isinf(dcoeff).any():
|
|
235
|
+
roots = np.roots(dcoeff)
|
|
236
|
+
crit_pts = np.append(crit_pts, roots)
|
|
237
|
+
|
|
238
|
+
# test critical points
|
|
239
|
+
f_min = np.inf
|
|
240
|
+
x_sol = None
|
|
241
|
+
for crit_pt in crit_pts:
|
|
242
|
+
if np.isreal(crit_pt):
|
|
243
|
+
if not np.isrealobj(crit_pt): crit_pt = crit_pt.real
|
|
244
|
+
if _within_bounds(crit_pt, lb, ub):
|
|
245
|
+
F_cp = np.polyval(coeff, crit_pt)
|
|
246
|
+
if np.isreal(F_cp) and F_cp < f_min:
|
|
247
|
+
x_sol = np.real(crit_pt)
|
|
248
|
+
f_min = np.real(F_cp)
|
|
249
|
+
|
|
250
|
+
return x_sol
|
|
251
|
+
|
|
252
|
+
def polyinterp2(points, lb, ub, unbounded: bool = False):
|
|
253
|
+
no_points = points.shape[0]
|
|
254
|
+
if no_points <= 1:
|
|
255
|
+
return (lb + ub)/2
|
|
256
|
+
|
|
257
|
+
order = np.sum(1 - np.isnan(points[:, 1:3]).astype('int')) - 1
|
|
258
|
+
|
|
259
|
+
x_min = np.min(points[:, 0])
|
|
260
|
+
x_max = np.max(points[:, 0])
|
|
261
|
+
|
|
262
|
+
# compute bounds of interpolation area
|
|
263
|
+
if not unbounded:
|
|
264
|
+
if lb is None:
|
|
265
|
+
lb = x_min
|
|
266
|
+
if ub is None:
|
|
267
|
+
ub = x_max
|
|
268
|
+
|
|
269
|
+
if no_points == 2 and order == 2:
|
|
270
|
+
x_sol = _quad_interp(points)
|
|
271
|
+
if x_sol is not None and _within_bounds(x_sol, lb, ub): return x_sol
|
|
272
|
+
return (lb + ub)/2
|
|
273
|
+
|
|
274
|
+
if no_points == 2 and order == 3:
|
|
275
|
+
x_sol = _cubic_interp(points, lb, ub) # includes fallback on _quad_interp
|
|
276
|
+
if x_sol is not None and _within_bounds(x_sol, lb, ub): return x_sol
|
|
277
|
+
return (lb + ub)/2
|
|
278
|
+
|
|
279
|
+
if no_points <= 2: # order < 2
|
|
280
|
+
return (lb + ub)/2
|
|
281
|
+
|
|
282
|
+
if no_points == 3:
|
|
283
|
+
for p in (points[:2], points[1:], points[::2]):
|
|
284
|
+
x_sol = _cubic_interp(p, lb, ub)
|
|
285
|
+
if x_sol is not None and _within_bounds(x_sol, lb, ub): return x_sol
|
|
286
|
+
|
|
287
|
+
x_sol = _poly_interp(points, lb, ub)
|
|
288
|
+
if x_sol is not None and _within_bounds(x_sol, lb, ub): return x_sol
|
|
289
|
+
return polyinterp2(points[1:], lb, ub)
|