torchzero 0.3.13__py3-none-any.whl → 0.3.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tests/test_opts.py CHANGED
@@ -400,13 +400,6 @@ RandomizedFDM_4samples = Run(
400
400
  func='booth', steps=50, loss=1e-5, merge_invariant=True,
401
401
  sphere_steps=100, sphere_loss=400,
402
402
  )
403
- RandomizedFDM_4samples_lerp = Run(
404
- func_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(n_samples=4, beta=0.99, seed=0), tz.m.LR(0.1)),
405
- sphere_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(n_samples=4, beta=0.9, seed=0), tz.m.LR(0.001)),
406
- needs_closure=True,
407
- func='booth', steps=50, loss=1e-5, merge_invariant=True,
408
- sphere_steps=100, sphere_loss=505,
409
- )
410
403
  RandomizedFDM_4samples_no_pre_generate = Run(
411
404
  func_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(n_samples=4, pre_generate=False, seed=0), tz.m.LR(0.1)),
412
405
  sphere_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(n_samples=4, pre_generate=False, seed=0), tz.m.LR(0.001)),
torchzero/core/module.py CHANGED
@@ -531,7 +531,11 @@ class Module(ABC):
531
531
  def reset(self):
532
532
  """Resets the internal state of the module (e.g. momentum) and all children. By default clears state and global state."""
533
533
  self.state.clear()
534
+
535
+ generator = self.global_state.get("generator", None)
534
536
  self.global_state.clear()
537
+ if generator is not None: self.global_state["generator"] = generator
538
+
535
539
  for c in self.children.values(): c.reset()
536
540
 
537
541
  def reset_for_online(self):
