torchzero 0.1.8__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. docs/source/conf.py +57 -0
  2. tests/test_identical.py +230 -0
  3. tests/test_module.py +50 -0
  4. tests/test_opts.py +884 -0
  5. tests/test_tensorlist.py +1787 -0
  6. tests/test_utils_optimizer.py +170 -0
  7. tests/test_vars.py +184 -0
  8. torchzero/__init__.py +4 -4
  9. torchzero/core/__init__.py +3 -13
  10. torchzero/core/module.py +629 -510
  11. torchzero/core/preconditioner.py +137 -0
  12. torchzero/core/transform.py +252 -0
  13. torchzero/modules/__init__.py +13 -21
  14. torchzero/modules/clipping/__init__.py +3 -0
  15. torchzero/modules/clipping/clipping.py +320 -0
  16. torchzero/modules/clipping/ema_clipping.py +135 -0
  17. torchzero/modules/clipping/growth_clipping.py +187 -0
  18. torchzero/modules/experimental/__init__.py +13 -18
  19. torchzero/modules/experimental/absoap.py +350 -0
  20. torchzero/modules/experimental/adadam.py +111 -0
  21. torchzero/modules/experimental/adamY.py +135 -0
  22. torchzero/modules/experimental/adasoap.py +282 -0
  23. torchzero/modules/experimental/algebraic_newton.py +145 -0
  24. torchzero/modules/experimental/curveball.py +89 -0
  25. torchzero/modules/experimental/dsoap.py +290 -0
  26. torchzero/modules/experimental/gradmin.py +85 -0
  27. torchzero/modules/experimental/reduce_outward_lr.py +35 -0
  28. torchzero/modules/experimental/spectral.py +286 -0
  29. torchzero/modules/experimental/subspace_preconditioners.py +128 -0
  30. torchzero/modules/experimental/tropical_newton.py +136 -0
  31. torchzero/modules/functional.py +209 -0
  32. torchzero/modules/grad_approximation/__init__.py +4 -0
  33. torchzero/modules/grad_approximation/fdm.py +120 -0
  34. torchzero/modules/grad_approximation/forward_gradient.py +81 -0
  35. torchzero/modules/grad_approximation/grad_approximator.py +66 -0
  36. torchzero/modules/grad_approximation/rfdm.py +259 -0
  37. torchzero/modules/line_search/__init__.py +5 -30
  38. torchzero/modules/line_search/backtracking.py +186 -0
  39. torchzero/modules/line_search/line_search.py +181 -0
  40. torchzero/modules/line_search/scipy.py +37 -0
  41. torchzero/modules/line_search/strong_wolfe.py +260 -0
  42. torchzero/modules/line_search/trust_region.py +61 -0
  43. torchzero/modules/lr/__init__.py +2 -0
  44. torchzero/modules/lr/lr.py +59 -0
  45. torchzero/modules/lr/step_size.py +97 -0
  46. torchzero/modules/momentum/__init__.py +14 -4
  47. torchzero/modules/momentum/averaging.py +78 -0
  48. torchzero/modules/momentum/cautious.py +181 -0
  49. torchzero/modules/momentum/ema.py +173 -0
  50. torchzero/modules/momentum/experimental.py +189 -0
  51. torchzero/modules/momentum/matrix_momentum.py +124 -0
  52. torchzero/modules/momentum/momentum.py +43 -106
  53. torchzero/modules/ops/__init__.py +103 -0
  54. torchzero/modules/ops/accumulate.py +65 -0
  55. torchzero/modules/ops/binary.py +240 -0
  56. torchzero/modules/ops/debug.py +25 -0
  57. torchzero/modules/ops/misc.py +419 -0
  58. torchzero/modules/ops/multi.py +137 -0
  59. torchzero/modules/ops/reduce.py +149 -0
  60. torchzero/modules/ops/split.py +75 -0
  61. torchzero/modules/ops/switch.py +68 -0
  62. torchzero/modules/ops/unary.py +115 -0
  63. torchzero/modules/ops/utility.py +112 -0
  64. torchzero/modules/optimizers/__init__.py +18 -10
  65. torchzero/modules/optimizers/adagrad.py +146 -49
  66. torchzero/modules/optimizers/adam.py +112 -118
  67. torchzero/modules/optimizers/lion.py +18 -11
  68. torchzero/modules/optimizers/muon.py +222 -0
  69. torchzero/modules/optimizers/orthograd.py +55 -0
  70. torchzero/modules/optimizers/rmsprop.py +103 -51
  71. torchzero/modules/optimizers/rprop.py +342 -99
  72. torchzero/modules/optimizers/shampoo.py +197 -0
  73. torchzero/modules/optimizers/soap.py +286 -0
  74. torchzero/modules/optimizers/sophia_h.py +129 -0
  75. torchzero/modules/projections/__init__.py +5 -0
  76. torchzero/modules/projections/dct.py +73 -0
  77. torchzero/modules/projections/fft.py +73 -0
  78. torchzero/modules/projections/galore.py +10 -0
  79. torchzero/modules/projections/projection.py +218 -0
  80. torchzero/modules/projections/structural.py +151 -0
  81. torchzero/modules/quasi_newton/__init__.py +7 -4
  82. torchzero/modules/quasi_newton/cg.py +218 -0
  83. torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
  84. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
  85. torchzero/modules/quasi_newton/lbfgs.py +228 -0
  86. torchzero/modules/quasi_newton/lsr1.py +170 -0
  87. torchzero/modules/quasi_newton/olbfgs.py +196 -0
  88. torchzero/modules/quasi_newton/quasi_newton.py +475 -0
  89. torchzero/modules/second_order/__init__.py +3 -4
  90. torchzero/modules/second_order/newton.py +142 -165
  91. torchzero/modules/second_order/newton_cg.py +84 -0
  92. torchzero/modules/second_order/nystrom.py +168 -0
  93. torchzero/modules/smoothing/__init__.py +2 -5
  94. torchzero/modules/smoothing/gaussian.py +164 -0
  95. torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
  96. torchzero/modules/weight_decay/__init__.py +1 -0
  97. torchzero/modules/weight_decay/weight_decay.py +52 -0
  98. torchzero/modules/wrappers/__init__.py +1 -0
  99. torchzero/modules/wrappers/optim_wrapper.py +91 -0
  100. torchzero/optim/__init__.py +2 -10
  101. torchzero/optim/utility/__init__.py +1 -0
  102. torchzero/optim/utility/split.py +45 -0
  103. torchzero/optim/wrappers/nevergrad.py +2 -28
  104. torchzero/optim/wrappers/nlopt.py +31 -16
  105. torchzero/optim/wrappers/scipy.py +79 -156
  106. torchzero/utils/__init__.py +27 -0
  107. torchzero/utils/compile.py +175 -37
  108. torchzero/utils/derivatives.py +513 -99
  109. torchzero/utils/linalg/__init__.py +5 -0
  110. torchzero/utils/linalg/matrix_funcs.py +87 -0
  111. torchzero/utils/linalg/orthogonalize.py +11 -0
  112. torchzero/utils/linalg/qr.py +71 -0
  113. torchzero/utils/linalg/solve.py +168 -0
  114. torchzero/utils/linalg/svd.py +20 -0
  115. torchzero/utils/numberlist.py +132 -0
  116. torchzero/utils/ops.py +10 -0
  117. torchzero/utils/optimizer.py +284 -0
  118. torchzero/utils/optuna_tools.py +40 -0
  119. torchzero/utils/params.py +149 -0
  120. torchzero/utils/python_tools.py +40 -25
  121. torchzero/utils/tensorlist.py +1081 -0
  122. torchzero/utils/torch_tools.py +48 -12
  123. torchzero-0.3.2.dist-info/METADATA +379 -0
  124. torchzero-0.3.2.dist-info/RECORD +128 -0
  125. {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info}/WHEEL +1 -1
  126. {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info/licenses}/LICENSE +0 -0
  127. torchzero-0.3.2.dist-info/top_level.txt +3 -0
  128. torchzero/core/tensorlist_optimizer.py +0 -219
  129. torchzero/modules/adaptive/__init__.py +0 -4
  130. torchzero/modules/adaptive/adaptive.py +0 -192
  131. torchzero/modules/experimental/experimental.py +0 -294
  132. torchzero/modules/experimental/quad_interp.py +0 -104
  133. torchzero/modules/experimental/subspace.py +0 -259
  134. torchzero/modules/gradient_approximation/__init__.py +0 -7
  135. torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
  136. torchzero/modules/gradient_approximation/base_approximator.py +0 -105
  137. torchzero/modules/gradient_approximation/fdm.py +0 -125
  138. torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
  139. torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
  140. torchzero/modules/gradient_approximation/rfdm.py +0 -125
  141. torchzero/modules/line_search/armijo.py +0 -56
  142. torchzero/modules/line_search/base_ls.py +0 -139
  143. torchzero/modules/line_search/directional_newton.py +0 -217
  144. torchzero/modules/line_search/grid_ls.py +0 -158
  145. torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
  146. torchzero/modules/meta/__init__.py +0 -12
  147. torchzero/modules/meta/alternate.py +0 -65
  148. torchzero/modules/meta/grafting.py +0 -195
  149. torchzero/modules/meta/optimizer_wrapper.py +0 -173
  150. torchzero/modules/meta/return_overrides.py +0 -46
  151. torchzero/modules/misc/__init__.py +0 -10
  152. torchzero/modules/misc/accumulate.py +0 -43
  153. torchzero/modules/misc/basic.py +0 -115
  154. torchzero/modules/misc/lr.py +0 -96
  155. torchzero/modules/misc/multistep.py +0 -51
  156. torchzero/modules/misc/on_increase.py +0 -53
  157. torchzero/modules/operations/__init__.py +0 -29
  158. torchzero/modules/operations/multi.py +0 -298
  159. torchzero/modules/operations/reduction.py +0 -134
  160. torchzero/modules/operations/singular.py +0 -113
  161. torchzero/modules/optimizers/sgd.py +0 -54
  162. torchzero/modules/orthogonalization/__init__.py +0 -2
  163. torchzero/modules/orthogonalization/newtonschulz.py +0 -159
  164. torchzero/modules/orthogonalization/svd.py +0 -86
  165. torchzero/modules/regularization/__init__.py +0 -22
  166. torchzero/modules/regularization/dropout.py +0 -34
  167. torchzero/modules/regularization/noise.py +0 -77
  168. torchzero/modules/regularization/normalization.py +0 -328
  169. torchzero/modules/regularization/ortho_grad.py +0 -78
  170. torchzero/modules/regularization/weight_decay.py +0 -92
  171. torchzero/modules/scheduling/__init__.py +0 -2
  172. torchzero/modules/scheduling/lr_schedulers.py +0 -131
  173. torchzero/modules/scheduling/step_size.py +0 -80
  174. torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
  175. torchzero/modules/weight_averaging/__init__.py +0 -2
  176. torchzero/modules/weight_averaging/ema.py +0 -72
  177. torchzero/modules/weight_averaging/swa.py +0 -171
  178. torchzero/optim/experimental/__init__.py +0 -20
  179. torchzero/optim/experimental/experimental.py +0 -343
  180. torchzero/optim/experimental/ray_search.py +0 -83
  181. torchzero/optim/first_order/__init__.py +0 -18
  182. torchzero/optim/first_order/cautious.py +0 -158
  183. torchzero/optim/first_order/forward_gradient.py +0 -70
  184. torchzero/optim/first_order/optimizers.py +0 -570
  185. torchzero/optim/modular.py +0 -148
  186. torchzero/optim/quasi_newton/__init__.py +0 -1
  187. torchzero/optim/quasi_newton/directional_newton.py +0 -58
  188. torchzero/optim/second_order/__init__.py +0 -1
  189. torchzero/optim/second_order/newton.py +0 -94
  190. torchzero/optim/zeroth_order/__init__.py +0 -4
  191. torchzero/optim/zeroth_order/fdm.py +0 -87
  192. torchzero/optim/zeroth_order/newton_fdm.py +0 -146
  193. torchzero/optim/zeroth_order/rfdm.py +0 -217
  194. torchzero/optim/zeroth_order/rs.py +0 -85
  195. torchzero/random/__init__.py +0 -1
  196. torchzero/random/random.py +0 -46
  197. torchzero/tensorlist.py +0 -826
  198. torchzero-0.1.8.dist-info/METADATA +0 -130
  199. torchzero-0.1.8.dist-info/RECORD +0 -104
  200. torchzero-0.1.8.dist-info/top_level.txt +0 -1
