torchzero 0.3.10__py3-none-any.whl → 0.3.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (182) hide show
  1. tests/test_identical.py +2 -3
  2. tests/test_opts.py +140 -100
  3. tests/test_tensorlist.py +8 -7
  4. tests/test_vars.py +1 -0
  5. torchzero/__init__.py +1 -1
  6. torchzero/core/__init__.py +2 -2
  7. torchzero/core/module.py +335 -50
  8. torchzero/core/reformulation.py +65 -0
  9. torchzero/core/transform.py +197 -70
  10. torchzero/modules/__init__.py +13 -4
  11. torchzero/modules/adaptive/__init__.py +30 -0
  12. torchzero/modules/adaptive/adagrad.py +356 -0
  13. torchzero/modules/adaptive/adahessian.py +224 -0
  14. torchzero/modules/{optimizers → adaptive}/adam.py +6 -8
  15. torchzero/modules/adaptive/adan.py +96 -0
  16. torchzero/modules/adaptive/adaptive_heavyball.py +54 -0
  17. torchzero/modules/adaptive/aegd.py +54 -0
  18. torchzero/modules/adaptive/esgd.py +171 -0
  19. torchzero/modules/{optimizers → adaptive}/lion.py +1 -1
  20. torchzero/modules/{experimental/spectral.py → adaptive/lmadagrad.py} +94 -71
  21. torchzero/modules/adaptive/mars.py +79 -0
  22. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  23. torchzero/modules/adaptive/msam.py +188 -0
  24. torchzero/modules/{optimizers → adaptive}/muon.py +29 -5
  25. torchzero/modules/adaptive/natural_gradient.py +175 -0
  26. torchzero/modules/{optimizers → adaptive}/orthograd.py +1 -1
  27. torchzero/modules/{optimizers → adaptive}/rmsprop.py +7 -4
  28. torchzero/modules/{optimizers → adaptive}/rprop.py +42 -10
  29. torchzero/modules/adaptive/sam.py +163 -0
  30. torchzero/modules/{optimizers → adaptive}/shampoo.py +47 -9
  31. torchzero/modules/{optimizers → adaptive}/soap.py +52 -65
  32. torchzero/modules/adaptive/sophia_h.py +185 -0
  33. torchzero/modules/clipping/clipping.py +115 -25
  34. torchzero/modules/clipping/ema_clipping.py +31 -17
  35. torchzero/modules/clipping/growth_clipping.py +8 -7
  36. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  37. torchzero/modules/conjugate_gradient/cg.py +355 -0
  38. torchzero/modules/experimental/__init__.py +13 -19
  39. torchzero/modules/{projections → experimental}/dct.py +11 -11
  40. torchzero/modules/{projections → experimental}/fft.py +10 -10
  41. torchzero/modules/experimental/gradmin.py +4 -3
  42. torchzero/modules/experimental/l_infinity.py +111 -0
  43. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +5 -42
  44. torchzero/modules/experimental/newton_solver.py +79 -17
  45. torchzero/modules/experimental/newtonnewton.py +32 -15
  46. torchzero/modules/experimental/reduce_outward_lr.py +4 -4
  47. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  48. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +13 -55
  49. torchzero/modules/functional.py +52 -6
  50. torchzero/modules/grad_approximation/fdm.py +30 -4
  51. torchzero/modules/grad_approximation/forward_gradient.py +16 -4
  52. torchzero/modules/grad_approximation/grad_approximator.py +51 -10
  53. torchzero/modules/grad_approximation/rfdm.py +321 -52
  54. torchzero/modules/higher_order/__init__.py +1 -1
  55. torchzero/modules/higher_order/higher_order_newton.py +164 -93
  56. torchzero/modules/least_squares/__init__.py +1 -0
  57. torchzero/modules/least_squares/gn.py +161 -0
  58. torchzero/modules/line_search/__init__.py +4 -4
  59. torchzero/modules/line_search/_polyinterp.py +289 -0
  60. torchzero/modules/line_search/adaptive.py +124 -0
  61. torchzero/modules/line_search/backtracking.py +95 -57
  62. torchzero/modules/line_search/line_search.py +171 -22
  63. torchzero/modules/line_search/scipy.py +3 -3
  64. torchzero/modules/line_search/strong_wolfe.py +327 -199
  65. torchzero/modules/misc/__init__.py +35 -0
  66. torchzero/modules/misc/debug.py +48 -0
  67. torchzero/modules/misc/escape.py +62 -0
  68. torchzero/modules/misc/gradient_accumulation.py +136 -0
  69. torchzero/modules/misc/homotopy.py +59 -0
  70. torchzero/modules/misc/misc.py +383 -0
  71. torchzero/modules/misc/multistep.py +194 -0
  72. torchzero/modules/misc/regularization.py +167 -0
  73. torchzero/modules/misc/split.py +123 -0
  74. torchzero/modules/{ops → misc}/switch.py +45 -4
  75. torchzero/modules/momentum/__init__.py +1 -5
  76. torchzero/modules/momentum/averaging.py +9 -9
  77. torchzero/modules/momentum/cautious.py +51 -19
  78. torchzero/modules/momentum/momentum.py +37 -2
  79. torchzero/modules/ops/__init__.py +11 -31
  80. torchzero/modules/ops/accumulate.py +6 -10
  81. torchzero/modules/ops/binary.py +81 -34
  82. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +16 -39
  83. torchzero/modules/ops/multi.py +82 -21
  84. torchzero/modules/ops/reduce.py +16 -8
  85. torchzero/modules/ops/unary.py +29 -13
  86. torchzero/modules/ops/utility.py +30 -18
  87. torchzero/modules/projections/__init__.py +2 -4
  88. torchzero/modules/projections/cast.py +51 -0
  89. torchzero/modules/projections/galore.py +3 -1
  90. torchzero/modules/projections/projection.py +190 -96
  91. torchzero/modules/quasi_newton/__init__.py +9 -14
  92. torchzero/modules/quasi_newton/damping.py +105 -0
  93. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -0
  94. torchzero/modules/quasi_newton/lbfgs.py +286 -173
  95. torchzero/modules/quasi_newton/lsr1.py +185 -106
  96. torchzero/modules/quasi_newton/quasi_newton.py +816 -268
  97. torchzero/modules/restarts/__init__.py +7 -0
  98. torchzero/modules/restarts/restars.py +252 -0
  99. torchzero/modules/second_order/__init__.py +3 -2
  100. torchzero/modules/second_order/multipoint.py +238 -0
  101. torchzero/modules/second_order/newton.py +292 -68
  102. torchzero/modules/second_order/newton_cg.py +365 -15
  103. torchzero/modules/second_order/nystrom.py +104 -1
  104. torchzero/modules/smoothing/__init__.py +1 -1
  105. torchzero/modules/smoothing/laplacian.py +14 -4
  106. torchzero/modules/smoothing/sampling.py +300 -0
  107. torchzero/modules/step_size/__init__.py +2 -0
  108. torchzero/modules/step_size/adaptive.py +387 -0
  109. torchzero/modules/step_size/lr.py +154 -0
  110. torchzero/modules/termination/__init__.py +14 -0
  111. torchzero/modules/termination/termination.py +207 -0
  112. torchzero/modules/trust_region/__init__.py +5 -0
  113. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  114. torchzero/modules/trust_region/dogleg.py +92 -0
  115. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  116. torchzero/modules/trust_region/trust_cg.py +97 -0
  117. torchzero/modules/trust_region/trust_region.py +350 -0
  118. torchzero/modules/variance_reduction/__init__.py +1 -0
  119. torchzero/modules/variance_reduction/svrg.py +208 -0
  120. torchzero/modules/weight_decay/__init__.py +1 -1
  121. torchzero/modules/weight_decay/weight_decay.py +94 -11
  122. torchzero/modules/wrappers/optim_wrapper.py +29 -1
  123. torchzero/modules/zeroth_order/__init__.py +1 -0
  124. torchzero/modules/zeroth_order/cd.py +359 -0
  125. torchzero/optim/root.py +65 -0
  126. torchzero/optim/utility/split.py +8 -8
  127. torchzero/optim/wrappers/directsearch.py +39 -3
  128. torchzero/optim/wrappers/fcmaes.py +24 -15
  129. torchzero/optim/wrappers/mads.py +5 -6
  130. torchzero/optim/wrappers/nevergrad.py +16 -1
  131. torchzero/optim/wrappers/nlopt.py +0 -2
  132. torchzero/optim/wrappers/optuna.py +3 -3
  133. torchzero/optim/wrappers/scipy.py +86 -25
  134. torchzero/utils/__init__.py +40 -4
  135. torchzero/utils/compile.py +1 -1
  136. torchzero/utils/derivatives.py +126 -114
  137. torchzero/utils/linalg/__init__.py +9 -2
  138. torchzero/utils/linalg/linear_operator.py +329 -0
  139. torchzero/utils/linalg/matrix_funcs.py +2 -2
  140. torchzero/utils/linalg/orthogonalize.py +2 -1
  141. torchzero/utils/linalg/qr.py +2 -2
  142. torchzero/utils/linalg/solve.py +369 -58
  143. torchzero/utils/metrics.py +83 -0
  144. torchzero/utils/numberlist.py +2 -0
  145. torchzero/utils/python_tools.py +16 -0
  146. torchzero/utils/tensorlist.py +134 -51
  147. torchzero/utils/torch_tools.py +9 -4
  148. torchzero-0.3.13.dist-info/METADATA +14 -0
  149. torchzero-0.3.13.dist-info/RECORD +166 -0
  150. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
  151. docs/source/conf.py +0 -57
  152. torchzero/modules/experimental/absoap.py +0 -250
  153. torchzero/modules/experimental/adadam.py +0 -112
  154. torchzero/modules/experimental/adamY.py +0 -125
  155. torchzero/modules/experimental/adasoap.py +0 -172
  156. torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
  157. torchzero/modules/experimental/eigendescent.py +0 -117
  158. torchzero/modules/experimental/etf.py +0 -172
  159. torchzero/modules/experimental/soapy.py +0 -163
  160. torchzero/modules/experimental/structured_newton.py +0 -111
  161. torchzero/modules/experimental/subspace_preconditioners.py +0 -138
  162. torchzero/modules/experimental/tada.py +0 -38
  163. torchzero/modules/line_search/trust_region.py +0 -73
  164. torchzero/modules/lr/__init__.py +0 -2
  165. torchzero/modules/lr/adaptive.py +0 -93
  166. torchzero/modules/lr/lr.py +0 -63
  167. torchzero/modules/momentum/matrix_momentum.py +0 -166
  168. torchzero/modules/ops/debug.py +0 -25
  169. torchzero/modules/ops/misc.py +0 -418
  170. torchzero/modules/ops/split.py +0 -75
  171. torchzero/modules/optimizers/__init__.py +0 -18
  172. torchzero/modules/optimizers/adagrad.py +0 -155
  173. torchzero/modules/optimizers/sophia_h.py +0 -129
  174. torchzero/modules/quasi_newton/cg.py +0 -268
  175. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  176. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -266
  177. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  178. torchzero/modules/smoothing/gaussian.py +0 -164
  179. torchzero-0.3.10.dist-info/METADATA +0 -379
  180. torchzero-0.3.10.dist-info/RECORD +0 -139
  181. torchzero-0.3.10.dist-info/licenses/LICENSE +0 -21
  182. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
