torchzero 0.3.14__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +2 -2
- tests/test_module_autograd.py +586 -0
- tests/test_objective.py +188 -0
- tests/test_opts.py +47 -36
- tests/test_tensorlist.py +0 -8
- tests/test_utils_optimizer.py +0 -1
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +8 -2
- torchzero/core/chain.py +47 -0
- torchzero/core/functional.py +103 -0
- torchzero/core/modular.py +233 -0
- torchzero/core/module.py +132 -643
- torchzero/core/objective.py +948 -0
- torchzero/core/reformulation.py +56 -23
- torchzero/core/transform.py +261 -365
- torchzero/linalg/__init__.py +10 -0
- torchzero/linalg/eigh.py +34 -0
- torchzero/linalg/linalg_utils.py +14 -0
- torchzero/{utils/linalg → linalg}/linear_operator.py +132 -34
- torchzero/linalg/matrix_power.py +28 -0
- torchzero/linalg/orthogonalize.py +95 -0
- torchzero/{utils/linalg → linalg}/qr.py +4 -2
- torchzero/{utils/linalg → linalg}/solve.py +76 -88
- torchzero/linalg/svd.py +20 -0
- torchzero/linalg/torch_linalg.py +168 -0
- torchzero/modules/__init__.py +0 -1
- torchzero/modules/adaptive/__init__.py +1 -1
- torchzero/modules/adaptive/adagrad.py +163 -213
- torchzero/modules/adaptive/adahessian.py +74 -103
- torchzero/modules/adaptive/adam.py +53 -76
- torchzero/modules/adaptive/adan.py +49 -30
- torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
- torchzero/modules/adaptive/aegd.py +12 -12
- torchzero/modules/adaptive/esgd.py +98 -119
- torchzero/modules/adaptive/lion.py +5 -10
- torchzero/modules/adaptive/lmadagrad.py +87 -32
- torchzero/modules/adaptive/mars.py +5 -5
- torchzero/modules/adaptive/matrix_momentum.py +47 -51
- torchzero/modules/adaptive/msam.py +70 -52
- torchzero/modules/adaptive/muon.py +59 -124
- torchzero/modules/adaptive/natural_gradient.py +33 -28
- torchzero/modules/adaptive/orthograd.py +11 -15
- torchzero/modules/adaptive/rmsprop.py +83 -75
- torchzero/modules/adaptive/rprop.py +48 -47
- torchzero/modules/adaptive/sam.py +55 -45
- torchzero/modules/adaptive/shampoo.py +123 -129
- torchzero/modules/adaptive/soap.py +207 -143
- torchzero/modules/adaptive/sophia_h.py +106 -130
- torchzero/modules/clipping/clipping.py +15 -18
- torchzero/modules/clipping/ema_clipping.py +31 -25
- torchzero/modules/clipping/growth_clipping.py +14 -17
- torchzero/modules/conjugate_gradient/cg.py +26 -37
- torchzero/modules/experimental/__init__.py +3 -6
- torchzero/modules/experimental/coordinate_momentum.py +36 -0
- torchzero/modules/experimental/curveball.py +25 -41
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/{higher_order → experimental}/higher_order_newton.py +14 -40
- torchzero/modules/experimental/newton_solver.py +22 -53
- torchzero/modules/experimental/newtonnewton.py +20 -17
- torchzero/modules/experimental/reduce_outward_lr.py +7 -7
- torchzero/modules/experimental/scipy_newton_cg.py +21 -24
- torchzero/modules/experimental/spsa1.py +5 -5
- torchzero/modules/experimental/structural_projections.py +1 -4
- torchzero/modules/functional.py +8 -1
- torchzero/modules/grad_approximation/forward_gradient.py +7 -7
- torchzero/modules/grad_approximation/grad_approximator.py +23 -16
- torchzero/modules/grad_approximation/rfdm.py +20 -17
- torchzero/modules/least_squares/gn.py +90 -42
- torchzero/modules/line_search/__init__.py +1 -1
- torchzero/modules/line_search/_polyinterp.py +3 -1
- torchzero/modules/line_search/adaptive.py +3 -3
- torchzero/modules/line_search/backtracking.py +3 -3
- torchzero/modules/line_search/interpolation.py +160 -0
- torchzero/modules/line_search/line_search.py +42 -51
- torchzero/modules/line_search/strong_wolfe.py +5 -5
- torchzero/modules/misc/debug.py +12 -12
- torchzero/modules/misc/escape.py +10 -10
- torchzero/modules/misc/gradient_accumulation.py +10 -78
- torchzero/modules/misc/homotopy.py +16 -8
- torchzero/modules/misc/misc.py +120 -122
- torchzero/modules/misc/multistep.py +63 -61
- torchzero/modules/misc/regularization.py +49 -44
- torchzero/modules/misc/split.py +30 -28
- torchzero/modules/misc/switch.py +37 -32
- torchzero/modules/momentum/averaging.py +14 -14
- torchzero/modules/momentum/cautious.py +34 -28
- torchzero/modules/momentum/momentum.py +11 -11
- torchzero/modules/ops/__init__.py +4 -4
- torchzero/modules/ops/accumulate.py +21 -21
- torchzero/modules/ops/binary.py +67 -66
- torchzero/modules/ops/higher_level.py +19 -19
- torchzero/modules/ops/multi.py +44 -41
- torchzero/modules/ops/reduce.py +26 -23
- torchzero/modules/ops/unary.py +53 -53
- torchzero/modules/ops/utility.py +47 -46
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +43 -43
- torchzero/modules/quasi_newton/__init__.py +2 -0
- torchzero/modules/quasi_newton/damping.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +7 -7
- torchzero/modules/quasi_newton/lsr1.py +7 -7
- torchzero/modules/quasi_newton/quasi_newton.py +25 -16
- torchzero/modules/quasi_newton/sg2.py +292 -0
- torchzero/modules/restarts/restars.py +26 -24
- torchzero/modules/second_order/__init__.py +6 -3
- torchzero/modules/second_order/ifn.py +58 -0
- torchzero/modules/second_order/inm.py +101 -0
- torchzero/modules/second_order/multipoint.py +40 -80
- torchzero/modules/second_order/newton.py +105 -228
- torchzero/modules/second_order/newton_cg.py +102 -154
- torchzero/modules/second_order/nystrom.py +158 -178
- torchzero/modules/second_order/rsn.py +237 -0
- torchzero/modules/smoothing/laplacian.py +13 -12
- torchzero/modules/smoothing/sampling.py +11 -10
- torchzero/modules/step_size/adaptive.py +23 -23
- torchzero/modules/step_size/lr.py +15 -15
- torchzero/modules/termination/termination.py +32 -30
- torchzero/modules/trust_region/cubic_regularization.py +2 -2
- torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
- torchzero/modules/trust_region/trust_cg.py +1 -1
- torchzero/modules/trust_region/trust_region.py +27 -22
- torchzero/modules/variance_reduction/svrg.py +21 -18
- torchzero/modules/weight_decay/__init__.py +2 -1
- torchzero/modules/weight_decay/reinit.py +83 -0
- torchzero/modules/weight_decay/weight_decay.py +12 -13
- torchzero/modules/wrappers/optim_wrapper.py +57 -50
- torchzero/modules/zeroth_order/cd.py +9 -6
- torchzero/optim/root.py +3 -3
- torchzero/optim/utility/split.py +2 -1
- torchzero/optim/wrappers/directsearch.py +27 -63
- torchzero/optim/wrappers/fcmaes.py +14 -35
- torchzero/optim/wrappers/mads.py +11 -31
- torchzero/optim/wrappers/moors.py +66 -0
- torchzero/optim/wrappers/nevergrad.py +4 -4
- torchzero/optim/wrappers/nlopt.py +31 -25
- torchzero/optim/wrappers/optuna.py +6 -13
- torchzero/optim/wrappers/pybobyqa.py +124 -0
- torchzero/optim/wrappers/scipy/__init__.py +7 -0
- torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
- torchzero/optim/wrappers/scipy/brute.py +48 -0
- torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
- torchzero/optim/wrappers/scipy/direct.py +69 -0
- torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
- torchzero/optim/wrappers/scipy/experimental.py +141 -0
- torchzero/optim/wrappers/scipy/minimize.py +151 -0
- torchzero/optim/wrappers/scipy/sgho.py +111 -0
- torchzero/optim/wrappers/wrapper.py +121 -0
- torchzero/utils/__init__.py +7 -25
- torchzero/utils/compile.py +2 -2
- torchzero/utils/derivatives.py +112 -88
- torchzero/utils/optimizer.py +4 -77
- torchzero/utils/python_tools.py +31 -0
- torchzero/utils/tensorlist.py +11 -5
- torchzero/utils/thoad_tools.py +68 -0
- {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
- torchzero-0.4.0.dist-info/RECORD +191 -0
- tests/test_vars.py +0 -185
- torchzero/modules/experimental/momentum.py +0 -160
- torchzero/modules/higher_order/__init__.py +0 -1
- torchzero/optim/wrappers/scipy.py +0 -572
- torchzero/utils/linalg/__init__.py +0 -12
- torchzero/utils/linalg/matrix_funcs.py +0 -87
- torchzero/utils/linalg/orthogonalize.py +0 -12
- torchzero/utils/linalg/svd.py +0 -20
- torchzero/utils/ops.py +0 -10
- torchzero-0.3.14.dist-info/RECORD +0 -167
- /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
- {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
- {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
from . import linear_operator
|
|
2
|
+
|
|
3
|
+
from .matrix_power import (
|
|
4
|
+
matrix_power_eigh,
|
|
5
|
+
matrix_power_svd,
|
|
6
|
+
)
|
|
7
|
+
from .orthogonalize import zeropower_via_eigh, zeropower_via_newtonschulz5, zeropower_via_svd, orthogonalize
|
|
8
|
+
from .qr import qr_householder
|
|
9
|
+
from .solve import cg, nystrom_sketch_and_solve, nystrom_pcg
|
|
10
|
+
from .eigh import nystrom_approximation
|
torchzero/linalg/eigh.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
import torch
|
|
3
|
+
from .linalg_utils import mm
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
# https://arxiv.org/pdf/2110.02820
|
|
8
|
+
def nystrom_approximation(
|
|
9
|
+
A_mv: Callable[[torch.Tensor], torch.Tensor] | None,
|
|
10
|
+
A_mm: Callable[[torch.Tensor], torch.Tensor] | None,
|
|
11
|
+
ndim: int,
|
|
12
|
+
rank: int,
|
|
13
|
+
device,
|
|
14
|
+
dtype = torch.float32,
|
|
15
|
+
generator = None,
|
|
16
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
17
|
+
"""Computes Nyström approximation to positive-semidefinite A factored as Q L Q^T (truncatd eigenvalue decomp),
|
|
18
|
+
returns ``(L, Q)``.
|
|
19
|
+
|
|
20
|
+
A is ``(m,m)``, then Q is ``(m, rank)``; L is a ``(rank, )`` vector - diagonal of ``(rank, rank)``"""
|
|
21
|
+
# basis
|
|
22
|
+
O = torch.randn((ndim, rank), device=device, dtype=dtype, generator=generator) # Gaussian test matrix
|
|
23
|
+
O, _ = torch.linalg.qr(O) # Thin QR decomposition # pylint:disable=not-callable
|
|
24
|
+
|
|
25
|
+
# Y = AΩ
|
|
26
|
+
AO = mm(A_mv=A_mv, A_mm=A_mm, X=O)
|
|
27
|
+
|
|
28
|
+
v = torch.finfo(dtype).eps * torch.linalg.matrix_norm(AO, ord='fro') # Compute shift # pylint:disable=not-callable
|
|
29
|
+
Yv = AO + v*O # Shift for stability
|
|
30
|
+
C = torch.linalg.cholesky_ex(O.mT @ Yv)[0] # pylint:disable=not-callable
|
|
31
|
+
B = torch.linalg.solve_triangular(C, Yv.mT, upper=False, unitriangular=False).mT # pylint:disable=not-callable
|
|
32
|
+
Q, S, _ = torch.linalg.svd(B, full_matrices=False) # pylint:disable=not-callable
|
|
33
|
+
L = (S.pow(2) - v).clip(min=0) #Remove shift, compute eigs
|
|
34
|
+
return L, Q
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
def mm(
|
|
5
|
+
A_mv: Callable[[torch.Tensor], torch.Tensor] | None,
|
|
6
|
+
A_mm: Callable[[torch.Tensor], torch.Tensor] | None,
|
|
7
|
+
X
|
|
8
|
+
):
|
|
9
|
+
"""matrix-matrix when either mv or mm is given"""
|
|
10
|
+
if A_mm is not None: return A_mm(X)
|
|
11
|
+
assert A_mv is not None
|
|
12
|
+
return torch.stack([A_mv(col) for col in X.unbind(-1)], -1) # rank matvecs
|
|
13
|
+
|
|
14
|
+
|
|
@@ -1,4 +1,6 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""This is mainly used for trust regions. In some cases certain operations are relaxed, e.g. eigenvalue shift instead of
|
|
2
|
+
adding diagonal when it isn't tractable, to make it work with Levenberg-Marquadt.
|
|
3
|
+
"""
|
|
2
4
|
import math
|
|
3
5
|
from abc import ABC, abstractmethod
|
|
4
6
|
from functools import partial
|
|
@@ -7,7 +9,8 @@ from typing import cast, final
|
|
|
7
9
|
|
|
8
10
|
import torch
|
|
9
11
|
|
|
10
|
-
from ..torch_tools import tofloat, tonumpy, totensor
|
|
12
|
+
from ..utils.torch_tools import tofloat, tonumpy, totensor
|
|
13
|
+
from .solve import nystrom_sketch_and_solve
|
|
11
14
|
|
|
12
15
|
if find_spec('scipy') is not None:
|
|
13
16
|
from scipy.sparse.linalg import LinearOperator as _ScipyLinearOperator
|
|
@@ -15,7 +18,6 @@ else:
|
|
|
15
18
|
_ScipyLinearOperator = None
|
|
16
19
|
|
|
17
20
|
class LinearOperator(ABC):
|
|
18
|
-
"""this is used for trust region"""
|
|
19
21
|
device: torch.types.Device
|
|
20
22
|
dtype: torch.dtype | None
|
|
21
23
|
|
|
@@ -25,18 +27,24 @@ class LinearOperator(ABC):
|
|
|
25
27
|
def rmatvec(self, x: torch.Tensor) -> torch.Tensor:
|
|
26
28
|
raise NotImplementedError(f"{self.__class__.__name__} doesn't implement rmatvec")
|
|
27
29
|
|
|
28
|
-
def matmat(self,
|
|
29
|
-
raise NotImplementedError(f"{self.__class__.__name__} doesn't implement
|
|
30
|
+
def matmat(self, X: torch.Tensor) -> "LinearOperator":
|
|
31
|
+
raise NotImplementedError(f"{self.__class__.__name__} doesn't implement matmat")
|
|
32
|
+
|
|
33
|
+
def rmatmat(self, X: torch.Tensor) -> "LinearOperator":
|
|
34
|
+
raise NotImplementedError(f"{self.__class__.__name__} doesn't implement rmatmat")
|
|
30
35
|
|
|
31
36
|
def solve(self, b: torch.Tensor) -> torch.Tensor:
|
|
32
37
|
raise NotImplementedError(f"{self.__class__.__name__} doesn't implement solve")
|
|
33
38
|
|
|
39
|
+
def solve_plus_diag(self, b: torch.Tensor, diag: int | float | torch.Tensor) -> torch.Tensor:
|
|
40
|
+
return self.add_diagonal(diag).solve(b)
|
|
41
|
+
|
|
34
42
|
def solve_bounded(self, b: torch.Tensor, bound:float, ord:float=2) -> torch.Tensor:
|
|
35
43
|
"""solve with a norm bound on x"""
|
|
36
44
|
raise NotImplementedError(f"{self.__class__.__name__} doesn't implement solve_bounded")
|
|
37
45
|
|
|
38
|
-
def update(self, *args, **kwargs) -> None:
|
|
39
|
-
|
|
46
|
+
# def update(self, *args, **kwargs) -> None:
|
|
47
|
+
# raise NotImplementedError(f"{self.__class__.__name__} doesn't implement update")
|
|
40
48
|
|
|
41
49
|
def add(self, x: torch.Tensor) -> "LinearOperator":
|
|
42
50
|
raise NotImplementedError(f"{self.__class__.__name__} doesn't implement add")
|
|
@@ -129,8 +137,8 @@ class Dense(LinearOperator):
|
|
|
129
137
|
def matvec(self, x): return self.A.mv(x)
|
|
130
138
|
def rmatvec(self, x): return self.A.mH.mv(x)
|
|
131
139
|
|
|
132
|
-
def matmat(self,
|
|
133
|
-
def rmatmat(self,
|
|
140
|
+
def matmat(self, X): return Dense(self.A.mm(X))
|
|
141
|
+
def rmatmat(self, X): return Dense(self.A.mH.mm(X))
|
|
134
142
|
|
|
135
143
|
def solve(self, b): return _solve(self.A, b)
|
|
136
144
|
|
|
@@ -146,6 +154,12 @@ class Dense(LinearOperator):
|
|
|
146
154
|
def is_dense(self): return True
|
|
147
155
|
def transpose(self): return Dense(self.A.mH)
|
|
148
156
|
|
|
157
|
+
class SPD(Dense):
|
|
158
|
+
def solve(self, b: torch.Tensor):
|
|
159
|
+
L, info = torch.linalg.cholesky_ex(self.A) # pylint:disable=not-callable
|
|
160
|
+
return torch.cholesky_solve(b.unsqueeze(-1), L).squeeze(-1)
|
|
161
|
+
|
|
162
|
+
|
|
149
163
|
class DenseInverse(LinearOperator):
|
|
150
164
|
"""Represents inverse of a dense matrix A."""
|
|
151
165
|
def __init__(self, A_inv: torch.Tensor):
|
|
@@ -156,8 +170,8 @@ class DenseInverse(LinearOperator):
|
|
|
156
170
|
def matvec(self, x): return _solve(self.A_inv, x) # pylint:disable=not-callable
|
|
157
171
|
def rmatvec(self, x): return _solve(self.A_inv.mH, x) # pylint:disable=not-callable
|
|
158
172
|
|
|
159
|
-
def matmat(self,
|
|
160
|
-
def rmatmat(self,
|
|
173
|
+
def matmat(self, X): return Dense(_solve(self.A_inv, X)) # pylint:disable=not-callable
|
|
174
|
+
def rmatmat(self, X): return Dense(_solve(self.A_inv.mH, X)) # pylint:disable=not-callable
|
|
161
175
|
|
|
162
176
|
def solve(self, b): return self.A_inv.mv(b)
|
|
163
177
|
|
|
@@ -190,8 +204,8 @@ class Diagonal(LinearOperator):
|
|
|
190
204
|
def matvec(self, x): return self.A * x
|
|
191
205
|
def rmatvec(self, x): return self.A * x
|
|
192
206
|
|
|
193
|
-
def matmat(self,
|
|
194
|
-
def rmatmat(self,
|
|
207
|
+
def matmat(self, X): return Dense(X * self.A.unsqueeze(-1))
|
|
208
|
+
def rmatmat(self, X): return Dense(X * self.A.unsqueeze(-1))
|
|
195
209
|
|
|
196
210
|
def solve(self, b): return b/self.A
|
|
197
211
|
|
|
@@ -221,8 +235,8 @@ class ScaledIdentity(LinearOperator):
|
|
|
221
235
|
def matvec(self, x): return x * self.s
|
|
222
236
|
def rmatvec(self, x): return x * self.s
|
|
223
237
|
|
|
224
|
-
def matmat(self,
|
|
225
|
-
def rmatmat(self,
|
|
238
|
+
def matmat(self, X): return Dense(X * self.s)
|
|
239
|
+
def rmatmat(self, X): return Dense(X * self.s)
|
|
226
240
|
|
|
227
241
|
def solve(self, b): return b / self.s
|
|
228
242
|
def solve_bounded(self, b, bound, ord = 2):
|
|
@@ -263,6 +277,7 @@ class ScaledIdentity(LinearOperator):
|
|
|
263
277
|
def is_dense(self): return False
|
|
264
278
|
def transpose(self): return ScaledIdentity(self.s, shape=self.shape, device=self.device, dtype=self.dtype)
|
|
265
279
|
|
|
280
|
+
|
|
266
281
|
class AtA(LinearOperator):
|
|
267
282
|
def __init__(self, A: torch.Tensor):
|
|
268
283
|
self.A = A
|
|
@@ -270,8 +285,8 @@ class AtA(LinearOperator):
|
|
|
270
285
|
def matvec(self, x): return self.A.mH.mv(self.A.mv(x))
|
|
271
286
|
def rmatvec(self, x): return self.matvec(x)
|
|
272
287
|
|
|
273
|
-
def matmat(self,
|
|
274
|
-
def rmatmat(self,
|
|
288
|
+
def matmat(self, X): return Dense(torch.linalg.multi_dot([self.A.mH, self.A, X])) # pylint:disable=not-callable
|
|
289
|
+
def rmatmat(self, X): return Dense(torch.linalg.multi_dot([self.A.mH, self.A, X])) # pylint:disable=not-callable
|
|
275
290
|
|
|
276
291
|
def is_dense(self): return False
|
|
277
292
|
def to_tensor(self): return self.A.mH @ self.A
|
|
@@ -283,7 +298,27 @@ class AtA(LinearOperator):
|
|
|
283
298
|
return Dense(self.to_tensor() + torch.diag_embed(x))
|
|
284
299
|
|
|
285
300
|
def solve(self, b):
|
|
286
|
-
|
|
301
|
+
*_, n, m = self.A.shape
|
|
302
|
+
if n >= m: return Dense(self.to_tensor()).solve(b)
|
|
303
|
+
|
|
304
|
+
A = self.A
|
|
305
|
+
C = A @ A.mH # (n, n), SPD
|
|
306
|
+
L, info = torch.linalg.cholesky_ex(C) # pylint:disable=not-callable
|
|
307
|
+
z = torch.cholesky_solve((A @ b).unsqueeze(-1), L).squeeze(-1)
|
|
308
|
+
return A.mH @ z
|
|
309
|
+
|
|
310
|
+
def solve_plus_diag(self, b, diag):
|
|
311
|
+
*_, n, m = self.A.shape
|
|
312
|
+
if (n >= m) or (isinstance(diag, torch.Tensor) and diag.numel() > 1):
|
|
313
|
+
return Dense(self.to_tensor()).solve_plus_diag(b, diag)
|
|
314
|
+
|
|
315
|
+
A = self.A
|
|
316
|
+
I = torch.eye(A.size(-2), device=A.device, dtype=A.dtype)
|
|
317
|
+
|
|
318
|
+
C = (A @ A.mH).add_(I.mul_(diag)) # (n, n), SPD
|
|
319
|
+
L, info = torch.linalg.cholesky_ex(C + I.mul_(diag)) # pylint:disable=not-callable
|
|
320
|
+
z = torch.cholesky_solve((A @ b).unsqueeze(-1), L).squeeze(-1)
|
|
321
|
+
return (1 / diag) * (b - A.mH @ z)
|
|
287
322
|
|
|
288
323
|
def inv(self):
|
|
289
324
|
return Dense(self.to_tensor()).inv()
|
|
@@ -295,35 +330,98 @@ class AtA(LinearOperator):
|
|
|
295
330
|
n = self.A.size(1)
|
|
296
331
|
return (n,n)
|
|
297
332
|
|
|
298
|
-
class
|
|
333
|
+
class AAt(AtA):
|
|
299
334
|
def __init__(self, A: torch.Tensor):
|
|
300
|
-
|
|
335
|
+
super().__init__(A.mH)
|
|
301
336
|
|
|
302
|
-
|
|
303
|
-
|
|
337
|
+
class Sketched(LinearOperator):
|
|
338
|
+
"""A projected by sketching matrix S, representing the operator S @ A_proj @ S.T.
|
|
339
|
+
|
|
340
|
+
Where A is (n, n) and S is (n, sketch_size).
|
|
341
|
+
"""
|
|
342
|
+
def __init__(self, S: torch.Tensor, A_proj: torch.Tensor):
|
|
343
|
+
self.S = S
|
|
344
|
+
self.A_proj = A_proj
|
|
345
|
+
self.device = self.A_proj.device; self.dtype = self.A_proj.dtype
|
|
346
|
+
|
|
347
|
+
def matvec(self, x):
|
|
348
|
+
x_proj = self.S.T @ x
|
|
349
|
+
Ax_proj = self.A_proj @ x_proj
|
|
350
|
+
return self.S @ Ax_proj
|
|
351
|
+
|
|
352
|
+
def rmatvec(self, x):
|
|
353
|
+
x_proj = self.S.T @ x
|
|
354
|
+
ATx_proj = self.A_proj.mH @ x_proj
|
|
355
|
+
return self.S @ ATx_proj
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
def matmat(self, X): return Dense(torch.linalg.multi_dot([self.S, self.A_proj, self.S.T, X])) # pylint:disable=not-callable
|
|
359
|
+
def rmatmat(self, X): return Dense(torch.linalg.multi_dot([self.S, self.A_proj.mH, self.S.T, X])) # pylint:disable=not-callable
|
|
304
360
|
|
|
305
|
-
def matmat(self, x): return Dense(torch.linalg.multi_dot([self.A, self.A.mH, x])) # pylint:disable=not-callable
|
|
306
|
-
def rmatmat(self, x): return Dense(torch.linalg.multi_dot([self.A, self.A.mH, x])) # pylint:disable=not-callable
|
|
307
361
|
|
|
308
362
|
def is_dense(self): return False
|
|
309
|
-
def to_tensor(self): return self.
|
|
310
|
-
def transpose(self): return
|
|
363
|
+
def to_tensor(self): return self.S @ self.A_proj @ self.S.T
|
|
364
|
+
def transpose(self): return Sketched(self.S, self.A_proj.mH)
|
|
311
365
|
|
|
312
366
|
def add_diagonal(self, x):
|
|
367
|
+
"""this doesn't correspond to adding diagonal to A, however it still works for LM etc."""
|
|
313
368
|
if isinstance(x, torch.Tensor) and x.numel() <= 1: x = x.item()
|
|
314
|
-
if isinstance(x, (int,float)): x = torch.full((self.shape[0],), fill_value=x, device=self.
|
|
315
|
-
return
|
|
369
|
+
if isinstance(x, (int,float)): x = torch.full((self.A_proj.shape[0],), fill_value=x, device=self.A_proj.device, dtype=self.A_proj.dtype)
|
|
370
|
+
return Sketched(S=self.S, A_proj=self.A_proj + x.diag_embed())
|
|
316
371
|
|
|
317
372
|
def solve(self, b):
|
|
318
|
-
return
|
|
373
|
+
return self.S @ torch.linalg.lstsq(self.A_proj, self.S.T @ b).solution # pylint:disable=not-callable
|
|
319
374
|
|
|
320
375
|
def inv(self):
|
|
321
|
-
return
|
|
322
|
-
|
|
323
|
-
def diagonal(self):
|
|
324
|
-
return self.A.pow(2).sum(0)
|
|
376
|
+
return Sketched(S=self.S, A_proj=torch.linalg.pinv(self.A_proj)) # pylint:disable=not-callable
|
|
325
377
|
|
|
326
378
|
def size(self):
|
|
327
|
-
n = self.
|
|
379
|
+
n = self.S.size(0)
|
|
328
380
|
return (n,n)
|
|
329
381
|
|
|
382
|
+
|
|
383
|
+
class Eigendecomposition(LinearOperator):
|
|
384
|
+
"""A represented as Q L Q^H. If A is (n,n), then Q is (n, rank); L is a vector - diagonal of (rank, rank)"""
|
|
385
|
+
def __init__(self, L: torch.Tensor, Q: torch.Tensor, use_nystrom: bool = True):
|
|
386
|
+
self.L = L
|
|
387
|
+
self.Q = Q
|
|
388
|
+
self.use_nystrom = use_nystrom
|
|
389
|
+
self.device = self.L.device; self.dtype = self.L.dtype
|
|
390
|
+
|
|
391
|
+
def matvec(self, x):
|
|
392
|
+
return self.Q @ ((self.Q.mH @ x) * self.L)
|
|
393
|
+
|
|
394
|
+
def rmatvec(self, x):
|
|
395
|
+
return self.matvec(x)
|
|
396
|
+
|
|
397
|
+
def matmat(self, X):
|
|
398
|
+
return Dense(self.Q @ (self.L[:, None] * (self.Q.mH @ X)))
|
|
399
|
+
|
|
400
|
+
def rmatmat(self, X):
|
|
401
|
+
return self.matmat(X)
|
|
402
|
+
|
|
403
|
+
def is_dense(self): return False
|
|
404
|
+
def to_tensor(self): return self.Q @ self.L.diag_embed() @ self.Q.mH
|
|
405
|
+
def transpose(self): return Eigendecomposition(L=self.L, Q=self.Q)
|
|
406
|
+
|
|
407
|
+
def add_diagonal(self, x):
|
|
408
|
+
"""this doesn't correspond to adding diagonal to A, however it still works for LM etc."""
|
|
409
|
+
if isinstance(x, torch.Tensor) and x.numel() > 1:
|
|
410
|
+
raise RuntimeError("Eigendecomposition linear operator doesn't support add_diagonal with a vector diag")
|
|
411
|
+
|
|
412
|
+
return Eigendecomposition(L=self.L + x, Q = self.Q)
|
|
413
|
+
|
|
414
|
+
def solve(self, b):
|
|
415
|
+
return self.Q @ ((self.Q.mH @ b) / self.L)
|
|
416
|
+
|
|
417
|
+
def solve_plus_diag(self, b, diag):
|
|
418
|
+
if isinstance(diag, torch.Tensor) and diag.numel() > 1: return super().solve_plus_diag(b, diag)
|
|
419
|
+
if not self.use_nystrom: return super().solve_plus_diag(b, diag)
|
|
420
|
+
return nystrom_sketch_and_solve(L=self.L, Q=self.Q, b=b, reg=float(diag))
|
|
421
|
+
|
|
422
|
+
def inv(self):
|
|
423
|
+
return Eigendecomposition(L=1 / self.L, Q = self.Q)
|
|
424
|
+
|
|
425
|
+
def size(self):
|
|
426
|
+
n = self.Q.size(0)
|
|
427
|
+
return (n,n)
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
from typing import Literal
|
|
2
|
+
import warnings
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from . import torch_linalg
|
|
7
|
+
def matrix_power_eigh(A: torch.Tensor, power:float, abs:bool=False):
|
|
8
|
+
"""this is faster than SVD but only for positive semi-definite symmetric matrices
|
|
9
|
+
(covariance matrices are always SPD)"""
|
|
10
|
+
|
|
11
|
+
L, Q = torch_linalg.eigh(A, retry_float64=True) # pylint:disable=not-callable
|
|
12
|
+
if abs: L.abs_()
|
|
13
|
+
if power % 2 != 0: L.clip_(min = torch.finfo(A.dtype).tiny * 2)
|
|
14
|
+
return (Q * L.pow_(power).unsqueeze(-2)) @ Q.mH
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def matrix_power_svd(A: torch.Tensor, power: float) -> torch.Tensor:
|
|
18
|
+
"""for any symmetric matrix"""
|
|
19
|
+
U, S, Vh = torch_linalg.svd(A, full_matrices=False, retry_float64=True) # pylint:disable=not-callable
|
|
20
|
+
if power % 2 != 0: S.clip_(min = torch.finfo(A.dtype).tiny * 2)
|
|
21
|
+
return (U * S.pow_(power).unsqueeze(-2)) @ Vh
|
|
22
|
+
|
|
23
|
+
MatrixPowerMethod = Literal["eigh", "eigh_abs", "svd"]
|
|
24
|
+
def matrix_power(A: torch.Tensor, power: float, method: MatrixPowerMethod = "eigh_abs") -> torch.Tensor:
|
|
25
|
+
if method == "eigh": return matrix_power_eigh(A, power)
|
|
26
|
+
if method == "eigh_abs": return matrix_power_eigh(A, power, abs=True)
|
|
27
|
+
if method == "svd": return matrix_power_svd(A, power)
|
|
28
|
+
raise ValueError(method)
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
from typing import Literal
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
from ..utils.compile import allow_compile
|
|
5
|
+
from . import torch_linalg
|
|
6
|
+
|
|
7
|
+
# zeropower_via_newtonschulz5 from:
|
|
8
|
+
# https://github.com/KellerJordan/Muon/blob/master/muon.py
|
|
9
|
+
# and
|
|
10
|
+
# https://github.com/HomebrewML/HeavyBall/blob/main/heavyball/utils.py#L452
|
|
11
|
+
_NS_COEFFS = (
|
|
12
|
+
(4.0848, -6.8946, 2.9270),
|
|
13
|
+
(3.9505, -6.3029, 2.6377),
|
|
14
|
+
(3.7418, -5.5913, 2.3037),
|
|
15
|
+
(2.8769, -3.1427, 1.2046),
|
|
16
|
+
(2.8366, -3.0525, 1.2012)
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
@allow_compile
|
|
20
|
+
def zeropower_via_newtonschulz5(G: torch.Tensor, coeffs=_NS_COEFFS) -> torch.Tensor:
|
|
21
|
+
"""
|
|
22
|
+
Applies to last 2 dims - so usually reverse_dims should be applied to G before and after.
|
|
23
|
+
|
|
24
|
+
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
|
|
25
|
+
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
|
|
26
|
+
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
|
|
27
|
+
zero even beyond the point where the iteration no longer converges all the way to one everywhere
|
|
28
|
+
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
|
|
29
|
+
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
|
|
30
|
+
performance at all relative to UV^T, where USV^T = G is the SVD.
|
|
31
|
+
"""
|
|
32
|
+
assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
|
|
33
|
+
|
|
34
|
+
X = G.bfloat16()
|
|
35
|
+
if G.size(-2) > G.size(-1):
|
|
36
|
+
X = X.mT
|
|
37
|
+
|
|
38
|
+
# Ensure spectral norm is at most 1
|
|
39
|
+
X = X / (X.norm(dim=(-2, -1), keepdim=True).clip(min=torch.finfo(X.dtype).tiny * 2))
|
|
40
|
+
|
|
41
|
+
# Perform the NS iterations
|
|
42
|
+
for a,b,c in coeffs:
|
|
43
|
+
A = X @ X.mT
|
|
44
|
+
B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
|
|
45
|
+
X = a * X + B @ X
|
|
46
|
+
|
|
47
|
+
if G.size(-2) > G.size(-1):
|
|
48
|
+
X = X.mT
|
|
49
|
+
|
|
50
|
+
return X.to(G.dtype)
|
|
51
|
+
|
|
52
|
+
# code from https://github.com/MarkTuddenham/Orthogonal-Optimisers.
|
|
53
|
+
# Tuddenham, M., Prügel-Bennett, A., & Hare, J. (2022).
|
|
54
|
+
# Orthogonalising gradients to speed up neural network optimisation. arXiv preprint arXiv:2202.07052.
|
|
55
|
+
def zeropower_via_svd(A: torch.Tensor) -> torch.Tensor:
|
|
56
|
+
"""
|
|
57
|
+
Applies to first 2 dims and isn't batched - rest of dimensions are flattened.
|
|
58
|
+
"""
|
|
59
|
+
try:
|
|
60
|
+
U, S, Vt = torch_linalg.svd(A, full_matrices=False, retry_float64=True) # pylint:disable=not-callable
|
|
61
|
+
except torch.linalg.LinAlgError:
|
|
62
|
+
U, S, Vt = torch.svd_lowrank(A, q=1, M=1e-4 * A.mean() * torch.rand_like(A))
|
|
63
|
+
|
|
64
|
+
return U @ Vt
|
|
65
|
+
|
|
66
|
+
def zeropower_via_eigh(A: torch.Tensor) -> torch.Tensor:
|
|
67
|
+
"""
|
|
68
|
+
Only SPD and I need to check if I apply those to SPD because this is better than SVD.
|
|
69
|
+
"""
|
|
70
|
+
L, Q = torch_linalg.eigh(A, retry_float64=True)
|
|
71
|
+
return Q @ Q.mH
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def orthogonalize_via_qr(A: torch.Tensor):
|
|
75
|
+
*_, m, n = A.shape
|
|
76
|
+
T = False
|
|
77
|
+
if m < n:
|
|
78
|
+
T = True
|
|
79
|
+
m,n = n,m
|
|
80
|
+
A = A.mH
|
|
81
|
+
|
|
82
|
+
Q = torch_linalg.qr(A, mode='reduced', retry_float64=True).Q
|
|
83
|
+
|
|
84
|
+
if T:
|
|
85
|
+
Q = Q.mH
|
|
86
|
+
|
|
87
|
+
return Q
|
|
88
|
+
|
|
89
|
+
OrthogonalizeMethod = Literal["newtonschulz", "svd", "qr"]
|
|
90
|
+
def orthogonalize(A: torch.Tensor, method: OrthogonalizeMethod = "newtonschulz") -> torch.Tensor:
|
|
91
|
+
if method == "newtonschulz": return zeropower_via_newtonschulz5(A)
|
|
92
|
+
if method == "svd": return zeropower_via_svd(A)
|
|
93
|
+
if method == "qr": return orthogonalize_via_qr(A)
|
|
94
|
+
if method == "eigh": return zeropower_via_eigh(A)
|
|
95
|
+
raise ValueError(method)
|
|
@@ -1,8 +1,9 @@
|
|
|
1
1
|
from typing import Literal
|
|
2
2
|
import torch
|
|
3
|
-
from ..compile import
|
|
3
|
+
from ..utils.compile import allow_compile
|
|
4
4
|
|
|
5
5
|
# reference - https://www.cs.cornell.edu/~bindel/class/cs6210-f09/lec18.pdf
|
|
6
|
+
@allow_compile
|
|
6
7
|
def _get_w_tau(R: torch.Tensor, i: int, eps: float):
|
|
7
8
|
R_ii = R[...,i,i]
|
|
8
9
|
R_below = R[...,i:,i]
|
|
@@ -17,6 +18,7 @@ def _get_w_tau(R: torch.Tensor, i: int, eps: float):
|
|
|
17
18
|
tau = torch.where(degenerate, 1, tau)
|
|
18
19
|
return w, tau
|
|
19
20
|
|
|
21
|
+
@allow_compile
|
|
20
22
|
def _qr_householder_complete(A:torch.Tensor):
|
|
21
23
|
*b,m,n = A.shape
|
|
22
24
|
k = min(m,n)
|
|
@@ -33,6 +35,7 @@ def _qr_householder_complete(A:torch.Tensor):
|
|
|
33
35
|
|
|
34
36
|
return Q, R
|
|
35
37
|
|
|
38
|
+
@allow_compile
|
|
36
39
|
def _qr_householder_reduced(A:torch.Tensor):
|
|
37
40
|
*b,m,n = A.shape
|
|
38
41
|
k = min(m,n)
|
|
@@ -64,7 +67,6 @@ def _qr_householder_reduced(A:torch.Tensor):
|
|
|
64
67
|
|
|
65
68
|
return Q, R
|
|
66
69
|
|
|
67
|
-
# @enable_compilation
|
|
68
70
|
def qr_householder(A:torch.Tensor, mode: Literal['complete', 'reduced'] = 'reduced'):
|
|
69
71
|
"""an attempt at making QR decomposition for very tall and thin matrices that doesn't freeze, but it is around n_cols times slower than torch.linalg.qr, but compilation makes it faster, but it has to recompile when processing different shapes"""
|
|
70
72
|
if mode == 'reduced': return _qr_householder_reduced(A)
|