torchzero 0.3.11__py3-none-any.whl → 0.3.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_opts.py +95 -69
- tests/test_tensorlist.py +8 -7
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +225 -72
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +44 -24
- torchzero/modules/__init__.py +13 -5
- torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
- torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
- torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
- torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
- torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
- torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
- torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
- torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
- torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
- torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
- torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
- torchzero/modules/clipping/clipping.py +85 -92
- torchzero/modules/clipping/ema_clipping.py +5 -5
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
- torchzero/modules/experimental/__init__.py +9 -32
- torchzero/modules/experimental/dct.py +2 -2
- torchzero/modules/experimental/fft.py +2 -2
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +27 -14
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/experimental/structural_projections.py +1 -1
- torchzero/modules/functional.py +50 -14
- torchzero/modules/grad_approximation/fdm.py +19 -20
- torchzero/modules/grad_approximation/forward_gradient.py +4 -2
- torchzero/modules/grad_approximation/grad_approximator.py +43 -47
- torchzero/modules/grad_approximation/rfdm.py +144 -122
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +31 -23
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +2 -2
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +69 -44
- torchzero/modules/line_search/backtracking.py +83 -70
- torchzero/modules/line_search/line_search.py +159 -68
- torchzero/modules/line_search/scipy.py +1 -1
- torchzero/modules/line_search/strong_wolfe.py +319 -218
- torchzero/modules/misc/__init__.py +8 -0
- torchzero/modules/misc/debug.py +4 -4
- torchzero/modules/misc/escape.py +9 -7
- torchzero/modules/misc/gradient_accumulation.py +88 -22
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +82 -15
- torchzero/modules/misc/multistep.py +47 -11
- torchzero/modules/misc/regularization.py +5 -9
- torchzero/modules/misc/split.py +55 -35
- torchzero/modules/misc/switch.py +1 -1
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +3 -3
- torchzero/modules/momentum/cautious.py +42 -47
- torchzero/modules/momentum/momentum.py +35 -1
- torchzero/modules/ops/__init__.py +9 -1
- torchzero/modules/ops/binary.py +9 -8
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
- torchzero/modules/ops/multi.py +15 -15
- torchzero/modules/ops/reduce.py +1 -1
- torchzero/modules/ops/utility.py +12 -8
- torchzero/modules/projections/projection.py +4 -4
- torchzero/modules/quasi_newton/__init__.py +1 -16
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
- torchzero/modules/quasi_newton/lbfgs.py +256 -200
- torchzero/modules/quasi_newton/lsr1.py +167 -132
- torchzero/modules/quasi_newton/quasi_newton.py +346 -446
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +252 -0
- torchzero/modules/second_order/__init__.py +2 -1
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +133 -88
- torchzero/modules/second_order/newton_cg.py +141 -80
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +1 -1
- torchzero/modules/step_size/adaptive.py +312 -47
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +97 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/weight_decay.py +65 -64
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +359 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +0 -1
- torchzero/optim/wrappers/fcmaes.py +3 -2
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +2 -2
- torchzero/optim/wrappers/scipy.py +81 -22
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +123 -111
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +226 -154
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/python_tools.py +6 -0
- torchzero/utils/tensorlist.py +105 -34
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.13.dist-info/METADATA +14 -0
- torchzero-0.3.13.dist-info/RECORD +166 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -59
- docs/source/docstring template.py +0 -46
- torchzero/modules/experimental/absoap.py +0 -253
- torchzero/modules/experimental/adadam.py +0 -118
- torchzero/modules/experimental/adamY.py +0 -131
- torchzero/modules/experimental/adam_lambertw.py +0 -149
- torchzero/modules/experimental/adaptive_step_size.py +0 -90
- torchzero/modules/experimental/adasoap.py +0 -177
- torchzero/modules/experimental/cosine.py +0 -214
- torchzero/modules/experimental/cubic_adam.py +0 -97
- torchzero/modules/experimental/eigendescent.py +0 -120
- torchzero/modules/experimental/etf.py +0 -195
- torchzero/modules/experimental/exp_adam.py +0 -113
- torchzero/modules/experimental/expanded_lbfgs.py +0 -141
- torchzero/modules/experimental/hnewton.py +0 -85
- torchzero/modules/experimental/modular_lbfgs.py +0 -265
- torchzero/modules/experimental/parabolic_search.py +0 -220
- torchzero/modules/experimental/subspace_preconditioners.py +0 -145
- torchzero/modules/experimental/tensor_adagrad.py +0 -42
- torchzero/modules/line_search/polynomial.py +0 -233
- torchzero/modules/momentum/matrix_momentum.py +0 -193
- torchzero/modules/optimizers/adagrad.py +0 -165
- torchzero/modules/quasi_newton/trust_region.py +0 -397
- torchzero/modules/smoothing/gaussian.py +0 -198
- torchzero-0.3.11.dist-info/METADATA +0 -404
- torchzero-0.3.11.dist-info/RECORD +0 -159
- torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
- /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,329 @@
|
|
|
1
|
+
"""simplified version of https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.linalg.LinearOperator.html. This is used for trust regions."""
|
|
2
|
+
import math
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from functools import partial
|
|
5
|
+
from importlib.util import find_spec
|
|
6
|
+
from typing import cast, final
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
from ..torch_tools import tofloat, tonumpy, totensor
|
|
11
|
+
|
|
12
|
+
if find_spec('scipy') is not None:
|
|
13
|
+
from scipy.sparse.linalg import LinearOperator as _ScipyLinearOperator
|
|
14
|
+
else:
|
|
15
|
+
_ScipyLinearOperator = None
|
|
16
|
+
|
|
17
|
+
class LinearOperator(ABC):
|
|
18
|
+
"""this is used for trust region"""
|
|
19
|
+
device: torch.types.Device
|
|
20
|
+
dtype: torch.dtype | None
|
|
21
|
+
|
|
22
|
+
def matvec(self, x: torch.Tensor) -> torch.Tensor:
|
|
23
|
+
raise NotImplementedError(f"{self.__class__.__name__} doesn't implement matvec")
|
|
24
|
+
|
|
25
|
+
def rmatvec(self, x: torch.Tensor) -> torch.Tensor:
|
|
26
|
+
raise NotImplementedError(f"{self.__class__.__name__} doesn't implement rmatvec")
|
|
27
|
+
|
|
28
|
+
def matmat(self, x: torch.Tensor) -> "LinearOperator":
|
|
29
|
+
raise NotImplementedError(f"{self.__class__.__name__} doesn't implement matmul")
|
|
30
|
+
|
|
31
|
+
def solve(self, b: torch.Tensor) -> torch.Tensor:
|
|
32
|
+
raise NotImplementedError(f"{self.__class__.__name__} doesn't implement solve")
|
|
33
|
+
|
|
34
|
+
def solve_bounded(self, b: torch.Tensor, bound:float, ord:float=2) -> torch.Tensor:
|
|
35
|
+
"""solve with a norm bound on x"""
|
|
36
|
+
raise NotImplementedError(f"{self.__class__.__name__} doesn't implement solve_bounded")
|
|
37
|
+
|
|
38
|
+
def update(self, *args, **kwargs) -> None:
|
|
39
|
+
raise NotImplementedError(f"{self.__class__.__name__} doesn't implement update")
|
|
40
|
+
|
|
41
|
+
def add(self, x: torch.Tensor) -> "LinearOperator":
|
|
42
|
+
raise NotImplementedError(f"{self.__class__.__name__} doesn't implement add")
|
|
43
|
+
|
|
44
|
+
def __add__(self, x: torch.Tensor) -> "LinearOperator":
|
|
45
|
+
return self.add(x)
|
|
46
|
+
|
|
47
|
+
def add_diagonal(self, x: torch.Tensor | float) -> "LinearOperator":
|
|
48
|
+
raise NotImplementedError(f"{self.__class__.__name__} doesn't implement add_diagonal")
|
|
49
|
+
|
|
50
|
+
def diagonal(self) -> torch.Tensor:
|
|
51
|
+
raise NotImplementedError(f"{self.__class__.__name__} doesn't implement diagonal")
|
|
52
|
+
|
|
53
|
+
def inv(self) -> "LinearOperator":
|
|
54
|
+
raise NotImplementedError(f"{self.__class__.__name__} doesn't implement inverse")
|
|
55
|
+
|
|
56
|
+
def transpose(self) -> "LinearOperator":
|
|
57
|
+
raise NotImplementedError(f"{self.__class__.__name__} doesn't implement transpose")
|
|
58
|
+
|
|
59
|
+
@property
|
|
60
|
+
def T(self): return self.transpose()
|
|
61
|
+
|
|
62
|
+
def to_tensor(self) -> torch.Tensor:
|
|
63
|
+
raise NotImplementedError(f"{self.__class__.__name__} doesn't implement to_tensor")
|
|
64
|
+
|
|
65
|
+
def to_dense(self) -> "Dense":
|
|
66
|
+
return Dense(self) # calls to_tensor
|
|
67
|
+
|
|
68
|
+
def size(self) -> tuple[int, ...]:
|
|
69
|
+
raise NotImplementedError(f"{self.__class__.__name__} doesn't implement size")
|
|
70
|
+
|
|
71
|
+
@property
|
|
72
|
+
def shape(self) -> tuple[int, ...]:
|
|
73
|
+
return self.size()
|
|
74
|
+
|
|
75
|
+
def numel(self) -> int:
|
|
76
|
+
return math.prod(self.size())
|
|
77
|
+
|
|
78
|
+
def ndimension(self) -> int:
|
|
79
|
+
return len(self.size())
|
|
80
|
+
|
|
81
|
+
@property
|
|
82
|
+
def ndim(self) -> int:
|
|
83
|
+
return self.ndimension()
|
|
84
|
+
|
|
85
|
+
def _numpy_matvec(self, x, dtype=None):
|
|
86
|
+
"""returns Ax ndarray for scipy's LinearOperator"""
|
|
87
|
+
Ax = self.matvec(totensor(x, device=self.device, dtype=self.dtype))
|
|
88
|
+
Ax = tonumpy(Ax)
|
|
89
|
+
if dtype is not None: Ax = Ax.astype(dtype)
|
|
90
|
+
return Ax
|
|
91
|
+
|
|
92
|
+
def _numpy_rmatvec(self, x, dtype=None):
|
|
93
|
+
"""returns Ax ndarray for scipy's LinearOperator"""
|
|
94
|
+
Ax = self.rmatvec(totensor(x, device=self.device, dtype=self.dtype))
|
|
95
|
+
Ax = tonumpy(Ax)
|
|
96
|
+
if dtype is not None: Ax = Ax.astype(dtype)
|
|
97
|
+
return Ax
|
|
98
|
+
|
|
99
|
+
def scipy_linop(self, dtype=None):
|
|
100
|
+
if _ScipyLinearOperator is None: raise ModuleNotFoundError("Scipy needs to be installed")
|
|
101
|
+
return _ScipyLinearOperator(
|
|
102
|
+
dtype=dtype,
|
|
103
|
+
shape=self.size(),
|
|
104
|
+
matvec=partial(self._numpy_matvec, dtype=dtype), # pyright:ignore[reportCallIssue]
|
|
105
|
+
rmatvec=partial(self._numpy_rmatvec, dtype=dtype), # pyright:ignore[reportCallIssue]
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
def is_dense(self) -> bool:
|
|
109
|
+
raise NotImplementedError(f"{self.__class__.__name__} doesn't implement is_dense")
|
|
110
|
+
|
|
111
|
+
def _solve(A: torch.Tensor, b: torch.Tensor) -> torch.Tensor: # should I keep this or separate solve and lstsq?
|
|
112
|
+
sol, info = torch.linalg.solve_ex(A, b) # pylint:disable=not-callable
|
|
113
|
+
if info == 0: return sol
|
|
114
|
+
return torch.linalg.lstsq(A, b).solution # pylint:disable=not-callable
|
|
115
|
+
|
|
116
|
+
def _inv(A: torch.Tensor) -> torch.Tensor:
|
|
117
|
+
sol, info = torch.linalg.inv_ex(A) # pylint:disable=not-callable
|
|
118
|
+
if info == 0: return sol
|
|
119
|
+
return torch.linalg.pinv(A) # pylint:disable=not-callable
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
class Dense(LinearOperator):
|
|
123
|
+
def __init__(self, A: torch.Tensor | LinearOperator):
|
|
124
|
+
if isinstance(A, LinearOperator): A = A.to_tensor()
|
|
125
|
+
self.A: torch.Tensor = A
|
|
126
|
+
self.device = self.A.device
|
|
127
|
+
self.dtype = self.A.dtype
|
|
128
|
+
|
|
129
|
+
def matvec(self, x): return self.A.mv(x)
|
|
130
|
+
def rmatvec(self, x): return self.A.mH.mv(x)
|
|
131
|
+
|
|
132
|
+
def matmat(self, x): return Dense(self.A.mm(x))
|
|
133
|
+
def rmatmat(self, x): return Dense(self.A.mH.mm(x))
|
|
134
|
+
|
|
135
|
+
def solve(self, b): return _solve(self.A, b)
|
|
136
|
+
|
|
137
|
+
def add(self, x): return Dense(self.A + x)
|
|
138
|
+
def add_diagonal(self, x):
|
|
139
|
+
if isinstance(x, torch.Tensor) and x.numel() <= 1: x = x.item()
|
|
140
|
+
if isinstance(x, (int,float)): x = torch.full((self.shape[0],), fill_value=x, device=self.A.device, dtype=self.A.dtype)
|
|
141
|
+
return Dense(self.A + torch.diag_embed(x))
|
|
142
|
+
def diagonal(self): return self.A.diagonal()
|
|
143
|
+
def inv(self): return Dense(_inv(self.A)) # pylint:disable=not-callable
|
|
144
|
+
def to_tensor(self): return self.A
|
|
145
|
+
def size(self): return self.A.size()
|
|
146
|
+
def is_dense(self): return True
|
|
147
|
+
def transpose(self): return Dense(self.A.mH)
|
|
148
|
+
|
|
149
|
+
class DenseInverse(LinearOperator):
|
|
150
|
+
"""Represents inverse of a dense matrix A."""
|
|
151
|
+
def __init__(self, A_inv: torch.Tensor):
|
|
152
|
+
self.A_inv: torch.Tensor = A_inv
|
|
153
|
+
self.device = self.A_inv.device
|
|
154
|
+
self.dtype = self.A_inv.dtype
|
|
155
|
+
|
|
156
|
+
def matvec(self, x): return _solve(self.A_inv, x) # pylint:disable=not-callable
|
|
157
|
+
def rmatvec(self, x): return _solve(self.A_inv.mH, x) # pylint:disable=not-callable
|
|
158
|
+
|
|
159
|
+
def matmat(self, x): return Dense(_solve(self.A_inv, x)) # pylint:disable=not-callable
|
|
160
|
+
def rmatmat(self, x): return Dense(_solve(self.A_inv.mH, x)) # pylint:disable=not-callable
|
|
161
|
+
|
|
162
|
+
def solve(self, b): return self.A_inv.mv(b)
|
|
163
|
+
|
|
164
|
+
def inv(self): return Dense(self.A_inv) # pylint:disable=not-callable
|
|
165
|
+
def to_tensor(self): return _inv(self.A_inv) # pylint:disable=not-callable
|
|
166
|
+
def size(self): return self.A_inv.size()
|
|
167
|
+
def is_dense(self): return True
|
|
168
|
+
def transpose(self): return DenseInverse(self.A_inv.mH)
|
|
169
|
+
|
|
170
|
+
class DenseWithInverse(Dense):
|
|
171
|
+
"""Represents a matrix where both the matrix and the inverse are known.
|
|
172
|
+
|
|
173
|
+
``matmat``, ``rmatmat``, ``add`` and ``add_diagonal`` will return a Dense matrix, inverse will be lost.
|
|
174
|
+
"""
|
|
175
|
+
def __init__(self, A: torch.Tensor, A_inv: torch.Tensor):
|
|
176
|
+
super().__init__(A)
|
|
177
|
+
self.A_inv: torch.Tensor = A_inv
|
|
178
|
+
|
|
179
|
+
def solve(self, b): return self.A_inv.mv(b)
|
|
180
|
+
def inv(self): return DenseWithInverse(self.A_inv, self.A) # pylint:disable=not-callable
|
|
181
|
+
def transpose(self): return DenseWithInverse(self.A.mH, self.A_inv.mH)
|
|
182
|
+
|
|
183
|
+
class Diagonal(LinearOperator):
|
|
184
|
+
def __init__(self, x: torch.Tensor):
|
|
185
|
+
assert x.ndim == 1
|
|
186
|
+
self.A: torch.Tensor = x
|
|
187
|
+
self.device = self.A.device
|
|
188
|
+
self.dtype = self.A.dtype
|
|
189
|
+
|
|
190
|
+
def matvec(self, x): return self.A * x
|
|
191
|
+
def rmatvec(self, x): return self.A * x
|
|
192
|
+
|
|
193
|
+
def matmat(self, x): return Dense(x * self.A.unsqueeze(-1))
|
|
194
|
+
def rmatmat(self, x): return Dense(x * self.A.unsqueeze(-1))
|
|
195
|
+
|
|
196
|
+
def solve(self, b): return b/self.A
|
|
197
|
+
|
|
198
|
+
def add(self, x): return Dense(x + self.A.diag_embed())
|
|
199
|
+
def add_diagonal(self, x): return Diagonal(self.A + x)
|
|
200
|
+
def diagonal(self): return self.A
|
|
201
|
+
def inv(self): return Diagonal(1/self.A)
|
|
202
|
+
def to_tensor(self): return self.A.diag_embed()
|
|
203
|
+
def size(self): return (self.A.numel(), self.A.numel())
|
|
204
|
+
def is_dense(self): return False
|
|
205
|
+
def transpose(self): return Diagonal(self.A)
|
|
206
|
+
|
|
207
|
+
class ScaledIdentity(LinearOperator):
|
|
208
|
+
def __init__(self, s: float | torch.Tensor = 1., shape=None, device=None, dtype=None):
|
|
209
|
+
self.device = self.dtype = None
|
|
210
|
+
|
|
211
|
+
if isinstance(s, torch.Tensor):
|
|
212
|
+
self.device = s.device
|
|
213
|
+
self.dtype = s.dtype
|
|
214
|
+
|
|
215
|
+
if device is not None: self.device = device
|
|
216
|
+
if dtype is not None: self.dtype = dtype
|
|
217
|
+
|
|
218
|
+
self.s = tofloat(s)
|
|
219
|
+
self._shape = shape
|
|
220
|
+
|
|
221
|
+
def matvec(self, x): return x * self.s
|
|
222
|
+
def rmatvec(self, x): return x * self.s
|
|
223
|
+
|
|
224
|
+
def matmat(self, x): return Dense(x * self.s)
|
|
225
|
+
def rmatmat(self, x): return Dense(x * self.s)
|
|
226
|
+
|
|
227
|
+
def solve(self, b): return b / self.s
|
|
228
|
+
def solve_bounded(self, b, bound, ord = 2):
|
|
229
|
+
b_norm = torch.linalg.vector_norm(b, ord=ord) # pylint:disable=not-callable
|
|
230
|
+
sol = b / self.s
|
|
231
|
+
sol_norm = b_norm / abs(self.s)
|
|
232
|
+
|
|
233
|
+
if sol_norm > bound:
|
|
234
|
+
if not math.isfinite(sol_norm):
|
|
235
|
+
if b_norm > bound: return b * (bound / b_norm)
|
|
236
|
+
return b
|
|
237
|
+
return sol * (bound / sol_norm)
|
|
238
|
+
|
|
239
|
+
return sol
|
|
240
|
+
|
|
241
|
+
def add(self, x): return Dense(x + self.s)
|
|
242
|
+
def add_diagonal(self, x):
|
|
243
|
+
if isinstance(x, torch.Tensor) and x.numel() <= 1: x = x.item()
|
|
244
|
+
if isinstance(x, (int,float)): return ScaledIdentity(x + self.s, shape=self._shape, device=self.device, dtype=self.dtype)
|
|
245
|
+
return Diagonal(x + self.s)
|
|
246
|
+
|
|
247
|
+
def diagonal(self):
|
|
248
|
+
if self._shape is None: raise RuntimeError("Shape is None")
|
|
249
|
+
return torch.full(self._shape, fill_value=self.s, device=self.device, dtype=self.dtype)
|
|
250
|
+
|
|
251
|
+
def inv(self): return ScaledIdentity(1 / self.s, shape=self._shape, device=self.device, dtype=self.dtype)
|
|
252
|
+
def to_tensor(self):
|
|
253
|
+
if self._shape is None: raise RuntimeError("Shape is None")
|
|
254
|
+
return torch.eye(*self.shape, device=self.device, dtype=self.dtype).mul_(self.s)
|
|
255
|
+
|
|
256
|
+
def size(self):
|
|
257
|
+
if self._shape is None: raise RuntimeError("Shape is None")
|
|
258
|
+
return self._shape
|
|
259
|
+
|
|
260
|
+
def __repr__(self):
|
|
261
|
+
return f"ScaledIdentity(s={self.s}, shape={self._shape}, dtype={self.dtype}, device={self.device})"
|
|
262
|
+
|
|
263
|
+
def is_dense(self): return False
|
|
264
|
+
def transpose(self): return ScaledIdentity(self.s, shape=self.shape, device=self.device, dtype=self.dtype)
|
|
265
|
+
|
|
266
|
+
class AtA(LinearOperator):
|
|
267
|
+
def __init__(self, A: torch.Tensor):
|
|
268
|
+
self.A = A
|
|
269
|
+
|
|
270
|
+
def matvec(self, x): return self.A.mH.mv(self.A.mv(x))
|
|
271
|
+
def rmatvec(self, x): return self.matvec(x)
|
|
272
|
+
|
|
273
|
+
def matmat(self, x): return Dense(torch.linalg.multi_dot([self.A.mH, self.A, x])) # pylint:disable=not-callable
|
|
274
|
+
def rmatmat(self, x): return Dense(torch.linalg.multi_dot([self.A.mH, self.A, x])) # pylint:disable=not-callable
|
|
275
|
+
|
|
276
|
+
def is_dense(self): return False
|
|
277
|
+
def to_tensor(self): return self.A.mH @ self.A
|
|
278
|
+
def transpose(self): return AtA(self.A)
|
|
279
|
+
|
|
280
|
+
def add_diagonal(self, x):
|
|
281
|
+
if isinstance(x, torch.Tensor) and x.numel() <= 1: x = x.item()
|
|
282
|
+
if isinstance(x, (int,float)): x = torch.full((self.shape[0],), fill_value=x, device=self.A.device, dtype=self.A.dtype)
|
|
283
|
+
return Dense(self.to_tensor() + torch.diag_embed(x))
|
|
284
|
+
|
|
285
|
+
def solve(self, b):
|
|
286
|
+
return Dense(self.to_tensor()).solve(b)
|
|
287
|
+
|
|
288
|
+
def inv(self):
|
|
289
|
+
return Dense(self.to_tensor()).inv()
|
|
290
|
+
|
|
291
|
+
def diagonal(self):
|
|
292
|
+
return self.A.pow(2).sum(1)
|
|
293
|
+
|
|
294
|
+
def size(self):
|
|
295
|
+
n = self.A.size(1)
|
|
296
|
+
return (n,n)
|
|
297
|
+
|
|
298
|
+
class AAT(LinearOperator):
|
|
299
|
+
def __init__(self, A: torch.Tensor):
|
|
300
|
+
self.A = A
|
|
301
|
+
|
|
302
|
+
def matvec(self, x): return self.A.mv(self.A.mH.mv(x))
|
|
303
|
+
def rmatvec(self, x): return self.matvec(x)
|
|
304
|
+
|
|
305
|
+
def matmat(self, x): return Dense(torch.linalg.multi_dot([self.A, self.A.mH, x])) # pylint:disable=not-callable
|
|
306
|
+
def rmatmat(self, x): return Dense(torch.linalg.multi_dot([self.A, self.A.mH, x])) # pylint:disable=not-callable
|
|
307
|
+
|
|
308
|
+
def is_dense(self): return False
|
|
309
|
+
def to_tensor(self): return self.A @ self.A.mH
|
|
310
|
+
def transpose(self): return AAT(self.A)
|
|
311
|
+
|
|
312
|
+
def add_diagonal(self, x):
|
|
313
|
+
if isinstance(x, torch.Tensor) and x.numel() <= 1: x = x.item()
|
|
314
|
+
if isinstance(x, (int,float)): x = torch.full((self.shape[0],), fill_value=x, device=self.A.device, dtype=self.A.dtype)
|
|
315
|
+
return Dense(self.to_tensor() + torch.diag_embed(x))
|
|
316
|
+
|
|
317
|
+
def solve(self, b):
|
|
318
|
+
return Dense(self.to_tensor()).solve(b)
|
|
319
|
+
|
|
320
|
+
def inv(self):
|
|
321
|
+
return Dense(self.to_tensor()).inv()
|
|
322
|
+
|
|
323
|
+
def diagonal(self):
|
|
324
|
+
return self.A.pow(2).sum(0)
|
|
325
|
+
|
|
326
|
+
def size(self):
|
|
327
|
+
n = self.A.size(1)
|
|
328
|
+
return (n,n)
|
|
329
|
+
|
|
@@ -15,13 +15,13 @@ def singular_vals_func(A: torch.Tensor, fn: Callable[[torch.Tensor], torch.Tenso
|
|
|
15
15
|
|
|
16
16
|
def matrix_power_eigh(A: torch.Tensor, pow:float):
|
|
17
17
|
L, Q = torch.linalg.eigh(A) # pylint:disable=not-callable
|
|
18
|
-
if pow % 2 != 0: L.clip_(min = torch.finfo(A.dtype).
|
|
18
|
+
if pow % 2 != 0: L.clip_(min = torch.finfo(A.dtype).tiny * 2)
|
|
19
19
|
return (Q * L.pow(pow).unsqueeze(-2)) @ Q.mH
|
|
20
20
|
|
|
21
21
|
|
|
22
22
|
def inv_sqrt_2x2(A: torch.Tensor, force_pd: bool=False) -> torch.Tensor:
|
|
23
23
|
"""Inverse square root of a possibly batched 2x2 matrix using a general formula for 2x2 matrices so that this is way faster than torch linalg. I tried doing a hierarchical 2x2 preconditioning but it didn't work well."""
|
|
24
|
-
eps = torch.finfo(A.dtype).
|
|
24
|
+
eps = torch.finfo(A.dtype).tiny * 2
|
|
25
25
|
|
|
26
26
|
a = A[..., 0, 0]
|
|
27
27
|
b = A[..., 0, 1]
|
|
@@ -8,4 +8,5 @@ def gram_schmidt(x: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.
|
|
|
8
8
|
def gram_schmidt(x: TensorList, y: TensorList) -> tuple[TensorList, TensorList]: ...
|
|
9
9
|
def gram_schmidt(x, y):
|
|
10
10
|
"""makes two orthogonal vectors, only y is changed"""
|
|
11
|
-
|
|
11
|
+
min = torch.finfo(x.dtype).tiny * 2
|
|
12
|
+
return x, y - (x*y) / (x*x).clip(min=min)
|
torchzero/utils/linalg/qr.py
CHANGED
|
@@ -20,7 +20,7 @@ def _get_w_tau(R: torch.Tensor, i: int, eps: float):
|
|
|
20
20
|
def _qr_householder_complete(A:torch.Tensor):
|
|
21
21
|
*b,m,n = A.shape
|
|
22
22
|
k = min(m,n)
|
|
23
|
-
eps = torch.finfo(A.dtype).
|
|
23
|
+
eps = torch.finfo(A.dtype).tiny * 2
|
|
24
24
|
|
|
25
25
|
Q = torch.eye(m, dtype=A.dtype, device=A.device).expand(*b, m, m).clone() # clone because expanded dims refer to same memory
|
|
26
26
|
R = A.clone()
|
|
@@ -36,7 +36,7 @@ def _qr_householder_complete(A:torch.Tensor):
|
|
|
36
36
|
def _qr_householder_reduced(A:torch.Tensor):
|
|
37
37
|
*b,m,n = A.shape
|
|
38
38
|
k = min(m,n)
|
|
39
|
-
eps = torch.finfo(A.dtype).
|
|
39
|
+
eps = torch.finfo(A.dtype).tiny * 2
|
|
40
40
|
|
|
41
41
|
R = A.clone()
|
|
42
42
|
|