torchzero 0.3.14__py3-none-any.whl → 0.3.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. tests/test_opts.py +4 -3
  2. torchzero/core/__init__.py +4 -1
  3. torchzero/core/chain.py +50 -0
  4. torchzero/core/functional.py +37 -0
  5. torchzero/core/modular.py +237 -0
  6. torchzero/core/module.py +8 -599
  7. torchzero/core/reformulation.py +3 -1
  8. torchzero/core/transform.py +7 -5
  9. torchzero/core/var.py +376 -0
  10. torchzero/modules/__init__.py +0 -1
  11. torchzero/modules/adaptive/adahessian.py +2 -2
  12. torchzero/modules/adaptive/esgd.py +2 -2
  13. torchzero/modules/adaptive/matrix_momentum.py +1 -1
  14. torchzero/modules/adaptive/sophia_h.py +2 -2
  15. torchzero/modules/experimental/__init__.py +1 -0
  16. torchzero/modules/experimental/newtonnewton.py +5 -5
  17. torchzero/modules/experimental/spsa1.py +2 -2
  18. torchzero/modules/functional.py +7 -0
  19. torchzero/modules/line_search/__init__.py +1 -1
  20. torchzero/modules/line_search/_polyinterp.py +3 -1
  21. torchzero/modules/line_search/adaptive.py +3 -3
  22. torchzero/modules/line_search/backtracking.py +1 -1
  23. torchzero/modules/line_search/interpolation.py +160 -0
  24. torchzero/modules/line_search/line_search.py +11 -20
  25. torchzero/modules/line_search/strong_wolfe.py +3 -3
  26. torchzero/modules/misc/misc.py +2 -2
  27. torchzero/modules/misc/multistep.py +13 -13
  28. torchzero/modules/quasi_newton/__init__.py +2 -0
  29. torchzero/modules/quasi_newton/quasi_newton.py +15 -6
  30. torchzero/modules/quasi_newton/sg2.py +292 -0
  31. torchzero/modules/second_order/__init__.py +6 -3
  32. torchzero/modules/second_order/ifn.py +89 -0
  33. torchzero/modules/second_order/inm.py +105 -0
  34. torchzero/modules/second_order/newton.py +103 -193
  35. torchzero/modules/second_order/nystrom.py +1 -1
  36. torchzero/modules/second_order/rsn.py +227 -0
  37. torchzero/modules/wrappers/optim_wrapper.py +49 -42
  38. torchzero/utils/derivatives.py +19 -19
  39. torchzero/utils/linalg/linear_operator.py +50 -2
  40. {torchzero-0.3.14.dist-info → torchzero-0.3.15.dist-info}/METADATA +1 -1
  41. {torchzero-0.3.14.dist-info → torchzero-0.3.15.dist-info}/RECORD +44 -36
  42. torchzero/modules/higher_order/__init__.py +0 -1
  43. /torchzero/modules/{higher_order → experimental}/higher_order_newton.py +0 -0
  44. {torchzero-0.3.14.dist-info → torchzero-0.3.15.dist-info}/WHEEL +0 -0
  45. {torchzero-0.3.14.dist-info → torchzero-0.3.15.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,160 @@
1
+ import math
2
+ from bisect import insort
3
+
4
+ import numpy as np
5
+ from numpy.polynomial import Polynomial
6
+
7
+
8
+ # we have a list of points in ascending order of their `y` value
9
+ class Point:
10
+ __slots__ = ("x", "y", "d")
11
+ def __init__(self, x, y, d):
12
+ self.x = x
13
+ self.y = y
14
+ self.d = d
15
+
16
+ def __lt__(self, other):
17
+ return self.y < other.y
18
+
19
+ def _get_dpoint(points: list[Point]):
20
+ """returns lowest point with derivative and list of other points"""
21
+ for i,p in enumerate(points):
22
+ if p.d is not None:
23
+ cpoints = points.copy()
24
+ del cpoints[i]
25
+ return p, cpoints
26
+ return None, points
27
+
28
+ # -------------------------------- quadratic2 -------------------------------- #
29
+ def _fitmin_quadratic2(x1, y1, d1, x2, y2):
30
+
31
+ a = (y2 - y1 - d1*(x2 - x1)) / (x2 - x1)**2
32
+ if a <= 0: return None
33
+
34
+ b = d1 - 2*a*x1
35
+ # c = y_1 - d_1*x_1 + a*x_1**2
36
+
37
+ return -b / (2*a)
38
+
39
+ def quadratic2(points:list[Point]):
40
+ pd, points = _get_dpoint(points)
41
+ if pd is None: return None
42
+ if len(points) == 0: return None
43
+
44
+ pn = points[0]
45
+ return _fitmin_quadratic2(pd.x, pd.y, pd.d, pn.x, pn.y)
46
+
47
+ # -------------------------------- quadratic3 -------------------------------- #
48
+ def _fitmin_quadratic3(x1, y1, x2, y2, x3, y3):
49
+ quad = Polynomial.fit([x1,x2,x3], [y1,y2,y3], deg=2)
50
+ a,b,c = quad.coef
51
+ if a <= 0: return None
52
+ return -b / (2*a)
53
+
54
+ def quadratic3(points:list[Point]):
55
+ if len(points) < 3: return None
56
+
57
+ p1,p2,p3 = points[:3]
58
+ return _fitmin_quadratic3(p1.x, p1.y, p2.x, p2.y, p3.x, p3.y)
59
+
60
+ # ---------------------------------- cubic3 ---------------------------------- #
61
+ def _minimize_polynomial(poly: Polynomial):
62
+ roots = poly.deriv().roots()
63
+ vals = poly(roots)
64
+ argmin = np.argmin(vals)
65
+ return roots[argmin], vals[argmin]
66
+
67
+
68
+ def _fitmin_cubic3(x1,y1,x2,y2,x3,y3,x4,d4):
69
+ """x4 is allowed to be equal to x1"""
70
+
71
+ A = np.array([
72
+ [x1**3, x1**2, x1, 1],
73
+ [x2**3, x2**2, x2, 1],
74
+ [x3**3, x3**2, x3, 1],
75
+ [3*x4**2, 2*x4, 1, 0]
76
+ ])
77
+
78
+ B = np.array([y1, y2, y3, d4])
79
+
80
+ try:
81
+ coeffs = np.linalg.solve(A, B)
82
+ except np.linalg.LinAlgError:
83
+ return None
84
+
85
+ cubic = Polynomial(coeffs)
86
+ x_min, y_min = _minimize_polynomial(cubic)
87
+ if y_min < min(y1,y2,y3): return x_min
88
+ return None
89
+
90
+ def cubic3(points: list[Point]):
91
+ pd, points = _get_dpoint(points)
92
+ if pd is None: return None
93
+ if len(points) < 2: return None
94
+ p1, p2 = points[:2]
95
+ return _fitmin_cubic3(pd.x, pd.y, p1.x, p1.y, p2.x, p2.y, pd.x, pd.d)
96
+
97
+ # ---------------------------------- cubic4 ---------------------------------- #
98
+ def _fitmin_cubic4(x1, y1, x2, y2, x3, y3, x4, y4):
99
+ cubic = Polynomial.fit([x1,x2,x3,x4], [y1,y2,y3,y4], deg=3)
100
+ x_min, y_min = _minimize_polynomial(cubic)
101
+ if y_min < min(y1,y2,y3,y4): return x_min
102
+ return None
103
+
104
+ def cubic4(points:list[Point]):
105
+ if len(points) < 4: return None
106
+
107
+ p1,p2,p3,p4 = points[:4]
108
+ return _fitmin_cubic4(p1.x, p1.y, p2.x, p2.y, p3.x, p3.y, p4.x, p4.y)
109
+
110
+ # ---------------------------------- linear3 --------------------------------- #
111
+ def _linear_intersection(x1,y1,s1,x2,y2,s2):
112
+ if s1 == 0 or s2 == 0 or s1 == s2: return None
113
+ return (y1 - s1*x1 - y2 + s2*x2) / (s2 - s1)
114
+
115
+ def _fitmin_linear3(x1, y1, d1, x2, y2, x3, y3):
116
+ # we have that
117
+ # s2 = (y2 - y3) / (x2 - x3) # slope origin in x2 y2
118
+ # f1(x) = y1 + d1 * (x - x1)
119
+ # f2(x) = y2 + s2 * (x - x2)
120
+ # y1 + d1 * (x - x1) = y2 + s2 * (x - x2)
121
+ # y1 + d1 x - d1 x1 - y2 - s2 x + s2 x2 = 0
122
+ # s2 x - d1 x = y1 - d1 x1 - y2 + s2 x2
123
+ # x = (y1 - d1 x1 - y2 + s2 x2) / (s2 - d1)
124
+
125
+ if x2 < x1 < x3 or x3 < x1 < x2: # point with derivative in between
126
+ return None
127
+
128
+ if d1 > 0:
129
+ if x2 > x1 or x3 > x1: return None # intersection is above to the right
130
+ if x2 > x3: x2,y2,x3,y3 = x3,y3,x2,y2
131
+ if d1 < 0:
132
+ if x2 < x1 or x3 < x1: return None # intersection is above to the left
133
+ if x2 < x3: x2,y2,x3,y3 = x3,y3,x2,y2
134
+
135
+ s2 = (y2 - y3) / (x2 - x3)
136
+ return _linear_intersection(x1,y1,d1,x2,y2,s2)
137
+
138
+ def linear3(points:list[Point]):
139
+ pd, points = _get_dpoint(points)
140
+ if pd is None: return None
141
+ if len(points) < 2: return None
142
+ p1, p2 = points[:2]
143
+ return _fitmin_linear3(pd.x, pd.y, pd.d, p1.x, p1.y, p2.x, p2.y)
144
+
145
+ # ---------------------------------- linear4 --------------------------------- #
146
+ def _fitmin_linear4(x1, y1, x2, y2, x3, y3, x4, y4):
147
+ # sort by x
148
+ points = ((x1,y1), (x2,y2), (x3,y3), (x4,y4))
149
+ points = sorted(points, key=lambda x: x[0])
150
+
151
+ (x1,y1), (x2,y2), (x3,y3), (x4,y4) = points
152
+ s1 = (y1 - y2) / (x1 - x2)
153
+ s3 = (y3 - y4) / (x3 - x4)
154
+
155
+ return _linear_intersection(x1,y1,s1,x3,y3,s3)
156
+
157
+ def linear4(points:list[Point]):
158
+ if len(points) < 4: return None
159
+ p1,p2,p3,p4 = points[:4]
160
+ return _fitmin_linear4(p1.x, p1.y, p2.x, p2.y, p3.x, p3.y, p4.x, p4.y)
@@ -10,6 +10,7 @@ import torch
10
10
 
11
11
  from ...core import Module, Target, Var
12
12
  from ...utils import tofloat, set_storage_
13
+ from ..functional import clip_by_finfo
13
14
 
14
15
 
15
16
  class MaxLineSearchItersReached(Exception): pass
@@ -103,23 +104,18 @@ class LineSearchBase(Module, ABC):
103
104
  ):
