data-manipulation-utilities 0.2.4__py3-none-any.whl → 0.2.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: data_manipulation_utilities
3
- Version: 0.2.4
3
+ Version: 0.2.6
4
4
  Description-Content-Type: text/markdown
5
5
  Requires-Dist: logzero
6
6
  Requires-Dist: PyYAML
@@ -26,7 +26,7 @@ These are tools that can be used for different data analysis tasks.
26
26
 
27
27
  ## Pushing
28
28
 
29
- From the root directory of a version controlled project (i.e. a directory with the `.git` subdirectory)
29
+ From the root directory of a version controlled project (i.e. a directory with the `.git` subdirectory)
30
30
  using a `pyproject.toml` file, run:
31
31
 
32
32
  ```bash
@@ -36,10 +36,10 @@ publish
36
36
  such that:
37
37
 
38
38
  1. The `pyproject.toml` file is checked and the version of the project is extracted.
39
- 1. If a tag named as the version exists move to the steps below.
39
+ 1. If a tag named as the version exists move to the steps below.
40
40
  1. If it does not, make a new tag with the name as the version
41
41
 
42
- Then, for each remote it pushes the tags and the commits.
42
+ Then, for each remote it pushes the tags and the commits.
43
43
 
44
44
  *Why?*
45
45
 
@@ -137,7 +137,17 @@ pdf = mod.get_pdf()
137
137
  ```
138
138
 
139
139
  where the model is a sum of three `CrystallBall` PDFs, one with a right tail and two with a left tail.
140
- The `mu` and `sg` parameters are shared.
140
+ The `mu` and `sg` parameters are shared. The elementary components that can be plugged are:
141
+
142
+ ```
143
+ exp: Exponential
144
+ pol1: Polynomial of degree 1
145
+ pol2: Polynomial of degree 2
146
+ cbr : CrystallBall with right tail
147
+ cbl : CrystallBall with left tail
148
+ gauss : Gaussian
149
+ dscb : Double sided CrystallBall
150
+ ```
141
151
 
142
152
  ### Printing PDFs
143
153
 
@@ -299,7 +309,7 @@ this will:
299
309
  - Try fitting at most 10 times
300
310
  - After each fit, calculate the goodness of fit (in this case the p-value)
301
311
  - Stop when the number of tries has been exhausted or the p-value reached is higher than `0.05`
302
- - If the fit has not succeeded because of convergence, validity or goodness of fit issues,
312
+ - If the fit has not succeeded because of convergence, validity or goodness of fit issues,
303
313
  randomize the parameters and try again.
304
314
  - If the desired goodness of fit has not been achieved, pick the best result.
305
315
  - Return the `FitResult` object and set the PDF to the final fit result.
@@ -337,11 +347,11 @@ bkg = zfit.pdf.Exponential(obs=obs, lam=lm)
337
347
  nbk = zfit.Parameter('nbk', 1000, 0, 10000)
338
348
  ebkg= bkg.create_extended(nbk, name='expo')
339
349
 
340
- # Add them
350
+ # Add them
341
351
  pdf = zfit.pdf.SumPDF([ebkg, esig])
342
352
  sam = pdf.create_sampler()
343
353
 
344
- # Plot them
354
+ # Plot them
345
355
  obj = ZFitPlotter(data=sam, model=pdf)
346
356
  d_leg = {'gauss': 'New Gauss'}
347
357
  obj.plot(nbins=50, d_leg=d_leg, stacked=True, plot_range=(0, 10), ext_text='Extra text here')
@@ -353,7 +363,7 @@ obj.axs[1].plot([0, 10], [0, 0], linestyle='--', color='black')
353
363
  this class supports:
354
364
 
355
365
  - Handling title, legend, plots size.
356
- - Adding pulls.
366
+ - Adding pulls.
357
367
  - Stacking and overlaying of PDFs.
358
368
  - Blinding.
359
369
 
@@ -417,7 +427,7 @@ rdf_bkg = _get_rdf(kind='bkg')
417
427
  cfg = _get_config()
418
428
 
419
429
  obj= TrainMva(sig=rdf_sig, bkg=rdf_bkg, cfg=cfg)
420
- obj.run()
430
+ obj.run(skip_fit=False) # by default it will be false, if true, it will only make plots of features
421
431
  ```
422
432
 
423
433
  where the settings for the training go in a config dictionary, which when written to YAML looks like:
@@ -434,7 +444,7 @@ dataset:
434
444
  nan:
435
445
  x : 0
436
446
  y : 0
437
- z : -999
447
+ z : -999
438
448
  training :
439
449
  nfold : 10
440
450
  features : [x, y, z]
@@ -497,7 +507,7 @@ When training on real data, several things might go wrong and the code will try
497
507
  will end up in different folds. The tool checks for wether a model is evaluated for an entry that was used for training and raise an exception. Thus, repeated
498
508
  entries will be removed before training.
499
509
 
500
- - **NaNs**: Entries with NaNs will break the training with the scikit `GradientBoostClassifier` base class. Thus, we:
510
+ - **NaNs**: Entries with NaNs will break the training with the scikit `GradientBoostClassifier` base class. Thus, we:
501
511
  - Can use the `nan` section shown above to replace `NaN` values with something else
502
512
  - For whatever remains we remove the entries from the training.
503
513
 
@@ -539,7 +549,7 @@ When evaluating the model with real data, problems might occur, we deal with the
539
549
  ```python
540
550
  model.cfg
541
551
  ```
542
- - For whatever entries that are still NaN, they will be _patched_ with zeros and evaluated. However, before returning, the probabilities will be
552
+ - For whatever features that are still NaN, they will be _patched_ with zeros when evaluated. However, the returned probabilities will be
543
553
  saved as -1. I.e. entries with NaNs will have probabilities of -1.
544
554
 
545
555
  # Pandas dataframes
@@ -674,6 +684,9 @@ ptr.run()
674
684
  where the config dictionary `cfg_dat` in YAML would look like:
675
685
 
676
686
  ```yaml
687
+ general:
688
+ # This will set the figure size
689
+ size : [20, 10]
677
690
  selection:
678
691
  #Will do at most 50K random entries. Will only happen if the dataset has more than 50K entries
679
692
  max_ran_entries : 50000
@@ -703,6 +716,16 @@ plots:
703
716
  yscale : 'linear'
704
717
  labels : ['x + y', 'Entries']
705
718
  normalized : true #This should normalize to the area
719
+ # Some vertical dashed lines are drawn by default
720
+ # If you see them, you can turn them off with this
721
+ style:
722
+ skip_lines : true
723
+ # This can pass arguments to legend making function `plt.legend()` in matplotlib
724
+ legend:
725
+ # The line below would place the legend outside the figure to avoid ovelaps with the histogram
726
+ bbox_to_anchor : [1.2, 1]
727
+ stats:
728
+ nentries : '{:.2e}' # This will add number of entries in legend box
706
729
  ```
707
730
 
708
731
  it's up to the user to build this dictionary and load it.
@@ -724,14 +747,19 @@ The config would look like:
724
747
  ```yaml
725
748
  saving:
726
749
  plt_dir : tests/plotting/2d
750
+ selection:
751
+ cuts:
752
+ xlow : x > -1.5
727
753
  general:
728
754
  size : [20, 10]
729
755
  plots_2d:
730
756
  # Column x and y
731
757
  # Name of column where weights are, null for not weights
732
758
  # Name of output plot, e.g. xy_x.png
733
- - [x, y, weights, 'xy_w']
734
- - [x, y, null, 'xy_r']
759
+ # Book signaling to use log scale for z axis
760
+ - [x, y, weights, 'xy_w', false]
761
+ - [x, y, null, 'xy_r', false]
762
+ - [x, y, null, 'xy_l', true]
735
763
  axes:
736
764
  x :
737
765
  binning : [-5.0, 8.0, 40]
@@ -823,7 +851,7 @@ Directory/Treename
823
851
  B_ENDVERTEX_CHI2DOF Double_t