@@ -3,20 +3,21 @@ from abc import ABC, abstractmethod
3
3
  from collections.abc import Sequence
4
4
  from functools import partial
5
5
  from operator import itemgetter
6
- from typing import Any
6
+ from typing import Any, Literal
7
7
 
8
8
  import numpy as np
9
9
  import torch
10
10
 
11
11
  from ...core import Module, Target, Var
12
- from ...utils import tofloat
12
+ from ...utils import tofloat, set_storage_
13
13
 
14
14
 
15
15
  class MaxLineSearchItersReached(Exception): pass
16
16
 
17
17
 
18
- class LineSearch(Module, ABC):
18
+ class LineSearchBase(Module, ABC):
19
19
  """Base class for line searches.
20
+
20
21
  This is an abstract class, to use it, subclass it and override `search`.
21
22
 
22
23
  Args:
@@ -26,6 +27,61 @@ class LineSearch(Module, ABC):
26
27
  the objective this many times, and step size with the lowest loss value will be used.
27
28
  This is useful when passing `make_objective` to an external library which
28
29
  doesn't have a maxiter option. Defaults to None.
30
+
31
+ Other useful methods:
32
+ * ``evaluate_f`` - returns loss with a given scalar step size
33
+ * ``evaluate_f_d`` - returns loss and directional derivative with a given scalar step size
34
+ * ``make_objective`` - creates a function that accepts a scalar step size and returns loss. This can be passed to a scalar solver, such as scipy.optimize.minimize_scalar.
35
+ * ``make_objective_with_derivative`` - creates a function that accepts a scalar step size and returns a tuple with loss and directional derivative. This can be passed to a scalar solver.
36
+
37
+ Examples:
38
+
39
+ #### Basic line search
40
+
41
+ This evaluates all step sizes in a range by using the :code:`self.evaluate_step_size` method.
42
+ ```python
43
+ class GridLineSearch(LineSearch):
44
+ def __init__(self, start, end, num):
45
+ defaults = dict(start=start,end=end,num=num)
46
+ super().__init__(defaults)
47
+
48
+ @torch.no_grad
49
+ def search(self, update, var):
50
+
51
+ start = self.defaults["start"]
52
+ end = self.defaults["end"]
53
+ num = self.defaults["num"]
54
+
55
+ lowest_loss = float("inf")
56
+ best_step_size = best_step_size
57
+
58
+ for step_size in torch.linspace(start,end,num):
59
+ loss = self.evaluate_step_size(step_size.item(), var=var, backward=False)
60
+ if loss < lowest_loss:
61
+ lowest_loss = loss
62
+ best_step_size = step_size
63
+
64
+ return best_step_size
65
+ ```
66
+
67
+ #### Using external solver via self.make_objective
68
+
69
+ Here we let :code:`scipy.optimize.minimize_scalar` solver find the best step size via :code:`self.make_objective`
70
+
71
+ ```python
72
+ class ScipyMinimizeScalar(LineSearch):
73
+ def __init__(self, method: str | None = None):
74
+ defaults = dict(method=method)
75
+ super().__init__(defaults)
76
+
77
+ @torch.no_grad
78
+ def search(self, update, var):
79
+ objective = self.make_objective(var=var)
80
+ method = self.defaults["method"]
81
+
82
+ res = self.scopt.minimize_scalar(objective, method=method)
83
+ return res.x
84
+ ```
29
85
  """
