data-manipulation-utilities 0.2.6__py3-none-any.whl → 0.2.8.dev714__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. {data_manipulation_utilities-0.2.6.dist-info → data_manipulation_utilities-0.2.8.dev714.dist-info}/METADATA +800 -34
  2. data_manipulation_utilities-0.2.8.dev714.dist-info/RECORD +93 -0
  3. {data_manipulation_utilities-0.2.6.dist-info → data_manipulation_utilities-0.2.8.dev714.dist-info}/WHEEL +1 -1
  4. {data_manipulation_utilities-0.2.6.dist-info → data_manipulation_utilities-0.2.8.dev714.dist-info}/entry_points.txt +1 -0
  5. dmu/__init__.py +0 -0
  6. dmu/generic/hashing.py +70 -0
  7. dmu/generic/utilities.py +175 -9
  8. dmu/generic/version_management.py +3 -5
  9. dmu/logging/log_store.py +34 -2
  10. dmu/logging/messages.py +96 -0
  11. dmu/ml/cv_classifier.py +3 -3
  12. dmu/ml/cv_diagnostics.py +224 -0
  13. dmu/ml/cv_performance.py +58 -0
  14. dmu/ml/cv_predict.py +149 -46
  15. dmu/ml/train_mva.py +587 -112
  16. dmu/ml/utilities.py +29 -10
  17. dmu/pdataframe/utilities.py +61 -3
  18. dmu/plotting/fwhm.py +64 -0
  19. dmu/plotting/matrix.py +1 -1
  20. dmu/plotting/plotter.py +25 -3
  21. dmu/plotting/plotter_1d.py +159 -14
  22. dmu/plotting/plotter_2d.py +5 -0
  23. dmu/rdataframe/utilities.py +54 -3
  24. dmu/rfile/ddfgetter.py +102 -0
  25. dmu/stats/fit_stats.py +129 -0
  26. dmu/stats/fitter.py +56 -23
  27. dmu/stats/gof_calculator.py +7 -0
  28. dmu/stats/model_factory.py +305 -50
  29. dmu/stats/parameters.py +100 -0
  30. dmu/stats/utilities.py +443 -12
  31. dmu/stats/wdata.py +187 -0
  32. dmu/stats/zfit.py +17 -0
  33. dmu/stats/zfit_models.py +68 -0
  34. dmu/stats/zfit_plotter.py +175 -56
  35. dmu/testing/utilities.py +120 -15
  36. dmu/workflow/__init__.py +0 -0
  37. dmu/workflow/cache.py +266 -0
  38. dmu_data/ml/tests/diagnostics_from_file.yaml +13 -0
  39. dmu_data/ml/tests/diagnostics_from_model.yaml +10 -0
  40. dmu_data/ml/tests/diagnostics_multiple_methods.yaml +10 -0
  41. dmu_data/ml/tests/diagnostics_overlay.yaml +33 -0
  42. dmu_data/ml/tests/train_mva.yaml +20 -12
  43. dmu_data/ml/tests/train_mva_def.yaml +75 -0
  44. dmu_data/ml/tests/train_mva_with_diagnostics.yaml +87 -0
  45. dmu_data/ml/tests/train_mva_with_preffix.yaml +58 -0
  46. dmu_data/plotting/tests/2d.yaml +5 -5
  47. dmu_data/plotting/tests/line.yaml +15 -0
  48. dmu_data/plotting/tests/plug_fwhm.yaml +24 -0
  49. dmu_data/plotting/tests/plug_stats.yaml +19 -0
  50. dmu_data/plotting/tests/simple.yaml +4 -3
  51. dmu_data/plotting/tests/styling.yaml +18 -0
  52. dmu_data/rfile/friends.yaml +13 -0
  53. dmu_data/stats/fitter/test_simple.yaml +28 -0
  54. dmu_data/stats/kde_optimizer/control.json +1 -0
  55. dmu_data/stats/kde_optimizer/signal.json +1 -0
  56. dmu_data/stats/parameters/data.yaml +178 -0
  57. dmu_data/tests/config.json +6 -0
  58. dmu_data/tests/config.yaml +4 -0
  59. dmu_data/tests/pdf_to_tex.txt +34 -0
  60. dmu_scripts/kerberos/check_expiration +21 -0
  61. dmu_scripts/kerberos/convert_certificate +22 -0
  62. dmu_scripts/ml/compare_classifiers.py +85 -0
  63. data_manipulation_utilities-0.2.6.dist-info/RECORD +0 -57
  64. {data_manipulation_utilities-0.2.6.data → data_manipulation_utilities-0.2.8.dev714.data}/scripts/publish +0 -0
  65. {data_manipulation_utilities-0.2.6.dist-info → data_manipulation_utilities-0.2.8.dev714.dist-info}/top_level.txt +0 -0
