torchzero 0.3.15__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. tests/test_identical.py +22 -22
  2. tests/test_module_autograd.py +586 -0
  3. tests/test_objective.py +188 -0
  4. tests/test_opts.py +225 -214
  5. tests/test_tensorlist.py +0 -8
  6. tests/test_utils_optimizer.py +0 -1
  7. torchzero/__init__.py +2 -2
  8. torchzero/core/__init__.py +7 -4
  9. torchzero/core/chain.py +20 -23
  10. torchzero/core/functional.py +90 -24
  11. torchzero/core/modular.py +53 -57
  12. torchzero/core/module.py +132 -52
  13. torchzero/core/objective.py +948 -0
  14. torchzero/core/reformulation.py +55 -24
  15. torchzero/core/transform.py +261 -367
  16. torchzero/linalg/__init__.py +11 -0
  17. torchzero/linalg/eigh.py +253 -0
  18. torchzero/linalg/linalg_utils.py +14 -0
  19. torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
  20. torchzero/linalg/matrix_power.py +28 -0
  21. torchzero/linalg/orthogonalize.py +93 -0
  22. torchzero/{utils/linalg → linalg}/qr.py +16 -2
  23. torchzero/{utils/linalg → linalg}/solve.py +74 -88
  24. torchzero/linalg/svd.py +47 -0
  25. torchzero/linalg/torch_linalg.py +168 -0
  26. torchzero/modules/__init__.py +4 -3
  27. torchzero/modules/adaptive/__init__.py +11 -3
  28. torchzero/modules/adaptive/adagrad.py +167 -217
  29. torchzero/modules/adaptive/adahessian.py +76 -105
  30. torchzero/modules/adaptive/adam.py +53 -76
  31. torchzero/modules/adaptive/adan.py +50 -31
  32. torchzero/modules/adaptive/adaptive_heavyball.py +12 -7
  33. torchzero/modules/adaptive/aegd.py +12 -12
  34. torchzero/modules/adaptive/esgd.py +98 -119
  35. torchzero/modules/adaptive/ggt.py +186 -0
  36. torchzero/modules/adaptive/lion.py +7 -11
  37. torchzero/modules/adaptive/lre_optimizers.py +299 -0
  38. torchzero/modules/adaptive/mars.py +7 -7
  39. torchzero/modules/adaptive/matrix_momentum.py +48 -52
  40. torchzero/modules/adaptive/msam.py +71 -53
  41. torchzero/modules/adaptive/muon.py +67 -129
  42. torchzero/modules/adaptive/natural_gradient.py +63 -41
  43. torchzero/modules/adaptive/orthograd.py +11 -15
  44. torchzero/modules/adaptive/psgd/__init__.py +5 -0
  45. torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
  46. torchzero/modules/adaptive/psgd/psgd.py +1390 -0
  47. torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
  48. torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
  49. torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
  50. torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
  51. torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
  52. torchzero/modules/adaptive/rmsprop.py +83 -75
  53. torchzero/modules/adaptive/rprop.py +48 -47
  54. torchzero/modules/adaptive/sam.py +55 -45
  55. torchzero/modules/adaptive/shampoo.py +149 -130
  56. torchzero/modules/adaptive/soap.py +207 -143
  57. torchzero/modules/adaptive/sophia_h.py +106 -130
  58. torchzero/modules/clipping/clipping.py +22 -25
  59. torchzero/modules/clipping/ema_clipping.py +31 -25
  60. torchzero/modules/clipping/growth_clipping.py +14 -17
  61. torchzero/modules/conjugate_gradient/cg.py +27 -38
  62. torchzero/modules/experimental/__init__.py +7 -6
  63. torchzero/modules/experimental/adanystrom.py +258 -0
  64. torchzero/modules/experimental/common_directions_whiten.py +142 -0
  65. torchzero/modules/experimental/coordinate_momentum.py +36 -0
  66. torchzero/modules/experimental/cubic_adam.py +160 -0
  67. torchzero/modules/experimental/curveball.py +25 -41
  68. torchzero/modules/experimental/eigen_sr1.py +182 -0
  69. torchzero/modules/experimental/eigengrad.py +207 -0
  70. torchzero/modules/experimental/gradmin.py +2 -2
  71. torchzero/modules/experimental/higher_order_newton.py +14 -40
  72. torchzero/modules/experimental/l_infinity.py +1 -1
  73. torchzero/modules/experimental/matrix_nag.py +122 -0
  74. torchzero/modules/experimental/newton_solver.py +23 -54
  75. torchzero/modules/experimental/newtonnewton.py +45 -48
  76. torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  77. torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  78. torchzero/modules/experimental/spsa1.py +3 -3
  79. torchzero/modules/experimental/structural_projections.py +1 -4
  80. torchzero/modules/grad_approximation/fdm.py +2 -2
  81. torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  82. torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  83. torchzero/modules/grad_approximation/rfdm.py +24 -21
  84. torchzero/modules/least_squares/gn.py +121 -50
  85. torchzero/modules/line_search/backtracking.py +4 -4
  86. torchzero/modules/line_search/line_search.py +33 -33
  87. torchzero/modules/line_search/strong_wolfe.py +4 -4
  88. torchzero/modules/misc/debug.py +12 -12
  89. torchzero/modules/misc/escape.py +10 -10
  90. torchzero/modules/misc/gradient_accumulation.py +11 -79
  91. torchzero/modules/misc/homotopy.py +16 -8
  92. torchzero/modules/misc/misc.py +121 -123
  93. torchzero/modules/misc/multistep.py +52 -53
  94. torchzero/modules/misc/regularization.py +49 -44
  95. torchzero/modules/misc/split.py +31 -29
  96. torchzero/modules/misc/switch.py +37 -32
  97. torchzero/modules/momentum/averaging.py +14 -14
  98. torchzero/modules/momentum/cautious.py +37 -31
  99. torchzero/modules/momentum/momentum.py +12 -12
  100. torchzero/modules/ops/__init__.py +4 -4
  101. torchzero/modules/ops/accumulate.py +21 -21
  102. torchzero/modules/ops/binary.py +67 -66
  103. torchzero/modules/ops/higher_level.py +20 -20
  104. torchzero/modules/ops/multi.py +44 -41
  105. torchzero/modules/ops/reduce.py +26 -23
  106. torchzero/modules/ops/unary.py +53 -53
  107. torchzero/modules/ops/utility.py +47 -46
  108. torchzero/modules/{functional.py → opt_utils.py} +1 -1
  109. torchzero/modules/projections/galore.py +1 -1
  110. torchzero/modules/projections/projection.py +46 -43
  111. torchzero/modules/quasi_newton/__init__.py +1 -1
  112. torchzero/modules/quasi_newton/damping.py +2 -2
  113. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
  114. torchzero/modules/quasi_newton/lbfgs.py +10 -10
  115. torchzero/modules/quasi_newton/lsr1.py +10 -10
  116. torchzero/modules/quasi_newton/quasi_newton.py +54 -39
  117. torchzero/modules/quasi_newton/sg2.py +69 -205
  118. torchzero/modules/restarts/restars.py +39 -37
  119. torchzero/modules/second_order/__init__.py +2 -2
  120. torchzero/modules/second_order/ifn.py +31 -62
  121. torchzero/modules/second_order/inm.py +57 -53
  122. torchzero/modules/second_order/multipoint.py +40 -80
  123. torchzero/modules/second_order/newton.py +165 -196
  124. torchzero/modules/second_order/newton_cg.py +105 -157
  125. torchzero/modules/second_order/nystrom.py +216 -185
  126. torchzero/modules/second_order/rsn.py +132 -125
  127. torchzero/modules/smoothing/laplacian.py +13 -12
  128. torchzero/modules/smoothing/sampling.py +10 -10
  129. torchzero/modules/step_size/adaptive.py +24 -24
  130. torchzero/modules/step_size/lr.py +17 -17
  131. torchzero/modules/termination/termination.py +32 -30
  132. torchzero/modules/trust_region/cubic_regularization.py +3 -3
  133. torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  134. torchzero/modules/trust_region/trust_cg.py +2 -2
  135. torchzero/modules/trust_region/trust_region.py +27 -22
  136. torchzero/modules/variance_reduction/svrg.py +23 -21
  137. torchzero/modules/weight_decay/__init__.py +2 -1
  138. torchzero/modules/weight_decay/reinit.py +83 -0
  139. torchzero/modules/weight_decay/weight_decay.py +17 -18
  140. torchzero/modules/wrappers/optim_wrapper.py +14 -14
  141. torchzero/modules/zeroth_order/cd.py +10 -7
  142. torchzero/optim/mbs.py +291 -0
  143. torchzero/optim/root.py +3 -3
  144. torchzero/optim/utility/split.py +2 -1
  145. torchzero/optim/wrappers/directsearch.py +27 -63
  146. torchzero/optim/wrappers/fcmaes.py +14 -35
  147. torchzero/optim/wrappers/mads.py +11 -31
  148. torchzero/optim/wrappers/moors.py +66 -0
  149. torchzero/optim/wrappers/nevergrad.py +4 -13
  150. torchzero/optim/wrappers/nlopt.py +31 -25
  151. torchzero/optim/wrappers/optuna.py +8 -13
  152. torchzero/optim/wrappers/pybobyqa.py +124 -0
  153. torchzero/optim/wrappers/scipy/__init__.py +7 -0
  154. torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  155. torchzero/optim/wrappers/scipy/brute.py +48 -0
  156. torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  157. torchzero/optim/wrappers/scipy/direct.py +69 -0
  158. torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  159. torchzero/optim/wrappers/scipy/experimental.py +141 -0
  160. torchzero/optim/wrappers/scipy/minimize.py +151 -0
  161. torchzero/optim/wrappers/scipy/sgho.py +111 -0
  162. torchzero/optim/wrappers/wrapper.py +121 -0
  163. torchzero/utils/__init__.py +7 -25
  164. torchzero/utils/benchmarks/__init__.py +0 -0
  165. torchzero/utils/benchmarks/logistic.py +122 -0
  166. torchzero/utils/compile.py +2 -2
  167. torchzero/utils/derivatives.py +97 -73
  168. torchzero/utils/optimizer.py +4 -77
  169. torchzero/utils/python_tools.py +31 -0
  170. torchzero/utils/tensorlist.py +11 -5
  171. torchzero/utils/thoad_tools.py +68 -0
  172. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
  173. torchzero-0.4.1.dist-info/RECORD +209 -0
  174. tests/test_vars.py +0 -185
  175. torchzero/core/var.py +0 -376
  176. torchzero/modules/adaptive/lmadagrad.py +0 -186
  177. torchzero/modules/experimental/momentum.py +0 -160
  178. torchzero/optim/wrappers/scipy.py +0 -572
  179. torchzero/utils/linalg/__init__.py +0 -12
  180. torchzero/utils/linalg/matrix_funcs.py +0 -87
  181. torchzero/utils/linalg/orthogonalize.py +0 -12
  182. torchzero/utils/linalg/svd.py +0 -20
  183. torchzero/utils/ops.py +0 -10
  184. torchzero-0.3.15.dist-info/RECORD +0 -175
  185. /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
  186. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
  187. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,11 @@
