torchzero 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +22 -22
- tests/test_opts.py +199 -198
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +1 -1
- torchzero/core/functional.py +1 -1
- torchzero/core/modular.py +5 -5
- torchzero/core/module.py +2 -2
- torchzero/core/objective.py +10 -10
- torchzero/core/transform.py +1 -1
- torchzero/linalg/__init__.py +3 -2
- torchzero/linalg/eigh.py +223 -4
- torchzero/linalg/orthogonalize.py +2 -4
- torchzero/linalg/qr.py +12 -0
- torchzero/linalg/solve.py +1 -3
- torchzero/linalg/svd.py +47 -20
- torchzero/modules/__init__.py +4 -3
- torchzero/modules/adaptive/__init__.py +11 -3
- torchzero/modules/adaptive/adagrad.py +10 -10
- torchzero/modules/adaptive/adahessian.py +2 -2
- torchzero/modules/adaptive/adam.py +1 -1
- torchzero/modules/adaptive/adan.py +1 -1
- torchzero/modules/adaptive/adaptive_heavyball.py +1 -1
- torchzero/modules/adaptive/esgd.py +2 -2
- torchzero/modules/adaptive/ggt.py +186 -0
- torchzero/modules/adaptive/lion.py +2 -1
- torchzero/modules/adaptive/lre_optimizers.py +299 -0
- torchzero/modules/adaptive/mars.py +2 -2
- torchzero/modules/adaptive/matrix_momentum.py +1 -1
- torchzero/modules/adaptive/msam.py +4 -4
- torchzero/modules/adaptive/muon.py +9 -6
- torchzero/modules/adaptive/natural_gradient.py +32 -15
- torchzero/modules/adaptive/psgd/__init__.py +5 -0
- torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
- torchzero/modules/adaptive/psgd/psgd.py +1390 -0
- torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
- torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
- torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
- torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
- torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
- torchzero/modules/adaptive/rprop.py +2 -2
- torchzero/modules/adaptive/sam.py +4 -4
- torchzero/modules/adaptive/shampoo.py +28 -3
- torchzero/modules/adaptive/soap.py +3 -3
- torchzero/modules/adaptive/sophia_h.py +2 -2
- torchzero/modules/clipping/clipping.py +7 -7
- torchzero/modules/conjugate_gradient/cg.py +2 -2
- torchzero/modules/experimental/__init__.py +5 -0
- torchzero/modules/experimental/adanystrom.py +258 -0
- torchzero/modules/experimental/common_directions_whiten.py +142 -0
- torchzero/modules/experimental/cubic_adam.py +160 -0
- torchzero/modules/experimental/eigen_sr1.py +182 -0
- torchzero/modules/experimental/eigengrad.py +207 -0
- torchzero/modules/experimental/l_infinity.py +1 -1
- torchzero/modules/experimental/matrix_nag.py +122 -0
- torchzero/modules/experimental/newton_solver.py +2 -2
- torchzero/modules/experimental/newtonnewton.py +34 -40
- torchzero/modules/grad_approximation/fdm.py +2 -2
- torchzero/modules/grad_approximation/rfdm.py +4 -4
- torchzero/modules/least_squares/gn.py +68 -45
- torchzero/modules/line_search/backtracking.py +2 -2
- torchzero/modules/line_search/line_search.py +1 -1
- torchzero/modules/line_search/strong_wolfe.py +2 -2
- torchzero/modules/misc/escape.py +1 -1
- torchzero/modules/misc/gradient_accumulation.py +1 -1
- torchzero/modules/misc/misc.py +1 -1
- torchzero/modules/misc/multistep.py +4 -7
- torchzero/modules/misc/regularization.py +2 -2
- torchzero/modules/misc/split.py +1 -1
- torchzero/modules/misc/switch.py +2 -2
- torchzero/modules/momentum/cautious.py +3 -3
- torchzero/modules/momentum/momentum.py +1 -1
- torchzero/modules/ops/higher_level.py +1 -1
- torchzero/modules/ops/multi.py +1 -1
- torchzero/modules/projections/projection.py +5 -2
- torchzero/modules/quasi_newton/__init__.py +1 -1
- torchzero/modules/quasi_newton/damping.py +1 -1
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +3 -3
- torchzero/modules/quasi_newton/lsr1.py +3 -3
- torchzero/modules/quasi_newton/quasi_newton.py +44 -29
- torchzero/modules/quasi_newton/sg2.py +69 -205
- torchzero/modules/restarts/restars.py +17 -17
- torchzero/modules/second_order/inm.py +33 -25
- torchzero/modules/second_order/newton.py +132 -130
- torchzero/modules/second_order/newton_cg.py +3 -3
- torchzero/modules/second_order/nystrom.py +83 -32
- torchzero/modules/second_order/rsn.py +41 -44
- torchzero/modules/smoothing/laplacian.py +1 -1
- torchzero/modules/smoothing/sampling.py +2 -3
- torchzero/modules/step_size/adaptive.py +6 -6
- torchzero/modules/step_size/lr.py +2 -2
- torchzero/modules/trust_region/cubic_regularization.py +1 -1
- torchzero/modules/trust_region/levenberg_marquardt.py +2 -2
- torchzero/modules/trust_region/trust_cg.py +1 -1
- torchzero/modules/variance_reduction/svrg.py +4 -5
- torchzero/modules/weight_decay/reinit.py +2 -2
- torchzero/modules/weight_decay/weight_decay.py +5 -5
- torchzero/modules/wrappers/optim_wrapper.py +4 -4
- torchzero/modules/zeroth_order/cd.py +1 -1
- torchzero/optim/mbs.py +291 -0
- torchzero/optim/wrappers/nevergrad.py +0 -9
- torchzero/optim/wrappers/optuna.py +2 -0
- torchzero/utils/benchmarks/__init__.py +0 -0
- torchzero/utils/benchmarks/logistic.py +122 -0
- torchzero/utils/derivatives.py +4 -4
- {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
- torchzero-0.4.1.dist-info/RECORD +209 -0
- torchzero/modules/adaptive/lmadagrad.py +0 -241
- torchzero-0.4.0.dist-info/RECORD +0 -191
- /torchzero/modules/{functional.py → opt_utils.py} +0 -0
- {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
- {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
# pylint:disable=not-callable
|
|
2
|
+
"""all functions are from https://github.com/lixilinx/psgd_torch/blob/master/psgd.py"""
|
|
3
|
+
import math
|
|
4
|
+
import warnings
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from ....core import Chainable, TensorTransform
|
|
9
|
+
from ._psgd_utils import _initialize_lra_state_
|
|
10
|
+
from .psgd import lift2single, precond_grad_lra, update_precond_lra_whiten
|
|
11
|
+
|
|
12
|
+
# matches
|
|
13
|
+
class PSGDLRAWhiten(TensorTransform):
|
|
14
|
+
"""Low rank whitening preconditioner from Preconditioned Stochastic Gradient Descent (see https://github.com/lixilinx/psgd_torch)
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
rank (int, optional):
|
|
18
|
+
Preconditioner has a diagonal part and a low rank part, whose rank is decided by this setting. Defaults to 10.
|
|
19
|
+
init_scale (float | None, optional):
|
|
20
|
+
initial scale of the preconditioner. If None, determined based on a heuristic. Defaults to None.
|
|
21
|
+
lr_preconditioner (float, optional): learning rate of the preconditioner. Defaults to 0.1.
|
|
22
|
+
betaL (float, optional): EMA factor for the L-smoothness constant wrt Q. Defaults to 0.9.
|
|
23
|
+
damping (float, optional):
|
|
24
|
+
adds small noise to hessian-vector product when updating the preconditioner. Defaults to 1e-9.
|
|
25
|
+
grad_clip_max_norm (float, optional): clips norm of the update. Defaults to float("inf").
|
|
26
|
+
update_probability (float, optional): probability of updating preconditioner on each step. Defaults to 1.0.
|
|
27
|
+
concat_params (bool, optional):
|
|
28
|
+
if True, treats all parameters as concatenated to a single vector.
|
|
29
|
+
If False, each parameter is preconditioned separately. Defaults to True.
|
|
30
|
+
inner (Chainable | None, optional): preconditioning will be applied to output of this module. Defaults to None.
|
|
31
|
+
|
|
32
|
+
###Examples:
|
|
33
|
+
|
|
34
|
+
Pure PSGD LRA:
|
|
35
|
+
```py
|
|
36
|
+
optimizer = tz.Optimizer(
|
|
37
|
+
model.parameters(),
|
|
38
|
+
tz.m.LRAWhiten(),
|
|
39
|
+
tz.m.LR(1e-3),
|
|
40
|
+
)
|
|
41
|
+
```
|
|
42
|
+
|
|
43
|
+
Momentum into preconditioner (whitens momentum):
|
|
44
|
+
```py
|
|
45
|
+
optimizer = tz.Optimizer(
|
|
46
|
+
model.parameters(),
|
|
47
|
+
tz.m.EMA(0.9),
|
|
48
|
+
tz.m.LRAWhiten(),
|
|
49
|
+
tz.m.LR(1e-3),
|
|
50
|
+
)
|
|
51
|
+
```
|
|
52
|
+
|
|
53
|
+
Updating the preconditioner from gradients and applying it to momentum:
|
|
54
|
+
```py
|
|
55
|
+
optimizer = tz.Optimizer(
|
|
56
|
+
model.parameters(),
|
|
57
|
+
tz.m.LRAWhiten(inner=tz.m.EMA(0.9)),
|
|
58
|
+
tz.m.LR(1e-3),
|
|
59
|
+
)
|
|
60
|
+
```
|
|
61
|
+
|
|
62
|
+
"""
|
|
63
|
+
def __init__(
|
|
64
|
+
self,
|
|
65
|
+
rank: int = 10,
|
|
66
|
+
init_scale: float | None = None,
|
|
67
|
+
lr_preconditioner=0.1,
|
|
68
|
+
betaL=0.9,
|
|
69
|
+
damping=1e-9,
|
|
70
|
+
grad_clip_max_amp=float("inf"),
|
|
71
|
+
update_probability=1.0,
|
|
72
|
+
|
|
73
|
+
concat_params: bool = True,
|
|
74
|
+
inner: Chainable | None = None,
|
|
75
|
+
):
|
|
76
|
+
defaults = locals().copy()
|
|
77
|
+
del defaults["inner"], defaults["self"]
|
|
78
|
+
super().__init__(defaults, concat_params=concat_params, inner=inner)
|
|
79
|
+
|
|
80
|
+
@torch.no_grad
|
|
81
|
+
def single_tensor_initialize(self, tensor, param, grad, loss, state, setting):
|
|
82
|
+
_initialize_lra_state_(tensor, state, setting)
|
|
83
|
+
|
|
84
|
+
@torch.no_grad
|
|
85
|
+
def single_tensor_update(self, tensor, param, grad, loss, state, setting):
|
|
86
|
+
|
|
87
|
+
g = tensor.ravel().unsqueeze(1) # column vector
|
|
88
|
+
|
|
89
|
+
UVd = state["UVd"]
|
|
90
|
+
if UVd[2] is None: # initialize d on the fly
|
|
91
|
+
UVd[2] = (torch.mean(g**4) + setting["damping"]**4)**(-1/8) * torch.ones_like(g)
|
|
92
|
+
|
|
93
|
+
if torch.rand([]) < setting["update_probability"]: # update preconditioner
|
|
94
|
+
update_precond_lra_whiten(
|
|
95
|
+
UVd=UVd,
|
|
96
|
+
Luvd=state["Luvd"],
|
|
97
|
+
g=g,
|
|
98
|
+
lr=setting["lr_preconditioner"],
|
|
99
|
+
betaL=setting["betaL"],
|
|
100
|
+
damping=setting["damping"],
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
@torch.no_grad
|
|
104
|
+
def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
|
|
105
|
+
|
|
106
|
+
g = tensor.ravel().unsqueeze(1)
|
|
107
|
+
pre_grad = precond_grad_lra(UVd=state["UVd"], g=g)
|
|
108
|
+
|
|
109
|
+
# norm clipping
|
|
110
|
+
grad_clip_max_amp = setting["grad_clip_max_amp"]
|
|
111
|
+
if grad_clip_max_amp < float("inf"): # clip preconditioned gradient
|
|
112
|
+
amp = torch.sqrt(torch.mean(pre_grad * pre_grad))
|
|
113
|
+
if amp > grad_clip_max_amp:
|
|
114
|
+
pre_grad *= grad_clip_max_amp/amp
|
|
115
|
+
|
|
116
|
+
return pre_grad.view_as(tensor)
|
|
@@ -304,7 +304,7 @@ class SignConsistencyMask(TensorTransform):
|
|
|
304
304
|
GD that skips update for weights where gradient sign changed compared to previous gradient.
|
|
305
305
|
|
|
306
306
|
```python
|
|
307
|
-
opt = tz.
|
|
307
|
+
opt = tz.Optimizer(
|
|
308
308
|
model.parameters(),
|
|
309
309
|
tz.m.Mul(tz.m.SignConsistencyMask()),
|
|
310
310
|
tz.m.LR(1e-2)
|
|
@@ -334,7 +334,7 @@ class SignConsistencyLRs(TensorTransform):
|
|
|
334
334
|
|
|
335
335
|
```python
|
|
336
336
|
|
|
337
|
-
opt = tz.
|
|
337
|
+
opt = tz.Optimizer(
|
|
338
338
|
model.parameters(),
|
|
339
339
|
tz.m.Mul(tz.m.SignConsistencyLRs()),
|
|
340
340
|
tz.m.LR(1e-2)
|
|
@@ -31,7 +31,7 @@ class SAM(Transform):
|
|
|
31
31
|
SAM-SGD:
|
|
32
32
|
|
|
33
33
|
```py
|
|
34
|
-
opt = tz.
|
|
34
|
+
opt = tz.Optimizer(
|
|
35
35
|
model.parameters(),
|
|
36
36
|
tz.m.SAM(),
|
|
37
37
|
tz.m.LR(1e-2)
|
|
@@ -41,7 +41,7 @@ class SAM(Transform):
|
|
|
41
41
|
SAM-Adam:
|
|
42
42
|
|
|
43
43
|
```
|
|
44
|
-
opt = tz.
|
|
44
|
+
opt = tz.Optimizer(
|
|
45
45
|
model.parameters(),
|
|
46
46
|
tz.m.SAM(),
|
|
47
47
|
tz.m.Adam(),
|
|
@@ -149,7 +149,7 @@ class ASAM(SAM):
|
|
|
149
149
|
ASAM-SGD:
|
|
150
150
|
|
|
151
151
|
```py
|
|
152
|
-
opt = tz.
|
|
152
|
+
opt = tz.Optimizer(
|
|
153
153
|
model.parameters(),
|
|
154
154
|
tz.m.ASAM(),
|
|
155
155
|
tz.m.LR(1e-2)
|
|
@@ -159,7 +159,7 @@ class ASAM(SAM):
|
|
|
159
159
|
ASAM-Adam:
|
|
160
160
|
|
|
161
161
|
```
|
|
162
|
-
opt = tz.
|
|
162
|
+
opt = tz.Optimizer(
|
|
163
163
|
model.parameters(),
|
|
164
164
|
tz.m.ASAM(),
|
|
165
165
|
tz.m.Adam(),
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from collections.abc import Sequence
|
|
1
|
+
from collections.abc import Sequence, Iterable
|
|
2
2
|
|
|
3
3
|
import numpy as np
|
|
4
4
|
import torch
|
|
@@ -82,6 +82,31 @@ def _unmerge_small_dims(tensor: torch.Tensor, flat_sizes: Sequence[int] | None,
|
|
|
82
82
|
tensor = tensor.unflatten(0, flat_sizes)
|
|
83
83
|
return tensor.permute(*np.argsort(sort_idxs).tolist())
|
|
84
84
|
|
|
85
|
+
def diagonal_memory(params: torch.nn.Module | torch.Tensor | Iterable[torch.Tensor]):
|
|
86
|
+
"""computes number of parameters"""
|
|
87
|
+
if isinstance(params, torch.nn.Module): params = params.parameters()
|
|
88
|
+
if isinstance(params, torch.Tensor): params = [params,]
|
|
89
|
+
params = list(params)
|
|
90
|
+
return sum(p.numel() for p in params)
|
|
91
|
+
|
|
92
|
+
def kronecker_memory(params: torch.nn.Module | torch.Tensor | Iterable[torch.Tensor], merge_small:bool=True, max_dim:int=10_000):
|
|
93
|
+
"""computes total size of tensors required to store shampoo preconditioner"""
|
|
94
|
+
if isinstance(params, torch.nn.Module): params = params.parameters()
|
|
95
|
+
if isinstance(params, torch.Tensor): params = [params,]
|
|
96
|
+
params = list(params)
|
|
97
|
+
|
|
98
|
+
memory = 0
|
|
99
|
+
for p in params:
|
|
100
|
+
if merge_small:
|
|
101
|
+
p, _, _ = _merge_small_dims(p, max_dim)
|
|
102
|
+
for dim in p.size():
|
|
103
|
+
if dim > max_dim: memory += dim
|
|
104
|
+
else: memory += dim**2
|
|
105
|
+
|
|
106
|
+
return memory
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
|
|
85
110
|
|
|
86
111
|
class Shampoo(TensorTransform):
|
|
87
112
|
"""Shampoo from Preconditioned Stochastic Tensor Optimization (https://arxiv.org/abs/1802.09568).
|
|
@@ -112,7 +137,7 @@ class Shampoo(TensorTransform):
|
|
|
112
137
|
Shampoo grafted to Adam
|
|
113
138
|
|
|
114
139
|
```python
|
|
115
|
-
opt = tz.
|
|
140
|
+
opt = tz.Optimizer(
|
|
116
141
|
model.parameters(),
|
|
117
142
|
tz.m.GraftModules(
|
|
118
143
|
direction = tz.m.Shampoo(),
|
|
@@ -125,7 +150,7 @@ class Shampoo(TensorTransform):
|
|
|
125
150
|
Adam with Shampoo preconditioner
|
|
126
151
|
|
|
127
152
|
```python
|
|
128
|
-
opt = tz.
|
|
153
|
+
opt = tz.Optimizer(
|
|
129
154
|
model.parameters(),
|
|
130
155
|
tz.m.Shampoo(beta=0.999, inner=tz.m.EMA(0.9)),
|
|
131
156
|
tz.m.Debias(0.9, 0.999),
|
|
@@ -132,7 +132,7 @@ class SOAP(TensorTransform):
|
|
|
132
132
|
SOAP:
|
|
133
133
|
|
|
134
134
|
```python
|
|
135
|
-
opt = tz.
|
|
135
|
+
opt = tz.Optimizer(
|
|
136
136
|
model.parameters(),
|
|
137
137
|
tz.m.SOAP(),
|
|
138
138
|
tz.m.LR(1e-3)
|
|
@@ -141,7 +141,7 @@ class SOAP(TensorTransform):
|
|
|
141
141
|
Stabilized SOAP:
|
|
142
142
|
|
|
143
143
|
```python
|
|
144
|
-
opt = tz.
|
|
144
|
+
opt = tz.Optimizer(
|
|
145
145
|
model.parameters(),
|
|
146
146
|
tz.m.SOAP(),
|
|
147
147
|
tz.m.NormalizeByEMA(max_ema_growth=1.2),
|
|
@@ -156,7 +156,7 @@ class SOAP(TensorTransform):
|
|
|
156
156
|
shampoo_beta: float | None = 0.95,
|
|
157
157
|
precond_freq: int = 10,
|
|
158
158
|
merge_small: bool = True,
|
|
159
|
-
max_dim: int =
|
|
159
|
+
max_dim: int = 4096,
|
|
160
160
|
precondition_1d: bool = True,
|
|
161
161
|
eps: float = 1e-8,
|
|
162
162
|
debias: bool = True,
|
|
@@ -50,7 +50,7 @@ class SophiaH(Transform):
|
|
|
50
50
|
|
|
51
51
|
```python
|
|
52
52
|
|
|
53
|
-
opt = tz.
|
|
53
|
+
opt = tz.Optimizer(
|
|
54
54
|
model.parameters(),
|
|
55
55
|
tz.m.SophiaH(),
|
|
56
56
|
tz.m.LR(0.1)
|
|
@@ -63,7 +63,7 @@ class SophiaH(Transform):
|
|
|
63
63
|
|
|
64
64
|
```python
|
|
65
65
|
|
|
66
|
-
opt = tz.
|
|
66
|
+
opt = tz.Optimizer(
|
|
67
67
|
model.parameters(),
|
|
68
68
|
tz.m.SophiaH(beta1=0, inner=tz.m.NAG(0.96)),
|
|
69
69
|
tz.m.LR(0.1)
|
|
@@ -161,7 +161,7 @@ class ClipValue(TensorTransform):
|
|
|
161
161
|
|
|
162
162
|
Gradient clipping:
|
|
163
163
|
```python
|
|
164
|
-
opt = tz.
|
|
164
|
+
opt = tz.Optimizer(
|
|
165
165
|
model.parameters(),
|
|
166
166
|
tz.m.ClipValue(1),
|
|
167
167
|
tz.m.Adam(),
|
|
@@ -171,7 +171,7 @@ class ClipValue(TensorTransform):
|
|
|
171
171
|
|
|
172
172
|
Update clipping:
|
|
173
173
|
```python
|
|
174
|
-
opt = tz.
|
|
174
|
+
opt = tz.Optimizer(
|
|
175
175
|
model.parameters(),
|
|
176
176
|
tz.m.Adam(),
|
|
177
177
|
tz.m.ClipValue(1),
|
|
@@ -211,7 +211,7 @@ class ClipNorm(TensorTransform):
|
|
|
211
211
|
|
|
212
212
|
Gradient norm clipping:
|
|
213
213
|
```python
|
|
214
|
-
opt = tz.
|
|
214
|
+
opt = tz.Optimizer(
|
|
215
215
|
model.parameters(),
|
|
216
216
|
tz.m.ClipNorm(1),
|
|
217
217
|
tz.m.Adam(),
|
|
@@ -221,7 +221,7 @@ class ClipNorm(TensorTransform):
|
|
|
221
221
|
|
|
222
222
|
Update norm clipping:
|
|
223
223
|
```python
|
|
224
|
-
opt = tz.
|
|
224
|
+
opt = tz.Optimizer(
|
|
225
225
|
model.parameters(),
|
|
226
226
|
tz.m.Adam(),
|
|
227
227
|
tz.m.ClipNorm(1),
|
|
@@ -277,7 +277,7 @@ class Normalize(TensorTransform):
|
|
|
277
277
|
Examples:
|
|
278
278
|
Gradient normalization:
|
|
279
279
|
```python
|
|
280
|
-
opt = tz.
|
|
280
|
+
opt = tz.Optimizer(
|
|
281
281
|
model.parameters(),
|
|
282
282
|
tz.m.Normalize(1),
|
|
283
283
|
tz.m.Adam(),
|
|
@@ -288,7 +288,7 @@ class Normalize(TensorTransform):
|
|
|
288
288
|
Update normalization:
|
|
289
289
|
|
|
290
290
|
```python
|
|
291
|
-
opt = tz.
|
|
291
|
+
opt = tz.Optimizer(
|
|
292
292
|
model.parameters(),
|
|
293
293
|
tz.m.Adam(),
|
|
294
294
|
tz.m.Normalize(1),
|
|
@@ -378,7 +378,7 @@ class Centralize(TensorTransform):
|
|
|
378
378
|
|
|
379
379
|
Standard gradient centralization:
|
|
380
380
|
```python
|
|
381
|
-
opt = tz.
|
|
381
|
+
opt = tz.Optimizer(
|
|
382
382
|
model.parameters(),
|
|
383
383
|
tz.m.Centralize(dim=0),
|
|
384
384
|
tz.m.LR(1e-2),
|
|
@@ -7,7 +7,7 @@ from ...core import Chainable, TensorTransform
|
|
|
7
7
|
|
|
8
8
|
from ...utils import TensorList, safe_dict_update_, unpack_dicts, unpack_states
|
|
9
9
|
from ..quasi_newton.quasi_newton import HessianUpdateStrategy
|
|
10
|
-
from ..
|
|
10
|
+
from ..opt_utils import safe_clip
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
class ConguateGradientBase(TensorTransform, ABC):
|
|
@@ -68,7 +68,7 @@ class ConguateGradientBase(TensorTransform, ABC):
|
|
|
68
68
|
self.increment_counter("step", start=0)
|
|
69
69
|
|
|
70
70
|
# initialize on first step
|
|
71
|
-
if self.global_state.get('stage', "first
|
|
71
|
+
if self.global_state.get('stage', "first update") == "first update":
|
|
72
72
|
g_prev, d_prev = unpack_states(states, tensors, 'g_prev', 'd_prev', cls=TensorList)
|
|
73
73
|
d_prev.copy_(tensors)
|
|
74
74
|
g_prev.copy_(tensors)
|
|
@@ -1,8 +1,13 @@
|
|
|
1
1
|
"""Those are various ideas of mine plus some other modules that I decided not to move to other sub-packages for whatever reason. This is generally less tested and shouldn't be used."""
|
|
2
|
+
from .adanystrom import AdaNystrom
|
|
3
|
+
from .common_directions_whiten import CommonDirectionsWhiten
|
|
2
4
|
from .coordinate_momentum import CoordinateMomentum
|
|
5
|
+
from .cubic_adam import CubicAdam, SubspaceCubicAdam
|
|
3
6
|
from .curveball import CurveBall
|
|
7
|
+
from .eigen_sr1 import EigenSR1
|
|
4
8
|
|
|
5
9
|
# from dct import DCTProjection
|
|
10
|
+
from .eigengrad import Eigengrad
|
|
6
11
|
from .fft import FFTProjection
|
|
7
12
|
from .gradmin import GradMin
|
|
8
13
|
from .higher_order_newton import HigherOrderNewton
|
|
@@ -0,0 +1,258 @@
|
|
|
1
|
+
# pylint: disable = non-ascii-name
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
from ...core import Chainable, TensorTransform
|
|
5
|
+
from ...linalg import (
|
|
6
|
+
OrthogonalizeMethod,
|
|
7
|
+
orthogonalize,
|
|
8
|
+
regularize_eigh,
|
|
9
|
+
torch_linalg,
|
|
10
|
+
)
|
|
11
|
+
from ...linalg.linear_operator import Eigendecomposition
|
|
12
|
+
from ..adaptive.lre_optimizers import LREOptimizerBase
|
|
13
|
+
from .eigengrad import _eigengrad_update_state_, eigengrad_apply
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def weighted_eigen_plus_rank1_mm(
|
|
17
|
+
# A1 = Q1 @ diag(L1) @ Q1.T
|
|
18
|
+
L1: torch.Tensor,
|
|
19
|
+
Q1: torch.Tensor,
|
|
20
|
+
|
|
21
|
+
# K2 = v2 @ v2.T
|
|
22
|
+
v2: torch.Tensor,
|
|
23
|
+
|
|
24
|
+
# second matrix
|
|
25
|
+
B: torch.Tensor,
|
|
26
|
+
|
|
27
|
+
# weights
|
|
28
|
+
w1: float,
|
|
29
|
+
w2: float,
|
|
30
|
+
|
|
31
|
+
) -> torch.Tensor:
|
|
32
|
+
"""
|
|
33
|
+
Computes ``(w1 * A1 + w2 * A2) @ B``, where ``A1`` is an eigendecomposition, ``A2`` is symmetric rank 1.
|
|
34
|
+
|
|
35
|
+
Returns ``(n, k)``
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
L1 (torch.Tensor): eigenvalues of A1, shape ``(rank,)``.
|
|
39
|
+
Q1 (torch.Tensor): eigenvectors of A1, shape ``(n, rank)``.
|
|
40
|
+
v2 (torch.Tensor): vector such that ``v v^T = A2``, shape ``(n,)``.
|
|
41
|
+
B (torch.Tensor): shape ``(n, k)``.
|
|
42
|
+
w1 (float): weight for A1.
|
|
43
|
+
w2 (float): weight for A2.
|
|
44
|
+
|
|
45
|
+
"""
|
|
46
|
+
# sketch A1
|
|
47
|
+
QTB = Q1.T @ B # (rank, k)
|
|
48
|
+
LQTB = L1.unsqueeze(1) * QTB # (rank, k)
|
|
49
|
+
sketch1 = Q1 @ LQTB # (n, k)
|
|
50
|
+
|
|
51
|
+
# skecth A2
|
|
52
|
+
vB = v2 @ B
|
|
53
|
+
sketch2 = v2.outer(vB)
|
|
54
|
+
|
|
55
|
+
return w1 * sketch1 + w2 * sketch2
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def adanystrom_update(
|
|
59
|
+
L1: torch.Tensor,
|
|
60
|
+
Q1: torch.Tensor,
|
|
61
|
+
v2: torch.Tensor,
|
|
62
|
+
w1: float,
|
|
63
|
+
w2: float,
|
|
64
|
+
oversampling_p: int,
|
|
65
|
+
rank: int,
|
|
66
|
+
eig_tol: float,
|
|
67
|
+
damping: float,
|
|
68
|
+
rdamping: float,
|
|
69
|
+
orthogonalize_method: OrthogonalizeMethod,
|
|
70
|
+
|
|
71
|
+
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
|
|
72
|
+
"""computes the Nyström approximation of ``(w1 * A1 + w2 * A2)``,
|
|
73
|
+
where ``A1`` is an eigendecomposition, ``A2`` is symmetric rank 1.
|
|
74
|
+
|
|
75
|
+
returns L of shape ``(k, )`` and Q of shape ``(n, k)``.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
L1 (torch.Tensor): eigenvalues of A1, shape ``(rank,)``.
|
|
79
|
+
Q1 (torch.Tensor): eigenvectors of A1, shape ``(n, rank)``.
|
|
80
|
+
v2 (torch.Tensor): vector such that ``v v^T = A2``, shape ``(n,)`` or ``(n, 1)``.
|
|
81
|
+
w1 (float): weight for A1.
|
|
82
|
+
w2 (float): weight for A2.
|
|
83
|
+
"""
|
|
84
|
+
n = Q1.shape[0]
|
|
85
|
+
device = Q1.device
|
|
86
|
+
dtype = Q1.dtype
|
|
87
|
+
l = rank + oversampling_p
|
|
88
|
+
|
|
89
|
+
# gaussian test matrix
|
|
90
|
+
Omega = torch.randn(n, l, device=device, dtype=dtype)
|
|
91
|
+
|
|
92
|
+
# sketch
|
|
93
|
+
AOmega = weighted_eigen_plus_rank1_mm(L1, Q1, v2, Omega, w1, w2)
|
|
94
|
+
Q = orthogonalize(AOmega, orthogonalize_method)
|
|
95
|
+
|
|
96
|
+
AQ = weighted_eigen_plus_rank1_mm(L1, Q1, v2, Q, w1, w2)
|
|
97
|
+
QTAQ = Q.T @ AQ
|
|
98
|
+
|
|
99
|
+
W = (QTAQ + QTAQ.T) / 2.0
|
|
100
|
+
|
|
101
|
+
# compute new L and Q
|
|
102
|
+
try:
|
|
103
|
+
L_prime, S = torch_linalg.eigh(W, retry_float64=True)
|
|
104
|
+
except torch.linalg.LinAlgError:
|
|
105
|
+
return L1, Q1
|
|
106
|
+
|
|
107
|
+
L_prime, S = regularize_eigh(L=L_prime, Q=S, truncate=rank, tol=eig_tol, damping=damping, rdamping=rdamping)
|
|
108
|
+
|
|
109
|
+
if L_prime is None or S is None:
|
|
110
|
+
return L1, Q1
|
|
111
|
+
|
|
112
|
+
return L_prime, Q @ S
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
# def adanystrom_update2(
|
|
116
|
+
# L1: torch.Tensor,
|
|
117
|
+
# Q1: torch.Tensor,
|
|
118
|
+
# v2: torch.Tensor,
|
|
119
|
+
# w1: float,
|
|
120
|
+
# w2: float,
|
|
121
|
+
# rank: int,
|
|
122
|
+
# ):
|
|
123
|
+
# def A_mm(X):
|
|
124
|
+
# return weighted_eigen_plus_rank1_mm(L1=L1, Q1=Q1, v2=v2, B=X, w1=w1, w2=w2)
|
|
125
|
+
|
|
126
|
+
# return nystrom_approximation(A_mm, A_mm=A_mm, ndim=v2.numel(), rank=rank, device=L1.device, dtype=L1.dtype)
|
|
127
|
+
|
|
128
|
+
class AdaNystrom(TensorTransform):
|
|
129
|
+
"""Adagrad/RMSprop/Adam with Nyström-approximated covariance matrix.
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
rank (_type_): rank of Nyström approximation.
|
|
133
|
+
w1 (float, optional): weight of current covariance matrix. Defaults to 0.95.
|
|
134
|
+
w2 (float, optional): weight of new gradient in covariance matrix. Defaults to 0.05.
|
|
135
|
+
oversampling (int, optional): number of extra random vectors (top rank eigenvalues are kept). Defaults to 10.
|
|
136
|
+
eig_tol (float, optional):
|
|
137
|
+
removes eigenvalues this much smaller than largest eigenvalue when updating the preconditioner. Defaults to 1e-7.
|
|
138
|
+
damping (float, optional):
|
|
139
|
+
added to eigenvalues when updating the preconditioner. Defaults to 1e-8.
|
|
140
|
+
rdamping (float, optional):
|
|
141
|
+
added to eigenvalues when updating the preconditioner, relative to largest eigenvalue. Defaults to 0.
|
|
142
|
+
mm_tol (float, optional):
|
|
143
|
+
removes eigenvalues this much smaller than largest eigenvalue when computing the update. Defaults to 1e-7.
|
|
144
|
+
mm_truncate (int | None, optional):
|
|
145
|
+
uses top k eigenvalues to compute the update. Defaults to None.
|
|
146
|
+
mm_damping (float, optional):
|
|
147
|
+
added to eigenvalues when computing the update. Defaults to 1e-4.
|
|
148
|
+
mm_rdamping (float, optional):
|
|
149
|
+
added to eigenvalues when computing the update, relative to largest eigenvalue. Defaults to 0.
|
|
150
|
+
id_reg (float, optional):
|
|
151
|
+
multiplier to identity matrix added to preconditioner before computing update
|
|
152
|
+
If this value is given, solution from Nyström sketch-and-solve will be used to compute the update.
|
|
153
|
+
This value can't be too small (i.e. less than 1e-5) or the solver will be very unstable. Defaults to None.
|
|
154
|
+
concat_params (bool, optional):
|
|
155
|
+
whether to precondition all parameters at once if True, or each separately if False. Defaults to True.
|
|
156
|
+
update_freq (int, optional): update frequency. Defaults to 1.
|
|
157
|
+
inner (Chainable | None, optional): inner modules. Defaults to None.
|
|
158
|
+
"""
|
|
159
|
+
def __init__(
|
|
160
|
+
self,
|
|
161
|
+
rank:int = 100,
|
|
162
|
+
beta=0.95,
|
|
163
|
+
oversampling: int = 10,
|
|
164
|
+
eig_tol: float | None = 1e-32,
|
|
165
|
+
damping: float = 0,
|
|
166
|
+
rdamping: float = 0,
|
|
167
|
+
mm_tol: float = 0,
|
|
168
|
+
mm_truncate: int | None = None,
|
|
169
|
+
mm_damping: float = 0,
|
|
170
|
+
mm_rdamping: float = 0,
|
|
171
|
+
id_reg: float | None = None,
|
|
172
|
+
orthogonalize_method: OrthogonalizeMethod = 'qr',
|
|
173
|
+
eigenbasis_optimizer: LREOptimizerBase | None = None,
|
|
174
|
+
orthogonalize_interval: int | None = 100,
|
|
175
|
+
|
|
176
|
+
concat_params: bool = True,
|
|
177
|
+
update_freq: int = 1,
|
|
178
|
+
inner: Chainable | None = None,
|
|
179
|
+
):
|
|
180
|
+
defaults = locals().copy()
|
|
181
|
+
for k in ["self", "concat_params", "inner", "update_freq"]:
|
|
182
|
+
del defaults[k]
|
|
183
|
+
|
|
184
|
+
super().__init__(defaults, concat_params=concat_params, inner=inner, update_freq=update_freq)
|
|
185
|
+
|
|
186
|
+
def single_tensor_update(self, tensor, param, grad, loss, state, setting):
|
|
187
|
+
state["step"] = state.get("step", 0) + 1
|
|
188
|
+
rank = setting["rank"]
|
|
189
|
+
device = tensor.device
|
|
190
|
+
dtype = tensor.dtype
|
|
191
|
+
beta = setting["beta"]
|
|
192
|
+
|
|
193
|
+
try:
|
|
194
|
+
if "L" not in state:
|
|
195
|
+
# use just tensor and zero L and Q with zero weight
|
|
196
|
+
|
|
197
|
+
L, Q = adanystrom_update(
|
|
198
|
+
L1=torch.zeros(rank, device=device, dtype=dtype),
|
|
199
|
+
Q1=torch.zeros((tensor.numel(), rank), device=device, dtype=dtype),
|
|
200
|
+
v2=tensor.ravel(),
|
|
201
|
+
w1=0,
|
|
202
|
+
w2=1-beta,
|
|
203
|
+
rank=rank,
|
|
204
|
+
oversampling_p=setting["oversampling"],
|
|
205
|
+
eig_tol=setting["eig_tol"],
|
|
206
|
+
damping=setting["damping"],
|
|
207
|
+
rdamping=setting["rdamping"],
|
|
208
|
+
orthogonalize_method=setting["orthogonalize_method"],
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
state["L"] = state["L_reg"] = L
|
|
212
|
+
state["Q"] = state["Q_reg"] = Q
|
|
213
|
+
|
|
214
|
+
else:
|
|
215
|
+
L = state["L"]
|
|
216
|
+
Q = state["Q"]
|
|
217
|
+
|
|
218
|
+
w1 = beta
|
|
219
|
+
w2 = 1 - w1
|
|
220
|
+
|
|
221
|
+
# compute new factors (this function truncates them)
|
|
222
|
+
L_new, Q_new = adanystrom_update(
|
|
223
|
+
L1=L,
|
|
224
|
+
Q1=Q,
|
|
225
|
+
v2=tensor.ravel(),
|
|
226
|
+
w1=w1,
|
|
227
|
+
w2=w2,
|
|
228
|
+
rank=rank,
|
|
229
|
+
oversampling_p=setting["oversampling"],
|
|
230
|
+
eig_tol=setting["eig_tol"],
|
|
231
|
+
damping=setting["damping"],
|
|
232
|
+
rdamping=setting["rdamping"],
|
|
233
|
+
orthogonalize_method=setting["orthogonalize_method"],
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
_eigengrad_update_state_(state=state, setting=setting, L_new=L_new, Q_new=Q_new)
|
|
237
|
+
|
|
238
|
+
except torch.linalg.LinAlgError:
|
|
239
|
+
pass
|
|
240
|
+
|
|
241
|
+
def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
|
|
242
|
+
if "L_reg" not in state:
|
|
243
|
+
return tensor.clip(-0.1, 0.1)
|
|
244
|
+
|
|
245
|
+
if "eigenbasis_state" not in state:
|
|
246
|
+
state["eigenbasis_state"] = {}
|
|
247
|
+
|
|
248
|
+
return eigengrad_apply(
|
|
249
|
+
tensor=tensor,
|
|
250
|
+
L_reg = state["L_reg"],
|
|
251
|
+
Q_reg = state["Q_reg"],
|
|
252
|
+
beta = setting["beta"],
|
|
253
|
+
step = state["step"],
|
|
254
|
+
debias = True,
|
|
255
|
+
id_reg = setting["id_reg"],
|
|
256
|
+
eigenbasis_optimizer = setting["eigenbasis_optimizer"],
|
|
257
|
+
eigenbasis_state = state["eigenbasis_state"]
|
|
258
|
+
)
|