torchzero 0.3.13__py3-none-any.whl → 0.3.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. tests/test_opts.py +4 -10
  2. torchzero/core/__init__.py +4 -1
  3. torchzero/core/chain.py +50 -0
  4. torchzero/core/functional.py +37 -0
  5. torchzero/core/modular.py +237 -0
  6. torchzero/core/module.py +12 -599
  7. torchzero/core/reformulation.py +3 -1
  8. torchzero/core/transform.py +7 -5
  9. torchzero/core/var.py +376 -0
  10. torchzero/modules/__init__.py +0 -1
  11. torchzero/modules/adaptive/adahessian.py +2 -2
  12. torchzero/modules/adaptive/esgd.py +2 -2
  13. torchzero/modules/adaptive/matrix_momentum.py +1 -1
  14. torchzero/modules/adaptive/sophia_h.py +2 -2
  15. torchzero/modules/conjugate_gradient/cg.py +16 -16
  16. torchzero/modules/experimental/__init__.py +1 -0
  17. torchzero/modules/experimental/newtonnewton.py +5 -5
  18. torchzero/modules/experimental/spsa1.py +93 -0
  19. torchzero/modules/functional.py +7 -0
  20. torchzero/modules/grad_approximation/__init__.py +1 -1
  21. torchzero/modules/grad_approximation/forward_gradient.py +2 -5
  22. torchzero/modules/grad_approximation/rfdm.py +27 -110
  23. torchzero/modules/line_search/__init__.py +1 -1
  24. torchzero/modules/line_search/_polyinterp.py +3 -1
  25. torchzero/modules/line_search/adaptive.py +3 -3
  26. torchzero/modules/line_search/backtracking.py +1 -1
  27. torchzero/modules/line_search/interpolation.py +160 -0
  28. torchzero/modules/line_search/line_search.py +11 -20
  29. torchzero/modules/line_search/scipy.py +15 -3
  30. torchzero/modules/line_search/strong_wolfe.py +3 -5
  31. torchzero/modules/misc/misc.py +2 -2
  32. torchzero/modules/misc/multistep.py +13 -13
  33. torchzero/modules/quasi_newton/__init__.py +2 -0
  34. torchzero/modules/quasi_newton/quasi_newton.py +15 -6
  35. torchzero/modules/quasi_newton/sg2.py +292 -0
  36. torchzero/modules/restarts/restars.py +5 -4
  37. torchzero/modules/second_order/__init__.py +6 -3
  38. torchzero/modules/second_order/ifn.py +89 -0
  39. torchzero/modules/second_order/inm.py +105 -0
  40. torchzero/modules/second_order/newton.py +103 -193
  41. torchzero/modules/second_order/newton_cg.py +86 -110
  42. torchzero/modules/second_order/nystrom.py +1 -1
  43. torchzero/modules/second_order/rsn.py +227 -0
  44. torchzero/modules/trust_region/levenberg_marquardt.py +2 -2
  45. torchzero/modules/trust_region/trust_cg.py +6 -4
  46. torchzero/modules/wrappers/optim_wrapper.py +49 -42
  47. torchzero/modules/zeroth_order/__init__.py +1 -1
  48. torchzero/modules/zeroth_order/cd.py +1 -238
  49. torchzero/utils/derivatives.py +19 -19
  50. torchzero/utils/linalg/linear_operator.py +50 -2
  51. torchzero/utils/optimizer.py +2 -2
  52. torchzero/utils/python_tools.py +1 -0
  53. {torchzero-0.3.13.dist-info → torchzero-0.3.15.dist-info}/METADATA +1 -1
  54. {torchzero-0.3.13.dist-info → torchzero-0.3.15.dist-info}/RECORD +57 -48
  55. torchzero/modules/higher_order/__init__.py +0 -1
  56. /torchzero/modules/{higher_order → experimental}/higher_order_newton.py +0 -0
  57. {torchzero-0.3.13.dist-info → torchzero-0.3.15.dist-info}/WHEEL +0 -0
  58. {torchzero-0.3.13.dist-info → torchzero-0.3.15.dist-info}/top_level.txt +0 -0
