torchzero 0.3.10__py3-none-any.whl → 0.3.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (182) hide show
  1. tests/test_identical.py +2 -3
  2. tests/test_opts.py +140 -100
  3. tests/test_tensorlist.py +8 -7
  4. tests/test_vars.py +1 -0
  5. torchzero/__init__.py +1 -1
  6. torchzero/core/__init__.py +2 -2
  7. torchzero/core/module.py +335 -50
  8. torchzero/core/reformulation.py +65 -0
  9. torchzero/core/transform.py +197 -70
  10. torchzero/modules/__init__.py +13 -4
  11. torchzero/modules/adaptive/__init__.py +30 -0
  12. torchzero/modules/adaptive/adagrad.py +356 -0
  13. torchzero/modules/adaptive/adahessian.py +224 -0
  14. torchzero/modules/{optimizers → adaptive}/adam.py +6 -8
  15. torchzero/modules/adaptive/adan.py +96 -0
  16. torchzero/modules/adaptive/adaptive_heavyball.py +54 -0
  17. torchzero/modules/adaptive/aegd.py +54 -0
  18. torchzero/modules/adaptive/esgd.py +171 -0
  19. torchzero/modules/{optimizers → adaptive}/lion.py +1 -1
  20. torchzero/modules/{experimental/spectral.py → adaptive/lmadagrad.py} +94 -71
  21. torchzero/modules/adaptive/mars.py +79 -0
  22. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  23. torchzero/modules/adaptive/msam.py +188 -0
  24. torchzero/modules/{optimizers → adaptive}/muon.py +29 -5
  25. torchzero/modules/adaptive/natural_gradient.py +175 -0
  26. torchzero/modules/{optimizers → adaptive}/orthograd.py +1 -1
  27. torchzero/modules/{optimizers → adaptive}/rmsprop.py +7 -4
  28. torchzero/modules/{optimizers → adaptive}/rprop.py +42 -10
  29. torchzero/modules/adaptive/sam.py +163 -0
  30. torchzero/modules/{optimizers → adaptive}/shampoo.py +47 -9
  31. torchzero/modules/{optimizers → adaptive}/soap.py +52 -65
  32. torchzero/modules/adaptive/sophia_h.py +185 -0
  33. torchzero/modules/clipping/clipping.py +115 -25
  34. torchzero/modules/clipping/ema_clipping.py +31 -17
  35. torchzero/modules/clipping/growth_clipping.py +8 -7
  36. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  37. torchzero/modules/conjugate_gradient/cg.py +355 -0
  38. torchzero/modules/experimental/__init__.py +13 -19
  39. torchzero/modules/{projections → experimental}/dct.py +11 -11
  40. torchzero/modules/{projections → experimental}/fft.py +10 -10
  41. torchzero/modules/experimental/gradmin.py +4 -3
  42. torchzero/modules/experimental/l_infinity.py +111 -0
  43. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +5 -42
  44. torchzero/modules/experimental/newton_solver.py +79 -17
  45. torchzero/modules/experimental/newtonnewton.py +32 -15
  46. torchzero/modules/experimental/reduce_outward_lr.py +4 -4
  47. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  48. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +13 -55
  49. torchzero/modules/functional.py +52 -6
  50. torchzero/modules/grad_approximation/fdm.py +30 -4
  51. torchzero/modules/grad_approximation/forward_gradient.py +16 -4
  52. torchzero/modules/grad_approximation/grad_approximator.py +51 -10
  53. torchzero/modules/grad_approximation/rfdm.py +321 -52
  54. torchzero/modules/higher_order/__init__.py +1 -1
  55. torchzero/modules/higher_order/higher_order_newton.py +164 -93
  56. torchzero/modules/least_squares/__init__.py +1 -0
  57. torchzero/modules/least_squares/gn.py +161 -0
  58. torchzero/modules/line_search/__init__.py +4 -4
  59. torchzero/modules/line_search/_polyinterp.py +289 -0
  60. torchzero/modules/line_search/adaptive.py +124 -0
  61. torchzero/modules/line_search/backtracking.py +95 -57
  62. torchzero/modules/line_search/line_search.py +171 -22
  63. torchzero/modules/line_search/scipy.py +3 -3
  64. torchzero/modules/line_search/strong_wolfe.py +327 -199
  65. torchzero/modules/misc/__init__.py +35 -0
  66. torchzero/modules/misc/debug.py +48 -0
  67. torchzero/modules/misc/escape.py +62 -0
  68. torchzero/modules/misc/gradient_accumulation.py +136 -0
  69. torchzero/modules/misc/homotopy.py +59 -0
  70. torchzero/modules/misc/misc.py +383 -0
  71. torchzero/modules/misc/multistep.py +194 -0
  72. torchzero/modules/misc/regularization.py +167 -0
  73. torchzero/modules/misc/split.py +123 -0
  74. torchzero/modules/{ops → misc}/switch.py +45 -4
  75. torchzero/modules/momentum/__init__.py +1 -5
  76. torchzero/modules/momentum/averaging.py +9 -9
  77. torchzero/modules/momentum/cautious.py +51 -19
  78. torchzero/modules/momentum/momentum.py +37 -2
  79. torchzero/modules/ops/__init__.py +11 -31
  80. torchzero/modules/ops/accumulate.py +6 -10
  81. torchzero/modules/ops/binary.py +81 -34
  82. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +16 -39
  83. torchzero/modules/ops/multi.py +82 -21
  84. torchzero/modules/ops/reduce.py +16 -8
  85. torchzero/modules/ops/unary.py +29 -13
  86. torchzero/modules/ops/utility.py +30 -18
  87. torchzero/modules/projections/__init__.py +2 -4
  88. torchzero/modules/projections/cast.py +51 -0
  89. torchzero/modules/projections/galore.py +3 -1
  90. torchzero/modules/projections/projection.py +190 -96
  91. torchzero/modules/quasi_newton/__init__.py +9 -14
  92. torchzero/modules/quasi_newton/damping.py +105 -0
  93. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -0
  94. torchzero/modules/quasi_newton/lbfgs.py +286 -173
  95. torchzero/modules/quasi_newton/lsr1.py +185 -106
  96. torchzero/modules/quasi_newton/quasi_newton.py +816 -268
  97. torchzero/modules/restarts/__init__.py +7 -0
  98. torchzero/modules/restarts/restars.py +252 -0
  99. torchzero/modules/second_order/__init__.py +3 -2
  100. torchzero/modules/second_order/multipoint.py +238 -0
  101. torchzero/modules/second_order/newton.py +292 -68
  102. torchzero/modules/second_order/newton_cg.py +365 -15
  103. torchzero/modules/second_order/nystrom.py +104 -1
  104. torchzero/modules/smoothing/__init__.py +1 -1
  105. torchzero/modules/smoothing/laplacian.py +14 -4
  106. torchzero/modules/smoothing/sampling.py +300 -0
  107. torchzero/modules/step_size/__init__.py +2 -0
  108. torchzero/modules/step_size/adaptive.py +387 -0
  109. torchzero/modules/step_size/lr.py +154 -0
  110. torchzero/modules/termination/__init__.py +14 -0
  111. torchzero/modules/termination/termination.py +207 -0
  112. torchzero/modules/trust_region/__init__.py +5 -0
  113. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  114. torchzero/modules/trust_region/dogleg.py +92 -0
  115. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  116. torchzero/modules/trust_region/trust_cg.py +97 -0
  117. torchzero/modules/trust_region/trust_region.py +350 -0
  118. torchzero/modules/variance_reduction/__init__.py +1 -0
  119. torchzero/modules/variance_reduction/svrg.py +208 -0
  120. torchzero/modules/weight_decay/__init__.py +1 -1
  121. torchzero/modules/weight_decay/weight_decay.py +94 -11
  122. torchzero/modules/wrappers/optim_wrapper.py +29 -1
  123. torchzero/modules/zeroth_order/__init__.py +1 -0
  124. torchzero/modules/zeroth_order/cd.py +359 -0
  125. torchzero/optim/root.py +65 -0
  126. torchzero/optim/utility/split.py +8 -8
  127. torchzero/optim/wrappers/directsearch.py +39 -3
  128. torchzero/optim/wrappers/fcmaes.py +24 -15
  129. torchzero/optim/wrappers/mads.py +5 -6
  130. torchzero/optim/wrappers/nevergrad.py +16 -1
  131. torchzero/optim/wrappers/nlopt.py +0 -2
  132. torchzero/optim/wrappers/optuna.py +3 -3
  133. torchzero/optim/wrappers/scipy.py +86 -25
  134. torchzero/utils/__init__.py +40 -4
  135. torchzero/utils/compile.py +1 -1
  136. torchzero/utils/derivatives.py +126 -114
  137. torchzero/utils/linalg/__init__.py +9 -2
  138. torchzero/utils/linalg/linear_operator.py +329 -0
  139. torchzero/utils/linalg/matrix_funcs.py +2 -2
  140. torchzero/utils/linalg/orthogonalize.py +2 -1
  141. torchzero/utils/linalg/qr.py +2 -2
  142. torchzero/utils/linalg/solve.py +369 -58
  143. torchzero/utils/metrics.py +83 -0
  144. torchzero/utils/numberlist.py +2 -0
  145. torchzero/utils/python_tools.py +16 -0
  146. torchzero/utils/tensorlist.py +134 -51
  147. torchzero/utils/torch_tools.py +9 -4
  148. torchzero-0.3.13.dist-info/METADATA +14 -0
  149. torchzero-0.3.13.dist-info/RECORD +166 -0
  150. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
  151. docs/source/conf.py +0 -57
  152. torchzero/modules/experimental/absoap.py +0 -250
  153. torchzero/modules/experimental/adadam.py +0 -112
  154. torchzero/modules/experimental/adamY.py +0 -125
  155. torchzero/modules/experimental/adasoap.py +0 -172
  156. torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
  157. torchzero/modules/experimental/eigendescent.py +0 -117
  158. torchzero/modules/experimental/etf.py +0 -172
  159. torchzero/modules/experimental/soapy.py +0 -163
  160. torchzero/modules/experimental/structured_newton.py +0 -111
  161. torchzero/modules/experimental/subspace_preconditioners.py +0 -138
  162. torchzero/modules/experimental/tada.py +0 -38
  163. torchzero/modules/line_search/trust_region.py +0 -73
  164. torchzero/modules/lr/__init__.py +0 -2
  165. torchzero/modules/lr/adaptive.py +0 -93
  166. torchzero/modules/lr/lr.py +0 -63
  167. torchzero/modules/momentum/matrix_momentum.py +0 -166
  168. torchzero/modules/ops/debug.py +0 -25
  169. torchzero/modules/ops/misc.py +0 -418
  170. torchzero/modules/ops/split.py +0 -75
  171. torchzero/modules/optimizers/__init__.py +0 -18
  172. torchzero/modules/optimizers/adagrad.py +0 -155
  173. torchzero/modules/optimizers/sophia_h.py +0 -129
  174. torchzero/modules/quasi_newton/cg.py +0 -268
  175. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  176. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -266
  177. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  178. torchzero/modules/smoothing/gaussian.py +0 -164
  179. torchzero-0.3.10.dist-info/METADATA +0 -379
  180. torchzero-0.3.10.dist-info/RECORD +0 -139
  181. torchzero-0.3.10.dist-info/licenses/LICENSE +0 -21
  182. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
