data-manipulation-utilities 0.2.7__py3-none-any.whl → 0.2.8.dev720__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {data_manipulation_utilities-0.2.7.dist-info → data_manipulation_utilities-0.2.8.dev720.dist-info}/METADATA +669 -42
- data_manipulation_utilities-0.2.8.dev720.dist-info/RECORD +45 -0
- {data_manipulation_utilities-0.2.7.dist-info → data_manipulation_utilities-0.2.8.dev720.dist-info}/WHEEL +1 -2
- data_manipulation_utilities-0.2.8.dev720.dist-info/entry_points.txt +8 -0
- dmu/generic/hashing.py +34 -8
- dmu/generic/utilities.py +164 -11
- dmu/logging/log_store.py +34 -2
- dmu/logging/messages.py +96 -0
- dmu/ml/cv_classifier.py +3 -3
- dmu/ml/cv_diagnostics.py +3 -0
- dmu/ml/cv_performance.py +58 -0
- dmu/ml/cv_predict.py +149 -46
- dmu/ml/train_mva.py +482 -100
- dmu/ml/utilities.py +29 -10
- dmu/pdataframe/utilities.py +28 -3
- dmu/plotting/fwhm.py +2 -2
- dmu/plotting/matrix.py +1 -1
- dmu/plotting/plotter.py +23 -3
- dmu/plotting/plotter_1d.py +96 -32
- dmu/plotting/plotter_2d.py +5 -0
- dmu/rdataframe/utilities.py +54 -3
- dmu/rfile/ddfgetter.py +102 -0
- dmu/stats/fit_stats.py +129 -0
- dmu/stats/fitter.py +55 -22
- dmu/stats/gof_calculator.py +7 -0
- dmu/stats/model_factory.py +153 -62
- dmu/stats/parameters.py +100 -0
- dmu/stats/utilities.py +443 -12
- dmu/stats/wdata.py +187 -0
- dmu/stats/zfit.py +17 -0
- dmu/stats/zfit_plotter.py +147 -36
- dmu/testing/utilities.py +102 -24
- dmu/workflow/__init__.py +0 -0
- dmu/workflow/cache.py +266 -0
- data_manipulation_utilities-0.2.7.data/scripts/publish +0 -89
- data_manipulation_utilities-0.2.7.dist-info/RECORD +0 -69
- data_manipulation_utilities-0.2.7.dist-info/entry_points.txt +0 -6
- data_manipulation_utilities-0.2.7.dist-info/top_level.txt +0 -3
- dmu_data/ml/tests/diagnostics_from_file.yaml +0 -13
- dmu_data/ml/tests/diagnostics_from_model.yaml +0 -10
- dmu_data/ml/tests/diagnostics_multiple_methods.yaml +0 -10
- dmu_data/ml/tests/diagnostics_overlay.yaml +0 -33
- dmu_data/ml/tests/train_mva.yaml +0 -58
- dmu_data/ml/tests/train_mva_with_diagnostics.yaml +0 -82
- dmu_data/plotting/tests/2d.yaml +0 -24
- dmu_data/plotting/tests/fig_size.yaml +0 -13
- dmu_data/plotting/tests/high_stat.yaml +0 -22
- dmu_data/plotting/tests/legend.yaml +0 -12
- dmu_data/plotting/tests/name.yaml +0 -14
- dmu_data/plotting/tests/no_bounds.yaml +0 -12
- dmu_data/plotting/tests/normalized.yaml +0 -9
- dmu_data/plotting/tests/plug_fwhm.yaml +0 -24
- dmu_data/plotting/tests/plug_stats.yaml +0 -19
- dmu_data/plotting/tests/simple.yaml +0 -9
- dmu_data/plotting/tests/stats.yaml +0 -9
- dmu_data/plotting/tests/styling.yaml +0 -11
- dmu_data/plotting/tests/title.yaml +0 -14
- dmu_data/plotting/tests/weights.yaml +0 -13
- dmu_data/text/transform.toml +0 -4
- dmu_data/text/transform.txt +0 -6
- dmu_data/text/transform_set.toml +0 -8
- dmu_data/text/transform_set.txt +0 -6
- dmu_data/text/transform_trf.txt +0 -12
- dmu_scripts/git/publish +0 -89
- dmu_scripts/physics/check_truth.py +0 -121
- dmu_scripts/rfile/compare_root_files.py +0 -299
- dmu_scripts/rfile/print_trees.py +0 -35
- dmu_scripts/ssh/coned.py +0 -168
- dmu_scripts/text/transform_text.py +0 -46
- {dmu_data → dmu}/__init__.py +0 -0
dmu/stats/utilities.py
CHANGED
@@ -1,18 +1,112 @@
|
|
1
1
|
'''
|
2
2
|
Module with utility functions related to the dmu.stats project
|
3
3
|
'''
|
4
|
+
# pylint: disable=import-error
|
5
|
+
|
4
6
|
import os
|
5
7
|
import re
|
8
|
+
import pickle
|
6
9
|
from typing import Union
|
7
|
-
import zfit
|
8
10
|
|
9
|
-
|
11
|
+
import numpy
|
12
|
+
import pandas as pnd
|
13
|
+
import matplotlib.pyplot as plt
|
14
|
+
|
15
|
+
import dmu.pdataframe.utilities as put
|
16
|
+
import dmu.generic.utilities as gut
|
17
|
+
|
18
|
+
from dmu.stats.zfit import zfit
|
19
|
+
from dmu.stats.fitter import Fitter
|
20
|
+
from dmu.stats.zfit_plotter import ZFitPlotter
|
21
|
+
from dmu.logging.log_store import LogStore
|
22
|
+
|
23
|
+
import tensorflow as tf
|
24
|
+
|
25
|
+
from omegaconf import OmegaConf, DictConfig
|
26
|
+
from zfit.core.interfaces import ZfitData as zdata
|
27
|
+
from zfit.core.interfaces import ZfitSpace as zobs
|
28
|
+
from zfit.core.interfaces import ZfitPDF as zpdf
|
29
|
+
from zfit.core.parameter import Parameter as zpar
|
30
|
+
from zfit.result import FitResult as zres
|
10
31
|
|
11
32
|
log = LogStore.add_logger('dmu:stats:utilities')
|
12
33
|
#-------------------------------------------------------
|
34
|
+
class Data:
|
35
|
+
'''
|
36
|
+
Data class
|
37
|
+
'''
|
38
|
+
weight_name = 'weight'
|
39
|
+
#-------------------------------------------------------
|
40
|
+
def name_from_obs(obs : zobs) -> str:
|
41
|
+
'''
|
42
|
+
Takes zfit observable, returns its name
|
43
|
+
It is assumed this is a 1D observable
|
44
|
+
'''
|
45
|
+
if not isinstance(obs.obs, tuple):
|
46
|
+
raise ValueError(f'Cannot retrieve name for: {obs}')
|
47
|
+
|
48
|
+
if len(obs.obs) != 1:
|
49
|
+
raise ValueError(f'Observable is not 1D: {obs.obs}')
|
50
|
+
|
51
|
+
return obs.obs[0]
|
52
|
+
#-------------------------------------------------------
|
53
|
+
def range_from_obs(obs : zobs) -> tuple[float,float]:
|
54
|
+
'''
|
55
|
+
Takes zfit observable, returns tuple with two floats, representing range
|
56
|
+
'''
|
57
|
+
if not isinstance(obs.limits, tuple):
|
58
|
+
raise ValueError(f'Cannot retrieve name for: {obs}')
|
59
|
+
|
60
|
+
if len(obs.limits) != 2:
|
61
|
+
raise ValueError(f'Observable has more than one range: {obs.limits}')
|
62
|
+
|
63
|
+
minx, maxx = obs.limits
|
64
|
+
|
65
|
+
return float(minx[0][0]), float(maxx[0][0])
|
66
|
+
#-------------------------------------------------------
|
67
|
+
def yield_from_zdata(data : zdata) -> float:
|
68
|
+
'''
|
69
|
+
Parameter
|
70
|
+
--------------
|
71
|
+
data : Zfit dataset
|
72
|
+
|
73
|
+
Returns
|
74
|
+
--------------
|
75
|
+
Yield, i.e. number of entries or sum of weights if weighted dataset
|
76
|
+
'''
|
77
|
+
|
78
|
+
if data.weights is not None:
|
79
|
+
val = data.weights.numpy().sum()
|
80
|
+
else:
|
81
|
+
arr_val = data.to_numpy()
|
82
|
+
val = len(arr_val)
|
83
|
+
|
84
|
+
if val < 0:
|
85
|
+
raise ValueError(f'Yield cannot be negative, found {val}')
|
86
|
+
|
87
|
+
return val
|
88
|
+
#-------------------------------------------------------
|
89
|
+
# Check PDF
|
90
|
+
#-------------------------------------------------------
|
91
|
+
def is_pdf_usable(pdf : zpdf) -> zpdf:
|
92
|
+
'''
|
93
|
+
Will check if the PDF is usable
|
94
|
+
'''
|
95
|
+
[[[minx]], [[maxx]]]= pdf.space.limits
|
96
|
+
|
97
|
+
arr_x = numpy.linspace(minx, maxx, 100)
|
98
|
+
|
99
|
+
try:
|
100
|
+
pdf.pdf(arr_x)
|
101
|
+
except tf.errors.InvalidArgumentError:
|
102
|
+
log.warning('PDF cannot be evaluated')
|
103
|
+
return False
|
104
|
+
|
105
|
+
return True
|
106
|
+
#-------------------------------------------------------
|
13
107
|
#Zfit/print_pdf
|
14
108
|
#-------------------------------------------------------
|
15
|
-
def _get_const(par :
|
109
|
+
def _get_const(par : zpar , d_const : Union[None, dict[str, tuple[float,float]]]) -> str:
|
16
110
|
'''
|
17
111
|
Takes zfit parameter and dictionary of constraints
|
18
112
|
Returns a formatted string with the value of the constraint on that parameter
|
@@ -23,13 +117,13 @@ def _get_const(par : zfit.Parameter, d_const : Union[None, dict[str, list[float]
|
|
23
117
|
obj = d_const[par.name]
|
24
118
|
if isinstance(obj, (list, tuple)):
|
25
119
|
[mu, sg] = obj
|
26
|
-
val = f'{mu:.3e}
|
120
|
+
val = f'{mu:.3e}___{sg:.3e}' # This separator needs to be readable and not a space
|
27
121
|
else:
|
28
122
|
val = str(obj)
|
29
123
|
|
30
124
|
return val
|
31
125
|
#-------------------------------------------------------
|
32
|
-
def _blind_vars(s_par : set, l_blind : Union[list[str], None] = None) -> set[
|
126
|
+
def _blind_vars(s_par : set, l_blind : Union[list[str], None] = None) -> set[zpar]:
|
33
127
|
'''
|
34
128
|
Takes set of zfit parameters and list of parameter names to blind
|
35
129
|
returns set of zfit parameters that should be blinded
|
@@ -45,7 +139,7 @@ def _blind_vars(s_par : set, l_blind : Union[list[str], None] = None) -> set[zfi
|
|
45
139
|
return s_par_blind
|
46
140
|
#-------------------------------------------------------
|
47
141
|
def _get_pars(
|
48
|
-
pdf
|
142
|
+
pdf : zpdf,
|
49
143
|
blind : Union[None, list[str]]) -> tuple[list, list]:
|
50
144
|
|
51
145
|
s_par_flt = pdf.get_params(floating= True)
|
@@ -63,7 +157,7 @@ def _get_pars(
|
|
63
157
|
return l_par_flt, l_par_fix
|
64
158
|
#-------------------------------------------------------
|
65
159
|
def _get_messages(
|
66
|
-
pdf :
|
160
|
+
pdf : zpdf,
|
67
161
|
l_par_flt : list,
|
68
162
|
l_par_fix : list,
|
69
163
|
d_const : Union[None, dict[str,list[float]]] = None) -> list[str]:
|
@@ -95,11 +189,11 @@ def _get_messages(
|
|
95
189
|
return l_msg
|
96
190
|
#-------------------------------------------------------
|
97
191
|
def print_pdf(
|
98
|
-
pdf :
|
99
|
-
d_const : Union[None, dict[str,
|
100
|
-
txt_path : Union[str,None]
|
101
|
-
level : int
|
102
|
-
blind : Union[None, list[str]]
|
192
|
+
pdf : zpdf,
|
193
|
+
d_const : Union[None, dict[str,tuple[float, float]]] = None,
|
194
|
+
txt_path : Union[str,None] = None,
|
195
|
+
level : int = 20,
|
196
|
+
blind : Union[None, list[str]] = None):
|
103
197
|
'''
|
104
198
|
Function used to print zfit PDFs
|
105
199
|
|
@@ -131,4 +225,341 @@ def print_pdf(
|
|
131
225
|
log.debug(msg)
|
132
226
|
else:
|
133
227
|
raise ValueError(f'Invalid level: {level}')
|
228
|
+
#---------------------------------------------
|
229
|
+
def _parameters_from_result(result : zres) -> dict[str,tuple[float,float]]:
|
230
|
+
d_par = {}
|
231
|
+
log.debug('Reading parameters from:')
|
232
|
+
if log.getEffectiveLevel() == 10:
|
233
|
+
print(result)
|
234
|
+
|
235
|
+
log.debug(60 * '-')
|
236
|
+
log.debug('Reading parameters')
|
237
|
+
log.debug(60 * '-')
|
238
|
+
for name, d_val in result.params.items():
|
239
|
+
value = d_val['value']
|
240
|
+
error = None
|
241
|
+
if 'hesse' in d_val:
|
242
|
+
error = d_val['hesse']['error']
|
243
|
+
|
244
|
+
if 'minuit_hesse' in d_val:
|
245
|
+
error = d_val['minuit_hesse']['error']
|
246
|
+
|
247
|
+
log.debug(f'{name:<20}{value:<20.3f}{error}')
|
248
|
+
|
249
|
+
d_par[name] = value, error
|
250
|
+
|
251
|
+
return d_par
|
252
|
+
#---------------------------------------------
|
253
|
+
def save_fit(
|
254
|
+
data : zdata,
|
255
|
+
model : zpdf|None,
|
256
|
+
res : zres|None,
|
257
|
+
fit_dir : str,
|
258
|
+
d_const : dict[str,tuple[float,float]]|None = None) -> None:
|
259
|
+
'''
|
260
|
+
Function used to save fit results, meant to reduce boiler plate code
|
261
|
+
|
262
|
+
Plots: If:
|
263
|
+
|
264
|
+
ptr = ZFitPlotter(data=dat, model=pdf)
|
265
|
+
ptr.plot()
|
266
|
+
|
267
|
+
was done before calling this method, the plot will also be saved
|
268
|
+
|
269
|
+
Parameters
|
270
|
+
--------------------
|
271
|
+
model: PDF to be plotted, if None, will skip steps
|
272
|
+
'''
|
273
|
+
os.makedirs(fit_dir, exist_ok=True)
|
274
|
+
log.info(f'Saving fit to: {fit_dir}')
|
275
|
+
|
276
|
+
if plt.get_fignums():
|
277
|
+
fit_path = f'{fit_dir}/fit.png'
|
278
|
+
log.info(f'Saving fit to: {fit_path}')
|
279
|
+
plt.savefig(fit_path)
|
280
|
+
plt.close('all')
|
281
|
+
else:
|
282
|
+
log.info('No fit plot found')
|
283
|
+
|
284
|
+
_save_result(fit_dir=fit_dir, res=res)
|
285
|
+
|
286
|
+
df = data.to_pandas(weightsname=Data.weight_name)
|
287
|
+
opath = f'{fit_dir}/data.json'
|
288
|
+
log.debug(f'Saving data to: {opath}')
|
289
|
+
df.to_json(opath, indent=2)
|
290
|
+
|
291
|
+
if model is None:
|
292
|
+
return
|
293
|
+
|
294
|
+
print_pdf(model, txt_path=f'{fit_dir}/post_fit.txt', d_const=d_const)
|
295
|
+
pdf_to_tex(path=f'{fit_dir}/post_fit.txt', d_par={'mu' : r'$\mu$'}, skip_fixed=True)
|
296
|
+
#-------------------------------------------------------
|
297
|
+
def _save_result(fit_dir : str, res : zres|None) -> None:
|
298
|
+
'''
|
299
|
+
Saves result as yaml, JSON, pkl
|
300
|
+
|
301
|
+
Parameters
|
302
|
+
---------------
|
303
|
+
fit_dir: Directory where fit result will go
|
304
|
+
res : Zfit result object
|
305
|
+
'''
|
306
|
+
if res is None:
|
307
|
+
log.info('No result object found, not saving parameters in pkl or JSON')
|
308
|
+
return
|
309
|
+
|
310
|
+
# TODO: Remove this once there be a safer way to freeze
|
311
|
+
# see https://github.com/zfit/zfit/issues/632
|
312
|
+
try:
|
313
|
+
res.freeze()
|
314
|
+
except AttributeError:
|
315
|
+
pass
|
316
|
+
|
317
|
+
with open(f'{fit_dir}/fit.pkl', 'wb') as ofile:
|
318
|
+
pickle.dump(res, ofile)
|
319
|
+
|
320
|
+
d_par = _parameters_from_result(result=res)
|
321
|
+
opath = f'{fit_dir}/parameters.json'
|
322
|
+
log.debug(f'Saving parameters to: {opath}')
|
323
|
+
gut.dump_json(d_par, opath)
|
324
|
+
|
325
|
+
opath = f'{fit_dir}/parameters.yaml'
|
326
|
+
cres = zres_to_cres(res=res)
|
327
|
+
OmegaConf.save(cres, opath)
|
328
|
+
#-------------------------------------------------------
|
329
|
+
# Make latex table from text file
|
330
|
+
#-------------------------------------------------------
|
331
|
+
def _reformat_expo(val : str) -> str:
|
332
|
+
regex = r'([\d\.]+)e([-,\d]+)'
|
333
|
+
mtch = re.match(regex, val)
|
334
|
+
if not mtch:
|
335
|
+
raise ValueError(f'Cannot extract value and exponent from: {val}')
|
336
|
+
|
337
|
+
[val, exp] = mtch.groups()
|
338
|
+
exp = int(exp)
|
339
|
+
|
340
|
+
return f'{val}\cdot 10^{{{exp}}}'
|
134
341
|
#-------------------------------------------------------
|
342
|
+
def _format_float_str(val : str) -> str:
|
343
|
+
'''
|
344
|
+
Takes number as string and returns a formatted version
|
345
|
+
'''
|
346
|
+
|
347
|
+
fval = float(val)
|
348
|
+
if abs(fval) > 1000:
|
349
|
+
return f'{fval:,.0f}'
|
350
|
+
|
351
|
+
val = f'{fval:.3g}'
|
352
|
+
if 'e' in val:
|
353
|
+
val = _reformat_expo(val)
|
354
|
+
|
355
|
+
return val
|
356
|
+
#-------------------------------------------------------
|
357
|
+
def _info_from_line(line : str) -> tuple|None:
|
358
|
+
regex = r'(^\S+)\s+(\S+)\s+(\S+)\s+(\S+)\s+(\S+)\s+(\S+)'
|
359
|
+
mtch = re.match(regex, line)
|
360
|
+
if not mtch:
|
361
|
+
return None
|
362
|
+
|
363
|
+
log.debug(f'Reading information from: {line}')
|
364
|
+
|
365
|
+
[par, _, low, high, floating, cons] = mtch.groups()
|
366
|
+
|
367
|
+
low = _format_float_str(low)
|
368
|
+
high = _format_float_str(high)
|
369
|
+
|
370
|
+
if cons != 'none':
|
371
|
+
[mu, sg] = cons.split('___')
|
372
|
+
|
373
|
+
mu = _format_float_str(mu)
|
374
|
+
sg = _format_float_str(sg)
|
375
|
+
|
376
|
+
cons = f'$\mu={mu}; \sigma={sg}$'
|
377
|
+
|
378
|
+
return par, low, high, floating, cons
|
379
|
+
#-------------------------------------------------------
|
380
|
+
def _df_from_lines(l_line : list[str]) -> pnd.DataFrame:
|
381
|
+
df = pnd.DataFrame(columns=['Parameter', 'Low', 'High', 'Floating', 'Constraint'])
|
382
|
+
|
383
|
+
for line in l_line:
|
384
|
+
info = _info_from_line(line=line)
|
385
|
+
if info is None:
|
386
|
+
continue
|
387
|
+
|
388
|
+
par, low, high, floating, cons = info
|
389
|
+
|
390
|
+
df.loc[len(df)] = {'Parameter' : par,
|
391
|
+
'Low' : low,
|
392
|
+
'High' : high,
|
393
|
+
'Floating' : floating,
|
394
|
+
'Constraint': cons,
|
395
|
+
}
|
396
|
+
|
397
|
+
return df
|
398
|
+
#-------------------------------------------------------
|
399
|
+
def pdf_to_tex(path : str, d_par : dict[str,str], skip_fixed : bool = True) -> None:
|
400
|
+
'''
|
401
|
+
Takes
|
402
|
+
|
403
|
+
path: path to a `txt` file produced by stats/utilities:print_pdf
|
404
|
+
d_par: Dictionary mapping parameter names in this file to proper latex names
|
405
|
+
|
406
|
+
Creates a latex table with the same name as `path` but `txt` extension replaced by `tex`
|
407
|
+
'''
|
408
|
+
|
409
|
+
path = str(path)
|
410
|
+
with open(path, encoding='utf-8') as ifile:
|
411
|
+
l_line = ifile.read().splitlines()
|
412
|
+
l_line = l_line[4:] # Remove header
|
413
|
+
|
414
|
+
df = _df_from_lines(l_line)
|
415
|
+
df['Parameter']=df.Parameter.apply(lambda x : d_par.get(x, x.replace('_', ' ')))
|
416
|
+
|
417
|
+
out_path = path.replace('.txt', '.tex')
|
418
|
+
|
419
|
+
if skip_fixed:
|
420
|
+
df = df[df.Floating == '1']
|
421
|
+
df = df.drop(columns='Floating')
|
422
|
+
|
423
|
+
df_1 = df[df.Constraint == 'none']
|
424
|
+
df_2 = df[df.Constraint != 'none']
|
425
|
+
|
426
|
+
df_1 = df_1.sort_values(by='Parameter', ascending=True)
|
427
|
+
df_2 = df_2.sort_values(by='Parameter', ascending=True)
|
428
|
+
df = pnd.concat([df_1, df_2])
|
429
|
+
|
430
|
+
put.df_to_tex(df, out_path)
|
431
|
+
#---------------------------------------------
|
432
|
+
# Fake/Placeholder fit
|
433
|
+
#---------------------------------------------
|
434
|
+
def get_model(
|
435
|
+
kind : str,
|
436
|
+
obs : zobs|None = None,
|
437
|
+
lam : float = -0.0001) -> zpdf:
|
438
|
+
'''
|
439
|
+
Returns zfit PDF for tests
|
440
|
+
|
441
|
+
Parameters:
|
442
|
+
|
443
|
+
kind: 'signal' for Gaussian, 's+b' for Gaussian plus exponential
|
444
|
+
obs : If provided, will use it, by default None and will be built in function
|
445
|
+
lam : Decay constant of exponential component, set to -0.0001 by default
|
446
|
+
'''
|
447
|
+
if obs is None:
|
448
|
+
obs = zfit.Space('mass', limits=(4500, 7000))
|
449
|
+
|
450
|
+
mu = zfit.Parameter('mu', 5200, 4500, 6000)
|
451
|
+
sg = zfit.Parameter('sg', 50, 10, 200)
|
452
|
+
gaus = zfit.pdf.Gauss(obs=obs, mu=mu, sigma=sg)
|
453
|
+
|
454
|
+
if kind == 'signal':
|
455
|
+
return gaus
|
456
|
+
|
457
|
+
c = zfit.Parameter('c', lam, -0.01, 0)
|
458
|
+
expo= zfit.pdf.Exponential(obs=obs, lam=c)
|
459
|
+
|
460
|
+
if kind == 's+b':
|
461
|
+
nexpo = zfit.param.Parameter('nbkg', 1000, 0, 1000_000)
|
462
|
+
ngaus = zfit.param.Parameter('nsig', 1000, 0, 1000_000)
|
463
|
+
|
464
|
+
bkg = expo.create_extended(nexpo)
|
465
|
+
sig = gaus.create_extended(ngaus)
|
466
|
+
pdf = zfit.pdf.SumPDF([bkg, sig])
|
467
|
+
|
468
|
+
return pdf
|
469
|
+
|
470
|
+
raise NotImplementedError(f'Invalid kind of fit: {kind}')
|
471
|
+
#---------------------------------------------
|
472
|
+
def _pdf_to_data(pdf : zpdf, add_weights : bool) -> zdata:
|
473
|
+
nentries = 10_000
|
474
|
+
data = pdf.create_sampler(n=nentries)
|
475
|
+
if not add_weights:
|
476
|
+
return data
|
477
|
+
|
478
|
+
arr_wgt = numpy.random.normal(loc=1, scale=0.1, size=nentries)
|
479
|
+
data = data.with_weights(arr_wgt)
|
480
|
+
|
481
|
+
return data
|
482
|
+
#---------------------------------------------
|
483
|
+
def placeholder_fit(
|
484
|
+
kind : str,
|
485
|
+
fit_dir : str,
|
486
|
+
df : pnd.DataFrame|None = None,
|
487
|
+
plot_fit : bool = True) -> None:
|
488
|
+
'''
|
489
|
+
Function meant to run toy fits that produce output needed as an input
|
490
|
+
to develop tools on top of them
|
491
|
+
|
492
|
+
kind: Kind of fit, e.g. s+b for the simples signal plus background fit
|
493
|
+
fit_dir: Directory where the output of the fit will go
|
494
|
+
df: pandas dataframe if passed, will reuse that data, needed to test data caching
|
495
|
+
plot_fit: Will plot the fit or not, by default True
|
496
|
+
'''
|
497
|
+
pdf = get_model(kind)
|
498
|
+
print_pdf(pdf, txt_path=f'{fit_dir}/pre_fit.txt')
|
499
|
+
if df is None:
|
500
|
+
log.warning('Using user provided data')
|
501
|
+
data = _pdf_to_data(pdf=pdf, add_weights=True)
|
502
|
+
else:
|
503
|
+
data = zfit.Data.from_pandas(df, obs=pdf.space, weights=Data.weight_name)
|
504
|
+
|
505
|
+
d_const = {'sg' : [50, 3]}
|
506
|
+
|
507
|
+
obj = Fitter(pdf, data)
|
508
|
+
res = obj.fit(cfg={'constraints' : d_const})
|
509
|
+
|
510
|
+
if plot_fit:
|
511
|
+
obj = ZFitPlotter(data=data, model=pdf)
|
512
|
+
obj.plot(nbins=50, stacked=True)
|
513
|
+
|
514
|
+
save_fit(data=data, model=pdf, res=res, fit_dir=fit_dir, d_const=d_const)
|
515
|
+
#---------------------------------------------
|
516
|
+
def _reformat_values(d_par : dict) -> dict:
|
517
|
+
'''
|
518
|
+
Parameters
|
519
|
+
--------------
|
520
|
+
d_par: Dictionary formatted as:
|
521
|
+
|
522
|
+
{'minuit_hesse': {'cl': 0.6,
|
523
|
+
'error': np.float64(0.04),
|
524
|
+
'weightcorr': <WeightCorr.FALSE: False>},
|
525
|
+
'value' : 0.34},
|
526
|
+
|
527
|
+
Returns
|
528
|
+
--------------
|
529
|
+
Dictionary formatted as:
|
530
|
+
|
531
|
+
{
|
532
|
+
'error' : 0.04,
|
533
|
+
'value' : 0.34
|
534
|
+
}
|
535
|
+
'''
|
536
|
+
|
537
|
+
error = d_par['minuit_hesse']['error']
|
538
|
+
error = float(error)
|
539
|
+
|
540
|
+
value = d_par['value']
|
541
|
+
|
542
|
+
return {'value' : value, 'error' : error}
|
543
|
+
#---------------------------------------------
|
544
|
+
def zres_to_cres(res : zres) -> DictConfig:
|
545
|
+
'''
|
546
|
+
Parameters
|
547
|
+
--------------
|
548
|
+
res : Zfit result object
|
549
|
+
|
550
|
+
Returns
|
551
|
+
--------------
|
552
|
+
OmegaConfig's DictConfig instance
|
553
|
+
'''
|
554
|
+
# This should prevent crash when result object was already frozen
|
555
|
+
try:
|
556
|
+
res.freeze()
|
557
|
+
except AttributeError:
|
558
|
+
pass
|
559
|
+
|
560
|
+
par = res.params
|
561
|
+
d_par = { name : _reformat_values(d_par=d_par) for name, d_par in par.items()}
|
562
|
+
cfg = OmegaConf.create(d_par)
|
563
|
+
|
564
|
+
return cfg
|
565
|
+
#---------------------------------------------
|
dmu/stats/wdata.py
ADDED
@@ -0,0 +1,187 @@
|
|
1
|
+
'''
|
2
|
+
Module with Wdata class
|
3
|
+
'''
|
4
|
+
from typing import Union
|
5
|
+
|
6
|
+
import zfit
|
7
|
+
import numpy
|
8
|
+
import pandas as pnd
|
9
|
+
from zfit.core.interfaces import ZfitSpace as zobs
|
10
|
+
from zfit.core.data import Data as zdata
|
11
|
+
|
12
|
+
from dmu.logging.log_store import LogStore
|
13
|
+
|
14
|
+
|
15
|
+
log=LogStore.add_logger('dmu:stats:wdata')
|
16
|
+
# -------------------------------
|
17
|
+
class Wdata:
|
18
|
+
'''
|
19
|
+
Class meant to symbolize weighted data
|
20
|
+
'''
|
21
|
+
# -------------------------------
|
22
|
+
def __init__(self,
|
23
|
+
data : Union[numpy.ndarray, pnd.DataFrame],
|
24
|
+
weights : numpy.ndarray = None,
|
25
|
+
extra_columns : pnd.DataFrame = None):
|
26
|
+
'''
|
27
|
+
data :
|
28
|
+
weights: Numpy array with weights, if not passed, will use ones as weights
|
29
|
+
extra_columns: Extra information that can be attached to the Wdata object in the form of a pandas dataframe, default None
|
30
|
+
'''
|
31
|
+
self._data = data
|
32
|
+
self._weights = self._get_weights(weights)
|
33
|
+
self._df = self._get_df_extr(extra_columns)
|
34
|
+
# -------------------------------
|
35
|
+
def _get_df_extr(self, df : pnd.DataFrame) -> Union[pnd.DataFrame,None]:
|
36
|
+
if df is None:
|
37
|
+
return None
|
38
|
+
|
39
|
+
if not isinstance(df, pnd.DataFrame):
|
40
|
+
arg_type = type(df)
|
41
|
+
raise ValueError(f'Expected a pandas dataframe, got {arg_type}')
|
42
|
+
|
43
|
+
if len(df) != self.size:
|
44
|
+
raise ValueError('Input dataframe differs in length from data')
|
45
|
+
|
46
|
+
return df
|
47
|
+
# -------------------------------
|
48
|
+
def _get_weights(self, weights : numpy.ndarray) -> numpy.ndarray:
|
49
|
+
if weights is None:
|
50
|
+
log.info('Weights not found, using ones')
|
51
|
+
return numpy.ones(self.size)
|
52
|
+
|
53
|
+
if not isinstance(weights, numpy.ndarray):
|
54
|
+
raise ValueError('Weights argument is not a numpy array')
|
55
|
+
|
56
|
+
weights_size = len(weights)
|
57
|
+
if weights_size != self.size:
|
58
|
+
raise ValueError(f'Data size and weights size differ: {self.size} != {weights_size}')
|
59
|
+
|
60
|
+
return weights
|
61
|
+
# -------------------------------
|
62
|
+
def _build_new_array(self, arr_other : numpy.ndarray, kind : str) -> numpy.ndarray:
|
63
|
+
arr_this = getattr(self, kind)
|
64
|
+
arr = numpy.concatenate([arr_this, arr_other])
|
65
|
+
|
66
|
+
return arr
|
67
|
+
# -------------------------------
|
68
|
+
def __add__(self, other : 'Wdata') -> 'Wdata':
|
69
|
+
'''
|
70
|
+
Takes instance of Wdata and adds it to this instance
|
71
|
+
returning sum.
|
72
|
+
|
73
|
+
Addition is defined as concatenating both data and weights.
|
74
|
+
'''
|
75
|
+
if not isinstance(other, Wdata):
|
76
|
+
other_type = type(other)
|
77
|
+
raise NotImplementedError(f'Cannot add Wdata instance to {other_type} instance')
|
78
|
+
|
79
|
+
log.debug('Adding instances of Wdata')
|
80
|
+
data = self._build_new_array(arr_other = other._data , kind='_data' )
|
81
|
+
weights = self._build_new_array(arr_other = other._weights, kind='_weights')
|
82
|
+
df = self._build_extra_df(df_other = other._df)
|
83
|
+
|
84
|
+
return Wdata(data=data, weights=weights, extra_columns=df)
|
85
|
+
# -------------------------------
|
86
|
+
def __str__(self) -> str:
|
87
|
+
message = '\n'
|
88
|
+
message+= f'{"Size ":<20}{self.size:<20}\n'
|
89
|
+
message+= f'{"Sumw ":<20}{self.sumw:<20.3f}\n'
|
90
|
+
if self._df is None:
|
91
|
+
return message
|
92
|
+
|
93
|
+
message+= f'{"Columns":<20}{" ":<20}\n'
|
94
|
+
for column in self._df.columns:
|
95
|
+
message += ' ' + column + '\n'
|
96
|
+
|
97
|
+
return message
|
98
|
+
# -------------------------------
|
99
|
+
def _build_extra_df(self, df_other : pnd.DataFrame) -> Union[pnd.DataFrame,None]:
|
100
|
+
if df_other is None and self._df is None:
|
101
|
+
return None
|
102
|
+
|
103
|
+
fail_1 = df_other is None and self._df is not None
|
104
|
+
fail_2 = df_other is not None and self._df is None
|
105
|
+
|
106
|
+
if fail_1 or fail_2:
|
107
|
+
raise ValueError('One of the two Wdata instances does not contain extra column information')
|
108
|
+
|
109
|
+
df = pnd.concat([df_other, self._df], axis=0, ignore_index=True)
|
110
|
+
|
111
|
+
return df
|
112
|
+
# -------------------------------
|
113
|
+
def _is_extra_data_equal(self, df_other : pnd.DataFrame, rtol : float) -> bool:
|
114
|
+
df_this = self._df
|
115
|
+
|
116
|
+
if df_other is None and df_this is None:
|
117
|
+
return True
|
118
|
+
|
119
|
+
fail_1 = df_other is None and df_this is not None
|
120
|
+
fail_2 = df_other is not None and df_this is None
|
121
|
+
|
122
|
+
if fail_1 or fail_2:
|
123
|
+
log.warning('One of the weighted data compared does not have extra columns information')
|
124
|
+
return False
|
125
|
+
|
126
|
+
return numpy.allclose(df_this.values, df_other.values, rtol=rtol)
|
127
|
+
# -------------------------------
|
128
|
+
def __eq__(self, other) -> bool:
|
129
|
+
'''
|
130
|
+
Checks that the data and weights are the same within a 1e-5 relative tolerance
|
131
|
+
'''
|
132
|
+
rtol = 1e-5
|
133
|
+
equal_data = numpy.allclose(other._data , self._data , rtol=rtol)
|
134
|
+
equal_weights = numpy.allclose(other._weights, self._weights, rtol=rtol)
|
135
|
+
equal_extra_data = self._is_extra_data_equal(df_other=other.extra_columns, rtol=rtol)
|
136
|
+
|
137
|
+
return equal_data and equal_weights and equal_extra_data
|
138
|
+
# -------------------------------
|
139
|
+
@property
|
140
|
+
def extra_columns(self) -> Union[pnd.DataFrame,None]:
|
141
|
+
'''
|
142
|
+
Dataframe with extra columns, or None, if not passed
|
143
|
+
'''
|
144
|
+
return self._df
|
145
|
+
# -------------------------------
|
146
|
+
@property
|
147
|
+
def size(self) -> int:
|
148
|
+
'''
|
149
|
+
Returns number of entries in dataset
|
150
|
+
'''
|
151
|
+
return len(self._data)
|
152
|
+
# -------------------------------
|
153
|
+
@property
|
154
|
+
def sumw(self) -> int:
|
155
|
+
'''
|
156
|
+
Returns sum of weights
|
157
|
+
'''
|
158
|
+
return numpy.sum(self._weights)
|
159
|
+
# -------------------------------
|
160
|
+
def update_weights(self, weights : numpy.ndarray, replace : bool) -> 'Wdata':
|
161
|
+
'''
|
162
|
+
Takes array of weights to either:
|
163
|
+
- Replace existing array
|
164
|
+
- Update by multiply by existing array
|
165
|
+
|
166
|
+
depending on the replace argument value. It returns a new instance of Wdata
|
167
|
+
'''
|
168
|
+
if self._weights.shape != weights.shape:
|
169
|
+
raise ValueError(f'Invalid shape for array of weights, expected/got: {self._weights.shape}/{weights.shape}')
|
170
|
+
|
171
|
+
new_weights = weights if replace else weights * self._weights
|
172
|
+
|
173
|
+
data = Wdata(data=self._data, weights=new_weights)
|
174
|
+
|
175
|
+
return data
|
176
|
+
# -------------------------------
|
177
|
+
def to_zfit(self, obs : zobs) -> zdata:
|
178
|
+
'''
|
179
|
+
Function that takes a zfit observable and uses it
|
180
|
+
to build a zfit data instance tha it then returns
|
181
|
+
'''
|
182
|
+
log.debug('Building zfit dataset')
|
183
|
+
|
184
|
+
data = zfit.data.Data(obs=obs, data=self._data, weights=self._weights)
|
185
|
+
|
186
|
+
return data
|
187
|
+
# -------------------------------
|
dmu/stats/zfit.py
ADDED
@@ -0,0 +1,17 @@
|
|
1
|
+
'''
|
2
|
+
Module intended to wrap zfit
|
3
|
+
|
4
|
+
Needed in order to silence tensorflow messages
|
5
|
+
'''
|
6
|
+
# pylint: disable=unused-import, wrong-import-order
|
7
|
+
|
8
|
+
try:
|
9
|
+
import ROOT
|
10
|
+
except ImportError:
|
11
|
+
pass
|
12
|
+
|
13
|
+
import dmu.generic.utilities as gut
|
14
|
+
with gut.silent_import():
|
15
|
+
import tensorflow
|
16
|
+
|
17
|
+
import zfit
|