wums 0.1.6__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wums/boostHistHelpers.py CHANGED
@@ -40,7 +40,7 @@ def broadcastSystHist(h1, h2, flow=True, by_ax_name=True):
40
40
  h2.ndim - 1 - i: h2.values(flow=flow).shape[h2.ndim - 1 - i]
41
41
  for i in range(h2.ndim - h1.ndim)
42
42
  }
43
-
43
+
44
44
  broadcast_shape = list(moves.values()) + list(s1)
45
45
 
46
46
  try:
@@ -53,6 +53,10 @@ def broadcastSystHist(h1, h2, flow=True, by_ax_name=True):
53
53
  f" h2.axes: {h2.axes}"
54
54
  )
55
55
 
56
+ if by_ax_name:
57
+ # We also have to move axes that are in common between h1 and h2 but in different order
58
+ moves.update({sum([k<i for k in moves.keys()]) + h1.axes.name.index(n2): None for i, n2 in enumerate(h2.axes.name) if n2 in h1.axes.name})
59
+
56
60
  # move back to original order
57
61
  new_vals = np.moveaxis(new_vals, np.arange(len(moves)), list(moves.keys()))
58
62
 
wums/fitutils.py ADDED
@@ -0,0 +1,989 @@
1
+ import numpy as np
2
+ import scipy
3
+ import tensorflow as tf
4
+ import math
5
+
6
+ from numpy import (zeros, where, diff, floor, minimum, maximum, array, concatenate, logical_or, logical_xor,
7
+ sqrt)
8
+
9
+ def cubic_spline_interpolate(xi, yi, x, axis=-1, extrpl=[None, None]):
10
+
11
+ # natural cublic spline
12
+ # if extrpl is given, the spline is linearly extrapolated outside the given boundaries
13
+
14
+ # https://www.math.ntnu.no/emner/TMA4215/2008h/cubicsplines.pdf
15
+ # https://random-walks.org/content/misc/ncs/ncs.html
16
+
17
+ # move selected axis to the end
18
+ tensors = [xi, yi]
19
+ nelems = [tensor.shape.num_elements() for tensor in tensors]
20
+
21
+ max_nelems = max(nelems)
22
+ broadcast_shape = tensors[nelems.index(max_nelems)].shape
23
+
24
+ ndim = len(broadcast_shape)
25
+
26
+ if xi.shape.num_elements() < max_nelems:
27
+ xi = tf.broadcast_to(xi, broadcast_shape)
28
+ if yi.shape.num_elements() < max_nelems:
29
+ yi = tf.broadcast_to(yi, broadcast_shape)
30
+
31
+ # # permutation to move the selected axis to the end
32
+ selaxis = axis
33
+ if axis < 0:
34
+ selaxis = ndim + axis
35
+ axis = -1
36
+ permfwd = list(range(ndim))
37
+ permfwd.remove(selaxis)
38
+ permfwd.append(selaxis)
39
+
40
+ # reverse permutation to restore the original axis order
41
+ permrev = list(range(ndim))
42
+ permrev.remove(ndim-1)
43
+ permrev.insert(selaxis, ndim-1)
44
+
45
+
46
+ ## start spline construction
47
+ xi = tf.transpose(xi, permfwd)
48
+ yi = tf.transpose(yi, permfwd)
49
+ x = tf.transpose(x, permfwd)
50
+
51
+ h = tf.experimental.numpy.diff(xi, axis=axis)
52
+ b = tf.experimental.numpy.diff(yi, axis=axis) / h
53
+ v = 2 * (h[:, 1:] + h[:, :-1])
54
+ u = 6 * (b[:, 1:] - b[:, :-1])
55
+
56
+ shape = (xi.shape[0], xi.shape[-1]-2, xi.shape[-1]-2)
57
+ uu = u[:,:,None]
58
+
59
+ diag = v
60
+ superdiag = h[:, 1:-1]
61
+ subdiag = superdiag
62
+
63
+ z = tf.linalg.tridiagonal_solve(diagonals = [superdiag, diag, subdiag], rhs = uu, diagonals_format = "sequence")
64
+ z = tf.squeeze(z, axis=axis)
65
+ f = tf.zeros(xi.shape[0], dtype=tf.float64)[:,None]
66
+ z = tf.concat([f, z, f], axis=axis)
67
+
68
+ x_steps = tf.experimental.numpy.diff(x, axis=axis)
69
+ idx_zero_constant = tf.constant(0, dtype=tf.int64)
70
+ float64_zero_constant = tf.constant(0., dtype=tf.float64)
71
+
72
+ x_compare = x[...,None] < xi[..., None, :]
73
+ x_compare_all = tf.math.reduce_all(x_compare, axis=axis)
74
+ x_compare_none = tf.math.reduce_all(tf.logical_not(x_compare), axis=axis)
75
+ x_index = tf.argmax(x_compare, axis=axis) - 1
76
+ x_index = tf.where(x_compare_all, idx_zero_constant, x_index)
77
+ x_index = tf.where(x_compare_none, tf.constant(xi.shape[axis]-2, dtype=tf.int64), x_index)
78
+
79
+ nbatch = ndim - 1
80
+
81
+ z_xidx = tf.gather(z, x_index, batch_dims=nbatch, axis=-1)
82
+ z_1pxidx = tf.gather(z, x_index+1, batch_dims=nbatch, axis=-1)
83
+ h_xidx = tf.gather(h, x_index, batch_dims=nbatch, axis=-1)
84
+ xi_xidx = tf.gather(xi, x_index, batch_dims=nbatch, axis=-1)
85
+ xi_1pxidx = tf.gather(xi, x_index+1, batch_dims=nbatch, axis=-1)
86
+ dxxi = x - xi_xidx
87
+ dxix = xi_1pxidx - x
88
+
89
+ y = z_1pxidx / (6 * h_xidx) * dxxi ** 3 + \
90
+ z_xidx / (6 * h_xidx) * dxix ** 3 + \
91
+ (tf.gather(yi, x_index+1, batch_dims=nbatch, axis=-1) / h_xidx - \
92
+ z_1pxidx * h_xidx / 6) * dxxi + \
93
+ (tf.gather(yi, x_index, batch_dims=nbatch, axis=-1) / h_xidx - \
94
+ z_xidx * h_xidx / 6) * dxix
95
+
96
+
97
+ # right side linear extrapolation
98
+ if extrpl[1] != None:
99
+ x_max = tf.reshape(tf.constant(extrpl[1], dtype=tf.float64), (1,1))
100
+
101
+ # calculate derivative yp_max at boundary
102
+ x_compare_max = x_max[...,None] < xi[..., None, :]
103
+ x_compare_max_all = tf.math.reduce_all(x_compare_max, axis=axis)
104
+ x_compare_max_none = tf.math.reduce_all(tf.logical_not(x_compare_max), axis=axis)
105
+ x_index_max = tf.argmax(x_compare_max, axis=axis) - 1
106
+ x_index_max = tf.where(x_compare_max_all, idx_zero_constant, x_index_max)
107
+ x_index_max = tf.where(x_compare_max_none, tf.constant(xi.shape[axis]-2, dtype=tf.int64), x_index_max)
108
+
109
+ z_xidx = tf.gather(z, x_index_max, batch_dims=nbatch, axis=-1)
110
+ z_1pxidx = tf.gather(z, x_index_max+1, batch_dims=nbatch, axis=-1)
111
+ hi_xidx = tf.gather(h, x_index_max, batch_dims=nbatch, axis=-1)
112
+ xi_xidx = tf.gather(xi, x_index_max, batch_dims=nbatch, axis=-1)
113
+ xi_1pxidx = tf.gather(xi, x_index_max+1, batch_dims=nbatch, axis=-1)
114
+ yi_xidx = tf.gather(yi, x_index_max, batch_dims=nbatch, axis=-1)
115
+ yi_1pxidx = tf.gather(yi, x_index_max+1, batch_dims=nbatch, axis=-1)
116
+
117
+ yp_max = z_1pxidx / (2 * hi_xidx) * (x_max - xi_xidx) ** 2 - \
118
+ z_xidx / (2 * hi_xidx) * (xi_1pxidx - x_max) ** 2 + \
119
+ 1./hi_xidx*(yi_1pxidx - yi_xidx) - hi_xidx/6.*(z_1pxidx - z_xidx)
120
+
121
+ # evaluate spline at boundary
122
+ y_b = cubic_spline_interpolate(xi, yi, x_max, axis=axis) # check shape of x_max
123
+
124
+ # replace y by lin for x > x_max
125
+ extrpl_lin = yp_max*(x-x_max) + y_b
126
+ cond = x[0,:] >= x_max
127
+ cond = tf.broadcast_to(cond, (extrpl_lin.shape[0], extrpl_lin.shape[1]))
128
+ y = tf.where(cond, extrpl_lin, y)
129
+
130
+
131
+ # left side linear extrapolation
132
+ if extrpl[0] != None:
133
+ x_min = tf.reshape(tf.constant(extrpl[0], dtype=tf.float64), (1,1))
134
+
135
+ # calculate derivative yp_min at boundary
136
+ x_compare_min = x_min[...,None] >= xi[..., None, :]
137
+ x_compare_min_all = tf.math.reduce_all(x_compare_min, axis=axis)
138
+ x_compare_min_none = tf.math.reduce_all(tf.logical_not(x_compare_min), axis=axis)
139
+ x_index_min = tf.argmax(x_compare_min, axis=axis)
140
+ x_index_min = tf.where(x_compare_min_all, idx_zero_constant, x_index_min)
141
+ x_index_min = tf.where(x_compare_min_none, tf.constant(0, dtype=tf.int64), x_index_min)
142
+
143
+ z_xidx = tf.gather(z, x_index_min, batch_dims=nbatch, axis=-1)
144
+ z_1pxidx = tf.gather(z, x_index_min+1, batch_dims=nbatch, axis=-1)
145
+ hi_xidx = tf.gather(h, x_index_min, batch_dims=nbatch, axis=-1)
146
+ xi_xidx = tf.gather(xi, x_index_min, batch_dims=nbatch, axis=-1)
147
+ xi_1pxidx = tf.gather(xi, x_index_min+1, batch_dims=nbatch, axis=-1)
148
+ yi_xidx = tf.gather(yi, x_index_min, batch_dims=nbatch, axis=-1)
149
+ yi_1pxidx = tf.gather(yi, x_index_min+1, batch_dims=nbatch, axis=-1)
150
+
151
+ yp_min = z_1pxidx / (2 * hi_xidx) * (x_min - xi_xidx) ** 2 - \
152
+ z_xidx / (2 * hi_xidx) * (xi_1pxidx - x_min) ** 2 + \
153
+ 1./hi_xidx*(yi_1pxidx - yi_xidx) - hi_xidx/6.*(z_1pxidx - z_xidx)
154
+
155
+ # evaluate spline at boundary
156
+ y_b = cubic_spline_interpolate(xi, yi, x_min, axis=axis) # check shape of x_max
157
+
158
+ # replace y by lin for x > x_min
159
+ extrpl_lin = yp_min*(x-x_min) + y_b
160
+ cond = x[0,:] <= x_min
161
+ cond = tf.broadcast_to(cond, (extrpl_lin.shape[0], extrpl_lin.shape[1]))
162
+ y = tf.where(cond, extrpl_lin, y)
163
+
164
+ ## convert back axes
165
+ y = tf.transpose(y, permrev)
166
+ return y
167
+
168
+
169
+
170
+ def pchip_interpolate(xi, yi, x, axis=-1):
171
+ '''
172
+ Functionality:
173
+ 1D PCHP interpolation
174
+ Authors:
175
+ Michael Taylor <mtaylor@atlanticsciences.com>
176
+ Mathieu Virbel <mat@meltingrocks.com>
177
+ Link:
178
+ https://gist.github.com/tito/553f1135959921ce6699652bf656150d
179
+ https://github.com/tensorflow/tensorflow/issues/46609#issuecomment-774573667
180
+ '''
181
+
182
+ tensors = [xi, yi]
183
+ nelems = [tensor.shape.num_elements() for tensor in tensors]
184
+
185
+ max_nelems = max(nelems)
186
+ broadcast_shape = tensors[nelems.index(max_nelems)].shape
187
+
188
+ ndim = len(broadcast_shape)
189
+
190
+ if xi.shape.num_elements() < max_nelems:
191
+ xi = tf.broadcast_to(xi, broadcast_shape)
192
+ if yi.shape.num_elements() < max_nelems:
193
+ yi = tf.broadcast_to(yi, broadcast_shape)
194
+
195
+ # # permutation to move the selected axis to the end
196
+ selaxis = axis
197
+ if axis < 0:
198
+ selaxis = ndim + axis
199
+
200
+ permfwd = list(range(ndim))
201
+ permfwd.remove(selaxis)
202
+ permfwd.append(selaxis)
203
+
204
+ # reverse permutation to restore the original axis order
205
+ permrev = list(range(ndim))
206
+ permrev.remove(ndim-1)
207
+ permrev.insert(selaxis, ndim-1)
208
+
209
+ xi = tf.transpose(xi, permfwd)
210
+ yi = tf.transpose(yi, permfwd)
211
+ x = tf.transpose(x, permfwd)
212
+ axis = -1
213
+
214
+ xi_steps = tf.experimental.numpy.diff(xi, axis=axis)
215
+
216
+
217
+ x_steps = tf.experimental.numpy.diff(x, axis=axis)
218
+
219
+ idx_zero_constant = tf.constant(0, dtype=tf.int64)
220
+ float64_zero_constant = tf.constant(0., dtype=tf.float64)
221
+
222
+ x_compare = x[...,None] < xi[..., None, :]
223
+ x_compare_all = tf.math.reduce_all(x_compare, axis=-1)
224
+ x_compare_none = tf.math.reduce_all(tf.logical_not(x_compare), axis=-1)
225
+ x_index = tf.argmax(x_compare, axis = -1) - 1
226
+
227
+ x_index = tf.where(x_compare_all, idx_zero_constant, x_index)
228
+ x_index = tf.where(x_compare_none, tf.constant(xi.shape[axis]-2, dtype=tf.int64), x_index)
229
+
230
+ # Calculate gradients d
231
+ h = tf.experimental.numpy.diff(xi, axis=axis)
232
+
233
+ d = tf.zeros_like(xi)
234
+
235
+ delta = tf.experimental.numpy.diff(yi, axis=axis) / h
236
+ # mode=='mono', Fritsch-Carlson algorithm from fortran numerical
237
+ # recipe
238
+
239
+ slice01 = [slice(None)]*ndim
240
+ slice01[axis] = slice(0,1)
241
+ slice01 = tuple(slice01)
242
+
243
+ slice0m1 = [slice(None)]*ndim
244
+ slice0m1[axis] = slice(0,-1)
245
+ slice0m1 = tuple(slice0m1)
246
+
247
+ slice1 = [slice(None)]*ndim
248
+ slice1[axis] = slice(1,None)
249
+ slice1 = tuple(slice1)
250
+
251
+ slicem1 = [slice(None)]*ndim
252
+ slicem1[axis] = slice(-1,None)
253
+ slicem1 = tuple(slicem1)
254
+
255
+ d = tf.concat(
256
+ (delta[slice01], 3 * (h[slice0m1] + h[slice1]) / ((h[slice0m1] + 2 * h[slice1]) / delta[slice0m1] +
257
+ (2 * h[slice0m1] + h[slice1]) / delta[slice1]), delta[slicem1]), axis=axis)
258
+
259
+ false_shape = [*xi.shape]
260
+ false_shape[axis] = 1
261
+ false_const = tf.fill(false_shape, False)
262
+
263
+ mask = tf.concat((false_const, tf.math.logical_xor(delta[slice0m1] > 0, delta[slice1] > 0), false_const), axis=axis)
264
+ d = tf.where(mask, float64_zero_constant, d)
265
+
266
+ mask = tf.math.logical_or(tf.concat((false_const, delta == 0), axis=axis), tf.concat((delta == 0, false_const), axis=axis))
267
+ d = tf.where(mask, float64_zero_constant, d)
268
+
269
+ xiperm = xi
270
+ yiperm = yi
271
+ dperm = d
272
+ hperm = h
273
+
274
+ nbatch = ndim - 1
275
+
276
+ xi_xidx = tf.gather(xiperm, x_index, batch_dims=nbatch, axis=-1)
277
+ xi_1pxidx = tf.gather(xiperm, 1 + x_index, batch_dims=nbatch, axis=-1)
278
+ yi_xidx = tf.gather(yiperm, x_index, batch_dims=nbatch, axis=-1)
279
+ yi_1pxidx = tf.gather(yiperm, 1 + x_index, batch_dims=nbatch, axis=-1)
280
+ d_xidx = tf.gather(dperm, x_index, batch_dims=nbatch, axis=-1)
281
+ d_1pxidx = tf.gather(dperm, 1 + x_index, batch_dims=nbatch, axis=-1)
282
+ h_xidx = tf.gather(hperm, x_index, batch_dims=nbatch, axis=-1)
283
+
284
+ dxxi = x - xi_xidx
285
+ dxxid = x - xi_1pxidx
286
+ dxxi2 = tf.math.pow(dxxi, 2)
287
+ dxxid2 = tf.math.pow(dxxid, 2)
288
+
289
+ y = (2 / tf.math.pow(h_xidx, 3) *
290
+ (yi_xidx * dxxid2 * (dxxi + h_xidx / 2) - yi_1pxidx * dxxi2 *
291
+ (dxxid - h_xidx / 2)) + 1 / tf.math.pow(h_xidx, 2) *
292
+ (d_xidx * dxxid2 * dxxi + d_1pxidx * dxxi2 * dxxid))
293
+
294
+ y = tf.transpose(y, permrev)
295
+
296
+ return y
297
+
298
+ def pchip_interpolate_np(xi, yi, x, mode="mono", verbose=False):
299
+ '''
300
+ Functionality:
301
+ 1D PCHP interpolation
302
+ Authors:
303
+ Michael Taylor <mtaylor@atlanticsciences.com>
304
+ Mathieu Virbel <mat@meltingrocks.com>
305
+ Link:
306
+ https://gist.github.com/tito/553f1135959921ce6699652bf656150d
307
+ '''
308
+
309
+ if mode not in ("mono", "quad"):
310
+ raise ValueError("Unrecognized mode string")
311
+
312
+ # Search for [xi,xi+1] interval for each x
313
+ xi = xi.astype("double")
314
+ yi = yi.astype("double")
315
+
316
+ x_index = zeros(len(x), dtype="int")
317
+ xi_steps = diff(xi)
318
+ if not all(xi_steps > 0):
319
+ raise ValueError("x-coordinates are not in increasing order.")
320
+
321
+ x_steps = diff(x)
322
+ if xi_steps.max() / xi_steps.min() < 1.000001:
323
+ # uniform input grid
324
+ if verbose:
325
+ print("pchip: uniform input grid")
326
+ xi_start = xi[0]
327
+ xi_step = (xi[-1] - xi[0]) / (len(xi) - 1)
328
+ x_index = minimum(maximum(floor((x - xi_start) / xi_step).astype(int), 0), len(xi) - 2)
329
+
330
+ # Calculate gradients d
331
+ h = (xi[-1] - xi[0]) / (len(xi) - 1)
332
+ d = zeros(len(xi), dtype="double")
333
+ if mode == "quad":
334
+ # quadratic polynomial fit
335
+ d[[0]] = (yi[1] - yi[0]) / h
336
+ d[[-1]] = (yi[-1] - yi[-2]) / h
337
+ d[1:-1] = (yi[2:] - yi[0:-2]) / 2 / h
338
+ else:
339
+ # mode=='mono', Fritsch-Carlson algorithm from fortran numerical
340
+ # recipe
341
+ delta = diff(yi) / h
342
+ d = concatenate((delta[0:1], 2 / (1 / delta[0:-1] + 1 / delta[1:]), delta[-1:]))
343
+ d[concatenate((array([False]), logical_xor(delta[0:-1] > 0, delta[1:] > 0), array([False])))] = 0
344
+ d[logical_or(concatenate((array([False]), delta == 0)), concatenate(
345
+ (delta == 0, array([False]))))] = 0
346
+ # Calculate output values y
347
+ dxxi = x - xi[x_index]
348
+ dxxid = x - xi[1 + x_index]
349
+ dxxi2 = pow(dxxi, 2)
350
+ dxxid2 = pow(dxxid, 2)
351
+ y = (2 / pow(h, 3) * (yi[x_index] * dxxid2 * (dxxi + h / 2) - yi[1 + x_index] * dxxi2 *
352
+ (dxxid - h / 2)) + 1 / pow(h, 2) *
353
+ (d[x_index] * dxxid2 * dxxi + d[1 + x_index] * dxxi2 * dxxid))
354
+ else:
355
+ # not uniform input grid
356
+ if (x_steps.max() / x_steps.min() < 1.000001 and x_steps.max() / x_steps.min() > 0.999999):
357
+ # non-uniform input grid, uniform output grid
358
+ if verbose:
359
+ print("pchip: non-uniform input grid, uniform output grid")
360
+ x_decreasing = x[-1] < x[0]
361
+ if x_decreasing:
362
+ x = x[::-1]
363
+ x_start = x[0]
364
+ x_step = (x[-1] - x[0]) / (len(x) - 1)
365
+ x_indexprev = -1
366
+ for xi_loop in range(len(xi) - 2):
367
+ x_indexcur = max(int(floor((xi[1 + xi_loop] - x_start) / x_step)), -1)
368
+ x_index[1 + x_indexprev:1 + x_indexcur] = xi_loop
369
+ x_indexprev = x_indexcur
370
+ x_index[1 + x_indexprev:] = len(xi) - 2
371
+ if x_decreasing:
372
+ x = x[::-1]
373
+ x_index = x_index[::-1]
374
+ elif all(x_steps > 0) or all(x_steps < 0):
375
+ # non-uniform input/output grids, output grid monotonic
376
+ if verbose:
377
+ print("pchip: non-uniform in/out grid, output grid monotonic")
378
+ x_decreasing = x[-1] < x[0]
379
+ if x_decreasing:
380
+ x = x[::-1]
381
+ x_len = len(x)
382
+ x_loop = 0
383
+ for xi_loop in range(len(xi) - 1):
384
+ while x_loop < x_len and x[x_loop] < xi[1 + xi_loop]:
385
+ x_index[x_loop] = xi_loop
386
+ x_loop += 1
387
+ x_index[x_loop:] = len(xi) - 2
388
+ if x_decreasing:
389
+ x = x[::-1]
390
+ x_index = x_index[::-1]
391
+ else:
392
+ # non-uniform input/output grids, output grid not monotonic
393
+ if verbose:
394
+ print("pchip: non-uniform in/out grids, " "output grid not monotonic")
395
+ for index in range(len(x)):
396
+ loc = where(x[index] < xi)[0]
397
+ if loc.size == 0:
398
+ x_index[index] = len(xi) - 2
399
+ elif loc[0] == 0:
400
+ x_index[index] = 0
401
+ else:
402
+ x_index[index] = loc[0] - 1
403
+ # Calculate gradients d
404
+ h = diff(xi)
405
+ d = zeros(len(xi), dtype="double")
406
+ delta = diff(yi) / h
407
+ if mode == "quad":
408
+ # quadratic polynomial fit
409
+ d[[0, -1]] = delta[[0, -1]]
410
+ d[1:-1] = (delta[1:] * h[0:-1] + delta[0:-1] * h[1:]) / (h[0:-1] + h[1:])
411
+ else:
412
+ # mode=='mono', Fritsch-Carlson algorithm from fortran numerical
413
+ # recipe
414
+ d = concatenate(
415
+ (delta[0:1], 3 * (h[0:-1] + h[1:]) / ((h[0:-1] + 2 * h[1:]) / delta[0:-1] +
416
+ (2 * h[0:-1] + h[1:]) / delta[1:]), delta[-1:]))
417
+ d[concatenate((array([False]), logical_xor(delta[0:-1] > 0, delta[1:] > 0), array([False])))] = 0
418
+ d[logical_or(concatenate((array([False]), delta == 0)), concatenate(
419
+ (delta == 0, array([False]))))] = 0
420
+ dxxi = x - xi[x_index]
421
+ dxxid = x - xi[1 + x_index]
422
+ dxxi2 = pow(dxxi, 2)
423
+ dxxid2 = pow(dxxid, 2)
424
+ y = (2 / pow(h[x_index], 3) *
425
+ (yi[x_index] * dxxid2 * (dxxi + h[x_index] / 2) - yi[1 + x_index] * dxxi2 *
426
+ (dxxid - h[x_index] / 2)) + 1 / pow(h[x_index], 2) *
427
+ (d[x_index] * dxxid2 * dxxi + d[1 + x_index] * dxxi2 * dxxid))
428
+ return y
429
+
430
+
431
+
432
+ def pchip_interpolate_np_forced(xi, yi, x, mode="mono", verbose=False):
433
+ '''
434
+ Functionality:
435
+ 1D PCHP interpolation
436
+ Authors:
437
+ Michael Taylor <mtaylor@atlanticsciences.com>
438
+ Mathieu Virbel <mat@meltingrocks.com>
439
+ Link:
440
+ https://gist.github.com/tito/553f1135959921ce6699652bf656150d
441
+ '''
442
+
443
+ if mode not in ("mono", "quad"):
444
+ raise ValueError("Unrecognized mode string")
445
+
446
+ # Search for [xi,xi+1] interval for each x
447
+ xi = xi.astype("double")
448
+ yi = yi.astype("double")
449
+
450
+ x_index = zeros(len(x), dtype="int")
451
+ xi_steps = diff(xi)
452
+ if not all(xi_steps > 0):
453
+ raise ValueError("x-coordinates are not in increasing order.")
454
+
455
+ x_steps = diff(x)
456
+ # if xi_steps.max() / xi_steps.min() < 1.000001:
457
+ if False:
458
+ # uniform input grid
459
+ if verbose:
460
+ print("pchip: uniform input grid")
461
+ xi_start = xi[0]
462
+ xi_step = (xi[-1] - xi[0]) / (len(xi) - 1)
463
+ x_index = minimum(maximum(floor((x - xi_start) / xi_step).astype(int), 0), len(xi) - 2)
464
+
465
+ # Calculate gradients d
466
+ h = (xi[-1] - xi[0]) / (len(xi) - 1)
467
+ d = zeros(len(xi), dtype="double")
468
+ if mode == "quad":
469
+ # quadratic polynomial fit
470
+ d[[0]] = (yi[1] - yi[0]) / h
471
+ d[[-1]] = (yi[-1] - yi[-2]) / h
472
+ d[1:-1] = (yi[2:] - yi[0:-2]) / 2 / h
473
+ else:
474
+ # mode=='mono', Fritsch-Carlson algorithm from fortran numerical
475
+ # recipe
476
+ delta = diff(yi) / h
477
+ d = concatenate((delta[0:1], 2 / (1 / delta[0:-1] + 1 / delta[1:]), delta[-1:]))
478
+ d[concatenate((array([False]), logical_xor(delta[0:-1] > 0, delta[1:] > 0), array([False])))] = 0
479
+ d[logical_or(concatenate((array([False]), delta == 0)), concatenate(
480
+ (delta == 0, array([False]))))] = 0
481
+ # Calculate output values y
482
+ dxxi = x - xi[x_index]
483
+ dxxid = x - xi[1 + x_index]
484
+ dxxi2 = pow(dxxi, 2)
485
+ dxxid2 = pow(dxxid, 2)
486
+ y = (2 / pow(h, 3) * (yi[x_index] * dxxid2 * (dxxi + h / 2) - yi[1 + x_index] * dxxi2 *
487
+ (dxxid - h / 2)) + 1 / pow(h, 2) *
488
+ (d[x_index] * dxxid2 * dxxi + d[1 + x_index] * dxxi2 * dxxid))
489
+ else:
490
+ # not uniform input grid
491
+ # if (x_steps.max() / x_steps.min() < 1.000001 and x_steps.max() / x_steps.min() > 0.999999):
492
+ if False:
493
+ # non-uniform input grid, uniform output grid
494
+ if verbose:
495
+ print("pchip: non-uniform input grid, uniform output grid")
496
+ x_decreasing = x[-1] < x[0]
497
+ if x_decreasing:
498
+ x = x[::-1]
499
+ x_start = x[0]
500
+ x_step = (x[-1] - x[0]) / (len(x) - 1)
501
+ x_indexprev = -1
502
+ for xi_loop in range(len(xi) - 2):
503
+ x_indexcur = max(int(floor((xi[1 + xi_loop] - x_start) / x_step)), -1)
504
+ x_index[1 + x_indexprev:1 + x_indexcur] = xi_loop
505
+ x_indexprev = x_indexcur
506
+ x_index[1 + x_indexprev:] = len(xi) - 2
507
+ if x_decreasing:
508
+ x = x[::-1]
509
+ x_index = x_index[::-1]
510
+ # elif all(x_steps > 0) or all(x_steps < 0):
511
+ elif True:
512
+ # non-uniform input/output grids, output grid monotonic
513
+ if verbose:
514
+ print("pchip: non-uniform in/out grid, output grid monotonic")
515
+ # x_decreasing = x[-1] < x[0]
516
+ x_decreasing = False
517
+ if x_decreasing:
518
+ x = x[::-1]
519
+ x_len = len(x)
520
+ x_loop = 0
521
+ for xi_loop in range(len(xi) - 1):
522
+ while x_loop < x_len and x[x_loop] < xi[1 + xi_loop]:
523
+ x_index[x_loop] = xi_loop
524
+ x_loop += 1
525
+ x_index[x_loop:] = len(xi) - 2
526
+
527
+ print("np_forced x_index", x_index)
528
+ if x_decreasing:
529
+ x = x[::-1]
530
+ x_index = x_index[::-1]
531
+ else:
532
+ # non-uniform input/output grids, output grid not monotonic
533
+ if verbose:
534
+ print("pchip: non-uniform in/out grids, " "output grid not monotonic")
535
+ for index in range(len(x)):
536
+ loc = where(x[index] < xi)[0]
537
+ if loc.size == 0:
538
+ x_index[index] = len(xi) - 2
539
+ elif loc[0] == 0:
540
+ x_index[index] = 0
541
+ else:
542
+ x_index[index] = loc[0] - 1
543
+ # Calculate gradients d
544
+ h = diff(xi)
545
+ d = zeros(len(xi), dtype="double")
546
+ delta = diff(yi) / h
547
+ if mode == "quad":
548
+ # quadratic polynomial fit
549
+ d[[0, -1]] = delta[[0, -1]]
550
+ d[1:-1] = (delta[1:] * h[0:-1] + delta[0:-1] * h[1:]) / (h[0:-1] + h[1:])
551
+ else:
552
+ # mode=='mono', Fritsch-Carlson algorithm from fortran numerical
553
+ # recipe
554
+ d = concatenate(
555
+ (delta[0:1], 3 * (h[0:-1] + h[1:]) / ((h[0:-1] + 2 * h[1:]) / delta[0:-1] +
556
+ (2 * h[0:-1] + h[1:]) / delta[1:]), delta[-1:]))
557
+ d[concatenate((array([False]), logical_xor(delta[0:-1] > 0, delta[1:] > 0), array([False])))] = 0
558
+ d[logical_or(concatenate((array([False]), delta == 0)), concatenate(
559
+ (delta == 0, array([False]))))] = 0
560
+ dxxi = x - xi[x_index]
561
+ dxxid = x - xi[1 + x_index]
562
+ dxxi2 = pow(dxxi, 2)
563
+ dxxid2 = pow(dxxid, 2)
564
+ y = (2 / pow(h[x_index], 3) *
565
+ (yi[x_index] * dxxid2 * (dxxi + h[x_index] / 2) - yi[1 + x_index] * dxxi2 *
566
+ (dxxid - h[x_index] / 2)) + 1 / pow(h[x_index], 2) *
567
+ (d[x_index] * dxxid2 * dxxi + d[1 + x_index] * dxxi2 * dxxid))
568
+ return y
569
+
570
+ def qparms_to_quantiles(qparms, x_low = 0., x_high = 1., axis = -1):
571
+ deltax = tf.exp(qparms)
572
+ sumdeltax = tf.math.reduce_sum(deltax, axis=axis, keepdims=True)
573
+
574
+ deltaxnorm = deltax/sumdeltax
575
+
576
+ x0shape = list(deltaxnorm.shape)
577
+ x0shape[axis] = 1
578
+ x0 = tf.fill(x0shape, x_low)
579
+ x0 = tf.cast(x0, tf.float64)
580
+
581
+ deltaxfull = (x_high - x_low)*deltaxnorm
582
+ deltaxfull = tf.concat([x0, deltaxfull], axis = axis)
583
+
584
+ quants = tf.math.cumsum(deltaxfull, axis=axis)
585
+
586
+ return quants
587
+
588
+
589
+
590
+ def quantiles_to_qparms(quants, quant_errs = None, x_low = 0., x_high = 1., axis = -1):
591
+
592
+ deltaxfull = tf.experimental.numpy.diff(quants, axis=axis)
593
+ deltaxnorm = deltaxfull/(x_high - x_low)
594
+ qparms = tf.math.log(deltaxnorm)
595
+
596
+ if quant_errs is not None:
597
+ quant_vars = tf.math.square(quant_errs)
598
+
599
+ ndim = len(quant_errs.shape)
600
+
601
+ slicem1 = [slice(None)]*ndim
602
+ slicem1[axis] = slice(None,-1)
603
+ slicem1 = tuple(slicem1)
604
+
605
+ slice1 = [slice(None)]*ndim
606
+ slice1[axis] = slice(1,None)
607
+ slice1 = tuple(slice1)
608
+
609
+ deltaxfull_vars = quant_vars[slice1] + quant_vars[slicem1]
610
+ deltaxfull_errs = tf.math.sqrt(deltaxfull_vars)
611
+
612
+ qparm_errs = deltaxfull_errs/deltaxfull
613
+
614
+ return qparms, qparm_errs
615
+ else:
616
+ return qparms
617
+
618
+
619
+ def hist_to_quantiles(h, quant_cdfvals, axis = -1):
620
+ dtype = tf.float64
621
+
622
+ xvals = [tf.constant(center, dtype=dtype) for center in h.axes.centers]
623
+ xwidths = [tf.constant(width, dtype=dtype) for width in h.axes.widths]
624
+ xedges = [tf.constant(edge, dtype=dtype) for edge in h.axes.edges]
625
+ yvals = tf.constant(h.values(), dtype=dtype)
626
+
627
+ if not isinstance(quant_cdfvals, tf.Tensor):
628
+ quant_cdfvals = tf.constant(quant_cdfvals, tf.float64)
629
+
630
+ x_flat = tf.reshape(xedges[axis], (-1,))
631
+ x_low = x_flat[0]
632
+ x_high = x_flat[-1]
633
+
634
+ hist_cdfvals = tf.cumsum(yvals, axis=axis)/tf.reduce_sum(yvals, axis=axis, keepdims=True)
635
+
636
+ x0shape = list(hist_cdfvals.shape)
637
+ x0shape[axis] = 1
638
+ x0 = tf.zeros(x0shape, dtype = dtype)
639
+
640
+ hist_cdfvals = tf.concat([x0, hist_cdfvals], axis=axis)
641
+
642
+ quants = pchip_interpolate(hist_cdfvals, xedges[axis], quant_cdfvals, axis=axis)
643
+
644
+ quants = tf.where(quant_cdfvals == 0., x_low, quants)
645
+ quants = tf.where(quant_cdfvals == 1., x_high, quants)
646
+
647
+ ntot = tf.math.reduce_sum(yvals, axis=axis, keepdims=True)
648
+
649
+ quant_cdf_bar = ntot/(1.+ntot)*(quant_cdfvals + 0.5/ntot)
650
+ quant_cdfval_errs = ntot/(1.+ntot)*tf.math.sqrt(quant_cdfvals*(1.-quant_cdfvals)/ntot + 0.25/ntot/ntot)
651
+
652
+ quant_cdfvals_up = quant_cdf_bar + quant_cdfval_errs
653
+ quant_cdfvals_up = tf.clip_by_value(quant_cdfvals_up, 0., 1.)
654
+
655
+ quant_cdfvals_down = quant_cdf_bar - quant_cdfval_errs
656
+ quant_cdfvals_down = tf.clip_by_value(quant_cdfvals_down, 0., 1.)
657
+
658
+ quants_up = pchip_interpolate(hist_cdfvals, xedges[axis], quant_cdfvals_up, axis=axis)
659
+ quants_up = tf.where(quant_cdfvals_up == 0., x_low, quants_up)
660
+ quants_up = tf.where(quant_cdfvals_up == 1., x_high, quants_up)
661
+
662
+ quants_down = pchip_interpolate(hist_cdfvals, xedges[axis], quant_cdfvals_down, axis=axis)
663
+ quants_down = tf.where(quant_cdfvals_down == 0., x_low, quants_down)
664
+ quants_down = tf.where(quant_cdfvals_down == 1., x_high, quants_down)
665
+
666
+ quant_errs = 0.5*(quants_up - quants_down)
667
+
668
+ zero_const = tf.constant(0., dtype)
669
+
670
+ quant_errs = tf.where(quant_cdfvals == 0., zero_const, quant_errs)
671
+ quant_errs = tf.where(quant_cdfvals == 1., zero_const, quant_errs)
672
+
673
+ return quants.numpy(), quant_errs.numpy()
674
+
675
+ def func_cdf_for_quantile_fit(xvals, xedges, qparms, quant_cdfvals, axis=-1, transform = None):
676
+ x_flat = tf.reshape(xedges[axis], (-1,))
677
+ x_low = x_flat[0]
678
+ x_high = x_flat[-1]
679
+
680
+ quants = qparms_to_quantiles(qparms, x_low = x_low, x_high = x_high, axis = axis)
681
+
682
+ spline_edges = xedges[axis]
683
+
684
+ ndim = len(xvals)
685
+
686
+ if transform is not None:
687
+ transform_cdf, transform_quantile = transform
688
+
689
+ slicelim = [slice(None)]*ndim
690
+ slicelim[axis] = slice(1, -1)
691
+ slicelim = tuple(slicelim)
692
+
693
+ quants = quants[slicelim]
694
+ quant_cdfvals = quant_cdfvals[slicelim]
695
+
696
+ quant_cdfvals = transform_quantile(quant_cdfvals)
697
+
698
+ cdfvals = pchip_interpolate(quants, quant_cdfvals, spline_edges, axis=axis)
699
+
700
+ if transform is not None:
701
+ cdfvals = transform_cdf(cdfvals)
702
+
703
+ slicefirst = [slice(None)]*ndim
704
+ slicefirst[axis] = slice(None, 1)
705
+ slicefirst = tuple(slicefirst)
706
+
707
+ slicelast = [slice(None)]*ndim
708
+ slicelast[axis] = slice(-1, None)
709
+ slicelast = tuple(slicelast)
710
+
711
+ cdfvals = (cdfvals - cdfvals[slicefirst])/(cdfvals[slicelast] - cdfvals[slicefirst])
712
+
713
+ return cdfvals
714
+
715
+ def func_constraint_for_quantile_fit(xvals, xedges, qparms, axis=-1):
716
+ constraints = 0.5*tf.math.square(tf.math.reduce_sum(tf.exp(qparms), axis=axis) - 1.)
717
+ constraint = tf.math.reduce_sum(constraints)
718
+ return constraint
719
+
720
+ @tf.function
721
+ def val_grad(func, *args, **kwargs):
722
+ xdep = kwargs["parms"]
723
+ with tf.GradientTape() as t1:
724
+ t1.watch(xdep)
725
+ val = func(*args, **kwargs)
726
+ grad = t1.gradient(val, xdep)
727
+ return val, grad
728
+
729
+ #TODO forward-over-reverse also here?
730
+ @tf.function
731
+ def val_grad_hess(func, *args, **kwargs):
732
+ xdep = kwargs["parms"]
733
+ with tf.GradientTape() as t2:
734
+ t2.watch(xdep)
735
+ with tf.GradientTape() as t1:
736
+ t1.watch(xdep)
737
+ val = func(*args, **kwargs)
738
+ grad = t1.gradient(val, xdep)
739
+ hess = t2.jacobian(grad, xdep)
740
+
741
+ return val, grad, hess
742
+
743
+ @tf.function
744
+ def val_grad_hessp(func, p, *args, **kwargs):
745
+ xdep = kwargs["parms"]
746
+ with tf.autodiff.ForwardAccumulator(xdep, p) as acc:
747
+ with tf.GradientTape() as grad_tape:
748
+ grad_tape.watch(xdep)
749
+ val = func(*args, **kwargs)
750
+ grad = grad_tape.gradient(val, xdep)
751
+ hessp = acc.jvp(grad)
752
+
753
+ return val, grad, hessp
754
+
755
+ def loss_with_constraint(func_loss, parms, func_constraint = None, args_loss = (), extra_args_loss=(), args_constraint = (), extra_args_constraint = ()):
756
+ loss = func_loss(parms, *args_loss, *extra_args_loss)
757
+ if func_constraint is not None:
758
+ loss += func_constraint(*args_constraint, parms, *extra_args_constraint)
759
+
760
+ return loss
761
+
762
+ def chisq_loss(parms, xvals, xwidths, xedges, yvals, yvariances, func, norm_axes = None, *args):
763
+ fvals = func(xvals, parms, *args)
764
+
765
+ # exclude zero-variance bins
766
+ variances_safe = tf.where(yvariances == 0., tf.ones_like(yvariances), yvariances)
767
+ chisqv = (fvals - yvals)**2/variances_safe
768
+ chisqv_safe = tf.where(yvariances == 0., tf.zeros_like(chisqv), chisqv)
769
+ return tf.reduce_sum(chisqv_safe)
770
+
771
+ def chisq_normalized_loss(parms, xvals, xwidths, xedges, yvals, yvariances, func, norm_axes = None, *args):
772
+ fvals = func(xvals, parms, *args)
773
+ norm = tf.reduce_sum(fvals, keepdims=True, axis = norm_axes)
774
+ sumw = tf.reduce_sum(yvals, keepdims=True, axis = norm_axes)
775
+ if norm_axes is None:
776
+ for xwidth in xwidths:
777
+ norm *= xwidth
778
+ else:
779
+ for norm_axis in norm_axes:
780
+ norm *= xwidths[norm_axis]
781
+
782
+ # exclude zero-variance bins
783
+ variances_safe = tf.where(yvariances == 0., tf.ones_like(yvariances), yvariances)
784
+ chisqv = (sumw*fvals/norm - yvals)**2/variances_safe
785
+ chisqv_safe = tf.where(yvariances == 0., tf.zeros_like(chisqv), chisqv)
786
+ return tf.reduce_sum(chisqv_safe)
787
+
788
+ def nll_loss(parms, xvals, xwidths, xedges, yvals, yvariances, func, norm_axes = None, *args):
789
+ fvals = func(xvals, parms, *args)
790
+
791
+ # compute overall scaling needed to restore mean == variance condition
792
+ yval_total = tf.reduce_sum(yvals, keepdims = True, axis = norm_axes)
793
+ variance_total = tf.reduce_sum(yvariances, keepdims = True, axis = norm_axes)
794
+ isnull_total = variance_total == 0.
795
+ variance_total_safe = tf.where(isnull_total, tf.ones_like(variance_total), variance_total)
796
+ scale_total = yval_total/variance_total_safe
797
+ scale_total_safe = tf.where(isnull_total, tf.ones_like(scale_total), scale_total)
798
+
799
+ # skip likelihood calculation for empty bins to avoid inf or nan
800
+ # compute per-bin scaling needed to restore mean == variance condition, falling
801
+ # back to overall scaling for empty bins
802
+ isnull = tf.logical_or(yvals == 0., yvariances == 0.)
803
+ variances_safe = tf.where(isnull, tf.ones_like(yvariances), yvariances)
804
+ scale = yvals/variances_safe
805
+ scale_safe = tf.where(isnull, scale_total_safe*tf.ones_like(scale), scale)
806
+
807
+ norm = tf.reduce_sum(scale_safe*fvals, keepdims=True, axis = norm_axes)
808
+ if norm_axes is None:
809
+ for xwidth in xwidths:
810
+ norm *= xwidth
811
+ else:
812
+ for norm_axis in norm_axes:
813
+ norm *= xwidths[norm_axis]
814
+
815
+ fvalsnorm = fvals/norm
816
+
817
+ fvalsnorm_safe = tf.where(isnull, tf.ones_like(fvalsnorm), fvalsnorm)
818
+ nllv = -scale_safe*yvals*tf.math.log(scale_safe*fvalsnorm_safe)
819
+ nllv_safe = tf.where(isnull, tf.zeros_like(nllv), nllv)
820
+ nllsum = tf.reduce_sum(nllv_safe)
821
+ return nllsum
822
+
823
+ def nll_loss_bin_integrated(parms, xvals, xwidths, xedges, yvals, yvariances, func, norm_axes = None, *args):
824
+ #TODO reduce code duplication with nll_loss_bin
825
+
826
+ norm_axis = 0
827
+ if norm_axes is not None:
828
+ if len(norm_axes) > 1:
829
+ raise ValueError("Only 1 nomralization access supported for bin-integrated nll")
830
+ norm_axis = norm_axes[0]
831
+
832
+ cdfvals = func(xvals, xedges, parms, *args)
833
+
834
+ slices_low = [slice(None)]*len(cdfvals.shape)
835
+ slices_low[norm_axis] = slice(None,-1)
836
+
837
+ slices_high = [slice(None)]*len(cdfvals.shape)
838
+ slices_high[norm_axis] = slice(1,None)
839
+
840
+ # bin_integrals = cdfvals[1:] - cdfvals[:-1]
841
+ bin_integrals = cdfvals[tuple(slices_high)] - cdfvals[tuple(slices_low)]
842
+ bin_integrals = tf.maximum(bin_integrals, tf.zeros_like(bin_integrals))
843
+
844
+ fvals = bin_integrals
845
+
846
+ # compute overall scaling needed to restore mean == variance condition
847
+ yval_total = tf.reduce_sum(yvals, keepdims = True, axis = norm_axes)
848
+ variance_total = tf.reduce_sum(yvariances, keepdims = True, axis = norm_axes)
849
+ isnull_total = variance_total == 0.
850
+ variance_total_safe = tf.where(isnull_total, tf.ones_like(variance_total), variance_total)
851
+ scale_total = yval_total/variance_total_safe
852
+ scale_total_safe = tf.where(isnull_total, tf.ones_like(scale_total), scale_total)
853
+
854
+ # skip likelihood calculation for empty bins to avoid inf or nan
855
+ # compute per-bin scaling needed to restore mean == variance condition, falling
856
+ # back to overall scaling for empty bins
857
+ isnull = tf.logical_or(yvals == 0., yvariances == 0.)
858
+ variances_safe = tf.where(isnull, tf.ones_like(yvariances), yvariances)
859
+ scale = yvals/variances_safe
860
+ scale_safe = tf.where(isnull, scale_total_safe*tf.ones_like(scale), scale)
861
+
862
+ norm = tf.reduce_sum(scale_safe*fvals, keepdims=True, axis = norm_axes)
863
+
864
+ fvalsnorm = fvals/norm
865
+
866
+ fvalsnorm_safe = tf.where(isnull, tf.ones_like(fvalsnorm), fvalsnorm)
867
+ nllv = -scale_safe*yvals*tf.math.log(scale_safe*fvalsnorm_safe)
868
+ nllv_safe = tf.where(isnull, tf.zeros_like(nllv), nllv)
869
+ nllsum = tf.reduce_sum(nllv_safe)
870
+ return nllsum
871
+
872
+ def chisq_loss_bin_integrated(parms, xvals, xwidths, xedges, yvals, yvariances, func, norm_axes = None, *args):
873
+ #FIXME this is only defined in 1D for now
874
+ cdfvals = func(xedges, parms, *args)
875
+ bin_integrals = cdfvals[1:] - cdfvals[:-1]
876
+ bin_integrals = tf.maximum(bin_integrals, tf.zeros_like(bin_integrals))
877
+
878
+ fvals = bin_integrals
879
+
880
+ # exclude zero-variance bins
881
+ variances_safe = tf.where(yvariances == 0., tf.ones_like(yvariances), yvariances)
882
+ chisqv = (fvals - yvals)**2/variances_safe
883
+ chisqv_safe = tf.where(yvariances == 0., tf.zeros_like(chisqv), chisqv)
884
+ chisqsum = tf.reduce_sum(chisqv_safe)
885
+
886
+ return chisqsum
887
+
888
+
889
+ def fit_hist(hist, func, initial_parmvals, max_iter = 5, edmtol = 1e-5, mode = "chisq", norm_axes = None, func_constraint = None, args = (), args_constraint=()):
890
+
891
+ dtype = tf.float64
892
+
893
+ xvals = [tf.constant(center, dtype=dtype) for center in hist.axes.centers]
894
+ xwidths = [tf.constant(width, dtype=dtype) for width in hist.axes.widths]
895
+ xedges = [tf.constant(edge, dtype=dtype) for edge in hist.axes.edges]
896
+ yvals = tf.constant(hist.values(), dtype=dtype)
897
+ yvariances = tf.constant(hist.variances(), dtype=dtype)
898
+
899
+ covscale = 1.
900
+ if mode == "chisq":
901
+ floss = chisq_loss
902
+ covscale = 2.
903
+ elif mode == "nll":
904
+ floss = nll_loss
905
+ elif mode == "nll_bin_integrated":
906
+ floss = nll_loss_bin_integrated
907
+ elif mode == "chisq_normalized":
908
+ floss = chisq_normalized_loss
909
+ covscale = 2.
910
+ elif mode == "chisq_loss_bin_integrated":
911
+ floss = chisq_loss_bin_integrated
912
+ covscale = 2.
913
+ elif mode == "nll_extended":
914
+ raise Exception("Not Implemented")
915
+ else:
916
+ raise Exception("unsupported mode")
917
+
918
+ val_grad_args = { "func_loss" : floss,
919
+ "func_constraint" : func_constraint,
920
+ "args_loss" : (xvals, xwidths, xedges, yvals, yvariances, func, norm_axes),
921
+ "extra_args_loss" : args,
922
+ "args_constraint" : (xvals, xedges),
923
+ "extra_args_constraint" : args_constraint}
924
+
925
+ def scipy_loss(parmvals, *args):
926
+ parms = tf.constant(parmvals, dtype=dtype)
927
+
928
+ # loss, grad = val_grad(floss, parms, xvals, xwidths, xedges, yvals, yvariances, func, norm_axes, *args)
929
+ loss, grad = val_grad(loss_with_constraint, parms=parms, **val_grad_args)
930
+ return loss.numpy(), grad.numpy()
931
+
932
+ def scipy_hessp(parmvals, p, *args):
933
+ parms = tf.constant(parmvals, dtype=dtype)
934
+
935
+ # loss, grad, hessp = val_grad_hessp(floss, p, parms, xvals, xwidths, xedges, yvals, yvariances, func, norm_axes, *args)
936
+ loss, grad, hessp = val_grad_hessp(loss_with_constraint, p, parms=parms, **val_grad_args)
937
+ return hessp.numpy()
938
+
939
+ current_parmvals = initial_parmvals
940
+ for iiter in range(max_iter):
941
+
942
+ res = scipy.optimize.minimize(scipy_loss, current_parmvals, method = "trust-krylov", jac = True, hessp = scipy_hessp, args = args)
943
+
944
+ current_parmvals = res.x
945
+
946
+ parms = tf.constant(current_parmvals, dtype=dtype)
947
+
948
+ # loss, grad, hess = val_grad_hess(floss, parms, xvals, xwidths, xedges, yvals, yvariances, func, norm_axes, *args)
949
+ loss, grad, hess = val_grad_hess(loss_with_constraint, parms=parms, **val_grad_args)
950
+ loss, grad, hess = loss.numpy(), grad.numpy(), hess.numpy()
951
+
952
+ try:
953
+ eigvals = np.linalg.eigvalsh(hess)
954
+ gradv = grad[:, np.newaxis]
955
+ edmval = 0.5*gradv.transpose()@np.linalg.solve(hess, gradv)
956
+ edmval = edmval[0][0]
957
+ except np.linalg.LinAlgError:
958
+ eigvals = np.zeros_like(grad)
959
+ edmval = 99.
960
+
961
+ converged = edmval < edmtol and np.abs(edmval) >= 0. and eigvals[0] > 0.
962
+ if converged:
963
+ break
964
+
965
+ status = 1
966
+ covstatus = 1
967
+
968
+ if edmval < edmtol and edmval >= -0.:
969
+ status = 0
970
+ if eigvals[0] > 0.:
971
+ covstatus = 0
972
+
973
+ try:
974
+ cov = covscale*np.linalg.inv(hess)
975
+ except np.linalg.LinAlgError:
976
+ cov = np.zeros_like(hess)
977
+ covstatus = 1
978
+
979
+ res = { "x" : current_parmvals,
980
+ "hess" : hess,
981
+ "cov" : cov,
982
+ "status" : status,
983
+ "covstatus" : covstatus,
984
+ "hess_eigvals" : eigvals,
985
+ "edmval" : edmval,
986
+ "loss_val" : loss }
987
+
988
+ return res
989
+
wums/fitutilsjax.py ADDED
@@ -0,0 +1,86 @@
1
+ import numpy as np
2
+ import scipy
3
+ import jax
4
+ import jax.numpy as jnp
5
+
6
+ def chisqloss(xvals, yvals, yvariances, func, parms):
7
+ return jnp.sum( (func(xvals, parms) - yvals)**2/yvariances )
8
+
9
+ chisqloss_grad = jax.jit(jax.value_and_grad(chisqloss, argnums = 4), static_argnums = 3)
10
+
11
+ def _chisqloss_grad_hess(xvals, yvals, yvariances, func, parms):
12
+ def lossf(parms):
13
+ return chisqloss(xvals, yvals, yvariances, func, parms)
14
+
15
+ gradf = jax.grad(lossf)
16
+ hessf = jax.jacfwd(gradf)
17
+
18
+ loss = lossf(parms)
19
+ grad = gradf(parms)
20
+ hess = hessf(parms)
21
+
22
+ return loss, grad, hess
23
+
24
+ chisqloss_grad_hess = jax.jit(_chisqloss_grad_hess, static_argnums = 3)
25
+
26
+ def _chisqloss_hessp(xvals, yvals, yvariances, func, parms, p):
27
+ def lossf(parms):
28
+ return chisqloss(xvals, yvals, yvariances, func, parms)
29
+
30
+ gradf = jax.grad(lossf)
31
+ hessp = jax.jvp(gradf, (parms,), (p,))[1]
32
+ return hessp
33
+
34
+ chisqloss_hessp = jax.jit(_chisqloss_hessp, static_argnums = 3)
35
+
36
+ def fit_hist_jax(hist, func, parmvals, max_iter = 5, edmtol = 1e-5):
37
+
38
+ xvals = [jnp.array(center) for center in hist.axes.centers]
39
+ yvals = jnp.array(hist.values())
40
+ yvariances = jnp.array(hist.variances())
41
+
42
+ def scipy_loss(parmvals):
43
+ parms = jnp.array(parmvals)
44
+ loss, grad = chisqloss_grad(xvals, yvals, yvariances, func, parms)
45
+ return np.asarray(loss).item(), np.asarray(grad)
46
+
47
+ def scipy_hessp(parmvals, p):
48
+ parms = jnp.array(parmvals)
49
+ tangent = jnp.array(p)
50
+ hessp = chisqloss_hessp(xvals, yvals, yvariances, func, parms, tangent)
51
+ return np.asarray(hessp)
52
+
53
+ for iiter in range(max_iter):
54
+ res = scipy.optimize.minimize(scipy_loss, parmvals, method = "trust-krylov", jac = True, hessp = scipy_hessp)
55
+
56
+ parms = jnp.array(res.x)
57
+ loss, grad, hess = chisqloss_grad_hess(xvals, yvals, yvariances, func, parms)
58
+ loss, grad, hess = np.asarray(loss).item(), np.asarray(grad), np.asarray(hess)
59
+
60
+ eigvals = np.linalg.eigvalsh(hess)
61
+ cov = 2.*np.linalg.inv(hess)
62
+
63
+ gradv = grad[:, np.newaxis]
64
+ edmval = 0.5*gradv.transpose()@cov@gradv
65
+ edmval = edmval[0][0]
66
+
67
+ converged = edmval < edmtol and np.abs(edmval) >= 0. and eigvals[0] > 0.
68
+ if converged:
69
+ break
70
+
71
+ status = 1
72
+ covstatus = 1
73
+ if edmval < edmtol and np.abs(edmval) >= 0.:
74
+ status = 0
75
+ if eigvals[0] > 0.:
76
+ covstatus = 0
77
+
78
+ res = { "x" : res.x,
79
+ "cov" : cov,
80
+ "status" : status,
81
+ "covstatus" : covstatus,
82
+ "hess_eigvals" : eigvals,
83
+ "edmval" : edmval,
84
+ "chisqval" : loss }
85
+
86
+ return res
wums/logging.py CHANGED
@@ -42,19 +42,19 @@ def set_logging_level(log, verbosity):
42
42
  log.setLevel(logging_verboseLevel[max(0, min(4, verbosity))])
