torchzero 0.3.11__py3-none-any.whl → 0.3.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_opts.py +95 -69
- tests/test_tensorlist.py +8 -7
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +225 -72
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +44 -24
- torchzero/modules/__init__.py +13 -5
- torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
- torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
- torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
- torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
- torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
- torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
- torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
- torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
- torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
- torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
- torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
- torchzero/modules/clipping/clipping.py +85 -92
- torchzero/modules/clipping/ema_clipping.py +5 -5
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
- torchzero/modules/experimental/__init__.py +9 -32
- torchzero/modules/experimental/dct.py +2 -2
- torchzero/modules/experimental/fft.py +2 -2
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +27 -14
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/experimental/structural_projections.py +1 -1
- torchzero/modules/functional.py +50 -14
- torchzero/modules/grad_approximation/fdm.py +19 -20
- torchzero/modules/grad_approximation/forward_gradient.py +4 -2
- torchzero/modules/grad_approximation/grad_approximator.py +43 -47
- torchzero/modules/grad_approximation/rfdm.py +144 -122
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +31 -23
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +2 -2
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +69 -44
- torchzero/modules/line_search/backtracking.py +83 -70
- torchzero/modules/line_search/line_search.py +159 -68
- torchzero/modules/line_search/scipy.py +1 -1
- torchzero/modules/line_search/strong_wolfe.py +319 -218
- torchzero/modules/misc/__init__.py +8 -0
- torchzero/modules/misc/debug.py +4 -4
- torchzero/modules/misc/escape.py +9 -7
- torchzero/modules/misc/gradient_accumulation.py +88 -22
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +82 -15
- torchzero/modules/misc/multistep.py +47 -11
- torchzero/modules/misc/regularization.py +5 -9
- torchzero/modules/misc/split.py +55 -35
- torchzero/modules/misc/switch.py +1 -1
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +3 -3
- torchzero/modules/momentum/cautious.py +42 -47
- torchzero/modules/momentum/momentum.py +35 -1
- torchzero/modules/ops/__init__.py +9 -1
- torchzero/modules/ops/binary.py +9 -8
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
- torchzero/modules/ops/multi.py +15 -15
- torchzero/modules/ops/reduce.py +1 -1
- torchzero/modules/ops/utility.py +12 -8
- torchzero/modules/projections/projection.py +4 -4
- torchzero/modules/quasi_newton/__init__.py +1 -16
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
- torchzero/modules/quasi_newton/lbfgs.py +256 -200
- torchzero/modules/quasi_newton/lsr1.py +167 -132
- torchzero/modules/quasi_newton/quasi_newton.py +346 -446
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +252 -0
- torchzero/modules/second_order/__init__.py +2 -1
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +133 -88
- torchzero/modules/second_order/newton_cg.py +141 -80
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +1 -1
- torchzero/modules/step_size/adaptive.py +312 -47
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +97 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/weight_decay.py +65 -64
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +359 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +0 -1
- torchzero/optim/wrappers/fcmaes.py +3 -2
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +2 -2
- torchzero/optim/wrappers/scipy.py +81 -22
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +123 -111
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +226 -154
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/python_tools.py +6 -0
- torchzero/utils/tensorlist.py +105 -34
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.13.dist-info/METADATA +14 -0
- torchzero-0.3.13.dist-info/RECORD +166 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -59
- docs/source/docstring template.py +0 -46
- torchzero/modules/experimental/absoap.py +0 -253
- torchzero/modules/experimental/adadam.py +0 -118
- torchzero/modules/experimental/adamY.py +0 -131
- torchzero/modules/experimental/adam_lambertw.py +0 -149
- torchzero/modules/experimental/adaptive_step_size.py +0 -90
- torchzero/modules/experimental/adasoap.py +0 -177
- torchzero/modules/experimental/cosine.py +0 -214
- torchzero/modules/experimental/cubic_adam.py +0 -97
- torchzero/modules/experimental/eigendescent.py +0 -120
- torchzero/modules/experimental/etf.py +0 -195
- torchzero/modules/experimental/exp_adam.py +0 -113
- torchzero/modules/experimental/expanded_lbfgs.py +0 -141
- torchzero/modules/experimental/hnewton.py +0 -85
- torchzero/modules/experimental/modular_lbfgs.py +0 -265
- torchzero/modules/experimental/parabolic_search.py +0 -220
- torchzero/modules/experimental/subspace_preconditioners.py +0 -145
- torchzero/modules/experimental/tensor_adagrad.py +0 -42
- torchzero/modules/line_search/polynomial.py +0 -233
- torchzero/modules/momentum/matrix_momentum.py +0 -193
- torchzero/modules/optimizers/adagrad.py +0 -165
- torchzero/modules/quasi_newton/trust_region.py +0 -397
- torchzero/modules/smoothing/gaussian.py +0 -198
- torchzero-0.3.11.dist-info/METADATA +0 -404
- torchzero-0.3.11.dist-info/RECORD +0 -159
- torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
- /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
torchzero/modules/functional.py
CHANGED
|
@@ -9,9 +9,17 @@ Additional functional variants are present in most module files, e.g. `adam_`, `
|
|
|
9
9
|
"""
|
|
10
10
|
from collections.abc import Callable
|
|
11
11
|
from typing import overload
|
|
12
|
+
|
|
12
13
|
import torch
|
|
13
14
|
|
|
14
|
-
from ..utils import
|
|
15
|
+
from ..utils import (
|
|
16
|
+
NumberList,
|
|
17
|
+
TensorList,
|
|
18
|
+
generic_finfo_eps,
|
|
19
|
+
generic_max,
|
|
20
|
+
generic_sum,
|
|
21
|
+
tofloat,
|
|
22
|
+
)
|
|
15
23
|
|
|
16
24
|
inf = float('inf')
|
|
17
25
|
|
|
@@ -87,10 +95,10 @@ def root(tensors_:TensorList, p:float, inplace: bool):
|
|
|
87
95
|
if p == 1: return tensors_.abs_()
|
|
88
96
|
if p == 2: return tensors_.sqrt_()
|
|
89
97
|
return tensors_.pow_(1/p)
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
98
|
+
|
|
99
|
+
if p == 1: return tensors_.abs()
|
|
100
|
+
if p == 2: return tensors_.sqrt()
|
|
101
|
+
return tensors_.pow(1/p)
|
|
94
102
|
|
|
95
103
|
|
|
96
104
|
def ema_(
|
|
@@ -207,13 +215,41 @@ def sqrt_centered_ema_sq_(
|
|
|
207
215
|
ema_sq_fn=lambda *a, **kw: centered_ema_sq_(*a, **kw, exp_avg_=exp_avg_)
|
|
208
216
|
)
|
|
209
217
|
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
218
|
+
def initial_step_size(tensors: torch.Tensor | TensorList, eps=None) -> float:
|
|
219
|
+
"""initial scaling taken from pytorch L-BFGS to avoid requiring a lot of line search iterations,
|
|
220
|
+
this version is safer and makes sure largest value isn't smaller than epsilon."""
|
|
221
|
+
tensors_abs = tensors.abs()
|
|
222
|
+
tensors_sum = generic_sum(tensors_abs)
|
|
223
|
+
tensors_max = generic_max(tensors_abs)
|
|
224
|
+
|
|
225
|
+
feps = generic_finfo_eps(tensors)
|
|
226
|
+
if eps is None: eps = feps
|
|
227
|
+
else: eps = max(eps, feps)
|
|
228
|
+
|
|
229
|
+
# scale should not make largest value smaller than epsilon
|
|
230
|
+
min = eps / tensors_max
|
|
231
|
+
if min >= 1: return 1.0
|
|
232
|
+
|
|
233
|
+
scale = 1 / tensors_sum
|
|
234
|
+
scale = scale.clip(min=min.item(), max=1)
|
|
235
|
+
return scale.item()
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
def epsilon_step_size(tensors: torch.Tensor | TensorList, alpha=1e-7) -> float:
|
|
239
|
+
"""makes sure largest value isn't smaller than epsilon."""
|
|
240
|
+
tensors_abs = tensors.abs()
|
|
241
|
+
tensors_max = generic_max(tensors_abs)
|
|
242
|
+
if tensors_max < alpha: return 1.0
|
|
243
|
+
|
|
244
|
+
if tensors_max < 1: alpha = alpha / tensors_max
|
|
245
|
+
return tofloat(alpha)
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
def safe_clip(x: torch.Tensor, min=None):
|
|
250
|
+
"""makes sure absolute value of scalar tensor x is not smaller than min"""
|
|
251
|
+
assert x.numel() == 1, x.shape
|
|
252
|
+
if min is None: min = torch.finfo(x.dtype).tiny * 2
|
|
219
253
|
|
|
254
|
+
if x.abs() < min: return x.new_full(x.size(), min).copysign(x)
|
|
255
|
+
return x
|
|
@@ -93,7 +93,7 @@ _FD_FUNCS = {
|
|
|
93
93
|
class FDM(GradApproximator):
|
|
94
94
|
"""Approximate gradients via finite difference method.
|
|
95
95
|
|
|
96
|
-
|
|
96
|
+
Note:
|
|
97
97
|
This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients,
|
|
98
98
|
and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.
|
|
99
99
|
|
|
@@ -103,24 +103,23 @@ class FDM(GradApproximator):
|
|
|
103
103
|
target (GradTarget, optional): what to set on var. Defaults to 'closure'.
|
|
104
104
|
|
|
105
105
|
Examples:
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
)
|
|
106
|
+
plain FDM:
|
|
107
|
+
|
|
108
|
+
```python
|
|
109
|
+
fdm = tz.Modular(model.parameters(), tz.m.FDM(), tz.m.LR(1e-2))
|
|
110
|
+
```
|
|
111
|
+
|
|
112
|
+
Any gradient-based method can use FDM-estimated gradients.
|
|
113
|
+
```python
|
|
114
|
+
fdm_ncg = tz.Modular(
|
|
115
|
+
model.parameters(),
|
|
116
|
+
tz.m.FDM(),
|
|
117
|
+
# set hvp_method to "forward" so that it
|
|
118
|
+
# uses gradient difference instead of autograd
|
|
119
|
+
tz.m.NewtonCG(hvp_method="forward"),
|
|
120
|
+
tz.m.Backtracking()
|
|
121
|
+
)
|
|
122
|
+
```
|
|
124
123
|
"""
|
|
125
124
|
def __init__(self, h: float=1e-3, formula: _FD_Formula = 'central', target: GradTarget = 'closure'):
|
|
126
125
|
defaults = dict(h=h, formula=formula)
|
|
@@ -139,7 +138,7 @@ class FDM(GradApproximator):
|
|
|
139
138
|
h = settings['h']
|
|
140
139
|
fd_fn = _FD_FUNCS[settings['formula']]
|
|
141
140
|
|
|
142
|
-
p_flat = p.
|
|
141
|
+
p_flat = p.ravel(); g_flat = g.ravel()
|
|
143
142
|
for i in range(len(p_flat)):
|
|
144
143
|
loss, loss_approx, d = fd_fn(closure=closure, param=p_flat, idx=i, h=h, v_0=loss)
|
|
145
144
|
g_flat[i] = d
|
|
@@ -15,7 +15,7 @@ class ForwardGradient(RandomizedFDM):
|
|
|
15
15
|
|
|
16
16
|
This method samples one or more directional derivatives evaluated via autograd jacobian-vector products. This is very similar to randomized finite difference.
|
|
17
17
|
|
|
18
|
-
|
|
18
|
+
Note:
|
|
19
19
|
This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients,
|
|
20
20
|
and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.
|
|
21
21
|
|
|
@@ -67,7 +67,9 @@ class ForwardGradient(RandomizedFDM):
|
|
|
67
67
|
grad = None
|
|
68
68
|
for i in range(n_samples):
|
|
69
69
|
prt = perturbations[i]
|
|
70
|
-
if prt[0] is None:
|
|
70
|
+
if prt[0] is None:
|
|
71
|
+
prt = params.sample_like(distribution=distribution, variance=1, generator=generator)
|
|
72
|
+
|
|
71
73
|
else: prt = TensorList(prt)
|
|
72
74
|
|
|
73
75
|
if jvp_method == 'autograd':
|
|
@@ -24,63 +24,59 @@ class GradApproximator(Module, ABC):
|
|
|
24
24
|
|
|
25
25
|
Example:
|
|
26
26
|
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
"""
|
|
27
|
+
Basic SPSA method implementation.
|
|
28
|
+
```python
|
|
29
|
+
class SPSA(GradApproximator):
|
|
30
|
+
def __init__(self, h=1e-3):
|
|
31
|
+
defaults = dict(h=h)
|
|
32
|
+
super().__init__(defaults)
|
|
33
|
+
|
|
34
|
+
@torch.no_grad
|
|
35
|
+
def approximate(self, closure, params, loss):
|
|
36
|
+
perturbation = [rademacher_like(p) * self.settings[p]['h'] for p in params]
|
|
37
|
+
|
|
38
|
+
# evaluate params + perturbation
|
|
39
|
+
torch._foreach_add_(params, perturbation)
|
|
40
|
+
loss_plus = closure(False)
|
|
41
|
+
|
|
42
|
+
# evaluate params - perturbation
|
|
43
|
+
torch._foreach_sub_(params, perturbation)
|
|
44
|
+
torch._foreach_sub_(params, perturbation)
|
|
45
|
+
loss_minus = closure(False)
|
|
46
|
+
|
|
47
|
+
# restore original params
|
|
48
|
+
torch._foreach_add_(params, perturbation)
|
|
49
|
+
|
|
50
|
+
# calculate SPSA gradients
|
|
51
|
+
spsa_grads = []
|
|
52
|
+
for p, pert in zip(params, perturbation):
|
|
53
|
+
settings = self.settings[p]
|
|
54
|
+
h = settings['h']
|
|
55
|
+
d = (loss_plus - loss_minus) / (2*(h**2))
|
|
56
|
+
spsa_grads.append(pert * d)
|
|
57
|
+
|
|
58
|
+
# returns tuple: (grads, loss, loss_approx)
|
|
59
|
+
# loss must be with initial parameters
|
|
60
|
+
# since we only evaluated loss with perturbed parameters
|
|
61
|
+
# we only have loss_approx
|
|
62
|
+
return spsa_grads, None, loss_plus
|
|
63
|
+
```
|
|
64
|
+
"""
|
|
67
65
|
def __init__(self, defaults: dict[str, Any] | None = None, target: GradTarget = 'closure'):
|
|
68
66
|
super().__init__(defaults)
|
|
69
67
|
self._target: GradTarget = target
|
|
70
68
|
|
|
71
69
|
@abstractmethod
|
|
72
|
-
def approximate(self, closure: Callable, params: list[torch.Tensor], loss:
|
|
73
|
-
"""Returns a tuple: (grad, loss, loss_approx)
|
|
70
|
+
def approximate(self, closure: Callable, params: list[torch.Tensor], loss: torch.Tensor | None) -> tuple[Iterable[torch.Tensor], torch.Tensor | None, torch.Tensor | None]:
|
|
71
|
+
"""Returns a tuple: ``(grad, loss, loss_approx)``, make sure this resets parameters to their original values!"""
|
|
74
72
|
|
|
75
|
-
def pre_step(self, var: Var) ->
|
|
73
|
+
def pre_step(self, var: Var) -> None:
|
|
76
74
|
"""This runs once before each step, whereas `approximate` may run multiple times per step if further modules
|
|
77
75
|
evaluate gradients at multiple points. This is useful for example to pre-generate new random perturbations."""
|
|
78
|
-
return var
|
|
79
76
|
|
|
80
77
|
@torch.no_grad
|
|
81
78
|
def step(self, var):
|
|
82
|
-
|
|
83
|
-
if isinstance(ret, Var): var = ret
|
|
79
|
+
self.pre_step(var)
|
|
84
80
|
|
|
85
81
|
if var.closure is None: raise RuntimeError("Gradient approximation requires closure")
|
|
86
82
|
params, closure, loss = var.params, var.closure, var.loss
|
|
@@ -108,4 +104,4 @@ class GradApproximator(Module, ABC):
|
|
|
108
104
|
else: raise ValueError(self._target)
|
|
109
105
|
return var
|
|
110
106
|
|
|
111
|
-
_FD_Formula = Literal['forward', 'forward2', 'backward', 'backward2', 'central', 'central2', 'central3', 'forward3', 'backward3', 'central4', 'forward4', 'forward5', '
|
|
107
|
+
_FD_Formula = Literal['forward', 'forward2', 'backward', 'backward2', 'central', 'central2', 'central3', 'forward3', 'backward3', 'central4', 'forward4', 'forward5', 'bspsa4']
|
|
@@ -115,26 +115,26 @@ def _rforward5(closure: Callable[..., float], params:TensorList, p_fn:Callable[[
|
|
|
115
115
|
h = h**2 # because perturbation already multiplied by h
|
|
116
116
|
return f_0, f_0, (-3*f_4 + 16*f_3 - 36*f_2 + 48*f_1 - 25*f_0) / (12 * h)
|
|
117
117
|
|
|
118
|
-
# another central4
|
|
119
|
-
def _bgspsa4(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h, f_0: float | None):
|
|
120
|
-
|
|
121
|
-
|
|
118
|
+
# # another central4
|
|
119
|
+
# def _bgspsa4(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h, f_0: float | None):
|
|
120
|
+
# params += p_fn()
|
|
121
|
+
# f_1 = closure(False)
|
|
122
122
|
|
|
123
|
-
|
|
124
|
-
|
|
123
|
+
# params += p_fn() * 2
|
|
124
|
+
# f_3 = closure(False)
|
|
125
125
|
|
|
126
|
-
|
|
127
|
-
|
|
126
|
+
# params -= p_fn() * 4
|
|
127
|
+
# f_m1 = closure(False)
|
|
128
128
|
|
|
129
|
-
|
|
130
|
-
|
|
129
|
+
# params -= p_fn() * 2
|
|
130
|
+
# f_m3 = closure(False)
|
|
131
131
|
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
132
|
+
# params += p_fn() * 3
|
|
133
|
+
# h = h**2 # because perturbation already multiplied by h
|
|
134
|
+
# return f_0, f_1, (27*f_1 - f_m1 - f_3 + f_m3) / (48 * h)
|
|
135
135
|
|
|
136
136
|
|
|
137
|
-
_RFD_FUNCS = {
|
|
137
|
+
_RFD_FUNCS: dict[_FD_Formula, Callable] = {
|
|
138
138
|
"forward": _rforward2,
|
|
139
139
|
"forward2": _rforward2,
|
|
140
140
|
"backward": _rbackward2,
|
|
@@ -147,14 +147,14 @@ _RFD_FUNCS = {
|
|
|
147
147
|
"central4": _rcentral4,
|
|
148
148
|
"forward4": _rforward4,
|
|
149
149
|
"forward5": _rforward5,
|
|
150
|
-
"bspsa4": _bgspsa4,
|
|
150
|
+
# "bspsa4": _bgspsa4,
|
|
151
151
|
}
|
|
152
152
|
|
|
153
153
|
|
|
154
154
|
class RandomizedFDM(GradApproximator):
|
|
155
155
|
"""Gradient approximation via a randomized finite-difference method.
|
|
156
156
|
|
|
157
|
-
|
|
157
|
+
Note:
|
|
158
158
|
This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients,
|
|
159
159
|
and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.
|
|
160
160
|
|
|
@@ -171,94 +171,95 @@ class RandomizedFDM(GradApproximator):
|
|
|
171
171
|
target (GradTarget, optional): what to set on var. Defaults to "closure".
|
|
172
172
|
|
|
173
173
|
Examples:
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
174
|
+
#### Simultaneous perturbation stochastic approximation (SPSA) method
|
|
175
|
+
|
|
176
|
+
SPSA is randomized finite differnce with rademacher distribution and central formula.
|
|
177
|
+
```py
|
|
178
|
+
spsa = tz.Modular(
|
|
179
|
+
model.parameters(),
|
|
180
|
+
tz.m.RandomizedFDM(formula="central", distribution="rademacher"),
|
|
181
|
+
tz.m.LR(1e-2)
|
|
182
|
+
)
|
|
183
|
+
```
|
|
184
|
+
|
|
185
|
+
#### Random-direction stochastic approximation (RDSA) method
|
|
186
|
+
|
|
187
|
+
RDSA is randomized finite differnce with usually gaussian distribution and central formula.
|
|
188
|
+
|
|
189
|
+
```
|
|
190
|
+
rdsa = tz.Modular(
|
|
191
|
+
model.parameters(),
|
|
192
|
+
tz.m.RandomizedFDM(formula="central", distribution="gaussian"),
|
|
193
|
+
tz.m.LR(1e-2)
|
|
194
|
+
)
|
|
195
|
+
```
|
|
196
|
+
|
|
197
|
+
#### RandomizedFDM with momentum
|
|
198
|
+
|
|
199
|
+
Momentum might help by reducing the variance of the estimated gradients.
|
|
200
|
+
|
|
201
|
+
```
|
|
202
|
+
momentum_spsa = tz.Modular(
|
|
203
|
+
model.parameters(),
|
|
204
|
+
tz.m.RandomizedFDM(),
|
|
205
|
+
tz.m.HeavyBall(0.9),
|
|
206
|
+
tz.m.LR(1e-3)
|
|
207
|
+
)
|
|
208
|
+
```
|
|
209
|
+
|
|
210
|
+
#### Gaussian smoothing method
|
|
211
|
+
|
|
212
|
+
GS uses many gaussian samples with possibly a larger finite difference step size.
|
|
213
|
+
|
|
214
|
+
```
|
|
215
|
+
gs = tz.Modular(
|
|
216
|
+
model.parameters(),
|
|
217
|
+
tz.m.RandomizedFDM(n_samples=100, distribution="gaussian", formula="forward2", h=1e-1),
|
|
218
|
+
tz.m.NewtonCG(hvp_method="forward"),
|
|
219
|
+
tz.m.Backtracking()
|
|
220
|
+
)
|
|
221
|
+
```
|
|
222
|
+
|
|
223
|
+
#### SPSA-NewtonCG
|
|
224
|
+
|
|
225
|
+
NewtonCG with hessian-vector product estimated via gradient difference
|
|
226
|
+
calls closure multiple times per step. If each closure call estimates gradients
|
|
227
|
+
with different perturbations, NewtonCG is unable to produce useful directions.
|
|
228
|
+
|
|
229
|
+
By setting pre_generate to True, perturbations are generated once before each step,
|
|
230
|
+
and each closure call estimates gradients using the same pre-generated perturbations.
|
|
231
|
+
This way closure-based algorithms are able to use gradients estimated in a consistent way.
|
|
232
|
+
|
|
233
|
+
```
|
|
234
|
+
opt = tz.Modular(
|
|
235
|
+
model.parameters(),
|
|
236
|
+
tz.m.RandomizedFDM(n_samples=10),
|
|
237
|
+
tz.m.NewtonCG(hvp_method="forward", pre_generate=True),
|
|
238
|
+
tz.m.Backtracking()
|
|
239
|
+
)
|
|
240
|
+
```
|
|
241
|
+
|
|
242
|
+
#### SPSA-LBFGS
|
|
243
|
+
|
|
244
|
+
LBFGS uses a memory of past parameter and gradient differences. If past gradients
|
|
245
|
+
were estimated with different perturbations, LBFGS directions will be useless.
|
|
246
|
+
|
|
247
|
+
To alleviate this momentum can be added to random perturbations to make sure they only
|
|
248
|
+
change by a little bit, and the history stays relevant. The momentum is determined by the :code:`beta` parameter.
|
|
249
|
+
The disadvantage is that the subspace the algorithm is able to explore changes slowly.
|
|
250
|
+
|
|
251
|
+
Additionally we will reset SPSA and LBFGS memory every 100 steps to remove influence from old gradient estimates.
|
|
252
|
+
|
|
253
|
+
```
|
|
254
|
+
opt = tz.Modular(
|
|
255
|
+
bench.parameters(),
|
|
256
|
+
tz.m.ResetEvery(
|
|
257
|
+
[tz.m.RandomizedFDM(n_samples=10, pre_generate=True, beta=0.99), tz.m.LBFGS()],
|
|
258
|
+
steps = 100,
|
|
259
|
+
),
|
|
260
|
+
tz.m.Backtracking()
|
|
261
|
+
)
|
|
262
|
+
```
|
|
262
263
|
"""
|
|
263
264
|
PRE_MULTIPLY_BY_H = True
|
|
264
265
|
def __init__(
|
|
@@ -280,6 +281,7 @@ class RandomizedFDM(GradApproximator):
|
|
|
280
281
|
generator = self.global_state.get('generator', None) # avoid resetting generator
|
|
281
282
|
self.global_state.clear()
|
|
282
283
|
if generator is not None: self.global_state['generator'] = generator
|
|
284
|
+
for c in self.children.values(): c.reset()
|
|
283
285
|
|
|
284
286
|
def _get_generator(self, seed: int | None | torch.Generator, params: list[torch.Tensor]):
|
|
285
287
|
if 'generator' not in self.global_state:
|
|
@@ -290,15 +292,15 @@ class RandomizedFDM(GradApproximator):
|
|
|
290
292
|
|
|
291
293
|
def pre_step(self, var):
|
|
292
294
|
h, beta = self.get_settings(var.params, 'h', 'beta')
|
|
293
|
-
|
|
294
|
-
n_samples =
|
|
295
|
-
distribution =
|
|
296
|
-
pre_generate =
|
|
295
|
+
|
|
296
|
+
n_samples = self.defaults['n_samples']
|
|
297
|
+
distribution = self.defaults['distribution']
|
|
298
|
+
pre_generate = self.defaults['pre_generate']
|
|
297
299
|
|
|
298
300
|
if pre_generate:
|
|
299
301
|
params = TensorList(var.params)
|
|
300
|
-
generator = self._get_generator(
|
|
301
|
-
perturbations = [params.sample_like(distribution=distribution, generator=generator) for _ in range(n_samples)]
|
|
302
|
+
generator = self._get_generator(self.defaults['seed'], var.params)
|
|
303
|
+
perturbations = [params.sample_like(distribution=distribution, variance=1, generator=generator) for _ in range(n_samples)]
|
|
302
304
|
|
|
303
305
|
if self.PRE_MULTIPLY_BY_H:
|
|
304
306
|
torch._foreach_mul_([p for l in perturbations for p in l], [v for vv in h for v in [vv]*n_samples])
|
|
@@ -339,27 +341,44 @@ class RandomizedFDM(GradApproximator):
|
|
|
339
341
|
grad = None
|
|
340
342
|
for i in range(n_samples):
|
|
341
343
|
prt = perturbations[i]
|
|
342
|
-
|
|
344
|
+
|
|
345
|
+
if prt[0] is None:
|
|
346
|
+
prt = params.sample_like(distribution=distribution, generator=generator, variance=1).mul_(h)
|
|
347
|
+
|
|
343
348
|
else: prt = TensorList(prt)
|
|
344
349
|
|
|
345
350
|
loss, loss_approx, d = fd_fn(closure=closure, params=params, p_fn=lambda: prt, h=h, f_0=loss)
|
|
351
|
+
# here `d` is a numberlist of directional derivatives, due to per parameter `h` values.
|
|
352
|
+
|
|
353
|
+
# support for per-sample values which gives better estimate
|
|
354
|
+
if d[0].numel() > 1: d = d.map(torch.mean)
|
|
355
|
+
|
|
346
356
|
if grad is None: grad = prt * d
|
|
347
357
|
else: grad += prt * d
|
|
348
358
|
|
|
349
359
|
params.set_(orig_params)
|
|
350
360
|
assert grad is not None
|
|
351
361
|
if n_samples > 1: grad.div_(n_samples)
|
|
362
|
+
|
|
363
|
+
# mean if got per-sample values
|
|
364
|
+
if loss is not None:
|
|
365
|
+
if loss.numel() > 1:
|
|
366
|
+
loss = loss.mean()
|
|
367
|
+
|
|
368
|
+
if loss_approx is not None:
|
|
369
|
+
if loss_approx.numel() > 1:
|
|
370
|
+
loss_approx = loss_approx.mean()
|
|
371
|
+
|
|
352
372
|
return grad, loss, loss_approx
|
|
353
373
|
|
|
354
374
|
class SPSA(RandomizedFDM):
|
|
355
375
|
"""
|
|
356
376
|
Gradient approximation via Simultaneous perturbation stochastic approximation (SPSA) method.
|
|
357
377
|
|
|
358
|
-
|
|
378
|
+
Note:
|
|
359
379
|
This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients,
|
|
360
380
|
and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.
|
|
361
381
|
|
|
362
|
-
|
|
363
382
|
Args:
|
|
364
383
|
h (float, optional): finite difference step size of jvp_method is set to `forward` or `central`. Defaults to 1e-3.
|
|
365
384
|
n_samples (int, optional): number of random gradient samples. Defaults to 1.
|
|
@@ -380,7 +399,7 @@ class RDSA(RandomizedFDM):
|
|
|
380
399
|
"""
|
|
381
400
|
Gradient approximation via Random-direction stochastic approximation (RDSA) method.
|
|
382
401
|
|
|
383
|
-
|
|
402
|
+
Note:
|
|
384
403
|
This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients,
|
|
385
404
|
and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.
|
|
386
405
|
|
|
@@ -417,7 +436,7 @@ class GaussianSmoothing(RandomizedFDM):
|
|
|
417
436
|
"""
|
|
418
437
|
Gradient approximation via Gaussian smoothing method.
|
|
419
438
|
|
|
420
|
-
|
|
439
|
+
Note:
|
|
421
440
|
This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients,
|
|
422
441
|
and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.
|
|
423
442
|
|
|
@@ -453,7 +472,7 @@ class GaussianSmoothing(RandomizedFDM):
|
|
|
453
472
|
class MeZO(GradApproximator):
|
|
454
473
|
"""Gradient approximation via memory-efficient zeroth order optimizer (MeZO) - https://arxiv.org/abs/2305.17333.
|
|
455
474
|
|
|
456
|
-
|
|
475
|
+
Note:
|
|
457
476
|
This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients,
|
|
458
477
|
and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.
|
|
459
478
|
|
|
@@ -476,15 +495,18 @@ class MeZO(GradApproximator):
|
|
|
476
495
|
super().__init__(defaults, target=target)
|
|
477
496
|
|
|
478
497
|
def _seeded_perturbation(self, params: list[torch.Tensor], distribution, seed, h):
|
|
479
|
-
|
|
480
|
-
distribution=distribution,
|
|
481
|
-
|
|
498
|
+
prt = TensorList(params).sample_like(
|
|
499
|
+
distribution=distribution,
|
|
500
|
+
variance=h,
|
|
501
|
+
generator=torch.Generator(params[0].device).manual_seed(seed)
|
|
502
|
+
)
|
|
503
|
+
return prt
|
|
482
504
|
|
|
483
505
|
def pre_step(self, var):
|
|
484
506
|
h = NumberList(self.settings[p]['h'] for p in var.params)
|
|
485
|
-
|
|
486
|
-
n_samples =
|
|
487
|
-
distribution =
|
|
507
|
+
|
|
508
|
+
n_samples = self.defaults['n_samples']
|
|
509
|
+
distribution = self.defaults['distribution']
|
|
488
510
|
|
|
489
511
|
step = var.current_step
|
|
490
512
|
|
|
@@ -1 +1 @@
|
|
|
1
|
-
from .higher_order_newton import HigherOrderNewton
|
|
1
|
+
from .higher_order_newton import HigherOrderNewton
|