104
105
  if not math.isfinite(step_size): return
105
106
 
106
- # fixes overflow when backtracking keeps increasing alpha after converging
107
- step_size = max(min(tofloat(step_size), 1e36), -1e36)
107
+ # avoid overflow error
108
+ step_size = clip_by_finfo(tofloat(step_size), torch.finfo(update[0].dtype))
108
109
 
109
110
  # skip is parameters are already at suggested step size
110
111
  if self._current_step_size == step_size: return
111
112
 
112
- # this was basically causing floating point imprecision to build up
113
- #if False:
114
- # if abs(alpha) < abs(step_size) and step_size != 0:
115
- # torch._foreach_add_(params, update, alpha=alpha)
116
-
117
- # else:
118
113
  assert self._initial_params is not None
119
114
  if step_size == 0:
120
115
  new_params = [p.clone() for p in self._initial_params]
121
116
  else:
122
117
  new_params = torch._foreach_sub(self._initial_params, update, alpha=step_size)
118
+
123
119
  for c, n in zip(params, new_params):
124
120
  set_storage_(c, n)
125
121
 
@@ -131,10 +127,7 @@ class LineSearchBase(Module, ABC):
131
127
  params: list[torch.Tensor],
132
128
  update: list[torch.Tensor],
133
129
  ):
