torchzero 0.3.10__py3-none-any.whl → 0.3.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +2 -3
- tests/test_opts.py +140 -100
- tests/test_tensorlist.py +8 -7
- tests/test_vars.py +1 -0
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +335 -50
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +197 -70
- torchzero/modules/__init__.py +13 -4
- torchzero/modules/adaptive/__init__.py +30 -0
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/adaptive/adahessian.py +224 -0
- torchzero/modules/{optimizers → adaptive}/adam.py +6 -8
- torchzero/modules/adaptive/adan.py +96 -0
- torchzero/modules/adaptive/adaptive_heavyball.py +54 -0
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/adaptive/esgd.py +171 -0
- torchzero/modules/{optimizers → adaptive}/lion.py +1 -1
- torchzero/modules/{experimental/spectral.py → adaptive/lmadagrad.py} +94 -71
- torchzero/modules/adaptive/mars.py +79 -0
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/adaptive/msam.py +188 -0
- torchzero/modules/{optimizers → adaptive}/muon.py +29 -5
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/orthograd.py +1 -1
- torchzero/modules/{optimizers → adaptive}/rmsprop.py +7 -4
- torchzero/modules/{optimizers → adaptive}/rprop.py +42 -10
- torchzero/modules/adaptive/sam.py +163 -0
- torchzero/modules/{optimizers → adaptive}/shampoo.py +47 -9
- torchzero/modules/{optimizers → adaptive}/soap.py +52 -65
- torchzero/modules/adaptive/sophia_h.py +185 -0
- torchzero/modules/clipping/clipping.py +115 -25
- torchzero/modules/clipping/ema_clipping.py +31 -17
- torchzero/modules/clipping/growth_clipping.py +8 -7
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/conjugate_gradient/cg.py +355 -0
- torchzero/modules/experimental/__init__.py +13 -19
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +5 -42
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +32 -15
- torchzero/modules/experimental/reduce_outward_lr.py +4 -4
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +13 -55
- torchzero/modules/functional.py +52 -6
- torchzero/modules/grad_approximation/fdm.py +30 -4
- torchzero/modules/grad_approximation/forward_gradient.py +16 -4
- torchzero/modules/grad_approximation/grad_approximator.py +51 -10
- torchzero/modules/grad_approximation/rfdm.py +321 -52
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +164 -93
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +124 -0
- torchzero/modules/line_search/backtracking.py +95 -57
- torchzero/modules/line_search/line_search.py +171 -22
- torchzero/modules/line_search/scipy.py +3 -3
- torchzero/modules/line_search/strong_wolfe.py +327 -199
- torchzero/modules/misc/__init__.py +35 -0
- torchzero/modules/misc/debug.py +48 -0
- torchzero/modules/misc/escape.py +62 -0
- torchzero/modules/misc/gradient_accumulation.py +136 -0
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +383 -0
- torchzero/modules/misc/multistep.py +194 -0
- torchzero/modules/misc/regularization.py +167 -0
- torchzero/modules/misc/split.py +123 -0
- torchzero/modules/{ops → misc}/switch.py +45 -4
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +9 -9
- torchzero/modules/momentum/cautious.py +51 -19
- torchzero/modules/momentum/momentum.py +37 -2
- torchzero/modules/ops/__init__.py +11 -31
- torchzero/modules/ops/accumulate.py +6 -10
- torchzero/modules/ops/binary.py +81 -34
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +16 -39
- torchzero/modules/ops/multi.py +82 -21
- torchzero/modules/ops/reduce.py +16 -8
- torchzero/modules/ops/unary.py +29 -13
- torchzero/modules/ops/utility.py +30 -18
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +3 -1
- torchzero/modules/projections/projection.py +190 -96
- torchzero/modules/quasi_newton/__init__.py +9 -14
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -0
- torchzero/modules/quasi_newton/lbfgs.py +286 -173
- torchzero/modules/quasi_newton/lsr1.py +185 -106
- torchzero/modules/quasi_newton/quasi_newton.py +816 -268
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +252 -0
- torchzero/modules/second_order/__init__.py +3 -2
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +292 -68
- torchzero/modules/second_order/newton_cg.py +365 -15
- torchzero/modules/second_order/nystrom.py +104 -1
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/laplacian.py +14 -4
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +387 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +97 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +94 -11
- torchzero/modules/wrappers/optim_wrapper.py +29 -1
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +359 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +39 -3
- torchzero/optim/wrappers/fcmaes.py +24 -15
- torchzero/optim/wrappers/mads.py +5 -6
- torchzero/optim/wrappers/nevergrad.py +16 -1
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +3 -3
- torchzero/optim/wrappers/scipy.py +86 -25
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +126 -114
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +369 -58
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/python_tools.py +16 -0
- torchzero/utils/tensorlist.py +134 -51
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.13.dist-info/METADATA +14 -0
- torchzero-0.3.13.dist-info/RECORD +166 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -57
- torchzero/modules/experimental/absoap.py +0 -250
- torchzero/modules/experimental/adadam.py +0 -112
- torchzero/modules/experimental/adamY.py +0 -125
- torchzero/modules/experimental/adasoap.py +0 -172
- torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
- torchzero/modules/experimental/eigendescent.py +0 -117
- torchzero/modules/experimental/etf.py +0 -172
- torchzero/modules/experimental/soapy.py +0 -163
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/experimental/subspace_preconditioners.py +0 -138
- torchzero/modules/experimental/tada.py +0 -38
- torchzero/modules/line_search/trust_region.py +0 -73
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/adaptive.py +0 -93
- torchzero/modules/lr/lr.py +0 -63
- torchzero/modules/momentum/matrix_momentum.py +0 -166
- torchzero/modules/ops/debug.py +0 -25
- torchzero/modules/ops/misc.py +0 -418
- torchzero/modules/ops/split.py +0 -75
- torchzero/modules/optimizers/__init__.py +0 -18
- torchzero/modules/optimizers/adagrad.py +0 -155
- torchzero/modules/optimizers/sophia_h.py +0 -129
- torchzero/modules/quasi_newton/cg.py +0 -268
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -266
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero/modules/smoothing/gaussian.py +0 -164
- torchzero-0.3.10.dist-info/METADATA +0 -379
- torchzero-0.3.10.dist-info/RECORD +0 -139
- torchzero-0.3.10.dist-info/licenses/LICENSE +0 -21
- {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
|
@@ -40,7 +40,9 @@ def rmsprop_(
|
|
|
40
40
|
return tensors_.div_(sqrt_exp_avg_sq.add_(eps))
|
|
41
41
|
|
|
42
42
|
class RMSprop(Transform):
|
|
43
|
-
"""Divides graient by EMA of gradient squares.
|
|
43
|
+
"""Divides graient by EMA of gradient squares.
|
|
44
|
+
|
|
45
|
+
This implementation is identical to :code:`torch.optim.RMSprop`.
|
|
44
46
|
|
|
45
47
|
Args:
|
|
46
48
|
smoothing (float, optional): beta for exponential moving average of gradient squares. Defaults to 0.99.
|
|
@@ -50,7 +52,8 @@ class RMSprop(Transform):
|
|
|
50
52
|
amsgrad (bool, optional): Whether to divide by maximum of EMA of gradient squares instead. Defaults to False.
|
|
51
53
|
pow (float, optional): power used in second momentum power and root. Defaults to 2.
|
|
52
54
|
init (str, optional): how to initialize EMA, either "update" to use first update or "zeros". Defaults to "update".
|
|
53
|
-
inner (Chainable | None, optional):
|
|
55
|
+
inner (Chainable | None, optional):
|
|
56
|
+
Inner modules that are applied after updating EMA and before preconditioning. Defaults to None.
|
|
54
57
|
"""
|
|
55
58
|
def __init__(
|
|
56
59
|
self,
|
|
@@ -60,7 +63,7 @@ class RMSprop(Transform):
|
|
|
60
63
|
debiased: bool = False,
|
|
61
64
|
amsgrad: bool = False,
|
|
62
65
|
pow: float = 2,
|
|
63
|
-
init: Literal["zeros", "update"] = "
|
|
66
|
+
init: Literal["zeros", "update"] = "zeros",
|
|
64
67
|
inner: Chainable | None = None,
|
|
65
68
|
):
|
|
66
69
|
defaults = dict(smoothing=smoothing,eps=eps,centered=centered,debiased=debiased,amsgrad=amsgrad,pow=pow,init=init)
|
|
@@ -69,7 +72,7 @@ class RMSprop(Transform):
|
|
|
69
72
|
if inner is not None:
|
|
70
73
|
self.set_child('inner', inner)
|
|
71
74
|
|
|
72
|
-
def
|
|
75
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
73
76
|
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
74
77
|
smoothing, eps = unpack_dicts(settings, 'smoothing', 'eps', cls=NumberList)
|
|
75
78
|
centered, debiased, amsgrad, pow, init = itemgetter('centered','debiased','amsgrad','pow','init')(settings[0])
|
|
@@ -135,7 +135,8 @@ class Rprop(Transform):
|
|
|
135
135
|
Next step, magnitude for that weight won't change.
|
|
136
136
|
|
|
137
137
|
Compared to pytorch this also implements backtracking update when sign changes.
|
|
138
|
-
|
|
138
|
+
|
|
139
|
+
This implementation is identical to :code:`torch.optim.Rprop` if :code:`backtrack` is set to False.
|
|
139
140
|
|
|
140
141
|
Args:
|
|
141
142
|
nplus (float): multiplicative increase factor for when ascent didn't change sign (default: 1.2).
|
|
@@ -164,7 +165,7 @@ class Rprop(Transform):
|
|
|
164
165
|
super().__init__(defaults, uses_grad=False)
|
|
165
166
|
|
|
166
167
|
@torch.no_grad
|
|
167
|
-
def
|
|
168
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
168
169
|
step = self.global_state.get('step', 0)
|
|
169
170
|
self.global_state['step'] = step + 1
|
|
170
171
|
|
|
@@ -223,7 +224,7 @@ class ScaleLRBySignChange(Transform):
|
|
|
223
224
|
super().__init__(defaults, uses_grad=use_grad, target=target)
|
|
224
225
|
|
|
225
226
|
@torch.no_grad
|
|
226
|
-
def
|
|
227
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
227
228
|
step = self.global_state.get('step', 0)
|
|
228
229
|
self.global_state['step'] = step + 1
|
|
229
230
|
|
|
@@ -257,8 +258,6 @@ class BacktrackOnSignChange(Transform):
|
|
|
257
258
|
This is part of RProp update rule.
|
|
258
259
|
|
|
259
260
|
Args:
|
|
260
|
-
normalize (bool, optional): renormalize update after masking. Defaults to False.
|
|
261
|
-
eps (_type_, optional): epsilon for normalization. Defaults to 1e-6.
|
|
262
261
|
use_grad (bool, optional):
|
|
263
262
|
if True, tracks sign change of the gradient,
|
|
264
263
|
otherwise track sign change of the update. Defaults to True.
|
|
@@ -272,7 +271,7 @@ class BacktrackOnSignChange(Transform):
|
|
|
272
271
|
super().__init__(defaults, uses_grad=use_grad)
|
|
273
272
|
|
|
274
273
|
@torch.no_grad
|
|
275
|
-
def
|
|
274
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
276
275
|
step = self.global_state.get('step', 0)
|
|
277
276
|
self.global_state['step'] = step + 1
|
|
278
277
|
|
|
@@ -294,12 +293,29 @@ class BacktrackOnSignChange(Transform):
|
|
|
294
293
|
return tensors
|
|
295
294
|
|
|
296
295
|
class SignConsistencyMask(Transform):
|
|
297
|
-
"""
|
|
296
|
+
"""
|
|
297
|
+
Outputs a mask of sign consistency of current and previous inputs.
|
|
298
|
+
|
|
299
|
+
The output is 0 for weights where input sign changed compared to previous input, 1 otherwise.
|
|
300
|
+
|
|
301
|
+
Examples:
|
|
302
|
+
|
|
303
|
+
GD that skips update for weights where gradient sign changed compared to previous gradient.
|
|
304
|
+
|
|
305
|
+
.. code-block:: python
|
|
306
|
+
|
|
307
|
+
opt = tz.Modular(
|
|
308
|
+
model.parameters(),
|
|
309
|
+
tz.m.Mul(tz.m.SignConsistencyMask()),
|
|
310
|
+
tz.m.LR(1e-2)
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
"""
|
|
298
314
|
def __init__(self,target: Target = 'update'):
|
|
299
315
|
super().__init__({}, uses_grad=False, target = target)
|
|
300
316
|
|
|
301
317
|
@torch.no_grad
|
|
302
|
-
def
|
|
318
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
303
319
|
prev = unpack_states(states, tensors, 'prev', cls=TensorList)
|
|
304
320
|
mask = prev.mul_(tensors).gt_(0)
|
|
305
321
|
prev.copy_(tensors)
|
|
@@ -307,7 +323,23 @@ class SignConsistencyMask(Transform):
|
|
|
307
323
|
|
|
308
324
|
|
|
309
325
|
class SignConsistencyLRs(Transform):
|
|
310
|
-
"""
|
|
326
|
+
"""Outputs per-weight learning rates based on consecutive sign consistency.
|
|
327
|
+
|
|
328
|
+
The learning rate for a weight is multiplied by :code:`nplus` when two consecutive update signs are the same, otherwise it is multiplied by :code:`nplus`. The learning rates are bounded to be in :code:`(lb, ub)` range.
|
|
329
|
+
|
|
330
|
+
Examples:
|
|
331
|
+
|
|
332
|
+
GD scaled by consecutive gradient sign consistency
|
|
333
|
+
|
|
334
|
+
.. code-block:: python
|
|
335
|
+
|
|
336
|
+
opt = tz.Modular(
|
|
337
|
+
model.parameters(),
|
|
338
|
+
tz.m.Mul(tz.m.SignConsistencyLRs()),
|
|
339
|
+
tz.m.LR(1e-2)
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
"""
|
|
311
343
|
def __init__(
|
|
312
344
|
self,
|
|
313
345
|
nplus: float = 1.2,
|
|
@@ -321,7 +353,7 @@ class SignConsistencyLRs(Transform):
|
|
|
321
353
|
super().__init__(defaults, uses_grad=False, target = target)
|
|
322
354
|
|
|
323
355
|
@torch.no_grad
|
|
324
|
-
def
|
|
356
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
325
357
|
step = self.global_state.get('step', 0)
|
|
326
358
|
self.global_state['step'] = step + 1
|
|
327
359
|
|
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
from contextlib import nullcontext
|
|
2
|
+
import torch
|
|
3
|
+
from ...utils import TensorList, NumberList
|
|
4
|
+
from ...core import Module
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class SAM(Module):
|
|
8
|
+
"""Sharpness-Aware Minimization from https://arxiv.org/pdf/2010.01412
|
|
9
|
+
|
|
10
|
+
SAM functions by seeking parameters that lie in neighborhoods having uniformly low loss value.
|
|
11
|
+
It performs two forward and backward passes per step.
|
|
12
|
+
|
|
13
|
+
This implementation modifies the closure to return loss and calculate gradients
|
|
14
|
+
of the SAM objective. All modules after this will use the modified objective.
|
|
15
|
+
|
|
16
|
+
.. note::
|
|
17
|
+
This module requires a closure passed to the optimizer step,
|
|
18
|
+
as it needs to re-evaluate the loss and gradients at two points on each step.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
rho (float, optional): Neighborhood size. Defaults to 0.05.
|
|
22
|
+
p (float, optional): norm of the SAM objective. Defaults to 2.
|
|
23
|
+
asam (bool, optional):
|
|
24
|
+
enables ASAM variant which makes perturbation relative to weight magnitudes.
|
|
25
|
+
ASAM requires a much larger :code:`rho`, like 0.5 or 1.
|
|
26
|
+
The :code:`tz.m.ASAM` class is idential to setting this argument to True, but
|
|
27
|
+
it has larger :code:`rho` by default.
|
|
28
|
+
|
|
29
|
+
Examples:
|
|
30
|
+
SAM-SGD:
|
|
31
|
+
|
|
32
|
+
.. code-block:: python
|
|
33
|
+
|
|
34
|
+
opt = tz.Modular(
|
|
35
|
+
model.parameters(),
|
|
36
|
+
tz.m.SAM(),
|
|
37
|
+
tz.m.LR(1e-2)
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
SAM-Adam:
|
|
41
|
+
|
|
42
|
+
.. code-block:: python
|
|
43
|
+
|
|
44
|
+
opt = tz.Modular(
|
|
45
|
+
model.parameters(),
|
|
46
|
+
tz.m.SAM(),
|
|
47
|
+
tz.m.Adam(),
|
|
48
|
+
tz.m.LR(1e-2)
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
References:
|
|
52
|
+
Foret, P., Kleiner, A., Mobahi, H., & Neyshabur, B. (2020). Sharpness-aware minimization for efficiently improving generalization. arXiv preprint arXiv:2010.01412. https://arxiv.org/abs/2010.01412#page=3.16
|
|
53
|
+
"""
|
|
54
|
+
def __init__(self, rho: float = 0.05, p: float = 2, eps=1e-10, asam=False):
|
|
55
|
+
defaults = dict(rho=rho, p=p, eps=eps, asam=asam)
|
|
56
|
+
super().__init__(defaults)
|
|
57
|
+
|
|
58
|
+
@torch.no_grad
|
|
59
|
+
def step(self, var):
|
|
60
|
+
|
|
61
|
+
params = var.params
|
|
62
|
+
closure = var.closure
|
|
63
|
+
zero_grad = var.zero_grad
|
|
64
|
+
if closure is None: raise RuntimeError("SAM requires a closure passed to the optimizer step")
|
|
65
|
+
p, rho = self.get_settings(var.params, 'p', 'rho', cls=NumberList)
|
|
66
|
+
s = self.defaults
|
|
67
|
+
eps = s['eps']
|
|
68
|
+
asam = s['asam']
|
|
69
|
+
|
|
70
|
+
# 1/p + 1/q = 1
|
|
71
|
+
# okay, authors of SAM paper, I will manually solve your equation
|
|
72
|
+
# so q = -p/(1-p)
|
|
73
|
+
q = -p / (1-p)
|
|
74
|
+
# as a validation for 2 it is -2 / -1 = 2
|
|
75
|
+
|
|
76
|
+
@torch.no_grad
|
|
77
|
+
def sam_closure(backward=True):
|
|
78
|
+
orig_grads = None
|
|
79
|
+
if not backward:
|
|
80
|
+
# if backward is False, make sure this doesn't modify gradients
|
|
81
|
+
# to avoid issues
|
|
82
|
+
orig_grads = [p.grad for p in params]
|
|
83
|
+
|
|
84
|
+
# gradient at initial parameters
|
|
85
|
+
zero_grad()
|
|
86
|
+
with torch.enable_grad():
|
|
87
|
+
closure()
|
|
88
|
+
|
|
89
|
+
grad = TensorList(p.grad if p.grad is not None else torch.zeros_like(p) for p in params)
|
|
90
|
+
grad_abs = grad.abs()
|
|
91
|
+
|
|
92
|
+
# compute e
|
|
93
|
+
term1 = grad.sign().mul_(rho)
|
|
94
|
+
term2 = grad_abs.pow(q-1)
|
|
95
|
+
|
|
96
|
+
if asam:
|
|
97
|
+
grad_abs.mul_(torch._foreach_abs(params))
|
|
98
|
+
|
|
99
|
+
denom = grad_abs.pow_(q).sum().pow(1/p)
|
|
100
|
+
|
|
101
|
+
e = term1.mul_(term2).div_(denom.clip(min=eps))
|
|
102
|
+
|
|
103
|
+
if asam:
|
|
104
|
+
e.mul_(torch._foreach_pow(params, 2))
|
|
105
|
+
|
|
106
|
+
# calculate loss and gradient approximation of inner problem
|
|
107
|
+
torch._foreach_add_(params, e)
|
|
108
|
+
if backward:
|
|
109
|
+
zero_grad()
|
|
110
|
+
with torch.enable_grad():
|
|
111
|
+
# this sets .grad attributes
|
|
112
|
+
sam_loss = closure()
|
|
113
|
+
|
|
114
|
+
else:
|
|
115
|
+
sam_loss = closure(False)
|
|
116
|
+
|
|
117
|
+
# and restore initial parameters
|
|
118
|
+
torch._foreach_sub_(params, e)
|
|
119
|
+
|
|
120
|
+
if orig_grads is not None:
|
|
121
|
+
for param,orig_grad in zip(params, orig_grads):
|
|
122
|
+
param.grad = orig_grad
|
|
123
|
+
|
|
124
|
+
return sam_loss
|
|
125
|
+
|
|
126
|
+
var.closure = sam_closure
|
|
127
|
+
return var
|
|
128
|
+
|
|
129
|
+
# different class because defaults for SAM are bad for ASAM
|
|
130
|
+
class ASAM(SAM):
|
|
131
|
+
"""Adaptive Sharpness-Aware Minimization from https://arxiv.org/pdf/2102.11600#page=6.52
|
|
132
|
+
|
|
133
|
+
SAM functions by seeking parameters that lie in neighborhoods having uniformly low loss value.
|
|
134
|
+
It performs two forward and backward passes per step.
|
|
135
|
+
|
|
136
|
+
This implementation modifies the closure to return loss and calculate gradients
|
|
137
|
+
of the SAM objective. All modules after this will use the modified objective.
|
|
138
|
+
|
|
139
|
+
.. note::
|
|
140
|
+
This module requires a closure passed to the optimizer step,
|
|
141
|
+
as it needs to re-evaluate the loss and gradients at two points on each step.
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
rho (float, optional): Neighborhood size. Defaults to 0.05.
|
|
145
|
+
p (float, optional): norm of the SAM objective. Defaults to 2.
|
|
146
|
+
|
|
147
|
+
Examples:
|
|
148
|
+
ASAM-Adam:
|
|
149
|
+
|
|
150
|
+
.. code-block:: python
|
|
151
|
+
|
|
152
|
+
opt = tz.Modular(
|
|
153
|
+
model.parameters(),
|
|
154
|
+
tz.m.ASAM(),
|
|
155
|
+
tz.m.Adam(),
|
|
156
|
+
tz.m.LR(1e-2)
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
References:
|
|
160
|
+
Kwon, J., Kim, J., Park, H., & Choi, I. K. (2021, July). Asam: Adaptive sharpness-aware minimization for scale-invariant learning of deep neural networks. In International Conference on Machine Learning (pp. 5905-5914). PMLR. https://arxiv.org/abs/2102.11600
|
|
161
|
+
"""
|
|
162
|
+
def __init__(self, rho: float = 0.5, p: float = 2, eps=1e-10):
|
|
163
|
+
super().__init__(rho=rho, p=p, eps=eps, asam=True)
|
|
@@ -17,6 +17,7 @@ def update_shampoo_preconditioner_(
|
|
|
17
17
|
update_freq: int,
|
|
18
18
|
exp_override: int | None,
|
|
19
19
|
beta: float | None,
|
|
20
|
+
reg: float
|
|
20
21
|
):
|
|
21
22
|
for i, (accumulator, preconditioner) in enumerate(zip(accumulators_, preconditioners_)):
|
|
22
23
|
if accumulator is None: continue
|
|
@@ -28,6 +29,8 @@ def update_shampoo_preconditioner_(
|
|
|
28
29
|
|
|
29
30
|
if step % update_freq == 0:
|
|
30
31
|
matrix_exp = -1/(grad.ndim*2) if exp_override is None else -1/exp_override
|
|
32
|
+
if reg != 0:
|
|
33
|
+
accumulator = accumulator + torch.eye(accumulator.size(0), device=accumulator.device, dtype=accumulator.dtype).mul_(reg)
|
|
31
34
|
set_storage_(preconditioner, matrix_power_eigh(accumulator, matrix_exp))
|
|
32
35
|
|
|
33
36
|
|
|
@@ -59,7 +62,7 @@ def _merge_small_dims(tensor: torch.Tensor, max_dim: int):
|
|
|
59
62
|
if tensor.shape[sort_idxs[0]] > max_dim:
|
|
60
63
|
return tensor, None, None
|
|
61
64
|
|
|
62
|
-
tensor = tensor.permute(*sort_idxs)
|
|
65
|
+
tensor = tensor.permute(*sort_idxs.tolist())
|
|
63
66
|
flatten_end_idx = 0
|
|
64
67
|
flat_sizes = []
|
|
65
68
|
flat_numel = 1
|
|
@@ -80,19 +83,27 @@ def _unmerge_small_dims(tensor: torch.Tensor, flat_sizes: Sequence[int] | None,
|
|
|
80
83
|
if flat_sizes is None: return tensor
|
|
81
84
|
assert sort_idxs is not None
|
|
82
85
|
tensor = tensor.unflatten(0, flat_sizes)
|
|
83
|
-
return tensor.permute(*np.argsort(sort_idxs))
|
|
86
|
+
return tensor.permute(*np.argsort(sort_idxs).tolist())
|
|
84
87
|
|
|
85
88
|
|
|
86
89
|
class Shampoo(Transform):
|
|
87
90
|
"""Shampoo from Preconditioned Stochastic Tensor Optimization (https://arxiv.org/abs/1802.09568).
|
|
88
91
|
|
|
92
|
+
.. note::
|
|
93
|
+
Shampoo is usually grafted to another optimizer like Adam, otherwise it can be unstable. An example of how to do grafting is given below in the Examples section.
|
|
94
|
+
|
|
95
|
+
.. note::
|
|
96
|
+
Shampoo is a very computationally expensive optimizer, increase :code:`update_freq` if it is too slow.
|
|
97
|
+
|
|
98
|
+
.. note::
|
|
99
|
+
SOAP optimizer usually outperforms Shampoo and is also not as computationally expensive. SOAP implementation is available as :code:`tz.m.SOAP`.
|
|
100
|
+
|
|
89
101
|
Args:
|
|
90
102
|
decay (float | None, optional): slowly decays preconditioners. Defaults to None.
|
|
91
103
|
beta (float | None, optional):
|
|
92
104
|
if None calculates sum as in standard shampoo, otherwise uses EMA of preconditioners. Defaults to None.
|
|
93
|
-
matrix_eps (float, optional): epsilon for matrix operations. Defaults to 1e-10.
|
|
94
105
|
update_freq (int, optional): preconditioner update frequency. Defaults to 10.
|
|
95
|
-
exp_override (int | None, optional): matrix exponent override, if not set, uses 2*ndim. Defaults to
|
|
106
|
+
exp_override (int | None, optional): matrix exponent override, if not set, uses 2*ndim. Defaults to 2.
|
|
96
107
|
merge_small (bool, optional): whether to merge small dims on tensors. Defaults to True.
|
|
97
108
|
max_dim (int, optional): maximum dimension size for preconditioning. Defaults to 2_000.
|
|
98
109
|
precondition_1d (bool, optional): whether to precondition 1d tensors. Defaults to True.
|
|
@@ -101,32 +112,58 @@ class Shampoo(Transform):
|
|
|
101
112
|
module applied after updating preconditioners and before applying preconditioning.
|
|
102
113
|
For example if beta≈0.999 and `inner=tz.m.EMA(0.9)`, this becomes Adam with shampoo preconditioner (ignoring debiasing).
|
|
103
114
|
Defaults to None.
|
|
115
|
+
|
|
116
|
+
Examples:
|
|
117
|
+
Shampoo grafted to Adam
|
|
118
|
+
|
|
119
|
+
.. code-block:: python
|
|
120
|
+
|
|
121
|
+
opt = tz.Modular(
|
|
122
|
+
model.parameters(),
|
|
123
|
+
tz.m.GraftModules(
|
|
124
|
+
direction = tz.m.Shampoo(),
|
|
125
|
+
magnitude = tz.m.Adam(),
|
|
126
|
+
),
|
|
127
|
+
tz.m.LR(1e-3)
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
Adam with Shampoo preconditioner
|
|
131
|
+
|
|
132
|
+
.. code-block:: python
|
|
133
|
+
|
|
134
|
+
opt = tz.Modular(
|
|
135
|
+
model.parameters(),
|
|
136
|
+
tz.m.Shampoo(beta=0.999, inner=tz.m.EMA(0.9)),
|
|
137
|
+
tz.m.Debias(0.9, 0.999),
|
|
138
|
+
tz.m.LR(1e-3)
|
|
139
|
+
)
|
|
104
140
|
"""
|
|
105
141
|
def __init__(
|
|
106
142
|
self,
|
|
107
143
|
decay: float | None = None,
|
|
108
144
|
beta: float | None = None,
|
|
145
|
+
reg: float = 1e-12,
|
|
109
146
|
update_freq: int = 10,
|
|
110
|
-
exp_override: int | None =
|
|
147
|
+
exp_override: int | None = 2,
|
|
111
148
|
merge_small: bool = True,
|
|
112
149
|
max_dim: int = 2_000,
|
|
113
150
|
precondition_1d: bool = True,
|
|
114
151
|
adagrad_eps: float = 1e-8,
|
|
115
152
|
inner: Chainable | None = None,
|
|
116
153
|
):
|
|
117
|
-
defaults = dict(decay=decay, beta=beta, update_freq=update_freq, exp_override=exp_override, merge_small=merge_small, max_dim=max_dim, precondition_1d=precondition_1d,adagrad_eps=adagrad_eps)
|
|
154
|
+
defaults = dict(decay=decay, beta=beta, update_freq=update_freq, exp_override=exp_override, merge_small=merge_small, max_dim=max_dim, precondition_1d=precondition_1d,adagrad_eps=adagrad_eps, reg=reg)
|
|
118
155
|
super().__init__(defaults, uses_grad=False)
|
|
119
156
|
|
|
120
157
|
if inner is not None:
|
|
121
158
|
self.set_child('inner', inner)
|
|
122
159
|
|
|
123
|
-
def
|
|
160
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
124
161
|
merged_tensors = [] # target with merged dims
|
|
125
162
|
|
|
126
163
|
# update preconditioners
|
|
127
164
|
for i,(t,state, setting) in enumerate(zip(tensors, states, settings)):
|
|
128
|
-
beta, update_freq, exp_override, merge_small, max_dim, precondition_1d = itemgetter(
|
|
129
|
-
'beta', 'update_freq', 'exp_override', 'merge_small', 'max_dim', 'precondition_1d')(setting)
|
|
165
|
+
beta, update_freq, exp_override, merge_small, max_dim, precondition_1d, reg = itemgetter(
|
|
166
|
+
'beta', 'update_freq', 'exp_override', 'merge_small', 'max_dim', 'precondition_1d', "reg")(setting)
|
|
130
167
|
|
|
131
168
|
if merge_small:
|
|
132
169
|
t, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(t, max_dim)
|
|
@@ -161,6 +198,7 @@ class Shampoo(Transform):
|
|
|
161
198
|
update_freq=update_freq,
|
|
162
199
|
exp_override=exp_override,
|
|
163
200
|
beta=beta,
|
|
201
|
+
reg=reg,
|
|
164
202
|
)
|
|
165
203
|
|
|
166
204
|
# inner step
|
|
@@ -1,9 +1,10 @@
|
|
|
1
1
|
from operator import itemgetter
|
|
2
|
+
import warnings
|
|
2
3
|
|
|
3
4
|
import torch
|
|
4
5
|
|
|
5
6
|
from ...core import Chainable, Transform, apply_transform
|
|
6
|
-
from ...modules.
|
|
7
|
+
from ...modules.adaptive.shampoo import _merge_small_dims, _unmerge_small_dims
|
|
7
8
|
|
|
8
9
|
@torch.no_grad
|
|
9
10
|
def update_soap_covariances_(
|
|
@@ -24,11 +25,9 @@ def project(tensors: torch.Tensor, Q: list[torch.Tensor | None]):
|
|
|
24
25
|
Projects the gradient to the eigenbases of the preconditioner.
|
|
25
26
|
"""
|
|
26
27
|
for mat in Q:
|
|
27
|
-
if mat is None:
|
|
28
|
-
if len(mat) > 0:
|
|
28
|
+
if mat is not None and len(mat) > 0:
|
|
29
29
|
tensors = torch.tensordot(tensors, mat, dims=[[0], [0]]) # pyright:ignore[reportArgumentType]
|
|
30
30
|
else:
|
|
31
|
-
# I don't understand this part but it is in https://github.com/nikhilvyas/SOAP/blob/main/soap.py
|
|
32
31
|
permute_order = list(range(1, len(tensors.shape))) + [0]
|
|
33
32
|
tensors = tensors.permute(permute_order)
|
|
34
33
|
|
|
@@ -40,8 +39,7 @@ def project_back(tensors: torch.Tensor, Q: list[torch.Tensor| None]):
|
|
|
40
39
|
Projects the gradient back to the original space.
|
|
41
40
|
"""
|
|
42
41
|
for mat in Q:
|
|
43
|
-
if mat is None:
|
|
44
|
-
if len(mat) > 0:
|
|
42
|
+
if mat is not None and len(mat) > 0:
|
|
45
43
|
tensors = torch.tensordot(tensors, mat,dims=[[0], [1]]) # pyright:ignore[reportArgumentType]
|
|
46
44
|
else:
|
|
47
45
|
permute_order = list(range(1, len(tensors.shape))) + [0]
|
|
@@ -55,37 +53,23 @@ def get_orthogonal_matrix(mat: list[torch.Tensor | None]):
|
|
|
55
53
|
"""
|
|
56
54
|
Computes the eigenbases of the preconditioner using torch.linalg.eigh decomposition.
|
|
57
55
|
"""
|
|
58
|
-
matrix = []
|
|
59
|
-
float_data = False
|
|
60
|
-
original_type = original_device = None
|
|
61
|
-
for m in mat:
|
|
62
|
-
if m is None: continue
|
|
63
|
-
if len(m) == 0:
|
|
64
|
-
matrix.append([])
|
|
65
|
-
continue
|
|
66
|
-
if m.dtype != torch.float:
|
|
67
|
-
original_type = m.dtype
|
|
68
|
-
original_device = m.device
|
|
69
|
-
matrix.append(m.float())
|
|
70
|
-
else:
|
|
71
|
-
float_data = True
|
|
72
|
-
matrix.append(m)
|
|
73
56
|
|
|
74
57
|
final = []
|
|
75
|
-
for m in
|
|
76
|
-
|
|
58
|
+
for m in mat:
|
|
59
|
+
|
|
60
|
+
if m is None or len(m) == 0:
|
|
77
61
|
final.append([])
|
|
78
62
|
continue
|
|
63
|
+
|
|
79
64
|
try:
|
|
80
65
|
_, Q = torch.linalg.eigh(m+1e-30*torch.eye(m.shape[0], device=m.device)) # pylint:disable=not-callable
|
|
81
|
-
except
|
|
66
|
+
except torch.linalg.LinAlgError:
|
|
82
67
|
_, Q = torch.linalg.eigh(m.to(torch.float64)+1e-30*torch.eye(m.shape[0], device=m.device)) # pylint:disable=not-callable
|
|
83
68
|
Q = Q.to(m.dtype)
|
|
84
|
-
Q = torch.flip(Q, [1])
|
|
85
69
|
|
|
86
|
-
|
|
87
|
-
Q = Q.to(original_device).type(original_type)
|
|
70
|
+
Q = torch.flip(Q, [1])
|
|
88
71
|
final.append(Q)
|
|
72
|
+
|
|
89
73
|
return final
|
|
90
74
|
|
|
91
75
|
# function from https://github.com/nikhilvyas/SOAP/blob/main/soap.py#L240
|
|
@@ -95,42 +79,24 @@ def get_orthogonal_matrix_QR(exp_avg_sq: torch.Tensor, GG: list[torch.Tensor | N
|
|
|
95
79
|
Computes the eigenbases of the preconditioner using one round of power iteration
|
|
96
80
|
followed by torch.linalg.qr decomposition.
|
|
97
81
|
"""
|
|
98
|
-
|
|
99
|
-
orth_matrix = []
|
|
100
|
-
float_data = False
|
|
101
|
-
original_type = original_device = None
|
|
102
|
-
for m,o in zip(GG, Q_list):
|
|
103
|
-
if m is None: continue
|
|
104
|
-
assert o is not None
|
|
82
|
+
final = []
|
|
105
83
|
|
|
106
|
-
|
|
107
|
-
matrix.append([])
|
|
108
|
-
orth_matrix.append([])
|
|
109
|
-
continue
|
|
110
|
-
if m.data.dtype != torch.float:
|
|
111
|
-
original_type = m.data.dtype
|
|
112
|
-
original_device = m.data.device
|
|
113
|
-
matrix.append(m.data.float())
|
|
114
|
-
orth_matrix.append(o.data.float())
|
|
115
|
-
else:
|
|
116
|
-
float_data = True
|
|
117
|
-
matrix.append(m.data.float())
|
|
118
|
-
orth_matrix.append(o.data.float())
|
|
84
|
+
for ind, (m,o) in enumerate(zip(GG, Q_list)):
|
|
119
85
|
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
if len(m)==0:
|
|
86
|
+
# skip 1d or large dims
|
|
87
|
+
if m is None or len(m) == 0:
|
|
123
88
|
final.append([])
|
|
124
89
|
continue
|
|
90
|
+
assert o is not None
|
|
91
|
+
|
|
125
92
|
est_eig = torch.diag(o.T @ m @ o)
|
|
126
93
|
sort_idx = torch.argsort(est_eig, descending=True)
|
|
127
94
|
exp_avg_sq = exp_avg_sq.index_select(ind, sort_idx)
|
|
128
|
-
o = o[:,sort_idx]
|
|
129
|
-
power_iter = m @ o
|
|
130
|
-
Q, _ = torch.linalg.qr(power_iter) # pylint:disable=not-callable
|
|
131
95
|
|
|
132
|
-
|
|
133
|
-
|
|
96
|
+
power_iter = m @ o[:, sort_idx]
|
|
97
|
+
Q, _ = torch.linalg.qr(power_iter.to(torch.float32)) # pylint:disable=not-callable
|
|
98
|
+
Q = Q.to(power_iter.dtype)
|
|
99
|
+
|
|
134
100
|
final.append(Q)
|
|
135
101
|
|
|
136
102
|
return final, exp_avg_sq
|
|
@@ -156,6 +122,24 @@ class SOAP(Transform):
|
|
|
156
122
|
learning rate. Defaults to 1.
|
|
157
123
|
bias_correction (bool, optional):
|
|
158
124
|
enables adam bias correction. Defaults to True.
|
|
125
|
+
|
|
126
|
+
Examples:
|
|
127
|
+
SOAP:
|
|
128
|
+
|
|
129
|
+
.. code-block:: python
|
|
130
|
+
|
|
131
|
+
opt = tz.Modular(model.parameters(), tz.m.SOAP(), tz.m.LR(1e-3))
|
|
132
|
+
|
|
133
|
+
Stabilized SOAP:
|
|
134
|
+
|
|
135
|
+
.. code-block:: python
|
|
136
|
+
|
|
137
|
+
opt = tz.Modular(
|
|
138
|
+
model.parameters(),
|
|
139
|
+
tz.m.SOAP(),
|
|
140
|
+
tz.m.NormalizeByEMA(max_ema_growth=1.2),
|
|
141
|
+
tz.m.LR(1e-2)
|
|
142
|
+
)
|
|
159
143
|
"""
|
|
160
144
|
def __init__(
|
|
161
145
|
self,
|
|
@@ -187,7 +171,7 @@ class SOAP(Transform):
|
|
|
187
171
|
super().__init__(defaults, uses_grad=False)
|
|
188
172
|
|
|
189
173
|
@torch.no_grad
|
|
190
|
-
def
|
|
174
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
191
175
|
updates = []
|
|
192
176
|
# update preconditioners
|
|
193
177
|
for i,(p,t, state, setting) in enumerate(zip(params, tensors, states, settings)):
|
|
@@ -200,7 +184,7 @@ class SOAP(Transform):
|
|
|
200
184
|
# initialize state on 1st step
|
|
201
185
|
if 'GG' not in state:
|
|
202
186
|
state["exp_avg"] = torch.zeros_like(t)
|
|
203
|
-
state["
|
|
187
|
+
state["exp_avg_sq_projected"] = torch.zeros_like(t)
|
|
204
188
|
|
|
205
189
|
if not precondition_1d and t.ndim <= 1:
|
|
206
190
|
state['GG'] = []
|
|
@@ -214,7 +198,10 @@ class SOAP(Transform):
|
|
|
214
198
|
|
|
215
199
|
if state['GG'] is not None:
|
|
216
200
|
update_soap_covariances_(t, GGs_=state['GG'], beta=shampoo_beta)
|
|
217
|
-
state['Q'] = get_orthogonal_matrix(state['GG'])
|
|
201
|
+
try: state['Q'] = get_orthogonal_matrix(state['GG'])
|
|
202
|
+
except torch.linalg.LinAlgError as e:
|
|
203
|
+
warnings.warn(f"torch.linalg.eigh raised an error when initializing SOAP Q matrices on 1st step, diagonal preconditioning will be used for this parameter. The error was:\n{e}")
|
|
204
|
+
state["GG"] = None
|
|
218
205
|
|
|
219
206
|
state['step'] = 0
|
|
220
207
|
updates.append(tensors[i].clip(-0.1, 0.1))
|
|
@@ -230,22 +217,20 @@ class SOAP(Transform):
|
|
|
230
217
|
# exponential moving averages
|
|
231
218
|
# this part could be foreached but I will do that at some point its not a big difference compared to preconditioning
|
|
232
219
|
exp_avg: torch.Tensor = state["exp_avg"]
|
|
233
|
-
|
|
220
|
+
exp_avg_sq_projected: torch.Tensor = state["exp_avg_sq_projected"]
|
|
234
221
|
|
|
235
222
|
exp_avg.lerp_(t, 1-beta1)
|
|
236
223
|
|
|
237
224
|
if t_projected is None:
|
|
238
|
-
|
|
225
|
+
exp_avg_sq_projected.mul_(beta2).addcmul_(t, t, value=1-beta2)
|
|
239
226
|
else:
|
|
240
|
-
|
|
227
|
+
exp_avg_sq_projected.mul_(beta2).addcmul_(t_projected, t_projected, value=1-beta2)
|
|
241
228
|
|
|
242
229
|
# project exponential moving averages if they are accumulated unprojected
|
|
243
230
|
exp_avg_projected = exp_avg
|
|
244
231
|
if t_projected is not None:
|
|
245
232
|
exp_avg_projected = project(exp_avg, state['Q'])
|
|
246
233
|
|
|
247
|
-
exp_avg_sq_projected = exp_avg_sq
|
|
248
|
-
|
|
249
234
|
denom = exp_avg_sq_projected.sqrt().add_(eps)
|
|
250
235
|
# print(f'{t_projected = }, {exp_avg = }, {exp_avg_projected = }, {exp_avg_sq = }, {exp_avg_sq_projected = }, {denom = }')
|
|
251
236
|
|
|
@@ -273,6 +258,8 @@ class SOAP(Transform):
|
|
|
273
258
|
if state['GG'] is not None:
|
|
274
259
|
update_soap_covariances_(t, state['GG'], shampoo_beta)
|
|
275
260
|
if state['step'] % setting['precond_freq'] == 0:
|
|
276
|
-
|
|
277
|
-
|
|
261
|
+
try:
|
|
262
|
+
state['Q'], state['exp_avg_sq_projected'] = get_orthogonal_matrix_QR(exp_avg_sq_projected, state['GG'], state['Q'])
|
|
263
|
+
except torch.linalg.LinAlgError:
|
|
264
|
+
pass
|
|
278
265
|
return updates
|