torchzero 0.3.10__py3-none-any.whl → 0.3.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +2 -3
- tests/test_opts.py +140 -100
- tests/test_tensorlist.py +8 -7
- tests/test_vars.py +1 -0
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +335 -50
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +197 -70
- torchzero/modules/__init__.py +13 -4
- torchzero/modules/adaptive/__init__.py +30 -0
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/adaptive/adahessian.py +224 -0
- torchzero/modules/{optimizers → adaptive}/adam.py +6 -8
- torchzero/modules/adaptive/adan.py +96 -0
- torchzero/modules/adaptive/adaptive_heavyball.py +54 -0
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/adaptive/esgd.py +171 -0
- torchzero/modules/{optimizers → adaptive}/lion.py +1 -1
- torchzero/modules/{experimental/spectral.py → adaptive/lmadagrad.py} +94 -71
- torchzero/modules/adaptive/mars.py +79 -0
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/adaptive/msam.py +188 -0
- torchzero/modules/{optimizers → adaptive}/muon.py +29 -5
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/orthograd.py +1 -1
- torchzero/modules/{optimizers → adaptive}/rmsprop.py +7 -4
- torchzero/modules/{optimizers → adaptive}/rprop.py +42 -10
- torchzero/modules/adaptive/sam.py +163 -0
- torchzero/modules/{optimizers → adaptive}/shampoo.py +47 -9
- torchzero/modules/{optimizers → adaptive}/soap.py +52 -65
- torchzero/modules/adaptive/sophia_h.py +185 -0
- torchzero/modules/clipping/clipping.py +115 -25
- torchzero/modules/clipping/ema_clipping.py +31 -17
- torchzero/modules/clipping/growth_clipping.py +8 -7
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/conjugate_gradient/cg.py +355 -0
- torchzero/modules/experimental/__init__.py +13 -19
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +5 -42
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +32 -15
- torchzero/modules/experimental/reduce_outward_lr.py +4 -4
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +13 -55
- torchzero/modules/functional.py +52 -6
- torchzero/modules/grad_approximation/fdm.py +30 -4
- torchzero/modules/grad_approximation/forward_gradient.py +16 -4
- torchzero/modules/grad_approximation/grad_approximator.py +51 -10
- torchzero/modules/grad_approximation/rfdm.py +321 -52
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +164 -93
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +124 -0
- torchzero/modules/line_search/backtracking.py +95 -57
- torchzero/modules/line_search/line_search.py +171 -22
- torchzero/modules/line_search/scipy.py +3 -3
- torchzero/modules/line_search/strong_wolfe.py +327 -199
- torchzero/modules/misc/__init__.py +35 -0
- torchzero/modules/misc/debug.py +48 -0
- torchzero/modules/misc/escape.py +62 -0
- torchzero/modules/misc/gradient_accumulation.py +136 -0
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +383 -0
- torchzero/modules/misc/multistep.py +194 -0
- torchzero/modules/misc/regularization.py +167 -0
- torchzero/modules/misc/split.py +123 -0
- torchzero/modules/{ops → misc}/switch.py +45 -4
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +9 -9
- torchzero/modules/momentum/cautious.py +51 -19
- torchzero/modules/momentum/momentum.py +37 -2
- torchzero/modules/ops/__init__.py +11 -31
- torchzero/modules/ops/accumulate.py +6 -10
- torchzero/modules/ops/binary.py +81 -34
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +16 -39
- torchzero/modules/ops/multi.py +82 -21
- torchzero/modules/ops/reduce.py +16 -8
- torchzero/modules/ops/unary.py +29 -13
- torchzero/modules/ops/utility.py +30 -18
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +3 -1
- torchzero/modules/projections/projection.py +190 -96
- torchzero/modules/quasi_newton/__init__.py +9 -14
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -0
- torchzero/modules/quasi_newton/lbfgs.py +286 -173
- torchzero/modules/quasi_newton/lsr1.py +185 -106
- torchzero/modules/quasi_newton/quasi_newton.py +816 -268
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +252 -0
- torchzero/modules/second_order/__init__.py +3 -2
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +292 -68
- torchzero/modules/second_order/newton_cg.py +365 -15
- torchzero/modules/second_order/nystrom.py +104 -1
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/laplacian.py +14 -4
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +387 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +97 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +94 -11
- torchzero/modules/wrappers/optim_wrapper.py +29 -1
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +359 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +39 -3
- torchzero/optim/wrappers/fcmaes.py +24 -15
- torchzero/optim/wrappers/mads.py +5 -6
- torchzero/optim/wrappers/nevergrad.py +16 -1
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +3 -3
- torchzero/optim/wrappers/scipy.py +86 -25
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +126 -114
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +369 -58
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/python_tools.py +16 -0
- torchzero/utils/tensorlist.py +134 -51
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.13.dist-info/METADATA +14 -0
- torchzero-0.3.13.dist-info/RECORD +166 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -57
- torchzero/modules/experimental/absoap.py +0 -250
- torchzero/modules/experimental/adadam.py +0 -112
- torchzero/modules/experimental/adamY.py +0 -125
- torchzero/modules/experimental/adasoap.py +0 -172
- torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
- torchzero/modules/experimental/eigendescent.py +0 -117
- torchzero/modules/experimental/etf.py +0 -172
- torchzero/modules/experimental/soapy.py +0 -163
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/experimental/subspace_preconditioners.py +0 -138
- torchzero/modules/experimental/tada.py +0 -38
- torchzero/modules/line_search/trust_region.py +0 -73
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/adaptive.py +0 -93
- torchzero/modules/lr/lr.py +0 -63
- torchzero/modules/momentum/matrix_momentum.py +0 -166
- torchzero/modules/ops/debug.py +0 -25
- torchzero/modules/ops/misc.py +0 -418
- torchzero/modules/ops/split.py +0 -75
- torchzero/modules/optimizers/__init__.py +0 -18
- torchzero/modules/optimizers/adagrad.py +0 -155
- torchzero/modules/optimizers/sophia_h.py +0 -129
- torchzero/modules/quasi_newton/cg.py +0 -268
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -266
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero/modules/smoothing/gaussian.py +0 -164
- torchzero-0.3.10.dist-info/METADATA +0 -379
- torchzero-0.3.10.dist-info/RECORD +0 -139
- torchzero-0.3.10.dist-info/licenses/LICENSE +0 -21
- {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
# pylint:disable=not-callable
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
from ...core import Chainable, Module
|
|
5
|
+
from ...utils.linalg import linear_operator
|
|
6
|
+
from .trust_region import _RADIUS_KEYS, TrustRegionBase, _RadiusStrategy
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class LevenbergMarquardt(TrustRegionBase):
|
|
10
|
+
"""Levenberg-Marquardt trust region algorithm.
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
hess_module (Module | None, optional):
|
|
15
|
+
A module that maintains a hessian approximation (not hessian inverse!).
|
|
16
|
+
This includes all full-matrix quasi-newton methods, ``tz.m.Newton`` and ``tz.m.GaussNewton``.
|
|
17
|
+
When using quasi-newton methods, set `inverse=False` when constructing them.
|
|
18
|
+
y (float, optional):
|
|
19
|
+
when ``y=0``, identity matrix is added to hessian, when ``y=1``, diagonal of the hessian approximation
|
|
20
|
+
is added. Values between interpolate. This should only be used with Gauss-Newton. Defaults to 0.
|
|
21
|
+
eta (float, optional):
|
|
22
|
+
if ratio of actual to predicted rediction is larger than this, step is accepted.
|
|
23
|
+
When :code:`hess_module` is GaussNewton, this can be set to 0. Defaults to 0.15.
|
|
24
|
+
nplus (float, optional): increase factor on successful steps. Defaults to 1.5.
|
|
25
|
+
nminus (float, optional): decrease factor on unsuccessful steps. Defaults to 0.75.
|
|
26
|
+
rho_good (float, optional):
|
|
27
|
+
if ratio of actual to predicted rediction is larger than this, trust region size is multiplied by `nplus`.
|
|
28
|
+
rho_bad (float, optional):
|
|
29
|
+
if ratio of actual to predicted rediction is less than this, trust region size is multiplied by `nminus`.
|
|
30
|
+
init (float, optional): Initial trust region value. Defaults to 1.
|
|
31
|
+
update_freq (int, optional): frequency of updating the hessian. Defaults to 1.
|
|
32
|
+
max_attempts (max_attempts, optional):
|
|
33
|
+
maximum number of trust region size size reductions per step. A zero update vector is returned when
|
|
34
|
+
this limit is exceeded. Defaults to 10.
|
|
35
|
+
fallback (bool, optional):
|
|
36
|
+
if ``True``, when ``hess_module`` maintains hessian inverse which can't be inverted efficiently, it will
|
|
37
|
+
be inverted anyway. When ``False`` (default), a ``RuntimeError`` will be raised instead.
|
|
38
|
+
inner (Chainable | None, optional): preconditioning is applied to output of thise module. Defaults to None.
|
|
39
|
+
|
|
40
|
+
Examples:
|
|
41
|
+
Gauss-Newton with Levenberg-Marquardt trust-region
|
|
42
|
+
|
|
43
|
+
.. code-block:: python
|
|
44
|
+
|
|
45
|
+
opt = tz.Modular(
|
|
46
|
+
model.parameters(),
|
|
47
|
+
tz.m.LevenbergMarquardt(tz.m.GaussNewton()),
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
LM-SR1
|
|
51
|
+
|
|
52
|
+
.. code-block:: python
|
|
53
|
+
|
|
54
|
+
opt = tz.Modular(
|
|
55
|
+
model.parameters(),
|
|
56
|
+
tz.m.LevenbergMarquardt(tz.m.SR1(inverse=False)),
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
First order trust region (hessian is assumed to be identity)
|
|
60
|
+
|
|
61
|
+
.. code-block:: python
|
|
62
|
+
|
|
63
|
+
opt = tz.Modular(
|
|
64
|
+
model.parameters(),
|
|
65
|
+
tz.m.LevenbergMarquardt(tz.m.Identity()),
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
"""
|
|
69
|
+
def __init__(
|
|
70
|
+
self,
|
|
71
|
+
hess_module: Chainable,
|
|
72
|
+
eta: float= 0.0,
|
|
73
|
+
nplus: float = 3.5,
|
|
74
|
+
nminus: float = 0.25,
|
|
75
|
+
rho_good: float = 0.99,
|
|
76
|
+
rho_bad: float = 1e-4,
|
|
77
|
+
init: float = 1,
|
|
78
|
+
max_attempts: int = 10,
|
|
79
|
+
radius_strategy: _RadiusStrategy | _RADIUS_KEYS = 'default',
|
|
80
|
+
y: float = 0,
|
|
81
|
+
fallback: bool = False,
|
|
82
|
+
update_freq: int = 1,
|
|
83
|
+
inner: Chainable | None = None,
|
|
84
|
+
):
|
|
85
|
+
defaults = dict(y=y, fallback=fallback)
|
|
86
|
+
super().__init__(
|
|
87
|
+
defaults=defaults,
|
|
88
|
+
hess_module=hess_module,
|
|
89
|
+
eta=eta,
|
|
90
|
+
nplus=nplus,
|
|
91
|
+
nminus=nminus,
|
|
92
|
+
rho_good=rho_good,
|
|
93
|
+
rho_bad=rho_bad,
|
|
94
|
+
init=init,
|
|
95
|
+
max_attempts=max_attempts,
|
|
96
|
+
radius_strategy=radius_strategy,
|
|
97
|
+
update_freq=update_freq,
|
|
98
|
+
inner=inner,
|
|
99
|
+
|
|
100
|
+
boundary_tol=None,
|
|
101
|
+
radius_fn=None,
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
def trust_solve(self, f, g, H, radius, params, closure, settings):
|
|
105
|
+
y = settings['y']
|
|
106
|
+
|
|
107
|
+
if isinstance(H, linear_operator.DenseInverse):
|
|
108
|
+
if settings['fallback']:
|
|
109
|
+
H = H.to_dense()
|
|
110
|
+
else:
|
|
111
|
+
raise RuntimeError(
|
|
112
|
+
f"{self.children['hess_module']} maintains a hessian inverse. "
|
|
113
|
+
"LevenbergMarquardt requires the hessian, not the inverse. "
|
|
114
|
+
"If that module is a quasi-newton module, pass `inverse=False` on initialization. "
|
|
115
|
+
"Or pass `fallback=True` to LevenbergMarquardt to allow inverting the hessian inverse, "
|
|
116
|
+
"however that can be inefficient and unstable."
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
reg = 1/radius
|
|
120
|
+
if y == 0:
|
|
121
|
+
return H.add_diagonal(reg).solve(g)
|
|
122
|
+
|
|
123
|
+
diag = H.diagonal()
|
|
124
|
+
diag = torch.where(diag < torch.finfo(diag.dtype).tiny * 2, 1, diag)
|
|
125
|
+
if y != 1: diag = (diag*y) + (1-y)
|
|
126
|
+
return H.add_diagonal(diag*reg).solve(g)
|
|
127
|
+
|
|
128
|
+
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from ...core import Chainable, Module
|
|
4
|
+
from ...utils.linalg import cg, linear_operator
|
|
5
|
+
from .trust_region import _RADIUS_KEYS, TrustRegionBase, _RadiusStrategy
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class TrustCG(TrustRegionBase):
|
|
9
|
+
"""Trust region via Steihaug-Toint Conjugate Gradient method.
|
|
10
|
+
|
|
11
|
+
.. note::
|
|
12
|
+
|
|
13
|
+
If you wish to use exact hessian, use the matrix-free :code:`tz.m.NewtonCGSteihaug`
|
|
14
|
+
which only uses hessian-vector products. While passing ``tz.m.Newton`` to this
|
|
15
|
+
is possible, it is usually less efficient.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
hess_module (Module | None, optional):
|
|
19
|
+
A module that maintains a hessian approximation (not hessian inverse!).
|
|
20
|
+
This includes all full-matrix quasi-newton methods, ``tz.m.Newton`` and ``tz.m.GaussNewton``.
|
|
21
|
+
When using quasi-newton methods, set `inverse=False` when constructing them.
|
|
22
|
+
eta (float, optional):
|
|
23
|
+
if ratio of actual to predicted rediction is larger than this, step is accepted.
|
|
24
|
+
When :code:`hess_module` is GaussNewton, this can be set to 0. Defaults to 0.15.
|
|
25
|
+
nplus (float, optional): increase factor on successful steps. Defaults to 1.5.
|
|
26
|
+
nminus (float, optional): decrease factor on unsuccessful steps. Defaults to 0.75.
|
|
27
|
+
rho_good (float, optional):
|
|
28
|
+
if ratio of actual to predicted rediction is larger than this, trust region size is multiplied by `nplus`.
|
|
29
|
+
rho_bad (float, optional):
|
|
30
|
+
if ratio of actual to predicted rediction is less than this, trust region size is multiplied by `nminus`.
|
|
31
|
+
init (float, optional): Initial trust region value. Defaults to 1.
|
|
32
|
+
update_freq (int, optional): frequency of updating the hessian. Defaults to 1.
|
|
33
|
+
reg (int, optional): regularization parameter for conjugate gradient. Defaults to 0.
|
|
34
|
+
max_attempts (max_attempts, optional):
|
|
35
|
+
maximum number of trust region size size reductions per step. A zero update vector is returned when
|
|
36
|
+
this limit is exceeded. Defaults to 10.
|
|
37
|
+
boundary_tol (float | None, optional):
|
|
38
|
+
The trust region only increases when suggested step's norm is at least `(1-boundary_tol)*trust_region`.
|
|
39
|
+
This prevents increasing trust region when solution is not on the boundary. Defaults to 1e-2.
|
|
40
|
+
prefer_exact (bool, optional):
|
|
41
|
+
when exact solution can be easily calculated without CG (e.g. hessian is stored as scaled identity),
|
|
42
|
+
uses the exact solution. If False, always uses CG. Defaults to True.
|
|
43
|
+
inner (Chainable | None, optional): preconditioning is applied to output of thise module. Defaults to None.
|
|
44
|
+
|
|
45
|
+
Examples:
|
|
46
|
+
Trust-SR1
|
|
47
|
+
|
|
48
|
+
.. code-block:: python
|
|
49
|
+
|
|
50
|
+
opt = tz.Modular(
|
|
51
|
+
model.parameters(),
|
|
52
|
+
tz.m.TrustCG(hess_module=tz.m.SR1(inverse=False)),
|
|
53
|
+
)
|
|
54
|
+
"""
|
|
55
|
+
def __init__(
|
|
56
|
+
self,
|
|
57
|
+
hess_module: Chainable,
|
|
58
|
+
eta: float= 0.0,
|
|
59
|
+
nplus: float = 3.5,
|
|
60
|
+
nminus: float = 0.25,
|
|
61
|
+
rho_good: float = 0.99,
|
|
62
|
+
rho_bad: float = 1e-4,
|
|
63
|
+
boundary_tol: float | None = 1e-1,
|
|
64
|
+
init: float = 1,
|
|
65
|
+
max_attempts: int = 10,
|
|
66
|
+
radius_strategy: _RadiusStrategy | _RADIUS_KEYS = 'default',
|
|
67
|
+
reg: float = 0,
|
|
68
|
+
cg_tol: float = 1e-4,
|
|
69
|
+
prefer_exact: bool = True,
|
|
70
|
+
update_freq: int = 1,
|
|
71
|
+
inner: Chainable | None = None,
|
|
72
|
+
):
|
|
73
|
+
defaults = dict(reg=reg, prefer_exact=prefer_exact, cg_tol=cg_tol)
|
|
74
|
+
super().__init__(
|
|
75
|
+
defaults=defaults,
|
|
76
|
+
hess_module=hess_module,
|
|
77
|
+
eta=eta,
|
|
78
|
+
nplus=nplus,
|
|
79
|
+
nminus=nminus,
|
|
80
|
+
rho_good=rho_good,
|
|
81
|
+
rho_bad=rho_bad,
|
|
82
|
+
boundary_tol=boundary_tol,
|
|
83
|
+
init=init,
|
|
84
|
+
max_attempts=max_attempts,
|
|
85
|
+
radius_strategy=radius_strategy,
|
|
86
|
+
update_freq=update_freq,
|
|
87
|
+
inner=inner,
|
|
88
|
+
|
|
89
|
+
radius_fn=torch.linalg.vector_norm,
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
def trust_solve(self, f, g, H, radius, params, closure, settings):
|
|
93
|
+
if settings['prefer_exact'] and isinstance(H, linear_operator.ScaledIdentity):
|
|
94
|
+
return H.solve_bounded(g, radius)
|
|
95
|
+
|
|
96
|
+
x, _ = cg(H.matvec, g, trust_radius=radius, reg=settings['reg'], tol=settings["cg_tol"])
|
|
97
|
+
return x
|
|
@@ -0,0 +1,350 @@
|
|
|
1
|
+
import math
|
|
2
|
+
import warnings
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from collections.abc import Callable, Mapping, Sequence
|
|
5
|
+
from functools import partial
|
|
6
|
+
from typing import Any, Literal, Protocol, cast, final, overload
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
from ...core import Chainable, Module, Var, apply_transform
|
|
11
|
+
from ...utils import TensorList, safe_dict_update_, tofloat, vec_to_tensors, generic_finfo, generic_vector_norm
|
|
12
|
+
from ...utils.linalg.linear_operator import LinearOperator
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _flatten_tensors(tensors: list[torch.Tensor]):
|
|
16
|
+
return torch.cat([t.ravel() for t in tensors])
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class _RadiusStrategy(Protocol):
|
|
21
|
+
def __call__(
|
|
22
|
+
self,
|
|
23
|
+
params: Sequence[torch.Tensor],
|
|
24
|
+
closure: Callable,
|
|
25
|
+
f: float,
|
|
26
|
+
g: torch.Tensor,
|
|
27
|
+
H: LinearOperator,
|
|
28
|
+
d: torch.Tensor,
|
|
29
|
+
trust_radius: float,
|
|
30
|
+
eta: float, # 0.0
|
|
31
|
+
nplus: float, # 3.5
|
|
32
|
+
nminus: float, # 0.25
|
|
33
|
+
rho_good: float, # 0.99
|
|
34
|
+
rho_bad: float, # 1e-4
|
|
35
|
+
boundary_tol: float | None,
|
|
36
|
+
init: float,
|
|
37
|
+
state: Mapping[str, Any],
|
|
38
|
+
settings: Mapping[str, Any],
|
|
39
|
+
radius_fn: Callable | None = torch.linalg.vector_norm,
|
|
40
|
+
) -> tuple[float, bool]:
|
|
41
|
+
"""returns (new trust_region value, success).
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
params (Sequence[torch.Tensor]): params tensor list
|
|
45
|
+
closure (Callable): closure
|
|
46
|
+
d (torch.Tensor):
|
|
47
|
+
current update vector with current trust_region, which is SUBTRACTED from parameters.
|
|
48
|
+
May be exact solution to (B+yI)x=g, approximate, or a solution to a different subproblem
|
|
49
|
+
(e.g. cubic regularization).
|
|
50
|
+
f (float | torch.Tensor): loss at x0
|
|
51
|
+
g (torch.Tensor): gradient vector
|
|
52
|
+
H (LinearOperator | None): hessian approximation
|
|
53
|
+
trust_radius (float): current trust region value
|
|
54
|
+
eta (float, optional):
|
|
55
|
+
if ratio of actual to predicted rediction is larger than this, step is accepted.
|
|
56
|
+
When :code:`hess_module` is GaussNewton, this can be set to 0. Defaults to 0.15.
|
|
57
|
+
nplus (float, optional): increase factor on successful steps. Defaults to 1.5.
|
|
58
|
+
nminus (float, optional): decrease factor on unsuccessful steps. Defaults to 0.75.
|
|
59
|
+
rho_good (float, optional):
|
|
60
|
+
if ratio of actual to predicted rediction is larger than this, trust region size is multiplied by `nplus`.
|
|
61
|
+
rho_bad (float, optional):
|
|
62
|
+
if ratio of actual to predicted rediction is less than this, trust region size is multiplied by `nminus`.
|
|
63
|
+
boundary_tol (float | None, optional):
|
|
64
|
+
The trust region only increases when suggested step's norm is at least `(1-boundary_tol)*trust_region`.
|
|
65
|
+
This prevents increasing trust region when solution is not on the boundary. Defaults to 1e-2.
|
|
66
|
+
init (float, optional): Initial trust region value. Defaults to 1.
|
|
67
|
+
state (dict, optional): global state of the module for storing persistent info.
|
|
68
|
+
settings (dict, optional): all settings in case this strategy has other settings.
|
|
69
|
+
radius_fn (Callable | None, optional):
|
|
70
|
+
function that accepts ``(d: torch.Tensor)`` and returns the actual region of ``d``
|
|
71
|
+
(e.g. L2) norm for L2 trust region.
|
|
72
|
+
"""
|
|
73
|
+
... # pylint:disable=unnecessary-ellipsis
|
|
74
|
+
|
|
75
|
+
def _get_rho(params: Sequence[torch.Tensor], closure:Callable,
|
|
76
|
+
f: float, g: torch.Tensor, H: LinearOperator, d:torch.Tensor, ):
|
|
77
|
+
"""rho is reduction/pred_reduction"""
|
|
78
|
+
|
|
79
|
+
# evaluate actual loss reduction
|
|
80
|
+
update_unflattned = vec_to_tensors(d, params)
|
|
81
|
+
params = TensorList(params)
|
|
82
|
+
x0 = params.clone() # same as in line searches, large directions are undone very imprecisely
|
|
83
|
+
|
|
84
|
+
params -= update_unflattned
|
|
85
|
+
f_star = closure(False)
|
|
86
|
+
params.set_(x0)
|
|
87
|
+
|
|
88
|
+
reduction = f - f_star
|
|
89
|
+
|
|
90
|
+
# expected reduction is g.T @ p + 0.5 * p.T @ B @ p
|
|
91
|
+
Hu = H.matvec(d)
|
|
92
|
+
pred_reduction = g.dot(d) - 0.5 * d.dot(Hu)
|
|
93
|
+
|
|
94
|
+
rho = reduction / (pred_reduction.clip(min=torch.finfo(g.dtype).tiny * 2))
|
|
95
|
+
return rho, f_star, reduction, pred_reduction
|
|
96
|
+
|
|
97
|
+
def _get_rho_tensorlist(params: Sequence[torch.Tensor], closure:Callable,
|
|
98
|
+
f: float, g: TensorList, Hvp: Callable[[TensorList], TensorList], d:TensorList):
|
|
99
|
+
"""rho is reduction/pred_reduction"""
|
|
100
|
+
params = TensorList(params)
|
|
101
|
+
x0 = params.clone() # same as in line searches, large directions are undone very imprecisely
|
|
102
|
+
|
|
103
|
+
# evaluate before modifying params to not break autograd
|
|
104
|
+
Hu = Hvp(d)
|
|
105
|
+
|
|
106
|
+
# actual f
|
|
107
|
+
params -= d
|
|
108
|
+
f_star = closure(False)
|
|
109
|
+
params.copy_(x0)
|
|
110
|
+
|
|
111
|
+
reduction = f - f_star
|
|
112
|
+
|
|
113
|
+
# expected f is g.T @ p + 0.5 * p.T @ B @ p
|
|
114
|
+
pred_reduction = g.dot(d) - 0.5 * d.dot(Hu)
|
|
115
|
+
|
|
116
|
+
rho = reduction / (pred_reduction.clip(min=torch.finfo(g[0].dtype).tiny * 2))
|
|
117
|
+
return rho, f_star, reduction, pred_reduction
|
|
118
|
+
|
|
119
|
+
@torch.no_grad
|
|
120
|
+
def default_radius(
|
|
121
|
+
params: Sequence[torch.Tensor],
|
|
122
|
+
closure: Callable,
|
|
123
|
+
f: float,
|
|
124
|
+
g: torch.Tensor | TensorList,
|
|
125
|
+
H: LinearOperator | Callable,
|
|
126
|
+
d: torch.Tensor | TensorList,
|
|
127
|
+
trust_radius: float,
|
|
128
|
+
eta: float, # 0.0
|
|
129
|
+
nplus: float, # 3.5
|
|
130
|
+
nminus: float, # 0.25
|
|
131
|
+
rho_good: float, # 0.99
|
|
132
|
+
rho_bad: float, # 1e-4
|
|
133
|
+
boundary_tol: float | None,
|
|
134
|
+
init: float,
|
|
135
|
+
state: Mapping[str, Any],
|
|
136
|
+
settings: Mapping[str, Any],
|
|
137
|
+
radius_fn: Callable | None = generic_vector_norm,
|
|
138
|
+
check_overflow: bool = True,
|
|
139
|
+
# dynamic_nminus: bool=False,
|
|
140
|
+
) -> tuple[float, bool]:
|
|
141
|
+
|
|
142
|
+
# when rho_bad < rho < eta, no update is made but trust region is not updated.
|
|
143
|
+
if eta > rho_bad:
|
|
144
|
+
warnings.warn(f"trust region eta={eta} is larger than rho_bad={rho_bad}, "
|
|
145
|
+
"this can lead to trust region getting stuck.")
|
|
146
|
+
|
|
147
|
+
if isinstance(g, torch.Tensor):
|
|
148
|
+
rho, f_star, _, _ = _get_rho(params=params, closure=closure, f=f, g=g, H=H, d=d) # pyright:ignore[reportArgumentType]
|
|
149
|
+
else:
|
|
150
|
+
rho, f_star, _, _ = _get_rho_tensorlist(params=params, closure=closure, f=f, g=g, Hvp=H, d=d) # pyright:ignore[reportArgumentType]
|
|
151
|
+
|
|
152
|
+
is_finite = math.isfinite(f_star)
|
|
153
|
+
|
|
154
|
+
# find boundary of current step
|
|
155
|
+
if radius_fn is None: d_radius = trust_radius
|
|
156
|
+
else: d_radius = radius_fn(d)
|
|
157
|
+
|
|
158
|
+
# failed step
|
|
159
|
+
if rho < rho_bad or not is_finite:
|
|
160
|
+
# if dynamic_nminus and rho > 0: nminus = nminus * max(rho, 1e-4)
|
|
161
|
+
trust_radius = d_radius*nminus
|
|
162
|
+
|
|
163
|
+
# very good step
|
|
164
|
+
elif rho > rho_good and is_finite:
|
|
165
|
+
if (boundary_tol is None) or (trust_radius-d_radius)/trust_radius < boundary_tol:
|
|
166
|
+
trust_radius = max(trust_radius, d_radius*nplus)
|
|
167
|
+
|
|
168
|
+
# prevent very small or large values
|
|
169
|
+
if check_overflow:
|
|
170
|
+
finfo = generic_finfo(g)
|
|
171
|
+
if trust_radius < finfo.tiny*2 or trust_radius > finfo.max/2:
|
|
172
|
+
trust_radius = init
|
|
173
|
+
|
|
174
|
+
# return new trust region and success boolean
|
|
175
|
+
return tofloat(trust_radius), rho > eta and is_finite
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def fixed_radius(
|
|
179
|
+
params: Sequence[torch.Tensor],
|
|
180
|
+
closure: Callable,
|
|
181
|
+
f: float,
|
|
182
|
+
g: torch.Tensor,
|
|
183
|
+
H: LinearOperator,
|
|
184
|
+
d: torch.Tensor,
|
|
185
|
+
trust_radius: float,
|
|
186
|
+
eta: float, # 0.0
|
|
187
|
+
nplus: float, # 3.5
|
|
188
|
+
nminus: float, # 0.25
|
|
189
|
+
rho_good: float, # 0.99
|
|
190
|
+
rho_bad: float, # 1e-4
|
|
191
|
+
boundary_tol: float | None,
|
|
192
|
+
init: float,
|
|
193
|
+
state: Mapping[str, Any],
|
|
194
|
+
settings: Mapping[str, Any],
|
|
195
|
+
radius_fn: Callable | None = torch.linalg.vector_norm,
|
|
196
|
+
) -> tuple[float, bool]:
|
|
197
|
+
return init, True
|
|
198
|
+
|
|
199
|
+
_RADIUS_KEYS = Literal['default', 'fixed']
|
|
200
|
+
_RADIUS_STRATEGIES: dict[_RADIUS_KEYS, _RadiusStrategy] = {
|
|
201
|
+
"default": default_radius,
|
|
202
|
+
"fixed": fixed_radius,
|
|
203
|
+
# "dynamic": partial(default_radius, dynamic_nminus=True)
|
|
204
|
+
}
|
|
205
|
+
|
|
206
|
+
class TrustRegionBase(Module, ABC):
|
|
207
|
+
def __init__(
|
|
208
|
+
self,
|
|
209
|
+
defaults: dict | None,
|
|
210
|
+
hess_module: Chainable,
|
|
211
|
+
# suggested default values:
|
|
212
|
+
# Gould, Nicholas IM, et al. "Sensitivity of trust-region algorithms to their parameters." 4OR 3.3 (2005): 227-241.
|
|
213
|
+
# which I found from https://github.com/patrick-kidger/optimistix/blob/c1dad7e75fc35bd5a4977ac3a872991e51e83d2c/optimistix/_solver/trust_region.py#L113-200
|
|
214
|
+
eta: float, # 0.0
|
|
215
|
+
nplus: float, # 3.5
|
|
216
|
+
nminus: float, # 0.25
|
|
217
|
+
rho_good: float, # 0.99
|
|
218
|
+
rho_bad: float, # 1e-4
|
|
219
|
+
boundary_tol: float | None, # None or 1e-1
|
|
220
|
+
init: float, # 1
|
|
221
|
+
max_attempts: int, # 10
|
|
222
|
+
radius_strategy: _RadiusStrategy | _RADIUS_KEYS, # "default"
|
|
223
|
+
radius_fn: Callable | None, # torch.linalg.vector_norm
|
|
224
|
+
update_freq: int = 1,
|
|
225
|
+
inner: Chainable | None = None,
|
|
226
|
+
):
|
|
227
|
+
if isinstance(radius_strategy, str): radius_strategy = _RADIUS_STRATEGIES[radius_strategy]
|
|
228
|
+
if defaults is None: defaults = {}
|
|
229
|
+
|
|
230
|
+
safe_dict_update_(
|
|
231
|
+
defaults,
|
|
232
|
+
dict(eta=eta, nplus=nplus, nminus=nminus, rho_good=rho_good, rho_bad=rho_bad, init=init,
|
|
233
|
+
update_freq=update_freq, max_attempts=max_attempts, radius_strategy=radius_strategy,
|
|
234
|
+
boundary_tol=boundary_tol)
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
super().__init__(defaults)
|
|
238
|
+
|
|
239
|
+
self._radius_fn = radius_fn
|
|
240
|
+
self.set_child('hess_module', hess_module)
|
|
241
|
+
|
|
242
|
+
if inner is not None:
|
|
243
|
+
self.set_child('inner', inner)
|
|
244
|
+
|
|
245
|
+
@abstractmethod
|
|
246
|
+
def trust_solve(
|
|
247
|
+
self,
|
|
248
|
+
f: float,
|
|
249
|
+
g: torch.Tensor,
|
|
250
|
+
H: LinearOperator,
|
|
251
|
+
radius: float,
|
|
252
|
+
params: list[torch.Tensor],
|
|
253
|
+
closure: Callable,
|
|
254
|
+
settings: Mapping[str, Any],
|
|
255
|
+
) -> torch.Tensor:
|
|
256
|
+
"""Solve Hx=g with a trust region penalty/bound defined by `radius`"""
|
|
257
|
+
... # pylint:disable=unnecessary-ellipsis
|
|
258
|
+
|
|
259
|
+
def trust_region_update(self, var: Var, H: LinearOperator | None) -> None:
|
|
260
|
+
"""updates the state of this module after H or B have been updated, if necessary"""
|
|
261
|
+
|
|
262
|
+
def trust_region_apply(self, var: Var, tensors:list[torch.Tensor], H: LinearOperator | None) -> Var:
|
|
263
|
+
"""Solves the trust region subproblem and outputs ``Var`` with the solution direction."""
|
|
264
|
+
assert H is not None
|
|
265
|
+
|
|
266
|
+
params = TensorList(var.params)
|
|
267
|
+
settings = self.settings[params[0]]
|
|
268
|
+
g = _flatten_tensors(tensors)
|
|
269
|
+
|
|
270
|
+
max_attempts = settings['max_attempts']
|
|
271
|
+
|
|
272
|
+
# loss at x_0
|
|
273
|
+
loss = var.loss
|
|
274
|
+
closure = var.closure
|
|
275
|
+
if closure is None: raise RuntimeError("Trust region requires closure")
|
|
276
|
+
if loss is None: loss = var.get_loss(False)
|
|
277
|
+
loss = tofloat(loss)
|
|
278
|
+
|
|
279
|
+
# trust region step and update
|
|
280
|
+
success = False
|
|
281
|
+
d = None
|
|
282
|
+
while not success:
|
|
283
|
+
max_attempts -= 1
|
|
284
|
+
if max_attempts < 0: break
|
|
285
|
+
|
|
286
|
+
trust_radius = self.global_state.get('trust_radius', settings['init'])
|
|
287
|
+
|
|
288
|
+
# solve Hx=g
|
|
289
|
+
d = self.trust_solve(f=loss, g=g, H=H, radius=trust_radius, params=params, closure=closure, settings=settings)
|
|
290
|
+
|
|
291
|
+
# update trust radius
|
|
292
|
+
radius_strategy: _RadiusStrategy = settings['radius_strategy']
|
|
293
|
+
self.global_state["trust_radius"], success = radius_strategy(
|
|
294
|
+
params=params,
|
|
295
|
+
closure=closure,
|
|
296
|
+
d=d,
|
|
297
|
+
f=loss,
|
|
298
|
+
g=g,
|
|
299
|
+
H=H,
|
|
300
|
+
trust_radius=trust_radius,
|
|
301
|
+
|
|
302
|
+
eta=settings["eta"],
|
|
303
|
+
nplus=settings["nplus"],
|
|
304
|
+
nminus=settings["nminus"],
|
|
305
|
+
rho_good=settings["rho_good"],
|
|
306
|
+
rho_bad=settings["rho_bad"],
|
|
307
|
+
boundary_tol=settings["boundary_tol"],
|
|
308
|
+
init=settings["init"],
|
|
309
|
+
|
|
310
|
+
state=self.global_state,
|
|
311
|
+
settings=settings,
|
|
312
|
+
radius_fn=self._radius_fn,
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
assert d is not None
|
|
316
|
+
if success: var.update = vec_to_tensors(d, params)
|
|
317
|
+
else: var.update = params.zeros_like()
|
|
318
|
+
|
|
319
|
+
return var
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
@final
|
|
323
|
+
@torch.no_grad
|
|
324
|
+
def update(self, var):
|
|
325
|
+
step = self.global_state.get('step', 0)
|
|
326
|
+
self.global_state['step'] = step + 1
|
|
327
|
+
|
|
328
|
+
if step % self.defaults["update_freq"] == 0:
|
|
329
|
+
|
|
330
|
+
hessian_module = self.children['hess_module']
|
|
331
|
+
hessian_module.update(var)
|
|
332
|
+
H = hessian_module.get_H(var)
|
|
333
|
+
self.global_state["H"] = H
|
|
334
|
+
|
|
335
|
+
self.trust_region_update(var, H=H)
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
@final
|
|
339
|
+
@torch.no_grad
|
|
340
|
+
def apply(self, var):
|
|
341
|
+
H = self.global_state.get('H', None)
|
|
342
|
+
|
|
343
|
+
# -------------------------------- inner step -------------------------------- #
|
|
344
|
+
update = var.get_update()
|
|
345
|
+
if 'inner' in self.children:
|
|
346
|
+
update = apply_transform(self.children['inner'], update, params=var.params, grads=var.grad, var=var)
|
|
347
|
+
|
|
348
|
+
# ----------------------------------- apply ---------------------------------- #
|
|
349
|
+
return self.trust_region_apply(var=var, tensors=update, H=H)
|
|
350
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .svrg import SVRG
|