134
- # if not np.isfinite(step_size): step_size = [0 for _ in step_size]
135
- # alpha = [self._current_step_size - s for s in step_size]
136
- # if any(a!=0 for a in alpha):
137
- # torch._foreach_add_(params, torch._foreach_mul(update, alpha))
130
+
138
131
  assert self._initial_params is not None
139
132
  if not np.isfinite(step_size).all(): step_size = [0 for _ in step_size]
140
133
 
@@ -248,16 +241,14 @@ class LineSearchBase(Module, ABC):
248
241
  except MaxLineSearchItersReached:
249
242
  step_size = self._best_step_size
250
243
 
244
+ step_size = clip_by_finfo(step_size, torch.finfo(update[0].dtype))
245
+
251
246
  # set loss_approx
252
247
  if var.loss_approx is None: var.loss_approx = self._lowest_loss
253
248
 
254
- # this is last module - set step size to found step_size times lr
255
- if var.is_last:
256
- if var.last_module_lrs is None:
257
- self.set_step_size_(step_size, params=params, update=update)
258
-
259
- else:
260
- self._set_per_parameter_step_size_([step_size*lr for lr in var.last_module_lrs], params=params, update=update)
249
+ # if this is last module, directly update parameters to avoid redundant operations
250
+ if var.modular is not None and self is var.modular.modules[-1]:
251
+ self.set_step_size_(step_size, params=params, update=update)
261
252
 
