torchzero 0.3.11__py3-none-any.whl → 0.3.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_opts.py +95 -69
- tests/test_tensorlist.py +8 -7
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +225 -72
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +44 -24
- torchzero/modules/__init__.py +13 -5
- torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
- torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
- torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
- torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
- torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
- torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
- torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
- torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
- torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
- torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
- torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
- torchzero/modules/clipping/clipping.py +85 -92
- torchzero/modules/clipping/ema_clipping.py +5 -5
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
- torchzero/modules/experimental/__init__.py +9 -32
- torchzero/modules/experimental/dct.py +2 -2
- torchzero/modules/experimental/fft.py +2 -2
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +27 -14
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/experimental/structural_projections.py +1 -1
- torchzero/modules/functional.py +50 -14
- torchzero/modules/grad_approximation/fdm.py +19 -20
- torchzero/modules/grad_approximation/forward_gradient.py +4 -2
- torchzero/modules/grad_approximation/grad_approximator.py +43 -47
- torchzero/modules/grad_approximation/rfdm.py +144 -122
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +31 -23
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +2 -2
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +69 -44
- torchzero/modules/line_search/backtracking.py +83 -70
- torchzero/modules/line_search/line_search.py +159 -68
- torchzero/modules/line_search/scipy.py +1 -1
- torchzero/modules/line_search/strong_wolfe.py +319 -218
- torchzero/modules/misc/__init__.py +8 -0
- torchzero/modules/misc/debug.py +4 -4
- torchzero/modules/misc/escape.py +9 -7
- torchzero/modules/misc/gradient_accumulation.py +88 -22
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +82 -15
- torchzero/modules/misc/multistep.py +47 -11
- torchzero/modules/misc/regularization.py +5 -9
- torchzero/modules/misc/split.py +55 -35
- torchzero/modules/misc/switch.py +1 -1
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +3 -3
- torchzero/modules/momentum/cautious.py +42 -47
- torchzero/modules/momentum/momentum.py +35 -1
- torchzero/modules/ops/__init__.py +9 -1
- torchzero/modules/ops/binary.py +9 -8
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
- torchzero/modules/ops/multi.py +15 -15
- torchzero/modules/ops/reduce.py +1 -1
- torchzero/modules/ops/utility.py +12 -8
- torchzero/modules/projections/projection.py +4 -4
- torchzero/modules/quasi_newton/__init__.py +1 -16
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
- torchzero/modules/quasi_newton/lbfgs.py +256 -200
- torchzero/modules/quasi_newton/lsr1.py +167 -132
- torchzero/modules/quasi_newton/quasi_newton.py +346 -446
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +252 -0
- torchzero/modules/second_order/__init__.py +2 -1
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +133 -88
- torchzero/modules/second_order/newton_cg.py +141 -80
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +1 -1
- torchzero/modules/step_size/adaptive.py +312 -47
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +97 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/weight_decay.py +65 -64
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +359 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +0 -1
- torchzero/optim/wrappers/fcmaes.py +3 -2
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +2 -2
- torchzero/optim/wrappers/scipy.py +81 -22
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +123 -111
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +226 -154
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/python_tools.py +6 -0
- torchzero/utils/tensorlist.py +105 -34
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.13.dist-info/METADATA +14 -0
- torchzero-0.3.13.dist-info/RECORD +166 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -59
- docs/source/docstring template.py +0 -46
- torchzero/modules/experimental/absoap.py +0 -253
- torchzero/modules/experimental/adadam.py +0 -118
- torchzero/modules/experimental/adamY.py +0 -131
- torchzero/modules/experimental/adam_lambertw.py +0 -149
- torchzero/modules/experimental/adaptive_step_size.py +0 -90
- torchzero/modules/experimental/adasoap.py +0 -177
- torchzero/modules/experimental/cosine.py +0 -214
- torchzero/modules/experimental/cubic_adam.py +0 -97
- torchzero/modules/experimental/eigendescent.py +0 -120
- torchzero/modules/experimental/etf.py +0 -195
- torchzero/modules/experimental/exp_adam.py +0 -113
- torchzero/modules/experimental/expanded_lbfgs.py +0 -141
- torchzero/modules/experimental/hnewton.py +0 -85
- torchzero/modules/experimental/modular_lbfgs.py +0 -265
- torchzero/modules/experimental/parabolic_search.py +0 -220
- torchzero/modules/experimental/subspace_preconditioners.py +0 -145
- torchzero/modules/experimental/tensor_adagrad.py +0 -42
- torchzero/modules/line_search/polynomial.py +0 -233
- torchzero/modules/momentum/matrix_momentum.py +0 -193
- torchzero/modules/optimizers/adagrad.py +0 -165
- torchzero/modules/quasi_newton/trust_region.py +0 -397
- torchzero/modules/smoothing/gaussian.py +0 -198
- torchzero-0.3.11.dist-info/METADATA +0 -404
- torchzero-0.3.11.dist-info/RECORD +0 -159
- torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
- /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
|
@@ -1,253 +0,0 @@
|
|
|
1
|
-
from operator import itemgetter
|
|
2
|
-
from typing import Literal
|
|
3
|
-
|
|
4
|
-
import torch
|
|
5
|
-
|
|
6
|
-
from ...core import Chainable, Transform
|
|
7
|
-
from ..optimizers.shampoo import _merge_small_dims, _unmerge_small_dims
|
|
8
|
-
from ..optimizers.soap import project, project_back, get_orthogonal_matrix, get_orthogonal_matrix_QR
|
|
9
|
-
|
|
10
|
-
@torch.no_grad
|
|
11
|
-
def update_absoap_covariances_(
|
|
12
|
-
g1: torch.Tensor,
|
|
13
|
-
g2: torch.Tensor,
|
|
14
|
-
GGs_: list[torch.Tensor | None],
|
|
15
|
-
beta: float | None,
|
|
16
|
-
):
|
|
17
|
-
for i, GG in enumerate(GGs_):
|
|
18
|
-
if GG is None: continue
|
|
19
|
-
|
|
20
|
-
axes = list(range(i)) + list(range(i + 1, g1.ndim)) # this works fine with 1d params
|
|
21
|
-
if beta is None: GG.add_(torch.tensordot(g1, g2, (axes, axes))) # pyright:ignore[reportArgumentType]
|
|
22
|
-
else: GG.lerp_(torch.tensordot(g1, g2, (axes, axes)), 1-beta) # pyright:ignore[reportArgumentType]
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
Source=Literal['p','g','s','y', 'gy', 'sy', 'sn', 'yn', 'gys', 'sys']
|
|
26
|
-
class ABSOAP(Transform):
|
|
27
|
-
"""SOAP but with some extra options for testing.
|
|
28
|
-
|
|
29
|
-
.. warning::
|
|
30
|
-
This module is just for testing my stupid ideas.
|
|
31
|
-
|
|
32
|
-
Args:
|
|
33
|
-
scale_by_s - whether to scale y by s
|
|
34
|
-
gg1 - 1st vector into GGᵀ
|
|
35
|
-
gg2 - 2nd vector into GGᵀ
|
|
36
|
-
ema1 - vector into 1st momentum
|
|
37
|
-
ema2 - 2 vectors into 2nd momentum
|
|
38
|
-
rel1 - if True, multiplies gg1 by params
|
|
39
|
-
rel2 - same but for gg2
|
|
40
|
-
norm - if True, gg1 a and gg2 are normalized, and I need to make that into a letter
|
|
41
|
-
|
|
42
|
-
letters:
|
|
43
|
-
p - params
|
|
44
|
-
g - grad
|
|
45
|
-
s - param difference
|
|
46
|
-
y - grad difference
|
|
47
|
-
gy - g+y
|
|
48
|
-
sy - s+y
|
|
49
|
-
sn - s normalized
|
|
50
|
-
yn - y normalized
|
|
51
|
-
gys - g + y#g
|
|
52
|
-
sys - s + y#s
|
|
53
|
-
|
|
54
|
-
"""
|
|
55
|
-
def __init__(
|
|
56
|
-
self,
|
|
57
|
-
beta1: float = 0.95,
|
|
58
|
-
beta2: float = 0.95,
|
|
59
|
-
shampoo_beta: float | None = 0.95,
|
|
60
|
-
precond_freq: int = 10,
|
|
61
|
-
merge_small: bool = True,
|
|
62
|
-
max_dim: int = 2_000,
|
|
63
|
-
precondition_1d: bool = True,
|
|
64
|
-
eps: float = 1e-8,
|
|
65
|
-
decay: float | None = None,
|
|
66
|
-
alpha: float = 1,
|
|
67
|
-
bias_correction: bool = True,
|
|
68
|
-
scale_by_s: bool = True,
|
|
69
|
-
gg1: Source='g',
|
|
70
|
-
gg2: Source='g',
|
|
71
|
-
ema1: Source='g',
|
|
72
|
-
ema2: tuple[Source, Source] = ('g','g'),
|
|
73
|
-
rel1: bool=False,
|
|
74
|
-
rel2: bool=False,
|
|
75
|
-
norm: bool = False,
|
|
76
|
-
):
|
|
77
|
-
defaults = dict(
|
|
78
|
-
beta1=beta1,
|
|
79
|
-
beta2=beta2,
|
|
80
|
-
shampoo_beta=shampoo_beta,
|
|
81
|
-
precond_freq=precond_freq,
|
|
82
|
-
merge_small=merge_small,
|
|
83
|
-
max_dim=max_dim,
|
|
84
|
-
precondition_1d=precondition_1d,
|
|
85
|
-
eps=eps,
|
|
86
|
-
decay=decay,
|
|
87
|
-
bias_correction=bias_correction,
|
|
88
|
-
alpha=alpha,
|
|
89
|
-
scale_by_s=scale_by_s,
|
|
90
|
-
ema1=ema1,
|
|
91
|
-
ema2=ema2,
|
|
92
|
-
first=gg1,
|
|
93
|
-
second=gg2,
|
|
94
|
-
rel1=rel1, rel2=rel2,
|
|
95
|
-
norm=norm,
|
|
96
|
-
)
|
|
97
|
-
super().__init__(defaults, uses_grad=False)
|
|
98
|
-
|
|
99
|
-
@torch.no_grad
|
|
100
|
-
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
101
|
-
updates = []
|
|
102
|
-
# update preconditioners
|
|
103
|
-
for i,(p,t, state, setting) in enumerate(zip(params, tensors, states, settings)):
|
|
104
|
-
beta1, beta2, shampoo_beta, merge_small, max_dim, precondition_1d, eps, alpha = itemgetter(
|
|
105
|
-
'beta1', 'beta2', 'shampoo_beta', 'merge_small', 'max_dim', 'precondition_1d', 'eps', 'alpha')(setting)
|
|
106
|
-
scale_by_s = setting['scale_by_s']
|
|
107
|
-
ema1 = setting['ema1']
|
|
108
|
-
ema2 = setting['ema2']
|
|
109
|
-
first=setting['first']
|
|
110
|
-
second=setting['second']
|
|
111
|
-
rel1 = setting['rel1']; rel2 = setting['rel2']
|
|
112
|
-
norm=setting['norm']
|
|
113
|
-
|
|
114
|
-
if merge_small:
|
|
115
|
-
t, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(t, max_dim)
|
|
116
|
-
|
|
117
|
-
if 'g_prev' not in state:
|
|
118
|
-
state['p_prev'] = p.clone()
|
|
119
|
-
state['g_prev'] = t.clone()
|
|
120
|
-
# updates.append(tensors[i].clip(-0.1,0.1))
|
|
121
|
-
# continue
|
|
122
|
-
|
|
123
|
-
p_prev = state['p_prev']
|
|
124
|
-
g_prev = state['g_prev']
|
|
125
|
-
s = p - p_prev
|
|
126
|
-
y = t - g_prev
|
|
127
|
-
|
|
128
|
-
# keep malding
|
|
129
|
-
p_norm = torch.linalg.vector_norm(p) # pylint:disable=not-callable
|
|
130
|
-
g_norm = torch.linalg.vector_norm(t) # pylint:disable=not-callable
|
|
131
|
-
s_norm = torch.linalg.vector_norm(s) # pylint:disable=not-callable
|
|
132
|
-
y_norm = torch.linalg.vector_norm(y) # pylint:disable=not-callable
|
|
133
|
-
|
|
134
|
-
sn = p - p_prev * (p_norm / torch.linalg.vector_norm(p_prev))# pylint:disable=not-callable
|
|
135
|
-
yn = t - g_prev * (g_norm / torch.linalg.vector_norm(g_prev))# pylint:disable=not-callable
|
|
136
|
-
|
|
137
|
-
if scale_by_s: y /= s_norm.clip(min=1e-8) # pylint:disable=not-callable
|
|
138
|
-
|
|
139
|
-
state['p_prev'].copy_(p)
|
|
140
|
-
state['g_prev'].copy_(t)
|
|
141
|
-
|
|
142
|
-
def _get(c: Source):
|
|
143
|
-
if c == 'p': return p
|
|
144
|
-
if c == 'g': return t
|
|
145
|
-
if c == 's': return s
|
|
146
|
-
if c == 'y': return y
|
|
147
|
-
if c == 'sn': return sn
|
|
148
|
-
if c == 'yn': return yn
|
|
149
|
-
if c == 'gy': return t+y
|
|
150
|
-
if c == 'sy': return s+y
|
|
151
|
-
if c == 'gys':
|
|
152
|
-
y_scaled = y * (g_norm/y_norm.clip(min=1e-8))
|
|
153
|
-
return t+y_scaled
|
|
154
|
-
if c == 'sys':
|
|
155
|
-
y_scaled = y * (s_norm/y_norm.clip(min=1e-8))
|
|
156
|
-
return s+y_scaled
|
|
157
|
-
raise RuntimeError("Big Chungus")
|
|
158
|
-
|
|
159
|
-
t1 = _get(first)
|
|
160
|
-
if rel1: t1 = t1 * p.abs().clip(min=1e-6)
|
|
161
|
-
t2 = _get(second)
|
|
162
|
-
if rel2: t2 = t2 * p.abs().clip(min=1e-6)
|
|
163
|
-
|
|
164
|
-
t_ema1 = _get(ema1)
|
|
165
|
-
t_ema2s = _get(ema2[0]), _get(ema2[1])
|
|
166
|
-
|
|
167
|
-
if norm:
|
|
168
|
-
t1 = t1/torch.linalg.vector_norm(t1).clip(min=1e-8) # pylint:disable=not-callable
|
|
169
|
-
t2 = t2/torch.linalg.vector_norm(t2).clip(min=1e-8) # pylint:disable=not-callable
|
|
170
|
-
|
|
171
|
-
# initialize state on 1st step
|
|
172
|
-
if 'GG' not in state:
|
|
173
|
-
state["exp_avg"] = torch.zeros_like(t)
|
|
174
|
-
state["exp_avg_sq"] = torch.zeros_like(t)
|
|
175
|
-
|
|
176
|
-
if not precondition_1d and t.ndim <= 1:
|
|
177
|
-
state['GG'] = []
|
|
178
|
-
|
|
179
|
-
else:
|
|
180
|
-
state['GG'] = [torch.zeros(sh, sh, dtype=t.dtype, device=t.device) if 1<sh<max_dim else None for sh in t.shape]
|
|
181
|
-
|
|
182
|
-
# either scalar parameter, 1d with precondition_1d=False, or all dims are too big.
|
|
183
|
-
if len([i is not None for i in state['GG']]) == 0:
|
|
184
|
-
state['GG'] = None
|
|
185
|
-
|
|
186
|
-
if state['GG'] is not None:
|
|
187
|
-
update_absoap_covariances_(t1, t2, GGs_=state['GG'], beta=shampoo_beta)
|
|
188
|
-
state['Q'] = get_orthogonal_matrix(state['GG'])
|
|
189
|
-
|
|
190
|
-
state['step'] = 0
|
|
191
|
-
updates.append(tensors[i].clip(-0.1,0.1))
|
|
192
|
-
continue # skip 1st step as in https://github.com/nikhilvyas/SOAP/blob/main/soap.py ?
|
|
193
|
-
# I use sign instead as to not mess up with next modules. 1st Adam step is always sign anyway.
|
|
194
|
-
|
|
195
|
-
# Projecting gradients to the eigenbases of Shampoo's preconditioner
|
|
196
|
-
# i.e. projecting to the eigenbases of matrices in state['GG']
|
|
197
|
-
z1_projected = None
|
|
198
|
-
z2_projected = None
|
|
199
|
-
|
|
200
|
-
if state['GG'] is not None:
|
|
201
|
-
z1_projected = project(t_ema2s[0], state['Q'])
|
|
202
|
-
if ema2[0] == ema2[1]: z2_projected = z1_projected
|
|
203
|
-
else: z2_projected = project(t_ema2s[1], state['Q'])
|
|
204
|
-
|
|
205
|
-
# exponential moving averages
|
|
206
|
-
# this part could be foreached but I will do that at some point its not a big difference compared to preconditioning
|
|
207
|
-
exp_avg: torch.Tensor = state["exp_avg"]
|
|
208
|
-
exp_avg_sq: torch.Tensor = state["exp_avg_sq"]
|
|
209
|
-
|
|
210
|
-
exp_avg.lerp_(t_ema1, 1-beta1)
|
|
211
|
-
|
|
212
|
-
if z1_projected is None:
|
|
213
|
-
exp_avg_sq.mul_(beta2).addcmul_(*t_ema2s, value=1-beta2)
|
|
214
|
-
else:
|
|
215
|
-
assert z2_projected is not None
|
|
216
|
-
exp_avg_sq.mul_(beta2).addcmul_(z1_projected, z2_projected, value=1-beta2)
|
|
217
|
-
|
|
218
|
-
# project exponential moving averages if they are accumulated unprojected
|
|
219
|
-
exp_avg_projected = exp_avg
|
|
220
|
-
if z1_projected is not None:
|
|
221
|
-
exp_avg_projected = project(exp_avg, state['Q'])
|
|
222
|
-
|
|
223
|
-
exp_avg_sq_projected = exp_avg_sq
|
|
224
|
-
|
|
225
|
-
denom = exp_avg_sq_projected.sqrt().add_(eps)
|
|
226
|
-
# print(f'{t_projected = }, {exp_avg = }, {exp_avg_projected = }, {exp_avg_sq = }, {exp_avg_sq_projected = }, {denom = }')
|
|
227
|
-
|
|
228
|
-
# Projecting back the preconditioned (by Adam) exponential moving average of gradients
|
|
229
|
-
# to the original space
|
|
230
|
-
update = exp_avg_projected / denom
|
|
231
|
-
if z1_projected is not None:
|
|
232
|
-
update = project_back(update, state["Q"])
|
|
233
|
-
|
|
234
|
-
if setting['bias_correction']:
|
|
235
|
-
bias_correction1 = 1.0 - beta1 ** (state["step"]+1)
|
|
236
|
-
bias_correction2 = 1.0 - beta2 ** (state["step"]+1)
|
|
237
|
-
update *= ((bias_correction2 ** .5) / bias_correction1) * alpha
|
|
238
|
-
elif alpha is not None:
|
|
239
|
-
update *= alpha
|
|
240
|
-
|
|
241
|
-
if merge_small:
|
|
242
|
-
update = _unmerge_small_dims(update, state['flat_sizes'], state['sort_idxs'])
|
|
243
|
-
|
|
244
|
-
updates.append(update)
|
|
245
|
-
state["step"] += 1
|
|
246
|
-
|
|
247
|
-
# Update is done after the gradient step to avoid using current gradients in the projection.
|
|
248
|
-
if state['GG'] is not None:
|
|
249
|
-
update_absoap_covariances_(t1, t2, state['GG'], shampoo_beta)
|
|
250
|
-
if state['step'] % setting['precond_freq'] == 0:
|
|
251
|
-
state['Q'], state['exp_avg_sq'] = get_orthogonal_matrix_QR(exp_avg_sq, state['GG'], state['Q'])
|
|
252
|
-
|
|
253
|
-
return updates
|
|
@@ -1,118 +0,0 @@
|
|
|
1
|
-
from operator import itemgetter
|
|
2
|
-
from functools import partial
|
|
3
|
-
|
|
4
|
-
import torch
|
|
5
|
-
|
|
6
|
-
from ...core import Module, Target, Transform
|
|
7
|
-
from ...utils import NumberList, TensorList
|
|
8
|
-
from ..functional import (
|
|
9
|
-
debias, debiased_step_size,
|
|
10
|
-
ema_,
|
|
11
|
-
sqrt_ema_sq_,
|
|
12
|
-
)
|
|
13
|
-
from ..step_size.lr import lazy_lr
|
|
14
|
-
from ..momentum.experimental import sqrt_nag_ema_sq_
|
|
15
|
-
from ..momentum.momentum import nag_
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
def adadam_(
|
|
19
|
-
tensors: TensorList,
|
|
20
|
-
exp_avg_: TensorList,
|
|
21
|
-
exp_avg_sq_: TensorList,
|
|
22
|
-
exp_avg_qu_: TensorList,
|
|
23
|
-
alpha: float | NumberList,
|
|
24
|
-
beta1: float | NumberList,
|
|
25
|
-
beta2: float | NumberList,
|
|
26
|
-
precond_beta: float | NumberList,
|
|
27
|
-
eps: float | NumberList,
|
|
28
|
-
step: int,
|
|
29
|
-
pow: float = 2,
|
|
30
|
-
debiased: bool = True,
|
|
31
|
-
max_exp_avg_sq_: TensorList | None = None,
|
|
32
|
-
max_exp_avg_qu_: TensorList | None = None,
|
|
33
|
-
params_: TensorList | None = None,
|
|
34
|
-
):
|
|
35
|
-
"""Returns new tensors or updates params in-place."""
|
|
36
|
-
exp_avg_ = ema_(tensors, exp_avg_=exp_avg_, beta=beta1, dampening=0,lerp=True)
|
|
37
|
-
|
|
38
|
-
sqrt_exp_avg_sq = sqrt_ema_sq_(tensors, exp_avg_sq_=exp_avg_sq_, beta=beta2, max_exp_avg_sq_=max_exp_avg_sq_,
|
|
39
|
-
debiased=False,step=step,pow=pow)
|
|
40
|
-
sqrt_exp_avg_qu = sqrt_ema_sq_(tensors/(sqrt_exp_avg_sq+1e-8), exp_avg_sq_=exp_avg_qu_,
|
|
41
|
-
beta=precond_beta,max_exp_avg_sq_=max_exp_avg_qu_, debiased=False,step=step,pow=pow)
|
|
42
|
-
|
|
43
|
-
if debiased: alpha = debiased_step_size(step, beta1=beta1, beta2=beta2, pow=pow, alpha=alpha)
|
|
44
|
-
|
|
45
|
-
# params is None, return update
|
|
46
|
-
if params_ is None: return (exp_avg_ / sqrt_exp_avg_qu.add_(eps)).lazy_mul(alpha)
|
|
47
|
-
|
|
48
|
-
# update params in-place
|
|
49
|
-
params_.addcdiv_(exp_avg_, sqrt_exp_avg_qu.add_(eps), -alpha)
|
|
50
|
-
return None
|
|
51
|
-
|
|
52
|
-
class Adadam(Module):
|
|
53
|
-
"""Adam with a diagonally preconditioned preconditioner.
|
|
54
|
-
|
|
55
|
-
Verdict: I haven't tested this yet.
|
|
56
|
-
|
|
57
|
-
.. warning::
|
|
58
|
-
Experimental.
|
|
59
|
-
"""
|
|
60
|
-
def __init__(
|
|
61
|
-
self,
|
|
62
|
-
beta1: float = 0.9,
|
|
63
|
-
beta2: float = 0.999,
|
|
64
|
-
precond_beta: float = 0.999,
|
|
65
|
-
eps: float = 1e-8,
|
|
66
|
-
amsgrad: bool = False,
|
|
67
|
-
alpha: float = 1.,
|
|
68
|
-
pow: float = 2,
|
|
69
|
-
debiased: bool = True,
|
|
70
|
-
):
|
|
71
|
-
defaults=dict(beta1=beta1,beta2=beta2,precond_beta=precond_beta,eps=eps,alpha=alpha,amsgrad=amsgrad,pow=pow,debiased=debiased)
|
|
72
|
-
super().__init__(defaults)
|
|
73
|
-
self.getter = itemgetter('amsgrad','pow','debiased')
|
|
74
|
-
|
|
75
|
-
@torch.no_grad
|
|
76
|
-
def step(self, var):
|
|
77
|
-
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
78
|
-
params = var.params
|
|
79
|
-
|
|
80
|
-
beta1,beta2,precond_beta,eps,alpha=self.get_settings(params, 'beta1','beta2','precond_beta','eps','alpha', cls=NumberList)
|
|
81
|
-
amsgrad,pow,debiased = self.getter(self.settings[var.params[0]])
|
|
82
|
-
|
|
83
|
-
if amsgrad:
|
|
84
|
-
exp_avg, exp_avg_sq, exp_avg_qu, max_exp_avg_sq, max_exp_avg_qu = self.get_state(params, 'exp_avg','exp_avg_sq', 'exp_avg_qu', 'max_exp_avg_sq', 'max_exp_avg_qu', cls=TensorList)
|
|
85
|
-
else:
|
|
86
|
-
exp_avg, exp_avg_sq, exp_avg_qu = self.get_state(params, 'exp_avg','exp_avg_sq', 'exp_avg_qu', cls=TensorList)
|
|
87
|
-
max_exp_avg_sq = None
|
|
88
|
-
max_exp_avg_qu = None
|
|
89
|
-
|
|
90
|
-
# if this is last module, update parameters in-place with slightly more efficient addcdiv_
|
|
91
|
-
if var.is_last:
|
|
92
|
-
if var.last_module_lrs is not None: alpha = alpha * var.last_module_lrs
|
|
93
|
-
passed_params = TensorList(var.params)
|
|
94
|
-
var.stop = True
|
|
95
|
-
var.skip_update = True
|
|
96
|
-
|
|
97
|
-
else:
|
|
98
|
-
passed_params = None
|
|
99
|
-
|
|
100
|
-
var.update = adadam_(
|
|
101
|
-
tensors=TensorList(var.get_update()),
|
|
102
|
-
exp_avg_=exp_avg,
|
|
103
|
-
exp_avg_sq_=exp_avg_sq,
|
|
104
|
-
exp_avg_qu_=exp_avg_qu,
|
|
105
|
-
alpha=alpha,
|
|
106
|
-
beta1=beta1,
|
|
107
|
-
beta2=beta2,
|
|
108
|
-
precond_beta=precond_beta,
|
|
109
|
-
eps=eps,
|
|
110
|
-
step=step,
|
|
111
|
-
pow=pow,
|
|
112
|
-
debiased=debiased,
|
|
113
|
-
max_exp_avg_sq_=max_exp_avg_sq,
|
|
114
|
-
max_exp_avg_qu_=max_exp_avg_qu,
|
|
115
|
-
params_=passed_params,
|
|
116
|
-
)
|
|
117
|
-
|
|
118
|
-
return var
|
|
@@ -1,131 +0,0 @@
|
|
|
1
|
-
from operator import itemgetter
|
|
2
|
-
from functools import partial
|
|
3
|
-
|
|
4
|
-
import torch
|
|
5
|
-
|
|
6
|
-
from ...core import Module, Target, Transform
|
|
7
|
-
from ...utils import NumberList, TensorList
|
|
8
|
-
from ..functional import (
|
|
9
|
-
debias, debiased_step_size,
|
|
10
|
-
ema_,
|
|
11
|
-
sqrt_ema_sq_,
|
|
12
|
-
)
|
|
13
|
-
from ..step_size.lr import lazy_lr
|
|
14
|
-
from ..momentum.experimental import sqrt_nag_ema_sq_
|
|
15
|
-
from ..momentum.momentum import nag_
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
def adamy_(
|
|
19
|
-
p: TensorList,
|
|
20
|
-
p_prev: TensorList,
|
|
21
|
-
g: TensorList,
|
|
22
|
-
g_prev: TensorList,
|
|
23
|
-
exp_avg_: TensorList,
|
|
24
|
-
exp_avg_sq_: TensorList,
|
|
25
|
-
alpha: float | NumberList,
|
|
26
|
-
beta1: float | NumberList,
|
|
27
|
-
beta2: float | NumberList,
|
|
28
|
-
eps: float | NumberList,
|
|
29
|
-
step: int,
|
|
30
|
-
pow: float = 2,
|
|
31
|
-
debiased: bool = True,
|
|
32
|
-
max_exp_avg_sq_: TensorList | None = None,
|
|
33
|
-
params_: TensorList | None = None,
|
|
34
|
-
):
|
|
35
|
-
"""Returns new tensors or updates params in-place."""
|
|
36
|
-
if step == 1:
|
|
37
|
-
p_prev.copy_(p)
|
|
38
|
-
g_prev.copy_(g)
|
|
39
|
-
|
|
40
|
-
update = g.clip(-0.1,0.1).lazy_mul_(alpha)
|
|
41
|
-
if params_ is None: return update
|
|
42
|
-
params_.sub_(update)
|
|
43
|
-
return None
|
|
44
|
-
|
|
45
|
-
s = p-p_prev
|
|
46
|
-
y = (g-g_prev).div_(s.global_vector_norm().clip(min=1e-8))
|
|
47
|
-
p_prev.copy_(p)
|
|
48
|
-
g_prev.copy_(g)
|
|
49
|
-
|
|
50
|
-
exp_avg_ = ema_(g, exp_avg_=exp_avg_, beta=beta1, dampening=0,lerp=True)
|
|
51
|
-
|
|
52
|
-
sqrt_exp_avg_sq = sqrt_ema_sq_(y, exp_avg_sq_=exp_avg_sq_, beta=beta2, max_exp_avg_sq_=max_exp_avg_sq_,
|
|
53
|
-
debiased=False,step=step,pow=pow)
|
|
54
|
-
|
|
55
|
-
if debiased: alpha = debiased_step_size(step, beta1=beta1, beta2=beta2, pow=pow, alpha=alpha)
|
|
56
|
-
|
|
57
|
-
# params is None, return update
|
|
58
|
-
if params_ is None: return (exp_avg_ / sqrt_exp_avg_sq.add_(eps)).lazy_mul(alpha)
|
|
59
|
-
|
|
60
|
-
# update params in-place
|
|
61
|
-
params_.addcdiv_(exp_avg_, sqrt_exp_avg_sq.add_(eps), -alpha)
|
|
62
|
-
return None
|
|
63
|
-
|
|
64
|
-
class AdamY(Module):
|
|
65
|
-
"""Adam but uses scaled gradient differences for second momentum.
|
|
66
|
-
|
|
67
|
-
Verdict: I haven't tested this yet.
|
|
68
|
-
|
|
69
|
-
.. warning::
|
|
70
|
-
Experimental.
|
|
71
|
-
"""
|
|
72
|
-
def __init__(
|
|
73
|
-
self,
|
|
74
|
-
beta1: float = 0.9,
|
|
75
|
-
beta2: float = 0.999,
|
|
76
|
-
eps: float = 1e-8,
|
|
77
|
-
amsgrad: bool = False,
|
|
78
|
-
alpha: float = 1.,
|
|
79
|
-
pow: float = 2,
|
|
80
|
-
debiased: bool = True,
|
|
81
|
-
):
|
|
82
|
-
defaults=dict(beta1=beta1,beta2=beta2,eps=eps,alpha=alpha,amsgrad=amsgrad,pow=pow,debiased=debiased)
|
|
83
|
-
super().__init__(defaults)
|
|
84
|
-
self.getter = itemgetter('amsgrad','pow','debiased')
|
|
85
|
-
|
|
86
|
-
@torch.no_grad
|
|
87
|
-
def step(self, var):
|
|
88
|
-
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
89
|
-
|
|
90
|
-
beta1,beta2,eps,alpha=self.get_settings(var.params, 'beta1','beta2','eps','alpha', cls=NumberList)
|
|
91
|
-
amsgrad,pow,debiased = self.getter(self.settings[var.params[0]])
|
|
92
|
-
|
|
93
|
-
if amsgrad:
|
|
94
|
-
exp_avg, exp_avg_sq, max_exp_avg_sq = self.get_state(var.params,'exp_avg','exp_avg_sq','max_exp_avg_sq', cls=TensorList)
|
|
95
|
-
else:
|
|
96
|
-
exp_avg, exp_avg_sq = self.get_state(var.params, 'exp_avg','exp_avg_sq', cls=TensorList)
|
|
97
|
-
max_exp_avg_sq = None
|
|
98
|
-
|
|
99
|
-
# if this is last module, update parameters in-place with slightly more efficient addcdiv_
|
|
100
|
-
if var.is_last:
|
|
101
|
-
if var.last_module_lrs is not None: alpha = alpha * var.last_module_lrs
|
|
102
|
-
passed_params = TensorList(var.params)
|
|
103
|
-
var.stop = True
|
|
104
|
-
var.skip_update = True
|
|
105
|
-
|
|
106
|
-
else:
|
|
107
|
-
passed_params = None
|
|
108
|
-
|
|
109
|
-
p_prev = self.get_state(var.params, 'p_prev', cls=TensorList)
|
|
110
|
-
g_prev = self.get_state(var.params, 'g_prev', cls=TensorList)
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
var.update = adamy_(
|
|
114
|
-
p=TensorList(var.params),
|
|
115
|
-
p_prev=p_prev,
|
|
116
|
-
g=TensorList(var.get_update()),
|
|
117
|
-
g_prev=g_prev,
|
|
118
|
-
exp_avg_=exp_avg,
|
|
119
|
-
exp_avg_sq_=exp_avg_sq,
|
|
120
|
-
alpha=alpha,
|
|
121
|
-
beta1=beta1,
|
|
122
|
-
beta2=beta2,
|
|
123
|
-
eps=eps,
|
|
124
|
-
step=step,
|
|
125
|
-
pow=pow,
|
|
126
|
-
debiased=debiased,
|
|
127
|
-
max_exp_avg_sq_=max_exp_avg_sq,
|
|
128
|
-
params_=passed_params,
|
|
129
|
-
)
|
|
130
|
-
|
|
131
|
-
return var
|
|
@@ -1,149 +0,0 @@
|
|
|
1
|
-
from operator import itemgetter
|
|
2
|
-
from functools import partial
|
|
3
|
-
import math
|
|
4
|
-
import torch
|
|
5
|
-
|
|
6
|
-
from ...core import Module, Target, Transform, apply_transform, Chainable
|
|
7
|
-
from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
|
|
8
|
-
from ..functional import (
|
|
9
|
-
debias, debiased_step_size,
|
|
10
|
-
ema_,
|
|
11
|
-
sqrt_ema_sq_,
|
|
12
|
-
)
|
|
13
|
-
from ..step_size.lr import lazy_lr
|
|
14
|
-
from ..momentum.experimental import sqrt_nag_ema_sq_
|
|
15
|
-
from ..momentum.momentum import nag_
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
def _lambertw_newton_raphson(x: TensorList, iterations=5):
|
|
19
|
-
# z = torch.zeros_like(x)
|
|
20
|
-
# mask_neg = x < 0
|
|
21
|
-
# mask_pos = ~mask_neg
|
|
22
|
-
|
|
23
|
-
# z[mask_pos] = torch.log(x[mask_pos] + 1.0)
|
|
24
|
-
|
|
25
|
-
# x_neg = x[mask_neg]
|
|
26
|
-
# z_neg = -1.0 + torch.sqrt(2.0 * (1.0 + math.e * x_neg))
|
|
27
|
-
# z[mask_neg] = z_neg
|
|
28
|
-
|
|
29
|
-
# x is always positive
|
|
30
|
-
z = (x+1).log_()
|
|
31
|
-
for _ in range(iterations):
|
|
32
|
-
exp_z = z.exp()
|
|
33
|
-
numerator = z * exp_z - x
|
|
34
|
-
denominator = exp_z * (z + 1.0) + 1e-8
|
|
35
|
-
delta = numerator / denominator
|
|
36
|
-
z -= delta
|
|
37
|
-
return z
|
|
38
|
-
|
|
39
|
-
# https://github.com/gmgeorg/torchlambertw/blob/main/torchlambertw/special.py
|
|
40
|
-
def _lambertw_winitzki(x: TensorList):
|
|
41
|
-
x_log1p = x.log1p()
|
|
42
|
-
return x_log1p * (1.0 - x_log1p.log1p() / (2.0 + x_log1p))
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
def adam_lambertw_(
|
|
46
|
-
tensors: TensorList,
|
|
47
|
-
exp_avg_: TensorList,
|
|
48
|
-
exp_avg_xpx_: TensorList,
|
|
49
|
-
alpha: float | NumberList,
|
|
50
|
-
beta1: float | NumberList,
|
|
51
|
-
beta2: float | NumberList,
|
|
52
|
-
eps: float | NumberList,
|
|
53
|
-
step: int,
|
|
54
|
-
pow: float = 2,
|
|
55
|
-
debiased: bool = True,
|
|
56
|
-
max_exp_avg_xpx_: TensorList | None = None,
|
|
57
|
-
iterations: int | None = 5,
|
|
58
|
-
|
|
59
|
-
# inner args
|
|
60
|
-
inner: Module | None = None,
|
|
61
|
-
params: list[torch.Tensor] | None = None,
|
|
62
|
-
grads: list[torch.Tensor] | None = None,
|
|
63
|
-
):
|
|
64
|
-
"""Returns new tensors."""
|
|
65
|
-
tensors_abs = tensors.abs().clip_(max=20)
|
|
66
|
-
tensors_xpx = tensors_abs.pow_(tensors_abs)
|
|
67
|
-
exp_avg_xpx_.lerp_(tensors_xpx, 1-beta2)
|
|
68
|
-
|
|
69
|
-
if max_exp_avg_xpx_ is not None:
|
|
70
|
-
max_exp_avg_xpx_.maximum_(exp_avg_xpx_)
|
|
71
|
-
exp_avg_xpx_ = max_exp_avg_xpx_
|
|
72
|
-
|
|
73
|
-
if inner is not None:
|
|
74
|
-
assert params is not None
|
|
75
|
-
tensors = TensorList(apply_transform(inner, tensors, params=params, grads=grads))
|
|
76
|
-
|
|
77
|
-
exp_avg_ = ema_(tensors, exp_avg_=exp_avg_, beta=beta1, dampening=0,lerp=True)
|
|
78
|
-
if debiased: alpha = debiased_step_size(step, beta1=beta1, beta2=beta2, pow=pow, alpha=alpha)
|
|
79
|
-
|
|
80
|
-
if iterations is None or iterations < 1: exp_avg_xpx_ = _lambertw_winitzki(exp_avg_xpx_)
|
|
81
|
-
else: exp_avg_xpx_ = _lambertw_newton_raphson(exp_avg_xpx_, iterations)
|
|
82
|
-
|
|
83
|
-
return (exp_avg_.lazy_mul(alpha) / exp_avg_xpx_.add_(eps))
|
|
84
|
-
|
|
85
|
-
class AdamLambertW(Transform):
|
|
86
|
-
"""Adam but uses abs x^x and LambertW instead of square and sqrt.
|
|
87
|
-
The gradient will be clipped to 20 because float32 which you have to use otherwise you're PC will explode.
|
|
88
|
-
|
|
89
|
-
Args:
|
|
90
|
-
beta1 (float, optional): momentum. Defaults to 0.9.
|
|
91
|
-
beta2 (float, optional): second momentum. Defaults to 0.999.
|
|
92
|
-
eps (float, optional): epsilon. Defaults to 1e-8.
|
|
93
|
-
alpha (float, optional): learning rate. Defaults to 1.
|
|
94
|
-
amsgrad (bool, optional): Whether to divide by maximum of EMA of gradient squares instead. Defaults to False.
|
|
95
|
-
pow (float, optional): power used in second momentum power and root. Defaults to 2.
|
|
96
|
-
debiased (bool, optional): whether to apply debiasing to momentums based on current step. Defaults to True.
|
|
97
|
-
iterations (int, optional): 0 or None means Winitzki approximation otherwise number of newton raphson iterations.
|
|
98
|
-
"""
|
|
99
|
-
def __init__(
|
|
100
|
-
self,
|
|
101
|
-
beta1: float = 0.9,
|
|
102
|
-
beta2: float = 0.999,
|
|
103
|
-
eps: float = 1e-8,
|
|
104
|
-
amsgrad: bool = False,
|
|
105
|
-
alpha: float = 1.,
|
|
106
|
-
pow: float = 2,
|
|
107
|
-
debiased: bool = True,
|
|
108
|
-
iterations: int | None = 5,
|
|
109
|
-
inner: Chainable | None = None
|
|
110
|
-
):
|
|
111
|
-
defaults=dict(beta1=beta1,beta2=beta2,eps=eps,alpha=alpha,amsgrad=amsgrad,pow=pow,debiased=debiased, iterations=iterations)
|
|
112
|
-
super().__init__(defaults, uses_grad=False)
|
|
113
|
-
|
|
114
|
-
if inner is not None: self.set_child('inner', inner)
|
|
115
|
-
|
|
116
|
-
@torch.no_grad
|
|
117
|
-
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
118
|
-
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
119
|
-
|
|
120
|
-
beta1,beta2,eps,alpha=unpack_dicts(settings, 'beta1','beta2','eps','alpha', cls=NumberList)
|
|
121
|
-
amsgrad,pow,debiased,iterations = itemgetter('amsgrad','pow','debiased','iterations')(settings[0])
|
|
122
|
-
|
|
123
|
-
if amsgrad:
|
|
124
|
-
exp_avg, exp_avg_xpx, max_exp_avg_xpx = unpack_states(states, tensors, 'exp_avg', 'exp_avg_xpx', 'max_exp_avg_xpx', cls=TensorList)
|
|
125
|
-
else:
|
|
126
|
-
exp_avg, exp_avg_xpx = unpack_states(states, tensors, 'exp_avg', 'exp_avg_xpx', cls=TensorList)
|
|
127
|
-
max_exp_avg_xpx = None
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
return adam_lambertw_(
|
|
131
|
-
tensors=TensorList(tensors),
|
|
132
|
-
exp_avg_=exp_avg,
|
|
133
|
-
exp_avg_xpx_=exp_avg_xpx,
|
|
134
|
-
alpha=alpha,
|
|
135
|
-
beta1=beta1,
|
|
136
|
-
beta2=beta2,
|
|
137
|
-
eps=eps,
|
|
138
|
-
step=step,
|
|
139
|
-
pow=pow,
|
|
140
|
-
debiased=debiased,
|
|
141
|
-
max_exp_avg_xpx_=max_exp_avg_xpx,
|
|
142
|
-
iterations=iterations,
|
|
143
|
-
|
|
144
|
-
# inner args
|
|
145
|
-
inner=self.children.get("inner", None),
|
|
146
|
-
params=params,
|
|
147
|
-
grads=grads,
|
|
148
|
-
|
|
149
|
-
)
|