torchzero 0.3.10__py3-none-any.whl → 0.3.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +2 -3
- tests/test_opts.py +140 -100
- tests/test_tensorlist.py +8 -7
- tests/test_vars.py +1 -0
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +335 -50
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +197 -70
- torchzero/modules/__init__.py +13 -4
- torchzero/modules/adaptive/__init__.py +30 -0
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/adaptive/adahessian.py +224 -0
- torchzero/modules/{optimizers → adaptive}/adam.py +6 -8
- torchzero/modules/adaptive/adan.py +96 -0
- torchzero/modules/adaptive/adaptive_heavyball.py +54 -0
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/adaptive/esgd.py +171 -0
- torchzero/modules/{optimizers → adaptive}/lion.py +1 -1
- torchzero/modules/{experimental/spectral.py → adaptive/lmadagrad.py} +94 -71
- torchzero/modules/adaptive/mars.py +79 -0
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/adaptive/msam.py +188 -0
- torchzero/modules/{optimizers → adaptive}/muon.py +29 -5
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/orthograd.py +1 -1
- torchzero/modules/{optimizers → adaptive}/rmsprop.py +7 -4
- torchzero/modules/{optimizers → adaptive}/rprop.py +42 -10
- torchzero/modules/adaptive/sam.py +163 -0
- torchzero/modules/{optimizers → adaptive}/shampoo.py +47 -9
- torchzero/modules/{optimizers → adaptive}/soap.py +52 -65
- torchzero/modules/adaptive/sophia_h.py +185 -0
- torchzero/modules/clipping/clipping.py +115 -25
- torchzero/modules/clipping/ema_clipping.py +31 -17
- torchzero/modules/clipping/growth_clipping.py +8 -7
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/conjugate_gradient/cg.py +355 -0
- torchzero/modules/experimental/__init__.py +13 -19
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +5 -42
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +32 -15
- torchzero/modules/experimental/reduce_outward_lr.py +4 -4
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +13 -55
- torchzero/modules/functional.py +52 -6
- torchzero/modules/grad_approximation/fdm.py +30 -4
- torchzero/modules/grad_approximation/forward_gradient.py +16 -4
- torchzero/modules/grad_approximation/grad_approximator.py +51 -10
- torchzero/modules/grad_approximation/rfdm.py +321 -52
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +164 -93
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +124 -0
- torchzero/modules/line_search/backtracking.py +95 -57
- torchzero/modules/line_search/line_search.py +171 -22
- torchzero/modules/line_search/scipy.py +3 -3
- torchzero/modules/line_search/strong_wolfe.py +327 -199
- torchzero/modules/misc/__init__.py +35 -0
- torchzero/modules/misc/debug.py +48 -0
- torchzero/modules/misc/escape.py +62 -0
- torchzero/modules/misc/gradient_accumulation.py +136 -0
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +383 -0
- torchzero/modules/misc/multistep.py +194 -0
- torchzero/modules/misc/regularization.py +167 -0
- torchzero/modules/misc/split.py +123 -0
- torchzero/modules/{ops → misc}/switch.py +45 -4
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +9 -9
- torchzero/modules/momentum/cautious.py +51 -19
- torchzero/modules/momentum/momentum.py +37 -2
- torchzero/modules/ops/__init__.py +11 -31
- torchzero/modules/ops/accumulate.py +6 -10
- torchzero/modules/ops/binary.py +81 -34
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +16 -39
- torchzero/modules/ops/multi.py +82 -21
- torchzero/modules/ops/reduce.py +16 -8
- torchzero/modules/ops/unary.py +29 -13
- torchzero/modules/ops/utility.py +30 -18
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +3 -1
- torchzero/modules/projections/projection.py +190 -96
- torchzero/modules/quasi_newton/__init__.py +9 -14
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -0
- torchzero/modules/quasi_newton/lbfgs.py +286 -173
- torchzero/modules/quasi_newton/lsr1.py +185 -106
- torchzero/modules/quasi_newton/quasi_newton.py +816 -268
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +252 -0
- torchzero/modules/second_order/__init__.py +3 -2
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +292 -68
- torchzero/modules/second_order/newton_cg.py +365 -15
- torchzero/modules/second_order/nystrom.py +104 -1
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/laplacian.py +14 -4
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +387 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +97 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +94 -11
- torchzero/modules/wrappers/optim_wrapper.py +29 -1
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +359 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +39 -3
- torchzero/optim/wrappers/fcmaes.py +24 -15
- torchzero/optim/wrappers/mads.py +5 -6
- torchzero/optim/wrappers/nevergrad.py +16 -1
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +3 -3
- torchzero/optim/wrappers/scipy.py +86 -25
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +126 -114
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +369 -58
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/python_tools.py +16 -0
- torchzero/utils/tensorlist.py +134 -51
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.13.dist-info/METADATA +14 -0
- torchzero-0.3.13.dist-info/RECORD +166 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -57
- torchzero/modules/experimental/absoap.py +0 -250
- torchzero/modules/experimental/adadam.py +0 -112
- torchzero/modules/experimental/adamY.py +0 -125
- torchzero/modules/experimental/adasoap.py +0 -172
- torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
- torchzero/modules/experimental/eigendescent.py +0 -117
- torchzero/modules/experimental/etf.py +0 -172
- torchzero/modules/experimental/soapy.py +0 -163
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/experimental/subspace_preconditioners.py +0 -138
- torchzero/modules/experimental/tada.py +0 -38
- torchzero/modules/line_search/trust_region.py +0 -73
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/adaptive.py +0 -93
- torchzero/modules/lr/lr.py +0 -63
- torchzero/modules/momentum/matrix_momentum.py +0 -166
- torchzero/modules/ops/debug.py +0 -25
- torchzero/modules/ops/misc.py +0 -418
- torchzero/modules/ops/split.py +0 -75
- torchzero/modules/optimizers/__init__.py +0 -18
- torchzero/modules/optimizers/adagrad.py +0 -155
- torchzero/modules/optimizers/sophia_h.py +0 -129
- torchzero/modules/quasi_newton/cg.py +0 -268
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -266
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero/modules/smoothing/gaussian.py +0 -164
- torchzero-0.3.10.dist-info/METADATA +0 -379
- torchzero-0.3.10.dist-info/RECORD +0 -139
- torchzero-0.3.10.dist-info/licenses/LICENSE +0 -21
- {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
torchzero/utils/linalg/solve.py
CHANGED
|
@@ -1,69 +1,32 @@
|
|
|
1
|
+
# pyright: reportArgumentType=false
|
|
2
|
+
import math
|
|
3
|
+
from collections import deque
|
|
1
4
|
from collections.abc import Callable
|
|
2
|
-
from typing import overload
|
|
5
|
+
from typing import Any, NamedTuple, overload
|
|
6
|
+
|
|
3
7
|
import torch
|
|
4
8
|
|
|
5
|
-
from .. import
|
|
9
|
+
from .. import (
|
|
10
|
+
TensorList,
|
|
11
|
+
generic_eq,
|
|
12
|
+
generic_finfo_tiny,
|
|
13
|
+
generic_numel,
|
|
14
|
+
generic_vector_norm,
|
|
15
|
+
generic_zeros_like,
|
|
16
|
+
)
|
|
6
17
|
|
|
7
|
-
@overload
|
|
8
|
-
def cg(
|
|
9
|
-
A_mm: Callable[[torch.Tensor], torch.Tensor],
|
|
10
|
-
b: torch.Tensor,
|
|
11
|
-
x0_: torch.Tensor | None = None,
|
|
12
|
-
tol: float | None = 1e-4,
|
|
13
|
-
maxiter: int | None = None,
|
|
14
|
-
reg: float = 0,
|
|
15
|
-
) -> torch.Tensor: ...
|
|
16
|
-
@overload
|
|
17
|
-
def cg(
|
|
18
|
-
A_mm: Callable[[TensorList], TensorList],
|
|
19
|
-
b: TensorList,
|
|
20
|
-
x0_: TensorList | None = None,
|
|
21
|
-
tol: float | None = 1e-4,
|
|
22
|
-
maxiter: int | None = None,
|
|
23
|
-
reg: float | list[float] | tuple[float] = 0,
|
|
24
|
-
) -> TensorList: ...
|
|
25
18
|
|
|
26
|
-
def
|
|
27
|
-
A_mm: Callable,
|
|
28
|
-
b: torch.Tensor | TensorList,
|
|
29
|
-
x0_: torch.Tensor | TensorList | None = None,
|
|
30
|
-
tol: float | None = 1e-4,
|
|
31
|
-
maxiter: int | None = None,
|
|
32
|
-
reg: float | list[float] | tuple[float] = 0,
|
|
33
|
-
):
|
|
19
|
+
def _make_A_mm_reg(A_mm: Callable, reg):
|
|
34
20
|
def A_mm_reg(x): # A_mm with regularization
|
|
35
21
|
Ax = A_mm(x)
|
|
36
22
|
if not generic_eq(reg, 0): Ax += x*reg
|
|
37
23
|
return Ax
|
|
24
|
+
return A_mm_reg
|
|
38
25
|
|
|
39
|
-
|
|
40
|
-
if x0_ is None: x0_ = generic_zeros_like(b)
|
|
41
|
-
|
|
42
|
-
x = x0_
|
|
43
|
-
residual = b - A_mm_reg(x)
|
|
44
|
-
p = residual.clone() # search direction
|
|
45
|
-
r_norm = generic_vector_norm(residual)
|
|
46
|
-
init_norm = r_norm
|
|
47
|
-
if tol is not None and r_norm < tol: return x
|
|
48
|
-
k = 0
|
|
49
|
-
|
|
50
|
-
while True:
|
|
51
|
-
Ap = A_mm_reg(p)
|
|
52
|
-
step_size = (r_norm**2) / p.dot(Ap)
|
|
53
|
-
x += step_size * p # Update solution
|
|
54
|
-
residual -= step_size * Ap # Update residual
|
|
55
|
-
new_r_norm = generic_vector_norm(residual)
|
|
56
|
-
|
|
57
|
-
k += 1
|
|
58
|
-
if tol is not None and new_r_norm <= tol * init_norm: return x
|
|
59
|
-
if k >= maxiter: return x
|
|
60
|
-
|
|
61
|
-
beta = (new_r_norm**2) / (r_norm**2)
|
|
62
|
-
p = residual + beta*p
|
|
63
|
-
r_norm = new_r_norm
|
|
26
|
+
def _identity(x): return x
|
|
64
27
|
|
|
65
28
|
|
|
66
|
-
# https://arxiv.org/pdf/2110.02820
|
|
29
|
+
# https://arxiv.org/pdf/2110.02820
|
|
67
30
|
def nystrom_approximation(
|
|
68
31
|
A_mm: Callable[[torch.Tensor], torch.Tensor],
|
|
69
32
|
ndim: int,
|
|
@@ -85,7 +48,6 @@ def nystrom_approximation(
|
|
|
85
48
|
lambd = (S.pow(2) - v).clip(min=0) #Remove shift, compute eigs
|
|
86
49
|
return U, lambd
|
|
87
50
|
|
|
88
|
-
# this one works worse
|
|
89
51
|
def nystrom_sketch_and_solve(
|
|
90
52
|
A_mm: Callable[[torch.Tensor], torch.Tensor],
|
|
91
53
|
b: torch.Tensor,
|
|
@@ -111,7 +73,6 @@ def nystrom_sketch_and_solve(
|
|
|
111
73
|
term2 = (1.0 / reg) * (b - U @ Uᵀb)
|
|
112
74
|
return (term1 + term2).squeeze(-1)
|
|
113
75
|
|
|
114
|
-
# this one is insane
|
|
115
76
|
def nystrom_pcg(
|
|
116
77
|
A_mm: Callable[[torch.Tensor], torch.Tensor],
|
|
117
78
|
b: torch.Tensor,
|
|
@@ -131,6 +92,8 @@ def nystrom_pcg(
|
|
|
131
92
|
generator=generator,
|
|
132
93
|
)
|
|
133
94
|
lambd += reg
|
|
95
|
+
eps = torch.finfo(b.dtype).tiny * 2
|
|
96
|
+
if tol is None: tol = eps
|
|
134
97
|
|
|
135
98
|
def A_mm_reg(x): # A_mm with regularization
|
|
136
99
|
Ax = A_mm(x)
|
|
@@ -150,7 +113,7 @@ def nystrom_pcg(
|
|
|
150
113
|
p = z.clone() # search direction
|
|
151
114
|
|
|
152
115
|
init_norm = torch.linalg.vector_norm(residual) # pylint:disable=not-callable
|
|
153
|
-
if
|
|
116
|
+
if init_norm < tol: return x
|
|
154
117
|
k = 0
|
|
155
118
|
while True:
|
|
156
119
|
Ap = A_mm_reg(p)
|
|
@@ -160,10 +123,358 @@ def nystrom_pcg(
|
|
|
160
123
|
residual -= step_size * Ap
|
|
161
124
|
|
|
162
125
|
k += 1
|
|
163
|
-
if
|
|
126
|
+
if torch.linalg.vector_norm(residual) <= tol * init_norm: return x # pylint:disable=not-callable
|
|
164
127
|
if k >= maxiter: return x
|
|
165
128
|
|
|
166
129
|
z = P_inv @ residual
|
|
167
130
|
beta = residual.dot(z) / rz
|
|
168
131
|
p = z + p*beta
|
|
169
132
|
|
|
133
|
+
|
|
134
|
+
def _safe_clip(x: torch.Tensor):
|
|
135
|
+
"""makes sure scalar tensor x is not smaller than tiny"""
|
|
136
|
+
assert x.numel() == 1, x.shape
|
|
137
|
+
eps = torch.finfo(x.dtype).tiny * 2
|
|
138
|
+
if x.abs() < eps: return x.new_full(x.size(), eps).copysign(x)
|
|
139
|
+
return x
|
|
140
|
+
|
|
141
|
+
def _trust_tau(x,d,trust_radius):
|
|
142
|
+
xx = x.dot(x)
|
|
143
|
+
xd = x.dot(d)
|
|
144
|
+
dd = _safe_clip(d.dot(d))
|
|
145
|
+
|
|
146
|
+
rad = (xd**2 - dd * (xx - trust_radius**2)).clip(min=0).sqrt()
|
|
147
|
+
tau = (-xd + rad) / dd
|
|
148
|
+
|
|
149
|
+
return x + tau * d
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
class CG:
|
|
153
|
+
"""Conjugate gradient method.
|
|
154
|
+
|
|
155
|
+
Args:
|
|
156
|
+
A_mm (Callable[[torch.Tensor], torch.Tensor] | torch.Tensor): Callable that returns matvec ``Ax``.
|
|
157
|
+
b (torch.Tensor): right hand side
|
|
158
|
+
x0 (torch.Tensor | None, optional): initial guess, defaults to zeros. Defaults to None.
|
|
159
|
+
tol (float | None, optional): tolerance for convergence. Defaults to 1e-8.
|
|
160
|
+
maxiter (int | None, optional):
|
|
161
|
+
maximum number of iterations, if None sets to number of dimensions. Defaults to None.
|
|
162
|
+
reg (float, optional): regularization. Defaults to 0.
|
|
163
|
+
trust_radius (float | None, optional):
|
|
164
|
+
CG is terminated whenever solution exceeds trust region, returning a solution modified to be within it. Defaults to None.
|
|
165
|
+
npc_terminate (bool, optional):
|
|
166
|
+
whether to terminate CG whenever negative curavture is detected. Defaults to False.
|
|
167
|
+
miniter (int, optional):
|
|
168
|
+
minimal number of iterations even if tolerance is satisfied, this ensures some progress
|
|
169
|
+
is always made.
|
|
170
|
+
history_size (int, optional):
|
|
171
|
+
number of past iterations to store, to re-use them when trust radius is decreased.
|
|
172
|
+
P_mm (Callable | torch.Tensor | None, optional):
|
|
173
|
+
Callable that returns inverse preconditioner times vector. Defaults to None.
|
|
174
|
+
"""
|
|
175
|
+
def __init__(
|
|
176
|
+
self,
|
|
177
|
+
A_mm: Callable,
|
|
178
|
+
b: torch.Tensor | TensorList,
|
|
179
|
+
x0: torch.Tensor | TensorList | None = None,
|
|
180
|
+
tol: float | None = 1e-4,
|
|
181
|
+
maxiter: int | None = None,
|
|
182
|
+
reg: float = 0,
|
|
183
|
+
trust_radius: float | None = None,
|
|
184
|
+
npc_terminate: bool=False,
|
|
185
|
+
miniter: int = 0,
|
|
186
|
+
history_size: int = 0,
|
|
187
|
+
P_mm: Callable | None = None,
|
|
188
|
+
):
|
|
189
|
+
# --------------------------------- set attrs -------------------------------- #
|
|
190
|
+
self.A_mm = _make_A_mm_reg(A_mm, reg)
|
|
191
|
+
self.b = b
|
|
192
|
+
if tol is None: tol = generic_finfo_tiny(b) * 2
|
|
193
|
+
self.tol = tol
|
|
194
|
+
self.eps = generic_finfo_tiny(b) * 2
|
|
195
|
+
if maxiter is None: maxiter = generic_numel(b)
|
|
196
|
+
self.maxiter = maxiter
|
|
197
|
+
self.miniter = miniter
|
|
198
|
+
self.trust_radius = trust_radius
|
|
199
|
+
self.npc_terminate = npc_terminate
|
|
200
|
+
self.P_mm = P_mm if P_mm is not None else _identity
|
|
201
|
+
|
|
202
|
+
if history_size > 0:
|
|
203
|
+
self.history = deque(maxlen = history_size)
|
|
204
|
+
"""history of (x, x_norm, d)"""
|
|
205
|
+
else:
|
|
206
|
+
self.history = None
|
|
207
|
+
|
|
208
|
+
# -------------------------------- initialize -------------------------------- #
|
|
209
|
+
|
|
210
|
+
self.iter = 0
|
|
211
|
+
|
|
212
|
+
if x0 is None:
|
|
213
|
+
self.x = generic_zeros_like(b)
|
|
214
|
+
self.r = b
|
|
215
|
+
else:
|
|
216
|
+
self.x = x0
|
|
217
|
+
self.r = b - A_mm(self.x)
|
|
218
|
+
|
|
219
|
+
self.z = self.P_mm(self.r)
|
|
220
|
+
self.d = self.z
|
|
221
|
+
|
|
222
|
+
if self.history is not None:
|
|
223
|
+
self.history.append((self.x, generic_vector_norm(self.x), self.d))
|
|
224
|
+
|
|
225
|
+
def step(self) -> tuple[Any, bool]:
|
|
226
|
+
"""returns ``(solution, should_terminate)``"""
|
|
227
|
+
x, b, d, r, z = self.x, self.b, self.d, self.r, self.z
|
|
228
|
+
|
|
229
|
+
if self.iter >= self.maxiter:
|
|
230
|
+
return x, True
|
|
231
|
+
|
|
232
|
+
Ad = self.A_mm(d)
|
|
233
|
+
dAd = d.dot(Ad)
|
|
234
|
+
|
|
235
|
+
# check negative curvature
|
|
236
|
+
if dAd <= self.eps:
|
|
237
|
+
if self.trust_radius is not None: return _trust_tau(x, d, self.trust_radius), True
|
|
238
|
+
if self.iter == 0: return b * (b.dot(b) / dAd).abs(), True
|
|
239
|
+
if self.npc_terminate: return x, True
|
|
240
|
+
|
|
241
|
+
rz = r.dot(z)
|
|
242
|
+
alpha = rz / dAd
|
|
243
|
+
x_next = x + alpha * d
|
|
244
|
+
|
|
245
|
+
# check if the step exceeds the trust-region boundary
|
|
246
|
+
x_next_norm = None
|
|
247
|
+
if self.trust_radius is not None:
|
|
248
|
+
x_next_norm = generic_vector_norm(x_next)
|
|
249
|
+
if x_next_norm >= self.trust_radius:
|
|
250
|
+
return _trust_tau(x, d, self.trust_radius), True
|
|
251
|
+
|
|
252
|
+
# update step, residual and direction
|
|
253
|
+
r_next = r - alpha * Ad
|
|
254
|
+
|
|
255
|
+
# check if r is sufficiently small
|
|
256
|
+
if self.iter >= self.miniter and generic_vector_norm(r_next) < self.tol:
|
|
257
|
+
return x_next, True
|
|
258
|
+
|
|
259
|
+
# update d, r, z
|
|
260
|
+
z_next = self.P_mm(r_next)
|
|
261
|
+
beta = r_next.dot(z_next) / rz
|
|
262
|
+
|
|
263
|
+
self.d = z_next + beta * d
|
|
264
|
+
self.x = x_next
|
|
265
|
+
self.r = r_next
|
|
266
|
+
self.z = z_next
|
|
267
|
+
|
|
268
|
+
# update history
|
|
269
|
+
if self.history is not None:
|
|
270
|
+
if x_next_norm is None: x_next_norm = generic_vector_norm(x_next)
|
|
271
|
+
self.history.append((self.x, x_next_norm, self.d))
|
|
272
|
+
|
|
273
|
+
self.iter += 1
|
|
274
|
+
return x, False
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
def solve(self):
|
|
278
|
+
# return initial guess if it is good enough
|
|
279
|
+
if self.miniter < 1 and generic_vector_norm(self.r) < self.tol:
|
|
280
|
+
return self.x
|
|
281
|
+
|
|
282
|
+
should_terminate = False
|
|
283
|
+
sol = None
|
|
284
|
+
|
|
285
|
+
while not should_terminate:
|
|
286
|
+
sol, should_terminate = self.step()
|
|
287
|
+
|
|
288
|
+
assert sol is not None
|
|
289
|
+
return sol
|
|
290
|
+
|
|
291
|
+
def find_within_trust_radius(history, trust_radius: float):
|
|
292
|
+
"""find first ``x`` in history that exceeds trust radius, if no such ``x`` exists, returns ``None``"""
|
|
293
|
+
for x, x_norm, d in reversed(tuple(history)):
|
|
294
|
+
if x_norm <= trust_radius:
|
|
295
|
+
return _trust_tau(x, d, trust_radius)
|
|
296
|
+
return None
|
|
297
|
+
|
|
298
|
+
class _TensorSolution(NamedTuple):
|
|
299
|
+
x: torch.Tensor
|
|
300
|
+
solver: CG
|
|
301
|
+
|
|
302
|
+
class _TensorListSolution(NamedTuple):
|
|
303
|
+
x: TensorList
|
|
304
|
+
solver: CG
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
@overload
|
|
308
|
+
def cg(
|
|
309
|
+
A_mm: Callable[[torch.Tensor], torch.Tensor],
|
|
310
|
+
b: torch.Tensor,
|
|
311
|
+
x0: torch.Tensor | None = None,
|
|
312
|
+
tol: float | None = 1e-8,
|
|
313
|
+
maxiter: int | None = None,
|
|
314
|
+
reg: float = 0,
|
|
315
|
+
trust_radius: float | None = None,
|
|
316
|
+
npc_terminate: bool = False,
|
|
317
|
+
miniter: int = 0,
|
|
318
|
+
history_size: int = 0,
|
|
319
|
+
P_mm: Callable[[torch.Tensor], torch.Tensor] | None = None
|
|
320
|
+
) -> _TensorSolution: ...
|
|
321
|
+
@overload
|
|
322
|
+
def cg(
|
|
323
|
+
A_mm: Callable[[TensorList], TensorList],
|
|
324
|
+
b: TensorList,
|
|
325
|
+
x0: TensorList | None = None,
|
|
326
|
+
tol: float | None = 1e-8,
|
|
327
|
+
maxiter: int | None = None,
|
|
328
|
+
reg: float | list[float] | tuple[float] = 0,
|
|
329
|
+
trust_radius: float | None = None,
|
|
330
|
+
npc_terminate: bool=False,
|
|
331
|
+
miniter: int = 0,
|
|
332
|
+
history_size: int = 0,
|
|
333
|
+
P_mm: Callable[[TensorList], TensorList] | None = None
|
|
334
|
+
) -> _TensorListSolution: ...
|
|
335
|
+
def cg(
|
|
336
|
+
A_mm: Callable,
|
|
337
|
+
b: torch.Tensor | TensorList,
|
|
338
|
+
x0: torch.Tensor | TensorList | None = None,
|
|
339
|
+
tol: float | None = 1e-8,
|
|
340
|
+
maxiter: int | None = None,
|
|
341
|
+
reg: float | list[float] | tuple[float] = 0,
|
|
342
|
+
trust_radius: float | None = None,
|
|
343
|
+
npc_terminate: bool = False,
|
|
344
|
+
miniter: int = 0,
|
|
345
|
+
history_size:int = 0,
|
|
346
|
+
P_mm: Callable | None = None
|
|
347
|
+
):
|
|
348
|
+
solver = CG(
|
|
349
|
+
A_mm=A_mm,
|
|
350
|
+
b=b,
|
|
351
|
+
x0=x0,
|
|
352
|
+
tol=tol,
|
|
353
|
+
maxiter=maxiter,
|
|
354
|
+
reg=reg,
|
|
355
|
+
trust_radius=trust_radius,
|
|
356
|
+
npc_terminate=npc_terminate,
|
|
357
|
+
miniter=miniter,
|
|
358
|
+
history_size=history_size,
|
|
359
|
+
P_mm=P_mm,
|
|
360
|
+
)
|
|
361
|
+
|
|
362
|
+
x = solver.solve()
|
|
363
|
+
|
|
364
|
+
if isinstance(b, torch.Tensor):
|
|
365
|
+
return _TensorSolution(x, solver)
|
|
366
|
+
|
|
367
|
+
return _TensorListSolution(x, solver)
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
# Liu, Yang, and Fred Roosta. "MINRES: From negative curvature detection to monotonicity properties." SIAM Journal on Optimization 32.4 (2022): 2636-2661.
|
|
371
|
+
@overload
|
|
372
|
+
def minres(
|
|
373
|
+
A_mm: Callable[[torch.Tensor], torch.Tensor] | torch.Tensor,
|
|
374
|
+
b: torch.Tensor,
|
|
375
|
+
x0: torch.Tensor | None = None,
|
|
376
|
+
tol: float | None = 1e-4,
|
|
377
|
+
maxiter: int | None = None,
|
|
378
|
+
reg: float = 0,
|
|
379
|
+
npc_terminate: bool=True,
|
|
380
|
+
trust_radius: float | None = None,
|
|
381
|
+
) -> torch.Tensor: ...
|
|
382
|
+
@overload
|
|
383
|
+
def minres(
|
|
384
|
+
A_mm: Callable[[TensorList], TensorList],
|
|
385
|
+
b: TensorList,
|
|
386
|
+
x0: TensorList | None = None,
|
|
387
|
+
tol: float | None = 1e-4,
|
|
388
|
+
maxiter: int | None = None,
|
|
389
|
+
reg: float | list[float] | tuple[float] = 0,
|
|
390
|
+
npc_terminate: bool=True,
|
|
391
|
+
trust_radius: float | None = None,
|
|
392
|
+
) -> TensorList: ...
|
|
393
|
+
def minres(
|
|
394
|
+
A_mm,
|
|
395
|
+
b,
|
|
396
|
+
x0: torch.Tensor | TensorList | None = None,
|
|
397
|
+
tol: float | None = 1e-4,
|
|
398
|
+
maxiter: int | None = None,
|
|
399
|
+
reg: float | list[float] | tuple[float] = 0,
|
|
400
|
+
npc_terminate: bool=True,
|
|
401
|
+
trust_radius: float | None = None, #trust region is experimental
|
|
402
|
+
):
|
|
403
|
+
A_mm_reg = _make_A_mm_reg(A_mm, reg)
|
|
404
|
+
eps = math.sqrt(generic_finfo_tiny(b) * 2)
|
|
405
|
+
if tol is None: tol = eps
|
|
406
|
+
|
|
407
|
+
if maxiter is None: maxiter = generic_numel(b)
|
|
408
|
+
if x0 is None:
|
|
409
|
+
R = b
|
|
410
|
+
x0 = generic_zeros_like(b)
|
|
411
|
+
else:
|
|
412
|
+
R = b - A_mm_reg(x0)
|
|
413
|
+
|
|
414
|
+
X: Any = x0
|
|
415
|
+
beta = b_norm = generic_vector_norm(b)
|
|
416
|
+
if b_norm < eps**2:
|
|
417
|
+
return generic_zeros_like(b)
|
|
418
|
+
|
|
419
|
+
|
|
420
|
+
V = b / beta
|
|
421
|
+
V_prev = generic_zeros_like(b)
|
|
422
|
+
D = generic_zeros_like(b)
|
|
423
|
+
D_prev = generic_zeros_like(b)
|
|
424
|
+
|
|
425
|
+
c = -1
|
|
426
|
+
phi = tau = beta
|
|
427
|
+
s = delta1 = e = 0
|
|
428
|
+
|
|
429
|
+
|
|
430
|
+
for _ in range(maxiter):
|
|
431
|
+
|
|
432
|
+
P = A_mm_reg(V)
|
|
433
|
+
alpha = V.dot(P)
|
|
434
|
+
P -= beta*V_prev
|
|
435
|
+
P -= alpha*V
|
|
436
|
+
beta = generic_vector_norm(P)
|
|
437
|
+
|
|
438
|
+
delta2 = c*delta1 + s*alpha
|
|
439
|
+
gamma1 = s*delta1 - c*alpha
|
|
440
|
+
e_next = s*beta
|
|
441
|
+
delta1 = -c*beta
|
|
442
|
+
|
|
443
|
+
cgamma1 = c*gamma1
|
|
444
|
+
if trust_radius is not None and cgamma1 >= 0:
|
|
445
|
+
if npc_terminate: return _trust_tau(X, R, trust_radius)
|
|
446
|
+
return _trust_tau(X, D, trust_radius)
|
|
447
|
+
|
|
448
|
+
if npc_terminate and cgamma1 >= 0:
|
|
449
|
+
return R
|
|
450
|
+
|
|
451
|
+
gamma2 = (gamma1**2 + beta**2)**(1/2)
|
|
452
|
+
|
|
453
|
+
if abs(gamma2) <= eps: # singular system
|
|
454
|
+
# c=0; s=1; tau=0
|
|
455
|
+
if trust_radius is None: return X
|
|
456
|
+
return _trust_tau(X, D, trust_radius)
|
|
457
|
+
|
|
458
|
+
c = gamma1 / gamma2
|
|
459
|
+
s = beta/gamma2
|
|
460
|
+
tau = c*phi
|
|
461
|
+
phi = s*phi
|
|
462
|
+
|
|
463
|
+
D_prev = D
|
|
464
|
+
D = (V - delta2*D - e*D_prev) / gamma2
|
|
465
|
+
e = e_next
|
|
466
|
+
X = X + tau*D
|
|
467
|
+
|
|
468
|
+
if trust_radius is not None:
|
|
469
|
+
if generic_vector_norm(X) > trust_radius:
|
|
470
|
+
return _trust_tau(X, D, trust_radius)
|
|
471
|
+
|
|
472
|
+
if (abs(beta) < eps) or (phi / b_norm <= tol):
|
|
473
|
+
# R = zeros(R)
|
|
474
|
+
return X
|
|
475
|
+
|
|
476
|
+
V_prev = V
|
|
477
|
+
V = P/beta
|
|
478
|
+
R = s**2*R - phi*c*V
|
|
479
|
+
|
|
480
|
+
return X
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
"""convenience submodule which allows to calculate a metric based on its string name,
|
|
2
|
+
used in many places"""
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from collections.abc import Sequence
|
|
5
|
+
from typing import TYPE_CHECKING, Any, Literal, overload
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from .tensorlist import TensorList
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class Metric(ABC):
|
|
15
|
+
@abstractmethod
|
|
16
|
+
def evaluate_global(self, x: "TensorList") -> torch.Tensor:
|
|
17
|
+
"""returns a global metric for a tensorlist"""
|
|
18
|
+
|
|
19
|
+
@abstractmethod
|
|
20
|
+
def evaluate_tensor(self, x: torch.Tensor, dim=None, keepdim=False) -> torch.Tensor:
|
|
21
|
+
"""returns metric for a tensor"""
|
|
22
|
+
|
|
23
|
+
def evaluate_list(self, x: "TensorList") -> "TensorList":
|
|
24
|
+
"""returns list of metrics for a tensorlist (possibly vectorized)"""
|
|
25
|
+
return x.map(self.evaluate_tensor)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class _MAD(Metric):
|
|
29
|
+
def evaluate_global(self, x): return x.abs().global_mean()
|
|
30
|
+
def evaluate_tensor(self, x, dim=None, keepdim=False): return x.abs().mean(dim=dim, keepdim=keepdim)
|
|
31
|
+
def evaluate_list(self, x): return x.abs().mean()
|
|
32
|
+
|
|
33
|
+
class _Std(Metric):
|
|
34
|
+
def evaluate_global(self, x): return x.global_std()
|
|
35
|
+
def evaluate_tensor(self, x, dim=None, keepdim=False): return x.std(dim=dim, keepdim=keepdim)
|
|
36
|
+
def evaluate_list(self, x): return x.std()
|
|
37
|
+
|
|
38
|
+
class _Var(Metric):
|
|
39
|
+
def evaluate_global(self, x): return x.global_var()
|
|
40
|
+
def evaluate_tensor(self, x, dim=None, keepdim=False): return x.var(dim=dim, keepdim=keepdim)
|
|
41
|
+
def evaluate_list(self, x): return x.var()
|
|
42
|
+
|
|
43
|
+
class _Sum(Metric):
|
|
44
|
+
def evaluate_global(self, x): return x.global_sum()
|
|
45
|
+
def evaluate_tensor(self, x, dim=None, keepdim=False): return x.sum(dim=dim, keepdim=keepdim)
|
|
46
|
+
def evaluate_list(self, x): return x.sum()
|
|
47
|
+
|
|
48
|
+
class _Norm(Metric):
|
|
49
|
+
def __init__(self, ord): self.ord = ord
|
|
50
|
+
def evaluate_global(self, x): return x.global_vector_norm(self.ord)
|
|
51
|
+
def evaluate_tensor(self, x, dim=None, keepdim=False):
|
|
52
|
+
return torch.linalg.vector_norm(x, ord=self.ord, dim=dim, keepdim=keepdim) # pylint:disable=not-callable
|
|
53
|
+
def evaluate_list(self, x): return x.norm(self.ord)
|
|
54
|
+
|
|
55
|
+
_METRIC_KEYS = Literal['mad', 'std', 'var', 'sum', 'l0', 'l1', 'l2', 'l3', 'l4', 'linf']
|
|
56
|
+
_METRICS: dict[_METRIC_KEYS, Metric] = {
|
|
57
|
+
"mad": _MAD(),
|
|
58
|
+
"std": _Std(),
|
|
59
|
+
"var": _Var(),
|
|
60
|
+
"sum": _Sum(),
|
|
61
|
+
"l0": _Norm(0),
|
|
62
|
+
"l1": _Norm(1),
|
|
63
|
+
"l2": _Norm(2),
|
|
64
|
+
"l3": _Norm(3),
|
|
65
|
+
"l4": _Norm(4),
|
|
66
|
+
"linf": _Norm(torch.inf),
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
Metrics = _METRIC_KEYS | float | torch.Tensor
|
|
70
|
+
def evaluate_metric(x: "torch.Tensor | TensorList", metric: Metrics) -> torch.Tensor:
|
|
71
|
+
if isinstance(metric, (int, float, torch.Tensor)):
|
|
72
|
+
if isinstance(x, torch.Tensor): return torch.linalg.vector_norm(x, ord=metric) # pylint:disable=not-callable
|
|
73
|
+
return x.global_vector_norm(ord=float(metric))
|
|
74
|
+
|
|
75
|
+
if isinstance(x, torch.Tensor): return _METRICS[metric].evaluate_tensor(x)
|
|
76
|
+
return _METRICS[metric].evaluate_global(x)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def calculate_metric_list(x: "TensorList", metric: Metrics) -> "TensorList":
|
|
80
|
+
if isinstance(metric, (int, float, torch.Tensor)):
|
|
81
|
+
return x.norm(ord=float(metric))
|
|
82
|
+
|
|
83
|
+
return _METRICS[metric].evaluate_list(x)
|
torchzero/utils/numberlist.py
CHANGED
|
@@ -129,4 +129,6 @@ class NumberList(list[int | float | Any]):
|
|
|
129
129
|
return self.__class__(fn(i, *args, **kwargs) for i in self)
|
|
130
130
|
|
|
131
131
|
def clamp(self, min=None, max=None):
|
|
132
|
+
return self.zipmap_args(_clamp, min, max)
|
|
133
|
+
def clip(self, min=None, max=None):
|
|
132
134
|
return self.zipmap_args(_clamp, min, max)
|
torchzero/utils/python_tools.py
CHANGED
|
@@ -31,6 +31,16 @@ def generic_eq(x: int | float | Iterable[int | float], y: int | float | Iterable
|
|
|
31
31
|
return all(i==y for i in x)
|
|
32
32
|
return all(i==j for i,j in zip(x,y))
|
|
33
33
|
|
|
34
|
+
def generic_ne(x: int | float | Iterable[int | float], y: int | float | Iterable[int | float]) -> bool:
|
|
35
|
+
"""generic not equals function that supports scalars and lists of numbers. Faster than not generic_eq"""
|
|
36
|
+
if isinstance(x, (int,float)):
|
|
37
|
+
if isinstance(y, (int,float)): return x!=y
|
|
38
|
+
return any(i!=x for i in y)
|
|
39
|
+
if isinstance(y, (int,float)):
|
|
40
|
+
return any(i!=y for i in x)
|
|
41
|
+
return any(i!=j for i,j in zip(x,y))
|
|
42
|
+
|
|
43
|
+
|
|
34
44
|
def zipmap(self, fn: Callable, other: Any | list | tuple, *args, **kwargs):
|
|
35
45
|
"""If `other` is list/tuple, applies `fn` to self zipped with `other`.
|
|
36
46
|
Otherwise applies `fn` to this sequence and `other`.
|
|
@@ -51,3 +61,9 @@ def unpack_dicts(dicts: Iterable[Mapping[str, Any]], key:str, key2: str | None =
|
|
|
51
61
|
values = [cls(s[k] for s in dicts) for k in keys] # pyright:ignore[reportCallIssue]
|
|
52
62
|
if len(values) == 1: return values[0]
|
|
53
63
|
return values
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def safe_dict_update_(d1_:dict, d2:dict):
|
|
67
|
+
inter = set(d1_.keys()).intersection(d2.keys())
|
|
68
|
+
if len(inter) > 0: raise RuntimeError(f"Duplicate keys {inter}")
|
|
69
|
+
d1_.update(d2)
|