data-manipulation-utilities 0.2.7__py3-none-any.whl → 0.2.8.dev720__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. {data_manipulation_utilities-0.2.7.dist-info → data_manipulation_utilities-0.2.8.dev720.dist-info}/METADATA +669 -42
  2. data_manipulation_utilities-0.2.8.dev720.dist-info/RECORD +45 -0
  3. {data_manipulation_utilities-0.2.7.dist-info → data_manipulation_utilities-0.2.8.dev720.dist-info}/WHEEL +1 -2
  4. data_manipulation_utilities-0.2.8.dev720.dist-info/entry_points.txt +8 -0
  5. dmu/generic/hashing.py +34 -8
  6. dmu/generic/utilities.py +164 -11
  7. dmu/logging/log_store.py +34 -2
  8. dmu/logging/messages.py +96 -0
  9. dmu/ml/cv_classifier.py +3 -3
  10. dmu/ml/cv_diagnostics.py +3 -0
  11. dmu/ml/cv_performance.py +58 -0
  12. dmu/ml/cv_predict.py +149 -46
  13. dmu/ml/train_mva.py +482 -100
  14. dmu/ml/utilities.py +29 -10
  15. dmu/pdataframe/utilities.py +28 -3
  16. dmu/plotting/fwhm.py +2 -2
  17. dmu/plotting/matrix.py +1 -1
  18. dmu/plotting/plotter.py +23 -3
  19. dmu/plotting/plotter_1d.py +96 -32
  20. dmu/plotting/plotter_2d.py +5 -0
  21. dmu/rdataframe/utilities.py +54 -3
  22. dmu/rfile/ddfgetter.py +102 -0
  23. dmu/stats/fit_stats.py +129 -0
  24. dmu/stats/fitter.py +55 -22
  25. dmu/stats/gof_calculator.py +7 -0
  26. dmu/stats/model_factory.py +153 -62
  27. dmu/stats/parameters.py +100 -0
  28. dmu/stats/utilities.py +443 -12
  29. dmu/stats/wdata.py +187 -0
  30. dmu/stats/zfit.py +17 -0
  31. dmu/stats/zfit_plotter.py +147 -36
  32. dmu/testing/utilities.py +102 -24
  33. dmu/workflow/__init__.py +0 -0
  34. dmu/workflow/cache.py +266 -0
  35. data_manipulation_utilities-0.2.7.data/scripts/publish +0 -89
  36. data_manipulation_utilities-0.2.7.dist-info/RECORD +0 -69
  37. data_manipulation_utilities-0.2.7.dist-info/entry_points.txt +0 -6
  38. data_manipulation_utilities-0.2.7.dist-info/top_level.txt +0 -3
  39. dmu_data/ml/tests/diagnostics_from_file.yaml +0 -13
  40. dmu_data/ml/tests/diagnostics_from_model.yaml +0 -10
  41. dmu_data/ml/tests/diagnostics_multiple_methods.yaml +0 -10
  42. dmu_data/ml/tests/diagnostics_overlay.yaml +0 -33
  43. dmu_data/ml/tests/train_mva.yaml +0 -58
  44. dmu_data/ml/tests/train_mva_with_diagnostics.yaml +0 -82
  45. dmu_data/plotting/tests/2d.yaml +0 -24
  46. dmu_data/plotting/tests/fig_size.yaml +0 -13
  47. dmu_data/plotting/tests/high_stat.yaml +0 -22
  48. dmu_data/plotting/tests/legend.yaml +0 -12
  49. dmu_data/plotting/tests/name.yaml +0 -14
  50. dmu_data/plotting/tests/no_bounds.yaml +0 -12
  51. dmu_data/plotting/tests/normalized.yaml +0 -9
  52. dmu_data/plotting/tests/plug_fwhm.yaml +0 -24
  53. dmu_data/plotting/tests/plug_stats.yaml +0 -19
  54. dmu_data/plotting/tests/simple.yaml +0 -9
  55. dmu_data/plotting/tests/stats.yaml +0 -9
  56. dmu_data/plotting/tests/styling.yaml +0 -11
  57. dmu_data/plotting/tests/title.yaml +0 -14
  58. dmu_data/plotting/tests/weights.yaml +0 -13
  59. dmu_data/text/transform.toml +0 -4
  60. dmu_data/text/transform.txt +0 -6
  61. dmu_data/text/transform_set.toml +0 -8
  62. dmu_data/text/transform_set.txt +0 -6
  63. dmu_data/text/transform_trf.txt +0 -12
  64. dmu_scripts/git/publish +0 -89
  65. dmu_scripts/physics/check_truth.py +0 -121
  66. dmu_scripts/rfile/compare_root_files.py +0 -299
  67. dmu_scripts/rfile/print_trees.py +0 -35
  68. dmu_scripts/ssh/coned.py +0 -168
  69. dmu_scripts/text/transform_text.py +0 -46
  70. {dmu_data → dmu}/__init__.py +0 -0
