torchzero 0.3.15__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +22 -22
- tests/test_module_autograd.py +586 -0
- tests/test_objective.py +188 -0
- tests/test_opts.py +225 -214
- tests/test_tensorlist.py +0 -8
- tests/test_utils_optimizer.py +0 -1
- torchzero/__init__.py +2 -2
- torchzero/core/__init__.py +7 -4
- torchzero/core/chain.py +20 -23
- torchzero/core/functional.py +90 -24
- torchzero/core/modular.py +53 -57
- torchzero/core/module.py +132 -52
- torchzero/core/objective.py +948 -0
- torchzero/core/reformulation.py +55 -24
- torchzero/core/transform.py +261 -367
- torchzero/linalg/__init__.py +11 -0
- torchzero/linalg/eigh.py +253 -0
- torchzero/linalg/linalg_utils.py +14 -0
- torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
- torchzero/linalg/matrix_power.py +28 -0
- torchzero/linalg/orthogonalize.py +93 -0
- torchzero/{utils/linalg → linalg}/qr.py +16 -2
- torchzero/{utils/linalg → linalg}/solve.py +74 -88
- torchzero/linalg/svd.py +47 -0
- torchzero/linalg/torch_linalg.py +168 -0
- torchzero/modules/__init__.py +4 -3
- torchzero/modules/adaptive/__init__.py +11 -3
- torchzero/modules/adaptive/adagrad.py +167 -217
- torchzero/modules/adaptive/adahessian.py +76 -105
- torchzero/modules/adaptive/adam.py +53 -76
- torchzero/modules/adaptive/adan.py +50 -31
- torchzero/modules/adaptive/adaptive_heavyball.py +12 -7
- torchzero/modules/adaptive/aegd.py +12 -12
- torchzero/modules/adaptive/esgd.py +98 -119
- torchzero/modules/adaptive/ggt.py +186 -0
- torchzero/modules/adaptive/lion.py +7 -11
- torchzero/modules/adaptive/lre_optimizers.py +299 -0
- torchzero/modules/adaptive/mars.py +7 -7
- torchzero/modules/adaptive/matrix_momentum.py +48 -52
- torchzero/modules/adaptive/msam.py +71 -53
- torchzero/modules/adaptive/muon.py +67 -129
- torchzero/modules/adaptive/natural_gradient.py +63 -41
- torchzero/modules/adaptive/orthograd.py +11 -15
- torchzero/modules/adaptive/psgd/__init__.py +5 -0
- torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
- torchzero/modules/adaptive/psgd/psgd.py +1390 -0
- torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
- torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
- torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
- torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
- torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
- torchzero/modules/adaptive/rmsprop.py +83 -75
- torchzero/modules/adaptive/rprop.py +48 -47
- torchzero/modules/adaptive/sam.py +55 -45
- torchzero/modules/adaptive/shampoo.py +149 -130
- torchzero/modules/adaptive/soap.py +207 -143
- torchzero/modules/adaptive/sophia_h.py +106 -130
- torchzero/modules/clipping/clipping.py +22 -25
- torchzero/modules/clipping/ema_clipping.py +31 -25
- torchzero/modules/clipping/growth_clipping.py +14 -17
- torchzero/modules/conjugate_gradient/cg.py +27 -38
- torchzero/modules/experimental/__init__.py +7 -6
- torchzero/modules/experimental/adanystrom.py +258 -0
- torchzero/modules/experimental/common_directions_whiten.py +142 -0
- torchzero/modules/experimental/coordinate_momentum.py +36 -0
- torchzero/modules/experimental/cubic_adam.py +160 -0
- torchzero/modules/experimental/curveball.py +25 -41
- torchzero/modules/experimental/eigen_sr1.py +182 -0
- torchzero/modules/experimental/eigengrad.py +207 -0
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/higher_order_newton.py +14 -40
- torchzero/modules/experimental/l_infinity.py +1 -1
- torchzero/modules/experimental/matrix_nag.py +122 -0
- torchzero/modules/experimental/newton_solver.py +23 -54
- torchzero/modules/experimental/newtonnewton.py +45 -48
- torchzero/modules/experimental/reduce_outward_lr.py +7 -7
- torchzero/modules/experimental/scipy_newton_cg.py +21 -24
- torchzero/modules/experimental/spsa1.py +3 -3
- torchzero/modules/experimental/structural_projections.py +1 -4
- torchzero/modules/grad_approximation/fdm.py +2 -2
- torchzero/modules/grad_approximation/forward_gradient.py +7 -7
- torchzero/modules/grad_approximation/grad_approximator.py +23 -16
- torchzero/modules/grad_approximation/rfdm.py +24 -21
- torchzero/modules/least_squares/gn.py +121 -50
- torchzero/modules/line_search/backtracking.py +4 -4
- torchzero/modules/line_search/line_search.py +33 -33
- torchzero/modules/line_search/strong_wolfe.py +4 -4
- torchzero/modules/misc/debug.py +12 -12
- torchzero/modules/misc/escape.py +10 -10
- torchzero/modules/misc/gradient_accumulation.py +11 -79
- torchzero/modules/misc/homotopy.py +16 -8
- torchzero/modules/misc/misc.py +121 -123
- torchzero/modules/misc/multistep.py +52 -53
- torchzero/modules/misc/regularization.py +49 -44
- torchzero/modules/misc/split.py +31 -29
- torchzero/modules/misc/switch.py +37 -32
- torchzero/modules/momentum/averaging.py +14 -14
- torchzero/modules/momentum/cautious.py +37 -31
- torchzero/modules/momentum/momentum.py +12 -12
- torchzero/modules/ops/__init__.py +4 -4
- torchzero/modules/ops/accumulate.py +21 -21
- torchzero/modules/ops/binary.py +67 -66
- torchzero/modules/ops/higher_level.py +20 -20
- torchzero/modules/ops/multi.py +44 -41
- torchzero/modules/ops/reduce.py +26 -23
- torchzero/modules/ops/unary.py +53 -53
- torchzero/modules/ops/utility.py +47 -46
- torchzero/modules/{functional.py → opt_utils.py} +1 -1
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +46 -43
- torchzero/modules/quasi_newton/__init__.py +1 -1
- torchzero/modules/quasi_newton/damping.py +2 -2
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +10 -10
- torchzero/modules/quasi_newton/lsr1.py +10 -10
- torchzero/modules/quasi_newton/quasi_newton.py +54 -39
- torchzero/modules/quasi_newton/sg2.py +69 -205
- torchzero/modules/restarts/restars.py +39 -37
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/ifn.py +31 -62
- torchzero/modules/second_order/inm.py +57 -53
- torchzero/modules/second_order/multipoint.py +40 -80
- torchzero/modules/second_order/newton.py +165 -196
- torchzero/modules/second_order/newton_cg.py +105 -157
- torchzero/modules/second_order/nystrom.py +216 -185
- torchzero/modules/second_order/rsn.py +132 -125
- torchzero/modules/smoothing/laplacian.py +13 -12
- torchzero/modules/smoothing/sampling.py +10 -10
- torchzero/modules/step_size/adaptive.py +24 -24
- torchzero/modules/step_size/lr.py +17 -17
- torchzero/modules/termination/termination.py +32 -30
- torchzero/modules/trust_region/cubic_regularization.py +3 -3
- torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
- torchzero/modules/trust_region/trust_cg.py +2 -2
- torchzero/modules/trust_region/trust_region.py +27 -22
- torchzero/modules/variance_reduction/svrg.py +23 -21
- torchzero/modules/weight_decay/__init__.py +2 -1
- torchzero/modules/weight_decay/reinit.py +83 -0
- torchzero/modules/weight_decay/weight_decay.py +17 -18
- torchzero/modules/wrappers/optim_wrapper.py +14 -14
- torchzero/modules/zeroth_order/cd.py +10 -7
- torchzero/optim/mbs.py +291 -0
- torchzero/optim/root.py +3 -3
- torchzero/optim/utility/split.py +2 -1
- torchzero/optim/wrappers/directsearch.py +27 -63
- torchzero/optim/wrappers/fcmaes.py +14 -35
- torchzero/optim/wrappers/mads.py +11 -31
- torchzero/optim/wrappers/moors.py +66 -0
- torchzero/optim/wrappers/nevergrad.py +4 -13
- torchzero/optim/wrappers/nlopt.py +31 -25
- torchzero/optim/wrappers/optuna.py +8 -13
- torchzero/optim/wrappers/pybobyqa.py +124 -0
- torchzero/optim/wrappers/scipy/__init__.py +7 -0
- torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
- torchzero/optim/wrappers/scipy/brute.py +48 -0
- torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
- torchzero/optim/wrappers/scipy/direct.py +69 -0
- torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
- torchzero/optim/wrappers/scipy/experimental.py +141 -0
- torchzero/optim/wrappers/scipy/minimize.py +151 -0
- torchzero/optim/wrappers/scipy/sgho.py +111 -0
- torchzero/optim/wrappers/wrapper.py +121 -0
- torchzero/utils/__init__.py +7 -25
- torchzero/utils/benchmarks/__init__.py +0 -0
- torchzero/utils/benchmarks/logistic.py +122 -0
- torchzero/utils/compile.py +2 -2
- torchzero/utils/derivatives.py +97 -73
- torchzero/utils/optimizer.py +4 -77
- torchzero/utils/python_tools.py +31 -0
- torchzero/utils/tensorlist.py +11 -5
- torchzero/utils/thoad_tools.py +68 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
- torchzero-0.4.1.dist-info/RECORD +209 -0
- tests/test_vars.py +0 -185
- torchzero/core/var.py +0 -376
- torchzero/modules/adaptive/lmadagrad.py +0 -186
- torchzero/modules/experimental/momentum.py +0 -160
- torchzero/optim/wrappers/scipy.py +0 -572
- torchzero/utils/linalg/__init__.py +0 -12
- torchzero/utils/linalg/matrix_funcs.py +0 -87
- torchzero/utils/linalg/orthogonalize.py +0 -12
- torchzero/utils/linalg/svd.py +0 -20
- torchzero/utils/ops.py +0 -10
- torchzero-0.3.15.dist-info/RECORD +0 -175
- /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
|
@@ -1,8 +1,21 @@
|
|
|
1
1
|
from typing import Literal
|
|
2
2
|
import torch
|
|
3
|
-
from ..compile import
|
|
3
|
+
from ..utils.compile import allow_compile
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
# super slow
|
|
7
|
+
# def cholesky_qr(A):
|
|
8
|
+
# """QR of (m, n) A via cholesky of (n, n) matrix"""
|
|
9
|
+
# AtA = A.T @ A
|
|
10
|
+
|
|
11
|
+
# L, _ = torch.linalg.cholesky_ex(AtA) # pylint:disable=not-callable
|
|
12
|
+
# R = L.T
|
|
13
|
+
|
|
14
|
+
# Q = torch.linalg.solve_triangular(R.T, A.T, upper=False).T # pylint:disable=not-callable
|
|
15
|
+
# return Q, R
|
|
4
16
|
|
|
5
17
|
# reference - https://www.cs.cornell.edu/~bindel/class/cs6210-f09/lec18.pdf
|
|
18
|
+
@allow_compile
|
|
6
19
|
def _get_w_tau(R: torch.Tensor, i: int, eps: float):
|
|
7
20
|
R_ii = R[...,i,i]
|
|
8
21
|
R_below = R[...,i:,i]
|
|
@@ -17,6 +30,7 @@ def _get_w_tau(R: torch.Tensor, i: int, eps: float):
|
|
|
17
30
|
tau = torch.where(degenerate, 1, tau)
|
|
18
31
|
return w, tau
|
|
19
32
|
|
|
33
|
+
@allow_compile
|
|
20
34
|
def _qr_householder_complete(A:torch.Tensor):
|
|
21
35
|
*b,m,n = A.shape
|
|
22
36
|
k = min(m,n)
|
|
@@ -33,6 +47,7 @@ def _qr_householder_complete(A:torch.Tensor):
|
|
|
33
47
|
|
|
34
48
|
return Q, R
|
|
35
49
|
|
|
50
|
+
@allow_compile
|
|
36
51
|
def _qr_householder_reduced(A:torch.Tensor):
|
|
37
52
|
*b,m,n = A.shape
|
|
38
53
|
k = min(m,n)
|
|
@@ -64,7 +79,6 @@ def _qr_householder_reduced(A:torch.Tensor):
|
|
|
64
79
|
|
|
65
80
|
return Q, R
|
|
66
81
|
|
|
67
|
-
# @enable_compilation
|
|
68
82
|
def qr_householder(A:torch.Tensor, mode: Literal['complete', 'reduced'] = 'reduced'):
|
|
69
83
|
"""an attempt at making QR decomposition for very tall and thin matrices that doesn't freeze, but it is around n_cols times slower than torch.linalg.qr, but compilation makes it faster, but it has to recompile when processing different shapes"""
|
|
70
84
|
if mode == 'reduced': return _qr_householder_reduced(A)
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
# pylint: disable = non-ascii-name
|
|
1
2
|
# pyright: reportArgumentType=false
|
|
2
3
|
import math
|
|
3
4
|
from collections import deque
|
|
@@ -5,8 +6,8 @@ from collections.abc import Callable
|
|
|
5
6
|
from typing import Any, NamedTuple, overload
|
|
6
7
|
|
|
7
8
|
import torch
|
|
8
|
-
|
|
9
|
-
from .. import (
|
|
9
|
+
from .linalg_utils import mm
|
|
10
|
+
from ..utils import (
|
|
10
11
|
TensorList,
|
|
11
12
|
generic_eq,
|
|
12
13
|
generic_finfo_tiny,
|
|
@@ -15,88 +16,71 @@ from .. import (
|
|
|
15
16
|
generic_zeros_like,
|
|
16
17
|
)
|
|
17
18
|
|
|
18
|
-
|
|
19
|
-
def
|
|
20
|
-
|
|
21
|
-
Ax = A_mm(x)
|
|
19
|
+
def _make_A_mv_reg(A_mv: Callable, reg):
|
|
20
|
+
def A_mv_reg(x): # A_mm with regularization
|
|
21
|
+
Ax = A_mv(x)
|
|
22
22
|
if not generic_eq(reg, 0): Ax += x*reg
|
|
23
23
|
return Ax
|
|
24
|
-
return
|
|
24
|
+
return A_mv_reg
|
|
25
25
|
|
|
26
26
|
def _identity(x): return x
|
|
27
27
|
|
|
28
|
-
|
|
29
|
-
# https://arxiv.org/pdf/2110.02820
|
|
30
|
-
def nystrom_approximation(
|
|
31
|
-
A_mm: Callable[[torch.Tensor], torch.Tensor],
|
|
32
|
-
ndim: int,
|
|
33
|
-
rank: int,
|
|
34
|
-
device,
|
|
35
|
-
dtype = torch.float32,
|
|
36
|
-
generator = None,
|
|
37
|
-
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
38
|
-
omega = torch.randn((ndim, rank), device=device, dtype=dtype, generator=generator) # Gaussian test matrix
|
|
39
|
-
omega, _ = torch.linalg.qr(omega) # Thin QR decomposition # pylint:disable=not-callable
|
|
40
|
-
|
|
41
|
-
# Y = AΩ
|
|
42
|
-
Y = torch.stack([A_mm(col) for col in omega.unbind(-1)], -1) # rank matvecs
|
|
43
|
-
v = torch.finfo(dtype).eps * torch.linalg.matrix_norm(Y, ord='fro') # Compute shift # pylint:disable=not-callable
|
|
44
|
-
Yv = Y + v*omega # Shift for stability
|
|
45
|
-
C = torch.linalg.cholesky_ex(omega.mT @ Yv)[0] # pylint:disable=not-callable
|
|
46
|
-
B = torch.linalg.solve_triangular(C, Yv.mT, upper=False, unitriangular=False).mT # pylint:disable=not-callable
|
|
47
|
-
U, S, _ = torch.linalg.svd(B, full_matrices=False) # pylint:disable=not-callable
|
|
48
|
-
lambd = (S.pow(2) - v).clip(min=0) #Remove shift, compute eigs
|
|
49
|
-
return U, lambd
|
|
50
|
-
|
|
51
28
|
def nystrom_sketch_and_solve(
|
|
52
|
-
|
|
29
|
+
L: torch.Tensor,
|
|
30
|
+
Q: torch.Tensor,
|
|
53
31
|
b: torch.Tensor,
|
|
54
|
-
rank: int,
|
|
55
32
|
reg: float = 1e-3,
|
|
56
|
-
generator=None,
|
|
57
33
|
) -> torch.Tensor:
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
34
|
+
"""Solves ``(Q diag(L) Q.T + reg*I)x = b``. Becomes super unstable with reg smaller than like 1e-5.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
L (torch.Tensor): eigenvalues, like from ``nystrom_approximation``
|
|
38
|
+
Q (torch.Tensor): eigenvectors, like from ``nystrom_approximation``
|
|
39
|
+
b (torch.Tensor): right hand side
|
|
40
|
+
reg (float, optional): regularization. Defaults to 1e-3.
|
|
41
|
+
"""
|
|
42
|
+
|
|
66
43
|
b = b.unsqueeze(-1)
|
|
67
|
-
|
|
44
|
+
L += reg
|
|
68
45
|
# x = (A + μI)⁻¹ b
|
|
69
|
-
# (A + μI)⁻¹ =
|
|
70
|
-
# x =
|
|
71
|
-
|
|
72
|
-
term1 =
|
|
73
|
-
term2 = (1.0 / reg) * (b -
|
|
46
|
+
# (A + μI)⁻¹ = Q(L + μI)⁻¹Qᵀ + (1/μ)(b - QQᵀ)
|
|
47
|
+
# x = Q(L + μI)⁻¹Qᵀb + (1/μ)(b - QQᵀb)
|
|
48
|
+
Qᵀb = Q.T @ b
|
|
49
|
+
term1 = Q @ ((1/L).unsqueeze(-1) * Qᵀb)
|
|
50
|
+
term2 = (1.0 / reg) * (b - Q @ Qᵀb)
|
|
74
51
|
return (term1 + term2).squeeze(-1)
|
|
75
52
|
|
|
76
53
|
def nystrom_pcg(
|
|
77
|
-
|
|
54
|
+
L: torch.Tensor,
|
|
55
|
+
Q: torch.Tensor,
|
|
56
|
+
A_mv: Callable[[torch.Tensor], torch.Tensor],
|
|
78
57
|
b: torch.Tensor,
|
|
79
|
-
sketch_size: int,
|
|
80
58
|
reg: float = 1e-6,
|
|
81
59
|
x0_: torch.Tensor | None = None,
|
|
82
|
-
tol: float | None = 1e-
|
|
60
|
+
tol: float | None = 1e-8,
|
|
83
61
|
maxiter: int | None = None,
|
|
84
|
-
generator=None,
|
|
85
62
|
) -> torch.Tensor:
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
63
|
+
"""conjugate gradient preconditioned by nystrom approximation.
|
|
64
|
+
|
|
65
|
+
The preconditioner can be computed by one matrix-matrix multiplication with A.
|
|
66
|
+
If matrix-matrix is efficient, then this is good (e.g. batched hessian-vector products in pytorch)
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
L (torch.Tensor): eigenvalues of approximation of A, like from ``nystrom_approximation``
|
|
70
|
+
Q (torch.Tensor): eigenvectors of approximation of A, like from ``nystrom_approximation``
|
|
71
|
+
A_mv (Callable[[torch.Tensor], torch.Tensor]): mat-vec func with hessian
|
|
72
|
+
b (torch.Tensor): right hand side
|
|
73
|
+
reg (float, optional): regularization. Defaults to 1e-6.
|
|
74
|
+
x0_ (torch.Tensor | None, optional): initial guess (modified in-place). Defaults to None.
|
|
75
|
+
tol (float | None, optional): tolerance for convergence. Defaults to 1e-4.
|
|
76
|
+
maxiter (int | None, optional): maximum number of iterations. Defaults to None.
|
|
77
|
+
"""
|
|
78
|
+
L += reg
|
|
95
79
|
eps = torch.finfo(b.dtype).tiny * 2
|
|
96
80
|
if tol is None: tol = eps
|
|
97
81
|
|
|
98
|
-
def
|
|
99
|
-
Ax =
|
|
82
|
+
def A_mv_reg(x): # A_mm with regularization
|
|
83
|
+
Ax = A_mv(x)
|
|
100
84
|
if reg != 0: Ax += x*reg
|
|
101
85
|
return Ax
|
|
102
86
|
|
|
@@ -104,10 +88,10 @@ def nystrom_pcg(
|
|
|
104
88
|
if x0_ is None: x0_ = torch.zeros_like(b)
|
|
105
89
|
|
|
106
90
|
x = x0_
|
|
107
|
-
residual = b -
|
|
91
|
+
residual = b - A_mv_reg(x)
|
|
108
92
|
# z0 = P⁻¹ r0
|
|
109
|
-
term1 =
|
|
110
|
-
term2 = torch.eye(
|
|
93
|
+
term1 = L[...,-1] * Q * (1/L.unsqueeze(-2)) @ Q.mT
|
|
94
|
+
term2 = torch.eye(Q.size(-2), device=Q.device,dtype=Q.dtype) - Q@Q.mT
|
|
111
95
|
P_inv = term1 + term2
|
|
112
96
|
z = P_inv @ residual
|
|
113
97
|
p = z.clone() # search direction
|
|
@@ -116,7 +100,7 @@ def nystrom_pcg(
|
|
|
116
100
|
if init_norm < tol: return x
|
|
117
101
|
k = 0
|
|
118
102
|
while True:
|
|
119
|
-
Ap =
|
|
103
|
+
Ap = A_mv_reg(p)
|
|
120
104
|
rz = residual.dot(z)
|
|
121
105
|
step_size = rz / p.dot(Ap)
|
|
122
106
|
x += step_size * p
|
|
@@ -138,7 +122,7 @@ def _safe_clip(x: torch.Tensor):
|
|
|
138
122
|
if x.abs() < eps: return x.new_full(x.size(), eps).copysign(x)
|
|
139
123
|
return x
|
|
140
124
|
|
|
141
|
-
def _trust_tau(x,d,trust_radius):
|
|
125
|
+
def _trust_tau(x, d, trust_radius):
|
|
142
126
|
xx = x.dot(x)
|
|
143
127
|
xd = x.dot(d)
|
|
144
128
|
dd = _safe_clip(d.dot(d))
|
|
@@ -150,10 +134,10 @@ def _trust_tau(x,d,trust_radius):
|
|
|
150
134
|
|
|
151
135
|
|
|
152
136
|
class CG:
|
|
153
|
-
"""Conjugate gradient method.
|
|
137
|
+
"""Conjugate gradient method optionally with norm constraint.
|
|
154
138
|
|
|
155
139
|
Args:
|
|
156
|
-
|
|
140
|
+
A_mv (Callable[[torch.Tensor], torch.Tensor] | torch.Tensor): Callable that returns matvec ``Ax``.
|
|
157
141
|
b (torch.Tensor): right hand side
|
|
158
142
|
x0 (torch.Tensor | None, optional): initial guess, defaults to zeros. Defaults to None.
|
|
159
143
|
tol (float | None, optional): tolerance for convergence. Defaults to 1e-8.
|
|
@@ -174,10 +158,10 @@ class CG:
|
|
|
174
158
|
"""
|
|
175
159
|
def __init__(
|
|
176
160
|
self,
|
|
177
|
-
|
|
161
|
+
A_mv: Callable,
|
|
178
162
|
b: torch.Tensor | TensorList,
|
|
179
163
|
x0: torch.Tensor | TensorList | None = None,
|
|
180
|
-
tol: float | None = 1e-
|
|
164
|
+
tol: float | None = 1e-8,
|
|
181
165
|
maxiter: int | None = None,
|
|
182
166
|
reg: float = 0,
|
|
183
167
|
trust_radius: float | None = None,
|
|
@@ -187,7 +171,7 @@ class CG:
|
|
|
187
171
|
P_mm: Callable | None = None,
|
|
188
172
|
):
|
|
189
173
|
# --------------------------------- set attrs -------------------------------- #
|
|
190
|
-
self.
|
|
174
|
+
self.A_mv = _make_A_mv_reg(A_mv, reg)
|
|
191
175
|
self.b = b
|
|
192
176
|
if tol is None: tol = generic_finfo_tiny(b) * 2
|
|
193
177
|
self.tol = tol
|
|
@@ -214,7 +198,7 @@ class CG:
|
|
|
214
198
|
self.r = b
|
|
215
199
|
else:
|
|
216
200
|
self.x = x0
|
|
217
|
-
self.r = b -
|
|
201
|
+
self.r = b - A_mv(self.x)
|
|
218
202
|
|
|
219
203
|
self.z = self.P_mm(self.r)
|
|
220
204
|
self.d = self.z
|
|
@@ -229,7 +213,7 @@ class CG:
|
|
|
229
213
|
if self.iter >= self.maxiter:
|
|
230
214
|
return x, True
|
|
231
215
|
|
|
232
|
-
Ad = self.
|
|
216
|
+
Ad = self.A_mv(d)
|
|
233
217
|
dAd = d.dot(Ad)
|
|
234
218
|
|
|
235
219
|
# check negative curvature
|
|
@@ -289,7 +273,8 @@ class CG:
|
|
|
289
273
|
return sol
|
|
290
274
|
|
|
291
275
|
def find_within_trust_radius(history, trust_radius: float):
|
|
292
|
-
"""find first ``x`` in history that exceeds trust radius
|
|
276
|
+
"""find first ``x`` in history that exceeds trust radius and returns solution within,
|
|
277
|
+
if no such ``x`` exists, returns ``None``"""
|
|
293
278
|
for x, x_norm, d in reversed(tuple(history)):
|
|
294
279
|
if x_norm <= trust_radius:
|
|
295
280
|
return _trust_tau(x, d, trust_radius)
|
|
@@ -306,7 +291,7 @@ class _TensorListSolution(NamedTuple):
|
|
|
306
291
|
|
|
307
292
|
@overload
|
|
308
293
|
def cg(
|
|
309
|
-
|
|
294
|
+
A_mv: Callable[[torch.Tensor], torch.Tensor],
|
|
310
295
|
b: torch.Tensor,
|
|
311
296
|
x0: torch.Tensor | None = None,
|
|
312
297
|
tol: float | None = 1e-8,
|
|
@@ -320,7 +305,7 @@ def cg(
|
|
|
320
305
|
) -> _TensorSolution: ...
|
|
321
306
|
@overload
|
|
322
307
|
def cg(
|
|
323
|
-
|
|
308
|
+
A_mv: Callable[[TensorList], TensorList],
|
|
324
309
|
b: TensorList,
|
|
325
310
|
x0: TensorList | None = None,
|
|
326
311
|
tol: float | None = 1e-8,
|
|
@@ -333,7 +318,7 @@ def cg(
|
|
|
333
318
|
P_mm: Callable[[TensorList], TensorList] | None = None
|
|
334
319
|
) -> _TensorListSolution: ...
|
|
335
320
|
def cg(
|
|
336
|
-
|
|
321
|
+
A_mv: Callable,
|
|
337
322
|
b: torch.Tensor | TensorList,
|
|
338
323
|
x0: torch.Tensor | TensorList | None = None,
|
|
339
324
|
tol: float | None = 1e-8,
|
|
@@ -346,7 +331,7 @@ def cg(
|
|
|
346
331
|
P_mm: Callable | None = None
|
|
347
332
|
):
|
|
348
333
|
solver = CG(
|
|
349
|
-
|
|
334
|
+
A_mv=A_mv,
|
|
350
335
|
b=b,
|
|
351
336
|
x0=x0,
|
|
352
337
|
tol=tol,
|
|
@@ -370,10 +355,10 @@ def cg(
|
|
|
370
355
|
# Liu, Yang, and Fred Roosta. "MINRES: From negative curvature detection to monotonicity properties." SIAM Journal on Optimization 32.4 (2022): 2636-2661.
|
|
371
356
|
@overload
|
|
372
357
|
def minres(
|
|
373
|
-
|
|
358
|
+
A_mv: Callable[[torch.Tensor], torch.Tensor] | torch.Tensor,
|
|
374
359
|
b: torch.Tensor,
|
|
375
360
|
x0: torch.Tensor | None = None,
|
|
376
|
-
tol: float | None = 1e-
|
|
361
|
+
tol: float | None = 1e-8,
|
|
377
362
|
maxiter: int | None = None,
|
|
378
363
|
reg: float = 0,
|
|
379
364
|
npc_terminate: bool=True,
|
|
@@ -381,26 +366,27 @@ def minres(
|
|
|
381
366
|
) -> torch.Tensor: ...
|
|
382
367
|
@overload
|
|
383
368
|
def minres(
|
|
384
|
-
|
|
369
|
+
A_mv: Callable[[TensorList], TensorList],
|
|
385
370
|
b: TensorList,
|
|
386
371
|
x0: TensorList | None = None,
|
|
387
|
-
tol: float | None = 1e-
|
|
372
|
+
tol: float | None = 1e-8,
|
|
388
373
|
maxiter: int | None = None,
|
|
389
374
|
reg: float | list[float] | tuple[float] = 0,
|
|
390
375
|
npc_terminate: bool=True,
|
|
391
376
|
trust_radius: float | None = None,
|
|
392
377
|
) -> TensorList: ...
|
|
393
378
|
def minres(
|
|
394
|
-
|
|
379
|
+
A_mv,
|
|
395
380
|
b,
|
|
396
381
|
x0: torch.Tensor | TensorList | None = None,
|
|
397
|
-
tol: float | None = 1e-
|
|
382
|
+
tol: float | None = 1e-8,
|
|
398
383
|
maxiter: int | None = None,
|
|
399
384
|
reg: float | list[float] | tuple[float] = 0,
|
|
400
385
|
npc_terminate: bool=True,
|
|
401
386
|
trust_radius: float | None = None, #trust region is experimental
|
|
402
387
|
):
|
|
403
|
-
|
|
388
|
+
"""MINRES (experimental)"""
|
|
389
|
+
A_mv_reg = _make_A_mv_reg(A_mv, reg)
|
|
404
390
|
eps = math.sqrt(generic_finfo_tiny(b) * 2)
|
|
405
391
|
if tol is None: tol = eps
|
|
406
392
|
|
|
@@ -409,7 +395,7 @@ def minres(
|
|
|
409
395
|
R = b
|
|
410
396
|
x0 = generic_zeros_like(b)
|
|
411
397
|
else:
|
|
412
|
-
R = b -
|
|
398
|
+
R = b - A_mv_reg(x0)
|
|
413
399
|
|
|
414
400
|
X: Any = x0
|
|
415
401
|
beta = b_norm = generic_vector_norm(b)
|
|
@@ -429,7 +415,7 @@ def minres(
|
|
|
429
415
|
|
|
430
416
|
for _ in range(maxiter):
|
|
431
417
|
|
|
432
|
-
P =
|
|
418
|
+
P = A_mv_reg(V)
|
|
433
419
|
alpha = V.dot(P)
|
|
434
420
|
P -= beta*V_prev
|
|
435
421
|
P -= alpha*V
|
torchzero/linalg/svd.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from . import torch_linalg
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def tall_reduced_svd_via_eigh(A: torch.Tensor, tol: float = 0, retry_float64:bool=False):
|
|
7
|
+
"""
|
|
8
|
+
Given a tall matrix A of size (m, n), computes U and S from the reduced SVD(A)
|
|
9
|
+
using the eigendecomposition of (n, n) matrix which is faster than direct SVD when m >= n.
|
|
10
|
+
|
|
11
|
+
This truncates small singular values that would causes nans,
|
|
12
|
+
so the returned U and S can have reduced dimension ``k <= n``.
|
|
13
|
+
|
|
14
|
+
Returns U of size ``(m, k)`` and S of size ``(k, )``.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
A (torch.Tensor): A tall matrix of size (m, n) with m >= n.
|
|
18
|
+
tol (float): Tolerance for truncating small singular values. Singular values
|
|
19
|
+
less than ``tol * max_singular_value`` will be discarded.
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
"""
|
|
23
|
+
# if m < n, A.T A will be low rank and we can't use eigh
|
|
24
|
+
m, n = A.size()
|
|
25
|
+
if m < n:
|
|
26
|
+
U, S, V = torch_linalg.svd(A, full_matrices=False, retry_float64=retry_float64)
|
|
27
|
+
return U, S
|
|
28
|
+
|
|
29
|
+
M = A.mH @ A # n,n
|
|
30
|
+
|
|
31
|
+
try:
|
|
32
|
+
L, Q = torch_linalg.eigh(M, retry_float64=retry_float64)
|
|
33
|
+
except torch.linalg.LinAlgError:
|
|
34
|
+
U, S, V = torch_linalg.svd(A, full_matrices=False, retry_float64=retry_float64)
|
|
35
|
+
return U, S
|
|
36
|
+
|
|
37
|
+
L = torch.flip(L, dims=[-1])
|
|
38
|
+
Q = torch.flip(Q, dims=[-1])
|
|
39
|
+
|
|
40
|
+
indices = L > tol * L[0] # L[0] is the max eigenvalue
|
|
41
|
+
L = L[indices]
|
|
42
|
+
Q = Q[:, indices]
|
|
43
|
+
|
|
44
|
+
S = L.sqrt()
|
|
45
|
+
U = (A @ Q) / S
|
|
46
|
+
|
|
47
|
+
return U, S
|
|
@@ -0,0 +1,168 @@
|
|
|
1
|
+
"""torch linalg with correct typing and retries in float64"""
|
|
2
|
+
from typing import NamedTuple
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def cholesky(A: torch.Tensor, *, upper=False, retry_float64:bool=False) -> torch.Tensor:
|
|
8
|
+
"""A - SPD, returns lower triangular L such that ``A = L @ L.mH`` also can pass L to ``torch.cholesky_solve``"""
|
|
9
|
+
try:
|
|
10
|
+
return torch.linalg.cholesky(A, upper=upper) # pylint:disable=not-callable
|
|
11
|
+
|
|
12
|
+
except torch.linalg.LinAlgError as e:
|
|
13
|
+
if not retry_float64: raise e
|
|
14
|
+
dtype = A.dtype
|
|
15
|
+
if dtype == torch.float64: raise e
|
|
16
|
+
return cholesky(A.to(torch.float64), upper=upper, retry_float64=False).to(dtype)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class _QRTuple(NamedTuple):
|
|
20
|
+
Q: torch.Tensor
|
|
21
|
+
R: torch.Tensor
|
|
22
|
+
|
|
23
|
+
def qr(A: torch.Tensor, mode='reduced', retry_float64:bool=False) -> _QRTuple:
|
|
24
|
+
"""A - any matrix ``(*, m, n)`` (for some reason sometimes it takes ages on some matrices)
|
|
25
|
+
|
|
26
|
+
### Returns (if mode = "reduced"):
|
|
27
|
+
|
|
28
|
+
Q: ``(*, m, k)`` - orthogonal
|
|
29
|
+
|
|
30
|
+
R: ``(*, k, n)`` - upper triangular
|
|
31
|
+
|
|
32
|
+
where ``k = min(m,n)``
|
|
33
|
+
"""
|
|
34
|
+
try:
|
|
35
|
+
return torch.linalg.qr(A, mode=mode) # pylint:disable=not-callable
|
|
36
|
+
|
|
37
|
+
except torch.linalg.LinAlgError as e:
|
|
38
|
+
if not retry_float64: raise e
|
|
39
|
+
dtype = A.dtype
|
|
40
|
+
if dtype == torch.float64: raise e
|
|
41
|
+
Q, R = qr(A.to(torch.float64), mode=mode, retry_float64=False)
|
|
42
|
+
return _QRTuple(Q=Q.to(dtype), R=R.to(dtype))
|
|
43
|
+
|
|
44
|
+
def eigh(A: torch.Tensor, UPLO="L", retry_float64:bool=False) -> tuple[torch.Tensor, torch.Tensor]:
|
|
45
|
+
"""A - symmetric, returns ``(L, Q)``, ``A = Q @ torch.diag(L) @ Q.mH``, this is faster than SVD"""
|
|
46
|
+
try:
|
|
47
|
+
return torch.linalg.eigh(A, UPLO=UPLO) # pylint:disable=not-callable
|
|
48
|
+
|
|
49
|
+
except torch.linalg.LinAlgError as e:
|
|
50
|
+
if not retry_float64: raise e
|
|
51
|
+
dtype = A.dtype
|
|
52
|
+
if dtype == torch.float64: raise e
|
|
53
|
+
L, Q = eigh(A.to(torch.float64), UPLO=UPLO, retry_float64=False)
|
|
54
|
+
return L.to(dtype), Q.to(dtype)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class _SVDTuple(NamedTuple):
|
|
59
|
+
U: torch.Tensor
|
|
60
|
+
S: torch.Tensor
|
|
61
|
+
Vh: torch.Tensor
|
|
62
|
+
|
|
63
|
+
def svd(A: torch.Tensor, full_matrices=True, driver=None, retry_float64:bool=False) -> _SVDTuple:
|
|
64
|
+
"""A - any matrix ``(*, n, m)``, but slows down if A isn't well conditioned, ``A = U @ torch.diag(S) @ Vh``
|
|
65
|
+
|
|
66
|
+
Don't forget to set ``full_matrices=False``
|
|
67
|
+
|
|
68
|
+
### Returns:
|
|
69
|
+
|
|
70
|
+
U: ``(*, m, m)`` or ``(*, m, k)`` - orthogonal
|
|
71
|
+
|
|
72
|
+
S: ``(*, k,)`` - singular values
|
|
73
|
+
|
|
74
|
+
V^H: ``(*, n, n)`` or ``(*, n, k)`` - orthogonal
|
|
75
|
+
|
|
76
|
+
where ``k = min(m,n)``
|
|
77
|
+
|
|
78
|
+
### Drivers
|
|
79
|
+
|
|
80
|
+
drivers are only supported on CUDA so A is moved to CUDA by this function if needed
|
|
81
|
+
|
|
82
|
+
from docs:
|
|
83
|
+
|
|
84
|
+
If A is well-conditioned (its condition number is not too large), or you do not mind some precision loss.
|
|
85
|
+
|
|
86
|
+
For a general matrix: ‘gesvdj’ (Jacobi method)
|
|
87
|
+
|
|
88
|
+
If A is tall or wide (m >> n or m << n): ‘gesvda’ (Approximate method)
|
|
89
|
+
|
|
90
|
+
If A is not well-conditioned or precision is relevant: ‘gesvd’ (QR based)
|
|
91
|
+
|
|
92
|
+
By default (driver= None), we call ‘gesvdj’ and, if it fails, we fallback to ‘gesvd’.
|
|
93
|
+
"""
|
|
94
|
+
# drivers are only for CUDA
|
|
95
|
+
# also the only one that doesn't freeze is ‘gesvda’
|
|
96
|
+
device=None
|
|
97
|
+
if driver is not None:
|
|
98
|
+
device = A.device
|
|
99
|
+
A = A.cuda()
|
|
100
|
+
|
|
101
|
+
try:
|
|
102
|
+
U, S, Vh = torch.linalg.svd(A, full_matrices=full_matrices, driver=driver) # pylint:disable=not-callable
|
|
103
|
+
if device is not None:
|
|
104
|
+
U = U.to(device); S = S.to(device); Vh = Vh.to(device)
|
|
105
|
+
return _SVDTuple(U=U, S=S, Vh=Vh)
|
|
106
|
+
|
|
107
|
+
except torch.linalg.LinAlgError as e:
|
|
108
|
+
if not retry_float64: raise e
|
|
109
|
+
dtype = A.dtype
|
|
110
|
+
if dtype == torch.float64: raise e
|
|
111
|
+
U, S, Vh = svd(A.to(torch.float64), full_matrices=full_matrices, driver=driver, retry_float64=False)
|
|
112
|
+
return _SVDTuple(U=U.to(dtype), S=S.to(dtype), Vh=Vh.to(dtype))
|
|
113
|
+
|
|
114
|
+
def solve(A: torch.Tensor, B: torch.Tensor, left:bool=True, retry_float64:bool=False) -> torch.Tensor:
|
|
115
|
+
"""I think this uses LU"""
|
|
116
|
+
try:
|
|
117
|
+
return torch.linalg.solve(A, B, left=left) # pylint:disable=not-callable
|
|
118
|
+
|
|
119
|
+
except torch.linalg.LinAlgError as e:
|
|
120
|
+
if not retry_float64: raise e
|
|
121
|
+
dtype = A.dtype
|
|
122
|
+
if dtype == torch.float64: raise e
|
|
123
|
+
return solve(A.to(torch.float64), B.to(torch.float64), left=left, retry_float64=False).to(dtype)
|
|
124
|
+
|
|
125
|
+
class _SolveExTuple(NamedTuple):
|
|
126
|
+
result: torch.Tensor
|
|
127
|
+
info: int
|
|
128
|
+
|
|
129
|
+
def solve_ex(A: torch.Tensor, B: torch.Tensor, left:bool=True, retry_float64:bool=False) -> _SolveExTuple:
|
|
130
|
+
"""I think this uses LU"""
|
|
131
|
+
result, info = torch.linalg.solve_ex(A, B, left=left) # pylint:disable=not-callable
|
|
132
|
+
|
|
133
|
+
if info != 0:
|
|
134
|
+
if not retry_float64: return _SolveExTuple(result, info)
|
|
135
|
+
dtype = A.dtype
|
|
136
|
+
if dtype == torch.float64: return _SolveExTuple(result, info)
|
|
137
|
+
result, info = solve_ex(A.to(torch.float64), B.to(torch.float64), retry_float64=False)
|
|
138
|
+
return _SolveExTuple(result.to(dtype), info)
|
|
139
|
+
|
|
140
|
+
return _SolveExTuple(result, info)
|
|
141
|
+
|
|
142
|
+
def inv(A: torch.Tensor, retry_float64:bool=False) -> torch.Tensor:
|
|
143
|
+
try:
|
|
144
|
+
return torch.linalg.inv(A) # pylint:disable=not-callable
|
|
145
|
+
|
|
146
|
+
except torch.linalg.LinAlgError as e:
|
|
147
|
+
if not retry_float64: raise e
|
|
148
|
+
dtype = A.dtype
|
|
149
|
+
if dtype == torch.float64: raise e
|
|
150
|
+
return inv(A.to(torch.float64), retry_float64=False).to(dtype)
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
class _InvExTuple(NamedTuple):
|
|
154
|
+
inverse: torch.Tensor
|
|
155
|
+
info: int
|
|
156
|
+
|
|
157
|
+
def inv_ex(A: torch.Tensor, *, check_errors=False, retry_float64:bool=False) -> _InvExTuple:
|
|
158
|
+
"""this retries in float64 but on fail info will be not 0"""
|
|
159
|
+
inverse, info = torch.linalg.inv_ex(A, check_errors=check_errors) # pylint:disable=not-callable
|
|
160
|
+
|
|
161
|
+
if info != 0:
|
|
162
|
+
if not retry_float64: return _InvExTuple(inverse, info)
|
|
163
|
+
dtype = A.dtype
|
|
164
|
+
if dtype == torch.float64: return _InvExTuple(inverse, info)
|
|
165
|
+
inverse, info = inv_ex(A.to(torch.float64), retry_float64=False)
|
|
166
|
+
return _InvExTuple(inverse.to(dtype), info)
|
|
167
|
+
|
|
168
|
+
return _InvExTuple(inverse, info)
|
torchzero/modules/__init__.py
CHANGED
|
@@ -1,4 +1,6 @@
|
|
|
1
1
|
from . import experimental
|
|
2
|
+
from .adaptive import *
|
|
3
|
+
from .adaptive import lre_optimizers as lre
|
|
2
4
|
from .clipping import *
|
|
3
5
|
from .conjugate_gradient import *
|
|
4
6
|
from .grad_approximation import *
|
|
@@ -7,9 +9,9 @@ from .line_search import *
|
|
|
7
9
|
from .misc import *
|
|
8
10
|
from .momentum import *
|
|
9
11
|
from .ops import *
|
|
10
|
-
from .adaptive import *
|
|
11
12
|
from .projections import *
|
|
12
13
|
from .quasi_newton import *
|
|
14
|
+
from .restarts import *
|
|
13
15
|
from .second_order import *
|
|
14
16
|
from .smoothing import *
|
|
15
17
|
from .step_size import *
|
|
@@ -18,5 +20,4 @@ from .trust_region import *
|
|
|
18
20
|
from .variance_reduction import *
|
|
19
21
|
from .weight_decay import *
|
|
20
22
|
from .wrappers import *
|
|
21
|
-
from .
|
|
22
|
-
from .zeroth_order import *
|
|
23
|
+
from .zeroth_order import *
|
|
@@ -1,4 +1,5 @@
|
|
|
1
|
-
from .
|
|
1
|
+
from . import lre_optimizers
|
|
2
|
+
from .adagrad import Adagrad, AdagradNorm, FullMatrixAdagrad
|
|
2
3
|
|
|
3
4
|
# from .curveball import CurveBall
|
|
4
5
|
# from .spectral import SpectralPreconditioner
|
|
@@ -8,14 +9,21 @@ from .adan import Adan
|
|
|
8
9
|
from .adaptive_heavyball import AdaptiveHeavyBall
|
|
9
10
|
from .aegd import AEGD
|
|
10
11
|
from .esgd import ESGD
|
|
11
|
-
from .lmadagrad import LMAdagrad
|
|
12
12
|
from .lion import Lion
|
|
13
|
+
from .ggt import GGT
|
|
13
14
|
from .mars import MARSCorrection
|
|
14
15
|
from .matrix_momentum import MatrixMomentum
|
|
15
|
-
from .msam import MSAM,
|
|
16
|
+
from .msam import MSAM, MSAMMomentum
|
|
16
17
|
from .muon import DualNormCorrection, MuonAdjustLR, Orthogonalize, orthogonalize_grads_
|
|
17
18
|
from .natural_gradient import NaturalGradient
|
|
18
19
|
from .orthograd import OrthoGrad, orthograd_
|
|
20
|
+
from .psgd import (
|
|
21
|
+
PSGDDenseNewton,
|
|
22
|
+
PSGDKronNewton,
|
|
23
|
+
PSGDKronWhiten,
|
|
24
|
+
PSGDLRANewton,
|
|
25
|
+
PSGDLRAWhiten,
|
|
26
|
+
)
|
|
19
27
|
from .rmsprop import RMSprop
|
|
20
28
|
from .rprop import (
|
|
21
29
|
BacktrackOnSignChange,
|