torchzero 0.3.11__py3-none-any.whl → 0.3.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_opts.py +95 -69
- tests/test_tensorlist.py +8 -7
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +225 -72
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +44 -24
- torchzero/modules/__init__.py +13 -5
- torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
- torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
- torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
- torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
- torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
- torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
- torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
- torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
- torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
- torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
- torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
- torchzero/modules/clipping/clipping.py +85 -92
- torchzero/modules/clipping/ema_clipping.py +5 -5
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
- torchzero/modules/experimental/__init__.py +9 -32
- torchzero/modules/experimental/dct.py +2 -2
- torchzero/modules/experimental/fft.py +2 -2
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +27 -14
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/experimental/structural_projections.py +1 -1
- torchzero/modules/functional.py +50 -14
- torchzero/modules/grad_approximation/fdm.py +19 -20
- torchzero/modules/grad_approximation/forward_gradient.py +4 -2
- torchzero/modules/grad_approximation/grad_approximator.py +43 -47
- torchzero/modules/grad_approximation/rfdm.py +144 -122
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +31 -23
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +2 -2
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +69 -44
- torchzero/modules/line_search/backtracking.py +83 -70
- torchzero/modules/line_search/line_search.py +159 -68
- torchzero/modules/line_search/scipy.py +1 -1
- torchzero/modules/line_search/strong_wolfe.py +319 -218
- torchzero/modules/misc/__init__.py +8 -0
- torchzero/modules/misc/debug.py +4 -4
- torchzero/modules/misc/escape.py +9 -7
- torchzero/modules/misc/gradient_accumulation.py +88 -22
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +82 -15
- torchzero/modules/misc/multistep.py +47 -11
- torchzero/modules/misc/regularization.py +5 -9
- torchzero/modules/misc/split.py +55 -35
- torchzero/modules/misc/switch.py +1 -1
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +3 -3
- torchzero/modules/momentum/cautious.py +42 -47
- torchzero/modules/momentum/momentum.py +35 -1
- torchzero/modules/ops/__init__.py +9 -1
- torchzero/modules/ops/binary.py +9 -8
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
- torchzero/modules/ops/multi.py +15 -15
- torchzero/modules/ops/reduce.py +1 -1
- torchzero/modules/ops/utility.py +12 -8
- torchzero/modules/projections/projection.py +4 -4
- torchzero/modules/quasi_newton/__init__.py +1 -16
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
- torchzero/modules/quasi_newton/lbfgs.py +256 -200
- torchzero/modules/quasi_newton/lsr1.py +167 -132
- torchzero/modules/quasi_newton/quasi_newton.py +346 -446
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +252 -0
- torchzero/modules/second_order/__init__.py +2 -1
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +133 -88
- torchzero/modules/second_order/newton_cg.py +141 -80
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +1 -1
- torchzero/modules/step_size/adaptive.py +312 -47
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +97 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/weight_decay.py +65 -64
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +359 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +0 -1
- torchzero/optim/wrappers/fcmaes.py +3 -2
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +2 -2
- torchzero/optim/wrappers/scipy.py +81 -22
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +123 -111
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +226 -154
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/python_tools.py +6 -0
- torchzero/utils/tensorlist.py +105 -34
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.13.dist-info/METADATA +14 -0
- torchzero-0.3.13.dist-info/RECORD +166 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -59
- docs/source/docstring template.py +0 -46
- torchzero/modules/experimental/absoap.py +0 -253
- torchzero/modules/experimental/adadam.py +0 -118
- torchzero/modules/experimental/adamY.py +0 -131
- torchzero/modules/experimental/adam_lambertw.py +0 -149
- torchzero/modules/experimental/adaptive_step_size.py +0 -90
- torchzero/modules/experimental/adasoap.py +0 -177
- torchzero/modules/experimental/cosine.py +0 -214
- torchzero/modules/experimental/cubic_adam.py +0 -97
- torchzero/modules/experimental/eigendescent.py +0 -120
- torchzero/modules/experimental/etf.py +0 -195
- torchzero/modules/experimental/exp_adam.py +0 -113
- torchzero/modules/experimental/expanded_lbfgs.py +0 -141
- torchzero/modules/experimental/hnewton.py +0 -85
- torchzero/modules/experimental/modular_lbfgs.py +0 -265
- torchzero/modules/experimental/parabolic_search.py +0 -220
- torchzero/modules/experimental/subspace_preconditioners.py +0 -145
- torchzero/modules/experimental/tensor_adagrad.py +0 -42
- torchzero/modules/line_search/polynomial.py +0 -233
- torchzero/modules/momentum/matrix_momentum.py +0 -193
- torchzero/modules/optimizers/adagrad.py +0 -165
- torchzero/modules/quasi_newton/trust_region.py +0 -397
- torchzero/modules/smoothing/gaussian.py +0 -198
- torchzero-0.3.11.dist-info/METADATA +0 -404
- torchzero-0.3.11.dist-info/RECORD +0 -159
- torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
- /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,170 @@
|
|
|
1
|
+
# pylint:disable=not-callable
|
|
2
|
+
from collections.abc import Callable
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from ...core import Chainable, Module
|
|
7
|
+
from ...utils import TensorList, vec_to_tensors
|
|
8
|
+
from ...utils.linalg.linear_operator import LinearOperator
|
|
9
|
+
from .trust_region import _RADIUS_KEYS, TrustRegionBase, _RadiusStrategy
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
# code from https://github.com/konstmish/opt_methods/blob/master/optmethods/second_order/cubic.py
|
|
13
|
+
# ported to pytorch and linear operator
|
|
14
|
+
def ls_cubic_solver(f, g:torch.Tensor, H:LinearOperator, M: float, loss_at_params_plus_x_fn: Callable | None, it_max=100, epsilon=1e-8, ):
|
|
15
|
+
"""
|
|
16
|
+
Solve min_z <g, z-x> + 1/2<z-x, H(z-x)> + M/3 ||z-x||^3
|
|
17
|
+
|
|
18
|
+
For explanation of Cauchy point, see "Gradient Descent
|
|
19
|
+
Efficiently Finds the Cubic-Regularized Non-Convex Newton Step"
|
|
20
|
+
https://arxiv.org/pdf/1612.00547.pdf
|
|
21
|
+
Other potential implementations can be found in paper
|
|
22
|
+
"Adaptive cubic regularisation methods"
|
|
23
|
+
https://people.maths.ox.ac.uk/cartis/papers/ARCpI.pdf
|
|
24
|
+
"""
|
|
25
|
+
solver_it = 1
|
|
26
|
+
newton_step = H.solve(g).neg_()
|
|
27
|
+
if M == 0:
|
|
28
|
+
return newton_step, solver_it
|
|
29
|
+
|
|
30
|
+
def cauchy_point(g, H:LinearOperator, M):
|
|
31
|
+
if torch.linalg.vector_norm(g) == 0 or M == 0:
|
|
32
|
+
return 0 * g
|
|
33
|
+
g_dir = g / torch.linalg.vector_norm(g)
|
|
34
|
+
H_g_g = H.matvec(g_dir) @ g_dir
|
|
35
|
+
R = -H_g_g / (2*M) + torch.sqrt((H_g_g/M)**2/4 + torch.linalg.vector_norm(g)/M)
|
|
36
|
+
return -R * g_dir
|
|
37
|
+
|
|
38
|
+
def conv_criterion(s, r):
|
|
39
|
+
"""
|
|
40
|
+
The convergence criterion is an increasing and concave function in r
|
|
41
|
+
and it is equal to 0 only if r is the solution to the cubic problem
|
|
42
|
+
"""
|
|
43
|
+
s_norm = torch.linalg.vector_norm(s)
|
|
44
|
+
return 1/s_norm - 1/r
|
|
45
|
+
|
|
46
|
+
# Solution s satisfies ||s|| >= Cauchy_radius
|
|
47
|
+
r_min = torch.linalg.vector_norm(cauchy_point(g, H, M))
|
|
48
|
+
|
|
49
|
+
if (loss_at_params_plus_x_fn is not None) and (f > loss_at_params_plus_x_fn(newton_step)):
|
|
50
|
+
return newton_step, solver_it
|
|
51
|
+
|
|
52
|
+
r_max = torch.linalg.vector_norm(newton_step)
|
|
53
|
+
if r_max - r_min < epsilon:
|
|
54
|
+
return newton_step, solver_it
|
|
55
|
+
|
|
56
|
+
# id_matrix = torch.eye(g.size(0), device=g.device, dtype=g.dtype)
|
|
57
|
+
s_lam = None
|
|
58
|
+
for _ in range(it_max):
|
|
59
|
+
r_try = (r_min + r_max) / 2
|
|
60
|
+
lam = r_try * M
|
|
61
|
+
s_lam = H.add_diagonal(lam).solve(g).neg()
|
|
62
|
+
# s_lam = -torch.linalg.solve(B + lam*id_matrix, g)
|
|
63
|
+
solver_it += 1
|
|
64
|
+
crit = conv_criterion(s_lam, r_try)
|
|
65
|
+
if torch.abs(crit) < epsilon:
|
|
66
|
+
return s_lam, solver_it
|
|
67
|
+
if crit < 0:
|
|
68
|
+
r_min = r_try
|
|
69
|
+
else:
|
|
70
|
+
r_max = r_try
|
|
71
|
+
if r_max - r_min < epsilon:
|
|
72
|
+
break
|
|
73
|
+
assert s_lam is not None
|
|
74
|
+
return s_lam, solver_it
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class CubicRegularization(TrustRegionBase):
|
|
78
|
+
"""Cubic regularization.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
hess_module (Module | None, optional):
|
|
82
|
+
A module that maintains a hessian approximation (not hessian inverse!).
|
|
83
|
+
This includes all full-matrix quasi-newton methods, ``tz.m.Newton`` and ``tz.m.GaussNewton``.
|
|
84
|
+
When using quasi-newton methods, set `inverse=False` when constructing them.
|
|
85
|
+
eta (float, optional):
|
|
86
|
+
if ratio of actual to predicted rediction is larger than this, step is accepted.
|
|
87
|
+
When :code:`hess_module` is GaussNewton, this can be set to 0. Defaults to 0.15.
|
|
88
|
+
nplus (float, optional): increase factor on successful steps. Defaults to 1.5.
|
|
89
|
+
nminus (float, optional): decrease factor on unsuccessful steps. Defaults to 0.75.
|
|
90
|
+
rho_good (float, optional):
|
|
91
|
+
if ratio of actual to predicted rediction is larger than this, trust region size is multiplied by `nplus`.
|
|
92
|
+
rho_bad (float, optional):
|
|
93
|
+
if ratio of actual to predicted rediction is less than this, trust region size is multiplied by `nminus`.
|
|
94
|
+
init (float, optional): Initial trust region value. Defaults to 1.
|
|
95
|
+
maxiter (float, optional): maximum iterations when solving cubic subproblem, defaults to 1e-7.
|
|
96
|
+
eps (float, optional): epsilon for the solver, defaults to 1e-8.
|
|
97
|
+
update_freq (int, optional): frequency of updating the hessian. Defaults to 1.
|
|
98
|
+
max_attempts (max_attempts, optional):
|
|
99
|
+
maximum number of trust region size size reductions per step. A zero update vector is returned when
|
|
100
|
+
this limit is exceeded. Defaults to 10.
|
|
101
|
+
fallback (bool, optional):
|
|
102
|
+
if ``True``, when ``hess_module`` maintains hessian inverse which can't be inverted efficiently, it will
|
|
103
|
+
be inverted anyway. When ``False`` (default), a ``RuntimeError`` will be raised instead.
|
|
104
|
+
inner (Chainable | None, optional): preconditioning is applied to output of thise module. Defaults to None.
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
Examples:
|
|
108
|
+
Cubic regularized newton
|
|
109
|
+
|
|
110
|
+
.. code-block:: python
|
|
111
|
+
|
|
112
|
+
opt = tz.Modular(
|
|
113
|
+
model.parameters(),
|
|
114
|
+
tz.m.CubicRegularization(tz.m.Newton()),
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
"""
|
|
118
|
+
def __init__(
|
|
119
|
+
self,
|
|
120
|
+
hess_module: Chainable,
|
|
121
|
+
eta: float= 0.0,
|
|
122
|
+
nplus: float = 3.5,
|
|
123
|
+
nminus: float = 0.25,
|
|
124
|
+
rho_good: float = 0.99,
|
|
125
|
+
rho_bad: float = 1e-4,
|
|
126
|
+
init: float = 1,
|
|
127
|
+
max_attempts: int = 10,
|
|
128
|
+
radius_strategy: _RadiusStrategy | _RADIUS_KEYS = 'default',
|
|
129
|
+
maxiter: int = 100,
|
|
130
|
+
eps: float = 1e-8,
|
|
131
|
+
check_decrease:bool=False,
|
|
132
|
+
update_freq: int = 1,
|
|
133
|
+
inner: Chainable | None = None,
|
|
134
|
+
):
|
|
135
|
+
defaults = dict(maxiter=maxiter, eps=eps, check_decrease=check_decrease)
|
|
136
|
+
super().__init__(
|
|
137
|
+
defaults=defaults,
|
|
138
|
+
hess_module=hess_module,
|
|
139
|
+
eta=eta,
|
|
140
|
+
nplus=nplus,
|
|
141
|
+
nminus=nminus,
|
|
142
|
+
rho_good=rho_good,
|
|
143
|
+
rho_bad=rho_bad,
|
|
144
|
+
init=init,
|
|
145
|
+
max_attempts=max_attempts,
|
|
146
|
+
radius_strategy=radius_strategy,
|
|
147
|
+
update_freq=update_freq,
|
|
148
|
+
inner=inner,
|
|
149
|
+
|
|
150
|
+
boundary_tol=None,
|
|
151
|
+
radius_fn=None,
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
def trust_solve(self, f, g, H, radius, params, closure, settings):
|
|
155
|
+
params = TensorList(params)
|
|
156
|
+
|
|
157
|
+
loss_at_params_plus_x_fn = None
|
|
158
|
+
if settings['check_decrease']:
|
|
159
|
+
def closure_plus_x(x):
|
|
160
|
+
x_unflat = vec_to_tensors(x, params)
|
|
161
|
+
params.add_(x_unflat)
|
|
162
|
+
loss_x = closure(False)
|
|
163
|
+
params.sub_(x_unflat)
|
|
164
|
+
return loss_x
|
|
165
|
+
loss_at_params_plus_x_fn = closure_plus_x
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
d, _ = ls_cubic_solver(f=f, g=g, H=H, M=1/radius, loss_at_params_plus_x_fn=loss_at_params_plus_x_fn,
|
|
169
|
+
it_max=settings['maxiter'], epsilon=settings['eps'])
|
|
170
|
+
return d.neg_()
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
# pylint:disable=not-callable
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
from ...core import Chainable, Module
|
|
5
|
+
from ...utils import TensorList, vec_to_tensors
|
|
6
|
+
from .trust_region import _RADIUS_KEYS, TrustRegionBase, _RadiusStrategy
|
|
7
|
+
|
|
8
|
+
class Dogleg(TrustRegionBase):
|
|
9
|
+
"""Dogleg trust region algorithm.
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
hess_module (Module | None, optional):
|
|
14
|
+
A module that maintains a hessian approximation (not hessian inverse!).
|
|
15
|
+
This includes all full-matrix quasi-newton methods, ``tz.m.Newton`` and ``tz.m.GaussNewton``.
|
|
16
|
+
When using quasi-newton methods, set `inverse=False` when constructing them.
|
|
17
|
+
eta (float, optional):
|
|
18
|
+
if ratio of actual to predicted rediction is larger than this, step is accepted.
|
|
19
|
+
When :code:`hess_module` is GaussNewton, this can be set to 0. Defaults to 0.15.
|
|
20
|
+
nplus (float, optional): increase factor on successful steps. Defaults to 1.5.
|
|
21
|
+
nminus (float, optional): decrease factor on unsuccessful steps. Defaults to 0.75.
|
|
22
|
+
rho_good (float, optional):
|
|
23
|
+
if ratio of actual to predicted rediction is larger than this, trust region size is multiplied by `nplus`.
|
|
24
|
+
rho_bad (float, optional):
|
|
25
|
+
if ratio of actual to predicted rediction is less than this, trust region size is multiplied by `nminus`.
|
|
26
|
+
init (float, optional): Initial trust region value. Defaults to 1.
|
|
27
|
+
update_freq (int, optional): frequency of updating the hessian. Defaults to 1.
|
|
28
|
+
max_attempts (max_attempts, optional):
|
|
29
|
+
maximum number of trust region size size reductions per step. A zero update vector is returned when
|
|
30
|
+
this limit is exceeded. Defaults to 10.
|
|
31
|
+
inner (Chainable | None, optional): preconditioning is applied to output of thise module. Defaults to None.
|
|
32
|
+
|
|
33
|
+
"""
|
|
34
|
+
def __init__(
|
|
35
|
+
self,
|
|
36
|
+
hess_module: Chainable,
|
|
37
|
+
eta: float= 0.0,
|
|
38
|
+
nplus: float = 2,
|
|
39
|
+
nminus: float = 0.25,
|
|
40
|
+
rho_good: float = 0.75,
|
|
41
|
+
rho_bad: float = 0.25,
|
|
42
|
+
boundary_tol: float | None = None,
|
|
43
|
+
init: float = 1,
|
|
44
|
+
max_attempts: int = 10,
|
|
45
|
+
radius_strategy: _RadiusStrategy | _RADIUS_KEYS = 'default',
|
|
46
|
+
update_freq: int = 1,
|
|
47
|
+
inner: Chainable | None = None,
|
|
48
|
+
):
|
|
49
|
+
defaults = dict()
|
|
50
|
+
super().__init__(
|
|
51
|
+
defaults=defaults,
|
|
52
|
+
hess_module=hess_module,
|
|
53
|
+
eta=eta,
|
|
54
|
+
nplus=nplus,
|
|
55
|
+
nminus=nminus,
|
|
56
|
+
rho_good=rho_good,
|
|
57
|
+
rho_bad=rho_bad,
|
|
58
|
+
boundary_tol=boundary_tol,
|
|
59
|
+
init=init,
|
|
60
|
+
max_attempts=max_attempts,
|
|
61
|
+
radius_strategy=radius_strategy,
|
|
62
|
+
update_freq=update_freq,
|
|
63
|
+
inner=inner,
|
|
64
|
+
|
|
65
|
+
radius_fn=torch.linalg.vector_norm,
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
def trust_solve(self, f, g, H, radius, params, closure, settings):
|
|
69
|
+
if radius > 2: radius = self.global_state['radius'] = 2
|
|
70
|
+
eps = torch.finfo(g.dtype).tiny * 2
|
|
71
|
+
|
|
72
|
+
gHg = g.dot(H.matvec(g))
|
|
73
|
+
if gHg <= eps:
|
|
74
|
+
return (radius / torch.linalg.vector_norm(g)) * g # pylint:disable=not-callable
|
|
75
|
+
|
|
76
|
+
p_cauchy = (g.dot(g) / gHg) * g
|
|
77
|
+
p_newton = H.solve(g)
|
|
78
|
+
|
|
79
|
+
a = p_newton - p_cauchy
|
|
80
|
+
b = p_cauchy
|
|
81
|
+
|
|
82
|
+
aa = a.dot(a)
|
|
83
|
+
if aa < eps:
|
|
84
|
+
return (radius / torch.linalg.vector_norm(g)) * g # pylint:disable=not-callable
|
|
85
|
+
|
|
86
|
+
ab = a.dot(b)
|
|
87
|
+
bb = b.dot(b)
|
|
88
|
+
c = bb - radius**2
|
|
89
|
+
discriminant = (2*ab)**2 - 4*aa*c
|
|
90
|
+
beta = (-2*ab + torch.sqrt(discriminant.clip(min=0))) / (2 * aa)
|
|
91
|
+
return p_cauchy + beta * (p_newton - p_cauchy)
|
|
92
|
+
|
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
# pylint:disable=not-callable
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
from ...core import Chainable, Module
|
|
5
|
+
from ...utils.linalg import linear_operator
|
|
6
|
+
from .trust_region import _RADIUS_KEYS, TrustRegionBase, _RadiusStrategy
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class LevenbergMarquardt(TrustRegionBase):
|
|
10
|
+
"""Levenberg-Marquardt trust region algorithm.
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
hess_module (Module | None, optional):
|
|
15
|
+
A module that maintains a hessian approximation (not hessian inverse!).
|
|
16
|
+
This includes all full-matrix quasi-newton methods, ``tz.m.Newton`` and ``tz.m.GaussNewton``.
|
|
17
|
+
When using quasi-newton methods, set `inverse=False` when constructing them.
|
|
18
|
+
y (float, optional):
|
|
19
|
+
when ``y=0``, identity matrix is added to hessian, when ``y=1``, diagonal of the hessian approximation
|
|
20
|
+
is added. Values between interpolate. This should only be used with Gauss-Newton. Defaults to 0.
|
|
21
|
+
eta (float, optional):
|
|
22
|
+
if ratio of actual to predicted rediction is larger than this, step is accepted.
|
|
23
|
+
When :code:`hess_module` is GaussNewton, this can be set to 0. Defaults to 0.15.
|
|
24
|
+
nplus (float, optional): increase factor on successful steps. Defaults to 1.5.
|
|
25
|
+
nminus (float, optional): decrease factor on unsuccessful steps. Defaults to 0.75.
|
|
26
|
+
rho_good (float, optional):
|
|
27
|
+
if ratio of actual to predicted rediction is larger than this, trust region size is multiplied by `nplus`.
|
|
28
|
+
rho_bad (float, optional):
|
|
29
|
+
if ratio of actual to predicted rediction is less than this, trust region size is multiplied by `nminus`.
|
|
30
|
+
init (float, optional): Initial trust region value. Defaults to 1.
|
|
31
|
+
update_freq (int, optional): frequency of updating the hessian. Defaults to 1.
|
|
32
|
+
max_attempts (max_attempts, optional):
|
|
33
|
+
maximum number of trust region size size reductions per step. A zero update vector is returned when
|
|
34
|
+
this limit is exceeded. Defaults to 10.
|
|
35
|
+
fallback (bool, optional):
|
|
36
|
+
if ``True``, when ``hess_module`` maintains hessian inverse which can't be inverted efficiently, it will
|
|
37
|
+
be inverted anyway. When ``False`` (default), a ``RuntimeError`` will be raised instead.
|
|
38
|
+
inner (Chainable | None, optional): preconditioning is applied to output of thise module. Defaults to None.
|
|
39
|
+
|
|
40
|
+
Examples:
|
|
41
|
+
Gauss-Newton with Levenberg-Marquardt trust-region
|
|
42
|
+
|
|
43
|
+
.. code-block:: python
|
|
44
|
+
|
|
45
|
+
opt = tz.Modular(
|
|
46
|
+
model.parameters(),
|
|
47
|
+
tz.m.LevenbergMarquardt(tz.m.GaussNewton()),
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
LM-SR1
|
|
51
|
+
|
|
52
|
+
.. code-block:: python
|
|
53
|
+
|
|
54
|
+
opt = tz.Modular(
|
|
55
|
+
model.parameters(),
|
|
56
|
+
tz.m.LevenbergMarquardt(tz.m.SR1(inverse=False)),
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
First order trust region (hessian is assumed to be identity)
|
|
60
|
+
|
|
61
|
+
.. code-block:: python
|
|
62
|
+
|
|
63
|
+
opt = tz.Modular(
|
|
64
|
+
model.parameters(),
|
|
65
|
+
tz.m.LevenbergMarquardt(tz.m.Identity()),
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
"""
|
|
69
|
+
def __init__(
|
|
70
|
+
self,
|
|
71
|
+
hess_module: Chainable,
|
|
72
|
+
eta: float= 0.0,
|
|
73
|
+
nplus: float = 3.5,
|
|
74
|
+
nminus: float = 0.25,
|
|
75
|
+
rho_good: float = 0.99,
|
|
76
|
+
rho_bad: float = 1e-4,
|
|
77
|
+
init: float = 1,
|
|
78
|
+
max_attempts: int = 10,
|
|
79
|
+
radius_strategy: _RadiusStrategy | _RADIUS_KEYS = 'default',
|
|
80
|
+
y: float = 0,
|
|
81
|
+
fallback: bool = False,
|
|
82
|
+
update_freq: int = 1,
|
|
83
|
+
inner: Chainable | None = None,
|
|
84
|
+
):
|
|
85
|
+
defaults = dict(y=y, fallback=fallback)
|
|
86
|
+
super().__init__(
|
|
87
|
+
defaults=defaults,
|
|
88
|
+
hess_module=hess_module,
|
|
89
|
+
eta=eta,
|
|
90
|
+
nplus=nplus,
|
|
91
|
+
nminus=nminus,
|
|
92
|
+
rho_good=rho_good,
|
|
93
|
+
rho_bad=rho_bad,
|
|
94
|
+
init=init,
|
|
95
|
+
max_attempts=max_attempts,
|
|
96
|
+
radius_strategy=radius_strategy,
|
|
97
|
+
update_freq=update_freq,
|
|
98
|
+
inner=inner,
|
|
99
|
+
|
|
100
|
+
boundary_tol=None,
|
|
101
|
+
radius_fn=None,
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
def trust_solve(self, f, g, H, radius, params, closure, settings):
|
|
105
|
+
y = settings['y']
|
|
106
|
+
|
|
107
|
+
if isinstance(H, linear_operator.DenseInverse):
|
|
108
|
+
if settings['fallback']:
|
|
109
|
+
H = H.to_dense()
|
|
110
|
+
else:
|
|
111
|
+
raise RuntimeError(
|
|
112
|
+
f"{self.children['hess_module']} maintains a hessian inverse. "
|
|
113
|
+
"LevenbergMarquardt requires the hessian, not the inverse. "
|
|
114
|
+
"If that module is a quasi-newton module, pass `inverse=False` on initialization. "
|
|
115
|
+
"Or pass `fallback=True` to LevenbergMarquardt to allow inverting the hessian inverse, "
|
|
116
|
+
"however that can be inefficient and unstable."
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
reg = 1/radius
|
|
120
|
+
if y == 0:
|
|
121
|
+
return H.add_diagonal(reg).solve(g)
|
|
122
|
+
|
|
123
|
+
diag = H.diagonal()
|
|
124
|
+
diag = torch.where(diag < torch.finfo(diag.dtype).tiny * 2, 1, diag)
|
|
125
|
+
if y != 1: diag = (diag*y) + (1-y)
|
|
126
|
+
return H.add_diagonal(diag*reg).solve(g)
|
|
127
|
+
|
|
128
|
+
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from ...core import Chainable, Module
|
|
4
|
+
from ...utils.linalg import cg, linear_operator
|
|
5
|
+
from .trust_region import _RADIUS_KEYS, TrustRegionBase, _RadiusStrategy
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class TrustCG(TrustRegionBase):
|
|
9
|
+
"""Trust region via Steihaug-Toint Conjugate Gradient method.
|
|
10
|
+
|
|
11
|
+
.. note::
|
|
12
|
+
|
|
13
|
+
If you wish to use exact hessian, use the matrix-free :code:`tz.m.NewtonCGSteihaug`
|
|
14
|
+
which only uses hessian-vector products. While passing ``tz.m.Newton`` to this
|
|
15
|
+
is possible, it is usually less efficient.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
hess_module (Module | None, optional):
|
|
19
|
+
A module that maintains a hessian approximation (not hessian inverse!).
|
|
20
|
+
This includes all full-matrix quasi-newton methods, ``tz.m.Newton`` and ``tz.m.GaussNewton``.
|
|
21
|
+
When using quasi-newton methods, set `inverse=False` when constructing them.
|
|
22
|
+
eta (float, optional):
|
|
23
|
+
if ratio of actual to predicted rediction is larger than this, step is accepted.
|
|
24
|
+
When :code:`hess_module` is GaussNewton, this can be set to 0. Defaults to 0.15.
|
|
25
|
+
nplus (float, optional): increase factor on successful steps. Defaults to 1.5.
|
|
26
|
+
nminus (float, optional): decrease factor on unsuccessful steps. Defaults to 0.75.
|
|
27
|
+
rho_good (float, optional):
|
|
28
|
+
if ratio of actual to predicted rediction is larger than this, trust region size is multiplied by `nplus`.
|
|
29
|
+
rho_bad (float, optional):
|
|
30
|
+
if ratio of actual to predicted rediction is less than this, trust region size is multiplied by `nminus`.
|
|
31
|
+
init (float, optional): Initial trust region value. Defaults to 1.
|
|
32
|
+
update_freq (int, optional): frequency of updating the hessian. Defaults to 1.
|
|
33
|
+
reg (int, optional): regularization parameter for conjugate gradient. Defaults to 0.
|
|
34
|
+
max_attempts (max_attempts, optional):
|
|
35
|
+
maximum number of trust region size size reductions per step. A zero update vector is returned when
|
|
36
|
+
this limit is exceeded. Defaults to 10.
|
|
37
|
+
boundary_tol (float | None, optional):
|
|
38
|
+
The trust region only increases when suggested step's norm is at least `(1-boundary_tol)*trust_region`.
|
|
39
|
+
This prevents increasing trust region when solution is not on the boundary. Defaults to 1e-2.
|
|
40
|
+
prefer_exact (bool, optional):
|
|
41
|
+
when exact solution can be easily calculated without CG (e.g. hessian is stored as scaled identity),
|
|
42
|
+
uses the exact solution. If False, always uses CG. Defaults to True.
|
|
43
|
+
inner (Chainable | None, optional): preconditioning is applied to output of thise module. Defaults to None.
|
|
44
|
+
|
|
45
|
+
Examples:
|
|
46
|
+
Trust-SR1
|
|
47
|
+
|
|
48
|
+
.. code-block:: python
|
|
49
|
+
|
|
50
|
+
opt = tz.Modular(
|
|
51
|
+
model.parameters(),
|
|
52
|
+
tz.m.TrustCG(hess_module=tz.m.SR1(inverse=False)),
|
|
53
|
+
)
|
|
54
|
+
"""
|
|
55
|
+
def __init__(
|
|
56
|
+
self,
|
|
57
|
+
hess_module: Chainable,
|
|
58
|
+
eta: float= 0.0,
|
|
59
|
+
nplus: float = 3.5,
|
|
60
|
+
nminus: float = 0.25,
|
|
61
|
+
rho_good: float = 0.99,
|
|
62
|
+
rho_bad: float = 1e-4,
|
|
63
|
+
boundary_tol: float | None = 1e-1,
|
|
64
|
+
init: float = 1,
|
|
65
|
+
max_attempts: int = 10,
|
|
66
|
+
radius_strategy: _RadiusStrategy | _RADIUS_KEYS = 'default',
|
|
67
|
+
reg: float = 0,
|
|
68
|
+
cg_tol: float = 1e-4,
|
|
69
|
+
prefer_exact: bool = True,
|
|
70
|
+
update_freq: int = 1,
|
|
71
|
+
inner: Chainable | None = None,
|
|
72
|
+
):
|
|
73
|
+
defaults = dict(reg=reg, prefer_exact=prefer_exact, cg_tol=cg_tol)
|
|
74
|
+
super().__init__(
|
|
75
|
+
defaults=defaults,
|
|
76
|
+
hess_module=hess_module,
|
|
77
|
+
eta=eta,
|
|
78
|
+
nplus=nplus,
|
|
79
|
+
nminus=nminus,
|
|
80
|
+
rho_good=rho_good,
|
|
81
|
+
rho_bad=rho_bad,
|
|
82
|
+
boundary_tol=boundary_tol,
|
|
83
|
+
init=init,
|
|
84
|
+
max_attempts=max_attempts,
|
|
85
|
+
radius_strategy=radius_strategy,
|
|
86
|
+
update_freq=update_freq,
|
|
87
|
+
inner=inner,
|
|
88
|
+
|
|
89
|
+
radius_fn=torch.linalg.vector_norm,
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
def trust_solve(self, f, g, H, radius, params, closure, settings):
|
|
93
|
+
if settings['prefer_exact'] and isinstance(H, linear_operator.ScaledIdentity):
|
|
94
|
+
return H.solve_bounded(g, radius)
|
|
95
|
+
|
|
96
|
+
x, _ = cg(H.matvec, g, trust_radius=radius, reg=settings['reg'], tol=settings["cg_tol"])
|
|
97
|
+
return x
|