torchzero 0.3.10__py3-none-any.whl → 0.3.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +2 -3
- tests/test_opts.py +140 -100
- tests/test_tensorlist.py +8 -7
- tests/test_vars.py +1 -0
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +335 -50
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +197 -70
- torchzero/modules/__init__.py +13 -4
- torchzero/modules/adaptive/__init__.py +30 -0
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/adaptive/adahessian.py +224 -0
- torchzero/modules/{optimizers → adaptive}/adam.py +6 -8
- torchzero/modules/adaptive/adan.py +96 -0
- torchzero/modules/adaptive/adaptive_heavyball.py +54 -0
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/adaptive/esgd.py +171 -0
- torchzero/modules/{optimizers → adaptive}/lion.py +1 -1
- torchzero/modules/{experimental/spectral.py → adaptive/lmadagrad.py} +94 -71
- torchzero/modules/adaptive/mars.py +79 -0
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/adaptive/msam.py +188 -0
- torchzero/modules/{optimizers → adaptive}/muon.py +29 -5
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/orthograd.py +1 -1
- torchzero/modules/{optimizers → adaptive}/rmsprop.py +7 -4
- torchzero/modules/{optimizers → adaptive}/rprop.py +42 -10
- torchzero/modules/adaptive/sam.py +163 -0
- torchzero/modules/{optimizers → adaptive}/shampoo.py +47 -9
- torchzero/modules/{optimizers → adaptive}/soap.py +52 -65
- torchzero/modules/adaptive/sophia_h.py +185 -0
- torchzero/modules/clipping/clipping.py +115 -25
- torchzero/modules/clipping/ema_clipping.py +31 -17
- torchzero/modules/clipping/growth_clipping.py +8 -7
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/conjugate_gradient/cg.py +355 -0
- torchzero/modules/experimental/__init__.py +13 -19
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +5 -42
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +32 -15
- torchzero/modules/experimental/reduce_outward_lr.py +4 -4
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +13 -55
- torchzero/modules/functional.py +52 -6
- torchzero/modules/grad_approximation/fdm.py +30 -4
- torchzero/modules/grad_approximation/forward_gradient.py +16 -4
- torchzero/modules/grad_approximation/grad_approximator.py +51 -10
- torchzero/modules/grad_approximation/rfdm.py +321 -52
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +164 -93
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +124 -0
- torchzero/modules/line_search/backtracking.py +95 -57
- torchzero/modules/line_search/line_search.py +171 -22
- torchzero/modules/line_search/scipy.py +3 -3
- torchzero/modules/line_search/strong_wolfe.py +327 -199
- torchzero/modules/misc/__init__.py +35 -0
- torchzero/modules/misc/debug.py +48 -0
- torchzero/modules/misc/escape.py +62 -0
- torchzero/modules/misc/gradient_accumulation.py +136 -0
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +383 -0
- torchzero/modules/misc/multistep.py +194 -0
- torchzero/modules/misc/regularization.py +167 -0
- torchzero/modules/misc/split.py +123 -0
- torchzero/modules/{ops → misc}/switch.py +45 -4
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +9 -9
- torchzero/modules/momentum/cautious.py +51 -19
- torchzero/modules/momentum/momentum.py +37 -2
- torchzero/modules/ops/__init__.py +11 -31
- torchzero/modules/ops/accumulate.py +6 -10
- torchzero/modules/ops/binary.py +81 -34
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +16 -39
- torchzero/modules/ops/multi.py +82 -21
- torchzero/modules/ops/reduce.py +16 -8
- torchzero/modules/ops/unary.py +29 -13
- torchzero/modules/ops/utility.py +30 -18
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +3 -1
- torchzero/modules/projections/projection.py +190 -96
- torchzero/modules/quasi_newton/__init__.py +9 -14
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -0
- torchzero/modules/quasi_newton/lbfgs.py +286 -173
- torchzero/modules/quasi_newton/lsr1.py +185 -106
- torchzero/modules/quasi_newton/quasi_newton.py +816 -268
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +252 -0
- torchzero/modules/second_order/__init__.py +3 -2
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +292 -68
- torchzero/modules/second_order/newton_cg.py +365 -15
- torchzero/modules/second_order/nystrom.py +104 -1
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/laplacian.py +14 -4
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +387 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +97 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +94 -11
- torchzero/modules/wrappers/optim_wrapper.py +29 -1
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +359 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +39 -3
- torchzero/optim/wrappers/fcmaes.py +24 -15
- torchzero/optim/wrappers/mads.py +5 -6
- torchzero/optim/wrappers/nevergrad.py +16 -1
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +3 -3
- torchzero/optim/wrappers/scipy.py +86 -25
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +126 -114
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +369 -58
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/python_tools.py +16 -0
- torchzero/utils/tensorlist.py +134 -51
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.13.dist-info/METADATA +14 -0
- torchzero-0.3.13.dist-info/RECORD +166 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -57
- torchzero/modules/experimental/absoap.py +0 -250
- torchzero/modules/experimental/adadam.py +0 -112
- torchzero/modules/experimental/adamY.py +0 -125
- torchzero/modules/experimental/adasoap.py +0 -172
- torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
- torchzero/modules/experimental/eigendescent.py +0 -117
- torchzero/modules/experimental/etf.py +0 -172
- torchzero/modules/experimental/soapy.py +0 -163
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/experimental/subspace_preconditioners.py +0 -138
- torchzero/modules/experimental/tada.py +0 -38
- torchzero/modules/line_search/trust_region.py +0 -73
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/adaptive.py +0 -93
- torchzero/modules/lr/lr.py +0 -63
- torchzero/modules/momentum/matrix_momentum.py +0 -166
- torchzero/modules/ops/debug.py +0 -25
- torchzero/modules/ops/misc.py +0 -418
- torchzero/modules/ops/split.py +0 -75
- torchzero/modules/optimizers/__init__.py +0 -18
- torchzero/modules/optimizers/adagrad.py +0 -155
- torchzero/modules/optimizers/sophia_h.py +0 -129
- torchzero/modules/quasi_newton/cg.py +0 -268
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -266
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero/modules/smoothing/gaussian.py +0 -164
- torchzero-0.3.10.dist-info/METADATA +0 -379
- torchzero-0.3.10.dist-info/RECORD +0 -139
- torchzero-0.3.10.dist-info/licenses/LICENSE +0 -21
- {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
|
@@ -1,31 +1,115 @@
|
|
|
1
|
-
from collections.abc import Callable
|
|
2
|
-
from typing import Literal, overload
|
|
3
1
|
import warnings
|
|
2
|
+
import math
|
|
3
|
+
from typing import Literal, cast
|
|
4
|
+
from operator import itemgetter
|
|
4
5
|
import torch
|
|
5
6
|
|
|
6
|
-
from ...
|
|
7
|
+
from ...core import Chainable, Module, apply_transform
|
|
8
|
+
from ...utils import TensorList, as_tensorlist, tofloat
|
|
7
9
|
from ...utils.derivatives import hvp, hvp_fd_central, hvp_fd_forward
|
|
8
|
-
|
|
9
|
-
from
|
|
10
|
-
from ...utils.linalg.solve import cg
|
|
10
|
+
from ...utils.linalg.solve import cg, minres, find_within_trust_radius
|
|
11
|
+
from ..trust_region.trust_region import default_radius
|
|
11
12
|
|
|
12
13
|
class NewtonCG(Module):
|
|
14
|
+
"""Newton's method with a matrix-free conjugate gradient or minimial-residual solver.
|
|
15
|
+
|
|
16
|
+
This optimizer implements Newton's method using a matrix-free conjugate
|
|
17
|
+
gradient (CG) or a minimal-residual (MINRES) solver to approximate the search direction. Instead of
|
|
18
|
+
forming the full Hessian matrix, it only requires Hessian-vector products
|
|
19
|
+
(HVPs). These can be calculated efficiently using automatic
|
|
20
|
+
differentiation or approximated using finite differences.
|
|
21
|
+
|
|
22
|
+
.. note::
|
|
23
|
+
In most cases NewtonCG should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
|
|
24
|
+
|
|
25
|
+
.. note::
|
|
26
|
+
This module requires the a closure passed to the optimizer step,
|
|
27
|
+
as it needs to re-evaluate the loss and gradients for calculating HVPs.
|
|
28
|
+
The closure must accept a ``backward`` argument (refer to documentation).
|
|
29
|
+
|
|
30
|
+
.. warning::
|
|
31
|
+
CG may fail if hessian is not positive-definite.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
maxiter (int | None, optional):
|
|
35
|
+
Maximum number of iterations for the conjugate gradient solver.
|
|
36
|
+
By default, this is set to the number of dimensions in the
|
|
37
|
+
objective function, which is the theoretical upper bound for CG
|
|
38
|
+
convergence. Setting this to a smaller value (truncated Newton)
|
|
39
|
+
can still generate good search directions. Defaults to None.
|
|
40
|
+
tol (float, optional):
|
|
41
|
+
Relative tolerance for the conjugate gradient solver to determine
|
|
42
|
+
convergence. Defaults to 1e-4.
|
|
43
|
+
reg (float, optional):
|
|
44
|
+
Regularization parameter (damping) added to the Hessian diagonal.
|
|
45
|
+
This helps ensure the system is positive-definite. Defaults to 1e-8.
|
|
46
|
+
hvp_method (str, optional):
|
|
47
|
+
Determines how Hessian-vector products are evaluated.
|
|
48
|
+
|
|
49
|
+
- ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
|
|
50
|
+
This requires creating a graph for the gradient.
|
|
51
|
+
- ``"forward"``: Use a forward finite difference formula to
|
|
52
|
+
approximate the HVP. This requires one extra gradient evaluation.
|
|
53
|
+
- ``"central"``: Use a central finite difference formula for a
|
|
54
|
+
more accurate HVP approximation. This requires two extra
|
|
55
|
+
gradient evaluations.
|
|
56
|
+
Defaults to "autograd".
|
|
57
|
+
h (float, optional):
|
|
58
|
+
The step size for finite differences if :code:`hvp_method` is
|
|
59
|
+
``"forward"`` or ``"central"``. Defaults to 1e-3.
|
|
60
|
+
warm_start (bool, optional):
|
|
61
|
+
If ``True``, the conjugate gradient solver is initialized with the
|
|
62
|
+
solution from the previous optimization step. This can accelerate
|
|
63
|
+
convergence, especially in truncated Newton methods.
|
|
64
|
+
Defaults to False.
|
|
65
|
+
inner (Chainable | None, optional):
|
|
66
|
+
NewtonCG will attempt to apply preconditioning to the output of this module.
|
|
67
|
+
|
|
68
|
+
Examples:
|
|
69
|
+
Newton-CG with a backtracking line search:
|
|
70
|
+
|
|
71
|
+
.. code-block:: python
|
|
72
|
+
|
|
73
|
+
opt = tz.Modular(
|
|
74
|
+
model.parameters(),
|
|
75
|
+
tz.m.NewtonCG(),
|
|
76
|
+
tz.m.Backtracking()
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
Truncated Newton method (useful for large-scale problems):
|
|
80
|
+
|
|
81
|
+
.. code-block:: python
|
|
82
|
+
|
|
83
|
+
opt = tz.Modular(
|
|
84
|
+
model.parameters(),
|
|
85
|
+
tz.m.NewtonCG(maxiter=10, warm_start=True),
|
|
86
|
+
tz.m.Backtracking()
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
"""
|
|
13
91
|
def __init__(
|
|
14
92
|
self,
|
|
15
|
-
maxiter=None,
|
|
16
|
-
tol=1e-
|
|
93
|
+
maxiter: int | None = None,
|
|
94
|
+
tol: float = 1e-8,
|
|
17
95
|
reg: float = 1e-8,
|
|
18
|
-
hvp_method: Literal["forward", "central", "autograd"] = "
|
|
19
|
-
|
|
96
|
+
hvp_method: Literal["forward", "central", "autograd"] = "autograd",
|
|
97
|
+
solver: Literal['cg', 'minres', 'minres_npc'] = 'cg',
|
|
98
|
+
h: float = 1e-3,
|
|
99
|
+
miniter:int = 1,
|
|
20
100
|
warm_start=False,
|
|
21
101
|
inner: Chainable | None = None,
|
|
22
102
|
):
|
|
23
|
-
defaults =
|
|
103
|
+
defaults = locals().copy()
|
|
104
|
+
del defaults['self'], defaults['inner']
|
|
24
105
|
super().__init__(defaults,)
|
|
25
106
|
|
|
26
107
|
if inner is not None:
|
|
27
108
|
self.set_child('inner', inner)
|
|
28
109
|
|
|
110
|
+
self._num_hvps = 0
|
|
111
|
+
self._num_hvps_last_step = 0
|
|
112
|
+
|
|
29
113
|
@torch.no_grad
|
|
30
114
|
def step(self, var):
|
|
31
115
|
params = TensorList(var.params)
|
|
@@ -37,14 +121,17 @@ class NewtonCG(Module):
|
|
|
37
121
|
reg = settings['reg']
|
|
38
122
|
maxiter = settings['maxiter']
|
|
39
123
|
hvp_method = settings['hvp_method']
|
|
124
|
+
solver = settings['solver'].lower().strip()
|
|
40
125
|
h = settings['h']
|
|
41
126
|
warm_start = settings['warm_start']
|
|
42
127
|
|
|
128
|
+
self._num_hvps_last_step = 0
|
|
43
129
|
# ---------------------- Hessian vector product function --------------------- #
|
|
44
130
|
if hvp_method == 'autograd':
|
|
45
131
|
grad = var.get_grad(create_graph=True)
|
|
46
132
|
|
|
47
133
|
def H_mm(x):
|
|
134
|
+
self._num_hvps_last_step += 1
|
|
48
135
|
with torch.enable_grad():
|
|
49
136
|
return TensorList(hvp(params, grad, x, retain_graph=True))
|
|
50
137
|
|
|
@@ -55,10 +142,12 @@ class NewtonCG(Module):
|
|
|
55
142
|
|
|
56
143
|
if hvp_method == 'forward':
|
|
57
144
|
def H_mm(x):
|
|
145
|
+
self._num_hvps_last_step += 1
|
|
58
146
|
return TensorList(hvp_fd_forward(closure, params, x, h=h, g_0=grad, normalize=True)[1])
|
|
59
147
|
|
|
60
148
|
elif hvp_method == 'central':
|
|
61
149
|
def H_mm(x):
|
|
150
|
+
self._num_hvps_last_step += 1
|
|
62
151
|
return TensorList(hvp_fd_central(closure, params, x, h=h, normalize=True)[1])
|
|
63
152
|
|
|
64
153
|
else:
|
|
@@ -68,18 +157,279 @@ class NewtonCG(Module):
|
|
|
68
157
|
# -------------------------------- inner step -------------------------------- #
|
|
69
158
|
b = var.get_update()
|
|
70
159
|
if 'inner' in self.children:
|
|
71
|
-
b =
|
|
160
|
+
b = apply_transform(self.children['inner'], b, params=params, grads=grad, var=var)
|
|
161
|
+
b = as_tensorlist(b)
|
|
72
162
|
|
|
73
163
|
# ---------------------------------- run cg ---------------------------------- #
|
|
74
164
|
x0 = None
|
|
75
165
|
if warm_start: x0 = self.get_state(params, 'prev_x', cls=TensorList) # initialized to 0 which is default anyway
|
|
76
166
|
|
|
77
|
-
|
|
167
|
+
if solver == 'cg':
|
|
168
|
+
d, _ = cg(A_mm=H_mm, b=b, x0=x0, tol=tol, maxiter=maxiter, miniter=self.defaults["miniter"],reg=reg)
|
|
169
|
+
|
|
170
|
+
elif solver == 'minres':
|
|
171
|
+
d = minres(A_mm=H_mm, b=b, x0=x0, tol=tol, maxiter=maxiter, reg=reg, npc_terminate=False)
|
|
172
|
+
|
|
173
|
+
elif solver == 'minres_npc':
|
|
174
|
+
d = minres(A_mm=H_mm, b=b, x0=x0, tol=tol, maxiter=maxiter, reg=reg, npc_terminate=True)
|
|
175
|
+
|
|
176
|
+
else:
|
|
177
|
+
raise ValueError(f"Unknown solver {solver}")
|
|
178
|
+
|
|
78
179
|
if warm_start:
|
|
79
180
|
assert x0 is not None
|
|
80
|
-
x0.copy_(
|
|
181
|
+
x0.copy_(d)
|
|
81
182
|
|
|
82
|
-
var.update =
|
|
183
|
+
var.update = d
|
|
184
|
+
|
|
185
|
+
self._num_hvps += self._num_hvps_last_step
|
|
83
186
|
return var
|
|
84
187
|
|
|
85
188
|
|
|
189
|
+
class NewtonCGSteihaug(Module):
|
|
190
|
+
"""Trust region Newton's method with a matrix-free Steihaug-Toint conjugate gradient or MINRES solver.
|
|
191
|
+
|
|
192
|
+
This optimizer implements Newton's method using a matrix-free conjugate
|
|
193
|
+
gradient (CG) solver to approximate the search direction. Instead of
|
|
194
|
+
forming the full Hessian matrix, it only requires Hessian-vector products
|
|
195
|
+
(HVPs). These can be calculated efficiently using automatic
|
|
196
|
+
differentiation or approximated using finite differences.
|
|
197
|
+
|
|
198
|
+
.. note::
|
|
199
|
+
In most cases NewtonCGSteihaug should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
|
|
200
|
+
|
|
201
|
+
.. note::
|
|
202
|
+
This module requires the a closure passed to the optimizer step,
|
|
203
|
+
as it needs to re-evaluate the loss and gradients for calculating HVPs.
|
|
204
|
+
The closure must accept a ``backward`` argument (refer to documentation).
|
|
205
|
+
|
|
206
|
+
.. warning::
|
|
207
|
+
CG may fail if hessian is not positive-definite.
|
|
208
|
+
|
|
209
|
+
Args:
|
|
210
|
+
maxiter (int | None, optional):
|
|
211
|
+
Maximum number of iterations for the conjugate gradient solver.
|
|
212
|
+
By default, this is set to the number of dimensions in the
|
|
213
|
+
objective function, which is the theoretical upper bound for CG
|
|
214
|
+
convergence. Setting this to a smaller value (truncated Newton)
|
|
215
|
+
can still generate good search directions. Defaults to None.
|
|
216
|
+
eta (float, optional):
|
|
217
|
+
whenever actual to predicted loss reduction ratio is larger than this, a step is accepted.
|
|
218
|
+
nplus (float, optional):
|
|
219
|
+
trust region multiplier on successful steps.
|
|
220
|
+
nminus (float, optional):
|
|
221
|
+
trust region multiplier on unsuccessful steps.
|
|
222
|
+
init (float, optional): initial trust region.
|
|
223
|
+
tol (float, optional):
|
|
224
|
+
Relative tolerance for the conjugate gradient solver to determine
|
|
225
|
+
convergence. Defaults to 1e-4.
|
|
226
|
+
reg (float, optional):
|
|
227
|
+
Regularization parameter (damping) added to the Hessian diagonal.
|
|
228
|
+
This helps ensure the system is positive-definite. Defaults to 1e-8.
|
|
229
|
+
hvp_method (str, optional):
|
|
230
|
+
Determines how Hessian-vector products are evaluated.
|
|
231
|
+
|
|
232
|
+
- ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
|
|
233
|
+
This requires creating a graph for the gradient.
|
|
234
|
+
- ``"forward"``: Use a forward finite difference formula to
|
|
235
|
+
approximate the HVP. This requires one extra gradient evaluation.
|
|
236
|
+
- ``"central"``: Use a central finite difference formula for a
|
|
237
|
+
more accurate HVP approximation. This requires two extra
|
|
238
|
+
gradient evaluations.
|
|
239
|
+
Defaults to "autograd".
|
|
240
|
+
h (float, optional):
|
|
241
|
+
The step size for finite differences if :code:`hvp_method` is
|
|
242
|
+
``"forward"`` or ``"central"``. Defaults to 1e-3.
|
|
243
|
+
inner (Chainable | None, optional):
|
|
244
|
+
NewtonCG will attempt to apply preconditioning to the output of this module.
|
|
245
|
+
|
|
246
|
+
Examples:
|
|
247
|
+
Trust-region Newton-CG:
|
|
248
|
+
|
|
249
|
+
.. code-block:: python
|
|
250
|
+
|
|
251
|
+
opt = tz.Modular(
|
|
252
|
+
model.parameters(),
|
|
253
|
+
tz.m.NewtonCGSteihaug(),
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
Reference:
|
|
257
|
+
Steihaug, Trond. "The conjugate gradient method and trust regions in large scale optimization." SIAM Journal on Numerical Analysis 20.3 (1983): 626-637.
|
|
258
|
+
"""
|
|
259
|
+
def __init__(
|
|
260
|
+
self,
|
|
261
|
+
maxiter: int | None = None,
|
|
262
|
+
eta: float= 0.0,
|
|
263
|
+
nplus: float = 3.5,
|
|
264
|
+
nminus: float = 0.25,
|
|
265
|
+
rho_good: float = 0.99,
|
|
266
|
+
rho_bad: float = 1e-4,
|
|
267
|
+
init: float = 1,
|
|
268
|
+
tol: float = 1e-8,
|
|
269
|
+
reg: float = 1e-8,
|
|
270
|
+
hvp_method: Literal["forward", "central"] = "forward",
|
|
271
|
+
solver: Literal['cg', "minres"] = 'cg',
|
|
272
|
+
h: float = 1e-3,
|
|
273
|
+
max_attempts: int = 100,
|
|
274
|
+
max_history: int = 100,
|
|
275
|
+
boundary_tol: float = 1e-1,
|
|
276
|
+
miniter: int = 1,
|
|
277
|
+
rms_beta: float | None = None,
|
|
278
|
+
adapt_tol: bool = True,
|
|
279
|
+
npc_terminate: bool = False,
|
|
280
|
+
inner: Chainable | None = None,
|
|
281
|
+
):
|
|
282
|
+
defaults = locals().copy()
|
|
283
|
+
del defaults['self'], defaults['inner']
|
|
284
|
+
super().__init__(defaults,)
|
|
285
|
+
|
|
286
|
+
if inner is not None:
|
|
287
|
+
self.set_child('inner', inner)
|
|
288
|
+
|
|
289
|
+
self._num_hvps = 0
|
|
290
|
+
self._num_hvps_last_step = 0
|
|
291
|
+
|
|
292
|
+
@torch.no_grad
|
|
293
|
+
def step(self, var):
|
|
294
|
+
params = TensorList(var.params)
|
|
295
|
+
closure = var.closure
|
|
296
|
+
if closure is None: raise RuntimeError('NewtonCG requires closure')
|
|
297
|
+
|
|
298
|
+
tol = self.defaults['tol'] * self.global_state.get('tol_mul', 1)
|
|
299
|
+
solver = self.defaults['solver'].lower().strip()
|
|
300
|
+
|
|
301
|
+
(reg, maxiter, hvp_method, h, max_attempts, boundary_tol,
|
|
302
|
+
eta, nplus, nminus, rho_good, rho_bad, init, npc_terminate,
|
|
303
|
+
miniter, max_history, adapt_tol) = itemgetter(
|
|
304
|
+
"reg", "maxiter", "hvp_method", "h", "max_attempts", "boundary_tol",
|
|
305
|
+
"eta", "nplus", "nminus", "rho_good", "rho_bad", "init", "npc_terminate",
|
|
306
|
+
"miniter", "max_history", "adapt_tol",
|
|
307
|
+
)(self.defaults)
|
|
308
|
+
|
|
309
|
+
self._num_hvps_last_step = 0
|
|
310
|
+
|
|
311
|
+
# ---------------------- Hessian vector product function --------------------- #
|
|
312
|
+
if hvp_method == 'autograd':
|
|
313
|
+
grad = var.get_grad(create_graph=True)
|
|
314
|
+
|
|
315
|
+
def H_mm(x):
|
|
316
|
+
self._num_hvps_last_step += 1
|
|
317
|
+
with torch.enable_grad():
|
|
318
|
+
return TensorList(hvp(params, grad, x, retain_graph=True))
|
|
319
|
+
|
|
320
|
+
else:
|
|
321
|
+
|
|
322
|
+
with torch.enable_grad():
|
|
323
|
+
grad = var.get_grad()
|
|
324
|
+
|
|
325
|
+
if hvp_method == 'forward':
|
|
326
|
+
def H_mm(x):
|
|
327
|
+
self._num_hvps_last_step += 1
|
|
328
|
+
return TensorList(hvp_fd_forward(closure, params, x, h=h, g_0=grad, normalize=True)[1])
|
|
329
|
+
|
|
330
|
+
elif hvp_method == 'central':
|
|
331
|
+
def H_mm(x):
|
|
332
|
+
self._num_hvps_last_step += 1
|
|
333
|
+
return TensorList(hvp_fd_central(closure, params, x, h=h, normalize=True)[1])
|
|
334
|
+
|
|
335
|
+
else:
|
|
336
|
+
raise ValueError(hvp_method)
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
# ------------------------- update RMS preconditioner ------------------------ #
|
|
340
|
+
b = var.get_update()
|
|
341
|
+
P_mm = None
|
|
342
|
+
rms_beta = self.defaults["rms_beta"]
|
|
343
|
+
if rms_beta is not None:
|
|
344
|
+
exp_avg_sq = self.get_state(params, "exp_avg_sq", init=b, cls=TensorList)
|
|
345
|
+
exp_avg_sq.mul_(rms_beta).addcmul(b, b, value=1-rms_beta)
|
|
346
|
+
exp_avg_sq_sqrt = exp_avg_sq.sqrt().add_(1e-8)
|
|
347
|
+
def _P_mm(x):
|
|
348
|
+
return x / exp_avg_sq_sqrt
|
|
349
|
+
P_mm = _P_mm
|
|
350
|
+
|
|
351
|
+
# -------------------------------- inner step -------------------------------- #
|
|
352
|
+
if 'inner' in self.children:
|
|
353
|
+
b = apply_transform(self.children['inner'], b, params=params, grads=grad, var=var)
|
|
354
|
+
b = as_tensorlist(b)
|
|
355
|
+
|
|
356
|
+
# ------------------------------- trust region ------------------------------- #
|
|
357
|
+
success = False
|
|
358
|
+
d = None
|
|
359
|
+
x0 = [p.clone() for p in params]
|
|
360
|
+
solution = None
|
|
361
|
+
|
|
362
|
+
while not success:
|
|
363
|
+
max_attempts -= 1
|
|
364
|
+
if max_attempts < 0: break
|
|
365
|
+
|
|
366
|
+
trust_radius = self.global_state.get('trust_radius', init)
|
|
367
|
+
|
|
368
|
+
# -------------- make sure trust radius isn't too small or large ------------- #
|
|
369
|
+
finfo = torch.finfo(x0[0].dtype)
|
|
370
|
+
if trust_radius < finfo.tiny * 2:
|
|
371
|
+
trust_radius = self.global_state['trust_radius'] = init
|
|
372
|
+
if adapt_tol:
|
|
373
|
+
self.global_state["tol_mul"] = self.global_state.get("tol_mul", 1) * 0.1
|
|
374
|
+
|
|
375
|
+
elif trust_radius > finfo.max / 2:
|
|
376
|
+
trust_radius = self.global_state['trust_radius'] = init
|
|
377
|
+
|
|
378
|
+
# ----------------------------------- solve ---------------------------------- #
|
|
379
|
+
d = None
|
|
380
|
+
if solution is not None and solution.history is not None:
|
|
381
|
+
d = find_within_trust_radius(solution.history, trust_radius)
|
|
382
|
+
|
|
383
|
+
if d is None:
|
|
384
|
+
if solver == 'cg':
|
|
385
|
+
d, solution = cg(
|
|
386
|
+
A_mm=H_mm,
|
|
387
|
+
b=b,
|
|
388
|
+
tol=tol,
|
|
389
|
+
maxiter=maxiter,
|
|
390
|
+
reg=reg,
|
|
391
|
+
trust_radius=trust_radius,
|
|
392
|
+
miniter=miniter,
|
|
393
|
+
npc_terminate=npc_terminate,
|
|
394
|
+
history_size=max_history,
|
|
395
|
+
P_mm=P_mm,
|
|
396
|
+
)
|
|
397
|
+
|
|
398
|
+
elif solver == 'minres':
|
|
399
|
+
d = minres(A_mm=H_mm, b=b, trust_radius=trust_radius, tol=tol, maxiter=maxiter, reg=reg, npc_terminate=npc_terminate)
|
|
400
|
+
|
|
401
|
+
else:
|
|
402
|
+
raise ValueError(f"unknown solver {solver}")
|
|
403
|
+
|
|
404
|
+
# ---------------------------- update trust radius --------------------------- #
|
|
405
|
+
self.global_state["trust_radius"], success = default_radius(
|
|
406
|
+
params=params,
|
|
407
|
+
closure=closure,
|
|
408
|
+
f=tofloat(var.get_loss(False)),
|
|
409
|
+
g=b,
|
|
410
|
+
H=H_mm,
|
|
411
|
+
d=d,
|
|
412
|
+
trust_radius=trust_radius,
|
|
413
|
+
eta=eta,
|
|
414
|
+
nplus=nplus,
|
|
415
|
+
nminus=nminus,
|
|
416
|
+
rho_good=rho_good,
|
|
417
|
+
rho_bad=rho_bad,
|
|
418
|
+
boundary_tol=boundary_tol,
|
|
419
|
+
|
|
420
|
+
init=init, # init isn't used because check_overflow=False
|
|
421
|
+
state=self.global_state, # not used
|
|
422
|
+
settings=self.defaults, # not used
|
|
423
|
+
check_overflow=False, # this is checked manually to adapt tolerance
|
|
424
|
+
)
|
|
425
|
+
|
|
426
|
+
# --------------------------- assign new direction --------------------------- #
|
|
427
|
+
assert d is not None
|
|
428
|
+
if success:
|
|
429
|
+
var.update = d
|
|
430
|
+
|
|
431
|
+
else:
|
|
432
|
+
var.update = params.zeros_like()
|
|
433
|
+
|
|
434
|
+
self._num_hvps += self._num_hvps_last_step
|
|
435
|
+
return var
|
|
@@ -10,12 +10,60 @@ from ...core import Chainable, apply_transform, Module
|
|
|
10
10
|
from ...utils.linalg.solve import nystrom_sketch_and_solve, nystrom_pcg
|
|
11
11
|
|
|
12
12
|
class NystromSketchAndSolve(Module):
|
|
13
|
+
"""Newton's method with a Nyström sketch-and-solve solver.
|
|
14
|
+
|
|
15
|
+
.. note::
|
|
16
|
+
This module requires the a closure passed to the optimizer step,
|
|
17
|
+
as it needs to re-evaluate the loss and gradients for calculating HVPs.
|
|
18
|
+
The closure must accept a ``backward`` argument (refer to documentation).
|
|
19
|
+
|
|
20
|
+
.. note::
|
|
21
|
+
In most cases NystromSketchAndSolve should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
|
|
22
|
+
|
|
23
|
+
.. note::
|
|
24
|
+
If this is unstable, increase the :code:`reg` parameter and tune the rank.
|
|
25
|
+
|
|
26
|
+
.. note:
|
|
27
|
+
:code:`tz.m.NystromPCG` usually outperforms this.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
rank (int): size of the sketch, this many hessian-vector products will be evaluated per step.
|
|
31
|
+
reg (float, optional): regularization parameter. Defaults to 1e-3.
|
|
32
|
+
hvp_method (str, optional):
|
|
33
|
+
Determines how Hessian-vector products are evaluated.
|
|
34
|
+
|
|
35
|
+
- ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
|
|
36
|
+
This requires creating a graph for the gradient.
|
|
37
|
+
- ``"forward"``: Use a forward finite difference formula to
|
|
38
|
+
approximate the HVP. This requires one extra gradient evaluation.
|
|
39
|
+
- ``"central"``: Use a central finite difference formula for a
|
|
40
|
+
more accurate HVP approximation. This requires two extra
|
|
41
|
+
gradient evaluations.
|
|
42
|
+
Defaults to "autograd".
|
|
43
|
+
h (float, optional): finite difference step size if :code:`hvp_method` is "forward" or "central". Defaults to 1e-3.
|
|
44
|
+
inner (Chainable | None, optional): modules to apply hessian preconditioner to. Defaults to None.
|
|
45
|
+
seed (int | None, optional): seed for random generator. Defaults to None.
|
|
46
|
+
|
|
47
|
+
Examples:
|
|
48
|
+
NystromSketchAndSolve with backtracking line search
|
|
49
|
+
|
|
50
|
+
.. code-block:: python
|
|
51
|
+
|
|
52
|
+
opt = tz.Modular(
|
|
53
|
+
model.parameters(),
|
|
54
|
+
tz.m.NystromSketchAndSolve(10),
|
|
55
|
+
tz.m.Backtracking()
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
Reference:
|
|
59
|
+
Frangella, Z., Tropp, J. A., & Udell, M. (2023). Randomized nyström preconditioning. SIAM Journal on Matrix Analysis and Applications, 44(2), 718-752. https://arxiv.org/abs/2110.02820
|
|
60
|
+
"""
|
|
13
61
|
def __init__(
|
|
14
62
|
self,
|
|
15
63
|
rank: int,
|
|
16
64
|
reg: float = 1e-3,
|
|
17
65
|
hvp_method: Literal["forward", "central", "autograd"] = "autograd",
|
|
18
|
-
h=1e-3,
|
|
66
|
+
h: float = 1e-3,
|
|
19
67
|
inner: Chainable | None = None,
|
|
20
68
|
seed: int | None = None,
|
|
21
69
|
):
|
|
@@ -86,6 +134,61 @@ class NystromSketchAndSolve(Module):
|
|
|
86
134
|
|
|
87
135
|
|
|
88
136
|
class NystromPCG(Module):
|
|
137
|
+
"""Newton's method with a Nyström-preconditioned conjugate gradient solver.
|
|
138
|
+
This tends to outperform NewtonCG but requires tuning sketch size.
|
|
139
|
+
An adaptive version exists in https://arxiv.org/abs/2110.02820, I might implement it too at some point.
|
|
140
|
+
|
|
141
|
+
.. note::
|
|
142
|
+
This module requires the a closure passed to the optimizer step,
|
|
143
|
+
as it needs to re-evaluate the loss and gradients for calculating HVPs.
|
|
144
|
+
The closure must accept a ``backward`` argument (refer to documentation).
|
|
145
|
+
|
|
146
|
+
.. note::
|
|
147
|
+
In most cases NystromPCG should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
sketch_size (int):
|
|
151
|
+
size of the sketch for preconditioning, this many hessian-vector products will be evaluated before
|
|
152
|
+
running the conjugate gradient solver. Larger value improves the preconditioning and speeds up
|
|
153
|
+
conjugate gradient.
|
|
154
|
+
maxiter (int | None, optional):
|
|
155
|
+
maximum number of iterations. By default this is set to the number of dimensions
|
|
156
|
+
in the objective function, which is supposed to be enough for conjugate gradient
|
|
157
|
+
to have guaranteed convergence. Setting this to a small value can still generate good enough directions.
|
|
158
|
+
Defaults to None.
|
|
159
|
+
tol (float, optional): relative tolerance for conjugate gradient solver. Defaults to 1e-4.
|
|
160
|
+
reg (float, optional): regularization parameter. Defaults to 1e-8.
|
|
161
|
+
hvp_method (str, optional):
|
|
162
|
+
Determines how Hessian-vector products are evaluated.
|
|
163
|
+
|
|
164
|
+
- ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
|
|
165
|
+
This requires creating a graph for the gradient.
|
|
166
|
+
- ``"forward"``: Use a forward finite difference formula to
|
|
167
|
+
approximate the HVP. This requires one extra gradient evaluation.
|
|
168
|
+
- ``"central"``: Use a central finite difference formula for a
|
|
169
|
+
more accurate HVP approximation. This requires two extra
|
|
170
|
+
gradient evaluations.
|
|
171
|
+
Defaults to "autograd".
|
|
172
|
+
h (float, optional): finite difference step size if :code:`hvp_method` is "forward" or "central". Defaults to 1e-3.
|
|
173
|
+
inner (Chainable | None, optional): modules to apply hessian preconditioner to. Defaults to None.
|
|
174
|
+
seed (int | None, optional): seed for random generator. Defaults to None.
|
|
175
|
+
|
|
176
|
+
Examples:
|
|
177
|
+
|
|
178
|
+
NystromPCG with backtracking line search
|
|
179
|
+
|
|
180
|
+
.. code-block:: python
|
|
181
|
+
|
|
182
|
+
opt = tz.Modular(
|
|
183
|
+
model.parameters(),
|
|
184
|
+
tz.m.NystromPCG(10),
|
|
185
|
+
tz.m.Backtracking()
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
Reference:
|
|
189
|
+
Frangella, Z., Tropp, J. A., & Udell, M. (2023). Randomized nyström preconditioning. SIAM Journal on Matrix Analysis and Applications, 44(2), 718-752. https://arxiv.org/abs/2110.02820
|
|
190
|
+
|
|
191
|
+
"""
|
|
89
192
|
def __init__(
|
|
90
193
|
self,
|
|
91
194
|
sketch_size: int,
|
|
@@ -1,2 +1,2 @@
|
|
|
1
1
|
from .laplacian import LaplacianSmoothing
|
|
2
|
-
from .
|
|
2
|
+
from .sampling import GradientSampling
|
|
@@ -56,7 +56,7 @@ def _precompute_denominator(tensor: torch.Tensor, sigma) -> torch.Tensor:
|
|
|
56
56
|
return 1 - sigma * torch.fft.fft(v) # pylint: disable = not-callable
|
|
57
57
|
|
|
58
58
|
class LaplacianSmoothing(Transform):
|
|
59
|
-
"""Applies laplacian smoothing via a fast Fourier transform solver.
|
|
59
|
+
"""Applies laplacian smoothing via a fast Fourier transform solver which can improve generalization.
|
|
60
60
|
|
|
61
61
|
Args:
|
|
62
62
|
sigma (float, optional): controls the amount of smoothing. Defaults to 1.
|
|
@@ -69,9 +69,19 @@ class LaplacianSmoothing(Transform):
|
|
|
69
69
|
target (str, optional):
|
|
70
70
|
what to set on var.
|
|
71
71
|
|
|
72
|
+
Examples:
|
|
73
|
+
Laplacian Smoothing Gradient Descent optimizer as in the paper
|
|
74
|
+
|
|
75
|
+
.. code-block:: python
|
|
76
|
+
|
|
77
|
+
opt = tz.Modular(
|
|
78
|
+
model.parameters(),
|
|
79
|
+
tz.m.LaplacianSmoothing(),
|
|
80
|
+
tz.m.LR(1e-2),
|
|
81
|
+
)
|
|
82
|
+
|
|
72
83
|
Reference:
|
|
73
|
-
|
|
74
|
-
Laplacian smoothing gradient descent. Research in the Mathematical Sciences, 9(3), 55.*
|
|
84
|
+
Osher, S., Wang, B., Yin, P., Luo, X., Barekat, F., Pham, M., & Lin, A. (2022). Laplacian smoothing gradient descent. Research in the Mathematical Sciences, 9(3), 55.
|
|
75
85
|
|
|
76
86
|
"""
|
|
77
87
|
def __init__(self, sigma:float = 1, layerwise=True, min_numel = 4, target: Target = 'update'):
|
|
@@ -82,7 +92,7 @@ class LaplacianSmoothing(Transform):
|
|
|
82
92
|
|
|
83
93
|
|
|
84
94
|
@torch.no_grad
|
|
85
|
-
def
|
|
95
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
86
96
|
layerwise = settings[0]['layerwise']
|
|
87
97
|
|
|
88
98
|
# layerwise laplacian smoothing
|