data-manipulation-utilities 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. data_manipulation_utilities-0.0.1.dist-info/METADATA +713 -0
  2. data_manipulation_utilities-0.0.1.dist-info/RECORD +45 -0
  3. data_manipulation_utilities-0.0.1.dist-info/WHEEL +5 -0
  4. data_manipulation_utilities-0.0.1.dist-info/entry_points.txt +6 -0
  5. data_manipulation_utilities-0.0.1.dist-info/top_level.txt +3 -0
  6. dmu/arrays/utilities.py +55 -0
  7. dmu/dataframe/dataframe.py +36 -0
  8. dmu/generic/utilities.py +69 -0
  9. dmu/logging/log_store.py +129 -0
  10. dmu/ml/cv_classifier.py +122 -0
  11. dmu/ml/cv_predict.py +152 -0
  12. dmu/ml/train_mva.py +257 -0
  13. dmu/ml/utilities.py +132 -0
  14. dmu/plotting/plotter.py +227 -0
  15. dmu/plotting/plotter_1d.py +113 -0
  16. dmu/plotting/plotter_2d.py +87 -0
  17. dmu/rdataframe/atr_mgr.py +79 -0
  18. dmu/rdataframe/utilities.py +72 -0
  19. dmu/rfile/rfprinter.py +91 -0
  20. dmu/rfile/utilities.py +34 -0
  21. dmu/stats/fitter.py +515 -0
  22. dmu/stats/function.py +314 -0
  23. dmu/stats/utilities.py +134 -0
  24. dmu/testing/utilities.py +119 -0
  25. dmu/text/transformer.py +182 -0
  26. dmu_data/__init__.py +0 -0
  27. dmu_data/ml/tests/train_mva.yaml +37 -0
  28. dmu_data/plotting/tests/2d.yaml +14 -0
  29. dmu_data/plotting/tests/fig_size.yaml +13 -0
  30. dmu_data/plotting/tests/high_stat.yaml +22 -0
  31. dmu_data/plotting/tests/name.yaml +14 -0
  32. dmu_data/plotting/tests/no_bounds.yaml +12 -0
  33. dmu_data/plotting/tests/simple.yaml +8 -0
  34. dmu_data/plotting/tests/title.yaml +14 -0
  35. dmu_data/plotting/tests/weights.yaml +13 -0
  36. dmu_data/text/transform.toml +4 -0
  37. dmu_data/text/transform.txt +6 -0
  38. dmu_data/text/transform_set.toml +8 -0
  39. dmu_data/text/transform_set.txt +6 -0
  40. dmu_data/text/transform_trf.txt +12 -0
  41. dmu_scripts/physics/check_truth.py +121 -0
  42. dmu_scripts/rfile/compare_root_files.py +299 -0
  43. dmu_scripts/rfile/print_trees.py +35 -0
  44. dmu_scripts/ssh/coned.py +168 -0
  45. dmu_scripts/text/transform_text.py +46 -0