1
+ from . import linear_operator
2
+
3
+ from .matrix_power import (
4
+ matrix_power_eigh,
5
+ matrix_power_svd,
6
+ MatrixPowerMethod,
7
+ )
8
+ from .orthogonalize import zeropower_via_eigh, zeropower_via_newtonschulz5, zeropower_via_svd, orthogonalize,OrthogonalizeMethod
9
+ from .qr import qr_householder
10
+ from .solve import cg, nystrom_sketch_and_solve, nystrom_pcg
11
+ from .eigh import nystrom_approximation, regularize_eigh
@@ -0,0 +1,253 @@
1
+ from collections.abc import Callable
2
+
3
+ import torch
4
+
5
+ from . import torch_linalg
6
+ from .linalg_utils import mm
7
+ from .orthogonalize import OrthogonalizeMethod, orthogonalize
8
+ from .svd import tall_reduced_svd_via_eigh
9
+
10
+
11
+ # https://arxiv.org/pdf/2110.02820
12
+ def nystrom_approximation(
13
+ A_mv: Callable[[torch.Tensor], torch.Tensor] | None,
14
+ A_mm: Callable[[torch.Tensor], torch.Tensor] | None,
15
+ ndim: int,
16
+ rank: int,
17
+ device,
18
+ orthogonalize_method: OrthogonalizeMethod = 'qr',
19
+ eigv_tol: float = 0,
20
+ dtype = torch.float32,
21
+ generator = None,
22
+ ) -> tuple[torch.Tensor, torch.Tensor]:
23
+ """Computes Nyström approximation to positive-semidefinite A factored as Q L Q^T (truncatd eigenvalue decomp),
24
+ returns ``(L, Q)``.
25
+
26
+ A is ``(m,m)``, then Q is ``(m, rank)``; L is a ``(rank, )`` vector - diagonal of ``(rank, rank)``"""
27
+ # basis
28
+ O = torch.randn((ndim, rank), device=device, dtype=dtype, generator=generator) # Gaussian test matrix
29
+ O = orthogonalize(O, method=orthogonalize_method) # Thin QR decomposition # pylint:disable=not-callable
30
+
31
+ # Y = AΩ
32
+ AO = mm(A_mv=A_mv, A_mm=A_mm, X=O)
33
+
34
+ v = torch.finfo(dtype).eps * torch.linalg.matrix_norm(AO, ord='fro') # Compute shift # pylint:disable=not-callable
35
+ Yv = AO + v*O # Shift for stability
36
+ C = torch.linalg.cholesky_ex(O.mT @ Yv)[0] # pylint:disable=not-callable
37
+ B = torch.linalg.solve_triangular(C, Yv.mT, upper=False, unitriangular=False).mT # pylint:disable=not-callable
38
+
39
+ # Q, S, _ = torch_linalg.svd(B, full_matrices=False) # pylint:disable=not-callable
40
+ # B is (ndim, rank) so we can use eigendecomp of (rank, rank)
41
+ Q, S = tall_reduced_svd_via_eigh(B, tol=eigv_tol, retry_float64=True)
42
+
43
+ L = S.pow(2) - v
44
+ return L, Q
45
+
46
+
47
+ def regularize_eigh(
48
+ L: torch.Tensor,
49
+ Q: torch.Tensor,
50
+ truncate: int | None = None,
51
+ tol: float | None = None,
52
+ damping: float = 0,
53
+ rdamping: float = 0,
54
+ ) -> tuple[torch.Tensor, torch.Tensor] | tuple[None, None]:
55
+ """Applies regularization to eigendecomposition. Returns ``(L, Q)``.
56
+
57
+ Args:
58
+ L (torch.Tensor): eigenvalues, shape ``(rank,)``.
59
+ Q (torch.Tensor): eigenvectors, shape ``(n, rank)``.
60
+ truncate (int | None, optional):
61
+ keeps top ``truncate`` eigenvalues. Defaults to None.
62
+ tol (float | None, optional):
63
+ all eigenvalues smaller than largest eigenvalue times ``tol`` are removed. Defaults to None.
64
+ damping (float | None, optional): scalar added to eigenvalues. Defaults to 0.
65
+ rdamping (float | None, optional): scalar multiplied by largest eigenvalue and added to eigenvalues. Defaults to 0.
66
+ """
67
+ # remove non-finite eigenvalues
68
+ finite = L.isfinite()
69
+ if finite.any():
70
+ L = L[finite]
71
+ Q = Q[:, finite]
72
+ else:
73
+ return None, None
74
+
75
+ # largest finite!!! eigval
76
+ L_max = L[-1] # L is sorted in ascending order
77
+
78
+ # remove small eigenvalues relative to largest
79
+ if tol is not None:
80
+ indices = L > tol * L_max
81
+ L = L[indices]
82
+ Q = Q[:, indices]
83
+
84
+ # truncate to rank (L is ordered in ascending order)
85
+ if truncate is not None:
86
+ L = L[-truncate:]
87
+ Q = Q[:, -truncate:]
88
+
89
+ # damping
90
+ d = damping + rdamping * L_max
91
+ if d != 0:
92
+ L += d
93
+
94
+ return L, Q
95
+
96
+ def eigh_plus_uuT(
97
+ L: torch.Tensor,
98
+ Q: torch.Tensor,
99
+ u: torch.Tensor,
100
+ alpha: float = 1,
101
+ tol: float | None = None,
102
+ retry_float64: bool = False,
103
+ ) -> tuple[torch.Tensor, torch.Tensor]:
104
+ """
105
+ compute eigendecomposition of Q L Q^T + alpha * (u u^T) where Q is ``(m, rank)`` and L is ``(rank, )`` and u is ``(m, )``
106
+ """
107
+ if tol is None: tol = torch.finfo(Q.dtype).eps
108
+ z = Q.T @ u # (rank,)
109
+
110
+ # component of u orthogonal to the column space of Q
111
+ res = u - Q @ z # (m,)
112
+ beta = torch.linalg.vector_norm(res) # pylint:disable=not-callable
113
+
114
+ if beta < tol:
115
+ # u is already in the column space of Q
116
+ B = L.diag_embed().add_(z.outer(z), alpha=alpha) # (rank, rank)
117
+ L_prime, S = torch_linalg.eigh(B, retry_float64=retry_float64)
118
+ Q_prime = Q @ S
119
+ return L_prime, Q_prime
120
+
121
+ # normalize the orthogonal component to get a new orthonormal vector
122
+ v = res / beta # (m, )
123
+
124
+ # project and compute new eigendecomposition
125
+ D_diag = torch.cat([L, torch.tensor([0.0], device=Q.device, dtype=Q.dtype)])
126
+ w = torch.cat([z, beta.unsqueeze(0)]) # Shape: (rank+1,)
127
+ B = D_diag.diag_embed().add_(w.outer(w), alpha=alpha)
128
+
129
+ L_prime, S = torch_linalg.eigh(B, retry_float64=retry_float64)
130
+
131
+ # unproject and sort
132
+ basis = torch.cat([Q, v.unsqueeze(-1)], dim=1) # (m, rank+1)
133
+ Q_prime = basis @ S # (m, rank+1)
134
+
135
+ idx = torch.argsort(L_prime)
136
+ L_prime = L_prime[idx]
137
+ Q_prime = Q_prime[:, idx]
138
+
139
+ return L_prime, Q_prime
140
+
141
+ def eigh_plus_UUT(
142
+ L: torch.Tensor,
143
+ Q: torch.Tensor,
144
+ U: torch.Tensor,
145
+ alpha: float = 1,
146
+ tol = None,
147
+ retry_float64: bool = False,
148
+ ):
149
+ """
150
+ compute eigendecomposition of Q L Q^T + alpha * (U U^T), where Q is ``(m, rank)`` and L is ``(rank, )``,
151
+ U is ``(m, k)`` where k is rank of correction
152
+ """
153
+ if U.size(1) == 1:
154
+ return eigh_plus_uuT(L, Q, U[:,0], alpha=alpha, tol=tol, retry_float64=retry_float64)
155
+
156
+ if tol is None: tol = torch.finfo(Q.dtype).eps
157
+ m, r = Q.shape
158
+
159
+ Z = Q.T @ U # (r, k)
160
+ U_res = U - Q @ Z # (m, k)
161
+
162
+ # find cols of U not in col space of Q
163
+ res_norms = torch.linalg.vector_norm(U_res, dim=0) # pylint:disable=not-callable
164
+ new_indices = torch.where(res_norms > tol)[0]
165
+ k_prime = len(new_indices)
166
+
167
+ if k_prime == 0:
168
+ # all cols are in Q
169
+ B = Q
170
+ C = Z # (r x k)
171
+ r_new = r
172
+ else:
173
+ # orthonormalize directions that aren't in Q
174
+ U_new = U_res[:, new_indices]
175
+ Q_u, _ = torch_linalg.qr(U_new, mode='reduced', retry_float64=retry_float64)
176
+ B = torch.hstack([Q, Q_u])
177
+ C = torch.vstack([Z, Q_u.T @ U])
178
+ r_new = r + k_prime
179
+
180
+
181
+ # project and compute new eigendecomposition
182
+ A_proj = torch.zeros((r_new, r_new), device=Q.device, dtype=Q.dtype)
183
+ A_proj[:r, :r] = L.diag_embed()
184
+ A_proj.addmm_(C, C.T, alpha=alpha)
185
+
186
+ L_prime, S = torch_linalg.eigh(A_proj, retry_float64=retry_float64)
187
+
188
+ # unproject and sort
189
+ Q_prime = B @ S
190
+ idx = torch.argsort(L_prime)
191
+ L_prime = L_prime[idx]
192
+ Q_prime = Q_prime[:, idx]
193
+
194
+ return L_prime, Q_prime
195
+
196
+
197
+ def eigh_plus_UVT_symmetrize(
198
+ Q: torch.Tensor,
199
+ L: torch.Tensor,
200
+ U: torch.Tensor,
201
+ V: torch.Tensor,
202
+ alpha: float,
203
+ retry_float64: bool = False,
204
+
205
+ ):
206
+ """
207
+ Q is ``(m, rank)``; L is ``(rank, )``; U and V are the low rank correction such that U V^T is ``(m, m)``.
208
+
209
+ This computes eigendecomposition of A, where
210
+
211
+ ``M = Q diag(L) Q^T + alpha * (U V^T)``;
212
+
213
+ ``A = (M + M^T) / 2``
214
+ """
215
+ m, rank = Q.shape
216
+ _, k = V.shape
217
+
218
+ # project U and V out of the Q subspace via Gram-schmidt
219
+ Q_T_U = Q.T @ U
220
+ U_perp = U - Q @ Q_T_U
221
+
222
+ Q_T_V = Q.T @ V
223
+ V_perp = V - Q @ Q_T_V
224
+
225
+ R = torch.hstack([U_perp, V_perp])
226
+ Q_perp, _ = torch_linalg.qr(R, retry_float64=retry_float64)
227
+
228
+ Q_B = torch.hstack([Q, Q_perp])
229
+ r_B = Q_B.shape[1]
230
+
231
+ # project, symmetrize and compute new eigendecomposition
232
+ A_proj = torch.zeros((r_B, r_B), device=Q.device, dtype=Q.dtype)
233
+ A_proj[:rank, :rank] = L.diag_embed()
234
+
235
+ Q_perp_T_U = Q_perp.T @ U
236
+ Q_B_T_U = torch.vstack([Q_T_U, Q_perp_T_U])
237
+
238
+ Q_perp_T_V = Q_perp.T @ V
239
+ Q_B_T_V = torch.vstack([Q_T_V, Q_perp_T_V])
240
+
241
+ update_proj = Q_B_T_U @ Q_B_T_V.T + Q_B_T_V @ Q_B_T_U.T
242
+ A_proj.add_(update_proj, alpha=alpha/2)
243
+
244
+ L_prime, S = torch_linalg.eigh(A_proj, retry_float64=retry_float64)
245
+
246
+ # unproject and sort
247
+ Q_prime = Q_B @ S
248
+
249
+ idx = torch.argsort(L_prime)
250
+ L_prime = L_prime[idx]
251
+ Q_prime = Q_prime[:, idx]
252
+
253
+ return L_prime, Q_prime
@@ -0,0 +1,14 @@
1
+ from collections.abc import Callable
2
+ import torch
3
+
4
+ def mm(
5
+ A_mv: Callable[[torch.Tensor], torch.Tensor] | None,
6
+ A_mm: Callable[[torch.Tensor], torch.Tensor] | None,
7
+ X
8
+ ):
9
+ """matrix-matrix when either mv or mm is given"""
10
+ if A_mm is not None: return A_mm(X)
11
+ assert A_mv is not None
12
+ return torch.stack([A_mv(col) for col in X.unbind(-1)], -1) # rank matvecs
13
+
14
+
@@ -1,4 +1,6 @@
1
- """simplified version of https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.linalg.LinearOperator.html. This is used for trust regions."""
1
+ """This is mainly used for trust regions. In some cases certain operations are relaxed, e.g. eigenvalue shift instead of
2
+ adding diagonal when it isn't tractable, to make it work with Levenberg-Marquadt.
3
+ """
2
4
  import math