30
86
  def __init__(self, defaults: dict[str, Any] | None, maxiter: int | None = None):
31
87
  super().__init__(defaults)
@@ -37,6 +93,7 @@ class LineSearch(Module, ABC):
37
93
  self._lowest_loss = float('inf')
38
94
  self._best_step_size: float = 0
39
95
  self._current_iter = 0
96
+ self._initial_params = None
40
97
 
41
98
  def set_step_size_(
42
99
  self,
@@ -45,10 +102,27 @@ class LineSearch(Module, ABC):
45
102
  update: list[torch.Tensor],
46
103
  ):
47
104
  if not math.isfinite(step_size): return
48
- step_size = max(min(tofloat(step_size), 1e36), -1e36) # fixes overflow when backtracking keeps increasing alpha after converging
49
- alpha = self._current_step_size - step_size
50
- if alpha != 0:
51
- torch._foreach_add_(params, update, alpha=alpha)
105
+
106
+ # fixes overflow when backtracking keeps increasing alpha after converging
107
+ step_size = max(min(tofloat(step_size), 1e36), -1e36)
108
+
109
+ # skip is parameters are already at suggested step size
110
+ if self._current_step_size == step_size: return
111
+
112
+ # this was basically causing floating point imprecision to build up
113
+ #if False:
114
+ # if abs(alpha) < abs(step_size) and step_size != 0:
115
+ # torch._foreach_add_(params, update, alpha=alpha)
116
+
117
+ # else:
118
+ assert self._initial_params is not None
119
+ if step_size == 0:
120
+ new_params = [p.clone() for p in self._initial_params]
121
+ else:
122
+ new_params = torch._foreach_sub(self._initial_params, update, alpha=step_size)
123
+ for c, n in zip(params, new_params):
124
+ set_storage_(c, n)
125
+
52
126
  self._current_step_size = step_size