43
43
 
44
44
 
45
- def setup_logger(basefile, verbosity=3, no_colors=False, initName="wremnants"):
45
+ def setup_logger(basefile, verbosity=3, no_colors=False, initName="wums"):
46
46
 
47
47
  setup_func = setup_base_logger if no_colors else setup_color_logger
48
48
  logger = setup_func(os.path.basename(basefile), verbosity, initName)
49
49
  # count messages of base logger
50
- base_logger = logging.getLogger("wremnants")
50
+ base_logger = logging.getLogger("wums")
51
51
  add_logging_counter(base_logger)
52
52
  # stop total time
53
53
  add_time_info("Total time")
54
54
  return logger
55
55
 
56
56
 
57
- def setup_color_logger(name, verbosity, initName="wremnants"):
57
+ def setup_color_logger(name, verbosity, initName="wums"):
58
58
  base_logger = logging.getLogger(initName)
59
59
  # set console handler
60
60
  ch = logging.StreamHandler()
@@ -65,14 +65,14 @@ def setup_color_logger(name, verbosity, initName="wremnants"):
65
65
  return base_logger.getChild(name)
66
66
 
67
67
 
68
- def setup_base_logger(name, verbosity, initName="wremnants"):
68
+ def setup_base_logger(name, verbosity, initName="wums"):
69
69
  logging.basicConfig(format="%(levelname)s: %(message)s")