@@ -10,6 +10,7 @@ import torch
10
10
 
11
11
  from ...core import Module, Target, Var
12
12
  from ...utils import tofloat, set_storage_
13
+ from ..functional import clip_by_finfo
13
14
 
14
15
 
15
16
  class MaxLineSearchItersReached(Exception): pass
@@ -103,23 +104,18 @@ class LineSearchBase(Module, ABC):
103
104
  ):
104
105
  if not math.isfinite(step_size): return
105
106
 
106
- # fixes overflow when backtracking keeps increasing alpha after converging
107
- step_size = max(min(tofloat(step_size), 1e36), -1e36)
107
+ # avoid overflow error
108
+ step_size = clip_by_finfo(tofloat(step_size), torch.finfo(update[0].dtype))
108
109
 
109
110
  # skip is parameters are already at suggested step size
110
111
  if self._current_step_size == step_size: return
111
112
 
112
- # this was basically causing floating point imprecision to build up
113
- #if False:
114
- # if abs(alpha) < abs(step_size) and step_size != 0:
115
- # torch._foreach_add_(params, update, alpha=alpha)
116
-
117
- # else:
118
113
  assert self._initial_params is not None
119
114
  if step_size == 0:
120
115
  new_params = [p.clone() for p in self._initial_params]
121
116
  else:
122
117
  new_params = torch._foreach_sub(self._initial_params, update, alpha=step_size)
118
+
123
119
  for c, n in zip(params, new_params):
124
120
  set_storage_(c, n)
125
121
 
@@ -131,10 +127,7 @@ class LineSearchBase(Module, ABC):
131
127
  params: list[torch.Tensor],
132
128
  update: list[torch.Tensor],
133
129
  ):
134
- # if not np.isfinite(step_size): step_size = [0 for _ in step_size]
135
- # alpha = [self._current_step_size - s for s in step_size]
136
- # if any(a!=0 for a in alpha):
137
- # torch._foreach_add_(params, torch._foreach_mul(update, alpha))
130
+
138
131
  assert self._initial_params is not None
139
132
  if not np.isfinite(step_size).all(): step_size = [0 for _ in step_size]
140
133
 
@@ -248,16 +241,14 @@ class LineSearchBase(Module, ABC):
248
241
  except MaxLineSearchItersReached:
249
242
  step_size = self._best_step_size
250
243
 
244
+ step_size = clip_by_finfo(step_size, torch.finfo(update[0].dtype))
245
+
251
246
  # set loss_approx
252
247
  if var.loss_approx is None: var.loss_approx = self._lowest_loss
253
248
 
254
- # this is last module - set step size to found step_size times lr
255
- if var.is_last:
256
- if var.last_module_lrs is None:
257
- self.set_step_size_(step_size, params=params, update=update)
258
-
259
- else:
260
- self._set_per_parameter_step_size_([step_size*lr for lr in var.last_module_lrs], params=params, update=update)
249
+ # if this is last module, directly update parameters to avoid redundant operations
250
+ if var.modular is not None and self is var.modular.modules[-1]:
251
+ self.set_step_size_(step_size, params=params, update=update)
261
252
 
262
253
  var.stop = True; var.skip_update = True
263
254
  return var
@@ -277,7 +268,7 @@ class GridLineSearch(LineSearchBase):
277
268
 
278
269
  @torch.no_grad
279
270
  def search(self, update, var):
280
- start,end,num=itemgetter('start','end','num')(self.defaults)
271
+ start, end, num = itemgetter('start', 'end', 'num')(self.defaults)
281
272
 
282
273
  for lr in torch.linspace(start,end,num):
283
274
  self.evaluate_f(lr.item(), var=var, backward=False)
@@ -1,3 +1,4 @@
1
+ import math
1
2
  from collections.abc import Mapping
2
3
  from operator import itemgetter
3
4
 
@@ -17,6 +18,7 @@ class ScipyMinimizeScalar(LineSearchBase):
17
18
  bounds (Sequence | None, optional):