@@ -50,7 +50,7 @@ class ConguateGradientBase(Transform, ABC):
50
50
  ```
51
51
 
52
52
  """
53
- def __init__(self, defaults = None, clip_beta: bool = False, restart_interval: int | None | Literal['auto'] = None, inner: Chainable | None = None):
53
+ def __init__(self, defaults, clip_beta: bool, restart_interval: int | None | Literal['auto'], inner: Chainable | None = None):
54
54
  if defaults is None: defaults = {}
55
55
  defaults['restart_interval'] = restart_interval
56
56
  defaults['clip_beta'] = clip_beta
@@ -140,8 +140,8 @@ class PolakRibiere(ConguateGradientBase):
140
140
  Note:
141
141
  This requires step size to be determined via a line search, so put a line search like ``tz.m.StrongWolfe(c2=0.1, a_init="first-order")`` after this.
142
142
  """
143
- def __init__(self, clip_beta=True, restart_interval: int | None = None, inner: Chainable | None = None):
144
- super().__init__(clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)
143
+ def __init__(self, clip_beta=True, restart_interval: int | None | Literal['auto'] = 'auto', inner: Chainable | None = None):
144
+ super().__init__({}, clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)
145
145
 
146
146
  def get_beta(self, p, g, prev_g, prev_d):
147
147
  return polak_ribiere_beta(g, prev_g)
@@ -158,7 +158,7 @@ class FletcherReeves(ConguateGradientBase):
158
158
  This requires step size to be determined via a line search, so put a line search like ``tz.m.StrongWolfe(c2=0.1, a_init="first-order")`` after this.
159
159
  """
160
160
  def __init__(self, restart_interval: int | None | Literal['auto'] = 'auto', clip_beta=False, inner: Chainable | None = None):
161
- super().__init__(clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)
161
+ super().__init__({}, clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)
162
162
 
163
163
  def initialize(self, p, g):
164
164
  self.global_state['prev_gg'] = g.dot(g)
@@ -183,8 +183,8 @@ class HestenesStiefel(ConguateGradientBase):
183
183
  Note:
184
184
  This requires step size to be determined via a line search, so put a line search like ``tz.m.StrongWolfe(c2=0.1, a_init="first-order")`` after this.
185
185
  """
186
- def __init__(self, restart_interval: int | None | Literal['auto'] = None, clip_beta=False, inner: Chainable | None = None):
187
- super().__init__(clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)
186
+ def __init__(self, restart_interval: int | None | Literal['auto'] = 'auto', clip_beta=False, inner: Chainable | None = None):
187
+ super().__init__({}, clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)
188
188
 
189
189
  def get_beta(self, p, g, prev_g, prev_d):
190
190
  return hestenes_stiefel_beta(g, prev_d, prev_g)
@@ -202,8 +202,8 @@ class DaiYuan(ConguateGradientBase):
202
202
  Note:
203
203
  This requires step size to be determined via a line search, so put a line search like ``tz.m.StrongWolfe(c2=0.1)`` after this.
204
204
  """
205
- def __init__(self, restart_interval: int | None | Literal['auto'] = None, clip_beta=False, inner: Chainable | None = None):
206
- super().__init__(clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)
205
+ def __init__(self, restart_interval: int | None | Literal['auto'] = 'auto', clip_beta=False, inner: Chainable | None = None):
206
+ super().__init__({}, clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)
207
207
 
208
208
  def get_beta(self, p, g, prev_g, prev_d):
209
209
  return dai_yuan_beta(g, prev_d, prev_g)
@@ -221,8 +221,8 @@ class LiuStorey(ConguateGradientBase):
221
221
  Note:
222
222
  This requires step size to be determined via a line search, so put a line search like ``tz.m.StrongWolfe(c2=0.1, a_init="first-order")`` after this.
223
223
  """
224
- def __init__(self, restart_interval: int | None | Literal['auto'] = None, clip_beta=False, inner: Chainable | None = None):
225
- super().__init__(clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)
224
+ def __init__(self, restart_interval: int | None | Literal['auto'] = 'auto', clip_beta=False, inner: Chainable | None = None):
225
+ super().__init__({}, clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)
226
226
 
227
227
  def get_beta(self, p, g, prev_g, prev_d):
228
228
  return liu_storey_beta(g, prev_d, prev_g)
@@ -239,8 +239,8 @@ class ConjugateDescent(ConguateGradientBase):
239
239
  Note:
240
240
  This requires step size to be determined via a line search, so put a line search like ``tz.m.StrongWolfe(c2=0.1, a_init="first-order")`` after this.
241
241
  """
242
- def __init__(self, restart_interval: int | None | Literal['auto'] = None, clip_beta=False, inner: Chainable | None = None):
243
- super().__init__(clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)
242
+ def __init__(self, restart_interval: int | None | Literal['auto'] = 'auto', clip_beta=False, inner: Chainable | None = None):
243
+ super().__init__({}, clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)
244
244
 
245
245
  def get_beta(self, p, g, prev_g, prev_d):
246
246
  return conjugate_descent_beta(g, prev_d, prev_g)
@@ -264,8 +264,8 @@ class HagerZhang(ConguateGradientBase):
264
264
  Note:
265
265
  This requires step size to be determined via a line search, so put a line search like ``tz.m.StrongWolfe(c2=0.1, a_init="first-order")`` after this.
266
266
  """
267
- def __init__(self, restart_interval: int | None | Literal['auto'] = None, clip_beta=False, inner: Chainable | None = None):
268
- super().__init__(clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)
267
+ def __init__(self, restart_interval: int | None | Literal['auto'] = 'auto', clip_beta=False, inner: Chainable | None = None):
268
+ super().__init__({}, clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)
269
269
 
270
270
  def get_beta(self, p, g, prev_g, prev_d):
271
271
  return hager_zhang_beta(g, prev_d, prev_g)
@@ -291,8 +291,8 @@ class DYHS(ConguateGradientBase):
291
291
  Note:
292
292
  This requires step size to be determined via a line search, so put a line search like ``tz.m.StrongWolfe(c2=0.1, a_init="first-order")`` after this.
293
293
  """
294
- def __init__(self, restart_interval: int | None | Literal['auto'] = None, clip_beta=False, inner: Chainable | None = None):
295
- super().__init__(clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)
294
+ def __init__(self, restart_interval: int | None | Literal['auto'] = 'auto', clip_beta=False, inner: Chainable | None = None):
295
+ super().__init__({}, clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)
296
296
 
297
297
  def get_beta(self, p, g, prev_g, prev_d):
298
298
  return dyhs_beta(g, prev_d, prev_g)
@@ -0,0 +1,93 @@
1
+ from collections.abc import Callable
2
+ from typing import Any
3
+ from functools import partial
4
+ import torch
5
+
6
+ from ...utils import TensorList, NumberList
7
+ from ..grad_approximation.grad_approximator import GradApproximator, GradTarget
8
+
9
+ class SPSA1(GradApproximator):
10
+ """One-measurement variant of SPSA. Unlike standard two-measurement SPSA, the estimated
11
+ gradient often won't be a descent direction, however the expectation is biased towards
12
+ the descent direction. Therefore this variant of SPSA is only recommended for a specific
13
+ class of problems where the objective function changes on each evaluation,
14
+ for example feedback control problems.
15
+
16
+ Args:
17
+ h (float, optional):
18
+ finite difference step size, recommended to set to same value as learning rate. Defaults to 1e-3.
19
+ n_samples (int, optional): number of random samples. Defaults to 1.
20
+ eps (float, optional): measurement noise estimate. Defaults to 1e-8.
21
+ seed (int | None | torch.Generator, optional): random seed. Defaults to None.
22
+ target (GradTarget, optional): what to set on closure. Defaults to "closure".
23
+
24
+ Reference:
25
+ [SPALL, JAMES C. "A One-measurement Form of Simultaneous Stochastic Approximation](https://www.jhuapl.edu/spsa/PDF-SPSA/automatica97_one_measSPSA.pdf)."
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ h: float = 1e-3,
31
+ n_samples: int = 1,
32
+ eps: float = 1e-8, # measurement noise
33
+ pre_generate = False,
34
+ seed: int | None | torch.Generator = None,
35
+ target: GradTarget = "closure",
36
+ ):
37
+ defaults = dict(h=h, eps=eps, n_samples=n_samples, pre_generate=pre_generate, seed=seed)
38
+ super().__init__(defaults, target=target)
39
+
40
+
41
+ def pre_step(self, var):
42
+
43
+ if self.defaults['pre_generate']:
44
+
45
+ params = TensorList(var.params)
46
+ generator = self.get_generator(params[0].device, self.defaults['seed'])
47
+
48
+ n_samples = self.defaults['n_samples']
49
+ h = self.get_settings(var.params, 'h')
50
+
51
+ perturbations = [params.sample_like(distribution='rademacher', generator=generator) for _ in range(n_samples)]
52
+ torch._foreach_mul_([p for l in perturbations for p in l], [v for vv in h for v in [vv]*n_samples])
53
+
54
+ for param, prt in zip(params, zip(*perturbations)):
55
+ self.state[param]['perturbations'] = prt
56
+
57
+ @torch.no_grad
58
+ def approximate(self, closure, params, loss):
59
+ generator = self.get_generator(params[0].device, self.defaults['seed'])
60
+
61
+ params = TensorList(params)
62
+ orig_params = params.clone() # store to avoid small changes due to float imprecision
63
+ loss_approx = None
64
+
65
+ h, eps = self.get_settings(params, "h", "eps", cls=NumberList)
66
+ n_samples = self.defaults['n_samples']
67
+
68
+ default = [None]*n_samples
69
+ # perturbations are pre-multiplied by h
70
+ perturbations = list(zip(*(self.state[p].get('perturbations', default) for p in params)))
71
+
72
+ grad = None
73
+ for i in range(n_samples):
74
+ prt = perturbations[i]
75
+
76
+ if prt[0] is None:
77
+ prt = params.sample_like('rademacher', generator=generator).mul_(h)
78
+
79
+ else: prt = TensorList(prt)
80
+
81
+ params += prt
82
+ L = closure(False)
83
+ params.copy_(orig_params)
84
+
85
+ sample = prt * ((L + eps) / h)
86
+ if grad is None: grad = sample
87
+ else: grad += sample
88
+
89
+ assert grad is not None
90
+ if n_samples > 1: grad.div_(n_samples)
91
+
92
+ # mean if got per-sample values
93
+ return grad, loss, loss_approx
@@ -1,4 +1,4 @@
1
1
  from .grad_approximator import GradApproximator, GradTarget
2
2
  from .fdm import FDM
3
3
  from .rfdm import RandomizedFDM, MeZO, SPSA, RDSA, GaussianSmoothing
4
- from .forward_gradient import ForwardGradient
4
+ from .forward_gradient import ForwardGradient
@@ -23,8 +23,6 @@ class ForwardGradient(RandomizedFDM):
23
23
  Args:
24
24
  n_samples (int, optional): number of random gradient samples. Defaults to 1.
25
25
  distribution (Distributions, optional): distribution for random gradient samples. Defaults to "gaussian".
26
- beta (float, optional):
27
- If this is set to a value higher than zero, instead of using directional derivatives in a new random direction on each step, the direction changes gradually with momentum based on this value. This may make it possible to use methods with memory. Defaults to 0.
28
26
  pre_generate (bool, optional):
29
27
  whether to pre-generate gradient samples before each step. If samples are not pre-generated, whenever a method performs multiple closure evaluations, the gradient will be evaluated in different directions each time. Defaults to True.
30
28
  jvp_method (str, optional):
@@ -40,14 +38,13 @@ class ForwardGradient(RandomizedFDM):
40
38
  self,
41
39
  n_samples: int = 1,
42
40
  distribution: Distributions = "gaussian",
43
- beta: float = 0,
44
41
  pre_generate = True,
45
42
  jvp_method: Literal['autograd', 'forward', 'central'] = 'autograd',
46
43
  h: float = 1e-3,
47
44
  target: GradTarget = "closure",
48
45
  seed: int | None | torch.Generator = None,
49
46
  ):
50
- super().__init__(h=h, n_samples=n_samples, distribution=distribution, beta=beta, target=target, pre_generate=pre_generate, seed=seed)
47
+ super().__init__(h=h, n_samples=n_samples, distribution=distribution, target=target, pre_generate=pre_generate, seed=seed)
51
48
  self.defaults['jvp_method'] = jvp_method
52
49
 
53
50
  @torch.no_grad
@@ -62,7 +59,7 @@ class ForwardGradient(RandomizedFDM):
62
59
  distribution = settings['distribution']
63
60
  default = [None]*n_samples
64
61
  perturbations = list(zip(*(self.state[p].get('perturbations', default) for p in params)))
65
- generator = self._get_generator(settings['seed'], params)
62
+ generator = self.get_generator(params[0].device, self.defaults['seed'])
66
63
 
67
64
  grad = None
68
65
  for i in range(n_samples):
@@ -164,7 +164,6 @@ class RandomizedFDM(GradApproximator):
164
164
  formula (_FD_Formula, optional): finite difference formula. Defaults to 'central2'.
165
165
  distribution (Distributions, optional): distribution. Defaults to "rademacher".
166
166
  If this is set to a value higher than zero, instead of using directional derivatives in a new random direction on each step, the direction changes gradually with momentum based on this value. This may make it possible to use methods with memory. Defaults to 0.
167
- beta (float, optional): optinal momentum for generated perturbations. Defaults to 1e-3.
168
167
  pre_generate (bool, optional):
169
168
  whether to pre-generate gradient samples before each step. If samples are not pre-generated, whenever a method performs multiple closure evaluations, the gradient will be evaluated in different directions each time. Defaults to True.
170
169
  seed (int | None | torch.Generator, optional): Seed for random generator. Defaults to None.
@@ -173,7 +172,7 @@ class RandomizedFDM(GradApproximator):
173
172
  Examples:
174
173
  #### Simultaneous perturbation stochastic approximation (SPSA) method
175
174
 
176
- SPSA is randomized finite differnce with rademacher distribution and central formula.
175
+ SPSA is randomized FDM with rademacher distribution and central formula.
177
176
  ```py
