torchzero 0.3.14__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. tests/test_identical.py +2 -2
  2. tests/test_module_autograd.py +586 -0
  3. tests/test_objective.py +188 -0
  4. tests/test_opts.py +47 -36
  5. tests/test_tensorlist.py +0 -8
  6. tests/test_utils_optimizer.py +0 -1
  7. torchzero/__init__.py +1 -1
  8. torchzero/core/__init__.py +8 -2
  9. torchzero/core/chain.py +47 -0
  10. torchzero/core/functional.py +103 -0
  11. torchzero/core/modular.py +233 -0
  12. torchzero/core/module.py +132 -643
  13. torchzero/core/objective.py +948 -0
  14. torchzero/core/reformulation.py +56 -23
  15. torchzero/core/transform.py +261 -365
  16. torchzero/linalg/__init__.py +10 -0
  17. torchzero/linalg/eigh.py +34 -0
  18. torchzero/linalg/linalg_utils.py +14 -0
  19. torchzero/{utils/linalg → linalg}/linear_operator.py +132 -34
  20. torchzero/linalg/matrix_power.py +28 -0
  21. torchzero/linalg/orthogonalize.py +95 -0
  22. torchzero/{utils/linalg → linalg}/qr.py +4 -2
  23. torchzero/{utils/linalg → linalg}/solve.py +76 -88
  24. torchzero/linalg/svd.py +20 -0
  25. torchzero/linalg/torch_linalg.py +168 -0
  26. torchzero/modules/__init__.py +0 -1
  27. torchzero/modules/adaptive/__init__.py +1 -1
  28. torchzero/modules/adaptive/adagrad.py +163 -213
  29. torchzero/modules/adaptive/adahessian.py +74 -103
  30. torchzero/modules/adaptive/adam.py +53 -76
  31. torchzero/modules/adaptive/adan.py +49 -30
  32. torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
  33. torchzero/modules/adaptive/aegd.py +12 -12
  34. torchzero/modules/adaptive/esgd.py +98 -119
  35. torchzero/modules/adaptive/lion.py +5 -10
  36. torchzero/modules/adaptive/lmadagrad.py +87 -32
  37. torchzero/modules/adaptive/mars.py +5 -5
  38. torchzero/modules/adaptive/matrix_momentum.py +47 -51
  39. torchzero/modules/adaptive/msam.py +70 -52
  40. torchzero/modules/adaptive/muon.py +59 -124
  41. torchzero/modules/adaptive/natural_gradient.py +33 -28
  42. torchzero/modules/adaptive/orthograd.py +11 -15
  43. torchzero/modules/adaptive/rmsprop.py +83 -75
  44. torchzero/modules/adaptive/rprop.py +48 -47
  45. torchzero/modules/adaptive/sam.py +55 -45
  46. torchzero/modules/adaptive/shampoo.py +123 -129
  47. torchzero/modules/adaptive/soap.py +207 -143
  48. torchzero/modules/adaptive/sophia_h.py +106 -130
  49. torchzero/modules/clipping/clipping.py +15 -18
  50. torchzero/modules/clipping/ema_clipping.py +31 -25
  51. torchzero/modules/clipping/growth_clipping.py +14 -17
  52. torchzero/modules/conjugate_gradient/cg.py +26 -37
  53. torchzero/modules/experimental/__init__.py +3 -6
  54. torchzero/modules/experimental/coordinate_momentum.py +36 -0
  55. torchzero/modules/experimental/curveball.py +25 -41
  56. torchzero/modules/experimental/gradmin.py +2 -2
  57. torchzero/modules/{higher_order → experimental}/higher_order_newton.py +14 -40
  58. torchzero/modules/experimental/newton_solver.py +22 -53
  59. torchzero/modules/experimental/newtonnewton.py +20 -17
  60. torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  61. torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  62. torchzero/modules/experimental/spsa1.py +5 -5
  63. torchzero/modules/experimental/structural_projections.py +1 -4
  64. torchzero/modules/functional.py +8 -1
  65. torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  66. torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  67. torchzero/modules/grad_approximation/rfdm.py +20 -17
  68. torchzero/modules/least_squares/gn.py +90 -42
  69. torchzero/modules/line_search/__init__.py +1 -1
  70. torchzero/modules/line_search/_polyinterp.py +3 -1
  71. torchzero/modules/line_search/adaptive.py +3 -3
  72. torchzero/modules/line_search/backtracking.py +3 -3
  73. torchzero/modules/line_search/interpolation.py +160 -0
  74. torchzero/modules/line_search/line_search.py +42 -51
  75. torchzero/modules/line_search/strong_wolfe.py +5 -5
  76. torchzero/modules/misc/debug.py +12 -12
  77. torchzero/modules/misc/escape.py +10 -10
  78. torchzero/modules/misc/gradient_accumulation.py +10 -78
  79. torchzero/modules/misc/homotopy.py +16 -8
  80. torchzero/modules/misc/misc.py +120 -122
  81. torchzero/modules/misc/multistep.py +63 -61
  82. torchzero/modules/misc/regularization.py +49 -44
  83. torchzero/modules/misc/split.py +30 -28
  84. torchzero/modules/misc/switch.py +37 -32
  85. torchzero/modules/momentum/averaging.py +14 -14
  86. torchzero/modules/momentum/cautious.py +34 -28
  87. torchzero/modules/momentum/momentum.py +11 -11
  88. torchzero/modules/ops/__init__.py +4 -4
  89. torchzero/modules/ops/accumulate.py +21 -21
  90. torchzero/modules/ops/binary.py +67 -66
  91. torchzero/modules/ops/higher_level.py +19 -19
  92. torchzero/modules/ops/multi.py +44 -41
  93. torchzero/modules/ops/reduce.py +26 -23
  94. torchzero/modules/ops/unary.py +53 -53
  95. torchzero/modules/ops/utility.py +47 -46
  96. torchzero/modules/projections/galore.py +1 -1
  97. torchzero/modules/projections/projection.py +43 -43
  98. torchzero/modules/quasi_newton/__init__.py +2 -0
  99. torchzero/modules/quasi_newton/damping.py +1 -1
  100. torchzero/modules/quasi_newton/lbfgs.py +7 -7
  101. torchzero/modules/quasi_newton/lsr1.py +7 -7
  102. torchzero/modules/quasi_newton/quasi_newton.py +25 -16
  103. torchzero/modules/quasi_newton/sg2.py +292 -0
  104. torchzero/modules/restarts/restars.py +26 -24
  105. torchzero/modules/second_order/__init__.py +6 -3
  106. torchzero/modules/second_order/ifn.py +58 -0
  107. torchzero/modules/second_order/inm.py +101 -0
  108. torchzero/modules/second_order/multipoint.py +40 -80
  109. torchzero/modules/second_order/newton.py +105 -228
  110. torchzero/modules/second_order/newton_cg.py +102 -154
  111. torchzero/modules/second_order/nystrom.py +158 -178
  112. torchzero/modules/second_order/rsn.py +237 -0
  113. torchzero/modules/smoothing/laplacian.py +13 -12
  114. torchzero/modules/smoothing/sampling.py +11 -10
  115. torchzero/modules/step_size/adaptive.py +23 -23
  116. torchzero/modules/step_size/lr.py +15 -15
  117. torchzero/modules/termination/termination.py +32 -30
  118. torchzero/modules/trust_region/cubic_regularization.py +2 -2
  119. torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  120. torchzero/modules/trust_region/trust_cg.py +1 -1
  121. torchzero/modules/trust_region/trust_region.py +27 -22
  122. torchzero/modules/variance_reduction/svrg.py +21 -18
  123. torchzero/modules/weight_decay/__init__.py +2 -1
  124. torchzero/modules/weight_decay/reinit.py +83 -0
  125. torchzero/modules/weight_decay/weight_decay.py +12 -13
  126. torchzero/modules/wrappers/optim_wrapper.py +57 -50
  127. torchzero/modules/zeroth_order/cd.py +9 -6
  128. torchzero/optim/root.py +3 -3
  129. torchzero/optim/utility/split.py +2 -1
  130. torchzero/optim/wrappers/directsearch.py +27 -63
  131. torchzero/optim/wrappers/fcmaes.py +14 -35
  132. torchzero/optim/wrappers/mads.py +11 -31
  133. torchzero/optim/wrappers/moors.py +66 -0
  134. torchzero/optim/wrappers/nevergrad.py +4 -4
  135. torchzero/optim/wrappers/nlopt.py +31 -25
  136. torchzero/optim/wrappers/optuna.py +6 -13
  137. torchzero/optim/wrappers/pybobyqa.py +124 -0
  138. torchzero/optim/wrappers/scipy/__init__.py +7 -0
  139. torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  140. torchzero/optim/wrappers/scipy/brute.py +48 -0
  141. torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  142. torchzero/optim/wrappers/scipy/direct.py +69 -0
  143. torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  144. torchzero/optim/wrappers/scipy/experimental.py +141 -0
  145. torchzero/optim/wrappers/scipy/minimize.py +151 -0
  146. torchzero/optim/wrappers/scipy/sgho.py +111 -0
  147. torchzero/optim/wrappers/wrapper.py +121 -0
  148. torchzero/utils/__init__.py +7 -25
  149. torchzero/utils/compile.py +2 -2
  150. torchzero/utils/derivatives.py +112 -88
  151. torchzero/utils/optimizer.py +4 -77
  152. torchzero/utils/python_tools.py +31 -0
  153. torchzero/utils/tensorlist.py +11 -5
  154. torchzero/utils/thoad_tools.py +68 -0
  155. {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
  156. torchzero-0.4.0.dist-info/RECORD +191 -0
  157. tests/test_vars.py +0 -185
  158. torchzero/modules/experimental/momentum.py +0 -160
  159. torchzero/modules/higher_order/__init__.py +0 -1
  160. torchzero/optim/wrappers/scipy.py +0 -572
  161. torchzero/utils/linalg/__init__.py +0 -12
  162. torchzero/utils/linalg/matrix_funcs.py +0 -87
  163. torchzero/utils/linalg/orthogonalize.py +0 -12
  164. torchzero/utils/linalg/svd.py +0 -20
  165. torchzero/utils/ops.py +0 -10
  166. torchzero-0.3.14.dist-info/RECORD +0 -167
  167. /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
  168. {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
  169. {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,10 @@
1
+ from . import linear_operator
2
+
3
+ from .matrix_power import (
4
+ matrix_power_eigh,
5
+ matrix_power_svd,
6
+ )
7
+ from .orthogonalize import zeropower_via_eigh, zeropower_via_newtonschulz5, zeropower_via_svd, orthogonalize
8
+ from .qr import qr_householder
9
+ from .solve import cg, nystrom_sketch_and_solve, nystrom_pcg
10
+ from .eigh import nystrom_approximation
@@ -0,0 +1,34 @@
1
+ from collections.abc import Callable
2
+ import torch
3
+ from .linalg_utils import mm
4
+
5
+
6
+
7
+ # https://arxiv.org/pdf/2110.02820
8
+ def nystrom_approximation(
9
+ A_mv: Callable[[torch.Tensor], torch.Tensor] | None,
10
+ A_mm: Callable[[torch.Tensor], torch.Tensor] | None,
11
+ ndim: int,
12
+ rank: int,
13
+ device,
14
+ dtype = torch.float32,
15
+ generator = None,
16
+ ) -> tuple[torch.Tensor, torch.Tensor]:
17
+ """Computes Nyström approximation to positive-semidefinite A factored as Q L Q^T (truncatd eigenvalue decomp),
18
+ returns ``(L, Q)``.
19
+
20
+ A is ``(m,m)``, then Q is ``(m, rank)``; L is a ``(rank, )`` vector - diagonal of ``(rank, rank)``"""
21
+ # basis
22
+ O = torch.randn((ndim, rank), device=device, dtype=dtype, generator=generator) # Gaussian test matrix
23
+ O, _ = torch.linalg.qr(O) # Thin QR decomposition # pylint:disable=not-callable
24
+
25
+ # Y = AΩ
26
+ AO = mm(A_mv=A_mv, A_mm=A_mm, X=O)
27
+
28
+ v = torch.finfo(dtype).eps * torch.linalg.matrix_norm(AO, ord='fro') # Compute shift # pylint:disable=not-callable
29
+ Yv = AO + v*O # Shift for stability
30
+ C = torch.linalg.cholesky_ex(O.mT @ Yv)[0] # pylint:disable=not-callable
31
+ B = torch.linalg.solve_triangular(C, Yv.mT, upper=False, unitriangular=False).mT # pylint:disable=not-callable
32
+ Q, S, _ = torch.linalg.svd(B, full_matrices=False) # pylint:disable=not-callable
33
+ L = (S.pow(2) - v).clip(min=0) #Remove shift, compute eigs
34
+ return L, Q
@@ -0,0 +1,14 @@
1
+ from collections.abc import Callable
2
+ import torch
3
+
4
+ def mm(
5
+ A_mv: Callable[[torch.Tensor], torch.Tensor] | None,
6
+ A_mm: Callable[[torch.Tensor], torch.Tensor] | None,
7
+ X
8
+ ):
9
+ """matrix-matrix when either mv or mm is given"""
10
+ if A_mm is not None: return A_mm(X)
11
+ assert A_mv is not None
12
+ return torch.stack([A_mv(col) for col in X.unbind(-1)], -1) # rank matvecs
13
+
14
+
@@ -1,4 +1,6 @@
1
- """simplified version of https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.linalg.LinearOperator.html. This is used for trust regions."""
1
+ """This is mainly used for trust regions. In some cases certain operations are relaxed, e.g. eigenvalue shift instead of
2
+ adding diagonal when it isn't tractable, to make it work with Levenberg-Marquadt.
3
+ """
2
4
  import math
3
5
  from abc import ABC, abstractmethod
4
6
  from functools import partial
@@ -7,7 +9,8 @@ from typing import cast, final
7
9
 
8
10
  import torch
9
11
 
10
- from ..torch_tools import tofloat, tonumpy, totensor
12
+ from ..utils.torch_tools import tofloat, tonumpy, totensor
13
+ from .solve import nystrom_sketch_and_solve
11
14
 
12
15
  if find_spec('scipy') is not None:
13
16
  from scipy.sparse.linalg import LinearOperator as _ScipyLinearOperator
@@ -15,7 +18,6 @@ else:
15
18
  _ScipyLinearOperator = None
16
19
 
17
20
  class LinearOperator(ABC):
18
- """this is used for trust region"""
19
21
  device: torch.types.Device
20
22
  dtype: torch.dtype | None
21
23
 
@@ -25,18 +27,24 @@ class LinearOperator(ABC):
25
27
  def rmatvec(self, x: torch.Tensor) -> torch.Tensor:
26
28
  raise NotImplementedError(f"{self.__class__.__name__} doesn't implement rmatvec")
27
29
 
28
- def matmat(self, x: torch.Tensor) -> "LinearOperator":
29
- raise NotImplementedError(f"{self.__class__.__name__} doesn't implement matmul")
30
+ def matmat(self, X: torch.Tensor) -> "LinearOperator":
31
+ raise NotImplementedError(f"{self.__class__.__name__} doesn't implement matmat")
32
+
33
+ def rmatmat(self, X: torch.Tensor) -> "LinearOperator":
34
+ raise NotImplementedError(f"{self.__class__.__name__} doesn't implement rmatmat")
30
35
 
31
36
  def solve(self, b: torch.Tensor) -> torch.Tensor:
32
37
  raise NotImplementedError(f"{self.__class__.__name__} doesn't implement solve")
33
38
 
39
+ def solve_plus_diag(self, b: torch.Tensor, diag: int | float | torch.Tensor) -> torch.Tensor:
40
+ return self.add_diagonal(diag).solve(b)
41
+
34
42
  def solve_bounded(self, b: torch.Tensor, bound:float, ord:float=2) -> torch.Tensor:
35
43
  """solve with a norm bound on x"""
36
44
  raise NotImplementedError(f"{self.__class__.__name__} doesn't implement solve_bounded")
37
45
 
38
- def update(self, *args, **kwargs) -> None:
39
- raise NotImplementedError(f"{self.__class__.__name__} doesn't implement update")
46
+ # def update(self, *args, **kwargs) -> None:
47
+ # raise NotImplementedError(f"{self.__class__.__name__} doesn't implement update")
40
48
 
41
49
  def add(self, x: torch.Tensor) -> "LinearOperator":
42
50
  raise NotImplementedError(f"{self.__class__.__name__} doesn't implement add")
@@ -129,8 +137,8 @@ class Dense(LinearOperator):
129
137
  def matvec(self, x): return self.A.mv(x)
130
138
  def rmatvec(self, x): return self.A.mH.mv(x)
131
139
 
132
- def matmat(self, x): return Dense(self.A.mm(x))
133
- def rmatmat(self, x): return Dense(self.A.mH.mm(x))
140
+ def matmat(self, X): return Dense(self.A.mm(X))
141
+ def rmatmat(self, X): return Dense(self.A.mH.mm(X))
134
142
 
135
143
  def solve(self, b): return _solve(self.A, b)
136
144
 
@@ -146,6 +154,12 @@ class Dense(LinearOperator):
146
154
  def is_dense(self): return True
147
155
  def transpose(self): return Dense(self.A.mH)
148
156
 
157
+ class SPD(Dense):
158
+ def solve(self, b: torch.Tensor):
159
+ L, info = torch.linalg.cholesky_ex(self.A) # pylint:disable=not-callable
160
+ return torch.cholesky_solve(b.unsqueeze(-1), L).squeeze(-1)
161
+
162
+
149
163
  class DenseInverse(LinearOperator):
150
164
  """Represents inverse of a dense matrix A."""
151
165
  def __init__(self, A_inv: torch.Tensor):
@@ -156,8 +170,8 @@ class DenseInverse(LinearOperator):
156
170
  def matvec(self, x): return _solve(self.A_inv, x) # pylint:disable=not-callable
157
171
  def rmatvec(self, x): return _solve(self.A_inv.mH, x) # pylint:disable=not-callable
158
172
 
159
- def matmat(self, x): return Dense(_solve(self.A_inv, x)) # pylint:disable=not-callable
160
- def rmatmat(self, x): return Dense(_solve(self.A_inv.mH, x)) # pylint:disable=not-callable
173
+ def matmat(self, X): return Dense(_solve(self.A_inv, X)) # pylint:disable=not-callable
174
+ def rmatmat(self, X): return Dense(_solve(self.A_inv.mH, X)) # pylint:disable=not-callable
161
175
 
162
176
  def solve(self, b): return self.A_inv.mv(b)
163
177
 
@@ -190,8 +204,8 @@ class Diagonal(LinearOperator):
190
204
  def matvec(self, x): return self.A * x
191
205
  def rmatvec(self, x): return self.A * x
192
206
 
193
- def matmat(self, x): return Dense(x * self.A.unsqueeze(-1))
194
- def rmatmat(self, x): return Dense(x * self.A.unsqueeze(-1))
207
+ def matmat(self, X): return Dense(X * self.A.unsqueeze(-1))
208
+ def rmatmat(self, X): return Dense(X * self.A.unsqueeze(-1))
195
209
 
196
210
  def solve(self, b): return b/self.A
197
211
 
@@ -221,8 +235,8 @@ class ScaledIdentity(LinearOperator):
221
235
  def matvec(self, x): return x * self.s
222
236
  def rmatvec(self, x): return x * self.s
223
237
 
224
- def matmat(self, x): return Dense(x * self.s)
225
- def rmatmat(self, x): return Dense(x * self.s)
238
+ def matmat(self, X): return Dense(X * self.s)
239
+ def rmatmat(self, X): return Dense(X * self.s)
226
240
 
227
241
  def solve(self, b): return b / self.s
228
242
  def solve_bounded(self, b, bound, ord = 2):
@@ -263,6 +277,7 @@ class ScaledIdentity(LinearOperator):
263
277
  def is_dense(self): return False
264
278
  def transpose(self): return ScaledIdentity(self.s, shape=self.shape, device=self.device, dtype=self.dtype)
265
279
 
280
+
266
281
  class AtA(LinearOperator):
267
282
  def __init__(self, A: torch.Tensor):
268
283
  self.A = A
@@ -270,8 +285,8 @@ class AtA(LinearOperator):
270
285
  def matvec(self, x): return self.A.mH.mv(self.A.mv(x))
271
286
  def rmatvec(self, x): return self.matvec(x)
272
287
 
273
- def matmat(self, x): return Dense(torch.linalg.multi_dot([self.A.mH, self.A, x])) # pylint:disable=not-callable
274
- def rmatmat(self, x): return Dense(torch.linalg.multi_dot([self.A.mH, self.A, x])) # pylint:disable=not-callable
288
+ def matmat(self, X): return Dense(torch.linalg.multi_dot([self.A.mH, self.A, X])) # pylint:disable=not-callable
289
+ def rmatmat(self, X): return Dense(torch.linalg.multi_dot([self.A.mH, self.A, X])) # pylint:disable=not-callable
275
290
 
276
291
  def is_dense(self): return False
277
292
  def to_tensor(self): return self.A.mH @ self.A
@@ -283,7 +298,27 @@ class AtA(LinearOperator):
283
298
  return Dense(self.to_tensor() + torch.diag_embed(x))
284
299
 
285
300
  def solve(self, b):
286
- return Dense(self.to_tensor()).solve(b)
301
+ *_, n, m = self.A.shape
302
+ if n >= m: return Dense(self.to_tensor()).solve(b)
303
+
304
+ A = self.A
305
+ C = A @ A.mH # (n, n), SPD
306
+ L, info = torch.linalg.cholesky_ex(C) # pylint:disable=not-callable
307
+ z = torch.cholesky_solve((A @ b).unsqueeze(-1), L).squeeze(-1)
308
+ return A.mH @ z
309
+
310
+ def solve_plus_diag(self, b, diag):
311
+ *_, n, m = self.A.shape
312
+ if (n >= m) or (isinstance(diag, torch.Tensor) and diag.numel() > 1):
313
+ return Dense(self.to_tensor()).solve_plus_diag(b, diag)
314
+
315
+ A = self.A
316
+ I = torch.eye(A.size(-2), device=A.device, dtype=A.dtype)
317
+
318
+ C = (A @ A.mH).add_(I.mul_(diag)) # (n, n), SPD
319
+ L, info = torch.linalg.cholesky_ex(C + I.mul_(diag)) # pylint:disable=not-callable
320
+ z = torch.cholesky_solve((A @ b).unsqueeze(-1), L).squeeze(-1)
321
+ return (1 / diag) * (b - A.mH @ z)
287
322
 
288
323
  def inv(self):
289
324
  return Dense(self.to_tensor()).inv()
@@ -295,35 +330,98 @@ class AtA(LinearOperator):
295
330
  n = self.A.size(1)
296
331
  return (n,n)
297
332
 
298
- class AAT(LinearOperator):
333
+ class AAt(AtA):
299
334
  def __init__(self, A: torch.Tensor):
300
- self.A = A
335
+ super().__init__(A.mH)
301
336
 
302
- def matvec(self, x): return self.A.mv(self.A.mH.mv(x))
303
- def rmatvec(self, x): return self.matvec(x)
337
+ class Sketched(LinearOperator):
338
+ """A projected by sketching matrix S, representing the operator S @ A_proj @ S.T.
339
+
340
+ Where A is (n, n) and S is (n, sketch_size).
341
+ """
342
+ def __init__(self, S: torch.Tensor, A_proj: torch.Tensor):
343
+ self.S = S
344
+ self.A_proj = A_proj
345
+ self.device = self.A_proj.device; self.dtype = self.A_proj.dtype
346
+
347
+ def matvec(self, x):
348
+ x_proj = self.S.T @ x
349
+ Ax_proj = self.A_proj @ x_proj
350
+ return self.S @ Ax_proj
351
+
352
+ def rmatvec(self, x):
353
+ x_proj = self.S.T @ x
354
+ ATx_proj = self.A_proj.mH @ x_proj
355
+ return self.S @ ATx_proj
356
+
357
+
358
+ def matmat(self, X): return Dense(torch.linalg.multi_dot([self.S, self.A_proj, self.S.T, X])) # pylint:disable=not-callable
359
+ def rmatmat(self, X): return Dense(torch.linalg.multi_dot([self.S, self.A_proj.mH, self.S.T, X])) # pylint:disable=not-callable
304
360
 
305
- def matmat(self, x): return Dense(torch.linalg.multi_dot([self.A, self.A.mH, x])) # pylint:disable=not-callable
306
- def rmatmat(self, x): return Dense(torch.linalg.multi_dot([self.A, self.A.mH, x])) # pylint:disable=not-callable
307
361
 
308
362
  def is_dense(self): return False
309
- def to_tensor(self): return self.A @ self.A.mH
310
- def transpose(self): return AAT(self.A)
363
+ def to_tensor(self): return self.S @ self.A_proj @ self.S.T
364
+ def transpose(self): return Sketched(self.S, self.A_proj.mH)
311
365
 
312
366
  def add_diagonal(self, x):
367
+ """this doesn't correspond to adding diagonal to A, however it still works for LM etc."""
313
368
  if isinstance(x, torch.Tensor) and x.numel() <= 1: x = x.item()
314
- if isinstance(x, (int,float)): x = torch.full((self.shape[0],), fill_value=x, device=self.A.device, dtype=self.A.dtype)
315
- return Dense(self.to_tensor() + torch.diag_embed(x))
369
+ if isinstance(x, (int,float)): x = torch.full((self.A_proj.shape[0],), fill_value=x, device=self.A_proj.device, dtype=self.A_proj.dtype)
370
+ return Sketched(S=self.S, A_proj=self.A_proj + x.diag_embed())
316
371
 
317
372
  def solve(self, b):
318
- return Dense(self.to_tensor()).solve(b)
373
+ return self.S @ torch.linalg.lstsq(self.A_proj, self.S.T @ b).solution # pylint:disable=not-callable
319
374
 
320
375
  def inv(self):
321
- return Dense(self.to_tensor()).inv()
322
-
323
- def diagonal(self):
324
- return self.A.pow(2).sum(0)
376
+ return Sketched(S=self.S, A_proj=torch.linalg.pinv(self.A_proj)) # pylint:disable=not-callable
325
377
 
326
378
  def size(self):
327
- n = self.A.size(1)
379
+ n = self.S.size(0)
328
380
  return (n,n)
329
381
 
382
+
383
+ class Eigendecomposition(LinearOperator):
384
+ """A represented as Q L Q^H. If A is (n,n), then Q is (n, rank); L is a vector - diagonal of (rank, rank)"""
385
+ def __init__(self, L: torch.Tensor, Q: torch.Tensor, use_nystrom: bool = True):
386
+ self.L = L
387
+ self.Q = Q
388
+ self.use_nystrom = use_nystrom
389
+ self.device = self.L.device; self.dtype = self.L.dtype
390
+
391
+ def matvec(self, x):
392
+ return self.Q @ ((self.Q.mH @ x) * self.L)
393
+
394
+ def rmatvec(self, x):
395
+ return self.matvec(x)
396
+
397
+ def matmat(self, X):
398
+ return Dense(self.Q @ (self.L[:, None] * (self.Q.mH @ X)))
399
+
400
+ def rmatmat(self, X):
401
+ return self.matmat(X)
402
+
403
+ def is_dense(self): return False
404
+ def to_tensor(self): return self.Q @ self.L.diag_embed() @ self.Q.mH
405
+ def transpose(self): return Eigendecomposition(L=self.L, Q=self.Q)
406
+
407
+ def add_diagonal(self, x):
408
+ """this doesn't correspond to adding diagonal to A, however it still works for LM etc."""
409
+ if isinstance(x, torch.Tensor) and x.numel() > 1:
410
+ raise RuntimeError("Eigendecomposition linear operator doesn't support add_diagonal with a vector diag")
411
+
412
+ return Eigendecomposition(L=self.L + x, Q = self.Q)
413
+
414
+ def solve(self, b):
415
+ return self.Q @ ((self.Q.mH @ b) / self.L)
416
+
417
+ def solve_plus_diag(self, b, diag):
418
+ if isinstance(diag, torch.Tensor) and diag.numel() > 1: return super().solve_plus_diag(b, diag)
419
+ if not self.use_nystrom: return super().solve_plus_diag(b, diag)
420
+ return nystrom_sketch_and_solve(L=self.L, Q=self.Q, b=b, reg=float(diag))
421
+
422
+ def inv(self):
423
+ return Eigendecomposition(L=1 / self.L, Q = self.Q)
424
+
425
+ def size(self):
426
+ n = self.Q.size(0)
427
+ return (n,n)
@@ -0,0 +1,28 @@
1
+ from typing import Literal
2
+ import warnings
3
+ from collections.abc import Callable
4
+
5
+ import torch
6
+ from . import torch_linalg
7
+ def matrix_power_eigh(A: torch.Tensor, power:float, abs:bool=False):
8
+ """this is faster than SVD but only for positive semi-definite symmetric matrices
9
+ (covariance matrices are always SPD)"""
10
+
11
+ L, Q = torch_linalg.eigh(A, retry_float64=True) # pylint:disable=not-callable
12
+ if abs: L.abs_()
13
+ if power % 2 != 0: L.clip_(min = torch.finfo(A.dtype).tiny * 2)
14
+ return (Q * L.pow_(power).unsqueeze(-2)) @ Q.mH
15
+
16
+
17
+ def matrix_power_svd(A: torch.Tensor, power: float) -> torch.Tensor:
18
+ """for any symmetric matrix"""
19
+ U, S, Vh = torch_linalg.svd(A, full_matrices=False, retry_float64=True) # pylint:disable=not-callable
20
+ if power % 2 != 0: S.clip_(min = torch.finfo(A.dtype).tiny * 2)
21
+ return (U * S.pow_(power).unsqueeze(-2)) @ Vh
22
+
23
+ MatrixPowerMethod = Literal["eigh", "eigh_abs", "svd"]
24
+ def matrix_power(A: torch.Tensor, power: float, method: MatrixPowerMethod = "eigh_abs") -> torch.Tensor:
25
+ if method == "eigh": return matrix_power_eigh(A, power)
26
+ if method == "eigh_abs": return matrix_power_eigh(A, power, abs=True)
27
+ if method == "svd": return matrix_power_svd(A, power)
28
+ raise ValueError(method)
@@ -0,0 +1,95 @@
1
+ from typing import Literal
2
+ import torch
3
+
4
+ from ..utils.compile import allow_compile
5
+ from . import torch_linalg
6
+
7
+ # zeropower_via_newtonschulz5 from:
8
+ # https://github.com/KellerJordan/Muon/blob/master/muon.py
9
+ # and
10
+ # https://github.com/HomebrewML/HeavyBall/blob/main/heavyball/utils.py#L452
11
+ _NS_COEFFS = (
12
+ (4.0848, -6.8946, 2.9270),
13
+ (3.9505, -6.3029, 2.6377),
14
+ (3.7418, -5.5913, 2.3037),
15
+ (2.8769, -3.1427, 1.2046),
16
+ (2.8366, -3.0525, 1.2012)
17
+ )
18
+
19
+ @allow_compile
20
+ def zeropower_via_newtonschulz5(G: torch.Tensor, coeffs=_NS_COEFFS) -> torch.Tensor:
21
+ """
22
+ Applies to last 2 dims - so usually reverse_dims should be applied to G before and after.
23
+
24
+ Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
25
+ quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
26
+ of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
27
+ zero even beyond the point where the iteration no longer converges all the way to one everywhere
28
+ on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
29
+ where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
30
+ performance at all relative to UV^T, where USV^T = G is the SVD.
31
+ """
32
+ assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
33
+
34
+ X = G.bfloat16()
35
+ if G.size(-2) > G.size(-1):
36
+ X = X.mT
37
+
38
+ # Ensure spectral norm is at most 1
39
+ X = X / (X.norm(dim=(-2, -1), keepdim=True).clip(min=torch.finfo(X.dtype).tiny * 2))
40
+
41
+ # Perform the NS iterations
42
+ for a,b,c in coeffs:
43
+ A = X @ X.mT
44
+ B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
45
+ X = a * X + B @ X
46
+
47
+ if G.size(-2) > G.size(-1):
48
+ X = X.mT
49
+
50
+ return X.to(G.dtype)
51
+
52
+ # code from https://github.com/MarkTuddenham/Orthogonal-Optimisers.
53
+ # Tuddenham, M., Prügel-Bennett, A., & Hare, J. (2022).
54
+ # Orthogonalising gradients to speed up neural network optimisation. arXiv preprint arXiv:2202.07052.
55
+ def zeropower_via_svd(A: torch.Tensor) -> torch.Tensor:
56
+ """
57
+ Applies to first 2 dims and isn't batched - rest of dimensions are flattened.
58
+ """
59
+ try:
60
+ U, S, Vt = torch_linalg.svd(A, full_matrices=False, retry_float64=True) # pylint:disable=not-callable
61
+ except torch.linalg.LinAlgError:
62
+ U, S, Vt = torch.svd_lowrank(A, q=1, M=1e-4 * A.mean() * torch.rand_like(A))
63
+
64
+ return U @ Vt
65
+
66
+ def zeropower_via_eigh(A: torch.Tensor) -> torch.Tensor:
67
+ """
68
+ Only SPD and I need to check if I apply those to SPD because this is better than SVD.
69
+ """
70
+ L, Q = torch_linalg.eigh(A, retry_float64=True)
71
+ return Q @ Q.mH
72
+
73
+
74
+ def orthogonalize_via_qr(A: torch.Tensor):
75
+ *_, m, n = A.shape
76
+ T = False
77
+ if m < n:
78
+ T = True
79
+ m,n = n,m
80
+ A = A.mH
81
+
82
+ Q = torch_linalg.qr(A, mode='reduced', retry_float64=True).Q
83
+
84
+ if T:
85
+ Q = Q.mH
86
+
87
+ return Q
88
+
89
+ OrthogonalizeMethod = Literal["newtonschulz", "svd", "qr"]
90
+ def orthogonalize(A: torch.Tensor, method: OrthogonalizeMethod = "newtonschulz") -> torch.Tensor:
91
+ if method == "newtonschulz": return zeropower_via_newtonschulz5(A)
92
+ if method == "svd": return zeropower_via_svd(A)
93
+ if method == "qr": return orthogonalize_via_qr(A)
94
+ if method == "eigh": return zeropower_via_eigh(A)
95
+ raise ValueError(method)
@@ -1,8 +1,9 @@
1
1
  from typing import Literal
2
2
  import torch
3
- from ..compile import enable_compilation
3
+ from ..utils.compile import allow_compile
4
4
 
5
5
  # reference - https://www.cs.cornell.edu/~bindel/class/cs6210-f09/lec18.pdf
6
+ @allow_compile
6
7
  def _get_w_tau(R: torch.Tensor, i: int, eps: float):
7
8
  R_ii = R[...,i,i]
8
9
  R_below = R[...,i:,i]
@@ -17,6 +18,7 @@ def _get_w_tau(R: torch.Tensor, i: int, eps: float):
17
18
  tau = torch.where(degenerate, 1, tau)
18
19
  return w, tau
19
20
 
21
+ @allow_compile
20
22
  def _qr_householder_complete(A:torch.Tensor):
21
23
  *b,m,n = A.shape
22
24
  k = min(m,n)
@@ -33,6 +35,7 @@ def _qr_householder_complete(A:torch.Tensor):
33
35
 
34
36
  return Q, R
35
37
 
38
+ @allow_compile
36
39
  def _qr_householder_reduced(A:torch.Tensor):
37
40
  *b,m,n = A.shape
38
41
  k = min(m,n)
@@ -64,7 +67,6 @@ def _qr_householder_reduced(A:torch.Tensor):
64
67
 
65
68
  return Q, R
66
69
 
67
- # @enable_compilation
68
70
  def qr_householder(A:torch.Tensor, mode: Literal['complete', 'reduced'] = 'reduced'):
69
71
  """an attempt at making QR decomposition for very tall and thin matrices that doesn't freeze, but it is around n_cols times slower than torch.linalg.qr, but compilation makes it faster, but it has to recompile when processing different shapes"""
70
72
  if mode == 'reduced': return _qr_householder_reduced(A)