824
852
  ```
825
853
 
826
- ## Comparing ROOT files
854
+ ## Comparing ROOT files
827
855
 
828
856
  Given two ROOT files the command below:
829
857
 
@@ -885,7 +913,7 @@ last_file = get_latest_file(dir_path = file_dir, wc='name_*.txt')
885
913
  # of directories in `dir_path`, e.g.:
886
914
 
887
915
  oversion=get_last_version(dir_path=dir_path, version_only=True) # This will return only the version, e.g. v3.2
888
- oversion=get_last_version(dir_path=dir_path, version_only=False) # This will return full path, e.g. /a/b/c/v3.2
916
+ oversion=get_last_version(dir_path=dir_path, version_only=False) # This will return full path, e.g. /a/b/c/v3.2
889
917
  ```
890
918
 
891
919
  The function above should work for numeric (e.g. `v1.2`) and non-numeric (e.g. `va`, `vb`) versions.
@@ -1,17 +1,17 @@
1
- data_manipulation_utilities-0.2.4.data/scripts/publish,sha256=-3K_Y2_4CfWCV50rPB8CRuhjxDu7xMGswinRwPovgLs,1976
1
+ data_manipulation_utilities-0.2.6.data/scripts/publish,sha256=-3K_Y2_4CfWCV50rPB8CRuhjxDu7xMGswinRwPovgLs,1976
2
2
  dmu/arrays/utilities.py,sha256=PKoYyybPptA2aU-V3KLnJXBudWxTXu4x1uGdIMQ49HY,1722
3
3
  dmu/generic/utilities.py,sha256=0Xnq9t35wuebAqKxbyAiMk1ISB7IcXK4cFH25MT1fgw,1741
4
4
  dmu/generic/version_management.py,sha256=G_HjGY-hu8lotZuTdVAg0B8yD0AltE866q2vJxvTg1g,3749
5
5
  dmu/logging/log_store.py,sha256=umdvjNDuV3LdezbG26b0AiyTglbvkxST19CQu9QATbA,4184
6
- dmu/ml/cv_classifier.py,sha256=8Jwx6xMhJaRLktlRdq0tFl32v6t8i63KmpxrlnXlomU,3759
7
- dmu/ml/cv_predict.py,sha256=4G7F_1yOvnLftsDC6zUpdvkxuHXGkPemhj0RsYySYDM,6708
8
- dmu/ml/train_mva.py,sha256=SZ5cQHl7HBxn0c5Hh4HlN1aqMZaJUAlNmsfjnUSQrTg,16894
9
- dmu/ml/utilities.py,sha256=l348bufD95CuSYdIrHScQThIy2nKwGKXZn-FQg3CEwg,3930
6
+ dmu/ml/cv_classifier.py,sha256=ZbzEm_jW9yoTC7k_xBA7hFpc1bDNayiVR3tbaj1_ieE,4228
7
+ dmu/ml/cv_predict.py,sha256=4wwYL_jcUExDqLJVfClxEUWSd_QAx8yKHO3rX-mx4vw,6711
8
+ dmu/ml/train_mva.py,sha256=Tjtm_cXIiC5syaeUXsPAK4NKLbgkDdly17qbiOIT_Go,17608
9
+ dmu/ml/utilities.py,sha256=PK_61fW7gBV9aGZyez3PI8zAT7_Fc6IlQzDB7f8iBTM,4133
10
10
  dmu/pdataframe/utilities.py,sha256=ypvLiFfJ82ga94qlW3t5dXnvEFwYOXnbtJb2zHwsbqk,987
11
11
  dmu/plotting/matrix.py,sha256=pXuUJn-LgOvrI9qGkZQw16BzLjOjeikYQ_ll2VIcIXU,4978
12
- dmu/plotting/plotter.py,sha256=ytMxtzHEY8ZFU0ZKEBE-ROjMszXl5kHTMnQnWe173nU,7208
13
- dmu/plotting/plotter_1d.py,sha256=g6H2xAgsL9a6vRkpbqHICb3qwV_qMiQPZxxw_oOSf9M,5115
14
- dmu/plotting/plotter_2d.py,sha256=J-gKnagoHGfJFU7HBrhDFpGYH5Rxy0_zF5l8eE_7ZHE,2944
12
+ dmu/plotting/plotter.py,sha256=3WRbNOrFBWgI3iW5TbEgT4w_eF7-XUPs_32JL1AW3yY,7359
13
+ dmu/plotting/plotter_1d.py,sha256=2AnVxulyhKtwN-2Srhfm6fqdEREZNhcpJolBsJrWcsc,5745
14
+ dmu/plotting/plotter_2d.py,sha256=mZhp3D5I-JodOnFTEF1NqHtcLtuI-2WNpCQsrsoXNtw,3017
15
15
  dmu/plotting/utilities.py,sha256=SI9dvtZq2gr-PXVz71KE4o0i09rZOKgqJKD1jzf6KXk,1167
16
16
  dmu/rdataframe/atr_mgr.py,sha256=FdhaQWVpsm4OOe1IRbm7rfrq8VenTNdORyI-lZ2Bs1M,2386
17
17
  dmu/rdataframe/utilities.py,sha256=pNcQARMP7txMhy6k27UnDcYf0buNy5U2fshaJDl_h8o,3661
@@ -20,21 +20,23 @@ dmu/rfile/utilities.py,sha256=XuYY7HuSBj46iSu3c60UYBHtI6KIPoJU_oofuhb-be0,945
20
20
  dmu/stats/fitter.py,sha256=vHNZ16U3apoQyeyM8evq-if49doF48sKB3q9wmA96Fw,18387
21
21
  dmu/stats/function.py,sha256=yzi_Fvp_ASsFzbWFivIf-comquy21WoeY7is6dgY0Go,9491
22
22
  dmu/stats/gof_calculator.py,sha256=4EN6OhULcztFvsAZ00rxgohJemnjtDNB5o0IBcv6kbk,4657
23
- dmu/stats/minimizers.py,sha256=f9cilFY9Kp9UvbSIUsKBGFzOOg7EEWZJLPod-4k-LAQ,6216
24
- dmu/stats/model_factory.py,sha256=LyDOf0f9I5dNUTS0MXHtSivD8aAcTLIagvMPtoXtThk,7426
23
+ dmu/stats/minimizers.py,sha256=db9R2G0SOV-k0BKi6m4EyB_yp6AtZdP23_28B0315oo,7094
24
+ dmu/stats/model_factory.py,sha256=QobbhhMFUg61icB_P2grNFsftf_kl6gELjj1mkC9YSw,9115
25
25
  dmu/stats/utilities.py,sha256=LQy4kd3xSXqpApcWuYfZxkGQyjowaXv2Wr1c4Bj-4ys,4523
26
26
  dmu/stats/zfit_plotter.py,sha256=Xs6kisNEmNQXhYRCcjowxO6xHuyAyrfyQIFhGAR61U4,19719
27
- dmu/testing/utilities.py,sha256=WbMM4e9Cn3-B-12Vr64mB5qTKkV32joStlRkD-48lG0,3460
27
+ dmu/testing/utilities.py,sha256=moImLqGX9LAt5zJtE5j0gHHkUJ5kpbodryhiVswOsyM,3696
28
28
  dmu/text/transformer.py,sha256=4lrGknbAWRm0-rxbvgzOO-eR1-9bkYk61boJUEV3cQ0,6100
29
29
  dmu_data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
30
- dmu_data/ml/tests/train_mva.yaml,sha256=k5H4Gu9Gj57B9iqabhcTQEFN674Cv_uJ2Xcumb02zF4,1279
31
- dmu_data/plotting/tests/2d.yaml,sha256=VApcAfJFbjNcjMCTBSRm2P37MQlGavMZv6msbZwLSgw,402
30
+ dmu_data/ml/tests/train_mva.yaml,sha256=jtYwBY_VELCgXY24e7eQYEvKQLsPtbFXgXEeOkYunvY,1291
31
+ dmu_data/plotting/tests/2d.yaml,sha256=HSAtER-8CEqIGBY_jdcIdSVOHMfYPYhmgeZghTpVYh8,516
32
32
  dmu_data/plotting/tests/fig_size.yaml,sha256=7ROq49nwZ1A2EbPiySmu6n3G-Jq6YAOkc3d2X3YNZv0,294
