yms-kan 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
yms_kan/LBFGS.py ADDED
@@ -0,0 +1,492 @@
1
+ import torch
2
+ from functools import reduce
3
+ from torch.optim import Optimizer
4
+
5
+ __all__ = ['LBFGS']
6
+
7
+
8
+ def _cubic_interpolate(x1, f1, g1, x2, f2, g2, bounds=None):
9
+ # ported from https://github.com/torch/optim/blob/master/polyinterp.lua
10
+ # Compute bounds of interpolation area
11
+ if bounds is not None:
12
+ xmin_bound, xmax_bound = bounds
13
+ else:
14
+ xmin_bound, xmax_bound = (x1, x2) if x1 <= x2 else (x2, x1)
15
+
16
+ # Code for most common case: cubic interpolation of 2 points
17
+ # w/ function and derivative values for both
18
+ # Solution in this case (where x2 is the farthest point):
19
+ # d1 = g1 + g2 - 3*(f1-f2)/(x1-x2);
20
+ # d2 = sqrt(d1^2 - g1*g2);
21
+ # min_pos = x2 - (x2 - x1)*((g2 + d2 - d1)/(g2 - g1 + 2*d2));
22
+ # t_new = min(max(min_pos,xmin_bound),xmax_bound);
23
+ d1 = g1 + g2 - 3 * (f1 - f2) / (x1 - x2)
24
+ d2_square = d1 ** 2 - g1 * g2
25
+ if d2_square >= 0:
26
+ d2 = d2_square.sqrt()
27
+ if x1 <= x2:
28
+ min_pos = x2 - (x2 - x1) * ((g2 + d2 - d1) / (g2 - g1 + 2 * d2))
29
+ else:
30
+ min_pos = x1 - (x1 - x2) * ((g1 + d2 - d1) / (g1 - g2 + 2 * d2))
31
+ return min(max(min_pos, xmin_bound), xmax_bound)
32
+ else:
33
+ return (xmin_bound + xmax_bound) / 2.
34
+
35
+
36
+ def _strong_wolfe(obj_func,
37
+ x,
38
+ t,
39
+ d,
40
+ f,
41
+ g,
42
+ gtd,
43
+ c1=1e-4,
44
+ c2=0.9,
45
+ tolerance_change=1e-9,
46
+ max_ls=25):
47
+ # ported from https://github.com/torch/optim/blob/master/lswolfe.lua
48
+ d_norm = d.abs().max()
49
+ g = g.clone(memory_format=torch.contiguous_format)
50
+ # evaluate objective and gradient using initial step
51
+ f_new, g_new = obj_func(x, t, d)
52
+ ls_func_evals = 1
53
+ gtd_new = g_new.dot(d)
54
+
55
+ # bracket an interval containing a point satisfying the Wolfe criteria
56
+ t_prev, f_prev, g_prev, gtd_prev = 0, f, g, gtd
57
+ done = False
58
+ ls_iter = 0
59
+ while ls_iter < max_ls:
60
+ # check conditions
61
+ #print(f_prev, f_new, g_new)
62
+ if f_new > (f + c1 * t * gtd) or (ls_iter > 1 and f_new >= f_prev):
63
+ bracket = [t_prev, t]
64
+ bracket_f = [f_prev, f_new]
65
+ bracket_g = [g_prev, g_new.clone(memory_format=torch.contiguous_format)]
66
+ bracket_gtd = [gtd_prev, gtd_new]
67
+ break
68
+
69
+ if abs(gtd_new) <= -c2 * gtd:
70
+ bracket = [t]
71
+ bracket_f = [f_new]
72
+ bracket_g = [g_new]
73
+ done = True
74
+ break
75
+
76
+ if gtd_new >= 0:
77
+ bracket = [t_prev, t]
78
+ bracket_f = [f_prev, f_new]
79
+ bracket_g = [g_prev, g_new.clone(memory_format=torch.contiguous_format)]
80
+ bracket_gtd = [gtd_prev, gtd_new]
81
+ break
82
+
83
+ # interpolate
84
+ min_step = t + 0.01 * (t - t_prev)
85
+ max_step = t * 10
86
+ tmp = t
87
+ t = _cubic_interpolate(
88
+ t_prev,
89
+ f_prev,
90
+ gtd_prev,
91
+ t,
92
+ f_new,
93
+ gtd_new,
94
+ bounds=(min_step, max_step))
95
+
96
+ # next step
97
+ t_prev = tmp
98
+ f_prev = f_new
99
+ g_prev = g_new.clone(memory_format=torch.contiguous_format)
100
+ gtd_prev = gtd_new
101
+ f_new, g_new = obj_func(x, t, d)
102
+ ls_func_evals += 1
103
+ gtd_new = g_new.dot(d)
104
+ ls_iter += 1
105
+
106
+ # reached max number of iterations?
107
+ if ls_iter == max_ls:
108
+ bracket = [0, t]
109
+ bracket_f = [f, f_new]
110
+ bracket_g = [g, g_new]
111
+
112
+ # zoom phase: we now have a point satisfying the criteria, or
113
+ # a bracket around it. We refine the bracket until we find the
114
+ # exact point satisfying the criteria
115
+ insuf_progress = False
116
+ # find high and low points in bracket
117
+ low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[-1] else (1, 0)
118
+ while not done and ls_iter < max_ls:
119
+ # line-search bracket is so small
120
+ if abs(bracket[1] - bracket[0]) * d_norm < tolerance_change:
121
+ break
122
+
123
+ # compute new trial value
124
+ t = _cubic_interpolate(bracket[0], bracket_f[0], bracket_gtd[0],
125
+ bracket[1], bracket_f[1], bracket_gtd[1])
126
+
127
+ # test that we are making sufficient progress:
128
+ # in case `t` is so close to boundary, we mark that we are making
129
+ # insufficient progress, and if
130
+ # + we have made insufficient progress in the last step, or
131
+ # + `t` is at one of the boundary,
132
+ # we will move `t` to a position which is `0.1 * len(bracket)`
133
+ # away from the nearest boundary point.
134
+ eps = 0.1 * (max(bracket) - min(bracket))
135
+ if min(max(bracket) - t, t - min(bracket)) < eps:
136
+ # interpolation close to boundary
137
+ if insuf_progress or t >= max(bracket) or t <= min(bracket):
138
+ # evaluate at 0.1 away from boundary
139
+ if abs(t - max(bracket)) < abs(t - min(bracket)):
140
+ t = max(bracket) - eps
141
+ else:
142
+ t = min(bracket) + eps
143
+ insuf_progress = False
144
+ else:
145
+ insuf_progress = True
146
+ else:
147
+ insuf_progress = False
148
+
149
+ # Evaluate new point
150
+ f_new, g_new = obj_func(x, t, d)
151
+ ls_func_evals += 1
152
+ gtd_new = g_new.dot(d)
153
+ ls_iter += 1
154
+
155
+ if f_new > (f + c1 * t * gtd) or f_new >= bracket_f[low_pos]:
156
+ # Armijo condition not satisfied or not lower than lowest point
157
+ bracket[high_pos] = t
158
+ bracket_f[high_pos] = f_new
159
+ bracket_g[high_pos] = g_new.clone(memory_format=torch.contiguous_format)
160
+ bracket_gtd[high_pos] = gtd_new
161
+ low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[1] else (1, 0)
162
+ else:
163
+ if abs(gtd_new) <= -c2 * gtd:
164
+ # Wolfe conditions satisfied
165
+ done = True
166
+ elif gtd_new * (bracket[high_pos] - bracket[low_pos]) >= 0:
167
+ # old low becomes new high
168
+ bracket[high_pos] = bracket[low_pos]
169
+ bracket_f[high_pos] = bracket_f[low_pos]
170
+ bracket_g[high_pos] = bracket_g[low_pos]
171
+ bracket_gtd[high_pos] = bracket_gtd[low_pos]
172
+
173
+ # new point becomes new low
174
+ bracket[low_pos] = t
175
+ bracket_f[low_pos] = f_new
176
+ bracket_g[low_pos] = g_new.clone(memory_format=torch.contiguous_format)
177
+ bracket_gtd[low_pos] = gtd_new
178
+
179
+ #print(bracket)
180
+ if len(bracket) == 1:
181
+ t = bracket[0]
182
+ f_new = bracket_f[0]
183
+ g_new = bracket_g[0]
184
+ else:
185
+ t = bracket[low_pos]
186
+ f_new = bracket_f[low_pos]
187
+ g_new = bracket_g[low_pos]
188
+ return f_new, g_new, t, ls_func_evals
189
+
190
+
191
+ class LBFGS(Optimizer):
192
+ """Implements L-BFGS algorithm.
193
+
194
+ Heavily inspired by `minFunc
195
+ <https://www.cs.ubc.ca/~schmidtm/Software/minFunc.html>`_.
196
+
197
+ .. warning::
198
+ This optimizer doesn't support per-parameter options and parameter
199
+ groups (there can be only one).
200
+
201
+ .. warning::
202
+ Right now all parameters have to be on a single device. This will be
203
+ improved in the future.
204
+
205
+ .. note::
206
+ This is a very memory intensive optimizer (it requires additional
207
+ ``param_bytes * (history_size + 1)`` bytes). If it doesn't fit in memory
208
+ try reducing the history size, or use a different algorithm.
209
+
210
+ Args:
211
+ lr (float): learning rate (default: 1)
212
+ max_iter (int): maximal number of iterations per optimization step
213
+ (default: 20)
214
+ max_eval (int): maximal number of function evaluations per optimization
215
+ step (default: max_iter * 1.25).
216
+ tolerance_grad (float): termination tolerance on first order optimality
217
+ (default: 1e-7).
218
+ tolerance_change (float): termination tolerance on function
219
+ value/parameter changes (default: 1e-9).
220
+ history_size (int): update history size (default: 100).
221
+ line_search_fn (str): either 'strong_wolfe' or None (default: None).
222
+ """
223
+
224
+ def __init__(self,
225
+ params,
226
+ lr=1,
227
+ max_iter=20,
228
+ max_eval=None,
229
+ tolerance_grad=1e-7,
230
+ tolerance_change=1e-9,
231
+ tolerance_ys=1e-32,
232
+ history_size=100,
233
+ line_search_fn=None):
234
+ if max_eval is None:
235
+ max_eval = max_iter * 5 // 4
236
+ defaults = dict(
237
+ lr=lr,
238
+ max_iter=max_iter,
239
+ max_eval=max_eval,
240
+ tolerance_grad=tolerance_grad,
241
+ tolerance_change=tolerance_change,
242
+ tolerance_ys=tolerance_ys,
243
+ history_size=history_size,
244
+ line_search_fn=line_search_fn)
245
+ super().__init__(params, defaults)
246
+
247
+ if len(self.param_groups) != 1:
248
+ raise ValueError("LBFGS doesn't support per-parameter options "
249
+ "(parameter groups)")
250
+
251
+ self._params = self.param_groups[0]['params']
252
+ self._numel_cache = None
253
+
254
+ def _numel(self):
255
+ if self._numel_cache is None:
256
+ self._numel_cache = reduce(lambda total, p: total + p.numel(), self._params, 0)
257
+ return self._numel_cache
258
+
259
+ def _gather_flat_grad(self):
260
+ views = []
261
+ for p in self._params:
262
+ if p.grad is None:
263
+ view = p.new(p.numel()).zero_()
264
+ elif p.grad.is_sparse:
265
+ view = p.grad.to_dense().view(-1)
266
+ else:
267
+ view = p.grad.view(-1)
268
+ views.append(view)
269
+ device = views[0].device
270
+ return torch.cat(views, dim=0)
271
+
272
+ def _add_grad(self, step_size, update):
273
+ offset = 0
274
+ for p in self._params:
275
+ numel = p.numel()
276
+ # view as to avoid deprecated pointwise semantics
277
+ p.add_(update[offset:offset + numel].view_as(p), alpha=step_size)
278
+ offset += numel
279
+ assert offset == self._numel()
280
+
281
+ def _clone_param(self):
282
+ return [p.clone(memory_format=torch.contiguous_format) for p in self._params]
283
+
284
+ def _set_param(self, params_data):
285
+ for p, pdata in zip(self._params, params_data):
286
+ p.copy_(pdata)
287
+
288
+ def _directional_evaluate(self, closure, x, t, d):
289
+ self._add_grad(t, d)
290
+ loss = float(closure())
291
+ flat_grad = self._gather_flat_grad()
292
+ self._set_param(x)
293
+ return loss, flat_grad
294
+
295
+ @torch.no_grad()
296
+ def step(self, closure):
297
+ """Perform a single optimization step.
298
+
299
+ Args:
300
+ closure (Callable): A closure that reevaluates the model
301
+ and returns the loss.
302
+ """
303
+
304
+ torch.manual_seed(0)
305
+
306
+ assert len(self.param_groups) == 1
307
+
308
+ # Make sure the closure is always called with grad enabled
309
+ closure = torch.enable_grad()(closure)
310
+
311
+ group = self.param_groups[0]
312
+ lr = group['lr']
313
+ max_iter = group['max_iter']
314
+ max_eval = group['max_eval']
315
+ tolerance_grad = group['tolerance_grad']
316
+ tolerance_change = group['tolerance_change']
317
+ tolerance_ys = group['tolerance_ys']
318
+ line_search_fn = group['line_search_fn']
319
+ history_size = group['history_size']
320
+
321
+ # NOTE: LBFGS has only global state, but we register it as state for
322
+ # the first param, because this helps with casting in load_state_dict
323
+ state = self.state[self._params[0]]
324
+ state.setdefault('func_evals', 0)
325
+ state.setdefault('n_iter', 0)
326
+
327
+ # evaluate initial f(x) and df/dx
328
+ orig_loss = closure()
329
+ loss = float(orig_loss)
330
+ current_evals = 1
331
+ state['func_evals'] += 1
332
+
333
+ flat_grad = self._gather_flat_grad()
334
+ opt_cond = flat_grad.abs().max() <= tolerance_grad
335
+
336
+ # optimal condition
337
+ if opt_cond:
338
+ return orig_loss
339
+
340
+ # tensors cached in state (for tracing)
341
+ d = state.get('d')
342
+ t = state.get('t')
343
+ old_dirs = state.get('old_dirs')
344
+ old_stps = state.get('old_stps')
345
+ ro = state.get('ro')
346
+ H_diag = state.get('H_diag')
347
+ prev_flat_grad = state.get('prev_flat_grad')
348
+ prev_loss = state.get('prev_loss')
349
+
350
+ n_iter = 0
351
+ # optimize for a max of max_iter iterations
352
+ while n_iter < max_iter:
353
+ # keep track of nb of iterations
354
+ n_iter += 1
355
+ state['n_iter'] += 1
356
+
357
+ ############################################################
358
+ # compute gradient descent direction
359
+ ############################################################
360
+ if state['n_iter'] == 1:
361
+ d = flat_grad.neg()
362
+ old_dirs = []
363
+ old_stps = []
364
+ ro = []
365
+ H_diag = 1
366
+ else:
367
+ # do lbfgs update (update memory)
368
+ y = flat_grad.sub(prev_flat_grad)
369
+ s = d.mul(t)
370
+ ys = y.dot(s) # y*s
371
+ if ys > tolerance_ys:
372
+ # updating memory
373
+ if len(old_dirs) == history_size:
374
+ # shift history by one (limited-memory)
375
+ old_dirs.pop(0)
376
+ old_stps.pop(0)
377
+ ro.pop(0)
378
+
379
+ # store new direction/step
380
+ old_dirs.append(y)
381
+ old_stps.append(s)
382
+ ro.append(1. / ys)
383
+
384
+ # update scale of initial Hessian approximation
385
+ H_diag = ys / y.dot(y) # (y*y)
386
+
387
+ # compute the approximate (L-BFGS) inverse Hessian
388
+ # multiplied by the gradient
389
+ num_old = len(old_dirs)
390
+
391
+ if 'al' not in state:
392
+ state['al'] = [None] * history_size
393
+ al = state['al']
394
+
395
+ # iteration in L-BFGS loop collapsed to use just one buffer
396
+ q = flat_grad.neg()
397
+ for i in range(num_old - 1, -1, -1):
398
+ al[i] = old_stps[i].dot(q) * ro[i]
399
+ q.add_(old_dirs[i], alpha=-al[i])
400
+
401
+ # multiply by initial Hessian
402
+ # r/d is the final direction
403
+ d = r = torch.mul(q, H_diag)
404
+ for i in range(num_old):
405
+ be_i = old_dirs[i].dot(r) * ro[i]
406
+ r.add_(old_stps[i], alpha=al[i] - be_i)
407
+
408
+ if prev_flat_grad is None:
409
+ prev_flat_grad = flat_grad.clone(memory_format=torch.contiguous_format)
410
+ else:
411
+ prev_flat_grad.copy_(flat_grad)
412
+ prev_loss = loss
413
+
414
+ ############################################################
415
+ # compute step length
416
+ ############################################################
417
+ # reset initial guess for step size
418
+ if state['n_iter'] == 1:
419
+ t = min(1., 1. / flat_grad.abs().sum()) * lr
420
+ else:
421
+ t = lr
422
+
423
+ # directional derivative
424
+ gtd = flat_grad.dot(d) # g * d
425
+
426
+ # directional derivative is below tolerance
427
+ if gtd > -tolerance_change:
428
+ break
429
+
430
+ # optional line search: user function
431
+ ls_func_evals = 0
432
+ if line_search_fn is not None:
433
+ # perform line search, using user function
434
+ if line_search_fn != "strong_wolfe":
435
+ raise RuntimeError("only 'strong_wolfe' is supported")
436
+ else:
437
+ x_init = self._clone_param()
438
+
439
+ def obj_func(x, t, d):
440
+ return self._directional_evaluate(closure, x, t, d)
441
+
442
+ loss, flat_grad, t, ls_func_evals = _strong_wolfe(
443
+ obj_func, x_init, t, d, loss, flat_grad, gtd)
444
+ self._add_grad(t, d)
445
+ opt_cond = flat_grad.abs().max() <= tolerance_grad
446
+ else:
447
+ # no line search, simply move with fixed-step
448
+ self._add_grad(t, d)
449
+ if n_iter != max_iter:
450
+ # re-evaluate function only if not in last iteration
451
+ # the reason we do this: in a stochastic setting,
452
+ # no use to re-evaluate that function here
453
+ with torch.enable_grad():
454
+ loss = float(closure())
455
+ flat_grad = self._gather_flat_grad()
456
+ opt_cond = flat_grad.abs().max() <= tolerance_grad
457
+ ls_func_evals = 1
458
+
459
+ # update func eval
460
+ current_evals += ls_func_evals
461
+ state['func_evals'] += ls_func_evals
462
+
463
+ ############################################################
464
+ # check conditions
465
+ ############################################################
466
+ if n_iter == max_iter:
467
+ break
468
+
469
+ if current_evals >= max_eval:
470
+ break
471
+
472
+ # optimal condition
473
+ if opt_cond:
474
+ break
475
+
476
+ # lack of progress
477
+ if d.mul(t).abs().max() <= tolerance_change:
478
+ break
479
+
480
+ if abs(loss - prev_loss) < tolerance_change:
481
+ break
482
+
483
+ state['d'] = d
484
+ state['t'] = t
485
+ state['old_dirs'] = old_dirs
486
+ state['old_stps'] = old_stps
487
+ state['ro'] = ro
488
+ state['H_diag'] = H_diag
489
+ state['prev_flat_grad'] = prev_flat_grad
490
+ state['prev_loss'] = prev_loss
491
+
492
+ return orig_loss