torchzero 0.3.9__py3-none-any.whl → 0.3.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_opts.py +54 -21
- tests/test_tensorlist.py +2 -2
- tests/test_vars.py +61 -61
- torchzero/core/__init__.py +2 -3
- torchzero/core/module.py +49 -49
- torchzero/core/transform.py +219 -158
- torchzero/modules/__init__.py +1 -0
- torchzero/modules/clipping/clipping.py +10 -10
- torchzero/modules/clipping/ema_clipping.py +14 -13
- torchzero/modules/clipping/growth_clipping.py +16 -18
- torchzero/modules/experimental/__init__.py +12 -3
- torchzero/modules/experimental/absoap.py +50 -156
- torchzero/modules/experimental/adadam.py +15 -14
- torchzero/modules/experimental/adamY.py +17 -27
- torchzero/modules/experimental/adasoap.py +19 -129
- torchzero/modules/experimental/curveball.py +12 -12
- torchzero/modules/experimental/diagonal_higher_order_newton.py +225 -0
- torchzero/modules/experimental/eigendescent.py +117 -0
- torchzero/modules/experimental/etf.py +172 -0
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/newton_solver.py +11 -11
- torchzero/modules/experimental/newtonnewton.py +88 -0
- torchzero/modules/experimental/reduce_outward_lr.py +8 -5
- torchzero/modules/experimental/soapy.py +19 -146
- torchzero/modules/experimental/spectral.py +79 -204
- torchzero/modules/experimental/structured_newton.py +12 -12
- torchzero/modules/experimental/subspace_preconditioners.py +13 -10
- torchzero/modules/experimental/tada.py +38 -0
- torchzero/modules/grad_approximation/fdm.py +2 -2
- torchzero/modules/grad_approximation/forward_gradient.py +5 -5
- torchzero/modules/grad_approximation/grad_approximator.py +21 -21
- torchzero/modules/grad_approximation/rfdm.py +28 -15
- torchzero/modules/higher_order/__init__.py +1 -0
- torchzero/modules/higher_order/higher_order_newton.py +256 -0
- torchzero/modules/line_search/backtracking.py +42 -23
- torchzero/modules/line_search/line_search.py +40 -40
- torchzero/modules/line_search/scipy.py +18 -3
- torchzero/modules/line_search/strong_wolfe.py +21 -32
- torchzero/modules/line_search/trust_region.py +18 -6
- torchzero/modules/lr/__init__.py +1 -1
- torchzero/modules/lr/{step_size.py → adaptive.py} +22 -26
- torchzero/modules/lr/lr.py +20 -16
- torchzero/modules/momentum/averaging.py +25 -10
- torchzero/modules/momentum/cautious.py +73 -35
- torchzero/modules/momentum/ema.py +92 -41
- torchzero/modules/momentum/experimental.py +21 -13
- torchzero/modules/momentum/matrix_momentum.py +96 -54
- torchzero/modules/momentum/momentum.py +24 -4
- torchzero/modules/ops/accumulate.py +51 -21
- torchzero/modules/ops/binary.py +36 -36
- torchzero/modules/ops/debug.py +7 -7
- torchzero/modules/ops/misc.py +128 -129
- torchzero/modules/ops/multi.py +19 -19
- torchzero/modules/ops/reduce.py +16 -16
- torchzero/modules/ops/split.py +26 -26
- torchzero/modules/ops/switch.py +4 -4
- torchzero/modules/ops/unary.py +20 -20
- torchzero/modules/ops/utility.py +37 -37
- torchzero/modules/optimizers/adagrad.py +33 -24
- torchzero/modules/optimizers/adam.py +31 -34
- torchzero/modules/optimizers/lion.py +4 -4
- torchzero/modules/optimizers/muon.py +6 -6
- torchzero/modules/optimizers/orthograd.py +4 -5
- torchzero/modules/optimizers/rmsprop.py +13 -16
- torchzero/modules/optimizers/rprop.py +52 -49
- torchzero/modules/optimizers/shampoo.py +17 -23
- torchzero/modules/optimizers/soap.py +12 -19
- torchzero/modules/optimizers/sophia_h.py +13 -13
- torchzero/modules/projections/dct.py +4 -4
- torchzero/modules/projections/fft.py +6 -6
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +57 -57
- torchzero/modules/projections/structural.py +17 -17
- torchzero/modules/quasi_newton/__init__.py +33 -4
- torchzero/modules/quasi_newton/cg.py +67 -17
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +24 -24
- torchzero/modules/quasi_newton/lbfgs.py +12 -12
- torchzero/modules/quasi_newton/lsr1.py +11 -11
- torchzero/modules/quasi_newton/olbfgs.py +19 -19
- torchzero/modules/quasi_newton/quasi_newton.py +254 -47
- torchzero/modules/second_order/newton.py +32 -20
- torchzero/modules/second_order/newton_cg.py +13 -12
- torchzero/modules/second_order/nystrom.py +21 -21
- torchzero/modules/smoothing/gaussian.py +21 -21
- torchzero/modules/smoothing/laplacian.py +7 -9
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +43 -9
- torchzero/modules/wrappers/optim_wrapper.py +11 -11
- torchzero/optim/wrappers/directsearch.py +244 -0
- torchzero/optim/wrappers/fcmaes.py +97 -0
- torchzero/optim/wrappers/mads.py +90 -0
- torchzero/optim/wrappers/nevergrad.py +4 -4
- torchzero/optim/wrappers/nlopt.py +28 -14
- torchzero/optim/wrappers/optuna.py +70 -0
- torchzero/optim/wrappers/scipy.py +162 -13
- torchzero/utils/__init__.py +2 -6
- torchzero/utils/derivatives.py +2 -1
- torchzero/utils/optimizer.py +55 -74
- torchzero/utils/python_tools.py +17 -4
- {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/METADATA +14 -14
- torchzero-0.3.10.dist-info/RECORD +139 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/WHEEL +1 -1
- torchzero/core/preconditioner.py +0 -138
- torchzero/modules/experimental/algebraic_newton.py +0 -145
- torchzero/modules/experimental/tropical_newton.py +0 -136
- torchzero-0.3.9.dist-info/RECORD +0 -131
- {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/top_level.txt +0 -0
|
@@ -2,181 +2,41 @@ from abc import ABC, abstractmethod
|
|
|
2
2
|
import math
|
|
3
3
|
from collections import deque
|
|
4
4
|
from typing import Literal, Any
|
|
5
|
+
import itertools
|
|
5
6
|
|
|
6
7
|
import torch
|
|
7
|
-
from ...core import Chainable,
|
|
8
|
+
from ...core import Chainable, TensorwiseTransform
|
|
8
9
|
from ...utils.linalg.matrix_funcs import matrix_power_eigh
|
|
9
10
|
from ...utils.linalg.svd import randomized_svd
|
|
10
11
|
from ...utils.linalg.qr import qr_householder
|
|
11
12
|
|
|
13
|
+
def spectral_update(history, damping, rdamping, true_damping: bool):
|
|
14
|
+
M_hist = torch.stack(tuple(history), dim=1)
|
|
15
|
+
device = M_hist.device
|
|
16
|
+
M_hist = M_hist.cuda()
|
|
12
17
|
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
"""returns stuff for apply"""
|
|
17
|
-
@abstractmethod
|
|
18
|
-
def apply(self, __g: torch.Tensor, __A:torch.Tensor, __B:torch.Tensor) -> torch.Tensor:
|
|
19
|
-
"""apply preconditioning to tensor"""
|
|
18
|
+
try:
|
|
19
|
+
U, S, _ = torch.linalg.svd(M_hist, full_matrices=False, driver='gesvda') # pylint:disable=not-callable
|
|
20
|
+
U = U.to(device); S = S.to(device)
|
|
20
21
|
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
22
|
+
if damping != 0 or rdamping != 0:
|
|
23
|
+
if rdamping != 0: rdamping *= torch.linalg.vector_norm(S) # pylint:disable=not-callable
|
|
24
|
+
Iu = damping + rdamping
|
|
25
|
+
if true_damping:
|
|
26
|
+
S.pow_(2)
|
|
27
|
+
Iu **= 2
|
|
28
|
+
S.add_(Iu)
|
|
29
|
+
if true_damping: S.sqrt_()
|
|
29
30
|
|
|
30
|
-
|
|
31
|
-
U, S, _ = torch.linalg.svd(M_hist, full_matrices=False, driver=self.driver) # pylint:disable=not-callable
|
|
31
|
+
return U, 1/S
|
|
32
32
|
|
|
33
|
-
|
|
34
|
-
|
|
33
|
+
except torch.linalg.LinAlgError:
|
|
34
|
+
return None, None
|
|
35
35
|
|
|
36
|
-
|
|
37
|
-
|
|
36
|
+
def spectral_apply(g: torch.Tensor, U: torch.Tensor, S_inv: torch.Tensor):
|
|
37
|
+
Utg = (U.T @ g)*S_inv
|
|
38
|
+
return U @ Utg
|
|
38
39
|
|
|
39
|
-
except torch.linalg.LinAlgError:
|
|
40
|
-
return None, None
|
|
41
|
-
|
|
42
|
-
def apply(self, g: torch.Tensor, U: torch.Tensor, S: torch.Tensor):
|
|
43
|
-
Utg = (U.T @ g).div_(S)
|
|
44
|
-
return U @ Utg
|
|
45
|
-
|
|
46
|
-
class _SVDLowRankSolver(_Solver):
|
|
47
|
-
def __init__(self, q: int = 6, niter: int = 2): self.q, self.niter = q, niter
|
|
48
|
-
def update(self, history, damping):
|
|
49
|
-
M_hist = torch.stack(tuple(history), dim=1)
|
|
50
|
-
try:
|
|
51
|
-
U, S, _ = torch.svd_lowrank(M_hist, q=self.q, niter=self.niter)
|
|
52
|
-
if damping is not None and damping != 0: S.add_(damping)
|
|
53
|
-
return U, S
|
|
54
|
-
except torch.linalg.LinAlgError:
|
|
55
|
-
return None, None
|
|
56
|
-
|
|
57
|
-
def apply(self, g: torch.Tensor, U: torch.Tensor, S: torch.Tensor):
|
|
58
|
-
Utg = (U.T @ g).div_(S)
|
|
59
|
-
return U @ Utg
|
|
60
|
-
|
|
61
|
-
class _RandomizedSVDSolver(_Solver):
|
|
62
|
-
def __init__(self, k: int = 3, driver: str | None = 'gesvda'):
|
|
63
|
-
self.driver = driver
|
|
64
|
-
self.k = k
|
|
65
|
-
|
|
66
|
-
def update(self, history, damping):
|
|
67
|
-
M_hist = torch.stack(tuple(history), dim=1)
|
|
68
|
-
device = None # driver is CUDA only
|
|
69
|
-
if self.driver is not None:
|
|
70
|
-
device = M_hist.device
|
|
71
|
-
M_hist = M_hist.cuda()
|
|
72
|
-
|
|
73
|
-
try:
|
|
74
|
-
U, S, _ = randomized_svd(M_hist, k=self.k, driver=self.driver)
|
|
75
|
-
|
|
76
|
-
if self.driver is not None:
|
|
77
|
-
U = U.to(device); S = S.to(device)
|
|
78
|
-
|
|
79
|
-
if damping is not None and damping != 0: S.add_(damping)
|
|
80
|
-
return U, S
|
|
81
|
-
|
|
82
|
-
except torch.linalg.LinAlgError:
|
|
83
|
-
return None, None
|
|
84
|
-
|
|
85
|
-
def apply(self, g: torch.Tensor, U: torch.Tensor, S: torch.Tensor):
|
|
86
|
-
Utg = (U.T @ g).div_(S)
|
|
87
|
-
return U @ Utg
|
|
88
|
-
|
|
89
|
-
class _QRDiagonalSolver(_Solver):
|
|
90
|
-
def __init__(self, sqrt=True): self.sqrt = sqrt
|
|
91
|
-
def update(self, history, damping):
|
|
92
|
-
M_hist = torch.stack(tuple(history), dim=1)
|
|
93
|
-
try:
|
|
94
|
-
Q, R = torch.linalg.qr(M_hist, mode='reduced') # pylint:disable=not-callable
|
|
95
|
-
R_diag = R.diag().abs()
|
|
96
|
-
if damping is not None and damping != 0: R_diag.add_(damping)
|
|
97
|
-
if self.sqrt: R_diag.sqrt_()
|
|
98
|
-
return Q, R_diag
|
|
99
|
-
except torch.linalg.LinAlgError:
|
|
100
|
-
return None, None
|
|
101
|
-
|
|
102
|
-
def apply(self, g: torch.Tensor, Q: torch.Tensor, R_diag: torch.Tensor):
|
|
103
|
-
Qtg = (Q.T @ g).div_(R_diag)
|
|
104
|
-
return Q @ Qtg
|
|
105
|
-
|
|
106
|
-
class _QRSolver(_Solver):
|
|
107
|
-
def __init__(self, sqrt=True): self.sqrt = sqrt
|
|
108
|
-
def update(self, history, damping):
|
|
109
|
-
M_hist = torch.stack(tuple(history), dim=1)
|
|
110
|
-
try:
|
|
111
|
-
# Q: d x k, R: k x k
|
|
112
|
-
Q, R = torch.linalg.qr(M_hist, mode='reduced') # pylint:disable=not-callable
|
|
113
|
-
A = R @ R.T
|
|
114
|
-
if damping is not None and damping != 0: A.diagonal(dim1=-2, dim2=-1).add_(damping)
|
|
115
|
-
if self.sqrt: A = matrix_power_eigh(A, 0.5)
|
|
116
|
-
return Q, A
|
|
117
|
-
except (torch.linalg.LinAlgError):
|
|
118
|
-
return None,None
|
|
119
|
-
|
|
120
|
-
def apply(self, g: torch.Tensor, Q: torch.Tensor, A: torch.Tensor) -> torch.Tensor:
|
|
121
|
-
g_proj = Q.T @ g
|
|
122
|
-
y, _ = torch.linalg.solve_ex(A, g_proj) # pylint:disable=not-callable
|
|
123
|
-
return Q @ y
|
|
124
|
-
|
|
125
|
-
class _QRHouseholderSolver(_Solver):
|
|
126
|
-
def __init__(self, sqrt=True): self.sqrt = sqrt
|
|
127
|
-
def update(self, history, damping):
|
|
128
|
-
M_hist = torch.stack(tuple(history), dim=1)
|
|
129
|
-
try:
|
|
130
|
-
# Q: d x k, R: k x k
|
|
131
|
-
Q, R = qr_householder(M_hist, mode='reduced') # pylint:disable=not-callable
|
|
132
|
-
A = R @ R.T
|
|
133
|
-
if damping is not None and damping != 0: A.diagonal(dim1=-2, dim2=-1).add_(damping)
|
|
134
|
-
if self.sqrt: A = matrix_power_eigh(A, 0.5)
|
|
135
|
-
return Q, A
|
|
136
|
-
except (torch.linalg.LinAlgError):
|
|
137
|
-
return None,None
|
|
138
|
-
|
|
139
|
-
def apply(self, g: torch.Tensor, Q: torch.Tensor, A: torch.Tensor) -> torch.Tensor:
|
|
140
|
-
g_proj = Q.T @ g
|
|
141
|
-
y, _ = torch.linalg.solve_ex(A, g_proj) # pylint:disable=not-callable
|
|
142
|
-
return Q @ y
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
class _EighSolver(_Solver):
|
|
146
|
-
def __init__(self, sqrt=True):
|
|
147
|
-
self.sqrt = sqrt
|
|
148
|
-
|
|
149
|
-
def update(self, history, damping):
|
|
150
|
-
M_hist = torch.stack(tuple(history), dim=1)
|
|
151
|
-
grams = M_hist @ M_hist.T # (d, d)
|
|
152
|
-
if damping is not None and damping != 0: grams.diagonal(dim1=-2, dim2=-1).add_(damping)
|
|
153
|
-
try:
|
|
154
|
-
L, Q = torch.linalg.eigh(grams) # L: (d,), Q: (d, d) # pylint:disable=not-callable
|
|
155
|
-
L = L.abs().clamp_(min=1e-12)
|
|
156
|
-
if self.sqrt: L = L.sqrt()
|
|
157
|
-
return Q, L
|
|
158
|
-
except torch.linalg.LinAlgError:
|
|
159
|
-
return None, None
|
|
160
|
-
|
|
161
|
-
def apply(self, g: torch.Tensor, Q: torch.Tensor, L: torch.Tensor) -> torch.Tensor:
|
|
162
|
-
Qtg = (Q.T @ g).div_(L)
|
|
163
|
-
return Q @ Qtg
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
SOLVERS = {
|
|
167
|
-
"svd": _SVDSolver(), # fallbacks on "gesvd" which basically takes ages or just hangs completely
|
|
168
|
-
"svd_gesvdj": _SVDSolver("gesvdj"), # no fallback on slow "gesvd"
|
|
169
|
-
"svd_gesvda": _SVDSolver("gesvda"), # approximate method for wide matrices, sometimes better sometimes worse but faster
|
|
170
|
-
"svd_lowrank": _SVDLowRankSolver(), # maybe need to tune parameters for this, with current ones its slower and worse
|
|
171
|
-
"randomized_svd2": _RandomizedSVDSolver(2),
|
|
172
|
-
"randomized_svd3": _RandomizedSVDSolver(3),
|
|
173
|
-
"randomized_svd4": _RandomizedSVDSolver(4),
|
|
174
|
-
"randomized_svd5": _RandomizedSVDSolver(5),
|
|
175
|
-
"eigh": _EighSolver(), # this is O(n**2) storage, but is this more accurate?
|
|
176
|
-
"qr": _QRSolver(),
|
|
177
|
-
"qr_householder": _QRHouseholderSolver(), # this is slower... but maybe it won't freeze? I think svd_gesvda is better
|
|
178
|
-
"qrdiag": _QRDiagonalSolver(),
|
|
179
|
-
}
|
|
180
40
|
|
|
181
41
|
def maybe_lerp_(state_: dict, beta: float | None, key, value: Any):
|
|
182
42
|
if (key not in state_) or (beta is None) or (not isinstance(value, torch.Tensor)): state_[key] = value
|
|
@@ -184,63 +44,76 @@ def maybe_lerp_(state_: dict, beta: float | None, key, value: Any):
|
|
|
184
44
|
if state_[key].shape != value.shape: state_[key] = value
|
|
185
45
|
else: state_[key].lerp_(value, 1-beta)
|
|
186
46
|
|
|
187
|
-
class SpectralPreconditioner(
|
|
188
|
-
"""
|
|
47
|
+
class SpectralPreconditioner(TensorwiseTransform):
|
|
48
|
+
"""
|
|
49
|
+
The update rule is to stack recent gradients into M, compute U, S <- SVD(M), then calculate U (Uᵀg)/S.
|
|
50
|
+
This is equivalent to full matrix Adagrad with accumulator initialized to zeros,
|
|
51
|
+
except only recent :code:`history_size` gradients are used.
|
|
52
|
+
However this doesn't require N^2 memory and is computationally less expensive than Shampoo.
|
|
189
53
|
|
|
190
54
|
Args:
|
|
191
|
-
history_size (int, optional): number of past gradients to store
|
|
192
|
-
update_freq (int, optional):
|
|
193
|
-
damping (float, optional): damping
|
|
55
|
+
history_size (int, optional): number of past gradients to store. Defaults to 10.
|
|
56
|
+
update_freq (int, optional): frequency of updating the preconditioner (U and S). Defaults to 1.
|
|
57
|
+
damping (float, optional): damping value. Defaults to 1e-4.
|
|
58
|
+
rdamping (float, optional): value of damping relative to singular values norm. Defaults to 0.
|
|
194
59
|
order (int, optional):
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
inner (Chainable | None, optional): Inner modules applied after updating preconditioner and before applying it. Defaults to None.
|
|
60
|
+
order=2 means gradient differences are used in place of gradients. Higher order uses higher order differences. Defaults to 1.
|
|
61
|
+
true_damping (bool, optional):
|
|
62
|
+
If True, damping is added to squared singular values to mimic Adagrad. Defaults to True.
|
|
63
|
+
U_beta (float | None, optional): momentum for U (too unstable, don't use). Defaults to None.
|
|
64
|
+
S_beta (float | None, optional): momentum for 1/S (too unstable, don't use). Defaults to None.
|
|
65
|
+
interval (int, optional): Interval between gradients that are added to history (2 means every second gradient is used). Defaults to 1.
|
|
66
|
+
concat_params (bool, optional): if True, treats all parameters as a single vector, meaning it will also whiten inter-parameters. Defaults to False.
|
|
67
|
+
normalize (bool, optional): whether to normalize gradients, this doesn't work well so don't use it. Defaults to False.
|
|
68
|
+
centralize (bool, optional): whether to centralize gradients, this doesn't work well so don't use it. Defaults to False.
|
|
69
|
+
inner (Chainable | None, optional): preconditioner will be applied to output of this module. Defaults to None.
|
|
206
70
|
"""
|
|
71
|
+
|
|
207
72
|
def __init__(
|
|
208
73
|
self,
|
|
209
74
|
history_size: int = 10,
|
|
210
75
|
update_freq: int = 1,
|
|
211
|
-
damping: float = 1e-
|
|
76
|
+
damping: float = 1e-4,
|
|
77
|
+
rdamping: float = 0,
|
|
212
78
|
order: int = 1,
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
79
|
+
true_damping: bool = True,
|
|
80
|
+
U_beta: float | None = None,
|
|
81
|
+
S_beta: float | None = None,
|
|
216
82
|
interval: int = 1,
|
|
217
83
|
concat_params: bool = False,
|
|
218
|
-
|
|
84
|
+
normalize: bool=False,
|
|
85
|
+
centralize:bool = False,
|
|
219
86
|
inner: Chainable | None = None,
|
|
220
87
|
):
|
|
221
|
-
if isinstance(solver, str): solver = SOLVERS[solver]
|
|
222
88
|
# history is still updated each step so Precondition's update_freq has different meaning
|
|
223
|
-
defaults = dict(history_size=history_size, update_freq=update_freq, damping=damping, order=order,
|
|
224
|
-
super().__init__(defaults, uses_grad=False, concat_params=concat_params,
|
|
89
|
+
defaults = dict(history_size=history_size, update_freq=update_freq, damping=damping, rdamping=rdamping, true_damping=true_damping, order=order, U_beta=U_beta, S_beta=S_beta, normalize=normalize, centralize=centralize)
|
|
90
|
+
super().__init__(defaults, uses_grad=False, concat_params=concat_params, inner=inner, update_freq=interval)
|
|
225
91
|
|
|
226
92
|
@torch.no_grad
|
|
227
|
-
def update_tensor(self, tensor, param, grad, state, settings):
|
|
93
|
+
def update_tensor(self, tensor, param, grad, loss, state, settings):
|
|
228
94
|
order = settings['order']
|
|
229
95
|
history_size = settings['history_size']
|
|
230
96
|
update_freq = settings['update_freq']
|
|
231
97
|
damping = settings['damping']
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
98
|
+
rdamping = settings['rdamping']
|
|
99
|
+
true_damping = settings['true_damping']
|
|
100
|
+
U_beta = settings['U_beta']
|
|
101
|
+
S_beta = settings['S_beta']
|
|
102
|
+
normalize = settings['normalize']
|
|
103
|
+
centralize = settings['centralize']
|
|
235
104
|
|
|
236
105
|
if 'history' not in state: state['history'] = deque(maxlen=history_size)
|
|
237
106
|
history = state['history']
|
|
238
107
|
|
|
239
|
-
if order == 1:
|
|
108
|
+
if order == 1:
|
|
109
|
+
t = tensor.clone().view(-1)
|
|
110
|
+
if centralize: t -= t.mean()
|
|
111
|
+
if normalize: t /= torch.linalg.vector_norm(t).clip(min=1e-8) # pylint:disable=not-callable
|
|
112
|
+
history.append(t)
|
|
240
113
|
else:
|
|
241
114
|
|
|
242
115
|
# if order=2, history is of gradient differences, order 3 is differences between differences, etc
|
|
243
|
-
#
|
|
116
|
+
# scaled by parameter differences
|
|
244
117
|
cur_p = param.clone()
|
|
245
118
|
cur_g = tensor.clone()
|
|
246
119
|
for i in range(1, order):
|
|
@@ -257,32 +130,34 @@ class SpectralPreconditioner(TensorwisePreconditioner):
|
|
|
257
130
|
cur_g = y_k
|
|
258
131
|
|
|
259
132
|
if i == order - 1:
|
|
260
|
-
cur_g = cur_g
|
|
133
|
+
if centralize: cur_g = cur_g - cur_g.mean()
|
|
134
|
+
if normalize: cur_g = cur_g / torch.linalg.vector_norm(cur_g).clip(min=1e-8) # pylint:disable=not-callable
|
|
135
|
+
else: cur_g = cur_g / torch.linalg.norm(cur_p).clip(min=1e-8) # pylint:disable=not-callable
|
|
261
136
|
history.append(cur_g.view(-1))
|
|
262
137
|
|
|
263
138
|
step = state.get('step', 0)
|
|
264
139
|
if step % update_freq == 0 and len(history) != 0:
|
|
265
|
-
|
|
266
|
-
maybe_lerp_(state,
|
|
267
|
-
maybe_lerp_(state,
|
|
140
|
+
U, S_inv = spectral_update(history, damping=damping, rdamping=rdamping, true_damping=true_damping)
|
|
141
|
+
maybe_lerp_(state, U_beta, 'U', U)
|
|
142
|
+
maybe_lerp_(state, S_beta, 'S_inv', S_inv)
|
|
268
143
|
|
|
269
144
|
if len(history) != 0:
|
|
270
145
|
state['step'] = step + 1 # do not increment if no history (gathering s_ks and y_ks)
|
|
271
146
|
|
|
272
147
|
@torch.no_grad
|
|
273
|
-
def apply_tensor(self, tensor, param, grad, state, settings):
|
|
148
|
+
def apply_tensor(self, tensor, param, grad, loss, state, settings):
|
|
274
149
|
history_size = settings['history_size']
|
|
275
|
-
solver: _Solver = settings['solver']
|
|
276
150
|
|
|
277
|
-
|
|
278
|
-
if
|
|
151
|
+
U = state.get('U', None)
|
|
152
|
+
if U is None:
|
|
279
153
|
# make a conservative step to avoid issues due to different GD scaling
|
|
280
154
|
return tensor.clip_(-0.1, 0.1) # pyright:ignore[reportArgumentType]
|
|
281
155
|
|
|
282
|
-
|
|
283
|
-
update =
|
|
156
|
+
S_inv = state['S_inv']
|
|
157
|
+
update = spectral_apply(tensor.view(-1), U, S_inv).view_as(tensor)
|
|
284
158
|
|
|
285
159
|
n = len(state['history'])
|
|
286
|
-
|
|
160
|
+
mh = min(history_size, 10)
|
|
161
|
+
if n <= mh: update.mul_(n/mh)
|
|
287
162
|
return update
|
|
288
163
|
|
|
@@ -6,7 +6,7 @@ from typing import Literal
|
|
|
6
6
|
|
|
7
7
|
import torch
|
|
8
8
|
|
|
9
|
-
from ...core import Chainable, Module,
|
|
9
|
+
from ...core import Chainable, Module, apply_transform
|
|
10
10
|
from ...utils import TensorList, vec_to_tensors
|
|
11
11
|
from ...utils.derivatives import (
|
|
12
12
|
hessian_list_to_mat,
|
|
@@ -19,7 +19,7 @@ from ...utils.derivatives import (
|
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
class StructuredNewton(Module):
|
|
22
|
-
"""TODO
|
|
22
|
+
"""TODO. Please note that this is experimental and isn't guaranteed to work.
|
|
23
23
|
Args:
|
|
24
24
|
structure (str, optional): structure.
|
|
25
25
|
reg (float, optional): tikhonov regularizer value. Defaults to 1e-6.
|
|
@@ -55,9 +55,9 @@ class StructuredNewton(Module):
|
|
|
55
55
|
self.set_child('inner', inner)
|
|
56
56
|
|
|
57
57
|
@torch.no_grad
|
|
58
|
-
def step(self,
|
|
59
|
-
params = TensorList(
|
|
60
|
-
closure =
|
|
58
|
+
def step(self, var):
|
|
59
|
+
params = TensorList(var.params)
|
|
60
|
+
closure = var.closure
|
|
61
61
|
if closure is None: raise RuntimeError('NewtonCG requires closure')
|
|
62
62
|
|
|
63
63
|
settings = self.settings[params[0]]
|
|
@@ -68,19 +68,19 @@ class StructuredNewton(Module):
|
|
|
68
68
|
|
|
69
69
|
# ------------------------ calculate grad and hessian ------------------------ #
|
|
70
70
|
if hvp_method == 'autograd':
|
|
71
|
-
grad =
|
|
71
|
+
grad = var.get_grad(create_graph=True)
|
|
72
72
|
def Hvp_fn1(x):
|
|
73
73
|
return hvp(params, grad, x, retain_graph=True)
|
|
74
74
|
Hvp_fn = Hvp_fn1
|
|
75
75
|
|
|
76
76
|
elif hvp_method == 'forward':
|
|
77
|
-
grad =
|
|
77
|
+
grad = var.get_grad()
|
|
78
78
|
def Hvp_fn2(x):
|
|
79
79
|
return hvp_fd_forward(closure, params, x, h=h, g_0=grad, normalize=True)[1]
|
|
80
80
|
Hvp_fn = Hvp_fn2
|
|
81
81
|
|
|
82
82
|
elif hvp_method == 'central':
|
|
83
|
-
grad =
|
|
83
|
+
grad = var.get_grad()
|
|
84
84
|
def Hvp_fn3(x):
|
|
85
85
|
return hvp_fd_central(closure, params, x, h=h, normalize=True)[1]
|
|
86
86
|
Hvp_fn = Hvp_fn3
|
|
@@ -88,9 +88,9 @@ class StructuredNewton(Module):
|
|
|
88
88
|
else: raise ValueError(hvp_method)
|
|
89
89
|
|
|
90
90
|
# -------------------------------- inner step -------------------------------- #
|
|
91
|
-
update =
|
|
91
|
+
update = var.get_update()
|
|
92
92
|
if 'inner' in self.children:
|
|
93
|
-
update =
|
|
93
|
+
update = apply_transform(self.children['inner'], update, params=params, grads=grad, var=var)
|
|
94
94
|
|
|
95
95
|
# hessian
|
|
96
96
|
if structure.startswith('diagonal'):
|
|
@@ -99,8 +99,8 @@ class StructuredNewton(Module):
|
|
|
99
99
|
if structure == 'diagonal_abs': torch._foreach_abs_(H)
|
|
100
100
|
torch._foreach_add_(H, reg)
|
|
101
101
|
torch._foreach_div_(update, H)
|
|
102
|
-
|
|
103
|
-
return
|
|
102
|
+
var.update = update
|
|
103
|
+
return var
|
|
104
104
|
|
|
105
105
|
# hessian
|
|
106
106
|
raise NotImplementedError(structure)
|
|
@@ -5,7 +5,7 @@ import torch
|
|
|
5
5
|
|
|
6
6
|
# import torchzero as tz
|
|
7
7
|
|
|
8
|
-
from ...core import Transform, Chainable,
|
|
8
|
+
from ...core import Transform, Chainable, apply_transform
|
|
9
9
|
from ...utils.linalg import inv_sqrt_2x2, matrix_power_eigh, gram_schmidt
|
|
10
10
|
from ...utils import TensorList, vec_to_tensors_
|
|
11
11
|
|
|
@@ -38,15 +38,15 @@ def apply_subspace_preconditioner(
|
|
|
38
38
|
return basis @ update_projected # d
|
|
39
39
|
|
|
40
40
|
class RandomSubspacePreconditioning(Transform):
|
|
41
|
-
"""
|
|
41
|
+
"""Whitens in random slowly changing subspace. Please note that this is experimental and isn't guaranteed to work."""
|
|
42
42
|
def __init__(self, k: int, beta: float | None = 0.99, basis_beta: float | None = 0.99, inner: Chainable | None = None):
|
|
43
43
|
defaults = dict(k=k, beta=beta, basis_beta=basis_beta)
|
|
44
44
|
super().__init__(defaults, uses_grad=False)
|
|
45
45
|
|
|
46
46
|
if inner is not None: self.set_child('inner', inner)
|
|
47
47
|
|
|
48
|
-
def
|
|
49
|
-
settings =
|
|
48
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
49
|
+
settings = settings[0]
|
|
50
50
|
g = torch.cat([t.view(-1) for t in tensors])
|
|
51
51
|
k = settings['k']
|
|
52
52
|
beta = settings['beta']
|
|
@@ -65,7 +65,7 @@ class RandomSubspacePreconditioning(Transform):
|
|
|
65
65
|
update_subspace_preconditioner_(g, basis, accumulator, beta)
|
|
66
66
|
|
|
67
67
|
if 'inner' in self.children:
|
|
68
|
-
tensors =
|
|
68
|
+
tensors = apply_transform(self.children['inner'], tensors, params, grads)
|
|
69
69
|
g = torch.cat([t.view(-1) for t in tensors])
|
|
70
70
|
|
|
71
71
|
try:
|
|
@@ -78,9 +78,12 @@ class RandomSubspacePreconditioning(Transform):
|
|
|
78
78
|
|
|
79
79
|
|
|
80
80
|
class HistorySubspacePreconditioning(Transform):
|
|
81
|
-
"""
|
|
81
|
+
"""Whitens in subspace spanned by history of gradient differences.
|
|
82
|
+
Please note that this is experimental and isn't guaranteed to work.
|
|
82
83
|
|
|
83
|
-
|
|
84
|
+
Args:
|
|
85
|
+
beta - for preconditioner itself in the basis.
|
|
86
|
+
basis_beta - how much basis is allowed to change.
|
|
84
87
|
"""
|
|
85
88
|
def __init__(self, k: int, beta: float | None = 0.99, basis_beta=0.99, inner: Chainable | None = None):
|
|
86
89
|
defaults = dict(k=k, beta=beta, basis_beta=basis_beta)
|
|
@@ -88,8 +91,8 @@ class HistorySubspacePreconditioning(Transform):
|
|
|
88
91
|
|
|
89
92
|
if inner is not None: self.set_child('inner', inner)
|
|
90
93
|
|
|
91
|
-
def
|
|
92
|
-
settings =
|
|
94
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
95
|
+
settings = settings[0]
|
|
93
96
|
|
|
94
97
|
g = torch.cat([t.view(-1) for t in tensors])
|
|
95
98
|
k = settings['k']
|
|
@@ -122,7 +125,7 @@ class HistorySubspacePreconditioning(Transform):
|
|
|
122
125
|
update_subspace_preconditioner_(g, basis, accumulator, beta)
|
|
123
126
|
|
|
124
127
|
if 'inner' in self.children:
|
|
125
|
-
tensors =
|
|
128
|
+
tensors = apply_transform(self.children['inner'], tensors, params, grads)
|
|
126
129
|
g = torch.cat([t.view(-1) for t in tensors])
|
|
127
130
|
|
|
128
131
|
try:
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
from collections import deque
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from ...core import Chainable, TensorwiseTransform
|
|
6
|
+
from ...utils.linalg import matrix_power_eigh
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class TAda(TensorwiseTransform):
|
|
10
|
+
"""3rd order whitening (maybe normalizes skewness). Please note that this is experimental and isn't guaranteed to work."""
|
|
11
|
+
def __init__(self, history_size: int = 100, reg: float = 1e-8, update_freq: int = 1, concat_params: bool = True, inner: Chainable | None = None):
|
|
12
|
+
defaults = dict(history_size=history_size, reg=reg)
|
|
13
|
+
super().__init__(defaults, uses_grad=False, update_freq=update_freq, inner=inner, concat_params=concat_params)
|
|
14
|
+
|
|
15
|
+
@torch.no_grad
|
|
16
|
+
def update_tensor(self, tensor, param, grad, loss, state, settings):
|
|
17
|
+
reg = settings['reg']
|
|
18
|
+
if 'history' not in state:
|
|
19
|
+
state['history'] = deque(maxlen=settings['history_size'])
|
|
20
|
+
|
|
21
|
+
g = tensor.view(-1)
|
|
22
|
+
history = state['history']
|
|
23
|
+
history.append(g.clone())
|
|
24
|
+
|
|
25
|
+
I = torch.eye(tensor.numel(), device=tensor.device, dtype=tensor.dtype).mul_(reg)
|
|
26
|
+
g_k = history[0]
|
|
27
|
+
outer = torch.outer(g_k, g_k).mul_(torch.dot(g_k, g).clip(min=reg))
|
|
28
|
+
if len(history) > 1:
|
|
29
|
+
for g_k in list(history)[1:]:
|
|
30
|
+
outer += torch.outer(g_k, g_k).mul_(torch.dot(g_k, g).clip(min=reg))
|
|
31
|
+
|
|
32
|
+
state['outer'] = outer.add_(I)
|
|
33
|
+
|
|
34
|
+
@torch.no_grad
|
|
35
|
+
def apply_tensor(self, tensor, param, grad, loss, state, settings):
|
|
36
|
+
outer = state['outer']
|
|
37
|
+
P = matrix_power_eigh(outer, -1/2)
|
|
38
|
+
return (P @ tensor.ravel()).view_as(tensor)
|
|
@@ -93,14 +93,14 @@ class FDM(GradApproximator):
|
|
|
93
93
|
Args:
|
|
94
94
|
h (float, optional): magnitude of parameter perturbation. Defaults to 1e-3.
|
|
95
95
|
formula (_FD_Formula, optional): finite difference formula. Defaults to 'central2'.
|
|
96
|
-
target (GradTarget, optional): what to set on
|
|
96
|
+
target (GradTarget, optional): what to set on var. Defaults to 'closure'.
|
|
97
97
|
"""
|
|
98
98
|
def __init__(self, h: float=1e-3, formula: _FD_Formula = 'central2', target: GradTarget = 'closure'):
|
|
99
99
|
defaults = dict(h=h, formula=formula)
|
|
100
100
|
super().__init__(defaults, target=target)
|
|
101
101
|
|
|
102
102
|
@torch.no_grad
|
|
103
|
-
def approximate(self, closure, params, loss,
|
|
103
|
+
def approximate(self, closure, params, loss, var):
|
|
104
104
|
grads = []
|
|
105
105
|
loss_approx = None
|
|
106
106
|
|
|
@@ -17,13 +17,13 @@ class ForwardGradient(RandomizedFDM):
|
|
|
17
17
|
n_samples (int, optional): number of random gradient samples. Defaults to 1.
|
|
18
18
|
distribution (Distributions, optional): distribution for random gradient samples. Defaults to "gaussian".
|
|
19
19
|
beta (float, optional):
|
|
20
|
-
|
|
20
|
+
If this is set to a value higher than zero, instead of using directional derivatives in a new random direction on each step, the direction changes gradually with momentum based on this value. This may make it possible to use methods with memory. Defaults to 0.
|
|
21
21
|
pre_generate (bool, optional):
|
|
22
|
-
whether to pre-generate gradient samples before each step. Defaults to True.
|
|
22
|
+
whether to pre-generate gradient samples before each step. If samples are not pre-generated, whenever a method performs multiple closure evaluations, the gradient will be evaluated in different directions each time. Defaults to True.
|
|
23
23
|
jvp_method (str, optional):
|
|
24
|
-
how to calculate jacobian vector product, note that with `forward` and 'central' this is
|
|
24
|
+
how to calculate jacobian vector product, note that with `forward` and 'central' this is equivalent to randomized finite difference. Defaults to 'autograd'.
|
|
25
25
|
h (float, optional): finite difference step size of jvp_method is set to `forward` or `central`. Defaults to 1e-3.
|
|
26
|
-
target (GradTarget, optional): what to set on
|
|
26
|
+
target (GradTarget, optional): what to set on var. Defaults to "closure".
|
|
27
27
|
"""
|
|
28
28
|
PRE_MULTIPLY_BY_H = False
|
|
29
29
|
def __init__(
|
|
@@ -41,7 +41,7 @@ class ForwardGradient(RandomizedFDM):
|
|
|
41
41
|
self.defaults['jvp_method'] = jvp_method
|
|
42
42
|
|
|
43
43
|
@torch.no_grad
|
|
44
|
-
def approximate(self, closure, params, loss,
|
|
44
|
+
def approximate(self, closure, params, loss, var):
|
|
45
45
|
params = TensorList(params)
|
|
46
46
|
loss_approx = None
|
|
47
47
|
|