33
33
  dmu_data/plotting/tests/high_stat.yaml,sha256=bLglBLCZK6ft0xMhQ5OltxE76cWsBMPMjO6GG0OkDr8,522
34
+ dmu_data/plotting/tests/legend.yaml,sha256=wGpj58ig-GOlqbWoN894zrCet2Fj9f5QtY0rig_UC-c,213
34
35
  dmu_data/plotting/tests/name.yaml,sha256=mkcPAVg8wBAmlSbSRQ1bcaMl4vOS6LXMtpqQeDrrtO4,312
35
36
  dmu_data/plotting/tests/no_bounds.yaml,sha256=8e1QdphBjz-suDr857DoeUC2DXiy6SE-gvkORJQYv80,257
36
37
  dmu_data/plotting/tests/normalized.yaml,sha256=Y0eKtyV5pvlSxvqfsLjytYtv8xYF3HZ5WEdCJdeHGQI,193
37
38
  dmu_data/plotting/tests/simple.yaml,sha256=N_TvNBh_2dU0-VYgu_LMrtY0kV_hg2HxVuEoDlr1HX8,138
39
+ dmu_data/plotting/tests/stats.yaml,sha256=fSZjoV-xPnukpCH2OAXsz_SNPjI113qzDg8Ln3spaaA,165
38
40
  dmu_data/plotting/tests/title.yaml,sha256=bawKp9aGpeRrHzv69BOCbFX8sq9bb3Es9tdsPTE7jIk,333
39
41
  dmu_data/plotting/tests/weights.yaml,sha256=RWQ1KxbCq-uO62WJ2AoY4h5Umc37zG35s-TpKnNMABI,312
40
42
  dmu_data/text/transform.toml,sha256=R-832BZalzHZ6c5gD6jtT_Hj8BCsM5vxa1v6oeiwaP4,94
@@ -48,8 +50,8 @@ dmu_scripts/rfile/compare_root_files.py,sha256=T8lDnQxsRNMr37x1Y7YvWD8ySHrJOWZki
48
50
  dmu_scripts/rfile/print_trees.py,sha256=Ze4Ccl_iUldl4eVEDVnYBoe4amqBT1fSBR1zN5WSztk,941
49
51
  dmu_scripts/ssh/coned.py,sha256=lhilYNHWRCGxC-jtyJ3LQ4oUgWW33B2l1tYCcyHHsR0,4858
50
52
  dmu_scripts/text/transform_text.py,sha256=9akj1LB0HAyopOvkLjNOJiptZw5XoOQLe17SlcrGMD0,1456
51
- data_manipulation_utilities-0.2.4.dist-info/METADATA,sha256=Gc-ZuL88YHEK3pOK1IfQmaN6rKCcVVqrFS2VlT70jyk,29229
52
- data_manipulation_utilities-0.2.4.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
53
- data_manipulation_utilities-0.2.4.dist-info/entry_points.txt,sha256=1TIZDed651KuOH-DgaN5AoBdirKmrKE_oM1b6b7zTUU,270
54
- data_manipulation_utilities-0.2.4.dist-info/top_level.txt,sha256=n_x5J6uWtSqy9mRImKtdA2V2NJNyU8Kn3u8DTOKJix0,25
55
- data_manipulation_utilities-0.2.4.dist-info/RECORD,,
53
+ data_manipulation_utilities-0.2.6.dist-info/METADATA,sha256=P9-pWYbzx2C-dntMHgb85WS64CVPRu8BeRPxvJOk3VE,30187
54
+ data_manipulation_utilities-0.2.6.dist-info/WHEEL,sha256=jB7zZ3N9hIM9adW7qlTAyycLYW9npaWKLRzaoVcLKcM,91
55
+ data_manipulation_utilities-0.2.6.dist-info/entry_points.txt,sha256=1TIZDed651KuOH-DgaN5AoBdirKmrKE_oM1b6b7zTUU,270
56
+ data_manipulation_utilities-0.2.6.dist-info/top_level.txt,sha256=n_x5J6uWtSqy9mRImKtdA2V2NJNyU8Kn3u8DTOKJix0,25
57
+ data_manipulation_utilities-0.2.6.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.8.0)
2
+ Generator: setuptools (75.8.2)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
dmu/ml/cv_classifier.py CHANGED
@@ -1,15 +1,15 @@
1
1
  '''
2
2
  Module holding cv_classifier class
3
3
  '''
4
-
4
+ import os
5
5
  from typing import Union
6
6
  from sklearn.ensemble import GradientBoostingClassifier
7
7
 
8
+ import yaml
8
9
  from dmu.logging.log_store import LogStore
9
10
  import dmu.ml.utilities as ut
10
11
 
11
12
  log = LogStore.add_logger('dmu:ml:CVClassifier')
12
-
13
13
  # ---------------------------------------
14
14
  class CVSameData(Exception):
15
15
  '''
@@ -61,6 +61,20 @@ class CVClassifier(GradientBoostingClassifier):
61
61
 
62
62
  return self._cfg
63
63
  # ----------------------------------
64
+ def save_cfg(self, path : str):
65
+ '''
66
+ Will save configuration used to train this classifier to YAML
67
+
68
+ path: Path to YAML file
69
+ '''
70
+ dir_name = os.path.dirname(path)
71
+ os.makedirs(dir_name, exist_ok=True)
72
+
73
+ with open(path, 'w', encoding='utf-8') as ofile:
74
+ yaml.safe_dump(self._cfg, ofile, indent=2)
75
+
76
+ log.info(f'Saved config to: {path}')
77
+ # ----------------------------------
64
78
  def __str__(self):
65
79
  nhash = len(self._s_hash)
66
80
 
dmu/ml/cv_predict.py CHANGED
@@ -73,11 +73,11 @@ class CVPredict:
73
73
  log.debug('Not doing any NaN replacement')
74
74
  return df
75
75
 
76
- log.debug(60 * '-')
76
+ log.info(60 * '-')
77
77
  log.info('Doing NaN replacements')
78
- log.debug(60 * '-')
78
+ log.info(60 * '-')
79
79
  for var, val in self._d_nan_rep.items():
80
- log.debug(f'{var:<20}{"--->":20}{val:<20.3f}')
80
+ log.info(f'{var:<20}{"--->":20}{val:<20.3f}')
81
81
  df[var] = df[var].fillna(val)
82
82
 
83
83
  return df
@@ -155,7 +155,7 @@ class CVPredict:
155
155
  ndif = len(s_dif_hash)
156
156
  ndat = len(s_dat_hash)
157
157
  nmod = len(s_mod_hash)
158
- log.debug(f'{ndif:<20}{"=":10}{ndat:<20}{"-":10}{nmod:<20}')
158
+ log.debug(f'{ndif:<10}{"=":5}{ndat:<10}{"-":5}{nmod:<10}')
159
159
 
160
160
  df_ft_group= df_ft.loc[df_ft.index.isin(s_dif_hash)]
161
161
 
@@ -173,7 +173,7 @@ class CVPredict:
173
173
  return arr_prb
174
174
 
175
175
  nentries = len(self._arr_patch)
176
- log.warning(f'Patching {nentries} probabilities')
176
+ log.warning(f'Patching {nentries} probabilities with -1')
177
177
  arr_prb[self._arr_patch] = -1
178
178
 
179
179
  return arr_prb
dmu/ml/train_mva.py CHANGED
@@ -1,7 +1,7 @@
1
1
  '''
2
2
  Module with TrainMva class