262
253
  var.stop = True; var.skip_update = True
263
254
  return var
@@ -277,7 +268,7 @@ class GridLineSearch(LineSearchBase):
277
268
 
278
269
  @torch.no_grad
279
270
  def search(self, update, var):
280
- start,end,num=itemgetter('start','end','num')(self.defaults)
271
+ start, end, num = itemgetter('start', 'end', 'num')(self.defaults)
281
272
 
282
273
  for lr in torch.linspace(start,end,num):
283
274
  self.evaluate_f(lr.item(), var=var, backward=False)
@@ -7,7 +7,7 @@ import numpy as np
7
7
  import torch
8
8
  from torch.optim.lbfgs import _cubic_interpolate
9
9
 
10
- from ...utils import as_tensorlist, totensor
10
+ from ...utils import as_tensorlist, totensor, tofloat
11
11
  from ._polyinterp import polyinterp, polyinterp2
12
12
  from .line_search import LineSearchBase, TerminationCondition, termination_condition
13
13
  from ..step_size.adaptive import _bb_geom
@@ -92,7 +92,7 @@ class _StrongWolfe:
92
92
  return _apply_bounds(a_lo + 0.5 * (a_hi - a_lo), bounds)
93
93
 
94
94
  if self.interpolation in ('polynomial', 'polynomial2'):
95
- finite_history = [(a, f, g) for a, (f,g) in self.history.items() if math.isfinite(a) and math.isfinite(f) and math.isfinite(g)]
95
+ finite_history = [(tofloat(a), tofloat(f), tofloat(g)) for a, (f,g) in self.history.items() if math.isfinite(a) and math.isfinite(f) and math.isfinite(g)]
96
96
  if bounds is None: bounds = (None, None)
97
97
  polyinterp_fn = polyinterp if self.interpolation == 'polynomial' else polyinterp2
98
98
  try:
@@ -370,6 +370,6 @@ class StrongWolfe(LineSearchBase):
370
370
  self.global_state['initial_scale'] = self.global_state.get('initial_scale', 1) * 0.5
371
371
  finfo = torch.finfo(dir[0].dtype)
372
372
  if self.global_state['initial_scale'] < finfo.tiny * 2:
373
- self.global_state['initial_scale'] = finfo.max / 2
373
+ self.global_state['initial_scale'] = init_value * 2
374
374
 
375
375
  return 0
@@ -306,8 +306,8 @@ class RandomHvp(Module):
306
306
  for i in range(n_samples):
307
307
  u = params.sample_like(distribution=distribution, variance=1)
308
308
 
309
- Hvp, rgrad = self.Hvp(u, at_x0=True, var=var, rgrad=rgrad, hvp_method=hvp_method,
310
- h=h, normalize=True, retain_grad=i < n_samples-1)
309
+ Hvp, rgrad = var.hessian_vector_product(u, at_x0=True, rgrad=rgrad, hvp_method=hvp_method,
310
+ h=h, normalize=True, retain_graph=i < n_samples-1)
311
311
 