3
5
  from abc import ABC, abstractmethod
4
6
  from functools import partial
@@ -7,7 +9,8 @@ from typing import cast, final
7
9
 
8
10
  import torch
9
11
 
10
- from ..torch_tools import tofloat, tonumpy, totensor
12
+ from ..utils.torch_tools import tofloat, tonumpy, totensor
13
+ from .solve import nystrom_sketch_and_solve
11
14
 
12
15
  if find_spec('scipy') is not None:
13
16
  from scipy.sparse.linalg import LinearOperator as _ScipyLinearOperator
@@ -15,7 +18,6 @@ else:
15
18
  _ScipyLinearOperator = None
16
19
 
17
20
  class LinearOperator(ABC):
18
- """this is used for trust region"""
19
21
  device: torch.types.Device
20
22
  dtype: torch.dtype | None
21
23
 
@@ -25,12 +27,18 @@ class LinearOperator(ABC):
25
27
  def rmatvec(self, x: torch.Tensor) -> torch.Tensor:
26
28
  raise NotImplementedError(f"{self.__class__.__name__} doesn't implement rmatvec")
27
29
 
28
- def matmat(self, x: torch.Tensor) -> "LinearOperator":
29
- raise NotImplementedError(f"{self.__class__.__name__} doesn't implement matmul")
30
+ def matmat(self, X: torch.Tensor) -> "LinearOperator":
31
+ raise NotImplementedError(f"{self.__class__.__name__} doesn't implement matmat")
32
+
33
+ def rmatmat(self, X: torch.Tensor) -> "LinearOperator":
34
+ raise NotImplementedError(f"{self.__class__.__name__} doesn't implement rmatmat")
30
35
 
