data-manipulation-utilities 0.2.7__py3-none-any.whl → 0.2.8.dev714__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. {data_manipulation_utilities-0.2.7.dist-info → data_manipulation_utilities-0.2.8.dev714.dist-info}/METADATA +641 -44
  2. data_manipulation_utilities-0.2.8.dev714.dist-info/RECORD +93 -0
  3. {data_manipulation_utilities-0.2.7.dist-info → data_manipulation_utilities-0.2.8.dev714.dist-info}/WHEEL +1 -1
  4. {data_manipulation_utilities-0.2.7.dist-info → data_manipulation_utilities-0.2.8.dev714.dist-info}/entry_points.txt +1 -0
  5. dmu/__init__.py +0 -0
  6. dmu/generic/hashing.py +34 -8
  7. dmu/generic/utilities.py +164 -11
  8. dmu/logging/log_store.py +34 -2
  9. dmu/logging/messages.py +96 -0
  10. dmu/ml/cv_classifier.py +3 -3
  11. dmu/ml/cv_diagnostics.py +3 -0
  12. dmu/ml/cv_performance.py +58 -0
  13. dmu/ml/cv_predict.py +149 -46
  14. dmu/ml/train_mva.py +482 -100
  15. dmu/ml/utilities.py +29 -10
  16. dmu/pdataframe/utilities.py +28 -3
  17. dmu/plotting/fwhm.py +2 -2
  18. dmu/plotting/matrix.py +1 -1
  19. dmu/plotting/plotter.py +23 -3
  20. dmu/plotting/plotter_1d.py +96 -32
  21. dmu/plotting/plotter_2d.py +5 -0
  22. dmu/rdataframe/utilities.py +54 -3
  23. dmu/rfile/ddfgetter.py +102 -0
  24. dmu/stats/fit_stats.py +129 -0
  25. dmu/stats/fitter.py +55 -22
  26. dmu/stats/gof_calculator.py +7 -0
  27. dmu/stats/model_factory.py +153 -62
  28. dmu/stats/parameters.py +100 -0
  29. dmu/stats/utilities.py +443 -12
  30. dmu/stats/wdata.py +187 -0
  31. dmu/stats/zfit.py +17 -0
  32. dmu/stats/zfit_plotter.py +147 -36
  33. dmu/testing/utilities.py +102 -24
  34. dmu/workflow/__init__.py +0 -0
  35. dmu/workflow/cache.py +266 -0
  36. dmu_data/ml/tests/train_mva.yaml +9 -7
  37. dmu_data/ml/tests/train_mva_def.yaml +75 -0
  38. dmu_data/ml/tests/train_mva_with_diagnostics.yaml +10 -5
  39. dmu_data/ml/tests/train_mva_with_preffix.yaml +58 -0
  40. dmu_data/plotting/tests/2d.yaml +5 -5
  41. dmu_data/plotting/tests/line.yaml +15 -0
  42. dmu_data/plotting/tests/styling.yaml +8 -1
  43. dmu_data/rfile/friends.yaml +13 -0
  44. dmu_data/stats/fitter/test_simple.yaml +28 -0
  45. dmu_data/stats/kde_optimizer/control.json +1 -0
  46. dmu_data/stats/kde_optimizer/signal.json +1 -0
  47. dmu_data/stats/parameters/data.yaml +178 -0
  48. dmu_data/tests/config.json +6 -0
  49. dmu_data/tests/config.yaml +4 -0
  50. dmu_data/tests/pdf_to_tex.txt +34 -0
  51. dmu_scripts/kerberos/check_expiration +21 -0
  52. dmu_scripts/kerberos/convert_certificate +22 -0
  53. dmu_scripts/ml/compare_classifiers.py +85 -0
  54. data_manipulation_utilities-0.2.7.dist-info/RECORD +0 -69
  55. {data_manipulation_utilities-0.2.7.data → data_manipulation_utilities-0.2.8.dev714.data}/scripts/publish +0 -0
  56. {data_manipulation_utilities-0.2.7.dist-info → data_manipulation_utilities-0.2.8.dev714.dist-info}/top_level.txt +0 -0
