wums 0.1.6__py3-none-any.whl → 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wums/boostHistHelpers.py +5 -1
- wums/fitutils.py +989 -0
- wums/fitutilsjax.py +86 -0
- wums/logging.py +8 -8
- wums/tfutils.py +81 -0
- wums-0.1.7.dist-info/METADATA +54 -0
- wums-0.1.7.dist-info/RECORD +14 -0
- {wums-0.1.6.dist-info → wums-0.1.7.dist-info}/WHEEL +1 -1
- wums-0.1.6.dist-info/METADATA +0 -29
- wums-0.1.6.dist-info/RECORD +0 -11
- {wums-0.1.6.dist-info → wums-0.1.7.dist-info}/top_level.txt +0 -0
wums/boostHistHelpers.py
CHANGED
|
@@ -40,7 +40,7 @@ def broadcastSystHist(h1, h2, flow=True, by_ax_name=True):
|
|
|
40
40
|
h2.ndim - 1 - i: h2.values(flow=flow).shape[h2.ndim - 1 - i]
|
|
41
41
|
for i in range(h2.ndim - h1.ndim)
|
|
42
42
|
}
|
|
43
|
-
|
|
43
|
+
|
|
44
44
|
broadcast_shape = list(moves.values()) + list(s1)
|
|
45
45
|
|
|
46
46
|
try:
|
|
@@ -53,6 +53,10 @@ def broadcastSystHist(h1, h2, flow=True, by_ax_name=True):
|
|
|
53
53
|
f" h2.axes: {h2.axes}"
|
|
54
54
|
)
|
|
55
55
|
|
|
56
|
+
if by_ax_name:
|
|
57
|
+
# We also have to move axes that are in common between h1 and h2 but in different order
|
|
58
|
+
moves.update({sum([k<i for k in moves.keys()]) + h1.axes.name.index(n2): None for i, n2 in enumerate(h2.axes.name) if n2 in h1.axes.name})
|
|
59
|
+
|
|
56
60
|
# move back to original order
|
|
57
61
|
new_vals = np.moveaxis(new_vals, np.arange(len(moves)), list(moves.keys()))
|
|
58
62
|
|
wums/fitutils.py
ADDED
|
@@ -0,0 +1,989 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import scipy
|
|
3
|
+
import tensorflow as tf
|
|
4
|
+
import math
|
|
5
|
+
|
|
6
|
+
from numpy import (zeros, where, diff, floor, minimum, maximum, array, concatenate, logical_or, logical_xor,
|
|
7
|
+
sqrt)
|
|
8
|
+
|
|
9
|
+
def cubic_spline_interpolate(xi, yi, x, axis=-1, extrpl=[None, None]):
|
|
10
|
+
|
|
11
|
+
# natural cublic spline
|
|
12
|
+
# if extrpl is given, the spline is linearly extrapolated outside the given boundaries
|
|
13
|
+
|
|
14
|
+
# https://www.math.ntnu.no/emner/TMA4215/2008h/cubicsplines.pdf
|
|
15
|
+
# https://random-walks.org/content/misc/ncs/ncs.html
|
|
16
|
+
|
|
17
|
+
# move selected axis to the end
|
|
18
|
+
tensors = [xi, yi]
|
|
19
|
+
nelems = [tensor.shape.num_elements() for tensor in tensors]
|
|
20
|
+
|
|
21
|
+
max_nelems = max(nelems)
|
|
22
|
+
broadcast_shape = tensors[nelems.index(max_nelems)].shape
|
|
23
|
+
|
|
24
|
+
ndim = len(broadcast_shape)
|
|
25
|
+
|
|
26
|
+
if xi.shape.num_elements() < max_nelems:
|
|
27
|
+
xi = tf.broadcast_to(xi, broadcast_shape)
|
|
28
|
+
if yi.shape.num_elements() < max_nelems:
|
|
29
|
+
yi = tf.broadcast_to(yi, broadcast_shape)
|
|
30
|
+
|
|
31
|
+
# # permutation to move the selected axis to the end
|
|
32
|
+
selaxis = axis
|
|
33
|
+
if axis < 0:
|
|
34
|
+
selaxis = ndim + axis
|
|
35
|
+
axis = -1
|
|
36
|
+
permfwd = list(range(ndim))
|
|
37
|
+
permfwd.remove(selaxis)
|
|
38
|
+
permfwd.append(selaxis)
|
|
39
|
+
|
|
40
|
+
# reverse permutation to restore the original axis order
|
|
41
|
+
permrev = list(range(ndim))
|
|
42
|
+
permrev.remove(ndim-1)
|
|
43
|
+
permrev.insert(selaxis, ndim-1)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
## start spline construction
|
|
47
|
+
xi = tf.transpose(xi, permfwd)
|
|
48
|
+
yi = tf.transpose(yi, permfwd)
|
|
49
|
+
x = tf.transpose(x, permfwd)
|
|
50
|
+
|
|
51
|
+
h = tf.experimental.numpy.diff(xi, axis=axis)
|
|
52
|
+
b = tf.experimental.numpy.diff(yi, axis=axis) / h
|
|
53
|
+
v = 2 * (h[:, 1:] + h[:, :-1])
|
|
54
|
+
u = 6 * (b[:, 1:] - b[:, :-1])
|
|
55
|
+
|
|
56
|
+
shape = (xi.shape[0], xi.shape[-1]-2, xi.shape[-1]-2)
|
|
57
|
+
uu = u[:,:,None]
|
|
58
|
+
|
|
59
|
+
diag = v
|
|
60
|
+
superdiag = h[:, 1:-1]
|
|
61
|
+
subdiag = superdiag
|
|
62
|
+
|
|
63
|
+
z = tf.linalg.tridiagonal_solve(diagonals = [superdiag, diag, subdiag], rhs = uu, diagonals_format = "sequence")
|
|
64
|
+
z = tf.squeeze(z, axis=axis)
|
|
65
|
+
f = tf.zeros(xi.shape[0], dtype=tf.float64)[:,None]
|
|
66
|
+
z = tf.concat([f, z, f], axis=axis)
|
|
67
|
+
|
|
68
|
+
x_steps = tf.experimental.numpy.diff(x, axis=axis)
|
|
69
|
+
idx_zero_constant = tf.constant(0, dtype=tf.int64)
|
|
70
|
+
float64_zero_constant = tf.constant(0., dtype=tf.float64)
|
|
71
|
+
|
|
72
|
+
x_compare = x[...,None] < xi[..., None, :]
|
|
73
|
+
x_compare_all = tf.math.reduce_all(x_compare, axis=axis)
|
|
74
|
+
x_compare_none = tf.math.reduce_all(tf.logical_not(x_compare), axis=axis)
|
|
75
|
+
x_index = tf.argmax(x_compare, axis=axis) - 1
|
|
76
|
+
x_index = tf.where(x_compare_all, idx_zero_constant, x_index)
|
|
77
|
+
x_index = tf.where(x_compare_none, tf.constant(xi.shape[axis]-2, dtype=tf.int64), x_index)
|
|
78
|
+
|
|
79
|
+
nbatch = ndim - 1
|
|
80
|
+
|
|
81
|
+
z_xidx = tf.gather(z, x_index, batch_dims=nbatch, axis=-1)
|
|
82
|
+
z_1pxidx = tf.gather(z, x_index+1, batch_dims=nbatch, axis=-1)
|
|
83
|
+
h_xidx = tf.gather(h, x_index, batch_dims=nbatch, axis=-1)
|
|
84
|
+
xi_xidx = tf.gather(xi, x_index, batch_dims=nbatch, axis=-1)
|
|
85
|
+
xi_1pxidx = tf.gather(xi, x_index+1, batch_dims=nbatch, axis=-1)
|
|
86
|
+
dxxi = x - xi_xidx
|
|
87
|
+
dxix = xi_1pxidx - x
|
|
88
|
+
|
|
89
|
+
y = z_1pxidx / (6 * h_xidx) * dxxi ** 3 + \
|
|
90
|
+
z_xidx / (6 * h_xidx) * dxix ** 3 + \
|
|
91
|
+
(tf.gather(yi, x_index+1, batch_dims=nbatch, axis=-1) / h_xidx - \
|
|
92
|
+
z_1pxidx * h_xidx / 6) * dxxi + \
|
|
93
|
+
(tf.gather(yi, x_index, batch_dims=nbatch, axis=-1) / h_xidx - \
|
|
94
|
+
z_xidx * h_xidx / 6) * dxix
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
# right side linear extrapolation
|
|
98
|
+
if extrpl[1] != None:
|
|
99
|
+
x_max = tf.reshape(tf.constant(extrpl[1], dtype=tf.float64), (1,1))
|
|
100
|
+
|
|
101
|
+
# calculate derivative yp_max at boundary
|
|
102
|
+
x_compare_max = x_max[...,None] < xi[..., None, :]
|
|
103
|
+
x_compare_max_all = tf.math.reduce_all(x_compare_max, axis=axis)
|
|
104
|
+
x_compare_max_none = tf.math.reduce_all(tf.logical_not(x_compare_max), axis=axis)
|
|
105
|
+
x_index_max = tf.argmax(x_compare_max, axis=axis) - 1
|
|
106
|
+
x_index_max = tf.where(x_compare_max_all, idx_zero_constant, x_index_max)
|
|
107
|
+
x_index_max = tf.where(x_compare_max_none, tf.constant(xi.shape[axis]-2, dtype=tf.int64), x_index_max)
|
|
108
|
+
|
|
109
|
+
z_xidx = tf.gather(z, x_index_max, batch_dims=nbatch, axis=-1)
|
|
110
|
+
z_1pxidx = tf.gather(z, x_index_max+1, batch_dims=nbatch, axis=-1)
|
|
111
|
+
hi_xidx = tf.gather(h, x_index_max, batch_dims=nbatch, axis=-1)
|
|
112
|
+
xi_xidx = tf.gather(xi, x_index_max, batch_dims=nbatch, axis=-1)
|
|
113
|
+
xi_1pxidx = tf.gather(xi, x_index_max+1, batch_dims=nbatch, axis=-1)
|
|
114
|
+
yi_xidx = tf.gather(yi, x_index_max, batch_dims=nbatch, axis=-1)
|
|
115
|
+
yi_1pxidx = tf.gather(yi, x_index_max+1, batch_dims=nbatch, axis=-1)
|
|
116
|
+
|
|
117
|
+
yp_max = z_1pxidx / (2 * hi_xidx) * (x_max - xi_xidx) ** 2 - \
|
|
118
|
+
z_xidx / (2 * hi_xidx) * (xi_1pxidx - x_max) ** 2 + \
|
|
119
|
+
1./hi_xidx*(yi_1pxidx - yi_xidx) - hi_xidx/6.*(z_1pxidx - z_xidx)
|
|
120
|
+
|
|
121
|
+
# evaluate spline at boundary
|
|
122
|
+
y_b = cubic_spline_interpolate(xi, yi, x_max, axis=axis) # check shape of x_max
|
|
123
|
+
|
|
124
|
+
# replace y by lin for x > x_max
|
|
125
|
+
extrpl_lin = yp_max*(x-x_max) + y_b
|
|
126
|
+
cond = x[0,:] >= x_max
|
|
127
|
+
cond = tf.broadcast_to(cond, (extrpl_lin.shape[0], extrpl_lin.shape[1]))
|
|
128
|
+
y = tf.where(cond, extrpl_lin, y)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
# left side linear extrapolation
|
|
132
|
+
if extrpl[0] != None:
|
|
133
|
+
x_min = tf.reshape(tf.constant(extrpl[0], dtype=tf.float64), (1,1))
|
|
134
|
+
|
|
135
|
+
# calculate derivative yp_min at boundary
|
|
136
|
+
x_compare_min = x_min[...,None] >= xi[..., None, :]
|
|
137
|
+
x_compare_min_all = tf.math.reduce_all(x_compare_min, axis=axis)
|
|
138
|
+
x_compare_min_none = tf.math.reduce_all(tf.logical_not(x_compare_min), axis=axis)
|
|
139
|
+
x_index_min = tf.argmax(x_compare_min, axis=axis)
|
|
140
|
+
x_index_min = tf.where(x_compare_min_all, idx_zero_constant, x_index_min)
|
|
141
|
+
x_index_min = tf.where(x_compare_min_none, tf.constant(0, dtype=tf.int64), x_index_min)
|
|
142
|
+
|
|
143
|
+
z_xidx = tf.gather(z, x_index_min, batch_dims=nbatch, axis=-1)
|
|
144
|
+
z_1pxidx = tf.gather(z, x_index_min+1, batch_dims=nbatch, axis=-1)
|
|
145
|
+
hi_xidx = tf.gather(h, x_index_min, batch_dims=nbatch, axis=-1)
|
|
146
|
+
xi_xidx = tf.gather(xi, x_index_min, batch_dims=nbatch, axis=-1)
|
|
147
|
+
xi_1pxidx = tf.gather(xi, x_index_min+1, batch_dims=nbatch, axis=-1)
|
|
148
|
+
yi_xidx = tf.gather(yi, x_index_min, batch_dims=nbatch, axis=-1)
|
|
149
|
+
yi_1pxidx = tf.gather(yi, x_index_min+1, batch_dims=nbatch, axis=-1)
|
|
150
|
+
|
|
151
|
+
yp_min = z_1pxidx / (2 * hi_xidx) * (x_min - xi_xidx) ** 2 - \
|
|
152
|
+
z_xidx / (2 * hi_xidx) * (xi_1pxidx - x_min) ** 2 + \
|
|
153
|
+
1./hi_xidx*(yi_1pxidx - yi_xidx) - hi_xidx/6.*(z_1pxidx - z_xidx)
|
|
154
|
+
|
|
155
|
+
# evaluate spline at boundary
|
|
156
|
+
y_b = cubic_spline_interpolate(xi, yi, x_min, axis=axis) # check shape of x_max
|
|
157
|
+
|
|
158
|
+
# replace y by lin for x > x_min
|
|
159
|
+
extrpl_lin = yp_min*(x-x_min) + y_b
|
|
160
|
+
cond = x[0,:] <= x_min
|
|
161
|
+
cond = tf.broadcast_to(cond, (extrpl_lin.shape[0], extrpl_lin.shape[1]))
|
|
162
|
+
y = tf.where(cond, extrpl_lin, y)
|
|
163
|
+
|
|
164
|
+
## convert back axes
|
|
165
|
+
y = tf.transpose(y, permrev)
|
|
166
|
+
return y
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def pchip_interpolate(xi, yi, x, axis=-1):
|
|
171
|
+
'''
|
|
172
|
+
Functionality:
|
|
173
|
+
1D PCHP interpolation
|
|
174
|
+
Authors:
|
|
175
|
+
Michael Taylor <mtaylor@atlanticsciences.com>
|
|
176
|
+
Mathieu Virbel <mat@meltingrocks.com>
|
|
177
|
+
Link:
|
|
178
|
+
https://gist.github.com/tito/553f1135959921ce6699652bf656150d
|
|
179
|
+
https://github.com/tensorflow/tensorflow/issues/46609#issuecomment-774573667
|
|
180
|
+
'''
|
|
181
|
+
|
|
182
|
+
tensors = [xi, yi]
|
|
183
|
+
nelems = [tensor.shape.num_elements() for tensor in tensors]
|
|
184
|
+
|
|
185
|
+
max_nelems = max(nelems)
|
|
186
|
+
broadcast_shape = tensors[nelems.index(max_nelems)].shape
|
|
187
|
+
|
|
188
|
+
ndim = len(broadcast_shape)
|
|
189
|
+
|
|
190
|
+
if xi.shape.num_elements() < max_nelems:
|
|
191
|
+
xi = tf.broadcast_to(xi, broadcast_shape)
|
|
192
|
+
if yi.shape.num_elements() < max_nelems:
|
|
193
|
+
yi = tf.broadcast_to(yi, broadcast_shape)
|
|
194
|
+
|
|
195
|
+
# # permutation to move the selected axis to the end
|
|
196
|
+
selaxis = axis
|
|
197
|
+
if axis < 0:
|
|
198
|
+
selaxis = ndim + axis
|
|
199
|
+
|
|
200
|
+
permfwd = list(range(ndim))
|
|
201
|
+
permfwd.remove(selaxis)
|
|
202
|
+
permfwd.append(selaxis)
|
|
203
|
+
|
|
204
|
+
# reverse permutation to restore the original axis order
|
|
205
|
+
permrev = list(range(ndim))
|
|
206
|
+
permrev.remove(ndim-1)
|
|
207
|
+
permrev.insert(selaxis, ndim-1)
|
|
208
|
+
|
|
209
|
+
xi = tf.transpose(xi, permfwd)
|
|
210
|
+
yi = tf.transpose(yi, permfwd)
|
|
211
|
+
x = tf.transpose(x, permfwd)
|
|
212
|
+
axis = -1
|
|
213
|
+
|
|
214
|
+
xi_steps = tf.experimental.numpy.diff(xi, axis=axis)
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
x_steps = tf.experimental.numpy.diff(x, axis=axis)
|
|
218
|
+
|
|
219
|
+
idx_zero_constant = tf.constant(0, dtype=tf.int64)
|
|
220
|
+
float64_zero_constant = tf.constant(0., dtype=tf.float64)
|
|
221
|
+
|
|
222
|
+
x_compare = x[...,None] < xi[..., None, :]
|
|
223
|
+
x_compare_all = tf.math.reduce_all(x_compare, axis=-1)
|
|
224
|
+
x_compare_none = tf.math.reduce_all(tf.logical_not(x_compare), axis=-1)
|
|
225
|
+
x_index = tf.argmax(x_compare, axis = -1) - 1
|
|
226
|
+
|
|
227
|
+
x_index = tf.where(x_compare_all, idx_zero_constant, x_index)
|
|
228
|
+
x_index = tf.where(x_compare_none, tf.constant(xi.shape[axis]-2, dtype=tf.int64), x_index)
|
|
229
|
+
|
|
230
|
+
# Calculate gradients d
|
|
231
|
+
h = tf.experimental.numpy.diff(xi, axis=axis)
|
|
232
|
+
|
|
233
|
+
d = tf.zeros_like(xi)
|
|
234
|
+
|
|
235
|
+
delta = tf.experimental.numpy.diff(yi, axis=axis) / h
|
|
236
|
+
# mode=='mono', Fritsch-Carlson algorithm from fortran numerical
|
|
237
|
+
# recipe
|
|
238
|
+
|
|
239
|
+
slice01 = [slice(None)]*ndim
|
|
240
|
+
slice01[axis] = slice(0,1)
|
|
241
|
+
slice01 = tuple(slice01)
|
|
242
|
+
|
|
243
|
+
slice0m1 = [slice(None)]*ndim
|
|
244
|
+
slice0m1[axis] = slice(0,-1)
|
|
245
|
+
slice0m1 = tuple(slice0m1)
|
|
246
|
+
|
|
247
|
+
slice1 = [slice(None)]*ndim
|
|
248
|
+
slice1[axis] = slice(1,None)
|
|
249
|
+
slice1 = tuple(slice1)
|
|
250
|
+
|
|
251
|
+
slicem1 = [slice(None)]*ndim
|
|
252
|
+
slicem1[axis] = slice(-1,None)
|
|
253
|
+
slicem1 = tuple(slicem1)
|
|
254
|
+
|
|
255
|
+
d = tf.concat(
|
|
256
|
+
(delta[slice01], 3 * (h[slice0m1] + h[slice1]) / ((h[slice0m1] + 2 * h[slice1]) / delta[slice0m1] +
|
|
257
|
+
(2 * h[slice0m1] + h[slice1]) / delta[slice1]), delta[slicem1]), axis=axis)
|
|
258
|
+
|
|
259
|
+
false_shape = [*xi.shape]
|
|
260
|
+
false_shape[axis] = 1
|
|
261
|
+
false_const = tf.fill(false_shape, False)
|
|
262
|
+
|
|
263
|
+
mask = tf.concat((false_const, tf.math.logical_xor(delta[slice0m1] > 0, delta[slice1] > 0), false_const), axis=axis)
|
|
264
|
+
d = tf.where(mask, float64_zero_constant, d)
|
|
265
|
+
|
|
266
|
+
mask = tf.math.logical_or(tf.concat((false_const, delta == 0), axis=axis), tf.concat((delta == 0, false_const), axis=axis))
|
|
267
|
+
d = tf.where(mask, float64_zero_constant, d)
|
|
268
|
+
|
|
269
|
+
xiperm = xi
|
|
270
|
+
yiperm = yi
|
|
271
|
+
dperm = d
|
|
272
|
+
hperm = h
|
|
273
|
+
|
|
274
|
+
nbatch = ndim - 1
|
|
275
|
+
|
|
276
|
+
xi_xidx = tf.gather(xiperm, x_index, batch_dims=nbatch, axis=-1)
|
|
277
|
+
xi_1pxidx = tf.gather(xiperm, 1 + x_index, batch_dims=nbatch, axis=-1)
|
|
278
|
+
yi_xidx = tf.gather(yiperm, x_index, batch_dims=nbatch, axis=-1)
|
|
279
|
+
yi_1pxidx = tf.gather(yiperm, 1 + x_index, batch_dims=nbatch, axis=-1)
|
|
280
|
+
d_xidx = tf.gather(dperm, x_index, batch_dims=nbatch, axis=-1)
|
|
281
|
+
d_1pxidx = tf.gather(dperm, 1 + x_index, batch_dims=nbatch, axis=-1)
|
|
282
|
+
h_xidx = tf.gather(hperm, x_index, batch_dims=nbatch, axis=-1)
|
|
283
|
+
|
|
284
|
+
dxxi = x - xi_xidx
|
|
285
|
+
dxxid = x - xi_1pxidx
|
|
286
|
+
dxxi2 = tf.math.pow(dxxi, 2)
|
|
287
|
+
dxxid2 = tf.math.pow(dxxid, 2)
|
|
288
|
+
|
|
289
|
+
y = (2 / tf.math.pow(h_xidx, 3) *
|
|
290
|
+
(yi_xidx * dxxid2 * (dxxi + h_xidx / 2) - yi_1pxidx * dxxi2 *
|
|
291
|
+
(dxxid - h_xidx / 2)) + 1 / tf.math.pow(h_xidx, 2) *
|
|
292
|
+
(d_xidx * dxxid2 * dxxi + d_1pxidx * dxxi2 * dxxid))
|
|
293
|
+
|
|
294
|
+
y = tf.transpose(y, permrev)
|
|
295
|
+
|
|
296
|
+
return y
|
|
297
|
+
|
|
298
|
+
def pchip_interpolate_np(xi, yi, x, mode="mono", verbose=False):
|
|
299
|
+
'''
|
|
300
|
+
Functionality:
|
|
301
|
+
1D PCHP interpolation
|
|
302
|
+
Authors:
|
|
303
|
+
Michael Taylor <mtaylor@atlanticsciences.com>
|
|
304
|
+
Mathieu Virbel <mat@meltingrocks.com>
|
|
305
|
+
Link:
|
|
306
|
+
https://gist.github.com/tito/553f1135959921ce6699652bf656150d
|
|
307
|
+
'''
|
|
308
|
+
|
|
309
|
+
if mode not in ("mono", "quad"):
|
|
310
|
+
raise ValueError("Unrecognized mode string")
|
|
311
|
+
|
|
312
|
+
# Search for [xi,xi+1] interval for each x
|
|
313
|
+
xi = xi.astype("double")
|
|
314
|
+
yi = yi.astype("double")
|
|
315
|
+
|
|
316
|
+
x_index = zeros(len(x), dtype="int")
|
|
317
|
+
xi_steps = diff(xi)
|
|
318
|
+
if not all(xi_steps > 0):
|
|
319
|
+
raise ValueError("x-coordinates are not in increasing order.")
|
|
320
|
+
|
|
321
|
+
x_steps = diff(x)
|
|
322
|
+
if xi_steps.max() / xi_steps.min() < 1.000001:
|
|
323
|
+
# uniform input grid
|
|
324
|
+
if verbose:
|
|
325
|
+
print("pchip: uniform input grid")
|
|
326
|
+
xi_start = xi[0]
|
|
327
|
+
xi_step = (xi[-1] - xi[0]) / (len(xi) - 1)
|
|
328
|
+
x_index = minimum(maximum(floor((x - xi_start) / xi_step).astype(int), 0), len(xi) - 2)
|
|
329
|
+
|
|
330
|
+
# Calculate gradients d
|
|
331
|
+
h = (xi[-1] - xi[0]) / (len(xi) - 1)
|
|
332
|
+
d = zeros(len(xi), dtype="double")
|
|
333
|
+
if mode == "quad":
|
|
334
|
+
# quadratic polynomial fit
|
|
335
|
+
d[[0]] = (yi[1] - yi[0]) / h
|
|
336
|
+
d[[-1]] = (yi[-1] - yi[-2]) / h
|
|
337
|
+
d[1:-1] = (yi[2:] - yi[0:-2]) / 2 / h
|
|
338
|
+
else:
|
|
339
|
+
# mode=='mono', Fritsch-Carlson algorithm from fortran numerical
|
|
340
|
+
# recipe
|
|
341
|
+
delta = diff(yi) / h
|
|
342
|
+
d = concatenate((delta[0:1], 2 / (1 / delta[0:-1] + 1 / delta[1:]), delta[-1:]))
|
|
343
|
+
d[concatenate((array([False]), logical_xor(delta[0:-1] > 0, delta[1:] > 0), array([False])))] = 0
|
|
344
|
+
d[logical_or(concatenate((array([False]), delta == 0)), concatenate(
|
|
345
|
+
(delta == 0, array([False]))))] = 0
|
|
346
|
+
# Calculate output values y
|
|
347
|
+
dxxi = x - xi[x_index]
|
|
348
|
+
dxxid = x - xi[1 + x_index]
|
|
349
|
+
dxxi2 = pow(dxxi, 2)
|
|
350
|
+
dxxid2 = pow(dxxid, 2)
|
|
351
|
+
y = (2 / pow(h, 3) * (yi[x_index] * dxxid2 * (dxxi + h / 2) - yi[1 + x_index] * dxxi2 *
|
|
352
|
+
(dxxid - h / 2)) + 1 / pow(h, 2) *
|
|
353
|
+
(d[x_index] * dxxid2 * dxxi + d[1 + x_index] * dxxi2 * dxxid))
|
|
354
|
+
else:
|
|
355
|
+
# not uniform input grid
|
|
356
|
+
if (x_steps.max() / x_steps.min() < 1.000001 and x_steps.max() / x_steps.min() > 0.999999):
|
|
357
|
+
# non-uniform input grid, uniform output grid
|
|
358
|
+
if verbose:
|
|
359
|
+
print("pchip: non-uniform input grid, uniform output grid")
|
|
360
|
+
x_decreasing = x[-1] < x[0]
|
|
361
|
+
if x_decreasing:
|
|
362
|
+
x = x[::-1]
|
|
363
|
+
x_start = x[0]
|
|
364
|
+
x_step = (x[-1] - x[0]) / (len(x) - 1)
|
|
365
|
+
x_indexprev = -1
|
|
366
|
+
for xi_loop in range(len(xi) - 2):
|
|
367
|
+
x_indexcur = max(int(floor((xi[1 + xi_loop] - x_start) / x_step)), -1)
|
|
368
|
+
x_index[1 + x_indexprev:1 + x_indexcur] = xi_loop
|
|
369
|
+
x_indexprev = x_indexcur
|
|
370
|
+
x_index[1 + x_indexprev:] = len(xi) - 2
|
|
371
|
+
if x_decreasing:
|
|
372
|
+
x = x[::-1]
|
|
373
|
+
x_index = x_index[::-1]
|
|
374
|
+
elif all(x_steps > 0) or all(x_steps < 0):
|
|
375
|
+
# non-uniform input/output grids, output grid monotonic
|
|
376
|
+
if verbose:
|
|
377
|
+
print("pchip: non-uniform in/out grid, output grid monotonic")
|
|
378
|
+
x_decreasing = x[-1] < x[0]
|
|
379
|
+
if x_decreasing:
|
|
380
|
+
x = x[::-1]
|
|
381
|
+
x_len = len(x)
|
|
382
|
+
x_loop = 0
|
|
383
|
+
for xi_loop in range(len(xi) - 1):
|
|
384
|
+
while x_loop < x_len and x[x_loop] < xi[1 + xi_loop]:
|
|
385
|
+
x_index[x_loop] = xi_loop
|
|
386
|
+
x_loop += 1
|
|
387
|
+
x_index[x_loop:] = len(xi) - 2
|
|
388
|
+
if x_decreasing:
|
|
389
|
+
x = x[::-1]
|
|
390
|
+
x_index = x_index[::-1]
|
|
391
|
+
else:
|
|
392
|
+
# non-uniform input/output grids, output grid not monotonic
|
|
393
|
+
if verbose:
|
|
394
|
+
print("pchip: non-uniform in/out grids, " "output grid not monotonic")
|
|
395
|
+
for index in range(len(x)):
|
|
396
|
+
loc = where(x[index] < xi)[0]
|
|
397
|
+
if loc.size == 0:
|
|
398
|
+
x_index[index] = len(xi) - 2
|
|
399
|
+
elif loc[0] == 0:
|
|
400
|
+
x_index[index] = 0
|
|
401
|
+
else:
|
|
402
|
+
x_index[index] = loc[0] - 1
|
|
403
|
+
# Calculate gradients d
|
|
404
|
+
h = diff(xi)
|
|
405
|
+
d = zeros(len(xi), dtype="double")
|
|
406
|
+
delta = diff(yi) / h
|
|
407
|
+
if mode == "quad":
|
|
408
|
+
# quadratic polynomial fit
|
|
409
|
+
d[[0, -1]] = delta[[0, -1]]
|
|
410
|
+
d[1:-1] = (delta[1:] * h[0:-1] + delta[0:-1] * h[1:]) / (h[0:-1] + h[1:])
|
|
411
|
+
else:
|
|
412
|
+
# mode=='mono', Fritsch-Carlson algorithm from fortran numerical
|
|
413
|
+
# recipe
|
|
414
|
+
d = concatenate(
|
|
415
|
+
(delta[0:1], 3 * (h[0:-1] + h[1:]) / ((h[0:-1] + 2 * h[1:]) / delta[0:-1] +
|
|
416
|
+
(2 * h[0:-1] + h[1:]) / delta[1:]), delta[-1:]))
|
|
417
|
+
d[concatenate((array([False]), logical_xor(delta[0:-1] > 0, delta[1:] > 0), array([False])))] = 0
|
|
418
|
+
d[logical_or(concatenate((array([False]), delta == 0)), concatenate(
|
|
419
|
+
(delta == 0, array([False]))))] = 0
|
|
420
|
+
dxxi = x - xi[x_index]
|
|
421
|
+
dxxid = x - xi[1 + x_index]
|
|
422
|
+
dxxi2 = pow(dxxi, 2)
|
|
423
|
+
dxxid2 = pow(dxxid, 2)
|
|
424
|
+
y = (2 / pow(h[x_index], 3) *
|
|
425
|
+
(yi[x_index] * dxxid2 * (dxxi + h[x_index] / 2) - yi[1 + x_index] * dxxi2 *
|
|
426
|
+
(dxxid - h[x_index] / 2)) + 1 / pow(h[x_index], 2) *
|
|
427
|
+
(d[x_index] * dxxid2 * dxxi + d[1 + x_index] * dxxi2 * dxxid))
|
|
428
|
+
return y
|
|
429
|
+
|
|
430
|
+
|
|
431
|
+
|
|
432
|
+
def pchip_interpolate_np_forced(xi, yi, x, mode="mono", verbose=False):
|
|
433
|
+
'''
|
|
434
|
+
Functionality:
|
|
435
|
+
1D PCHP interpolation
|
|
436
|
+
Authors:
|
|
437
|
+
Michael Taylor <mtaylor@atlanticsciences.com>
|
|
438
|
+
Mathieu Virbel <mat@meltingrocks.com>
|
|
439
|
+
Link:
|
|
440
|
+
https://gist.github.com/tito/553f1135959921ce6699652bf656150d
|
|
441
|
+
'''
|
|
442
|
+
|
|
443
|
+
if mode not in ("mono", "quad"):
|
|
444
|
+
raise ValueError("Unrecognized mode string")
|
|
445
|
+
|
|
446
|
+
# Search for [xi,xi+1] interval for each x
|
|
447
|
+
xi = xi.astype("double")
|
|
448
|
+
yi = yi.astype("double")
|
|
449
|
+
|
|
450
|
+
x_index = zeros(len(x), dtype="int")
|
|
451
|
+
xi_steps = diff(xi)
|
|
452
|
+
if not all(xi_steps > 0):
|
|
453
|
+
raise ValueError("x-coordinates are not in increasing order.")
|
|
454
|
+
|
|
455
|
+
x_steps = diff(x)
|
|
456
|
+
# if xi_steps.max() / xi_steps.min() < 1.000001:
|
|
457
|
+
if False:
|
|
458
|
+
# uniform input grid
|
|
459
|
+
if verbose:
|
|
460
|
+
print("pchip: uniform input grid")
|
|
461
|
+
xi_start = xi[0]
|
|
462
|
+
xi_step = (xi[-1] - xi[0]) / (len(xi) - 1)
|
|
463
|
+
x_index = minimum(maximum(floor((x - xi_start) / xi_step).astype(int), 0), len(xi) - 2)
|
|
464
|
+
|
|
465
|
+
# Calculate gradients d
|
|
466
|
+
h = (xi[-1] - xi[0]) / (len(xi) - 1)
|
|
467
|
+
d = zeros(len(xi), dtype="double")
|
|
468
|
+
if mode == "quad":
|
|
469
|
+
# quadratic polynomial fit
|
|
470
|
+
d[[0]] = (yi[1] - yi[0]) / h
|
|
471
|
+
d[[-1]] = (yi[-1] - yi[-2]) / h
|
|
472
|
+
d[1:-1] = (yi[2:] - yi[0:-2]) / 2 / h
|
|
473
|
+
else:
|
|
474
|
+
# mode=='mono', Fritsch-Carlson algorithm from fortran numerical
|
|
475
|
+
# recipe
|
|
476
|
+
delta = diff(yi) / h
|
|
477
|
+
d = concatenate((delta[0:1], 2 / (1 / delta[0:-1] + 1 / delta[1:]), delta[-1:]))
|
|
478
|
+
d[concatenate((array([False]), logical_xor(delta[0:-1] > 0, delta[1:] > 0), array([False])))] = 0
|
|
479
|
+
d[logical_or(concatenate((array([False]), delta == 0)), concatenate(
|
|
480
|
+
(delta == 0, array([False]))))] = 0
|
|
481
|
+
# Calculate output values y
|
|
482
|
+
dxxi = x - xi[x_index]
|
|
483
|
+
dxxid = x - xi[1 + x_index]
|
|
484
|
+
dxxi2 = pow(dxxi, 2)
|
|
485
|
+
dxxid2 = pow(dxxid, 2)
|
|
486
|
+
y = (2 / pow(h, 3) * (yi[x_index] * dxxid2 * (dxxi + h / 2) - yi[1 + x_index] * dxxi2 *
|
|
487
|
+
(dxxid - h / 2)) + 1 / pow(h, 2) *
|
|
488
|
+
(d[x_index] * dxxid2 * dxxi + d[1 + x_index] * dxxi2 * dxxid))
|
|
489
|
+
else:
|
|
490
|
+
# not uniform input grid
|
|
491
|
+
# if (x_steps.max() / x_steps.min() < 1.000001 and x_steps.max() / x_steps.min() > 0.999999):
|
|
492
|
+
if False:
|
|
493
|
+
# non-uniform input grid, uniform output grid
|
|
494
|
+
if verbose:
|
|
495
|
+
print("pchip: non-uniform input grid, uniform output grid")
|
|
496
|
+
x_decreasing = x[-1] < x[0]
|
|
497
|
+
if x_decreasing:
|
|
498
|
+
x = x[::-1]
|
|
499
|
+
x_start = x[0]
|
|
500
|
+
x_step = (x[-1] - x[0]) / (len(x) - 1)
|
|
501
|
+
x_indexprev = -1
|
|
502
|
+
for xi_loop in range(len(xi) - 2):
|
|
503
|
+
x_indexcur = max(int(floor((xi[1 + xi_loop] - x_start) / x_step)), -1)
|
|
504
|
+
x_index[1 + x_indexprev:1 + x_indexcur] = xi_loop
|
|
505
|
+
x_indexprev = x_indexcur
|
|
506
|
+
x_index[1 + x_indexprev:] = len(xi) - 2
|
|
507
|
+
if x_decreasing:
|
|
508
|
+
x = x[::-1]
|
|
509
|
+
x_index = x_index[::-1]
|
|
510
|
+
# elif all(x_steps > 0) or all(x_steps < 0):
|
|
511
|
+
elif True:
|
|
512
|
+
# non-uniform input/output grids, output grid monotonic
|
|
513
|
+
if verbose:
|
|
514
|
+
print("pchip: non-uniform in/out grid, output grid monotonic")
|
|
515
|
+
# x_decreasing = x[-1] < x[0]
|
|
516
|
+
x_decreasing = False
|
|
517
|
+
if x_decreasing:
|
|
518
|
+
x = x[::-1]
|
|
519
|
+
x_len = len(x)
|
|
520
|
+
x_loop = 0
|
|
521
|
+
for xi_loop in range(len(xi) - 1):
|
|
522
|
+
while x_loop < x_len and x[x_loop] < xi[1 + xi_loop]:
|
|
523
|
+
x_index[x_loop] = xi_loop
|
|
524
|
+
x_loop += 1
|
|
525
|
+
x_index[x_loop:] = len(xi) - 2
|
|
526
|
+
|
|
527
|
+
print("np_forced x_index", x_index)
|
|
528
|
+
if x_decreasing:
|
|
529
|
+
x = x[::-1]
|
|
530
|
+
x_index = x_index[::-1]
|
|
531
|
+
else:
|
|
532
|
+
# non-uniform input/output grids, output grid not monotonic
|
|
533
|
+
if verbose:
|
|
534
|
+
print("pchip: non-uniform in/out grids, " "output grid not monotonic")
|
|
535
|
+
for index in range(len(x)):
|
|
536
|
+
loc = where(x[index] < xi)[0]
|
|
537
|
+
if loc.size == 0:
|
|
538
|
+
x_index[index] = len(xi) - 2
|
|
539
|
+
elif loc[0] == 0:
|
|
540
|
+
x_index[index] = 0
|
|
541
|
+
else:
|
|
542
|
+
x_index[index] = loc[0] - 1
|
|
543
|
+
# Calculate gradients d
|
|
544
|
+
h = diff(xi)
|
|
545
|
+
d = zeros(len(xi), dtype="double")
|
|
546
|
+
delta = diff(yi) / h
|
|
547
|
+
if mode == "quad":
|
|
548
|
+
# quadratic polynomial fit
|
|
549
|
+
d[[0, -1]] = delta[[0, -1]]
|
|
550
|
+
d[1:-1] = (delta[1:] * h[0:-1] + delta[0:-1] * h[1:]) / (h[0:-1] + h[1:])
|
|
551
|
+
else:
|
|
552
|
+
# mode=='mono', Fritsch-Carlson algorithm from fortran numerical
|
|
553
|
+
# recipe
|
|
554
|
+
d = concatenate(
|
|
555
|
+
(delta[0:1], 3 * (h[0:-1] + h[1:]) / ((h[0:-1] + 2 * h[1:]) / delta[0:-1] +
|
|
556
|
+
(2 * h[0:-1] + h[1:]) / delta[1:]), delta[-1:]))
|
|
557
|
+
d[concatenate((array([False]), logical_xor(delta[0:-1] > 0, delta[1:] > 0), array([False])))] = 0
|
|
558
|
+
d[logical_or(concatenate((array([False]), delta == 0)), concatenate(
|
|
559
|
+
(delta == 0, array([False]))))] = 0
|
|
560
|
+
dxxi = x - xi[x_index]
|
|
561
|
+
dxxid = x - xi[1 + x_index]
|
|
562
|
+
dxxi2 = pow(dxxi, 2)
|
|
563
|
+
dxxid2 = pow(dxxid, 2)
|
|
564
|
+
y = (2 / pow(h[x_index], 3) *
|
|
565
|
+
(yi[x_index] * dxxid2 * (dxxi + h[x_index] / 2) - yi[1 + x_index] * dxxi2 *
|
|
566
|
+
(dxxid - h[x_index] / 2)) + 1 / pow(h[x_index], 2) *
|
|
567
|
+
(d[x_index] * dxxid2 * dxxi + d[1 + x_index] * dxxi2 * dxxid))
|
|
568
|
+
return y
|
|
569
|
+
|
|
570
|
+
def qparms_to_quantiles(qparms, x_low = 0., x_high = 1., axis = -1):
|
|
571
|
+
deltax = tf.exp(qparms)
|
|
572
|
+
sumdeltax = tf.math.reduce_sum(deltax, axis=axis, keepdims=True)
|
|
573
|
+
|
|
574
|
+
deltaxnorm = deltax/sumdeltax
|
|
575
|
+
|
|
576
|
+
x0shape = list(deltaxnorm.shape)
|
|
577
|
+
x0shape[axis] = 1
|
|
578
|
+
x0 = tf.fill(x0shape, x_low)
|
|
579
|
+
x0 = tf.cast(x0, tf.float64)
|
|
580
|
+
|
|
581
|
+
deltaxfull = (x_high - x_low)*deltaxnorm
|
|
582
|
+
deltaxfull = tf.concat([x0, deltaxfull], axis = axis)
|
|
583
|
+
|
|
584
|
+
quants = tf.math.cumsum(deltaxfull, axis=axis)
|
|
585
|
+
|
|
586
|
+
return quants
|
|
587
|
+
|
|
588
|
+
|
|
589
|
+
|
|
590
|
+
def quantiles_to_qparms(quants, quant_errs = None, x_low = 0., x_high = 1., axis = -1):
|
|
591
|
+
|
|
592
|
+
deltaxfull = tf.experimental.numpy.diff(quants, axis=axis)
|
|
593
|
+
deltaxnorm = deltaxfull/(x_high - x_low)
|
|
594
|
+
qparms = tf.math.log(deltaxnorm)
|
|
595
|
+
|
|
596
|
+
if quant_errs is not None:
|
|
597
|
+
quant_vars = tf.math.square(quant_errs)
|
|
598
|
+
|
|
599
|
+
ndim = len(quant_errs.shape)
|
|
600
|
+
|
|
601
|
+
slicem1 = [slice(None)]*ndim
|
|
602
|
+
slicem1[axis] = slice(None,-1)
|
|
603
|
+
slicem1 = tuple(slicem1)
|
|
604
|
+
|
|
605
|
+
slice1 = [slice(None)]*ndim
|
|
606
|
+
slice1[axis] = slice(1,None)
|
|
607
|
+
slice1 = tuple(slice1)
|
|
608
|
+
|
|
609
|
+
deltaxfull_vars = quant_vars[slice1] + quant_vars[slicem1]
|
|
610
|
+
deltaxfull_errs = tf.math.sqrt(deltaxfull_vars)
|
|
611
|
+
|
|
612
|
+
qparm_errs = deltaxfull_errs/deltaxfull
|
|
613
|
+
|
|
614
|
+
return qparms, qparm_errs
|
|
615
|
+
else:
|
|
616
|
+
return qparms
|
|
617
|
+
|
|
618
|
+
|
|
619
|
+
def hist_to_quantiles(h, quant_cdfvals, axis = -1):
|
|
620
|
+
dtype = tf.float64
|
|
621
|
+
|
|
622
|
+
xvals = [tf.constant(center, dtype=dtype) for center in h.axes.centers]
|
|
623
|
+
xwidths = [tf.constant(width, dtype=dtype) for width in h.axes.widths]
|
|
624
|
+
xedges = [tf.constant(edge, dtype=dtype) for edge in h.axes.edges]
|
|
625
|
+
yvals = tf.constant(h.values(), dtype=dtype)
|
|
626
|
+
|
|
627
|
+
if not isinstance(quant_cdfvals, tf.Tensor):
|
|
628
|
+
quant_cdfvals = tf.constant(quant_cdfvals, tf.float64)
|
|
629
|
+
|
|
630
|
+
x_flat = tf.reshape(xedges[axis], (-1,))
|
|
631
|
+
x_low = x_flat[0]
|
|
632
|
+
x_high = x_flat[-1]
|
|
633
|
+
|
|
634
|
+
hist_cdfvals = tf.cumsum(yvals, axis=axis)/tf.reduce_sum(yvals, axis=axis, keepdims=True)
|
|
635
|
+
|
|
636
|
+
x0shape = list(hist_cdfvals.shape)
|
|
637
|
+
x0shape[axis] = 1
|
|
638
|
+
x0 = tf.zeros(x0shape, dtype = dtype)
|
|
639
|
+
|
|
640
|
+
hist_cdfvals = tf.concat([x0, hist_cdfvals], axis=axis)
|
|
641
|
+
|
|
642
|
+
quants = pchip_interpolate(hist_cdfvals, xedges[axis], quant_cdfvals, axis=axis)
|
|
643
|
+
|
|
644
|
+
quants = tf.where(quant_cdfvals == 0., x_low, quants)
|
|
645
|
+
quants = tf.where(quant_cdfvals == 1., x_high, quants)
|
|
646
|
+
|
|
647
|
+
ntot = tf.math.reduce_sum(yvals, axis=axis, keepdims=True)
|
|
648
|
+
|
|
649
|
+
quant_cdf_bar = ntot/(1.+ntot)*(quant_cdfvals + 0.5/ntot)
|
|
650
|
+
quant_cdfval_errs = ntot/(1.+ntot)*tf.math.sqrt(quant_cdfvals*(1.-quant_cdfvals)/ntot + 0.25/ntot/ntot)
|
|
651
|
+
|
|
652
|
+
quant_cdfvals_up = quant_cdf_bar + quant_cdfval_errs
|
|
653
|
+
quant_cdfvals_up = tf.clip_by_value(quant_cdfvals_up, 0., 1.)
|
|
654
|
+
|
|
655
|
+
quant_cdfvals_down = quant_cdf_bar - quant_cdfval_errs
|
|
656
|
+
quant_cdfvals_down = tf.clip_by_value(quant_cdfvals_down, 0., 1.)
|
|
657
|
+
|
|
658
|
+
quants_up = pchip_interpolate(hist_cdfvals, xedges[axis], quant_cdfvals_up, axis=axis)
|
|
659
|
+
quants_up = tf.where(quant_cdfvals_up == 0., x_low, quants_up)
|
|
660
|
+
quants_up = tf.where(quant_cdfvals_up == 1., x_high, quants_up)
|
|
661
|
+
|
|
662
|
+
quants_down = pchip_interpolate(hist_cdfvals, xedges[axis], quant_cdfvals_down, axis=axis)
|
|
663
|
+
quants_down = tf.where(quant_cdfvals_down == 0., x_low, quants_down)
|
|
664
|
+
quants_down = tf.where(quant_cdfvals_down == 1., x_high, quants_down)
|
|
665
|
+
|
|
666
|
+
quant_errs = 0.5*(quants_up - quants_down)
|
|
667
|
+
|
|
668
|
+
zero_const = tf.constant(0., dtype)
|
|
669
|
+
|
|
670
|
+
quant_errs = tf.where(quant_cdfvals == 0., zero_const, quant_errs)
|
|
671
|
+
quant_errs = tf.where(quant_cdfvals == 1., zero_const, quant_errs)
|
|
672
|
+
|
|
673
|
+
return quants.numpy(), quant_errs.numpy()
|
|
674
|
+
|
|
675
|
+
def func_cdf_for_quantile_fit(xvals, xedges, qparms, quant_cdfvals, axis=-1, transform = None):
|
|
676
|
+
x_flat = tf.reshape(xedges[axis], (-1,))
|
|
677
|
+
x_low = x_flat[0]
|
|
678
|
+
x_high = x_flat[-1]
|
|
679
|
+
|
|
680
|
+
quants = qparms_to_quantiles(qparms, x_low = x_low, x_high = x_high, axis = axis)
|
|
681
|
+
|
|
682
|
+
spline_edges = xedges[axis]
|
|
683
|
+
|
|
684
|
+
ndim = len(xvals)
|
|
685
|
+
|
|
686
|
+
if transform is not None:
|
|
687
|
+
transform_cdf, transform_quantile = transform
|
|
688
|
+
|
|
689
|
+
slicelim = [slice(None)]*ndim
|
|
690
|
+
slicelim[axis] = slice(1, -1)
|
|
691
|
+
slicelim = tuple(slicelim)
|
|
692
|
+
|
|
693
|
+
quants = quants[slicelim]
|
|
694
|
+
quant_cdfvals = quant_cdfvals[slicelim]
|
|
695
|
+
|
|
696
|
+
quant_cdfvals = transform_quantile(quant_cdfvals)
|
|
697
|
+
|
|
698
|
+
cdfvals = pchip_interpolate(quants, quant_cdfvals, spline_edges, axis=axis)
|
|
699
|
+
|
|
700
|
+
if transform is not None:
|
|
701
|
+
cdfvals = transform_cdf(cdfvals)
|
|
702
|
+
|
|
703
|
+
slicefirst = [slice(None)]*ndim
|
|
704
|
+
slicefirst[axis] = slice(None, 1)
|
|
705
|
+
slicefirst = tuple(slicefirst)
|
|
706
|
+
|
|
707
|
+
slicelast = [slice(None)]*ndim
|
|
708
|
+
slicelast[axis] = slice(-1, None)
|
|
709
|
+
slicelast = tuple(slicelast)
|
|
710
|
+
|
|
711
|
+
cdfvals = (cdfvals - cdfvals[slicefirst])/(cdfvals[slicelast] - cdfvals[slicefirst])
|
|
712
|
+
|
|
713
|
+
return cdfvals
|
|
714
|
+
|
|
715
|
+
def func_constraint_for_quantile_fit(xvals, xedges, qparms, axis=-1):
|
|
716
|
+
constraints = 0.5*tf.math.square(tf.math.reduce_sum(tf.exp(qparms), axis=axis) - 1.)
|
|
717
|
+
constraint = tf.math.reduce_sum(constraints)
|
|
718
|
+
return constraint
|
|
719
|
+
|
|
720
|
+
@tf.function
|
|
721
|
+
def val_grad(func, *args, **kwargs):
|
|
722
|
+
xdep = kwargs["parms"]
|
|
723
|
+
with tf.GradientTape() as t1:
|
|
724
|
+
t1.watch(xdep)
|
|
725
|
+
val = func(*args, **kwargs)
|
|
726
|
+
grad = t1.gradient(val, xdep)
|
|
727
|
+
return val, grad
|
|
728
|
+
|
|
729
|
+
#TODO forward-over-reverse also here?
|
|
730
|
+
@tf.function
|
|
731
|
+
def val_grad_hess(func, *args, **kwargs):
|
|
732
|
+
xdep = kwargs["parms"]
|
|
733
|
+
with tf.GradientTape() as t2:
|
|
734
|
+
t2.watch(xdep)
|
|
735
|
+
with tf.GradientTape() as t1:
|
|
736
|
+
t1.watch(xdep)
|
|
737
|
+
val = func(*args, **kwargs)
|
|
738
|
+
grad = t1.gradient(val, xdep)
|
|
739
|
+
hess = t2.jacobian(grad, xdep)
|
|
740
|
+
|
|
741
|
+
return val, grad, hess
|
|
742
|
+
|
|
743
|
+
@tf.function
|
|
744
|
+
def val_grad_hessp(func, p, *args, **kwargs):
|
|
745
|
+
xdep = kwargs["parms"]
|
|
746
|
+
with tf.autodiff.ForwardAccumulator(xdep, p) as acc:
|
|
747
|
+
with tf.GradientTape() as grad_tape:
|
|
748
|
+
grad_tape.watch(xdep)
|
|
749
|
+
val = func(*args, **kwargs)
|
|
750
|
+
grad = grad_tape.gradient(val, xdep)
|
|
751
|
+
hessp = acc.jvp(grad)
|
|
752
|
+
|
|
753
|
+
return val, grad, hessp
|
|
754
|
+
|
|
755
|
+
def loss_with_constraint(func_loss, parms, func_constraint = None, args_loss = (), extra_args_loss=(), args_constraint = (), extra_args_constraint = ()):
|
|
756
|
+
loss = func_loss(parms, *args_loss, *extra_args_loss)
|
|
757
|
+
if func_constraint is not None:
|
|
758
|
+
loss += func_constraint(*args_constraint, parms, *extra_args_constraint)
|
|
759
|
+
|
|
760
|
+
return loss
|
|
761
|
+
|
|
762
|
+
def chisq_loss(parms, xvals, xwidths, xedges, yvals, yvariances, func, norm_axes = None, *args):
|
|
763
|
+
fvals = func(xvals, parms, *args)
|
|
764
|
+
|
|
765
|
+
# exclude zero-variance bins
|
|
766
|
+
variances_safe = tf.where(yvariances == 0., tf.ones_like(yvariances), yvariances)
|
|
767
|
+
chisqv = (fvals - yvals)**2/variances_safe
|
|
768
|
+
chisqv_safe = tf.where(yvariances == 0., tf.zeros_like(chisqv), chisqv)
|
|
769
|
+
return tf.reduce_sum(chisqv_safe)
|
|
770
|
+
|
|
771
|
+
def chisq_normalized_loss(parms, xvals, xwidths, xedges, yvals, yvariances, func, norm_axes = None, *args):
|
|
772
|
+
fvals = func(xvals, parms, *args)
|
|
773
|
+
norm = tf.reduce_sum(fvals, keepdims=True, axis = norm_axes)
|
|
774
|
+
sumw = tf.reduce_sum(yvals, keepdims=True, axis = norm_axes)
|
|
775
|
+
if norm_axes is None:
|
|
776
|
+
for xwidth in xwidths:
|
|
777
|
+
norm *= xwidth
|
|
778
|
+
else:
|
|
779
|
+
for norm_axis in norm_axes:
|
|
780
|
+
norm *= xwidths[norm_axis]
|
|
781
|
+
|
|
782
|
+
# exclude zero-variance bins
|
|
783
|
+
variances_safe = tf.where(yvariances == 0., tf.ones_like(yvariances), yvariances)
|
|
784
|
+
chisqv = (sumw*fvals/norm - yvals)**2/variances_safe
|
|
785
|
+
chisqv_safe = tf.where(yvariances == 0., tf.zeros_like(chisqv), chisqv)
|
|
786
|
+
return tf.reduce_sum(chisqv_safe)
|
|
787
|
+
|
|
788
|
+
def nll_loss(parms, xvals, xwidths, xedges, yvals, yvariances, func, norm_axes = None, *args):
|
|
789
|
+
fvals = func(xvals, parms, *args)
|
|
790
|
+
|
|
791
|
+
# compute overall scaling needed to restore mean == variance condition
|
|
792
|
+
yval_total = tf.reduce_sum(yvals, keepdims = True, axis = norm_axes)
|
|
793
|
+
variance_total = tf.reduce_sum(yvariances, keepdims = True, axis = norm_axes)
|
|
794
|
+
isnull_total = variance_total == 0.
|
|
795
|
+
variance_total_safe = tf.where(isnull_total, tf.ones_like(variance_total), variance_total)
|
|
796
|
+
scale_total = yval_total/variance_total_safe
|
|
797
|
+
scale_total_safe = tf.where(isnull_total, tf.ones_like(scale_total), scale_total)
|
|
798
|
+
|
|
799
|
+
# skip likelihood calculation for empty bins to avoid inf or nan
|
|
800
|
+
# compute per-bin scaling needed to restore mean == variance condition, falling
|
|
801
|
+
# back to overall scaling for empty bins
|
|
802
|
+
isnull = tf.logical_or(yvals == 0., yvariances == 0.)
|
|
803
|
+
variances_safe = tf.where(isnull, tf.ones_like(yvariances), yvariances)
|
|
804
|
+
scale = yvals/variances_safe
|
|
805
|
+
scale_safe = tf.where(isnull, scale_total_safe*tf.ones_like(scale), scale)
|
|
806
|
+
|
|
807
|
+
norm = tf.reduce_sum(scale_safe*fvals, keepdims=True, axis = norm_axes)
|
|
808
|
+
if norm_axes is None:
|
|
809
|
+
for xwidth in xwidths:
|
|
810
|
+
norm *= xwidth
|
|
811
|
+
else:
|
|
812
|
+
for norm_axis in norm_axes:
|
|
813
|
+
norm *= xwidths[norm_axis]
|
|
814
|
+
|
|
815
|
+
fvalsnorm = fvals/norm
|
|
816
|
+
|
|
817
|
+
fvalsnorm_safe = tf.where(isnull, tf.ones_like(fvalsnorm), fvalsnorm)
|
|
818
|
+
nllv = -scale_safe*yvals*tf.math.log(scale_safe*fvalsnorm_safe)
|
|
819
|
+
nllv_safe = tf.where(isnull, tf.zeros_like(nllv), nllv)
|
|
820
|
+
nllsum = tf.reduce_sum(nllv_safe)
|
|
821
|
+
return nllsum
|
|
822
|
+
|
|
823
|
+
def nll_loss_bin_integrated(parms, xvals, xwidths, xedges, yvals, yvariances, func, norm_axes = None, *args):
|
|
824
|
+
#TODO reduce code duplication with nll_loss_bin
|
|
825
|
+
|
|
826
|
+
norm_axis = 0
|
|
827
|
+
if norm_axes is not None:
|
|
828
|
+
if len(norm_axes) > 1:
|
|
829
|
+
raise ValueError("Only 1 nomralization access supported for bin-integrated nll")
|
|
830
|
+
norm_axis = norm_axes[0]
|
|
831
|
+
|
|
832
|
+
cdfvals = func(xvals, xedges, parms, *args)
|
|
833
|
+
|
|
834
|
+
slices_low = [slice(None)]*len(cdfvals.shape)
|
|
835
|
+
slices_low[norm_axis] = slice(None,-1)
|
|
836
|
+
|
|
837
|
+
slices_high = [slice(None)]*len(cdfvals.shape)
|
|
838
|
+
slices_high[norm_axis] = slice(1,None)
|
|
839
|
+
|
|
840
|
+
# bin_integrals = cdfvals[1:] - cdfvals[:-1]
|
|
841
|
+
bin_integrals = cdfvals[tuple(slices_high)] - cdfvals[tuple(slices_low)]
|
|
842
|
+
bin_integrals = tf.maximum(bin_integrals, tf.zeros_like(bin_integrals))
|
|
843
|
+
|
|
844
|
+
fvals = bin_integrals
|
|
845
|
+
|
|
846
|
+
# compute overall scaling needed to restore mean == variance condition
|
|
847
|
+
yval_total = tf.reduce_sum(yvals, keepdims = True, axis = norm_axes)
|
|
848
|
+
variance_total = tf.reduce_sum(yvariances, keepdims = True, axis = norm_axes)
|
|
849
|
+
isnull_total = variance_total == 0.
|
|
850
|
+
variance_total_safe = tf.where(isnull_total, tf.ones_like(variance_total), variance_total)
|
|
851
|
+
scale_total = yval_total/variance_total_safe
|
|
852
|
+
scale_total_safe = tf.where(isnull_total, tf.ones_like(scale_total), scale_total)
|
|
853
|
+
|
|
854
|
+
# skip likelihood calculation for empty bins to avoid inf or nan
|
|
855
|
+
# compute per-bin scaling needed to restore mean == variance condition, falling
|
|
856
|
+
# back to overall scaling for empty bins
|
|
857
|
+
isnull = tf.logical_or(yvals == 0., yvariances == 0.)
|
|
858
|
+
variances_safe = tf.where(isnull, tf.ones_like(yvariances), yvariances)
|
|
859
|
+
scale = yvals/variances_safe
|
|
860
|
+
scale_safe = tf.where(isnull, scale_total_safe*tf.ones_like(scale), scale)
|
|
861
|
+
|
|
862
|
+
norm = tf.reduce_sum(scale_safe*fvals, keepdims=True, axis = norm_axes)
|
|
863
|
+
|
|
864
|
+
fvalsnorm = fvals/norm
|
|
865
|
+
|
|
866
|
+
fvalsnorm_safe = tf.where(isnull, tf.ones_like(fvalsnorm), fvalsnorm)
|
|
867
|
+
nllv = -scale_safe*yvals*tf.math.log(scale_safe*fvalsnorm_safe)
|
|
868
|
+
nllv_safe = tf.where(isnull, tf.zeros_like(nllv), nllv)
|
|
869
|
+
nllsum = tf.reduce_sum(nllv_safe)
|
|
870
|
+
return nllsum
|
|
871
|
+
|
|
872
|
+
def chisq_loss_bin_integrated(parms, xvals, xwidths, xedges, yvals, yvariances, func, norm_axes = None, *args):
|
|
873
|
+
#FIXME this is only defined in 1D for now
|
|
874
|
+
cdfvals = func(xedges, parms, *args)
|
|
875
|
+
bin_integrals = cdfvals[1:] - cdfvals[:-1]
|
|
876
|
+
bin_integrals = tf.maximum(bin_integrals, tf.zeros_like(bin_integrals))
|
|
877
|
+
|
|
878
|
+
fvals = bin_integrals
|
|
879
|
+
|
|
880
|
+
# exclude zero-variance bins
|
|
881
|
+
variances_safe = tf.where(yvariances == 0., tf.ones_like(yvariances), yvariances)
|
|
882
|
+
chisqv = (fvals - yvals)**2/variances_safe
|
|
883
|
+
chisqv_safe = tf.where(yvariances == 0., tf.zeros_like(chisqv), chisqv)
|
|
884
|
+
chisqsum = tf.reduce_sum(chisqv_safe)
|
|
885
|
+
|
|
886
|
+
return chisqsum
|
|
887
|
+
|
|
888
|
+
|
|
889
|
+
def fit_hist(hist, func, initial_parmvals, max_iter = 5, edmtol = 1e-5, mode = "chisq", norm_axes = None, func_constraint = None, args = (), args_constraint=()):
|
|
890
|
+
|
|
891
|
+
dtype = tf.float64
|
|
892
|
+
|
|
893
|
+
xvals = [tf.constant(center, dtype=dtype) for center in hist.axes.centers]
|
|
894
|
+
xwidths = [tf.constant(width, dtype=dtype) for width in hist.axes.widths]
|
|
895
|
+
xedges = [tf.constant(edge, dtype=dtype) for edge in hist.axes.edges]
|
|
896
|
+
yvals = tf.constant(hist.values(), dtype=dtype)
|
|
897
|
+
yvariances = tf.constant(hist.variances(), dtype=dtype)
|
|
898
|
+
|
|
899
|
+
covscale = 1.
|
|
900
|
+
if mode == "chisq":
|
|
901
|
+
floss = chisq_loss
|
|
902
|
+
covscale = 2.
|
|
903
|
+
elif mode == "nll":
|
|
904
|
+
floss = nll_loss
|
|
905
|
+
elif mode == "nll_bin_integrated":
|
|
906
|
+
floss = nll_loss_bin_integrated
|
|
907
|
+
elif mode == "chisq_normalized":
|
|
908
|
+
floss = chisq_normalized_loss
|
|
909
|
+
covscale = 2.
|
|
910
|
+
elif mode == "chisq_loss_bin_integrated":
|
|
911
|
+
floss = chisq_loss_bin_integrated
|
|
912
|
+
covscale = 2.
|
|
913
|
+
elif mode == "nll_extended":
|
|
914
|
+
raise Exception("Not Implemented")
|
|
915
|
+
else:
|
|
916
|
+
raise Exception("unsupported mode")
|
|
917
|
+
|
|
918
|
+
val_grad_args = { "func_loss" : floss,
|
|
919
|
+
"func_constraint" : func_constraint,
|
|
920
|
+
"args_loss" : (xvals, xwidths, xedges, yvals, yvariances, func, norm_axes),
|
|
921
|
+
"extra_args_loss" : args,
|
|
922
|
+
"args_constraint" : (xvals, xedges),
|
|
923
|
+
"extra_args_constraint" : args_constraint}
|
|
924
|
+
|
|
925
|
+
def scipy_loss(parmvals, *args):
|
|
926
|
+
parms = tf.constant(parmvals, dtype=dtype)
|
|
927
|
+
|
|
928
|
+
# loss, grad = val_grad(floss, parms, xvals, xwidths, xedges, yvals, yvariances, func, norm_axes, *args)
|
|
929
|
+
loss, grad = val_grad(loss_with_constraint, parms=parms, **val_grad_args)
|
|
930
|
+
return loss.numpy(), grad.numpy()
|
|
931
|
+
|
|
932
|
+
def scipy_hessp(parmvals, p, *args):
|
|
933
|
+
parms = tf.constant(parmvals, dtype=dtype)
|
|
934
|
+
|
|
935
|
+
# loss, grad, hessp = val_grad_hessp(floss, p, parms, xvals, xwidths, xedges, yvals, yvariances, func, norm_axes, *args)
|
|
936
|
+
loss, grad, hessp = val_grad_hessp(loss_with_constraint, p, parms=parms, **val_grad_args)
|
|
937
|
+
return hessp.numpy()
|
|
938
|
+
|
|
939
|
+
current_parmvals = initial_parmvals
|
|
940
|
+
for iiter in range(max_iter):
|
|
941
|
+
|
|
942
|
+
res = scipy.optimize.minimize(scipy_loss, current_parmvals, method = "trust-krylov", jac = True, hessp = scipy_hessp, args = args)
|
|
943
|
+
|
|
944
|
+
current_parmvals = res.x
|
|
945
|
+
|
|
946
|
+
parms = tf.constant(current_parmvals, dtype=dtype)
|
|
947
|
+
|
|
948
|
+
# loss, grad, hess = val_grad_hess(floss, parms, xvals, xwidths, xedges, yvals, yvariances, func, norm_axes, *args)
|
|
949
|
+
loss, grad, hess = val_grad_hess(loss_with_constraint, parms=parms, **val_grad_args)
|
|
950
|
+
loss, grad, hess = loss.numpy(), grad.numpy(), hess.numpy()
|
|
951
|
+
|
|
952
|
+
try:
|
|
953
|
+
eigvals = np.linalg.eigvalsh(hess)
|
|
954
|
+
gradv = grad[:, np.newaxis]
|
|
955
|
+
edmval = 0.5*gradv.transpose()@np.linalg.solve(hess, gradv)
|
|
956
|
+
edmval = edmval[0][0]
|
|
957
|
+
except np.linalg.LinAlgError:
|
|
958
|
+
eigvals = np.zeros_like(grad)
|
|
959
|
+
edmval = 99.
|
|
960
|
+
|
|
961
|
+
converged = edmval < edmtol and np.abs(edmval) >= 0. and eigvals[0] > 0.
|
|
962
|
+
if converged:
|
|
963
|
+
break
|
|
964
|
+
|
|
965
|
+
status = 1
|
|
966
|
+
covstatus = 1
|
|
967
|
+
|
|
968
|
+
if edmval < edmtol and edmval >= -0.:
|
|
969
|
+
status = 0
|
|
970
|
+
if eigvals[0] > 0.:
|
|
971
|
+
covstatus = 0
|
|
972
|
+
|
|
973
|
+
try:
|
|
974
|
+
cov = covscale*np.linalg.inv(hess)
|
|
975
|
+
except np.linalg.LinAlgError:
|
|
976
|
+
cov = np.zeros_like(hess)
|
|
977
|
+
covstatus = 1
|
|
978
|
+
|
|
979
|
+
res = { "x" : current_parmvals,
|
|
980
|
+
"hess" : hess,
|
|
981
|
+
"cov" : cov,
|
|
982
|
+
"status" : status,
|
|
983
|
+
"covstatus" : covstatus,
|
|
984
|
+
"hess_eigvals" : eigvals,
|
|
985
|
+
"edmval" : edmval,
|
|
986
|
+
"loss_val" : loss }
|
|
987
|
+
|
|
988
|
+
return res
|
|
989
|
+
|
wums/fitutilsjax.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import scipy
|
|
3
|
+
import jax
|
|
4
|
+
import jax.numpy as jnp
|
|
5
|
+
|
|
6
|
+
def chisqloss(xvals, yvals, yvariances, func, parms):
|
|
7
|
+
return jnp.sum( (func(xvals, parms) - yvals)**2/yvariances )
|
|
8
|
+
|
|
9
|
+
chisqloss_grad = jax.jit(jax.value_and_grad(chisqloss, argnums = 4), static_argnums = 3)
|
|
10
|
+
|
|
11
|
+
def _chisqloss_grad_hess(xvals, yvals, yvariances, func, parms):
|
|
12
|
+
def lossf(parms):
|
|
13
|
+
return chisqloss(xvals, yvals, yvariances, func, parms)
|
|
14
|
+
|
|
15
|
+
gradf = jax.grad(lossf)
|
|
16
|
+
hessf = jax.jacfwd(gradf)
|
|
17
|
+
|
|
18
|
+
loss = lossf(parms)
|
|
19
|
+
grad = gradf(parms)
|
|
20
|
+
hess = hessf(parms)
|
|
21
|
+
|
|
22
|
+
return loss, grad, hess
|
|
23
|
+
|
|
24
|
+
chisqloss_grad_hess = jax.jit(_chisqloss_grad_hess, static_argnums = 3)
|
|
25
|
+
|
|
26
|
+
def _chisqloss_hessp(xvals, yvals, yvariances, func, parms, p):
|
|
27
|
+
def lossf(parms):
|
|
28
|
+
return chisqloss(xvals, yvals, yvariances, func, parms)
|
|
29
|
+
|
|
30
|
+
gradf = jax.grad(lossf)
|
|
31
|
+
hessp = jax.jvp(gradf, (parms,), (p,))[1]
|
|
32
|
+
return hessp
|
|
33
|
+
|
|
34
|
+
chisqloss_hessp = jax.jit(_chisqloss_hessp, static_argnums = 3)
|
|
35
|
+
|
|
36
|
+
def fit_hist_jax(hist, func, parmvals, max_iter = 5, edmtol = 1e-5):
|
|
37
|
+
|
|
38
|
+
xvals = [jnp.array(center) for center in hist.axes.centers]
|
|
39
|
+
yvals = jnp.array(hist.values())
|
|
40
|
+
yvariances = jnp.array(hist.variances())
|
|
41
|
+
|
|
42
|
+
def scipy_loss(parmvals):
|
|
43
|
+
parms = jnp.array(parmvals)
|
|
44
|
+
loss, grad = chisqloss_grad(xvals, yvals, yvariances, func, parms)
|
|
45
|
+
return np.asarray(loss).item(), np.asarray(grad)
|
|
46
|
+
|
|
47
|
+
def scipy_hessp(parmvals, p):
|
|
48
|
+
parms = jnp.array(parmvals)
|
|
49
|
+
tangent = jnp.array(p)
|
|
50
|
+
hessp = chisqloss_hessp(xvals, yvals, yvariances, func, parms, tangent)
|
|
51
|
+
return np.asarray(hessp)
|
|
52
|
+
|
|
53
|
+
for iiter in range(max_iter):
|
|
54
|
+
res = scipy.optimize.minimize(scipy_loss, parmvals, method = "trust-krylov", jac = True, hessp = scipy_hessp)
|
|
55
|
+
|
|
56
|
+
parms = jnp.array(res.x)
|
|
57
|
+
loss, grad, hess = chisqloss_grad_hess(xvals, yvals, yvariances, func, parms)
|
|
58
|
+
loss, grad, hess = np.asarray(loss).item(), np.asarray(grad), np.asarray(hess)
|
|
59
|
+
|
|
60
|
+
eigvals = np.linalg.eigvalsh(hess)
|
|
61
|
+
cov = 2.*np.linalg.inv(hess)
|
|
62
|
+
|
|
63
|
+
gradv = grad[:, np.newaxis]
|
|
64
|
+
edmval = 0.5*gradv.transpose()@cov@gradv
|
|
65
|
+
edmval = edmval[0][0]
|
|
66
|
+
|
|
67
|
+
converged = edmval < edmtol and np.abs(edmval) >= 0. and eigvals[0] > 0.
|
|
68
|
+
if converged:
|
|
69
|
+
break
|
|
70
|
+
|
|
71
|
+
status = 1
|
|
72
|
+
covstatus = 1
|
|
73
|
+
if edmval < edmtol and np.abs(edmval) >= 0.:
|
|
74
|
+
status = 0
|
|
75
|
+
if eigvals[0] > 0.:
|
|
76
|
+
covstatus = 0
|
|
77
|
+
|
|
78
|
+
res = { "x" : res.x,
|
|
79
|
+
"cov" : cov,
|
|
80
|
+
"status" : status,
|
|
81
|
+
"covstatus" : covstatus,
|
|
82
|
+
"hess_eigvals" : eigvals,
|
|
83
|
+
"edmval" : edmval,
|
|
84
|
+
"chisqval" : loss }
|
|
85
|
+
|
|
86
|
+
return res
|
wums/logging.py
CHANGED
|
@@ -42,19 +42,19 @@ def set_logging_level(log, verbosity):
|
|
|
42
42
|
log.setLevel(logging_verboseLevel[max(0, min(4, verbosity))])
|
|
43
43
|
|
|
44
44
|
|
|
45
|
-
def setup_logger(basefile, verbosity=3, no_colors=False, initName="
|
|
45
|
+
def setup_logger(basefile, verbosity=3, no_colors=False, initName="wums"):
|
|
46
46
|
|
|
47
47
|
setup_func = setup_base_logger if no_colors else setup_color_logger
|
|
48
48
|
logger = setup_func(os.path.basename(basefile), verbosity, initName)
|
|
49
49
|
# count messages of base logger
|
|
50
|
-
base_logger = logging.getLogger("
|
|
50
|
+
base_logger = logging.getLogger("wums")
|
|
51
51
|
add_logging_counter(base_logger)
|
|
52
52
|
# stop total time
|
|
53
53
|
add_time_info("Total time")
|
|
54
54
|
return logger
|
|
55
55
|
|
|
56
56
|
|
|
57
|
-
def setup_color_logger(name, verbosity, initName="
|
|
57
|
+
def setup_color_logger(name, verbosity, initName="wums"):
|
|
58
58
|
base_logger = logging.getLogger(initName)
|
|
59
59
|
# set console handler
|
|
60
60
|
ch = logging.StreamHandler()
|
|
@@ -65,14 +65,14 @@ def setup_color_logger(name, verbosity, initName="wremnants"):
|
|
|
65
65
|
return base_logger.getChild(name)
|
|
66
66
|
|
|
67
67
|
|
|
68
|
-
def setup_base_logger(name, verbosity, initName="
|
|
68
|
+
def setup_base_logger(name, verbosity, initName="wums"):
|
|
69
69
|
logging.basicConfig(format="%(levelname)s: %(message)s")
|
|
70
70
|
base_logger = logging.getLogger(initName)
|
|
71
71
|
set_logging_level(base_logger, verbosity)
|
|
72
72
|
return base_logger.getChild(name)
|
|
73
73
|
|
|
74
74
|
|
|
75
|
-
def child_logger(name, initName="
|
|
75
|
+
def child_logger(name, initName="wums"):
|
|
76
76
|
# count messages of child logger
|
|
77
77
|
logger = logging.getLogger(initName).getChild(name)
|
|
78
78
|
add_logging_counter(logger)
|
|
@@ -110,7 +110,7 @@ def print_logging_count(logger, verbosity=logging.WARNING):
|
|
|
110
110
|
)
|
|
111
111
|
|
|
112
112
|
|
|
113
|
-
def add_time_info(tag, logger=logging.getLogger("
|
|
113
|
+
def add_time_info(tag, logger=logging.getLogger("wums")):
|
|
114
114
|
if not hasattr(logger, "times"):
|
|
115
115
|
logger.times = {}
|
|
116
116
|
logger.times[tag] = time.time()
|
|
@@ -125,7 +125,7 @@ def print_time_info(logger):
|
|
|
125
125
|
|
|
126
126
|
|
|
127
127
|
def summary(verbosity=logging.WARNING, extended=True):
|
|
128
|
-
base_logger = logging.getLogger("
|
|
128
|
+
base_logger = logging.getLogger("wums")
|
|
129
129
|
|
|
130
130
|
base_logger.info(f"--------------------------------------")
|
|
131
131
|
base_logger.info(f"----------- logger summary -----------")
|
|
@@ -141,5 +141,5 @@ def summary(verbosity=logging.WARNING, extended=True):
|
|
|
141
141
|
# Iterate through all child loggers and print their names, levels, and counts
|
|
142
142
|
all_loggers = logging.Logger.manager.loggerDict
|
|
143
143
|
for logger_name, logger_obj in all_loggers.items():
|
|
144
|
-
if logger_name.startswith("
|
|
144
|
+
if logger_name.startswith("wums."):
|
|
145
145
|
print_logging_count(logger_obj, verbosity=verbosity)
|
wums/tfutils.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
import tensorflow as tf
|
|
2
|
+
|
|
3
|
+
def function_to_tflite(funcs, input_signatures, func_names=""):
|
|
4
|
+
"""Convert function to tflite model using python dynamic execution trickery to ensure that inputs
|
|
5
|
+
and outputs are alphabetically ordered, since this is apparently the only way to prevent tflite from
|
|
6
|
+
scrambling them"""
|
|
7
|
+
|
|
8
|
+
if not isinstance(funcs, list):
|
|
9
|
+
funcs = [funcs]
|
|
10
|
+
input_signatures = [input_signatures]
|
|
11
|
+
func_names = [func_names]
|
|
12
|
+
func_names = [funcs[iif].__name__ if func_names[iif]=="" else func_names[iif] for iif in range(len(funcs))]
|
|
13
|
+
|
|
14
|
+
def wrapped_func(iif, *args):
|
|
15
|
+
outputs = funcs[iif](*args)
|
|
16
|
+
|
|
17
|
+
if not isinstance(outputs, tuple):
|
|
18
|
+
outputs = (outputs,)
|
|
19
|
+
|
|
20
|
+
output_dict = {}
|
|
21
|
+
for i,output in enumerate(outputs):
|
|
22
|
+
output_name = f"output_{iif:05d}_{i:05d}"
|
|
23
|
+
output_dict[output_name] = output
|
|
24
|
+
|
|
25
|
+
return output_dict
|
|
26
|
+
|
|
27
|
+
arg_string = []
|
|
28
|
+
for iif, input_signature in enumerate(input_signatures):
|
|
29
|
+
inputs = []
|
|
30
|
+
for i in range(len(input_signature)):
|
|
31
|
+
input_name = f"input_{iif:05d}_{i:05d}"
|
|
32
|
+
inputs.append(input_name)
|
|
33
|
+
arg_string.append(", ".join(inputs))
|
|
34
|
+
|
|
35
|
+
def_string = ""
|
|
36
|
+
def_string += "def make_module(wrapped_func, input_signatures):\n"
|
|
37
|
+
def_string += " class Export_Module(tf.Module):\n"
|
|
38
|
+
for i, func in enumerate(funcs):
|
|
39
|
+
def_string += f" @tf.function(input_signature = input_signatures[{i}])\n"
|
|
40
|
+
def_string += f" def {func_names[i]}(self, {arg_string[i]}):\n"
|
|
41
|
+
def_string += f" return wrapped_func({i}, {arg_string[i]})\n"
|
|
42
|
+
def_string += " return Export_Module"
|
|
43
|
+
|
|
44
|
+
ldict = {}
|
|
45
|
+
exec(def_string, globals(), ldict)
|
|
46
|
+
|
|
47
|
+
make_module = ldict["make_module"]
|
|
48
|
+
Export_Module = make_module(wrapped_func, input_signatures)
|
|
49
|
+
|
|
50
|
+
module = Export_Module()
|
|
51
|
+
concrete_functions = [getattr(module, func_name).get_concrete_function() for func_name in func_names]
|
|
52
|
+
converter = tf.lite.TFLiteConverter.from_concrete_functions(concrete_functions, module)
|
|
53
|
+
|
|
54
|
+
# enable TenorFlow ops and DISABLE builtin TFLite ops since these apparently slow things down
|
|
55
|
+
converter.target_spec.supported_ops = [
|
|
56
|
+
tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops.
|
|
57
|
+
]
|
|
58
|
+
|
|
59
|
+
converter._experimental_allow_all_select_tf_ops = True
|
|
60
|
+
|
|
61
|
+
tflite_model = converter.convert()
|
|
62
|
+
|
|
63
|
+
test_interp = tf.lite.Interpreter(model_content = tflite_model)
|
|
64
|
+
print(test_interp.get_input_details())
|
|
65
|
+
print(test_interp.get_output_details())
|
|
66
|
+
print(test_interp.get_signature_list())
|
|
67
|
+
|
|
68
|
+
return tflite_model
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def function_to_saved_model(func, input_signature, output):
|
|
73
|
+
|
|
74
|
+
class Export_Module(tf.Module):
|
|
75
|
+
@tf.function(input_signature = input_signature)
|
|
76
|
+
def __call__(self, *args):
|
|
77
|
+
return func(*args)
|
|
78
|
+
|
|
79
|
+
model = Export_Module()
|
|
80
|
+
|
|
81
|
+
tf.saved_model.save(model, output)
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
Metadata-Version: 2.2
|
|
2
|
+
Name: wums
|
|
3
|
+
Version: 0.1.7
|
|
4
|
+
Summary: .
|
|
5
|
+
Author-email: David Walter <david.walter@cern.ch>, Josh Bendavid <josh.bendavid@cern.ch>, Kenneth Long <kenneth.long@cern.ch>, Jan Eysermans <jan.eysermans@cern.ch>
|
|
6
|
+
License: MIT
|
|
7
|
+
Project-URL: Homepage, https://github.com/WMass/wums
|
|
8
|
+
Classifier: Programming Language :: Python :: 3
|
|
9
|
+
Classifier: Programming Language :: Python :: 3.8
|
|
10
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
11
|
+
Classifier: Operating System :: OS Independent
|
|
12
|
+
Requires-Python: >=3.8
|
|
13
|
+
Description-Content-Type: text/markdown
|
|
14
|
+
Requires-Dist: hist
|
|
15
|
+
Requires-Dist: numpy
|
|
16
|
+
Provides-Extra: plotting
|
|
17
|
+
Requires-Dist: matplotlib; extra == "plotting"
|
|
18
|
+
Requires-Dist: mplhep; extra == "plotting"
|
|
19
|
+
Provides-Extra: fitting
|
|
20
|
+
Requires-Dist: tensorflow; extra == "fitting"
|
|
21
|
+
Requires-Dist: jax; extra == "fitting"
|
|
22
|
+
Requires-Dist: scipy; extra == "fitting"
|
|
23
|
+
Provides-Extra: pickling
|
|
24
|
+
Requires-Dist: boost_histogram; extra == "pickling"
|
|
25
|
+
Requires-Dist: h5py; extra == "pickling"
|
|
26
|
+
Requires-Dist: hdf5plugin; extra == "pickling"
|
|
27
|
+
Requires-Dist: lz4; extra == "pickling"
|
|
28
|
+
Provides-Extra: all
|
|
29
|
+
Requires-Dist: plotting; extra == "all"
|
|
30
|
+
Requires-Dist: fitting; extra == "all"
|
|
31
|
+
Requires-Dist: pickling; extra == "all"
|
|
32
|
+
|
|
33
|
+
# WUMS: Wremnants Utilities, Modules, and other Stuff
|
|
34
|
+
|
|
35
|
+
As the name suggests, this is a collection of different thins, all python based:
|
|
36
|
+
- Fitting with tensorflow or jax
|
|
37
|
+
- Custom pickling h5py objects
|
|
38
|
+
- Plotting functionality
|
|
39
|
+
|
|
40
|
+
## Install
|
|
41
|
+
|
|
42
|
+
The `wums` package can be pip installed with minimal dependencies:
|
|
43
|
+
```bash
|
|
44
|
+
pip install wums
|
|
45
|
+
```
|
|
46
|
+
Different dependencies can be added with `plotting`, `fitting`, `pickling` to use the corresponding scripts.
|
|
47
|
+
For example, one can install with
|
|
48
|
+
```bash
|
|
49
|
+
pip install wums[plotting,fitting]
|
|
50
|
+
```
|
|
51
|
+
Or all dependencies with
|
|
52
|
+
```bash
|
|
53
|
+
pip install wums[all]
|
|
54
|
+
```
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
wums/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
+
wums/boostHistHelpers.py,sha256=mgdPXAgmxriqoOhrhMctyZcfwEOPfV07V27CvGt2sk8,39260
|
|
3
|
+
wums/fitutils.py,sha256=sPCMJqZGdXvDfc8OxjOB-Bpf45GWHKxmKkDV3SlMUQs,38297
|
|
4
|
+
wums/fitutilsjax.py,sha256=HE1AcIZmI6N_xIHo8OHCPaYkHSnND_B-vI4Gl3vaUmA,2659
|
|
5
|
+
wums/ioutils.py,sha256=ziyfQQ8CB3Ir2BJKJU3_a7YMF-Jd2nGXKoMQoJ2T8fo,12334
|
|
6
|
+
wums/logging.py,sha256=L4514Xyq7L1z77Tkh8KE2HX88ZZ06o6SSRyQo96DbC0,4494
|
|
7
|
+
wums/output_tools.py,sha256=SHcZqXAdqL9AkA57UF0b-R-U4u7rzDgL8Def4E-ulW0,6713
|
|
8
|
+
wums/plot_tools.py,sha256=4iPx9Nr9y8c3p4ovy8XOS-xU_w11OyQEjISKkygxqcA,55918
|
|
9
|
+
wums/tfutils.py,sha256=9efkkvxH7VtwJN2yBS6_-P9dLKs3CXdxMFdrEBNsna8,2892
|
|
10
|
+
wums/Templates/index.php,sha256=9EYmfc0ltMqr5oOdA4_BVIHdSbef5aA0ORoRZBEADVw,4348
|
|
11
|
+
wums-0.1.7.dist-info/METADATA,sha256=GrQyVuatMvHdallbstH7YdiACEMLIo5isHyugfFawW8,1784
|
|
12
|
+
wums-0.1.7.dist-info/WHEEL,sha256=jB7zZ3N9hIM9adW7qlTAyycLYW9npaWKLRzaoVcLKcM,91
|
|
13
|
+
wums-0.1.7.dist-info/top_level.txt,sha256=DCE1TVg7ySraosR3kYZkLIZ2w1Pwk2pVTdkqx6E-yRY,5
|
|
14
|
+
wums-0.1.7.dist-info/RECORD,,
|
wums-0.1.6.dist-info/METADATA
DELETED
|
@@ -1,29 +0,0 @@
|
|
|
1
|
-
Metadata-Version: 2.2
|
|
2
|
-
Name: wums
|
|
3
|
-
Version: 0.1.6
|
|
4
|
-
Summary: .
|
|
5
|
-
Author-email: David Walter <david.walter@cern.ch>, Josh Bendavid <josh.bendavid@cern.ch>, Kenneth Long <kenneth.long@cern.ch>
|
|
6
|
-
License: MIT
|
|
7
|
-
Project-URL: Homepage, https://github.com/WMass/wums
|
|
8
|
-
Classifier: Programming Language :: Python :: 3
|
|
9
|
-
Classifier: Programming Language :: Python :: 3.8
|
|
10
|
-
Classifier: License :: OSI Approved :: MIT License
|
|
11
|
-
Classifier: Operating System :: OS Independent
|
|
12
|
-
Requires-Python: >=3.8
|
|
13
|
-
Description-Content-Type: text/markdown
|
|
14
|
-
Requires-Dist: boost_histogram
|
|
15
|
-
Requires-Dist: h5py
|
|
16
|
-
Requires-Dist: hdf5plugin
|
|
17
|
-
Requires-Dist: hist
|
|
18
|
-
Requires-Dist: lz4
|
|
19
|
-
Requires-Dist: matplotlib
|
|
20
|
-
Requires-Dist: mplhep
|
|
21
|
-
Requires-Dist: numpy
|
|
22
|
-
Requires-Dist: uproot
|
|
23
|
-
|
|
24
|
-
# WUMS: Wremnants Utilities, Modules, and other Stuff
|
|
25
|
-
|
|
26
|
-
The `wums` package can be pip installed:
|
|
27
|
-
```bash
|
|
28
|
-
pip install wums
|
|
29
|
-
```
|
wums-0.1.6.dist-info/RECORD
DELETED
|
@@ -1,11 +0,0 @@
|
|
|
1
|
-
wums/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
-
wums/boostHistHelpers.py,sha256=F4SwQEVjNObFscfs0qrJEyOHYNKqUCmusW8HIF1o-0c,38993
|
|
3
|
-
wums/ioutils.py,sha256=ziyfQQ8CB3Ir2BJKJU3_a7YMF-Jd2nGXKoMQoJ2T8fo,12334
|
|
4
|
-
wums/logging.py,sha256=zNnLVJUwG3HMvr9NeXmiheX07VmsnSt8cQ6R4q4XBk4,4534
|
|
5
|
-
wums/output_tools.py,sha256=SHcZqXAdqL9AkA57UF0b-R-U4u7rzDgL8Def4E-ulW0,6713
|
|
6
|
-
wums/plot_tools.py,sha256=4iPx9Nr9y8c3p4ovy8XOS-xU_w11OyQEjISKkygxqcA,55918
|
|
7
|
-
wums/Templates/index.php,sha256=9EYmfc0ltMqr5oOdA4_BVIHdSbef5aA0ORoRZBEADVw,4348
|
|
8
|
-
wums-0.1.6.dist-info/METADATA,sha256=pTmIMc-rth2X53tju6Ef8WbJDda2zbr8isEqUpeqhDo,843
|
|
9
|
-
wums-0.1.6.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
|
10
|
-
wums-0.1.6.dist-info/top_level.txt,sha256=DCE1TVg7ySraosR3kYZkLIZ2w1Pwk2pVTdkqx6E-yRY,5
|
|
11
|
-
wums-0.1.6.dist-info/RECORD,,
|
|
File without changes
|