torchzero 0.3.9__py3-none-any.whl → 0.3.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_opts.py +54 -21
- tests/test_tensorlist.py +2 -2
- tests/test_vars.py +61 -61
- torchzero/core/__init__.py +2 -3
- torchzero/core/module.py +49 -49
- torchzero/core/transform.py +219 -158
- torchzero/modules/__init__.py +1 -0
- torchzero/modules/clipping/clipping.py +10 -10
- torchzero/modules/clipping/ema_clipping.py +14 -13
- torchzero/modules/clipping/growth_clipping.py +16 -18
- torchzero/modules/experimental/__init__.py +12 -3
- torchzero/modules/experimental/absoap.py +50 -156
- torchzero/modules/experimental/adadam.py +15 -14
- torchzero/modules/experimental/adamY.py +17 -27
- torchzero/modules/experimental/adasoap.py +19 -129
- torchzero/modules/experimental/curveball.py +12 -12
- torchzero/modules/experimental/diagonal_higher_order_newton.py +225 -0
- torchzero/modules/experimental/eigendescent.py +117 -0
- torchzero/modules/experimental/etf.py +172 -0
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/newton_solver.py +11 -11
- torchzero/modules/experimental/newtonnewton.py +88 -0
- torchzero/modules/experimental/reduce_outward_lr.py +8 -5
- torchzero/modules/experimental/soapy.py +19 -146
- torchzero/modules/experimental/spectral.py +79 -204
- torchzero/modules/experimental/structured_newton.py +12 -12
- torchzero/modules/experimental/subspace_preconditioners.py +13 -10
- torchzero/modules/experimental/tada.py +38 -0
- torchzero/modules/grad_approximation/fdm.py +2 -2
- torchzero/modules/grad_approximation/forward_gradient.py +5 -5
- torchzero/modules/grad_approximation/grad_approximator.py +21 -21
- torchzero/modules/grad_approximation/rfdm.py +28 -15
- torchzero/modules/higher_order/__init__.py +1 -0
- torchzero/modules/higher_order/higher_order_newton.py +256 -0
- torchzero/modules/line_search/backtracking.py +42 -23
- torchzero/modules/line_search/line_search.py +40 -40
- torchzero/modules/line_search/scipy.py +18 -3
- torchzero/modules/line_search/strong_wolfe.py +21 -32
- torchzero/modules/line_search/trust_region.py +18 -6
- torchzero/modules/lr/__init__.py +1 -1
- torchzero/modules/lr/{step_size.py → adaptive.py} +22 -26
- torchzero/modules/lr/lr.py +20 -16
- torchzero/modules/momentum/averaging.py +25 -10
- torchzero/modules/momentum/cautious.py +73 -35
- torchzero/modules/momentum/ema.py +92 -41
- torchzero/modules/momentum/experimental.py +21 -13
- torchzero/modules/momentum/matrix_momentum.py +96 -54
- torchzero/modules/momentum/momentum.py +24 -4
- torchzero/modules/ops/accumulate.py +51 -21
- torchzero/modules/ops/binary.py +36 -36
- torchzero/modules/ops/debug.py +7 -7
- torchzero/modules/ops/misc.py +128 -129
- torchzero/modules/ops/multi.py +19 -19
- torchzero/modules/ops/reduce.py +16 -16
- torchzero/modules/ops/split.py +26 -26
- torchzero/modules/ops/switch.py +4 -4
- torchzero/modules/ops/unary.py +20 -20
- torchzero/modules/ops/utility.py +37 -37
- torchzero/modules/optimizers/adagrad.py +33 -24
- torchzero/modules/optimizers/adam.py +31 -34
- torchzero/modules/optimizers/lion.py +4 -4
- torchzero/modules/optimizers/muon.py +6 -6
- torchzero/modules/optimizers/orthograd.py +4 -5
- torchzero/modules/optimizers/rmsprop.py +13 -16
- torchzero/modules/optimizers/rprop.py +52 -49
- torchzero/modules/optimizers/shampoo.py +17 -23
- torchzero/modules/optimizers/soap.py +12 -19
- torchzero/modules/optimizers/sophia_h.py +13 -13
- torchzero/modules/projections/dct.py +4 -4
- torchzero/modules/projections/fft.py +6 -6
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +57 -57
- torchzero/modules/projections/structural.py +17 -17
- torchzero/modules/quasi_newton/__init__.py +33 -4
- torchzero/modules/quasi_newton/cg.py +67 -17
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +24 -24
- torchzero/modules/quasi_newton/lbfgs.py +12 -12
- torchzero/modules/quasi_newton/lsr1.py +11 -11
- torchzero/modules/quasi_newton/olbfgs.py +19 -19
- torchzero/modules/quasi_newton/quasi_newton.py +254 -47
- torchzero/modules/second_order/newton.py +32 -20
- torchzero/modules/second_order/newton_cg.py +13 -12
- torchzero/modules/second_order/nystrom.py +21 -21
- torchzero/modules/smoothing/gaussian.py +21 -21
- torchzero/modules/smoothing/laplacian.py +7 -9
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +43 -9
- torchzero/modules/wrappers/optim_wrapper.py +11 -11
- torchzero/optim/wrappers/directsearch.py +244 -0
- torchzero/optim/wrappers/fcmaes.py +97 -0
- torchzero/optim/wrappers/mads.py +90 -0
- torchzero/optim/wrappers/nevergrad.py +4 -4
- torchzero/optim/wrappers/nlopt.py +28 -14
- torchzero/optim/wrappers/optuna.py +70 -0
- torchzero/optim/wrappers/scipy.py +162 -13
- torchzero/utils/__init__.py +2 -6
- torchzero/utils/derivatives.py +2 -1
- torchzero/utils/optimizer.py +55 -74
- torchzero/utils/python_tools.py +17 -4
- {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/METADATA +14 -14
- torchzero-0.3.10.dist-info/RECORD +139 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/WHEEL +1 -1
- torchzero/core/preconditioner.py +0 -138
- torchzero/modules/experimental/algebraic_newton.py +0 -145
- torchzero/modules/experimental/tropical_newton.py +0 -136
- torchzero-0.3.9.dist-info/RECORD +0 -131
- {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
import itertools
|
|
2
|
+
import warnings
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
from contextlib import nullcontext
|
|
5
|
+
from functools import partial
|
|
6
|
+
from typing import Literal
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
from ...core import Chainable, Module, apply_transform
|
|
11
|
+
from ...utils import TensorList, vec_to_tensors
|
|
12
|
+
from ...utils.derivatives import (
|
|
13
|
+
hessian_list_to_mat,
|
|
14
|
+
jacobian_wrt,
|
|
15
|
+
)
|
|
16
|
+
from ..second_order.newton import (
|
|
17
|
+
cholesky_solve,
|
|
18
|
+
eigh_solve,
|
|
19
|
+
least_squares_solve,
|
|
20
|
+
lu_solve,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class NewtonNewton(Module):
|
|
25
|
+
"""
|
|
26
|
+
Method that I thought of and then it worked.
|
|
27
|
+
|
|
28
|
+
1. Calculate newton step by solving Hx=g
|
|
29
|
+
|
|
30
|
+
2. Calculate jacobian of x wrt parameters and call it H2
|
|
31
|
+
|
|
32
|
+
3. Solve H2 x2 = x for x2.
|
|
33
|
+
|
|
34
|
+
4. Optionally, repeat (if order is higher than 3.)
|
|
35
|
+
|
|
36
|
+
Memory is n^order. It tends to converge faster on convex functions, but can be unstable on non-convex. Orders higher than 3 are usually too unsable and have little benefit.
|
|
37
|
+
"""
|
|
38
|
+
def __init__(
|
|
39
|
+
self,
|
|
40
|
+
reg: float = 1e-6,
|
|
41
|
+
order: int = 3,
|
|
42
|
+
search_negative: bool = False,
|
|
43
|
+
vectorize: bool = True,
|
|
44
|
+
eigval_tfm: Callable[[torch.Tensor], torch.Tensor] | None = None,
|
|
45
|
+
):
|
|
46
|
+
defaults = dict(order=order, reg=reg, vectorize=vectorize, eigval_tfm=eigval_tfm, search_negative=search_negative)
|
|
47
|
+
super().__init__(defaults)
|
|
48
|
+
|
|
49
|
+
@torch.no_grad
|
|
50
|
+
def step(self, var):
|
|
51
|
+
params = TensorList(var.params)
|
|
52
|
+
closure = var.closure
|
|
53
|
+
if closure is None: raise RuntimeError('NewtonCG requires closure')
|
|
54
|
+
|
|
55
|
+
settings = self.settings[params[0]]
|
|
56
|
+
reg = settings['reg']
|
|
57
|
+
vectorize = settings['vectorize']
|
|
58
|
+
order = settings['order']
|
|
59
|
+
search_negative = settings['search_negative']
|
|
60
|
+
eigval_tfm = settings['eigval_tfm']
|
|
61
|
+
|
|
62
|
+
# ------------------------ calculate grad and hessian ------------------------ #
|
|
63
|
+
with torch.enable_grad():
|
|
64
|
+
loss = var.loss = var.loss_approx = closure(False)
|
|
65
|
+
g_list = torch.autograd.grad(loss, params, create_graph=True)
|
|
66
|
+
var.grad = list(g_list)
|
|
67
|
+
|
|
68
|
+
xp = torch.cat([t.ravel() for t in g_list])
|
|
69
|
+
I = torch.eye(xp.numel(), dtype=xp.dtype, device=xp.device)
|
|
70
|
+
|
|
71
|
+
for o in range(2, order + 1):
|
|
72
|
+
is_last = o == order
|
|
73
|
+
H_list = jacobian_wrt([xp], params, create_graph=not is_last, batched=vectorize)
|
|
74
|
+
with torch.no_grad() if is_last else nullcontext():
|
|
75
|
+
H = hessian_list_to_mat(H_list)
|
|
76
|
+
if reg != 0: H = H + I * reg
|
|
77
|
+
|
|
78
|
+
x = None
|
|
79
|
+
if search_negative or (is_last and eigval_tfm is not None):
|
|
80
|
+
x = eigh_solve(H, xp, eigval_tfm, search_negative=search_negative)
|
|
81
|
+
if x is None: x = cholesky_solve(H, xp)
|
|
82
|
+
if x is None: x = lu_solve(H, xp)
|
|
83
|
+
if x is None: x = least_squares_solve(H, xp)
|
|
84
|
+
xp = x.squeeze()
|
|
85
|
+
|
|
86
|
+
var.update = vec_to_tensors(xp, params)
|
|
87
|
+
return var
|
|
88
|
+
|
|
@@ -1,30 +1,33 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
|
|
3
3
|
from ...core import Target, Transform
|
|
4
|
-
from ...utils import TensorList
|
|
4
|
+
from ...utils import TensorList, unpack_states, unpack_dicts
|
|
5
5
|
|
|
6
6
|
class ReduceOutwardLR(Transform):
|
|
7
7
|
"""
|
|
8
8
|
When update sign matches weight sign, the learning rate for that weight is multiplied by `mul`.
|
|
9
9
|
|
|
10
10
|
This means updates that move weights towards zero have higher learning rates.
|
|
11
|
+
|
|
12
|
+
A note on this is that it sounded good but its really bad in practice.
|
|
11
13
|
"""
|
|
12
14
|
def __init__(self, mul = 0.5, use_grad=False, invert=False, target: Target = 'update'):
|
|
13
15
|
defaults = dict(mul=mul, use_grad=use_grad, invert=invert)
|
|
14
16
|
super().__init__(defaults, uses_grad=use_grad, target=target)
|
|
15
17
|
|
|
16
18
|
@torch.no_grad
|
|
17
|
-
def
|
|
19
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
18
20
|
params = TensorList(params)
|
|
19
21
|
tensors = TensorList(tensors)
|
|
20
22
|
|
|
21
|
-
mul =
|
|
22
|
-
s =
|
|
23
|
+
mul = [s['mul'] for s in settings]
|
|
24
|
+
s = settings[0]
|
|
23
25
|
use_grad = s['use_grad']
|
|
24
26
|
invert = s['invert']
|
|
25
27
|
|
|
26
|
-
if use_grad: cur =
|
|
28
|
+
if use_grad: cur = grads
|
|
27
29
|
else: cur = tensors
|
|
30
|
+
assert cur is not None
|
|
28
31
|
|
|
29
32
|
# mask of weights where sign matches with update sign (minus ascent sign), multiplied by `mul`.
|
|
30
33
|
if invert: mask = (params * cur) > 0
|
|
@@ -2,147 +2,22 @@ from operator import itemgetter
|
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
4
|
|
|
5
|
-
from ...core import Chainable, Transform
|
|
5
|
+
from ...core import Chainable, Transform
|
|
6
6
|
from ..optimizers.shampoo import _merge_small_dims, _unmerge_small_dims
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
)
|
|
14
|
-
for i, GG in enumerate(GGs_):
|
|
15
|
-
if GG is None: continue
|
|
16
|
-
|
|
17
|
-
axes = list(range(i)) + list(range(i + 1, grad.ndim)) # this works fine with 1d params
|
|
18
|
-
if beta is None: GG.add_(torch.tensordot(grad, grad, (axes, axes))) # pyright:ignore[reportArgumentType]
|
|
19
|
-
else: GG.lerp_(torch.tensordot(grad, grad, (axes, axes)), 1-beta) # pyright:ignore[reportArgumentType]
|
|
20
|
-
|
|
21
|
-
@torch.no_grad
|
|
22
|
-
def project(tensors: torch.Tensor, Q: list[torch.Tensor | None]):
|
|
23
|
-
"""
|
|
24
|
-
Projects the gradient to the eigenbases of the preconditioner.
|
|
25
|
-
"""
|
|
26
|
-
for mat in Q:
|
|
27
|
-
if mat is None: continue
|
|
28
|
-
if len(mat) > 0:
|
|
29
|
-
tensors = torch.tensordot(tensors, mat, dims=[[0], [0]]) # pyright:ignore[reportArgumentType]
|
|
30
|
-
else:
|
|
31
|
-
# I don't understand this part but it is in https://github.com/nikhilvyas/SOAP/blob/main/soap.py
|
|
32
|
-
permute_order = list(range(1, len(tensors.shape))) + [0]
|
|
33
|
-
tensors = tensors.permute(permute_order)
|
|
34
|
-
|
|
35
|
-
return tensors
|
|
36
|
-
|
|
37
|
-
@torch.no_grad
|
|
38
|
-
def project_back(tensors: torch.Tensor, Q: list[torch.Tensor| None]):
|
|
39
|
-
"""
|
|
40
|
-
Projects the gradient back to the original space.
|
|
41
|
-
"""
|
|
42
|
-
for mat in Q:
|
|
43
|
-
if mat is None: continue
|
|
44
|
-
if len(mat) > 0:
|
|
45
|
-
tensors = torch.tensordot(tensors, mat,dims=[[0], [1]]) # pyright:ignore[reportArgumentType]
|
|
46
|
-
else:
|
|
47
|
-
permute_order = list(range(1, len(tensors.shape))) + [0]
|
|
48
|
-
tensors = tensors.permute(permute_order)
|
|
49
|
-
|
|
50
|
-
return tensors
|
|
51
|
-
|
|
52
|
-
# function from https://github.com/nikhilvyas/SOAP/blob/main/soap.py
|
|
53
|
-
@torch.no_grad
|
|
54
|
-
def get_orthogonal_matrix(mat: list[torch.Tensor | None]):
|
|
55
|
-
"""
|
|
56
|
-
Computes the eigenbases of the preconditioner using torch.linalg.eigh decomposition.
|
|
57
|
-
"""
|
|
58
|
-
matrix = []
|
|
59
|
-
float_data = False
|
|
60
|
-
original_type = original_device = None
|
|
61
|
-
for m in mat:
|
|
62
|
-
if m is None: continue
|
|
63
|
-
if len(m) == 0:
|
|
64
|
-
matrix.append([])
|
|
65
|
-
continue
|
|
66
|
-
if m.dtype != torch.float:
|
|
67
|
-
original_type = m.dtype
|
|
68
|
-
original_device = m.device
|
|
69
|
-
matrix.append(m.float())
|
|
70
|
-
else:
|
|
71
|
-
float_data = True
|
|
72
|
-
matrix.append(m)
|
|
73
|
-
|
|
74
|
-
final = []
|
|
75
|
-
for m in matrix:
|
|
76
|
-
if len(m) == 0:
|
|
77
|
-
final.append([])
|
|
78
|
-
continue
|
|
79
|
-
try:
|
|
80
|
-
_, Q = torch.linalg.eigh(m+1e-30*torch.eye(m.shape[0], device=m.device)) # pylint:disable=not-callable
|
|
81
|
-
except Exception:
|
|
82
|
-
_, Q = torch.linalg.eigh(m.to(torch.float64)+1e-30*torch.eye(m.shape[0], device=m.device)) # pylint:disable=not-callable
|
|
83
|
-
Q = Q.to(m.dtype)
|
|
84
|
-
Q = torch.flip(Q, [1])
|
|
85
|
-
|
|
86
|
-
if not float_data:
|
|
87
|
-
Q = Q.to(original_device).type(original_type)
|
|
88
|
-
final.append(Q)
|
|
89
|
-
return final
|
|
90
|
-
|
|
91
|
-
# function from https://github.com/nikhilvyas/SOAP/blob/main/soap.py#L240
|
|
92
|
-
@torch.no_grad
|
|
93
|
-
def get_orthogonal_matrix_QR(exp_avg_sq: torch.Tensor, GG: list[torch.Tensor | None], Q_list: list[torch.Tensor | None]):
|
|
94
|
-
"""
|
|
95
|
-
Computes the eigenbases of the preconditioner using one round of power iteration
|
|
96
|
-
followed by torch.linalg.qr decomposition.
|
|
97
|
-
"""
|
|
98
|
-
matrix = []
|
|
99
|
-
orth_matrix = []
|
|
100
|
-
float_data = False
|
|
101
|
-
original_type = original_device = None
|
|
102
|
-
for m,o in zip(GG, Q_list):
|
|
103
|
-
if m is None: continue
|
|
104
|
-
assert o is not None
|
|
105
|
-
|
|
106
|
-
if len(m) == 0:
|
|
107
|
-
matrix.append([])
|
|
108
|
-
orth_matrix.append([])
|
|
109
|
-
continue
|
|
110
|
-
if m.data.dtype != torch.float:
|
|
111
|
-
original_type = m.data.dtype
|
|
112
|
-
original_device = m.data.device
|
|
113
|
-
matrix.append(m.data.float())
|
|
114
|
-
orth_matrix.append(o.data.float())
|
|
115
|
-
else:
|
|
116
|
-
float_data = True
|
|
117
|
-
matrix.append(m.data.float())
|
|
118
|
-
orth_matrix.append(o.data.float())
|
|
119
|
-
|
|
120
|
-
final = []
|
|
121
|
-
for ind, (m,o) in enumerate(zip(matrix, orth_matrix)):
|
|
122
|
-
if len(m)==0:
|
|
123
|
-
final.append([])
|
|
124
|
-
continue
|
|
125
|
-
est_eig = torch.diag(o.T @ m @ o)
|
|
126
|
-
sort_idx = torch.argsort(est_eig, descending=True)
|
|
127
|
-
exp_avg_sq = exp_avg_sq.index_select(ind, sort_idx)
|
|
128
|
-
o = o[:,sort_idx]
|
|
129
|
-
power_iter = m @ o
|
|
130
|
-
Q, _ = torch.linalg.qr(power_iter) # pylint:disable=not-callable
|
|
131
|
-
|
|
132
|
-
if not float_data:
|
|
133
|
-
Q = Q.to(original_device).type(original_type)
|
|
134
|
-
final.append(Q)
|
|
135
|
-
|
|
136
|
-
return final, exp_avg_sq
|
|
7
|
+
from ..optimizers.soap import (
|
|
8
|
+
update_soap_covariances_,
|
|
9
|
+
get_orthogonal_matrix,
|
|
10
|
+
get_orthogonal_matrix_QR,
|
|
11
|
+
project,
|
|
12
|
+
project_back,
|
|
13
|
+
)
|
|
137
14
|
|
|
138
15
|
class SOAPY(Transform):
|
|
139
|
-
"""
|
|
140
|
-
|
|
141
|
-
new args
|
|
142
|
-
|
|
143
|
-
scale by s whether to scale gradient differences by parameter differences
|
|
16
|
+
"""Adam but uses scaled gradient differences for GGᵀ. Please note that this is experimental and isn't guaranteed to work.
|
|
144
17
|
|
|
145
|
-
|
|
18
|
+
New args:
|
|
19
|
+
scale_by_s - whether to scale gradient differences by parameter differences
|
|
20
|
+
y_to_ema2 - whether to use gradient differences for exponential moving average too
|
|
146
21
|
"""
|
|
147
22
|
def __init__(
|
|
148
23
|
self,
|
|
@@ -178,16 +53,14 @@ class SOAPY(Transform):
|
|
|
178
53
|
super().__init__(defaults, uses_grad=False)
|
|
179
54
|
|
|
180
55
|
@torch.no_grad
|
|
181
|
-
def
|
|
56
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
182
57
|
updates = []
|
|
183
58
|
# update preconditioners
|
|
184
|
-
for i,(p,t) in enumerate(zip(params, tensors)):
|
|
185
|
-
state = self.state[p]
|
|
186
|
-
settings = self.settings[p]
|
|
59
|
+
for i,(p,t, state, setting) in enumerate(zip(params, tensors, states, settings)):
|
|
187
60
|
beta1, beta2, shampoo_beta, merge_small, max_dim, precondition_1d, eps, alpha = itemgetter(
|
|
188
|
-
'beta1', 'beta2', 'shampoo_beta', 'merge_small', 'max_dim', 'precondition_1d', 'eps', 'alpha')(
|
|
189
|
-
scale_by_s =
|
|
190
|
-
y_to_ema2 =
|
|
61
|
+
'beta1', 'beta2', 'shampoo_beta', 'merge_small', 'max_dim', 'precondition_1d', 'eps', 'alpha')(setting)
|
|
62
|
+
scale_by_s = setting['scale_by_s']
|
|
63
|
+
y_to_ema2 = setting['y_to_ema2']
|
|
191
64
|
|
|
192
65
|
if merge_small:
|
|
193
66
|
t, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(t, max_dim)
|
|
@@ -268,7 +141,7 @@ class SOAPY(Transform):
|
|
|
268
141
|
if z_projected is not None:
|
|
269
142
|
update = project_back(update, state["Q"])
|
|
270
143
|
|
|
271
|
-
if
|
|
144
|
+
if setting['bias_correction']:
|
|
272
145
|
bias_correction1 = 1.0 - beta1 ** (state["step"]+1)
|
|
273
146
|
bias_correction2 = 1.0 - beta2 ** (state["step"]+1)
|
|
274
147
|
update *= ((bias_correction2 ** .5) / bias_correction1) * alpha
|
|
@@ -284,7 +157,7 @@ class SOAPY(Transform):
|
|
|
284
157
|
# Update is done after the gradient step to avoid using current gradients in the projection.
|
|
285
158
|
if state['GG'] is not None:
|
|
286
159
|
update_soap_covariances_(y, state['GG'], shampoo_beta)
|
|
287
|
-
if state['step'] %
|
|
160
|
+
if state['step'] % setting['precond_freq'] == 0:
|
|
288
161
|
state['Q'], state['exp_avg_sq'] = get_orthogonal_matrix_QR(exp_avg_sq, state['GG'], state['Q'])
|
|
289
162
|
|
|
290
163
|
return updates
|