31
36
  def solve(self, b: torch.Tensor) -> torch.Tensor:
32
37
  raise NotImplementedError(f"{self.__class__.__name__} doesn't implement solve")
33
38
 
39
+ def solve_plus_diag(self, b: torch.Tensor, diag: int | float | torch.Tensor) -> torch.Tensor:
40
+ return self.add_diagonal(diag).solve(b)
41
+
34
42
  def solve_bounded(self, b: torch.Tensor, bound:float, ord:float=2) -> torch.Tensor:
35
43
  """solve with a norm bound on x"""
36
44
  raise NotImplementedError(f"{self.__class__.__name__} doesn't implement solve_bounded")
@@ -129,8 +137,8 @@ class Dense(LinearOperator):
129
137
  def matvec(self, x): return self.A.mv(x)
130
138
  def rmatvec(self, x): return self.A.mH.mv(x)
131
139
 
132
- def matmat(self, x): return Dense(self.A.mm(x))
133
- def rmatmat(self, x): return Dense(self.A.mH.mm(x))
140
+ def matmat(self, X): return Dense(self.A.mm(X))
141
+ def rmatmat(self, X): return Dense(self.A.mH.mm(X))
134
142
 
135
143
  def solve(self, b): return _solve(self.A, b)
136
144
 
@@ -146,6 +154,12 @@ class Dense(LinearOperator):
146
154
  def is_dense(self): return True
