data-manipulation-utilities 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. data_manipulation_utilities-0.0.1.dist-info/METADATA +713 -0
  2. data_manipulation_utilities-0.0.1.dist-info/RECORD +45 -0
  3. data_manipulation_utilities-0.0.1.dist-info/WHEEL +5 -0
  4. data_manipulation_utilities-0.0.1.dist-info/entry_points.txt +6 -0
  5. data_manipulation_utilities-0.0.1.dist-info/top_level.txt +3 -0
  6. dmu/arrays/utilities.py +55 -0
  7. dmu/dataframe/dataframe.py +36 -0
  8. dmu/generic/utilities.py +69 -0
  9. dmu/logging/log_store.py +129 -0
  10. dmu/ml/cv_classifier.py +122 -0
  11. dmu/ml/cv_predict.py +152 -0
  12. dmu/ml/train_mva.py +257 -0
  13. dmu/ml/utilities.py +132 -0
  14. dmu/plotting/plotter.py +227 -0
  15. dmu/plotting/plotter_1d.py +113 -0
  16. dmu/plotting/plotter_2d.py +87 -0
  17. dmu/rdataframe/atr_mgr.py +79 -0
  18. dmu/rdataframe/utilities.py +72 -0
  19. dmu/rfile/rfprinter.py +91 -0
  20. dmu/rfile/utilities.py +34 -0
  21. dmu/stats/fitter.py +515 -0
  22. dmu/stats/function.py +314 -0
  23. dmu/stats/utilities.py +134 -0
  24. dmu/testing/utilities.py +119 -0
  25. dmu/text/transformer.py +182 -0
  26. dmu_data/__init__.py +0 -0
  27. dmu_data/ml/tests/train_mva.yaml +37 -0
  28. dmu_data/plotting/tests/2d.yaml +14 -0
  29. dmu_data/plotting/tests/fig_size.yaml +13 -0
  30. dmu_data/plotting/tests/high_stat.yaml +22 -0
  31. dmu_data/plotting/tests/name.yaml +14 -0
  32. dmu_data/plotting/tests/no_bounds.yaml +12 -0
  33. dmu_data/plotting/tests/simple.yaml +8 -0
  34. dmu_data/plotting/tests/title.yaml +14 -0
  35. dmu_data/plotting/tests/weights.yaml +13 -0
  36. dmu_data/text/transform.toml +4 -0
  37. dmu_data/text/transform.txt +6 -0
  38. dmu_data/text/transform_set.toml +8 -0
  39. dmu_data/text/transform_set.txt +6 -0
  40. dmu_data/text/transform_trf.txt +12 -0
  41. dmu_scripts/physics/check_truth.py +121 -0
  42. dmu_scripts/rfile/compare_root_files.py +299 -0
  43. dmu_scripts/rfile/print_trees.py +35 -0
  44. dmu_scripts/ssh/coned.py +168 -0
  45. dmu_scripts/text/transform_text.py +46 -0