@@ -0,0 +1,113 @@
1
+ '''
2
+ Module containing plotter class
3
+ '''
4
+
5
+ import numpy
6
+ import matplotlib.pyplot as plt
7
+
8
+ from dmu.logging.log_store import LogStore
9
+ from dmu.plotting.plotter import Plotter
10
+
11
+ log = LogStore.add_logger('dmu:plotting:Plotter1D')
12
+ # --------------------------------------------
13
+ class Plotter1D(Plotter):
14
+ '''
15
+ Class used to plot columns in ROOT dataframes
16
+ '''
17
+ # --------------------------------------------
18
+ def __init__(self, d_rdf=None, cfg=None):
19
+ '''
20
+ Parameters:
21
+
22
+ d_rdf (dict): Dictionary mapping the kind of sample with the ROOT dataframe
23
+ cfg (dict): Dictionary with configuration, e.g. binning, ranges, etc
24
+ '''
25
+
26
+ super().__init__(d_rdf=d_rdf, cfg=cfg)
27
+ #-------------------------------------
28
+ def _get_labels(self, var : str) -> tuple[str,str]:
29
+ if 'labels' not in self._d_cfg['plots'][var]:
30
+ return var, 'Entries'
31
+
32
+ xname, yname = self._d_cfg['plots'][var]['labels' ]
33
+
34
+ return xname, yname
35
+ #-------------------------------------
36
+ def _plot_var(self, var):
37
+ '''
38
+ Will plot a variable from a dictionary of dataframes
39
+ Parameters
40
+ --------------------
41
+ var (str) : name of column
42
+ '''
43
+ # pylint: disable=too-many-locals
44
+
45
+ d_cfg = self._d_cfg['plots'][var]
46
+
47
+ minx, maxx, bins = d_cfg['binning']
48
+ yscale = d_cfg['yscale' ] if 'yscale' in d_cfg else 'linear'
49
+ xname, yname = self._get_labels(var)
50
+
51
+ normalized=False
52
+ if 'normalized' in d_cfg:
53
+ normalized = d_cfg['normalized']
54
+
55
+ title = ''
56
+ if 'title' in d_cfg:
57
+ title = d_cfg['title']
58
+
59
+ d_data = {}
60
+ for name, rdf in self._d_rdf.items():
61
+ d_data[name] = rdf.AsNumpy([var])[var]
62
+
63
+ if maxx <= minx + 1e-5:
64
+ log.info(f'Bounds not set for {var}, will calculated them')
65
+ minx, maxx = self._find_bounds(d_data = d_data, qnt=minx)
66
+ log.info(f'Using bounds [{minx:.3e}, {maxx:.3e}]')
67
+ else:
68
+ log.debug(f'Using bounds [{minx:.3e}, {maxx:.3e}]')
69
+
70
+ l_bc_all = []
71
+ d_wgt = self._get_weights(var)
72
+ for name, arr_val in d_data.items():
73
+ arr_wgt = d_wgt[name] if d_wgt is not None else None
74
+
75
+ self._print_weights(arr_wgt, var, name)
76
+ l_bc, _, _ = plt.hist(arr_val, weights=arr_wgt, bins=bins, range=(minx, maxx), density=normalized, histtype='step', label=name)
77
+ l_bc_all += numpy.array(l_bc).tolist()
78
+
79
+ plt.yscale(yscale)
80
+ plt.xlabel(xname)
81
+ plt.ylabel(yname)
82
+
83
+ if yscale == 'linear':
84
+ plt.ylim(bottom=0)
85
+
86
+ max_y = max(l_bc_all)
87
+ plt.ylim(top=1.2 * max_y)
88
+ plt.title(title)
89
+ # --------------------------------------------
90
+ def _plot_lines(self, var : str):
91
+ '''
92
+ Will plot vertical lines for some variables
93
+
94
+ var (str) : name of variable
95
+ '''
96
+ if var in ['B_const_mass_M', 'B_M']:
97
+ plt.axvline(x=5280, color='r', label=r'$B^+$' , linestyle=':')
98
+ elif var == 'Jpsi_M':
99
+ plt.axvline(x=3096, color='r', label=r'$J/\psi$', linestyle=':')
100
+ # --------------------------------------------
101
+ def run(self):
102
+ '''
103
+ Will run plotting
104
+ '''
105
+
106
+ fig_size = self._get_fig_size()
107
+ for var in self._d_cfg['plots']:
108
+ log.debug(f'Plotting: {var}')
109
+ plt.figure(var, figsize=fig_size)
110
+ self._plot_var(var)
111
+ self._plot_lines(var)
112
+ self._save_plot(var)
113
+ # --------------------------------------------
@@ -0,0 +1,87 @@
1
+ '''
2
+ Module containing Plotter2D class
3
+ '''
4
+ from typing import Union
5
+
6
+ import hist
7
+ import numpy
8
+ import mplhep
9
+ import matplotlib.pyplot as plt
10
+
11
+ from hist import Hist
12
+ from ROOT import RDataFrame
13
+ from dmu.logging.log_store import LogStore
14
+ from dmu.plotting.plotter import Plotter
15
+
16
+ log = LogStore.add_logger('dmu:plotting:Plotter2D')
17
+ # --------------------------------------------
18
+ class Plotter2D(Plotter):
19
+ '''
20
+ Class used to plot columns in ROOT dataframes
21
+ '''
22
+ # --------------------------------------------
23
+ def __init__(self, rdf=None, cfg=None):
24
+ '''
25
+ Parameters:
26
+
27
+ d_rdf (dict): Dictionary mapping the kind of sample with the ROOT dataframe
28
+ cfg (dict): Dictionary with configuration, e.g. binning, ranges, etc
29
+ '''
30
+
31
+ if not isinstance(cfg, dict):
32
+ raise ValueError('Config dictionary not passed')
33
+
34
+ self._rdf : RDataFrame = rdf
35
+ self._d_cfg : dict = cfg
36
+
37
+ self._wgt : numpy.ndarray
38
+ # --------------------------------------------
39
+ def _get_axis(self, var : str):
40
+ [minx, maxx, nbins] = self._d_cfg['axes'][var]['binning']
41
+ label = self._d_cfg['axes'][var][ 'label']
42
+
43
+ axis = hist.axis.Regular(nbins, minx, maxx, name=label, label=label)
44
+
45
+ return axis
46
+ # --------------------------------------------
47
+ def _get_data(self, varx : str, vary : str) -> tuple[numpy.ndarray, numpy.ndarray]:
48
+ d_data = self._rdf.AsNumpy([varx, vary])
49
+ arr_x = d_data[varx]
50
+ arr_y = d_data[vary]
51
+
52
+ return arr_x, arr_y
53
+ # --------------------------------------------
54
+ def _get_dataset_weights(self, wgt_name : Union[str, None]) -> Union[numpy.ndarray, None]:
55
+ if wgt_name is None:
56
+ log.debug('Skipping weights')
57
+ return None
58
+
59
+ log.debug(f'Adding weights form column {wgt_name}')
60
+ arr_wgt = self._rdf.AsNumpy([wgt_name])[wgt_name]
61
+
62
+ return arr_wgt
63
+ # --------------------------------------------
64
+ def _plot_vars(self, varx : str, vary : str, wgt_name : str) -> None:
65
+ log.info(f'Plotting {varx} vs {vary} with weights {wgt_name}')
66
+
67
+ ax_x = self._get_axis(varx)
68
+ ax_y = self._get_axis(vary)
69
+ arr_x, arr_y = self._get_data(varx, vary)
70
+
71
+ arr_w = self._get_dataset_weights(wgt_name)
72
+ hst = Hist(ax_x, ax_y)
73
+ hst.fill(arr_x, arr_y, weight=arr_w)
74
+
75
+ mplhep.hist2dplot(hst)
76
+ # --------------------------------------------
77
+ def run(self):
78
+ '''
79
+ Will run plotting
80
+ '''
81
+
82
+ fig_size = self._get_fig_size()
83
+ for [varx, vary, wgt_name, plot_name] in self._d_cfg['plots_2d']:
84
+ plt.figure(plot_name, figsize=fig_size)
85
+ self._plot_vars(varx, vary, wgt_name)
86
+ self._save_plot(plot_name)
87
+ # --------------------------------------------
@@ -0,0 +1,79 @@
1
+ '''
2
+ Module with AtrMgr class
3
+ '''
4
+
5
+ import os
6
+
7
+ from ROOT import RDataFrame
8
+
9
+ import dmu.generic.utilities as gut
10
+ from dmu.logging.log_store import LogStore
11
+
12
+ #TODO:Skip attributes that start with Take< in a betterway
13
+ log = LogStore.add_logger('dmu:rdataframe:atr_mgr')
14
+ #------------------------
15
+ class AtrMgr:
16
+ '''
17
+ Class intended to store attributes of ROOT dataframes and attach them back after a Filtering, definition, redefinition, etc operation
18
+ These operations create new dataframes and therefore drop attributes.
19
+ '''
20
+ #------------------------
21
+ def __init__(self, rdf : RDataFrame):
22
+ self.d_in_atr = {}
23
+
24
+ self._store_atr(rdf)
25
+ #------------------------
26
+ def _store_atr(self, rdf : RDataFrame):
27
+ self.d_in_atr = self._get_atr(rdf)
28
+ #------------------------
29
+ def _skip_attr(self, name : str) -> bool:
30
+ if name.startswith('__') and name.endswith('__'):
31
+ return True
32
+
33
+ return False
34
+ #------------------------
35
+ def _get_atr(self, rdf : RDataFrame):
36
+ l_atr = dir(rdf)
37
+ d_atr = {}
38
+ for atr in l_atr:
39
+ if self._skip_attr(atr):
40
+ continue
41
+
42
+ val = getattr(rdf, atr)
43
+ d_atr[atr] = val
44
+
45
+ return d_atr
46
+ #------------------------
47
+ def add_atr(self, rdf : RDataFrame) -> RDataFrame:
48
+ '''
49
+ Takes dataframe and adds back the attributes
50
+ '''
51
+ d_ou_atr = self._get_atr(rdf)
52
+
53
+ key_in_atr = set(self.d_in_atr.keys())
54
+ key_ou_atr = set( d_ou_atr.keys())
55
+
56
+ key_to_add = key_in_atr.difference(key_ou_atr)
57
+
58
+ for key in key_to_add:
59
+ val = self.d_in_atr[key]
60
+ if key.startswith('Take<') and key.endswith('>'):
61
+ continue
62
+
63
+ log.info(f'Adding attribute {key}')
64
+ setattr(rdf, key, val)
65
+
66
+ return rdf
67
+ #------------------------
68
+ def to_json(self, json_path : str) -> None:
69
+ '''
70
+ Saves the attributes inside current instance to JSON. Takes JSON path as argument
71
+ '''
72
+ json_dir = os.path.dirname(json_path)
73
+ os.makedirs(json_dir, exist_ok=True)
74
+
75
+ t_type = (list, str, int, float)
76
+ d_fl_atr = { key : val for key, val in self.d_in_atr.items() if isinstance(val, t_type) and isinstance(key, t_type) }
77
+
78
+ gut.dump_json(d_fl_atr, json_path)
79
+ #------------------------
@@ -0,0 +1,72 @@
1
+ '''
2
+ Module containing utility functions to be used with ROOT dataframes
3
+ '''
4
+
5
+ import re
6
+ from dataclasses import dataclass
7
+
8
+ import awkward as ak
9
+ import numpy
10
+
11
+ from ROOT import RDataFrame
12
+
13
+ from dmu.logging.log_store import LogStore
14
+
15
+ log = LogStore.add_logger('dmu:rdataframe:utilities')
16
+
17
+ # ---------------------------------------------------------------------
18
+ @dataclass
19
+ class Data:
20
+ '''
21
+ Class meant to store data that is shared
22
+ '''
23
+ l_good_type = [int, numpy.bool_, numpy.int32, numpy.uint32, numpy.int64, numpy.uint64, numpy.float32, numpy.float64]
24
+ d_cast_type = {'bool': numpy.int32}
25
+ # ---------------------------------------------------------------------
26
+ def add_column(rdf : RDataFrame, arr_val : numpy.ndarray | None, name : str, d_opt : dict | None = None):
27
+ '''
28
+ Will take a dataframe, an array of numbers and a string
29
+ Will add the array as a colunm to the dataframe
30
+
31
+ d_opt (dict) : Used to configure adding columns
32
+ exclude_re : Regex with patter of column names that we won't pick
33
+ '''
34
+
35
+ d_opt = {} if d_opt is None else d_opt
36
+ if arr_val is None:
37
+ raise ValueError('Array of values not introduced')
38
+
39
+ if 'exclude_re' not in d_opt:
40
+ d_opt['exclude_re'] = None
41
+
42
+ v_col_org = rdf.GetColumnNames()
43
+ l_col_org = [name.c_str() for name in v_col_org ]
44
+ l_col = []
45
+
46
+ tmva_rgx = r'tmva_\d+_\d+'
47
+
48
+ for col in l_col_org:
49
+ user_rgx = d_opt['exclude_re']
50
+ if user_rgx is not None and re.match(user_rgx, col):
51
+ log.debug(f'Dropping: {col}')
52
+ continue
53
+
54
+ if re.match(tmva_rgx, col):
55
+ log.debug(f'Dropping: {col}')
56
+ continue
57
+
58
+ log.debug(f'Picking: {col}')
59
+ l_col.append(col)
60
+
61
+ data = ak.from_rdataframe(rdf, columns=l_col)
62
+ d_data= { col : data[col] for col in l_col }
63
+
64
+ if arr_val.dtype == 'object':
65
+ arr_val = arr_val.astype(float)
66
+
67
+ d_data[name] = arr_val
68
+
69
+ rdf = ak.to_rdataframe(d_data)
70
+
71
+ return rdf
72
+ # ---------------------------------------------------------------------
dmu/rfile/rfprinter.py ADDED
@@ -0,0 +1,91 @@
1
+ '''
2
+ Module containing RFPrinter
3
+ '''
4
+ import os
5
+
6
+ from ROOT import TFile
7
+
8
+ from dmu.logging.log_store import LogStore
9
+
10
+ log = LogStore.add_logger('dmu:rfprinter')
11
+ #--------------------------------------------------
12
+ class RFPrinter:
13
+ '''
14
+ Class meant to print summary of ROOT file
15
+ '''
16
+ #-----------------------------------------
17
+ def __init__(self, path : str):
18
+ '''
19
+ Takes path to root file
20
+ '''
21
+ if not os.path.isfile(path):
22
+ raise FileNotFoundError(f'Cannot find {path}')
23
+
24
+ self._root_path = path
25
+ self._text_path = path.replace('.root', '.txt')
26
+ #-----------------------------------------
27
+ def _get_trees(self, ifile):
28
+ '''
29
+ Takes TFile object, returns list of TTree objects
30
+ '''
31
+ l_key=ifile.GetListOfKeys()
32
+
33
+ l_tree=[]
34
+ for key in l_key:
35
+ obj=key.ReadObj()
36
+ if obj.InheritsFrom("TTree"):
37
+ fname=ifile.GetName()
38
+ tname=obj.GetName()
39
+
40
+ title=f'{fname}/{tname}'
41
+ obj.SetTitle(title)
42
+ l_tree.append(obj)
43
+ elif obj.InheritsFrom("TDirectory"):
44
+ l_tree+=self._get_trees(obj)
45
+
46
+ return l_tree
47
+ #---------------------------------
48
+ def _get_tree_info(self, tree):
49
+ '''
50
+ Takes ROOT tree, returns list of strings with information about tree
51
+ '''
52
+ l_branch= tree.GetListOfBranches()
53
+ l_line = []
54
+ for branch in l_branch:
55
+ bname = branch.GetName()
56
+ leaf = branch.GetLeaf(bname)
57
+ btype = leaf.GetTypeName()
58
+
59
+ l_line.append(f'{"":4}{bname:<100}{btype:<40}')
60
+
61
+ return l_line
62
+ #-----------------------------------------
63
+ def _save_info(self, l_info):
64
+ '''
65
+ Takes list of strings, saves it to text file
66
+ '''
67
+
68
+ with open(self._text_path, 'w', encoding='utf-8') as ofile:
69
+ for info in l_info:
70
+ ofile.write(f'{info}\n')
71
+
72
+ log.info(f'Saved to: {self._text_path}')
73
+ #-----------------------------------------
74
+ def save(self, to_screen=False):
75
+ '''
76
+ Will save a text file with the summary of the ROOT file contents
77
+
78
+ to_screen (bool) : If true, will print to screen, default=False
79
+ '''
80
+ l_info = []
81
+ log.info(f'Reading from : {self._root_path}')
82
+ with TFile.Open(self._root_path) as ifile:
83
+ l_tree = self._get_trees(ifile)
84
+ for tree in l_tree:
85
+ l_info+= self._get_tree_info(tree)
86
+
87
+ self._save_info(l_info)
88
+ if to_screen:
89
+ for info in l_info:
90
+ log.info(info)
91
+ #-----------------------------------------
dmu/rfile/utilities.py ADDED
@@ -0,0 +1,34 @@
1
+ '''
2
+ Module with utilities needed to manipulate ROOT files
3
+ '''
4
+ from ROOT import TTree, TDirectoryFile
5
+
6
+ def get_trees_from_file(ifile : TDirectoryFile) -> dict[str,TTree]:
7
+ '''
8
+ Picks up a TFile object
9
+ Returns a dictionary of trees, with the tree location as the key
10
+ Can search recursively within directories
11
+ '''
12
+ if not ifile.InheritsFrom('TDirectoryFile'):
13
+ str_type = str(type(ifile))
14
+ raise ValueError(f'Unrecognized object type {str_type}')
15
+
16
+ dir_name = ifile.GetName()
17
+
18
+ d_tree={}
19
+ l_key =ifile.GetListOfKeys()
20
+
21
+ for key in l_key:
22
+ obj=key.ReadObj()
23
+ if obj.InheritsFrom('TDirectoryFile'):
24
+ d_tmp = get_trees_from_file(obj)
25
+ d_tree.update(d_tmp)
26
+ elif obj.InheritsFrom('TTree'):
27
+ obj_name = obj.GetName()
28
+ key = f'{dir_name}/{obj_name}'
29
+
30
+ d_tree[key] = obj
31
+ else:
32
+ continue
33
+
34
+ return d_tree