torchzero 0.1.8__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +57 -0
- tests/test_identical.py +230 -0
- tests/test_module.py +50 -0
- tests/test_opts.py +884 -0
- tests/test_tensorlist.py +1787 -0
- tests/test_utils_optimizer.py +170 -0
- tests/test_vars.py +184 -0
- torchzero/__init__.py +4 -4
- torchzero/core/__init__.py +3 -13
- torchzero/core/module.py +629 -510
- torchzero/core/preconditioner.py +137 -0
- torchzero/core/transform.py +252 -0
- torchzero/modules/__init__.py +13 -21
- torchzero/modules/clipping/__init__.py +3 -0
- torchzero/modules/clipping/clipping.py +320 -0
- torchzero/modules/clipping/ema_clipping.py +135 -0
- torchzero/modules/clipping/growth_clipping.py +187 -0
- torchzero/modules/experimental/__init__.py +13 -18
- torchzero/modules/experimental/absoap.py +350 -0
- torchzero/modules/experimental/adadam.py +111 -0
- torchzero/modules/experimental/adamY.py +135 -0
- torchzero/modules/experimental/adasoap.py +282 -0
- torchzero/modules/experimental/algebraic_newton.py +145 -0
- torchzero/modules/experimental/curveball.py +89 -0
- torchzero/modules/experimental/dsoap.py +290 -0
- torchzero/modules/experimental/gradmin.py +85 -0
- torchzero/modules/experimental/reduce_outward_lr.py +35 -0
- torchzero/modules/experimental/spectral.py +286 -0
- torchzero/modules/experimental/subspace_preconditioners.py +128 -0
- torchzero/modules/experimental/tropical_newton.py +136 -0
- torchzero/modules/functional.py +209 -0
- torchzero/modules/grad_approximation/__init__.py +4 -0
- torchzero/modules/grad_approximation/fdm.py +120 -0
- torchzero/modules/grad_approximation/forward_gradient.py +81 -0
- torchzero/modules/grad_approximation/grad_approximator.py +66 -0
- torchzero/modules/grad_approximation/rfdm.py +259 -0
- torchzero/modules/line_search/__init__.py +5 -30
- torchzero/modules/line_search/backtracking.py +186 -0
- torchzero/modules/line_search/line_search.py +181 -0
- torchzero/modules/line_search/scipy.py +37 -0
- torchzero/modules/line_search/strong_wolfe.py +260 -0
- torchzero/modules/line_search/trust_region.py +61 -0
- torchzero/modules/lr/__init__.py +2 -0
- torchzero/modules/lr/lr.py +59 -0
- torchzero/modules/lr/step_size.py +97 -0
- torchzero/modules/momentum/__init__.py +14 -4
- torchzero/modules/momentum/averaging.py +78 -0
- torchzero/modules/momentum/cautious.py +181 -0
- torchzero/modules/momentum/ema.py +173 -0
- torchzero/modules/momentum/experimental.py +189 -0
- torchzero/modules/momentum/matrix_momentum.py +124 -0
- torchzero/modules/momentum/momentum.py +43 -106
- torchzero/modules/ops/__init__.py +103 -0
- torchzero/modules/ops/accumulate.py +65 -0
- torchzero/modules/ops/binary.py +240 -0
- torchzero/modules/ops/debug.py +25 -0
- torchzero/modules/ops/misc.py +419 -0
- torchzero/modules/ops/multi.py +137 -0
- torchzero/modules/ops/reduce.py +149 -0
- torchzero/modules/ops/split.py +75 -0
- torchzero/modules/ops/switch.py +68 -0
- torchzero/modules/ops/unary.py +115 -0
- torchzero/modules/ops/utility.py +112 -0
- torchzero/modules/optimizers/__init__.py +18 -10
- torchzero/modules/optimizers/adagrad.py +146 -49
- torchzero/modules/optimizers/adam.py +112 -118
- torchzero/modules/optimizers/lion.py +18 -11
- torchzero/modules/optimizers/muon.py +222 -0
- torchzero/modules/optimizers/orthograd.py +55 -0
- torchzero/modules/optimizers/rmsprop.py +103 -51
- torchzero/modules/optimizers/rprop.py +342 -99
- torchzero/modules/optimizers/shampoo.py +197 -0
- torchzero/modules/optimizers/soap.py +286 -0
- torchzero/modules/optimizers/sophia_h.py +129 -0
- torchzero/modules/projections/__init__.py +5 -0
- torchzero/modules/projections/dct.py +73 -0
- torchzero/modules/projections/fft.py +73 -0
- torchzero/modules/projections/galore.py +10 -0
- torchzero/modules/projections/projection.py +218 -0
- torchzero/modules/projections/structural.py +151 -0
- torchzero/modules/quasi_newton/__init__.py +7 -4
- torchzero/modules/quasi_newton/cg.py +218 -0
- torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
- torchzero/modules/quasi_newton/lbfgs.py +228 -0
- torchzero/modules/quasi_newton/lsr1.py +170 -0
- torchzero/modules/quasi_newton/olbfgs.py +196 -0
- torchzero/modules/quasi_newton/quasi_newton.py +475 -0
- torchzero/modules/second_order/__init__.py +3 -4
- torchzero/modules/second_order/newton.py +142 -165
- torchzero/modules/second_order/newton_cg.py +84 -0
- torchzero/modules/second_order/nystrom.py +168 -0
- torchzero/modules/smoothing/__init__.py +2 -5
- torchzero/modules/smoothing/gaussian.py +164 -0
- torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
- torchzero/modules/weight_decay/__init__.py +1 -0
- torchzero/modules/weight_decay/weight_decay.py +52 -0
- torchzero/modules/wrappers/__init__.py +1 -0
- torchzero/modules/wrappers/optim_wrapper.py +91 -0
- torchzero/optim/__init__.py +2 -10
- torchzero/optim/utility/__init__.py +1 -0
- torchzero/optim/utility/split.py +45 -0
- torchzero/optim/wrappers/nevergrad.py +2 -28
- torchzero/optim/wrappers/nlopt.py +31 -16
- torchzero/optim/wrappers/scipy.py +79 -156
- torchzero/utils/__init__.py +27 -0
- torchzero/utils/compile.py +175 -37
- torchzero/utils/derivatives.py +513 -99
- torchzero/utils/linalg/__init__.py +5 -0
- torchzero/utils/linalg/matrix_funcs.py +87 -0
- torchzero/utils/linalg/orthogonalize.py +11 -0
- torchzero/utils/linalg/qr.py +71 -0
- torchzero/utils/linalg/solve.py +168 -0
- torchzero/utils/linalg/svd.py +20 -0
- torchzero/utils/numberlist.py +132 -0
- torchzero/utils/ops.py +10 -0
- torchzero/utils/optimizer.py +284 -0
- torchzero/utils/optuna_tools.py +40 -0
- torchzero/utils/params.py +149 -0
- torchzero/utils/python_tools.py +40 -25
- torchzero/utils/tensorlist.py +1081 -0
- torchzero/utils/torch_tools.py +48 -12
- torchzero-0.3.2.dist-info/METADATA +379 -0
- torchzero-0.3.2.dist-info/RECORD +128 -0
- {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info}/WHEEL +1 -1
- {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info/licenses}/LICENSE +0 -0
- torchzero-0.3.2.dist-info/top_level.txt +3 -0
- torchzero/core/tensorlist_optimizer.py +0 -219
- torchzero/modules/adaptive/__init__.py +0 -4
- torchzero/modules/adaptive/adaptive.py +0 -192
- torchzero/modules/experimental/experimental.py +0 -294
- torchzero/modules/experimental/quad_interp.py +0 -104
- torchzero/modules/experimental/subspace.py +0 -259
- torchzero/modules/gradient_approximation/__init__.py +0 -7
- torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
- torchzero/modules/gradient_approximation/base_approximator.py +0 -105
- torchzero/modules/gradient_approximation/fdm.py +0 -125
- torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
- torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
- torchzero/modules/gradient_approximation/rfdm.py +0 -125
- torchzero/modules/line_search/armijo.py +0 -56
- torchzero/modules/line_search/base_ls.py +0 -139
- torchzero/modules/line_search/directional_newton.py +0 -217
- torchzero/modules/line_search/grid_ls.py +0 -158
- torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
- torchzero/modules/meta/__init__.py +0 -12
- torchzero/modules/meta/alternate.py +0 -65
- torchzero/modules/meta/grafting.py +0 -195
- torchzero/modules/meta/optimizer_wrapper.py +0 -173
- torchzero/modules/meta/return_overrides.py +0 -46
- torchzero/modules/misc/__init__.py +0 -10
- torchzero/modules/misc/accumulate.py +0 -43
- torchzero/modules/misc/basic.py +0 -115
- torchzero/modules/misc/lr.py +0 -96
- torchzero/modules/misc/multistep.py +0 -51
- torchzero/modules/misc/on_increase.py +0 -53
- torchzero/modules/operations/__init__.py +0 -29
- torchzero/modules/operations/multi.py +0 -298
- torchzero/modules/operations/reduction.py +0 -134
- torchzero/modules/operations/singular.py +0 -113
- torchzero/modules/optimizers/sgd.py +0 -54
- torchzero/modules/orthogonalization/__init__.py +0 -2
- torchzero/modules/orthogonalization/newtonschulz.py +0 -159
- torchzero/modules/orthogonalization/svd.py +0 -86
- torchzero/modules/regularization/__init__.py +0 -22
- torchzero/modules/regularization/dropout.py +0 -34
- torchzero/modules/regularization/noise.py +0 -77
- torchzero/modules/regularization/normalization.py +0 -328
- torchzero/modules/regularization/ortho_grad.py +0 -78
- torchzero/modules/regularization/weight_decay.py +0 -92
- torchzero/modules/scheduling/__init__.py +0 -2
- torchzero/modules/scheduling/lr_schedulers.py +0 -131
- torchzero/modules/scheduling/step_size.py +0 -80
- torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
- torchzero/modules/weight_averaging/__init__.py +0 -2
- torchzero/modules/weight_averaging/ema.py +0 -72
- torchzero/modules/weight_averaging/swa.py +0 -171
- torchzero/optim/experimental/__init__.py +0 -20
- torchzero/optim/experimental/experimental.py +0 -343
- torchzero/optim/experimental/ray_search.py +0 -83
- torchzero/optim/first_order/__init__.py +0 -18
- torchzero/optim/first_order/cautious.py +0 -158
- torchzero/optim/first_order/forward_gradient.py +0 -70
- torchzero/optim/first_order/optimizers.py +0 -570
- torchzero/optim/modular.py +0 -148
- torchzero/optim/quasi_newton/__init__.py +0 -1
- torchzero/optim/quasi_newton/directional_newton.py +0 -58
- torchzero/optim/second_order/__init__.py +0 -1
- torchzero/optim/second_order/newton.py +0 -94
- torchzero/optim/zeroth_order/__init__.py +0 -4
- torchzero/optim/zeroth_order/fdm.py +0 -87
- torchzero/optim/zeroth_order/newton_fdm.py +0 -146
- torchzero/optim/zeroth_order/rfdm.py +0 -217
- torchzero/optim/zeroth_order/rs.py +0 -85
- torchzero/random/__init__.py +0 -1
- torchzero/random/random.py +0 -46
- torchzero/tensorlist.py +0 -826
- torchzero-0.1.8.dist-info/METADATA +0 -130
- torchzero-0.1.8.dist-info/RECORD +0 -104
- torchzero-0.1.8.dist-info/top_level.txt +0 -1
|
@@ -0,0 +1,350 @@
|
|
|
1
|
+
from operator import itemgetter
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from typing import Literal
|
|
5
|
+
from ...core import Chainable, Transform, apply
|
|
6
|
+
from ..optimizers.shampoo import _merge_small_dims, _unmerge_small_dims
|
|
7
|
+
|
|
8
|
+
@torch.no_grad
|
|
9
|
+
def update_soap_covariances_(
|
|
10
|
+
g1: torch.Tensor,
|
|
11
|
+
g2: torch.Tensor,
|
|
12
|
+
GGs_: list[torch.Tensor | None],
|
|
13
|
+
beta: float | None,
|
|
14
|
+
):
|
|
15
|
+
for i, GG in enumerate(GGs_):
|
|
16
|
+
if GG is None: continue
|
|
17
|
+
|
|
18
|
+
axes = list(range(i)) + list(range(i + 1, g1.ndim)) # this works fine with 1d params
|
|
19
|
+
if beta is None: GG.add_(torch.tensordot(g1, g2, (axes, axes))) # pyright:ignore[reportArgumentType]
|
|
20
|
+
else: GG.lerp_(torch.tensordot(g1, g2, (axes, axes)), 1-beta) # pyright:ignore[reportArgumentType]
|
|
21
|
+
|
|
22
|
+
@torch.no_grad
|
|
23
|
+
def project(tensors: torch.Tensor, Q: list[torch.Tensor | None]):
|
|
24
|
+
"""
|
|
25
|
+
Projects the gradient to the eigenbases of the preconditioner.
|
|
26
|
+
"""
|
|
27
|
+
for mat in Q:
|
|
28
|
+
if mat is None: continue
|
|
29
|
+
if len(mat) > 0:
|
|
30
|
+
tensors = torch.tensordot(tensors, mat, dims=[[0], [0]]) # pyright:ignore[reportArgumentType]
|
|
31
|
+
else:
|
|
32
|
+
# I don't understand this part but it is in https://github.com/nikhilvyas/SOAP/blob/main/soap.py
|
|
33
|
+
permute_order = list(range(1, len(tensors.shape))) + [0]
|
|
34
|
+
tensors = tensors.permute(permute_order)
|
|
35
|
+
|
|
36
|
+
return tensors
|
|
37
|
+
|
|
38
|
+
@torch.no_grad
|
|
39
|
+
def project_back(tensors: torch.Tensor, Q: list[torch.Tensor| None]):
|
|
40
|
+
"""
|
|
41
|
+
Projects the gradient back to the original space.
|
|
42
|
+
"""
|
|
43
|
+
for mat in Q:
|
|
44
|
+
if mat is None: continue
|
|
45
|
+
if len(mat) > 0:
|
|
46
|
+
tensors = torch.tensordot(tensors, mat,dims=[[0], [1]]) # pyright:ignore[reportArgumentType]
|
|
47
|
+
else:
|
|
48
|
+
permute_order = list(range(1, len(tensors.shape))) + [0]
|
|
49
|
+
tensors = tensors.permute(permute_order)
|
|
50
|
+
|
|
51
|
+
return tensors
|
|
52
|
+
|
|
53
|
+
# function from https://github.com/nikhilvyas/SOAP/blob/main/soap.py
|
|
54
|
+
@torch.no_grad
|
|
55
|
+
def get_orthogonal_matrix(mat: list[torch.Tensor | None]):
|
|
56
|
+
"""
|
|
57
|
+
Computes the eigenbases of the preconditioner using torch.linalg.eigh decomposition.
|
|
58
|
+
"""
|
|
59
|
+
matrix = []
|
|
60
|
+
float_data = False
|
|
61
|
+
original_type = original_device = None
|
|
62
|
+
for m in mat:
|
|
63
|
+
if m is None: continue
|
|
64
|
+
if len(m) == 0:
|
|
65
|
+
matrix.append([])
|
|
66
|
+
continue
|
|
67
|
+
if m.dtype != torch.float:
|
|
68
|
+
original_type = m.dtype
|
|
69
|
+
original_device = m.device
|
|
70
|
+
matrix.append(m.float())
|
|
71
|
+
else:
|
|
72
|
+
float_data = True
|
|
73
|
+
matrix.append(m)
|
|
74
|
+
|
|
75
|
+
final = []
|
|
76
|
+
for m in matrix:
|
|
77
|
+
if len(m) == 0:
|
|
78
|
+
final.append([])
|
|
79
|
+
continue
|
|
80
|
+
try:
|
|
81
|
+
_, Q = torch.linalg.eigh(m+1e-30*torch.eye(m.shape[0], device=m.device)) # pylint:disable=not-callable
|
|
82
|
+
except Exception:
|
|
83
|
+
_, Q = torch.linalg.eigh(m.to(torch.float64)+1e-30*torch.eye(m.shape[0], device=m.device)) # pylint:disable=not-callable
|
|
84
|
+
Q = Q.to(m.dtype)
|
|
85
|
+
Q = torch.flip(Q, [1])
|
|
86
|
+
|
|
87
|
+
if not float_data:
|
|
88
|
+
Q = Q.to(original_device).type(original_type)
|
|
89
|
+
final.append(Q)
|
|
90
|
+
return final
|
|
91
|
+
|
|
92
|
+
# function from https://github.com/nikhilvyas/SOAP/blob/main/soap.py#L240
|
|
93
|
+
@torch.no_grad
|
|
94
|
+
def get_orthogonal_matrix_QR(exp_avg_sq: torch.Tensor, GG: list[torch.Tensor | None], Q_list: list[torch.Tensor | None]):
|
|
95
|
+
"""
|
|
96
|
+
Computes the eigenbases of the preconditioner using one round of power iteration
|
|
97
|
+
followed by torch.linalg.qr decomposition.
|
|
98
|
+
"""
|
|
99
|
+
matrix = []
|
|
100
|
+
orth_matrix = []
|
|
101
|
+
float_data = False
|
|
102
|
+
original_type = original_device = None
|
|
103
|
+
for m,o in zip(GG, Q_list):
|
|
104
|
+
if m is None: continue
|
|
105
|
+
assert o is not None
|
|
106
|
+
|
|
107
|
+
if len(m) == 0:
|
|
108
|
+
matrix.append([])
|
|
109
|
+
orth_matrix.append([])
|
|
110
|
+
continue
|
|
111
|
+
if m.data.dtype != torch.float:
|
|
112
|
+
original_type = m.data.dtype
|
|
113
|
+
original_device = m.data.device
|
|
114
|
+
matrix.append(m.data.float())
|
|
115
|
+
orth_matrix.append(o.data.float())
|
|
116
|
+
else:
|
|
117
|
+
float_data = True
|
|
118
|
+
matrix.append(m.data.float())
|
|
119
|
+
orth_matrix.append(o.data.float())
|
|
120
|
+
|
|
121
|
+
final = []
|
|
122
|
+
for ind, (m,o) in enumerate(zip(matrix, orth_matrix)):
|
|
123
|
+
if len(m)==0:
|
|
124
|
+
final.append([])
|
|
125
|
+
continue
|
|
126
|
+
est_eig = torch.diag(o.T @ m @ o)
|
|
127
|
+
sort_idx = torch.argsort(est_eig, descending=True)
|
|
128
|
+
exp_avg_sq = exp_avg_sq.index_select(ind, sort_idx)
|
|
129
|
+
o = o[:,sort_idx]
|
|
130
|
+
power_iter = m @ o
|
|
131
|
+
Q, _ = torch.linalg.qr(power_iter) # pylint:disable=not-callable
|
|
132
|
+
|
|
133
|
+
if not float_data:
|
|
134
|
+
Q = Q.to(original_device).type(original_type)
|
|
135
|
+
final.append(Q)
|
|
136
|
+
|
|
137
|
+
return final, exp_avg_sq
|
|
138
|
+
|
|
139
|
+
Source=Literal['p','g','s','y', 'gy', 'sy', 'sn', 'yn', 'gys', 'sys','sn', 'yn']
|
|
140
|
+
class ABSOAP(Transform):
|
|
141
|
+
"""SOAP but with two extra letters included in its name in order to improve converence
|
|
142
|
+
|
|
143
|
+
new args
|
|
144
|
+
|
|
145
|
+
scale by s whether to scale gradient differences by parameter differences
|
|
146
|
+
|
|
147
|
+
y_to_ema2 whether to use gradient differences for exponential moving average too
|
|
148
|
+
"""
|
|
149
|
+
def __init__(
|
|
150
|
+
self,
|
|
151
|
+
beta1: float = 0.95,
|
|
152
|
+
beta2: float = 0.95,
|
|
153
|
+
shampoo_beta: float | None = 0.95,
|
|
154
|
+
precond_freq: int = 10,
|
|
155
|
+
merge_small: bool = True,
|
|
156
|
+
max_dim: int = 2_000,
|
|
157
|
+
precondition_1d: bool = True,
|
|
158
|
+
eps: float = 1e-8,
|
|
159
|
+
decay: float | None = None,
|
|
160
|
+
alpha: float = 1,
|
|
161
|
+
bias_correction: bool = True,
|
|
162
|
+
scale_by_s: bool = True,
|
|
163
|
+
first: Source='g',
|
|
164
|
+
second: Source='g',
|
|
165
|
+
ema1: Source='g',
|
|
166
|
+
ema2: tuple[Source, Source] = ('g','g'),
|
|
167
|
+
rel1: bool=False,
|
|
168
|
+
rel2: bool=False,
|
|
169
|
+
norm: bool = False,
|
|
170
|
+
):
|
|
171
|
+
defaults = dict(
|
|
172
|
+
beta1=beta1,
|
|
173
|
+
beta2=beta2,
|
|
174
|
+
shampoo_beta=shampoo_beta,
|
|
175
|
+
precond_freq=precond_freq,
|
|
176
|
+
merge_small=merge_small,
|
|
177
|
+
max_dim=max_dim,
|
|
178
|
+
precondition_1d=precondition_1d,
|
|
179
|
+
eps=eps,
|
|
180
|
+
decay=decay,
|
|
181
|
+
bias_correction=bias_correction,
|
|
182
|
+
alpha=alpha,
|
|
183
|
+
scale_by_s=scale_by_s,
|
|
184
|
+
ema1=ema1,
|
|
185
|
+
ema2=ema2,
|
|
186
|
+
first=first,
|
|
187
|
+
second=second,
|
|
188
|
+
rel1=rel1, rel2=rel2,
|
|
189
|
+
norm=norm,
|
|
190
|
+
)
|
|
191
|
+
super().__init__(defaults, uses_grad=False)
|
|
192
|
+
|
|
193
|
+
@torch.no_grad
|
|
194
|
+
def transform(self, tensors, params, grads, vars):
|
|
195
|
+
updates = []
|
|
196
|
+
# update preconditioners
|
|
197
|
+
for i,(p,t) in enumerate(zip(params, tensors)):
|
|
198
|
+
state = self.state[p]
|
|
199
|
+
settings = self.settings[p]
|
|
200
|
+
beta1, beta2, shampoo_beta, merge_small, max_dim, precondition_1d, eps, alpha = itemgetter(
|
|
201
|
+
'beta1', 'beta2', 'shampoo_beta', 'merge_small', 'max_dim', 'precondition_1d', 'eps', 'alpha')(settings)
|
|
202
|
+
scale_by_s = settings['scale_by_s']
|
|
203
|
+
ema1 = settings['ema1']
|
|
204
|
+
ema2 = settings['ema2']
|
|
205
|
+
first=settings['first']
|
|
206
|
+
second=settings['second']
|
|
207
|
+
rel1 = settings['rel1']; rel2 = settings['rel2']
|
|
208
|
+
norm=settings['norm']
|
|
209
|
+
|
|
210
|
+
if merge_small:
|
|
211
|
+
t, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(t, max_dim)
|
|
212
|
+
|
|
213
|
+
if 'g_prev' not in state:
|
|
214
|
+
state['p_prev'] = p.clone()
|
|
215
|
+
state['g_prev'] = t.clone()
|
|
216
|
+
updates.append(tensors[i].sign())
|
|
217
|
+
continue
|
|
218
|
+
|
|
219
|
+
p_prev = state['p_prev']
|
|
220
|
+
g_prev = state['g_prev']
|
|
221
|
+
s = p - p_prev
|
|
222
|
+
y = t - g_prev
|
|
223
|
+
|
|
224
|
+
# keep malding
|
|
225
|
+
p_norm = torch.linalg.vector_norm(p) # pylint:disable=not-callable
|
|
226
|
+
g_norm = torch.linalg.vector_norm(t) # pylint:disable=not-callable
|
|
227
|
+
s_norm = torch.linalg.vector_norm(s) # pylint:disable=not-callable
|
|
228
|
+
y_norm = torch.linalg.vector_norm(y) # pylint:disable=not-callable
|
|
229
|
+
|
|
230
|
+
sn = p - p_prev * (p_norm / torch.linalg.vector_norm(p_prev))# pylint:disable=not-callable
|
|
231
|
+
yn = t - g_prev * (g_norm / torch.linalg.vector_norm(g_prev))# pylint:disable=not-callable
|
|
232
|
+
|
|
233
|
+
if scale_by_s: y /= s_norm.clip(min=1e-8) # pylint:disable=not-callable
|
|
234
|
+
|
|
235
|
+
state['p_prev'].copy_(p)
|
|
236
|
+
state['g_prev'].copy_(t)
|
|
237
|
+
|
|
238
|
+
def _get(c: Source):
|
|
239
|
+
if c == 'p': return p
|
|
240
|
+
if c == 'g': return t
|
|
241
|
+
if c == 's': return s
|
|
242
|
+
if c == 'y': return y
|
|
243
|
+
if c == 'sn': return sn
|
|
244
|
+
if c == 'yn': return yn
|
|
245
|
+
if c == 'gy': return t+y
|
|
246
|
+
if c == 'sy': return s+y
|
|
247
|
+
if c == 'gys':
|
|
248
|
+
y_scaled = y * (g_norm/y_norm.clip(min=1e-8))
|
|
249
|
+
return t+y_scaled
|
|
250
|
+
if c == 'sys':
|
|
251
|
+
y_scaled = y * (s_norm/y_norm.clip(min=1e-8))
|
|
252
|
+
return s+y_scaled
|
|
253
|
+
raise RuntimeError("Big Chungus")
|
|
254
|
+
|
|
255
|
+
t1 = _get(first)
|
|
256
|
+
if rel1: t1 = t1 * p.abs().clip(min=1e-6)
|
|
257
|
+
t2 = _get(second)
|
|
258
|
+
if rel2: t2 = t2 * p.abs().clip(min=1e-6)
|
|
259
|
+
|
|
260
|
+
t_ema1 = _get(ema1)
|
|
261
|
+
t_ema2s = _get(ema2[0]), _get(ema2[1])
|
|
262
|
+
|
|
263
|
+
if norm:
|
|
264
|
+
t1 = t1/torch.linalg.vector_norm(t1).clip(min=1e-8) # pylint:disable=not-callable
|
|
265
|
+
t2 = t2/torch.linalg.vector_norm(t2).clip(min=1e-8) # pylint:disable=not-callable
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
# initialize state on 1st step
|
|
269
|
+
if 'GG' not in state:
|
|
270
|
+
state["exp_avg"] = torch.zeros_like(t)
|
|
271
|
+
state["exp_avg_sq"] = torch.ones_like(t)
|
|
272
|
+
|
|
273
|
+
if not precondition_1d and t.ndim <= 1:
|
|
274
|
+
state['GG'] = []
|
|
275
|
+
|
|
276
|
+
else:
|
|
277
|
+
state['GG'] = [torch.zeros(sh, sh, dtype=t.dtype, device=t.device) if 1<sh<max_dim else None for sh in t.shape]
|
|
278
|
+
|
|
279
|
+
# either scalar parameter, 1d with precondition_1d=False, or all dims are too big.
|
|
280
|
+
if len([i is not None for i in state['GG']]) == 0:
|
|
281
|
+
state['GG'] = None
|
|
282
|
+
|
|
283
|
+
if state['GG'] is not None:
|
|
284
|
+
update_soap_covariances_(t1, t2, GGs_=state['GG'], beta=shampoo_beta)
|
|
285
|
+
state['Q'] = get_orthogonal_matrix(state['GG'])
|
|
286
|
+
|
|
287
|
+
state['step'] = 0
|
|
288
|
+
updates.append(tensors[i].sign())
|
|
289
|
+
continue # skip 1st step as in https://github.com/nikhilvyas/SOAP/blob/main/soap.py ?
|
|
290
|
+
# I use sign instead as to not mess up with next modules. 1st Adam step is always sign anyway.
|
|
291
|
+
|
|
292
|
+
# Projecting gradients to the eigenbases of Shampoo's preconditioner
|
|
293
|
+
# i.e. projecting to the eigenbases of matrices in state['GG']
|
|
294
|
+
z1_projected = None
|
|
295
|
+
z2_projected = None
|
|
296
|
+
|
|
297
|
+
if state['GG'] is not None:
|
|
298
|
+
z1_projected = project(t_ema2s[0], state['Q'])
|
|
299
|
+
if ema2[0] == ema2[1]: z2_projected = z1_projected
|
|
300
|
+
else: z2_projected = project(t_ema2s[1], state['Q'])
|
|
301
|
+
|
|
302
|
+
# exponential moving averages
|
|
303
|
+
# this part could be foreached but I will do that at some point its not a big difference compared to preconditioning
|
|
304
|
+
exp_avg: torch.Tensor = state["exp_avg"]
|
|
305
|
+
exp_avg_sq: torch.Tensor = state["exp_avg_sq"]
|
|
306
|
+
|
|
307
|
+
exp_avg.lerp_(t_ema1, 1-beta1)
|
|
308
|
+
|
|
309
|
+
if z1_projected is None:
|
|
310
|
+
exp_avg_sq.mul_(beta2).addcmul_(*t_ema2s, value=1-beta2)
|
|
311
|
+
else:
|
|
312
|
+
assert z2_projected is not None
|
|
313
|
+
exp_avg_sq.mul_(beta2).addcmul_(z1_projected, z2_projected, value=1-beta2)
|
|
314
|
+
|
|
315
|
+
# project exponential moving averages if they are accumulated unprojected
|
|
316
|
+
exp_avg_projected = exp_avg
|
|
317
|
+
if z1_projected is not None:
|
|
318
|
+
exp_avg_projected = project(exp_avg, state['Q'])
|
|
319
|
+
|
|
320
|
+
exp_avg_sq_projected = exp_avg_sq
|
|
321
|
+
|
|
322
|
+
denom = exp_avg_sq_projected.sqrt().add_(eps)
|
|
323
|
+
# print(f'{t_projected = }, {exp_avg = }, {exp_avg_projected = }, {exp_avg_sq = }, {exp_avg_sq_projected = }, {denom = }')
|
|
324
|
+
|
|
325
|
+
# Projecting back the preconditioned (by Adam) exponential moving average of gradients
|
|
326
|
+
# to the original space
|
|
327
|
+
update = exp_avg_projected / denom
|
|
328
|
+
if z1_projected is not None:
|
|
329
|
+
update = project_back(update, state["Q"])
|
|
330
|
+
|
|
331
|
+
if settings['bias_correction']:
|
|
332
|
+
bias_correction1 = 1.0 - beta1 ** (state["step"]+1)
|
|
333
|
+
bias_correction2 = 1.0 - beta2 ** (state["step"]+1)
|
|
334
|
+
update *= ((bias_correction2 ** .5) / bias_correction1) * alpha
|
|
335
|
+
elif alpha is not None:
|
|
336
|
+
update *= alpha
|
|
337
|
+
|
|
338
|
+
if merge_small:
|
|
339
|
+
update = _unmerge_small_dims(update, state['flat_sizes'], state['sort_idxs'])
|
|
340
|
+
|
|
341
|
+
updates.append(update)
|
|
342
|
+
state["step"] += 1
|
|
343
|
+
|
|
344
|
+
# Update is done after the gradient step to avoid using current gradients in the projection.
|
|
345
|
+
if state['GG'] is not None:
|
|
346
|
+
update_soap_covariances_(t1, t2, state['GG'], shampoo_beta)
|
|
347
|
+
if state['step'] % settings['precond_freq'] == 0:
|
|
348
|
+
state['Q'], state['exp_avg_sq'] = get_orthogonal_matrix_QR(exp_avg_sq, state['GG'], state['Q'])
|
|
349
|
+
|
|
350
|
+
return updates
|
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
from operator import itemgetter
|
|
2
|
+
from functools import partial
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from ...core import Module, Target, Transform
|
|
7
|
+
from ...utils import NumberList, TensorList
|
|
8
|
+
from ..functional import (
|
|
9
|
+
debias, debiased_step_size,
|
|
10
|
+
ema_,
|
|
11
|
+
sqrt_ema_sq_,
|
|
12
|
+
)
|
|
13
|
+
from ..lr.lr import lazy_lr
|
|
14
|
+
from ..momentum.experimental import sqrt_nag_ema_sq_
|
|
15
|
+
from ..momentum.momentum import nag_
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def adadam_(
|
|
19
|
+
tensors: TensorList,
|
|
20
|
+
exp_avg_: TensorList,
|
|
21
|
+
exp_avg_sq_: TensorList,
|
|
22
|
+
exp_avg_qu_: TensorList,
|
|
23
|
+
alpha: float | NumberList,
|
|
24
|
+
beta1: float | NumberList,
|
|
25
|
+
beta2: float | NumberList,
|
|
26
|
+
precond_beta: float | NumberList,
|
|
27
|
+
eps: float | NumberList,
|
|
28
|
+
step: int,
|
|
29
|
+
pow: float = 2,
|
|
30
|
+
debiased: bool = True,
|
|
31
|
+
max_exp_avg_sq_: TensorList | None = None,
|
|
32
|
+
max_exp_avg_qu_: TensorList | None = None,
|
|
33
|
+
params_: TensorList | None = None,
|
|
34
|
+
):
|
|
35
|
+
"""Returns new tensors or updates params in-place."""
|
|
36
|
+
exp_avg_ = ema_(tensors, exp_avg_=exp_avg_, beta=beta1, dampening=0,lerp=True)
|
|
37
|
+
|
|
38
|
+
sqrt_exp_avg_sq = sqrt_ema_sq_(tensors, exp_avg_sq_=exp_avg_sq_, beta=beta2, max_exp_avg_sq_=max_exp_avg_sq_,
|
|
39
|
+
debiased=False,step=step,pow=pow)
|
|
40
|
+
sqrt_exp_avg_qu = sqrt_ema_sq_(tensors/(sqrt_exp_avg_sq+1e-8), exp_avg_sq_=exp_avg_qu_,
|
|
41
|
+
beta=precond_beta,max_exp_avg_sq_=max_exp_avg_qu_, debiased=False,step=step,pow=pow)
|
|
42
|
+
|
|
43
|
+
if debiased: alpha = debiased_step_size(step, beta1=beta1, beta2=beta2, pow=pow, alpha=alpha)
|
|
44
|
+
|
|
45
|
+
# params is None, return update
|
|
46
|
+
if params_ is None: return (exp_avg_ / sqrt_exp_avg_qu.add_(eps)).lazy_mul(alpha)
|
|
47
|
+
|
|
48
|
+
# update params in-place
|
|
49
|
+
params_.addcdiv_(exp_avg_, sqrt_exp_avg_qu.add_(eps), -alpha)
|
|
50
|
+
return None
|
|
51
|
+
|
|
52
|
+
class Adadam(Module):
|
|
53
|
+
"""Adam with a diagonally preconditioned preconditioner and a graceful name."""
|
|
54
|
+
def __init__(
|
|
55
|
+
self,
|
|
56
|
+
beta1: float = 0.9,
|
|
57
|
+
beta2: float = 0.999,
|
|
58
|
+
precond_beta: float = 0.999,
|
|
59
|
+
eps: float = 1e-8,
|
|
60
|
+
amsgrad: bool = False,
|
|
61
|
+
alpha: float = 1.,
|
|
62
|
+
pow: float = 2,
|
|
63
|
+
debiased: bool = True,
|
|
64
|
+
):
|
|
65
|
+
defaults=dict(beta1=beta1,beta2=beta2,precond_beta=precond_beta,eps=eps,alpha=alpha,amsgrad=amsgrad,pow=pow,debiased=debiased)
|
|
66
|
+
super().__init__(defaults)
|
|
67
|
+
self.getter = itemgetter('amsgrad','pow','debiased')
|
|
68
|
+
|
|
69
|
+
@torch.no_grad
|
|
70
|
+
def step(self, vars):
|
|
71
|
+
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
72
|
+
|
|
73
|
+
beta1,beta2,precond_beta,eps,alpha=self.get_settings('beta1','beta2','precond_beta','eps','alpha', params=vars.params, cls=NumberList)
|
|
74
|
+
amsgrad,pow,debiased = self.getter(self.settings[vars.params[0]])
|
|
75
|
+
|
|
76
|
+
if amsgrad:
|
|
77
|
+
exp_avg, exp_avg_sq, exp_avg_qu, max_exp_avg_sq, max_exp_avg_qu = self.get_state('exp_avg','exp_avg_sq', 'exp_avg_qu', 'max_exp_avg_sq', 'max_exp_avg_qu', params=vars.params, cls=TensorList)
|
|
78
|
+
else:
|
|
79
|
+
exp_avg, exp_avg_sq, exp_avg_qu = self.get_state('exp_avg','exp_avg_sq', 'exp_avg_qu', params=vars.params, cls=TensorList)
|
|
80
|
+
max_exp_avg_sq = None
|
|
81
|
+
max_exp_avg_qu = None
|
|
82
|
+
|
|
83
|
+
# if this is last module, update parameters in-place with slightly more efficient addcdiv_
|
|
84
|
+
if vars.is_last:
|
|
85
|
+
if vars.last_module_lrs is not None: alpha = alpha * vars.last_module_lrs
|
|
86
|
+
passed_params = TensorList(vars.params)
|
|
87
|
+
vars.stop = True
|
|
88
|
+
vars.skip_update = True
|
|
89
|
+
|
|
90
|
+
else:
|
|
91
|
+
passed_params = None
|
|
92
|
+
|
|
93
|
+
vars.update = adadam_(
|
|
94
|
+
tensors=TensorList(vars.get_update()),
|
|
95
|
+
exp_avg_=exp_avg,
|
|
96
|
+
exp_avg_sq_=exp_avg_sq,
|
|
97
|
+
exp_avg_qu_=exp_avg_qu,
|
|
98
|
+
alpha=alpha,
|
|
99
|
+
beta1=beta1,
|
|
100
|
+
beta2=beta2,
|
|
101
|
+
precond_beta=precond_beta,
|
|
102
|
+
eps=eps,
|
|
103
|
+
step=step,
|
|
104
|
+
pow=pow,
|
|
105
|
+
debiased=debiased,
|
|
106
|
+
max_exp_avg_sq_=max_exp_avg_sq,
|
|
107
|
+
max_exp_avg_qu_=max_exp_avg_qu,
|
|
108
|
+
params_=passed_params,
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
return vars
|
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
from operator import itemgetter
|
|
2
|
+
from functools import partial
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from ...core import Module, Target, Transform
|
|
7
|
+
from ...utils import NumberList, TensorList
|
|
8
|
+
from ..functional import (
|
|
9
|
+
debias, debiased_step_size,
|
|
10
|
+
ema_,
|
|
11
|
+
sqrt_ema_sq_,
|
|
12
|
+
)
|
|
13
|
+
from ..lr.lr import lazy_lr
|
|
14
|
+
from ..momentum.experimental import sqrt_nag_ema_sq_
|
|
15
|
+
from ..momentum.momentum import nag_
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def adamy_(
|
|
19
|
+
p: TensorList,
|
|
20
|
+
p_prev: TensorList,
|
|
21
|
+
g: TensorList,
|
|
22
|
+
g_prev: TensorList,
|
|
23
|
+
exp_avg_: TensorList,
|
|
24
|
+
exp_avg_sq_: TensorList,
|
|
25
|
+
alpha: float | NumberList,
|
|
26
|
+
beta1: float | NumberList,
|
|
27
|
+
beta2: float | NumberList,
|
|
28
|
+
eps: float | NumberList,
|
|
29
|
+
step: int,
|
|
30
|
+
pow: float = 2,
|
|
31
|
+
debiased: bool = True,
|
|
32
|
+
max_exp_avg_sq_: TensorList | None = None,
|
|
33
|
+
params_: TensorList | None = None,
|
|
34
|
+
):
|
|
35
|
+
"""Returns new tensors or updates params in-place."""
|
|
36
|
+
if step == 1:
|
|
37
|
+
p_prev.copy_(p)
|
|
38
|
+
g_prev.copy_(g)
|
|
39
|
+
|
|
40
|
+
update = g.sign().lazy_mul_(alpha*0.1)
|
|
41
|
+
if params_ is None: return update
|
|
42
|
+
params_.sub_(update)
|
|
43
|
+
return None
|
|
44
|
+
|
|
45
|
+
s = p-p_prev
|
|
46
|
+
y = (g-g_prev).div_(s.global_vector_norm().clip(min=1e-8))
|
|
47
|
+
p_prev.copy_(p)
|
|
48
|
+
g_prev.copy_(g)
|
|
49
|
+
|
|
50
|
+
exp_avg_ = ema_(g, exp_avg_=exp_avg_, beta=beta1, dampening=0,lerp=True)
|
|
51
|
+
|
|
52
|
+
sqrt_exp_avg_sq = sqrt_ema_sq_(y, exp_avg_sq_=exp_avg_sq_, beta=beta2, max_exp_avg_sq_=max_exp_avg_sq_,
|
|
53
|
+
debiased=False,step=step,pow=pow)
|
|
54
|
+
|
|
55
|
+
if debiased: alpha = debiased_step_size(step, beta1=beta1, beta2=beta2, pow=pow, alpha=alpha)
|
|
56
|
+
|
|
57
|
+
# params is None, return update
|
|
58
|
+
if params_ is None: return (exp_avg_ / sqrt_exp_avg_sq.add_(eps)).lazy_mul(alpha)
|
|
59
|
+
|
|
60
|
+
# update params in-place
|
|
61
|
+
params_.addcdiv_(exp_avg_, sqrt_exp_avg_sq.add_(eps), -alpha)
|
|
62
|
+
return None
|
|
63
|
+
|
|
64
|
+
class AdamY(Module):
|
|
65
|
+
"""Adam but uses scaled gradient differences for second momentum.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
beta1 (float, optional): momentum. Defaults to 0.9.
|
|
69
|
+
beta2 (float, optional): second momentum. Defaults to 0.999.
|
|
70
|
+
eps (float, optional): epsilon. Defaults to 1e-8.
|
|
71
|
+
alpha (float, optional): learning rate. Defaults to 1.
|
|
72
|
+
amsgrad (bool, optional): Whether to divide by maximum of EMA of gradient squares instead. Defaults to False.
|
|
73
|
+
pow (float, optional): power used in second momentum power and root. Defaults to 2.
|
|
74
|
+
debiased (bool, optional): whether to apply debiasing to momentums based on current step. Defaults to True.
|
|
75
|
+
"""
|
|
76
|
+
def __init__(
|
|
77
|
+
self,
|
|
78
|
+
beta1: float = 0.9,
|
|
79
|
+
beta2: float = 0.999,
|
|
80
|
+
eps: float = 1e-8,
|
|
81
|
+
amsgrad: bool = False,
|
|
82
|
+
alpha: float = 1.,
|
|
83
|
+
pow: float = 2,
|
|
84
|
+
debiased: bool = True,
|
|
85
|
+
):
|
|
86
|
+
defaults=dict(beta1=beta1,beta2=beta2,eps=eps,alpha=alpha,amsgrad=amsgrad,pow=pow,debiased=debiased)
|
|
87
|
+
super().__init__(defaults)
|
|
88
|
+
self.getter = itemgetter('amsgrad','pow','debiased')
|
|
89
|
+
|
|
90
|
+
@torch.no_grad
|
|
91
|
+
def step(self, vars):
|
|
92
|
+
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
93
|
+
|
|
94
|
+
beta1,beta2,eps,alpha=self.get_settings('beta1','beta2','eps','alpha', params=vars.params, cls=NumberList)
|
|
95
|
+
amsgrad,pow,debiased = self.getter(self.settings[vars.params[0]])
|
|
96
|
+
|
|
97
|
+
if amsgrad:
|
|
98
|
+
exp_avg, exp_avg_sq, max_exp_avg_sq = self.get_state('exp_avg','exp_avg_sq','max_exp_avg_sq', params=vars.params, cls=TensorList)
|
|
99
|
+
else:
|
|
100
|
+
exp_avg, exp_avg_sq = self.get_state('exp_avg','exp_avg_sq', params=vars.params, cls=TensorList)
|
|
101
|
+
max_exp_avg_sq = None
|
|
102
|
+
|
|
103
|
+
# if this is last module, update parameters in-place with slightly more efficient addcdiv_
|
|
104
|
+
if vars.is_last:
|
|
105
|
+
if vars.last_module_lrs is not None: alpha = alpha * vars.last_module_lrs
|
|
106
|
+
passed_params = TensorList(vars.params)
|
|
107
|
+
vars.stop = True
|
|
108
|
+
vars.skip_update = True
|
|
109
|
+
|
|
110
|
+
else:
|
|
111
|
+
passed_params = None
|
|
112
|
+
|
|
113
|
+
p_prev = self.get_state('p_prev', params=vars.params, cls=TensorList)
|
|
114
|
+
g_prev = self.get_state('g_prev', params=vars.params, cls=TensorList)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
vars.update = adamy_(
|
|
118
|
+
p=TensorList(vars.params),
|
|
119
|
+
p_prev=p_prev,
|
|
120
|
+
g=TensorList(vars.get_update()),
|
|
121
|
+
g_prev=g_prev,
|
|
122
|
+
exp_avg_=exp_avg,
|
|
123
|
+
exp_avg_sq_=exp_avg_sq,
|
|
124
|
+
alpha=alpha,
|
|
125
|
+
beta1=beta1,
|
|
126
|
+
beta2=beta2,
|
|
127
|
+
eps=eps,
|
|
128
|
+
step=step,
|
|
129
|
+
pow=pow,
|
|
130
|
+
debiased=debiased,
|
|
131
|
+
max_exp_avg_sq_=max_exp_avg_sq,
|
|
132
|
+
params_=passed_params,
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
return vars
|