data-manipulation-utilities 0.2.7__py3-none-any.whl → 0.2.8.dev714__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {data_manipulation_utilities-0.2.7.dist-info → data_manipulation_utilities-0.2.8.dev714.dist-info}/METADATA +641 -44
- data_manipulation_utilities-0.2.8.dev714.dist-info/RECORD +93 -0
- {data_manipulation_utilities-0.2.7.dist-info → data_manipulation_utilities-0.2.8.dev714.dist-info}/WHEEL +1 -1
- {data_manipulation_utilities-0.2.7.dist-info → data_manipulation_utilities-0.2.8.dev714.dist-info}/entry_points.txt +1 -0
- dmu/__init__.py +0 -0
- dmu/generic/hashing.py +34 -8
- dmu/generic/utilities.py +164 -11
- dmu/logging/log_store.py +34 -2
- dmu/logging/messages.py +96 -0
- dmu/ml/cv_classifier.py +3 -3
- dmu/ml/cv_diagnostics.py +3 -0
- dmu/ml/cv_performance.py +58 -0
- dmu/ml/cv_predict.py +149 -46
- dmu/ml/train_mva.py +482 -100
- dmu/ml/utilities.py +29 -10
- dmu/pdataframe/utilities.py +28 -3
- dmu/plotting/fwhm.py +2 -2
- dmu/plotting/matrix.py +1 -1
- dmu/plotting/plotter.py +23 -3
- dmu/plotting/plotter_1d.py +96 -32
- dmu/plotting/plotter_2d.py +5 -0
- dmu/rdataframe/utilities.py +54 -3
- dmu/rfile/ddfgetter.py +102 -0
- dmu/stats/fit_stats.py +129 -0
- dmu/stats/fitter.py +55 -22
- dmu/stats/gof_calculator.py +7 -0
- dmu/stats/model_factory.py +153 -62
- dmu/stats/parameters.py +100 -0
- dmu/stats/utilities.py +443 -12
- dmu/stats/wdata.py +187 -0
- dmu/stats/zfit.py +17 -0
- dmu/stats/zfit_plotter.py +147 -36
- dmu/testing/utilities.py +102 -24
- dmu/workflow/__init__.py +0 -0
- dmu/workflow/cache.py +266 -0
- dmu_data/ml/tests/train_mva.yaml +9 -7
- dmu_data/ml/tests/train_mva_def.yaml +75 -0
- dmu_data/ml/tests/train_mva_with_diagnostics.yaml +10 -5
- dmu_data/ml/tests/train_mva_with_preffix.yaml +58 -0
- dmu_data/plotting/tests/2d.yaml +5 -5
- dmu_data/plotting/tests/line.yaml +15 -0
- dmu_data/plotting/tests/styling.yaml +8 -1
- dmu_data/rfile/friends.yaml +13 -0
- dmu_data/stats/fitter/test_simple.yaml +28 -0
- dmu_data/stats/kde_optimizer/control.json +1 -0
- dmu_data/stats/kde_optimizer/signal.json +1 -0
- dmu_data/stats/parameters/data.yaml +178 -0
- dmu_data/tests/config.json +6 -0
- dmu_data/tests/config.yaml +4 -0
- dmu_data/tests/pdf_to_tex.txt +34 -0
- dmu_scripts/kerberos/check_expiration +21 -0
- dmu_scripts/kerberos/convert_certificate +22 -0
- dmu_scripts/ml/compare_classifiers.py +85 -0
- data_manipulation_utilities-0.2.7.dist-info/RECORD +0 -69
- {data_manipulation_utilities-0.2.7.data → data_manipulation_utilities-0.2.8.dev714.data}/scripts/publish +0 -0
- {data_manipulation_utilities-0.2.7.dist-info → data_manipulation_utilities-0.2.8.dev714.dist-info}/top_level.txt +0 -0
dmu/stats/fit_stats.py
ADDED
@@ -0,0 +1,129 @@
|
|
1
|
+
'''
|
2
|
+
Module with FitStats class
|
3
|
+
'''
|
4
|
+
|
5
|
+
import re
|
6
|
+
import pprint
|
7
|
+
import pickle
|
8
|
+
from typing import Union
|
9
|
+
|
10
|
+
import numpy
|
11
|
+
import pandas as pnd
|
12
|
+
from zfit.result import FitResult as zres
|
13
|
+
from dmu.logging.log_store import LogStore
|
14
|
+
|
15
|
+
log = LogStore.add_logger('dmu:fit_stats')
|
16
|
+
# -------------------------------
|
17
|
+
class FitStats:
|
18
|
+
'''
|
19
|
+
Class meant to provide fit statistics
|
20
|
+
'''
|
21
|
+
# -------------------------------
|
22
|
+
def __init__(self, fit_dir : str):
|
23
|
+
'''
|
24
|
+
fit_dir : Path to directory where fit outputs are stored
|
25
|
+
'''
|
26
|
+
self._fit_dir = fit_dir
|
27
|
+
self._regex = r'^([^\s]+)\s+([^\s]+)\s+([^\s]+)\s+([^\s]+)\s+([^\s]+)\s+([^\s]+)\s*$'
|
28
|
+
self._sig_yld = 'nsig'
|
29
|
+
|
30
|
+
# Functions need to be called at the end
|
31
|
+
# When all the needed attributes are already set
|
32
|
+
self._df = self._get_data()
|
33
|
+
# -------------------------------
|
34
|
+
def _row_from_line(self, line : str) -> Union[list,None]:
|
35
|
+
mtch = re.match(self._regex, line)
|
36
|
+
if not mtch:
|
37
|
+
return None
|
38
|
+
|
39
|
+
[name, value, low, high, is_floating, mu_sg] = mtch.groups()
|
40
|
+
|
41
|
+
if mu_sg == 'none':
|
42
|
+
mu = numpy.nan
|
43
|
+
sg = numpy.nan
|
44
|
+
else:
|
45
|
+
[mu, sg] = mu_sg.split('___')
|
46
|
+
mu = float(mu)
|
47
|
+
sg = float(sg)
|
48
|
+
|
49
|
+
is_floating = int(is_floating) #Direct conversion from '0' to bool will break this
|
50
|
+
is_floating = bool(is_floating)
|
51
|
+
row = [name, float(value), float(low), float(high), is_floating, mu, sg]
|
52
|
+
|
53
|
+
return row
|
54
|
+
# -------------------------------
|
55
|
+
def _get_data(self) -> pnd.DataFrame:
|
56
|
+
fit_path = f'{self._fit_dir}/post_fit.txt'
|
57
|
+
|
58
|
+
with open(fit_path, encoding='utf-8') as ifile:
|
59
|
+
l_line = ifile.read().splitlines()
|
60
|
+
|
61
|
+
df = pnd.DataFrame(columns=['name', 'value', 'low', 'high', 'float', 'mu', 'sg'])
|
62
|
+
for line in l_line:
|
63
|
+
row = self._row_from_line(line)
|
64
|
+
if row is None:
|
65
|
+
continue
|
66
|
+
|
67
|
+
df.loc[len(df)] = row
|
68
|
+
|
69
|
+
df = self._attach_errors(df)
|
70
|
+
log.debug(df)
|
71
|
+
|
72
|
+
return df
|
73
|
+
# -------------------------------
|
74
|
+
def _error_from_res(self, row : pnd.Series, res : zres) -> float:
|
75
|
+
if not row['float']: # If this parameter is fixed in the fit, the error is zero
|
76
|
+
return 0
|
77
|
+
|
78
|
+
name = row['name']
|
79
|
+
if name not in res.params:
|
80
|
+
for this_name in res.params:
|
81
|
+
log.info(this_name)
|
82
|
+
|
83
|
+
raise KeyError(f'{name} not found')
|
84
|
+
|
85
|
+
d_data = res.params[name]
|
86
|
+
|
87
|
+
if 'hesse' in d_data:
|
88
|
+
return d_data['hesse']['error']
|
89
|
+
|
90
|
+
if 'minuit_hesse' in d_data:
|
91
|
+
return d_data['minuit_hesse']['error']
|
92
|
+
|
93
|
+
pprint.pprint(d_data)
|
94
|
+
raise KeyError(f'Cannot find error in dictionary')
|
95
|
+
# -------------------------------
|
96
|
+
def _attach_errors(self, df : pnd.DataFrame) -> pnd.DataFrame:
|
97
|
+
pkl_path = f'{self._fit_dir}/fit.pkl'
|
98
|
+
with open(pkl_path, 'rb') as ifile:
|
99
|
+
res = pickle.load(ifile)
|
100
|
+
|
101
|
+
df['error'] = df.apply(lambda row : self._error_from_res(row, res), axis=1)
|
102
|
+
|
103
|
+
return df
|
104
|
+
# -------------------------------
|
105
|
+
def print_blind_stats(self) -> None:
|
106
|
+
'''
|
107
|
+
Will print statistics, excluding signal information
|
108
|
+
'''
|
109
|
+
df_blind = self._df[self._df['name'] != self._sig_yld]
|
110
|
+
log.info(df_blind)
|
111
|
+
# -------------------------------
|
112
|
+
def get_value(self, name : str, kind : str) -> float:
|
113
|
+
'''
|
114
|
+
Returns float with value associated to fit
|
115
|
+
name : Name of variable, e.g. mu, sg, nsig
|
116
|
+
kind : Type of quantity, e.g. value, error
|
117
|
+
'''
|
118
|
+
|
119
|
+
log.info(f'Retrieving signal yield from {name} and {kind}')
|
120
|
+
df = self._df[self._df['name'] == name]
|
121
|
+
nrow = len(df)
|
122
|
+
if nrow != 1:
|
123
|
+
self.print_blind_stats()
|
124
|
+
raise ValueError(f'Cannot retrieve one and only one row, found {nrow}')
|
125
|
+
|
126
|
+
val = df[kind]
|
127
|
+
|
128
|
+
return float(val)
|
129
|
+
# -------------------------------
|
dmu/stats/fitter.py
CHANGED
@@ -1,20 +1,23 @@
|
|
1
1
|
'''
|
2
2
|
Module holding zfitter class
|
3
3
|
'''
|
4
|
+
# pylint: disable=wrong-import-order, import-error
|
4
5
|
|
5
6
|
import pprint
|
6
7
|
from typing import Union
|
7
8
|
from functools import lru_cache
|
8
9
|
|
9
10
|
import numpy
|
10
|
-
import zfit
|
11
11
|
import pandas as pd
|
12
12
|
|
13
|
-
from
|
13
|
+
from dmu.logging import messages as mes
|
14
|
+
from dmu.stats.zfit import zfit
|
15
|
+
from dmu.logging.log_store import LogStore
|
16
|
+
|
14
17
|
from zfit.minimizers.strategy import FailMinimizeNaN
|
15
|
-
from zfit.result import FitResult
|
16
18
|
from zfit.core.data import Data
|
17
|
-
from
|
19
|
+
from zfit.result import FitResult as zres
|
20
|
+
from scipy import stats
|
18
21
|
|
19
22
|
log = LogStore.add_logger('dmu:statistics:fitter')
|
20
23
|
#------------------------------
|
@@ -43,6 +46,15 @@ class Fitter:
|
|
43
46
|
self._obs : zfit.Space
|
44
47
|
self._d_par : dict
|
45
48
|
|
49
|
+
# These are substrings found in tensorflow messages
|
50
|
+
# that are pretty useless and need to be hidden
|
51
|
+
self._l_hidden_tf_lines= [
|
52
|
+
'abnormal_detected_host @',
|
53
|
+
'Skipping loop optimization for Merge',
|
54
|
+
'Creating GpuSolver handles for stream',
|
55
|
+
'Loaded cuDNN version',
|
56
|
+
'All log messages before absl::InitializeLog()']
|
57
|
+
|
46
58
|
self._ndof = 10
|
47
59
|
self._pval_threshold = 0.01
|
48
60
|
self._initialized = False
|
@@ -53,7 +65,7 @@ class Fitter:
|
|
53
65
|
|
54
66
|
self._check_data()
|
55
67
|
|
56
|
-
self.
|
68
|
+
self._initialized = True
|
57
69
|
#------------------------------
|
58
70
|
def _check_data(self):
|
59
71
|
if isinstance(self._data_in, numpy.ndarray):
|
@@ -83,11 +95,9 @@ class Fitter:
|
|
83
95
|
elif len(shp) == 2:
|
84
96
|
_, jval = shp
|
85
97
|
if jval != 1:
|
86
|
-
|
87
|
-
raise
|
98
|
+
raise ValueError(f'Invalid data shape: {shp}')
|
88
99
|
else:
|
89
|
-
|
90
|
-
raise
|
100
|
+
raise ValueError(f'Invalid data shape: {shp}')
|
91
101
|
|
92
102
|
ival = data.shape[0]
|
93
103
|
|
@@ -158,7 +168,7 @@ class Fitter:
|
|
158
168
|
log.debug(f'Ndof: {self._ndof}')
|
159
169
|
log.debug(f'pval: {pvalue:<.3e}')
|
160
170
|
|
161
|
-
return
|
171
|
+
return sum_chi2, self._ndof, pvalue
|
162
172
|
#------------------------------
|
163
173
|
def _get_float_pars(self):
|
164
174
|
npar = 0
|
@@ -240,7 +250,9 @@ class Fitter:
|
|
240
250
|
if 'ranges' not in cfg:
|
241
251
|
return [None]
|
242
252
|
|
243
|
-
|
253
|
+
ranges_any = cfg['ranges']
|
254
|
+
|
255
|
+
ranges = [ tuple(elm) for elm in ranges_any ]
|
244
256
|
log.info('-' * 30)
|
245
257
|
log.info(f'{"Low edge":>15}{"High edge":>15}')
|
246
258
|
log.info('-' * 30)
|
@@ -359,9 +371,13 @@ class Fitter:
|
|
359
371
|
log.info(header)
|
360
372
|
log.info(parval)
|
361
373
|
#------------------------------
|
362
|
-
def _minimize(self, nll, cfg : dict) -> tuple[
|
374
|
+
def _minimize(self, nll, cfg : dict) -> tuple[zres, tuple]:
|
363
375
|
mnm = zfit.minimize.Minuit()
|
364
|
-
|
376
|
+
|
377
|
+
with mes.filter_stderr(banned_substrings=self._l_hidden_tf_lines):
|
378
|
+
res = mnm.minimize(nll)
|
379
|
+
|
380
|
+
res = self._calculate_error(res)
|
365
381
|
|
366
382
|
try:
|
367
383
|
gof = self._calc_gof()
|
@@ -376,7 +392,16 @@ class Fitter:
|
|
376
392
|
|
377
393
|
return res, gof
|
378
394
|
#------------------------------
|
379
|
-
def
|
395
|
+
def _gof_is_bad(self, gof : tuple[float, int, float]) -> bool:
|
396
|
+
chi2, ndof, pval = gof
|
397
|
+
|
398
|
+
good_ndof = 0 <= ndof < numpy.inf
|
399
|
+
good_chi2 = 0 <= chi2 < numpy.inf
|
400
|
+
good_pval = 0 <= pval < numpy.inf
|
401
|
+
|
402
|
+
return not (good_chi2 and good_pval and good_ndof)
|
403
|
+
#------------------------------
|
404
|
+
def _fit_retries(self, cfg : dict) -> tuple[dict, zres]:
|
380
405
|
ntries = cfg['strategy']['retry']['ntries']
|
381
406
|
pvalue_thresh= cfg['strategy']['retry']['pvalue_thresh']
|
382
407
|
ignore_status= cfg['strategy']['retry']['ignore_status']
|
@@ -401,6 +426,12 @@ class Fitter:
|
|
401
426
|
continue
|
402
427
|
|
403
428
|
chi2, _, pval = gof
|
429
|
+
|
430
|
+
if self._gof_is_bad(gof):
|
431
|
+
log.debug('Reshufling and skipping, found bad gof')
|
432
|
+
self._reshuffle_pdf_pars()
|
433
|
+
continue
|
434
|
+
|
404
435
|
d_pval_res[chi2]=res
|
405
436
|
|
406
437
|
if pval > pvalue_thresh:
|
@@ -414,7 +445,7 @@ class Fitter:
|
|
414
445
|
|
415
446
|
return d_pval_res, last_res
|
416
447
|
#------------------------------
|
417
|
-
def _pick_best_fit(self, d_pval_res : dict, last_res :
|
448
|
+
def _pick_best_fit(self, d_pval_res : dict, last_res : zres) -> zres:
|
418
449
|
nsucc = len(d_pval_res)
|
419
450
|
if nsucc == 0:
|
420
451
|
log.warning('None of the fits succeeded, returning last result')
|
@@ -426,7 +457,7 @@ class Fitter:
|
|
426
457
|
l_pval_res.sort()
|
427
458
|
_, res = l_pval_res[0]
|
428
459
|
|
429
|
-
log.debug('Picking out best fit from {nsucc} fits')
|
460
|
+
log.debug(f'Picking out best fit from {nsucc} fits')
|
430
461
|
for chi2, _ in l_pval_res:
|
431
462
|
log.debug(f'{chi2:.3f}')
|
432
463
|
|
@@ -434,7 +465,7 @@ class Fitter:
|
|
434
465
|
|
435
466
|
return res
|
436
467
|
#------------------------------
|
437
|
-
def _fit_in_steps(self, cfg : dict) ->
|
468
|
+
def _fit_in_steps(self, cfg : dict) -> zres:
|
438
469
|
l_nsample = cfg['strategy']['steps']['nsteps']
|
439
470
|
l_nsigma = cfg['strategy']['steps']['nsigma']
|
440
471
|
l_yield = cfg['strategy']['steps']['yields']
|
@@ -453,7 +484,6 @@ class Fitter:
|
|
453
484
|
log.info('Fitting full sample')
|
454
485
|
nll = self._get_full_nll(cfg = cfg)
|
455
486
|
res, _ = self._minimize(nll, cfg)
|
456
|
-
res.hesse(method='minuit_hesse')
|
457
487
|
|
458
488
|
if res is None:
|
459
489
|
nsteps = len(l_nsample)
|
@@ -461,7 +491,7 @@ class Fitter:
|
|
461
491
|
|
462
492
|
return res
|
463
493
|
#------------------------------
|
464
|
-
def _result_to_value_error(self, res :
|
494
|
+
def _result_to_value_error(self, res : zres) -> dict[str, list[float]]:
|
465
495
|
d_par = {}
|
466
496
|
for par, d_val in res.params.items():
|
467
497
|
try:
|
@@ -475,7 +505,7 @@ class Fitter:
|
|
475
505
|
|
476
506
|
return d_par
|
477
507
|
#------------------------------
|
478
|
-
def _update_par_bounds(self, res :
|
508
|
+
def _update_par_bounds(self, res : zres, nsigma : float, yields : list[str]) -> None:
|
479
509
|
s_shape_par = self._pdf.get_params(is_yield=False, floating=True)
|
480
510
|
d_shp_par = { par.name : par for par in s_shape_par if par.name not in yields}
|
481
511
|
d_fit_par = self._result_to_value_error(res)
|
@@ -494,6 +524,11 @@ class Fitter:
|
|
494
524
|
|
495
525
|
log.info(f'{name:<20}{val - err:<20.3e}{val + err:<20.3e}')
|
496
526
|
#------------------------------
|
527
|
+
def _calculate_error(self, res : zres) -> zres:
|
528
|
+
res.hesse(name='minuit_hesse')
|
529
|
+
|
530
|
+
return res
|
531
|
+
#------------------------------
|
497
532
|
def fit(self, cfg : Union[dict, None] = None):
|
498
533
|
'''
|
499
534
|
Runs the fit using the configuration specified by the cfg dictionary
|
@@ -508,7 +543,6 @@ class Fitter:
|
|
508
543
|
if 'strategy' not in cfg:
|
509
544
|
nll = self._get_full_nll(cfg = cfg)
|
510
545
|
res, _ = self._minimize(nll, cfg)
|
511
|
-
res.hesse(method='minuit_hesse')
|
512
546
|
elif 'retry' in cfg['strategy']:
|
513
547
|
d_pval_res, last_res = self._fit_retries(cfg)
|
514
548
|
res = self._pick_best_fit(d_pval_res, last_res)
|
@@ -517,6 +551,5 @@ class Fitter:
|
|
517
551
|
else:
|
518
552
|
raise ValueError('Unsupported fitting strategy')
|
519
553
|
|
520
|
-
|
521
554
|
return res
|
522
555
|
#------------------------------
|
dmu/stats/gof_calculator.py
CHANGED
@@ -110,6 +110,13 @@ class GofCalculator:
|
|
110
110
|
arr_data = self._get_data_bin_contents()
|
111
111
|
arr_modl = self._get_pdf_bin_contents()
|
112
112
|
|
113
|
+
log.debug(40 * '-')
|
114
|
+
log.debug(f'{"Data":<20}{"Model":<20}')
|
115
|
+
log.debug(40 * '-')
|
116
|
+
for dval, mval in zip(arr_data, arr_modl):
|
117
|
+
log.debug(f'{dval:<20.3f}{mval:<20.3f}')
|
118
|
+
log.debug(40 * '-')
|
119
|
+
|
113
120
|
norm = numpy.sum(arr_data) / numpy.sum(arr_modl)
|
114
121
|
arr_modl = norm * arr_modl
|
115
122
|
arr_res = arr_modl - arr_data
|