data-manipulation-utilities 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. data_manipulation_utilities-0.0.1.dist-info/METADATA +713 -0
  2. data_manipulation_utilities-0.0.1.dist-info/RECORD +45 -0
  3. data_manipulation_utilities-0.0.1.dist-info/WHEEL +5 -0
  4. data_manipulation_utilities-0.0.1.dist-info/entry_points.txt +6 -0
  5. data_manipulation_utilities-0.0.1.dist-info/top_level.txt +3 -0
  6. dmu/arrays/utilities.py +55 -0
  7. dmu/dataframe/dataframe.py +36 -0
  8. dmu/generic/utilities.py +69 -0
  9. dmu/logging/log_store.py +129 -0
  10. dmu/ml/cv_classifier.py +122 -0
  11. dmu/ml/cv_predict.py +152 -0
  12. dmu/ml/train_mva.py +257 -0
  13. dmu/ml/utilities.py +132 -0
  14. dmu/plotting/plotter.py +227 -0
  15. dmu/plotting/plotter_1d.py +113 -0
  16. dmu/plotting/plotter_2d.py +87 -0
  17. dmu/rdataframe/atr_mgr.py +79 -0
  18. dmu/rdataframe/utilities.py +72 -0
  19. dmu/rfile/rfprinter.py +91 -0
  20. dmu/rfile/utilities.py +34 -0
  21. dmu/stats/fitter.py +515 -0
  22. dmu/stats/function.py +314 -0
  23. dmu/stats/utilities.py +134 -0
  24. dmu/testing/utilities.py +119 -0
  25. dmu/text/transformer.py +182 -0
  26. dmu_data/__init__.py +0 -0
  27. dmu_data/ml/tests/train_mva.yaml +37 -0
  28. dmu_data/plotting/tests/2d.yaml +14 -0
  29. dmu_data/plotting/tests/fig_size.yaml +13 -0
  30. dmu_data/plotting/tests/high_stat.yaml +22 -0
  31. dmu_data/plotting/tests/name.yaml +14 -0
  32. dmu_data/plotting/tests/no_bounds.yaml +12 -0
  33. dmu_data/plotting/tests/simple.yaml +8 -0
  34. dmu_data/plotting/tests/title.yaml +14 -0
  35. dmu_data/plotting/tests/weights.yaml +13 -0
  36. dmu_data/text/transform.toml +4 -0
  37. dmu_data/text/transform.txt +6 -0
  38. dmu_data/text/transform_set.toml +8 -0
  39. dmu_data/text/transform_set.txt +6 -0
  40. dmu_data/text/transform_trf.txt +12 -0
  41. dmu_scripts/physics/check_truth.py +121 -0
  42. dmu_scripts/rfile/compare_root_files.py +299 -0
  43. dmu_scripts/rfile/print_trees.py +35 -0
  44. dmu_scripts/ssh/coned.py +168 -0
  45. dmu_scripts/text/transform_text.py +46 -0
