torchzero 0.3.11__py3-none-any.whl → 0.3.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_opts.py +95 -76
- tests/test_tensorlist.py +8 -7
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +229 -72
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +44 -24
- torchzero/modules/__init__.py +13 -5
- torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
- torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
- torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
- torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
- torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
- torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
- torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
- torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
- torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
- torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
- torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
- torchzero/modules/clipping/clipping.py +85 -92
- torchzero/modules/clipping/ema_clipping.py +5 -5
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
- torchzero/modules/experimental/__init__.py +9 -32
- torchzero/modules/experimental/dct.py +2 -2
- torchzero/modules/experimental/fft.py +2 -2
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +27 -14
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/experimental/spsa1.py +93 -0
- torchzero/modules/experimental/structural_projections.py +1 -1
- torchzero/modules/functional.py +50 -14
- torchzero/modules/grad_approximation/__init__.py +1 -1
- torchzero/modules/grad_approximation/fdm.py +19 -20
- torchzero/modules/grad_approximation/forward_gradient.py +6 -7
- torchzero/modules/grad_approximation/grad_approximator.py +43 -47
- torchzero/modules/grad_approximation/rfdm.py +114 -175
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +31 -23
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +2 -2
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +69 -44
- torchzero/modules/line_search/backtracking.py +83 -70
- torchzero/modules/line_search/line_search.py +159 -68
- torchzero/modules/line_search/scipy.py +16 -4
- torchzero/modules/line_search/strong_wolfe.py +319 -220
- torchzero/modules/misc/__init__.py +8 -0
- torchzero/modules/misc/debug.py +4 -4
- torchzero/modules/misc/escape.py +9 -7
- torchzero/modules/misc/gradient_accumulation.py +88 -22
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +82 -15
- torchzero/modules/misc/multistep.py +47 -11
- torchzero/modules/misc/regularization.py +5 -9
- torchzero/modules/misc/split.py +55 -35
- torchzero/modules/misc/switch.py +1 -1
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +3 -3
- torchzero/modules/momentum/cautious.py +42 -47
- torchzero/modules/momentum/momentum.py +35 -1
- torchzero/modules/ops/__init__.py +9 -1
- torchzero/modules/ops/binary.py +9 -8
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
- torchzero/modules/ops/multi.py +15 -15
- torchzero/modules/ops/reduce.py +1 -1
- torchzero/modules/ops/utility.py +12 -8
- torchzero/modules/projections/projection.py +4 -4
- torchzero/modules/quasi_newton/__init__.py +1 -16
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
- torchzero/modules/quasi_newton/lbfgs.py +256 -200
- torchzero/modules/quasi_newton/lsr1.py +167 -132
- torchzero/modules/quasi_newton/quasi_newton.py +346 -446
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +253 -0
- torchzero/modules/second_order/__init__.py +2 -1
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +133 -88
- torchzero/modules/second_order/newton_cg.py +207 -170
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +1 -1
- torchzero/modules/step_size/adaptive.py +312 -47
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +99 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/weight_decay.py +65 -64
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +122 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +0 -1
- torchzero/optim/wrappers/fcmaes.py +3 -2
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +2 -2
- torchzero/optim/wrappers/scipy.py +81 -22
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +123 -111
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +226 -154
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/optimizer.py +2 -2
- torchzero/utils/python_tools.py +7 -0
- torchzero/utils/tensorlist.py +105 -34
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.14.dist-info/METADATA +14 -0
- torchzero-0.3.14.dist-info/RECORD +167 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -59
- docs/source/docstring template.py +0 -46
- torchzero/modules/experimental/absoap.py +0 -253
- torchzero/modules/experimental/adadam.py +0 -118
- torchzero/modules/experimental/adamY.py +0 -131
- torchzero/modules/experimental/adam_lambertw.py +0 -149
- torchzero/modules/experimental/adaptive_step_size.py +0 -90
- torchzero/modules/experimental/adasoap.py +0 -177
- torchzero/modules/experimental/cosine.py +0 -214
- torchzero/modules/experimental/cubic_adam.py +0 -97
- torchzero/modules/experimental/eigendescent.py +0 -120
- torchzero/modules/experimental/etf.py +0 -195
- torchzero/modules/experimental/exp_adam.py +0 -113
- torchzero/modules/experimental/expanded_lbfgs.py +0 -141
- torchzero/modules/experimental/hnewton.py +0 -85
- torchzero/modules/experimental/modular_lbfgs.py +0 -265
- torchzero/modules/experimental/parabolic_search.py +0 -220
- torchzero/modules/experimental/subspace_preconditioners.py +0 -145
- torchzero/modules/experimental/tensor_adagrad.py +0 -42
- torchzero/modules/line_search/polynomial.py +0 -233
- torchzero/modules/momentum/matrix_momentum.py +0 -193
- torchzero/modules/optimizers/adagrad.py +0 -165
- torchzero/modules/quasi_newton/trust_region.py +0 -397
- torchzero/modules/smoothing/gaussian.py +0 -198
- torchzero-0.3.11.dist-info/METADATA +0 -404
- torchzero-0.3.11.dist-info/RECORD +0 -159
- torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
- /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
from typing import Any
|
|
3
|
+
from functools import partial
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from ...utils import TensorList, NumberList
|
|
7
|
+
from ..grad_approximation.grad_approximator import GradApproximator, GradTarget
|
|
8
|
+
|
|
9
|
+
class SPSA1(GradApproximator):
|
|
10
|
+
"""One-measurement variant of SPSA. Unlike standard two-measurement SPSA, the estimated
|
|
11
|
+
gradient often won't be a descent direction, however the expectation is biased towards
|
|
12
|
+
the descent direction. Therefore this variant of SPSA is only recommended for a specific
|
|
13
|
+
class of problems where the objective function changes on each evaluation,
|
|
14
|
+
for example feedback control problems.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
h (float, optional):
|
|
18
|
+
finite difference step size, recommended to set to same value as learning rate. Defaults to 1e-3.
|
|
19
|
+
n_samples (int, optional): number of random samples. Defaults to 1.
|
|
20
|
+
eps (float, optional): measurement noise estimate. Defaults to 1e-8.
|
|
21
|
+
seed (int | None | torch.Generator, optional): random seed. Defaults to None.
|
|
22
|
+
target (GradTarget, optional): what to set on closure. Defaults to "closure".
|
|
23
|
+
|
|
24
|
+
Reference:
|
|
25
|
+
[SPALL, JAMES C. "A One-measurement Form of Simultaneous Stochastic Approximation](https://www.jhuapl.edu/spsa/PDF-SPSA/automatica97_one_measSPSA.pdf)."
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
h: float = 1e-3,
|
|
31
|
+
n_samples: int = 1,
|
|
32
|
+
eps: float = 1e-8, # measurement noise
|
|
33
|
+
pre_generate = False,
|
|
34
|
+
seed: int | None | torch.Generator = None,
|
|
35
|
+
target: GradTarget = "closure",
|
|
36
|
+
):
|
|
37
|
+
defaults = dict(h=h, eps=eps, n_samples=n_samples, pre_generate=pre_generate, seed=seed)
|
|
38
|
+
super().__init__(defaults, target=target)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def pre_step(self, var):
|
|
42
|
+
|
|
43
|
+
if self.defaults['pre_generate']:
|
|
44
|
+
|
|
45
|
+
params = TensorList(var.params)
|
|
46
|
+
generator = self.get_generator(params[0].device, self.defaults['seed'])
|
|
47
|
+
|
|
48
|
+
n_samples = self.defaults['n_samples']
|
|
49
|
+
h = self.get_settings(var.params, 'h')
|
|
50
|
+
|
|
51
|
+
perturbations = [params.sample_like(distribution='rademacher', generator=generator) for _ in range(n_samples)]
|
|
52
|
+
torch._foreach_mul_([p for l in perturbations for p in l], [v for vv in h for v in [vv]*n_samples])
|
|
53
|
+
|
|
54
|
+
for param, prt in zip(params, zip(*perturbations)):
|
|
55
|
+
self.state[param]['perturbations'] = prt
|
|
56
|
+
|
|
57
|
+
@torch.no_grad
|
|
58
|
+
def approximate(self, closure, params, loss):
|
|
59
|
+
generator = self.get_generator(params[0].device, self.defaults['seed'])
|
|
60
|
+
|
|
61
|
+
params = TensorList(params)
|
|
62
|
+
orig_params = params.clone() # store to avoid small changes due to float imprecision
|
|
63
|
+
loss_approx = None
|
|
64
|
+
|
|
65
|
+
h, eps = self.get_settings(params, "h", "eps", cls=NumberList)
|
|
66
|
+
n_samples = self.defaults['n_samples']
|
|
67
|
+
|
|
68
|
+
default = [None]*n_samples
|
|
69
|
+
# perturbations are pre-multiplied by h
|
|
70
|
+
perturbations = list(zip(*(self.state[p].get('perturbations', default) for p in params)))
|
|
71
|
+
|
|
72
|
+
grad = None
|
|
73
|
+
for i in range(n_samples):
|
|
74
|
+
prt = perturbations[i]
|
|
75
|
+
|
|
76
|
+
if prt[0] is None:
|
|
77
|
+
prt = params.sample_like('rademacher', generator=generator).mul_(h)
|
|
78
|
+
|
|
79
|
+
else: prt = TensorList(prt)
|
|
80
|
+
|
|
81
|
+
params += prt
|
|
82
|
+
L = closure(False)
|
|
83
|
+
params.copy_(orig_params)
|
|
84
|
+
|
|
85
|
+
sample = prt * ((L + eps) / h)
|
|
86
|
+
if grad is None: grad = sample
|
|
87
|
+
else: grad += sample
|
|
88
|
+
|
|
89
|
+
assert grad is not None
|
|
90
|
+
if n_samples > 1: grad.div_(n_samples)
|
|
91
|
+
|
|
92
|
+
# mean if got per-sample values
|
|
93
|
+
return grad, loss, loss_approx
|
torchzero/modules/functional.py
CHANGED
|
@@ -9,9 +9,17 @@ Additional functional variants are present in most module files, e.g. `adam_`, `
|
|
|
9
9
|
"""
|
|
10
10
|
from collections.abc import Callable
|
|
11
11
|
from typing import overload
|
|
12
|
+
|
|
12
13
|
import torch
|
|
13
14
|
|
|
14
|
-
from ..utils import
|
|
15
|
+
from ..utils import (
|
|
16
|
+
NumberList,
|
|
17
|
+
TensorList,
|
|
18
|
+
generic_finfo_eps,
|
|
19
|
+
generic_max,
|
|
20
|
+
generic_sum,
|
|
21
|
+
tofloat,
|
|
22
|
+
)
|
|
15
23
|
|
|
16
24
|
inf = float('inf')
|
|
17
25
|
|
|
@@ -87,10 +95,10 @@ def root(tensors_:TensorList, p:float, inplace: bool):
|
|
|
87
95
|
if p == 1: return tensors_.abs_()
|
|
88
96
|
if p == 2: return tensors_.sqrt_()
|
|
89
97
|
return tensors_.pow_(1/p)
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
98
|
+
|
|
99
|
+
if p == 1: return tensors_.abs()
|
|
100
|
+
if p == 2: return tensors_.sqrt()
|
|
101
|
+
return tensors_.pow(1/p)
|
|
94
102
|
|
|
95
103
|
|
|
96
104
|
def ema_(
|
|
@@ -207,13 +215,41 @@ def sqrt_centered_ema_sq_(
|
|
|
207
215
|
ema_sq_fn=lambda *a, **kw: centered_ema_sq_(*a, **kw, exp_avg_=exp_avg_)
|
|
208
216
|
)
|
|
209
217
|
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
218
|
+
def initial_step_size(tensors: torch.Tensor | TensorList, eps=None) -> float:
|
|
219
|
+
"""initial scaling taken from pytorch L-BFGS to avoid requiring a lot of line search iterations,
|
|
220
|
+
this version is safer and makes sure largest value isn't smaller than epsilon."""
|
|
221
|
+
tensors_abs = tensors.abs()
|
|
222
|
+
tensors_sum = generic_sum(tensors_abs)
|
|
223
|
+
tensors_max = generic_max(tensors_abs)
|
|
224
|
+
|
|
225
|
+
feps = generic_finfo_eps(tensors)
|
|
226
|
+
if eps is None: eps = feps
|
|
227
|
+
else: eps = max(eps, feps)
|
|
228
|
+
|
|
229
|
+
# scale should not make largest value smaller than epsilon
|
|
230
|
+
min = eps / tensors_max
|
|
231
|
+
if min >= 1: return 1.0
|
|
232
|
+
|
|
233
|
+
scale = 1 / tensors_sum
|
|
234
|
+
scale = scale.clip(min=min.item(), max=1)
|
|
235
|
+
return scale.item()
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
def epsilon_step_size(tensors: torch.Tensor | TensorList, alpha=1e-7) -> float:
|
|
239
|
+
"""makes sure largest value isn't smaller than epsilon."""
|
|
240
|
+
tensors_abs = tensors.abs()
|
|
241
|
+
tensors_max = generic_max(tensors_abs)
|
|
242
|
+
if tensors_max < alpha: return 1.0
|
|
243
|
+
|
|
244
|
+
if tensors_max < 1: alpha = alpha / tensors_max
|
|
245
|
+
return tofloat(alpha)
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
def safe_clip(x: torch.Tensor, min=None):
|
|
250
|
+
"""makes sure absolute value of scalar tensor x is not smaller than min"""
|
|
251
|
+
assert x.numel() == 1, x.shape
|
|
252
|
+
if min is None: min = torch.finfo(x.dtype).tiny * 2
|
|
219
253
|
|
|
254
|
+
if x.abs() < min: return x.new_full(x.size(), min).copysign(x)
|
|
255
|
+
return x
|
|
@@ -93,7 +93,7 @@ _FD_FUNCS = {
|
|
|
93
93
|
class FDM(GradApproximator):
|
|
94
94
|
"""Approximate gradients via finite difference method.
|
|
95
95
|
|
|
96
|
-
|
|
96
|
+
Note:
|
|
97
97
|
This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients,
|
|
98
98
|
and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.
|
|
99
99
|
|
|
@@ -103,24 +103,23 @@ class FDM(GradApproximator):
|
|
|
103
103
|
target (GradTarget, optional): what to set on var. Defaults to 'closure'.
|
|
104
104
|
|
|
105
105
|
Examples:
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
)
|
|
106
|
+
plain FDM:
|
|
107
|
+
|
|
108
|
+
```python
|
|
109
|
+
fdm = tz.Modular(model.parameters(), tz.m.FDM(), tz.m.LR(1e-2))
|
|
110
|
+
```
|
|
111
|
+
|
|
112
|
+
Any gradient-based method can use FDM-estimated gradients.
|
|
113
|
+
```python
|
|
114
|
+
fdm_ncg = tz.Modular(
|
|
115
|
+
model.parameters(),
|
|
116
|
+
tz.m.FDM(),
|
|
117
|
+
# set hvp_method to "forward" so that it
|
|
118
|
+
# uses gradient difference instead of autograd
|
|
119
|
+
tz.m.NewtonCG(hvp_method="forward"),
|
|
120
|
+
tz.m.Backtracking()
|
|
121
|
+
)
|
|
122
|
+
```
|
|
124
123
|
"""
|
|
125
124
|
def __init__(self, h: float=1e-3, formula: _FD_Formula = 'central', target: GradTarget = 'closure'):
|
|
126
125
|
defaults = dict(h=h, formula=formula)
|
|
@@ -139,7 +138,7 @@ class FDM(GradApproximator):
|
|
|
139
138
|
h = settings['h']
|
|
140
139
|
fd_fn = _FD_FUNCS[settings['formula']]
|
|
141
140
|
|
|
142
|
-
p_flat = p.
|
|
141
|
+
p_flat = p.ravel(); g_flat = g.ravel()
|
|
143
142
|
for i in range(len(p_flat)):
|
|
144
143
|
loss, loss_approx, d = fd_fn(closure=closure, param=p_flat, idx=i, h=h, v_0=loss)
|
|
145
144
|
g_flat[i] = d
|
|
@@ -15,7 +15,7 @@ class ForwardGradient(RandomizedFDM):
|
|
|
15
15
|
|
|
16
16
|
This method samples one or more directional derivatives evaluated via autograd jacobian-vector products. This is very similar to randomized finite difference.
|
|
17
17
|
|
|
18
|
-
|
|
18
|
+
Note:
|
|
19
19
|
This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients,
|
|
20
20
|
and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.
|
|
21
21
|
|
|
@@ -23,8 +23,6 @@ class ForwardGradient(RandomizedFDM):
|
|
|
23
23
|
Args:
|
|
24
24
|
n_samples (int, optional): number of random gradient samples. Defaults to 1.
|
|
25
25
|
distribution (Distributions, optional): distribution for random gradient samples. Defaults to "gaussian".
|
|
26
|
-
beta (float, optional):
|
|
27
|
-
If this is set to a value higher than zero, instead of using directional derivatives in a new random direction on each step, the direction changes gradually with momentum based on this value. This may make it possible to use methods with memory. Defaults to 0.
|
|
28
26
|
pre_generate (bool, optional):
|
|
29
27
|
whether to pre-generate gradient samples before each step. If samples are not pre-generated, whenever a method performs multiple closure evaluations, the gradient will be evaluated in different directions each time. Defaults to True.
|
|
30
28
|
jvp_method (str, optional):
|
|
@@ -40,14 +38,13 @@ class ForwardGradient(RandomizedFDM):
|
|
|
40
38
|
self,
|
|
41
39
|
n_samples: int = 1,
|
|
42
40
|
distribution: Distributions = "gaussian",
|
|
43
|
-
beta: float = 0,
|
|
44
41
|
pre_generate = True,
|
|
45
42
|
jvp_method: Literal['autograd', 'forward', 'central'] = 'autograd',
|
|
46
43
|
h: float = 1e-3,
|
|
47
44
|
target: GradTarget = "closure",
|
|
48
45
|
seed: int | None | torch.Generator = None,
|
|
49
46
|
):
|
|
50
|
-
super().__init__(h=h, n_samples=n_samples, distribution=distribution,
|
|
47
|
+
super().__init__(h=h, n_samples=n_samples, distribution=distribution, target=target, pre_generate=pre_generate, seed=seed)
|
|
51
48
|
self.defaults['jvp_method'] = jvp_method
|
|
52
49
|
|
|
53
50
|
@torch.no_grad
|
|
@@ -62,12 +59,14 @@ class ForwardGradient(RandomizedFDM):
|
|
|
62
59
|
distribution = settings['distribution']
|
|
63
60
|
default = [None]*n_samples
|
|
64
61
|
perturbations = list(zip(*(self.state[p].get('perturbations', default) for p in params)))
|
|
65
|
-
generator = self.
|
|
62
|
+
generator = self.get_generator(params[0].device, self.defaults['seed'])
|
|
66
63
|
|
|
67
64
|
grad = None
|
|
68
65
|
for i in range(n_samples):
|
|
69
66
|
prt = perturbations[i]
|
|
70
|
-
if prt[0] is None:
|
|
67
|
+
if prt[0] is None:
|
|
68
|
+
prt = params.sample_like(distribution=distribution, variance=1, generator=generator)
|
|
69
|
+
|
|
71
70
|
else: prt = TensorList(prt)
|
|
72
71
|
|
|
73
72
|
if jvp_method == 'autograd':
|
|
@@ -24,63 +24,59 @@ class GradApproximator(Module, ABC):
|
|
|
24
24
|
|
|
25
25
|
Example:
|
|
26
26
|
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
"""
|
|
27
|
+
Basic SPSA method implementation.
|
|
28
|
+
```python
|
|
29
|
+
class SPSA(GradApproximator):
|
|
30
|
+
def __init__(self, h=1e-3):
|
|
31
|
+
defaults = dict(h=h)
|
|
32
|
+
super().__init__(defaults)
|
|
33
|
+
|
|
34
|
+
@torch.no_grad
|
|
35
|
+
def approximate(self, closure, params, loss):
|
|
36
|
+
perturbation = [rademacher_like(p) * self.settings[p]['h'] for p in params]
|
|
37
|
+
|
|
38
|
+
# evaluate params + perturbation
|
|
39
|
+
torch._foreach_add_(params, perturbation)
|
|
40
|
+
loss_plus = closure(False)
|
|
41
|
+
|
|
42
|
+
# evaluate params - perturbation
|
|
43
|
+
torch._foreach_sub_(params, perturbation)
|
|
44
|
+
torch._foreach_sub_(params, perturbation)
|
|
45
|
+
loss_minus = closure(False)
|
|
46
|
+
|
|
47
|
+
# restore original params
|
|
48
|
+
torch._foreach_add_(params, perturbation)
|
|
49
|
+
|
|
50
|
+
# calculate SPSA gradients
|
|
51
|
+
spsa_grads = []
|
|
52
|
+
for p, pert in zip(params, perturbation):
|
|
53
|
+
settings = self.settings[p]
|
|
54
|
+
h = settings['h']
|
|
55
|
+
d = (loss_plus - loss_minus) / (2*(h**2))
|
|
56
|
+
spsa_grads.append(pert * d)
|
|
57
|
+
|
|
58
|
+
# returns tuple: (grads, loss, loss_approx)
|
|
59
|
+
# loss must be with initial parameters
|
|
60
|
+
# since we only evaluated loss with perturbed parameters
|
|
61
|
+
# we only have loss_approx
|
|
62
|
+
return spsa_grads, None, loss_plus
|
|
63
|
+
```
|
|
64
|
+
"""
|
|
67
65
|
def __init__(self, defaults: dict[str, Any] | None = None, target: GradTarget = 'closure'):
|
|
68
66
|
super().__init__(defaults)
|
|
69
67
|
self._target: GradTarget = target
|
|
70
68
|
|
|
71
69
|
@abstractmethod
|
|
72
|
-
def approximate(self, closure: Callable, params: list[torch.Tensor], loss:
|
|
73
|
-
"""Returns a tuple: (grad, loss, loss_approx)
|
|
70
|
+
def approximate(self, closure: Callable, params: list[torch.Tensor], loss: torch.Tensor | None) -> tuple[Iterable[torch.Tensor], torch.Tensor | None, torch.Tensor | None]:
|
|
71
|
+
"""Returns a tuple: ``(grad, loss, loss_approx)``, make sure this resets parameters to their original values!"""
|
|
74
72
|
|
|
75
|
-
def pre_step(self, var: Var) ->
|
|
73
|
+
def pre_step(self, var: Var) -> None:
|
|
76
74
|
"""This runs once before each step, whereas `approximate` may run multiple times per step if further modules
|
|
77
75
|
evaluate gradients at multiple points. This is useful for example to pre-generate new random perturbations."""
|
|
78
|
-
return var
|
|
79
76
|
|
|
80
77
|
@torch.no_grad
|
|
81
78
|
def step(self, var):
|
|
82
|
-
|
|
83
|
-
if isinstance(ret, Var): var = ret
|
|
79
|
+
self.pre_step(var)
|
|
84
80
|
|
|
85
81
|
if var.closure is None: raise RuntimeError("Gradient approximation requires closure")
|
|
86
82
|
params, closure, loss = var.params, var.closure, var.loss
|
|
@@ -108,4 +104,4 @@ class GradApproximator(Module, ABC):
|
|
|
108
104
|
else: raise ValueError(self._target)
|
|
109
105
|
return var
|
|
110
106
|
|
|
111
|
-
_FD_Formula = Literal['forward', 'forward2', 'backward', 'backward2', 'central', 'central2', 'central3', 'forward3', 'backward3', 'central4', 'forward4', 'forward5', '
|
|
107
|
+
_FD_Formula = Literal['forward', 'forward2', 'backward', 'backward2', 'central', 'central2', 'central3', 'forward3', 'backward3', 'central4', 'forward4', 'forward5', 'bspsa4']
|