@@ -1,69 +1,32 @@
1
+ # pyright: reportArgumentType=false
2
+ import math
3
+ from collections import deque
1
4
  from collections.abc import Callable
2
- from typing import overload
5
+ from typing import Any, NamedTuple, overload
6
+
3
7
  import torch
4
8
 
5
- from .. import TensorList, generic_zeros_like, generic_vector_norm, generic_numel, generic_randn_like, generic_eq
9
+ from .. import (
10
+ TensorList,
11
+ generic_eq,
12
+ generic_finfo_tiny,
13
+ generic_numel,
14
+ generic_vector_norm,
15
+ generic_zeros_like,
16
+ )
6
17
 
7
- @overload
8
- def cg(
9
- A_mm: Callable[[torch.Tensor], torch.Tensor],
10
- b: torch.Tensor,
11
- x0_: torch.Tensor | None = None,
12
- tol: float | None = 1e-4,
13
- maxiter: int | None = None,
14
- reg: float = 0,
15
- ) -> torch.Tensor: ...
16
- @overload
17
- def cg(
18
- A_mm: Callable[[TensorList], TensorList],
19
- b: TensorList,
20
- x0_: TensorList | None = None,
21
- tol: float | None = 1e-4,
22
- maxiter: int | None = None,
23
- reg: float | list[float] | tuple[float] = 0,
24
- ) -> TensorList: ...
25
18
 
