tigramite-fast 5.2.10.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. tigramite/__init__.py +0 -0
  2. tigramite/causal_effects.py +1525 -0
  3. tigramite/causal_mediation.py +1592 -0
  4. tigramite/data_processing.py +1574 -0
  5. tigramite/graphs.py +1509 -0
  6. tigramite/independence_tests/LBFGS.py +1114 -0
  7. tigramite/independence_tests/__init__.py +0 -0
  8. tigramite/independence_tests/cmiknn.py +661 -0
  9. tigramite/independence_tests/cmiknn_mixed.py +1397 -0
  10. tigramite/independence_tests/cmisymb.py +286 -0
  11. tigramite/independence_tests/gpdc.py +664 -0
  12. tigramite/independence_tests/gpdc_torch.py +820 -0
  13. tigramite/independence_tests/gsquared.py +190 -0
  14. tigramite/independence_tests/independence_tests_base.py +1310 -0
  15. tigramite/independence_tests/oracle_conditional_independence.py +1582 -0
  16. tigramite/independence_tests/pairwise_CI.py +383 -0
  17. tigramite/independence_tests/parcorr.py +369 -0
  18. tigramite/independence_tests/parcorr_mult.py +485 -0
  19. tigramite/independence_tests/parcorr_wls.py +451 -0
  20. tigramite/independence_tests/regressionCI.py +403 -0
  21. tigramite/independence_tests/robust_parcorr.py +403 -0
  22. tigramite/jpcmciplus.py +966 -0
  23. tigramite/lpcmci.py +3649 -0
  24. tigramite/models.py +2257 -0
  25. tigramite/pcmci.py +3935 -0
  26. tigramite/pcmci_base.py +1218 -0
  27. tigramite/plotting.py +4735 -0
  28. tigramite/rpcmci.py +467 -0
  29. tigramite/toymodels/__init__.py +0 -0
  30. tigramite/toymodels/context_model.py +261 -0
  31. tigramite/toymodels/non_additive.py +1231 -0
  32. tigramite/toymodels/structural_causal_processes.py +1201 -0
  33. tigramite/toymodels/surrogate_generator.py +319 -0
  34. tigramite_fast-5.2.10.1.dist-info/METADATA +182 -0
  35. tigramite_fast-5.2.10.1.dist-info/RECORD +38 -0
  36. tigramite_fast-5.2.10.1.dist-info/WHEEL +5 -0
  37. tigramite_fast-5.2.10.1.dist-info/licenses/license.txt +621 -0
  38. tigramite_fast-5.2.10.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1114 @@