147
155
  def transpose(self): return Dense(self.A.mH)
148
156
 
157
+ class SPD(Dense):
158
+ def solve(self, b: torch.Tensor):
159
+ L, info = torch.linalg.cholesky_ex(self.A) # pylint:disable=not-callable
160
+ return torch.cholesky_solve(b.unsqueeze(-1), L).squeeze(-1)
161
+
162
+
149
163
  class DenseInverse(LinearOperator):
150
164
  """Represents inverse of a dense matrix A."""
151
165
  def __init__(self, A_inv: torch.Tensor):
@@ -156,8 +170,8 @@ class DenseInverse(LinearOperator):
156
170
  def matvec(self, x): return _solve(self.A_inv, x) # pylint:disable=not-callable
157
171
  def rmatvec(self, x): return _solve(self.A_inv.mH, x) # pylint:disable=not-callable
158
172
 
159
- def matmat(self, x): return Dense(_solve(self.A_inv, x)) # pylint:disable=not-callable
160
- def rmatmat(self, x): return Dense(_solve(self.A_inv.mH, x)) # pylint:disable=not-callable
173
+ def matmat(self, X): return Dense(_solve(self.A_inv, X)) # pylint:disable=not-callable
174
+ def rmatmat(self, X): return Dense(_solve(self.A_inv.mH, X)) # pylint:disable=not-callable
161
175
 
