data-manipulation-utilities 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data_manipulation_utilities-0.0.1.dist-info/METADATA +713 -0
- data_manipulation_utilities-0.0.1.dist-info/RECORD +45 -0
- data_manipulation_utilities-0.0.1.dist-info/WHEEL +5 -0
- data_manipulation_utilities-0.0.1.dist-info/entry_points.txt +6 -0
- data_manipulation_utilities-0.0.1.dist-info/top_level.txt +3 -0
- dmu/arrays/utilities.py +55 -0
- dmu/dataframe/dataframe.py +36 -0
- dmu/generic/utilities.py +69 -0
- dmu/logging/log_store.py +129 -0
- dmu/ml/cv_classifier.py +122 -0
- dmu/ml/cv_predict.py +152 -0
- dmu/ml/train_mva.py +257 -0
- dmu/ml/utilities.py +132 -0
- dmu/plotting/plotter.py +227 -0
- dmu/plotting/plotter_1d.py +113 -0
- dmu/plotting/plotter_2d.py +87 -0
- dmu/rdataframe/atr_mgr.py +79 -0
- dmu/rdataframe/utilities.py +72 -0
- dmu/rfile/rfprinter.py +91 -0
- dmu/rfile/utilities.py +34 -0
- dmu/stats/fitter.py +515 -0
- dmu/stats/function.py +314 -0
- dmu/stats/utilities.py +134 -0
- dmu/testing/utilities.py +119 -0
- dmu/text/transformer.py +182 -0
- dmu_data/__init__.py +0 -0
- dmu_data/ml/tests/train_mva.yaml +37 -0
- dmu_data/plotting/tests/2d.yaml +14 -0
- dmu_data/plotting/tests/fig_size.yaml +13 -0
- dmu_data/plotting/tests/high_stat.yaml +22 -0
- dmu_data/plotting/tests/name.yaml +14 -0
- dmu_data/plotting/tests/no_bounds.yaml +12 -0
- dmu_data/plotting/tests/simple.yaml +8 -0
- dmu_data/plotting/tests/title.yaml +14 -0
- dmu_data/plotting/tests/weights.yaml +13 -0
- dmu_data/text/transform.toml +4 -0
- dmu_data/text/transform.txt +6 -0
- dmu_data/text/transform_set.toml +8 -0
- dmu_data/text/transform_set.txt +6 -0
- dmu_data/text/transform_trf.txt +12 -0
- dmu_scripts/physics/check_truth.py +121 -0
- dmu_scripts/rfile/compare_root_files.py +299 -0
- dmu_scripts/rfile/print_trees.py +35 -0
- dmu_scripts/ssh/coned.py +168 -0
- dmu_scripts/text/transform_text.py +46 -0
dmu/stats/function.py
ADDED
@@ -0,0 +1,314 @@
|
|
1
|
+
'''
|
2
|
+
Module containing the Function class
|
3
|
+
'''
|
4
|
+
import os
|
5
|
+
import json
|
6
|
+
|
7
|
+
from typing import Any
|
8
|
+
|
9
|
+
import numpy
|
10
|
+
import matplotlib.pyplot as plt
|
11
|
+
|
12
|
+
from scipy.interpolate import interp1d
|
13
|
+
from dmu.logging.log_store import LogStore
|
14
|
+
|
15
|
+
log = LogStore.add_logger('dmu:stats:function')
|
16
|
+
#---------------------------------------------------------
|
17
|
+
class FunOutOfBounds(Exception):
|
18
|
+
'''
|
19
|
+
Will be raised when function defined between [a, b] is evaluated outside
|
20
|
+
'''
|
21
|
+
#---------------------------------------------------------
|
22
|
+
class Function:
|
23
|
+
'''
|
24
|
+
Class meant to represent a 1D function created from (x, y) coordinates
|
25
|
+
'''
|
26
|
+
#------------------------------------------------
|
27
|
+
def __init__(self, x : list | numpy.ndarray, y : list | numpy.ndarray, kind : str = 'cubic'):
|
28
|
+
'''
|
29
|
+
x (list) : List with x coordinates
|
30
|
+
y (list) : List with y coordinates
|
31
|
+
'''
|
32
|
+
|
33
|
+
x = self._array_to_list(x)
|
34
|
+
y = self._array_to_list(y)
|
35
|
+
|
36
|
+
if len(x) != len(y):
|
37
|
+
raise ValueError('X and Y coordinates have different lengths')
|
38
|
+
|
39
|
+
npoint = len(x)
|
40
|
+
if npoint < 4:
|
41
|
+
raise ValueError('Need at least four points, found {npoint}')
|
42
|
+
|
43
|
+
x, y = self._remove_duplicates(x=x, y=y)
|
44
|
+
|
45
|
+
self._max_entries = 400
|
46
|
+
self._l_x = x
|
47
|
+
self._l_y = y
|
48
|
+
self._kind= kind
|
49
|
+
self._tag = 'no_tag'
|
50
|
+
|
51
|
+
self._interpolator = interp1d(self._l_x, self._l_y, kind=self._kind)
|
52
|
+
|
53
|
+
self._update_data()
|
54
|
+
#------------------------------------------------
|
55
|
+
def __eq__(self, othr):
|
56
|
+
if not isinstance(othr, Function):
|
57
|
+
log.warning('Comparison not done with instance of Function')
|
58
|
+
return False
|
59
|
+
|
60
|
+
d_self = self.__dict__
|
61
|
+
d_othr = othr.__dict__
|
62
|
+
|
63
|
+
if '_interpolator' in d_self:
|
64
|
+
del d_self['_interpolator']
|
65
|
+
|
66
|
+
if '_interpolator' in d_othr:
|
67
|
+
del d_othr['_interpolator']
|
68
|
+
|
69
|
+
return d_self == d_othr
|
70
|
+
#------------------------------------------------
|
71
|
+
def __str__(self):
|
72
|
+
npoints = len(self._l_x)
|
73
|
+
max_x = max(self._l_x)
|
74
|
+
min_x = min(self._l_x)
|
75
|
+
|
76
|
+
max_y = max(self._l_y)
|
77
|
+
min_y = min(self._l_y)
|
78
|
+
|
79
|
+
line = f'\n{"Points":<20}{npoints:<20}\n'
|
80
|
+
line+= '-------------------------\n'
|
81
|
+
line+= f'{"x-max":<20}{max_x:<20}\n'
|
82
|
+
line+= f'{"x-min":<20}{min_x:<20}\n'
|
83
|
+
line+= f'{"y-max":<20}{max_y:<20}\n'
|
84
|
+
line+= f'{"y-min":<20}{min_y:<20}'
|
85
|
+
|
86
|
+
return line
|
87
|
+
#------------------------------------------------
|
88
|
+
def __call__(self, xval : float | numpy.ndarray | list, off_bounds_raise : bool = False) -> numpy.ndarray:
|
89
|
+
'''
|
90
|
+
Class taking value of x coordinates as a float, numpy array or list
|
91
|
+
It will interpolate y value and return value
|
92
|
+
'''
|
93
|
+
if not off_bounds_raise:
|
94
|
+
xval = self._push_in_bounds(xval)
|
95
|
+
|
96
|
+
self._check_xval_validity(xval)
|
97
|
+
|
98
|
+
return self._interpolator(xval)
|
99
|
+
#------------------------------------------------
|
100
|
+
def _push_in_bounds(self, xval : float | numpy.ndarray | list) -> numpy.ndarray:
|
101
|
+
'''
|
102
|
+
If the xval container, has elements above (below) the upper (lower) bound, these events will be set to the closest bound
|
103
|
+
'''
|
104
|
+
|
105
|
+
xval = numpy.array(xval).flatten().astype(float)
|
106
|
+
|
107
|
+
max_x = max(self._l_x)
|
108
|
+
min_x = min(self._l_x)
|
109
|
+
|
110
|
+
if ((min_x <= xval) & (xval <= max_x)).all():
|
111
|
+
log.debug('Input array within bounds, will not push elements')
|
112
|
+
return xval
|
113
|
+
|
114
|
+
|
115
|
+
xmod = numpy.clip(xval, min_x, max_x)
|
116
|
+
|
117
|
+
arr_diff = xval != xmod
|
118
|
+
arr_indx = numpy.where(arr_diff)[0]
|
119
|
+
ndiff = numpy.sum(arr_diff)
|
120
|
+
arr_indx = arr_indx[:20]
|
121
|
+
|
122
|
+
log.warning(f'Sending {ndiff} entries inside bounds [{min_x:.3e}, {max_x:.3e}]')
|
123
|
+
|
124
|
+
for indx in arr_indx:
|
125
|
+
org = xval[indx]
|
126
|
+
mod = xmod[indx]
|
127
|
+
|
128
|
+
log.info(f'{org:<20.5e}{"-->":<20}{mod:<20.5}')
|
129
|
+
|
130
|
+
return xmod
|
131
|
+
#------------------------------------------------
|
132
|
+
@staticmethod
|
133
|
+
def json_decoder(d_attr):
|
134
|
+
'''
|
135
|
+
Takes dictionary of attributes from JSON serialization
|
136
|
+
Returns instance of Function
|
137
|
+
'''
|
138
|
+
|
139
|
+
if '_l_x' not in d_attr:
|
140
|
+
raise KeyError('X values not found')
|
141
|
+
|
142
|
+
if '_l_y' not in d_attr:
|
143
|
+
raise KeyError('Y values not found')
|
144
|
+
|
145
|
+
if '_tag' not in d_attr:
|
146
|
+
raise KeyError('tag not found')
|
147
|
+
|
148
|
+
x = d_attr['_l_x' ]
|
149
|
+
y = d_attr['_l_y' ]
|
150
|
+
kind = d_attr['_kind']
|
151
|
+
tag = d_attr['_tag' ]
|
152
|
+
|
153
|
+
fun = Function(x=x, y=y, kind=kind)
|
154
|
+
fun.tag = tag
|
155
|
+
|
156
|
+
return fun
|
157
|
+
#------------------------------------------------
|
158
|
+
@property
|
159
|
+
def tag(self):
|
160
|
+
'''
|
161
|
+
Returns string simbolyzing tag of function
|
162
|
+
'''
|
163
|
+
return self._tag
|
164
|
+
|
165
|
+
@tag.setter
|
166
|
+
def tag(self, value : str):
|
167
|
+
'''
|
168
|
+
This sets the _tag property of the function
|
169
|
+
'''
|
170
|
+
self._tag = value
|
171
|
+
#------------------------------------------------
|
172
|
+
@staticmethod
|
173
|
+
def load(path : str):
|
174
|
+
'''
|
175
|
+
Will take path to JSON file with serialized function
|
176
|
+
Will return function instance
|
177
|
+
'''
|
178
|
+
|
179
|
+
if not os.path.isfile(path):
|
180
|
+
raise FileNotFoundError(f'Cannot find: {path}')
|
181
|
+
|
182
|
+
with open(path, encoding='utf-8') as ifile:
|
183
|
+
fun = json.loads(ifile.read(), object_hook=Function.json_decoder)
|
184
|
+
|
185
|
+
log.info(f'Loaded from: {path}')
|
186
|
+
|
187
|
+
return fun
|
188
|
+
#------------------------------------------------
|
189
|
+
def _array_to_list(self, x : Any):
|
190
|
+
'''
|
191
|
+
Transform from ndarray to list
|
192
|
+
Return x if already list
|
193
|
+
Raise otherwise
|
194
|
+
'''
|
195
|
+
if isinstance(x, list):
|
196
|
+
log.debug('Already found list')
|
197
|
+
return x
|
198
|
+
|
199
|
+
if isinstance(x, numpy.ndarray):
|
200
|
+
log.debug('Transforming argument to list')
|
201
|
+
return x.tolist()
|
202
|
+
|
203
|
+
raise ValueError('Object introduced is neither a list nor a numpy array')
|
204
|
+
#------------------------------------------------
|
205
|
+
def _update_data(self):
|
206
|
+
'''
|
207
|
+
If number of entries in dataset is larger than _max_entries:
|
208
|
+
|
209
|
+
Use interpolator to scan function and get new (x, y) pairs.
|
210
|
+
'''
|
211
|
+
norg = len(self._l_x)
|
212
|
+
if norg <= self._max_entries:
|
213
|
+
return
|
214
|
+
|
215
|
+
log.info(f'Trimming dataset: {norg} -> {self._max_entries}')
|
216
|
+
|
217
|
+
min_x = min(self._l_x)
|
218
|
+
max_x = max(self._l_x)
|
219
|
+
|
220
|
+
arr_x = numpy.linspace(min_x, max_x, self._max_entries)
|
221
|
+
arr_y = self(arr_x)
|
222
|
+
|
223
|
+
self._l_x = arr_x.tolist()
|
224
|
+
self._l_y = arr_y.tolist()
|
225
|
+
#------------------------------------------------
|
226
|
+
def _remove_duplicates(self, x : list, y : list):
|
227
|
+
'''
|
228
|
+
Takes two lists with the same sizes and remove (x, y) points with repeated
|
229
|
+
x coordinates.
|
230
|
+
Return tuple with x and y after removal
|
231
|
+
'''
|
232
|
+
|
233
|
+
norg = len(x)
|
234
|
+
|
235
|
+
d_tmp = dict(zip(x, y))
|
236
|
+
|
237
|
+
x = list(d_tmp.keys())
|
238
|
+
y = list(d_tmp.values())
|
239
|
+
|
240
|
+
nfnl = len(x)
|
241
|
+
|
242
|
+
if norg != nfnl:
|
243
|
+
log.warning(f'Found duplicates: {norg} -> {nfnl}')
|
244
|
+
|
245
|
+
return x, y
|
246
|
+
#------------------------------------------------
|
247
|
+
def _check_xval_validity(self, xval : float | numpy.ndarray | list):
|
248
|
+
'''
|
249
|
+
Will check that xval is an acceptable value for the function to be evaluated at
|
250
|
+
'''
|
251
|
+
|
252
|
+
if isinstance(xval, list):
|
253
|
+
xval = numpy.array(xval)
|
254
|
+
|
255
|
+
if not isinstance(xval, (float, numpy.ndarray)):
|
256
|
+
raise ValueError(f'x value is not a float or numpy array: {xval}')
|
257
|
+
|
258
|
+
check_within_bounds_vect = numpy.vectorize(self._check_within_bounds)
|
259
|
+
check_within_bounds_vect(xval)
|
260
|
+
#------------------------------------------------
|
261
|
+
def _check_within_bounds(self, xval : float):
|
262
|
+
'''
|
263
|
+
Check that xval is within bounds of function
|
264
|
+
'''
|
265
|
+
|
266
|
+
if xval < min(self._l_x) or xval > max(self._l_x):
|
267
|
+
print(self)
|
268
|
+
raise FunOutOfBounds(f'x value outside bounds: {xval}')
|
269
|
+
#------------------------------------------------
|
270
|
+
def _json_encoder(self, obj):
|
271
|
+
'''
|
272
|
+
Takes Function object
|
273
|
+
Returns dictionary of attributes for encoding
|
274
|
+
'''
|
275
|
+
d_data = obj.__dict__
|
276
|
+
|
277
|
+
if '_interpolator' in d_data:
|
278
|
+
del d_data['_interpolator']
|
279
|
+
|
280
|
+
return d_data
|
281
|
+
#------------------------------------------------
|
282
|
+
def _save_plot(self, path : str):
|
283
|
+
'''
|
284
|
+
Takes path to PNG, saves scatter plot of l_y vs l_x
|
285
|
+
'''
|
286
|
+
|
287
|
+
plt.plot(self._l_x, self._l_y)
|
288
|
+
plt.savefig(path)
|
289
|
+
plt.close()
|
290
|
+
|
291
|
+
log.info(f'Saved to: {path}')
|
292
|
+
#------------------------------------------------
|
293
|
+
def save(self, path : str, plot : bool = False):
|
294
|
+
'''
|
295
|
+
Saves current object to JSON
|
296
|
+
|
297
|
+
path (str): Path to file, needs to end in .json
|
298
|
+
'''
|
299
|
+
|
300
|
+
if not path.endswith('.json'):
|
301
|
+
raise ValueError(f'Output path does not end in .json: {path}')
|
302
|
+
|
303
|
+
dir_name = os.path.dirname(path)
|
304
|
+
os.makedirs(dir_name, exist_ok=True)
|
305
|
+
|
306
|
+
with open(path, 'w', encoding='utf-8') as ofile:
|
307
|
+
json.dump(self, ofile, indent=4, default=self._json_encoder)
|
308
|
+
|
309
|
+
if plot:
|
310
|
+
path = path.replace('.json', '.png')
|
311
|
+
self._save_plot(path)
|
312
|
+
|
313
|
+
log.info(f'Saved to: {path}')
|
314
|
+
#------------------------------------------------
|
dmu/stats/utilities.py
ADDED
@@ -0,0 +1,134 @@
|
|
1
|
+
'''
|
2
|
+
Module with utility functions related to the dmu.stats project
|
3
|
+
'''
|
4
|
+
import os
|
5
|
+
import re
|
6
|
+
from typing import Union
|
7
|
+
import zfit
|
8
|
+
|
9
|
+
from dmu.logging.log_store import LogStore
|
10
|
+
|
11
|
+
log = LogStore.add_logger('dmu:stats:utilities')
|
12
|
+
#-------------------------------------------------------
|
13
|
+
#Zfit/print_pdf
|
14
|
+
#-------------------------------------------------------
|
15
|
+
def _get_const(par : zfit.Parameter, d_const : Union[None, dict[str, list[float]]]) -> str:
|
16
|
+
'''
|
17
|
+
Takes zfit parameter and dictionary of constraints
|
18
|
+
Returns a formatted string with the value of the constraint on that parameter
|
19
|
+
'''
|
20
|
+
if d_const is None or par.name not in d_const:
|
21
|
+
return 'none'
|
22
|
+
|
23
|
+
obj = d_const[par.name]
|
24
|
+
if isinstance(obj, (list, tuple)):
|
25
|
+
[mu, sg] = obj
|
26
|
+
val = f'{mu:.3e}; {sg:.3e}'
|
27
|
+
else:
|
28
|
+
val = str(obj)
|
29
|
+
|
30
|
+
return val
|
31
|
+
#-------------------------------------------------------
|
32
|
+
def _blind_vars(s_par : set, l_blind : Union[list[str], None] = None) -> set[zfit.Parameter]:
|
33
|
+
'''
|
34
|
+
Takes set of zfit parameters and list of parameter names to blind
|
35
|
+
returns set of zfit parameters that should be blinded
|
36
|
+
'''
|
37
|
+
if l_blind is None:
|
38
|
+
return s_par
|
39
|
+
|
40
|
+
rgx_ors = '|'.join(l_blind)
|
41
|
+
regex = f'({rgx_ors})'
|
42
|
+
|
43
|
+
s_par_blind = { par for par in s_par if not re.match(regex, par.name) }
|
44
|
+
|
45
|
+
return s_par_blind
|
46
|
+
#-------------------------------------------------------
|
47
|
+
def _get_pars(
|
48
|
+
pdf : zfit.pdf.BasePDF,
|
49
|
+
blind : Union[None, list[str]]) -> tuple[list, list]:
|
50
|
+
|
51
|
+
s_par_flt = pdf.get_params(floating= True)
|
52
|
+
s_par_fix = pdf.get_params(floating=False)
|
53
|
+
|
54
|
+
s_par_flt = _blind_vars(s_par_flt, l_blind=blind)
|
55
|
+
s_par_fix = _blind_vars(s_par_fix, l_blind=blind)
|
56
|
+
|
57
|
+
l_par_flt = list(s_par_flt)
|
58
|
+
l_par_fix = list(s_par_fix)
|
59
|
+
|
60
|
+
l_par_flt = sorted(l_par_flt, key=lambda par: par.name)
|
61
|
+
l_par_fix = sorted(l_par_fix, key=lambda par: par.name)
|
62
|
+
|
63
|
+
return l_par_flt, l_par_fix
|
64
|
+
#-------------------------------------------------------
|
65
|
+
def _get_messages(
|
66
|
+
pdf : zfit.pdf.BasePDF,
|
67
|
+
l_par_flt : list,
|
68
|
+
l_par_fix : list,
|
69
|
+
d_const : Union[None, dict[str,list[float]]] = None) -> list[str]:
|
70
|
+
|
71
|
+
str_space = str(pdf.space)
|
72
|
+
|
73
|
+
l_msg=[]
|
74
|
+
l_msg.append('-' * 20)
|
75
|
+
l_msg.append(f'PDF: {pdf.name}')
|
76
|
+
l_msg.append(f'OBS: {str_space}')
|
77
|
+
l_msg.append(f'{"Name":<50}{"Value":>15}{"Low":>15}{"High":>15}{"Floating":>5}{"Constraint":>25}')
|
78
|
+
l_msg.append('-' * 20)
|
79
|
+
for par in l_par_flt:
|
80
|
+
value = par.value().numpy()
|
81
|
+
low = par.lower
|
82
|
+
hig = par.upper
|
83
|
+
const = _get_const(par, d_const)
|
84
|
+
l_msg.append(f'{par.name:<50}{value:>15.3e}{low:>15.3e}{hig:>15.3e}{par.floating:>5}{const:>25}')
|
85
|
+
|
86
|
+
l_msg.append('')
|
87
|
+
|
88
|
+
for par in l_par_fix:
|
89
|
+
value = par.value().numpy()
|
90
|
+
low = par.lower
|
91
|
+
hig = par.upper
|
92
|
+
const = _get_const(par, d_const)
|
93
|
+
l_msg.append(f'{par.name:<50}{value:>15.3e}{low:>15.3e}{hig:>15.3e}{par.floating:>5}{const:>25}')
|
94
|
+
|
95
|
+
return l_msg
|
96
|
+
#-------------------------------------------------------
|
97
|
+
def print_pdf(
|
98
|
+
pdf : zfit.pdf.BasePDF,
|
99
|
+
d_const : Union[None, dict[str,list[float]]] = None,
|
100
|
+
txt_path : Union[str,None] = None,
|
101
|
+
level : int = 20,
|
102
|
+
blind : Union[None, list[str]] = None):
|
103
|
+
'''
|
104
|
+
Function used to print zfit PDFs
|
105
|
+
|
106
|
+
Parameters
|
107
|
+
-------------------
|
108
|
+
pdf (zfit.PDF): PDF
|
109
|
+
d_const (dict): Optional dictionary mapping {par_name : [mu, sg]}
|
110
|
+
txt_path (str): Optionally, dump output to text in this path
|
111
|
+
level (str) : Optionally set the level at which the printing happens in screen, default info
|
112
|
+
blind (list) : List of regular expressions matching variable names to blind in printout
|
113
|
+
'''
|
114
|
+
l_par_flt, l_par_fix = _get_pars(pdf, blind)
|
115
|
+
l_msg = _get_messages(pdf, l_par_flt, l_par_fix, d_const)
|
116
|
+
|
117
|
+
if txt_path is not None:
|
118
|
+
log.debug(f'Saving to: {txt_path}')
|
119
|
+
message = '\n'.join(l_msg)
|
120
|
+
dir_path = os.path.dirname(txt_path)
|
121
|
+
os.makedirs(dir_path, exist_ok=True)
|
122
|
+
with open(txt_path, 'w', encoding='utf-8') as ofile:
|
123
|
+
ofile.write(message)
|
124
|
+
|
125
|
+
return
|
126
|
+
|
127
|
+
for msg in l_msg:
|
128
|
+
if level == 20:
|
129
|
+
log.info(msg)
|
130
|
+
elif level == 30:
|
131
|
+
log.debug(msg)
|
132
|
+
else:
|
133
|
+
raise ValueError(f'Invalid level: {level}')
|
134
|
+
#-------------------------------------------------------
|
dmu/testing/utilities.py
ADDED
@@ -0,0 +1,119 @@
|
|
1
|
+
'''
|
2
|
+
Module containing utility functions needed by unit tests
|
3
|
+
'''
|
4
|
+
import os
|
5
|
+
from typing import Union
|
6
|
+
from dataclasses import dataclass
|
7
|
+
from importlib.resources import files
|
8
|
+
|
9
|
+
from ROOT import RDF, TFile, RDataFrame
|
10
|
+
|
11
|
+
import pandas as pnd
|
12
|
+
import numpy
|
13
|
+
import yaml
|
14
|
+
|
15
|
+
from dmu.logging.log_store import LogStore
|
16
|
+
|
17
|
+
log = LogStore.add_logger('dmu:testing:utilities')
|
18
|
+
# -------------------------------
|
19
|
+
@dataclass
|
20
|
+
class Data:
|
21
|
+
'''
|
22
|
+
Class storing shared data
|
23
|
+
'''
|
24
|
+
nentries = 3000
|
25
|
+
# -------------------------------
|
26
|
+
def _double_data(d_data : dict) -> dict:
|
27
|
+
df_1 = pnd.DataFrame(d_data)
|
28
|
+
df_2 = pnd.DataFrame(d_data)
|
29
|
+
|
30
|
+
df = pnd.concat([df_1, df_2], axis=0)
|
31
|
+
|
32
|
+
d_data = { name : df[name].to_numpy() for name in df.columns }
|
33
|
+
|
34
|
+
return d_data
|
35
|
+
# -------------------------------
|
36
|
+
def _add_nans(d_data : dict) -> dict:
|
37
|
+
df_good = pnd.DataFrame(d_data)
|
38
|
+
df_bad = pnd.DataFrame(d_data)
|
39
|
+
df_bad[:] = numpy.nan
|
40
|
+
|
41
|
+
df = pnd.concat([df_good, df_bad])
|
42
|
+
d_data = { name : df[name].to_numpy() for name in df.columns }
|
43
|
+
|
44
|
+
return d_data
|
45
|
+
# -------------------------------
|
46
|
+
def get_rdf(kind : Union[str,None] = None,
|
47
|
+
repeated : bool = False,
|
48
|
+
add_nans : bool = False):
|
49
|
+
'''
|
50
|
+
Return ROOT dataframe with toy data
|
51
|
+
'''
|
52
|
+
d_data = {}
|
53
|
+
if kind == 'sig':
|
54
|
+
d_data['w'] = numpy.random.normal(0, 1, size=Data.nentries)
|
55
|
+
d_data['x'] = numpy.random.normal(0, 1, size=Data.nentries)
|
56
|
+
d_data['y'] = numpy.random.normal(0, 1, size=Data.nentries)
|
57
|
+
d_data['z'] = numpy.random.normal(0, 1, size=Data.nentries)
|
58
|
+
elif kind == 'bkg':
|
59
|
+
d_data['w'] = numpy.random.normal(1, 1, size=Data.nentries)
|
60
|
+
d_data['x'] = numpy.random.normal(1, 1, size=Data.nentries)
|
61
|
+
d_data['y'] = numpy.random.normal(1, 1, size=Data.nentries)
|
62
|
+
d_data['z'] = numpy.random.normal(1, 1, size=Data.nentries)
|
63
|
+
else:
|
64
|
+
log.error(f'Invalid kind: {kind}')
|
65
|
+
raise ValueError
|
66
|
+
|
67
|
+
if repeated:
|
68
|
+
d_data = _double_data(d_data)
|
69
|
+
|
70
|
+
if add_nans:
|
71
|
+
d_data = _add_nans(d_data)
|
72
|
+
|
73
|
+
rdf = RDF.FromNumpy(d_data)
|
74
|
+
|
75
|
+
return rdf
|
76
|
+
# -------------------------------
|
77
|
+
def get_config(name : Union[str,None] = None):
|
78
|
+
'''
|
79
|
+
Takes path to the YAML config file, after `dmu_data`
|
80
|
+
Returns dictionary with config
|
81
|
+
'''
|
82
|
+
if name is None:
|
83
|
+
raise ValueError('Name not pased')
|
84
|
+
|
85
|
+
cfg_path = files('dmu_data').joinpath(name)
|
86
|
+
cfg_path = str(cfg_path)
|
87
|
+
with open(cfg_path, encoding='utf-8') as ifile:
|
88
|
+
d_cfg = yaml.safe_load(ifile)
|
89
|
+
|
90
|
+
return d_cfg
|
91
|
+
# -------------------------------
|
92
|
+
def _get_rdf(nentries : int) -> RDataFrame:
|
93
|
+
rdf = RDataFrame(nentries)
|
94
|
+
rdf = rdf.Define('x', '0')
|
95
|
+
rdf = rdf.Define('y', '1')
|
96
|
+
rdf = rdf.Define('z', '2')
|
97
|
+
|
98
|
+
return rdf
|
99
|
+
# -------------------------------
|
100
|
+
def get_file_with_trees(path : str) -> TFile:
|
101
|
+
'''
|
102
|
+
Picks full path to toy ROOT file, in the form of /a/b/c/file.root
|
103
|
+
returns handle to it
|
104
|
+
'''
|
105
|
+
dir_name = os.path.dirname(path)
|
106
|
+
os.makedirs(dir_name, exist_ok=True)
|
107
|
+
|
108
|
+
snap = RDF.RSnapshotOptions()
|
109
|
+
snap.fMode = 'recreate'
|
110
|
+
|
111
|
+
l_tree_name = ['odir/idir/a', 'dir/b', 'c']
|
112
|
+
l_nevt = [ 100, 200, 300]
|
113
|
+
|
114
|
+
l_rdf = [ _get_rdf(nevt) for nevt in l_nevt ]
|
115
|
+
for rdf, tree_name in zip(l_rdf, l_tree_name):
|
116
|
+
rdf.Snapshot(tree_name, path, ['x', 'y', 'z'], snap)
|
117
|
+
snap.fMode = 'update'
|
118
|
+
|
119
|
+
return TFile(path)
|