torchzero 0.3.9__py3-none-any.whl → 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +6 -4
- docs/source/docstring template.py +46 -0
- tests/test_identical.py +2 -3
- tests/test_opts.py +115 -68
- tests/test_tensorlist.py +2 -2
- tests/test_vars.py +62 -61
- torchzero/core/__init__.py +2 -3
- torchzero/core/module.py +185 -53
- torchzero/core/transform.py +327 -159
- torchzero/modules/__init__.py +3 -1
- torchzero/modules/clipping/clipping.py +120 -23
- torchzero/modules/clipping/ema_clipping.py +37 -22
- torchzero/modules/clipping/growth_clipping.py +20 -21
- torchzero/modules/experimental/__init__.py +30 -4
- torchzero/modules/experimental/absoap.py +53 -156
- torchzero/modules/experimental/adadam.py +22 -15
- torchzero/modules/experimental/adamY.py +21 -25
- torchzero/modules/experimental/adam_lambertw.py +149 -0
- torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +37 -8
- torchzero/modules/experimental/adasoap.py +24 -129
- torchzero/modules/experimental/cosine.py +214 -0
- torchzero/modules/experimental/cubic_adam.py +97 -0
- torchzero/modules/experimental/curveball.py +12 -12
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/experimental/eigendescent.py +120 -0
- torchzero/modules/experimental/etf.py +195 -0
- torchzero/modules/experimental/exp_adam.py +113 -0
- torchzero/modules/experimental/expanded_lbfgs.py +141 -0
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/hnewton.py +85 -0
- torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +49 -50
- torchzero/modules/experimental/newton_solver.py +11 -11
- torchzero/modules/experimental/newtonnewton.py +92 -0
- torchzero/modules/experimental/parabolic_search.py +220 -0
- torchzero/modules/experimental/reduce_outward_lr.py +10 -7
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
- torchzero/modules/experimental/subspace_preconditioners.py +20 -10
- torchzero/modules/experimental/tensor_adagrad.py +42 -0
- torchzero/modules/functional.py +12 -2
- torchzero/modules/grad_approximation/fdm.py +31 -4
- torchzero/modules/grad_approximation/forward_gradient.py +17 -7
- torchzero/modules/grad_approximation/grad_approximator.py +69 -24
- torchzero/modules/grad_approximation/rfdm.py +310 -50
- torchzero/modules/higher_order/__init__.py +1 -0
- torchzero/modules/higher_order/higher_order_newton.py +319 -0
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/adaptive.py +99 -0
- torchzero/modules/line_search/backtracking.py +75 -31
- torchzero/modules/line_search/line_search.py +107 -49
- torchzero/modules/line_search/polynomial.py +233 -0
- torchzero/modules/line_search/scipy.py +20 -5
- torchzero/modules/line_search/strong_wolfe.py +52 -36
- torchzero/modules/misc/__init__.py +27 -0
- torchzero/modules/misc/debug.py +48 -0
- torchzero/modules/misc/escape.py +60 -0
- torchzero/modules/misc/gradient_accumulation.py +70 -0
- torchzero/modules/misc/misc.py +316 -0
- torchzero/modules/misc/multistep.py +158 -0
- torchzero/modules/misc/regularization.py +171 -0
- torchzero/modules/misc/split.py +103 -0
- torchzero/modules/{ops → misc}/switch.py +48 -7
- torchzero/modules/momentum/__init__.py +1 -1
- torchzero/modules/momentum/averaging.py +25 -10
- torchzero/modules/momentum/cautious.py +115 -40
- torchzero/modules/momentum/ema.py +92 -41
- torchzero/modules/momentum/experimental.py +21 -13
- torchzero/modules/momentum/matrix_momentum.py +145 -76
- torchzero/modules/momentum/momentum.py +25 -4
- torchzero/modules/ops/__init__.py +3 -31
- torchzero/modules/ops/accumulate.py +51 -25
- torchzero/modules/ops/binary.py +108 -62
- torchzero/modules/ops/multi.py +95 -34
- torchzero/modules/ops/reduce.py +31 -23
- torchzero/modules/ops/unary.py +37 -21
- torchzero/modules/ops/utility.py +53 -45
- torchzero/modules/optimizers/__init__.py +12 -3
- torchzero/modules/optimizers/adagrad.py +48 -29
- torchzero/modules/optimizers/adahessian.py +223 -0
- torchzero/modules/optimizers/adam.py +35 -37
- torchzero/modules/optimizers/adan.py +110 -0
- torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
- torchzero/modules/optimizers/esgd.py +171 -0
- torchzero/modules/optimizers/ladagrad.py +183 -0
- torchzero/modules/optimizers/lion.py +4 -4
- torchzero/modules/optimizers/mars.py +91 -0
- torchzero/modules/optimizers/msam.py +186 -0
- torchzero/modules/optimizers/muon.py +32 -7
- torchzero/modules/optimizers/orthograd.py +4 -5
- torchzero/modules/optimizers/rmsprop.py +19 -19
- torchzero/modules/optimizers/rprop.py +89 -52
- torchzero/modules/optimizers/sam.py +163 -0
- torchzero/modules/optimizers/shampoo.py +55 -27
- torchzero/modules/optimizers/soap.py +40 -37
- torchzero/modules/optimizers/sophia_h.py +82 -25
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +4 -2
- torchzero/modules/projections/projection.py +212 -118
- torchzero/modules/quasi_newton/__init__.py +44 -5
- torchzero/modules/quasi_newton/cg.py +190 -39
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
- torchzero/modules/quasi_newton/lbfgs.py +154 -97
- torchzero/modules/quasi_newton/lsr1.py +102 -58
- torchzero/modules/quasi_newton/quasi_newton.py +1032 -177
- torchzero/modules/quasi_newton/trust_region.py +397 -0
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/newton.py +245 -54
- torchzero/modules/second_order/newton_cg.py +311 -21
- torchzero/modules/second_order/nystrom.py +124 -21
- torchzero/modules/smoothing/gaussian.py +55 -21
- torchzero/modules/smoothing/laplacian.py +20 -12
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +122 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +126 -10
- torchzero/modules/wrappers/optim_wrapper.py +40 -12
- torchzero/optim/wrappers/directsearch.py +281 -0
- torchzero/optim/wrappers/fcmaes.py +105 -0
- torchzero/optim/wrappers/mads.py +89 -0
- torchzero/optim/wrappers/nevergrad.py +20 -5
- torchzero/optim/wrappers/nlopt.py +28 -14
- torchzero/optim/wrappers/optuna.py +70 -0
- torchzero/optim/wrappers/scipy.py +167 -16
- torchzero/utils/__init__.py +3 -7
- torchzero/utils/derivatives.py +5 -4
- torchzero/utils/linalg/__init__.py +1 -1
- torchzero/utils/linalg/solve.py +251 -12
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/optimizer.py +55 -74
- torchzero/utils/python_tools.py +27 -4
- torchzero/utils/tensorlist.py +40 -28
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/METADATA +76 -51
- torchzero-0.3.11.dist-info/RECORD +159 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/WHEEL +1 -1
- torchzero/core/preconditioner.py +0 -138
- torchzero/modules/experimental/algebraic_newton.py +0 -145
- torchzero/modules/experimental/soapy.py +0 -290
- torchzero/modules/experimental/spectral.py +0 -288
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/experimental/tropical_newton.py +0 -136
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/lr.py +0 -59
- torchzero/modules/lr/step_size.py +0 -97
- torchzero/modules/ops/debug.py +0 -25
- torchzero/modules/ops/misc.py +0 -419
- torchzero/modules/ops/split.py +0 -75
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero-0.3.9.dist-info/RECORD +0 -131
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
|
@@ -1,35 +1,111 @@
|
|
|
1
|
-
from collections.abc import Callable
|
|
2
1
|
from typing import Literal, overload
|
|
3
|
-
import warnings
|
|
4
2
|
import torch
|
|
5
3
|
|
|
6
|
-
from ...utils import TensorList, as_tensorlist,
|
|
4
|
+
from ...utils import TensorList, as_tensorlist, NumberList
|
|
7
5
|
from ...utils.derivatives import hvp, hvp_fd_central, hvp_fd_forward
|
|
8
6
|
|
|
9
|
-
from ...core import Chainable,
|
|
10
|
-
from ...utils.linalg.solve import cg
|
|
7
|
+
from ...core import Chainable, apply_transform, Module
|
|
8
|
+
from ...utils.linalg.solve import cg, steihaug_toint_cg, minres
|
|
11
9
|
|
|
12
10
|
class NewtonCG(Module):
|
|
11
|
+
"""Newton's method with a matrix-free conjugate gradient or minimial-residual solver.
|
|
12
|
+
|
|
13
|
+
This optimizer implements Newton's method using a matrix-free conjugate
|
|
14
|
+
gradient (CG) or a minimal-residual (MINRES) solver to approximate the search direction. Instead of
|
|
15
|
+
forming the full Hessian matrix, it only requires Hessian-vector products
|
|
16
|
+
(HVPs). These can be calculated efficiently using automatic
|
|
17
|
+
differentiation or approximated using finite differences.
|
|
18
|
+
|
|
19
|
+
.. note::
|
|
20
|
+
In most cases NewtonCG should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
|
|
21
|
+
|
|
22
|
+
.. note::
|
|
23
|
+
This module requires the a closure passed to the optimizer step,
|
|
24
|
+
as it needs to re-evaluate the loss and gradients for calculating HVPs.
|
|
25
|
+
The closure must accept a ``backward`` argument (refer to documentation).
|
|
26
|
+
|
|
27
|
+
.. warning::
|
|
28
|
+
CG may fail if hessian is not positive-definite.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
maxiter (int | None, optional):
|
|
32
|
+
Maximum number of iterations for the conjugate gradient solver.
|
|
33
|
+
By default, this is set to the number of dimensions in the
|
|
34
|
+
objective function, which is the theoretical upper bound for CG
|
|
35
|
+
convergence. Setting this to a smaller value (truncated Newton)
|
|
36
|
+
can still generate good search directions. Defaults to None.
|
|
37
|
+
tol (float, optional):
|
|
38
|
+
Relative tolerance for the conjugate gradient solver to determine
|
|
39
|
+
convergence. Defaults to 1e-4.
|
|
40
|
+
reg (float, optional):
|
|
41
|
+
Regularization parameter (damping) added to the Hessian diagonal.
|
|
42
|
+
This helps ensure the system is positive-definite. Defaults to 1e-8.
|
|
43
|
+
hvp_method (str, optional):
|
|
44
|
+
Determines how Hessian-vector products are evaluated.
|
|
45
|
+
|
|
46
|
+
- ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
|
|
47
|
+
This requires creating a graph for the gradient.
|
|
48
|
+
- ``"forward"``: Use a forward finite difference formula to
|
|
49
|
+
approximate the HVP. This requires one extra gradient evaluation.
|
|
50
|
+
- ``"central"``: Use a central finite difference formula for a
|
|
51
|
+
more accurate HVP approximation. This requires two extra
|
|
52
|
+
gradient evaluations.
|
|
53
|
+
Defaults to "autograd".
|
|
54
|
+
h (float, optional):
|
|
55
|
+
The step size for finite differences if :code:`hvp_method` is
|
|
56
|
+
``"forward"`` or ``"central"``. Defaults to 1e-3.
|
|
57
|
+
warm_start (bool, optional):
|
|
58
|
+
If ``True``, the conjugate gradient solver is initialized with the
|
|
59
|
+
solution from the previous optimization step. This can accelerate
|
|
60
|
+
convergence, especially in truncated Newton methods.
|
|
61
|
+
Defaults to False.
|
|
62
|
+
inner (Chainable | None, optional):
|
|
63
|
+
NewtonCG will attempt to apply preconditioning to the output of this module.
|
|
64
|
+
|
|
65
|
+
Examples:
|
|
66
|
+
Newton-CG with a backtracking line search:
|
|
67
|
+
|
|
68
|
+
.. code-block:: python
|
|
69
|
+
|
|
70
|
+
opt = tz.Modular(
|
|
71
|
+
model.parameters(),
|
|
72
|
+
tz.m.NewtonCG(),
|
|
73
|
+
tz.m.Backtracking()
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
Truncated Newton method (useful for large-scale problems):
|
|
77
|
+
|
|
78
|
+
.. code-block:: python
|
|
79
|
+
|
|
80
|
+
opt = tz.Modular(
|
|
81
|
+
model.parameters(),
|
|
82
|
+
tz.m.NewtonCG(maxiter=10, warm_start=True),
|
|
83
|
+
tz.m.Backtracking()
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
"""
|
|
13
88
|
def __init__(
|
|
14
89
|
self,
|
|
15
|
-
maxiter=None,
|
|
16
|
-
tol=1e-
|
|
90
|
+
maxiter: int | None = None,
|
|
91
|
+
tol: float = 1e-4,
|
|
17
92
|
reg: float = 1e-8,
|
|
18
|
-
hvp_method: Literal["forward", "central", "autograd"] = "
|
|
19
|
-
|
|
93
|
+
hvp_method: Literal["forward", "central", "autograd"] = "autograd",
|
|
94
|
+
solver: Literal['cg', 'minres', 'minres_npc'] = 'cg',
|
|
95
|
+
h: float = 1e-3,
|
|
20
96
|
warm_start=False,
|
|
21
97
|
inner: Chainable | None = None,
|
|
22
98
|
):
|
|
23
|
-
defaults = dict(tol=tol, maxiter=maxiter, reg=reg, hvp_method=hvp_method, h=h, warm_start=warm_start)
|
|
99
|
+
defaults = dict(tol=tol, maxiter=maxiter, reg=reg, hvp_method=hvp_method, solver=solver, h=h, warm_start=warm_start)
|
|
24
100
|
super().__init__(defaults,)
|
|
25
101
|
|
|
26
102
|
if inner is not None:
|
|
27
103
|
self.set_child('inner', inner)
|
|
28
104
|
|
|
29
105
|
@torch.no_grad
|
|
30
|
-
def step(self,
|
|
31
|
-
params = TensorList(
|
|
32
|
-
closure =
|
|
106
|
+
def step(self, var):
|
|
107
|
+
params = TensorList(var.params)
|
|
108
|
+
closure = var.closure
|
|
33
109
|
if closure is None: raise RuntimeError('NewtonCG requires closure')
|
|
34
110
|
|
|
35
111
|
settings = self.settings[params[0]]
|
|
@@ -37,12 +113,13 @@ class NewtonCG(Module):
|
|
|
37
113
|
reg = settings['reg']
|
|
38
114
|
maxiter = settings['maxiter']
|
|
39
115
|
hvp_method = settings['hvp_method']
|
|
116
|
+
solver = settings['solver'].lower().strip()
|
|
40
117
|
h = settings['h']
|
|
41
118
|
warm_start = settings['warm_start']
|
|
42
119
|
|
|
43
120
|
# ---------------------- Hessian vector product function --------------------- #
|
|
44
121
|
if hvp_method == 'autograd':
|
|
45
|
-
grad =
|
|
122
|
+
grad = var.get_grad(create_graph=True)
|
|
46
123
|
|
|
47
124
|
def H_mm(x):
|
|
48
125
|
with torch.enable_grad():
|
|
@@ -51,7 +128,7 @@ class NewtonCG(Module):
|
|
|
51
128
|
else:
|
|
52
129
|
|
|
53
130
|
with torch.enable_grad():
|
|
54
|
-
grad =
|
|
131
|
+
grad = var.get_grad()
|
|
55
132
|
|
|
56
133
|
if hvp_method == 'forward':
|
|
57
134
|
def H_mm(x):
|
|
@@ -66,19 +143,232 @@ class NewtonCG(Module):
|
|
|
66
143
|
|
|
67
144
|
|
|
68
145
|
# -------------------------------- inner step -------------------------------- #
|
|
69
|
-
b =
|
|
146
|
+
b = var.get_update()
|
|
70
147
|
if 'inner' in self.children:
|
|
71
|
-
b =
|
|
148
|
+
b = apply_transform(self.children['inner'], b, params=params, grads=grad, var=var)
|
|
149
|
+
b = as_tensorlist(b)
|
|
72
150
|
|
|
73
151
|
# ---------------------------------- run cg ---------------------------------- #
|
|
74
152
|
x0 = None
|
|
75
|
-
if warm_start: x0 = self.get_state('prev_x',
|
|
76
|
-
|
|
153
|
+
if warm_start: x0 = self.get_state(params, 'prev_x', cls=TensorList) # initialized to 0 which is default anyway
|
|
154
|
+
|
|
155
|
+
if solver == 'cg':
|
|
156
|
+
x = cg(A_mm=H_mm, b=b, x0_=x0, tol=tol, maxiter=maxiter, reg=reg)
|
|
157
|
+
|
|
158
|
+
elif solver == 'minres':
|
|
159
|
+
x = minres(A_mm=H_mm, b=b, x0=x0, tol=tol, maxiter=maxiter, reg=reg, npc_terminate=False)
|
|
160
|
+
|
|
161
|
+
elif solver == 'minres_npc':
|
|
162
|
+
x = minres(A_mm=H_mm, b=b, x0=x0, tol=tol, maxiter=maxiter, reg=reg, npc_terminate=True)
|
|
163
|
+
|
|
164
|
+
else:
|
|
165
|
+
raise ValueError(f"Unknown solver {solver}")
|
|
166
|
+
|
|
77
167
|
if warm_start:
|
|
78
168
|
assert x0 is not None
|
|
79
169
|
x0.copy_(x)
|
|
80
170
|
|
|
81
|
-
|
|
82
|
-
return
|
|
171
|
+
var.update = x
|
|
172
|
+
return var
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
class TruncatedNewtonCG(Module):
|
|
176
|
+
"""Trust region Newton's method with a matrix-free Steihaug-Toint conjugate gradient or MINRES solver.
|
|
177
|
+
|
|
178
|
+
This optimizer implements Newton's method using a matrix-free conjugate
|
|
179
|
+
gradient (CG) solver to approximate the search direction. Instead of
|
|
180
|
+
forming the full Hessian matrix, it only requires Hessian-vector products
|
|
181
|
+
(HVPs). These can be calculated efficiently using automatic
|
|
182
|
+
differentiation or approximated using finite differences.
|
|
183
|
+
|
|
184
|
+
.. note::
|
|
185
|
+
In most cases NewtonCGSteihaug should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
|
|
186
|
+
|
|
187
|
+
.. note::
|
|
188
|
+
This module requires the a closure passed to the optimizer step,
|
|
189
|
+
as it needs to re-evaluate the loss and gradients for calculating HVPs.
|
|
190
|
+
The closure must accept a ``backward`` argument (refer to documentation).
|
|
191
|
+
|
|
192
|
+
.. warning::
|
|
193
|
+
CG may fail if hessian is not positive-definite.
|
|
194
|
+
|
|
195
|
+
Args:
|
|
196
|
+
maxiter (int | None, optional):
|
|
197
|
+
Maximum number of iterations for the conjugate gradient solver.
|
|
198
|
+
By default, this is set to the number of dimensions in the
|
|
199
|
+
objective function, which is the theoretical upper bound for CG
|
|
200
|
+
convergence. Setting this to a smaller value (truncated Newton)
|
|
201
|
+
can still generate good search directions. Defaults to None.
|
|
202
|
+
eta (float, optional):
|
|
203
|
+
whenever actual to predicted loss reduction ratio is larger than this, a step is accepted.
|
|
204
|
+
nplus (float, optional):
|
|
205
|
+
trust region multiplier on successful steps.
|
|
206
|
+
nminus (float, optional):
|
|
207
|
+
trust region multiplier on unsuccessful steps.
|
|
208
|
+
init (float, optional): initial trust region.
|
|
209
|
+
tol (float, optional):
|
|
210
|
+
Relative tolerance for the conjugate gradient solver to determine
|
|
211
|
+
convergence. Defaults to 1e-4.
|
|
212
|
+
reg (float, optional):
|
|
213
|
+
Regularization parameter (damping) added to the Hessian diagonal.
|
|
214
|
+
This helps ensure the system is positive-definite. Defaults to 1e-8.
|
|
215
|
+
hvp_method (str, optional):
|
|
216
|
+
Determines how Hessian-vector products are evaluated.
|
|
217
|
+
|
|
218
|
+
- ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
|
|
219
|
+
This requires creating a graph for the gradient.
|
|
220
|
+
- ``"forward"``: Use a forward finite difference formula to
|
|
221
|
+
approximate the HVP. This requires one extra gradient evaluation.
|
|
222
|
+
- ``"central"``: Use a central finite difference formula for a
|
|
223
|
+
more accurate HVP approximation. This requires two extra
|
|
224
|
+
gradient evaluations.
|
|
225
|
+
Defaults to "autograd".
|
|
226
|
+
h (float, optional):
|
|
227
|
+
The step size for finite differences if :code:`hvp_method` is
|
|
228
|
+
``"forward"`` or ``"central"``. Defaults to 1e-3.
|
|
229
|
+
inner (Chainable | None, optional):
|
|
230
|
+
NewtonCG will attempt to apply preconditioning to the output of this module.
|
|
231
|
+
|
|
232
|
+
Examples:
|
|
233
|
+
Trust-region Newton-CG:
|
|
234
|
+
|
|
235
|
+
.. code-block:: python
|
|
236
|
+
|
|
237
|
+
opt = tz.Modular(
|
|
238
|
+
model.parameters(),
|
|
239
|
+
tz.m.NewtonCGSteihaug(),
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
Reference:
|
|
243
|
+
Steihaug, Trond. "The conjugate gradient method and trust regions in large scale optimization." SIAM Journal on Numerical Analysis 20.3 (1983): 626-637.
|
|
244
|
+
"""
|
|
245
|
+
def __init__(
|
|
246
|
+
self,
|
|
247
|
+
maxiter: int | None = None,
|
|
248
|
+
eta: float= 1e-6,
|
|
249
|
+
nplus: float = 2,
|
|
250
|
+
nminus: float = 0.25,
|
|
251
|
+
init: float = 1,
|
|
252
|
+
tol: float = 1e-4,
|
|
253
|
+
reg: float = 1e-8,
|
|
254
|
+
hvp_method: Literal["forward", "central", "autograd"] = "autograd",
|
|
255
|
+
solver: Literal['cg', 'minres', 'minres_npc'] = 'cg',
|
|
256
|
+
h: float = 1e-3,
|
|
257
|
+
max_attempts: int = 10,
|
|
258
|
+
inner: Chainable | None = None,
|
|
259
|
+
):
|
|
260
|
+
defaults = dict(tol=tol, maxiter=maxiter, reg=reg, hvp_method=hvp_method, h=h, eta=eta, nplus=nplus, nminus=nminus, init=init, max_attempts=max_attempts, solver=solver)
|
|
261
|
+
super().__init__(defaults,)
|
|
262
|
+
|
|
263
|
+
if inner is not None:
|
|
264
|
+
self.set_child('inner', inner)
|
|
265
|
+
|
|
266
|
+
@torch.no_grad
|
|
267
|
+
def step(self, var):
|
|
268
|
+
params = TensorList(var.params)
|
|
269
|
+
closure = var.closure
|
|
270
|
+
if closure is None: raise RuntimeError('NewtonCG requires closure')
|
|
271
|
+
|
|
272
|
+
settings = self.settings[params[0]]
|
|
273
|
+
tol = settings['tol']
|
|
274
|
+
reg = settings['reg']
|
|
275
|
+
maxiter = settings['maxiter']
|
|
276
|
+
hvp_method = settings['hvp_method']
|
|
277
|
+
h = settings['h']
|
|
278
|
+
max_attempts = settings['max_attempts']
|
|
279
|
+
solver = settings['solver'].lower().strip()
|
|
280
|
+
|
|
281
|
+
eta = settings['eta']
|
|
282
|
+
nplus = settings['nplus']
|
|
283
|
+
nminus = settings['nminus']
|
|
284
|
+
init = settings['init']
|
|
285
|
+
|
|
286
|
+
# ---------------------- Hessian vector product function --------------------- #
|
|
287
|
+
if hvp_method == 'autograd':
|
|
288
|
+
grad = var.get_grad(create_graph=True)
|
|
289
|
+
|
|
290
|
+
def H_mm(x):
|
|
291
|
+
with torch.enable_grad():
|
|
292
|
+
return TensorList(hvp(params, grad, x, retain_graph=True))
|
|
293
|
+
|
|
294
|
+
else:
|
|
295
|
+
|
|
296
|
+
with torch.enable_grad():
|
|
297
|
+
grad = var.get_grad()
|
|
298
|
+
|
|
299
|
+
if hvp_method == 'forward':
|
|
300
|
+
def H_mm(x):
|
|
301
|
+
return TensorList(hvp_fd_forward(closure, params, x, h=h, g_0=grad, normalize=True)[1])
|
|
302
|
+
|
|
303
|
+
elif hvp_method == 'central':
|
|
304
|
+
def H_mm(x):
|
|
305
|
+
return TensorList(hvp_fd_central(closure, params, x, h=h, normalize=True)[1])
|
|
306
|
+
|
|
307
|
+
else:
|
|
308
|
+
raise ValueError(hvp_method)
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
# -------------------------------- inner step -------------------------------- #
|
|
312
|
+
b = var.get_update()
|
|
313
|
+
if 'inner' in self.children:
|
|
314
|
+
b = apply_transform(self.children['inner'], b, params=params, grads=grad, var=var)
|
|
315
|
+
b = as_tensorlist(b)
|
|
316
|
+
|
|
317
|
+
# ---------------------------------- run cg ---------------------------------- #
|
|
318
|
+
success = False
|
|
319
|
+
x = None
|
|
320
|
+
while not success:
|
|
321
|
+
max_attempts -= 1
|
|
322
|
+
if max_attempts < 0: break
|
|
323
|
+
|
|
324
|
+
trust_region = self.global_state.get('trust_region', init)
|
|
325
|
+
if trust_region < 1e-8 or trust_region > 1e8:
|
|
326
|
+
trust_region = self.global_state['trust_region'] = init
|
|
327
|
+
|
|
328
|
+
if solver == 'cg':
|
|
329
|
+
x = steihaug_toint_cg(A_mm=H_mm, b=b, trust_region=trust_region, tol=tol, maxiter=maxiter, reg=reg)
|
|
330
|
+
|
|
331
|
+
elif solver == 'minres':
|
|
332
|
+
x = minres(A_mm=H_mm, b=b, trust_region=trust_region, tol=tol, maxiter=maxiter, reg=reg, npc_terminate=False)
|
|
333
|
+
|
|
334
|
+
elif solver == 'minres_npc':
|
|
335
|
+
x = minres(A_mm=H_mm, b=b, trust_region=trust_region, tol=tol, maxiter=maxiter, reg=reg, npc_terminate=True)
|
|
336
|
+
|
|
337
|
+
else:
|
|
338
|
+
raise ValueError(f"unknown solver {solver}")
|
|
339
|
+
|
|
340
|
+
# ------------------------------- trust region ------------------------------- #
|
|
341
|
+
Hx = H_mm(x)
|
|
342
|
+
pred_reduction = b.dot(x) - 0.5 * x.dot(Hx)
|
|
343
|
+
|
|
344
|
+
params -= x
|
|
345
|
+
loss_star = closure(False)
|
|
346
|
+
params += x
|
|
347
|
+
reduction = var.get_loss(False) - loss_star
|
|
348
|
+
|
|
349
|
+
rho = reduction / (pred_reduction.clip(min=1e-8))
|
|
350
|
+
|
|
351
|
+
# failed step
|
|
352
|
+
if rho < 0.25:
|
|
353
|
+
self.global_state['trust_region'] = trust_region * nminus
|
|
354
|
+
|
|
355
|
+
# very good step
|
|
356
|
+
elif rho > 0.75:
|
|
357
|
+
diff = trust_region - x.abs()
|
|
358
|
+
if (diff.global_min() / trust_region) > 1e-4: # hits boundary
|
|
359
|
+
self.global_state['trust_region'] = trust_region * nplus
|
|
360
|
+
|
|
361
|
+
# if the ratio is high enough then accept the proposed step
|
|
362
|
+
if rho > eta:
|
|
363
|
+
success = True
|
|
364
|
+
|
|
365
|
+
assert x is not None
|
|
366
|
+
if success:
|
|
367
|
+
var.update = x
|
|
368
|
+
|
|
369
|
+
else:
|
|
370
|
+
var.update = params.zeros_like()
|
|
371
|
+
|
|
372
|
+
return var
|
|
83
373
|
|
|
84
374
|
|
|
@@ -6,16 +6,64 @@ import torch
|
|
|
6
6
|
from ...utils import TensorList, as_tensorlist, generic_zeros_like, generic_vector_norm, generic_numel, vec_to_tensors
|
|
7
7
|
from ...utils.derivatives import hvp, hvp_fd_central, hvp_fd_forward
|
|
8
8
|
|
|
9
|
-
from ...core import Chainable,
|
|
9
|
+
from ...core import Chainable, apply_transform, Module
|
|
10
10
|
from ...utils.linalg.solve import nystrom_sketch_and_solve, nystrom_pcg
|
|
11
11
|
|
|
12
12
|
class NystromSketchAndSolve(Module):
|
|
13
|
+
"""Newton's method with a Nyström sketch-and-solve solver.
|
|
14
|
+
|
|
15
|
+
.. note::
|
|
16
|
+
This module requires the a closure passed to the optimizer step,
|
|
17
|
+
as it needs to re-evaluate the loss and gradients for calculating HVPs.
|
|
18
|
+
The closure must accept a ``backward`` argument (refer to documentation).
|
|
19
|
+
|
|
20
|
+
.. note::
|
|
21
|
+
In most cases NystromSketchAndSolve should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
|
|
22
|
+
|
|
23
|
+
.. note::
|
|
24
|
+
If this is unstable, increase the :code:`reg` parameter and tune the rank.
|
|
25
|
+
|
|
26
|
+
.. note:
|
|
27
|
+
:code:`tz.m.NystromPCG` usually outperforms this.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
rank (int): size of the sketch, this many hessian-vector products will be evaluated per step.
|
|
31
|
+
reg (float, optional): regularization parameter. Defaults to 1e-3.
|
|
32
|
+
hvp_method (str, optional):
|
|
33
|
+
Determines how Hessian-vector products are evaluated.
|
|
34
|
+
|
|
35
|
+
- ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
|
|
36
|
+
This requires creating a graph for the gradient.
|
|
37
|
+
- ``"forward"``: Use a forward finite difference formula to
|
|
38
|
+
approximate the HVP. This requires one extra gradient evaluation.
|
|
39
|
+
- ``"central"``: Use a central finite difference formula for a
|
|
40
|
+
more accurate HVP approximation. This requires two extra
|
|
41
|
+
gradient evaluations.
|
|
42
|
+
Defaults to "autograd".
|
|
43
|
+
h (float, optional): finite difference step size if :code:`hvp_method` is "forward" or "central". Defaults to 1e-3.
|
|
44
|
+
inner (Chainable | None, optional): modules to apply hessian preconditioner to. Defaults to None.
|
|
45
|
+
seed (int | None, optional): seed for random generator. Defaults to None.
|
|
46
|
+
|
|
47
|
+
Examples:
|
|
48
|
+
NystromSketchAndSolve with backtracking line search
|
|
49
|
+
|
|
50
|
+
.. code-block:: python
|
|
51
|
+
|
|
52
|
+
opt = tz.Modular(
|
|
53
|
+
model.parameters(),
|
|
54
|
+
tz.m.NystromSketchAndSolve(10),
|
|
55
|
+
tz.m.Backtracking()
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
Reference:
|
|
59
|
+
Frangella, Z., Tropp, J. A., & Udell, M. (2023). Randomized nyström preconditioning. SIAM Journal on Matrix Analysis and Applications, 44(2), 718-752. https://arxiv.org/abs/2110.02820
|
|
60
|
+
"""
|
|
13
61
|
def __init__(
|
|
14
62
|
self,
|
|
15
63
|
rank: int,
|
|
16
64
|
reg: float = 1e-3,
|
|
17
65
|
hvp_method: Literal["forward", "central", "autograd"] = "autograd",
|
|
18
|
-
h=1e-
|
|
66
|
+
h: float = 1e-3,
|
|
19
67
|
inner: Chainable | None = None,
|
|
20
68
|
seed: int | None = None,
|
|
21
69
|
):
|
|
@@ -26,10 +74,10 @@ class NystromSketchAndSolve(Module):
|
|
|
26
74
|
self.set_child('inner', inner)
|
|
27
75
|
|
|
28
76
|
@torch.no_grad
|
|
29
|
-
def step(self,
|
|
30
|
-
params = TensorList(
|
|
77
|
+
def step(self, var):
|
|
78
|
+
params = TensorList(var.params)
|
|
31
79
|
|
|
32
|
-
closure =
|
|
80
|
+
closure = var.closure
|
|
33
81
|
if closure is None: raise RuntimeError('NewtonCG requires closure')
|
|
34
82
|
|
|
35
83
|
settings = self.settings[params[0]]
|
|
@@ -47,7 +95,7 @@ class NystromSketchAndSolve(Module):
|
|
|
47
95
|
|
|
48
96
|
# ---------------------- Hessian vector product function --------------------- #
|
|
49
97
|
if hvp_method == 'autograd':
|
|
50
|
-
grad =
|
|
98
|
+
grad = var.get_grad(create_graph=True)
|
|
51
99
|
|
|
52
100
|
def H_mm(x):
|
|
53
101
|
with torch.enable_grad():
|
|
@@ -57,7 +105,7 @@ class NystromSketchAndSolve(Module):
|
|
|
57
105
|
else:
|
|
58
106
|
|
|
59
107
|
with torch.enable_grad():
|
|
60
|
-
grad =
|
|
108
|
+
grad = var.get_grad()
|
|
61
109
|
|
|
62
110
|
if hvp_method == 'forward':
|
|
63
111
|
def H_mm(x):
|
|
@@ -74,18 +122,73 @@ class NystromSketchAndSolve(Module):
|
|
|
74
122
|
|
|
75
123
|
|
|
76
124
|
# -------------------------------- inner step -------------------------------- #
|
|
77
|
-
b =
|
|
125
|
+
b = var.get_update()
|
|
78
126
|
if 'inner' in self.children:
|
|
79
|
-
b =
|
|
127
|
+
b = apply_transform(self.children['inner'], b, params=params, grads=grad, var=var)
|
|
80
128
|
|
|
81
129
|
# ------------------------------ sketch&n&solve ------------------------------ #
|
|
82
130
|
x = nystrom_sketch_and_solve(A_mm=H_mm, b=torch.cat([t.ravel() for t in b]), rank=rank, reg=reg, generator=generator)
|
|
83
|
-
|
|
84
|
-
return
|
|
131
|
+
var.update = vec_to_tensors(x, reference=params)
|
|
132
|
+
return var
|
|
85
133
|
|
|
86
134
|
|
|
87
135
|
|
|
88
136
|
class NystromPCG(Module):
|
|
137
|
+
"""Newton's method with a Nyström-preconditioned conjugate gradient solver.
|
|
138
|
+
This tends to outperform NewtonCG but requires tuning sketch size.
|
|
139
|
+
An adaptive version exists in https://arxiv.org/abs/2110.02820, I might implement it too at some point.
|
|
140
|
+
|
|
141
|
+
.. note::
|
|
142
|
+
This module requires the a closure passed to the optimizer step,
|
|
143
|
+
as it needs to re-evaluate the loss and gradients for calculating HVPs.
|
|
144
|
+
The closure must accept a ``backward`` argument (refer to documentation).
|
|
145
|
+
|
|
146
|
+
.. note::
|
|
147
|
+
In most cases NystromPCG should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
sketch_size (int):
|
|
151
|
+
size of the sketch for preconditioning, this many hessian-vector products will be evaluated before
|
|
152
|
+
running the conjugate gradient solver. Larger value improves the preconditioning and speeds up
|
|
153
|
+
conjugate gradient.
|
|
154
|
+
maxiter (int | None, optional):
|
|
155
|
+
maximum number of iterations. By default this is set to the number of dimensions
|
|
156
|
+
in the objective function, which is supposed to be enough for conjugate gradient
|
|
157
|
+
to have guaranteed convergence. Setting this to a small value can still generate good enough directions.
|
|
158
|
+
Defaults to None.
|
|
159
|
+
tol (float, optional): relative tolerance for conjugate gradient solver. Defaults to 1e-4.
|
|
160
|
+
reg (float, optional): regularization parameter. Defaults to 1e-8.
|
|
161
|
+
hvp_method (str, optional):
|
|
162
|
+
Determines how Hessian-vector products are evaluated.
|
|
163
|
+
|
|
164
|
+
- ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
|
|
165
|
+
This requires creating a graph for the gradient.
|
|
166
|
+
- ``"forward"``: Use a forward finite difference formula to
|
|
167
|
+
approximate the HVP. This requires one extra gradient evaluation.
|
|
168
|
+
- ``"central"``: Use a central finite difference formula for a
|
|
169
|
+
more accurate HVP approximation. This requires two extra
|
|
170
|
+
gradient evaluations.
|
|
171
|
+
Defaults to "autograd".
|
|
172
|
+
h (float, optional): finite difference step size if :code:`hvp_method` is "forward" or "central". Defaults to 1e-3.
|
|
173
|
+
inner (Chainable | None, optional): modules to apply hessian preconditioner to. Defaults to None.
|
|
174
|
+
seed (int | None, optional): seed for random generator. Defaults to None.
|
|
175
|
+
|
|
176
|
+
Examples:
|
|
177
|
+
|
|
178
|
+
NystromPCG with backtracking line search
|
|
179
|
+
|
|
180
|
+
.. code-block:: python
|
|
181
|
+
|
|
182
|
+
opt = tz.Modular(
|
|
183
|
+
model.parameters(),
|
|
184
|
+
tz.m.NystromPCG(10),
|
|
185
|
+
tz.m.Backtracking()
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
Reference:
|
|
189
|
+
Frangella, Z., Tropp, J. A., & Udell, M. (2023). Randomized nyström preconditioning. SIAM Journal on Matrix Analysis and Applications, 44(2), 718-752. https://arxiv.org/abs/2110.02820
|
|
190
|
+
|
|
191
|
+
"""
|
|
89
192
|
def __init__(
|
|
90
193
|
self,
|
|
91
194
|
sketch_size: int,
|
|
@@ -93,7 +196,7 @@ class NystromPCG(Module):
|
|
|
93
196
|
tol=1e-3,
|
|
94
197
|
reg: float = 1e-6,
|
|
95
198
|
hvp_method: Literal["forward", "central", "autograd"] = "autograd",
|
|
96
|
-
h=1e-
|
|
199
|
+
h=1e-3,
|
|
97
200
|
inner: Chainable | None = None,
|
|
98
201
|
seed: int | None = None,
|
|
99
202
|
):
|
|
@@ -104,10 +207,10 @@ class NystromPCG(Module):
|
|
|
104
207
|
self.set_child('inner', inner)
|
|
105
208
|
|
|
106
209
|
@torch.no_grad
|
|
107
|
-
def step(self,
|
|
108
|
-
params = TensorList(
|
|
210
|
+
def step(self, var):
|
|
211
|
+
params = TensorList(var.params)
|
|
109
212
|
|
|
110
|
-
closure =
|
|
213
|
+
closure = var.closure
|
|
111
214
|
if closure is None: raise RuntimeError('NewtonCG requires closure')
|
|
112
215
|
|
|
113
216
|
settings = self.settings[params[0]]
|
|
@@ -129,7 +232,7 @@ class NystromPCG(Module):
|
|
|
129
232
|
|
|
130
233
|
# ---------------------- Hessian vector product function --------------------- #
|
|
131
234
|
if hvp_method == 'autograd':
|
|
132
|
-
grad =
|
|
235
|
+
grad = var.get_grad(create_graph=True)
|
|
133
236
|
|
|
134
237
|
def H_mm(x):
|
|
135
238
|
with torch.enable_grad():
|
|
@@ -139,7 +242,7 @@ class NystromPCG(Module):
|
|
|
139
242
|
else:
|
|
140
243
|
|
|
141
244
|
with torch.enable_grad():
|
|
142
|
-
grad =
|
|
245
|
+
grad = var.get_grad()
|
|
143
246
|
|
|
144
247
|
if hvp_method == 'forward':
|
|
145
248
|
def H_mm(x):
|
|
@@ -156,13 +259,13 @@ class NystromPCG(Module):
|
|
|
156
259
|
|
|
157
260
|
|
|
158
261
|
# -------------------------------- inner step -------------------------------- #
|
|
159
|
-
b =
|
|
262
|
+
b = var.get_update()
|
|
160
263
|
if 'inner' in self.children:
|
|
161
|
-
b =
|
|
264
|
+
b = apply_transform(self.children['inner'], b, params=params, grads=grad, var=var)
|
|
162
265
|
|
|
163
266
|
# ------------------------------ sketch&n&solve ------------------------------ #
|
|
164
267
|
x = nystrom_pcg(A_mm=H_mm, b=torch.cat([t.ravel() for t in b]), sketch_size=sketch_size, reg=reg, tol=tol, maxiter=maxiter, x0_=None, generator=generator)
|
|
165
|
-
|
|
166
|
-
return
|
|
268
|
+
var.update = vec_to_tensors(x, reference=params)
|
|
269
|
+
return var
|
|
167
270
|
|
|
168
271
|
|