162
176
  def solve(self, b): return self.A_inv.mv(b)
163
177
 
@@ -190,8 +204,8 @@ class Diagonal(LinearOperator):
190
204
  def matvec(self, x): return self.A * x
191
205
  def rmatvec(self, x): return self.A * x
192
206
 
193
- def matmat(self, x): return Dense(x * self.A.unsqueeze(-1))
194
- def rmatmat(self, x): return Dense(x * self.A.unsqueeze(-1))
207
+ def matmat(self, X): return Dense(X * self.A.unsqueeze(-1))
208
+ def rmatmat(self, X): return Dense(X * self.A.unsqueeze(-1))
195
209
 
196
210
  def solve(self, b): return b/self.A
197
211
 
@@ -221,8 +235,8 @@ class ScaledIdentity(LinearOperator):
221
235
  def matvec(self, x): return x * self.s
222
236
  def rmatvec(self, x): return x * self.s
223
237
 
224
- def matmat(self, x): return Dense(x * self.s)
225
- def rmatmat(self, x): return Dense(x * self.s)
238
+ def matmat(self, X): return Dense(X * self.s)
239
+ def rmatmat(self, X): return Dense(X * self.s)
226
240
 
227
241
  def solve(self, b): return b / self.s
228
242
  def solve_bounded(self, b, bound, ord = 2):
@@ -263,6 +277,7 @@ class ScaledIdentity(LinearOperator):
263
277
  def is_dense(self): return False
264
278
  def transpose(self): return ScaledIdentity(self.s, shape=self.shape, device=self.device, dtype=self.dtype)
265
279
 
280
+
266
281
  class AtA(LinearOperator):
267
282
  def __init__(self, A: torch.Tensor):
268
283
  self.A = A
@@ -270,8 +285,8 @@ class AtA(LinearOperator):
270
285
  def matvec(self, x): return self.A.mH.mv(self.A.mv(x))
271
286
  def rmatvec(self, x): return self.matvec(x)
272
287
 
273
- def matmat(self, x): return Dense(torch.linalg.multi_dot([self.A.mH, self.A, x])) # pylint:disable=not-callable
274
- def rmatmat(self, x): return Dense(torch.linalg.multi_dot([self.A.mH, self.A, x])) # pylint:disable=not-callable
288
+ def matmat(self, X): return Dense(torch.linalg.multi_dot([self.A.mH, self.A, X])) # pylint:disable=not-callable
289
+ def rmatmat(self, X): return Dense(torch.linalg.multi_dot([self.A.mH, self.A, X])) # pylint:disable=not-callable
275
290
 
276
291
  def is_dense(self): return False
277
292
  def to_tensor(self): return self.A.mH @ self.A
@@ -283,51 +298,41 @@ class AtA(LinearOperator):
283
298
  return Dense(self.to_tensor() + torch.diag_embed(x))
284
299
 
285
300
  def solve(self, b):
286
- return Dense(self.to_tensor()).solve(b)
301
+ *_, n, m = self.A.shape
302
+ if n >= m: return Dense(self.to_tensor()).solve(b)
287
303
 
288
- def inv(self):
289
- return Dense(self.to_tensor()).inv()
304
+ A = self.A
305
+ C = A @ A.mH # (n, n), SPD
306
+ L, info = torch.linalg.cholesky_ex(C) # pylint:disable=not-callable
307
+ z = torch.cholesky_solve((A @ b).unsqueeze(-1), L).squeeze(-1)
308
+ return A.mH @ z
290
309
 
291
- def diagonal(self):
292
- return self.A.pow(2).sum(1)
293
-
294
- def size(self):
295
- n = self.A.size(1)
296
- return (n,n)
310
+ def solve_plus_diag(self, b, diag):
311
+ *_, n, m = self.A.shape
312
+ if (n >= m) or (isinstance(diag, torch.Tensor) and diag.numel() > 1):
313
+ return Dense(self.to_tensor()).solve_plus_diag(b, diag)
297
314
 
