tigramite-fast 5.2.10.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tigramite/__init__.py +0 -0
- tigramite/causal_effects.py +1525 -0
- tigramite/causal_mediation.py +1592 -0
- tigramite/data_processing.py +1574 -0
- tigramite/graphs.py +1509 -0
- tigramite/independence_tests/LBFGS.py +1114 -0
- tigramite/independence_tests/__init__.py +0 -0
- tigramite/independence_tests/cmiknn.py +661 -0
- tigramite/independence_tests/cmiknn_mixed.py +1397 -0
- tigramite/independence_tests/cmisymb.py +286 -0
- tigramite/independence_tests/gpdc.py +664 -0
- tigramite/independence_tests/gpdc_torch.py +820 -0
- tigramite/independence_tests/gsquared.py +190 -0
- tigramite/independence_tests/independence_tests_base.py +1310 -0
- tigramite/independence_tests/oracle_conditional_independence.py +1582 -0
- tigramite/independence_tests/pairwise_CI.py +383 -0
- tigramite/independence_tests/parcorr.py +369 -0
- tigramite/independence_tests/parcorr_mult.py +485 -0
- tigramite/independence_tests/parcorr_wls.py +451 -0
- tigramite/independence_tests/regressionCI.py +403 -0
- tigramite/independence_tests/robust_parcorr.py +403 -0
- tigramite/jpcmciplus.py +966 -0
- tigramite/lpcmci.py +3649 -0
- tigramite/models.py +2257 -0
- tigramite/pcmci.py +3935 -0
- tigramite/pcmci_base.py +1218 -0
- tigramite/plotting.py +4735 -0
- tigramite/rpcmci.py +467 -0
- tigramite/toymodels/__init__.py +0 -0
- tigramite/toymodels/context_model.py +261 -0
- tigramite/toymodels/non_additive.py +1231 -0
- tigramite/toymodels/structural_causal_processes.py +1201 -0
- tigramite/toymodels/surrogate_generator.py +319 -0
- tigramite_fast-5.2.10.1.dist-info/METADATA +182 -0
- tigramite_fast-5.2.10.1.dist-info/RECORD +38 -0
- tigramite_fast-5.2.10.1.dist-info/WHEEL +5 -0
- tigramite_fast-5.2.10.1.dist-info/licenses/license.txt +621 -0
- tigramite_fast-5.2.10.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1114 @@
|
|
|
1
|
+
"""
|
|
2
|
+
MIT License
|
|
3
|
+
|
|
4
|
+
Copyright (c) 2017 Jake Gardner
|
|
5
|
+
|
|
6
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
7
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
8
|
+
in the Software without restriction, including without limitation the rights
|
|
9
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
10
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
11
|
+
furnished to do so, subject to the following conditions:
|
|
12
|
+
|
|
13
|
+
The above copyright notice and this permission notice shall be included in all
|
|
14
|
+
copies or substantial portions of the Software.
|
|
15
|
+
|
|
16
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
17
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
18
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
19
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
20
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
21
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
22
|
+
SOFTWARE.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
import torch
|
|
26
|
+
import numpy as np
|
|
27
|
+
# import matplotlib.pyplot as plt
|
|
28
|
+
from functools import reduce
|
|
29
|
+
from copy import deepcopy
|
|
30
|
+
from torch.optim import Optimizer
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def is_legal(v):
|
|
34
|
+
"""
|
|
35
|
+
Checks that tensor is not NaN or Inf.
|
|
36
|
+
Inputs:
|
|
37
|
+
v (tensor): tensor to be checked
|
|
38
|
+
"""
|
|
39
|
+
legal = not torch.isnan(v).any() and not torch.isinf(v)
|
|
40
|
+
|
|
41
|
+
return legal
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def polyinterp(points, x_min_bound=None, x_max_bound=None, plot=False):
|
|
45
|
+
"""
|
|
46
|
+
Gives the minimizer and minimum of the interpolating polynomial over given points
|
|
47
|
+
based on function and derivative information. Defaults to bisection if no critical
|
|
48
|
+
points are valid.
|
|
49
|
+
Based on polyinterp.m Matlab function in minFunc by Mark Schmidt with some slight
|
|
50
|
+
modifications.
|
|
51
|
+
Implemented by: Hao-Jun Michael Shi and Dheevatsa Mudigere
|
|
52
|
+
Last edited 12/6/18.
|
|
53
|
+
Inputs:
|
|
54
|
+
points (nparray): two-dimensional array with each point of form [x f g]
|
|
55
|
+
x_min_bound (float): minimum value that brackets minimum (default: minimum of points)
|
|
56
|
+
x_max_bound (float): maximum value that brackets minimum (default: maximum of points)
|
|
57
|
+
plot (bool): plot interpolating polynomial
|
|
58
|
+
Outputs:
|
|
59
|
+
x_sol (float): minimizer of interpolating polynomial
|
|
60
|
+
F_min (float): minimum of interpolating polynomial
|
|
61
|
+
Note:
|
|
62
|
+
. Set f or g to np.nan if they are unknown
|
|
63
|
+
"""
|
|
64
|
+
no_points = points.shape[0]
|
|
65
|
+
order = np.sum(1 - np.isnan(points[:, 1:3]).astype("int")) - 1
|
|
66
|
+
|
|
67
|
+
x_min = np.min(points[:, 0])
|
|
68
|
+
x_max = np.max(points[:, 0])
|
|
69
|
+
|
|
70
|
+
# compute bounds of interpolation area
|
|
71
|
+
if x_min_bound is None:
|
|
72
|
+
x_min_bound = x_min
|
|
73
|
+
if x_max_bound is None:
|
|
74
|
+
x_max_bound = x_max
|
|
75
|
+
|
|
76
|
+
# explicit formula for quadratic interpolation
|
|
77
|
+
if no_points == 2 and order == 2 and plot is False:
|
|
78
|
+
# Solution to quadratic interpolation is given by:
|
|
79
|
+
# a = -(f1 - f2 - g1(x1 - x2))/(x1 - x2)^2
|
|
80
|
+
# x_min = x1 - g1/(2a)
|
|
81
|
+
# if x1 = 0, then is given by:
|
|
82
|
+
# x_min = - (g1*x2^2)/(2(f2 - f1 - g1*x2))
|
|
83
|
+
|
|
84
|
+
if points[0, 0] == 0:
|
|
85
|
+
x_sol = (
|
|
86
|
+
-points[0, 2] * points[1, 0] ** 2 / (2 * (points[1, 1] - points[0, 1] - points[0, 2] * points[1, 0]))
|
|
87
|
+
)
|
|
88
|
+
else:
|
|
89
|
+
a = (
|
|
90
|
+
-(points[0, 1] - points[1, 1] - points[0, 2] * (points[0, 0] - points[1, 0]))
|
|
91
|
+
/ (points[0, 0] - points[1, 0]) ** 2
|
|
92
|
+
)
|
|
93
|
+
x_sol = points[0, 0] - points[0, 2] / (2 * a)
|
|
94
|
+
|
|
95
|
+
x_sol = np.minimum(np.maximum(x_min_bound, x_sol), x_max_bound)
|
|
96
|
+
|
|
97
|
+
# explicit formula for cubic interpolation
|
|
98
|
+
elif no_points == 2 and order == 3 and plot is False:
|
|
99
|
+
# Solution to cubic interpolation is given by:
|
|
100
|
+
# d1 = g1 + g2 - 3((f1 - f2)/(x1 - x2))
|
|
101
|
+
# d2 = sqrt(d1^2 - g1*g2)
|
|
102
|
+
# x_min = x2 - (x2 - x1)*((g2 + d2 - d1)/(g2 - g1 + 2*d2))
|
|
103
|
+
d1 = points[0, 2] + points[1, 2] - 3 * ((points[0, 1] - points[1, 1]) / (points[0, 0] - points[1, 0]))
|
|
104
|
+
d2 = np.sqrt(d1 ** 2 - points[0, 2] * points[1, 2])
|
|
105
|
+
if np.isreal(d2):
|
|
106
|
+
x_sol = points[1, 0] - (points[1, 0] - points[0, 0]) * (
|
|
107
|
+
(points[1, 2] + d2 - d1) / (points[1, 2] - points[0, 2] + 2 * d2)
|
|
108
|
+
)
|
|
109
|
+
x_sol = np.minimum(np.maximum(x_min_bound, x_sol), x_max_bound)
|
|
110
|
+
else:
|
|
111
|
+
x_sol = (x_max_bound + x_min_bound) / 2
|
|
112
|
+
|
|
113
|
+
# solve linear system
|
|
114
|
+
else:
|
|
115
|
+
# define linear constraints
|
|
116
|
+
A = np.zeros((0, order + 1))
|
|
117
|
+
b = np.zeros((0, 1))
|
|
118
|
+
|
|
119
|
+
# add linear constraints on function values
|
|
120
|
+
for i in range(no_points):
|
|
121
|
+
if not np.isnan(points[i, 1]):
|
|
122
|
+
constraint = np.zeros((1, order + 1))
|
|
123
|
+
for j in range(order, -1, -1):
|
|
124
|
+
constraint[0, order - j] = points[i, 0] ** j
|
|
125
|
+
A = np.append(A, constraint, 0)
|
|
126
|
+
b = np.append(b, points[i, 1])
|
|
127
|
+
|
|
128
|
+
# add linear constraints on gradient values
|
|
129
|
+
for i in range(no_points):
|
|
130
|
+
if not np.isnan(points[i, 2]):
|
|
131
|
+
constraint = np.zeros((1, order + 1))
|
|
132
|
+
for j in range(order):
|
|
133
|
+
constraint[0, j] = (order - j) * points[i, 0] ** (order - j - 1)
|
|
134
|
+
A = np.append(A, constraint, 0)
|
|
135
|
+
b = np.append(b, points[i, 2])
|
|
136
|
+
|
|
137
|
+
# check if system is solvable
|
|
138
|
+
if A.shape[0] != A.shape[1] or np.linalg.matrix_rank(A) != A.shape[0]:
|
|
139
|
+
x_sol = (x_min_bound + x_max_bound) / 2
|
|
140
|
+
f_min = np.Inf
|
|
141
|
+
else:
|
|
142
|
+
# solve linear system for interpolating polynomial
|
|
143
|
+
coeff = np.linalg.solve(A, b)
|
|
144
|
+
|
|
145
|
+
# compute critical points
|
|
146
|
+
dcoeff = np.zeros(order)
|
|
147
|
+
for i in range(len(coeff) - 1):
|
|
148
|
+
dcoeff[i] = coeff[i] * (order - i)
|
|
149
|
+
|
|
150
|
+
crit_pts = np.array([x_min_bound, x_max_bound])
|
|
151
|
+
crit_pts = np.append(crit_pts, points[:, 0])
|
|
152
|
+
|
|
153
|
+
if not np.isinf(dcoeff).any():
|
|
154
|
+
roots = np.roots(dcoeff)
|
|
155
|
+
crit_pts = np.append(crit_pts, roots)
|
|
156
|
+
|
|
157
|
+
# test critical points
|
|
158
|
+
f_min = np.Inf
|
|
159
|
+
x_sol = (x_min_bound + x_max_bound) / 2 # defaults to bisection
|
|
160
|
+
for crit_pt in crit_pts:
|
|
161
|
+
if np.isreal(crit_pt) and crit_pt >= x_min_bound and crit_pt <= x_max_bound:
|
|
162
|
+
F_cp = np.polyval(coeff, crit_pt)
|
|
163
|
+
if np.isreal(F_cp) and F_cp < f_min:
|
|
164
|
+
x_sol = np.real(crit_pt)
|
|
165
|
+
f_min = np.real(F_cp)
|
|
166
|
+
|
|
167
|
+
# if plot:
|
|
168
|
+
# plt.figure()
|
|
169
|
+
# x = np.arange(x_min_bound, x_max_bound, (x_max_bound - x_min_bound) / 10000)
|
|
170
|
+
# f = np.polyval(coeff, x)
|
|
171
|
+
# plt.plot(x, f)
|
|
172
|
+
# plt.plot(x_sol, f_min, "x")
|
|
173
|
+
|
|
174
|
+
return x_sol
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
class LBFGS(Optimizer):
|
|
178
|
+
"""
|
|
179
|
+
Implements the L-BFGS algorithm. Compatible with multi-batch and full-overlap
|
|
180
|
+
L-BFGS implementations and (stochastic) Powell damping. Partly based on the
|
|
181
|
+
original L-BFGS implementation in PyTorch, Mark Schmidt's minFunc MATLAB code,
|
|
182
|
+
and Michael Overton's weak Wolfe line search MATLAB code.
|
|
183
|
+
Implemented by: Hao-Jun Michael Shi and Dheevatsa Mudigere
|
|
184
|
+
Last edited 12/6/18.
|
|
185
|
+
Warnings:
|
|
186
|
+
. Does not support per-parameter options and parameter groups.
|
|
187
|
+
. All parameters have to be on a single device.
|
|
188
|
+
Inputs:
|
|
189
|
+
lr (float): steplength or learning rate (default: 1)
|
|
190
|
+
history_size (int): update history size (default: 10)
|
|
191
|
+
line_search (str): designates line search to use (default: 'Wolfe')
|
|
192
|
+
Options:
|
|
193
|
+
'None': uses steplength designated in algorithm
|
|
194
|
+
'Armijo': uses Armijo backtracking line search
|
|
195
|
+
'Wolfe': uses Armijo-Wolfe bracketing line search
|
|
196
|
+
dtype: data type (default: torch.float)
|
|
197
|
+
debug (bool): debugging mode
|
|
198
|
+
References:
|
|
199
|
+
[1] Berahas, Albert S., Jorge Nocedal, and Martin Takác. "A Multi-Batch L-BFGS
|
|
200
|
+
Method for Machine Learning." Advances in Neural Information Processing
|
|
201
|
+
Systems. 2016.
|
|
202
|
+
[2] Bollapragada, Raghu, et al. "A Progressive Batching L-BFGS Method for Machine
|
|
203
|
+
Learning." International Conference on Machine Learning. 2018.
|
|
204
|
+
[3] Lewis, Adrian S., and Michael L. Overton. "Nonsmooth Optimization via Quasi-Newton
|
|
205
|
+
Methods." Mathematical Programming 141.1-2 (2013): 135-163.
|
|
206
|
+
[4] Liu, Dong C., and Jorge Nocedal. "On the Limited Memory BFGS Method for
|
|
207
|
+
Large Scale Optimization." Mathematical Programming 45.1-3 (1989): 503-528.
|
|
208
|
+
[5] Nocedal, Jorge. "Updating Quasi-Newton Matrices With Limited Storage."
|
|
209
|
+
Mathematics of Computation 35.151 (1980): 773-782.
|
|
210
|
+
[6] Nocedal, Jorge, and Stephen J. Wright. "Numerical Optimization." Springer New York,
|
|
211
|
+
2006.
|
|
212
|
+
[7] Schmidt, Mark. "minFunc: Unconstrained Differentiable Multivariate Optimization
|
|
213
|
+
in Matlab." Software available at http://www.cs.ubc.ca/~schmidtm/Software/minFunc.html
|
|
214
|
+
(2005).
|
|
215
|
+
[8] Schraudolph, Nicol N., Jin Yu, and Simon Günter. "A Stochastic Quasi-Newton
|
|
216
|
+
Method for Online Convex Optimization." Artificial Intelligence and Statistics.
|
|
217
|
+
2007.
|
|
218
|
+
[9] Wang, Xiao, et al. "Stochastic Quasi-Newton Methods for Nonconvex Stochastic
|
|
219
|
+
Optimization." SIAM Journal on Optimization 27.2 (2017): 927-956.
|
|
220
|
+
"""
|
|
221
|
+
|
|
222
|
+
def __init__(self, params, lr=1, history_size=10, line_search="Wolfe", dtype=torch.float, debug=False):
|
|
223
|
+
|
|
224
|
+
# ensure inputs are valid
|
|
225
|
+
if not 0.0 <= lr:
|
|
226
|
+
raise ValueError("Invalid learning rate: {}".format(lr))
|
|
227
|
+
if not 0 <= history_size:
|
|
228
|
+
raise ValueError("Invalid history size: {}".format(history_size))
|
|
229
|
+
if line_search not in ["Armijo", "Wolfe", "None"]:
|
|
230
|
+
raise ValueError("Invalid line search: {}".format(line_search))
|
|
231
|
+
|
|
232
|
+
defaults = dict(lr=lr, history_size=history_size, line_search=line_search, dtype=dtype, debug=debug)
|
|
233
|
+
super(LBFGS, self).__init__(params, defaults)
|
|
234
|
+
|
|
235
|
+
if len(self.param_groups) != 1:
|
|
236
|
+
raise ValueError("L-BFGS doesn't support per-parameter options " "(parameter groups)")
|
|
237
|
+
|
|
238
|
+
self._params = self.param_groups[0]["params"]
|
|
239
|
+
self._numel_cache = None
|
|
240
|
+
|
|
241
|
+
state = self.state["global_state"]
|
|
242
|
+
state.setdefault("n_iter", 0)
|
|
243
|
+
state.setdefault("curv_skips", 0)
|
|
244
|
+
state.setdefault("fail_skips", 0)
|
|
245
|
+
state.setdefault("H_diag", 1)
|
|
246
|
+
state.setdefault("fail", True)
|
|
247
|
+
|
|
248
|
+
state["old_dirs"] = []
|
|
249
|
+
state["old_stps"] = []
|
|
250
|
+
|
|
251
|
+
def _numel(self):
|
|
252
|
+
if self._numel_cache is None:
|
|
253
|
+
self._numel_cache = reduce(lambda total, p: total + p.numel(), self._params, 0)
|
|
254
|
+
return self._numel_cache
|
|
255
|
+
|
|
256
|
+
def _gather_flat_grad(self):
|
|
257
|
+
views = []
|
|
258
|
+
for p in self._params:
|
|
259
|
+
if p.grad is None:
|
|
260
|
+
view = p.data.new(p.data.numel()).zero_()
|
|
261
|
+
elif p.grad.data.is_sparse:
|
|
262
|
+
view = p.grad.data.to_dense().view(-1)
|
|
263
|
+
else:
|
|
264
|
+
view = p.grad.data.view(-1)
|
|
265
|
+
views.append(view)
|
|
266
|
+
return torch.cat(views, 0)
|
|
267
|
+
|
|
268
|
+
def _add_update(self, step_size, update):
|
|
269
|
+
offset = 0
|
|
270
|
+
for p in self._params:
|
|
271
|
+
numel = p.numel()
|
|
272
|
+
# view as to avoid deprecated pointwise semantics
|
|
273
|
+
p.data.add_(update[offset : offset + numel].view_as(p.data), alpha=step_size)
|
|
274
|
+
offset += numel
|
|
275
|
+
assert offset == self._numel()
|
|
276
|
+
|
|
277
|
+
def _copy_params(self):
|
|
278
|
+
current_params = []
|
|
279
|
+
for param in self._params:
|
|
280
|
+
current_params.append(deepcopy(param.data))
|
|
281
|
+
return current_params
|
|
282
|
+
|
|
283
|
+
def _load_params(self, current_params):
|
|
284
|
+
i = 0
|
|
285
|
+
for param in self._params:
|
|
286
|
+
param.data[:] = current_params[i]
|
|
287
|
+
i += 1
|
|
288
|
+
|
|
289
|
+
def line_search(self, line_search):
|
|
290
|
+
"""
|
|
291
|
+
Switches line search option.
|
|
292
|
+
Inputs:
|
|
293
|
+
line_search (str): designates line search to use
|
|
294
|
+
Options:
|
|
295
|
+
'None': uses steplength designated in algorithm
|
|
296
|
+
'Armijo': uses Armijo backtracking line search
|
|
297
|
+
'Wolfe': uses Armijo-Wolfe bracketing line search
|
|
298
|
+
"""
|
|
299
|
+
|
|
300
|
+
group = self.param_groups[0]
|
|
301
|
+
group["line_search"] = line_search
|
|
302
|
+
|
|
303
|
+
return
|
|
304
|
+
|
|
305
|
+
def two_loop_recursion(self, vec):
|
|
306
|
+
"""
|
|
307
|
+
Performs two-loop recursion on given vector to obtain Hv.
|
|
308
|
+
Inputs:
|
|
309
|
+
vec (tensor): 1-D tensor to apply two-loop recursion to
|
|
310
|
+
Output:
|
|
311
|
+
r (tensor): matrix-vector product Hv
|
|
312
|
+
"""
|
|
313
|
+
|
|
314
|
+
group = self.param_groups[0]
|
|
315
|
+
history_size = group["history_size"]
|
|
316
|
+
|
|
317
|
+
state = self.state["global_state"]
|
|
318
|
+
old_dirs = state.get("old_dirs") # change in gradients
|
|
319
|
+
old_stps = state.get("old_stps") # change in iterates
|
|
320
|
+
H_diag = state.get("H_diag")
|
|
321
|
+
|
|
322
|
+
# compute the product of the inverse Hessian approximation and the gradient
|
|
323
|
+
num_old = len(old_dirs)
|
|
324
|
+
|
|
325
|
+
if "rho" not in state:
|
|
326
|
+
state["rho"] = [None] * history_size
|
|
327
|
+
state["alpha"] = [None] * history_size
|
|
328
|
+
rho = state["rho"]
|
|
329
|
+
alpha = state["alpha"]
|
|
330
|
+
|
|
331
|
+
for i in range(num_old):
|
|
332
|
+
rho[i] = 1.0 / old_stps[i].dot(old_dirs[i])
|
|
333
|
+
|
|
334
|
+
q = vec
|
|
335
|
+
for i in range(num_old - 1, -1, -1):
|
|
336
|
+
alpha[i] = old_dirs[i].dot(q) * rho[i]
|
|
337
|
+
q.add_(old_stps[i], alpha=-alpha[i])
|
|
338
|
+
|
|
339
|
+
# multiply by initial Hessian
|
|
340
|
+
# r/d is the final direction
|
|
341
|
+
r = torch.mul(q, H_diag)
|
|
342
|
+
for i in range(num_old):
|
|
343
|
+
beta = old_stps[i].dot(r) * rho[i]
|
|
344
|
+
r.add_(old_dirs[i], alpha=(alpha[i] - beta))
|
|
345
|
+
|
|
346
|
+
return r
|
|
347
|
+
|
|
348
|
+
def curvature_update(self, flat_grad, eps=1e-2, damping=False):
|
|
349
|
+
"""
|
|
350
|
+
Performs curvature update.
|
|
351
|
+
Inputs:
|
|
352
|
+
flat_grad (tensor): 1-D tensor of flattened gradient for computing
|
|
353
|
+
gradient difference with previously stored gradient
|
|
354
|
+
eps (float): constant for curvature pair rejection or damping (default: 1e-2)
|
|
355
|
+
damping (bool): flag for using Powell damping (default: False)
|
|
356
|
+
"""
|
|
357
|
+
|
|
358
|
+
assert len(self.param_groups) == 1
|
|
359
|
+
|
|
360
|
+
# load parameters
|
|
361
|
+
if eps <= 0:
|
|
362
|
+
raise (ValueError("Invalid eps; must be positive."))
|
|
363
|
+
|
|
364
|
+
group = self.param_groups[0]
|
|
365
|
+
history_size = group["history_size"]
|
|
366
|
+
debug = group["debug"]
|
|
367
|
+
|
|
368
|
+
# variables cached in state (for tracing)
|
|
369
|
+
state = self.state["global_state"]
|
|
370
|
+
fail = state.get("fail")
|
|
371
|
+
|
|
372
|
+
# check if line search failed
|
|
373
|
+
if not fail:
|
|
374
|
+
|
|
375
|
+
d = state.get("d")
|
|
376
|
+
t = state.get("t")
|
|
377
|
+
old_dirs = state.get("old_dirs")
|
|
378
|
+
old_stps = state.get("old_stps")
|
|
379
|
+
H_diag = state.get("H_diag")
|
|
380
|
+
prev_flat_grad = state.get("prev_flat_grad")
|
|
381
|
+
Bs = state.get("Bs")
|
|
382
|
+
|
|
383
|
+
# compute y's
|
|
384
|
+
y = flat_grad.sub(prev_flat_grad)
|
|
385
|
+
s = d.mul(t)
|
|
386
|
+
sBs = s.dot(Bs)
|
|
387
|
+
ys = y.dot(s) # y*s
|
|
388
|
+
|
|
389
|
+
# update L-BFGS matrix
|
|
390
|
+
if ys > eps * sBs or damping == True:
|
|
391
|
+
|
|
392
|
+
# perform Powell damping
|
|
393
|
+
if damping == True and ys < eps * sBs:
|
|
394
|
+
if debug:
|
|
395
|
+
print("Applying Powell damping...")
|
|
396
|
+
theta = ((1 - eps) * sBs) / (sBs - ys)
|
|
397
|
+
y = theta * y + (1 - theta) * Bs
|
|
398
|
+
|
|
399
|
+
# updating memory
|
|
400
|
+
if len(old_dirs) == history_size:
|
|
401
|
+
# shift history by one (limited-memory)
|
|
402
|
+
old_dirs.pop(0)
|
|
403
|
+
old_stps.pop(0)
|
|
404
|
+
|
|
405
|
+
# store new direction/step
|
|
406
|
+
old_dirs.append(s)
|
|
407
|
+
old_stps.append(y)
|
|
408
|
+
|
|
409
|
+
# update scale of initial Hessian approximation
|
|
410
|
+
H_diag = ys / y.dot(y) # (y*y)
|
|
411
|
+
|
|
412
|
+
state["old_dirs"] = old_dirs
|
|
413
|
+
state["old_stps"] = old_stps
|
|
414
|
+
state["H_diag"] = H_diag
|
|
415
|
+
|
|
416
|
+
else:
|
|
417
|
+
# save skip
|
|
418
|
+
state["curv_skips"] += 1
|
|
419
|
+
if debug:
|
|
420
|
+
print("Curvature pair skipped due to failed criterion")
|
|
421
|
+
|
|
422
|
+
else:
|
|
423
|
+
# save skip
|
|
424
|
+
state["fail_skips"] += 1
|
|
425
|
+
if debug:
|
|
426
|
+
print("Line search failed; curvature pair update skipped")
|
|
427
|
+
|
|
428
|
+
return
|
|
429
|
+
|
|
430
|
+
def _step(self, p_k, g_Ok, g_Sk=None, options={}):
|
|
431
|
+
"""
|
|
432
|
+
Performs a single optimization step.
|
|
433
|
+
Inputs:
|
|
434
|
+
p_k (tensor): 1-D tensor specifying search direction
|
|
435
|
+
g_Ok (tensor): 1-D tensor of flattened gradient over overlap O_k used
|
|
436
|
+
for gradient differencing in curvature pair update
|
|
437
|
+
g_Sk (tensor): 1-D tensor of flattened gradient over full sample S_k
|
|
438
|
+
used for curvature pair damping or rejection criterion,
|
|
439
|
+
if None, will use g_Ok (default: None)
|
|
440
|
+
options (dict): contains options for performing line search
|
|
441
|
+
Options for Armijo backtracking line search:
|
|
442
|
+
'closure' (callable): reevaluates model and returns function value
|
|
443
|
+
'current_loss' (tensor): objective value at current iterate (default: F(x_k))
|
|
444
|
+
'gtd' (tensor): inner product g_Ok'd in line search (default: g_Ok'd)
|
|
445
|
+
'eta' (tensor): factor for decreasing steplength > 0 (default: 2)
|
|
446
|
+
'c1' (tensor): sufficient decrease constant in (0, 1) (default: 1e-4)
|
|
447
|
+
'max_ls' (int): maximum number of line search steps permitted (default: 10)
|
|
448
|
+
'interpolate' (bool): flag for using interpolation (default: True)
|
|
449
|
+
'inplace' (bool): flag for inplace operations (default: True)
|
|
450
|
+
'ls_debug' (bool): debugging mode for line search
|
|
451
|
+
Options for Wolfe line search:
|
|
452
|
+
'closure' (callable): reevaluates model and returns function value
|
|
453
|
+
'current_loss' (tensor): objective value at current iterate (default: F(x_k))
|
|
454
|
+
'gtd' (tensor): inner product g_Ok'd in line search (default: g_Ok'd)
|
|
455
|
+
'eta' (float): factor for extrapolation (default: 2)
|
|
456
|
+
'c1' (float): sufficient decrease constant in (0, 1) (default: 1e-4)
|
|
457
|
+
'c2' (float): curvature condition constant in (0, 1) (default: 0.9)
|
|
458
|
+
'max_ls' (int): maximum number of line search steps permitted (default: 10)
|
|
459
|
+
'interpolate' (bool): flag for using interpolation (default: True)
|
|
460
|
+
'inplace' (bool): flag for inplace operations (default: True)
|
|
461
|
+
'ls_debug' (bool): debugging mode for line search
|
|
462
|
+
Outputs (depends on line search):
|
|
463
|
+
. No line search:
|
|
464
|
+
t (float): steplength
|
|
465
|
+
. Armijo backtracking line search:
|
|
466
|
+
F_new (tensor): loss function at new iterate
|
|
467
|
+
t (tensor): final steplength
|
|
468
|
+
ls_step (int): number of backtracks
|
|
469
|
+
closure_eval (int): number of closure evaluations
|
|
470
|
+
desc_dir (bool): descent direction flag
|
|
471
|
+
True: p_k is descent direction with respect to the line search
|
|
472
|
+
function
|
|
473
|
+
False: p_k is not a descent direction with respect to the line
|
|
474
|
+
search function
|
|
475
|
+
fail (bool): failure flag
|
|
476
|
+
True: line search reached maximum number of iterations, failed
|
|
477
|
+
False: line search succeeded
|
|
478
|
+
. Wolfe line search:
|
|
479
|
+
F_new (tensor): loss function at new iterate
|
|
480
|
+
g_new (tensor): gradient at new iterate
|
|
481
|
+
t (float): final steplength
|
|
482
|
+
ls_step (int): number of backtracks
|
|
483
|
+
closure_eval (int): number of closure evaluations
|
|
484
|
+
grad_eval (int): number of gradient evaluations
|
|
485
|
+
desc_dir (bool): descent direction flag
|
|
486
|
+
True: p_k is descent direction with respect to the line search
|
|
487
|
+
function
|
|
488
|
+
False: p_k is not a descent direction with respect to the line
|
|
489
|
+
search function
|
|
490
|
+
fail (bool): failure flag
|
|
491
|
+
True: line search reached maximum number of iterations, failed
|
|
492
|
+
False: line search succeeded
|
|
493
|
+
Notes:
|
|
494
|
+
. If encountering line search failure in the deterministic setting, one
|
|
495
|
+
should try increasing the maximum number of line search steps max_ls.
|
|
496
|
+
"""
|
|
497
|
+
|
|
498
|
+
assert len(self.param_groups) == 1
|
|
499
|
+
|
|
500
|
+
# load parameter options
|
|
501
|
+
group = self.param_groups[0]
|
|
502
|
+
lr = group["lr"]
|
|
503
|
+
line_search = group["line_search"]
|
|
504
|
+
dtype = group["dtype"]
|
|
505
|
+
debug = group["debug"]
|
|
506
|
+
|
|
507
|
+
# variables cached in state (for tracing)
|
|
508
|
+
state = self.state["global_state"]
|
|
509
|
+
d = state.get("d")
|
|
510
|
+
t = state.get("t")
|
|
511
|
+
prev_flat_grad = state.get("prev_flat_grad")
|
|
512
|
+
Bs = state.get("Bs")
|
|
513
|
+
|
|
514
|
+
# keep track of nb of iterations
|
|
515
|
+
state["n_iter"] += 1
|
|
516
|
+
|
|
517
|
+
# set search direction
|
|
518
|
+
d = p_k
|
|
519
|
+
|
|
520
|
+
# modify previous gradient
|
|
521
|
+
if prev_flat_grad is None:
|
|
522
|
+
prev_flat_grad = g_Ok.clone()
|
|
523
|
+
else:
|
|
524
|
+
prev_flat_grad.copy_(g_Ok)
|
|
525
|
+
|
|
526
|
+
# set initial step size
|
|
527
|
+
t = lr
|
|
528
|
+
|
|
529
|
+
# closure evaluation counter
|
|
530
|
+
closure_eval = 0
|
|
531
|
+
|
|
532
|
+
if g_Sk is None:
|
|
533
|
+
g_Sk = g_Ok.clone()
|
|
534
|
+
|
|
535
|
+
# perform Armijo backtracking line search
|
|
536
|
+
if line_search == "Armijo":
|
|
537
|
+
|
|
538
|
+
# load options
|
|
539
|
+
if options:
|
|
540
|
+
if "closure" not in options.keys():
|
|
541
|
+
raise (ValueError("closure option not specified."))
|
|
542
|
+
else:
|
|
543
|
+
closure = options["closure"]
|
|
544
|
+
|
|
545
|
+
if "gtd" not in options.keys():
|
|
546
|
+
gtd = g_Ok.dot(d)
|
|
547
|
+
else:
|
|
548
|
+
gtd = options["gtd"]
|
|
549
|
+
|
|
550
|
+
if "current_loss" not in options.keys():
|
|
551
|
+
F_k = closure()
|
|
552
|
+
closure_eval += 1
|
|
553
|
+
else:
|
|
554
|
+
F_k = options["current_loss"]
|
|
555
|
+
|
|
556
|
+
if "eta" not in options.keys():
|
|
557
|
+
eta = 2
|
|
558
|
+
elif options["eta"] <= 0:
|
|
559
|
+
raise (ValueError("Invalid eta; must be positive."))
|
|
560
|
+
else:
|
|
561
|
+
eta = options["eta"]
|
|
562
|
+
|
|
563
|
+
if "c1" not in options.keys():
|
|
564
|
+
c1 = 1e-4
|
|
565
|
+
elif options["c1"] >= 1 or options["c1"] <= 0:
|
|
566
|
+
raise (ValueError("Invalid c1; must be strictly between 0 and 1."))
|
|
567
|
+
else:
|
|
568
|
+
c1 = options["c1"]
|
|
569
|
+
|
|
570
|
+
if "max_ls" not in options.keys():
|
|
571
|
+
max_ls = 10
|
|
572
|
+
elif options["max_ls"] <= 0:
|
|
573
|
+
raise (ValueError("Invalid max_ls; must be positive."))
|
|
574
|
+
else:
|
|
575
|
+
max_ls = options["max_ls"]
|
|
576
|
+
|
|
577
|
+
if "interpolate" not in options.keys():
|
|
578
|
+
interpolate = True
|
|
579
|
+
else:
|
|
580
|
+
interpolate = options["interpolate"]
|
|
581
|
+
|
|
582
|
+
if "inplace" not in options.keys():
|
|
583
|
+
inplace = True
|
|
584
|
+
else:
|
|
585
|
+
inplace = options["inplace"]
|
|
586
|
+
|
|
587
|
+
if "ls_debug" not in options.keys():
|
|
588
|
+
ls_debug = False
|
|
589
|
+
else:
|
|
590
|
+
ls_debug = options["ls_debug"]
|
|
591
|
+
|
|
592
|
+
else:
|
|
593
|
+
raise (ValueError("Options are not specified; need closure evaluating function."))
|
|
594
|
+
|
|
595
|
+
# initialize values
|
|
596
|
+
if interpolate:
|
|
597
|
+
if torch.cuda.is_available():
|
|
598
|
+
F_prev = torch.tensor(np.nan, dtype=dtype).cuda()
|
|
599
|
+
else:
|
|
600
|
+
F_prev = torch.tensor(np.nan, dtype=dtype)
|
|
601
|
+
|
|
602
|
+
ls_step = 0
|
|
603
|
+
t_prev = 0 # old steplength
|
|
604
|
+
fail = False # failure flag
|
|
605
|
+
|
|
606
|
+
# begin print for debug mode
|
|
607
|
+
if ls_debug:
|
|
608
|
+
print(
|
|
609
|
+
"==================================== Begin Armijo line search ==================================="
|
|
610
|
+
)
|
|
611
|
+
print("F(x): %.8e g*d: %.8e" % (F_k, gtd))
|
|
612
|
+
|
|
613
|
+
# check if search direction is descent direction
|
|
614
|
+
if gtd >= 0:
|
|
615
|
+
desc_dir = False
|
|
616
|
+
if debug:
|
|
617
|
+
print("Not a descent direction!")
|
|
618
|
+
else:
|
|
619
|
+
desc_dir = True
|
|
620
|
+
|
|
621
|
+
# store values if not in-place
|
|
622
|
+
if not inplace:
|
|
623
|
+
current_params = self._copy_params()
|
|
624
|
+
|
|
625
|
+
# update and evaluate at new point
|
|
626
|
+
self._add_update(t, d)
|
|
627
|
+
F_new = closure()
|
|
628
|
+
closure_eval += 1
|
|
629
|
+
|
|
630
|
+
# print info if debugging
|
|
631
|
+
if ls_debug:
|
|
632
|
+
print(
|
|
633
|
+
"LS Step: %d t: %.8e F(x+td): %.8e F-c1*t*g*d: %.8e F(x): %.8e"
|
|
634
|
+
% (ls_step, t, F_new, F_k + c1 * t * gtd, F_k)
|
|
635
|
+
)
|
|
636
|
+
|
|
637
|
+
# check Armijo condition
|
|
638
|
+
while F_new > F_k + c1 * t * gtd or not is_legal(F_new):
|
|
639
|
+
|
|
640
|
+
# check if maximum number of iterations reached
|
|
641
|
+
if ls_step >= max_ls:
|
|
642
|
+
if inplace:
|
|
643
|
+
self._add_update(-t, d)
|
|
644
|
+
else:
|
|
645
|
+
self._load_params(current_params)
|
|
646
|
+
|
|
647
|
+
t = 0
|
|
648
|
+
F_new = closure()
|
|
649
|
+
closure_eval += 1
|
|
650
|
+
fail = True
|
|
651
|
+
break
|
|
652
|
+
|
|
653
|
+
else:
|
|
654
|
+
# store current steplength
|
|
655
|
+
t_new = t
|
|
656
|
+
|
|
657
|
+
# compute new steplength
|
|
658
|
+
|
|
659
|
+
# if first step or not interpolating, then multiply by factor
|
|
660
|
+
if ls_step == 0 or not interpolate or not is_legal(F_new):
|
|
661
|
+
t = t / eta
|
|
662
|
+
|
|
663
|
+
# if second step, use function value at new point along with
|
|
664
|
+
# gradient and function at current iterate
|
|
665
|
+
elif ls_step == 1 or not is_legal(F_prev):
|
|
666
|
+
t = polyinterp(np.array([[0, F_k.item(), gtd.item()], [t_new, F_new.item(), np.nan]]))
|
|
667
|
+
|
|
668
|
+
# otherwise, use function values at new point, previous point,
|
|
669
|
+
# and gradient and function at current iterate
|
|
670
|
+
else:
|
|
671
|
+
t = polyinterp(
|
|
672
|
+
np.array(
|
|
673
|
+
[
|
|
674
|
+
[0, F_k.item(), gtd.item()],
|
|
675
|
+
[t_new, F_new.item(), np.nan],
|
|
676
|
+
[t_prev, F_prev.item(), np.nan],
|
|
677
|
+
]
|
|
678
|
+
)
|
|
679
|
+
)
|
|
680
|
+
|
|
681
|
+
# if values are too extreme, adjust t
|
|
682
|
+
if interpolate:
|
|
683
|
+
if t < 1e-3 * t_new:
|
|
684
|
+
t = 1e-3 * t_new
|
|
685
|
+
elif t > 0.6 * t_new:
|
|
686
|
+
t = 0.6 * t_new
|
|
687
|
+
|
|
688
|
+
# store old point
|
|
689
|
+
F_prev = F_new
|
|
690
|
+
t_prev = t_new
|
|
691
|
+
|
|
692
|
+
# update iterate and reevaluate
|
|
693
|
+
if inplace:
|
|
694
|
+
self._add_update(t - t_new, d)
|
|
695
|
+
else:
|
|
696
|
+
self._load_params(current_params)
|
|
697
|
+
self._add_update(t, d)
|
|
698
|
+
|
|
699
|
+
F_new = closure()
|
|
700
|
+
closure_eval += 1
|
|
701
|
+
ls_step += 1 # iterate
|
|
702
|
+
|
|
703
|
+
# print info if debugging
|
|
704
|
+
if ls_debug:
|
|
705
|
+
print(
|
|
706
|
+
"LS Step: %d t: %.8e F(x+td): %.8e F-c1*t*g*d: %.8e F(x): %.8e"
|
|
707
|
+
% (ls_step, t, F_new, F_k + c1 * t * gtd, F_k)
|
|
708
|
+
)
|
|
709
|
+
|
|
710
|
+
# store Bs
|
|
711
|
+
if Bs is None:
|
|
712
|
+
Bs = (g_Sk.mul(-t)).clone()
|
|
713
|
+
else:
|
|
714
|
+
Bs.copy_(g_Sk.mul(-t))
|
|
715
|
+
|
|
716
|
+
# print final steplength
|
|
717
|
+
if ls_debug:
|
|
718
|
+
print("Final Steplength:", t)
|
|
719
|
+
print(
|
|
720
|
+
"===================================== End Armijo line search ===================================="
|
|
721
|
+
)
|
|
722
|
+
|
|
723
|
+
state["d"] = d
|
|
724
|
+
state["prev_flat_grad"] = prev_flat_grad
|
|
725
|
+
state["t"] = t
|
|
726
|
+
state["Bs"] = Bs
|
|
727
|
+
state["fail"] = fail
|
|
728
|
+
|
|
729
|
+
return F_new, t, ls_step, closure_eval, desc_dir, fail
|
|
730
|
+
|
|
731
|
+
# perform weak Wolfe line search
|
|
732
|
+
elif line_search == "Wolfe":
|
|
733
|
+
|
|
734
|
+
# load options
|
|
735
|
+
if options:
|
|
736
|
+
if "closure" not in options.keys():
|
|
737
|
+
raise (ValueError("closure option not specified."))
|
|
738
|
+
else:
|
|
739
|
+
closure = options["closure"]
|
|
740
|
+
|
|
741
|
+
if "current_loss" not in options.keys():
|
|
742
|
+
F_k = closure()
|
|
743
|
+
closure_eval += 1
|
|
744
|
+
else:
|
|
745
|
+
F_k = options["current_loss"]
|
|
746
|
+
|
|
747
|
+
if "gtd" not in options.keys():
|
|
748
|
+
gtd = g_Ok.dot(d)
|
|
749
|
+
else:
|
|
750
|
+
gtd = options["gtd"]
|
|
751
|
+
|
|
752
|
+
if "eta" not in options.keys():
|
|
753
|
+
eta = 2
|
|
754
|
+
elif options["eta"] <= 1:
|
|
755
|
+
raise (ValueError("Invalid eta; must be greater than 1."))
|
|
756
|
+
else:
|
|
757
|
+
eta = options["eta"]
|
|
758
|
+
|
|
759
|
+
if "c1" not in options.keys():
|
|
760
|
+
c1 = 1e-4
|
|
761
|
+
elif options["c1"] >= 1 or options["c1"] <= 0:
|
|
762
|
+
raise (ValueError("Invalid c1; must be strictly between 0 and 1."))
|
|
763
|
+
else:
|
|
764
|
+
c1 = options["c1"]
|
|
765
|
+
|
|
766
|
+
if "c2" not in options.keys():
|
|
767
|
+
c2 = 0.9
|
|
768
|
+
elif options["c2"] >= 1 or options["c2"] <= 0:
|
|
769
|
+
raise (ValueError("Invalid c2; must be strictly between 0 and 1."))
|
|
770
|
+
elif options["c2"] <= c1:
|
|
771
|
+
raise (ValueError("Invalid c2; must be strictly larger than c1."))
|
|
772
|
+
else:
|
|
773
|
+
c2 = options["c2"]
|
|
774
|
+
|
|
775
|
+
if "max_ls" not in options.keys():
|
|
776
|
+
max_ls = 10
|
|
777
|
+
elif options["max_ls"] <= 0:
|
|
778
|
+
raise (ValueError("Invalid max_ls; must be positive."))
|
|
779
|
+
else:
|
|
780
|
+
max_ls = options["max_ls"]
|
|
781
|
+
|
|
782
|
+
if "interpolate" not in options.keys():
|
|
783
|
+
interpolate = True
|
|
784
|
+
else:
|
|
785
|
+
interpolate = options["interpolate"]
|
|
786
|
+
|
|
787
|
+
if "inplace" not in options.keys():
|
|
788
|
+
inplace = True
|
|
789
|
+
else:
|
|
790
|
+
inplace = options["inplace"]
|
|
791
|
+
|
|
792
|
+
if "ls_debug" not in options.keys():
|
|
793
|
+
ls_debug = False
|
|
794
|
+
else:
|
|
795
|
+
ls_debug = options["ls_debug"]
|
|
796
|
+
|
|
797
|
+
else:
|
|
798
|
+
raise (ValueError("Options are not specified; need closure evaluating function."))
|
|
799
|
+
|
|
800
|
+
# initialize counters
|
|
801
|
+
ls_step = 0
|
|
802
|
+
grad_eval = 0 # tracks gradient evaluations
|
|
803
|
+
t_prev = 0 # old steplength
|
|
804
|
+
|
|
805
|
+
# initialize bracketing variables and flag
|
|
806
|
+
alpha = 0
|
|
807
|
+
beta = float("Inf")
|
|
808
|
+
fail = False
|
|
809
|
+
|
|
810
|
+
# initialize values for line search
|
|
811
|
+
if interpolate:
|
|
812
|
+
F_a = F_k
|
|
813
|
+
g_a = gtd
|
|
814
|
+
|
|
815
|
+
if torch.cuda.is_available():
|
|
816
|
+
F_b = torch.tensor(np.nan, dtype=dtype).cuda()
|
|
817
|
+
g_b = torch.tensor(np.nan, dtype=dtype).cuda()
|
|
818
|
+
else:
|
|
819
|
+
F_b = torch.tensor(np.nan, dtype=dtype)
|
|
820
|
+
g_b = torch.tensor(np.nan, dtype=dtype)
|
|
821
|
+
|
|
822
|
+
# begin print for debug mode
|
|
823
|
+
if ls_debug:
|
|
824
|
+
print(
|
|
825
|
+
"==================================== Begin Wolfe line search ===================================="
|
|
826
|
+
)
|
|
827
|
+
print("F(x): %.8e g*d: %.8e" % (F_k, gtd))
|
|
828
|
+
|
|
829
|
+
# check if search direction is descent direction
|
|
830
|
+
if gtd >= 0:
|
|
831
|
+
desc_dir = False
|
|
832
|
+
if debug:
|
|
833
|
+
print("Not a descent direction!")
|
|
834
|
+
else:
|
|
835
|
+
desc_dir = True
|
|
836
|
+
|
|
837
|
+
# store values if not in-place
|
|
838
|
+
if not inplace:
|
|
839
|
+
current_params = self._copy_params()
|
|
840
|
+
|
|
841
|
+
# update and evaluate at new point
|
|
842
|
+
self._add_update(t, d)
|
|
843
|
+
F_new = closure()
|
|
844
|
+
closure_eval += 1
|
|
845
|
+
|
|
846
|
+
# main loop
|
|
847
|
+
while True:
|
|
848
|
+
|
|
849
|
+
# check if maximum number of line search steps have been reached
|
|
850
|
+
if ls_step >= max_ls:
|
|
851
|
+
if inplace:
|
|
852
|
+
self._add_update(-t, d)
|
|
853
|
+
else:
|
|
854
|
+
self._load_params(current_params)
|
|
855
|
+
|
|
856
|
+
t = 0
|
|
857
|
+
F_new = closure()
|
|
858
|
+
F_new.backward()
|
|
859
|
+
g_new = self._gather_flat_grad()
|
|
860
|
+
closure_eval += 1
|
|
861
|
+
grad_eval += 1
|
|
862
|
+
fail = True
|
|
863
|
+
break
|
|
864
|
+
|
|
865
|
+
# print info if debugging
|
|
866
|
+
if ls_debug:
|
|
867
|
+
print("LS Step: %d t: %.8e alpha: %.8e beta: %.8e" % (ls_step, t, alpha, beta))
|
|
868
|
+
print("Armijo: F(x+td): %.8e F-c1*t*g*d: %.8e F(x): %.8e" % (F_new, F_k + c1 * t * gtd, F_k))
|
|
869
|
+
|
|
870
|
+
# check Armijo condition
|
|
871
|
+
if F_new > F_k + c1 * t * gtd:
|
|
872
|
+
|
|
873
|
+
# set upper bound
|
|
874
|
+
beta = t
|
|
875
|
+
t_prev = t
|
|
876
|
+
|
|
877
|
+
# update interpolation quantities
|
|
878
|
+
if interpolate:
|
|
879
|
+
F_b = F_new
|
|
880
|
+
if torch.cuda.is_available():
|
|
881
|
+
g_b = torch.tensor(np.nan, dtype=dtype).cuda()
|
|
882
|
+
else:
|
|
883
|
+
g_b = torch.tensor(np.nan, dtype=dtype)
|
|
884
|
+
|
|
885
|
+
else:
|
|
886
|
+
|
|
887
|
+
# compute gradient
|
|
888
|
+
F_new.backward()
|
|
889
|
+
g_new = self._gather_flat_grad()
|
|
890
|
+
grad_eval += 1
|
|
891
|
+
gtd_new = g_new.dot(d)
|
|
892
|
+
|
|
893
|
+
# print info if debugging
|
|
894
|
+
if ls_debug:
|
|
895
|
+
print("Wolfe: g(x+td)*d: %.8e c2*g*d: %.8e gtd: %.8e" % (gtd_new, c2 * gtd, gtd))
|
|
896
|
+
|
|
897
|
+
# check curvature condition
|
|
898
|
+
if gtd_new < c2 * gtd:
|
|
899
|
+
|
|
900
|
+
# set lower bound
|
|
901
|
+
alpha = t
|
|
902
|
+
t_prev = t
|
|
903
|
+
|
|
904
|
+
# update interpolation quantities
|
|
905
|
+
if interpolate:
|
|
906
|
+
F_a = F_new
|
|
907
|
+
g_a = gtd_new
|
|
908
|
+
|
|
909
|
+
else:
|
|
910
|
+
break
|
|
911
|
+
|
|
912
|
+
# compute new steplength
|
|
913
|
+
|
|
914
|
+
# if first step or not interpolating, then bisect or multiply by factor
|
|
915
|
+
if not interpolate or not is_legal(F_b):
|
|
916
|
+
if beta == float("Inf"):
|
|
917
|
+
t = eta * t
|
|
918
|
+
else:
|
|
919
|
+
t = (alpha + beta) / 2.0
|
|
920
|
+
|
|
921
|
+
# otherwise interpolate between a and b
|
|
922
|
+
else:
|
|
923
|
+
t = polyinterp(np.array([[alpha, F_a.item(), g_a.item()], [beta, F_b.item(), g_b.item()]]))
|
|
924
|
+
|
|
925
|
+
# if values are too extreme, adjust t
|
|
926
|
+
if beta == float("Inf"):
|
|
927
|
+
if t > 2 * eta * t_prev:
|
|
928
|
+
t = 2 * eta * t_prev
|
|
929
|
+
elif t < eta * t_prev:
|
|
930
|
+
t = eta * t_prev
|
|
931
|
+
else:
|
|
932
|
+
if t < alpha + 0.2 * (beta - alpha):
|
|
933
|
+
t = alpha + 0.2 * (beta - alpha)
|
|
934
|
+
elif t > (beta - alpha) / 2.0:
|
|
935
|
+
t = (beta - alpha) / 2.0
|
|
936
|
+
|
|
937
|
+
# if we obtain nonsensical value from interpolation
|
|
938
|
+
if t <= 0:
|
|
939
|
+
t = (beta - alpha) / 2.0
|
|
940
|
+
|
|
941
|
+
# update parameters
|
|
942
|
+
if inplace:
|
|
943
|
+
self._add_update(t - t_prev, d)
|
|
944
|
+
else:
|
|
945
|
+
self._load_params(current_params)
|
|
946
|
+
self._add_update(t, d)
|
|
947
|
+
|
|
948
|
+
# evaluate closure
|
|
949
|
+
F_new = closure()
|
|
950
|
+
closure_eval += 1
|
|
951
|
+
ls_step += 1
|
|
952
|
+
|
|
953
|
+
# store Bs
|
|
954
|
+
if Bs is None:
|
|
955
|
+
Bs = (g_Sk.mul(-t)).clone()
|
|
956
|
+
else:
|
|
957
|
+
Bs.copy_(g_Sk.mul(-t))
|
|
958
|
+
|
|
959
|
+
# print final steplength
|
|
960
|
+
if ls_debug:
|
|
961
|
+
print("Final Steplength:", t)
|
|
962
|
+
print(
|
|
963
|
+
"===================================== End Wolfe line search ====================================="
|
|
964
|
+
)
|
|
965
|
+
|
|
966
|
+
state["d"] = d
|
|
967
|
+
state["prev_flat_grad"] = prev_flat_grad
|
|
968
|
+
state["t"] = t
|
|
969
|
+
state["Bs"] = Bs
|
|
970
|
+
state["fail"] = fail
|
|
971
|
+
|
|
972
|
+
return F_new, g_new, t, ls_step, closure_eval, grad_eval, desc_dir, fail
|
|
973
|
+
|
|
974
|
+
else:
|
|
975
|
+
|
|
976
|
+
# perform update
|
|
977
|
+
self._add_update(t, d)
|
|
978
|
+
|
|
979
|
+
# store Bs
|
|
980
|
+
if Bs is None:
|
|
981
|
+
Bs = (g_Sk.mul(-t)).clone()
|
|
982
|
+
else:
|
|
983
|
+
Bs.copy_(g_Sk.mul(-t))
|
|
984
|
+
|
|
985
|
+
state["d"] = d
|
|
986
|
+
state["prev_flat_grad"] = prev_flat_grad
|
|
987
|
+
state["t"] = t
|
|
988
|
+
state["Bs"] = Bs
|
|
989
|
+
state["fail"] = False
|
|
990
|
+
|
|
991
|
+
return t
|
|
992
|
+
|
|
993
|
+
def step(self, p_k, g_Ok, g_Sk=None, options={}):
|
|
994
|
+
return self._step(p_k, g_Ok, g_Sk, options)
|
|
995
|
+
|
|
996
|
+
|
|
997
|
+
#%% Full-Batch (Deterministic) L-BFGS Optimizer (Wrapper)
|
|
998
|
+
|
|
999
|
+
|
|
1000
|
+
class FullBatchLBFGS(LBFGS):
|
|
1001
|
+
"""
|
|
1002
|
+
Implements full-batch or deterministic L-BFGS algorithm. Compatible with
|
|
1003
|
+
Powell damping. Can be used when evaluating a deterministic function and
|
|
1004
|
+
gradient. Wraps the LBFGS optimizer. Performs the two-loop recursion,
|
|
1005
|
+
updating, and curvature updating in a single step.
|
|
1006
|
+
Implemented by: Hao-Jun Michael Shi and Dheevatsa Mudigere
|
|
1007
|
+
Last edited 11/15/18.
|
|
1008
|
+
Warnings:
|
|
1009
|
+
. Does not support per-parameter options and parameter groups.
|
|
1010
|
+
. All parameters have to be on a single device.
|
|
1011
|
+
Inputs:
|
|
1012
|
+
lr (float): steplength or learning rate (default: 1)
|
|
1013
|
+
history_size (int): update history size (default: 10)
|
|
1014
|
+
line_search (str): designates line search to use (default: 'Wolfe')
|
|
1015
|
+
Options:
|
|
1016
|
+
'None': uses steplength designated in algorithm
|
|
1017
|
+
'Armijo': uses Armijo backtracking line search
|
|
1018
|
+
'Wolfe': uses Armijo-Wolfe bracketing line search
|
|
1019
|
+
dtype: data type (default: torch.float)
|
|
1020
|
+
debug (bool): debugging mode
|
|
1021
|
+
"""
|
|
1022
|
+
|
|
1023
|
+
def __init__(self, params, lr=1, history_size=10, line_search="Wolfe", dtype=torch.float, debug=False):
|
|
1024
|
+
super(FullBatchLBFGS, self).__init__(params, lr, history_size, line_search, dtype, debug)
|
|
1025
|
+
|
|
1026
|
+
def step(self, options={}):
|
|
1027
|
+
"""
|
|
1028
|
+
Performs a single optimization step.
|
|
1029
|
+
Inputs:
|
|
1030
|
+
options (dict): contains options for performing line search
|
|
1031
|
+
General Options:
|
|
1032
|
+
'eps' (float): constant for curvature pair rejection or damping (default: 1e-2)
|
|
1033
|
+
'damping' (bool): flag for using Powell damping (default: False)
|
|
1034
|
+
Options for Armijo backtracking line search:
|
|
1035
|
+
'closure' (callable): reevaluates model and returns function value
|
|
1036
|
+
'current_loss' (tensor): objective value at current iterate (default: F(x_k))
|
|
1037
|
+
'gtd' (tensor): inner product g_Ok'd in line search (default: g_Ok'd)
|
|
1038
|
+
'eta' (tensor): factor for decreasing steplength > 0 (default: 2)
|
|
1039
|
+
'c1' (tensor): sufficient decrease constant in (0, 1) (default: 1e-4)
|
|
1040
|
+
'max_ls' (int): maximum number of line search steps permitted (default: 10)
|
|
1041
|
+
'interpolate' (bool): flag for using interpolation (default: True)
|
|
1042
|
+
'inplace' (bool): flag for inplace operations (default: True)
|
|
1043
|
+
'ls_debug' (bool): debugging mode for line search
|
|
1044
|
+
Options for Wolfe line search:
|
|
1045
|
+
'closure' (callable): reevaluates model and returns function value
|
|
1046
|
+
'current_loss' (tensor): objective value at current iterate (default: F(x_k))
|
|
1047
|
+
'gtd' (tensor): inner product g_Ok'd in line search (default: g_Ok'd)
|
|
1048
|
+
'eta' (float): factor for extrapolation (default: 2)
|
|
1049
|
+
'c1' (float): sufficient decrease constant in (0, 1) (default: 1e-4)
|
|
1050
|
+
'c2' (float): curvature condition constant in (0, 1) (default: 0.9)
|
|
1051
|
+
'max_ls' (int): maximum number of line search steps permitted (default: 10)
|
|
1052
|
+
'interpolate' (bool): flag for using interpolation (default: True)
|
|
1053
|
+
'inplace' (bool): flag for inplace operations (default: True)
|
|
1054
|
+
'ls_debug' (bool): debugging mode for line search
|
|
1055
|
+
Outputs (depends on line search):
|
|
1056
|
+
. No line search:
|
|
1057
|
+
t (float): steplength
|
|
1058
|
+
. Armijo backtracking line search:
|
|
1059
|
+
F_new (tensor): loss function at new iterate
|
|
1060
|
+
t (tensor): final steplength
|
|
1061
|
+
ls_step (int): number of backtracks
|
|
1062
|
+
closure_eval (int): number of closure evaluations
|
|
1063
|
+
desc_dir (bool): descent direction flag
|
|
1064
|
+
True: p_k is descent direction with respect to the line search
|
|
1065
|
+
function
|
|
1066
|
+
False: p_k is not a descent direction with respect to the line
|
|
1067
|
+
search function
|
|
1068
|
+
fail (bool): failure flag
|
|
1069
|
+
True: line search reached maximum number of iterations, failed
|
|
1070
|
+
False: line search succeeded
|
|
1071
|
+
. Wolfe line search:
|
|
1072
|
+
F_new (tensor): loss function at new iterate
|
|
1073
|
+
g_new (tensor): gradient at new iterate
|
|
1074
|
+
t (float): final steplength
|
|
1075
|
+
ls_step (int): number of backtracks
|
|
1076
|
+
closure_eval (int): number of closure evaluations
|
|
1077
|
+
grad_eval (int): number of gradient evaluations
|
|
1078
|
+
desc_dir (bool): descent direction flag
|
|
1079
|
+
True: p_k is descent direction with respect to the line search
|
|
1080
|
+
function
|
|
1081
|
+
False: p_k is not a descent direction with respect to the line
|
|
1082
|
+
search function
|
|
1083
|
+
fail (bool): failure flag
|
|
1084
|
+
True: line search reached maximum number of iterations, failed
|
|
1085
|
+
False: line search succeeded
|
|
1086
|
+
Notes:
|
|
1087
|
+
. If encountering line search failure in the deterministic setting, one
|
|
1088
|
+
should try increasing the maximum number of line search steps max_ls.
|
|
1089
|
+
"""
|
|
1090
|
+
|
|
1091
|
+
# load options for damping and eps
|
|
1092
|
+
if "damping" not in options.keys():
|
|
1093
|
+
damping = False
|
|
1094
|
+
else:
|
|
1095
|
+
damping = options["damping"]
|
|
1096
|
+
|
|
1097
|
+
if "eps" not in options.keys():
|
|
1098
|
+
eps = 1e-2
|
|
1099
|
+
else:
|
|
1100
|
+
eps = options["eps"]
|
|
1101
|
+
|
|
1102
|
+
# gather gradient
|
|
1103
|
+
grad = self._gather_flat_grad()
|
|
1104
|
+
|
|
1105
|
+
# update curvature if after 1st iteration
|
|
1106
|
+
state = self.state["global_state"]
|
|
1107
|
+
if state["n_iter"] > 0:
|
|
1108
|
+
self.curvature_update(grad, eps, damping)
|
|
1109
|
+
|
|
1110
|
+
# compute search direction
|
|
1111
|
+
p = self.two_loop_recursion(-grad)
|
|
1112
|
+
|
|
1113
|
+
# take step
|
|
1114
|
+
return self._step(p, grad, options=options)
|