53
127
 
54
128
  def _set_per_parameter_step_size_(
@@ -57,10 +131,20 @@ class LineSearch(Module, ABC):
57
131
  params: list[torch.Tensor],
58
132
  update: list[torch.Tensor],
59
133
  ):
60
- if not np.isfinite(step_size): step_size = [0 for _ in step_size]
61
- alpha = [self._current_step_size - s for s in step_size]
62
- if any(a!=0 for a in alpha):
63
- torch._foreach_add_(params, torch._foreach_mul(update, alpha))
134
+ # if not np.isfinite(step_size): step_size = [0 for _ in step_size]
135
+ # alpha = [self._current_step_size - s for s in step_size]
136
+ # if any(a!=0 for a in alpha):
137
+ # torch._foreach_add_(params, torch._foreach_mul(update, alpha))
138
+ assert self._initial_params is not None
139
+ if not np.isfinite(step_size).all(): step_size = [0 for _ in step_size]
140
+
141
+ if any(s!=0 for s in step_size):
142
+ new_params = torch._foreach_sub(self._initial_params, torch._foreach_mul(update, step_size))
143
+ else:
144
+ new_params = [p.clone() for p in self._initial_params]
145
+
146
+ for c, n in zip(params, new_params):
147
+ set_storage_(c, n)
64
148
 