dmu/stats/utilities.py CHANGED
@@ -1,18 +1,112 @@
1
1
  '''
2
2
  Module with utility functions related to the dmu.stats project
3
3
  '''
4
+ # pylint: disable=import-error
5
+
4
6
  import os
5
7
  import re
8
+ import pickle
6
9
  from typing import Union
7
- import zfit
8
10
 
9
- from dmu.logging.log_store import LogStore
11
+ import numpy
12
+ import pandas as pnd
13
+ import matplotlib.pyplot as plt
14
+
15
+ import dmu.pdataframe.utilities as put
16
+ import dmu.generic.utilities as gut
17
+
18
+ from dmu.stats.zfit import zfit
19
+ from dmu.stats.fitter import Fitter
20
+ from dmu.stats.zfit_plotter import ZFitPlotter
21
+ from dmu.logging.log_store import LogStore
22
+
23
+ import tensorflow as tf
24
+
25
+ from omegaconf import OmegaConf, DictConfig
26
+ from zfit.core.interfaces import ZfitData as zdata
27
+ from zfit.core.interfaces import ZfitSpace as zobs
28
+ from zfit.core.interfaces import ZfitPDF as zpdf
29
+ from zfit.core.parameter import Parameter as zpar
30
+ from zfit.result import FitResult as zres
10
31
 
11
32
  log = LogStore.add_logger('dmu:stats:utilities')
12
33
  #-------------------------------------------------------
34
+ class Data:
35
+ '''
36
+ Data class
37
+ '''
38
+ weight_name = 'weight'
39
+ #-------------------------------------------------------
40
+ def name_from_obs(obs : zobs) -> str:
41
+ '''
42
+ Takes zfit observable, returns its name
43
+ It is assumed this is a 1D observable
44
+ '''
45
+ if not isinstance(obs.obs, tuple):
46
+ raise ValueError(f'Cannot retrieve name for: {obs}')
47
+
48
+ if len(obs.obs) != 1:
49
+ raise ValueError(f'Observable is not 1D: {obs.obs}')
50
+
51
+ return obs.obs[0]
52
+ #-------------------------------------------------------
53
+ def range_from_obs(obs : zobs) -> tuple[float,float]:
54
+ '''
55
+ Takes zfit observable, returns tuple with two floats, representing range
56
+ '''
57
+ if not isinstance(obs.limits, tuple):
58
+ raise ValueError(f'Cannot retrieve name for: {obs}')
59
+
60
+ if len(obs.limits) != 2:
61
+ raise ValueError(f'Observable has more than one range: {obs.limits}')
62
+
63
+ minx, maxx = obs.limits
64
+
65
+ return float(minx[0][0]), float(maxx[0][0])
66
+ #-------------------------------------------------------
67
+ def yield_from_zdata(data : zdata) -> float:
68
+ '''
69
+ Parameter
70
+ --------------
71
+ data : Zfit dataset
72
+
73
+ Returns
74
+ --------------
75
+ Yield, i.e. number of entries or sum of weights if weighted dataset
76
+ '''
77
+
78
+ if data.weights is not None:
79
+ val = data.weights.numpy().sum()
80
+ else:
81
+ arr_val = data.to_numpy()
82
+ val = len(arr_val)
83
+
84
+ if val < 0:
85
+ raise ValueError(f'Yield cannot be negative, found {val}')
86
+
87
+ return val
88
+ #-------------------------------------------------------
89
+ # Check PDF
90
+ #-------------------------------------------------------
91
+ def is_pdf_usable(pdf : zpdf) -> zpdf:
92
+ '''
93
+ Will check if the PDF is usable
94
+ '''
95
+ [[[minx]], [[maxx]]]= pdf.space.limits
96
+
97
+ arr_x = numpy.linspace(minx, maxx, 100)
98
+
99
+ try:
100
+ pdf.pdf(arr_x)
101
+ except tf.errors.InvalidArgumentError:
102
+ log.warning('PDF cannot be evaluated')
103
+ return False
104
+
105
+ return True
106
+ #-------------------------------------------------------
13
107
  #Zfit/print_pdf
