torchzero 0.3.15__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. tests/test_identical.py +22 -22
  2. tests/test_module_autograd.py +586 -0
  3. tests/test_objective.py +188 -0
  4. tests/test_opts.py +225 -214
  5. tests/test_tensorlist.py +0 -8
  6. tests/test_utils_optimizer.py +0 -1
  7. torchzero/__init__.py +2 -2
  8. torchzero/core/__init__.py +7 -4
  9. torchzero/core/chain.py +20 -23
  10. torchzero/core/functional.py +90 -24
  11. torchzero/core/modular.py +53 -57
  12. torchzero/core/module.py +132 -52
  13. torchzero/core/objective.py +948 -0
  14. torchzero/core/reformulation.py +55 -24
  15. torchzero/core/transform.py +261 -367
  16. torchzero/linalg/__init__.py +11 -0
  17. torchzero/linalg/eigh.py +253 -0
  18. torchzero/linalg/linalg_utils.py +14 -0
  19. torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
  20. torchzero/linalg/matrix_power.py +28 -0
  21. torchzero/linalg/orthogonalize.py +93 -0
  22. torchzero/{utils/linalg → linalg}/qr.py +16 -2
  23. torchzero/{utils/linalg → linalg}/solve.py +74 -88
  24. torchzero/linalg/svd.py +47 -0
  25. torchzero/linalg/torch_linalg.py +168 -0
  26. torchzero/modules/__init__.py +4 -3
  27. torchzero/modules/adaptive/__init__.py +11 -3
  28. torchzero/modules/adaptive/adagrad.py +167 -217
  29. torchzero/modules/adaptive/adahessian.py +76 -105
  30. torchzero/modules/adaptive/adam.py +53 -76
  31. torchzero/modules/adaptive/adan.py +50 -31
  32. torchzero/modules/adaptive/adaptive_heavyball.py +12 -7
  33. torchzero/modules/adaptive/aegd.py +12 -12
  34. torchzero/modules/adaptive/esgd.py +98 -119
  35. torchzero/modules/adaptive/ggt.py +186 -0
  36. torchzero/modules/adaptive/lion.py +7 -11
  37. torchzero/modules/adaptive/lre_optimizers.py +299 -0
  38. torchzero/modules/adaptive/mars.py +7 -7
  39. torchzero/modules/adaptive/matrix_momentum.py +48 -52
  40. torchzero/modules/adaptive/msam.py +71 -53
  41. torchzero/modules/adaptive/muon.py +67 -129
  42. torchzero/modules/adaptive/natural_gradient.py +63 -41
  43. torchzero/modules/adaptive/orthograd.py +11 -15
  44. torchzero/modules/adaptive/psgd/__init__.py +5 -0
  45. torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
  46. torchzero/modules/adaptive/psgd/psgd.py +1390 -0
  47. torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
  48. torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
  49. torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
  50. torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
  51. torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
  52. torchzero/modules/adaptive/rmsprop.py +83 -75
  53. torchzero/modules/adaptive/rprop.py +48 -47
  54. torchzero/modules/adaptive/sam.py +55 -45
  55. torchzero/modules/adaptive/shampoo.py +149 -130
  56. torchzero/modules/adaptive/soap.py +207 -143
  57. torchzero/modules/adaptive/sophia_h.py +106 -130
  58. torchzero/modules/clipping/clipping.py +22 -25
  59. torchzero/modules/clipping/ema_clipping.py +31 -25
  60. torchzero/modules/clipping/growth_clipping.py +14 -17
  61. torchzero/modules/conjugate_gradient/cg.py +27 -38
  62. torchzero/modules/experimental/__init__.py +7 -6
  63. torchzero/modules/experimental/adanystrom.py +258 -0
  64. torchzero/modules/experimental/common_directions_whiten.py +142 -0
  65. torchzero/modules/experimental/coordinate_momentum.py +36 -0
  66. torchzero/modules/experimental/cubic_adam.py +160 -0
  67. torchzero/modules/experimental/curveball.py +25 -41
  68. torchzero/modules/experimental/eigen_sr1.py +182 -0
  69. torchzero/modules/experimental/eigengrad.py +207 -0
  70. torchzero/modules/experimental/gradmin.py +2 -2
  71. torchzero/modules/experimental/higher_order_newton.py +14 -40
  72. torchzero/modules/experimental/l_infinity.py +1 -1
  73. torchzero/modules/experimental/matrix_nag.py +122 -0
  74. torchzero/modules/experimental/newton_solver.py +23 -54
  75. torchzero/modules/experimental/newtonnewton.py +45 -48
  76. torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  77. torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  78. torchzero/modules/experimental/spsa1.py +3 -3
  79. torchzero/modules/experimental/structural_projections.py +1 -4
  80. torchzero/modules/grad_approximation/fdm.py +2 -2
  81. torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  82. torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  83. torchzero/modules/grad_approximation/rfdm.py +24 -21
  84. torchzero/modules/least_squares/gn.py +121 -50
  85. torchzero/modules/line_search/backtracking.py +4 -4
  86. torchzero/modules/line_search/line_search.py +33 -33
  87. torchzero/modules/line_search/strong_wolfe.py +4 -4
  88. torchzero/modules/misc/debug.py +12 -12
  89. torchzero/modules/misc/escape.py +10 -10
  90. torchzero/modules/misc/gradient_accumulation.py +11 -79
  91. torchzero/modules/misc/homotopy.py +16 -8
  92. torchzero/modules/misc/misc.py +121 -123
  93. torchzero/modules/misc/multistep.py +52 -53
  94. torchzero/modules/misc/regularization.py +49 -44
  95. torchzero/modules/misc/split.py +31 -29
  96. torchzero/modules/misc/switch.py +37 -32
  97. torchzero/modules/momentum/averaging.py +14 -14
  98. torchzero/modules/momentum/cautious.py +37 -31
  99. torchzero/modules/momentum/momentum.py +12 -12
  100. torchzero/modules/ops/__init__.py +4 -4
  101. torchzero/modules/ops/accumulate.py +21 -21
  102. torchzero/modules/ops/binary.py +67 -66
  103. torchzero/modules/ops/higher_level.py +20 -20
  104. torchzero/modules/ops/multi.py +44 -41
  105. torchzero/modules/ops/reduce.py +26 -23
  106. torchzero/modules/ops/unary.py +53 -53
  107. torchzero/modules/ops/utility.py +47 -46
  108. torchzero/modules/{functional.py → opt_utils.py} +1 -1
  109. torchzero/modules/projections/galore.py +1 -1
  110. torchzero/modules/projections/projection.py +46 -43
  111. torchzero/modules/quasi_newton/__init__.py +1 -1
  112. torchzero/modules/quasi_newton/damping.py +2 -2
  113. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
  114. torchzero/modules/quasi_newton/lbfgs.py +10 -10
  115. torchzero/modules/quasi_newton/lsr1.py +10 -10
  116. torchzero/modules/quasi_newton/quasi_newton.py +54 -39
  117. torchzero/modules/quasi_newton/sg2.py +69 -205
  118. torchzero/modules/restarts/restars.py +39 -37
  119. torchzero/modules/second_order/__init__.py +2 -2
  120. torchzero/modules/second_order/ifn.py +31 -62
  121. torchzero/modules/second_order/inm.py +57 -53
  122. torchzero/modules/second_order/multipoint.py +40 -80
  123. torchzero/modules/second_order/newton.py +165 -196
  124. torchzero/modules/second_order/newton_cg.py +105 -157
  125. torchzero/modules/second_order/nystrom.py +216 -185
  126. torchzero/modules/second_order/rsn.py +132 -125
  127. torchzero/modules/smoothing/laplacian.py +13 -12
  128. torchzero/modules/smoothing/sampling.py +10 -10
  129. torchzero/modules/step_size/adaptive.py +24 -24
  130. torchzero/modules/step_size/lr.py +17 -17
  131. torchzero/modules/termination/termination.py +32 -30
  132. torchzero/modules/trust_region/cubic_regularization.py +3 -3
  133. torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  134. torchzero/modules/trust_region/trust_cg.py +2 -2
  135. torchzero/modules/trust_region/trust_region.py +27 -22
  136. torchzero/modules/variance_reduction/svrg.py +23 -21
  137. torchzero/modules/weight_decay/__init__.py +2 -1
  138. torchzero/modules/weight_decay/reinit.py +83 -0
  139. torchzero/modules/weight_decay/weight_decay.py +17 -18
  140. torchzero/modules/wrappers/optim_wrapper.py +14 -14
  141. torchzero/modules/zeroth_order/cd.py +10 -7
  142. torchzero/optim/mbs.py +291 -0
  143. torchzero/optim/root.py +3 -3
  144. torchzero/optim/utility/split.py +2 -1
  145. torchzero/optim/wrappers/directsearch.py +27 -63
  146. torchzero/optim/wrappers/fcmaes.py +14 -35
  147. torchzero/optim/wrappers/mads.py +11 -31
  148. torchzero/optim/wrappers/moors.py +66 -0
  149. torchzero/optim/wrappers/nevergrad.py +4 -13
  150. torchzero/optim/wrappers/nlopt.py +31 -25
  151. torchzero/optim/wrappers/optuna.py +8 -13
  152. torchzero/optim/wrappers/pybobyqa.py +124 -0
  153. torchzero/optim/wrappers/scipy/__init__.py +7 -0
  154. torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  155. torchzero/optim/wrappers/scipy/brute.py +48 -0
  156. torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  157. torchzero/optim/wrappers/scipy/direct.py +69 -0
  158. torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  159. torchzero/optim/wrappers/scipy/experimental.py +141 -0
  160. torchzero/optim/wrappers/scipy/minimize.py +151 -0
  161. torchzero/optim/wrappers/scipy/sgho.py +111 -0
  162. torchzero/optim/wrappers/wrapper.py +121 -0
  163. torchzero/utils/__init__.py +7 -25
  164. torchzero/utils/benchmarks/__init__.py +0 -0
  165. torchzero/utils/benchmarks/logistic.py +122 -0
  166. torchzero/utils/compile.py +2 -2
  167. torchzero/utils/derivatives.py +97 -73
  168. torchzero/utils/optimizer.py +4 -77
  169. torchzero/utils/python_tools.py +31 -0
  170. torchzero/utils/tensorlist.py +11 -5
  171. torchzero/utils/thoad_tools.py +68 -0
  172. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
  173. torchzero-0.4.1.dist-info/RECORD +209 -0
  174. tests/test_vars.py +0 -185
  175. torchzero/core/var.py +0 -376
  176. torchzero/modules/adaptive/lmadagrad.py +0 -186
  177. torchzero/modules/experimental/momentum.py +0 -160
  178. torchzero/optim/wrappers/scipy.py +0 -572
  179. torchzero/utils/linalg/__init__.py +0 -12
  180. torchzero/utils/linalg/matrix_funcs.py +0 -87
  181. torchzero/utils/linalg/orthogonalize.py +0 -12
  182. torchzero/utils/linalg/svd.py +0 -20
  183. torchzero/utils/ops.py +0 -10
  184. torchzero-0.3.15.dist-info/RECORD +0 -175
  185. /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
  186. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
  187. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,21 @@