312
312
  if D is None: D = Hvp
313
313
  else: torch._foreach_add_(D, Hvp)
@@ -15,7 +15,7 @@ def _sequential_step(self: Module, var: Var, sequential: bool):
15
15
  if var.closure is None and len(modules) > 1: raise ValueError('Multistep and Sequential require closure')
16
16
 
17
17
  # store original params unless this is last module and can update params directly
18
- params_before_steps = None if (var.is_last and var.last_module_lrs is None) else [p.clone() for p in params]
18
+ params_before_steps = [p.clone() for p in params]
19
19
 
20
20
  # first step - pass var as usual
21
21
  var = modules[0].step(var)
@@ -27,8 +27,8 @@ def _sequential_step(self: Module, var: Var, sequential: bool):
27
27
 
28
28
  # update params
29
29
  if (not new_var.skip_update):
30
- if new_var.last_module_lrs is not None:
31
- torch._foreach_mul_(new_var.get_update(), new_var.last_module_lrs)
30
+ # if new_var.last_module_lrs is not None:
31
+ # torch._foreach_mul_(new_var.get_update(), new_var.last_module_lrs)
32
32
 
33
33
  torch._foreach_sub_(params, new_var.get_update())
34
34
 
@@ -41,16 +41,16 @@ def _sequential_step(self: Module, var: Var, sequential: bool):
41
41
 
42
42
  # final parameter update
43
43
  if (not new_var.skip_update):
44
- if new_var.last_module_lrs is not None:
45
- torch._foreach_mul_(new_var.get_update(), new_var.last_module_lrs)
44
+ # if new_var.last_module_lrs is not None:
45
+ # torch._foreach_mul_(new_var.get_update(), new_var.last_module_lrs)
46
46
 
47
47
  torch._foreach_sub_(params, new_var.get_update())
48
48
 
49
49
  # if last module, update is applied so return new var
50
- if params_before_steps is None:
51
- new_var.stop = True
52
- new_var.skip_update = True
53
- return new_var
50
+ # if params_before_steps is None:
51
+ # new_var.stop = True
52
+ # new_var.skip_update = True
53
+ # return new_var
54
54
 
55
55
  # otherwise use parameter difference as update
56
56
  var.update = list(torch._foreach_sub(params_before_steps, params))
@@ -106,10 +106,10 @@ class NegateOnLossIncrease(Module):
106
106
  f_1 = closure(False)
107
107
 
108
108
  if f_1 <= f_0:
109
- if var.is_last and var.last_module_lrs is None:
110
- var.stop = True
111
- var.skip_update = True
112
- return var
109
+ # if var.is_last and var.last_module_lrs is None:
110
+ # var.stop = True
111
+ # var.skip_update = True
112
+ # return var
113
113
 
114
114
  torch._foreach_add_(var.params, update)
115
115
  return var
@@ -29,3 +29,5 @@ from .quasi_newton import (
29
29
  ShorR,
30
30
  ThomasOptimalMethod,
31
31
  )
32
+
33
+ from .sg2 import SG2, SPSA2
@@ -1182,16 +1182,19 @@ class ShorR(HessianUpdateStrategy):
1182
1182
  """Shor’s r-algorithm.
1183
1183
 
1184
1184
  Note:
1185
- A line search such as ``tz.m.StrongWolfe(a_init="quadratic", fallback=True)`` is required.
1186
- Similarly to conjugate gradient, ShorR doesn't have an automatic step size scaling,
1187
- so setting ``a_init`` in the line search is recommended.
1185
+ - A line search such as ``[tz.m.StrongWolfe(a_init="quadratic", fallback=True), tz.m.Mul(1.2)]`` is required. Similarly to conjugate gradient, ShorR doesn't have an automatic step size scaling, so setting ``a_init`` in the line search is recommended.
1186
+
1187
+ - The line search should try to overstep by a little, therefore it can help to multiply direction given by a line search by some value slightly larger than 1 such as 1.2.
1188
1188
 
1189
1189
  References:
1190
- S HOR , N. Z. (1985) Minimization Methods for Non-differentiable Functions. New York: Springer.
1190
+ Those are the original references, but neither seem to be available online:
1191
+ - Shor, N. Z., Utilization of the Operation of Space Dilatation in the Minimization of Convex Functions, Kibernetika, No. 1, pp. 6-12, 1970.
1192
+
1193
+ - Skokov, V. A., Note on Minimization Methods Employing Space Stretching, Kibernetika, No. 4, pp. 115-117, 1974.
1191
1194
 
1192
- Burke, James V., Adrian S. Lewis, and Michael L. Overton. "The Speed of Shor's R-algorithm." IMA Journal of numerical analysis 28.4 (2008): 711-720. - good overview.
1195
+ An overview is available in [Burke, James V., Adrian S. Lewis, and Michael L. Overton. "The Speed of Shor's R-algorithm." IMA Journal of numerical analysis 28.4 (2008): 711-720](https://sites.math.washington.edu/~burke/papers/reprints/60-speed-Shor-R.pdf).
1193
1196
 
1194
- Ansari, Zafar A. Limited Memory Space Dilation and Reduction Algorithms. Diss. Virginia Tech, 1998. - this is where a more efficient formula is described.
1197
+ Reference by Skokov, V. A. describes a more efficient formula which can be found here [Ansari, Zafar A. Limited Memory Space Dilation and Reduction Algorithms. Diss. Virginia Tech, 1998.](https://camo.ici.ro/books/thesis/th.pdf)
1195
1198
  """