dmu/ml/utilities.py CHANGED
@@ -14,11 +14,24 @@ log = LogStore.add_logger('dmu:ml:utilities')
14
14
  # ---------------------------------------------
15
15
  # Patch dataframe with features
16
16
  # ---------------------------------------------
17
- def patch_and_tag(df : pnd.DataFrame, value : float = 0) -> pnd.DataFrame:
17
+ def tag_nans(
18
+ df : pnd.DataFrame,
19
+ indexes : str) -> pnd.DataFrame:
18
20
  '''
19
- Takes pandas dataframe, replaces NaNs with value introduced, by default 0
20
- Returns array of indices where the replacement happened
21
+
22
+ Parameters
23
+ ----------------
24
+ df : Pandas dataframe
25
+ indexes : Name of dataframe attribute where array of indices of NaN rows should go
26
+
27
+ Returns
28
+ ----------------
29
+ Dataframe:
30
+
31
+ - After filtering, i.e. with dropped rows.
32
+ - With array of indices dropped as attribute at `patched_indices`
21
33
  '''
34
+
22
35
  l_nan = df.index[df.isna().any(axis=1)].tolist()
23
36
  nnan = len(l_nan)
24
37
  if nnan == 0:
@@ -29,15 +42,21 @@ def patch_and_tag(df : pnd.DataFrame, value : float = 0) -> pnd.DataFrame:
29
42
 
30
43
  df_nan_frq = df.isna().sum()
31
44
  df_nan_frq = df_nan_frq[df_nan_frq > 0]
32
- print(df_nan_frq)
33
45
 
46
+ log.info(df_nan_frq)
34
47
  log.warning(f'Attaching array with NaN {nnan} indexes and removing NaNs from dataframe')
35
48
 
36
- df_pa = df.fillna(value)
49
+ arr_index_2 = numpy.array(l_nan)
50
+ if indexes in df.attrs:
51
+ arr_index_1 = df.attrs[indexes]
52
+ arr_index = numpy.concatenate((arr_index_1, arr_index_2))
53
+ arr_index = numpy.unique(arr_index)
54
+ else:
55
+ arr_index = arr_index_2
37
56
 
38
- df_pa.attrs['patched_indices'] = numpy.array(l_nan)
57
+ df.attrs[indexes] = arr_index
39
58
 
40
- return df_pa
59
+ return df
41
60
  # ---------------------------------------------
42
61
  # Cleanup of dataframe with features
43
62
  # ---------------------------------------------
@@ -96,7 +115,7 @@ def _remove_repeated(df : pnd.DataFrame) -> pnd.DataFrame:
96
115
  return df_clean
97
116
  # ----------------------------------
98
117
  # ---------------------------------------------
99
- def get_hashes(df_ft : pnd.DataFrame, rvalue : str ='set') -> Union[set, list]:
118
+ def get_hashes(df_ft : pnd.DataFrame, rvalue : str ='set') -> Union[set[str], list[str]]:
100
119
  '''
101
120
  Will return hashes for each row in the feature dataframe
102
121
 
@@ -113,9 +132,9 @@ def get_hashes(df_ft : pnd.DataFrame, rvalue : str ='set') -> Union[set, list]:
113
132
 
114
133
  return res
115
134
  # ----------------------------------
116
- def hash_from_row(row):
135
+ def hash_from_row(row : pnd.Series) -> str:
117
136
  '''