1
1
  from typing import Literal
2
2
  import torch
3
- from ..compile import enable_compilation
3
+ from ..utils.compile import allow_compile
4
+
5
+
6
+ # super slow
7
+ # def cholesky_qr(A):
8
+ # """QR of (m, n) A via cholesky of (n, n) matrix"""
9
+ # AtA = A.T @ A
10
+
11
+ # L, _ = torch.linalg.cholesky_ex(AtA) # pylint:disable=not-callable
12
+ # R = L.T
13
+
14
+ # Q = torch.linalg.solve_triangular(R.T, A.T, upper=False).T # pylint:disable=not-callable
15
+ # return Q, R
4
16
 
5
17
  # reference - https://www.cs.cornell.edu/~bindel/class/cs6210-f09/lec18.pdf
18
+ @allow_compile
6
19
  def _get_w_tau(R: torch.Tensor, i: int, eps: float):
7
20
  R_ii = R[...,i,i]
8
21
  R_below = R[...,i:,i]
@@ -17,6 +30,7 @@ def _get_w_tau(R: torch.Tensor, i: int, eps: float):
17
30
  tau = torch.where(degenerate, 1, tau)
18
31
  return w, tau
19
32
 
33
+ @allow_compile
20
34
  def _qr_householder_complete(A:torch.Tensor):