3
3
  '''
4
- # pylint: disable = too-many-locals
4
+ # pylint: disable = too-many-locals, no-name-in-module
5
5
  # pylint: disable = too-many-arguments, too-many-positional-arguments
6
6
 
7
7
  import os
@@ -14,7 +14,7 @@ import matplotlib.pyplot as plt
14
14
  from sklearn.metrics import roc_curve, auc
15
15
  from sklearn.model_selection import StratifiedKFold
16
16
 
17
- from ROOT import RDataFrame
17
+ from ROOT import RDataFrame, RDF
18
18
 
19
19
  import dmu.ml.utilities as ut
20
20
  import dmu.pdataframe.utilities as put
@@ -33,61 +33,71 @@ class TrainMva:
33
33
  Interface to scikit learn used to train classifier
34
34
  '''
35
35
  # ---------------------------------------------
36
- def __init__(self, bkg=None, sig=None, cfg=None):
36
+ def __init__(self, bkg : RDataFrame, sig : RDataFrame, cfg : dict):
37
37
  '''
38
38
  bkg (ROOT dataframe): Holds real data
39
39
  sig (ROOT dataframe): Holds simulation
40
40
  cfg (dict) : Dictionary storing configuration for training
41
41
  '''
42
- if bkg is None:
43
- raise ValueError('Background dataframe is not a ROOT dataframe')
44
-
45
- if sig is None:
46
- raise ValueError('Signal dataframe is not a ROOT dataframe')
47
-
48
- if not isinstance(cfg, dict):
49
- raise ValueError('Config dictionary is not a dictionary')
42
+ self._cfg = cfg
43
+ self._l_ft_name = self._cfg['training']['features']
50
44
 
51
- self._rdf_bkg = bkg
52
- self._rdf_sig = sig
53
- self._cfg = cfg
45
+ df_ft_sig, l_lab_sig = self._get_sample_inputs(rdf = sig, label = 1)
46
+ df_ft_bkg, l_lab_bkg = self._get_sample_inputs(rdf = bkg, label = 0)
54
47
 
55
- self._l_ft_name = self._cfg['training']['features']
48
+ self._df_ft = pnd.concat([df_ft_sig, df_ft_bkg], axis=0)
49
+ self._l_lab = numpy.array(l_lab_sig + l_lab_bkg)
56
50
 
57
- self._df_ft, self._l_lab = self._get_inputs()
51
+ self._rdf_bkg = self._get_rdf(rdf = bkg, df=df_ft_bkg)
52
+ self._rdf_sig = self._get_rdf(rdf = sig, df=df_ft_sig)
58
53
  # ---------------------------------------------
59
- def _get_inputs(self) -> tuple[pnd.DataFrame, npa]:
60
- log.info('Getting signal')
61
- df_sig, arr_lab_sig = self._get_sample_inputs(self._rdf_sig, label = 1)
54
+ def _get_rdf(self, rdf : RDataFrame, df : pnd.DataFrame) -> RDataFrame:
55
+ '''
56
+ Takes original ROOT dataframe and pre-processed features dataframe
57
+ Adds missing branches to latter and returns expanded ROOT dataframe
58
+ '''
62
59
 
63
- log.info('Getting background')
64
- df_bkg, arr_lab_bkg = self._get_sample_inputs(self._rdf_bkg, label = 0)
60
+ l_pnd_col = df.columns.tolist()
61
+ l_rdf_col = [ name.c_str() for name in rdf.GetColumnNames() ]
62
+ l_mis_col = [ col for col in l_rdf_col if col not in l_pnd_col ]
65
63
 
66
- df = pnd.concat([df_sig, df_bkg], axis=0)
67
- arr_lab = numpy.concatenate([arr_lab_sig, arr_lab_bkg])
64
+ log.debug(f'Adding extra-nonfeature columns: {l_mis_col}')
68
65
 
69
- return df, arr_lab
66
+ d_data = rdf.AsNumpy(l_mis_col)
67
+ df_ext = pnd.DataFrame(d_data)
68
+ df_all = pnd.concat([df, df_ext], axis=1)
69
+
70
+ return RDF.FromPandas(df_all)
70
71
  # ---------------------------------------------
71
72
  def _pre_process_nans(self, df : pnd.DataFrame) -> pnd.DataFrame:
73
+ if 'dataset' not in self._cfg:
74
+ return df
75
+
72
76
  if 'nan' not in self._cfg['dataset']:
73
77
  log.debug('dataset/nan section not found, not pre-processing NaNs')
74
78
  return df
75
79
 
76
80
  d_name_val = self._cfg['dataset']['nan']
77
- for name, val in d_name_val.items():
78
- log.debug(f'{val:<20}{"<---":<10}{name:<100}')
79
- df[name] = df[name].fillna(val)
81
+ log.info(70 * '-')
82
+ log.info('Doing NaN replacements')
83
+ log.info(70 * '-')
84
+ for var, val in d_name_val.items():
85
+ nna = df[var].isna().sum()
86
+
87
+ log.info(f'{var:<20}{"--->":20}{val:<20.3f}{nna}')
88
+ df[var] = df[var].fillna(val)
89
+ log.info(70 * '-')
80
90
 
81
91
  return df
82
92
  # ---------------------------------------------
83
- def _get_sample_inputs(self, rdf : RDataFrame, label : int) -> tuple[pnd.DataFrame, npa]:
93
+ def _get_sample_inputs(self, rdf : RDataFrame, label : int) -> tuple[pnd.DataFrame, list[int]]:
84
94
  d_ft = rdf.AsNumpy(self._l_ft_name)
85
95
  df = pnd.DataFrame(d_ft)
86
96
  df = self._pre_process_nans(df)
87
97
  df = ut.cleanup(df)
88
98
  l_lab= len(df) * [label]
89
99
 
90
- return df, numpy.array(l_lab)
100
+ return df, l_lab
91
101
  # ---------------------------------------------
92
102
  def _get_model(self, arr_index : npa) -> cls:
93
103
  model = cls(cfg = self._cfg)
@@ -406,6 +416,9 @@ class TrainMva:
406
416
  self._save_hyperparameters_to_tex()
407
417
  # ---------------------------------------------
408
418
  def _save_nan_conversion(self) -> None:
419
+ if 'dataset' not in self._cfg:
420
+ return
421
+
409
422
  if 'nan' not in self._cfg['dataset']:
410
423
  log.debug('NaN section not found, not saving it')
411
424
  return
@@ -434,13 +447,18 @@ class TrainMva:
434
447
  os.makedirs(val_dir, exist_ok=True)
435
448
  put.df_to_tex(df, f'{val_dir}/hyperparameters.tex')
436
449
  # ---------------------------------------------
437
- def run(self):
450
+ def run(self, skip_fit : bool = False) -> None:
438
451
  '''
439
452
  Will do the training
453
+
454
+ skip_fit: By default false, if True, it will only do the plots of features and save tables
440
455
  '''
441
456
  self._save_settings_to_tex()
442
457
  self._plot_features()
443
458
 
459
+ if skip_fit:
460
+ return
461
+
444
462
  l_mod = self._get_models()
445
463
  for ifold, mod in enumerate(l_mod):
446
464
  self._save_model(mod, ifold)
dmu/ml/utilities.py CHANGED
@@ -16,7 +16,7 @@ log = LogStore.add_logger('dmu:ml:utilities')
16
16
  # ---------------------------------------------
17
17
  def patch_and_tag(df : pnd.DataFrame, value : float = 0) -> pnd.DataFrame:
18
18
  '''
19
- Takes panda dataframe, replaces NaNs with value introduced, by default 0
19
+ Takes pandas dataframe, replaces NaNs with value introduced, by default 0
20
20
  Returns array of indices where the replacement happened
21
21
  '''
22
22
  l_nan = df.index[df.isna().any(axis=1)].tolist()
@@ -25,7 +25,13 @@ def patch_and_tag(df : pnd.DataFrame, value : float = 0) -> pnd.DataFrame:
25
25
  log.debug('No NaNs found')
26
26
  return df
27
27
 