1196
1199
 
1197
1200
  def __init__(
@@ -1229,3 +1232,9 @@ class ShorR(HessianUpdateStrategy):
1229
1232
 
1230
1233
  def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
1231
1234
  return shor_r_(H=H, y=y, alpha=setting['alpha'])
1235
+
1236
+
1237
+ # Todd, Michael J. "The symmetric rank-one quasi-Newton method is a space-dilation subgradient algorithm." Operations research letters 5.5 (1986): 217-219.
1238
+ # TODO
1239
+
1240
+ # Sorensen, D. C. "The q-superlinear convergence of a collinear scaling algorithm for unconstrained optimization." SIAM Journal on Numerical Analysis 17.1 (1980): 84-114.
@@ -0,0 +1,292 @@
1
+ import torch
2
+
3
+ from ...core import Module, Chainable, apply_transform
4
+ from ...utils import TensorList, vec_to_tensors
5
+ from ..second_order.newton import _newton_step, _get_H
6
+
7
+ def sg2_(
8
+ delta_g: torch.Tensor,
9
+ cd: torch.Tensor,
10
+ ) -> torch.Tensor:
11
+ """cd is c * perturbation, and must be multiplied by two if hessian estimate is two-sided
12
+ (or divide delta_g by two)."""
13
+
14
+ M = torch.outer(1.0 / cd, delta_g)
15
+ H_hat = 0.5 * (M + M.T)
16
+
17
+ return H_hat
18
+
19
+
20
+
21
+ class SG2(Module):
22
+ """second-order stochastic gradient
23
+
24
+ SG2 with line search
25
+ ```python
26
+ opt = tz.Modular(
27
+ model.parameters(),
28
+ tz.m.SG2(),
29
+ tz.m.Backtracking()
30
+ )
31
+ ```
32
+
33
+ SG2 with trust region
34
+ ```python
35
+ opt = tz.Modular(
36
+ model.parameters(),
37
+ tz.m.LevenbergMarquardt(tz.m.SG2()),
38
+ )
39
+ ```
40
+
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ n_samples: int = 1,
46
+ h: float = 1e-2,
47
+ beta: float | None = None,
48
+ damping: float = 0,
49
+ eigval_fn=None,
50
+ one_sided: bool = False, # one-sided hessian
51
+ use_lstsq: bool = True,
52
+ seed=None,
53
+ inner: Chainable | None = None,
54
+ ):
55
+ defaults = dict(n_samples=n_samples, h=h, beta=beta, damping=damping, eigval_fn=eigval_fn, one_sided=one_sided, seed=seed, use_lstsq=use_lstsq)
56
+ super().__init__(defaults)
57
+
58
+ if inner is not None: self.set_child('inner', inner)
59
+
60
+ @torch.no_grad
61
+ def update(self, var):
62
+ k = self.global_state.get('step', 0) + 1
63
+ self.global_state["step"] = k
64
+
65
+ params = TensorList(var.params)
66
+ closure = var.closure
67
+ if closure is None:
68
+ raise RuntimeError("closure is required for SG2")
69
+ generator = self.get_generator(params[0].device, self.defaults["seed"])
70
+
71
+ h = self.get_settings(params, "h")
72
+ x_0 = params.clone()
73
+ n_samples = self.defaults["n_samples"]
74
+ H_hat = None
75
+
76
+ for i in range(n_samples):
77
+ # generate perturbation
78
+ cd = params.rademacher_like(generator=generator).mul_(h)
79
+
80
+ # one sided
81
+ if self.defaults["one_sided"]:
82
+ g_0 = TensorList(var.get_grad())
83
+ params.add_(cd)
84
+ closure()
85
+
86
+ g_p = params.grad.fill_none_(params)
87
+ delta_g = (g_p - g_0) * 2
88
+
89
+ # two sided
90
+ else:
91
+ params.add_(cd)
92
+ closure()
93
+ g_p = params.grad.fill_none_(params)
94
+
95
+ params.copy_(x_0)
96
+ params.sub_(cd)
97
+ closure()
98
+ g_n = params.grad.fill_none_(params)
99
+
100
+ delta_g = g_p - g_n
101
+
102
+ # restore params
103
+ params.set_(x_0)
104
+
105
+ # compute H hat
106
+ H_i = sg2_(
107
+ delta_g = delta_g.to_vec(),
108
+ cd = cd.to_vec(),
109
+ )
110
+
111
+ if H_hat is None: H_hat = H_i
112
+ else: H_hat += H_i
113
+
114
+ assert H_hat is not None
115
+ if n_samples > 1: H_hat /= n_samples
116
+
117
+ # update H
118
+ H = self.global_state.get("H", None)
119
+ if H is None: H = H_hat
120
+ else:
121
+ beta = self.defaults["beta"]
122
+ if beta is None: beta = k / (k+1)
123
+ H.lerp_(H_hat, 1-beta)
124
+
125
+ self.global_state["H"] = H
126
+
127
+
128
+ @torch.no_grad
129
+ def apply(self, var):
130
+ dir = _newton_step(
131
+ var=var,
132
+ H = self.global_state["H"],
133
+ damping = self.defaults["damping"],
134
+ inner = self.children.get("inner", None),
135
+ H_tfm=None,
136
+ eigval_fn=self.defaults["eigval_fn"],
137
+ use_lstsq=self.defaults["use_lstsq"],
138
+ g_proj=None,
139
+ )
140
+
141
+ var.update = vec_to_tensors(dir, var.params)
142
+ return var
143
+
144
+ def get_H(self,var=...):
145
+ return _get_H(self.global_state["H"], self.defaults["eigval_fn"])
146
+
147
+
148
+
149
+
150
+ # two sided
151
+ # we have g via x + d, x - d
152
+ # H via g(x + d), g(x - d)
153
+ # 1 is x, x+2d
154
+ # 2 is x, x-2d
155
+ # 5 evals in total
156
+
157
+ # one sided
158
+ # g via x, x + d
159
+ # 1 is x, x + d
160
+ # 2 is x, x - d
161
+ # 3 evals and can use two sided for g_0
162
+
163
+ class SPSA2(Module):
164
+ """second-order SPSA
165
+
166
+ SPSA2 with line search
167
+ ```python
168
+ opt = tz.Modular(
169
+ model.parameters(),
170
+ tz.m.SPSA2(),
171
+ tz.m.Backtracking()
172
+ )
173
+ ```
174
+
175
+ SPSA2 with trust region
176
+ ```python
177
+ opt = tz.Modular(
178
+ model.parameters(),
179
+ tz.m.LevenbergMarquardt(tz.m.SPSA2()),
180
+ )
181
+ ```
182
+ """
183
+
184
+ def __init__(
185
+ self,
186
+ n_samples: int = 1,
187
+ h: float = 1e-2,
188
+ beta: float | None = None,
189
+ damping: float = 0,
190
+ eigval_fn=None,
191
+ use_lstsq: bool = True,
192
+ seed=None,
193
+ inner: Chainable | None = None,
194
+ ):
195
+ defaults = dict(n_samples=n_samples, h=h, beta=beta, damping=damping, eigval_fn=eigval_fn, seed=seed, use_lstsq=use_lstsq)
196
+ super().__init__(defaults)
197
+
198
+ if inner is not None: self.set_child('inner', inner)
199
+
200
+ @torch.no_grad
201
+ def update(self, var):
202
+ k = self.global_state.get('step', 0) + 1
203
+ self.global_state["step"] = k
204
+
205
+ params = TensorList(var.params)
206
+ closure = var.closure
207
+ if closure is None:
208
+ raise RuntimeError("closure is required for SPSA2")
209
+
210
+ generator = self.get_generator(params[0].device, self.defaults["seed"])
211
+
212
+ h = self.get_settings(params, "h")
213
+ x_0 = params.clone()
214
+ n_samples = self.defaults["n_samples"]
215
+ H_hat = None
216
+ g_0 = None
217
+
218
+ for i in range(n_samples):
219
+ # perturbations for g and H
220
+ cd_g = params.rademacher_like(generator=generator).mul_(h)
221
+ cd_H = params.rademacher_like(generator=generator).mul_(h)
222
+
223
+ # evaluate 4 points
224
+ x_p = x_0 + cd_g
225
+ x_n = x_0 - cd_g
226
+
227
+ params.set_(x_p)
228
+ f_p = closure(False)
229
+ params.add_(cd_H)
230
+ f_pp = closure(False)
231
+
232
+ params.set_(x_n)
233
+ f_n = closure(False)
234
+ params.add_(cd_H)
235
+ f_np = closure(False)
236
+
237
+ g_p_vec = (f_pp - f_p) / cd_H
238
+ g_n_vec = (f_np - f_n) / cd_H
239
+ delta_g = g_p_vec - g_n_vec
240
+
241
+ # restore params
242
+ params.set_(x_0)
243
+
244
+ # compute grad
245
+ g_i = (f_p - f_n) / (2 * cd_g)
246
+ if g_0 is None: g_0 = g_i
247
+ else: g_0 += g_i
248
+
249
+ # compute H hat
250
+ H_i = sg2_(
251
+ delta_g = delta_g.to_vec().div_(2.0),
252
+ cd = cd_g.to_vec(), # The interval is measured by the original 'cd'
253
+ )
254
+ if H_hat is None: H_hat = H_i
255
+ else: H_hat += H_i
256
+
257
+ assert g_0 is not None and H_hat is not None
258
+ if n_samples > 1:
259
+ g_0 /= n_samples
260
+ H_hat /= n_samples
261
+
262
+ # set grad to approximated grad
263
+ var.grad = g_0
264
+
265
+ # update H
266
+ H = self.global_state.get("H", None)
267
+ if H is None: H = H_hat
268
+ else:
269
+ beta = self.defaults["beta"]
270
+ if beta is None: beta = k / (k+1)
271
+ H.lerp_(H_hat, 1-beta)
272
+
273
+ self.global_state["H"] = H
274
+
275
+ @torch.no_grad
276
+ def apply(self, var):
277
+ dir = _newton_step(
278
+ var=var,
279
+ H = self.global_state["H"],
280
+ damping = self.defaults["damping"],
281
+ inner = self.children.get("inner", None),
282
+ H_tfm=None,
283
+ eigval_fn=self.defaults["eigval_fn"],
284
+ use_lstsq=self.defaults["use_lstsq"],
285
+ g_proj=None,
286
+ )
287
+
288
+ var.update = vec_to_tensors(dir, var.params)
289
+ return var
290
+
291
+ def get_H(self,var=...):
292
+ return _get_H(self.global_state["H"], self.defaults["eigval_fn"])
@@ -1,4 +1,7 @@
1
- from .newton import Newton, InverseFreeNewton
1
+ from .ifn import InverseFreeNewton
2
+ from .inm import INM
3
+ from .multipoint import SixthOrder3P, SixthOrder3PM2, SixthOrder5P, TwoPointNewton
4
+ from .newton import Newton
2
5
  from .newton_cg import NewtonCG, NewtonCGSteihaug
3
- from .nystrom import NystromSketchAndSolve, NystromPCG
4
- from .multipoint import SixthOrder3P, SixthOrder5P, TwoPointNewton, SixthOrder3PM2
6
+ from .nystrom import NystromPCG, NystromSketchAndSolve
7
+ from .rsn import RSN