21
35
  *b,m,n = A.shape
22
36
  k = min(m,n)
@@ -33,6 +47,7 @@ def _qr_householder_complete(A:torch.Tensor):
33
47
 
34
48
  return Q, R
35
49
 
50
+ @allow_compile
36
51
  def _qr_householder_reduced(A:torch.Tensor):
37
52
  *b,m,n = A.shape
38
53
  k = min(m,n)
@@ -64,7 +79,6 @@ def _qr_householder_reduced(A:torch.Tensor):
64
79
 
65
80
  return Q, R
66
81
 
67
- # @enable_compilation
68
82
  def qr_householder(A:torch.Tensor, mode: Literal['complete', 'reduced'] = 'reduced'):
69
83
  """an attempt at making QR decomposition for very tall and thin matrices that doesn't freeze, but it is around n_cols times slower than torch.linalg.qr, but compilation makes it faster, but it has to recompile when processing different shapes"""
70
84
  if mode == 'reduced': return _qr_householder_reduced(A)
@@ -1,3 +1,4 @@
1
+ # pylint: disable = non-ascii-name
1
2
  # pyright: reportArgumentType=false
2
3
  import math
3
4
  from collections import deque
@@ -5,8 +6,8 @@ from collections.abc import Callable
5
6
  from typing import Any, NamedTuple, overload
6
7
 
7
8
  import torch
