torchzero 0.3.10__py3-none-any.whl → 0.3.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +2 -3
- tests/test_opts.py +140 -100
- tests/test_tensorlist.py +8 -7
- tests/test_vars.py +1 -0
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +335 -50
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +197 -70
- torchzero/modules/__init__.py +13 -4
- torchzero/modules/adaptive/__init__.py +30 -0
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/adaptive/adahessian.py +224 -0
- torchzero/modules/{optimizers → adaptive}/adam.py +6 -8
- torchzero/modules/adaptive/adan.py +96 -0
- torchzero/modules/adaptive/adaptive_heavyball.py +54 -0
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/adaptive/esgd.py +171 -0
- torchzero/modules/{optimizers → adaptive}/lion.py +1 -1
- torchzero/modules/{experimental/spectral.py → adaptive/lmadagrad.py} +94 -71
- torchzero/modules/adaptive/mars.py +79 -0
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/adaptive/msam.py +188 -0
- torchzero/modules/{optimizers → adaptive}/muon.py +29 -5
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/orthograd.py +1 -1
- torchzero/modules/{optimizers → adaptive}/rmsprop.py +7 -4
- torchzero/modules/{optimizers → adaptive}/rprop.py +42 -10
- torchzero/modules/adaptive/sam.py +163 -0
- torchzero/modules/{optimizers → adaptive}/shampoo.py +47 -9
- torchzero/modules/{optimizers → adaptive}/soap.py +52 -65
- torchzero/modules/adaptive/sophia_h.py +185 -0
- torchzero/modules/clipping/clipping.py +115 -25
- torchzero/modules/clipping/ema_clipping.py +31 -17
- torchzero/modules/clipping/growth_clipping.py +8 -7
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/conjugate_gradient/cg.py +355 -0
- torchzero/modules/experimental/__init__.py +13 -19
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +5 -42
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +32 -15
- torchzero/modules/experimental/reduce_outward_lr.py +4 -4
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +13 -55
- torchzero/modules/functional.py +52 -6
- torchzero/modules/grad_approximation/fdm.py +30 -4
- torchzero/modules/grad_approximation/forward_gradient.py +16 -4
- torchzero/modules/grad_approximation/grad_approximator.py +51 -10
- torchzero/modules/grad_approximation/rfdm.py +321 -52
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +164 -93
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +124 -0
- torchzero/modules/line_search/backtracking.py +95 -57
- torchzero/modules/line_search/line_search.py +171 -22
- torchzero/modules/line_search/scipy.py +3 -3
- torchzero/modules/line_search/strong_wolfe.py +327 -199
- torchzero/modules/misc/__init__.py +35 -0
- torchzero/modules/misc/debug.py +48 -0
- torchzero/modules/misc/escape.py +62 -0
- torchzero/modules/misc/gradient_accumulation.py +136 -0
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +383 -0
- torchzero/modules/misc/multistep.py +194 -0
- torchzero/modules/misc/regularization.py +167 -0
- torchzero/modules/misc/split.py +123 -0
- torchzero/modules/{ops → misc}/switch.py +45 -4
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +9 -9
- torchzero/modules/momentum/cautious.py +51 -19
- torchzero/modules/momentum/momentum.py +37 -2
- torchzero/modules/ops/__init__.py +11 -31
- torchzero/modules/ops/accumulate.py +6 -10
- torchzero/modules/ops/binary.py +81 -34
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +16 -39
- torchzero/modules/ops/multi.py +82 -21
- torchzero/modules/ops/reduce.py +16 -8
- torchzero/modules/ops/unary.py +29 -13
- torchzero/modules/ops/utility.py +30 -18
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +3 -1
- torchzero/modules/projections/projection.py +190 -96
- torchzero/modules/quasi_newton/__init__.py +9 -14
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -0
- torchzero/modules/quasi_newton/lbfgs.py +286 -173
- torchzero/modules/quasi_newton/lsr1.py +185 -106
- torchzero/modules/quasi_newton/quasi_newton.py +816 -268
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +252 -0
- torchzero/modules/second_order/__init__.py +3 -2
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +292 -68
- torchzero/modules/second_order/newton_cg.py +365 -15
- torchzero/modules/second_order/nystrom.py +104 -1
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/laplacian.py +14 -4
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +387 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +97 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +94 -11
- torchzero/modules/wrappers/optim_wrapper.py +29 -1
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +359 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +39 -3
- torchzero/optim/wrappers/fcmaes.py +24 -15
- torchzero/optim/wrappers/mads.py +5 -6
- torchzero/optim/wrappers/nevergrad.py +16 -1
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +3 -3
- torchzero/optim/wrappers/scipy.py +86 -25
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +126 -114
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +369 -58
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/python_tools.py +16 -0
- torchzero/utils/tensorlist.py +134 -51
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.13.dist-info/METADATA +14 -0
- torchzero-0.3.13.dist-info/RECORD +166 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -57
- torchzero/modules/experimental/absoap.py +0 -250
- torchzero/modules/experimental/adadam.py +0 -112
- torchzero/modules/experimental/adamY.py +0 -125
- torchzero/modules/experimental/adasoap.py +0 -172
- torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
- torchzero/modules/experimental/eigendescent.py +0 -117
- torchzero/modules/experimental/etf.py +0 -172
- torchzero/modules/experimental/soapy.py +0 -163
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/experimental/subspace_preconditioners.py +0 -138
- torchzero/modules/experimental/tada.py +0 -38
- torchzero/modules/line_search/trust_region.py +0 -73
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/adaptive.py +0 -93
- torchzero/modules/lr/lr.py +0 -63
- torchzero/modules/momentum/matrix_momentum.py +0 -166
- torchzero/modules/ops/debug.py +0 -25
- torchzero/modules/ops/misc.py +0 -418
- torchzero/modules/ops/split.py +0 -75
- torchzero/modules/optimizers/__init__.py +0 -18
- torchzero/modules/optimizers/adagrad.py +0 -155
- torchzero/modules/optimizers/sophia_h.py +0 -129
- torchzero/modules/quasi_newton/cg.py +0 -268
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -266
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero/modules/smoothing/gaussian.py +0 -164
- torchzero-0.3.10.dist-info/METADATA +0 -379
- torchzero-0.3.10.dist-info/RECORD +0 -139
- torchzero-0.3.10.dist-info/licenses/LICENSE +0 -21
- {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
|
@@ -1,225 +0,0 @@
|
|
|
1
|
-
import itertools
|
|
2
|
-
import math
|
|
3
|
-
import warnings
|
|
4
|
-
from collections.abc import Callable
|
|
5
|
-
from contextlib import nullcontext
|
|
6
|
-
from functools import partial
|
|
7
|
-
from typing import Any, Literal
|
|
8
|
-
|
|
9
|
-
import numpy as np
|
|
10
|
-
import scipy.optimize
|
|
11
|
-
import torch
|
|
12
|
-
|
|
13
|
-
from ...core import Chainable, Module, apply_transform
|
|
14
|
-
from ...utils import TensorList, vec_to_tensors, vec_to_tensors_
|
|
15
|
-
from ...utils.derivatives import (
|
|
16
|
-
hessian_list_to_mat,
|
|
17
|
-
jacobian_wrt,
|
|
18
|
-
hvp,
|
|
19
|
-
)
|
|
20
|
-
|
|
21
|
-
def _poly_eval_diag(s: np.ndarray, c, derivatives):
|
|
22
|
-
val = float(c) + (derivatives[0] * s).sum(-1)
|
|
23
|
-
|
|
24
|
-
if len(derivatives) > 1:
|
|
25
|
-
for i, d_diag in enumerate(derivatives[1:], 2):
|
|
26
|
-
val += (d_diag * (s**i)).sum(-1) / math.factorial(i)
|
|
27
|
-
|
|
28
|
-
return val
|
|
29
|
-
|
|
30
|
-
def _proximal_poly_v_diag(x: np.ndarray, c, prox, x0: np.ndarray, derivatives):
|
|
31
|
-
"""Computes the value of the proximal polynomial approximation."""
|
|
32
|
-
if x.ndim == 2: x = x.T
|
|
33
|
-
s = x - x0
|
|
34
|
-
|
|
35
|
-
val = _poly_eval_diag(s, c, derivatives)
|
|
36
|
-
|
|
37
|
-
penalty = 0
|
|
38
|
-
if prox != 0:
|
|
39
|
-
penalty = (prox / 2) * (s**2).sum(-1)
|
|
40
|
-
|
|
41
|
-
return val + penalty
|
|
42
|
-
|
|
43
|
-
def _proximal_poly_g_diag(x: np.ndarray, c, prox, x0: np.ndarray, derivatives):
|
|
44
|
-
"""Computes the gradient of the proximal polynomial approximation."""
|
|
45
|
-
s = x - x0
|
|
46
|
-
|
|
47
|
-
g = derivatives[0].copy()
|
|
48
|
-
|
|
49
|
-
if len(derivatives) > 1:
|
|
50
|
-
for i, d_diag in enumerate(derivatives[1:], 2):
|
|
51
|
-
g += d_diag * (s**(i - 1)) / math.factorial(i - 1)
|
|
52
|
-
|
|
53
|
-
if prox != 0:
|
|
54
|
-
g += prox * s
|
|
55
|
-
|
|
56
|
-
return g
|
|
57
|
-
|
|
58
|
-
def _proximal_poly_H_diag(x: np.ndarray, c, prox, x0: np.ndarray, derivatives):
|
|
59
|
-
"""Computes the Hessian of the proximal polynomial approximation."""
|
|
60
|
-
s = x - x0
|
|
61
|
-
n = x.shape[0]
|
|
62
|
-
|
|
63
|
-
if len(derivatives) < 2:
|
|
64
|
-
H_diag = np.zeros(n, dtype=s.dtype)
|
|
65
|
-
else:
|
|
66
|
-
H_diag = derivatives[1].copy()
|
|
67
|
-
|
|
68
|
-
if len(derivatives) > 2:
|
|
69
|
-
for i, d_diag in enumerate(derivatives[2:], 3):
|
|
70
|
-
H_diag += d_diag * (s**(i - 2)) / math.factorial(i - 2)
|
|
71
|
-
|
|
72
|
-
if prox != 0:
|
|
73
|
-
H_diag += prox
|
|
74
|
-
|
|
75
|
-
return np.diag(H_diag)
|
|
76
|
-
|
|
77
|
-
def _poly_minimize(trust_region, prox, de_iters: Any, c, x: torch.Tensor, derivatives):
|
|
78
|
-
derivatives = [T.detach().cpu().numpy().astype(np.float64) for T in derivatives]
|
|
79
|
-
x0 = x.detach().cpu().numpy().astype(np.float64) # taylor series center
|
|
80
|
-
bounds = None
|
|
81
|
-
if trust_region is not None: bounds = list(zip(x0 - trust_region, x0 + trust_region))
|
|
82
|
-
|
|
83
|
-
# if len(derivatives) is 1, only gradient is available, I use that to test proximal penalty and bounds
|
|
84
|
-
if bounds is None:
|
|
85
|
-
if len(derivatives) == 1: method = 'bfgs'
|
|
86
|
-
else: method = 'trust-exact'
|
|
87
|
-
else:
|
|
88
|
-
if len(derivatives) == 1: method = 'l-bfgs-b'
|
|
89
|
-
else: method = 'trust-constr'
|
|
90
|
-
|
|
91
|
-
x_init = x0.copy()
|
|
92
|
-
v0 = _proximal_poly_v_diag(x0, c, prox, x0, derivatives)
|
|
93
|
-
if de_iters is not None and de_iters != 0:
|
|
94
|
-
if de_iters == -1: de_iters = None # let scipy decide
|
|
95
|
-
res = scipy.optimize.differential_evolution(
|
|
96
|
-
_proximal_poly_v_diag,
|
|
97
|
-
bounds if bounds is not None else list(zip(x0 - 10, x0 + 10)),
|
|
98
|
-
args=(c, prox, x0.copy(), derivatives),
|
|
99
|
-
maxiter=de_iters,
|
|
100
|
-
vectorized=True,
|
|
101
|
-
)
|
|
102
|
-
if res.fun < v0: x_init = res.x
|
|
103
|
-
|
|
104
|
-
res = scipy.optimize.minimize(
|
|
105
|
-
_proximal_poly_v_diag,
|
|
106
|
-
x_init,
|
|
107
|
-
method=method,
|
|
108
|
-
args=(c, prox, x0.copy(), derivatives),
|
|
109
|
-
jac=_proximal_poly_g_diag,
|
|
110
|
-
hess=_proximal_poly_H_diag,
|
|
111
|
-
bounds=bounds
|
|
112
|
-
)
|
|
113
|
-
|
|
114
|
-
return torch.from_numpy(res.x).to(x), res.fun
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
class DiagonalHigherOrderNewton(Module):
|
|
119
|
-
"""
|
|
120
|
-
Hvp with ones doesn't give you the diagonal unless derivatives are diagonal, but somehow it still works,
|
|
121
|
-
except it doesn't work in all cases except ones where it works.
|
|
122
|
-
"""
|
|
123
|
-
def __init__(
|
|
124
|
-
self,
|
|
125
|
-
order: int = 4,
|
|
126
|
-
trust_method: Literal['bounds', 'proximal', 'none'] | None = 'bounds',
|
|
127
|
-
increase: float = 1.5,
|
|
128
|
-
decrease: float = 0.75,
|
|
129
|
-
trust_init: float | None = None,
|
|
130
|
-
trust_tol: float = 1,
|
|
131
|
-
de_iters: int | None = None,
|
|
132
|
-
vectorize: bool = True,
|
|
133
|
-
):
|
|
134
|
-
if trust_init is None:
|
|
135
|
-
if trust_method == 'bounds': trust_init = 1
|
|
136
|
-
else: trust_init = 0.1
|
|
137
|
-
|
|
138
|
-
defaults = dict(order=order, trust_method=trust_method, increase=increase, decrease=decrease, trust_tol=trust_tol, trust_init=trust_init, vectorize=vectorize, de_iters=de_iters)
|
|
139
|
-
super().__init__(defaults)
|
|
140
|
-
|
|
141
|
-
@torch.no_grad
|
|
142
|
-
def step(self, var):
|
|
143
|
-
params = TensorList(var.params)
|
|
144
|
-
closure = var.closure
|
|
145
|
-
if closure is None: raise RuntimeError('NewtonCG requires closure')
|
|
146
|
-
|
|
147
|
-
settings = self.settings[params[0]]
|
|
148
|
-
order = settings['order']
|
|
149
|
-
increase = settings['increase']
|
|
150
|
-
decrease = settings['decrease']
|
|
151
|
-
trust_tol = settings['trust_tol']
|
|
152
|
-
trust_init = settings['trust_init']
|
|
153
|
-
trust_method = settings['trust_method']
|
|
154
|
-
de_iters = settings['de_iters']
|
|
155
|
-
|
|
156
|
-
trust_value = self.global_state.get('trust_value', trust_init)
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
# ------------------------ calculate grad and hessian ------------------------ #
|
|
160
|
-
with torch.enable_grad():
|
|
161
|
-
loss = var.loss = var.loss_approx = closure(False)
|
|
162
|
-
|
|
163
|
-
g = torch.autograd.grad(loss, params, create_graph=True)
|
|
164
|
-
var.grad = list(g)
|
|
165
|
-
|
|
166
|
-
derivatives = [g]
|
|
167
|
-
T = g # current derivatives tensor diagonal
|
|
168
|
-
ones = [torch.ones_like(t) for t in g]
|
|
169
|
-
|
|
170
|
-
# get all derivatives up to order
|
|
171
|
-
for o in range(2, order + 1):
|
|
172
|
-
T = hvp(params, T, ones, create_graph=o != order)
|
|
173
|
-
derivatives.append(T)
|
|
174
|
-
|
|
175
|
-
x0 = torch.cat([p.ravel() for p in params])
|
|
176
|
-
|
|
177
|
-
if trust_method is None: trust_method = 'none'
|
|
178
|
-
else: trust_method = trust_method.lower()
|
|
179
|
-
|
|
180
|
-
if trust_method == 'none':
|
|
181
|
-
trust_region = None
|
|
182
|
-
prox = 0
|
|
183
|
-
|
|
184
|
-
elif trust_method == 'bounds':
|
|
185
|
-
trust_region = trust_value
|
|
186
|
-
prox = 0
|
|
187
|
-
|
|
188
|
-
elif trust_method == 'proximal':
|
|
189
|
-
trust_region = None
|
|
190
|
-
prox = 1 / trust_value
|
|
191
|
-
|
|
192
|
-
else:
|
|
193
|
-
raise ValueError(trust_method)
|
|
194
|
-
|
|
195
|
-
x_star, expected_loss = _poly_minimize(
|
|
196
|
-
trust_region=trust_region,
|
|
197
|
-
prox=prox,
|
|
198
|
-
de_iters=de_iters,
|
|
199
|
-
c=loss.item(),
|
|
200
|
-
x=x0,
|
|
201
|
-
derivatives=[torch.cat([t.ravel() for t in d]) for d in derivatives],
|
|
202
|
-
)
|
|
203
|
-
|
|
204
|
-
# trust region
|
|
205
|
-
if trust_method != 'none':
|
|
206
|
-
expected_reduction = loss - expected_loss
|
|
207
|
-
|
|
208
|
-
vec_to_tensors_(x_star, params)
|
|
209
|
-
loss_star = closure(False)
|
|
210
|
-
vec_to_tensors_(x0, params)
|
|
211
|
-
reduction = loss - loss_star
|
|
212
|
-
|
|
213
|
-
# failed step
|
|
214
|
-
if reduction <= 0:
|
|
215
|
-
x_star = x0
|
|
216
|
-
self.global_state['trust_value'] = trust_value * decrease
|
|
217
|
-
|
|
218
|
-
# very good step
|
|
219
|
-
elif expected_reduction / reduction <= trust_tol:
|
|
220
|
-
self.global_state['trust_value'] = trust_value * increase
|
|
221
|
-
|
|
222
|
-
difference = vec_to_tensors(x0 - x_star, params)
|
|
223
|
-
var.update = list(difference)
|
|
224
|
-
return var
|
|
225
|
-
|
|
@@ -1,117 +0,0 @@
|
|
|
1
|
-
from contextlib import nullcontext
|
|
2
|
-
import warnings
|
|
3
|
-
from collections.abc import Callable
|
|
4
|
-
from functools import partial
|
|
5
|
-
import itertools
|
|
6
|
-
from typing import Literal
|
|
7
|
-
|
|
8
|
-
import torch
|
|
9
|
-
|
|
10
|
-
from ...core import Chainable, Module, apply_transform
|
|
11
|
-
from ...utils import TensorList, vec_to_tensors
|
|
12
|
-
from ...utils.derivatives import (
|
|
13
|
-
hessian_list_to_mat,
|
|
14
|
-
jacobian_wrt, jacobian_and_hessian_wrt, hessian_mat,
|
|
15
|
-
)
|
|
16
|
-
|
|
17
|
-
def _batched_dot(x, y):
|
|
18
|
-
return (x.unsqueeze(-2) @ y.unsqueeze(-1)).squeeze(-1).squeeze(-1)
|
|
19
|
-
|
|
20
|
-
def _cosine_similarity(x, y):
|
|
21
|
-
denom = torch.linalg.vector_norm(x, dim=-1) * torch.linalg.vector_norm(y, dim=-1).clip(min=torch.finfo(x.dtype).eps) # pylint:disable=not-callable
|
|
22
|
-
return _batched_dot(x, y) / denom
|
|
23
|
-
|
|
24
|
-
class EigenDescent(Module):
|
|
25
|
-
"""
|
|
26
|
-
Uses eigenvectors corresponding to certain eigenvalues. Please note that this is experimental and isn't guaranteed to work.
|
|
27
|
-
|
|
28
|
-
Args:
|
|
29
|
-
mode (str, optional):
|
|
30
|
-
- largest - use largest eigenvalue unless all eigenvalues are negative, then smallest is used.
|
|
31
|
-
- smallest - use smallest eigenvalue unless all eigenvalues are positive, then largest is used.
|
|
32
|
-
- mean-sign - use mean of eigenvectors multiplied by 1 or -1 if they point in opposite direction from gradient.
|
|
33
|
-
- mean-dot - use mean of eigenvectors multiplied by dot product with gradient.
|
|
34
|
-
- mean-cosine - use mean of eigenvectors multiplied by cosine similarity with gradient.
|
|
35
|
-
- mm - for testing.
|
|
36
|
-
|
|
37
|
-
Defaults to 'mean-sign'.
|
|
38
|
-
hessian_method (str, optional): how to calculate hessian. Defaults to "autograd".
|
|
39
|
-
vectorize (bool, optional): how to calculate hessian. Defaults to True.
|
|
40
|
-
|
|
41
|
-
"""
|
|
42
|
-
def __init__(
|
|
43
|
-
self,
|
|
44
|
-
mode: Literal['largest', 'smallest','magnitude', 'mean-sign', 'mean-dot', 'mean-cosine', 'mm'] = 'mean-sign',
|
|
45
|
-
hessian_method: Literal["autograd", "func", "autograd.functional"] = "autograd",
|
|
46
|
-
vectorize: bool = True,
|
|
47
|
-
):
|
|
48
|
-
defaults = dict(hessian_method=hessian_method, vectorize=vectorize, mode=mode)
|
|
49
|
-
super().__init__(defaults)
|
|
50
|
-
|
|
51
|
-
@torch.no_grad
|
|
52
|
-
def step(self, var):
|
|
53
|
-
params = TensorList(var.params)
|
|
54
|
-
closure = var.closure
|
|
55
|
-
if closure is None: raise RuntimeError('NewtonCG requires closure')
|
|
56
|
-
|
|
57
|
-
settings = self.settings[params[0]]
|
|
58
|
-
mode = settings['mode']
|
|
59
|
-
hessian_method = settings['hessian_method']
|
|
60
|
-
vectorize = settings['vectorize']
|
|
61
|
-
|
|
62
|
-
# ------------------------ calculate grad and hessian ------------------------ #
|
|
63
|
-
if hessian_method == 'autograd':
|
|
64
|
-
with torch.enable_grad():
|
|
65
|
-
loss = var.loss = var.loss_approx = closure(False)
|
|
66
|
-
g_list, H_list = jacobian_and_hessian_wrt([loss], params, batched=vectorize)
|
|
67
|
-
g_list = [t[0] for t in g_list] # remove leading dim from loss
|
|
68
|
-
var.grad = g_list
|
|
69
|
-
H = hessian_list_to_mat(H_list)
|
|
70
|
-
|
|
71
|
-
elif hessian_method in ('func', 'autograd.functional'):
|
|
72
|
-
strat = 'forward-mode' if vectorize else 'reverse-mode'
|
|
73
|
-
with torch.enable_grad():
|
|
74
|
-
g_list = var.get_grad(retain_graph=True)
|
|
75
|
-
H: torch.Tensor = hessian_mat(partial(closure, backward=False), params,
|
|
76
|
-
method=hessian_method, vectorize=vectorize, outer_jacobian_strategy=strat) # pyright:ignore[reportAssignmentType]
|
|
77
|
-
|
|
78
|
-
else:
|
|
79
|
-
raise ValueError(hessian_method)
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
# ----------------------------------- solve ---------------------------------- #
|
|
83
|
-
g = torch.cat([t.ravel() for t in g_list])
|
|
84
|
-
L, Q = torch.linalg.eigh(H) # L is sorted # pylint:disable=not-callable
|
|
85
|
-
if mode == 'largest':
|
|
86
|
-
# smallest eigenvalue if all eigenvalues are negative else largest
|
|
87
|
-
if L[-1] <= 0: d = Q[0]
|
|
88
|
-
else: d = Q[-1]
|
|
89
|
-
|
|
90
|
-
elif mode == 'smallest':
|
|
91
|
-
# smallest eigenvalue if negative eigenvalues exist else largest
|
|
92
|
-
if L[0] <= 0: d = Q[0]
|
|
93
|
-
else: d = Q[-1]
|
|
94
|
-
|
|
95
|
-
elif mode == 'magnitude':
|
|
96
|
-
# largest by magnitude
|
|
97
|
-
if L[0].abs() > L[-1].abs(): d = Q[0]
|
|
98
|
-
else: d = Q[-1]
|
|
99
|
-
|
|
100
|
-
elif mode == 'mean-dot':
|
|
101
|
-
d = ((g.unsqueeze(0) @ Q).squeeze(0) * Q).mean(1)
|
|
102
|
-
|
|
103
|
-
elif mode == 'mean-sign':
|
|
104
|
-
d = ((g.unsqueeze(0) @ Q).squeeze(0).sign() * Q).mean(1)
|
|
105
|
-
|
|
106
|
-
elif mode == 'mean-cosine':
|
|
107
|
-
d = (Q * _cosine_similarity(Q, g)).mean(1)
|
|
108
|
-
|
|
109
|
-
elif mode == 'mm':
|
|
110
|
-
d = (g.unsqueeze(0) @ Q).squeeze(0) / g.numel()
|
|
111
|
-
|
|
112
|
-
else:
|
|
113
|
-
raise ValueError(mode)
|
|
114
|
-
|
|
115
|
-
var.update = vec_to_tensors(g.dot(d).sign() * d, params)
|
|
116
|
-
return var
|
|
117
|
-
|
|
@@ -1,172 +0,0 @@
|
|
|
1
|
-
from typing import cast
|
|
2
|
-
import warnings
|
|
3
|
-
|
|
4
|
-
import torch
|
|
5
|
-
|
|
6
|
-
from ...core import Module
|
|
7
|
-
from ...utils import vec_to_tensors, vec_to_tensors_
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
class ExponentialTrajectoryFit(Module):
|
|
11
|
-
"""A method. Please note that this is experimental and isn't guaranteed to work."""
|
|
12
|
-
def __init__(self, step_size=1e-3):
|
|
13
|
-
defaults = dict(step_size = step_size)
|
|
14
|
-
super().__init__(defaults)
|
|
15
|
-
|
|
16
|
-
@torch.no_grad
|
|
17
|
-
def step(self, var):
|
|
18
|
-
closure = var.closure
|
|
19
|
-
assert closure is not None
|
|
20
|
-
step_size = self.settings[var.params[0]]['step_size']
|
|
21
|
-
|
|
22
|
-
# 1. perform 3 GD steps to obtain 4 points
|
|
23
|
-
points = [torch.cat([p.view(-1) for p in var.params])]
|
|
24
|
-
for i in range(3):
|
|
25
|
-
if i == 0: grad = var.get_grad()
|
|
26
|
-
else:
|
|
27
|
-
with torch.enable_grad(): closure()
|
|
28
|
-
grad = [cast(torch.Tensor, p.grad) for p in var.params]
|
|
29
|
-
|
|
30
|
-
# GD step
|
|
31
|
-
torch._foreach_sub_(var.params, grad, alpha=step_size)
|
|
32
|
-
|
|
33
|
-
points.append(torch.cat([p.view(-1) for p in var.params]))
|
|
34
|
-
|
|
35
|
-
assert len(points) == 4, len(points)
|
|
36
|
-
x0, x1, x2, x3 = points
|
|
37
|
-
dim = x0.numel()
|
|
38
|
-
|
|
39
|
-
# 2. fit a generalized exponential curve
|
|
40
|
-
d0 = (x1 - x0).unsqueeze(1) # column vectors
|
|
41
|
-
d1 = (x2 - x1).unsqueeze(1)
|
|
42
|
-
d2 = (x3 - x2).unsqueeze(1)
|
|
43
|
-
|
|
44
|
-
# cat
|
|
45
|
-
D1 = torch.cat([d0, d1], dim=1)
|
|
46
|
-
D2 = torch.cat([d1, d2], dim=1)
|
|
47
|
-
|
|
48
|
-
# if points are collinear this will happen on sphere and a quadratic "line search" will minimize it
|
|
49
|
-
if x0.numel() >= 2:
|
|
50
|
-
if torch.linalg.matrix_rank(D1) < 2: # pylint:disable=not-callable
|
|
51
|
-
pass # need to put a quadratic fit there
|
|
52
|
-
|
|
53
|
-
M = D2 @ torch.linalg.pinv(D1) # pylint:disable=not-callable # this defines the curve
|
|
54
|
-
|
|
55
|
-
# now we can predict x*
|
|
56
|
-
I = torch.eye(dim, device=x0.device, dtype=x0.dtype)
|
|
57
|
-
B = I - M
|
|
58
|
-
z = x1 - M @ x0
|
|
59
|
-
|
|
60
|
-
x_star = torch.linalg.lstsq(B, z).solution # pylint:disable=not-callable
|
|
61
|
-
|
|
62
|
-
vec_to_tensors_(x0, var.params)
|
|
63
|
-
difference = torch._foreach_sub(var.params, vec_to_tensors(x_star, var.params))
|
|
64
|
-
var.update = list(difference)
|
|
65
|
-
return var
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
class ExponentialTrajectoryFitV2(Module):
|
|
70
|
-
"""Should be better than one above, except it isn't. Please note that this is experimental and isn't guaranteed to work."""
|
|
71
|
-
def __init__(self, step_size=1e-3, num_steps: int= 4):
|
|
72
|
-
defaults = dict(step_size = step_size, num_steps=num_steps)
|
|
73
|
-
super().__init__(defaults)
|
|
74
|
-
|
|
75
|
-
@torch.no_grad
|
|
76
|
-
def step(self, var):
|
|
77
|
-
closure = var.closure
|
|
78
|
-
assert closure is not None
|
|
79
|
-
step_size = self.settings[var.params[0]]['step_size']
|
|
80
|
-
num_steps = self.settings[var.params[0]]['num_steps']
|
|
81
|
-
|
|
82
|
-
# 1. perform 3 GD steps to obtain 4 points (or more)
|
|
83
|
-
grad = var.get_grad()
|
|
84
|
-
points = [torch.cat([p.view(-1) for p in var.params])]
|
|
85
|
-
point_grads = [torch.cat([g.view(-1) for g in grad])]
|
|
86
|
-
|
|
87
|
-
for i in range(num_steps):
|
|
88
|
-
# GD step
|
|
89
|
-
torch._foreach_sub_(var.params, grad, alpha=step_size)
|
|
90
|
-
|
|
91
|
-
points.append(torch.cat([p.view(-1) for p in var.params]))
|
|
92
|
-
|
|
93
|
-
closure(backward=True)
|
|
94
|
-
grad = [cast(torch.Tensor, p.grad) for p in var.params]
|
|
95
|
-
point_grads.append(torch.cat([g.view(-1) for g in grad]))
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
X = torch.stack(points, 1) # dim, num_steps+1
|
|
99
|
-
G = torch.stack(point_grads, 1)
|
|
100
|
-
dim = points[0].numel()
|
|
101
|
-
|
|
102
|
-
X = torch.cat([X, torch.ones(1, num_steps+1, dtype=G.dtype, device=G.device)])
|
|
103
|
-
|
|
104
|
-
P = G @ torch.linalg.pinv(X) # pylint:disable=not-callable
|
|
105
|
-
A = P[:, :dim]
|
|
106
|
-
b = -P[:, dim]
|
|
107
|
-
|
|
108
|
-
# symmetrize
|
|
109
|
-
A = 0.5 * (A + A.T)
|
|
110
|
-
|
|
111
|
-
# predict x*
|
|
112
|
-
x_star = torch.linalg.lstsq(A, b).solution # pylint:disable=not-callable
|
|
113
|
-
|
|
114
|
-
vec_to_tensors_(points[0], var.params)
|
|
115
|
-
difference = torch._foreach_sub(var.params, vec_to_tensors(x_star, var.params))
|
|
116
|
-
var.update = list(difference)
|
|
117
|
-
return var
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
def _fit_exponential(y0, y1, y2):
|
|
123
|
-
"""x0, x1 and x2 are assumed to be 0, 1, 2"""
|
|
124
|
-
r = (y2 - y1) / (y1 - y0)
|
|
125
|
-
ones = r==1
|
|
126
|
-
r[ones] = 0
|
|
127
|
-
B = (y1 - y0) / (r - 1)
|
|
128
|
-
A = y0 - B
|
|
129
|
-
|
|
130
|
-
A[ones] = 0
|
|
131
|
-
B[ones] = 0
|
|
132
|
-
return A, B, r
|
|
133
|
-
|
|
134
|
-
class PointwiseExponential(Module):
|
|
135
|
-
"""A stupid method (for my youtube channel). Please note that this is experimental and isn't guaranteed to work."""
|
|
136
|
-
def __init__(self, step_size: float = 1e-3, reg: float = 1e-2, steps = 10000):
|
|
137
|
-
defaults = dict(reg=reg, steps=steps, step_size=step_size)
|
|
138
|
-
super().__init__(defaults)
|
|
139
|
-
|
|
140
|
-
@torch.no_grad
|
|
141
|
-
def step(self, var):
|
|
142
|
-
closure = var.closure
|
|
143
|
-
assert closure is not None
|
|
144
|
-
settings = self.settings[var.params[0]]
|
|
145
|
-
step_size = settings['step_size']
|
|
146
|
-
reg = settings['reg']
|
|
147
|
-
steps = settings['steps']
|
|
148
|
-
|
|
149
|
-
# 1. perform 2 GD steps to obtain 3 points
|
|
150
|
-
points = [torch.cat([p.view(-1) for p in var.params])]
|
|
151
|
-
for i in range(2):
|
|
152
|
-
if i == 0: grad = var.get_grad()
|
|
153
|
-
else:
|
|
154
|
-
with torch.enable_grad(): closure()
|
|
155
|
-
grad = [cast(torch.Tensor, p.grad) for p in var.params]
|
|
156
|
-
|
|
157
|
-
# GD step
|
|
158
|
-
torch._foreach_sub_(var.params, grad, alpha=step_size)
|
|
159
|
-
|
|
160
|
-
points.append(torch.cat([p.view(-1) for p in var.params]))
|
|
161
|
-
|
|
162
|
-
assert len(points) == 3, len(points)
|
|
163
|
-
y0, y1, y2 = points
|
|
164
|
-
|
|
165
|
-
A, B, r = _fit_exponential(y0, y1, y2)
|
|
166
|
-
r = r.clip(max = 1-reg)
|
|
167
|
-
x_star = A + B * r**steps
|
|
168
|
-
|
|
169
|
-
vec_to_tensors_(y0, var.params)
|
|
170
|
-
difference = torch._foreach_sub(var.params, vec_to_tensors(x_star, var.params))
|
|
171
|
-
var.update = list(difference)
|
|
172
|
-
return var
|
|
@@ -1,163 +0,0 @@
|
|
|
1
|
-
from operator import itemgetter
|
|
2
|
-
|
|
3
|
-
import torch
|
|
4
|
-
|
|
5
|
-
from ...core import Chainable, Transform
|
|
6
|
-
from ..optimizers.shampoo import _merge_small_dims, _unmerge_small_dims
|
|
7
|
-
from ..optimizers.soap import (
|
|
8
|
-
update_soap_covariances_,
|
|
9
|
-
get_orthogonal_matrix,
|
|
10
|
-
get_orthogonal_matrix_QR,
|
|
11
|
-
project,
|
|
12
|
-
project_back,
|
|
13
|
-
)
|
|
14
|
-
|
|
15
|
-
class SOAPY(Transform):
|
|
16
|
-
"""Adam but uses scaled gradient differences for GGᵀ. Please note that this is experimental and isn't guaranteed to work.
|
|
17
|
-
|
|
18
|
-
New args:
|
|
19
|
-
scale_by_s - whether to scale gradient differences by parameter differences
|
|
20
|
-
y_to_ema2 - whether to use gradient differences for exponential moving average too
|
|
21
|
-
"""
|
|
22
|
-
def __init__(
|
|
23
|
-
self,
|
|
24
|
-
beta1: float = 0.95,
|
|
25
|
-
beta2: float = 0.95,
|
|
26
|
-
shampoo_beta: float | None = 0.95,
|
|
27
|
-
precond_freq: int = 10,
|
|
28
|
-
merge_small: bool = True,
|
|
29
|
-
max_dim: int = 2_000,
|
|
30
|
-
precondition_1d: bool = True,
|
|
31
|
-
eps: float = 1e-8,
|
|
32
|
-
decay: float | None = None,
|
|
33
|
-
alpha: float = 1,
|
|
34
|
-
bias_correction: bool = True,
|
|
35
|
-
scale_by_s: bool = True,
|
|
36
|
-
y_to_ema2: bool = False,
|
|
37
|
-
):
|
|
38
|
-
defaults = dict(
|
|
39
|
-
beta1=beta1,
|
|
40
|
-
beta2=beta2,
|
|
41
|
-
shampoo_beta=shampoo_beta,
|
|
42
|
-
precond_freq=precond_freq,
|
|
43
|
-
merge_small=merge_small,
|
|
44
|
-
max_dim=max_dim,
|
|
45
|
-
precondition_1d=precondition_1d,
|
|
46
|
-
eps=eps,
|
|
47
|
-
decay=decay,
|
|
48
|
-
bias_correction=bias_correction,
|
|
49
|
-
alpha=alpha,
|
|
50
|
-
scale_by_s=scale_by_s,
|
|
51
|
-
y_to_ema2=y_to_ema2,
|
|
52
|
-
)
|
|
53
|
-
super().__init__(defaults, uses_grad=False)
|
|
54
|
-
|
|
55
|
-
@torch.no_grad
|
|
56
|
-
def apply(self, tensors, params, grads, loss, states, settings):
|
|
57
|
-
updates = []
|
|
58
|
-
# update preconditioners
|
|
59
|
-
for i,(p,t, state, setting) in enumerate(zip(params, tensors, states, settings)):
|
|
60
|
-
beta1, beta2, shampoo_beta, merge_small, max_dim, precondition_1d, eps, alpha = itemgetter(
|
|
61
|
-
'beta1', 'beta2', 'shampoo_beta', 'merge_small', 'max_dim', 'precondition_1d', 'eps', 'alpha')(setting)
|
|
62
|
-
scale_by_s = setting['scale_by_s']
|
|
63
|
-
y_to_ema2 = setting['y_to_ema2']
|
|
64
|
-
|
|
65
|
-
if merge_small:
|
|
66
|
-
t, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(t, max_dim)
|
|
67
|
-
|
|
68
|
-
if 'g_prev' not in state:
|
|
69
|
-
state['p_prev'] = p.clone()
|
|
70
|
-
state['g_prev'] = t.clone()
|
|
71
|
-
updates.append(tensors[i].clip(-0.1,0.1))
|
|
72
|
-
continue
|
|
73
|
-
|
|
74
|
-
p_prev = state['p_prev']
|
|
75
|
-
g_prev = state['g_prev']
|
|
76
|
-
s = p - p_prev
|
|
77
|
-
y = t - g_prev
|
|
78
|
-
if scale_by_s: y /= torch.linalg.norm(s).clip(min=1e-8) # pylint:disable=not-callable
|
|
79
|
-
|
|
80
|
-
state['p_prev'].copy_(p)
|
|
81
|
-
state['g_prev'].copy_(t)
|
|
82
|
-
|
|
83
|
-
# initialize state on 1st step
|
|
84
|
-
if 'GG' not in state:
|
|
85
|
-
state["exp_avg"] = torch.zeros_like(t)
|
|
86
|
-
if y_to_ema2: state["exp_avg_sq"] = torch.ones_like(t)
|
|
87
|
-
else: state["exp_avg_sq"] = torch.zeros_like(t)
|
|
88
|
-
|
|
89
|
-
if not precondition_1d and t.ndim <= 1:
|
|
90
|
-
state['GG'] = []
|
|
91
|
-
|
|
92
|
-
else:
|
|
93
|
-
state['GG'] = [torch.zeros(sh, sh, dtype=t.dtype, device=t.device) if 1<sh<max_dim else None for sh in t.shape]
|
|
94
|
-
|
|
95
|
-
# either scalar parameter, 1d with precondition_1d=False, or all dims are too big.
|
|
96
|
-
if len([i is not None for i in state['GG']]) == 0:
|
|
97
|
-
state['GG'] = None
|
|
98
|
-
|
|
99
|
-
if state['GG'] is not None:
|
|
100
|
-
update_soap_covariances_(y, GGs_=state['GG'], beta=shampoo_beta)
|
|
101
|
-
state['Q'] = get_orthogonal_matrix(state['GG'])
|
|
102
|
-
|
|
103
|
-
state['step'] = 0
|
|
104
|
-
updates.append(tensors[i].clip(-0.1,0.1))
|
|
105
|
-
continue # skip 1st step as in https://github.com/nikhilvyas/SOAP/blob/main/soap.py ?
|
|
106
|
-
# I use sign instead as to not mess up with next modules. 1st Adam step is always sign anyway.
|
|
107
|
-
|
|
108
|
-
# Projecting gradients to the eigenbases of Shampoo's preconditioner
|
|
109
|
-
# i.e. projecting to the eigenbases of matrices in state['GG']
|
|
110
|
-
z_projected = None
|
|
111
|
-
if state['GG'] is not None:
|
|
112
|
-
if y_to_ema2: z_projected = project(y, state['Q'])
|
|
113
|
-
else: z_projected = project(t, state['Q'])
|
|
114
|
-
|
|
115
|
-
# exponential moving averages
|
|
116
|
-
# this part could be foreached but I will do that at some point its not a big difference compared to preconditioning
|
|
117
|
-
exp_avg: torch.Tensor = state["exp_avg"]
|
|
118
|
-
exp_avg_sq: torch.Tensor = state["exp_avg_sq"]
|
|
119
|
-
|
|
120
|
-
exp_avg.lerp_(t, 1-beta1)
|
|
121
|
-
|
|
122
|
-
if z_projected is None:
|
|
123
|
-
if y_to_ema2: exp_avg_sq.mul_(beta2).addcmul_(y, y, value=1-beta2)
|
|
124
|
-
else: exp_avg_sq.mul_(beta2).addcmul_(t, t, value=1-beta2)
|
|
125
|
-
else:
|
|
126
|
-
exp_avg_sq.mul_(beta2).addcmul_(z_projected, z_projected, value=1-beta2)
|
|
127
|
-
|
|
128
|
-
# project exponential moving averages if they are accumulated unprojected
|
|
129
|
-
exp_avg_projected = exp_avg
|
|
130
|
-
if z_projected is not None:
|
|
131
|
-
exp_avg_projected = project(exp_avg, state['Q'])
|
|
132
|
-
|
|
133
|
-
exp_avg_sq_projected = exp_avg_sq
|
|
134
|
-
|
|
135
|
-
denom = exp_avg_sq_projected.sqrt().add_(eps)
|
|
136
|
-
# print(f'{t_projected = }, {exp_avg = }, {exp_avg_projected = }, {exp_avg_sq = }, {exp_avg_sq_projected = }, {denom = }')
|
|
137
|
-
|
|
138
|
-
# Projecting back the preconditioned (by Adam) exponential moving average of gradients
|
|
139
|
-
# to the original space
|
|
140
|
-
update = exp_avg_projected / denom
|
|
141
|
-
if z_projected is not None:
|
|
142
|
-
update = project_back(update, state["Q"])
|
|
143
|
-
|
|
144
|
-
if setting['bias_correction']:
|
|
145
|
-
bias_correction1 = 1.0 - beta1 ** (state["step"]+1)
|
|
146
|
-
bias_correction2 = 1.0 - beta2 ** (state["step"]+1)
|
|
147
|
-
update *= ((bias_correction2 ** .5) / bias_correction1) * alpha
|
|
148
|
-
elif alpha is not None:
|
|
149
|
-
update *= alpha
|
|
150
|
-
|
|
151
|
-
if merge_small:
|
|
152
|
-
update = _unmerge_small_dims(update, state['flat_sizes'], state['sort_idxs'])
|
|
153
|
-
|
|
154
|
-
updates.append(update)
|
|
155
|
-
state["step"] += 1
|
|
156
|
-
|
|
157
|
-
# Update is done after the gradient step to avoid using current gradients in the projection.
|
|
158
|
-
if state['GG'] is not None:
|
|
159
|
-
update_soap_covariances_(y, state['GG'], shampoo_beta)
|
|
160
|
-
if state['step'] % setting['precond_freq'] == 0:
|
|
161
|
-
state['Q'], state['exp_avg_sq'] = get_orthogonal_matrix_QR(exp_avg_sq, state['GG'], state['Q'])
|
|
162
|
-
|
|
163
|
-
return updates
|