14
108
  #-------------------------------------------------------
15
- def _get_const(par : zfit.Parameter, d_const : Union[None, dict[str, list[float]]]) -> str:
109
+ def _get_const(par : zpar , d_const : Union[None, dict[str, tuple[float,float]]]) -> str:
16
110
  '''
17
111
  Takes zfit parameter and dictionary of constraints
18
112
  Returns a formatted string with the value of the constraint on that parameter
@@ -23,13 +117,13 @@ def _get_const(par : zfit.Parameter, d_const : Union[None, dict[str, list[float]
23
117
  obj = d_const[par.name]
24
118
  if isinstance(obj, (list, tuple)):
25
119
  [mu, sg] = obj
26
- val = f'{mu:.3e}; {sg:.3e}'
120
+ val = f'{mu:.3e}___{sg:.3e}' # This separator needs to be readable and not a space
27
121
  else:
28
122
  val = str(obj)
29
123
 
30
124
  return val
31
125
  #-------------------------------------------------------
32
- def _blind_vars(s_par : set, l_blind : Union[list[str], None] = None) -> set[zfit.Parameter]:
126
+ def _blind_vars(s_par : set, l_blind : Union[list[str], None] = None) -> set[zpar]:
33
127
  '''
34
128
  Takes set of zfit parameters and list of parameter names to blind
35
129
  returns set of zfit parameters that should be blinded
