torchzero 0.3.11__py3-none-any.whl → 0.3.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_opts.py +95 -76
- tests/test_tensorlist.py +8 -7
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +229 -72
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +44 -24
- torchzero/modules/__init__.py +13 -5
- torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
- torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
- torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
- torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
- torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
- torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
- torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
- torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
- torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
- torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
- torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
- torchzero/modules/clipping/clipping.py +85 -92
- torchzero/modules/clipping/ema_clipping.py +5 -5
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
- torchzero/modules/experimental/__init__.py +9 -32
- torchzero/modules/experimental/dct.py +2 -2
- torchzero/modules/experimental/fft.py +2 -2
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +27 -14
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/experimental/spsa1.py +93 -0
- torchzero/modules/experimental/structural_projections.py +1 -1
- torchzero/modules/functional.py +50 -14
- torchzero/modules/grad_approximation/__init__.py +1 -1
- torchzero/modules/grad_approximation/fdm.py +19 -20
- torchzero/modules/grad_approximation/forward_gradient.py +6 -7
- torchzero/modules/grad_approximation/grad_approximator.py +43 -47
- torchzero/modules/grad_approximation/rfdm.py +114 -175
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +31 -23
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +2 -2
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +69 -44
- torchzero/modules/line_search/backtracking.py +83 -70
- torchzero/modules/line_search/line_search.py +159 -68
- torchzero/modules/line_search/scipy.py +16 -4
- torchzero/modules/line_search/strong_wolfe.py +319 -220
- torchzero/modules/misc/__init__.py +8 -0
- torchzero/modules/misc/debug.py +4 -4
- torchzero/modules/misc/escape.py +9 -7
- torchzero/modules/misc/gradient_accumulation.py +88 -22
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +82 -15
- torchzero/modules/misc/multistep.py +47 -11
- torchzero/modules/misc/regularization.py +5 -9
- torchzero/modules/misc/split.py +55 -35
- torchzero/modules/misc/switch.py +1 -1
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +3 -3
- torchzero/modules/momentum/cautious.py +42 -47
- torchzero/modules/momentum/momentum.py +35 -1
- torchzero/modules/ops/__init__.py +9 -1
- torchzero/modules/ops/binary.py +9 -8
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
- torchzero/modules/ops/multi.py +15 -15
- torchzero/modules/ops/reduce.py +1 -1
- torchzero/modules/ops/utility.py +12 -8
- torchzero/modules/projections/projection.py +4 -4
- torchzero/modules/quasi_newton/__init__.py +1 -16
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
- torchzero/modules/quasi_newton/lbfgs.py +256 -200
- torchzero/modules/quasi_newton/lsr1.py +167 -132
- torchzero/modules/quasi_newton/quasi_newton.py +346 -446
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +253 -0
- torchzero/modules/second_order/__init__.py +2 -1
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +133 -88
- torchzero/modules/second_order/newton_cg.py +207 -170
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +1 -1
- torchzero/modules/step_size/adaptive.py +312 -47
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +99 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/weight_decay.py +65 -64
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +122 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +0 -1
- torchzero/optim/wrappers/fcmaes.py +3 -2
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +2 -2
- torchzero/optim/wrappers/scipy.py +81 -22
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +123 -111
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +226 -154
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/optimizer.py +2 -2
- torchzero/utils/python_tools.py +7 -0
- torchzero/utils/tensorlist.py +105 -34
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.14.dist-info/METADATA +14 -0
- torchzero-0.3.14.dist-info/RECORD +167 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -59
- docs/source/docstring template.py +0 -46
- torchzero/modules/experimental/absoap.py +0 -253
- torchzero/modules/experimental/adadam.py +0 -118
- torchzero/modules/experimental/adamY.py +0 -131
- torchzero/modules/experimental/adam_lambertw.py +0 -149
- torchzero/modules/experimental/adaptive_step_size.py +0 -90
- torchzero/modules/experimental/adasoap.py +0 -177
- torchzero/modules/experimental/cosine.py +0 -214
- torchzero/modules/experimental/cubic_adam.py +0 -97
- torchzero/modules/experimental/eigendescent.py +0 -120
- torchzero/modules/experimental/etf.py +0 -195
- torchzero/modules/experimental/exp_adam.py +0 -113
- torchzero/modules/experimental/expanded_lbfgs.py +0 -141
- torchzero/modules/experimental/hnewton.py +0 -85
- torchzero/modules/experimental/modular_lbfgs.py +0 -265
- torchzero/modules/experimental/parabolic_search.py +0 -220
- torchzero/modules/experimental/subspace_preconditioners.py +0 -145
- torchzero/modules/experimental/tensor_adagrad.py +0 -42
- torchzero/modules/line_search/polynomial.py +0 -233
- torchzero/modules/momentum/matrix_momentum.py +0 -193
- torchzero/modules/optimizers/adagrad.py +0 -165
- torchzero/modules/quasi_newton/trust_region.py +0 -397
- torchzero/modules/smoothing/gaussian.py +0 -198
- torchzero-0.3.11.dist-info/METADATA +0 -404
- torchzero-0.3.11.dist-info/RECORD +0 -159
- torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
- /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from ...core import Module
|
|
3
|
+
|
|
4
|
+
from ...utils.derivatives import jacobian_wrt, flatten_jacobian
|
|
5
|
+
from ...utils import vec_to_tensors
|
|
6
|
+
from ...utils.linalg import linear_operator
|
|
7
|
+
class SumOfSquares(Module):
|
|
8
|
+
"""Sets loss to be the sum of squares of values returned by the closure.
|
|
9
|
+
|
|
10
|
+
This is meant to be used to test least squares methods against ordinary minimization methods.
|
|
11
|
+
|
|
12
|
+
To use this, the closure should return a vector of values to minimize sum of squares of.
|
|
13
|
+
Please add the `backward` argument, it will always be False but it is required.
|
|
14
|
+
"""
|
|
15
|
+
def __init__(self):
|
|
16
|
+
super().__init__()
|
|
17
|
+
|
|
18
|
+
@torch.no_grad
|
|
19
|
+
def step(self, var):
|
|
20
|
+
closure = var.closure
|
|
21
|
+
|
|
22
|
+
if closure is not None:
|
|
23
|
+
def sos_closure(backward=True):
|
|
24
|
+
if backward:
|
|
25
|
+
var.zero_grad()
|
|
26
|
+
with torch.enable_grad():
|
|
27
|
+
loss = closure(False)
|
|
28
|
+
loss = loss.pow(2).sum()
|
|
29
|
+
loss.backward()
|
|
30
|
+
return loss
|
|
31
|
+
|
|
32
|
+
loss = closure(False)
|
|
33
|
+
return loss.pow(2).sum()
|
|
34
|
+
|
|
35
|
+
var.closure = sos_closure
|
|
36
|
+
|
|
37
|
+
if var.loss is not None:
|
|
38
|
+
var.loss = var.loss.pow(2).sum()
|
|
39
|
+
|
|
40
|
+
if var.loss_approx is not None:
|
|
41
|
+
var.loss_approx = var.loss_approx.pow(2).sum()
|
|
42
|
+
|
|
43
|
+
return var
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class GaussNewton(Module):
|
|
47
|
+
"""Gauss-newton method.
|
|
48
|
+
|
|
49
|
+
To use this, the closure should return a vector of values to minimize sum of squares of.
|
|
50
|
+
Please add the ``backward`` argument, it will always be False but it is required.
|
|
51
|
+
Gradients will be calculated via batched autograd within this module, you don't need to
|
|
52
|
+
implement the backward pass. Please see below for an example.
|
|
53
|
+
|
|
54
|
+
Note:
|
|
55
|
+
This method requires ``ndim^2`` memory, however, if it is used within ``tz.m.TrustCG`` trust region,
|
|
56
|
+
the memory requirement is ``ndim*m``, where ``m`` is number of values in the output.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
reg (float, optional): regularization parameter. Defaults to 1e-8.
|
|
60
|
+
batched (bool, optional): whether to use vmapping. Defaults to True.
|
|
61
|
+
|
|
62
|
+
Examples:
|
|
63
|
+
|
|
64
|
+
minimizing the rosenbrock function:
|
|
65
|
+
```python
|
|
66
|
+
def rosenbrock(X):
|
|
67
|
+
x1, x2 = X
|
|
68
|
+
return torch.stack([(1 - x1), 100 * (x2 - x1**2)])
|
|
69
|
+
|
|
70
|
+
X = torch.tensor([-1.1, 2.5], requires_grad=True)
|
|
71
|
+
opt = tz.Modular([X], tz.m.GaussNewton(), tz.m.Backtracking())
|
|
72
|
+
|
|
73
|
+
# define the closure for line search
|
|
74
|
+
def closure(backward=True):
|
|
75
|
+
return rosenbrock(X)
|
|
76
|
+
|
|
77
|
+
# minimize
|
|
78
|
+
for iter in range(10):
|
|
79
|
+
loss = opt.step(closure)
|
|
80
|
+
print(f'{loss = }')
|
|
81
|
+
```
|
|
82
|
+
|
|
83
|
+
training a neural network with a matrix-free GN trust region:
|
|
84
|
+
```python
|
|
85
|
+
X = torch.randn(64, 20)
|
|
86
|
+
y = torch.randn(64, 10)
|
|
87
|
+
|
|
88
|
+
model = nn.Sequential(nn.Linear(20, 64), nn.ELU(), nn.Linear(64, 10))
|
|
89
|
+
opt = tz.Modular(
|
|
90
|
+
model.parameters(),
|
|
91
|
+
tz.m.TrustCG(tz.m.GaussNewton()),
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
def closure(backward=True):
|
|
95
|
+
y_hat = model(X) # (64, 10)
|
|
96
|
+
return (y_hat - y).pow(2).mean(0) # (10, )
|
|
97
|
+
|
|
98
|
+
for i in range(100):
|
|
99
|
+
losses = opt.step(closure)
|
|
100
|
+
if i % 10 == 0:
|
|
101
|
+
print(f'{losses.mean() = }')
|
|
102
|
+
```
|
|
103
|
+
"""
|
|
104
|
+
def __init__(self, reg:float = 1e-8, batched:bool=True, ):
|
|
105
|
+
super().__init__(defaults=dict(batched=batched, reg=reg))
|
|
106
|
+
|
|
107
|
+
@torch.no_grad
|
|
108
|
+
def update(self, var):
|
|
109
|
+
params = var.params
|
|
110
|
+
batched = self.defaults['batched']
|
|
111
|
+
|
|
112
|
+
closure = var.closure
|
|
113
|
+
assert closure is not None
|
|
114
|
+
|
|
115
|
+
# gauss newton direction
|
|
116
|
+
with torch.enable_grad():
|
|
117
|
+
f = var.get_loss(backward=False) # n_out
|
|
118
|
+
assert isinstance(f, torch.Tensor)
|
|
119
|
+
G_list = jacobian_wrt([f.ravel()], params, batched=batched)
|
|
120
|
+
|
|
121
|
+
var.loss = f.pow(2).sum()
|
|
122
|
+
|
|
123
|
+
G = self.global_state["G"] = flatten_jacobian(G_list) # (n_out, ndim)
|
|
124
|
+
Gtf = G.T @ f.detach() # (ndim)
|
|
125
|
+
self.global_state["Gtf"] = Gtf
|
|
126
|
+
var.grad = vec_to_tensors(Gtf, var.params)
|
|
127
|
+
|
|
128
|
+
# set closure to calculate sum of squares for line searches etc
|
|
129
|
+
if var.closure is not None:
|
|
130
|
+
def sos_closure(backward=True):
|
|
131
|
+
if backward:
|
|
132
|
+
var.zero_grad()
|
|
133
|
+
with torch.enable_grad():
|
|
134
|
+
loss = closure(False).pow(2).sum()
|
|
135
|
+
loss.backward()
|
|
136
|
+
return loss
|
|
137
|
+
|
|
138
|
+
loss = closure(False).pow(2).sum()
|
|
139
|
+
return loss
|
|
140
|
+
|
|
141
|
+
var.closure = sos_closure
|
|
142
|
+
|
|
143
|
+
@torch.no_grad
|
|
144
|
+
def apply(self, var):
|
|
145
|
+
reg = self.defaults['reg']
|
|
146
|
+
|
|
147
|
+
G = self.global_state['G']
|
|
148
|
+
Gtf = self.global_state['Gtf']
|
|
149
|
+
|
|
150
|
+
GtG = G.T @ G # (ndim, ndim)
|
|
151
|
+
if reg != 0:
|
|
152
|
+
GtG.add_(torch.eye(GtG.size(0), device=GtG.device, dtype=GtG.dtype).mul_(reg))
|
|
153
|
+
|
|
154
|
+
v = torch.linalg.lstsq(GtG, Gtf).solution # pylint:disable=not-callable
|
|
155
|
+
|
|
156
|
+
var.update = vec_to_tensors(v, var.params)
|
|
157
|
+
return var
|
|
158
|
+
|
|
159
|
+
def get_H(self, var):
|
|
160
|
+
G = self.global_state['G']
|
|
161
|
+
return linear_operator.AtA(G)
|
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
from .adaptive import
|
|
1
|
+
from .adaptive import AdaptiveTracking
|
|
2
2
|
from .backtracking import AdaptiveBacktracking, Backtracking
|
|
3
3
|
from .line_search import LineSearchBase
|
|
4
4
|
from .scipy import ScipyMinimizeScalar
|
|
5
|
-
from .strong_wolfe import StrongWolfe
|
|
5
|
+
from .strong_wolfe import StrongWolfe
|
|
@@ -0,0 +1,289 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
from .line_search import LineSearchBase
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
# polynomial interpolation
|
|
8
|
+
# this code is from https://github.com/hjmshi/PyTorch-LBFGS/blob/master/functions/LBFGS.py
|
|
9
|
+
# PyTorch-LBFGS: A PyTorch Implementation of L-BFGS
|
|
10
|
+
def polyinterp(points, x_min_bound=None, x_max_bound=None, plot=False):
|
|
11
|
+
"""
|
|
12
|
+
Gives the minimizer and minimum of the interpolating polynomial over given points
|
|
13
|
+
based on function and derivative information. Defaults to bisection if no critical
|
|
14
|
+
points are valid.
|
|
15
|
+
|
|
16
|
+
Based on polyinterp.m Matlab function in minFunc by Mark Schmidt with some slight
|
|
17
|
+
modifications.
|
|
18
|
+
|
|
19
|
+
Implemented by: Hao-Jun Michael Shi and Dheevatsa Mudigere
|
|
20
|
+
Last edited 12/6/18.
|
|
21
|
+
|
|
22
|
+
Inputs:
|
|
23
|
+
points (nparray): two-dimensional array with each point of form [x f g]
|
|
24
|
+
x_min_bound (float): minimum value that brackets minimum (default: minimum of points)
|
|
25
|
+
x_max_bound (float): maximum value that brackets minimum (default: maximum of points)
|
|
26
|
+
plot (bool): plot interpolating polynomial
|
|
27
|
+
|
|
28
|
+
Outputs:
|
|
29
|
+
x_sol (float): minimizer of interpolating polynomial
|
|
30
|
+
F_min (float): minimum of interpolating polynomial
|
|
31
|
+
|
|
32
|
+
Note:
|
|
33
|
+
. Set f or g to np.nan if they are unknown
|
|
34
|
+
|
|
35
|
+
"""
|
|
36
|
+
no_points = points.shape[0]
|
|
37
|
+
order = np.sum(1 - np.isnan(points[:, 1:3]).astype('int')) - 1
|
|
38
|
+
|
|
39
|
+
x_min = np.min(points[:, 0])
|
|
40
|
+
x_max = np.max(points[:, 0])
|
|
41
|
+
|
|
42
|
+
# compute bounds of interpolation area
|
|
43
|
+
if x_min_bound is None:
|
|
44
|
+
x_min_bound = x_min
|
|
45
|
+
if x_max_bound is None:
|
|
46
|
+
x_max_bound = x_max
|
|
47
|
+
|
|
48
|
+
# explicit formula for quadratic interpolation
|
|
49
|
+
if no_points == 2 and order == 2 and plot is False:
|
|
50
|
+
# Solution to quadratic interpolation is given by:
|
|
51
|
+
# a = -(f1 - f2 - g1(x1 - x2))/(x1 - x2)^2
|
|
52
|
+
# x_min = x1 - g1/(2a)
|
|
53
|
+
# if x1 = 0, then is given by:
|
|
54
|
+
# x_min = - (g1*x2^2)/(2(f2 - f1 - g1*x2))
|
|
55
|
+
|
|
56
|
+
if points[0, 0] == 0:
|
|
57
|
+
x_sol = -points[0, 2] * points[1, 0] ** 2 / (2 * (points[1, 1] - points[0, 1] - points[0, 2] * points[1, 0]))
|
|
58
|
+
else:
|
|
59
|
+
a = -(points[0, 1] - points[1, 1] - points[0, 2] * (points[0, 0] - points[1, 0])) / (points[0, 0] - points[1, 0]) ** 2
|
|
60
|
+
x_sol = points[0, 0] - points[0, 2]/(2*a)
|
|
61
|
+
|
|
62
|
+
x_sol = np.minimum(np.maximum(x_min_bound, x_sol), x_max_bound)
|
|
63
|
+
|
|
64
|
+
# explicit formula for cubic interpolation
|
|
65
|
+
elif no_points == 2 and order == 3 and plot is False:
|
|
66
|
+
# Solution to cubic interpolation is given by:
|
|
67
|
+
# d1 = g1 + g2 - 3((f1 - f2)/(x1 - x2))
|
|
68
|
+
# d2 = sqrt(d1^2 - g1*g2)
|
|
69
|
+
# x_min = x2 - (x2 - x1)*((g2 + d2 - d1)/(g2 - g1 + 2*d2))
|
|
70
|
+
d1 = points[0, 2] + points[1, 2] - 3 * ((points[0, 1] - points[1, 1]) / (points[0, 0] - points[1, 0]))
|
|
71
|
+
value = d1 ** 2 - points[0, 2] * points[1, 2]
|
|
72
|
+
if value > 0:
|
|
73
|
+
d2 = np.sqrt(value)
|
|
74
|
+
x_sol = points[1, 0] - (points[1, 0] - points[0, 0]) * ((points[1, 2] + d2 - d1) / (points[1, 2] - points[0, 2] + 2 * d2))
|
|
75
|
+
x_sol = np.minimum(np.maximum(x_min_bound, x_sol), x_max_bound)
|
|
76
|
+
else:
|
|
77
|
+
x_sol = (x_max_bound + x_min_bound)/2
|
|
78
|
+
|
|
79
|
+
# solve linear system
|
|
80
|
+
else:
|
|
81
|
+
# define linear constraints
|
|
82
|
+
A = np.zeros((0, order + 1))
|
|
83
|
+
b = np.zeros((0, 1))
|
|
84
|
+
|
|
85
|
+
# add linear constraints on function values
|
|
86
|
+
for i in range(no_points):
|
|
87
|
+
if not np.isnan(points[i, 1]):
|
|
88
|
+
constraint = np.zeros((1, order + 1))
|
|
89
|
+
for j in range(order, -1, -1):
|
|
90
|
+
constraint[0, order - j] = points[i, 0] ** j
|
|
91
|
+
A = np.append(A, constraint, 0)
|
|
92
|
+
b = np.append(b, points[i, 1])
|
|
93
|
+
|
|
94
|
+
# add linear constraints on gradient values
|
|
95
|
+
for i in range(no_points):
|
|
96
|
+
if not np.isnan(points[i, 2]):
|
|
97
|
+
constraint = np.zeros((1, order + 1))
|
|
98
|
+
for j in range(order):
|
|
99
|
+
constraint[0, j] = (order - j) * points[i, 0] ** (order - j - 1)
|
|
100
|
+
A = np.append(A, constraint, 0)
|
|
101
|
+
b = np.append(b, points[i, 2])
|
|
102
|
+
|
|
103
|
+
# check if system is solvable
|
|
104
|
+
if A.shape[0] != A.shape[1] or np.linalg.matrix_rank(A) != A.shape[0]:
|
|
105
|
+
x_sol = (x_min_bound + x_max_bound)/2
|
|
106
|
+
f_min = np.inf
|
|
107
|
+
else:
|
|
108
|
+
# solve linear system for interpolating polynomial
|
|
109
|
+
coeff = np.linalg.solve(A, b)
|
|
110
|
+
|
|
111
|
+
# compute critical points
|
|
112
|
+
dcoeff = np.zeros(order)
|
|
113
|
+
for i in range(len(coeff) - 1):
|
|
114
|
+
dcoeff[i] = coeff[i] * (order - i)
|
|
115
|
+
|
|
116
|
+
crit_pts = np.array([x_min_bound, x_max_bound])
|
|
117
|
+
crit_pts = np.append(crit_pts, points[:, 0])
|
|
118
|
+
|
|
119
|
+
if not np.isinf(dcoeff).any():
|
|
120
|
+
roots = np.roots(dcoeff)
|
|
121
|
+
crit_pts = np.append(crit_pts, roots)
|
|
122
|
+
|
|
123
|
+
# test critical points
|
|
124
|
+
f_min = np.inf
|
|
125
|
+
x_sol = (x_min_bound + x_max_bound) / 2 # defaults to bisection
|
|
126
|
+
for crit_pt in crit_pts:
|
|
127
|
+
if np.isreal(crit_pt):
|
|
128
|
+
if not np.isrealobj(crit_pt): crit_pt = crit_pt.real
|
|
129
|
+
if crit_pt >= x_min_bound and crit_pt <= x_max_bound:
|
|
130
|
+
F_cp = np.polyval(coeff, crit_pt)
|
|
131
|
+
if np.isreal(F_cp) and F_cp < f_min:
|
|
132
|
+
x_sol = np.real(crit_pt)
|
|
133
|
+
f_min = np.real(F_cp)
|
|
134
|
+
|
|
135
|
+
if(plot):
|
|
136
|
+
import matplotlib.pyplot as plt
|
|
137
|
+
plt.figure()
|
|
138
|
+
x = np.arange(x_min_bound, x_max_bound, (x_max_bound - x_min_bound)/10000)
|
|
139
|
+
f = np.polyval(coeff, x)
|
|
140
|
+
plt.plot(x, f)
|
|
141
|
+
plt.plot(x_sol, f_min, 'x')
|
|
142
|
+
|
|
143
|
+
return x_sol
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
# polynomial interpolation
|
|
147
|
+
# this code is based on https://github.com/hjmshi/PyTorch-LBFGS/blob/master/functions/LBFGS.py
|
|
148
|
+
# PyTorch-LBFGS: A PyTorch Implementation of L-BFGS
|
|
149
|
+
# this one is modified where instead of clipping the solution by bounds, it tries a lower degree polynomial
|
|
150
|
+
# all the way to bisection
|
|
151
|
+
def _within_bounds(x, lb, ub):
|
|
152
|
+
if lb is not None and x < lb: return False
|
|
153
|
+
if ub is not None and x > ub: return False
|
|
154
|
+
return True
|
|
155
|
+
|
|
156
|
+
def _quad_interp(points):
|
|
157
|
+
assert points.shape[0] == 2, points.shape
|
|
158
|
+
if points[0, 0] == 0:
|
|
159
|
+
denom = 2 * (points[1, 1] - points[0, 1] - points[0, 2] * points[1, 0])
|
|
160
|
+
if abs(denom) > 1e-32:
|
|
161
|
+
return -points[0, 2] * points[1, 0] ** 2 / denom
|
|
162
|
+
else:
|
|
163
|
+
denom = (points[0, 0] - points[1, 0]) ** 2
|
|
164
|
+
if denom > 1e-32:
|
|
165
|
+
a = -(points[0, 1] - points[1, 1] - points[0, 2] * (points[0, 0] - points[1, 0])) / denom
|
|
166
|
+
if a > 1e-32:
|
|
167
|
+
return points[0, 0] - points[0, 2]/(2*a)
|
|
168
|
+
return None
|
|
169
|
+
|
|
170
|
+
def _cubic_interp(points, lb, ub):
|
|
171
|
+
assert points.shape[0] == 2, points.shape
|
|
172
|
+
denom = points[0, 0] - points[1, 0]
|
|
173
|
+
if abs(denom) > 1e-32:
|
|
174
|
+
d1 = points[0, 2] + points[1, 2] - 3 * ((points[0, 1] - points[1, 1]) / denom)
|
|
175
|
+
value = d1 ** 2 - points[0, 2] * points[1, 2]
|
|
176
|
+
if value > 0:
|
|
177
|
+
d2 = np.sqrt(value)
|
|
178
|
+
denom = points[1, 2] - points[0, 2] + 2 * d2
|
|
179
|
+
if abs(denom) > 1e-32:
|
|
180
|
+
x_sol = points[1, 0] - (points[1, 0] - points[0, 0]) * ((points[1, 2] + d2 - d1) / denom)
|
|
181
|
+
if _within_bounds(x_sol, lb, ub): return x_sol
|
|
182
|
+
|
|
183
|
+
# try quadratic interpolations
|
|
184
|
+
x_sol = _quad_interp(points)
|
|
185
|
+
if x_sol is not None and _within_bounds(x_sol, lb, ub): return x_sol
|
|
186
|
+
|
|
187
|
+
return None
|
|
188
|
+
|
|
189
|
+
def _poly_interp(points, lb, ub):
|
|
190
|
+
no_points = points.shape[0]
|
|
191
|
+
assert no_points > 2, points.shape
|
|
192
|
+
order = np.sum(1 - np.isnan(points[:, 1:3]).astype('int')) - 1
|
|
193
|
+
|
|
194
|
+
# define linear constraints
|
|
195
|
+
A = np.zeros((0, order + 1))
|
|
196
|
+
b = np.zeros((0, 1))
|
|
197
|
+
|
|
198
|
+
# add linear constraints on function values
|
|
199
|
+
for i in range(no_points):
|
|
200
|
+
if not np.isnan(points[i, 1]):
|
|
201
|
+
constraint = np.zeros((1, order + 1))
|
|
202
|
+
for j in range(order, -1, -1):
|
|
203
|
+
constraint[0, order - j] = points[i, 0] ** j
|
|
204
|
+
A = np.append(A, constraint, 0)
|
|
205
|
+
b = np.append(b, points[i, 1])
|
|
206
|
+
|
|
207
|
+
# add linear constraints on gradient values
|
|
208
|
+
for i in range(no_points):
|
|
209
|
+
if not np.isnan(points[i, 2]):
|
|
210
|
+
constraint = np.zeros((1, order + 1))
|
|
211
|
+
for j in range(order):
|
|
212
|
+
constraint[0, j] = (order - j) * points[i, 0] ** (order - j - 1)
|
|
213
|
+
A = np.append(A, constraint, 0)
|
|
214
|
+
b = np.append(b, points[i, 2])
|
|
215
|
+
|
|
216
|
+
# check if system is solvable
|
|
217
|
+
if A.shape[0] != A.shape[1] or np.linalg.matrix_rank(A) != A.shape[0]:
|
|
218
|
+
return None
|
|
219
|
+
|
|
220
|
+
# solve linear system for interpolating polynomial
|
|
221
|
+
coeff = np.linalg.solve(A, b)
|
|
222
|
+
|
|
223
|
+
# compute critical points
|
|
224
|
+
dcoeff = np.zeros(order)
|
|
225
|
+
for i in range(len(coeff) - 1):
|
|
226
|
+
dcoeff[i] = coeff[i] * (order - i)
|
|
227
|
+
|
|
228
|
+
lower = np.min(points[:, 0]) if lb is None else lb
|
|
229
|
+
upper = np.max(points[:, 0]) if ub is None else ub
|
|
230
|
+
|
|
231
|
+
crit_pts = np.array([lower, upper])
|
|
232
|
+
crit_pts = np.append(crit_pts, points[:, 0])
|
|
233
|
+
|
|
234
|
+
if not np.isinf(dcoeff).any():
|
|
235
|
+
roots = np.roots(dcoeff)
|
|
236
|
+
crit_pts = np.append(crit_pts, roots)
|
|
237
|
+
|
|
238
|
+
# test critical points
|
|
239
|
+
f_min = np.inf
|
|
240
|
+
x_sol = None
|
|
241
|
+
for crit_pt in crit_pts:
|
|
242
|
+
if np.isreal(crit_pt):
|
|
243
|
+
if not np.isrealobj(crit_pt): crit_pt = crit_pt.real
|
|
244
|
+
if _within_bounds(crit_pt, lb, ub):
|
|
245
|
+
F_cp = np.polyval(coeff, crit_pt)
|
|
246
|
+
if np.isreal(F_cp) and F_cp < f_min:
|
|
247
|
+
x_sol = np.real(crit_pt)
|
|
248
|
+
f_min = np.real(F_cp)
|
|
249
|
+
|
|
250
|
+
return x_sol
|
|
251
|
+
|
|
252
|
+
def polyinterp2(points, lb, ub, unbounded: bool = False):
|
|
253
|
+
no_points = points.shape[0]
|
|
254
|
+
if no_points <= 1:
|
|
255
|
+
return (lb + ub)/2
|
|
256
|
+
|
|
257
|
+
order = np.sum(1 - np.isnan(points[:, 1:3]).astype('int')) - 1
|
|
258
|
+
|
|
259
|
+
x_min = np.min(points[:, 0])
|
|
260
|
+
x_max = np.max(points[:, 0])
|
|
261
|
+
|
|
262
|
+
# compute bounds of interpolation area
|
|
263
|
+
if not unbounded:
|
|
264
|
+
if lb is None:
|
|
265
|
+
lb = x_min
|
|
266
|
+
if ub is None:
|
|
267
|
+
ub = x_max
|
|
268
|
+
|
|
269
|
+
if no_points == 2 and order == 2:
|
|
270
|
+
x_sol = _quad_interp(points)
|
|
271
|
+
if x_sol is not None and _within_bounds(x_sol, lb, ub): return x_sol
|
|
272
|
+
return (lb + ub)/2
|
|
273
|
+
|
|
274
|
+
if no_points == 2 and order == 3:
|
|
275
|
+
x_sol = _cubic_interp(points, lb, ub) # includes fallback on _quad_interp
|
|
276
|
+
if x_sol is not None and _within_bounds(x_sol, lb, ub): return x_sol
|
|
277
|
+
return (lb + ub)/2
|
|
278
|
+
|
|
279
|
+
if no_points <= 2: # order < 2
|
|
280
|
+
return (lb + ub)/2
|
|
281
|
+
|
|
282
|
+
if no_points == 3:
|
|
283
|
+
for p in (points[:2], points[1:], points[::2]):
|
|
284
|
+
x_sol = _cubic_interp(p, lb, ub)
|
|
285
|
+
if x_sol is not None and _within_bounds(x_sol, lb, ub): return x_sol
|
|
286
|
+
|
|
287
|
+
x_sol = _poly_interp(points, lb, ub)
|
|
288
|
+
if x_sol is not None and _within_bounds(x_sol, lb, ub): return x_sol
|
|
289
|
+
return polyinterp2(points[1:], lb, ub)
|
|
@@ -1,58 +1,73 @@
|
|
|
1
1
|
import math
|
|
2
|
+
from bisect import insort
|
|
3
|
+
from collections import deque
|
|
2
4
|
from collections.abc import Callable
|
|
3
5
|
from operator import itemgetter
|
|
4
6
|
|
|
7
|
+
import numpy as np
|
|
5
8
|
import torch
|
|
6
9
|
|
|
7
|
-
from .line_search import LineSearchBase
|
|
8
|
-
|
|
10
|
+
from .line_search import LineSearchBase, TerminationCondition, termination_condition
|
|
9
11
|
|
|
10
12
|
|
|
11
13
|
def adaptive_tracking(
|
|
12
14
|
f,
|
|
13
|
-
|
|
15
|
+
a_init,
|
|
14
16
|
maxiter: int,
|
|
15
17
|
nplus: float = 2,
|
|
16
18
|
nminus: float = 0.5,
|
|
19
|
+
f_0 = None,
|
|
17
20
|
):
|
|
18
|
-
|
|
21
|
+
niter = 0
|
|
22
|
+
if f_0 is None: f_0 = f(0)
|
|
19
23
|
|
|
20
|
-
|
|
21
|
-
|
|
24
|
+
a = a_init
|
|
25
|
+
f_a = f(a)
|
|
22
26
|
|
|
23
27
|
# backtrack
|
|
24
|
-
|
|
25
|
-
|
|
28
|
+
a_prev = a
|
|
29
|
+
f_prev = math.inf
|
|
30
|
+
if (f_a > f_0) or (not math.isfinite(f_a)):
|
|
31
|
+
while (f_a < f_prev) or not math.isfinite(f_a):
|
|
32
|
+
a_prev, f_prev = a, f_a
|
|
26
33
|
maxiter -= 1
|
|
27
|
-
if maxiter < 0:
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
34
|
+
if maxiter < 0: break
|
|
35
|
+
|
|
36
|
+
a = a*nminus
|
|
37
|
+
f_a = f(a)
|
|
38
|
+
niter += 1
|
|
39
|
+
|
|
40
|
+
if f_prev < f_0: return a_prev, f_prev, niter
|
|
41
|
+
return 0, f_0, niter
|
|
31
42
|
|
|
32
43
|
# forwardtrack
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
while f_prev >= f_t:
|
|
44
|
+
a_prev = a
|
|
45
|
+
f_prev = math.inf
|
|
46
|
+
while (f_a <= f_prev) and math.isfinite(f_a):
|
|
47
|
+
a_prev, f_prev = a, f_a
|
|
38
48
|
maxiter -= 1
|
|
39
|
-
if maxiter < 0:
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
49
|
+
if maxiter < 0: break
|
|
50
|
+
|
|
51
|
+
a *= nplus
|
|
52
|
+
f_a = f(a)
|
|
53
|
+
niter+= 1
|
|
54
|
+
|
|
55
|
+
if f_prev < f_0: return a_prev, f_prev, niter
|
|
56
|
+
return 0, f_0, niter
|
|
57
|
+
|
|
44
58
|
|
|
45
|
-
class
|
|
46
|
-
"""
|
|
47
|
-
|
|
59
|
+
class AdaptiveTracking(LineSearchBase):
|
|
60
|
+
"""A line search that evaluates previous step size, if value increased, backtracks until the value stops decreasing,
|
|
61
|
+
otherwise forward-tracks until value stops decreasing.
|
|
48
62
|
|
|
49
63
|
Args:
|
|
50
64
|
init (float, optional): initial step size. Defaults to 1.0.
|
|
51
|
-
|
|
52
|
-
|
|
65
|
+
nplus (float, optional): multiplier to step size if initial step size is optimal. Defaults to 2.
|
|
66
|
+
nminus (float, optional): multiplier to step size if initial step size is too big. Defaults to 0.5.
|
|
67
|
+
maxiter (int, optional): maximum number of function evaluations per step. Defaults to 10.
|
|
53
68
|
adaptive (bool, optional):
|
|
54
|
-
when enabled, if line search failed,
|
|
55
|
-
Otherwise it
|
|
69
|
+
when enabled, if line search failed, step size will continue decreasing on the next step.
|
|
70
|
+
Otherwise it will restart the line search from ``init`` step size. Defaults to True.
|
|
56
71
|
"""
|
|
57
72
|
def __init__(
|
|
58
73
|
self,
|
|
@@ -62,38 +77,48 @@ class AdaptiveLineSearch(LineSearchBase):
|
|
|
62
77
|
maxiter: int = 10,
|
|
63
78
|
adaptive=True,
|
|
64
79
|
):
|
|
65
|
-
defaults=dict(init=init,nplus=nplus,nminus=nminus,maxiter=maxiter,adaptive=adaptive
|
|
80
|
+
defaults=dict(init=init,nplus=nplus,nminus=nminus,maxiter=maxiter,adaptive=adaptive)
|
|
66
81
|
super().__init__(defaults=defaults)
|
|
67
|
-
self.global_state['beta_scale'] = 1.0
|
|
68
82
|
|
|
69
83
|
def reset(self):
|
|
70
84
|
super().reset()
|
|
71
|
-
self.global_state['beta_scale'] = 1.0
|
|
72
85
|
|
|
73
86
|
@torch.no_grad
|
|
74
87
|
def search(self, update, var):
|
|
75
88
|
init, nplus, nminus, maxiter, adaptive = itemgetter(
|
|
76
|
-
'init', 'nplus', 'nminus', 'maxiter', 'adaptive')(self.
|
|
89
|
+
'init', 'nplus', 'nminus', 'maxiter', 'adaptive')(self.defaults)
|
|
77
90
|
|
|
78
91
|
objective = self.make_objective(var=var)
|
|
79
92
|
|
|
80
|
-
#
|
|
81
|
-
|
|
93
|
+
# scale a_prev
|
|
94
|
+
a_prev = self.global_state.get('a_prev', init)
|
|
95
|
+
if adaptive: a_prev = a_prev * self.global_state.get('init_scale', 1)
|
|
82
96
|
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
97
|
+
a_init = a_prev
|
|
98
|
+
if a_init < torch.finfo(var.params[0].dtype).tiny * 2:
|
|
99
|
+
a_init = torch.finfo(var.params[0].dtype).max / 2
|
|
86
100
|
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
101
|
+
step_size, f, niter = adaptive_tracking(
|
|
102
|
+
objective,
|
|
103
|
+
a_init=a_init,
|
|
104
|
+
maxiter=maxiter,
|
|
105
|
+
nplus=nplus,
|
|
106
|
+
nminus=nminus,
|
|
107
|
+
)
|
|
91
108
|
|
|
92
109
|
# found an alpha that reduces loss
|
|
93
110
|
if step_size != 0:
|
|
94
|
-
|
|
111
|
+
assert (var.loss is None) or (math.isfinite(f) and f < var.loss)
|
|
112
|
+
self.global_state['init_scale'] = 1
|
|
113
|
+
|
|
114
|
+
# if niter == 1, forward tracking failed to decrease function value compared to f_a_prev
|
|
115
|
+
if niter == 1 and step_size >= a_init: step_size *= nminus
|
|
116
|
+
|
|
117
|
+
self.global_state['a_prev'] = step_size
|
|
95
118
|
return step_size
|
|
96
119
|
|
|
97
120
|
# on fail reduce beta scale value
|
|
98
|
-
self.global_state['
|
|
121
|
+
self.global_state['init_scale'] = self.global_state.get('init_scale', 1) * nminus**maxiter
|
|
122
|
+
self.global_state['a_prev'] = init
|
|
99
123
|
return 0
|
|
124
|
+
|