dmu/ml/utilities.py CHANGED
@@ -14,11 +14,24 @@ log = LogStore.add_logger('dmu:ml:utilities')
14
14
  # ---------------------------------------------
15
15
  # Patch dataframe with features
16
16
  # ---------------------------------------------
17
- def patch_and_tag(df : pnd.DataFrame, value : float = 0) -> pnd.DataFrame:
17
+ def tag_nans(
18
+ df : pnd.DataFrame,
19
+ indexes : str) -> pnd.DataFrame:
18
20
  '''
19
- Takes pandas dataframe, replaces NaNs with value introduced, by default 0
20
- Returns array of indices where the replacement happened
21
+
22
+ Parameters
23
+ ----------------
24
+ df : Pandas dataframe
25
+ indexes : Name of dataframe attribute where array of indices of NaN rows should go
26
+
27
+ Returns
28
+ ----------------
29
+ Dataframe:
30
+
31
+ - After filtering, i.e. with dropped rows.
32
+ - With array of indices dropped as attribute at `patched_indices`
21
33
  '''
34
+
22
35
  l_nan = df.index[df.isna().any(axis=1)].tolist()
23
36
  nnan = len(l_nan)
24
37
  if nnan == 0:
@@ -29,15 +42,21 @@ def patch_and_tag(df : pnd.DataFrame, value : float = 0) -> pnd.DataFrame:
29
42
 
30
43
  df_nan_frq = df.isna().sum()
31
44
  df_nan_frq = df_nan_frq[df_nan_frq > 0]
32
- print(df_nan_frq)
33
45
 
46
+ log.info(df_nan_frq)
34
47
  log.warning(f'Attaching array with NaN {nnan} indexes and removing NaNs from dataframe')
35
48
 
36
- df_pa = df.fillna(value)
49
+ arr_index_2 = numpy.array(l_nan)
50
+ if indexes in df.attrs:
51
+ arr_index_1 = df.attrs[indexes]
52
+ arr_index = numpy.concatenate((arr_index_1, arr_index_2))
53
+ arr_index = numpy.unique(arr_index)
54
+ else:
55
+ arr_index = arr_index_2
37
56
 
38
- df_pa.attrs['patched_indices'] = numpy.array(l_nan)
57
+ df.attrs[indexes] = arr_index
39
58
 
40
- return df_pa
59
+ return df
41
60
  # ---------------------------------------------
42
61
  # Cleanup of dataframe with features
43
62
  # ---------------------------------------------
@@ -96,7 +115,7 @@ def _remove_repeated(df : pnd.DataFrame) -> pnd.DataFrame:
96
115
  return df_clean
97
116
  # ----------------------------------
98
117
  # ---------------------------------------------
99
- def get_hashes(df_ft : pnd.DataFrame, rvalue : str ='set') -> Union[set, list]:
118
+ def get_hashes(df_ft : pnd.DataFrame, rvalue : str ='set') -> Union[set[str], list[str]]:
100
119
  '''
101
120
  Will return hashes for each row in the feature dataframe
102
121
 
@@ -113,9 +132,9 @@ def get_hashes(df_ft : pnd.DataFrame, rvalue : str ='set') -> Union[set, list]:
113
132
 
114
133
  return res
115
134
  # ----------------------------------
116
- def hash_from_row(row):
135
+ def hash_from_row(row : pnd.Series) -> str:
117
136
  '''
118
- Will return a hash from a pandas dataframe row
137
+ Will return a hash in the form or a string from a pandas dataframe row
119
138
  corresponding to an event