8
-
9
- from .. import (
9
+ from .linalg_utils import mm
10
+ from ..utils import (
10
11
  TensorList,
11
12
  generic_eq,
12
13
  generic_finfo_tiny,
@@ -15,88 +16,71 @@ from .. import (
15
16
  generic_zeros_like,
16
17
  )
17
18
 
18
-
19
- def _make_A_mm_reg(A_mm: Callable, reg):
20
- def A_mm_reg(x): # A_mm with regularization
21
- Ax = A_mm(x)
19
+ def _make_A_mv_reg(A_mv: Callable, reg):
20
+ def A_mv_reg(x): # A_mm with regularization
21
+ Ax = A_mv(x)
22
22
  if not generic_eq(reg, 0): Ax += x*reg
23
23
  return Ax
24
- return A_mm_reg
24
+ return A_mv_reg
25
25
 
26
26
  def _identity(x): return x
27
27
 
28
-
29
- # https://arxiv.org/pdf/2110.02820
30
- def nystrom_approximation(
31
- A_mm: Callable[[torch.Tensor], torch.Tensor],
32
- ndim: int,
33
- rank: int,
34
- device,
35
- dtype = torch.float32,
36
- generator = None,
37
- ) -> tuple[torch.Tensor, torch.Tensor]:
38
- omega = torch.randn((ndim, rank), device=device, dtype=dtype, generator=generator) # Gaussian test matrix
39
- omega, _ = torch.linalg.qr(omega) # Thin QR decomposition # pylint:disable=not-callable
40
-
41
- # Y = AΩ
42
- Y = torch.stack([A_mm(col) for col in omega.unbind(-1)], -1) # rank matvecs
43
- v = torch.finfo(dtype).eps * torch.linalg.matrix_norm(Y, ord='fro') # Compute shift # pylint:disable=not-callable
44
- Yv = Y + v*omega # Shift for stability
45
- C = torch.linalg.cholesky_ex(omega.mT @ Yv)[0] # pylint:disable=not-callable
46
- B = torch.linalg.solve_triangular(C, Yv.mT, upper=False, unitriangular=False).mT # pylint:disable=not-callable
47
- U, S, _ = torch.linalg.svd(B, full_matrices=False) # pylint:disable=not-callable
48
- lambd = (S.pow(2) - v).clip(min=0) #Remove shift, compute eigs
49
- return U, lambd
50
-
51
28
  def nystrom_sketch_and_solve(
52
- A_mm: Callable[[torch.Tensor], torch.Tensor],
29
+ L: torch.Tensor,
30
+ Q: torch.Tensor,
53
31
  b: torch.Tensor,
54
- rank: int,
55
32
  reg: float = 1e-3,
56
- generator=None,
57
33
  ) -> torch.Tensor:
58
- U, lambd = nystrom_approximation(
59
- A_mm=A_mm,
60
- ndim=b.size(-1),
61
- rank=rank,
62
- device=b.device,
63
- dtype=b.dtype,
64
- generator=generator,
65
- )
34
+ """Solves ``(Q diag(L) Q.T + reg*I)x = b``. Becomes super unstable with reg smaller than like 1e-5.
35
+
36
+ Args:
37
+ L (torch.Tensor): eigenvalues, like from ``nystrom_approximation``
38
+ Q (torch.Tensor): eigenvectors, like from ``nystrom_approximation``
39
+ b (torch.Tensor): right hand side
40
+ reg (float, optional): regularization. Defaults to 1e-3.
41
+ """
42
+
66
43
  b = b.unsqueeze(-1)
67
- lambd += reg
44
+ L += reg
68
45
  # x = (A + μI)⁻¹ b
69
- # (A + μI)⁻¹ = U(Λ + μI)⁻¹Uᵀ + (1/μ)(b - UUᵀ)
70
- # x = U(Λ + μI)⁻¹Uᵀb + (1/μ)(b - UUᵀb)
71
- Uᵀb = U.T @ b
72
- term1 = U @ ((1/lambd).unsqueeze(-1) * Uᵀb)
73
- term2 = (1.0 / reg) * (b - U @ Uᵀb)
46
+ # (A + μI)⁻¹ = Q(L + μI)⁻¹Qᵀ + (1/μ)(b - QQᵀ)
47
+ # x = Q(L + μI)⁻¹Qᵀb + (1/μ)(b - QQᵀb)
48
+ Qᵀb = Q.T @ b
49
+ term1 = Q @ ((1/L).unsqueeze(-1) * Qᵀb)
50
+ term2 = (1.0 / reg) * (b - Q @ Qᵀb)
74
51
  return (term1 + term2).squeeze(-1)
75
52
 
76
53
  def nystrom_pcg(
77
- A_mm: Callable[[torch.Tensor], torch.Tensor],
54
+ L: torch.Tensor,
55
+ Q: torch.Tensor,
56
+ A_mv: Callable[[torch.Tensor], torch.Tensor],
78
57
  b: torch.Tensor,
79
- sketch_size: int,
80
58
  reg: float = 1e-6,
81
59
  x0_: torch.Tensor | None = None,
82
- tol: float | None = 1e-4,
60
+ tol: float | None = 1e-8,
83
61
  maxiter: int | None = None,
84
- generator=None,
85
62
  ) -> torch.Tensor:
86
- U, lambd = nystrom_approximation(
87
- A_mm=A_mm,
88
- ndim=b.size(-1),
89
- rank=sketch_size,
90
- device=b.device,
91
- dtype=b.dtype,
92
- generator=generator,
93
- )
94
- lambd += reg
63
+ """conjugate gradient preconditioned by nystrom approximation.
64
+
65
+ The preconditioner can be computed by one matrix-matrix multiplication with A.
66
+ If matrix-matrix is efficient, then this is good (e.g. batched hessian-vector products in pytorch)
67
+
68
+ Args:
69
+ L (torch.Tensor): eigenvalues of approximation of A, like from ``nystrom_approximation``
70
+ Q (torch.Tensor): eigenvectors of approximation of A, like from ``nystrom_approximation``
71
+ A_mv (Callable[[torch.Tensor], torch.Tensor]): mat-vec func with hessian
72
+ b (torch.Tensor): right hand side
73
+ reg (float, optional): regularization. Defaults to 1e-6.
74
+ x0_ (torch.Tensor | None, optional): initial guess (modified in-place). Defaults to None.
75
+ tol (float | None, optional): tolerance for convergence. Defaults to 1e-4.
76
+ maxiter (int | None, optional): maximum number of iterations. Defaults to None.
77
+ """
78
+ L += reg
95
79
  eps = torch.finfo(b.dtype).tiny * 2
96
80
  if tol is None: tol = eps
97
81
 
98
- def A_mm_reg(x): # A_mm with regularization
99
- Ax = A_mm(x)
82
+ def A_mv_reg(x): # A_mm with regularization
83
+ Ax = A_mv(x)
100
84
  if reg != 0: Ax += x*reg
101
85
  return Ax
102
86
 
@@ -104,10 +88,10 @@ def nystrom_pcg(
104
88
  if x0_ is None: x0_ = torch.zeros_like(b)
105
89
 
106
90
  x = x0_
107
- residual = b - A_mm_reg(x)
91
+ residual = b - A_mv_reg(x)
108
92
  # z0 = P⁻¹ r0
109
- term1 = lambd[...,-1] * U * (1/lambd.unsqueeze(-2)) @ U.mT
110
- term2 = torch.eye(U.size(-2), device=U.device,dtype=U.dtype) - U@U.mT
93
+ term1 = L[...,-1] * Q * (1/L.unsqueeze(-2)) @ Q.mT
94
+ term2 = torch.eye(Q.size(-2), device=Q.device,dtype=Q.dtype) - Q@Q.mT
111
95
  P_inv = term1 + term2
112
96
  z = P_inv @ residual
113
97
  p = z.clone() # search direction
@@ -116,7 +100,7 @@ def nystrom_pcg(
116
100
  if init_norm < tol: return x
117
101
  k = 0
118
102
  while True:
119
- Ap = A_mm_reg(p)
103
+ Ap = A_mv_reg(p)
120
104
  rz = residual.dot(z)
121
105
  step_size = rz / p.dot(Ap)
122
106
  x += step_size * p
@@ -138,7 +122,7 @@ def _safe_clip(x: torch.Tensor):
138
122
  if x.abs() < eps: return x.new_full(x.size(), eps).copysign(x)
139
123
  return x
140
124
 
141
- def _trust_tau(x,d,trust_radius):
125
+ def _trust_tau(x, d, trust_radius):
142
126
  xx = x.dot(x)
143
127
  xd = x.dot(d)
144
128
  dd = _safe_clip(d.dot(d))
@@ -150,10 +134,10 @@ def _trust_tau(x,d,trust_radius):
150
134
 
151
135
 
152
136
  class CG:
153
- """Conjugate gradient method.
137
+ """Conjugate gradient method optionally with norm constraint.
154
138
 
155
139
  Args:
156
- A_mm (Callable[[torch.Tensor], torch.Tensor] | torch.Tensor): Callable that returns matvec ``Ax``.
140
+ A_mv (Callable[[torch.Tensor], torch.Tensor] | torch.Tensor): Callable that returns matvec ``Ax``.
157
141
  b (torch.Tensor): right hand side
158
142
  x0 (torch.Tensor | None, optional): initial guess, defaults to zeros. Defaults to None.
159
143
  tol (float | None, optional): tolerance for convergence. Defaults to 1e-8.
@@ -174,10 +158,10 @@ class CG:
174
158
  """
175
159
  def __init__(
176
160
  self,
177
- A_mm: Callable,
161
+ A_mv: Callable,
178
162
  b: torch.Tensor | TensorList,
179
163
  x0: torch.Tensor | TensorList | None = None,
180
- tol: float | None = 1e-4,
164
+ tol: float | None = 1e-8,
181
165
  maxiter: int | None = None,
182
166
  reg: float = 0,
183
167
  trust_radius: float | None = None,
@@ -187,7 +171,7 @@ class CG:
187
171
  P_mm: Callable | None = None,
188
172
  ):
189
173
  # --------------------------------- set attrs -------------------------------- #
190
- self.A_mm = _make_A_mm_reg(A_mm, reg)
174
+ self.A_mv = _make_A_mv_reg(A_mv, reg)
191
175
  self.b = b
192
176
  if tol is None: tol = generic_finfo_tiny(b) * 2
193
177
  self.tol = tol
@@ -214,7 +198,7 @@ class CG:
214
198
  self.r = b
215
199
  else:
216
200
  self.x = x0
217
- self.r = b - A_mm(self.x)
201
+ self.r = b - A_mv(self.x)
218
202
 
219
203
  self.z = self.P_mm(self.r)
220
204
  self.d = self.z
@@ -229,7 +213,7 @@ class CG:
229
213
  if self.iter >= self.maxiter:
230
214
  return x, True
231
215
 
232
- Ad = self.A_mm(d)
216
+ Ad = self.A_mv(d)
233
217
  dAd = d.dot(Ad)
234
218
 
235
219
  # check negative curvature
@@ -289,7 +273,8 @@ class CG:
289
273
  return sol
290
274
 
291
275
  def find_within_trust_radius(history, trust_radius: float):
292
- """find first ``x`` in history that exceeds trust radius, if no such ``x`` exists, returns ``None``"""
276
+ """find first ``x`` in history that exceeds trust radius and returns solution within,
277
+ if no such ``x`` exists, returns ``None``"""
293
278
  for x, x_norm, d in reversed(tuple(history)):
294
279
  if x_norm <= trust_radius:
295
280
  return _trust_tau(x, d, trust_radius)
@@ -306,7 +291,7 @@ class _TensorListSolution(NamedTuple):
306
291
 
307
292
  @overload
308
293
  def cg(
309
- A_mm: Callable[[torch.Tensor], torch.Tensor],
294
+ A_mv: Callable[[torch.Tensor], torch.Tensor],
310
295
  b: torch.Tensor,
311
296
  x0: torch.Tensor | None = None,
312
297
  tol: float | None = 1e-8,
@@ -320,7 +305,7 @@ def cg(
320
305
  ) -> _TensorSolution: ...
321
306
  @overload
322
307
  def cg(
323
- A_mm: Callable[[TensorList], TensorList],
308
+ A_mv: Callable[[TensorList], TensorList],
324
309
  b: TensorList,
325
310
  x0: TensorList | None = None,
326
311
  tol: float | None = 1e-8,
@@ -333,7 +318,7 @@ def cg(
333
318
  P_mm: Callable[[TensorList], TensorList] | None = None
334
319
  ) -> _TensorListSolution: ...
335
320
  def cg(
336
- A_mm: Callable,
321
+ A_mv: Callable,
337
322
  b: torch.Tensor | TensorList,
338
323
  x0: torch.Tensor | TensorList | None = None,
339
324
  tol: float | None = 1e-8,
@@ -346,7 +331,7 @@ def cg(
346
331
  P_mm: Callable | None = None
347
332
  ):
348
333
  solver = CG(
349
- A_mm=A_mm,
334
+ A_mv=A_mv,
350
335
  b=b,
351
336
  x0=x0,
352
337
  tol=tol,
@@ -370,10 +355,10 @@ def cg(
370
355
  # Liu, Yang, and Fred Roosta. "MINRES: From negative curvature detection to monotonicity properties." SIAM Journal on Optimization 32.4 (2022): 2636-2661.
371
356
  @overload
372
357
  def minres(
373
- A_mm: Callable[[torch.Tensor], torch.Tensor] | torch.Tensor,
358
+ A_mv: Callable[[torch.Tensor], torch.Tensor] | torch.Tensor,
374
359
  b: torch.Tensor,
375
360
  x0: torch.Tensor | None = None,
376
- tol: float | None = 1e-4,
361
+ tol: float | None = 1e-8,
377
362
  maxiter: int | None = None,
378
363
  reg: float = 0,
379
364
  npc_terminate: bool=True,
@@ -381,26 +366,27 @@ def minres(
381
366
  ) -> torch.Tensor: ...
382
367
  @overload
383
368
  def minres(
384
- A_mm: Callable[[TensorList], TensorList],
369
+ A_mv: Callable[[TensorList], TensorList],
385
370
  b: TensorList,
386
371
  x0: TensorList | None = None,
387
- tol: float | None = 1e-4,
372
+ tol: float | None = 1e-8,
388
373
  maxiter: int | None = None,
389
374
  reg: float | list[float] | tuple[float] = 0,
390
375
  npc_terminate: bool=True,
391
376
  trust_radius: float | None = None,
392
377
  ) -> TensorList: ...
393
378
  def minres(
394
- A_mm,
379
+ A_mv,
395
380
  b,
396
381
  x0: torch.Tensor | TensorList | None = None,
397
- tol: float | None = 1e-4,
382
+ tol: float | None = 1e-8,
398
383
  maxiter: int | None = None,
399
384
  reg: float | list[float] | tuple[float] = 0,
400
385
  npc_terminate: bool=True,
401
386
  trust_radius: float | None = None, #trust region is experimental
402
387
  ):
403
- A_mm_reg = _make_A_mm_reg(A_mm, reg)
388
+ """MINRES (experimental)"""
389
+ A_mv_reg = _make_A_mv_reg(A_mv, reg)
404
390
  eps = math.sqrt(generic_finfo_tiny(b) * 2)
405
391
  if tol is None: tol = eps
406
392
 
@@ -409,7 +395,7 @@ def minres(
409
395
  R = b
410
396
  x0 = generic_zeros_like(b)
411
397
  else:
412
- R = b - A_mm_reg(x0)
398
+ R = b - A_mv_reg(x0)
413
399
 
414
400
  X: Any = x0
415
401
  beta = b_norm = generic_vector_norm(b)
@@ -429,7 +415,7 @@ def minres(
429
415
 
430
416
  for _ in range(maxiter):
431
417
 
432
- P = A_mm_reg(V)
418
+ P = A_mv_reg(V)
433
419
  alpha = V.dot(P)
434
420
  P -= beta*V_prev
435
421
  P -= alpha*V
@@ -0,0 +1,47 @@
1
+ import torch
2
+
3
+ from . import torch_linalg
4
+
5
+
6
+ def tall_reduced_svd_via_eigh(A: torch.Tensor, tol: float = 0, retry_float64:bool=False):
7
+ """
8
+ Given a tall matrix A of size (m, n), computes U and S from the reduced SVD(A)
9
+ using the eigendecomposition of (n, n) matrix which is faster than direct SVD when m >= n.
10
+
11
+ This truncates small singular values that would causes nans,
12
+ so the returned U and S can have reduced dimension ``k <= n``.
13
+
14
+ Returns U of size ``(m, k)`` and S of size ``(k, )``.
15
+
16
+ Args:
17
+ A (torch.Tensor): A tall matrix of size (m, n) with m >= n.
18
+ tol (float): Tolerance for truncating small singular values. Singular values
19
+ less than ``tol * max_singular_value`` will be discarded.
20
+
21
+
22
+ """
23
+ # if m < n, A.T A will be low rank and we can't use eigh
24
+ m, n = A.size()
25
+ if m < n:
26
+ U, S, V = torch_linalg.svd(A, full_matrices=False, retry_float64=retry_float64)
27
+ return U, S
28
+
29
+ M = A.mH @ A # n,n
30
+
31
+ try:
32
+ L, Q = torch_linalg.eigh(M, retry_float64=retry_float64)
33
+ except torch.linalg.LinAlgError:
34
+ U, S, V = torch_linalg.svd(A, full_matrices=False, retry_float64=retry_float64)
35
+ return U, S
36
+
37
+ L = torch.flip(L, dims=[-1])
38
+ Q = torch.flip(Q, dims=[-1])
39
+
40
+ indices = L > tol * L[0] # L[0] is the max eigenvalue
41
+ L = L[indices]
42
+ Q = Q[:, indices]
43
+
44
+ S = L.sqrt()
45
+ U = (A @ Q) / S
46
+
47
+ return U, S
@@ -0,0 +1,168 @@
1
+ """torch linalg with correct typing and retries in float64"""
2
+ from typing import NamedTuple
3
+
4
+ import torch
5
+
6
+
7
+ def cholesky(A: torch.Tensor, *, upper=False, retry_float64:bool=False) -> torch.Tensor:
8
+ """A - SPD, returns lower triangular L such that ``A = L @ L.mH`` also can pass L to ``torch.cholesky_solve``"""
9
+ try:
10
+ return torch.linalg.cholesky(A, upper=upper) # pylint:disable=not-callable
11
+
12
+ except torch.linalg.LinAlgError as e:
13
+ if not retry_float64: raise e
14
+ dtype = A.dtype
15
+ if dtype == torch.float64: raise e
16
+ return cholesky(A.to(torch.float64), upper=upper, retry_float64=False).to(dtype)
17
+
18
+
19
+ class _QRTuple(NamedTuple):
20
+ Q: torch.Tensor
21
+ R: torch.Tensor
22
+
23
+ def qr(A: torch.Tensor, mode='reduced', retry_float64:bool=False) -> _QRTuple:
24
+ """A - any matrix ``(*, m, n)`` (for some reason sometimes it takes ages on some matrices)
25
+
26
+ ### Returns (if mode = "reduced"):
27
+
28
+ Q: ``(*, m, k)`` - orthogonal
29
+
30
+ R: ``(*, k, n)`` - upper triangular
31
+
32
+ where ``k = min(m,n)``
33
+ """
34
+ try:
35
+ return torch.linalg.qr(A, mode=mode) # pylint:disable=not-callable
36
+
37
+ except torch.linalg.LinAlgError as e:
38
+ if not retry_float64: raise e
39
+ dtype = A.dtype
40
+ if dtype == torch.float64: raise e
41
+ Q, R = qr(A.to(torch.float64), mode=mode, retry_float64=False)
42
+ return _QRTuple(Q=Q.to(dtype), R=R.to(dtype))
43
+
44
+ def eigh(A: torch.Tensor, UPLO="L", retry_float64:bool=False) -> tuple[torch.Tensor, torch.Tensor]:
45
+ """A - symmetric, returns ``(L, Q)``, ``A = Q @ torch.diag(L) @ Q.mH``, this is faster than SVD"""
46
+ try:
47
+ return torch.linalg.eigh(A, UPLO=UPLO) # pylint:disable=not-callable
48
+
49
+ except torch.linalg.LinAlgError as e:
50
+ if not retry_float64: raise e
51
+ dtype = A.dtype
52
+ if dtype == torch.float64: raise e
53
+ L, Q = eigh(A.to(torch.float64), UPLO=UPLO, retry_float64=False)
54
+ return L.to(dtype), Q.to(dtype)
55
+
56
+
57
+
58
+ class _SVDTuple(NamedTuple):
59
+ U: torch.Tensor
60
+ S: torch.Tensor
61
+ Vh: torch.Tensor
62
+
63
+ def svd(A: torch.Tensor, full_matrices=True, driver=None, retry_float64:bool=False) -> _SVDTuple:
64
+ """A - any matrix ``(*, n, m)``, but slows down if A isn't well conditioned, ``A = U @ torch.diag(S) @ Vh``
65
+
66
+ Don't forget to set ``full_matrices=False``
67
+
68
+ ### Returns:
69
+
70
+ U: ``(*, m, m)`` or ``(*, m, k)`` - orthogonal
71
+
72
+ S: ``(*, k,)`` - singular values
73
+
74
+ V^H: ``(*, n, n)`` or ``(*, n, k)`` - orthogonal
75
+
76
+ where ``k = min(m,n)``
77
+
78
+ ### Drivers
79
+
80
+ drivers are only supported on CUDA so A is moved to CUDA by this function if needed
81
+
82
+ from docs:
83
+
84
+ If A is well-conditioned (its condition number is not too large), or you do not mind some precision loss.
85
+
86
+ For a general matrix: ‘gesvdj’ (Jacobi method)
87
+
88
+ If A is tall or wide (m >> n or m << n): ‘gesvda’ (Approximate method)
89
+
90
+ If A is not well-conditioned or precision is relevant: ‘gesvd’ (QR based)
91
+
92
+ By default (driver= None), we call ‘gesvdj’ and, if it fails, we fallback to ‘gesvd’.
93
+ """
94
+ # drivers are only for CUDA
95
+ # also the only one that doesn't freeze is ‘gesvda’
96
+ device=None
97
+ if driver is not None:
98
+ device = A.device
99
+ A = A.cuda()
100
+
101
+ try:
102
+ U, S, Vh = torch.linalg.svd(A, full_matrices=full_matrices, driver=driver) # pylint:disable=not-callable
103
+ if device is not None:
104
+ U = U.to(device); S = S.to(device); Vh = Vh.to(device)
105
+ return _SVDTuple(U=U, S=S, Vh=Vh)
106
+
107
+ except torch.linalg.LinAlgError as e:
108
+ if not retry_float64: raise e
109
+ dtype = A.dtype
110
+ if dtype == torch.float64: raise e
111
+ U, S, Vh = svd(A.to(torch.float64), full_matrices=full_matrices, driver=driver, retry_float64=False)
112
+ return _SVDTuple(U=U.to(dtype), S=S.to(dtype), Vh=Vh.to(dtype))
113
+
114
+ def solve(A: torch.Tensor, B: torch.Tensor, left:bool=True, retry_float64:bool=False) -> torch.Tensor:
115
+ """I think this uses LU"""
116
+ try:
117
+ return torch.linalg.solve(A, B, left=left) # pylint:disable=not-callable
118
+
119
+ except torch.linalg.LinAlgError as e:
120
+ if not retry_float64: raise e
121
+ dtype = A.dtype
122
+ if dtype == torch.float64: raise e
123
+ return solve(A.to(torch.float64), B.to(torch.float64), left=left, retry_float64=False).to(dtype)
124
+
125
+ class _SolveExTuple(NamedTuple):
126
+ result: torch.Tensor
127
+ info: int
128
+
129
+ def solve_ex(A: torch.Tensor, B: torch.Tensor, left:bool=True, retry_float64:bool=False) -> _SolveExTuple:
130
+ """I think this uses LU"""
131
+ result, info = torch.linalg.solve_ex(A, B, left=left) # pylint:disable=not-callable
132
+
133
+ if info != 0:
134
+ if not retry_float64: return _SolveExTuple(result, info)
135
+ dtype = A.dtype
136
+ if dtype == torch.float64: return _SolveExTuple(result, info)
137
+ result, info = solve_ex(A.to(torch.float64), B.to(torch.float64), retry_float64=False)
138
+ return _SolveExTuple(result.to(dtype), info)
139
+
140
+ return _SolveExTuple(result, info)
141
+
142
+ def inv(A: torch.Tensor, retry_float64:bool=False) -> torch.Tensor:
143
+ try:
144
+ return torch.linalg.inv(A) # pylint:disable=not-callable
145
+
146
+ except torch.linalg.LinAlgError as e:
147
+ if not retry_float64: raise e
148
+ dtype = A.dtype
149
+ if dtype == torch.float64: raise e
150
+ return inv(A.to(torch.float64), retry_float64=False).to(dtype)
151
+
152
+
153
+ class _InvExTuple(NamedTuple):
154
+ inverse: torch.Tensor
155
+ info: int
156
+
157
+ def inv_ex(A: torch.Tensor, *, check_errors=False, retry_float64:bool=False) -> _InvExTuple:
158
+ """this retries in float64 but on fail info will be not 0"""
159
+ inverse, info = torch.linalg.inv_ex(A, check_errors=check_errors) # pylint:disable=not-callable
160
+
161
+ if info != 0:
162
+ if not retry_float64: return _InvExTuple(inverse, info)
163
+ dtype = A.dtype
164
+ if dtype == torch.float64: return _InvExTuple(inverse, info)
165
+ inverse, info = inv_ex(A.to(torch.float64), retry_float64=False)
166
+ return _InvExTuple(inverse.to(dtype), info)
167
+
168
+ return _InvExTuple(inverse, info)
@@ -1,4 +1,6 @@
1
1
  from . import experimental
2
+ from .adaptive import *
3
+ from .adaptive import lre_optimizers as lre
2
4
  from .clipping import *
3
5
  from .conjugate_gradient import *
4
6
  from .grad_approximation import *
@@ -7,9 +9,9 @@ from .line_search import *
7
9
  from .misc import *
8
10
  from .momentum import *
9
11
  from .ops import *
10
- from .adaptive import *
11
12
  from .projections import *
12
13
  from .quasi_newton import *
14
+ from .restarts import *
13
15
  from .second_order import *
14
16
  from .smoothing import *
15
17
  from .step_size import *
@@ -18,5 +20,4 @@ from .trust_region import *
18
20
  from .variance_reduction import *
19
21
  from .weight_decay import *
20
22
  from .wrappers import *
21
- from .restarts import *
22
- from .zeroth_order import *
23
+ from .zeroth_order import *
@@ -1,4 +1,5 @@
1
- from .adagrad import Adagrad, FullMatrixAdagrad, AdagradNorm
1
+ from . import lre_optimizers
2
+ from .adagrad import Adagrad, AdagradNorm, FullMatrixAdagrad
2
3
 
3
4
  # from .curveball import CurveBall
4
5
  # from .spectral import SpectralPreconditioner
@@ -8,14 +9,21 @@ from .adan import Adan
8
9
  from .adaptive_heavyball import AdaptiveHeavyBall
9
10
  from .aegd import AEGD
10
11
  from .esgd import ESGD
11
- from .lmadagrad import LMAdagrad
12
12
  from .lion import Lion
13
+ from .ggt import GGT
13
14
  from .mars import MARSCorrection
14
15
  from .matrix_momentum import MatrixMomentum
15
- from .msam import MSAM, MSAMObjective
16
+ from .msam import MSAM, MSAMMomentum
16
17
  from .muon import DualNormCorrection, MuonAdjustLR, Orthogonalize, orthogonalize_grads_
17
18
  from .natural_gradient import NaturalGradient
18
19
  from .orthograd import OrthoGrad, orthograd_
20
+ from .psgd import (
21
+ PSGDDenseNewton,
22
+ PSGDKronNewton,
23
+ PSGDKronWhiten,
24
+ PSGDLRANewton,
25
+ PSGDLRAWhiten,
26
+ )
19
27
  from .rmsprop import RMSprop
20
28
  from .rprop import (
21
29
  BacktrackOnSignChange,