torchzero 0.1.8__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +57 -0
- tests/test_identical.py +230 -0
- tests/test_module.py +50 -0
- tests/test_opts.py +884 -0
- tests/test_tensorlist.py +1787 -0
- tests/test_utils_optimizer.py +170 -0
- tests/test_vars.py +184 -0
- torchzero/__init__.py +4 -4
- torchzero/core/__init__.py +3 -13
- torchzero/core/module.py +629 -510
- torchzero/core/preconditioner.py +137 -0
- torchzero/core/transform.py +252 -0
- torchzero/modules/__init__.py +13 -21
- torchzero/modules/clipping/__init__.py +3 -0
- torchzero/modules/clipping/clipping.py +320 -0
- torchzero/modules/clipping/ema_clipping.py +135 -0
- torchzero/modules/clipping/growth_clipping.py +187 -0
- torchzero/modules/experimental/__init__.py +13 -18
- torchzero/modules/experimental/absoap.py +350 -0
- torchzero/modules/experimental/adadam.py +111 -0
- torchzero/modules/experimental/adamY.py +135 -0
- torchzero/modules/experimental/adasoap.py +282 -0
- torchzero/modules/experimental/algebraic_newton.py +145 -0
- torchzero/modules/experimental/curveball.py +89 -0
- torchzero/modules/experimental/dsoap.py +290 -0
- torchzero/modules/experimental/gradmin.py +85 -0
- torchzero/modules/experimental/reduce_outward_lr.py +35 -0
- torchzero/modules/experimental/spectral.py +286 -0
- torchzero/modules/experimental/subspace_preconditioners.py +128 -0
- torchzero/modules/experimental/tropical_newton.py +136 -0
- torchzero/modules/functional.py +209 -0
- torchzero/modules/grad_approximation/__init__.py +4 -0
- torchzero/modules/grad_approximation/fdm.py +120 -0
- torchzero/modules/grad_approximation/forward_gradient.py +81 -0
- torchzero/modules/grad_approximation/grad_approximator.py +66 -0
- torchzero/modules/grad_approximation/rfdm.py +259 -0
- torchzero/modules/line_search/__init__.py +5 -30
- torchzero/modules/line_search/backtracking.py +186 -0
- torchzero/modules/line_search/line_search.py +181 -0
- torchzero/modules/line_search/scipy.py +37 -0
- torchzero/modules/line_search/strong_wolfe.py +260 -0
- torchzero/modules/line_search/trust_region.py +61 -0
- torchzero/modules/lr/__init__.py +2 -0
- torchzero/modules/lr/lr.py +59 -0
- torchzero/modules/lr/step_size.py +97 -0
- torchzero/modules/momentum/__init__.py +14 -4
- torchzero/modules/momentum/averaging.py +78 -0
- torchzero/modules/momentum/cautious.py +181 -0
- torchzero/modules/momentum/ema.py +173 -0
- torchzero/modules/momentum/experimental.py +189 -0
- torchzero/modules/momentum/matrix_momentum.py +124 -0
- torchzero/modules/momentum/momentum.py +43 -106
- torchzero/modules/ops/__init__.py +103 -0
- torchzero/modules/ops/accumulate.py +65 -0
- torchzero/modules/ops/binary.py +240 -0
- torchzero/modules/ops/debug.py +25 -0
- torchzero/modules/ops/misc.py +419 -0
- torchzero/modules/ops/multi.py +137 -0
- torchzero/modules/ops/reduce.py +149 -0
- torchzero/modules/ops/split.py +75 -0
- torchzero/modules/ops/switch.py +68 -0
- torchzero/modules/ops/unary.py +115 -0
- torchzero/modules/ops/utility.py +112 -0
- torchzero/modules/optimizers/__init__.py +18 -10
- torchzero/modules/optimizers/adagrad.py +146 -49
- torchzero/modules/optimizers/adam.py +112 -118
- torchzero/modules/optimizers/lion.py +18 -11
- torchzero/modules/optimizers/muon.py +222 -0
- torchzero/modules/optimizers/orthograd.py +55 -0
- torchzero/modules/optimizers/rmsprop.py +103 -51
- torchzero/modules/optimizers/rprop.py +342 -99
- torchzero/modules/optimizers/shampoo.py +197 -0
- torchzero/modules/optimizers/soap.py +286 -0
- torchzero/modules/optimizers/sophia_h.py +129 -0
- torchzero/modules/projections/__init__.py +5 -0
- torchzero/modules/projections/dct.py +73 -0
- torchzero/modules/projections/fft.py +73 -0
- torchzero/modules/projections/galore.py +10 -0
- torchzero/modules/projections/projection.py +218 -0
- torchzero/modules/projections/structural.py +151 -0
- torchzero/modules/quasi_newton/__init__.py +7 -4
- torchzero/modules/quasi_newton/cg.py +218 -0
- torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
- torchzero/modules/quasi_newton/lbfgs.py +228 -0
- torchzero/modules/quasi_newton/lsr1.py +170 -0
- torchzero/modules/quasi_newton/olbfgs.py +196 -0
- torchzero/modules/quasi_newton/quasi_newton.py +475 -0
- torchzero/modules/second_order/__init__.py +3 -4
- torchzero/modules/second_order/newton.py +142 -165
- torchzero/modules/second_order/newton_cg.py +84 -0
- torchzero/modules/second_order/nystrom.py +168 -0
- torchzero/modules/smoothing/__init__.py +2 -5
- torchzero/modules/smoothing/gaussian.py +164 -0
- torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
- torchzero/modules/weight_decay/__init__.py +1 -0
- torchzero/modules/weight_decay/weight_decay.py +52 -0
- torchzero/modules/wrappers/__init__.py +1 -0
- torchzero/modules/wrappers/optim_wrapper.py +91 -0
- torchzero/optim/__init__.py +2 -10
- torchzero/optim/utility/__init__.py +1 -0
- torchzero/optim/utility/split.py +45 -0
- torchzero/optim/wrappers/nevergrad.py +2 -28
- torchzero/optim/wrappers/nlopt.py +31 -16
- torchzero/optim/wrappers/scipy.py +79 -156
- torchzero/utils/__init__.py +27 -0
- torchzero/utils/compile.py +175 -37
- torchzero/utils/derivatives.py +513 -99
- torchzero/utils/linalg/__init__.py +5 -0
- torchzero/utils/linalg/matrix_funcs.py +87 -0
- torchzero/utils/linalg/orthogonalize.py +11 -0
- torchzero/utils/linalg/qr.py +71 -0
- torchzero/utils/linalg/solve.py +168 -0
- torchzero/utils/linalg/svd.py +20 -0
- torchzero/utils/numberlist.py +132 -0
- torchzero/utils/ops.py +10 -0
- torchzero/utils/optimizer.py +284 -0
- torchzero/utils/optuna_tools.py +40 -0
- torchzero/utils/params.py +149 -0
- torchzero/utils/python_tools.py +40 -25
- torchzero/utils/tensorlist.py +1081 -0
- torchzero/utils/torch_tools.py +48 -12
- torchzero-0.3.2.dist-info/METADATA +379 -0
- torchzero-0.3.2.dist-info/RECORD +128 -0
- {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info}/WHEEL +1 -1
- {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info/licenses}/LICENSE +0 -0
- torchzero-0.3.2.dist-info/top_level.txt +3 -0
- torchzero/core/tensorlist_optimizer.py +0 -219
- torchzero/modules/adaptive/__init__.py +0 -4
- torchzero/modules/adaptive/adaptive.py +0 -192
- torchzero/modules/experimental/experimental.py +0 -294
- torchzero/modules/experimental/quad_interp.py +0 -104
- torchzero/modules/experimental/subspace.py +0 -259
- torchzero/modules/gradient_approximation/__init__.py +0 -7
- torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
- torchzero/modules/gradient_approximation/base_approximator.py +0 -105
- torchzero/modules/gradient_approximation/fdm.py +0 -125
- torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
- torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
- torchzero/modules/gradient_approximation/rfdm.py +0 -125
- torchzero/modules/line_search/armijo.py +0 -56
- torchzero/modules/line_search/base_ls.py +0 -139
- torchzero/modules/line_search/directional_newton.py +0 -217
- torchzero/modules/line_search/grid_ls.py +0 -158
- torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
- torchzero/modules/meta/__init__.py +0 -12
- torchzero/modules/meta/alternate.py +0 -65
- torchzero/modules/meta/grafting.py +0 -195
- torchzero/modules/meta/optimizer_wrapper.py +0 -173
- torchzero/modules/meta/return_overrides.py +0 -46
- torchzero/modules/misc/__init__.py +0 -10
- torchzero/modules/misc/accumulate.py +0 -43
- torchzero/modules/misc/basic.py +0 -115
- torchzero/modules/misc/lr.py +0 -96
- torchzero/modules/misc/multistep.py +0 -51
- torchzero/modules/misc/on_increase.py +0 -53
- torchzero/modules/operations/__init__.py +0 -29
- torchzero/modules/operations/multi.py +0 -298
- torchzero/modules/operations/reduction.py +0 -134
- torchzero/modules/operations/singular.py +0 -113
- torchzero/modules/optimizers/sgd.py +0 -54
- torchzero/modules/orthogonalization/__init__.py +0 -2
- torchzero/modules/orthogonalization/newtonschulz.py +0 -159
- torchzero/modules/orthogonalization/svd.py +0 -86
- torchzero/modules/regularization/__init__.py +0 -22
- torchzero/modules/regularization/dropout.py +0 -34
- torchzero/modules/regularization/noise.py +0 -77
- torchzero/modules/regularization/normalization.py +0 -328
- torchzero/modules/regularization/ortho_grad.py +0 -78
- torchzero/modules/regularization/weight_decay.py +0 -92
- torchzero/modules/scheduling/__init__.py +0 -2
- torchzero/modules/scheduling/lr_schedulers.py +0 -131
- torchzero/modules/scheduling/step_size.py +0 -80
- torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
- torchzero/modules/weight_averaging/__init__.py +0 -2
- torchzero/modules/weight_averaging/ema.py +0 -72
- torchzero/modules/weight_averaging/swa.py +0 -171
- torchzero/optim/experimental/__init__.py +0 -20
- torchzero/optim/experimental/experimental.py +0 -343
- torchzero/optim/experimental/ray_search.py +0 -83
- torchzero/optim/first_order/__init__.py +0 -18
- torchzero/optim/first_order/cautious.py +0 -158
- torchzero/optim/first_order/forward_gradient.py +0 -70
- torchzero/optim/first_order/optimizers.py +0 -570
- torchzero/optim/modular.py +0 -148
- torchzero/optim/quasi_newton/__init__.py +0 -1
- torchzero/optim/quasi_newton/directional_newton.py +0 -58
- torchzero/optim/second_order/__init__.py +0 -1
- torchzero/optim/second_order/newton.py +0 -94
- torchzero/optim/zeroth_order/__init__.py +0 -4
- torchzero/optim/zeroth_order/fdm.py +0 -87
- torchzero/optim/zeroth_order/newton_fdm.py +0 -146
- torchzero/optim/zeroth_order/rfdm.py +0 -217
- torchzero/optim/zeroth_order/rs.py +0 -85
- torchzero/random/__init__.py +0 -1
- torchzero/random/random.py +0 -46
- torchzero/tensorlist.py +0 -826
- torchzero-0.1.8.dist-info/METADATA +0 -130
- torchzero-0.1.8.dist-info/RECORD +0 -104
- torchzero-0.1.8.dist-info/top_level.txt +0 -1
|
@@ -0,0 +1,290 @@
|
|
|
1
|
+
from operator import itemgetter
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from ...core import Chainable, Transform, apply
|
|
6
|
+
from ...modules.optimizers.shampoo import _merge_small_dims, _unmerge_small_dims
|
|
7
|
+
|
|
8
|
+
@torch.no_grad
|
|
9
|
+
def update_soap_covariances_(
|
|
10
|
+
grad: torch.Tensor,
|
|
11
|
+
GGs_: list[torch.Tensor | None],
|
|
12
|
+
beta: float | None,
|
|
13
|
+
):
|
|
14
|
+
for i, GG in enumerate(GGs_):
|
|
15
|
+
if GG is None: continue
|
|
16
|
+
|
|
17
|
+
axes = list(range(i)) + list(range(i + 1, grad.ndim)) # this works fine with 1d params
|
|
18
|
+
if beta is None: GG.add_(torch.tensordot(grad, grad, (axes, axes))) # pyright:ignore[reportArgumentType]
|
|
19
|
+
else: GG.lerp_(torch.tensordot(grad, grad, (axes, axes)), 1-beta) # pyright:ignore[reportArgumentType]
|
|
20
|
+
|
|
21
|
+
@torch.no_grad
|
|
22
|
+
def project(tensors: torch.Tensor, Q: list[torch.Tensor | None]):
|
|
23
|
+
"""
|
|
24
|
+
Projects the gradient to the eigenbases of the preconditioner.
|
|
25
|
+
"""
|
|
26
|
+
for mat in Q:
|
|
27
|
+
if mat is None: continue
|
|
28
|
+
if len(mat) > 0:
|
|
29
|
+
tensors = torch.tensordot(tensors, mat, dims=[[0], [0]]) # pyright:ignore[reportArgumentType]
|
|
30
|
+
else:
|
|
31
|
+
# I don't understand this part but it is in https://github.com/nikhilvyas/SOAP/blob/main/soap.py
|
|
32
|
+
permute_order = list(range(1, len(tensors.shape))) + [0]
|
|
33
|
+
tensors = tensors.permute(permute_order)
|
|
34
|
+
|
|
35
|
+
return tensors
|
|
36
|
+
|
|
37
|
+
@torch.no_grad
|
|
38
|
+
def project_back(tensors: torch.Tensor, Q: list[torch.Tensor| None]):
|
|
39
|
+
"""
|
|
40
|
+
Projects the gradient back to the original space.
|
|
41
|
+
"""
|
|
42
|
+
for mat in Q:
|
|
43
|
+
if mat is None: continue
|
|
44
|
+
if len(mat) > 0:
|
|
45
|
+
tensors = torch.tensordot(tensors, mat,dims=[[0], [1]]) # pyright:ignore[reportArgumentType]
|
|
46
|
+
else:
|
|
47
|
+
permute_order = list(range(1, len(tensors.shape))) + [0]
|
|
48
|
+
tensors = tensors.permute(permute_order)
|
|
49
|
+
|
|
50
|
+
return tensors
|
|
51
|
+
|
|
52
|
+
# function from https://github.com/nikhilvyas/SOAP/blob/main/soap.py
|
|
53
|
+
@torch.no_grad
|
|
54
|
+
def get_orthogonal_matrix(mat: list[torch.Tensor | None]):
|
|
55
|
+
"""
|
|
56
|
+
Computes the eigenbases of the preconditioner using torch.linalg.eigh decomposition.
|
|
57
|
+
"""
|
|
58
|
+
matrix = []
|
|
59
|
+
float_data = False
|
|
60
|
+
original_type = original_device = None
|
|
61
|
+
for m in mat:
|
|
62
|
+
if m is None: continue
|
|
63
|
+
if len(m) == 0:
|
|
64
|
+
matrix.append([])
|
|
65
|
+
continue
|
|
66
|
+
if m.dtype != torch.float:
|
|
67
|
+
original_type = m.dtype
|
|
68
|
+
original_device = m.device
|
|
69
|
+
matrix.append(m.float())
|
|
70
|
+
else:
|
|
71
|
+
float_data = True
|
|
72
|
+
matrix.append(m)
|
|
73
|
+
|
|
74
|
+
final = []
|
|
75
|
+
for m in matrix:
|
|
76
|
+
if len(m) == 0:
|
|
77
|
+
final.append([])
|
|
78
|
+
continue
|
|
79
|
+
try:
|
|
80
|
+
_, Q = torch.linalg.eigh(m+1e-30*torch.eye(m.shape[0], device=m.device)) # pylint:disable=not-callable
|
|
81
|
+
except Exception:
|
|
82
|
+
_, Q = torch.linalg.eigh(m.to(torch.float64)+1e-30*torch.eye(m.shape[0], device=m.device)) # pylint:disable=not-callable
|
|
83
|
+
Q = Q.to(m.dtype)
|
|
84
|
+
Q = torch.flip(Q, [1])
|
|
85
|
+
|
|
86
|
+
if not float_data:
|
|
87
|
+
Q = Q.to(original_device).type(original_type)
|
|
88
|
+
final.append(Q)
|
|
89
|
+
return final
|
|
90
|
+
|
|
91
|
+
# function from https://github.com/nikhilvyas/SOAP/blob/main/soap.py#L240
|
|
92
|
+
@torch.no_grad
|
|
93
|
+
def get_orthogonal_matrix_QR(exp_avg_sq: torch.Tensor, GG: list[torch.Tensor | None], Q_list: list[torch.Tensor | None]):
|
|
94
|
+
"""
|
|
95
|
+
Computes the eigenbases of the preconditioner using one round of power iteration
|
|
96
|
+
followed by torch.linalg.qr decomposition.
|
|
97
|
+
"""
|
|
98
|
+
matrix = []
|
|
99
|
+
orth_matrix = []
|
|
100
|
+
float_data = False
|
|
101
|
+
original_type = original_device = None
|
|
102
|
+
for m,o in zip(GG, Q_list):
|
|
103
|
+
if m is None: continue
|
|
104
|
+
assert o is not None
|
|
105
|
+
|
|
106
|
+
if len(m) == 0:
|
|
107
|
+
matrix.append([])
|
|
108
|
+
orth_matrix.append([])
|
|
109
|
+
continue
|
|
110
|
+
if m.data.dtype != torch.float:
|
|
111
|
+
original_type = m.data.dtype
|
|
112
|
+
original_device = m.data.device
|
|
113
|
+
matrix.append(m.data.float())
|
|
114
|
+
orth_matrix.append(o.data.float())
|
|
115
|
+
else:
|
|
116
|
+
float_data = True
|
|
117
|
+
matrix.append(m.data.float())
|
|
118
|
+
orth_matrix.append(o.data.float())
|
|
119
|
+
|
|
120
|
+
final = []
|
|
121
|
+
for ind, (m,o) in enumerate(zip(matrix, orth_matrix)):
|
|
122
|
+
if len(m)==0:
|
|
123
|
+
final.append([])
|
|
124
|
+
continue
|
|
125
|
+
est_eig = torch.diag(o.T @ m @ o)
|
|
126
|
+
sort_idx = torch.argsort(est_eig, descending=True)
|
|
127
|
+
exp_avg_sq = exp_avg_sq.index_select(ind, sort_idx)
|
|
128
|
+
o = o[:,sort_idx]
|
|
129
|
+
power_iter = m @ o
|
|
130
|
+
Q, _ = torch.linalg.qr(power_iter) # pylint:disable=not-callable
|
|
131
|
+
|
|
132
|
+
if not float_data:
|
|
133
|
+
Q = Q.to(original_device).type(original_type)
|
|
134
|
+
final.append(Q)
|
|
135
|
+
|
|
136
|
+
return final, exp_avg_sq
|
|
137
|
+
|
|
138
|
+
class DSOAP(Transform):
|
|
139
|
+
"""SOAP but uses scaled gradient differences
|
|
140
|
+
|
|
141
|
+
new args
|
|
142
|
+
|
|
143
|
+
scale by s whether to scale gradient differences by parameter differences
|
|
144
|
+
|
|
145
|
+
y_to_ema2 whether to use gradient differences for exponential moving average too
|
|
146
|
+
"""
|
|
147
|
+
def __init__(
|
|
148
|
+
self,
|
|
149
|
+
beta1: float = 0.95,
|
|
150
|
+
beta2: float = 0.95,
|
|
151
|
+
shampoo_beta: float | None = 0.95,
|
|
152
|
+
precond_freq: int = 10,
|
|
153
|
+
merge_small: bool = True,
|
|
154
|
+
max_dim: int = 2_000,
|
|
155
|
+
precondition_1d: bool = True,
|
|
156
|
+
eps: float = 1e-8,
|
|
157
|
+
decay: float | None = None,
|
|
158
|
+
alpha: float = 1,
|
|
159
|
+
bias_correction: bool = True,
|
|
160
|
+
scale_by_s: bool = True,
|
|
161
|
+
y_to_ema2: bool = False,
|
|
162
|
+
):
|
|
163
|
+
defaults = dict(
|
|
164
|
+
beta1=beta1,
|
|
165
|
+
beta2=beta2,
|
|
166
|
+
shampoo_beta=shampoo_beta,
|
|
167
|
+
precond_freq=precond_freq,
|
|
168
|
+
merge_small=merge_small,
|
|
169
|
+
max_dim=max_dim,
|
|
170
|
+
precondition_1d=precondition_1d,
|
|
171
|
+
eps=eps,
|
|
172
|
+
decay=decay,
|
|
173
|
+
bias_correction=bias_correction,
|
|
174
|
+
alpha=alpha,
|
|
175
|
+
scale_by_s=scale_by_s,
|
|
176
|
+
y_to_ema2=y_to_ema2,
|
|
177
|
+
)
|
|
178
|
+
super().__init__(defaults, uses_grad=False)
|
|
179
|
+
|
|
180
|
+
@torch.no_grad
|
|
181
|
+
def transform(self, tensors, params, grads, vars):
|
|
182
|
+
updates = []
|
|
183
|
+
# update preconditioners
|
|
184
|
+
for i,(p,t) in enumerate(zip(params, tensors)):
|
|
185
|
+
state = self.state[p]
|
|
186
|
+
settings = self.settings[p]
|
|
187
|
+
beta1, beta2, shampoo_beta, merge_small, max_dim, precondition_1d, eps, alpha = itemgetter(
|
|
188
|
+
'beta1', 'beta2', 'shampoo_beta', 'merge_small', 'max_dim', 'precondition_1d', 'eps', 'alpha')(settings)
|
|
189
|
+
scale_by_s = settings['scale_by_s']
|
|
190
|
+
y_to_ema2 = settings['y_to_ema2']
|
|
191
|
+
|
|
192
|
+
if merge_small:
|
|
193
|
+
t, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(t, max_dim)
|
|
194
|
+
|
|
195
|
+
if 'g_prev' not in state:
|
|
196
|
+
state['p_prev'] = p.clone()
|
|
197
|
+
state['g_prev'] = t.clone()
|
|
198
|
+
updates.append(tensors[i].sign())
|
|
199
|
+
continue
|
|
200
|
+
|
|
201
|
+
p_prev = state['p_prev']
|
|
202
|
+
g_prev = state['g_prev']
|
|
203
|
+
s = p - p_prev
|
|
204
|
+
y = t - g_prev
|
|
205
|
+
if scale_by_s: y /= torch.linalg.norm(s).clip(min=1e-8) # pylint:disable=not-callable
|
|
206
|
+
|
|
207
|
+
state['p_prev'].copy_(p)
|
|
208
|
+
state['g_prev'].copy_(t)
|
|
209
|
+
|
|
210
|
+
# initialize state on 1st step
|
|
211
|
+
if 'GG' not in state:
|
|
212
|
+
state["exp_avg"] = torch.zeros_like(t)
|
|
213
|
+
if y_to_ema2: state["exp_avg_sq"] = torch.ones_like(t)
|
|
214
|
+
else: state["exp_avg_sq"] = torch.zeros_like(t)
|
|
215
|
+
|
|
216
|
+
if not precondition_1d and t.ndim <= 1:
|
|
217
|
+
state['GG'] = []
|
|
218
|
+
|
|
219
|
+
else:
|
|
220
|
+
state['GG'] = [torch.zeros(sh, sh, dtype=t.dtype, device=t.device) if 1<sh<max_dim else None for sh in t.shape]
|
|
221
|
+
|
|
222
|
+
# either scalar parameter, 1d with precondition_1d=False, or all dims are too big.
|
|
223
|
+
if len([i is not None for i in state['GG']]) == 0:
|
|
224
|
+
state['GG'] = None
|
|
225
|
+
|
|
226
|
+
if state['GG'] is not None:
|
|
227
|
+
update_soap_covariances_(y, GGs_=state['GG'], beta=shampoo_beta)
|
|
228
|
+
state['Q'] = get_orthogonal_matrix(state['GG'])
|
|
229
|
+
|
|
230
|
+
state['step'] = 0
|
|
231
|
+
updates.append(tensors[i].sign())
|
|
232
|
+
continue # skip 1st step as in https://github.com/nikhilvyas/SOAP/blob/main/soap.py ?
|
|
233
|
+
# I use sign instead as to not mess up with next modules. 1st Adam step is always sign anyway.
|
|
234
|
+
|
|
235
|
+
# Projecting gradients to the eigenbases of Shampoo's preconditioner
|
|
236
|
+
# i.e. projecting to the eigenbases of matrices in state['GG']
|
|
237
|
+
z_projected = None
|
|
238
|
+
if state['GG'] is not None:
|
|
239
|
+
if y_to_ema2: z_projected = project(y, state['Q'])
|
|
240
|
+
else: z_projected = project(t, state['Q'])
|
|
241
|
+
|
|
242
|
+
# exponential moving averages
|
|
243
|
+
# this part could be foreached but I will do that at some point its not a big difference compared to preconditioning
|
|
244
|
+
exp_avg: torch.Tensor = state["exp_avg"]
|
|
245
|
+
exp_avg_sq: torch.Tensor = state["exp_avg_sq"]
|
|
246
|
+
|
|
247
|
+
exp_avg.lerp_(t, 1-beta1)
|
|
248
|
+
|
|
249
|
+
if z_projected is None:
|
|
250
|
+
if y_to_ema2: exp_avg_sq.mul_(beta2).addcmul_(y, y, value=1-beta2)
|
|
251
|
+
else: exp_avg_sq.mul_(beta2).addcmul_(t, t, value=1-beta2)
|
|
252
|
+
else:
|
|
253
|
+
exp_avg_sq.mul_(beta2).addcmul_(z_projected, z_projected, value=1-beta2)
|
|
254
|
+
|
|
255
|
+
# project exponential moving averages if they are accumulated unprojected
|
|
256
|
+
exp_avg_projected = exp_avg
|
|
257
|
+
if z_projected is not None:
|
|
258
|
+
exp_avg_projected = project(exp_avg, state['Q'])
|
|
259
|
+
|
|
260
|
+
exp_avg_sq_projected = exp_avg_sq
|
|
261
|
+
|
|
262
|
+
denom = exp_avg_sq_projected.sqrt().add_(eps)
|
|
263
|
+
# print(f'{t_projected = }, {exp_avg = }, {exp_avg_projected = }, {exp_avg_sq = }, {exp_avg_sq_projected = }, {denom = }')
|
|
264
|
+
|
|
265
|
+
# Projecting back the preconditioned (by Adam) exponential moving average of gradients
|
|
266
|
+
# to the original space
|
|
267
|
+
update = exp_avg_projected / denom
|
|
268
|
+
if z_projected is not None:
|
|
269
|
+
update = project_back(update, state["Q"])
|
|
270
|
+
|
|
271
|
+
if settings['bias_correction']:
|
|
272
|
+
bias_correction1 = 1.0 - beta1 ** (state["step"]+1)
|
|
273
|
+
bias_correction2 = 1.0 - beta2 ** (state["step"]+1)
|
|
274
|
+
update *= ((bias_correction2 ** .5) / bias_correction1) * alpha
|
|
275
|
+
elif alpha is not None:
|
|
276
|
+
update *= alpha
|
|
277
|
+
|
|
278
|
+
if merge_small:
|
|
279
|
+
update = _unmerge_small_dims(update, state['flat_sizes'], state['sort_idxs'])
|
|
280
|
+
|
|
281
|
+
updates.append(update)
|
|
282
|
+
state["step"] += 1
|
|
283
|
+
|
|
284
|
+
# Update is done after the gradient step to avoid using current gradients in the projection.
|
|
285
|
+
if state['GG'] is not None:
|
|
286
|
+
update_soap_covariances_(y, state['GG'], shampoo_beta)
|
|
287
|
+
if state['step'] % settings['precond_freq'] == 0:
|
|
288
|
+
state['Q'], state['exp_avg_sq'] = get_orthogonal_matrix_QR(exp_avg_sq, state['GG'], state['Q'])
|
|
289
|
+
|
|
290
|
+
return updates
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from collections.abc import Callable, Sequence
|
|
4
|
+
from typing import Literal
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from ...core import Module, Vars
|
|
9
|
+
from ...utils import NumberList, TensorList
|
|
10
|
+
from ...utils.derivatives import jacobian_wrt
|
|
11
|
+
from ..grad_approximation import GradApproximator, GradTarget
|
|
12
|
+
from ..smoothing.gaussian import Reformulation
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class GradMin(Reformulation):
|
|
17
|
+
"""Reformulates the objective to minimize sum of gradient magnitudes via autograd.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
loss_term (float, optional): adds loss value times this to sum of gradient magnitudes. Defaults to 1.
|
|
21
|
+
relative (bool, optional): whether to make loss_term relative to gradient magnitude. Defaults to False.
|
|
22
|
+
graft (bool, optional): whether to make loss term same as gradient magnitude. Defaults to False.
|
|
23
|
+
square (bool, optional): whether to use sum of squared gradient magnitudes, if False uses absolute values. Defaults to False.
|
|
24
|
+
mean (bool, optional): whether to use mean, if False uses sum. Defaults to True.
|
|
25
|
+
maximize_grad (bool, optional): whether to maximize gradient magnitudes instead of minimizing. Defaults to False.
|
|
26
|
+
create_graph (bool, optional): whether to create graph. Defaults to False.
|
|
27
|
+
modify_loss (bool, optional): whether to modify the loss value to make line searches minimize new objective. Defaults to True.
|
|
28
|
+
"""
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
loss_term: float | None = 0,
|
|
32
|
+
relative: Literal['loss_to_grad', 'grad_to_loss'] | None = None,
|
|
33
|
+
graft: Literal['loss_to_grad', 'grad_to_loss'] | None = None,
|
|
34
|
+
square=False,
|
|
35
|
+
mean=True,
|
|
36
|
+
maximize_grad=False,
|
|
37
|
+
create_graph=False,
|
|
38
|
+
modify_loss: bool = True,
|
|
39
|
+
):
|
|
40
|
+
if (relative is not None) and (graft is not None): warnings.warn('both relative and graft loss are True, they will clash with each other')
|
|
41
|
+
defaults = dict(loss_term=loss_term, relative=relative, graft=graft, square=square, mean=mean, maximize_grad=maximize_grad, create_graph=create_graph, modify_loss=modify_loss)
|
|
42
|
+
super().__init__(defaults)
|
|
43
|
+
|
|
44
|
+
@torch.no_grad
|
|
45
|
+
def closure(self, backward, closure, params, vars):
|
|
46
|
+
settings = self.settings[params[0]]
|
|
47
|
+
loss_term = settings['loss_term']
|
|
48
|
+
relative = settings['relative']
|
|
49
|
+
graft = settings['graft']
|
|
50
|
+
square = settings['square']
|
|
51
|
+
maximize_grad = settings['maximize_grad']
|
|
52
|
+
create_graph = settings['create_graph']
|
|
53
|
+
modify_loss = settings['modify_loss']
|
|
54
|
+
mean = settings['mean']
|
|
55
|
+
|
|
56
|
+
with torch.enable_grad():
|
|
57
|
+
for p in params: p.grad = None
|
|
58
|
+
loss = closure(False)
|
|
59
|
+
grads = TensorList(torch.autograd.grad(loss, params, create_graph=True))
|
|
60
|
+
|
|
61
|
+
if square: grads = grads ** 2
|
|
62
|
+
else: grads = grads.abs()
|
|
63
|
+
|
|
64
|
+
if mean: f = grads.global_mean()
|
|
65
|
+
else: f = grads.global_sum()
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
if graft == 'grad_to_loss': f = f * (loss.detach()/f.detach()).detach()
|
|
69
|
+
if relative == 'grad_to_loss': f = f * loss
|
|
70
|
+
|
|
71
|
+
if loss_term is not None and loss_term != 0:
|
|
72
|
+
if relative == 'loss_to_grad': loss_term = loss_term * f
|
|
73
|
+
l = loss
|
|
74
|
+
if graft == 'loss_to_grad': l = loss * (f.detach()/loss.detach()).detach()
|
|
75
|
+
f = f + l*loss_term
|
|
76
|
+
|
|
77
|
+
if maximize_grad: f = -f
|
|
78
|
+
if modify_loss: loss = f
|
|
79
|
+
|
|
80
|
+
grad = None
|
|
81
|
+
if backward:
|
|
82
|
+
for p in params: p.grad = None
|
|
83
|
+
grad = TensorList(torch.autograd.grad(f, params, create_graph=create_graph))
|
|
84
|
+
|
|
85
|
+
return loss, grad
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from ...core import Target, Transform
|
|
4
|
+
from ...utils import TensorList
|
|
5
|
+
|
|
6
|
+
class ReduceOutwardLR(Transform):
|
|
7
|
+
"""
|
|
8
|
+
When update sign matches weight sign, the learning rate for that weight is multiplied by `mul`.
|
|
9
|
+
|
|
10
|
+
This means updates that move weights towards zero have higher learning rates.
|
|
11
|
+
"""
|
|
12
|
+
def __init__(self, mul = 0.5, use_grad=False, invert=False, target: Target = 'update'):
|
|
13
|
+
defaults = dict(mul=mul, use_grad=use_grad, invert=invert)
|
|
14
|
+
super().__init__(defaults, uses_grad=use_grad, target=target)
|
|
15
|
+
|
|
16
|
+
@torch.no_grad
|
|
17
|
+
def transform(self, tensors, params, grads, vars):
|
|
18
|
+
params = TensorList(params)
|
|
19
|
+
tensors = TensorList(tensors)
|
|
20
|
+
|
|
21
|
+
mul = self.get_settings('mul', params=params)
|
|
22
|
+
s = self.settings[params[0]]
|
|
23
|
+
use_grad = s['use_grad']
|
|
24
|
+
invert = s['invert']
|
|
25
|
+
|
|
26
|
+
if use_grad: cur = vars.get_grad()
|
|
27
|
+
else: cur = tensors
|
|
28
|
+
|
|
29
|
+
# mask of weights where sign matches with update sign (minus ascent sign), multiplied by `mul`.
|
|
30
|
+
if invert: mask = (params * cur) > 0
|
|
31
|
+
else: mask = (params * cur) < 0
|
|
32
|
+
|
|
33
|
+
tensors.masked_set_(mask, tensors*mul)
|
|
34
|
+
|
|
35
|
+
return tensors
|
|
@@ -0,0 +1,286 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
import math
|
|
3
|
+
from collections import deque
|
|
4
|
+
from typing import Literal, Any
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from ...core import Chainable, TensorwisePreconditioner
|
|
8
|
+
from ...utils.linalg.matrix_funcs import matrix_power_eigh
|
|
9
|
+
from ...utils.linalg.svd import randomized_svd
|
|
10
|
+
from ...utils.linalg.qr import qr_householder
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class _Solver:
|
|
14
|
+
@abstractmethod
|
|
15
|
+
def update(self, history: deque[torch.Tensor], damping: float | None) -> tuple[Any, Any]:
|
|
16
|
+
"""returns stuff for apply"""
|
|
17
|
+
@abstractmethod
|
|
18
|
+
def apply(self, __g: torch.Tensor, __A:torch.Tensor, __B:torch.Tensor) -> torch.Tensor:
|
|
19
|
+
"""apply preconditioning to tensor"""
|
|
20
|
+
|
|
21
|
+
class _SVDSolver(_Solver):
|
|
22
|
+
def __init__(self, driver=None): self.driver=driver
|
|
23
|
+
def update(self, history, damping):
|
|
24
|
+
M_hist = torch.stack(tuple(history), dim=1)
|
|
25
|
+
device = None # driver is CUDA only
|
|
26
|
+
if self.driver is not None:
|
|
27
|
+
device = M_hist.device
|
|
28
|
+
M_hist = M_hist.cuda()
|
|
29
|
+
|
|
30
|
+
try:
|
|
31
|
+
U, S, _ = torch.linalg.svd(M_hist, full_matrices=False, driver=self.driver) # pylint:disable=not-callable
|
|
32
|
+
|
|
33
|
+
if self.driver is not None:
|
|
34
|
+
U = U.to(device); S = S.to(device)
|
|
35
|
+
|
|
36
|
+
if damping is not None and damping != 0: S.add_(damping)
|
|
37
|
+
return U, S
|
|
38
|
+
|
|
39
|
+
except torch.linalg.LinAlgError:
|
|
40
|
+
return None, None
|
|
41
|
+
|
|
42
|
+
def apply(self, g: torch.Tensor, U: torch.Tensor, S: torch.Tensor):
|
|
43
|
+
Utg = (U.T @ g).div_(S)
|
|
44
|
+
return U @ Utg
|
|
45
|
+
|
|
46
|
+
class _SVDLowRankSolver(_Solver):
|
|
47
|
+
def __init__(self, q: int = 6, niter: int = 2): self.q, self.niter = q, niter
|
|
48
|
+
def update(self, history, damping):
|
|
49
|
+
M_hist = torch.stack(tuple(history), dim=1)
|
|
50
|
+
try:
|
|
51
|
+
U, S, _ = torch.svd_lowrank(M_hist, q=self.q, niter=self.niter)
|
|
52
|
+
if damping is not None and damping != 0: S.add_(damping)
|
|
53
|
+
return U, S
|
|
54
|
+
except torch.linalg.LinAlgError:
|
|
55
|
+
return None, None
|
|
56
|
+
|
|
57
|
+
def apply(self, g: torch.Tensor, U: torch.Tensor, S: torch.Tensor):
|
|
58
|
+
Utg = (U.T @ g).div_(S)
|
|
59
|
+
return U @ Utg
|
|
60
|
+
|
|
61
|
+
class _RandomizedSVDSolver(_Solver):
|
|
62
|
+
def __init__(self, k: int = 3, driver: str | None = 'gesvda'):
|
|
63
|
+
self.driver = driver
|
|
64
|
+
self.k = k
|
|
65
|
+
|
|
66
|
+
def update(self, history, damping):
|
|
67
|
+
M_hist = torch.stack(tuple(history), dim=1)
|
|
68
|
+
device = None # driver is CUDA only
|
|
69
|
+
if self.driver is not None:
|
|
70
|
+
device = M_hist.device
|
|
71
|
+
M_hist = M_hist.cuda()
|
|
72
|
+
|
|
73
|
+
try:
|
|
74
|
+
U, S, _ = randomized_svd(M_hist, k=self.k, driver=self.driver)
|
|
75
|
+
|
|
76
|
+
if self.driver is not None:
|
|
77
|
+
U = U.to(device); S = S.to(device)
|
|
78
|
+
|
|
79
|
+
if damping is not None and damping != 0: S.add_(damping)
|
|
80
|
+
return U, S
|
|
81
|
+
|
|
82
|
+
except torch.linalg.LinAlgError:
|
|
83
|
+
return None, None
|
|
84
|
+
|
|
85
|
+
def apply(self, g: torch.Tensor, U: torch.Tensor, S: torch.Tensor):
|
|
86
|
+
Utg = (U.T @ g).div_(S)
|
|
87
|
+
return U @ Utg
|
|
88
|
+
|
|
89
|
+
class _QRDiagonalSolver(_Solver):
|
|
90
|
+
def __init__(self, sqrt=True): self.sqrt = sqrt
|
|
91
|
+
def update(self, history, damping):
|
|
92
|
+
M_hist = torch.stack(tuple(history), dim=1)
|
|
93
|
+
try:
|
|
94
|
+
Q, R = torch.linalg.qr(M_hist, mode='reduced') # pylint:disable=not-callable
|
|
95
|
+
R_diag = R.diag().abs()
|
|
96
|
+
if damping is not None and damping != 0: R_diag.add_(damping)
|
|
97
|
+
if self.sqrt: R_diag.sqrt_()
|
|
98
|
+
return Q, R_diag
|
|
99
|
+
except torch.linalg.LinAlgError:
|
|
100
|
+
return None, None
|
|
101
|
+
|
|
102
|
+
def apply(self, g: torch.Tensor, Q: torch.Tensor, R_diag: torch.Tensor):
|
|
103
|
+
Qtg = (Q.T @ g).div_(R_diag)
|
|
104
|
+
return Q @ Qtg
|
|
105
|
+
|
|
106
|
+
class _QRSolver(_Solver):
|
|
107
|
+
def __init__(self, sqrt=True): self.sqrt = sqrt
|
|
108
|
+
def update(self, history, damping):
|
|
109
|
+
M_hist = torch.stack(tuple(history), dim=1)
|
|
110
|
+
try:
|
|
111
|
+
# Q: d x k, R: k x k
|
|
112
|
+
Q, R = torch.linalg.qr(M_hist, mode='reduced') # pylint:disable=not-callable
|
|
113
|
+
A = R @ R.T
|
|
114
|
+
if damping is not None and damping != 0: A.diagonal(dim1=-2, dim2=-1).add_(damping)
|
|
115
|
+
if self.sqrt: A = matrix_power_eigh(A, 0.5)
|
|
116
|
+
return Q, A
|
|
117
|
+
except (torch.linalg.LinAlgError):
|
|
118
|
+
return None,None
|
|
119
|
+
|
|
120
|
+
def apply(self, g: torch.Tensor, Q: torch.Tensor, A: torch.Tensor) -> torch.Tensor:
|
|
121
|
+
g_proj = Q.T @ g
|
|
122
|
+
y, _ = torch.linalg.solve_ex(A, g_proj) # pylint:disable=not-callable
|
|
123
|
+
return Q @ y
|
|
124
|
+
|
|
125
|
+
class _QRHouseholderSolver(_Solver):
|
|
126
|
+
def __init__(self, sqrt=True): self.sqrt = sqrt
|
|
127
|
+
def update(self, history, damping):
|
|
128
|
+
M_hist = torch.stack(tuple(history), dim=1)
|
|
129
|
+
try:
|
|
130
|
+
# Q: d x k, R: k x k
|
|
131
|
+
Q, R = qr_householder(M_hist, mode='reduced') # pylint:disable=not-callable
|
|
132
|
+
A = R @ R.T
|
|
133
|
+
if damping is not None and damping != 0: A.diagonal(dim1=-2, dim2=-1).add_(damping)
|
|
134
|
+
if self.sqrt: A = matrix_power_eigh(A, 0.5)
|
|
135
|
+
return Q, A
|
|
136
|
+
except (torch.linalg.LinAlgError):
|
|
137
|
+
return None,None
|
|
138
|
+
|
|
139
|
+
def apply(self, g: torch.Tensor, Q: torch.Tensor, A: torch.Tensor) -> torch.Tensor:
|
|
140
|
+
g_proj = Q.T @ g
|
|
141
|
+
y, _ = torch.linalg.solve_ex(A, g_proj) # pylint:disable=not-callable
|
|
142
|
+
return Q @ y
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
class _EighSolver(_Solver):
|
|
146
|
+
def __init__(self, sqrt=True):
|
|
147
|
+
self.sqrt = sqrt
|
|
148
|
+
|
|
149
|
+
def update(self, history, damping):
|
|
150
|
+
M_hist = torch.stack(tuple(history), dim=1)
|
|
151
|
+
grams = M_hist @ M_hist.T # (d, d)
|
|
152
|
+
if damping is not None and damping != 0: grams.diagonal(dim1=-2, dim2=-1).add_(damping)
|
|
153
|
+
try:
|
|
154
|
+
L, Q = torch.linalg.eigh(grams) # L: (d,), Q: (d, d) # pylint:disable=not-callable
|
|
155
|
+
L = L.abs().clamp_(min=1e-12)
|
|
156
|
+
if self.sqrt: L = L.sqrt()
|
|
157
|
+
return Q, L
|
|
158
|
+
except torch.linalg.LinAlgError:
|
|
159
|
+
return None, None
|
|
160
|
+
|
|
161
|
+
def apply(self, g: torch.Tensor, Q: torch.Tensor, L: torch.Tensor) -> torch.Tensor:
|
|
162
|
+
Qtg = (Q.T @ g).div_(L)
|
|
163
|
+
return Q @ Qtg
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
SOLVERS = {
|
|
167
|
+
"svd": _SVDSolver(), # fallbacks on "gesvd" which basically takes ages or just hangs completely
|
|
168
|
+
"svd_gesvdj": _SVDSolver("gesvdj"), # no fallback on slow "gesvd"
|
|
169
|
+
"svd_gesvda": _SVDSolver("gesvda"), # approximate method for wide matrices, sometimes better sometimes worse but faster
|
|
170
|
+
"svd_lowrank": _SVDLowRankSolver(), # maybe need to tune parameters for this, with current ones its slower and worse
|
|
171
|
+
"randomized_svd2": _RandomizedSVDSolver(2),
|
|
172
|
+
"randomized_svd3": _RandomizedSVDSolver(3),
|
|
173
|
+
"randomized_svd4": _RandomizedSVDSolver(4),
|
|
174
|
+
"randomized_svd5": _RandomizedSVDSolver(5),
|
|
175
|
+
"eigh": _EighSolver(), # this is O(n**2) storage, but is this more accurate?
|
|
176
|
+
"qr": _QRSolver(),
|
|
177
|
+
"qr_householder": _QRHouseholderSolver(), # this is slower... but maybe it won't freeze? I think svd_gesvda is better
|
|
178
|
+
"qrdiag": _QRDiagonalSolver(),
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
def maybe_lerp_(state_: dict, beta: float | None, key, value: Any):
|
|
182
|
+
if (key not in state_) or (beta is None) or (not isinstance(value, torch.Tensor)): state_[key] = value
|
|
183
|
+
else:
|
|
184
|
+
if state_[key].shape != value.shape: state_[key] = value
|
|
185
|
+
else: state_[key].lerp_(value, 1-beta)
|
|
186
|
+
|
|
187
|
+
class SpectralPreconditioner(TensorwisePreconditioner):
|
|
188
|
+
"""Whitening preconditioner via SVD on history of past gradients or gradient differences scaled by parameter differences.
|
|
189
|
+
|
|
190
|
+
Args:
|
|
191
|
+
history_size (int, optional): number of past gradients to store for preconditioning. Defaults to 10.
|
|
192
|
+
update_freq (int, optional): how often to re-compute the preconditioner. Defaults to 1.
|
|
193
|
+
damping (float, optional): damping term, makes it closer to GD. Defaults to 1e-7.
|
|
194
|
+
order (int, optional):
|
|
195
|
+
whitening order, 1 approximates FIM (maybe), 2 - hessian (maybe), 3+ - god knows what.
|
|
196
|
+
solver (str, optional): what to use for whitening. Defaults to 'svd'.
|
|
197
|
+
U_beta (float | None, optional): beta for U (probably a bad idea). Defaults to None.
|
|
198
|
+
S_beta (float | None, optional): beta for S (probably a bad idea). Defaults to None.
|
|
199
|
+
interval (int, optional): How often to update history. Defaults to 1 (every step).
|
|
200
|
+
concat_params (bool, optional):
|
|
201
|
+
whether to apply preconditioning to each tensor (False, default) or to all tensors concatenated into a vector (True). Latter will be slower but captures interactions between layers. Defaults to True.
|
|
202
|
+
scale_first (bool, optional): makes first step small, usually not needed. Defaults to False.
|
|
203
|
+
inner (Chainable | None, optional): Inner modules applied after updating preconditioner and before applying it. Defaults to None.
|
|
204
|
+
"""
|
|
205
|
+
def __init__(
|
|
206
|
+
self,
|
|
207
|
+
history_size: int = 10,
|
|
208
|
+
update_freq: int = 1,
|
|
209
|
+
damping: float = 1e-12,
|
|
210
|
+
order: int = 1,
|
|
211
|
+
solver: Literal['svd', 'svd_gesvdj', 'svd_gesvda', 'svd_lowrank', 'eigh', 'qr', 'qrdiag', 'qr_householder'] | _Solver | str = 'svd_gesvda',
|
|
212
|
+
A_beta: float | None = None,
|
|
213
|
+
B_beta: float | None = None,
|
|
214
|
+
interval: int = 1,
|
|
215
|
+
concat_params: bool = False,
|
|
216
|
+
scale_first: bool = False,
|
|
217
|
+
inner: Chainable | None = None,
|
|
218
|
+
):
|
|
219
|
+
if isinstance(solver, str): solver = SOLVERS[solver]
|
|
220
|
+
# history is still updated each step so Precondition's update_freq has different meaning
|
|
221
|
+
defaults = dict(history_size=history_size, update_freq=update_freq, damping=damping, order=order, A_beta=A_beta, B_beta=B_beta, solver=solver)
|
|
222
|
+
super().__init__(defaults, uses_grad=False, concat_params=concat_params, scale_first=scale_first, inner=inner, update_freq=interval)
|
|
223
|
+
|
|
224
|
+
@torch.no_grad
|
|
225
|
+
def update_tensor(self, tensor, param, grad, state, settings):
|
|
226
|
+
order = settings['order']
|
|
227
|
+
history_size = settings['history_size']
|
|
228
|
+
update_freq = settings['update_freq']
|
|
229
|
+
damping = settings['damping']
|
|
230
|
+
A_beta = settings['A_beta']
|
|
231
|
+
B_beta = settings['B_beta']
|
|
232
|
+
solver: _Solver = settings['solver']
|
|
233
|
+
|
|
234
|
+
if 'history' not in state: state['history'] = deque(maxlen=history_size)
|
|
235
|
+
history = state['history']
|
|
236
|
+
|
|
237
|
+
if order == 1: history.append(tensor.clone().view(-1))
|
|
238
|
+
else:
|
|
239
|
+
|
|
240
|
+
# if order=2, history is of gradient differences, order 3 is differences between differences, etc
|
|
241
|
+
# normalized by parameter differences
|
|
242
|
+
cur_p = param.clone()
|
|
243
|
+
cur_g = tensor.clone()
|
|
244
|
+
for i in range(1, order):
|
|
245
|
+
if f'prev_g_{i}' not in state:
|
|
246
|
+
state[f'prev_p_{i}'] = cur_p
|
|
247
|
+
state[f'prev_g_{i}'] = cur_g
|
|
248
|
+
break
|
|
249
|
+
|
|
250
|
+
s_k = cur_p - state[f'prev_p_{i}']
|
|
251
|
+
y_k = cur_g - state[f'prev_g_{i}']
|
|
252
|
+
state[f'prev_p_{i}'] = cur_p
|
|
253
|
+
state[f'prev_g_{i}'] = cur_g
|
|
254
|
+
cur_p = s_k
|
|
255
|
+
cur_g = y_k
|
|
256
|
+
|
|
257
|
+
if i == order - 1:
|
|
258
|
+
cur_g = cur_g / torch.linalg.norm(cur_p).clip(min=1e-8) # pylint:disable=not-callable
|
|
259
|
+
history.append(cur_g.view(-1))
|
|
260
|
+
|
|
261
|
+
step = state.get('step', 0)
|
|
262
|
+
if step % update_freq == 0 and len(history) != 0:
|
|
263
|
+
A, B = solver.update(history, damping=damping)
|
|
264
|
+
maybe_lerp_(state, A_beta, 'A', A)
|
|
265
|
+
maybe_lerp_(state, B_beta, 'B', B)
|
|
266
|
+
|
|
267
|
+
if len(history) != 0:
|
|
268
|
+
state['step'] = step + 1 # do not increment if no history (gathering s_ks and y_ks)
|
|
269
|
+
|
|
270
|
+
@torch.no_grad
|
|
271
|
+
def apply_tensor(self, tensor, param, grad, state, settings):
|
|
272
|
+
history_size = settings['history_size']
|
|
273
|
+
solver: _Solver = settings['solver']
|
|
274
|
+
|
|
275
|
+
A = state.get('A', None)
|
|
276
|
+
if A is None:
|
|
277
|
+
# make a conservative step to avoid issues due to different GD scaling
|
|
278
|
+
return tensor.div_(max(1, tensor.abs().sum())) # pyright:ignore[reportArgumentType]
|
|
279
|
+
|
|
280
|
+
B = state['B']
|
|
281
|
+
update = solver.apply(tensor.view(-1), A, B).view_as(tensor)
|
|
282
|
+
|
|
283
|
+
n = len(state['history'])
|
|
284
|
+
if n != history_size: update.mul_(n/history_size)
|
|
285
|
+
return update
|
|
286
|
+
|