28
- log.warning(f'Found {nnan} NaNs, patching them with {value}')
28
+ log.warning(f'Found {nnan} NaNs')
29
+
30
+ df_nan_frq = df.isna().sum()
31
+ df_nan_frq = df_nan_frq[df_nan_frq > 0]
32
+ print(df_nan_frq)
33
+
34
+ log.warning(f'Attaching array with NaN {nnan} indexes and removing NaNs from dataframe')
29
35
 
30
36
  df_pa = df.fillna(value)
31
37
 
@@ -57,7 +63,7 @@ def _remove_nans(df : pnd.DataFrame) -> pnd.DataFrame:
57
63
  log.info('Found columns with NaNs')
58
64
  for name in l_na_name:
59
65
  nan_count = df[name].isna().sum()
60
- log.info(f'{nan_count:<10}{name:<100}')
66
+ log.info(f'{nan_count:<10}{name}')
61
67
 
62
68
  ninit = len(df)
63
69
  df = df.dropna()
@@ -75,10 +81,10 @@ def _remove_repeated(df : pnd.DataFrame) -> pnd.DataFrame:
75
81
  nfinl = len(s_hash)
76
82
 
77
83
  if ninit == nfinl:
78
- log.debug('No cleaning needed for dataframe')
84
+ log.debug('No overlap between training and application found')
79
85
  return df
80
86
 
81
- log.warning(f'Repeated entries found, cleaning up: {ninit} -> {nfinl}')
87
+ log.warning(f'Overlap between training and application found, cleaning up: {ninit} -> {nfinl}')
82
88
 
83
89
  df['hash_index'] = l_hash
84
90
  df = df.set_index('hash_index', drop=True)
dmu/plotting/plotter.py CHANGED
@@ -107,7 +107,7 @@ class Plotter:
107
107
 
108
108
  d_cut = self._d_cfg['selection']['cuts']
109
109
 
110
- log.info('Applying cuts')
110
+ log.debug('Applying cuts')
111
111
  for name, cut in d_cut.items():
112
112
  log.debug(f'{name:<50}{cut:<150}')
113
113
  rdf = rdf.Filter(cut, name)
@@ -212,7 +212,11 @@ class Plotter:
212
212
 
213
213
  var (str) : Name of variable, needed for plot name
214
214
  '''
215
- plt.legend()
215
+ d_leg = {}
216
+ if 'style' in self._d_cfg and 'legend' in self._d_cfg['style']:
217
+ d_leg = self._d_cfg['style']['legend']
218
+
219
+ plt.legend(**d_leg)
216
220
 
217
221
  plt_dir = self._d_cfg['saving']['plt_dir']
218
222
  os.makedirs(plt_dir, exist_ok=True)
@@ -77,17 +77,33 @@ class Plotter1D(Plotter):
77
77
 
78
78
  l_bc_all = []
79
79
  for name, arr_val in d_data.items():
80
+ label = self._label_from_name(name, arr_val)
80
81
  arr_wgt = d_wgt[name] if d_wgt is not None else numpy.ones_like(arr_val)
81
82
  arr_wgt = self._normalize_weights(arr_wgt, var)
82
- hst = Hist.new.Reg(bins=bins, start=minx, stop=maxx, name='x', label=name).Weight()
83
+ hst = Hist.new.Reg(bins=bins, start=minx, stop=maxx, name='x').Weight()
83
84
  hst.fill(x=arr_val, weight=arr_wgt)
84
- hst.plot(label=name)
85
+ hst.plot(label=label)
85
86
  l_bc_all += hst.values().tolist()
86
87
 
87
88
  max_y = max(l_bc_all)
88
89
 
89
90
  return max_y
90
91
  # --------------------------------------------
92
+ def _label_from_name(self, name : str, arr_val : numpy.ndarray) -> str:
93
+ if 'stats' not in self._d_cfg:
94
+ return name
95
+
96
+ d_stat = self._d_cfg['stats']
97
+ if 'nentries' not in d_stat:
98
+ return name
99
+
100
+ form = d_stat['nentries']
101
+
102
+ nentries = len(arr_val)
103
+ nentries = form.format(nentries)
104
+
105
+ return f'{name}{nentries}'
106
+ # --------------------------------------------
91
107
  def _normalize_weights(self, arr_wgt : numpy.ndarray, var : str) -> numpy.ndarray:
92
108
  cfg_var = self._d_cfg['plots'][var]
93
109
  if 'normalized' not in cfg_var:
@@ -104,7 +120,6 @@ class Plotter1D(Plotter):
104
120
 
105
121
  return arr_wgt
106
122
  # --------------------------------------------
107
-
108
123
  def _style_plot(self, var : str, max_y : float) -> None:
109
124
  d_cfg = self._d_cfg['plots'][var]
110
125
  yscale = d_cfg['yscale' ] if 'yscale' in d_cfg else 'linear'
@@ -124,12 +139,15 @@ class Plotter1D(Plotter):
124
139
  plt.legend()
125
140
  plt.title(title)
126
141
  # --------------------------------------------
127
- def _plot_lines(self, var : str):
142
+ def _plot_lines(self, var : str) -> None:
128
143
  '''
129
144
  Will plot vertical lines for some variables
130
145
 
131
146
  var (str) : name of variable
132
147
  '''
148
+ if 'style' in self._d_cfg and 'skip_lines' in self._d_cfg['style'] and self._d_cfg['style']['skip_lines']:
149
+ return
150
+
133
151
  if var in ['B_const_mass_M', 'B_M']:
134
152
  plt.axvline(x=5280, color='r', label=r'$B^+$' , linestyle=':')
135
153
  elif var == 'Jpsi_M':
@@ -10,6 +10,7 @@ import matplotlib.pyplot as plt
10
10
 
11
11
  from hist import Hist
12
12
  from ROOT import RDataFrame
13
+ from matplotlib.colors import LogNorm
13
14
  from dmu.logging.log_store import LogStore
14
15
  from dmu.plotting.plotter import Plotter
15
16
 
@@ -28,11 +29,8 @@ class Plotter2D(Plotter):
28
29
  cfg (dict): Dictionary with configuration, e.g. binning, ranges, etc
29
30
  '''
30
31
 
31
- if not isinstance(cfg, dict):
32
- raise ValueError('Config dictionary not passed')
33
-
34
- self._d_cfg : dict = cfg
35
- self._rdf : RDataFrame = super()._preprocess_rdf(rdf)
32
+ super().__init__({'single_rdf' : rdf}, cfg)
33
+ self._rdf : RDataFrame = self._d_rdf['single_rdf']
36
34
 
37
35
  self._wgt : numpy.ndarray
38
36
  # --------------------------------------------
@@ -61,7 +59,7 @@ class Plotter2D(Plotter):
61
59
 
62
60
  return arr_wgt
63
61
  # --------------------------------------------
64
- def _plot_vars(self, varx : str, vary : str, wgt_name : str) -> None:
62
+ def _plot_vars(self, varx : str, vary : str, wgt_name : str, use_log : bool) -> None:
65
63
  log.info(f'Plotting {varx} vs {vary} with weights {wgt_name}')
66
64
 
67
65
  ax_x = self._get_axis(varx)
@@ -72,7 +70,10 @@ class Plotter2D(Plotter):
72
70
  hst = Hist(ax_x, ax_y)
73
71
  hst.fill(arr_x, arr_y, weight=arr_w)
74
72
 
75
- mplhep.hist2dplot(hst)
73
+ if use_log:
74
+ mplhep.hist2dplot(hst, norm=LogNorm())
75
+ else:
76
+ mplhep.hist2dplot(hst)
76
77
  # --------------------------------------------
77
78
  def run(self):
78
79
  '''
@@ -80,8 +81,8 @@ class Plotter2D(Plotter):
80
81
  '''
81
82
 
82
83
  fig_size = self._get_fig_size()
83
- for [varx, vary, wgt_name, plot_name] in self._d_cfg['plots_2d']:
84
+ for [varx, vary, wgt_name, plot_name, use_log] in self._d_cfg['plots_2d']:
84
85
  plt.figure(plot_name, figsize=fig_size)
