torchzero 0.3.10__py3-none-any.whl → 0.3.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +2 -3
- tests/test_opts.py +140 -100
- tests/test_tensorlist.py +8 -7
- tests/test_vars.py +1 -0
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +335 -50
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +197 -70
- torchzero/modules/__init__.py +13 -4
- torchzero/modules/adaptive/__init__.py +30 -0
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/adaptive/adahessian.py +224 -0
- torchzero/modules/{optimizers → adaptive}/adam.py +6 -8
- torchzero/modules/adaptive/adan.py +96 -0
- torchzero/modules/adaptive/adaptive_heavyball.py +54 -0
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/adaptive/esgd.py +171 -0
- torchzero/modules/{optimizers → adaptive}/lion.py +1 -1
- torchzero/modules/{experimental/spectral.py → adaptive/lmadagrad.py} +94 -71
- torchzero/modules/adaptive/mars.py +79 -0
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/adaptive/msam.py +188 -0
- torchzero/modules/{optimizers → adaptive}/muon.py +29 -5
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/orthograd.py +1 -1
- torchzero/modules/{optimizers → adaptive}/rmsprop.py +7 -4
- torchzero/modules/{optimizers → adaptive}/rprop.py +42 -10
- torchzero/modules/adaptive/sam.py +163 -0
- torchzero/modules/{optimizers → adaptive}/shampoo.py +47 -9
- torchzero/modules/{optimizers → adaptive}/soap.py +52 -65
- torchzero/modules/adaptive/sophia_h.py +185 -0
- torchzero/modules/clipping/clipping.py +115 -25
- torchzero/modules/clipping/ema_clipping.py +31 -17
- torchzero/modules/clipping/growth_clipping.py +8 -7
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/conjugate_gradient/cg.py +355 -0
- torchzero/modules/experimental/__init__.py +13 -19
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +5 -42
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +32 -15
- torchzero/modules/experimental/reduce_outward_lr.py +4 -4
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +13 -55
- torchzero/modules/functional.py +52 -6
- torchzero/modules/grad_approximation/fdm.py +30 -4
- torchzero/modules/grad_approximation/forward_gradient.py +16 -4
- torchzero/modules/grad_approximation/grad_approximator.py +51 -10
- torchzero/modules/grad_approximation/rfdm.py +321 -52
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +164 -93
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +124 -0
- torchzero/modules/line_search/backtracking.py +95 -57
- torchzero/modules/line_search/line_search.py +171 -22
- torchzero/modules/line_search/scipy.py +3 -3
- torchzero/modules/line_search/strong_wolfe.py +327 -199
- torchzero/modules/misc/__init__.py +35 -0
- torchzero/modules/misc/debug.py +48 -0
- torchzero/modules/misc/escape.py +62 -0
- torchzero/modules/misc/gradient_accumulation.py +136 -0
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +383 -0
- torchzero/modules/misc/multistep.py +194 -0
- torchzero/modules/misc/regularization.py +167 -0
- torchzero/modules/misc/split.py +123 -0
- torchzero/modules/{ops → misc}/switch.py +45 -4
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +9 -9
- torchzero/modules/momentum/cautious.py +51 -19
- torchzero/modules/momentum/momentum.py +37 -2
- torchzero/modules/ops/__init__.py +11 -31
- torchzero/modules/ops/accumulate.py +6 -10
- torchzero/modules/ops/binary.py +81 -34
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +16 -39
- torchzero/modules/ops/multi.py +82 -21
- torchzero/modules/ops/reduce.py +16 -8
- torchzero/modules/ops/unary.py +29 -13
- torchzero/modules/ops/utility.py +30 -18
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +3 -1
- torchzero/modules/projections/projection.py +190 -96
- torchzero/modules/quasi_newton/__init__.py +9 -14
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -0
- torchzero/modules/quasi_newton/lbfgs.py +286 -173
- torchzero/modules/quasi_newton/lsr1.py +185 -106
- torchzero/modules/quasi_newton/quasi_newton.py +816 -268
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +252 -0
- torchzero/modules/second_order/__init__.py +3 -2
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +292 -68
- torchzero/modules/second_order/newton_cg.py +365 -15
- torchzero/modules/second_order/nystrom.py +104 -1
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/laplacian.py +14 -4
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +387 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +97 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +94 -11
- torchzero/modules/wrappers/optim_wrapper.py +29 -1
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +359 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +39 -3
- torchzero/optim/wrappers/fcmaes.py +24 -15
- torchzero/optim/wrappers/mads.py +5 -6
- torchzero/optim/wrappers/nevergrad.py +16 -1
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +3 -3
- torchzero/optim/wrappers/scipy.py +86 -25
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +126 -114
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +369 -58
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/python_tools.py +16 -0
- torchzero/utils/tensorlist.py +134 -51
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.13.dist-info/METADATA +14 -0
- torchzero-0.3.13.dist-info/RECORD +166 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -57
- torchzero/modules/experimental/absoap.py +0 -250
- torchzero/modules/experimental/adadam.py +0 -112
- torchzero/modules/experimental/adamY.py +0 -125
- torchzero/modules/experimental/adasoap.py +0 -172
- torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
- torchzero/modules/experimental/eigendescent.py +0 -117
- torchzero/modules/experimental/etf.py +0 -172
- torchzero/modules/experimental/soapy.py +0 -163
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/experimental/subspace_preconditioners.py +0 -138
- torchzero/modules/experimental/tada.py +0 -38
- torchzero/modules/line_search/trust_region.py +0 -73
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/adaptive.py +0 -93
- torchzero/modules/lr/lr.py +0 -63
- torchzero/modules/momentum/matrix_momentum.py +0 -166
- torchzero/modules/ops/debug.py +0 -25
- torchzero/modules/ops/misc.py +0 -418
- torchzero/modules/ops/split.py +0 -75
- torchzero/modules/optimizers/__init__.py +0 -18
- torchzero/modules/optimizers/adagrad.py +0 -155
- torchzero/modules/optimizers/sophia_h.py +0 -129
- torchzero/modules/quasi_newton/cg.py +0 -268
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -266
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero/modules/smoothing/gaussian.py +0 -164
- torchzero-0.3.10.dist-info/METADATA +0 -379
- torchzero-0.3.10.dist-info/RECORD +0 -139
- torchzero-0.3.10.dist-info/licenses/LICENSE +0 -21
- {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
|
@@ -1,250 +0,0 @@
|
|
|
1
|
-
from operator import itemgetter
|
|
2
|
-
from typing import Literal
|
|
3
|
-
|
|
4
|
-
import torch
|
|
5
|
-
|
|
6
|
-
from ...core import Chainable, Transform
|
|
7
|
-
from ..optimizers.shampoo import _merge_small_dims, _unmerge_small_dims
|
|
8
|
-
from ..optimizers.soap import project, project_back, get_orthogonal_matrix, get_orthogonal_matrix_QR
|
|
9
|
-
|
|
10
|
-
@torch.no_grad
|
|
11
|
-
def update_absoap_covariances_(
|
|
12
|
-
g1: torch.Tensor,
|
|
13
|
-
g2: torch.Tensor,
|
|
14
|
-
GGs_: list[torch.Tensor | None],
|
|
15
|
-
beta: float | None,
|
|
16
|
-
):
|
|
17
|
-
for i, GG in enumerate(GGs_):
|
|
18
|
-
if GG is None: continue
|
|
19
|
-
|
|
20
|
-
axes = list(range(i)) + list(range(i + 1, g1.ndim)) # this works fine with 1d params
|
|
21
|
-
if beta is None: GG.add_(torch.tensordot(g1, g2, (axes, axes))) # pyright:ignore[reportArgumentType]
|
|
22
|
-
else: GG.lerp_(torch.tensordot(g1, g2, (axes, axes)), 1-beta) # pyright:ignore[reportArgumentType]
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
Source=Literal['p','g','s','y', 'gy', 'sy', 'sn', 'yn', 'gys', 'sys']
|
|
26
|
-
class ABSOAP(Transform):
|
|
27
|
-
"""SOAP but with some extra options for testing. Please note that this is experimental and isn't guaranteed to work.
|
|
28
|
-
|
|
29
|
-
Args:
|
|
30
|
-
scale_by_s - whether to scale y by s
|
|
31
|
-
gg1 - 1st vector into GGᵀ
|
|
32
|
-
gg2 - 2nd vector into GGᵀ
|
|
33
|
-
ema1 - vector into 1st momentum
|
|
34
|
-
ema2 - 2 vectors into 2nd momentum
|
|
35
|
-
rel1 - if True, multiplies gg1 by params
|
|
36
|
-
rel2 - same but for gg2
|
|
37
|
-
norm - if True, gg1 a and gg2 are normalized, and I need to make that into a letter
|
|
38
|
-
|
|
39
|
-
letters:
|
|
40
|
-
p - params
|
|
41
|
-
g - grad
|
|
42
|
-
s - param difference
|
|
43
|
-
y - grad difference
|
|
44
|
-
gy - g+y
|
|
45
|
-
sy - s+y
|
|
46
|
-
sn - s normalized
|
|
47
|
-
yn - y normalized
|
|
48
|
-
gys - g + y#g
|
|
49
|
-
sys - s + y#s
|
|
50
|
-
|
|
51
|
-
"""
|
|
52
|
-
def __init__(
|
|
53
|
-
self,
|
|
54
|
-
beta1: float = 0.95,
|
|
55
|
-
beta2: float = 0.95,
|
|
56
|
-
shampoo_beta: float | None = 0.95,
|
|
57
|
-
precond_freq: int = 10,
|
|
58
|
-
merge_small: bool = True,
|
|
59
|
-
max_dim: int = 2_000,
|
|
60
|
-
precondition_1d: bool = True,
|
|
61
|
-
eps: float = 1e-8,
|
|
62
|
-
decay: float | None = None,
|
|
63
|
-
alpha: float = 1,
|
|
64
|
-
bias_correction: bool = True,
|
|
65
|
-
scale_by_s: bool = True,
|
|
66
|
-
gg1: Source='g',
|
|
67
|
-
gg2: Source='g',
|
|
68
|
-
ema1: Source='g',
|
|
69
|
-
ema2: tuple[Source, Source] = ('g','g'),
|
|
70
|
-
rel1: bool=False,
|
|
71
|
-
rel2: bool=False,
|
|
72
|
-
norm: bool = False,
|
|
73
|
-
):
|
|
74
|
-
defaults = dict(
|
|
75
|
-
beta1=beta1,
|
|
76
|
-
beta2=beta2,
|
|
77
|
-
shampoo_beta=shampoo_beta,
|
|
78
|
-
precond_freq=precond_freq,
|
|
79
|
-
merge_small=merge_small,
|
|
80
|
-
max_dim=max_dim,
|
|
81
|
-
precondition_1d=precondition_1d,
|
|
82
|
-
eps=eps,
|
|
83
|
-
decay=decay,
|
|
84
|
-
bias_correction=bias_correction,
|
|
85
|
-
alpha=alpha,
|
|
86
|
-
scale_by_s=scale_by_s,
|
|
87
|
-
ema1=ema1,
|
|
88
|
-
ema2=ema2,
|
|
89
|
-
first=gg1,
|
|
90
|
-
second=gg2,
|
|
91
|
-
rel1=rel1, rel2=rel2,
|
|
92
|
-
norm=norm,
|
|
93
|
-
)
|
|
94
|
-
super().__init__(defaults, uses_grad=False)
|
|
95
|
-
|
|
96
|
-
@torch.no_grad
|
|
97
|
-
def apply(self, tensors, params, grads, loss, states, settings):
|
|
98
|
-
updates = []
|
|
99
|
-
# update preconditioners
|
|
100
|
-
for i,(p,t, state, setting) in enumerate(zip(params, tensors, states, settings)):
|
|
101
|
-
beta1, beta2, shampoo_beta, merge_small, max_dim, precondition_1d, eps, alpha = itemgetter(
|
|
102
|
-
'beta1', 'beta2', 'shampoo_beta', 'merge_small', 'max_dim', 'precondition_1d', 'eps', 'alpha')(setting)
|
|
103
|
-
scale_by_s = setting['scale_by_s']
|
|
104
|
-
ema1 = setting['ema1']
|
|
105
|
-
ema2 = setting['ema2']
|
|
106
|
-
first=setting['first']
|
|
107
|
-
second=setting['second']
|
|
108
|
-
rel1 = setting['rel1']; rel2 = setting['rel2']
|
|
109
|
-
norm=setting['norm']
|
|
110
|
-
|
|
111
|
-
if merge_small:
|
|
112
|
-
t, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(t, max_dim)
|
|
113
|
-
|
|
114
|
-
if 'g_prev' not in state:
|
|
115
|
-
state['p_prev'] = p.clone()
|
|
116
|
-
state['g_prev'] = t.clone()
|
|
117
|
-
# updates.append(tensors[i].clip(-0.1,0.1))
|
|
118
|
-
# continue
|
|
119
|
-
|
|
120
|
-
p_prev = state['p_prev']
|
|
121
|
-
g_prev = state['g_prev']
|
|
122
|
-
s = p - p_prev
|
|
123
|
-
y = t - g_prev
|
|
124
|
-
|
|
125
|
-
# keep malding
|
|
126
|
-
p_norm = torch.linalg.vector_norm(p) # pylint:disable=not-callable
|
|
127
|
-
g_norm = torch.linalg.vector_norm(t) # pylint:disable=not-callable
|
|
128
|
-
s_norm = torch.linalg.vector_norm(s) # pylint:disable=not-callable
|
|
129
|
-
y_norm = torch.linalg.vector_norm(y) # pylint:disable=not-callable
|
|
130
|
-
|
|
131
|
-
sn = p - p_prev * (p_norm / torch.linalg.vector_norm(p_prev))# pylint:disable=not-callable
|
|
132
|
-
yn = t - g_prev * (g_norm / torch.linalg.vector_norm(g_prev))# pylint:disable=not-callable
|
|
133
|
-
|
|
134
|
-
if scale_by_s: y /= s_norm.clip(min=1e-8) # pylint:disable=not-callable
|
|
135
|
-
|
|
136
|
-
state['p_prev'].copy_(p)
|
|
137
|
-
state['g_prev'].copy_(t)
|
|
138
|
-
|
|
139
|
-
def _get(c: Source):
|
|
140
|
-
if c == 'p': return p
|
|
141
|
-
if c == 'g': return t
|
|
142
|
-
if c == 's': return s
|
|
143
|
-
if c == 'y': return y
|
|
144
|
-
if c == 'sn': return sn
|
|
145
|
-
if c == 'yn': return yn
|
|
146
|
-
if c == 'gy': return t+y
|
|
147
|
-
if c == 'sy': return s+y
|
|
148
|
-
if c == 'gys':
|
|
149
|
-
y_scaled = y * (g_norm/y_norm.clip(min=1e-8))
|
|
150
|
-
return t+y_scaled
|
|
151
|
-
if c == 'sys':
|
|
152
|
-
y_scaled = y * (s_norm/y_norm.clip(min=1e-8))
|
|
153
|
-
return s+y_scaled
|
|
154
|
-
raise RuntimeError("Big Chungus")
|
|
155
|
-
|
|
156
|
-
t1 = _get(first)
|
|
157
|
-
if rel1: t1 = t1 * p.abs().clip(min=1e-6)
|
|
158
|
-
t2 = _get(second)
|
|
159
|
-
if rel2: t2 = t2 * p.abs().clip(min=1e-6)
|
|
160
|
-
|
|
161
|
-
t_ema1 = _get(ema1)
|
|
162
|
-
t_ema2s = _get(ema2[0]), _get(ema2[1])
|
|
163
|
-
|
|
164
|
-
if norm:
|
|
165
|
-
t1 = t1/torch.linalg.vector_norm(t1).clip(min=1e-8) # pylint:disable=not-callable
|
|
166
|
-
t2 = t2/torch.linalg.vector_norm(t2).clip(min=1e-8) # pylint:disable=not-callable
|
|
167
|
-
|
|
168
|
-
# initialize state on 1st step
|
|
169
|
-
if 'GG' not in state:
|
|
170
|
-
state["exp_avg"] = torch.zeros_like(t)
|
|
171
|
-
state["exp_avg_sq"] = torch.zeros_like(t)
|
|
172
|
-
|
|
173
|
-
if not precondition_1d and t.ndim <= 1:
|
|
174
|
-
state['GG'] = []
|
|
175
|
-
|
|
176
|
-
else:
|
|
177
|
-
state['GG'] = [torch.zeros(sh, sh, dtype=t.dtype, device=t.device) if 1<sh<max_dim else None for sh in t.shape]
|
|
178
|
-
|
|
179
|
-
# either scalar parameter, 1d with precondition_1d=False, or all dims are too big.
|
|
180
|
-
if len([i is not None for i in state['GG']]) == 0:
|
|
181
|
-
state['GG'] = None
|
|
182
|
-
|
|
183
|
-
if state['GG'] is not None:
|
|
184
|
-
update_absoap_covariances_(t1, t2, GGs_=state['GG'], beta=shampoo_beta)
|
|
185
|
-
state['Q'] = get_orthogonal_matrix(state['GG'])
|
|
186
|
-
|
|
187
|
-
state['step'] = 0
|
|
188
|
-
updates.append(tensors[i].clip(-0.1,0.1))
|
|
189
|
-
continue # skip 1st step as in https://github.com/nikhilvyas/SOAP/blob/main/soap.py ?
|
|
190
|
-
# I use sign instead as to not mess up with next modules. 1st Adam step is always sign anyway.
|
|
191
|
-
|
|
192
|
-
# Projecting gradients to the eigenbases of Shampoo's preconditioner
|
|
193
|
-
# i.e. projecting to the eigenbases of matrices in state['GG']
|
|
194
|
-
z1_projected = None
|
|
195
|
-
z2_projected = None
|
|
196
|
-
|
|
197
|
-
if state['GG'] is not None:
|
|
198
|
-
z1_projected = project(t_ema2s[0], state['Q'])
|
|
199
|
-
if ema2[0] == ema2[1]: z2_projected = z1_projected
|
|
200
|
-
else: z2_projected = project(t_ema2s[1], state['Q'])
|
|
201
|
-
|
|
202
|
-
# exponential moving averages
|
|
203
|
-
# this part could be foreached but I will do that at some point its not a big difference compared to preconditioning
|
|
204
|
-
exp_avg: torch.Tensor = state["exp_avg"]
|
|
205
|
-
exp_avg_sq: torch.Tensor = state["exp_avg_sq"]
|
|
206
|
-
|
|
207
|
-
exp_avg.lerp_(t_ema1, 1-beta1)
|
|
208
|
-
|
|
209
|
-
if z1_projected is None:
|
|
210
|
-
exp_avg_sq.mul_(beta2).addcmul_(*t_ema2s, value=1-beta2)
|
|
211
|
-
else:
|
|
212
|
-
assert z2_projected is not None
|
|
213
|
-
exp_avg_sq.mul_(beta2).addcmul_(z1_projected, z2_projected, value=1-beta2)
|
|
214
|
-
|
|
215
|
-
# project exponential moving averages if they are accumulated unprojected
|
|
216
|
-
exp_avg_projected = exp_avg
|
|
217
|
-
if z1_projected is not None:
|
|
218
|
-
exp_avg_projected = project(exp_avg, state['Q'])
|
|
219
|
-
|
|
220
|
-
exp_avg_sq_projected = exp_avg_sq
|
|
221
|
-
|
|
222
|
-
denom = exp_avg_sq_projected.sqrt().add_(eps)
|
|
223
|
-
# print(f'{t_projected = }, {exp_avg = }, {exp_avg_projected = }, {exp_avg_sq = }, {exp_avg_sq_projected = }, {denom = }')
|
|
224
|
-
|
|
225
|
-
# Projecting back the preconditioned (by Adam) exponential moving average of gradients
|
|
226
|
-
# to the original space
|
|
227
|
-
update = exp_avg_projected / denom
|
|
228
|
-
if z1_projected is not None:
|
|
229
|
-
update = project_back(update, state["Q"])
|
|
230
|
-
|
|
231
|
-
if setting['bias_correction']:
|
|
232
|
-
bias_correction1 = 1.0 - beta1 ** (state["step"]+1)
|
|
233
|
-
bias_correction2 = 1.0 - beta2 ** (state["step"]+1)
|
|
234
|
-
update *= ((bias_correction2 ** .5) / bias_correction1) * alpha
|
|
235
|
-
elif alpha is not None:
|
|
236
|
-
update *= alpha
|
|
237
|
-
|
|
238
|
-
if merge_small:
|
|
239
|
-
update = _unmerge_small_dims(update, state['flat_sizes'], state['sort_idxs'])
|
|
240
|
-
|
|
241
|
-
updates.append(update)
|
|
242
|
-
state["step"] += 1
|
|
243
|
-
|
|
244
|
-
# Update is done after the gradient step to avoid using current gradients in the projection.
|
|
245
|
-
if state['GG'] is not None:
|
|
246
|
-
update_absoap_covariances_(t1, t2, state['GG'], shampoo_beta)
|
|
247
|
-
if state['step'] % setting['precond_freq'] == 0:
|
|
248
|
-
state['Q'], state['exp_avg_sq'] = get_orthogonal_matrix_QR(exp_avg_sq, state['GG'], state['Q'])
|
|
249
|
-
|
|
250
|
-
return updates
|
|
@@ -1,112 +0,0 @@
|
|
|
1
|
-
from operator import itemgetter
|
|
2
|
-
from functools import partial
|
|
3
|
-
|
|
4
|
-
import torch
|
|
5
|
-
|
|
6
|
-
from ...core import Module, Target, Transform
|
|
7
|
-
from ...utils import NumberList, TensorList
|
|
8
|
-
from ..functional import (
|
|
9
|
-
debias, debiased_step_size,
|
|
10
|
-
ema_,
|
|
11
|
-
sqrt_ema_sq_,
|
|
12
|
-
)
|
|
13
|
-
from ..lr.lr import lazy_lr
|
|
14
|
-
from ..momentum.experimental import sqrt_nag_ema_sq_
|
|
15
|
-
from ..momentum.momentum import nag_
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
def adadam_(
|
|
19
|
-
tensors: TensorList,
|
|
20
|
-
exp_avg_: TensorList,
|
|
21
|
-
exp_avg_sq_: TensorList,
|
|
22
|
-
exp_avg_qu_: TensorList,
|
|
23
|
-
alpha: float | NumberList,
|
|
24
|
-
beta1: float | NumberList,
|
|
25
|
-
beta2: float | NumberList,
|
|
26
|
-
precond_beta: float | NumberList,
|
|
27
|
-
eps: float | NumberList,
|
|
28
|
-
step: int,
|
|
29
|
-
pow: float = 2,
|
|
30
|
-
debiased: bool = True,
|
|
31
|
-
max_exp_avg_sq_: TensorList | None = None,
|
|
32
|
-
max_exp_avg_qu_: TensorList | None = None,
|
|
33
|
-
params_: TensorList | None = None,
|
|
34
|
-
):
|
|
35
|
-
"""Returns new tensors or updates params in-place."""
|
|
36
|
-
exp_avg_ = ema_(tensors, exp_avg_=exp_avg_, beta=beta1, dampening=0,lerp=True)
|
|
37
|
-
|
|
38
|
-
sqrt_exp_avg_sq = sqrt_ema_sq_(tensors, exp_avg_sq_=exp_avg_sq_, beta=beta2, max_exp_avg_sq_=max_exp_avg_sq_,
|
|
39
|
-
debiased=False,step=step,pow=pow)
|
|
40
|
-
sqrt_exp_avg_qu = sqrt_ema_sq_(tensors/(sqrt_exp_avg_sq+1e-8), exp_avg_sq_=exp_avg_qu_,
|
|
41
|
-
beta=precond_beta,max_exp_avg_sq_=max_exp_avg_qu_, debiased=False,step=step,pow=pow)
|
|
42
|
-
|
|
43
|
-
if debiased: alpha = debiased_step_size(step, beta1=beta1, beta2=beta2, pow=pow, alpha=alpha)
|
|
44
|
-
|
|
45
|
-
# params is None, return update
|
|
46
|
-
if params_ is None: return (exp_avg_ / sqrt_exp_avg_qu.add_(eps)).lazy_mul(alpha)
|
|
47
|
-
|
|
48
|
-
# update params in-place
|
|
49
|
-
params_.addcdiv_(exp_avg_, sqrt_exp_avg_qu.add_(eps), -alpha)
|
|
50
|
-
return None
|
|
51
|
-
|
|
52
|
-
class Adadam(Module):
|
|
53
|
-
"""Adam with a diagonally preconditioned preconditioner. Please note that this is experimental and isn't guaranteed to work."""
|
|
54
|
-
def __init__(
|
|
55
|
-
self,
|
|
56
|
-
beta1: float = 0.9,
|
|
57
|
-
beta2: float = 0.999,
|
|
58
|
-
precond_beta: float = 0.999,
|
|
59
|
-
eps: float = 1e-8,
|
|
60
|
-
amsgrad: bool = False,
|
|
61
|
-
alpha: float = 1.,
|
|
62
|
-
pow: float = 2,
|
|
63
|
-
debiased: bool = True,
|
|
64
|
-
):
|
|
65
|
-
defaults=dict(beta1=beta1,beta2=beta2,precond_beta=precond_beta,eps=eps,alpha=alpha,amsgrad=amsgrad,pow=pow,debiased=debiased)
|
|
66
|
-
super().__init__(defaults)
|
|
67
|
-
self.getter = itemgetter('amsgrad','pow','debiased')
|
|
68
|
-
|
|
69
|
-
@torch.no_grad
|
|
70
|
-
def step(self, var):
|
|
71
|
-
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
72
|
-
params = var.params
|
|
73
|
-
|
|
74
|
-
beta1,beta2,precond_beta,eps,alpha=self.get_settings(params, 'beta1','beta2','precond_beta','eps','alpha', cls=NumberList)
|
|
75
|
-
amsgrad,pow,debiased = self.getter(self.settings[var.params[0]])
|
|
76
|
-
|
|
77
|
-
if amsgrad:
|
|
78
|
-
exp_avg, exp_avg_sq, exp_avg_qu, max_exp_avg_sq, max_exp_avg_qu = self.get_state(params, 'exp_avg','exp_avg_sq', 'exp_avg_qu', 'max_exp_avg_sq', 'max_exp_avg_qu', cls=TensorList)
|
|
79
|
-
else:
|
|
80
|
-
exp_avg, exp_avg_sq, exp_avg_qu = self.get_state(params, 'exp_avg','exp_avg_sq', 'exp_avg_qu', cls=TensorList)
|
|
81
|
-
max_exp_avg_sq = None
|
|
82
|
-
max_exp_avg_qu = None
|
|
83
|
-
|
|
84
|
-
# if this is last module, update parameters in-place with slightly more efficient addcdiv_
|
|
85
|
-
if var.is_last:
|
|
86
|
-
if var.last_module_lrs is not None: alpha = alpha * var.last_module_lrs
|
|
87
|
-
passed_params = TensorList(var.params)
|
|
88
|
-
var.stop = True
|
|
89
|
-
var.skip_update = True
|
|
90
|
-
|
|
91
|
-
else:
|
|
92
|
-
passed_params = None
|
|
93
|
-
|
|
94
|
-
var.update = adadam_(
|
|
95
|
-
tensors=TensorList(var.get_update()),
|
|
96
|
-
exp_avg_=exp_avg,
|
|
97
|
-
exp_avg_sq_=exp_avg_sq,
|
|
98
|
-
exp_avg_qu_=exp_avg_qu,
|
|
99
|
-
alpha=alpha,
|
|
100
|
-
beta1=beta1,
|
|
101
|
-
beta2=beta2,
|
|
102
|
-
precond_beta=precond_beta,
|
|
103
|
-
eps=eps,
|
|
104
|
-
step=step,
|
|
105
|
-
pow=pow,
|
|
106
|
-
debiased=debiased,
|
|
107
|
-
max_exp_avg_sq_=max_exp_avg_sq,
|
|
108
|
-
max_exp_avg_qu_=max_exp_avg_qu,
|
|
109
|
-
params_=passed_params,
|
|
110
|
-
)
|
|
111
|
-
|
|
112
|
-
return var
|
|
@@ -1,125 +0,0 @@
|
|
|
1
|
-
from operator import itemgetter
|
|
2
|
-
from functools import partial
|
|
3
|
-
|
|
4
|
-
import torch
|
|
5
|
-
|
|
6
|
-
from ...core import Module, Target, Transform
|
|
7
|
-
from ...utils import NumberList, TensorList
|
|
8
|
-
from ..functional import (
|
|
9
|
-
debias, debiased_step_size,
|
|
10
|
-
ema_,
|
|
11
|
-
sqrt_ema_sq_,
|
|
12
|
-
)
|
|
13
|
-
from ..lr.lr import lazy_lr
|
|
14
|
-
from ..momentum.experimental import sqrt_nag_ema_sq_
|
|
15
|
-
from ..momentum.momentum import nag_
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
def adamy_(
|
|
19
|
-
p: TensorList,
|
|
20
|
-
p_prev: TensorList,
|
|
21
|
-
g: TensorList,
|
|
22
|
-
g_prev: TensorList,
|
|
23
|
-
exp_avg_: TensorList,
|
|
24
|
-
exp_avg_sq_: TensorList,
|
|
25
|
-
alpha: float | NumberList,
|
|
26
|
-
beta1: float | NumberList,
|
|
27
|
-
beta2: float | NumberList,
|
|
28
|
-
eps: float | NumberList,
|
|
29
|
-
step: int,
|
|
30
|
-
pow: float = 2,
|
|
31
|
-
debiased: bool = True,
|
|
32
|
-
max_exp_avg_sq_: TensorList | None = None,
|
|
33
|
-
params_: TensorList | None = None,
|
|
34
|
-
):
|
|
35
|
-
"""Returns new tensors or updates params in-place."""
|
|
36
|
-
if step == 1:
|
|
37
|
-
p_prev.copy_(p)
|
|
38
|
-
g_prev.copy_(g)
|
|
39
|
-
|
|
40
|
-
update = g.clip(-0.1,0.1).lazy_mul_(alpha)
|
|
41
|
-
if params_ is None: return update
|
|
42
|
-
params_.sub_(update)
|
|
43
|
-
return None
|
|
44
|
-
|
|
45
|
-
s = p-p_prev
|
|
46
|
-
y = (g-g_prev).div_(s.global_vector_norm().clip(min=1e-8))
|
|
47
|
-
p_prev.copy_(p)
|
|
48
|
-
g_prev.copy_(g)
|
|
49
|
-
|
|
50
|
-
exp_avg_ = ema_(g, exp_avg_=exp_avg_, beta=beta1, dampening=0,lerp=True)
|
|
51
|
-
|
|
52
|
-
sqrt_exp_avg_sq = sqrt_ema_sq_(y, exp_avg_sq_=exp_avg_sq_, beta=beta2, max_exp_avg_sq_=max_exp_avg_sq_,
|
|
53
|
-
debiased=False,step=step,pow=pow)
|
|
54
|
-
|
|
55
|
-
if debiased: alpha = debiased_step_size(step, beta1=beta1, beta2=beta2, pow=pow, alpha=alpha)
|
|
56
|
-
|
|
57
|
-
# params is None, return update
|
|
58
|
-
if params_ is None: return (exp_avg_ / sqrt_exp_avg_sq.add_(eps)).lazy_mul(alpha)
|
|
59
|
-
|
|
60
|
-
# update params in-place
|
|
61
|
-
params_.addcdiv_(exp_avg_, sqrt_exp_avg_sq.add_(eps), -alpha)
|
|
62
|
-
return None
|
|
63
|
-
|
|
64
|
-
class AdamY(Module):
|
|
65
|
-
"""Adam but uses scaled gradient differences for second momentum. Please note that this is experimental and isn't guaranteed to work."""
|
|
66
|
-
def __init__(
|
|
67
|
-
self,
|
|
68
|
-
beta1: float = 0.9,
|
|
69
|
-
beta2: float = 0.999,
|
|
70
|
-
eps: float = 1e-8,
|
|
71
|
-
amsgrad: bool = False,
|
|
72
|
-
alpha: float = 1.,
|
|
73
|
-
pow: float = 2,
|
|
74
|
-
debiased: bool = True,
|
|
75
|
-
):
|
|
76
|
-
defaults=dict(beta1=beta1,beta2=beta2,eps=eps,alpha=alpha,amsgrad=amsgrad,pow=pow,debiased=debiased)
|
|
77
|
-
super().__init__(defaults)
|
|
78
|
-
self.getter = itemgetter('amsgrad','pow','debiased')
|
|
79
|
-
|
|
80
|
-
@torch.no_grad
|
|
81
|
-
def step(self, var):
|
|
82
|
-
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
83
|
-
|
|
84
|
-
beta1,beta2,eps,alpha=self.get_settings(var.params, 'beta1','beta2','eps','alpha', cls=NumberList)
|
|
85
|
-
amsgrad,pow,debiased = self.getter(self.settings[var.params[0]])
|
|
86
|
-
|
|
87
|
-
if amsgrad:
|
|
88
|
-
exp_avg, exp_avg_sq, max_exp_avg_sq = self.get_state(var.params,'exp_avg','exp_avg_sq','max_exp_avg_sq', cls=TensorList)
|
|
89
|
-
else:
|
|
90
|
-
exp_avg, exp_avg_sq = self.get_state(var.params, 'exp_avg','exp_avg_sq', cls=TensorList)
|
|
91
|
-
max_exp_avg_sq = None
|
|
92
|
-
|
|
93
|
-
# if this is last module, update parameters in-place with slightly more efficient addcdiv_
|
|
94
|
-
if var.is_last:
|
|
95
|
-
if var.last_module_lrs is not None: alpha = alpha * var.last_module_lrs
|
|
96
|
-
passed_params = TensorList(var.params)
|
|
97
|
-
var.stop = True
|
|
98
|
-
var.skip_update = True
|
|
99
|
-
|
|
100
|
-
else:
|
|
101
|
-
passed_params = None
|
|
102
|
-
|
|
103
|
-
p_prev = self.get_state(var.params, 'p_prev', cls=TensorList)
|
|
104
|
-
g_prev = self.get_state(var.params, 'g_prev', cls=TensorList)
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
var.update = adamy_(
|
|
108
|
-
p=TensorList(var.params),
|
|
109
|
-
p_prev=p_prev,
|
|
110
|
-
g=TensorList(var.get_update()),
|
|
111
|
-
g_prev=g_prev,
|
|
112
|
-
exp_avg_=exp_avg,
|
|
113
|
-
exp_avg_sq_=exp_avg_sq,
|
|
114
|
-
alpha=alpha,
|
|
115
|
-
beta1=beta1,
|
|
116
|
-
beta2=beta2,
|
|
117
|
-
eps=eps,
|
|
118
|
-
step=step,
|
|
119
|
-
pow=pow,
|
|
120
|
-
debiased=debiased,
|
|
121
|
-
max_exp_avg_sq_=max_exp_avg_sq,
|
|
122
|
-
params_=passed_params,
|
|
123
|
-
)
|
|
124
|
-
|
|
125
|
-
return var
|
|
@@ -1,172 +0,0 @@
|
|
|
1
|
-
from operator import itemgetter
|
|
2
|
-
|
|
3
|
-
import torch
|
|
4
|
-
|
|
5
|
-
from ...core import Chainable, Transform
|
|
6
|
-
from ...modules.optimizers.shampoo import _merge_small_dims, _unmerge_small_dims
|
|
7
|
-
from ..optimizers.soap import (
|
|
8
|
-
get_orthogonal_matrix,
|
|
9
|
-
get_orthogonal_matrix_QR,
|
|
10
|
-
project,
|
|
11
|
-
project_back,
|
|
12
|
-
)
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
@torch.no_grad
|
|
16
|
-
def update_adasoap_covariances_(
|
|
17
|
-
grad: torch.Tensor,
|
|
18
|
-
GGs_: list[torch.Tensor | None],
|
|
19
|
-
GG_sqs: list[torch.Tensor | None],
|
|
20
|
-
beta: float | None,
|
|
21
|
-
precond_beta: float | None,
|
|
22
|
-
):
|
|
23
|
-
for i, (GG, GG_sq) in enumerate(zip(GGs_, GG_sqs)):
|
|
24
|
-
if GG is None: continue
|
|
25
|
-
assert GG_sq is not None
|
|
26
|
-
|
|
27
|
-
if precond_beta is None: GG_sq.addcmul_(GG, GG)
|
|
28
|
-
else: GG_sq.mul_(precond_beta).addcmul_(GG, GG, value=1-precond_beta)
|
|
29
|
-
|
|
30
|
-
axes = list(range(i)) + list(range(i + 1, grad.ndim)) # this works fine with 1d params
|
|
31
|
-
if beta is None: GG.add_(torch.tensordot(grad, grad, (axes, axes))) # pyright:ignore[reportArgumentType]
|
|
32
|
-
else: GG.lerp_(torch.tensordot(grad, grad, (axes, axes)), 1-beta) # pyright:ignore[reportArgumentType]
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
class AdaSOAP(Transform):
|
|
36
|
-
"""SOAP with diagonally preconditioned GG^Ts. Please note that this is experimental and isn't guaranteed to work.
|
|
37
|
-
|
|
38
|
-
precond_beta - beta for GG^T squares
|
|
39
|
-
"""
|
|
40
|
-
def __init__(
|
|
41
|
-
self,
|
|
42
|
-
beta1: float = 0.95,
|
|
43
|
-
beta2: float = 0.95,
|
|
44
|
-
shampoo_beta: float | None = 0.95,
|
|
45
|
-
precond_beta: float | None = 0.95,
|
|
46
|
-
precond_freq: int = 10,
|
|
47
|
-
merge_small: bool = True,
|
|
48
|
-
max_dim: int = 2_000,
|
|
49
|
-
precondition_1d: bool = True,
|
|
50
|
-
eps: float = 1e-8,
|
|
51
|
-
decay: float | None = None,
|
|
52
|
-
alpha: float = 1,
|
|
53
|
-
unprojected_exp_avg: bool = True,
|
|
54
|
-
bias_correction: bool = True,
|
|
55
|
-
):
|
|
56
|
-
defaults = dict(
|
|
57
|
-
beta1=beta1,
|
|
58
|
-
beta2=beta2,
|
|
59
|
-
shampoo_beta=shampoo_beta,
|
|
60
|
-
precond_beta=precond_beta,
|
|
61
|
-
precond_freq=precond_freq,
|
|
62
|
-
merge_small=merge_small,
|
|
63
|
-
max_dim=max_dim,
|
|
64
|
-
precondition_1d=precondition_1d,
|
|
65
|
-
eps=eps,
|
|
66
|
-
decay=decay,
|
|
67
|
-
unprojected_exp_avg=unprojected_exp_avg,
|
|
68
|
-
bias_correction=bias_correction,
|
|
69
|
-
alpha=alpha,
|
|
70
|
-
)
|
|
71
|
-
super().__init__(defaults, uses_grad=False)
|
|
72
|
-
|
|
73
|
-
@torch.no_grad
|
|
74
|
-
def apply(self, tensors, params, grads, loss, states, settings):
|
|
75
|
-
updates = []
|
|
76
|
-
# update preconditioners
|
|
77
|
-
for i,(p,t, state, setting) in enumerate(zip(params, tensors, states, settings)):
|
|
78
|
-
|
|
79
|
-
beta1, beta2, shampoo_beta, merge_small, max_dim, precondition_1d, eps, unprojected_exp_avg,alpha = itemgetter(
|
|
80
|
-
'beta1', 'beta2', 'shampoo_beta', 'merge_small', 'max_dim', 'precondition_1d', 'eps', 'unprojected_exp_avg','alpha')(setting)
|
|
81
|
-
precond_beta = setting['precond_beta']
|
|
82
|
-
|
|
83
|
-
if merge_small:
|
|
84
|
-
t, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(t, max_dim)
|
|
85
|
-
|
|
86
|
-
# initialize state on 1st step
|
|
87
|
-
if 'GG' not in state:
|
|
88
|
-
state["exp_avg"] = torch.zeros_like(t)
|
|
89
|
-
state["exp_avg_sq"] = torch.zeros_like(t)
|
|
90
|
-
|
|
91
|
-
if not precondition_1d and t.ndim <= 1:
|
|
92
|
-
state['GG'] = []
|
|
93
|
-
state['GG_sq'] = []
|
|
94
|
-
|
|
95
|
-
else:
|
|
96
|
-
state['GG'] = [torch.zeros(s, s, dtype=t.dtype, device=t.device) if 1<s<max_dim else None for s in t.shape]
|
|
97
|
-
state['GG_sq'] = [torch.zeros(s, s, dtype=t.dtype, device=t.device) if 1<s<max_dim else None for s in t.shape]
|
|
98
|
-
|
|
99
|
-
# either scalar parameter, 1d with precondition_1d=False, or all dims are too big.
|
|
100
|
-
if len([i is not None for i in state['GG']]) == 0:
|
|
101
|
-
state['GG'] = None
|
|
102
|
-
state['GG_sq'] = None
|
|
103
|
-
|
|
104
|
-
if state['GG'] is not None:
|
|
105
|
-
assert state['GG_sq'] is not None
|
|
106
|
-
update_adasoap_covariances_(t, GGs_=state['GG'], GG_sqs=state['GG_sq'], beta=shampoo_beta, precond_beta=precond_beta)
|
|
107
|
-
GG_precond = [GG / (GG_sq+1e-8) if GG is not None and GG_sq is not None else None for GG, GG_sq in zip(state['GG'], state['GG_sq'])]
|
|
108
|
-
state['Q'] = get_orthogonal_matrix(GG_precond)
|
|
109
|
-
|
|
110
|
-
state['step'] = 0
|
|
111
|
-
updates.append(tensors[i].clip(-0.1,0.1))
|
|
112
|
-
continue # skip 1st step as in https://github.com/nikhilvyas/SOAP/blob/main/soap.py ?
|
|
113
|
-
# that can mess with other modules scaling
|
|
114
|
-
|
|
115
|
-
# Projecting gradients to the eigenbases of Shampoo's preconditioner
|
|
116
|
-
# i.e. projecting to the eigenbases of matrices in state['GG']
|
|
117
|
-
t_projected = None
|
|
118
|
-
if state['GG'] is not None:
|
|
119
|
-
t_projected = project(t, state['Q'])
|
|
120
|
-
|
|
121
|
-
# exponential moving averages
|
|
122
|
-
# this part could be foreached but I will do that at some point its not a big difference compared to preconditioning
|
|
123
|
-
exp_avg: torch.Tensor = state["exp_avg"]
|
|
124
|
-
exp_avg_sq: torch.Tensor = state["exp_avg_sq"]
|
|
125
|
-
|
|
126
|
-
if unprojected_exp_avg or t_projected is None:
|
|
127
|
-
exp_avg.lerp_(t, 1-beta1)
|
|
128
|
-
else:
|
|
129
|
-
exp_avg.lerp_(t_projected, 1-beta1)
|
|
130
|
-
|
|
131
|
-
if t_projected is None:
|
|
132
|
-
exp_avg_sq.mul_(beta2).addcmul_(t, t, value=1-beta2)
|
|
133
|
-
else:
|
|
134
|
-
exp_avg_sq.mul_(beta2).addcmul_(t_projected, t_projected, value=1-beta2)
|
|
135
|
-
|
|
136
|
-
# project exponential moving averages if they are accumulated unprojected
|
|
137
|
-
exp_avg_projected = exp_avg
|
|
138
|
-
if unprojected_exp_avg and t_projected is not None:
|
|
139
|
-
exp_avg_projected = project(exp_avg, state['Q'])
|
|
140
|
-
|
|
141
|
-
exp_avg_sq_projected = exp_avg_sq
|
|
142
|
-
|
|
143
|
-
denom = exp_avg_sq_projected.sqrt().add_(eps)
|
|
144
|
-
# print(f'{t_projected = }, {exp_avg = }, {exp_avg_projected = }, {exp_avg_sq = }, {exp_avg_sq_projected = }, {denom = }')
|
|
145
|
-
|
|
146
|
-
# Projecting back the preconditioned (by Adam) exponential moving average of gradients
|
|
147
|
-
# to the original space
|
|
148
|
-
update = exp_avg_projected / denom
|
|
149
|
-
if t_projected is not None:
|
|
150
|
-
update = project_back(update, state["Q"])
|
|
151
|
-
|
|
152
|
-
if setting['bias_correction']:
|
|
153
|
-
bias_correction1 = 1.0 - beta1 ** (state["step"]+1)
|
|
154
|
-
bias_correction2 = 1.0 - beta2 ** (state["step"]+1)
|
|
155
|
-
update *= ((bias_correction2 ** .5) / bias_correction1) * alpha
|
|
156
|
-
elif alpha is not None:
|
|
157
|
-
update *= alpha
|
|
158
|
-
|
|
159
|
-
if merge_small:
|
|
160
|
-
update = _unmerge_small_dims(update, state['flat_sizes'], state['sort_idxs'])
|
|
161
|
-
|
|
162
|
-
updates.append(update)
|
|
163
|
-
state["step"] += 1
|
|
164
|
-
|
|
165
|
-
# Update is done after the gradient step to avoid using current gradients in the projection.
|
|
166
|
-
if state['GG'] is not None:
|
|
167
|
-
update_adasoap_covariances_(t, GGs_=state['GG'], GG_sqs=state['GG_sq'], beta=shampoo_beta, precond_beta=precond_beta)
|
|
168
|
-
GG_precond = [GG / (GG_sq+1e-8) if GG is not None and GG_sq is not None else None for GG, GG_sq in zip(state['GG'], state['GG_sq'])]
|
|
169
|
-
if state['step'] % setting['precond_freq'] == 0:
|
|
170
|
-
state['Q'], state['exp_avg_sq'] = get_orthogonal_matrix_QR(exp_avg_sq, GG_precond, state['Q'])
|
|
171
|
-
|
|
172
|
-
return updates
|