data-manipulation-utilities 0.2.6__py3-none-any.whl → 0.2.8.dev714__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. {data_manipulation_utilities-0.2.6.dist-info → data_manipulation_utilities-0.2.8.dev714.dist-info}/METADATA +800 -34
  2. data_manipulation_utilities-0.2.8.dev714.dist-info/RECORD +93 -0
  3. {data_manipulation_utilities-0.2.6.dist-info → data_manipulation_utilities-0.2.8.dev714.dist-info}/WHEEL +1 -1
  4. {data_manipulation_utilities-0.2.6.dist-info → data_manipulation_utilities-0.2.8.dev714.dist-info}/entry_points.txt +1 -0
  5. dmu/__init__.py +0 -0
  6. dmu/generic/hashing.py +70 -0
  7. dmu/generic/utilities.py +175 -9
  8. dmu/generic/version_management.py +3 -5
  9. dmu/logging/log_store.py +34 -2
  10. dmu/logging/messages.py +96 -0
  11. dmu/ml/cv_classifier.py +3 -3
  12. dmu/ml/cv_diagnostics.py +224 -0
  13. dmu/ml/cv_performance.py +58 -0
  14. dmu/ml/cv_predict.py +149 -46
  15. dmu/ml/train_mva.py +587 -112
  16. dmu/ml/utilities.py +29 -10
  17. dmu/pdataframe/utilities.py +61 -3
  18. dmu/plotting/fwhm.py +64 -0
  19. dmu/plotting/matrix.py +1 -1
  20. dmu/plotting/plotter.py +25 -3
  21. dmu/plotting/plotter_1d.py +159 -14
  22. dmu/plotting/plotter_2d.py +5 -0
  23. dmu/rdataframe/utilities.py +54 -3
  24. dmu/rfile/ddfgetter.py +102 -0
  25. dmu/stats/fit_stats.py +129 -0
  26. dmu/stats/fitter.py +56 -23
  27. dmu/stats/gof_calculator.py +7 -0
  28. dmu/stats/model_factory.py +305 -50
  29. dmu/stats/parameters.py +100 -0
  30. dmu/stats/utilities.py +443 -12
  31. dmu/stats/wdata.py +187 -0
  32. dmu/stats/zfit.py +17 -0
  33. dmu/stats/zfit_models.py +68 -0
  34. dmu/stats/zfit_plotter.py +175 -56
  35. dmu/testing/utilities.py +120 -15
  36. dmu/workflow/__init__.py +0 -0
  37. dmu/workflow/cache.py +266 -0
  38. dmu_data/ml/tests/diagnostics_from_file.yaml +13 -0
  39. dmu_data/ml/tests/diagnostics_from_model.yaml +10 -0
  40. dmu_data/ml/tests/diagnostics_multiple_methods.yaml +10 -0
  41. dmu_data/ml/tests/diagnostics_overlay.yaml +33 -0
  42. dmu_data/ml/tests/train_mva.yaml +20 -12
  43. dmu_data/ml/tests/train_mva_def.yaml +75 -0
  44. dmu_data/ml/tests/train_mva_with_diagnostics.yaml +87 -0
  45. dmu_data/ml/tests/train_mva_with_preffix.yaml +58 -0
  46. dmu_data/plotting/tests/2d.yaml +5 -5
  47. dmu_data/plotting/tests/line.yaml +15 -0
  48. dmu_data/plotting/tests/plug_fwhm.yaml +24 -0
  49. dmu_data/plotting/tests/plug_stats.yaml +19 -0
  50. dmu_data/plotting/tests/simple.yaml +4 -3
  51. dmu_data/plotting/tests/styling.yaml +18 -0
  52. dmu_data/rfile/friends.yaml +13 -0
  53. dmu_data/stats/fitter/test_simple.yaml +28 -0
  54. dmu_data/stats/kde_optimizer/control.json +1 -0
  55. dmu_data/stats/kde_optimizer/signal.json +1 -0
  56. dmu_data/stats/parameters/data.yaml +178 -0
  57. dmu_data/tests/config.json +6 -0
  58. dmu_data/tests/config.yaml +4 -0
  59. dmu_data/tests/pdf_to_tex.txt +34 -0
  60. dmu_scripts/kerberos/check_expiration +21 -0
  61. dmu_scripts/kerberos/convert_certificate +22 -0
  62. dmu_scripts/ml/compare_classifiers.py +85 -0
  63. data_manipulation_utilities-0.2.6.dist-info/RECORD +0 -57
  64. {data_manipulation_utilities-0.2.6.data → data_manipulation_utilities-0.2.8.dev714.data}/scripts/publish +0 -0
  65. {data_manipulation_utilities-0.2.6.dist-info → data_manipulation_utilities-0.2.8.dev714.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,68 @@
1
+ '''
2
+ Module meant to hold classes defining PDFs that can be used by ZFIT
3
+ '''
4
+
5
+ import zfit
6
+ from zfit import z
7
+
8
+ #-------------------------------------------------------------------
9
+ class HypExp(zfit.pdf.ZPDF):
10
+ _N_OBS = 1
11
+ _PARAMS = ['mu', 'alpha', 'beta']
12
+
13
+ def _unnormalized_pdf(self, x):
14
+ x = z.unstack_x(x)
15
+ mu = self.params['mu']
16
+ ap = self.params['alpha']
17
+ bt = self.params['beta']
18
+
19
+ u = (x - mu)
20
+ val = z.exp(-bt * x) / (1 + z.exp(-ap * u))
21
+
22
+ return val
23
+ #-------------------------------------------------------------------
24
+ class ModExp(zfit.pdf.ZPDF):
25
+ _N_OBS = 1
26
+ _PARAMS = ['mu', 'alpha', 'beta']
27
+
28
+ def _unnormalized_pdf(self, x):
29
+ x = z.unstack_x(x)
30
+ mu = self.params['mu']
31
+ ap = self.params['alpha']
32
+ bt = self.params['beta']
33
+
34
+ u = x - mu
35
+ val = (1 - z.exp(-ap * u)) * z.exp(-bt * u)
36
+
37
+ return val
38
+ #-------------------------------------------------------------------
39
+ class GenExp(zfit.pdf.ZPDF):
40
+ _N_OBS = 1
41
+ _PARAMS = ['mu', 'sg', 'alpha', 'beta']
42
+
43
+ def _unnormalized_pdf(self, x):
44
+ x = z.unstack_x(x)
45
+ mu = self.params['mu']
46
+ sg = self.params['sg']
47
+ ap = self.params['alpha']
48
+ bt = self.params['beta']
49
+
50
+ u = (x - mu) / sg
51
+ val = (1 - z.exp(-ap * u)) * z.exp(-bt * u)
52
+
53
+ return val
54
+ #-------------------------------------------------------------------
55
+ class FermiDirac(zfit.pdf.ZPDF):
56
+ _N_OBS = 1
57
+ _PARAMS = ['mu', 'ap']
58
+
59
+ def _unnormalized_pdf(self, x):
60
+ x = z.unstack_x(x)
61
+ mu = self.params['mu']
62
+ ap = self.params['ap']
63
+
64
+ exp = (x - mu) / ap
65
+ den = 1 + z.exp(exp)
66
+
67
+ return 1. / den
68
+ #-------------------------------------------------------------------
dmu/stats/zfit_plotter.py CHANGED
@@ -1,27 +1,31 @@
1
1
  '''
2
2
  Module containing plot class, used to plot fits
3
3
  '''
4
- # pylint: disable=too-many-instance-attributes
4
+ # pylint: disable=too-many-instance-attributes, too-many-arguments
5
5
 
6
+ import math
6
7
  import warnings
7
8
  import pprint
8
9
 
9
10
  import zfit
10
11
  import hist
11
12
  import mplhep
13
+ import tensorflow as tf
12
14
  import pandas as pd
13
15
  import numpy as np
14
16
  import matplotlib.pyplot as plt
15
- import dmu.generic.utilities as gut
17
+ from zfit.core.basepdf import BasePDF as zpdf
16
18
 
19
+ import dmu.generic.utilities as gut
17
20
  from dmu.logging.log_store import LogStore
18
21
 
19
- log = LogStore.add_logger('dmu:fit_plotter')
22
+ log = LogStore.add_logger('dmu:zfit_plotter')
20
23
  #----------------------------------------
21
24
  class ZFitPlotter:
22
25
  '''
23
26
  Class used to plot fits done with zfit
24
27
  '''
28
+ #----------------------------------------
25
29
  def __init__(self, data=None, model=None, weights=None, result=None, suffix=''):
26
30
  '''
27
31
  obs: zfit space you are using to define the data and model
@@ -51,6 +55,8 @@ class ZFitPlotter:
51
55
  self._figsize = None
52
56
  self._leg_loc = None
53
57
 
58
+ self.dat_xerr : bool
59
+
54
60
  # zfit.settings.advanced_warnings['extend_wrapped_extended'] = False
55
61
  warnings.filterwarnings("ignore")
56
62
  #----------------------------------------
@@ -60,50 +66,90 @@ class ZFitPlotter:
60
66
  self._l_def_col = list(mcolors.TABLEAU_COLORS.keys())
61
67
  #----------------------------------------
62
68
  def _data_to_zdata(self, obs, data, weights):
69
+ if isinstance(data, zfit.Data):
70
+ return data
71
+
63
72
  if isinstance(data, np.ndarray):
64
73
  data = zfit.Data.from_numpy (obs=obs, array=data , weights=weights)
65
74
  elif isinstance(data, pd.Series):
66
75
  data = zfit.Data.from_pandas(obs=obs, df=pd.DataFrame(data), weights=weights)
67
76
  elif isinstance(data, pd.DataFrame):
68
77
  data = zfit.Data.from_pandas(obs=obs, df=data , weights=weights)
69
- elif isinstance(data, zfit.data.Data):
70
- data = data
71
78
  else:
72
- log.error(f'Passed data is of usupported type {type(data)}')
73
- raise
79
+ raise ValueError(f'Passed data is of usupported type {type(data)}')
74
80
 
75
81
  return data
76
82
  #----------------------------------------
77
- def _get_errors(self, nbins=100, l_range=None):
78
- dat, wgt = self._get_range_data(l_range, blind=False)
79
- data_hist = hist.Hist.new.Regular(nbins, self.lower, self.upper, name=self.obs.obs[0], underflow=False, overflow=False)
83
+ def _get_errors(
84
+ self,
85
+ nbins : int = 100,
86
+ l_range: list[tuple[float,float]]|None = None) -> list[float]:
87
+ '''
88
+ Parameters
89
+ ---------------------
90
+ nbins : Number of bins
91
+ l_range: List of ranges where data should be picked, if None, will pick full range
92
+
93
+ Returns
94
+ ---------------------
95
+ list of errors associated to histogram filled with data
96
+ '''
97
+ dat, wgt = self._get_range_data(l_range=l_range, blind=False)
98
+ data_hist = hist.Hist.new.Regular(
99
+ nbins,
100
+ self.lower,
101
+ self.upper,
102
+ name =self.obs.obs[0],
103
+ underflow =False,
104
+ overflow =False)
105
+
80
106
  data_hist = data_hist.Weight()
81
107
  data_hist.fill(dat, weight=wgt)
82
108
 
83
109
  tmp_fig, tmp_ax = plt.subplots()
84
- errorbars = mplhep.histplot(
110
+ errorbars = mplhep.histplot(
85
111
  data_hist,
86
- yerr=True,
87
- color='white',
88
- histtype="errorbar",
89
- label=None,
90
- ax=tmp_ax,
91
- )
92
- plt.close(tmp_fig)
93
-
94
- lines = errorbars[0].errorbar[2]
95
- segs = lines[0].get_segments()
112
+ yerr =True,
113
+ color ='white',
114
+ histtype ='errorbar',
115
+ label =None,
116
+ ax =tmp_ax)
117
+
118
+ lines = errorbars[0].errorbar[2]
119
+ segs = lines[0].get_segments()
96
120
  values = data_hist.values()
97
121
 
98
122
  l_error=[]
99
123
  for i in range(nbins):
100
- low = values[i] - segs[i][0][1]
101
- up = -values[i] + segs[i][1][1]
124
+ seg = segs[i]
125
+ val = values[i]
126
+
127
+ try:
128
+ low = val - seg[0][1]
129
+ up = -val + seg[1][1]
130
+ except IndexError as exc:
131
+ raise IndexError(f'Cannot read the upper/lower errors, found {seg}') from exc
132
+
102
133
  l_error.append((low, up))
103
134
 
135
+ plt.close(tmp_fig)
136
+
104
137
  return l_error
105
138
  #----------------------------------------
106
- def _get_range_data(self, l_range, blind=True):
139
+ def _get_range_data(
140
+ self,
141
+ l_range : list[tuple[float,float]]|None,
142
+ blind : bool =True) -> tuple[np.ndarray, np.ndarray]:
143
+ '''
144
+ Parameters
145
+ -----------------
146
+ l_range: List of ranges, i.e. tuples of bounds
147
+ blind : If true (default) will blind the range specified, i.e. will exclude it
148
+
149
+ Returns
150
+ -----------------
151
+ Tuple with two numpy arrays defined in those ranges, with the observable and the weights.
152
+ '''
107
153
  sdat = self.data_np
108
154
  swgt = self.data_weight_np
109
155
  dmat = np.array([sdat, swgt]).T
@@ -115,6 +161,8 @@ class ZFitPlotter:
115
161
 
116
162
  if l_range is None:
117
163
  [dat, wgt] = dmat.T
164
+ self._check_data(dat=dat, wgt=wgt)
165
+
118
166
  return dat, wgt
119
167
 
120
168
  l_dat = []
@@ -130,23 +178,42 @@ class ZFitPlotter:
130
178
  dat_f = np.concatenate(l_dat)
131
179
  wgt_f = np.concatenate(l_wgt)
132
180
 
181
+ self._check_data(dat=dat_f, wgt=wgt_f)
182
+
133
183
  return dat_f, wgt_f
134
184
  #----------------------------------------
185
+ def _check_data(
186
+ self,
187
+ dat : np.ndarray,
188
+ wgt : np.ndarray) -> None:
189
+ '''
190
+ Checks for empty data, etc
191
+
192
+ Parameters
193
+ ------------
194
+ Numpy arrays with data and weights
195
+ '''
196
+
197
+ if dat.shape != wgt.shape:
198
+ raise ValueError(f'Shapes or data and weights differ: {dat.shape}/{wgt.shape}')
199
+
200
+ if len(dat) == 0:
201
+ raise ValueError('Dataset is empty')
202
+ #----------------------------------------
135
203
  def _plot_data(self, ax, nbins=100, l_range=None):
136
204
  dat, wgt = self._get_range_data(l_range, blind=True)
137
205
  data_hist = hist.Hist.new.Regular(nbins, self.lower, self.upper, name=self.obs.obs[0], underflow=False, overflow=False)
138
206
  data_hist = data_hist.Weight()
139
207
  data_hist.fill(dat, weight=wgt)
140
208
 
141
- _ = mplhep.histplot(
142
- data_hist,
143
- yerr=True,
144
- color="black",
145
- histtype="errorbar",
146
- label=self._leg.get("Data", "Data"),
147
- ax=ax,
148
- xerr=self.dat_xerr
149
- )
209
+ mplhep.histplot(
210
+ data_hist,
211
+ yerr = True,
212
+ color = 'black',
213
+ histtype = 'errorbar',
214
+ label = self._leg.get('Data', 'Data'),
215
+ ax = ax,
216
+ xerr = self.dat_xerr)
150
217
  #----------------------------------------
151
218
  def _pull_hist(self, pdf_hist, nbins, data_yield, l_range=None):
152
219
  pdf_values= pdf_hist.values()
@@ -168,8 +235,16 @@ class ZFitPlotter:
168
235
  err = low if res > 0 else up
169
236
  pul = res / err
170
237
 
171
- if abs(pul) > 5:
172
- log.warning(f'Large pull: {pul:.1f}=({dat_val:.0f}-{pdf_val:.0f})/{err:.0f}')
238
+ # If the data is weighted
239
+ # and the data does not exist
240
+ # The pulls will have an error of zero => pull is inf
241
+ # Ignore these cases
242
+ if math.isinf(pul):
243
+ pass
244
+ elif abs(pul) > 5:
245
+ log.info(f'Pull: {pul:.2f}=({dat_val:.2f}-{pdf_val:.2f})/{err:.2f}')
246
+ else:
247
+ log.debug(f'Pull: {pul:.2f}=({dat_val:.2f}-{pdf_val:.2f})/{err:.2f}')
173
248
 
174
249
  pulls.append(pul)
175
250
  pull_errors[0].append(low / err)
@@ -200,7 +275,7 @@ class ZFitPlotter:
200
275
  #----------------------------------------
201
276
  def _get_zfit_gof(self):
202
277
  if not hasattr(self._result, 'gof'):
203
- return
278
+ return None
204
279
 
205
280
  chi2, ndof, pval = self._result.gof
206
281
 
@@ -211,14 +286,16 @@ class ZFitPlotter:
211
286
  def _get_text(self, ext_text):
212
287
  gof_text = self._get_zfit_gof()
213
288
 
214
- if ext_text is None and gof_text is None:
215
- return
216
- elif ext_text is not None and gof_text is None:
289
+ if ext_text is None and gof_text is None:
290
+ return None
291
+
292
+ if ext_text is not None and gof_text is None:
217
293
  return ext_text
218
- elif ext_text is None and gof_text is not None:
294
+
295
+ if ext_text is None and gof_text is not None:
219
296
  return gof_text
220
- else:
221
- return f'{ext_text}\n{gof_text}'
297
+
298
+ return f'{ext_text}\n{gof_text}'
222
299
  #----------------------------------------
223
300
  def _get_pars(self):
224
301
  '''
@@ -237,8 +314,8 @@ class ZFitPlotter:
237
314
  val = d_val['value']
238
315
  name= par if isinstance(par, str) else par.name
239
316
  try:
240
- err = d_val['hesse']['error']
241
- except:
317
+ err = d_val['minuit_hesse']['error']
318
+ except KeyError:
242
319
  log.warning(f'Cannot extract {name} Hesse errors, using zeros')
243
320
  pprint.pprint(d_val)
244
321
  err = 0
@@ -260,7 +337,7 @@ class ZFitPlotter:
260
337
  '''
261
338
  d_par = self._get_pars()
262
339
 
263
- line = f''
340
+ line = ''
264
341
  for name, [val, err] in d_par.items():
265
342
  if add_pars != 'all' and name not in add_pars:
266
343
  continue
@@ -328,7 +405,7 @@ class ZFitPlotter:
328
405
  nevt = self._get_component_yield(model, par)
329
406
 
330
407
  if model.name in self._l_plot_components and hasattr(model, 'pdfs'):
331
- l_model = [ (frc, pdf) for pdf, frc in zip(model.pdfs, model.params.values()) ]
408
+ l_model = [ (frc, pdf) for pdf, frc in zip(model.pdfs, model.params.values()) ]
332
409
  elif model.name in self._l_plot_components and not hasattr(model, 'pdfs'):
333
410
  log.warning(f'Cannot plot {model.name} as separate components, despite it was requested')
334
411
  l_model = [ (1, model)]
@@ -344,27 +421,51 @@ class ZFitPlotter:
344
421
  if stacked:
345
422
  ax.fill_between(self.x, y, alpha=1.0, label=self._leg.get(name, name), color=self._get_col(name))
346
423
  else:
347
- ax.plot(self.x, y, '-', label=self._leg.get(name, name), color=self._col.get(name))
424
+ ax.plot(self.x, y, ':', label=self._leg.get(name, name), color=self._col.get(name))
348
425
 
349
426
  if (blind_name is not None) and (was_blinded is False):
350
- log.error(f'Blinding was requested, but PDF {blind_name} was not found among:')
351
427
  for model in self.total_model.pdfs:
352
428
  log.info(model.name)
353
- raise
429
+
430
+ raise ValueError(f'Blinding was requested, but PDF {blind_name} was not found among:')
354
431
  #----------------------------------------
355
432
  def _get_col(self, name):
356
433
  if name in self._col:
357
434
  return self._col[name]
358
435
 
359
436
  col = self._l_def_col[0]
360
- del(self._l_def_col[0])
437
+ del self._l_def_col[0]
361
438
 
362
439
  return col
363
440
  #----------------------------------------
441
+ def _print_data(self) -> None:
442
+ log.info(f'Data shape : {self.data_np.shape}')
443
+ log.info(f'Weights shape: {self.data_weight_np.shape}')
444
+
445
+ nnans = np.sum(np.isnan(self.data_np))
446
+ log.info(f'NaNs: {nnans}')
447
+
448
+ # This function will run before program raises
449
+ # One should be able to drop any plot
450
+ plt.close('all')
451
+ plt.hist(self.data_np, weights=self.data_weight_np)
452
+ plt.show()
453
+ #----------------------------------------
454
+ def _evaluate_pdf(self, pdf : zpdf) -> np.ndarray:
455
+ try:
456
+ arr_y = pdf.pdf(self.x)
457
+ except tf.errors.InvalidArgumentError as exc:
458
+ log.info(f'X values: {self.x}')
459
+ self._print_data()
460
+ raise ValueError('Cannot evaluate PDF') from exc
461
+
462
+ return arr_y
463
+ #----------------------------------------
364
464
  def _plot_sub_components(self, y, nbins, stacked, nevt, l_model):
365
465
  l_y = []
366
466
  for frc, model in l_model:
367
- this_y = model.pdf(self.x) * nevt * frc / nbins * (self.upper - self.lower)
467
+ arr_y = self._evaluate_pdf(pdf = model)
468
+ this_y = arr_y * nevt * frc / nbins * (self.upper - self.lower)
368
469
 
369
470
  if stacked:
370
471
  y = this_y if y is None else y + this_y
@@ -381,7 +482,13 @@ class ZFitPlotter:
381
482
  return
382
483
 
383
484
  data_yield = self.data_weight_np.sum()
384
- y = model.pdf(self.x) * data_yield / nbins * (self.upper - self.lower)
485
+ try:
486
+ arr_y = self._evaluate_pdf(model)
487
+ y = arr_y * data_yield / nbins * (self.upper - self.lower)
488
+ except tf.errors.InvalidArgumentError as exc:
489
+ log.warning(f'Data yield: {data_yield:.0f}')
490
+ log.info(self.data_np)
491
+ raise RuntimeError('Cannot parse PDF') from exc
385
492
 
386
493
  name = model.name
387
494
  ax.plot(self.x, y, linestyle, label=self._leg.get(name, name), color=self._col.get(name))
@@ -392,7 +499,7 @@ class ZFitPlotter:
392
499
 
393
500
  if ylabel == "":
394
501
  width = (self.upper-self.lower)/nbins
395
- ylabel = f'Candidates / ({width:.3f} {unit})'
502
+ ylabel = f'Candidates / ({width:.0f} {unit})'
396
503
 
397
504
  return xlabel, ylabel
398
505
  #----------------------------------------
@@ -400,9 +507,8 @@ class ZFitPlotter:
400
507
  if plot_range is not None:
401
508
  try:
402
509
  self.lower, self.upper = plot_range
403
- except TypeError:
404
- log.error(f'plot_range argument is expected to be a tuple with two numeric values')
405
- raise TypeError
510
+ except TypeError as exc:
511
+ raise TypeError('plot_range argument is expected to be a tuple with two numeric values') from exc
406
512
 
407
513
  return np.linspace(self.lower, self.upper, 2000)
408
514
  #----------------------------------------
@@ -439,6 +545,8 @@ class ZFitPlotter:
439
545
  add_pars = None,
440
546
  ymax = None,
441
547
  skip_pulls = False,
548
+ pull_styling :bool= True,
549
+ yscale : str = None,
442
550
  axs = None,
443
551
  figsize:tuple = (13, 7),
444
552
  leg_loc:str = 'best',
@@ -455,6 +563,7 @@ class ZFitPlotter:
455
563
  d_leg : Customize legend
456
564
  d_col : Customize color
457
565
  plot_range : Set plot_range
566
+ pull_styling(bool) : Will add lines at +/-3 and set range to +/-5 for pull plots, by default True
458
567
  plot_components (list): List of strings, with names of PDFs, which are expected to be sums of PDFs and whose components should be plotted separately
459
568
  ext_text : Text that can be added to plot
460
569
  add_pars (list|str) : List of names of parameters to be added or string with value 'all' to add all fit parameters. If this is used, plot won't use LHCb style.
@@ -464,6 +573,7 @@ class ZFitPlotter:
464
573
  figsize (tuple) : Tuple with figure size, default (13, 7)
465
574
  leg_loc (str) : Location of legend, default 'best'
466
575
  xerr (bool or float) : Used to pass xerr to mplhep histplot. True will use error with bin size, False, no error, otherwise it's the size of the xerror bar
576
+ yscale (str) : Scale for y axis of main plot, either log or linear
467
577
  '''
468
578
  # pylint: disable=too-many-locals, too-many-positional-arguments, too-many-arguments
469
579
  d_leg = {} if d_leg is None else d_leg
@@ -512,6 +622,9 @@ class ZFitPlotter:
512
622
  self.axs[0].set(xlabel=xlabel, ylabel=ylabel)
513
623
  self.axs[0].set_xlim([self.lower, self.upper])
514
624
 
625
+ if yscale is not None:
626
+ self.axs[0].set_yscale(yscale)
627
+
515
628
  if title is not None:
516
629
  self.axs[0].set_title(title)
517
630
 
@@ -524,4 +637,10 @@ class ZFitPlotter:
524
637
 
525
638
  for ax in self.axs:
526
639
  ax.label_outer()
640
+
641
+ if pull_styling and not skip_pulls:
642
+ self.axs[1].axhline(y=-3, color='red' , linestyle='-', lw=2)
643
+ self.axs[1].axhline(y= 0, color='gray', linestyle='-', lw=1)
644
+ self.axs[1].axhline(y=+3, color='red' , linestyle='-', lw=2)
645
+ self.axs[1].set_ylim(-5, 5)
527
646
  #----------------------------------------
dmu/testing/utilities.py CHANGED
@@ -3,16 +3,21 @@ Module containing utility functions needed by unit tests
3
3
  '''
4
4
  import os
5
5
  import math
6
+ import glob
6
7
  from typing import Union
7
8
  from dataclasses import dataclass
8
9
  from importlib.resources import files
9
10
 
10
11
  from ROOT import RDF, TFile, RDataFrame
11
12
 
13
+ import uproot
14
+ import joblib
12
15
  import pandas as pnd
13
16
  import numpy
14
17
  import yaml
15
18
 
19
+ from dmu.ml.train_mva import TrainMva
20
+ from dmu.ml.cv_classifier import CVClassifier
16
21
  from dmu.logging.log_store import LogStore
17
22
 
18
23
  log = LogStore.add_logger('dmu:testing:utilities')
@@ -22,6 +27,15 @@ class Data:
22
27
  '''
23
28
  Class storing shared data
24
29
  '''
30
+ out_dir = '/tmp/tests/dmu/ml/cv_predict'
31
+
32
+ d_col = {
33
+ 'main' : ['index', 'a0', 'b0'],
34
+ 'frn1' : ['index', 'a1', 'b1'],
35
+ 'frn2' : ['index', 'a2', 'b2'],
36
+ 'frn3' : ['index', 'a3', 'b3'],
37
+ 'frn4' : ['index', 'a4', 'b4'],
38
+ }
25
39
  # -------------------------------
26
40
  def _double_data(df_1 : pnd.DataFrame) -> pnd.DataFrame:
27
41
  df_2 = df_1.copy()
@@ -39,7 +53,7 @@ def _add_nans(df : pnd.DataFrame, columns : list[str]) -> pnd.DataFrame:
39
53
  else:
40
54
  l_col_index = [ l_col.index(column) for column in columns ]
41
55
 
42
- log.debug('Replacing randomly with {size} NaNs')
56
+ log.debug(f'Replacing randomly with {size} NaNs')
43
57
  for _ in range(size):
44
58
  irow = numpy.random.randint(0, df.shape[0]) # Random row index
45
59
  icol = numpy.random.choice(l_col_index) # Random column index
@@ -48,25 +62,39 @@ def _add_nans(df : pnd.DataFrame, columns : list[str]) -> pnd.DataFrame:
48
62
 
49
63
  return df
50
64
  # -------------------------------
51
- def get_rdf(kind : Union[str,None] = None,
52
- repeated : bool = False,
53
- nentries : int = 3_000,
54
- add_nans : list[str] = None):
65
+ def get_rdf(
66
+ kind : Union[str,None] = None,
67
+ repeated : bool = False,
68
+ nentries : int = 3_000,
69
+ use_preffix : bool = False,
70
+ columns_with_nans : list[str] = None):
55
71
  '''
56
72
  Return ROOT dataframe with toy data
73
+
74
+ kind : sig, bkg or bkg_alt
75
+ repeated : Will add repeated rows
76
+ nentries : Number of rows
77
+ columns_with_nans : List of column names in [w, y, z]
57
78
  '''
79
+ # Needed for a specific test
80
+ xnm = 'preffix.x.suffix' if use_preffix else 'x'
58
81
 
59
82
  d_data = {}
60
83
  if kind == 'sig':
61
- d_data['w'] = numpy.random.normal(0, 1, size=nentries)
62
- d_data['x'] = numpy.random.normal(0, 1, size=nentries)
63
- d_data['y'] = numpy.random.normal(0, 1, size=nentries)
64
- d_data['z'] = numpy.random.normal(0, 1, size=nentries)
84
+ d_data[xnm] = numpy.random.normal(0.0, 1.0, size=nentries)
85
+ d_data['w'] = numpy.random.normal(0.0, 1.0, size=nentries)
86
+ d_data['y'] = numpy.random.normal(0.0, 1.0, size=nentries)
87
+ d_data['z'] = numpy.random.normal(0.0, 1.0, size=nentries)
65
88
  elif kind == 'bkg':
66
- d_data['w'] = numpy.random.normal(1, 1, size=nentries)
67
- d_data['x'] = numpy.random.normal(1, 1, size=nentries)
68
- d_data['y'] = numpy.random.normal(1, 1, size=nentries)
69
- d_data['z'] = numpy.random.normal(1, 1, size=nentries)
89
+ d_data[xnm] = numpy.random.normal(1.0, 1.0, size=nentries)
90
+ d_data['w'] = numpy.random.normal(1.0, 1.0, size=nentries)
91
+ d_data['y'] = numpy.random.normal(1.0, 1.0, size=nentries)
92
+ d_data['z'] = numpy.random.normal(1.0, 1.0, size=nentries)
93
+ elif kind == 'bkg_alt':
94
+ d_data[xnm] = numpy.random.normal(1.3, 1.3, size=nentries)
95
+ d_data['w'] = numpy.random.normal(1.3, 1.3, size=nentries)
96
+ d_data['y'] = numpy.random.normal(1.3, 1.3, size=nentries)
97
+ d_data['z'] = numpy.random.normal(1.3, 1.3, size=nentries)
70
98
  else:
71
99
  log.error(f'Invalid kind: {kind}')
72
100
  raise ValueError
@@ -76,8 +104,8 @@ def get_rdf(kind : Union[str,None] = None,
76
104
  if repeated:
77
105
  df = _double_data(df)
78
106
 
79
- if add_nans:
80
- df = _add_nans(df, columns=add_nans)
107
+ if columns_with_nans is not None:
108
+ df = _add_nans(df, columns=columns_with_nans)
81
109
 
82
110
  rdf = RDF.FromPandas(df)
83
111
 
@@ -126,3 +154,80 @@ def get_file_with_trees(path : str) -> TFile:
126
154
  snap.fMode = 'update'
127
155
 
128
156
  return TFile(path)
157
+ # -------------------------------
158
+ def get_models(
159
+ rdf_sig : RDataFrame,
160
+ rdf_bkg : RDataFrame,
161
+ name : str = 'train_mva',
162
+ out_dir : str | None = None) -> tuple[list[CVClassifier], float]:
163
+ '''
164
+ Will train and return models together with the AUC in a tuple
165
+
166
+ rdf_xxx : Signal or background dataframe used for training
167
+ name : Name of config file, e.g. train_mva
168
+ out_dir : Directory where the training output will go, optional.
169
+ '''
170
+ out_dir = Data.out_dir if out_dir is None else out_dir
171
+
172
+ cfg = get_config(f'ml/tests/{name}.yaml')
173
+ cfg['saving']['output'] = out_dir
174
+
175
+ obj = TrainMva(sig=rdf_sig, bkg=rdf_bkg, cfg=cfg)
176
+ auc = obj.run()
177
+
178
+ pkl_wc = f'{out_dir}/model*.pkl'
179
+ l_pkl_path = glob.glob(pkl_wc)
180
+ l_model = [ joblib.load(pkl_path) for pkl_path in l_pkl_path ]
181
+
182
+ return l_model, auc
183
+ # -------------------------------
184
+ def _make_file(
185
+ fpath : str,
186
+ tree : str,
187
+ nentries : int) -> None:
188
+
189
+ fdir = os.path.dirname(fpath)
190
+ sample = os.path.basename(fdir)
191
+ l_col_name = Data.d_col[sample]
192
+ data = {}
193
+ for col_name in l_col_name:
194
+ if col_name == 'index':
195
+ data[col_name] = numpy.arange(nentries)
196
+ continue
197
+
198
+ data[col_name] = numpy.random.normal(0, 1, nentries)
199
+
200
+ with uproot.recreate(fpath) as ofile:
201
+ log.debug(f'Saving to: {fpath}:{tree}')
202
+ ofile[tree] = data
203
+ # -------------------------------
204
+ def build_friend_structure(file_name : str, nentries : int) -> None:
205
+ '''
206
+ Will load YAML file with file structure needed to
207
+ test code that relies on friend trees, e.g. DDFGetter
208
+
209
+ Parameters:
210
+ -------------------
211
+ file_name (str): Name of YAML file with wanted structure, e.g. friends.yaml
212
+ nentries (int) : Number of entries in file
213
+ '''
214
+ cfg_path = files('dmu_data').joinpath(f'rfile/{file_name}')
215
+ with open(cfg_path, encoding='utf=8') as ifile:
216
+ data = yaml.safe_load(ifile)
217
+
218
+ if 'tree' not in data:
219
+ raise ValueError('tree entry missing in: {cfg_path}')
220
+
221
+ tree_name = data['tree']
222
+
223
+ if 'samples' not in data:
224
+ raise ValueError('Samples section missing in: {cfg_path}')
225
+
226
+ if 'files' not in data:
227
+ raise ValueError('Files section missing in: {cfg_path}')
228
+
229
+ for fdir in data['samples']:
230
+ for fname in data['files']:
231
+ path = f'{fdir}/{fname}'
232
+ _make_file(fpath=path, tree=tree_name, nentries=nentries)
233
+ # ----------------------------------------------
File without changes