torchzero 0.3.15__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +2 -2
- tests/test_module_autograd.py +586 -0
- tests/test_objective.py +188 -0
- tests/test_opts.py +43 -33
- tests/test_tensorlist.py +0 -8
- tests/test_utils_optimizer.py +0 -1
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +7 -4
- torchzero/core/chain.py +20 -23
- torchzero/core/functional.py +90 -24
- torchzero/core/modular.py +48 -52
- torchzero/core/module.py +130 -50
- torchzero/core/objective.py +948 -0
- torchzero/core/reformulation.py +55 -24
- torchzero/core/transform.py +261 -367
- torchzero/linalg/__init__.py +10 -0
- torchzero/linalg/eigh.py +34 -0
- torchzero/linalg/linalg_utils.py +14 -0
- torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
- torchzero/linalg/matrix_power.py +28 -0
- torchzero/linalg/orthogonalize.py +95 -0
- torchzero/{utils/linalg → linalg}/qr.py +4 -2
- torchzero/{utils/linalg → linalg}/solve.py +76 -88
- torchzero/linalg/svd.py +20 -0
- torchzero/linalg/torch_linalg.py +168 -0
- torchzero/modules/adaptive/__init__.py +1 -1
- torchzero/modules/adaptive/adagrad.py +163 -213
- torchzero/modules/adaptive/adahessian.py +74 -103
- torchzero/modules/adaptive/adam.py +53 -76
- torchzero/modules/adaptive/adan.py +49 -30
- torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
- torchzero/modules/adaptive/aegd.py +12 -12
- torchzero/modules/adaptive/esgd.py +98 -119
- torchzero/modules/adaptive/lion.py +5 -10
- torchzero/modules/adaptive/lmadagrad.py +87 -32
- torchzero/modules/adaptive/mars.py +5 -5
- torchzero/modules/adaptive/matrix_momentum.py +47 -51
- torchzero/modules/adaptive/msam.py +70 -52
- torchzero/modules/adaptive/muon.py +59 -124
- torchzero/modules/adaptive/natural_gradient.py +33 -28
- torchzero/modules/adaptive/orthograd.py +11 -15
- torchzero/modules/adaptive/rmsprop.py +83 -75
- torchzero/modules/adaptive/rprop.py +48 -47
- torchzero/modules/adaptive/sam.py +55 -45
- torchzero/modules/adaptive/shampoo.py +123 -129
- torchzero/modules/adaptive/soap.py +207 -143
- torchzero/modules/adaptive/sophia_h.py +106 -130
- torchzero/modules/clipping/clipping.py +15 -18
- torchzero/modules/clipping/ema_clipping.py +31 -25
- torchzero/modules/clipping/growth_clipping.py +14 -17
- torchzero/modules/conjugate_gradient/cg.py +26 -37
- torchzero/modules/experimental/__init__.py +2 -6
- torchzero/modules/experimental/coordinate_momentum.py +36 -0
- torchzero/modules/experimental/curveball.py +25 -41
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/higher_order_newton.py +14 -40
- torchzero/modules/experimental/newton_solver.py +22 -53
- torchzero/modules/experimental/newtonnewton.py +15 -12
- torchzero/modules/experimental/reduce_outward_lr.py +7 -7
- torchzero/modules/experimental/scipy_newton_cg.py +21 -24
- torchzero/modules/experimental/spsa1.py +3 -3
- torchzero/modules/experimental/structural_projections.py +1 -4
- torchzero/modules/functional.py +1 -1
- torchzero/modules/grad_approximation/forward_gradient.py +7 -7
- torchzero/modules/grad_approximation/grad_approximator.py +23 -16
- torchzero/modules/grad_approximation/rfdm.py +20 -17
- torchzero/modules/least_squares/gn.py +90 -42
- torchzero/modules/line_search/backtracking.py +2 -2
- torchzero/modules/line_search/line_search.py +32 -32
- torchzero/modules/line_search/strong_wolfe.py +2 -2
- torchzero/modules/misc/debug.py +12 -12
- torchzero/modules/misc/escape.py +10 -10
- torchzero/modules/misc/gradient_accumulation.py +10 -78
- torchzero/modules/misc/homotopy.py +16 -8
- torchzero/modules/misc/misc.py +120 -122
- torchzero/modules/misc/multistep.py +50 -48
- torchzero/modules/misc/regularization.py +49 -44
- torchzero/modules/misc/split.py +30 -28
- torchzero/modules/misc/switch.py +37 -32
- torchzero/modules/momentum/averaging.py +14 -14
- torchzero/modules/momentum/cautious.py +34 -28
- torchzero/modules/momentum/momentum.py +11 -11
- torchzero/modules/ops/__init__.py +4 -4
- torchzero/modules/ops/accumulate.py +21 -21
- torchzero/modules/ops/binary.py +67 -66
- torchzero/modules/ops/higher_level.py +19 -19
- torchzero/modules/ops/multi.py +44 -41
- torchzero/modules/ops/reduce.py +26 -23
- torchzero/modules/ops/unary.py +53 -53
- torchzero/modules/ops/utility.py +47 -46
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +43 -43
- torchzero/modules/quasi_newton/damping.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +7 -7
- torchzero/modules/quasi_newton/lsr1.py +7 -7
- torchzero/modules/quasi_newton/quasi_newton.py +10 -10
- torchzero/modules/quasi_newton/sg2.py +19 -19
- torchzero/modules/restarts/restars.py +26 -24
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/ifn.py +31 -62
- torchzero/modules/second_order/inm.py +49 -53
- torchzero/modules/second_order/multipoint.py +40 -80
- torchzero/modules/second_order/newton.py +57 -90
- torchzero/modules/second_order/newton_cg.py +102 -154
- torchzero/modules/second_order/nystrom.py +157 -177
- torchzero/modules/second_order/rsn.py +106 -96
- torchzero/modules/smoothing/laplacian.py +13 -12
- torchzero/modules/smoothing/sampling.py +11 -10
- torchzero/modules/step_size/adaptive.py +23 -23
- torchzero/modules/step_size/lr.py +15 -15
- torchzero/modules/termination/termination.py +32 -30
- torchzero/modules/trust_region/cubic_regularization.py +2 -2
- torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
- torchzero/modules/trust_region/trust_cg.py +1 -1
- torchzero/modules/trust_region/trust_region.py +27 -22
- torchzero/modules/variance_reduction/svrg.py +21 -18
- torchzero/modules/weight_decay/__init__.py +2 -1
- torchzero/modules/weight_decay/reinit.py +83 -0
- torchzero/modules/weight_decay/weight_decay.py +12 -13
- torchzero/modules/wrappers/optim_wrapper.py +10 -10
- torchzero/modules/zeroth_order/cd.py +9 -6
- torchzero/optim/root.py +3 -3
- torchzero/optim/utility/split.py +2 -1
- torchzero/optim/wrappers/directsearch.py +27 -63
- torchzero/optim/wrappers/fcmaes.py +14 -35
- torchzero/optim/wrappers/mads.py +11 -31
- torchzero/optim/wrappers/moors.py +66 -0
- torchzero/optim/wrappers/nevergrad.py +4 -4
- torchzero/optim/wrappers/nlopt.py +31 -25
- torchzero/optim/wrappers/optuna.py +6 -13
- torchzero/optim/wrappers/pybobyqa.py +124 -0
- torchzero/optim/wrappers/scipy/__init__.py +7 -0
- torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
- torchzero/optim/wrappers/scipy/brute.py +48 -0
- torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
- torchzero/optim/wrappers/scipy/direct.py +69 -0
- torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
- torchzero/optim/wrappers/scipy/experimental.py +141 -0
- torchzero/optim/wrappers/scipy/minimize.py +151 -0
- torchzero/optim/wrappers/scipy/sgho.py +111 -0
- torchzero/optim/wrappers/wrapper.py +121 -0
- torchzero/utils/__init__.py +7 -25
- torchzero/utils/compile.py +2 -2
- torchzero/utils/derivatives.py +93 -69
- torchzero/utils/optimizer.py +4 -77
- torchzero/utils/python_tools.py +31 -0
- torchzero/utils/tensorlist.py +11 -5
- torchzero/utils/thoad_tools.py +68 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
- torchzero-0.4.0.dist-info/RECORD +191 -0
- tests/test_vars.py +0 -185
- torchzero/core/var.py +0 -376
- torchzero/modules/experimental/momentum.py +0 -160
- torchzero/optim/wrappers/scipy.py +0 -572
- torchzero/utils/linalg/__init__.py +0 -12
- torchzero/utils/linalg/matrix_funcs.py +0 -87
- torchzero/utils/linalg/orthogonalize.py +0 -12
- torchzero/utils/linalg/svd.py +0 -20
- torchzero/utils/ops.py +0 -10
- torchzero-0.3.15.dist-info/RECORD +0 -175
- /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
|
@@ -5,9 +5,10 @@ from typing import Literal
|
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
|
|
8
|
-
from ...core import Chainable,
|
|
9
|
-
from ...utils import
|
|
10
|
-
from ...
|
|
8
|
+
from ...core import Chainable, Transform, HVPMethod
|
|
9
|
+
from ...utils import vec_to_tensors
|
|
10
|
+
from ...linalg.linear_operator import Sketched
|
|
11
|
+
|
|
11
12
|
from .newton import _newton_step
|
|
12
13
|
|
|
13
14
|
def _qr_orthonormalize(A:torch.Tensor):
|
|
@@ -15,9 +16,9 @@ def _qr_orthonormalize(A:torch.Tensor):
|
|
|
15
16
|
if m < n:
|
|
16
17
|
q, _ = torch.linalg.qr(A.T) # pylint:disable=not-callable
|
|
17
18
|
return q.T
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
19
|
+
|
|
20
|
+
q, _ = torch.linalg.qr(A) # pylint:disable=not-callable
|
|
21
|
+
return q
|
|
21
22
|
|
|
22
23
|
def _orthonormal_sketch(m, n, dtype, device, generator):
|
|
23
24
|
return _qr_orthonormalize(torch.randn(m, n, dtype=dtype, device=device, generator=generator))
|
|
@@ -25,26 +26,31 @@ def _orthonormal_sketch(m, n, dtype, device, generator):
|
|
|
25
26
|
def _gaussian_sketch(m, n, dtype, device, generator):
|
|
26
27
|
return torch.randn(m, n, dtype=dtype, device=device, generator=generator) / math.sqrt(m)
|
|
27
28
|
|
|
28
|
-
|
|
29
|
-
|
|
29
|
+
def _rademacher_sketch(m, n, dtype, device, generator):
|
|
30
|
+
rademacher = torch.bernoulli(torch.full((m,n), 0.5), generator = generator).mul_(2).sub_(1)
|
|
31
|
+
return rademacher.mul_(1 / math.sqrt(m))
|
|
32
|
+
|
|
33
|
+
class SubspaceNewton(Transform):
|
|
34
|
+
"""Subspace Newton. Performs a Newton step in a subspace (random or spanned by past gradients).
|
|
30
35
|
|
|
31
36
|
Args:
|
|
32
37
|
sketch_size (int):
|
|
33
38
|
size of the random sketch. This many hessian-vector products will need to be evaluated each step.
|
|
34
39
|
sketch_type (str, optional):
|
|
35
40
|
- "orthonormal" - random orthonormal basis. Orthonormality is necessary to use linear operator based modules such as trust region, but it can be slower to compute.
|
|
41
|
+
- "rademacher" - approximately orthonormal scaled random rademacher basis.
|
|
36
42
|
- "gaussian" - random gaussian (not orthonormal) basis.
|
|
37
43
|
- "common_directions" - uses history steepest descent directions as the basis[2]. It is orthonormalized on-line using Gram-Schmidt.
|
|
38
|
-
- "mixed" - random orthonormal basis but with
|
|
44
|
+
- "mixed" - random orthonormal basis but with four directions set to gradient, slow and fast gradient EMAs, and previous update direction (default).
|
|
39
45
|
damping (float, optional): hessian damping (scale of identity matrix added to hessian). Defaults to 0.
|
|
40
46
|
hvp_method (str, optional):
|
|
41
47
|
How to compute hessian-matrix product:
|
|
42
|
-
- "
|
|
48
|
+
- "batched_autograd" - uses batched autograd
|
|
43
49
|
- "autograd" - uses unbatched autograd
|
|
44
50
|
- "forward" - uses finite difference with forward formula, performing 1 backward pass per Hvp.
|
|
45
51
|
- "central" - uses finite difference with a more accurate central formula, performing 2 backward passes per Hvp.
|
|
46
52
|
|
|
47
|
-
. Defaults to "
|
|
53
|
+
. Defaults to "batched_autograd".
|
|
48
54
|
h (float, optional): finite difference step size. Defaults to 1e-2.
|
|
49
55
|
use_lstsq (bool, optional): whether to use least squares to solve ``Hx=g``. Defaults to False.
|
|
50
56
|
update_freq (int, optional): frequency of updating the hessian. Defaults to 1.
|
|
@@ -93,7 +99,7 @@ class RSN(Module):
|
|
|
93
99
|
sketch_size: int,
|
|
94
100
|
sketch_type: Literal["orthonormal", "gaussian", "common_directions", "mixed"] = "mixed",
|
|
95
101
|
damping:float=0,
|
|
96
|
-
hvp_method:
|
|
102
|
+
hvp_method: HVPMethod = "batched_autograd",
|
|
97
103
|
h: float = 1e-2,
|
|
98
104
|
use_lstsq: bool = True,
|
|
99
105
|
update_freq: int = 1,
|
|
@@ -102,115 +108,119 @@ class RSN(Module):
|
|
|
102
108
|
seed: int | None = None,
|
|
103
109
|
inner: Chainable | None = None,
|
|
104
110
|
):
|
|
105
|
-
defaults =
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
if inner is not None:
|
|
109
|
-
self.set_child("inner", inner)
|
|
111
|
+
defaults = locals().copy()
|
|
112
|
+
del defaults['self'], defaults['inner'], defaults["update_freq"]
|
|
113
|
+
super().__init__(defaults, update_freq=update_freq, inner=inner)
|
|
110
114
|
|
|
111
115
|
@torch.no_grad
|
|
112
|
-
def
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
if step % self.defaults['update_freq'] == 0:
|
|
116
|
+
def update_states(self, objective, states, settings):
|
|
117
|
+
fs = settings[0]
|
|
118
|
+
params = objective.params
|
|
119
|
+
generator = self.get_generator(params[0].device, fs["seed"])
|
|
117
120
|
|
|
118
|
-
|
|
119
|
-
if closure is None:
|
|
120
|
-
raise RuntimeError("RSN requires closure")
|
|
121
|
-
params = var.params
|
|
122
|
-
generator = self.get_generator(params[0].device, self.defaults["seed"])
|
|
121
|
+
ndim = sum(p.numel() for p in params)
|
|
123
122
|
|
|
124
|
-
|
|
123
|
+
device=params[0].device
|
|
124
|
+
dtype=params[0].dtype
|
|
125
125
|
|
|
126
|
-
|
|
127
|
-
|
|
126
|
+
# sample sketch matrix S: (ndim, sketch_size)
|
|
127
|
+
sketch_size = min(fs["sketch_size"], ndim)
|
|
128
|
+
sketch_type = fs["sketch_type"]
|
|
129
|
+
hvp_method = fs["hvp_method"]
|
|
128
130
|
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
sketch_type = self.defaults["sketch_type"]
|
|
132
|
-
hvp_method = self.defaults["hvp_method"]
|
|
131
|
+
if sketch_type in ('normal', 'gaussian'):
|
|
132
|
+
S = _gaussian_sketch(ndim, sketch_size, device=device, dtype=dtype, generator=generator)
|
|
133
133
|
|
|
134
|
-
|
|
135
|
-
|
|
134
|
+
elif sketch_type == "rademacher":
|
|
135
|
+
S = _rademacher_sketch(ndim, sketch_size, device=device, dtype=dtype, generator=generator)
|
|
136
136
|
|
|
137
|
-
|
|
138
|
-
|
|
137
|
+
elif sketch_type == 'orthonormal':
|
|
138
|
+
S = _orthonormal_sketch(ndim, sketch_size, device=device, dtype=dtype, generator=generator)
|
|
139
139
|
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
140
|
+
elif sketch_type == 'common_directions':
|
|
141
|
+
# Wang, Po-Wei, Ching-pei Lee, and Chih-Jen Lin. "The common-directions method for regularized empirical risk minimization." Journal of Machine Learning Research 20.58 (2019): 1-49.
|
|
142
|
+
g_list = objective.get_grads(create_graph=hvp_method in ("batched_autograd", "autograd"))
|
|
143
|
+
g = torch.cat([t.ravel() for t in g_list])
|
|
144
144
|
|
|
145
|
-
|
|
146
|
-
|
|
145
|
+
# initialize directions deque
|
|
146
|
+
if "directions" not in self.global_state:
|
|
147
147
|
|
|
148
|
+
g_norm = torch.linalg.vector_norm(g) # pylint:disable=not-callable
|
|
149
|
+
if g_norm < torch.finfo(g.dtype).tiny * 2:
|
|
150
|
+
g = torch.randn_like(g)
|
|
148
151
|
g_norm = torch.linalg.vector_norm(g) # pylint:disable=not-callable
|
|
149
|
-
if g_norm < torch.finfo(g.dtype).tiny * 2:
|
|
150
|
-
g = torch.randn_like(g)
|
|
151
|
-
g_norm = torch.linalg.vector_norm(g) # pylint:disable=not-callable
|
|
152
|
-
|
|
153
|
-
self.global_state["directions"] = deque([g / g_norm], maxlen=sketch_size)
|
|
154
|
-
S = self.global_state["directions"][0].unsqueeze(1)
|
|
155
|
-
|
|
156
|
-
# add new steepest descent direction orthonormal to existing columns
|
|
157
|
-
else:
|
|
158
|
-
S = torch.stack(tuple(self.global_state["directions"]), dim=1)
|
|
159
|
-
p = g - S @ (S.T @ g)
|
|
160
|
-
p_norm = torch.linalg.vector_norm(p) # pylint:disable=not-callable
|
|
161
|
-
if p_norm > torch.finfo(p.dtype).tiny * 2:
|
|
162
|
-
p = p / p_norm
|
|
163
|
-
self.global_state["directions"].append(p)
|
|
164
|
-
S = torch.cat([S, p.unsqueeze(1)], dim=1)
|
|
165
|
-
|
|
166
|
-
elif sketch_type == "mixed":
|
|
167
|
-
g_list = var.get_grad(create_graph=hvp_method in ("batched", "autograd"))
|
|
168
|
-
g = torch.cat([t.ravel() for t in g_list])
|
|
169
|
-
|
|
170
|
-
if "slow_ema" not in self.global_state:
|
|
171
|
-
self.global_state["slow_ema"] = torch.randn_like(g) * 1e-2
|
|
172
|
-
self.global_state["fast_ema"] = torch.randn_like(g) * 1e-2
|
|
173
|
-
|
|
174
|
-
slow_ema = self.global_state["slow_ema"]
|
|
175
|
-
fast_ema = self.global_state["fast_ema"]
|
|
176
|
-
slow_ema.lerp_(g, 0.001)
|
|
177
|
-
fast_ema.lerp_(g, 0.1)
|
|
178
|
-
|
|
179
|
-
S = torch.stack([g, slow_ema, fast_ema], dim=1)
|
|
180
|
-
if sketch_size > 3:
|
|
181
|
-
S_random = _gaussian_sketch(ndim, sketch_size - 3, device=device, dtype=dtype, generator=generator)
|
|
182
|
-
S = torch.cat([S, S_random], dim=1)
|
|
183
|
-
|
|
184
|
-
S = _qr_orthonormalize(S)
|
|
185
|
-
|
|
186
|
-
else:
|
|
187
|
-
raise ValueError(f'Unknown sketch_type {sketch_type}')
|
|
188
|
-
|
|
189
|
-
# form sketched hessian
|
|
190
|
-
HS, _ = var.hessian_matrix_product(S, at_x0=True, rgrad=None, hvp_method=self.defaults["hvp_method"], normalize=True, retain_graph=False, h=self.defaults["h"])
|
|
191
|
-
H_sketched = S.T @ HS
|
|
192
152
|
|
|
193
|
-
|
|
194
|
-
|
|
153
|
+
self.global_state["directions"] = deque([g / g_norm], maxlen=sketch_size)
|
|
154
|
+
S = self.global_state["directions"][0].unsqueeze(1)
|
|
195
155
|
|
|
196
|
-
|
|
156
|
+
# add new steepest descent direction orthonormal to existing columns
|
|
157
|
+
else:
|
|
158
|
+
S = torch.stack(tuple(self.global_state["directions"]), dim=1)
|
|
159
|
+
p = g - S @ (S.T @ g)
|
|
160
|
+
p_norm = torch.linalg.vector_norm(p) # pylint:disable=not-callable
|
|
161
|
+
if p_norm > torch.finfo(p.dtype).tiny * 2:
|
|
162
|
+
p = p / p_norm
|
|
163
|
+
self.global_state["directions"].append(p)
|
|
164
|
+
S = torch.cat([S, p.unsqueeze(1)], dim=1)
|
|
165
|
+
|
|
166
|
+
elif sketch_type == "mixed":
|
|
167
|
+
g_list = objective.get_grads(create_graph=hvp_method in ("batched_autograd", "autograd"))
|
|
168
|
+
g = torch.cat([t.ravel() for t in g_list])
|
|
169
|
+
|
|
170
|
+
# initialize state
|
|
171
|
+
if "slow_ema" not in self.global_state:
|
|
172
|
+
self.global_state["slow_ema"] = torch.randn_like(g) * 1e-2
|
|
173
|
+
self.global_state["fast_ema"] = torch.randn_like(g) * 1e-2
|
|
174
|
+
self.global_state["p_prev"] = torch.randn_like(g)
|
|
175
|
+
|
|
176
|
+
# previous update direction
|
|
177
|
+
p_cur = torch.cat([t.ravel() for t in params])
|
|
178
|
+
prev_dir = p_cur - self.global_state["p_prev"]
|
|
179
|
+
self.global_state["p_prev"] = p_cur
|
|
180
|
+
|
|
181
|
+
# EMAs
|
|
182
|
+
slow_ema = self.global_state["slow_ema"]
|
|
183
|
+
fast_ema = self.global_state["fast_ema"]
|
|
184
|
+
slow_ema.lerp_(g, 0.001)
|
|
185
|
+
fast_ema.lerp_(g, 0.1)
|
|
186
|
+
|
|
187
|
+
# form and orthogonalize sketching matrix
|
|
188
|
+
S = torch.stack([g, slow_ema, fast_ema, prev_dir], dim=1)
|
|
189
|
+
if sketch_size > 4:
|
|
190
|
+
S_random = _gaussian_sketch(ndim, sketch_size - 3, device=device, dtype=dtype, generator=generator)
|
|
191
|
+
S = torch.cat([S, S_random], dim=1)
|
|
192
|
+
|
|
193
|
+
S = _qr_orthonormalize(S)
|
|
194
|
+
|
|
195
|
+
else:
|
|
196
|
+
raise ValueError(f'Unknown sketch_type {sketch_type}')
|
|
197
|
+
|
|
198
|
+
# form sketched hessian
|
|
199
|
+
HS, _ = objective.hessian_matrix_product(S, rgrad=None, at_x0=True,
|
|
200
|
+
hvp_method=fs["hvp_method"], h=fs["h"])
|
|
201
|
+
H_sketched = S.T @ HS
|
|
202
|
+
|
|
203
|
+
self.global_state["H_sketched"] = H_sketched
|
|
204
|
+
self.global_state["S"] = S
|
|
205
|
+
|
|
206
|
+
def apply_states(self, objective, states, settings):
|
|
197
207
|
S: torch.Tensor = self.global_state["S"]
|
|
208
|
+
|
|
198
209
|
d_proj = _newton_step(
|
|
199
|
-
|
|
210
|
+
objective=objective,
|
|
200
211
|
H=self.global_state["H_sketched"],
|
|
201
212
|
damping=self.defaults["damping"],
|
|
202
|
-
inner=self.children.get("inner", None),
|
|
203
213
|
H_tfm=self.defaults["H_tfm"],
|
|
204
214
|
eigval_fn=self.defaults["eigval_fn"],
|
|
205
215
|
use_lstsq=self.defaults["use_lstsq"],
|
|
206
216
|
g_proj = lambda g: S.T @ g
|
|
207
217
|
)
|
|
208
|
-
d = S @ d_proj
|
|
209
|
-
var.update = vec_to_tensors(d, var.params)
|
|
210
218
|
|
|
211
|
-
|
|
219
|
+
d = S @ d_proj
|
|
220
|
+
objective.updates = vec_to_tensors(d, objective.params)
|
|
221
|
+
return objective
|
|
212
222
|
|
|
213
|
-
def get_H(self,
|
|
223
|
+
def get_H(self, objective=...):
|
|
214
224
|
eigval_fn = self.defaults["eigval_fn"]
|
|
215
225
|
H_sketched: torch.Tensor = self.global_state["H_sketched"]
|
|
216
226
|
S: torch.Tensor = self.global_state["S"]
|
|
@@ -4,7 +4,7 @@ from collections.abc import Iterable
|
|
|
4
4
|
import torch
|
|
5
5
|
|
|
6
6
|
from ...utils.tensorlist import TensorList
|
|
7
|
-
from ...core import
|
|
7
|
+
from ...core import TensorTransform
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
def vector_laplacian_smoothing(input: torch.Tensor, sigma: float = 1) -> torch.Tensor:
|
|
@@ -55,7 +55,7 @@ def _precompute_denominator(tensor: torch.Tensor, sigma) -> torch.Tensor:
|
|
|
55
55
|
v[-1] = 1
|
|
56
56
|
return 1 - sigma * torch.fft.fft(v) # pylint: disable = not-callable
|
|
57
57
|
|
|
58
|
-
class LaplacianSmoothing(
|
|
58
|
+
class LaplacianSmoothing(TensorTransform):
|
|
59
59
|
"""Applies laplacian smoothing via a fast Fourier transform solver which can improve generalization.
|
|
60
60
|
|
|
61
61
|
Args:
|
|
@@ -70,29 +70,30 @@ class LaplacianSmoothing(Transform):
|
|
|
70
70
|
what to set on var.
|
|
71
71
|
|
|
72
72
|
Examples:
|
|
73
|
-
|
|
73
|
+
Laplacian Smoothing Gradient Descent optimizer as in the paper
|
|
74
74
|
|
|
75
|
-
|
|
75
|
+
```python
|
|
76
76
|
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
77
|
+
opt = tz.Modular(
|
|
78
|
+
model.parameters(),
|
|
79
|
+
tz.m.LaplacianSmoothing(),
|
|
80
|
+
tz.m.LR(1e-2),
|
|
81
|
+
)
|
|
82
|
+
```
|
|
82
83
|
|
|
83
84
|
Reference:
|
|
84
85
|
Osher, S., Wang, B., Yin, P., Luo, X., Barekat, F., Pham, M., & Lin, A. (2022). Laplacian smoothing gradient descent. Research in the Mathematical Sciences, 9(3), 55.
|
|
85
86
|
|
|
86
87
|
"""
|
|
87
|
-
def __init__(self, sigma:float = 1, layerwise=True, min_numel = 4
|
|
88
|
+
def __init__(self, sigma:float = 1, layerwise=True, min_numel = 4):
|
|
88
89
|
defaults = dict(sigma = sigma, layerwise=layerwise, min_numel=min_numel)
|
|
89
|
-
super().__init__(defaults
|
|
90
|
+
super().__init__(defaults)
|
|
90
91
|
# precomputed denominator for when layerwise=False
|
|
91
92
|
self.global_state['full_denominator'] = None
|
|
92
93
|
|
|
93
94
|
|
|
94
95
|
@torch.no_grad
|
|
95
|
-
def
|
|
96
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
96
97
|
layerwise = settings[0]['layerwise']
|
|
97
98
|
|
|
98
99
|
# layerwise laplacian smoothing
|
|
@@ -7,14 +7,15 @@ from typing import Literal, cast
|
|
|
7
7
|
|
|
8
8
|
import torch
|
|
9
9
|
|
|
10
|
-
from ...core import Chainable, Modular, Module,
|
|
10
|
+
from ...core import Chainable, Modular, Module, Objective
|
|
11
11
|
from ...core.reformulation import Reformulation
|
|
12
12
|
from ...utils import Distributions, NumberList, TensorList
|
|
13
13
|
from ..termination import TerminationCriteriaBase, make_termination_criteria
|
|
14
14
|
|
|
15
15
|
|
|
16
|
-
def _reset_except_self(
|
|
17
|
-
|
|
16
|
+
def _reset_except_self(objective: Objective, modules, self: Module):
|
|
17
|
+
assert objective.modular is not None
|
|
18
|
+
for m in objective.modular.flat_modules:
|
|
18
19
|
if m is not self:
|
|
19
20
|
m.reset()
|
|
20
21
|
|
|
@@ -98,15 +99,15 @@ class GradientSampling(Reformulation):
|
|
|
98
99
|
self.set_child('termination', make_termination_criteria(extra=termination))
|
|
99
100
|
|
|
100
101
|
@torch.no_grad
|
|
101
|
-
def pre_step(self,
|
|
102
|
-
params = TensorList(
|
|
102
|
+
def pre_step(self, objective):
|
|
103
|
+
params = TensorList(objective.params)
|
|
103
104
|
|
|
104
105
|
fixed = self.defaults['fixed']
|
|
105
106
|
|
|
106
107
|
# check termination criteria
|
|
107
108
|
if 'termination' in self.children:
|
|
108
109
|
termination = cast(TerminationCriteriaBase, self.children['termination'])
|
|
109
|
-
if termination.should_terminate(
|
|
110
|
+
if termination.should_terminate(objective):
|
|
110
111
|
|
|
111
112
|
# decay sigmas
|
|
112
113
|
states = [self.state[p] for p in params]
|
|
@@ -118,7 +119,7 @@ class GradientSampling(Reformulation):
|
|
|
118
119
|
|
|
119
120
|
# reset on sigmas decay
|
|
120
121
|
if self.defaults['reset_on_termination']:
|
|
121
|
-
|
|
122
|
+
objective.post_step_hooks.append(partial(_reset_except_self, self=self))
|
|
122
123
|
|
|
123
124
|
# clear perturbations
|
|
124
125
|
self.global_state.pop('perts', None)
|
|
@@ -136,7 +137,7 @@ class GradientSampling(Reformulation):
|
|
|
136
137
|
self.global_state['perts'] = perts
|
|
137
138
|
|
|
138
139
|
@torch.no_grad
|
|
139
|
-
def closure(self, backward, closure, params,
|
|
140
|
+
def closure(self, backward, closure, params, objective):
|
|
140
141
|
params = TensorList(params)
|
|
141
142
|
loss_agg = None
|
|
142
143
|
grad_agg = None
|
|
@@ -160,7 +161,7 @@ class GradientSampling(Reformulation):
|
|
|
160
161
|
|
|
161
162
|
# evaluate at x_0
|
|
162
163
|
if include_x0:
|
|
163
|
-
f_0 =
|
|
164
|
+
f_0 = objective.get_loss(backward=backward)
|
|
164
165
|
|
|
165
166
|
isfinite = math.isfinite(f_0)
|
|
166
167
|
if isfinite:
|
|
@@ -168,7 +169,7 @@ class GradientSampling(Reformulation):
|
|
|
168
169
|
loss_agg = f_0
|
|
169
170
|
|
|
170
171
|
if backward:
|
|
171
|
-
g_0 =
|
|
172
|
+
g_0 = objective.get_grads()
|
|
172
173
|
if isfinite: grad_agg = g_0
|
|
173
174
|
|
|
174
175
|
# evaluate at x_0 + p for each perturbation
|
|
@@ -5,9 +5,9 @@ from typing import Any, Literal
|
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
|
|
8
|
-
from ...core import Chainable,
|
|
8
|
+
from ...core import Chainable, TensorTransform
|
|
9
9
|
from ...utils import NumberList, TensorList, tofloat, unpack_dicts, unpack_states
|
|
10
|
-
from ...
|
|
10
|
+
from ...linalg.linear_operator import ScaledIdentity
|
|
11
11
|
from ..functional import epsilon_step_size
|
|
12
12
|
|
|
13
13
|
def _acceptable_alpha(alpha, param:torch.Tensor):
|
|
@@ -16,7 +16,7 @@ def _acceptable_alpha(alpha, param:torch.Tensor):
|
|
|
16
16
|
return False
|
|
17
17
|
return True
|
|
18
18
|
|
|
19
|
-
def _get_H(self:
|
|
19
|
+
def _get_H(self: TensorTransform, var):
|
|
20
20
|
n = sum(p.numel() for p in var.params)
|
|
21
21
|
p = var.params[0]
|
|
22
22
|
alpha = self.global_state.get('alpha', 1)
|
|
@@ -25,7 +25,7 @@ def _get_H(self: Transform, var):
|
|
|
25
25
|
return ScaledIdentity(1 / alpha, shape=(n,n), device=p.device, dtype=p.dtype)
|
|
26
26
|
|
|
27
27
|
|
|
28
|
-
class PolyakStepSize(
|
|
28
|
+
class PolyakStepSize(TensorTransform):
|
|
29
29
|
"""Polyak's subgradient method with known or unknown f*.
|
|
30
30
|
|
|
31
31
|
Args:
|
|
@@ -47,7 +47,7 @@ class PolyakStepSize(Transform):
|
|
|
47
47
|
super().__init__(defaults, uses_grad=use_grad, uses_loss=True, inner=inner)
|
|
48
48
|
|
|
49
49
|
@torch.no_grad
|
|
50
|
-
def
|
|
50
|
+
def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
|
|
51
51
|
assert grads is not None and loss is not None
|
|
52
52
|
tensors = TensorList(tensors)
|
|
53
53
|
grads = TensorList(grads)
|
|
@@ -79,15 +79,15 @@ class PolyakStepSize(Transform):
|
|
|
79
79
|
self.global_state['alpha'] = alpha
|
|
80
80
|
|
|
81
81
|
@torch.no_grad
|
|
82
|
-
def
|
|
82
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
83
83
|
alpha = self.global_state.get('alpha', 1)
|
|
84
84
|
if not _acceptable_alpha(alpha, tensors[0]): alpha = epsilon_step_size(TensorList(tensors))
|
|
85
85
|
|
|
86
86
|
torch._foreach_mul_(tensors, alpha * unpack_dicts(settings, 'alpha', cls=NumberList))
|
|
87
87
|
return tensors
|
|
88
88
|
|
|
89
|
-
def get_H(self,
|
|
90
|
-
return _get_H(self,
|
|
89
|
+
def get_H(self, objective):
|
|
90
|
+
return _get_H(self, objective)
|
|
91
91
|
|
|
92
92
|
|
|
93
93
|
def _bb_short(s: TensorList, y: TensorList, sy, eps):
|
|
@@ -116,7 +116,7 @@ def _bb_geom(s: TensorList, y: TensorList, sy, eps, fallback:bool):
|
|
|
116
116
|
return None
|
|
117
117
|
return (short * long) ** 0.5
|
|
118
118
|
|
|
119
|
-
class BarzilaiBorwein(
|
|
119
|
+
class BarzilaiBorwein(TensorTransform):
|
|
120
120
|
"""Barzilai-Borwein step size method.
|
|
121
121
|
|
|
122
122
|
Args:
|
|
@@ -144,7 +144,7 @@ class BarzilaiBorwein(Transform):
|
|
|
144
144
|
self.global_state['reset'] = True
|
|
145
145
|
|
|
146
146
|
@torch.no_grad
|
|
147
|
-
def
|
|
147
|
+
def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
|
|
148
148
|
step = self.global_state.get('step', 0)
|
|
149
149
|
self.global_state['step'] = step + 1
|
|
150
150
|
|
|
@@ -175,11 +175,11 @@ class BarzilaiBorwein(Transform):
|
|
|
175
175
|
prev_p.copy_(params)
|
|
176
176
|
prev_g.copy_(g)
|
|
177
177
|
|
|
178
|
-
def get_H(self,
|
|
179
|
-
return _get_H(self,
|
|
178
|
+
def get_H(self, objective):
|
|
179
|
+
return _get_H(self, objective)
|
|
180
180
|
|
|
181
181
|
@torch.no_grad
|
|
182
|
-
def
|
|
182
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
183
183
|
alpha = self.global_state.get('alpha', None)
|
|
184
184
|
|
|
185
185
|
if not _acceptable_alpha(alpha, tensors[0]):
|
|
@@ -189,7 +189,7 @@ class BarzilaiBorwein(Transform):
|
|
|
189
189
|
return tensors
|
|
190
190
|
|
|
191
191
|
|
|
192
|
-
class BBStab(
|
|
192
|
+
class BBStab(TensorTransform):
|
|
193
193
|
"""Stabilized Barzilai-Borwein method (https://arxiv.org/abs/1907.06409).
|
|
194
194
|
|
|
195
195
|
This clips the norm of the Barzilai-Borwein update by ``delta``, where ``delta`` can be adaptive if ``c`` is specified.
|
|
@@ -228,7 +228,7 @@ class BBStab(Transform):
|
|
|
228
228
|
self.global_state['reset'] = True
|
|
229
229
|
|
|
230
230
|
@torch.no_grad
|
|
231
|
-
def
|
|
231
|
+
def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
|
|
232
232
|
step = self.global_state.get('step', 0)
|
|
233
233
|
self.global_state['step'] = step + 1
|
|
234
234
|
|
|
@@ -287,11 +287,11 @@ class BBStab(Transform):
|
|
|
287
287
|
prev_p.copy_(params)
|
|
288
288
|
prev_g.copy_(g)
|
|
289
289
|
|
|
290
|
-
def get_H(self,
|
|
291
|
-
return _get_H(self,
|
|
290
|
+
def get_H(self, objective):
|
|
291
|
+
return _get_H(self, objective)
|
|
292
292
|
|
|
293
293
|
@torch.no_grad
|
|
294
|
-
def
|
|
294
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
295
295
|
alpha = self.global_state.get('alpha', None)
|
|
296
296
|
|
|
297
297
|
if not _acceptable_alpha(alpha, tensors[0]):
|
|
@@ -301,7 +301,7 @@ class BBStab(Transform):
|
|
|
301
301
|
return tensors
|
|
302
302
|
|
|
303
303
|
|
|
304
|
-
class AdGD(
|
|
304
|
+
class AdGD(TensorTransform):
|
|
305
305
|
"""AdGD and AdGD-2 (https://arxiv.org/abs/2308.02261)"""
|
|
306
306
|
def __init__(self, variant:Literal[1,2]=2, alpha_0:float = 1e-7, sqrt:bool=True, use_grad=True, inner: Chainable | None = None,):
|
|
307
307
|
defaults = dict(variant=variant, alpha_0=alpha_0, sqrt=sqrt)
|
|
@@ -313,7 +313,7 @@ class AdGD(Transform):
|
|
|
313
313
|
self.global_state['reset'] = True
|
|
314
314
|
|
|
315
315
|
@torch.no_grad
|
|
316
|
-
def
|
|
316
|
+
def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
|
|
317
317
|
variant = settings[0]['variant']
|
|
318
318
|
theta_0 = 0 if variant == 1 else 1/3
|
|
319
319
|
theta = self.global_state.get('theta', theta_0)
|
|
@@ -371,7 +371,7 @@ class AdGD(Transform):
|
|
|
371
371
|
prev_g.copy_(g)
|
|
372
372
|
|
|
373
373
|
@torch.no_grad
|
|
374
|
-
def
|
|
374
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
375
375
|
alpha = self.global_state.get('alpha', None)
|
|
376
376
|
|
|
377
377
|
if not _acceptable_alpha(alpha, tensors[0]):
|
|
@@ -383,5 +383,5 @@ class AdGD(Transform):
|
|
|
383
383
|
torch._foreach_mul_(tensors, alpha)
|
|
384
384
|
return tensors
|
|
385
385
|
|
|
386
|
-
def get_H(self,
|
|
387
|
-
return _get_H(self,
|
|
386
|
+
def get_H(self, objective):
|
|
387
|
+
return _get_H(self, objective)
|
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
import torch
|
|
3
3
|
import random
|
|
4
4
|
|
|
5
|
-
from ...core import
|
|
5
|
+
from ...core import TensorTransform
|
|
6
6
|
from ...utils import NumberList, TensorList, generic_ne, unpack_dicts
|
|
7
7
|
|
|
8
8
|
def lazy_lr(tensors: TensorList, lr: float | list, inplace:bool):
|
|
@@ -12,24 +12,24 @@ def lazy_lr(tensors: TensorList, lr: float | list, inplace:bool):
|
|
|
12
12
|
return tensors * lr
|
|
13
13
|
return tensors
|
|
14
14
|
|
|
15
|
-
class LR(
|
|
15
|
+
class LR(TensorTransform):
|
|
16
16
|
"""Learning rate. Adding this module also adds support for LR schedulers."""
|
|
17
17
|
def __init__(self, lr: float):
|
|
18
18
|
defaults=dict(lr=lr)
|
|
19
|
-
super().__init__(defaults
|
|
19
|
+
super().__init__(defaults)
|
|
20
20
|
|
|
21
21
|
@torch.no_grad
|
|
22
|
-
def
|
|
22
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
23
23
|
return lazy_lr(TensorList(tensors), lr=[s['lr'] for s in settings], inplace=True)
|
|
24
24
|
|
|
25
|
-
class StepSize(
|
|
25
|
+
class StepSize(TensorTransform):
|
|
26
26
|
"""this is exactly the same as LR, except the `lr` parameter can be renamed to any other name to avoid clashes"""
|
|
27
27
|
def __init__(self, step_size: float, key = 'step_size'):
|
|
28
28
|
defaults={"key": key, key: step_size}
|
|
29
|
-
super().__init__(defaults
|
|
29
|
+
super().__init__(defaults)
|
|
30
30
|
|
|
31
31
|
@torch.no_grad
|
|
32
|
-
def
|
|
32
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
33
33
|
return lazy_lr(TensorList(tensors), lr=[s[s['key']] for s in settings], inplace=True)
|
|
34
34
|
|
|
35
35
|
|
|
@@ -38,8 +38,8 @@ def _warmup_lr(step: int, start_lr: float | NumberList, end_lr: float | NumberLi
|
|
|
38
38
|
if step > steps: return end_lr
|
|
39
39
|
return start_lr + (end_lr - start_lr) * (step / steps)
|
|
40
40
|
|
|
41
|
-
class Warmup(
|
|
42
|
-
"""Learning rate warmup, linearly increases learning rate multiplier from
|
|
41
|
+
class Warmup(TensorTransform):
|
|
42
|
+
"""Learning rate warmup, linearly increases learning rate multiplier from ``start_lr`` to ``end_lr`` over ``steps`` steps.
|
|
43
43
|
|
|
44
44
|
Args:
|
|
45
45
|
steps (int, optional): number of steps to perform warmup for. Defaults to 100.
|
|
@@ -64,7 +64,7 @@ class Warmup(Transform):
|
|
|
64
64
|
super().__init__(defaults, uses_grad=False)
|
|
65
65
|
|
|
66
66
|
@torch.no_grad
|
|
67
|
-
def
|
|
67
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
68
68
|
start_lr, end_lr = unpack_dicts(settings, 'start_lr', 'end_lr', cls = NumberList)
|
|
69
69
|
num_steps = settings[0]['steps']
|
|
70
70
|
step = self.global_state.get('step', 0)
|
|
@@ -77,7 +77,7 @@ class Warmup(Transform):
|
|
|
77
77
|
self.global_state['step'] = step + 1
|
|
78
78
|
return tensors
|
|
79
79
|
|
|
80
|
-
class WarmupNormClip(
|
|
80
|
+
class WarmupNormClip(TensorTransform):
|
|
81
81
|
"""Warmup via clipping of the update norm.
|
|
82
82
|
|
|
83
83
|
Args:
|
|
@@ -102,7 +102,7 @@ class WarmupNormClip(Transform):
|
|
|
102
102
|
super().__init__(defaults, uses_grad=False)
|
|
103
103
|
|
|
104
104
|
@torch.no_grad
|
|
105
|
-
def
|
|
105
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
106
106
|
start_norm, end_norm = unpack_dicts(settings, 'start_norm', 'end_norm', cls = NumberList)
|
|
107
107
|
num_steps = settings[0]['steps']
|
|
108
108
|
step = self.global_state.get('step', 0)
|
|
@@ -118,8 +118,8 @@ class WarmupNormClip(Transform):
|
|
|
118
118
|
return tensors
|
|
119
119
|
|
|
120
120
|
|
|
121
|
-
class RandomStepSize(
|
|
122
|
-
"""Uses random global or layer-wise step size from
|
|
121
|
+
class RandomStepSize(TensorTransform):
|
|
122
|
+
"""Uses random global or layer-wise step size from ``low`` to ``high``.
|
|
123
123
|
|
|
124
124
|
Args:
|
|
125
125
|
low (float, optional): minimum learning rate. Defaults to 0.
|
|
@@ -133,7 +133,7 @@ class RandomStepSize(Transform):
|
|
|
133
133
|
super().__init__(defaults, uses_grad=False)
|
|
134
134
|
|
|
135
135
|
@torch.no_grad
|
|
136
|
-
def
|
|
136
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
137
137
|
s = settings[0]
|
|
138
138
|
parameterwise = s['parameterwise']
|
|
139
139
|
|