120
139
  '''
121
140
  l_val = [ str(val) for val in row ]
@@ -2,20 +2,28 @@
2
2
  Module containing utilities for pandas dataframes
3
3
  '''
4
4
  import os
5
+ import yaml
5
6
  import pandas as pnd
6
7
 
7
8
  from dmu.logging.log_store import LogStore
8
9
 
9
10
  log=LogStore.add_logger('dmu:pdataframe:utilities')
10
-
11
11
  # -------------------------------------
12
- def df_to_tex(df : pnd.DataFrame, path : str, hide_index : bool = True, d_format : dict[str,str]=None, caption : str =None) -> None:
12
+ def df_to_tex(df : pnd.DataFrame,
13
+ path : str,
14
+ hide_index : bool = True,
15
+ d_format : dict[str,str]= None,
16
+ **kwargs : str ) -> None:
13
17
  '''
14
18
  Saves pandas dataframe to latex
15
19
 
16
20
  Parameters
17
21
  -------------
22
+ df : Dataframe with data
23
+ path (str) : Path to latex file
24
+ hide_index : If true (default), index of dataframe won't appear in table
18
25
  d_format (dict) : Dictionary specifying the formattinng of the table, e.g. `{'col1': '{}', 'col2': '{:.3f}', 'col3' : '{:.3f}'}`
26
+ kwargs : Arguments needed in `to_latex`
19
27
  '''
20
28
 
21
29
  if path is not None:
