torchzero 0.4.1__py3-none-any.whl → 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchzero/__init__.py +3 -1
- torchzero/_minimize/__init__.py +0 -0
- torchzero/_minimize/methods.py +95 -0
- torchzero/_minimize/minimize.py +518 -0
- torchzero/core/__init__.py +5 -5
- torchzero/core/chain.py +2 -1
- torchzero/core/functional.py +2 -1
- torchzero/core/module.py +75 -4
- torchzero/core/transform.py +6 -5
- torchzero/linalg/eigh.py +116 -68
- torchzero/linalg/linear_operator.py +1 -0
- torchzero/linalg/orthogonalize.py +60 -5
- torchzero/linalg/sketch.py +39 -0
- torchzero/modules/__init__.py +1 -0
- torchzero/modules/adaptive/adagrad.py +2 -0
- torchzero/modules/adaptive/adam.py +5 -1
- torchzero/modules/adaptive/adan.py +3 -0
- torchzero/modules/adaptive/ggt.py +20 -18
- torchzero/modules/adaptive/lion.py +3 -1
- torchzero/modules/adaptive/mars.py +6 -5
- torchzero/modules/adaptive/msam.py +3 -0
- torchzero/modules/adaptive/rmsprop.py +2 -0
- torchzero/modules/adaptive/rprop.py +9 -7
- torchzero/modules/adaptive/shampoo.py +9 -1
- torchzero/modules/adaptive/soap.py +32 -29
- torchzero/modules/basis/__init__.py +2 -0
- torchzero/modules/basis/ggt_basis.py +199 -0
- torchzero/modules/basis/soap_basis.py +254 -0
- torchzero/modules/clipping/ema_clipping.py +32 -27
- torchzero/modules/clipping/growth_clipping.py +1 -0
- torchzero/modules/experimental/__init__.py +1 -6
- torchzero/modules/experimental/coordinate_momentum.py +2 -0
- torchzero/modules/experimental/cubic_adam.py +4 -0
- torchzero/modules/grad_approximation/__init__.py +3 -2
- torchzero/modules/least_squares/gn.py +6 -0
- torchzero/modules/misc/gradient_accumulation.py +1 -0
- torchzero/modules/misc/misc.py +6 -0
- torchzero/modules/momentum/averaging.py +6 -0
- torchzero/modules/momentum/momentum.py +4 -0
- torchzero/modules/ops/__init__.py +0 -1
- torchzero/modules/ops/accumulate.py +4 -0
- torchzero/modules/ops/higher_level.py +6 -1
- torchzero/modules/second_order/inm.py +4 -0
- torchzero/modules/second_order/newton.py +11 -3
- torchzero/modules/second_order/newton_cg.py +7 -3
- torchzero/modules/second_order/nystrom.py +14 -19
- torchzero/modules/second_order/rsn.py +37 -6
- torchzero/modules/trust_region/trust_region.py +2 -1
- torchzero/utils/benchmarks/logistic.py +33 -18
- torchzero/utils/params.py +13 -1
- torchzero/utils/tensorlist.py +2 -2
- {torchzero-0.4.1.dist-info → torchzero-0.4.2.dist-info}/METADATA +1 -1
- {torchzero-0.4.1.dist-info → torchzero-0.4.2.dist-info}/RECORD +56 -53
- torchzero/modules/experimental/adanystrom.py +0 -258
- torchzero/modules/experimental/common_directions_whiten.py +0 -142
- torchzero/modules/experimental/eigen_sr1.py +0 -182
- torchzero/modules/experimental/eigengrad.py +0 -207
- /torchzero/modules/{experimental → grad_approximation}/spsa1.py +0 -0
- {torchzero-0.4.1.dist-info → torchzero-0.4.2.dist-info}/WHEEL +0 -0
- {torchzero-0.4.1.dist-info → torchzero-0.4.2.dist-info}/top_level.txt +0 -0
|
@@ -1,258 +0,0 @@
|
|
|
1
|
-
# pylint: disable = non-ascii-name
|
|
2
|
-
import torch
|
|
3
|
-
|
|
4
|
-
from ...core import Chainable, TensorTransform
|
|
5
|
-
from ...linalg import (
|
|
6
|
-
OrthogonalizeMethod,
|
|
7
|
-
orthogonalize,
|
|
8
|
-
regularize_eigh,
|
|
9
|
-
torch_linalg,
|
|
10
|
-
)
|
|
11
|
-
from ...linalg.linear_operator import Eigendecomposition
|
|
12
|
-
from ..adaptive.lre_optimizers import LREOptimizerBase
|
|
13
|
-
from .eigengrad import _eigengrad_update_state_, eigengrad_apply
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
def weighted_eigen_plus_rank1_mm(
|
|
17
|
-
# A1 = Q1 @ diag(L1) @ Q1.T
|
|
18
|
-
L1: torch.Tensor,
|
|
19
|
-
Q1: torch.Tensor,
|
|
20
|
-
|
|
21
|
-
# K2 = v2 @ v2.T
|
|
22
|
-
v2: torch.Tensor,
|
|
23
|
-
|
|
24
|
-
# second matrix
|
|
25
|
-
B: torch.Tensor,
|
|
26
|
-
|
|
27
|
-
# weights
|
|
28
|
-
w1: float,
|
|
29
|
-
w2: float,
|
|
30
|
-
|
|
31
|
-
) -> torch.Tensor:
|
|
32
|
-
"""
|
|
33
|
-
Computes ``(w1 * A1 + w2 * A2) @ B``, where ``A1`` is an eigendecomposition, ``A2`` is symmetric rank 1.
|
|
34
|
-
|
|
35
|
-
Returns ``(n, k)``
|
|
36
|
-
|
|
37
|
-
Args:
|
|
38
|
-
L1 (torch.Tensor): eigenvalues of A1, shape ``(rank,)``.
|
|
39
|
-
Q1 (torch.Tensor): eigenvectors of A1, shape ``(n, rank)``.
|
|
40
|
-
v2 (torch.Tensor): vector such that ``v v^T = A2``, shape ``(n,)``.
|
|
41
|
-
B (torch.Tensor): shape ``(n, k)``.
|
|
42
|
-
w1 (float): weight for A1.
|
|
43
|
-
w2 (float): weight for A2.
|
|
44
|
-
|
|
45
|
-
"""
|
|
46
|
-
# sketch A1
|
|
47
|
-
QTB = Q1.T @ B # (rank, k)
|
|
48
|
-
LQTB = L1.unsqueeze(1) * QTB # (rank, k)
|
|
49
|
-
sketch1 = Q1 @ LQTB # (n, k)
|
|
50
|
-
|
|
51
|
-
# skecth A2
|
|
52
|
-
vB = v2 @ B
|
|
53
|
-
sketch2 = v2.outer(vB)
|
|
54
|
-
|
|
55
|
-
return w1 * sketch1 + w2 * sketch2
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
def adanystrom_update(
|
|
59
|
-
L1: torch.Tensor,
|
|
60
|
-
Q1: torch.Tensor,
|
|
61
|
-
v2: torch.Tensor,
|
|
62
|
-
w1: float,
|
|
63
|
-
w2: float,
|
|
64
|
-
oversampling_p: int,
|
|
65
|
-
rank: int,
|
|
66
|
-
eig_tol: float,
|
|
67
|
-
damping: float,
|
|
68
|
-
rdamping: float,
|
|
69
|
-
orthogonalize_method: OrthogonalizeMethod,
|
|
70
|
-
|
|
71
|
-
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
|
|
72
|
-
"""computes the Nyström approximation of ``(w1 * A1 + w2 * A2)``,
|
|
73
|
-
where ``A1`` is an eigendecomposition, ``A2`` is symmetric rank 1.
|
|
74
|
-
|
|
75
|
-
returns L of shape ``(k, )`` and Q of shape ``(n, k)``.
|
|
76
|
-
|
|
77
|
-
Args:
|
|
78
|
-
L1 (torch.Tensor): eigenvalues of A1, shape ``(rank,)``.
|
|
79
|
-
Q1 (torch.Tensor): eigenvectors of A1, shape ``(n, rank)``.
|
|
80
|
-
v2 (torch.Tensor): vector such that ``v v^T = A2``, shape ``(n,)`` or ``(n, 1)``.
|
|
81
|
-
w1 (float): weight for A1.
|
|
82
|
-
w2 (float): weight for A2.
|
|
83
|
-
"""
|
|
84
|
-
n = Q1.shape[0]
|
|
85
|
-
device = Q1.device
|
|
86
|
-
dtype = Q1.dtype
|
|
87
|
-
l = rank + oversampling_p
|
|
88
|
-
|
|
89
|
-
# gaussian test matrix
|
|
90
|
-
Omega = torch.randn(n, l, device=device, dtype=dtype)
|
|
91
|
-
|
|
92
|
-
# sketch
|
|
93
|
-
AOmega = weighted_eigen_plus_rank1_mm(L1, Q1, v2, Omega, w1, w2)
|
|
94
|
-
Q = orthogonalize(AOmega, orthogonalize_method)
|
|
95
|
-
|
|
96
|
-
AQ = weighted_eigen_plus_rank1_mm(L1, Q1, v2, Q, w1, w2)
|
|
97
|
-
QTAQ = Q.T @ AQ
|
|
98
|
-
|
|
99
|
-
W = (QTAQ + QTAQ.T) / 2.0
|
|
100
|
-
|
|
101
|
-
# compute new L and Q
|
|
102
|
-
try:
|
|
103
|
-
L_prime, S = torch_linalg.eigh(W, retry_float64=True)
|
|
104
|
-
except torch.linalg.LinAlgError:
|
|
105
|
-
return L1, Q1
|
|
106
|
-
|
|
107
|
-
L_prime, S = regularize_eigh(L=L_prime, Q=S, truncate=rank, tol=eig_tol, damping=damping, rdamping=rdamping)
|
|
108
|
-
|
|
109
|
-
if L_prime is None or S is None:
|
|
110
|
-
return L1, Q1
|
|
111
|
-
|
|
112
|
-
return L_prime, Q @ S
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
# def adanystrom_update2(
|
|
116
|
-
# L1: torch.Tensor,
|
|
117
|
-
# Q1: torch.Tensor,
|
|
118
|
-
# v2: torch.Tensor,
|
|
119
|
-
# w1: float,
|
|
120
|
-
# w2: float,
|
|
121
|
-
# rank: int,
|
|
122
|
-
# ):
|
|
123
|
-
# def A_mm(X):
|
|
124
|
-
# return weighted_eigen_plus_rank1_mm(L1=L1, Q1=Q1, v2=v2, B=X, w1=w1, w2=w2)
|
|
125
|
-
|
|
126
|
-
# return nystrom_approximation(A_mm, A_mm=A_mm, ndim=v2.numel(), rank=rank, device=L1.device, dtype=L1.dtype)
|
|
127
|
-
|
|
128
|
-
class AdaNystrom(TensorTransform):
|
|
129
|
-
"""Adagrad/RMSprop/Adam with Nyström-approximated covariance matrix.
|
|
130
|
-
|
|
131
|
-
Args:
|
|
132
|
-
rank (_type_): rank of Nyström approximation.
|
|
133
|
-
w1 (float, optional): weight of current covariance matrix. Defaults to 0.95.
|
|
134
|
-
w2 (float, optional): weight of new gradient in covariance matrix. Defaults to 0.05.
|
|
135
|
-
oversampling (int, optional): number of extra random vectors (top rank eigenvalues are kept). Defaults to 10.
|
|
136
|
-
eig_tol (float, optional):
|
|
137
|
-
removes eigenvalues this much smaller than largest eigenvalue when updating the preconditioner. Defaults to 1e-7.
|
|
138
|
-
damping (float, optional):
|
|
139
|
-
added to eigenvalues when updating the preconditioner. Defaults to 1e-8.
|
|
140
|
-
rdamping (float, optional):
|
|
141
|
-
added to eigenvalues when updating the preconditioner, relative to largest eigenvalue. Defaults to 0.
|
|
142
|
-
mm_tol (float, optional):
|
|
143
|
-
removes eigenvalues this much smaller than largest eigenvalue when computing the update. Defaults to 1e-7.
|
|
144
|
-
mm_truncate (int | None, optional):
|
|
145
|
-
uses top k eigenvalues to compute the update. Defaults to None.
|
|
146
|
-
mm_damping (float, optional):
|
|
147
|
-
added to eigenvalues when computing the update. Defaults to 1e-4.
|
|
148
|
-
mm_rdamping (float, optional):
|
|
149
|
-
added to eigenvalues when computing the update, relative to largest eigenvalue. Defaults to 0.
|
|
150
|
-
id_reg (float, optional):
|
|
151
|
-
multiplier to identity matrix added to preconditioner before computing update
|
|
152
|
-
If this value is given, solution from Nyström sketch-and-solve will be used to compute the update.
|
|
153
|
-
This value can't be too small (i.e. less than 1e-5) or the solver will be very unstable. Defaults to None.
|
|
154
|
-
concat_params (bool, optional):
|
|
155
|
-
whether to precondition all parameters at once if True, or each separately if False. Defaults to True.
|
|
156
|
-
update_freq (int, optional): update frequency. Defaults to 1.
|
|
157
|
-
inner (Chainable | None, optional): inner modules. Defaults to None.
|
|
158
|
-
"""
|
|
159
|
-
def __init__(
|
|
160
|
-
self,
|
|
161
|
-
rank:int = 100,
|
|
162
|
-
beta=0.95,
|
|
163
|
-
oversampling: int = 10,
|
|
164
|
-
eig_tol: float | None = 1e-32,
|
|
165
|
-
damping: float = 0,
|
|
166
|
-
rdamping: float = 0,
|
|
167
|
-
mm_tol: float = 0,
|
|
168
|
-
mm_truncate: int | None = None,
|
|
169
|
-
mm_damping: float = 0,
|
|
170
|
-
mm_rdamping: float = 0,
|
|
171
|
-
id_reg: float | None = None,
|
|
172
|
-
orthogonalize_method: OrthogonalizeMethod = 'qr',
|
|
173
|
-
eigenbasis_optimizer: LREOptimizerBase | None = None,
|
|
174
|
-
orthogonalize_interval: int | None = 100,
|
|
175
|
-
|
|
176
|
-
concat_params: bool = True,
|
|
177
|
-
update_freq: int = 1,
|
|
178
|
-
inner: Chainable | None = None,
|
|
179
|
-
):
|
|
180
|
-
defaults = locals().copy()
|
|
181
|
-
for k in ["self", "concat_params", "inner", "update_freq"]:
|
|
182
|
-
del defaults[k]
|
|
183
|
-
|
|
184
|
-
super().__init__(defaults, concat_params=concat_params, inner=inner, update_freq=update_freq)
|
|
185
|
-
|
|
186
|
-
def single_tensor_update(self, tensor, param, grad, loss, state, setting):
|
|
187
|
-
state["step"] = state.get("step", 0) + 1
|
|
188
|
-
rank = setting["rank"]
|
|
189
|
-
device = tensor.device
|
|
190
|
-
dtype = tensor.dtype
|
|
191
|
-
beta = setting["beta"]
|
|
192
|
-
|
|
193
|
-
try:
|
|
194
|
-
if "L" not in state:
|
|
195
|
-
# use just tensor and zero L and Q with zero weight
|
|
196
|
-
|
|
197
|
-
L, Q = adanystrom_update(
|
|
198
|
-
L1=torch.zeros(rank, device=device, dtype=dtype),
|
|
199
|
-
Q1=torch.zeros((tensor.numel(), rank), device=device, dtype=dtype),
|
|
200
|
-
v2=tensor.ravel(),
|
|
201
|
-
w1=0,
|
|
202
|
-
w2=1-beta,
|
|
203
|
-
rank=rank,
|
|
204
|
-
oversampling_p=setting["oversampling"],
|
|
205
|
-
eig_tol=setting["eig_tol"],
|
|
206
|
-
damping=setting["damping"],
|
|
207
|
-
rdamping=setting["rdamping"],
|
|
208
|
-
orthogonalize_method=setting["orthogonalize_method"],
|
|
209
|
-
)
|
|
210
|
-
|
|
211
|
-
state["L"] = state["L_reg"] = L
|
|
212
|
-
state["Q"] = state["Q_reg"] = Q
|
|
213
|
-
|
|
214
|
-
else:
|
|
215
|
-
L = state["L"]
|
|
216
|
-
Q = state["Q"]
|
|
217
|
-
|
|
218
|
-
w1 = beta
|
|
219
|
-
w2 = 1 - w1
|
|
220
|
-
|
|
221
|
-
# compute new factors (this function truncates them)
|
|
222
|
-
L_new, Q_new = adanystrom_update(
|
|
223
|
-
L1=L,
|
|
224
|
-
Q1=Q,
|
|
225
|
-
v2=tensor.ravel(),
|
|
226
|
-
w1=w1,
|
|
227
|
-
w2=w2,
|
|
228
|
-
rank=rank,
|
|
229
|
-
oversampling_p=setting["oversampling"],
|
|
230
|
-
eig_tol=setting["eig_tol"],
|
|
231
|
-
damping=setting["damping"],
|
|
232
|
-
rdamping=setting["rdamping"],
|
|
233
|
-
orthogonalize_method=setting["orthogonalize_method"],
|
|
234
|
-
)
|
|
235
|
-
|
|
236
|
-
_eigengrad_update_state_(state=state, setting=setting, L_new=L_new, Q_new=Q_new)
|
|
237
|
-
|
|
238
|
-
except torch.linalg.LinAlgError:
|
|
239
|
-
pass
|
|
240
|
-
|
|
241
|
-
def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
|
|
242
|
-
if "L_reg" not in state:
|
|
243
|
-
return tensor.clip(-0.1, 0.1)
|
|
244
|
-
|
|
245
|
-
if "eigenbasis_state" not in state:
|
|
246
|
-
state["eigenbasis_state"] = {}
|
|
247
|
-
|
|
248
|
-
return eigengrad_apply(
|
|
249
|
-
tensor=tensor,
|
|
250
|
-
L_reg = state["L_reg"],
|
|
251
|
-
Q_reg = state["Q_reg"],
|
|
252
|
-
beta = setting["beta"],
|
|
253
|
-
step = state["step"],
|
|
254
|
-
debias = True,
|
|
255
|
-
id_reg = setting["id_reg"],
|
|
256
|
-
eigenbasis_optimizer = setting["eigenbasis_optimizer"],
|
|
257
|
-
eigenbasis_state = state["eigenbasis_state"]
|
|
258
|
-
)
|
|
@@ -1,142 +0,0 @@
|
|
|
1
|
-
from collections import deque
|
|
2
|
-
from typing import Literal
|
|
3
|
-
|
|
4
|
-
import torch
|
|
5
|
-
|
|
6
|
-
from torchzero.core import Chainable, TensorTransform
|
|
7
|
-
from torchzero.linalg import matrix_power_eigh, torch_linalg, orthogonalize, OrthogonalizeMethod, regularize_eigh
|
|
8
|
-
from torchzero.utils import TensorList, vec_to_tensors_
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
def update_subspace_preconditioner_(
|
|
12
|
-
grad: torch.Tensor, # store grads and basis as vectors for matmul
|
|
13
|
-
basis: torch.Tensor, # ndim, k
|
|
14
|
-
accumulator_: torch.Tensor, # k, k
|
|
15
|
-
beta: float | None,
|
|
16
|
-
):
|
|
17
|
-
projected = basis.T @ grad # k
|
|
18
|
-
outer = torch.outer(projected, projected)
|
|
19
|
-
|
|
20
|
-
if beta is None: accumulator_.add_(outer)
|
|
21
|
-
else: accumulator_.lerp_(outer, 1-beta)
|
|
22
|
-
|
|
23
|
-
# yeah so I can also run subspace opts in this basis
|
|
24
|
-
def apply_subspace_preconditioner(
|
|
25
|
-
tensor: torch.Tensor,
|
|
26
|
-
basis: torch.Tensor, # ndim, k
|
|
27
|
-
accumulator: torch.Tensor,
|
|
28
|
-
tol: float,
|
|
29
|
-
truncate: int | None,
|
|
30
|
-
damping: float,
|
|
31
|
-
rdamping: float,
|
|
32
|
-
):
|
|
33
|
-
L, Q = torch_linalg.eigh(accumulator, retry_float64=True)
|
|
34
|
-
L, Q = regularize_eigh(L=L, Q=Q, truncate=truncate, tol=tol, damping=damping, rdamping=rdamping)
|
|
35
|
-
|
|
36
|
-
if L is None or Q is None:
|
|
37
|
-
return tensor.clip(-0.1, 0.1)
|
|
38
|
-
|
|
39
|
-
preconditioner = (Q * L.rsqrt().unsqueeze(-2)) @ Q.mH
|
|
40
|
-
|
|
41
|
-
tensor_projected = basis.T @ tensor # k
|
|
42
|
-
update_projected = preconditioner @ tensor_projected # k
|
|
43
|
-
return basis @ update_projected # d
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
class CommonDirectionsWhiten(TensorTransform):
|
|
47
|
-
"""Whitens in subspace spanned by history of gradient differences.
|
|
48
|
-
|
|
49
|
-
Args:
|
|
50
|
-
beta - for preconditioner itself in the basis.
|
|
51
|
-
basis_beta - how much basis is allowed to change.
|
|
52
|
-
"""
|
|
53
|
-
|
|
54
|
-
def __init__(
|
|
55
|
-
self,
|
|
56
|
-
k: int = 100,
|
|
57
|
-
beta: float | None = 0.95,
|
|
58
|
-
basis_beta=0.95,
|
|
59
|
-
tol: float = 1e-7,
|
|
60
|
-
truncate: int | None = None,
|
|
61
|
-
damping: float = 1e-4,
|
|
62
|
-
rdamping: float = 0,
|
|
63
|
-
basis_type: Literal["gradients", "differences"] = "differences",
|
|
64
|
-
orthogonalize_method: OrthogonalizeMethod | None = 'newtonschulz',
|
|
65
|
-
|
|
66
|
-
concat_params: bool = True,
|
|
67
|
-
inner: Chainable | None = None,
|
|
68
|
-
):
|
|
69
|
-
defaults = locals().copy()
|
|
70
|
-
for key in ["self", "inner", "concat_params"]:
|
|
71
|
-
del defaults[key]
|
|
72
|
-
|
|
73
|
-
super().__init__(defaults, concat_params=concat_params, inner=inner)
|
|
74
|
-
|
|
75
|
-
@torch.no_grad
|
|
76
|
-
def single_tensor_update(self, tensor, param, grad, loss, state, setting):
|
|
77
|
-
g = tensor.ravel()
|
|
78
|
-
k = setting['k']
|
|
79
|
-
beta = setting['beta']
|
|
80
|
-
basis_beta = setting['basis_beta']
|
|
81
|
-
step = state.get("step", 0)
|
|
82
|
-
state["step"] = step + 1
|
|
83
|
-
|
|
84
|
-
# initialize history
|
|
85
|
-
if 'history' not in state:
|
|
86
|
-
state['history'] = deque(maxlen=k)
|
|
87
|
-
state['accumulator'] = torch.eye(k, device=g.device, dtype=g.dtype)
|
|
88
|
-
state['basis'] = torch.zeros(g.numel(), k, device=g.device, dtype=g.dtype)
|
|
89
|
-
|
|
90
|
-
history: deque = state['history']
|
|
91
|
-
accumulator = state['accumulator']
|
|
92
|
-
basis = state['basis']
|
|
93
|
-
history.append(g)
|
|
94
|
-
|
|
95
|
-
# stack history to new basis term, if history isn't full, fill with random vecs
|
|
96
|
-
if len(history) < k:
|
|
97
|
-
basis_t = torch.randn(g.numel(), k, device=g.device, dtype=g.dtype)
|
|
98
|
-
history_basis = torch.stack(tuple(history), -1)
|
|
99
|
-
basis_t[:, -len(history):] = history_basis
|
|
100
|
-
|
|
101
|
-
else:
|
|
102
|
-
basis_t = torch.stack(tuple(history), -1)
|
|
103
|
-
|
|
104
|
-
# in this case basis uses differences in gradients except last entry is the gradient
|
|
105
|
-
if setting["basis_type"] == "differences":
|
|
106
|
-
basis_t[:,:-1] = basis_t[:, :-1] - basis_t[:, 1:]
|
|
107
|
-
|
|
108
|
-
# normalize or orthonormalize new basis term
|
|
109
|
-
if setting["orthogonalize_method"] is not None:
|
|
110
|
-
basis_t = orthogonalize(basis_t, method = setting["orthogonalize_method"])
|
|
111
|
-
else:
|
|
112
|
-
basis_t = (basis_t - basis_t.mean()) / basis_t.std().clip(min=torch.finfo(g.dtype).tiny * 2)
|
|
113
|
-
|
|
114
|
-
# lerp basis
|
|
115
|
-
basis.lerp_(basis_t, 1-basis_beta)
|
|
116
|
-
basis = basis / (1 - basis_beta ** (step+1)) # correct bias on basis EMA
|
|
117
|
-
update_subspace_preconditioner_(g, basis, accumulator, beta)
|
|
118
|
-
|
|
119
|
-
@torch.no_grad
|
|
120
|
-
def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
|
|
121
|
-
g = tensor.ravel()
|
|
122
|
-
|
|
123
|
-
basis = state['basis']
|
|
124
|
-
accumulator = state['accumulator']
|
|
125
|
-
step = state["step"]
|
|
126
|
-
accumulator = accumulator / (1 - setting["beta"] ** (step+1)) # correct bias on accumulator EMA
|
|
127
|
-
|
|
128
|
-
try:
|
|
129
|
-
preconditioned = apply_subspace_preconditioner(
|
|
130
|
-
g,
|
|
131
|
-
basis,
|
|
132
|
-
accumulator,
|
|
133
|
-
tol=setting["tol"],
|
|
134
|
-
truncate=setting["truncate"],
|
|
135
|
-
damping=setting["damping"],
|
|
136
|
-
rdamping=setting["rdamping"],
|
|
137
|
-
)
|
|
138
|
-
except torch.linalg.LinAlgError:
|
|
139
|
-
preconditioned = g.clip(-0.1, 0.1)
|
|
140
|
-
|
|
141
|
-
return preconditioned.view_as(tensor)
|
|
142
|
-
|
|
@@ -1,182 +0,0 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
|
|
3
|
-
from ...core import Transform
|
|
4
|
-
from ...linalg.orthogonalize import orthogonalize, OrthogonalizeMethod
|
|
5
|
-
from ...linalg.eigh import eigh_plus_uuT, regularize_eigh
|
|
6
|
-
from ...utils import TensorList, unpack_states, vec_to_tensors_
|
|
7
|
-
from ..opt_utils import safe_clip
|
|
8
|
-
from .eigengrad import _eigengrad_update_state_, eigengrad_apply
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
def sr1_u(L: torch.Tensor, Q: torch.Tensor, s:torch.Tensor, y: torch.Tensor, tol:float):
|
|
12
|
-
"""u from u u^T correction and its sign"""
|
|
13
|
-
r = y - torch.linalg.multi_dot([Q, L.diag_embed(), Q.T, s]) # pylint:disable=not-callable
|
|
14
|
-
rs = r.dot(s)
|
|
15
|
-
|
|
16
|
-
if rs.abs() < tol * torch.linalg.vector_norm(r) * torch.linalg.vector_norm(s): # pylint:disable=not-callable
|
|
17
|
-
return None, None
|
|
18
|
-
|
|
19
|
-
u = r / rs.abs().sqrt()
|
|
20
|
-
return u, torch.sign(rs)
|
|
21
|
-
|
|
22
|
-
class EigenSR1(Transform):
|
|
23
|
-
def __init__(
|
|
24
|
-
self,
|
|
25
|
-
rank: int = 100,
|
|
26
|
-
tol: float = 1e-32,
|
|
27
|
-
eig_tol: float | None = None,
|
|
28
|
-
damping: float = 0,
|
|
29
|
-
rdamping: float = 0,
|
|
30
|
-
abs: bool = False,
|
|
31
|
-
mm_tol: float = 1e-7,
|
|
32
|
-
mm_truncate: int | None = None,
|
|
33
|
-
mm_damping: float = 1e-4,
|
|
34
|
-
mm_rdamping: float = 0,
|
|
35
|
-
mm_abs: bool = True,
|
|
36
|
-
id_reg: float | None = None,
|
|
37
|
-
column_space_tol=1e-9,
|
|
38
|
-
beta: float = 0.95,
|
|
39
|
-
balance_tol: float = 10,
|
|
40
|
-
balance_strength: float = 1e-1,
|
|
41
|
-
|
|
42
|
-
eigenbasis_optimizer = None,
|
|
43
|
-
update_freq: int = 1,
|
|
44
|
-
init_steps: int = 10,
|
|
45
|
-
orthogonalize_interval: int | None = 1,
|
|
46
|
-
orthogonalize_method: OrthogonalizeMethod = 'qr',
|
|
47
|
-
|
|
48
|
-
hvp_method = "autograd",
|
|
49
|
-
h = 1e-3,
|
|
50
|
-
inner = None,
|
|
51
|
-
|
|
52
|
-
):
|
|
53
|
-
defaults = locals().copy()
|
|
54
|
-
for k in ["self", "inner"]:
|
|
55
|
-
del defaults[k]
|
|
56
|
-
|
|
57
|
-
super().__init__(defaults)
|
|
58
|
-
|
|
59
|
-
def update_states(self, objective, states, settings):
|
|
60
|
-
fs = settings[0]
|
|
61
|
-
step = self.increment_counter("step", 0)
|
|
62
|
-
|
|
63
|
-
if step % fs["update_freq"] == 0:
|
|
64
|
-
|
|
65
|
-
params = TensorList(objective.params)
|
|
66
|
-
|
|
67
|
-
# compute y as hessian-vector product with s (random vecs during init steps)
|
|
68
|
-
if ("p_prev" not in self.global_state) or (step < fs["init_steps"]):
|
|
69
|
-
s_list = params.sample_like('rademacher')
|
|
70
|
-
|
|
71
|
-
else:
|
|
72
|
-
p_prev = self.global_state["p_prev"]
|
|
73
|
-
s_list = params - p_prev
|
|
74
|
-
|
|
75
|
-
if s_list.dot(s_list) < torch.finfo(s_list[0].dtype).tiny * 2:
|
|
76
|
-
s_list = params.sample_like('rademacher')
|
|
77
|
-
|
|
78
|
-
self.global_state["p_prev"] = params
|
|
79
|
-
|
|
80
|
-
# compute y as hessian-vector product with s
|
|
81
|
-
Hz_list, _ = objective.hessian_vector_product(s_list, rgrad=None, at_x0=True, hvp_method=fs["hvp_method"], h=fs["h"])
|
|
82
|
-
|
|
83
|
-
s = torch.cat([t.ravel() for t in s_list])
|
|
84
|
-
y = torch.cat([t.ravel() for t in Hz_list])
|
|
85
|
-
|
|
86
|
-
# keep track of exponential moving average of hessian diagonal and balance eigenvalues
|
|
87
|
-
if (fs["balance_strength"] != 0) and (step > fs["init_steps"]) and ("L" in self.global_state):
|
|
88
|
-
|
|
89
|
-
D = s * y # hutchinson estimator
|
|
90
|
-
exp_avg = self.global_state.get("exp_avg", None)
|
|
91
|
-
|
|
92
|
-
if exp_avg is None:
|
|
93
|
-
exp_avg = self.global_state["exp_avg"] = D
|
|
94
|
-
|
|
95
|
-
else:
|
|
96
|
-
exp_avg.lerp_(D, weight=1-fs["beta"])
|
|
97
|
-
|
|
98
|
-
L = self.global_state["L"]
|
|
99
|
-
L_abs = L.abs()
|
|
100
|
-
tau = L_abs.amax() / exp_avg.abs().amax()
|
|
101
|
-
|
|
102
|
-
if tau > fs["balance_tol"]:
|
|
103
|
-
L_balanced = L_abs.pow((1 / tau) ** (1 / fs["balance_strength"])).copysign(L)
|
|
104
|
-
self.global_state["L"] = torch.where(L_abs > 1, L_balanced, L)
|
|
105
|
-
|
|
106
|
-
# initialize L and Q on 1st step
|
|
107
|
-
if "L" not in self.global_state:
|
|
108
|
-
|
|
109
|
-
L = torch.zeros(1, dtype=s.dtype, device=s.device) # rank, rank
|
|
110
|
-
Q = torch.zeros([s.numel(), 1], dtype=s.dtype, device=s.device) # ndim, rank
|
|
111
|
-
|
|
112
|
-
u, sign = sr1_u(L=L, Q=Q, s=s, y=y, tol=0)
|
|
113
|
-
assert u is not None and sign is not None
|
|
114
|
-
|
|
115
|
-
# for uu^T u is eigenvector and u^T u is eigenvalue
|
|
116
|
-
norm = torch.linalg.vector_norm(u).clip(min=torch.finfo(u.dtype).tiny * 2) # pylint:disable=not-callable
|
|
117
|
-
|
|
118
|
-
self.global_state["L"] = self.global_state["L_reg"] = (u.dot(u).unsqueeze(0) / norm) * sign # (rank,)
|
|
119
|
-
self.global_state["Q"] = self.global_state["Q_reg"] = u.unsqueeze(-1) / norm # (m, rank)
|
|
120
|
-
|
|
121
|
-
# update hessian
|
|
122
|
-
else:
|
|
123
|
-
try:
|
|
124
|
-
L = self.global_state["L"]
|
|
125
|
-
Q = self.global_state["Q"]
|
|
126
|
-
|
|
127
|
-
H_step = self.increment_counter("H_step", start=0)
|
|
128
|
-
if H_step % fs["orthogonalize_interval"] == 0:
|
|
129
|
-
Q = orthogonalize(Q, method=fs["orthogonalize_method"])
|
|
130
|
-
|
|
131
|
-
u, sign = sr1_u(L=L, Q=Q, s=s, y=y, tol=fs["tol"])
|
|
132
|
-
|
|
133
|
-
if (u is not None) and (sign is not None):
|
|
134
|
-
|
|
135
|
-
# compute new factors
|
|
136
|
-
L_new, Q_new = eigh_plus_uuT(L, Q, u, tol=fs["column_space_tol"], alpha=sign.item(), retry_float64=True)
|
|
137
|
-
|
|
138
|
-
# truncate/regularize new factors (those go into the accumulator)
|
|
139
|
-
L_new, Q_new = regularize_eigh(L=L_new, Q=Q_new, truncate=min(fs["rank"], s.numel()),
|
|
140
|
-
tol=fs["eig_tol"], damping=fs["damping"], rdamping=fs["rdamping"])
|
|
141
|
-
|
|
142
|
-
_eigengrad_update_state_(state=self.global_state, setting=fs, L_new=L_new, Q_new=Q_new)
|
|
143
|
-
|
|
144
|
-
except torch.linalg.LinAlgError:
|
|
145
|
-
pass
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
def apply_states(self, objective, states, settings):
|
|
150
|
-
fs = settings[0]
|
|
151
|
-
updates = objective.get_updates()
|
|
152
|
-
|
|
153
|
-
if "eigenbasis_state" not in self.global_state:
|
|
154
|
-
self.global_state["eigenbasis_state"] = {}
|
|
155
|
-
|
|
156
|
-
step = self.global_state["step"] # starts at 0
|
|
157
|
-
if step < fs["init_steps"]:
|
|
158
|
-
|
|
159
|
-
# skip update first init_steps to let hessian kick-start
|
|
160
|
-
objective.stop = True
|
|
161
|
-
objective.skip_update = True
|
|
162
|
-
return objective
|
|
163
|
-
|
|
164
|
-
if "L_reg" not in self.global_state:
|
|
165
|
-
TensorList(updates).clip_(-0.1, 0.1)
|
|
166
|
-
return objective
|
|
167
|
-
|
|
168
|
-
dir = eigengrad_apply(
|
|
169
|
-
tensor = torch.cat([t.ravel() for t in updates]),
|
|
170
|
-
L_reg = self.global_state["L_reg"],
|
|
171
|
-
Q_reg = self.global_state["Q_reg"],
|
|
172
|
-
beta = None,
|
|
173
|
-
step = None,
|
|
174
|
-
debias = False,
|
|
175
|
-
id_reg = fs["id_reg"],
|
|
176
|
-
eigenbasis_optimizer = fs["eigenbasis_optimizer"],
|
|
177
|
-
eigenbasis_state = self.global_state["eigenbasis_state"],
|
|
178
|
-
whiten_fn = lambda x: x
|
|
179
|
-
)
|
|
180
|
-
|
|
181
|
-
vec_to_tensors_(dir, updates)
|
|
182
|
-
return objective
|