torchzero 0.3.11__py3-none-any.whl → 0.3.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (161) hide show
  1. tests/test_opts.py +95 -69
  2. tests/test_tensorlist.py +8 -7
  3. torchzero/__init__.py +1 -1
  4. torchzero/core/__init__.py +2 -2
  5. torchzero/core/module.py +225 -72
  6. torchzero/core/reformulation.py +65 -0
  7. torchzero/core/transform.py +44 -24
  8. torchzero/modules/__init__.py +13 -5
  9. torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
  10. torchzero/modules/adaptive/adagrad.py +356 -0
  11. torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
  12. torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
  13. torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
  14. torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
  15. torchzero/modules/adaptive/aegd.py +54 -0
  16. torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
  17. torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
  18. torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
  19. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  20. torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
  21. torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
  22. torchzero/modules/adaptive/natural_gradient.py +175 -0
  23. torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
  24. torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
  25. torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
  26. torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
  27. torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
  28. torchzero/modules/clipping/clipping.py +85 -92
  29. torchzero/modules/clipping/ema_clipping.py +5 -5
  30. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  31. torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
  32. torchzero/modules/experimental/__init__.py +9 -32
  33. torchzero/modules/experimental/dct.py +2 -2
  34. torchzero/modules/experimental/fft.py +2 -2
  35. torchzero/modules/experimental/gradmin.py +4 -3
  36. torchzero/modules/experimental/l_infinity.py +111 -0
  37. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
  38. torchzero/modules/experimental/newton_solver.py +79 -17
  39. torchzero/modules/experimental/newtonnewton.py +27 -14
  40. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  41. torchzero/modules/experimental/structural_projections.py +1 -1
  42. torchzero/modules/functional.py +50 -14
  43. torchzero/modules/grad_approximation/fdm.py +19 -20
  44. torchzero/modules/grad_approximation/forward_gradient.py +4 -2
  45. torchzero/modules/grad_approximation/grad_approximator.py +43 -47
  46. torchzero/modules/grad_approximation/rfdm.py +144 -122
  47. torchzero/modules/higher_order/__init__.py +1 -1
  48. torchzero/modules/higher_order/higher_order_newton.py +31 -23
  49. torchzero/modules/least_squares/__init__.py +1 -0
  50. torchzero/modules/least_squares/gn.py +161 -0
  51. torchzero/modules/line_search/__init__.py +2 -2
  52. torchzero/modules/line_search/_polyinterp.py +289 -0
  53. torchzero/modules/line_search/adaptive.py +69 -44
  54. torchzero/modules/line_search/backtracking.py +83 -70
  55. torchzero/modules/line_search/line_search.py +159 -68
  56. torchzero/modules/line_search/scipy.py +1 -1
  57. torchzero/modules/line_search/strong_wolfe.py +319 -218
  58. torchzero/modules/misc/__init__.py +8 -0
  59. torchzero/modules/misc/debug.py +4 -4
  60. torchzero/modules/misc/escape.py +9 -7
  61. torchzero/modules/misc/gradient_accumulation.py +88 -22
  62. torchzero/modules/misc/homotopy.py +59 -0
  63. torchzero/modules/misc/misc.py +82 -15
  64. torchzero/modules/misc/multistep.py +47 -11
  65. torchzero/modules/misc/regularization.py +5 -9
  66. torchzero/modules/misc/split.py +55 -35
  67. torchzero/modules/misc/switch.py +1 -1
  68. torchzero/modules/momentum/__init__.py +1 -5
  69. torchzero/modules/momentum/averaging.py +3 -3
  70. torchzero/modules/momentum/cautious.py +42 -47
  71. torchzero/modules/momentum/momentum.py +35 -1
  72. torchzero/modules/ops/__init__.py +9 -1
  73. torchzero/modules/ops/binary.py +9 -8
  74. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
  75. torchzero/modules/ops/multi.py +15 -15
  76. torchzero/modules/ops/reduce.py +1 -1
  77. torchzero/modules/ops/utility.py +12 -8
  78. torchzero/modules/projections/projection.py +4 -4
  79. torchzero/modules/quasi_newton/__init__.py +1 -16
  80. torchzero/modules/quasi_newton/damping.py +105 -0
  81. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
  82. torchzero/modules/quasi_newton/lbfgs.py +256 -200
  83. torchzero/modules/quasi_newton/lsr1.py +167 -132
  84. torchzero/modules/quasi_newton/quasi_newton.py +346 -446
  85. torchzero/modules/restarts/__init__.py +7 -0
  86. torchzero/modules/restarts/restars.py +252 -0
  87. torchzero/modules/second_order/__init__.py +2 -1
  88. torchzero/modules/second_order/multipoint.py +238 -0
  89. torchzero/modules/second_order/newton.py +133 -88
  90. torchzero/modules/second_order/newton_cg.py +141 -80
  91. torchzero/modules/smoothing/__init__.py +1 -1
  92. torchzero/modules/smoothing/sampling.py +300 -0
  93. torchzero/modules/step_size/__init__.py +1 -1
  94. torchzero/modules/step_size/adaptive.py +312 -47
  95. torchzero/modules/termination/__init__.py +14 -0
  96. torchzero/modules/termination/termination.py +207 -0
  97. torchzero/modules/trust_region/__init__.py +5 -0
  98. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  99. torchzero/modules/trust_region/dogleg.py +92 -0
  100. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  101. torchzero/modules/trust_region/trust_cg.py +97 -0
  102. torchzero/modules/trust_region/trust_region.py +350 -0
  103. torchzero/modules/variance_reduction/__init__.py +1 -0
  104. torchzero/modules/variance_reduction/svrg.py +208 -0
  105. torchzero/modules/weight_decay/weight_decay.py +65 -64
  106. torchzero/modules/zeroth_order/__init__.py +1 -0
  107. torchzero/modules/zeroth_order/cd.py +359 -0
  108. torchzero/optim/root.py +65 -0
  109. torchzero/optim/utility/split.py +8 -8
  110. torchzero/optim/wrappers/directsearch.py +0 -1
  111. torchzero/optim/wrappers/fcmaes.py +3 -2
  112. torchzero/optim/wrappers/nlopt.py +0 -2
  113. torchzero/optim/wrappers/optuna.py +2 -2
  114. torchzero/optim/wrappers/scipy.py +81 -22
  115. torchzero/utils/__init__.py +40 -4
  116. torchzero/utils/compile.py +1 -1
  117. torchzero/utils/derivatives.py +123 -111
  118. torchzero/utils/linalg/__init__.py +9 -2
  119. torchzero/utils/linalg/linear_operator.py +329 -0
  120. torchzero/utils/linalg/matrix_funcs.py +2 -2
  121. torchzero/utils/linalg/orthogonalize.py +2 -1
  122. torchzero/utils/linalg/qr.py +2 -2
  123. torchzero/utils/linalg/solve.py +226 -154
  124. torchzero/utils/metrics.py +83 -0
  125. torchzero/utils/python_tools.py +6 -0
  126. torchzero/utils/tensorlist.py +105 -34
  127. torchzero/utils/torch_tools.py +9 -4
  128. torchzero-0.3.13.dist-info/METADATA +14 -0
  129. torchzero-0.3.13.dist-info/RECORD +166 -0
  130. {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
  131. docs/source/conf.py +0 -59
  132. docs/source/docstring template.py +0 -46
  133. torchzero/modules/experimental/absoap.py +0 -253
  134. torchzero/modules/experimental/adadam.py +0 -118
  135. torchzero/modules/experimental/adamY.py +0 -131
  136. torchzero/modules/experimental/adam_lambertw.py +0 -149
  137. torchzero/modules/experimental/adaptive_step_size.py +0 -90
  138. torchzero/modules/experimental/adasoap.py +0 -177
  139. torchzero/modules/experimental/cosine.py +0 -214
  140. torchzero/modules/experimental/cubic_adam.py +0 -97
  141. torchzero/modules/experimental/eigendescent.py +0 -120
  142. torchzero/modules/experimental/etf.py +0 -195
  143. torchzero/modules/experimental/exp_adam.py +0 -113
  144. torchzero/modules/experimental/expanded_lbfgs.py +0 -141
  145. torchzero/modules/experimental/hnewton.py +0 -85
  146. torchzero/modules/experimental/modular_lbfgs.py +0 -265
  147. torchzero/modules/experimental/parabolic_search.py +0 -220
  148. torchzero/modules/experimental/subspace_preconditioners.py +0 -145
  149. torchzero/modules/experimental/tensor_adagrad.py +0 -42
  150. torchzero/modules/line_search/polynomial.py +0 -233
  151. torchzero/modules/momentum/matrix_momentum.py +0 -193
  152. torchzero/modules/optimizers/adagrad.py +0 -165
  153. torchzero/modules/quasi_newton/trust_region.py +0 -397
  154. torchzero/modules/smoothing/gaussian.py +0 -198
  155. torchzero-0.3.11.dist-info/METADATA +0 -404
  156. torchzero-0.3.11.dist-info/RECORD +0 -159
  157. torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
  158. /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
  159. /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
  160. /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
  161. {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
@@ -1,99 +1,32 @@
1
1
  # pyright: reportArgumentType=false
2
+ import math
3
+ from collections import deque
2
4
  from collections.abc import Callable
3
- from typing import Any, overload
5
+ from typing import Any, NamedTuple, overload
4
6
 
5
7
  import torch
6
8
 
7
9
  from .. import (
8
10
  TensorList,
9
11
  generic_eq,
10
- generic_finfo_eps,
12
+ generic_finfo_tiny,
11
13
  generic_numel,
12
- generic_randn_like,
13
14
  generic_vector_norm,
14
15
  generic_zeros_like,
15
16
  )
16
17
 
17
18
 
18
- def _make_A_mm_reg(A_mm: Callable | torch.Tensor, reg):
19
- if callable(A_mm):
20
- def A_mm_reg(x): # A_mm with regularization
21
- Ax = A_mm(x)
22
- if not generic_eq(reg, 0): Ax += x*reg
23
- return Ax
24
- return A_mm_reg
25
-
26
- if not isinstance(A_mm, torch.Tensor): raise TypeError(type(A_mm))
27
-
28
- def Ax_reg(x): # A_mm with regularization
29
- if A_mm.ndim == 1: Ax = A_mm * x
30
- else: Ax = A_mm @ x
31
- if reg != 0: Ax += x*reg
19
+ def _make_A_mm_reg(A_mm: Callable, reg):
20
+ def A_mm_reg(x): # A_mm with regularization
21
+ Ax = A_mm(x)
22
+ if not generic_eq(reg, 0): Ax += x*reg
32
23
  return Ax
33
- return Ax_reg
24
+ return A_mm_reg
34
25
 
26
+ def _identity(x): return x
35
27
 
36
- @overload
37
- def cg(
38
- A_mm: Callable[[torch.Tensor], torch.Tensor] | torch.Tensor,
39
- b: torch.Tensor,
40
- x0_: torch.Tensor | None = None,
41
- tol: float | None = 1e-4,
42
- maxiter: int | None = None,
43
- reg: float = 0,
44
- ) -> torch.Tensor: ...
45
- @overload
46
- def cg(
47
- A_mm: Callable[[TensorList], TensorList],
48
- b: TensorList,
49
- x0_: TensorList | None = None,
50
- tol: float | None = 1e-4,
51
- maxiter: int | None = None,
52
- reg: float | list[float] | tuple[float] = 0,
53
- ) -> TensorList: ...
54
28
 
55
- def cg(
56
- A_mm: Callable | torch.Tensor,
57
- b: torch.Tensor | TensorList,
58
- x0_: torch.Tensor | TensorList | None = None,
59
- tol: float | None = 1e-4,
60
- maxiter: int | None = None,
61
- reg: float | list[float] | tuple[float] = 0,
62
- ):
63
- A_mm_reg = _make_A_mm_reg(A_mm, reg)
64
- eps = generic_finfo_eps(b)
65
-
66
- if tol is None: tol = eps
67
-
68
- if maxiter is None: maxiter = generic_numel(b)
69
- if x0_ is None: x0_ = generic_zeros_like(b)
70
-
71
- x = x0_
72
- residual = b - A_mm_reg(x)
73
- p = residual.clone() # search direction
74
- r_norm = generic_vector_norm(residual)
75
- init_norm = r_norm
76
- if r_norm < tol: return x
77
- k = 0
78
-
79
-
80
- while True:
81
- Ap = A_mm_reg(p)
82
- step_size = (r_norm**2) / p.dot(Ap)
83
- x += step_size * p # Update solution
84
- residual -= step_size * Ap # Update residual
85
- new_r_norm = generic_vector_norm(residual)
86
-
87
- k += 1
88
- if new_r_norm <= tol * init_norm: return x
89
- if k >= maxiter: return x
90
-
91
- beta = (new_r_norm**2) / (r_norm**2)
92
- p = residual + beta*p
93
- r_norm = new_r_norm
94
-
95
-
96
- # https://arxiv.org/pdf/2110.02820 algorithm 2.1 apparently supposed to be diabolical
29
+ # https://arxiv.org/pdf/2110.02820
97
30
  def nystrom_approximation(
98
31
  A_mm: Callable[[torch.Tensor], torch.Tensor],
99
32
  ndim: int,
@@ -115,7 +48,6 @@ def nystrom_approximation(
115
48
  lambd = (S.pow(2) - v).clip(min=0) #Remove shift, compute eigs
116
49
  return U, lambd
117
50
 
118
- # this one works worse
119
51
  def nystrom_sketch_and_solve(
120
52
  A_mm: Callable[[torch.Tensor], torch.Tensor],
121
53
  b: torch.Tensor,
@@ -141,7 +73,6 @@ def nystrom_sketch_and_solve(
141
73
  term2 = (1.0 / reg) * (b - U @ Uᵀb)
142
74
  return (term1 + term2).squeeze(-1)
143
75
 
144
- # this one is insane
145
76
  def nystrom_pcg(
146
77
  A_mm: Callable[[torch.Tensor], torch.Tensor],
147
78
  b: torch.Tensor,
@@ -161,7 +92,7 @@ def nystrom_pcg(
161
92
  generator=generator,
162
93
  )
163
94
  lambd += reg
164
- eps = torch.finfo(b.dtype).eps ** 2
95
+ eps = torch.finfo(b.dtype).tiny * 2
165
96
  if tol is None: tol = eps
166
97
 
167
98
  def A_mm_reg(x): # A_mm with regularization
@@ -201,98 +132,239 @@ def nystrom_pcg(
201
132
 
202
133
 
203
134
  def _safe_clip(x: torch.Tensor):
204
- """makes sure scalar tensor x is not smaller than epsilon"""
135
+ """makes sure scalar tensor x is not smaller than tiny"""
205
136
  assert x.numel() == 1, x.shape
206
- eps = torch.finfo(x.dtype).eps
137
+ eps = torch.finfo(x.dtype).tiny * 2
207
138
  if x.abs() < eps: return x.new_full(x.size(), eps).copysign(x)
208
139
  return x
209
140
 
210
- def _trust_tau(x,d,trust_region):
141
+ def _trust_tau(x,d,trust_radius):
211
142
  xx = x.dot(x)
212
143
  xd = x.dot(d)
213
144
  dd = _safe_clip(d.dot(d))
214
145
 
215
- rad = (xd**2 - dd * (xx - trust_region**2)).clip(min=0).sqrt()
146
+ rad = (xd**2 - dd * (xx - trust_radius**2)).clip(min=0).sqrt()
216
147
  tau = (-xd + rad) / dd
217
148
 
218
149
  return x + tau * d
219
150
 
220
151
 
152
+ class CG:
153
+ """Conjugate gradient method.
154
+
155
+ Args:
156
+ A_mm (Callable[[torch.Tensor], torch.Tensor] | torch.Tensor): Callable that returns matvec ``Ax``.
157
+ b (torch.Tensor): right hand side
158
+ x0 (torch.Tensor | None, optional): initial guess, defaults to zeros. Defaults to None.
159
+ tol (float | None, optional): tolerance for convergence. Defaults to 1e-8.
160
+ maxiter (int | None, optional):
161
+ maximum number of iterations, if None sets to number of dimensions. Defaults to None.
162
+ reg (float, optional): regularization. Defaults to 0.
163
+ trust_radius (float | None, optional):
164
+ CG is terminated whenever solution exceeds trust region, returning a solution modified to be within it. Defaults to None.
165
+ npc_terminate (bool, optional):
166
+ whether to terminate CG whenever negative curavture is detected. Defaults to False.
167
+ miniter (int, optional):
168
+ minimal number of iterations even if tolerance is satisfied, this ensures some progress
169
+ is always made.
170
+ history_size (int, optional):
171
+ number of past iterations to store, to re-use them when trust radius is decreased.
172
+ P_mm (Callable | torch.Tensor | None, optional):
173
+ Callable that returns inverse preconditioner times vector. Defaults to None.
174
+ """
175
+ def __init__(
176
+ self,
177
+ A_mm: Callable,
178
+ b: torch.Tensor | TensorList,
179
+ x0: torch.Tensor | TensorList | None = None,
180
+ tol: float | None = 1e-4,
181
+ maxiter: int | None = None,
182
+ reg: float = 0,
183
+ trust_radius: float | None = None,
184
+ npc_terminate: bool=False,
185
+ miniter: int = 0,
186
+ history_size: int = 0,
187
+ P_mm: Callable | None = None,
188
+ ):
189
+ # --------------------------------- set attrs -------------------------------- #
190
+ self.A_mm = _make_A_mm_reg(A_mm, reg)
191
+ self.b = b
192
+ if tol is None: tol = generic_finfo_tiny(b) * 2
193
+ self.tol = tol
194
+ self.eps = generic_finfo_tiny(b) * 2
195
+ if maxiter is None: maxiter = generic_numel(b)
196
+ self.maxiter = maxiter
197
+ self.miniter = miniter
198
+ self.trust_radius = trust_radius
199
+ self.npc_terminate = npc_terminate
200
+ self.P_mm = P_mm if P_mm is not None else _identity
201
+
202
+ if history_size > 0:
203
+ self.history = deque(maxlen = history_size)
204
+ """history of (x, x_norm, d)"""
205
+ else:
206
+ self.history = None
207
+
208
+ # -------------------------------- initialize -------------------------------- #
209
+
210
+ self.iter = 0
211
+
212
+ if x0 is None:
213
+ self.x = generic_zeros_like(b)
214
+ self.r = b
215
+ else:
216
+ self.x = x0
217
+ self.r = b - A_mm(self.x)
218
+
219
+ self.z = self.P_mm(self.r)
220
+ self.d = self.z
221
+
222
+ if self.history is not None:
223
+ self.history.append((self.x, generic_vector_norm(self.x), self.d))
224
+
225
+ def step(self) -> tuple[Any, bool]:
226
+ """returns ``(solution, should_terminate)``"""
227
+ x, b, d, r, z = self.x, self.b, self.d, self.r, self.z
228
+
229
+ if self.iter >= self.maxiter:
230
+ return x, True
231
+
232
+ Ad = self.A_mm(d)
233
+ dAd = d.dot(Ad)
234
+
235
+ # check negative curvature
236
+ if dAd <= self.eps:
237
+ if self.trust_radius is not None: return _trust_tau(x, d, self.trust_radius), True
238
+ if self.iter == 0: return b * (b.dot(b) / dAd).abs(), True
239
+ if self.npc_terminate: return x, True
240
+
241
+ rz = r.dot(z)
242
+ alpha = rz / dAd
243
+ x_next = x + alpha * d
244
+
245
+ # check if the step exceeds the trust-region boundary
246
+ x_next_norm = None
247
+ if self.trust_radius is not None:
248
+ x_next_norm = generic_vector_norm(x_next)
249
+ if x_next_norm >= self.trust_radius:
250
+ return _trust_tau(x, d, self.trust_radius), True
251
+
252
+ # update step, residual and direction
253
+ r_next = r - alpha * Ad
254
+
255
+ # check if r is sufficiently small
256
+ if self.iter >= self.miniter and generic_vector_norm(r_next) < self.tol:
257
+ return x_next, True
258
+
259
+ # update d, r, z
260
+ z_next = self.P_mm(r_next)
261
+ beta = r_next.dot(z_next) / rz
262
+
263
+ self.d = z_next + beta * d
264
+ self.x = x_next
265
+ self.r = r_next
266
+ self.z = z_next
267
+
268
+ # update history
269
+ if self.history is not None:
270
+ if x_next_norm is None: x_next_norm = generic_vector_norm(x_next)
271
+ self.history.append((self.x, x_next_norm, self.d))
272
+
273
+ self.iter += 1
274
+ return x, False
275
+
276
+
277
+ def solve(self):
278
+ # return initial guess if it is good enough
279
+ if self.miniter < 1 and generic_vector_norm(self.r) < self.tol:
280
+ return self.x
281
+
282
+ should_terminate = False
283
+ sol = None
284
+
285
+ while not should_terminate:
286
+ sol, should_terminate = self.step()
287
+
288
+ assert sol is not None
289
+ return sol
290
+
291
+ def find_within_trust_radius(history, trust_radius: float):
292
+ """find first ``x`` in history that exceeds trust radius, if no such ``x`` exists, returns ``None``"""
293
+ for x, x_norm, d in reversed(tuple(history)):
294
+ if x_norm <= trust_radius:
295
+ return _trust_tau(x, d, trust_radius)
296
+ return None
297
+
298
+ class _TensorSolution(NamedTuple):
299
+ x: torch.Tensor
300
+ solver: CG
301
+
302
+ class _TensorListSolution(NamedTuple):
303
+ x: TensorList
304
+ solver: CG
305
+
306
+
221
307
  @overload
222
- def steihaug_toint_cg(
223
- A_mm: Callable[[torch.Tensor], torch.Tensor] | torch.Tensor,
308
+ def cg(
309
+ A_mm: Callable[[torch.Tensor], torch.Tensor],
224
310
  b: torch.Tensor,
225
- trust_region: float,
226
311
  x0: torch.Tensor | None = None,
227
- tol: float | None = 1e-4,
312
+ tol: float | None = 1e-8,
228
313
  maxiter: int | None = None,
229
314
  reg: float = 0,
230
- ) -> torch.Tensor: ...
315
+ trust_radius: float | None = None,
316
+ npc_terminate: bool = False,
317
+ miniter: int = 0,
318
+ history_size: int = 0,
319
+ P_mm: Callable[[torch.Tensor], torch.Tensor] | None = None
320
+ ) -> _TensorSolution: ...
231
321
  @overload
232
- def steihaug_toint_cg(
322
+ def cg(
233
323
  A_mm: Callable[[TensorList], TensorList],
234
324
  b: TensorList,
235
- trust_region: float,
236
325
  x0: TensorList | None = None,
237
- tol: float | None = 1e-4,
326
+ tol: float | None = 1e-8,
238
327
  maxiter: int | None = None,
239
328
  reg: float | list[float] | tuple[float] = 0,
240
- ) -> TensorList: ...
241
- def steihaug_toint_cg(
242
- A_mm: Callable | torch.Tensor,
329
+ trust_radius: float | None = None,
330
+ npc_terminate: bool=False,
331
+ miniter: int = 0,
332
+ history_size: int = 0,
333
+ P_mm: Callable[[TensorList], TensorList] | None = None
334
+ ) -> _TensorListSolution: ...
335
+ def cg(
336
+ A_mm: Callable,
243
337
  b: torch.Tensor | TensorList,
244
- trust_region: float,
245
338
  x0: torch.Tensor | TensorList | None = None,
246
- tol: float | None = 1e-4,
339
+ tol: float | None = 1e-8,
247
340
  maxiter: int | None = None,
248
341
  reg: float | list[float] | tuple[float] = 0,
342
+ trust_radius: float | None = None,
343
+ npc_terminate: bool = False,
344
+ miniter: int = 0,
345
+ history_size:int = 0,
346
+ P_mm: Callable | None = None
249
347
  ):
250
- """
251
- Solution is bounded to have L2 norm no larger than :code:`trust_region`. If solution exceeds :code:`trust_region`, CG is terminated early, so it is also faster.
252
- """
253
- A_mm_reg = _make_A_mm_reg(A_mm, reg)
254
-
255
- x = x0
256
- if x is None: x = generic_zeros_like(b)
257
- r = b
258
- d = r.clone()
259
-
260
- eps = generic_finfo_eps(b)**2
261
- if tol is None: tol = eps
262
-
263
- if generic_vector_norm(r) < tol:
264
- return x
265
-
266
- if maxiter is None:
267
- maxiter = generic_numel(b)
268
-
269
- for _ in range(maxiter):
270
- Ad = A_mm_reg(d)
271
-
272
- d_Ad = d.dot(Ad)
273
- if d_Ad <= eps:
274
- return _trust_tau(x, d, trust_region)
275
-
276
- alpha = r.dot(r) / d_Ad
277
- p_next = x + alpha * d
278
-
279
- # check if the step exceeds the trust-region boundary
280
- if generic_vector_norm(p_next) >= trust_region:
281
- return _trust_tau(x, d, trust_region)
282
-
283
- # update step, residual and direction
284
- x = p_next
285
- r_next = r - alpha * Ad
286
-
287
- if generic_vector_norm(r_next) < tol:
288
- return x
348
+ solver = CG(
349
+ A_mm=A_mm,
350
+ b=b,
351
+ x0=x0,
352
+ tol=tol,
353
+ maxiter=maxiter,
354
+ reg=reg,
355
+ trust_radius=trust_radius,
356
+ npc_terminate=npc_terminate,
357
+ miniter=miniter,
358
+ history_size=history_size,
359
+ P_mm=P_mm,
360
+ )
289
361
 
290
- beta = r_next.dot(r_next) / r.dot(r)
291
- d = r_next + beta * d
292
- r = r_next
362
+ x = solver.solve()
293
363
 
294
- return x
364
+ if isinstance(b, torch.Tensor):
365
+ return _TensorSolution(x, solver)
295
366
 
367
+ return _TensorListSolution(x, solver)
296
368
 
297
369
 
298
370
  # Liu, Yang, and Fred Roosta. "MINRES: From negative curvature detection to monotonicity properties." SIAM Journal on Optimization 32.4 (2022): 2636-2661.
@@ -305,7 +377,7 @@ def minres(
305
377
  maxiter: int | None = None,
306
378
  reg: float = 0,
307
379
  npc_terminate: bool=True,
308
- trust_region: float | None = None,
380
+ trust_radius: float | None = None,
309
381
  ) -> torch.Tensor: ...
310
382
  @overload
311
383
  def minres(
@@ -316,7 +388,7 @@ def minres(
316
388
  maxiter: int | None = None,
317
389
  reg: float | list[float] | tuple[float] = 0,
318
390
  npc_terminate: bool=True,
319
- trust_region: float | None = None,
391
+ trust_radius: float | None = None,
320
392
  ) -> TensorList: ...
321
393
  def minres(
322
394
  A_mm,
@@ -326,11 +398,11 @@ def minres(
326
398
  maxiter: int | None = None,
327
399
  reg: float | list[float] | tuple[float] = 0,
328
400
  npc_terminate: bool=True,
329
- trust_region: float | None = None,
401
+ trust_radius: float | None = None, #trust region is experimental
330
402
  ):
331
403
  A_mm_reg = _make_A_mm_reg(A_mm, reg)
332
- eps = generic_finfo_eps(b)
333
- if tol is None: tol = eps**2
404
+ eps = math.sqrt(generic_finfo_tiny(b) * 2)
405
+ if tol is None: tol = eps
334
406
 
335
407
  if maxiter is None: maxiter = generic_numel(b)
336
408
  if x0 is None:
@@ -369,9 +441,9 @@ def minres(
369
441
  delta1 = -c*beta
370
442
 
371
443
  cgamma1 = c*gamma1
372
- if trust_region is not None and cgamma1 >= 0:
373
- if npc_terminate: return _trust_tau(X, R, trust_region)
374
- return _trust_tau(X, D, trust_region)
444
+ if trust_radius is not None and cgamma1 >= 0:
445
+ if npc_terminate: return _trust_tau(X, R, trust_radius)
446
+ return _trust_tau(X, D, trust_radius)
375
447
 
376
448
  if npc_terminate and cgamma1 >= 0:
377
449
  return R
@@ -380,8 +452,8 @@ def minres(
380
452
 
381
453
  if abs(gamma2) <= eps: # singular system
382
454
  # c=0; s=1; tau=0
383
- if trust_region is None: return X
384
- return _trust_tau(X, D, trust_region)
455
+ if trust_radius is None: return X
456
+ return _trust_tau(X, D, trust_radius)
385
457
 
386
458
  c = gamma1 / gamma2
387
459
  s = beta/gamma2
@@ -393,9 +465,9 @@ def minres(
393
465
  e = e_next
394
466
  X = X + tau*D
395
467
 
396
- if trust_region is not None:
397
- if generic_vector_norm(X) > trust_region:
398
- return _trust_tau(X, D, trust_region)
468
+ if trust_radius is not None:
469
+ if generic_vector_norm(X) > trust_radius:
470
+ return _trust_tau(X, D, trust_radius)
399
471
 
400
472
  if (abs(beta) < eps) or (phi / b_norm <= tol):
401
473
  # R = zeros(R)
@@ -0,0 +1,83 @@
1
+ """convenience submodule which allows to calculate a metric based on its string name,
2
+ used in many places"""
3
+ from abc import ABC, abstractmethod
4
+ from collections.abc import Sequence
5
+ from typing import TYPE_CHECKING, Any, Literal, overload
6
+
7
+ import torch
8
+
9
+ if TYPE_CHECKING:
10
+ from .tensorlist import TensorList
11
+
12
+
13
+
14
+ class Metric(ABC):
15
+ @abstractmethod
16
+ def evaluate_global(self, x: "TensorList") -> torch.Tensor:
17
+ """returns a global metric for a tensorlist"""
18
+
19
+ @abstractmethod
20
+ def evaluate_tensor(self, x: torch.Tensor, dim=None, keepdim=False) -> torch.Tensor:
21
+ """returns metric for a tensor"""
22
+
23
+ def evaluate_list(self, x: "TensorList") -> "TensorList":
24
+ """returns list of metrics for a tensorlist (possibly vectorized)"""
25
+ return x.map(self.evaluate_tensor)
26
+
27
+
28
+ class _MAD(Metric):
29
+ def evaluate_global(self, x): return x.abs().global_mean()
30
+ def evaluate_tensor(self, x, dim=None, keepdim=False): return x.abs().mean(dim=dim, keepdim=keepdim)
31
+ def evaluate_list(self, x): return x.abs().mean()
32
+
33
+ class _Std(Metric):
34
+ def evaluate_global(self, x): return x.global_std()
35
+ def evaluate_tensor(self, x, dim=None, keepdim=False): return x.std(dim=dim, keepdim=keepdim)
36
+ def evaluate_list(self, x): return x.std()
37
+
38
+ class _Var(Metric):
39
+ def evaluate_global(self, x): return x.global_var()
40
+ def evaluate_tensor(self, x, dim=None, keepdim=False): return x.var(dim=dim, keepdim=keepdim)
41
+ def evaluate_list(self, x): return x.var()
42
+
43
+ class _Sum(Metric):
44
+ def evaluate_global(self, x): return x.global_sum()
45
+ def evaluate_tensor(self, x, dim=None, keepdim=False): return x.sum(dim=dim, keepdim=keepdim)
46
+ def evaluate_list(self, x): return x.sum()
47
+
48
+ class _Norm(Metric):
49
+ def __init__(self, ord): self.ord = ord
50
+ def evaluate_global(self, x): return x.global_vector_norm(self.ord)
51
+ def evaluate_tensor(self, x, dim=None, keepdim=False):
52
+ return torch.linalg.vector_norm(x, ord=self.ord, dim=dim, keepdim=keepdim) # pylint:disable=not-callable
53
+ def evaluate_list(self, x): return x.norm(self.ord)
54
+
55
+ _METRIC_KEYS = Literal['mad', 'std', 'var', 'sum', 'l0', 'l1', 'l2', 'l3', 'l4', 'linf']
56
+ _METRICS: dict[_METRIC_KEYS, Metric] = {
57
+ "mad": _MAD(),
58
+ "std": _Std(),
59
+ "var": _Var(),
60
+ "sum": _Sum(),
61
+ "l0": _Norm(0),
62
+ "l1": _Norm(1),
63
+ "l2": _Norm(2),
64
+ "l3": _Norm(3),
65
+ "l4": _Norm(4),
66
+ "linf": _Norm(torch.inf),
67
+ }
68
+
69
+ Metrics = _METRIC_KEYS | float | torch.Tensor
70
+ def evaluate_metric(x: "torch.Tensor | TensorList", metric: Metrics) -> torch.Tensor:
71
+ if isinstance(metric, (int, float, torch.Tensor)):
72
+ if isinstance(x, torch.Tensor): return torch.linalg.vector_norm(x, ord=metric) # pylint:disable=not-callable
73
+ return x.global_vector_norm(ord=float(metric))
74
+
75
+ if isinstance(x, torch.Tensor): return _METRICS[metric].evaluate_tensor(x)
76
+ return _METRICS[metric].evaluate_global(x)
77
+
78
+
79
+ def calculate_metric_list(x: "TensorList", metric: Metrics) -> "TensorList":
80
+ if isinstance(metric, (int, float, torch.Tensor)):
81
+ return x.norm(ord=float(metric))
82
+
83
+ return _METRICS[metric].evaluate_list(x)
@@ -61,3 +61,9 @@ def unpack_dicts(dicts: Iterable[Mapping[str, Any]], key:str, key2: str | None =
61
61
  values = [cls(s[k] for s in dicts) for k in keys] # pyright:ignore[reportCallIssue]
62
62
  if len(values) == 1: return values[0]
63
63
  return values
64
+
65
+
66
+ def safe_dict_update_(d1_:dict, d2:dict):
67
+ inter = set(d1_.keys()).intersection(d2.keys())
68
+ if len(inter) > 0: raise RuntimeError(f"Duplicate keys {inter}")
69
+ d1_.update(d2)