torchzero 0.3.9__py3-none-any.whl → 0.3.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_opts.py +54 -21
- tests/test_tensorlist.py +2 -2
- tests/test_vars.py +61 -61
- torchzero/core/__init__.py +2 -3
- torchzero/core/module.py +49 -49
- torchzero/core/transform.py +219 -158
- torchzero/modules/__init__.py +1 -0
- torchzero/modules/clipping/clipping.py +10 -10
- torchzero/modules/clipping/ema_clipping.py +14 -13
- torchzero/modules/clipping/growth_clipping.py +16 -18
- torchzero/modules/experimental/__init__.py +12 -3
- torchzero/modules/experimental/absoap.py +50 -156
- torchzero/modules/experimental/adadam.py +15 -14
- torchzero/modules/experimental/adamY.py +17 -27
- torchzero/modules/experimental/adasoap.py +19 -129
- torchzero/modules/experimental/curveball.py +12 -12
- torchzero/modules/experimental/diagonal_higher_order_newton.py +225 -0
- torchzero/modules/experimental/eigendescent.py +117 -0
- torchzero/modules/experimental/etf.py +172 -0
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/newton_solver.py +11 -11
- torchzero/modules/experimental/newtonnewton.py +88 -0
- torchzero/modules/experimental/reduce_outward_lr.py +8 -5
- torchzero/modules/experimental/soapy.py +19 -146
- torchzero/modules/experimental/spectral.py +79 -204
- torchzero/modules/experimental/structured_newton.py +12 -12
- torchzero/modules/experimental/subspace_preconditioners.py +13 -10
- torchzero/modules/experimental/tada.py +38 -0
- torchzero/modules/grad_approximation/fdm.py +2 -2
- torchzero/modules/grad_approximation/forward_gradient.py +5 -5
- torchzero/modules/grad_approximation/grad_approximator.py +21 -21
- torchzero/modules/grad_approximation/rfdm.py +28 -15
- torchzero/modules/higher_order/__init__.py +1 -0
- torchzero/modules/higher_order/higher_order_newton.py +256 -0
- torchzero/modules/line_search/backtracking.py +42 -23
- torchzero/modules/line_search/line_search.py +40 -40
- torchzero/modules/line_search/scipy.py +18 -3
- torchzero/modules/line_search/strong_wolfe.py +21 -32
- torchzero/modules/line_search/trust_region.py +18 -6
- torchzero/modules/lr/__init__.py +1 -1
- torchzero/modules/lr/{step_size.py → adaptive.py} +22 -26
- torchzero/modules/lr/lr.py +20 -16
- torchzero/modules/momentum/averaging.py +25 -10
- torchzero/modules/momentum/cautious.py +73 -35
- torchzero/modules/momentum/ema.py +92 -41
- torchzero/modules/momentum/experimental.py +21 -13
- torchzero/modules/momentum/matrix_momentum.py +96 -54
- torchzero/modules/momentum/momentum.py +24 -4
- torchzero/modules/ops/accumulate.py +51 -21
- torchzero/modules/ops/binary.py +36 -36
- torchzero/modules/ops/debug.py +7 -7
- torchzero/modules/ops/misc.py +128 -129
- torchzero/modules/ops/multi.py +19 -19
- torchzero/modules/ops/reduce.py +16 -16
- torchzero/modules/ops/split.py +26 -26
- torchzero/modules/ops/switch.py +4 -4
- torchzero/modules/ops/unary.py +20 -20
- torchzero/modules/ops/utility.py +37 -37
- torchzero/modules/optimizers/adagrad.py +33 -24
- torchzero/modules/optimizers/adam.py +31 -34
- torchzero/modules/optimizers/lion.py +4 -4
- torchzero/modules/optimizers/muon.py +6 -6
- torchzero/modules/optimizers/orthograd.py +4 -5
- torchzero/modules/optimizers/rmsprop.py +13 -16
- torchzero/modules/optimizers/rprop.py +52 -49
- torchzero/modules/optimizers/shampoo.py +17 -23
- torchzero/modules/optimizers/soap.py +12 -19
- torchzero/modules/optimizers/sophia_h.py +13 -13
- torchzero/modules/projections/dct.py +4 -4
- torchzero/modules/projections/fft.py +6 -6
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +57 -57
- torchzero/modules/projections/structural.py +17 -17
- torchzero/modules/quasi_newton/__init__.py +33 -4
- torchzero/modules/quasi_newton/cg.py +67 -17
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +24 -24
- torchzero/modules/quasi_newton/lbfgs.py +12 -12
- torchzero/modules/quasi_newton/lsr1.py +11 -11
- torchzero/modules/quasi_newton/olbfgs.py +19 -19
- torchzero/modules/quasi_newton/quasi_newton.py +254 -47
- torchzero/modules/second_order/newton.py +32 -20
- torchzero/modules/second_order/newton_cg.py +13 -12
- torchzero/modules/second_order/nystrom.py +21 -21
- torchzero/modules/smoothing/gaussian.py +21 -21
- torchzero/modules/smoothing/laplacian.py +7 -9
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +43 -9
- torchzero/modules/wrappers/optim_wrapper.py +11 -11
- torchzero/optim/wrappers/directsearch.py +244 -0
- torchzero/optim/wrappers/fcmaes.py +97 -0
- torchzero/optim/wrappers/mads.py +90 -0
- torchzero/optim/wrappers/nevergrad.py +4 -4
- torchzero/optim/wrappers/nlopt.py +28 -14
- torchzero/optim/wrappers/optuna.py +70 -0
- torchzero/optim/wrappers/scipy.py +162 -13
- torchzero/utils/__init__.py +2 -6
- torchzero/utils/derivatives.py +2 -1
- torchzero/utils/optimizer.py +55 -74
- torchzero/utils/python_tools.py +17 -4
- {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/METADATA +14 -14
- torchzero-0.3.10.dist-info/RECORD +139 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/WHEEL +1 -1
- torchzero/core/preconditioner.py +0 -138
- torchzero/modules/experimental/algebraic_newton.py +0 -145
- torchzero/modules/experimental/tropical_newton.py +0 -136
- torchzero-0.3.9.dist-info/RECORD +0 -131
- {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/top_level.txt +0 -0
|
@@ -5,7 +5,7 @@ from typing import Any, Literal
|
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
|
|
8
|
-
from ...core import Module,
|
|
8
|
+
from ...core import Module, Var
|
|
9
9
|
|
|
10
10
|
GradTarget = Literal['update', 'grad', 'closure']
|
|
11
11
|
_Scalar = torch.Tensor | float
|
|
@@ -17,50 +17,50 @@ class GradApproximator(Module, ABC):
|
|
|
17
17
|
Args:
|
|
18
18
|
defaults (dict[str, Any] | None, optional): dict with defaults. Defaults to None.
|
|
19
19
|
target (str, optional):
|
|
20
|
-
whether to set `
|
|
20
|
+
whether to set `var.grad`, `var.update` or 'var.closure`. Defaults to 'closure'.
|
|
21
21
|
"""
|
|
22
22
|
def __init__(self, defaults: dict[str, Any] | None = None, target: GradTarget = 'closure'):
|
|
23
23
|
super().__init__(defaults)
|
|
24
24
|
self._target: GradTarget = target
|
|
25
25
|
|
|
26
26
|
@abstractmethod
|
|
27
|
-
def approximate(self, closure: Callable, params: list[torch.Tensor], loss: _Scalar | None,
|
|
27
|
+
def approximate(self, closure: Callable, params: list[torch.Tensor], loss: _Scalar | None, var: Var) -> tuple[Iterable[torch.Tensor], _Scalar | None, _Scalar | None]:
|
|
28
28
|
"""Returns a tuple: (grad, loss, loss_approx), make sure this resets parameters to their original values!"""
|
|
29
29
|
|
|
30
|
-
def pre_step(self,
|
|
30
|
+
def pre_step(self, var: Var) -> Var | None:
|
|
31
31
|
"""This runs once before each step, whereas `approximate` may run multiple times per step if further modules
|
|
32
32
|
evaluate gradients at multiple points. This is useful for example to pre-generate new random perturbations."""
|
|
33
|
-
return
|
|
33
|
+
return var
|
|
34
34
|
|
|
35
35
|
@torch.no_grad
|
|
36
|
-
def step(self,
|
|
37
|
-
ret = self.pre_step(
|
|
38
|
-
if isinstance(ret,
|
|
36
|
+
def step(self, var):
|
|
37
|
+
ret = self.pre_step(var)
|
|
38
|
+
if isinstance(ret, Var): var = ret
|
|
39
39
|
|
|
40
|
-
if
|
|
41
|
-
params, closure, loss =
|
|
40
|
+
if var.closure is None: raise RuntimeError("Gradient approximation requires closure")
|
|
41
|
+
params, closure, loss = var.params, var.closure, var.loss
|
|
42
42
|
|
|
43
43
|
if self._target == 'closure':
|
|
44
44
|
|
|
45
45
|
def approx_closure(backward=True):
|
|
46
46
|
if backward:
|
|
47
47
|
# set loss to None because closure might be evaluated at different points
|
|
48
|
-
grad, l, l_approx = self.approximate(closure=closure, params=params, loss=None,
|
|
48
|
+
grad, l, l_approx = self.approximate(closure=closure, params=params, loss=None, var=var)
|
|
49
49
|
for p, g in zip(params, grad): p.grad = g
|
|
50
50
|
return l if l is not None else l_approx
|
|
51
51
|
return closure(False)
|
|
52
52
|
|
|
53
|
-
|
|
54
|
-
return
|
|
53
|
+
var.closure = approx_closure
|
|
54
|
+
return var
|
|
55
55
|
|
|
56
|
-
# if
|
|
57
|
-
# warnings.warn('Using grad approximator when `
|
|
58
|
-
grad,loss,loss_approx = self.approximate(closure=closure, params=params, loss=loss,
|
|
59
|
-
if loss_approx is not None:
|
|
60
|
-
if loss is not None:
|
|
61
|
-
if self._target == 'grad':
|
|
62
|
-
elif self._target == 'update':
|
|
56
|
+
# if var.grad is not None:
|
|
57
|
+
# warnings.warn('Using grad approximator when `var.grad` is already set.')
|
|
58
|
+
grad,loss,loss_approx = self.approximate(closure=closure, params=params, loss=loss, var=var)
|
|
59
|
+
if loss_approx is not None: var.loss_approx = loss_approx
|
|
60
|
+
if loss is not None: var.loss = var.loss_approx = loss
|
|
61
|
+
if self._target == 'grad': var.grad = list(grad)
|
|
62
|
+
elif self._target == 'update': var.update = list(grad)
|
|
63
63
|
else: raise ValueError(self._target)
|
|
64
|
-
return
|
|
64
|
+
return var
|
|
65
65
|
|
|
66
66
|
_FD_Formula = Literal['forward2', 'backward2', 'forward3', 'backward3', 'central2', 'central4']
|
|
@@ -90,6 +90,19 @@ _RFD_FUNCS = {
|
|
|
90
90
|
|
|
91
91
|
|
|
92
92
|
class RandomizedFDM(GradApproximator):
|
|
93
|
+
"""_summary_
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
h (float, optional): finite difference step size of jvp_method is set to `forward` or `central`. Defaults to 1e-3.
|
|
97
|
+
n_samples (int, optional): number of random gradient samples. Defaults to 1.
|
|
98
|
+
formula (_FD_Formula, optional): finite difference formula. Defaults to 'central2'.
|
|
99
|
+
distribution (Distributions, optional): distribution. Defaults to "rademacher".
|
|
100
|
+
If this is set to a value higher than zero, instead of using directional derivatives in a new random direction on each step, the direction changes gradually with momentum based on this value. This may make it possible to use methods with memory. Defaults to 0.
|
|
101
|
+
pre_generate (bool, optional):
|
|
102
|
+
whether to pre-generate gradient samples before each step. If samples are not pre-generated, whenever a method performs multiple closure evaluations, the gradient will be evaluated in different directions each time. Defaults to True.
|
|
103
|
+
seed (int | None | torch.Generator, optional): Seed for random generator. Defaults to None.
|
|
104
|
+
target (GradTarget, optional): what to set on var. Defaults to "closure".
|
|
105
|
+
"""
|
|
93
106
|
PRE_MULTIPLY_BY_H = True
|
|
94
107
|
def __init__(
|
|
95
108
|
self,
|
|
@@ -99,8 +112,8 @@ class RandomizedFDM(GradApproximator):
|
|
|
99
112
|
distribution: Distributions = "rademacher",
|
|
100
113
|
beta: float = 0,
|
|
101
114
|
pre_generate = True,
|
|
102
|
-
target: GradTarget = "closure",
|
|
103
115
|
seed: int | None | torch.Generator = None,
|
|
116
|
+
target: GradTarget = "closure",
|
|
104
117
|
):
|
|
105
118
|
defaults = dict(h=h, formula=formula, n_samples=n_samples, distribution=distribution, beta=beta, pre_generate=pre_generate, seed=seed)
|
|
106
119
|
super().__init__(defaults, target=target)
|
|
@@ -118,16 +131,16 @@ class RandomizedFDM(GradApproximator):
|
|
|
118
131
|
else: self.global_state['generator'] = None
|
|
119
132
|
return self.global_state['generator']
|
|
120
133
|
|
|
121
|
-
def pre_step(self,
|
|
122
|
-
h, beta = self.get_settings('h', 'beta'
|
|
123
|
-
settings = self.settings[
|
|
134
|
+
def pre_step(self, var):
|
|
135
|
+
h, beta = self.get_settings(var.params, 'h', 'beta')
|
|
136
|
+
settings = self.settings[var.params[0]]
|
|
124
137
|
n_samples = settings['n_samples']
|
|
125
138
|
distribution = settings['distribution']
|
|
126
139
|
pre_generate = settings['pre_generate']
|
|
127
140
|
|
|
128
141
|
if pre_generate:
|
|
129
|
-
params = TensorList(
|
|
130
|
-
generator = self._get_generator(settings['seed'],
|
|
142
|
+
params = TensorList(var.params)
|
|
143
|
+
generator = self._get_generator(settings['seed'], var.params)
|
|
131
144
|
perturbations = [params.sample_like(distribution=distribution, generator=generator) for _ in range(n_samples)]
|
|
132
145
|
|
|
133
146
|
if self.PRE_MULTIPLY_BY_H:
|
|
@@ -152,11 +165,11 @@ class RandomizedFDM(GradApproximator):
|
|
|
152
165
|
torch._foreach_lerp_(cur_flat, new_flat, betas)
|
|
153
166
|
|
|
154
167
|
@torch.no_grad
|
|
155
|
-
def approximate(self, closure, params, loss,
|
|
168
|
+
def approximate(self, closure, params, loss, var):
|
|
156
169
|
params = TensorList(params)
|
|
157
170
|
loss_approx = None
|
|
158
171
|
|
|
159
|
-
h = self.
|
|
172
|
+
h = NumberList(self.settings[p]['h'] for p in params)
|
|
160
173
|
settings = self.settings[params[0]]
|
|
161
174
|
n_samples = settings['n_samples']
|
|
162
175
|
fd_fn = _RFD_FUNCS[settings['formula']]
|
|
@@ -220,29 +233,29 @@ class MeZO(GradApproximator):
|
|
|
220
233
|
distribution=distribution, generator=torch.Generator(params[0].device).manual_seed(seed)
|
|
221
234
|
).mul_(h)
|
|
222
235
|
|
|
223
|
-
def pre_step(self,
|
|
224
|
-
h = self.
|
|
225
|
-
settings = self.settings[
|
|
236
|
+
def pre_step(self, var):
|
|
237
|
+
h = NumberList(self.settings[p]['h'] for p in var.params)
|
|
238
|
+
settings = self.settings[var.params[0]]
|
|
226
239
|
n_samples = settings['n_samples']
|
|
227
240
|
distribution = settings['distribution']
|
|
228
241
|
|
|
229
|
-
step =
|
|
242
|
+
step = var.current_step
|
|
230
243
|
|
|
231
244
|
# create functions that generate a deterministic perturbation from seed based on current step
|
|
232
245
|
prt_fns = []
|
|
233
246
|
for i in range(n_samples):
|
|
234
247
|
|
|
235
|
-
prt_fn = partial(self._seeded_perturbation, params=
|
|
248
|
+
prt_fn = partial(self._seeded_perturbation, params=var.params, distribution=distribution, seed=1_000_000*step + i, h=h)
|
|
236
249
|
prt_fns.append(prt_fn)
|
|
237
250
|
|
|
238
251
|
self.global_state['prt_fns'] = prt_fns
|
|
239
252
|
|
|
240
253
|
@torch.no_grad
|
|
241
|
-
def approximate(self, closure, params, loss,
|
|
254
|
+
def approximate(self, closure, params, loss, var):
|
|
242
255
|
params = TensorList(params)
|
|
243
256
|
loss_approx = None
|
|
244
257
|
|
|
245
|
-
h = self.
|
|
258
|
+
h = NumberList(self.settings[p]['h'] for p in params)
|
|
246
259
|
settings = self.settings[params[0]]
|
|
247
260
|
n_samples = settings['n_samples']
|
|
248
261
|
fd_fn = _RFD_FUNCS[settings['formula']]
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .higher_order_newton import HigherOrderNewton
|
|
@@ -0,0 +1,256 @@
|
|
|
1
|
+
import itertools
|
|
2
|
+
import math
|
|
3
|
+
import warnings
|
|
4
|
+
from collections.abc import Callable
|
|
5
|
+
from contextlib import nullcontext
|
|
6
|
+
from functools import partial
|
|
7
|
+
from typing import Any, Literal
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
import scipy.optimize
|
|
11
|
+
import torch
|
|
12
|
+
|
|
13
|
+
from ...core import Chainable, Module, apply_transform
|
|
14
|
+
from ...utils import TensorList, vec_to_tensors, vec_to_tensors_
|
|
15
|
+
from ...utils.derivatives import (
|
|
16
|
+
hessian_list_to_mat,
|
|
17
|
+
jacobian_wrt,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
_LETTERS = 'abcdefghijklmnopqrstuvwxyz'
|
|
21
|
+
def _poly_eval(s: np.ndarray, c, derivatives):
|
|
22
|
+
val = float(c)
|
|
23
|
+
for i,T in enumerate(derivatives, 1):
|
|
24
|
+
s1 = ''.join(_LETTERS[:i]) # abcd
|
|
25
|
+
s2 = ',...'.join(_LETTERS[:i]) # a,b,c,d
|
|
26
|
+
# this would make einsum('abcd,a,b,c,d', T, x, x, x, x)
|
|
27
|
+
val += np.einsum(f"...{s1},...{s2}", T, *(s for _ in range(i))) / math.factorial(i)
|
|
28
|
+
return val
|
|
29
|
+
|
|
30
|
+
def _proximal_poly_v(x: np.ndarray, c, prox, x0: np.ndarray, derivatives):
|
|
31
|
+
if x.ndim == 2: x = x.T # DE passes (ndim, batch_size)
|
|
32
|
+
s = x - x0
|
|
33
|
+
val = _poly_eval(s, c, derivatives)
|
|
34
|
+
penalty = 0
|
|
35
|
+
if prox != 0: penalty = (prox / 2) * (s**2).sum(-1) # proximal penalty
|
|
36
|
+
return val + penalty
|
|
37
|
+
|
|
38
|
+
def _proximal_poly_g(x: np.ndarray, c, prox, x0: np.ndarray, derivatives):
|
|
39
|
+
s = x - x0
|
|
40
|
+
g = derivatives[0].copy()
|
|
41
|
+
if len(derivatives) > 1:
|
|
42
|
+
for i, T in enumerate(derivatives[1:], 2):
|
|
43
|
+
s1 = ''.join(_LETTERS[:i]) # abcd
|
|
44
|
+
s2 = ','.join(_LETTERS[1:i]) # b,c,d
|
|
45
|
+
# this would make einsum('abcd,b,c,d->a', T, x, x, x)
|
|
46
|
+
g += np.einsum(f"{s1},{s2}->a", T, *(s for _ in range(i-1))) / math.factorial(i - 1)
|
|
47
|
+
|
|
48
|
+
g_prox = 0
|
|
49
|
+
if prox != 0: g_prox = prox * s
|
|
50
|
+
return g + g_prox
|
|
51
|
+
|
|
52
|
+
def _proximal_poly_H(x: np.ndarray, c, prox, x0: np.ndarray, derivatives):
|
|
53
|
+
s = x - x0
|
|
54
|
+
n = x.shape[0]
|
|
55
|
+
if len(derivatives) == 1:
|
|
56
|
+
H = np.zeros(n, n)
|
|
57
|
+
else:
|
|
58
|
+
H = derivatives[1].copy()
|
|
59
|
+
if len(derivatives) > 2:
|
|
60
|
+
for i, T in enumerate(derivatives[2:], 3):
|
|
61
|
+
s1 = ''.join(_LETTERS[:i]) # abcd
|
|
62
|
+
s2 = ','.join(_LETTERS[2:i]) # c,d
|
|
63
|
+
# this would make einsum('abcd,c,d->ab', T, x, x, x)
|
|
64
|
+
H += np.einsum(f"{s1},{s2}->ab", T, *(s for _ in range(i-2))) / math.factorial(i - 2)
|
|
65
|
+
|
|
66
|
+
H_prox = 0
|
|
67
|
+
if prox != 0: H_prox = np.eye(n) * prox
|
|
68
|
+
return H + H_prox
|
|
69
|
+
|
|
70
|
+
def _poly_minimize(trust_region, prox, de_iters: Any, c, x: torch.Tensor, derivatives):
|
|
71
|
+
derivatives = [T.detach().cpu().numpy().astype(np.float64) for T in derivatives]
|
|
72
|
+
x0 = x.detach().cpu().numpy().astype(np.float64) # taylor series center
|
|
73
|
+
bounds = None
|
|
74
|
+
if trust_region is not None: bounds = list(zip(x0 - trust_region, x0 + trust_region))
|
|
75
|
+
|
|
76
|
+
# if len(derivatives) is 1, only gradient is available, I use that to test proximal penalty and bounds
|
|
77
|
+
if bounds is None:
|
|
78
|
+
if len(derivatives) == 1: method = 'bfgs'
|
|
79
|
+
else: method = 'trust-exact'
|
|
80
|
+
else:
|
|
81
|
+
if len(derivatives) == 1: method = 'l-bfgs-b'
|
|
82
|
+
else: method = 'trust-constr'
|
|
83
|
+
|
|
84
|
+
x_init = x0.copy()
|
|
85
|
+
v0 = _proximal_poly_v(x0, c, prox, x0, derivatives)
|
|
86
|
+
if de_iters is not None and de_iters != 0:
|
|
87
|
+
if de_iters == -1: de_iters = None # let scipy decide
|
|
88
|
+
res = scipy.optimize.differential_evolution(
|
|
89
|
+
_proximal_poly_v,
|
|
90
|
+
bounds if bounds is not None else list(zip(x0 - 10, x0 + 10)),
|
|
91
|
+
args=(c, prox, x0.copy(), derivatives),
|
|
92
|
+
maxiter=de_iters,
|
|
93
|
+
vectorized=True,
|
|
94
|
+
)
|
|
95
|
+
if res.fun < v0: x_init = res.x
|
|
96
|
+
|
|
97
|
+
res = scipy.optimize.minimize(
|
|
98
|
+
_proximal_poly_v,
|
|
99
|
+
x_init,
|
|
100
|
+
method=method,
|
|
101
|
+
args=(c, prox, x0.copy(), derivatives),
|
|
102
|
+
jac=_proximal_poly_g,
|
|
103
|
+
hess=_proximal_poly_H,
|
|
104
|
+
bounds=bounds
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
return torch.from_numpy(res.x).to(x), res.fun
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
class HigherOrderNewton(Module):
|
|
112
|
+
"""
|
|
113
|
+
A basic arbitrary order newton's method with optional trust region and proximal penalty.
|
|
114
|
+
It is recommended to enable at least one of trust region or proximal penalty.
|
|
115
|
+
|
|
116
|
+
This constructs an nth order taylor approximation via autograd and minimizes it with
|
|
117
|
+
scipy.optimize.minimize trust region newton solvers with optional proximal penalty.
|
|
118
|
+
|
|
119
|
+
This uses n^order memory, where n is number of decision variables, and I am not aware
|
|
120
|
+
of any problems where this is more efficient than newton's method. It can minimize
|
|
121
|
+
rosenbrock in a single step, but that step probably takes more time than newton.
|
|
122
|
+
And there are way more efficient tensor methods out there but they tend to be
|
|
123
|
+
significantly more complex.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
|
|
127
|
+
order (int, optional):
|
|
128
|
+
Order of the method, number of taylor series terms (orders of derivatives) used to approximate the function. Defaults to 4.
|
|
129
|
+
trust_method (str | None, optional):
|
|
130
|
+
Method used for trust region.
|
|
131
|
+
- "bounds" - the model is minimized within bounds defined by trust region.
|
|
132
|
+
- "proximal" - the model is minimized with penalty for going too far from current point.
|
|
133
|
+
- "none" - disables trust region.
|
|
134
|
+
|
|
135
|
+
Defaults to 'bounds'.
|
|
136
|
+
increase (float, optional): trust region multiplier on good steps. Defaults to 1.5.
|
|
137
|
+
decrease (float, optional): trust region multiplier on bad steps. Defaults to 0.75.
|
|
138
|
+
trust_init (float | None, optional):
|
|
139
|
+
initial trust region size. If none, defaults to 1 on :code:`trust_method="bounds"` and 0.1 on :code:`"proximal"`. Defaults to None.
|
|
140
|
+
trust_tol (float, optional):
|
|
141
|
+
Maximum ratio of expected loss reduction to actual reduction for trust region increase.
|
|
142
|
+
Should 1 or higer. Defaults to 2.
|
|
143
|
+
de_iters (int | None, optional):
|
|
144
|
+
If this is specified, the model is minimized via differential evolution first to possibly escape local minima,
|
|
145
|
+
then it is passed to scipy.optimize.minimize. Defaults to None.
|
|
146
|
+
vectorize (bool, optional): whether to enable vectorized jacobians (usually faster). Defaults to True.
|
|
147
|
+
"""
|
|
148
|
+
def __init__(
|
|
149
|
+
self,
|
|
150
|
+
order: int = 4,
|
|
151
|
+
trust_method: Literal['bounds', 'proximal', 'none'] | None = 'bounds',
|
|
152
|
+
increase: float = 1.5,
|
|
153
|
+
decrease: float = 0.75,
|
|
154
|
+
trust_init: float | None = None,
|
|
155
|
+
trust_tol: float = 2,
|
|
156
|
+
de_iters: int | None = None,
|
|
157
|
+
vectorize: bool = True,
|
|
158
|
+
):
|
|
159
|
+
if trust_init is None:
|
|
160
|
+
if trust_method == 'bounds': trust_init = 1
|
|
161
|
+
else: trust_init = 0.1
|
|
162
|
+
|
|
163
|
+
defaults = dict(order=order, trust_method=trust_method, increase=increase, decrease=decrease, trust_tol=trust_tol, trust_init=trust_init, vectorize=vectorize, de_iters=de_iters)
|
|
164
|
+
super().__init__(defaults)
|
|
165
|
+
|
|
166
|
+
@torch.no_grad
|
|
167
|
+
def step(self, var):
|
|
168
|
+
params = TensorList(var.params)
|
|
169
|
+
closure = var.closure
|
|
170
|
+
if closure is None: raise RuntimeError('NewtonCG requires closure')
|
|
171
|
+
|
|
172
|
+
settings = self.settings[params[0]]
|
|
173
|
+
order = settings['order']
|
|
174
|
+
increase = settings['increase']
|
|
175
|
+
decrease = settings['decrease']
|
|
176
|
+
trust_tol = settings['trust_tol']
|
|
177
|
+
trust_init = settings['trust_init']
|
|
178
|
+
trust_method = settings['trust_method']
|
|
179
|
+
de_iters = settings['de_iters']
|
|
180
|
+
vectorize = settings['vectorize']
|
|
181
|
+
|
|
182
|
+
trust_value = self.global_state.get('trust_value', trust_init)
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
# ------------------------ calculate grad and hessian ------------------------ #
|
|
186
|
+
with torch.enable_grad():
|
|
187
|
+
loss = var.loss = var.loss_approx = closure(False)
|
|
188
|
+
|
|
189
|
+
g_list = torch.autograd.grad(loss, params, create_graph=True)
|
|
190
|
+
var.grad = list(g_list)
|
|
191
|
+
|
|
192
|
+
g = torch.cat([t.ravel() for t in g_list])
|
|
193
|
+
n = g.numel()
|
|
194
|
+
derivatives = [g]
|
|
195
|
+
T = g # current derivatives tensor
|
|
196
|
+
|
|
197
|
+
# get all derivative up to order
|
|
198
|
+
for o in range(2, order + 1):
|
|
199
|
+
is_last = o == order
|
|
200
|
+
T_list = jacobian_wrt([T], params, create_graph=not is_last, batched=vectorize)
|
|
201
|
+
with torch.no_grad() if is_last else nullcontext():
|
|
202
|
+
# the shape is (ndim, ) * order
|
|
203
|
+
T = hessian_list_to_mat(T_list).view(n, n, *T.shape[1:])
|
|
204
|
+
derivatives.append(T)
|
|
205
|
+
|
|
206
|
+
x0 = torch.cat([p.ravel() for p in params])
|
|
207
|
+
|
|
208
|
+
if trust_method is None: trust_method = 'none'
|
|
209
|
+
else: trust_method = trust_method.lower()
|
|
210
|
+
|
|
211
|
+
if trust_method == 'none':
|
|
212
|
+
trust_region = None
|
|
213
|
+
prox = 0
|
|
214
|
+
|
|
215
|
+
elif trust_method == 'bounds':
|
|
216
|
+
trust_region = trust_value
|
|
217
|
+
prox = 0
|
|
218
|
+
|
|
219
|
+
elif trust_method == 'proximal':
|
|
220
|
+
trust_region = None
|
|
221
|
+
prox = 1 / trust_value
|
|
222
|
+
|
|
223
|
+
else:
|
|
224
|
+
raise ValueError(trust_method)
|
|
225
|
+
|
|
226
|
+
x_star, expected_loss = _poly_minimize(
|
|
227
|
+
trust_region=trust_region,
|
|
228
|
+
prox=prox,
|
|
229
|
+
de_iters=de_iters,
|
|
230
|
+
c=loss.item(),
|
|
231
|
+
x=x0,
|
|
232
|
+
derivatives=derivatives,
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
# trust region
|
|
236
|
+
if trust_method != 'none':
|
|
237
|
+
expected_reduction = loss - expected_loss
|
|
238
|
+
|
|
239
|
+
vec_to_tensors_(x_star, params)
|
|
240
|
+
loss_star = closure(False)
|
|
241
|
+
vec_to_tensors_(x0, params)
|
|
242
|
+
reduction = loss - loss_star
|
|
243
|
+
|
|
244
|
+
# failed step
|
|
245
|
+
if reduction <= 0:
|
|
246
|
+
x_star = x0
|
|
247
|
+
self.global_state['trust_value'] = trust_value * decrease
|
|
248
|
+
|
|
249
|
+
# very good step
|
|
250
|
+
elif expected_reduction / reduction <= trust_tol:
|
|
251
|
+
self.global_state['trust_value'] = trust_value * increase
|
|
252
|
+
|
|
253
|
+
difference = vec_to_tensors(x0 - x_star, params)
|
|
254
|
+
var.update = list(difference)
|
|
255
|
+
return var
|
|
256
|
+
|
|
@@ -14,7 +14,6 @@ def backtracking_line_search(
|
|
|
14
14
|
beta: float = 0.5,
|
|
15
15
|
c: float = 1e-4,
|
|
16
16
|
maxiter: int = 10,
|
|
17
|
-
a_min: float | None = None,
|
|
18
17
|
try_negative: bool = False,
|
|
19
18
|
) -> float | None:
|
|
20
19
|
"""
|
|
@@ -26,7 +25,6 @@ def backtracking_line_search(
|
|
|
26
25
|
beta: The factor by which to decrease alpha in each iteration
|
|
27
26
|
c: The constant for the Armijo sufficient decrease condition
|
|
28
27
|
max_iter: Maximum number of backtracking iterations (default: 10).
|
|
29
|
-
min_alpha: Minimum allowable step size to prevent near-zero values (default: 1e-16).
|
|
30
28
|
|
|
31
29
|
Returns:
|
|
32
30
|
step size
|
|
@@ -45,10 +43,6 @@ def backtracking_line_search(
|
|
|
45
43
|
# decrease alpha
|
|
46
44
|
a *= beta
|
|
47
45
|
|
|
48
|
-
# alpha too small
|
|
49
|
-
if a_min is not None and a < a_min:
|
|
50
|
-
return a_min
|
|
51
|
-
|
|
52
46
|
# fail
|
|
53
47
|
if try_negative:
|
|
54
48
|
def inv_objective(alpha): return f(-alpha)
|
|
@@ -59,7 +53,6 @@ def backtracking_line_search(
|
|
|
59
53
|
beta=beta,
|
|
60
54
|
c=c,
|
|
61
55
|
maxiter=maxiter,
|
|
62
|
-
a_min=a_min,
|
|
63
56
|
try_negative=False,
|
|
64
57
|
)
|
|
65
58
|
if v is not None: return -v
|
|
@@ -67,17 +60,28 @@ def backtracking_line_search(
|
|
|
67
60
|
return None
|
|
68
61
|
|
|
69
62
|
class Backtracking(LineSearch):
|
|
63
|
+
"""Backtracking line search satisfying the Armijo condition.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
init (float, optional): initial step size. Defaults to 1.0.
|
|
67
|
+
beta (float, optional): multiplies each consecutive step size by this value. Defaults to 0.5.
|
|
68
|
+
c (float, optional): acceptance value for Armijo condition. Defaults to 1e-4.
|
|
69
|
+
maxiter (int, optional): Maximum line search function evaluations. Defaults to 10.
|
|
70
|
+
adaptive (bool, optional):
|
|
71
|
+
when enabled, if line search failed, initial step size is reduced.
|
|
72
|
+
Otherwise it is reset to initial value. Defaults to True.
|
|
73
|
+
try_negative (bool, optional): Whether to perform line search in opposite direction on fail. Defaults to False.
|
|
74
|
+
"""
|
|
70
75
|
def __init__(
|
|
71
76
|
self,
|
|
72
77
|
init: float = 1.0,
|
|
73
78
|
beta: float = 0.5,
|
|
74
79
|
c: float = 1e-4,
|
|
75
80
|
maxiter: int = 10,
|
|
76
|
-
min_alpha: float | None = None,
|
|
77
81
|
adaptive=True,
|
|
78
82
|
try_negative: bool = False,
|
|
79
83
|
):
|
|
80
|
-
defaults=dict(init=init,beta=beta,c=c,maxiter=maxiter,
|
|
84
|
+
defaults=dict(init=init,beta=beta,c=c,maxiter=maxiter,adaptive=adaptive, try_negative=try_negative)
|
|
81
85
|
super().__init__(defaults=defaults)
|
|
82
86
|
self.global_state['beta_scale'] = 1.0
|
|
83
87
|
|
|
@@ -86,20 +90,20 @@ class Backtracking(LineSearch):
|
|
|
86
90
|
self.global_state['beta_scale'] = 1.0
|
|
87
91
|
|
|
88
92
|
@torch.no_grad
|
|
89
|
-
def search(self, update,
|
|
90
|
-
init, beta, c, maxiter,
|
|
91
|
-
'init', 'beta', 'c', 'maxiter', '
|
|
93
|
+
def search(self, update, var):
|
|
94
|
+
init, beta, c, maxiter, adaptive, try_negative = itemgetter(
|
|
95
|
+
'init', 'beta', 'c', 'maxiter', 'adaptive', 'try_negative')(self.settings[var.params[0]])
|
|
92
96
|
|
|
93
|
-
objective = self.make_objective(
|
|
97
|
+
objective = self.make_objective(var=var)
|
|
94
98
|
|
|
95
99
|
# # directional derivative
|
|
96
|
-
d = -sum(t.sum() for t in torch._foreach_mul(
|
|
100
|
+
d = -sum(t.sum() for t in torch._foreach_mul(var.get_grad(), var.get_update()))
|
|
97
101
|
|
|
98
102
|
# scale beta (beta is multiplicative and i think may be better than scaling initial step size)
|
|
99
103
|
if adaptive: beta = beta * self.global_state['beta_scale']
|
|
100
104
|
|
|
101
105
|
step_size = backtracking_line_search(objective, d, init=init,beta=beta,
|
|
102
|
-
c=c,maxiter=maxiter,
|
|
106
|
+
c=c,maxiter=maxiter, try_negative=try_negative)
|
|
103
107
|
|
|
104
108
|
# found an alpha that reduces loss
|
|
105
109
|
if step_size is not None:
|
|
@@ -114,19 +118,34 @@ def _lerp(start,end,weight):
|
|
|
114
118
|
return start + weight * (end - start)
|
|
115
119
|
|
|
116
120
|
class AdaptiveBacktracking(LineSearch):
|
|
121
|
+
"""Adaptive backtracking line search. After each line search procedure, a new initial step size is set
|
|
122
|
+
such that optimal step size in the procedure would be found on the second line search iteration.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
init (float, optional): step size for the first step. Defaults to 1.0.
|
|
126
|
+
beta (float, optional): multiplies each consecutive step size by this value. Defaults to 0.5.
|
|
127
|
+
c (float, optional): acceptance value for Armijo condition. Defaults to 1e-4.
|
|
128
|
+
maxiter (int, optional): Maximum line search function evaluations. Defaults to 10.
|
|
129
|
+
target_iters (int, optional):
|
|
130
|
+
target number of iterations that would be performed until optimal step size is found. Defaults to 1.
|
|
131
|
+
nplus (float, optional):
|
|
132
|
+
Multiplier to initial step size if it was found to be the optimal step size. Defaults to 2.0.
|
|
133
|
+
scale_beta (float, optional):
|
|
134
|
+
Momentum for initial step size, at 0 disables momentum. Defaults to 0.0.
|
|
135
|
+
try_negative (bool, optional): Whether to perform line search in opposite direction on fail. Defaults to False.
|
|
136
|
+
"""
|
|
117
137
|
def __init__(
|
|
118
138
|
self,
|
|
119
139
|
init: float = 1.0,
|
|
120
140
|
beta: float = 0.5,
|
|
121
141
|
c: float = 1e-4,
|
|
122
142
|
maxiter: int = 20,
|
|
123
|
-
min_alpha: float | None = None,
|
|
124
143
|
target_iters = 1,
|
|
125
144
|
nplus = 2.0,
|
|
126
145
|
scale_beta = 0.0,
|
|
127
146
|
try_negative: bool = False,
|
|
128
147
|
):
|
|
129
|
-
defaults=dict(init=init,beta=beta,c=c,maxiter=maxiter,
|
|
148
|
+
defaults=dict(init=init,beta=beta,c=c,maxiter=maxiter,target_iters=target_iters,nplus=nplus,scale_beta=scale_beta, try_negative=try_negative)
|
|
130
149
|
super().__init__(defaults=defaults)
|
|
131
150
|
|
|
132
151
|
self.global_state['beta_scale'] = 1.0
|
|
@@ -138,15 +157,15 @@ class AdaptiveBacktracking(LineSearch):
|
|
|
138
157
|
self.global_state['initial_scale'] = 1.0
|
|
139
158
|
|
|
140
159
|
@torch.no_grad
|
|
141
|
-
def search(self, update,
|
|
142
|
-
init, beta, c, maxiter,
|
|
143
|
-
'init','beta','c','maxiter','
|
|
160
|
+
def search(self, update, var):
|
|
161
|
+
init, beta, c, maxiter, target_iters, nplus, scale_beta, try_negative=itemgetter(
|
|
162
|
+
'init','beta','c','maxiter','target_iters','nplus','scale_beta', 'try_negative')(self.settings[var.params[0]])
|
|
144
163
|
|
|
145
|
-
objective = self.make_objective(
|
|
164
|
+
objective = self.make_objective(var=var)
|
|
146
165
|
|
|
147
166
|
# directional derivative (0 if c = 0 because it is not needed)
|
|
148
167
|
if c == 0: d = 0
|
|
149
|
-
else: d = -sum(t.sum() for t in torch._foreach_mul(
|
|
168
|
+
else: d = -sum(t.sum() for t in torch._foreach_mul(var.get_grad(), update))
|
|
150
169
|
|
|
151
170
|
# scale beta
|
|
152
171
|
beta = beta * self.global_state['beta_scale']
|
|
@@ -155,7 +174,7 @@ class AdaptiveBacktracking(LineSearch):
|
|
|
155
174
|
init = init * self.global_state['initial_scale']
|
|
156
175
|
|
|
157
176
|
step_size = backtracking_line_search(objective, d, init=init, beta=beta,
|
|
158
|
-
c=c,maxiter=maxiter,
|
|
177
|
+
c=c,maxiter=maxiter, try_negative=try_negative)
|
|
159
178
|
|
|
160
179
|
# found an alpha that reduces loss
|
|
161
180
|
if step_size is not None:
|