yms-kan 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- yms_kan/KANLayer.py +364 -0
- yms_kan/LBFGS.py +492 -0
- yms_kan/MLP.py +361 -0
- yms_kan/MultKAN.py +3085 -0
- yms_kan/Symbolic_KANLayer.py +270 -0
- yms_kan/__init__.py +4 -0
- yms_kan/compiler.py +498 -0
- yms_kan/experiment.py +50 -0
- yms_kan/feynman.py +739 -0
- yms_kan/hypothesis.py +695 -0
- yms_kan/spline.py +144 -0
- yms_kan/tool.py +304 -0
- yms_kan/train_eval_utils.py +175 -0
- yms_kan/utils.py +661 -0
- yms_kan/version.py +1 -0
- yms_kan-0.0.1.dist-info/METADATA +11 -0
- yms_kan-0.0.1.dist-info/RECORD +20 -0
- yms_kan-0.0.1.dist-info/WHEEL +5 -0
- yms_kan-0.0.1.dist-info/licenses/LICENSE +21 -0
- yms_kan-0.0.1.dist-info/top_level.txt +1 -0
yms_kan/LBFGS.py
ADDED
@@ -0,0 +1,492 @@
|
|
1
|
+
import torch
|
2
|
+
from functools import reduce
|
3
|
+
from torch.optim import Optimizer
|
4
|
+
|
5
|
+
__all__ = ['LBFGS']
|
6
|
+
|
7
|
+
|
8
|
+
def _cubic_interpolate(x1, f1, g1, x2, f2, g2, bounds=None):
|
9
|
+
# ported from https://github.com/torch/optim/blob/master/polyinterp.lua
|
10
|
+
# Compute bounds of interpolation area
|
11
|
+
if bounds is not None:
|
12
|
+
xmin_bound, xmax_bound = bounds
|
13
|
+
else:
|
14
|
+
xmin_bound, xmax_bound = (x1, x2) if x1 <= x2 else (x2, x1)
|
15
|
+
|
16
|
+
# Code for most common case: cubic interpolation of 2 points
|
17
|
+
# w/ function and derivative values for both
|
18
|
+
# Solution in this case (where x2 is the farthest point):
|
19
|
+
# d1 = g1 + g2 - 3*(f1-f2)/(x1-x2);
|
20
|
+
# d2 = sqrt(d1^2 - g1*g2);
|
21
|
+
# min_pos = x2 - (x2 - x1)*((g2 + d2 - d1)/(g2 - g1 + 2*d2));
|
22
|
+
# t_new = min(max(min_pos,xmin_bound),xmax_bound);
|
23
|
+
d1 = g1 + g2 - 3 * (f1 - f2) / (x1 - x2)
|
24
|
+
d2_square = d1 ** 2 - g1 * g2
|
25
|
+
if d2_square >= 0:
|
26
|
+
d2 = d2_square.sqrt()
|
27
|
+
if x1 <= x2:
|
28
|
+
min_pos = x2 - (x2 - x1) * ((g2 + d2 - d1) / (g2 - g1 + 2 * d2))
|
29
|
+
else:
|
30
|
+
min_pos = x1 - (x1 - x2) * ((g1 + d2 - d1) / (g1 - g2 + 2 * d2))
|
31
|
+
return min(max(min_pos, xmin_bound), xmax_bound)
|
32
|
+
else:
|
33
|
+
return (xmin_bound + xmax_bound) / 2.
|
34
|
+
|
35
|
+
|
36
|
+
def _strong_wolfe(obj_func,
|
37
|
+
x,
|
38
|
+
t,
|
39
|
+
d,
|
40
|
+
f,
|
41
|
+
g,
|
42
|
+
gtd,
|
43
|
+
c1=1e-4,
|
44
|
+
c2=0.9,
|
45
|
+
tolerance_change=1e-9,
|
46
|
+
max_ls=25):
|
47
|
+
# ported from https://github.com/torch/optim/blob/master/lswolfe.lua
|
48
|
+
d_norm = d.abs().max()
|
49
|
+
g = g.clone(memory_format=torch.contiguous_format)
|
50
|
+
# evaluate objective and gradient using initial step
|
51
|
+
f_new, g_new = obj_func(x, t, d)
|
52
|
+
ls_func_evals = 1
|
53
|
+
gtd_new = g_new.dot(d)
|
54
|
+
|
55
|
+
# bracket an interval containing a point satisfying the Wolfe criteria
|
56
|
+
t_prev, f_prev, g_prev, gtd_prev = 0, f, g, gtd
|
57
|
+
done = False
|
58
|
+
ls_iter = 0
|
59
|
+
while ls_iter < max_ls:
|
60
|
+
# check conditions
|
61
|
+
#print(f_prev, f_new, g_new)
|
62
|
+
if f_new > (f + c1 * t * gtd) or (ls_iter > 1 and f_new >= f_prev):
|
63
|
+
bracket = [t_prev, t]
|
64
|
+
bracket_f = [f_prev, f_new]
|
65
|
+
bracket_g = [g_prev, g_new.clone(memory_format=torch.contiguous_format)]
|
66
|
+
bracket_gtd = [gtd_prev, gtd_new]
|
67
|
+
break
|
68
|
+
|
69
|
+
if abs(gtd_new) <= -c2 * gtd:
|
70
|
+
bracket = [t]
|
71
|
+
bracket_f = [f_new]
|
72
|
+
bracket_g = [g_new]
|
73
|
+
done = True
|
74
|
+
break
|
75
|
+
|
76
|
+
if gtd_new >= 0:
|
77
|
+
bracket = [t_prev, t]
|
78
|
+
bracket_f = [f_prev, f_new]
|
79
|
+
bracket_g = [g_prev, g_new.clone(memory_format=torch.contiguous_format)]
|
80
|
+
bracket_gtd = [gtd_prev, gtd_new]
|
81
|
+
break
|
82
|
+
|
83
|
+
# interpolate
|
84
|
+
min_step = t + 0.01 * (t - t_prev)
|
85
|
+
max_step = t * 10
|
86
|
+
tmp = t
|
87
|
+
t = _cubic_interpolate(
|
88
|
+
t_prev,
|
89
|
+
f_prev,
|
90
|
+
gtd_prev,
|
91
|
+
t,
|
92
|
+
f_new,
|
93
|
+
gtd_new,
|
94
|
+
bounds=(min_step, max_step))
|
95
|
+
|
96
|
+
# next step
|
97
|
+
t_prev = tmp
|
98
|
+
f_prev = f_new
|
99
|
+
g_prev = g_new.clone(memory_format=torch.contiguous_format)
|
100
|
+
gtd_prev = gtd_new
|
101
|
+
f_new, g_new = obj_func(x, t, d)
|
102
|
+
ls_func_evals += 1
|
103
|
+
gtd_new = g_new.dot(d)
|
104
|
+
ls_iter += 1
|
105
|
+
|
106
|
+
# reached max number of iterations?
|
107
|
+
if ls_iter == max_ls:
|
108
|
+
bracket = [0, t]
|
109
|
+
bracket_f = [f, f_new]
|
110
|
+
bracket_g = [g, g_new]
|
111
|
+
|
112
|
+
# zoom phase: we now have a point satisfying the criteria, or
|
113
|
+
# a bracket around it. We refine the bracket until we find the
|
114
|
+
# exact point satisfying the criteria
|
115
|
+
insuf_progress = False
|
116
|
+
# find high and low points in bracket
|
117
|
+
low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[-1] else (1, 0)
|
118
|
+
while not done and ls_iter < max_ls:
|
119
|
+
# line-search bracket is so small
|
120
|
+
if abs(bracket[1] - bracket[0]) * d_norm < tolerance_change:
|
121
|
+
break
|
122
|
+
|
123
|
+
# compute new trial value
|
124
|
+
t = _cubic_interpolate(bracket[0], bracket_f[0], bracket_gtd[0],
|
125
|
+
bracket[1], bracket_f[1], bracket_gtd[1])
|
126
|
+
|
127
|
+
# test that we are making sufficient progress:
|
128
|
+
# in case `t` is so close to boundary, we mark that we are making
|
129
|
+
# insufficient progress, and if
|
130
|
+
# + we have made insufficient progress in the last step, or
|
131
|
+
# + `t` is at one of the boundary,
|
132
|
+
# we will move `t` to a position which is `0.1 * len(bracket)`
|
133
|
+
# away from the nearest boundary point.
|
134
|
+
eps = 0.1 * (max(bracket) - min(bracket))
|
135
|
+
if min(max(bracket) - t, t - min(bracket)) < eps:
|
136
|
+
# interpolation close to boundary
|
137
|
+
if insuf_progress or t >= max(bracket) or t <= min(bracket):
|
138
|
+
# evaluate at 0.1 away from boundary
|
139
|
+
if abs(t - max(bracket)) < abs(t - min(bracket)):
|
140
|
+
t = max(bracket) - eps
|
141
|
+
else:
|
142
|
+
t = min(bracket) + eps
|
143
|
+
insuf_progress = False
|
144
|
+
else:
|
145
|
+
insuf_progress = True
|
146
|
+
else:
|
147
|
+
insuf_progress = False
|
148
|
+
|
149
|
+
# Evaluate new point
|
150
|
+
f_new, g_new = obj_func(x, t, d)
|
151
|
+
ls_func_evals += 1
|
152
|
+
gtd_new = g_new.dot(d)
|
153
|
+
ls_iter += 1
|
154
|
+
|
155
|
+
if f_new > (f + c1 * t * gtd) or f_new >= bracket_f[low_pos]:
|
156
|
+
# Armijo condition not satisfied or not lower than lowest point
|
157
|
+
bracket[high_pos] = t
|
158
|
+
bracket_f[high_pos] = f_new
|
159
|
+
bracket_g[high_pos] = g_new.clone(memory_format=torch.contiguous_format)
|
160
|
+
bracket_gtd[high_pos] = gtd_new
|
161
|
+
low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[1] else (1, 0)
|
162
|
+
else:
|
163
|
+
if abs(gtd_new) <= -c2 * gtd:
|
164
|
+
# Wolfe conditions satisfied
|
165
|
+
done = True
|
166
|
+
elif gtd_new * (bracket[high_pos] - bracket[low_pos]) >= 0:
|
167
|
+
# old low becomes new high
|
168
|
+
bracket[high_pos] = bracket[low_pos]
|
169
|
+
bracket_f[high_pos] = bracket_f[low_pos]
|
170
|
+
bracket_g[high_pos] = bracket_g[low_pos]
|
171
|
+
bracket_gtd[high_pos] = bracket_gtd[low_pos]
|
172
|
+
|
173
|
+
# new point becomes new low
|
174
|
+
bracket[low_pos] = t
|
175
|
+
bracket_f[low_pos] = f_new
|
176
|
+
bracket_g[low_pos] = g_new.clone(memory_format=torch.contiguous_format)
|
177
|
+
bracket_gtd[low_pos] = gtd_new
|
178
|
+
|
179
|
+
#print(bracket)
|
180
|
+
if len(bracket) == 1:
|
181
|
+
t = bracket[0]
|
182
|
+
f_new = bracket_f[0]
|
183
|
+
g_new = bracket_g[0]
|
184
|
+
else:
|
185
|
+
t = bracket[low_pos]
|
186
|
+
f_new = bracket_f[low_pos]
|
187
|
+
g_new = bracket_g[low_pos]
|
188
|
+
return f_new, g_new, t, ls_func_evals
|
189
|
+
|
190
|
+
|
191
|
+
class LBFGS(Optimizer):
|
192
|
+
"""Implements L-BFGS algorithm.
|
193
|
+
|
194
|
+
Heavily inspired by `minFunc
|
195
|
+
<https://www.cs.ubc.ca/~schmidtm/Software/minFunc.html>`_.
|
196
|
+
|
197
|
+
.. warning::
|
198
|
+
This optimizer doesn't support per-parameter options and parameter
|
199
|
+
groups (there can be only one).
|
200
|
+
|
201
|
+
.. warning::
|
202
|
+
Right now all parameters have to be on a single device. This will be
|
203
|
+
improved in the future.
|
204
|
+
|
205
|
+
.. note::
|
206
|
+
This is a very memory intensive optimizer (it requires additional
|
207
|
+
``param_bytes * (history_size + 1)`` bytes). If it doesn't fit in memory
|
208
|
+
try reducing the history size, or use a different algorithm.
|
209
|
+
|
210
|
+
Args:
|
211
|
+
lr (float): learning rate (default: 1)
|
212
|
+
max_iter (int): maximal number of iterations per optimization step
|
213
|
+
(default: 20)
|
214
|
+
max_eval (int): maximal number of function evaluations per optimization
|
215
|
+
step (default: max_iter * 1.25).
|
216
|
+
tolerance_grad (float): termination tolerance on first order optimality
|
217
|
+
(default: 1e-7).
|
218
|
+
tolerance_change (float): termination tolerance on function
|
219
|
+
value/parameter changes (default: 1e-9).
|
220
|
+
history_size (int): update history size (default: 100).
|
221
|
+
line_search_fn (str): either 'strong_wolfe' or None (default: None).
|
222
|
+
"""
|
223
|
+
|
224
|
+
def __init__(self,
|
225
|
+
params,
|
226
|
+
lr=1,
|
227
|
+
max_iter=20,
|
228
|
+
max_eval=None,
|
229
|
+
tolerance_grad=1e-7,
|
230
|
+
tolerance_change=1e-9,
|
231
|
+
tolerance_ys=1e-32,
|
232
|
+
history_size=100,
|
233
|
+
line_search_fn=None):
|
234
|
+
if max_eval is None:
|
235
|
+
max_eval = max_iter * 5 // 4
|
236
|
+
defaults = dict(
|
237
|
+
lr=lr,
|
238
|
+
max_iter=max_iter,
|
239
|
+
max_eval=max_eval,
|
240
|
+
tolerance_grad=tolerance_grad,
|
241
|
+
tolerance_change=tolerance_change,
|
242
|
+
tolerance_ys=tolerance_ys,
|
243
|
+
history_size=history_size,
|
244
|
+
line_search_fn=line_search_fn)
|
245
|
+
super().__init__(params, defaults)
|
246
|
+
|
247
|
+
if len(self.param_groups) != 1:
|
248
|
+
raise ValueError("LBFGS doesn't support per-parameter options "
|
249
|
+
"(parameter groups)")
|
250
|
+
|
251
|
+
self._params = self.param_groups[0]['params']
|
252
|
+
self._numel_cache = None
|
253
|
+
|
254
|
+
def _numel(self):
|
255
|
+
if self._numel_cache is None:
|
256
|
+
self._numel_cache = reduce(lambda total, p: total + p.numel(), self._params, 0)
|
257
|
+
return self._numel_cache
|
258
|
+
|
259
|
+
def _gather_flat_grad(self):
|
260
|
+
views = []
|
261
|
+
for p in self._params:
|
262
|
+
if p.grad is None:
|
263
|
+
view = p.new(p.numel()).zero_()
|
264
|
+
elif p.grad.is_sparse:
|
265
|
+
view = p.grad.to_dense().view(-1)
|
266
|
+
else:
|
267
|
+
view = p.grad.view(-1)
|
268
|
+
views.append(view)
|
269
|
+
device = views[0].device
|
270
|
+
return torch.cat(views, dim=0)
|
271
|
+
|
272
|
+
def _add_grad(self, step_size, update):
|
273
|
+
offset = 0
|
274
|
+
for p in self._params:
|
275
|
+
numel = p.numel()
|
276
|
+
# view as to avoid deprecated pointwise semantics
|
277
|
+
p.add_(update[offset:offset + numel].view_as(p), alpha=step_size)
|
278
|
+
offset += numel
|
279
|
+
assert offset == self._numel()
|
280
|
+
|
281
|
+
def _clone_param(self):
|
282
|
+
return [p.clone(memory_format=torch.contiguous_format) for p in self._params]
|
283
|
+
|
284
|
+
def _set_param(self, params_data):
|
285
|
+
for p, pdata in zip(self._params, params_data):
|
286
|
+
p.copy_(pdata)
|
287
|
+
|
288
|
+
def _directional_evaluate(self, closure, x, t, d):
|
289
|
+
self._add_grad(t, d)
|
290
|
+
loss = float(closure())
|
291
|
+
flat_grad = self._gather_flat_grad()
|
292
|
+
self._set_param(x)
|
293
|
+
return loss, flat_grad
|
294
|
+
|
295
|
+
@torch.no_grad()
|
296
|
+
def step(self, closure):
|
297
|
+
"""Perform a single optimization step.
|
298
|
+
|
299
|
+
Args:
|
300
|
+
closure (Callable): A closure that reevaluates the model
|
301
|
+
and returns the loss.
|
302
|
+
"""
|
303
|
+
|
304
|
+
torch.manual_seed(0)
|
305
|
+
|
306
|
+
assert len(self.param_groups) == 1
|
307
|
+
|
308
|
+
# Make sure the closure is always called with grad enabled
|
309
|
+
closure = torch.enable_grad()(closure)
|
310
|
+
|
311
|
+
group = self.param_groups[0]
|
312
|
+
lr = group['lr']
|
313
|
+
max_iter = group['max_iter']
|
314
|
+
max_eval = group['max_eval']
|
315
|
+
tolerance_grad = group['tolerance_grad']
|
316
|
+
tolerance_change = group['tolerance_change']
|
317
|
+
tolerance_ys = group['tolerance_ys']
|
318
|
+
line_search_fn = group['line_search_fn']
|
319
|
+
history_size = group['history_size']
|
320
|
+
|
321
|
+
# NOTE: LBFGS has only global state, but we register it as state for
|
322
|
+
# the first param, because this helps with casting in load_state_dict
|
323
|
+
state = self.state[self._params[0]]
|
324
|
+
state.setdefault('func_evals', 0)
|
325
|
+
state.setdefault('n_iter', 0)
|
326
|
+
|
327
|
+
# evaluate initial f(x) and df/dx
|
328
|
+
orig_loss = closure()
|
329
|
+
loss = float(orig_loss)
|
330
|
+
current_evals = 1
|
331
|
+
state['func_evals'] += 1
|
332
|
+
|
333
|
+
flat_grad = self._gather_flat_grad()
|
334
|
+
opt_cond = flat_grad.abs().max() <= tolerance_grad
|
335
|
+
|
336
|
+
# optimal condition
|
337
|
+
if opt_cond:
|
338
|
+
return orig_loss
|
339
|
+
|
340
|
+
# tensors cached in state (for tracing)
|
341
|
+
d = state.get('d')
|
342
|
+
t = state.get('t')
|
343
|
+
old_dirs = state.get('old_dirs')
|
344
|
+
old_stps = state.get('old_stps')
|
345
|
+
ro = state.get('ro')
|
346
|
+
H_diag = state.get('H_diag')
|
347
|
+
prev_flat_grad = state.get('prev_flat_grad')
|
348
|
+
prev_loss = state.get('prev_loss')
|
349
|
+
|
350
|
+
n_iter = 0
|
351
|
+
# optimize for a max of max_iter iterations
|
352
|
+
while n_iter < max_iter:
|
353
|
+
# keep track of nb of iterations
|
354
|
+
n_iter += 1
|
355
|
+
state['n_iter'] += 1
|
356
|
+
|
357
|
+
############################################################
|
358
|
+
# compute gradient descent direction
|
359
|
+
############################################################
|
360
|
+
if state['n_iter'] == 1:
|
361
|
+
d = flat_grad.neg()
|
362
|
+
old_dirs = []
|
363
|
+
old_stps = []
|
364
|
+
ro = []
|
365
|
+
H_diag = 1
|
366
|
+
else:
|
367
|
+
# do lbfgs update (update memory)
|
368
|
+
y = flat_grad.sub(prev_flat_grad)
|
369
|
+
s = d.mul(t)
|
370
|
+
ys = y.dot(s) # y*s
|
371
|
+
if ys > tolerance_ys:
|
372
|
+
# updating memory
|
373
|
+
if len(old_dirs) == history_size:
|
374
|
+
# shift history by one (limited-memory)
|
375
|
+
old_dirs.pop(0)
|
376
|
+
old_stps.pop(0)
|
377
|
+
ro.pop(0)
|
378
|
+
|
379
|
+
# store new direction/step
|
380
|
+
old_dirs.append(y)
|
381
|
+
old_stps.append(s)
|
382
|
+
ro.append(1. / ys)
|
383
|
+
|
384
|
+
# update scale of initial Hessian approximation
|
385
|
+
H_diag = ys / y.dot(y) # (y*y)
|
386
|
+
|
387
|
+
# compute the approximate (L-BFGS) inverse Hessian
|
388
|
+
# multiplied by the gradient
|
389
|
+
num_old = len(old_dirs)
|
390
|
+
|
391
|
+
if 'al' not in state:
|
392
|
+
state['al'] = [None] * history_size
|
393
|
+
al = state['al']
|
394
|
+
|
395
|
+
# iteration in L-BFGS loop collapsed to use just one buffer
|
396
|
+
q = flat_grad.neg()
|
397
|
+
for i in range(num_old - 1, -1, -1):
|
398
|
+
al[i] = old_stps[i].dot(q) * ro[i]
|
399
|
+
q.add_(old_dirs[i], alpha=-al[i])
|
400
|
+
|
401
|
+
# multiply by initial Hessian
|
402
|
+
# r/d is the final direction
|
403
|
+
d = r = torch.mul(q, H_diag)
|
404
|
+
for i in range(num_old):
|
405
|
+
be_i = old_dirs[i].dot(r) * ro[i]
|
406
|
+
r.add_(old_stps[i], alpha=al[i] - be_i)
|
407
|
+
|
408
|
+
if prev_flat_grad is None:
|
409
|
+
prev_flat_grad = flat_grad.clone(memory_format=torch.contiguous_format)
|
410
|
+
else:
|
411
|
+
prev_flat_grad.copy_(flat_grad)
|
412
|
+
prev_loss = loss
|
413
|
+
|
414
|
+
############################################################
|
415
|
+
# compute step length
|
416
|
+
############################################################
|
417
|
+
# reset initial guess for step size
|
418
|
+
if state['n_iter'] == 1:
|
419
|
+
t = min(1., 1. / flat_grad.abs().sum()) * lr
|
420
|
+
else:
|
421
|
+
t = lr
|
422
|
+
|
423
|
+
# directional derivative
|
424
|
+
gtd = flat_grad.dot(d) # g * d
|
425
|
+
|
426
|
+
# directional derivative is below tolerance
|
427
|
+
if gtd > -tolerance_change:
|
428
|
+
break
|
429
|
+
|
430
|
+
# optional line search: user function
|
431
|
+
ls_func_evals = 0
|
432
|
+
if line_search_fn is not None:
|
433
|
+
# perform line search, using user function
|
434
|
+
if line_search_fn != "strong_wolfe":
|
435
|
+
raise RuntimeError("only 'strong_wolfe' is supported")
|
436
|
+
else:
|
437
|
+
x_init = self._clone_param()
|
438
|
+
|
439
|
+
def obj_func(x, t, d):
|
440
|
+
return self._directional_evaluate(closure, x, t, d)
|
441
|
+
|
442
|
+
loss, flat_grad, t, ls_func_evals = _strong_wolfe(
|
443
|
+
obj_func, x_init, t, d, loss, flat_grad, gtd)
|
444
|
+
self._add_grad(t, d)
|
445
|
+
opt_cond = flat_grad.abs().max() <= tolerance_grad
|
446
|
+
else:
|
447
|
+
# no line search, simply move with fixed-step
|
448
|
+
self._add_grad(t, d)
|
449
|
+
if n_iter != max_iter:
|
450
|
+
# re-evaluate function only if not in last iteration
|
451
|
+
# the reason we do this: in a stochastic setting,
|
452
|
+
# no use to re-evaluate that function here
|
453
|
+
with torch.enable_grad():
|
454
|
+
loss = float(closure())
|
455
|
+
flat_grad = self._gather_flat_grad()
|
456
|
+
opt_cond = flat_grad.abs().max() <= tolerance_grad
|
457
|
+
ls_func_evals = 1
|
458
|
+
|
459
|
+
# update func eval
|
460
|
+
current_evals += ls_func_evals
|
461
|
+
state['func_evals'] += ls_func_evals
|
462
|
+
|
463
|
+
############################################################
|
464
|
+
# check conditions
|
465
|
+
############################################################
|
466
|
+
if n_iter == max_iter:
|
467
|
+
break
|
468
|
+
|
469
|
+
if current_evals >= max_eval:
|
470
|
+
break
|
471
|
+
|
472
|
+
# optimal condition
|
473
|
+
if opt_cond:
|
474
|
+
break
|
475
|
+
|
476
|
+
# lack of progress
|
477
|
+
if d.mul(t).abs().max() <= tolerance_change:
|
478
|
+
break
|
479
|
+
|
480
|
+
if abs(loss - prev_loss) < tolerance_change:
|
481
|
+
break
|
482
|
+
|
483
|
+
state['d'] = d
|
484
|
+
state['t'] = t
|
485
|
+
state['old_dirs'] = old_dirs
|
486
|
+
state['old_stps'] = old_stps
|
487
|
+
state['ro'] = ro
|
488
|
+
state['H_diag'] = H_diag
|
489
|
+
state['prev_flat_grad'] = prev_flat_grad
|
490
|
+
state['prev_loss'] = prev_loss
|
491
|
+
|
492
|
+
return orig_loss
|