torchzero 0.3.9__py3-none-any.whl → 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +6 -4
- docs/source/docstring template.py +46 -0
- tests/test_identical.py +2 -3
- tests/test_opts.py +115 -68
- tests/test_tensorlist.py +2 -2
- tests/test_vars.py +62 -61
- torchzero/core/__init__.py +2 -3
- torchzero/core/module.py +185 -53
- torchzero/core/transform.py +327 -159
- torchzero/modules/__init__.py +3 -1
- torchzero/modules/clipping/clipping.py +120 -23
- torchzero/modules/clipping/ema_clipping.py +37 -22
- torchzero/modules/clipping/growth_clipping.py +20 -21
- torchzero/modules/experimental/__init__.py +30 -4
- torchzero/modules/experimental/absoap.py +53 -156
- torchzero/modules/experimental/adadam.py +22 -15
- torchzero/modules/experimental/adamY.py +21 -25
- torchzero/modules/experimental/adam_lambertw.py +149 -0
- torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +37 -8
- torchzero/modules/experimental/adasoap.py +24 -129
- torchzero/modules/experimental/cosine.py +214 -0
- torchzero/modules/experimental/cubic_adam.py +97 -0
- torchzero/modules/experimental/curveball.py +12 -12
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/experimental/eigendescent.py +120 -0
- torchzero/modules/experimental/etf.py +195 -0
- torchzero/modules/experimental/exp_adam.py +113 -0
- torchzero/modules/experimental/expanded_lbfgs.py +141 -0
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/hnewton.py +85 -0
- torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +49 -50
- torchzero/modules/experimental/newton_solver.py +11 -11
- torchzero/modules/experimental/newtonnewton.py +92 -0
- torchzero/modules/experimental/parabolic_search.py +220 -0
- torchzero/modules/experimental/reduce_outward_lr.py +10 -7
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
- torchzero/modules/experimental/subspace_preconditioners.py +20 -10
- torchzero/modules/experimental/tensor_adagrad.py +42 -0
- torchzero/modules/functional.py +12 -2
- torchzero/modules/grad_approximation/fdm.py +31 -4
- torchzero/modules/grad_approximation/forward_gradient.py +17 -7
- torchzero/modules/grad_approximation/grad_approximator.py +69 -24
- torchzero/modules/grad_approximation/rfdm.py +310 -50
- torchzero/modules/higher_order/__init__.py +1 -0
- torchzero/modules/higher_order/higher_order_newton.py +319 -0
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/adaptive.py +99 -0
- torchzero/modules/line_search/backtracking.py +75 -31
- torchzero/modules/line_search/line_search.py +107 -49
- torchzero/modules/line_search/polynomial.py +233 -0
- torchzero/modules/line_search/scipy.py +20 -5
- torchzero/modules/line_search/strong_wolfe.py +52 -36
- torchzero/modules/misc/__init__.py +27 -0
- torchzero/modules/misc/debug.py +48 -0
- torchzero/modules/misc/escape.py +60 -0
- torchzero/modules/misc/gradient_accumulation.py +70 -0
- torchzero/modules/misc/misc.py +316 -0
- torchzero/modules/misc/multistep.py +158 -0
- torchzero/modules/misc/regularization.py +171 -0
- torchzero/modules/misc/split.py +103 -0
- torchzero/modules/{ops → misc}/switch.py +48 -7
- torchzero/modules/momentum/__init__.py +1 -1
- torchzero/modules/momentum/averaging.py +25 -10
- torchzero/modules/momentum/cautious.py +115 -40
- torchzero/modules/momentum/ema.py +92 -41
- torchzero/modules/momentum/experimental.py +21 -13
- torchzero/modules/momentum/matrix_momentum.py +145 -76
- torchzero/modules/momentum/momentum.py +25 -4
- torchzero/modules/ops/__init__.py +3 -31
- torchzero/modules/ops/accumulate.py +51 -25
- torchzero/modules/ops/binary.py +108 -62
- torchzero/modules/ops/multi.py +95 -34
- torchzero/modules/ops/reduce.py +31 -23
- torchzero/modules/ops/unary.py +37 -21
- torchzero/modules/ops/utility.py +53 -45
- torchzero/modules/optimizers/__init__.py +12 -3
- torchzero/modules/optimizers/adagrad.py +48 -29
- torchzero/modules/optimizers/adahessian.py +223 -0
- torchzero/modules/optimizers/adam.py +35 -37
- torchzero/modules/optimizers/adan.py +110 -0
- torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
- torchzero/modules/optimizers/esgd.py +171 -0
- torchzero/modules/optimizers/ladagrad.py +183 -0
- torchzero/modules/optimizers/lion.py +4 -4
- torchzero/modules/optimizers/mars.py +91 -0
- torchzero/modules/optimizers/msam.py +186 -0
- torchzero/modules/optimizers/muon.py +32 -7
- torchzero/modules/optimizers/orthograd.py +4 -5
- torchzero/modules/optimizers/rmsprop.py +19 -19
- torchzero/modules/optimizers/rprop.py +89 -52
- torchzero/modules/optimizers/sam.py +163 -0
- torchzero/modules/optimizers/shampoo.py +55 -27
- torchzero/modules/optimizers/soap.py +40 -37
- torchzero/modules/optimizers/sophia_h.py +82 -25
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +4 -2
- torchzero/modules/projections/projection.py +212 -118
- torchzero/modules/quasi_newton/__init__.py +44 -5
- torchzero/modules/quasi_newton/cg.py +190 -39
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
- torchzero/modules/quasi_newton/lbfgs.py +154 -97
- torchzero/modules/quasi_newton/lsr1.py +102 -58
- torchzero/modules/quasi_newton/quasi_newton.py +1032 -177
- torchzero/modules/quasi_newton/trust_region.py +397 -0
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/newton.py +245 -54
- torchzero/modules/second_order/newton_cg.py +311 -21
- torchzero/modules/second_order/nystrom.py +124 -21
- torchzero/modules/smoothing/gaussian.py +55 -21
- torchzero/modules/smoothing/laplacian.py +20 -12
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +122 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +126 -10
- torchzero/modules/wrappers/optim_wrapper.py +40 -12
- torchzero/optim/wrappers/directsearch.py +281 -0
- torchzero/optim/wrappers/fcmaes.py +105 -0
- torchzero/optim/wrappers/mads.py +89 -0
- torchzero/optim/wrappers/nevergrad.py +20 -5
- torchzero/optim/wrappers/nlopt.py +28 -14
- torchzero/optim/wrappers/optuna.py +70 -0
- torchzero/optim/wrappers/scipy.py +167 -16
- torchzero/utils/__init__.py +3 -7
- torchzero/utils/derivatives.py +5 -4
- torchzero/utils/linalg/__init__.py +1 -1
- torchzero/utils/linalg/solve.py +251 -12
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/optimizer.py +55 -74
- torchzero/utils/python_tools.py +27 -4
- torchzero/utils/tensorlist.py +40 -28
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/METADATA +76 -51
- torchzero-0.3.11.dist-info/RECORD +159 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/WHEEL +1 -1
- torchzero/core/preconditioner.py +0 -138
- torchzero/modules/experimental/algebraic_newton.py +0 -145
- torchzero/modules/experimental/soapy.py +0 -290
- torchzero/modules/experimental/spectral.py +0 -288
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/experimental/tropical_newton.py +0 -136
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/lr.py +0 -59
- torchzero/modules/lr/step_size.py +0 -97
- torchzero/modules/ops/debug.py +0 -25
- torchzero/modules/ops/misc.py +0 -419
- torchzero/modules/ops/split.py +0 -75
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero-0.3.9.dist-info/RECORD +0 -131
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
|
@@ -3,104 +3,274 @@ from typing import Any
|
|
|
3
3
|
from functools import partial
|
|
4
4
|
import torch
|
|
5
5
|
|
|
6
|
-
from ...utils import TensorList, Distributions, NumberList
|
|
6
|
+
from ...utils import TensorList, Distributions, NumberList
|
|
7
7
|
from .grad_approximator import GradApproximator, GradTarget, _FD_Formula
|
|
8
8
|
|
|
9
|
-
|
|
10
|
-
def _rforward2(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h, v_0: float | None):
|
|
9
|
+
def _rforward2(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h, f_0: float | None):
|
|
11
10
|
"""p_fn is a function that returns the perturbation.
|
|
12
11
|
It may return pre-generated one or generate one deterministically from a seed as in MeZO.
|
|
13
12
|
Returned perturbation must be multiplied by `h`."""
|
|
14
|
-
if
|
|
13
|
+
if f_0 is None: f_0 = closure(False)
|
|
15
14
|
params += p_fn()
|
|
16
|
-
|
|
15
|
+
f_1 = closure(False)
|
|
17
16
|
params -= p_fn()
|
|
18
17
|
h = h**2 # because perturbation already multiplied by h
|
|
19
|
-
return
|
|
18
|
+
return f_0, f_0, (f_1 - f_0) / h # (loss, loss_approx, grad)
|
|
20
19
|
|
|
21
|
-
def _rbackward2(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h,
|
|
22
|
-
if
|
|
20
|
+
def _rbackward2(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h, f_0: float | None):
|
|
21
|
+
if f_0 is None: f_0 = closure(False)
|
|
23
22
|
params -= p_fn()
|
|
24
|
-
|
|
23
|
+
f_m1 = closure(False)
|
|
25
24
|
params += p_fn()
|
|
26
25
|
h = h**2 # because perturbation already multiplied by h
|
|
27
|
-
return
|
|
26
|
+
return f_0, f_0, (f_0 - f_m1) / h
|
|
28
27
|
|
|
29
|
-
def _rcentral2(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h,
|
|
28
|
+
def _rcentral2(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h, f_0: Any):
|
|
30
29
|
params += p_fn()
|
|
31
|
-
|
|
30
|
+
f_1 = closure(False)
|
|
32
31
|
|
|
33
32
|
params -= p_fn() * 2
|
|
34
|
-
|
|
33
|
+
f_m1 = closure(False)
|
|
35
34
|
|
|
36
35
|
params += p_fn()
|
|
37
36
|
h = h**2 # because perturbation already multiplied by h
|
|
38
|
-
return
|
|
37
|
+
return f_0, f_1, (f_1 - f_m1) / (2 * h)
|
|
39
38
|
|
|
40
|
-
def _rforward3(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h,
|
|
41
|
-
if
|
|
39
|
+
def _rforward3(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h, f_0: float | None):
|
|
40
|
+
if f_0 is None: f_0 = closure(False)
|
|
42
41
|
params += p_fn()
|
|
43
|
-
|
|
42
|
+
f_1 = closure(False)
|
|
44
43
|
|
|
45
44
|
params += p_fn()
|
|
46
|
-
|
|
45
|
+
f_2 = closure(False)
|
|
47
46
|
|
|
48
47
|
params -= p_fn() * 2
|
|
49
48
|
h = h**2 # because perturbation already multiplied by h
|
|
50
|
-
return
|
|
49
|
+
return f_0, f_0, (-3*f_0 + 4*f_1 - f_2) / (2 * h)
|
|
51
50
|
|
|
52
|
-
def _rbackward3(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h,
|
|
53
|
-
if
|
|
51
|
+
def _rbackward3(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h, f_0: float | None):
|
|
52
|
+
if f_0 is None: f_0 = closure(False)
|
|
54
53
|
|
|
55
54
|
params -= p_fn()
|
|
56
|
-
|
|
55
|
+
f_m1 = closure(False)
|
|
57
56
|
|
|
58
57
|
params -= p_fn()
|
|
59
|
-
|
|
58
|
+
f_m2 = closure(False)
|
|
60
59
|
|
|
61
60
|
params += p_fn() * 2
|
|
62
61
|
h = h**2 # because perturbation already multiplied by h
|
|
63
|
-
return
|
|
62
|
+
return f_0, f_0, (f_m2 - 4*f_m1 + 3*f_0) / (2 * h)
|
|
64
63
|
|
|
65
|
-
def _rcentral4(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h,
|
|
64
|
+
def _rcentral4(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h, f_0: float | None):
|
|
66
65
|
params += p_fn()
|
|
67
|
-
|
|
66
|
+
f_1 = closure(False)
|
|
68
67
|
|
|
69
68
|
params += p_fn()
|
|
70
|
-
|
|
69
|
+
f_2 = closure(False)
|
|
71
70
|
|
|
72
71
|
params -= p_fn() * 3
|
|
73
|
-
|
|
72
|
+
f_m1 = closure(False)
|
|
74
73
|
|
|
75
74
|
params -= p_fn()
|
|
76
|
-
|
|
75
|
+
f_m2 = closure(False)
|
|
76
|
+
|
|
77
|
+
params += p_fn() * 2
|
|
78
|
+
h = h**2 # because perturbation already multiplied by h
|
|
79
|
+
return f_0, f_1, (f_m2 - 8*f_m1 + 8*f_1 - f_2) / (12 * h)
|
|
80
|
+
|
|
81
|
+
# some good ones
|
|
82
|
+
# Pachalyl S. et al. Generalized simultaneous perturbation-based gradient search with reduced estimator bias //IEEE Transactions on Automatic Control. – 2025.
|
|
83
|
+
# Three measurements GSPSA is _rforward3
|
|
84
|
+
# Four measurements GSPSA
|
|
85
|
+
def _rforward4(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h, f_0: float | None):
|
|
86
|
+
if f_0 is None: f_0 = closure(False)
|
|
87
|
+
params += p_fn()
|
|
88
|
+
f_1 = closure(False)
|
|
89
|
+
|
|
90
|
+
params += p_fn()
|
|
91
|
+
f_2 = closure(False)
|
|
92
|
+
|
|
93
|
+
params += p_fn()
|
|
94
|
+
f_3 = closure(False)
|
|
95
|
+
|
|
96
|
+
params -= p_fn() * 3
|
|
97
|
+
h = h**2 # because perturbation already multiplied by h
|
|
98
|
+
return f_0, f_0, (2*f_3 - 9*f_2 + 18*f_1 - 11*f_0) / (6 * h)
|
|
99
|
+
|
|
100
|
+
def _rforward5(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h, f_0: float | None):
|
|
101
|
+
if f_0 is None: f_0 = closure(False)
|
|
102
|
+
params += p_fn()
|
|
103
|
+
f_1 = closure(False)
|
|
104
|
+
|
|
105
|
+
params += p_fn()
|
|
106
|
+
f_2 = closure(False)
|
|
107
|
+
|
|
108
|
+
params += p_fn()
|
|
109
|
+
f_3 = closure(False)
|
|
110
|
+
|
|
111
|
+
params += p_fn()
|
|
112
|
+
f_4 = closure(False)
|
|
113
|
+
|
|
114
|
+
params -= p_fn() * 4
|
|
115
|
+
h = h**2 # because perturbation already multiplied by h
|
|
116
|
+
return f_0, f_0, (-3*f_4 + 16*f_3 - 36*f_2 + 48*f_1 - 25*f_0) / (12 * h)
|
|
117
|
+
|
|
118
|
+
# another central4
|
|
119
|
+
def _bgspsa4(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h, f_0: float | None):
|
|
120
|
+
params += p_fn()
|
|
121
|
+
f_1 = closure(False)
|
|
77
122
|
|
|
78
123
|
params += p_fn() * 2
|
|
124
|
+
f_3 = closure(False)
|
|
125
|
+
|
|
126
|
+
params -= p_fn() * 4
|
|
127
|
+
f_m1 = closure(False)
|
|
128
|
+
|
|
129
|
+
params -= p_fn() * 2
|
|
130
|
+
f_m3 = closure(False)
|
|
131
|
+
|
|
132
|
+
params += p_fn() * 3
|
|
79
133
|
h = h**2 # because perturbation already multiplied by h
|
|
80
|
-
return
|
|
134
|
+
return f_0, f_1, (27*f_1 - f_m1 - f_3 + f_m3) / (48 * h)
|
|
135
|
+
|
|
81
136
|
|
|
82
137
|
_RFD_FUNCS = {
|
|
138
|
+
"forward": _rforward2,
|
|
83
139
|
"forward2": _rforward2,
|
|
140
|
+
"backward": _rbackward2,
|
|
84
141
|
"backward2": _rbackward2,
|
|
142
|
+
"central": _rcentral2,
|
|
85
143
|
"central2": _rcentral2,
|
|
144
|
+
"central3": _rcentral2,
|
|
86
145
|
"forward3": _rforward3,
|
|
87
146
|
"backward3": _rbackward3,
|
|
88
147
|
"central4": _rcentral4,
|
|
148
|
+
"forward4": _rforward4,
|
|
149
|
+
"forward5": _rforward5,
|
|
150
|
+
"bspsa4": _bgspsa4,
|
|
89
151
|
}
|
|
90
152
|
|
|
91
153
|
|
|
92
154
|
class RandomizedFDM(GradApproximator):
|
|
155
|
+
"""Gradient approximation via a randomized finite-difference method.
|
|
156
|
+
|
|
157
|
+
.. note::
|
|
158
|
+
This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients,
|
|
159
|
+
and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.
|
|
160
|
+
|
|
161
|
+
Args:
|
|
162
|
+
h (float, optional): finite difference step size of jvp_method is set to `forward` or `central`. Defaults to 1e-3.
|
|
163
|
+
n_samples (int, optional): number of random gradient samples. Defaults to 1.
|
|
164
|
+
formula (_FD_Formula, optional): finite difference formula. Defaults to 'central2'.
|
|
165
|
+
distribution (Distributions, optional): distribution. Defaults to "rademacher".
|
|
166
|
+
If this is set to a value higher than zero, instead of using directional derivatives in a new random direction on each step, the direction changes gradually with momentum based on this value. This may make it possible to use methods with memory. Defaults to 0.
|
|
167
|
+
beta (float, optional): optinal momentum for generated perturbations. Defaults to 1e-3.
|
|
168
|
+
pre_generate (bool, optional):
|
|
169
|
+
whether to pre-generate gradient samples before each step. If samples are not pre-generated, whenever a method performs multiple closure evaluations, the gradient will be evaluated in different directions each time. Defaults to True.
|
|
170
|
+
seed (int | None | torch.Generator, optional): Seed for random generator. Defaults to None.
|
|
171
|
+
target (GradTarget, optional): what to set on var. Defaults to "closure".
|
|
172
|
+
|
|
173
|
+
Examples:
|
|
174
|
+
#### Simultaneous perturbation stochastic approximation (SPSA) method
|
|
175
|
+
|
|
176
|
+
SPSA is randomized finite differnce with rademacher distribution and central formula.
|
|
177
|
+
|
|
178
|
+
.. code-block:: python
|
|
179
|
+
|
|
180
|
+
spsa = tz.Modular(
|
|
181
|
+
model.parameters(),
|
|
182
|
+
tz.m.RandomizedFDM(formula="central", distribution="rademacher"),
|
|
183
|
+
tz.m.LR(1e-2)
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
#### Random-direction stochastic approximation (RDSA) method
|
|
187
|
+
|
|
188
|
+
RDSA is randomized finite differnce with usually gaussian distribution and central formula.
|
|
189
|
+
|
|
190
|
+
.. code-block:: python
|
|
191
|
+
|
|
192
|
+
rdsa = tz.Modular(
|
|
193
|
+
model.parameters(),
|
|
194
|
+
tz.m.RandomizedFDM(formula="central", distribution="gaussian"),
|
|
195
|
+
tz.m.LR(1e-2)
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
#### RandomizedFDM with momentum
|
|
199
|
+
|
|
200
|
+
Momentum might help by reducing the variance of the estimated gradients.
|
|
201
|
+
|
|
202
|
+
.. code-block:: python
|
|
203
|
+
|
|
204
|
+
momentum_spsa = tz.Modular(
|
|
205
|
+
model.parameters(),
|
|
206
|
+
tz.m.RandomizedFDM(),
|
|
207
|
+
tz.m.HeavyBall(0.9),
|
|
208
|
+
tz.m.LR(1e-3)
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
#### Gaussian smoothing method
|
|
212
|
+
|
|
213
|
+
GS uses many gaussian samples with possibly a larger finite difference step size.
|
|
214
|
+
|
|
215
|
+
.. code-block:: python
|
|
216
|
+
|
|
217
|
+
gs = tz.Modular(
|
|
218
|
+
model.parameters(),
|
|
219
|
+
tz.m.RandomizedFDM(n_samples=100, distribution="gaussian", formula="forward2", h=1e-1),
|
|
220
|
+
tz.m.NewtonCG(hvp_method="forward"),
|
|
221
|
+
tz.m.Backtracking()
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
#### SPSA-NewtonCG
|
|
225
|
+
|
|
226
|
+
NewtonCG with hessian-vector product estimated via gradient difference
|
|
227
|
+
calls closure multiple times per step. If each closure call estimates gradients
|
|
228
|
+
with different perturbations, NewtonCG is unable to produce useful directions.
|
|
229
|
+
|
|
230
|
+
By setting pre_generate to True, perturbations are generated once before each step,
|
|
231
|
+
and each closure call estimates gradients using the same pre-generated perturbations.
|
|
232
|
+
This way closure-based algorithms are able to use gradients estimated in a consistent way.
|
|
233
|
+
|
|
234
|
+
.. code-block:: python
|
|
235
|
+
|
|
236
|
+
opt = tz.Modular(
|
|
237
|
+
model.parameters(),
|
|
238
|
+
tz.m.RandomizedFDM(n_samples=10),
|
|
239
|
+
tz.m.NewtonCG(hvp_method="forward", pre_generate=True),
|
|
240
|
+
tz.m.Backtracking()
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
#### SPSA-BFGS
|
|
244
|
+
|
|
245
|
+
L-BFGS uses a memory of past parameter and gradient differences. If past gradients
|
|
246
|
+
were estimated with different perturbations, L-BFGS directions will be useless.
|
|
247
|
+
|
|
248
|
+
To alleviate this momentum can be added to random perturbations to make sure they only
|
|
249
|
+
change by a little bit, and the history stays relevant. The momentum is determined by the :code:`beta` parameter.
|
|
250
|
+
The disadvantage is that the subspace the algorithm is able to explore changes slowly.
|
|
251
|
+
|
|
252
|
+
Additionally we will reset BFGS memory every 100 steps to remove influence from old gradient estimates.
|
|
253
|
+
|
|
254
|
+
.. code-block:: python
|
|
255
|
+
|
|
256
|
+
opt = tz.Modular(
|
|
257
|
+
model.parameters(),
|
|
258
|
+
tz.m.RandomizedFDM(n_samples=10, pre_generate=True, beta=0.99),
|
|
259
|
+
tz.m.BFGS(reset_interval=100),
|
|
260
|
+
tz.m.Backtracking()
|
|
261
|
+
)
|
|
262
|
+
"""
|
|
93
263
|
PRE_MULTIPLY_BY_H = True
|
|
94
264
|
def __init__(
|
|
95
265
|
self,
|
|
96
266
|
h: float = 1e-3,
|
|
97
267
|
n_samples: int = 1,
|
|
98
|
-
formula: _FD_Formula = "
|
|
268
|
+
formula: _FD_Formula = "central",
|
|
99
269
|
distribution: Distributions = "rademacher",
|
|
100
270
|
beta: float = 0,
|
|
101
271
|
pre_generate = True,
|
|
102
|
-
target: GradTarget = "closure",
|
|
103
272
|
seed: int | None | torch.Generator = None,
|
|
273
|
+
target: GradTarget = "closure",
|
|
104
274
|
):
|
|
105
275
|
defaults = dict(h=h, formula=formula, n_samples=n_samples, distribution=distribution, beta=beta, pre_generate=pre_generate, seed=seed)
|
|
106
276
|
super().__init__(defaults, target=target)
|
|
@@ -118,16 +288,16 @@ class RandomizedFDM(GradApproximator):
|
|
|
118
288
|
else: self.global_state['generator'] = None
|
|
119
289
|
return self.global_state['generator']
|
|
120
290
|
|
|
121
|
-
def pre_step(self,
|
|
122
|
-
h, beta = self.get_settings('h', 'beta'
|
|
123
|
-
settings = self.settings[
|
|
291
|
+
def pre_step(self, var):
|
|
292
|
+
h, beta = self.get_settings(var.params, 'h', 'beta')
|
|
293
|
+
settings = self.settings[var.params[0]]
|
|
124
294
|
n_samples = settings['n_samples']
|
|
125
295
|
distribution = settings['distribution']
|
|
126
296
|
pre_generate = settings['pre_generate']
|
|
127
297
|
|
|
128
298
|
if pre_generate:
|
|
129
|
-
params = TensorList(
|
|
130
|
-
generator = self._get_generator(settings['seed'],
|
|
299
|
+
params = TensorList(var.params)
|
|
300
|
+
generator = self._get_generator(settings['seed'], var.params)
|
|
131
301
|
perturbations = [params.sample_like(distribution=distribution, generator=generator) for _ in range(n_samples)]
|
|
132
302
|
|
|
133
303
|
if self.PRE_MULTIPLY_BY_H:
|
|
@@ -152,11 +322,12 @@ class RandomizedFDM(GradApproximator):
|
|
|
152
322
|
torch._foreach_lerp_(cur_flat, new_flat, betas)
|
|
153
323
|
|
|
154
324
|
@torch.no_grad
|
|
155
|
-
def approximate(self, closure, params, loss
|
|
325
|
+
def approximate(self, closure, params, loss):
|
|
156
326
|
params = TensorList(params)
|
|
327
|
+
orig_params = params.clone() # store to avoid small changes due to float imprecision
|
|
157
328
|
loss_approx = None
|
|
158
329
|
|
|
159
|
-
h = self.
|
|
330
|
+
h = NumberList(self.settings[p]['h'] for p in params)
|
|
160
331
|
settings = self.settings[params[0]]
|
|
161
332
|
n_samples = settings['n_samples']
|
|
162
333
|
fd_fn = _RFD_FUNCS[settings['formula']]
|
|
@@ -171,17 +342,64 @@ class RandomizedFDM(GradApproximator):
|
|
|
171
342
|
if prt[0] is None: prt = params.sample_like(distribution=distribution, generator=generator).mul_(h)
|
|
172
343
|
else: prt = TensorList(prt)
|
|
173
344
|
|
|
174
|
-
loss, loss_approx, d = fd_fn(closure=closure, params=params, p_fn=lambda: prt, h=h,
|
|
345
|
+
loss, loss_approx, d = fd_fn(closure=closure, params=params, p_fn=lambda: prt, h=h, f_0=loss)
|
|
175
346
|
if grad is None: grad = prt * d
|
|
176
347
|
else: grad += prt * d
|
|
177
348
|
|
|
349
|
+
params.set_(orig_params)
|
|
178
350
|
assert grad is not None
|
|
179
351
|
if n_samples > 1: grad.div_(n_samples)
|
|
180
352
|
return grad, loss, loss_approx
|
|
181
353
|
|
|
182
|
-
SPSA
|
|
354
|
+
class SPSA(RandomizedFDM):
|
|
355
|
+
"""
|
|
356
|
+
Gradient approximation via Simultaneous perturbation stochastic approximation (SPSA) method.
|
|
357
|
+
|
|
358
|
+
.. note::
|
|
359
|
+
This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients,
|
|
360
|
+
and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
Args:
|
|
364
|
+
h (float, optional): finite difference step size of jvp_method is set to `forward` or `central`. Defaults to 1e-3.
|
|
365
|
+
n_samples (int, optional): number of random gradient samples. Defaults to 1.
|
|
366
|
+
formula (_FD_Formula, optional): finite difference formula. Defaults to 'central2'.
|
|
367
|
+
distribution (Distributions, optional): distribution. Defaults to "rademacher".
|
|
368
|
+
beta (float, optional):
|
|
369
|
+
If this is set to a value higher than zero, instead of using directional derivatives in a new random direction on each step, the direction changes gradually with momentum based on this value. This may make it possible to use methods with memory. Defaults to 0.
|
|
370
|
+
pre_generate (bool, optional):
|
|
371
|
+
whether to pre-generate gradient samples before each step. If samples are not pre-generated, whenever a method performs multiple closure evaluations, the gradient will be evaluated in different directions each time. Defaults to True.
|
|
372
|
+
seed (int | None | torch.Generator, optional): Seed for random generator. Defaults to None.
|
|
373
|
+
target (GradTarget, optional): what to set on var. Defaults to "closure".
|
|
374
|
+
|
|
375
|
+
References:
|
|
376
|
+
Chen, Y. (2021). Theoretical study and comparison of SPSA and RDSA algorithms. arXiv preprint arXiv:2107.12771. https://arxiv.org/abs/2107.12771
|
|
377
|
+
"""
|
|
183
378
|
|
|
184
379
|
class RDSA(RandomizedFDM):
|
|
380
|
+
"""
|
|
381
|
+
Gradient approximation via Random-direction stochastic approximation (RDSA) method.
|
|
382
|
+
|
|
383
|
+
.. note::
|
|
384
|
+
This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients,
|
|
385
|
+
and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.
|
|
386
|
+
|
|
387
|
+
Args:
|
|
388
|
+
h (float, optional): finite difference step size of jvp_method is set to `forward` or `central`. Defaults to 1e-3.
|
|
389
|
+
n_samples (int, optional): number of random gradient samples. Defaults to 1.
|
|
390
|
+
formula (_FD_Formula, optional): finite difference formula. Defaults to 'central2'.
|
|
391
|
+
distribution (Distributions, optional): distribution. Defaults to "gaussian".
|
|
392
|
+
beta (float, optional):
|
|
393
|
+
If this is set to a value higher than zero, instead of using directional derivatives in a new random direction on each step, the direction changes gradually with momentum based on this value. This may make it possible to use methods with memory. Defaults to 0.
|
|
394
|
+
pre_generate (bool, optional):
|
|
395
|
+
whether to pre-generate gradient samples before each step. If samples are not pre-generated, whenever a method performs multiple closure evaluations, the gradient will be evaluated in different directions each time. Defaults to True.
|
|
396
|
+
seed (int | None | torch.Generator, optional): Seed for random generator. Defaults to None.
|
|
397
|
+
target (GradTarget, optional): what to set on var. Defaults to "closure".
|
|
398
|
+
|
|
399
|
+
References:
|
|
400
|
+
Chen, Y. (2021). Theoretical study and comparison of SPSA and RDSA algorithms. arXiv preprint arXiv:2107.12771. https://arxiv.org/abs/2107.12771
|
|
401
|
+
|
|
402
|
+
"""
|
|
185
403
|
def __init__(
|
|
186
404
|
self,
|
|
187
405
|
h: float = 1e-3,
|
|
@@ -196,11 +414,34 @@ class RDSA(RandomizedFDM):
|
|
|
196
414
|
super().__init__(h=h, n_samples=n_samples,formula=formula,distribution=distribution,beta=beta,pre_generate=pre_generate,target=target,seed=seed)
|
|
197
415
|
|
|
198
416
|
class GaussianSmoothing(RandomizedFDM):
|
|
417
|
+
"""
|
|
418
|
+
Gradient approximation via Gaussian smoothing method.
|
|
419
|
+
|
|
420
|
+
.. note::
|
|
421
|
+
This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients,
|
|
422
|
+
and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.
|
|
423
|
+
|
|
424
|
+
Args:
|
|
425
|
+
h (float, optional): finite difference step size of jvp_method is set to `forward` or `central`. Defaults to 1e-2.
|
|
426
|
+
n_samples (int, optional): number of random gradient samples. Defaults to 100.
|
|
427
|
+
formula (_FD_Formula, optional): finite difference formula. Defaults to 'forward2'.
|
|
428
|
+
distribution (Distributions, optional): distribution. Defaults to "gaussian".
|
|
429
|
+
beta (float, optional):
|
|
430
|
+
If this is set to a value higher than zero, instead of using directional derivatives in a new random direction on each step, the direction changes gradually with momentum based on this value. This may make it possible to use methods with memory. Defaults to 0.
|
|
431
|
+
pre_generate (bool, optional):
|
|
432
|
+
whether to pre-generate gradient samples before each step. If samples are not pre-generated, whenever a method performs multiple closure evaluations, the gradient will be evaluated in different directions each time. Defaults to True.
|
|
433
|
+
seed (int | None | torch.Generator, optional): Seed for random generator. Defaults to None.
|
|
434
|
+
target (GradTarget, optional): what to set on var. Defaults to "closure".
|
|
435
|
+
|
|
436
|
+
|
|
437
|
+
References:
|
|
438
|
+
Yurii Nesterov, Vladimir Spokoiny. (2015). Random Gradient-Free Minimization of Convex Functions. https://gwern.net/doc/math/2015-nesterov.pdf
|
|
439
|
+
"""
|
|
199
440
|
def __init__(
|
|
200
441
|
self,
|
|
201
442
|
h: float = 1e-2,
|
|
202
443
|
n_samples: int = 100,
|
|
203
|
-
formula: _FD_Formula = "
|
|
444
|
+
formula: _FD_Formula = "forward2",
|
|
204
445
|
distribution: Distributions = "gaussian",
|
|
205
446
|
beta: float = 0,
|
|
206
447
|
pre_generate = True,
|
|
@@ -210,8 +451,27 @@ class GaussianSmoothing(RandomizedFDM):
|
|
|
210
451
|
super().__init__(h=h, n_samples=n_samples,formula=formula,distribution=distribution,beta=beta,pre_generate=pre_generate,target=target,seed=seed)
|
|
211
452
|
|
|
212
453
|
class MeZO(GradApproximator):
|
|
454
|
+
"""Gradient approximation via memory-efficient zeroth order optimizer (MeZO) - https://arxiv.org/abs/2305.17333.
|
|
455
|
+
|
|
456
|
+
.. note::
|
|
457
|
+
This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients,
|
|
458
|
+
and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.
|
|
459
|
+
|
|
460
|
+
Args:
|
|
461
|
+
h (float, optional): finite difference step size of jvp_method is set to `forward` or `central`. Defaults to 1e-3.
|
|
462
|
+
n_samples (int, optional): number of random gradient samples. Defaults to 1.
|
|
463
|
+
formula (_FD_Formula, optional): finite difference formula. Defaults to 'central2'.
|
|
464
|
+
distribution (Distributions, optional): distribution. Defaults to "rademacher".
|
|
465
|
+
If this is set to a value higher than zero, instead of using directional derivatives in a new random direction on each step, the direction changes gradually with momentum based on this value. This may make it possible to use methods with memory. Defaults to 0.
|
|
466
|
+
target (GradTarget, optional): what to set on var. Defaults to "closure".
|
|
467
|
+
|
|
468
|
+
References:
|
|
469
|
+
Malladi, S., Gao, T., Nichani, E., Damian, A., Lee, J. D., Chen, D., & Arora, S. (2023). Fine-tuning language models with just forward passes. Advances in Neural Information Processing Systems, 36, 53038-53075. https://arxiv.org/abs/2305.17333
|
|
470
|
+
"""
|
|
471
|
+
|
|
213
472
|
def __init__(self, h: float=1e-3, n_samples: int = 1, formula: _FD_Formula = 'central2',
|
|
214
473
|
distribution: Distributions = 'rademacher', target: GradTarget = 'closure'):
|
|
474
|
+
|
|
215
475
|
defaults = dict(h=h, formula=formula, n_samples=n_samples, distribution=distribution)
|
|
216
476
|
super().__init__(defaults, target=target)
|
|
217
477
|
|
|
@@ -220,29 +480,29 @@ class MeZO(GradApproximator):
|
|
|
220
480
|
distribution=distribution, generator=torch.Generator(params[0].device).manual_seed(seed)
|
|
221
481
|
).mul_(h)
|
|
222
482
|
|
|
223
|
-
def pre_step(self,
|
|
224
|
-
h = self.
|
|
225
|
-
settings = self.settings[
|
|
483
|
+
def pre_step(self, var):
|
|
484
|
+
h = NumberList(self.settings[p]['h'] for p in var.params)
|
|
485
|
+
settings = self.settings[var.params[0]]
|
|
226
486
|
n_samples = settings['n_samples']
|
|
227
487
|
distribution = settings['distribution']
|
|
228
488
|
|
|
229
|
-
step =
|
|
489
|
+
step = var.current_step
|
|
230
490
|
|
|
231
491
|
# create functions that generate a deterministic perturbation from seed based on current step
|
|
232
492
|
prt_fns = []
|
|
233
493
|
for i in range(n_samples):
|
|
234
494
|
|
|
235
|
-
prt_fn = partial(self._seeded_perturbation, params=
|
|
495
|
+
prt_fn = partial(self._seeded_perturbation, params=var.params, distribution=distribution, seed=1_000_000*step + i, h=h)
|
|
236
496
|
prt_fns.append(prt_fn)
|
|
237
497
|
|
|
238
498
|
self.global_state['prt_fns'] = prt_fns
|
|
239
499
|
|
|
240
500
|
@torch.no_grad
|
|
241
|
-
def approximate(self, closure, params, loss
|
|
501
|
+
def approximate(self, closure, params, loss):
|
|
242
502
|
params = TensorList(params)
|
|
243
503
|
loss_approx = None
|
|
244
504
|
|
|
245
|
-
h = self.
|
|
505
|
+
h = NumberList(self.settings[p]['h'] for p in params)
|
|
246
506
|
settings = self.settings[params[0]]
|
|
247
507
|
n_samples = settings['n_samples']
|
|
248
508
|
fd_fn = _RFD_FUNCS[settings['formula']]
|
|
@@ -250,7 +510,7 @@ class MeZO(GradApproximator):
|
|
|
250
510
|
|
|
251
511
|
grad = None
|
|
252
512
|
for i in range(n_samples):
|
|
253
|
-
loss, loss_approx, d = fd_fn(closure=closure, params=params, p_fn=prt_fns[i], h=h,
|
|
513
|
+
loss, loss_approx, d = fd_fn(closure=closure, params=params, p_fn=prt_fns[i], h=h, f_0=loss)
|
|
254
514
|
if grad is None: grad = prt_fns[i]().mul_(d)
|
|
255
515
|
else: grad += prt_fns[i]().mul_(d)
|
|
256
516
|
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .higher_order_newton import HigherOrderNewton
|