85
- self._plot_vars(varx, vary, wgt_name)
86
+ self._plot_vars(varx, vary, wgt_name, use_log)
86
87
  self._save_plot(plot_name)
87
88
  # --------------------------------------------
dmu/stats/minimizers.py CHANGED
@@ -1,12 +1,16 @@
1
1
  '''
2
2
  Module containing derived classes from ZFit minimizer
3
3
  '''
4
+ from typing import Union
4
5
  import numpy
5
6
 
6
7
  import zfit
8
+ import matplotlib.pyplot as plt
9
+
7
10
  from zfit.result import FitResult
8
11
  from zfit.core.basepdf import BasePDF as zpdf
9
12
  from zfit.minimizers.baseminimizer import FailMinimizeNaN
13
+ from dmu.stats.utilities import print_pdf
10
14
  from dmu.stats.gof_calculator import GofCalculator
11
15
  from dmu.logging.log_store import LogStore
12
16
 
@@ -29,6 +33,7 @@ class AnealingMinimizer(zfit.minimize.Minuit):
29
33
  self._chi2ndof = chi2ndof
30
34
 
31
35
  self._check_thresholds()
36
+ self._l_bad_fit_res : list[FitResult] = []
32
37
 
33
38
  super().__init__()
34
39
  # ------------------------
@@ -66,19 +71,24 @@ class AnealingMinimizer(zfit.minimize.Minuit):
66
71
  return is_good
67
72
  # ------------------------
68
73
  def _is_good_fit(self, res : FitResult) -> bool:
74
+ good_fit = True
75
+
69
76
  if not res.valid:
70
- log.warning('Skipping invalid fit')
71
- return False
77
+ log.debug('Skipping invalid fit')
78
+ good_fit = False
72
79
 
73
80
  if res.status != 0:
74
- log.warning('Skipping fit with bad status')
75
- return False
81
+ log.debug('Skipping fit with bad status')
82
+ good_fit = False
76
83
 
77
84
  if not res.converged:
78
- log.warning('Skipping non-converging fit')
79
- return False
85
+ log.debug('Skipping non-converging fit')
86
+ good_fit = False
80
87
 
81
- return True
88
+ if not good_fit:
89
+ self._l_bad_fit_res.append(res)
90
+
91
+ return good_fit
82
92
  # ------------------------
83
93
  def _get_gof(self, nll) -> tuple[float, float]:
84
94
  log.debug('Checking GOF')
@@ -108,10 +118,11 @@ class AnealingMinimizer(zfit.minimize.Minuit):
108
118
  par.set_value(fval)
109
119
  log.debug(f'{par.name:<20}{ival:<15.3f}{"->":<10}{fval:<15.3f}{"in":<5}{par.lower:<15.3e}{par.upper:<15.3e}')
110
120
  # ------------------------
111
- def _pick_best_fit(self, d_chi2_res : dict) -> FitResult:
121
+ def _pick_best_fit(self, d_chi2_res : dict) -> Union[FitResult,None]:
112
122
  nres = len(d_chi2_res)
113
123
  if nres == 0:
114
- raise ValueError('No fits found')
124
+ log.error('No fits found')
125
+ return None
115
126
 
116
127
  l_chi2_res= list(d_chi2_res.items())
117
128
  l_chi2_res.sort()
@@ -149,6 +160,15 @@ class AnealingMinimizer(zfit.minimize.Minuit):
149
160
 
150
161
  return l_model[0]
151
162
  # ------------------------
163
+ def _print_failed_fit_diagnostics(self, nll) -> None:
164
+ for res in self._l_bad_fit_res:
165
+ print(res)
166
+
167
+ arr_mass = nll.data[0].numpy()
168
+
169
+ plt.hist(arr_mass, bins=60)
170
+ plt.show()
171
+ # ------------------------
152
172
  def minimize(self, nll, **kwargs) -> FitResult:
153
173
  '''
154
174
  Will run minimization and return FitResult object
@@ -156,18 +176,20 @@ class AnealingMinimizer(zfit.minimize.Minuit):
156
176
 
157
177
  d_chi2_res : dict[float,FitResult] = {}
158
178
  for i_try in range(self._ntries):
159
- log.info(f'try {i_try:02}/{self._ntries:02}')
160
179
  try:
161
180
  res = super().minimize(nll, **kwargs)
162
181
  except (FailMinimizeNaN, ValueError, RuntimeError) as exc:
163
- log.warning(exc)
182
+ log.error(f'{i_try:02}/{self._ntries:02}{"Failed":>20}')
183
+ log.debug(exc)
164
184
  self._randomize_parameters(nll)
165
185
  continue
166
186
 
167
187
  if not self._is_good_fit(res):
188
+ log.warning(f'{i_try:02}/{self._ntries:02}{"Bad fit":>20}')
168
189
  continue
169
190
 
170
191
  chi2, pvl = self._get_gof(nll)
192
+ log.info(f'{i_try:02}/{self._ntries:02}{chi2:>20.3f}')
171
193
  d_chi2_res[chi2] = res
172
194
 
173
195
  if self._is_good_gof(chi2, pvl):
@@ -176,6 +198,13 @@ class AnealingMinimizer(zfit.minimize.Minuit):
176
198
  self._randomize_parameters(nll)
177
199
 
178
200
  res = self._pick_best_fit(d_chi2_res)
201
+ if res is None:
202
+ self._print_failed_fit_diagnostics(nll)
203
+ pdf = nll.model[0]
204
+ print_pdf(pdf)
205
+
206
+ raise ValueError('Fit failed')
207
+
179
208
  pdf = self._pdf_from_nll(nll)
180
209
  self._set_pdf_pars(res, pdf)
181
210
 
@@ -1,7 +1,7 @@
1
1
  '''
2
2
  Module storing ZModel class
3
3
  '''
4
- # pylint: disable=too-many-lines, import-error
4
+ # pylint: disable=too-many-lines, import-error, too-many-positional-arguments, too-many-arguments
5
5
 
6
6
  from typing import Callable, Union
7
7
 
@@ -37,7 +37,16 @@ class MethodRegistry:
37
37
  '''
38
38
  Will return method in charge of building PDF, for an input nickname
39
39
  '''
40
- return cls._d_method.get(nickname, None)
40
+ method = cls._d_method.get(nickname, None)
41
+
42
+ if method is not None:
43
+ return method
44
+
45
+ log.warning('Available PDFs:')
46
+ for value in cls._d_method:
47
+ log.info(f' {value}')
48
+
49
+ return method
41
50
  #-----------------------------------------
42
51
  class ModelFactory:
43
52
  '''
@@ -48,33 +57,56 @@ class ModelFactory:
48
57
 
49
58
  l_pdf = ['dscb', 'gauss']
50
59
  l_shr = ['mu']
51
- mod = ModelFactory(obs = obs, l_pdf = l_pdf, l_shared=l_shr)
60
+ mod = ModelFactory(preffix = 'signal', obs = obs, l_pdf = l_pdf, l_shared=l_shr)
52
61
  pdf = mod.get_pdf()