dmu/ml/train_mva.py ADDED
@@ -0,0 +1,257 @@
1
+ '''
2
+ Module with TrainMva class
3
+ '''
4
+ import os
5
+ from typing import Union
6
+
7
+ import joblib
8
+ import pandas as pnd
9
+ import numpy
10
+ import matplotlib.pyplot as plt
11
+
12
+ from sklearn.metrics import roc_curve, auc
13
+ from sklearn.model_selection import StratifiedKFold
14
+
15
+ from ROOT import RDataFrame
16
+
17
+ import dmu.ml.utilities as ut
18
+ from dmu.ml.cv_classifier import CVClassifier as cls
19
+ from dmu.plotting.plotter_1d import Plotter1D as Plotter
20
+ from dmu.logging.log_store import LogStore
21
+
22
+ log = LogStore.add_logger('data_checks:train_mva')
23
+ # ---------------------------------------------
24
+ class TrainMva:
25
+ '''
26
+ Interface to scikit learn used to train classifier
27
+ '''
28
+ # ---------------------------------------------
29
+ def __init__(self, bkg=None, sig=None, cfg=None):
30
+ '''
31
+ bkg (ROOT dataframe): Holds real data
32
+ sig (ROOT dataframe): Holds simulation
33
+ cfg (dict) : Dictionary storing configuration for training
34
+ '''
35
+ if bkg is None:
36
+ raise ValueError('Background dataframe is not a ROOT dataframe')
37
+
38
+ if sig is None:
39
+ raise ValueError('Signal dataframe is not a ROOT dataframe')
40
+
41
+ if not isinstance(cfg, dict):
42
+ raise ValueError('Config dictionary is not a dictionary')
43
+
44
+ self._rdf_bkg = bkg
45
+ self._rdf_sig = sig
46
+ self._cfg = cfg if cfg is not None else {}
47
+
48
+ self._l_model : cls
49
+
50
+ self._l_ft_name = self._cfg['training']['features']
51
+
52
+ self._df_ft, self._l_lab = self._get_inputs()
53
+ # ---------------------------------------------
54
+ def _get_inputs(self) -> tuple[pnd.DataFrame, numpy.ndarray]:
55
+ log.info('Getting signal')
56
+ df_sig, arr_lab_sig = self._get_sample_inputs(self._rdf_sig, label = 1)
57
+
58
+ log.info('Getting background')
59
+ df_bkg, arr_lab_bkg = self._get_sample_inputs(self._rdf_bkg, label = 0)
60
+
61
+ df = pnd.concat([df_sig, df_bkg], axis=0)
62
+ arr_lab = numpy.concatenate([arr_lab_sig, arr_lab_bkg])
63
+
64
+ return df, arr_lab
65
+ # ---------------------------------------------
66
+ def _get_sample_inputs(self, rdf : RDataFrame, label : int) -> tuple[pnd.DataFrame, numpy.ndarray]:
67
+ d_ft = rdf.AsNumpy(self._l_ft_name)
68
+ df = pnd.DataFrame(d_ft)
69
+ df = ut.cleanup(df)
70
+ l_lab= len(df) * [label]
71
+
72
+ return df, numpy.array(l_lab)
73
+ # ---------------------------------------------
74
+ def _get_model(self, arr_index : numpy.ndarray) -> cls:
75
+ model = cls(cfg = self._cfg)
76
+ df_ft = self._df_ft.iloc[arr_index]
77
+ l_lab = self._l_lab[arr_index]
78
+
79
+ log.debug(f'Training feature shape: {df_ft.shape}')
80
+ log.debug(f'Training label size: {len(l_lab)}')
81
+
82
+ model.fit(df_ft, l_lab)
83
+
84
+ return model
85
+ # ---------------------------------------------
86
+ def _get_models(self):
87
+ # pylint: disable = too-many-locals
88
+ '''
89
+ Will create models, train them and return them
90
+ '''
91
+ nfold = self._cfg['training']['nfold']
92
+ rdmst = self._cfg['training']['rdm_stat']
93
+
94
+ kfold = StratifiedKFold(n_splits=nfold, shuffle=True, random_state=rdmst)
95
+
96
+ l_model=[]
97
+ ifold=0
98
+ for arr_itr, arr_its in kfold.split(self._df_ft, self._l_lab):
99
+ log.debug(20 * '-')
100
+ log.info(f'Training fold: {ifold}')
101
+ log.debug(20 * '-')
102
+ model = self._get_model(arr_itr)
103
+ l_model.append(model)
104
+
105
+ arr_sig_sig_tr, arr_sig_bkg_tr, arr_sig_all_tr, arr_lab_tr = self._get_scores(model, arr_itr, on_training_ok= True)
106
+ arr_sig_sig_ts, arr_sig_bkg_ts, arr_sig_all_ts, arr_lab_ts = self._get_scores(model, arr_its, on_training_ok=False)
107
+
108
+ self._plot_scores(arr_sig_sig_tr, arr_sig_sig_ts, arr_sig_bkg_tr, arr_sig_bkg_ts, ifold)
109
+
110
+ self._plot_roc(arr_lab_ts, arr_sig_all_ts, arr_lab_tr, arr_sig_all_tr, ifold)
111
+
112
+ ifold+=1
113
+
114
+ return l_model
115
+ # ---------------------------------------------
116
+ def _get_scores(self, model : cls, arr_index : numpy.ndarray, on_training_ok : bool) -> tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray, numpy.ndarray]:
117
+ '''
118
+ Returns a tuple of four arrays
119
+
120
+ arr_sig : Signal probabilities for signal
121
+ arr_bkg : Signal probabilities for background
122
+ arr_all : Signal probabilities for both
123
+ arr_lab : Labels for both
124
+ '''
125
+ nentries = len(arr_index)
126
+ log.debug(f'Getting {nentries} signal probabilities')
127
+
128
+ df_ft = self._df_ft.iloc[arr_index]
129
+ arr_prob = model.predict_proba(df_ft, on_training_ok=on_training_ok)
130
+ arr_lab = self._l_lab[arr_index]
131
+
132
+ l_all = [ sig_prob for [_, sig_prob] in arr_prob ]
133
+ arr_all = numpy.array(l_all)
134
+
135
+ arr_sig, arr_bkg= self._split_scores(arr_prob=arr_prob, arr_label=arr_lab)
136
+
137
+ return arr_sig, arr_bkg, arr_all, arr_lab
138
+ # ---------------------------------------------
139
+ def _split_scores(self, arr_prob : numpy.ndarray, arr_label : numpy.ndarray) -> tuple[numpy.ndarray, numpy.ndarray]:
140
+ '''
141
+ Will split the testing scores (predictions) based on the training scores
142
+
143
+ tst is a list of lists as [p_bkg, p_sig]
144
+ '''
145
+
146
+ l_sig = [ prb[1] for prb, lab in zip(arr_prob, arr_label) if lab == 1]
147
+ l_bkg = [ prb[1] for prb, lab in zip(arr_prob, arr_label) if lab == 0]
148
+
149
+ arr_sig = numpy.array(l_sig)
150
+ arr_bkg = numpy.array(l_bkg)
151
+
152
+ return arr_sig, arr_bkg
153
+ # ---------------------------------------------
154
+ def _save_model(self, model, ifold):
155
+ '''
156
+ Saves a model, associated to a specific fold
157
+ '''
158
+ model_path = self._cfg['saving']['path']
159
+ if os.path.isfile(model_path):
160
+ log.info(f'Model found in {model_path}, not saving')
161
+ return
162
+
163
+ dir_name = os.path.dirname(model_path)
164
+ os.makedirs(dir_name, exist_ok=True)
165
+
166
+ model_path = model_path.replace('.pkl', f'_{ifold:03}.pkl')
167
+
168
+ log.info(f'Saving model to: {model_path}')
169
+ joblib.dump(model, model_path)
170
+ # ---------------------------------------------
171
+ def _plot_scores(self, arr_sig_trn, arr_sig_tst, arr_bkg_trn, arr_bkg_tst, ifold):
172
+ # pylint: disable = too-many-arguments, too-many-positional-arguments
173
+ '''
174
+ Will plot an array of scores, associated to a given fold
175
+ '''
176
+ log.debug(f'Plotting scores for {ifold} fold')
177
+
178
+ if 'val_dir' not in self._cfg['plotting']:
179
+ log.warning('Scores path not passed, not plotting scores')
180
+ return
181
+
182
+ val_dir = self._cfg['plotting']['val_dir']
183
+ val_dir = f'{val_dir}/fold_{ifold:03}'
184
+ os.makedirs(val_dir, exist_ok=True)
185
+
186
+ plt.hist(arr_sig_trn, alpha = 0.3, bins=50, range=(0,1), color='b', density=True, label='Signal Train')
187
+ plt.hist(arr_sig_tst, histtype='step', bins=50, range=(0,1), color='b', density=True, label='Signal Test')
188
+
189
+ plt.hist(arr_bkg_trn, alpha = 0.3, bins=50, range=(0,1), color='r', density=True, label='Background Train')
190
+ plt.hist(arr_bkg_tst, histtype='step', bins=50, range=(0,1), color='r', density=True, label='Background Test')
191
+
192
+ plt.legend()
193
+ plt.title(f'Fold: {ifold}')
194
+ plt.xlabel('Signal probability')
195
+ plt.ylabel('Normalized')
196
+ plt.savefig(f'{val_dir}/scores.png')
197
+ plt.close()
198
+ # ---------------------------------------------
199
+ def _plot_roc(self,
200
+ l_lab_ts : numpy.ndarray,
201
+ l_prb_ts : numpy.ndarray,
202
+ l_lab_tr : numpy.ndarray,
203
+ l_prb_tr : numpy.ndarray,
204
+ ifold : int):
205
+ '''
206
+ Takes the labels and the probabilities and plots ROC
207
+ curve for given fold
208
+ '''
209
+ # pylint: disable = too-many-arguments, too-many-positional-arguments
210
+ log.debug(f'Plotting ROC curve for {ifold} fold')
211
+
212
+ val_dir = self._cfg['plotting']['val_dir']
213
+ val_dir = f'{val_dir}/fold_{ifold:03}'
214
+ os.makedirs(val_dir, exist_ok=True)
215
+
216
+ xval_ts, yval_ts, _ = roc_curve(l_lab_ts, l_prb_ts)
217
+ xval_ts = 1 - xval_ts
218
+ area_ts = auc(xval_ts, yval_ts)
219
+
220
+ xval_tr, yval_tr, _ = roc_curve(l_lab_tr, l_prb_tr)
221
+ xval_tr = 1 - xval_tr
222
+ area_tr = auc(xval_tr, yval_tr)
223
+
224
+ min_x = 0
225
+ min_y = 0
226
+ if 'min' in self._cfg['plotting']['roc']:
227
+ [min_x, min_y] = self._cfg['plotting']['roc']['min']
228
+
229
+ plt.plot(xval_ts, yval_ts, color='b', label=f'Test: {area_ts:.3f}')
230
+ plt.plot(xval_tr, yval_tr, color='r', label=f'Train: {area_tr:.3f}')
231
+ plt.xlabel('Signal efficiency')
232
+ plt.ylabel('Background efficiency')
233
+ plt.title(f'Fold: {ifold}')
234
+ plt.xlim(min_x, 1)
235
+ plt.ylim(min_y, 1)
236
+ plt.legend()
237
+ plt.savefig(f'{val_dir}/roc.png')
238
+ plt.close()
239
+ # ---------------------------------------------
240
+ def _plot_features(self):
241
+ '''
242
+ Will plot the features, based on the settings in the config
243
+ '''
244
+ d_cfg = self._cfg['plotting']['features']
245
+ ptr = Plotter(d_rdf = {'Signal' : self._rdf_sig, 'Background' : self._rdf_bkg}, cfg=d_cfg)
246
+ ptr.run()
247
+ # ---------------------------------------------
248
+ def run(self):
249
+ '''
250
+ Will do the training
251
+ '''
252
+ self._plot_features()
253
+
254
+ l_mod = self._get_models()
255
+ for ifold, mod in enumerate(l_mod):
256
+ self._save_model(mod, ifold)
257
+ # ---------------------------------------------
dmu/ml/utilities.py ADDED
@@ -0,0 +1,132 @@
1
+ '''
2
+ Module containing utility functions for ML tools
3
+ '''
4
+
5
+ import hashlib
6
+ from typing import Union
7
+
8
+ import numpy
9
+ import pandas as pnd
10
+
11
+ from dmu.logging.log_store import LogStore
12
+
13
+ log = LogStore.add_logger('dmu:ml:utilities')
14
+ # ---------------------------------------------
15
+ # Patch dataframe with features
16
+ # ---------------------------------------------
17
+ def patch_and_tag(df : pnd.DataFrame, value : float = 0) -> pnd.DataFrame:
18
+ '''
19
+ Takes panda dataframe, replaces NaNs with value introduced, by default 0
20
+ Returns array of indices where the replacement happened
21
+ '''
22
+ l_nan = df.index[df.isna().any(axis=1)].tolist()
23
+ nnan = len(l_nan)
24
+ if nnan == 0:
25
+ log.debug('No NaNs found')
26
+ return df
27
+
28
+ log.warning(f'Found {nnan} NaNs, patching them with {value}')
29
+
30
+ df_pa = df.fillna(value)
31
+
32
+ df_pa.attrs['patched_indices'] = numpy.array(l_nan)
33
+
34
+ return df_pa
35
+ # ---------------------------------------------
36
+ # Cleanup of dataframe with features
37
+ # ---------------------------------------------
38
+ def cleanup(df : pnd.DataFrame) -> pnd.DataFrame:
39
+ '''
40
+ Takes pandas dataframe with features for classification
41
+ Removes repeated entries and entries with nans
42
+ Returns dataframe
43
+ '''
44
+ df = _remove_repeated(df)
45
+ df = _remove_nans(df)
46
+
47
+ return df
48
+ # ---------------------------------------------
49
+ def _remove_nans(df : pnd.DataFrame) -> pnd.DataFrame:
50
+ if not df.isna().any().any():
51
+ log.debug('No NaNs found in dataframe')
52
+ return df
53
+
54
+ ninit = len(df)
55
+ df = df.dropna()
56
+ nfinl = len(df)
57
+
58
+ log.warning(f'NaNs found, cleaning dataset: {ninit} -> {nfinl}')
59
+
60
+ return df
61
+ # ---------------------------------------------
62
+ def _remove_repeated(df : pnd.DataFrame) -> pnd.DataFrame:
63
+ l_hash = get_hashes(df, rvalue='list')
64
+ s_hash = set(l_hash)
65
+
66
+ ninit = len(l_hash)
67
+ nfinl = len(s_hash)
68
+
69
+ if ninit == nfinl:
70
+ log.debug('No cleaning needed for dataframe')
71
+ return df
72
+
73
+ log.warning(f'Repeated entries found, cleaning up: {ninit} -> {nfinl}')
74
+
75
+ df['hash_index'] = l_hash
76
+ df = df.set_index('hash_index', drop=True)
77
+ df_clean = df[~df.index.duplicated(keep='first')]
78
+
79
+ if not isinstance(df_clean, pnd.DataFrame):
80
+ raise ValueError('Cleaning did not return pandas dataframe')
81
+
82
+ return df_clean
83
+ # ----------------------------------
84
+ # ---------------------------------------------
85
+ def get_hashes(df_ft : pnd.DataFrame, rvalue : str ='set') -> Union[set, list]:
86
+ '''
87
+ Will return hashes for each row in the feature dataframe
88
+
89
+ rvalue (str): Return value, can be a set or a list
90
+ '''
91
+
92
+ if rvalue == 'set':
93
+ res = { hash_from_row(row) for _, row in df_ft.iterrows() }
94
+ elif rvalue == 'list':
95
+ res = [ hash_from_row(row) for _, row in df_ft.iterrows() ]
96
+ else:
97
+ log.error(f'Invalid return value: {rvalue}')
98
+ raise ValueError
99
+
100
+ return res
101
+ # ----------------------------------
102
+ def hash_from_row(row):
103
+ '''
104
+ Will return a hash from a pandas dataframe row
105
+ corresponding to an event
106
+ '''
107
+ l_val = [ str(val) for val in row ]
108
+ row_str = ','.join(l_val)
109
+ row_str = row_str.encode('utf-8')
110
+
111
+ hsh = hashlib.sha256()
112
+ hsh.update(row_str)
113
+
114
+ hsh_val = hsh.hexdigest()
115
+
116
+ return hsh_val
117
+ # ----------------------------------
118
+ def index_with_hashes(df):
119
+ '''
120
+ Will:
121
+ - take dataframe with features
122
+ - calculate hashes and add them as the index column
123
+ - drop old index column
124
+ '''
125
+
126
+ l_hash = get_hashes(df, rvalue='list')
127
+ ind_hsh= pnd.Index(l_hash)
128
+
129
+ df = df.set_index(ind_hsh, drop=True)
130
+
131
+ return df
132
+ # ----------------------------------
@@ -0,0 +1,227 @@
1
+ '''
2
+ Module containing plotter class
3
+ '''
4
+
5
+ import os
6
+ import math
7
+ from typing import Union
8
+
9
+ import numpy
10
+ import matplotlib.pyplot as plt
11
+
12
+ from ROOT import RDataFrame
13
+ from dmu.logging.log_store import LogStore
14
+
15
+ log = LogStore.add_logger('dmu:plotting:Plotter')
16
+ # --------------------------------------------
17
+ class Plotter:
18
+ '''
19
+ Base class of Plotter1D and Plotter2D
20
+ '''
21
+ #-------------------------------------
22
+ def __init__(self, d_rdf=None, cfg=None):
23
+ if not isinstance( cfg, dict):
24
+ raise ValueError('Config dictionary not passed')
25
+
26
+ if not isinstance(d_rdf, dict):
27
+ raise ValueError('Dataframe dictionary not passed')
28
+
29
+ self._d_cfg = cfg
30
+ self._d_rdf : dict[str, RDataFrame] = { name : self._preprocess_rdf(rdf) for name, rdf in d_rdf.items()}
31
+ self._d_wgt : Union[dict[str, Union[numpy.ndarray, None]], None]
32
+ #-------------------------------------
33
+ def _check_quantile(self, qnt : float):
34
+ '''
35
+ Will check validity of quantile
36
+ '''
37
+
38
+ if 0.5 < qnt <= 1.0:
39
+ return
40
+
41
+ raise ValueError(f'Invalid quantile: {qnt:.3e}, value needs to be in (0.5, 1.0] interval')
42
+ #-------------------------------------
43
+ def _find_bounds(self, d_data : dict, qnt : float = 0.98):
44
+ '''
45
+ Will take dictionary between kinds of data and numpy array
46
+ Will return tuple with bounds, where 95% of the data is found
47
+ '''
48
+ self._check_quantile(qnt)
49
+
50
+ l_max = []
51
+ l_min = []
52
+
53
+ for arr_val in d_data.values():
54
+ minv = numpy.quantile(arr_val, 1 - qnt)
55
+ maxv = numpy.quantile(arr_val, qnt)
56
+
57
+ l_max.append(maxv)
58
+ l_min.append(minv)
59
+
60
+ minx = min(l_min)
61
+ maxx = max(l_max)
62
+
63
+ if minx >= maxx:
64
+ raise ValueError(f'Could not calculate bounds correctly: [{minx:.3e}, {maxx:.3e}]')
65
+
66
+ return minx, maxx
67
+ #-------------------------------------
68
+ def _preprocess_rdf(self, rdf):
69
+ '''
70
+ rdf (RDataFrame): ROOT dataframe
71
+
72
+ returns preprocessed dataframe
73
+ '''
74
+
75
+ rdf = self._define_vars(rdf)
76
+ if 'selection' in self._d_cfg:
77
+ rdf = self._apply_selection(rdf)
78
+ rdf = self._max_ran_entries(rdf)
79
+
80
+ return rdf
81
+ #-------------------------------------
82
+ def _define_vars(self, rdf):
83
+ '''
84
+ Will define extra columns in dataframe and return updated dataframe
85
+ '''
86
+
87
+ if 'definitions' not in self._d_cfg:
88
+ log.debug('No definitions section found, returning same RDF')
89
+ return rdf
90
+
91
+ d_def = self._d_cfg['definitions']
92
+
93
+ log.info('Defining extra variables')
94
+ for name, expr in d_def.items():
95
+ log.debug(f'{name:<30}{expr:<150}')
96
+ rdf = rdf.Define(name, expr)
97
+
98
+ return rdf
99
+ #-------------------------------------
100
+ def _apply_selection(self, rdf):
101
+ '''
102
+ Will take dataframe, apply selection and return dataframe
103
+ '''
104
+ if 'cuts' not in self._d_cfg['selection']:
105
+ log.debug('Cuts not found in selection section, not applying any cuts')
106
+ return rdf
107
+
108
+ d_cut = self._d_cfg['selection']['cuts']
109
+
110
+ log.info('Applying cuts')
111
+ for name, cut in d_cut.items():
112
+ log.debug(f'{name:<50}{cut:<150}')
113
+ rdf = rdf.Filter(cut, name)
114
+
115
+ return rdf
116
+ #-------------------------------------
117
+ def _max_ran_entries(self, rdf):
118
+ '''
119
+ Will take dataframe and randomly drop events
120
+ '''
121
+
122
+ if 'max_ran_entries' not in self._d_cfg['selection']:
123
+ log.debug('Cuts not found in selection section, not applying any cuts')
124
+ return rdf
125
+
126
+ tot_entries = rdf.Count().GetValue()
127
+ max_entries = self._d_cfg['selection']['max_ran_entries']
128
+
129
+ if tot_entries < max_entries:
130
+ log.debug(f'Not dropping dandom entries: {tot_entries} < {max_entries}')
131
+ return rdf
132
+
133
+ prescale = math.floor(tot_entries / max_entries)
134
+ if prescale < 2:
135
+ log.debug(f'Not dropping random entries, prescale is below 2: {tot_entries}/{max_entries}')
136
+ return rdf
137
+
138
+ rdf = rdf.Filter(f'rdfentry_ % {prescale} == 0', 'max_ran_entries')
139
+
140
+ fnl_entries = rdf.Count().GetValue()
141
+
142
+ log.info(f'Dropped entries randomly: {tot_entries} -> {fnl_entries}')
143
+
144
+ return rdf
145
+ # --------------------------------------------
146
+ def _print_weights(self, arr_wgt : Union[numpy.ndarray, None], var : str, sample : str) -> None:
147
+ if arr_wgt is None:
148
+ log.debug(f'Not using weights for {sample}:{var}')
149
+ return
150
+
151
+ num_wgt = len(arr_wgt)
152
+ sum_wgt = numpy.sum(arr_wgt)
153
+
154
+ log.debug(f'Using weights [{num_wgt},{sum_wgt:.0f}] for {var}')
155
+ # --------------------------------------------
156
+ def _get_fig_size(self):
157
+ '''
158
+ Will read size list from config dictionary if found
159
+ other wise will return None
160
+ '''
161
+ if 'general' not in self._d_cfg:
162
+ return None
163
+
164
+ if 'size' not in self._d_cfg['general']:
165
+ return None
166
+
167
+ fig_size = self._d_cfg['general']['size']
168
+
169
+ return fig_size
170
+ #-------------------------------------
171
+ def _get_weights(self, var) -> Union[dict[str, Union[numpy.ndarray, None]], None]:
172
+ d_cfg = self._d_cfg['plots'][var]
173
+ if 'weights' not in d_cfg:
174
+ return None
175
+
176
+ if hasattr(self, '_d_wgt'):
177
+ return self._d_wgt
178
+
179
+ wgt_name = d_cfg['weights']
180
+ d_weight = {sam_name : self._read_weights(wgt_name, rdf) for sam_name, rdf in self._d_rdf.items()}
181
+
182
+ self._d_wgt = d_weight
183
+
184
+ return d_weight
185
+ # --------------------------------------------
186
+ def _read_weights(self, name : str, rdf : RDataFrame) -> Union[numpy.ndarray, None]:
187
+ v_col = rdf.GetColumnNames()
188
+ l_col = [ col.c_str() for col in v_col ]
189
+
190
+ if name not in l_col:
191
+ log.debug(f'Weight {name} not found')
192
+ return None
193
+
194
+ arr_wgt = rdf.AsNumpy([name])[name]
195
+
196
+ return arr_wgt
197
+ #-------------------------------------
198
+ def _get_plot_name(self, var : str) -> str:
199
+ if 'plots_2d' in self._d_cfg:
200
+ #For 2D plots the name will always be specified in the config
201
+ return var
202
+
203
+ if 'name' not in self._d_cfg['plots'][var]:
204
+ # For 1D plots the name can be taken from variable name itself or specified
205
+ return var
206
+
207
+ return self._d_cfg['plots'][var]['name']
208
+ #-------------------------------------
209
+ def _save_plot(self, var):
210
+ '''
211
+ Will save to PNG:
212
+
213
+ var (str) : Name of variable, needed for plot name
214
+ '''
215
+ plt.legend()
216
+
217
+ plt_dir = self._d_cfg['saving']['plt_dir']
218
+ os.makedirs(plt_dir, exist_ok=True)
219
+
220
+ name = self._get_plot_name(var)
221
+
222
+ plot_path = f'{plt_dir}/{name}.png'
223
+ log.info(f'Saving to: {plot_path}')
224
+ plt.tight_layout()
225
+ plt.savefig(plot_path)
226
+ plt.close(var)
227
+ # --------------------------------------------