torchzero 0.3.14__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. tests/test_identical.py +2 -2
  2. tests/test_module_autograd.py +586 -0
  3. tests/test_objective.py +188 -0
  4. tests/test_opts.py +47 -36
  5. tests/test_tensorlist.py +0 -8
  6. tests/test_utils_optimizer.py +0 -1
  7. torchzero/__init__.py +1 -1
  8. torchzero/core/__init__.py +8 -2
  9. torchzero/core/chain.py +47 -0
  10. torchzero/core/functional.py +103 -0
  11. torchzero/core/modular.py +233 -0
  12. torchzero/core/module.py +132 -643
  13. torchzero/core/objective.py +948 -0
  14. torchzero/core/reformulation.py +56 -23
  15. torchzero/core/transform.py +261 -365
  16. torchzero/linalg/__init__.py +10 -0
  17. torchzero/linalg/eigh.py +34 -0
  18. torchzero/linalg/linalg_utils.py +14 -0
  19. torchzero/{utils/linalg → linalg}/linear_operator.py +132 -34
  20. torchzero/linalg/matrix_power.py +28 -0
  21. torchzero/linalg/orthogonalize.py +95 -0
  22. torchzero/{utils/linalg → linalg}/qr.py +4 -2
  23. torchzero/{utils/linalg → linalg}/solve.py +76 -88
  24. torchzero/linalg/svd.py +20 -0
  25. torchzero/linalg/torch_linalg.py +168 -0
  26. torchzero/modules/__init__.py +0 -1
  27. torchzero/modules/adaptive/__init__.py +1 -1
  28. torchzero/modules/adaptive/adagrad.py +163 -213
  29. torchzero/modules/adaptive/adahessian.py +74 -103
  30. torchzero/modules/adaptive/adam.py +53 -76
  31. torchzero/modules/adaptive/adan.py +49 -30
  32. torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
  33. torchzero/modules/adaptive/aegd.py +12 -12
  34. torchzero/modules/adaptive/esgd.py +98 -119
  35. torchzero/modules/adaptive/lion.py +5 -10
  36. torchzero/modules/adaptive/lmadagrad.py +87 -32
  37. torchzero/modules/adaptive/mars.py +5 -5
  38. torchzero/modules/adaptive/matrix_momentum.py +47 -51
  39. torchzero/modules/adaptive/msam.py +70 -52
  40. torchzero/modules/adaptive/muon.py +59 -124
  41. torchzero/modules/adaptive/natural_gradient.py +33 -28
  42. torchzero/modules/adaptive/orthograd.py +11 -15
  43. torchzero/modules/adaptive/rmsprop.py +83 -75
  44. torchzero/modules/adaptive/rprop.py +48 -47
  45. torchzero/modules/adaptive/sam.py +55 -45
  46. torchzero/modules/adaptive/shampoo.py +123 -129
  47. torchzero/modules/adaptive/soap.py +207 -143
  48. torchzero/modules/adaptive/sophia_h.py +106 -130
  49. torchzero/modules/clipping/clipping.py +15 -18
  50. torchzero/modules/clipping/ema_clipping.py +31 -25
  51. torchzero/modules/clipping/growth_clipping.py +14 -17
  52. torchzero/modules/conjugate_gradient/cg.py +26 -37
  53. torchzero/modules/experimental/__init__.py +3 -6
  54. torchzero/modules/experimental/coordinate_momentum.py +36 -0
  55. torchzero/modules/experimental/curveball.py +25 -41
  56. torchzero/modules/experimental/gradmin.py +2 -2
  57. torchzero/modules/{higher_order → experimental}/higher_order_newton.py +14 -40
  58. torchzero/modules/experimental/newton_solver.py +22 -53
  59. torchzero/modules/experimental/newtonnewton.py +20 -17
  60. torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  61. torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  62. torchzero/modules/experimental/spsa1.py +5 -5
  63. torchzero/modules/experimental/structural_projections.py +1 -4
  64. torchzero/modules/functional.py +8 -1
  65. torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  66. torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  67. torchzero/modules/grad_approximation/rfdm.py +20 -17
  68. torchzero/modules/least_squares/gn.py +90 -42
  69. torchzero/modules/line_search/__init__.py +1 -1
  70. torchzero/modules/line_search/_polyinterp.py +3 -1
  71. torchzero/modules/line_search/adaptive.py +3 -3
  72. torchzero/modules/line_search/backtracking.py +3 -3
  73. torchzero/modules/line_search/interpolation.py +160 -0
  74. torchzero/modules/line_search/line_search.py +42 -51
  75. torchzero/modules/line_search/strong_wolfe.py +5 -5
  76. torchzero/modules/misc/debug.py +12 -12
  77. torchzero/modules/misc/escape.py +10 -10
  78. torchzero/modules/misc/gradient_accumulation.py +10 -78
  79. torchzero/modules/misc/homotopy.py +16 -8
  80. torchzero/modules/misc/misc.py +120 -122
  81. torchzero/modules/misc/multistep.py +63 -61
  82. torchzero/modules/misc/regularization.py +49 -44
  83. torchzero/modules/misc/split.py +30 -28
  84. torchzero/modules/misc/switch.py +37 -32
  85. torchzero/modules/momentum/averaging.py +14 -14
  86. torchzero/modules/momentum/cautious.py +34 -28
  87. torchzero/modules/momentum/momentum.py +11 -11
  88. torchzero/modules/ops/__init__.py +4 -4
  89. torchzero/modules/ops/accumulate.py +21 -21
  90. torchzero/modules/ops/binary.py +67 -66
  91. torchzero/modules/ops/higher_level.py +19 -19
  92. torchzero/modules/ops/multi.py +44 -41
  93. torchzero/modules/ops/reduce.py +26 -23
  94. torchzero/modules/ops/unary.py +53 -53
  95. torchzero/modules/ops/utility.py +47 -46
  96. torchzero/modules/projections/galore.py +1 -1
  97. torchzero/modules/projections/projection.py +43 -43
  98. torchzero/modules/quasi_newton/__init__.py +2 -0
  99. torchzero/modules/quasi_newton/damping.py +1 -1
  100. torchzero/modules/quasi_newton/lbfgs.py +7 -7
  101. torchzero/modules/quasi_newton/lsr1.py +7 -7
  102. torchzero/modules/quasi_newton/quasi_newton.py +25 -16
  103. torchzero/modules/quasi_newton/sg2.py +292 -0
  104. torchzero/modules/restarts/restars.py +26 -24
  105. torchzero/modules/second_order/__init__.py +6 -3
  106. torchzero/modules/second_order/ifn.py +58 -0
  107. torchzero/modules/second_order/inm.py +101 -0
  108. torchzero/modules/second_order/multipoint.py +40 -80
  109. torchzero/modules/second_order/newton.py +105 -228
  110. torchzero/modules/second_order/newton_cg.py +102 -154
  111. torchzero/modules/second_order/nystrom.py +158 -178
  112. torchzero/modules/second_order/rsn.py +237 -0
  113. torchzero/modules/smoothing/laplacian.py +13 -12
  114. torchzero/modules/smoothing/sampling.py +11 -10
  115. torchzero/modules/step_size/adaptive.py +23 -23
  116. torchzero/modules/step_size/lr.py +15 -15
  117. torchzero/modules/termination/termination.py +32 -30
  118. torchzero/modules/trust_region/cubic_regularization.py +2 -2
  119. torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  120. torchzero/modules/trust_region/trust_cg.py +1 -1
  121. torchzero/modules/trust_region/trust_region.py +27 -22
  122. torchzero/modules/variance_reduction/svrg.py +21 -18
  123. torchzero/modules/weight_decay/__init__.py +2 -1
  124. torchzero/modules/weight_decay/reinit.py +83 -0
  125. torchzero/modules/weight_decay/weight_decay.py +12 -13
  126. torchzero/modules/wrappers/optim_wrapper.py +57 -50
  127. torchzero/modules/zeroth_order/cd.py +9 -6
  128. torchzero/optim/root.py +3 -3
  129. torchzero/optim/utility/split.py +2 -1
  130. torchzero/optim/wrappers/directsearch.py +27 -63
  131. torchzero/optim/wrappers/fcmaes.py +14 -35
  132. torchzero/optim/wrappers/mads.py +11 -31
  133. torchzero/optim/wrappers/moors.py +66 -0
  134. torchzero/optim/wrappers/nevergrad.py +4 -4
  135. torchzero/optim/wrappers/nlopt.py +31 -25
  136. torchzero/optim/wrappers/optuna.py +6 -13
  137. torchzero/optim/wrappers/pybobyqa.py +124 -0
  138. torchzero/optim/wrappers/scipy/__init__.py +7 -0
  139. torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  140. torchzero/optim/wrappers/scipy/brute.py +48 -0
  141. torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  142. torchzero/optim/wrappers/scipy/direct.py +69 -0
  143. torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  144. torchzero/optim/wrappers/scipy/experimental.py +141 -0
  145. torchzero/optim/wrappers/scipy/minimize.py +151 -0
  146. torchzero/optim/wrappers/scipy/sgho.py +111 -0
  147. torchzero/optim/wrappers/wrapper.py +121 -0
  148. torchzero/utils/__init__.py +7 -25
  149. torchzero/utils/compile.py +2 -2
  150. torchzero/utils/derivatives.py +112 -88
  151. torchzero/utils/optimizer.py +4 -77
  152. torchzero/utils/python_tools.py +31 -0
  153. torchzero/utils/tensorlist.py +11 -5
  154. torchzero/utils/thoad_tools.py +68 -0
  155. {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
  156. torchzero-0.4.0.dist-info/RECORD +191 -0
  157. tests/test_vars.py +0 -185
  158. torchzero/modules/experimental/momentum.py +0 -160
  159. torchzero/modules/higher_order/__init__.py +0 -1
  160. torchzero/optim/wrappers/scipy.py +0 -572
  161. torchzero/utils/linalg/__init__.py +0 -12
  162. torchzero/utils/linalg/matrix_funcs.py +0 -87
  163. torchzero/utils/linalg/orthogonalize.py +0 -12
  164. torchzero/utils/linalg/svd.py +0 -20
  165. torchzero/utils/ops.py +0 -10
  166. torchzero-0.3.14.dist-info/RECORD +0 -167
  167. /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
  168. {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
  169. {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,4 @@
1
+ # pylint: disable = non-ascii-name
1
2
  # pyright: reportArgumentType=false
2
3
  import math
3
4
  from collections import deque
@@ -5,8 +6,8 @@ from collections.abc import Callable
5
6
  from typing import Any, NamedTuple, overload
6
7
 
7
8
  import torch
8
-
9
- from .. import (
9
+ from .linalg_utils import mm
10
+ from ..utils import (
10
11
  TensorList,
11
12
  generic_eq,
12
13
  generic_finfo_tiny,
@@ -15,88 +16,73 @@ from .. import (
15
16
  generic_zeros_like,
16
17
  )
17
18
 
18
-
19
- def _make_A_mm_reg(A_mm: Callable, reg):
20
- def A_mm_reg(x): # A_mm with regularization
21
- Ax = A_mm(x)
19
+ def _make_A_mv_reg(A_mv: Callable, reg):
20
+ def A_mv_reg(x): # A_mm with regularization
21
+ Ax = A_mv(x)
22
22
  if not generic_eq(reg, 0): Ax += x*reg
23
23
  return Ax
24
- return A_mm_reg
24
+ return A_mv_reg
25
25
 
26
26
  def _identity(x): return x
27
27
 
28
-
29
- # https://arxiv.org/pdf/2110.02820
30
- def nystrom_approximation(
31
- A_mm: Callable[[torch.Tensor], torch.Tensor],
32
- ndim: int,
33
- rank: int,
34
- device,
35
- dtype = torch.float32,
36
- generator = None,
37
- ) -> tuple[torch.Tensor, torch.Tensor]:
38
- omega = torch.randn((ndim, rank), device=device, dtype=dtype, generator=generator) # Gaussian test matrix
39
- omega, _ = torch.linalg.qr(omega) # Thin QR decomposition # pylint:disable=not-callable
40
-
41
- # Y = AΩ
42
- Y = torch.stack([A_mm(col) for col in omega.unbind(-1)], -1) # rank matvecs
43
- v = torch.finfo(dtype).eps * torch.linalg.matrix_norm(Y, ord='fro') # Compute shift # pylint:disable=not-callable
44
- Yv = Y + v*omega # Shift for stability
45
- C = torch.linalg.cholesky_ex(omega.mT @ Yv)[0] # pylint:disable=not-callable
46
- B = torch.linalg.solve_triangular(C, Yv.mT, upper=False, unitriangular=False).mT # pylint:disable=not-callable
47
- U, S, _ = torch.linalg.svd(B, full_matrices=False) # pylint:disable=not-callable
48
- lambd = (S.pow(2) - v).clip(min=0) #Remove shift, compute eigs
49
- return U, lambd
50
-
28
+ # TODO this is used in NystromSketchAndSolve
29
+ # I need to add alternative to it where it just shifts eigenvalues by reg and uses their reciprocal
51
30
  def nystrom_sketch_and_solve(
52
- A_mm: Callable[[torch.Tensor], torch.Tensor],
31
+ L: torch.Tensor,
32
+ Q: torch.Tensor,
53
33
  b: torch.Tensor,
54
- rank: int,
55
34
  reg: float = 1e-3,
56
- generator=None,
57
35
  ) -> torch.Tensor:
58
- U, lambd = nystrom_approximation(
59
- A_mm=A_mm,
60
- ndim=b.size(-1),
61
- rank=rank,
62
- device=b.device,
63
- dtype=b.dtype,
64
- generator=generator,
65
- )
36
+ """Solves (Q diag(L) Q.T + reg*I)x = b. Becomes super unstable with reg smaller than like 1e-5.
37
+
38
+ Args:
39
+ L (torch.Tensor): eigenvalues, like from ``nystrom_approximation``
40
+ Q (torch.Tensor): eigenvectors, like from ``nystrom_approximation``
41
+ b (torch.Tensor): right hand side
42
+ reg (float, optional): regularization. Defaults to 1e-3.
43
+ """
44
+
66
45
  b = b.unsqueeze(-1)
67
- lambd += reg
46
+ L += reg
68
47
  # x = (A + μI)⁻¹ b
69
- # (A + μI)⁻¹ = U(Λ + μI)⁻¹Uᵀ + (1/μ)(b - UUᵀ)
70
- # x = U(Λ + μI)⁻¹Uᵀb + (1/μ)(b - UUᵀb)
71
- Uᵀb = U.T @ b
72
- term1 = U @ ((1/lambd).unsqueeze(-1) * Uᵀb)
73
- term2 = (1.0 / reg) * (b - U @ Uᵀb)
48
+ # (A + μI)⁻¹ = Q(L + μI)⁻¹Qᵀ + (1/μ)(b - QQᵀ)
49
+ # x = Q(L + μI)⁻¹Qᵀb + (1/μ)(b - QQᵀb)
50
+ Qᵀb = Q.T @ b
51
+ term1 = Q @ ((1/L).unsqueeze(-1) * Qᵀb)
52
+ term2 = (1.0 / reg) * (b - Q @ Qᵀb)
74
53
  return (term1 + term2).squeeze(-1)
75
54
 
76
55
  def nystrom_pcg(
77
- A_mm: Callable[[torch.Tensor], torch.Tensor],
56
+ L: torch.Tensor,
57
+ Q: torch.Tensor,
58
+ A_mv: Callable[[torch.Tensor], torch.Tensor],
78
59
  b: torch.Tensor,
79
- sketch_size: int,
80
60
  reg: float = 1e-6,
81
61
  x0_: torch.Tensor | None = None,
82
- tol: float | None = 1e-4,
62
+ tol: float | None = 1e-8,
83
63
  maxiter: int | None = None,
84
- generator=None,
85
64
  ) -> torch.Tensor:
86
- U, lambd = nystrom_approximation(
87
- A_mm=A_mm,
88
- ndim=b.size(-1),
89
- rank=sketch_size,
90
- device=b.device,
91
- dtype=b.dtype,
92
- generator=generator,
93
- )
94
- lambd += reg
65
+ """conjugate gradient preconditioned by nystrom approximation.
66
+
67
+ The preconditioner can be computed by one matrix-matrix multiplication with A.
68
+ If matrix-matrix is efficient, then this is good (e.g. batched hessian-vector products in pytorch)
69
+
70
+ Args:
71
+ L (torch.Tensor): eigenvalues of approximation of A, like from ``nystrom_approximation``
72
+ Q (torch.Tensor): eigenvectors of approximation of A, like from ``nystrom_approximation``
73
+ A_mv (Callable[[torch.Tensor], torch.Tensor]): mat-vec func with hessian
74
+ b (torch.Tensor): right hand side
75
+ reg (float, optional): regularization. Defaults to 1e-6.
76
+ x0_ (torch.Tensor | None, optional): initial guess (modified in-place). Defaults to None.
77
+ tol (float | None, optional): tolerance for convergence. Defaults to 1e-4.
78
+ maxiter (int | None, optional): maximum number of iterations. Defaults to None.
79
+ """
80
+ L += reg
95
81
  eps = torch.finfo(b.dtype).tiny * 2
96
82
  if tol is None: tol = eps
97
83
 
98
- def A_mm_reg(x): # A_mm with regularization
99
- Ax = A_mm(x)
84
+ def A_mv_reg(x): # A_mm with regularization
85
+ Ax = A_mv(x)
100
86
  if reg != 0: Ax += x*reg
101
87
  return Ax
102
88
 
@@ -104,10 +90,10 @@ def nystrom_pcg(
104
90
  if x0_ is None: x0_ = torch.zeros_like(b)
105
91
 
106
92
  x = x0_
107
- residual = b - A_mm_reg(x)
93
+ residual = b - A_mv_reg(x)
108
94
  # z0 = P⁻¹ r0
109
- term1 = lambd[...,-1] * U * (1/lambd.unsqueeze(-2)) @ U.mT
110
- term2 = torch.eye(U.size(-2), device=U.device,dtype=U.dtype) - U@U.mT
95
+ term1 = L[...,-1] * Q * (1/L.unsqueeze(-2)) @ Q.mT
96
+ term2 = torch.eye(Q.size(-2), device=Q.device,dtype=Q.dtype) - Q@Q.mT
111
97
  P_inv = term1 + term2
112
98
  z = P_inv @ residual
113
99
  p = z.clone() # search direction
@@ -116,7 +102,7 @@ def nystrom_pcg(
116
102
  if init_norm < tol: return x
117
103
  k = 0
118
104
  while True:
119
- Ap = A_mm_reg(p)
105
+ Ap = A_mv_reg(p)
120
106
  rz = residual.dot(z)
121
107
  step_size = rz / p.dot(Ap)
122
108
  x += step_size * p
@@ -138,7 +124,7 @@ def _safe_clip(x: torch.Tensor):
138
124
  if x.abs() < eps: return x.new_full(x.size(), eps).copysign(x)
139
125
  return x
140
126
 
141
- def _trust_tau(x,d,trust_radius):
127
+ def _trust_tau(x, d, trust_radius):
142
128
  xx = x.dot(x)
143
129
  xd = x.dot(d)
144
130
  dd = _safe_clip(d.dot(d))
@@ -150,10 +136,10 @@ def _trust_tau(x,d,trust_radius):
150
136
 
151
137
 
152
138
  class CG:
153
- """Conjugate gradient method.
139
+ """Conjugate gradient method optionally with norm constraint.
154
140
 
155
141
  Args:
156
- A_mm (Callable[[torch.Tensor], torch.Tensor] | torch.Tensor): Callable that returns matvec ``Ax``.
142
+ A_mv (Callable[[torch.Tensor], torch.Tensor] | torch.Tensor): Callable that returns matvec ``Ax``.
157
143
  b (torch.Tensor): right hand side
158
144
  x0 (torch.Tensor | None, optional): initial guess, defaults to zeros. Defaults to None.
159
145
  tol (float | None, optional): tolerance for convergence. Defaults to 1e-8.
@@ -174,10 +160,10 @@ class CG:
174
160
  """
175
161
  def __init__(
176
162
  self,
177
- A_mm: Callable,
163
+ A_mv: Callable,
178
164
  b: torch.Tensor | TensorList,
179
165
  x0: torch.Tensor | TensorList | None = None,
180
- tol: float | None = 1e-4,
166
+ tol: float | None = 1e-8,
181
167
  maxiter: int | None = None,
182
168
  reg: float = 0,
183
169
  trust_radius: float | None = None,
@@ -187,7 +173,7 @@ class CG:
187
173
  P_mm: Callable | None = None,
188
174
  ):
189
175
  # --------------------------------- set attrs -------------------------------- #
190
- self.A_mm = _make_A_mm_reg(A_mm, reg)
176
+ self.A_mv = _make_A_mv_reg(A_mv, reg)
191
177
  self.b = b
192
178
  if tol is None: tol = generic_finfo_tiny(b) * 2
193
179
  self.tol = tol
@@ -214,7 +200,7 @@ class CG:
214
200
  self.r = b
215
201
  else:
216
202
  self.x = x0
217
- self.r = b - A_mm(self.x)
203
+ self.r = b - A_mv(self.x)
218
204
 
219
205
  self.z = self.P_mm(self.r)
220
206
  self.d = self.z
@@ -229,7 +215,7 @@ class CG:
229
215
  if self.iter >= self.maxiter:
230
216
  return x, True
231
217
 
232
- Ad = self.A_mm(d)
218
+ Ad = self.A_mv(d)
233
219
  dAd = d.dot(Ad)
234
220
 
235
221
  # check negative curvature
@@ -289,7 +275,8 @@ class CG:
289
275
  return sol
290
276
 
291
277
  def find_within_trust_radius(history, trust_radius: float):
292
- """find first ``x`` in history that exceeds trust radius, if no such ``x`` exists, returns ``None``"""
278
+ """find first ``x`` in history that exceeds trust radius and returns solution within,
279
+ if no such ``x`` exists, returns ``None``"""
293
280
  for x, x_norm, d in reversed(tuple(history)):
294
281
  if x_norm <= trust_radius:
295
282
  return _trust_tau(x, d, trust_radius)
@@ -306,7 +293,7 @@ class _TensorListSolution(NamedTuple):
306
293
 
307
294
  @overload
308
295
  def cg(
309
- A_mm: Callable[[torch.Tensor], torch.Tensor],
296
+ A_mv: Callable[[torch.Tensor], torch.Tensor],
310
297
  b: torch.Tensor,
311
298
  x0: torch.Tensor | None = None,
312
299
  tol: float | None = 1e-8,
@@ -320,7 +307,7 @@ def cg(
320
307
  ) -> _TensorSolution: ...
321
308
  @overload
322
309
  def cg(
323
- A_mm: Callable[[TensorList], TensorList],
310
+ A_mv: Callable[[TensorList], TensorList],
324
311
  b: TensorList,
325
312
  x0: TensorList | None = None,
326
313
  tol: float | None = 1e-8,
@@ -333,7 +320,7 @@ def cg(
333
320
  P_mm: Callable[[TensorList], TensorList] | None = None
334
321
  ) -> _TensorListSolution: ...
335
322
  def cg(
336
- A_mm: Callable,
323
+ A_mv: Callable,
337
324
  b: torch.Tensor | TensorList,
338
325
  x0: torch.Tensor | TensorList | None = None,
339
326
  tol: float | None = 1e-8,
@@ -346,7 +333,7 @@ def cg(
346
333
  P_mm: Callable | None = None
347
334
  ):
348
335
  solver = CG(
349
- A_mm=A_mm,
336
+ A_mv=A_mv,
350
337
  b=b,
351
338
  x0=x0,
352
339
  tol=tol,
@@ -370,10 +357,10 @@ def cg(
370
357
  # Liu, Yang, and Fred Roosta. "MINRES: From negative curvature detection to monotonicity properties." SIAM Journal on Optimization 32.4 (2022): 2636-2661.
371
358
  @overload
372
359
  def minres(
373
- A_mm: Callable[[torch.Tensor], torch.Tensor] | torch.Tensor,
360
+ A_mv: Callable[[torch.Tensor], torch.Tensor] | torch.Tensor,
374
361
  b: torch.Tensor,
375
362
  x0: torch.Tensor | None = None,
376
- tol: float | None = 1e-4,
363
+ tol: float | None = 1e-8,
377
364
  maxiter: int | None = None,
378
365
  reg: float = 0,
379
366
  npc_terminate: bool=True,
@@ -381,26 +368,27 @@ def minres(
381
368
  ) -> torch.Tensor: ...
382
369
  @overload
383
370
  def minres(
384
- A_mm: Callable[[TensorList], TensorList],
371
+ A_mv: Callable[[TensorList], TensorList],
385
372
  b: TensorList,
386
373
  x0: TensorList | None = None,
387
- tol: float | None = 1e-4,
374
+ tol: float | None = 1e-8,
388
375
  maxiter: int | None = None,
389
376
  reg: float | list[float] | tuple[float] = 0,
390
377
  npc_terminate: bool=True,
391
378
  trust_radius: float | None = None,
392
379
  ) -> TensorList: ...
393
380
  def minres(
394
- A_mm,
381
+ A_mv,
395
382
  b,
396
383
  x0: torch.Tensor | TensorList | None = None,
397
- tol: float | None = 1e-4,
384
+ tol: float | None = 1e-8,
398
385
  maxiter: int | None = None,
399
386
  reg: float | list[float] | tuple[float] = 0,
400
387
  npc_terminate: bool=True,
401
388
  trust_radius: float | None = None, #trust region is experimental
402
389
  ):
403
- A_mm_reg = _make_A_mm_reg(A_mm, reg)
390
+ """MINRES (experimental)"""
391
+ A_mv_reg = _make_A_mv_reg(A_mv, reg)
404
392
  eps = math.sqrt(generic_finfo_tiny(b) * 2)
405
393
  if tol is None: tol = eps
406
394
 
@@ -409,7 +397,7 @@ def minres(
409
397
  R = b
410
398
  x0 = generic_zeros_like(b)
411
399
  else:
412
- R = b - A_mm_reg(x0)
400
+ R = b - A_mv_reg(x0)
413
401
 
414
402
  X: Any = x0
415
403
  beta = b_norm = generic_vector_norm(b)
@@ -429,7 +417,7 @@ def minres(
429
417
 
430
418
  for _ in range(maxiter):
431
419
 
432
- P = A_mm_reg(V)
420
+ P = A_mv_reg(V)
433
421
  alpha = V.dot(P)
434
422
  P -= beta*V_prev
435
423
  P -= alpha*V
@@ -0,0 +1,20 @@
1
+ # import torch
2
+
3
+ # # projected svd
4
+ # # adapted from https://github.com/smortezavi/Randomized_SVD_GPU
5
+ # def randomized_svd(M: torch.Tensor, k: int, driver=None):
6
+ # *_, m, n = M.shape
7
+ # transpose = False
8
+ # if m < n:
9
+ # transpose = True
10
+ # M = M.mT
11
+ # m,n = n,m
12
+
13
+ # rand_matrix = torch.randn(size=(n, k), device=M.device, dtype=M.dtype)
14
+ # Q, _ = torch.linalg.qr(M @ rand_matrix, mode='reduced') # pylint:disable=not-callable
15
+ # smaller_matrix = Q.mT @ M
16
+ # U_hat, s, V = torch.linalg.svd(smaller_matrix, driver=driver, full_matrices=False) # pylint:disable=not-callable
17
+ # U = Q @ U_hat
18
+
19
+ # if transpose: return V.mT, s, U.mT
20
+ # return U, s, V
@@ -0,0 +1,168 @@
1
+ """torch linalg with correct typing and retries in float64"""
2
+ from typing import NamedTuple
3
+
4
+ import torch
5
+
6
+
7
+ def cholesky(A: torch.Tensor, *, upper=False, retry_float64:bool=False) -> torch.Tensor:
8
+ """A - SPD, returns lower triangular L such that ``A = L @ L.mH`` also can pass L to ``torch.cholesky_solve``"""
9
+ try:
10
+ return torch.linalg.cholesky(A, upper=upper) # pylint:disable=not-callable
11
+
12
+ except torch.linalg.LinAlgError as e:
13
+ if not retry_float64: raise e
14
+ dtype = A.dtype
15
+ if dtype == torch.float64: raise e
16
+ return cholesky(A.to(torch.float64), upper=upper, retry_float64=False).to(dtype)
17
+
18
+
19
+ class _QRTuple(NamedTuple):
20
+ Q: torch.Tensor
21
+ R: torch.Tensor
22
+
23
+ def qr(A: torch.Tensor, mode='reduced', retry_float64:bool=False) -> _QRTuple:
24
+ """A - any matrix ``(*, m, n)`` (for some reason sometimes it takes ages on some matrices)
25
+
26
+ ### Returns (if mode = "reduced"):
27
+
28
+ Q: ``(*, m, k)`` - orthogonal
29
+
30
+ R: ``(*, k, n)`` - upper triangular
31
+
32
+ where ``k = min(m,n)``
33
+ """
34
+ try:
35
+ return torch.linalg.qr(A, mode=mode) # pylint:disable=not-callable
36
+
37
+ except torch.linalg.LinAlgError as e:
38
+ if not retry_float64: raise e
39
+ dtype = A.dtype
40
+ if dtype == torch.float64: raise e
41
+ Q, R = qr(A.to(torch.float64), mode=mode, retry_float64=False)
42
+ return _QRTuple(Q=Q.to(dtype), R=R.to(dtype))
43
+
44
+ def eigh(A: torch.Tensor, UPLO="L", retry_float64:bool=False) -> tuple[torch.Tensor, torch.Tensor]:
45
+ """A - symmetric, returns ``(L, Q)``, ``A = Q @ torch.diag(L) @ Q.mH``, this is faster than SVD"""
46
+ try:
47
+ return torch.linalg.eigh(A, UPLO=UPLO) # pylint:disable=not-callable
48
+
49
+ except torch.linalg.LinAlgError as e:
50
+ if not retry_float64: raise e
51
+ dtype = A.dtype
52
+ if dtype == torch.float64: raise e
53
+ L, Q = eigh(A.to(torch.float64), UPLO=UPLO, retry_float64=False)
54
+ return L.to(dtype), Q.to(dtype)
55
+
56
+
57
+
58
+ class _SVDTuple(NamedTuple):
59
+ U: torch.Tensor
60
+ S: torch.Tensor
61
+ Vh: torch.Tensor
62
+
63
+ def svd(A: torch.Tensor, full_matrices=True, driver=None, retry_float64:bool=False) -> _SVDTuple:
64
+ """A - any matrix ``(*, n, m)``, but slows down if A isn't well conditioned, ``A = U @ torch.diag(S) @ Vh``
65
+
66
+ Don't forget to set ``full_matrices=False``
67
+
68
+ ### Returns:
69
+
70
+ U: ``(*, m, m)`` or ``(*, m, k)`` - orthogonal
71
+
72
+ S: ``(*, k,)`` - singular values
73
+
74
+ V^H: ``(*, n, n)`` or ``(*, n, k)`` - orthogonal
75
+
76
+ where ``k = min(m,n)``
77
+
78
+ ### Drivers
79
+
80
+ drivers are only supported on CUDA so A is moved to CUDA by this function if needed
81
+
82
+ from docs:
83
+
84
+ If A is well-conditioned (its condition number is not too large), or you do not mind some precision loss.
85
+
86
+ For a general matrix: ‘gesvdj’ (Jacobi method)
87
+
88
+ If A is tall or wide (m >> n or m << n): ‘gesvda’ (Approximate method)
89
+
90
+ If A is not well-conditioned or precision is relevant: ‘gesvd’ (QR based)
91
+
92
+ By default (driver= None), we call ‘gesvdj’ and, if it fails, we fallback to ‘gesvd’.
93
+ """
94
+ # drivers are only for CUDA
95
+ # also the only one that doesn't freeze is ‘gesvda’
96
+ device=None
97
+ if driver is not None:
98
+ device = A.device
99
+ A = A.cuda()
100
+
101
+ try:
102
+ U, S, Vh = torch.linalg.svd(A, full_matrices=full_matrices, driver=driver) # pylint:disable=not-callable
103
+ if device is not None:
104
+ U = U.to(device); S = S.to(device); Vh = Vh.to(device)
105
+ return _SVDTuple(U=U, S=S, Vh=Vh)
106
+
107
+ except torch.linalg.LinAlgError as e:
108
+ if not retry_float64: raise e
109
+ dtype = A.dtype
110
+ if dtype == torch.float64: raise e
111
+ U, S, Vh = svd(A.to(torch.float64), full_matrices=full_matrices, driver=driver, retry_float64=False)
112
+ return _SVDTuple(U=U.to(dtype), S=S.to(dtype), Vh=Vh.to(dtype))
113
+
114
+ def solve(A: torch.Tensor, B: torch.Tensor, left:bool=True, retry_float64:bool=False) -> torch.Tensor:
115
+ """I think this uses LU"""
116
+ try:
117
+ return torch.linalg.solve(A, B, left=left) # pylint:disable=not-callable
118
+
119
+ except torch.linalg.LinAlgError as e:
120
+ if not retry_float64: raise e
121
+ dtype = A.dtype
122
+ if dtype == torch.float64: raise e
123
+ return solve(A.to(torch.float64), B.to(torch.float64), left=left, retry_float64=False).to(dtype)
124
+
125
+ class _SolveExTuple(NamedTuple):
126
+ result: torch.Tensor
127
+ info: int
128
+
129
+ def solve_ex(A: torch.Tensor, B: torch.Tensor, left:bool=True, retry_float64:bool=False) -> _SolveExTuple:
130
+ """I think this uses LU"""
131
+ result, info = torch.linalg.solve_ex(A, B, left=left) # pylint:disable=not-callable
132
+
133
+ if info != 0:
134
+ if not retry_float64: return _SolveExTuple(result, info)
135
+ dtype = A.dtype
136
+ if dtype == torch.float64: return _SolveExTuple(result, info)
137
+ result, info = solve_ex(A.to(torch.float64), B.to(torch.float64), retry_float64=False)
138
+ return _SolveExTuple(result.to(dtype), info)
139
+
140
+ return _SolveExTuple(result, info)
141
+
142
+ def inv(A: torch.Tensor, retry_float64:bool=False) -> torch.Tensor:
143
+ try:
144
+ return torch.linalg.inv(A) # pylint:disable=not-callable
145
+
146
+ except torch.linalg.LinAlgError as e:
147
+ if not retry_float64: raise e
148
+ dtype = A.dtype
149
+ if dtype == torch.float64: raise e
150
+ return inv(A.to(torch.float64), retry_float64=False).to(dtype)
151
+
152
+
153
+ class _InvExTuple(NamedTuple):
154
+ inverse: torch.Tensor
155
+ info: int
156
+
157
+ def inv_ex(A: torch.Tensor, *, check_errors=False, retry_float64:bool=False) -> _InvExTuple:
158
+ """this retries in float64 but on fail info will be not 0"""
159
+ inverse, info = torch.linalg.inv_ex(A, check_errors=check_errors) # pylint:disable=not-callable
160
+
161
+ if info != 0:
162
+ if not retry_float64: return _InvExTuple(inverse, info)
163
+ dtype = A.dtype
164
+ if dtype == torch.float64: return _InvExTuple(inverse, info)
165
+ inverse, info = inv_ex(A.to(torch.float64), retry_float64=False)
166
+ return _InvExTuple(inverse.to(dtype), info)
167
+
168
+ return _InvExTuple(inverse, info)
@@ -2,7 +2,6 @@ from . import experimental
2
2
  from .clipping import *
3
3
  from .conjugate_gradient import *
4
4
  from .grad_approximation import *
5
- from .higher_order import *
6
5
  from .least_squares import *
7
6
  from .line_search import *
8
7
  from .misc import *
@@ -12,7 +12,7 @@ from .lmadagrad import LMAdagrad
12
12
  from .lion import Lion
13
13
  from .mars import MARSCorrection
14
14
  from .matrix_momentum import MatrixMomentum
15
- from .msam import MSAM, MSAMObjective
15
+ from .msam import MSAMMomentum, MSAM
16
16
  from .muon import DualNormCorrection, MuonAdjustLR, Orthogonalize, orthogonalize_grads_
17
17
  from .natural_gradient import NaturalGradient
18
18
  from .orthograd import OrthoGrad, orthograd_