26
- def cg(
27
- A_mm: Callable,
28
- b: torch.Tensor | TensorList,
29
- x0_: torch.Tensor | TensorList | None = None,
30
- tol: float | None = 1e-4,
31
- maxiter: int | None = None,
32
- reg: float | list[float] | tuple[float] = 0,
33
- ):
19
+ def _make_A_mm_reg(A_mm: Callable, reg):
34
20
  def A_mm_reg(x): # A_mm with regularization
35
21
  Ax = A_mm(x)
36
22
  if not generic_eq(reg, 0): Ax += x*reg
37
23
  return Ax
24
+ return A_mm_reg
38
25
 
39
- if maxiter is None: maxiter = generic_numel(b)
40
- if x0_ is None: x0_ = generic_zeros_like(b)
41
-
42
- x = x0_
43
- residual = b - A_mm_reg(x)
44
- p = residual.clone() # search direction
45
- r_norm = generic_vector_norm(residual)
46
- init_norm = r_norm
47
- if tol is not None and r_norm < tol: return x
48
- k = 0
49
-
50
- while True:
51
- Ap = A_mm_reg(p)
52
- step_size = (r_norm**2) / p.dot(Ap)
53
- x += step_size * p # Update solution
54
- residual -= step_size * Ap # Update residual
55
- new_r_norm = generic_vector_norm(residual)
56
-
57
- k += 1
58
- if tol is not None and new_r_norm <= tol * init_norm: return x
59
- if k >= maxiter: return x
60
-
61
- beta = (new_r_norm**2) / (r_norm**2)
62
- p = residual + beta*p
63
- r_norm = new_r_norm
26
+ def _identity(x): return x
64
27
 