178
177
  spsa = tz.Modular(
179
178
  model.parameters(),
@@ -184,8 +183,7 @@ class RandomizedFDM(GradApproximator):
184
183
 
185
184
  #### Random-direction stochastic approximation (RDSA) method
186
185
 
187
- RDSA is randomized finite differnce with usually gaussian distribution and central formula.
188
-
186
+ RDSA is randomized FDM with usually gaussian distribution and central formula.
189
187
  ```
190
188
  rdsa = tz.Modular(
191
189
  model.parameters(),
@@ -194,23 +192,9 @@ class RandomizedFDM(GradApproximator):
194
192
  )
195
193
  ```
196
194
 
197
- #### RandomizedFDM with momentum
198
-
199
- Momentum might help by reducing the variance of the estimated gradients.
200
-
201
- ```
202
- momentum_spsa = tz.Modular(
203
- model.parameters(),
204
- tz.m.RandomizedFDM(),
205
- tz.m.HeavyBall(0.9),
206
- tz.m.LR(1e-3)
207
- )
208
- ```
209
-
210
195
  #### Gaussian smoothing method
211
196
 
212
197
  GS uses many gaussian samples with possibly a larger finite difference step size.
213
-
214
198
  ```
215
199
  gs = tz.Modular(
216
200
  model.parameters(),
@@ -220,44 +204,15 @@ class RandomizedFDM(GradApproximator):
220
204
  )
221
205
  ```
222
206
 
223
- #### SPSA-NewtonCG
224
-
225
- NewtonCG with hessian-vector product estimated via gradient difference
226
- calls closure multiple times per step. If each closure call estimates gradients
227
- with different perturbations, NewtonCG is unable to produce useful directions.
228
-
229
- By setting pre_generate to True, perturbations are generated once before each step,
230
- and each closure call estimates gradients using the same pre-generated perturbations.
231
- This way closure-based algorithms are able to use gradients estimated in a consistent way.
207
+ #### RandomizedFDM with momentum
232
208
 
209
+ Momentum might help by reducing the variance of the estimated gradients.
233
210
  ```
234
- opt = tz.Modular(
211
+ momentum_spsa = tz.Modular(
235
212
  model.parameters(),
236
- tz.m.RandomizedFDM(n_samples=10),
237
- tz.m.NewtonCG(hvp_method="forward", pre_generate=True),
238
- tz.m.Backtracking()
239
- )
240
- ```
241
-
242
- #### SPSA-LBFGS
243
-
244
- LBFGS uses a memory of past parameter and gradient differences. If past gradients
245
- were estimated with different perturbations, LBFGS directions will be useless.
246
-
247
- To alleviate this momentum can be added to random perturbations to make sure they only
248
- change by a little bit, and the history stays relevant. The momentum is determined by the :code:`beta` parameter.
249
- The disadvantage is that the subspace the algorithm is able to explore changes slowly.
250
-
251
- Additionally we will reset SPSA and LBFGS memory every 100 steps to remove influence from old gradient estimates.
252
-
253
- ```
254
- opt = tz.Modular(
255
- bench.parameters(),
256
- tz.m.ResetEvery(
257
- [tz.m.RandomizedFDM(n_samples=10, pre_generate=True, beta=0.99), tz.m.LBFGS()],
258
- steps = 100,
259
- ),
260
- tz.m.Backtracking()
213
+ tz.m.RandomizedFDM(),
214
+ tz.m.HeavyBall(0.9),
215
+ tz.m.LR(1e-3)
261
216
  )
262
217
  ```
263
218
  """
@@ -268,75 +223,46 @@ class RandomizedFDM(GradApproximator):
268
223
  n_samples: int = 1,
269
224
  formula: _FD_Formula = "central",
270
225
  distribution: Distributions = "rademacher",
271
- beta: float = 0,
272
226
  pre_generate = True,
273
227
  seed: int | None | torch.Generator = None,
274
228
  target: GradTarget = "closure",
275
229
  ):
276
- defaults = dict(h=h, formula=formula, n_samples=n_samples, distribution=distribution, beta=beta, pre_generate=pre_generate, seed=seed)
230
+ defaults = dict(h=h, formula=formula, n_samples=n_samples, distribution=distribution, pre_generate=pre_generate, seed=seed)
277
231
  super().__init__(defaults, target=target)
278
232
 
279
- def reset(self):
280
- self.state.clear()
281
- generator = self.global_state.get('generator', None) # avoid resetting generator
282
- self.global_state.clear()
283
- if generator is not None: self.global_state['generator'] = generator
284
- for c in self.children.values(): c.reset()
285
-
286
- def _get_generator(self, seed: int | None | torch.Generator, params: list[torch.Tensor]):
287
- if 'generator' not in self.global_state:
288
- if isinstance(seed, torch.Generator): self.global_state['generator'] = seed
289
- elif seed is not None: self.global_state['generator'] = torch.Generator(params[0].device).manual_seed(seed)
290
- else: self.global_state['generator'] = None
291
- return self.global_state['generator']
292
233
 
293
234
  def pre_step(self, var):
294
- h, beta = self.get_settings(var.params, 'h', 'beta')
295
-
296
- n_samples = self.defaults['n_samples']
297
- distribution = self.defaults['distribution']
235
+ h = self.get_settings(var.params, 'h')
298
236
  pre_generate = self.defaults['pre_generate']
299
237
 
300
238
  if pre_generate:
239
+ n_samples = self.defaults['n_samples']
240
+ distribution = self.defaults['distribution']
241
+
301
242
  params = TensorList(var.params)
302
- generator = self._get_generator(self.defaults['seed'], var.params)
243
+ generator = self.get_generator(params[0].device, self.defaults['seed'])
303
244
  perturbations = [params.sample_like(distribution=distribution, variance=1, generator=generator) for _ in range(n_samples)]
304
245
 
246
+ # this is false for ForwardGradient where h isn't used and it subclasses this
305
247
  if self.PRE_MULTIPLY_BY_H:
306
248
  torch._foreach_mul_([p for l in perturbations for p in l], [v for vv in h for v in [vv]*n_samples])
307
249
 
308
- if all(i==0 for i in beta):
309
- # just use pre-generated perturbations
310
- for param, prt in zip(params, zip(*perturbations)):
311
- self.state[param]['perturbations'] = prt
312
-
313
- else:
314
- # lerp old and new perturbations. This makes the subspace change gradually
315
- # which in theory might improve algorithms with history
316
- for i,p in enumerate(params):
317
- state = self.state[p]
318
- if 'perturbations' not in state: state['perturbations'] = [p[i] for p in perturbations]
319
-
320
- cur = [self.state[p]['perturbations'][:n_samples] for p in params]
321
- cur_flat = [p for l in cur for p in l]
322
- new_flat = [p for l in zip(*perturbations) for p in l]
323
- betas = [1-v for b in beta for v in [b]*n_samples]
324
- torch._foreach_lerp_(cur_flat, new_flat, betas)
250
+ for param, prt in zip(params, zip(*perturbations)):
251
+ self.state[param]['perturbations'] = prt
325
252
 
326
253
  @torch.no_grad
327
254
  def approximate(self, closure, params, loss):
328
255
  params = TensorList(params)
329
- orig_params = params.clone() # store to avoid small changes due to float imprecision
330
256
  loss_approx = None
331
257
 
332
258
  h = NumberList(self.settings[p]['h'] for p in params)
333
- settings = self.settings[params[0]]
334
- n_samples = settings['n_samples']
335
- fd_fn = _RFD_FUNCS[settings['formula']]
259
+ n_samples = self.defaults['n_samples']
260
+ distribution = self.defaults['distribution']
261
+ fd_fn = _RFD_FUNCS[self.defaults['formula']]
262
+
336
263
  default = [None]*n_samples
337
264
  perturbations = list(zip(*(self.state[p].get('perturbations', default) for p in params)))
338
- distribution = settings['distribution']
339
- generator = self._get_generator(settings['seed'], params)
265
+ generator = self.get_generator(params[0].device, self.defaults['seed'])
340
266
 
341
267
  grad = None
342
268
  for i in range(n_samples):
@@ -356,7 +282,6 @@ class RandomizedFDM(GradApproximator):
356
282
  if grad is None: grad = prt * d
357
283
  else: grad += prt * d
358
284
 
359
- params.set_(orig_params)
360
285
  assert grad is not None
361
286
  if n_samples > 1: grad.div_(n_samples)
362
287
 
@@ -384,8 +309,6 @@ class SPSA(RandomizedFDM):
384
309
  n_samples (int, optional): number of random gradient samples. Defaults to 1.
385
310
  formula (_FD_Formula, optional): finite difference formula. Defaults to 'central2'.
386
311
  distribution (Distributions, optional): distribution. Defaults to "rademacher".
387
- beta (float, optional):
388
- If this is set to a value higher than zero, instead of using directional derivatives in a new random direction on each step, the direction changes gradually with momentum based on this value. This may make it possible to use methods with memory. Defaults to 0.
389
312
  pre_generate (bool, optional):
390
313
  whether to pre-generate gradient samples before each step. If samples are not pre-generated, whenever a method performs multiple closure evaluations, the gradient will be evaluated in different directions each time. Defaults to True.
391
314
  seed (int | None | torch.Generator, optional): Seed for random generator. Defaults to None.
@@ -408,8 +331,6 @@ class RDSA(RandomizedFDM):
408
331
  n_samples (int, optional): number of random gradient samples. Defaults to 1.
409
332
  formula (_FD_Formula, optional): finite difference formula. Defaults to 'central2'.
410
333
  distribution (Distributions, optional): distribution. Defaults to "gaussian".
411
- beta (float, optional):
412
- If this is set to a value higher than zero, instead of using directional derivatives in a new random direction on each step, the direction changes gradually with momentum based on this value. This may make it possible to use methods with memory. Defaults to 0.
413
334
  pre_generate (bool, optional):
414
335
  whether to pre-generate gradient samples before each step. If samples are not pre-generated, whenever a method performs multiple closure evaluations, the gradient will be evaluated in different directions each time. Defaults to True.
415
336
  seed (int | None | torch.Generator, optional): Seed for random generator. Defaults to None.
@@ -425,12 +346,11 @@ class RDSA(RandomizedFDM):
425
346
  n_samples: int = 1,
426
347
  formula: _FD_Formula = "central2",
427
348
  distribution: Distributions = "gaussian",
428
- beta: float = 0,
429
349
  pre_generate = True,
430
350
  target: GradTarget = "closure",
431
351
  seed: int | None | torch.Generator = None,
432
352
  ):
433
- super().__init__(h=h, n_samples=n_samples,formula=formula,distribution=distribution,beta=beta,pre_generate=pre_generate,target=target,seed=seed)
353
+ super().__init__(h=h, n_samples=n_samples,formula=formula,distribution=distribution,pre_generate=pre_generate,target=target,seed=seed)
434
354
 
435
355
  class GaussianSmoothing(RandomizedFDM):
436
356
  """
@@ -445,8 +365,6 @@ class GaussianSmoothing(RandomizedFDM):
445
365
  n_samples (int, optional): number of random gradient samples. Defaults to 100.
446
366
  formula (_FD_Formula, optional): finite difference formula. Defaults to 'forward2'.
447
367
  distribution (Distributions, optional): distribution. Defaults to "gaussian".
448
- beta (float, optional):
449
- If this is set to a value higher than zero, instead of using directional derivatives in a new random direction on each step, the direction changes gradually with momentum based on this value. This may make it possible to use methods with memory. Defaults to 0.
450
368
  pre_generate (bool, optional):
451
369
  whether to pre-generate gradient samples before each step. If samples are not pre-generated, whenever a method performs multiple closure evaluations, the gradient will be evaluated in different directions each time. Defaults to True.
452
370
  seed (int | None | torch.Generator, optional): Seed for random generator. Defaults to None.
@@ -462,12 +380,11 @@ class GaussianSmoothing(RandomizedFDM):
462
380
  n_samples: int = 100,
463
381
  formula: _FD_Formula = "forward2",
464
382
  distribution: Distributions = "gaussian",
465
- beta: float = 0,
466
383
  pre_generate = True,
467
384
  target: GradTarget = "closure",
468
385
  seed: int | None | torch.Generator = None,
469
386
  ):