1
+ """
2
+ MIT License
3
+
4
+ Copyright (c) 2017 Jake Gardner
5
+
6
+ Permission is hereby granted, free of charge, to any person obtaining a copy
7
+ of this software and associated documentation files (the "Software"), to deal
8
+ in the Software without restriction, including without limitation the rights
9
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
+ copies of the Software, and to permit persons to whom the Software is
11
+ furnished to do so, subject to the following conditions:
12
+
13
+ The above copyright notice and this permission notice shall be included in all
14
+ copies or substantial portions of the Software.
15
+
16
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
+ SOFTWARE.
23
+ """
24
+
25
+ import torch
26
+ import numpy as np
27
+ # import matplotlib.pyplot as plt
28
+ from functools import reduce
29
+ from copy import deepcopy
30
+ from torch.optim import Optimizer
31
+
32
+
33
+ def is_legal(v):
34
+ """
35
+ Checks that tensor is not NaN or Inf.
36
+ Inputs:
37
+ v (tensor): tensor to be checked
38
+ """
39
+ legal = not torch.isnan(v).any() and not torch.isinf(v)
40
+
41
+ return legal
42
+
43
+
44
+ def polyinterp(points, x_min_bound=None, x_max_bound=None, plot=False):
45
+ """
46
+ Gives the minimizer and minimum of the interpolating polynomial over given points
47
+ based on function and derivative information. Defaults to bisection if no critical
48
+ points are valid.
49
+ Based on polyinterp.m Matlab function in minFunc by Mark Schmidt with some slight
50
+ modifications.
51
+ Implemented by: Hao-Jun Michael Shi and Dheevatsa Mudigere
52
+ Last edited 12/6/18.
53
+ Inputs:
54
+ points (nparray): two-dimensional array with each point of form [x f g]
55
+ x_min_bound (float): minimum value that brackets minimum (default: minimum of points)
56
+ x_max_bound (float): maximum value that brackets minimum (default: maximum of points)
57
+ plot (bool): plot interpolating polynomial
58
+ Outputs:
59
+ x_sol (float): minimizer of interpolating polynomial
60
+ F_min (float): minimum of interpolating polynomial
61
+ Note:
62
+ . Set f or g to np.nan if they are unknown
63
+ """
64
+ no_points = points.shape[0]
65
+ order = np.sum(1 - np.isnan(points[:, 1:3]).astype("int")) - 1
66
+
67
+ x_min = np.min(points[:, 0])
68
+ x_max = np.max(points[:, 0])
69
+
70
+ # compute bounds of interpolation area
71
+ if x_min_bound is None:
72
+ x_min_bound = x_min
73
+ if x_max_bound is None:
74
+ x_max_bound = x_max
75
+
76
+ # explicit formula for quadratic interpolation
77
+ if no_points == 2 and order == 2 and plot is False:
78
+ # Solution to quadratic interpolation is given by:
79
+ # a = -(f1 - f2 - g1(x1 - x2))/(x1 - x2)^2
80
+ # x_min = x1 - g1/(2a)
81
+ # if x1 = 0, then is given by:
82
+ # x_min = - (g1*x2^2)/(2(f2 - f1 - g1*x2))
83
+
84
+ if points[0, 0] == 0:
85
+ x_sol = (
86
+ -points[0, 2] * points[1, 0] ** 2 / (2 * (points[1, 1] - points[0, 1] - points[0, 2] * points[1, 0]))
87
+ )
88
+ else:
89
+ a = (
90
+ -(points[0, 1] - points[1, 1] - points[0, 2] * (points[0, 0] - points[1, 0]))
91
+ / (points[0, 0] - points[1, 0]) ** 2
92
+ )
93
+ x_sol = points[0, 0] - points[0, 2] / (2 * a)
94
+
95
+ x_sol = np.minimum(np.maximum(x_min_bound, x_sol), x_max_bound)
96
+
97
+ # explicit formula for cubic interpolation
98
+ elif no_points == 2 and order == 3 and plot is False:
99
+ # Solution to cubic interpolation is given by:
100
+ # d1 = g1 + g2 - 3((f1 - f2)/(x1 - x2))
101
+ # d2 = sqrt(d1^2 - g1*g2)
102
+ # x_min = x2 - (x2 - x1)*((g2 + d2 - d1)/(g2 - g1 + 2*d2))
103
+ d1 = points[0, 2] + points[1, 2] - 3 * ((points[0, 1] - points[1, 1]) / (points[0, 0] - points[1, 0]))
104
+ d2 = np.sqrt(d1 ** 2 - points[0, 2] * points[1, 2])
105
+ if np.isreal(d2):
106
+ x_sol = points[1, 0] - (points[1, 0] - points[0, 0]) * (
107
+ (points[1, 2] + d2 - d1) / (points[1, 2] - points[0, 2] + 2 * d2)
108
+ )
109
+ x_sol = np.minimum(np.maximum(x_min_bound, x_sol), x_max_bound)
110
+ else:
111
+ x_sol = (x_max_bound + x_min_bound) / 2
112
+
113
+ # solve linear system
114
+ else:
115
+ # define linear constraints
116
+ A = np.zeros((0, order + 1))
117
+ b = np.zeros((0, 1))
118
+
119
+ # add linear constraints on function values
120
+ for i in range(no_points):
121
+ if not np.isnan(points[i, 1]):
122
+ constraint = np.zeros((1, order + 1))
123
+ for j in range(order, -1, -1):
124
+ constraint[0, order - j] = points[i, 0] ** j
125
+ A = np.append(A, constraint, 0)
126
+ b = np.append(b, points[i, 1])
127
+
128
+ # add linear constraints on gradient values
129
+ for i in range(no_points):
130
+ if not np.isnan(points[i, 2]):
131
+ constraint = np.zeros((1, order + 1))
132
+ for j in range(order):
133
+ constraint[0, j] = (order - j) * points[i, 0] ** (order - j - 1)
134
+ A = np.append(A, constraint, 0)
135
+ b = np.append(b, points[i, 2])
136
+
137
+ # check if system is solvable
138
+ if A.shape[0] != A.shape[1] or np.linalg.matrix_rank(A) != A.shape[0]:
139
+ x_sol = (x_min_bound + x_max_bound) / 2
140
+ f_min = np.Inf
141
+ else:
142
+ # solve linear system for interpolating polynomial
143
+ coeff = np.linalg.solve(A, b)
144
+
145
+ # compute critical points
146
+ dcoeff = np.zeros(order)
147
+ for i in range(len(coeff) - 1):
148
+ dcoeff[i] = coeff[i] * (order - i)
149
+
150
+ crit_pts = np.array([x_min_bound, x_max_bound])
151
+ crit_pts = np.append(crit_pts, points[:, 0])
152
+
153
+ if not np.isinf(dcoeff).any():
154
+ roots = np.roots(dcoeff)
155
+ crit_pts = np.append(crit_pts, roots)
156
+
157
+ # test critical points
158
+ f_min = np.Inf
159
+ x_sol = (x_min_bound + x_max_bound) / 2 # defaults to bisection
160
+ for crit_pt in crit_pts:
161
+ if np.isreal(crit_pt) and crit_pt >= x_min_bound and crit_pt <= x_max_bound:
162
+ F_cp = np.polyval(coeff, crit_pt)
163
+ if np.isreal(F_cp) and F_cp < f_min:
164
+ x_sol = np.real(crit_pt)
165
+ f_min = np.real(F_cp)
166
+
167
+ # if plot:
168
+ # plt.figure()
169
+ # x = np.arange(x_min_bound, x_max_bound, (x_max_bound - x_min_bound) / 10000)
170
+ # f = np.polyval(coeff, x)
171
+ # plt.plot(x, f)
172
+ # plt.plot(x_sol, f_min, "x")
173
+
174
+ return x_sol
175
+
176
+
177
+ class LBFGS(Optimizer):
178
+ """
179
+ Implements the L-BFGS algorithm. Compatible with multi-batch and full-overlap
180
+ L-BFGS implementations and (stochastic) Powell damping. Partly based on the
181
+ original L-BFGS implementation in PyTorch, Mark Schmidt's minFunc MATLAB code,
182
+ and Michael Overton's weak Wolfe line search MATLAB code.
183
+ Implemented by: Hao-Jun Michael Shi and Dheevatsa Mudigere
184
+ Last edited 12/6/18.
185
+ Warnings:
186
+ . Does not support per-parameter options and parameter groups.
187
+ . All parameters have to be on a single device.
188
+ Inputs:
189
+ lr (float): steplength or learning rate (default: 1)
190
+ history_size (int): update history size (default: 10)
191
+ line_search (str): designates line search to use (default: 'Wolfe')
192
+ Options:
193
+ 'None': uses steplength designated in algorithm
194
+ 'Armijo': uses Armijo backtracking line search
195
+ 'Wolfe': uses Armijo-Wolfe bracketing line search
196
+ dtype: data type (default: torch.float)
197
+ debug (bool): debugging mode
198
+ References:
199
+ [1] Berahas, Albert S., Jorge Nocedal, and Martin Takác. "A Multi-Batch L-BFGS
200
+ Method for Machine Learning." Advances in Neural Information Processing
201
+ Systems. 2016.
202
+ [2] Bollapragada, Raghu, et al. "A Progressive Batching L-BFGS Method for Machine
203
+ Learning." International Conference on Machine Learning. 2018.
204
+ [3] Lewis, Adrian S., and Michael L. Overton. "Nonsmooth Optimization via Quasi-Newton
205
+ Methods." Mathematical Programming 141.1-2 (2013): 135-163.
206
+ [4] Liu, Dong C., and Jorge Nocedal. "On the Limited Memory BFGS Method for
207
+ Large Scale Optimization." Mathematical Programming 45.1-3 (1989): 503-528.
208
+ [5] Nocedal, Jorge. "Updating Quasi-Newton Matrices With Limited Storage."
209
+ Mathematics of Computation 35.151 (1980): 773-782.
210
+ [6] Nocedal, Jorge, and Stephen J. Wright. "Numerical Optimization." Springer New York,
211
+ 2006.
212
+ [7] Schmidt, Mark. "minFunc: Unconstrained Differentiable Multivariate Optimization
213
+ in Matlab." Software available at http://www.cs.ubc.ca/~schmidtm/Software/minFunc.html
214
+ (2005).
215
+ [8] Schraudolph, Nicol N., Jin Yu, and Simon Günter. "A Stochastic Quasi-Newton
216
+ Method for Online Convex Optimization." Artificial Intelligence and Statistics.
217
+ 2007.
218
+ [9] Wang, Xiao, et al. "Stochastic Quasi-Newton Methods for Nonconvex Stochastic
219
+ Optimization." SIAM Journal on Optimization 27.2 (2017): 927-956.
220
+ """
221
+
222
+ def __init__(self, params, lr=1, history_size=10, line_search="Wolfe", dtype=torch.float, debug=False):
223
+
224
+ # ensure inputs are valid
225
+ if not 0.0 <= lr:
226
+ raise ValueError("Invalid learning rate: {}".format(lr))
227
+ if not 0 <= history_size:
228
+ raise ValueError("Invalid history size: {}".format(history_size))
229
+ if line_search not in ["Armijo", "Wolfe", "None"]:
230
+ raise ValueError("Invalid line search: {}".format(line_search))
231
+
232
+ defaults = dict(lr=lr, history_size=history_size, line_search=line_search, dtype=dtype, debug=debug)
233
+ super(LBFGS, self).__init__(params, defaults)
234
+
235
+ if len(self.param_groups) != 1:
236
+ raise ValueError("L-BFGS doesn't support per-parameter options " "(parameter groups)")
237
+
238
+ self._params = self.param_groups[0]["params"]
239
+ self._numel_cache = None
240
+
241
+ state = self.state["global_state"]
242
+ state.setdefault("n_iter", 0)
243
+ state.setdefault("curv_skips", 0)
244
+ state.setdefault("fail_skips", 0)
245
+ state.setdefault("H_diag", 1)
246
+ state.setdefault("fail", True)
247
+
248
+ state["old_dirs"] = []
249
+ state["old_stps"] = []
250
+
251
+ def _numel(self):
252
+ if self._numel_cache is None:
253
+ self._numel_cache = reduce(lambda total, p: total + p.numel(), self._params, 0)
254
+ return self._numel_cache
255
+
256
+ def _gather_flat_grad(self):
257
+ views = []
258
+ for p in self._params:
259
+ if p.grad is None:
260
+ view = p.data.new(p.data.numel()).zero_()
261
+ elif p.grad.data.is_sparse:
262
+ view = p.grad.data.to_dense().view(-1)
263
+ else:
264
+ view = p.grad.data.view(-1)
265
+ views.append(view)
266
+ return torch.cat(views, 0)
267
+
268
+ def _add_update(self, step_size, update):
269
+ offset = 0
270
+ for p in self._params:
271
+ numel = p.numel()
272
+ # view as to avoid deprecated pointwise semantics
273
+ p.data.add_(update[offset : offset + numel].view_as(p.data), alpha=step_size)
274
+ offset += numel
275
+ assert offset == self._numel()
276
+
277
+ def _copy_params(self):
278
+ current_params = []
279
+ for param in self._params:
280
+ current_params.append(deepcopy(param.data))
281
+ return current_params
282
+
283
+ def _load_params(self, current_params):
284
+ i = 0
285
+ for param in self._params:
286
+ param.data[:] = current_params[i]
287
+ i += 1
288
+
289
+ def line_search(self, line_search):
290
+ """
291
+ Switches line search option.
292
+ Inputs:
293
+ line_search (str): designates line search to use
294
+ Options:
295
+ 'None': uses steplength designated in algorithm
296
+ 'Armijo': uses Armijo backtracking line search
297
+ 'Wolfe': uses Armijo-Wolfe bracketing line search
298
+ """
299
+
300
+ group = self.param_groups[0]
301
+ group["line_search"] = line_search
302
+
303
+ return
304
+
305
+ def two_loop_recursion(self, vec):
306
+ """
307
+ Performs two-loop recursion on given vector to obtain Hv.
308
+ Inputs:
309
+ vec (tensor): 1-D tensor to apply two-loop recursion to
310
+ Output:
311
+ r (tensor): matrix-vector product Hv
312
+ """
313
+
314
+ group = self.param_groups[0]
315
+ history_size = group["history_size"]
316
+
317
+ state = self.state["global_state"]
318
+ old_dirs = state.get("old_dirs") # change in gradients
319
+ old_stps = state.get("old_stps") # change in iterates
320
+ H_diag = state.get("H_diag")
321
+
322
+ # compute the product of the inverse Hessian approximation and the gradient
323
+ num_old = len(old_dirs)
324
+
325
+ if "rho" not in state:
326
+ state["rho"] = [None] * history_size
327
+ state["alpha"] = [None] * history_size
328
+ rho = state["rho"]
329
+ alpha = state["alpha"]
330
+
331
+ for i in range(num_old):
332
+ rho[i] = 1.0 / old_stps[i].dot(old_dirs[i])
333
+
334
+ q = vec
335
+ for i in range(num_old - 1, -1, -1):
336
+ alpha[i] = old_dirs[i].dot(q) * rho[i]
337
+ q.add_(old_stps[i], alpha=-alpha[i])
338
+
339
+ # multiply by initial Hessian
340
+ # r/d is the final direction
341
+ r = torch.mul(q, H_diag)
342
+ for i in range(num_old):
343
+ beta = old_stps[i].dot(r) * rho[i]
344
+ r.add_(old_dirs[i], alpha=(alpha[i] - beta))
345
+
346
+ return r
347
+
348
+ def curvature_update(self, flat_grad, eps=1e-2, damping=False):
349
+ """
350
+ Performs curvature update.
351
+ Inputs:
352
+ flat_grad (tensor): 1-D tensor of flattened gradient for computing
353
+ gradient difference with previously stored gradient
354
+ eps (float): constant for curvature pair rejection or damping (default: 1e-2)
355
+ damping (bool): flag for using Powell damping (default: False)
356
+ """
357
+
358
+ assert len(self.param_groups) == 1
359
+
360
+ # load parameters
361
+ if eps <= 0:
362
+ raise (ValueError("Invalid eps; must be positive."))
363
+
364
+ group = self.param_groups[0]
365
+ history_size = group["history_size"]
366
+ debug = group["debug"]
367
+
368
+ # variables cached in state (for tracing)
369
+ state = self.state["global_state"]
370
+ fail = state.get("fail")
371
+
372
+ # check if line search failed
373
+ if not fail:
374
+
375
+ d = state.get("d")
376
+ t = state.get("t")
377
+ old_dirs = state.get("old_dirs")
378
+ old_stps = state.get("old_stps")
379
+ H_diag = state.get("H_diag")
380
+ prev_flat_grad = state.get("prev_flat_grad")
381
+ Bs = state.get("Bs")
382
+
383
+ # compute y's
384
+ y = flat_grad.sub(prev_flat_grad)
385
+ s = d.mul(t)
386
+ sBs = s.dot(Bs)
387
+ ys = y.dot(s) # y*s
388
+
389
+ # update L-BFGS matrix
390
+ if ys > eps * sBs or damping == True:
391
+
392
+ # perform Powell damping
393
+ if damping == True and ys < eps * sBs:
394
+ if debug:
395
+ print("Applying Powell damping...")
396
+ theta = ((1 - eps) * sBs) / (sBs - ys)
397
+ y = theta * y + (1 - theta) * Bs
398
+
399
+ # updating memory
400
+ if len(old_dirs) == history_size:
401
+ # shift history by one (limited-memory)
402
+ old_dirs.pop(0)
403
+ old_stps.pop(0)
404
+
405
+ # store new direction/step
406
+ old_dirs.append(s)
407
+ old_stps.append(y)
408
+
409
+ # update scale of initial Hessian approximation
410
+ H_diag = ys / y.dot(y) # (y*y)
411
+
412
+ state["old_dirs"] = old_dirs
413
+ state["old_stps"] = old_stps
414
+ state["H_diag"] = H_diag
415
+
416
+ else:
417
+ # save skip
418
+ state["curv_skips"] += 1
419
+ if debug:
420
+ print("Curvature pair skipped due to failed criterion")
421
+
422
+ else:
423
+ # save skip
424
+ state["fail_skips"] += 1
425
+ if debug:
426
+ print("Line search failed; curvature pair update skipped")
427
+
428
+ return
429
+
430
+ def _step(self, p_k, g_Ok, g_Sk=None, options={}):
431
+ """
432
+ Performs a single optimization step.
433
+ Inputs:
434
+ p_k (tensor): 1-D tensor specifying search direction
435
+ g_Ok (tensor): 1-D tensor of flattened gradient over overlap O_k used
436
+ for gradient differencing in curvature pair update
437
+ g_Sk (tensor): 1-D tensor of flattened gradient over full sample S_k
438
+ used for curvature pair damping or rejection criterion,
439
+ if None, will use g_Ok (default: None)
440
+ options (dict): contains options for performing line search
441
+ Options for Armijo backtracking line search:
442
+ 'closure' (callable): reevaluates model and returns function value
443
+ 'current_loss' (tensor): objective value at current iterate (default: F(x_k))
444
+ 'gtd' (tensor): inner product g_Ok'd in line search (default: g_Ok'd)
445
+ 'eta' (tensor): factor for decreasing steplength > 0 (default: 2)
446
+ 'c1' (tensor): sufficient decrease constant in (0, 1) (default: 1e-4)
447
+ 'max_ls' (int): maximum number of line search steps permitted (default: 10)
448
+ 'interpolate' (bool): flag for using interpolation (default: True)
449
+ 'inplace' (bool): flag for inplace operations (default: True)
450
+ 'ls_debug' (bool): debugging mode for line search
451
+ Options for Wolfe line search:
452
+ 'closure' (callable): reevaluates model and returns function value
453
+ 'current_loss' (tensor): objective value at current iterate (default: F(x_k))
454
+ 'gtd' (tensor): inner product g_Ok'd in line search (default: g_Ok'd)
455
+ 'eta' (float): factor for extrapolation (default: 2)
456
+ 'c1' (float): sufficient decrease constant in (0, 1) (default: 1e-4)
457
+ 'c2' (float): curvature condition constant in (0, 1) (default: 0.9)
458
+ 'max_ls' (int): maximum number of line search steps permitted (default: 10)
459
+ 'interpolate' (bool): flag for using interpolation (default: True)
460
+ 'inplace' (bool): flag for inplace operations (default: True)
461
+ 'ls_debug' (bool): debugging mode for line search
462
+ Outputs (depends on line search):
463
+ . No line search:
464
+ t (float): steplength
465
+ . Armijo backtracking line search:
466
+ F_new (tensor): loss function at new iterate
467
+ t (tensor): final steplength
468
+ ls_step (int): number of backtracks
469
+ closure_eval (int): number of closure evaluations
470
+ desc_dir (bool): descent direction flag
471
+ True: p_k is descent direction with respect to the line search
472
+ function
473
+ False: p_k is not a descent direction with respect to the line
474
+ search function
475
+ fail (bool): failure flag
476
+ True: line search reached maximum number of iterations, failed
477
+ False: line search succeeded
478
+ . Wolfe line search:
479
+ F_new (tensor): loss function at new iterate
480
+ g_new (tensor): gradient at new iterate
481
+ t (float): final steplength
482
+ ls_step (int): number of backtracks
483
+ closure_eval (int): number of closure evaluations
484
+ grad_eval (int): number of gradient evaluations
485
+ desc_dir (bool): descent direction flag
486
+ True: p_k is descent direction with respect to the line search
487
+ function
488
+ False: p_k is not a descent direction with respect to the line
489
+ search function
490
+ fail (bool): failure flag
491
+ True: line search reached maximum number of iterations, failed
492
+ False: line search succeeded
493
+ Notes:
494
+ . If encountering line search failure in the deterministic setting, one
495
+ should try increasing the maximum number of line search steps max_ls.
496
+ """
497
+
498
+ assert len(self.param_groups) == 1
499
+
500
+ # load parameter options
501
+ group = self.param_groups[0]
502
+ lr = group["lr"]
503
+ line_search = group["line_search"]
504
+ dtype = group["dtype"]
505
+ debug = group["debug"]
506
+
507
+ # variables cached in state (for tracing)
508
+ state = self.state["global_state"]
509
+ d = state.get("d")
510
+ t = state.get("t")
511
+ prev_flat_grad = state.get("prev_flat_grad")
512
+ Bs = state.get("Bs")
513
+
514
+ # keep track of nb of iterations
515
+ state["n_iter"] += 1
516
+
517
+ # set search direction
518
+ d = p_k
519
+
520
+ # modify previous gradient
521
+ if prev_flat_grad is None:
522
+ prev_flat_grad = g_Ok.clone()
523
+ else:
524
+ prev_flat_grad.copy_(g_Ok)
525
+
526
+ # set initial step size
527
+ t = lr
528
+
529
+ # closure evaluation counter
530
+ closure_eval = 0
531
+
532
+ if g_Sk is None:
533
+ g_Sk = g_Ok.clone()
534
+
535
+ # perform Armijo backtracking line search
536
+ if line_search == "Armijo":
537
+
538
+ # load options
539
+ if options:
540
+ if "closure" not in options.keys():
541
+ raise (ValueError("closure option not specified."))
542
+ else:
543
+ closure = options["closure"]
544
+
545
+ if "gtd" not in options.keys():
546
+ gtd = g_Ok.dot(d)
547
+ else:
548
+ gtd = options["gtd"]
549
+
550
+ if "current_loss" not in options.keys():
551
+ F_k = closure()
552
+ closure_eval += 1
553
+ else:
554
+ F_k = options["current_loss"]
555
+
556
+ if "eta" not in options.keys():
557
+ eta = 2
558
+ elif options["eta"] <= 0:
559
+ raise (ValueError("Invalid eta; must be positive."))
560
+ else:
561
+ eta = options["eta"]
562
+
563
+ if "c1" not in options.keys():
564
+ c1 = 1e-4
565
+ elif options["c1"] >= 1 or options["c1"] <= 0:
566
+ raise (ValueError("Invalid c1; must be strictly between 0 and 1."))
567
+ else:
568
+ c1 = options["c1"]
569
+
570
+ if "max_ls" not in options.keys():
571
+ max_ls = 10
572
+ elif options["max_ls"] <= 0:
573
+ raise (ValueError("Invalid max_ls; must be positive."))
574
+ else:
575
+ max_ls = options["max_ls"]
576
+
577
+ if "interpolate" not in options.keys():
578
+ interpolate = True
579
+ else:
580
+ interpolate = options["interpolate"]
581
+
582
+ if "inplace" not in options.keys():
583
+ inplace = True
584
+ else:
585
+ inplace = options["inplace"]
586
+
587
+ if "ls_debug" not in options.keys():
588
+ ls_debug = False
589
+ else:
590
+ ls_debug = options["ls_debug"]
591
+
592
+ else:
593
+ raise (ValueError("Options are not specified; need closure evaluating function."))
594
+
595
+ # initialize values
596
+ if interpolate:
597
+ if torch.cuda.is_available():
598
+ F_prev = torch.tensor(np.nan, dtype=dtype).cuda()
599
+ else:
600
+ F_prev = torch.tensor(np.nan, dtype=dtype)
601
+
602
+ ls_step = 0
603
+ t_prev = 0 # old steplength
604
+ fail = False # failure flag
605
+
606
+ # begin print for debug mode
607
+ if ls_debug:
608
+ print(
609
+ "==================================== Begin Armijo line search ==================================="
610
+ )
611
+ print("F(x): %.8e g*d: %.8e" % (F_k, gtd))
612
+
613
+ # check if search direction is descent direction
614
+ if gtd >= 0:
615
+ desc_dir = False
616
+ if debug:
617
+ print("Not a descent direction!")
618
+ else:
619
+ desc_dir = True
620
+
621
+ # store values if not in-place
622
+ if not inplace:
623
+ current_params = self._copy_params()
624
+
625
+ # update and evaluate at new point
626
+ self._add_update(t, d)
627
+ F_new = closure()
628
+ closure_eval += 1
629
+
630
+ # print info if debugging
631
+ if ls_debug:
632
+ print(
633
+ "LS Step: %d t: %.8e F(x+td): %.8e F-c1*t*g*d: %.8e F(x): %.8e"
634
+ % (ls_step, t, F_new, F_k + c1 * t * gtd, F_k)
635
+ )
636
+
637
+ # check Armijo condition
638
+ while F_new > F_k + c1 * t * gtd or not is_legal(F_new):
639
+
640
+ # check if maximum number of iterations reached
641
+ if ls_step >= max_ls:
642
+ if inplace:
643
+ self._add_update(-t, d)
644
+ else:
645
+ self._load_params(current_params)
646
+
647
+ t = 0
648
+ F_new = closure()
649
+ closure_eval += 1
650
+ fail = True
651
+ break
652
+
653
+ else:
654
+ # store current steplength
655
+ t_new = t
656
+
657
+ # compute new steplength
658
+
659
+ # if first step or not interpolating, then multiply by factor
660
+ if ls_step == 0 or not interpolate or not is_legal(F_new):
661
+ t = t / eta
662
+
663
+ # if second step, use function value at new point along with
664
+ # gradient and function at current iterate
665
+ elif ls_step == 1 or not is_legal(F_prev):
666
+ t = polyinterp(np.array([[0, F_k.item(), gtd.item()], [t_new, F_new.item(), np.nan]]))
667
+
668
+ # otherwise, use function values at new point, previous point,
669
+ # and gradient and function at current iterate
670
+ else:
671
+ t = polyinterp(
672
+ np.array(
673
+ [
674
+ [0, F_k.item(), gtd.item()],
675
+ [t_new, F_new.item(), np.nan],
676
+ [t_prev, F_prev.item(), np.nan],
677
+ ]
678
+ )
679
+ )
680
+
681
+ # if values are too extreme, adjust t
682
+ if interpolate:
683
+ if t < 1e-3 * t_new:
684
+ t = 1e-3 * t_new
685
+ elif t > 0.6 * t_new:
686
+ t = 0.6 * t_new
687
+
688
+ # store old point
689
+ F_prev = F_new
690
+ t_prev = t_new
691
+
692
+ # update iterate and reevaluate
693
+ if inplace:
694
+ self._add_update(t - t_new, d)
695
+ else:
696
+ self._load_params(current_params)
697
+ self._add_update(t, d)
698
+
699
+ F_new = closure()
700
+ closure_eval += 1
701
+ ls_step += 1 # iterate
702
+
703
+ # print info if debugging
704
+ if ls_debug:
705
+ print(
706
+ "LS Step: %d t: %.8e F(x+td): %.8e F-c1*t*g*d: %.8e F(x): %.8e"
707
+ % (ls_step, t, F_new, F_k + c1 * t * gtd, F_k)
708
+ )
709
+
710
+ # store Bs
711
+ if Bs is None:
712
+ Bs = (g_Sk.mul(-t)).clone()
713
+ else:
714
+ Bs.copy_(g_Sk.mul(-t))
715
+
716
+ # print final steplength
717
+ if ls_debug:
718
+ print("Final Steplength:", t)
719
+ print(
720
+ "===================================== End Armijo line search ===================================="
721
+ )
722
+
723
+ state["d"] = d
724
+ state["prev_flat_grad"] = prev_flat_grad
725
+ state["t"] = t
726
+ state["Bs"] = Bs
727
+ state["fail"] = fail
728
+
729
+ return F_new, t, ls_step, closure_eval, desc_dir, fail
730
+
731
+ # perform weak Wolfe line search
732
+ elif line_search == "Wolfe":
733
+
734
+ # load options
735
+ if options:
736
+ if "closure" not in options.keys():
737
+ raise (ValueError("closure option not specified."))
738
+ else:
739
+ closure = options["closure"]
740
+
741
+ if "current_loss" not in options.keys():
742
+ F_k = closure()
743
+ closure_eval += 1
744
+ else:
745
+ F_k = options["current_loss"]
746
+
747
+ if "gtd" not in options.keys():
748
+ gtd = g_Ok.dot(d)
749
+ else:
750
+ gtd = options["gtd"]
751
+
752
+ if "eta" not in options.keys():
753
+ eta = 2
754
+ elif options["eta"] <= 1:
755
+ raise (ValueError("Invalid eta; must be greater than 1."))
756
+ else:
757
+ eta = options["eta"]
758
+
759
+ if "c1" not in options.keys():
760
+ c1 = 1e-4
761
+ elif options["c1"] >= 1 or options["c1"] <= 0:
762
+ raise (ValueError("Invalid c1; must be strictly between 0 and 1."))
763
+ else:
764
+ c1 = options["c1"]
765
+
766
+ if "c2" not in options.keys():
767
+ c2 = 0.9
768
+ elif options["c2"] >= 1 or options["c2"] <= 0:
769
+ raise (ValueError("Invalid c2; must be strictly between 0 and 1."))
770
+ elif options["c2"] <= c1:
771
+ raise (ValueError("Invalid c2; must be strictly larger than c1."))
772
+ else:
773
+ c2 = options["c2"]
774
+
775
+ if "max_ls" not in options.keys():
776
+ max_ls = 10
777
+ elif options["max_ls"] <= 0:
778
+ raise (ValueError("Invalid max_ls; must be positive."))
779
+ else:
780
+ max_ls = options["max_ls"]
781
+
782
+ if "interpolate" not in options.keys():
783
+ interpolate = True
784
+ else:
785
+ interpolate = options["interpolate"]
786
+
787
+ if "inplace" not in options.keys():
788
+ inplace = True
789
+ else:
790
+ inplace = options["inplace"]
791
+
792
+ if "ls_debug" not in options.keys():
793
+ ls_debug = False
794
+ else:
795
+ ls_debug = options["ls_debug"]
796
+
797
+ else:
798
+ raise (ValueError("Options are not specified; need closure evaluating function."))
799
+
800
+ # initialize counters
801
+ ls_step = 0
802
+ grad_eval = 0 # tracks gradient evaluations
803
+ t_prev = 0 # old steplength
804
+
805
+ # initialize bracketing variables and flag
806
+ alpha = 0
807
+ beta = float("Inf")
808
+ fail = False
809
+
810
+ # initialize values for line search
811
+ if interpolate:
812
+ F_a = F_k
813
+ g_a = gtd
814
+
815
+ if torch.cuda.is_available():
816
+ F_b = torch.tensor(np.nan, dtype=dtype).cuda()
817
+ g_b = torch.tensor(np.nan, dtype=dtype).cuda()
818
+ else:
819
+ F_b = torch.tensor(np.nan, dtype=dtype)
820
+ g_b = torch.tensor(np.nan, dtype=dtype)
821
+
822
+ # begin print for debug mode
823
+ if ls_debug:
824
+ print(
825
+ "==================================== Begin Wolfe line search ===================================="
826
+ )
827
+ print("F(x): %.8e g*d: %.8e" % (F_k, gtd))
828
+
829
+ # check if search direction is descent direction
830
+ if gtd >= 0:
831
+ desc_dir = False
832
+ if debug:
833
+ print("Not a descent direction!")
834
+ else:
835
+ desc_dir = True
836
+
837
+ # store values if not in-place
838
+ if not inplace:
839
+ current_params = self._copy_params()
840
+
841
+ # update and evaluate at new point
842
+ self._add_update(t, d)
843
+ F_new = closure()
844
+ closure_eval += 1
845
+
846
+ # main loop
847
+ while True:
848
+
849
+ # check if maximum number of line search steps have been reached
850
+ if ls_step >= max_ls:
851
+ if inplace:
852
+ self._add_update(-t, d)
853
+ else:
854
+ self._load_params(current_params)
855
+
856
+ t = 0
857
+ F_new = closure()
858
+ F_new.backward()
859
+ g_new = self._gather_flat_grad()
860
+ closure_eval += 1
861
+ grad_eval += 1
862
+ fail = True
863
+ break
864
+
865
+ # print info if debugging
866
+ if ls_debug:
867
+ print("LS Step: %d t: %.8e alpha: %.8e beta: %.8e" % (ls_step, t, alpha, beta))
868
+ print("Armijo: F(x+td): %.8e F-c1*t*g*d: %.8e F(x): %.8e" % (F_new, F_k + c1 * t * gtd, F_k))
869
+
870
+ # check Armijo condition
871
+ if F_new > F_k + c1 * t * gtd:
872
+
873
+ # set upper bound
874
+ beta = t
875
+ t_prev = t
876
+
877
+ # update interpolation quantities
878
+ if interpolate:
879
+ F_b = F_new
880
+ if torch.cuda.is_available():
881
+ g_b = torch.tensor(np.nan, dtype=dtype).cuda()
882
+ else:
883
+ g_b = torch.tensor(np.nan, dtype=dtype)
884
+
885
+ else:
886
+
887
+ # compute gradient
888
+ F_new.backward()
889
+ g_new = self._gather_flat_grad()
890
+ grad_eval += 1
891
+ gtd_new = g_new.dot(d)
892
+
893
+ # print info if debugging
894
+ if ls_debug:
895
+ print("Wolfe: g(x+td)*d: %.8e c2*g*d: %.8e gtd: %.8e" % (gtd_new, c2 * gtd, gtd))
896
+
897
+ # check curvature condition
898
+ if gtd_new < c2 * gtd:
899
+
900
+ # set lower bound
901
+ alpha = t
902
+ t_prev = t
903
+
904
+ # update interpolation quantities
905
+ if interpolate:
906
+ F_a = F_new
907
+ g_a = gtd_new
908
+
909
+ else:
910
+ break
911
+
912
+ # compute new steplength
913
+
914
+ # if first step or not interpolating, then bisect or multiply by factor
915
+ if not interpolate or not is_legal(F_b):
916
+ if beta == float("Inf"):
917
+ t = eta * t
918
+ else:
919
+ t = (alpha + beta) / 2.0
920
+
921
+ # otherwise interpolate between a and b
922
+ else:
923
+ t = polyinterp(np.array([[alpha, F_a.item(), g_a.item()], [beta, F_b.item(), g_b.item()]]))
924
+
925
+ # if values are too extreme, adjust t
926
+ if beta == float("Inf"):
927
+ if t > 2 * eta * t_prev:
928
+ t = 2 * eta * t_prev
929
+ elif t < eta * t_prev:
930
+ t = eta * t_prev
931
+ else:
932
+ if t < alpha + 0.2 * (beta - alpha):
933
+ t = alpha + 0.2 * (beta - alpha)
934
+ elif t > (beta - alpha) / 2.0:
935
+ t = (beta - alpha) / 2.0
936
+
937
+ # if we obtain nonsensical value from interpolation
938
+ if t <= 0:
939
+ t = (beta - alpha) / 2.0
940
+
941
+ # update parameters
942
+ if inplace:
943
+ self._add_update(t - t_prev, d)
944
+ else:
945
+ self._load_params(current_params)
946
+ self._add_update(t, d)
947
+
948
+ # evaluate closure
949
+ F_new = closure()
950
+ closure_eval += 1
951
+ ls_step += 1
952
+
953
+ # store Bs
954
+ if Bs is None:
955
+ Bs = (g_Sk.mul(-t)).clone()
956
+ else:
957
+ Bs.copy_(g_Sk.mul(-t))
958
+
959
+ # print final steplength
960
+ if ls_debug:
961
+ print("Final Steplength:", t)
962
+ print(
963
+ "===================================== End Wolfe line search ====================================="
964
+ )
965
+
966
+ state["d"] = d
967
+ state["prev_flat_grad"] = prev_flat_grad
968
+ state["t"] = t
969
+ state["Bs"] = Bs
970
+ state["fail"] = fail
971
+
972
+ return F_new, g_new, t, ls_step, closure_eval, grad_eval, desc_dir, fail
973
+
974
+ else:
975
+
976
+ # perform update
977
+ self._add_update(t, d)
978
+
979
+ # store Bs
980
+ if Bs is None:
981
+ Bs = (g_Sk.mul(-t)).clone()
982
+ else:
983
+ Bs.copy_(g_Sk.mul(-t))
984
+
985
+ state["d"] = d
986
+ state["prev_flat_grad"] = prev_flat_grad
987
+ state["t"] = t
988
+ state["Bs"] = Bs
989
+ state["fail"] = False
990
+
991
+ return t
992
+
993
+ def step(self, p_k, g_Ok, g_Sk=None, options={}):
994
+ return self._step(p_k, g_Ok, g_Sk, options)
995
+
996
+
997
+ #%% Full-Batch (Deterministic) L-BFGS Optimizer (Wrapper)
998
+
999
+
1000
+ class FullBatchLBFGS(LBFGS):
1001
+ """
1002
+ Implements full-batch or deterministic L-BFGS algorithm. Compatible with
1003
+ Powell damping. Can be used when evaluating a deterministic function and
1004
+ gradient. Wraps the LBFGS optimizer. Performs the two-loop recursion,
1005
+ updating, and curvature updating in a single step.
1006
+ Implemented by: Hao-Jun Michael Shi and Dheevatsa Mudigere
1007
+ Last edited 11/15/18.
1008
+ Warnings:
1009
+ . Does not support per-parameter options and parameter groups.
1010
+ . All parameters have to be on a single device.
1011
+ Inputs:
1012
+ lr (float): steplength or learning rate (default: 1)
1013
+ history_size (int): update history size (default: 10)
1014
+ line_search (str): designates line search to use (default: 'Wolfe')
1015
+ Options:
1016
+ 'None': uses steplength designated in algorithm
1017
+ 'Armijo': uses Armijo backtracking line search
1018
+ 'Wolfe': uses Armijo-Wolfe bracketing line search
1019
+ dtype: data type (default: torch.float)
1020
+ debug (bool): debugging mode
1021
+ """
1022
+
1023
+ def __init__(self, params, lr=1, history_size=10, line_search="Wolfe", dtype=torch.float, debug=False):
1024
+ super(FullBatchLBFGS, self).__init__(params, lr, history_size, line_search, dtype, debug)
1025
+
1026
+ def step(self, options={}):
1027
+ """
1028
+ Performs a single optimization step.
1029
+ Inputs:
1030
+ options (dict): contains options for performing line search
1031
+ General Options:
1032
+ 'eps' (float): constant for curvature pair rejection or damping (default: 1e-2)
1033
+ 'damping' (bool): flag for using Powell damping (default: False)
1034
+ Options for Armijo backtracking line search:
1035
+ 'closure' (callable): reevaluates model and returns function value
1036
+ 'current_loss' (tensor): objective value at current iterate (default: F(x_k))
1037
+ 'gtd' (tensor): inner product g_Ok'd in line search (default: g_Ok'd)
1038
+ 'eta' (tensor): factor for decreasing steplength > 0 (default: 2)
1039
+ 'c1' (tensor): sufficient decrease constant in (0, 1) (default: 1e-4)
1040
+ 'max_ls' (int): maximum number of line search steps permitted (default: 10)
1041
+ 'interpolate' (bool): flag for using interpolation (default: True)
1042
+ 'inplace' (bool): flag for inplace operations (default: True)
1043
+ 'ls_debug' (bool): debugging mode for line search
1044
+ Options for Wolfe line search:
1045
+ 'closure' (callable): reevaluates model and returns function value
1046
+ 'current_loss' (tensor): objective value at current iterate (default: F(x_k))
1047
+ 'gtd' (tensor): inner product g_Ok'd in line search (default: g_Ok'd)
1048
+ 'eta' (float): factor for extrapolation (default: 2)
1049
+ 'c1' (float): sufficient decrease constant in (0, 1) (default: 1e-4)
1050
+ 'c2' (float): curvature condition constant in (0, 1) (default: 0.9)
1051
+ 'max_ls' (int): maximum number of line search steps permitted (default: 10)
1052
+ 'interpolate' (bool): flag for using interpolation (default: True)
1053
+ 'inplace' (bool): flag for inplace operations (default: True)
1054
+ 'ls_debug' (bool): debugging mode for line search
1055
+ Outputs (depends on line search):
1056
+ . No line search:
1057
+ t (float): steplength
1058
+ . Armijo backtracking line search:
1059
+ F_new (tensor): loss function at new iterate
1060
+ t (tensor): final steplength
1061
+ ls_step (int): number of backtracks
1062
+ closure_eval (int): number of closure evaluations
1063
+ desc_dir (bool): descent direction flag
1064
+ True: p_k is descent direction with respect to the line search
1065
+ function
1066
+ False: p_k is not a descent direction with respect to the line
1067
+ search function
1068
+ fail (bool): failure flag
1069
+ True: line search reached maximum number of iterations, failed
1070
+ False: line search succeeded
1071
+ . Wolfe line search:
1072
+ F_new (tensor): loss function at new iterate
1073
+ g_new (tensor): gradient at new iterate
1074
+ t (float): final steplength
1075
+ ls_step (int): number of backtracks
1076
+ closure_eval (int): number of closure evaluations
1077
+ grad_eval (int): number of gradient evaluations
1078
+ desc_dir (bool): descent direction flag
1079
+ True: p_k is descent direction with respect to the line search
1080
+ function
1081
+ False: p_k is not a descent direction with respect to the line
1082
+ search function
1083
+ fail (bool): failure flag
1084
+ True: line search reached maximum number of iterations, failed
1085
+ False: line search succeeded
1086
+ Notes:
1087
+ . If encountering line search failure in the deterministic setting, one
1088
+ should try increasing the maximum number of line search steps max_ls.
1089
+ """
1090
+
1091
+ # load options for damping and eps
1092
+ if "damping" not in options.keys():
1093
+ damping = False
1094
+ else:
1095
+ damping = options["damping"]
1096
+
1097
+ if "eps" not in options.keys():
1098
+ eps = 1e-2
1099
+ else:
1100
+ eps = options["eps"]
1101
+
1102
+ # gather gradient
1103
+ grad = self._gather_flat_grad()
1104
+
1105
+ # update curvature if after 1st iteration
1106
+ state = self.state["global_state"]
1107
+ if state["n_iter"] > 0:
1108
+ self.curvature_update(grad, eps, damping)
1109
+
1110
+ # compute search direction
1111
+ p = self.two_loop_recursion(-grad)
1112
+
1113
+ # take step
1114
+ return self._step(p, grad, options=options)