18
19
  For method ‘bounded’, bounds is mandatory and must have two finite items corresponding to the optimization bounds. Defaults to None.
19
20
  tol (float | None, optional): Tolerance for termination. Defaults to None.
21
+ prev_init (bool, optional): uses previous step size as initial guess for the line search.
20
22
  options (dict | None, optional): A dictionary of solver options. Defaults to None.
21
23
 
22
24
  For more details on methods and arguments refer to https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize_scalar.html
@@ -29,9 +31,10 @@ class ScipyMinimizeScalar(LineSearchBase):
29
31
  bracket=None,
30
32
  bounds=None,
31
33
  tol: float | None = None,
34
+ prev_init: bool = False,
32
35
  options=None,
33
36
  ):
34
- defaults = dict(method=method,bracket=bracket,bounds=bounds,tol=tol,options=options,maxiter=maxiter)
37
+ defaults = dict(method=method,bracket=bracket,bounds=bounds,tol=tol,options=options,maxiter=maxiter, prev_init=prev_init)
35
38
  super().__init__(defaults)
36
39
 
37
40
  import scipy.optimize
@@ -48,5 +51,14 @@ class ScipyMinimizeScalar(LineSearchBase):
48
51
  options = dict(options) if isinstance(options, Mapping) else {}
49
52
  options['maxiter'] = maxiter
50
53
 
51
- res = self.scopt.minimize_scalar(objective, method=method, bracket=bracket, bounds=bounds, tol=tol, options=options)
52
- return res.x
54
+ if self.defaults["prev_init"] and "x_prev" in self.global_state:
55
+ if bracket is None: bracket = (0, 1)
56
+ bracket = (*bracket[:-1], self.global_state["x_prev"])
57
+
58
+ x = self.scopt.minimize_scalar(objective, method=method, bracket=bracket, bounds=bounds, tol=tol, options=options).x # pyright:ignore[reportAttributeAccessIssue]
59
+
60
+ max = torch.finfo(var.params[0].dtype).max / 2
61
+ if (not math.isfinite(x)) or abs(x) >= max: x = 0
62
+
63
+ self.global_state['x_prev'] = x
64
+ return x
@@ -7,7 +7,7 @@ import numpy as np
7
7
  import torch
8
8
  from torch.optim.lbfgs import _cubic_interpolate
9
9
 
10
- from ...utils import as_tensorlist, totensor
10
+ from ...utils import as_tensorlist, totensor, tofloat
11
11
  from ._polyinterp import polyinterp, polyinterp2
12
12
  from .line_search import LineSearchBase, TerminationCondition, termination_condition
13
13
  from ..step_size.adaptive import _bb_geom
@@ -92,7 +92,7 @@ class _StrongWolfe:
92
92
  return _apply_bounds(a_lo + 0.5 * (a_hi - a_lo), bounds)
93
93
 
94
94
  if self.interpolation in ('polynomial', 'polynomial2'):
95
- finite_history = [(a, f, g) for a, (f,g) in self.history.items() if math.isfinite(a) and math.isfinite(f) and math.isfinite(g)]
95
+ finite_history = [(tofloat(a), tofloat(f), tofloat(g)) for a, (f,g) in self.history.items() if math.isfinite(a) and math.isfinite(f) and math.isfinite(g)]
96
96
  if bounds is None: bounds = (None, None)
97
97
  polyinterp_fn = polyinterp if self.interpolation == 'polynomial' else polyinterp2
98
98
  try:
@@ -330,7 +330,6 @@ class StrongWolfe(LineSearchBase):
330
330
  if adaptive:
331
331
  a_init *= self.global_state.get('initial_scale', 1)
332
332
 
