torchzero 0.3.15__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +22 -22
- tests/test_module_autograd.py +586 -0
- tests/test_objective.py +188 -0
- tests/test_opts.py +225 -214
- tests/test_tensorlist.py +0 -8
- tests/test_utils_optimizer.py +0 -1
- torchzero/__init__.py +2 -2
- torchzero/core/__init__.py +7 -4
- torchzero/core/chain.py +20 -23
- torchzero/core/functional.py +90 -24
- torchzero/core/modular.py +53 -57
- torchzero/core/module.py +132 -52
- torchzero/core/objective.py +948 -0
- torchzero/core/reformulation.py +55 -24
- torchzero/core/transform.py +261 -367
- torchzero/linalg/__init__.py +11 -0
- torchzero/linalg/eigh.py +253 -0
- torchzero/linalg/linalg_utils.py +14 -0
- torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
- torchzero/linalg/matrix_power.py +28 -0
- torchzero/linalg/orthogonalize.py +93 -0
- torchzero/{utils/linalg → linalg}/qr.py +16 -2
- torchzero/{utils/linalg → linalg}/solve.py +74 -88
- torchzero/linalg/svd.py +47 -0
- torchzero/linalg/torch_linalg.py +168 -0
- torchzero/modules/__init__.py +4 -3
- torchzero/modules/adaptive/__init__.py +11 -3
- torchzero/modules/adaptive/adagrad.py +167 -217
- torchzero/modules/adaptive/adahessian.py +76 -105
- torchzero/modules/adaptive/adam.py +53 -76
- torchzero/modules/adaptive/adan.py +50 -31
- torchzero/modules/adaptive/adaptive_heavyball.py +12 -7
- torchzero/modules/adaptive/aegd.py +12 -12
- torchzero/modules/adaptive/esgd.py +98 -119
- torchzero/modules/adaptive/ggt.py +186 -0
- torchzero/modules/adaptive/lion.py +7 -11
- torchzero/modules/adaptive/lre_optimizers.py +299 -0
- torchzero/modules/adaptive/mars.py +7 -7
- torchzero/modules/adaptive/matrix_momentum.py +48 -52
- torchzero/modules/adaptive/msam.py +71 -53
- torchzero/modules/adaptive/muon.py +67 -129
- torchzero/modules/adaptive/natural_gradient.py +63 -41
- torchzero/modules/adaptive/orthograd.py +11 -15
- torchzero/modules/adaptive/psgd/__init__.py +5 -0
- torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
- torchzero/modules/adaptive/psgd/psgd.py +1390 -0
- torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
- torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
- torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
- torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
- torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
- torchzero/modules/adaptive/rmsprop.py +83 -75
- torchzero/modules/adaptive/rprop.py +48 -47
- torchzero/modules/adaptive/sam.py +55 -45
- torchzero/modules/adaptive/shampoo.py +149 -130
- torchzero/modules/adaptive/soap.py +207 -143
- torchzero/modules/adaptive/sophia_h.py +106 -130
- torchzero/modules/clipping/clipping.py +22 -25
- torchzero/modules/clipping/ema_clipping.py +31 -25
- torchzero/modules/clipping/growth_clipping.py +14 -17
- torchzero/modules/conjugate_gradient/cg.py +27 -38
- torchzero/modules/experimental/__init__.py +7 -6
- torchzero/modules/experimental/adanystrom.py +258 -0
- torchzero/modules/experimental/common_directions_whiten.py +142 -0
- torchzero/modules/experimental/coordinate_momentum.py +36 -0
- torchzero/modules/experimental/cubic_adam.py +160 -0
- torchzero/modules/experimental/curveball.py +25 -41
- torchzero/modules/experimental/eigen_sr1.py +182 -0
- torchzero/modules/experimental/eigengrad.py +207 -0
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/higher_order_newton.py +14 -40
- torchzero/modules/experimental/l_infinity.py +1 -1
- torchzero/modules/experimental/matrix_nag.py +122 -0
- torchzero/modules/experimental/newton_solver.py +23 -54
- torchzero/modules/experimental/newtonnewton.py +45 -48
- torchzero/modules/experimental/reduce_outward_lr.py +7 -7
- torchzero/modules/experimental/scipy_newton_cg.py +21 -24
- torchzero/modules/experimental/spsa1.py +3 -3
- torchzero/modules/experimental/structural_projections.py +1 -4
- torchzero/modules/grad_approximation/fdm.py +2 -2
- torchzero/modules/grad_approximation/forward_gradient.py +7 -7
- torchzero/modules/grad_approximation/grad_approximator.py +23 -16
- torchzero/modules/grad_approximation/rfdm.py +24 -21
- torchzero/modules/least_squares/gn.py +121 -50
- torchzero/modules/line_search/backtracking.py +4 -4
- torchzero/modules/line_search/line_search.py +33 -33
- torchzero/modules/line_search/strong_wolfe.py +4 -4
- torchzero/modules/misc/debug.py +12 -12
- torchzero/modules/misc/escape.py +10 -10
- torchzero/modules/misc/gradient_accumulation.py +11 -79
- torchzero/modules/misc/homotopy.py +16 -8
- torchzero/modules/misc/misc.py +121 -123
- torchzero/modules/misc/multistep.py +52 -53
- torchzero/modules/misc/regularization.py +49 -44
- torchzero/modules/misc/split.py +31 -29
- torchzero/modules/misc/switch.py +37 -32
- torchzero/modules/momentum/averaging.py +14 -14
- torchzero/modules/momentum/cautious.py +37 -31
- torchzero/modules/momentum/momentum.py +12 -12
- torchzero/modules/ops/__init__.py +4 -4
- torchzero/modules/ops/accumulate.py +21 -21
- torchzero/modules/ops/binary.py +67 -66
- torchzero/modules/ops/higher_level.py +20 -20
- torchzero/modules/ops/multi.py +44 -41
- torchzero/modules/ops/reduce.py +26 -23
- torchzero/modules/ops/unary.py +53 -53
- torchzero/modules/ops/utility.py +47 -46
- torchzero/modules/{functional.py → opt_utils.py} +1 -1
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +46 -43
- torchzero/modules/quasi_newton/__init__.py +1 -1
- torchzero/modules/quasi_newton/damping.py +2 -2
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +10 -10
- torchzero/modules/quasi_newton/lsr1.py +10 -10
- torchzero/modules/quasi_newton/quasi_newton.py +54 -39
- torchzero/modules/quasi_newton/sg2.py +69 -205
- torchzero/modules/restarts/restars.py +39 -37
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/ifn.py +31 -62
- torchzero/modules/second_order/inm.py +57 -53
- torchzero/modules/second_order/multipoint.py +40 -80
- torchzero/modules/second_order/newton.py +165 -196
- torchzero/modules/second_order/newton_cg.py +105 -157
- torchzero/modules/second_order/nystrom.py +216 -185
- torchzero/modules/second_order/rsn.py +132 -125
- torchzero/modules/smoothing/laplacian.py +13 -12
- torchzero/modules/smoothing/sampling.py +10 -10
- torchzero/modules/step_size/adaptive.py +24 -24
- torchzero/modules/step_size/lr.py +17 -17
- torchzero/modules/termination/termination.py +32 -30
- torchzero/modules/trust_region/cubic_regularization.py +3 -3
- torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
- torchzero/modules/trust_region/trust_cg.py +2 -2
- torchzero/modules/trust_region/trust_region.py +27 -22
- torchzero/modules/variance_reduction/svrg.py +23 -21
- torchzero/modules/weight_decay/__init__.py +2 -1
- torchzero/modules/weight_decay/reinit.py +83 -0
- torchzero/modules/weight_decay/weight_decay.py +17 -18
- torchzero/modules/wrappers/optim_wrapper.py +14 -14
- torchzero/modules/zeroth_order/cd.py +10 -7
- torchzero/optim/mbs.py +291 -0
- torchzero/optim/root.py +3 -3
- torchzero/optim/utility/split.py +2 -1
- torchzero/optim/wrappers/directsearch.py +27 -63
- torchzero/optim/wrappers/fcmaes.py +14 -35
- torchzero/optim/wrappers/mads.py +11 -31
- torchzero/optim/wrappers/moors.py +66 -0
- torchzero/optim/wrappers/nevergrad.py +4 -13
- torchzero/optim/wrappers/nlopt.py +31 -25
- torchzero/optim/wrappers/optuna.py +8 -13
- torchzero/optim/wrappers/pybobyqa.py +124 -0
- torchzero/optim/wrappers/scipy/__init__.py +7 -0
- torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
- torchzero/optim/wrappers/scipy/brute.py +48 -0
- torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
- torchzero/optim/wrappers/scipy/direct.py +69 -0
- torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
- torchzero/optim/wrappers/scipy/experimental.py +141 -0
- torchzero/optim/wrappers/scipy/minimize.py +151 -0
- torchzero/optim/wrappers/scipy/sgho.py +111 -0
- torchzero/optim/wrappers/wrapper.py +121 -0
- torchzero/utils/__init__.py +7 -25
- torchzero/utils/benchmarks/__init__.py +0 -0
- torchzero/utils/benchmarks/logistic.py +122 -0
- torchzero/utils/compile.py +2 -2
- torchzero/utils/derivatives.py +97 -73
- torchzero/utils/optimizer.py +4 -77
- torchzero/utils/python_tools.py +31 -0
- torchzero/utils/tensorlist.py +11 -5
- torchzero/utils/thoad_tools.py +68 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
- torchzero-0.4.1.dist-info/RECORD +209 -0
- tests/test_vars.py +0 -185
- torchzero/core/var.py +0 -376
- torchzero/modules/adaptive/lmadagrad.py +0 -186
- torchzero/modules/experimental/momentum.py +0 -160
- torchzero/optim/wrappers/scipy.py +0 -572
- torchzero/utils/linalg/__init__.py +0 -12
- torchzero/utils/linalg/matrix_funcs.py +0 -87
- torchzero/utils/linalg/orthogonalize.py +0 -12
- torchzero/utils/linalg/svd.py +0 -20
- torchzero/utils/ops.py +0 -10
- torchzero-0.3.15.dist-info/RECORD +0 -175
- /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
|
@@ -5,46 +5,49 @@ from typing import Literal
|
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
|
|
8
|
-
from ...core import Chainable,
|
|
9
|
-
from ...utils import
|
|
10
|
-
from ...
|
|
11
|
-
|
|
8
|
+
from ...core import Chainable, Transform, HVPMethod
|
|
9
|
+
from ...utils import vec_to_tensors_
|
|
10
|
+
from ...linalg.linear_operator import Sketched
|
|
11
|
+
|
|
12
|
+
from .newton import _newton_update_state_, _newton_solve
|
|
12
13
|
|
|
13
14
|
def _qr_orthonormalize(A:torch.Tensor):
|
|
14
15
|
m,n = A.shape
|
|
15
16
|
if m < n:
|
|
16
17
|
q, _ = torch.linalg.qr(A.T) # pylint:disable=not-callable
|
|
17
18
|
return q.T
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
19
|
+
|
|
20
|
+
q, _ = torch.linalg.qr(A) # pylint:disable=not-callable
|
|
21
|
+
return q
|
|
22
|
+
|
|
21
23
|
|
|
22
24
|
def _orthonormal_sketch(m, n, dtype, device, generator):
|
|
23
25
|
return _qr_orthonormalize(torch.randn(m, n, dtype=dtype, device=device, generator=generator))
|
|
24
26
|
|
|
25
|
-
def
|
|
26
|
-
|
|
27
|
+
def _rademacher_sketch(m, n, dtype, device, generator):
|
|
28
|
+
rademacher = torch.bernoulli(torch.full((m,n), 0.5), generator = generator).mul_(2).sub_(1)
|
|
29
|
+
return rademacher.mul_(1 / math.sqrt(m))
|
|
27
30
|
|
|
28
|
-
class
|
|
29
|
-
"""
|
|
31
|
+
class SubspaceNewton(Transform):
|
|
32
|
+
"""Subspace Newton. Performs a Newton step in a subspace (random or spanned by past gradients).
|
|
30
33
|
|
|
31
34
|
Args:
|
|
32
35
|
sketch_size (int):
|
|
33
36
|
size of the random sketch. This many hessian-vector products will need to be evaluated each step.
|
|
34
37
|
sketch_type (str, optional):
|
|
38
|
+
- "common_directions" - uses history steepest descent directions as the basis[2]. It is orthonormalized on-line using Gram-Schmidt (default).
|
|
35
39
|
- "orthonormal" - random orthonormal basis. Orthonormality is necessary to use linear operator based modules such as trust region, but it can be slower to compute.
|
|
36
|
-
- "
|
|
37
|
-
- "
|
|
38
|
-
- "mixed" - random orthonormal basis but with three directions set to gradient, slow EMA and fast EMA (default).
|
|
40
|
+
- "rademacher" - approximately orthonormal (if dimension is large) scaled random rademacher basis. It is recommended to use at least "orthonormal" - it requires QR but it is still very cheap.
|
|
41
|
+
- "mixed" - random orthonormal basis but with four directions set to gradient, slow and fast gradient EMAs, and previous update direction.
|
|
39
42
|
damping (float, optional): hessian damping (scale of identity matrix added to hessian). Defaults to 0.
|
|
40
43
|
hvp_method (str, optional):
|
|
41
44
|
How to compute hessian-matrix product:
|
|
42
|
-
- "
|
|
45
|
+
- "batched_autograd" - uses batched autograd
|
|
43
46
|
- "autograd" - uses unbatched autograd
|
|
44
47
|
- "forward" - uses finite difference with forward formula, performing 1 backward pass per Hvp.
|
|
45
48
|
- "central" - uses finite difference with a more accurate central formula, performing 2 backward passes per Hvp.
|
|
46
49
|
|
|
47
|
-
. Defaults to "
|
|
50
|
+
. Defaults to "batched_autograd".
|
|
48
51
|
h (float, optional): finite difference step size. Defaults to 1e-2.
|
|
49
52
|
use_lstsq (bool, optional): whether to use least squares to solve ``Hx=g``. Defaults to False.
|
|
50
53
|
update_freq (int, optional): frequency of updating the hessian. Defaults to 1.
|
|
@@ -67,7 +70,7 @@ class RSN(Module):
|
|
|
67
70
|
|
|
68
71
|
RSN with line search
|
|
69
72
|
```python
|
|
70
|
-
opt = tz.
|
|
73
|
+
opt = tz.Optimizer(
|
|
71
74
|
model.parameters(),
|
|
72
75
|
tz.m.RSN(),
|
|
73
76
|
tz.m.Backtracking()
|
|
@@ -76,7 +79,7 @@ class RSN(Module):
|
|
|
76
79
|
|
|
77
80
|
RSN with trust region
|
|
78
81
|
```python
|
|
79
|
-
opt = tz.
|
|
82
|
+
opt = tz.Optimizer(
|
|
80
83
|
model.parameters(),
|
|
81
84
|
tz.m.LevenbergMarquardt(tz.m.RSN()),
|
|
82
85
|
)
|
|
@@ -91,137 +94,141 @@ class RSN(Module):
|
|
|
91
94
|
def __init__(
|
|
92
95
|
self,
|
|
93
96
|
sketch_size: int,
|
|
94
|
-
sketch_type: Literal["orthonormal", "
|
|
97
|
+
sketch_type: Literal["orthonormal", "common_directions", "mixed", "rademacher"] = "common_directions",
|
|
95
98
|
damping:float=0,
|
|
96
|
-
hvp_method: Literal["batched", "autograd", "forward", "central"] = "batched",
|
|
97
|
-
h: float = 1e-2,
|
|
98
|
-
use_lstsq: bool = True,
|
|
99
|
-
update_freq: int = 1,
|
|
100
|
-
H_tfm: Callable[[torch.Tensor, torch.Tensor], tuple[torch.Tensor, bool]] | Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
|
|
101
99
|
eigval_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
|
|
100
|
+
update_freq: int = 1,
|
|
101
|
+
precompute_inverse: bool = False,
|
|
102
|
+
use_lstsq: bool = True,
|
|
103
|
+
hvp_method: HVPMethod = "batched_autograd",
|
|
104
|
+
h: float = 1e-2,
|
|
102
105
|
seed: int | None = None,
|
|
103
106
|
inner: Chainable | None = None,
|
|
104
107
|
):
|
|
105
|
-
defaults =
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
if inner is not None:
|
|
109
|
-
self.set_child("inner", inner)
|
|
108
|
+
defaults = locals().copy()
|
|
109
|
+
del defaults['self'], defaults['inner'], defaults["update_freq"]
|
|
110
|
+
super().__init__(defaults, update_freq=update_freq, inner=inner)
|
|
110
111
|
|
|
111
112
|
@torch.no_grad
|
|
112
|
-
def
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
if step % self.defaults['update_freq'] == 0:
|
|
113
|
+
def update_states(self, objective, states, settings):
|
|
114
|
+
fs = settings[0]
|
|
115
|
+
params = objective.params
|
|
116
|
+
generator = self.get_generator(params[0].device, fs["seed"])
|
|
117
117
|
|
|
118
|
-
|
|
119
|
-
if closure is None:
|
|
120
|
-
raise RuntimeError("RSN requires closure")
|
|
121
|
-
params = var.params
|
|
122
|
-
generator = self.get_generator(params[0].device, self.defaults["seed"])
|
|
118
|
+
ndim = sum(p.numel() for p in params)
|
|
123
119
|
|
|
124
|
-
|
|
120
|
+
device=params[0].device
|
|
121
|
+
dtype=params[0].dtype
|
|
125
122
|
|
|
126
|
-
|
|
127
|
-
|
|
123
|
+
# sample sketch matrix S: (ndim, sketch_size)
|
|
124
|
+
sketch_size = min(fs["sketch_size"], ndim)
|
|
125
|
+
sketch_type = fs["sketch_type"]
|
|
126
|
+
hvp_method = fs["hvp_method"]
|
|
128
127
|
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
sketch_type = self.defaults["sketch_type"]
|
|
132
|
-
hvp_method = self.defaults["hvp_method"]
|
|
128
|
+
if sketch_type == "rademacher":
|
|
129
|
+
S = _rademacher_sketch(ndim, sketch_size, device=device, dtype=dtype, generator=generator)
|
|
133
130
|
|
|
134
|
-
|
|
135
|
-
|
|
131
|
+
elif sketch_type == 'orthonormal':
|
|
132
|
+
S = _orthonormal_sketch(ndim, sketch_size, device=device, dtype=dtype, generator=generator)
|
|
136
133
|
|
|
137
|
-
|
|
138
|
-
|
|
134
|
+
elif sketch_type == 'common_directions':
|
|
135
|
+
# Wang, Po-Wei, Ching-pei Lee, and Chih-Jen Lin. "The common-directions method for regularized empirical risk minimization." Journal of Machine Learning Research 20.58 (2019): 1-49.
|
|
136
|
+
g_list = objective.get_grads(create_graph=hvp_method in ("batched_autograd", "autograd"))
|
|
137
|
+
g = torch.cat([t.ravel() for t in g_list])
|
|
139
138
|
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
g_list = var.get_grad(create_graph=hvp_method in ("batched", "autograd"))
|
|
143
|
-
g = torch.cat([t.ravel() for t in g_list])
|
|
144
|
-
|
|
145
|
-
# initialize directions deque
|
|
146
|
-
if "directions" not in self.global_state:
|
|
139
|
+
# initialize directions deque
|
|
140
|
+
if "directions" not in self.global_state:
|
|
147
141
|
|
|
142
|
+
g_norm = torch.linalg.vector_norm(g) # pylint:disable=not-callable
|
|
143
|
+
if g_norm < torch.finfo(g.dtype).tiny * 2:
|
|
144
|
+
g = torch.randn_like(g)
|
|
148
145
|
g_norm = torch.linalg.vector_norm(g) # pylint:disable=not-callable
|
|
149
|
-
if g_norm < torch.finfo(g.dtype).tiny * 2:
|
|
150
|
-
g = torch.randn_like(g)
|
|
151
|
-
g_norm = torch.linalg.vector_norm(g) # pylint:disable=not-callable
|
|
152
|
-
|
|
153
|
-
self.global_state["directions"] = deque([g / g_norm], maxlen=sketch_size)
|
|
154
|
-
S = self.global_state["directions"][0].unsqueeze(1)
|
|
155
|
-
|
|
156
|
-
# add new steepest descent direction orthonormal to existing columns
|
|
157
|
-
else:
|
|
158
|
-
S = torch.stack(tuple(self.global_state["directions"]), dim=1)
|
|
159
|
-
p = g - S @ (S.T @ g)
|
|
160
|
-
p_norm = torch.linalg.vector_norm(p) # pylint:disable=not-callable
|
|
161
|
-
if p_norm > torch.finfo(p.dtype).tiny * 2:
|
|
162
|
-
p = p / p_norm
|
|
163
|
-
self.global_state["directions"].append(p)
|
|
164
|
-
S = torch.cat([S, p.unsqueeze(1)], dim=1)
|
|
165
|
-
|
|
166
|
-
elif sketch_type == "mixed":
|
|
167
|
-
g_list = var.get_grad(create_graph=hvp_method in ("batched", "autograd"))
|
|
168
|
-
g = torch.cat([t.ravel() for t in g_list])
|
|
169
|
-
|
|
170
|
-
if "slow_ema" not in self.global_state:
|
|
171
|
-
self.global_state["slow_ema"] = torch.randn_like(g) * 1e-2
|
|
172
|
-
self.global_state["fast_ema"] = torch.randn_like(g) * 1e-2
|
|
173
|
-
|
|
174
|
-
slow_ema = self.global_state["slow_ema"]
|
|
175
|
-
fast_ema = self.global_state["fast_ema"]
|
|
176
|
-
slow_ema.lerp_(g, 0.001)
|
|
177
|
-
fast_ema.lerp_(g, 0.1)
|
|
178
|
-
|
|
179
|
-
S = torch.stack([g, slow_ema, fast_ema], dim=1)
|
|
180
|
-
if sketch_size > 3:
|
|
181
|
-
S_random = _gaussian_sketch(ndim, sketch_size - 3, device=device, dtype=dtype, generator=generator)
|
|
182
|
-
S = torch.cat([S, S_random], dim=1)
|
|
183
|
-
|
|
184
|
-
S = _qr_orthonormalize(S)
|
|
185
146
|
|
|
147
|
+
self.global_state["directions"] = deque([g / g_norm], maxlen=sketch_size)
|
|
148
|
+
S = self.global_state["directions"][0].unsqueeze(1)
|
|
149
|
+
|
|
150
|
+
# add new steepest descent direction orthonormal to existing columns
|
|
186
151
|
else:
|
|
187
|
-
|
|
152
|
+
S = torch.stack(tuple(self.global_state["directions"]), dim=1)
|
|
153
|
+
p = g - S @ (S.T @ g)
|
|
154
|
+
p_norm = torch.linalg.vector_norm(p) # pylint:disable=not-callable
|
|
155
|
+
if p_norm > torch.finfo(p.dtype).tiny * 2:
|
|
156
|
+
p = p / p_norm
|
|
157
|
+
self.global_state["directions"].append(p)
|
|
158
|
+
S = torch.cat([S, p.unsqueeze(1)], dim=1)
|
|
159
|
+
|
|
160
|
+
elif sketch_type == "mixed":
|
|
161
|
+
g_list = objective.get_grads(create_graph=hvp_method in ("batched_autograd", "autograd"))
|
|
162
|
+
g = torch.cat([t.ravel() for t in g_list])
|
|
163
|
+
|
|
164
|
+
# initialize state
|
|
165
|
+
if "slow_ema" not in self.global_state:
|
|
166
|
+
self.global_state["slow_ema"] = torch.randn_like(g) * 1e-2
|
|
167
|
+
self.global_state["fast_ema"] = torch.randn_like(g) * 1e-2
|
|
168
|
+
self.global_state["p_prev"] = torch.randn_like(g)
|
|
169
|
+
|
|
170
|
+
# previous update direction
|
|
171
|
+
p_cur = torch.cat([t.ravel() for t in params])
|
|
172
|
+
prev_dir = p_cur - self.global_state["p_prev"]
|
|
173
|
+
self.global_state["p_prev"] = p_cur
|
|
174
|
+
|
|
175
|
+
# EMAs
|
|
176
|
+
slow_ema = self.global_state["slow_ema"]
|
|
177
|
+
fast_ema = self.global_state["fast_ema"]
|
|
178
|
+
slow_ema.lerp_(g, 0.001)
|
|
179
|
+
fast_ema.lerp_(g, 0.1)
|
|
180
|
+
|
|
181
|
+
# form and orthogonalize sketching matrix
|
|
182
|
+
S = torch.stack([g, slow_ema, fast_ema, prev_dir], dim=1)
|
|
183
|
+
if sketch_size > 4:
|
|
184
|
+
S_random = torch.randn(ndim, sketch_size - 3, device=device, dtype=dtype, generator=generator) / math.sqrt(ndim)
|
|
185
|
+
S = torch.cat([S, S_random], dim=1)
|
|
186
|
+
|
|
187
|
+
S = _qr_orthonormalize(S)
|
|
188
|
+
|
|
189
|
+
else:
|
|
190
|
+
raise ValueError(f'Unknown sketch_type {sketch_type}')
|
|
191
|
+
|
|
192
|
+
# form sketched hessian
|
|
193
|
+
HS, _ = objective.hessian_matrix_product(S, rgrad=None, at_x0=True,
|
|
194
|
+
hvp_method=fs["hvp_method"], h=fs["h"])
|
|
195
|
+
H_sketched = S.T @ HS
|
|
196
|
+
|
|
197
|
+
# update state
|
|
198
|
+
_newton_update_state_(
|
|
199
|
+
state = self.global_state,
|
|
200
|
+
H = H_sketched,
|
|
201
|
+
damping = fs["damping"],
|
|
202
|
+
eigval_fn = fs["eigval_fn"],
|
|
203
|
+
precompute_inverse = fs["precompute_inverse"],
|
|
204
|
+
use_lstsq = fs["use_lstsq"]
|
|
188
205
|
|
|
189
|
-
|
|
190
|
-
HS, _ = var.hessian_matrix_product(S, at_x0=True, rgrad=None, hvp_method=self.defaults["hvp_method"], normalize=True, retain_graph=False, h=self.defaults["h"])
|
|
191
|
-
H_sketched = S.T @ HS
|
|
206
|
+
)
|
|
192
207
|
|
|
193
|
-
|
|
194
|
-
self.global_state["S"] = S
|
|
208
|
+
self.global_state["S"] = S
|
|
195
209
|
|
|
196
|
-
def
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
var=var,
|
|
200
|
-
H=self.global_state["H_sketched"],
|
|
201
|
-
damping=self.defaults["damping"],
|
|
202
|
-
inner=self.children.get("inner", None),
|
|
203
|
-
H_tfm=self.defaults["H_tfm"],
|
|
204
|
-
eigval_fn=self.defaults["eigval_fn"],
|
|
205
|
-
use_lstsq=self.defaults["use_lstsq"],
|
|
206
|
-
g_proj = lambda g: S.T @ g
|
|
207
|
-
)
|
|
208
|
-
d = S @ d_proj
|
|
209
|
-
var.update = vec_to_tensors(d, var.params)
|
|
210
|
+
def apply_states(self, objective, states, settings):
|
|
211
|
+
updates = objective.get_updates()
|
|
212
|
+
fs = settings[0]
|
|
210
213
|
|
|
211
|
-
|
|
214
|
+
S = self.global_state["S"]
|
|
215
|
+
b = torch.cat([t.ravel() for t in updates])
|
|
216
|
+
b_proj = S.T @ b
|
|
212
217
|
|
|
213
|
-
|
|
214
|
-
eigval_fn = self.defaults["eigval_fn"]
|
|
215
|
-
H_sketched: torch.Tensor = self.global_state["H_sketched"]
|
|
216
|
-
S: torch.Tensor = self.global_state["S"]
|
|
218
|
+
d_proj = _newton_solve(b=b_proj, state=self.global_state, use_lstsq=fs["use_lstsq"])
|
|
217
219
|
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
L: torch.Tensor = eigval_fn(L)
|
|
222
|
-
H_sketched = Q @ L.diag_embed() @ Q.mH
|
|
220
|
+
d = S @ d_proj
|
|
221
|
+
vec_to_tensors_(d, updates)
|
|
222
|
+
return objective
|
|
223
223
|
|
|
224
|
-
|
|
225
|
-
|
|
224
|
+
def get_H(self, objective=...):
|
|
225
|
+
if "H" in self.global_state:
|
|
226
|
+
H_sketched = self.global_state["H"]
|
|
226
227
|
|
|
228
|
+
else:
|
|
229
|
+
L = self.global_state["L"]
|
|
230
|
+
Q = self.global_state["Q"]
|
|
231
|
+
H_sketched = Q @ L.diag_embed() @ Q.mH
|
|
232
|
+
|
|
233
|
+
S: torch.Tensor = self.global_state["S"]
|
|
227
234
|
return Sketched(S, H_sketched)
|
|
@@ -4,7 +4,7 @@ from collections.abc import Iterable
|
|
|
4
4
|
import torch
|
|
5
5
|
|
|
6
6
|
from ...utils.tensorlist import TensorList
|
|
7
|
-
from ...core import
|
|
7
|
+
from ...core import TensorTransform
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
def vector_laplacian_smoothing(input: torch.Tensor, sigma: float = 1) -> torch.Tensor:
|
|
@@ -55,7 +55,7 @@ def _precompute_denominator(tensor: torch.Tensor, sigma) -> torch.Tensor:
|
|
|
55
55
|
v[-1] = 1
|
|
56
56
|
return 1 - sigma * torch.fft.fft(v) # pylint: disable = not-callable
|
|
57
57
|
|
|
58
|
-
class LaplacianSmoothing(
|
|
58
|
+
class LaplacianSmoothing(TensorTransform):
|
|
59
59
|
"""Applies laplacian smoothing via a fast Fourier transform solver which can improve generalization.
|
|
60
60
|
|
|
61
61
|
Args:
|
|
@@ -70,29 +70,30 @@ class LaplacianSmoothing(Transform):
|
|
|
70
70
|
what to set on var.
|
|
71
71
|
|
|
72
72
|
Examples:
|
|
73
|
-
|
|
73
|
+
Laplacian Smoothing Gradient Descent optimizer as in the paper
|
|
74
74
|
|
|
75
|
-
|
|
75
|
+
```python
|
|
76
76
|
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
77
|
+
opt = tz.Optimizer(
|
|
78
|
+
model.parameters(),
|
|
79
|
+
tz.m.LaplacianSmoothing(),
|
|
80
|
+
tz.m.LR(1e-2),
|
|
81
|
+
)
|
|
82
|
+
```
|
|
82
83
|
|
|
83
84
|
Reference:
|
|
84
85
|
Osher, S., Wang, B., Yin, P., Luo, X., Barekat, F., Pham, M., & Lin, A. (2022). Laplacian smoothing gradient descent. Research in the Mathematical Sciences, 9(3), 55.
|
|
85
86
|
|
|
86
87
|
"""
|
|
87
|
-
def __init__(self, sigma:float = 1, layerwise=True, min_numel = 4
|
|
88
|
+
def __init__(self, sigma:float = 1, layerwise=True, min_numel = 4):
|
|
88
89
|
defaults = dict(sigma = sigma, layerwise=layerwise, min_numel=min_numel)
|
|
89
|
-
super().__init__(defaults
|
|
90
|
+
super().__init__(defaults)
|
|
90
91
|
# precomputed denominator for when layerwise=False
|
|
91
92
|
self.global_state['full_denominator'] = None
|
|
92
93
|
|
|
93
94
|
|
|
94
95
|
@torch.no_grad
|
|
95
|
-
def
|
|
96
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
96
97
|
layerwise = settings[0]['layerwise']
|
|
97
98
|
|
|
98
99
|
# layerwise laplacian smoothing
|
|
@@ -7,14 +7,14 @@ from typing import Literal, cast
|
|
|
7
7
|
|
|
8
8
|
import torch
|
|
9
9
|
|
|
10
|
-
from ...core import Chainable,
|
|
10
|
+
from ...core import Chainable, Optimizer, Module, Objective
|
|
11
11
|
from ...core.reformulation import Reformulation
|
|
12
12
|
from ...utils import Distributions, NumberList, TensorList
|
|
13
13
|
from ..termination import TerminationCriteriaBase, make_termination_criteria
|
|
14
14
|
|
|
15
15
|
|
|
16
|
-
def _reset_except_self(
|
|
17
|
-
for m in
|
|
16
|
+
def _reset_except_self(objective: Objective, modules, self: Module):
|
|
17
|
+
for m in modules:
|
|
18
18
|
if m is not self:
|
|
19
19
|
m.reset()
|
|
20
20
|
|
|
@@ -98,15 +98,15 @@ class GradientSampling(Reformulation):
|
|
|
98
98
|
self.set_child('termination', make_termination_criteria(extra=termination))
|
|
99
99
|
|
|
100
100
|
@torch.no_grad
|
|
101
|
-
def pre_step(self,
|
|
102
|
-
params = TensorList(
|
|
101
|
+
def pre_step(self, objective):
|
|
102
|
+
params = TensorList(objective.params)
|
|
103
103
|
|
|
104
104
|
fixed = self.defaults['fixed']
|
|
105
105
|
|
|
106
106
|
# check termination criteria
|
|
107
107
|
if 'termination' in self.children:
|
|
108
108
|
termination = cast(TerminationCriteriaBase, self.children['termination'])
|
|
109
|
-
if termination.should_terminate(
|
|
109
|
+
if termination.should_terminate(objective):
|
|
110
110
|
|
|
111
111
|
# decay sigmas
|
|
112
112
|
states = [self.state[p] for p in params]
|
|
@@ -118,7 +118,7 @@ class GradientSampling(Reformulation):
|
|
|
118
118
|
|
|
119
119
|
# reset on sigmas decay
|
|
120
120
|
if self.defaults['reset_on_termination']:
|
|
121
|
-
|
|
121
|
+
objective.post_step_hooks.append(partial(_reset_except_self, self=self))
|
|
122
122
|
|
|
123
123
|
# clear perturbations
|
|
124
124
|
self.global_state.pop('perts', None)
|
|
@@ -136,7 +136,7 @@ class GradientSampling(Reformulation):
|
|
|
136
136
|
self.global_state['perts'] = perts
|
|
137
137
|
|
|
138
138
|
@torch.no_grad
|
|
139
|
-
def closure(self, backward, closure, params,
|
|
139
|
+
def closure(self, backward, closure, params, objective):
|
|
140
140
|
params = TensorList(params)
|
|
141
141
|
loss_agg = None
|
|
142
142
|
grad_agg = None
|
|
@@ -160,7 +160,7 @@ class GradientSampling(Reformulation):
|
|
|
160
160
|
|
|
161
161
|
# evaluate at x_0
|
|
162
162
|
if include_x0:
|
|
163
|
-
f_0 =
|
|
163
|
+
f_0 = objective.get_loss(backward=backward)
|
|
164
164
|
|
|
165
165
|
isfinite = math.isfinite(f_0)
|
|
166
166
|
if isfinite:
|
|
@@ -168,7 +168,7 @@ class GradientSampling(Reformulation):
|
|
|
168
168
|
loss_agg = f_0
|
|
169
169
|
|
|
170
170
|
if backward:
|
|
171
|
-
g_0 =
|
|
171
|
+
g_0 = objective.get_grads()
|
|
172
172
|
if isfinite: grad_agg = g_0
|
|
173
173
|
|
|
174
174
|
# evaluate at x_0 + p for each perturbation
|
|
@@ -5,10 +5,10 @@ from typing import Any, Literal
|
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
|
|
8
|
-
from ...core import Chainable,
|
|
8
|
+
from ...core import Chainable, TensorTransform
|
|
9
9
|
from ...utils import NumberList, TensorList, tofloat, unpack_dicts, unpack_states
|
|
10
|
-
from ...
|
|
11
|
-
from ..
|
|
10
|
+
from ...linalg.linear_operator import ScaledIdentity
|
|
11
|
+
from ..opt_utils import epsilon_step_size
|
|
12
12
|
|
|
13
13
|
def _acceptable_alpha(alpha, param:torch.Tensor):
|
|
14
14
|
finfo = torch.finfo(param.dtype)
|
|
@@ -16,7 +16,7 @@ def _acceptable_alpha(alpha, param:torch.Tensor):
|
|
|
16
16
|
return False
|
|
17
17
|
return True
|
|
18
18
|
|
|
19
|
-
def
|
|
19
|
+
def _get_scaled_identity_H(self: TensorTransform, var):
|
|
20
20
|
n = sum(p.numel() for p in var.params)
|
|
21
21
|
p = var.params[0]
|
|
22
22
|
alpha = self.global_state.get('alpha', 1)
|
|
@@ -25,7 +25,7 @@ def _get_H(self: Transform, var):
|
|
|
25
25
|
return ScaledIdentity(1 / alpha, shape=(n,n), device=p.device, dtype=p.dtype)
|
|
26
26
|
|
|
27
27
|
|
|
28
|
-
class PolyakStepSize(
|
|
28
|
+
class PolyakStepSize(TensorTransform):
|
|
29
29
|
"""Polyak's subgradient method with known or unknown f*.
|
|
30
30
|
|
|
31
31
|
Args:
|
|
@@ -47,7 +47,7 @@ class PolyakStepSize(Transform):
|
|
|
47
47
|
super().__init__(defaults, uses_grad=use_grad, uses_loss=True, inner=inner)
|
|
48
48
|
|
|
49
49
|
@torch.no_grad
|
|
50
|
-
def
|
|
50
|
+
def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
|
|
51
51
|
assert grads is not None and loss is not None
|
|
52
52
|
tensors = TensorList(tensors)
|
|
53
53
|
grads = TensorList(grads)
|
|
@@ -79,15 +79,15 @@ class PolyakStepSize(Transform):
|
|
|
79
79
|
self.global_state['alpha'] = alpha
|
|
80
80
|
|
|
81
81
|
@torch.no_grad
|
|
82
|
-
def
|
|
82
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
83
83
|
alpha = self.global_state.get('alpha', 1)
|
|
84
84
|
if not _acceptable_alpha(alpha, tensors[0]): alpha = epsilon_step_size(TensorList(tensors))
|
|
85
85
|
|
|
86
86
|
torch._foreach_mul_(tensors, alpha * unpack_dicts(settings, 'alpha', cls=NumberList))
|
|
87
87
|
return tensors
|
|
88
88
|
|
|
89
|
-
def get_H(self,
|
|
90
|
-
return
|
|
89
|
+
def get_H(self, objective):
|
|
90
|
+
return _get_scaled_identity_H(self, objective)
|
|
91
91
|
|
|
92
92
|
|
|
93
93
|
def _bb_short(s: TensorList, y: TensorList, sy, eps):
|
|
@@ -116,7 +116,7 @@ def _bb_geom(s: TensorList, y: TensorList, sy, eps, fallback:bool):
|
|
|
116
116
|
return None
|
|
117
117
|
return (short * long) ** 0.5
|
|
118
118
|
|
|
119
|
-
class BarzilaiBorwein(
|
|
119
|
+
class BarzilaiBorwein(TensorTransform):
|
|
120
120
|
"""Barzilai-Borwein step size method.
|
|
121
121
|
|
|
122
122
|
Args:
|
|
@@ -144,7 +144,7 @@ class BarzilaiBorwein(Transform):
|
|
|
144
144
|
self.global_state['reset'] = True
|
|
145
145
|
|
|
146
146
|
@torch.no_grad
|
|
147
|
-
def
|
|
147
|
+
def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
|
|
148
148
|
step = self.global_state.get('step', 0)
|
|
149
149
|
self.global_state['step'] = step + 1
|
|
150
150
|
|
|
@@ -175,11 +175,11 @@ class BarzilaiBorwein(Transform):
|
|
|
175
175
|
prev_p.copy_(params)
|
|
176
176
|
prev_g.copy_(g)
|
|
177
177
|
|
|
178
|
-
def get_H(self,
|
|
179
|
-
return
|
|
178
|
+
def get_H(self, objective):
|
|
179
|
+
return _get_scaled_identity_H(self, objective)
|
|
180
180
|
|
|
181
181
|
@torch.no_grad
|
|
182
|
-
def
|
|
182
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
183
183
|
alpha = self.global_state.get('alpha', None)
|
|
184
184
|
|
|
185
185
|
if not _acceptable_alpha(alpha, tensors[0]):
|
|
@@ -189,7 +189,7 @@ class BarzilaiBorwein(Transform):
|
|
|
189
189
|
return tensors
|
|
190
190
|
|
|
191
191
|
|
|
192
|
-
class BBStab(
|
|
192
|
+
class BBStab(TensorTransform):
|
|
193
193
|
"""Stabilized Barzilai-Borwein method (https://arxiv.org/abs/1907.06409).
|
|
194
194
|
|
|
195
195
|
This clips the norm of the Barzilai-Borwein update by ``delta``, where ``delta`` can be adaptive if ``c`` is specified.
|
|
@@ -228,7 +228,7 @@ class BBStab(Transform):
|
|
|
228
228
|
self.global_state['reset'] = True
|
|
229
229
|
|
|
230
230
|
@torch.no_grad
|
|
231
|
-
def
|
|
231
|
+
def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
|
|
232
232
|
step = self.global_state.get('step', 0)
|
|
233
233
|
self.global_state['step'] = step + 1
|
|
234
234
|
|
|
@@ -287,11 +287,11 @@ class BBStab(Transform):
|
|
|
287
287
|
prev_p.copy_(params)
|
|
288
288
|
prev_g.copy_(g)
|
|
289
289
|
|
|
290
|
-
def get_H(self,
|
|
291
|
-
return
|
|
290
|
+
def get_H(self, objective):
|
|
291
|
+
return _get_scaled_identity_H(self, objective)
|
|
292
292
|
|
|
293
293
|
@torch.no_grad
|
|
294
|
-
def
|
|
294
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
295
295
|
alpha = self.global_state.get('alpha', None)
|
|
296
296
|
|
|
297
297
|
if not _acceptable_alpha(alpha, tensors[0]):
|
|
@@ -301,7 +301,7 @@ class BBStab(Transform):
|
|
|
301
301
|
return tensors
|
|
302
302
|
|
|
303
303
|
|
|
304
|
-
class AdGD(
|
|
304
|
+
class AdGD(TensorTransform):
|
|
305
305
|
"""AdGD and AdGD-2 (https://arxiv.org/abs/2308.02261)"""
|
|
306
306
|
def __init__(self, variant:Literal[1,2]=2, alpha_0:float = 1e-7, sqrt:bool=True, use_grad=True, inner: Chainable | None = None,):
|
|
307
307
|
defaults = dict(variant=variant, alpha_0=alpha_0, sqrt=sqrt)
|
|
@@ -313,7 +313,7 @@ class AdGD(Transform):
|
|
|
313
313
|
self.global_state['reset'] = True
|
|
314
314
|
|
|
315
315
|
@torch.no_grad
|
|
316
|
-
def
|
|
316
|
+
def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
|
|
317
317
|
variant = settings[0]['variant']
|
|
318
318
|
theta_0 = 0 if variant == 1 else 1/3
|
|
319
319
|
theta = self.global_state.get('theta', theta_0)
|
|
@@ -371,7 +371,7 @@ class AdGD(Transform):
|
|
|
371
371
|
prev_g.copy_(g)
|
|
372
372
|
|
|
373
373
|
@torch.no_grad
|
|
374
|
-
def
|
|
374
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
375
375
|
alpha = self.global_state.get('alpha', None)
|
|
376
376
|
|
|
377
377
|
if not _acceptable_alpha(alpha, tensors[0]):
|
|
@@ -383,5 +383,5 @@ class AdGD(Transform):
|
|
|
383
383
|
torch._foreach_mul_(tensors, alpha)
|
|
384
384
|
return tensors
|
|
385
385
|
|
|
386
|
-
def get_H(self,
|
|
387
|
-
return
|
|
386
|
+
def get_H(self, objective):
|
|
387
|
+
return _get_scaled_identity_H(self, objective)
|