wums 0.1.6__py3-none-any.whl → 0.1.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scripts/test/testsplinepdf.py +90 -0
- scripts/test/testsplinepdf2d.py +323 -0
- wums/boostHistHelpers.py +5 -1
- wums/fitutils.py +989 -0
- wums/fitutilsjax.py +86 -0
- wums/logging.py +8 -8
- wums/plot_tools.py +1 -68
- wums/tfutils.py +81 -0
- wums-0.1.8.dist-info/METADATA +54 -0
- wums-0.1.8.dist-info/RECORD +16 -0
- {wums-0.1.6.dist-info → wums-0.1.8.dist-info}/WHEEL +1 -1
- wums-0.1.8.dist-info/top_level.txt +2 -0
- wums-0.1.6.dist-info/METADATA +0 -29
- wums-0.1.6.dist-info/RECORD +0 -11
- wums-0.1.6.dist-info/top_level.txt +0 -1
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
import wums.fitutils
|
|
2
|
+
|
|
3
|
+
import tensorflow as tf
|
|
4
|
+
|
|
5
|
+
import matplotlib.pyplot as plt
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import hist
|
|
9
|
+
import math
|
|
10
|
+
|
|
11
|
+
np.random.seed(1234)
|
|
12
|
+
|
|
13
|
+
nevt = 100000
|
|
14
|
+
|
|
15
|
+
rgaus = np.random.normal(size=(nevt,))
|
|
16
|
+
|
|
17
|
+
print(rgaus.dtype)
|
|
18
|
+
print(rgaus)
|
|
19
|
+
|
|
20
|
+
axis0 = hist.axis.Regular(100, -5., 5.)
|
|
21
|
+
|
|
22
|
+
htest = hist.Hist(axis0)
|
|
23
|
+
htest.fill(rgaus)
|
|
24
|
+
|
|
25
|
+
print(htest)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
quant_cdfvals = tf.constant([0.0, 1e-3, 0.02, 0.05, 0.10, 0.20, 0.30, 0.40, 0.50, 0.60, 0.70, 0.80, 0.90, 0.95, 0.98, 1.0-1e-3, 1.0], tf.float64)
|
|
29
|
+
|
|
30
|
+
nquants = quant_cdfvals.shape.num_elements()
|
|
31
|
+
|
|
32
|
+
def func_transform_cdf(quantile):
|
|
33
|
+
const_sqrt2 = tf.constant(math.sqrt(2.), quantile.dtype)
|
|
34
|
+
return 0.5*(1. + tf.math.erf(quantile/const_sqrt2))
|
|
35
|
+
|
|
36
|
+
def func_transform_quantile(cdf):
|
|
37
|
+
const_sqrt2 = tf.constant(math.sqrt(2.), cdf.dtype)
|
|
38
|
+
return const_sqrt2*tf.math.erfinv(2*cdf - 1.)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def func_cdf(xvals, xedges, parms, quant_cdfvals):
|
|
43
|
+
qparms = parms
|
|
44
|
+
|
|
45
|
+
cdf = narf.fitutils.func_cdf_for_quantile_fit(xvals, xedges, qparms, quant_cdfvals, transform = (func_transform_cdf, func_transform_quantile))
|
|
46
|
+
|
|
47
|
+
return cdf
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
#this is just for plotting
|
|
51
|
+
def func_pdf(h, parms):
|
|
52
|
+
dtype = tf.float64
|
|
53
|
+
xvals = [tf.constant(center, dtype=dtype) for center in h.axes.centers]
|
|
54
|
+
xedges = [tf.constant(edge, dtype=dtype) for edge in h.axes.edges]
|
|
55
|
+
|
|
56
|
+
tfparms = tf.constant(parms)
|
|
57
|
+
|
|
58
|
+
cdf = func_cdf(xvals, xedges, tfparms, quant_cdfvals)
|
|
59
|
+
|
|
60
|
+
pdf = cdf[1:] - cdf[:-1]
|
|
61
|
+
pdf = tf.maximum(pdf, tf.zeros_like(pdf))
|
|
62
|
+
|
|
63
|
+
return pdf
|
|
64
|
+
|
|
65
|
+
nparms = nquants-1
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
initial_parms = np.array([np.log(1./nparms)]*nparms)
|
|
69
|
+
|
|
70
|
+
res = narf.fitutils.fit_hist(htest, func_cdf, initial_parms, mode="nll_bin_integrated", func_constraint=narf.fitutils.func_constraint_for_quantile_fit, args = (quant_cdfvals,))
|
|
71
|
+
|
|
72
|
+
print(res)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
parmvals = res["x"]
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
pdfvals = func_pdf(htest, parmvals)
|
|
79
|
+
pdfvals *= htest.sum()/np.sum(pdfvals)
|
|
80
|
+
|
|
81
|
+
#
|
|
82
|
+
plot = plt.figure()
|
|
83
|
+
plt.yscale("log")
|
|
84
|
+
htest.plot()
|
|
85
|
+
plt.plot(htest.axes[0].centers, pdfvals)
|
|
86
|
+
# plt.show()
|
|
87
|
+
plot.savefig("test.png")
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
|
|
@@ -0,0 +1,323 @@
|
|
|
1
|
+
import wums.fitutils
|
|
2
|
+
|
|
3
|
+
import tensorflow as tf
|
|
4
|
+
|
|
5
|
+
import matplotlib.pyplot as plt
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import hist
|
|
9
|
+
import math
|
|
10
|
+
|
|
11
|
+
import onnx
|
|
12
|
+
import tf2onnx
|
|
13
|
+
|
|
14
|
+
np.random.seed(1234)
|
|
15
|
+
|
|
16
|
+
nevt = 20000
|
|
17
|
+
|
|
18
|
+
runiform = np.random.random((nevt,))
|
|
19
|
+
rgaus = np.random.normal(size=(nevt,))
|
|
20
|
+
|
|
21
|
+
data = np.stack([runiform, rgaus], axis=-1)
|
|
22
|
+
|
|
23
|
+
# "pt"-dependent mean and sigma
|
|
24
|
+
data[:,1] = -0.1 + 0.1*data[:,0] + (1. + 0.2*data[:,0])*data[:,1]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
# print(rgaus.dtype)
|
|
28
|
+
# print(rgaus)
|
|
29
|
+
|
|
30
|
+
axis0 = hist.axis.Regular(50, 0., 1., name="pt")
|
|
31
|
+
axis1 = hist.axis.Regular(100, -5., 5., name="recoil")
|
|
32
|
+
|
|
33
|
+
htest_data = hist.Hist(axis0, axis1)
|
|
34
|
+
htest_mc = hist.Hist(axis0, axis1)
|
|
35
|
+
|
|
36
|
+
# print("data.shape", data.shape)
|
|
37
|
+
htest_data.fill(data[:nevt//2,0], data[:nevt//2, 1])
|
|
38
|
+
htest_mc.fill(data[nevt//2:,0], data[nevt//2:, 1])
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
quant_cdfvals = tf.constant([0.0, 1e-3, 0.02, 0.05, 0.10, 0.20, 0.30, 0.40, 0.50, 0.60, 0.70, 0.80, 0.90, 0.95, 0.98, 1.0-1e-3, 1.0], dtype = tf.float64)
|
|
44
|
+
nquants = quant_cdfvals.shape.num_elements()
|
|
45
|
+
|
|
46
|
+
print("nquants", nquants)
|
|
47
|
+
|
|
48
|
+
#cdf is in terms of axis1, so shapes need to be compatible
|
|
49
|
+
quant_cdfvals = quant_cdfvals[None, :]
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
# get quantiles from histogram, e.g. to help initialize the parameters for the fit (not actually used here)
|
|
53
|
+
|
|
54
|
+
# hist_quantiles, hist_quantile_errs = wums.fitutils.hist_to_quantiles(htest, quant_cdfvals, axis=1)
|
|
55
|
+
#
|
|
56
|
+
# print(hist_quantiles)
|
|
57
|
+
# print(hist_quantile_errs)
|
|
58
|
+
#
|
|
59
|
+
# hist_qparms, hist_qparm_errs = wums.fitutils.quantiles_to_qparms(hist_quantiles, hist_quantile_errs)
|
|
60
|
+
#
|
|
61
|
+
# print(hist_qparms)
|
|
62
|
+
# print(hist_qparm_errs)
|
|
63
|
+
|
|
64
|
+
def parms_to_qparms(xvals, parms):
|
|
65
|
+
|
|
66
|
+
parms_2d = tf.reshape(parms, (-1, 2))
|
|
67
|
+
parms_const = parms_2d[:,0]
|
|
68
|
+
parms_slope = parms_2d[:,1]
|
|
69
|
+
|
|
70
|
+
#cdf is in terms of axis1, so shapes need to be compatible
|
|
71
|
+
parms_const = parms_const[None, :]
|
|
72
|
+
parms_slope = parms_slope[None, :]
|
|
73
|
+
|
|
74
|
+
qparms = parms_const + parms_slope*xvals[0]
|
|
75
|
+
|
|
76
|
+
return qparms
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def func_transform_cdf(quantile):
|
|
80
|
+
const_sqrt2 = tf.constant(math.sqrt(2.), quantile.dtype)
|
|
81
|
+
return 0.5*(1. + tf.math.erf(quantile/const_sqrt2))
|
|
82
|
+
|
|
83
|
+
def func_transform_quantile(cdf):
|
|
84
|
+
const_sqrt2 = tf.constant(math.sqrt(2.), cdf.dtype)
|
|
85
|
+
return const_sqrt2*tf.math.erfinv(2.*cdf - 1.)
|
|
86
|
+
|
|
87
|
+
# def func_transform_cdf(quantile):
|
|
88
|
+
# return tf.math.log(quantile/(1.-quantile))
|
|
89
|
+
#
|
|
90
|
+
# def func_transform_quantile(cdf):
|
|
91
|
+
# return tf.math.sigmoid(cdf)
|
|
92
|
+
|
|
93
|
+
def func_cdf(xvals, xedges, parms):
|
|
94
|
+
qparms = parms_to_qparms(xvals, parms)
|
|
95
|
+
# return wums.fitutils.func_cdf_for_quantile_fit(xvals, xedges, qparms, quant_cdfvals, axis=1)
|
|
96
|
+
|
|
97
|
+
return wums.fitutils.func_cdf_for_quantile_fit(xvals, xedges, qparms, quant_cdfvals, axis=1, transform = (func_transform_cdf, func_transform_quantile))
|
|
98
|
+
|
|
99
|
+
def func_constraint(xvals, xedges, parms):
|
|
100
|
+
qparms = parms_to_qparms(xvals, parms)
|
|
101
|
+
return wums.fitutils.func_constraint_for_quantile_fit(xvals, xedges, qparms)
|
|
102
|
+
|
|
103
|
+
#this is just for plotting
|
|
104
|
+
def func_pdf(h, parms):
|
|
105
|
+
dtype = tf.float64
|
|
106
|
+
xvals = [tf.constant(center, dtype=dtype) for center in h.axes.centers]
|
|
107
|
+
xedges = [tf.constant(edge, dtype=dtype) for edge in h.axes.edges]
|
|
108
|
+
|
|
109
|
+
tfparms = tf.constant(parms)
|
|
110
|
+
|
|
111
|
+
cdf = func_cdf(xvals, xedges, tfparms)
|
|
112
|
+
|
|
113
|
+
pdf = cdf[:,1:] - cdf[:,:-1]
|
|
114
|
+
pdf = tf.maximum(pdf, tf.zeros_like(pdf))
|
|
115
|
+
|
|
116
|
+
return pdf
|
|
117
|
+
|
|
118
|
+
nparms = nquants-1
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
# print("edges", htest.edges)
|
|
122
|
+
|
|
123
|
+
# assert(0)
|
|
124
|
+
|
|
125
|
+
initial_parms_const = np.array([np.log(1./nparms)]*nparms)
|
|
126
|
+
initial_parms_slope = np.zeros_like(initial_parms_const)
|
|
127
|
+
|
|
128
|
+
initial_parms = np.stack([initial_parms_const, initial_parms_slope], axis=-1)
|
|
129
|
+
initial_parms = np.reshape(initial_parms, (-1,))
|
|
130
|
+
|
|
131
|
+
res_data = wums.fitutils.fit_hist(htest_data, func_cdf, initial_parms, mode="nll_bin_integrated", norm_axes=[1])
|
|
132
|
+
|
|
133
|
+
res_mc = wums.fitutils.fit_hist(htest_mc, func_cdf, initial_parms, mode="nll_bin_integrated", norm_axes=[1])
|
|
134
|
+
|
|
135
|
+
print(res_data)
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
parmvals_data = tf.constant(res_data["x"], tf.float64)
|
|
139
|
+
parmvals_mc = tf.constant(res_mc["x"], tf.float64)
|
|
140
|
+
|
|
141
|
+
hess_data = res_data["hess"]
|
|
142
|
+
hess_mc = res_mc["hess"]
|
|
143
|
+
|
|
144
|
+
def get_scaled_eigenvectors(hess, num_null = 2):
|
|
145
|
+
e,v = np.linalg.eigh(hess)
|
|
146
|
+
|
|
147
|
+
# remove the null eigenvectors
|
|
148
|
+
e = e[None, num_null:]
|
|
149
|
+
v = v[:, num_null:]
|
|
150
|
+
|
|
151
|
+
# scale the eigenvectors
|
|
152
|
+
vscaled = v/np.sqrt(e)
|
|
153
|
+
|
|
154
|
+
return vscaled
|
|
155
|
+
|
|
156
|
+
vscaled_data = tf.constant(get_scaled_eigenvectors(hess_data), tf.float64)
|
|
157
|
+
vscaled_mc = tf.constant(get_scaled_eigenvectors(hess_data), tf.float64)
|
|
158
|
+
|
|
159
|
+
print("vscaled_data.shape", vscaled_data.shape)
|
|
160
|
+
|
|
161
|
+
ut_flat = np.reshape(htest_data.axes.edges[1], (-1,))
|
|
162
|
+
ut_low = tf.constant(ut_flat[0], tf.float64)
|
|
163
|
+
ut_high = tf.constant(ut_flat[-1], tf.float64)
|
|
164
|
+
|
|
165
|
+
def func_cdf_mc(pt, ut):
|
|
166
|
+
pts = tf.reshape(pt, (1,1))
|
|
167
|
+
uts = tf.reshape(ut, (1,1))
|
|
168
|
+
|
|
169
|
+
xvals = [pts, None]
|
|
170
|
+
xedges = [None, uts]
|
|
171
|
+
|
|
172
|
+
parms = parmvals_mc
|
|
173
|
+
|
|
174
|
+
qparms = parms_to_qparms(xvals, parms)
|
|
175
|
+
|
|
176
|
+
ut_axis = 1
|
|
177
|
+
|
|
178
|
+
quants = wums.fitutils.qparms_to_quantiles(qparms, x_low = ut_low, x_high = ut_high, axis = ut_axis)
|
|
179
|
+
spline_edges = xedges[ut_axis]
|
|
180
|
+
|
|
181
|
+
cdfvals = wums.fitutils.pchip_interpolate(quants, quant_cdfvals, spline_edges, axis=ut_axis)
|
|
182
|
+
|
|
183
|
+
return cdfvals
|
|
184
|
+
|
|
185
|
+
def func_cdfinv_data(pt, quant):
|
|
186
|
+
pts = tf.reshape(pt, (1,1))
|
|
187
|
+
quant_outs = tf.reshape(quant, (1,1))
|
|
188
|
+
|
|
189
|
+
xvals = [pts, None]
|
|
190
|
+
xedges = [None, quant_outs]
|
|
191
|
+
|
|
192
|
+
parms = parmvals_data
|
|
193
|
+
|
|
194
|
+
qparms = parms_to_qparms(xvals, parms)
|
|
195
|
+
|
|
196
|
+
ut_axis = 1
|
|
197
|
+
|
|
198
|
+
quants = wums.fitutils.qparms_to_quantiles(qparms, x_low = ut_low, x_high = ut_high, axis = ut_axis)
|
|
199
|
+
spline_edges = xedges[ut_axis]
|
|
200
|
+
|
|
201
|
+
cdfinvvals = wums.fitutils.pchip_interpolate(quant_cdfvals, quants, spline_edges, axis=ut_axis)
|
|
202
|
+
|
|
203
|
+
return cdfinvvals
|
|
204
|
+
|
|
205
|
+
def func_cdfinv_pdf_data(pt, quant):
|
|
206
|
+
with tf.GradientTape() as t:
|
|
207
|
+
t.watch(quant)
|
|
208
|
+
cdfinv = func_cdfinv_data(pt, quant)
|
|
209
|
+
pdfreciprocal = t.gradient(cdfinv, quant)
|
|
210
|
+
pdf = 1./pdfreciprocal
|
|
211
|
+
return cdfinv, pdf
|
|
212
|
+
|
|
213
|
+
scalar_spec = tf.TensorSpec([], tf.float64)
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
def transform_mc(pt, ut):
|
|
217
|
+
with tf.GradientTape(persistent=True) as t:
|
|
218
|
+
t.watch(parmvals_mc)
|
|
219
|
+
t.watch(parmvals_data)
|
|
220
|
+
|
|
221
|
+
cdf_mc = func_cdf_mc(pt, ut)
|
|
222
|
+
ut_transformed, pdf = func_cdfinv_pdf_data(pt, cdf_mc)
|
|
223
|
+
|
|
224
|
+
ut_transformed = tf.reshape(ut_transformed, [])
|
|
225
|
+
pdf = tf.reshape(pdf, [])
|
|
226
|
+
|
|
227
|
+
pdf_grad_mc = t.gradient(pdf, parmvals_mc)
|
|
228
|
+
pdf_grad_data = t.gradient(pdf, parmvals_data)
|
|
229
|
+
|
|
230
|
+
del t
|
|
231
|
+
|
|
232
|
+
weight_grad_mc = pdf_grad_mc/pdf
|
|
233
|
+
weight_grad_data = pdf_grad_data/pdf
|
|
234
|
+
|
|
235
|
+
weight_grad_mc = weight_grad_mc[None, :]
|
|
236
|
+
weight_grad_data = weight_grad_data[None, :]
|
|
237
|
+
|
|
238
|
+
weight_grad_mc_eig = weight_grad_mc @ vscaled_mc
|
|
239
|
+
weight_grad_data_eig = weight_grad_data @ vscaled_data
|
|
240
|
+
|
|
241
|
+
weight_grad_mc_eig = tf.reshape(weight_grad_mc_eig, [-1])
|
|
242
|
+
weight_grad_data_eig = tf.reshape(weight_grad_data_eig, [-1])
|
|
243
|
+
|
|
244
|
+
weight_grad_eig = tf.concat([weight_grad_mc_eig, weight_grad_data_eig], axis=0)
|
|
245
|
+
|
|
246
|
+
return ut_transformed, weight_grad_eig
|
|
247
|
+
# return ut_transformed
|
|
248
|
+
|
|
249
|
+
@tf.function
|
|
250
|
+
def transform_mc_simple(pt, ut):
|
|
251
|
+
cdf_mc = func_cdf_mc(pt, ut)
|
|
252
|
+
ut_transformed, pdf = func_cdfinv_pdf_data(pt, cdf_mc)
|
|
253
|
+
|
|
254
|
+
ut_transformed = tf.reshape(ut_transformed, [])
|
|
255
|
+
|
|
256
|
+
return ut_transformed
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
pt_test = tf.constant(0.2, tf.float64)
|
|
261
|
+
ut_test = tf.constant(1.0, tf.float64)
|
|
262
|
+
|
|
263
|
+
ut, grad = transform_mc(pt_test, ut_test)
|
|
264
|
+
# ut = transform_mc(pt_test, ut_test)
|
|
265
|
+
|
|
266
|
+
print("shapes", ut.shape, grad.shape)
|
|
267
|
+
|
|
268
|
+
print("ut", ut)
|
|
269
|
+
print("grad", grad)
|
|
270
|
+
|
|
271
|
+
input_signature = [scalar_spec, scalar_spec]
|
|
272
|
+
|
|
273
|
+
class TestMod(tf.Module):
|
|
274
|
+
|
|
275
|
+
@tf.function(input_signature = [scalar_spec, scalar_spec])
|
|
276
|
+
def __call__(self, pt, ut):
|
|
277
|
+
return transform_mc(pt, ut)
|
|
278
|
+
|
|
279
|
+
module = TestMod()
|
|
280
|
+
# tf.saved_model.save(module, "test")
|
|
281
|
+
|
|
282
|
+
concrete_function = module.__call__.get_concrete_function()
|
|
283
|
+
|
|
284
|
+
# Convert the model
|
|
285
|
+
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_function], module)
|
|
286
|
+
|
|
287
|
+
# converter = tf.lite.TFLiteConverter.from_saved_model("test") # path to the SavedModel directory
|
|
288
|
+
converter.target_spec.supported_ops = [
|
|
289
|
+
tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.
|
|
290
|
+
tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops.
|
|
291
|
+
]
|
|
292
|
+
|
|
293
|
+
tflite_model = converter.convert()
|
|
294
|
+
|
|
295
|
+
# print(tflite_model)
|
|
296
|
+
|
|
297
|
+
# Save the model.
|
|
298
|
+
with open('model.tflite', 'wb') as f:
|
|
299
|
+
f.write(tflite_model)
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
# onnx_model, _ = tf2onnx.convert.from_function(transform_mc, input_signature)
|
|
303
|
+
# onnx.save(onnx_model, "test.onnx")
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
parmvals = res_data["x"]
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
pdfvals = func_pdf(htest_data, parmvals)
|
|
310
|
+
pdfvals *= htest_data.sum()/np.sum(pdfvals)
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
# hplot = htest[5]
|
|
314
|
+
|
|
315
|
+
plot = plt.figure()
|
|
316
|
+
plt.yscale("log")
|
|
317
|
+
htest_data[5,:].plot()
|
|
318
|
+
plt.plot(htest_data.axes[1].centers, pdfvals[5])
|
|
319
|
+
# plt.show()
|
|
320
|
+
plot.savefig("test.png")
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
|
wums/boostHistHelpers.py
CHANGED
|
@@ -40,7 +40,7 @@ def broadcastSystHist(h1, h2, flow=True, by_ax_name=True):
|
|
|
40
40
|
h2.ndim - 1 - i: h2.values(flow=flow).shape[h2.ndim - 1 - i]
|
|
41
41
|
for i in range(h2.ndim - h1.ndim)
|
|
42
42
|
}
|
|
43
|
-
|
|
43
|
+
|
|
44
44
|
broadcast_shape = list(moves.values()) + list(s1)
|
|
45
45
|
|
|
46
46
|
try:
|
|
@@ -53,6 +53,10 @@ def broadcastSystHist(h1, h2, flow=True, by_ax_name=True):
|
|
|
53
53
|
f" h2.axes: {h2.axes}"
|
|
54
54
|
)
|
|
55
55
|
|
|
56
|
+
if by_ax_name:
|
|
57
|
+
# We also have to move axes that are in common between h1 and h2 but in different order
|
|
58
|
+
moves.update({sum([k<i for k in moves.keys()]) + h1.axes.name.index(n2): None for i, n2 in enumerate(h2.axes.name) if n2 in h1.axes.name})
|
|
59
|
+
|
|
56
60
|
# move back to original order
|
|
57
61
|
new_vals = np.moveaxis(new_vals, np.arange(len(moves)), list(moves.keys()))
|
|
58
62
|
|