torchzero 0.3.15__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +2 -2
- tests/test_module_autograd.py +586 -0
- tests/test_objective.py +188 -0
- tests/test_opts.py +43 -33
- tests/test_tensorlist.py +0 -8
- tests/test_utils_optimizer.py +0 -1
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +7 -4
- torchzero/core/chain.py +20 -23
- torchzero/core/functional.py +90 -24
- torchzero/core/modular.py +48 -52
- torchzero/core/module.py +130 -50
- torchzero/core/objective.py +948 -0
- torchzero/core/reformulation.py +55 -24
- torchzero/core/transform.py +261 -367
- torchzero/linalg/__init__.py +10 -0
- torchzero/linalg/eigh.py +34 -0
- torchzero/linalg/linalg_utils.py +14 -0
- torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
- torchzero/linalg/matrix_power.py +28 -0
- torchzero/linalg/orthogonalize.py +95 -0
- torchzero/{utils/linalg → linalg}/qr.py +4 -2
- torchzero/{utils/linalg → linalg}/solve.py +76 -88
- torchzero/linalg/svd.py +20 -0
- torchzero/linalg/torch_linalg.py +168 -0
- torchzero/modules/adaptive/__init__.py +1 -1
- torchzero/modules/adaptive/adagrad.py +163 -213
- torchzero/modules/adaptive/adahessian.py +74 -103
- torchzero/modules/adaptive/adam.py +53 -76
- torchzero/modules/adaptive/adan.py +49 -30
- torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
- torchzero/modules/adaptive/aegd.py +12 -12
- torchzero/modules/adaptive/esgd.py +98 -119
- torchzero/modules/adaptive/lion.py +5 -10
- torchzero/modules/adaptive/lmadagrad.py +87 -32
- torchzero/modules/adaptive/mars.py +5 -5
- torchzero/modules/adaptive/matrix_momentum.py +47 -51
- torchzero/modules/adaptive/msam.py +70 -52
- torchzero/modules/adaptive/muon.py +59 -124
- torchzero/modules/adaptive/natural_gradient.py +33 -28
- torchzero/modules/adaptive/orthograd.py +11 -15
- torchzero/modules/adaptive/rmsprop.py +83 -75
- torchzero/modules/adaptive/rprop.py +48 -47
- torchzero/modules/adaptive/sam.py +55 -45
- torchzero/modules/adaptive/shampoo.py +123 -129
- torchzero/modules/adaptive/soap.py +207 -143
- torchzero/modules/adaptive/sophia_h.py +106 -130
- torchzero/modules/clipping/clipping.py +15 -18
- torchzero/modules/clipping/ema_clipping.py +31 -25
- torchzero/modules/clipping/growth_clipping.py +14 -17
- torchzero/modules/conjugate_gradient/cg.py +26 -37
- torchzero/modules/experimental/__init__.py +2 -6
- torchzero/modules/experimental/coordinate_momentum.py +36 -0
- torchzero/modules/experimental/curveball.py +25 -41
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/higher_order_newton.py +14 -40
- torchzero/modules/experimental/newton_solver.py +22 -53
- torchzero/modules/experimental/newtonnewton.py +15 -12
- torchzero/modules/experimental/reduce_outward_lr.py +7 -7
- torchzero/modules/experimental/scipy_newton_cg.py +21 -24
- torchzero/modules/experimental/spsa1.py +3 -3
- torchzero/modules/experimental/structural_projections.py +1 -4
- torchzero/modules/functional.py +1 -1
- torchzero/modules/grad_approximation/forward_gradient.py +7 -7
- torchzero/modules/grad_approximation/grad_approximator.py +23 -16
- torchzero/modules/grad_approximation/rfdm.py +20 -17
- torchzero/modules/least_squares/gn.py +90 -42
- torchzero/modules/line_search/backtracking.py +2 -2
- torchzero/modules/line_search/line_search.py +32 -32
- torchzero/modules/line_search/strong_wolfe.py +2 -2
- torchzero/modules/misc/debug.py +12 -12
- torchzero/modules/misc/escape.py +10 -10
- torchzero/modules/misc/gradient_accumulation.py +10 -78
- torchzero/modules/misc/homotopy.py +16 -8
- torchzero/modules/misc/misc.py +120 -122
- torchzero/modules/misc/multistep.py +50 -48
- torchzero/modules/misc/regularization.py +49 -44
- torchzero/modules/misc/split.py +30 -28
- torchzero/modules/misc/switch.py +37 -32
- torchzero/modules/momentum/averaging.py +14 -14
- torchzero/modules/momentum/cautious.py +34 -28
- torchzero/modules/momentum/momentum.py +11 -11
- torchzero/modules/ops/__init__.py +4 -4
- torchzero/modules/ops/accumulate.py +21 -21
- torchzero/modules/ops/binary.py +67 -66
- torchzero/modules/ops/higher_level.py +19 -19
- torchzero/modules/ops/multi.py +44 -41
- torchzero/modules/ops/reduce.py +26 -23
- torchzero/modules/ops/unary.py +53 -53
- torchzero/modules/ops/utility.py +47 -46
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +43 -43
- torchzero/modules/quasi_newton/damping.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +7 -7
- torchzero/modules/quasi_newton/lsr1.py +7 -7
- torchzero/modules/quasi_newton/quasi_newton.py +10 -10
- torchzero/modules/quasi_newton/sg2.py +19 -19
- torchzero/modules/restarts/restars.py +26 -24
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/ifn.py +31 -62
- torchzero/modules/second_order/inm.py +49 -53
- torchzero/modules/second_order/multipoint.py +40 -80
- torchzero/modules/second_order/newton.py +57 -90
- torchzero/modules/second_order/newton_cg.py +102 -154
- torchzero/modules/second_order/nystrom.py +157 -177
- torchzero/modules/second_order/rsn.py +106 -96
- torchzero/modules/smoothing/laplacian.py +13 -12
- torchzero/modules/smoothing/sampling.py +11 -10
- torchzero/modules/step_size/adaptive.py +23 -23
- torchzero/modules/step_size/lr.py +15 -15
- torchzero/modules/termination/termination.py +32 -30
- torchzero/modules/trust_region/cubic_regularization.py +2 -2
- torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
- torchzero/modules/trust_region/trust_cg.py +1 -1
- torchzero/modules/trust_region/trust_region.py +27 -22
- torchzero/modules/variance_reduction/svrg.py +21 -18
- torchzero/modules/weight_decay/__init__.py +2 -1
- torchzero/modules/weight_decay/reinit.py +83 -0
- torchzero/modules/weight_decay/weight_decay.py +12 -13
- torchzero/modules/wrappers/optim_wrapper.py +10 -10
- torchzero/modules/zeroth_order/cd.py +9 -6
- torchzero/optim/root.py +3 -3
- torchzero/optim/utility/split.py +2 -1
- torchzero/optim/wrappers/directsearch.py +27 -63
- torchzero/optim/wrappers/fcmaes.py +14 -35
- torchzero/optim/wrappers/mads.py +11 -31
- torchzero/optim/wrappers/moors.py +66 -0
- torchzero/optim/wrappers/nevergrad.py +4 -4
- torchzero/optim/wrappers/nlopt.py +31 -25
- torchzero/optim/wrappers/optuna.py +6 -13
- torchzero/optim/wrappers/pybobyqa.py +124 -0
- torchzero/optim/wrappers/scipy/__init__.py +7 -0
- torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
- torchzero/optim/wrappers/scipy/brute.py +48 -0
- torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
- torchzero/optim/wrappers/scipy/direct.py +69 -0
- torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
- torchzero/optim/wrappers/scipy/experimental.py +141 -0
- torchzero/optim/wrappers/scipy/minimize.py +151 -0
- torchzero/optim/wrappers/scipy/sgho.py +111 -0
- torchzero/optim/wrappers/wrapper.py +121 -0
- torchzero/utils/__init__.py +7 -25
- torchzero/utils/compile.py +2 -2
- torchzero/utils/derivatives.py +93 -69
- torchzero/utils/optimizer.py +4 -77
- torchzero/utils/python_tools.py +31 -0
- torchzero/utils/tensorlist.py +11 -5
- torchzero/utils/thoad_tools.py +68 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
- torchzero-0.4.0.dist-info/RECORD +191 -0
- tests/test_vars.py +0 -185
- torchzero/core/var.py +0 -376
- torchzero/modules/experimental/momentum.py +0 -160
- torchzero/optim/wrappers/scipy.py +0 -572
- torchzero/utils/linalg/__init__.py +0 -12
- torchzero/utils/linalg/matrix_funcs.py +0 -87
- torchzero/utils/linalg/orthogonalize.py +0 -12
- torchzero/utils/linalg/svd.py +0 -20
- torchzero/utils/ops.py +0 -10
- torchzero-0.3.15.dist-info/RECORD +0 -175
- /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
# pylint: disable = non-ascii-name
|
|
1
2
|
# pyright: reportArgumentType=false
|
|
2
3
|
import math
|
|
3
4
|
from collections import deque
|
|
@@ -5,8 +6,8 @@ from collections.abc import Callable
|
|
|
5
6
|
from typing import Any, NamedTuple, overload
|
|
6
7
|
|
|
7
8
|
import torch
|
|
8
|
-
|
|
9
|
-
from .. import (
|
|
9
|
+
from .linalg_utils import mm
|
|
10
|
+
from ..utils import (
|
|
10
11
|
TensorList,
|
|
11
12
|
generic_eq,
|
|
12
13
|
generic_finfo_tiny,
|
|
@@ -15,88 +16,73 @@ from .. import (
|
|
|
15
16
|
generic_zeros_like,
|
|
16
17
|
)
|
|
17
18
|
|
|
18
|
-
|
|
19
|
-
def
|
|
20
|
-
|
|
21
|
-
Ax = A_mm(x)
|
|
19
|
+
def _make_A_mv_reg(A_mv: Callable, reg):
|
|
20
|
+
def A_mv_reg(x): # A_mm with regularization
|
|
21
|
+
Ax = A_mv(x)
|
|
22
22
|
if not generic_eq(reg, 0): Ax += x*reg
|
|
23
23
|
return Ax
|
|
24
|
-
return
|
|
24
|
+
return A_mv_reg
|
|
25
25
|
|
|
26
26
|
def _identity(x): return x
|
|
27
27
|
|
|
28
|
-
|
|
29
|
-
#
|
|
30
|
-
def nystrom_approximation(
|
|
31
|
-
A_mm: Callable[[torch.Tensor], torch.Tensor],
|
|
32
|
-
ndim: int,
|
|
33
|
-
rank: int,
|
|
34
|
-
device,
|
|
35
|
-
dtype = torch.float32,
|
|
36
|
-
generator = None,
|
|
37
|
-
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
38
|
-
omega = torch.randn((ndim, rank), device=device, dtype=dtype, generator=generator) # Gaussian test matrix
|
|
39
|
-
omega, _ = torch.linalg.qr(omega) # Thin QR decomposition # pylint:disable=not-callable
|
|
40
|
-
|
|
41
|
-
# Y = AΩ
|
|
42
|
-
Y = torch.stack([A_mm(col) for col in omega.unbind(-1)], -1) # rank matvecs
|
|
43
|
-
v = torch.finfo(dtype).eps * torch.linalg.matrix_norm(Y, ord='fro') # Compute shift # pylint:disable=not-callable
|
|
44
|
-
Yv = Y + v*omega # Shift for stability
|
|
45
|
-
C = torch.linalg.cholesky_ex(omega.mT @ Yv)[0] # pylint:disable=not-callable
|
|
46
|
-
B = torch.linalg.solve_triangular(C, Yv.mT, upper=False, unitriangular=False).mT # pylint:disable=not-callable
|
|
47
|
-
U, S, _ = torch.linalg.svd(B, full_matrices=False) # pylint:disable=not-callable
|
|
48
|
-
lambd = (S.pow(2) - v).clip(min=0) #Remove shift, compute eigs
|
|
49
|
-
return U, lambd
|
|
50
|
-
|
|
28
|
+
# TODO this is used in NystromSketchAndSolve
|
|
29
|
+
# I need to add alternative to it where it just shifts eigenvalues by reg and uses their reciprocal
|
|
51
30
|
def nystrom_sketch_and_solve(
|
|
52
|
-
|
|
31
|
+
L: torch.Tensor,
|
|
32
|
+
Q: torch.Tensor,
|
|
53
33
|
b: torch.Tensor,
|
|
54
|
-
rank: int,
|
|
55
34
|
reg: float = 1e-3,
|
|
56
|
-
generator=None,
|
|
57
35
|
) -> torch.Tensor:
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
36
|
+
"""Solves (Q diag(L) Q.T + reg*I)x = b. Becomes super unstable with reg smaller than like 1e-5.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
L (torch.Tensor): eigenvalues, like from ``nystrom_approximation``
|
|
40
|
+
Q (torch.Tensor): eigenvectors, like from ``nystrom_approximation``
|
|
41
|
+
b (torch.Tensor): right hand side
|
|
42
|
+
reg (float, optional): regularization. Defaults to 1e-3.
|
|
43
|
+
"""
|
|
44
|
+
|
|
66
45
|
b = b.unsqueeze(-1)
|
|
67
|
-
|
|
46
|
+
L += reg
|
|
68
47
|
# x = (A + μI)⁻¹ b
|
|
69
|
-
# (A + μI)⁻¹ =
|
|
70
|
-
# x =
|
|
71
|
-
|
|
72
|
-
term1 =
|
|
73
|
-
term2 = (1.0 / reg) * (b -
|
|
48
|
+
# (A + μI)⁻¹ = Q(L + μI)⁻¹Qᵀ + (1/μ)(b - QQᵀ)
|
|
49
|
+
# x = Q(L + μI)⁻¹Qᵀb + (1/μ)(b - QQᵀb)
|
|
50
|
+
Qᵀb = Q.T @ b
|
|
51
|
+
term1 = Q @ ((1/L).unsqueeze(-1) * Qᵀb)
|
|
52
|
+
term2 = (1.0 / reg) * (b - Q @ Qᵀb)
|
|
74
53
|
return (term1 + term2).squeeze(-1)
|
|
75
54
|
|
|
76
55
|
def nystrom_pcg(
|
|
77
|
-
|
|
56
|
+
L: torch.Tensor,
|
|
57
|
+
Q: torch.Tensor,
|
|
58
|
+
A_mv: Callable[[torch.Tensor], torch.Tensor],
|
|
78
59
|
b: torch.Tensor,
|
|
79
|
-
sketch_size: int,
|
|
80
60
|
reg: float = 1e-6,
|
|
81
61
|
x0_: torch.Tensor | None = None,
|
|
82
|
-
tol: float | None = 1e-
|
|
62
|
+
tol: float | None = 1e-8,
|
|
83
63
|
maxiter: int | None = None,
|
|
84
|
-
generator=None,
|
|
85
64
|
) -> torch.Tensor:
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
65
|
+
"""conjugate gradient preconditioned by nystrom approximation.
|
|
66
|
+
|
|
67
|
+
The preconditioner can be computed by one matrix-matrix multiplication with A.
|
|
68
|
+
If matrix-matrix is efficient, then this is good (e.g. batched hessian-vector products in pytorch)
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
L (torch.Tensor): eigenvalues of approximation of A, like from ``nystrom_approximation``
|
|
72
|
+
Q (torch.Tensor): eigenvectors of approximation of A, like from ``nystrom_approximation``
|
|
73
|
+
A_mv (Callable[[torch.Tensor], torch.Tensor]): mat-vec func with hessian
|
|
74
|
+
b (torch.Tensor): right hand side
|
|
75
|
+
reg (float, optional): regularization. Defaults to 1e-6.
|
|
76
|
+
x0_ (torch.Tensor | None, optional): initial guess (modified in-place). Defaults to None.
|
|
77
|
+
tol (float | None, optional): tolerance for convergence. Defaults to 1e-4.
|
|
78
|
+
maxiter (int | None, optional): maximum number of iterations. Defaults to None.
|
|
79
|
+
"""
|
|
80
|
+
L += reg
|
|
95
81
|
eps = torch.finfo(b.dtype).tiny * 2
|
|
96
82
|
if tol is None: tol = eps
|
|
97
83
|
|
|
98
|
-
def
|
|
99
|
-
Ax =
|
|
84
|
+
def A_mv_reg(x): # A_mm with regularization
|
|
85
|
+
Ax = A_mv(x)
|
|
100
86
|
if reg != 0: Ax += x*reg
|
|
101
87
|
return Ax
|
|
102
88
|
|
|
@@ -104,10 +90,10 @@ def nystrom_pcg(
|
|
|
104
90
|
if x0_ is None: x0_ = torch.zeros_like(b)
|
|
105
91
|
|
|
106
92
|
x = x0_
|
|
107
|
-
residual = b -
|
|
93
|
+
residual = b - A_mv_reg(x)
|
|
108
94
|
# z0 = P⁻¹ r0
|
|
109
|
-
term1 =
|
|
110
|
-
term2 = torch.eye(
|
|
95
|
+
term1 = L[...,-1] * Q * (1/L.unsqueeze(-2)) @ Q.mT
|
|
96
|
+
term2 = torch.eye(Q.size(-2), device=Q.device,dtype=Q.dtype) - Q@Q.mT
|
|
111
97
|
P_inv = term1 + term2
|
|
112
98
|
z = P_inv @ residual
|
|
113
99
|
p = z.clone() # search direction
|
|
@@ -116,7 +102,7 @@ def nystrom_pcg(
|
|
|
116
102
|
if init_norm < tol: return x
|
|
117
103
|
k = 0
|
|
118
104
|
while True:
|
|
119
|
-
Ap =
|
|
105
|
+
Ap = A_mv_reg(p)
|
|
120
106
|
rz = residual.dot(z)
|
|
121
107
|
step_size = rz / p.dot(Ap)
|
|
122
108
|
x += step_size * p
|
|
@@ -138,7 +124,7 @@ def _safe_clip(x: torch.Tensor):
|
|
|
138
124
|
if x.abs() < eps: return x.new_full(x.size(), eps).copysign(x)
|
|
139
125
|
return x
|
|
140
126
|
|
|
141
|
-
def _trust_tau(x,d,trust_radius):
|
|
127
|
+
def _trust_tau(x, d, trust_radius):
|
|
142
128
|
xx = x.dot(x)
|
|
143
129
|
xd = x.dot(d)
|
|
144
130
|
dd = _safe_clip(d.dot(d))
|
|
@@ -150,10 +136,10 @@ def _trust_tau(x,d,trust_radius):
|
|
|
150
136
|
|
|
151
137
|
|
|
152
138
|
class CG:
|
|
153
|
-
"""Conjugate gradient method.
|
|
139
|
+
"""Conjugate gradient method optionally with norm constraint.
|
|
154
140
|
|
|
155
141
|
Args:
|
|
156
|
-
|
|
142
|
+
A_mv (Callable[[torch.Tensor], torch.Tensor] | torch.Tensor): Callable that returns matvec ``Ax``.
|
|
157
143
|
b (torch.Tensor): right hand side
|
|
158
144
|
x0 (torch.Tensor | None, optional): initial guess, defaults to zeros. Defaults to None.
|
|
159
145
|
tol (float | None, optional): tolerance for convergence. Defaults to 1e-8.
|
|
@@ -174,10 +160,10 @@ class CG:
|
|
|
174
160
|
"""
|
|
175
161
|
def __init__(
|
|
176
162
|
self,
|
|
177
|
-
|
|
163
|
+
A_mv: Callable,
|
|
178
164
|
b: torch.Tensor | TensorList,
|
|
179
165
|
x0: torch.Tensor | TensorList | None = None,
|
|
180
|
-
tol: float | None = 1e-
|
|
166
|
+
tol: float | None = 1e-8,
|
|
181
167
|
maxiter: int | None = None,
|
|
182
168
|
reg: float = 0,
|
|
183
169
|
trust_radius: float | None = None,
|
|
@@ -187,7 +173,7 @@ class CG:
|
|
|
187
173
|
P_mm: Callable | None = None,
|
|
188
174
|
):
|
|
189
175
|
# --------------------------------- set attrs -------------------------------- #
|
|
190
|
-
self.
|
|
176
|
+
self.A_mv = _make_A_mv_reg(A_mv, reg)
|
|
191
177
|
self.b = b
|
|
192
178
|
if tol is None: tol = generic_finfo_tiny(b) * 2
|
|
193
179
|
self.tol = tol
|
|
@@ -214,7 +200,7 @@ class CG:
|
|
|
214
200
|
self.r = b
|
|
215
201
|
else:
|
|
216
202
|
self.x = x0
|
|
217
|
-
self.r = b -
|
|
203
|
+
self.r = b - A_mv(self.x)
|
|
218
204
|
|
|
219
205
|
self.z = self.P_mm(self.r)
|
|
220
206
|
self.d = self.z
|
|
@@ -229,7 +215,7 @@ class CG:
|
|
|
229
215
|
if self.iter >= self.maxiter:
|
|
230
216
|
return x, True
|
|
231
217
|
|
|
232
|
-
Ad = self.
|
|
218
|
+
Ad = self.A_mv(d)
|
|
233
219
|
dAd = d.dot(Ad)
|
|
234
220
|
|
|
235
221
|
# check negative curvature
|
|
@@ -289,7 +275,8 @@ class CG:
|
|
|
289
275
|
return sol
|
|
290
276
|
|
|
291
277
|
def find_within_trust_radius(history, trust_radius: float):
|
|
292
|
-
"""find first ``x`` in history that exceeds trust radius
|
|
278
|
+
"""find first ``x`` in history that exceeds trust radius and returns solution within,
|
|
279
|
+
if no such ``x`` exists, returns ``None``"""
|
|
293
280
|
for x, x_norm, d in reversed(tuple(history)):
|
|
294
281
|
if x_norm <= trust_radius:
|
|
295
282
|
return _trust_tau(x, d, trust_radius)
|
|
@@ -306,7 +293,7 @@ class _TensorListSolution(NamedTuple):
|
|
|
306
293
|
|
|
307
294
|
@overload
|
|
308
295
|
def cg(
|
|
309
|
-
|
|
296
|
+
A_mv: Callable[[torch.Tensor], torch.Tensor],
|
|
310
297
|
b: torch.Tensor,
|
|
311
298
|
x0: torch.Tensor | None = None,
|
|
312
299
|
tol: float | None = 1e-8,
|
|
@@ -320,7 +307,7 @@ def cg(
|
|
|
320
307
|
) -> _TensorSolution: ...
|
|
321
308
|
@overload
|
|
322
309
|
def cg(
|
|
323
|
-
|
|
310
|
+
A_mv: Callable[[TensorList], TensorList],
|
|
324
311
|
b: TensorList,
|
|
325
312
|
x0: TensorList | None = None,
|
|
326
313
|
tol: float | None = 1e-8,
|
|
@@ -333,7 +320,7 @@ def cg(
|
|
|
333
320
|
P_mm: Callable[[TensorList], TensorList] | None = None
|
|
334
321
|
) -> _TensorListSolution: ...
|
|
335
322
|
def cg(
|
|
336
|
-
|
|
323
|
+
A_mv: Callable,
|
|
337
324
|
b: torch.Tensor | TensorList,
|
|
338
325
|
x0: torch.Tensor | TensorList | None = None,
|
|
339
326
|
tol: float | None = 1e-8,
|
|
@@ -346,7 +333,7 @@ def cg(
|
|
|
346
333
|
P_mm: Callable | None = None
|
|
347
334
|
):
|
|
348
335
|
solver = CG(
|
|
349
|
-
|
|
336
|
+
A_mv=A_mv,
|
|
350
337
|
b=b,
|
|
351
338
|
x0=x0,
|
|
352
339
|
tol=tol,
|
|
@@ -370,10 +357,10 @@ def cg(
|
|
|
370
357
|
# Liu, Yang, and Fred Roosta. "MINRES: From negative curvature detection to monotonicity properties." SIAM Journal on Optimization 32.4 (2022): 2636-2661.
|
|
371
358
|
@overload
|
|
372
359
|
def minres(
|
|
373
|
-
|
|
360
|
+
A_mv: Callable[[torch.Tensor], torch.Tensor] | torch.Tensor,
|
|
374
361
|
b: torch.Tensor,
|
|
375
362
|
x0: torch.Tensor | None = None,
|
|
376
|
-
tol: float | None = 1e-
|
|
363
|
+
tol: float | None = 1e-8,
|
|
377
364
|
maxiter: int | None = None,
|
|
378
365
|
reg: float = 0,
|
|
379
366
|
npc_terminate: bool=True,
|
|
@@ -381,26 +368,27 @@ def minres(
|
|
|
381
368
|
) -> torch.Tensor: ...
|
|
382
369
|
@overload
|
|
383
370
|
def minres(
|
|
384
|
-
|
|
371
|
+
A_mv: Callable[[TensorList], TensorList],
|
|
385
372
|
b: TensorList,
|
|
386
373
|
x0: TensorList | None = None,
|
|
387
|
-
tol: float | None = 1e-
|
|
374
|
+
tol: float | None = 1e-8,
|
|
388
375
|
maxiter: int | None = None,
|
|
389
376
|
reg: float | list[float] | tuple[float] = 0,
|
|
390
377
|
npc_terminate: bool=True,
|
|
391
378
|
trust_radius: float | None = None,
|
|
392
379
|
) -> TensorList: ...
|
|
393
380
|
def minres(
|
|
394
|
-
|
|
381
|
+
A_mv,
|
|
395
382
|
b,
|
|
396
383
|
x0: torch.Tensor | TensorList | None = None,
|
|
397
|
-
tol: float | None = 1e-
|
|
384
|
+
tol: float | None = 1e-8,
|
|
398
385
|
maxiter: int | None = None,
|
|
399
386
|
reg: float | list[float] | tuple[float] = 0,
|
|
400
387
|
npc_terminate: bool=True,
|
|
401
388
|
trust_radius: float | None = None, #trust region is experimental
|
|
402
389
|
):
|
|
403
|
-
|
|
390
|
+
"""MINRES (experimental)"""
|
|
391
|
+
A_mv_reg = _make_A_mv_reg(A_mv, reg)
|
|
404
392
|
eps = math.sqrt(generic_finfo_tiny(b) * 2)
|
|
405
393
|
if tol is None: tol = eps
|
|
406
394
|
|
|
@@ -409,7 +397,7 @@ def minres(
|
|
|
409
397
|
R = b
|
|
410
398
|
x0 = generic_zeros_like(b)
|
|
411
399
|
else:
|
|
412
|
-
R = b -
|
|
400
|
+
R = b - A_mv_reg(x0)
|
|
413
401
|
|
|
414
402
|
X: Any = x0
|
|
415
403
|
beta = b_norm = generic_vector_norm(b)
|
|
@@ -429,7 +417,7 @@ def minres(
|
|
|
429
417
|
|
|
430
418
|
for _ in range(maxiter):
|
|
431
419
|
|
|
432
|
-
P =
|
|
420
|
+
P = A_mv_reg(V)
|
|
433
421
|
alpha = V.dot(P)
|
|
434
422
|
P -= beta*V_prev
|
|
435
423
|
P -= alpha*V
|
torchzero/linalg/svd.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
# import torch
|
|
2
|
+
|
|
3
|
+
# # projected svd
|
|
4
|
+
# # adapted from https://github.com/smortezavi/Randomized_SVD_GPU
|
|
5
|
+
# def randomized_svd(M: torch.Tensor, k: int, driver=None):
|
|
6
|
+
# *_, m, n = M.shape
|
|
7
|
+
# transpose = False
|
|
8
|
+
# if m < n:
|
|
9
|
+
# transpose = True
|
|
10
|
+
# M = M.mT
|
|
11
|
+
# m,n = n,m
|
|
12
|
+
|
|
13
|
+
# rand_matrix = torch.randn(size=(n, k), device=M.device, dtype=M.dtype)
|
|
14
|
+
# Q, _ = torch.linalg.qr(M @ rand_matrix, mode='reduced') # pylint:disable=not-callable
|
|
15
|
+
# smaller_matrix = Q.mT @ M
|
|
16
|
+
# U_hat, s, V = torch.linalg.svd(smaller_matrix, driver=driver, full_matrices=False) # pylint:disable=not-callable
|
|
17
|
+
# U = Q @ U_hat
|
|
18
|
+
|
|
19
|
+
# if transpose: return V.mT, s, U.mT
|
|
20
|
+
# return U, s, V
|
|
@@ -0,0 +1,168 @@
|
|
|
1
|
+
"""torch linalg with correct typing and retries in float64"""
|
|
2
|
+
from typing import NamedTuple
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def cholesky(A: torch.Tensor, *, upper=False, retry_float64:bool=False) -> torch.Tensor:
|
|
8
|
+
"""A - SPD, returns lower triangular L such that ``A = L @ L.mH`` also can pass L to ``torch.cholesky_solve``"""
|
|
9
|
+
try:
|
|
10
|
+
return torch.linalg.cholesky(A, upper=upper) # pylint:disable=not-callable
|
|
11
|
+
|
|
12
|
+
except torch.linalg.LinAlgError as e:
|
|
13
|
+
if not retry_float64: raise e
|
|
14
|
+
dtype = A.dtype
|
|
15
|
+
if dtype == torch.float64: raise e
|
|
16
|
+
return cholesky(A.to(torch.float64), upper=upper, retry_float64=False).to(dtype)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class _QRTuple(NamedTuple):
|
|
20
|
+
Q: torch.Tensor
|
|
21
|
+
R: torch.Tensor
|
|
22
|
+
|
|
23
|
+
def qr(A: torch.Tensor, mode='reduced', retry_float64:bool=False) -> _QRTuple:
|
|
24
|
+
"""A - any matrix ``(*, m, n)`` (for some reason sometimes it takes ages on some matrices)
|
|
25
|
+
|
|
26
|
+
### Returns (if mode = "reduced"):
|
|
27
|
+
|
|
28
|
+
Q: ``(*, m, k)`` - orthogonal
|
|
29
|
+
|
|
30
|
+
R: ``(*, k, n)`` - upper triangular
|
|
31
|
+
|
|
32
|
+
where ``k = min(m,n)``
|
|
33
|
+
"""
|
|
34
|
+
try:
|
|
35
|
+
return torch.linalg.qr(A, mode=mode) # pylint:disable=not-callable
|
|
36
|
+
|
|
37
|
+
except torch.linalg.LinAlgError as e:
|
|
38
|
+
if not retry_float64: raise e
|
|
39
|
+
dtype = A.dtype
|
|
40
|
+
if dtype == torch.float64: raise e
|
|
41
|
+
Q, R = qr(A.to(torch.float64), mode=mode, retry_float64=False)
|
|
42
|
+
return _QRTuple(Q=Q.to(dtype), R=R.to(dtype))
|
|
43
|
+
|
|
44
|
+
def eigh(A: torch.Tensor, UPLO="L", retry_float64:bool=False) -> tuple[torch.Tensor, torch.Tensor]:
|
|
45
|
+
"""A - symmetric, returns ``(L, Q)``, ``A = Q @ torch.diag(L) @ Q.mH``, this is faster than SVD"""
|
|
46
|
+
try:
|
|
47
|
+
return torch.linalg.eigh(A, UPLO=UPLO) # pylint:disable=not-callable
|
|
48
|
+
|
|
49
|
+
except torch.linalg.LinAlgError as e:
|
|
50
|
+
if not retry_float64: raise e
|
|
51
|
+
dtype = A.dtype
|
|
52
|
+
if dtype == torch.float64: raise e
|
|
53
|
+
L, Q = eigh(A.to(torch.float64), UPLO=UPLO, retry_float64=False)
|
|
54
|
+
return L.to(dtype), Q.to(dtype)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class _SVDTuple(NamedTuple):
|
|
59
|
+
U: torch.Tensor
|
|
60
|
+
S: torch.Tensor
|
|
61
|
+
Vh: torch.Tensor
|
|
62
|
+
|
|
63
|
+
def svd(A: torch.Tensor, full_matrices=True, driver=None, retry_float64:bool=False) -> _SVDTuple:
|
|
64
|
+
"""A - any matrix ``(*, n, m)``, but slows down if A isn't well conditioned, ``A = U @ torch.diag(S) @ Vh``
|
|
65
|
+
|
|
66
|
+
Don't forget to set ``full_matrices=False``
|
|
67
|
+
|
|
68
|
+
### Returns:
|
|
69
|
+
|
|
70
|
+
U: ``(*, m, m)`` or ``(*, m, k)`` - orthogonal
|
|
71
|
+
|
|
72
|
+
S: ``(*, k,)`` - singular values
|
|
73
|
+
|
|
74
|
+
V^H: ``(*, n, n)`` or ``(*, n, k)`` - orthogonal
|
|
75
|
+
|
|
76
|
+
where ``k = min(m,n)``
|
|
77
|
+
|
|
78
|
+
### Drivers
|
|
79
|
+
|
|
80
|
+
drivers are only supported on CUDA so A is moved to CUDA by this function if needed
|
|
81
|
+
|
|
82
|
+
from docs:
|
|
83
|
+
|
|
84
|
+
If A is well-conditioned (its condition number is not too large), or you do not mind some precision loss.
|
|
85
|
+
|
|
86
|
+
For a general matrix: ‘gesvdj’ (Jacobi method)
|
|
87
|
+
|
|
88
|
+
If A is tall or wide (m >> n or m << n): ‘gesvda’ (Approximate method)
|
|
89
|
+
|
|
90
|
+
If A is not well-conditioned or precision is relevant: ‘gesvd’ (QR based)
|
|
91
|
+
|
|
92
|
+
By default (driver= None), we call ‘gesvdj’ and, if it fails, we fallback to ‘gesvd’.
|
|
93
|
+
"""
|
|
94
|
+
# drivers are only for CUDA
|
|
95
|
+
# also the only one that doesn't freeze is ‘gesvda’
|
|
96
|
+
device=None
|
|
97
|
+
if driver is not None:
|
|
98
|
+
device = A.device
|
|
99
|
+
A = A.cuda()
|
|
100
|
+
|
|
101
|
+
try:
|
|
102
|
+
U, S, Vh = torch.linalg.svd(A, full_matrices=full_matrices, driver=driver) # pylint:disable=not-callable
|
|
103
|
+
if device is not None:
|
|
104
|
+
U = U.to(device); S = S.to(device); Vh = Vh.to(device)
|
|
105
|
+
return _SVDTuple(U=U, S=S, Vh=Vh)
|
|
106
|
+
|
|
107
|
+
except torch.linalg.LinAlgError as e:
|
|
108
|
+
if not retry_float64: raise e
|
|
109
|
+
dtype = A.dtype
|
|
110
|
+
if dtype == torch.float64: raise e
|
|
111
|
+
U, S, Vh = svd(A.to(torch.float64), full_matrices=full_matrices, driver=driver, retry_float64=False)
|
|
112
|
+
return _SVDTuple(U=U.to(dtype), S=S.to(dtype), Vh=Vh.to(dtype))
|
|
113
|
+
|
|
114
|
+
def solve(A: torch.Tensor, B: torch.Tensor, left:bool=True, retry_float64:bool=False) -> torch.Tensor:
|
|
115
|
+
"""I think this uses LU"""
|
|
116
|
+
try:
|
|
117
|
+
return torch.linalg.solve(A, B, left=left) # pylint:disable=not-callable
|
|
118
|
+
|
|
119
|
+
except torch.linalg.LinAlgError as e:
|
|
120
|
+
if not retry_float64: raise e
|
|
121
|
+
dtype = A.dtype
|
|
122
|
+
if dtype == torch.float64: raise e
|
|
123
|
+
return solve(A.to(torch.float64), B.to(torch.float64), left=left, retry_float64=False).to(dtype)
|
|
124
|
+
|
|
125
|
+
class _SolveExTuple(NamedTuple):
|
|
126
|
+
result: torch.Tensor
|
|
127
|
+
info: int
|
|
128
|
+
|
|
129
|
+
def solve_ex(A: torch.Tensor, B: torch.Tensor, left:bool=True, retry_float64:bool=False) -> _SolveExTuple:
|
|
130
|
+
"""I think this uses LU"""
|
|
131
|
+
result, info = torch.linalg.solve_ex(A, B, left=left) # pylint:disable=not-callable
|
|
132
|
+
|
|
133
|
+
if info != 0:
|
|
134
|
+
if not retry_float64: return _SolveExTuple(result, info)
|
|
135
|
+
dtype = A.dtype
|
|
136
|
+
if dtype == torch.float64: return _SolveExTuple(result, info)
|
|
137
|
+
result, info = solve_ex(A.to(torch.float64), B.to(torch.float64), retry_float64=False)
|
|
138
|
+
return _SolveExTuple(result.to(dtype), info)
|
|
139
|
+
|
|
140
|
+
return _SolveExTuple(result, info)
|
|
141
|
+
|
|
142
|
+
def inv(A: torch.Tensor, retry_float64:bool=False) -> torch.Tensor:
|
|
143
|
+
try:
|
|
144
|
+
return torch.linalg.inv(A) # pylint:disable=not-callable
|
|
145
|
+
|
|
146
|
+
except torch.linalg.LinAlgError as e:
|
|
147
|
+
if not retry_float64: raise e
|
|
148
|
+
dtype = A.dtype
|
|
149
|
+
if dtype == torch.float64: raise e
|
|
150
|
+
return inv(A.to(torch.float64), retry_float64=False).to(dtype)
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
class _InvExTuple(NamedTuple):
|
|
154
|
+
inverse: torch.Tensor
|
|
155
|
+
info: int
|
|
156
|
+
|
|
157
|
+
def inv_ex(A: torch.Tensor, *, check_errors=False, retry_float64:bool=False) -> _InvExTuple:
|
|
158
|
+
"""this retries in float64 but on fail info will be not 0"""
|
|
159
|
+
inverse, info = torch.linalg.inv_ex(A, check_errors=check_errors) # pylint:disable=not-callable
|
|
160
|
+
|
|
161
|
+
if info != 0:
|
|
162
|
+
if not retry_float64: return _InvExTuple(inverse, info)
|
|
163
|
+
dtype = A.dtype
|
|
164
|
+
if dtype == torch.float64: return _InvExTuple(inverse, info)
|
|
165
|
+
inverse, info = inv_ex(A.to(torch.float64), retry_float64=False)
|
|
166
|
+
return _InvExTuple(inverse.to(dtype), info)
|
|
167
|
+
|
|
168
|
+
return _InvExTuple(inverse, info)
|
|
@@ -12,7 +12,7 @@ from .lmadagrad import LMAdagrad
|
|
|
12
12
|
from .lion import Lion
|
|
13
13
|
from .mars import MARSCorrection
|
|
14
14
|
from .matrix_momentum import MatrixMomentum
|
|
15
|
-
from .msam import
|
|
15
|
+
from .msam import MSAMMomentum, MSAM
|
|
16
16
|
from .muon import DualNormCorrection, MuonAdjustLR, Orthogonalize, orthogonalize_grads_
|
|
17
17
|
from .natural_gradient import NaturalGradient
|
|
18
18
|
from .orthograd import OrthoGrad, orthograd_
|