torchzero 0.3.11__py3-none-any.whl → 0.3.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_opts.py +95 -76
- tests/test_tensorlist.py +8 -7
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +229 -72
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +44 -24
- torchzero/modules/__init__.py +13 -5
- torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
- torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
- torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
- torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
- torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
- torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
- torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
- torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
- torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
- torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
- torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
- torchzero/modules/clipping/clipping.py +85 -92
- torchzero/modules/clipping/ema_clipping.py +5 -5
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
- torchzero/modules/experimental/__init__.py +9 -32
- torchzero/modules/experimental/dct.py +2 -2
- torchzero/modules/experimental/fft.py +2 -2
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +27 -14
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/experimental/spsa1.py +93 -0
- torchzero/modules/experimental/structural_projections.py +1 -1
- torchzero/modules/functional.py +50 -14
- torchzero/modules/grad_approximation/__init__.py +1 -1
- torchzero/modules/grad_approximation/fdm.py +19 -20
- torchzero/modules/grad_approximation/forward_gradient.py +6 -7
- torchzero/modules/grad_approximation/grad_approximator.py +43 -47
- torchzero/modules/grad_approximation/rfdm.py +114 -175
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +31 -23
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +2 -2
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +69 -44
- torchzero/modules/line_search/backtracking.py +83 -70
- torchzero/modules/line_search/line_search.py +159 -68
- torchzero/modules/line_search/scipy.py +16 -4
- torchzero/modules/line_search/strong_wolfe.py +319 -220
- torchzero/modules/misc/__init__.py +8 -0
- torchzero/modules/misc/debug.py +4 -4
- torchzero/modules/misc/escape.py +9 -7
- torchzero/modules/misc/gradient_accumulation.py +88 -22
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +82 -15
- torchzero/modules/misc/multistep.py +47 -11
- torchzero/modules/misc/regularization.py +5 -9
- torchzero/modules/misc/split.py +55 -35
- torchzero/modules/misc/switch.py +1 -1
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +3 -3
- torchzero/modules/momentum/cautious.py +42 -47
- torchzero/modules/momentum/momentum.py +35 -1
- torchzero/modules/ops/__init__.py +9 -1
- torchzero/modules/ops/binary.py +9 -8
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
- torchzero/modules/ops/multi.py +15 -15
- torchzero/modules/ops/reduce.py +1 -1
- torchzero/modules/ops/utility.py +12 -8
- torchzero/modules/projections/projection.py +4 -4
- torchzero/modules/quasi_newton/__init__.py +1 -16
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
- torchzero/modules/quasi_newton/lbfgs.py +256 -200
- torchzero/modules/quasi_newton/lsr1.py +167 -132
- torchzero/modules/quasi_newton/quasi_newton.py +346 -446
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +253 -0
- torchzero/modules/second_order/__init__.py +2 -1
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +133 -88
- torchzero/modules/second_order/newton_cg.py +207 -170
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +1 -1
- torchzero/modules/step_size/adaptive.py +312 -47
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +99 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/weight_decay.py +65 -64
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +122 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +0 -1
- torchzero/optim/wrappers/fcmaes.py +3 -2
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +2 -2
- torchzero/optim/wrappers/scipy.py +81 -22
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +123 -111
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +226 -154
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/optimizer.py +2 -2
- torchzero/utils/python_tools.py +7 -0
- torchzero/utils/tensorlist.py +105 -34
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.14.dist-info/METADATA +14 -0
- torchzero-0.3.14.dist-info/RECORD +167 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -59
- docs/source/docstring template.py +0 -46
- torchzero/modules/experimental/absoap.py +0 -253
- torchzero/modules/experimental/adadam.py +0 -118
- torchzero/modules/experimental/adamY.py +0 -131
- torchzero/modules/experimental/adam_lambertw.py +0 -149
- torchzero/modules/experimental/adaptive_step_size.py +0 -90
- torchzero/modules/experimental/adasoap.py +0 -177
- torchzero/modules/experimental/cosine.py +0 -214
- torchzero/modules/experimental/cubic_adam.py +0 -97
- torchzero/modules/experimental/eigendescent.py +0 -120
- torchzero/modules/experimental/etf.py +0 -195
- torchzero/modules/experimental/exp_adam.py +0 -113
- torchzero/modules/experimental/expanded_lbfgs.py +0 -141
- torchzero/modules/experimental/hnewton.py +0 -85
- torchzero/modules/experimental/modular_lbfgs.py +0 -265
- torchzero/modules/experimental/parabolic_search.py +0 -220
- torchzero/modules/experimental/subspace_preconditioners.py +0 -145
- torchzero/modules/experimental/tensor_adagrad.py +0 -42
- torchzero/modules/line_search/polynomial.py +0 -233
- torchzero/modules/momentum/matrix_momentum.py +0 -193
- torchzero/modules/optimizers/adagrad.py +0 -165
- torchzero/modules/quasi_newton/trust_region.py +0 -397
- torchzero/modules/smoothing/gaussian.py +0 -198
- torchzero-0.3.11.dist-info/METADATA +0 -404
- torchzero-0.3.11.dist-info/RECORD +0 -159
- torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
- /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/WHEEL +0 -0
|
@@ -1,30 +1,24 @@
|
|
|
1
|
-
|
|
1
|
+
import warnings
|
|
2
|
+
import math
|
|
3
|
+
from typing import Literal, cast
|
|
4
|
+
from operator import itemgetter
|
|
2
5
|
import torch
|
|
3
6
|
|
|
4
|
-
from ...
|
|
7
|
+
from ...core import Chainable, Module, apply_transform
|
|
8
|
+
from ...utils import TensorList, as_tensorlist, tofloat
|
|
5
9
|
from ...utils.derivatives import hvp, hvp_fd_central, hvp_fd_forward
|
|
6
|
-
|
|
7
|
-
from
|
|
8
|
-
from ...utils.linalg.solve import cg, steihaug_toint_cg, minres
|
|
10
|
+
from ...utils.linalg.solve import cg, minres, find_within_trust_radius
|
|
11
|
+
from ..trust_region.trust_region import default_radius
|
|
9
12
|
|
|
10
13
|
class NewtonCG(Module):
|
|
11
14
|
"""Newton's method with a matrix-free conjugate gradient or minimial-residual solver.
|
|
12
15
|
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
forming the full Hessian matrix, it only requires Hessian-vector products
|
|
16
|
-
(HVPs). These can be calculated efficiently using automatic
|
|
17
|
-
differentiation or approximated using finite differences.
|
|
18
|
-
|
|
19
|
-
.. note::
|
|
20
|
-
In most cases NewtonCG should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
|
|
16
|
+
Notes:
|
|
17
|
+
* In most cases NewtonCGSteihaug should be the first module in the chain because it relies on autograd. Use the ``inner`` argument if you wish to apply Newton preconditioning to another module's output.
|
|
21
18
|
|
|
22
|
-
|
|
23
|
-
This module requires the a closure passed to the optimizer step,
|
|
24
|
-
as it needs to re-evaluate the loss and gradients for calculating HVPs.
|
|
25
|
-
The closure must accept a ``backward`` argument (refer to documentation).
|
|
19
|
+
* This module requires the a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a ``backward`` argument (refer to documentation).
|
|
26
20
|
|
|
27
|
-
|
|
21
|
+
Warning:
|
|
28
22
|
CG may fail if hessian is not positive-definite.
|
|
29
23
|
|
|
30
24
|
Args:
|
|
@@ -63,45 +57,48 @@ class NewtonCG(Module):
|
|
|
63
57
|
NewtonCG will attempt to apply preconditioning to the output of this module.
|
|
64
58
|
|
|
65
59
|
Examples:
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
)
|
|
85
|
-
|
|
60
|
+
Newton-CG with a backtracking line search:
|
|
61
|
+
|
|
62
|
+
```python
|
|
63
|
+
opt = tz.Modular(
|
|
64
|
+
model.parameters(),
|
|
65
|
+
tz.m.NewtonCG(),
|
|
66
|
+
tz.m.Backtracking()
|
|
67
|
+
)
|
|
68
|
+
```
|
|
69
|
+
|
|
70
|
+
Truncated Newton method (useful for large-scale problems):
|
|
71
|
+
```
|
|
72
|
+
opt = tz.Modular(
|
|
73
|
+
model.parameters(),
|
|
74
|
+
tz.m.NewtonCG(maxiter=10),
|
|
75
|
+
tz.m.Backtracking()
|
|
76
|
+
)
|
|
77
|
+
```
|
|
86
78
|
|
|
87
79
|
"""
|
|
88
80
|
def __init__(
|
|
89
81
|
self,
|
|
90
82
|
maxiter: int | None = None,
|
|
91
|
-
tol: float = 1e-
|
|
83
|
+
tol: float = 1e-8,
|
|
92
84
|
reg: float = 1e-8,
|
|
93
85
|
hvp_method: Literal["forward", "central", "autograd"] = "autograd",
|
|
94
86
|
solver: Literal['cg', 'minres', 'minres_npc'] = 'cg',
|
|
95
|
-
h: float = 1e-3,
|
|
87
|
+
h: float = 1e-3, # tuned 1e-4 or 1e-3
|
|
88
|
+
miniter:int = 1,
|
|
96
89
|
warm_start=False,
|
|
97
90
|
inner: Chainable | None = None,
|
|
98
91
|
):
|
|
99
|
-
defaults =
|
|
92
|
+
defaults = locals().copy()
|
|
93
|
+
del defaults['self'], defaults['inner']
|
|
100
94
|
super().__init__(defaults,)
|
|
101
95
|
|
|
102
96
|
if inner is not None:
|
|
103
97
|
self.set_child('inner', inner)
|
|
104
98
|
|
|
99
|
+
self._num_hvps = 0
|
|
100
|
+
self._num_hvps_last_step = 0
|
|
101
|
+
|
|
105
102
|
@torch.no_grad
|
|
106
103
|
def step(self, var):
|
|
107
104
|
params = TensorList(var.params)
|
|
@@ -117,11 +114,13 @@ class NewtonCG(Module):
|
|
|
117
114
|
h = settings['h']
|
|
118
115
|
warm_start = settings['warm_start']
|
|
119
116
|
|
|
117
|
+
self._num_hvps_last_step = 0
|
|
120
118
|
# ---------------------- Hessian vector product function --------------------- #
|
|
121
119
|
if hvp_method == 'autograd':
|
|
122
120
|
grad = var.get_grad(create_graph=True)
|
|
123
121
|
|
|
124
122
|
def H_mm(x):
|
|
123
|
+
self._num_hvps_last_step += 1
|
|
125
124
|
with torch.enable_grad():
|
|
126
125
|
return TensorList(hvp(params, grad, x, retain_graph=True))
|
|
127
126
|
|
|
@@ -132,10 +131,12 @@ class NewtonCG(Module):
|
|
|
132
131
|
|
|
133
132
|
if hvp_method == 'forward':
|
|
134
133
|
def H_mm(x):
|
|
134
|
+
self._num_hvps_last_step += 1
|
|
135
135
|
return TensorList(hvp_fd_forward(closure, params, x, h=h, g_0=grad, normalize=True)[1])
|
|
136
136
|
|
|
137
137
|
elif hvp_method == 'central':
|
|
138
138
|
def H_mm(x):
|
|
139
|
+
self._num_hvps_last_step += 1
|
|
139
140
|
return TensorList(hvp_fd_central(closure, params, x, h=h, normalize=True)[1])
|
|
140
141
|
|
|
141
142
|
else:
|
|
@@ -153,141 +154,154 @@ class NewtonCG(Module):
|
|
|
153
154
|
if warm_start: x0 = self.get_state(params, 'prev_x', cls=TensorList) # initialized to 0 which is default anyway
|
|
154
155
|
|
|
155
156
|
if solver == 'cg':
|
|
156
|
-
|
|
157
|
+
d, _ = cg(A_mm=H_mm, b=b, x0=x0, tol=tol, maxiter=maxiter, miniter=self.defaults["miniter"],reg=reg)
|
|
157
158
|
|
|
158
159
|
elif solver == 'minres':
|
|
159
|
-
|
|
160
|
+
d = minres(A_mm=H_mm, b=b, x0=x0, tol=tol, maxiter=maxiter, reg=reg, npc_terminate=False)
|
|
160
161
|
|
|
161
162
|
elif solver == 'minres_npc':
|
|
162
|
-
|
|
163
|
+
d = minres(A_mm=H_mm, b=b, x0=x0, tol=tol, maxiter=maxiter, reg=reg, npc_terminate=True)
|
|
163
164
|
|
|
164
165
|
else:
|
|
165
166
|
raise ValueError(f"Unknown solver {solver}")
|
|
166
167
|
|
|
167
168
|
if warm_start:
|
|
168
169
|
assert x0 is not None
|
|
169
|
-
x0.copy_(
|
|
170
|
-
|
|
171
|
-
var.update = x
|
|
172
|
-
return var
|
|
170
|
+
x0.copy_(d)
|
|
173
171
|
|
|
172
|
+
var.update = d
|
|
174
173
|
|
|
175
|
-
|
|
176
|
-
|
|
174
|
+
self._num_hvps += self._num_hvps_last_step
|
|
175
|
+
return var
|
|
177
176
|
|
|
178
|
-
This optimizer implements Newton's method using a matrix-free conjugate
|
|
179
|
-
gradient (CG) solver to approximate the search direction. Instead of
|
|
180
|
-
forming the full Hessian matrix, it only requires Hessian-vector products
|
|
181
|
-
(HVPs). These can be calculated efficiently using automatic
|
|
182
|
-
differentiation or approximated using finite differences.
|
|
183
177
|
|
|
184
|
-
|
|
185
|
-
|
|
178
|
+
class NewtonCGSteihaug(Module):
|
|
179
|
+
"""Newton's method with trust region and a matrix-free Steihaug-Toint conjugate gradient solver.
|
|
186
180
|
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
as it needs to re-evaluate the loss and gradients for calculating HVPs.
|
|
190
|
-
The closure must accept a ``backward`` argument (refer to documentation).
|
|
181
|
+
Notes:
|
|
182
|
+
* In most cases NewtonCGSteihaug should be the first module in the chain because it relies on autograd. Use the ``inner`` argument if you wish to apply Newton preconditioning to another module's output.
|
|
191
183
|
|
|
192
|
-
|
|
193
|
-
CG may fail if hessian is not positive-definite.
|
|
184
|
+
* This module requires the a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a ``backward`` argument (refer to documentation).
|
|
194
185
|
|
|
195
186
|
Args:
|
|
196
|
-
maxiter (int | None, optional):
|
|
197
|
-
Maximum number of iterations for the conjugate gradient solver.
|
|
198
|
-
By default, this is set to the number of dimensions in the
|
|
199
|
-
objective function, which is the theoretical upper bound for CG
|
|
200
|
-
convergence. Setting this to a smaller value (truncated Newton)
|
|
201
|
-
can still generate good search directions. Defaults to None.
|
|
202
187
|
eta (float, optional):
|
|
203
|
-
|
|
204
|
-
nplus (float, optional):
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
trust region
|
|
208
|
-
|
|
188
|
+
if ratio of actual to predicted rediction is larger than this, step is accepted. Defaults to 0.0.
|
|
189
|
+
nplus (float, optional): increase factor on successful steps. Defaults to 1.5.
|
|
190
|
+
nminus (float, optional): decrease factor on unsuccessful steps. Defaults to 0.75.
|
|
191
|
+
rho_good (float, optional):
|
|
192
|
+
if ratio of actual to predicted rediction is larger than this, trust region size is multiplied by `nplus`.
|
|
193
|
+
rho_bad (float, optional):
|
|
194
|
+
if ratio of actual to predicted rediction is less than this, trust region size is multiplied by `nminus`.
|
|
195
|
+
init (float, optional): Initial trust region value. Defaults to 1.
|
|
196
|
+
max_attempts (max_attempts, optional):
|
|
197
|
+
maximum number of trust radius reductions per step. A zero update vector is returned when
|
|
198
|
+
this limit is exceeded. Defaults to 10.
|
|
199
|
+
max_history (int, optional):
|
|
200
|
+
CG will store this many intermediate solutions, reusing them when trust radius is reduced
|
|
201
|
+
instead of re-running CG. Each solution storage requires 2N memory. Defaults to 100.
|
|
202
|
+
boundary_tol (float | None, optional):
|
|
203
|
+
The trust region only increases when suggested step's norm is at least `(1-boundary_tol)*trust_region`.
|
|
204
|
+
This prevents increasing trust region when solution is not on the boundary. Defaults to 1e-2.
|
|
205
|
+
|
|
206
|
+
maxiter (int | None, optional):
|
|
207
|
+
maximum number of CG iterations per step. Each iteration requies one backward pass if `hvp_method="forward"`, two otherwise. Defaults to None.
|
|
208
|
+
miniter (int, optional):
|
|
209
|
+
minimal number of CG iterations. This prevents making no progress
|
|
209
210
|
tol (float, optional):
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
reg (float, optional):
|
|
213
|
-
|
|
214
|
-
|
|
211
|
+
terminates CG when norm of the residual is less than this value. Defaults to 1e-8.
|
|
212
|
+
when initial guess is below tolerance. Defaults to 1.
|
|
213
|
+
reg (float, optional): hessian regularization. Defaults to 1e-8.
|
|
214
|
+
solver (str, optional): solver, "cg" or "minres". "cg" is recommended. Defaults to 'cg'.
|
|
215
|
+
adapt_tol (bool, optional):
|
|
216
|
+
if True, whenever trust radius collapses to smallest representable number,
|
|
217
|
+
the tolerance is multiplied by 0.1. Defaults to True.
|
|
218
|
+
npc_terminate (bool, optional):
|
|
219
|
+
whether to terminate CG/MINRES whenever negative curvature is detected. Defaults to False.
|
|
220
|
+
|
|
215
221
|
hvp_method (str, optional):
|
|
216
|
-
|
|
222
|
+
either "forward" to use forward formula which requires one backward pass per Hvp, or "central" to use a more accurate central formula which requires two backward passes. "forward" is usually accurate enough. Defaults to "forward".
|
|
223
|
+
h (float, optional): finite difference step size. Defaults to 1e-3.
|
|
217
224
|
|
|
218
|
-
- ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
|
|
219
|
-
This requires creating a graph for the gradient.
|
|
220
|
-
- ``"forward"``: Use a forward finite difference formula to
|
|
221
|
-
approximate the HVP. This requires one extra gradient evaluation.
|
|
222
|
-
- ``"central"``: Use a central finite difference formula for a
|
|
223
|
-
more accurate HVP approximation. This requires two extra
|
|
224
|
-
gradient evaluations.
|
|
225
|
-
Defaults to "autograd".
|
|
226
|
-
h (float, optional):
|
|
227
|
-
The step size for finite differences if :code:`hvp_method` is
|
|
228
|
-
``"forward"`` or ``"central"``. Defaults to 1e-3.
|
|
229
225
|
inner (Chainable | None, optional):
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
Examples:
|
|
233
|
-
Trust-region Newton-CG:
|
|
226
|
+
applies preconditioning to output of this module. Defaults to None.
|
|
234
227
|
|
|
235
|
-
|
|
228
|
+
### Examples:
|
|
229
|
+
Trust-region Newton-CG:
|
|
236
230
|
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
231
|
+
```python
|
|
232
|
+
opt = tz.Modular(
|
|
233
|
+
model.parameters(),
|
|
234
|
+
tz.m.NewtonCGSteihaug(),
|
|
235
|
+
)
|
|
236
|
+
```
|
|
241
237
|
|
|
242
|
-
Reference:
|
|
238
|
+
### Reference:
|
|
243
239
|
Steihaug, Trond. "The conjugate gradient method and trust regions in large scale optimization." SIAM Journal on Numerical Analysis 20.3 (1983): 626-637.
|
|
244
240
|
"""
|
|
245
241
|
def __init__(
|
|
246
242
|
self,
|
|
247
|
-
|
|
248
|
-
eta: float=
|
|
249
|
-
nplus: float =
|
|
243
|
+
# trust region settings
|
|
244
|
+
eta: float= 0.0,
|
|
245
|
+
nplus: float = 3.5,
|
|
250
246
|
nminus: float = 0.25,
|
|
247
|
+
rho_good: float = 0.99,
|
|
248
|
+
rho_bad: float = 1e-4,
|
|
251
249
|
init: float = 1,
|
|
252
|
-
|
|
250
|
+
max_attempts: int = 100,
|
|
251
|
+
max_history: int = 100,
|
|
252
|
+
boundary_tol: float = 1e-6, # tuned
|
|
253
|
+
|
|
254
|
+
# cg settings
|
|
255
|
+
maxiter: int | None = None,
|
|
256
|
+
miniter: int = 1,
|
|
257
|
+
tol: float = 1e-8,
|
|
253
258
|
reg: float = 1e-8,
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
259
|
+
solver: Literal['cg', "minres"] = 'cg',
|
|
260
|
+
adapt_tol: bool = True,
|
|
261
|
+
npc_terminate: bool = False,
|
|
262
|
+
|
|
263
|
+
# hvp settings
|
|
264
|
+
hvp_method: Literal["forward", "central"] = "central",
|
|
265
|
+
h: float = 1e-3, # tuned 1e-4 or 1e-3
|
|
266
|
+
|
|
267
|
+
# inner
|
|
258
268
|
inner: Chainable | None = None,
|
|
259
269
|
):
|
|
260
|
-
defaults =
|
|
270
|
+
defaults = locals().copy()
|
|
271
|
+
del defaults['self'], defaults['inner']
|
|
261
272
|
super().__init__(defaults,)
|
|
262
273
|
|
|
263
274
|
if inner is not None:
|
|
264
275
|
self.set_child('inner', inner)
|
|
265
276
|
|
|
277
|
+
self._num_hvps = 0
|
|
278
|
+
self._num_hvps_last_step = 0
|
|
279
|
+
|
|
266
280
|
@torch.no_grad
|
|
267
281
|
def step(self, var):
|
|
268
282
|
params = TensorList(var.params)
|
|
269
283
|
closure = var.closure
|
|
270
284
|
if closure is None: raise RuntimeError('NewtonCG requires closure')
|
|
271
285
|
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
maxiter
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
286
|
+
tol = self.defaults['tol'] * self.global_state.get('tol_mul', 1)
|
|
287
|
+
solver = self.defaults['solver'].lower().strip()
|
|
288
|
+
|
|
289
|
+
(reg, maxiter, hvp_method, h, max_attempts, boundary_tol,
|
|
290
|
+
eta, nplus, nminus, rho_good, rho_bad, init, npc_terminate,
|
|
291
|
+
miniter, max_history, adapt_tol) = itemgetter(
|
|
292
|
+
"reg", "maxiter", "hvp_method", "h", "max_attempts", "boundary_tol",
|
|
293
|
+
"eta", "nplus", "nminus", "rho_good", "rho_bad", "init", "npc_terminate",
|
|
294
|
+
"miniter", "max_history", "adapt_tol",
|
|
295
|
+
)(self.defaults)
|
|
280
296
|
|
|
281
|
-
|
|
282
|
-
nplus = settings['nplus']
|
|
283
|
-
nminus = settings['nminus']
|
|
284
|
-
init = settings['init']
|
|
297
|
+
self._num_hvps_last_step = 0
|
|
285
298
|
|
|
286
299
|
# ---------------------- Hessian vector product function --------------------- #
|
|
287
300
|
if hvp_method == 'autograd':
|
|
288
301
|
grad = var.get_grad(create_graph=True)
|
|
289
302
|
|
|
290
303
|
def H_mm(x):
|
|
304
|
+
self._num_hvps_last_step += 1
|
|
291
305
|
with torch.enable_grad():
|
|
292
306
|
return TensorList(hvp(params, grad, x, retain_graph=True))
|
|
293
307
|
|
|
@@ -298,10 +312,12 @@ class TruncatedNewtonCG(Module):
|
|
|
298
312
|
|
|
299
313
|
if hvp_method == 'forward':
|
|
300
314
|
def H_mm(x):
|
|
315
|
+
self._num_hvps_last_step += 1
|
|
301
316
|
return TensorList(hvp_fd_forward(closure, params, x, h=h, g_0=grad, normalize=True)[1])
|
|
302
317
|
|
|
303
318
|
elif hvp_method == 'central':
|
|
304
319
|
def H_mm(x):
|
|
320
|
+
self._num_hvps_last_step += 1
|
|
305
321
|
return TensorList(hvp_fd_central(closure, params, x, h=h, normalize=True)[1])
|
|
306
322
|
|
|
307
323
|
else:
|
|
@@ -314,61 +330,82 @@ class TruncatedNewtonCG(Module):
|
|
|
314
330
|
b = apply_transform(self.children['inner'], b, params=params, grads=grad, var=var)
|
|
315
331
|
b = as_tensorlist(b)
|
|
316
332
|
|
|
317
|
-
#
|
|
333
|
+
# ------------------------------- trust region ------------------------------- #
|
|
318
334
|
success = False
|
|
319
|
-
|
|
335
|
+
d = None
|
|
336
|
+
x0 = [p.clone() for p in params]
|
|
337
|
+
solution = None
|
|
338
|
+
|
|
320
339
|
while not success:
|
|
321
340
|
max_attempts -= 1
|
|
322
341
|
if max_attempts < 0: break
|
|
323
342
|
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
if
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
#
|
|
362
|
-
|
|
363
|
-
|
|
343
|
+
trust_radius = self.global_state.get('trust_radius', init)
|
|
344
|
+
|
|
345
|
+
# -------------- make sure trust radius isn't too small or large ------------- #
|
|
346
|
+
finfo = torch.finfo(x0[0].dtype)
|
|
347
|
+
if trust_radius < finfo.tiny * 2:
|
|
348
|
+
trust_radius = self.global_state['trust_radius'] = init
|
|
349
|
+
if adapt_tol:
|
|
350
|
+
self.global_state["tol_mul"] = self.global_state.get("tol_mul", 1) * 0.1
|
|
351
|
+
|
|
352
|
+
elif trust_radius > finfo.max / 2:
|
|
353
|
+
trust_radius = self.global_state['trust_radius'] = init
|
|
354
|
+
|
|
355
|
+
# ----------------------------------- solve ---------------------------------- #
|
|
356
|
+
d = None
|
|
357
|
+
if solution is not None and solution.history is not None:
|
|
358
|
+
d = find_within_trust_radius(solution.history, trust_radius)
|
|
359
|
+
|
|
360
|
+
if d is None:
|
|
361
|
+
if solver == 'cg':
|
|
362
|
+
d, solution = cg(
|
|
363
|
+
A_mm=H_mm,
|
|
364
|
+
b=b,
|
|
365
|
+
tol=tol,
|
|
366
|
+
maxiter=maxiter,
|
|
367
|
+
reg=reg,
|
|
368
|
+
trust_radius=trust_radius,
|
|
369
|
+
miniter=miniter,
|
|
370
|
+
npc_terminate=npc_terminate,
|
|
371
|
+
history_size=max_history,
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
elif solver == 'minres':
|
|
375
|
+
d = minres(A_mm=H_mm, b=b, trust_radius=trust_radius, tol=tol, maxiter=maxiter, reg=reg, npc_terminate=npc_terminate)
|
|
376
|
+
|
|
377
|
+
else:
|
|
378
|
+
raise ValueError(f"unknown solver {solver}")
|
|
379
|
+
|
|
380
|
+
# ---------------------------- update trust radius --------------------------- #
|
|
381
|
+
self.global_state["trust_radius"], success = default_radius(
|
|
382
|
+
params=params,
|
|
383
|
+
closure=closure,
|
|
384
|
+
f=tofloat(var.get_loss(False)),
|
|
385
|
+
g=b,
|
|
386
|
+
H=H_mm,
|
|
387
|
+
d=d,
|
|
388
|
+
trust_radius=trust_radius,
|
|
389
|
+
eta=eta,
|
|
390
|
+
nplus=nplus,
|
|
391
|
+
nminus=nminus,
|
|
392
|
+
rho_good=rho_good,
|
|
393
|
+
rho_bad=rho_bad,
|
|
394
|
+
boundary_tol=boundary_tol,
|
|
395
|
+
|
|
396
|
+
init=init, # init isn't used because check_overflow=False
|
|
397
|
+
state=self.global_state, # not used
|
|
398
|
+
settings=self.defaults, # not used
|
|
399
|
+
check_overflow=False, # this is checked manually to adapt tolerance
|
|
400
|
+
)
|
|
364
401
|
|
|
365
|
-
|
|
402
|
+
# --------------------------- assign new direction --------------------------- #
|
|
403
|
+
assert d is not None
|
|
366
404
|
if success:
|
|
367
|
-
var.update =
|
|
405
|
+
var.update = d
|
|
368
406
|
|
|
369
407
|
else:
|
|
370
408
|
var.update = params.zeros_like()
|
|
371
409
|
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
410
|
+
self._num_hvps += self._num_hvps_last_step
|
|
411
|
+
return var
|
|
@@ -1,2 +1,2 @@
|
|
|
1
1
|
from .laplacian import LaplacianSmoothing
|
|
2
|
-
from .
|
|
2
|
+
from .sampling import GradientSampling
|