298
- class AAT(LinearOperator):
299
- def __init__(self, A: torch.Tensor):
300
- self.A = A
301
- self.device = self.A.device; self.dtype = self.A.dtype
302
-
303
- def matvec(self, x): return self.A.mv(self.A.mH.mv(x))
304
- def rmatvec(self, x): return self.matvec(x)
305
-
306
- def matmat(self, x): return Dense(torch.linalg.multi_dot([self.A, self.A.mH, x])) # pylint:disable=not-callable
307
- def rmatmat(self, x): return Dense(torch.linalg.multi_dot([self.A, self.A.mH, x])) # pylint:disable=not-callable
308
-
309
- def is_dense(self): return False
310
- def to_tensor(self): return self.A @ self.A.mH
311
- def transpose(self): return AAT(self.A)
312
-
313
- def add_diagonal(self, x):
314
- if isinstance(x, torch.Tensor) and x.numel() <= 1: x = x.item()
315
- if isinstance(x, (int,float)): x = torch.full((self.shape[0],), fill_value=x, device=self.A.device, dtype=self.A.dtype)
316
- return Dense(self.to_tensor() + torch.diag_embed(x))
315
+ A = self.A
316
+ I = torch.eye(A.size(-2), device=A.device, dtype=A.dtype)
317
317
 
318
- def solve(self, b):
319
- return Dense(self.to_tensor()).solve(b)
318
+ C = (A @ A.mH).add_(I.mul_(diag)) # (n, n), SPD
319
+ L, info = torch.linalg.cholesky_ex(C + I.mul_(diag)) # pylint:disable=not-callable
320
+ z = torch.cholesky_solve((A @ b).unsqueeze(-1), L).squeeze(-1)
321
+ return (1 / diag) * (b - A.mH @ z)
320
322
 
321
323
  def inv(self):
322
324
  return Dense(self.to_tensor()).inv()
323
325
 
324
326
  def diagonal(self):
325
- return self.A.pow(2).sum(0)
327
+ return self.A.pow(2).sum(1)
326
328
 
327
329
  def size(self):
328
330
  n = self.A.size(1)
329
331
  return (n,n)
330
332
 
333
+ class AAt(AtA):
334
+ def __init__(self, A: torch.Tensor):
335
+ super().__init__(A.mH)
331
336
 
332
337
  class Sketched(LinearOperator):
