torchzero 0.3.11__py3-none-any.whl → 0.3.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_opts.py +95 -69
- tests/test_tensorlist.py +8 -7
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +225 -72
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +44 -24
- torchzero/modules/__init__.py +13 -5
- torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
- torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
- torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
- torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
- torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
- torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
- torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
- torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
- torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
- torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
- torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
- torchzero/modules/clipping/clipping.py +85 -92
- torchzero/modules/clipping/ema_clipping.py +5 -5
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
- torchzero/modules/experimental/__init__.py +9 -32
- torchzero/modules/experimental/dct.py +2 -2
- torchzero/modules/experimental/fft.py +2 -2
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +27 -14
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/experimental/structural_projections.py +1 -1
- torchzero/modules/functional.py +50 -14
- torchzero/modules/grad_approximation/fdm.py +19 -20
- torchzero/modules/grad_approximation/forward_gradient.py +4 -2
- torchzero/modules/grad_approximation/grad_approximator.py +43 -47
- torchzero/modules/grad_approximation/rfdm.py +144 -122
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +31 -23
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +2 -2
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +69 -44
- torchzero/modules/line_search/backtracking.py +83 -70
- torchzero/modules/line_search/line_search.py +159 -68
- torchzero/modules/line_search/scipy.py +1 -1
- torchzero/modules/line_search/strong_wolfe.py +319 -218
- torchzero/modules/misc/__init__.py +8 -0
- torchzero/modules/misc/debug.py +4 -4
- torchzero/modules/misc/escape.py +9 -7
- torchzero/modules/misc/gradient_accumulation.py +88 -22
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +82 -15
- torchzero/modules/misc/multistep.py +47 -11
- torchzero/modules/misc/regularization.py +5 -9
- torchzero/modules/misc/split.py +55 -35
- torchzero/modules/misc/switch.py +1 -1
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +3 -3
- torchzero/modules/momentum/cautious.py +42 -47
- torchzero/modules/momentum/momentum.py +35 -1
- torchzero/modules/ops/__init__.py +9 -1
- torchzero/modules/ops/binary.py +9 -8
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
- torchzero/modules/ops/multi.py +15 -15
- torchzero/modules/ops/reduce.py +1 -1
- torchzero/modules/ops/utility.py +12 -8
- torchzero/modules/projections/projection.py +4 -4
- torchzero/modules/quasi_newton/__init__.py +1 -16
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
- torchzero/modules/quasi_newton/lbfgs.py +256 -200
- torchzero/modules/quasi_newton/lsr1.py +167 -132
- torchzero/modules/quasi_newton/quasi_newton.py +346 -446
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +252 -0
- torchzero/modules/second_order/__init__.py +2 -1
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +133 -88
- torchzero/modules/second_order/newton_cg.py +141 -80
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +1 -1
- torchzero/modules/step_size/adaptive.py +312 -47
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +97 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/weight_decay.py +65 -64
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +359 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +0 -1
- torchzero/optim/wrappers/fcmaes.py +3 -2
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +2 -2
- torchzero/optim/wrappers/scipy.py +81 -22
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +123 -111
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +226 -154
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/python_tools.py +6 -0
- torchzero/utils/tensorlist.py +105 -34
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.13.dist-info/METADATA +14 -0
- torchzero-0.3.13.dist-info/RECORD +166 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -59
- docs/source/docstring template.py +0 -46
- torchzero/modules/experimental/absoap.py +0 -253
- torchzero/modules/experimental/adadam.py +0 -118
- torchzero/modules/experimental/adamY.py +0 -131
- torchzero/modules/experimental/adam_lambertw.py +0 -149
- torchzero/modules/experimental/adaptive_step_size.py +0 -90
- torchzero/modules/experimental/adasoap.py +0 -177
- torchzero/modules/experimental/cosine.py +0 -214
- torchzero/modules/experimental/cubic_adam.py +0 -97
- torchzero/modules/experimental/eigendescent.py +0 -120
- torchzero/modules/experimental/etf.py +0 -195
- torchzero/modules/experimental/exp_adam.py +0 -113
- torchzero/modules/experimental/expanded_lbfgs.py +0 -141
- torchzero/modules/experimental/hnewton.py +0 -85
- torchzero/modules/experimental/modular_lbfgs.py +0 -265
- torchzero/modules/experimental/parabolic_search.py +0 -220
- torchzero/modules/experimental/subspace_preconditioners.py +0 -145
- torchzero/modules/experimental/tensor_adagrad.py +0 -42
- torchzero/modules/line_search/polynomial.py +0 -233
- torchzero/modules/momentum/matrix_momentum.py +0 -193
- torchzero/modules/optimizers/adagrad.py +0 -165
- torchzero/modules/quasi_newton/trust_region.py +0 -397
- torchzero/modules/smoothing/gaussian.py +0 -198
- torchzero-0.3.11.dist-info/METADATA +0 -404
- torchzero-0.3.11.dist-info/RECORD +0 -159
- torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
- /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
torchzero/utils/linalg/solve.py
CHANGED
|
@@ -1,99 +1,32 @@
|
|
|
1
1
|
# pyright: reportArgumentType=false
|
|
2
|
+
import math
|
|
3
|
+
from collections import deque
|
|
2
4
|
from collections.abc import Callable
|
|
3
|
-
from typing import Any, overload
|
|
5
|
+
from typing import Any, NamedTuple, overload
|
|
4
6
|
|
|
5
7
|
import torch
|
|
6
8
|
|
|
7
9
|
from .. import (
|
|
8
10
|
TensorList,
|
|
9
11
|
generic_eq,
|
|
10
|
-
|
|
12
|
+
generic_finfo_tiny,
|
|
11
13
|
generic_numel,
|
|
12
|
-
generic_randn_like,
|
|
13
14
|
generic_vector_norm,
|
|
14
15
|
generic_zeros_like,
|
|
15
16
|
)
|
|
16
17
|
|
|
17
18
|
|
|
18
|
-
def _make_A_mm_reg(A_mm: Callable
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
if not generic_eq(reg, 0): Ax += x*reg
|
|
23
|
-
return Ax
|
|
24
|
-
return A_mm_reg
|
|
25
|
-
|
|
26
|
-
if not isinstance(A_mm, torch.Tensor): raise TypeError(type(A_mm))
|
|
27
|
-
|
|
28
|
-
def Ax_reg(x): # A_mm with regularization
|
|
29
|
-
if A_mm.ndim == 1: Ax = A_mm * x
|
|
30
|
-
else: Ax = A_mm @ x
|
|
31
|
-
if reg != 0: Ax += x*reg
|
|
19
|
+
def _make_A_mm_reg(A_mm: Callable, reg):
|
|
20
|
+
def A_mm_reg(x): # A_mm with regularization
|
|
21
|
+
Ax = A_mm(x)
|
|
22
|
+
if not generic_eq(reg, 0): Ax += x*reg
|
|
32
23
|
return Ax
|
|
33
|
-
return
|
|
24
|
+
return A_mm_reg
|
|
34
25
|
|
|
26
|
+
def _identity(x): return x
|
|
35
27
|
|
|
36
|
-
@overload
|
|
37
|
-
def cg(
|
|
38
|
-
A_mm: Callable[[torch.Tensor], torch.Tensor] | torch.Tensor,
|
|
39
|
-
b: torch.Tensor,
|
|
40
|
-
x0_: torch.Tensor | None = None,
|
|
41
|
-
tol: float | None = 1e-4,
|
|
42
|
-
maxiter: int | None = None,
|
|
43
|
-
reg: float = 0,
|
|
44
|
-
) -> torch.Tensor: ...
|
|
45
|
-
@overload
|
|
46
|
-
def cg(
|
|
47
|
-
A_mm: Callable[[TensorList], TensorList],
|
|
48
|
-
b: TensorList,
|
|
49
|
-
x0_: TensorList | None = None,
|
|
50
|
-
tol: float | None = 1e-4,
|
|
51
|
-
maxiter: int | None = None,
|
|
52
|
-
reg: float | list[float] | tuple[float] = 0,
|
|
53
|
-
) -> TensorList: ...
|
|
54
28
|
|
|
55
|
-
|
|
56
|
-
A_mm: Callable | torch.Tensor,
|
|
57
|
-
b: torch.Tensor | TensorList,
|
|
58
|
-
x0_: torch.Tensor | TensorList | None = None,
|
|
59
|
-
tol: float | None = 1e-4,
|
|
60
|
-
maxiter: int | None = None,
|
|
61
|
-
reg: float | list[float] | tuple[float] = 0,
|
|
62
|
-
):
|
|
63
|
-
A_mm_reg = _make_A_mm_reg(A_mm, reg)
|
|
64
|
-
eps = generic_finfo_eps(b)
|
|
65
|
-
|
|
66
|
-
if tol is None: tol = eps
|
|
67
|
-
|
|
68
|
-
if maxiter is None: maxiter = generic_numel(b)
|
|
69
|
-
if x0_ is None: x0_ = generic_zeros_like(b)
|
|
70
|
-
|
|
71
|
-
x = x0_
|
|
72
|
-
residual = b - A_mm_reg(x)
|
|
73
|
-
p = residual.clone() # search direction
|
|
74
|
-
r_norm = generic_vector_norm(residual)
|
|
75
|
-
init_norm = r_norm
|
|
76
|
-
if r_norm < tol: return x
|
|
77
|
-
k = 0
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
while True:
|
|
81
|
-
Ap = A_mm_reg(p)
|
|
82
|
-
step_size = (r_norm**2) / p.dot(Ap)
|
|
83
|
-
x += step_size * p # Update solution
|
|
84
|
-
residual -= step_size * Ap # Update residual
|
|
85
|
-
new_r_norm = generic_vector_norm(residual)
|
|
86
|
-
|
|
87
|
-
k += 1
|
|
88
|
-
if new_r_norm <= tol * init_norm: return x
|
|
89
|
-
if k >= maxiter: return x
|
|
90
|
-
|
|
91
|
-
beta = (new_r_norm**2) / (r_norm**2)
|
|
92
|
-
p = residual + beta*p
|
|
93
|
-
r_norm = new_r_norm
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
# https://arxiv.org/pdf/2110.02820 algorithm 2.1 apparently supposed to be diabolical
|
|
29
|
+
# https://arxiv.org/pdf/2110.02820
|
|
97
30
|
def nystrom_approximation(
|
|
98
31
|
A_mm: Callable[[torch.Tensor], torch.Tensor],
|
|
99
32
|
ndim: int,
|
|
@@ -115,7 +48,6 @@ def nystrom_approximation(
|
|
|
115
48
|
lambd = (S.pow(2) - v).clip(min=0) #Remove shift, compute eigs
|
|
116
49
|
return U, lambd
|
|
117
50
|
|
|
118
|
-
# this one works worse
|
|
119
51
|
def nystrom_sketch_and_solve(
|
|
120
52
|
A_mm: Callable[[torch.Tensor], torch.Tensor],
|
|
121
53
|
b: torch.Tensor,
|
|
@@ -141,7 +73,6 @@ def nystrom_sketch_and_solve(
|
|
|
141
73
|
term2 = (1.0 / reg) * (b - U @ Uᵀb)
|
|
142
74
|
return (term1 + term2).squeeze(-1)
|
|
143
75
|
|
|
144
|
-
# this one is insane
|
|
145
76
|
def nystrom_pcg(
|
|
146
77
|
A_mm: Callable[[torch.Tensor], torch.Tensor],
|
|
147
78
|
b: torch.Tensor,
|
|
@@ -161,7 +92,7 @@ def nystrom_pcg(
|
|
|
161
92
|
generator=generator,
|
|
162
93
|
)
|
|
163
94
|
lambd += reg
|
|
164
|
-
eps = torch.finfo(b.dtype).
|
|
95
|
+
eps = torch.finfo(b.dtype).tiny * 2
|
|
165
96
|
if tol is None: tol = eps
|
|
166
97
|
|
|
167
98
|
def A_mm_reg(x): # A_mm with regularization
|
|
@@ -201,98 +132,239 @@ def nystrom_pcg(
|
|
|
201
132
|
|
|
202
133
|
|
|
203
134
|
def _safe_clip(x: torch.Tensor):
|
|
204
|
-
"""makes sure scalar tensor x is not smaller than
|
|
135
|
+
"""makes sure scalar tensor x is not smaller than tiny"""
|
|
205
136
|
assert x.numel() == 1, x.shape
|
|
206
|
-
eps = torch.finfo(x.dtype).
|
|
137
|
+
eps = torch.finfo(x.dtype).tiny * 2
|
|
207
138
|
if x.abs() < eps: return x.new_full(x.size(), eps).copysign(x)
|
|
208
139
|
return x
|
|
209
140
|
|
|
210
|
-
def _trust_tau(x,d,
|
|
141
|
+
def _trust_tau(x,d,trust_radius):
|
|
211
142
|
xx = x.dot(x)
|
|
212
143
|
xd = x.dot(d)
|
|
213
144
|
dd = _safe_clip(d.dot(d))
|
|
214
145
|
|
|
215
|
-
rad = (xd**2 - dd * (xx -
|
|
146
|
+
rad = (xd**2 - dd * (xx - trust_radius**2)).clip(min=0).sqrt()
|
|
216
147
|
tau = (-xd + rad) / dd
|
|
217
148
|
|
|
218
149
|
return x + tau * d
|
|
219
150
|
|
|
220
151
|
|
|
152
|
+
class CG:
|
|
153
|
+
"""Conjugate gradient method.
|
|
154
|
+
|
|
155
|
+
Args:
|
|
156
|
+
A_mm (Callable[[torch.Tensor], torch.Tensor] | torch.Tensor): Callable that returns matvec ``Ax``.
|
|
157
|
+
b (torch.Tensor): right hand side
|
|
158
|
+
x0 (torch.Tensor | None, optional): initial guess, defaults to zeros. Defaults to None.
|
|
159
|
+
tol (float | None, optional): tolerance for convergence. Defaults to 1e-8.
|
|
160
|
+
maxiter (int | None, optional):
|
|
161
|
+
maximum number of iterations, if None sets to number of dimensions. Defaults to None.
|
|
162
|
+
reg (float, optional): regularization. Defaults to 0.
|
|
163
|
+
trust_radius (float | None, optional):
|
|
164
|
+
CG is terminated whenever solution exceeds trust region, returning a solution modified to be within it. Defaults to None.
|
|
165
|
+
npc_terminate (bool, optional):
|
|
166
|
+
whether to terminate CG whenever negative curavture is detected. Defaults to False.
|
|
167
|
+
miniter (int, optional):
|
|
168
|
+
minimal number of iterations even if tolerance is satisfied, this ensures some progress
|
|
169
|
+
is always made.
|
|
170
|
+
history_size (int, optional):
|
|
171
|
+
number of past iterations to store, to re-use them when trust radius is decreased.
|
|
172
|
+
P_mm (Callable | torch.Tensor | None, optional):
|
|
173
|
+
Callable that returns inverse preconditioner times vector. Defaults to None.
|
|
174
|
+
"""
|
|
175
|
+
def __init__(
|
|
176
|
+
self,
|
|
177
|
+
A_mm: Callable,
|
|
178
|
+
b: torch.Tensor | TensorList,
|
|
179
|
+
x0: torch.Tensor | TensorList | None = None,
|
|
180
|
+
tol: float | None = 1e-4,
|
|
181
|
+
maxiter: int | None = None,
|
|
182
|
+
reg: float = 0,
|
|
183
|
+
trust_radius: float | None = None,
|
|
184
|
+
npc_terminate: bool=False,
|
|
185
|
+
miniter: int = 0,
|
|
186
|
+
history_size: int = 0,
|
|
187
|
+
P_mm: Callable | None = None,
|
|
188
|
+
):
|
|
189
|
+
# --------------------------------- set attrs -------------------------------- #
|
|
190
|
+
self.A_mm = _make_A_mm_reg(A_mm, reg)
|
|
191
|
+
self.b = b
|
|
192
|
+
if tol is None: tol = generic_finfo_tiny(b) * 2
|
|
193
|
+
self.tol = tol
|
|
194
|
+
self.eps = generic_finfo_tiny(b) * 2
|
|
195
|
+
if maxiter is None: maxiter = generic_numel(b)
|
|
196
|
+
self.maxiter = maxiter
|
|
197
|
+
self.miniter = miniter
|
|
198
|
+
self.trust_radius = trust_radius
|
|
199
|
+
self.npc_terminate = npc_terminate
|
|
200
|
+
self.P_mm = P_mm if P_mm is not None else _identity
|
|
201
|
+
|
|
202
|
+
if history_size > 0:
|
|
203
|
+
self.history = deque(maxlen = history_size)
|
|
204
|
+
"""history of (x, x_norm, d)"""
|
|
205
|
+
else:
|
|
206
|
+
self.history = None
|
|
207
|
+
|
|
208
|
+
# -------------------------------- initialize -------------------------------- #
|
|
209
|
+
|
|
210
|
+
self.iter = 0
|
|
211
|
+
|
|
212
|
+
if x0 is None:
|
|
213
|
+
self.x = generic_zeros_like(b)
|
|
214
|
+
self.r = b
|
|
215
|
+
else:
|
|
216
|
+
self.x = x0
|
|
217
|
+
self.r = b - A_mm(self.x)
|
|
218
|
+
|
|
219
|
+
self.z = self.P_mm(self.r)
|
|
220
|
+
self.d = self.z
|
|
221
|
+
|
|
222
|
+
if self.history is not None:
|
|
223
|
+
self.history.append((self.x, generic_vector_norm(self.x), self.d))
|
|
224
|
+
|
|
225
|
+
def step(self) -> tuple[Any, bool]:
|
|
226
|
+
"""returns ``(solution, should_terminate)``"""
|
|
227
|
+
x, b, d, r, z = self.x, self.b, self.d, self.r, self.z
|
|
228
|
+
|
|
229
|
+
if self.iter >= self.maxiter:
|
|
230
|
+
return x, True
|
|
231
|
+
|
|
232
|
+
Ad = self.A_mm(d)
|
|
233
|
+
dAd = d.dot(Ad)
|
|
234
|
+
|
|
235
|
+
# check negative curvature
|
|
236
|
+
if dAd <= self.eps:
|
|
237
|
+
if self.trust_radius is not None: return _trust_tau(x, d, self.trust_radius), True
|
|
238
|
+
if self.iter == 0: return b * (b.dot(b) / dAd).abs(), True
|
|
239
|
+
if self.npc_terminate: return x, True
|
|
240
|
+
|
|
241
|
+
rz = r.dot(z)
|
|
242
|
+
alpha = rz / dAd
|
|
243
|
+
x_next = x + alpha * d
|
|
244
|
+
|
|
245
|
+
# check if the step exceeds the trust-region boundary
|
|
246
|
+
x_next_norm = None
|
|
247
|
+
if self.trust_radius is not None:
|
|
248
|
+
x_next_norm = generic_vector_norm(x_next)
|
|
249
|
+
if x_next_norm >= self.trust_radius:
|
|
250
|
+
return _trust_tau(x, d, self.trust_radius), True
|
|
251
|
+
|
|
252
|
+
# update step, residual and direction
|
|
253
|
+
r_next = r - alpha * Ad
|
|
254
|
+
|
|
255
|
+
# check if r is sufficiently small
|
|
256
|
+
if self.iter >= self.miniter and generic_vector_norm(r_next) < self.tol:
|
|
257
|
+
return x_next, True
|
|
258
|
+
|
|
259
|
+
# update d, r, z
|
|
260
|
+
z_next = self.P_mm(r_next)
|
|
261
|
+
beta = r_next.dot(z_next) / rz
|
|
262
|
+
|
|
263
|
+
self.d = z_next + beta * d
|
|
264
|
+
self.x = x_next
|
|
265
|
+
self.r = r_next
|
|
266
|
+
self.z = z_next
|
|
267
|
+
|
|
268
|
+
# update history
|
|
269
|
+
if self.history is not None:
|
|
270
|
+
if x_next_norm is None: x_next_norm = generic_vector_norm(x_next)
|
|
271
|
+
self.history.append((self.x, x_next_norm, self.d))
|
|
272
|
+
|
|
273
|
+
self.iter += 1
|
|
274
|
+
return x, False
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
def solve(self):
|
|
278
|
+
# return initial guess if it is good enough
|
|
279
|
+
if self.miniter < 1 and generic_vector_norm(self.r) < self.tol:
|
|
280
|
+
return self.x
|
|
281
|
+
|
|
282
|
+
should_terminate = False
|
|
283
|
+
sol = None
|
|
284
|
+
|
|
285
|
+
while not should_terminate:
|
|
286
|
+
sol, should_terminate = self.step()
|
|
287
|
+
|
|
288
|
+
assert sol is not None
|
|
289
|
+
return sol
|
|
290
|
+
|
|
291
|
+
def find_within_trust_radius(history, trust_radius: float):
|
|
292
|
+
"""find first ``x`` in history that exceeds trust radius, if no such ``x`` exists, returns ``None``"""
|
|
293
|
+
for x, x_norm, d in reversed(tuple(history)):
|
|
294
|
+
if x_norm <= trust_radius:
|
|
295
|
+
return _trust_tau(x, d, trust_radius)
|
|
296
|
+
return None
|
|
297
|
+
|
|
298
|
+
class _TensorSolution(NamedTuple):
|
|
299
|
+
x: torch.Tensor
|
|
300
|
+
solver: CG
|
|
301
|
+
|
|
302
|
+
class _TensorListSolution(NamedTuple):
|
|
303
|
+
x: TensorList
|
|
304
|
+
solver: CG
|
|
305
|
+
|
|
306
|
+
|
|
221
307
|
@overload
|
|
222
|
-
def
|
|
223
|
-
A_mm: Callable[[torch.Tensor], torch.Tensor]
|
|
308
|
+
def cg(
|
|
309
|
+
A_mm: Callable[[torch.Tensor], torch.Tensor],
|
|
224
310
|
b: torch.Tensor,
|
|
225
|
-
trust_region: float,
|
|
226
311
|
x0: torch.Tensor | None = None,
|
|
227
|
-
tol: float | None = 1e-
|
|
312
|
+
tol: float | None = 1e-8,
|
|
228
313
|
maxiter: int | None = None,
|
|
229
314
|
reg: float = 0,
|
|
230
|
-
|
|
315
|
+
trust_radius: float | None = None,
|
|
316
|
+
npc_terminate: bool = False,
|
|
317
|
+
miniter: int = 0,
|
|
318
|
+
history_size: int = 0,
|
|
319
|
+
P_mm: Callable[[torch.Tensor], torch.Tensor] | None = None
|
|
320
|
+
) -> _TensorSolution: ...
|
|
231
321
|
@overload
|
|
232
|
-
def
|
|
322
|
+
def cg(
|
|
233
323
|
A_mm: Callable[[TensorList], TensorList],
|
|
234
324
|
b: TensorList,
|
|
235
|
-
trust_region: float,
|
|
236
325
|
x0: TensorList | None = None,
|
|
237
|
-
tol: float | None = 1e-
|
|
326
|
+
tol: float | None = 1e-8,
|
|
238
327
|
maxiter: int | None = None,
|
|
239
328
|
reg: float | list[float] | tuple[float] = 0,
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
329
|
+
trust_radius: float | None = None,
|
|
330
|
+
npc_terminate: bool=False,
|
|
331
|
+
miniter: int = 0,
|
|
332
|
+
history_size: int = 0,
|
|
333
|
+
P_mm: Callable[[TensorList], TensorList] | None = None
|
|
334
|
+
) -> _TensorListSolution: ...
|
|
335
|
+
def cg(
|
|
336
|
+
A_mm: Callable,
|
|
243
337
|
b: torch.Tensor | TensorList,
|
|
244
|
-
trust_region: float,
|
|
245
338
|
x0: torch.Tensor | TensorList | None = None,
|
|
246
|
-
tol: float | None = 1e-
|
|
339
|
+
tol: float | None = 1e-8,
|
|
247
340
|
maxiter: int | None = None,
|
|
248
341
|
reg: float | list[float] | tuple[float] = 0,
|
|
342
|
+
trust_radius: float | None = None,
|
|
343
|
+
npc_terminate: bool = False,
|
|
344
|
+
miniter: int = 0,
|
|
345
|
+
history_size:int = 0,
|
|
346
|
+
P_mm: Callable | None = None
|
|
249
347
|
):
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
if generic_vector_norm(r) < tol:
|
|
264
|
-
return x
|
|
265
|
-
|
|
266
|
-
if maxiter is None:
|
|
267
|
-
maxiter = generic_numel(b)
|
|
268
|
-
|
|
269
|
-
for _ in range(maxiter):
|
|
270
|
-
Ad = A_mm_reg(d)
|
|
271
|
-
|
|
272
|
-
d_Ad = d.dot(Ad)
|
|
273
|
-
if d_Ad <= eps:
|
|
274
|
-
return _trust_tau(x, d, trust_region)
|
|
275
|
-
|
|
276
|
-
alpha = r.dot(r) / d_Ad
|
|
277
|
-
p_next = x + alpha * d
|
|
278
|
-
|
|
279
|
-
# check if the step exceeds the trust-region boundary
|
|
280
|
-
if generic_vector_norm(p_next) >= trust_region:
|
|
281
|
-
return _trust_tau(x, d, trust_region)
|
|
282
|
-
|
|
283
|
-
# update step, residual and direction
|
|
284
|
-
x = p_next
|
|
285
|
-
r_next = r - alpha * Ad
|
|
286
|
-
|
|
287
|
-
if generic_vector_norm(r_next) < tol:
|
|
288
|
-
return x
|
|
348
|
+
solver = CG(
|
|
349
|
+
A_mm=A_mm,
|
|
350
|
+
b=b,
|
|
351
|
+
x0=x0,
|
|
352
|
+
tol=tol,
|
|
353
|
+
maxiter=maxiter,
|
|
354
|
+
reg=reg,
|
|
355
|
+
trust_radius=trust_radius,
|
|
356
|
+
npc_terminate=npc_terminate,
|
|
357
|
+
miniter=miniter,
|
|
358
|
+
history_size=history_size,
|
|
359
|
+
P_mm=P_mm,
|
|
360
|
+
)
|
|
289
361
|
|
|
290
|
-
|
|
291
|
-
d = r_next + beta * d
|
|
292
|
-
r = r_next
|
|
362
|
+
x = solver.solve()
|
|
293
363
|
|
|
294
|
-
|
|
364
|
+
if isinstance(b, torch.Tensor):
|
|
365
|
+
return _TensorSolution(x, solver)
|
|
295
366
|
|
|
367
|
+
return _TensorListSolution(x, solver)
|
|
296
368
|
|
|
297
369
|
|
|
298
370
|
# Liu, Yang, and Fred Roosta. "MINRES: From negative curvature detection to monotonicity properties." SIAM Journal on Optimization 32.4 (2022): 2636-2661.
|
|
@@ -305,7 +377,7 @@ def minres(
|
|
|
305
377
|
maxiter: int | None = None,
|
|
306
378
|
reg: float = 0,
|
|
307
379
|
npc_terminate: bool=True,
|
|
308
|
-
|
|
380
|
+
trust_radius: float | None = None,
|
|
309
381
|
) -> torch.Tensor: ...
|
|
310
382
|
@overload
|
|
311
383
|
def minres(
|
|
@@ -316,7 +388,7 @@ def minres(
|
|
|
316
388
|
maxiter: int | None = None,
|
|
317
389
|
reg: float | list[float] | tuple[float] = 0,
|
|
318
390
|
npc_terminate: bool=True,
|
|
319
|
-
|
|
391
|
+
trust_radius: float | None = None,
|
|
320
392
|
) -> TensorList: ...
|
|
321
393
|
def minres(
|
|
322
394
|
A_mm,
|
|
@@ -326,11 +398,11 @@ def minres(
|
|
|
326
398
|
maxiter: int | None = None,
|
|
327
399
|
reg: float | list[float] | tuple[float] = 0,
|
|
328
400
|
npc_terminate: bool=True,
|
|
329
|
-
|
|
401
|
+
trust_radius: float | None = None, #trust region is experimental
|
|
330
402
|
):
|
|
331
403
|
A_mm_reg = _make_A_mm_reg(A_mm, reg)
|
|
332
|
-
eps =
|
|
333
|
-
if tol is None: tol = eps
|
|
404
|
+
eps = math.sqrt(generic_finfo_tiny(b) * 2)
|
|
405
|
+
if tol is None: tol = eps
|
|
334
406
|
|
|
335
407
|
if maxiter is None: maxiter = generic_numel(b)
|
|
336
408
|
if x0 is None:
|
|
@@ -369,9 +441,9 @@ def minres(
|
|
|
369
441
|
delta1 = -c*beta
|
|
370
442
|
|
|
371
443
|
cgamma1 = c*gamma1
|
|
372
|
-
if
|
|
373
|
-
if npc_terminate: return _trust_tau(X, R,
|
|
374
|
-
return _trust_tau(X, D,
|
|
444
|
+
if trust_radius is not None and cgamma1 >= 0:
|
|
445
|
+
if npc_terminate: return _trust_tau(X, R, trust_radius)
|
|
446
|
+
return _trust_tau(X, D, trust_radius)
|
|
375
447
|
|
|
376
448
|
if npc_terminate and cgamma1 >= 0:
|
|
377
449
|
return R
|
|
@@ -380,8 +452,8 @@ def minres(
|
|
|
380
452
|
|
|
381
453
|
if abs(gamma2) <= eps: # singular system
|
|
382
454
|
# c=0; s=1; tau=0
|
|
383
|
-
if
|
|
384
|
-
return _trust_tau(X, D,
|
|
455
|
+
if trust_radius is None: return X
|
|
456
|
+
return _trust_tau(X, D, trust_radius)
|
|
385
457
|
|
|
386
458
|
c = gamma1 / gamma2
|
|
387
459
|
s = beta/gamma2
|
|
@@ -393,9 +465,9 @@ def minres(
|
|
|
393
465
|
e = e_next
|
|
394
466
|
X = X + tau*D
|
|
395
467
|
|
|
396
|
-
if
|
|
397
|
-
if generic_vector_norm(X) >
|
|
398
|
-
return _trust_tau(X, D,
|
|
468
|
+
if trust_radius is not None:
|
|
469
|
+
if generic_vector_norm(X) > trust_radius:
|
|
470
|
+
return _trust_tau(X, D, trust_radius)
|
|
399
471
|
|
|
400
472
|
if (abs(beta) < eps) or (phi / b_norm <= tol):
|
|
401
473
|
# R = zeros(R)
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
"""convenience submodule which allows to calculate a metric based on its string name,
|
|
2
|
+
used in many places"""
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from collections.abc import Sequence
|
|
5
|
+
from typing import TYPE_CHECKING, Any, Literal, overload
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from .tensorlist import TensorList
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class Metric(ABC):
|
|
15
|
+
@abstractmethod
|
|
16
|
+
def evaluate_global(self, x: "TensorList") -> torch.Tensor:
|
|
17
|
+
"""returns a global metric for a tensorlist"""
|
|
18
|
+
|
|
19
|
+
@abstractmethod
|
|
20
|
+
def evaluate_tensor(self, x: torch.Tensor, dim=None, keepdim=False) -> torch.Tensor:
|
|
21
|
+
"""returns metric for a tensor"""
|
|
22
|
+
|
|
23
|
+
def evaluate_list(self, x: "TensorList") -> "TensorList":
|
|
24
|
+
"""returns list of metrics for a tensorlist (possibly vectorized)"""
|
|
25
|
+
return x.map(self.evaluate_tensor)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class _MAD(Metric):
|
|
29
|
+
def evaluate_global(self, x): return x.abs().global_mean()
|
|
30
|
+
def evaluate_tensor(self, x, dim=None, keepdim=False): return x.abs().mean(dim=dim, keepdim=keepdim)
|
|
31
|
+
def evaluate_list(self, x): return x.abs().mean()
|
|
32
|
+
|
|
33
|
+
class _Std(Metric):
|
|
34
|
+
def evaluate_global(self, x): return x.global_std()
|
|
35
|
+
def evaluate_tensor(self, x, dim=None, keepdim=False): return x.std(dim=dim, keepdim=keepdim)
|
|
36
|
+
def evaluate_list(self, x): return x.std()
|
|
37
|
+
|
|
38
|
+
class _Var(Metric):
|
|
39
|
+
def evaluate_global(self, x): return x.global_var()
|
|
40
|
+
def evaluate_tensor(self, x, dim=None, keepdim=False): return x.var(dim=dim, keepdim=keepdim)
|
|
41
|
+
def evaluate_list(self, x): return x.var()
|
|
42
|
+
|
|
43
|
+
class _Sum(Metric):
|
|
44
|
+
def evaluate_global(self, x): return x.global_sum()
|
|
45
|
+
def evaluate_tensor(self, x, dim=None, keepdim=False): return x.sum(dim=dim, keepdim=keepdim)
|
|
46
|
+
def evaluate_list(self, x): return x.sum()
|
|
47
|
+
|
|
48
|
+
class _Norm(Metric):
|
|
49
|
+
def __init__(self, ord): self.ord = ord
|
|
50
|
+
def evaluate_global(self, x): return x.global_vector_norm(self.ord)
|
|
51
|
+
def evaluate_tensor(self, x, dim=None, keepdim=False):
|
|
52
|
+
return torch.linalg.vector_norm(x, ord=self.ord, dim=dim, keepdim=keepdim) # pylint:disable=not-callable
|
|
53
|
+
def evaluate_list(self, x): return x.norm(self.ord)
|
|
54
|
+
|
|
55
|
+
_METRIC_KEYS = Literal['mad', 'std', 'var', 'sum', 'l0', 'l1', 'l2', 'l3', 'l4', 'linf']
|
|
56
|
+
_METRICS: dict[_METRIC_KEYS, Metric] = {
|
|
57
|
+
"mad": _MAD(),
|
|
58
|
+
"std": _Std(),
|
|
59
|
+
"var": _Var(),
|
|
60
|
+
"sum": _Sum(),
|
|
61
|
+
"l0": _Norm(0),
|
|
62
|
+
"l1": _Norm(1),
|
|
63
|
+
"l2": _Norm(2),
|
|
64
|
+
"l3": _Norm(3),
|
|
65
|
+
"l4": _Norm(4),
|
|
66
|
+
"linf": _Norm(torch.inf),
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
Metrics = _METRIC_KEYS | float | torch.Tensor
|
|
70
|
+
def evaluate_metric(x: "torch.Tensor | TensorList", metric: Metrics) -> torch.Tensor:
|
|
71
|
+
if isinstance(metric, (int, float, torch.Tensor)):
|
|
72
|
+
if isinstance(x, torch.Tensor): return torch.linalg.vector_norm(x, ord=metric) # pylint:disable=not-callable
|
|
73
|
+
return x.global_vector_norm(ord=float(metric))
|
|
74
|
+
|
|
75
|
+
if isinstance(x, torch.Tensor): return _METRICS[metric].evaluate_tensor(x)
|
|
76
|
+
return _METRICS[metric].evaluate_global(x)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def calculate_metric_list(x: "TensorList", metric: Metrics) -> "TensorList":
|
|
80
|
+
if isinstance(metric, (int, float, torch.Tensor)):
|
|
81
|
+
return x.norm(ord=float(metric))
|
|
82
|
+
|
|
83
|
+
return _METRICS[metric].evaluate_list(x)
|
torchzero/utils/python_tools.py
CHANGED
|
@@ -61,3 +61,9 @@ def unpack_dicts(dicts: Iterable[Mapping[str, Any]], key:str, key2: str | None =
|
|
|
61
61
|
values = [cls(s[k] for s in dicts) for k in keys] # pyright:ignore[reportCallIssue]
|
|
62
62
|
if len(values) == 1: return values[0]
|
|
63
63
|
return values
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def safe_dict_update_(d1_:dict, d2:dict):
|
|
67
|
+
inter = set(d1_.keys()).intersection(d2.keys())
|
|
68
|
+
if len(inter) > 0: raise RuntimeError(f"Duplicate keys {inter}")
|
|
69
|
+
d1_.update(d2)
|