dmu/stats/function.py ADDED
@@ -0,0 +1,314 @@
1
+ '''
2
+ Module containing the Function class
3
+ '''
4
+ import os
5
+ import json
6
+
7
+ from typing import Any
8
+
9
+ import numpy
10
+ import matplotlib.pyplot as plt
11
+
12
+ from scipy.interpolate import interp1d
13
+ from dmu.logging.log_store import LogStore
14
+
15
+ log = LogStore.add_logger('dmu:stats:function')
16
+ #---------------------------------------------------------
17
+ class FunOutOfBounds(Exception):
18
+ '''
19
+ Will be raised when function defined between [a, b] is evaluated outside
20
+ '''
21
+ #---------------------------------------------------------
22
+ class Function:
23
+ '''
24
+ Class meant to represent a 1D function created from (x, y) coordinates
25
+ '''
26
+ #------------------------------------------------
27
+ def __init__(self, x : list | numpy.ndarray, y : list | numpy.ndarray, kind : str = 'cubic'):
28
+ '''
29
+ x (list) : List with x coordinates
30
+ y (list) : List with y coordinates
31
+ '''
32
+
33
+ x = self._array_to_list(x)
34
+ y = self._array_to_list(y)
35
+
36
+ if len(x) != len(y):
37
+ raise ValueError('X and Y coordinates have different lengths')
38
+
39
+ npoint = len(x)
40
+ if npoint < 4:
41
+ raise ValueError('Need at least four points, found {npoint}')
42
+
43
+ x, y = self._remove_duplicates(x=x, y=y)
44
+
45
+ self._max_entries = 400
46
+ self._l_x = x
47
+ self._l_y = y
48
+ self._kind= kind
49
+ self._tag = 'no_tag'
50
+
51
+ self._interpolator = interp1d(self._l_x, self._l_y, kind=self._kind)
52
+
53
+ self._update_data()
54
+ #------------------------------------------------
55
+ def __eq__(self, othr):
56
+ if not isinstance(othr, Function):
57
+ log.warning('Comparison not done with instance of Function')
58
+ return False
59
+
60
+ d_self = self.__dict__
61
+ d_othr = othr.__dict__
62
+
63
+ if '_interpolator' in d_self:
64
+ del d_self['_interpolator']
65
+
66
+ if '_interpolator' in d_othr:
67
+ del d_othr['_interpolator']
68
+
69
+ return d_self == d_othr
70
+ #------------------------------------------------
71
+ def __str__(self):
72
+ npoints = len(self._l_x)
73
+ max_x = max(self._l_x)
74
+ min_x = min(self._l_x)
75
+
76
+ max_y = max(self._l_y)
77
+ min_y = min(self._l_y)
78
+
79
+ line = f'\n{"Points":<20}{npoints:<20}\n'
80
+ line+= '-------------------------\n'
81
+ line+= f'{"x-max":<20}{max_x:<20}\n'
82
+ line+= f'{"x-min":<20}{min_x:<20}\n'
83
+ line+= f'{"y-max":<20}{max_y:<20}\n'
84
+ line+= f'{"y-min":<20}{min_y:<20}'
85
+
86
+ return line
87
+ #------------------------------------------------
88
+ def __call__(self, xval : float | numpy.ndarray | list, off_bounds_raise : bool = False) -> numpy.ndarray:
89
+ '''
90
+ Class taking value of x coordinates as a float, numpy array or list
91
+ It will interpolate y value and return value
92
+ '''
93
+ if not off_bounds_raise:
94
+ xval = self._push_in_bounds(xval)
95
+
96
+ self._check_xval_validity(xval)
97
+
98
+ return self._interpolator(xval)
99
+ #------------------------------------------------
100
+ def _push_in_bounds(self, xval : float | numpy.ndarray | list) -> numpy.ndarray:
101
+ '''
102
+ If the xval container, has elements above (below) the upper (lower) bound, these events will be set to the closest bound
103
+ '''
104
+
105
+ xval = numpy.array(xval).flatten().astype(float)
106
+
107
+ max_x = max(self._l_x)
108
+ min_x = min(self._l_x)
109
+
110
+ if ((min_x <= xval) & (xval <= max_x)).all():
111
+ log.debug('Input array within bounds, will not push elements')
112
+ return xval
113
+
114
+
115
+ xmod = numpy.clip(xval, min_x, max_x)
116
+
117
+ arr_diff = xval != xmod
118
+ arr_indx = numpy.where(arr_diff)[0]
119
+ ndiff = numpy.sum(arr_diff)
120
+ arr_indx = arr_indx[:20]
121
+
122
+ log.warning(f'Sending {ndiff} entries inside bounds [{min_x:.3e}, {max_x:.3e}]')
123
+
124
+ for indx in arr_indx:
125
+ org = xval[indx]
126
+ mod = xmod[indx]
127
+
128
+ log.info(f'{org:<20.5e}{"-->":<20}{mod:<20.5}')
129
+
130
+ return xmod
131
+ #------------------------------------------------
132
+ @staticmethod
133
+ def json_decoder(d_attr):
134
+ '''
135
+ Takes dictionary of attributes from JSON serialization
136
+ Returns instance of Function
137
+ '''
138
+
139
+ if '_l_x' not in d_attr:
140
+ raise KeyError('X values not found')
141
+
142
+ if '_l_y' not in d_attr:
143
+ raise KeyError('Y values not found')
144
+
145
+ if '_tag' not in d_attr:
146
+ raise KeyError('tag not found')
147
+
148
+ x = d_attr['_l_x' ]
149
+ y = d_attr['_l_y' ]
150
+ kind = d_attr['_kind']
151
+ tag = d_attr['_tag' ]
152
+
153
+ fun = Function(x=x, y=y, kind=kind)
154
+ fun.tag = tag
155
+
156
+ return fun
157
+ #------------------------------------------------
158
+ @property
159
+ def tag(self):
160
+ '''
161
+ Returns string simbolyzing tag of function
162
+ '''
163
+ return self._tag
164
+
165
+ @tag.setter
166
+ def tag(self, value : str):
167
+ '''
168
+ This sets the _tag property of the function
169
+ '''
170
+ self._tag = value
171
+ #------------------------------------------------
172
+ @staticmethod
173
+ def load(path : str):
174
+ '''
175
+ Will take path to JSON file with serialized function
176
+ Will return function instance
177
+ '''
178
+
179
+ if not os.path.isfile(path):
180
+ raise FileNotFoundError(f'Cannot find: {path}')
181
+
182
+ with open(path, encoding='utf-8') as ifile:
183
+ fun = json.loads(ifile.read(), object_hook=Function.json_decoder)
184
+
185
+ log.info(f'Loaded from: {path}')
186
+
187
+ return fun
188
+ #------------------------------------------------
189
+ def _array_to_list(self, x : Any):
190
+ '''
191
+ Transform from ndarray to list
192
+ Return x if already list
193
+ Raise otherwise
194
+ '''
195
+ if isinstance(x, list):
196
+ log.debug('Already found list')
197
+ return x
198
+
199
+ if isinstance(x, numpy.ndarray):
200
+ log.debug('Transforming argument to list')
201
+ return x.tolist()
202
+
203
+ raise ValueError('Object introduced is neither a list nor a numpy array')
204
+ #------------------------------------------------
205
+ def _update_data(self):
206
+ '''
207
+ If number of entries in dataset is larger than _max_entries:
208
+
209
+ Use interpolator to scan function and get new (x, y) pairs.
210
+ '''
211
+ norg = len(self._l_x)
212
+ if norg <= self._max_entries:
213
+ return
214
+
215
+ log.info(f'Trimming dataset: {norg} -> {self._max_entries}')
216
+
217
+ min_x = min(self._l_x)
218
+ max_x = max(self._l_x)
219
+
220
+ arr_x = numpy.linspace(min_x, max_x, self._max_entries)
221
+ arr_y = self(arr_x)
222
+
223
+ self._l_x = arr_x.tolist()
224
+ self._l_y = arr_y.tolist()
225
+ #------------------------------------------------
226
+ def _remove_duplicates(self, x : list, y : list):
227
+ '''
228
+ Takes two lists with the same sizes and remove (x, y) points with repeated
229
+ x coordinates.
230
+ Return tuple with x and y after removal
231
+ '''
232
+
233
+ norg = len(x)
234
+
235
+ d_tmp = dict(zip(x, y))
236
+
237
+ x = list(d_tmp.keys())
238
+ y = list(d_tmp.values())
239
+
240
+ nfnl = len(x)
241
+
242
+ if norg != nfnl:
243
+ log.warning(f'Found duplicates: {norg} -> {nfnl}')
244
+
245
+ return x, y
246
+ #------------------------------------------------
247
+ def _check_xval_validity(self, xval : float | numpy.ndarray | list):
248
+ '''
249
+ Will check that xval is an acceptable value for the function to be evaluated at
250
+ '''
251
+
252
+ if isinstance(xval, list):
253
+ xval = numpy.array(xval)
254
+
255
+ if not isinstance(xval, (float, numpy.ndarray)):
256
+ raise ValueError(f'x value is not a float or numpy array: {xval}')
257
+
258
+ check_within_bounds_vect = numpy.vectorize(self._check_within_bounds)
259
+ check_within_bounds_vect(xval)
260
+ #------------------------------------------------
261
+ def _check_within_bounds(self, xval : float):
262
+ '''
263
+ Check that xval is within bounds of function
264
+ '''
265
+
266
+ if xval < min(self._l_x) or xval > max(self._l_x):
267
+ print(self)
268
+ raise FunOutOfBounds(f'x value outside bounds: {xval}')
269
+ #------------------------------------------------
270
+ def _json_encoder(self, obj):
271
+ '''
272
+ Takes Function object
273
+ Returns dictionary of attributes for encoding
274
+ '''
275
+ d_data = obj.__dict__
276
+
277
+ if '_interpolator' in d_data:
278
+ del d_data['_interpolator']
279
+
280
+ return d_data
281
+ #------------------------------------------------
282
+ def _save_plot(self, path : str):
283
+ '''
284
+ Takes path to PNG, saves scatter plot of l_y vs l_x
285
+ '''
286
+
287
+ plt.plot(self._l_x, self._l_y)
288
+ plt.savefig(path)
289
+ plt.close()
290
+
291
+ log.info(f'Saved to: {path}')
292
+ #------------------------------------------------
293
+ def save(self, path : str, plot : bool = False):
294
+ '''
295
+ Saves current object to JSON
296
+
297
+ path (str): Path to file, needs to end in .json
298
+ '''
299
+
300
+ if not path.endswith('.json'):
301
+ raise ValueError(f'Output path does not end in .json: {path}')
302
+
303
+ dir_name = os.path.dirname(path)
304
+ os.makedirs(dir_name, exist_ok=True)
305
+
306
+ with open(path, 'w', encoding='utf-8') as ofile:
307
+ json.dump(self, ofile, indent=4, default=self._json_encoder)
308
+
309
+ if plot:
310
+ path = path.replace('.json', '.png')
311
+ self._save_plot(path)
312
+
313
+ log.info(f'Saved to: {path}')
314
+ #------------------------------------------------
dmu/stats/utilities.py ADDED
@@ -0,0 +1,134 @@
1
+ '''
2
+ Module with utility functions related to the dmu.stats project
3
+ '''
4
+ import os
5
+ import re
6
+ from typing import Union
7
+ import zfit
8
+
9
+ from dmu.logging.log_store import LogStore
10
+
11
+ log = LogStore.add_logger('dmu:stats:utilities')
12
+ #-------------------------------------------------------
13
+ #Zfit/print_pdf
14
+ #-------------------------------------------------------
15
+ def _get_const(par : zfit.Parameter, d_const : Union[None, dict[str, list[float]]]) -> str:
16
+ '''
17
+ Takes zfit parameter and dictionary of constraints
18
+ Returns a formatted string with the value of the constraint on that parameter
19
+ '''
20
+ if d_const is None or par.name not in d_const:
21
+ return 'none'
22
+
23
+ obj = d_const[par.name]
24
+ if isinstance(obj, (list, tuple)):
25
+ [mu, sg] = obj
26
+ val = f'{mu:.3e}; {sg:.3e}'
27
+ else:
28
+ val = str(obj)
29
+
30
+ return val
31
+ #-------------------------------------------------------
32
+ def _blind_vars(s_par : set, l_blind : Union[list[str], None] = None) -> set[zfit.Parameter]:
33
+ '''
34
+ Takes set of zfit parameters and list of parameter names to blind
35
+ returns set of zfit parameters that should be blinded
36
+ '''
37
+ if l_blind is None:
38
+ return s_par
39
+
40
+ rgx_ors = '|'.join(l_blind)
41
+ regex = f'({rgx_ors})'
42
+
43
+ s_par_blind = { par for par in s_par if not re.match(regex, par.name) }
44
+
45
+ return s_par_blind
46
+ #-------------------------------------------------------
47
+ def _get_pars(
48
+ pdf : zfit.pdf.BasePDF,
49
+ blind : Union[None, list[str]]) -> tuple[list, list]:
50
+
51
+ s_par_flt = pdf.get_params(floating= True)
52
+ s_par_fix = pdf.get_params(floating=False)
53
+
54
+ s_par_flt = _blind_vars(s_par_flt, l_blind=blind)
55
+ s_par_fix = _blind_vars(s_par_fix, l_blind=blind)
56
+
57
+ l_par_flt = list(s_par_flt)
58
+ l_par_fix = list(s_par_fix)
59
+
60
+ l_par_flt = sorted(l_par_flt, key=lambda par: par.name)
61
+ l_par_fix = sorted(l_par_fix, key=lambda par: par.name)
62
+
63
+ return l_par_flt, l_par_fix
64
+ #-------------------------------------------------------
65
+ def _get_messages(
66
+ pdf : zfit.pdf.BasePDF,
67
+ l_par_flt : list,
68
+ l_par_fix : list,
69
+ d_const : Union[None, dict[str,list[float]]] = None) -> list[str]:
70
+
71
+ str_space = str(pdf.space)
72
+
73
+ l_msg=[]
74
+ l_msg.append('-' * 20)
75
+ l_msg.append(f'PDF: {pdf.name}')
76
+ l_msg.append(f'OBS: {str_space}')
77
+ l_msg.append(f'{"Name":<50}{"Value":>15}{"Low":>15}{"High":>15}{"Floating":>5}{"Constraint":>25}')
78
+ l_msg.append('-' * 20)
79
+ for par in l_par_flt:
80
+ value = par.value().numpy()
81
+ low = par.lower
82
+ hig = par.upper
83
+ const = _get_const(par, d_const)
84
+ l_msg.append(f'{par.name:<50}{value:>15.3e}{low:>15.3e}{hig:>15.3e}{par.floating:>5}{const:>25}')
85
+
86
+ l_msg.append('')
87
+
88
+ for par in l_par_fix:
89
+ value = par.value().numpy()
90
+ low = par.lower
91
+ hig = par.upper
92
+ const = _get_const(par, d_const)
93
+ l_msg.append(f'{par.name:<50}{value:>15.3e}{low:>15.3e}{hig:>15.3e}{par.floating:>5}{const:>25}')
94
+
95
+ return l_msg
96
+ #-------------------------------------------------------
97
+ def print_pdf(
98
+ pdf : zfit.pdf.BasePDF,
99
+ d_const : Union[None, dict[str,list[float]]] = None,
100
+ txt_path : Union[str,None] = None,
101
+ level : int = 20,
102
+ blind : Union[None, list[str]] = None):
103
+ '''
104
+ Function used to print zfit PDFs
105
+
106
+ Parameters
107
+ -------------------
108
+ pdf (zfit.PDF): PDF
109
+ d_const (dict): Optional dictionary mapping {par_name : [mu, sg]}
110
+ txt_path (str): Optionally, dump output to text in this path
111
+ level (str) : Optionally set the level at which the printing happens in screen, default info
112
+ blind (list) : List of regular expressions matching variable names to blind in printout
113
+ '''
114
+ l_par_flt, l_par_fix = _get_pars(pdf, blind)
115
+ l_msg = _get_messages(pdf, l_par_flt, l_par_fix, d_const)
116
+
117
+ if txt_path is not None:
118
+ log.debug(f'Saving to: {txt_path}')
119
+ message = '\n'.join(l_msg)
120
+ dir_path = os.path.dirname(txt_path)
121
+ os.makedirs(dir_path, exist_ok=True)
122
+ with open(txt_path, 'w', encoding='utf-8') as ofile:
123
+ ofile.write(message)
124
+
125
+ return
126
+
127
+ for msg in l_msg:
128
+ if level == 20:
129
+ log.info(msg)
130
+ elif level == 30:
131
+ log.debug(msg)
132
+ else:
133
+ raise ValueError(f'Invalid level: {level}')
134
+ #-------------------------------------------------------
@@ -0,0 +1,119 @@
1
+ '''
2
+ Module containing utility functions needed by unit tests
3
+ '''
4
+ import os
5
+ from typing import Union
6
+ from dataclasses import dataclass
7
+ from importlib.resources import files
8
+
9
+ from ROOT import RDF, TFile, RDataFrame
10
+
11
+ import pandas as pnd
12
+ import numpy
13
+ import yaml
14
+
15
+ from dmu.logging.log_store import LogStore
16
+
17
+ log = LogStore.add_logger('dmu:testing:utilities')
18
+ # -------------------------------
19
+ @dataclass
20
+ class Data:
21
+ '''
22
+ Class storing shared data
23
+ '''
24
+ nentries = 3000
25
+ # -------------------------------
26
+ def _double_data(d_data : dict) -> dict:
27
+ df_1 = pnd.DataFrame(d_data)
28
+ df_2 = pnd.DataFrame(d_data)
29
+
30
+ df = pnd.concat([df_1, df_2], axis=0)
31
+
32
+ d_data = { name : df[name].to_numpy() for name in df.columns }
33
+
34
+ return d_data
35
+ # -------------------------------
36
+ def _add_nans(d_data : dict) -> dict:
37
+ df_good = pnd.DataFrame(d_data)
38
+ df_bad = pnd.DataFrame(d_data)
39
+ df_bad[:] = numpy.nan
40
+
41
+ df = pnd.concat([df_good, df_bad])
42
+ d_data = { name : df[name].to_numpy() for name in df.columns }
43
+
44
+ return d_data
45
+ # -------------------------------
46
+ def get_rdf(kind : Union[str,None] = None,
47
+ repeated : bool = False,
48
+ add_nans : bool = False):
49
+ '''
50
+ Return ROOT dataframe with toy data
51
+ '''
52
+ d_data = {}
53
+ if kind == 'sig':
54
+ d_data['w'] = numpy.random.normal(0, 1, size=Data.nentries)
55
+ d_data['x'] = numpy.random.normal(0, 1, size=Data.nentries)
56
+ d_data['y'] = numpy.random.normal(0, 1, size=Data.nentries)
57
+ d_data['z'] = numpy.random.normal(0, 1, size=Data.nentries)
58
+ elif kind == 'bkg':
59
+ d_data['w'] = numpy.random.normal(1, 1, size=Data.nentries)
60
+ d_data['x'] = numpy.random.normal(1, 1, size=Data.nentries)
61
+ d_data['y'] = numpy.random.normal(1, 1, size=Data.nentries)
62
+ d_data['z'] = numpy.random.normal(1, 1, size=Data.nentries)
63
+ else:
64
+ log.error(f'Invalid kind: {kind}')
65
+ raise ValueError
66
+
67
+ if repeated:
68
+ d_data = _double_data(d_data)
69
+
70
+ if add_nans:
71
+ d_data = _add_nans(d_data)
72
+
73
+ rdf = RDF.FromNumpy(d_data)
74
+
75
+ return rdf
76
+ # -------------------------------
77
+ def get_config(name : Union[str,None] = None):
78
+ '''
79
+ Takes path to the YAML config file, after `dmu_data`
80
+ Returns dictionary with config
81
+ '''
82
+ if name is None:
83
+ raise ValueError('Name not pased')
84
+
85
+ cfg_path = files('dmu_data').joinpath(name)
86
+ cfg_path = str(cfg_path)
87
+ with open(cfg_path, encoding='utf-8') as ifile:
88
+ d_cfg = yaml.safe_load(ifile)
89
+
90
+ return d_cfg
91
+ # -------------------------------
92
+ def _get_rdf(nentries : int) -> RDataFrame:
93
+ rdf = RDataFrame(nentries)
94
+ rdf = rdf.Define('x', '0')
95
+ rdf = rdf.Define('y', '1')
96
+ rdf = rdf.Define('z', '2')
97
+
98
+ return rdf
99
+ # -------------------------------
100
+ def get_file_with_trees(path : str) -> TFile:
101
+ '''
102
+ Picks full path to toy ROOT file, in the form of /a/b/c/file.root
103
+ returns handle to it
104
+ '''
105
+ dir_name = os.path.dirname(path)
106
+ os.makedirs(dir_name, exist_ok=True)
107
+
108
+ snap = RDF.RSnapshotOptions()
109
+ snap.fMode = 'recreate'
110
+
111
+ l_tree_name = ['odir/idir/a', 'dir/b', 'c']
112
+ l_nevt = [ 100, 200, 300]
113
+
114
+ l_rdf = [ _get_rdf(nevt) for nevt in l_nevt ]
115
+ for rdf, tree_name in zip(l_rdf, l_tree_name):
116
+ rdf.Snapshot(tree_name, path, ['x', 'y', 'z'], snap)
117
+ snap.fMode = 'update'
118
+
119
+ return TFile(path)