70
70
  base_logger = logging.getLogger(initName)
71
71
  set_logging_level(base_logger, verbosity)
72
72
  return base_logger.getChild(name)
73
73
 
74
74
 
75
- def child_logger(name, initName="wremnants"):
75
+ def child_logger(name, initName="wums"):
76
76
  # count messages of child logger
77
77
  logger = logging.getLogger(initName).getChild(name)
78
78
  add_logging_counter(logger)
@@ -110,7 +110,7 @@ def print_logging_count(logger, verbosity=logging.WARNING):
110
110
  )
111
111
 
112
112
 
113
- def add_time_info(tag, logger=logging.getLogger("wremnants")):
113
+ def add_time_info(tag, logger=logging.getLogger("wums")):
114
114
  if not hasattr(logger, "times"):
115
115
  logger.times = {}
116
116
  logger.times[tag] = time.time()
@@ -125,7 +125,7 @@ def print_time_info(logger):
125
125
 
126
126
 
127
127
  def summary(verbosity=logging.WARNING, extended=True):
128
- base_logger = logging.getLogger("wremnants")
128
+ base_logger = logging.getLogger("wums")
129
129
 
130
130
  base_logger.info(f"--------------------------------------")
131
131
  base_logger.info(f"----------- logger summary -----------")
@@ -141,5 +141,5 @@ def summary(verbosity=logging.WARNING, extended=True):
141
141
  # Iterate through all child loggers and print their names, levels, and counts