118
- Will return a hash from a pandas dataframe row
137
+ Will return a hash in the form or a string from a pandas dataframe row
119
138
  corresponding to an event
120
139
  '''
121
140
  l_val = [ str(val) for val in row ]
@@ -48,12 +48,13 @@ def to_yaml(df : pnd.DataFrame, path : str):
48
48
  Makes the directory path if not found and saves data in YAML file
49
49
  '''
50
50
  dir_path = os.path.dirname(path)
51
- os.makedirs(dir_path, exist_ok=True)
51
+ if dir_path != '':
52
+ os.makedirs(dir_path, exist_ok=True)
52
53
 
53
54
  data = df.to_dict()
54
55
 
55
56
  with open(path, 'w', encoding='utf-8') as ofile:
56
- yaml.safe_dump(data, ofile)
57
+ yaml.dump(data, ofile, Dumper=yaml.CDumper)
57
58
  # -------------------------------------
58
59
  def from_yaml(path : str) -> pnd.DataFrame:
59
60
  '''
@@ -61,9 +62,33 @@ def from_yaml(path : str) -> pnd.DataFrame:
61
62
  Makes dataframe from it and returns it
62
63
  '''
63
64
  with open(path, encoding='utf-8') as ifile:
64
- data = yaml.safe_load(ifile)
65
+ data = yaml.load(ifile, Loader=yaml.CSafeLoader)
65
66
 
66
67
  df = pnd.DataFrame(data)
67
68
 
68
69
  return df
69
70
  # -------------------------------------
71
+ def dropna(df : pnd.DataFrame, max_frac : float = 0.02) -> pnd.DataFrame:
72
+ '''
73
+ Parameters
74
+ ----------------
75
+ df : Pandas dataframe potentially with NaNs
76
+ max_frac: Maximum fraction of the data that can be dropped, will raise exception beyond
77
+ '''
78
+
79
+ ini = len(df)
80
+ df = df.dropna()
81
+ fin = len(df)
82
+
83
+ if ini == fin:
84
+ log.debug('No NaNs were found')
85
+ return df
86
+
87
+ # If fewer elements survive the filter, raise
88
+ if fin < ini * (1 - max_frac):
89
+ raise ValueError(f'Too man NaNs were detected: {ini} --> {fin}')
90
+
91
+ log.info(f'Found NaNs: {ini} --> {fin}')
92
+
93
+ return df
94
+ # -------------------------------------
dmu/plotting/fwhm.py CHANGED
@@ -1,10 +1,10 @@
1
1
  '''
2
2
  Module with FWHM plugin class
3
3
  '''
4
- import zfit
5
4
  import numpy
6
5
  import matplotlib.pyplot as plt
7
6
 
7
+ from dmu.stats.zfit import zfit
8
8
  from dmu.logging.log_store import LogStore
9
9
 
10
10
  log = LogStore.add_logger('dmu:plotting:fwhm')
@@ -49,7 +49,7 @@ class FWHM:
49
49
 
50
50
  log.info('Running FWHM pluggin')
51
51
  obs = zfit.Space('mass', limits=(minx, maxx))
52
- pdf= zfit.pdf.KDE1DimExact(obs=obs, data=self._arr_val, weights=self._arr_wgt)
52
+ pdf= zfit.pdf.KDE1DimISJ(obs=obs, data=self._arr_val, weights=self._arr_wgt)
53
53
 
54
54
  xval = numpy.linspace(minx, maxx, 200)
55
55
  yval = pdf.pdf(xval)
dmu/plotting/matrix.py CHANGED
@@ -102,7 +102,7 @@ class MatrixPlotter:
102
102
 
103
103
  fig, ax = plt.subplots() if fsize is None else plt.subplots(figsize=fsize)
104
104
 
105
- palette = plt.cm.viridis
105
+ palette = plt.cm.viridis #pylint: disable=no-member
106
106
  im = ax.imshow(self._mat, cmap=palette, vmin=zmin, vmax=zmax)
107
107
  self._set_axes(ax)
108
108
 
