torchzero 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +22 -22
- tests/test_opts.py +199 -198
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +1 -1
- torchzero/core/functional.py +1 -1
- torchzero/core/modular.py +5 -5
- torchzero/core/module.py +2 -2
- torchzero/core/objective.py +10 -10
- torchzero/core/transform.py +1 -1
- torchzero/linalg/__init__.py +3 -2
- torchzero/linalg/eigh.py +223 -4
- torchzero/linalg/orthogonalize.py +2 -4
- torchzero/linalg/qr.py +12 -0
- torchzero/linalg/solve.py +1 -3
- torchzero/linalg/svd.py +47 -20
- torchzero/modules/__init__.py +4 -3
- torchzero/modules/adaptive/__init__.py +11 -3
- torchzero/modules/adaptive/adagrad.py +10 -10
- torchzero/modules/adaptive/adahessian.py +2 -2
- torchzero/modules/adaptive/adam.py +1 -1
- torchzero/modules/adaptive/adan.py +1 -1
- torchzero/modules/adaptive/adaptive_heavyball.py +1 -1
- torchzero/modules/adaptive/esgd.py +2 -2
- torchzero/modules/adaptive/ggt.py +186 -0
- torchzero/modules/adaptive/lion.py +2 -1
- torchzero/modules/adaptive/lre_optimizers.py +299 -0
- torchzero/modules/adaptive/mars.py +2 -2
- torchzero/modules/adaptive/matrix_momentum.py +1 -1
- torchzero/modules/adaptive/msam.py +4 -4
- torchzero/modules/adaptive/muon.py +9 -6
- torchzero/modules/adaptive/natural_gradient.py +32 -15
- torchzero/modules/adaptive/psgd/__init__.py +5 -0
- torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
- torchzero/modules/adaptive/psgd/psgd.py +1390 -0
- torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
- torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
- torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
- torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
- torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
- torchzero/modules/adaptive/rprop.py +2 -2
- torchzero/modules/adaptive/sam.py +4 -4
- torchzero/modules/adaptive/shampoo.py +28 -3
- torchzero/modules/adaptive/soap.py +3 -3
- torchzero/modules/adaptive/sophia_h.py +2 -2
- torchzero/modules/clipping/clipping.py +7 -7
- torchzero/modules/conjugate_gradient/cg.py +2 -2
- torchzero/modules/experimental/__init__.py +5 -0
- torchzero/modules/experimental/adanystrom.py +258 -0
- torchzero/modules/experimental/common_directions_whiten.py +142 -0
- torchzero/modules/experimental/cubic_adam.py +160 -0
- torchzero/modules/experimental/eigen_sr1.py +182 -0
- torchzero/modules/experimental/eigengrad.py +207 -0
- torchzero/modules/experimental/l_infinity.py +1 -1
- torchzero/modules/experimental/matrix_nag.py +122 -0
- torchzero/modules/experimental/newton_solver.py +2 -2
- torchzero/modules/experimental/newtonnewton.py +34 -40
- torchzero/modules/grad_approximation/fdm.py +2 -2
- torchzero/modules/grad_approximation/rfdm.py +4 -4
- torchzero/modules/least_squares/gn.py +68 -45
- torchzero/modules/line_search/backtracking.py +2 -2
- torchzero/modules/line_search/line_search.py +1 -1
- torchzero/modules/line_search/strong_wolfe.py +2 -2
- torchzero/modules/misc/escape.py +1 -1
- torchzero/modules/misc/gradient_accumulation.py +1 -1
- torchzero/modules/misc/misc.py +1 -1
- torchzero/modules/misc/multistep.py +4 -7
- torchzero/modules/misc/regularization.py +2 -2
- torchzero/modules/misc/split.py +1 -1
- torchzero/modules/misc/switch.py +2 -2
- torchzero/modules/momentum/cautious.py +3 -3
- torchzero/modules/momentum/momentum.py +1 -1
- torchzero/modules/ops/higher_level.py +1 -1
- torchzero/modules/ops/multi.py +1 -1
- torchzero/modules/projections/projection.py +5 -2
- torchzero/modules/quasi_newton/__init__.py +1 -1
- torchzero/modules/quasi_newton/damping.py +1 -1
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +3 -3
- torchzero/modules/quasi_newton/lsr1.py +3 -3
- torchzero/modules/quasi_newton/quasi_newton.py +44 -29
- torchzero/modules/quasi_newton/sg2.py +69 -205
- torchzero/modules/restarts/restars.py +17 -17
- torchzero/modules/second_order/inm.py +33 -25
- torchzero/modules/second_order/newton.py +132 -130
- torchzero/modules/second_order/newton_cg.py +3 -3
- torchzero/modules/second_order/nystrom.py +83 -32
- torchzero/modules/second_order/rsn.py +41 -44
- torchzero/modules/smoothing/laplacian.py +1 -1
- torchzero/modules/smoothing/sampling.py +2 -3
- torchzero/modules/step_size/adaptive.py +6 -6
- torchzero/modules/step_size/lr.py +2 -2
- torchzero/modules/trust_region/cubic_regularization.py +1 -1
- torchzero/modules/trust_region/levenberg_marquardt.py +2 -2
- torchzero/modules/trust_region/trust_cg.py +1 -1
- torchzero/modules/variance_reduction/svrg.py +4 -5
- torchzero/modules/weight_decay/reinit.py +2 -2
- torchzero/modules/weight_decay/weight_decay.py +5 -5
- torchzero/modules/wrappers/optim_wrapper.py +4 -4
- torchzero/modules/zeroth_order/cd.py +1 -1
- torchzero/optim/mbs.py +291 -0
- torchzero/optim/wrappers/nevergrad.py +0 -9
- torchzero/optim/wrappers/optuna.py +2 -0
- torchzero/utils/benchmarks/__init__.py +0 -0
- torchzero/utils/benchmarks/logistic.py +122 -0
- torchzero/utils/derivatives.py +4 -4
- {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
- torchzero-0.4.1.dist-info/RECORD +209 -0
- torchzero/modules/adaptive/lmadagrad.py +0 -241
- torchzero-0.4.0.dist-info/RECORD +0 -191
- /torchzero/modules/{functional.py → opt_utils.py} +0 -0
- {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
- {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
|
@@ -169,7 +169,7 @@ class FullMatrixAdagrad(TensorTransform):
|
|
|
169
169
|
"""Full-matrix version of Adagrad, can be customized to make RMSprop or Adam (see examples).
|
|
170
170
|
|
|
171
171
|
Note:
|
|
172
|
-
A more memory-efficient version equivalent to full matrix Adagrad on last n gradients is implemented in ``tz.m.
|
|
172
|
+
A more memory-efficient version equivalent to full matrix Adagrad on last n gradients is implemented in ``tz.m.GGT``.
|
|
173
173
|
|
|
174
174
|
Args:
|
|
175
175
|
reg (float, optional): regularization, scale of identity matrix added to accumulator. Defaults to 1e-12.
|
|
@@ -190,7 +190,7 @@ class FullMatrixAdagrad(TensorTransform):
|
|
|
190
190
|
|
|
191
191
|
Plain full-matrix adagrad
|
|
192
192
|
```python
|
|
193
|
-
opt = tz.
|
|
193
|
+
opt = tz.Optimizer(
|
|
194
194
|
model.parameters(),
|
|
195
195
|
tz.m.FullMatrixAdagrd(),
|
|
196
196
|
tz.m.LR(1e-2),
|
|
@@ -199,7 +199,7 @@ class FullMatrixAdagrad(TensorTransform):
|
|
|
199
199
|
|
|
200
200
|
Full-matrix RMSprop
|
|
201
201
|
```python
|
|
202
|
-
opt = tz.
|
|
202
|
+
opt = tz.Optimizer(
|
|
203
203
|
model.parameters(),
|
|
204
204
|
tz.m.FullMatrixAdagrad(beta=0.99),
|
|
205
205
|
tz.m.LR(1e-2),
|
|
@@ -208,7 +208,7 @@ class FullMatrixAdagrad(TensorTransform):
|
|
|
208
208
|
|
|
209
209
|
Full-matrix Adam
|
|
210
210
|
```python
|
|
211
|
-
opt = tz.
|
|
211
|
+
opt = tz.Optimizer(
|
|
212
212
|
model.parameters(),
|
|
213
213
|
tz.m.FullMatrixAdagrad(beta=0.999, inner=tz.m.EMA(0.9)),
|
|
214
214
|
tz.m.Debias(0.9, 0.999),
|
|
@@ -240,22 +240,22 @@ class FullMatrixAdagrad(TensorTransform):
|
|
|
240
240
|
def single_tensor_update(self, tensor, param, grad, loss, state, setting):
|
|
241
241
|
|
|
242
242
|
G = tensor.ravel()
|
|
243
|
-
|
|
243
|
+
GGT = torch.outer(G, G)
|
|
244
244
|
|
|
245
245
|
# initialize
|
|
246
246
|
if "accumulator" not in state:
|
|
247
247
|
init = setting['init']
|
|
248
|
-
if init == 'identity': state['accumulator'] = torch.eye(
|
|
249
|
-
elif init == 'zeros': state['accumulator'] = torch.zeros_like(
|
|
250
|
-
elif init == 'GGT': state['accumulator'] =
|
|
248
|
+
if init == 'identity': state['accumulator'] = torch.eye(GGT.size(0), device=GGT.device, dtype=GGT.dtype)
|
|
249
|
+
elif init == 'zeros': state['accumulator'] = torch.zeros_like(GGT)
|
|
250
|
+
elif init == 'GGT': state['accumulator'] = GGT.clone()
|
|
251
251
|
else: raise ValueError(init)
|
|
252
252
|
|
|
253
253
|
# update
|
|
254
254
|
beta = setting['beta']
|
|
255
255
|
accumulator: torch.Tensor = state["accumulator"]
|
|
256
256
|
|
|
257
|
-
if beta is None: accumulator.add_(
|
|
258
|
-
else: accumulator.lerp_(
|
|
257
|
+
if beta is None: accumulator.add_(GGT)
|
|
258
|
+
else: accumulator.lerp_(GGT, 1-beta)
|
|
259
259
|
|
|
260
260
|
# update number of GGᵀ in accumulator for divide
|
|
261
261
|
state['num_GGTs'] = state.get('num_GGTs', 0) + 1
|
|
@@ -86,7 +86,7 @@ class AdaHessian(Transform):
|
|
|
86
86
|
Using AdaHessian:
|
|
87
87
|
|
|
88
88
|
```python
|
|
89
|
-
opt = tz.
|
|
89
|
+
opt = tz.Optimizer(
|
|
90
90
|
model.parameters(),
|
|
91
91
|
tz.m.AdaHessian(),
|
|
92
92
|
tz.m.LR(0.1)
|
|
@@ -97,7 +97,7 @@ class AdaHessian(Transform):
|
|
|
97
97
|
Turn off AdaHessian's first momentum to get just the preconditioning. Here is an example of applying
|
|
98
98
|
AdaHessian preconditioning to nesterov momentum (``tz.m.NAG``):
|
|
99
99
|
```python
|
|
100
|
-
opt = tz.
|
|
100
|
+
opt = tz.Optimizer(
|
|
101
101
|
model.parameters(),
|
|
102
102
|
tz.m.AdaHessian(beta1=0, inner=tz.m.NAG(0.9)),
|
|
103
103
|
tz.m.LR(0.1)
|
|
@@ -2,7 +2,7 @@ import torch
|
|
|
2
2
|
|
|
3
3
|
from ...core import Chainable, Module, TensorTransform
|
|
4
4
|
from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
|
|
5
|
-
from ..
|
|
5
|
+
from ..opt_utils import debiased_step_size
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
class Adam(TensorTransform):
|
|
@@ -30,7 +30,7 @@ class AdaptiveHeavyBall(TensorTransform):
|
|
|
30
30
|
"""
|
|
31
31
|
def __init__(self, f_star: float = 0):
|
|
32
32
|
defaults = dict(f_star=f_star)
|
|
33
|
-
super().__init__(defaults,
|
|
33
|
+
super().__init__(defaults, uses_loss=True)
|
|
34
34
|
|
|
35
35
|
@torch.no_grad
|
|
36
36
|
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
@@ -48,7 +48,7 @@ class ESGD(Transform):
|
|
|
48
48
|
Using ESGD:
|
|
49
49
|
```python
|
|
50
50
|
|
|
51
|
-
opt = tz.
|
|
51
|
+
opt = tz.Optimizer(
|
|
52
52
|
model.parameters(),
|
|
53
53
|
tz.m.ESGD(),
|
|
54
54
|
tz.m.LR(0.1)
|
|
@@ -59,7 +59,7 @@ class ESGD(Transform):
|
|
|
59
59
|
ESGD preconditioning to nesterov momentum (:code:`tz.m.NAG`):
|
|
60
60
|
|
|
61
61
|
```python
|
|
62
|
-
opt = tz.
|
|
62
|
+
opt = tz.Optimizer(
|
|
63
63
|
model.parameters(),
|
|
64
64
|
tz.m.ESGD(beta1=0, inner=tz.m.NAG(0.9)),
|
|
65
65
|
tz.m.LR(0.1)
|
|
@@ -0,0 +1,186 @@
|
|
|
1
|
+
from collections import deque
|
|
2
|
+
from typing import Literal, Any
|
|
3
|
+
import warnings
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from ...core import Chainable, TensorTransform
|
|
7
|
+
from ...linalg import torch_linalg, regularize_eigh
|
|
8
|
+
from .lre_optimizers import LREOptimizerBase
|
|
9
|
+
|
|
10
|
+
def ggt_update(history: deque[torch.Tensor] | torch.Tensor, damping, rdamping, truncate, eig_tol):
|
|
11
|
+
"""returns U ``(ndim, rank)``, L ``(rank, )``"""
|
|
12
|
+
if isinstance(history, torch.Tensor):
|
|
13
|
+
M = history
|
|
14
|
+
else:
|
|
15
|
+
M = torch.stack(tuple(history), dim=1)# / len(history)
|
|
16
|
+
|
|
17
|
+
MtM = M.T @ M
|
|
18
|
+
if damping != 0:
|
|
19
|
+
MtM.add_(torch.eye(MtM.size(0), device=MtM.device, dtype=MtM.dtype).mul_(damping))
|
|
20
|
+
|
|
21
|
+
try:
|
|
22
|
+
L, Q = torch_linalg.eigh(MtM, retry_float64=True)
|
|
23
|
+
|
|
24
|
+
# damping is already added to MTM, rdamping is added afterwards
|
|
25
|
+
L, Q = regularize_eigh(L, Q, truncate=truncate, tol=eig_tol, damping=0, rdamping=0)
|
|
26
|
+
|
|
27
|
+
if L is None or Q is None: # this means there are no finite eigenvalues
|
|
28
|
+
return None, None
|
|
29
|
+
|
|
30
|
+
U = (M @ Q) * L.rsqrt()
|
|
31
|
+
|
|
32
|
+
# this damping is added after computing U, this is why I didn't use one in linalg.regularize_eig
|
|
33
|
+
# that's because we damp singular values this way
|
|
34
|
+
if rdamping != 0:
|
|
35
|
+
L.add_(rdamping * L[-1]) # L is sorted in ascending order
|
|
36
|
+
|
|
37
|
+
return L, U
|
|
38
|
+
|
|
39
|
+
except torch.linalg.LinAlgError:
|
|
40
|
+
return None, None
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class GGT(TensorTransform):
|
|
44
|
+
"""
|
|
45
|
+
GGT method from https://arxiv.org/pdf/1806.02958
|
|
46
|
+
|
|
47
|
+
The update rule is to stack recent gradients into M, compute U, S <- SVD(M), then calculate update as U S^-1 Uᵀg.
|
|
48
|
+
But it uses eigendecomposition on MᵀM to get U and S^2 because that is faster when you don't neeed V.
|
|
49
|
+
|
|
50
|
+
This is equivalent to full-matrix Adagrad on recent gradients.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
history_size (int, optional): number of past gradients to store. Defaults to 10.
|
|
54
|
+
beta (float, optional): beta for momentum maintained in whitened space. Defaults to 0.0.
|
|
55
|
+
update_freq (int, optional): frequency of updating the preconditioner (U and S). Defaults to 1.
|
|
56
|
+
eig_tol (float, optional): removes eigenvalues this much smaller than largest eigenvalue. Defaults to 1e-7.
|
|
57
|
+
truncate (int, optional): number of larges eigenvalues to keep. None to disable. Defaults to None.
|
|
58
|
+
damping (float, optional): damping value. Defaults to 1e-4.
|
|
59
|
+
rdamping (float, optional): value of damping relative to largest eigenvalue. Defaults to 0.
|
|
60
|
+
concat_params (bool, optional): if True, treats all parameters as a single vector. Defaults to True.
|
|
61
|
+
inner (Chainable | None, optional): preconditioner will be applied to output of this module. Defaults to None.
|
|
62
|
+
|
|
63
|
+
## Examples:
|
|
64
|
+
|
|
65
|
+
Limited-memory Adagrad
|
|
66
|
+
|
|
67
|
+
```python
|
|
68
|
+
optimizer = tz.Optimizer(
|
|
69
|
+
model.parameters(),
|
|
70
|
+
tz.m.GGT(),
|
|
71
|
+
tz.m.LR(0.1)
|
|
72
|
+
)
|
|
73
|
+
```
|
|
74
|
+
Adam with L-Adagrad preconditioner (for debiasing second beta is 0.999 arbitrarily)
|
|
75
|
+
|
|
76
|
+
```python
|
|
77
|
+
optimizer = tz.Optimizer(
|
|
78
|
+
model.parameters(),
|
|
79
|
+
tz.m.GGT(inner=tz.m.EMA()),
|
|
80
|
+
tz.m.Debias(0.9, 0.999),
|
|
81
|
+
tz.m.LR(0.01)
|
|
82
|
+
)
|
|
83
|
+
```
|
|
84
|
+
|
|
85
|
+
Stable Adam with L-Adagrad preconditioner (this is what I would recommend)
|
|
86
|
+
|
|
87
|
+
```python
|
|
88
|
+
optimizer = tz.Optimizer(
|
|
89
|
+
model.parameters(),
|
|
90
|
+
tz.m.GGT(inner=tz.m.EMA()),
|
|
91
|
+
tz.m.Debias(0.9, 0.999),
|
|
92
|
+
tz.m.ClipNormByEMA(max_ema_growth=1.2),
|
|
93
|
+
tz.m.LR(0.01)
|
|
94
|
+
)
|
|
95
|
+
```
|
|
96
|
+
Reference:
|
|
97
|
+
Agarwal N. et al. Efficient full-matrix adaptive regularization //International Conference on Machine Learning. – PMLR, 2019. – С. 102-110.
|
|
98
|
+
"""
|
|
99
|
+
|
|
100
|
+
def __init__(
|
|
101
|
+
self,
|
|
102
|
+
history_size: int = 100,
|
|
103
|
+
update_freq: int = 1,
|
|
104
|
+
eig_tol: float = 1e-7,
|
|
105
|
+
truncate: int | None = None,
|
|
106
|
+
damping: float = 1e-4,
|
|
107
|
+
rdamping: float = 0,
|
|
108
|
+
eigenbasis_optimizer: LREOptimizerBase | None = None,
|
|
109
|
+
concat_params: bool = True,
|
|
110
|
+
|
|
111
|
+
inner: Chainable | None = None,
|
|
112
|
+
):
|
|
113
|
+
defaults = locals().copy()
|
|
114
|
+
del defaults['self'], defaults['inner'], defaults['concat_params']
|
|
115
|
+
|
|
116
|
+
super().__init__(defaults, concat_params=concat_params, inner=inner)
|
|
117
|
+
|
|
118
|
+
@torch.no_grad
|
|
119
|
+
def single_tensor_update(self, tensor, param, grad, loss, state, setting):
|
|
120
|
+
history_size = setting['history_size']
|
|
121
|
+
update_freq = setting['update_freq']
|
|
122
|
+
|
|
123
|
+
if 'history' not in state: state['history'] = deque(maxlen=history_size)
|
|
124
|
+
history = state['history']
|
|
125
|
+
|
|
126
|
+
t = tensor.clone().view(-1)
|
|
127
|
+
history.append(t)
|
|
128
|
+
|
|
129
|
+
step = state.get('step', 0)
|
|
130
|
+
state['step'] = step + 1
|
|
131
|
+
|
|
132
|
+
if step % update_freq == 0 :
|
|
133
|
+
|
|
134
|
+
# compute new factors
|
|
135
|
+
L = state.get("L", None)
|
|
136
|
+
U = state.get("U", None)
|
|
137
|
+
|
|
138
|
+
L_new, U_new = ggt_update(
|
|
139
|
+
history,
|
|
140
|
+
damping=setting["damping"],
|
|
141
|
+
rdamping=setting["rdamping"],
|
|
142
|
+
truncate=setting["truncate"],
|
|
143
|
+
eig_tol=setting["eig_tol"],
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
# reproject eigenbasis optimizer
|
|
147
|
+
eigenbasis_optimizer: LREOptimizerBase | None = setting["eigenbasis_optimizer"]
|
|
148
|
+
if eigenbasis_optimizer is not None:
|
|
149
|
+
if (L is not None) and (U is not None) and (L_new is not None) and (U_new is not None):
|
|
150
|
+
eigenbasis_state = state["eigenbasis_state"]
|
|
151
|
+
eigenbasis_optimizer.reproject(L_old=L, Q_old=U, L_new=L_new, Q_new=U_new, state=eigenbasis_state)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
# store new factors
|
|
155
|
+
if L_new is not None: state["L"] = L_new
|
|
156
|
+
if U_new is not None: state["U"] = U_new
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
@torch.no_grad
|
|
160
|
+
def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
|
|
161
|
+
g = tensor.view(-1)
|
|
162
|
+
U = state.get('U', None)
|
|
163
|
+
|
|
164
|
+
if U is None:
|
|
165
|
+
# fallback to element-wise preconditioning
|
|
166
|
+
history = torch.stack(tuple(state["history"]), 0)
|
|
167
|
+
g /= history.square().mean(0).sqrt().add(1e-8)
|
|
168
|
+
return g.view_as(tensor)
|
|
169
|
+
|
|
170
|
+
L = state['L']
|
|
171
|
+
|
|
172
|
+
# step with eigenbasis optimizer
|
|
173
|
+
eigenbasis_optimizer: LREOptimizerBase | None = setting["eigenbasis_optimizer"]
|
|
174
|
+
if eigenbasis_optimizer is not None:
|
|
175
|
+
|
|
176
|
+
if "eigenbasis_state" not in state: state["eigenbasis_state"] = {}
|
|
177
|
+
eigenbasis_state = state["eigenbasis_state"]
|
|
178
|
+
|
|
179
|
+
update = eigenbasis_optimizer.step(g, L=L, Q=U, state=eigenbasis_state)
|
|
180
|
+
return update.view_as(tensor)
|
|
181
|
+
|
|
182
|
+
# or just whiten
|
|
183
|
+
z = U.T @ g
|
|
184
|
+
update = (U * L.rsqrt()) @ z
|
|
185
|
+
return update.view_as(tensor)
|
|
186
|
+
|
|
@@ -1,10 +1,11 @@
|
|
|
1
|
+
from typing import Any
|
|
1
2
|
import torch
|
|
2
3
|
|
|
3
4
|
from ...core import TensorTransform
|
|
4
5
|
from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
|
|
5
6
|
|
|
6
7
|
|
|
7
|
-
def lion_(tensors: TensorList, exp_avg_: TensorList, beta1, beta2,):
|
|
8
|
+
def lion_(tensors: TensorList | Any, exp_avg_: TensorList | Any, beta1, beta2,):
|
|
8
9
|
update = exp_avg_.lerp(tensors, 1-beta1).sign_()
|
|
9
10
|
exp_avg_.lerp_(tensors, 1-beta2)
|
|
10
11
|
return update
|
|
@@ -0,0 +1,299 @@
|
|
|
1
|
+
"""subspace optimizers to be used in a low rank eigenbasis
|
|
2
|
+
|
|
3
|
+
three opts support this - GGT and experimental AdaNystrom and Eigengrad
|
|
4
|
+
|
|
5
|
+
I could define repoject on a module but because most opts use per-parameter state that is complicated"""
|
|
6
|
+
|
|
7
|
+
import math
|
|
8
|
+
from abc import ABC, abstractmethod
|
|
9
|
+
from typing import Any, cast
|
|
10
|
+
|
|
11
|
+
import torch
|
|
12
|
+
|
|
13
|
+
from ...linalg import matrix_power_eigh, torch_linalg
|
|
14
|
+
from .lion import lion_
|
|
15
|
+
|
|
16
|
+
class LREOptimizerBase(ABC):
|
|
17
|
+
"""Optimizer to run in a low rank eigenbasis.
|
|
18
|
+
|
|
19
|
+
notes:
|
|
20
|
+
|
|
21
|
+
1. it shouldn't store any states in self, everything should be in state.
|
|
22
|
+
This is because this may be called on multiple parameters in a sequence
|
|
23
|
+
|
|
24
|
+
2. apply is always called first, than reproject whenever eigenbasis gets updated
|
|
25
|
+
|
|
26
|
+
3. L is variance in the eigenbasis.
|
|
27
|
+
"""
|
|
28
|
+
@abstractmethod
|
|
29
|
+
def step(self, g: torch.Tensor, L: torch.Tensor, Q: torch.Tensor, state: dict) -> torch.Tensor:
|
|
30
|
+
...
|
|
31
|
+
|
|
32
|
+
@abstractmethod
|
|
33
|
+
def reproject(self, L_old: torch.Tensor, Q_old: torch.Tensor,
|
|
34
|
+
L_new: torch.Tensor, Q_new: torch.Tensor, state: dict) -> None:
|
|
35
|
+
...
|
|
36
|
+
|
|
37
|
+
class Whiten(LREOptimizerBase):
|
|
38
|
+
"""This simply applies whitening and is equivalent to not running an optimizer in the eigenbasis"""
|
|
39
|
+
def step(self, g, L, Q, state): return (Q * L.rsqrt()) @ (Q.T @ g)
|
|
40
|
+
def reproject(self, L_old, Q_old, L_new, Q_new, state): pass
|
|
41
|
+
|
|
42
|
+
class EMA(LREOptimizerBase):
|
|
43
|
+
"""Maintains exponential moving average of gradients in the low rank eigenbasis. Nesterov setting is experimental"""
|
|
44
|
+
def __init__(self, beta=0.9, nesterov:bool=False, cautious:bool=False, whiten:bool=True):
|
|
45
|
+
self.beta = beta
|
|
46
|
+
self.nesterov = nesterov
|
|
47
|
+
self.whiten = whiten
|
|
48
|
+
self.cautious = cautious
|
|
49
|
+
|
|
50
|
+
def step(self, g, L, Q, state):
|
|
51
|
+
g = Q.T @ g
|
|
52
|
+
|
|
53
|
+
if "exp_avg" not in state:
|
|
54
|
+
state["exp_avg"] = torch.zeros_like(g)
|
|
55
|
+
|
|
56
|
+
exp_avg = state["exp_avg"]
|
|
57
|
+
exp_avg.lerp_(g, 1-self.beta)
|
|
58
|
+
|
|
59
|
+
if self.nesterov:
|
|
60
|
+
dir = (g + exp_avg * self.beta) / (1 + self.beta)
|
|
61
|
+
else:
|
|
62
|
+
dir = exp_avg
|
|
63
|
+
|
|
64
|
+
if self.cautious:
|
|
65
|
+
mask = (g * dir) > 0
|
|
66
|
+
dir *= mask
|
|
67
|
+
|
|
68
|
+
if self.whiten: return (Q * L.rsqrt()) @ dir
|
|
69
|
+
return Q @ dir
|
|
70
|
+
|
|
71
|
+
def reproject(self, L_old, Q_old, L_new, Q_new, state):
|
|
72
|
+
if "exp_avg" not in state: return
|
|
73
|
+
C = Q_new.T @ Q_old
|
|
74
|
+
state["exp_avg"] = C @ state["exp_avg"]
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def adam(g:torch.Tensor, state:dict, beta1, beta2, eps):
|
|
78
|
+
|
|
79
|
+
if "exp_avg" not in state:
|
|
80
|
+
state["exp_avg"] = torch.zeros_like(g)
|
|
81
|
+
state["exp_avg_sq"] = torch.zeros_like(g)
|
|
82
|
+
state["current_step"] = 1
|
|
83
|
+
|
|
84
|
+
exp_avg = state["exp_avg"]
|
|
85
|
+
exp_avg_sq = state["exp_avg_sq"]
|
|
86
|
+
current_step = state["current_step"]
|
|
87
|
+
|
|
88
|
+
exp_avg.lerp_(g, 1-beta1)
|
|
89
|
+
exp_avg_sq.mul_(beta2).addcmul_(g, g, value=1-beta2)
|
|
90
|
+
denom = exp_avg_sq.sqrt().add_(eps)
|
|
91
|
+
|
|
92
|
+
bias_correction1 = 1.0 - (beta1 ** current_step)
|
|
93
|
+
bias_correction2 = 1.0 - (beta2 ** current_step)
|
|
94
|
+
alpha = math.sqrt(bias_correction2) / bias_correction1
|
|
95
|
+
state["current_step"] = current_step + 1
|
|
96
|
+
|
|
97
|
+
return (exp_avg * alpha) / denom
|
|
98
|
+
|
|
99
|
+
def _squared_reproject(C: torch.Tensor, sq: torch.Tensor, exact: bool):
|
|
100
|
+
if exact:
|
|
101
|
+
return (C @ sq.diag_embed() @ C.T).diagonal()
|
|
102
|
+
|
|
103
|
+
return C.square() @ sq
|
|
104
|
+
|
|
105
|
+
class Adam(LREOptimizerBase):
|
|
106
|
+
"""Runs Adam in low rank eigenbasis."""
|
|
107
|
+
def __init__(self, beta1=0.9, beta2=0.95, cautious:bool=False, eps=1e-8, exact_reproject:bool=True):
|
|
108
|
+
self.beta1 = beta1
|
|
109
|
+
self.beta2 = beta2
|
|
110
|
+
self.eps = eps
|
|
111
|
+
self.cautious = cautious
|
|
112
|
+
self.exact_reproject = exact_reproject
|
|
113
|
+
|
|
114
|
+
def step(self, g, L, Q, state):
|
|
115
|
+
g = Q.T @ g
|
|
116
|
+
|
|
117
|
+
dir = adam(g, state, self.beta1, self.beta2, self.eps)
|
|
118
|
+
|
|
119
|
+
if self.cautious:
|
|
120
|
+
mask = (g * dir) > 0
|
|
121
|
+
dir *= mask
|
|
122
|
+
|
|
123
|
+
return Q @ dir
|
|
124
|
+
|
|
125
|
+
def reproject(self, L_old, Q_old, L_new, Q_new, state):
|
|
126
|
+
if "exp_avg" not in state: return
|
|
127
|
+
C = Q_new.T @ Q_old
|
|
128
|
+
|
|
129
|
+
state["exp_avg"] = C @ state["exp_avg"]
|
|
130
|
+
state["exp_avg_sq"] = _squared_reproject(C, state["exp_avg_sq"], self.exact_reproject)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
class FullMatrixAdam(LREOptimizerBase):
|
|
134
|
+
"""Runs full-matrix Adam in low rank eigenbasis.
|
|
135
|
+
The preconditioner is updated whenever basis is updated"""
|
|
136
|
+
def __init__(self, beta1=0.9, beta2=0.95, eps=1e-8, matrix_power=-1/2, abs=True, cautious:bool=False):
|
|
137
|
+
self.beta1 = beta1
|
|
138
|
+
self.beta2 = beta2
|
|
139
|
+
self.eps = eps
|
|
140
|
+
self.matrix_power = matrix_power
|
|
141
|
+
self.abs = abs
|
|
142
|
+
self.cautious = cautious
|
|
143
|
+
|
|
144
|
+
def step(self, g, L, Q, state):
|
|
145
|
+
g = Q.T @ g
|
|
146
|
+
|
|
147
|
+
# initialize
|
|
148
|
+
if "exp_avg" not in state:
|
|
149
|
+
state["exp_avg"] = torch.zeros_like(g)
|
|
150
|
+
state["covariance"] = torch.eye(g.numel(), device=g.device, dtype=g.dtype)
|
|
151
|
+
state["preconditioner"] = torch.eye(g.numel(), device=g.device, dtype=g.dtype)
|
|
152
|
+
state["reprojected"] = True
|
|
153
|
+
state["current_step"] = 1
|
|
154
|
+
|
|
155
|
+
exp_avg = state["exp_avg"]
|
|
156
|
+
covariance = state["covariance"]
|
|
157
|
+
current_step = state["current_step"]
|
|
158
|
+
|
|
159
|
+
# update buffers
|
|
160
|
+
exp_avg.lerp_(g, 1-self.beta1)
|
|
161
|
+
covariance.lerp_(g.outer(g), weight=1-self.beta2)
|
|
162
|
+
|
|
163
|
+
# correct bias
|
|
164
|
+
bias_correction1 = 1.0 - (self.beta1 ** current_step)
|
|
165
|
+
exp_avg = exp_avg / bias_correction1
|
|
166
|
+
|
|
167
|
+
# after reprojecting update the preconditioner
|
|
168
|
+
if state["reprojected"]:
|
|
169
|
+
state["reprojected"] = False
|
|
170
|
+
|
|
171
|
+
bias_correction2 = 1.0 - (self.beta2 ** current_step)
|
|
172
|
+
covariance = covariance / bias_correction2
|
|
173
|
+
|
|
174
|
+
reg = torch.eye(covariance.size(0), device=covariance.device, dtype=covariance.dtype).mul_(self.eps)
|
|
175
|
+
covariance = covariance + reg
|
|
176
|
+
|
|
177
|
+
# compute matrix power
|
|
178
|
+
try:
|
|
179
|
+
state["preconditioner"] = matrix_power_eigh(covariance, self.matrix_power, abs=self.abs)
|
|
180
|
+
|
|
181
|
+
except torch.linalg.LinAlgError:
|
|
182
|
+
|
|
183
|
+
# fallback to diagonal
|
|
184
|
+
state["preconditioner"] = covariance.diagonal().rsqrt().diag_embed()
|
|
185
|
+
|
|
186
|
+
# compute the update
|
|
187
|
+
state["current_step"] = current_step + 1
|
|
188
|
+
preconditioner = state["preconditioner"]
|
|
189
|
+
dir = preconditioner @ exp_avg
|
|
190
|
+
|
|
191
|
+
if self.cautious:
|
|
192
|
+
mask = (g * dir) > 0
|
|
193
|
+
dir *= mask
|
|
194
|
+
|
|
195
|
+
return Q @ dir
|
|
196
|
+
|
|
197
|
+
def reproject(self, L_old, Q_old, L_new, Q_new, state):
|
|
198
|
+
if "exp_avg" not in state: return
|
|
199
|
+
|
|
200
|
+
state["reprojected"] = True
|
|
201
|
+
|
|
202
|
+
C = Q_new.T @ Q_old
|
|
203
|
+
state["exp_avg"] = C @ state["exp_avg"]
|
|
204
|
+
state["covariance"] = C @ state["covariance"] @ C.T
|
|
205
|
+
|
|
206
|
+
class Lion(LREOptimizerBase):
|
|
207
|
+
"""Runs Lion in the low rank eigenbasis."""
|
|
208
|
+
def __init__(self, beta1=0.9, beta2=0.99, cautious:bool=False):
|
|
209
|
+
self.beta1 = beta1
|
|
210
|
+
self.beta2 = beta2
|
|
211
|
+
self.cautious = cautious
|
|
212
|
+
|
|
213
|
+
def step(self, g, L, Q, state):
|
|
214
|
+
g = Q.T @ g
|
|
215
|
+
|
|
216
|
+
if "exp_avg" not in state:
|
|
217
|
+
state["exp_avg"] = torch.zeros_like(g)
|
|
218
|
+
|
|
219
|
+
dir = cast(torch.Tensor, lion_(g, state["exp_avg"], beta1=self.beta1, beta2=self.beta2))
|
|
220
|
+
|
|
221
|
+
if self.cautious:
|
|
222
|
+
mask = (g * dir) > 0
|
|
223
|
+
dir *= mask
|
|
224
|
+
|
|
225
|
+
return Q @ dir
|
|
226
|
+
|
|
227
|
+
def reproject(self, L_old, Q_old, L_new, Q_new, state):
|
|
228
|
+
if "exp_avg" not in state: return
|
|
229
|
+
C = Q_new.T @ Q_old
|
|
230
|
+
state["exp_avg"] = C @ state["exp_avg"]
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
class Grams(LREOptimizerBase):
|
|
234
|
+
"""Runs Grams in low rank eigenbasis."""
|
|
235
|
+
def __init__(self, beta1=0.9, beta2=0.95, eps=1e-8, exact_reproject=True):
|
|
236
|
+
self.beta1 = beta1
|
|
237
|
+
self.beta2 = beta2
|
|
238
|
+
self.eps = eps
|
|
239
|
+
self.exact_reproject = exact_reproject
|
|
240
|
+
|
|
241
|
+
def step(self, g, L, Q, state):
|
|
242
|
+
g = Q.T @ g
|
|
243
|
+
dir = adam(g, state, self.beta1, self.beta2, self.eps)
|
|
244
|
+
return Q @ dir.copysign(g)
|
|
245
|
+
|
|
246
|
+
def reproject(self, L_old, Q_old, L_new, Q_new, state):
|
|
247
|
+
if "exp_avg" not in state: return
|
|
248
|
+
C = Q_new.T @ Q_old
|
|
249
|
+
|
|
250
|
+
state["exp_avg"] = C @ state["exp_avg"]
|
|
251
|
+
state["exp_avg_sq"] = _squared_reproject(C, state["exp_avg_sq"], self.exact_reproject)
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
class LaProp(LREOptimizerBase):
|
|
255
|
+
"""Runs LaProp in low rank eigenbasis."""
|
|
256
|
+
def __init__(self, beta1=0.9, beta2=0.95, eps=1e-8, cautious:bool=False, exact_reproject=True):
|
|
257
|
+
self.beta1 = beta1
|
|
258
|
+
self.beta2 = beta2
|
|
259
|
+
self.eps = eps
|
|
260
|
+
self.cautious = cautious
|
|
261
|
+
self.exact_reproject = exact_reproject
|
|
262
|
+
|
|
263
|
+
def step(self, g, L, Q, state):
|
|
264
|
+
g = Q.T @ g
|
|
265
|
+
|
|
266
|
+
if "exp_avg" not in state:
|
|
267
|
+
state["exp_avg"] = torch.zeros_like(g)
|
|
268
|
+
state["exp_avg_sq"] = torch.zeros_like(g)
|
|
269
|
+
state["current_step"] = 1
|
|
270
|
+
|
|
271
|
+
exp_avg = state["exp_avg"]
|
|
272
|
+
exp_avg_sq = state["exp_avg_sq"]
|
|
273
|
+
current_step = state["current_step"]
|
|
274
|
+
|
|
275
|
+
# update second moments
|
|
276
|
+
exp_avg_sq.mul_(self.beta2).addcmul_(g, g, value=1-self.beta2)
|
|
277
|
+
bias_correction2 = 1.0 - (self.beta2 ** current_step)
|
|
278
|
+
|
|
279
|
+
# divide by bias corrected second moments
|
|
280
|
+
dir = g / (exp_avg_sq / bias_correction2).sqrt().add_(self.eps)
|
|
281
|
+
|
|
282
|
+
# update first moments and bias correct
|
|
283
|
+
exp_avg.lerp_(dir, 1-self.beta1)
|
|
284
|
+
bias_correction1 = 1.0 - (self.beta1 ** current_step)
|
|
285
|
+
dir = exp_avg / bias_correction1
|
|
286
|
+
|
|
287
|
+
if self.cautious:
|
|
288
|
+
mask = (g * dir) > 0
|
|
289
|
+
dir *= mask
|
|
290
|
+
|
|
291
|
+
state["current_step"] = current_step + 1
|
|
292
|
+
return Q @ dir
|
|
293
|
+
|
|
294
|
+
def reproject(self, L_old, Q_old, L_new, Q_new, state):
|
|
295
|
+
if "exp_avg" not in state: return
|
|
296
|
+
C = Q_new.T @ Q_old
|
|
297
|
+
|
|
298
|
+
state["exp_avg"] = C @ state["exp_avg"]
|
|
299
|
+
state["exp_avg_sq"] = _squared_reproject(C, state["exp_avg_sq"], self.exact_reproject)
|
|
@@ -35,7 +35,7 @@ class MARSCorrection(TensorTransform):
|
|
|
35
35
|
|
|
36
36
|
Mars-AdamW
|
|
37
37
|
```python
|
|
38
|
-
optimizer = tz.
|
|
38
|
+
optimizer = tz.Optimizer(
|
|
39
39
|
model.parameters(),
|
|
40
40
|
tz.m.MARSCorrection(beta=0.95),
|
|
41
41
|
tz.m.Adam(beta1=0.95, beta2=0.99),
|
|
@@ -46,7 +46,7 @@ class MARSCorrection(TensorTransform):
|
|
|
46
46
|
|
|
47
47
|
Mars-Lion
|
|
48
48
|
```python
|
|
49
|
-
optimizer = tz.
|
|
49
|
+
optimizer = tz.Optimizer(
|
|
50
50
|
model.parameters(),
|
|
51
51
|
tz.m.MARSCorrection(beta=0.9),
|
|
52
52
|
tz.m.Lion(beta1=0.9),
|
|
@@ -4,7 +4,7 @@ import torch
|
|
|
4
4
|
|
|
5
5
|
from ...core import Chainable, Transform, HVPMethod
|
|
6
6
|
from ...utils import NumberList, TensorList, unpack_states, unpack_dicts
|
|
7
|
-
from ..
|
|
7
|
+
from ..opt_utils import initial_step_size
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
class MatrixMomentum(Transform):
|
|
@@ -4,7 +4,7 @@ import torch
|
|
|
4
4
|
|
|
5
5
|
from ...core import Chainable, Module, Transform, TensorTransform, step, Objective
|
|
6
6
|
from ...utils import NumberList, TensorList, unpack_dicts, unpack_states, generic_ne
|
|
7
|
-
from ..
|
|
7
|
+
from ..opt_utils import ema_
|
|
8
8
|
from ..momentum.momentum import nag_
|
|
9
9
|
|
|
10
10
|
|
|
@@ -99,7 +99,7 @@ class MSAMMomentum(TensorTransform):
|
|
|
99
99
|
|
|
100
100
|
```python
|
|
101
101
|
|
|
102
|
-
opt = tz.
|
|
102
|
+
opt = tz.Optimizer(
|
|
103
103
|
model.parameters(),
|
|
104
104
|
tz.m.MSAM(1e-3)
|
|
105
105
|
)
|
|
@@ -109,7 +109,7 @@ class MSAMMomentum(TensorTransform):
|
|
|
109
109
|
To make Adam_MSAM and such, use the ``tz.m.MSAMObjective`` module.
|
|
110
110
|
|
|
111
111
|
```python
|
|
112
|
-
opt = tz.
|
|
112
|
+
opt = tz.Optimizer(
|
|
113
113
|
model.parameters(),
|
|
114
114
|
tz.m.RMSprop(0.999, inner=tz.m.MSAM(1e-3)),
|
|
115
115
|
tz.m.Debias(0.9, 0.999),
|
|
@@ -166,7 +166,7 @@ class MSAM(Transform):
|
|
|
166
166
|
AdamW-MSAM
|
|
167
167
|
|
|
168
168
|
```py
|
|
169
|
-
opt = tz.
|
|
169
|
+
opt = tz.Optimizer(
|
|
170
170
|
bench.parameters(),
|
|
171
171
|
tz.m.MSAMObjective(
|
|
172
172
|
[tz.m.Adam(), tz.m.WeightDecay(1e-3), tz.m.LR(1e-3)],
|