142
142
  all_loggers = logging.Logger.manager.loggerDict
143
143
  for logger_name, logger_obj in all_loggers.items():
144
- if logger_name.startswith("wremnants."):
144
+ if logger_name.startswith("wums."):
145
145
  print_logging_count(logger_obj, verbosity=verbosity)
wums/tfutils.py ADDED
@@ -0,0 +1,81 @@
1
+ import tensorflow as tf
2
+
3
+ def function_to_tflite(funcs, input_signatures, func_names=""):
4
+ """Convert function to tflite model using python dynamic execution trickery to ensure that inputs
5
+ and outputs are alphabetically ordered, since this is apparently the only way to prevent tflite from
6
+ scrambling them"""
7
+
8
+ if not isinstance(funcs, list):
9
+ funcs = [funcs]
10
+ input_signatures = [input_signatures]
11
+ func_names = [func_names]
12
+ func_names = [funcs[iif].__name__ if func_names[iif]=="" else func_names[iif] for iif in range(len(funcs))]
13
+
14
+ def wrapped_func(iif, *args):
15
+ outputs = funcs[iif](*args)
16
+
17
+ if not isinstance(outputs, tuple):
18
+ outputs = (outputs,)
19
+
20
+ output_dict = {}
21
+ for i,output in enumerate(outputs):
22
+ output_name = f"output_{iif:05d}_{i:05d}"
23
+ output_dict[output_name] = output
24
+
25
+ return output_dict
26
+
27
+ arg_string = []
28
+ for iif, input_signature in enumerate(input_signatures):
29
+ inputs = []
30
+ for i in range(len(input_signature)):
31
+ input_name = f"input_{iif:05d}_{i:05d}"
32
+ inputs.append(input_name)
33
+ arg_string.append(", ".join(inputs))
34
+
35
+ def_string = ""
36
+ def_string += "def make_module(wrapped_func, input_signatures):\n"
37
+ def_string += " class Export_Module(tf.Module):\n"
38
+ for i, func in enumerate(funcs):
39
+ def_string += f" @tf.function(input_signature = input_signatures[{i}])\n"
40
+ def_string += f" def {func_names[i]}(self, {arg_string[i]}):\n"
41
+ def_string += f" return wrapped_func({i}, {arg_string[i]})\n"
42
+ def_string += " return Export_Module"
43
+
44
+ ldict = {}
45
+ exec(def_string, globals(), ldict)
46
+
47
+ make_module = ldict["make_module"]
48
+ Export_Module = make_module(wrapped_func, input_signatures)
49
+
50
+ module = Export_Module()
51
+ concrete_functions = [getattr(module, func_name).get_concrete_function() for func_name in func_names]
52
+ converter = tf.lite.TFLiteConverter.from_concrete_functions(concrete_functions, module)
53
+
54
+ # enable TenorFlow ops and DISABLE builtin TFLite ops since these apparently slow things down
55
+ converter.target_spec.supported_ops = [
56
+ tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops.
57
+ ]
58
+
59
+ converter._experimental_allow_all_select_tf_ops = True
60
+
61
+ tflite_model = converter.convert()
62
+
63
+ test_interp = tf.lite.Interpreter(model_content = tflite_model)
64
+ print(test_interp.get_input_details())
65
+ print(test_interp.get_output_details())
66
+ print(test_interp.get_signature_list())
67
+
68
+ return tflite_model
69
+
70
+
71
+
72
+ def function_to_saved_model(func, input_signature, output):
73
+
74
+ class Export_Module(tf.Module):
75
+ @tf.function(input_signature = input_signature)
76
+ def __call__(self, *args):
77
+ return func(*args)
78
+
79
+ model = Export_Module()
80
+
81
+ tf.saved_model.save(model, output)
@@ -0,0 +1,54 @@
1
+ Metadata-Version: 2.2
2
+ Name: wums
3
+ Version: 0.1.7
4
+ Summary: .
5
+ Author-email: David Walter <david.walter@cern.ch>, Josh Bendavid <josh.bendavid@cern.ch>, Kenneth Long <kenneth.long@cern.ch>, Jan Eysermans <jan.eysermans@cern.ch>
6
+ License: MIT
7
+ Project-URL: Homepage, https://github.com/WMass/wums
8
+ Classifier: Programming Language :: Python :: 3
9
+ Classifier: Programming Language :: Python :: 3.8
10
+ Classifier: License :: OSI Approved :: MIT License
11
+ Classifier: Operating System :: OS Independent
12
+ Requires-Python: >=3.8
13
+ Description-Content-Type: text/markdown
14
+ Requires-Dist: hist
15
+ Requires-Dist: numpy
16
+ Provides-Extra: plotting
17
+ Requires-Dist: matplotlib; extra == "plotting"
18
+ Requires-Dist: mplhep; extra == "plotting"
19
+ Provides-Extra: fitting
20
+ Requires-Dist: tensorflow; extra == "fitting"
21
+ Requires-Dist: jax; extra == "fitting"
22
+ Requires-Dist: scipy; extra == "fitting"
23
+ Provides-Extra: pickling
24
+ Requires-Dist: boost_histogram; extra == "pickling"
25
+ Requires-Dist: h5py; extra == "pickling"
26
+ Requires-Dist: hdf5plugin; extra == "pickling"
27
+ Requires-Dist: lz4; extra == "pickling"
28
+ Provides-Extra: all
29
+ Requires-Dist: plotting; extra == "all"
30
+ Requires-Dist: fitting; extra == "all"
31
+ Requires-Dist: pickling; extra == "all"
32
+
33
+ # WUMS: Wremnants Utilities, Modules, and other Stuff
34
+
35
+ As the name suggests, this is a collection of different thins, all python based:
36
+ - Fitting with tensorflow or jax
37
+ - Custom pickling h5py objects
38
+ - Plotting functionality
39
+
40
+ ## Install
41
+
42
+ The `wums` package can be pip installed with minimal dependencies:
43
+ ```bash
44
+ pip install wums
45
+ ```
46
+ Different dependencies can be added with `plotting`, `fitting`, `pickling` to use the corresponding scripts.
47
+ For example, one can install with
48
+ ```bash
49
+ pip install wums[plotting,fitting]
50
+ ```
51
+ Or all dependencies with
52
+ ```bash
53
+ pip install wums[all]
54
+ ```
@@ -0,0 +1,14 @@
1
+ wums/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ wums/boostHistHelpers.py,sha256=mgdPXAgmxriqoOhrhMctyZcfwEOPfV07V27CvGt2sk8,39260
3
+ wums/fitutils.py,sha256=sPCMJqZGdXvDfc8OxjOB-Bpf45GWHKxmKkDV3SlMUQs,38297
4
+ wums/fitutilsjax.py,sha256=HE1AcIZmI6N_xIHo8OHCPaYkHSnND_B-vI4Gl3vaUmA,2659
5
+ wums/ioutils.py,sha256=ziyfQQ8CB3Ir2BJKJU3_a7YMF-Jd2nGXKoMQoJ2T8fo,12334
6
+ wums/logging.py,sha256=L4514Xyq7L1z77Tkh8KE2HX88ZZ06o6SSRyQo96DbC0,4494
7
+ wums/output_tools.py,sha256=SHcZqXAdqL9AkA57UF0b-R-U4u7rzDgL8Def4E-ulW0,6713
8
+ wums/plot_tools.py,sha256=4iPx9Nr9y8c3p4ovy8XOS-xU_w11OyQEjISKkygxqcA,55918
9
+ wums/tfutils.py,sha256=9efkkvxH7VtwJN2yBS6_-P9dLKs3CXdxMFdrEBNsna8,2892
10
+ wums/Templates/index.php,sha256=9EYmfc0ltMqr5oOdA4_BVIHdSbef5aA0ORoRZBEADVw,4348
11
+ wums-0.1.7.dist-info/METADATA,sha256=GrQyVuatMvHdallbstH7YdiACEMLIo5isHyugfFawW8,1784
12
+ wums-0.1.7.dist-info/WHEEL,sha256=jB7zZ3N9hIM9adW7qlTAyycLYW9npaWKLRzaoVcLKcM,91
13
+ wums-0.1.7.dist-info/top_level.txt,sha256=DCE1TVg7ySraosR3kYZkLIZ2w1Pwk2pVTdkqx6E-yRY,5
14
+ wums-0.1.7.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.8.0)
2
+ Generator: setuptools (75.8.2)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,29 +0,0 @@
1
- Metadata-Version: 2.2
2
- Name: wums
3
- Version: 0.1.6
4
- Summary: .
5
- Author-email: David Walter <david.walter@cern.ch>, Josh Bendavid <josh.bendavid@cern.ch>, Kenneth Long <kenneth.long@cern.ch>
6
- License: MIT
7
- Project-URL: Homepage, https://github.com/WMass/wums
8
- Classifier: Programming Language :: Python :: 3
9
- Classifier: Programming Language :: Python :: 3.8
10
- Classifier: License :: OSI Approved :: MIT License
11
- Classifier: Operating System :: OS Independent
12
- Requires-Python: >=3.8
13
- Description-Content-Type: text/markdown
14
- Requires-Dist: boost_histogram
15
- Requires-Dist: h5py
16
- Requires-Dist: hdf5plugin
17
- Requires-Dist: hist
18
- Requires-Dist: lz4
19
- Requires-Dist: matplotlib
20
- Requires-Dist: mplhep
21
- Requires-Dist: numpy
22
- Requires-Dist: uproot
23
-
24
- # WUMS: Wremnants Utilities, Modules, and other Stuff
25
-
26
- The `wums` package can be pip installed:
27
- ```bash
28
- pip install wums
29
- ```
@@ -1,11 +0,0 @@
1
- wums/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- wums/boostHistHelpers.py,sha256=F4SwQEVjNObFscfs0qrJEyOHYNKqUCmusW8HIF1o-0c,38993
3
- wums/ioutils.py,sha256=ziyfQQ8CB3Ir2BJKJU3_a7YMF-Jd2nGXKoMQoJ2T8fo,12334
4
- wums/logging.py,sha256=zNnLVJUwG3HMvr9NeXmiheX07VmsnSt8cQ6R4q4XBk4,4534
5
- wums/output_tools.py,sha256=SHcZqXAdqL9AkA57UF0b-R-U4u7rzDgL8Def4E-ulW0,6713
6
- wums/plot_tools.py,sha256=4iPx9Nr9y8c3p4ovy8XOS-xU_w11OyQEjISKkygxqcA,55918
7
- wums/Templates/index.php,sha256=9EYmfc0ltMqr5oOdA4_BVIHdSbef5aA0ORoRZBEADVw,4348
8
- wums-0.1.6.dist-info/METADATA,sha256=pTmIMc-rth2X53tju6Ef8WbJDda2zbr8isEqUpeqhDo,843
9
- wums-0.1.6.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
10
- wums-0.1.6.dist-info/top_level.txt,sha256=DCE1TVg7ySraosR3kYZkLIZ2w1Pwk2pVTdkqx6E-yRY,5
11
- wums-0.1.6.dist-info/RECORD,,