torchzero 0.3.10__py3-none-any.whl → 0.3.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +2 -3
- tests/test_opts.py +140 -100
- tests/test_tensorlist.py +8 -7
- tests/test_vars.py +1 -0
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +335 -50
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +197 -70
- torchzero/modules/__init__.py +13 -4
- torchzero/modules/adaptive/__init__.py +30 -0
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/adaptive/adahessian.py +224 -0
- torchzero/modules/{optimizers → adaptive}/adam.py +6 -8
- torchzero/modules/adaptive/adan.py +96 -0
- torchzero/modules/adaptive/adaptive_heavyball.py +54 -0
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/adaptive/esgd.py +171 -0
- torchzero/modules/{optimizers → adaptive}/lion.py +1 -1
- torchzero/modules/{experimental/spectral.py → adaptive/lmadagrad.py} +94 -71
- torchzero/modules/adaptive/mars.py +79 -0
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/adaptive/msam.py +188 -0
- torchzero/modules/{optimizers → adaptive}/muon.py +29 -5
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/orthograd.py +1 -1
- torchzero/modules/{optimizers → adaptive}/rmsprop.py +7 -4
- torchzero/modules/{optimizers → adaptive}/rprop.py +42 -10
- torchzero/modules/adaptive/sam.py +163 -0
- torchzero/modules/{optimizers → adaptive}/shampoo.py +47 -9
- torchzero/modules/{optimizers → adaptive}/soap.py +52 -65
- torchzero/modules/adaptive/sophia_h.py +185 -0
- torchzero/modules/clipping/clipping.py +115 -25
- torchzero/modules/clipping/ema_clipping.py +31 -17
- torchzero/modules/clipping/growth_clipping.py +8 -7
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/conjugate_gradient/cg.py +355 -0
- torchzero/modules/experimental/__init__.py +13 -19
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +5 -42
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +32 -15
- torchzero/modules/experimental/reduce_outward_lr.py +4 -4
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +13 -55
- torchzero/modules/functional.py +52 -6
- torchzero/modules/grad_approximation/fdm.py +30 -4
- torchzero/modules/grad_approximation/forward_gradient.py +16 -4
- torchzero/modules/grad_approximation/grad_approximator.py +51 -10
- torchzero/modules/grad_approximation/rfdm.py +321 -52
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +164 -93
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +124 -0
- torchzero/modules/line_search/backtracking.py +95 -57
- torchzero/modules/line_search/line_search.py +171 -22
- torchzero/modules/line_search/scipy.py +3 -3
- torchzero/modules/line_search/strong_wolfe.py +327 -199
- torchzero/modules/misc/__init__.py +35 -0
- torchzero/modules/misc/debug.py +48 -0
- torchzero/modules/misc/escape.py +62 -0
- torchzero/modules/misc/gradient_accumulation.py +136 -0
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +383 -0
- torchzero/modules/misc/multistep.py +194 -0
- torchzero/modules/misc/regularization.py +167 -0
- torchzero/modules/misc/split.py +123 -0
- torchzero/modules/{ops → misc}/switch.py +45 -4
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +9 -9
- torchzero/modules/momentum/cautious.py +51 -19
- torchzero/modules/momentum/momentum.py +37 -2
- torchzero/modules/ops/__init__.py +11 -31
- torchzero/modules/ops/accumulate.py +6 -10
- torchzero/modules/ops/binary.py +81 -34
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +16 -39
- torchzero/modules/ops/multi.py +82 -21
- torchzero/modules/ops/reduce.py +16 -8
- torchzero/modules/ops/unary.py +29 -13
- torchzero/modules/ops/utility.py +30 -18
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +3 -1
- torchzero/modules/projections/projection.py +190 -96
- torchzero/modules/quasi_newton/__init__.py +9 -14
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -0
- torchzero/modules/quasi_newton/lbfgs.py +286 -173
- torchzero/modules/quasi_newton/lsr1.py +185 -106
- torchzero/modules/quasi_newton/quasi_newton.py +816 -268
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +252 -0
- torchzero/modules/second_order/__init__.py +3 -2
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +292 -68
- torchzero/modules/second_order/newton_cg.py +365 -15
- torchzero/modules/second_order/nystrom.py +104 -1
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/laplacian.py +14 -4
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +387 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +97 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +94 -11
- torchzero/modules/wrappers/optim_wrapper.py +29 -1
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +359 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +39 -3
- torchzero/optim/wrappers/fcmaes.py +24 -15
- torchzero/optim/wrappers/mads.py +5 -6
- torchzero/optim/wrappers/nevergrad.py +16 -1
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +3 -3
- torchzero/optim/wrappers/scipy.py +86 -25
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +126 -114
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +369 -58
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/python_tools.py +16 -0
- torchzero/utils/tensorlist.py +134 -51
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.13.dist-info/METADATA +14 -0
- torchzero-0.3.13.dist-info/RECORD +166 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -57
- torchzero/modules/experimental/absoap.py +0 -250
- torchzero/modules/experimental/adadam.py +0 -112
- torchzero/modules/experimental/adamY.py +0 -125
- torchzero/modules/experimental/adasoap.py +0 -172
- torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
- torchzero/modules/experimental/eigendescent.py +0 -117
- torchzero/modules/experimental/etf.py +0 -172
- torchzero/modules/experimental/soapy.py +0 -163
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/experimental/subspace_preconditioners.py +0 -138
- torchzero/modules/experimental/tada.py +0 -38
- torchzero/modules/line_search/trust_region.py +0 -73
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/adaptive.py +0 -93
- torchzero/modules/lr/lr.py +0 -63
- torchzero/modules/momentum/matrix_momentum.py +0 -166
- torchzero/modules/ops/debug.py +0 -25
- torchzero/modules/ops/misc.py +0 -418
- torchzero/modules/ops/split.py +0 -75
- torchzero/modules/optimizers/__init__.py +0 -18
- torchzero/modules/optimizers/adagrad.py +0 -155
- torchzero/modules/optimizers/sophia_h.py +0 -129
- torchzero/modules/quasi_newton/cg.py +0 -268
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -266
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero/modules/smoothing/gaussian.py +0 -164
- torchzero-0.3.10.dist-info/METADATA +0 -379
- torchzero-0.3.10.dist-info/RECORD +0 -139
- torchzero-0.3.10.dist-info/licenses/LICENSE +0 -21
- {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
|
@@ -3,94 +3,160 @@ from typing import Any
|
|
|
3
3
|
from functools import partial
|
|
4
4
|
import torch
|
|
5
5
|
|
|
6
|
-
from ...utils import TensorList, Distributions, NumberList
|
|
6
|
+
from ...utils import TensorList, Distributions, NumberList
|
|
7
7
|
from .grad_approximator import GradApproximator, GradTarget, _FD_Formula
|
|
8
8
|
|
|
9
|
-
|
|
10
|
-
def _rforward2(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h, v_0: float | None):
|
|
9
|
+
def _rforward2(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h, f_0: float | None):
|
|
11
10
|
"""p_fn is a function that returns the perturbation.
|
|
12
11
|
It may return pre-generated one or generate one deterministically from a seed as in MeZO.
|
|
13
12
|
Returned perturbation must be multiplied by `h`."""
|
|
14
|
-
if
|
|
13
|
+
if f_0 is None: f_0 = closure(False)
|
|
15
14
|
params += p_fn()
|
|
16
|
-
|
|
15
|
+
f_1 = closure(False)
|
|
17
16
|
params -= p_fn()
|
|
18
17
|
h = h**2 # because perturbation already multiplied by h
|
|
19
|
-
return
|
|
18
|
+
return f_0, f_0, (f_1 - f_0) / h # (loss, loss_approx, grad)
|
|
20
19
|
|
|
21
|
-
def _rbackward2(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h,
|
|
22
|
-
if
|
|
20
|
+
def _rbackward2(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h, f_0: float | None):
|
|
21
|
+
if f_0 is None: f_0 = closure(False)
|
|
23
22
|
params -= p_fn()
|
|
24
|
-
|
|
23
|
+
f_m1 = closure(False)
|
|
25
24
|
params += p_fn()
|
|
26
25
|
h = h**2 # because perturbation already multiplied by h
|
|
27
|
-
return
|
|
26
|
+
return f_0, f_0, (f_0 - f_m1) / h
|
|
28
27
|
|
|
29
|
-
def _rcentral2(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h,
|
|
28
|
+
def _rcentral2(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h, f_0: Any):
|
|
30
29
|
params += p_fn()
|
|
31
|
-
|
|
30
|
+
f_1 = closure(False)
|
|
32
31
|
|
|
33
32
|
params -= p_fn() * 2
|
|
34
|
-
|
|
33
|
+
f_m1 = closure(False)
|
|
35
34
|
|
|
36
35
|
params += p_fn()
|
|
37
36
|
h = h**2 # because perturbation already multiplied by h
|
|
38
|
-
return
|
|
37
|
+
return f_0, f_1, (f_1 - f_m1) / (2 * h)
|
|
39
38
|
|
|
40
|
-
def _rforward3(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h,
|
|
41
|
-
if
|
|
39
|
+
def _rforward3(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h, f_0: float | None):
|
|
40
|
+
if f_0 is None: f_0 = closure(False)
|
|
42
41
|
params += p_fn()
|
|
43
|
-
|
|
42
|
+
f_1 = closure(False)
|
|
44
43
|
|
|
45
44
|
params += p_fn()
|
|
46
|
-
|
|
45
|
+
f_2 = closure(False)
|
|
47
46
|
|
|
48
47
|
params -= p_fn() * 2
|
|
49
48
|
h = h**2 # because perturbation already multiplied by h
|
|
50
|
-
return
|
|
49
|
+
return f_0, f_0, (-3*f_0 + 4*f_1 - f_2) / (2 * h)
|
|
51
50
|
|
|
52
|
-
def _rbackward3(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h,
|
|
53
|
-
if
|
|
51
|
+
def _rbackward3(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h, f_0: float | None):
|
|
52
|
+
if f_0 is None: f_0 = closure(False)
|
|
54
53
|
|
|
55
54
|
params -= p_fn()
|
|
56
|
-
|
|
55
|
+
f_m1 = closure(False)
|
|
57
56
|
|
|
58
57
|
params -= p_fn()
|
|
59
|
-
|
|
58
|
+
f_m2 = closure(False)
|
|
60
59
|
|
|
61
60
|
params += p_fn() * 2
|
|
62
61
|
h = h**2 # because perturbation already multiplied by h
|
|
63
|
-
return
|
|
62
|
+
return f_0, f_0, (f_m2 - 4*f_m1 + 3*f_0) / (2 * h)
|
|
64
63
|
|
|
65
|
-
def _rcentral4(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h,
|
|
64
|
+
def _rcentral4(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h, f_0: float | None):
|
|
66
65
|
params += p_fn()
|
|
67
|
-
|
|
66
|
+
f_1 = closure(False)
|
|
68
67
|
|
|
69
68
|
params += p_fn()
|
|
70
|
-
|
|
69
|
+
f_2 = closure(False)
|
|
71
70
|
|
|
72
71
|
params -= p_fn() * 3
|
|
73
|
-
|
|
72
|
+
f_m1 = closure(False)
|
|
74
73
|
|
|
75
74
|
params -= p_fn()
|
|
76
|
-
|
|
75
|
+
f_m2 = closure(False)
|
|
77
76
|
|
|
78
77
|
params += p_fn() * 2
|
|
79
78
|
h = h**2 # because perturbation already multiplied by h
|
|
80
|
-
return
|
|
79
|
+
return f_0, f_1, (f_m2 - 8*f_m1 + 8*f_1 - f_2) / (12 * h)
|
|
80
|
+
|
|
81
|
+
# some good ones
|
|
82
|
+
# Pachalyl S. et al. Generalized simultaneous perturbation-based gradient search with reduced estimator bias //IEEE Transactions on Automatic Control. – 2025.
|
|
83
|
+
# Three measurements GSPSA is _rforward3
|
|
84
|
+
# Four measurements GSPSA
|
|
85
|
+
def _rforward4(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h, f_0: float | None):
|
|
86
|
+
if f_0 is None: f_0 = closure(False)
|
|
87
|
+
params += p_fn()
|
|
88
|
+
f_1 = closure(False)
|
|
89
|
+
|
|
90
|
+
params += p_fn()
|
|
91
|
+
f_2 = closure(False)
|
|
92
|
+
|
|
93
|
+
params += p_fn()
|
|
94
|
+
f_3 = closure(False)
|
|
95
|
+
|
|
96
|
+
params -= p_fn() * 3
|
|
97
|
+
h = h**2 # because perturbation already multiplied by h
|
|
98
|
+
return f_0, f_0, (2*f_3 - 9*f_2 + 18*f_1 - 11*f_0) / (6 * h)
|
|
99
|
+
|
|
100
|
+
def _rforward5(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h, f_0: float | None):
|
|
101
|
+
if f_0 is None: f_0 = closure(False)
|
|
102
|
+
params += p_fn()
|
|
103
|
+
f_1 = closure(False)
|
|
104
|
+
|
|
105
|
+
params += p_fn()
|
|
106
|
+
f_2 = closure(False)
|
|
107
|
+
|
|
108
|
+
params += p_fn()
|
|
109
|
+
f_3 = closure(False)
|
|
110
|
+
|
|
111
|
+
params += p_fn()
|
|
112
|
+
f_4 = closure(False)
|
|
113
|
+
|
|
114
|
+
params -= p_fn() * 4
|
|
115
|
+
h = h**2 # because perturbation already multiplied by h
|
|
116
|
+
return f_0, f_0, (-3*f_4 + 16*f_3 - 36*f_2 + 48*f_1 - 25*f_0) / (12 * h)
|
|
117
|
+
|
|
118
|
+
# # another central4
|
|
119
|
+
# def _bgspsa4(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h, f_0: float | None):
|
|
120
|
+
# params += p_fn()
|
|
121
|
+
# f_1 = closure(False)
|
|
122
|
+
|
|
123
|
+
# params += p_fn() * 2
|
|
124
|
+
# f_3 = closure(False)
|
|
125
|
+
|
|
126
|
+
# params -= p_fn() * 4
|
|
127
|
+
# f_m1 = closure(False)
|
|
128
|
+
|
|
129
|
+
# params -= p_fn() * 2
|
|
130
|
+
# f_m3 = closure(False)
|
|
131
|
+
|
|
132
|
+
# params += p_fn() * 3
|
|
133
|
+
# h = h**2 # because perturbation already multiplied by h
|
|
134
|
+
# return f_0, f_1, (27*f_1 - f_m1 - f_3 + f_m3) / (48 * h)
|
|
135
|
+
|
|
81
136
|
|
|
82
|
-
_RFD_FUNCS = {
|
|
137
|
+
_RFD_FUNCS: dict[_FD_Formula, Callable] = {
|
|
138
|
+
"forward": _rforward2,
|
|
83
139
|
"forward2": _rforward2,
|
|
140
|
+
"backward": _rbackward2,
|
|
84
141
|
"backward2": _rbackward2,
|
|
142
|
+
"central": _rcentral2,
|
|
85
143
|
"central2": _rcentral2,
|
|
144
|
+
"central3": _rcentral2,
|
|
86
145
|
"forward3": _rforward3,
|
|
87
146
|
"backward3": _rbackward3,
|
|
88
147
|
"central4": _rcentral4,
|
|
148
|
+
"forward4": _rforward4,
|
|
149
|
+
"forward5": _rforward5,
|
|
150
|
+
# "bspsa4": _bgspsa4,
|
|
89
151
|
}
|
|
90
152
|
|
|
91
153
|
|
|
92
154
|
class RandomizedFDM(GradApproximator):
|
|
93
|
-
"""
|
|
155
|
+
"""Gradient approximation via a randomized finite-difference method.
|
|
156
|
+
|
|
157
|
+
Note:
|
|
158
|
+
This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients,
|
|
159
|
+
and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.
|
|
94
160
|
|
|
95
161
|
Args:
|
|
96
162
|
h (float, optional): finite difference step size of jvp_method is set to `forward` or `central`. Defaults to 1e-3.
|
|
@@ -98,17 +164,109 @@ class RandomizedFDM(GradApproximator):
|
|
|
98
164
|
formula (_FD_Formula, optional): finite difference formula. Defaults to 'central2'.
|
|
99
165
|
distribution (Distributions, optional): distribution. Defaults to "rademacher".
|
|
100
166
|
If this is set to a value higher than zero, instead of using directional derivatives in a new random direction on each step, the direction changes gradually with momentum based on this value. This may make it possible to use methods with memory. Defaults to 0.
|
|
167
|
+
beta (float, optional): optinal momentum for generated perturbations. Defaults to 1e-3.
|
|
101
168
|
pre_generate (bool, optional):
|
|
102
169
|
whether to pre-generate gradient samples before each step. If samples are not pre-generated, whenever a method performs multiple closure evaluations, the gradient will be evaluated in different directions each time. Defaults to True.
|
|
103
170
|
seed (int | None | torch.Generator, optional): Seed for random generator. Defaults to None.
|
|
104
171
|
target (GradTarget, optional): what to set on var. Defaults to "closure".
|
|
172
|
+
|
|
173
|
+
Examples:
|
|
174
|
+
#### Simultaneous perturbation stochastic approximation (SPSA) method
|
|
175
|
+
|
|
176
|
+
SPSA is randomized finite differnce with rademacher distribution and central formula.
|
|
177
|
+
```py
|
|
178
|
+
spsa = tz.Modular(
|
|
179
|
+
model.parameters(),
|
|
180
|
+
tz.m.RandomizedFDM(formula="central", distribution="rademacher"),
|
|
181
|
+
tz.m.LR(1e-2)
|
|
182
|
+
)
|
|
183
|
+
```
|
|
184
|
+
|
|
185
|
+
#### Random-direction stochastic approximation (RDSA) method
|
|
186
|
+
|
|
187
|
+
RDSA is randomized finite differnce with usually gaussian distribution and central formula.
|
|
188
|
+
|
|
189
|
+
```
|
|
190
|
+
rdsa = tz.Modular(
|
|
191
|
+
model.parameters(),
|
|
192
|
+
tz.m.RandomizedFDM(formula="central", distribution="gaussian"),
|
|
193
|
+
tz.m.LR(1e-2)
|
|
194
|
+
)
|
|
195
|
+
```
|
|
196
|
+
|
|
197
|
+
#### RandomizedFDM with momentum
|
|
198
|
+
|
|
199
|
+
Momentum might help by reducing the variance of the estimated gradients.
|
|
200
|
+
|
|
201
|
+
```
|
|
202
|
+
momentum_spsa = tz.Modular(
|
|
203
|
+
model.parameters(),
|
|
204
|
+
tz.m.RandomizedFDM(),
|
|
205
|
+
tz.m.HeavyBall(0.9),
|
|
206
|
+
tz.m.LR(1e-3)
|
|
207
|
+
)
|
|
208
|
+
```
|
|
209
|
+
|
|
210
|
+
#### Gaussian smoothing method
|
|
211
|
+
|
|
212
|
+
GS uses many gaussian samples with possibly a larger finite difference step size.
|
|
213
|
+
|
|
214
|
+
```
|
|
215
|
+
gs = tz.Modular(
|
|
216
|
+
model.parameters(),
|
|
217
|
+
tz.m.RandomizedFDM(n_samples=100, distribution="gaussian", formula="forward2", h=1e-1),
|
|
218
|
+
tz.m.NewtonCG(hvp_method="forward"),
|
|
219
|
+
tz.m.Backtracking()
|
|
220
|
+
)
|
|
221
|
+
```
|
|
222
|
+
|
|
223
|
+
#### SPSA-NewtonCG
|
|
224
|
+
|
|
225
|
+
NewtonCG with hessian-vector product estimated via gradient difference
|
|
226
|
+
calls closure multiple times per step. If each closure call estimates gradients
|
|
227
|
+
with different perturbations, NewtonCG is unable to produce useful directions.
|
|
228
|
+
|
|
229
|
+
By setting pre_generate to True, perturbations are generated once before each step,
|
|
230
|
+
and each closure call estimates gradients using the same pre-generated perturbations.
|
|
231
|
+
This way closure-based algorithms are able to use gradients estimated in a consistent way.
|
|
232
|
+
|
|
233
|
+
```
|
|
234
|
+
opt = tz.Modular(
|
|
235
|
+
model.parameters(),
|
|
236
|
+
tz.m.RandomizedFDM(n_samples=10),
|
|
237
|
+
tz.m.NewtonCG(hvp_method="forward", pre_generate=True),
|
|
238
|
+
tz.m.Backtracking()
|
|
239
|
+
)
|
|
240
|
+
```
|
|
241
|
+
|
|
242
|
+
#### SPSA-LBFGS
|
|
243
|
+
|
|
244
|
+
LBFGS uses a memory of past parameter and gradient differences. If past gradients
|
|
245
|
+
were estimated with different perturbations, LBFGS directions will be useless.
|
|
246
|
+
|
|
247
|
+
To alleviate this momentum can be added to random perturbations to make sure they only
|
|
248
|
+
change by a little bit, and the history stays relevant. The momentum is determined by the :code:`beta` parameter.
|
|
249
|
+
The disadvantage is that the subspace the algorithm is able to explore changes slowly.
|
|
250
|
+
|
|
251
|
+
Additionally we will reset SPSA and LBFGS memory every 100 steps to remove influence from old gradient estimates.
|
|
252
|
+
|
|
253
|
+
```
|
|
254
|
+
opt = tz.Modular(
|
|
255
|
+
bench.parameters(),
|
|
256
|
+
tz.m.ResetEvery(
|
|
257
|
+
[tz.m.RandomizedFDM(n_samples=10, pre_generate=True, beta=0.99), tz.m.LBFGS()],
|
|
258
|
+
steps = 100,
|
|
259
|
+
),
|
|
260
|
+
tz.m.Backtracking()
|
|
261
|
+
)
|
|
262
|
+
```
|
|
105
263
|
"""
|
|
106
264
|
PRE_MULTIPLY_BY_H = True
|
|
107
265
|
def __init__(
|
|
108
266
|
self,
|
|
109
267
|
h: float = 1e-3,
|
|
110
268
|
n_samples: int = 1,
|
|
111
|
-
formula: _FD_Formula = "
|
|
269
|
+
formula: _FD_Formula = "central",
|
|
112
270
|
distribution: Distributions = "rademacher",
|
|
113
271
|
beta: float = 0,
|
|
114
272
|
pre_generate = True,
|
|
@@ -123,6 +281,7 @@ class RandomizedFDM(GradApproximator):
|
|
|
123
281
|
generator = self.global_state.get('generator', None) # avoid resetting generator
|
|
124
282
|
self.global_state.clear()
|
|
125
283
|
if generator is not None: self.global_state['generator'] = generator
|
|
284
|
+
for c in self.children.values(): c.reset()
|
|
126
285
|
|
|
127
286
|
def _get_generator(self, seed: int | None | torch.Generator, params: list[torch.Tensor]):
|
|
128
287
|
if 'generator' not in self.global_state:
|
|
@@ -133,15 +292,15 @@ class RandomizedFDM(GradApproximator):
|
|
|
133
292
|
|
|
134
293
|
def pre_step(self, var):
|
|
135
294
|
h, beta = self.get_settings(var.params, 'h', 'beta')
|
|
136
|
-
|
|
137
|
-
n_samples =
|
|
138
|
-
distribution =
|
|
139
|
-
pre_generate =
|
|
295
|
+
|
|
296
|
+
n_samples = self.defaults['n_samples']
|
|
297
|
+
distribution = self.defaults['distribution']
|
|
298
|
+
pre_generate = self.defaults['pre_generate']
|
|
140
299
|
|
|
141
300
|
if pre_generate:
|
|
142
301
|
params = TensorList(var.params)
|
|
143
|
-
generator = self._get_generator(
|
|
144
|
-
perturbations = [params.sample_like(distribution=distribution, generator=generator) for _ in range(n_samples)]
|
|
302
|
+
generator = self._get_generator(self.defaults['seed'], var.params)
|
|
303
|
+
perturbations = [params.sample_like(distribution=distribution, variance=1, generator=generator) for _ in range(n_samples)]
|
|
145
304
|
|
|
146
305
|
if self.PRE_MULTIPLY_BY_H:
|
|
147
306
|
torch._foreach_mul_([p for l in perturbations for p in l], [v for vv in h for v in [vv]*n_samples])
|
|
@@ -165,8 +324,9 @@ class RandomizedFDM(GradApproximator):
|
|
|
165
324
|
torch._foreach_lerp_(cur_flat, new_flat, betas)
|
|
166
325
|
|
|
167
326
|
@torch.no_grad
|
|
168
|
-
def approximate(self, closure, params, loss
|
|
327
|
+
def approximate(self, closure, params, loss):
|
|
169
328
|
params = TensorList(params)
|
|
329
|
+
orig_params = params.clone() # store to avoid small changes due to float imprecision
|
|
170
330
|
loss_approx = None
|
|
171
331
|
|
|
172
332
|
h = NumberList(self.settings[p]['h'] for p in params)
|
|
@@ -181,20 +341,84 @@ class RandomizedFDM(GradApproximator):
|
|
|
181
341
|
grad = None
|
|
182
342
|
for i in range(n_samples):
|
|
183
343
|
prt = perturbations[i]
|
|
184
|
-
|
|
344
|
+
|
|
345
|
+
if prt[0] is None:
|
|
346
|
+
prt = params.sample_like(distribution=distribution, generator=generator, variance=1).mul_(h)
|
|
347
|
+
|
|
185
348
|
else: prt = TensorList(prt)
|
|
186
349
|
|
|
187
|
-
loss, loss_approx, d = fd_fn(closure=closure, params=params, p_fn=lambda: prt, h=h,
|
|
350
|
+
loss, loss_approx, d = fd_fn(closure=closure, params=params, p_fn=lambda: prt, h=h, f_0=loss)
|
|
351
|
+
# here `d` is a numberlist of directional derivatives, due to per parameter `h` values.
|
|
352
|
+
|
|
353
|
+
# support for per-sample values which gives better estimate
|
|
354
|
+
if d[0].numel() > 1: d = d.map(torch.mean)
|
|
355
|
+
|
|
188
356
|
if grad is None: grad = prt * d
|
|
189
357
|
else: grad += prt * d
|
|
190
358
|
|
|
359
|
+
params.set_(orig_params)
|
|
191
360
|
assert grad is not None
|
|
192
361
|
if n_samples > 1: grad.div_(n_samples)
|
|
362
|
+
|
|
363
|
+
# mean if got per-sample values
|
|
364
|
+
if loss is not None:
|
|
365
|
+
if loss.numel() > 1:
|
|
366
|
+
loss = loss.mean()
|
|
367
|
+
|
|
368
|
+
if loss_approx is not None:
|
|
369
|
+
if loss_approx.numel() > 1:
|
|
370
|
+
loss_approx = loss_approx.mean()
|
|
371
|
+
|
|
193
372
|
return grad, loss, loss_approx
|
|
194
373
|
|
|
195
|
-
SPSA
|
|
374
|
+
class SPSA(RandomizedFDM):
|
|
375
|
+
"""
|
|
376
|
+
Gradient approximation via Simultaneous perturbation stochastic approximation (SPSA) method.
|
|
377
|
+
|
|
378
|
+
Note:
|
|
379
|
+
This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients,
|
|
380
|
+
and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.
|
|
381
|
+
|
|
382
|
+
Args:
|
|
383
|
+
h (float, optional): finite difference step size of jvp_method is set to `forward` or `central`. Defaults to 1e-3.
|
|
384
|
+
n_samples (int, optional): number of random gradient samples. Defaults to 1.
|
|
385
|
+
formula (_FD_Formula, optional): finite difference formula. Defaults to 'central2'.
|
|
386
|
+
distribution (Distributions, optional): distribution. Defaults to "rademacher".
|
|
387
|
+
beta (float, optional):
|
|
388
|
+
If this is set to a value higher than zero, instead of using directional derivatives in a new random direction on each step, the direction changes gradually with momentum based on this value. This may make it possible to use methods with memory. Defaults to 0.
|
|
389
|
+
pre_generate (bool, optional):
|
|
390
|
+
whether to pre-generate gradient samples before each step. If samples are not pre-generated, whenever a method performs multiple closure evaluations, the gradient will be evaluated in different directions each time. Defaults to True.
|
|
391
|
+
seed (int | None | torch.Generator, optional): Seed for random generator. Defaults to None.
|
|
392
|
+
target (GradTarget, optional): what to set on var. Defaults to "closure".
|
|
393
|
+
|
|
394
|
+
References:
|
|
395
|
+
Chen, Y. (2021). Theoretical study and comparison of SPSA and RDSA algorithms. arXiv preprint arXiv:2107.12771. https://arxiv.org/abs/2107.12771
|
|
396
|
+
"""
|
|
196
397
|
|
|
197
398
|
class RDSA(RandomizedFDM):
|
|
399
|
+
"""
|
|
400
|
+
Gradient approximation via Random-direction stochastic approximation (RDSA) method.
|
|
401
|
+
|
|
402
|
+
Note:
|
|
403
|
+
This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients,
|
|
404
|
+
and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.
|
|
405
|
+
|
|
406
|
+
Args:
|
|
407
|
+
h (float, optional): finite difference step size of jvp_method is set to `forward` or `central`. Defaults to 1e-3.
|
|
408
|
+
n_samples (int, optional): number of random gradient samples. Defaults to 1.
|
|
409
|
+
formula (_FD_Formula, optional): finite difference formula. Defaults to 'central2'.
|
|
410
|
+
distribution (Distributions, optional): distribution. Defaults to "gaussian".
|
|
411
|
+
beta (float, optional):
|
|
412
|
+
If this is set to a value higher than zero, instead of using directional derivatives in a new random direction on each step, the direction changes gradually with momentum based on this value. This may make it possible to use methods with memory. Defaults to 0.
|
|
413
|
+
pre_generate (bool, optional):
|
|
414
|
+
whether to pre-generate gradient samples before each step. If samples are not pre-generated, whenever a method performs multiple closure evaluations, the gradient will be evaluated in different directions each time. Defaults to True.
|
|
415
|
+
seed (int | None | torch.Generator, optional): Seed for random generator. Defaults to None.
|
|
416
|
+
target (GradTarget, optional): what to set on var. Defaults to "closure".
|
|
417
|
+
|
|
418
|
+
References:
|
|
419
|
+
Chen, Y. (2021). Theoretical study and comparison of SPSA and RDSA algorithms. arXiv preprint arXiv:2107.12771. https://arxiv.org/abs/2107.12771
|
|
420
|
+
|
|
421
|
+
"""
|
|
198
422
|
def __init__(
|
|
199
423
|
self,
|
|
200
424
|
h: float = 1e-3,
|
|
@@ -209,11 +433,34 @@ class RDSA(RandomizedFDM):
|
|
|
209
433
|
super().__init__(h=h, n_samples=n_samples,formula=formula,distribution=distribution,beta=beta,pre_generate=pre_generate,target=target,seed=seed)
|
|
210
434
|
|
|
211
435
|
class GaussianSmoothing(RandomizedFDM):
|
|
436
|
+
"""
|
|
437
|
+
Gradient approximation via Gaussian smoothing method.
|
|
438
|
+
|
|
439
|
+
Note:
|
|
440
|
+
This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients,
|
|
441
|
+
and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.
|
|
442
|
+
|
|
443
|
+
Args:
|
|
444
|
+
h (float, optional): finite difference step size of jvp_method is set to `forward` or `central`. Defaults to 1e-2.
|
|
445
|
+
n_samples (int, optional): number of random gradient samples. Defaults to 100.
|
|
446
|
+
formula (_FD_Formula, optional): finite difference formula. Defaults to 'forward2'.
|
|
447
|
+
distribution (Distributions, optional): distribution. Defaults to "gaussian".
|
|
448
|
+
beta (float, optional):
|
|
449
|
+
If this is set to a value higher than zero, instead of using directional derivatives in a new random direction on each step, the direction changes gradually with momentum based on this value. This may make it possible to use methods with memory. Defaults to 0.
|
|
450
|
+
pre_generate (bool, optional):
|
|
451
|
+
whether to pre-generate gradient samples before each step. If samples are not pre-generated, whenever a method performs multiple closure evaluations, the gradient will be evaluated in different directions each time. Defaults to True.
|
|
452
|
+
seed (int | None | torch.Generator, optional): Seed for random generator. Defaults to None.
|
|
453
|
+
target (GradTarget, optional): what to set on var. Defaults to "closure".
|
|
454
|
+
|
|
455
|
+
|
|
456
|
+
References:
|
|
457
|
+
Yurii Nesterov, Vladimir Spokoiny. (2015). Random Gradient-Free Minimization of Convex Functions. https://gwern.net/doc/math/2015-nesterov.pdf
|
|
458
|
+
"""
|
|
212
459
|
def __init__(
|
|
213
460
|
self,
|
|
214
461
|
h: float = 1e-2,
|
|
215
462
|
n_samples: int = 100,
|
|
216
|
-
formula: _FD_Formula = "
|
|
463
|
+
formula: _FD_Formula = "forward2",
|
|
217
464
|
distribution: Distributions = "gaussian",
|
|
218
465
|
beta: float = 0,
|
|
219
466
|
pre_generate = True,
|
|
@@ -223,21 +470,43 @@ class GaussianSmoothing(RandomizedFDM):
|
|
|
223
470
|
super().__init__(h=h, n_samples=n_samples,formula=formula,distribution=distribution,beta=beta,pre_generate=pre_generate,target=target,seed=seed)
|
|
224
471
|
|
|
225
472
|
class MeZO(GradApproximator):
|
|
473
|
+
"""Gradient approximation via memory-efficient zeroth order optimizer (MeZO) - https://arxiv.org/abs/2305.17333.
|
|
474
|
+
|
|
475
|
+
Note:
|
|
476
|
+
This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients,
|
|
477
|
+
and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.
|
|
478
|
+
|
|
479
|
+
Args:
|
|
480
|
+
h (float, optional): finite difference step size of jvp_method is set to `forward` or `central`. Defaults to 1e-3.
|
|
481
|
+
n_samples (int, optional): number of random gradient samples. Defaults to 1.
|
|
482
|
+
formula (_FD_Formula, optional): finite difference formula. Defaults to 'central2'.
|
|
483
|
+
distribution (Distributions, optional): distribution. Defaults to "rademacher".
|
|
484
|
+
If this is set to a value higher than zero, instead of using directional derivatives in a new random direction on each step, the direction changes gradually with momentum based on this value. This may make it possible to use methods with memory. Defaults to 0.
|
|
485
|
+
target (GradTarget, optional): what to set on var. Defaults to "closure".
|
|
486
|
+
|
|
487
|
+
References:
|
|
488
|
+
Malladi, S., Gao, T., Nichani, E., Damian, A., Lee, J. D., Chen, D., & Arora, S. (2023). Fine-tuning language models with just forward passes. Advances in Neural Information Processing Systems, 36, 53038-53075. https://arxiv.org/abs/2305.17333
|
|
489
|
+
"""
|
|
490
|
+
|
|
226
491
|
def __init__(self, h: float=1e-3, n_samples: int = 1, formula: _FD_Formula = 'central2',
|
|
227
492
|
distribution: Distributions = 'rademacher', target: GradTarget = 'closure'):
|
|
493
|
+
|
|
228
494
|
defaults = dict(h=h, formula=formula, n_samples=n_samples, distribution=distribution)
|
|
229
495
|
super().__init__(defaults, target=target)
|
|
230
496
|
|
|
231
497
|
def _seeded_perturbation(self, params: list[torch.Tensor], distribution, seed, h):
|
|
232
|
-
|
|
233
|
-
distribution=distribution,
|
|
234
|
-
|
|
498
|
+
prt = TensorList(params).sample_like(
|
|
499
|
+
distribution=distribution,
|
|
500
|
+
variance=h,
|
|
501
|
+
generator=torch.Generator(params[0].device).manual_seed(seed)
|
|
502
|
+
)
|
|
503
|
+
return prt
|
|
235
504
|
|
|
236
505
|
def pre_step(self, var):
|
|
237
506
|
h = NumberList(self.settings[p]['h'] for p in var.params)
|
|
238
|
-
|
|
239
|
-
n_samples =
|
|
240
|
-
distribution =
|
|
507
|
+
|
|
508
|
+
n_samples = self.defaults['n_samples']
|
|
509
|
+
distribution = self.defaults['distribution']
|
|
241
510
|
|
|
242
511
|
step = var.current_step
|
|
243
512
|
|
|
@@ -251,7 +520,7 @@ class MeZO(GradApproximator):
|
|
|
251
520
|
self.global_state['prt_fns'] = prt_fns
|
|
252
521
|
|
|
253
522
|
@torch.no_grad
|
|
254
|
-
def approximate(self, closure, params, loss
|
|
523
|
+
def approximate(self, closure, params, loss):
|
|
255
524
|
params = TensorList(params)
|
|
256
525
|
loss_approx = None
|
|
257
526
|
|
|
@@ -263,7 +532,7 @@ class MeZO(GradApproximator):
|
|
|
263
532
|
|
|
264
533
|
grad = None
|
|
265
534
|
for i in range(n_samples):
|
|
266
|
-
loss, loss_approx, d = fd_fn(closure=closure, params=params, p_fn=prt_fns[i], h=h,
|
|
535
|
+
loss, loss_approx, d = fd_fn(closure=closure, params=params, p_fn=prt_fns[i], h=h, f_0=loss)
|
|
267
536
|
if grad is None: grad = prt_fns[i]().mul_(d)
|
|
268
537
|
else: grad += prt_fns[i]().mul_(d)
|
|
269
538
|
|
|
@@ -1 +1 @@
|
|
|
1
|
-
from .higher_order_newton import HigherOrderNewton
|
|
1
|
+
from .higher_order_newton import HigherOrderNewton
|