dmu/plotting/plotter.py CHANGED
@@ -3,6 +3,7 @@ Module containing plotter class
3
3
  '''
4
4
 
5
5
  import os
6
+ import json
6
7
  import math
7
8
  from typing import Union
8
9
 
@@ -185,14 +186,17 @@ class Plotter:
185
186
 
186
187
  return d_weight
187
188
  # --------------------------------------------
188
- def _read_weights(self, name : str, rdf : RDataFrame) -> Union[numpy.ndarray, None]:
189
+ def _read_weights(self, name : str, rdf : RDataFrame) -> numpy.ndarray:
189
190
  v_col = rdf.GetColumnNames()
190
191
  l_col = [ col.c_str() for col in v_col ]
191
192
 
192
193
  if name not in l_col:
193
- log.debug(f'Weight {name} not found')
194
- return None
194
+ nentries = rdf.Count().GetValue()
195
+ log.debug(f'Weight {name} not found, using ones')
196
+
197
+ return numpy.ones(nentries)
195
198
 
199
+ log.debug(f'Weight {name} found')
196
200
  arr_wgt = rdf.AsNumpy([name])[name]
197
201
 
198
202
  return arr_wgt
@@ -230,4 +234,20 @@ class Plotter:
230
234
  plt.tight_layout()
231
235
  plt.savefig(plot_path)
232
236
  plt.close(var)
237
+ #-------------------------------------
238
+ def _data_to_json(self,
239
+ data : dict[str,float],
240
+ name : str) -> None:
241
+
242
+ # In case the values are numpy objects, which are not JSON
243
+ # serializable
244
+ data = { key : float(value) for key, value in data.items() }
245
+
246
+ plt_dir = self._d_cfg['saving']['plt_dir']
247
+ os.makedirs(plt_dir, exist_ok=True)
248
+
249
+ name = name.replace(' ', '_')
250
+ json_path = f'{plt_dir}/{name}.json'
251
+ with open(json_path, 'w', encoding='utf-8') as ofile:
252
+ json.dump(data, ofile, indent=2, sort_keys=True)
233
253
  # --------------------------------------------
@@ -1,7 +1,9 @@
1
1
  '''
2
2
  Module containing plotter class
