torchzero 0.4.1__py3-none-any.whl → 0.4.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +1 -1
- torchzero/__init__.py +3 -1
- torchzero/_minimize/__init__.py +0 -0
- torchzero/_minimize/methods.py +95 -0
- torchzero/_minimize/minimize.py +518 -0
- torchzero/core/__init__.py +5 -5
- torchzero/core/chain.py +2 -1
- torchzero/core/functional.py +2 -1
- torchzero/core/module.py +75 -4
- torchzero/core/transform.py +6 -5
- torchzero/linalg/eigh.py +116 -68
- torchzero/linalg/linear_operator.py +1 -0
- torchzero/linalg/orthogonalize.py +60 -5
- torchzero/linalg/sketch.py +39 -0
- torchzero/modules/__init__.py +1 -0
- torchzero/modules/adaptive/adagrad.py +2 -0
- torchzero/modules/adaptive/adam.py +5 -1
- torchzero/modules/adaptive/adan.py +3 -0
- torchzero/modules/adaptive/ggt.py +20 -18
- torchzero/modules/adaptive/lion.py +3 -1
- torchzero/modules/adaptive/mars.py +6 -5
- torchzero/modules/adaptive/msam.py +3 -0
- torchzero/modules/adaptive/rmsprop.py +2 -0
- torchzero/modules/adaptive/rprop.py +9 -7
- torchzero/modules/adaptive/shampoo.py +9 -1
- torchzero/modules/adaptive/soap.py +32 -29
- torchzero/modules/basis/__init__.py +2 -0
- torchzero/modules/basis/ggt_basis.py +199 -0
- torchzero/modules/basis/soap_basis.py +254 -0
- torchzero/modules/clipping/ema_clipping.py +32 -27
- torchzero/modules/clipping/growth_clipping.py +1 -0
- torchzero/modules/experimental/__init__.py +1 -6
- torchzero/modules/experimental/coordinate_momentum.py +2 -0
- torchzero/modules/experimental/cubic_adam.py +4 -0
- torchzero/modules/grad_approximation/__init__.py +3 -2
- torchzero/modules/least_squares/gn.py +6 -0
- torchzero/modules/misc/gradient_accumulation.py +1 -0
- torchzero/modules/misc/misc.py +6 -0
- torchzero/modules/momentum/averaging.py +6 -0
- torchzero/modules/momentum/momentum.py +13 -9
- torchzero/modules/ops/__init__.py +0 -1
- torchzero/modules/ops/accumulate.py +4 -0
- torchzero/modules/ops/higher_level.py +6 -1
- torchzero/modules/second_order/inm.py +4 -0
- torchzero/modules/second_order/newton.py +11 -3
- torchzero/modules/second_order/newton_cg.py +7 -3
- torchzero/modules/second_order/nystrom.py +14 -19
- torchzero/modules/second_order/rsn.py +37 -6
- torchzero/modules/trust_region/trust_region.py +2 -1
- torchzero/utils/benchmarks/logistic.py +33 -18
- torchzero/utils/optuna_tools.py +1 -1
- torchzero/utils/params.py +13 -1
- torchzero/utils/tensorlist.py +2 -2
- {torchzero-0.4.1.dist-info → torchzero-0.4.3.dist-info}/METADATA +1 -1
- {torchzero-0.4.1.dist-info → torchzero-0.4.3.dist-info}/RECORD +58 -55
- torchzero/modules/experimental/adanystrom.py +0 -258
- torchzero/modules/experimental/common_directions_whiten.py +0 -142
- torchzero/modules/experimental/eigen_sr1.py +0 -182
- torchzero/modules/experimental/eigengrad.py +0 -207
- /torchzero/modules/{experimental → grad_approximation}/spsa1.py +0 -0
- {torchzero-0.4.1.dist-info → torchzero-0.4.3.dist-info}/WHEEL +0 -0
- {torchzero-0.4.1.dist-info → torchzero-0.4.3.dist-info}/top_level.txt +0 -0
torchzero/core/module.py
CHANGED
|
@@ -2,18 +2,19 @@ import warnings
|
|
|
2
2
|
from abc import ABC, abstractmethod
|
|
3
3
|
from collections import ChainMap, defaultdict
|
|
4
4
|
from collections.abc import Callable, Iterable, Sequence
|
|
5
|
-
from typing import Any, overload, TYPE_CHECKING
|
|
5
|
+
from typing import Any, overload, TYPE_CHECKING, Literal
|
|
6
6
|
|
|
7
7
|
import torch
|
|
8
8
|
|
|
9
9
|
from ..linalg.linear_operator import LinearOperator
|
|
10
10
|
from ..utils.optimizer import Init, ListLike, get_state_vals
|
|
11
|
-
from ..utils.params import Params, _make_param_groups
|
|
11
|
+
from ..utils.params import Params, _make_param_groups, _set_fake_params_, _empty_fake_param_storage_
|
|
12
12
|
from .functional import step_tensors
|
|
13
13
|
|
|
14
14
|
if TYPE_CHECKING:
|
|
15
15
|
from .objective import Objective
|
|
16
16
|
|
|
17
|
+
ProjectedBuffer = Literal["grad", "grad_sq", "grad_cu", "covariance", "inverse"]
|
|
17
18
|
|
|
18
19
|
class Module(ABC):
|
|
19
20
|
"""Abstract base class for an optimizer modules.
|
|
@@ -52,6 +53,12 @@ class Module(ABC):
|
|
|
52
53
|
self._overridden_keys = set()
|
|
53
54
|
"""tracks keys overridden with ``set_param_groups``, only used to not give a warning"""
|
|
54
55
|
|
|
56
|
+
self._projected_keys: defaultdict[ProjectedBuffer, set[str]] = defaultdict(set)
|
|
57
|
+
"""tracks keys with gradient-like buffers, covariance-like buffers, etc for reprojecting"""
|
|
58
|
+
|
|
59
|
+
self._fake_params: dict[str, list[torch.Tensor]] = {}
|
|
60
|
+
"""fake parameters for state keys and shape inference, key is name of child, value is list of fake parameters"""
|
|
61
|
+
|
|
55
62
|
|
|
56
63
|
def set_param_groups(self, param_groups: Params):
|
|
57
64
|
"""Set custom parameter groups with per-parameter settings that this module will use."""
|
|
@@ -123,7 +130,9 @@ class Module(ABC):
|
|
|
123
130
|
clone (bool):
|
|
124
131
|
If ``key`` exists, whether to clone ``tensors`` to avoid modifying buffers in-place.
|
|
125
132
|
If ``key`` doesn't exist, ``tensors`` are always returned without cloning
|
|
126
|
-
params (Iterable[torch.Tensor] | None, optional):
|
|
133
|
+
params (Iterable[torch.Tensor] | None, optional):
|
|
134
|
+
pass None if ``tensors`` have different shape, it will create fake params from tensors
|
|
135
|
+
for state keys and shape inference. Defaults to None.
|
|
127
136
|
grads (Sequence[torch.Tensor] | None, optional): grads. Defaults to None.
|
|
128
137
|
loss (torch.Tensor | None, optional): loss. Defaults to None.
|
|
129
138
|
closure (Callable | None, optional): closure. Defaults to None.
|
|
@@ -137,9 +146,26 @@ class Module(ABC):
|
|
|
137
146
|
return tensors
|
|
138
147
|
|
|
139
148
|
if clone: tensors = [t.clone() for t in tensors]
|
|
140
|
-
|
|
149
|
+
|
|
150
|
+
# set fake params to same storage as tensors so as to not use any extra memory
|
|
151
|
+
# while they still refer to same python objects, so they can be used
|
|
152
|
+
# as state keys and for shape inference when params aren't given.
|
|
153
|
+
fake = params is None
|
|
154
|
+
if fake:
|
|
155
|
+
if key not in self._fake_params:
|
|
156
|
+
self._fake_params[key] = [torch.empty_like(t) for t in tensors]
|
|
157
|
+
params = self._fake_params[key]
|
|
158
|
+
_set_fake_params_(params, tensors)
|
|
159
|
+
|
|
160
|
+
update = step_tensors(modules=child, tensors=tensors, params=params, grads=grads,
|
|
141
161
|
loss=loss, closure=closure, objective=objective)
|
|
142
162
|
|
|
163
|
+
# set fake params storage to empty
|
|
164
|
+
if fake:
|
|
165
|
+
_empty_fake_param_storage_(params)
|
|
166
|
+
|
|
167
|
+
return update
|
|
168
|
+
|
|
143
169
|
|
|
144
170
|
def __repr__(self):
|
|
145
171
|
s = self.__class__.__name__
|
|
@@ -322,6 +348,48 @@ class Module(ABC):
|
|
|
322
348
|
self.global_state[key] = value
|
|
323
349
|
return value
|
|
324
350
|
|
|
351
|
+
def get_child_projected_buffers(self, key: str, buff: ProjectedBuffer | Sequence[ProjectedBuffer], params:Sequence[torch.Tensor] | None = None) -> list[list[torch.Tensor]]:
|
|
352
|
+
"""if params is None, assumes fake parameters"""
|
|
353
|
+
if isinstance(buff, str): buff = (buff, )
|
|
354
|
+
|
|
355
|
+
child = self.children[key]
|
|
356
|
+
child.on_get_projected_buffers()
|
|
357
|
+
if params is None:
|
|
358
|
+
params = self._fake_params[key]
|
|
359
|
+
|
|
360
|
+
vals = []
|
|
361
|
+
for b in buff:
|
|
362
|
+
for buff_key in child._projected_keys[b]:
|
|
363
|
+
state = child.state[params[0]]
|
|
364
|
+
if buff_key in state:
|
|
365
|
+
tensors = [child.state[p][buff_key] for p in params]
|
|
366
|
+
if isinstance(tensors[0], torch.Tensor):
|
|
367
|
+
vals.append(tensors)
|
|
368
|
+
else: # its usually a deque
|
|
369
|
+
assert isinstance(tensors[0], Sequence), type(tensors[0])
|
|
370
|
+
vals.extend(zip(*tensors))
|
|
371
|
+
|
|
372
|
+
elif buff_key in child.global_state:
|
|
373
|
+
val = child.global_state[buff_key]
|
|
374
|
+
if len(val) == 0: continue
|
|
375
|
+
if isinstance(val[0], torch.Tensor):
|
|
376
|
+
vals.append(val)
|
|
377
|
+
else:
|
|
378
|
+
assert isinstance(val[0], Sequence)
|
|
379
|
+
vals.extend(zip(*vals))
|
|
380
|
+
|
|
381
|
+
# recursively do this on children,
|
|
382
|
+
# note that if params are fake, children will have same fake params
|
|
383
|
+
# unless that child steps with something else. I don't think that is feasible to support it
|
|
384
|
+
for c in child.children:
|
|
385
|
+
vals.extend(child.get_child_projected_buffers(c, buff, params=params))
|
|
386
|
+
|
|
387
|
+
return vals
|
|
388
|
+
|
|
389
|
+
def add_projected_keys(self, buffer: ProjectedBuffer, *keys):
|
|
390
|
+
for k in keys: self._projected_keys[buffer].add(k)
|
|
391
|
+
|
|
392
|
+
|
|
325
393
|
# ---------------------------- OVERRIDABLE METHODS --------------------------- #
|
|
326
394
|
def update(self, objective:"Objective") -> None:
|
|
327
395
|
"""Updates internal state of this module. This should not modify ``objective.update``.
|
|
@@ -394,6 +462,9 @@ class Module(ABC):
|
|
|
394
462
|
"""
|
|
395
463
|
for c in self.children.values(): c.reset_for_online()
|
|
396
464
|
|
|
465
|
+
def on_get_projected_buffers(self):
|
|
466
|
+
"""runs before projected buffers are accessed"""
|
|
467
|
+
|
|
397
468
|
def _extra_pack(self) -> dict:
|
|
398
469
|
"""extra information to store in ``state_dict`` of this optimizer.
|
|
399
470
|
Will be passed to ``_extra_unpack`` when loading the ``state_dict``."""
|
torchzero/core/transform.py
CHANGED
|
@@ -1,12 +1,12 @@
|
|
|
1
1
|
from abc import ABC, abstractmethod
|
|
2
2
|
from collections.abc import Mapping, Sequence
|
|
3
3
|
from operator import itemgetter
|
|
4
|
-
from typing import
|
|
4
|
+
from typing import TYPE_CHECKING, Any, cast, final
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
|
|
8
|
+
from ..utils import safe_dict_update_, vec_to_tensors
|
|
8
9
|
from .module import Module
|
|
9
|
-
from ..utils import vec_to_tensors, safe_dict_update_
|
|
10
10
|
|
|
11
11
|
if TYPE_CHECKING:
|
|
12
12
|
from .chain import Chainable
|
|
@@ -31,7 +31,7 @@ class Transform(Module):
|
|
|
31
31
|
|
|
32
32
|
self._objective = None
|
|
33
33
|
if inner is not None:
|
|
34
|
-
self.set_child("
|
|
34
|
+
self.set_child("__inner", inner)
|
|
35
35
|
|
|
36
36
|
# settings shouldn't mutate, so they are typed as Sequence[Mapping]
|
|
37
37
|
def update_states(self, objective: "Objective", states: list[dict[str, Any]], settings: Sequence[Mapping[str, Any]]) -> None:
|
|
@@ -70,8 +70,8 @@ class Transform(Module):
|
|
|
70
70
|
def apply(self, objective: "Objective"):
|
|
71
71
|
|
|
72
72
|
# inner step
|
|
73
|
-
if "
|
|
74
|
-
inner = self.children["
|
|
73
|
+
if "__inner" in self.children:
|
|
74
|
+
inner = self.children["__inner"]
|
|
75
75
|
objective = inner.step(objective)
|
|
76
76
|
|
|
77
77
|
# apply and return
|
|
@@ -128,6 +128,7 @@ class TensorTransform(Transform):
|
|
|
128
128
|
self._uses_grad = uses_grad
|
|
129
129
|
self._uses_loss = uses_loss
|
|
130
130
|
|
|
131
|
+
|
|
131
132
|
# ------------------------------- single tensor ------------------------------ #
|
|
132
133
|
def single_tensor_initialize(
|
|
133
134
|
self,
|
torchzero/linalg/eigh.py
CHANGED
|
@@ -10,30 +10,18 @@ from .svd import tall_reduced_svd_via_eigh
|
|
|
10
10
|
|
|
11
11
|
# https://arxiv.org/pdf/2110.02820
|
|
12
12
|
def nystrom_approximation(
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
ndim: int,
|
|
16
|
-
rank: int,
|
|
17
|
-
device,
|
|
18
|
-
orthogonalize_method: OrthogonalizeMethod = 'qr',
|
|
13
|
+
Omega: torch.Tensor,
|
|
14
|
+
AOmega: torch.Tensor,
|
|
19
15
|
eigv_tol: float = 0,
|
|
20
|
-
dtype = torch.float32,
|
|
21
|
-
generator = None,
|
|
22
16
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
23
17
|
"""Computes Nyström approximation to positive-semidefinite A factored as Q L Q^T (truncatd eigenvalue decomp),
|
|
24
18
|
returns ``(L, Q)``.
|
|
25
19
|
|
|
26
20
|
A is ``(m,m)``, then Q is ``(m, rank)``; L is a ``(rank, )`` vector - diagonal of ``(rank, rank)``"""
|
|
27
|
-
# basis
|
|
28
|
-
O = torch.randn((ndim, rank), device=device, dtype=dtype, generator=generator) # Gaussian test matrix
|
|
29
|
-
O = orthogonalize(O, method=orthogonalize_method) # Thin QR decomposition # pylint:disable=not-callable
|
|
30
|
-
|
|
31
|
-
# Y = AΩ
|
|
32
|
-
AO = mm(A_mv=A_mv, A_mm=A_mm, X=O)
|
|
33
21
|
|
|
34
|
-
v = torch.finfo(dtype).eps * torch.linalg.matrix_norm(
|
|
35
|
-
Yv =
|
|
36
|
-
C = torch.linalg.cholesky_ex(
|
|
22
|
+
v = torch.finfo(AOmega.dtype).eps * torch.linalg.matrix_norm(AOmega, ord='fro') # Compute shift # pylint:disable=not-callable
|
|
23
|
+
Yv = AOmega + v*Omega # Shift for stability
|
|
24
|
+
C = torch.linalg.cholesky_ex(Omega.mT @ Yv)[0] # pylint:disable=not-callable
|
|
37
25
|
B = torch.linalg.solve_triangular(C, Yv.mT, upper=False, unitriangular=False).mT # pylint:disable=not-callable
|
|
38
26
|
|
|
39
27
|
# Q, S, _ = torch_linalg.svd(B, full_matrices=False) # pylint:disable=not-callable
|
|
@@ -138,26 +126,35 @@ def eigh_plus_uuT(
|
|
|
138
126
|
|
|
139
127
|
return L_prime, Q_prime
|
|
140
128
|
|
|
141
|
-
def
|
|
129
|
+
def eigh_plus_UUt(
|
|
142
130
|
L: torch.Tensor,
|
|
143
131
|
Q: torch.Tensor,
|
|
144
132
|
U: torch.Tensor,
|
|
145
|
-
alpha: float = 1,
|
|
133
|
+
alpha: float | torch.Tensor = 1,
|
|
146
134
|
tol = None,
|
|
147
|
-
|
|
148
|
-
|
|
135
|
+
ortho_method: OrthogonalizeMethod = 'qr',
|
|
136
|
+
retry_float64=True,
|
|
137
|
+
) -> tuple[torch.Tensor, torch.Tensor] | tuple[None, None]:
|
|
149
138
|
"""
|
|
150
139
|
compute eigendecomposition of Q L Q^T + alpha * (U U^T), where Q is ``(m, rank)`` and L is ``(rank, )``,
|
|
151
140
|
U is ``(m, k)`` where k is rank of correction
|
|
141
|
+
|
|
142
|
+
returns ``(L, Q)``
|
|
152
143
|
"""
|
|
153
144
|
if U.size(1) == 1:
|
|
154
|
-
return eigh_plus_uuT(L, Q, U[:,0], alpha=alpha, tol=tol
|
|
145
|
+
return eigh_plus_uuT(L, Q, U[:,0], alpha=float(alpha), tol=tol)
|
|
146
|
+
|
|
147
|
+
# make alpha shape (k, )
|
|
148
|
+
k = U.size(1)
|
|
149
|
+
if isinstance(alpha, torch.Tensor):
|
|
150
|
+
alpha = torch.broadcast_to(alpha, (k, ))
|
|
151
|
+
else:
|
|
152
|
+
alpha = torch.full((k,), float(alpha), device=U.device, dtype=U.dtype)
|
|
155
153
|
|
|
156
154
|
if tol is None: tol = torch.finfo(Q.dtype).eps
|
|
157
155
|
m, r = Q.shape
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
U_res = U - Q @ Z # (m, k)
|
|
156
|
+
QtU = Q.T @ U # (r, k)
|
|
157
|
+
U_res = U - Q @ QtU # (m, k)
|
|
161
158
|
|
|
162
159
|
# find cols of U not in col space of Q
|
|
163
160
|
res_norms = torch.linalg.vector_norm(U_res, dim=0) # pylint:disable=not-callable
|
|
@@ -167,23 +164,26 @@ def eigh_plus_UUT(
|
|
|
167
164
|
if k_prime == 0:
|
|
168
165
|
# all cols are in Q
|
|
169
166
|
B = Q
|
|
170
|
-
C =
|
|
167
|
+
C = QtU # (r x k)
|
|
171
168
|
r_new = r
|
|
172
169
|
else:
|
|
173
170
|
# orthonormalize directions that aren't in Q
|
|
174
171
|
U_new = U_res[:, new_indices]
|
|
175
|
-
Q_u
|
|
172
|
+
Q_u = orthogonalize(U_new, method=ortho_method)
|
|
176
173
|
B = torch.hstack([Q, Q_u])
|
|
177
|
-
C = torch.vstack([
|
|
174
|
+
C = torch.vstack([QtU, Q_u.T @ U_res])
|
|
178
175
|
r_new = r + k_prime
|
|
179
176
|
|
|
180
|
-
|
|
181
177
|
# project and compute new eigendecomposition
|
|
182
178
|
A_proj = torch.zeros((r_new, r_new), device=Q.device, dtype=Q.dtype)
|
|
183
179
|
A_proj[:r, :r] = L.diag_embed()
|
|
184
|
-
A_proj
|
|
180
|
+
# A_proj += (C @ C.T).mul_(alpha)
|
|
181
|
+
A_proj.addmm_(C * alpha, C.T)
|
|
185
182
|
|
|
186
|
-
|
|
183
|
+
try:
|
|
184
|
+
L_prime, S = torch_linalg.eigh(A_proj, retry_float64=retry_float64)
|
|
185
|
+
except torch.linalg.LinAlgError:
|
|
186
|
+
return None, None
|
|
187
187
|
|
|
188
188
|
# unproject and sort
|
|
189
189
|
Q_prime = B @ S
|
|
@@ -194,60 +194,108 @@ def eigh_plus_UUT(
|
|
|
194
194
|
return L_prime, Q_prime
|
|
195
195
|
|
|
196
196
|
|
|
197
|
-
def
|
|
198
|
-
Q
|
|
197
|
+
def eigh_plus_UUt_mm(
|
|
198
|
+
# A1 = Q @ diag(L) @ Q.T
|
|
199
199
|
L: torch.Tensor,
|
|
200
|
+
Q: torch.Tensor,
|
|
201
|
+
|
|
202
|
+
# A2 = U @ U.T
|
|
200
203
|
U: torch.Tensor,
|
|
201
|
-
V: torch.Tensor,
|
|
202
|
-
alpha: float,
|
|
203
|
-
retry_float64: bool = False,
|
|
204
204
|
|
|
205
|
-
|
|
205
|
+
# rhs
|
|
206
|
+
B: torch.Tensor,
|
|
207
|
+
|
|
208
|
+
# weights
|
|
209
|
+
w1: float,
|
|
210
|
+
w2: float | torch.Tensor,
|
|
211
|
+
|
|
212
|
+
) -> torch.Tensor:
|
|
206
213
|
"""
|
|
207
|
-
|
|
214
|
+
Computes ``(w1 * (Q L Q^T) + (U diag(w2) U^T) @ B``,
|
|
208
215
|
|
|
209
|
-
|
|
216
|
+
Q is ``(m, rank)``, L is ``(rank, rank)``, U is ``(m, z)``, B is ``(m, k)``.
|
|
210
217
|
|
|
211
|
-
``
|
|
218
|
+
Returns ``(m, k)``
|
|
219
|
+
"""
|
|
220
|
+
# sketch Q L Q^T
|
|
221
|
+
QtB = Q.T @ B # (rank, k)
|
|
222
|
+
LQtB = L.unsqueeze(1) * QtB # (rank, k)
|
|
223
|
+
sketch1 = Q @ LQtB # (m, k)
|
|
224
|
+
|
|
225
|
+
# skecth U U^T
|
|
226
|
+
UtB = U.T @ B # (z, k)
|
|
227
|
+
if isinstance(w2, torch.Tensor) and w2.numel() > 1: w2UtB = w2.unsqueeze(-1) * UtB
|
|
228
|
+
else: w2UtB = w2 * UtB
|
|
229
|
+
sketch2 = U @ w2UtB # (m, k)
|
|
230
|
+
|
|
231
|
+
return w1 * sketch1 + sketch2
|
|
212
232
|
|
|
213
|
-
|
|
233
|
+
|
|
234
|
+
def randomized_eigh_plus_UUt(
|
|
235
|
+
L1: torch.Tensor,
|
|
236
|
+
Q1: torch.Tensor,
|
|
237
|
+
U: torch.Tensor,
|
|
238
|
+
w1: float,
|
|
239
|
+
w2: float | torch.Tensor,
|
|
240
|
+
oversampling_p: int,
|
|
241
|
+
rank: int,
|
|
242
|
+
eig_tol: float,
|
|
243
|
+
damping: float,
|
|
244
|
+
rdamping: float,
|
|
245
|
+
ortho_method: OrthogonalizeMethod = 'qr',
|
|
246
|
+
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
|
|
247
|
+
"""
|
|
248
|
+
compute randomized eigendecomposition of w1 * Q L Q^T + w2 * (U U^T),
|
|
249
|
+
where Q is ``(m, rank)`` and L is ``(rank, )``,
|
|
250
|
+
U is ``(m, k)`` where k is rank of correction, returns ``(L, Q)``
|
|
214
251
|
"""
|
|
215
|
-
|
|
216
|
-
|
|
252
|
+
n = Q1.shape[0]
|
|
253
|
+
device = Q1.device
|
|
254
|
+
dtype = Q1.dtype
|
|
255
|
+
l = rank + oversampling_p
|
|
217
256
|
|
|
218
|
-
#
|
|
219
|
-
|
|
220
|
-
U_perp = U - Q @ Q_T_U
|
|
257
|
+
# gaussian test matrix
|
|
258
|
+
Omega = torch.randn(n, l, device=device, dtype=dtype)
|
|
221
259
|
|
|
222
|
-
|
|
223
|
-
|
|
260
|
+
# sketch
|
|
261
|
+
AOmega = eigh_plus_UUt_mm(L1, Q1, U, Omega, w1, w2)
|
|
262
|
+
Q = orthogonalize(AOmega, ortho_method)
|
|
224
263
|
|
|
225
|
-
|
|
226
|
-
|
|
264
|
+
AQ = eigh_plus_UUt_mm(L1, Q1, U, Q, w1, w2)
|
|
265
|
+
QtAQ = Q.T @ AQ
|
|
227
266
|
|
|
228
|
-
|
|
229
|
-
r_B = Q_B.shape[1]
|
|
267
|
+
W = (QtAQ + QtAQ.T) / 2.0
|
|
230
268
|
|
|
231
|
-
#
|
|
232
|
-
|
|
233
|
-
|
|
269
|
+
# compute new L and Q
|
|
270
|
+
try:
|
|
271
|
+
L_prime, S = torch.linalg.eigh(W) # pylint:disable=not-callable
|
|
272
|
+
except torch.linalg.LinAlgError:
|
|
273
|
+
return L1, Q1
|
|
234
274
|
|
|
235
|
-
|
|
236
|
-
Q_B_T_U = torch.vstack([Q_T_U, Q_perp_T_U])
|
|
275
|
+
L_prime, S = regularize_eigh(L=L_prime, Q=S, truncate=rank, tol=eig_tol, damping=damping, rdamping=rdamping)
|
|
237
276
|
|
|
238
|
-
|
|
239
|
-
|
|
277
|
+
if L_prime is None or S is None:
|
|
278
|
+
return L1, Q1
|
|
240
279
|
|
|
241
|
-
|
|
242
|
-
A_proj.add_(update_proj, alpha=alpha/2)
|
|
280
|
+
return L_prime, Q @ S
|
|
243
281
|
|
|
244
|
-
L_prime, S = torch_linalg.eigh(A_proj, retry_float64=retry_float64)
|
|
245
282
|
|
|
246
|
-
|
|
247
|
-
|
|
283
|
+
def rank1_eigh(v: torch.Tensor):
|
|
284
|
+
"""returns ``(L, Q)`` of ``(v v^T)``"""
|
|
285
|
+
vv = v.dot(v)
|
|
286
|
+
norm = vv.sqrt().clip(min=torch.finfo(vv.dtype).tiny * 2)
|
|
248
287
|
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
Q_prime = Q_prime[:, idx]
|
|
288
|
+
L = vv.unsqueeze(0) # (rank, )
|
|
289
|
+
Q = v.unsqueeze(-1) / norm # (m, rank)
|
|
252
290
|
|
|
253
|
-
return
|
|
291
|
+
return L, Q
|
|
292
|
+
|
|
293
|
+
def low_rank_eigh(U: torch.Tensor):
|
|
294
|
+
"""returns ``(L, Q)`` of ``alpha * (U U^T)`` (from GGT)"""
|
|
295
|
+
M = U.T @ U
|
|
296
|
+
L, S = torch.linalg.eigh(M) # pylint:disable=not-callable
|
|
297
|
+
|
|
298
|
+
Q = U @ S
|
|
299
|
+
Q /= torch.sqrt(L).clip(min=torch.finfo(L.dtype).tiny * 2)
|
|
300
|
+
|
|
301
|
+
return L, Q
|
|
@@ -51,9 +51,6 @@ def zeropower_via_newtonschulz5(G: torch.Tensor, coeffs=_NS_COEFFS) -> torch.Ten
|
|
|
51
51
|
return X.to(G.dtype)
|
|
52
52
|
|
|
53
53
|
def zeropower_via_svd(A: torch.Tensor) -> torch.Tensor:
|
|
54
|
-
"""
|
|
55
|
-
Applies to first 2 dims and isn't batched - rest of dimensions are flattened.
|
|
56
|
-
"""
|
|
57
54
|
try:
|
|
58
55
|
U, S, Vt = torch_linalg.svd(A, full_matrices=False, retry_float64=True) # pylint:disable=not-callable
|
|
59
56
|
except torch.linalg.LinAlgError:
|
|
@@ -84,9 +81,67 @@ def orthogonalize_via_qr(A: torch.Tensor):
|
|
|
84
81
|
|
|
85
82
|
return Q
|
|
86
83
|
|
|
87
|
-
|
|
84
|
+
# CODE FROM https://github.com/HomebrewML/HeavyBall/blob/main/heavyball/utils.py:
|
|
85
|
+
|
|
86
|
+
## Based on https://arxiv.org/pdf/2505.16932v3
|
|
87
|
+
# and https://github.com/NoahAmsel/PolarExpress/blob/5454910920ca8c65afda28820cdf9e49b9436ed0/polar_express.py#L69-L82
|
|
88
|
+
# and https://github.com/thinking-machines-lab/manifolds/blob/89dcae50f01af59f1e0570289474da3a2ecaa60b/src/msign.py#L47
|
|
89
|
+
#
|
|
90
|
+
# under the MIT License
|
|
91
|
+
# Coefficients are from https://arxiv.org/pdf/2505.16932v3
|
|
92
|
+
ABC_LIST: list[tuple[float, float, float]] = [
|
|
93
|
+
(8.28721201814563, -23.595886519098837, 17.300387312530933),
|
|
94
|
+
(4.107059111542203, -2.9478499167379106, 0.5448431082926601),
|
|
95
|
+
(3.9486908534822946, -2.908902115962949, 0.5518191394370137),
|
|
96
|
+
(3.3184196573706015, -2.488488024314874, 0.51004894012372),
|
|
97
|
+
(2.300652019954817, -1.6689039845747493, 0.4188073119525673),
|
|
98
|
+
(1.891301407787398, -1.2679958271945868, 0.37680408948524835),
|
|
99
|
+
(1.8750014808534479, -1.2500016453999487, 0.3750001645474248),
|
|
100
|
+
(1.875, -1.25, 0.375),
|
|
101
|
+
]
|
|
102
|
+
|
|
103
|
+
# safety factor for numerical stability (but exclude last polynomial)
|
|
104
|
+
ABC_LIST_STABLE: list[tuple[float, float, float]] = [
|
|
105
|
+
(a / 1.01, b / 1.01**3, c / 1.01**5) for (a, b, c) in ABC_LIST[:-1]
|
|
106
|
+
] + [ABC_LIST[-1]]
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def msign(G: torch.Tensor, steps: int = 10, eps: float = 1e-7) -> torch.Tensor:
|
|
110
|
+
"""
|
|
111
|
+
Polar Express algorithm for the matrix sign function:
|
|
112
|
+
https://arxiv.org/abs/2505.16932
|
|
113
|
+
"""
|
|
114
|
+
assert G.ndim >= 2
|
|
115
|
+
should_transpose: bool = G.size(-2) > G.size(-1)
|
|
116
|
+
|
|
117
|
+
x = G
|
|
118
|
+
if should_transpose:
|
|
119
|
+
x = x.mT
|
|
120
|
+
|
|
121
|
+
x = x / (x.norm(dim=(-2, -1), keepdim=True) * 1.01 + eps)
|
|
122
|
+
|
|
123
|
+
for step in range(steps):
|
|
124
|
+
a, b, c = ABC_LIST_STABLE[step] if step < len(ABC_LIST_STABLE) else ABC_LIST_STABLE[-1]
|
|
125
|
+
s = x @ x.mT
|
|
126
|
+
# goal is to compute x = a x + b S x + c S^2 x
|
|
127
|
+
# we can break this up into: x = (a I + (b I + c S) S) x
|
|
128
|
+
y = c * s
|
|
129
|
+
y.diagonal(dim1=-2, dim2=-1).add_(b)
|
|
130
|
+
y = y @ s
|
|
131
|
+
y.diagonal(dim1=-2, dim2=-1).add_(a)
|
|
132
|
+
x = y @ x
|
|
133
|
+
|
|
134
|
+
if should_transpose:
|
|
135
|
+
x = x.mT
|
|
136
|
+
return x.float()
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
###### END
|
|
140
|
+
|
|
141
|
+
OrthogonalizeMethod = Literal["newtonschulz", "ns5", "polar_express", "svd", "qr", "eigh"]
|
|
88
142
|
def orthogonalize(A: torch.Tensor, method: OrthogonalizeMethod) -> torch.Tensor:
|
|
89
|
-
if method
|
|
143
|
+
if method in ("newtonschulz", "ns5"): return zeropower_via_newtonschulz5(A)
|
|
144
|
+
if method == "polar_express": return msign(A)
|
|
90
145
|
if method == "svd": return zeropower_via_svd(A)
|
|
91
146
|
if method == "qr": return orthogonalize_via_qr(A)
|
|
92
147
|
if method == "eigh": return zeropower_via_eigh(A)
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
import math
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from .orthogonalize import orthogonalize_via_qr
|
|
6
|
+
from .linear_operator import LinearOperator, Dense
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class Permutation(LinearOperator):
|
|
10
|
+
def __init__(self, indices:torch.Tensor):
|
|
11
|
+
self.indices = indices
|
|
12
|
+
self.device = indices.device
|
|
13
|
+
|
|
14
|
+
def matvec(self, x):
|
|
15
|
+
return x[self.indices]
|
|
16
|
+
|
|
17
|
+
def matmat(self, X):
|
|
18
|
+
return Dense(X[:, self.indices])
|
|
19
|
+
|
|
20
|
+
def orthonormal_sketch(m, k, dtype, device, generator):
|
|
21
|
+
return orthogonalize_via_qr(torch.randn(m, k, dtype=dtype, device=device, generator=generator))
|
|
22
|
+
|
|
23
|
+
def rademacher_sketch(m, k, dtype, device, generator):
|
|
24
|
+
rademacher = torch.bernoulli(torch.full((m, k), 0.5, device=device, dtype=dtype), generator = generator).mul_(2).sub_(1)
|
|
25
|
+
return rademacher.mul_(1 / math.sqrt(m))
|
|
26
|
+
|
|
27
|
+
def row_sketch(m, k, dtype, device, generator):
|
|
28
|
+
weights = torch.ones(m, dtype=dtype, device=device)
|
|
29
|
+
indices = torch.multinomial(weights, k, replacement=False, generator=generator)
|
|
30
|
+
|
|
31
|
+
P = torch.zeros(m, k, dtype=dtype, device=device)
|
|
32
|
+
P[indices, range(k)] = 1
|
|
33
|
+
return P
|
|
34
|
+
|
|
35
|
+
def topk_rows_sketch(v: torch.Tensor, m, k, dtype, device):
|
|
36
|
+
_, indices = torch.topk(v, k)
|
|
37
|
+
P = torch.zeros(m, k, dtype=dtype, device=device)
|
|
38
|
+
P[indices, range(k)] = 1
|
|
39
|
+
return P
|
torchzero/modules/__init__.py
CHANGED
|
@@ -40,6 +40,7 @@ class Adagrad(TensorTransform):
|
|
|
40
40
|
super().__init__(defaults=defaults, inner=inner)
|
|
41
41
|
|
|
42
42
|
self.set_child('accumulator', accumulator_tfm)
|
|
43
|
+
self.add_projected_keys("grad", "accumulator")
|
|
43
44
|
|
|
44
45
|
@torch.no_grad
|
|
45
46
|
def single_tensor_initialize(self, tensor, param, grad, loss, state, setting):
|
|
@@ -235,6 +236,7 @@ class FullMatrixAdagrad(TensorTransform):
|
|
|
235
236
|
super().__init__(defaults=defaults, inner=inner, concat_params=concat_params)
|
|
236
237
|
|
|
237
238
|
self.set_child("accumulator", accumulator_tfm)
|
|
239
|
+
self.add_projected_keys("covariance", "accumulator")
|
|
238
240
|
|
|
239
241
|
@torch.no_grad
|
|
240
242
|
def single_tensor_update(self, tensor, param, grad, loss, state, setting):
|
|
@@ -38,6 +38,9 @@ class Adam(TensorTransform):
|
|
|
38
38
|
self.set_child('exp_avg', exp_avg_tfm)
|
|
39
39
|
self.set_child('exp_avg_sq', exp_avg_sq_tfm)
|
|
40
40
|
|
|
41
|
+
self.add_projected_keys("grad", "exp_avg")
|
|
42
|
+
self.add_projected_keys("grad_sq", "exp_avg_sq", "max_exp_avg_sq")
|
|
43
|
+
|
|
41
44
|
@torch.no_grad
|
|
42
45
|
def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
|
|
43
46
|
self.increment_counter("step", start=0)
|
|
@@ -81,4 +84,5 @@ class Adam(TensorTransform):
|
|
|
81
84
|
exp_avg = exp_avg * alpha
|
|
82
85
|
|
|
83
86
|
# ---------------------------------- update ---------------------------------- #
|
|
84
|
-
return exp_avg / exp_avg_sq.sqrt().add_(eps)
|
|
87
|
+
return exp_avg / exp_avg_sq.sqrt().add_(eps)
|
|
88
|
+
|
|
@@ -87,6 +87,9 @@ class Adan(TensorTransform):
|
|
|
87
87
|
self.set_child("v", v_tfm)
|
|
88
88
|
self.set_child("n", n_tfm)
|
|
89
89
|
|
|
90
|
+
self.add_projected_keys("grad_sq", "m", "v", "g_prev")
|
|
91
|
+
self.add_projected_keys("grad", "n")
|
|
92
|
+
|
|
90
93
|
@torch.no_grad
|
|
91
94
|
def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
|
|
92
95
|
tensors = TensorList(tensors)
|