torchzero 0.3.10__py3-none-any.whl → 0.3.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +2 -3
- tests/test_opts.py +140 -100
- tests/test_tensorlist.py +8 -7
- tests/test_vars.py +1 -0
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +335 -50
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +197 -70
- torchzero/modules/__init__.py +13 -4
- torchzero/modules/adaptive/__init__.py +30 -0
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/adaptive/adahessian.py +224 -0
- torchzero/modules/{optimizers → adaptive}/adam.py +6 -8
- torchzero/modules/adaptive/adan.py +96 -0
- torchzero/modules/adaptive/adaptive_heavyball.py +54 -0
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/adaptive/esgd.py +171 -0
- torchzero/modules/{optimizers → adaptive}/lion.py +1 -1
- torchzero/modules/{experimental/spectral.py → adaptive/lmadagrad.py} +94 -71
- torchzero/modules/adaptive/mars.py +79 -0
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/adaptive/msam.py +188 -0
- torchzero/modules/{optimizers → adaptive}/muon.py +29 -5
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/orthograd.py +1 -1
- torchzero/modules/{optimizers → adaptive}/rmsprop.py +7 -4
- torchzero/modules/{optimizers → adaptive}/rprop.py +42 -10
- torchzero/modules/adaptive/sam.py +163 -0
- torchzero/modules/{optimizers → adaptive}/shampoo.py +47 -9
- torchzero/modules/{optimizers → adaptive}/soap.py +52 -65
- torchzero/modules/adaptive/sophia_h.py +185 -0
- torchzero/modules/clipping/clipping.py +115 -25
- torchzero/modules/clipping/ema_clipping.py +31 -17
- torchzero/modules/clipping/growth_clipping.py +8 -7
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/conjugate_gradient/cg.py +355 -0
- torchzero/modules/experimental/__init__.py +13 -19
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +5 -42
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +32 -15
- torchzero/modules/experimental/reduce_outward_lr.py +4 -4
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +13 -55
- torchzero/modules/functional.py +52 -6
- torchzero/modules/grad_approximation/fdm.py +30 -4
- torchzero/modules/grad_approximation/forward_gradient.py +16 -4
- torchzero/modules/grad_approximation/grad_approximator.py +51 -10
- torchzero/modules/grad_approximation/rfdm.py +321 -52
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +164 -93
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +124 -0
- torchzero/modules/line_search/backtracking.py +95 -57
- torchzero/modules/line_search/line_search.py +171 -22
- torchzero/modules/line_search/scipy.py +3 -3
- torchzero/modules/line_search/strong_wolfe.py +327 -199
- torchzero/modules/misc/__init__.py +35 -0
- torchzero/modules/misc/debug.py +48 -0
- torchzero/modules/misc/escape.py +62 -0
- torchzero/modules/misc/gradient_accumulation.py +136 -0
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +383 -0
- torchzero/modules/misc/multistep.py +194 -0
- torchzero/modules/misc/regularization.py +167 -0
- torchzero/modules/misc/split.py +123 -0
- torchzero/modules/{ops → misc}/switch.py +45 -4
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +9 -9
- torchzero/modules/momentum/cautious.py +51 -19
- torchzero/modules/momentum/momentum.py +37 -2
- torchzero/modules/ops/__init__.py +11 -31
- torchzero/modules/ops/accumulate.py +6 -10
- torchzero/modules/ops/binary.py +81 -34
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +16 -39
- torchzero/modules/ops/multi.py +82 -21
- torchzero/modules/ops/reduce.py +16 -8
- torchzero/modules/ops/unary.py +29 -13
- torchzero/modules/ops/utility.py +30 -18
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +3 -1
- torchzero/modules/projections/projection.py +190 -96
- torchzero/modules/quasi_newton/__init__.py +9 -14
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -0
- torchzero/modules/quasi_newton/lbfgs.py +286 -173
- torchzero/modules/quasi_newton/lsr1.py +185 -106
- torchzero/modules/quasi_newton/quasi_newton.py +816 -268
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +252 -0
- torchzero/modules/second_order/__init__.py +3 -2
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +292 -68
- torchzero/modules/second_order/newton_cg.py +365 -15
- torchzero/modules/second_order/nystrom.py +104 -1
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/laplacian.py +14 -4
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +387 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +97 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +94 -11
- torchzero/modules/wrappers/optim_wrapper.py +29 -1
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +359 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +39 -3
- torchzero/optim/wrappers/fcmaes.py +24 -15
- torchzero/optim/wrappers/mads.py +5 -6
- torchzero/optim/wrappers/nevergrad.py +16 -1
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +3 -3
- torchzero/optim/wrappers/scipy.py +86 -25
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +126 -114
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +369 -58
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/python_tools.py +16 -0
- torchzero/utils/tensorlist.py +134 -51
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.13.dist-info/METADATA +14 -0
- torchzero-0.3.13.dist-info/RECORD +166 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -57
- torchzero/modules/experimental/absoap.py +0 -250
- torchzero/modules/experimental/adadam.py +0 -112
- torchzero/modules/experimental/adamY.py +0 -125
- torchzero/modules/experimental/adasoap.py +0 -172
- torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
- torchzero/modules/experimental/eigendescent.py +0 -117
- torchzero/modules/experimental/etf.py +0 -172
- torchzero/modules/experimental/soapy.py +0 -163
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/experimental/subspace_preconditioners.py +0 -138
- torchzero/modules/experimental/tada.py +0 -38
- torchzero/modules/line_search/trust_region.py +0 -73
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/adaptive.py +0 -93
- torchzero/modules/lr/lr.py +0 -63
- torchzero/modules/momentum/matrix_momentum.py +0 -166
- torchzero/modules/ops/debug.py +0 -25
- torchzero/modules/ops/misc.py +0 -418
- torchzero/modules/ops/split.py +0 -75
- torchzero/modules/optimizers/__init__.py +0 -18
- torchzero/modules/optimizers/adagrad.py +0 -155
- torchzero/modules/optimizers/sophia_h.py +0 -129
- torchzero/modules/quasi_newton/cg.py +0 -268
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -266
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero/modules/smoothing/gaussian.py +0 -164
- torchzero-0.3.10.dist-info/METADATA +0 -379
- torchzero-0.3.10.dist-info/RECORD +0 -139
- torchzero-0.3.10.dist-info/licenses/LICENSE +0 -21
- {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
|
@@ -3,20 +3,21 @@ from abc import ABC, abstractmethod
|
|
|
3
3
|
from collections.abc import Sequence
|
|
4
4
|
from functools import partial
|
|
5
5
|
from operator import itemgetter
|
|
6
|
-
from typing import Any
|
|
6
|
+
from typing import Any, Literal
|
|
7
7
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
import torch
|
|
10
10
|
|
|
11
11
|
from ...core import Module, Target, Var
|
|
12
|
-
from ...utils import tofloat
|
|
12
|
+
from ...utils import tofloat, set_storage_
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
class MaxLineSearchItersReached(Exception): pass
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
class
|
|
18
|
+
class LineSearchBase(Module, ABC):
|
|
19
19
|
"""Base class for line searches.
|
|
20
|
+
|
|
20
21
|
This is an abstract class, to use it, subclass it and override `search`.
|
|
21
22
|
|
|
22
23
|
Args:
|
|
@@ -26,6 +27,61 @@ class LineSearch(Module, ABC):
|
|
|
26
27
|
the objective this many times, and step size with the lowest loss value will be used.
|
|
27
28
|
This is useful when passing `make_objective` to an external library which
|
|
28
29
|
doesn't have a maxiter option. Defaults to None.
|
|
30
|
+
|
|
31
|
+
Other useful methods:
|
|
32
|
+
* ``evaluate_f`` - returns loss with a given scalar step size
|
|
33
|
+
* ``evaluate_f_d`` - returns loss and directional derivative with a given scalar step size
|
|
34
|
+
* ``make_objective`` - creates a function that accepts a scalar step size and returns loss. This can be passed to a scalar solver, such as scipy.optimize.minimize_scalar.
|
|
35
|
+
* ``make_objective_with_derivative`` - creates a function that accepts a scalar step size and returns a tuple with loss and directional derivative. This can be passed to a scalar solver.
|
|
36
|
+
|
|
37
|
+
Examples:
|
|
38
|
+
|
|
39
|
+
#### Basic line search
|
|
40
|
+
|
|
41
|
+
This evaluates all step sizes in a range by using the :code:`self.evaluate_step_size` method.
|
|
42
|
+
```python
|
|
43
|
+
class GridLineSearch(LineSearch):
|
|
44
|
+
def __init__(self, start, end, num):
|
|
45
|
+
defaults = dict(start=start,end=end,num=num)
|
|
46
|
+
super().__init__(defaults)
|
|
47
|
+
|
|
48
|
+
@torch.no_grad
|
|
49
|
+
def search(self, update, var):
|
|
50
|
+
|
|
51
|
+
start = self.defaults["start"]
|
|
52
|
+
end = self.defaults["end"]
|
|
53
|
+
num = self.defaults["num"]
|
|
54
|
+
|
|
55
|
+
lowest_loss = float("inf")
|
|
56
|
+
best_step_size = best_step_size
|
|
57
|
+
|
|
58
|
+
for step_size in torch.linspace(start,end,num):
|
|
59
|
+
loss = self.evaluate_step_size(step_size.item(), var=var, backward=False)
|
|
60
|
+
if loss < lowest_loss:
|
|
61
|
+
lowest_loss = loss
|
|
62
|
+
best_step_size = step_size
|
|
63
|
+
|
|
64
|
+
return best_step_size
|
|
65
|
+
```
|
|
66
|
+
|
|
67
|
+
#### Using external solver via self.make_objective
|
|
68
|
+
|
|
69
|
+
Here we let :code:`scipy.optimize.minimize_scalar` solver find the best step size via :code:`self.make_objective`
|
|
70
|
+
|
|
71
|
+
```python
|
|
72
|
+
class ScipyMinimizeScalar(LineSearch):
|
|
73
|
+
def __init__(self, method: str | None = None):
|
|
74
|
+
defaults = dict(method=method)
|
|
75
|
+
super().__init__(defaults)
|
|
76
|
+
|
|
77
|
+
@torch.no_grad
|
|
78
|
+
def search(self, update, var):
|
|
79
|
+
objective = self.make_objective(var=var)
|
|
80
|
+
method = self.defaults["method"]
|
|
81
|
+
|
|
82
|
+
res = self.scopt.minimize_scalar(objective, method=method)
|
|
83
|
+
return res.x
|
|
84
|
+
```
|
|
29
85
|
"""
|
|
30
86
|
def __init__(self, defaults: dict[str, Any] | None, maxiter: int | None = None):
|
|
31
87
|
super().__init__(defaults)
|
|
@@ -37,6 +93,7 @@ class LineSearch(Module, ABC):
|
|
|
37
93
|
self._lowest_loss = float('inf')
|
|
38
94
|
self._best_step_size: float = 0
|
|
39
95
|
self._current_iter = 0
|
|
96
|
+
self._initial_params = None
|
|
40
97
|
|
|
41
98
|
def set_step_size_(
|
|
42
99
|
self,
|
|
@@ -45,10 +102,27 @@ class LineSearch(Module, ABC):
|
|
|
45
102
|
update: list[torch.Tensor],
|
|
46
103
|
):
|
|
47
104
|
if not math.isfinite(step_size): return
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
105
|
+
|
|
106
|
+
# fixes overflow when backtracking keeps increasing alpha after converging
|
|
107
|
+
step_size = max(min(tofloat(step_size), 1e36), -1e36)
|
|
108
|
+
|
|
109
|
+
# skip is parameters are already at suggested step size
|
|
110
|
+
if self._current_step_size == step_size: return
|
|
111
|
+
|
|
112
|
+
# this was basically causing floating point imprecision to build up
|
|
113
|
+
#if False:
|
|
114
|
+
# if abs(alpha) < abs(step_size) and step_size != 0:
|
|
115
|
+
# torch._foreach_add_(params, update, alpha=alpha)
|
|
116
|
+
|
|
117
|
+
# else:
|
|
118
|
+
assert self._initial_params is not None
|
|
119
|
+
if step_size == 0:
|
|
120
|
+
new_params = [p.clone() for p in self._initial_params]
|
|
121
|
+
else:
|
|
122
|
+
new_params = torch._foreach_sub(self._initial_params, update, alpha=step_size)
|
|
123
|
+
for c, n in zip(params, new_params):
|
|
124
|
+
set_storage_(c, n)
|
|
125
|
+
|
|
52
126
|
self._current_step_size = step_size
|
|
53
127
|
|
|
54
128
|
def _set_per_parameter_step_size_(
|
|
@@ -57,10 +131,20 @@ class LineSearch(Module, ABC):
|
|
|
57
131
|
params: list[torch.Tensor],
|
|
58
132
|
update: list[torch.Tensor],
|
|
59
133
|
):
|
|
60
|
-
if not np.isfinite(step_size): step_size = [0 for _ in step_size]
|
|
61
|
-
alpha = [self._current_step_size - s for s in step_size]
|
|
62
|
-
if any(a!=0 for a in alpha):
|
|
63
|
-
|
|
134
|
+
# if not np.isfinite(step_size): step_size = [0 for _ in step_size]
|
|
135
|
+
# alpha = [self._current_step_size - s for s in step_size]
|
|
136
|
+
# if any(a!=0 for a in alpha):
|
|
137
|
+
# torch._foreach_add_(params, torch._foreach_mul(update, alpha))
|
|
138
|
+
assert self._initial_params is not None
|
|
139
|
+
if not np.isfinite(step_size).all(): step_size = [0 for _ in step_size]
|
|
140
|
+
|
|
141
|
+
if any(s!=0 for s in step_size):
|
|
142
|
+
new_params = torch._foreach_sub(self._initial_params, torch._foreach_mul(update, step_size))
|
|
143
|
+
else:
|
|
144
|
+
new_params = [p.clone() for p in self._initial_params]
|
|
145
|
+
|
|
146
|
+
for c, n in zip(params, new_params):
|
|
147
|
+
set_storage_(c, n)
|
|
64
148
|
|
|
65
149
|
def _loss(self, step_size: float, var: Var, closure, params: list[torch.Tensor],
|
|
66
150
|
update: list[torch.Tensor], backward:bool=False) -> float:
|
|
@@ -92,7 +176,7 @@ class LineSearch(Module, ABC):
|
|
|
92
176
|
|
|
93
177
|
return tofloat(loss)
|
|
94
178
|
|
|
95
|
-
def
|
|
179
|
+
def _loss_derivative_gradient(self, step_size: float, var: Var, closure,
|
|
96
180
|
params: list[torch.Tensor], update: list[torch.Tensor]):
|
|
97
181
|
# if step_size is 0, we might already know the derivative
|
|
98
182
|
if (var.grad is not None) and (step_size == 0):
|
|
@@ -107,18 +191,31 @@ class LineSearch(Module, ABC):
|
|
|
107
191
|
derivative = - sum(t.sum() for t in torch._foreach_mul([p.grad if p.grad is not None
|
|
108
192
|
else torch.zeros_like(p) for p in params], update))
|
|
109
193
|
|
|
110
|
-
|
|
194
|
+
assert var.grad is not None
|
|
195
|
+
return loss, tofloat(derivative), var.grad
|
|
111
196
|
|
|
112
|
-
def
|
|
197
|
+
def _loss_derivative(self, step_size: float, var: Var, closure,
|
|
198
|
+
params: list[torch.Tensor], update: list[torch.Tensor]):
|
|
199
|
+
return self._loss_derivative_gradient(step_size=step_size, var=var,closure=closure,params=params,update=update)[:2]
|
|
200
|
+
|
|
201
|
+
def evaluate_f(self, step_size: float, var: Var, backward:bool=False):
|
|
202
|
+
"""evaluate function value at alpha `step_size`."""
|
|
113
203
|
closure = var.closure
|
|
114
204
|
if closure is None: raise RuntimeError('line search requires closure')
|
|
115
205
|
return self._loss(step_size=step_size, var=var, closure=closure, params=var.params,update=var.get_update(),backward=backward)
|
|
116
206
|
|
|
117
|
-
def
|
|
207
|
+
def evaluate_f_d(self, step_size: float, var: Var):
|
|
208
|
+
"""evaluate function value and directional derivative in the direction of the update at step size `step_size`."""
|
|
118
209
|
closure = var.closure
|
|
119
210
|
if closure is None: raise RuntimeError('line search requires closure')
|
|
120
211
|
return self._loss_derivative(step_size=step_size, var=var, closure=closure, params=var.params,update=var.get_update())
|
|
121
212
|
|
|
213
|
+
def evaluate_f_d_g(self, step_size: float, var: Var):
|
|
214
|
+
"""evaluate function value, directional derivative, and gradient list at step size `step_size`."""
|
|
215
|
+
closure = var.closure
|
|
216
|
+
if closure is None: raise RuntimeError('line search requires closure')
|
|
217
|
+
return self._loss_derivative_gradient(step_size=step_size, var=var, closure=closure, params=var.params,update=var.get_update())
|
|
218
|
+
|
|
122
219
|
def make_objective(self, var: Var, backward:bool=False):
|
|
123
220
|
closure = var.closure
|
|
124
221
|
if closure is None: raise RuntimeError('line search requires closure')
|
|
@@ -129,6 +226,11 @@ class LineSearch(Module, ABC):
|
|
|
129
226
|
if closure is None: raise RuntimeError('line search requires closure')
|
|
130
227
|
return partial(self._loss_derivative, var=var, closure=closure, params=var.params, update=var.get_update())
|
|
131
228
|
|
|
229
|
+
def make_objective_with_derivative_and_gradient(self, var: Var):
|
|
230
|
+
closure = var.closure
|
|
231
|
+
if closure is None: raise RuntimeError('line search requires closure')
|
|
232
|
+
return partial(self._loss_derivative_gradient, var=var, closure=closure, params=var.params, update=var.get_update())
|
|
233
|
+
|
|
132
234
|
@abstractmethod
|
|
133
235
|
def search(self, update: list[torch.Tensor], var: Var) -> float:
|
|
134
236
|
"""Finds the step size to use"""
|
|
@@ -136,7 +238,9 @@ class LineSearch(Module, ABC):
|
|
|
136
238
|
@torch.no_grad
|
|
137
239
|
def step(self, var: Var) -> Var:
|
|
138
240
|
self._reset()
|
|
241
|
+
|
|
139
242
|
params = var.params
|
|
243
|
+
self._initial_params = [p.clone() for p in params]
|
|
140
244
|
update = var.get_update()
|
|
141
245
|
|
|
142
246
|
try:
|
|
@@ -149,7 +253,6 @@ class LineSearch(Module, ABC):
|
|
|
149
253
|
|
|
150
254
|
# this is last module - set step size to found step_size times lr
|
|
151
255
|
if var.is_last:
|
|
152
|
-
|
|
153
256
|
if var.last_module_lrs is None:
|
|
154
257
|
self.set_step_size_(step_size, params=params, update=update)
|
|
155
258
|
|
|
@@ -165,17 +268,63 @@ class LineSearch(Module, ABC):
|
|
|
165
268
|
return var
|
|
166
269
|
|
|
167
270
|
|
|
168
|
-
|
|
169
|
-
|
|
271
|
+
|
|
272
|
+
class GridLineSearch(LineSearchBase):
|
|
273
|
+
""""""
|
|
170
274
|
def __init__(self, start, end, num):
|
|
171
275
|
defaults = dict(start=start,end=end,num=num)
|
|
172
276
|
super().__init__(defaults)
|
|
173
277
|
|
|
174
278
|
@torch.no_grad
|
|
175
279
|
def search(self, update, var):
|
|
176
|
-
start,end,num=itemgetter('start','end','num')(self.
|
|
280
|
+
start,end,num=itemgetter('start','end','num')(self.defaults)
|
|
177
281
|
|
|
178
282
|
for lr in torch.linspace(start,end,num):
|
|
179
|
-
self.
|
|
180
|
-
|
|
181
|
-
return self._best_step_size
|
|
283
|
+
self.evaluate_f(lr.item(), var=var, backward=False)
|
|
284
|
+
|
|
285
|
+
return self._best_step_size
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
def sufficient_decrease(f_0, g_0, f_a, a, c):
|
|
289
|
+
return f_a < f_0 + c*a*min(g_0, 0)
|
|
290
|
+
|
|
291
|
+
def curvature(g_0, g_a, c):
|
|
292
|
+
if g_0 > 0: return True
|
|
293
|
+
return g_a >= c * g_0
|
|
294
|
+
|
|
295
|
+
def strong_curvature(g_0, g_a, c):
|
|
296
|
+
"""same as curvature condition except curvature can't be too positive (which indicates overstep)"""
|
|
297
|
+
if g_0 > 0: return True
|
|
298
|
+
return abs(g_a) <= c * abs(g_0)
|
|
299
|
+
|
|
300
|
+
def wolfe(f_0, g_0, f_a, g_a, a, c1, c2):
|
|
301
|
+
return sufficient_decrease(f_0, g_0, f_a, a, c1) and curvature(g_0, g_a, c2)
|
|
302
|
+
|
|
303
|
+
def strong_wolfe(f_0, g_0, f_a, g_a, a, c1, c2):
|
|
304
|
+
return sufficient_decrease(f_0, g_0, f_a, a, c1) and strong_curvature(g_0, g_a, c2)
|
|
305
|
+
|
|
306
|
+
def goldstein(f_0, g_0, f_a, a, c):
|
|
307
|
+
"""same as armijo (sufficient_decrease) but additional lower bound"""
|
|
308
|
+
g_0 = min(g_0, 0)
|
|
309
|
+
return f_0 + (1-c)*a*g_0 < f_a < f_0 + c*a*g_0
|
|
310
|
+
|
|
311
|
+
TerminationCondition = Literal["armijo", "curvature", "strong_curvature", "wolfe", "strong_wolfe", "goldstein", "decrease"]
|
|
312
|
+
def termination_condition(
|
|
313
|
+
condition: TerminationCondition,
|
|
314
|
+
f_0,
|
|
315
|
+
g_0,
|
|
316
|
+
f_a,
|
|
317
|
+
g_a: Any | None,
|
|
318
|
+
a,
|
|
319
|
+
c,
|
|
320
|
+
c2=None,
|
|
321
|
+
):
|
|
322
|
+
if not math.isfinite(f_a): return False
|
|
323
|
+
if condition == 'armijo': return sufficient_decrease(f_0, g_0, f_a, a, c)
|
|
324
|
+
if condition == 'curvature': return curvature(g_0, g_a, c)
|
|
325
|
+
if condition == 'strong_curvature': return strong_curvature(g_0, g_a, c)
|
|
326
|
+
if condition == 'wolfe': return wolfe(f_0, g_0, f_a, g_a, a, c, c2)
|
|
327
|
+
if condition == 'strong_wolfe': return strong_wolfe(f_0, g_0, f_a, g_a, a, c, c2)
|
|
328
|
+
if condition == 'goldstein': return goldstein(f_0, g_0, f_a, a, c)
|
|
329
|
+
if condition == 'decrease': return f_a < f_0
|
|
330
|
+
raise ValueError(f"unknown condition {condition}")
|
|
@@ -3,10 +3,10 @@ from operator import itemgetter
|
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
5
|
|
|
6
|
-
from .line_search import
|
|
6
|
+
from .line_search import LineSearchBase
|
|
7
7
|
|
|
8
8
|
|
|
9
|
-
class ScipyMinimizeScalar(
|
|
9
|
+
class ScipyMinimizeScalar(LineSearchBase):
|
|
10
10
|
"""Line search via :code:`scipy.optimize.minimize_scalar` which implements brent, golden search and bounded brent methods.
|
|
11
11
|
|
|
12
12
|
Args:
|
|
@@ -42,7 +42,7 @@ class ScipyMinimizeScalar(LineSearch):
|
|
|
42
42
|
def search(self, update, var):
|
|
43
43
|
objective = self.make_objective(var=var)
|
|
44
44
|
method, bracket, bounds, tol, options, maxiter = itemgetter(
|
|
45
|
-
'method', 'bracket', 'bounds', 'tol', 'options', 'maxiter')(self.
|
|
45
|
+
'method', 'bracket', 'bounds', 'tol', 'options', 'maxiter')(self.defaults)
|
|
46
46
|
|
|
47
47
|
if maxiter is not None:
|
|
48
48
|
options = dict(options) if isinstance(options, Mapping) else {}
|