3
3
  '''
4
- import copy
4
+ # pylint: disable=too-many-positional-arguments, too-many-arguments
5
+
6
+ import cppyy
5
7
  from hist import Hist
6
8
 
7
9
  import numpy
@@ -56,12 +58,14 @@ class Plotter1D(Plotter):
56
58
 
57
59
  return minx, maxx, bins
58
60
  #-------------------------------------
59
- def _run_plugins(self,
60
- arr_val : numpy.ndarray,
61
- arr_wgt : numpy.ndarray,
62
- hst,
63
- name : str,
64
- varname : str) -> None:
61
+ def _run_plugins(
62
+ self,
63
+ arr_val : numpy.ndarray,
64
+ arr_wgt : numpy.ndarray,
65
+ hst : Hist,
66
+ name : str,
67
+ varname : str) -> None:
68
+
65
69
  if 'plugin' not in self._d_cfg:
66
70
  log.debug('No plugins found')
67
71
  return
@@ -73,7 +77,13 @@ class Plotter1D(Plotter):
73
77
 
74
78
  log.debug(f'FWHM plugin found for variable {varname}')
75
79
  cfg = self._d_cfg['plugin']['fwhm'][varname]
76
- self._run_fwhm(arr_val = arr_val, arr_wgt=arr_wgt, hst=hst, name=name, cfg = cfg)
80
+ self._run_fwhm(
81
+ arr_val = arr_val,
82
+ arr_wgt = arr_wgt,
83
+ hst = hst,
84
+ name = name,
85
+ varname = varname,
86
+ cfg = cfg)
77
87
 
78
88
  if 'stats' in self._d_cfg['plugin']:
79
89
  if varname not in self._d_cfg['plugin']['stats']:
@@ -82,29 +92,55 @@ class Plotter1D(Plotter):
82
92
 
83
93
  log.debug(f'stats plugin found for variable {varname}')
84
94
  cfg = self._d_cfg['plugin']['stats'][varname]
85
- self._run_stats(arr_val = arr_val, arr_wgt=arr_wgt, name=name, cfg = cfg)
95
+ self._run_stats(
96
+ arr_val = arr_val,
97
+ arr_wgt = arr_wgt,
98
+ name = name,
99
+ varname = varname,
100
+ cfg = cfg)
86
101
  #-------------------------------------
87
- def _run_stats(self, arr_val : numpy.ndarray, arr_wgt : numpy.ndarray, name : str, cfg : dict[str:str]) -> None:
102
+ def _run_stats(
103
+ self,
104
+ arr_val : numpy.ndarray,
105
+ arr_wgt : numpy.ndarray,
106
+ varname : str,
107
+ name : str,
108
+ cfg : dict[str:str]) -> None:
109
+
88
110
  this_title = ''
111
+ data = {}
89
112
  if 'sum' in cfg:
90
113
  form = cfg['sum']
91
114
  sumv = numpy.sum(arr_wgt)
92
115
  this_title += form.format(sumv) + '; '
116
+ data['sum'] = sumv
93
117
 
94
118
  if 'mean' in cfg:
95
119
  form = cfg['mean']
96
120
  mean = numpy.average(arr_val, weights=arr_wgt)
97
121
  this_title += form.format(mean) + '; '
122
+ data['mean'] = mean
98
123
 
99
124
  if 'rms' in cfg:
100
125
  form = cfg['rms']
101
126
  mean = numpy.average(arr_val, weights=arr_wgt)
102
127
  rms = numpy.sqrt(numpy.average((arr_val - mean) ** 2, weights=arr_wgt))
103
128
  this_title += form.format(rms ) + '; '
129
+ data['rms'] = rms
130
+
131
+ self._data_to_json(data = data, name = f'stats_{varname}_{name}')
104
132
 
105
133
  self._title+= f'\n{name}: {this_title}'
106
134
  #-------------------------------------
107
- def _run_fwhm(self, arr_val : numpy.ndarray, arr_wgt : numpy.ndarray, hst, name : str, cfg : dict) -> None:
135
+ def _run_fwhm(
136
+ self,
137
+ arr_val : numpy.ndarray,
138
+ arr_wgt : numpy.ndarray,
139
+ hst : Hist,
140
+ varname : str,
141
+ name : str,
142
+ cfg : dict) -> None:
143
+
108
144
  arr_bin_cnt = hst.values()
109
145
  maxy = numpy.max(arr_bin_cnt)
110
146
  obj = FWHM(cfg=cfg, val=arr_val, wgt=arr_wgt, maxy=maxy)
@@ -112,13 +148,17 @@ class Plotter1D(Plotter):
112
148
 
113
149
  form = cfg['format']
114
150
  this_title = form.format(fwhm)
151
+ data = {}
115
152
 
116
153
  if 'add_std' in cfg and cfg['add_std']:
117
154
  mu = numpy.average(arr_val , weights=arr_wgt)
118
- avg = numpy.average((arr_val - mu) ** 2, weights=arr_wgt)
119
- std = numpy.sqrt(avg)
155
+ var = numpy.average((arr_val - mu) ** 2, weights=arr_wgt)
156
+ std = numpy.sqrt(var)
120
157
  form = form.replace('FWHM', 'STD')
121
158
  this_title+= '; ' + form.format(std)
159
+ data = {'mu' : mu, 'std' : std, 'fwhm' : fwhm}
160
+
161
+ self._data_to_json(data = data, name = f'fwhm_{varname}_{name}')
122
162
 
123
163
  self._title+= f'\n{name}: {this_title}'
124
164
  #-------------------------------------
@@ -137,51 +177,70 @@ class Plotter1D(Plotter):
137
177
 
138
178
  d_data = {}
139
179
  for name, rdf in self._d_rdf.items():
140
- log.debug(f'Plotting: {var}/{name}')
141
- d_data[name] = rdf.AsNumpy([var])[var]
180
+ try:
181
+ log.debug(f'Plotting: {var}/{name}')
182
+ d_data[name] = rdf.AsNumpy([var])[var]
183
+ except cppyy.gbl.std.runtime_error as exc:
184
+ raise ValueError(f'Cannot find variable {var} in category {name}') from exc
142
185
 
143
186
  minx, maxx, bins = self._get_binning(var, d_data)
144
187
  d_wgt = self._get_weights(var)
145
188
 
146
189
  l_bc_all = []
147
190
  for name, arr_val in d_data.items():
148
- label = self._label_from_name(name, arr_val)
191
+ label = self._label_from_name(name)
149
192
  arr_wgt = d_wgt[name] if d_wgt is not None else numpy.ones_like(arr_val)
150
193
  arr_wgt = self._normalize_weights(arr_wgt, var)
151
194
  hst = Hist.new.Reg(bins=bins, start=minx, stop=maxx, name='x').Weight()
152
195
  hst.fill(x=arr_val, weight=arr_wgt)
153
196
  self._run_plugins(arr_val, arr_wgt, hst, name, var)
197
+ style = self._get_style_config(var=var, label=label)
154
198
 
155
- if 'styling' in self._d_cfg['plots'][var]:
156
- style = self._d_cfg['plots'][var]['styling']
157
- style = copy.deepcopy(style)
158
- else:
159
- style = {'label' : label, 'histtype' : 'errorbar', 'marker' : '.', 'linestyle' : 'none'}
160
-
161
- if 'label' not in style:
162
- style['label'] = label
163
-
199
+ log.debug(f'Style: {style}')
164
200
  hst.plot(**style)
201
+
165
202
  l_bc_all += hst.values().tolist()
166
203
 
167
204
  max_y = max(l_bc_all)
168
205
 
169
206
  return max_y
170
207
  # --------------------------------------------
171
- def _label_from_name(self, name : str, arr_val : numpy.ndarray) -> str:
208
+ def _get_style_config(self, var : str, label : str) -> dict[str,str]:
209
+ style = {
210
+ 'label' : label,
211
+ 'histtype' : 'errorbar',
212
+ 'linestyle' : 'none'}
213
+
214
+ if 'styling' not in self._d_cfg['plots'][var]:
215
+ log.debug(f'Styling not specified for {var}')
216
+ return style
217
+
218
+ if label not in self._d_cfg['plots'][var]['styling']:
219
+ log.debug(f'Styling not specified for {var}/{label}')
220
+ return style
221
+
222
+ custom_style = self._d_cfg['plots'][var]['styling'][label]
223
+ style.update(custom_style)
224
+ log.debug(f'Using custom styling for {var}/{label}')
225
+
226
+ return style
227
+ # --------------------------------------------
228
+ def _label_from_name(self, name : str) -> str:
172
229
  if 'stats' not in self._d_cfg:
173
230
  return name
174
231
 
175
232
  d_stat = self._d_cfg['stats']
176
- if 'nentries' not in d_stat:
233
+ if 'sumw' not in d_stat:
177
234
  return name
178
235
 
179
- form = d_stat['nentries']
236
+ form = d_stat['sumw']
180
237
 
181
- nentries = len(arr_val)
182
- nentries = form.format(nentries)
238
+ arr_wgt = self._d_wgt[name]
239
+ arr_wgt = numpy.nan_to_num(arr_wgt, nan=0.0)
240
+ sumw = numpy.sum(arr_wgt)
241
+ nentries = form.format(sumw)
183
242
 
184
- return f'{name}{nentries}'
243
+ return f'{name:<15}{nentries:<10}'
185
244
  # --------------------------------------------
186
245
  def _normalize_weights(self, arr_wgt : numpy.ndarray, var : str) -> numpy.ndarray:
187
246
  cfg_var = self._d_cfg['plots'][var]
@@ -227,10 +286,15 @@ class Plotter1D(Plotter):
227
286
 
228
287
  var (str) : name of variable
229
288
  '''
