torchzero 0.3.15__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +2 -2
- tests/test_module_autograd.py +586 -0
- tests/test_objective.py +188 -0
- tests/test_opts.py +43 -33
- tests/test_tensorlist.py +0 -8
- tests/test_utils_optimizer.py +0 -1
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +7 -4
- torchzero/core/chain.py +20 -23
- torchzero/core/functional.py +90 -24
- torchzero/core/modular.py +48 -52
- torchzero/core/module.py +130 -50
- torchzero/core/objective.py +948 -0
- torchzero/core/reformulation.py +55 -24
- torchzero/core/transform.py +261 -367
- torchzero/linalg/__init__.py +10 -0
- torchzero/linalg/eigh.py +34 -0
- torchzero/linalg/linalg_utils.py +14 -0
- torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
- torchzero/linalg/matrix_power.py +28 -0
- torchzero/linalg/orthogonalize.py +95 -0
- torchzero/{utils/linalg → linalg}/qr.py +4 -2
- torchzero/{utils/linalg → linalg}/solve.py +76 -88
- torchzero/linalg/svd.py +20 -0
- torchzero/linalg/torch_linalg.py +168 -0
- torchzero/modules/adaptive/__init__.py +1 -1
- torchzero/modules/adaptive/adagrad.py +163 -213
- torchzero/modules/adaptive/adahessian.py +74 -103
- torchzero/modules/adaptive/adam.py +53 -76
- torchzero/modules/adaptive/adan.py +49 -30
- torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
- torchzero/modules/adaptive/aegd.py +12 -12
- torchzero/modules/adaptive/esgd.py +98 -119
- torchzero/modules/adaptive/lion.py +5 -10
- torchzero/modules/adaptive/lmadagrad.py +87 -32
- torchzero/modules/adaptive/mars.py +5 -5
- torchzero/modules/adaptive/matrix_momentum.py +47 -51
- torchzero/modules/adaptive/msam.py +70 -52
- torchzero/modules/adaptive/muon.py +59 -124
- torchzero/modules/adaptive/natural_gradient.py +33 -28
- torchzero/modules/adaptive/orthograd.py +11 -15
- torchzero/modules/adaptive/rmsprop.py +83 -75
- torchzero/modules/adaptive/rprop.py +48 -47
- torchzero/modules/adaptive/sam.py +55 -45
- torchzero/modules/adaptive/shampoo.py +123 -129
- torchzero/modules/adaptive/soap.py +207 -143
- torchzero/modules/adaptive/sophia_h.py +106 -130
- torchzero/modules/clipping/clipping.py +15 -18
- torchzero/modules/clipping/ema_clipping.py +31 -25
- torchzero/modules/clipping/growth_clipping.py +14 -17
- torchzero/modules/conjugate_gradient/cg.py +26 -37
- torchzero/modules/experimental/__init__.py +2 -6
- torchzero/modules/experimental/coordinate_momentum.py +36 -0
- torchzero/modules/experimental/curveball.py +25 -41
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/higher_order_newton.py +14 -40
- torchzero/modules/experimental/newton_solver.py +22 -53
- torchzero/modules/experimental/newtonnewton.py +15 -12
- torchzero/modules/experimental/reduce_outward_lr.py +7 -7
- torchzero/modules/experimental/scipy_newton_cg.py +21 -24
- torchzero/modules/experimental/spsa1.py +3 -3
- torchzero/modules/experimental/structural_projections.py +1 -4
- torchzero/modules/functional.py +1 -1
- torchzero/modules/grad_approximation/forward_gradient.py +7 -7
- torchzero/modules/grad_approximation/grad_approximator.py +23 -16
- torchzero/modules/grad_approximation/rfdm.py +20 -17
- torchzero/modules/least_squares/gn.py +90 -42
- torchzero/modules/line_search/backtracking.py +2 -2
- torchzero/modules/line_search/line_search.py +32 -32
- torchzero/modules/line_search/strong_wolfe.py +2 -2
- torchzero/modules/misc/debug.py +12 -12
- torchzero/modules/misc/escape.py +10 -10
- torchzero/modules/misc/gradient_accumulation.py +10 -78
- torchzero/modules/misc/homotopy.py +16 -8
- torchzero/modules/misc/misc.py +120 -122
- torchzero/modules/misc/multistep.py +50 -48
- torchzero/modules/misc/regularization.py +49 -44
- torchzero/modules/misc/split.py +30 -28
- torchzero/modules/misc/switch.py +37 -32
- torchzero/modules/momentum/averaging.py +14 -14
- torchzero/modules/momentum/cautious.py +34 -28
- torchzero/modules/momentum/momentum.py +11 -11
- torchzero/modules/ops/__init__.py +4 -4
- torchzero/modules/ops/accumulate.py +21 -21
- torchzero/modules/ops/binary.py +67 -66
- torchzero/modules/ops/higher_level.py +19 -19
- torchzero/modules/ops/multi.py +44 -41
- torchzero/modules/ops/reduce.py +26 -23
- torchzero/modules/ops/unary.py +53 -53
- torchzero/modules/ops/utility.py +47 -46
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +43 -43
- torchzero/modules/quasi_newton/damping.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +7 -7
- torchzero/modules/quasi_newton/lsr1.py +7 -7
- torchzero/modules/quasi_newton/quasi_newton.py +10 -10
- torchzero/modules/quasi_newton/sg2.py +19 -19
- torchzero/modules/restarts/restars.py +26 -24
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/ifn.py +31 -62
- torchzero/modules/second_order/inm.py +49 -53
- torchzero/modules/second_order/multipoint.py +40 -80
- torchzero/modules/second_order/newton.py +57 -90
- torchzero/modules/second_order/newton_cg.py +102 -154
- torchzero/modules/second_order/nystrom.py +157 -177
- torchzero/modules/second_order/rsn.py +106 -96
- torchzero/modules/smoothing/laplacian.py +13 -12
- torchzero/modules/smoothing/sampling.py +11 -10
- torchzero/modules/step_size/adaptive.py +23 -23
- torchzero/modules/step_size/lr.py +15 -15
- torchzero/modules/termination/termination.py +32 -30
- torchzero/modules/trust_region/cubic_regularization.py +2 -2
- torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
- torchzero/modules/trust_region/trust_cg.py +1 -1
- torchzero/modules/trust_region/trust_region.py +27 -22
- torchzero/modules/variance_reduction/svrg.py +21 -18
- torchzero/modules/weight_decay/__init__.py +2 -1
- torchzero/modules/weight_decay/reinit.py +83 -0
- torchzero/modules/weight_decay/weight_decay.py +12 -13
- torchzero/modules/wrappers/optim_wrapper.py +10 -10
- torchzero/modules/zeroth_order/cd.py +9 -6
- torchzero/optim/root.py +3 -3
- torchzero/optim/utility/split.py +2 -1
- torchzero/optim/wrappers/directsearch.py +27 -63
- torchzero/optim/wrappers/fcmaes.py +14 -35
- torchzero/optim/wrappers/mads.py +11 -31
- torchzero/optim/wrappers/moors.py +66 -0
- torchzero/optim/wrappers/nevergrad.py +4 -4
- torchzero/optim/wrappers/nlopt.py +31 -25
- torchzero/optim/wrappers/optuna.py +6 -13
- torchzero/optim/wrappers/pybobyqa.py +124 -0
- torchzero/optim/wrappers/scipy/__init__.py +7 -0
- torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
- torchzero/optim/wrappers/scipy/brute.py +48 -0
- torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
- torchzero/optim/wrappers/scipy/direct.py +69 -0
- torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
- torchzero/optim/wrappers/scipy/experimental.py +141 -0
- torchzero/optim/wrappers/scipy/minimize.py +151 -0
- torchzero/optim/wrappers/scipy/sgho.py +111 -0
- torchzero/optim/wrappers/wrapper.py +121 -0
- torchzero/utils/__init__.py +7 -25
- torchzero/utils/compile.py +2 -2
- torchzero/utils/derivatives.py +93 -69
- torchzero/utils/optimizer.py +4 -77
- torchzero/utils/python_tools.py +31 -0
- torchzero/utils/tensorlist.py +11 -5
- torchzero/utils/thoad_tools.py +68 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
- torchzero-0.4.0.dist-info/RECORD +191 -0
- tests/test_vars.py +0 -185
- torchzero/core/var.py +0 -376
- torchzero/modules/experimental/momentum.py +0 -160
- torchzero/optim/wrappers/scipy.py +0 -572
- torchzero/utils/linalg/__init__.py +0 -12
- torchzero/utils/linalg/matrix_funcs.py +0 -87
- torchzero/utils/linalg/orthogonalize.py +0 -12
- torchzero/utils/linalg/svd.py +0 -20
- torchzero/utils/ops.py +0 -10
- torchzero-0.3.15.dist-info/RECORD +0 -175
- /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
|
@@ -1,150 +1,141 @@
|
|
|
1
|
-
from
|
|
2
|
-
from typing import Literal, overload
|
|
3
|
-
import warnings
|
|
4
|
-
import torch
|
|
1
|
+
from typing import Literal
|
|
5
2
|
|
|
6
|
-
|
|
7
|
-
from ...utils.derivatives import hvp, hvp_fd_central, hvp_fd_forward
|
|
3
|
+
import torch
|
|
8
4
|
|
|
9
|
-
from ...core import Chainable,
|
|
10
|
-
from ...utils
|
|
5
|
+
from ...core import Chainable, Transform, HVPMethod
|
|
6
|
+
from ...utils import TensorList, vec_to_tensors
|
|
7
|
+
from ...linalg import nystrom_pcg, nystrom_sketch_and_solve, nystrom_approximation, cg
|
|
8
|
+
from ...linalg.linear_operator import Eigendecomposition, ScaledIdentity
|
|
11
9
|
|
|
12
|
-
class NystromSketchAndSolve(
|
|
10
|
+
class NystromSketchAndSolve(Transform):
|
|
13
11
|
"""Newton's method with a Nyström sketch-and-solve solver.
|
|
14
12
|
|
|
15
|
-
|
|
16
|
-
This module requires the a closure passed to the optimizer step,
|
|
17
|
-
as it needs to re-evaluate the loss and gradients for calculating HVPs.
|
|
18
|
-
The closure must accept a ``backward`` argument (refer to documentation).
|
|
19
|
-
|
|
20
|
-
.. note::
|
|
21
|
-
In most cases NystromSketchAndSolve should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
|
|
13
|
+
Notes:
|
|
14
|
+
- This module requires the a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a ``backward`` argument (refer to documentation).
|
|
22
15
|
|
|
23
|
-
|
|
24
|
-
If this is unstable, increase the :code:`reg` parameter and tune the rank.
|
|
16
|
+
- In most cases NystromSketchAndSolve should be the first module in the chain because it relies on autograd. Use the ``inner`` argument if you wish to apply Newton preconditioning to another module's output.
|
|
25
17
|
|
|
26
|
-
|
|
27
|
-
:code:`tz.m.NystromPCG` usually outperforms this.
|
|
18
|
+
- If this is unstable, increase the ``reg`` parameter and tune the rank.
|
|
28
19
|
|
|
29
20
|
Args:
|
|
30
21
|
rank (int): size of the sketch, this many hessian-vector products will be evaluated per step.
|
|
31
22
|
reg (float, optional): regularization parameter. Defaults to 1e-3.
|
|
32
23
|
hvp_method (str, optional):
|
|
33
|
-
Determines how Hessian-vector products are
|
|
34
|
-
|
|
35
|
-
- ``"
|
|
36
|
-
|
|
37
|
-
- ``"
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
24
|
+
Determines how Hessian-vector products are computed.
|
|
25
|
+
|
|
26
|
+
- ``"batched_autograd"`` - uses autograd with batched hessian-vector products to compute the preconditioner. Faster than ``"autograd"`` but uses more memory.
|
|
27
|
+
- ``"autograd"`` - uses autograd hessian-vector products, uses a for loop to compute the preconditioner. Slower than ``"batched_autograd"`` but uses less memory.
|
|
28
|
+
- ``"fd_forward"`` - uses gradient finite difference approximation with a less accurate forward formula which requires one extra gradient evaluation per hessian-vector product.
|
|
29
|
+
- ``"fd_central"`` - uses gradient finite difference approximation with a more accurate central formula which requires two gradient evaluations per hessian-vector product.
|
|
30
|
+
|
|
31
|
+
Defaults to ``"autograd"``.
|
|
32
|
+
h (float, optional):
|
|
33
|
+
The step size for finite difference if ``hvp_method`` is
|
|
34
|
+
``"fd_forward"`` or ``"fd_central"``. Defaults to 1e-3.
|
|
44
35
|
inner (Chainable | None, optional): modules to apply hessian preconditioner to. Defaults to None.
|
|
45
36
|
seed (int | None, optional): seed for random generator. Defaults to None.
|
|
46
37
|
|
|
38
|
+
|
|
47
39
|
Examples:
|
|
48
|
-
|
|
40
|
+
NystromSketchAndSolve with backtracking line search
|
|
49
41
|
|
|
50
|
-
|
|
42
|
+
```py
|
|
43
|
+
opt = tz.Modular(
|
|
44
|
+
model.parameters(),
|
|
45
|
+
tz.m.NystromSketchAndSolve(100),
|
|
46
|
+
tz.m.Backtracking()
|
|
47
|
+
)
|
|
48
|
+
```
|
|
51
49
|
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
50
|
+
Trust region NystromSketchAndSolve
|
|
51
|
+
|
|
52
|
+
```py
|
|
53
|
+
opt = tz.Modular(
|
|
54
|
+
model.parameters(),
|
|
55
|
+
tz.m.LevenbergMarquadt(tz.m.NystromSketchAndSolve(100)),
|
|
56
|
+
)
|
|
57
|
+
```
|
|
58
|
+
|
|
59
|
+
References:
|
|
60
|
+
- [Frangella, Z., Rathore, P., Zhao, S., & Udell, M. (2024). SketchySGD: Reliable Stochastic Optimization via Randomized Curvature Estimates. SIAM Journal on Mathematics of Data Science, 6(4), 1173-1204.](https://arxiv.org/pdf/2211.08597)
|
|
61
|
+
- [Frangella, Z., Tropp, J. A., & Udell, M. (2023). Randomized nyström preconditioning. SIAM Journal on Matrix Analysis and Applications, 44(2), 718-752](https://arxiv.org/abs/2110.02820)
|
|
57
62
|
|
|
58
|
-
Reference:
|
|
59
|
-
Frangella, Z., Tropp, J. A., & Udell, M. (2023). Randomized nyström preconditioning. SIAM Journal on Matrix Analysis and Applications, 44(2), 718-752. https://arxiv.org/abs/2110.02820
|
|
60
63
|
"""
|
|
61
64
|
def __init__(
|
|
62
65
|
self,
|
|
63
66
|
rank: int,
|
|
64
67
|
reg: float = 1e-3,
|
|
65
|
-
hvp_method:
|
|
68
|
+
hvp_method: HVPMethod = "batched_autograd",
|
|
66
69
|
h: float = 1e-3,
|
|
70
|
+
update_freq: int = 1,
|
|
67
71
|
inner: Chainable | None = None,
|
|
68
72
|
seed: int | None = None,
|
|
69
73
|
):
|
|
70
|
-
defaults =
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
if inner is not None:
|
|
74
|
-
self.set_child('inner', inner)
|
|
74
|
+
defaults = locals().copy()
|
|
75
|
+
del defaults['self'], defaults['inner'], defaults["update_freq"]
|
|
76
|
+
super().__init__(defaults, update_freq=update_freq, inner=inner)
|
|
75
77
|
|
|
76
78
|
@torch.no_grad
|
|
77
|
-
def
|
|
78
|
-
params = TensorList(
|
|
79
|
-
|
|
80
|
-
closure = var.closure
|
|
81
|
-
if closure is None: raise RuntimeError('NewtonCG requires closure')
|
|
82
|
-
|
|
83
|
-
settings = self.settings[params[0]]
|
|
84
|
-
rank = settings['rank']
|
|
85
|
-
reg = settings['reg']
|
|
86
|
-
hvp_method = settings['hvp_method']
|
|
87
|
-
h = settings['h']
|
|
88
|
-
|
|
89
|
-
seed = settings['seed']
|
|
90
|
-
generator = None
|
|
91
|
-
if seed is not None:
|
|
92
|
-
if 'generator' not in self.global_state:
|
|
93
|
-
self.global_state['generator'] = torch.Generator(params[0].device).manual_seed(seed)
|
|
94
|
-
generator = self.global_state['generator']
|
|
79
|
+
def update_states(self, objective, states, settings):
|
|
80
|
+
params = TensorList(objective.params)
|
|
81
|
+
fs = settings[0]
|
|
95
82
|
|
|
96
83
|
# ---------------------- Hessian vector product function --------------------- #
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
def H_mm(x):
|
|
101
|
-
with torch.enable_grad():
|
|
102
|
-
Hvp = hvp(params, grad, params.from_vec(x), retain_graph=True)
|
|
103
|
-
return torch.cat([t.ravel() for t in Hvp])
|
|
84
|
+
hvp_method = fs['hvp_method']
|
|
85
|
+
h = fs['h']
|
|
86
|
+
_, H_mv, H_mm = objective.tensor_Hvp_function(hvp_method=hvp_method, h=h, at_x0=True)
|
|
104
87
|
|
|
105
|
-
|
|
88
|
+
# ---------------------------------- sketch ---------------------------------- #
|
|
89
|
+
ndim = sum(t.numel() for t in objective.params)
|
|
90
|
+
device = params[0].device
|
|
91
|
+
dtype = params[0].dtype
|
|
106
92
|
|
|
107
|
-
|
|
108
|
-
|
|
93
|
+
generator = self.get_generator(params[0].device, seed=fs['seed'])
|
|
94
|
+
try:
|
|
95
|
+
L, Q = nystrom_approximation(A_mv=H_mv, A_mm=H_mm, ndim=ndim, rank=fs['rank'],
|
|
96
|
+
dtype=dtype, device=device, generator=generator)
|
|
109
97
|
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
98
|
+
self.global_state["L"] = L
|
|
99
|
+
self.global_state["Q"] = Q
|
|
100
|
+
except torch.linalg.LinAlgError:
|
|
101
|
+
pass
|
|
114
102
|
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
return torch.cat([t.ravel() for t in Hvp])
|
|
103
|
+
def apply_states(self, objective, states, settings):
|
|
104
|
+
fs = settings[0]
|
|
105
|
+
b = objective.get_updates()
|
|
119
106
|
|
|
120
|
-
|
|
121
|
-
|
|
107
|
+
# ----------------------------------- solve ---------------------------------- #
|
|
108
|
+
if "L" not in self.global_state:
|
|
109
|
+
return objective
|
|
122
110
|
|
|
111
|
+
L = self.global_state["L"]
|
|
112
|
+
Q = self.global_state["Q"]
|
|
113
|
+
x = nystrom_sketch_and_solve(L=L, Q=Q, b=torch.cat([t.ravel() for t in b]), reg=fs["reg"])
|
|
123
114
|
|
|
124
|
-
# --------------------------------
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
b = apply_transform(self.children['inner'], b, params=params, grads=grad, var=var)
|
|
115
|
+
# -------------------------------- set update -------------------------------- #
|
|
116
|
+
objective.updates = vec_to_tensors(x, reference=objective.params)
|
|
117
|
+
return objective
|
|
128
118
|
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
return var
|
|
119
|
+
def get_H(self, objective=...):
|
|
120
|
+
if "L" not in self.global_state:
|
|
121
|
+
return ScaledIdentity()
|
|
133
122
|
|
|
123
|
+
L = self.global_state["L"]
|
|
124
|
+
Q = self.global_state["Q"]
|
|
125
|
+
return Eigendecomposition(L, Q)
|
|
134
126
|
|
|
135
127
|
|
|
136
|
-
class NystromPCG(
|
|
128
|
+
class NystromPCG(Transform):
|
|
137
129
|
"""Newton's method with a Nyström-preconditioned conjugate gradient solver.
|
|
138
130
|
This tends to outperform NewtonCG but requires tuning sketch size.
|
|
139
131
|
An adaptive version exists in https://arxiv.org/abs/2110.02820, I might implement it too at some point.
|
|
140
132
|
|
|
141
|
-
|
|
142
|
-
This module requires the a closure passed to the optimizer step,
|
|
133
|
+
Notes:
|
|
134
|
+
- This module requires the a closure passed to the optimizer step,
|
|
143
135
|
as it needs to re-evaluate the loss and gradients for calculating HVPs.
|
|
144
136
|
The closure must accept a ``backward`` argument (refer to documentation).
|
|
145
137
|
|
|
146
|
-
|
|
147
|
-
In most cases NystromPCG should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
|
|
138
|
+
- In most cases NystromPCG should be the first module in the chain because it relies on autograd. Use the ``inner`` argument if you wish to apply Newton preconditioning to another module's output.
|
|
148
139
|
|
|
149
140
|
Args:
|
|
150
141
|
sketch_size (int):
|
|
@@ -159,31 +150,31 @@ class NystromPCG(Module):
|
|
|
159
150
|
tol (float, optional): relative tolerance for conjugate gradient solver. Defaults to 1e-4.
|
|
160
151
|
reg (float, optional): regularization parameter. Defaults to 1e-8.
|
|
161
152
|
hvp_method (str, optional):
|
|
162
|
-
Determines how Hessian-vector products are
|
|
163
|
-
|
|
164
|
-
- ``"
|
|
165
|
-
|
|
166
|
-
- ``"
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
153
|
+
Determines how Hessian-vector products are computed.
|
|
154
|
+
|
|
155
|
+
- ``"batched_autograd"`` - uses autograd with batched hessian-vector products to compute the preconditioner. Faster than ``"autograd"`` but uses more memory.
|
|
156
|
+
- ``"autograd"`` - uses autograd hessian-vector products, uses a for loop to compute the preconditioner. Slower than ``"batched_autograd"`` but uses less memory.
|
|
157
|
+
- ``"fd_forward"`` - uses gradient finite difference approximation with a less accurate forward formula which requires one extra gradient evaluation per hessian-vector product.
|
|
158
|
+
- ``"fd_central"`` - uses gradient finite difference approximation with a more accurate central formula which requires two gradient evaluations per hessian-vector product.
|
|
159
|
+
|
|
160
|
+
Defaults to ``"autograd"``.
|
|
161
|
+
h (float, optional):
|
|
162
|
+
The step size for finite difference if ``hvp_method`` is
|
|
163
|
+
``"fd_forward"`` or ``"fd_central"``. Defaults to 1e-3.
|
|
173
164
|
inner (Chainable | None, optional): modules to apply hessian preconditioner to. Defaults to None.
|
|
174
165
|
seed (int | None, optional): seed for random generator. Defaults to None.
|
|
175
166
|
|
|
176
167
|
Examples:
|
|
177
168
|
|
|
178
|
-
|
|
169
|
+
NystromPCG with backtracking line search
|
|
179
170
|
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
171
|
+
```python
|
|
172
|
+
opt = tz.Modular(
|
|
173
|
+
model.parameters(),
|
|
174
|
+
tz.m.NystromPCG(10),
|
|
175
|
+
tz.m.Backtracking()
|
|
176
|
+
)
|
|
177
|
+
```
|
|
187
178
|
|
|
188
179
|
Reference:
|
|
189
180
|
Frangella, Z., Tropp, J. A., & Udell, M. (2023). Randomized nyström preconditioning. SIAM Journal on Matrix Analysis and Applications, 44(2), 718-752. https://arxiv.org/abs/2110.02820
|
|
@@ -191,81 +182,70 @@ class NystromPCG(Module):
|
|
|
191
182
|
"""
|
|
192
183
|
def __init__(
|
|
193
184
|
self,
|
|
194
|
-
|
|
185
|
+
rank: int,
|
|
195
186
|
maxiter=None,
|
|
196
187
|
tol=1e-8,
|
|
197
188
|
reg: float = 1e-6,
|
|
198
|
-
|
|
189
|
+
update_freq: int = 1, # here update_freq is within update_states
|
|
190
|
+
hvp_method: HVPMethod = "batched_autograd",
|
|
199
191
|
h=1e-3,
|
|
200
192
|
inner: Chainable | None = None,
|
|
201
193
|
seed: int | None = None,
|
|
202
194
|
):
|
|
203
|
-
defaults =
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
if inner is not None:
|
|
207
|
-
self.set_child('inner', inner)
|
|
195
|
+
defaults = locals().copy()
|
|
196
|
+
del defaults['self'], defaults['inner']
|
|
197
|
+
super().__init__(defaults, inner=inner)
|
|
208
198
|
|
|
209
199
|
@torch.no_grad
|
|
210
|
-
def
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
closure = var.closure
|
|
214
|
-
if closure is None: raise RuntimeError('NewtonCG requires closure')
|
|
215
|
-
|
|
216
|
-
settings = self.settings[params[0]]
|
|
217
|
-
sketch_size = settings['sketch_size']
|
|
218
|
-
maxiter = settings['maxiter']
|
|
219
|
-
tol = settings['tol']
|
|
220
|
-
reg = settings['reg']
|
|
221
|
-
hvp_method = settings['hvp_method']
|
|
222
|
-
h = settings['h']
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
seed = settings['seed']
|
|
226
|
-
generator = None
|
|
227
|
-
if seed is not None:
|
|
228
|
-
if 'generator' not in self.global_state:
|
|
229
|
-
self.global_state['generator'] = torch.Generator(params[0].device).manual_seed(seed)
|
|
230
|
-
generator = self.global_state['generator']
|
|
231
|
-
|
|
200
|
+
def update_states(self, objective, states, settings):
|
|
201
|
+
fs = settings[0]
|
|
232
202
|
|
|
233
203
|
# ---------------------- Hessian vector product function --------------------- #
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
Hvp = hvp(params, grad, params.from_vec(x), retain_graph=True)
|
|
240
|
-
return torch.cat([t.ravel() for t in Hvp])
|
|
241
|
-
|
|
242
|
-
else:
|
|
204
|
+
# this should run on every update_states
|
|
205
|
+
hvp_method = fs['hvp_method']
|
|
206
|
+
h = fs['h']
|
|
207
|
+
_, H_mv, H_mm = objective.tensor_Hvp_function(hvp_method=hvp_method, h=h, at_x0=True)
|
|
208
|
+
objective.temp = H_mv
|
|
243
209
|
|
|
244
|
-
|
|
245
|
-
|
|
210
|
+
# --------------------------- update preconditioner -------------------------- #
|
|
211
|
+
step = self.increment_counter("step", 0)
|
|
212
|
+
update_freq = self.defaults["update_freq"]
|
|
246
213
|
|
|
247
|
-
|
|
248
|
-
def H_mm(x):
|
|
249
|
-
Hvp = hvp_fd_forward(closure, params, params.from_vec(x), h=h, g_0=grad, normalize=True)[1]
|
|
250
|
-
return torch.cat([t.ravel() for t in Hvp])
|
|
214
|
+
if step % update_freq == 0:
|
|
251
215
|
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
216
|
+
rank = fs['rank']
|
|
217
|
+
ndim = sum(t.numel() for t in objective.params)
|
|
218
|
+
device = objective.params[0].device
|
|
219
|
+
dtype = objective.params[0].dtype
|
|
220
|
+
generator = self.get_generator(device, seed=fs['seed'])
|
|
256
221
|
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
# -------------------------------- inner step -------------------------------- #
|
|
262
|
-
b = var.get_update()
|
|
263
|
-
if 'inner' in self.children:
|
|
264
|
-
b = apply_transform(self.children['inner'], b, params=params, grads=grad, var=var)
|
|
265
|
-
|
|
266
|
-
# ------------------------------ sketch&n&solve ------------------------------ #
|
|
267
|
-
x = nystrom_pcg(A_mm=H_mm, b=torch.cat([t.ravel() for t in b]), sketch_size=sketch_size, reg=reg, tol=tol, maxiter=maxiter, x0_=None, generator=generator)
|
|
268
|
-
var.update = vec_to_tensors(x, reference=params)
|
|
269
|
-
return var
|
|
222
|
+
try:
|
|
223
|
+
L, Q = nystrom_approximation(A_mv=None, A_mm=H_mm, ndim=ndim, rank=rank,
|
|
224
|
+
dtype=dtype, device=device, generator=generator)
|
|
270
225
|
|
|
226
|
+
self.global_state["L"] = L
|
|
227
|
+
self.global_state["Q"] = Q
|
|
228
|
+
except torch.linalg.LinAlgError:
|
|
229
|
+
pass
|
|
271
230
|
|
|
231
|
+
@torch.no_grad
|
|
232
|
+
def apply_states(self, objective, states, settings):
|
|
233
|
+
b = objective.get_updates()
|
|
234
|
+
H_mv = objective.poptemp()
|
|
235
|
+
fs = self.settings[objective.params[0]]
|
|
236
|
+
|
|
237
|
+
# ----------------------------------- solve ---------------------------------- #
|
|
238
|
+
if "L" not in self.global_state:
|
|
239
|
+
# fallback on cg
|
|
240
|
+
sol = cg(A_mv=H_mv, b=TensorList(b), tol=fs["tol"], reg=fs["reg"], maxiter=fs["maxiter"])
|
|
241
|
+
objective.updates = sol.x
|
|
242
|
+
return objective
|
|
243
|
+
|
|
244
|
+
L = self.global_state["L"]
|
|
245
|
+
Q = self.global_state["Q"]
|
|
246
|
+
x = nystrom_pcg(L=L, Q=Q, A_mv=H_mv, b=torch.cat([t.ravel() for t in b]),
|
|
247
|
+
reg=fs['reg'], tol=fs["tol"], maxiter=fs["maxiter"])
|
|
248
|
+
|
|
249
|
+
# -------------------------------- set update -------------------------------- #
|
|
250
|
+
objective.updates = vec_to_tensors(x, reference=objective.params)
|
|
251
|
+
return objective
|