torchzero 0.3.11__py3-none-any.whl → 0.3.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (161) hide show
  1. tests/test_opts.py +95 -69
  2. tests/test_tensorlist.py +8 -7
  3. torchzero/__init__.py +1 -1
  4. torchzero/core/__init__.py +2 -2
  5. torchzero/core/module.py +225 -72
  6. torchzero/core/reformulation.py +65 -0
  7. torchzero/core/transform.py +44 -24
  8. torchzero/modules/__init__.py +13 -5
  9. torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
  10. torchzero/modules/adaptive/adagrad.py +356 -0
  11. torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
  12. torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
  13. torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
  14. torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
  15. torchzero/modules/adaptive/aegd.py +54 -0
  16. torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
  17. torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
  18. torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
  19. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  20. torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
  21. torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
  22. torchzero/modules/adaptive/natural_gradient.py +175 -0
  23. torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
  24. torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
  25. torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
  26. torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
  27. torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
  28. torchzero/modules/clipping/clipping.py +85 -92
  29. torchzero/modules/clipping/ema_clipping.py +5 -5
  30. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  31. torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
  32. torchzero/modules/experimental/__init__.py +9 -32
  33. torchzero/modules/experimental/dct.py +2 -2
  34. torchzero/modules/experimental/fft.py +2 -2
  35. torchzero/modules/experimental/gradmin.py +4 -3
  36. torchzero/modules/experimental/l_infinity.py +111 -0
  37. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
  38. torchzero/modules/experimental/newton_solver.py +79 -17
  39. torchzero/modules/experimental/newtonnewton.py +27 -14
  40. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  41. torchzero/modules/experimental/structural_projections.py +1 -1
  42. torchzero/modules/functional.py +50 -14
  43. torchzero/modules/grad_approximation/fdm.py +19 -20
  44. torchzero/modules/grad_approximation/forward_gradient.py +4 -2
  45. torchzero/modules/grad_approximation/grad_approximator.py +43 -47
  46. torchzero/modules/grad_approximation/rfdm.py +144 -122
  47. torchzero/modules/higher_order/__init__.py +1 -1
  48. torchzero/modules/higher_order/higher_order_newton.py +31 -23
  49. torchzero/modules/least_squares/__init__.py +1 -0
  50. torchzero/modules/least_squares/gn.py +161 -0
  51. torchzero/modules/line_search/__init__.py +2 -2
  52. torchzero/modules/line_search/_polyinterp.py +289 -0
  53. torchzero/modules/line_search/adaptive.py +69 -44
  54. torchzero/modules/line_search/backtracking.py +83 -70
  55. torchzero/modules/line_search/line_search.py +159 -68
  56. torchzero/modules/line_search/scipy.py +1 -1
  57. torchzero/modules/line_search/strong_wolfe.py +319 -218
  58. torchzero/modules/misc/__init__.py +8 -0
  59. torchzero/modules/misc/debug.py +4 -4
  60. torchzero/modules/misc/escape.py +9 -7
  61. torchzero/modules/misc/gradient_accumulation.py +88 -22
  62. torchzero/modules/misc/homotopy.py +59 -0
  63. torchzero/modules/misc/misc.py +82 -15
  64. torchzero/modules/misc/multistep.py +47 -11
  65. torchzero/modules/misc/regularization.py +5 -9
  66. torchzero/modules/misc/split.py +55 -35
  67. torchzero/modules/misc/switch.py +1 -1
  68. torchzero/modules/momentum/__init__.py +1 -5
  69. torchzero/modules/momentum/averaging.py +3 -3
  70. torchzero/modules/momentum/cautious.py +42 -47
  71. torchzero/modules/momentum/momentum.py +35 -1
  72. torchzero/modules/ops/__init__.py +9 -1
  73. torchzero/modules/ops/binary.py +9 -8
  74. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
  75. torchzero/modules/ops/multi.py +15 -15
  76. torchzero/modules/ops/reduce.py +1 -1
  77. torchzero/modules/ops/utility.py +12 -8
  78. torchzero/modules/projections/projection.py +4 -4
  79. torchzero/modules/quasi_newton/__init__.py +1 -16
  80. torchzero/modules/quasi_newton/damping.py +105 -0
  81. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
  82. torchzero/modules/quasi_newton/lbfgs.py +256 -200
  83. torchzero/modules/quasi_newton/lsr1.py +167 -132
  84. torchzero/modules/quasi_newton/quasi_newton.py +346 -446
  85. torchzero/modules/restarts/__init__.py +7 -0
  86. torchzero/modules/restarts/restars.py +252 -0
  87. torchzero/modules/second_order/__init__.py +2 -1
  88. torchzero/modules/second_order/multipoint.py +238 -0
  89. torchzero/modules/second_order/newton.py +133 -88
  90. torchzero/modules/second_order/newton_cg.py +141 -80
  91. torchzero/modules/smoothing/__init__.py +1 -1
  92. torchzero/modules/smoothing/sampling.py +300 -0
  93. torchzero/modules/step_size/__init__.py +1 -1
  94. torchzero/modules/step_size/adaptive.py +312 -47
  95. torchzero/modules/termination/__init__.py +14 -0
  96. torchzero/modules/termination/termination.py +207 -0
  97. torchzero/modules/trust_region/__init__.py +5 -0
  98. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  99. torchzero/modules/trust_region/dogleg.py +92 -0
  100. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  101. torchzero/modules/trust_region/trust_cg.py +97 -0
  102. torchzero/modules/trust_region/trust_region.py +350 -0
  103. torchzero/modules/variance_reduction/__init__.py +1 -0
  104. torchzero/modules/variance_reduction/svrg.py +208 -0
  105. torchzero/modules/weight_decay/weight_decay.py +65 -64
  106. torchzero/modules/zeroth_order/__init__.py +1 -0
  107. torchzero/modules/zeroth_order/cd.py +359 -0
  108. torchzero/optim/root.py +65 -0
  109. torchzero/optim/utility/split.py +8 -8
  110. torchzero/optim/wrappers/directsearch.py +0 -1
  111. torchzero/optim/wrappers/fcmaes.py +3 -2
  112. torchzero/optim/wrappers/nlopt.py +0 -2
  113. torchzero/optim/wrappers/optuna.py +2 -2
  114. torchzero/optim/wrappers/scipy.py +81 -22
  115. torchzero/utils/__init__.py +40 -4
  116. torchzero/utils/compile.py +1 -1
  117. torchzero/utils/derivatives.py +123 -111
  118. torchzero/utils/linalg/__init__.py +9 -2
  119. torchzero/utils/linalg/linear_operator.py +329 -0
  120. torchzero/utils/linalg/matrix_funcs.py +2 -2
  121. torchzero/utils/linalg/orthogonalize.py +2 -1
  122. torchzero/utils/linalg/qr.py +2 -2
  123. torchzero/utils/linalg/solve.py +226 -154
  124. torchzero/utils/metrics.py +83 -0
  125. torchzero/utils/python_tools.py +6 -0
  126. torchzero/utils/tensorlist.py +105 -34
  127. torchzero/utils/torch_tools.py +9 -4
  128. torchzero-0.3.13.dist-info/METADATA +14 -0
  129. torchzero-0.3.13.dist-info/RECORD +166 -0
  130. {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
  131. docs/source/conf.py +0 -59
  132. docs/source/docstring template.py +0 -46
  133. torchzero/modules/experimental/absoap.py +0 -253
  134. torchzero/modules/experimental/adadam.py +0 -118
  135. torchzero/modules/experimental/adamY.py +0 -131
  136. torchzero/modules/experimental/adam_lambertw.py +0 -149
  137. torchzero/modules/experimental/adaptive_step_size.py +0 -90
  138. torchzero/modules/experimental/adasoap.py +0 -177
  139. torchzero/modules/experimental/cosine.py +0 -214
  140. torchzero/modules/experimental/cubic_adam.py +0 -97
  141. torchzero/modules/experimental/eigendescent.py +0 -120
  142. torchzero/modules/experimental/etf.py +0 -195
  143. torchzero/modules/experimental/exp_adam.py +0 -113
  144. torchzero/modules/experimental/expanded_lbfgs.py +0 -141
  145. torchzero/modules/experimental/hnewton.py +0 -85
  146. torchzero/modules/experimental/modular_lbfgs.py +0 -265
  147. torchzero/modules/experimental/parabolic_search.py +0 -220
  148. torchzero/modules/experimental/subspace_preconditioners.py +0 -145
  149. torchzero/modules/experimental/tensor_adagrad.py +0 -42
  150. torchzero/modules/line_search/polynomial.py +0 -233
  151. torchzero/modules/momentum/matrix_momentum.py +0 -193
  152. torchzero/modules/optimizers/adagrad.py +0 -165
  153. torchzero/modules/quasi_newton/trust_region.py +0 -397
  154. torchzero/modules/smoothing/gaussian.py +0 -198
  155. torchzero-0.3.11.dist-info/METADATA +0 -404
  156. torchzero-0.3.11.dist-info/RECORD +0 -159
  157. torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
  158. /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
  159. /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
  160. /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
  161. {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
@@ -0,0 +1,329 @@
1
+ """simplified version of https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.linalg.LinearOperator.html. This is used for trust regions."""
2
+ import math
3
+ from abc import ABC, abstractmethod
4
+ from functools import partial
5
+ from importlib.util import find_spec
6
+ from typing import cast, final
7
+
8
+ import torch
9
+
10
+ from ..torch_tools import tofloat, tonumpy, totensor
11
+
12
+ if find_spec('scipy') is not None:
13
+ from scipy.sparse.linalg import LinearOperator as _ScipyLinearOperator
14
+ else:
15
+ _ScipyLinearOperator = None
16
+
17
+ class LinearOperator(ABC):
18
+ """this is used for trust region"""
19
+ device: torch.types.Device
20
+ dtype: torch.dtype | None
21
+
22
+ def matvec(self, x: torch.Tensor) -> torch.Tensor:
23
+ raise NotImplementedError(f"{self.__class__.__name__} doesn't implement matvec")
24
+
25
+ def rmatvec(self, x: torch.Tensor) -> torch.Tensor:
26
+ raise NotImplementedError(f"{self.__class__.__name__} doesn't implement rmatvec")
27
+
28
+ def matmat(self, x: torch.Tensor) -> "LinearOperator":
29
+ raise NotImplementedError(f"{self.__class__.__name__} doesn't implement matmul")
30
+
31
+ def solve(self, b: torch.Tensor) -> torch.Tensor:
32
+ raise NotImplementedError(f"{self.__class__.__name__} doesn't implement solve")
33
+
34
+ def solve_bounded(self, b: torch.Tensor, bound:float, ord:float=2) -> torch.Tensor:
35
+ """solve with a norm bound on x"""
36
+ raise NotImplementedError(f"{self.__class__.__name__} doesn't implement solve_bounded")
37
+
38
+ def update(self, *args, **kwargs) -> None:
39
+ raise NotImplementedError(f"{self.__class__.__name__} doesn't implement update")
40
+
41
+ def add(self, x: torch.Tensor) -> "LinearOperator":
42
+ raise NotImplementedError(f"{self.__class__.__name__} doesn't implement add")
43
+
44
+ def __add__(self, x: torch.Tensor) -> "LinearOperator":
45
+ return self.add(x)
46
+
47
+ def add_diagonal(self, x: torch.Tensor | float) -> "LinearOperator":
48
+ raise NotImplementedError(f"{self.__class__.__name__} doesn't implement add_diagonal")
49
+
50
+ def diagonal(self) -> torch.Tensor:
51
+ raise NotImplementedError(f"{self.__class__.__name__} doesn't implement diagonal")
52
+
53
+ def inv(self) -> "LinearOperator":
54
+ raise NotImplementedError(f"{self.__class__.__name__} doesn't implement inverse")
55
+
56
+ def transpose(self) -> "LinearOperator":
57
+ raise NotImplementedError(f"{self.__class__.__name__} doesn't implement transpose")
58
+
59
+ @property
60
+ def T(self): return self.transpose()
61
+
62
+ def to_tensor(self) -> torch.Tensor:
63
+ raise NotImplementedError(f"{self.__class__.__name__} doesn't implement to_tensor")
64
+
65
+ def to_dense(self) -> "Dense":
66
+ return Dense(self) # calls to_tensor
67
+
68
+ def size(self) -> tuple[int, ...]:
69
+ raise NotImplementedError(f"{self.__class__.__name__} doesn't implement size")
70
+
71
+ @property
72
+ def shape(self) -> tuple[int, ...]:
73
+ return self.size()
74
+
75
+ def numel(self) -> int:
76
+ return math.prod(self.size())
77
+
78
+ def ndimension(self) -> int:
79
+ return len(self.size())
80
+
81
+ @property
82
+ def ndim(self) -> int:
83
+ return self.ndimension()
84
+
85
+ def _numpy_matvec(self, x, dtype=None):
86
+ """returns Ax ndarray for scipy's LinearOperator"""
87
+ Ax = self.matvec(totensor(x, device=self.device, dtype=self.dtype))
88
+ Ax = tonumpy(Ax)
89
+ if dtype is not None: Ax = Ax.astype(dtype)
90
+ return Ax
91
+
92
+ def _numpy_rmatvec(self, x, dtype=None):
93
+ """returns Ax ndarray for scipy's LinearOperator"""
94
+ Ax = self.rmatvec(totensor(x, device=self.device, dtype=self.dtype))
95
+ Ax = tonumpy(Ax)
96
+ if dtype is not None: Ax = Ax.astype(dtype)
97
+ return Ax
98
+
99
+ def scipy_linop(self, dtype=None):
100
+ if _ScipyLinearOperator is None: raise ModuleNotFoundError("Scipy needs to be installed")
101
+ return _ScipyLinearOperator(
102
+ dtype=dtype,
103
+ shape=self.size(),
104
+ matvec=partial(self._numpy_matvec, dtype=dtype), # pyright:ignore[reportCallIssue]
105
+ rmatvec=partial(self._numpy_rmatvec, dtype=dtype), # pyright:ignore[reportCallIssue]
106
+ )
107
+
108
+ def is_dense(self) -> bool:
109
+ raise NotImplementedError(f"{self.__class__.__name__} doesn't implement is_dense")
110
+
111
+ def _solve(A: torch.Tensor, b: torch.Tensor) -> torch.Tensor: # should I keep this or separate solve and lstsq?
112
+ sol, info = torch.linalg.solve_ex(A, b) # pylint:disable=not-callable
113
+ if info == 0: return sol
114
+ return torch.linalg.lstsq(A, b).solution # pylint:disable=not-callable
115
+
116
+ def _inv(A: torch.Tensor) -> torch.Tensor:
117
+ sol, info = torch.linalg.inv_ex(A) # pylint:disable=not-callable
118
+ if info == 0: return sol
119
+ return torch.linalg.pinv(A) # pylint:disable=not-callable
120
+
121
+
122
+ class Dense(LinearOperator):
123
+ def __init__(self, A: torch.Tensor | LinearOperator):
124
+ if isinstance(A, LinearOperator): A = A.to_tensor()
125
+ self.A: torch.Tensor = A
126
+ self.device = self.A.device
127
+ self.dtype = self.A.dtype
128
+
129
+ def matvec(self, x): return self.A.mv(x)
130
+ def rmatvec(self, x): return self.A.mH.mv(x)
131
+
132
+ def matmat(self, x): return Dense(self.A.mm(x))
133
+ def rmatmat(self, x): return Dense(self.A.mH.mm(x))
134
+
135
+ def solve(self, b): return _solve(self.A, b)
136
+
137
+ def add(self, x): return Dense(self.A + x)
138
+ def add_diagonal(self, x):
139
+ if isinstance(x, torch.Tensor) and x.numel() <= 1: x = x.item()
140
+ if isinstance(x, (int,float)): x = torch.full((self.shape[0],), fill_value=x, device=self.A.device, dtype=self.A.dtype)
141
+ return Dense(self.A + torch.diag_embed(x))
142
+ def diagonal(self): return self.A.diagonal()
143
+ def inv(self): return Dense(_inv(self.A)) # pylint:disable=not-callable
144
+ def to_tensor(self): return self.A
145
+ def size(self): return self.A.size()
146
+ def is_dense(self): return True
147
+ def transpose(self): return Dense(self.A.mH)
148
+
149
+ class DenseInverse(LinearOperator):
150
+ """Represents inverse of a dense matrix A."""
151
+ def __init__(self, A_inv: torch.Tensor):
152
+ self.A_inv: torch.Tensor = A_inv
153
+ self.device = self.A_inv.device
154
+ self.dtype = self.A_inv.dtype
155
+
156
+ def matvec(self, x): return _solve(self.A_inv, x) # pylint:disable=not-callable
157
+ def rmatvec(self, x): return _solve(self.A_inv.mH, x) # pylint:disable=not-callable
158
+
159
+ def matmat(self, x): return Dense(_solve(self.A_inv, x)) # pylint:disable=not-callable
160
+ def rmatmat(self, x): return Dense(_solve(self.A_inv.mH, x)) # pylint:disable=not-callable
161
+
162
+ def solve(self, b): return self.A_inv.mv(b)
163
+
164
+ def inv(self): return Dense(self.A_inv) # pylint:disable=not-callable
165
+ def to_tensor(self): return _inv(self.A_inv) # pylint:disable=not-callable
166
+ def size(self): return self.A_inv.size()
167
+ def is_dense(self): return True
168
+ def transpose(self): return DenseInverse(self.A_inv.mH)
169
+
170
+ class DenseWithInverse(Dense):
171
+ """Represents a matrix where both the matrix and the inverse are known.
172
+
173
+ ``matmat``, ``rmatmat``, ``add`` and ``add_diagonal`` will return a Dense matrix, inverse will be lost.
174
+ """
175
+ def __init__(self, A: torch.Tensor, A_inv: torch.Tensor):
176
+ super().__init__(A)
177
+ self.A_inv: torch.Tensor = A_inv
178
+
179
+ def solve(self, b): return self.A_inv.mv(b)
180
+ def inv(self): return DenseWithInverse(self.A_inv, self.A) # pylint:disable=not-callable
181
+ def transpose(self): return DenseWithInverse(self.A.mH, self.A_inv.mH)
182
+
183
+ class Diagonal(LinearOperator):
184
+ def __init__(self, x: torch.Tensor):
185
+ assert x.ndim == 1
186
+ self.A: torch.Tensor = x
187
+ self.device = self.A.device
188
+ self.dtype = self.A.dtype
189
+
190
+ def matvec(self, x): return self.A * x
191
+ def rmatvec(self, x): return self.A * x
192
+
193
+ def matmat(self, x): return Dense(x * self.A.unsqueeze(-1))
194
+ def rmatmat(self, x): return Dense(x * self.A.unsqueeze(-1))
195
+
196
+ def solve(self, b): return b/self.A
197
+
198
+ def add(self, x): return Dense(x + self.A.diag_embed())
199
+ def add_diagonal(self, x): return Diagonal(self.A + x)
200
+ def diagonal(self): return self.A
201
+ def inv(self): return Diagonal(1/self.A)
202
+ def to_tensor(self): return self.A.diag_embed()
203
+ def size(self): return (self.A.numel(), self.A.numel())
204
+ def is_dense(self): return False
205
+ def transpose(self): return Diagonal(self.A)
206
+
207
+ class ScaledIdentity(LinearOperator):
208
+ def __init__(self, s: float | torch.Tensor = 1., shape=None, device=None, dtype=None):
209
+ self.device = self.dtype = None
210
+
211
+ if isinstance(s, torch.Tensor):
212
+ self.device = s.device
213
+ self.dtype = s.dtype
214
+
215
+ if device is not None: self.device = device
216
+ if dtype is not None: self.dtype = dtype
217
+
218
+ self.s = tofloat(s)
219
+ self._shape = shape
220
+
221
+ def matvec(self, x): return x * self.s
222
+ def rmatvec(self, x): return x * self.s
223
+
224
+ def matmat(self, x): return Dense(x * self.s)
225
+ def rmatmat(self, x): return Dense(x * self.s)
226
+
227
+ def solve(self, b): return b / self.s
228
+ def solve_bounded(self, b, bound, ord = 2):
229
+ b_norm = torch.linalg.vector_norm(b, ord=ord) # pylint:disable=not-callable
230
+ sol = b / self.s
231
+ sol_norm = b_norm / abs(self.s)
232
+
233
+ if sol_norm > bound:
234
+ if not math.isfinite(sol_norm):
235
+ if b_norm > bound: return b * (bound / b_norm)
236
+ return b
237
+ return sol * (bound / sol_norm)
238
+
239
+ return sol
240
+
241
+ def add(self, x): return Dense(x + self.s)
242
+ def add_diagonal(self, x):
243
+ if isinstance(x, torch.Tensor) and x.numel() <= 1: x = x.item()
244
+ if isinstance(x, (int,float)): return ScaledIdentity(x + self.s, shape=self._shape, device=self.device, dtype=self.dtype)
245
+ return Diagonal(x + self.s)
246
+
247
+ def diagonal(self):
248
+ if self._shape is None: raise RuntimeError("Shape is None")
249
+ return torch.full(self._shape, fill_value=self.s, device=self.device, dtype=self.dtype)
250
+
251
+ def inv(self): return ScaledIdentity(1 / self.s, shape=self._shape, device=self.device, dtype=self.dtype)
252
+ def to_tensor(self):
253
+ if self._shape is None: raise RuntimeError("Shape is None")
254
+ return torch.eye(*self.shape, device=self.device, dtype=self.dtype).mul_(self.s)
255
+
256
+ def size(self):
257
+ if self._shape is None: raise RuntimeError("Shape is None")
258
+ return self._shape
259
+
260
+ def __repr__(self):
261
+ return f"ScaledIdentity(s={self.s}, shape={self._shape}, dtype={self.dtype}, device={self.device})"
262
+
263
+ def is_dense(self): return False
264
+ def transpose(self): return ScaledIdentity(self.s, shape=self.shape, device=self.device, dtype=self.dtype)
265
+
266
+ class AtA(LinearOperator):
267
+ def __init__(self, A: torch.Tensor):
268
+ self.A = A
269
+
270
+ def matvec(self, x): return self.A.mH.mv(self.A.mv(x))
271
+ def rmatvec(self, x): return self.matvec(x)
272
+
273
+ def matmat(self, x): return Dense(torch.linalg.multi_dot([self.A.mH, self.A, x])) # pylint:disable=not-callable
274
+ def rmatmat(self, x): return Dense(torch.linalg.multi_dot([self.A.mH, self.A, x])) # pylint:disable=not-callable
275
+
276
+ def is_dense(self): return False
277
+ def to_tensor(self): return self.A.mH @ self.A
278
+ def transpose(self): return AtA(self.A)
279
+
280
+ def add_diagonal(self, x):
281
+ if isinstance(x, torch.Tensor) and x.numel() <= 1: x = x.item()
282
+ if isinstance(x, (int,float)): x = torch.full((self.shape[0],), fill_value=x, device=self.A.device, dtype=self.A.dtype)
283
+ return Dense(self.to_tensor() + torch.diag_embed(x))
284
+
285
+ def solve(self, b):
286
+ return Dense(self.to_tensor()).solve(b)
287
+
288
+ def inv(self):
289
+ return Dense(self.to_tensor()).inv()
290
+
291
+ def diagonal(self):
292
+ return self.A.pow(2).sum(1)
293
+
294
+ def size(self):
295
+ n = self.A.size(1)
296
+ return (n,n)
297
+
298
+ class AAT(LinearOperator):
299
+ def __init__(self, A: torch.Tensor):
300
+ self.A = A
301
+
302
+ def matvec(self, x): return self.A.mv(self.A.mH.mv(x))
303
+ def rmatvec(self, x): return self.matvec(x)
304
+
305
+ def matmat(self, x): return Dense(torch.linalg.multi_dot([self.A, self.A.mH, x])) # pylint:disable=not-callable
306
+ def rmatmat(self, x): return Dense(torch.linalg.multi_dot([self.A, self.A.mH, x])) # pylint:disable=not-callable
307
+
308
+ def is_dense(self): return False
309
+ def to_tensor(self): return self.A @ self.A.mH
310
+ def transpose(self): return AAT(self.A)
311
+
312
+ def add_diagonal(self, x):
313
+ if isinstance(x, torch.Tensor) and x.numel() <= 1: x = x.item()
314
+ if isinstance(x, (int,float)): x = torch.full((self.shape[0],), fill_value=x, device=self.A.device, dtype=self.A.dtype)
315
+ return Dense(self.to_tensor() + torch.diag_embed(x))
316
+
317
+ def solve(self, b):
318
+ return Dense(self.to_tensor()).solve(b)
319
+
320
+ def inv(self):
321
+ return Dense(self.to_tensor()).inv()
322
+
323
+ def diagonal(self):
324
+ return self.A.pow(2).sum(0)
325
+
326
+ def size(self):
327
+ n = self.A.size(1)
328
+ return (n,n)
329
+
@@ -15,13 +15,13 @@ def singular_vals_func(A: torch.Tensor, fn: Callable[[torch.Tensor], torch.Tenso
15
15
 
16
16
  def matrix_power_eigh(A: torch.Tensor, pow:float):
17
17
  L, Q = torch.linalg.eigh(A) # pylint:disable=not-callable
18
- if pow % 2 != 0: L.clip_(min = torch.finfo(A.dtype).eps)
18
+ if pow % 2 != 0: L.clip_(min = torch.finfo(A.dtype).tiny * 2)
19
19
  return (Q * L.pow(pow).unsqueeze(-2)) @ Q.mH
20
20
 
21
21
 
22
22
  def inv_sqrt_2x2(A: torch.Tensor, force_pd: bool=False) -> torch.Tensor:
23
23
  """Inverse square root of a possibly batched 2x2 matrix using a general formula for 2x2 matrices so that this is way faster than torch linalg. I tried doing a hierarchical 2x2 preconditioning but it didn't work well."""
24
- eps = torch.finfo(A.dtype).eps
24
+ eps = torch.finfo(A.dtype).tiny * 2
25
25
 
26
26
  a = A[..., 0, 0]
27
27
  b = A[..., 0, 1]
@@ -8,4 +8,5 @@ def gram_schmidt(x: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.
8
8
  def gram_schmidt(x: TensorList, y: TensorList) -> tuple[TensorList, TensorList]: ...
9
9
  def gram_schmidt(x, y):
10
10
  """makes two orthogonal vectors, only y is changed"""
11
- return x, y - (x*y) / ((x*x) + 1e-8)
11
+ min = torch.finfo(x.dtype).tiny * 2
12
+ return x, y - (x*y) / (x*x).clip(min=min)
@@ -20,7 +20,7 @@ def _get_w_tau(R: torch.Tensor, i: int, eps: float):
20
20
  def _qr_householder_complete(A:torch.Tensor):
21
21
  *b,m,n = A.shape
22
22
  k = min(m,n)
23
- eps = torch.finfo(A.dtype).eps
23
+ eps = torch.finfo(A.dtype).tiny * 2
24
24
 
25
25
  Q = torch.eye(m, dtype=A.dtype, device=A.device).expand(*b, m, m).clone() # clone because expanded dims refer to same memory
26
26
  R = A.clone()
@@ -36,7 +36,7 @@ def _qr_householder_complete(A:torch.Tensor):
36
36
  def _qr_householder_reduced(A:torch.Tensor):
37
37
  *b,m,n = A.shape
38
38
  k = min(m,n)
39
- eps = torch.finfo(A.dtype).eps
39
+ eps = torch.finfo(A.dtype).tiny * 2
40
40
 
41
41
  R = A.clone()
42
42