@@ -30,7 +38,57 @@ def df_to_tex(df : pnd.DataFrame, path : str, hide_index : bool = True, d_format
30
38
  st=st.format(formatter=d_format)
31
39
 
32
40
  log.info(f'Saving to: {path}')
33
- buf = st.to_latex(buf=path, caption=caption, hrules=True)
41
+ buf = st.to_latex(buf=path, hrules=True, **kwargs)
34
42
 
35
43
  return buf
36
44
  # -------------------------------------
45
+ def to_yaml(df : pnd.DataFrame, path : str):
46
+ '''
47
+ Takes a dataframe and the path to a yaml file
48
+ Makes the directory path if not found and saves data in YAML file
49
+ '''
50
+ dir_path = os.path.dirname(path)
51
+ if dir_path != '':
52
+ os.makedirs(dir_path, exist_ok=True)
53
+
54
+ data = df.to_dict()
55
+
56
+ with open(path, 'w', encoding='utf-8') as ofile:
57
+ yaml.dump(data, ofile, Dumper=yaml.CDumper)
58
+ # -------------------------------------
59
+ def from_yaml(path : str) -> pnd.DataFrame:
60
+ '''
61
+ Takes path to a yaml file
62
+ Makes dataframe from it and returns it
63
+ '''
64
+ with open(path, encoding='utf-8') as ifile:
65
+ data = yaml.load(ifile, Loader=yaml.CSafeLoader)
66
+
67
+ df = pnd.DataFrame(data)
68
+
69
+ return df
70
+ # -------------------------------------
71
+ def dropna(df : pnd.DataFrame, max_frac : float = 0.02) -> pnd.DataFrame:
72
+ '''
73
+ Parameters
74
+ ----------------
75
+ df : Pandas dataframe potentially with NaNs
76
+ max_frac: Maximum fraction of the data that can be dropped, will raise exception beyond
77
+ '''
78
+
79
+ ini = len(df)
80
+ df = df.dropna()
81
+ fin = len(df)
82
+
83
+ if ini == fin:
84
+ log.debug('No NaNs were found')
85
+ return df
86
+
87
+ # If fewer elements survive the filter, raise
88
+ if fin < ini * (1 - max_frac):
89
+ raise ValueError(f'Too man NaNs were detected: {ini} --> {fin}')
90
+
91
+ log.info(f'Found NaNs: {ini} --> {fin}')
92
+
93
+ return df
94
+ # -------------------------------------
dmu/plotting/fwhm.py ADDED
@@ -0,0 +1,64 @@
1
+ '''
2
+ Module with FWHM plugin class
3
+ '''
4
+ import numpy
5
+ import matplotlib.pyplot as plt
6
+
7
+ from dmu.stats.zfit import zfit
8
+ from dmu.logging.log_store import LogStore
9
+
10
+ log = LogStore.add_logger('dmu:plotting:fwhm')
11
+ # --------------------------------------------
12
+ class FWHM:
13
+ '''
14
+ Class meant to be used to calculate Full Width at Half Maximum
15
+ as a Plotter1d plugin
16
+ '''
17
+ # -------------------------
18
+ def __init__(self, cfg : dict, val : numpy.ndarray, wgt : numpy.ndarray, maxy : float):
19
+ self._cfg = cfg
20
+ self._arr_val = val
21
+ self._arr_wgt = wgt
22
+ self._maxy = maxy
23
+ # -------------------------
24
+ def _normalize_yval(self, arr_pdf_val : numpy.ndarray) -> None:
25
+ max_pdf_val = numpy.max(arr_pdf_val)
26
+ arr_pdf_val*= self._maxy / max_pdf_val
27
+
28
+ return arr_pdf_val
29
+ # -------------------------
30
+ def _get_fwhm(self, arr_x : numpy.ndarray, arr_y : numpy.ndarray) -> float:
31
+ maxy = numpy.max(arr_y)
32
+ arry = numpy.where(arr_y > maxy/2.)[0]
33
+ imax = arry[ 0]
34
+ imin = arry[-1]
35
+
36
+ x1 = arr_x[imax]
37
+ x2 = arr_x[imin]
38
+
39
+ if self._cfg['plot']:
40
+ plt.plot([x1, x2], [maxy/2, maxy/2], linestyle=':', linewidth=1, color='k')
41
+
42
+ return x2 - x1
43
+ # -------------------------
44
+ def run(self) -> float:
45
+ '''
46
+ Runs plugin and return FWHM
47
+ '''
48
+ [minx, maxx] = self._cfg['obs']
49
+
50
+ log.info('Running FWHM pluggin')
51
+ obs = zfit.Space('mass', limits=(minx, maxx))
52
+ pdf= zfit.pdf.KDE1DimISJ(obs=obs, data=self._arr_val, weights=self._arr_wgt)
53
+
54
+ xval = numpy.linspace(minx, maxx, 200)
55
+ yval = pdf.pdf(xval)
56
+ yval = self._normalize_yval(yval)
57
+
58
+ if self._cfg['plot']:
59
+ plt.plot(xval, yval, linestyle='-', linewidth=2, color='gray')
60
+
61
+ fwhm = self._get_fwhm(xval, yval)
62
+
63
+ return fwhm
64
+ # --------------------------------------------
dmu/plotting/matrix.py CHANGED
@@ -102,7 +102,7 @@ class MatrixPlotter:
102
102
 
103
103
  fig, ax = plt.subplots() if fsize is None else plt.subplots(figsize=fsize)
104
104
 
105
- palette = plt.cm.viridis
105
+ palette = plt.cm.viridis #pylint: disable=no-member
106
106
  im = ax.imshow(self._mat, cmap=palette, vmin=zmin, vmax=zmax)
107
107
  self._set_axes(ax)
108
108
 
dmu/plotting/plotter.py CHANGED
@@ -3,6 +3,7 @@ Module containing plotter class
3
3
  '''
4
4
 
5
5
  import os
6
+ import json
6
7
  import math
7
8
  from typing import Union
8
9
 
@@ -29,6 +30,8 @@ class Plotter:
29
30
  self._d_cfg = cfg
30
31
  self._d_rdf : dict[str, RDataFrame] = { name : self._preprocess_rdf(rdf) for name, rdf in d_rdf.items()}
31
32
  self._d_wgt : Union[dict[str, Union[numpy.ndarray, None]], None]
33
+
34
+ self._title : str = ''
32
35
  #-------------------------------------
33
36
  def _check_quantile(self, qnt : float):
34
37
  '''
@@ -183,14 +186,17 @@ class Plotter:
183
186
 
184
187
  return d_weight
185
188
  # --------------------------------------------
186
- def _read_weights(self, name : str, rdf : RDataFrame) -> Union[numpy.ndarray, None]:
189
+ def _read_weights(self, name : str, rdf : RDataFrame) -> numpy.ndarray:
187
190
  v_col = rdf.GetColumnNames()
188
191
  l_col = [ col.c_str() for col in v_col ]
189
192
 
190
193
  if name not in l_col:
191
- log.debug(f'Weight {name} not found')
192
- return None
194
+ nentries = rdf.Count().GetValue()
195
+ log.debug(f'Weight {name} not found, using ones')
196
+
197
+ return numpy.ones(nentries)
193
198
 
199
+ log.debug(f'Weight {name} found')
194
200
  arr_wgt = rdf.AsNumpy([name])[name]
195
201
 
196
202
  return arr_wgt
@@ -228,4 +234,20 @@ class Plotter:
228
234
  plt.tight_layout()
229
235
  plt.savefig(plot_path)
230
236
  plt.close(var)
237
+ #-------------------------------------
238
+ def _data_to_json(self,
239
+ data : dict[str,float],
240
+ name : str) -> None:
241
+
242
+ # In case the values are numpy objects, which are not JSON
243
+ # serializable
244
+ data = { key : float(value) for key, value in data.items() }
245
+
246
+ plt_dir = self._d_cfg['saving']['plt_dir']
247
+ os.makedirs(plt_dir, exist_ok=True)
248
+
249
+ name = name.replace(' ', '_')
250
+ json_path = f'{plt_dir}/{name}.json'
251
+ with open(json_path, 'w', encoding='utf-8') as ofile:
252
+ json.dump(data, ofile, indent=2, sort_keys=True)
231
253
  # --------------------------------------------
@@ -1,7 +1,9 @@
1
1
  '''
2
2
  Module containing plotter class
3
3
  '''
4
+ # pylint: disable=too-many-positional-arguments, too-many-arguments
4
5
 
6
+ import cppyy
5
7
  from hist import Hist
6
8
 
7
9
  import numpy
@@ -9,6 +11,7 @@ import matplotlib.pyplot as plt
9
11
 
10
12
  from dmu.logging.log_store import LogStore
11
13
  from dmu.plotting.plotter import Plotter
14
+ from dmu.plotting.fwhm import FWHM
12
15
 
13
16
  log = LogStore.add_logger('dmu:plotting:Plotter1D')
14
17
  # --------------------------------------------
@@ -55,6 +58,110 @@ class Plotter1D(Plotter):
55
58
 
56
59
  return minx, maxx, bins
57
60
  #-------------------------------------
61
+ def _run_plugins(
62
+ self,
63
+ arr_val : numpy.ndarray,
64
+ arr_wgt : numpy.ndarray,
65
+ hst : Hist,
66
+ name : str,
67
+ varname : str) -> None:
68
+
69
+ if 'plugin' not in self._d_cfg:
70
+ log.debug('No plugins found')
71
+ return
72
+
73
+ if 'fwhm' in self._d_cfg['plugin']:
74
+ if varname not in self._d_cfg['plugin']['fwhm']:
75
+ log.debug(f'No FWHM plugin found for variable {varname}')
76
+ return
77
+
78
+ log.debug(f'FWHM plugin found for variable {varname}')
79
+ cfg = self._d_cfg['plugin']['fwhm'][varname]
80
+ self._run_fwhm(
81
+ arr_val = arr_val,
82
+ arr_wgt = arr_wgt,
83
+ hst = hst,
84
+ name = name,
85
+ varname = varname,
86
+ cfg = cfg)
87
+
88
+ if 'stats' in self._d_cfg['plugin']:
89
+ if varname not in self._d_cfg['plugin']['stats']:
90
+ log.debug(f'No stats plugin found for variable {varname}')
91
+ return
92
+
93
+ log.debug(f'stats plugin found for variable {varname}')
94
+ cfg = self._d_cfg['plugin']['stats'][varname]
95
+ self._run_stats(
96
+ arr_val = arr_val,
97
+ arr_wgt = arr_wgt,
98
+ name = name,
99
+ varname = varname,
100
+ cfg = cfg)
101
+ #-------------------------------------
102
+ def _run_stats(
103
+ self,
104
+ arr_val : numpy.ndarray,
105
+ arr_wgt : numpy.ndarray,
106
+ varname : str,
107
+ name : str,
108
+ cfg : dict[str:str]) -> None:
109
+
110
+ this_title = ''
111
+ data = {}
112
+ if 'sum' in cfg:
113
+ form = cfg['sum']
114
+ sumv = numpy.sum(arr_wgt)
115
+ this_title += form.format(sumv) + '; '
116
+ data['sum'] = sumv
117
+
118
+ if 'mean' in cfg:
119
+ form = cfg['mean']
120
+ mean = numpy.average(arr_val, weights=arr_wgt)
121
+ this_title += form.format(mean) + '; '
122
+ data['mean'] = mean
123
+
124
+ if 'rms' in cfg:
125
+ form = cfg['rms']
126
+ mean = numpy.average(arr_val, weights=arr_wgt)
127
+ rms = numpy.sqrt(numpy.average((arr_val - mean) ** 2, weights=arr_wgt))
128
+ this_title += form.format(rms ) + '; '
129
+ data['rms'] = rms
130
+
131
+ self._data_to_json(data = data, name = f'stats_{varname}_{name}')
132
+
133
+ self._title+= f'\n{name}: {this_title}'
134
+ #-------------------------------------
135
+ def _run_fwhm(
136
+ self,
137
+ arr_val : numpy.ndarray,
138
+ arr_wgt : numpy.ndarray,
139
+ hst : Hist,
140
+ varname : str,
141
+ name : str,
142
+ cfg : dict) -> None:
143
+
144
+ arr_bin_cnt = hst.values()
145
+ maxy = numpy.max(arr_bin_cnt)
146
+ obj = FWHM(cfg=cfg, val=arr_val, wgt=arr_wgt, maxy=maxy)
147
+ fwhm = obj.run()
148
+
149
+ form = cfg['format']
150
+ this_title = form.format(fwhm)
151
+ data = {}
152
+
153
+ if 'add_std' in cfg and cfg['add_std']:
154
+ mu = numpy.average(arr_val , weights=arr_wgt)
155
+ var = numpy.average((arr_val - mu) ** 2, weights=arr_wgt)
156
+ std = numpy.sqrt(var)
157
+ form = form.replace('FWHM', 'STD')
158
+ this_title+= '; ' + form.format(std)
159
+ data = {'mu' : mu, 'std' : std, 'fwhm' : fwhm}
160
+
161
+ self._data_to_json(data = data, name = f'fwhm_{varname}_{name}')
162
+
163
+ self._title+= f'\n{name}: {this_title}'
164
+ #-------------------------------------
58
165
  def _plot_var(self, var : str) -> float:
59
166
  '''
60
167
  Will plot a variable from a dictionary of dataframes
@@ -70,39 +177,70 @@ class Plotter1D(Plotter):
70
177
 
71
178
  d_data = {}
72
179
  for name, rdf in self._d_rdf.items():
73
- d_data[name] = rdf.AsNumpy([var])[var]
180
+ try:
181
+ log.debug(f'Plotting: {var}/{name}')
182
+ d_data[name] = rdf.AsNumpy([var])[var]
183
+ except cppyy.gbl.std.runtime_error as exc:
184
+ raise ValueError(f'Cannot find variable {var} in category {name}') from exc
74
185
 
75
186
  minx, maxx, bins = self._get_binning(var, d_data)
76
187
  d_wgt = self._get_weights(var)
77
188
 
78
189
  l_bc_all = []
79
190
  for name, arr_val in d_data.items():
80
- label = self._label_from_name(name, arr_val)
191
+ label = self._label_from_name(name)
81
192
  arr_wgt = d_wgt[name] if d_wgt is not None else numpy.ones_like(arr_val)
82
193
  arr_wgt = self._normalize_weights(arr_wgt, var)
83
194
  hst = Hist.new.Reg(bins=bins, start=minx, stop=maxx, name='x').Weight()
84
195
  hst.fill(x=arr_val, weight=arr_wgt)
85
- hst.plot(label=label)
196
+ self._run_plugins(arr_val, arr_wgt, hst, name, var)
197
+ style = self._get_style_config(var=var, label=label)
198
+
199
+ log.debug(f'Style: {style}')
200
+ hst.plot(**style)
201
+
86
202
  l_bc_all += hst.values().tolist()
87
203
 
88
204
  max_y = max(l_bc_all)
89
205
 
90
206
  return max_y
91
207
  # --------------------------------------------
92
- def _label_from_name(self, name : str, arr_val : numpy.ndarray) -> str:
208
+ def _get_style_config(self, var : str, label : str) -> dict[str,str]:
209
+ style = {
210
+ 'label' : label,
211
+ 'histtype' : 'errorbar',
212
+ 'linestyle' : 'none'}
213
+
214
+ if 'styling' not in self._d_cfg['plots'][var]:
215
+ log.debug(f'Styling not specified for {var}')
216
+ return style
217
+
218
+ if label not in self._d_cfg['plots'][var]['styling']:
219
+ log.debug(f'Styling not specified for {var}/{label}')
220
+ return style
221
+
222
+ custom_style = self._d_cfg['plots'][var]['styling'][label]
223
+ style.update(custom_style)
224
+ log.debug(f'Using custom styling for {var}/{label}')
225
+
226
+ return style
227
+ # --------------------------------------------
228
+ def _label_from_name(self, name : str) -> str:
93
229
  if 'stats' not in self._d_cfg:
94
230
  return name
95
231
 
96
232
  d_stat = self._d_cfg['stats']
97
- if 'nentries' not in d_stat:
233
+ if 'sumw' not in d_stat:
98
234
  return name
99
235
 
100
- form = d_stat['nentries']
236
+ form = d_stat['sumw']
101
237
 
102
- nentries = len(arr_val)
103
- nentries = form.format(nentries)
238
+ arr_wgt = self._d_wgt[name]
239
+ arr_wgt = numpy.nan_to_num(arr_wgt, nan=0.0)
240
+ sumw = numpy.sum(arr_wgt)
241
+ nentries = form.format(sumw)
104
242
 
105
- return f'{name}{nentries}'
243
+ return f'{name:<15}{nentries:<10}'
106
244
  # --------------------------------------------
107
245
  def _normalize_weights(self, arr_wgt : numpy.ndarray, var : str) -> numpy.ndarray:
108
246
  cfg_var = self._d_cfg['plots'][var]
@@ -131,9 +269,12 @@ class Plotter1D(Plotter):
131
269
  if yscale == 'linear':
132
270
  plt.ylim(bottom=0)
133
271
 
134
- title = ''
272
+ title = self._title
135
273
  if 'title' in d_cfg:
136
- title = d_cfg['title']
274
+ this_title = d_cfg['title']
275
+ title += f'\n {this_title}'
276
+
277
+ title = title.lstrip('\n')
137
278
 
138
279
  plt.ylim(top=1.2 * max_y)
139
280
  plt.legend()
@@ -145,10 +286,15 @@ class Plotter1D(Plotter):
145
286
 
146
287
  var (str) : name of variable
147
288
  '''
289
+ var_cfg = self._d_cfg['plots'][var]
290
+ if 'vline' in var_cfg:
291
+ line_cfg = var_cfg['vline']
292
+ plt.axvline(**line_cfg)
293
+
148
294
  if 'style' in self._d_cfg and 'skip_lines' in self._d_cfg['style'] and self._d_cfg['style']['skip_lines']:
149
295
  return
150
296
 
151
- if var in ['B_const_mass_M', 'B_M']:
297
+ if var in ['B_const_mass_M', 'B_M', 'B_Mass', 'B_Mass_smr']:
152
298
  plt.axvline(x=5280, color='r', label=r'$B^+$' , linestyle=':')
153
299
  elif var == 'Jpsi_M':
154
300
  plt.axvline(x=3096, color='r', label=r'$J/\psi$', linestyle=':')
@@ -160,8 +306,7 @@ class Plotter1D(Plotter):
160
306
 
161
307
  fig_size = self._get_fig_size()
162
308
  for var in self._d_cfg['plots']:
163
- log.debug(f'Plotting: {var}')
164
-
309
+ self._title = ''
165
310
  plt.figure(var, figsize=fig_size)
166
311
  max_y = self._plot_var(var)
167
312
  self._style_plot(var, max_y)
@@ -70,6 +70,11 @@ class Plotter2D(Plotter):
70
70
  hst = Hist(ax_x, ax_y)
71
71
  hst.fill(arr_x, arr_y, weight=arr_w)
72
72
 
73
+ if hst.values().sum() == 0:
74
+ log.warning('Empty histogram, not using log scale')
75
+ mplhep.hist2dplot(hst)
76
+ return
77
+
73
78
  if use_log:
74
79
  mplhep.hist2dplot(hst, norm=LogNorm())
75
80
  else:
@@ -16,7 +16,6 @@ from ROOT import RDataFrame, RDF, Numba
16
16
  from dmu.logging.log_store import LogStore
17
17
 
18
18
  log = LogStore.add_logger('dmu:rdataframe:utilities')
19
-
20
19
  # ---------------------------------------------------------------------
21
20
  @dataclass
22
21
  class Data:
@@ -98,12 +97,17 @@ def add_column_with_numba(
98
97
 
99
98
  return rdf
100
99
  # ---------------------------------------------------------------------
101
- def rdf_report_to_df(rep : RDF.RCutFlowReport) -> pnd.DataFrame:
100
+ def rdf_report_to_df(rep : RDF.RCutFlowReport) -> Union[pnd.DataFrame, None]:
102
101
  '''
103
102
  Takes the output of rdf.Report(), i.e. an RDataFrame cutflow report.
104
103
 
105
- Produces a pandas dataframe with
104
+ Produces a pandas dataframe with the total, failed, efficiency, and cummulative efficiency
105
+ If no cut was applied, i.e. the cutflow is empty, will return None and show warning
106
106
  '''
107
+ if rep.begin() == rep.end():
108
+ log.warning('Empty cutflow')
109
+ return None
110
+
107
111
  d_data = {'cut' : [], 'All' : [], 'Passed' : []}
108
112
  for cut in rep:
109
113
  name=cut.GetName()
@@ -119,3 +123,50 @@ def rdf_report_to_df(rep : RDF.RCutFlowReport) -> pnd.DataFrame:
119
123
  df['Cummulative'] = df['Efficiency'].cumprod()
120
124
 
121
125
  return df
126
+ # ---------------------------------------------------------------------
127
+ def random_filter(rdf : RDataFrame, entries : int) -> RDataFrame:
128
+ '''
129
+ Filters a dataframe, such that the output has **approximately** `entries` entries
130
+ '''
131
+ ntot = rdf.Count().GetValue()
132
+
133
+ if entries <= 0 or entries >= ntot:
134
+ log.warning(f'Requested {entries}/{ntot} random entries, not filtering')
135
+ return rdf
136
+
137
+ prob = float(entries) / ntot
138
+ name = f'filter_{entries}'
139
+
140
+ rdf = rdf.Define(name, 'gRandom->Rndm();')
141
+ rdf = rdf.Filter(f'{name} < {prob}', name)
142
+ nres = rdf.Count().GetValue()
143
+
144
+ log.debug(f'Requested {ntot}, picked {nres}')
145
+
146
+ return rdf
147
+ # ---------------------------------------------------------------------
148
+ def rdf_to_df(
149
+ rdf : RDataFrame,
150
+ columns : list[str]) -> pnd.DataFrame:
151
+ '''
152
+ Parameters
153
+ ---------------
154
+ rdf : ROOT dataframe
155
+ branches : List of columns to keep in pandas dataframe
156
+
157
+ Returns
158
+ ---------------
159
+ Pandas dataframe with subset of columns
160
+ '''
161
+ log.debug('Storing branches')
162
+ data = rdf.AsNumpy(columns)
163
+ df = pnd.DataFrame(data)
164
+
165
+ if len(df) == 0:
166
+ rep = rdf.Report()
167
+ cutflow = rdf_report_to_df(rep)
168
+ log.warning('Empty dataset:\n')
169
+ log.info(cutflow)
170
+
171
+ return df
172
+ # ---------------------------------------------------------------------