torchzero 0.3.11__py3-none-any.whl → 0.3.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (161) hide show
  1. tests/test_opts.py +95 -69
  2. tests/test_tensorlist.py +8 -7
  3. torchzero/__init__.py +1 -1
  4. torchzero/core/__init__.py +2 -2
  5. torchzero/core/module.py +225 -72
  6. torchzero/core/reformulation.py +65 -0
  7. torchzero/core/transform.py +44 -24
  8. torchzero/modules/__init__.py +13 -5
  9. torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
  10. torchzero/modules/adaptive/adagrad.py +356 -0
  11. torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
  12. torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
  13. torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
  14. torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
  15. torchzero/modules/adaptive/aegd.py +54 -0
  16. torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
  17. torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
  18. torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
  19. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  20. torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
  21. torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
  22. torchzero/modules/adaptive/natural_gradient.py +175 -0
  23. torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
  24. torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
  25. torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
  26. torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
  27. torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
  28. torchzero/modules/clipping/clipping.py +85 -92
  29. torchzero/modules/clipping/ema_clipping.py +5 -5
  30. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  31. torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
  32. torchzero/modules/experimental/__init__.py +9 -32
  33. torchzero/modules/experimental/dct.py +2 -2
  34. torchzero/modules/experimental/fft.py +2 -2
  35. torchzero/modules/experimental/gradmin.py +4 -3
  36. torchzero/modules/experimental/l_infinity.py +111 -0
  37. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
  38. torchzero/modules/experimental/newton_solver.py +79 -17
  39. torchzero/modules/experimental/newtonnewton.py +27 -14
  40. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  41. torchzero/modules/experimental/structural_projections.py +1 -1
  42. torchzero/modules/functional.py +50 -14
  43. torchzero/modules/grad_approximation/fdm.py +19 -20
  44. torchzero/modules/grad_approximation/forward_gradient.py +4 -2
  45. torchzero/modules/grad_approximation/grad_approximator.py +43 -47
  46. torchzero/modules/grad_approximation/rfdm.py +144 -122
  47. torchzero/modules/higher_order/__init__.py +1 -1
  48. torchzero/modules/higher_order/higher_order_newton.py +31 -23
  49. torchzero/modules/least_squares/__init__.py +1 -0
  50. torchzero/modules/least_squares/gn.py +161 -0
  51. torchzero/modules/line_search/__init__.py +2 -2
  52. torchzero/modules/line_search/_polyinterp.py +289 -0
  53. torchzero/modules/line_search/adaptive.py +69 -44
  54. torchzero/modules/line_search/backtracking.py +83 -70
  55. torchzero/modules/line_search/line_search.py +159 -68
  56. torchzero/modules/line_search/scipy.py +1 -1
  57. torchzero/modules/line_search/strong_wolfe.py +319 -218
  58. torchzero/modules/misc/__init__.py +8 -0
  59. torchzero/modules/misc/debug.py +4 -4
  60. torchzero/modules/misc/escape.py +9 -7
  61. torchzero/modules/misc/gradient_accumulation.py +88 -22
  62. torchzero/modules/misc/homotopy.py +59 -0
  63. torchzero/modules/misc/misc.py +82 -15
  64. torchzero/modules/misc/multistep.py +47 -11
  65. torchzero/modules/misc/regularization.py +5 -9
  66. torchzero/modules/misc/split.py +55 -35
  67. torchzero/modules/misc/switch.py +1 -1
  68. torchzero/modules/momentum/__init__.py +1 -5
  69. torchzero/modules/momentum/averaging.py +3 -3
  70. torchzero/modules/momentum/cautious.py +42 -47
  71. torchzero/modules/momentum/momentum.py +35 -1
  72. torchzero/modules/ops/__init__.py +9 -1
  73. torchzero/modules/ops/binary.py +9 -8
  74. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
  75. torchzero/modules/ops/multi.py +15 -15
  76. torchzero/modules/ops/reduce.py +1 -1
  77. torchzero/modules/ops/utility.py +12 -8
  78. torchzero/modules/projections/projection.py +4 -4
  79. torchzero/modules/quasi_newton/__init__.py +1 -16
  80. torchzero/modules/quasi_newton/damping.py +105 -0
  81. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
  82. torchzero/modules/quasi_newton/lbfgs.py +256 -200
  83. torchzero/modules/quasi_newton/lsr1.py +167 -132
  84. torchzero/modules/quasi_newton/quasi_newton.py +346 -446
  85. torchzero/modules/restarts/__init__.py +7 -0
  86. torchzero/modules/restarts/restars.py +252 -0
  87. torchzero/modules/second_order/__init__.py +2 -1
  88. torchzero/modules/second_order/multipoint.py +238 -0
  89. torchzero/modules/second_order/newton.py +133 -88
  90. torchzero/modules/second_order/newton_cg.py +141 -80
  91. torchzero/modules/smoothing/__init__.py +1 -1
  92. torchzero/modules/smoothing/sampling.py +300 -0
  93. torchzero/modules/step_size/__init__.py +1 -1
  94. torchzero/modules/step_size/adaptive.py +312 -47
  95. torchzero/modules/termination/__init__.py +14 -0
  96. torchzero/modules/termination/termination.py +207 -0
  97. torchzero/modules/trust_region/__init__.py +5 -0
  98. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  99. torchzero/modules/trust_region/dogleg.py +92 -0
  100. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  101. torchzero/modules/trust_region/trust_cg.py +97 -0
  102. torchzero/modules/trust_region/trust_region.py +350 -0
  103. torchzero/modules/variance_reduction/__init__.py +1 -0
  104. torchzero/modules/variance_reduction/svrg.py +208 -0
  105. torchzero/modules/weight_decay/weight_decay.py +65 -64
  106. torchzero/modules/zeroth_order/__init__.py +1 -0
  107. torchzero/modules/zeroth_order/cd.py +359 -0
  108. torchzero/optim/root.py +65 -0
  109. torchzero/optim/utility/split.py +8 -8
  110. torchzero/optim/wrappers/directsearch.py +0 -1
  111. torchzero/optim/wrappers/fcmaes.py +3 -2
  112. torchzero/optim/wrappers/nlopt.py +0 -2
  113. torchzero/optim/wrappers/optuna.py +2 -2
  114. torchzero/optim/wrappers/scipy.py +81 -22
  115. torchzero/utils/__init__.py +40 -4
  116. torchzero/utils/compile.py +1 -1
  117. torchzero/utils/derivatives.py +123 -111
  118. torchzero/utils/linalg/__init__.py +9 -2
  119. torchzero/utils/linalg/linear_operator.py +329 -0
  120. torchzero/utils/linalg/matrix_funcs.py +2 -2
  121. torchzero/utils/linalg/orthogonalize.py +2 -1
  122. torchzero/utils/linalg/qr.py +2 -2
  123. torchzero/utils/linalg/solve.py +226 -154
  124. torchzero/utils/metrics.py +83 -0
  125. torchzero/utils/python_tools.py +6 -0
  126. torchzero/utils/tensorlist.py +105 -34
  127. torchzero/utils/torch_tools.py +9 -4
  128. torchzero-0.3.13.dist-info/METADATA +14 -0
  129. torchzero-0.3.13.dist-info/RECORD +166 -0
  130. {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
  131. docs/source/conf.py +0 -59
  132. docs/source/docstring template.py +0 -46
  133. torchzero/modules/experimental/absoap.py +0 -253
  134. torchzero/modules/experimental/adadam.py +0 -118
  135. torchzero/modules/experimental/adamY.py +0 -131
  136. torchzero/modules/experimental/adam_lambertw.py +0 -149
  137. torchzero/modules/experimental/adaptive_step_size.py +0 -90
  138. torchzero/modules/experimental/adasoap.py +0 -177
  139. torchzero/modules/experimental/cosine.py +0 -214
  140. torchzero/modules/experimental/cubic_adam.py +0 -97
  141. torchzero/modules/experimental/eigendescent.py +0 -120
  142. torchzero/modules/experimental/etf.py +0 -195
  143. torchzero/modules/experimental/exp_adam.py +0 -113
  144. torchzero/modules/experimental/expanded_lbfgs.py +0 -141
  145. torchzero/modules/experimental/hnewton.py +0 -85
  146. torchzero/modules/experimental/modular_lbfgs.py +0 -265
  147. torchzero/modules/experimental/parabolic_search.py +0 -220
  148. torchzero/modules/experimental/subspace_preconditioners.py +0 -145
  149. torchzero/modules/experimental/tensor_adagrad.py +0 -42
  150. torchzero/modules/line_search/polynomial.py +0 -233
  151. torchzero/modules/momentum/matrix_momentum.py +0 -193
  152. torchzero/modules/optimizers/adagrad.py +0 -165
  153. torchzero/modules/quasi_newton/trust_region.py +0 -397
  154. torchzero/modules/smoothing/gaussian.py +0 -198
  155. torchzero-0.3.11.dist-info/METADATA +0 -404
  156. torchzero-0.3.11.dist-info/RECORD +0 -159
  157. torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
  158. /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
  159. /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
  160. /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
  161. {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
@@ -3,13 +3,13 @@ from abc import ABC, abstractmethod
3
3
  from collections.abc import Sequence
4
4
  from functools import partial
5
5
  from operator import itemgetter
6
- from typing import Any
6
+ from typing import Any, Literal
7
7
 
8
8
  import numpy as np
9
9
  import torch
10
10
 
11
11
  from ...core import Module, Target, Var
12
- from ...utils import tofloat
12
+ from ...utils import tofloat, set_storage_
13
13
 
14
14
 
15
15
  class MaxLineSearchItersReached(Exception): pass
@@ -29,60 +29,59 @@ class LineSearchBase(Module, ABC):
29
29
  doesn't have a maxiter option. Defaults to None.
30
30
 
31
31
  Other useful methods:
32
- * `evaluate_step_size` - returns loss with a given scalar step size
33
- * `evaluate_step_size_loss_and_derivative` - returns loss and directional derivative with a given scalar step size
34
- * `make_objective` - creates a function that accepts a scalar step size and returns loss. This can be passed to a scalar solver, such as scipy.optimize.minimize_scalar.
35
- * `make_objective_with_derivative` - creates a function that accepts a scalar step size and returns a tuple with loss and directional derivative. This can be passed to a scalar solver.
32
+ * ``evaluate_f`` - returns loss with a given scalar step size
33
+ * ``evaluate_f_d`` - returns loss and directional derivative with a given scalar step size
34
+ * ``make_objective`` - creates a function that accepts a scalar step size and returns loss. This can be passed to a scalar solver, such as scipy.optimize.minimize_scalar.
35
+ * ``make_objective_with_derivative`` - creates a function that accepts a scalar step size and returns a tuple with loss and directional derivative. This can be passed to a scalar solver.
36
36
 
37
37
  Examples:
38
- #### Basic line search
39
38
 
40
- This evaluates all step sizes in a range by using the :code:`self.evaluate_step_size` method.
39
+ #### Basic line search
41
40
 
42
- .. code-block:: python
41
+ This evaluates all step sizes in a range by using the :code:`self.evaluate_step_size` method.
42
+ ```python
43
+ class GridLineSearch(LineSearch):
44
+ def __init__(self, start, end, num):
45
+ defaults = dict(start=start,end=end,num=num)
46
+ super().__init__(defaults)
43
47
 
44
- class GridLineSearch(LineSearch):
45
- def __init__(self, start, end, num):
46
- defaults = dict(start=start,end=end,num=num)
47
- super().__init__(defaults)
48
+ @torch.no_grad
49
+ def search(self, update, var):
48
50
 
49
- @torch.no_grad
50
- def search(self, update, var):
51
- settings = self.settings[var.params[0]]
52
- start = settings["start"]
53
- end = settings["end"]
54
- num = settings["num"]
51
+ start = self.defaults["start"]
52
+ end = self.defaults["end"]
53
+ num = self.defaults["num"]
55
54
 
56
- lowest_loss = float("inf")
57
- best_step_size = best_step_size
55
+ lowest_loss = float("inf")
56
+ best_step_size = best_step_size
58
57
 
59
- for step_size in torch.linspace(start,end,num):
60
- loss = self.evaluate_step_size(step_size.item(), var=var, backward=False)
61
- if loss < lowest_loss:
62
- lowest_loss = loss
63
- best_step_size = step_size
58
+ for step_size in torch.linspace(start,end,num):
59
+ loss = self.evaluate_step_size(step_size.item(), var=var, backward=False)
60
+ if loss < lowest_loss:
61
+ lowest_loss = loss
62
+ best_step_size = step_size
64
63
 
65
- return best_step_size
64
+ return best_step_size
65
+ ```
66
66
 
67
- #### Using external solver via self.make_objective
67
+ #### Using external solver via self.make_objective
68
68
 
69
- Here we let :code:`scipy.optimize.minimize_scalar` solver find the best step size via :code:`self.make_objective`
69
+ Here we let :code:`scipy.optimize.minimize_scalar` solver find the best step size via :code:`self.make_objective`
70
70
 
71
- .. code-block:: python
71
+ ```python
72
+ class ScipyMinimizeScalar(LineSearch):
73
+ def __init__(self, method: str | None = None):
74
+ defaults = dict(method=method)
75
+ super().__init__(defaults)
72
76
 
73
- class ScipyMinimizeScalar(LineSearch):
74
- def __init__(self, method: str | None = None):
75
- defaults = dict(method=method)
76
- super().__init__(defaults)
77
-
78
- @torch.no_grad
79
- def search(self, update, var):
80
- objective = self.make_objective(var=var)
81
- method = self.settings[var.params[0]]["method"]
82
-
83
- res = self.scopt.minimize_scalar(objective, method=method)
84
- return res.x
77
+ @torch.no_grad
78
+ def search(self, update, var):
79
+ objective = self.make_objective(var=var)
80
+ method = self.defaults["method"]
85
81
 
82
+ res = self.scopt.minimize_scalar(objective, method=method)
83
+ return res.x
84
+ ```
86
85
  """
87
86
  def __init__(self, defaults: dict[str, Any] | None, maxiter: int | None = None):
88
87
  super().__init__(defaults)
@@ -94,6 +93,7 @@ class LineSearchBase(Module, ABC):
94
93
  self._lowest_loss = float('inf')
95
94
  self._best_step_size: float = 0
96
95
  self._current_iter = 0
96
+ self._initial_params = None
97
97
 
98
98
  def set_step_size_(
99
99
  self,
@@ -102,10 +102,27 @@ class LineSearchBase(Module, ABC):
102
102
  update: list[torch.Tensor],
103
103
  ):
104
104
  if not math.isfinite(step_size): return
105
- step_size = max(min(tofloat(step_size), 1e36), -1e36) # fixes overflow when backtracking keeps increasing alpha after converging
106
- alpha = self._current_step_size - step_size
107
- if alpha != 0:
108
- torch._foreach_add_(params, update, alpha=alpha)
105
+
106
+ # fixes overflow when backtracking keeps increasing alpha after converging
107
+ step_size = max(min(tofloat(step_size), 1e36), -1e36)
108
+
109
+ # skip is parameters are already at suggested step size
110
+ if self._current_step_size == step_size: return
111
+
112
+ # this was basically causing floating point imprecision to build up
113
+ #if False:
114
+ # if abs(alpha) < abs(step_size) and step_size != 0:
115
+ # torch._foreach_add_(params, update, alpha=alpha)
116
+
117
+ # else:
118
+ assert self._initial_params is not None
119
+ if step_size == 0:
120
+ new_params = [p.clone() for p in self._initial_params]
121
+ else:
122
+ new_params = torch._foreach_sub(self._initial_params, update, alpha=step_size)
123
+ for c, n in zip(params, new_params):
124
+ set_storage_(c, n)
125
+
109
126
  self._current_step_size = step_size
110
127
 
111
128
  def _set_per_parameter_step_size_(
@@ -114,10 +131,20 @@ class LineSearchBase(Module, ABC):
114
131
  params: list[torch.Tensor],
115
132
  update: list[torch.Tensor],
116
133
  ):
117
- if not np.isfinite(step_size): step_size = [0 for _ in step_size]
118
- alpha = [self._current_step_size - s for s in step_size]
119
- if any(a!=0 for a in alpha):
120
- torch._foreach_add_(params, torch._foreach_mul(update, alpha))
134
+ # if not np.isfinite(step_size): step_size = [0 for _ in step_size]
135
+ # alpha = [self._current_step_size - s for s in step_size]
136
+ # if any(a!=0 for a in alpha):
137
+ # torch._foreach_add_(params, torch._foreach_mul(update, alpha))
138
+ assert self._initial_params is not None
139
+ if not np.isfinite(step_size).all(): step_size = [0 for _ in step_size]
140
+
141
+ if any(s!=0 for s in step_size):
142
+ new_params = torch._foreach_sub(self._initial_params, torch._foreach_mul(update, step_size))
143
+ else:
144
+ new_params = [p.clone() for p in self._initial_params]
145
+
146
+ for c, n in zip(params, new_params):
147
+ set_storage_(c, n)
121
148
 
122
149
  def _loss(self, step_size: float, var: Var, closure, params: list[torch.Tensor],
123
150
  update: list[torch.Tensor], backward:bool=False) -> float:
@@ -149,7 +176,7 @@ class LineSearchBase(Module, ABC):
149
176
 
150
177
  return tofloat(loss)
151
178
 
152
- def _loss_derivative(self, step_size: float, var: Var, closure,
179
+ def _loss_derivative_gradient(self, step_size: float, var: Var, closure,
153
180
  params: list[torch.Tensor], update: list[torch.Tensor]):
154
181
  # if step_size is 0, we might already know the derivative
155
182
  if (var.grad is not None) and (step_size == 0):
@@ -164,18 +191,31 @@ class LineSearchBase(Module, ABC):
164
191
  derivative = - sum(t.sum() for t in torch._foreach_mul([p.grad if p.grad is not None
165
192
  else torch.zeros_like(p) for p in params], update))
166
193
 
167
- return loss, tofloat(derivative)
194
+ assert var.grad is not None
195
+ return loss, tofloat(derivative), var.grad
168
196
 
169
- def evaluate_step_size(self, step_size: float, var: Var, backward:bool=False):
197
+ def _loss_derivative(self, step_size: float, var: Var, closure,
198
+ params: list[torch.Tensor], update: list[torch.Tensor]):
199
+ return self._loss_derivative_gradient(step_size=step_size, var=var,closure=closure,params=params,update=update)[:2]
200
+
201
+ def evaluate_f(self, step_size: float, var: Var, backward:bool=False):
202
+ """evaluate function value at alpha `step_size`."""
170
203
  closure = var.closure
171
204
  if closure is None: raise RuntimeError('line search requires closure')
172
205
  return self._loss(step_size=step_size, var=var, closure=closure, params=var.params,update=var.get_update(),backward=backward)
173
206
 
174
- def evaluate_step_size_loss_and_derivative(self, step_size: float, var: Var):
207
+ def evaluate_f_d(self, step_size: float, var: Var):
208
+ """evaluate function value and directional derivative in the direction of the update at step size `step_size`."""
175
209
  closure = var.closure
176
210
  if closure is None: raise RuntimeError('line search requires closure')
177
211
  return self._loss_derivative(step_size=step_size, var=var, closure=closure, params=var.params,update=var.get_update())
178
212
 
213
+ def evaluate_f_d_g(self, step_size: float, var: Var):
214
+ """evaluate function value, directional derivative, and gradient list at step size `step_size`."""
215
+ closure = var.closure
216
+ if closure is None: raise RuntimeError('line search requires closure')
217
+ return self._loss_derivative_gradient(step_size=step_size, var=var, closure=closure, params=var.params,update=var.get_update())
218
+
179
219
  def make_objective(self, var: Var, backward:bool=False):
180
220
  closure = var.closure
181
221
  if closure is None: raise RuntimeError('line search requires closure')
@@ -186,6 +226,11 @@ class LineSearchBase(Module, ABC):
186
226
  if closure is None: raise RuntimeError('line search requires closure')
187
227
  return partial(self._loss_derivative, var=var, closure=closure, params=var.params, update=var.get_update())
188
228
 
229
+ def make_objective_with_derivative_and_gradient(self, var: Var):
230
+ closure = var.closure
231
+ if closure is None: raise RuntimeError('line search requires closure')
232
+ return partial(self._loss_derivative_gradient, var=var, closure=closure, params=var.params, update=var.get_update())
233
+
189
234
  @abstractmethod
190
235
  def search(self, update: list[torch.Tensor], var: Var) -> float:
191
236
  """Finds the step size to use"""
@@ -193,7 +238,9 @@ class LineSearchBase(Module, ABC):
193
238
  @torch.no_grad
194
239
  def step(self, var: Var) -> Var:
195
240
  self._reset()
241
+
196
242
  params = var.params
243
+ self._initial_params = [p.clone() for p in params]
197
244
  update = var.get_update()
198
245
 
199
246
  try:
@@ -206,7 +253,6 @@ class LineSearchBase(Module, ABC):
206
253
 
207
254
  # this is last module - set step size to found step_size times lr
208
255
  if var.is_last:
209
-
210
256
  if var.last_module_lrs is None:
211
257
  self.set_step_size_(step_size, params=params, update=update)
212
258
 
@@ -223,17 +269,62 @@ class LineSearchBase(Module, ABC):
223
269
 
224
270
 
225
271
 
226
- # class GridLineSearch(LineSearch):
227
- # """Mostly for testing, this is not practical"""
228
- # def __init__(self, start, end, num):
229
- # defaults = dict(start=start,end=end,num=num)
230
- # super().__init__(defaults)
231
-
232
- # @torch.no_grad
233
- # def search(self, update, var):
234
- # start,end,num=itemgetter('start','end','num')(self.settings[var.params[0]])
235
-
236
- # for lr in torch.linspace(start,end,num):
237
- # self.evaluate_step_size(lr.item(), var=var, backward=False)
272
+ class GridLineSearch(LineSearchBase):
273
+ """"""
274
+ def __init__(self, start, end, num):
275
+ defaults = dict(start=start,end=end,num=num)
276
+ super().__init__(defaults)
238
277
 
239
- # return self._best_step_size
278
+ @torch.no_grad
279
+ def search(self, update, var):
280
+ start,end,num=itemgetter('start','end','num')(self.defaults)
281
+
282
+ for lr in torch.linspace(start,end,num):
283
+ self.evaluate_f(lr.item(), var=var, backward=False)
284
+
285
+ return self._best_step_size
286
+
287
+
288
+ def sufficient_decrease(f_0, g_0, f_a, a, c):
289
+ return f_a < f_0 + c*a*min(g_0, 0)
290
+
291
+ def curvature(g_0, g_a, c):
292
+ if g_0 > 0: return True
293
+ return g_a >= c * g_0
294
+
295
+ def strong_curvature(g_0, g_a, c):
296
+ """same as curvature condition except curvature can't be too positive (which indicates overstep)"""
297
+ if g_0 > 0: return True
298
+ return abs(g_a) <= c * abs(g_0)
299
+
300
+ def wolfe(f_0, g_0, f_a, g_a, a, c1, c2):
301
+ return sufficient_decrease(f_0, g_0, f_a, a, c1) and curvature(g_0, g_a, c2)
302
+
303
+ def strong_wolfe(f_0, g_0, f_a, g_a, a, c1, c2):
304
+ return sufficient_decrease(f_0, g_0, f_a, a, c1) and strong_curvature(g_0, g_a, c2)
305
+
306
+ def goldstein(f_0, g_0, f_a, a, c):
307
+ """same as armijo (sufficient_decrease) but additional lower bound"""
308
+ g_0 = min(g_0, 0)
309
+ return f_0 + (1-c)*a*g_0 < f_a < f_0 + c*a*g_0
310
+
311
+ TerminationCondition = Literal["armijo", "curvature", "strong_curvature", "wolfe", "strong_wolfe", "goldstein", "decrease"]
312
+ def termination_condition(
313
+ condition: TerminationCondition,
314
+ f_0,
315
+ g_0,
316
+ f_a,
317
+ g_a: Any | None,
318
+ a,
319
+ c,
320
+ c2=None,
321
+ ):
322
+ if not math.isfinite(f_a): return False
323
+ if condition == 'armijo': return sufficient_decrease(f_0, g_0, f_a, a, c)
324
+ if condition == 'curvature': return curvature(g_0, g_a, c)
325
+ if condition == 'strong_curvature': return strong_curvature(g_0, g_a, c)
326
+ if condition == 'wolfe': return wolfe(f_0, g_0, f_a, g_a, a, c, c2)
327
+ if condition == 'strong_wolfe': return strong_wolfe(f_0, g_0, f_a, g_a, a, c, c2)
328
+ if condition == 'goldstein': return goldstein(f_0, g_0, f_a, a, c)
329
+ if condition == 'decrease': return f_a < f_0
330
+ raise ValueError(f"unknown condition {condition}")
@@ -42,7 +42,7 @@ class ScipyMinimizeScalar(LineSearchBase):
42
42
  def search(self, update, var):
43
43
  objective = self.make_objective(var=var)
44
44
  method, bracket, bounds, tol, options, maxiter = itemgetter(
45
- 'method', 'bracket', 'bounds', 'tol', 'options', 'maxiter')(self.settings[var.params[0]])
45
+ 'method', 'bracket', 'bounds', 'tol', 'options', 'maxiter')(self.defaults)
46
46
 
47
47
  if maxiter is not None:
48
48
  options = dict(options) if isinstance(options, Mapping) else {}