333
-
334
333
  strong_wolfe = _StrongWolfe(
335
334
  f=objective,
336
335
  f_0=f_0,
@@ -360,7 +359,6 @@ class StrongWolfe(LineSearchBase):
360
359
  if inverted: a = -a
361
360
 
362
361
  if a is not None and a != 0 and math.isfinite(a):
363
- #self.global_state['initial_scale'] = min(1.0, self.global_state.get('initial_scale', 1) * math.sqrt(2))
364
362
  self.global_state['initial_scale'] = 1
365
363
  self.global_state['a_prev'] = a
366
364
  self.global_state['f_prev'] = f_0
@@ -372,6 +370,6 @@ class StrongWolfe(LineSearchBase):
372
370
  self.global_state['initial_scale'] = self.global_state.get('initial_scale', 1) * 0.5
373
371
  finfo = torch.finfo(dir[0].dtype)
374
372
  if self.global_state['initial_scale'] < finfo.tiny * 2:
375
- self.global_state['initial_scale'] = finfo.max / 2
373
+ self.global_state['initial_scale'] = init_value * 2
376
374
 
377
375
  return 0
@@ -306,8 +306,8 @@ class RandomHvp(Module):
306
306
  for i in range(n_samples):
307
307
  u = params.sample_like(distribution=distribution, variance=1)
308
308
 
309
- Hvp, rgrad = self.Hvp(u, at_x0=True, var=var, rgrad=rgrad, hvp_method=hvp_method,
310
- h=h, normalize=True, retain_grad=i < n_samples-1)
309
+ Hvp, rgrad = var.hessian_vector_product(u, at_x0=True, rgrad=rgrad, hvp_method=hvp_method,
310
+ h=h, normalize=True, retain_graph=i < n_samples-1)
311
311
 
312
312
  if D is None: D = Hvp
313
313
  else: torch._foreach_add_(D, Hvp)
@@ -15,7 +15,7 @@ def _sequential_step(self: Module, var: Var, sequential: bool):
15
15
  if var.closure is None and len(modules) > 1: raise ValueError('Multistep and Sequential require closure')
16
16
 
17
17
  # store original params unless this is last module and can update params directly
18
- params_before_steps = None if (var.is_last and var.last_module_lrs is None) else [p.clone() for p in params]
18
+ params_before_steps = [p.clone() for p in params]
19
19
 
20
20
  # first step - pass var as usual
21
21
  var = modules[0].step(var)
@@ -27,8 +27,8 @@ def _sequential_step(self: Module, var: Var, sequential: bool):
27
27
 
28
28
  # update params
29
29
  if (not new_var.skip_update):
30
- if new_var.last_module_lrs is not None:
31
- torch._foreach_mul_(new_var.get_update(), new_var.last_module_lrs)
30
+ # if new_var.last_module_lrs is not None:
31
+ # torch._foreach_mul_(new_var.get_update(), new_var.last_module_lrs)
32
32
 
33
33
  torch._foreach_sub_(params, new_var.get_update())
34
34
 
@@ -41,16 +41,16 @@ def _sequential_step(self: Module, var: Var, sequential: bool):
41
41
 
42
42
  # final parameter update
43
43
  if (not new_var.skip_update):
44
- if new_var.last_module_lrs is not None:
45
- torch._foreach_mul_(new_var.get_update(), new_var.last_module_lrs)
44
+ # if new_var.last_module_lrs is not None:
45
+ # torch._foreach_mul_(new_var.get_update(), new_var.last_module_lrs)
46
46
 
47
47
  torch._foreach_sub_(params, new_var.get_update())
48
48
 
49
49
  # if last module, update is applied so return new var
50
- if params_before_steps is None:
51
- new_var.stop = True
52
- new_var.skip_update = True
53
- return new_var
50
+ # if params_before_steps is None:
51
+ # new_var.stop = True
52
+ # new_var.skip_update = True
53
+ # return new_var
54
54
 
55
55
  # otherwise use parameter difference as update
56
56
  var.update = list(torch._foreach_sub(params_before_steps, params))
@@ -106,10 +106,10 @@ class NegateOnLossIncrease(Module):
106
106
  f_1 = closure(False)
107
107
 
108
108
  if f_1 <= f_0:
109
- if var.is_last and var.last_module_lrs is None:
110
- var.stop = True
111
- var.skip_update = True
112
- return var
109
+ # if var.is_last and var.last_module_lrs is None:
110
+ # var.stop = True
111
+ # var.skip_update = True
112
+ # return var
113
113
 
114
114
  torch._foreach_add_(var.params, update)
115
115
  return var
@@ -29,3 +29,5 @@ from .quasi_newton import (
29
29
  ShorR,
30
30
  ThomasOptimalMethod,
31
31
  )
32
+
33
+ from .sg2 import SG2, SPSA2
@@ -1182,16 +1182,19 @@ class ShorR(HessianUpdateStrategy):
1182
1182
  """Shor’s r-algorithm.
1183
1183
 
1184
1184
  Note:
1185
- A line search such as ``tz.m.StrongWolfe(a_init="quadratic", fallback=True)`` is required.
1186
- Similarly to conjugate gradient, ShorR doesn't have an automatic step size scaling,
1187
- so setting ``a_init`` in the line search is recommended.
1185
+ - A line search such as ``[tz.m.StrongWolfe(a_init="quadratic", fallback=True), tz.m.Mul(1.2)]`` is required. Similarly to conjugate gradient, ShorR doesn't have an automatic step size scaling, so setting ``a_init`` in the line search is recommended.
1186
+
1187
+ - The line search should try to overstep by a little, therefore it can help to multiply direction given by a line search by some value slightly larger than 1 such as 1.2.
1188
1188
 
1189
1189
  References:
1190
- S HOR , N. Z. (1985) Minimization Methods for Non-differentiable Functions. New York: Springer.
1190
+ Those are the original references, but neither seem to be available online:
1191
+ - Shor, N. Z., Utilization of the Operation of Space Dilatation in the Minimization of Convex Functions, Kibernetika, No. 1, pp. 6-12, 1970.
1192
+
1193
+ - Skokov, V. A., Note on Minimization Methods Employing Space Stretching, Kibernetika, No. 4, pp. 115-117, 1974.
1191
1194
 
1192
- Burke, James V., Adrian S. Lewis, and Michael L. Overton. "The Speed of Shor's R-algorithm." IMA Journal of numerical analysis 28.4 (2008): 711-720. - good overview.
1195
+ An overview is available in [Burke, James V., Adrian S. Lewis, and Michael L. Overton. "The Speed of Shor's R-algorithm." IMA Journal of numerical analysis 28.4 (2008): 711-720](https://sites.math.washington.edu/~burke/papers/reprints/60-speed-Shor-R.pdf).
1193
1196
 
1194
- Ansari, Zafar A. Limited Memory Space Dilation and Reduction Algorithms. Diss. Virginia Tech, 1998. - this is where a more efficient formula is described.
1197
+ Reference by Skokov, V. A. describes a more efficient formula which can be found here [Ansari, Zafar A. Limited Memory Space Dilation and Reduction Algorithms. Diss. Virginia Tech, 1998.](https://camo.ici.ro/books/thesis/th.pdf)
1195
1198
  """
1196
1199
 
1197
1200
  def __init__(
@@ -1229,3 +1232,9 @@ class ShorR(HessianUpdateStrategy):
1229
1232
 
1230
1233
  def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
1231
1234
  return shor_r_(H=H, y=y, alpha=setting['alpha'])
1235
+
1236
+
1237
+ # Todd, Michael J. "The symmetric rank-one quasi-Newton method is a space-dilation subgradient algorithm." Operations research letters 5.5 (1986): 217-219.
1238
+ # TODO
1239
+
1240
+ # Sorensen, D. C. "The q-superlinear convergence of a collinear scaling algorithm for unconstrained optimization." SIAM Journal on Numerical Analysis 17.1 (1980): 84-114.
@@ -0,0 +1,292 @@
1
+ import torch
2
+
3
+ from ...core import Module, Chainable, apply_transform
4
+ from ...utils import TensorList, vec_to_tensors
5
+ from ..second_order.newton import _newton_step, _get_H
6
+
7
+ def sg2_(
8
+ delta_g: torch.Tensor,
9
+ cd: torch.Tensor,
10
+ ) -> torch.Tensor:
11
+ """cd is c * perturbation, and must be multiplied by two if hessian estimate is two-sided
12
+ (or divide delta_g by two)."""
13
+
14
+ M = torch.outer(1.0 / cd, delta_g)
15
+ H_hat = 0.5 * (M + M.T)
16
+
17
+ return H_hat
18
+
19
+
20
+
21
+ class SG2(Module):
22
+ """second-order stochastic gradient
23
+
24
+ SG2 with line search
25
+ ```python
26
+ opt = tz.Modular(
27
+ model.parameters(),
28
+ tz.m.SG2(),
29
+ tz.m.Backtracking()
30
+ )
31
+ ```
32
+
33
+ SG2 with trust region
34
+ ```python
35
+ opt = tz.Modular(
36
+ model.parameters(),
37
+ tz.m.LevenbergMarquardt(tz.m.SG2()),
38
+ )
39
+ ```
40
+
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ n_samples: int = 1,
46
+ h: float = 1e-2,
47
+ beta: float | None = None,
48
+ damping: float = 0,
49
+ eigval_fn=None,
50
+ one_sided: bool = False, # one-sided hessian
51
+ use_lstsq: bool = True,
52
+ seed=None,
53
+ inner: Chainable | None = None,
54
+ ):
55
+ defaults = dict(n_samples=n_samples, h=h, beta=beta, damping=damping, eigval_fn=eigval_fn, one_sided=one_sided, seed=seed, use_lstsq=use_lstsq)
56
+ super().__init__(defaults)
57
+
58
+ if inner is not None: self.set_child('inner', inner)
59
+
60
+ @torch.no_grad
61
+ def update(self, var):
62
+ k = self.global_state.get('step', 0) + 1
63
+ self.global_state["step"] = k
64
+
65
+ params = TensorList(var.params)
66
+ closure = var.closure
67
+ if closure is None:
68
+ raise RuntimeError("closure is required for SG2")
69
+ generator = self.get_generator(params[0].device, self.defaults["seed"])
70
+
71
+ h = self.get_settings(params, "h")
72
+ x_0 = params.clone()
73
+ n_samples = self.defaults["n_samples"]
74
+ H_hat = None
75
+
76
+ for i in range(n_samples):
77
+ # generate perturbation
78
+ cd = params.rademacher_like(generator=generator).mul_(h)
79
+
80
+ # one sided
81
+ if self.defaults["one_sided"]:
82
+ g_0 = TensorList(var.get_grad())
83
+ params.add_(cd)
84
+ closure()
85
+
86
+ g_p = params.grad.fill_none_(params)
87
+ delta_g = (g_p - g_0) * 2
88
+
89
+ # two sided
90
+ else:
91
+ params.add_(cd)
92
+ closure()
93
+ g_p = params.grad.fill_none_(params)
94
+
95
+ params.copy_(x_0)
96
+ params.sub_(cd)
97
+ closure()
98
+ g_n = params.grad.fill_none_(params)
99
+
100
+ delta_g = g_p - g_n
101
+
102
+ # restore params
103
+ params.set_(x_0)
104
+
105
+ # compute H hat
106
+ H_i = sg2_(
107
+ delta_g = delta_g.to_vec(),
108
+ cd = cd.to_vec(),
109
+ )
110
+
111
+ if H_hat is None: H_hat = H_i
112
+ else: H_hat += H_i
113
+
114
+ assert H_hat is not None
115
+ if n_samples > 1: H_hat /= n_samples
116
+
117
+ # update H
118
+ H = self.global_state.get("H", None)
119
+ if H is None: H = H_hat
120
+ else:
121
+ beta = self.defaults["beta"]
122
+ if beta is None: beta = k / (k+1)
123
+ H.lerp_(H_hat, 1-beta)
124
+
125
+ self.global_state["H"] = H
126
+
127
+
128
+ @torch.no_grad
129
+ def apply(self, var):
130
+ dir = _newton_step(
131
+ var=var,
132
+ H = self.global_state["H"],
133
+ damping = self.defaults["damping"],
134
+ inner = self.children.get("inner", None),
135
+ H_tfm=None,
136
+ eigval_fn=self.defaults["eigval_fn"],
137
+ use_lstsq=self.defaults["use_lstsq"],
138
+ g_proj=None,
139
+ )
140
+
141
+ var.update = vec_to_tensors(dir, var.params)
142
+ return var
143
+
144
+ def get_H(self,var=...):
145
+ return _get_H(self.global_state["H"], self.defaults["eigval_fn"])
146
+
147
+
148
+
149
+
150
+ # two sided
151
+ # we have g via x + d, x - d
152
+ # H via g(x + d), g(x - d)
153
+ # 1 is x, x+2d
154
+ # 2 is x, x-2d
155
+ # 5 evals in total
156
+
157
+ # one sided
158
+ # g via x, x + d
159
+ # 1 is x, x + d
160
+ # 2 is x, x - d
161
+ # 3 evals and can use two sided for g_0
162
+
163
+ class SPSA2(Module):
164
+ """second-order SPSA
165
+
166
+ SPSA2 with line search
167
+ ```python
168
+ opt = tz.Modular(
169
+ model.parameters(),
170
+ tz.m.SPSA2(),
171
+ tz.m.Backtracking()
172
+ )
173
+ ```
174
+
175
+ SPSA2 with trust region
176
+ ```python
177
+ opt = tz.Modular(
178
+ model.parameters(),
179
+ tz.m.LevenbergMarquardt(tz.m.SPSA2()),
180
+ )
181
+ ```
182
+ """
183
+
184
+ def __init__(
185
+ self,
186
+ n_samples: int = 1,
187
+ h: float = 1e-2,
188
+ beta: float | None = None,
189
+ damping: float = 0,
190
+ eigval_fn=None,
191
+ use_lstsq: bool = True,
192
+ seed=None,
193
+ inner: Chainable | None = None,
194
+ ):
195
+ defaults = dict(n_samples=n_samples, h=h, beta=beta, damping=damping, eigval_fn=eigval_fn, seed=seed, use_lstsq=use_lstsq)
196
+ super().__init__(defaults)
197
+
198
+ if inner is not None: self.set_child('inner', inner)
199
+
200
+ @torch.no_grad
201
+ def update(self, var):
202
+ k = self.global_state.get('step', 0) + 1
203
+ self.global_state["step"] = k
204
+
205
+ params = TensorList(var.params)
206
+ closure = var.closure
207
+ if closure is None:
208
+ raise RuntimeError("closure is required for SPSA2")
209
+
210
+ generator = self.get_generator(params[0].device, self.defaults["seed"])
211
+
212
+ h = self.get_settings(params, "h")
213
+ x_0 = params.clone()
214
+ n_samples = self.defaults["n_samples"]
215
+ H_hat = None
216
+ g_0 = None
217
+
218
+ for i in range(n_samples):
219
+ # perturbations for g and H
220
+ cd_g = params.rademacher_like(generator=generator).mul_(h)
221
+ cd_H = params.rademacher_like(generator=generator).mul_(h)
222
+
223
+ # evaluate 4 points
224
+ x_p = x_0 + cd_g
225
+ x_n = x_0 - cd_g
226
+
227
+ params.set_(x_p)
228
+ f_p = closure(False)
229
+ params.add_(cd_H)
230
+ f_pp = closure(False)
231
+
232
+ params.set_(x_n)
233
+ f_n = closure(False)
234
+ params.add_(cd_H)
235
+ f_np = closure(False)
236
+
237
+ g_p_vec = (f_pp - f_p) / cd_H
238
+ g_n_vec = (f_np - f_n) / cd_H
239
+ delta_g = g_p_vec - g_n_vec
240
+
241
+ # restore params
242
+ params.set_(x_0)
243
+
244
+ # compute grad
245
+ g_i = (f_p - f_n) / (2 * cd_g)
246
+ if g_0 is None: g_0 = g_i
247
+ else: g_0 += g_i
248
+
249
+ # compute H hat
250
+ H_i = sg2_(
251
+ delta_g = delta_g.to_vec().div_(2.0),
252
+ cd = cd_g.to_vec(), # The interval is measured by the original 'cd'
253
+ )
254
+ if H_hat is None: H_hat = H_i
255
+ else: H_hat += H_i
256
+
257
+ assert g_0 is not None and H_hat is not None
258
+ if n_samples > 1:
259
+ g_0 /= n_samples
260
+ H_hat /= n_samples
261
+
262
+ # set grad to approximated grad
263
+ var.grad = g_0
264
+
265
+ # update H
266
+ H = self.global_state.get("H", None)
267
+ if H is None: H = H_hat
268
+ else:
269
+ beta = self.defaults["beta"]
270
+ if beta is None: beta = k / (k+1)
271
+ H.lerp_(H_hat, 1-beta)
272
+
273
+ self.global_state["H"] = H
274
+
275
+ @torch.no_grad
276
+ def apply(self, var):
277
+ dir = _newton_step(
278
+ var=var,
279
+ H = self.global_state["H"],
280
+ damping = self.defaults["damping"],
281
+ inner = self.children.get("inner", None),
282
+ H_tfm=None,
283
+ eigval_fn=self.defaults["eigval_fn"],
284
+ use_lstsq=self.defaults["use_lstsq"],
285
+ g_proj=None,
286
+ )
287
+
288
+ var.update = vec_to_tensors(dir, var.params)
289
+ return var
290
+
291
+ def get_H(self,var=...):
292
+ return _get_H(self.global_state["H"], self.defaults["eigval_fn"])
@@ -60,18 +60,18 @@ class RestartStrategyBase(Module, ABC):
60
60
 
61
61
 
62
62
  class RestartOnStuck(RestartStrategyBase):
63
- """Resets the state when update (difference in parameters) is close to zero for multiple steps in a row.
63
+ """Resets the state when update (difference in parameters) is zero for multiple steps in a row.
64
64
 
65
65
  Args:
66
66
  modules (Chainable | None):
67
67
  modules to reset. If None, resets all modules.
68
68
  tol (float, optional):
69
- step is considered failed when maximum absolute parameter difference is smaller than this. Defaults to 1e-10.
69
+ step is considered failed when maximum absolute parameter difference is smaller than this. Defaults to None (uses twice the smallest respresentable number)
70
70
  n_tol (int, optional):
71
- number of failed consequtive steps required to trigger a reset. Defaults to 4.
71
+ number of failed consequtive steps required to trigger a reset. Defaults to 10.
72
72
 
73
73
  """
74
- def __init__(self, modules: Chainable | None, tol: float = 1e-10, n_tol: int = 4):
74
+ def __init__(self, modules: Chainable | None, tol: float | None = None, n_tol: int = 10):
75
75
  defaults = dict(tol=tol, n_tol=n_tol)
76
76
  super().__init__(defaults, modules)
77
77
 
@@ -82,6 +82,7 @@ class RestartOnStuck(RestartStrategyBase):
82
82
 
83
83
  params = TensorList(var.params)
84
84
  tol = self.defaults['tol']
85
+ if tol is None: tol = torch.finfo(params[0].dtype).tiny * 2
85
86
  n_tol = self.defaults['n_tol']
86
87
  n_bad = self.global_state.get('n_bad', 0)
87
88
 
@@ -1,4 +1,7 @@
1
- from .newton import Newton, InverseFreeNewton
1
+ from .ifn import InverseFreeNewton
2
+ from .inm import INM
3
+ from .multipoint import SixthOrder3P, SixthOrder3PM2, SixthOrder5P, TwoPointNewton
4
+ from .newton import Newton
2
5
  from .newton_cg import NewtonCG, NewtonCGSteihaug
3
- from .nystrom import NystromSketchAndSolve, NystromPCG
4
- from .multipoint import SixthOrder3P, SixthOrder5P, TwoPointNewton, SixthOrder3PM2
6
+ from .nystrom import NystromPCG, NystromSketchAndSolve
7
+ from .rsn import RSN