@@ -0,0 +1,87 @@
1
+ import warnings
2
+ from collections.abc import Callable
3
+
4
+ import torch
5
+
6
+ def eigvals_func(A: torch.Tensor, fn: Callable[[torch.Tensor], torch.Tensor]) -> torch.Tensor:
7
+ L, Q = torch.linalg.eigh(A) # pylint:disable=not-callable
8
+ L = fn(L)
9
+ return (Q * L.unsqueeze(-2)) @ Q.mH
10
+
11
+ def singular_vals_func(A: torch.Tensor, fn: Callable[[torch.Tensor], torch.Tensor]) -> torch.Tensor:
12
+ U, S, V = torch.linalg.svd(A) # pylint:disable=not-callable
13
+ S = fn(S)
14
+ return (U * S.unsqueeze(-2)) @ V.mT
15
+
16
+ def matrix_power_eigh(A: torch.Tensor, pow:float):
17
+ L, Q = torch.linalg.eigh(A) # pylint:disable=not-callable
18
+ if pow % 2 != 0: L.clip_(min = torch.finfo(A.dtype).eps)
19
+ return (Q * L.pow(pow).unsqueeze(-2)) @ Q.mH
20
+
21
+
22
+ def inv_sqrt_2x2(A: torch.Tensor, force_pd: bool=False) -> torch.Tensor:
23
+ """Inverse square root of a possibly batched 2x2 matrix using a general formula for 2x2 matrices so that this is way faster than torch linalg. I tried doing a hierarchical 2x2 preconditioning but it didn't work well."""
24
+ eps = torch.finfo(A.dtype).eps
25
+
26
+ a = A[..., 0, 0]
27
+ b = A[..., 0, 1]
28
+ c = A[..., 1, 0]
29
+ d = A[..., 1, 1]
30
+
31
+ det = (a * d).sub_(b * c)
32
+ trace = a + d
33
+
34
+ if force_pd:
35
+ # add smallest eigenvalue magnitude to diagonal to force PD
36
+ # could also abs or clip eigenvalues bc there is a formula for eigenvectors
37
+ term1 = trace/2
38
+ term2 = (trace.pow(2).div_(4).sub_(det)).clamp_(min=eps).sqrt_()
39
+ y1 = term1 + term2
40
+ y2 = term1 - term2
41
+ smallest_eigval = torch.minimum(y1, y2).neg_().clamp_(min=0) + eps
42
+ a = a+smallest_eigval
43
+ d = d+smallest_eigval
44
+
45
+ # recalculate det and trace witg new a and b
46
+ det = (a * d).sub_(b * c)
47
+ trace = a + d
48
+
49
+ s = (det.clamp(min=eps)).sqrt_()
50
+
51
+ tau_squared = trace + 2 * s
52
+ tau = (tau_squared.clamp(min=eps)).sqrt_()
53
+
54
+ denom = s * tau
55
+
56
+ coeff = (denom.clamp(min=eps)).reciprocal_().unsqueeze(-1).unsqueeze(-1)
57
+
58
+ row1 = torch.stack([d + s, -b], dim=-1)
59
+ row2 = torch.stack([-c, a + s], dim=-1)
60
+ M = torch.stack([row1, row2], dim=-2)
61
+
62
+ return coeff * M
63
+
64
+
65
+ def x_inv(diag: torch.Tensor,antidiag: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
66
+ """invert a matrix with diagonal and anti-diagonal non zero elements, with no checks that it is invertible"""
67
+ n = diag.shape[0]
68
+ if diag.dim() != 1 or antidiag.dim() != 1 or antidiag.shape[0] != n:
69
+ raise ValueError("Input tensors must be 1D and have the same size.")
70
+ if n == 0:
71
+ return torch.empty_like(diag), torch.empty_like(antidiag)
72
+
73
+ # opposite indexes
74
+ diag_rev = torch.flip(diag, dims=[0])
75
+ antidiag_rev = torch.flip(antidiag, dims=[0])
76
+
77
+ # determinants
78
+ # det_i = d[i] * d[n-1-i] - a[i] * a[n-1-i]
79
+ determinant_vec = diag * diag_rev - antidiag * antidiag_rev
80
+
81
+ # inverse diagonal elements: y_d[i] = d[n-1-i] / det_i
82
+ inv_diag_vec = diag_rev / determinant_vec
83
+
84
+ # inverse anti-diagonal elements: y_a[i] = -a[i] / det_i
85
+ inv_anti_diag_vec = -antidiag / determinant_vec
86
+
87
+ return inv_diag_vec, inv_anti_diag_vec
@@ -0,0 +1,11 @@
1
+ from typing import overload
2
+ import torch
3
+ from ..tensorlist import TensorList
4
+
5
+ @overload
6
+ def gram_schmidt(x: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: ...
7
+ @overload
8
+ def gram_schmidt(x: TensorList, y: TensorList) -> tuple[TensorList, TensorList]: ...
9
+ def gram_schmidt(x, y):
10
+ """makes two orthogonal vectors, only y is changed"""
11
+ return x, y - (x*y) / ((x*x) + 1e-8)
@@ -0,0 +1,71 @@
1
+ from typing import Literal
2
+ import torch
3
+ from ..compile import enable_compilation
4
+
5
+ # reference - https://www.cs.cornell.edu/~bindel/class/cs6210-f09/lec18.pdf
6
+ def _get_w_tau(R: torch.Tensor, i: int, eps: float):
7
+ R_ii = R[...,i,i]
8
+ R_below = R[...,i:,i]
9
+ norm_x = torch.linalg.vector_norm(R_below, dim=-1) # pylint:disable=not-callable
10
+ degenerate = norm_x < eps
11
+ s = -torch.sign(R_ii)
12
+ u1 = R_ii - s*norm_x
13
+ u1 = torch.where(degenerate, 1, u1)
14
+ w = R_below / u1.unsqueeze(-1)
15
+ w[...,0] = 1
16
+ tau = -s*u1/norm_x
17
+ tau = torch.where(degenerate, 1, tau)
18
+ return w, tau
19
+
20
+ def _qr_householder_complete(A:torch.Tensor):
21
+ *b,m,n = A.shape
22
+ k = min(m,n)
23
+ eps = torch.finfo(A.dtype).eps
24
+
25
+ Q = torch.eye(m, dtype=A.dtype, device=A.device).expand(*b, m, m).clone() # clone because expanded dims refer to same memory
26
+ R = A.clone()
27
+
28
+ for i in range(k):
29
+ w, tau = _get_w_tau(R, i, eps)
30
+
31
+ R[..., i:,:] -= (tau*w).unsqueeze(-1) @ (w.unsqueeze(-2) @ R[..., i:,:])
32
+ Q[..., :,i:] -= (Q[..., :,i:]@w).unsqueeze(-1) @ (tau*w).unsqueeze(-2)
33
+
34
+ return Q, R
35
+
36
+ def _qr_householder_reduced(A:torch.Tensor):
37
+ *b,m,n = A.shape
38
+ k = min(m,n)
39
+ eps = torch.finfo(A.dtype).eps
40
+
41
+ R = A.clone()
42
+
43
+ ws:list = [None for _ in range(k)]
44
+ taus:list = [None for _ in range(k)]
45
+
46
+ for i in range(k):
47
+ w, tau = _get_w_tau(R, i, eps)
48
+
49
+ ws[i] = w
50
+ taus[i] = tau
51
+
52
+ if m - i > 0 :
53
+ R[..., i:,:] -= (tau*w).unsqueeze(-1) @ (w.unsqueeze(-2) @ R[..., i:,:])
54
+ # Q[..., :,i:] -= (Q[..., :,i:]@w).unsqueeze(-1) @ (tau*w).unsqueeze(-2)
55
+
56
+ R = R[..., :k, :]
57
+ Q = torch.eye(m, k, dtype=A.dtype, device=A.device).expand(*b, m, k).clone()
58
+ for i in range(k - 1, -1, -1):
59
+ if m - i > 0:
60
+ w = ws[i]
61
+ tau = taus[i].unsqueeze(-1).unsqueeze(-1)
62
+ Q_below = Q[..., i:, :]
63
+ Q[..., i:, :] -= torch.linalg.multi_dot([tau * w.unsqueeze(-1), w.unsqueeze(-2), Q_below]) # pylint:disable=not-callable
64
+
65
+ return Q, R
66
+
67
+ # @enable_compilation
68
+ def qr_householder(A:torch.Tensor, mode: Literal['complete', 'reduced'] = 'reduced'):
69
+ """an attempt at making QR decomposition for very tall and thin matrices that doesn't freeze, but it is around n_cols times slower than torch.linalg.qr, but compilation makes it faster, but it has to recompile when processing different shapes"""
70
+ if mode == 'reduced': return _qr_householder_reduced(A)
71
+ return _qr_householder_complete(A)
@@ -0,0 +1,168 @@
1
+ from collections.abc import Callable
2
+ from typing import overload
3
+ import torch
4
+
5
+ from .. import TensorList, generic_zeros_like, generic_vector_norm, generic_numel, generic_randn_like, generic_eq
6
+
7
+ @overload
8
+ def cg(
9
+ A_mm: Callable[[torch.Tensor], torch.Tensor],
10
+ b: torch.Tensor,
11
+ x0_: torch.Tensor | None,
12
+ tol: float | None,
13
+ maxiter: int | None,
14
+ reg: float = 0,
15
+ ) -> torch.Tensor: ...
16
+ @overload
17
+ def cg(
18
+ A_mm: Callable[[TensorList], TensorList],
19
+ b: TensorList,
20
+ x0_: TensorList | None,
21
+ tol: float | None,
22
+ maxiter: int | None,
23
+ reg: float | list[float] | tuple[float] = 0,
24
+ ) -> TensorList: ...
25
+
26
+ def cg(
27
+ A_mm: Callable,
28
+ b: torch.Tensor | TensorList,
29
+ x0_: torch.Tensor | TensorList | None,
30
+ tol: float | None,
31
+ maxiter: int | None,
32
+ reg: float | list[float] | tuple[float] = 0,
33
+ ):
34
+ def A_mm_reg(x): # A_mm with regularization
35
+ Ax = A_mm(x)
36
+ if not generic_eq(reg, 0): Ax += x*reg
37
+ return Ax
38
+
39
+ if maxiter is None: maxiter = generic_numel(b)
40
+ if x0_ is None: x0_ = generic_zeros_like(b)
41
+
42
+ x = x0_
43
+ residual = b - A_mm_reg(x)
44
+ p = residual.clone() # search direction
45
+ r_norm = generic_vector_norm(residual)
46
+ init_norm = r_norm
47
+ if tol is not None and r_norm < tol: return x
48
+ k = 0
49
+
50
+ while True:
51
+ Ap = A_mm_reg(p)
52
+ step_size = (r_norm**2) / p.dot(Ap)
53
+ x += step_size * p # Update solution
54
+ residual -= step_size * Ap # Update residual
55
+ new_r_norm = generic_vector_norm(residual)
56
+
57
+ k += 1
58
+ if tol is not None and new_r_norm <= tol * init_norm: return x
59
+ if k >= maxiter: return x
60
+
61
+ beta = (new_r_norm**2) / (r_norm**2)
62
+ p = residual + beta*p
63
+ r_norm = new_r_norm
64
+
65
+
66
+ # https://arxiv.org/pdf/2110.02820 algorithm 2.1 apparently supposed to be diabolical
67
+ def nystrom_approximation(
68
+ A_mm: Callable[[torch.Tensor], torch.Tensor],
69
+ ndim: int,
70
+ rank: int,
71
+ device,
72
+ dtype = torch.float32,
73
+ generator = None,
74
+ ) -> tuple[torch.Tensor, torch.Tensor]:
75
+ omega = torch.randn((ndim, rank), device=device, dtype=dtype, generator=generator) # Gaussian test matrix
76
+ omega, _ = torch.linalg.qr(omega) # Thin QR decomposition # pylint:disable=not-callable
77
+
78
+ # Y = AΩ
79
+ Y = torch.stack([A_mm(col) for col in omega.unbind(-1)], -1) # rank matvecs
80
+ v = torch.finfo(dtype).eps * torch.linalg.matrix_norm(Y, ord='fro') # Compute shift # pylint:disable=not-callable
81
+ Yv = Y + v*omega # Shift for stability
82
+ C = torch.linalg.cholesky_ex(omega.mT @ Yv)[0] # pylint:disable=not-callable
83
+ B = torch.linalg.solve_triangular(C, Yv.mT, upper=False, unitriangular=False).mT # pylint:disable=not-callable
84
+ U, S, _ = torch.linalg.svd(B, full_matrices=False) # pylint:disable=not-callable
85
+ lambd = (S.pow(2) - v).clip(min=0) #Remove shift, compute eigs
86
+ return U, lambd
87
+
88
+ # this one works worse
89
+ def nystrom_sketch_and_solve(
90
+ A_mm: Callable[[torch.Tensor], torch.Tensor],
91
+ b: torch.Tensor,
92
+ rank: int,
93
+ reg: float,
94
+ generator=None,
95
+ ) -> torch.Tensor:
96
+ U, lambd = nystrom_approximation(
97
+ A_mm=A_mm,
98
+ ndim=b.size(-1),
99
+ rank=rank,
100
+ device=b.device,
101
+ dtype=b.dtype,
102
+ generator=generator,
103
+ )
104
+ b = b.unsqueeze(-1)
105
+ lambd += reg
106
+ # x = (A + μI)⁻¹ b
107
+ # (A + μI)⁻¹ = U(Λ + μI)⁻¹Uᵀ + (1/μ)(b - UUᵀ)
108
+ # x = U(Λ + μI)⁻¹Uᵀb + (1/μ)(b - UUᵀb)
109
+ Uᵀb = U.T @ b
110
+ term1 = U @ ((1/lambd).unsqueeze(-1) * Uᵀb)
111
+ term2 = (1.0 / reg) * (b - U @ Uᵀb)
112
+ return (term1 + term2).squeeze(-1)
113
+
114
+ # this one is insane
115
+ def nystrom_pcg(
116
+ A_mm: Callable[[torch.Tensor], torch.Tensor],
117
+ b: torch.Tensor,
118
+ sketch_size: int,
119
+ reg: float,
120
+ x0_: torch.Tensor | None,
121
+ tol: float | None,
122
+ maxiter: int | None,
123
+ generator=None,
124
+ ) -> torch.Tensor:
125
+ U, lambd = nystrom_approximation(
126
+ A_mm=A_mm,
127
+ ndim=b.size(-1),
128
+ rank=sketch_size,
129
+ device=b.device,
130
+ dtype=b.dtype,
131
+ generator=generator,
132
+ )
133
+ lambd += reg
134
+
135
+ def A_mm_reg(x): # A_mm with regularization
136
+ Ax = A_mm(x)
137
+ if reg != 0: Ax += x*reg
138
+ return Ax
139
+
140
+ if maxiter is None: maxiter = b.numel()
141
+ if x0_ is None: x0_ = torch.zeros_like(b)
142
+
143
+ x = x0_
144
+ residual = b - A_mm_reg(x)
145
+ # z0 = P⁻¹ r0
146
+ term1 = lambd[...,-1] * U * (1/lambd.unsqueeze(-2)) @ U.mT
147
+ term2 = torch.eye(U.size(-2), device=U.device,dtype=U.dtype) - U@U.mT
148
+ P_inv = term1 + term2
149
+ z = P_inv @ residual
150
+ p = z.clone() # search direction
151
+
152
+ init_norm = torch.linalg.vector_norm(residual) # pylint:disable=not-callable
153
+ if tol is not None and init_norm < tol: return x
154
+ k = 0
155
+ while True:
156
+ Ap = A_mm_reg(p)
157
+ rz = residual.dot(z)
158
+ step_size = rz / p.dot(Ap)
159
+ x += step_size * p
160
+ residual -= step_size * Ap
161
+
162
+ k += 1
163
+ if tol is not None and torch.linalg.vector_norm(residual) <= tol * init_norm: return x # pylint:disable=not-callable
164
+ if k >= maxiter: return x
165
+
166
+ z = P_inv @ residual
167
+ beta = residual.dot(z) / rz
168
+ p = z + p*beta
@@ -0,0 +1,20 @@
1
+ import torch
2
+
3
+ # projected svd
4
+ # adapted from https://github.com/smortezavi/Randomized_SVD_GPU
5
+ def randomized_svd(M: torch.Tensor, k: int, driver=None):
6
+ *_, m, n = M.shape
7
+ transpose = False
8
+ if m < n:
9
+ transpose = True
10
+ M = M.mT
11
+ m,n = n,m
12
+
13
+ rand_matrix = torch.randn(size=(n, k), device=M.device, dtype=M.dtype)
14
+ Q, _ = torch.linalg.qr(M @ rand_matrix, mode='reduced') # pylint:disable=not-callable
15
+ smaller_matrix = Q.mT @ M
16
+ U_hat, s, V = torch.linalg.svd(smaller_matrix, driver=driver, full_matrices=False) # pylint:disable=not-callable
17
+ U = Q @ U_hat
18
+
19
+ if transpose: return V.mT, s, U.mT
20
+ return U, s, V
@@ -0,0 +1,132 @@
1
+ """A lightweight data type for a list of numbers (or anything else) with arithmetic overloads (using basic for-loops).
2
+ Subclasses list so works with torch._foreach_xxx operations."""
3
+ import builtins
4
+ from collections.abc import Callable, Sequence, Iterable, Generator, Iterator
5
+ import math
6
+ import operator
7
+ from typing import Any, Literal, TypedDict
8
+ from typing_extensions import Self, TypeAlias, Unpack
9
+
10
+ import torch
11
+ from .python_tools import zipmap
12
+
13
+ def _alpha_add(x, other, alpha):
14
+ return x + other * alpha
15
+
16
+ def as_numberlist(x):
17
+ if isinstance(x, NumberList): return x
18
+ return NumberList(x)
19
+
20
+
21
+ def maybe_numberlist(x):
22
+ if isinstance(x, (list,tuple)): return as_numberlist(x)
23
+ return x
24
+
25
+ def _clamp(x,min,max):
26
+ if min is not None and x < min: return min
27
+ if max is not None and x > max: return max
28
+ return x
29
+
30
+ class NumberList(list[int | float | Any]):
31
+ """List of python numbers.
32
+ Note that this only supports basic arithmetic operations that are overloaded.
33
+
34
+ Can't use a numpy array because _foreach methods do not work with it."""
35
+ # remove torch.Tensor from return values
36
+ # this is no longer necessary
37
+ # def __getitem__(self, i) -> Any:
38
+ # return super().__getitem__(i)
39
+
40
+ # def __iter__(self) -> Iterator[Any]:
41
+ # return super().__iter__()
42
+
43
+ def __add__(self, other: Any) -> Self: return self.add(other) # type:ignore
44
+ def __radd__(self, other: Any) -> Self: return self.add(other)
45
+
46
+ def __sub__(self, other: Any) -> Self: return self.sub(other)
47
+ def __rsub__(self, other: Any) -> Self: return self.sub(other).neg()
48
+
49
+ def __mul__(self, other: Any) -> Self: return self.mul(other) # type:ignore
50
+ def __rmul__(self, other: Any) -> Self: return self.mul(other) # type:ignore
51
+
52
+ def __truediv__(self, other: Any) -> Self: return self.div(other)
53
+ def __rtruediv__(self, other: Any):
54
+ if isinstance(other, (tuple,list)): return self.__class__(o / i for o, i in zip(self, other))
55
+ return self.__class__(other / i for i in self)
56
+
57
+ def __floordiv__(self, other: Any): return self.floor_divide(other)
58
+ def __mod__(self, other: Any): return self.remainder(other)
59
+
60
+
61
+ def __pow__(self, other: Any): return self.pow(other)
62
+ def __rpow__(self, other: Any): return self.rpow(other)
63
+
64
+ def __neg__(self): return self.neg()
65
+
66
+ def __eq__(self, other: Any): return self.eq(other) # type:ignore
67
+ def __ne__(self, other: Any): return self.ne(other) # type:ignore
68
+ def __lt__(self, other: Any): return self.lt(other) # type:ignore
69
+ def __le__(self, other: Any): return self.le(other) # type:ignore
70
+ def __gt__(self, other: Any): return self.gt(other) # type:ignore
71
+ def __ge__(self, other: Any): return self.ge(other) # type:ignore
72
+
73
+ def __invert__(self): return self.logical_not()
74
+
75
+ def __and__(self, other: Any): return self.logical_and(other)
76
+ def __or__(self, other: Any): return self.logical_or(other)
77
+ def __xor__(self, other: Any): return self.logical_xor(other)
78
+
79
+ def __bool__(self):
80
+ raise RuntimeError(f'Boolean value of {self.__class__.__name__} is ambiguous')
81
+
82
+ def zipmap(self, fn: Callable, other: Any | list | tuple, *args, **kwargs):
83
+ """If `other` is list/tuple, applies `fn` to this TensorList zipped with `other`.
84
+ Otherwise applies `fn` to this TensorList and `other`.
85
+ Returns a new TensorList with return values of the callable."""
86
+ return zipmap(self, fn, other, *args, **kwargs)
87
+
88
+ def zipmap_args(self, fn: Callable[..., Any], *others, **kwargs):
89
+ """If `args` is list/tuple, applies `fn` to this TensorList zipped with `others`.
90
+ Otherwise applies `fn` to this TensorList and `other`."""
91
+ others = [i if isinstance(i, (list, tuple)) else [i]*len(self) for i in others]
92
+ return self.__class__(fn(*z, **kwargs) for z in zip(self, *others))
93
+
94
+ # def _set_to_method_result_(self, method: str, *args, **kwargs):
95
+ # """Sets each element of the tensorlist to the result of calling the specified method on the corresponding element.
96
+ # This is used to support/mimic in-place operations, although I decided to remove them."""
97
+ # res = getattr(self, method)(*args, **kwargs)
98
+ # for i,v in enumerate(res): self[i] = v
99
+ # return self
100
+
101
+ def add(self, other: Any, alpha: int | float = 1):
102
+ if alpha == 1: return self.zipmap(operator.add, other=other)
103
+ return self.zipmap(_alpha_add, other=other, alpha = alpha)
104
+
105
+ def sub(self, other: Any, alpha: int | float = 1):
106
+ if alpha == 1: return self.zipmap(operator.sub, other=other)
107
+ return self.zipmap(_alpha_add, other=other, alpha = -alpha)
108
+
109
+ def neg(self): return self.__class__(-i for i in self)
110
+ def mul(self, other: Any): return self.zipmap(operator.mul, other=other)
111
+ def div(self, other: Any) -> Self: return self.zipmap(operator.truediv, other=other)
112
+ def pow(self, exponent: Any): return self.zipmap(math.pow, other=exponent)
113
+ def floor_divide(self, other: Any): return self.zipmap(operator.floordiv, other=other)
114
+ def remainder(self, other: Any): return self.zipmap(operator.mod, other=other)
115
+ def rpow(self, other: Any): return self.zipmap(lambda x,y: y**x, other=other)
116
+
117
+ def fill_none(self, value):
118
+ if isinstance(value, (list,tuple)): return self.__class__(v if s is None else s for s, v in zip(self, value))
119
+ return self.__class__(value if s is None else s for s in self)
120
+
121
+ def logical_not(self): return self.__class__(not i for i in self)
122
+ def logical_and(self, other: Any): return self.zipmap(operator.and_, other=other)
123
+ def logical_or(self, other: Any): return self.zipmap(operator.or_, other=other)
124
+ def logical_xor(self, other: Any): return self.zipmap(operator.xor, other=other)
125
+
126
+ def map(self, fn: Callable[..., torch.Tensor], *args, **kwargs):
127
+ """Applies `fn` to all elements of this TensorList
128
+ and returns a new TensorList with return values of the callable."""
129
+ return self.__class__(fn(i, *args, **kwargs) for i in self)
130
+
131
+ def clamp(self, min=None, max=None):
132
+ return self.zipmap_args(_clamp, min, max)
torchzero/utils/ops.py ADDED
@@ -0,0 +1,10 @@
1
+ import torch
2
+
3
+
4
+ def maximum_(input:torch.Tensor, other: torch.Tensor):
5
+ """in-place maximum"""
6
+ return torch.maximum(input, other, out = input)
7
+
8
+ def where_(input: torch.Tensor, condition: torch.Tensor, other: torch.Tensor):
9
+ """in-place where"""
10
+ return torch.where(condition, input, other, out = input)