torchzero 0.1.8__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +57 -0
- tests/test_identical.py +230 -0
- tests/test_module.py +50 -0
- tests/test_opts.py +884 -0
- tests/test_tensorlist.py +1787 -0
- tests/test_utils_optimizer.py +170 -0
- tests/test_vars.py +184 -0
- torchzero/__init__.py +4 -4
- torchzero/core/__init__.py +3 -13
- torchzero/core/module.py +629 -510
- torchzero/core/preconditioner.py +137 -0
- torchzero/core/transform.py +252 -0
- torchzero/modules/__init__.py +13 -21
- torchzero/modules/clipping/__init__.py +3 -0
- torchzero/modules/clipping/clipping.py +320 -0
- torchzero/modules/clipping/ema_clipping.py +135 -0
- torchzero/modules/clipping/growth_clipping.py +187 -0
- torchzero/modules/experimental/__init__.py +13 -18
- torchzero/modules/experimental/absoap.py +350 -0
- torchzero/modules/experimental/adadam.py +111 -0
- torchzero/modules/experimental/adamY.py +135 -0
- torchzero/modules/experimental/adasoap.py +282 -0
- torchzero/modules/experimental/algebraic_newton.py +145 -0
- torchzero/modules/experimental/curveball.py +89 -0
- torchzero/modules/experimental/dsoap.py +290 -0
- torchzero/modules/experimental/gradmin.py +85 -0
- torchzero/modules/experimental/reduce_outward_lr.py +35 -0
- torchzero/modules/experimental/spectral.py +286 -0
- torchzero/modules/experimental/subspace_preconditioners.py +128 -0
- torchzero/modules/experimental/tropical_newton.py +136 -0
- torchzero/modules/functional.py +209 -0
- torchzero/modules/grad_approximation/__init__.py +4 -0
- torchzero/modules/grad_approximation/fdm.py +120 -0
- torchzero/modules/grad_approximation/forward_gradient.py +81 -0
- torchzero/modules/grad_approximation/grad_approximator.py +66 -0
- torchzero/modules/grad_approximation/rfdm.py +259 -0
- torchzero/modules/line_search/__init__.py +5 -30
- torchzero/modules/line_search/backtracking.py +186 -0
- torchzero/modules/line_search/line_search.py +181 -0
- torchzero/modules/line_search/scipy.py +37 -0
- torchzero/modules/line_search/strong_wolfe.py +260 -0
- torchzero/modules/line_search/trust_region.py +61 -0
- torchzero/modules/lr/__init__.py +2 -0
- torchzero/modules/lr/lr.py +59 -0
- torchzero/modules/lr/step_size.py +97 -0
- torchzero/modules/momentum/__init__.py +14 -4
- torchzero/modules/momentum/averaging.py +78 -0
- torchzero/modules/momentum/cautious.py +181 -0
- torchzero/modules/momentum/ema.py +173 -0
- torchzero/modules/momentum/experimental.py +189 -0
- torchzero/modules/momentum/matrix_momentum.py +124 -0
- torchzero/modules/momentum/momentum.py +43 -106
- torchzero/modules/ops/__init__.py +103 -0
- torchzero/modules/ops/accumulate.py +65 -0
- torchzero/modules/ops/binary.py +240 -0
- torchzero/modules/ops/debug.py +25 -0
- torchzero/modules/ops/misc.py +419 -0
- torchzero/modules/ops/multi.py +137 -0
- torchzero/modules/ops/reduce.py +149 -0
- torchzero/modules/ops/split.py +75 -0
- torchzero/modules/ops/switch.py +68 -0
- torchzero/modules/ops/unary.py +115 -0
- torchzero/modules/ops/utility.py +112 -0
- torchzero/modules/optimizers/__init__.py +18 -10
- torchzero/modules/optimizers/adagrad.py +146 -49
- torchzero/modules/optimizers/adam.py +112 -118
- torchzero/modules/optimizers/lion.py +18 -11
- torchzero/modules/optimizers/muon.py +222 -0
- torchzero/modules/optimizers/orthograd.py +55 -0
- torchzero/modules/optimizers/rmsprop.py +103 -51
- torchzero/modules/optimizers/rprop.py +342 -99
- torchzero/modules/optimizers/shampoo.py +197 -0
- torchzero/modules/optimizers/soap.py +286 -0
- torchzero/modules/optimizers/sophia_h.py +129 -0
- torchzero/modules/projections/__init__.py +5 -0
- torchzero/modules/projections/dct.py +73 -0
- torchzero/modules/projections/fft.py +73 -0
- torchzero/modules/projections/galore.py +10 -0
- torchzero/modules/projections/projection.py +218 -0
- torchzero/modules/projections/structural.py +151 -0
- torchzero/modules/quasi_newton/__init__.py +7 -4
- torchzero/modules/quasi_newton/cg.py +218 -0
- torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
- torchzero/modules/quasi_newton/lbfgs.py +228 -0
- torchzero/modules/quasi_newton/lsr1.py +170 -0
- torchzero/modules/quasi_newton/olbfgs.py +196 -0
- torchzero/modules/quasi_newton/quasi_newton.py +475 -0
- torchzero/modules/second_order/__init__.py +3 -4
- torchzero/modules/second_order/newton.py +142 -165
- torchzero/modules/second_order/newton_cg.py +84 -0
- torchzero/modules/second_order/nystrom.py +168 -0
- torchzero/modules/smoothing/__init__.py +2 -5
- torchzero/modules/smoothing/gaussian.py +164 -0
- torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
- torchzero/modules/weight_decay/__init__.py +1 -0
- torchzero/modules/weight_decay/weight_decay.py +52 -0
- torchzero/modules/wrappers/__init__.py +1 -0
- torchzero/modules/wrappers/optim_wrapper.py +91 -0
- torchzero/optim/__init__.py +2 -10
- torchzero/optim/utility/__init__.py +1 -0
- torchzero/optim/utility/split.py +45 -0
- torchzero/optim/wrappers/nevergrad.py +2 -28
- torchzero/optim/wrappers/nlopt.py +31 -16
- torchzero/optim/wrappers/scipy.py +79 -156
- torchzero/utils/__init__.py +27 -0
- torchzero/utils/compile.py +175 -37
- torchzero/utils/derivatives.py +513 -99
- torchzero/utils/linalg/__init__.py +5 -0
- torchzero/utils/linalg/matrix_funcs.py +87 -0
- torchzero/utils/linalg/orthogonalize.py +11 -0
- torchzero/utils/linalg/qr.py +71 -0
- torchzero/utils/linalg/solve.py +168 -0
- torchzero/utils/linalg/svd.py +20 -0
- torchzero/utils/numberlist.py +132 -0
- torchzero/utils/ops.py +10 -0
- torchzero/utils/optimizer.py +284 -0
- torchzero/utils/optuna_tools.py +40 -0
- torchzero/utils/params.py +149 -0
- torchzero/utils/python_tools.py +40 -25
- torchzero/utils/tensorlist.py +1081 -0
- torchzero/utils/torch_tools.py +48 -12
- torchzero-0.3.1.dist-info/METADATA +379 -0
- torchzero-0.3.1.dist-info/RECORD +128 -0
- {torchzero-0.1.8.dist-info → torchzero-0.3.1.dist-info}/WHEEL +1 -1
- {torchzero-0.1.8.dist-info → torchzero-0.3.1.dist-info/licenses}/LICENSE +0 -0
- torchzero-0.3.1.dist-info/top_level.txt +3 -0
- torchzero/core/tensorlist_optimizer.py +0 -219
- torchzero/modules/adaptive/__init__.py +0 -4
- torchzero/modules/adaptive/adaptive.py +0 -192
- torchzero/modules/experimental/experimental.py +0 -294
- torchzero/modules/experimental/quad_interp.py +0 -104
- torchzero/modules/experimental/subspace.py +0 -259
- torchzero/modules/gradient_approximation/__init__.py +0 -7
- torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
- torchzero/modules/gradient_approximation/base_approximator.py +0 -105
- torchzero/modules/gradient_approximation/fdm.py +0 -125
- torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
- torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
- torchzero/modules/gradient_approximation/rfdm.py +0 -125
- torchzero/modules/line_search/armijo.py +0 -56
- torchzero/modules/line_search/base_ls.py +0 -139
- torchzero/modules/line_search/directional_newton.py +0 -217
- torchzero/modules/line_search/grid_ls.py +0 -158
- torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
- torchzero/modules/meta/__init__.py +0 -12
- torchzero/modules/meta/alternate.py +0 -65
- torchzero/modules/meta/grafting.py +0 -195
- torchzero/modules/meta/optimizer_wrapper.py +0 -173
- torchzero/modules/meta/return_overrides.py +0 -46
- torchzero/modules/misc/__init__.py +0 -10
- torchzero/modules/misc/accumulate.py +0 -43
- torchzero/modules/misc/basic.py +0 -115
- torchzero/modules/misc/lr.py +0 -96
- torchzero/modules/misc/multistep.py +0 -51
- torchzero/modules/misc/on_increase.py +0 -53
- torchzero/modules/operations/__init__.py +0 -29
- torchzero/modules/operations/multi.py +0 -298
- torchzero/modules/operations/reduction.py +0 -134
- torchzero/modules/operations/singular.py +0 -113
- torchzero/modules/optimizers/sgd.py +0 -54
- torchzero/modules/orthogonalization/__init__.py +0 -2
- torchzero/modules/orthogonalization/newtonschulz.py +0 -159
- torchzero/modules/orthogonalization/svd.py +0 -86
- torchzero/modules/regularization/__init__.py +0 -22
- torchzero/modules/regularization/dropout.py +0 -34
- torchzero/modules/regularization/noise.py +0 -77
- torchzero/modules/regularization/normalization.py +0 -328
- torchzero/modules/regularization/ortho_grad.py +0 -78
- torchzero/modules/regularization/weight_decay.py +0 -92
- torchzero/modules/scheduling/__init__.py +0 -2
- torchzero/modules/scheduling/lr_schedulers.py +0 -131
- torchzero/modules/scheduling/step_size.py +0 -80
- torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
- torchzero/modules/weight_averaging/__init__.py +0 -2
- torchzero/modules/weight_averaging/ema.py +0 -72
- torchzero/modules/weight_averaging/swa.py +0 -171
- torchzero/optim/experimental/__init__.py +0 -20
- torchzero/optim/experimental/experimental.py +0 -343
- torchzero/optim/experimental/ray_search.py +0 -83
- torchzero/optim/first_order/__init__.py +0 -18
- torchzero/optim/first_order/cautious.py +0 -158
- torchzero/optim/first_order/forward_gradient.py +0 -70
- torchzero/optim/first_order/optimizers.py +0 -570
- torchzero/optim/modular.py +0 -148
- torchzero/optim/quasi_newton/__init__.py +0 -1
- torchzero/optim/quasi_newton/directional_newton.py +0 -58
- torchzero/optim/second_order/__init__.py +0 -1
- torchzero/optim/second_order/newton.py +0 -94
- torchzero/optim/zeroth_order/__init__.py +0 -4
- torchzero/optim/zeroth_order/fdm.py +0 -87
- torchzero/optim/zeroth_order/newton_fdm.py +0 -146
- torchzero/optim/zeroth_order/rfdm.py +0 -217
- torchzero/optim/zeroth_order/rs.py +0 -85
- torchzero/random/__init__.py +0 -1
- torchzero/random/random.py +0 -46
- torchzero/tensorlist.py +0 -826
- torchzero-0.1.8.dist-info/METADATA +0 -130
- torchzero-0.1.8.dist-info/RECORD +0 -104
- torchzero-0.1.8.dist-info/top_level.txt +0 -1
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
from collections.abc import Callable
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
def eigvals_func(A: torch.Tensor, fn: Callable[[torch.Tensor], torch.Tensor]) -> torch.Tensor:
|
|
7
|
+
L, Q = torch.linalg.eigh(A) # pylint:disable=not-callable
|
|
8
|
+
L = fn(L)
|
|
9
|
+
return (Q * L.unsqueeze(-2)) @ Q.mH
|
|
10
|
+
|
|
11
|
+
def singular_vals_func(A: torch.Tensor, fn: Callable[[torch.Tensor], torch.Tensor]) -> torch.Tensor:
|
|
12
|
+
U, S, V = torch.linalg.svd(A) # pylint:disable=not-callable
|
|
13
|
+
S = fn(S)
|
|
14
|
+
return (U * S.unsqueeze(-2)) @ V.mT
|
|
15
|
+
|
|
16
|
+
def matrix_power_eigh(A: torch.Tensor, pow:float):
|
|
17
|
+
L, Q = torch.linalg.eigh(A) # pylint:disable=not-callable
|
|
18
|
+
if pow % 2 != 0: L.clip_(min = torch.finfo(A.dtype).eps)
|
|
19
|
+
return (Q * L.pow(pow).unsqueeze(-2)) @ Q.mH
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def inv_sqrt_2x2(A: torch.Tensor, force_pd: bool=False) -> torch.Tensor:
|
|
23
|
+
"""Inverse square root of a possibly batched 2x2 matrix using a general formula for 2x2 matrices so that this is way faster than torch linalg. I tried doing a hierarchical 2x2 preconditioning but it didn't work well."""
|
|
24
|
+
eps = torch.finfo(A.dtype).eps
|
|
25
|
+
|
|
26
|
+
a = A[..., 0, 0]
|
|
27
|
+
b = A[..., 0, 1]
|
|
28
|
+
c = A[..., 1, 0]
|
|
29
|
+
d = A[..., 1, 1]
|
|
30
|
+
|
|
31
|
+
det = (a * d).sub_(b * c)
|
|
32
|
+
trace = a + d
|
|
33
|
+
|
|
34
|
+
if force_pd:
|
|
35
|
+
# add smallest eigenvalue magnitude to diagonal to force PD
|
|
36
|
+
# could also abs or clip eigenvalues bc there is a formula for eigenvectors
|
|
37
|
+
term1 = trace/2
|
|
38
|
+
term2 = (trace.pow(2).div_(4).sub_(det)).clamp_(min=eps).sqrt_()
|
|
39
|
+
y1 = term1 + term2
|
|
40
|
+
y2 = term1 - term2
|
|
41
|
+
smallest_eigval = torch.minimum(y1, y2).neg_().clamp_(min=0) + eps
|
|
42
|
+
a = a+smallest_eigval
|
|
43
|
+
d = d+smallest_eigval
|
|
44
|
+
|
|
45
|
+
# recalculate det and trace witg new a and b
|
|
46
|
+
det = (a * d).sub_(b * c)
|
|
47
|
+
trace = a + d
|
|
48
|
+
|
|
49
|
+
s = (det.clamp(min=eps)).sqrt_()
|
|
50
|
+
|
|
51
|
+
tau_squared = trace + 2 * s
|
|
52
|
+
tau = (tau_squared.clamp(min=eps)).sqrt_()
|
|
53
|
+
|
|
54
|
+
denom = s * tau
|
|
55
|
+
|
|
56
|
+
coeff = (denom.clamp(min=eps)).reciprocal_().unsqueeze(-1).unsqueeze(-1)
|
|
57
|
+
|
|
58
|
+
row1 = torch.stack([d + s, -b], dim=-1)
|
|
59
|
+
row2 = torch.stack([-c, a + s], dim=-1)
|
|
60
|
+
M = torch.stack([row1, row2], dim=-2)
|
|
61
|
+
|
|
62
|
+
return coeff * M
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def x_inv(diag: torch.Tensor,antidiag: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
66
|
+
"""invert a matrix with diagonal and anti-diagonal non zero elements, with no checks that it is invertible"""
|
|
67
|
+
n = diag.shape[0]
|
|
68
|
+
if diag.dim() != 1 or antidiag.dim() != 1 or antidiag.shape[0] != n:
|
|
69
|
+
raise ValueError("Input tensors must be 1D and have the same size.")
|
|
70
|
+
if n == 0:
|
|
71
|
+
return torch.empty_like(diag), torch.empty_like(antidiag)
|
|
72
|
+
|
|
73
|
+
# opposite indexes
|
|
74
|
+
diag_rev = torch.flip(diag, dims=[0])
|
|
75
|
+
antidiag_rev = torch.flip(antidiag, dims=[0])
|
|
76
|
+
|
|
77
|
+
# determinants
|
|
78
|
+
# det_i = d[i] * d[n-1-i] - a[i] * a[n-1-i]
|
|
79
|
+
determinant_vec = diag * diag_rev - antidiag * antidiag_rev
|
|
80
|
+
|
|
81
|
+
# inverse diagonal elements: y_d[i] = d[n-1-i] / det_i
|
|
82
|
+
inv_diag_vec = diag_rev / determinant_vec
|
|
83
|
+
|
|
84
|
+
# inverse anti-diagonal elements: y_a[i] = -a[i] / det_i
|
|
85
|
+
inv_anti_diag_vec = -antidiag / determinant_vec
|
|
86
|
+
|
|
87
|
+
return inv_diag_vec, inv_anti_diag_vec
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
from typing import overload
|
|
2
|
+
import torch
|
|
3
|
+
from ..tensorlist import TensorList
|
|
4
|
+
|
|
5
|
+
@overload
|
|
6
|
+
def gram_schmidt(x: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: ...
|
|
7
|
+
@overload
|
|
8
|
+
def gram_schmidt(x: TensorList, y: TensorList) -> tuple[TensorList, TensorList]: ...
|
|
9
|
+
def gram_schmidt(x, y):
|
|
10
|
+
"""makes two orthogonal vectors, only y is changed"""
|
|
11
|
+
return x, y - (x*y) / ((x*x) + 1e-8)
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
from typing import Literal
|
|
2
|
+
import torch
|
|
3
|
+
from ..compile import enable_compilation
|
|
4
|
+
|
|
5
|
+
# reference - https://www.cs.cornell.edu/~bindel/class/cs6210-f09/lec18.pdf
|
|
6
|
+
def _get_w_tau(R: torch.Tensor, i: int, eps: float):
|
|
7
|
+
R_ii = R[...,i,i]
|
|
8
|
+
R_below = R[...,i:,i]
|
|
9
|
+
norm_x = torch.linalg.vector_norm(R_below, dim=-1) # pylint:disable=not-callable
|
|
10
|
+
degenerate = norm_x < eps
|
|
11
|
+
s = -torch.sign(R_ii)
|
|
12
|
+
u1 = R_ii - s*norm_x
|
|
13
|
+
u1 = torch.where(degenerate, 1, u1)
|
|
14
|
+
w = R_below / u1.unsqueeze(-1)
|
|
15
|
+
w[...,0] = 1
|
|
16
|
+
tau = -s*u1/norm_x
|
|
17
|
+
tau = torch.where(degenerate, 1, tau)
|
|
18
|
+
return w, tau
|
|
19
|
+
|
|
20
|
+
def _qr_householder_complete(A:torch.Tensor):
|
|
21
|
+
*b,m,n = A.shape
|
|
22
|
+
k = min(m,n)
|
|
23
|
+
eps = torch.finfo(A.dtype).eps
|
|
24
|
+
|
|
25
|
+
Q = torch.eye(m, dtype=A.dtype, device=A.device).expand(*b, m, m).clone() # clone because expanded dims refer to same memory
|
|
26
|
+
R = A.clone()
|
|
27
|
+
|
|
28
|
+
for i in range(k):
|
|
29
|
+
w, tau = _get_w_tau(R, i, eps)
|
|
30
|
+
|
|
31
|
+
R[..., i:,:] -= (tau*w).unsqueeze(-1) @ (w.unsqueeze(-2) @ R[..., i:,:])
|
|
32
|
+
Q[..., :,i:] -= (Q[..., :,i:]@w).unsqueeze(-1) @ (tau*w).unsqueeze(-2)
|
|
33
|
+
|
|
34
|
+
return Q, R
|
|
35
|
+
|
|
36
|
+
def _qr_householder_reduced(A:torch.Tensor):
|
|
37
|
+
*b,m,n = A.shape
|
|
38
|
+
k = min(m,n)
|
|
39
|
+
eps = torch.finfo(A.dtype).eps
|
|
40
|
+
|
|
41
|
+
R = A.clone()
|
|
42
|
+
|
|
43
|
+
ws:list = [None for _ in range(k)]
|
|
44
|
+
taus:list = [None for _ in range(k)]
|
|
45
|
+
|
|
46
|
+
for i in range(k):
|
|
47
|
+
w, tau = _get_w_tau(R, i, eps)
|
|
48
|
+
|
|
49
|
+
ws[i] = w
|
|
50
|
+
taus[i] = tau
|
|
51
|
+
|
|
52
|
+
if m - i > 0 :
|
|
53
|
+
R[..., i:,:] -= (tau*w).unsqueeze(-1) @ (w.unsqueeze(-2) @ R[..., i:,:])
|
|
54
|
+
# Q[..., :,i:] -= (Q[..., :,i:]@w).unsqueeze(-1) @ (tau*w).unsqueeze(-2)
|
|
55
|
+
|
|
56
|
+
R = R[..., :k, :]
|
|
57
|
+
Q = torch.eye(m, k, dtype=A.dtype, device=A.device).expand(*b, m, k).clone()
|
|
58
|
+
for i in range(k - 1, -1, -1):
|
|
59
|
+
if m - i > 0:
|
|
60
|
+
w = ws[i]
|
|
61
|
+
tau = taus[i].unsqueeze(-1).unsqueeze(-1)
|
|
62
|
+
Q_below = Q[..., i:, :]
|
|
63
|
+
Q[..., i:, :] -= torch.linalg.multi_dot([tau * w.unsqueeze(-1), w.unsqueeze(-2), Q_below]) # pylint:disable=not-callable
|
|
64
|
+
|
|
65
|
+
return Q, R
|
|
66
|
+
|
|
67
|
+
# @enable_compilation
|
|
68
|
+
def qr_householder(A:torch.Tensor, mode: Literal['complete', 'reduced'] = 'reduced'):
|
|
69
|
+
"""an attempt at making QR decomposition for very tall and thin matrices that doesn't freeze, but it is around n_cols times slower than torch.linalg.qr, but compilation makes it faster, but it has to recompile when processing different shapes"""
|
|
70
|
+
if mode == 'reduced': return _qr_householder_reduced(A)
|
|
71
|
+
return _qr_householder_complete(A)
|
|
@@ -0,0 +1,168 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
from typing import overload
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from .. import TensorList, generic_zeros_like, generic_vector_norm, generic_numel, generic_randn_like, generic_eq
|
|
6
|
+
|
|
7
|
+
@overload
|
|
8
|
+
def cg(
|
|
9
|
+
A_mm: Callable[[torch.Tensor], torch.Tensor],
|
|
10
|
+
b: torch.Tensor,
|
|
11
|
+
x0_: torch.Tensor | None,
|
|
12
|
+
tol: float | None,
|
|
13
|
+
maxiter: int | None,
|
|
14
|
+
reg: float = 0,
|
|
15
|
+
) -> torch.Tensor: ...
|
|
16
|
+
@overload
|
|
17
|
+
def cg(
|
|
18
|
+
A_mm: Callable[[TensorList], TensorList],
|
|
19
|
+
b: TensorList,
|
|
20
|
+
x0_: TensorList | None,
|
|
21
|
+
tol: float | None,
|
|
22
|
+
maxiter: int | None,
|
|
23
|
+
reg: float | list[float] | tuple[float] = 0,
|
|
24
|
+
) -> TensorList: ...
|
|
25
|
+
|
|
26
|
+
def cg(
|
|
27
|
+
A_mm: Callable,
|
|
28
|
+
b: torch.Tensor | TensorList,
|
|
29
|
+
x0_: torch.Tensor | TensorList | None,
|
|
30
|
+
tol: float | None,
|
|
31
|
+
maxiter: int | None,
|
|
32
|
+
reg: float | list[float] | tuple[float] = 0,
|
|
33
|
+
):
|
|
34
|
+
def A_mm_reg(x): # A_mm with regularization
|
|
35
|
+
Ax = A_mm(x)
|
|
36
|
+
if not generic_eq(reg, 0): Ax += x*reg
|
|
37
|
+
return Ax
|
|
38
|
+
|
|
39
|
+
if maxiter is None: maxiter = generic_numel(b)
|
|
40
|
+
if x0_ is None: x0_ = generic_zeros_like(b)
|
|
41
|
+
|
|
42
|
+
x = x0_
|
|
43
|
+
residual = b - A_mm_reg(x)
|
|
44
|
+
p = residual.clone() # search direction
|
|
45
|
+
r_norm = generic_vector_norm(residual)
|
|
46
|
+
init_norm = r_norm
|
|
47
|
+
if tol is not None and r_norm < tol: return x
|
|
48
|
+
k = 0
|
|
49
|
+
|
|
50
|
+
while True:
|
|
51
|
+
Ap = A_mm_reg(p)
|
|
52
|
+
step_size = (r_norm**2) / p.dot(Ap)
|
|
53
|
+
x += step_size * p # Update solution
|
|
54
|
+
residual -= step_size * Ap # Update residual
|
|
55
|
+
new_r_norm = generic_vector_norm(residual)
|
|
56
|
+
|
|
57
|
+
k += 1
|
|
58
|
+
if tol is not None and new_r_norm <= tol * init_norm: return x
|
|
59
|
+
if k >= maxiter: return x
|
|
60
|
+
|
|
61
|
+
beta = (new_r_norm**2) / (r_norm**2)
|
|
62
|
+
p = residual + beta*p
|
|
63
|
+
r_norm = new_r_norm
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
# https://arxiv.org/pdf/2110.02820 algorithm 2.1 apparently supposed to be diabolical
|
|
67
|
+
def nystrom_approximation(
|
|
68
|
+
A_mm: Callable[[torch.Tensor], torch.Tensor],
|
|
69
|
+
ndim: int,
|
|
70
|
+
rank: int,
|
|
71
|
+
device,
|
|
72
|
+
dtype = torch.float32,
|
|
73
|
+
generator = None,
|
|
74
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
75
|
+
omega = torch.randn((ndim, rank), device=device, dtype=dtype, generator=generator) # Gaussian test matrix
|
|
76
|
+
omega, _ = torch.linalg.qr(omega) # Thin QR decomposition # pylint:disable=not-callable
|
|
77
|
+
|
|
78
|
+
# Y = AΩ
|
|
79
|
+
Y = torch.stack([A_mm(col) for col in omega.unbind(-1)], -1) # rank matvecs
|
|
80
|
+
v = torch.finfo(dtype).eps * torch.linalg.matrix_norm(Y, ord='fro') # Compute shift # pylint:disable=not-callable
|
|
81
|
+
Yv = Y + v*omega # Shift for stability
|
|
82
|
+
C = torch.linalg.cholesky_ex(omega.mT @ Yv)[0] # pylint:disable=not-callable
|
|
83
|
+
B = torch.linalg.solve_triangular(C, Yv.mT, upper=False, unitriangular=False).mT # pylint:disable=not-callable
|
|
84
|
+
U, S, _ = torch.linalg.svd(B, full_matrices=False) # pylint:disable=not-callable
|
|
85
|
+
lambd = (S.pow(2) - v).clip(min=0) #Remove shift, compute eigs
|
|
86
|
+
return U, lambd
|
|
87
|
+
|
|
88
|
+
# this one works worse
|
|
89
|
+
def nystrom_sketch_and_solve(
|
|
90
|
+
A_mm: Callable[[torch.Tensor], torch.Tensor],
|
|
91
|
+
b: torch.Tensor,
|
|
92
|
+
rank: int,
|
|
93
|
+
reg: float,
|
|
94
|
+
generator=None,
|
|
95
|
+
) -> torch.Tensor:
|
|
96
|
+
U, lambd = nystrom_approximation(
|
|
97
|
+
A_mm=A_mm,
|
|
98
|
+
ndim=b.size(-1),
|
|
99
|
+
rank=rank,
|
|
100
|
+
device=b.device,
|
|
101
|
+
dtype=b.dtype,
|
|
102
|
+
generator=generator,
|
|
103
|
+
)
|
|
104
|
+
b = b.unsqueeze(-1)
|
|
105
|
+
lambd += reg
|
|
106
|
+
# x = (A + μI)⁻¹ b
|
|
107
|
+
# (A + μI)⁻¹ = U(Λ + μI)⁻¹Uᵀ + (1/μ)(b - UUᵀ)
|
|
108
|
+
# x = U(Λ + μI)⁻¹Uᵀb + (1/μ)(b - UUᵀb)
|
|
109
|
+
Uᵀb = U.T @ b
|
|
110
|
+
term1 = U @ ((1/lambd).unsqueeze(-1) * Uᵀb)
|
|
111
|
+
term2 = (1.0 / reg) * (b - U @ Uᵀb)
|
|
112
|
+
return (term1 + term2).squeeze(-1)
|
|
113
|
+
|
|
114
|
+
# this one is insane
|
|
115
|
+
def nystrom_pcg(
|
|
116
|
+
A_mm: Callable[[torch.Tensor], torch.Tensor],
|
|
117
|
+
b: torch.Tensor,
|
|
118
|
+
sketch_size: int,
|
|
119
|
+
reg: float,
|
|
120
|
+
x0_: torch.Tensor | None,
|
|
121
|
+
tol: float | None,
|
|
122
|
+
maxiter: int | None,
|
|
123
|
+
generator=None,
|
|
124
|
+
) -> torch.Tensor:
|
|
125
|
+
U, lambd = nystrom_approximation(
|
|
126
|
+
A_mm=A_mm,
|
|
127
|
+
ndim=b.size(-1),
|
|
128
|
+
rank=sketch_size,
|
|
129
|
+
device=b.device,
|
|
130
|
+
dtype=b.dtype,
|
|
131
|
+
generator=generator,
|
|
132
|
+
)
|
|
133
|
+
lambd += reg
|
|
134
|
+
|
|
135
|
+
def A_mm_reg(x): # A_mm with regularization
|
|
136
|
+
Ax = A_mm(x)
|
|
137
|
+
if reg != 0: Ax += x*reg
|
|
138
|
+
return Ax
|
|
139
|
+
|
|
140
|
+
if maxiter is None: maxiter = b.numel()
|
|
141
|
+
if x0_ is None: x0_ = torch.zeros_like(b)
|
|
142
|
+
|
|
143
|
+
x = x0_
|
|
144
|
+
residual = b - A_mm_reg(x)
|
|
145
|
+
# z0 = P⁻¹ r0
|
|
146
|
+
term1 = lambd[...,-1] * U * (1/lambd.unsqueeze(-2)) @ U.mT
|
|
147
|
+
term2 = torch.eye(U.size(-2), device=U.device,dtype=U.dtype) - U@U.mT
|
|
148
|
+
P_inv = term1 + term2
|
|
149
|
+
z = P_inv @ residual
|
|
150
|
+
p = z.clone() # search direction
|
|
151
|
+
|
|
152
|
+
init_norm = torch.linalg.vector_norm(residual) # pylint:disable=not-callable
|
|
153
|
+
if tol is not None and init_norm < tol: return x
|
|
154
|
+
k = 0
|
|
155
|
+
while True:
|
|
156
|
+
Ap = A_mm_reg(p)
|
|
157
|
+
rz = residual.dot(z)
|
|
158
|
+
step_size = rz / p.dot(Ap)
|
|
159
|
+
x += step_size * p
|
|
160
|
+
residual -= step_size * Ap
|
|
161
|
+
|
|
162
|
+
k += 1
|
|
163
|
+
if tol is not None and torch.linalg.vector_norm(residual) <= tol * init_norm: return x # pylint:disable=not-callable
|
|
164
|
+
if k >= maxiter: return x
|
|
165
|
+
|
|
166
|
+
z = P_inv @ residual
|
|
167
|
+
beta = residual.dot(z) / rz
|
|
168
|
+
p = z + p*beta
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
# projected svd
|
|
4
|
+
# adapted from https://github.com/smortezavi/Randomized_SVD_GPU
|
|
5
|
+
def randomized_svd(M: torch.Tensor, k: int, driver=None):
|
|
6
|
+
*_, m, n = M.shape
|
|
7
|
+
transpose = False
|
|
8
|
+
if m < n:
|
|
9
|
+
transpose = True
|
|
10
|
+
M = M.mT
|
|
11
|
+
m,n = n,m
|
|
12
|
+
|
|
13
|
+
rand_matrix = torch.randn(size=(n, k), device=M.device, dtype=M.dtype)
|
|
14
|
+
Q, _ = torch.linalg.qr(M @ rand_matrix, mode='reduced') # pylint:disable=not-callable
|
|
15
|
+
smaller_matrix = Q.mT @ M
|
|
16
|
+
U_hat, s, V = torch.linalg.svd(smaller_matrix, driver=driver, full_matrices=False) # pylint:disable=not-callable
|
|
17
|
+
U = Q @ U_hat
|
|
18
|
+
|
|
19
|
+
if transpose: return V.mT, s, U.mT
|
|
20
|
+
return U, s, V
|
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
"""A lightweight data type for a list of numbers (or anything else) with arithmetic overloads (using basic for-loops).
|
|
2
|
+
Subclasses list so works with torch._foreach_xxx operations."""
|
|
3
|
+
import builtins
|
|
4
|
+
from collections.abc import Callable, Sequence, Iterable, Generator, Iterator
|
|
5
|
+
import math
|
|
6
|
+
import operator
|
|
7
|
+
from typing import Any, Literal, TypedDict
|
|
8
|
+
from typing_extensions import Self, TypeAlias, Unpack
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
from .python_tools import zipmap
|
|
12
|
+
|
|
13
|
+
def _alpha_add(x, other, alpha):
|
|
14
|
+
return x + other * alpha
|
|
15
|
+
|
|
16
|
+
def as_numberlist(x):
|
|
17
|
+
if isinstance(x, NumberList): return x
|
|
18
|
+
return NumberList(x)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def maybe_numberlist(x):
|
|
22
|
+
if isinstance(x, (list,tuple)): return as_numberlist(x)
|
|
23
|
+
return x
|
|
24
|
+
|
|
25
|
+
def _clamp(x,min,max):
|
|
26
|
+
if min is not None and x < min: return min
|
|
27
|
+
if max is not None and x > max: return max
|
|
28
|
+
return x
|
|
29
|
+
|
|
30
|
+
class NumberList(list[int | float | Any]):
|
|
31
|
+
"""List of python numbers.
|
|
32
|
+
Note that this only supports basic arithmetic operations that are overloaded.
|
|
33
|
+
|
|
34
|
+
Can't use a numpy array because _foreach methods do not work with it."""
|
|
35
|
+
# remove torch.Tensor from return values
|
|
36
|
+
# this is no longer necessary
|
|
37
|
+
# def __getitem__(self, i) -> Any:
|
|
38
|
+
# return super().__getitem__(i)
|
|
39
|
+
|
|
40
|
+
# def __iter__(self) -> Iterator[Any]:
|
|
41
|
+
# return super().__iter__()
|
|
42
|
+
|
|
43
|
+
def __add__(self, other: Any) -> Self: return self.add(other) # type:ignore
|
|
44
|
+
def __radd__(self, other: Any) -> Self: return self.add(other)
|
|
45
|
+
|
|
46
|
+
def __sub__(self, other: Any) -> Self: return self.sub(other)
|
|
47
|
+
def __rsub__(self, other: Any) -> Self: return self.sub(other).neg()
|
|
48
|
+
|
|
49
|
+
def __mul__(self, other: Any) -> Self: return self.mul(other) # type:ignore
|
|
50
|
+
def __rmul__(self, other: Any) -> Self: return self.mul(other) # type:ignore
|
|
51
|
+
|
|
52
|
+
def __truediv__(self, other: Any) -> Self: return self.div(other)
|
|
53
|
+
def __rtruediv__(self, other: Any):
|
|
54
|
+
if isinstance(other, (tuple,list)): return self.__class__(o / i for o, i in zip(self, other))
|
|
55
|
+
return self.__class__(other / i for i in self)
|
|
56
|
+
|
|
57
|
+
def __floordiv__(self, other: Any): return self.floor_divide(other)
|
|
58
|
+
def __mod__(self, other: Any): return self.remainder(other)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def __pow__(self, other: Any): return self.pow(other)
|
|
62
|
+
def __rpow__(self, other: Any): return self.rpow(other)
|
|
63
|
+
|
|
64
|
+
def __neg__(self): return self.neg()
|
|
65
|
+
|
|
66
|
+
def __eq__(self, other: Any): return self.eq(other) # type:ignore
|
|
67
|
+
def __ne__(self, other: Any): return self.ne(other) # type:ignore
|
|
68
|
+
def __lt__(self, other: Any): return self.lt(other) # type:ignore
|
|
69
|
+
def __le__(self, other: Any): return self.le(other) # type:ignore
|
|
70
|
+
def __gt__(self, other: Any): return self.gt(other) # type:ignore
|
|
71
|
+
def __ge__(self, other: Any): return self.ge(other) # type:ignore
|
|
72
|
+
|
|
73
|
+
def __invert__(self): return self.logical_not()
|
|
74
|
+
|
|
75
|
+
def __and__(self, other: Any): return self.logical_and(other)
|
|
76
|
+
def __or__(self, other: Any): return self.logical_or(other)
|
|
77
|
+
def __xor__(self, other: Any): return self.logical_xor(other)
|
|
78
|
+
|
|
79
|
+
def __bool__(self):
|
|
80
|
+
raise RuntimeError(f'Boolean value of {self.__class__.__name__} is ambiguous')
|
|
81
|
+
|
|
82
|
+
def zipmap(self, fn: Callable, other: Any | list | tuple, *args, **kwargs):
|
|
83
|
+
"""If `other` is list/tuple, applies `fn` to this TensorList zipped with `other`.
|
|
84
|
+
Otherwise applies `fn` to this TensorList and `other`.
|
|
85
|
+
Returns a new TensorList with return values of the callable."""
|
|
86
|
+
return zipmap(self, fn, other, *args, **kwargs)
|
|
87
|
+
|
|
88
|
+
def zipmap_args(self, fn: Callable[..., Any], *others, **kwargs):
|
|
89
|
+
"""If `args` is list/tuple, applies `fn` to this TensorList zipped with `others`.
|
|
90
|
+
Otherwise applies `fn` to this TensorList and `other`."""
|
|
91
|
+
others = [i if isinstance(i, (list, tuple)) else [i]*len(self) for i in others]
|
|
92
|
+
return self.__class__(fn(*z, **kwargs) for z in zip(self, *others))
|
|
93
|
+
|
|
94
|
+
# def _set_to_method_result_(self, method: str, *args, **kwargs):
|
|
95
|
+
# """Sets each element of the tensorlist to the result of calling the specified method on the corresponding element.
|
|
96
|
+
# This is used to support/mimic in-place operations, although I decided to remove them."""
|
|
97
|
+
# res = getattr(self, method)(*args, **kwargs)
|
|
98
|
+
# for i,v in enumerate(res): self[i] = v
|
|
99
|
+
# return self
|
|
100
|
+
|
|
101
|
+
def add(self, other: Any, alpha: int | float = 1):
|
|
102
|
+
if alpha == 1: return self.zipmap(operator.add, other=other)
|
|
103
|
+
return self.zipmap(_alpha_add, other=other, alpha = alpha)
|
|
104
|
+
|
|
105
|
+
def sub(self, other: Any, alpha: int | float = 1):
|
|
106
|
+
if alpha == 1: return self.zipmap(operator.sub, other=other)
|
|
107
|
+
return self.zipmap(_alpha_add, other=other, alpha = -alpha)
|
|
108
|
+
|
|
109
|
+
def neg(self): return self.__class__(-i for i in self)
|
|
110
|
+
def mul(self, other: Any): return self.zipmap(operator.mul, other=other)
|
|
111
|
+
def div(self, other: Any) -> Self: return self.zipmap(operator.truediv, other=other)
|
|
112
|
+
def pow(self, exponent: Any): return self.zipmap(math.pow, other=exponent)
|
|
113
|
+
def floor_divide(self, other: Any): return self.zipmap(operator.floordiv, other=other)
|
|
114
|
+
def remainder(self, other: Any): return self.zipmap(operator.mod, other=other)
|
|
115
|
+
def rpow(self, other: Any): return self.zipmap(lambda x,y: y**x, other=other)
|
|
116
|
+
|
|
117
|
+
def fill_none(self, value):
|
|
118
|
+
if isinstance(value, (list,tuple)): return self.__class__(v if s is None else s for s, v in zip(self, value))
|
|
119
|
+
return self.__class__(value if s is None else s for s in self)
|
|
120
|
+
|
|
121
|
+
def logical_not(self): return self.__class__(not i for i in self)
|
|
122
|
+
def logical_and(self, other: Any): return self.zipmap(operator.and_, other=other)
|
|
123
|
+
def logical_or(self, other: Any): return self.zipmap(operator.or_, other=other)
|
|
124
|
+
def logical_xor(self, other: Any): return self.zipmap(operator.xor, other=other)
|
|
125
|
+
|
|
126
|
+
def map(self, fn: Callable[..., torch.Tensor], *args, **kwargs):
|
|
127
|
+
"""Applies `fn` to all elements of this TensorList
|
|
128
|
+
and returns a new TensorList with return values of the callable."""
|
|
129
|
+
return self.__class__(fn(i, *args, **kwargs) for i in self)
|
|
130
|
+
|
|
131
|
+
def clamp(self, min=None, max=None):
|
|
132
|
+
return self.zipmap_args(_clamp, min, max)
|
torchzero/utils/ops.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def maximum_(input:torch.Tensor, other: torch.Tensor):
|
|
5
|
+
"""in-place maximum"""
|
|
6
|
+
return torch.maximum(input, other, out = input)
|
|
7
|
+
|
|
8
|
+
def where_(input: torch.Tensor, condition: torch.Tensor, other: torch.Tensor):
|
|
9
|
+
"""in-place where"""
|
|
10
|
+
return torch.where(condition, input, other, out = input)
|