torchzero 0.3.8__py3-none-any.whl → 0.3.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_opts.py +55 -22
- tests/test_tensorlist.py +3 -3
- tests/test_vars.py +61 -61
- torchzero/core/__init__.py +2 -3
- torchzero/core/module.py +49 -49
- torchzero/core/transform.py +219 -158
- torchzero/modules/__init__.py +1 -0
- torchzero/modules/clipping/clipping.py +10 -10
- torchzero/modules/clipping/ema_clipping.py +14 -13
- torchzero/modules/clipping/growth_clipping.py +16 -18
- torchzero/modules/experimental/__init__.py +12 -3
- torchzero/modules/experimental/absoap.py +50 -156
- torchzero/modules/experimental/adadam.py +15 -14
- torchzero/modules/experimental/adamY.py +17 -27
- torchzero/modules/experimental/adasoap.py +20 -130
- torchzero/modules/experimental/curveball.py +12 -12
- torchzero/modules/experimental/diagonal_higher_order_newton.py +225 -0
- torchzero/modules/experimental/eigendescent.py +117 -0
- torchzero/modules/experimental/etf.py +172 -0
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/newton_solver.py +11 -11
- torchzero/modules/experimental/newtonnewton.py +88 -0
- torchzero/modules/experimental/reduce_outward_lr.py +8 -5
- torchzero/modules/experimental/soapy.py +19 -146
- torchzero/modules/experimental/spectral.py +79 -204
- torchzero/modules/experimental/structured_newton.py +111 -0
- torchzero/modules/experimental/subspace_preconditioners.py +13 -10
- torchzero/modules/experimental/tada.py +38 -0
- torchzero/modules/grad_approximation/fdm.py +2 -2
- torchzero/modules/grad_approximation/forward_gradient.py +5 -5
- torchzero/modules/grad_approximation/grad_approximator.py +21 -21
- torchzero/modules/grad_approximation/rfdm.py +28 -15
- torchzero/modules/higher_order/__init__.py +1 -0
- torchzero/modules/higher_order/higher_order_newton.py +256 -0
- torchzero/modules/line_search/backtracking.py +42 -23
- torchzero/modules/line_search/line_search.py +40 -40
- torchzero/modules/line_search/scipy.py +18 -3
- torchzero/modules/line_search/strong_wolfe.py +21 -32
- torchzero/modules/line_search/trust_region.py +18 -6
- torchzero/modules/lr/__init__.py +1 -1
- torchzero/modules/lr/{step_size.py → adaptive.py} +22 -26
- torchzero/modules/lr/lr.py +20 -16
- torchzero/modules/momentum/averaging.py +25 -10
- torchzero/modules/momentum/cautious.py +73 -35
- torchzero/modules/momentum/ema.py +92 -41
- torchzero/modules/momentum/experimental.py +21 -13
- torchzero/modules/momentum/matrix_momentum.py +96 -54
- torchzero/modules/momentum/momentum.py +24 -4
- torchzero/modules/ops/accumulate.py +51 -21
- torchzero/modules/ops/binary.py +36 -36
- torchzero/modules/ops/debug.py +7 -7
- torchzero/modules/ops/misc.py +128 -129
- torchzero/modules/ops/multi.py +19 -19
- torchzero/modules/ops/reduce.py +16 -16
- torchzero/modules/ops/split.py +26 -26
- torchzero/modules/ops/switch.py +4 -4
- torchzero/modules/ops/unary.py +20 -20
- torchzero/modules/ops/utility.py +37 -37
- torchzero/modules/optimizers/adagrad.py +33 -24
- torchzero/modules/optimizers/adam.py +31 -34
- torchzero/modules/optimizers/lion.py +4 -4
- torchzero/modules/optimizers/muon.py +6 -6
- torchzero/modules/optimizers/orthograd.py +4 -5
- torchzero/modules/optimizers/rmsprop.py +13 -16
- torchzero/modules/optimizers/rprop.py +52 -49
- torchzero/modules/optimizers/shampoo.py +17 -23
- torchzero/modules/optimizers/soap.py +12 -19
- torchzero/modules/optimizers/sophia_h.py +13 -13
- torchzero/modules/projections/dct.py +4 -4
- torchzero/modules/projections/fft.py +6 -6
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +57 -57
- torchzero/modules/projections/structural.py +17 -17
- torchzero/modules/quasi_newton/__init__.py +33 -4
- torchzero/modules/quasi_newton/cg.py +76 -26
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +24 -24
- torchzero/modules/quasi_newton/lbfgs.py +15 -15
- torchzero/modules/quasi_newton/lsr1.py +18 -17
- torchzero/modules/quasi_newton/olbfgs.py +19 -19
- torchzero/modules/quasi_newton/quasi_newton.py +257 -48
- torchzero/modules/second_order/newton.py +38 -21
- torchzero/modules/second_order/newton_cg.py +13 -12
- torchzero/modules/second_order/nystrom.py +19 -19
- torchzero/modules/smoothing/gaussian.py +21 -21
- torchzero/modules/smoothing/laplacian.py +7 -9
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +43 -9
- torchzero/modules/wrappers/optim_wrapper.py +11 -11
- torchzero/optim/wrappers/directsearch.py +244 -0
- torchzero/optim/wrappers/fcmaes.py +97 -0
- torchzero/optim/wrappers/mads.py +90 -0
- torchzero/optim/wrappers/nevergrad.py +4 -4
- torchzero/optim/wrappers/nlopt.py +28 -14
- torchzero/optim/wrappers/optuna.py +70 -0
- torchzero/optim/wrappers/scipy.py +162 -13
- torchzero/utils/__init__.py +2 -6
- torchzero/utils/derivatives.py +2 -1
- torchzero/utils/optimizer.py +55 -74
- torchzero/utils/python_tools.py +17 -4
- {torchzero-0.3.8.dist-info → torchzero-0.3.10.dist-info}/METADATA +14 -14
- torchzero-0.3.10.dist-info/RECORD +139 -0
- {torchzero-0.3.8.dist-info → torchzero-0.3.10.dist-info}/WHEEL +1 -1
- torchzero/core/preconditioner.py +0 -138
- torchzero/modules/experimental/algebraic_newton.py +0 -145
- torchzero/modules/experimental/tropical_newton.py +0 -136
- torchzero-0.3.8.dist-info/RECORD +0 -130
- {torchzero-0.3.8.dist-info → torchzero-0.3.10.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.8.dist-info → torchzero-0.3.10.dist-info}/top_level.txt +0 -0
|
@@ -2,7 +2,7 @@ from typing import Literal
|
|
|
2
2
|
from collections.abc import Callable
|
|
3
3
|
import torch
|
|
4
4
|
|
|
5
|
-
from ...core import Module, Target, Transform, Chainable,
|
|
5
|
+
from ...core import Module, Target, Transform, Chainable, apply_transform
|
|
6
6
|
from ...utils import NumberList, TensorList, as_tensorlist
|
|
7
7
|
from ...utils.derivatives import hvp, hvp_fd_forward, hvp_fd_central
|
|
8
8
|
|
|
@@ -47,27 +47,27 @@ class CurveBall(Module):
|
|
|
47
47
|
if inner is not None: self.set_child('inner', inner)
|
|
48
48
|
|
|
49
49
|
@torch.no_grad
|
|
50
|
-
def step(self,
|
|
50
|
+
def step(self, var):
|
|
51
51
|
|
|
52
|
-
params =
|
|
52
|
+
params = var.params
|
|
53
53
|
settings = self.settings[params[0]]
|
|
54
54
|
hvp_method = settings['hvp_method']
|
|
55
55
|
h = settings['h']
|
|
56
56
|
|
|
57
|
-
precond_lr, momentum, reg = self.get_settings('
|
|
57
|
+
precond_lr, momentum, reg = self.get_settings(params, 'precond_lr', 'momentum', 'reg', cls=NumberList)
|
|
58
58
|
|
|
59
59
|
|
|
60
|
-
closure =
|
|
60
|
+
closure = var.closure
|
|
61
61
|
assert closure is not None
|
|
62
62
|
|
|
63
|
-
z, Hz = self.get_state('z', 'Hz',
|
|
63
|
+
z, Hz = self.get_state(params, 'z', 'Hz', cls=TensorList)
|
|
64
64
|
|
|
65
65
|
if hvp_method == 'autograd':
|
|
66
|
-
grad =
|
|
66
|
+
grad = var.get_grad(create_graph=True)
|
|
67
67
|
Hvp = hvp(params, grad, z)
|
|
68
68
|
|
|
69
69
|
elif hvp_method == 'forward':
|
|
70
|
-
loss, Hvp = hvp_fd_forward(closure, params, z, h=h, g_0=
|
|
70
|
+
loss, Hvp = hvp_fd_forward(closure, params, z, h=h, g_0=var.get_grad(), normalize=True)
|
|
71
71
|
|
|
72
72
|
elif hvp_method == 'central':
|
|
73
73
|
loss, Hvp = hvp_fd_central(closure, params, z, h=h, normalize=True)
|
|
@@ -79,11 +79,11 @@ class CurveBall(Module):
|
|
|
79
79
|
Hz.set_(Hvp + z*reg)
|
|
80
80
|
|
|
81
81
|
|
|
82
|
-
update =
|
|
82
|
+
update = var.get_update()
|
|
83
83
|
if 'inner' in self.children:
|
|
84
|
-
update =
|
|
84
|
+
update = apply_transform(self.children['inner'], update, params, grads=var.grad, var=var)
|
|
85
85
|
|
|
86
86
|
z = curveball(TensorList(update), z, Hz, momentum=momentum, precond_lr=precond_lr)
|
|
87
|
-
|
|
87
|
+
var.update = z.neg()
|
|
88
88
|
|
|
89
|
-
return
|
|
89
|
+
return var
|
|
@@ -0,0 +1,225 @@
|
|
|
1
|
+
import itertools
|
|
2
|
+
import math
|
|
3
|
+
import warnings
|
|
4
|
+
from collections.abc import Callable
|
|
5
|
+
from contextlib import nullcontext
|
|
6
|
+
from functools import partial
|
|
7
|
+
from typing import Any, Literal
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
import scipy.optimize
|
|
11
|
+
import torch
|
|
12
|
+
|
|
13
|
+
from ...core import Chainable, Module, apply_transform
|
|
14
|
+
from ...utils import TensorList, vec_to_tensors, vec_to_tensors_
|
|
15
|
+
from ...utils.derivatives import (
|
|
16
|
+
hessian_list_to_mat,
|
|
17
|
+
jacobian_wrt,
|
|
18
|
+
hvp,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
def _poly_eval_diag(s: np.ndarray, c, derivatives):
|
|
22
|
+
val = float(c) + (derivatives[0] * s).sum(-1)
|
|
23
|
+
|
|
24
|
+
if len(derivatives) > 1:
|
|
25
|
+
for i, d_diag in enumerate(derivatives[1:], 2):
|
|
26
|
+
val += (d_diag * (s**i)).sum(-1) / math.factorial(i)
|
|
27
|
+
|
|
28
|
+
return val
|
|
29
|
+
|
|
30
|
+
def _proximal_poly_v_diag(x: np.ndarray, c, prox, x0: np.ndarray, derivatives):
|
|
31
|
+
"""Computes the value of the proximal polynomial approximation."""
|
|
32
|
+
if x.ndim == 2: x = x.T
|
|
33
|
+
s = x - x0
|
|
34
|
+
|
|
35
|
+
val = _poly_eval_diag(s, c, derivatives)
|
|
36
|
+
|
|
37
|
+
penalty = 0
|
|
38
|
+
if prox != 0:
|
|
39
|
+
penalty = (prox / 2) * (s**2).sum(-1)
|
|
40
|
+
|
|
41
|
+
return val + penalty
|
|
42
|
+
|
|
43
|
+
def _proximal_poly_g_diag(x: np.ndarray, c, prox, x0: np.ndarray, derivatives):
|
|
44
|
+
"""Computes the gradient of the proximal polynomial approximation."""
|
|
45
|
+
s = x - x0
|
|
46
|
+
|
|
47
|
+
g = derivatives[0].copy()
|
|
48
|
+
|
|
49
|
+
if len(derivatives) > 1:
|
|
50
|
+
for i, d_diag in enumerate(derivatives[1:], 2):
|
|
51
|
+
g += d_diag * (s**(i - 1)) / math.factorial(i - 1)
|
|
52
|
+
|
|
53
|
+
if prox != 0:
|
|
54
|
+
g += prox * s
|
|
55
|
+
|
|
56
|
+
return g
|
|
57
|
+
|
|
58
|
+
def _proximal_poly_H_diag(x: np.ndarray, c, prox, x0: np.ndarray, derivatives):
|
|
59
|
+
"""Computes the Hessian of the proximal polynomial approximation."""
|
|
60
|
+
s = x - x0
|
|
61
|
+
n = x.shape[0]
|
|
62
|
+
|
|
63
|
+
if len(derivatives) < 2:
|
|
64
|
+
H_diag = np.zeros(n, dtype=s.dtype)
|
|
65
|
+
else:
|
|
66
|
+
H_diag = derivatives[1].copy()
|
|
67
|
+
|
|
68
|
+
if len(derivatives) > 2:
|
|
69
|
+
for i, d_diag in enumerate(derivatives[2:], 3):
|
|
70
|
+
H_diag += d_diag * (s**(i - 2)) / math.factorial(i - 2)
|
|
71
|
+
|
|
72
|
+
if prox != 0:
|
|
73
|
+
H_diag += prox
|
|
74
|
+
|
|
75
|
+
return np.diag(H_diag)
|
|
76
|
+
|
|
77
|
+
def _poly_minimize(trust_region, prox, de_iters: Any, c, x: torch.Tensor, derivatives):
|
|
78
|
+
derivatives = [T.detach().cpu().numpy().astype(np.float64) for T in derivatives]
|
|
79
|
+
x0 = x.detach().cpu().numpy().astype(np.float64) # taylor series center
|
|
80
|
+
bounds = None
|
|
81
|
+
if trust_region is not None: bounds = list(zip(x0 - trust_region, x0 + trust_region))
|
|
82
|
+
|
|
83
|
+
# if len(derivatives) is 1, only gradient is available, I use that to test proximal penalty and bounds
|
|
84
|
+
if bounds is None:
|
|
85
|
+
if len(derivatives) == 1: method = 'bfgs'
|
|
86
|
+
else: method = 'trust-exact'
|
|
87
|
+
else:
|
|
88
|
+
if len(derivatives) == 1: method = 'l-bfgs-b'
|
|
89
|
+
else: method = 'trust-constr'
|
|
90
|
+
|
|
91
|
+
x_init = x0.copy()
|
|
92
|
+
v0 = _proximal_poly_v_diag(x0, c, prox, x0, derivatives)
|
|
93
|
+
if de_iters is not None and de_iters != 0:
|
|
94
|
+
if de_iters == -1: de_iters = None # let scipy decide
|
|
95
|
+
res = scipy.optimize.differential_evolution(
|
|
96
|
+
_proximal_poly_v_diag,
|
|
97
|
+
bounds if bounds is not None else list(zip(x0 - 10, x0 + 10)),
|
|
98
|
+
args=(c, prox, x0.copy(), derivatives),
|
|
99
|
+
maxiter=de_iters,
|
|
100
|
+
vectorized=True,
|
|
101
|
+
)
|
|
102
|
+
if res.fun < v0: x_init = res.x
|
|
103
|
+
|
|
104
|
+
res = scipy.optimize.minimize(
|
|
105
|
+
_proximal_poly_v_diag,
|
|
106
|
+
x_init,
|
|
107
|
+
method=method,
|
|
108
|
+
args=(c, prox, x0.copy(), derivatives),
|
|
109
|
+
jac=_proximal_poly_g_diag,
|
|
110
|
+
hess=_proximal_poly_H_diag,
|
|
111
|
+
bounds=bounds
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
return torch.from_numpy(res.x).to(x), res.fun
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
class DiagonalHigherOrderNewton(Module):
|
|
119
|
+
"""
|
|
120
|
+
Hvp with ones doesn't give you the diagonal unless derivatives are diagonal, but somehow it still works,
|
|
121
|
+
except it doesn't work in all cases except ones where it works.
|
|
122
|
+
"""
|
|
123
|
+
def __init__(
|
|
124
|
+
self,
|
|
125
|
+
order: int = 4,
|
|
126
|
+
trust_method: Literal['bounds', 'proximal', 'none'] | None = 'bounds',
|
|
127
|
+
increase: float = 1.5,
|
|
128
|
+
decrease: float = 0.75,
|
|
129
|
+
trust_init: float | None = None,
|
|
130
|
+
trust_tol: float = 1,
|
|
131
|
+
de_iters: int | None = None,
|
|
132
|
+
vectorize: bool = True,
|
|
133
|
+
):
|
|
134
|
+
if trust_init is None:
|
|
135
|
+
if trust_method == 'bounds': trust_init = 1
|
|
136
|
+
else: trust_init = 0.1
|
|
137
|
+
|
|
138
|
+
defaults = dict(order=order, trust_method=trust_method, increase=increase, decrease=decrease, trust_tol=trust_tol, trust_init=trust_init, vectorize=vectorize, de_iters=de_iters)
|
|
139
|
+
super().__init__(defaults)
|
|
140
|
+
|
|
141
|
+
@torch.no_grad
|
|
142
|
+
def step(self, var):
|
|
143
|
+
params = TensorList(var.params)
|
|
144
|
+
closure = var.closure
|
|
145
|
+
if closure is None: raise RuntimeError('NewtonCG requires closure')
|
|
146
|
+
|
|
147
|
+
settings = self.settings[params[0]]
|
|
148
|
+
order = settings['order']
|
|
149
|
+
increase = settings['increase']
|
|
150
|
+
decrease = settings['decrease']
|
|
151
|
+
trust_tol = settings['trust_tol']
|
|
152
|
+
trust_init = settings['trust_init']
|
|
153
|
+
trust_method = settings['trust_method']
|
|
154
|
+
de_iters = settings['de_iters']
|
|
155
|
+
|
|
156
|
+
trust_value = self.global_state.get('trust_value', trust_init)
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
# ------------------------ calculate grad and hessian ------------------------ #
|
|
160
|
+
with torch.enable_grad():
|
|
161
|
+
loss = var.loss = var.loss_approx = closure(False)
|
|
162
|
+
|
|
163
|
+
g = torch.autograd.grad(loss, params, create_graph=True)
|
|
164
|
+
var.grad = list(g)
|
|
165
|
+
|
|
166
|
+
derivatives = [g]
|
|
167
|
+
T = g # current derivatives tensor diagonal
|
|
168
|
+
ones = [torch.ones_like(t) for t in g]
|
|
169
|
+
|
|
170
|
+
# get all derivatives up to order
|
|
171
|
+
for o in range(2, order + 1):
|
|
172
|
+
T = hvp(params, T, ones, create_graph=o != order)
|
|
173
|
+
derivatives.append(T)
|
|
174
|
+
|
|
175
|
+
x0 = torch.cat([p.ravel() for p in params])
|
|
176
|
+
|
|
177
|
+
if trust_method is None: trust_method = 'none'
|
|
178
|
+
else: trust_method = trust_method.lower()
|
|
179
|
+
|
|
180
|
+
if trust_method == 'none':
|
|
181
|
+
trust_region = None
|
|
182
|
+
prox = 0
|
|
183
|
+
|
|
184
|
+
elif trust_method == 'bounds':
|
|
185
|
+
trust_region = trust_value
|
|
186
|
+
prox = 0
|
|
187
|
+
|
|
188
|
+
elif trust_method == 'proximal':
|
|
189
|
+
trust_region = None
|
|
190
|
+
prox = 1 / trust_value
|
|
191
|
+
|
|
192
|
+
else:
|
|
193
|
+
raise ValueError(trust_method)
|
|
194
|
+
|
|
195
|
+
x_star, expected_loss = _poly_minimize(
|
|
196
|
+
trust_region=trust_region,
|
|
197
|
+
prox=prox,
|
|
198
|
+
de_iters=de_iters,
|
|
199
|
+
c=loss.item(),
|
|
200
|
+
x=x0,
|
|
201
|
+
derivatives=[torch.cat([t.ravel() for t in d]) for d in derivatives],
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
# trust region
|
|
205
|
+
if trust_method != 'none':
|
|
206
|
+
expected_reduction = loss - expected_loss
|
|
207
|
+
|
|
208
|
+
vec_to_tensors_(x_star, params)
|
|
209
|
+
loss_star = closure(False)
|
|
210
|
+
vec_to_tensors_(x0, params)
|
|
211
|
+
reduction = loss - loss_star
|
|
212
|
+
|
|
213
|
+
# failed step
|
|
214
|
+
if reduction <= 0:
|
|
215
|
+
x_star = x0
|
|
216
|
+
self.global_state['trust_value'] = trust_value * decrease
|
|
217
|
+
|
|
218
|
+
# very good step
|
|
219
|
+
elif expected_reduction / reduction <= trust_tol:
|
|
220
|
+
self.global_state['trust_value'] = trust_value * increase
|
|
221
|
+
|
|
222
|
+
difference = vec_to_tensors(x0 - x_star, params)
|
|
223
|
+
var.update = list(difference)
|
|
224
|
+
return var
|
|
225
|
+
|
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
from contextlib import nullcontext
|
|
2
|
+
import warnings
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
from functools import partial
|
|
5
|
+
import itertools
|
|
6
|
+
from typing import Literal
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
from ...core import Chainable, Module, apply_transform
|
|
11
|
+
from ...utils import TensorList, vec_to_tensors
|
|
12
|
+
from ...utils.derivatives import (
|
|
13
|
+
hessian_list_to_mat,
|
|
14
|
+
jacobian_wrt, jacobian_and_hessian_wrt, hessian_mat,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
def _batched_dot(x, y):
|
|
18
|
+
return (x.unsqueeze(-2) @ y.unsqueeze(-1)).squeeze(-1).squeeze(-1)
|
|
19
|
+
|
|
20
|
+
def _cosine_similarity(x, y):
|
|
21
|
+
denom = torch.linalg.vector_norm(x, dim=-1) * torch.linalg.vector_norm(y, dim=-1).clip(min=torch.finfo(x.dtype).eps) # pylint:disable=not-callable
|
|
22
|
+
return _batched_dot(x, y) / denom
|
|
23
|
+
|
|
24
|
+
class EigenDescent(Module):
|
|
25
|
+
"""
|
|
26
|
+
Uses eigenvectors corresponding to certain eigenvalues. Please note that this is experimental and isn't guaranteed to work.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
mode (str, optional):
|
|
30
|
+
- largest - use largest eigenvalue unless all eigenvalues are negative, then smallest is used.
|
|
31
|
+
- smallest - use smallest eigenvalue unless all eigenvalues are positive, then largest is used.
|
|
32
|
+
- mean-sign - use mean of eigenvectors multiplied by 1 or -1 if they point in opposite direction from gradient.
|
|
33
|
+
- mean-dot - use mean of eigenvectors multiplied by dot product with gradient.
|
|
34
|
+
- mean-cosine - use mean of eigenvectors multiplied by cosine similarity with gradient.
|
|
35
|
+
- mm - for testing.
|
|
36
|
+
|
|
37
|
+
Defaults to 'mean-sign'.
|
|
38
|
+
hessian_method (str, optional): how to calculate hessian. Defaults to "autograd".
|
|
39
|
+
vectorize (bool, optional): how to calculate hessian. Defaults to True.
|
|
40
|
+
|
|
41
|
+
"""
|
|
42
|
+
def __init__(
|
|
43
|
+
self,
|
|
44
|
+
mode: Literal['largest', 'smallest','magnitude', 'mean-sign', 'mean-dot', 'mean-cosine', 'mm'] = 'mean-sign',
|
|
45
|
+
hessian_method: Literal["autograd", "func", "autograd.functional"] = "autograd",
|
|
46
|
+
vectorize: bool = True,
|
|
47
|
+
):
|
|
48
|
+
defaults = dict(hessian_method=hessian_method, vectorize=vectorize, mode=mode)
|
|
49
|
+
super().__init__(defaults)
|
|
50
|
+
|
|
51
|
+
@torch.no_grad
|
|
52
|
+
def step(self, var):
|
|
53
|
+
params = TensorList(var.params)
|
|
54
|
+
closure = var.closure
|
|
55
|
+
if closure is None: raise RuntimeError('NewtonCG requires closure')
|
|
56
|
+
|
|
57
|
+
settings = self.settings[params[0]]
|
|
58
|
+
mode = settings['mode']
|
|
59
|
+
hessian_method = settings['hessian_method']
|
|
60
|
+
vectorize = settings['vectorize']
|
|
61
|
+
|
|
62
|
+
# ------------------------ calculate grad and hessian ------------------------ #
|
|
63
|
+
if hessian_method == 'autograd':
|
|
64
|
+
with torch.enable_grad():
|
|
65
|
+
loss = var.loss = var.loss_approx = closure(False)
|
|
66
|
+
g_list, H_list = jacobian_and_hessian_wrt([loss], params, batched=vectorize)
|
|
67
|
+
g_list = [t[0] for t in g_list] # remove leading dim from loss
|
|
68
|
+
var.grad = g_list
|
|
69
|
+
H = hessian_list_to_mat(H_list)
|
|
70
|
+
|
|
71
|
+
elif hessian_method in ('func', 'autograd.functional'):
|
|
72
|
+
strat = 'forward-mode' if vectorize else 'reverse-mode'
|
|
73
|
+
with torch.enable_grad():
|
|
74
|
+
g_list = var.get_grad(retain_graph=True)
|
|
75
|
+
H: torch.Tensor = hessian_mat(partial(closure, backward=False), params,
|
|
76
|
+
method=hessian_method, vectorize=vectorize, outer_jacobian_strategy=strat) # pyright:ignore[reportAssignmentType]
|
|
77
|
+
|
|
78
|
+
else:
|
|
79
|
+
raise ValueError(hessian_method)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
# ----------------------------------- solve ---------------------------------- #
|
|
83
|
+
g = torch.cat([t.ravel() for t in g_list])
|
|
84
|
+
L, Q = torch.linalg.eigh(H) # L is sorted # pylint:disable=not-callable
|
|
85
|
+
if mode == 'largest':
|
|
86
|
+
# smallest eigenvalue if all eigenvalues are negative else largest
|
|
87
|
+
if L[-1] <= 0: d = Q[0]
|
|
88
|
+
else: d = Q[-1]
|
|
89
|
+
|
|
90
|
+
elif mode == 'smallest':
|
|
91
|
+
# smallest eigenvalue if negative eigenvalues exist else largest
|
|
92
|
+
if L[0] <= 0: d = Q[0]
|
|
93
|
+
else: d = Q[-1]
|
|
94
|
+
|
|
95
|
+
elif mode == 'magnitude':
|
|
96
|
+
# largest by magnitude
|
|
97
|
+
if L[0].abs() > L[-1].abs(): d = Q[0]
|
|
98
|
+
else: d = Q[-1]
|
|
99
|
+
|
|
100
|
+
elif mode == 'mean-dot':
|
|
101
|
+
d = ((g.unsqueeze(0) @ Q).squeeze(0) * Q).mean(1)
|
|
102
|
+
|
|
103
|
+
elif mode == 'mean-sign':
|
|
104
|
+
d = ((g.unsqueeze(0) @ Q).squeeze(0).sign() * Q).mean(1)
|
|
105
|
+
|
|
106
|
+
elif mode == 'mean-cosine':
|
|
107
|
+
d = (Q * _cosine_similarity(Q, g)).mean(1)
|
|
108
|
+
|
|
109
|
+
elif mode == 'mm':
|
|
110
|
+
d = (g.unsqueeze(0) @ Q).squeeze(0) / g.numel()
|
|
111
|
+
|
|
112
|
+
else:
|
|
113
|
+
raise ValueError(mode)
|
|
114
|
+
|
|
115
|
+
var.update = vec_to_tensors(g.dot(d).sign() * d, params)
|
|
116
|
+
return var
|
|
117
|
+
|
|
@@ -0,0 +1,172 @@
|
|
|
1
|
+
from typing import cast
|
|
2
|
+
import warnings
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from ...core import Module
|
|
7
|
+
from ...utils import vec_to_tensors, vec_to_tensors_
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class ExponentialTrajectoryFit(Module):
|
|
11
|
+
"""A method. Please note that this is experimental and isn't guaranteed to work."""
|
|
12
|
+
def __init__(self, step_size=1e-3):
|
|
13
|
+
defaults = dict(step_size = step_size)
|
|
14
|
+
super().__init__(defaults)
|
|
15
|
+
|
|
16
|
+
@torch.no_grad
|
|
17
|
+
def step(self, var):
|
|
18
|
+
closure = var.closure
|
|
19
|
+
assert closure is not None
|
|
20
|
+
step_size = self.settings[var.params[0]]['step_size']
|
|
21
|
+
|
|
22
|
+
# 1. perform 3 GD steps to obtain 4 points
|
|
23
|
+
points = [torch.cat([p.view(-1) for p in var.params])]
|
|
24
|
+
for i in range(3):
|
|
25
|
+
if i == 0: grad = var.get_grad()
|
|
26
|
+
else:
|
|
27
|
+
with torch.enable_grad(): closure()
|
|
28
|
+
grad = [cast(torch.Tensor, p.grad) for p in var.params]
|
|
29
|
+
|
|
30
|
+
# GD step
|
|
31
|
+
torch._foreach_sub_(var.params, grad, alpha=step_size)
|
|
32
|
+
|
|
33
|
+
points.append(torch.cat([p.view(-1) for p in var.params]))
|
|
34
|
+
|
|
35
|
+
assert len(points) == 4, len(points)
|
|
36
|
+
x0, x1, x2, x3 = points
|
|
37
|
+
dim = x0.numel()
|
|
38
|
+
|
|
39
|
+
# 2. fit a generalized exponential curve
|
|
40
|
+
d0 = (x1 - x0).unsqueeze(1) # column vectors
|
|
41
|
+
d1 = (x2 - x1).unsqueeze(1)
|
|
42
|
+
d2 = (x3 - x2).unsqueeze(1)
|
|
43
|
+
|
|
44
|
+
# cat
|
|
45
|
+
D1 = torch.cat([d0, d1], dim=1)
|
|
46
|
+
D2 = torch.cat([d1, d2], dim=1)
|
|
47
|
+
|
|
48
|
+
# if points are collinear this will happen on sphere and a quadratic "line search" will minimize it
|
|
49
|
+
if x0.numel() >= 2:
|
|
50
|
+
if torch.linalg.matrix_rank(D1) < 2: # pylint:disable=not-callable
|
|
51
|
+
pass # need to put a quadratic fit there
|
|
52
|
+
|
|
53
|
+
M = D2 @ torch.linalg.pinv(D1) # pylint:disable=not-callable # this defines the curve
|
|
54
|
+
|
|
55
|
+
# now we can predict x*
|
|
56
|
+
I = torch.eye(dim, device=x0.device, dtype=x0.dtype)
|
|
57
|
+
B = I - M
|
|
58
|
+
z = x1 - M @ x0
|
|
59
|
+
|
|
60
|
+
x_star = torch.linalg.lstsq(B, z).solution # pylint:disable=not-callable
|
|
61
|
+
|
|
62
|
+
vec_to_tensors_(x0, var.params)
|
|
63
|
+
difference = torch._foreach_sub(var.params, vec_to_tensors(x_star, var.params))
|
|
64
|
+
var.update = list(difference)
|
|
65
|
+
return var
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class ExponentialTrajectoryFitV2(Module):
|
|
70
|
+
"""Should be better than one above, except it isn't. Please note that this is experimental and isn't guaranteed to work."""
|
|
71
|
+
def __init__(self, step_size=1e-3, num_steps: int= 4):
|
|
72
|
+
defaults = dict(step_size = step_size, num_steps=num_steps)
|
|
73
|
+
super().__init__(defaults)
|
|
74
|
+
|
|
75
|
+
@torch.no_grad
|
|
76
|
+
def step(self, var):
|
|
77
|
+
closure = var.closure
|
|
78
|
+
assert closure is not None
|
|
79
|
+
step_size = self.settings[var.params[0]]['step_size']
|
|
80
|
+
num_steps = self.settings[var.params[0]]['num_steps']
|
|
81
|
+
|
|
82
|
+
# 1. perform 3 GD steps to obtain 4 points (or more)
|
|
83
|
+
grad = var.get_grad()
|
|
84
|
+
points = [torch.cat([p.view(-1) for p in var.params])]
|
|
85
|
+
point_grads = [torch.cat([g.view(-1) for g in grad])]
|
|
86
|
+
|
|
87
|
+
for i in range(num_steps):
|
|
88
|
+
# GD step
|
|
89
|
+
torch._foreach_sub_(var.params, grad, alpha=step_size)
|
|
90
|
+
|
|
91
|
+
points.append(torch.cat([p.view(-1) for p in var.params]))
|
|
92
|
+
|
|
93
|
+
closure(backward=True)
|
|
94
|
+
grad = [cast(torch.Tensor, p.grad) for p in var.params]
|
|
95
|
+
point_grads.append(torch.cat([g.view(-1) for g in grad]))
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
X = torch.stack(points, 1) # dim, num_steps+1
|
|
99
|
+
G = torch.stack(point_grads, 1)
|
|
100
|
+
dim = points[0].numel()
|
|
101
|
+
|
|
102
|
+
X = torch.cat([X, torch.ones(1, num_steps+1, dtype=G.dtype, device=G.device)])
|
|
103
|
+
|
|
104
|
+
P = G @ torch.linalg.pinv(X) # pylint:disable=not-callable
|
|
105
|
+
A = P[:, :dim]
|
|
106
|
+
b = -P[:, dim]
|
|
107
|
+
|
|
108
|
+
# symmetrize
|
|
109
|
+
A = 0.5 * (A + A.T)
|
|
110
|
+
|
|
111
|
+
# predict x*
|
|
112
|
+
x_star = torch.linalg.lstsq(A, b).solution # pylint:disable=not-callable
|
|
113
|
+
|
|
114
|
+
vec_to_tensors_(points[0], var.params)
|
|
115
|
+
difference = torch._foreach_sub(var.params, vec_to_tensors(x_star, var.params))
|
|
116
|
+
var.update = list(difference)
|
|
117
|
+
return var
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def _fit_exponential(y0, y1, y2):
|
|
123
|
+
"""x0, x1 and x2 are assumed to be 0, 1, 2"""
|
|
124
|
+
r = (y2 - y1) / (y1 - y0)
|
|
125
|
+
ones = r==1
|
|
126
|
+
r[ones] = 0
|
|
127
|
+
B = (y1 - y0) / (r - 1)
|
|
128
|
+
A = y0 - B
|
|
129
|
+
|
|
130
|
+
A[ones] = 0
|
|
131
|
+
B[ones] = 0
|
|
132
|
+
return A, B, r
|
|
133
|
+
|
|
134
|
+
class PointwiseExponential(Module):
|
|
135
|
+
"""A stupid method (for my youtube channel). Please note that this is experimental and isn't guaranteed to work."""
|
|
136
|
+
def __init__(self, step_size: float = 1e-3, reg: float = 1e-2, steps = 10000):
|
|
137
|
+
defaults = dict(reg=reg, steps=steps, step_size=step_size)
|
|
138
|
+
super().__init__(defaults)
|
|
139
|
+
|
|
140
|
+
@torch.no_grad
|
|
141
|
+
def step(self, var):
|
|
142
|
+
closure = var.closure
|
|
143
|
+
assert closure is not None
|
|
144
|
+
settings = self.settings[var.params[0]]
|
|
145
|
+
step_size = settings['step_size']
|
|
146
|
+
reg = settings['reg']
|
|
147
|
+
steps = settings['steps']
|
|
148
|
+
|
|
149
|
+
# 1. perform 2 GD steps to obtain 3 points
|
|
150
|
+
points = [torch.cat([p.view(-1) for p in var.params])]
|
|
151
|
+
for i in range(2):
|
|
152
|
+
if i == 0: grad = var.get_grad()
|
|
153
|
+
else:
|
|
154
|
+
with torch.enable_grad(): closure()
|
|
155
|
+
grad = [cast(torch.Tensor, p.grad) for p in var.params]
|
|
156
|
+
|
|
157
|
+
# GD step
|
|
158
|
+
torch._foreach_sub_(var.params, grad, alpha=step_size)
|
|
159
|
+
|
|
160
|
+
points.append(torch.cat([p.view(-1) for p in var.params]))
|
|
161
|
+
|
|
162
|
+
assert len(points) == 3, len(points)
|
|
163
|
+
y0, y1, y2 = points
|
|
164
|
+
|
|
165
|
+
A, B, r = _fit_exponential(y0, y1, y2)
|
|
166
|
+
r = r.clip(max = 1-reg)
|
|
167
|
+
x_star = A + B * r**steps
|
|
168
|
+
|
|
169
|
+
vec_to_tensors_(y0, var.params)
|
|
170
|
+
difference = torch._foreach_sub(var.params, vec_to_tensors(x_star, var.params))
|
|
171
|
+
var.update = list(difference)
|
|
172
|
+
return var
|
|
@@ -5,7 +5,7 @@ from typing import Literal
|
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
|
|
8
|
-
from ...core import Module,
|
|
8
|
+
from ...core import Module, Var
|
|
9
9
|
from ...utils import NumberList, TensorList
|
|
10
10
|
from ...utils.derivatives import jacobian_wrt
|
|
11
11
|
from ..grad_approximation import GradApproximator, GradTarget
|
|
@@ -42,7 +42,7 @@ class GradMin(Reformulation):
|
|
|
42
42
|
super().__init__(defaults)
|
|
43
43
|
|
|
44
44
|
@torch.no_grad
|
|
45
|
-
def closure(self, backward, closure, params,
|
|
45
|
+
def closure(self, backward, closure, params, var):
|
|
46
46
|
settings = self.settings[params[0]]
|
|
47
47
|
loss_term = settings['loss_term']
|
|
48
48
|
relative = settings['relative']
|
|
@@ -3,13 +3,13 @@ from typing import Any, Literal, overload
|
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
5
|
|
|
6
|
-
from ...core import Chainable, Module,
|
|
6
|
+
from ...core import Chainable, Module, apply_transform, Modular
|
|
7
7
|
from ...utils import TensorList, as_tensorlist
|
|
8
8
|
from ...utils.derivatives import hvp
|
|
9
9
|
from ..quasi_newton import LBFGS
|
|
10
10
|
|
|
11
11
|
class NewtonSolver(Module):
|
|
12
|
-
"""Matrix free newton via with any custom solver (
|
|
12
|
+
"""Matrix free newton via with any custom solver (this is for testing, use NewtonCG or NystromPCG)"""
|
|
13
13
|
def __init__(
|
|
14
14
|
self,
|
|
15
15
|
solver: Callable[[list[torch.Tensor]], Any] = lambda p: Modular(p, LBFGS()),
|
|
@@ -26,9 +26,9 @@ class NewtonSolver(Module):
|
|
|
26
26
|
self.set_child('inner', inner)
|
|
27
27
|
|
|
28
28
|
@torch.no_grad
|
|
29
|
-
def step(self,
|
|
30
|
-
params = TensorList(
|
|
31
|
-
closure =
|
|
29
|
+
def step(self, var):
|
|
30
|
+
params = TensorList(var.params)
|
|
31
|
+
closure = var.closure
|
|
32
32
|
if closure is None: raise RuntimeError('NewtonCG requires closure')
|
|
33
33
|
|
|
34
34
|
settings = self.settings[params[0]]
|
|
@@ -39,7 +39,7 @@ class NewtonSolver(Module):
|
|
|
39
39
|
warm_start = settings['warm_start']
|
|
40
40
|
|
|
41
41
|
# ---------------------- Hessian vector product function --------------------- #
|
|
42
|
-
grad =
|
|
42
|
+
grad = var.get_grad(create_graph=True)
|
|
43
43
|
|
|
44
44
|
def H_mm(x):
|
|
45
45
|
with torch.enable_grad():
|
|
@@ -50,11 +50,11 @@ class NewtonSolver(Module):
|
|
|
50
50
|
# -------------------------------- inner step -------------------------------- #
|
|
51
51
|
b = as_tensorlist(grad)
|
|
52
52
|
if 'inner' in self.children:
|
|
53
|
-
b = as_tensorlist(
|
|
53
|
+
b = as_tensorlist(apply_transform(self.children['inner'], [g.clone() for g in grad], params=params, grads=grad, var=var))
|
|
54
54
|
|
|
55
55
|
# ---------------------------------- run cg ---------------------------------- #
|
|
56
56
|
x0 = None
|
|
57
|
-
if warm_start: x0 = self.get_state('prev_x',
|
|
57
|
+
if warm_start: x0 = self.get_state(params, 'prev_x', cls=TensorList) # initialized to 0 which is default anyway
|
|
58
58
|
if x0 is None: x = b.zeros_like().requires_grad_(True)
|
|
59
59
|
else: x = x0.clone().requires_grad_(True)
|
|
60
60
|
|
|
@@ -76,13 +76,13 @@ class NewtonSolver(Module):
|
|
|
76
76
|
assert loss is not None
|
|
77
77
|
if min(loss, loss/initial_loss) < tol: break
|
|
78
78
|
|
|
79
|
-
print(f'{loss = }')
|
|
79
|
+
# print(f'{loss = }')
|
|
80
80
|
|
|
81
81
|
if warm_start:
|
|
82
82
|
assert x0 is not None
|
|
83
83
|
x0.copy_(x)
|
|
84
84
|
|
|
85
|
-
|
|
86
|
-
return
|
|
85
|
+
var.update = x.detach()
|
|
86
|
+
return var
|
|
87
87
|
|
|
88
88
|
|