torchzero 0.3.11__py3-none-any.whl → 0.3.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_opts.py +95 -69
- tests/test_tensorlist.py +8 -7
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +225 -72
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +44 -24
- torchzero/modules/__init__.py +13 -5
- torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
- torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
- torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
- torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
- torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
- torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
- torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
- torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
- torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
- torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
- torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
- torchzero/modules/clipping/clipping.py +85 -92
- torchzero/modules/clipping/ema_clipping.py +5 -5
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
- torchzero/modules/experimental/__init__.py +9 -32
- torchzero/modules/experimental/dct.py +2 -2
- torchzero/modules/experimental/fft.py +2 -2
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +27 -14
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/experimental/structural_projections.py +1 -1
- torchzero/modules/functional.py +50 -14
- torchzero/modules/grad_approximation/fdm.py +19 -20
- torchzero/modules/grad_approximation/forward_gradient.py +4 -2
- torchzero/modules/grad_approximation/grad_approximator.py +43 -47
- torchzero/modules/grad_approximation/rfdm.py +144 -122
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +31 -23
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +2 -2
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +69 -44
- torchzero/modules/line_search/backtracking.py +83 -70
- torchzero/modules/line_search/line_search.py +159 -68
- torchzero/modules/line_search/scipy.py +1 -1
- torchzero/modules/line_search/strong_wolfe.py +319 -218
- torchzero/modules/misc/__init__.py +8 -0
- torchzero/modules/misc/debug.py +4 -4
- torchzero/modules/misc/escape.py +9 -7
- torchzero/modules/misc/gradient_accumulation.py +88 -22
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +82 -15
- torchzero/modules/misc/multistep.py +47 -11
- torchzero/modules/misc/regularization.py +5 -9
- torchzero/modules/misc/split.py +55 -35
- torchzero/modules/misc/switch.py +1 -1
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +3 -3
- torchzero/modules/momentum/cautious.py +42 -47
- torchzero/modules/momentum/momentum.py +35 -1
- torchzero/modules/ops/__init__.py +9 -1
- torchzero/modules/ops/binary.py +9 -8
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
- torchzero/modules/ops/multi.py +15 -15
- torchzero/modules/ops/reduce.py +1 -1
- torchzero/modules/ops/utility.py +12 -8
- torchzero/modules/projections/projection.py +4 -4
- torchzero/modules/quasi_newton/__init__.py +1 -16
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
- torchzero/modules/quasi_newton/lbfgs.py +256 -200
- torchzero/modules/quasi_newton/lsr1.py +167 -132
- torchzero/modules/quasi_newton/quasi_newton.py +346 -446
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +252 -0
- torchzero/modules/second_order/__init__.py +2 -1
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +133 -88
- torchzero/modules/second_order/newton_cg.py +141 -80
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +1 -1
- torchzero/modules/step_size/adaptive.py +312 -47
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +97 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/weight_decay.py +65 -64
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +359 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +0 -1
- torchzero/optim/wrappers/fcmaes.py +3 -2
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +2 -2
- torchzero/optim/wrappers/scipy.py +81 -22
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +123 -111
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +226 -154
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/python_tools.py +6 -0
- torchzero/utils/tensorlist.py +105 -34
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.13.dist-info/METADATA +14 -0
- torchzero-0.3.13.dist-info/RECORD +166 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -59
- docs/source/docstring template.py +0 -46
- torchzero/modules/experimental/absoap.py +0 -253
- torchzero/modules/experimental/adadam.py +0 -118
- torchzero/modules/experimental/adamY.py +0 -131
- torchzero/modules/experimental/adam_lambertw.py +0 -149
- torchzero/modules/experimental/adaptive_step_size.py +0 -90
- torchzero/modules/experimental/adasoap.py +0 -177
- torchzero/modules/experimental/cosine.py +0 -214
- torchzero/modules/experimental/cubic_adam.py +0 -97
- torchzero/modules/experimental/eigendescent.py +0 -120
- torchzero/modules/experimental/etf.py +0 -195
- torchzero/modules/experimental/exp_adam.py +0 -113
- torchzero/modules/experimental/expanded_lbfgs.py +0 -141
- torchzero/modules/experimental/hnewton.py +0 -85
- torchzero/modules/experimental/modular_lbfgs.py +0 -265
- torchzero/modules/experimental/parabolic_search.py +0 -220
- torchzero/modules/experimental/subspace_preconditioners.py +0 -145
- torchzero/modules/experimental/tensor_adagrad.py +0 -42
- torchzero/modules/line_search/polynomial.py +0 -233
- torchzero/modules/momentum/matrix_momentum.py +0 -193
- torchzero/modules/optimizers/adagrad.py +0 -165
- torchzero/modules/quasi_newton/trust_region.py +0 -397
- torchzero/modules/smoothing/gaussian.py +0 -198
- torchzero-0.3.11.dist-info/METADATA +0 -404
- torchzero-0.3.11.dist-info/RECORD +0 -159
- torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
- /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
|
@@ -1,41 +1,18 @@
|
|
|
1
|
-
"""
|
|
2
|
-
from .absoap import ABSOAP
|
|
3
|
-
from .adadam import Adadam
|
|
4
|
-
from .adam_lambertw import AdamLambertW
|
|
5
|
-
from .adamY import AdamY
|
|
6
|
-
from .adaptive_step_size import AdaptiveStepSize
|
|
7
|
-
from .adasoap import AdaSOAP
|
|
8
|
-
from .cosine import (
|
|
9
|
-
AdaptiveDifference,
|
|
10
|
-
AdaptiveDifferenceEMA,
|
|
11
|
-
CosineDebounce,
|
|
12
|
-
CosineMomentum,
|
|
13
|
-
CosineStepSize,
|
|
14
|
-
ScaledAdaptiveDifference,
|
|
15
|
-
)
|
|
16
|
-
from .cubic_adam import CubicAdam
|
|
1
|
+
"""Those are various ideas of mine plus some other modules that I decided not to move to other sub-packages for whatever reason. This is generally less tested and shouldn't be used."""
|
|
17
2
|
from .curveball import CurveBall
|
|
18
3
|
|
|
19
4
|
# from dct import DCTProjection
|
|
20
|
-
from .eigendescent import EigenDescent
|
|
21
|
-
from .etf import (
|
|
22
|
-
ExponentialTrajectoryFit,
|
|
23
|
-
ExponentialTrajectoryFitV2,
|
|
24
|
-
PointwiseExponential,
|
|
25
|
-
)
|
|
26
|
-
from .exp_adam import ExpAdam
|
|
27
|
-
from .expanded_lbfgs import ExpandedLBFGS
|
|
28
5
|
from .fft import FFTProjection
|
|
29
6
|
from .gradmin import GradMin
|
|
30
|
-
from .
|
|
31
|
-
from .
|
|
7
|
+
from .l_infinity import InfinityNormTrustRegion
|
|
8
|
+
from .momentum import (
|
|
9
|
+
CoordinateMomentum,
|
|
10
|
+
NesterovEMASquared,
|
|
11
|
+
PrecenteredEMASquared,
|
|
12
|
+
SqrtNesterovEMASquared,
|
|
13
|
+
)
|
|
32
14
|
from .newton_solver import NewtonSolver
|
|
33
15
|
from .newtonnewton import NewtonNewton
|
|
34
|
-
from .parabolic_search import CubicParabolaSearch, ParabolaSearch
|
|
35
16
|
from .reduce_outward_lr import ReduceOutwardLR
|
|
17
|
+
from .scipy_newton_cg import ScipyNewtonCG
|
|
36
18
|
from .structural_projections import BlockPartition, TensorizeProjection
|
|
37
|
-
from .subspace_preconditioners import (
|
|
38
|
-
HistorySubspacePreconditioning,
|
|
39
|
-
RandomSubspacePreconditioning,
|
|
40
|
-
)
|
|
41
|
-
from .tensor_adagrad import TensorAdagrad
|
|
@@ -54,8 +54,8 @@ class DCTProjection(ProjectionBase):
|
|
|
54
54
|
return projected
|
|
55
55
|
|
|
56
56
|
@torch.no_grad
|
|
57
|
-
def unproject(self, projected_tensors, params, grads, loss,
|
|
58
|
-
settings =
|
|
57
|
+
def unproject(self, projected_tensors, params, grads, loss, states, settings, current):
|
|
58
|
+
settings = settings[0]
|
|
59
59
|
dims = settings['dims']
|
|
60
60
|
norm = settings['norm']
|
|
61
61
|
|
|
@@ -60,8 +60,8 @@ class FFTProjection(ProjectionBase):
|
|
|
60
60
|
return [torch.view_as_real(torch.fft.rfftn(t, norm=norm)) if t.numel() > 1 else t for t in tensors] # pylint:disable=not-callable
|
|
61
61
|
|
|
62
62
|
@torch.no_grad
|
|
63
|
-
def unproject(self, projected_tensors, params, grads, loss,
|
|
64
|
-
settings =
|
|
63
|
+
def unproject(self, projected_tensors, params, grads, loss, states, settings, current):
|
|
64
|
+
settings = settings[0]
|
|
65
65
|
one_d = settings['one_d']
|
|
66
66
|
norm = settings['norm']
|
|
67
67
|
|
|
@@ -5,11 +5,11 @@ from typing import Literal
|
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
|
|
8
|
-
from ...core import Module, Var
|
|
8
|
+
from ...core import Module, Var, Chainable
|
|
9
9
|
from ...utils import NumberList, TensorList
|
|
10
10
|
from ...utils.derivatives import jacobian_wrt
|
|
11
11
|
from ..grad_approximation import GradApproximator, GradTarget
|
|
12
|
-
from ..smoothing.
|
|
12
|
+
from ..smoothing.sampling import Reformulation
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
|
|
@@ -28,6 +28,7 @@ class GradMin(Reformulation):
|
|
|
28
28
|
"""
|
|
29
29
|
def __init__(
|
|
30
30
|
self,
|
|
31
|
+
modules: Chainable,
|
|
31
32
|
loss_term: float | None = 0,
|
|
32
33
|
relative: Literal['loss_to_grad', 'grad_to_loss'] | None = None,
|
|
33
34
|
graft: Literal['loss_to_grad', 'grad_to_loss'] | None = None,
|
|
@@ -39,7 +40,7 @@ class GradMin(Reformulation):
|
|
|
39
40
|
):
|
|
40
41
|
if (relative is not None) and (graft is not None): warnings.warn('both relative and graft loss are True, they will clash with each other')
|
|
41
42
|
defaults = dict(loss_term=loss_term, relative=relative, graft=graft, square=square, mean=mean, maximize_grad=maximize_grad, create_graph=create_graph, modify_loss=modify_loss)
|
|
42
|
-
super().__init__(defaults)
|
|
43
|
+
super().__init__(defaults, modules=modules)
|
|
43
44
|
|
|
44
45
|
@torch.no_grad
|
|
45
46
|
def closure(self, backward, closure, params, var):
|
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
|
|
2
|
+
import numpy as np
|
|
3
|
+
import torch
|
|
4
|
+
from scipy.optimize import lsq_linear
|
|
5
|
+
|
|
6
|
+
from ...core import Chainable, Module
|
|
7
|
+
from ..trust_region.trust_region import _RADIUS_KEYS, TrustRegionBase, _RadiusStrategy
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class InfinityNormTrustRegion(TrustRegionBase):
|
|
11
|
+
"""Trust region with L-infinity norm via ``scipy.optimize.lsq_linear``.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
hess_module (Module | None, optional):
|
|
15
|
+
A module that maintains a hessian approximation (not hessian inverse!).
|
|
16
|
+
This includes all full-matrix quasi-newton methods, ``tz.m.Newton`` and ``tz.m.GaussNewton``.
|
|
17
|
+
When using quasi-newton methods, set `inverse=False` when constructing them.
|
|
18
|
+
eta (float, optional):
|
|
19
|
+
if ratio of actual to predicted rediction is larger than this, step is accepted.
|
|
20
|
+
When :code:`hess_module` is GaussNewton, this can be set to 0. Defaults to 0.15.
|
|
21
|
+
nplus (float, optional): increase factor on successful steps. Defaults to 1.5.
|
|
22
|
+
nminus (float, optional): decrease factor on unsuccessful steps. Defaults to 0.75.
|
|
23
|
+
rho_good (float, optional):
|
|
24
|
+
if ratio of actual to predicted rediction is larger than this, trust region size is multiplied by `nplus`.
|
|
25
|
+
rho_bad (float, optional):
|
|
26
|
+
if ratio of actual to predicted rediction is less than this, trust region size is multiplied by `nminus`.
|
|
27
|
+
init (float, optional): Initial trust region value. Defaults to 1.
|
|
28
|
+
update_freq (int, optional): frequency of updating the hessian. Defaults to 1.
|
|
29
|
+
max_attempts (max_attempts, optional):
|
|
30
|
+
maximum number of trust region size size reductions per step. A zero update vector is returned when
|
|
31
|
+
this limit is exceeded. Defaults to 10.
|
|
32
|
+
boundary_tol (float | None, optional):
|
|
33
|
+
The trust region only increases when suggested step's norm is at least `(1-boundary_tol)*trust_region`.
|
|
34
|
+
This prevents increasing trust region when solution is not on the boundary. Defaults to 1e-2.
|
|
35
|
+
tol (float | None, optional): tolerance for least squares solver.
|
|
36
|
+
fallback (bool, optional):
|
|
37
|
+
if ``True``, when ``hess_module`` maintains hessian inverse which can't be inverted efficiently, it will
|
|
38
|
+
be inverted anyway. When ``False`` (default), a ``RuntimeError`` will be raised instead.
|
|
39
|
+
inner (Chainable | None, optional): preconditioning is applied to output of thise module. Defaults to None.
|
|
40
|
+
|
|
41
|
+
Examples:
|
|
42
|
+
BFGS with infinity-norm trust region
|
|
43
|
+
|
|
44
|
+
.. code-block:: python
|
|
45
|
+
|
|
46
|
+
opt = tz.Modular(
|
|
47
|
+
model.parameters(),
|
|
48
|
+
tz.m.InfinityNormTrustRegion(hess_module=tz.m.BFGS(inverse=False)),
|
|
49
|
+
)
|
|
50
|
+
"""
|
|
51
|
+
def __init__(
|
|
52
|
+
self,
|
|
53
|
+
hess_module: Module,
|
|
54
|
+
prefer_dense:bool=True,
|
|
55
|
+
tol: float = 1e-10,
|
|
56
|
+
eta: float= 0.0,
|
|
57
|
+
nplus: float = 3.5,
|
|
58
|
+
nminus: float = 0.25,
|
|
59
|
+
rho_good: float = 0.99,
|
|
60
|
+
rho_bad: float = 1e-4,
|
|
61
|
+
boundary_tol: float | None = None,
|
|
62
|
+
init: float = 1,
|
|
63
|
+
max_attempts: int = 10,
|
|
64
|
+
radius_strategy: _RadiusStrategy | _RADIUS_KEYS = 'default',
|
|
65
|
+
update_freq: int = 1,
|
|
66
|
+
inner: Chainable | None = None,
|
|
67
|
+
):
|
|
68
|
+
defaults = dict(tol=tol, prefer_dense=prefer_dense)
|
|
69
|
+
super().__init__(
|
|
70
|
+
defaults=defaults,
|
|
71
|
+
hess_module=hess_module,
|
|
72
|
+
eta=eta,
|
|
73
|
+
nplus=nplus,
|
|
74
|
+
nminus=nminus,
|
|
75
|
+
rho_good=rho_good,
|
|
76
|
+
rho_bad=rho_bad,
|
|
77
|
+
boundary_tol=boundary_tol,
|
|
78
|
+
init=init,
|
|
79
|
+
max_attempts=max_attempts,
|
|
80
|
+
radius_strategy=radius_strategy,
|
|
81
|
+
update_freq=update_freq,
|
|
82
|
+
inner=inner,
|
|
83
|
+
|
|
84
|
+
radius_fn=torch.amax,
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
def trust_solve(self, f, g, H, radius, params, closure, settings):
|
|
88
|
+
if settings['prefer_dense'] and H.is_dense():
|
|
89
|
+
# convert to array if possible to avoid many conversions
|
|
90
|
+
# between torch and numpy, plus it seems that it uses
|
|
91
|
+
# a better solver
|
|
92
|
+
A = H.to_tensor().numpy(force=True).astype(np.float64)
|
|
93
|
+
else:
|
|
94
|
+
# memory efficient linear operator (is this still faster on CUDA?)
|
|
95
|
+
A = H.scipy_linop()
|
|
96
|
+
|
|
97
|
+
try:
|
|
98
|
+
d_np = lsq_linear(
|
|
99
|
+
A,
|
|
100
|
+
g.numpy(force=True).astype(np.float64),
|
|
101
|
+
tol=settings['bounds'],
|
|
102
|
+
bounds=(-radius, radius),
|
|
103
|
+
).x
|
|
104
|
+
return torch.as_tensor(d_np, device=g.device, dtype=g.dtype)
|
|
105
|
+
|
|
106
|
+
except np.linalg.LinAlgError:
|
|
107
|
+
self.children['hess_module'].reset()
|
|
108
|
+
g_max = g.amax()
|
|
109
|
+
if g_max > radius:
|
|
110
|
+
g = g * (radius / g_max)
|
|
111
|
+
return g
|
|
@@ -6,10 +6,10 @@ from typing import Literal
|
|
|
6
6
|
import torch
|
|
7
7
|
|
|
8
8
|
from ...core import Target, Transform
|
|
9
|
-
from ...utils import NumberList, TensorList,
|
|
9
|
+
from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
|
|
10
10
|
from ..functional import ema_, ema_sq_, sqrt_ema_sq_
|
|
11
|
-
from .
|
|
12
|
-
from .
|
|
11
|
+
from ..momentum.momentum import nag_
|
|
12
|
+
from ..ops.higher_level import EMASquared, SqrtEMASquared
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
def precentered_ema_sq_(
|
|
@@ -158,40 +158,3 @@ class CoordinateMomentum(Transform):
|
|
|
158
158
|
p = NumberList(s['p'] for s in settings)
|
|
159
159
|
velocity = unpack_states(states, tensors, 'velocity', cls=TensorList)
|
|
160
160
|
return coordinate_momentum_(TensorList(tensors), velocity_=velocity, p=p).clone()
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
# def multiplicative_momentum_(
|
|
164
|
-
# tensors_: TensorList,
|
|
165
|
-
# velocity_: TensorList,
|
|
166
|
-
# momentum: float | NumberList,
|
|
167
|
-
# dampening: float | NumberList,
|
|
168
|
-
# normalize_velocity: bool = True,
|
|
169
|
-
# abs: bool = False,
|
|
170
|
-
# lerp: bool = False,
|
|
171
|
-
# ):
|
|
172
|
-
# """
|
|
173
|
-
# abs: if True, tracks momentum of absolute magnitudes.
|
|
174
|
-
|
|
175
|
-
# returns `tensors_`.
|
|
176
|
-
# """
|
|
177
|
-
# tensors_into_velocity = tensors_.abs() if abs else tensors_
|
|
178
|
-
# ema_(tensors_into_velocity, exp_avg_=velocity_, beta=momentum, dampening=0, lerp=lerp)
|
|
179
|
-
|
|
180
|
-
# if normalize_velocity: velocity_ = velocity_ / velocity_.std().add_(1e-8)
|
|
181
|
-
# return tensors_.mul_(velocity_.lazy_mul(1-dampening) if abs else velocity_.abs().lazy_mul_(1-dampening))
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
# class MultiplicativeMomentum(Transform):
|
|
185
|
-
# """sucks"""
|
|
186
|
-
# def __init__(self, momentum: float = 0.9, dampening: float = 0,normalize_velocity: bool = True, abs: bool = False, lerp: bool = False):
|
|
187
|
-
# defaults = dict(momentum=momentum, dampening=dampening, normalize_velocity=normalize_velocity,abs=abs, lerp=lerp)
|
|
188
|
-
# super().__init__(defaults, uses_grad=False)
|
|
189
|
-
|
|
190
|
-
# @torch.no_grad
|
|
191
|
-
# def apply(self, tensors, params, grads, loss, states, settings):
|
|
192
|
-
# momentum,dampening = self.get_settings('momentum','dampening', params=params, cls=NumberList)
|
|
193
|
-
# abs,lerp,normalize_velocity = self.first_setting('abs','lerp','normalize_velocity', params=params)
|
|
194
|
-
# velocity = self.get_state('velocity', params=params, cls=TensorList)
|
|
195
|
-
# return multiplicative_momentum_(TensorList(target), velocity_=velocity, momentum=momentum, dampening=dampening,
|
|
196
|
-
# normalize_velocity=normalize_velocity,abs=abs,lerp=lerp)
|
|
197
|
-
|
|
@@ -3,28 +3,36 @@ from typing import Any, Literal, overload
|
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
5
|
|
|
6
|
-
from ...core import Chainable, Module, apply_transform
|
|
6
|
+
from ...core import Chainable, Modular, Module, apply_transform
|
|
7
7
|
from ...utils import TensorList, as_tensorlist
|
|
8
|
-
from ...utils.derivatives import hvp
|
|
8
|
+
from ...utils.derivatives import hvp, hvp_fd_forward, hvp_fd_central
|
|
9
9
|
from ..quasi_newton import LBFGS
|
|
10
10
|
|
|
11
|
+
|
|
11
12
|
class NewtonSolver(Module):
|
|
12
|
-
"""Matrix free newton via with any custom solver (this is for testing, use NewtonCG or NystromPCG)"""
|
|
13
|
+
"""Matrix free newton via with any custom solver (this is for testing, use NewtonCG or NystromPCG)."""
|
|
13
14
|
def __init__(
|
|
14
15
|
self,
|
|
15
16
|
solver: Callable[[list[torch.Tensor]], Any] = lambda p: Modular(p, LBFGS()),
|
|
16
17
|
maxiter=None,
|
|
17
|
-
|
|
18
|
+
maxiter1=None,
|
|
19
|
+
tol:float | None=1e-3,
|
|
18
20
|
reg: float = 0,
|
|
19
21
|
warm_start=True,
|
|
22
|
+
hvp_method: Literal["forward", "central", "autograd"] = "autograd",
|
|
23
|
+
reset_solver: bool = False,
|
|
24
|
+
h: float= 1e-3,
|
|
20
25
|
inner: Chainable | None = None,
|
|
21
26
|
):
|
|
22
|
-
defaults = dict(tol=tol, maxiter=maxiter, reg=reg, warm_start=warm_start, solver=solver)
|
|
27
|
+
defaults = dict(tol=tol, h=h,reset_solver=reset_solver, maxiter=maxiter, maxiter1=maxiter1, reg=reg, warm_start=warm_start, solver=solver, hvp_method=hvp_method)
|
|
23
28
|
super().__init__(defaults,)
|
|
24
29
|
|
|
25
30
|
if inner is not None:
|
|
26
31
|
self.set_child('inner', inner)
|
|
27
32
|
|
|
33
|
+
self._num_hvps = 0
|
|
34
|
+
self._num_hvps_last_step = 0
|
|
35
|
+
|
|
28
36
|
@torch.no_grad
|
|
29
37
|
def step(self, var):
|
|
30
38
|
params = TensorList(var.params)
|
|
@@ -34,19 +42,49 @@ class NewtonSolver(Module):
|
|
|
34
42
|
settings = self.settings[params[0]]
|
|
35
43
|
solver_cls = settings['solver']
|
|
36
44
|
maxiter = settings['maxiter']
|
|
45
|
+
maxiter1 = settings['maxiter1']
|
|
37
46
|
tol = settings['tol']
|
|
38
47
|
reg = settings['reg']
|
|
48
|
+
hvp_method = settings['hvp_method']
|
|
39
49
|
warm_start = settings['warm_start']
|
|
50
|
+
h = settings['h']
|
|
51
|
+
reset_solver = settings['reset_solver']
|
|
40
52
|
|
|
53
|
+
self._num_hvps_last_step = 0
|
|
41
54
|
# ---------------------- Hessian vector product function --------------------- #
|
|
42
|
-
|
|
55
|
+
if hvp_method == 'autograd':
|
|
56
|
+
grad = var.get_grad(create_graph=True)
|
|
43
57
|
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
58
|
+
def H_mm(x):
|
|
59
|
+
self._num_hvps_last_step += 1
|
|
60
|
+
with torch.enable_grad():
|
|
61
|
+
Hvp = TensorList(hvp(params, grad, x, retain_graph=True))
|
|
47
62
|
if reg != 0: Hvp = Hvp + (x*reg)
|
|
48
63
|
return Hvp
|
|
49
64
|
|
|
65
|
+
else:
|
|
66
|
+
|
|
67
|
+
with torch.enable_grad():
|
|
68
|
+
grad = var.get_grad()
|
|
69
|
+
|
|
70
|
+
if hvp_method == 'forward':
|
|
71
|
+
def H_mm(x):
|
|
72
|
+
self._num_hvps_last_step += 1
|
|
73
|
+
Hvp = TensorList(hvp_fd_forward(closure, params, x, h=h, g_0=grad, normalize=True)[1])
|
|
74
|
+
if reg != 0: Hvp = Hvp + (x*reg)
|
|
75
|
+
return Hvp
|
|
76
|
+
|
|
77
|
+
elif hvp_method == 'central':
|
|
78
|
+
def H_mm(x):
|
|
79
|
+
self._num_hvps_last_step += 1
|
|
80
|
+
Hvp = TensorList(hvp_fd_central(closure, params, x, h=h, normalize=True)[1])
|
|
81
|
+
if reg != 0: Hvp = Hvp + (x*reg)
|
|
82
|
+
return Hvp
|
|
83
|
+
|
|
84
|
+
else:
|
|
85
|
+
raise ValueError(hvp_method)
|
|
86
|
+
|
|
87
|
+
|
|
50
88
|
# -------------------------------- inner step -------------------------------- #
|
|
51
89
|
b = as_tensorlist(grad)
|
|
52
90
|
if 'inner' in self.children:
|
|
@@ -58,23 +96,46 @@ class NewtonSolver(Module):
|
|
|
58
96
|
if x0 is None: x = b.zeros_like().requires_grad_(True)
|
|
59
97
|
else: x = x0.clone().requires_grad_(True)
|
|
60
98
|
|
|
61
|
-
|
|
99
|
+
|
|
100
|
+
if 'solver' not in self.global_state:
|
|
101
|
+
if maxiter1 is not None: maxiter = maxiter1
|
|
102
|
+
solver = self.global_state['solver'] = solver_cls(x)
|
|
103
|
+
self.global_state['x'] = x
|
|
104
|
+
|
|
105
|
+
else:
|
|
106
|
+
if reset_solver:
|
|
107
|
+
solver = self.global_state['solver'] = solver_cls(x)
|
|
108
|
+
else:
|
|
109
|
+
solver_params = self.global_state['x']
|
|
110
|
+
solver_params.set_(x)
|
|
111
|
+
x = solver_params
|
|
112
|
+
solver = self.global_state['solver']
|
|
113
|
+
|
|
62
114
|
def lstsq_closure(backward=True):
|
|
63
|
-
Hx = H_mm(x)
|
|
64
|
-
loss = (Hx-b).pow(2).global_mean()
|
|
115
|
+
Hx = H_mm(x).detach()
|
|
116
|
+
# loss = (Hx-b).pow(2).global_mean()
|
|
117
|
+
# if backward:
|
|
118
|
+
# solver.zero_grad()
|
|
119
|
+
# loss.backward(inputs=x)
|
|
120
|
+
|
|
121
|
+
residual = Hx - b
|
|
122
|
+
loss = residual.pow(2).global_mean()
|
|
65
123
|
if backward:
|
|
66
|
-
|
|
67
|
-
|
|
124
|
+
with torch.no_grad():
|
|
125
|
+
H_residual = H_mm(residual)
|
|
126
|
+
n = residual.global_numel()
|
|
127
|
+
x.set_grad_((2.0 / n) * H_residual)
|
|
128
|
+
|
|
68
129
|
return loss
|
|
69
130
|
|
|
70
131
|
if maxiter is None: maxiter = b.global_numel()
|
|
71
132
|
loss = None
|
|
72
|
-
initial_loss = lstsq_closure(False)
|
|
73
|
-
if initial_loss >
|
|
133
|
+
initial_loss = lstsq_closure(False) if tol is not None else None # skip unnecessary closure if tol is None
|
|
134
|
+
if initial_loss is None or initial_loss > torch.finfo(b[0].dtype).eps:
|
|
74
135
|
for i in range(maxiter):
|
|
75
136
|
loss = solver.step(lstsq_closure)
|
|
76
137
|
assert loss is not None
|
|
77
|
-
if
|
|
138
|
+
if initial_loss is not None and loss/initial_loss < tol: break
|
|
78
139
|
|
|
79
140
|
# print(f'{loss = }')
|
|
80
141
|
|
|
@@ -83,6 +144,7 @@ class NewtonSolver(Module):
|
|
|
83
144
|
x0.copy_(x)
|
|
84
145
|
|
|
85
146
|
var.update = x.detach()
|
|
147
|
+
self._num_hvps += self._num_hvps_last_step
|
|
86
148
|
return var
|
|
87
149
|
|
|
88
150
|
|
|
@@ -10,16 +10,16 @@ import torch
|
|
|
10
10
|
from ...core import Chainable, Module, apply_transform
|
|
11
11
|
from ...utils import TensorList, vec_to_tensors
|
|
12
12
|
from ...utils.derivatives import (
|
|
13
|
-
|
|
13
|
+
flatten_jacobian,
|
|
14
14
|
jacobian_wrt,
|
|
15
15
|
)
|
|
16
16
|
from ..second_order.newton import (
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
17
|
+
_cholesky_solve,
|
|
18
|
+
_eigh_solve,
|
|
19
|
+
_least_squares_solve,
|
|
20
|
+
_lu_solve,
|
|
21
21
|
)
|
|
22
|
-
|
|
22
|
+
from ...utils.linalg.linear_operator import Dense
|
|
23
23
|
|
|
24
24
|
class NewtonNewton(Module):
|
|
25
25
|
"""Applies Newton-like preconditioning to Newton step.
|
|
@@ -51,10 +51,10 @@ class NewtonNewton(Module):
|
|
|
51
51
|
super().__init__(defaults)
|
|
52
52
|
|
|
53
53
|
@torch.no_grad
|
|
54
|
-
def
|
|
54
|
+
def update(self, var):
|
|
55
55
|
params = TensorList(var.params)
|
|
56
56
|
closure = var.closure
|
|
57
|
-
if closure is None: raise RuntimeError('
|
|
57
|
+
if closure is None: raise RuntimeError('NewtonNewton requires closure')
|
|
58
58
|
|
|
59
59
|
settings = self.settings[params[0]]
|
|
60
60
|
reg = settings['reg']
|
|
@@ -64,6 +64,7 @@ class NewtonNewton(Module):
|
|
|
64
64
|
eigval_tfm = settings['eigval_tfm']
|
|
65
65
|
|
|
66
66
|
# ------------------------ calculate grad and hessian ------------------------ #
|
|
67
|
+
Hs = []
|
|
67
68
|
with torch.enable_grad():
|
|
68
69
|
loss = var.loss = var.loss_approx = closure(False)
|
|
69
70
|
g_list = torch.autograd.grad(loss, params, create_graph=True)
|
|
@@ -76,17 +77,29 @@ class NewtonNewton(Module):
|
|
|
76
77
|
is_last = o == order
|
|
77
78
|
H_list = jacobian_wrt([xp], params, create_graph=not is_last, batched=vectorize)
|
|
78
79
|
with torch.no_grad() if is_last else nullcontext():
|
|
79
|
-
H =
|
|
80
|
+
H = flatten_jacobian(H_list)
|
|
80
81
|
if reg != 0: H = H + I * reg
|
|
82
|
+
Hs.append(H)
|
|
81
83
|
|
|
82
84
|
x = None
|
|
83
85
|
if search_negative or (is_last and eigval_tfm is not None):
|
|
84
|
-
x =
|
|
85
|
-
if x is None: x =
|
|
86
|
-
if x is None: x =
|
|
87
|
-
if x is None: x =
|
|
86
|
+
x = _eigh_solve(H, xp, eigval_tfm, search_negative=search_negative)
|
|
87
|
+
if x is None: x = _cholesky_solve(H, xp)
|
|
88
|
+
if x is None: x = _lu_solve(H, xp)
|
|
89
|
+
if x is None: x = _least_squares_solve(H, xp)
|
|
88
90
|
xp = x.squeeze()
|
|
89
91
|
|
|
90
|
-
|
|
92
|
+
self.global_state["Hs"] = Hs
|
|
93
|
+
self.global_state['xp'] = xp.nan_to_num_(0,0,0)
|
|
94
|
+
|
|
95
|
+
@torch.no_grad
|
|
96
|
+
def apply(self, var):
|
|
97
|
+
params = var.params
|
|
98
|
+
xp = self.global_state['xp']
|
|
99
|
+
var.update = vec_to_tensors(xp, params)
|
|
91
100
|
return var
|
|
92
101
|
|
|
102
|
+
def get_H(self, var):
|
|
103
|
+
Hs = self.global_state["Hs"]
|
|
104
|
+
if len(Hs) == 1: return Dense(Hs[0])
|
|
105
|
+
return Dense(torch.linalg.multi_dot(self.global_state["Hs"])) # pylint:disable=not-callable
|
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
from typing import Literal, overload
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from scipy.sparse.linalg import LinearOperator, gcrotmk
|
|
5
|
+
|
|
6
|
+
from ...core import Chainable, Module, apply_transform
|
|
7
|
+
from ...utils import NumberList, TensorList, as_tensorlist, generic_vector_norm, vec_to_tensors
|
|
8
|
+
from ...utils.derivatives import hvp, hvp_fd_central, hvp_fd_forward
|
|
9
|
+
from ...utils.linalg.solve import cg, minres
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ScipyNewtonCG(Module):
|
|
13
|
+
"""NewtonCG with scipy solvers (any from scipy.sparse.linalg)"""
|
|
14
|
+
def __init__(
|
|
15
|
+
self,
|
|
16
|
+
solver = gcrotmk,
|
|
17
|
+
hvp_method: Literal["forward", "central", "autograd"] = "autograd",
|
|
18
|
+
h: float = 1e-3,
|
|
19
|
+
warm_start=False,
|
|
20
|
+
inner: Chainable | None = None,
|
|
21
|
+
kwargs: dict | None = None,
|
|
22
|
+
):
|
|
23
|
+
defaults = dict(hvp_method=hvp_method, solver=solver, h=h, warm_start=warm_start)
|
|
24
|
+
super().__init__(defaults,)
|
|
25
|
+
|
|
26
|
+
if inner is not None:
|
|
27
|
+
self.set_child('inner', inner)
|
|
28
|
+
|
|
29
|
+
self._num_hvps = 0
|
|
30
|
+
self._num_hvps_last_step = 0
|
|
31
|
+
|
|
32
|
+
if kwargs is None: kwargs = {}
|
|
33
|
+
self._kwargs = kwargs
|
|
34
|
+
|
|
35
|
+
@torch.no_grad
|
|
36
|
+
def step(self, var):
|
|
37
|
+
params = TensorList(var.params)
|
|
38
|
+
closure = var.closure
|
|
39
|
+
if closure is None: raise RuntimeError('NewtonCG requires closure')
|
|
40
|
+
|
|
41
|
+
settings = self.settings[params[0]]
|
|
42
|
+
hvp_method = settings['hvp_method']
|
|
43
|
+
solver = settings['solver']
|
|
44
|
+
h = settings['h']
|
|
45
|
+
warm_start = settings['warm_start']
|
|
46
|
+
|
|
47
|
+
self._num_hvps_last_step = 0
|
|
48
|
+
# ---------------------- Hessian vector product function --------------------- #
|
|
49
|
+
device = params[0].device; dtype=params[0].dtype
|
|
50
|
+
if hvp_method == 'autograd':
|
|
51
|
+
grad = var.get_grad(create_graph=True)
|
|
52
|
+
|
|
53
|
+
def H_mm(x_np):
|
|
54
|
+
self._num_hvps_last_step += 1
|
|
55
|
+
x = vec_to_tensors(torch.as_tensor(x_np, device=device, dtype=dtype), grad)
|
|
56
|
+
with torch.enable_grad():
|
|
57
|
+
Hvp = TensorList(hvp(params, grad, x, retain_graph=True))
|
|
58
|
+
return torch.cat([t.ravel() for t in Hvp]).numpy(force=True)
|
|
59
|
+
|
|
60
|
+
else:
|
|
61
|
+
|
|
62
|
+
with torch.enable_grad():
|
|
63
|
+
grad = var.get_grad()
|
|
64
|
+
|
|
65
|
+
if hvp_method == 'forward':
|
|
66
|
+
def H_mm(x_np):
|
|
67
|
+
self._num_hvps_last_step += 1
|
|
68
|
+
x = vec_to_tensors(torch.as_tensor(x_np, device=device, dtype=dtype), grad)
|
|
69
|
+
Hvp = TensorList(hvp_fd_forward(closure, params, x, h=h, g_0=grad, normalize=True)[1])
|
|
70
|
+
return torch.cat([t.ravel() for t in Hvp]).numpy(force=True)
|
|
71
|
+
|
|
72
|
+
elif hvp_method == 'central':
|
|
73
|
+
def H_mm(x_np):
|
|
74
|
+
self._num_hvps_last_step += 1
|
|
75
|
+
x = vec_to_tensors(torch.as_tensor(x_np, device=device, dtype=dtype), grad)
|
|
76
|
+
Hvp = TensorList(hvp_fd_central(closure, params, x, h=h, normalize=True)[1])
|
|
77
|
+
return torch.cat([t.ravel() for t in Hvp]).numpy(force=True)
|
|
78
|
+
|
|
79
|
+
else:
|
|
80
|
+
raise ValueError(hvp_method)
|
|
81
|
+
|
|
82
|
+
ndim = sum(p.numel() for p in params)
|
|
83
|
+
H = LinearOperator(shape=(ndim,ndim), matvec=H_mm, rmatvec=H_mm) # type:ignore
|
|
84
|
+
|
|
85
|
+
# -------------------------------- inner step -------------------------------- #
|
|
86
|
+
b = var.get_update()
|
|
87
|
+
if 'inner' in self.children:
|
|
88
|
+
b = apply_transform(self.children['inner'], b, params=params, grads=grad, var=var)
|
|
89
|
+
b = as_tensorlist(b)
|
|
90
|
+
|
|
91
|
+
# ---------------------------------- run cg ---------------------------------- #
|
|
92
|
+
x0 = None
|
|
93
|
+
if warm_start: x0 = self.global_state.get('x_prev', None) # initialized to 0 which is default anyway
|
|
94
|
+
|
|
95
|
+
x_np = solver(H, b.to_vec().nan_to_num().numpy(force=True), x0=x0, **self._kwargs)
|
|
96
|
+
if isinstance(x_np, tuple): x_np = x_np[0]
|
|
97
|
+
|
|
98
|
+
if warm_start:
|
|
99
|
+
self.global_state['x_prev'] = x_np
|
|
100
|
+
|
|
101
|
+
var.update = vec_to_tensors(torch.as_tensor(x_np, device=device, dtype=dtype), params)
|
|
102
|
+
|
|
103
|
+
self._num_hvps += self._num_hvps_last_step
|
|
104
|
+
return var
|
|
105
|
+
|