torchzero 0.3.15__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +2 -2
- tests/test_module_autograd.py +586 -0
- tests/test_objective.py +188 -0
- tests/test_opts.py +43 -33
- tests/test_tensorlist.py +0 -8
- tests/test_utils_optimizer.py +0 -1
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +7 -4
- torchzero/core/chain.py +20 -23
- torchzero/core/functional.py +90 -24
- torchzero/core/modular.py +48 -52
- torchzero/core/module.py +130 -50
- torchzero/core/objective.py +948 -0
- torchzero/core/reformulation.py +55 -24
- torchzero/core/transform.py +261 -367
- torchzero/linalg/__init__.py +10 -0
- torchzero/linalg/eigh.py +34 -0
- torchzero/linalg/linalg_utils.py +14 -0
- torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
- torchzero/linalg/matrix_power.py +28 -0
- torchzero/linalg/orthogonalize.py +95 -0
- torchzero/{utils/linalg → linalg}/qr.py +4 -2
- torchzero/{utils/linalg → linalg}/solve.py +76 -88
- torchzero/linalg/svd.py +20 -0
- torchzero/linalg/torch_linalg.py +168 -0
- torchzero/modules/adaptive/__init__.py +1 -1
- torchzero/modules/adaptive/adagrad.py +163 -213
- torchzero/modules/adaptive/adahessian.py +74 -103
- torchzero/modules/adaptive/adam.py +53 -76
- torchzero/modules/adaptive/adan.py +49 -30
- torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
- torchzero/modules/adaptive/aegd.py +12 -12
- torchzero/modules/adaptive/esgd.py +98 -119
- torchzero/modules/adaptive/lion.py +5 -10
- torchzero/modules/adaptive/lmadagrad.py +87 -32
- torchzero/modules/adaptive/mars.py +5 -5
- torchzero/modules/adaptive/matrix_momentum.py +47 -51
- torchzero/modules/adaptive/msam.py +70 -52
- torchzero/modules/adaptive/muon.py +59 -124
- torchzero/modules/adaptive/natural_gradient.py +33 -28
- torchzero/modules/adaptive/orthograd.py +11 -15
- torchzero/modules/adaptive/rmsprop.py +83 -75
- torchzero/modules/adaptive/rprop.py +48 -47
- torchzero/modules/adaptive/sam.py +55 -45
- torchzero/modules/adaptive/shampoo.py +123 -129
- torchzero/modules/adaptive/soap.py +207 -143
- torchzero/modules/adaptive/sophia_h.py +106 -130
- torchzero/modules/clipping/clipping.py +15 -18
- torchzero/modules/clipping/ema_clipping.py +31 -25
- torchzero/modules/clipping/growth_clipping.py +14 -17
- torchzero/modules/conjugate_gradient/cg.py +26 -37
- torchzero/modules/experimental/__init__.py +2 -6
- torchzero/modules/experimental/coordinate_momentum.py +36 -0
- torchzero/modules/experimental/curveball.py +25 -41
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/higher_order_newton.py +14 -40
- torchzero/modules/experimental/newton_solver.py +22 -53
- torchzero/modules/experimental/newtonnewton.py +15 -12
- torchzero/modules/experimental/reduce_outward_lr.py +7 -7
- torchzero/modules/experimental/scipy_newton_cg.py +21 -24
- torchzero/modules/experimental/spsa1.py +3 -3
- torchzero/modules/experimental/structural_projections.py +1 -4
- torchzero/modules/functional.py +1 -1
- torchzero/modules/grad_approximation/forward_gradient.py +7 -7
- torchzero/modules/grad_approximation/grad_approximator.py +23 -16
- torchzero/modules/grad_approximation/rfdm.py +20 -17
- torchzero/modules/least_squares/gn.py +90 -42
- torchzero/modules/line_search/backtracking.py +2 -2
- torchzero/modules/line_search/line_search.py +32 -32
- torchzero/modules/line_search/strong_wolfe.py +2 -2
- torchzero/modules/misc/debug.py +12 -12
- torchzero/modules/misc/escape.py +10 -10
- torchzero/modules/misc/gradient_accumulation.py +10 -78
- torchzero/modules/misc/homotopy.py +16 -8
- torchzero/modules/misc/misc.py +120 -122
- torchzero/modules/misc/multistep.py +50 -48
- torchzero/modules/misc/regularization.py +49 -44
- torchzero/modules/misc/split.py +30 -28
- torchzero/modules/misc/switch.py +37 -32
- torchzero/modules/momentum/averaging.py +14 -14
- torchzero/modules/momentum/cautious.py +34 -28
- torchzero/modules/momentum/momentum.py +11 -11
- torchzero/modules/ops/__init__.py +4 -4
- torchzero/modules/ops/accumulate.py +21 -21
- torchzero/modules/ops/binary.py +67 -66
- torchzero/modules/ops/higher_level.py +19 -19
- torchzero/modules/ops/multi.py +44 -41
- torchzero/modules/ops/reduce.py +26 -23
- torchzero/modules/ops/unary.py +53 -53
- torchzero/modules/ops/utility.py +47 -46
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +43 -43
- torchzero/modules/quasi_newton/damping.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +7 -7
- torchzero/modules/quasi_newton/lsr1.py +7 -7
- torchzero/modules/quasi_newton/quasi_newton.py +10 -10
- torchzero/modules/quasi_newton/sg2.py +19 -19
- torchzero/modules/restarts/restars.py +26 -24
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/ifn.py +31 -62
- torchzero/modules/second_order/inm.py +49 -53
- torchzero/modules/second_order/multipoint.py +40 -80
- torchzero/modules/second_order/newton.py +57 -90
- torchzero/modules/second_order/newton_cg.py +102 -154
- torchzero/modules/second_order/nystrom.py +157 -177
- torchzero/modules/second_order/rsn.py +106 -96
- torchzero/modules/smoothing/laplacian.py +13 -12
- torchzero/modules/smoothing/sampling.py +11 -10
- torchzero/modules/step_size/adaptive.py +23 -23
- torchzero/modules/step_size/lr.py +15 -15
- torchzero/modules/termination/termination.py +32 -30
- torchzero/modules/trust_region/cubic_regularization.py +2 -2
- torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
- torchzero/modules/trust_region/trust_cg.py +1 -1
- torchzero/modules/trust_region/trust_region.py +27 -22
- torchzero/modules/variance_reduction/svrg.py +21 -18
- torchzero/modules/weight_decay/__init__.py +2 -1
- torchzero/modules/weight_decay/reinit.py +83 -0
- torchzero/modules/weight_decay/weight_decay.py +12 -13
- torchzero/modules/wrappers/optim_wrapper.py +10 -10
- torchzero/modules/zeroth_order/cd.py +9 -6
- torchzero/optim/root.py +3 -3
- torchzero/optim/utility/split.py +2 -1
- torchzero/optim/wrappers/directsearch.py +27 -63
- torchzero/optim/wrappers/fcmaes.py +14 -35
- torchzero/optim/wrappers/mads.py +11 -31
- torchzero/optim/wrappers/moors.py +66 -0
- torchzero/optim/wrappers/nevergrad.py +4 -4
- torchzero/optim/wrappers/nlopt.py +31 -25
- torchzero/optim/wrappers/optuna.py +6 -13
- torchzero/optim/wrappers/pybobyqa.py +124 -0
- torchzero/optim/wrappers/scipy/__init__.py +7 -0
- torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
- torchzero/optim/wrappers/scipy/brute.py +48 -0
- torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
- torchzero/optim/wrappers/scipy/direct.py +69 -0
- torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
- torchzero/optim/wrappers/scipy/experimental.py +141 -0
- torchzero/optim/wrappers/scipy/minimize.py +151 -0
- torchzero/optim/wrappers/scipy/sgho.py +111 -0
- torchzero/optim/wrappers/wrapper.py +121 -0
- torchzero/utils/__init__.py +7 -25
- torchzero/utils/compile.py +2 -2
- torchzero/utils/derivatives.py +93 -69
- torchzero/utils/optimizer.py +4 -77
- torchzero/utils/python_tools.py +31 -0
- torchzero/utils/tensorlist.py +11 -5
- torchzero/utils/thoad_tools.py +68 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
- torchzero-0.4.0.dist-info/RECORD +191 -0
- tests/test_vars.py +0 -185
- torchzero/core/var.py +0 -376
- torchzero/modules/experimental/momentum.py +0 -160
- torchzero/optim/wrappers/scipy.py +0 -572
- torchzero/utils/linalg/__init__.py +0 -12
- torchzero/utils/linalg/matrix_funcs.py +0 -87
- torchzero/utils/linalg/orthogonalize.py +0 -12
- torchzero/utils/linalg/svd.py +0 -20
- torchzero/utils/ops.py +0 -10
- torchzero-0.3.15.dist-info/RECORD +0 -175
- /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
from . import linear_operator
|
|
2
|
+
|
|
3
|
+
from .matrix_power import (
|
|
4
|
+
matrix_power_eigh,
|
|
5
|
+
matrix_power_svd,
|
|
6
|
+
)
|
|
7
|
+
from .orthogonalize import zeropower_via_eigh, zeropower_via_newtonschulz5, zeropower_via_svd, orthogonalize
|
|
8
|
+
from .qr import qr_householder
|
|
9
|
+
from .solve import cg, nystrom_sketch_and_solve, nystrom_pcg
|
|
10
|
+
from .eigh import nystrom_approximation
|
torchzero/linalg/eigh.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
import torch
|
|
3
|
+
from .linalg_utils import mm
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
# https://arxiv.org/pdf/2110.02820
|
|
8
|
+
def nystrom_approximation(
|
|
9
|
+
A_mv: Callable[[torch.Tensor], torch.Tensor] | None,
|
|
10
|
+
A_mm: Callable[[torch.Tensor], torch.Tensor] | None,
|
|
11
|
+
ndim: int,
|
|
12
|
+
rank: int,
|
|
13
|
+
device,
|
|
14
|
+
dtype = torch.float32,
|
|
15
|
+
generator = None,
|
|
16
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
17
|
+
"""Computes Nyström approximation to positive-semidefinite A factored as Q L Q^T (truncatd eigenvalue decomp),
|
|
18
|
+
returns ``(L, Q)``.
|
|
19
|
+
|
|
20
|
+
A is ``(m,m)``, then Q is ``(m, rank)``; L is a ``(rank, )`` vector - diagonal of ``(rank, rank)``"""
|
|
21
|
+
# basis
|
|
22
|
+
O = torch.randn((ndim, rank), device=device, dtype=dtype, generator=generator) # Gaussian test matrix
|
|
23
|
+
O, _ = torch.linalg.qr(O) # Thin QR decomposition # pylint:disable=not-callable
|
|
24
|
+
|
|
25
|
+
# Y = AΩ
|
|
26
|
+
AO = mm(A_mv=A_mv, A_mm=A_mm, X=O)
|
|
27
|
+
|
|
28
|
+
v = torch.finfo(dtype).eps * torch.linalg.matrix_norm(AO, ord='fro') # Compute shift # pylint:disable=not-callable
|
|
29
|
+
Yv = AO + v*O # Shift for stability
|
|
30
|
+
C = torch.linalg.cholesky_ex(O.mT @ Yv)[0] # pylint:disable=not-callable
|
|
31
|
+
B = torch.linalg.solve_triangular(C, Yv.mT, upper=False, unitriangular=False).mT # pylint:disable=not-callable
|
|
32
|
+
Q, S, _ = torch.linalg.svd(B, full_matrices=False) # pylint:disable=not-callable
|
|
33
|
+
L = (S.pow(2) - v).clip(min=0) #Remove shift, compute eigs
|
|
34
|
+
return L, Q
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
def mm(
|
|
5
|
+
A_mv: Callable[[torch.Tensor], torch.Tensor] | None,
|
|
6
|
+
A_mm: Callable[[torch.Tensor], torch.Tensor] | None,
|
|
7
|
+
X
|
|
8
|
+
):
|
|
9
|
+
"""matrix-matrix when either mv or mm is given"""
|
|
10
|
+
if A_mm is not None: return A_mm(X)
|
|
11
|
+
assert A_mv is not None
|
|
12
|
+
return torch.stack([A_mv(col) for col in X.unbind(-1)], -1) # rank matvecs
|
|
13
|
+
|
|
14
|
+
|
|
@@ -1,4 +1,6 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""This is mainly used for trust regions. In some cases certain operations are relaxed, e.g. eigenvalue shift instead of
|
|
2
|
+
adding diagonal when it isn't tractable, to make it work with Levenberg-Marquadt.
|
|
3
|
+
"""
|
|
2
4
|
import math
|
|
3
5
|
from abc import ABC, abstractmethod
|
|
4
6
|
from functools import partial
|
|
@@ -7,7 +9,8 @@ from typing import cast, final
|
|
|
7
9
|
|
|
8
10
|
import torch
|
|
9
11
|
|
|
10
|
-
from ..torch_tools import tofloat, tonumpy, totensor
|
|
12
|
+
from ..utils.torch_tools import tofloat, tonumpy, totensor
|
|
13
|
+
from .solve import nystrom_sketch_and_solve
|
|
11
14
|
|
|
12
15
|
if find_spec('scipy') is not None:
|
|
13
16
|
from scipy.sparse.linalg import LinearOperator as _ScipyLinearOperator
|
|
@@ -15,7 +18,6 @@ else:
|
|
|
15
18
|
_ScipyLinearOperator = None
|
|
16
19
|
|
|
17
20
|
class LinearOperator(ABC):
|
|
18
|
-
"""this is used for trust region"""
|
|
19
21
|
device: torch.types.Device
|
|
20
22
|
dtype: torch.dtype | None
|
|
21
23
|
|
|
@@ -25,12 +27,18 @@ class LinearOperator(ABC):
|
|
|
25
27
|
def rmatvec(self, x: torch.Tensor) -> torch.Tensor:
|
|
26
28
|
raise NotImplementedError(f"{self.__class__.__name__} doesn't implement rmatvec")
|
|
27
29
|
|
|
28
|
-
def matmat(self,
|
|
29
|
-
raise NotImplementedError(f"{self.__class__.__name__} doesn't implement
|
|
30
|
+
def matmat(self, X: torch.Tensor) -> "LinearOperator":
|
|
31
|
+
raise NotImplementedError(f"{self.__class__.__name__} doesn't implement matmat")
|
|
32
|
+
|
|
33
|
+
def rmatmat(self, X: torch.Tensor) -> "LinearOperator":
|
|
34
|
+
raise NotImplementedError(f"{self.__class__.__name__} doesn't implement rmatmat")
|
|
30
35
|
|
|
31
36
|
def solve(self, b: torch.Tensor) -> torch.Tensor:
|
|
32
37
|
raise NotImplementedError(f"{self.__class__.__name__} doesn't implement solve")
|
|
33
38
|
|
|
39
|
+
def solve_plus_diag(self, b: torch.Tensor, diag: int | float | torch.Tensor) -> torch.Tensor:
|
|
40
|
+
return self.add_diagonal(diag).solve(b)
|
|
41
|
+
|
|
34
42
|
def solve_bounded(self, b: torch.Tensor, bound:float, ord:float=2) -> torch.Tensor:
|
|
35
43
|
"""solve with a norm bound on x"""
|
|
36
44
|
raise NotImplementedError(f"{self.__class__.__name__} doesn't implement solve_bounded")
|
|
@@ -129,8 +137,8 @@ class Dense(LinearOperator):
|
|
|
129
137
|
def matvec(self, x): return self.A.mv(x)
|
|
130
138
|
def rmatvec(self, x): return self.A.mH.mv(x)
|
|
131
139
|
|
|
132
|
-
def matmat(self,
|
|
133
|
-
def rmatmat(self,
|
|
140
|
+
def matmat(self, X): return Dense(self.A.mm(X))
|
|
141
|
+
def rmatmat(self, X): return Dense(self.A.mH.mm(X))
|
|
134
142
|
|
|
135
143
|
def solve(self, b): return _solve(self.A, b)
|
|
136
144
|
|
|
@@ -146,6 +154,12 @@ class Dense(LinearOperator):
|
|
|
146
154
|
def is_dense(self): return True
|
|
147
155
|
def transpose(self): return Dense(self.A.mH)
|
|
148
156
|
|
|
157
|
+
class SPD(Dense):
|
|
158
|
+
def solve(self, b: torch.Tensor):
|
|
159
|
+
L, info = torch.linalg.cholesky_ex(self.A) # pylint:disable=not-callable
|
|
160
|
+
return torch.cholesky_solve(b.unsqueeze(-1), L).squeeze(-1)
|
|
161
|
+
|
|
162
|
+
|
|
149
163
|
class DenseInverse(LinearOperator):
|
|
150
164
|
"""Represents inverse of a dense matrix A."""
|
|
151
165
|
def __init__(self, A_inv: torch.Tensor):
|
|
@@ -156,8 +170,8 @@ class DenseInverse(LinearOperator):
|
|
|
156
170
|
def matvec(self, x): return _solve(self.A_inv, x) # pylint:disable=not-callable
|
|
157
171
|
def rmatvec(self, x): return _solve(self.A_inv.mH, x) # pylint:disable=not-callable
|
|
158
172
|
|
|
159
|
-
def matmat(self,
|
|
160
|
-
def rmatmat(self,
|
|
173
|
+
def matmat(self, X): return Dense(_solve(self.A_inv, X)) # pylint:disable=not-callable
|
|
174
|
+
def rmatmat(self, X): return Dense(_solve(self.A_inv.mH, X)) # pylint:disable=not-callable
|
|
161
175
|
|
|
162
176
|
def solve(self, b): return self.A_inv.mv(b)
|
|
163
177
|
|
|
@@ -190,8 +204,8 @@ class Diagonal(LinearOperator):
|
|
|
190
204
|
def matvec(self, x): return self.A * x
|
|
191
205
|
def rmatvec(self, x): return self.A * x
|
|
192
206
|
|
|
193
|
-
def matmat(self,
|
|
194
|
-
def rmatmat(self,
|
|
207
|
+
def matmat(self, X): return Dense(X * self.A.unsqueeze(-1))
|
|
208
|
+
def rmatmat(self, X): return Dense(X * self.A.unsqueeze(-1))
|
|
195
209
|
|
|
196
210
|
def solve(self, b): return b/self.A
|
|
197
211
|
|
|
@@ -221,8 +235,8 @@ class ScaledIdentity(LinearOperator):
|
|
|
221
235
|
def matvec(self, x): return x * self.s
|
|
222
236
|
def rmatvec(self, x): return x * self.s
|
|
223
237
|
|
|
224
|
-
def matmat(self,
|
|
225
|
-
def rmatmat(self,
|
|
238
|
+
def matmat(self, X): return Dense(X * self.s)
|
|
239
|
+
def rmatmat(self, X): return Dense(X * self.s)
|
|
226
240
|
|
|
227
241
|
def solve(self, b): return b / self.s
|
|
228
242
|
def solve_bounded(self, b, bound, ord = 2):
|
|
@@ -263,6 +277,7 @@ class ScaledIdentity(LinearOperator):
|
|
|
263
277
|
def is_dense(self): return False
|
|
264
278
|
def transpose(self): return ScaledIdentity(self.s, shape=self.shape, device=self.device, dtype=self.dtype)
|
|
265
279
|
|
|
280
|
+
|
|
266
281
|
class AtA(LinearOperator):
|
|
267
282
|
def __init__(self, A: torch.Tensor):
|
|
268
283
|
self.A = A
|
|
@@ -270,8 +285,8 @@ class AtA(LinearOperator):
|
|
|
270
285
|
def matvec(self, x): return self.A.mH.mv(self.A.mv(x))
|
|
271
286
|
def rmatvec(self, x): return self.matvec(x)
|
|
272
287
|
|
|
273
|
-
def matmat(self,
|
|
274
|
-
def rmatmat(self,
|
|
288
|
+
def matmat(self, X): return Dense(torch.linalg.multi_dot([self.A.mH, self.A, X])) # pylint:disable=not-callable
|
|
289
|
+
def rmatmat(self, X): return Dense(torch.linalg.multi_dot([self.A.mH, self.A, X])) # pylint:disable=not-callable
|
|
275
290
|
|
|
276
291
|
def is_dense(self): return False
|
|
277
292
|
def to_tensor(self): return self.A.mH @ self.A
|
|
@@ -283,51 +298,41 @@ class AtA(LinearOperator):
|
|
|
283
298
|
return Dense(self.to_tensor() + torch.diag_embed(x))
|
|
284
299
|
|
|
285
300
|
def solve(self, b):
|
|
286
|
-
|
|
301
|
+
*_, n, m = self.A.shape
|
|
302
|
+
if n >= m: return Dense(self.to_tensor()).solve(b)
|
|
287
303
|
|
|
288
|
-
|
|
289
|
-
|
|
304
|
+
A = self.A
|
|
305
|
+
C = A @ A.mH # (n, n), SPD
|
|
306
|
+
L, info = torch.linalg.cholesky_ex(C) # pylint:disable=not-callable
|
|
307
|
+
z = torch.cholesky_solve((A @ b).unsqueeze(-1), L).squeeze(-1)
|
|
308
|
+
return A.mH @ z
|
|
290
309
|
|
|
291
|
-
def
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
n = self.A.size(1)
|
|
296
|
-
return (n,n)
|
|
310
|
+
def solve_plus_diag(self, b, diag):
|
|
311
|
+
*_, n, m = self.A.shape
|
|
312
|
+
if (n >= m) or (isinstance(diag, torch.Tensor) and diag.numel() > 1):
|
|
313
|
+
return Dense(self.to_tensor()).solve_plus_diag(b, diag)
|
|
297
314
|
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
self.A = A
|
|
301
|
-
self.device = self.A.device; self.dtype = self.A.dtype
|
|
302
|
-
|
|
303
|
-
def matvec(self, x): return self.A.mv(self.A.mH.mv(x))
|
|
304
|
-
def rmatvec(self, x): return self.matvec(x)
|
|
305
|
-
|
|
306
|
-
def matmat(self, x): return Dense(torch.linalg.multi_dot([self.A, self.A.mH, x])) # pylint:disable=not-callable
|
|
307
|
-
def rmatmat(self, x): return Dense(torch.linalg.multi_dot([self.A, self.A.mH, x])) # pylint:disable=not-callable
|
|
308
|
-
|
|
309
|
-
def is_dense(self): return False
|
|
310
|
-
def to_tensor(self): return self.A @ self.A.mH
|
|
311
|
-
def transpose(self): return AAT(self.A)
|
|
312
|
-
|
|
313
|
-
def add_diagonal(self, x):
|
|
314
|
-
if isinstance(x, torch.Tensor) and x.numel() <= 1: x = x.item()
|
|
315
|
-
if isinstance(x, (int,float)): x = torch.full((self.shape[0],), fill_value=x, device=self.A.device, dtype=self.A.dtype)
|
|
316
|
-
return Dense(self.to_tensor() + torch.diag_embed(x))
|
|
315
|
+
A = self.A
|
|
316
|
+
I = torch.eye(A.size(-2), device=A.device, dtype=A.dtype)
|
|
317
317
|
|
|
318
|
-
|
|
319
|
-
|
|
318
|
+
C = (A @ A.mH).add_(I.mul_(diag)) # (n, n), SPD
|
|
319
|
+
L, info = torch.linalg.cholesky_ex(C + I.mul_(diag)) # pylint:disable=not-callable
|
|
320
|
+
z = torch.cholesky_solve((A @ b).unsqueeze(-1), L).squeeze(-1)
|
|
321
|
+
return (1 / diag) * (b - A.mH @ z)
|
|
320
322
|
|
|
321
323
|
def inv(self):
|
|
322
324
|
return Dense(self.to_tensor()).inv()
|
|
323
325
|
|
|
324
326
|
def diagonal(self):
|
|
325
|
-
return self.A.pow(2).sum(
|
|
327
|
+
return self.A.pow(2).sum(1)
|
|
326
328
|
|
|
327
329
|
def size(self):
|
|
328
330
|
n = self.A.size(1)
|
|
329
331
|
return (n,n)
|
|
330
332
|
|
|
333
|
+
class AAt(AtA):
|
|
334
|
+
def __init__(self, A: torch.Tensor):
|
|
335
|
+
super().__init__(A.mH)
|
|
331
336
|
|
|
332
337
|
class Sketched(LinearOperator):
|
|
333
338
|
"""A projected by sketching matrix S, representing the operator S @ A_proj @ S.T.
|
|
@@ -339,7 +344,6 @@ class Sketched(LinearOperator):
|
|
|
339
344
|
self.A_proj = A_proj
|
|
340
345
|
self.device = self.A_proj.device; self.dtype = self.A_proj.dtype
|
|
341
346
|
|
|
342
|
-
|
|
343
347
|
def matvec(self, x):
|
|
344
348
|
x_proj = self.S.T @ x
|
|
345
349
|
Ax_proj = self.A_proj @ x_proj
|
|
@@ -351,8 +355,8 @@ class Sketched(LinearOperator):
|
|
|
351
355
|
return self.S @ ATx_proj
|
|
352
356
|
|
|
353
357
|
|
|
354
|
-
def matmat(self,
|
|
355
|
-
def rmatmat(self,
|
|
358
|
+
def matmat(self, X): return Dense(torch.linalg.multi_dot([self.S, self.A_proj, self.S.T, X])) # pylint:disable=not-callable
|
|
359
|
+
def rmatmat(self, X): return Dense(torch.linalg.multi_dot([self.S, self.A_proj.mH, self.S.T, X])) # pylint:disable=not-callable
|
|
356
360
|
|
|
357
361
|
|
|
358
362
|
def is_dense(self): return False
|
|
@@ -375,3 +379,49 @@ class Sketched(LinearOperator):
|
|
|
375
379
|
n = self.S.size(0)
|
|
376
380
|
return (n,n)
|
|
377
381
|
|
|
382
|
+
|
|
383
|
+
class Eigendecomposition(LinearOperator):
|
|
384
|
+
"""A represented as Q L Q^H. If A is (n,n), then Q is (n, rank); L is a vector - diagonal of (rank, rank)"""
|
|
385
|
+
def __init__(self, L: torch.Tensor, Q: torch.Tensor, use_nystrom: bool = True):
|
|
386
|
+
self.L = L
|
|
387
|
+
self.Q = Q
|
|
388
|
+
self.use_nystrom = use_nystrom
|
|
389
|
+
self.device = self.L.device; self.dtype = self.L.dtype
|
|
390
|
+
|
|
391
|
+
def matvec(self, x):
|
|
392
|
+
return self.Q @ ((self.Q.mH @ x) * self.L)
|
|
393
|
+
|
|
394
|
+
def rmatvec(self, x):
|
|
395
|
+
return self.matvec(x)
|
|
396
|
+
|
|
397
|
+
def matmat(self, X):
|
|
398
|
+
return Dense(self.Q @ (self.L[:, None] * (self.Q.mH @ X)))
|
|
399
|
+
|
|
400
|
+
def rmatmat(self, X):
|
|
401
|
+
return self.matmat(X)
|
|
402
|
+
|
|
403
|
+
def is_dense(self): return False
|
|
404
|
+
def to_tensor(self): return self.Q @ self.L.diag_embed() @ self.Q.mH
|
|
405
|
+
def transpose(self): return Eigendecomposition(L=self.L, Q=self.Q)
|
|
406
|
+
|
|
407
|
+
def add_diagonal(self, x):
|
|
408
|
+
"""this doesn't correspond to adding diagonal to A, however it still works for LM etc."""
|
|
409
|
+
if isinstance(x, torch.Tensor) and x.numel() > 1:
|
|
410
|
+
raise RuntimeError("Eigendecomposition linear operator doesn't support add_diagonal with a vector diag")
|
|
411
|
+
|
|
412
|
+
return Eigendecomposition(L=self.L + x, Q = self.Q)
|
|
413
|
+
|
|
414
|
+
def solve(self, b):
|
|
415
|
+
return self.Q @ ((self.Q.mH @ b) / self.L)
|
|
416
|
+
|
|
417
|
+
def solve_plus_diag(self, b, diag):
|
|
418
|
+
if isinstance(diag, torch.Tensor) and diag.numel() > 1: return super().solve_plus_diag(b, diag)
|
|
419
|
+
if not self.use_nystrom: return super().solve_plus_diag(b, diag)
|
|
420
|
+
return nystrom_sketch_and_solve(L=self.L, Q=self.Q, b=b, reg=float(diag))
|
|
421
|
+
|
|
422
|
+
def inv(self):
|
|
423
|
+
return Eigendecomposition(L=1 / self.L, Q = self.Q)
|
|
424
|
+
|
|
425
|
+
def size(self):
|
|
426
|
+
n = self.Q.size(0)
|
|
427
|
+
return (n,n)
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
from typing import Literal
|
|
2
|
+
import warnings
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from . import torch_linalg
|
|
7
|
+
def matrix_power_eigh(A: torch.Tensor, power:float, abs:bool=False):
|
|
8
|
+
"""this is faster than SVD but only for positive semi-definite symmetric matrices
|
|
9
|
+
(covariance matrices are always SPD)"""
|
|
10
|
+
|
|
11
|
+
L, Q = torch_linalg.eigh(A, retry_float64=True) # pylint:disable=not-callable
|
|
12
|
+
if abs: L.abs_()
|
|
13
|
+
if power % 2 != 0: L.clip_(min = torch.finfo(A.dtype).tiny * 2)
|
|
14
|
+
return (Q * L.pow_(power).unsqueeze(-2)) @ Q.mH
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def matrix_power_svd(A: torch.Tensor, power: float) -> torch.Tensor:
|
|
18
|
+
"""for any symmetric matrix"""
|
|
19
|
+
U, S, Vh = torch_linalg.svd(A, full_matrices=False, retry_float64=True) # pylint:disable=not-callable
|
|
20
|
+
if power % 2 != 0: S.clip_(min = torch.finfo(A.dtype).tiny * 2)
|
|
21
|
+
return (U * S.pow_(power).unsqueeze(-2)) @ Vh
|
|
22
|
+
|
|
23
|
+
MatrixPowerMethod = Literal["eigh", "eigh_abs", "svd"]
|
|
24
|
+
def matrix_power(A: torch.Tensor, power: float, method: MatrixPowerMethod = "eigh_abs") -> torch.Tensor:
|
|
25
|
+
if method == "eigh": return matrix_power_eigh(A, power)
|
|
26
|
+
if method == "eigh_abs": return matrix_power_eigh(A, power, abs=True)
|
|
27
|
+
if method == "svd": return matrix_power_svd(A, power)
|
|
28
|
+
raise ValueError(method)
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
from typing import Literal
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
from ..utils.compile import allow_compile
|
|
5
|
+
from . import torch_linalg
|
|
6
|
+
|
|
7
|
+
# zeropower_via_newtonschulz5 from:
|
|
8
|
+
# https://github.com/KellerJordan/Muon/blob/master/muon.py
|
|
9
|
+
# and
|
|
10
|
+
# https://github.com/HomebrewML/HeavyBall/blob/main/heavyball/utils.py#L452
|
|
11
|
+
_NS_COEFFS = (
|
|
12
|
+
(4.0848, -6.8946, 2.9270),
|
|
13
|
+
(3.9505, -6.3029, 2.6377),
|
|
14
|
+
(3.7418, -5.5913, 2.3037),
|
|
15
|
+
(2.8769, -3.1427, 1.2046),
|
|
16
|
+
(2.8366, -3.0525, 1.2012)
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
@allow_compile
|
|
20
|
+
def zeropower_via_newtonschulz5(G: torch.Tensor, coeffs=_NS_COEFFS) -> torch.Tensor:
|
|
21
|
+
"""
|
|
22
|
+
Applies to last 2 dims - so usually reverse_dims should be applied to G before and after.
|
|
23
|
+
|
|
24
|
+
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
|
|
25
|
+
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
|
|
26
|
+
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
|
|
27
|
+
zero even beyond the point where the iteration no longer converges all the way to one everywhere
|
|
28
|
+
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
|
|
29
|
+
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
|
|
30
|
+
performance at all relative to UV^T, where USV^T = G is the SVD.
|
|
31
|
+
"""
|
|
32
|
+
assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
|
|
33
|
+
|
|
34
|
+
X = G.bfloat16()
|
|
35
|
+
if G.size(-2) > G.size(-1):
|
|
36
|
+
X = X.mT
|
|
37
|
+
|
|
38
|
+
# Ensure spectral norm is at most 1
|
|
39
|
+
X = X / (X.norm(dim=(-2, -1), keepdim=True).clip(min=torch.finfo(X.dtype).tiny * 2))
|
|
40
|
+
|
|
41
|
+
# Perform the NS iterations
|
|
42
|
+
for a,b,c in coeffs:
|
|
43
|
+
A = X @ X.mT
|
|
44
|
+
B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
|
|
45
|
+
X = a * X + B @ X
|
|
46
|
+
|
|
47
|
+
if G.size(-2) > G.size(-1):
|
|
48
|
+
X = X.mT
|
|
49
|
+
|
|
50
|
+
return X.to(G.dtype)
|
|
51
|
+
|
|
52
|
+
# code from https://github.com/MarkTuddenham/Orthogonal-Optimisers.
|
|
53
|
+
# Tuddenham, M., Prügel-Bennett, A., & Hare, J. (2022).
|
|
54
|
+
# Orthogonalising gradients to speed up neural network optimisation. arXiv preprint arXiv:2202.07052.
|
|
55
|
+
def zeropower_via_svd(A: torch.Tensor) -> torch.Tensor:
|
|
56
|
+
"""
|
|
57
|
+
Applies to first 2 dims and isn't batched - rest of dimensions are flattened.
|
|
58
|
+
"""
|
|
59
|
+
try:
|
|
60
|
+
U, S, Vt = torch_linalg.svd(A, full_matrices=False, retry_float64=True) # pylint:disable=not-callable
|
|
61
|
+
except torch.linalg.LinAlgError:
|
|
62
|
+
U, S, Vt = torch.svd_lowrank(A, q=1, M=1e-4 * A.mean() * torch.rand_like(A))
|
|
63
|
+
|
|
64
|
+
return U @ Vt
|
|
65
|
+
|
|
66
|
+
def zeropower_via_eigh(A: torch.Tensor) -> torch.Tensor:
|
|
67
|
+
"""
|
|
68
|
+
Only SPD and I need to check if I apply those to SPD because this is better than SVD.
|
|
69
|
+
"""
|
|
70
|
+
L, Q = torch_linalg.eigh(A, retry_float64=True)
|
|
71
|
+
return Q @ Q.mH
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def orthogonalize_via_qr(A: torch.Tensor):
|
|
75
|
+
*_, m, n = A.shape
|
|
76
|
+
T = False
|
|
77
|
+
if m < n:
|
|
78
|
+
T = True
|
|
79
|
+
m,n = n,m
|
|
80
|
+
A = A.mH
|
|
81
|
+
|
|
82
|
+
Q = torch_linalg.qr(A, mode='reduced', retry_float64=True).Q
|
|
83
|
+
|
|
84
|
+
if T:
|
|
85
|
+
Q = Q.mH
|
|
86
|
+
|
|
87
|
+
return Q
|
|
88
|
+
|
|
89
|
+
OrthogonalizeMethod = Literal["newtonschulz", "svd", "qr"]
|
|
90
|
+
def orthogonalize(A: torch.Tensor, method: OrthogonalizeMethod = "newtonschulz") -> torch.Tensor:
|
|
91
|
+
if method == "newtonschulz": return zeropower_via_newtonschulz5(A)
|
|
92
|
+
if method == "svd": return zeropower_via_svd(A)
|
|
93
|
+
if method == "qr": return orthogonalize_via_qr(A)
|
|
94
|
+
if method == "eigh": return zeropower_via_eigh(A)
|
|
95
|
+
raise ValueError(method)
|
|
@@ -1,8 +1,9 @@
|
|
|
1
1
|
from typing import Literal
|
|
2
2
|
import torch
|
|
3
|
-
from ..compile import
|
|
3
|
+
from ..utils.compile import allow_compile
|
|
4
4
|
|
|
5
5
|
# reference - https://www.cs.cornell.edu/~bindel/class/cs6210-f09/lec18.pdf
|
|
6
|
+
@allow_compile
|
|
6
7
|
def _get_w_tau(R: torch.Tensor, i: int, eps: float):
|
|
7
8
|
R_ii = R[...,i,i]
|
|
8
9
|
R_below = R[...,i:,i]
|
|
@@ -17,6 +18,7 @@ def _get_w_tau(R: torch.Tensor, i: int, eps: float):
|
|
|
17
18
|
tau = torch.where(degenerate, 1, tau)
|
|
18
19
|
return w, tau
|
|
19
20
|
|
|
21
|
+
@allow_compile
|
|
20
22
|
def _qr_householder_complete(A:torch.Tensor):
|
|
21
23
|
*b,m,n = A.shape
|
|
22
24
|
k = min(m,n)
|
|
@@ -33,6 +35,7 @@ def _qr_householder_complete(A:torch.Tensor):
|
|
|
33
35
|
|
|
34
36
|
return Q, R
|
|
35
37
|
|
|
38
|
+
@allow_compile
|
|
36
39
|
def _qr_householder_reduced(A:torch.Tensor):
|
|
37
40
|
*b,m,n = A.shape
|
|
38
41
|
k = min(m,n)
|
|
@@ -64,7 +67,6 @@ def _qr_householder_reduced(A:torch.Tensor):
|
|
|
64
67
|
|
|
65
68
|
return Q, R
|
|
66
69
|
|
|
67
|
-
# @enable_compilation
|
|
68
70
|
def qr_householder(A:torch.Tensor, mode: Literal['complete', 'reduced'] = 'reduced'):
|
|
69
71
|
"""an attempt at making QR decomposition for very tall and thin matrices that doesn't freeze, but it is around n_cols times slower than torch.linalg.qr, but compilation makes it faster, but it has to recompile when processing different shapes"""
|
|
70
72
|
if mode == 'reduced': return _qr_householder_reduced(A)
|