470
- super().__init__(h=h, n_samples=n_samples,formula=formula,distribution=distribution,beta=beta,pre_generate=pre_generate,target=target,seed=seed)
387
+ super().__init__(h=h, n_samples=n_samples,formula=formula,distribution=distribution,pre_generate=pre_generate,target=target,seed=seed)
471
388
 
472
389
  class MeZO(GradApproximator):
473
390
  """Gradient approximation via memory-efficient zeroth order optimizer (MeZO) - https://arxiv.org/abs/2305.17333.
@@ -525,9 +442,9 @@ class MeZO(GradApproximator):
525
442
  loss_approx = None
526
443
 
527
444
  h = NumberList(self.settings[p]['h'] for p in params)
528
- settings = self.settings[params[0]]
529
- n_samples = settings['n_samples']
530
- fd_fn = _RFD_FUNCS[settings['formula']]
445
+ n_samples = self.defaults['n_samples']
446
+ fd_fn = _RFD_FUNCS[self.defaults['formula']]
447
+
531
448
  prt_fns = self.global_state['prt_fns']
532
449
 
533
450
  grad = None
@@ -1,3 +1,4 @@
1
+ import math
1
2
  from collections.abc import Mapping
2
3
  from operator import itemgetter
3
4
 
@@ -17,6 +18,7 @@ class ScipyMinimizeScalar(LineSearchBase):
17
18
  bounds (Sequence | None, optional):
18
19
  For method ‘bounded’, bounds is mandatory and must have two finite items corresponding to the optimization bounds. Defaults to None.
19
20
  tol (float | None, optional): Tolerance for termination. Defaults to None.
21
+ prev_init (bool, optional): uses previous step size as initial guess for the line search.
20
22
  options (dict | None, optional): A dictionary of solver options. Defaults to None.
21
23
 
22
24
  For more details on methods and arguments refer to https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize_scalar.html
@@ -29,9 +31,10 @@ class ScipyMinimizeScalar(LineSearchBase):
29
31
  bracket=None,
30
32
  bounds=None,
31
33
  tol: float | None = None,
34
+ prev_init: bool = False,
32
35
  options=None,
33
36
  ):
34
- defaults = dict(method=method,bracket=bracket,bounds=bounds,tol=tol,options=options,maxiter=maxiter)
37
+ defaults = dict(method=method,bracket=bracket,bounds=bounds,tol=tol,options=options,maxiter=maxiter, prev_init=prev_init)
35
38
  super().__init__(defaults)
36
39
 
37
40
  import scipy.optimize
@@ -48,5 +51,14 @@ class ScipyMinimizeScalar(LineSearchBase):
48
51
  options = dict(options) if isinstance(options, Mapping) else {}
49
52
  options['maxiter'] = maxiter
50
53
 
51
- res = self.scopt.minimize_scalar(objective, method=method, bracket=bracket, bounds=bounds, tol=tol, options=options)
52
- return res.x
54
+ if self.defaults["prev_init"] and "x_prev" in self.global_state:
55
+ if bracket is None: bracket = (0, 1)
56
+ bracket = (*bracket[:-1], self.global_state["x_prev"])
57
+
58
+ x = self.scopt.minimize_scalar(objective, method=method, bracket=bracket, bounds=bounds, tol=tol, options=options).x # pyright:ignore[reportAttributeAccessIssue]
59
+
60
+ max = torch.finfo(var.params[0].dtype).max / 2
61
+ if (not math.isfinite(x)) or abs(x) >= max: x = 0
62
+
63
+ self.global_state['x_prev'] = x
64
+ return x
@@ -330,7 +330,6 @@ class StrongWolfe(LineSearchBase):
330
330
  if adaptive:
331
331
  a_init *= self.global_state.get('initial_scale', 1)
332
332
 
333
-
334
333
  strong_wolfe = _StrongWolfe(
335
334
  f=objective,
336
335
  f_0=f_0,
@@ -360,7 +359,6 @@ class StrongWolfe(LineSearchBase):
360
359
  if inverted: a = -a
361
360
 
362
361
  if a is not None and a != 0 and math.isfinite(a):
363
- #self.global_state['initial_scale'] = min(1.0, self.global_state.get('initial_scale', 1) * math.sqrt(2))
364
362
  self.global_state['initial_scale'] = 1
365
363
  self.global_state['a_prev'] = a
366
364
  self.global_state['f_prev'] = f_0
@@ -60,18 +60,18 @@ class RestartStrategyBase(Module, ABC):
60
60
 
61
61
 
62
62
  class RestartOnStuck(RestartStrategyBase):
63
- """Resets the state when update (difference in parameters) is close to zero for multiple steps in a row.
63
+ """Resets the state when update (difference in parameters) is zero for multiple steps in a row.
64
64
 
65
65
  Args:
66
66
  modules (Chainable | None):
67
67
  modules to reset. If None, resets all modules.
68
68
  tol (float, optional):
69
- step is considered failed when maximum absolute parameter difference is smaller than this. Defaults to 1e-10.
69
+ step is considered failed when maximum absolute parameter difference is smaller than this. Defaults to None (uses twice the smallest respresentable number)
70
70
  n_tol (int, optional):
71
- number of failed consequtive steps required to trigger a reset. Defaults to 4.
71
+ number of failed consequtive steps required to trigger a reset. Defaults to 10.
72
72
 
73
73
  """
74
- def __init__(self, modules: Chainable | None, tol: float = 1e-10, n_tol: int = 4):
74
+ def __init__(self, modules: Chainable | None, tol: float | None = None, n_tol: int = 10):
75
75
  defaults = dict(tol=tol, n_tol=n_tol)
76
76
  super().__init__(defaults, modules)
77
77
 
@@ -82,6 +82,7 @@ class RestartOnStuck(RestartStrategyBase):
82
82
 
83
83
  params = TensorList(var.params)
84
84
  tol = self.defaults['tol']
85
+ if tol is None: tol = torch.finfo(params[0].dtype).tiny * 2
85
86
  n_tol = self.defaults['n_tol']
86
87
  n_bad = self.global_state.get('n_bad', 0)
87
88