@@ -45,7 +139,7 @@ def _blind_vars(s_par : set, l_blind : Union[list[str], None] = None) -> set[zfi
45
139
  return s_par_blind
46
140
  #-------------------------------------------------------
47
141
  def _get_pars(
48
- pdf : zfit.pdf.BasePDF,
142
+ pdf : zpdf,
49
143
  blind : Union[None, list[str]]) -> tuple[list, list]:
50
144
 
51
145
  s_par_flt = pdf.get_params(floating= True)
@@ -63,7 +157,7 @@ def _get_pars(
63
157
  return l_par_flt, l_par_fix
64
158
  #-------------------------------------------------------
65
159
  def _get_messages(
66
- pdf : zfit.pdf.BasePDF,
160
+ pdf : zpdf,
67
161
  l_par_flt : list,
68
162
  l_par_fix : list,
69
163
  d_const : Union[None, dict[str,list[float]]] = None) -> list[str]:
@@ -95,11 +189,11 @@ def _get_messages(
95
189
  return l_msg
96
190
  #-------------------------------------------------------
97
191
  def print_pdf(
98
- pdf : zfit.pdf.BasePDF,
99
- d_const : Union[None, dict[str,list[float]]] = None,
100
- txt_path : Union[str,None] = None,
101
- level : int = 20,
102
- blind : Union[None, list[str]] = None):
192
+ pdf : zpdf,
193
+ d_const : Union[None, dict[str,tuple[float, float]]] = None,
194
+ txt_path : Union[str,None] = None,
195
+ level : int = 20,
196
+ blind : Union[None, list[str]] = None):
103
197
  '''
104
198
  Function used to print zfit PDFs
105
199
 
@@ -131,4 +225,341 @@ def print_pdf(
131
225
  log.debug(msg)
132
226
  else:
133
227
  raise ValueError(f'Invalid level: {level}')
228
+ #---------------------------------------------
229
+ def _parameters_from_result(result : zres) -> dict[str,tuple[float,float]]:
230
+ d_par = {}
231
+ log.debug('Reading parameters from:')
232
+ if log.getEffectiveLevel() == 10:
233
+ print(result)
234
+
235
+ log.debug(60 * '-')
236
+ log.debug('Reading parameters')
237
+ log.debug(60 * '-')
238
+ for name, d_val in result.params.items():
239
+ value = d_val['value']
240
+ error = None
241
+ if 'hesse' in d_val:
242
+ error = d_val['hesse']['error']
243
+
244
+ if 'minuit_hesse' in d_val:
245
+ error = d_val['minuit_hesse']['error']
246
+
247
+ log.debug(f'{name:<20}{value:<20.3f}{error}')
248
+
249
+ d_par[name] = value, error
250
+
251
+ return d_par
252
+ #---------------------------------------------
253
+ def save_fit(
254
+ data : zdata,
255
+ model : zpdf|None,
256
+ res : zres|None,
257
+ fit_dir : str,
258
+ d_const : dict[str,tuple[float,float]]|None = None) -> None:
259
+ '''
260
+ Function used to save fit results, meant to reduce boiler plate code
261
+
262
+ Plots: If:
263
+
264
+ ptr = ZFitPlotter(data=dat, model=pdf)
265
+ ptr.plot()
266
+
267
+ was done before calling this method, the plot will also be saved
268
+
269
+ Parameters
270
+ --------------------
271
+ model: PDF to be plotted, if None, will skip steps
272
+ '''
273
+ os.makedirs(fit_dir, exist_ok=True)
274
+ log.info(f'Saving fit to: {fit_dir}')
275
+
276
+ if plt.get_fignums():
277
+ fit_path = f'{fit_dir}/fit.png'
278
+ log.info(f'Saving fit to: {fit_path}')
279
+ plt.savefig(fit_path)
280
+ plt.close('all')
281
+ else:
282
+ log.info('No fit plot found')
283
+
284
+ _save_result(fit_dir=fit_dir, res=res)
285
+
286
+ df = data.to_pandas(weightsname=Data.weight_name)
287
+ opath = f'{fit_dir}/data.json'
288
+ log.debug(f'Saving data to: {opath}')
289
+ df.to_json(opath, indent=2)
290
+
291
+ if model is None:
292
+ return
293
+
294
+ print_pdf(model, txt_path=f'{fit_dir}/post_fit.txt', d_const=d_const)
295
+ pdf_to_tex(path=f'{fit_dir}/post_fit.txt', d_par={'mu' : r'$\mu$'}, skip_fixed=True)
296
+ #-------------------------------------------------------
297
+ def _save_result(fit_dir : str, res : zres|None) -> None:
298
+ '''
299
+ Saves result as yaml, JSON, pkl
300
+
301
+ Parameters
302
+ ---------------
303
+ fit_dir: Directory where fit result will go
304
+ res : Zfit result object
305
+ '''
306
+ if res is None:
307
+ log.info('No result object found, not saving parameters in pkl or JSON')
308
+ return
309
+
310
+ # TODO: Remove this once there be a safer way to freeze
311
+ # see https://github.com/zfit/zfit/issues/632
312
+ try:
313
+ res.freeze()
314
+ except AttributeError:
315
+ pass
316
+
317
+ with open(f'{fit_dir}/fit.pkl', 'wb') as ofile:
318
+ pickle.dump(res, ofile)
319
+
320
+ d_par = _parameters_from_result(result=res)
321
+ opath = f'{fit_dir}/parameters.json'
322
+ log.debug(f'Saving parameters to: {opath}')
323
+ gut.dump_json(d_par, opath)
324
+
325
+ opath = f'{fit_dir}/parameters.yaml'
326
+ cres = zres_to_cres(res=res)
327
+ OmegaConf.save(cres, opath)
328
+ #-------------------------------------------------------
329
+ # Make latex table from text file
330
+ #-------------------------------------------------------
331
+ def _reformat_expo(val : str) -> str:
332
+ regex = r'([\d\.]+)e([-,\d]+)'
333
+ mtch = re.match(regex, val)
334
+ if not mtch:
335
+ raise ValueError(f'Cannot extract value and exponent from: {val}')
336
+
337
+ [val, exp] = mtch.groups()
338
+ exp = int(exp)
339
+
340
+ return f'{val}\cdot 10^{{{exp}}}'
134
341
  #-------------------------------------------------------
342
+ def _format_float_str(val : str) -> str:
343
+ '''
344
+ Takes number as string and returns a formatted version
345
+ '''
346
+
347
+ fval = float(val)
348
+ if abs(fval) > 1000:
349
+ return f'{fval:,.0f}'
350
+
351
+ val = f'{fval:.3g}'
352
+ if 'e' in val:
353
+ val = _reformat_expo(val)
354
+
355
+ return val
356
+ #-------------------------------------------------------
357
+ def _info_from_line(line : str) -> tuple|None:
358
+ regex = r'(^\S+)\s+(\S+)\s+(\S+)\s+(\S+)\s+(\S+)\s+(\S+)'
359
+ mtch = re.match(regex, line)
360
+ if not mtch:
361
+ return None
362
+
363
+ log.debug(f'Reading information from: {line}')
364
+
365
+ [par, _, low, high, floating, cons] = mtch.groups()
366
+
367
+ low = _format_float_str(low)
368
+ high = _format_float_str(high)
369
+
370
+ if cons != 'none':
371
+ [mu, sg] = cons.split('___')
372
+
373
+ mu = _format_float_str(mu)
374
+ sg = _format_float_str(sg)
375
+
376
+ cons = f'$\mu={mu}; \sigma={sg}$'
377
+
378
+ return par, low, high, floating, cons
379
+ #-------------------------------------------------------
380
+ def _df_from_lines(l_line : list[str]) -> pnd.DataFrame:
381
+ df = pnd.DataFrame(columns=['Parameter', 'Low', 'High', 'Floating', 'Constraint'])
382
+
383
+ for line in l_line:
384
+ info = _info_from_line(line=line)
385
+ if info is None:
386
+ continue
387
+
388
+ par, low, high, floating, cons = info
389
+
390
+ df.loc[len(df)] = {'Parameter' : par,
391
+ 'Low' : low,
392
+ 'High' : high,
393
+ 'Floating' : floating,
394
+ 'Constraint': cons,
395
+ }
396
+
397
+ return df
398
+ #-------------------------------------------------------
399
+ def pdf_to_tex(path : str, d_par : dict[str,str], skip_fixed : bool = True) -> None:
400
+ '''
401
+ Takes
402
+
403
+ path: path to a `txt` file produced by stats/utilities:print_pdf
404
+ d_par: Dictionary mapping parameter names in this file to proper latex names
405
+
406
+ Creates a latex table with the same name as `path` but `txt` extension replaced by `tex`
407
+ '''
408
+
409
+ path = str(path)
410
+ with open(path, encoding='utf-8') as ifile:
411
+ l_line = ifile.read().splitlines()
412
+ l_line = l_line[4:] # Remove header
413
+
414
+ df = _df_from_lines(l_line)
415
+ df['Parameter']=df.Parameter.apply(lambda x : d_par.get(x, x.replace('_', ' ')))
416
+
417
+ out_path = path.replace('.txt', '.tex')
418
+
419
+ if skip_fixed:
420
+ df = df[df.Floating == '1']
421
+ df = df.drop(columns='Floating')
422
+
423
+ df_1 = df[df.Constraint == 'none']
424
+ df_2 = df[df.Constraint != 'none']
425
+
426
+ df_1 = df_1.sort_values(by='Parameter', ascending=True)
427
+ df_2 = df_2.sort_values(by='Parameter', ascending=True)
428
+ df = pnd.concat([df_1, df_2])
429
+
430
+ put.df_to_tex(df, out_path)
431
+ #---------------------------------------------
432
+ # Fake/Placeholder fit
433
+ #---------------------------------------------
434
+ def get_model(
435
+ kind : str,
436
+ obs : zobs|None = None,
437
+ lam : float = -0.0001) -> zpdf:
438
+ '''
439
+ Returns zfit PDF for tests
440
+
441
+ Parameters:
442
+
443
+ kind: 'signal' for Gaussian, 's+b' for Gaussian plus exponential
444
+ obs : If provided, will use it, by default None and will be built in function
445
+ lam : Decay constant of exponential component, set to -0.0001 by default
446
+ '''
447
+ if obs is None:
448
+ obs = zfit.Space('mass', limits=(4500, 7000))
449
+
450
+ mu = zfit.Parameter('mu', 5200, 4500, 6000)
451
+ sg = zfit.Parameter('sg', 50, 10, 200)
452
+ gaus = zfit.pdf.Gauss(obs=obs, mu=mu, sigma=sg)
453
+
454
+ if kind == 'signal':
455
+ return gaus
456
+
457
+ c = zfit.Parameter('c', lam, -0.01, 0)
458
+ expo= zfit.pdf.Exponential(obs=obs, lam=c)
459
+
460
+ if kind == 's+b':
461
+ nexpo = zfit.param.Parameter('nbkg', 1000, 0, 1000_000)
462
+ ngaus = zfit.param.Parameter('nsig', 1000, 0, 1000_000)
463
+
464
+ bkg = expo.create_extended(nexpo)
465
+ sig = gaus.create_extended(ngaus)
466
+ pdf = zfit.pdf.SumPDF([bkg, sig])
467
+
468
+ return pdf
469
+
470
+ raise NotImplementedError(f'Invalid kind of fit: {kind}')
471
+ #---------------------------------------------
472
+ def _pdf_to_data(pdf : zpdf, add_weights : bool) -> zdata:
473
+ nentries = 10_000
474
+ data = pdf.create_sampler(n=nentries)
475
+ if not add_weights:
476
+ return data
477
+
478
+ arr_wgt = numpy.random.normal(loc=1, scale=0.1, size=nentries)
479
+ data = data.with_weights(arr_wgt)
480
+
481
+ return data
482
+ #---------------------------------------------
483
+ def placeholder_fit(
484
+ kind : str,
485
+ fit_dir : str,
486
+ df : pnd.DataFrame|None = None,
487
+ plot_fit : bool = True) -> None:
488
+ '''
489
+ Function meant to run toy fits that produce output needed as an input
490
+ to develop tools on top of them
491
+
492
+ kind: Kind of fit, e.g. s+b for the simples signal plus background fit
493
+ fit_dir: Directory where the output of the fit will go
494
+ df: pandas dataframe if passed, will reuse that data, needed to test data caching
495
+ plot_fit: Will plot the fit or not, by default True
496
+ '''
497
+ pdf = get_model(kind)
498
+ print_pdf(pdf, txt_path=f'{fit_dir}/pre_fit.txt')
499
+ if df is None:
500
+ log.warning('Using user provided data')
501
+ data = _pdf_to_data(pdf=pdf, add_weights=True)
502
+ else:
503
+ data = zfit.Data.from_pandas(df, obs=pdf.space, weights=Data.weight_name)
504
+
505
+ d_const = {'sg' : [50, 3]}
506
+
507
+ obj = Fitter(pdf, data)
508
+ res = obj.fit(cfg={'constraints' : d_const})
509
+
510
+ if plot_fit:
511
+ obj = ZFitPlotter(data=data, model=pdf)
512
+ obj.plot(nbins=50, stacked=True)
513
+
514
+ save_fit(data=data, model=pdf, res=res, fit_dir=fit_dir, d_const=d_const)
515
+ #---------------------------------------------
516
+ def _reformat_values(d_par : dict) -> dict:
517
+ '''
518
+ Parameters
519
+ --------------
520
+ d_par: Dictionary formatted as:
521
+
522
+ {'minuit_hesse': {'cl': 0.6,
523
+ 'error': np.float64(0.04),
524
+ 'weightcorr': <WeightCorr.FALSE: False>},
525
+ 'value' : 0.34},
526
+
527
+ Returns
528
+ --------------
529
+ Dictionary formatted as:
530
+
531
+ {
532
+ 'error' : 0.04,
533
+ 'value' : 0.34
534
+ }
535
+ '''
536
+
537
+ error = d_par['minuit_hesse']['error']
538
+ error = float(error)
539
+
540
+ value = d_par['value']
541
+
542
+ return {'value' : value, 'error' : error}
543
+ #---------------------------------------------
544
+ def zres_to_cres(res : zres) -> DictConfig:
545
+ '''
546
+ Parameters
547
+ --------------
548
+ res : Zfit result object
549
+
550
+ Returns
551
+ --------------
552
+ OmegaConfig's DictConfig instance
553
+ '''
554
+ # This should prevent crash when result object was already frozen
555
+ try:
556
+ res.freeze()
557
+ except AttributeError:
558
+ pass
559
+
560
+ par = res.params
561
+ d_par = { name : _reformat_values(d_par=d_par) for name, d_par in par.items()}
562
+ cfg = OmegaConf.create(d_par)
563
+
564
+ return cfg
565
+ #---------------------------------------------
dmu/stats/wdata.py ADDED
@@ -0,0 +1,187 @@
1
+ '''
2
+ Module with Wdata class
3
+ '''
4
+ from typing import Union
5
+
6
+ import zfit
7
+ import numpy
8
+ import pandas as pnd
9
+ from zfit.core.interfaces import ZfitSpace as zobs
10
+ from zfit.core.data import Data as zdata
11
+
12
+ from dmu.logging.log_store import LogStore
13
+
14
+
15
+ log=LogStore.add_logger('dmu:stats:wdata')
16
+ # -------------------------------
17
+ class Wdata:
18
+ '''
19
+ Class meant to symbolize weighted data
20
+ '''
21
+ # -------------------------------
22
+ def __init__(self,
23
+ data : Union[numpy.ndarray, pnd.DataFrame],
24
+ weights : numpy.ndarray = None,
25
+ extra_columns : pnd.DataFrame = None):
26
+ '''
27
+ data :
28
+ weights: Numpy array with weights, if not passed, will use ones as weights
29
+ extra_columns: Extra information that can be attached to the Wdata object in the form of a pandas dataframe, default None
30
+ '''
31
+ self._data = data
32
+ self._weights = self._get_weights(weights)
33
+ self._df = self._get_df_extr(extra_columns)
34
+ # -------------------------------
35
+ def _get_df_extr(self, df : pnd.DataFrame) -> Union[pnd.DataFrame,None]:
36
+ if df is None:
37
+ return None
38
+
39
+ if not isinstance(df, pnd.DataFrame):
40
+ arg_type = type(df)
41
+ raise ValueError(f'Expected a pandas dataframe, got {arg_type}')
42
+
43
+ if len(df) != self.size:
44
+ raise ValueError('Input dataframe differs in length from data')
45
+
46
+ return df
47
+ # -------------------------------
48
+ def _get_weights(self, weights : numpy.ndarray) -> numpy.ndarray:
49
+ if weights is None:
50
+ log.info('Weights not found, using ones')
51
+ return numpy.ones(self.size)
52
+
53
+ if not isinstance(weights, numpy.ndarray):
54
+ raise ValueError('Weights argument is not a numpy array')
55
+
56
+ weights_size = len(weights)
57
+ if weights_size != self.size:
58
+ raise ValueError(f'Data size and weights size differ: {self.size} != {weights_size}')
59
+
60
+ return weights
61
+ # -------------------------------
62
+ def _build_new_array(self, arr_other : numpy.ndarray, kind : str) -> numpy.ndarray:
63
+ arr_this = getattr(self, kind)
64
+ arr = numpy.concatenate([arr_this, arr_other])
65
+
66
+ return arr
67
+ # -------------------------------
68
+ def __add__(self, other : 'Wdata') -> 'Wdata':
69
+ '''
70
+ Takes instance of Wdata and adds it to this instance
71
+ returning sum.
72
+
73
+ Addition is defined as concatenating both data and weights.
74
+ '''
75
+ if not isinstance(other, Wdata):
76
+ other_type = type(other)
77
+ raise NotImplementedError(f'Cannot add Wdata instance to {other_type} instance')
78
+
79
+ log.debug('Adding instances of Wdata')
80
+ data = self._build_new_array(arr_other = other._data , kind='_data' )
81
+ weights = self._build_new_array(arr_other = other._weights, kind='_weights')
82
+ df = self._build_extra_df(df_other = other._df)
83
+
84
+ return Wdata(data=data, weights=weights, extra_columns=df)
85
+ # -------------------------------
86
+ def __str__(self) -> str:
87
+ message = '\n'
88
+ message+= f'{"Size ":<20}{self.size:<20}\n'
89
+ message+= f'{"Sumw ":<20}{self.sumw:<20.3f}\n'
90
+ if self._df is None:
91
+ return message
92
+
93
+ message+= f'{"Columns":<20}{" ":<20}\n'
94
+ for column in self._df.columns:
95
+ message += ' ' + column + '\n'
96
+
97
+ return message
98
+ # -------------------------------
99
+ def _build_extra_df(self, df_other : pnd.DataFrame) -> Union[pnd.DataFrame,None]:
100
+ if df_other is None and self._df is None:
101
+ return None
102
+
103
+ fail_1 = df_other is None and self._df is not None
104
+ fail_2 = df_other is not None and self._df is None
105
+
106
+ if fail_1 or fail_2:
107
+ raise ValueError('One of the two Wdata instances does not contain extra column information')
108
+
109
+ df = pnd.concat([df_other, self._df], axis=0, ignore_index=True)
110
+
111
+ return df
112
+ # -------------------------------
113
+ def _is_extra_data_equal(self, df_other : pnd.DataFrame, rtol : float) -> bool:
114
+ df_this = self._df
115
+
116
+ if df_other is None and df_this is None:
117
+ return True
118
+
119
+ fail_1 = df_other is None and df_this is not None
120
+ fail_2 = df_other is not None and df_this is None
121
+
122
+ if fail_1 or fail_2:
123
+ log.warning('One of the weighted data compared does not have extra columns information')
124
+ return False
125
+
126
+ return numpy.allclose(df_this.values, df_other.values, rtol=rtol)
127
+ # -------------------------------
128
+ def __eq__(self, other) -> bool:
129
+ '''
130
+ Checks that the data and weights are the same within a 1e-5 relative tolerance
131
+ '''
132
+ rtol = 1e-5
133
+ equal_data = numpy.allclose(other._data , self._data , rtol=rtol)
134
+ equal_weights = numpy.allclose(other._weights, self._weights, rtol=rtol)
135
+ equal_extra_data = self._is_extra_data_equal(df_other=other.extra_columns, rtol=rtol)
136
+
137
+ return equal_data and equal_weights and equal_extra_data
138
+ # -------------------------------
139
+ @property
140
+ def extra_columns(self) -> Union[pnd.DataFrame,None]:
141
+ '''
142
+ Dataframe with extra columns, or None, if not passed
143
+ '''
144
+ return self._df
145
+ # -------------------------------
146
+ @property
147
+ def size(self) -> int:
148
+ '''
149
+ Returns number of entries in dataset
150
+ '''
151
+ return len(self._data)
152
+ # -------------------------------
153
+ @property
154
+ def sumw(self) -> int:
155
+ '''
156
+ Returns sum of weights
157
+ '''
158
+ return numpy.sum(self._weights)
159
+ # -------------------------------
160
+ def update_weights(self, weights : numpy.ndarray, replace : bool) -> 'Wdata':
161
+ '''
162
+ Takes array of weights to either:
163
+ - Replace existing array
164
+ - Update by multiply by existing array
165
+
166
+ depending on the replace argument value. It returns a new instance of Wdata
167
+ '''
168
+ if self._weights.shape != weights.shape:
169
+ raise ValueError(f'Invalid shape for array of weights, expected/got: {self._weights.shape}/{weights.shape}')
170
+
171
+ new_weights = weights if replace else weights * self._weights
172
+
173
+ data = Wdata(data=self._data, weights=new_weights)
174
+
175
+ return data
176
+ # -------------------------------
177
+ def to_zfit(self, obs : zobs) -> zdata:
178
+ '''
179
+ Function that takes a zfit observable and uses it
180
+ to build a zfit data instance tha it then returns
181
+ '''
182
+ log.debug('Building zfit dataset')
183
+
184
+ data = zfit.data.Data(obs=obs, data=self._data, weights=self._weights)
185
+
186
+ return data
187
+ # -------------------------------
dmu/stats/zfit.py ADDED
@@ -0,0 +1,17 @@
1
+ '''
2
+ Module intended to wrap zfit
3
+
4
+ Needed in order to silence tensorflow messages
5
+ '''
6
+ # pylint: disable=unused-import, wrong-import-order
7
+
8
+ try:
9
+ import ROOT
10
+ except ImportError:
11
+ pass
12
+
13
+ import dmu.generic.utilities as gut
14
+ with gut.silent_import():
15
+ import tensorflow
16
+
17
+ import zfit