333
338
  """A projected by sketching matrix S, representing the operator S @ A_proj @ S.T.
@@ -339,7 +344,6 @@ class Sketched(LinearOperator):
339
344
  self.A_proj = A_proj
340
345
  self.device = self.A_proj.device; self.dtype = self.A_proj.dtype
341
346
 
342
-
343
347
  def matvec(self, x):
344
348
  x_proj = self.S.T @ x
345
349
  Ax_proj = self.A_proj @ x_proj
@@ -351,8 +355,8 @@ class Sketched(LinearOperator):
351
355
  return self.S @ ATx_proj
352
356
 
353
357
 
354
- def matmat(self, x): return Dense(torch.linalg.multi_dot([self.S, self.A_proj, self.S.T, x])) # pylint:disable=not-callable
355
- def rmatmat(self, x): return Dense(torch.linalg.multi_dot([self.S, self.A_proj.mH, self.S.T, x])) # pylint:disable=not-callable
358
+ def matmat(self, X): return Dense(torch.linalg.multi_dot([self.S, self.A_proj, self.S.T, X])) # pylint:disable=not-callable
359
+ def rmatmat(self, X): return Dense(torch.linalg.multi_dot([self.S, self.A_proj.mH, self.S.T, X])) # pylint:disable=not-callable
356
360
 
357
361
 
358
362
  def is_dense(self): return False
@@ -375,3 +379,49 @@ class Sketched(LinearOperator):
375
379
  n = self.S.size(0)
376
380
  return (n,n)
377
381
 
382
+
383
+ class Eigendecomposition(LinearOperator):
384
+ """A represented as Q L Q^H. If A is (n,n), then Q is (n, rank); L is a vector - diagonal of (rank, rank)"""
385
+ def __init__(self, L: torch.Tensor, Q: torch.Tensor, use_nystrom: bool = True):
386
+ self.L = L
387
+ self.Q = Q
388
+ self.use_nystrom = use_nystrom
389
+ self.device = self.L.device; self.dtype = self.L.dtype
390
+
391
+ def matvec(self, x):
392
+ return self.Q @ ((self.Q.mH @ x) * self.L)
393
+
394
+ def rmatvec(self, x):
395
+ return self.matvec(x)
396
+
397
+ def matmat(self, X):
398
+ return Dense(self.Q @ (self.L[:, None] * (self.Q.mH @ X)))
399
+
400
+ def rmatmat(self, X):
401
+ return self.matmat(X)
402
+
403
+ def is_dense(self): return False
404
+ def to_tensor(self): return self.Q @ self.L.diag_embed() @ self.Q.mH
405
+ def transpose(self): return Eigendecomposition(L=self.L, Q=self.Q)
406
+
407
+ def add_diagonal(self, x):
408
+ """this doesn't correspond to adding diagonal to A, however it still works for LM etc."""
409
+ if isinstance(x, torch.Tensor) and x.numel() > 1:
410
+ raise RuntimeError("Eigendecomposition linear operator doesn't support add_diagonal with a vector diag")
411
+
412
+ return Eigendecomposition(L=self.L + x, Q = self.Q)
413
+
414
+ def solve(self, b):
415
+ return self.Q @ ((self.Q.mH @ b) / self.L)
416
+
417
+ def solve_plus_diag(self, b, diag):
418
+ if isinstance(diag, torch.Tensor) and diag.numel() > 1: return super().solve_plus_diag(b, diag)
419
+ if not self.use_nystrom: return super().solve_plus_diag(b, diag)
420
+ return nystrom_sketch_and_solve(L=self.L, Q=self.Q, b=b, reg=float(diag))
421
+
422
+ def inv(self):
423
+ return Eigendecomposition(L=1 / self.L, Q = self.Q)
424
+
425
+ def size(self):
426
+ n = self.Q.size(0)
427
+ return (n,n)
@@ -0,0 +1,28 @@
1
+ from typing import Literal
2
+ import warnings
3
+ from collections.abc import Callable
4
+
5
+ import torch
6
+ from . import torch_linalg
7
+ def matrix_power_eigh(A: torch.Tensor, power:float, abs:bool=False):
8
+ """this is faster than SVD but only for positive semi-definite symmetric matrices
9
+ (covariance matrices are always SPD)"""
10
+
11
+ L, Q = torch_linalg.eigh(A, retry_float64=True) # pylint:disable=not-callable
12
+ if abs: L.abs_()
13
+ if power % 2 != 0: L.clip_(min = torch.finfo(A.dtype).tiny * 2)
14
+ return (Q * L.pow_(power).unsqueeze(-2)) @ Q.mH
15
+
16
+
17
+ def matrix_power_svd(A: torch.Tensor, power: float) -> torch.Tensor:
18
+ """for any symmetric matrix"""
19
+ U, S, Vh = torch_linalg.svd(A, full_matrices=False, retry_float64=True) # pylint:disable=not-callable
20
+ if power % 2 != 0: S.clip_(min = torch.finfo(A.dtype).tiny * 2)
21
+ return (U * S.pow_(power).unsqueeze(-2)) @ Vh
22
+
23
+ MatrixPowerMethod = Literal["eigh", "eigh_abs", "svd"]
24
+ def matrix_power(A: torch.Tensor, power: float, method: MatrixPowerMethod = "eigh_abs") -> torch.Tensor:
25
+ if method == "eigh": return matrix_power_eigh(A, power)
26
+ if method == "eigh_abs": return matrix_power_eigh(A, power, abs=True)
27
+ if method == "svd": return matrix_power_svd(A, power)
28
+ raise ValueError(method)
@@ -0,0 +1,93 @@
1
+ from typing import Literal
2
+
3
+ import torch
4
+
5
+ from ..utils.compile import allow_compile
6
+ from . import torch_linalg
7
+
8
+ # zeropower_via_newtonschulz5 from:
9
+ # https://github.com/KellerJordan/Muon/blob/master/muon.py
10
+ # and
11
+ # https://github.com/HomebrewML/HeavyBall/blob/main/heavyball/utils.py#L452
12
+ _NS_COEFFS = (
13
+ (4.0848, -6.8946, 2.9270),
14
+ (3.9505, -6.3029, 2.6377),
15
+ (3.7418, -5.5913, 2.3037),
16
+ (2.8769, -3.1427, 1.2046),
17
+ (2.8366, -3.0525, 1.2012)
18
+ )
19
+
20
+ @allow_compile
21
+ def zeropower_via_newtonschulz5(G: torch.Tensor, coeffs=_NS_COEFFS) -> torch.Tensor:
22
+ """
23
+ Applies to last 2 dims - so usually reverse_dims should be applied to G before and after.
24
+
25
+ Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
26
+ quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
27
+ of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
28
+ zero even beyond the point where the iteration no longer converges all the way to one everywhere
29
+ on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
30
+ where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
31
+ performance at all relative to UV^T, where USV^T = G is the SVD.
32
+ """
33
+ assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
34
+
35
+ X = G.bfloat16()
36
+ if G.size(-2) > G.size(-1):
37
+ X = X.mT
38
+
39
+ # Ensure spectral norm is at most 1
40
+ X = X / (X.norm(dim=(-2, -1), keepdim=True).clip(min=torch.finfo(X.dtype).tiny * 2))
41
+
42
+ # Perform the NS iterations
43
+ for a,b,c in coeffs:
44
+ A = X @ X.mT
45
+ B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
46
+ X = a * X + B @ X
47
+
48
+ if G.size(-2) > G.size(-1):
49
+ X = X.mT
50
+
51
+ return X.to(G.dtype)
52
+
53
+ def zeropower_via_svd(A: torch.Tensor) -> torch.Tensor:
54
+ """
55
+ Applies to first 2 dims and isn't batched - rest of dimensions are flattened.
56
+ """
57
+ try:
58
+ U, S, Vt = torch_linalg.svd(A, full_matrices=False, retry_float64=True) # pylint:disable=not-callable
59
+ except torch.linalg.LinAlgError:
60
+ U, S, Vt = torch.svd_lowrank(A, q=1, M=1e-4 * A.mean() * torch.rand_like(A))
61
+
62
+ return U @ Vt
63
+
64
+ def zeropower_via_eigh(A: torch.Tensor) -> torch.Tensor:
65
+ """
66
+ Only SPD and I need to check if I apply those to SPD because this is better than SVD.
67
+ """
68
+ L, Q = torch_linalg.eigh(A, retry_float64=True)
69
+ return Q @ Q.mH
70
+
71
+
72
+ def orthogonalize_via_qr(A: torch.Tensor):
73
+ *_, m, n = A.shape
74
+ T = False
75
+ if m < n:
76
+ T = True
77
+ m,n = n,m
78
+ A = A.mH
79
+
80
+ Q = torch_linalg.qr(A, mode='reduced', retry_float64=True).Q
81
+
82
+ if T:
83
+ Q = Q.mH
84
+
85
+ return Q
86
+
87
+ OrthogonalizeMethod = Literal["newtonschulz", "svd", "qr"]
88
+ def orthogonalize(A: torch.Tensor, method: OrthogonalizeMethod) -> torch.Tensor:
89
+ if method == "newtonschulz": return zeropower_via_newtonschulz5(A)
90
+ if method == "svd": return zeropower_via_svd(A)
91
+ if method == "qr": return orthogonalize_via_qr(A)
92
+ if method == "eigh": return zeropower_via_eigh(A)
93
+ raise ValueError(method)