torchzero 0.3.11__py3-none-any.whl → 0.3.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_opts.py +95 -76
- tests/test_tensorlist.py +8 -7
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +229 -72
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +44 -24
- torchzero/modules/__init__.py +13 -5
- torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
- torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
- torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
- torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
- torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
- torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
- torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
- torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
- torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
- torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
- torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
- torchzero/modules/clipping/clipping.py +85 -92
- torchzero/modules/clipping/ema_clipping.py +5 -5
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
- torchzero/modules/experimental/__init__.py +9 -32
- torchzero/modules/experimental/dct.py +2 -2
- torchzero/modules/experimental/fft.py +2 -2
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +27 -14
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/experimental/spsa1.py +93 -0
- torchzero/modules/experimental/structural_projections.py +1 -1
- torchzero/modules/functional.py +50 -14
- torchzero/modules/grad_approximation/__init__.py +1 -1
- torchzero/modules/grad_approximation/fdm.py +19 -20
- torchzero/modules/grad_approximation/forward_gradient.py +6 -7
- torchzero/modules/grad_approximation/grad_approximator.py +43 -47
- torchzero/modules/grad_approximation/rfdm.py +114 -175
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +31 -23
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +2 -2
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +69 -44
- torchzero/modules/line_search/backtracking.py +83 -70
- torchzero/modules/line_search/line_search.py +159 -68
- torchzero/modules/line_search/scipy.py +16 -4
- torchzero/modules/line_search/strong_wolfe.py +319 -220
- torchzero/modules/misc/__init__.py +8 -0
- torchzero/modules/misc/debug.py +4 -4
- torchzero/modules/misc/escape.py +9 -7
- torchzero/modules/misc/gradient_accumulation.py +88 -22
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +82 -15
- torchzero/modules/misc/multistep.py +47 -11
- torchzero/modules/misc/regularization.py +5 -9
- torchzero/modules/misc/split.py +55 -35
- torchzero/modules/misc/switch.py +1 -1
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +3 -3
- torchzero/modules/momentum/cautious.py +42 -47
- torchzero/modules/momentum/momentum.py +35 -1
- torchzero/modules/ops/__init__.py +9 -1
- torchzero/modules/ops/binary.py +9 -8
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
- torchzero/modules/ops/multi.py +15 -15
- torchzero/modules/ops/reduce.py +1 -1
- torchzero/modules/ops/utility.py +12 -8
- torchzero/modules/projections/projection.py +4 -4
- torchzero/modules/quasi_newton/__init__.py +1 -16
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
- torchzero/modules/quasi_newton/lbfgs.py +256 -200
- torchzero/modules/quasi_newton/lsr1.py +167 -132
- torchzero/modules/quasi_newton/quasi_newton.py +346 -446
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +253 -0
- torchzero/modules/second_order/__init__.py +2 -1
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +133 -88
- torchzero/modules/second_order/newton_cg.py +207 -170
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +1 -1
- torchzero/modules/step_size/adaptive.py +312 -47
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +99 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/weight_decay.py +65 -64
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +122 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +0 -1
- torchzero/optim/wrappers/fcmaes.py +3 -2
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +2 -2
- torchzero/optim/wrappers/scipy.py +81 -22
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +123 -111
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +226 -154
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/optimizer.py +2 -2
- torchzero/utils/python_tools.py +7 -0
- torchzero/utils/tensorlist.py +105 -34
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.14.dist-info/METADATA +14 -0
- torchzero-0.3.14.dist-info/RECORD +167 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -59
- docs/source/docstring template.py +0 -46
- torchzero/modules/experimental/absoap.py +0 -253
- torchzero/modules/experimental/adadam.py +0 -118
- torchzero/modules/experimental/adamY.py +0 -131
- torchzero/modules/experimental/adam_lambertw.py +0 -149
- torchzero/modules/experimental/adaptive_step_size.py +0 -90
- torchzero/modules/experimental/adasoap.py +0 -177
- torchzero/modules/experimental/cosine.py +0 -214
- torchzero/modules/experimental/cubic_adam.py +0 -97
- torchzero/modules/experimental/eigendescent.py +0 -120
- torchzero/modules/experimental/etf.py +0 -195
- torchzero/modules/experimental/exp_adam.py +0 -113
- torchzero/modules/experimental/expanded_lbfgs.py +0 -141
- torchzero/modules/experimental/hnewton.py +0 -85
- torchzero/modules/experimental/modular_lbfgs.py +0 -265
- torchzero/modules/experimental/parabolic_search.py +0 -220
- torchzero/modules/experimental/subspace_preconditioners.py +0 -145
- torchzero/modules/experimental/tensor_adagrad.py +0 -42
- torchzero/modules/line_search/polynomial.py +0 -233
- torchzero/modules/momentum/matrix_momentum.py +0 -193
- torchzero/modules/optimizers/adagrad.py +0 -165
- torchzero/modules/quasi_newton/trust_region.py +0 -397
- torchzero/modules/smoothing/gaussian.py +0 -198
- torchzero-0.3.11.dist-info/METADATA +0 -404
- torchzero-0.3.11.dist-info/RECORD +0 -159
- torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
- /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/WHEEL +0 -0
|
@@ -115,26 +115,26 @@ def _rforward5(closure: Callable[..., float], params:TensorList, p_fn:Callable[[
|
|
|
115
115
|
h = h**2 # because perturbation already multiplied by h
|
|
116
116
|
return f_0, f_0, (-3*f_4 + 16*f_3 - 36*f_2 + 48*f_1 - 25*f_0) / (12 * h)
|
|
117
117
|
|
|
118
|
-
# another central4
|
|
119
|
-
def _bgspsa4(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h, f_0: float | None):
|
|
120
|
-
|
|
121
|
-
|
|
118
|
+
# # another central4
|
|
119
|
+
# def _bgspsa4(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h, f_0: float | None):
|
|
120
|
+
# params += p_fn()
|
|
121
|
+
# f_1 = closure(False)
|
|
122
122
|
|
|
123
|
-
|
|
124
|
-
|
|
123
|
+
# params += p_fn() * 2
|
|
124
|
+
# f_3 = closure(False)
|
|
125
125
|
|
|
126
|
-
|
|
127
|
-
|
|
126
|
+
# params -= p_fn() * 4
|
|
127
|
+
# f_m1 = closure(False)
|
|
128
128
|
|
|
129
|
-
|
|
130
|
-
|
|
129
|
+
# params -= p_fn() * 2
|
|
130
|
+
# f_m3 = closure(False)
|
|
131
131
|
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
132
|
+
# params += p_fn() * 3
|
|
133
|
+
# h = h**2 # because perturbation already multiplied by h
|
|
134
|
+
# return f_0, f_1, (27*f_1 - f_m1 - f_3 + f_m3) / (48 * h)
|
|
135
135
|
|
|
136
136
|
|
|
137
|
-
_RFD_FUNCS = {
|
|
137
|
+
_RFD_FUNCS: dict[_FD_Formula, Callable] = {
|
|
138
138
|
"forward": _rforward2,
|
|
139
139
|
"forward2": _rforward2,
|
|
140
140
|
"backward": _rbackward2,
|
|
@@ -147,14 +147,14 @@ _RFD_FUNCS = {
|
|
|
147
147
|
"central4": _rcentral4,
|
|
148
148
|
"forward4": _rforward4,
|
|
149
149
|
"forward5": _rforward5,
|
|
150
|
-
"bspsa4": _bgspsa4,
|
|
150
|
+
# "bspsa4": _bgspsa4,
|
|
151
151
|
}
|
|
152
152
|
|
|
153
153
|
|
|
154
154
|
class RandomizedFDM(GradApproximator):
|
|
155
155
|
"""Gradient approximation via a randomized finite-difference method.
|
|
156
156
|
|
|
157
|
-
|
|
157
|
+
Note:
|
|
158
158
|
This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients,
|
|
159
159
|
and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.
|
|
160
160
|
|
|
@@ -164,101 +164,57 @@ class RandomizedFDM(GradApproximator):
|
|
|
164
164
|
formula (_FD_Formula, optional): finite difference formula. Defaults to 'central2'.
|
|
165
165
|
distribution (Distributions, optional): distribution. Defaults to "rademacher".
|
|
166
166
|
If this is set to a value higher than zero, instead of using directional derivatives in a new random direction on each step, the direction changes gradually with momentum based on this value. This may make it possible to use methods with memory. Defaults to 0.
|
|
167
|
-
beta (float, optional): optinal momentum for generated perturbations. Defaults to 1e-3.
|
|
168
167
|
pre_generate (bool, optional):
|
|
169
168
|
whether to pre-generate gradient samples before each step. If samples are not pre-generated, whenever a method performs multiple closure evaluations, the gradient will be evaluated in different directions each time. Defaults to True.
|
|
170
169
|
seed (int | None | torch.Generator, optional): Seed for random generator. Defaults to None.
|
|
171
170
|
target (GradTarget, optional): what to set on var. Defaults to "closure".
|
|
172
171
|
|
|
173
172
|
Examples:
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
tz.m.RandomizedFDM(n_samples=100, distribution="gaussian", formula="forward2", h=1e-1),
|
|
220
|
-
tz.m.NewtonCG(hvp_method="forward"),
|
|
221
|
-
tz.m.Backtracking()
|
|
222
|
-
)
|
|
223
|
-
|
|
224
|
-
#### SPSA-NewtonCG
|
|
225
|
-
|
|
226
|
-
NewtonCG with hessian-vector product estimated via gradient difference
|
|
227
|
-
calls closure multiple times per step. If each closure call estimates gradients
|
|
228
|
-
with different perturbations, NewtonCG is unable to produce useful directions.
|
|
229
|
-
|
|
230
|
-
By setting pre_generate to True, perturbations are generated once before each step,
|
|
231
|
-
and each closure call estimates gradients using the same pre-generated perturbations.
|
|
232
|
-
This way closure-based algorithms are able to use gradients estimated in a consistent way.
|
|
233
|
-
|
|
234
|
-
.. code-block:: python
|
|
235
|
-
|
|
236
|
-
opt = tz.Modular(
|
|
237
|
-
model.parameters(),
|
|
238
|
-
tz.m.RandomizedFDM(n_samples=10),
|
|
239
|
-
tz.m.NewtonCG(hvp_method="forward", pre_generate=True),
|
|
240
|
-
tz.m.Backtracking()
|
|
241
|
-
)
|
|
242
|
-
|
|
243
|
-
#### SPSA-BFGS
|
|
244
|
-
|
|
245
|
-
L-BFGS uses a memory of past parameter and gradient differences. If past gradients
|
|
246
|
-
were estimated with different perturbations, L-BFGS directions will be useless.
|
|
247
|
-
|
|
248
|
-
To alleviate this momentum can be added to random perturbations to make sure they only
|
|
249
|
-
change by a little bit, and the history stays relevant. The momentum is determined by the :code:`beta` parameter.
|
|
250
|
-
The disadvantage is that the subspace the algorithm is able to explore changes slowly.
|
|
251
|
-
|
|
252
|
-
Additionally we will reset BFGS memory every 100 steps to remove influence from old gradient estimates.
|
|
253
|
-
|
|
254
|
-
.. code-block:: python
|
|
255
|
-
|
|
256
|
-
opt = tz.Modular(
|
|
257
|
-
model.parameters(),
|
|
258
|
-
tz.m.RandomizedFDM(n_samples=10, pre_generate=True, beta=0.99),
|
|
259
|
-
tz.m.BFGS(reset_interval=100),
|
|
260
|
-
tz.m.Backtracking()
|
|
261
|
-
)
|
|
173
|
+
#### Simultaneous perturbation stochastic approximation (SPSA) method
|
|
174
|
+
|
|
175
|
+
SPSA is randomized FDM with rademacher distribution and central formula.
|
|
176
|
+
```py
|
|
177
|
+
spsa = tz.Modular(
|
|
178
|
+
model.parameters(),
|
|
179
|
+
tz.m.RandomizedFDM(formula="central", distribution="rademacher"),
|
|
180
|
+
tz.m.LR(1e-2)
|
|
181
|
+
)
|
|
182
|
+
```
|
|
183
|
+
|
|
184
|
+
#### Random-direction stochastic approximation (RDSA) method
|
|
185
|
+
|
|
186
|
+
RDSA is randomized FDM with usually gaussian distribution and central formula.
|
|
187
|
+
```
|
|
188
|
+
rdsa = tz.Modular(
|
|
189
|
+
model.parameters(),
|
|
190
|
+
tz.m.RandomizedFDM(formula="central", distribution="gaussian"),
|
|
191
|
+
tz.m.LR(1e-2)
|
|
192
|
+
)
|
|
193
|
+
```
|
|
194
|
+
|
|
195
|
+
#### Gaussian smoothing method
|
|
196
|
+
|
|
197
|
+
GS uses many gaussian samples with possibly a larger finite difference step size.
|
|
198
|
+
```
|
|
199
|
+
gs = tz.Modular(
|
|
200
|
+
model.parameters(),
|
|
201
|
+
tz.m.RandomizedFDM(n_samples=100, distribution="gaussian", formula="forward2", h=1e-1),
|
|
202
|
+
tz.m.NewtonCG(hvp_method="forward"),
|
|
203
|
+
tz.m.Backtracking()
|
|
204
|
+
)
|
|
205
|
+
```
|
|
206
|
+
|
|
207
|
+
#### RandomizedFDM with momentum
|
|
208
|
+
|
|
209
|
+
Momentum might help by reducing the variance of the estimated gradients.
|
|
210
|
+
```
|
|
211
|
+
momentum_spsa = tz.Modular(
|
|
212
|
+
model.parameters(),
|
|
213
|
+
tz.m.RandomizedFDM(),
|
|
214
|
+
tz.m.HeavyBall(0.9),
|
|
215
|
+
tz.m.LR(1e-3)
|
|
216
|
+
)
|
|
217
|
+
```
|
|
262
218
|
"""
|
|
263
219
|
PRE_MULTIPLY_BY_H = True
|
|
264
220
|
def __init__(
|
|
@@ -267,106 +223,92 @@ class RandomizedFDM(GradApproximator):
|
|
|
267
223
|
n_samples: int = 1,
|
|
268
224
|
formula: _FD_Formula = "central",
|
|
269
225
|
distribution: Distributions = "rademacher",
|
|
270
|
-
beta: float = 0,
|
|
271
226
|
pre_generate = True,
|
|
272
227
|
seed: int | None | torch.Generator = None,
|
|
273
228
|
target: GradTarget = "closure",
|
|
274
229
|
):
|
|
275
|
-
defaults = dict(h=h, formula=formula, n_samples=n_samples, distribution=distribution,
|
|
230
|
+
defaults = dict(h=h, formula=formula, n_samples=n_samples, distribution=distribution, pre_generate=pre_generate, seed=seed)
|
|
276
231
|
super().__init__(defaults, target=target)
|
|
277
232
|
|
|
278
|
-
def reset(self):
|
|
279
|
-
self.state.clear()
|
|
280
|
-
generator = self.global_state.get('generator', None) # avoid resetting generator
|
|
281
|
-
self.global_state.clear()
|
|
282
|
-
if generator is not None: self.global_state['generator'] = generator
|
|
283
|
-
|
|
284
|
-
def _get_generator(self, seed: int | None | torch.Generator, params: list[torch.Tensor]):
|
|
285
|
-
if 'generator' not in self.global_state:
|
|
286
|
-
if isinstance(seed, torch.Generator): self.global_state['generator'] = seed
|
|
287
|
-
elif seed is not None: self.global_state['generator'] = torch.Generator(params[0].device).manual_seed(seed)
|
|
288
|
-
else: self.global_state['generator'] = None
|
|
289
|
-
return self.global_state['generator']
|
|
290
233
|
|
|
291
234
|
def pre_step(self, var):
|
|
292
|
-
h
|
|
293
|
-
|
|
294
|
-
n_samples = settings['n_samples']
|
|
295
|
-
distribution = settings['distribution']
|
|
296
|
-
pre_generate = settings['pre_generate']
|
|
235
|
+
h = self.get_settings(var.params, 'h')
|
|
236
|
+
pre_generate = self.defaults['pre_generate']
|
|
297
237
|
|
|
298
238
|
if pre_generate:
|
|
239
|
+
n_samples = self.defaults['n_samples']
|
|
240
|
+
distribution = self.defaults['distribution']
|
|
241
|
+
|
|
299
242
|
params = TensorList(var.params)
|
|
300
|
-
generator = self.
|
|
301
|
-
perturbations = [params.sample_like(distribution=distribution, generator=generator) for _ in range(n_samples)]
|
|
243
|
+
generator = self.get_generator(params[0].device, self.defaults['seed'])
|
|
244
|
+
perturbations = [params.sample_like(distribution=distribution, variance=1, generator=generator) for _ in range(n_samples)]
|
|
302
245
|
|
|
246
|
+
# this is false for ForwardGradient where h isn't used and it subclasses this
|
|
303
247
|
if self.PRE_MULTIPLY_BY_H:
|
|
304
248
|
torch._foreach_mul_([p for l in perturbations for p in l], [v for vv in h for v in [vv]*n_samples])
|
|
305
249
|
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
for param, prt in zip(params, zip(*perturbations)):
|
|
309
|
-
self.state[param]['perturbations'] = prt
|
|
310
|
-
|
|
311
|
-
else:
|
|
312
|
-
# lerp old and new perturbations. This makes the subspace change gradually
|
|
313
|
-
# which in theory might improve algorithms with history
|
|
314
|
-
for i,p in enumerate(params):
|
|
315
|
-
state = self.state[p]
|
|
316
|
-
if 'perturbations' not in state: state['perturbations'] = [p[i] for p in perturbations]
|
|
317
|
-
|
|
318
|
-
cur = [self.state[p]['perturbations'][:n_samples] for p in params]
|
|
319
|
-
cur_flat = [p for l in cur for p in l]
|
|
320
|
-
new_flat = [p for l in zip(*perturbations) for p in l]
|
|
321
|
-
betas = [1-v for b in beta for v in [b]*n_samples]
|
|
322
|
-
torch._foreach_lerp_(cur_flat, new_flat, betas)
|
|
250
|
+
for param, prt in zip(params, zip(*perturbations)):
|
|
251
|
+
self.state[param]['perturbations'] = prt
|
|
323
252
|
|
|
324
253
|
@torch.no_grad
|
|
325
254
|
def approximate(self, closure, params, loss):
|
|
326
255
|
params = TensorList(params)
|
|
327
|
-
orig_params = params.clone() # store to avoid small changes due to float imprecision
|
|
328
256
|
loss_approx = None
|
|
329
257
|
|
|
330
258
|
h = NumberList(self.settings[p]['h'] for p in params)
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
fd_fn = _RFD_FUNCS[
|
|
259
|
+
n_samples = self.defaults['n_samples']
|
|
260
|
+
distribution = self.defaults['distribution']
|
|
261
|
+
fd_fn = _RFD_FUNCS[self.defaults['formula']]
|
|
262
|
+
|
|
334
263
|
default = [None]*n_samples
|
|
335
264
|
perturbations = list(zip(*(self.state[p].get('perturbations', default) for p in params)))
|
|
336
|
-
|
|
337
|
-
generator = self._get_generator(settings['seed'], params)
|
|
265
|
+
generator = self.get_generator(params[0].device, self.defaults['seed'])
|
|
338
266
|
|
|
339
267
|
grad = None
|
|
340
268
|
for i in range(n_samples):
|
|
341
269
|
prt = perturbations[i]
|
|
342
|
-
|
|
270
|
+
|
|
271
|
+
if prt[0] is None:
|
|
272
|
+
prt = params.sample_like(distribution=distribution, generator=generator, variance=1).mul_(h)
|
|
273
|
+
|
|
343
274
|
else: prt = TensorList(prt)
|
|
344
275
|
|
|
345
276
|
loss, loss_approx, d = fd_fn(closure=closure, params=params, p_fn=lambda: prt, h=h, f_0=loss)
|
|
277
|
+
# here `d` is a numberlist of directional derivatives, due to per parameter `h` values.
|
|
278
|
+
|
|
279
|
+
# support for per-sample values which gives better estimate
|
|
280
|
+
if d[0].numel() > 1: d = d.map(torch.mean)
|
|
281
|
+
|
|
346
282
|
if grad is None: grad = prt * d
|
|
347
283
|
else: grad += prt * d
|
|
348
284
|
|
|
349
|
-
params.set_(orig_params)
|
|
350
285
|
assert grad is not None
|
|
351
286
|
if n_samples > 1: grad.div_(n_samples)
|
|
287
|
+
|
|
288
|
+
# mean if got per-sample values
|
|
289
|
+
if loss is not None:
|
|
290
|
+
if loss.numel() > 1:
|
|
291
|
+
loss = loss.mean()
|
|
292
|
+
|
|
293
|
+
if loss_approx is not None:
|
|
294
|
+
if loss_approx.numel() > 1:
|
|
295
|
+
loss_approx = loss_approx.mean()
|
|
296
|
+
|
|
352
297
|
return grad, loss, loss_approx
|
|
353
298
|
|
|
354
299
|
class SPSA(RandomizedFDM):
|
|
355
300
|
"""
|
|
356
301
|
Gradient approximation via Simultaneous perturbation stochastic approximation (SPSA) method.
|
|
357
302
|
|
|
358
|
-
|
|
303
|
+
Note:
|
|
359
304
|
This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients,
|
|
360
305
|
and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.
|
|
361
306
|
|
|
362
|
-
|
|
363
307
|
Args:
|
|
364
308
|
h (float, optional): finite difference step size of jvp_method is set to `forward` or `central`. Defaults to 1e-3.
|
|
365
309
|
n_samples (int, optional): number of random gradient samples. Defaults to 1.
|
|
366
310
|
formula (_FD_Formula, optional): finite difference formula. Defaults to 'central2'.
|
|
367
311
|
distribution (Distributions, optional): distribution. Defaults to "rademacher".
|
|
368
|
-
beta (float, optional):
|
|
369
|
-
If this is set to a value higher than zero, instead of using directional derivatives in a new random direction on each step, the direction changes gradually with momentum based on this value. This may make it possible to use methods with memory. Defaults to 0.
|
|
370
312
|
pre_generate (bool, optional):
|
|
371
313
|
whether to pre-generate gradient samples before each step. If samples are not pre-generated, whenever a method performs multiple closure evaluations, the gradient will be evaluated in different directions each time. Defaults to True.
|
|
372
314
|
seed (int | None | torch.Generator, optional): Seed for random generator. Defaults to None.
|
|
@@ -380,7 +322,7 @@ class RDSA(RandomizedFDM):
|
|
|
380
322
|
"""
|
|
381
323
|
Gradient approximation via Random-direction stochastic approximation (RDSA) method.
|
|
382
324
|
|
|
383
|
-
|
|
325
|
+
Note:
|
|
384
326
|
This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients,
|
|
385
327
|
and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.
|
|
386
328
|
|
|
@@ -389,8 +331,6 @@ class RDSA(RandomizedFDM):
|
|
|
389
331
|
n_samples (int, optional): number of random gradient samples. Defaults to 1.
|
|
390
332
|
formula (_FD_Formula, optional): finite difference formula. Defaults to 'central2'.
|
|
391
333
|
distribution (Distributions, optional): distribution. Defaults to "gaussian".
|
|
392
|
-
beta (float, optional):
|
|
393
|
-
If this is set to a value higher than zero, instead of using directional derivatives in a new random direction on each step, the direction changes gradually with momentum based on this value. This may make it possible to use methods with memory. Defaults to 0.
|
|
394
334
|
pre_generate (bool, optional):
|
|
395
335
|
whether to pre-generate gradient samples before each step. If samples are not pre-generated, whenever a method performs multiple closure evaluations, the gradient will be evaluated in different directions each time. Defaults to True.
|
|
396
336
|
seed (int | None | torch.Generator, optional): Seed for random generator. Defaults to None.
|
|
@@ -406,18 +346,17 @@ class RDSA(RandomizedFDM):
|
|
|
406
346
|
n_samples: int = 1,
|
|
407
347
|
formula: _FD_Formula = "central2",
|
|
408
348
|
distribution: Distributions = "gaussian",
|
|
409
|
-
beta: float = 0,
|
|
410
349
|
pre_generate = True,
|
|
411
350
|
target: GradTarget = "closure",
|
|
412
351
|
seed: int | None | torch.Generator = None,
|
|
413
352
|
):
|
|
414
|
-
super().__init__(h=h, n_samples=n_samples,formula=formula,distribution=distribution,
|
|
353
|
+
super().__init__(h=h, n_samples=n_samples,formula=formula,distribution=distribution,pre_generate=pre_generate,target=target,seed=seed)
|
|
415
354
|
|
|
416
355
|
class GaussianSmoothing(RandomizedFDM):
|
|
417
356
|
"""
|
|
418
357
|
Gradient approximation via Gaussian smoothing method.
|
|
419
358
|
|
|
420
|
-
|
|
359
|
+
Note:
|
|
421
360
|
This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients,
|
|
422
361
|
and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.
|
|
423
362
|
|
|
@@ -426,8 +365,6 @@ class GaussianSmoothing(RandomizedFDM):
|
|
|
426
365
|
n_samples (int, optional): number of random gradient samples. Defaults to 100.
|
|
427
366
|
formula (_FD_Formula, optional): finite difference formula. Defaults to 'forward2'.
|
|
428
367
|
distribution (Distributions, optional): distribution. Defaults to "gaussian".
|
|
429
|
-
beta (float, optional):
|
|
430
|
-
If this is set to a value higher than zero, instead of using directional derivatives in a new random direction on each step, the direction changes gradually with momentum based on this value. This may make it possible to use methods with memory. Defaults to 0.
|
|
431
368
|
pre_generate (bool, optional):
|
|
432
369
|
whether to pre-generate gradient samples before each step. If samples are not pre-generated, whenever a method performs multiple closure evaluations, the gradient will be evaluated in different directions each time. Defaults to True.
|
|
433
370
|
seed (int | None | torch.Generator, optional): Seed for random generator. Defaults to None.
|
|
@@ -443,17 +380,16 @@ class GaussianSmoothing(RandomizedFDM):
|
|
|
443
380
|
n_samples: int = 100,
|
|
444
381
|
formula: _FD_Formula = "forward2",
|
|
445
382
|
distribution: Distributions = "gaussian",
|
|
446
|
-
beta: float = 0,
|
|
447
383
|
pre_generate = True,
|
|
448
384
|
target: GradTarget = "closure",
|
|
449
385
|
seed: int | None | torch.Generator = None,
|
|
450
386
|
):
|
|
451
|
-
super().__init__(h=h, n_samples=n_samples,formula=formula,distribution=distribution,
|
|
387
|
+
super().__init__(h=h, n_samples=n_samples,formula=formula,distribution=distribution,pre_generate=pre_generate,target=target,seed=seed)
|
|
452
388
|
|
|
453
389
|
class MeZO(GradApproximator):
|
|
454
390
|
"""Gradient approximation via memory-efficient zeroth order optimizer (MeZO) - https://arxiv.org/abs/2305.17333.
|
|
455
391
|
|
|
456
|
-
|
|
392
|
+
Note:
|
|
457
393
|
This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients,
|
|
458
394
|
and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.
|
|
459
395
|
|
|
@@ -476,15 +412,18 @@ class MeZO(GradApproximator):
|
|
|
476
412
|
super().__init__(defaults, target=target)
|
|
477
413
|
|
|
478
414
|
def _seeded_perturbation(self, params: list[torch.Tensor], distribution, seed, h):
|
|
479
|
-
|
|
480
|
-
distribution=distribution,
|
|
481
|
-
|
|
415
|
+
prt = TensorList(params).sample_like(
|
|
416
|
+
distribution=distribution,
|
|
417
|
+
variance=h,
|
|
418
|
+
generator=torch.Generator(params[0].device).manual_seed(seed)
|
|
419
|
+
)
|
|
420
|
+
return prt
|
|
482
421
|
|
|
483
422
|
def pre_step(self, var):
|
|
484
423
|
h = NumberList(self.settings[p]['h'] for p in var.params)
|
|
485
|
-
|
|
486
|
-
n_samples =
|
|
487
|
-
distribution =
|
|
424
|
+
|
|
425
|
+
n_samples = self.defaults['n_samples']
|
|
426
|
+
distribution = self.defaults['distribution']
|
|
488
427
|
|
|
489
428
|
step = var.current_step
|
|
490
429
|
|
|
@@ -503,9 +442,9 @@ class MeZO(GradApproximator):
|
|
|
503
442
|
loss_approx = None
|
|
504
443
|
|
|
505
444
|
h = NumberList(self.settings[p]['h'] for p in params)
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
445
|
+
n_samples = self.defaults['n_samples']
|
|
446
|
+
fd_fn = _RFD_FUNCS[self.defaults['formula']]
|
|
447
|
+
|
|
509
448
|
prt_fns = self.global_state['prt_fns']
|
|
510
449
|
|
|
511
450
|
grad = None
|
|
@@ -1 +1 @@
|
|
|
1
|
-
from .higher_order_newton import HigherOrderNewton
|
|
1
|
+
from .higher_order_newton import HigherOrderNewton
|
|
@@ -13,7 +13,7 @@ import torch
|
|
|
13
13
|
from ...core import Chainable, Module, apply_transform
|
|
14
14
|
from ...utils import TensorList, vec_to_tensors, vec_to_tensors_
|
|
15
15
|
from ...utils.derivatives import (
|
|
16
|
-
|
|
16
|
+
flatten_jacobian,
|
|
17
17
|
jacobian_wrt,
|
|
18
18
|
)
|
|
19
19
|
|
|
@@ -148,21 +148,16 @@ class HigherOrderNewton(Module):
|
|
|
148
148
|
"""A basic arbitrary order newton's method with optional trust region and proximal penalty.
|
|
149
149
|
|
|
150
150
|
This constructs an nth order taylor approximation via autograd and minimizes it with
|
|
151
|
-
scipy.optimize.minimize trust region newton solvers with optional proximal penalty.
|
|
151
|
+
``scipy.optimize.minimize`` trust region newton solvers with optional proximal penalty.
|
|
152
152
|
|
|
153
|
-
|
|
154
|
-
|
|
153
|
+
The hessian of taylor approximation is easier to evaluate, plus it can be evaluated in a batched mode,
|
|
154
|
+
so it can be more efficient in very specific instances.
|
|
155
155
|
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
as it needs to re-evaluate the loss and gradients for calculating higher order derivatives.
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
.. warning::
|
|
162
|
-
this uses roughly O(N^order) memory and solving the subproblem can be very expensive.
|
|
163
|
-
|
|
164
|
-
.. warning::
|
|
165
|
-
"none" and "proximal" trust methods may generate subproblems that have no minima, causing divergence.
|
|
156
|
+
Notes:
|
|
157
|
+
- In most cases HigherOrderNewton should be the first module in the chain because it relies on extra autograd. Use the ``inner`` argument if you wish to apply Newton preconditioning to another module's output.
|
|
158
|
+
- This module requires the a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating higher order derivatives. The closure must accept a ``backward`` argument (refer to documentation).
|
|
159
|
+
- this uses roughly O(N^order) memory and solving the subproblem is very expensive.
|
|
160
|
+
- "none" and "proximal" trust methods may generate subproblems that have no minima, causing divergence.
|
|
166
161
|
|
|
167
162
|
Args:
|
|
168
163
|
|
|
@@ -178,7 +173,7 @@ class HigherOrderNewton(Module):
|
|
|
178
173
|
increase (float, optional): trust region multiplier on good steps. Defaults to 1.5.
|
|
179
174
|
decrease (float, optional): trust region multiplier on bad steps. Defaults to 0.75.
|
|
180
175
|
trust_init (float | None, optional):
|
|
181
|
-
initial trust region size. If none, defaults to 1 on :code:`trust_method="bounds"` and 0.1 on
|
|
176
|
+
initial trust region size. If none, defaults to 1 on :code:`trust_method="bounds"` and 0.1 on ``"proximal"``. Defaults to None.
|
|
182
177
|
trust_tol (float, optional):
|
|
183
178
|
Maximum ratio of expected loss reduction to actual reduction for trust region increase.
|
|
184
179
|
Should 1 or higer. Defaults to 2.
|
|
@@ -191,11 +186,14 @@ class HigherOrderNewton(Module):
|
|
|
191
186
|
self,
|
|
192
187
|
order: int = 4,
|
|
193
188
|
trust_method: Literal['bounds', 'proximal', 'none'] | None = 'bounds',
|
|
194
|
-
nplus: float =
|
|
189
|
+
nplus: float = 3.5,
|
|
195
190
|
nminus: float = 0.25,
|
|
191
|
+
rho_good: float = 0.99,
|
|
192
|
+
rho_bad: float = 1e-4,
|
|
196
193
|
init: float | None = None,
|
|
197
194
|
eta: float = 1e-6,
|
|
198
195
|
max_attempts = 10,
|
|
196
|
+
boundary_tol: float = 1e-2,
|
|
199
197
|
de_iters: int | None = None,
|
|
200
198
|
vectorize: bool = True,
|
|
201
199
|
):
|
|
@@ -203,7 +201,7 @@ class HigherOrderNewton(Module):
|
|
|
203
201
|
if trust_method == 'bounds': init = 1
|
|
204
202
|
else: init = 0.1
|
|
205
203
|
|
|
206
|
-
defaults = dict(order=order, trust_method=trust_method, nplus=nplus, nminus=nminus, eta=eta, init=init, vectorize=vectorize, de_iters=de_iters, max_attempts=max_attempts)
|
|
204
|
+
defaults = dict(order=order, trust_method=trust_method, nplus=nplus, nminus=nminus, eta=eta, init=init, vectorize=vectorize, de_iters=de_iters, max_attempts=max_attempts, boundary_tol=boundary_tol, rho_good=rho_good, rho_bad=rho_bad)
|
|
207
205
|
super().__init__(defaults)
|
|
208
206
|
|
|
209
207
|
@torch.no_grad
|
|
@@ -222,6 +220,9 @@ class HigherOrderNewton(Module):
|
|
|
222
220
|
de_iters = settings['de_iters']
|
|
223
221
|
max_attempts = settings['max_attempts']
|
|
224
222
|
vectorize = settings['vectorize']
|
|
223
|
+
boundary_tol = settings['boundary_tol']
|
|
224
|
+
rho_good = settings['rho_good']
|
|
225
|
+
rho_bad = settings['rho_bad']
|
|
225
226
|
|
|
226
227
|
# ------------------------ calculate grad and hessian ------------------------ #
|
|
227
228
|
with torch.enable_grad():
|
|
@@ -241,7 +242,7 @@ class HigherOrderNewton(Module):
|
|
|
241
242
|
T_list = jacobian_wrt([T], params, create_graph=not is_last, batched=vectorize)
|
|
242
243
|
with torch.no_grad() if is_last else nullcontext():
|
|
243
244
|
# the shape is (ndim, ) * order
|
|
244
|
-
T =
|
|
245
|
+
T = flatten_jacobian(T_list).view(n, n, *T.shape[1:])
|
|
245
246
|
derivatives.append(T)
|
|
246
247
|
|
|
247
248
|
x0 = torch.cat([p.ravel() for p in params])
|
|
@@ -254,8 +255,13 @@ class HigherOrderNewton(Module):
|
|
|
254
255
|
|
|
255
256
|
# load trust region value
|
|
256
257
|
trust_value = self.global_state.get('trust_region', init)
|
|
257
|
-
if trust_value < 1e-8 or trust_value > 1e16: trust_value = self.global_state['trust_region'] = settings['init']
|
|
258
258
|
|
|
259
|
+
# make sure its not too small or too large
|
|
260
|
+
finfo = torch.finfo(x0.dtype)
|
|
261
|
+
if trust_value < finfo.tiny*2 or trust_value > finfo.max / (2*nplus):
|
|
262
|
+
trust_value = self.global_state['trust_region'] = settings['init']
|
|
263
|
+
|
|
264
|
+
# determine tr and prox values
|
|
259
265
|
if trust_method is None: trust_method = 'none'
|
|
260
266
|
else: trust_method = trust_method.lower()
|
|
261
267
|
|
|
@@ -297,13 +303,15 @@ class HigherOrderNewton(Module):
|
|
|
297
303
|
|
|
298
304
|
rho = reduction / (max(pred_reduction, 1e-8))
|
|
299
305
|
# failed step
|
|
300
|
-
if rho <
|
|
306
|
+
if rho < rho_bad:
|
|
301
307
|
self.global_state['trust_region'] = trust_value * nminus
|
|
302
308
|
|
|
303
309
|
# very good step
|
|
304
|
-
elif rho >
|
|
305
|
-
|
|
306
|
-
|
|
310
|
+
elif rho > rho_good:
|
|
311
|
+
step = (x_star - x0)
|
|
312
|
+
magn = torch.linalg.vector_norm(step) # pylint:disable=not-callable
|
|
313
|
+
if trust_method == 'proximal' or (trust_value - magn) / trust_value <= boundary_tol:
|
|
314
|
+
# close to boundary
|
|
307
315
|
self.global_state['trust_region'] = trust_value * nplus
|
|
308
316
|
|
|
309
317
|
# if the ratio is high enough then accept the proposed step
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .gn import SumOfSquares, GaussNewton
|