65
149
  def _loss(self, step_size: float, var: Var, closure, params: list[torch.Tensor],
66
150
  update: list[torch.Tensor], backward:bool=False) -> float:
@@ -92,7 +176,7 @@ class LineSearch(Module, ABC):
92
176
 
93
177
  return tofloat(loss)
94
178
 
95
- def _loss_derivative(self, step_size: float, var: Var, closure,
179
+ def _loss_derivative_gradient(self, step_size: float, var: Var, closure,
96
180
  params: list[torch.Tensor], update: list[torch.Tensor]):
97
181
  # if step_size is 0, we might already know the derivative
98
182
  if (var.grad is not None) and (step_size == 0):
@@ -107,18 +191,31 @@ class LineSearch(Module, ABC):
107
191
  derivative = - sum(t.sum() for t in torch._foreach_mul([p.grad if p.grad is not None
108
192
  else torch.zeros_like(p) for p in params], update))
109
193
 
110
- return loss, tofloat(derivative)
194
+ assert var.grad is not None
195
+ return loss, tofloat(derivative), var.grad
111
196
 
112
- def evaluate_step_size(self, step_size: float, var: Var, backward:bool=False):
197
+ def _loss_derivative(self, step_size: float, var: Var, closure,
198
+ params: list[torch.Tensor], update: list[torch.Tensor]):
199
+ return self._loss_derivative_gradient(step_size=step_size, var=var,closure=closure,params=params,update=update)[:2]
200
+
201
+ def evaluate_f(self, step_size: float, var: Var, backward:bool=False):
202
+ """evaluate function value at alpha `step_size`."""
113
203
  closure = var.closure
114
204
  if closure is None: raise RuntimeError('line search requires closure')
115
205
  return self._loss(step_size=step_size, var=var, closure=closure, params=var.params,update=var.get_update(),backward=backward)
116
206
 
117
- def evaluate_step_size_loss_and_derivative(self, step_size: float, var: Var):
207
+ def evaluate_f_d(self, step_size: float, var: Var):
208
+ """evaluate function value and directional derivative in the direction of the update at step size `step_size`."""
118
209
  closure = var.closure
119
210
  if closure is None: raise RuntimeError('line search requires closure')
120
211
  return self._loss_derivative(step_size=step_size, var=var, closure=closure, params=var.params,update=var.get_update())
121
212
 
213
+ def evaluate_f_d_g(self, step_size: float, var: Var):
214
+ """evaluate function value, directional derivative, and gradient list at step size `step_size`."""
215
+ closure = var.closure
216
+ if closure is None: raise RuntimeError('line search requires closure')
217
+ return self._loss_derivative_gradient(step_size=step_size, var=var, closure=closure, params=var.params,update=var.get_update())
218
+
122
219
  def make_objective(self, var: Var, backward:bool=False):
123
220
  closure = var.closure
124
221
  if closure is None: raise RuntimeError('line search requires closure')
@@ -129,6 +226,11 @@ class LineSearch(Module, ABC):
129
226
  if closure is None: raise RuntimeError('line search requires closure')
130
227
  return partial(self._loss_derivative, var=var, closure=closure, params=var.params, update=var.get_update())
131
228
 
229
+ def make_objective_with_derivative_and_gradient(self, var: Var):
230
+ closure = var.closure
231
+ if closure is None: raise RuntimeError('line search requires closure')
232
+ return partial(self._loss_derivative_gradient, var=var, closure=closure, params=var.params, update=var.get_update())
233
+
132
234
  @abstractmethod
133
235
  def search(self, update: list[torch.Tensor], var: Var) -> float:
134
236
  """Finds the step size to use"""
@@ -136,7 +238,9 @@ class LineSearch(Module, ABC):
136
238
  @torch.no_grad
137
239
  def step(self, var: Var) -> Var:
138
240
  self._reset()
241
+
139
242
  params = var.params
243
+ self._initial_params = [p.clone() for p in params]
140
244
  update = var.get_update()
141
245
 
142
246
  try:
@@ -149,7 +253,6 @@ class LineSearch(Module, ABC):
149
253
 
150
254
  # this is last module - set step size to found step_size times lr
151
255
  if var.is_last:
152
-
153
256
  if var.last_module_lrs is None:
154
257
  self.set_step_size_(step_size, params=params, update=update)
155
258
 
@@ -165,17 +268,63 @@ class LineSearch(Module, ABC):
165
268
  return var
166
269
 
167
270
 
168
- class GridLineSearch(LineSearch):
169
- """Mostly for testing, this is not practical"""
271
+
272
+ class GridLineSearch(LineSearchBase):
273
+ """"""
170
274
  def __init__(self, start, end, num):
171
275
  defaults = dict(start=start,end=end,num=num)
172
276
  super().__init__(defaults)
173
277
 
174
278
  @torch.no_grad
175
279
  def search(self, update, var):
176
- start,end,num=itemgetter('start','end','num')(self.settings[var.params[0]])
280
+ start,end,num=itemgetter('start','end','num')(self.defaults)
177
281
 
178
282
  for lr in torch.linspace(start,end,num):
179
- self.evaluate_step_size(lr.item(), var=var, backward=False)
180
-
181
- return self._best_step_size
283
+ self.evaluate_f(lr.item(), var=var, backward=False)
284
+
285
+ return self._best_step_size
286
+
287
+
288
+ def sufficient_decrease(f_0, g_0, f_a, a, c):
289
+ return f_a < f_0 + c*a*min(g_0, 0)
290
+
291
+ def curvature(g_0, g_a, c):
292
+ if g_0 > 0: return True
293
+ return g_a >= c * g_0
294
+
295
+ def strong_curvature(g_0, g_a, c):
296
+ """same as curvature condition except curvature can't be too positive (which indicates overstep)"""
297
+ if g_0 > 0: return True
298
+ return abs(g_a) <= c * abs(g_0)
299
+
300
+ def wolfe(f_0, g_0, f_a, g_a, a, c1, c2):
301
+ return sufficient_decrease(f_0, g_0, f_a, a, c1) and curvature(g_0, g_a, c2)
302
+
303
+ def strong_wolfe(f_0, g_0, f_a, g_a, a, c1, c2):
304
+ return sufficient_decrease(f_0, g_0, f_a, a, c1) and strong_curvature(g_0, g_a, c2)
305
+
306
+ def goldstein(f_0, g_0, f_a, a, c):
307
+ """same as armijo (sufficient_decrease) but additional lower bound"""
308
+ g_0 = min(g_0, 0)
309
+ return f_0 + (1-c)*a*g_0 < f_a < f_0 + c*a*g_0
310
+
311
+ TerminationCondition = Literal["armijo", "curvature", "strong_curvature", "wolfe", "strong_wolfe", "goldstein", "decrease"]
312
+ def termination_condition(
313
+ condition: TerminationCondition,
314
+ f_0,
315
+ g_0,
316
+ f_a,
317
+ g_a: Any | None,
318
+ a,
319
+ c,
320
+ c2=None,
321
+ ):
322
+ if not math.isfinite(f_a): return False
323
+ if condition == 'armijo': return sufficient_decrease(f_0, g_0, f_a, a, c)
324
+ if condition == 'curvature': return curvature(g_0, g_a, c)
325
+ if condition == 'strong_curvature': return strong_curvature(g_0, g_a, c)
326
+ if condition == 'wolfe': return wolfe(f_0, g_0, f_a, g_a, a, c, c2)
327
+ if condition == 'strong_wolfe': return strong_wolfe(f_0, g_0, f_a, g_a, a, c, c2)
328
+ if condition == 'goldstein': return goldstein(f_0, g_0, f_a, a, c)
329
+ if condition == 'decrease': return f_a < f_0
330
+ raise ValueError(f"unknown condition {condition}")
@@ -3,10 +3,10 @@ from operator import itemgetter
3
3
 
4
4
  import torch
5
5
 
6
- from .line_search import LineSearch
6
+ from .line_search import LineSearchBase
7
7
 
8
8
 
9
- class ScipyMinimizeScalar(LineSearch):
9
+ class ScipyMinimizeScalar(LineSearchBase):
10
10
  """Line search via :code:`scipy.optimize.minimize_scalar` which implements brent, golden search and bounded brent methods.
11
11
 
12
12
  Args:
@@ -42,7 +42,7 @@ class ScipyMinimizeScalar(LineSearch):
42
42
  def search(self, update, var):
43
43
  objective = self.make_objective(var=var)
44
44
  method, bracket, bounds, tol, options, maxiter = itemgetter(
45
- 'method', 'bracket', 'bounds', 'tol', 'options', 'maxiter')(self.settings[var.params[0]])
45
+ 'method', 'bracket', 'bounds', 'tol', 'options', 'maxiter')(self.defaults)
46
46
 
47
47
  if maxiter is not None:
48
48
  options = dict(options) if isinstance(options, Mapping) else {}