torchzero 0.3.11__py3-none-any.whl → 0.3.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_opts.py +95 -69
- tests/test_tensorlist.py +8 -7
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +225 -72
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +44 -24
- torchzero/modules/__init__.py +13 -5
- torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
- torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
- torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
- torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
- torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
- torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
- torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
- torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
- torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
- torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
- torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
- torchzero/modules/clipping/clipping.py +85 -92
- torchzero/modules/clipping/ema_clipping.py +5 -5
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
- torchzero/modules/experimental/__init__.py +9 -32
- torchzero/modules/experimental/dct.py +2 -2
- torchzero/modules/experimental/fft.py +2 -2
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +27 -14
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/experimental/structural_projections.py +1 -1
- torchzero/modules/functional.py +50 -14
- torchzero/modules/grad_approximation/fdm.py +19 -20
- torchzero/modules/grad_approximation/forward_gradient.py +4 -2
- torchzero/modules/grad_approximation/grad_approximator.py +43 -47
- torchzero/modules/grad_approximation/rfdm.py +144 -122
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +31 -23
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +2 -2
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +69 -44
- torchzero/modules/line_search/backtracking.py +83 -70
- torchzero/modules/line_search/line_search.py +159 -68
- torchzero/modules/line_search/scipy.py +1 -1
- torchzero/modules/line_search/strong_wolfe.py +319 -218
- torchzero/modules/misc/__init__.py +8 -0
- torchzero/modules/misc/debug.py +4 -4
- torchzero/modules/misc/escape.py +9 -7
- torchzero/modules/misc/gradient_accumulation.py +88 -22
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +82 -15
- torchzero/modules/misc/multistep.py +47 -11
- torchzero/modules/misc/regularization.py +5 -9
- torchzero/modules/misc/split.py +55 -35
- torchzero/modules/misc/switch.py +1 -1
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +3 -3
- torchzero/modules/momentum/cautious.py +42 -47
- torchzero/modules/momentum/momentum.py +35 -1
- torchzero/modules/ops/__init__.py +9 -1
- torchzero/modules/ops/binary.py +9 -8
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
- torchzero/modules/ops/multi.py +15 -15
- torchzero/modules/ops/reduce.py +1 -1
- torchzero/modules/ops/utility.py +12 -8
- torchzero/modules/projections/projection.py +4 -4
- torchzero/modules/quasi_newton/__init__.py +1 -16
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
- torchzero/modules/quasi_newton/lbfgs.py +256 -200
- torchzero/modules/quasi_newton/lsr1.py +167 -132
- torchzero/modules/quasi_newton/quasi_newton.py +346 -446
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +252 -0
- torchzero/modules/second_order/__init__.py +2 -1
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +133 -88
- torchzero/modules/second_order/newton_cg.py +141 -80
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +1 -1
- torchzero/modules/step_size/adaptive.py +312 -47
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +97 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/weight_decay.py +65 -64
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +359 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +0 -1
- torchzero/optim/wrappers/fcmaes.py +3 -2
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +2 -2
- torchzero/optim/wrappers/scipy.py +81 -22
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +123 -111
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +226 -154
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/python_tools.py +6 -0
- torchzero/utils/tensorlist.py +105 -34
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.13.dist-info/METADATA +14 -0
- torchzero-0.3.13.dist-info/RECORD +166 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -59
- docs/source/docstring template.py +0 -46
- torchzero/modules/experimental/absoap.py +0 -253
- torchzero/modules/experimental/adadam.py +0 -118
- torchzero/modules/experimental/adamY.py +0 -131
- torchzero/modules/experimental/adam_lambertw.py +0 -149
- torchzero/modules/experimental/adaptive_step_size.py +0 -90
- torchzero/modules/experimental/adasoap.py +0 -177
- torchzero/modules/experimental/cosine.py +0 -214
- torchzero/modules/experimental/cubic_adam.py +0 -97
- torchzero/modules/experimental/eigendescent.py +0 -120
- torchzero/modules/experimental/etf.py +0 -195
- torchzero/modules/experimental/exp_adam.py +0 -113
- torchzero/modules/experimental/expanded_lbfgs.py +0 -141
- torchzero/modules/experimental/hnewton.py +0 -85
- torchzero/modules/experimental/modular_lbfgs.py +0 -265
- torchzero/modules/experimental/parabolic_search.py +0 -220
- torchzero/modules/experimental/subspace_preconditioners.py +0 -145
- torchzero/modules/experimental/tensor_adagrad.py +0 -42
- torchzero/modules/line_search/polynomial.py +0 -233
- torchzero/modules/momentum/matrix_momentum.py +0 -193
- torchzero/modules/optimizers/adagrad.py +0 -165
- torchzero/modules/quasi_newton/trust_region.py +0 -397
- torchzero/modules/smoothing/gaussian.py +0 -198
- torchzero-0.3.11.dist-info/METADATA +0 -404
- torchzero-0.3.11.dist-info/RECORD +0 -159
- torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
- /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
|
@@ -3,13 +3,13 @@ from abc import ABC, abstractmethod
|
|
|
3
3
|
from collections.abc import Sequence
|
|
4
4
|
from functools import partial
|
|
5
5
|
from operator import itemgetter
|
|
6
|
-
from typing import Any
|
|
6
|
+
from typing import Any, Literal
|
|
7
7
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
import torch
|
|
10
10
|
|
|
11
11
|
from ...core import Module, Target, Var
|
|
12
|
-
from ...utils import tofloat
|
|
12
|
+
from ...utils import tofloat, set_storage_
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
class MaxLineSearchItersReached(Exception): pass
|
|
@@ -29,60 +29,59 @@ class LineSearchBase(Module, ABC):
|
|
|
29
29
|
doesn't have a maxiter option. Defaults to None.
|
|
30
30
|
|
|
31
31
|
Other useful methods:
|
|
32
|
-
*
|
|
33
|
-
*
|
|
34
|
-
*
|
|
35
|
-
*
|
|
32
|
+
* ``evaluate_f`` - returns loss with a given scalar step size
|
|
33
|
+
* ``evaluate_f_d`` - returns loss and directional derivative with a given scalar step size
|
|
34
|
+
* ``make_objective`` - creates a function that accepts a scalar step size and returns loss. This can be passed to a scalar solver, such as scipy.optimize.minimize_scalar.
|
|
35
|
+
* ``make_objective_with_derivative`` - creates a function that accepts a scalar step size and returns a tuple with loss and directional derivative. This can be passed to a scalar solver.
|
|
36
36
|
|
|
37
37
|
Examples:
|
|
38
|
-
#### Basic line search
|
|
39
38
|
|
|
40
|
-
|
|
39
|
+
#### Basic line search
|
|
41
40
|
|
|
42
|
-
|
|
41
|
+
This evaluates all step sizes in a range by using the :code:`self.evaluate_step_size` method.
|
|
42
|
+
```python
|
|
43
|
+
class GridLineSearch(LineSearch):
|
|
44
|
+
def __init__(self, start, end, num):
|
|
45
|
+
defaults = dict(start=start,end=end,num=num)
|
|
46
|
+
super().__init__(defaults)
|
|
43
47
|
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
defaults = dict(start=start,end=end,num=num)
|
|
47
|
-
super().__init__(defaults)
|
|
48
|
+
@torch.no_grad
|
|
49
|
+
def search(self, update, var):
|
|
48
50
|
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
start = settings["start"]
|
|
53
|
-
end = settings["end"]
|
|
54
|
-
num = settings["num"]
|
|
51
|
+
start = self.defaults["start"]
|
|
52
|
+
end = self.defaults["end"]
|
|
53
|
+
num = self.defaults["num"]
|
|
55
54
|
|
|
56
|
-
|
|
57
|
-
|
|
55
|
+
lowest_loss = float("inf")
|
|
56
|
+
best_step_size = best_step_size
|
|
58
57
|
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
58
|
+
for step_size in torch.linspace(start,end,num):
|
|
59
|
+
loss = self.evaluate_step_size(step_size.item(), var=var, backward=False)
|
|
60
|
+
if loss < lowest_loss:
|
|
61
|
+
lowest_loss = loss
|
|
62
|
+
best_step_size = step_size
|
|
64
63
|
|
|
65
|
-
|
|
64
|
+
return best_step_size
|
|
65
|
+
```
|
|
66
66
|
|
|
67
|
-
|
|
67
|
+
#### Using external solver via self.make_objective
|
|
68
68
|
|
|
69
|
-
|
|
69
|
+
Here we let :code:`scipy.optimize.minimize_scalar` solver find the best step size via :code:`self.make_objective`
|
|
70
70
|
|
|
71
|
-
|
|
71
|
+
```python
|
|
72
|
+
class ScipyMinimizeScalar(LineSearch):
|
|
73
|
+
def __init__(self, method: str | None = None):
|
|
74
|
+
defaults = dict(method=method)
|
|
75
|
+
super().__init__(defaults)
|
|
72
76
|
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
@torch.no_grad
|
|
79
|
-
def search(self, update, var):
|
|
80
|
-
objective = self.make_objective(var=var)
|
|
81
|
-
method = self.settings[var.params[0]]["method"]
|
|
82
|
-
|
|
83
|
-
res = self.scopt.minimize_scalar(objective, method=method)
|
|
84
|
-
return res.x
|
|
77
|
+
@torch.no_grad
|
|
78
|
+
def search(self, update, var):
|
|
79
|
+
objective = self.make_objective(var=var)
|
|
80
|
+
method = self.defaults["method"]
|
|
85
81
|
|
|
82
|
+
res = self.scopt.minimize_scalar(objective, method=method)
|
|
83
|
+
return res.x
|
|
84
|
+
```
|
|
86
85
|
"""
|
|
87
86
|
def __init__(self, defaults: dict[str, Any] | None, maxiter: int | None = None):
|
|
88
87
|
super().__init__(defaults)
|
|
@@ -94,6 +93,7 @@ class LineSearchBase(Module, ABC):
|
|
|
94
93
|
self._lowest_loss = float('inf')
|
|
95
94
|
self._best_step_size: float = 0
|
|
96
95
|
self._current_iter = 0
|
|
96
|
+
self._initial_params = None
|
|
97
97
|
|
|
98
98
|
def set_step_size_(
|
|
99
99
|
self,
|
|
@@ -102,10 +102,27 @@ class LineSearchBase(Module, ABC):
|
|
|
102
102
|
update: list[torch.Tensor],
|
|
103
103
|
):
|
|
104
104
|
if not math.isfinite(step_size): return
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
105
|
+
|
|
106
|
+
# fixes overflow when backtracking keeps increasing alpha after converging
|
|
107
|
+
step_size = max(min(tofloat(step_size), 1e36), -1e36)
|
|
108
|
+
|
|
109
|
+
# skip is parameters are already at suggested step size
|
|
110
|
+
if self._current_step_size == step_size: return
|
|
111
|
+
|
|
112
|
+
# this was basically causing floating point imprecision to build up
|
|
113
|
+
#if False:
|
|
114
|
+
# if abs(alpha) < abs(step_size) and step_size != 0:
|
|
115
|
+
# torch._foreach_add_(params, update, alpha=alpha)
|
|
116
|
+
|
|
117
|
+
# else:
|
|
118
|
+
assert self._initial_params is not None
|
|
119
|
+
if step_size == 0:
|
|
120
|
+
new_params = [p.clone() for p in self._initial_params]
|
|
121
|
+
else:
|
|
122
|
+
new_params = torch._foreach_sub(self._initial_params, update, alpha=step_size)
|
|
123
|
+
for c, n in zip(params, new_params):
|
|
124
|
+
set_storage_(c, n)
|
|
125
|
+
|
|
109
126
|
self._current_step_size = step_size
|
|
110
127
|
|
|
111
128
|
def _set_per_parameter_step_size_(
|
|
@@ -114,10 +131,20 @@ class LineSearchBase(Module, ABC):
|
|
|
114
131
|
params: list[torch.Tensor],
|
|
115
132
|
update: list[torch.Tensor],
|
|
116
133
|
):
|
|
117
|
-
if not np.isfinite(step_size): step_size = [0 for _ in step_size]
|
|
118
|
-
alpha = [self._current_step_size - s for s in step_size]
|
|
119
|
-
if any(a!=0 for a in alpha):
|
|
120
|
-
|
|
134
|
+
# if not np.isfinite(step_size): step_size = [0 for _ in step_size]
|
|
135
|
+
# alpha = [self._current_step_size - s for s in step_size]
|
|
136
|
+
# if any(a!=0 for a in alpha):
|
|
137
|
+
# torch._foreach_add_(params, torch._foreach_mul(update, alpha))
|
|
138
|
+
assert self._initial_params is not None
|
|
139
|
+
if not np.isfinite(step_size).all(): step_size = [0 for _ in step_size]
|
|
140
|
+
|
|
141
|
+
if any(s!=0 for s in step_size):
|
|
142
|
+
new_params = torch._foreach_sub(self._initial_params, torch._foreach_mul(update, step_size))
|
|
143
|
+
else:
|
|
144
|
+
new_params = [p.clone() for p in self._initial_params]
|
|
145
|
+
|
|
146
|
+
for c, n in zip(params, new_params):
|
|
147
|
+
set_storage_(c, n)
|
|
121
148
|
|
|
122
149
|
def _loss(self, step_size: float, var: Var, closure, params: list[torch.Tensor],
|
|
123
150
|
update: list[torch.Tensor], backward:bool=False) -> float:
|
|
@@ -149,7 +176,7 @@ class LineSearchBase(Module, ABC):
|
|
|
149
176
|
|
|
150
177
|
return tofloat(loss)
|
|
151
178
|
|
|
152
|
-
def
|
|
179
|
+
def _loss_derivative_gradient(self, step_size: float, var: Var, closure,
|
|
153
180
|
params: list[torch.Tensor], update: list[torch.Tensor]):
|
|
154
181
|
# if step_size is 0, we might already know the derivative
|
|
155
182
|
if (var.grad is not None) and (step_size == 0):
|
|
@@ -164,18 +191,31 @@ class LineSearchBase(Module, ABC):
|
|
|
164
191
|
derivative = - sum(t.sum() for t in torch._foreach_mul([p.grad if p.grad is not None
|
|
165
192
|
else torch.zeros_like(p) for p in params], update))
|
|
166
193
|
|
|
167
|
-
|
|
194
|
+
assert var.grad is not None
|
|
195
|
+
return loss, tofloat(derivative), var.grad
|
|
168
196
|
|
|
169
|
-
def
|
|
197
|
+
def _loss_derivative(self, step_size: float, var: Var, closure,
|
|
198
|
+
params: list[torch.Tensor], update: list[torch.Tensor]):
|
|
199
|
+
return self._loss_derivative_gradient(step_size=step_size, var=var,closure=closure,params=params,update=update)[:2]
|
|
200
|
+
|
|
201
|
+
def evaluate_f(self, step_size: float, var: Var, backward:bool=False):
|
|
202
|
+
"""evaluate function value at alpha `step_size`."""
|
|
170
203
|
closure = var.closure
|
|
171
204
|
if closure is None: raise RuntimeError('line search requires closure')
|
|
172
205
|
return self._loss(step_size=step_size, var=var, closure=closure, params=var.params,update=var.get_update(),backward=backward)
|
|
173
206
|
|
|
174
|
-
def
|
|
207
|
+
def evaluate_f_d(self, step_size: float, var: Var):
|
|
208
|
+
"""evaluate function value and directional derivative in the direction of the update at step size `step_size`."""
|
|
175
209
|
closure = var.closure
|
|
176
210
|
if closure is None: raise RuntimeError('line search requires closure')
|
|
177
211
|
return self._loss_derivative(step_size=step_size, var=var, closure=closure, params=var.params,update=var.get_update())
|
|
178
212
|
|
|
213
|
+
def evaluate_f_d_g(self, step_size: float, var: Var):
|
|
214
|
+
"""evaluate function value, directional derivative, and gradient list at step size `step_size`."""
|
|
215
|
+
closure = var.closure
|
|
216
|
+
if closure is None: raise RuntimeError('line search requires closure')
|
|
217
|
+
return self._loss_derivative_gradient(step_size=step_size, var=var, closure=closure, params=var.params,update=var.get_update())
|
|
218
|
+
|
|
179
219
|
def make_objective(self, var: Var, backward:bool=False):
|
|
180
220
|
closure = var.closure
|
|
181
221
|
if closure is None: raise RuntimeError('line search requires closure')
|
|
@@ -186,6 +226,11 @@ class LineSearchBase(Module, ABC):
|
|
|
186
226
|
if closure is None: raise RuntimeError('line search requires closure')
|
|
187
227
|
return partial(self._loss_derivative, var=var, closure=closure, params=var.params, update=var.get_update())
|
|
188
228
|
|
|
229
|
+
def make_objective_with_derivative_and_gradient(self, var: Var):
|
|
230
|
+
closure = var.closure
|
|
231
|
+
if closure is None: raise RuntimeError('line search requires closure')
|
|
232
|
+
return partial(self._loss_derivative_gradient, var=var, closure=closure, params=var.params, update=var.get_update())
|
|
233
|
+
|
|
189
234
|
@abstractmethod
|
|
190
235
|
def search(self, update: list[torch.Tensor], var: Var) -> float:
|
|
191
236
|
"""Finds the step size to use"""
|
|
@@ -193,7 +238,9 @@ class LineSearchBase(Module, ABC):
|
|
|
193
238
|
@torch.no_grad
|
|
194
239
|
def step(self, var: Var) -> Var:
|
|
195
240
|
self._reset()
|
|
241
|
+
|
|
196
242
|
params = var.params
|
|
243
|
+
self._initial_params = [p.clone() for p in params]
|
|
197
244
|
update = var.get_update()
|
|
198
245
|
|
|
199
246
|
try:
|
|
@@ -206,7 +253,6 @@ class LineSearchBase(Module, ABC):
|
|
|
206
253
|
|
|
207
254
|
# this is last module - set step size to found step_size times lr
|
|
208
255
|
if var.is_last:
|
|
209
|
-
|
|
210
256
|
if var.last_module_lrs is None:
|
|
211
257
|
self.set_step_size_(step_size, params=params, update=update)
|
|
212
258
|
|
|
@@ -223,17 +269,62 @@ class LineSearchBase(Module, ABC):
|
|
|
223
269
|
|
|
224
270
|
|
|
225
271
|
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
# @torch.no_grad
|
|
233
|
-
# def search(self, update, var):
|
|
234
|
-
# start,end,num=itemgetter('start','end','num')(self.settings[var.params[0]])
|
|
235
|
-
|
|
236
|
-
# for lr in torch.linspace(start,end,num):
|
|
237
|
-
# self.evaluate_step_size(lr.item(), var=var, backward=False)
|
|
272
|
+
class GridLineSearch(LineSearchBase):
|
|
273
|
+
""""""
|
|
274
|
+
def __init__(self, start, end, num):
|
|
275
|
+
defaults = dict(start=start,end=end,num=num)
|
|
276
|
+
super().__init__(defaults)
|
|
238
277
|
|
|
239
|
-
|
|
278
|
+
@torch.no_grad
|
|
279
|
+
def search(self, update, var):
|
|
280
|
+
start,end,num=itemgetter('start','end','num')(self.defaults)
|
|
281
|
+
|
|
282
|
+
for lr in torch.linspace(start,end,num):
|
|
283
|
+
self.evaluate_f(lr.item(), var=var, backward=False)
|
|
284
|
+
|
|
285
|
+
return self._best_step_size
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
def sufficient_decrease(f_0, g_0, f_a, a, c):
|
|
289
|
+
return f_a < f_0 + c*a*min(g_0, 0)
|
|
290
|
+
|
|
291
|
+
def curvature(g_0, g_a, c):
|
|
292
|
+
if g_0 > 0: return True
|
|
293
|
+
return g_a >= c * g_0
|
|
294
|
+
|
|
295
|
+
def strong_curvature(g_0, g_a, c):
|
|
296
|
+
"""same as curvature condition except curvature can't be too positive (which indicates overstep)"""
|
|
297
|
+
if g_0 > 0: return True
|
|
298
|
+
return abs(g_a) <= c * abs(g_0)
|
|
299
|
+
|
|
300
|
+
def wolfe(f_0, g_0, f_a, g_a, a, c1, c2):
|
|
301
|
+
return sufficient_decrease(f_0, g_0, f_a, a, c1) and curvature(g_0, g_a, c2)
|
|
302
|
+
|
|
303
|
+
def strong_wolfe(f_0, g_0, f_a, g_a, a, c1, c2):
|
|
304
|
+
return sufficient_decrease(f_0, g_0, f_a, a, c1) and strong_curvature(g_0, g_a, c2)
|
|
305
|
+
|
|
306
|
+
def goldstein(f_0, g_0, f_a, a, c):
|
|
307
|
+
"""same as armijo (sufficient_decrease) but additional lower bound"""
|
|
308
|
+
g_0 = min(g_0, 0)
|
|
309
|
+
return f_0 + (1-c)*a*g_0 < f_a < f_0 + c*a*g_0
|
|
310
|
+
|
|
311
|
+
TerminationCondition = Literal["armijo", "curvature", "strong_curvature", "wolfe", "strong_wolfe", "goldstein", "decrease"]
|
|
312
|
+
def termination_condition(
|
|
313
|
+
condition: TerminationCondition,
|
|
314
|
+
f_0,
|
|
315
|
+
g_0,
|
|
316
|
+
f_a,
|
|
317
|
+
g_a: Any | None,
|
|
318
|
+
a,
|
|
319
|
+
c,
|
|
320
|
+
c2=None,
|
|
321
|
+
):
|
|
322
|
+
if not math.isfinite(f_a): return False
|
|
323
|
+
if condition == 'armijo': return sufficient_decrease(f_0, g_0, f_a, a, c)
|
|
324
|
+
if condition == 'curvature': return curvature(g_0, g_a, c)
|
|
325
|
+
if condition == 'strong_curvature': return strong_curvature(g_0, g_a, c)
|
|
326
|
+
if condition == 'wolfe': return wolfe(f_0, g_0, f_a, g_a, a, c, c2)
|
|
327
|
+
if condition == 'strong_wolfe': return strong_wolfe(f_0, g_0, f_a, g_a, a, c, c2)
|
|
328
|
+
if condition == 'goldstein': return goldstein(f_0, g_0, f_a, a, c)
|
|
329
|
+
if condition == 'decrease': return f_a < f_0
|
|
330
|
+
raise ValueError(f"unknown condition {condition}")
|
|
@@ -42,7 +42,7 @@ class ScipyMinimizeScalar(LineSearchBase):
|
|
|
42
42
|
def search(self, update, var):
|
|
43
43
|
objective = self.make_objective(var=var)
|
|
44
44
|
method, bracket, bounds, tol, options, maxiter = itemgetter(
|
|
45
|
-
'method', 'bracket', 'bounds', 'tol', 'options', 'maxiter')(self.
|
|
45
|
+
'method', 'bracket', 'bounds', 'tol', 'options', 'maxiter')(self.defaults)
|
|
46
46
|
|
|
47
47
|
if maxiter is not None:
|
|
48
48
|
options = dict(options) if isinstance(options, Mapping) else {}
|