torchzero 0.1.7__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +57 -0
- tests/test_identical.py +230 -0
- tests/test_module.py +50 -0
- tests/test_opts.py +884 -0
- tests/test_tensorlist.py +1787 -0
- tests/test_utils_optimizer.py +170 -0
- tests/test_vars.py +184 -0
- torchzero/__init__.py +4 -4
- torchzero/core/__init__.py +3 -13
- torchzero/core/module.py +629 -494
- torchzero/core/preconditioner.py +137 -0
- torchzero/core/transform.py +252 -0
- torchzero/modules/__init__.py +13 -21
- torchzero/modules/clipping/__init__.py +3 -0
- torchzero/modules/clipping/clipping.py +320 -0
- torchzero/modules/clipping/ema_clipping.py +135 -0
- torchzero/modules/clipping/growth_clipping.py +187 -0
- torchzero/modules/experimental/__init__.py +13 -18
- torchzero/modules/experimental/absoap.py +350 -0
- torchzero/modules/experimental/adadam.py +111 -0
- torchzero/modules/experimental/adamY.py +135 -0
- torchzero/modules/experimental/adasoap.py +282 -0
- torchzero/modules/experimental/algebraic_newton.py +145 -0
- torchzero/modules/experimental/curveball.py +89 -0
- torchzero/modules/experimental/dsoap.py +290 -0
- torchzero/modules/experimental/gradmin.py +85 -0
- torchzero/modules/experimental/reduce_outward_lr.py +35 -0
- torchzero/modules/experimental/spectral.py +286 -0
- torchzero/modules/experimental/subspace_preconditioners.py +128 -0
- torchzero/modules/experimental/tropical_newton.py +136 -0
- torchzero/modules/functional.py +209 -0
- torchzero/modules/grad_approximation/__init__.py +4 -0
- torchzero/modules/grad_approximation/fdm.py +120 -0
- torchzero/modules/grad_approximation/forward_gradient.py +81 -0
- torchzero/modules/grad_approximation/grad_approximator.py +66 -0
- torchzero/modules/grad_approximation/rfdm.py +259 -0
- torchzero/modules/line_search/__init__.py +5 -30
- torchzero/modules/line_search/backtracking.py +186 -0
- torchzero/modules/line_search/line_search.py +181 -0
- torchzero/modules/line_search/scipy.py +37 -0
- torchzero/modules/line_search/strong_wolfe.py +260 -0
- torchzero/modules/line_search/trust_region.py +61 -0
- torchzero/modules/lr/__init__.py +2 -0
- torchzero/modules/lr/lr.py +59 -0
- torchzero/modules/lr/step_size.py +97 -0
- torchzero/modules/momentum/__init__.py +14 -4
- torchzero/modules/momentum/averaging.py +78 -0
- torchzero/modules/momentum/cautious.py +181 -0
- torchzero/modules/momentum/ema.py +173 -0
- torchzero/modules/momentum/experimental.py +189 -0
- torchzero/modules/momentum/matrix_momentum.py +124 -0
- torchzero/modules/momentum/momentum.py +43 -106
- torchzero/modules/ops/__init__.py +103 -0
- torchzero/modules/ops/accumulate.py +65 -0
- torchzero/modules/ops/binary.py +240 -0
- torchzero/modules/ops/debug.py +25 -0
- torchzero/modules/ops/misc.py +419 -0
- torchzero/modules/ops/multi.py +137 -0
- torchzero/modules/ops/reduce.py +149 -0
- torchzero/modules/ops/split.py +75 -0
- torchzero/modules/ops/switch.py +68 -0
- torchzero/modules/ops/unary.py +115 -0
- torchzero/modules/ops/utility.py +112 -0
- torchzero/modules/optimizers/__init__.py +18 -10
- torchzero/modules/optimizers/adagrad.py +146 -49
- torchzero/modules/optimizers/adam.py +112 -118
- torchzero/modules/optimizers/lion.py +18 -11
- torchzero/modules/optimizers/muon.py +222 -0
- torchzero/modules/optimizers/orthograd.py +55 -0
- torchzero/modules/optimizers/rmsprop.py +103 -51
- torchzero/modules/optimizers/rprop.py +342 -99
- torchzero/modules/optimizers/shampoo.py +197 -0
- torchzero/modules/optimizers/soap.py +286 -0
- torchzero/modules/optimizers/sophia_h.py +129 -0
- torchzero/modules/projections/__init__.py +5 -0
- torchzero/modules/projections/dct.py +73 -0
- torchzero/modules/projections/fft.py +73 -0
- torchzero/modules/projections/galore.py +10 -0
- torchzero/modules/projections/projection.py +218 -0
- torchzero/modules/projections/structural.py +151 -0
- torchzero/modules/quasi_newton/__init__.py +7 -4
- torchzero/modules/quasi_newton/cg.py +218 -0
- torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
- torchzero/modules/quasi_newton/lbfgs.py +228 -0
- torchzero/modules/quasi_newton/lsr1.py +170 -0
- torchzero/modules/quasi_newton/olbfgs.py +196 -0
- torchzero/modules/quasi_newton/quasi_newton.py +475 -0
- torchzero/modules/second_order/__init__.py +3 -4
- torchzero/modules/second_order/newton.py +142 -165
- torchzero/modules/second_order/newton_cg.py +84 -0
- torchzero/modules/second_order/nystrom.py +168 -0
- torchzero/modules/smoothing/__init__.py +2 -5
- torchzero/modules/smoothing/gaussian.py +164 -0
- torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
- torchzero/modules/weight_decay/__init__.py +1 -0
- torchzero/modules/weight_decay/weight_decay.py +52 -0
- torchzero/modules/wrappers/__init__.py +1 -0
- torchzero/modules/wrappers/optim_wrapper.py +91 -0
- torchzero/optim/__init__.py +2 -10
- torchzero/optim/utility/__init__.py +1 -0
- torchzero/optim/utility/split.py +45 -0
- torchzero/optim/wrappers/nevergrad.py +2 -28
- torchzero/optim/wrappers/nlopt.py +31 -16
- torchzero/optim/wrappers/scipy.py +79 -156
- torchzero/utils/__init__.py +27 -0
- torchzero/utils/compile.py +175 -37
- torchzero/utils/derivatives.py +513 -99
- torchzero/utils/linalg/__init__.py +5 -0
- torchzero/utils/linalg/matrix_funcs.py +87 -0
- torchzero/utils/linalg/orthogonalize.py +11 -0
- torchzero/utils/linalg/qr.py +71 -0
- torchzero/utils/linalg/solve.py +168 -0
- torchzero/utils/linalg/svd.py +20 -0
- torchzero/utils/numberlist.py +132 -0
- torchzero/utils/ops.py +10 -0
- torchzero/utils/optimizer.py +284 -0
- torchzero/utils/optuna_tools.py +40 -0
- torchzero/utils/params.py +149 -0
- torchzero/utils/python_tools.py +40 -25
- torchzero/utils/tensorlist.py +1081 -0
- torchzero/utils/torch_tools.py +48 -12
- torchzero-0.3.1.dist-info/METADATA +379 -0
- torchzero-0.3.1.dist-info/RECORD +128 -0
- {torchzero-0.1.7.dist-info → torchzero-0.3.1.dist-info}/WHEEL +1 -1
- {torchzero-0.1.7.dist-info → torchzero-0.3.1.dist-info/licenses}/LICENSE +0 -0
- torchzero-0.3.1.dist-info/top_level.txt +3 -0
- torchzero/core/tensorlist_optimizer.py +0 -219
- torchzero/modules/adaptive/__init__.py +0 -4
- torchzero/modules/adaptive/adaptive.py +0 -192
- torchzero/modules/experimental/experimental.py +0 -294
- torchzero/modules/experimental/quad_interp.py +0 -104
- torchzero/modules/experimental/subspace.py +0 -259
- torchzero/modules/gradient_approximation/__init__.py +0 -7
- torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
- torchzero/modules/gradient_approximation/base_approximator.py +0 -105
- torchzero/modules/gradient_approximation/fdm.py +0 -125
- torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
- torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
- torchzero/modules/gradient_approximation/rfdm.py +0 -125
- torchzero/modules/line_search/armijo.py +0 -56
- torchzero/modules/line_search/base_ls.py +0 -139
- torchzero/modules/line_search/directional_newton.py +0 -217
- torchzero/modules/line_search/grid_ls.py +0 -158
- torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
- torchzero/modules/meta/__init__.py +0 -12
- torchzero/modules/meta/alternate.py +0 -65
- torchzero/modules/meta/grafting.py +0 -195
- torchzero/modules/meta/optimizer_wrapper.py +0 -173
- torchzero/modules/meta/return_overrides.py +0 -46
- torchzero/modules/misc/__init__.py +0 -10
- torchzero/modules/misc/accumulate.py +0 -43
- torchzero/modules/misc/basic.py +0 -115
- torchzero/modules/misc/lr.py +0 -96
- torchzero/modules/misc/multistep.py +0 -51
- torchzero/modules/misc/on_increase.py +0 -53
- torchzero/modules/operations/__init__.py +0 -29
- torchzero/modules/operations/multi.py +0 -298
- torchzero/modules/operations/reduction.py +0 -134
- torchzero/modules/operations/singular.py +0 -113
- torchzero/modules/optimizers/sgd.py +0 -54
- torchzero/modules/orthogonalization/__init__.py +0 -2
- torchzero/modules/orthogonalization/newtonschulz.py +0 -159
- torchzero/modules/orthogonalization/svd.py +0 -86
- torchzero/modules/regularization/__init__.py +0 -22
- torchzero/modules/regularization/dropout.py +0 -34
- torchzero/modules/regularization/noise.py +0 -77
- torchzero/modules/regularization/normalization.py +0 -328
- torchzero/modules/regularization/ortho_grad.py +0 -78
- torchzero/modules/regularization/weight_decay.py +0 -92
- torchzero/modules/scheduling/__init__.py +0 -2
- torchzero/modules/scheduling/lr_schedulers.py +0 -131
- torchzero/modules/scheduling/step_size.py +0 -80
- torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
- torchzero/modules/weight_averaging/__init__.py +0 -2
- torchzero/modules/weight_averaging/ema.py +0 -72
- torchzero/modules/weight_averaging/swa.py +0 -171
- torchzero/optim/experimental/__init__.py +0 -20
- torchzero/optim/experimental/experimental.py +0 -343
- torchzero/optim/experimental/ray_search.py +0 -83
- torchzero/optim/first_order/__init__.py +0 -18
- torchzero/optim/first_order/cautious.py +0 -158
- torchzero/optim/first_order/forward_gradient.py +0 -70
- torchzero/optim/first_order/optimizers.py +0 -570
- torchzero/optim/modular.py +0 -132
- torchzero/optim/quasi_newton/__init__.py +0 -1
- torchzero/optim/quasi_newton/directional_newton.py +0 -58
- torchzero/optim/second_order/__init__.py +0 -1
- torchzero/optim/second_order/newton.py +0 -94
- torchzero/optim/zeroth_order/__init__.py +0 -4
- torchzero/optim/zeroth_order/fdm.py +0 -87
- torchzero/optim/zeroth_order/newton_fdm.py +0 -146
- torchzero/optim/zeroth_order/rfdm.py +0 -217
- torchzero/optim/zeroth_order/rs.py +0 -85
- torchzero/random/__init__.py +0 -1
- torchzero/random/random.py +0 -46
- torchzero/tensorlist.py +0 -826
- torchzero-0.1.7.dist-info/METADATA +0 -120
- torchzero-0.1.7.dist-info/RECORD +0 -104
- torchzero-0.1.7.dist-info/top_level.txt +0 -1
|
@@ -0,0 +1,286 @@
|
|
|
1
|
+
from operator import itemgetter
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from ...core import Chainable, Transform, apply
|
|
6
|
+
from ...modules.optimizers.shampoo import _merge_small_dims, _unmerge_small_dims
|
|
7
|
+
|
|
8
|
+
@torch.no_grad
|
|
9
|
+
def update_soap_covariances_(
|
|
10
|
+
grad: torch.Tensor,
|
|
11
|
+
GGs_: list[torch.Tensor | None],
|
|
12
|
+
beta: float | None,
|
|
13
|
+
):
|
|
14
|
+
for i, GG in enumerate(GGs_):
|
|
15
|
+
if GG is None: continue
|
|
16
|
+
|
|
17
|
+
axes = list(range(i)) + list(range(i + 1, grad.ndim)) # this works fine with 1d params
|
|
18
|
+
if beta is None: GG.add_(torch.tensordot(grad, grad, (axes, axes))) # pyright:ignore[reportArgumentType]
|
|
19
|
+
else: GG.lerp_(torch.tensordot(grad, grad, (axes, axes)), 1-beta) # pyright:ignore[reportArgumentType]
|
|
20
|
+
|
|
21
|
+
@torch.no_grad
|
|
22
|
+
def project(tensors: torch.Tensor, Q: list[torch.Tensor | None]):
|
|
23
|
+
"""
|
|
24
|
+
Projects the gradient to the eigenbases of the preconditioner.
|
|
25
|
+
"""
|
|
26
|
+
for mat in Q:
|
|
27
|
+
if mat is None: continue
|
|
28
|
+
if len(mat) > 0:
|
|
29
|
+
tensors = torch.tensordot(tensors, mat, dims=[[0], [0]]) # pyright:ignore[reportArgumentType]
|
|
30
|
+
else:
|
|
31
|
+
# I don't understand this part but it is in https://github.com/nikhilvyas/SOAP/blob/main/soap.py
|
|
32
|
+
permute_order = list(range(1, len(tensors.shape))) + [0]
|
|
33
|
+
tensors = tensors.permute(permute_order)
|
|
34
|
+
|
|
35
|
+
return tensors
|
|
36
|
+
|
|
37
|
+
@torch.no_grad
|
|
38
|
+
def project_back(tensors: torch.Tensor, Q: list[torch.Tensor| None]):
|
|
39
|
+
"""
|
|
40
|
+
Projects the gradient back to the original space.
|
|
41
|
+
"""
|
|
42
|
+
for mat in Q:
|
|
43
|
+
if mat is None: continue
|
|
44
|
+
if len(mat) > 0:
|
|
45
|
+
tensors = torch.tensordot(tensors, mat,dims=[[0], [1]]) # pyright:ignore[reportArgumentType]
|
|
46
|
+
else:
|
|
47
|
+
permute_order = list(range(1, len(tensors.shape))) + [0]
|
|
48
|
+
tensors = tensors.permute(permute_order)
|
|
49
|
+
|
|
50
|
+
return tensors
|
|
51
|
+
|
|
52
|
+
# function from https://github.com/nikhilvyas/SOAP/blob/main/soap.py
|
|
53
|
+
@torch.no_grad
|
|
54
|
+
def get_orthogonal_matrix(mat: list[torch.Tensor | None]):
|
|
55
|
+
"""
|
|
56
|
+
Computes the eigenbases of the preconditioner using torch.linalg.eigh decomposition.
|
|
57
|
+
"""
|
|
58
|
+
matrix = []
|
|
59
|
+
float_data = False
|
|
60
|
+
original_type = original_device = None
|
|
61
|
+
for m in mat:
|
|
62
|
+
if m is None: continue
|
|
63
|
+
if len(m) == 0:
|
|
64
|
+
matrix.append([])
|
|
65
|
+
continue
|
|
66
|
+
if m.dtype != torch.float:
|
|
67
|
+
original_type = m.dtype
|
|
68
|
+
original_device = m.device
|
|
69
|
+
matrix.append(m.float())
|
|
70
|
+
else:
|
|
71
|
+
float_data = True
|
|
72
|
+
matrix.append(m)
|
|
73
|
+
|
|
74
|
+
final = []
|
|
75
|
+
for m in matrix:
|
|
76
|
+
if len(m) == 0:
|
|
77
|
+
final.append([])
|
|
78
|
+
continue
|
|
79
|
+
try:
|
|
80
|
+
_, Q = torch.linalg.eigh(m+1e-30*torch.eye(m.shape[0], device=m.device)) # pylint:disable=not-callable
|
|
81
|
+
except Exception:
|
|
82
|
+
_, Q = torch.linalg.eigh(m.to(torch.float64)+1e-30*torch.eye(m.shape[0], device=m.device)) # pylint:disable=not-callable
|
|
83
|
+
Q = Q.to(m.dtype)
|
|
84
|
+
Q = torch.flip(Q, [1])
|
|
85
|
+
|
|
86
|
+
if not float_data:
|
|
87
|
+
Q = Q.to(original_device).type(original_type)
|
|
88
|
+
final.append(Q)
|
|
89
|
+
return final
|
|
90
|
+
|
|
91
|
+
# function from https://github.com/nikhilvyas/SOAP/blob/main/soap.py#L240
|
|
92
|
+
@torch.no_grad
|
|
93
|
+
def get_orthogonal_matrix_QR(exp_avg_sq: torch.Tensor, GG: list[torch.Tensor | None], Q_list: list[torch.Tensor | None]):
|
|
94
|
+
"""
|
|
95
|
+
Computes the eigenbases of the preconditioner using one round of power iteration
|
|
96
|
+
followed by torch.linalg.qr decomposition.
|
|
97
|
+
"""
|
|
98
|
+
matrix = []
|
|
99
|
+
orth_matrix = []
|
|
100
|
+
float_data = False
|
|
101
|
+
original_type = original_device = None
|
|
102
|
+
for m,o in zip(GG, Q_list):
|
|
103
|
+
if m is None: continue
|
|
104
|
+
assert o is not None
|
|
105
|
+
|
|
106
|
+
if len(m) == 0:
|
|
107
|
+
matrix.append([])
|
|
108
|
+
orth_matrix.append([])
|
|
109
|
+
continue
|
|
110
|
+
if m.data.dtype != torch.float:
|
|
111
|
+
original_type = m.data.dtype
|
|
112
|
+
original_device = m.data.device
|
|
113
|
+
matrix.append(m.data.float())
|
|
114
|
+
orth_matrix.append(o.data.float())
|
|
115
|
+
else:
|
|
116
|
+
float_data = True
|
|
117
|
+
matrix.append(m.data.float())
|
|
118
|
+
orth_matrix.append(o.data.float())
|
|
119
|
+
|
|
120
|
+
final = []
|
|
121
|
+
for ind, (m,o) in enumerate(zip(matrix, orth_matrix)):
|
|
122
|
+
if len(m)==0:
|
|
123
|
+
final.append([])
|
|
124
|
+
continue
|
|
125
|
+
est_eig = torch.diag(o.T @ m @ o)
|
|
126
|
+
sort_idx = torch.argsort(est_eig, descending=True)
|
|
127
|
+
exp_avg_sq = exp_avg_sq.index_select(ind, sort_idx)
|
|
128
|
+
o = o[:,sort_idx]
|
|
129
|
+
power_iter = m @ o
|
|
130
|
+
Q, _ = torch.linalg.qr(power_iter) # pylint:disable=not-callable
|
|
131
|
+
|
|
132
|
+
if not float_data:
|
|
133
|
+
Q = Q.to(original_device).type(original_type)
|
|
134
|
+
final.append(Q)
|
|
135
|
+
|
|
136
|
+
return final, exp_avg_sq
|
|
137
|
+
|
|
138
|
+
class SOAP(Transform):
|
|
139
|
+
"""SOAP (ShampoO with Adam in the Preconditioner's eigenbasis from https://arxiv.org/abs/2409.11321).
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
beta1 (float, optional): beta for first momentum. Defaults to 0.95.
|
|
143
|
+
beta2 (float, optional): beta for second momentum. Defaults to 0.95.
|
|
144
|
+
shampoo_beta (float | None, optional):
|
|
145
|
+
beta for covariance matrices accumulators. Can be None, then it just sums them like Adagrad (which works worse). Defaults to 0.95.
|
|
146
|
+
precond_freq (int, optional): How often to update the preconditioner. Defaults to 10.
|
|
147
|
+
merge_small (bool, optional): Whether to merge small dims. Defaults to True.
|
|
148
|
+
max_dim (int, optional): Won't precondition dims larger than this. Defaults to 2_000.
|
|
149
|
+
precondition_1d (bool, optional):
|
|
150
|
+
Whether to precondition 1d params (SOAP paper sets this to False). Defaults to True.
|
|
151
|
+
eps (float, optional):
|
|
152
|
+
epsilon for dividing first momentum by second. Defaults to 1e-8.
|
|
153
|
+
decay (float | None, optional):
|
|
154
|
+
Decays covariance matrix accumulators, this may be useful if `shampoo_beta` is None. Defaults to None.
|
|
155
|
+
unprojected_exp_avg (bool, optional):
|
|
156
|
+
whether to update first momentum in unprojected space. Both true and false work and lead to different
|
|
157
|
+
results but True usually works better. Defaults to True.
|
|
158
|
+
bias_correction (bool, optional):
|
|
159
|
+
enables adam bias correction. Defaults to True.
|
|
160
|
+
"""
|
|
161
|
+
def __init__(
|
|
162
|
+
self,
|
|
163
|
+
beta1: float = 0.95,
|
|
164
|
+
beta2: float = 0.95,
|
|
165
|
+
shampoo_beta: float | None = 0.95,
|
|
166
|
+
precond_freq: int = 10,
|
|
167
|
+
merge_small: bool = True,
|
|
168
|
+
max_dim: int = 2_000,
|
|
169
|
+
precondition_1d: bool = True,
|
|
170
|
+
eps: float = 1e-8,
|
|
171
|
+
decay: float | None = None,
|
|
172
|
+
alpha: float = 1,
|
|
173
|
+
unprojected_exp_avg: bool = True,
|
|
174
|
+
bias_correction: bool = True,
|
|
175
|
+
):
|
|
176
|
+
defaults = dict(
|
|
177
|
+
beta1=beta1,
|
|
178
|
+
beta2=beta2,
|
|
179
|
+
shampoo_beta=shampoo_beta,
|
|
180
|
+
precond_freq=precond_freq,
|
|
181
|
+
merge_small=merge_small,
|
|
182
|
+
max_dim=max_dim,
|
|
183
|
+
precondition_1d=precondition_1d,
|
|
184
|
+
eps=eps,
|
|
185
|
+
decay=decay,
|
|
186
|
+
unprojected_exp_avg=unprojected_exp_avg,
|
|
187
|
+
bias_correction=bias_correction,
|
|
188
|
+
alpha=alpha,
|
|
189
|
+
)
|
|
190
|
+
super().__init__(defaults, uses_grad=False)
|
|
191
|
+
|
|
192
|
+
@torch.no_grad
|
|
193
|
+
def transform(self, tensors, params, grads, vars):
|
|
194
|
+
updates = []
|
|
195
|
+
# update preconditioners
|
|
196
|
+
for i,(p,t) in enumerate(zip(params, tensors)):
|
|
197
|
+
state = self.state[p]
|
|
198
|
+
settings = self.settings[p]
|
|
199
|
+
beta1, beta2, shampoo_beta, merge_small, max_dim, precondition_1d, eps, unprojected_exp_avg,alpha = itemgetter(
|
|
200
|
+
'beta1', 'beta2', 'shampoo_beta', 'merge_small', 'max_dim', 'precondition_1d', 'eps', 'unprojected_exp_avg','alpha')(settings)
|
|
201
|
+
|
|
202
|
+
if merge_small:
|
|
203
|
+
t, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(t, max_dim)
|
|
204
|
+
|
|
205
|
+
# initialize state on 1st step
|
|
206
|
+
if 'GG' not in state:
|
|
207
|
+
state["exp_avg"] = torch.zeros_like(t)
|
|
208
|
+
state["exp_avg_sq"] = torch.zeros_like(t)
|
|
209
|
+
|
|
210
|
+
if not precondition_1d and t.ndim <= 1:
|
|
211
|
+
state['GG'] = []
|
|
212
|
+
|
|
213
|
+
else:
|
|
214
|
+
state['GG'] = [torch.zeros(s, s, dtype=t.dtype, device=t.device) if 1<s<max_dim else None for s in t.shape]
|
|
215
|
+
|
|
216
|
+
# either scalar parameter, 1d with precondition_1d=False, or all dims are too big.
|
|
217
|
+
if len([i is not None for i in state['GG']]) == 0:
|
|
218
|
+
state['GG'] = None
|
|
219
|
+
|
|
220
|
+
if state['GG'] is not None:
|
|
221
|
+
update_soap_covariances_(t, GGs_=state['GG'], beta=shampoo_beta)
|
|
222
|
+
state['Q'] = get_orthogonal_matrix(state['GG'])
|
|
223
|
+
|
|
224
|
+
state['step'] = 0
|
|
225
|
+
updates.append(tensors[i].sign().div_(10))
|
|
226
|
+
# updates.append(tensors[i] / tensors[i].abs().sum())
|
|
227
|
+
continue # skip 1st step as in https://github.com/nikhilvyas/SOAP/blob/main/soap.py ?
|
|
228
|
+
# I use scaled update instead as to not mess up with next modules.
|
|
229
|
+
|
|
230
|
+
# Projecting gradients to the eigenbases of Shampoo's preconditioner
|
|
231
|
+
# i.e. projecting to the eigenbases of matrices in state['GG']
|
|
232
|
+
t_projected = None
|
|
233
|
+
if state['GG'] is not None:
|
|
234
|
+
t_projected = project(t, state['Q'])
|
|
235
|
+
|
|
236
|
+
# exponential moving averages
|
|
237
|
+
# this part could be foreached but I will do that at some point its not a big difference compared to preconditioning
|
|
238
|
+
exp_avg: torch.Tensor = state["exp_avg"]
|
|
239
|
+
exp_avg_sq: torch.Tensor = state["exp_avg_sq"]
|
|
240
|
+
|
|
241
|
+
if unprojected_exp_avg or t_projected is None:
|
|
242
|
+
exp_avg.lerp_(t, 1-beta1)
|
|
243
|
+
else:
|
|
244
|
+
exp_avg.lerp_(t_projected, 1-beta1)
|
|
245
|
+
|
|
246
|
+
if t_projected is None:
|
|
247
|
+
exp_avg_sq.mul_(beta2).addcmul_(t, t, value=1-beta2)
|
|
248
|
+
else:
|
|
249
|
+
exp_avg_sq.mul_(beta2).addcmul_(t_projected, t_projected, value=1-beta2)
|
|
250
|
+
|
|
251
|
+
# project exponential moving averages if they are accumulated unprojected
|
|
252
|
+
exp_avg_projected = exp_avg
|
|
253
|
+
if unprojected_exp_avg and t_projected is not None:
|
|
254
|
+
exp_avg_projected = project(exp_avg, state['Q'])
|
|
255
|
+
|
|
256
|
+
exp_avg_sq_projected = exp_avg_sq
|
|
257
|
+
|
|
258
|
+
denom = exp_avg_sq_projected.sqrt().add_(eps)
|
|
259
|
+
# print(f'{t_projected = }, {exp_avg = }, {exp_avg_projected = }, {exp_avg_sq = }, {exp_avg_sq_projected = }, {denom = }')
|
|
260
|
+
|
|
261
|
+
# Projecting back the preconditioned (by Adam) exponential moving average of gradients
|
|
262
|
+
# to the original space
|
|
263
|
+
update = exp_avg_projected / denom
|
|
264
|
+
if t_projected is not None:
|
|
265
|
+
update = project_back(update, state["Q"])
|
|
266
|
+
|
|
267
|
+
if settings['bias_correction']:
|
|
268
|
+
bias_correction1 = 1.0 - beta1 ** (state["step"]+1)
|
|
269
|
+
bias_correction2 = 1.0 - beta2 ** (state["step"]+1)
|
|
270
|
+
update *= ((bias_correction2 ** .5) / bias_correction1) * alpha
|
|
271
|
+
elif alpha is not None:
|
|
272
|
+
update *= alpha
|
|
273
|
+
|
|
274
|
+
if merge_small:
|
|
275
|
+
update = _unmerge_small_dims(update, state['flat_sizes'], state['sort_idxs'])
|
|
276
|
+
|
|
277
|
+
updates.append(update)
|
|
278
|
+
state["step"] += 1
|
|
279
|
+
|
|
280
|
+
# Update is done after the gradient step to avoid using current gradients in the projection.
|
|
281
|
+
if state['GG'] is not None:
|
|
282
|
+
update_soap_covariances_(t, state['GG'], shampoo_beta)
|
|
283
|
+
if state['step'] % settings['precond_freq'] == 0:
|
|
284
|
+
state['Q'], state['exp_avg_sq'] = get_orthogonal_matrix_QR(exp_avg_sq, state['GG'], state['Q'])
|
|
285
|
+
|
|
286
|
+
return updates
|
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
from typing import Literal
|
|
2
|
+
from collections.abc import Callable
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from ...core import Module, Target, Transform, Chainable, apply
|
|
6
|
+
from ...utils import NumberList, TensorList, as_tensorlist
|
|
7
|
+
from ...utils.derivatives import hvp, hvp_fd_forward, hvp_fd_central
|
|
8
|
+
|
|
9
|
+
def sophia_H(
|
|
10
|
+
tensors: TensorList,
|
|
11
|
+
h: TensorList | None,
|
|
12
|
+
exp_avg_: TensorList,
|
|
13
|
+
h_exp_avg_: TensorList,
|
|
14
|
+
beta1: float | NumberList,
|
|
15
|
+
beta2: float | NumberList,
|
|
16
|
+
update_freq: int,
|
|
17
|
+
precond_scale: float | NumberList,
|
|
18
|
+
clip: float | NumberList,
|
|
19
|
+
eps: float | NumberList,
|
|
20
|
+
step: int
|
|
21
|
+
):
|
|
22
|
+
# momentum
|
|
23
|
+
exp_avg_.lerp_(tensors, 1-beta1)
|
|
24
|
+
|
|
25
|
+
# update preconditioner
|
|
26
|
+
if step % update_freq == 0:
|
|
27
|
+
assert h is not None
|
|
28
|
+
h_exp_avg_.lerp_(h, 1-beta2)
|
|
29
|
+
|
|
30
|
+
else:
|
|
31
|
+
assert h is None
|
|
32
|
+
|
|
33
|
+
denom = (h_exp_avg_ * precond_scale).clip_(min=eps)
|
|
34
|
+
return (exp_avg_ / denom).clip_(-clip, clip)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class SophiaH(Module):
|
|
38
|
+
def __init__(
|
|
39
|
+
self,
|
|
40
|
+
beta1: float = 0.96,
|
|
41
|
+
beta2: float = 0.99,
|
|
42
|
+
update_freq: int = 10,
|
|
43
|
+
precond_scale: float = 1,
|
|
44
|
+
clip: float = 1,
|
|
45
|
+
eps: float = 1e-12,
|
|
46
|
+
hvp_method: Literal['autograd', 'forward', 'central'] = 'autograd',
|
|
47
|
+
fd_h: float = 1e-3,
|
|
48
|
+
n_samples = 1,
|
|
49
|
+
seed: int | None = None,
|
|
50
|
+
inner: Chainable | None = None
|
|
51
|
+
):
|
|
52
|
+
defaults = dict(beta1=beta1, beta2=beta2, update_freq=update_freq, precond_scale=precond_scale, clip=clip, eps=eps, hvp_method=hvp_method, n_samples=n_samples, fd_h=fd_h, seed=seed)
|
|
53
|
+
super().__init__(defaults)
|
|
54
|
+
|
|
55
|
+
if inner is not None:
|
|
56
|
+
self.set_child('inner', inner)
|
|
57
|
+
|
|
58
|
+
@torch.no_grad
|
|
59
|
+
def step(self, vars):
|
|
60
|
+
params = vars.params
|
|
61
|
+
settings = self.settings[params[0]]
|
|
62
|
+
hvp_method = settings['hvp_method']
|
|
63
|
+
fd_h = settings['fd_h']
|
|
64
|
+
update_freq = settings['update_freq']
|
|
65
|
+
n_samples = settings['n_samples']
|
|
66
|
+
|
|
67
|
+
seed = settings['seed']
|
|
68
|
+
generator = None
|
|
69
|
+
if seed is not None:
|
|
70
|
+
if 'generator' not in self.global_state:
|
|
71
|
+
self.global_state['generator'] = torch.Generator(params[0].device).manual_seed(seed)
|
|
72
|
+
generator = self.global_state['generator']
|
|
73
|
+
|
|
74
|
+
beta1, beta2, precond_scale, clip, eps = self.get_settings(
|
|
75
|
+
'beta1', 'beta2', 'precond_scale', 'clip', 'eps', params=params, cls=NumberList)
|
|
76
|
+
|
|
77
|
+
exp_avg, h_exp_avg = self.get_state('exp_avg', 'h_exp_avg', params=params, cls=TensorList)
|
|
78
|
+
|
|
79
|
+
step = self.global_state.get('step', 0)
|
|
80
|
+
self.global_state['step'] = step + 1
|
|
81
|
+
|
|
82
|
+
closure = vars.closure
|
|
83
|
+
assert closure is not None
|
|
84
|
+
|
|
85
|
+
h = None
|
|
86
|
+
if step % update_freq == 0:
|
|
87
|
+
|
|
88
|
+
grad=None
|
|
89
|
+
for i in range(n_samples):
|
|
90
|
+
u = [torch.randn(p.shape, device=p.device, dtype=p.dtype, generator=generator) for p in params]
|
|
91
|
+
|
|
92
|
+
if hvp_method == 'autograd':
|
|
93
|
+
if grad is None: grad = vars.get_grad(create_graph=True)
|
|
94
|
+
assert grad is not None
|
|
95
|
+
Hvp = hvp(params, grad, u, retain_graph=i < n_samples-1)
|
|
96
|
+
|
|
97
|
+
elif hvp_method == 'forward':
|
|
98
|
+
loss, Hvp = hvp_fd_forward(closure, params, u, h=fd_h, g_0=vars.get_grad(), normalize=True)
|
|
99
|
+
|
|
100
|
+
elif hvp_method == 'central':
|
|
101
|
+
loss, Hvp = hvp_fd_central(closure, params, u, h=fd_h, normalize=True)
|
|
102
|
+
|
|
103
|
+
else:
|
|
104
|
+
raise ValueError(hvp_method)
|
|
105
|
+
|
|
106
|
+
if h is None: h = Hvp
|
|
107
|
+
else: torch._foreach_add_(h, Hvp)
|
|
108
|
+
|
|
109
|
+
assert h is not None
|
|
110
|
+
if n_samples > 1: torch._foreach_div_(h, n_samples)
|
|
111
|
+
|
|
112
|
+
update = vars.get_update()
|
|
113
|
+
if 'inner' in self.children:
|
|
114
|
+
update = apply(self.children['inner'], tensors=update, params=params, grads=vars.grad, vars=vars)
|
|
115
|
+
|
|
116
|
+
vars.update = sophia_H(
|
|
117
|
+
tensors=TensorList(update),
|
|
118
|
+
h=TensorList(h) if h is not None else None,
|
|
119
|
+
exp_avg_=exp_avg,
|
|
120
|
+
h_exp_avg_=h_exp_avg,
|
|
121
|
+
beta1=beta1,
|
|
122
|
+
beta2=beta2,
|
|
123
|
+
update_freq=update_freq,
|
|
124
|
+
precond_scale=precond_scale,
|
|
125
|
+
clip=clip,
|
|
126
|
+
eps=eps,
|
|
127
|
+
step=step,
|
|
128
|
+
)
|
|
129
|
+
return vars
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
from typing import Literal
|
|
2
|
+
import torch
|
|
3
|
+
import torch_dct
|
|
4
|
+
from .projection import Projection
|
|
5
|
+
from ...core import Chainable
|
|
6
|
+
|
|
7
|
+
def reverse_dims(t:torch.Tensor):
|
|
8
|
+
return t.permute(*reversed(range(t.ndim)))
|
|
9
|
+
|
|
10
|
+
class DCTProjection(Projection):
|
|
11
|
+
# norm description copied from pytorch docstring
|
|
12
|
+
"""Project update into Discrete Cosine Transform space, requires `torch_dct` library.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
modules (Chainable): modules that will optimize the projected update.
|
|
16
|
+
dims (1, 2 or 3, optional):
|
|
17
|
+
applies DCT to first 1,2 or 3 dims, defaults to 3.
|
|
18
|
+
norm (str, optional):
|
|
19
|
+
Normalization mode.
|
|
20
|
+
* None - no normalization
|
|
21
|
+
* "ortho" - normalize by 1/sqrt(n)
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
modules: Chainable,
|
|
27
|
+
dims: Literal[1, 2, 3] = 3,
|
|
28
|
+
norm=None,
|
|
29
|
+
project_update=True,
|
|
30
|
+
project_params=False,
|
|
31
|
+
project_grad=False,
|
|
32
|
+
):
|
|
33
|
+
defaults = dict(dims=dims, norm=norm)
|
|
34
|
+
super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad, defaults=defaults)
|
|
35
|
+
|
|
36
|
+
@torch.no_grad
|
|
37
|
+
def project(self, tensors, vars, current):
|
|
38
|
+
settings = self.settings[vars.params[0]]
|
|
39
|
+
dims = settings['dims']
|
|
40
|
+
norm = settings['norm']
|
|
41
|
+
|
|
42
|
+
projected = []
|
|
43
|
+
for u in tensors:
|
|
44
|
+
u = reverse_dims(u)
|
|
45
|
+
dim = min(u.ndim, dims)
|
|
46
|
+
|
|
47
|
+
if dim == 1: dct = torch_dct.dct(u, norm = norm)
|
|
48
|
+
elif dim == 2: dct = torch_dct.dct_2d(u, norm=norm)
|
|
49
|
+
elif dim == 3: dct = torch_dct.dct_3d(u, norm=norm)
|
|
50
|
+
else: raise ValueError(f"Unsupported number of dimensions {dim}")
|
|
51
|
+
|
|
52
|
+
projected.append(dct)
|
|
53
|
+
|
|
54
|
+
return projected
|
|
55
|
+
|
|
56
|
+
@torch.no_grad
|
|
57
|
+
def unproject(self, tensors, vars, current):
|
|
58
|
+
settings = self.settings[vars.params[0]]
|
|
59
|
+
dims = settings['dims']
|
|
60
|
+
norm = settings['norm']
|
|
61
|
+
|
|
62
|
+
unprojected = []
|
|
63
|
+
for u in tensors:
|
|
64
|
+
dim = min(u.ndim, dims)
|
|
65
|
+
|
|
66
|
+
if dim == 1: idct = torch_dct.idct(u, norm = norm)
|
|
67
|
+
elif dim == 2: idct = torch_dct.idct_2d(u, norm=norm)
|
|
68
|
+
elif dim == 3: idct = torch_dct.idct_3d(u, norm=norm)
|
|
69
|
+
else: raise ValueError(f"Unsupported number of dimensions {dim}")
|
|
70
|
+
|
|
71
|
+
unprojected.append(reverse_dims(idct))
|
|
72
|
+
|
|
73
|
+
return unprojected
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from ...core import Chainable
|
|
4
|
+
from ...utils import vec_to_tensors
|
|
5
|
+
from .projection import Projection
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class FFTProjection(Projection):
|
|
9
|
+
# norm description copied from pytorch docstring
|
|
10
|
+
"""Project update into Fourrier space of real-valued inputs.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
modules (Chainable): modules that will optimize the projected update.
|
|
14
|
+
one_d (bool, optional):
|
|
15
|
+
* If True, uses 1d fft on parameters concatenated into a vector.
|
|
16
|
+
* If False, uses n-dimensional fft on each parameter (default).
|
|
17
|
+
norm (str, optional):
|
|
18
|
+
Normalization mode.
|
|
19
|
+
|
|
20
|
+
* "forward" - normalize by 1/n
|
|
21
|
+
* "backward" - no normalization
|
|
22
|
+
* "ortho" - normalize by 1/sqrt(n) (making the FFT orthonormal)
|
|
23
|
+
|
|
24
|
+
Calling the backward transform (:func:`~torch.fft.irfft`) with the same
|
|
25
|
+
normalization mode will apply an overall normalization of ``1/n`` between
|
|
26
|
+
the two transforms. This is required to make :func:`~torch.fft.irfft`
|
|
27
|
+
the exact inverse.
|
|
28
|
+
|
|
29
|
+
Default is "backward" (no normalization).
|
|
30
|
+
|
|
31
|
+
The actual torch.fft.rfft default is None, so I set it to None too. I guess None and "backward"
|
|
32
|
+
are the same.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
modules: Chainable,
|
|
38
|
+
one_d: bool = False,
|
|
39
|
+
norm=None,
|
|
40
|
+
project_update=True,
|
|
41
|
+
project_params=False,
|
|
42
|
+
project_grad=False,
|
|
43
|
+
):
|
|
44
|
+
defaults = dict(one_d=one_d, norm=norm)
|
|
45
|
+
super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad, defaults=defaults)
|
|
46
|
+
|
|
47
|
+
@torch.no_grad
|
|
48
|
+
def project(self, tensors, vars, current):
|
|
49
|
+
settings = self.settings[vars.params[0]]
|
|
50
|
+
one_d = settings['one_d']
|
|
51
|
+
norm = settings['norm']
|
|
52
|
+
|
|
53
|
+
# 1d fft, concatenate all parameters into a vector and calculate fft
|
|
54
|
+
if one_d:
|
|
55
|
+
vec = torch.cat([t.view(-1) for t in tensors])
|
|
56
|
+
self.global_state['length'] = len(vec)
|
|
57
|
+
return [torch.view_as_real(torch.fft.rfft(vec, norm=norm))] # pylint:disable=not-callable
|
|
58
|
+
|
|
59
|
+
# multidimensional fft for each parameter
|
|
60
|
+
return [torch.view_as_real(torch.fft.rfftn(t, norm=norm)) if t.numel() > 1 else t for t in tensors] # pylint:disable=not-callable
|
|
61
|
+
|
|
62
|
+
@torch.no_grad
|
|
63
|
+
def unproject(self, tensors, vars, current):
|
|
64
|
+
settings = self.settings[vars.params[0]]
|
|
65
|
+
one_d = settings['one_d']
|
|
66
|
+
norm = settings['norm']
|
|
67
|
+
|
|
68
|
+
if one_d:
|
|
69
|
+
vec = torch.view_as_complex(tensors[0])
|
|
70
|
+
unprojected_vec = torch.fft.irfft(vec, n=self.global_state['length'], norm=norm) # pylint:disable=not-callable
|
|
71
|
+
return vec_to_tensors(unprojected_vec, reference=vars.params)
|
|
72
|
+
|
|
73
|
+
return [torch.fft.irfftn(torch.view_as_complex(t.contiguous()), s=p.shape, norm=norm) if t.numel() > 1 else t for t, p in zip(tensors, vars.params)] # pylint:disable=not-callable
|