wums 0.1.6__py3-none-any.whl → 0.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wums/fitutils.py ADDED
@@ -0,0 +1,989 @@
1
+ import numpy as np
2
+ import scipy
3
+ import tensorflow as tf
4
+ import math
5
+
6
+ from numpy import (zeros, where, diff, floor, minimum, maximum, array, concatenate, logical_or, logical_xor,
7
+ sqrt)
8
+
9
+ def cubic_spline_interpolate(xi, yi, x, axis=-1, extrpl=[None, None]):
10
+
11
+ # natural cublic spline
12
+ # if extrpl is given, the spline is linearly extrapolated outside the given boundaries
13
+
14
+ # https://www.math.ntnu.no/emner/TMA4215/2008h/cubicsplines.pdf
15
+ # https://random-walks.org/content/misc/ncs/ncs.html
16
+
17
+ # move selected axis to the end
18
+ tensors = [xi, yi]
19
+ nelems = [tensor.shape.num_elements() for tensor in tensors]
20
+
21
+ max_nelems = max(nelems)
22
+ broadcast_shape = tensors[nelems.index(max_nelems)].shape
23
+
24
+ ndim = len(broadcast_shape)
25
+
26
+ if xi.shape.num_elements() < max_nelems:
27
+ xi = tf.broadcast_to(xi, broadcast_shape)
28
+ if yi.shape.num_elements() < max_nelems:
29
+ yi = tf.broadcast_to(yi, broadcast_shape)
30
+
31
+ # # permutation to move the selected axis to the end
32
+ selaxis = axis
33
+ if axis < 0:
34
+ selaxis = ndim + axis
35
+ axis = -1
36
+ permfwd = list(range(ndim))
37
+ permfwd.remove(selaxis)
38
+ permfwd.append(selaxis)
39
+
40
+ # reverse permutation to restore the original axis order
41
+ permrev = list(range(ndim))
42
+ permrev.remove(ndim-1)
43
+ permrev.insert(selaxis, ndim-1)
44
+
45
+
46
+ ## start spline construction
47
+ xi = tf.transpose(xi, permfwd)
48
+ yi = tf.transpose(yi, permfwd)
49
+ x = tf.transpose(x, permfwd)
50
+
51
+ h = tf.experimental.numpy.diff(xi, axis=axis)
52
+ b = tf.experimental.numpy.diff(yi, axis=axis) / h
53
+ v = 2 * (h[:, 1:] + h[:, :-1])
54
+ u = 6 * (b[:, 1:] - b[:, :-1])
55
+
56
+ shape = (xi.shape[0], xi.shape[-1]-2, xi.shape[-1]-2)
57
+ uu = u[:,:,None]
58
+
59
+ diag = v
60
+ superdiag = h[:, 1:-1]
61
+ subdiag = superdiag
62
+
63
+ z = tf.linalg.tridiagonal_solve(diagonals = [superdiag, diag, subdiag], rhs = uu, diagonals_format = "sequence")
64
+ z = tf.squeeze(z, axis=axis)
65
+ f = tf.zeros(xi.shape[0], dtype=tf.float64)[:,None]
66
+ z = tf.concat([f, z, f], axis=axis)
67
+
68
+ x_steps = tf.experimental.numpy.diff(x, axis=axis)
69
+ idx_zero_constant = tf.constant(0, dtype=tf.int64)
70
+ float64_zero_constant = tf.constant(0., dtype=tf.float64)
71
+
72
+ x_compare = x[...,None] < xi[..., None, :]
73
+ x_compare_all = tf.math.reduce_all(x_compare, axis=axis)
74
+ x_compare_none = tf.math.reduce_all(tf.logical_not(x_compare), axis=axis)
75
+ x_index = tf.argmax(x_compare, axis=axis) - 1
76
+ x_index = tf.where(x_compare_all, idx_zero_constant, x_index)
77
+ x_index = tf.where(x_compare_none, tf.constant(xi.shape[axis]-2, dtype=tf.int64), x_index)
78
+
79
+ nbatch = ndim - 1
80
+
81
+ z_xidx = tf.gather(z, x_index, batch_dims=nbatch, axis=-1)
82
+ z_1pxidx = tf.gather(z, x_index+1, batch_dims=nbatch, axis=-1)
83
+ h_xidx = tf.gather(h, x_index, batch_dims=nbatch, axis=-1)
84
+ xi_xidx = tf.gather(xi, x_index, batch_dims=nbatch, axis=-1)
85
+ xi_1pxidx = tf.gather(xi, x_index+1, batch_dims=nbatch, axis=-1)
86
+ dxxi = x - xi_xidx
87
+ dxix = xi_1pxidx - x
88
+
89
+ y = z_1pxidx / (6 * h_xidx) * dxxi ** 3 + \
90
+ z_xidx / (6 * h_xidx) * dxix ** 3 + \
91
+ (tf.gather(yi, x_index+1, batch_dims=nbatch, axis=-1) / h_xidx - \
92
+ z_1pxidx * h_xidx / 6) * dxxi + \
93
+ (tf.gather(yi, x_index, batch_dims=nbatch, axis=-1) / h_xidx - \
94
+ z_xidx * h_xidx / 6) * dxix
95
+
96
+
97
+ # right side linear extrapolation
98
+ if extrpl[1] != None:
99
+ x_max = tf.reshape(tf.constant(extrpl[1], dtype=tf.float64), (1,1))
100
+
101
+ # calculate derivative yp_max at boundary
102
+ x_compare_max = x_max[...,None] < xi[..., None, :]
103
+ x_compare_max_all = tf.math.reduce_all(x_compare_max, axis=axis)
104
+ x_compare_max_none = tf.math.reduce_all(tf.logical_not(x_compare_max), axis=axis)
105
+ x_index_max = tf.argmax(x_compare_max, axis=axis) - 1
106
+ x_index_max = tf.where(x_compare_max_all, idx_zero_constant, x_index_max)
107
+ x_index_max = tf.where(x_compare_max_none, tf.constant(xi.shape[axis]-2, dtype=tf.int64), x_index_max)
108
+
109
+ z_xidx = tf.gather(z, x_index_max, batch_dims=nbatch, axis=-1)
110
+ z_1pxidx = tf.gather(z, x_index_max+1, batch_dims=nbatch, axis=-1)
111
+ hi_xidx = tf.gather(h, x_index_max, batch_dims=nbatch, axis=-1)
112
+ xi_xidx = tf.gather(xi, x_index_max, batch_dims=nbatch, axis=-1)
113
+ xi_1pxidx = tf.gather(xi, x_index_max+1, batch_dims=nbatch, axis=-1)
114
+ yi_xidx = tf.gather(yi, x_index_max, batch_dims=nbatch, axis=-1)
115
+ yi_1pxidx = tf.gather(yi, x_index_max+1, batch_dims=nbatch, axis=-1)
116
+
117
+ yp_max = z_1pxidx / (2 * hi_xidx) * (x_max - xi_xidx) ** 2 - \
118
+ z_xidx / (2 * hi_xidx) * (xi_1pxidx - x_max) ** 2 + \
119
+ 1./hi_xidx*(yi_1pxidx - yi_xidx) - hi_xidx/6.*(z_1pxidx - z_xidx)
120
+
121
+ # evaluate spline at boundary
122
+ y_b = cubic_spline_interpolate(xi, yi, x_max, axis=axis) # check shape of x_max
123
+
124
+ # replace y by lin for x > x_max
125
+ extrpl_lin = yp_max*(x-x_max) + y_b
126
+ cond = x[0,:] >= x_max
127
+ cond = tf.broadcast_to(cond, (extrpl_lin.shape[0], extrpl_lin.shape[1]))
128
+ y = tf.where(cond, extrpl_lin, y)
129
+
130
+
131
+ # left side linear extrapolation
132
+ if extrpl[0] != None:
133
+ x_min = tf.reshape(tf.constant(extrpl[0], dtype=tf.float64), (1,1))
134
+
135
+ # calculate derivative yp_min at boundary
136
+ x_compare_min = x_min[...,None] >= xi[..., None, :]
137
+ x_compare_min_all = tf.math.reduce_all(x_compare_min, axis=axis)
138
+ x_compare_min_none = tf.math.reduce_all(tf.logical_not(x_compare_min), axis=axis)
139
+ x_index_min = tf.argmax(x_compare_min, axis=axis)
140
+ x_index_min = tf.where(x_compare_min_all, idx_zero_constant, x_index_min)
141
+ x_index_min = tf.where(x_compare_min_none, tf.constant(0, dtype=tf.int64), x_index_min)
142
+
143
+ z_xidx = tf.gather(z, x_index_min, batch_dims=nbatch, axis=-1)
144
+ z_1pxidx = tf.gather(z, x_index_min+1, batch_dims=nbatch, axis=-1)
145
+ hi_xidx = tf.gather(h, x_index_min, batch_dims=nbatch, axis=-1)
146
+ xi_xidx = tf.gather(xi, x_index_min, batch_dims=nbatch, axis=-1)
147
+ xi_1pxidx = tf.gather(xi, x_index_min+1, batch_dims=nbatch, axis=-1)
148
+ yi_xidx = tf.gather(yi, x_index_min, batch_dims=nbatch, axis=-1)
149
+ yi_1pxidx = tf.gather(yi, x_index_min+1, batch_dims=nbatch, axis=-1)
150
+
151
+ yp_min = z_1pxidx / (2 * hi_xidx) * (x_min - xi_xidx) ** 2 - \
152
+ z_xidx / (2 * hi_xidx) * (xi_1pxidx - x_min) ** 2 + \
153
+ 1./hi_xidx*(yi_1pxidx - yi_xidx) - hi_xidx/6.*(z_1pxidx - z_xidx)
154
+
155
+ # evaluate spline at boundary
156
+ y_b = cubic_spline_interpolate(xi, yi, x_min, axis=axis) # check shape of x_max
157
+
158
+ # replace y by lin for x > x_min
159
+ extrpl_lin = yp_min*(x-x_min) + y_b
160
+ cond = x[0,:] <= x_min
161
+ cond = tf.broadcast_to(cond, (extrpl_lin.shape[0], extrpl_lin.shape[1]))
162
+ y = tf.where(cond, extrpl_lin, y)
163
+
164
+ ## convert back axes
165
+ y = tf.transpose(y, permrev)
166
+ return y
167
+
168
+
169
+
170
+ def pchip_interpolate(xi, yi, x, axis=-1):
171
+ '''
172
+ Functionality:
173
+ 1D PCHP interpolation
174
+ Authors:
175
+ Michael Taylor <mtaylor@atlanticsciences.com>
176
+ Mathieu Virbel <mat@meltingrocks.com>
177
+ Link:
178
+ https://gist.github.com/tito/553f1135959921ce6699652bf656150d
179
+ https://github.com/tensorflow/tensorflow/issues/46609#issuecomment-774573667
180
+ '''
181
+
182
+ tensors = [xi, yi]
183
+ nelems = [tensor.shape.num_elements() for tensor in tensors]
184
+
185
+ max_nelems = max(nelems)
186
+ broadcast_shape = tensors[nelems.index(max_nelems)].shape
187
+
188
+ ndim = len(broadcast_shape)
189
+
190
+ if xi.shape.num_elements() < max_nelems:
191
+ xi = tf.broadcast_to(xi, broadcast_shape)
192
+ if yi.shape.num_elements() < max_nelems:
193
+ yi = tf.broadcast_to(yi, broadcast_shape)
194
+
195
+ # # permutation to move the selected axis to the end
196
+ selaxis = axis
197
+ if axis < 0:
198
+ selaxis = ndim + axis
199
+
200
+ permfwd = list(range(ndim))
201
+ permfwd.remove(selaxis)
202
+ permfwd.append(selaxis)
203
+
204
+ # reverse permutation to restore the original axis order
205
+ permrev = list(range(ndim))
206
+ permrev.remove(ndim-1)
207
+ permrev.insert(selaxis, ndim-1)
208
+
209
+ xi = tf.transpose(xi, permfwd)
210
+ yi = tf.transpose(yi, permfwd)
211
+ x = tf.transpose(x, permfwd)
212
+ axis = -1
213
+
214
+ xi_steps = tf.experimental.numpy.diff(xi, axis=axis)
215
+
216
+
217
+ x_steps = tf.experimental.numpy.diff(x, axis=axis)
218
+
219
+ idx_zero_constant = tf.constant(0, dtype=tf.int64)
220
+ float64_zero_constant = tf.constant(0., dtype=tf.float64)
221
+
222
+ x_compare = x[...,None] < xi[..., None, :]
223
+ x_compare_all = tf.math.reduce_all(x_compare, axis=-1)
224
+ x_compare_none = tf.math.reduce_all(tf.logical_not(x_compare), axis=-1)
225
+ x_index = tf.argmax(x_compare, axis = -1) - 1
226
+
227
+ x_index = tf.where(x_compare_all, idx_zero_constant, x_index)
228
+ x_index = tf.where(x_compare_none, tf.constant(xi.shape[axis]-2, dtype=tf.int64), x_index)
229
+
230
+ # Calculate gradients d
231
+ h = tf.experimental.numpy.diff(xi, axis=axis)
232
+
233
+ d = tf.zeros_like(xi)
234
+
235
+ delta = tf.experimental.numpy.diff(yi, axis=axis) / h
236
+ # mode=='mono', Fritsch-Carlson algorithm from fortran numerical
237
+ # recipe
238
+
239
+ slice01 = [slice(None)]*ndim
240
+ slice01[axis] = slice(0,1)
241
+ slice01 = tuple(slice01)
242
+
243
+ slice0m1 = [slice(None)]*ndim
244
+ slice0m1[axis] = slice(0,-1)
245
+ slice0m1 = tuple(slice0m1)
246
+
247
+ slice1 = [slice(None)]*ndim
248
+ slice1[axis] = slice(1,None)
249
+ slice1 = tuple(slice1)
250
+
251
+ slicem1 = [slice(None)]*ndim
252
+ slicem1[axis] = slice(-1,None)
253
+ slicem1 = tuple(slicem1)
254
+
255
+ d = tf.concat(
256
+ (delta[slice01], 3 * (h[slice0m1] + h[slice1]) / ((h[slice0m1] + 2 * h[slice1]) / delta[slice0m1] +
257
+ (2 * h[slice0m1] + h[slice1]) / delta[slice1]), delta[slicem1]), axis=axis)
258
+
259
+ false_shape = [*xi.shape]
260
+ false_shape[axis] = 1
261
+ false_const = tf.fill(false_shape, False)
262
+
263
+ mask = tf.concat((false_const, tf.math.logical_xor(delta[slice0m1] > 0, delta[slice1] > 0), false_const), axis=axis)
264
+ d = tf.where(mask, float64_zero_constant, d)
265
+
266
+ mask = tf.math.logical_or(tf.concat((false_const, delta == 0), axis=axis), tf.concat((delta == 0, false_const), axis=axis))
267
+ d = tf.where(mask, float64_zero_constant, d)
268
+
269
+ xiperm = xi
270
+ yiperm = yi
271
+ dperm = d
272
+ hperm = h
273
+
274
+ nbatch = ndim - 1
275
+
276
+ xi_xidx = tf.gather(xiperm, x_index, batch_dims=nbatch, axis=-1)
277
+ xi_1pxidx = tf.gather(xiperm, 1 + x_index, batch_dims=nbatch, axis=-1)
278
+ yi_xidx = tf.gather(yiperm, x_index, batch_dims=nbatch, axis=-1)
279
+ yi_1pxidx = tf.gather(yiperm, 1 + x_index, batch_dims=nbatch, axis=-1)
280
+ d_xidx = tf.gather(dperm, x_index, batch_dims=nbatch, axis=-1)
281
+ d_1pxidx = tf.gather(dperm, 1 + x_index, batch_dims=nbatch, axis=-1)
282
+ h_xidx = tf.gather(hperm, x_index, batch_dims=nbatch, axis=-1)
283
+
284
+ dxxi = x - xi_xidx
285
+ dxxid = x - xi_1pxidx
286
+ dxxi2 = tf.math.pow(dxxi, 2)
287
+ dxxid2 = tf.math.pow(dxxid, 2)
288
+
289
+ y = (2 / tf.math.pow(h_xidx, 3) *
290
+ (yi_xidx * dxxid2 * (dxxi + h_xidx / 2) - yi_1pxidx * dxxi2 *
291
+ (dxxid - h_xidx / 2)) + 1 / tf.math.pow(h_xidx, 2) *
292
+ (d_xidx * dxxid2 * dxxi + d_1pxidx * dxxi2 * dxxid))
293
+
294
+ y = tf.transpose(y, permrev)
295
+
296
+ return y
297
+
298
+ def pchip_interpolate_np(xi, yi, x, mode="mono", verbose=False):
299
+ '''
300
+ Functionality:
301
+ 1D PCHP interpolation
302
+ Authors:
303
+ Michael Taylor <mtaylor@atlanticsciences.com>
304
+ Mathieu Virbel <mat@meltingrocks.com>
305
+ Link:
306
+ https://gist.github.com/tito/553f1135959921ce6699652bf656150d
307
+ '''
308
+
309
+ if mode not in ("mono", "quad"):
310
+ raise ValueError("Unrecognized mode string")
311
+
312
+ # Search for [xi,xi+1] interval for each x
313
+ xi = xi.astype("double")
314
+ yi = yi.astype("double")
315
+
316
+ x_index = zeros(len(x), dtype="int")
317
+ xi_steps = diff(xi)
318
+ if not all(xi_steps > 0):
319
+ raise ValueError("x-coordinates are not in increasing order.")
320
+
321
+ x_steps = diff(x)
322
+ if xi_steps.max() / xi_steps.min() < 1.000001:
323
+ # uniform input grid
324
+ if verbose:
325
+ print("pchip: uniform input grid")
326
+ xi_start = xi[0]
327
+ xi_step = (xi[-1] - xi[0]) / (len(xi) - 1)
328
+ x_index = minimum(maximum(floor((x - xi_start) / xi_step).astype(int), 0), len(xi) - 2)
329
+
330
+ # Calculate gradients d
331
+ h = (xi[-1] - xi[0]) / (len(xi) - 1)
332
+ d = zeros(len(xi), dtype="double")
333
+ if mode == "quad":
334
+ # quadratic polynomial fit
335
+ d[[0]] = (yi[1] - yi[0]) / h
336
+ d[[-1]] = (yi[-1] - yi[-2]) / h
337
+ d[1:-1] = (yi[2:] - yi[0:-2]) / 2 / h
338
+ else:
339
+ # mode=='mono', Fritsch-Carlson algorithm from fortran numerical
340
+ # recipe
341
+ delta = diff(yi) / h
342
+ d = concatenate((delta[0:1], 2 / (1 / delta[0:-1] + 1 / delta[1:]), delta[-1:]))
343
+ d[concatenate((array([False]), logical_xor(delta[0:-1] > 0, delta[1:] > 0), array([False])))] = 0
344
+ d[logical_or(concatenate((array([False]), delta == 0)), concatenate(
345
+ (delta == 0, array([False]))))] = 0
346
+ # Calculate output values y
347
+ dxxi = x - xi[x_index]
348
+ dxxid = x - xi[1 + x_index]
349
+ dxxi2 = pow(dxxi, 2)
350
+ dxxid2 = pow(dxxid, 2)
351
+ y = (2 / pow(h, 3) * (yi[x_index] * dxxid2 * (dxxi + h / 2) - yi[1 + x_index] * dxxi2 *
352
+ (dxxid - h / 2)) + 1 / pow(h, 2) *
353
+ (d[x_index] * dxxid2 * dxxi + d[1 + x_index] * dxxi2 * dxxid))
354
+ else:
355
+ # not uniform input grid
356
+ if (x_steps.max() / x_steps.min() < 1.000001 and x_steps.max() / x_steps.min() > 0.999999):
357
+ # non-uniform input grid, uniform output grid
358
+ if verbose:
359
+ print("pchip: non-uniform input grid, uniform output grid")
360
+ x_decreasing = x[-1] < x[0]
361
+ if x_decreasing:
362
+ x = x[::-1]
363
+ x_start = x[0]
364
+ x_step = (x[-1] - x[0]) / (len(x) - 1)
365
+ x_indexprev = -1
366
+ for xi_loop in range(len(xi) - 2):
367
+ x_indexcur = max(int(floor((xi[1 + xi_loop] - x_start) / x_step)), -1)
368
+ x_index[1 + x_indexprev:1 + x_indexcur] = xi_loop
369
+ x_indexprev = x_indexcur
370
+ x_index[1 + x_indexprev:] = len(xi) - 2
371
+ if x_decreasing:
372
+ x = x[::-1]
373
+ x_index = x_index[::-1]
374
+ elif all(x_steps > 0) or all(x_steps < 0):
375
+ # non-uniform input/output grids, output grid monotonic
376
+ if verbose:
377
+ print("pchip: non-uniform in/out grid, output grid monotonic")
378
+ x_decreasing = x[-1] < x[0]
379
+ if x_decreasing:
380
+ x = x[::-1]
381
+ x_len = len(x)
382
+ x_loop = 0
383
+ for xi_loop in range(len(xi) - 1):
384
+ while x_loop < x_len and x[x_loop] < xi[1 + xi_loop]:
385
+ x_index[x_loop] = xi_loop
386
+ x_loop += 1
387
+ x_index[x_loop:] = len(xi) - 2
388
+ if x_decreasing:
389
+ x = x[::-1]
390
+ x_index = x_index[::-1]
391
+ else:
392
+ # non-uniform input/output grids, output grid not monotonic
393
+ if verbose:
394
+ print("pchip: non-uniform in/out grids, " "output grid not monotonic")
395
+ for index in range(len(x)):
396
+ loc = where(x[index] < xi)[0]
397
+ if loc.size == 0:
398
+ x_index[index] = len(xi) - 2
399
+ elif loc[0] == 0:
400
+ x_index[index] = 0
401
+ else:
402
+ x_index[index] = loc[0] - 1
403
+ # Calculate gradients d
404
+ h = diff(xi)
405
+ d = zeros(len(xi), dtype="double")
406
+ delta = diff(yi) / h
407
+ if mode == "quad":
408
+ # quadratic polynomial fit
409
+ d[[0, -1]] = delta[[0, -1]]
410
+ d[1:-1] = (delta[1:] * h[0:-1] + delta[0:-1] * h[1:]) / (h[0:-1] + h[1:])
411
+ else:
412
+ # mode=='mono', Fritsch-Carlson algorithm from fortran numerical
413
+ # recipe
414
+ d = concatenate(
415
+ (delta[0:1], 3 * (h[0:-1] + h[1:]) / ((h[0:-1] + 2 * h[1:]) / delta[0:-1] +
416
+ (2 * h[0:-1] + h[1:]) / delta[1:]), delta[-1:]))
417
+ d[concatenate((array([False]), logical_xor(delta[0:-1] > 0, delta[1:] > 0), array([False])))] = 0
418
+ d[logical_or(concatenate((array([False]), delta == 0)), concatenate(
419
+ (delta == 0, array([False]))))] = 0
420
+ dxxi = x - xi[x_index]
421
+ dxxid = x - xi[1 + x_index]
422
+ dxxi2 = pow(dxxi, 2)
423
+ dxxid2 = pow(dxxid, 2)
424
+ y = (2 / pow(h[x_index], 3) *
425
+ (yi[x_index] * dxxid2 * (dxxi + h[x_index] / 2) - yi[1 + x_index] * dxxi2 *
426
+ (dxxid - h[x_index] / 2)) + 1 / pow(h[x_index], 2) *
427
+ (d[x_index] * dxxid2 * dxxi + d[1 + x_index] * dxxi2 * dxxid))
428
+ return y
429
+
430
+
431
+
432
+ def pchip_interpolate_np_forced(xi, yi, x, mode="mono", verbose=False):
433
+ '''
434
+ Functionality:
435
+ 1D PCHP interpolation
436
+ Authors:
437
+ Michael Taylor <mtaylor@atlanticsciences.com>
438
+ Mathieu Virbel <mat@meltingrocks.com>
439
+ Link:
440
+ https://gist.github.com/tito/553f1135959921ce6699652bf656150d
441
+ '''
442
+
443
+ if mode not in ("mono", "quad"):
444
+ raise ValueError("Unrecognized mode string")
445
+
446
+ # Search for [xi,xi+1] interval for each x
447
+ xi = xi.astype("double")
448
+ yi = yi.astype("double")
449
+
450
+ x_index = zeros(len(x), dtype="int")
451
+ xi_steps = diff(xi)
452
+ if not all(xi_steps > 0):
453
+ raise ValueError("x-coordinates are not in increasing order.")
454
+
455
+ x_steps = diff(x)
456
+ # if xi_steps.max() / xi_steps.min() < 1.000001:
457
+ if False:
458
+ # uniform input grid
459
+ if verbose:
460
+ print("pchip: uniform input grid")
461
+ xi_start = xi[0]
462
+ xi_step = (xi[-1] - xi[0]) / (len(xi) - 1)
463
+ x_index = minimum(maximum(floor((x - xi_start) / xi_step).astype(int), 0), len(xi) - 2)
464
+
465
+ # Calculate gradients d
466
+ h = (xi[-1] - xi[0]) / (len(xi) - 1)
467
+ d = zeros(len(xi), dtype="double")
468
+ if mode == "quad":
469
+ # quadratic polynomial fit
470
+ d[[0]] = (yi[1] - yi[0]) / h
471
+ d[[-1]] = (yi[-1] - yi[-2]) / h
472
+ d[1:-1] = (yi[2:] - yi[0:-2]) / 2 / h
473
+ else:
474
+ # mode=='mono', Fritsch-Carlson algorithm from fortran numerical
475
+ # recipe
476
+ delta = diff(yi) / h
477
+ d = concatenate((delta[0:1], 2 / (1 / delta[0:-1] + 1 / delta[1:]), delta[-1:]))
478
+ d[concatenate((array([False]), logical_xor(delta[0:-1] > 0, delta[1:] > 0), array([False])))] = 0
479
+ d[logical_or(concatenate((array([False]), delta == 0)), concatenate(
480
+ (delta == 0, array([False]))))] = 0
481
+ # Calculate output values y
482
+ dxxi = x - xi[x_index]
483
+ dxxid = x - xi[1 + x_index]
484
+ dxxi2 = pow(dxxi, 2)
485
+ dxxid2 = pow(dxxid, 2)
486
+ y = (2 / pow(h, 3) * (yi[x_index] * dxxid2 * (dxxi + h / 2) - yi[1 + x_index] * dxxi2 *
487
+ (dxxid - h / 2)) + 1 / pow(h, 2) *
488
+ (d[x_index] * dxxid2 * dxxi + d[1 + x_index] * dxxi2 * dxxid))
489
+ else:
490
+ # not uniform input grid
491
+ # if (x_steps.max() / x_steps.min() < 1.000001 and x_steps.max() / x_steps.min() > 0.999999):
492
+ if False:
493
+ # non-uniform input grid, uniform output grid
494
+ if verbose:
495
+ print("pchip: non-uniform input grid, uniform output grid")
496
+ x_decreasing = x[-1] < x[0]
497
+ if x_decreasing:
498
+ x = x[::-1]
499
+ x_start = x[0]
500
+ x_step = (x[-1] - x[0]) / (len(x) - 1)
501
+ x_indexprev = -1
502
+ for xi_loop in range(len(xi) - 2):
503
+ x_indexcur = max(int(floor((xi[1 + xi_loop] - x_start) / x_step)), -1)
504
+ x_index[1 + x_indexprev:1 + x_indexcur] = xi_loop
505
+ x_indexprev = x_indexcur
506
+ x_index[1 + x_indexprev:] = len(xi) - 2
507
+ if x_decreasing:
508
+ x = x[::-1]
509
+ x_index = x_index[::-1]
510
+ # elif all(x_steps > 0) or all(x_steps < 0):
511
+ elif True:
512
+ # non-uniform input/output grids, output grid monotonic
513
+ if verbose:
514
+ print("pchip: non-uniform in/out grid, output grid monotonic")
515
+ # x_decreasing = x[-1] < x[0]
516
+ x_decreasing = False
517
+ if x_decreasing:
518
+ x = x[::-1]
519
+ x_len = len(x)
520
+ x_loop = 0
521
+ for xi_loop in range(len(xi) - 1):
522
+ while x_loop < x_len and x[x_loop] < xi[1 + xi_loop]:
523
+ x_index[x_loop] = xi_loop
524
+ x_loop += 1
525
+ x_index[x_loop:] = len(xi) - 2
526
+
527
+ print("np_forced x_index", x_index)
528
+ if x_decreasing:
529
+ x = x[::-1]
530
+ x_index = x_index[::-1]
531
+ else:
532
+ # non-uniform input/output grids, output grid not monotonic
533
+ if verbose:
534
+ print("pchip: non-uniform in/out grids, " "output grid not monotonic")
535
+ for index in range(len(x)):
536
+ loc = where(x[index] < xi)[0]
537
+ if loc.size == 0:
538
+ x_index[index] = len(xi) - 2
539
+ elif loc[0] == 0:
540
+ x_index[index] = 0
541
+ else:
542
+ x_index[index] = loc[0] - 1
543
+ # Calculate gradients d
544
+ h = diff(xi)
545
+ d = zeros(len(xi), dtype="double")
546
+ delta = diff(yi) / h
547
+ if mode == "quad":
548
+ # quadratic polynomial fit
549
+ d[[0, -1]] = delta[[0, -1]]
550
+ d[1:-1] = (delta[1:] * h[0:-1] + delta[0:-1] * h[1:]) / (h[0:-1] + h[1:])
551
+ else:
552
+ # mode=='mono', Fritsch-Carlson algorithm from fortran numerical
553
+ # recipe
554
+ d = concatenate(
555
+ (delta[0:1], 3 * (h[0:-1] + h[1:]) / ((h[0:-1] + 2 * h[1:]) / delta[0:-1] +
556
+ (2 * h[0:-1] + h[1:]) / delta[1:]), delta[-1:]))
557
+ d[concatenate((array([False]), logical_xor(delta[0:-1] > 0, delta[1:] > 0), array([False])))] = 0
558
+ d[logical_or(concatenate((array([False]), delta == 0)), concatenate(
559
+ (delta == 0, array([False]))))] = 0
560
+ dxxi = x - xi[x_index]
561
+ dxxid = x - xi[1 + x_index]
562
+ dxxi2 = pow(dxxi, 2)
563
+ dxxid2 = pow(dxxid, 2)
564
+ y = (2 / pow(h[x_index], 3) *
565
+ (yi[x_index] * dxxid2 * (dxxi + h[x_index] / 2) - yi[1 + x_index] * dxxi2 *
566
+ (dxxid - h[x_index] / 2)) + 1 / pow(h[x_index], 2) *
567
+ (d[x_index] * dxxid2 * dxxi + d[1 + x_index] * dxxi2 * dxxid))
568
+ return y
569
+
570
+ def qparms_to_quantiles(qparms, x_low = 0., x_high = 1., axis = -1):
571
+ deltax = tf.exp(qparms)
572
+ sumdeltax = tf.math.reduce_sum(deltax, axis=axis, keepdims=True)
573
+
574
+ deltaxnorm = deltax/sumdeltax
575
+
576
+ x0shape = list(deltaxnorm.shape)
577
+ x0shape[axis] = 1
578
+ x0 = tf.fill(x0shape, x_low)
579
+ x0 = tf.cast(x0, tf.float64)
580
+
581
+ deltaxfull = (x_high - x_low)*deltaxnorm
582
+ deltaxfull = tf.concat([x0, deltaxfull], axis = axis)
583
+
584
+ quants = tf.math.cumsum(deltaxfull, axis=axis)
585
+
586
+ return quants
587
+
588
+
589
+
590
+ def quantiles_to_qparms(quants, quant_errs = None, x_low = 0., x_high = 1., axis = -1):
591
+
592
+ deltaxfull = tf.experimental.numpy.diff(quants, axis=axis)
593
+ deltaxnorm = deltaxfull/(x_high - x_low)
594
+ qparms = tf.math.log(deltaxnorm)
595
+
596
+ if quant_errs is not None:
597
+ quant_vars = tf.math.square(quant_errs)
598
+
599
+ ndim = len(quant_errs.shape)
600
+
601
+ slicem1 = [slice(None)]*ndim
602
+ slicem1[axis] = slice(None,-1)
603
+ slicem1 = tuple(slicem1)
604
+
605
+ slice1 = [slice(None)]*ndim
606
+ slice1[axis] = slice(1,None)
607
+ slice1 = tuple(slice1)
608
+
609
+ deltaxfull_vars = quant_vars[slice1] + quant_vars[slicem1]
610
+ deltaxfull_errs = tf.math.sqrt(deltaxfull_vars)
611
+
612
+ qparm_errs = deltaxfull_errs/deltaxfull
613
+
614
+ return qparms, qparm_errs
615
+ else:
616
+ return qparms
617
+
618
+
619
+ def hist_to_quantiles(h, quant_cdfvals, axis = -1):
620
+ dtype = tf.float64
621
+
622
+ xvals = [tf.constant(center, dtype=dtype) for center in h.axes.centers]
623
+ xwidths = [tf.constant(width, dtype=dtype) for width in h.axes.widths]
624
+ xedges = [tf.constant(edge, dtype=dtype) for edge in h.axes.edges]
625
+ yvals = tf.constant(h.values(), dtype=dtype)
626
+
627
+ if not isinstance(quant_cdfvals, tf.Tensor):
628
+ quant_cdfvals = tf.constant(quant_cdfvals, tf.float64)
629
+
630
+ x_flat = tf.reshape(xedges[axis], (-1,))
631
+ x_low = x_flat[0]
632
+ x_high = x_flat[-1]
633
+
634
+ hist_cdfvals = tf.cumsum(yvals, axis=axis)/tf.reduce_sum(yvals, axis=axis, keepdims=True)
635
+
636
+ x0shape = list(hist_cdfvals.shape)
637
+ x0shape[axis] = 1
638
+ x0 = tf.zeros(x0shape, dtype = dtype)
639
+
640
+ hist_cdfvals = tf.concat([x0, hist_cdfvals], axis=axis)
641
+
642
+ quants = pchip_interpolate(hist_cdfvals, xedges[axis], quant_cdfvals, axis=axis)
643
+
644
+ quants = tf.where(quant_cdfvals == 0., x_low, quants)
645
+ quants = tf.where(quant_cdfvals == 1., x_high, quants)
646
+
647
+ ntot = tf.math.reduce_sum(yvals, axis=axis, keepdims=True)
648
+
649
+ quant_cdf_bar = ntot/(1.+ntot)*(quant_cdfvals + 0.5/ntot)
650
+ quant_cdfval_errs = ntot/(1.+ntot)*tf.math.sqrt(quant_cdfvals*(1.-quant_cdfvals)/ntot + 0.25/ntot/ntot)
651
+
652
+ quant_cdfvals_up = quant_cdf_bar + quant_cdfval_errs
653
+ quant_cdfvals_up = tf.clip_by_value(quant_cdfvals_up, 0., 1.)
654
+
655
+ quant_cdfvals_down = quant_cdf_bar - quant_cdfval_errs
656
+ quant_cdfvals_down = tf.clip_by_value(quant_cdfvals_down, 0., 1.)
657
+
658
+ quants_up = pchip_interpolate(hist_cdfvals, xedges[axis], quant_cdfvals_up, axis=axis)
659
+ quants_up = tf.where(quant_cdfvals_up == 0., x_low, quants_up)
660
+ quants_up = tf.where(quant_cdfvals_up == 1., x_high, quants_up)
661
+
662
+ quants_down = pchip_interpolate(hist_cdfvals, xedges[axis], quant_cdfvals_down, axis=axis)
663
+ quants_down = tf.where(quant_cdfvals_down == 0., x_low, quants_down)
664
+ quants_down = tf.where(quant_cdfvals_down == 1., x_high, quants_down)
665
+
666
+ quant_errs = 0.5*(quants_up - quants_down)
667
+
668
+ zero_const = tf.constant(0., dtype)
669
+
670
+ quant_errs = tf.where(quant_cdfvals == 0., zero_const, quant_errs)
671
+ quant_errs = tf.where(quant_cdfvals == 1., zero_const, quant_errs)
672
+
673
+ return quants.numpy(), quant_errs.numpy()
674
+
675
+ def func_cdf_for_quantile_fit(xvals, xedges, qparms, quant_cdfvals, axis=-1, transform = None):
676
+ x_flat = tf.reshape(xedges[axis], (-1,))
677
+ x_low = x_flat[0]
678
+ x_high = x_flat[-1]
679
+
680
+ quants = qparms_to_quantiles(qparms, x_low = x_low, x_high = x_high, axis = axis)
681
+
682
+ spline_edges = xedges[axis]
683
+
684
+ ndim = len(xvals)
685
+
686
+ if transform is not None:
687
+ transform_cdf, transform_quantile = transform
688
+
689
+ slicelim = [slice(None)]*ndim
690
+ slicelim[axis] = slice(1, -1)
691
+ slicelim = tuple(slicelim)
692
+
693
+ quants = quants[slicelim]
694
+ quant_cdfvals = quant_cdfvals[slicelim]
695
+
696
+ quant_cdfvals = transform_quantile(quant_cdfvals)
697
+
698
+ cdfvals = pchip_interpolate(quants, quant_cdfvals, spline_edges, axis=axis)
699
+
700
+ if transform is not None:
701
+ cdfvals = transform_cdf(cdfvals)
702
+
703
+ slicefirst = [slice(None)]*ndim
704
+ slicefirst[axis] = slice(None, 1)
705
+ slicefirst = tuple(slicefirst)
706
+
707
+ slicelast = [slice(None)]*ndim
708
+ slicelast[axis] = slice(-1, None)
709
+ slicelast = tuple(slicelast)
710
+
711
+ cdfvals = (cdfvals - cdfvals[slicefirst])/(cdfvals[slicelast] - cdfvals[slicefirst])
712
+
713
+ return cdfvals
714
+
715
+ def func_constraint_for_quantile_fit(xvals, xedges, qparms, axis=-1):
716
+ constraints = 0.5*tf.math.square(tf.math.reduce_sum(tf.exp(qparms), axis=axis) - 1.)
717
+ constraint = tf.math.reduce_sum(constraints)
718
+ return constraint
719
+
720
+ @tf.function
721
+ def val_grad(func, *args, **kwargs):
722
+ xdep = kwargs["parms"]
723
+ with tf.GradientTape() as t1:
724
+ t1.watch(xdep)
725
+ val = func(*args, **kwargs)
726
+ grad = t1.gradient(val, xdep)
727
+ return val, grad
728
+
729
+ #TODO forward-over-reverse also here?
730
+ @tf.function
731
+ def val_grad_hess(func, *args, **kwargs):
732
+ xdep = kwargs["parms"]
733
+ with tf.GradientTape() as t2:
734
+ t2.watch(xdep)
735
+ with tf.GradientTape() as t1:
736
+ t1.watch(xdep)
737
+ val = func(*args, **kwargs)
738
+ grad = t1.gradient(val, xdep)
739
+ hess = t2.jacobian(grad, xdep)
740
+
741
+ return val, grad, hess
742
+
743
+ @tf.function
744
+ def val_grad_hessp(func, p, *args, **kwargs):
745
+ xdep = kwargs["parms"]
746
+ with tf.autodiff.ForwardAccumulator(xdep, p) as acc:
747
+ with tf.GradientTape() as grad_tape:
748
+ grad_tape.watch(xdep)
749
+ val = func(*args, **kwargs)
750
+ grad = grad_tape.gradient(val, xdep)
751
+ hessp = acc.jvp(grad)
752
+
753
+ return val, grad, hessp
754
+
755
+ def loss_with_constraint(func_loss, parms, func_constraint = None, args_loss = (), extra_args_loss=(), args_constraint = (), extra_args_constraint = ()):
756
+ loss = func_loss(parms, *args_loss, *extra_args_loss)
757
+ if func_constraint is not None:
758
+ loss += func_constraint(*args_constraint, parms, *extra_args_constraint)
759
+
760
+ return loss
761
+
762
+ def chisq_loss(parms, xvals, xwidths, xedges, yvals, yvariances, func, norm_axes = None, *args):
763
+ fvals = func(xvals, parms, *args)
764
+
765
+ # exclude zero-variance bins
766
+ variances_safe = tf.where(yvariances == 0., tf.ones_like(yvariances), yvariances)
767
+ chisqv = (fvals - yvals)**2/variances_safe
768
+ chisqv_safe = tf.where(yvariances == 0., tf.zeros_like(chisqv), chisqv)
769
+ return tf.reduce_sum(chisqv_safe)
770
+
771
+ def chisq_normalized_loss(parms, xvals, xwidths, xedges, yvals, yvariances, func, norm_axes = None, *args):
772
+ fvals = func(xvals, parms, *args)
773
+ norm = tf.reduce_sum(fvals, keepdims=True, axis = norm_axes)
774
+ sumw = tf.reduce_sum(yvals, keepdims=True, axis = norm_axes)
775
+ if norm_axes is None:
776
+ for xwidth in xwidths:
777
+ norm *= xwidth
778
+ else:
779
+ for norm_axis in norm_axes:
780
+ norm *= xwidths[norm_axis]
781
+
782
+ # exclude zero-variance bins
783
+ variances_safe = tf.where(yvariances == 0., tf.ones_like(yvariances), yvariances)
784
+ chisqv = (sumw*fvals/norm - yvals)**2/variances_safe
785
+ chisqv_safe = tf.where(yvariances == 0., tf.zeros_like(chisqv), chisqv)
786
+ return tf.reduce_sum(chisqv_safe)
787
+
788
+ def nll_loss(parms, xvals, xwidths, xedges, yvals, yvariances, func, norm_axes = None, *args):
789
+ fvals = func(xvals, parms, *args)
790
+
791
+ # compute overall scaling needed to restore mean == variance condition
792
+ yval_total = tf.reduce_sum(yvals, keepdims = True, axis = norm_axes)
793
+ variance_total = tf.reduce_sum(yvariances, keepdims = True, axis = norm_axes)
794
+ isnull_total = variance_total == 0.
795
+ variance_total_safe = tf.where(isnull_total, tf.ones_like(variance_total), variance_total)
796
+ scale_total = yval_total/variance_total_safe
797
+ scale_total_safe = tf.where(isnull_total, tf.ones_like(scale_total), scale_total)
798
+
799
+ # skip likelihood calculation for empty bins to avoid inf or nan
800
+ # compute per-bin scaling needed to restore mean == variance condition, falling
801
+ # back to overall scaling for empty bins
802
+ isnull = tf.logical_or(yvals == 0., yvariances == 0.)
803
+ variances_safe = tf.where(isnull, tf.ones_like(yvariances), yvariances)
804
+ scale = yvals/variances_safe
805
+ scale_safe = tf.where(isnull, scale_total_safe*tf.ones_like(scale), scale)
806
+
807
+ norm = tf.reduce_sum(scale_safe*fvals, keepdims=True, axis = norm_axes)
808
+ if norm_axes is None:
809
+ for xwidth in xwidths:
810
+ norm *= xwidth
811
+ else:
812
+ for norm_axis in norm_axes:
813
+ norm *= xwidths[norm_axis]
814
+
815
+ fvalsnorm = fvals/norm
816
+
817
+ fvalsnorm_safe = tf.where(isnull, tf.ones_like(fvalsnorm), fvalsnorm)
818
+ nllv = -scale_safe*yvals*tf.math.log(scale_safe*fvalsnorm_safe)
819
+ nllv_safe = tf.where(isnull, tf.zeros_like(nllv), nllv)
820
+ nllsum = tf.reduce_sum(nllv_safe)
821
+ return nllsum
822
+
823
+ def nll_loss_bin_integrated(parms, xvals, xwidths, xedges, yvals, yvariances, func, norm_axes = None, *args):
824
+ #TODO reduce code duplication with nll_loss_bin
825
+
826
+ norm_axis = 0
827
+ if norm_axes is not None:
828
+ if len(norm_axes) > 1:
829
+ raise ValueError("Only 1 nomralization access supported for bin-integrated nll")
830
+ norm_axis = norm_axes[0]
831
+
832
+ cdfvals = func(xvals, xedges, parms, *args)
833
+
834
+ slices_low = [slice(None)]*len(cdfvals.shape)
835
+ slices_low[norm_axis] = slice(None,-1)
836
+
837
+ slices_high = [slice(None)]*len(cdfvals.shape)
838
+ slices_high[norm_axis] = slice(1,None)
839
+
840
+ # bin_integrals = cdfvals[1:] - cdfvals[:-1]
841
+ bin_integrals = cdfvals[tuple(slices_high)] - cdfvals[tuple(slices_low)]
842
+ bin_integrals = tf.maximum(bin_integrals, tf.zeros_like(bin_integrals))
843
+
844
+ fvals = bin_integrals
845
+
846
+ # compute overall scaling needed to restore mean == variance condition
847
+ yval_total = tf.reduce_sum(yvals, keepdims = True, axis = norm_axes)
848
+ variance_total = tf.reduce_sum(yvariances, keepdims = True, axis = norm_axes)
849
+ isnull_total = variance_total == 0.
850
+ variance_total_safe = tf.where(isnull_total, tf.ones_like(variance_total), variance_total)
851
+ scale_total = yval_total/variance_total_safe
852
+ scale_total_safe = tf.where(isnull_total, tf.ones_like(scale_total), scale_total)
853
+
854
+ # skip likelihood calculation for empty bins to avoid inf or nan
855
+ # compute per-bin scaling needed to restore mean == variance condition, falling
856
+ # back to overall scaling for empty bins
857
+ isnull = tf.logical_or(yvals == 0., yvariances == 0.)
858
+ variances_safe = tf.where(isnull, tf.ones_like(yvariances), yvariances)
859
+ scale = yvals/variances_safe
860
+ scale_safe = tf.where(isnull, scale_total_safe*tf.ones_like(scale), scale)
861
+
862
+ norm = tf.reduce_sum(scale_safe*fvals, keepdims=True, axis = norm_axes)
863
+
864
+ fvalsnorm = fvals/norm
865
+
866
+ fvalsnorm_safe = tf.where(isnull, tf.ones_like(fvalsnorm), fvalsnorm)
867
+ nllv = -scale_safe*yvals*tf.math.log(scale_safe*fvalsnorm_safe)
868
+ nllv_safe = tf.where(isnull, tf.zeros_like(nllv), nllv)
869
+ nllsum = tf.reduce_sum(nllv_safe)
870
+ return nllsum
871
+
872
+ def chisq_loss_bin_integrated(parms, xvals, xwidths, xedges, yvals, yvariances, func, norm_axes = None, *args):
873
+ #FIXME this is only defined in 1D for now
874
+ cdfvals = func(xedges, parms, *args)
875
+ bin_integrals = cdfvals[1:] - cdfvals[:-1]
876
+ bin_integrals = tf.maximum(bin_integrals, tf.zeros_like(bin_integrals))
877
+
878
+ fvals = bin_integrals
879
+
880
+ # exclude zero-variance bins
881
+ variances_safe = tf.where(yvariances == 0., tf.ones_like(yvariances), yvariances)
882
+ chisqv = (fvals - yvals)**2/variances_safe
883
+ chisqv_safe = tf.where(yvariances == 0., tf.zeros_like(chisqv), chisqv)
884
+ chisqsum = tf.reduce_sum(chisqv_safe)
885
+
886
+ return chisqsum
887
+
888
+
889
+ def fit_hist(hist, func, initial_parmvals, max_iter = 5, edmtol = 1e-5, mode = "chisq", norm_axes = None, func_constraint = None, args = (), args_constraint=()):
890
+
891
+ dtype = tf.float64
892
+
893
+ xvals = [tf.constant(center, dtype=dtype) for center in hist.axes.centers]
894
+ xwidths = [tf.constant(width, dtype=dtype) for width in hist.axes.widths]
895
+ xedges = [tf.constant(edge, dtype=dtype) for edge in hist.axes.edges]
896
+ yvals = tf.constant(hist.values(), dtype=dtype)
897
+ yvariances = tf.constant(hist.variances(), dtype=dtype)
898
+
899
+ covscale = 1.
900
+ if mode == "chisq":
901
+ floss = chisq_loss
902
+ covscale = 2.
903
+ elif mode == "nll":
904
+ floss = nll_loss
905
+ elif mode == "nll_bin_integrated":
906
+ floss = nll_loss_bin_integrated
907
+ elif mode == "chisq_normalized":
908
+ floss = chisq_normalized_loss
909
+ covscale = 2.
910
+ elif mode == "chisq_loss_bin_integrated":
911
+ floss = chisq_loss_bin_integrated
912
+ covscale = 2.
913
+ elif mode == "nll_extended":
914
+ raise Exception("Not Implemented")
915
+ else:
916
+ raise Exception("unsupported mode")
917
+
918
+ val_grad_args = { "func_loss" : floss,
919
+ "func_constraint" : func_constraint,
920
+ "args_loss" : (xvals, xwidths, xedges, yvals, yvariances, func, norm_axes),
921
+ "extra_args_loss" : args,
922
+ "args_constraint" : (xvals, xedges),
923
+ "extra_args_constraint" : args_constraint}
924
+
925
+ def scipy_loss(parmvals, *args):
926
+ parms = tf.constant(parmvals, dtype=dtype)
927
+
928
+ # loss, grad = val_grad(floss, parms, xvals, xwidths, xedges, yvals, yvariances, func, norm_axes, *args)
929
+ loss, grad = val_grad(loss_with_constraint, parms=parms, **val_grad_args)
930
+ return loss.numpy(), grad.numpy()
931
+
932
+ def scipy_hessp(parmvals, p, *args):
933
+ parms = tf.constant(parmvals, dtype=dtype)
934
+
935
+ # loss, grad, hessp = val_grad_hessp(floss, p, parms, xvals, xwidths, xedges, yvals, yvariances, func, norm_axes, *args)
936
+ loss, grad, hessp = val_grad_hessp(loss_with_constraint, p, parms=parms, **val_grad_args)
937
+ return hessp.numpy()
938
+
939
+ current_parmvals = initial_parmvals
940
+ for iiter in range(max_iter):
941
+
942
+ res = scipy.optimize.minimize(scipy_loss, current_parmvals, method = "trust-krylov", jac = True, hessp = scipy_hessp, args = args)
943
+
944
+ current_parmvals = res.x
945
+
946
+ parms = tf.constant(current_parmvals, dtype=dtype)
947
+
948
+ # loss, grad, hess = val_grad_hess(floss, parms, xvals, xwidths, xedges, yvals, yvariances, func, norm_axes, *args)
949
+ loss, grad, hess = val_grad_hess(loss_with_constraint, parms=parms, **val_grad_args)
950
+ loss, grad, hess = loss.numpy(), grad.numpy(), hess.numpy()
951
+
952
+ try:
953
+ eigvals = np.linalg.eigvalsh(hess)
954
+ gradv = grad[:, np.newaxis]
955
+ edmval = 0.5*gradv.transpose()@np.linalg.solve(hess, gradv)
956
+ edmval = edmval[0][0]
957
+ except np.linalg.LinAlgError:
958
+ eigvals = np.zeros_like(grad)
959
+ edmval = 99.
960
+
961
+ converged = edmval < edmtol and np.abs(edmval) >= 0. and eigvals[0] > 0.
962
+ if converged:
963
+ break
964
+
965
+ status = 1
966
+ covstatus = 1
967
+
968
+ if edmval < edmtol and edmval >= -0.:
969
+ status = 0
970
+ if eigvals[0] > 0.:
971
+ covstatus = 0
972
+
973
+ try:
974
+ cov = covscale*np.linalg.inv(hess)
975
+ except np.linalg.LinAlgError:
976
+ cov = np.zeros_like(hess)
977
+ covstatus = 1
978
+
979
+ res = { "x" : current_parmvals,
980
+ "hess" : hess,
981
+ "cov" : cov,
982
+ "status" : status,
983
+ "covstatus" : covstatus,
984
+ "hess_eigvals" : eigvals,
985
+ "edmval" : edmval,
986
+ "loss_val" : loss }
987
+
988
+ return res
989
+