53
62
  ```
54
63
 
55
64
  where one can specify which parameters can be shared among the PDFs
56
65
  '''
57
66
  #-----------------------------------------
58
- def __init__(self, obs : zobs, l_pdf : list[str], l_shared : list[str]):
67
+ def __init__(self,
68
+ preffix : str,
69
+ obs : zobs,
70
+ l_pdf : list[str],
71
+ l_shared : list[str],
72
+ l_float : list[str]):
59
73
  '''
74
+ preffix: used to identify PDF, will be used to name every parameter
60
75
  obs: zfit obserbable
61
76
  l_pdf: List of PDF nicknames which are registered below
62
77
  l_shared: List of parameter names that are shared
78
+ l_float: List of parameter names to allow to float
63
79
  '''
64
80
 
81
+ self._preffix = preffix
65
82
  self._l_pdf = l_pdf
66
83
  self._l_shr = l_shared
67
- self._l_can_be_shared = ['mu', 'sg']
84
+ self._l_flt = l_float
68
85
  self._obs = obs
69
86
 
70
87
  self._d_par : dict[str,zpar] = {}
71
88
  #-----------------------------------------
72
- def _get_name(self, name : str, suffix : str) -> str:
73
- for can_be_shared in self._l_can_be_shared:
74
- if name.startswith(f'{can_be_shared}_') and can_be_shared in self._l_shr:
75
- return can_be_shared
89
+ def _split_name(self, name : str) -> tuple[str,str]:
90
+ l_part = name.split('_')
91
+ pname = l_part[0]
92
+ xname = '_'.join(l_part[1:])
76
93
 
77
- return f'{name}{suffix}'
94
+ return pname, xname
95
+ #-----------------------------------------
96
+ def _get_parameter_name(self, name : str, suffix : str) -> str:
97
+ pname, xname = self._split_name(name)
98
+
99
+ log.debug(f'Using physical name: {pname}')
100
+
101
+ if pname in self._l_shr:
102
+ name = f'{pname}_{self._preffix}'
103
+ else:
104
+ name = f'{pname}_{xname}_{self._preffix}{suffix}'
105
+
106
+ if pname in self._l_flt:
107
+ return f'{name}_flt'
108
+
109
+ return name
78
110
  #-----------------------------------------
79
111
  def _get_parameter(self,
80
112
  name : str,
@@ -82,7 +114,10 @@ class ModelFactory:
82
114
  val : float,
83
115
  low : float,
84
116
  high : float) -> zpar:
85
- name = self._get_name(name, suffix)
117
+
118
+ name = self._get_parameter_name(name, suffix)
119
+ log.debug(f'Assigning name: {name}')
120
+
86
121
  if name in self._d_par:
87
122
  return self._d_par[name]
88
123
 
@@ -94,15 +129,15 @@ class ModelFactory:
94
129
  #-----------------------------------------
95
130
  @MethodRegistry.register('exp')
96
131
  def _get_exponential(self, suffix : str = '') -> zpdf:
97
- c = self._get_parameter('c_exp', suffix, -0.005, -0.05, 0.00)
98
- pdf = zfit.pdf.Exponential(c, self._obs)
132
+ c = self._get_parameter('c_exp', suffix, -0.005, -0.20, 0.00)
133
+ pdf = zfit.pdf.Exponential(c, self._obs, name=f'exp{suffix}')
99
134
 
100
135
  return pdf
101
136
  #-----------------------------------------
102
137
  @MethodRegistry.register('pol1')
103
138
  def _get_pol1(self, suffix : str = '') -> zpdf:
104
139
  a = self._get_parameter('a_pol1', suffix, -0.005, -0.95, 0.00)
105
- pdf = zfit.pdf.Chebyshev(obs=self._obs, coeffs=[a])
140
+ pdf = zfit.pdf.Chebyshev(obs=self._obs, coeffs=[a], name=f'pol1{suffix}')
106
141
 
107
142
  return pdf
108
143
  #-----------------------------------------
@@ -110,51 +145,62 @@ class ModelFactory:
110
145
  def _get_pol2(self, suffix : str = '') -> zpdf:
111
146
  a = self._get_parameter('a_pol2', suffix, -0.005, -0.95, 0.00)
112
147
  b = self._get_parameter('b_pol2', suffix, 0.000, -0.95, 0.95)
113
- pdf = zfit.pdf.Chebyshev(obs=self._obs, coeffs=[a, b])
148
+ pdf = zfit.pdf.Chebyshev(obs=self._obs, coeffs=[a, b], name=f'pol2{suffix}')
114
149
 
115
150
  return pdf
116
151
  #-----------------------------------------
117
152
  @MethodRegistry.register('cbr')
118
153
  def _get_cbr(self, suffix : str = '') -> zpdf:
119
- mu = self._get_parameter('mu_cbr', suffix, 5300, 5250, 5350)
154
+ mu = self._get_parameter('mu_cbr', suffix, 5300, 5100, 5350)
120
155
  sg = self._get_parameter('sg_cbr', suffix, 10, 2, 300)
121
- ar = self._get_parameter('ac_cbr', suffix, -2, -4., -1.)
122
- nr = self._get_parameter('nc_cbr', suffix, 1, 0.5, 5.0)
156
+ ar = self._get_parameter('ac_cbr', suffix, -2, -14., -0.1)
157
+ nr = self._get_parameter('nc_cbr', suffix, 1, 0.5, 150)
158
+
159
+ pdf = zfit.pdf.CrystalBall(mu, sg, ar, nr, self._obs, name=f'cbr{suffix}')
160
+
161
+ return pdf
162
+ #-----------------------------------------
163
+ @MethodRegistry.register('suj')
164
+ def _get_suj(self, suffix : str = '') -> zpdf:
165
+ mu = self._get_parameter('mu_suj', suffix, 5300, 4000, 6000)
166
+ sg = self._get_parameter('sg_suj', suffix, 10, 2, 5000)
167
+ gm = self._get_parameter('gm_suj', suffix, 1, -10, 10)
168
+ dl = self._get_parameter('dl_suj', suffix, 1, 0.1, 10)
123
169
 
124
- pdf = zfit.pdf.CrystalBall(mu, sg, ar, nr, self._obs)
170
+ pdf = zfit.pdf.JohnsonSU(mu, sg, gm, dl, self._obs, name=f'suj{suffix}')
125
171
 
126
172
  return pdf
127
173
  #-----------------------------------------
128
174
  @MethodRegistry.register('cbl')
129
175
  def _get_cbl(self, suffix : str = '') -> zpdf:
130
- mu = self._get_parameter('mu_cbl', suffix, 5300, 5250, 5350)
176
+ mu = self._get_parameter('mu_cbl', suffix, 5300, 5100, 5350)
131
177
  sg = self._get_parameter('sg_cbl', suffix, 10, 2, 300)
132
- al = self._get_parameter('ac_cbl', suffix, 2, 1., 4.)
133
- nl = self._get_parameter('nc_cbl', suffix, 1, 0.5, 5.0)
178
+ al = self._get_parameter('ac_cbl', suffix, 2, 0.1, 14.)
179
+ nl = self._get_parameter('nc_cbl', suffix, 1, 0.5, 150)
134
180
 
135
- pdf = zfit.pdf.CrystalBall(mu, sg, al, nl, self._obs)
181
+ pdf = zfit.pdf.CrystalBall(mu, sg, al, nl, self._obs, name=f'cbl{suffix}')
136
182
 
137
183
  return pdf
138
184
  #-----------------------------------------
139
185
  @MethodRegistry.register('gauss')
140
186
  def _get_gauss(self, suffix : str = '') -> zpdf:
141
- mu = self._get_parameter('mu_gauss', suffix, 5300, 5250, 5350)
187
+ mu = self._get_parameter('mu_gauss', suffix, 5300, 5100, 5350)
142
188
  sg = self._get_parameter('sg_gauss', suffix, 10, 2, 300)
143
189
 
144
- pdf = zfit.pdf.Gauss(mu, sg, self._obs)
190
+ pdf = zfit.pdf.Gauss(mu, sg, self._obs, name=f'gauss{suffix}')
145
191
 
146
192
  return pdf
147
193
  #-----------------------------------------
148
194
  @MethodRegistry.register('dscb')
149
195
  def _get_dscb(self, suffix : str = '') -> zpdf:
150
- mu = self._get_parameter('mu_dscb', suffix, 5300, 5250, 5400)
151
- sg = self._get_parameter('sg_dscb', suffix, 10, 2, 30)
196
+ mu = self._get_parameter('mu_dscb', suffix, 4000, 4000, 5400)
197
+ sg = self._get_parameter('sg_dscb', suffix, 10, 2, 500)
152
198
  ar = self._get_parameter('ar_dscb', suffix, 1, 0, 5)
153
199
  al = self._get_parameter('al_dscb', suffix, 1, 0, 5)
154
- nr = self._get_parameter('nr_dscb', suffix, 2, 1, 5)
155
- nl = self._get_parameter('nl_dscb', suffix, 2, 0, 5)
200
+ nr = self._get_parameter('nr_dscb', suffix, 2, 1, 150)
201
+ nl = self._get_parameter('nl_dscb', suffix, 2, 0, 150)
156
202
 
157
- pdf = zfit.pdf.DoubleCB(mu, sg, al, nl, ar, nr, self._obs)
203
+ pdf = zfit.pdf.DoubleCB(mu, sg, al, nl, ar, nr, self._obs, name=f'dscb{suffix}')
158
204
 
159
205
  return pdf
160
206
  #-----------------------------------------
@@ -190,7 +236,7 @@ class ModelFactory:
190
236
 
191
237
  l_frc= [ zfit.param.Parameter(f'frc_{ifrc + 1}', 0.5, 0, 1) for ifrc in range(nfrc - 1) ]
192
238
 
193
- pdf = zfit.pdf.SumPDF(l_pdf, fracs=l_frc)
239
+ pdf = zfit.pdf.SumPDF(l_pdf, name=self._preffix, fracs=l_frc)
194
240
 
195
241
  return pdf
196
242
  #-----------------------------------------
dmu/testing/utilities.py CHANGED
@@ -2,6 +2,7 @@
2
2
  Module containing utility functions needed by unit tests
3
3
  '''
4
4
  import os
5
+ import math
5
6
  from typing import Union
6
7
  from dataclasses import dataclass
7
8
  from importlib.resources import files
@@ -21,56 +22,64 @@ class Data:
21
22
  '''
22
23
  Class storing shared data
23
24
  '''
24
- nentries = 3000
25
25
  # -------------------------------
26
- def _double_data(d_data : dict) -> dict:
27
- df_1 = pnd.DataFrame(d_data)
28
- df_2 = pnd.DataFrame(d_data)
29
-
26
+ def _double_data(df_1 : pnd.DataFrame) -> pnd.DataFrame:
27
+ df_2 = df_1.copy()
30
28
  df = pnd.concat([df_1, df_2], axis=0)
31
29
 
32
- d_data = { name : df[name].to_numpy() for name in df.columns }
33
-
34
- return d_data
30
+ return df
35
31
  # -------------------------------
36
- def _add_nans(d_data : dict) -> dict:
37
- df_good = pnd.DataFrame(d_data)
38
- df_bad = pnd.DataFrame(d_data)
39
- df_bad[:] = numpy.nan
32
+ def _add_nans(df : pnd.DataFrame, columns : list[str]) -> pnd.DataFrame:
33
+ size = len(df) * 0.2
34
+ size = math.floor(size)
35
+
36
+ l_col = df.columns.tolist()
37
+ if columns is None:
38
+ l_col_index = range(len(l_col))
39
+ else:
40
+ l_col_index = [ l_col.index(column) for column in columns ]
40
41
 
41
- df = pnd.concat([df_good, df_bad])
42
- d_data = { name : df[name].to_numpy() for name in df.columns }
42
+ log.debug('Replacing randomly with {size} NaNs')
43
+ for _ in range(size):
44
+ irow = numpy.random.randint(0, df.shape[0]) # Random row index
45
+ icol = numpy.random.choice(l_col_index) # Random column index
43
46
 
44
- return d_data
47
+ df.iat[irow, icol] = numpy.nan
48
+
49
+ return df
45
50
  # -------------------------------
46
51
  def get_rdf(kind : Union[str,None] = None,
47
52
  repeated : bool = False,
48
- add_nans : bool = False):
53
+ nentries : int = 3_000,
54
+ add_nans : list[str] = None):
49
55
  '''
50
56
  Return ROOT dataframe with toy data
51
57
  '''
58
+
52
59
  d_data = {}
53
60
  if kind == 'sig':
54
- d_data['w'] = numpy.random.normal(0, 1, size=Data.nentries)
55
- d_data['x'] = numpy.random.normal(0, 1, size=Data.nentries)
56
- d_data['y'] = numpy.random.normal(0, 1, size=Data.nentries)
57
- d_data['z'] = numpy.random.normal(0, 1, size=Data.nentries)
61
+ d_data['w'] = numpy.random.normal(0, 1, size=nentries)
62
+ d_data['x'] = numpy.random.normal(0, 1, size=nentries)
63
+ d_data['y'] = numpy.random.normal(0, 1, size=nentries)
64
+ d_data['z'] = numpy.random.normal(0, 1, size=nentries)
58
65
  elif kind == 'bkg':
59
- d_data['w'] = numpy.random.normal(1, 1, size=Data.nentries)
60
- d_data['x'] = numpy.random.normal(1, 1, size=Data.nentries)
61
- d_data['y'] = numpy.random.normal(1, 1, size=Data.nentries)
62
- d_data['z'] = numpy.random.normal(1, 1, size=Data.nentries)
66
+ d_data['w'] = numpy.random.normal(1, 1, size=nentries)
67
+ d_data['x'] = numpy.random.normal(1, 1, size=nentries)
68
+ d_data['y'] = numpy.random.normal(1, 1, size=nentries)
69
+ d_data['z'] = numpy.random.normal(1, 1, size=nentries)
63
70
  else:
64
71
  log.error(f'Invalid kind: {kind}')
65
72
  raise ValueError
66
73
 
74
+ df = pnd.DataFrame(d_data)
75
+
67
76
  if repeated:
68
- d_data = _double_data(d_data)
77
+ df = _double_data(df)
69
78
 
70
79
  if add_nans:
71
- d_data = _add_nans(d_data)
80
+ df = _add_nans(df, columns=add_nans)
72
81
 
73
- rdf = RDF.FromNumpy(d_data)
82
+ rdf = RDF.FromPandas(df)
74
83
 
75
84
  return rdf
76
85
  # -------------------------------
@@ -1,6 +1,7 @@
1
1
  dataset:
2
2
  nan :
3
- x : 0
3
+ x : -3
4
+ y : -3
4
5
  training :
5
6
  nfold : 3
6
7
  features : [x, y, z]
@@ -49,4 +50,3 @@ plotting:
49
50
  binning : [-4, 4, 100]
50
51
  yscale : 'linear'
51
52
  labels : ['z', '']
52
-
@@ -1,13 +1,17 @@
1
1
  saving:
2
- plt_dir : tests/plotting/2d_weighted
2
+ plt_dir : /tmp/dmu/tests/plotting/2d_weighted
3
+ selection:
4
+ cuts:
5
+ xlow : x > -1.5
3
6
  definitions:
4
7
  z : x + y
5
8
  general:
6
9
  size : [20, 10]
7
10
  plots_2d:
8
- - [x, y, weights, 'xy_w']
9
- - [x, y, null, 'xy_r']
10
- - [x, z, null, 'xz_r']
11
+ - [x, y, weights, 'xy_wgt', false]
12
+ - [x, y, null, 'xy_raw', false]
13
+ - [x, z, null, 'xz_raw', false]
14
+ - [x, z, null, 'xz_log', true]
11
15
  axes:
12
16
  x :
13
17
  binning : [-3.0, 3.0, 40]
@@ -0,0 +1,12 @@
1
+ saving:
2
+ plt_dir : tests/plotting/legend
3
+ general:
4
+ size : [20, 10]
5
+ plots:
6
+ x :
7
+ binning : [-5.0, 8.0, 40]
8
+ y :
9
+ binning : [-5.0, 8.0, 40]
10
+ style:
11
+ legend:
12
+ bbox_to_anchor : [1.2, 1]
@@ -0,0 +1,9 @@
1
+ saving:
2
+ plt_dir : tests/plotting/stats
3
+ plots:
4
+ x :
5
+ binning : [-5.0, 8.0, 40]
6
+ y :
7
+ binning : [-5.0, 8.0, 40]
8
+ stats:
9
+ nentries : '{:.2e}'