torchzero 0.3.14__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +2 -2
- tests/test_module_autograd.py +586 -0
- tests/test_objective.py +188 -0
- tests/test_opts.py +47 -36
- tests/test_tensorlist.py +0 -8
- tests/test_utils_optimizer.py +0 -1
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +8 -2
- torchzero/core/chain.py +47 -0
- torchzero/core/functional.py +103 -0
- torchzero/core/modular.py +233 -0
- torchzero/core/module.py +132 -643
- torchzero/core/objective.py +948 -0
- torchzero/core/reformulation.py +56 -23
- torchzero/core/transform.py +261 -365
- torchzero/linalg/__init__.py +10 -0
- torchzero/linalg/eigh.py +34 -0
- torchzero/linalg/linalg_utils.py +14 -0
- torchzero/{utils/linalg → linalg}/linear_operator.py +132 -34
- torchzero/linalg/matrix_power.py +28 -0
- torchzero/linalg/orthogonalize.py +95 -0
- torchzero/{utils/linalg → linalg}/qr.py +4 -2
- torchzero/{utils/linalg → linalg}/solve.py +76 -88
- torchzero/linalg/svd.py +20 -0
- torchzero/linalg/torch_linalg.py +168 -0
- torchzero/modules/__init__.py +0 -1
- torchzero/modules/adaptive/__init__.py +1 -1
- torchzero/modules/adaptive/adagrad.py +163 -213
- torchzero/modules/adaptive/adahessian.py +74 -103
- torchzero/modules/adaptive/adam.py +53 -76
- torchzero/modules/adaptive/adan.py +49 -30
- torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
- torchzero/modules/adaptive/aegd.py +12 -12
- torchzero/modules/adaptive/esgd.py +98 -119
- torchzero/modules/adaptive/lion.py +5 -10
- torchzero/modules/adaptive/lmadagrad.py +87 -32
- torchzero/modules/adaptive/mars.py +5 -5
- torchzero/modules/adaptive/matrix_momentum.py +47 -51
- torchzero/modules/adaptive/msam.py +70 -52
- torchzero/modules/adaptive/muon.py +59 -124
- torchzero/modules/adaptive/natural_gradient.py +33 -28
- torchzero/modules/adaptive/orthograd.py +11 -15
- torchzero/modules/adaptive/rmsprop.py +83 -75
- torchzero/modules/adaptive/rprop.py +48 -47
- torchzero/modules/adaptive/sam.py +55 -45
- torchzero/modules/adaptive/shampoo.py +123 -129
- torchzero/modules/adaptive/soap.py +207 -143
- torchzero/modules/adaptive/sophia_h.py +106 -130
- torchzero/modules/clipping/clipping.py +15 -18
- torchzero/modules/clipping/ema_clipping.py +31 -25
- torchzero/modules/clipping/growth_clipping.py +14 -17
- torchzero/modules/conjugate_gradient/cg.py +26 -37
- torchzero/modules/experimental/__init__.py +3 -6
- torchzero/modules/experimental/coordinate_momentum.py +36 -0
- torchzero/modules/experimental/curveball.py +25 -41
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/{higher_order → experimental}/higher_order_newton.py +14 -40
- torchzero/modules/experimental/newton_solver.py +22 -53
- torchzero/modules/experimental/newtonnewton.py +20 -17
- torchzero/modules/experimental/reduce_outward_lr.py +7 -7
- torchzero/modules/experimental/scipy_newton_cg.py +21 -24
- torchzero/modules/experimental/spsa1.py +5 -5
- torchzero/modules/experimental/structural_projections.py +1 -4
- torchzero/modules/functional.py +8 -1
- torchzero/modules/grad_approximation/forward_gradient.py +7 -7
- torchzero/modules/grad_approximation/grad_approximator.py +23 -16
- torchzero/modules/grad_approximation/rfdm.py +20 -17
- torchzero/modules/least_squares/gn.py +90 -42
- torchzero/modules/line_search/__init__.py +1 -1
- torchzero/modules/line_search/_polyinterp.py +3 -1
- torchzero/modules/line_search/adaptive.py +3 -3
- torchzero/modules/line_search/backtracking.py +3 -3
- torchzero/modules/line_search/interpolation.py +160 -0
- torchzero/modules/line_search/line_search.py +42 -51
- torchzero/modules/line_search/strong_wolfe.py +5 -5
- torchzero/modules/misc/debug.py +12 -12
- torchzero/modules/misc/escape.py +10 -10
- torchzero/modules/misc/gradient_accumulation.py +10 -78
- torchzero/modules/misc/homotopy.py +16 -8
- torchzero/modules/misc/misc.py +120 -122
- torchzero/modules/misc/multistep.py +63 -61
- torchzero/modules/misc/regularization.py +49 -44
- torchzero/modules/misc/split.py +30 -28
- torchzero/modules/misc/switch.py +37 -32
- torchzero/modules/momentum/averaging.py +14 -14
- torchzero/modules/momentum/cautious.py +34 -28
- torchzero/modules/momentum/momentum.py +11 -11
- torchzero/modules/ops/__init__.py +4 -4
- torchzero/modules/ops/accumulate.py +21 -21
- torchzero/modules/ops/binary.py +67 -66
- torchzero/modules/ops/higher_level.py +19 -19
- torchzero/modules/ops/multi.py +44 -41
- torchzero/modules/ops/reduce.py +26 -23
- torchzero/modules/ops/unary.py +53 -53
- torchzero/modules/ops/utility.py +47 -46
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +43 -43
- torchzero/modules/quasi_newton/__init__.py +2 -0
- torchzero/modules/quasi_newton/damping.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +7 -7
- torchzero/modules/quasi_newton/lsr1.py +7 -7
- torchzero/modules/quasi_newton/quasi_newton.py +25 -16
- torchzero/modules/quasi_newton/sg2.py +292 -0
- torchzero/modules/restarts/restars.py +26 -24
- torchzero/modules/second_order/__init__.py +6 -3
- torchzero/modules/second_order/ifn.py +58 -0
- torchzero/modules/second_order/inm.py +101 -0
- torchzero/modules/second_order/multipoint.py +40 -80
- torchzero/modules/second_order/newton.py +105 -228
- torchzero/modules/second_order/newton_cg.py +102 -154
- torchzero/modules/second_order/nystrom.py +158 -178
- torchzero/modules/second_order/rsn.py +237 -0
- torchzero/modules/smoothing/laplacian.py +13 -12
- torchzero/modules/smoothing/sampling.py +11 -10
- torchzero/modules/step_size/adaptive.py +23 -23
- torchzero/modules/step_size/lr.py +15 -15
- torchzero/modules/termination/termination.py +32 -30
- torchzero/modules/trust_region/cubic_regularization.py +2 -2
- torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
- torchzero/modules/trust_region/trust_cg.py +1 -1
- torchzero/modules/trust_region/trust_region.py +27 -22
- torchzero/modules/variance_reduction/svrg.py +21 -18
- torchzero/modules/weight_decay/__init__.py +2 -1
- torchzero/modules/weight_decay/reinit.py +83 -0
- torchzero/modules/weight_decay/weight_decay.py +12 -13
- torchzero/modules/wrappers/optim_wrapper.py +57 -50
- torchzero/modules/zeroth_order/cd.py +9 -6
- torchzero/optim/root.py +3 -3
- torchzero/optim/utility/split.py +2 -1
- torchzero/optim/wrappers/directsearch.py +27 -63
- torchzero/optim/wrappers/fcmaes.py +14 -35
- torchzero/optim/wrappers/mads.py +11 -31
- torchzero/optim/wrappers/moors.py +66 -0
- torchzero/optim/wrappers/nevergrad.py +4 -4
- torchzero/optim/wrappers/nlopt.py +31 -25
- torchzero/optim/wrappers/optuna.py +6 -13
- torchzero/optim/wrappers/pybobyqa.py +124 -0
- torchzero/optim/wrappers/scipy/__init__.py +7 -0
- torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
- torchzero/optim/wrappers/scipy/brute.py +48 -0
- torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
- torchzero/optim/wrappers/scipy/direct.py +69 -0
- torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
- torchzero/optim/wrappers/scipy/experimental.py +141 -0
- torchzero/optim/wrappers/scipy/minimize.py +151 -0
- torchzero/optim/wrappers/scipy/sgho.py +111 -0
- torchzero/optim/wrappers/wrapper.py +121 -0
- torchzero/utils/__init__.py +7 -25
- torchzero/utils/compile.py +2 -2
- torchzero/utils/derivatives.py +112 -88
- torchzero/utils/optimizer.py +4 -77
- torchzero/utils/python_tools.py +31 -0
- torchzero/utils/tensorlist.py +11 -5
- torchzero/utils/thoad_tools.py +68 -0
- {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
- torchzero-0.4.0.dist-info/RECORD +191 -0
- tests/test_vars.py +0 -185
- torchzero/modules/experimental/momentum.py +0 -160
- torchzero/modules/higher_order/__init__.py +0 -1
- torchzero/optim/wrappers/scipy.py +0 -572
- torchzero/utils/linalg/__init__.py +0 -12
- torchzero/utils/linalg/matrix_funcs.py +0 -87
- torchzero/utils/linalg/orthogonalize.py +0 -12
- torchzero/utils/linalg/svd.py +0 -20
- torchzero/utils/ops.py +0 -10
- torchzero-0.3.14.dist-info/RECORD +0 -167
- /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
- {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
- {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
|
@@ -1,150 +1,141 @@
|
|
|
1
|
-
from
|
|
2
|
-
from typing import Literal, overload
|
|
3
|
-
import warnings
|
|
4
|
-
import torch
|
|
1
|
+
from typing import Literal
|
|
5
2
|
|
|
6
|
-
|
|
7
|
-
from ...utils.derivatives import hvp, hvp_fd_central, hvp_fd_forward
|
|
3
|
+
import torch
|
|
8
4
|
|
|
9
|
-
from ...core import Chainable,
|
|
10
|
-
from ...utils
|
|
5
|
+
from ...core import Chainable, Transform, HVPMethod
|
|
6
|
+
from ...utils import TensorList, vec_to_tensors
|
|
7
|
+
from ...linalg import nystrom_pcg, nystrom_sketch_and_solve, nystrom_approximation, cg
|
|
8
|
+
from ...linalg.linear_operator import Eigendecomposition, ScaledIdentity
|
|
11
9
|
|
|
12
|
-
class NystromSketchAndSolve(
|
|
10
|
+
class NystromSketchAndSolve(Transform):
|
|
13
11
|
"""Newton's method with a Nyström sketch-and-solve solver.
|
|
14
12
|
|
|
15
|
-
|
|
16
|
-
This module requires the a closure passed to the optimizer step,
|
|
17
|
-
as it needs to re-evaluate the loss and gradients for calculating HVPs.
|
|
18
|
-
The closure must accept a ``backward`` argument (refer to documentation).
|
|
19
|
-
|
|
20
|
-
.. note::
|
|
21
|
-
In most cases NystromSketchAndSolve should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
|
|
13
|
+
Notes:
|
|
14
|
+
- This module requires the a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a ``backward`` argument (refer to documentation).
|
|
22
15
|
|
|
23
|
-
|
|
24
|
-
If this is unstable, increase the :code:`reg` parameter and tune the rank.
|
|
16
|
+
- In most cases NystromSketchAndSolve should be the first module in the chain because it relies on autograd. Use the ``inner`` argument if you wish to apply Newton preconditioning to another module's output.
|
|
25
17
|
|
|
26
|
-
|
|
27
|
-
:code:`tz.m.NystromPCG` usually outperforms this.
|
|
18
|
+
- If this is unstable, increase the ``reg`` parameter and tune the rank.
|
|
28
19
|
|
|
29
20
|
Args:
|
|
30
21
|
rank (int): size of the sketch, this many hessian-vector products will be evaluated per step.
|
|
31
22
|
reg (float, optional): regularization parameter. Defaults to 1e-3.
|
|
32
23
|
hvp_method (str, optional):
|
|
33
|
-
Determines how Hessian-vector products are
|
|
34
|
-
|
|
35
|
-
- ``"
|
|
36
|
-
|
|
37
|
-
- ``"
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
24
|
+
Determines how Hessian-vector products are computed.
|
|
25
|
+
|
|
26
|
+
- ``"batched_autograd"`` - uses autograd with batched hessian-vector products to compute the preconditioner. Faster than ``"autograd"`` but uses more memory.
|
|
27
|
+
- ``"autograd"`` - uses autograd hessian-vector products, uses a for loop to compute the preconditioner. Slower than ``"batched_autograd"`` but uses less memory.
|
|
28
|
+
- ``"fd_forward"`` - uses gradient finite difference approximation with a less accurate forward formula which requires one extra gradient evaluation per hessian-vector product.
|
|
29
|
+
- ``"fd_central"`` - uses gradient finite difference approximation with a more accurate central formula which requires two gradient evaluations per hessian-vector product.
|
|
30
|
+
|
|
31
|
+
Defaults to ``"autograd"``.
|
|
32
|
+
h (float, optional):
|
|
33
|
+
The step size for finite difference if ``hvp_method`` is
|
|
34
|
+
``"fd_forward"`` or ``"fd_central"``. Defaults to 1e-3.
|
|
44
35
|
inner (Chainable | None, optional): modules to apply hessian preconditioner to. Defaults to None.
|
|
45
36
|
seed (int | None, optional): seed for random generator. Defaults to None.
|
|
46
37
|
|
|
38
|
+
|
|
47
39
|
Examples:
|
|
48
|
-
|
|
40
|
+
NystromSketchAndSolve with backtracking line search
|
|
49
41
|
|
|
50
|
-
|
|
42
|
+
```py
|
|
43
|
+
opt = tz.Modular(
|
|
44
|
+
model.parameters(),
|
|
45
|
+
tz.m.NystromSketchAndSolve(100),
|
|
46
|
+
tz.m.Backtracking()
|
|
47
|
+
)
|
|
48
|
+
```
|
|
51
49
|
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
50
|
+
Trust region NystromSketchAndSolve
|
|
51
|
+
|
|
52
|
+
```py
|
|
53
|
+
opt = tz.Modular(
|
|
54
|
+
model.parameters(),
|
|
55
|
+
tz.m.LevenbergMarquadt(tz.m.NystromSketchAndSolve(100)),
|
|
56
|
+
)
|
|
57
|
+
```
|
|
58
|
+
|
|
59
|
+
References:
|
|
60
|
+
- [Frangella, Z., Rathore, P., Zhao, S., & Udell, M. (2024). SketchySGD: Reliable Stochastic Optimization via Randomized Curvature Estimates. SIAM Journal on Mathematics of Data Science, 6(4), 1173-1204.](https://arxiv.org/pdf/2211.08597)
|
|
61
|
+
- [Frangella, Z., Tropp, J. A., & Udell, M. (2023). Randomized nyström preconditioning. SIAM Journal on Matrix Analysis and Applications, 44(2), 718-752](https://arxiv.org/abs/2110.02820)
|
|
57
62
|
|
|
58
|
-
Reference:
|
|
59
|
-
Frangella, Z., Tropp, J. A., & Udell, M. (2023). Randomized nyström preconditioning. SIAM Journal on Matrix Analysis and Applications, 44(2), 718-752. https://arxiv.org/abs/2110.02820
|
|
60
63
|
"""
|
|
61
64
|
def __init__(
|
|
62
65
|
self,
|
|
63
66
|
rank: int,
|
|
64
67
|
reg: float = 1e-3,
|
|
65
|
-
hvp_method:
|
|
68
|
+
hvp_method: HVPMethod = "batched_autograd",
|
|
66
69
|
h: float = 1e-3,
|
|
70
|
+
update_freq: int = 1,
|
|
67
71
|
inner: Chainable | None = None,
|
|
68
72
|
seed: int | None = None,
|
|
69
73
|
):
|
|
70
|
-
defaults =
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
if inner is not None:
|
|
74
|
-
self.set_child('inner', inner)
|
|
74
|
+
defaults = locals().copy()
|
|
75
|
+
del defaults['self'], defaults['inner'], defaults["update_freq"]
|
|
76
|
+
super().__init__(defaults, update_freq=update_freq, inner=inner)
|
|
75
77
|
|
|
76
78
|
@torch.no_grad
|
|
77
|
-
def
|
|
78
|
-
params = TensorList(
|
|
79
|
-
|
|
80
|
-
closure = var.closure
|
|
81
|
-
if closure is None: raise RuntimeError('NewtonCG requires closure')
|
|
82
|
-
|
|
83
|
-
settings = self.settings[params[0]]
|
|
84
|
-
rank = settings['rank']
|
|
85
|
-
reg = settings['reg']
|
|
86
|
-
hvp_method = settings['hvp_method']
|
|
87
|
-
h = settings['h']
|
|
88
|
-
|
|
89
|
-
seed = settings['seed']
|
|
90
|
-
generator = None
|
|
91
|
-
if seed is not None:
|
|
92
|
-
if 'generator' not in self.global_state:
|
|
93
|
-
self.global_state['generator'] = torch.Generator(params[0].device).manual_seed(seed)
|
|
94
|
-
generator = self.global_state['generator']
|
|
79
|
+
def update_states(self, objective, states, settings):
|
|
80
|
+
params = TensorList(objective.params)
|
|
81
|
+
fs = settings[0]
|
|
95
82
|
|
|
96
83
|
# ---------------------- Hessian vector product function --------------------- #
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
def H_mm(x):
|
|
101
|
-
with torch.enable_grad():
|
|
102
|
-
Hvp = hvp(params, grad, params.from_vec(x), retain_graph=True)
|
|
103
|
-
return torch.cat([t.ravel() for t in Hvp])
|
|
84
|
+
hvp_method = fs['hvp_method']
|
|
85
|
+
h = fs['h']
|
|
86
|
+
_, H_mv, H_mm = objective.tensor_Hvp_function(hvp_method=hvp_method, h=h, at_x0=True)
|
|
104
87
|
|
|
105
|
-
|
|
88
|
+
# ---------------------------------- sketch ---------------------------------- #
|
|
89
|
+
ndim = sum(t.numel() for t in objective.params)
|
|
90
|
+
device = params[0].device
|
|
91
|
+
dtype = params[0].dtype
|
|
106
92
|
|
|
107
|
-
|
|
108
|
-
|
|
93
|
+
generator = self.get_generator(params[0].device, seed=fs['seed'])
|
|
94
|
+
try:
|
|
95
|
+
L, Q = nystrom_approximation(A_mv=H_mv, A_mm=H_mm, ndim=ndim, rank=fs['rank'],
|
|
96
|
+
dtype=dtype, device=device, generator=generator)
|
|
109
97
|
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
98
|
+
self.global_state["L"] = L
|
|
99
|
+
self.global_state["Q"] = Q
|
|
100
|
+
except torch.linalg.LinAlgError:
|
|
101
|
+
pass
|
|
114
102
|
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
return torch.cat([t.ravel() for t in Hvp])
|
|
103
|
+
def apply_states(self, objective, states, settings):
|
|
104
|
+
fs = settings[0]
|
|
105
|
+
b = objective.get_updates()
|
|
119
106
|
|
|
120
|
-
|
|
121
|
-
|
|
107
|
+
# ----------------------------------- solve ---------------------------------- #
|
|
108
|
+
if "L" not in self.global_state:
|
|
109
|
+
return objective
|
|
122
110
|
|
|
111
|
+
L = self.global_state["L"]
|
|
112
|
+
Q = self.global_state["Q"]
|
|
113
|
+
x = nystrom_sketch_and_solve(L=L, Q=Q, b=torch.cat([t.ravel() for t in b]), reg=fs["reg"])
|
|
123
114
|
|
|
124
|
-
# --------------------------------
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
b = apply_transform(self.children['inner'], b, params=params, grads=grad, var=var)
|
|
115
|
+
# -------------------------------- set update -------------------------------- #
|
|
116
|
+
objective.updates = vec_to_tensors(x, reference=objective.params)
|
|
117
|
+
return objective
|
|
128
118
|
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
return var
|
|
119
|
+
def get_H(self, objective=...):
|
|
120
|
+
if "L" not in self.global_state:
|
|
121
|
+
return ScaledIdentity()
|
|
133
122
|
|
|
123
|
+
L = self.global_state["L"]
|
|
124
|
+
Q = self.global_state["Q"]
|
|
125
|
+
return Eigendecomposition(L, Q)
|
|
134
126
|
|
|
135
127
|
|
|
136
|
-
class NystromPCG(
|
|
128
|
+
class NystromPCG(Transform):
|
|
137
129
|
"""Newton's method with a Nyström-preconditioned conjugate gradient solver.
|
|
138
130
|
This tends to outperform NewtonCG but requires tuning sketch size.
|
|
139
131
|
An adaptive version exists in https://arxiv.org/abs/2110.02820, I might implement it too at some point.
|
|
140
132
|
|
|
141
|
-
|
|
142
|
-
This module requires the a closure passed to the optimizer step,
|
|
133
|
+
Notes:
|
|
134
|
+
- This module requires the a closure passed to the optimizer step,
|
|
143
135
|
as it needs to re-evaluate the loss and gradients for calculating HVPs.
|
|
144
136
|
The closure must accept a ``backward`` argument (refer to documentation).
|
|
145
137
|
|
|
146
|
-
|
|
147
|
-
In most cases NystromPCG should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
|
|
138
|
+
- In most cases NystromPCG should be the first module in the chain because it relies on autograd. Use the ``inner`` argument if you wish to apply Newton preconditioning to another module's output.
|
|
148
139
|
|
|
149
140
|
Args:
|
|
150
141
|
sketch_size (int):
|
|
@@ -159,31 +150,31 @@ class NystromPCG(Module):
|
|
|
159
150
|
tol (float, optional): relative tolerance for conjugate gradient solver. Defaults to 1e-4.
|
|
160
151
|
reg (float, optional): regularization parameter. Defaults to 1e-8.
|
|
161
152
|
hvp_method (str, optional):
|
|
162
|
-
Determines how Hessian-vector products are
|
|
163
|
-
|
|
164
|
-
- ``"
|
|
165
|
-
|
|
166
|
-
- ``"
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
153
|
+
Determines how Hessian-vector products are computed.
|
|
154
|
+
|
|
155
|
+
- ``"batched_autograd"`` - uses autograd with batched hessian-vector products to compute the preconditioner. Faster than ``"autograd"`` but uses more memory.
|
|
156
|
+
- ``"autograd"`` - uses autograd hessian-vector products, uses a for loop to compute the preconditioner. Slower than ``"batched_autograd"`` but uses less memory.
|
|
157
|
+
- ``"fd_forward"`` - uses gradient finite difference approximation with a less accurate forward formula which requires one extra gradient evaluation per hessian-vector product.
|
|
158
|
+
- ``"fd_central"`` - uses gradient finite difference approximation with a more accurate central formula which requires two gradient evaluations per hessian-vector product.
|
|
159
|
+
|
|
160
|
+
Defaults to ``"autograd"``.
|
|
161
|
+
h (float, optional):
|
|
162
|
+
The step size for finite difference if ``hvp_method`` is
|
|
163
|
+
``"fd_forward"`` or ``"fd_central"``. Defaults to 1e-3.
|
|
173
164
|
inner (Chainable | None, optional): modules to apply hessian preconditioner to. Defaults to None.
|
|
174
165
|
seed (int | None, optional): seed for random generator. Defaults to None.
|
|
175
166
|
|
|
176
167
|
Examples:
|
|
177
168
|
|
|
178
|
-
|
|
169
|
+
NystromPCG with backtracking line search
|
|
179
170
|
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
171
|
+
```python
|
|
172
|
+
opt = tz.Modular(
|
|
173
|
+
model.parameters(),
|
|
174
|
+
tz.m.NystromPCG(10),
|
|
175
|
+
tz.m.Backtracking()
|
|
176
|
+
)
|
|
177
|
+
```
|
|
187
178
|
|
|
188
179
|
Reference:
|
|
189
180
|
Frangella, Z., Tropp, J. A., & Udell, M. (2023). Randomized nyström preconditioning. SIAM Journal on Matrix Analysis and Applications, 44(2), 718-752. https://arxiv.org/abs/2110.02820
|
|
@@ -191,81 +182,70 @@ class NystromPCG(Module):
|
|
|
191
182
|
"""
|
|
192
183
|
def __init__(
|
|
193
184
|
self,
|
|
194
|
-
|
|
185
|
+
rank: int,
|
|
195
186
|
maxiter=None,
|
|
196
|
-
tol=1e-
|
|
187
|
+
tol=1e-8,
|
|
197
188
|
reg: float = 1e-6,
|
|
198
|
-
|
|
189
|
+
update_freq: int = 1, # here update_freq is within update_states
|
|
190
|
+
hvp_method: HVPMethod = "batched_autograd",
|
|
199
191
|
h=1e-3,
|
|
200
192
|
inner: Chainable | None = None,
|
|
201
193
|
seed: int | None = None,
|
|
202
194
|
):
|
|
203
|
-
defaults =
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
if inner is not None:
|
|
207
|
-
self.set_child('inner', inner)
|
|
195
|
+
defaults = locals().copy()
|
|
196
|
+
del defaults['self'], defaults['inner']
|
|
197
|
+
super().__init__(defaults, inner=inner)
|
|
208
198
|
|
|
209
199
|
@torch.no_grad
|
|
210
|
-
def
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
closure = var.closure
|
|
214
|
-
if closure is None: raise RuntimeError('NewtonCG requires closure')
|
|
215
|
-
|
|
216
|
-
settings = self.settings[params[0]]
|
|
217
|
-
sketch_size = settings['sketch_size']
|
|
218
|
-
maxiter = settings['maxiter']
|
|
219
|
-
tol = settings['tol']
|
|
220
|
-
reg = settings['reg']
|
|
221
|
-
hvp_method = settings['hvp_method']
|
|
222
|
-
h = settings['h']
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
seed = settings['seed']
|
|
226
|
-
generator = None
|
|
227
|
-
if seed is not None:
|
|
228
|
-
if 'generator' not in self.global_state:
|
|
229
|
-
self.global_state['generator'] = torch.Generator(params[0].device).manual_seed(seed)
|
|
230
|
-
generator = self.global_state['generator']
|
|
231
|
-
|
|
200
|
+
def update_states(self, objective, states, settings):
|
|
201
|
+
fs = settings[0]
|
|
232
202
|
|
|
233
203
|
# ---------------------- Hessian vector product function --------------------- #
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
Hvp = hvp(params, grad, params.from_vec(x), retain_graph=True)
|
|
240
|
-
return torch.cat([t.ravel() for t in Hvp])
|
|
241
|
-
|
|
242
|
-
else:
|
|
204
|
+
# this should run on every update_states
|
|
205
|
+
hvp_method = fs['hvp_method']
|
|
206
|
+
h = fs['h']
|
|
207
|
+
_, H_mv, H_mm = objective.tensor_Hvp_function(hvp_method=hvp_method, h=h, at_x0=True)
|
|
208
|
+
objective.temp = H_mv
|
|
243
209
|
|
|
244
|
-
|
|
245
|
-
|
|
210
|
+
# --------------------------- update preconditioner -------------------------- #
|
|
211
|
+
step = self.increment_counter("step", 0)
|
|
212
|
+
update_freq = self.defaults["update_freq"]
|
|
246
213
|
|
|
247
|
-
|
|
248
|
-
def H_mm(x):
|
|
249
|
-
Hvp = hvp_fd_forward(closure, params, params.from_vec(x), h=h, g_0=grad, normalize=True)[1]
|
|
250
|
-
return torch.cat([t.ravel() for t in Hvp])
|
|
214
|
+
if step % update_freq == 0:
|
|
251
215
|
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
216
|
+
rank = fs['rank']
|
|
217
|
+
ndim = sum(t.numel() for t in objective.params)
|
|
218
|
+
device = objective.params[0].device
|
|
219
|
+
dtype = objective.params[0].dtype
|
|
220
|
+
generator = self.get_generator(device, seed=fs['seed'])
|
|
256
221
|
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
# -------------------------------- inner step -------------------------------- #
|
|
262
|
-
b = var.get_update()
|
|
263
|
-
if 'inner' in self.children:
|
|
264
|
-
b = apply_transform(self.children['inner'], b, params=params, grads=grad, var=var)
|
|
265
|
-
|
|
266
|
-
# ------------------------------ sketch&n&solve ------------------------------ #
|
|
267
|
-
x = nystrom_pcg(A_mm=H_mm, b=torch.cat([t.ravel() for t in b]), sketch_size=sketch_size, reg=reg, tol=tol, maxiter=maxiter, x0_=None, generator=generator)
|
|
268
|
-
var.update = vec_to_tensors(x, reference=params)
|
|
269
|
-
return var
|
|
222
|
+
try:
|
|
223
|
+
L, Q = nystrom_approximation(A_mv=None, A_mm=H_mm, ndim=ndim, rank=rank,
|
|
224
|
+
dtype=dtype, device=device, generator=generator)
|
|
270
225
|
|
|
226
|
+
self.global_state["L"] = L
|
|
227
|
+
self.global_state["Q"] = Q
|
|
228
|
+
except torch.linalg.LinAlgError:
|
|
229
|
+
pass
|
|
271
230
|
|
|
231
|
+
@torch.no_grad
|
|
232
|
+
def apply_states(self, objective, states, settings):
|
|
233
|
+
b = objective.get_updates()
|
|
234
|
+
H_mv = objective.poptemp()
|
|
235
|
+
fs = self.settings[objective.params[0]]
|
|
236
|
+
|
|
237
|
+
# ----------------------------------- solve ---------------------------------- #
|
|
238
|
+
if "L" not in self.global_state:
|
|
239
|
+
# fallback on cg
|
|
240
|
+
sol = cg(A_mv=H_mv, b=TensorList(b), tol=fs["tol"], reg=fs["reg"], maxiter=fs["maxiter"])
|
|
241
|
+
objective.updates = sol.x
|
|
242
|
+
return objective
|
|
243
|
+
|
|
244
|
+
L = self.global_state["L"]
|
|
245
|
+
Q = self.global_state["Q"]
|
|
246
|
+
x = nystrom_pcg(L=L, Q=Q, A_mv=H_mv, b=torch.cat([t.ravel() for t in b]),
|
|
247
|
+
reg=fs['reg'], tol=fs["tol"], maxiter=fs["maxiter"])
|
|
248
|
+
|
|
249
|
+
# -------------------------------- set update -------------------------------- #
|
|
250
|
+
objective.updates = vec_to_tensors(x, reference=objective.params)
|
|
251
|
+
return objective
|
|
@@ -0,0 +1,237 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from collections import deque
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
from typing import Literal
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from ...core import Chainable, Transform, HVPMethod
|
|
9
|
+
from ...utils import vec_to_tensors
|
|
10
|
+
from ...linalg.linear_operator import Sketched
|
|
11
|
+
|
|
12
|
+
from .newton import _newton_step
|
|
13
|
+
|
|
14
|
+
def _qr_orthonormalize(A:torch.Tensor):
|
|
15
|
+
m,n = A.shape
|
|
16
|
+
if m < n:
|
|
17
|
+
q, _ = torch.linalg.qr(A.T) # pylint:disable=not-callable
|
|
18
|
+
return q.T
|
|
19
|
+
|
|
20
|
+
q, _ = torch.linalg.qr(A) # pylint:disable=not-callable
|
|
21
|
+
return q
|
|
22
|
+
|
|
23
|
+
def _orthonormal_sketch(m, n, dtype, device, generator):
|
|
24
|
+
return _qr_orthonormalize(torch.randn(m, n, dtype=dtype, device=device, generator=generator))
|
|
25
|
+
|
|
26
|
+
def _gaussian_sketch(m, n, dtype, device, generator):
|
|
27
|
+
return torch.randn(m, n, dtype=dtype, device=device, generator=generator) / math.sqrt(m)
|
|
28
|
+
|
|
29
|
+
def _rademacher_sketch(m, n, dtype, device, generator):
|
|
30
|
+
rademacher = torch.bernoulli(torch.full((m,n), 0.5), generator = generator).mul_(2).sub_(1)
|
|
31
|
+
return rademacher.mul_(1 / math.sqrt(m))
|
|
32
|
+
|
|
33
|
+
class SubspaceNewton(Transform):
|
|
34
|
+
"""Subspace Newton. Performs a Newton step in a subspace (random or spanned by past gradients).
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
sketch_size (int):
|
|
38
|
+
size of the random sketch. This many hessian-vector products will need to be evaluated each step.
|
|
39
|
+
sketch_type (str, optional):
|
|
40
|
+
- "orthonormal" - random orthonormal basis. Orthonormality is necessary to use linear operator based modules such as trust region, but it can be slower to compute.
|
|
41
|
+
- "rademacher" - approximately orthonormal scaled random rademacher basis.
|
|
42
|
+
- "gaussian" - random gaussian (not orthonormal) basis.
|
|
43
|
+
- "common_directions" - uses history steepest descent directions as the basis[2]. It is orthonormalized on-line using Gram-Schmidt.
|
|
44
|
+
- "mixed" - random orthonormal basis but with four directions set to gradient, slow and fast gradient EMAs, and previous update direction (default).
|
|
45
|
+
damping (float, optional): hessian damping (scale of identity matrix added to hessian). Defaults to 0.
|
|
46
|
+
hvp_method (str, optional):
|
|
47
|
+
How to compute hessian-matrix product:
|
|
48
|
+
- "batched_autograd" - uses batched autograd
|
|
49
|
+
- "autograd" - uses unbatched autograd
|
|
50
|
+
- "forward" - uses finite difference with forward formula, performing 1 backward pass per Hvp.
|
|
51
|
+
- "central" - uses finite difference with a more accurate central formula, performing 2 backward passes per Hvp.
|
|
52
|
+
|
|
53
|
+
. Defaults to "batched_autograd".
|
|
54
|
+
h (float, optional): finite difference step size. Defaults to 1e-2.
|
|
55
|
+
use_lstsq (bool, optional): whether to use least squares to solve ``Hx=g``. Defaults to False.
|
|
56
|
+
update_freq (int, optional): frequency of updating the hessian. Defaults to 1.
|
|
57
|
+
H_tfm (Callable | None, optional):
|
|
58
|
+
optional hessian transforms, takes in two arguments - `(hessian, gradient)`.
|
|
59
|
+
|
|
60
|
+
must return either a tuple: `(hessian, is_inverted)` with transformed hessian and a boolean value
|
|
61
|
+
which must be True if transform inverted the hessian and False otherwise.
|
|
62
|
+
|
|
63
|
+
Or it returns a single tensor which is used as the update.
|
|
64
|
+
|
|
65
|
+
Defaults to None.
|
|
66
|
+
eigval_fn (Callable | None, optional):
|
|
67
|
+
optional eigenvalues transform, for example ``torch.abs`` or ``lambda L: torch.clip(L, min=1e-8)``.
|
|
68
|
+
If this is specified, eigendecomposition will be used to invert the hessian.
|
|
69
|
+
seed (int | None, optional): seed for random generator. Defaults to None.
|
|
70
|
+
inner (Chainable | None, optional): preconditions output of this module. Defaults to None.
|
|
71
|
+
|
|
72
|
+
### Examples
|
|
73
|
+
|
|
74
|
+
RSN with line search
|
|
75
|
+
```python
|
|
76
|
+
opt = tz.Modular(
|
|
77
|
+
model.parameters(),
|
|
78
|
+
tz.m.RSN(),
|
|
79
|
+
tz.m.Backtracking()
|
|
80
|
+
)
|
|
81
|
+
```
|
|
82
|
+
|
|
83
|
+
RSN with trust region
|
|
84
|
+
```python
|
|
85
|
+
opt = tz.Modular(
|
|
86
|
+
model.parameters(),
|
|
87
|
+
tz.m.LevenbergMarquardt(tz.m.RSN()),
|
|
88
|
+
)
|
|
89
|
+
```
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
References:
|
|
93
|
+
1. [Gower, Robert, et al. "RSN: randomized subspace Newton." Advances in Neural Information Processing Systems 32 (2019).](https://arxiv.org/abs/1905.10874)
|
|
94
|
+
2. Wang, Po-Wei, Ching-pei Lee, and Chih-Jen Lin. "The common-directions method for regularized empirical risk minimization." Journal of Machine Learning Research 20.58 (2019): 1-49.
|
|
95
|
+
"""
|
|
96
|
+
|
|
97
|
+
def __init__(
|
|
98
|
+
self,
|
|
99
|
+
sketch_size: int,
|
|
100
|
+
sketch_type: Literal["orthonormal", "gaussian", "common_directions", "mixed"] = "mixed",
|
|
101
|
+
damping:float=0,
|
|
102
|
+
hvp_method: HVPMethod = "batched_autograd",
|
|
103
|
+
h: float = 1e-2,
|
|
104
|
+
use_lstsq: bool = True,
|
|
105
|
+
update_freq: int = 1,
|
|
106
|
+
H_tfm: Callable[[torch.Tensor, torch.Tensor], tuple[torch.Tensor, bool]] | Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
|
|
107
|
+
eigval_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
|
|
108
|
+
seed: int | None = None,
|
|
109
|
+
inner: Chainable | None = None,
|
|
110
|
+
):
|
|
111
|
+
defaults = locals().copy()
|
|
112
|
+
del defaults['self'], defaults['inner'], defaults["update_freq"]
|
|
113
|
+
super().__init__(defaults, update_freq=update_freq, inner=inner)
|
|
114
|
+
|
|
115
|
+
@torch.no_grad
|
|
116
|
+
def update_states(self, objective, states, settings):
|
|
117
|
+
fs = settings[0]
|
|
118
|
+
params = objective.params
|
|
119
|
+
generator = self.get_generator(params[0].device, fs["seed"])
|
|
120
|
+
|
|
121
|
+
ndim = sum(p.numel() for p in params)
|
|
122
|
+
|
|
123
|
+
device=params[0].device
|
|
124
|
+
dtype=params[0].dtype
|
|
125
|
+
|
|
126
|
+
# sample sketch matrix S: (ndim, sketch_size)
|
|
127
|
+
sketch_size = min(fs["sketch_size"], ndim)
|
|
128
|
+
sketch_type = fs["sketch_type"]
|
|
129
|
+
hvp_method = fs["hvp_method"]
|
|
130
|
+
|
|
131
|
+
if sketch_type in ('normal', 'gaussian'):
|
|
132
|
+
S = _gaussian_sketch(ndim, sketch_size, device=device, dtype=dtype, generator=generator)
|
|
133
|
+
|
|
134
|
+
elif sketch_type == "rademacher":
|
|
135
|
+
S = _rademacher_sketch(ndim, sketch_size, device=device, dtype=dtype, generator=generator)
|
|
136
|
+
|
|
137
|
+
elif sketch_type == 'orthonormal':
|
|
138
|
+
S = _orthonormal_sketch(ndim, sketch_size, device=device, dtype=dtype, generator=generator)
|
|
139
|
+
|
|
140
|
+
elif sketch_type == 'common_directions':
|
|
141
|
+
# Wang, Po-Wei, Ching-pei Lee, and Chih-Jen Lin. "The common-directions method for regularized empirical risk minimization." Journal of Machine Learning Research 20.58 (2019): 1-49.
|
|
142
|
+
g_list = objective.get_grads(create_graph=hvp_method in ("batched_autograd", "autograd"))
|
|
143
|
+
g = torch.cat([t.ravel() for t in g_list])
|
|
144
|
+
|
|
145
|
+
# initialize directions deque
|
|
146
|
+
if "directions" not in self.global_state:
|
|
147
|
+
|
|
148
|
+
g_norm = torch.linalg.vector_norm(g) # pylint:disable=not-callable
|
|
149
|
+
if g_norm < torch.finfo(g.dtype).tiny * 2:
|
|
150
|
+
g = torch.randn_like(g)
|
|
151
|
+
g_norm = torch.linalg.vector_norm(g) # pylint:disable=not-callable
|
|
152
|
+
|
|
153
|
+
self.global_state["directions"] = deque([g / g_norm], maxlen=sketch_size)
|
|
154
|
+
S = self.global_state["directions"][0].unsqueeze(1)
|
|
155
|
+
|
|
156
|
+
# add new steepest descent direction orthonormal to existing columns
|
|
157
|
+
else:
|
|
158
|
+
S = torch.stack(tuple(self.global_state["directions"]), dim=1)
|
|
159
|
+
p = g - S @ (S.T @ g)
|
|
160
|
+
p_norm = torch.linalg.vector_norm(p) # pylint:disable=not-callable
|
|
161
|
+
if p_norm > torch.finfo(p.dtype).tiny * 2:
|
|
162
|
+
p = p / p_norm
|
|
163
|
+
self.global_state["directions"].append(p)
|
|
164
|
+
S = torch.cat([S, p.unsqueeze(1)], dim=1)
|
|
165
|
+
|
|
166
|
+
elif sketch_type == "mixed":
|
|
167
|
+
g_list = objective.get_grads(create_graph=hvp_method in ("batched_autograd", "autograd"))
|
|
168
|
+
g = torch.cat([t.ravel() for t in g_list])
|
|
169
|
+
|
|
170
|
+
# initialize state
|
|
171
|
+
if "slow_ema" not in self.global_state:
|
|
172
|
+
self.global_state["slow_ema"] = torch.randn_like(g) * 1e-2
|
|
173
|
+
self.global_state["fast_ema"] = torch.randn_like(g) * 1e-2
|
|
174
|
+
self.global_state["p_prev"] = torch.randn_like(g)
|
|
175
|
+
|
|
176
|
+
# previous update direction
|
|
177
|
+
p_cur = torch.cat([t.ravel() for t in params])
|
|
178
|
+
prev_dir = p_cur - self.global_state["p_prev"]
|
|
179
|
+
self.global_state["p_prev"] = p_cur
|
|
180
|
+
|
|
181
|
+
# EMAs
|
|
182
|
+
slow_ema = self.global_state["slow_ema"]
|
|
183
|
+
fast_ema = self.global_state["fast_ema"]
|
|
184
|
+
slow_ema.lerp_(g, 0.001)
|
|
185
|
+
fast_ema.lerp_(g, 0.1)
|
|
186
|
+
|
|
187
|
+
# form and orthogonalize sketching matrix
|
|
188
|
+
S = torch.stack([g, slow_ema, fast_ema, prev_dir], dim=1)
|
|
189
|
+
if sketch_size > 4:
|
|
190
|
+
S_random = _gaussian_sketch(ndim, sketch_size - 3, device=device, dtype=dtype, generator=generator)
|
|
191
|
+
S = torch.cat([S, S_random], dim=1)
|
|
192
|
+
|
|
193
|
+
S = _qr_orthonormalize(S)
|
|
194
|
+
|
|
195
|
+
else:
|
|
196
|
+
raise ValueError(f'Unknown sketch_type {sketch_type}')
|
|
197
|
+
|
|
198
|
+
# form sketched hessian
|
|
199
|
+
HS, _ = objective.hessian_matrix_product(S, rgrad=None, at_x0=True,
|
|
200
|
+
hvp_method=fs["hvp_method"], h=fs["h"])
|
|
201
|
+
H_sketched = S.T @ HS
|
|
202
|
+
|
|
203
|
+
self.global_state["H_sketched"] = H_sketched
|
|
204
|
+
self.global_state["S"] = S
|
|
205
|
+
|
|
206
|
+
def apply_states(self, objective, states, settings):
|
|
207
|
+
S: torch.Tensor = self.global_state["S"]
|
|
208
|
+
|
|
209
|
+
d_proj = _newton_step(
|
|
210
|
+
objective=objective,
|
|
211
|
+
H=self.global_state["H_sketched"],
|
|
212
|
+
damping=self.defaults["damping"],
|
|
213
|
+
H_tfm=self.defaults["H_tfm"],
|
|
214
|
+
eigval_fn=self.defaults["eigval_fn"],
|
|
215
|
+
use_lstsq=self.defaults["use_lstsq"],
|
|
216
|
+
g_proj = lambda g: S.T @ g
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
d = S @ d_proj
|
|
220
|
+
objective.updates = vec_to_tensors(d, objective.params)
|
|
221
|
+
return objective
|
|
222
|
+
|
|
223
|
+
def get_H(self, objective=...):
|
|
224
|
+
eigval_fn = self.defaults["eigval_fn"]
|
|
225
|
+
H_sketched: torch.Tensor = self.global_state["H_sketched"]
|
|
226
|
+
S: torch.Tensor = self.global_state["S"]
|
|
227
|
+
|
|
228
|
+
if eigval_fn is not None:
|
|
229
|
+
try:
|
|
230
|
+
L, Q = torch.linalg.eigh(H_sketched) # pylint:disable=not-callable
|
|
231
|
+
L: torch.Tensor = eigval_fn(L)
|
|
232
|
+
H_sketched = Q @ L.diag_embed() @ Q.mH
|
|
233
|
+
|
|
234
|
+
except torch.linalg.LinAlgError:
|
|
235
|
+
pass
|
|
236
|
+
|
|
237
|
+
return Sketched(S, H_sketched)
|