289
+ var_cfg = self._d_cfg['plots'][var]
290
+ if 'vline' in var_cfg:
291
+ line_cfg = var_cfg['vline']
292
+ plt.axvline(**line_cfg)
293
+
230
294
  if 'style' in self._d_cfg and 'skip_lines' in self._d_cfg['style'] and self._d_cfg['style']['skip_lines']:
231
295
  return
232
296
 
233
- if var in ['B_const_mass_M', 'B_M']:
297
+ if var in ['B_const_mass_M', 'B_M', 'B_Mass', 'B_Mass_smr']:
234
298
  plt.axvline(x=5280, color='r', label=r'$B^+$' , linestyle=':')
235
299
  elif var == 'Jpsi_M':
236
300
  plt.axvline(x=3096, color='r', label=r'$J/\psi$', linestyle=':')
@@ -70,6 +70,11 @@ class Plotter2D(Plotter):
70
70
  hst = Hist(ax_x, ax_y)
71
71
  hst.fill(arr_x, arr_y, weight=arr_w)
72
72
 
73
+ if hst.values().sum() == 0:
74
+ log.warning('Empty histogram, not using log scale')
75
+ mplhep.hist2dplot(hst)
76
+ return
77
+
73
78
  if use_log:
74
79
  mplhep.hist2dplot(hst, norm=LogNorm())