65
28
 
66
- # https://arxiv.org/pdf/2110.02820 algorithm 2.1 apparently supposed to be diabolical
29
+ # https://arxiv.org/pdf/2110.02820
67
30
  def nystrom_approximation(
68
31
  A_mm: Callable[[torch.Tensor], torch.Tensor],
69
32
  ndim: int,
@@ -85,7 +48,6 @@ def nystrom_approximation(
85
48
  lambd = (S.pow(2) - v).clip(min=0) #Remove shift, compute eigs
86
49
  return U, lambd
87
50
 
88
- # this one works worse
89
51
  def nystrom_sketch_and_solve(
90
52
  A_mm: Callable[[torch.Tensor], torch.Tensor],
91
53
  b: torch.Tensor,
@@ -111,7 +73,6 @@ def nystrom_sketch_and_solve(
111
73
  term2 = (1.0 / reg) * (b - U @ Uᵀb)
112
74
  return (term1 + term2).squeeze(-1)
113
75
 
114
- # this one is insane
115
76
  def nystrom_pcg(
116
77
  A_mm: Callable[[torch.Tensor], torch.Tensor],
117
78
  b: torch.Tensor,
@@ -131,6 +92,8 @@ def nystrom_pcg(
131
92
  generator=generator,
132
93
  )
133
94
  lambd += reg
95
+ eps = torch.finfo(b.dtype).tiny * 2
96
+ if tol is None: tol = eps
134
97
 
135
98
  def A_mm_reg(x): # A_mm with regularization
136
99
  Ax = A_mm(x)
@@ -150,7 +113,7 @@ def nystrom_pcg(
150
113
  p = z.clone() # search direction
151
114
 
152
115
  init_norm = torch.linalg.vector_norm(residual) # pylint:disable=not-callable
153
- if tol is not None and init_norm < tol: return x
116
+ if init_norm < tol: return x
154
117
  k = 0
155
118
  while True:
156
119
  Ap = A_mm_reg(p)
@@ -160,10 +123,358 @@ def nystrom_pcg(
160
123
  residual -= step_size * Ap
161
124
 
162
125
  k += 1
163
- if tol is not None and torch.linalg.vector_norm(residual) <= tol * init_norm: return x # pylint:disable=not-callable
126
+ if torch.linalg.vector_norm(residual) <= tol * init_norm: return x # pylint:disable=not-callable
164
127
  if k >= maxiter: return x
165
128
 
166
129
  z = P_inv @ residual
167
130
  beta = residual.dot(z) / rz
168
131
  p = z + p*beta
169
132
 
133
+
134
+ def _safe_clip(x: torch.Tensor):
135
+ """makes sure scalar tensor x is not smaller than tiny"""
136
+ assert x.numel() == 1, x.shape
137
+ eps = torch.finfo(x.dtype).tiny * 2
138
+ if x.abs() < eps: return x.new_full(x.size(), eps).copysign(x)
139
+ return x
140
+
141
+ def _trust_tau(x,d,trust_radius):
142
+ xx = x.dot(x)
143
+ xd = x.dot(d)
144
+ dd = _safe_clip(d.dot(d))
145
+
146
+ rad = (xd**2 - dd * (xx - trust_radius**2)).clip(min=0).sqrt()
147
+ tau = (-xd + rad) / dd
148
+
149
+ return x + tau * d
150
+
151
+
152
+ class CG:
153
+ """Conjugate gradient method.
154
+
155
+ Args:
156
+ A_mm (Callable[[torch.Tensor], torch.Tensor] | torch.Tensor): Callable that returns matvec ``Ax``.
157
+ b (torch.Tensor): right hand side
158
+ x0 (torch.Tensor | None, optional): initial guess, defaults to zeros. Defaults to None.
159
+ tol (float | None, optional): tolerance for convergence. Defaults to 1e-8.
160
+ maxiter (int | None, optional):
161
+ maximum number of iterations, if None sets to number of dimensions. Defaults to None.
162
+ reg (float, optional): regularization. Defaults to 0.
163
+ trust_radius (float | None, optional):
164
+ CG is terminated whenever solution exceeds trust region, returning a solution modified to be within it. Defaults to None.
165
+ npc_terminate (bool, optional):
166
+ whether to terminate CG whenever negative curavture is detected. Defaults to False.
167
+ miniter (int, optional):
168
+ minimal number of iterations even if tolerance is satisfied, this ensures some progress
169
+ is always made.
170
+ history_size (int, optional):
171
+ number of past iterations to store, to re-use them when trust radius is decreased.
172
+ P_mm (Callable | torch.Tensor | None, optional):
173
+ Callable that returns inverse preconditioner times vector. Defaults to None.
174
+ """
175
+ def __init__(
176
+ self,
177
+ A_mm: Callable,
178
+ b: torch.Tensor | TensorList,
179
+ x0: torch.Tensor | TensorList | None = None,
180
+ tol: float | None = 1e-4,
181
+ maxiter: int | None = None,
182
+ reg: float = 0,
183
+ trust_radius: float | None = None,
184
+ npc_terminate: bool=False,
185
+ miniter: int = 0,
186
+ history_size: int = 0,
187
+ P_mm: Callable | None = None,
188
+ ):
189
+ # --------------------------------- set attrs -------------------------------- #
190
+ self.A_mm = _make_A_mm_reg(A_mm, reg)
191
+ self.b = b
192
+ if tol is None: tol = generic_finfo_tiny(b) * 2
193
+ self.tol = tol
194
+ self.eps = generic_finfo_tiny(b) * 2
195
+ if maxiter is None: maxiter = generic_numel(b)
196
+ self.maxiter = maxiter
197
+ self.miniter = miniter
198
+ self.trust_radius = trust_radius
199
+ self.npc_terminate = npc_terminate
200
+ self.P_mm = P_mm if P_mm is not None else _identity
201
+
202
+ if history_size > 0:
203
+ self.history = deque(maxlen = history_size)
204
+ """history of (x, x_norm, d)"""
205
+ else:
206
+ self.history = None
207
+
208
+ # -------------------------------- initialize -------------------------------- #
209
+
210
+ self.iter = 0
211
+
212
+ if x0 is None:
213
+ self.x = generic_zeros_like(b)
214
+ self.r = b
215
+ else:
216
+ self.x = x0
217
+ self.r = b - A_mm(self.x)
218
+
219
+ self.z = self.P_mm(self.r)
220
+ self.d = self.z
221
+
222
+ if self.history is not None:
223
+ self.history.append((self.x, generic_vector_norm(self.x), self.d))
224
+
225
+ def step(self) -> tuple[Any, bool]:
226
+ """returns ``(solution, should_terminate)``"""
227
+ x, b, d, r, z = self.x, self.b, self.d, self.r, self.z
228
+
229
+ if self.iter >= self.maxiter:
230
+ return x, True
231
+
232
+ Ad = self.A_mm(d)
233
+ dAd = d.dot(Ad)
234
+
235
+ # check negative curvature
236
+ if dAd <= self.eps:
237
+ if self.trust_radius is not None: return _trust_tau(x, d, self.trust_radius), True
238
+ if self.iter == 0: return b * (b.dot(b) / dAd).abs(), True
239
+ if self.npc_terminate: return x, True
240
+
241
+ rz = r.dot(z)
242
+ alpha = rz / dAd
243
+ x_next = x + alpha * d
244
+
245
+ # check if the step exceeds the trust-region boundary
246
+ x_next_norm = None
247
+ if self.trust_radius is not None:
248
+ x_next_norm = generic_vector_norm(x_next)
249
+ if x_next_norm >= self.trust_radius:
250
+ return _trust_tau(x, d, self.trust_radius), True
251
+
252
+ # update step, residual and direction
253
+ r_next = r - alpha * Ad
254
+
255
+ # check if r is sufficiently small
256
+ if self.iter >= self.miniter and generic_vector_norm(r_next) < self.tol:
257
+ return x_next, True
258
+
259
+ # update d, r, z
260
+ z_next = self.P_mm(r_next)
261
+ beta = r_next.dot(z_next) / rz
262
+
263
+ self.d = z_next + beta * d
264
+ self.x = x_next
265
+ self.r = r_next
266
+ self.z = z_next
267
+
268
+ # update history
269
+ if self.history is not None:
270
+ if x_next_norm is None: x_next_norm = generic_vector_norm(x_next)
271
+ self.history.append((self.x, x_next_norm, self.d))
272
+
273
+ self.iter += 1
274
+ return x, False
275
+
276
+
277
+ def solve(self):
278
+ # return initial guess if it is good enough
279
+ if self.miniter < 1 and generic_vector_norm(self.r) < self.tol:
280
+ return self.x
281
+
282
+ should_terminate = False
283
+ sol = None
284
+
285
+ while not should_terminate:
286
+ sol, should_terminate = self.step()
287
+
288
+ assert sol is not None
289
+ return sol
290
+
291
+ def find_within_trust_radius(history, trust_radius: float):
292
+ """find first ``x`` in history that exceeds trust radius, if no such ``x`` exists, returns ``None``"""
293
+ for x, x_norm, d in reversed(tuple(history)):
294
+ if x_norm <= trust_radius:
295
+ return _trust_tau(x, d, trust_radius)
296
+ return None
297
+
298
+ class _TensorSolution(NamedTuple):
299
+ x: torch.Tensor
300
+ solver: CG
301
+
302
+ class _TensorListSolution(NamedTuple):
303
+ x: TensorList
304
+ solver: CG
305
+
306
+
307
+ @overload
308
+ def cg(
309
+ A_mm: Callable[[torch.Tensor], torch.Tensor],
310
+ b: torch.Tensor,
311
+ x0: torch.Tensor | None = None,
312
+ tol: float | None = 1e-8,
313
+ maxiter: int | None = None,
314
+ reg: float = 0,
315
+ trust_radius: float | None = None,
316
+ npc_terminate: bool = False,
317
+ miniter: int = 0,
318
+ history_size: int = 0,
319
+ P_mm: Callable[[torch.Tensor], torch.Tensor] | None = None
320
+ ) -> _TensorSolution: ...
321
+ @overload
322
+ def cg(
323
+ A_mm: Callable[[TensorList], TensorList],
324
+ b: TensorList,
325
+ x0: TensorList | None = None,
326
+ tol: float | None = 1e-8,
327
+ maxiter: int | None = None,
328
+ reg: float | list[float] | tuple[float] = 0,
329
+ trust_radius: float | None = None,
330
+ npc_terminate: bool=False,
331
+ miniter: int = 0,
332
+ history_size: int = 0,
333
+ P_mm: Callable[[TensorList], TensorList] | None = None
334
+ ) -> _TensorListSolution: ...
335
+ def cg(
336
+ A_mm: Callable,
337
+ b: torch.Tensor | TensorList,
338
+ x0: torch.Tensor | TensorList | None = None,
339
+ tol: float | None = 1e-8,
340
+ maxiter: int | None = None,
341
+ reg: float | list[float] | tuple[float] = 0,
342
+ trust_radius: float | None = None,
343
+ npc_terminate: bool = False,
344
+ miniter: int = 0,
345
+ history_size:int = 0,
346
+ P_mm: Callable | None = None
347
+ ):
348
+ solver = CG(
349
+ A_mm=A_mm,
350
+ b=b,
351
+ x0=x0,
352
+ tol=tol,
353
+ maxiter=maxiter,
354
+ reg=reg,
355
+ trust_radius=trust_radius,
356
+ npc_terminate=npc_terminate,
357
+ miniter=miniter,
358
+ history_size=history_size,
359
+ P_mm=P_mm,
360
+ )
361
+
362
+ x = solver.solve()
363
+
364
+ if isinstance(b, torch.Tensor):
365
+ return _TensorSolution(x, solver)
366
+
367
+ return _TensorListSolution(x, solver)
368
+
369
+
370
+ # Liu, Yang, and Fred Roosta. "MINRES: From negative curvature detection to monotonicity properties." SIAM Journal on Optimization 32.4 (2022): 2636-2661.
371
+ @overload
372
+ def minres(
373
+ A_mm: Callable[[torch.Tensor], torch.Tensor] | torch.Tensor,
374
+ b: torch.Tensor,
375
+ x0: torch.Tensor | None = None,
376
+ tol: float | None = 1e-4,
377
+ maxiter: int | None = None,
378
+ reg: float = 0,
379
+ npc_terminate: bool=True,
380
+ trust_radius: float | None = None,
381
+ ) -> torch.Tensor: ...
382
+ @overload
383
+ def minres(
384
+ A_mm: Callable[[TensorList], TensorList],
385
+ b: TensorList,
386
+ x0: TensorList | None = None,
387
+ tol: float | None = 1e-4,
388
+ maxiter: int | None = None,
389
+ reg: float | list[float] | tuple[float] = 0,
390
+ npc_terminate: bool=True,
391
+ trust_radius: float | None = None,
392
+ ) -> TensorList: ...
393
+ def minres(
394
+ A_mm,
395
+ b,
396
+ x0: torch.Tensor | TensorList | None = None,
397
+ tol: float | None = 1e-4,
398
+ maxiter: int | None = None,
399
+ reg: float | list[float] | tuple[float] = 0,
400
+ npc_terminate: bool=True,
401
+ trust_radius: float | None = None, #trust region is experimental
402
+ ):
403
+ A_mm_reg = _make_A_mm_reg(A_mm, reg)
404
+ eps = math.sqrt(generic_finfo_tiny(b) * 2)
405
+ if tol is None: tol = eps
406
+
407
+ if maxiter is None: maxiter = generic_numel(b)
408
+ if x0 is None:
409
+ R = b
410
+ x0 = generic_zeros_like(b)
411
+ else:
412
+ R = b - A_mm_reg(x0)
413
+
414
+ X: Any = x0
415
+ beta = b_norm = generic_vector_norm(b)
416
+ if b_norm < eps**2:
417
+ return generic_zeros_like(b)
418
+
419
+
420
+ V = b / beta
421
+ V_prev = generic_zeros_like(b)
422
+ D = generic_zeros_like(b)
423
+ D_prev = generic_zeros_like(b)
424
+
425
+ c = -1
426
+ phi = tau = beta
427
+ s = delta1 = e = 0
428
+
429
+
430
+ for _ in range(maxiter):
431
+
432
+ P = A_mm_reg(V)
433
+ alpha = V.dot(P)
434
+ P -= beta*V_prev
435
+ P -= alpha*V
436
+ beta = generic_vector_norm(P)
437
+
438
+ delta2 = c*delta1 + s*alpha
439
+ gamma1 = s*delta1 - c*alpha
440
+ e_next = s*beta
441
+ delta1 = -c*beta
442
+
443
+ cgamma1 = c*gamma1
444
+ if trust_radius is not None and cgamma1 >= 0:
445
+ if npc_terminate: return _trust_tau(X, R, trust_radius)
446
+ return _trust_tau(X, D, trust_radius)
447
+
448
+ if npc_terminate and cgamma1 >= 0:
449
+ return R
450
+
451
+ gamma2 = (gamma1**2 + beta**2)**(1/2)
452
+
453
+ if abs(gamma2) <= eps: # singular system
454
+ # c=0; s=1; tau=0
455
+ if trust_radius is None: return X
456
+ return _trust_tau(X, D, trust_radius)
457
+
458
+ c = gamma1 / gamma2
459
+ s = beta/gamma2
460
+ tau = c*phi
461
+ phi = s*phi
462
+
463
+ D_prev = D
464
+ D = (V - delta2*D - e*D_prev) / gamma2
465
+ e = e_next
466
+ X = X + tau*D
467
+
468
+ if trust_radius is not None:
469
+ if generic_vector_norm(X) > trust_radius:
470
+ return _trust_tau(X, D, trust_radius)
471
+
472
+ if (abs(beta) < eps) or (phi / b_norm <= tol):
473
+ # R = zeros(R)
474
+ return X
475
+
476
+ V_prev = V
477
+ V = P/beta
478
+ R = s**2*R - phi*c*V
479
+
480
+ return X
@@ -0,0 +1,83 @@
1
+ """convenience submodule which allows to calculate a metric based on its string name,
2
+ used in many places"""
3
+ from abc import ABC, abstractmethod
4
+ from collections.abc import Sequence
5
+ from typing import TYPE_CHECKING, Any, Literal, overload
6
+
7
+ import torch
8
+
9
+ if TYPE_CHECKING:
10
+ from .tensorlist import TensorList
11
+
12
+
13
+
14
+ class Metric(ABC):
15
+ @abstractmethod
16
+ def evaluate_global(self, x: "TensorList") -> torch.Tensor:
17
+ """returns a global metric for a tensorlist"""
18
+
19
+ @abstractmethod
20
+ def evaluate_tensor(self, x: torch.Tensor, dim=None, keepdim=False) -> torch.Tensor:
21
+ """returns metric for a tensor"""
22
+
23
+ def evaluate_list(self, x: "TensorList") -> "TensorList":
24
+ """returns list of metrics for a tensorlist (possibly vectorized)"""
25
+ return x.map(self.evaluate_tensor)
26
+
27
+
28
+ class _MAD(Metric):
29
+ def evaluate_global(self, x): return x.abs().global_mean()
30
+ def evaluate_tensor(self, x, dim=None, keepdim=False): return x.abs().mean(dim=dim, keepdim=keepdim)
31
+ def evaluate_list(self, x): return x.abs().mean()
32
+
33
+ class _Std(Metric):
34
+ def evaluate_global(self, x): return x.global_std()
35
+ def evaluate_tensor(self, x, dim=None, keepdim=False): return x.std(dim=dim, keepdim=keepdim)
36
+ def evaluate_list(self, x): return x.std()
37
+
38
+ class _Var(Metric):
39
+ def evaluate_global(self, x): return x.global_var()
40
+ def evaluate_tensor(self, x, dim=None, keepdim=False): return x.var(dim=dim, keepdim=keepdim)
41
+ def evaluate_list(self, x): return x.var()
42
+
43
+ class _Sum(Metric):
44
+ def evaluate_global(self, x): return x.global_sum()
45
+ def evaluate_tensor(self, x, dim=None, keepdim=False): return x.sum(dim=dim, keepdim=keepdim)
46
+ def evaluate_list(self, x): return x.sum()
47
+
48
+ class _Norm(Metric):
49
+ def __init__(self, ord): self.ord = ord
50
+ def evaluate_global(self, x): return x.global_vector_norm(self.ord)
51
+ def evaluate_tensor(self, x, dim=None, keepdim=False):
52
+ return torch.linalg.vector_norm(x, ord=self.ord, dim=dim, keepdim=keepdim) # pylint:disable=not-callable
53
+ def evaluate_list(self, x): return x.norm(self.ord)
54
+
55
+ _METRIC_KEYS = Literal['mad', 'std', 'var', 'sum', 'l0', 'l1', 'l2', 'l3', 'l4', 'linf']
56
+ _METRICS: dict[_METRIC_KEYS, Metric] = {
57
+ "mad": _MAD(),
58
+ "std": _Std(),
59
+ "var": _Var(),
60
+ "sum": _Sum(),
61
+ "l0": _Norm(0),
62
+ "l1": _Norm(1),
63
+ "l2": _Norm(2),
64
+ "l3": _Norm(3),
65
+ "l4": _Norm(4),
66
+ "linf": _Norm(torch.inf),
67
+ }
68
+
69
+ Metrics = _METRIC_KEYS | float | torch.Tensor
70
+ def evaluate_metric(x: "torch.Tensor | TensorList", metric: Metrics) -> torch.Tensor:
71
+ if isinstance(metric, (int, float, torch.Tensor)):
72
+ if isinstance(x, torch.Tensor): return torch.linalg.vector_norm(x, ord=metric) # pylint:disable=not-callable
73
+ return x.global_vector_norm(ord=float(metric))
74
+
75
+ if isinstance(x, torch.Tensor): return _METRICS[metric].evaluate_tensor(x)
76
+ return _METRICS[metric].evaluate_global(x)
77
+
78
+
79
+ def calculate_metric_list(x: "TensorList", metric: Metrics) -> "TensorList":
80
+ if isinstance(metric, (int, float, torch.Tensor)):
81
+ return x.norm(ord=float(metric))
82
+
83
+ return _METRICS[metric].evaluate_list(x)
@@ -129,4 +129,6 @@ class NumberList(list[int | float | Any]):
129
129
  return self.__class__(fn(i, *args, **kwargs) for i in self)
130
130
 
131
131
  def clamp(self, min=None, max=None):
132
+ return self.zipmap_args(_clamp, min, max)
133
+ def clip(self, min=None, max=None):
132
134
  return self.zipmap_args(_clamp, min, max)
@@ -31,6 +31,16 @@ def generic_eq(x: int | float | Iterable[int | float], y: int | float | Iterable
31
31
  return all(i==y for i in x)
32
32
  return all(i==j for i,j in zip(x,y))
33
33
 
34
+ def generic_ne(x: int | float | Iterable[int | float], y: int | float | Iterable[int | float]) -> bool:
35
+ """generic not equals function that supports scalars and lists of numbers. Faster than not generic_eq"""
36
+ if isinstance(x, (int,float)):
37
+ if isinstance(y, (int,float)): return x!=y
38
+ return any(i!=x for i in y)
39
+ if isinstance(y, (int,float)):
40
+ return any(i!=y for i in x)
41
+ return any(i!=j for i,j in zip(x,y))
42
+
43
+
34
44
  def zipmap(self, fn: Callable, other: Any | list | tuple, *args, **kwargs):
35
45
  """If `other` is list/tuple, applies `fn` to self zipped with `other`.
36
46
  Otherwise applies `fn` to this sequence and `other`.
@@ -51,3 +61,9 @@ def unpack_dicts(dicts: Iterable[Mapping[str, Any]], key:str, key2: str | None =
51
61
  values = [cls(s[k] for s in dicts) for k in keys] # pyright:ignore[reportCallIssue]
52
62
  if len(values) == 1: return values[0]
53
63
  return values
64
+
65
+
66
+ def safe_dict_update_(d1_:dict, d2:dict):
67
+ inter = set(d1_.keys()).intersection(d2.keys())
68
+ if len(inter) > 0: raise RuntimeError(f"Duplicate keys {inter}")
69
+ d1_.update(d2)