torchzero 0.3.15__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +22 -22
- tests/test_module_autograd.py +586 -0
- tests/test_objective.py +188 -0
- tests/test_opts.py +225 -214
- tests/test_tensorlist.py +0 -8
- tests/test_utils_optimizer.py +0 -1
- torchzero/__init__.py +2 -2
- torchzero/core/__init__.py +7 -4
- torchzero/core/chain.py +20 -23
- torchzero/core/functional.py +90 -24
- torchzero/core/modular.py +53 -57
- torchzero/core/module.py +132 -52
- torchzero/core/objective.py +948 -0
- torchzero/core/reformulation.py +55 -24
- torchzero/core/transform.py +261 -367
- torchzero/linalg/__init__.py +11 -0
- torchzero/linalg/eigh.py +253 -0
- torchzero/linalg/linalg_utils.py +14 -0
- torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
- torchzero/linalg/matrix_power.py +28 -0
- torchzero/linalg/orthogonalize.py +93 -0
- torchzero/{utils/linalg → linalg}/qr.py +16 -2
- torchzero/{utils/linalg → linalg}/solve.py +74 -88
- torchzero/linalg/svd.py +47 -0
- torchzero/linalg/torch_linalg.py +168 -0
- torchzero/modules/__init__.py +4 -3
- torchzero/modules/adaptive/__init__.py +11 -3
- torchzero/modules/adaptive/adagrad.py +167 -217
- torchzero/modules/adaptive/adahessian.py +76 -105
- torchzero/modules/adaptive/adam.py +53 -76
- torchzero/modules/adaptive/adan.py +50 -31
- torchzero/modules/adaptive/adaptive_heavyball.py +12 -7
- torchzero/modules/adaptive/aegd.py +12 -12
- torchzero/modules/adaptive/esgd.py +98 -119
- torchzero/modules/adaptive/ggt.py +186 -0
- torchzero/modules/adaptive/lion.py +7 -11
- torchzero/modules/adaptive/lre_optimizers.py +299 -0
- torchzero/modules/adaptive/mars.py +7 -7
- torchzero/modules/adaptive/matrix_momentum.py +48 -52
- torchzero/modules/adaptive/msam.py +71 -53
- torchzero/modules/adaptive/muon.py +67 -129
- torchzero/modules/adaptive/natural_gradient.py +63 -41
- torchzero/modules/adaptive/orthograd.py +11 -15
- torchzero/modules/adaptive/psgd/__init__.py +5 -0
- torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
- torchzero/modules/adaptive/psgd/psgd.py +1390 -0
- torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
- torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
- torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
- torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
- torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
- torchzero/modules/adaptive/rmsprop.py +83 -75
- torchzero/modules/adaptive/rprop.py +48 -47
- torchzero/modules/adaptive/sam.py +55 -45
- torchzero/modules/adaptive/shampoo.py +149 -130
- torchzero/modules/adaptive/soap.py +207 -143
- torchzero/modules/adaptive/sophia_h.py +106 -130
- torchzero/modules/clipping/clipping.py +22 -25
- torchzero/modules/clipping/ema_clipping.py +31 -25
- torchzero/modules/clipping/growth_clipping.py +14 -17
- torchzero/modules/conjugate_gradient/cg.py +27 -38
- torchzero/modules/experimental/__init__.py +7 -6
- torchzero/modules/experimental/adanystrom.py +258 -0
- torchzero/modules/experimental/common_directions_whiten.py +142 -0
- torchzero/modules/experimental/coordinate_momentum.py +36 -0
- torchzero/modules/experimental/cubic_adam.py +160 -0
- torchzero/modules/experimental/curveball.py +25 -41
- torchzero/modules/experimental/eigen_sr1.py +182 -0
- torchzero/modules/experimental/eigengrad.py +207 -0
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/higher_order_newton.py +14 -40
- torchzero/modules/experimental/l_infinity.py +1 -1
- torchzero/modules/experimental/matrix_nag.py +122 -0
- torchzero/modules/experimental/newton_solver.py +23 -54
- torchzero/modules/experimental/newtonnewton.py +45 -48
- torchzero/modules/experimental/reduce_outward_lr.py +7 -7
- torchzero/modules/experimental/scipy_newton_cg.py +21 -24
- torchzero/modules/experimental/spsa1.py +3 -3
- torchzero/modules/experimental/structural_projections.py +1 -4
- torchzero/modules/grad_approximation/fdm.py +2 -2
- torchzero/modules/grad_approximation/forward_gradient.py +7 -7
- torchzero/modules/grad_approximation/grad_approximator.py +23 -16
- torchzero/modules/grad_approximation/rfdm.py +24 -21
- torchzero/modules/least_squares/gn.py +121 -50
- torchzero/modules/line_search/backtracking.py +4 -4
- torchzero/modules/line_search/line_search.py +33 -33
- torchzero/modules/line_search/strong_wolfe.py +4 -4
- torchzero/modules/misc/debug.py +12 -12
- torchzero/modules/misc/escape.py +10 -10
- torchzero/modules/misc/gradient_accumulation.py +11 -79
- torchzero/modules/misc/homotopy.py +16 -8
- torchzero/modules/misc/misc.py +121 -123
- torchzero/modules/misc/multistep.py +52 -53
- torchzero/modules/misc/regularization.py +49 -44
- torchzero/modules/misc/split.py +31 -29
- torchzero/modules/misc/switch.py +37 -32
- torchzero/modules/momentum/averaging.py +14 -14
- torchzero/modules/momentum/cautious.py +37 -31
- torchzero/modules/momentum/momentum.py +12 -12
- torchzero/modules/ops/__init__.py +4 -4
- torchzero/modules/ops/accumulate.py +21 -21
- torchzero/modules/ops/binary.py +67 -66
- torchzero/modules/ops/higher_level.py +20 -20
- torchzero/modules/ops/multi.py +44 -41
- torchzero/modules/ops/reduce.py +26 -23
- torchzero/modules/ops/unary.py +53 -53
- torchzero/modules/ops/utility.py +47 -46
- torchzero/modules/{functional.py → opt_utils.py} +1 -1
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +46 -43
- torchzero/modules/quasi_newton/__init__.py +1 -1
- torchzero/modules/quasi_newton/damping.py +2 -2
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +10 -10
- torchzero/modules/quasi_newton/lsr1.py +10 -10
- torchzero/modules/quasi_newton/quasi_newton.py +54 -39
- torchzero/modules/quasi_newton/sg2.py +69 -205
- torchzero/modules/restarts/restars.py +39 -37
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/ifn.py +31 -62
- torchzero/modules/second_order/inm.py +57 -53
- torchzero/modules/second_order/multipoint.py +40 -80
- torchzero/modules/second_order/newton.py +165 -196
- torchzero/modules/second_order/newton_cg.py +105 -157
- torchzero/modules/second_order/nystrom.py +216 -185
- torchzero/modules/second_order/rsn.py +132 -125
- torchzero/modules/smoothing/laplacian.py +13 -12
- torchzero/modules/smoothing/sampling.py +10 -10
- torchzero/modules/step_size/adaptive.py +24 -24
- torchzero/modules/step_size/lr.py +17 -17
- torchzero/modules/termination/termination.py +32 -30
- torchzero/modules/trust_region/cubic_regularization.py +3 -3
- torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
- torchzero/modules/trust_region/trust_cg.py +2 -2
- torchzero/modules/trust_region/trust_region.py +27 -22
- torchzero/modules/variance_reduction/svrg.py +23 -21
- torchzero/modules/weight_decay/__init__.py +2 -1
- torchzero/modules/weight_decay/reinit.py +83 -0
- torchzero/modules/weight_decay/weight_decay.py +17 -18
- torchzero/modules/wrappers/optim_wrapper.py +14 -14
- torchzero/modules/zeroth_order/cd.py +10 -7
- torchzero/optim/mbs.py +291 -0
- torchzero/optim/root.py +3 -3
- torchzero/optim/utility/split.py +2 -1
- torchzero/optim/wrappers/directsearch.py +27 -63
- torchzero/optim/wrappers/fcmaes.py +14 -35
- torchzero/optim/wrappers/mads.py +11 -31
- torchzero/optim/wrappers/moors.py +66 -0
- torchzero/optim/wrappers/nevergrad.py +4 -13
- torchzero/optim/wrappers/nlopt.py +31 -25
- torchzero/optim/wrappers/optuna.py +8 -13
- torchzero/optim/wrappers/pybobyqa.py +124 -0
- torchzero/optim/wrappers/scipy/__init__.py +7 -0
- torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
- torchzero/optim/wrappers/scipy/brute.py +48 -0
- torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
- torchzero/optim/wrappers/scipy/direct.py +69 -0
- torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
- torchzero/optim/wrappers/scipy/experimental.py +141 -0
- torchzero/optim/wrappers/scipy/minimize.py +151 -0
- torchzero/optim/wrappers/scipy/sgho.py +111 -0
- torchzero/optim/wrappers/wrapper.py +121 -0
- torchzero/utils/__init__.py +7 -25
- torchzero/utils/benchmarks/__init__.py +0 -0
- torchzero/utils/benchmarks/logistic.py +122 -0
- torchzero/utils/compile.py +2 -2
- torchzero/utils/derivatives.py +97 -73
- torchzero/utils/optimizer.py +4 -77
- torchzero/utils/python_tools.py +31 -0
- torchzero/utils/tensorlist.py +11 -5
- torchzero/utils/thoad_tools.py +68 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
- torchzero-0.4.1.dist-info/RECORD +209 -0
- tests/test_vars.py +0 -185
- torchzero/core/var.py +0 -376
- torchzero/modules/adaptive/lmadagrad.py +0 -186
- torchzero/modules/experimental/momentum.py +0 -160
- torchzero/optim/wrappers/scipy.py +0 -572
- torchzero/utils/linalg/__init__.py +0 -12
- torchzero/utils/linalg/matrix_funcs.py +0 -87
- torchzero/utils/linalg/orthogonalize.py +0 -12
- torchzero/utils/linalg/svd.py +0 -20
- torchzero/utils/ops.py +0 -10
- torchzero-0.3.15.dist-info/RECORD +0 -175
- /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
|
@@ -1,153 +1,187 @@
|
|
|
1
|
-
from collections.abc import Callable
|
|
2
|
-
from typing import Literal, overload
|
|
3
1
|
import warnings
|
|
4
|
-
import
|
|
2
|
+
from typing import Literal
|
|
5
3
|
|
|
6
|
-
|
|
7
|
-
from ...utils.derivatives import hvp, hvp_fd_central, hvp_fd_forward
|
|
4
|
+
import torch
|
|
8
5
|
|
|
9
|
-
from ...core import Chainable,
|
|
10
|
-
from ...utils
|
|
6
|
+
from ...core import Chainable, Transform, HVPMethod
|
|
7
|
+
from ...utils import TensorList, vec_to_tensors
|
|
8
|
+
from ...linalg import nystrom_pcg, nystrom_sketch_and_solve, nystrom_approximation, cg, regularize_eigh, OrthogonalizeMethod
|
|
9
|
+
from ...linalg.linear_operator import Eigendecomposition, ScaledIdentity
|
|
11
10
|
|
|
12
|
-
class NystromSketchAndSolve(
|
|
11
|
+
class NystromSketchAndSolve(Transform):
|
|
13
12
|
"""Newton's method with a Nyström sketch-and-solve solver.
|
|
14
13
|
|
|
15
|
-
|
|
16
|
-
This module requires the a closure passed to the optimizer step,
|
|
17
|
-
as it needs to re-evaluate the loss and gradients for calculating HVPs.
|
|
18
|
-
The closure must accept a ``backward`` argument (refer to documentation).
|
|
19
|
-
|
|
20
|
-
.. note::
|
|
21
|
-
In most cases NystromSketchAndSolve should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
|
|
14
|
+
Notes:
|
|
15
|
+
- This module requires the a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a ``backward`` argument (refer to documentation).
|
|
22
16
|
|
|
23
|
-
|
|
24
|
-
If this is unstable, increase the :code:`reg` parameter and tune the rank.
|
|
17
|
+
- In most cases NystromSketchAndSolve should be the first module in the chain because it relies on autograd. Use the ``inner`` argument if you wish to apply Newton preconditioning to another module's output.
|
|
25
18
|
|
|
26
|
-
|
|
27
|
-
:code:`tz.m.NystromPCG` usually outperforms this.
|
|
19
|
+
- If this is unstable, increase the ``reg`` parameter and tune the rank.
|
|
28
20
|
|
|
29
21
|
Args:
|
|
30
22
|
rank (int): size of the sketch, this many hessian-vector products will be evaluated per step.
|
|
31
|
-
reg (float, optional):
|
|
23
|
+
reg (float | None, optional):
|
|
24
|
+
scale of identity matrix added to hessian. Note that if this is specified, nystrom sketch-and-solve
|
|
25
|
+
is used to compute ``(Q diag(L) Q.T + reg*I)x = b``. It is very unstable when ``reg`` is small,
|
|
26
|
+
i.e. smaller than 1e-4. If this is None,``(Q diag(L) Q.T)x = b`` is computed by simply taking
|
|
27
|
+
reciprocal of eigenvalues. Defaults to 1e-3.
|
|
28
|
+
eigv_tol (float, optional):
|
|
29
|
+
all eigenvalues smaller than largest eigenvalue times ``eigv_tol`` are removed. Defaults to None.
|
|
30
|
+
truncate (int | None, optional):
|
|
31
|
+
keeps top ``truncate`` eigenvalues. Defaults to None.
|
|
32
|
+
damping (float, optional): scalar added to eigenvalues. Defaults to 0.
|
|
33
|
+
rdamping (float, optional): scalar multiplied by largest eigenvalue and added to eigenvalues. Defaults to 0.
|
|
34
|
+
update_freq (int, optional): frequency of updating preconditioner. Defaults to 1.
|
|
32
35
|
hvp_method (str, optional):
|
|
33
|
-
Determines how Hessian-vector products are
|
|
34
|
-
|
|
35
|
-
- ``"
|
|
36
|
-
|
|
37
|
-
- ``"
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
36
|
+
Determines how Hessian-vector products are computed.
|
|
37
|
+
|
|
38
|
+
- ``"batched_autograd"`` - uses autograd with batched hessian-vector products to compute the preconditioner. Faster than ``"autograd"`` but uses more memory.
|
|
39
|
+
- ``"autograd"`` - uses autograd hessian-vector products, uses a for loop to compute the preconditioner. Slower than ``"batched_autograd"`` but uses less memory.
|
|
40
|
+
- ``"fd_forward"`` - uses gradient finite difference approximation with a less accurate forward formula which requires one extra gradient evaluation per hessian-vector product.
|
|
41
|
+
- ``"fd_central"`` - uses gradient finite difference approximation with a more accurate central formula which requires two gradient evaluations per hessian-vector product.
|
|
42
|
+
|
|
43
|
+
Defaults to ``"autograd"``.
|
|
44
|
+
h (float, optional):
|
|
45
|
+
The step size for finite difference if ``hvp_method`` is
|
|
46
|
+
``"fd_forward"`` or ``"fd_central"``. Defaults to 1e-3.
|
|
44
47
|
inner (Chainable | None, optional): modules to apply hessian preconditioner to. Defaults to None.
|
|
45
48
|
seed (int | None, optional): seed for random generator. Defaults to None.
|
|
46
49
|
|
|
50
|
+
|
|
47
51
|
Examples:
|
|
48
|
-
|
|
52
|
+
NystromSketchAndSolve with backtracking line search
|
|
49
53
|
|
|
50
|
-
|
|
54
|
+
```py
|
|
55
|
+
opt = tz.Optimizer(
|
|
56
|
+
model.parameters(),
|
|
57
|
+
tz.m.NystromSketchAndSolve(100),
|
|
58
|
+
tz.m.Backtracking()
|
|
59
|
+
)
|
|
60
|
+
```
|
|
51
61
|
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
62
|
+
Trust region NystromSketchAndSolve
|
|
63
|
+
|
|
64
|
+
```py
|
|
65
|
+
opt = tz.Optimizer(
|
|
66
|
+
model.parameters(),
|
|
67
|
+
tz.m.LevenbergMarquadt(tz.m.NystromSketchAndSolve(100)),
|
|
68
|
+
)
|
|
69
|
+
```
|
|
70
|
+
|
|
71
|
+
References:
|
|
72
|
+
- [Frangella, Z., Rathore, P., Zhao, S., & Udell, M. (2024). SketchySGD: Reliable Stochastic Optimization via Randomized Curvature Estimates. SIAM Journal on Mathematics of Data Science, 6(4), 1173-1204.](https://arxiv.org/pdf/2211.08597)
|
|
73
|
+
- [Frangella, Z., Tropp, J. A., & Udell, M. (2023). Randomized nyström preconditioning. SIAM Journal on Matrix Analysis and Applications, 44(2), 718-752](https://arxiv.org/abs/2110.02820)
|
|
57
74
|
|
|
58
|
-
Reference:
|
|
59
|
-
Frangella, Z., Tropp, J. A., & Udell, M. (2023). Randomized nyström preconditioning. SIAM Journal on Matrix Analysis and Applications, 44(2), 718-752. https://arxiv.org/abs/2110.02820
|
|
60
75
|
"""
|
|
61
76
|
def __init__(
|
|
62
77
|
self,
|
|
63
78
|
rank: int,
|
|
64
|
-
reg: float = 1e-
|
|
65
|
-
|
|
79
|
+
reg: float | None = 1e-2,
|
|
80
|
+
eigv_tol: float = 0,
|
|
81
|
+
truncate: int | None = None,
|
|
82
|
+
damping: float = 0,
|
|
83
|
+
rdamping: float = 0,
|
|
84
|
+
update_freq: int = 1,
|
|
85
|
+
orthogonalize_method: OrthogonalizeMethod = 'qr',
|
|
86
|
+
hvp_method: HVPMethod = "batched_autograd",
|
|
66
87
|
h: float = 1e-3,
|
|
67
88
|
inner: Chainable | None = None,
|
|
68
89
|
seed: int | None = None,
|
|
69
90
|
):
|
|
70
|
-
defaults =
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
if inner is not None:
|
|
74
|
-
self.set_child('inner', inner)
|
|
91
|
+
defaults = locals().copy()
|
|
92
|
+
del defaults['self'], defaults['inner'], defaults["update_freq"]
|
|
93
|
+
super().__init__(defaults, update_freq=update_freq, inner=inner)
|
|
75
94
|
|
|
76
95
|
@torch.no_grad
|
|
77
|
-
def
|
|
78
|
-
params = TensorList(
|
|
79
|
-
|
|
80
|
-
closure = var.closure
|
|
81
|
-
if closure is None: raise RuntimeError('NewtonCG requires closure')
|
|
82
|
-
|
|
83
|
-
settings = self.settings[params[0]]
|
|
84
|
-
rank = settings['rank']
|
|
85
|
-
reg = settings['reg']
|
|
86
|
-
hvp_method = settings['hvp_method']
|
|
87
|
-
h = settings['h']
|
|
88
|
-
|
|
89
|
-
seed = settings['seed']
|
|
90
|
-
generator = None
|
|
91
|
-
if seed is not None:
|
|
92
|
-
if 'generator' not in self.global_state:
|
|
93
|
-
self.global_state['generator'] = torch.Generator(params[0].device).manual_seed(seed)
|
|
94
|
-
generator = self.global_state['generator']
|
|
96
|
+
def update_states(self, objective, states, settings):
|
|
97
|
+
params = TensorList(objective.params)
|
|
98
|
+
fs = settings[0]
|
|
95
99
|
|
|
96
100
|
# ---------------------- Hessian vector product function --------------------- #
|
|
97
|
-
|
|
98
|
-
|
|
101
|
+
hvp_method = fs['hvp_method']
|
|
102
|
+
h = fs['h']
|
|
103
|
+
_, H_mv, H_mm = objective.tensor_Hvp_function(hvp_method=hvp_method, h=h, at_x0=True)
|
|
104
|
+
|
|
105
|
+
# ---------------------------------- sketch ---------------------------------- #
|
|
106
|
+
ndim = sum(t.numel() for t in objective.params)
|
|
107
|
+
device = params[0].device
|
|
108
|
+
dtype = params[0].dtype
|
|
109
|
+
|
|
110
|
+
generator = self.get_generator(params[0].device, seed=fs['seed'])
|
|
111
|
+
try:
|
|
112
|
+
# compute the approximation
|
|
113
|
+
L, Q = nystrom_approximation(
|
|
114
|
+
A_mv=H_mv,
|
|
115
|
+
A_mm=H_mm,
|
|
116
|
+
ndim=ndim,
|
|
117
|
+
rank=min(fs["rank"], ndim),
|
|
118
|
+
eigv_tol=fs["eigv_tol"],
|
|
119
|
+
orthogonalize_method=fs["orthogonalize_method"],
|
|
120
|
+
dtype=dtype,
|
|
121
|
+
device=device,
|
|
122
|
+
generator=generator,
|
|
123
|
+
)
|
|
99
124
|
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
125
|
+
# regularize
|
|
126
|
+
L, Q = regularize_eigh(
|
|
127
|
+
L=L,
|
|
128
|
+
Q=Q,
|
|
129
|
+
truncate=fs["truncate"],
|
|
130
|
+
tol=fs["eigv_tol"],
|
|
131
|
+
damping=fs["damping"],
|
|
132
|
+
rdamping=fs["rdamping"],
|
|
133
|
+
)
|
|
104
134
|
|
|
105
|
-
|
|
135
|
+
# store
|
|
136
|
+
if L is not None:
|
|
137
|
+
self.global_state["L"] = L
|
|
138
|
+
self.global_state["Q"] = Q
|
|
106
139
|
|
|
107
|
-
|
|
108
|
-
|
|
140
|
+
except torch.linalg.LinAlgError as e:
|
|
141
|
+
warnings.warn(f"Nystrom approximation failed with: {e}")
|
|
109
142
|
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
return torch.cat([t.ravel() for t in Hvp])
|
|
143
|
+
def apply_states(self, objective, states, settings):
|
|
144
|
+
if "L" not in self.global_state:
|
|
145
|
+
return objective
|
|
114
146
|
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
return torch.cat([t.ravel() for t in Hvp])
|
|
147
|
+
fs = settings[0]
|
|
148
|
+
updates = objective.get_updates()
|
|
149
|
+
b=torch.cat([t.ravel() for t in updates])
|
|
119
150
|
|
|
120
|
-
|
|
121
|
-
|
|
151
|
+
# ----------------------------------- solve ---------------------------------- #
|
|
152
|
+
L = self.global_state["L"]
|
|
153
|
+
Q = self.global_state["Q"]
|
|
122
154
|
|
|
155
|
+
if fs["reg"] is None:
|
|
156
|
+
x = Q @ ((Q.mH @ b) / L)
|
|
157
|
+
else:
|
|
158
|
+
x = nystrom_sketch_and_solve(L=L, Q=Q, b=b, reg=fs["reg"])
|
|
123
159
|
|
|
124
|
-
# --------------------------------
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
b = apply_transform(self.children['inner'], b, params=params, grads=grad, var=var)
|
|
160
|
+
# -------------------------------- set update -------------------------------- #
|
|
161
|
+
objective.updates = vec_to_tensors(x, reference=objective.params)
|
|
162
|
+
return objective
|
|
128
163
|
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
return var
|
|
164
|
+
def get_H(self, objective=...):
|
|
165
|
+
if "L" not in self.global_state:
|
|
166
|
+
return ScaledIdentity()
|
|
133
167
|
|
|
168
|
+
L = self.global_state["L"]
|
|
169
|
+
Q = self.global_state["Q"]
|
|
170
|
+
return Eigendecomposition(L, Q)
|
|
134
171
|
|
|
135
172
|
|
|
136
|
-
class NystromPCG(
|
|
173
|
+
class NystromPCG(Transform):
|
|
137
174
|
"""Newton's method with a Nyström-preconditioned conjugate gradient solver.
|
|
138
|
-
This tends to outperform NewtonCG but requires tuning sketch size.
|
|
139
|
-
An adaptive version exists in https://arxiv.org/abs/2110.02820, I might implement it too at some point.
|
|
140
175
|
|
|
141
|
-
|
|
142
|
-
This module requires the a closure passed to the optimizer step,
|
|
176
|
+
Notes:
|
|
177
|
+
- This module requires the a closure passed to the optimizer step,
|
|
143
178
|
as it needs to re-evaluate the loss and gradients for calculating HVPs.
|
|
144
179
|
The closure must accept a ``backward`` argument (refer to documentation).
|
|
145
180
|
|
|
146
|
-
|
|
147
|
-
In most cases NystromPCG should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
|
|
181
|
+
- In most cases NystromPCG should be the first module in the chain because it relies on autograd. Use the ``inner`` argument if you wish to apply Newton preconditioning to another module's output.
|
|
148
182
|
|
|
149
183
|
Args:
|
|
150
|
-
|
|
184
|
+
rank (int):
|
|
151
185
|
size of the sketch for preconditioning, this many hessian-vector products will be evaluated before
|
|
152
186
|
running the conjugate gradient solver. Larger value improves the preconditioning and speeds up
|
|
153
187
|
conjugate gradient.
|
|
@@ -159,31 +193,31 @@ class NystromPCG(Module):
|
|
|
159
193
|
tol (float, optional): relative tolerance for conjugate gradient solver. Defaults to 1e-4.
|
|
160
194
|
reg (float, optional): regularization parameter. Defaults to 1e-8.
|
|
161
195
|
hvp_method (str, optional):
|
|
162
|
-
Determines how Hessian-vector products are
|
|
163
|
-
|
|
164
|
-
- ``"
|
|
165
|
-
|
|
166
|
-
- ``"
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
196
|
+
Determines how Hessian-vector products are computed.
|
|
197
|
+
|
|
198
|
+
- ``"batched_autograd"`` - uses autograd with batched hessian-vector products to compute the preconditioner. Faster than ``"autograd"`` but uses more memory.
|
|
199
|
+
- ``"autograd"`` - uses autograd hessian-vector products, uses a for loop to compute the preconditioner. Slower than ``"batched_autograd"`` but uses less memory.
|
|
200
|
+
- ``"fd_forward"`` - uses gradient finite difference approximation with a less accurate forward formula which requires one extra gradient evaluation per hessian-vector product.
|
|
201
|
+
- ``"fd_central"`` - uses gradient finite difference approximation with a more accurate central formula which requires two gradient evaluations per hessian-vector product.
|
|
202
|
+
|
|
203
|
+
Defaults to ``"autograd"``.
|
|
204
|
+
h (float, optional):
|
|
205
|
+
The step size for finite difference if ``hvp_method`` is
|
|
206
|
+
``"fd_forward"`` or ``"fd_central"``. Defaults to 1e-3.
|
|
173
207
|
inner (Chainable | None, optional): modules to apply hessian preconditioner to. Defaults to None.
|
|
174
208
|
seed (int | None, optional): seed for random generator. Defaults to None.
|
|
175
209
|
|
|
176
210
|
Examples:
|
|
177
211
|
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
.. code-block:: python
|
|
212
|
+
NystromPCG with backtracking line search
|
|
181
213
|
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
214
|
+
```python
|
|
215
|
+
opt = tz.Optimizer(
|
|
216
|
+
model.parameters(),
|
|
217
|
+
tz.m.NystromPCG(10),
|
|
218
|
+
tz.m.Backtracking()
|
|
219
|
+
)
|
|
220
|
+
```
|
|
187
221
|
|
|
188
222
|
Reference:
|
|
189
223
|
Frangella, Z., Tropp, J. A., & Udell, M. (2023). Randomized nyström preconditioning. SIAM Journal on Matrix Analysis and Applications, 44(2), 718-752. https://arxiv.org/abs/2110.02820
|
|
@@ -191,81 +225,78 @@ class NystromPCG(Module):
|
|
|
191
225
|
"""
|
|
192
226
|
def __init__(
|
|
193
227
|
self,
|
|
194
|
-
|
|
228
|
+
rank: int,
|
|
195
229
|
maxiter=None,
|
|
196
230
|
tol=1e-8,
|
|
197
231
|
reg: float = 1e-6,
|
|
198
|
-
|
|
232
|
+
update_freq: int = 1, # here update_freq is within update_states
|
|
233
|
+
eigv_tol: float = 0,
|
|
234
|
+
orthogonalize_method: OrthogonalizeMethod = 'qr',
|
|
235
|
+
hvp_method: HVPMethod = "batched_autograd",
|
|
199
236
|
h=1e-3,
|
|
200
237
|
inner: Chainable | None = None,
|
|
201
238
|
seed: int | None = None,
|
|
202
239
|
):
|
|
203
|
-
defaults =
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
if inner is not None:
|
|
207
|
-
self.set_child('inner', inner)
|
|
240
|
+
defaults = locals().copy()
|
|
241
|
+
del defaults['self'], defaults['inner']
|
|
242
|
+
super().__init__(defaults, inner=inner)
|
|
208
243
|
|
|
209
244
|
@torch.no_grad
|
|
210
|
-
def
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
closure = var.closure
|
|
214
|
-
if closure is None: raise RuntimeError('NewtonCG requires closure')
|
|
215
|
-
|
|
216
|
-
settings = self.settings[params[0]]
|
|
217
|
-
sketch_size = settings['sketch_size']
|
|
218
|
-
maxiter = settings['maxiter']
|
|
219
|
-
tol = settings['tol']
|
|
220
|
-
reg = settings['reg']
|
|
221
|
-
hvp_method = settings['hvp_method']
|
|
222
|
-
h = settings['h']
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
seed = settings['seed']
|
|
226
|
-
generator = None
|
|
227
|
-
if seed is not None:
|
|
228
|
-
if 'generator' not in self.global_state:
|
|
229
|
-
self.global_state['generator'] = torch.Generator(params[0].device).manual_seed(seed)
|
|
230
|
-
generator = self.global_state['generator']
|
|
231
|
-
|
|
245
|
+
def update_states(self, objective, states, settings):
|
|
246
|
+
fs = settings[0]
|
|
232
247
|
|
|
233
248
|
# ---------------------- Hessian vector product function --------------------- #
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
# ------------------------------ sketch&n&solve ------------------------------ #
|
|
267
|
-
x = nystrom_pcg(A_mm=H_mm, b=torch.cat([t.ravel() for t in b]), sketch_size=sketch_size, reg=reg, tol=tol, maxiter=maxiter, x0_=None, generator=generator)
|
|
268
|
-
var.update = vec_to_tensors(x, reference=params)
|
|
269
|
-
return var
|
|
270
|
-
|
|
249
|
+
# this should run on every update_states
|
|
250
|
+
_, H_mv, H_mm = objective.tensor_Hvp_function(hvp_method=fs['hvp_method'], h=fs['h'], at_x0=True)
|
|
251
|
+
objective.temp = H_mv
|
|
252
|
+
|
|
253
|
+
# --------------------------- update preconditioner -------------------------- #
|
|
254
|
+
step = self.increment_counter("step", 0)
|
|
255
|
+
if step % fs["update_freq"] == 0:
|
|
256
|
+
|
|
257
|
+
ndim = sum(t.numel() for t in objective.params)
|
|
258
|
+
device = objective.params[0].device
|
|
259
|
+
dtype = objective.params[0].dtype
|
|
260
|
+
generator = self.get_generator(device, seed=fs['seed'])
|
|
261
|
+
|
|
262
|
+
try:
|
|
263
|
+
L, Q = nystrom_approximation(
|
|
264
|
+
A_mv=None,
|
|
265
|
+
A_mm=H_mm,
|
|
266
|
+
ndim=ndim,
|
|
267
|
+
rank=min(fs["rank"], ndim),
|
|
268
|
+
eigv_tol=fs["eigv_tol"],
|
|
269
|
+
orthogonalize_method=fs["orthogonalize_method"],
|
|
270
|
+
dtype=dtype,
|
|
271
|
+
device=device,
|
|
272
|
+
generator=generator,
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
self.global_state["L"] = L
|
|
276
|
+
self.global_state["Q"] = Q
|
|
277
|
+
|
|
278
|
+
except torch.linalg.LinAlgError as e:
|
|
279
|
+
warnings.warn(f"Nystrom approximation failed with: {e}")
|
|
271
280
|
|
|
281
|
+
@torch.no_grad
|
|
282
|
+
def apply_states(self, objective, states, settings):
|
|
283
|
+
b = objective.get_updates()
|
|
284
|
+
H_mv = objective.poptemp()
|
|
285
|
+
fs = self.settings[objective.params[0]]
|
|
286
|
+
|
|
287
|
+
# ----------------------------------- solve ---------------------------------- #
|
|
288
|
+
if "L" not in self.global_state:
|
|
289
|
+
# fallback on cg
|
|
290
|
+
sol = cg(A_mv=H_mv, b=TensorList(b), tol=fs["tol"], reg=fs["reg"], maxiter=fs["maxiter"])
|
|
291
|
+
objective.updates = sol.x
|
|
292
|
+
return objective
|
|
293
|
+
|
|
294
|
+
L = self.global_state["L"]
|
|
295
|
+
Q = self.global_state["Q"]
|
|
296
|
+
|
|
297
|
+
x = nystrom_pcg(L=L, Q=Q, A_mv=H_mv, b=torch.cat([t.ravel() for t in b]),
|
|
298
|
+
reg=fs['reg'], tol=fs["tol"], maxiter=fs["maxiter"])
|
|
299
|
+
|
|
300
|
+
# -------------------------------- set update -------------------------------- #
|
|
301
|
+
objective.updates = vec_to_tensors(x, reference=objective.params)
|
|
302
|
+
return objective
|