data-manipulation-utilities 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data_manipulation_utilities-0.0.1.dist-info/METADATA +713 -0
- data_manipulation_utilities-0.0.1.dist-info/RECORD +45 -0
- data_manipulation_utilities-0.0.1.dist-info/WHEEL +5 -0
- data_manipulation_utilities-0.0.1.dist-info/entry_points.txt +6 -0
- data_manipulation_utilities-0.0.1.dist-info/top_level.txt +3 -0
- dmu/arrays/utilities.py +55 -0
- dmu/dataframe/dataframe.py +36 -0
- dmu/generic/utilities.py +69 -0
- dmu/logging/log_store.py +129 -0
- dmu/ml/cv_classifier.py +122 -0
- dmu/ml/cv_predict.py +152 -0
- dmu/ml/train_mva.py +257 -0
- dmu/ml/utilities.py +132 -0
- dmu/plotting/plotter.py +227 -0
- dmu/plotting/plotter_1d.py +113 -0
- dmu/plotting/plotter_2d.py +87 -0
- dmu/rdataframe/atr_mgr.py +79 -0
- dmu/rdataframe/utilities.py +72 -0
- dmu/rfile/rfprinter.py +91 -0
- dmu/rfile/utilities.py +34 -0
- dmu/stats/fitter.py +515 -0
- dmu/stats/function.py +314 -0
- dmu/stats/utilities.py +134 -0
- dmu/testing/utilities.py +119 -0
- dmu/text/transformer.py +182 -0
- dmu_data/__init__.py +0 -0
- dmu_data/ml/tests/train_mva.yaml +37 -0
- dmu_data/plotting/tests/2d.yaml +14 -0
- dmu_data/plotting/tests/fig_size.yaml +13 -0
- dmu_data/plotting/tests/high_stat.yaml +22 -0
- dmu_data/plotting/tests/name.yaml +14 -0
- dmu_data/plotting/tests/no_bounds.yaml +12 -0
- dmu_data/plotting/tests/simple.yaml +8 -0
- dmu_data/plotting/tests/title.yaml +14 -0
- dmu_data/plotting/tests/weights.yaml +13 -0
- dmu_data/text/transform.toml +4 -0
- dmu_data/text/transform.txt +6 -0
- dmu_data/text/transform_set.toml +8 -0
- dmu_data/text/transform_set.txt +6 -0
- dmu_data/text/transform_trf.txt +12 -0
- dmu_scripts/physics/check_truth.py +121 -0
- dmu_scripts/rfile/compare_root_files.py +299 -0
- dmu_scripts/rfile/print_trees.py +35 -0
- dmu_scripts/ssh/coned.py +168 -0
- dmu_scripts/text/transform_text.py +46 -0
dmu/ml/train_mva.py
ADDED
@@ -0,0 +1,257 @@
|
|
1
|
+
'''
|
2
|
+
Module with TrainMva class
|
3
|
+
'''
|
4
|
+
import os
|
5
|
+
from typing import Union
|
6
|
+
|
7
|
+
import joblib
|
8
|
+
import pandas as pnd
|
9
|
+
import numpy
|
10
|
+
import matplotlib.pyplot as plt
|
11
|
+
|
12
|
+
from sklearn.metrics import roc_curve, auc
|
13
|
+
from sklearn.model_selection import StratifiedKFold
|
14
|
+
|
15
|
+
from ROOT import RDataFrame
|
16
|
+
|
17
|
+
import dmu.ml.utilities as ut
|
18
|
+
from dmu.ml.cv_classifier import CVClassifier as cls
|
19
|
+
from dmu.plotting.plotter_1d import Plotter1D as Plotter
|
20
|
+
from dmu.logging.log_store import LogStore
|
21
|
+
|
22
|
+
log = LogStore.add_logger('data_checks:train_mva')
|
23
|
+
# ---------------------------------------------
|
24
|
+
class TrainMva:
|
25
|
+
'''
|
26
|
+
Interface to scikit learn used to train classifier
|
27
|
+
'''
|
28
|
+
# ---------------------------------------------
|
29
|
+
def __init__(self, bkg=None, sig=None, cfg=None):
|
30
|
+
'''
|
31
|
+
bkg (ROOT dataframe): Holds real data
|
32
|
+
sig (ROOT dataframe): Holds simulation
|
33
|
+
cfg (dict) : Dictionary storing configuration for training
|
34
|
+
'''
|
35
|
+
if bkg is None:
|
36
|
+
raise ValueError('Background dataframe is not a ROOT dataframe')
|
37
|
+
|
38
|
+
if sig is None:
|
39
|
+
raise ValueError('Signal dataframe is not a ROOT dataframe')
|
40
|
+
|
41
|
+
if not isinstance(cfg, dict):
|
42
|
+
raise ValueError('Config dictionary is not a dictionary')
|
43
|
+
|
44
|
+
self._rdf_bkg = bkg
|
45
|
+
self._rdf_sig = sig
|
46
|
+
self._cfg = cfg if cfg is not None else {}
|
47
|
+
|
48
|
+
self._l_model : cls
|
49
|
+
|
50
|
+
self._l_ft_name = self._cfg['training']['features']
|
51
|
+
|
52
|
+
self._df_ft, self._l_lab = self._get_inputs()
|
53
|
+
# ---------------------------------------------
|
54
|
+
def _get_inputs(self) -> tuple[pnd.DataFrame, numpy.ndarray]:
|
55
|
+
log.info('Getting signal')
|
56
|
+
df_sig, arr_lab_sig = self._get_sample_inputs(self._rdf_sig, label = 1)
|
57
|
+
|
58
|
+
log.info('Getting background')
|
59
|
+
df_bkg, arr_lab_bkg = self._get_sample_inputs(self._rdf_bkg, label = 0)
|
60
|
+
|
61
|
+
df = pnd.concat([df_sig, df_bkg], axis=0)
|
62
|
+
arr_lab = numpy.concatenate([arr_lab_sig, arr_lab_bkg])
|
63
|
+
|
64
|
+
return df, arr_lab
|
65
|
+
# ---------------------------------------------
|
66
|
+
def _get_sample_inputs(self, rdf : RDataFrame, label : int) -> tuple[pnd.DataFrame, numpy.ndarray]:
|
67
|
+
d_ft = rdf.AsNumpy(self._l_ft_name)
|
68
|
+
df = pnd.DataFrame(d_ft)
|
69
|
+
df = ut.cleanup(df)
|
70
|
+
l_lab= len(df) * [label]
|
71
|
+
|
72
|
+
return df, numpy.array(l_lab)
|
73
|
+
# ---------------------------------------------
|
74
|
+
def _get_model(self, arr_index : numpy.ndarray) -> cls:
|
75
|
+
model = cls(cfg = self._cfg)
|
76
|
+
df_ft = self._df_ft.iloc[arr_index]
|
77
|
+
l_lab = self._l_lab[arr_index]
|
78
|
+
|
79
|
+
log.debug(f'Training feature shape: {df_ft.shape}')
|
80
|
+
log.debug(f'Training label size: {len(l_lab)}')
|
81
|
+
|
82
|
+
model.fit(df_ft, l_lab)
|
83
|
+
|
84
|
+
return model
|
85
|
+
# ---------------------------------------------
|
86
|
+
def _get_models(self):
|
87
|
+
# pylint: disable = too-many-locals
|
88
|
+
'''
|
89
|
+
Will create models, train them and return them
|
90
|
+
'''
|
91
|
+
nfold = self._cfg['training']['nfold']
|
92
|
+
rdmst = self._cfg['training']['rdm_stat']
|
93
|
+
|
94
|
+
kfold = StratifiedKFold(n_splits=nfold, shuffle=True, random_state=rdmst)
|
95
|
+
|
96
|
+
l_model=[]
|
97
|
+
ifold=0
|
98
|
+
for arr_itr, arr_its in kfold.split(self._df_ft, self._l_lab):
|
99
|
+
log.debug(20 * '-')
|
100
|
+
log.info(f'Training fold: {ifold}')
|
101
|
+
log.debug(20 * '-')
|
102
|
+
model = self._get_model(arr_itr)
|
103
|
+
l_model.append(model)
|
104
|
+
|
105
|
+
arr_sig_sig_tr, arr_sig_bkg_tr, arr_sig_all_tr, arr_lab_tr = self._get_scores(model, arr_itr, on_training_ok= True)
|
106
|
+
arr_sig_sig_ts, arr_sig_bkg_ts, arr_sig_all_ts, arr_lab_ts = self._get_scores(model, arr_its, on_training_ok=False)
|
107
|
+
|
108
|
+
self._plot_scores(arr_sig_sig_tr, arr_sig_sig_ts, arr_sig_bkg_tr, arr_sig_bkg_ts, ifold)
|
109
|
+
|
110
|
+
self._plot_roc(arr_lab_ts, arr_sig_all_ts, arr_lab_tr, arr_sig_all_tr, ifold)
|
111
|
+
|
112
|
+
ifold+=1
|
113
|
+
|
114
|
+
return l_model
|
115
|
+
# ---------------------------------------------
|
116
|
+
def _get_scores(self, model : cls, arr_index : numpy.ndarray, on_training_ok : bool) -> tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray, numpy.ndarray]:
|
117
|
+
'''
|
118
|
+
Returns a tuple of four arrays
|
119
|
+
|
120
|
+
arr_sig : Signal probabilities for signal
|
121
|
+
arr_bkg : Signal probabilities for background
|
122
|
+
arr_all : Signal probabilities for both
|
123
|
+
arr_lab : Labels for both
|
124
|
+
'''
|
125
|
+
nentries = len(arr_index)
|
126
|
+
log.debug(f'Getting {nentries} signal probabilities')
|
127
|
+
|
128
|
+
df_ft = self._df_ft.iloc[arr_index]
|
129
|
+
arr_prob = model.predict_proba(df_ft, on_training_ok=on_training_ok)
|
130
|
+
arr_lab = self._l_lab[arr_index]
|
131
|
+
|
132
|
+
l_all = [ sig_prob for [_, sig_prob] in arr_prob ]
|
133
|
+
arr_all = numpy.array(l_all)
|
134
|
+
|
135
|
+
arr_sig, arr_bkg= self._split_scores(arr_prob=arr_prob, arr_label=arr_lab)
|
136
|
+
|
137
|
+
return arr_sig, arr_bkg, arr_all, arr_lab
|
138
|
+
# ---------------------------------------------
|
139
|
+
def _split_scores(self, arr_prob : numpy.ndarray, arr_label : numpy.ndarray) -> tuple[numpy.ndarray, numpy.ndarray]:
|
140
|
+
'''
|
141
|
+
Will split the testing scores (predictions) based on the training scores
|
142
|
+
|
143
|
+
tst is a list of lists as [p_bkg, p_sig]
|
144
|
+
'''
|
145
|
+
|
146
|
+
l_sig = [ prb[1] for prb, lab in zip(arr_prob, arr_label) if lab == 1]
|
147
|
+
l_bkg = [ prb[1] for prb, lab in zip(arr_prob, arr_label) if lab == 0]
|
148
|
+
|
149
|
+
arr_sig = numpy.array(l_sig)
|
150
|
+
arr_bkg = numpy.array(l_bkg)
|
151
|
+
|
152
|
+
return arr_sig, arr_bkg
|
153
|
+
# ---------------------------------------------
|
154
|
+
def _save_model(self, model, ifold):
|
155
|
+
'''
|
156
|
+
Saves a model, associated to a specific fold
|
157
|
+
'''
|
158
|
+
model_path = self._cfg['saving']['path']
|
159
|
+
if os.path.isfile(model_path):
|
160
|
+
log.info(f'Model found in {model_path}, not saving')
|
161
|
+
return
|
162
|
+
|
163
|
+
dir_name = os.path.dirname(model_path)
|
164
|
+
os.makedirs(dir_name, exist_ok=True)
|
165
|
+
|
166
|
+
model_path = model_path.replace('.pkl', f'_{ifold:03}.pkl')
|
167
|
+
|
168
|
+
log.info(f'Saving model to: {model_path}')
|
169
|
+
joblib.dump(model, model_path)
|
170
|
+
# ---------------------------------------------
|
171
|
+
def _plot_scores(self, arr_sig_trn, arr_sig_tst, arr_bkg_trn, arr_bkg_tst, ifold):
|
172
|
+
# pylint: disable = too-many-arguments, too-many-positional-arguments
|
173
|
+
'''
|
174
|
+
Will plot an array of scores, associated to a given fold
|
175
|
+
'''
|
176
|
+
log.debug(f'Plotting scores for {ifold} fold')
|
177
|
+
|
178
|
+
if 'val_dir' not in self._cfg['plotting']:
|
179
|
+
log.warning('Scores path not passed, not plotting scores')
|
180
|
+
return
|
181
|
+
|
182
|
+
val_dir = self._cfg['plotting']['val_dir']
|
183
|
+
val_dir = f'{val_dir}/fold_{ifold:03}'
|
184
|
+
os.makedirs(val_dir, exist_ok=True)
|
185
|
+
|
186
|
+
plt.hist(arr_sig_trn, alpha = 0.3, bins=50, range=(0,1), color='b', density=True, label='Signal Train')
|
187
|
+
plt.hist(arr_sig_tst, histtype='step', bins=50, range=(0,1), color='b', density=True, label='Signal Test')
|
188
|
+
|
189
|
+
plt.hist(arr_bkg_trn, alpha = 0.3, bins=50, range=(0,1), color='r', density=True, label='Background Train')
|
190
|
+
plt.hist(arr_bkg_tst, histtype='step', bins=50, range=(0,1), color='r', density=True, label='Background Test')
|
191
|
+
|
192
|
+
plt.legend()
|
193
|
+
plt.title(f'Fold: {ifold}')
|
194
|
+
plt.xlabel('Signal probability')
|
195
|
+
plt.ylabel('Normalized')
|
196
|
+
plt.savefig(f'{val_dir}/scores.png')
|
197
|
+
plt.close()
|
198
|
+
# ---------------------------------------------
|
199
|
+
def _plot_roc(self,
|
200
|
+
l_lab_ts : numpy.ndarray,
|
201
|
+
l_prb_ts : numpy.ndarray,
|
202
|
+
l_lab_tr : numpy.ndarray,
|
203
|
+
l_prb_tr : numpy.ndarray,
|
204
|
+
ifold : int):
|
205
|
+
'''
|
206
|
+
Takes the labels and the probabilities and plots ROC
|
207
|
+
curve for given fold
|
208
|
+
'''
|
209
|
+
# pylint: disable = too-many-arguments, too-many-positional-arguments
|
210
|
+
log.debug(f'Plotting ROC curve for {ifold} fold')
|
211
|
+
|
212
|
+
val_dir = self._cfg['plotting']['val_dir']
|
213
|
+
val_dir = f'{val_dir}/fold_{ifold:03}'
|
214
|
+
os.makedirs(val_dir, exist_ok=True)
|
215
|
+
|
216
|
+
xval_ts, yval_ts, _ = roc_curve(l_lab_ts, l_prb_ts)
|
217
|
+
xval_ts = 1 - xval_ts
|
218
|
+
area_ts = auc(xval_ts, yval_ts)
|
219
|
+
|
220
|
+
xval_tr, yval_tr, _ = roc_curve(l_lab_tr, l_prb_tr)
|
221
|
+
xval_tr = 1 - xval_tr
|
222
|
+
area_tr = auc(xval_tr, yval_tr)
|
223
|
+
|
224
|
+
min_x = 0
|
225
|
+
min_y = 0
|
226
|
+
if 'min' in self._cfg['plotting']['roc']:
|
227
|
+
[min_x, min_y] = self._cfg['plotting']['roc']['min']
|
228
|
+
|
229
|
+
plt.plot(xval_ts, yval_ts, color='b', label=f'Test: {area_ts:.3f}')
|
230
|
+
plt.plot(xval_tr, yval_tr, color='r', label=f'Train: {area_tr:.3f}')
|
231
|
+
plt.xlabel('Signal efficiency')
|
232
|
+
plt.ylabel('Background efficiency')
|
233
|
+
plt.title(f'Fold: {ifold}')
|
234
|
+
plt.xlim(min_x, 1)
|
235
|
+
plt.ylim(min_y, 1)
|
236
|
+
plt.legend()
|
237
|
+
plt.savefig(f'{val_dir}/roc.png')
|
238
|
+
plt.close()
|
239
|
+
# ---------------------------------------------
|
240
|
+
def _plot_features(self):
|
241
|
+
'''
|
242
|
+
Will plot the features, based on the settings in the config
|
243
|
+
'''
|
244
|
+
d_cfg = self._cfg['plotting']['features']
|
245
|
+
ptr = Plotter(d_rdf = {'Signal' : self._rdf_sig, 'Background' : self._rdf_bkg}, cfg=d_cfg)
|
246
|
+
ptr.run()
|
247
|
+
# ---------------------------------------------
|
248
|
+
def run(self):
|
249
|
+
'''
|
250
|
+
Will do the training
|
251
|
+
'''
|
252
|
+
self._plot_features()
|
253
|
+
|
254
|
+
l_mod = self._get_models()
|
255
|
+
for ifold, mod in enumerate(l_mod):
|
256
|
+
self._save_model(mod, ifold)
|
257
|
+
# ---------------------------------------------
|
dmu/ml/utilities.py
ADDED
@@ -0,0 +1,132 @@
|
|
1
|
+
'''
|
2
|
+
Module containing utility functions for ML tools
|
3
|
+
'''
|
4
|
+
|
5
|
+
import hashlib
|
6
|
+
from typing import Union
|
7
|
+
|
8
|
+
import numpy
|
9
|
+
import pandas as pnd
|
10
|
+
|
11
|
+
from dmu.logging.log_store import LogStore
|
12
|
+
|
13
|
+
log = LogStore.add_logger('dmu:ml:utilities')
|
14
|
+
# ---------------------------------------------
|
15
|
+
# Patch dataframe with features
|
16
|
+
# ---------------------------------------------
|
17
|
+
def patch_and_tag(df : pnd.DataFrame, value : float = 0) -> pnd.DataFrame:
|
18
|
+
'''
|
19
|
+
Takes panda dataframe, replaces NaNs with value introduced, by default 0
|
20
|
+
Returns array of indices where the replacement happened
|
21
|
+
'''
|
22
|
+
l_nan = df.index[df.isna().any(axis=1)].tolist()
|
23
|
+
nnan = len(l_nan)
|
24
|
+
if nnan == 0:
|
25
|
+
log.debug('No NaNs found')
|
26
|
+
return df
|
27
|
+
|
28
|
+
log.warning(f'Found {nnan} NaNs, patching them with {value}')
|
29
|
+
|
30
|
+
df_pa = df.fillna(value)
|
31
|
+
|
32
|
+
df_pa.attrs['patched_indices'] = numpy.array(l_nan)
|
33
|
+
|
34
|
+
return df_pa
|
35
|
+
# ---------------------------------------------
|
36
|
+
# Cleanup of dataframe with features
|
37
|
+
# ---------------------------------------------
|
38
|
+
def cleanup(df : pnd.DataFrame) -> pnd.DataFrame:
|
39
|
+
'''
|
40
|
+
Takes pandas dataframe with features for classification
|
41
|
+
Removes repeated entries and entries with nans
|
42
|
+
Returns dataframe
|
43
|
+
'''
|
44
|
+
df = _remove_repeated(df)
|
45
|
+
df = _remove_nans(df)
|
46
|
+
|
47
|
+
return df
|
48
|
+
# ---------------------------------------------
|
49
|
+
def _remove_nans(df : pnd.DataFrame) -> pnd.DataFrame:
|
50
|
+
if not df.isna().any().any():
|
51
|
+
log.debug('No NaNs found in dataframe')
|
52
|
+
return df
|
53
|
+
|
54
|
+
ninit = len(df)
|
55
|
+
df = df.dropna()
|
56
|
+
nfinl = len(df)
|
57
|
+
|
58
|
+
log.warning(f'NaNs found, cleaning dataset: {ninit} -> {nfinl}')
|
59
|
+
|
60
|
+
return df
|
61
|
+
# ---------------------------------------------
|
62
|
+
def _remove_repeated(df : pnd.DataFrame) -> pnd.DataFrame:
|
63
|
+
l_hash = get_hashes(df, rvalue='list')
|
64
|
+
s_hash = set(l_hash)
|
65
|
+
|
66
|
+
ninit = len(l_hash)
|
67
|
+
nfinl = len(s_hash)
|
68
|
+
|
69
|
+
if ninit == nfinl:
|
70
|
+
log.debug('No cleaning needed for dataframe')
|
71
|
+
return df
|
72
|
+
|
73
|
+
log.warning(f'Repeated entries found, cleaning up: {ninit} -> {nfinl}')
|
74
|
+
|
75
|
+
df['hash_index'] = l_hash
|
76
|
+
df = df.set_index('hash_index', drop=True)
|
77
|
+
df_clean = df[~df.index.duplicated(keep='first')]
|
78
|
+
|
79
|
+
if not isinstance(df_clean, pnd.DataFrame):
|
80
|
+
raise ValueError('Cleaning did not return pandas dataframe')
|
81
|
+
|
82
|
+
return df_clean
|
83
|
+
# ----------------------------------
|
84
|
+
# ---------------------------------------------
|
85
|
+
def get_hashes(df_ft : pnd.DataFrame, rvalue : str ='set') -> Union[set, list]:
|
86
|
+
'''
|
87
|
+
Will return hashes for each row in the feature dataframe
|
88
|
+
|
89
|
+
rvalue (str): Return value, can be a set or a list
|
90
|
+
'''
|
91
|
+
|
92
|
+
if rvalue == 'set':
|
93
|
+
res = { hash_from_row(row) for _, row in df_ft.iterrows() }
|
94
|
+
elif rvalue == 'list':
|
95
|
+
res = [ hash_from_row(row) for _, row in df_ft.iterrows() ]
|
96
|
+
else:
|
97
|
+
log.error(f'Invalid return value: {rvalue}')
|
98
|
+
raise ValueError
|
99
|
+
|
100
|
+
return res
|
101
|
+
# ----------------------------------
|
102
|
+
def hash_from_row(row):
|
103
|
+
'''
|
104
|
+
Will return a hash from a pandas dataframe row
|
105
|
+
corresponding to an event
|
106
|
+
'''
|
107
|
+
l_val = [ str(val) for val in row ]
|
108
|
+
row_str = ','.join(l_val)
|
109
|
+
row_str = row_str.encode('utf-8')
|
110
|
+
|
111
|
+
hsh = hashlib.sha256()
|
112
|
+
hsh.update(row_str)
|
113
|
+
|
114
|
+
hsh_val = hsh.hexdigest()
|
115
|
+
|
116
|
+
return hsh_val
|
117
|
+
# ----------------------------------
|
118
|
+
def index_with_hashes(df):
|
119
|
+
'''
|
120
|
+
Will:
|
121
|
+
- take dataframe with features
|
122
|
+
- calculate hashes and add them as the index column
|
123
|
+
- drop old index column
|
124
|
+
'''
|
125
|
+
|
126
|
+
l_hash = get_hashes(df, rvalue='list')
|
127
|
+
ind_hsh= pnd.Index(l_hash)
|
128
|
+
|
129
|
+
df = df.set_index(ind_hsh, drop=True)
|
130
|
+
|
131
|
+
return df
|
132
|
+
# ----------------------------------
|
dmu/plotting/plotter.py
ADDED
@@ -0,0 +1,227 @@
|
|
1
|
+
'''
|
2
|
+
Module containing plotter class
|
3
|
+
'''
|
4
|
+
|
5
|
+
import os
|
6
|
+
import math
|
7
|
+
from typing import Union
|
8
|
+
|
9
|
+
import numpy
|
10
|
+
import matplotlib.pyplot as plt
|
11
|
+
|
12
|
+
from ROOT import RDataFrame
|
13
|
+
from dmu.logging.log_store import LogStore
|
14
|
+
|
15
|
+
log = LogStore.add_logger('dmu:plotting:Plotter')
|
16
|
+
# --------------------------------------------
|
17
|
+
class Plotter:
|
18
|
+
'''
|
19
|
+
Base class of Plotter1D and Plotter2D
|
20
|
+
'''
|
21
|
+
#-------------------------------------
|
22
|
+
def __init__(self, d_rdf=None, cfg=None):
|
23
|
+
if not isinstance( cfg, dict):
|
24
|
+
raise ValueError('Config dictionary not passed')
|
25
|
+
|
26
|
+
if not isinstance(d_rdf, dict):
|
27
|
+
raise ValueError('Dataframe dictionary not passed')
|
28
|
+
|
29
|
+
self._d_cfg = cfg
|
30
|
+
self._d_rdf : dict[str, RDataFrame] = { name : self._preprocess_rdf(rdf) for name, rdf in d_rdf.items()}
|
31
|
+
self._d_wgt : Union[dict[str, Union[numpy.ndarray, None]], None]
|
32
|
+
#-------------------------------------
|
33
|
+
def _check_quantile(self, qnt : float):
|
34
|
+
'''
|
35
|
+
Will check validity of quantile
|
36
|
+
'''
|
37
|
+
|
38
|
+
if 0.5 < qnt <= 1.0:
|
39
|
+
return
|
40
|
+
|
41
|
+
raise ValueError(f'Invalid quantile: {qnt:.3e}, value needs to be in (0.5, 1.0] interval')
|
42
|
+
#-------------------------------------
|
43
|
+
def _find_bounds(self, d_data : dict, qnt : float = 0.98):
|
44
|
+
'''
|
45
|
+
Will take dictionary between kinds of data and numpy array
|
46
|
+
Will return tuple with bounds, where 95% of the data is found
|
47
|
+
'''
|
48
|
+
self._check_quantile(qnt)
|
49
|
+
|
50
|
+
l_max = []
|
51
|
+
l_min = []
|
52
|
+
|
53
|
+
for arr_val in d_data.values():
|
54
|
+
minv = numpy.quantile(arr_val, 1 - qnt)
|
55
|
+
maxv = numpy.quantile(arr_val, qnt)
|
56
|
+
|
57
|
+
l_max.append(maxv)
|
58
|
+
l_min.append(minv)
|
59
|
+
|
60
|
+
minx = min(l_min)
|
61
|
+
maxx = max(l_max)
|
62
|
+
|
63
|
+
if minx >= maxx:
|
64
|
+
raise ValueError(f'Could not calculate bounds correctly: [{minx:.3e}, {maxx:.3e}]')
|
65
|
+
|
66
|
+
return minx, maxx
|
67
|
+
#-------------------------------------
|
68
|
+
def _preprocess_rdf(self, rdf):
|
69
|
+
'''
|
70
|
+
rdf (RDataFrame): ROOT dataframe
|
71
|
+
|
72
|
+
returns preprocessed dataframe
|
73
|
+
'''
|
74
|
+
|
75
|
+
rdf = self._define_vars(rdf)
|
76
|
+
if 'selection' in self._d_cfg:
|
77
|
+
rdf = self._apply_selection(rdf)
|
78
|
+
rdf = self._max_ran_entries(rdf)
|
79
|
+
|
80
|
+
return rdf
|
81
|
+
#-------------------------------------
|
82
|
+
def _define_vars(self, rdf):
|
83
|
+
'''
|
84
|
+
Will define extra columns in dataframe and return updated dataframe
|
85
|
+
'''
|
86
|
+
|
87
|
+
if 'definitions' not in self._d_cfg:
|
88
|
+
log.debug('No definitions section found, returning same RDF')
|
89
|
+
return rdf
|
90
|
+
|
91
|
+
d_def = self._d_cfg['definitions']
|
92
|
+
|
93
|
+
log.info('Defining extra variables')
|
94
|
+
for name, expr in d_def.items():
|
95
|
+
log.debug(f'{name:<30}{expr:<150}')
|
96
|
+
rdf = rdf.Define(name, expr)
|
97
|
+
|
98
|
+
return rdf
|
99
|
+
#-------------------------------------
|
100
|
+
def _apply_selection(self, rdf):
|
101
|
+
'''
|
102
|
+
Will take dataframe, apply selection and return dataframe
|
103
|
+
'''
|
104
|
+
if 'cuts' not in self._d_cfg['selection']:
|
105
|
+
log.debug('Cuts not found in selection section, not applying any cuts')
|
106
|
+
return rdf
|
107
|
+
|
108
|
+
d_cut = self._d_cfg['selection']['cuts']
|
109
|
+
|
110
|
+
log.info('Applying cuts')
|
111
|
+
for name, cut in d_cut.items():
|
112
|
+
log.debug(f'{name:<50}{cut:<150}')
|
113
|
+
rdf = rdf.Filter(cut, name)
|
114
|
+
|
115
|
+
return rdf
|
116
|
+
#-------------------------------------
|
117
|
+
def _max_ran_entries(self, rdf):
|
118
|
+
'''
|
119
|
+
Will take dataframe and randomly drop events
|
120
|
+
'''
|
121
|
+
|
122
|
+
if 'max_ran_entries' not in self._d_cfg['selection']:
|
123
|
+
log.debug('Cuts not found in selection section, not applying any cuts')
|
124
|
+
return rdf
|
125
|
+
|
126
|
+
tot_entries = rdf.Count().GetValue()
|
127
|
+
max_entries = self._d_cfg['selection']['max_ran_entries']
|
128
|
+
|
129
|
+
if tot_entries < max_entries:
|
130
|
+
log.debug(f'Not dropping dandom entries: {tot_entries} < {max_entries}')
|
131
|
+
return rdf
|
132
|
+
|
133
|
+
prescale = math.floor(tot_entries / max_entries)
|
134
|
+
if prescale < 2:
|
135
|
+
log.debug(f'Not dropping random entries, prescale is below 2: {tot_entries}/{max_entries}')
|
136
|
+
return rdf
|
137
|
+
|
138
|
+
rdf = rdf.Filter(f'rdfentry_ % {prescale} == 0', 'max_ran_entries')
|
139
|
+
|
140
|
+
fnl_entries = rdf.Count().GetValue()
|
141
|
+
|
142
|
+
log.info(f'Dropped entries randomly: {tot_entries} -> {fnl_entries}')
|
143
|
+
|
144
|
+
return rdf
|
145
|
+
# --------------------------------------------
|
146
|
+
def _print_weights(self, arr_wgt : Union[numpy.ndarray, None], var : str, sample : str) -> None:
|
147
|
+
if arr_wgt is None:
|
148
|
+
log.debug(f'Not using weights for {sample}:{var}')
|
149
|
+
return
|
150
|
+
|
151
|
+
num_wgt = len(arr_wgt)
|
152
|
+
sum_wgt = numpy.sum(arr_wgt)
|
153
|
+
|
154
|
+
log.debug(f'Using weights [{num_wgt},{sum_wgt:.0f}] for {var}')
|
155
|
+
# --------------------------------------------
|
156
|
+
def _get_fig_size(self):
|
157
|
+
'''
|
158
|
+
Will read size list from config dictionary if found
|
159
|
+
other wise will return None
|
160
|
+
'''
|
161
|
+
if 'general' not in self._d_cfg:
|
162
|
+
return None
|
163
|
+
|
164
|
+
if 'size' not in self._d_cfg['general']:
|
165
|
+
return None
|
166
|
+
|
167
|
+
fig_size = self._d_cfg['general']['size']
|
168
|
+
|
169
|
+
return fig_size
|
170
|
+
#-------------------------------------
|
171
|
+
def _get_weights(self, var) -> Union[dict[str, Union[numpy.ndarray, None]], None]:
|
172
|
+
d_cfg = self._d_cfg['plots'][var]
|
173
|
+
if 'weights' not in d_cfg:
|
174
|
+
return None
|
175
|
+
|
176
|
+
if hasattr(self, '_d_wgt'):
|
177
|
+
return self._d_wgt
|
178
|
+
|
179
|
+
wgt_name = d_cfg['weights']
|
180
|
+
d_weight = {sam_name : self._read_weights(wgt_name, rdf) for sam_name, rdf in self._d_rdf.items()}
|
181
|
+
|
182
|
+
self._d_wgt = d_weight
|
183
|
+
|
184
|
+
return d_weight
|
185
|
+
# --------------------------------------------
|
186
|
+
def _read_weights(self, name : str, rdf : RDataFrame) -> Union[numpy.ndarray, None]:
|
187
|
+
v_col = rdf.GetColumnNames()
|
188
|
+
l_col = [ col.c_str() for col in v_col ]
|
189
|
+
|
190
|
+
if name not in l_col:
|
191
|
+
log.debug(f'Weight {name} not found')
|
192
|
+
return None
|
193
|
+
|
194
|
+
arr_wgt = rdf.AsNumpy([name])[name]
|
195
|
+
|
196
|
+
return arr_wgt
|
197
|
+
#-------------------------------------
|
198
|
+
def _get_plot_name(self, var : str) -> str:
|
199
|
+
if 'plots_2d' in self._d_cfg:
|
200
|
+
#For 2D plots the name will always be specified in the config
|
201
|
+
return var
|
202
|
+
|
203
|
+
if 'name' not in self._d_cfg['plots'][var]:
|
204
|
+
# For 1D plots the name can be taken from variable name itself or specified
|
205
|
+
return var
|
206
|
+
|
207
|
+
return self._d_cfg['plots'][var]['name']
|
208
|
+
#-------------------------------------
|
209
|
+
def _save_plot(self, var):
|
210
|
+
'''
|
211
|
+
Will save to PNG:
|
212
|
+
|
213
|
+
var (str) : Name of variable, needed for plot name
|
214
|
+
'''
|
215
|
+
plt.legend()
|
216
|
+
|
217
|
+
plt_dir = self._d_cfg['saving']['plt_dir']
|
218
|
+
os.makedirs(plt_dir, exist_ok=True)
|
219
|
+
|
220
|
+
name = self._get_plot_name(var)
|
221
|
+
|
222
|
+
plot_path = f'{plt_dir}/{name}.png'
|
223
|
+
log.info(f'Saving to: {plot_path}')
|
224
|
+
plt.tight_layout()
|
225
|
+
plt.savefig(plot_path)
|
226
|
+
plt.close(var)
|
227
|
+
# --------------------------------------------
|