75
80
  else:
@@ -16,7 +16,6 @@ from ROOT import RDataFrame, RDF, Numba
16
16
  from dmu.logging.log_store import LogStore
17
17
 
18
18
  log = LogStore.add_logger('dmu:rdataframe:utilities')
19
-
20
19
  # ---------------------------------------------------------------------
21
20
  @dataclass
22
21
  class Data:
@@ -98,12 +97,17 @@ def add_column_with_numba(
98
97
 
99
98
  return rdf
100
99
  # ---------------------------------------------------------------------
101
- def rdf_report_to_df(rep : RDF.RCutFlowReport) -> pnd.DataFrame:
100
+ def rdf_report_to_df(rep : RDF.RCutFlowReport) -> Union[pnd.DataFrame, None]:
102
101
  '''
103
102
  Takes the output of rdf.Report(), i.e. an RDataFrame cutflow report.
104
103
 
105
- Produces a pandas dataframe with
104
+ Produces a pandas dataframe with the total, failed, efficiency, and cummulative efficiency
105
+ If no cut was applied, i.e. the cutflow is empty, will return None and show warning
106
106
  '''
107
+ if rep.begin() == rep.end():
108
+ log.warning('Empty cutflow')
109
+ return None
110
+
107
111
  d_data = {'cut' : [], 'All' : [], 'Passed' : []}
108
112
  for cut in rep:
109
113
  name=cut.GetName()
@@ -119,3 +123,50 @@ def rdf_report_to_df(rep : RDF.RCutFlowReport) -> pnd.DataFrame:
119
123
  df['Cummulative'] = df['Efficiency'].cumprod()
120
124
 
121
125
  return df
126
+ # ---------------------------------------------------------------------
127
+ def random_filter(rdf : RDataFrame, entries : int) -> RDataFrame:
128
+ '''
129
+ Filters a dataframe, such that the output has **approximately** `entries` entries
130
+ '''
131
+ ntot = rdf.Count().GetValue()
132
+
133
+ if entries <= 0 or entries >= ntot:
134
+ log.warning(f'Requested {entries}/{ntot} random entries, not filtering')
135
+ return rdf
136
+
137
+ prob = float(entries) / ntot
138
+ name = f'filter_{entries}'
139
+
140
+ rdf = rdf.Define(name, 'gRandom->Rndm();')
141
+ rdf = rdf.Filter(f'{name} < {prob}', name)
142
+ nres = rdf.Count().GetValue()
143
+
144
+ log.debug(f'Requested {ntot}, picked {nres}')
145
+
146
+ return rdf
147
+ # ---------------------------------------------------------------------
148
+ def rdf_to_df(
149
+ rdf : RDataFrame,
150
+ columns : list[str]) -> pnd.DataFrame:
151
+ '''
152
+ Parameters
153
+ ---------------
154
+ rdf : ROOT dataframe
155
+ branches : List of columns to keep in pandas dataframe
156
+
157
+ Returns
158
+ ---------------
159
+ Pandas dataframe with subset of columns
160
+ '''
161
+ log.debug('Storing branches')
162
+ data = rdf.AsNumpy(columns)
163
+ df = pnd.DataFrame(data)
164
+
165
+ if len(df) == 0:
166
+ rep = rdf.Report()
167
+ cutflow = rdf_report_to_df(rep)
168
+ log.warning('Empty dataset:\n')
169
+ log.info(cutflow)
170
+
171
+ return df
172
+ # ---------------------------------------------------------------------
dmu/rfile/ddfgetter.py ADDED
@@ -0,0 +1,102 @@
1
+ '''
2
+ Module holding DDFGetter class
3
+ '''
4
+ # pylint: disable=unnecessary-lambda-assignment
5
+
6
+ from functools import reduce
7
+
8
+ import dask
9
+ import dask.dataframe as ddf
10
+
11
+ import uproot
12
+ import yaml
13
+ import pandas as pnd
14
+ from dmu.logging.log_store import LogStore
15
+
16
+ log=LogStore.add_logger('dmu:rfile:ddfgetter')
17
+ # -------------------------------
18
+ class DDFGetter:
19
+ '''
20
+ Class used to provide Dask DataFrames from YAML config files. It should handle:
21
+
22
+ - Friend trees
23
+ - Multiple files
24
+ '''
25
+ # ----------------------
26
+ def __init__(
27
+ self,
28
+ cfg : dict = None,
29
+ config_path : str = None,
30
+ columns : list[str] = None):
31
+ '''
32
+ Params
33
+ --------------
34
+ cfg : Dictionary storing the configuration (optional)
35
+ config_path : Path to YAML configuration file (optional)
36
+ colums : If passed, will only use these columns to build dataframe
37
+ '''
38
+ self._cfg = self._load_config(path=config_path) if cfg is None else cfg
39
+ self._columns = columns
40
+ # ----------------------
41
+ def _load_config(self, path : str) -> dict:
42
+ with open(path, encoding='utf-8') as ifile:
43
+ data = yaml.safe_load(ifile)
44
+
45
+ return data
46
+ # ----------------------
47
+ def _get_columns_to_keep(self, tree) -> list[str]:
48
+ if self._columns is None:
49
+ return None
50
+
51
+ columns = self._columns + self._cfg['primary_keys']
52
+ columns = set(columns)
53
+ available = set(tree.keys())
54
+ columns = columns & available
55
+
56
+ log.debug(f'Keeping only {columns}')
57
+
58
+ return list(columns)
59
+ # ----------------------
60
+ def _get_file_df(self, fpath : str) -> pnd.DataFrame:
61
+ with uproot.open(fpath) as file:
62
+ tname = self._cfg['tree']
63
+ tree = file[tname]
64
+ columns = self._get_columns_to_keep(tree)
65
+ df = tree.arrays(columns, library='pd')
66
+
67
+ return df
68
+ # ----------------------
69
+ def _get_file_dfs(self, fname : str) -> list[pnd.DataFrame]:
70
+ log.debug(f'Getting dataframes for: {fname}')
71
+
72
+ l_fpath = [ f'{sample_dir}/{fname}' for sample_dir in self._cfg['samples'] ]
73
+ l_df = [ self._get_file_df(fpath = fpath) for fpath in l_fpath ]
74
+
75
+ return l_df
76
+ # ----------------------
77
+ def _load_root_file(self, fname : str, ifname : int, size : int) -> pnd.DataFrame:
78
+ keys = self._cfg['primary_keys']
79
+ l_df = self._get_file_dfs(fname=fname)
80
+ fun = lambda df_l, df_r : pnd.merge(df_l, df_r, on=keys)
81
+
82
+ log.info(f'Merging dataframes: {ifname}/{size}')
83
+ df = reduce(fun, l_df)
84
+ df = df.drop(columns=keys)
85
+
86
+ return df
87
+ # ----------------------
88
+ def get_dataframe(self) -> ddf:
89
+ '''
90
+ Returns dask dataframe
91
+ '''
92
+ l_fname = self._cfg['files']
93
+ nfiles = len(l_fname)
94
+
95
+ log.debug('Building dataframes for single files')
96
+ l_dfs = [ dask.delayed(self._load_root_file)(fname = fname, ifname=ifname, size=nfiles) for ifname, fname in enumerate(l_fname) ]
97
+
98
+ log.debug('Bulding full dask dataframe')
99
+ output = ddf.from_delayed(l_dfs)
100
+
101
+ return output
102
+ # -------------------------------