torchzero 0.3.11__py3-none-any.whl → 0.3.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (164) hide show
  1. tests/test_opts.py +95 -76
  2. tests/test_tensorlist.py +8 -7
  3. torchzero/__init__.py +1 -1
  4. torchzero/core/__init__.py +2 -2
  5. torchzero/core/module.py +229 -72
  6. torchzero/core/reformulation.py +65 -0
  7. torchzero/core/transform.py +44 -24
  8. torchzero/modules/__init__.py +13 -5
  9. torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
  10. torchzero/modules/adaptive/adagrad.py +356 -0
  11. torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
  12. torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
  13. torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
  14. torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
  15. torchzero/modules/adaptive/aegd.py +54 -0
  16. torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
  17. torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
  18. torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
  19. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  20. torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
  21. torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
  22. torchzero/modules/adaptive/natural_gradient.py +175 -0
  23. torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
  24. torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
  25. torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
  26. torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
  27. torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
  28. torchzero/modules/clipping/clipping.py +85 -92
  29. torchzero/modules/clipping/ema_clipping.py +5 -5
  30. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  31. torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
  32. torchzero/modules/experimental/__init__.py +9 -32
  33. torchzero/modules/experimental/dct.py +2 -2
  34. torchzero/modules/experimental/fft.py +2 -2
  35. torchzero/modules/experimental/gradmin.py +4 -3
  36. torchzero/modules/experimental/l_infinity.py +111 -0
  37. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
  38. torchzero/modules/experimental/newton_solver.py +79 -17
  39. torchzero/modules/experimental/newtonnewton.py +27 -14
  40. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  41. torchzero/modules/experimental/spsa1.py +93 -0
  42. torchzero/modules/experimental/structural_projections.py +1 -1
  43. torchzero/modules/functional.py +50 -14
  44. torchzero/modules/grad_approximation/__init__.py +1 -1
  45. torchzero/modules/grad_approximation/fdm.py +19 -20
  46. torchzero/modules/grad_approximation/forward_gradient.py +6 -7
  47. torchzero/modules/grad_approximation/grad_approximator.py +43 -47
  48. torchzero/modules/grad_approximation/rfdm.py +114 -175
  49. torchzero/modules/higher_order/__init__.py +1 -1
  50. torchzero/modules/higher_order/higher_order_newton.py +31 -23
  51. torchzero/modules/least_squares/__init__.py +1 -0
  52. torchzero/modules/least_squares/gn.py +161 -0
  53. torchzero/modules/line_search/__init__.py +2 -2
  54. torchzero/modules/line_search/_polyinterp.py +289 -0
  55. torchzero/modules/line_search/adaptive.py +69 -44
  56. torchzero/modules/line_search/backtracking.py +83 -70
  57. torchzero/modules/line_search/line_search.py +159 -68
  58. torchzero/modules/line_search/scipy.py +16 -4
  59. torchzero/modules/line_search/strong_wolfe.py +319 -220
  60. torchzero/modules/misc/__init__.py +8 -0
  61. torchzero/modules/misc/debug.py +4 -4
  62. torchzero/modules/misc/escape.py +9 -7
  63. torchzero/modules/misc/gradient_accumulation.py +88 -22
  64. torchzero/modules/misc/homotopy.py +59 -0
  65. torchzero/modules/misc/misc.py +82 -15
  66. torchzero/modules/misc/multistep.py +47 -11
  67. torchzero/modules/misc/regularization.py +5 -9
  68. torchzero/modules/misc/split.py +55 -35
  69. torchzero/modules/misc/switch.py +1 -1
  70. torchzero/modules/momentum/__init__.py +1 -5
  71. torchzero/modules/momentum/averaging.py +3 -3
  72. torchzero/modules/momentum/cautious.py +42 -47
  73. torchzero/modules/momentum/momentum.py +35 -1
  74. torchzero/modules/ops/__init__.py +9 -1
  75. torchzero/modules/ops/binary.py +9 -8
  76. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
  77. torchzero/modules/ops/multi.py +15 -15
  78. torchzero/modules/ops/reduce.py +1 -1
  79. torchzero/modules/ops/utility.py +12 -8
  80. torchzero/modules/projections/projection.py +4 -4
  81. torchzero/modules/quasi_newton/__init__.py +1 -16
  82. torchzero/modules/quasi_newton/damping.py +105 -0
  83. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
  84. torchzero/modules/quasi_newton/lbfgs.py +256 -200
  85. torchzero/modules/quasi_newton/lsr1.py +167 -132
  86. torchzero/modules/quasi_newton/quasi_newton.py +346 -446
  87. torchzero/modules/restarts/__init__.py +7 -0
  88. torchzero/modules/restarts/restars.py +253 -0
  89. torchzero/modules/second_order/__init__.py +2 -1
  90. torchzero/modules/second_order/multipoint.py +238 -0
  91. torchzero/modules/second_order/newton.py +133 -88
  92. torchzero/modules/second_order/newton_cg.py +207 -170
  93. torchzero/modules/smoothing/__init__.py +1 -1
  94. torchzero/modules/smoothing/sampling.py +300 -0
  95. torchzero/modules/step_size/__init__.py +1 -1
  96. torchzero/modules/step_size/adaptive.py +312 -47
  97. torchzero/modules/termination/__init__.py +14 -0
  98. torchzero/modules/termination/termination.py +207 -0
  99. torchzero/modules/trust_region/__init__.py +5 -0
  100. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  101. torchzero/modules/trust_region/dogleg.py +92 -0
  102. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  103. torchzero/modules/trust_region/trust_cg.py +99 -0
  104. torchzero/modules/trust_region/trust_region.py +350 -0
  105. torchzero/modules/variance_reduction/__init__.py +1 -0
  106. torchzero/modules/variance_reduction/svrg.py +208 -0
  107. torchzero/modules/weight_decay/weight_decay.py +65 -64
  108. torchzero/modules/zeroth_order/__init__.py +1 -0
  109. torchzero/modules/zeroth_order/cd.py +122 -0
  110. torchzero/optim/root.py +65 -0
  111. torchzero/optim/utility/split.py +8 -8
  112. torchzero/optim/wrappers/directsearch.py +0 -1
  113. torchzero/optim/wrappers/fcmaes.py +3 -2
  114. torchzero/optim/wrappers/nlopt.py +0 -2
  115. torchzero/optim/wrappers/optuna.py +2 -2
  116. torchzero/optim/wrappers/scipy.py +81 -22
  117. torchzero/utils/__init__.py +40 -4
  118. torchzero/utils/compile.py +1 -1
  119. torchzero/utils/derivatives.py +123 -111
  120. torchzero/utils/linalg/__init__.py +9 -2
  121. torchzero/utils/linalg/linear_operator.py +329 -0
  122. torchzero/utils/linalg/matrix_funcs.py +2 -2
  123. torchzero/utils/linalg/orthogonalize.py +2 -1
  124. torchzero/utils/linalg/qr.py +2 -2
  125. torchzero/utils/linalg/solve.py +226 -154
  126. torchzero/utils/metrics.py +83 -0
  127. torchzero/utils/optimizer.py +2 -2
  128. torchzero/utils/python_tools.py +7 -0
  129. torchzero/utils/tensorlist.py +105 -34
  130. torchzero/utils/torch_tools.py +9 -4
  131. torchzero-0.3.14.dist-info/METADATA +14 -0
  132. torchzero-0.3.14.dist-info/RECORD +167 -0
  133. {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/top_level.txt +0 -1
  134. docs/source/conf.py +0 -59
  135. docs/source/docstring template.py +0 -46
  136. torchzero/modules/experimental/absoap.py +0 -253
  137. torchzero/modules/experimental/adadam.py +0 -118
  138. torchzero/modules/experimental/adamY.py +0 -131
  139. torchzero/modules/experimental/adam_lambertw.py +0 -149
  140. torchzero/modules/experimental/adaptive_step_size.py +0 -90
  141. torchzero/modules/experimental/adasoap.py +0 -177
  142. torchzero/modules/experimental/cosine.py +0 -214
  143. torchzero/modules/experimental/cubic_adam.py +0 -97
  144. torchzero/modules/experimental/eigendescent.py +0 -120
  145. torchzero/modules/experimental/etf.py +0 -195
  146. torchzero/modules/experimental/exp_adam.py +0 -113
  147. torchzero/modules/experimental/expanded_lbfgs.py +0 -141
  148. torchzero/modules/experimental/hnewton.py +0 -85
  149. torchzero/modules/experimental/modular_lbfgs.py +0 -265
  150. torchzero/modules/experimental/parabolic_search.py +0 -220
  151. torchzero/modules/experimental/subspace_preconditioners.py +0 -145
  152. torchzero/modules/experimental/tensor_adagrad.py +0 -42
  153. torchzero/modules/line_search/polynomial.py +0 -233
  154. torchzero/modules/momentum/matrix_momentum.py +0 -193
  155. torchzero/modules/optimizers/adagrad.py +0 -165
  156. torchzero/modules/quasi_newton/trust_region.py +0 -397
  157. torchzero/modules/smoothing/gaussian.py +0 -198
  158. torchzero-0.3.11.dist-info/METADATA +0 -404
  159. torchzero-0.3.11.dist-info/RECORD +0 -159
  160. torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
  161. /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
  162. /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
  163. /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
  164. {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/WHEEL +0 -0
@@ -1,122 +1,387 @@
1
1
  """Various step size strategies"""
2
- from typing import Any, Literal
2
+ import math
3
3
  from operator import itemgetter
4
+ from typing import Any, Literal
5
+
4
6
  import torch
5
7
 
6
- from ...core import Transform, Chainable
7
- from ...utils import TensorList, unpack_dicts, unpack_states, NumberList
8
+ from ...core import Chainable, Transform
9
+ from ...utils import NumberList, TensorList, tofloat, unpack_dicts, unpack_states
10
+ from ...utils.linalg.linear_operator import ScaledIdentity
11
+ from ..functional import epsilon_step_size
12
+
13
+ def _acceptable_alpha(alpha, param:torch.Tensor):
14
+ finfo = torch.finfo(param.dtype)
15
+ if (alpha is None) or (alpha < finfo.tiny*2) or (not math.isfinite(alpha)) or (alpha > finfo.max/2):
16
+ return False
17
+ return True
18
+
19
+ def _get_H(self: Transform, var):
20
+ n = sum(p.numel() for p in var.params)
21
+ p = var.params[0]
22
+ alpha = self.global_state.get('alpha', 1)
23
+ if not _acceptable_alpha(alpha, p): alpha = 1
24
+
25
+ return ScaledIdentity(1 / alpha, shape=(n,n), device=p.device, dtype=p.dtype)
8
26
 
9
27
 
10
28
  class PolyakStepSize(Transform):
11
- """Polyak's subgradient method.
29
+ """Polyak's subgradient method with known or unknown f*.
12
30
 
13
31
  Args:
14
- f_star (int, optional):
15
- (estimated) minimal possible value of the objective function (lowest possible loss). Defaults to 0.
32
+ f_star (float | Mone, optional):
33
+ minimal possible value of the objective function. If not known, set to ``None``. Defaults to 0.
34
+ y (float, optional):
35
+ when ``f_star`` is set to None, it is calculated as ``f_best - y``.
36
+ y_decay (float, optional):
37
+ ``y`` is multiplied by ``(1 - y_decay)`` after each step. Defaults to 1e-3.
16
38
  max (float | None, optional): maximum possible step size. Defaults to None.
17
39
  use_grad (bool, optional):
18
40
  if True, uses dot product of update and gradient to compute the step size.
19
- Otherwise, dot product of update with itself is used, which has no geometric meaning so it probably won't work well.
20
- Defaults to False.
41
+ Otherwise, dot product of update with itself is used.
21
42
  alpha (float, optional): multiplier to Polyak step-size. Defaults to 1.
22
43
  """
23
- def __init__(self, f_star: float = 0, max: float | None = None, use_grad=False, alpha: float = 1, inner: Chainable | None = None):
44
+ def __init__(self, f_star: float | None = 0, y: float = 1, y_decay: float = 1e-3, max: float | None = None, use_grad=True, alpha: float = 1, inner: Chainable | None = None):
24
45
 
25
- defaults = dict(alpha=alpha, max=max, f_star=f_star, use_grad=use_grad)
46
+ defaults = dict(alpha=alpha, max=max, f_star=f_star, y=y, y_decay=y_decay)
26
47
  super().__init__(defaults, uses_grad=use_grad, uses_loss=True, inner=inner)
27
48
 
49
+ @torch.no_grad
28
50
  def update_tensors(self, tensors, params, grads, loss, states, settings):
29
51
  assert grads is not None and loss is not None
30
52
  tensors = TensorList(tensors)
31
53
  grads = TensorList(grads)
32
54
 
33
- use_grad, max, f_star = itemgetter('use_grad', 'max', 'f_star')(settings[0])
55
+ # load variables
56
+ max, f_star, y, y_decay = itemgetter('max', 'f_star', 'y', 'y_decay')(settings[0])
57
+ y_val = self.global_state.get('y_val', y)
58
+ f_best = self.global_state.get('f_best', None)
34
59
 
35
- if use_grad: gg = tensors.dot(grads)
60
+ # gg
61
+ if self._uses_grad: gg = tensors.dot(grads)
36
62
  else: gg = tensors.dot(tensors)
37
63
 
38
- if gg.abs() <= torch.finfo(gg.dtype).eps: step_size = 0 # converged
39
- else: step_size = (loss - f_star) / gg
64
+ # store loss
65
+ if f_best is None or loss < f_best: f_best = tofloat(loss)
66
+ if f_star is None: f_star = f_best - y_val
67
+
68
+ # calculate the step size
69
+ if gg <= torch.finfo(gg.dtype).tiny * 2: alpha = 0 # converged
70
+ else: alpha = (loss - f_star) / gg
40
71
 
72
+ # clip
41
73
  if max is not None:
42
- if step_size > max: step_size = max
74
+ if alpha > max: alpha = max
43
75
 
44
- self.global_state['step_size'] = step_size
76
+ # store state
77
+ self.global_state['f_best'] = f_best
78
+ self.global_state['y_val'] = y_val * (1 - y_decay)
79
+ self.global_state['alpha'] = alpha
45
80
 
46
81
  @torch.no_grad
47
82
  def apply_tensors(self, tensors, params, grads, loss, states, settings):
48
- step_size = self.global_state.get('step_size', 1)
49
- torch._foreach_mul_(tensors, step_size * unpack_dicts(settings, 'alpha', cls=NumberList))
83
+ alpha = self.global_state.get('alpha', 1)
84
+ if not _acceptable_alpha(alpha, tensors[0]): alpha = epsilon_step_size(TensorList(tensors))
85
+
86
+ torch._foreach_mul_(tensors, alpha * unpack_dicts(settings, 'alpha', cls=NumberList))
50
87
  return tensors
51
88
 
89
+ def get_H(self, var):
90
+ return _get_H(self, var)
52
91
 
53
92
 
54
- def _bb_short(s: TensorList, y: TensorList, sy, eps, fallback):
93
+ def _bb_short(s: TensorList, y: TensorList, sy, eps):
55
94
  yy = y.dot(y)
56
95
  if yy < eps:
57
- if sy < eps: return fallback # try to fallback on long
96
+ if sy < eps: return None # try to fallback on long
58
97
  ss = s.dot(s)
59
98
  return ss/sy
60
99
  return sy/yy
61
100
 
62
- def _bb_long(s: TensorList, y: TensorList, sy, eps, fallback):
101
+ def _bb_long(s: TensorList, y: TensorList, sy, eps):
63
102
  ss = s.dot(s)
64
103
  if sy < eps:
65
104
  yy = y.dot(y) # try to fallback on short
66
- if yy < eps: return fallback
105
+ if yy < eps: return None
67
106
  return sy/yy
68
107
  return ss/sy
69
108
 
70
- def _bb_geom(s: TensorList, y: TensorList, sy, eps, fallback):
71
- short = _bb_short(s, y, sy, eps, fallback)
72
- long = _bb_long(s, y, sy, eps, fallback)
109
+ def _bb_geom(s: TensorList, y: TensorList, sy, eps, fallback:bool):
110
+ short = _bb_short(s, y, sy, eps)
111
+ long = _bb_long(s, y, sy, eps)
112
+ if long is None or short is None:
113
+ if fallback:
114
+ if short is not None: return short
115
+ if long is not None: return long
116
+ return None
73
117
  return (short * long) ** 0.5
74
118
 
75
119
  class BarzilaiBorwein(Transform):
76
- """Barzilai-Borwein method.
120
+ """Barzilai-Borwein step size method.
77
121
 
78
122
  Args:
79
123
  type (str, optional):
80
124
  one of "short" with formula sᵀy/yᵀy, "long" with formula sᵀs/sᵀy, or "geom" to use geometric mean of short and long.
81
- Defaults to 'geom'.
82
- scale_first (bool, optional):
83
- whether to make first step very small when previous gradient is not available. Defaults to True.
125
+ Defaults to "geom".
84
126
  fallback (float, optional): step size when denominator is less than 0 (will happen on negative curvature). Defaults to 1e-3.
85
127
  inner (Chainable | None, optional):
86
128
  step size will be applied to outputs of this module. Defaults to None.
129
+ """
130
+
131
+ def __init__(
132
+ self,
133
+ type: Literal["long", "short", "geom", "geom-fallback"] = "geom",
134
+ alpha_0: float = 1e-7,
135
+ use_grad=True,
136
+ inner: Chainable | None = None,
137
+ ):
138
+ defaults = dict(type=type, alpha_0=alpha_0)
139
+ super().__init__(defaults, uses_grad=use_grad, inner=inner)
140
+
141
+ def reset_for_online(self):
142
+ super().reset_for_online()
143
+ self.clear_state_keys('prev_g')
144
+ self.global_state['reset'] = True
145
+
146
+ @torch.no_grad
147
+ def update_tensors(self, tensors, params, grads, loss, states, settings):
148
+ step = self.global_state.get('step', 0)
149
+ self.global_state['step'] = step + 1
150
+
151
+ prev_p, prev_g = unpack_states(states, tensors, 'prev_p', 'prev_g', cls=TensorList)
152
+ type = self.defaults['type']
153
+
154
+ g = grads if self._uses_grad else tensors
155
+ assert g is not None
156
+
157
+ reset = self.global_state.get('reset', False)
158
+ self.global_state.pop('reset', None)
159
+
160
+ if step != 0 and not reset:
161
+ s = params-prev_p
162
+ y = g-prev_g
163
+ sy = s.dot(y)
164
+ eps = torch.finfo(sy.dtype).tiny * 2
165
+
166
+ if type == 'short': alpha = _bb_short(s, y, sy, eps)
167
+ elif type == 'long': alpha = _bb_long(s, y, sy, eps)
168
+ elif type == 'geom': alpha = _bb_geom(s, y, sy, eps, fallback=False)
169
+ elif type == 'geom-fallback': alpha = _bb_geom(s, y, sy, eps, fallback=True)
170
+ else: raise ValueError(type)
171
+
172
+ # if alpha is not None:
173
+ self.global_state['alpha'] = alpha
174
+
175
+ prev_p.copy_(params)
176
+ prev_g.copy_(g)
177
+
178
+ def get_H(self, var):
179
+ return _get_H(self, var)
180
+
181
+ @torch.no_grad
182
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
183
+ alpha = self.global_state.get('alpha', None)
184
+
185
+ if not _acceptable_alpha(alpha, tensors[0]):
186
+ alpha = epsilon_step_size(TensorList(tensors), settings[0]['alpha_0'])
187
+
188
+ torch._foreach_mul_(tensors, alpha)
189
+ return tensors
190
+
191
+
192
+ class BBStab(Transform):
193
+ """Stabilized Barzilai-Borwein method (https://arxiv.org/abs/1907.06409).
194
+
195
+ This clips the norm of the Barzilai-Borwein update by ``delta``, where ``delta`` can be adaptive if ``c`` is specified.
196
+
197
+ Args:
198
+ c (float, optional):
199
+ adaptive delta parameter. If ``delta`` is set to None, first ``inf_iters`` updates are performed
200
+ with non-stabilized Barzilai-Borwein step size. Then delta is set to norm of
201
+ the update that had the smallest norm, and multiplied by ``c``. Defaults to 0.2.
202
+ delta (float | None, optional):
203
+ Barzilai-Borwein update is clipped to this value. Set to ``None`` to use an adaptive choice. Defaults to None.
204
+ type (str, optional):
205
+ one of "short" with formula sᵀy/yᵀy, "long" with formula sᵀs/sᵀy, or "geom" to use geometric mean of short and long.
206
+ Defaults to "geom". Note that "long" corresponds to BB1stab and "short" to BB2stab,
207
+ however I found that "geom" works really well.
208
+ inner (Chainable | None, optional):
209
+ step size will be applied to outputs of this module. Defaults to None.
87
210
 
88
211
  """
89
- def __init__(self, type: Literal['long', 'short', 'geom'] = 'geom', scale_first:bool=True, fallback:float=1e-3, inner:Chainable|None = None):
90
- defaults = dict(type=type, fallback=fallback)
91
- super().__init__(defaults, uses_grad=False, scale_first=scale_first, inner=inner)
212
+ def __init__(
213
+ self,
214
+ c=0.2,
215
+ delta:float | None = None,
216
+ type: Literal["long", "short", "geom", "geom-fallback"] = "geom",
217
+ alpha_0: float = 1e-7,
218
+ use_grad=True,
219
+ inf_iters: int = 3,
220
+ inner: Chainable | None = None,
221
+ ):
222
+ defaults = dict(type=type,alpha_0=alpha_0, c=c, delta=delta, inf_iters=inf_iters)
223
+ super().__init__(defaults, uses_grad=use_grad, inner=inner)
92
224
 
93
225
  def reset_for_online(self):
94
226
  super().reset_for_online()
95
- self.clear_state_keys('prev_p', 'prev_g')
227
+ self.clear_state_keys('prev_g')
228
+ self.global_state['reset'] = True
96
229
 
97
230
  @torch.no_grad
98
231
  def update_tensors(self, tensors, params, grads, loss, states, settings):
232
+ step = self.global_state.get('step', 0)
233
+ self.global_state['step'] = step + 1
234
+
99
235
  prev_p, prev_g = unpack_states(states, tensors, 'prev_p', 'prev_g', cls=TensorList)
100
- fallback = unpack_dicts(settings, 'fallback', cls=NumberList)
101
- type = settings[0]['type']
236
+ type = self.defaults['type']
237
+ c = self.defaults['c']
238
+ delta = self.defaults['delta']
239
+ inf_iters = self.defaults['inf_iters']
240
+
241
+ g = grads if self._uses_grad else tensors
242
+ assert g is not None
243
+ g = TensorList(g)
244
+
245
+ reset = self.global_state.get('reset', False)
246
+ self.global_state.pop('reset', None)
247
+
248
+ if step != 0 and not reset:
249
+ s = params-prev_p
250
+ y = g-prev_g
251
+ sy = s.dot(y)
252
+ eps = torch.finfo(sy.dtype).tiny
253
+
254
+ if type == 'short': alpha = _bb_short(s, y, sy, eps)
255
+ elif type == 'long': alpha = _bb_long(s, y, sy, eps)
256
+ elif type == 'geom': alpha = _bb_geom(s, y, sy, eps, fallback=False)
257
+ elif type == 'geom-fallback': alpha = _bb_geom(s, y, sy, eps, fallback=True)
258
+ else: raise ValueError(type)
259
+
260
+ if alpha is not None:
261
+
262
+ # adaptive delta
263
+ if delta is None:
264
+ niters = self.global_state.get('niters', 0) # this accounts for skipped negative curvature steps
265
+ self.global_state['niters'] = niters + 1
266
+
267
+
268
+ if niters == 0: pass # 1st iteration is scaled GD step, shouldn't be used to find s_norm_min
269
+ elif niters <= inf_iters:
270
+ s_norm_min = self.global_state.get('s_norm_min', None)
271
+ if s_norm_min is None: s_norm_min = s.global_vector_norm()
272
+ else: s_norm_min = min(s_norm_min, s.global_vector_norm())
273
+ self.global_state['s_norm_min'] = s_norm_min
274
+ # first few steps use delta=inf, so delta remains None
102
275
 
103
- s = params-prev_p
104
- y = tensors-prev_g
105
- sy = s.dot(y)
106
- eps = torch.finfo(sy.dtype).eps
276
+ else:
277
+ delta = c * self.global_state['s_norm_min']
107
278
 
108
- if type == 'short': step_size = _bb_short(s, y, sy, eps, fallback)
109
- elif type == 'long': step_size = _bb_long(s, y, sy, eps, fallback)
110
- elif type == 'geom': step_size = _bb_geom(s, y, sy, eps, fallback)
111
- else: raise ValueError(type)
279
+ if delta is None: # delta is inf for first few steps
280
+ self.global_state['alpha'] = alpha
112
281
 
113
- self.global_state['step_size'] = step_size
282
+ # BBStab step size
283
+ else:
284
+ a_stab = delta / g.global_vector_norm()
285
+ self.global_state['alpha'] = min(alpha, a_stab)
114
286
 
115
287
  prev_p.copy_(params)
116
- prev_g.copy_(tensors)
288
+ prev_g.copy_(g)
289
+
290
+ def get_H(self, var):
291
+ return _get_H(self, var)
292
+
293
+ @torch.no_grad
294
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
295
+ alpha = self.global_state.get('alpha', None)
296
+
297
+ if not _acceptable_alpha(alpha, tensors[0]):
298
+ alpha = epsilon_step_size(TensorList(tensors), settings[0]['alpha_0'])
299
+
300
+ torch._foreach_mul_(tensors, alpha)
301
+ return tensors
302
+
303
+
304
+ class AdGD(Transform):
305
+ """AdGD and AdGD-2 (https://arxiv.org/abs/2308.02261)"""
306
+ def __init__(self, variant:Literal[1,2]=2, alpha_0:float = 1e-7, sqrt:bool=True, use_grad=True, inner: Chainable | None = None,):
307
+ defaults = dict(variant=variant, alpha_0=alpha_0, sqrt=sqrt)
308
+ super().__init__(defaults, uses_grad=use_grad, inner=inner,)
309
+
310
+ def reset_for_online(self):
311
+ super().reset_for_online()
312
+ self.clear_state_keys('prev_g')
313
+ self.global_state['reset'] = True
314
+
315
+ @torch.no_grad
316
+ def update_tensors(self, tensors, params, grads, loss, states, settings):
317
+ variant = settings[0]['variant']
318
+ theta_0 = 0 if variant == 1 else 1/3
319
+ theta = self.global_state.get('theta', theta_0)
320
+
321
+ step = self.global_state.get('step', 0)
322
+ self.global_state['step'] = step + 1
323
+
324
+ p = TensorList(params)
325
+ g = grads if self._uses_grad else tensors
326
+ assert g is not None
327
+ g = TensorList(g)
328
+
329
+ prev_p, prev_g = unpack_states(states, tensors, 'prev_p', 'prev_g', cls=TensorList)
330
+
331
+ # online
332
+ if self.global_state.get('reset', False):
333
+ del self.global_state['reset']
334
+ prev_p.copy_(p)
335
+ prev_g.copy_(g)
336
+ return
117
337
 
338
+ if step == 0:
339
+ alpha_0 = settings[0]['alpha_0']
340
+ if alpha_0 is None: alpha_0 = epsilon_step_size(g)
341
+ self.global_state['alpha'] = alpha_0
342
+ prev_p.copy_(p)
343
+ prev_g.copy_(g)
344
+ return
345
+
346
+ sqrt = settings[0]['sqrt']
347
+ alpha = self.global_state.get('alpha', math.inf)
348
+ L = (g - prev_g).global_vector_norm() / (p - prev_p).global_vector_norm()
349
+ eps = torch.finfo(L.dtype).tiny * 2
350
+
351
+ if variant == 1:
352
+ a1 = math.sqrt(1 + theta)*alpha
353
+ val = math.sqrt(2) if sqrt else 2
354
+ if L > eps: a2 = 1 / (val*L)
355
+ else: a2 = math.inf
356
+
357
+ elif variant == 2:
358
+ a1 = math.sqrt(2/3 + theta)*alpha
359
+ a2 = alpha / math.sqrt(max(eps, 2 * alpha**2 * L**2 - 1))
360
+
361
+ else:
362
+ raise ValueError(variant)
363
+
364
+ alpha_new = min(a1, a2)
365
+ if alpha_new < 0: alpha_new = max(a1, a2)
366
+ if alpha_new > eps:
367
+ self.global_state['theta'] = alpha_new/alpha
368
+ self.global_state['alpha'] = alpha_new
369
+
370
+ prev_p.copy_(p)
371
+ prev_g.copy_(g)
372
+
373
+ @torch.no_grad
118
374
  def apply_tensors(self, tensors, params, grads, loss, states, settings):
119
- step_size = self.global_state.get('step_size', 1)
120
- torch._foreach_mul_(tensors, step_size)
375
+ alpha = self.global_state.get('alpha', None)
376
+
377
+ if not _acceptable_alpha(alpha, tensors[0]):
378
+ # alpha isn't None on 1st step
379
+ self.state.clear()
380
+ self.global_state.clear()
381
+ alpha = epsilon_step_size(TensorList(tensors), settings[0]['alpha_0'])
382
+
383
+ torch._foreach_mul_(tensors, alpha)
121
384
  return tensors
122
385
 
386
+ def get_H(self, var):
387
+ return _get_H(self, var)
@@ -0,0 +1,14 @@
1
+ from .termination import (
2
+ TerminateAfterNEvaluations,
3
+ TerminateAfterNSeconds,
4
+ TerminateAfterNSteps,
5
+ TerminateAll,
6
+ TerminateAny,
7
+ TerminateByGradientNorm,
8
+ TerminateByUpdateNorm,
9
+ TerminateOnLossReached,
10
+ TerminateOnNoImprovement,
11
+ TerminationCriteriaBase,
12
+ TerminateNever,
13
+ make_termination_criteria
14
+ )
@@ -0,0 +1,207 @@
1
+ import time
2
+ from abc import ABC, abstractmethod
3
+ from collections.abc import Sequence
4
+ from typing import cast
5
+
6
+ import torch
7
+
8
+ from ...core import Module, Var
9
+ from ...utils import Metrics, TensorList, safe_dict_update_, tofloat
10
+
11
+
12
+ class TerminationCriteriaBase(Module):
13
+ def __init__(self, defaults:dict | None = None, n: int = 1):
14
+ if defaults is None: defaults = {}
15
+ safe_dict_update_(defaults, {"_n": n})
16
+ super().__init__(defaults)
17
+
18
+ @abstractmethod
19
+ def termination_criteria(self, var: Var) -> bool:
20
+ ...
21
+
22
+ def should_terminate(self, var: Var) -> bool:
23
+ n_bad = self.global_state.get('_n_bad', 0)
24
+ n = self.defaults['_n']
25
+
26
+ if self.termination_criteria(var):
27
+ n_bad += 1
28
+ if n_bad >= n:
29
+ self.global_state['_n_bad'] = 0
30
+ return True
31
+
32
+ else:
33
+ n_bad = 0
34
+
35
+ self.global_state['_n_bad'] = n_bad
36
+ return False
37
+
38
+
39
+ def update(self, var):
40
+ var.should_terminate = self.should_terminate(var)
41
+ if var.should_terminate: self.global_state['_n_bad'] = 0
42
+
43
+ def apply(self, var):
44
+ return var
45
+
46
+
47
+ class TerminateAfterNSteps(TerminationCriteriaBase):
48
+ def __init__(self, steps:int):
49
+ defaults = dict(steps=steps)
50
+ super().__init__(defaults)
51
+
52
+ def termination_criteria(self, var):
53
+ step = self.global_state.get('step', 0)
54
+ self.global_state['step'] = step + 1
55
+
56
+ max_steps = self.defaults['steps']
57
+ return step >= max_steps
58
+
59
+ class TerminateAfterNEvaluations(TerminationCriteriaBase):
60
+ def __init__(self, maxevals:int):
61
+ defaults = dict(maxevals=maxevals)
62
+ super().__init__(defaults)
63
+
64
+ def termination_criteria(self, var):
65
+ maxevals = self.defaults['maxevals']
66
+ return var.modular.num_evaluations >= maxevals
67
+
68
+ class TerminateAfterNSeconds(TerminationCriteriaBase):
69
+ def __init__(self, seconds:float, sec_fn = time.time):
70
+ defaults = dict(seconds=seconds, sec_fn=sec_fn)
71
+ super().__init__(defaults)
72
+
73
+ def termination_criteria(self, var):
74
+ max_seconds = self.defaults['seconds']
75
+ sec_fn = self.defaults['sec_fn']
76
+
77
+ if 'start' not in self.global_state:
78
+ self.global_state['start'] = sec_fn()
79
+ return False
80
+
81
+ seconds_passed = sec_fn() - self.global_state['start']
82
+ return seconds_passed >= max_seconds
83
+
84
+
85
+
86
+ class TerminateByGradientNorm(TerminationCriteriaBase):
87
+ def __init__(self, tol:float = 1e-8, n: int = 3, ord: Metrics = 2):
88
+ defaults = dict(tol=tol, ord=ord)
89
+ super().__init__(defaults, n=n)
90
+
91
+ def termination_criteria(self, var):
92
+ tol = self.defaults['tol']
93
+ ord = self.defaults['ord']
94
+ return TensorList(var.get_grad()).global_metric(ord) <= tol
95
+
96
+
97
+ class TerminateByUpdateNorm(TerminationCriteriaBase):
98
+ """update is calculated as parameter difference"""
99
+ def __init__(self, tol:float = 1e-8, n: int = 3, ord: Metrics = 2):
100
+ defaults = dict(tol=tol, ord=ord)
101
+ super().__init__(defaults, n=n)
102
+
103
+ def termination_criteria(self, var):
104
+ step = self.global_state.get('step', 0)
105
+ self.global_state['step'] = step + 1
106
+
107
+ tol = self.defaults['tol']
108
+ ord = self.defaults['ord']
109
+
110
+ p_prev = self.get_state(var.params, 'p_prev', cls=TensorList)
111
+ if step == 0:
112
+ p_prev.copy_(var.params)
113
+ return False
114
+
115
+ should_terminate = (p_prev - var.params).global_metric(ord) <= tol
116
+ p_prev.copy_(var.params)
117
+ return should_terminate
118
+
119
+
120
+ class TerminateOnNoImprovement(TerminationCriteriaBase):
121
+ def __init__(self, tol:float = 1e-8, n: int = 10):
122
+ defaults = dict(tol=tol)
123
+ super().__init__(defaults, n=n)
124
+
125
+ def termination_criteria(self, var):
126
+ tol = self.defaults['tol']
127
+
128
+ f = tofloat(var.get_loss(False))
129
+ if 'f_min' not in self.global_state:
130
+ self.global_state['f_min'] = f
131
+ return False
132
+
133
+ f_min = self.global_state['f_min']
134
+ d = f_min - f
135
+ should_terminate = d <= tol
136
+ self.global_state['f_min'] = min(f, f_min)
137
+ return should_terminate
138
+
139
+ class TerminateOnLossReached(TerminationCriteriaBase):
140
+ def __init__(self, value: float):
141
+ defaults = dict(value=value)
142
+ super().__init__(defaults)
143
+
144
+ def termination_criteria(self, var):
145
+ value = self.defaults['value']
146
+ return var.get_loss(False) <= value
147
+
148
+ class TerminateAny(TerminationCriteriaBase):
149
+ def __init__(self, *criteria: TerminationCriteriaBase):
150
+ super().__init__()
151
+
152
+ self.set_children_sequence(criteria)
153
+
154
+ def termination_criteria(self, var: Var) -> bool:
155
+ for c in self.get_children_sequence():
156
+ if cast(TerminationCriteriaBase, c).termination_criteria(var): return True
157
+
158
+ return False
159
+
160
+ class TerminateAll(TerminationCriteriaBase):
161
+ def __init__(self, *criteria: TerminationCriteriaBase):
162
+ super().__init__()
163
+
164
+ self.set_children_sequence(criteria)
165
+
166
+ def termination_criteria(self, var: Var) -> bool:
167
+ for c in self.get_children_sequence():
168
+ if not cast(TerminationCriteriaBase, c).termination_criteria(var): return False
169
+
170
+ return True
171
+
172
+ class TerminateNever(TerminationCriteriaBase):
173
+ def __init__(self):
174
+ super().__init__()
175
+
176
+ def termination_criteria(self, var): return False
177
+
178
+ def make_termination_criteria(
179
+ ftol: float | None = None,
180
+ gtol: float | None = None,
181
+ stol: float | None = None,
182
+ maxiter: int | None = None,
183
+ maxeval: int | None = None,
184
+ maxsec: float | None = None,
185
+ target_loss: float | None = None,
186
+ extra: TerminationCriteriaBase | Sequence[TerminationCriteriaBase] | None = None,
187
+ n: int = 3,
188
+ ):
189
+ criteria: list[TerminationCriteriaBase] = []
190
+
191
+ if ftol is not None: criteria.append(TerminateOnNoImprovement(ftol, n=n))
192
+ if gtol is not None: criteria.append(TerminateByGradientNorm(gtol, n=n))
193
+ if stol is not None: criteria.append(TerminateByUpdateNorm(stol, n=n))
194
+
195
+ if maxiter is not None: criteria.append(TerminateAfterNSteps(maxiter))
196
+ if maxeval is not None: criteria.append(TerminateAfterNEvaluations(maxeval))
197
+ if maxsec is not None: criteria.append(TerminateAfterNSeconds(maxsec))
198
+
199
+ if target_loss is not None: criteria.append(TerminateOnLossReached(target_loss))
200
+
201
+ if extra is not None:
202
+ if isinstance(extra, TerminationCriteriaBase): criteria.append(extra)
203
+ else: criteria.extend(extra)
204
+
205
+ if len(criteria) == 0: return TerminateNever()
206
+ if len(criteria) == 1: return criteria[0]
207
+ return TerminateAny(*criteria)
@@ -0,0 +1,5 @@
1
+ from .trust_region import TrustRegionBase
2
+ from .cubic_regularization import CubicRegularization
3
+ from .trust_cg import TrustCG
4
+ from .levenberg_marquardt import LevenbergMarquardt
5
+ from .dogleg import Dogleg