torchzero 0.4.1__py3-none-any.whl → 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchzero/__init__.py +3 -1
- torchzero/_minimize/__init__.py +0 -0
- torchzero/_minimize/methods.py +95 -0
- torchzero/_minimize/minimize.py +518 -0
- torchzero/core/__init__.py +5 -5
- torchzero/core/chain.py +2 -1
- torchzero/core/functional.py +2 -1
- torchzero/core/module.py +75 -4
- torchzero/core/transform.py +6 -5
- torchzero/linalg/eigh.py +116 -68
- torchzero/linalg/linear_operator.py +1 -0
- torchzero/linalg/orthogonalize.py +60 -5
- torchzero/linalg/sketch.py +39 -0
- torchzero/modules/__init__.py +1 -0
- torchzero/modules/adaptive/adagrad.py +2 -0
- torchzero/modules/adaptive/adam.py +5 -1
- torchzero/modules/adaptive/adan.py +3 -0
- torchzero/modules/adaptive/ggt.py +20 -18
- torchzero/modules/adaptive/lion.py +3 -1
- torchzero/modules/adaptive/mars.py +6 -5
- torchzero/modules/adaptive/msam.py +3 -0
- torchzero/modules/adaptive/rmsprop.py +2 -0
- torchzero/modules/adaptive/rprop.py +9 -7
- torchzero/modules/adaptive/shampoo.py +9 -1
- torchzero/modules/adaptive/soap.py +32 -29
- torchzero/modules/basis/__init__.py +2 -0
- torchzero/modules/basis/ggt_basis.py +199 -0
- torchzero/modules/basis/soap_basis.py +254 -0
- torchzero/modules/clipping/ema_clipping.py +32 -27
- torchzero/modules/clipping/growth_clipping.py +1 -0
- torchzero/modules/experimental/__init__.py +1 -6
- torchzero/modules/experimental/coordinate_momentum.py +2 -0
- torchzero/modules/experimental/cubic_adam.py +4 -0
- torchzero/modules/grad_approximation/__init__.py +3 -2
- torchzero/modules/least_squares/gn.py +6 -0
- torchzero/modules/misc/gradient_accumulation.py +1 -0
- torchzero/modules/misc/misc.py +6 -0
- torchzero/modules/momentum/averaging.py +6 -0
- torchzero/modules/momentum/momentum.py +4 -0
- torchzero/modules/ops/__init__.py +0 -1
- torchzero/modules/ops/accumulate.py +4 -0
- torchzero/modules/ops/higher_level.py +6 -1
- torchzero/modules/second_order/inm.py +4 -0
- torchzero/modules/second_order/newton.py +11 -3
- torchzero/modules/second_order/newton_cg.py +7 -3
- torchzero/modules/second_order/nystrom.py +14 -19
- torchzero/modules/second_order/rsn.py +37 -6
- torchzero/modules/trust_region/trust_region.py +2 -1
- torchzero/utils/benchmarks/logistic.py +33 -18
- torchzero/utils/params.py +13 -1
- torchzero/utils/tensorlist.py +2 -2
- {torchzero-0.4.1.dist-info → torchzero-0.4.2.dist-info}/METADATA +1 -1
- {torchzero-0.4.1.dist-info → torchzero-0.4.2.dist-info}/RECORD +56 -53
- torchzero/modules/experimental/adanystrom.py +0 -258
- torchzero/modules/experimental/common_directions_whiten.py +0 -142
- torchzero/modules/experimental/eigen_sr1.py +0 -182
- torchzero/modules/experimental/eigengrad.py +0 -207
- /torchzero/modules/{experimental → grad_approximation}/spsa1.py +0 -0
- {torchzero-0.4.1.dist-info → torchzero-0.4.2.dist-info}/WHEEL +0 -0
- {torchzero-0.4.1.dist-info → torchzero-0.4.2.dist-info}/top_level.txt +0 -0
|
@@ -7,7 +7,7 @@ from ...core import Chainable, TensorTransform
|
|
|
7
7
|
from ...linalg import torch_linalg, regularize_eigh
|
|
8
8
|
from .lre_optimizers import LREOptimizerBase
|
|
9
9
|
|
|
10
|
-
def ggt_update(history: deque[torch.Tensor] | torch.Tensor, damping, rdamping, truncate, eig_tol):
|
|
10
|
+
def ggt_update(history: deque[torch.Tensor] | torch.Tensor, damping, rdamping, truncate, eig_tol, matrix_power=-1/2):
|
|
11
11
|
"""returns U ``(ndim, rank)``, L ``(rank, )``"""
|
|
12
12
|
if isinstance(history, torch.Tensor):
|
|
13
13
|
M = history
|
|
@@ -27,7 +27,7 @@ def ggt_update(history: deque[torch.Tensor] | torch.Tensor, damping, rdamping, t
|
|
|
27
27
|
if L is None or Q is None: # this means there are no finite eigenvalues
|
|
28
28
|
return None, None
|
|
29
29
|
|
|
30
|
-
U = (M @ Q) * L.
|
|
30
|
+
U = (M @ Q) * L.pow(matrix_power)
|
|
31
31
|
|
|
32
32
|
# this damping is added after computing U, this is why I didn't use one in linalg.regularize_eig
|
|
33
33
|
# that's because we damp singular values this way
|
|
@@ -44,14 +44,13 @@ class GGT(TensorTransform):
|
|
|
44
44
|
"""
|
|
45
45
|
GGT method from https://arxiv.org/pdf/1806.02958
|
|
46
46
|
|
|
47
|
-
The update rule is to stack recent gradients into M
|
|
48
|
-
|
|
47
|
+
The update rule is to stack recent gradients into M and
|
|
48
|
+
compute eigendecomposition of M M^T via eigendecomposition of M^T M.
|
|
49
49
|
|
|
50
50
|
This is equivalent to full-matrix Adagrad on recent gradients.
|
|
51
51
|
|
|
52
52
|
Args:
|
|
53
53
|
history_size (int, optional): number of past gradients to store. Defaults to 10.
|
|
54
|
-
beta (float, optional): beta for momentum maintained in whitened space. Defaults to 0.0.
|
|
55
54
|
update_freq (int, optional): frequency of updating the preconditioner (U and S). Defaults to 1.
|
|
56
55
|
eig_tol (float, optional): removes eigenvalues this much smaller than largest eigenvalue. Defaults to 1e-7.
|
|
57
56
|
truncate (int, optional): number of larges eigenvalues to keep. None to disable. Defaults to None.
|
|
@@ -105,7 +104,8 @@ class GGT(TensorTransform):
|
|
|
105
104
|
truncate: int | None = None,
|
|
106
105
|
damping: float = 1e-4,
|
|
107
106
|
rdamping: float = 0,
|
|
108
|
-
|
|
107
|
+
matrix_power: float = -1/2,
|
|
108
|
+
basis_optimizer: LREOptimizerBase | None = None,
|
|
109
109
|
concat_params: bool = True,
|
|
110
110
|
|
|
111
111
|
inner: Chainable | None = None,
|
|
@@ -114,6 +114,7 @@ class GGT(TensorTransform):
|
|
|
114
114
|
del defaults['self'], defaults['inner'], defaults['concat_params']
|
|
115
115
|
|
|
116
116
|
super().__init__(defaults, concat_params=concat_params, inner=inner)
|
|
117
|
+
self.add_projected_keys("grad", "history")
|
|
117
118
|
|
|
118
119
|
@torch.no_grad
|
|
119
120
|
def single_tensor_update(self, tensor, param, grad, loss, state, setting):
|
|
@@ -141,14 +142,15 @@ class GGT(TensorTransform):
|
|
|
141
142
|
rdamping=setting["rdamping"],
|
|
142
143
|
truncate=setting["truncate"],
|
|
143
144
|
eig_tol=setting["eig_tol"],
|
|
145
|
+
matrix_power=setting["matrix_power"],
|
|
144
146
|
)
|
|
145
147
|
|
|
146
|
-
# reproject
|
|
147
|
-
|
|
148
|
-
if
|
|
148
|
+
# reproject basis optimizer
|
|
149
|
+
basis_optimizer: LREOptimizerBase | None = setting["basis_optimizer"]
|
|
150
|
+
if basis_optimizer is not None:
|
|
149
151
|
if (L is not None) and (U is not None) and (L_new is not None) and (U_new is not None):
|
|
150
|
-
|
|
151
|
-
|
|
152
|
+
basis_state = state["basis_state"]
|
|
153
|
+
basis_optimizer.reproject(L_old=L, Q_old=U, L_new=L_new, Q_new=U_new, state=basis_state)
|
|
152
154
|
|
|
153
155
|
|
|
154
156
|
# store new factors
|
|
@@ -169,18 +171,18 @@ class GGT(TensorTransform):
|
|
|
169
171
|
|
|
170
172
|
L = state['L']
|
|
171
173
|
|
|
172
|
-
# step with
|
|
173
|
-
|
|
174
|
-
if
|
|
174
|
+
# step with basis optimizer
|
|
175
|
+
basis_optimizer: LREOptimizerBase | None = setting["basis_optimizer"]
|
|
176
|
+
if basis_optimizer is not None:
|
|
175
177
|
|
|
176
|
-
if "
|
|
177
|
-
|
|
178
|
+
if "basis_state" not in state: state["basis_state"] = {}
|
|
179
|
+
basis_state = state["basis_state"]
|
|
178
180
|
|
|
179
|
-
update =
|
|
181
|
+
update = basis_optimizer.step(g, L=L, Q=U, state=basis_state)
|
|
180
182
|
return update.view_as(tensor)
|
|
181
183
|
|
|
182
184
|
# or just whiten
|
|
183
185
|
z = U.T @ g
|
|
184
|
-
update = (U * L.
|
|
186
|
+
update = (U * L.pow(setting["matrix_power"])) @ z
|
|
185
187
|
return update.view_as(tensor)
|
|
186
188
|
|
|
@@ -23,9 +23,11 @@ class Lion(TensorTransform):
|
|
|
23
23
|
defaults = dict(beta1=beta1, beta2=beta2)
|
|
24
24
|
super().__init__(defaults)
|
|
25
25
|
|
|
26
|
+
self.add_projected_keys("grad", "exp_avg")
|
|
27
|
+
|
|
26
28
|
@torch.no_grad
|
|
27
29
|
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
28
30
|
beta1, beta2 = unpack_dicts(settings, 'beta1', 'beta2', cls=NumberList)
|
|
29
|
-
exp_avg = unpack_states(states, tensors, '
|
|
31
|
+
exp_avg = unpack_states(states, tensors, 'exp_avg', cls=TensorList)
|
|
30
32
|
return lion_(TensorList(tensors), exp_avg, beta1, beta2)
|
|
31
33
|
|
|
@@ -6,13 +6,13 @@ from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
|
|
|
6
6
|
|
|
7
7
|
def mars_correction_(
|
|
8
8
|
tensors_: TensorList,
|
|
9
|
-
|
|
9
|
+
g_prev_: TensorList,
|
|
10
10
|
beta: float | NumberList,
|
|
11
11
|
scaling: float | NumberList,
|
|
12
12
|
max_norm: float | NumberList | None,
|
|
13
13
|
):
|
|
14
|
-
dg = (tensors_ -
|
|
15
|
-
|
|
14
|
+
dg = (tensors_ - g_prev_).mul_(scaling * beta / (1-beta))
|
|
15
|
+
g_prev_.copy_(tensors_)
|
|
16
16
|
|
|
17
17
|
c = tensors_.add_(dg)
|
|
18
18
|
if max_norm is not None:
|
|
@@ -63,16 +63,17 @@ class MARSCorrection(TensorTransform):
|
|
|
63
63
|
):
|
|
64
64
|
defaults = dict(beta=beta, scaling=scaling, max_norm=max_norm)
|
|
65
65
|
super().__init__(defaults)
|
|
66
|
+
self.add_projected_keys("grad", "g_prev")
|
|
66
67
|
|
|
67
68
|
@torch.no_grad
|
|
68
69
|
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
69
|
-
|
|
70
|
+
g_prev = unpack_states(states, tensors, 'g_prev', init=tensors, cls=TensorList)
|
|
70
71
|
beta, scaling = unpack_dicts(settings, 'beta', 'scaling', cls=NumberList)
|
|
71
72
|
max_norm = settings[0]['max_norm']
|
|
72
73
|
|
|
73
74
|
return mars_correction_(
|
|
74
75
|
tensors_=TensorList(tensors),
|
|
75
|
-
|
|
76
|
+
g_prev_=g_prev,
|
|
76
77
|
beta=beta,
|
|
77
78
|
scaling=scaling,
|
|
78
79
|
max_norm=max_norm,
|
|
@@ -121,6 +121,8 @@ class MSAMMomentum(TensorTransform):
|
|
|
121
121
|
defaults = dict(lr = lr, momentum=momentum, rho=rho, nesterov=nesterov, lerp=lerp, weight_decay=weight_decay)
|
|
122
122
|
super().__init__(defaults, uses_grad=False)
|
|
123
123
|
|
|
124
|
+
self.add_projected_keys("grad", "velocity")
|
|
125
|
+
|
|
124
126
|
@torch.no_grad
|
|
125
127
|
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
126
128
|
velocity = unpack_states(states, tensors, 'velocity', cls=TensorList)
|
|
@@ -180,6 +182,7 @@ class MSAM(Transform):
|
|
|
180
182
|
super().__init__(defaults)
|
|
181
183
|
|
|
182
184
|
self.set_child('modules', modules)
|
|
185
|
+
self.add_projected_keys("grad", "velocity")
|
|
183
186
|
|
|
184
187
|
|
|
185
188
|
@torch.no_grad
|
|
@@ -38,6 +38,8 @@ class RMSprop(TensorTransform):
|
|
|
38
38
|
super().__init__(defaults, inner=inner)
|
|
39
39
|
|
|
40
40
|
self.set_child('exp_avg_sq', exp_avg_sq_tfm)
|
|
41
|
+
self.add_projected_keys("grad", "exp_avg")
|
|
42
|
+
self.add_projected_keys("grad_sq", "exp_avg_sq", "exp_avg_sq_max")
|
|
41
43
|
|
|
42
44
|
@torch.no_grad
|
|
43
45
|
def single_tensor_initialize(self, tensor, param, grad, loss, state, setting):
|
|
@@ -128,15 +128,15 @@ def rprop_(
|
|
|
128
128
|
|
|
129
129
|
class Rprop(TensorTransform):
|
|
130
130
|
"""
|
|
131
|
-
Resilient propagation. The update magnitude gets multiplied by
|
|
132
|
-
or
|
|
131
|
+
Resilient propagation. The update magnitude gets multiplied by ``nplus`` if gradient didn't change the sign,
|
|
132
|
+
or ``nminus`` if it did. Then the update is applied with the sign of the current gradient.
|
|
133
133
|
|
|
134
134
|
Additionally, if gradient changes sign, the update for that weight is reverted.
|
|
135
135
|
Next step, magnitude for that weight won't change.
|
|
136
136
|
|
|
137
137
|
Compared to pytorch this also implements backtracking update when sign changes.
|
|
138
138
|
|
|
139
|
-
This implementation is identical to
|
|
139
|
+
This implementation is identical to ``torch.optim.Rprop`` if ``backtrack`` is set to False.
|
|
140
140
|
|
|
141
141
|
Args:
|
|
142
142
|
nplus (float): multiplicative increase factor for when ascent didn't change sign (default: 1.2).
|
|
@@ -164,6 +164,8 @@ class Rprop(TensorTransform):
|
|
|
164
164
|
defaults = dict(nplus = nplus, nminus = nminus, alpha = alpha, lb = lb, ub = ub, backtrack=backtrack)
|
|
165
165
|
super().__init__(defaults, uses_grad=False)
|
|
166
166
|
|
|
167
|
+
self.add_projected_keys("grad", "prev")
|
|
168
|
+
|
|
167
169
|
@torch.no_grad
|
|
168
170
|
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
169
171
|
step = self.global_state.get('step', 0)
|
|
@@ -196,14 +198,14 @@ class Rprop(TensorTransform):
|
|
|
196
198
|
|
|
197
199
|
class ScaleLRBySignChange(TensorTransform):
|
|
198
200
|
"""
|
|
199
|
-
learning rate gets multiplied by
|
|
200
|
-
or
|
|
201
|
+
learning rate gets multiplied by ``nplus`` if ascent/gradient didn't change the sign,
|
|
202
|
+
or ``nminus`` if it did.
|
|
201
203
|
|
|
202
204
|
This is part of RProp update rule.
|
|
203
205
|
|
|
204
206
|
Args:
|
|
205
|
-
nplus (float): learning rate gets multiplied by
|
|
206
|
-
nminus (float): learning rate gets multiplied by
|
|
207
|
+
nplus (float): learning rate gets multiplied by ``nplus`` if ascent/gradient didn't change the sign
|
|
208
|
+
nminus (float): learning rate gets multiplied by ``nminus`` if ascent/gradient changed the sign
|
|
207
209
|
lb (float): lower bound for lr.
|
|
208
210
|
ub (float): upper bound for lr.
|
|
209
211
|
alpha (float): initial learning rate.
|
|
@@ -207,6 +207,9 @@ class Shampoo(TensorTransform):
|
|
|
207
207
|
if setting["merge_small"]:
|
|
208
208
|
tensor, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(tensor, setting["max_dim"])
|
|
209
209
|
|
|
210
|
+
if "inner" not in self.children:
|
|
211
|
+
state["merged"] = tensor
|
|
212
|
+
|
|
210
213
|
if 'diagonal_accumulator' in state:
|
|
211
214
|
update_diagonal_(tensor, state['diagonal_accumulator'], beta=setting["beta"])
|
|
212
215
|
else:
|
|
@@ -227,10 +230,15 @@ class Shampoo(TensorTransform):
|
|
|
227
230
|
|
|
228
231
|
state["step"] += 1
|
|
229
232
|
|
|
233
|
+
|
|
230
234
|
@torch.no_grad
|
|
231
235
|
def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
|
|
236
|
+
|
|
232
237
|
if setting["merge_small"]:
|
|
233
|
-
|
|
238
|
+
if "inner" not in self.children:
|
|
239
|
+
tensor = state.pop("merged")
|
|
240
|
+
else:
|
|
241
|
+
tensor, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(tensor, setting["max_dim"])
|
|
234
242
|
|
|
235
243
|
if 'diagonal_accumulator' in state:
|
|
236
244
|
dir = apply_diagonal_(tensor, state['diagonal_accumulator'], eps=setting["adagrad_eps"])
|
|
@@ -1,12 +1,13 @@
|
|
|
1
|
-
from operator import itemgetter
|
|
2
1
|
import warnings
|
|
2
|
+
from operator import itemgetter
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
5
|
|
|
6
|
-
from ...core import
|
|
7
|
-
from ...utils import unpack_dicts, unpack_states, TensorList, NumberList
|
|
8
|
-
from ...modules.adaptive.shampoo import _merge_small_dims, _unmerge_small_dims
|
|
6
|
+
from ...core import Chainable, TensorTransform
|
|
9
7
|
from ...linalg import torch_linalg
|
|
8
|
+
from ...modules.adaptive.shampoo import _merge_small_dims, _unmerge_small_dims
|
|
9
|
+
from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
|
|
10
|
+
|
|
10
11
|
|
|
11
12
|
@torch.no_grad
|
|
12
13
|
def update_soap_covariances_(
|
|
@@ -221,25 +222,38 @@ class SOAP(TensorTransform):
|
|
|
221
222
|
return TensorList(tensors).clamp(-0.1, 0.1)
|
|
222
223
|
# return TensorList(tensors).zero_()
|
|
223
224
|
|
|
224
|
-
|
|
225
225
|
fs = settings[0]
|
|
226
|
-
|
|
226
|
+
merged_updates = [] # for when exp_avg is maintained unprojected
|
|
227
|
+
merged_grads = [] # this doesn't go into preconditioner
|
|
227
228
|
projected = []
|
|
228
|
-
# ---------------------------------- project --------------------------------- #
|
|
229
229
|
|
|
230
|
-
|
|
230
|
+
# -------------------------------- inner step -------------------------------- #
|
|
231
|
+
updates = tensors
|
|
232
|
+
has_inner = "inner" in self.children
|
|
233
|
+
if has_inner:
|
|
234
|
+
updates = self.inner_step_tensors("inner", updates, clone=True,
|
|
235
|
+
params=params, grads=grads, loss=loss)
|
|
236
|
+
|
|
237
|
+
# ---------------------------------- project --------------------------------- #
|
|
238
|
+
for grad, update, state, setting in zip(tensors, updates, states, settings):
|
|
231
239
|
if setting["merge_small"]:
|
|
232
|
-
|
|
240
|
+
update, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(update, setting["max_dim"])
|
|
241
|
+
if has_inner: # grad is a different tensor, merge it too
|
|
242
|
+
grad, _, _ = _merge_small_dims(grad, setting["max_dim"])
|
|
243
|
+
else: # in this case update is still just grad
|
|
244
|
+
grad = update
|
|
233
245
|
|
|
234
|
-
|
|
246
|
+
merged_updates.append(update)
|
|
247
|
+
merged_grads.append(grad)
|
|
235
248
|
|
|
236
249
|
if state['GG'] is not None:
|
|
237
|
-
|
|
250
|
+
update = project(update, state['Q'])
|
|
251
|
+
|
|
252
|
+
projected.append(update)
|
|
238
253
|
|
|
239
|
-
projected.append(tensor)
|
|
240
254
|
|
|
241
255
|
# ------------------------ run adam in projected space ----------------------- #
|
|
242
|
-
exp_avg_proj, exp_avg_sq_proj = unpack_states(states,
|
|
256
|
+
exp_avg_proj, exp_avg_sq_proj = unpack_states(states, projected, "exp_avg_proj", "exp_avg_sq_proj", must_exist=True, cls=TensorList)
|
|
243
257
|
alpha, beta1, beta2, eps = unpack_dicts(settings, "alpha", "beta1", "beta2", "eps", cls=NumberList)
|
|
244
258
|
|
|
245
259
|
# lerp exp_avg in projected space
|
|
@@ -249,15 +263,17 @@ class SOAP(TensorTransform):
|
|
|
249
263
|
# or lerp in original space and project
|
|
250
264
|
else:
|
|
251
265
|
exp_avg = exp_avg_proj
|
|
252
|
-
exp_avg.lerp_(
|
|
266
|
+
exp_avg.lerp_(merged_updates, weight=1-beta1)
|
|
253
267
|
exp_avg_proj = []
|
|
254
268
|
for t, state, setting in zip(exp_avg, states, settings):
|
|
255
269
|
if state['GG'] is not None:
|
|
256
270
|
t = project(t, state["Q"])
|
|
257
271
|
exp_avg_proj.append(t)
|
|
258
272
|
|
|
273
|
+
# lerp exp_avg_sq
|
|
259
274
|
exp_avg_sq_proj.mul_(beta2).addcmul_(projected, projected, value=1-beta2)
|
|
260
275
|
|
|
276
|
+
# adam direction
|
|
261
277
|
denom = exp_avg_sq_proj.sqrt().add_(eps)
|
|
262
278
|
dirs_proj = exp_avg_proj / denom
|
|
263
279
|
|
|
@@ -272,27 +288,14 @@ class SOAP(TensorTransform):
|
|
|
272
288
|
|
|
273
289
|
dirs.append(dir)
|
|
274
290
|
|
|
275
|
-
|
|
276
|
-
# -------------------------------- inner step -------------------------------- #
|
|
277
|
-
if "inner" in self.children:
|
|
278
|
-
tensors = self.inner_step_tensors("inner", tensors, clone=False,
|
|
279
|
-
params=params, grads=grads,loss=loss)
|
|
280
|
-
|
|
281
|
-
# we now have to re-merge small dims on updated tensors
|
|
282
|
-
merged = []
|
|
283
|
-
for tensor, state, setting in zip(tensors, states, settings):
|
|
284
|
-
if setting["merge_small"]:
|
|
285
|
-
tensor, _, _ = _merge_small_dims(tensor, setting["max_dim"])
|
|
286
|
-
merged.append(tensor)
|
|
287
|
-
|
|
288
291
|
# -------------------------- update preconditioners -------------------------- #
|
|
289
292
|
# Update is done after the gradient step to avoid using current gradients in the projection.
|
|
290
293
|
|
|
291
|
-
for
|
|
294
|
+
for grad, state, setting in zip(merged_grads, states, settings):
|
|
292
295
|
if state['GG'] is not None:
|
|
293
296
|
|
|
294
297
|
# lerp covariances
|
|
295
|
-
update_soap_covariances_(
|
|
298
|
+
update_soap_covariances_(grad, state['GG'], beta=setting["shampoo_beta"])
|
|
296
299
|
|
|
297
300
|
# (state['step'] - 1) since we start updating on 2nd step
|
|
298
301
|
if (state['step'] - 1) % setting['precond_freq'] == 0:
|
|
@@ -0,0 +1,199 @@
|
|
|
1
|
+
from collections import deque
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from ...core import Chainable, TensorTransform
|
|
6
|
+
from ...utils import set_storage_
|
|
7
|
+
from ..adaptive.ggt import ggt_update
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def _cubic_reproject(C: torch.Tensor, cu: torch.Tensor, approx:bool):
|
|
11
|
+
if approx: return C.pow(3) @ cu
|
|
12
|
+
|
|
13
|
+
n = cu.numel()
|
|
14
|
+
T = torch.zeros([n,n,n], device=cu.device, dtype=cu.dtype)
|
|
15
|
+
T[range(n),range(n),range(n)] = cu
|
|
16
|
+
T = torch.einsum('ai,bj,ck,ijk->abc', C, C, C, T)
|
|
17
|
+
n2 = T.size(0)
|
|
18
|
+
return T[range(n2), range(n2), range(n2)]
|
|
19
|
+
|
|
20
|
+
class GGTBasis(TensorTransform):
|
|
21
|
+
"""
|
|
22
|
+
Run another optimizer in GGT eigenbasis. The eigenbasis is ``rank``-sized, so it is possible to run expensive
|
|
23
|
+
methods such as Full-matrix Adagrad/Adam.
|
|
24
|
+
|
|
25
|
+
The update rule is to stack recent gradients into M and
|
|
26
|
+
compute eigendecomposition of M M^T via eigendecomposition of M^T M.
|
|
27
|
+
|
|
28
|
+
This is equivalent to full-matrix Adagrad on recent gradients.
|
|
29
|
+
|
|
30
|
+
Note:
|
|
31
|
+
the buffers of the ``basis_opt`` are re-projected whenever basis changes. The reprojection logic is not implemented on all modules. Some supported modules are:
|
|
32
|
+
|
|
33
|
+
``Adagrad``, ``FullMatrixAdagrad``, ``Adam``, ``Adan``, ``Lion``, ``MARSCorrection``, ``MSAMMomentum``, ``RMSprop``, ``GGT``, ``EMA``, ``HeavyBall``, ``NAG``, ``ClipNormByEMA``, ``ClipValueByEMA``, ``NormalizeByEMA``, ``ClipValueGrowth``, ``CoordinateMomentum``, ``CubicAdam``.
|
|
34
|
+
|
|
35
|
+
Additionally most modules with no internal buffers are supported, e.g. ``Cautious``, ``Sign``, ``ClipNorm``, ``Orthogonalize``, etc. However modules that use weight values, such as ``WeighDecay`` can't be supported, as weights can't be projected.
|
|
36
|
+
|
|
37
|
+
Also, if you say use ``EMA`` on output of ``Pow(2)``, the exponential average will be reprojected as gradient and not as squared gradients. Use modules like ``EMASquared``, ``SqrtEMASquared`` to get correct reprojections.
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
basis_opt (Chainable): module or modules to run in GGT eigenbasis.
|
|
42
|
+
history_size (int, optional): number of past gradients to store, and rank of preconditioner. Defaults to 10.
|
|
43
|
+
update_freq (int, optional): frequency of updating the preconditioner (U and S). Defaults to 1.
|
|
44
|
+
eig_tol (float, optional): removes eigenvalues this much smaller than largest eigenvalue. Defaults to 1e-7.
|
|
45
|
+
truncate (int, optional): number of larges eigenvalues to keep. None to disable. Defaults to None.
|
|
46
|
+
damping (float, optional): damping value. Defaults to 1e-4.
|
|
47
|
+
rdamping (float, optional): value of damping relative to largest eigenvalue. Defaults to 0.
|
|
48
|
+
concat_params (bool, optional): if True, treats all parameters as a single vector. Defaults to True.
|
|
49
|
+
inner (Chainable | None, optional):
|
|
50
|
+
output of this module is projected and ``basis_opt`` will run on it, but preconditioners are updated
|
|
51
|
+
from original gradients.
|
|
52
|
+
|
|
53
|
+
## Examples:
|
|
54
|
+
|
|
55
|
+
Examples:
|
|
56
|
+
Adam in GGT eigenbasis:
|
|
57
|
+
```python
|
|
58
|
+
opt = tz.Optimizer(
|
|
59
|
+
model.parameters(),
|
|
60
|
+
tz.m.GGTBasis(tz.m.Adam(beta2=0.99)),
|
|
61
|
+
tz.m.LR(1e-3)
|
|
62
|
+
)
|
|
63
|
+
```
|
|
64
|
+
|
|
65
|
+
Full-matrix Adam in GGT eigenbasis. We can define full-matrix Adam through ``FullMatrixAdagrad``.
|
|
66
|
+
```python
|
|
67
|
+
opt = tz.Optimizer(
|
|
68
|
+
model.parameters(),
|
|
69
|
+
tz.m.GGTBasis(
|
|
70
|
+
[tz.m.FullMatrixAdagrad(beta=0.99, inner=tz.m.EMA(0.9, debias=True))]
|
|
71
|
+
),
|
|
72
|
+
tz.m.LR(1e-3)
|
|
73
|
+
)
|
|
74
|
+
```
|
|
75
|
+
|
|
76
|
+
LaProp in GGT eigenbasis:
|
|
77
|
+
```python
|
|
78
|
+
|
|
79
|
+
# we define LaProp through other modules, moved it out for brevity
|
|
80
|
+
laprop = (
|
|
81
|
+
tz.m.RMSprop(0.95),
|
|
82
|
+
tz.m.Debias(beta1=None, beta2=0.95),
|
|
83
|
+
tz.m.EMA(0.95),
|
|
84
|
+
tz.m.Debias(beta1=0.95, beta2=None),
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
opt = tz.Optimizer(
|
|
88
|
+
model.parameters(),
|
|
89
|
+
tz.m.GGTBasis(laprop),
|
|
90
|
+
tz.m.LR(1e-3)
|
|
91
|
+
)
|
|
92
|
+
```
|
|
93
|
+
|
|
94
|
+
Reference:
|
|
95
|
+
Agarwal N. et al. Efficient full-matrix adaptive regularization //International Conference on Machine Learning. – PMLR, 2019. – С. 102-110.
|
|
96
|
+
"""
|
|
97
|
+
|
|
98
|
+
def __init__(
|
|
99
|
+
self,
|
|
100
|
+
basis_opt: Chainable,
|
|
101
|
+
history_size: int = 100,
|
|
102
|
+
update_freq: int = 1,
|
|
103
|
+
eig_tol: float = 1e-7,
|
|
104
|
+
truncate: int | None = None,
|
|
105
|
+
damping: float = 1e-4,
|
|
106
|
+
rdamping: float = 0,
|
|
107
|
+
matrix_power: float = -1/2,
|
|
108
|
+
approx_sq_reproject:bool = False,
|
|
109
|
+
approx_cu_reproject:bool = False,
|
|
110
|
+
|
|
111
|
+
inner: Chainable | None = None,
|
|
112
|
+
):
|
|
113
|
+
defaults = locals().copy()
|
|
114
|
+
del defaults['self'], defaults['inner']
|
|
115
|
+
|
|
116
|
+
super().__init__(defaults, concat_params=True, inner=inner)
|
|
117
|
+
self.set_child("basis_opt", basis_opt)
|
|
118
|
+
|
|
119
|
+
@torch.no_grad
|
|
120
|
+
def single_tensor_update(self, tensor, param, grad, loss, state, setting):
|
|
121
|
+
history_size = setting['history_size']
|
|
122
|
+
update_freq = setting['update_freq']
|
|
123
|
+
|
|
124
|
+
if 'history' not in state: state['history'] = deque(maxlen=history_size)
|
|
125
|
+
history = state['history']
|
|
126
|
+
|
|
127
|
+
t = tensor.clone().view(-1)
|
|
128
|
+
history.append(t)
|
|
129
|
+
|
|
130
|
+
step = state.get('step', 0)
|
|
131
|
+
state['step'] = step + 1
|
|
132
|
+
|
|
133
|
+
if step % update_freq == 0 :
|
|
134
|
+
|
|
135
|
+
# compute new factors
|
|
136
|
+
L = state.get("L", None)
|
|
137
|
+
U = state.get("U", None)
|
|
138
|
+
|
|
139
|
+
L_new, U_new = ggt_update(
|
|
140
|
+
history,
|
|
141
|
+
damping=setting["damping"],
|
|
142
|
+
rdamping=setting["rdamping"],
|
|
143
|
+
truncate=setting["truncate"],
|
|
144
|
+
eig_tol=setting["eig_tol"],
|
|
145
|
+
matrix_power=setting["matrix_power"],
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
if (L is not None) and (U is not None) and (L_new is not None) and (U_new is not None):
|
|
149
|
+
# reproject basis optimizer
|
|
150
|
+
# this happens after first step, so basis opt is initialized by then
|
|
151
|
+
# note that because we concatenate parameters, each buffer will a single rank-length vector
|
|
152
|
+
C = U_new.T @ U # change of basis matrix
|
|
153
|
+
|
|
154
|
+
# reproject gradient-like buffers
|
|
155
|
+
for (buff,) in self.get_child_projected_buffers("basis_opt", "grad"):
|
|
156
|
+
set_storage_(buff, C @ buff)
|
|
157
|
+
|
|
158
|
+
# reproject covariance diagonal-like buffers
|
|
159
|
+
for (buff,) in self.get_child_projected_buffers("basis_opt", "grad_sq"):
|
|
160
|
+
if setting["approx_sq_reproject"]: set_storage_(buff, C.pow(2) @ buff)
|
|
161
|
+
else: set_storage_(buff, (C @ buff.diag_embed() @ C.T).diagonal())
|
|
162
|
+
|
|
163
|
+
# reproject third order diagonal-like buffers
|
|
164
|
+
for (buff,) in self.get_child_projected_buffers("basis_opt", "grad_cu"):
|
|
165
|
+
buff_r = _cubic_reproject(C, buff, setting["approx_cu_reproject"])
|
|
166
|
+
set_storage_(buff, buff_r)
|
|
167
|
+
|
|
168
|
+
# reproject covariance-like buffers
|
|
169
|
+
for (buff,) in self.get_child_projected_buffers("basis_opt", "covariance"):
|
|
170
|
+
set_storage_(buff, C @ buff @ C.T)
|
|
171
|
+
|
|
172
|
+
# store new factors
|
|
173
|
+
if L_new is not None: state["L"] = L_new
|
|
174
|
+
if U_new is not None: state["U"] = U_new
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
@torch.no_grad
|
|
178
|
+
def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
|
|
179
|
+
g = tensor.view(-1)
|
|
180
|
+
U = state.get('U', None)
|
|
181
|
+
|
|
182
|
+
if U is None:
|
|
183
|
+
# fallback to element-wise preconditioning
|
|
184
|
+
history = torch.stack(tuple(state["history"]), 0)
|
|
185
|
+
g /= history.square().mean(0).sqrt().add(1e-8)
|
|
186
|
+
return g.view_as(tensor)
|
|
187
|
+
|
|
188
|
+
# project
|
|
189
|
+
g_proj = U.T @ g
|
|
190
|
+
|
|
191
|
+
# step
|
|
192
|
+
dir_proj = self.inner_step_tensors("basis_opt", tensors=[g_proj], clone=False, grads=[g_proj])[0]
|
|
193
|
+
|
|
194
|
+
# unproject
|
|
195
|
+
update = U @ dir_proj
|
|
196
|
+
|
|
197
|
+
# update = (U * L.pow(setting["matrix_power"])) @ z
|
|
198
|
+
return update.view_as(tensor)
|
|
199
|
+
|