torchzero 0.3.13__py3-none-any.whl → 0.3.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_opts.py +4 -10
- torchzero/core/__init__.py +4 -1
- torchzero/core/chain.py +50 -0
- torchzero/core/functional.py +37 -0
- torchzero/core/modular.py +237 -0
- torchzero/core/module.py +12 -599
- torchzero/core/reformulation.py +3 -1
- torchzero/core/transform.py +7 -5
- torchzero/core/var.py +376 -0
- torchzero/modules/__init__.py +0 -1
- torchzero/modules/adaptive/adahessian.py +2 -2
- torchzero/modules/adaptive/esgd.py +2 -2
- torchzero/modules/adaptive/matrix_momentum.py +1 -1
- torchzero/modules/adaptive/sophia_h.py +2 -2
- torchzero/modules/conjugate_gradient/cg.py +16 -16
- torchzero/modules/experimental/__init__.py +1 -0
- torchzero/modules/experimental/newtonnewton.py +5 -5
- torchzero/modules/experimental/spsa1.py +93 -0
- torchzero/modules/functional.py +7 -0
- torchzero/modules/grad_approximation/__init__.py +1 -1
- torchzero/modules/grad_approximation/forward_gradient.py +2 -5
- torchzero/modules/grad_approximation/rfdm.py +27 -110
- torchzero/modules/line_search/__init__.py +1 -1
- torchzero/modules/line_search/_polyinterp.py +3 -1
- torchzero/modules/line_search/adaptive.py +3 -3
- torchzero/modules/line_search/backtracking.py +1 -1
- torchzero/modules/line_search/interpolation.py +160 -0
- torchzero/modules/line_search/line_search.py +11 -20
- torchzero/modules/line_search/scipy.py +15 -3
- torchzero/modules/line_search/strong_wolfe.py +3 -5
- torchzero/modules/misc/misc.py +2 -2
- torchzero/modules/misc/multistep.py +13 -13
- torchzero/modules/quasi_newton/__init__.py +2 -0
- torchzero/modules/quasi_newton/quasi_newton.py +15 -6
- torchzero/modules/quasi_newton/sg2.py +292 -0
- torchzero/modules/restarts/restars.py +5 -4
- torchzero/modules/second_order/__init__.py +6 -3
- torchzero/modules/second_order/ifn.py +89 -0
- torchzero/modules/second_order/inm.py +105 -0
- torchzero/modules/second_order/newton.py +103 -193
- torchzero/modules/second_order/newton_cg.py +86 -110
- torchzero/modules/second_order/nystrom.py +1 -1
- torchzero/modules/second_order/rsn.py +227 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +2 -2
- torchzero/modules/trust_region/trust_cg.py +6 -4
- torchzero/modules/wrappers/optim_wrapper.py +49 -42
- torchzero/modules/zeroth_order/__init__.py +1 -1
- torchzero/modules/zeroth_order/cd.py +1 -238
- torchzero/utils/derivatives.py +19 -19
- torchzero/utils/linalg/linear_operator.py +50 -2
- torchzero/utils/optimizer.py +2 -2
- torchzero/utils/python_tools.py +1 -0
- {torchzero-0.3.13.dist-info → torchzero-0.3.15.dist-info}/METADATA +1 -1
- {torchzero-0.3.13.dist-info → torchzero-0.3.15.dist-info}/RECORD +57 -48
- torchzero/modules/higher_order/__init__.py +0 -1
- /torchzero/modules/{higher_order → experimental}/higher_order_newton.py +0 -0
- {torchzero-0.3.13.dist-info → torchzero-0.3.15.dist-info}/WHEEL +0 -0
- {torchzero-0.3.13.dist-info → torchzero-0.3.15.dist-info}/top_level.txt +0 -0
|
@@ -13,21 +13,12 @@ from ..trust_region.trust_region import default_radius
|
|
|
13
13
|
class NewtonCG(Module):
|
|
14
14
|
"""Newton's method with a matrix-free conjugate gradient or minimial-residual solver.
|
|
15
15
|
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
forming the full Hessian matrix, it only requires Hessian-vector products
|
|
19
|
-
(HVPs). These can be calculated efficiently using automatic
|
|
20
|
-
differentiation or approximated using finite differences.
|
|
16
|
+
Notes:
|
|
17
|
+
* In most cases NewtonCGSteihaug should be the first module in the chain because it relies on autograd. Use the ``inner`` argument if you wish to apply Newton preconditioning to another module's output.
|
|
21
18
|
|
|
22
|
-
|
|
23
|
-
In most cases NewtonCG should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
|
|
19
|
+
* This module requires the a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a ``backward`` argument (refer to documentation).
|
|
24
20
|
|
|
25
|
-
|
|
26
|
-
This module requires the a closure passed to the optimizer step,
|
|
27
|
-
as it needs to re-evaluate the loss and gradients for calculating HVPs.
|
|
28
|
-
The closure must accept a ``backward`` argument (refer to documentation).
|
|
29
|
-
|
|
30
|
-
.. warning::
|
|
21
|
+
Warning:
|
|
31
22
|
CG may fail if hessian is not positive-definite.
|
|
32
23
|
|
|
33
24
|
Args:
|
|
@@ -66,26 +57,24 @@ class NewtonCG(Module):
|
|
|
66
57
|
NewtonCG will attempt to apply preconditioning to the output of this module.
|
|
67
58
|
|
|
68
59
|
Examples:
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
)
|
|
88
|
-
|
|
60
|
+
Newton-CG with a backtracking line search:
|
|
61
|
+
|
|
62
|
+
```python
|
|
63
|
+
opt = tz.Modular(
|
|
64
|
+
model.parameters(),
|
|
65
|
+
tz.m.NewtonCG(),
|
|
66
|
+
tz.m.Backtracking()
|
|
67
|
+
)
|
|
68
|
+
```
|
|
69
|
+
|
|
70
|
+
Truncated Newton method (useful for large-scale problems):
|
|
71
|
+
```
|
|
72
|
+
opt = tz.Modular(
|
|
73
|
+
model.parameters(),
|
|
74
|
+
tz.m.NewtonCG(maxiter=10),
|
|
75
|
+
tz.m.Backtracking()
|
|
76
|
+
)
|
|
77
|
+
```
|
|
89
78
|
|
|
90
79
|
"""
|
|
91
80
|
def __init__(
|
|
@@ -95,7 +84,7 @@ class NewtonCG(Module):
|
|
|
95
84
|
reg: float = 1e-8,
|
|
96
85
|
hvp_method: Literal["forward", "central", "autograd"] = "autograd",
|
|
97
86
|
solver: Literal['cg', 'minres', 'minres_npc'] = 'cg',
|
|
98
|
-
h: float = 1e-3,
|
|
87
|
+
h: float = 1e-3, # tuned 1e-4 or 1e-3
|
|
99
88
|
miniter:int = 1,
|
|
100
89
|
warm_start=False,
|
|
101
90
|
inner: Chainable | None = None,
|
|
@@ -187,96 +176,95 @@ class NewtonCG(Module):
|
|
|
187
176
|
|
|
188
177
|
|
|
189
178
|
class NewtonCGSteihaug(Module):
|
|
190
|
-
"""
|
|
179
|
+
"""Newton's method with trust region and a matrix-free Steihaug-Toint conjugate gradient solver.
|
|
191
180
|
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
forming the full Hessian matrix, it only requires Hessian-vector products
|
|
195
|
-
(HVPs). These can be calculated efficiently using automatic
|
|
196
|
-
differentiation or approximated using finite differences.
|
|
181
|
+
Notes:
|
|
182
|
+
* In most cases NewtonCGSteihaug should be the first module in the chain because it relies on autograd. Use the ``inner`` argument if you wish to apply Newton preconditioning to another module's output.
|
|
197
183
|
|
|
198
|
-
|
|
199
|
-
In most cases NewtonCGSteihaug should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
|
|
200
|
-
|
|
201
|
-
.. note::
|
|
202
|
-
This module requires the a closure passed to the optimizer step,
|
|
203
|
-
as it needs to re-evaluate the loss and gradients for calculating HVPs.
|
|
204
|
-
The closure must accept a ``backward`` argument (refer to documentation).
|
|
205
|
-
|
|
206
|
-
.. warning::
|
|
207
|
-
CG may fail if hessian is not positive-definite.
|
|
184
|
+
* This module requires the a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a ``backward`` argument (refer to documentation).
|
|
208
185
|
|
|
209
186
|
Args:
|
|
210
|
-
maxiter (int | None, optional):
|
|
211
|
-
Maximum number of iterations for the conjugate gradient solver.
|
|
212
|
-
By default, this is set to the number of dimensions in the
|
|
213
|
-
objective function, which is the theoretical upper bound for CG
|
|
214
|
-
convergence. Setting this to a smaller value (truncated Newton)
|
|
215
|
-
can still generate good search directions. Defaults to None.
|
|
216
187
|
eta (float, optional):
|
|
217
|
-
|
|
218
|
-
nplus (float, optional):
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
trust region
|
|
222
|
-
|
|
188
|
+
if ratio of actual to predicted rediction is larger than this, step is accepted. Defaults to 0.0.
|
|
189
|
+
nplus (float, optional): increase factor on successful steps. Defaults to 1.5.
|
|
190
|
+
nminus (float, optional): decrease factor on unsuccessful steps. Defaults to 0.75.
|
|
191
|
+
rho_good (float, optional):
|
|
192
|
+
if ratio of actual to predicted rediction is larger than this, trust region size is multiplied by `nplus`.
|
|
193
|
+
rho_bad (float, optional):
|
|
194
|
+
if ratio of actual to predicted rediction is less than this, trust region size is multiplied by `nminus`.
|
|
195
|
+
init (float, optional): Initial trust region value. Defaults to 1.
|
|
196
|
+
max_attempts (max_attempts, optional):
|
|
197
|
+
maximum number of trust radius reductions per step. A zero update vector is returned when
|
|
198
|
+
this limit is exceeded. Defaults to 10.
|
|
199
|
+
max_history (int, optional):
|
|
200
|
+
CG will store this many intermediate solutions, reusing them when trust radius is reduced
|
|
201
|
+
instead of re-running CG. Each solution storage requires 2N memory. Defaults to 100.
|
|
202
|
+
boundary_tol (float | None, optional):
|
|
203
|
+
The trust region only increases when suggested step's norm is at least `(1-boundary_tol)*trust_region`.
|
|
204
|
+
This prevents increasing trust region when solution is not on the boundary. Defaults to 1e-2.
|
|
205
|
+
|
|
206
|
+
maxiter (int | None, optional):
|
|
207
|
+
maximum number of CG iterations per step. Each iteration requies one backward pass if `hvp_method="forward"`, two otherwise. Defaults to None.
|
|
208
|
+
miniter (int, optional):
|
|
209
|
+
minimal number of CG iterations. This prevents making no progress
|
|
223
210
|
tol (float, optional):
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
reg (float, optional):
|
|
227
|
-
|
|
228
|
-
|
|
211
|
+
terminates CG when norm of the residual is less than this value. Defaults to 1e-8.
|
|
212
|
+
when initial guess is below tolerance. Defaults to 1.
|
|
213
|
+
reg (float, optional): hessian regularization. Defaults to 1e-8.
|
|
214
|
+
solver (str, optional): solver, "cg" or "minres". "cg" is recommended. Defaults to 'cg'.
|
|
215
|
+
adapt_tol (bool, optional):
|
|
216
|
+
if True, whenever trust radius collapses to smallest representable number,
|
|
217
|
+
the tolerance is multiplied by 0.1. Defaults to True.
|
|
218
|
+
npc_terminate (bool, optional):
|
|
219
|
+
whether to terminate CG/MINRES whenever negative curvature is detected. Defaults to False.
|
|
220
|
+
|
|
229
221
|
hvp_method (str, optional):
|
|
230
|
-
|
|
222
|
+
either "forward" to use forward formula which requires one backward pass per Hvp, or "central" to use a more accurate central formula which requires two backward passes. "forward" is usually accurate enough. Defaults to "forward".
|
|
223
|
+
h (float, optional): finite difference step size. Defaults to 1e-3.
|
|
231
224
|
|
|
232
|
-
- ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
|
|
233
|
-
This requires creating a graph for the gradient.
|
|
234
|
-
- ``"forward"``: Use a forward finite difference formula to
|
|
235
|
-
approximate the HVP. This requires one extra gradient evaluation.
|
|
236
|
-
- ``"central"``: Use a central finite difference formula for a
|
|
237
|
-
more accurate HVP approximation. This requires two extra
|
|
238
|
-
gradient evaluations.
|
|
239
|
-
Defaults to "autograd".
|
|
240
|
-
h (float, optional):
|
|
241
|
-
The step size for finite differences if :code:`hvp_method` is
|
|
242
|
-
``"forward"`` or ``"central"``. Defaults to 1e-3.
|
|
243
225
|
inner (Chainable | None, optional):
|
|
244
|
-
|
|
226
|
+
applies preconditioning to output of this module. Defaults to None.
|
|
245
227
|
|
|
246
|
-
Examples:
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
.. code-block:: python
|
|
228
|
+
### Examples:
|
|
229
|
+
Trust-region Newton-CG:
|
|
250
230
|
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
231
|
+
```python
|
|
232
|
+
opt = tz.Modular(
|
|
233
|
+
model.parameters(),
|
|
234
|
+
tz.m.NewtonCGSteihaug(),
|
|
235
|
+
)
|
|
236
|
+
```
|
|
255
237
|
|
|
256
|
-
Reference:
|
|
238
|
+
### Reference:
|
|
257
239
|
Steihaug, Trond. "The conjugate gradient method and trust regions in large scale optimization." SIAM Journal on Numerical Analysis 20.3 (1983): 626-637.
|
|
258
240
|
"""
|
|
259
241
|
def __init__(
|
|
260
242
|
self,
|
|
261
|
-
|
|
243
|
+
# trust region settings
|
|
262
244
|
eta: float= 0.0,
|
|
263
245
|
nplus: float = 3.5,
|
|
264
246
|
nminus: float = 0.25,
|
|
265
247
|
rho_good: float = 0.99,
|
|
266
248
|
rho_bad: float = 1e-4,
|
|
267
249
|
init: float = 1,
|
|
268
|
-
tol: float = 1e-8,
|
|
269
|
-
reg: float = 1e-8,
|
|
270
|
-
hvp_method: Literal["forward", "central"] = "forward",
|
|
271
|
-
solver: Literal['cg', "minres"] = 'cg',
|
|
272
|
-
h: float = 1e-3,
|
|
273
250
|
max_attempts: int = 100,
|
|
274
251
|
max_history: int = 100,
|
|
275
|
-
boundary_tol: float = 1e-
|
|
252
|
+
boundary_tol: float = 1e-6, # tuned
|
|
253
|
+
|
|
254
|
+
# cg settings
|
|
255
|
+
maxiter: int | None = None,
|
|
276
256
|
miniter: int = 1,
|
|
277
|
-
|
|
257
|
+
tol: float = 1e-8,
|
|
258
|
+
reg: float = 1e-8,
|
|
259
|
+
solver: Literal['cg', "minres"] = 'cg',
|
|
278
260
|
adapt_tol: bool = True,
|
|
279
261
|
npc_terminate: bool = False,
|
|
262
|
+
|
|
263
|
+
# hvp settings
|
|
264
|
+
hvp_method: Literal["forward", "central"] = "central",
|
|
265
|
+
h: float = 1e-3, # tuned 1e-4 or 1e-3
|
|
266
|
+
|
|
267
|
+
# inner
|
|
280
268
|
inner: Chainable | None = None,
|
|
281
269
|
):
|
|
282
270
|
defaults = locals().copy()
|
|
@@ -336,19 +324,8 @@ class NewtonCGSteihaug(Module):
|
|
|
336
324
|
raise ValueError(hvp_method)
|
|
337
325
|
|
|
338
326
|
|
|
339
|
-
# ------------------------- update RMS preconditioner ------------------------ #
|
|
340
|
-
b = var.get_update()
|
|
341
|
-
P_mm = None
|
|
342
|
-
rms_beta = self.defaults["rms_beta"]
|
|
343
|
-
if rms_beta is not None:
|
|
344
|
-
exp_avg_sq = self.get_state(params, "exp_avg_sq", init=b, cls=TensorList)
|
|
345
|
-
exp_avg_sq.mul_(rms_beta).addcmul(b, b, value=1-rms_beta)
|
|
346
|
-
exp_avg_sq_sqrt = exp_avg_sq.sqrt().add_(1e-8)
|
|
347
|
-
def _P_mm(x):
|
|
348
|
-
return x / exp_avg_sq_sqrt
|
|
349
|
-
P_mm = _P_mm
|
|
350
|
-
|
|
351
327
|
# -------------------------------- inner step -------------------------------- #
|
|
328
|
+
b = var.get_update()
|
|
352
329
|
if 'inner' in self.children:
|
|
353
330
|
b = apply_transform(self.children['inner'], b, params=params, grads=grad, var=var)
|
|
354
331
|
b = as_tensorlist(b)
|
|
@@ -392,7 +369,6 @@ class NewtonCGSteihaug(Module):
|
|
|
392
369
|
miniter=miniter,
|
|
393
370
|
npc_terminate=npc_terminate,
|
|
394
371
|
history_size=max_history,
|
|
395
|
-
P_mm=P_mm,
|
|
396
372
|
)
|
|
397
373
|
|
|
398
374
|
elif solver == 'minres':
|
|
@@ -0,0 +1,227 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from collections import deque
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
from typing import Literal
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from ...core import Chainable, Module, apply_transform
|
|
9
|
+
from ...utils import Distributions, TensorList, vec_to_tensors
|
|
10
|
+
from ...utils.linalg.linear_operator import Sketched
|
|
11
|
+
from .newton import _newton_step
|
|
12
|
+
|
|
13
|
+
def _qr_orthonormalize(A:torch.Tensor):
|
|
14
|
+
m,n = A.shape
|
|
15
|
+
if m < n:
|
|
16
|
+
q, _ = torch.linalg.qr(A.T) # pylint:disable=not-callable
|
|
17
|
+
return q.T
|
|
18
|
+
else:
|
|
19
|
+
q, _ = torch.linalg.qr(A) # pylint:disable=not-callable
|
|
20
|
+
return q
|
|
21
|
+
|
|
22
|
+
def _orthonormal_sketch(m, n, dtype, device, generator):
|
|
23
|
+
return _qr_orthonormalize(torch.randn(m, n, dtype=dtype, device=device, generator=generator))
|
|
24
|
+
|
|
25
|
+
def _gaussian_sketch(m, n, dtype, device, generator):
|
|
26
|
+
return torch.randn(m, n, dtype=dtype, device=device, generator=generator) / math.sqrt(m)
|
|
27
|
+
|
|
28
|
+
class RSN(Module):
|
|
29
|
+
"""Randomized Subspace Newton. Performs a Newton step in a random subspace.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
sketch_size (int):
|
|
33
|
+
size of the random sketch. This many hessian-vector products will need to be evaluated each step.
|
|
34
|
+
sketch_type (str, optional):
|
|
35
|
+
- "orthonormal" - random orthonormal basis. Orthonormality is necessary to use linear operator based modules such as trust region, but it can be slower to compute.
|
|
36
|
+
- "gaussian" - random gaussian (not orthonormal) basis.
|
|
37
|
+
- "common_directions" - uses history steepest descent directions as the basis[2]. It is orthonormalized on-line using Gram-Schmidt.
|
|
38
|
+
- "mixed" - random orthonormal basis but with three directions set to gradient, slow EMA and fast EMA (default).
|
|
39
|
+
damping (float, optional): hessian damping (scale of identity matrix added to hessian). Defaults to 0.
|
|
40
|
+
hvp_method (str, optional):
|
|
41
|
+
How to compute hessian-matrix product:
|
|
42
|
+
- "batched" - uses batched autograd
|
|
43
|
+
- "autograd" - uses unbatched autograd
|
|
44
|
+
- "forward" - uses finite difference with forward formula, performing 1 backward pass per Hvp.
|
|
45
|
+
- "central" - uses finite difference with a more accurate central formula, performing 2 backward passes per Hvp.
|
|
46
|
+
|
|
47
|
+
. Defaults to "batched".
|
|
48
|
+
h (float, optional): finite difference step size. Defaults to 1e-2.
|
|
49
|
+
use_lstsq (bool, optional): whether to use least squares to solve ``Hx=g``. Defaults to False.
|
|
50
|
+
update_freq (int, optional): frequency of updating the hessian. Defaults to 1.
|
|
51
|
+
H_tfm (Callable | None, optional):
|
|
52
|
+
optional hessian transforms, takes in two arguments - `(hessian, gradient)`.
|
|
53
|
+
|
|
54
|
+
must return either a tuple: `(hessian, is_inverted)` with transformed hessian and a boolean value
|
|
55
|
+
which must be True if transform inverted the hessian and False otherwise.
|
|
56
|
+
|
|
57
|
+
Or it returns a single tensor which is used as the update.
|
|
58
|
+
|
|
59
|
+
Defaults to None.
|
|
60
|
+
eigval_fn (Callable | None, optional):
|
|
61
|
+
optional eigenvalues transform, for example ``torch.abs`` or ``lambda L: torch.clip(L, min=1e-8)``.
|
|
62
|
+
If this is specified, eigendecomposition will be used to invert the hessian.
|
|
63
|
+
seed (int | None, optional): seed for random generator. Defaults to None.
|
|
64
|
+
inner (Chainable | None, optional): preconditions output of this module. Defaults to None.
|
|
65
|
+
|
|
66
|
+
### Examples
|
|
67
|
+
|
|
68
|
+
RSN with line search
|
|
69
|
+
```python
|
|
70
|
+
opt = tz.Modular(
|
|
71
|
+
model.parameters(),
|
|
72
|
+
tz.m.RSN(),
|
|
73
|
+
tz.m.Backtracking()
|
|
74
|
+
)
|
|
75
|
+
```
|
|
76
|
+
|
|
77
|
+
RSN with trust region
|
|
78
|
+
```python
|
|
79
|
+
opt = tz.Modular(
|
|
80
|
+
model.parameters(),
|
|
81
|
+
tz.m.LevenbergMarquardt(tz.m.RSN()),
|
|
82
|
+
)
|
|
83
|
+
```
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
References:
|
|
87
|
+
1. [Gower, Robert, et al. "RSN: randomized subspace Newton." Advances in Neural Information Processing Systems 32 (2019).](https://arxiv.org/abs/1905.10874)
|
|
88
|
+
2. Wang, Po-Wei, Ching-pei Lee, and Chih-Jen Lin. "The common-directions method for regularized empirical risk minimization." Journal of Machine Learning Research 20.58 (2019): 1-49.
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
def __init__(
|
|
92
|
+
self,
|
|
93
|
+
sketch_size: int,
|
|
94
|
+
sketch_type: Literal["orthonormal", "gaussian", "common_directions", "mixed"] = "mixed",
|
|
95
|
+
damping:float=0,
|
|
96
|
+
hvp_method: Literal["batched", "autograd", "forward", "central"] = "batched",
|
|
97
|
+
h: float = 1e-2,
|
|
98
|
+
use_lstsq: bool = True,
|
|
99
|
+
update_freq: int = 1,
|
|
100
|
+
H_tfm: Callable[[torch.Tensor, torch.Tensor], tuple[torch.Tensor, bool]] | Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
|
|
101
|
+
eigval_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
|
|
102
|
+
seed: int | None = None,
|
|
103
|
+
inner: Chainable | None = None,
|
|
104
|
+
):
|
|
105
|
+
defaults = dict(sketch_size=sketch_size, sketch_type=sketch_type,seed=seed,hvp_method=hvp_method, h=h, damping=damping, use_lstsq=use_lstsq, H_tfm=H_tfm, eigval_fn=eigval_fn, update_freq=update_freq)
|
|
106
|
+
super().__init__(defaults)
|
|
107
|
+
|
|
108
|
+
if inner is not None:
|
|
109
|
+
self.set_child("inner", inner)
|
|
110
|
+
|
|
111
|
+
@torch.no_grad
|
|
112
|
+
def update(self, var):
|
|
113
|
+
step = self.global_state.get('step', 0)
|
|
114
|
+
self.global_state['step'] = step + 1
|
|
115
|
+
|
|
116
|
+
if step % self.defaults['update_freq'] == 0:
|
|
117
|
+
|
|
118
|
+
closure = var.closure
|
|
119
|
+
if closure is None:
|
|
120
|
+
raise RuntimeError("RSN requires closure")
|
|
121
|
+
params = var.params
|
|
122
|
+
generator = self.get_generator(params[0].device, self.defaults["seed"])
|
|
123
|
+
|
|
124
|
+
ndim = sum(p.numel() for p in params)
|
|
125
|
+
|
|
126
|
+
device=params[0].device
|
|
127
|
+
dtype=params[0].dtype
|
|
128
|
+
|
|
129
|
+
# sample sketch matrix S: (ndim, sketch_size)
|
|
130
|
+
sketch_size = min(self.defaults["sketch_size"], ndim)
|
|
131
|
+
sketch_type = self.defaults["sketch_type"]
|
|
132
|
+
hvp_method = self.defaults["hvp_method"]
|
|
133
|
+
|
|
134
|
+
if sketch_type in ('normal', 'gaussian'):
|
|
135
|
+
S = _gaussian_sketch(ndim, sketch_size, device=device, dtype=dtype, generator=generator)
|
|
136
|
+
|
|
137
|
+
elif sketch_type == 'orthonormal':
|
|
138
|
+
S = _orthonormal_sketch(ndim, sketch_size, device=device, dtype=dtype, generator=generator)
|
|
139
|
+
|
|
140
|
+
elif sketch_type == 'common_directions':
|
|
141
|
+
# Wang, Po-Wei, Ching-pei Lee, and Chih-Jen Lin. "The common-directions method for regularized empirical risk minimization." Journal of Machine Learning Research 20.58 (2019): 1-49.
|
|
142
|
+
g_list = var.get_grad(create_graph=hvp_method in ("batched", "autograd"))
|
|
143
|
+
g = torch.cat([t.ravel() for t in g_list])
|
|
144
|
+
|
|
145
|
+
# initialize directions deque
|
|
146
|
+
if "directions" not in self.global_state:
|
|
147
|
+
|
|
148
|
+
g_norm = torch.linalg.vector_norm(g) # pylint:disable=not-callable
|
|
149
|
+
if g_norm < torch.finfo(g.dtype).tiny * 2:
|
|
150
|
+
g = torch.randn_like(g)
|
|
151
|
+
g_norm = torch.linalg.vector_norm(g) # pylint:disable=not-callable
|
|
152
|
+
|
|
153
|
+
self.global_state["directions"] = deque([g / g_norm], maxlen=sketch_size)
|
|
154
|
+
S = self.global_state["directions"][0].unsqueeze(1)
|
|
155
|
+
|
|
156
|
+
# add new steepest descent direction orthonormal to existing columns
|
|
157
|
+
else:
|
|
158
|
+
S = torch.stack(tuple(self.global_state["directions"]), dim=1)
|
|
159
|
+
p = g - S @ (S.T @ g)
|
|
160
|
+
p_norm = torch.linalg.vector_norm(p) # pylint:disable=not-callable
|
|
161
|
+
if p_norm > torch.finfo(p.dtype).tiny * 2:
|
|
162
|
+
p = p / p_norm
|
|
163
|
+
self.global_state["directions"].append(p)
|
|
164
|
+
S = torch.cat([S, p.unsqueeze(1)], dim=1)
|
|
165
|
+
|
|
166
|
+
elif sketch_type == "mixed":
|
|
167
|
+
g_list = var.get_grad(create_graph=hvp_method in ("batched", "autograd"))
|
|
168
|
+
g = torch.cat([t.ravel() for t in g_list])
|
|
169
|
+
|
|
170
|
+
if "slow_ema" not in self.global_state:
|
|
171
|
+
self.global_state["slow_ema"] = torch.randn_like(g) * 1e-2
|
|
172
|
+
self.global_state["fast_ema"] = torch.randn_like(g) * 1e-2
|
|
173
|
+
|
|
174
|
+
slow_ema = self.global_state["slow_ema"]
|
|
175
|
+
fast_ema = self.global_state["fast_ema"]
|
|
176
|
+
slow_ema.lerp_(g, 0.001)
|
|
177
|
+
fast_ema.lerp_(g, 0.1)
|
|
178
|
+
|
|
179
|
+
S = torch.stack([g, slow_ema, fast_ema], dim=1)
|
|
180
|
+
if sketch_size > 3:
|
|
181
|
+
S_random = _gaussian_sketch(ndim, sketch_size - 3, device=device, dtype=dtype, generator=generator)
|
|
182
|
+
S = torch.cat([S, S_random], dim=1)
|
|
183
|
+
|
|
184
|
+
S = _qr_orthonormalize(S)
|
|
185
|
+
|
|
186
|
+
else:
|
|
187
|
+
raise ValueError(f'Unknown sketch_type {sketch_type}')
|
|
188
|
+
|
|
189
|
+
# form sketched hessian
|
|
190
|
+
HS, _ = var.hessian_matrix_product(S, at_x0=True, rgrad=None, hvp_method=self.defaults["hvp_method"], normalize=True, retain_graph=False, h=self.defaults["h"])
|
|
191
|
+
H_sketched = S.T @ HS
|
|
192
|
+
|
|
193
|
+
self.global_state["H_sketched"] = H_sketched
|
|
194
|
+
self.global_state["S"] = S
|
|
195
|
+
|
|
196
|
+
def apply(self, var):
|
|
197
|
+
S: torch.Tensor = self.global_state["S"]
|
|
198
|
+
d_proj = _newton_step(
|
|
199
|
+
var=var,
|
|
200
|
+
H=self.global_state["H_sketched"],
|
|
201
|
+
damping=self.defaults["damping"],
|
|
202
|
+
inner=self.children.get("inner", None),
|
|
203
|
+
H_tfm=self.defaults["H_tfm"],
|
|
204
|
+
eigval_fn=self.defaults["eigval_fn"],
|
|
205
|
+
use_lstsq=self.defaults["use_lstsq"],
|
|
206
|
+
g_proj = lambda g: S.T @ g
|
|
207
|
+
)
|
|
208
|
+
d = S @ d_proj
|
|
209
|
+
var.update = vec_to_tensors(d, var.params)
|
|
210
|
+
|
|
211
|
+
return var
|
|
212
|
+
|
|
213
|
+
def get_H(self, var=...):
|
|
214
|
+
eigval_fn = self.defaults["eigval_fn"]
|
|
215
|
+
H_sketched: torch.Tensor = self.global_state["H_sketched"]
|
|
216
|
+
S: torch.Tensor = self.global_state["S"]
|
|
217
|
+
|
|
218
|
+
if eigval_fn is not None:
|
|
219
|
+
try:
|
|
220
|
+
L, Q = torch.linalg.eigh(H_sketched) # pylint:disable=not-callable
|
|
221
|
+
L: torch.Tensor = eigval_fn(L)
|
|
222
|
+
H_sketched = Q @ L.diag_embed() @ Q.mH
|
|
223
|
+
|
|
224
|
+
except torch.linalg.LinAlgError:
|
|
225
|
+
pass
|
|
226
|
+
|
|
227
|
+
return Sketched(S, H_sketched)
|
|
@@ -14,13 +14,13 @@ class LevenbergMarquardt(TrustRegionBase):
|
|
|
14
14
|
hess_module (Module | None, optional):
|
|
15
15
|
A module that maintains a hessian approximation (not hessian inverse!).
|
|
16
16
|
This includes all full-matrix quasi-newton methods, ``tz.m.Newton`` and ``tz.m.GaussNewton``.
|
|
17
|
-
When using quasi-newton methods, set
|
|
17
|
+
When using quasi-newton methods, set ``inverse=False`` when constructing them.
|
|
18
18
|
y (float, optional):
|
|
19
19
|
when ``y=0``, identity matrix is added to hessian, when ``y=1``, diagonal of the hessian approximation
|
|
20
20
|
is added. Values between interpolate. This should only be used with Gauss-Newton. Defaults to 0.
|
|
21
21
|
eta (float, optional):
|
|
22
22
|
if ratio of actual to predicted rediction is larger than this, step is accepted.
|
|
23
|
-
When
|
|
23
|
+
When ``hess_module`` is ``Newton`` or ``GaussNewton``, this can be set to 0. Defaults to 0.15.
|
|
24
24
|
nplus (float, optional): increase factor on successful steps. Defaults to 1.5.
|
|
25
25
|
nminus (float, optional): decrease factor on unsuccessful steps. Defaults to 0.75.
|
|
26
26
|
rho_good (float, optional):
|
|
@@ -60,17 +60,19 @@ class TrustCG(TrustRegionBase):
|
|
|
60
60
|
nminus: float = 0.25,
|
|
61
61
|
rho_good: float = 0.99,
|
|
62
62
|
rho_bad: float = 1e-4,
|
|
63
|
-
boundary_tol: float | None = 1e-
|
|
63
|
+
boundary_tol: float | None = 1e-6, # tuned
|
|
64
64
|
init: float = 1,
|
|
65
65
|
max_attempts: int = 10,
|
|
66
66
|
radius_strategy: _RadiusStrategy | _RADIUS_KEYS = 'default',
|
|
67
67
|
reg: float = 0,
|
|
68
|
-
|
|
68
|
+
maxiter: int | None = None,
|
|
69
|
+
miniter: int = 1,
|
|
70
|
+
cg_tol: float = 1e-8,
|
|
69
71
|
prefer_exact: bool = True,
|
|
70
72
|
update_freq: int = 1,
|
|
71
73
|
inner: Chainable | None = None,
|
|
72
74
|
):
|
|
73
|
-
defaults = dict(reg=reg, prefer_exact=prefer_exact, cg_tol=cg_tol)
|
|
75
|
+
defaults = dict(reg=reg, prefer_exact=prefer_exact, cg_tol=cg_tol, maxiter=maxiter, miniter=miniter)
|
|
74
76
|
super().__init__(
|
|
75
77
|
defaults=defaults,
|
|
76
78
|
hess_module=hess_module,
|
|
@@ -93,5 +95,5 @@ class TrustCG(TrustRegionBase):
|
|
|
93
95
|
if settings['prefer_exact'] and isinstance(H, linear_operator.ScaledIdentity):
|
|
94
96
|
return H.solve_bounded(g, radius)
|
|
95
97
|
|
|
96
|
-
x, _ = cg(H.matvec, g, trust_radius=radius, reg=settings['reg'], tol=settings["cg_tol"])
|
|
98
|
+
x, _ = cg(H.matvec, g, trust_radius=radius, reg=settings['reg'], maxiter=settings["maxiter"], miniter=settings["miniter"], tol=settings["cg_tol"])
|
|
97
99
|
return x
|