data-manipulation-utilities 0.2.5__tar.gz → 0.2.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. {data_manipulation_utilities-0.2.5 → data_manipulation_utilities-0.2.6}/PKG-INFO +3 -3
  2. {data_manipulation_utilities-0.2.5 → data_manipulation_utilities-0.2.6}/README.md +2 -2
  3. {data_manipulation_utilities-0.2.5 → data_manipulation_utilities-0.2.6}/pyproject.toml +1 -1
  4. {data_manipulation_utilities-0.2.5 → data_manipulation_utilities-0.2.6}/src/data_manipulation_utilities.egg-info/PKG-INFO +3 -3
  5. {data_manipulation_utilities-0.2.5 → data_manipulation_utilities-0.2.6}/src/dmu/ml/train_mva.py +33 -29
  6. {data_manipulation_utilities-0.2.5 → data_manipulation_utilities-0.2.6}/src/dmu/stats/minimizers.py +40 -11
  7. {data_manipulation_utilities-0.2.5 → data_manipulation_utilities-0.2.6}/src/dmu/stats/model_factory.py +74 -34
  8. {data_manipulation_utilities-0.2.5 → data_manipulation_utilities-0.2.6}/src/dmu_data/ml/tests/train_mva.yaml +6 -3
  9. {data_manipulation_utilities-0.2.5 → data_manipulation_utilities-0.2.6}/setup.cfg +0 -0
  10. {data_manipulation_utilities-0.2.5 → data_manipulation_utilities-0.2.6}/src/data_manipulation_utilities.egg-info/SOURCES.txt +0 -0
  11. {data_manipulation_utilities-0.2.5 → data_manipulation_utilities-0.2.6}/src/data_manipulation_utilities.egg-info/dependency_links.txt +0 -0
  12. {data_manipulation_utilities-0.2.5 → data_manipulation_utilities-0.2.6}/src/data_manipulation_utilities.egg-info/entry_points.txt +0 -0
  13. {data_manipulation_utilities-0.2.5 → data_manipulation_utilities-0.2.6}/src/data_manipulation_utilities.egg-info/requires.txt +0 -0
  14. {data_manipulation_utilities-0.2.5 → data_manipulation_utilities-0.2.6}/src/data_manipulation_utilities.egg-info/top_level.txt +0 -0
  15. {data_manipulation_utilities-0.2.5 → data_manipulation_utilities-0.2.6}/src/dmu/arrays/utilities.py +0 -0
  16. {data_manipulation_utilities-0.2.5 → data_manipulation_utilities-0.2.6}/src/dmu/generic/utilities.py +0 -0
  17. {data_manipulation_utilities-0.2.5 → data_manipulation_utilities-0.2.6}/src/dmu/generic/version_management.py +0 -0
  18. {data_manipulation_utilities-0.2.5 → data_manipulation_utilities-0.2.6}/src/dmu/logging/log_store.py +0 -0
  19. {data_manipulation_utilities-0.2.5 → data_manipulation_utilities-0.2.6}/src/dmu/ml/cv_classifier.py +0 -0
  20. {data_manipulation_utilities-0.2.5 → data_manipulation_utilities-0.2.6}/src/dmu/ml/cv_predict.py +0 -0
  21. {data_manipulation_utilities-0.2.5 → data_manipulation_utilities-0.2.6}/src/dmu/ml/utilities.py +0 -0
  22. {data_manipulation_utilities-0.2.5 → data_manipulation_utilities-0.2.6}/src/dmu/pdataframe/utilities.py +0 -0
  23. {data_manipulation_utilities-0.2.5 → data_manipulation_utilities-0.2.6}/src/dmu/plotting/matrix.py +0 -0
  24. {data_manipulation_utilities-0.2.5 → data_manipulation_utilities-0.2.6}/src/dmu/plotting/plotter.py +0 -0
  25. {data_manipulation_utilities-0.2.5 → data_manipulation_utilities-0.2.6}/src/dmu/plotting/plotter_1d.py +0 -0
  26. {data_manipulation_utilities-0.2.5 → data_manipulation_utilities-0.2.6}/src/dmu/plotting/plotter_2d.py +0 -0
  27. {data_manipulation_utilities-0.2.5 → data_manipulation_utilities-0.2.6}/src/dmu/plotting/utilities.py +0 -0
  28. {data_manipulation_utilities-0.2.5 → data_manipulation_utilities-0.2.6}/src/dmu/rdataframe/atr_mgr.py +0 -0
  29. {data_manipulation_utilities-0.2.5 → data_manipulation_utilities-0.2.6}/src/dmu/rdataframe/utilities.py +0 -0
  30. {data_manipulation_utilities-0.2.5 → data_manipulation_utilities-0.2.6}/src/dmu/rfile/rfprinter.py +0 -0
  31. {data_manipulation_utilities-0.2.5 → data_manipulation_utilities-0.2.6}/src/dmu/rfile/utilities.py +0 -0
  32. {data_manipulation_utilities-0.2.5 → data_manipulation_utilities-0.2.6}/src/dmu/stats/fitter.py +0 -0
  33. {data_manipulation_utilities-0.2.5 → data_manipulation_utilities-0.2.6}/src/dmu/stats/function.py +0 -0
  34. {data_manipulation_utilities-0.2.5 → data_manipulation_utilities-0.2.6}/src/dmu/stats/gof_calculator.py +0 -0
  35. {data_manipulation_utilities-0.2.5 → data_manipulation_utilities-0.2.6}/src/dmu/stats/utilities.py +0 -0
  36. {data_manipulation_utilities-0.2.5 → data_manipulation_utilities-0.2.6}/src/dmu/stats/zfit_plotter.py +0 -0
  37. {data_manipulation_utilities-0.2.5 → data_manipulation_utilities-0.2.6}/src/dmu/testing/utilities.py +0 -0
  38. {data_manipulation_utilities-0.2.5 → data_manipulation_utilities-0.2.6}/src/dmu/text/transformer.py +0 -0
  39. {data_manipulation_utilities-0.2.5 → data_manipulation_utilities-0.2.6}/src/dmu_data/__init__.py +0 -0
  40. {data_manipulation_utilities-0.2.5 → data_manipulation_utilities-0.2.6}/src/dmu_data/plotting/tests/2d.yaml +0 -0
  41. {data_manipulation_utilities-0.2.5 → data_manipulation_utilities-0.2.6}/src/dmu_data/plotting/tests/fig_size.yaml +0 -0
  42. {data_manipulation_utilities-0.2.5 → data_manipulation_utilities-0.2.6}/src/dmu_data/plotting/tests/high_stat.yaml +0 -0
  43. {data_manipulation_utilities-0.2.5 → data_manipulation_utilities-0.2.6}/src/dmu_data/plotting/tests/legend.yaml +0 -0
  44. {data_manipulation_utilities-0.2.5 → data_manipulation_utilities-0.2.6}/src/dmu_data/plotting/tests/name.yaml +0 -0
  45. {data_manipulation_utilities-0.2.5 → data_manipulation_utilities-0.2.6}/src/dmu_data/plotting/tests/no_bounds.yaml +0 -0
  46. {data_manipulation_utilities-0.2.5 → data_manipulation_utilities-0.2.6}/src/dmu_data/plotting/tests/normalized.yaml +0 -0
  47. {data_manipulation_utilities-0.2.5 → data_manipulation_utilities-0.2.6}/src/dmu_data/plotting/tests/simple.yaml +0 -0
  48. {data_manipulation_utilities-0.2.5 → data_manipulation_utilities-0.2.6}/src/dmu_data/plotting/tests/stats.yaml +0 -0
  49. {data_manipulation_utilities-0.2.5 → data_manipulation_utilities-0.2.6}/src/dmu_data/plotting/tests/title.yaml +0 -0
  50. {data_manipulation_utilities-0.2.5 → data_manipulation_utilities-0.2.6}/src/dmu_data/plotting/tests/weights.yaml +0 -0
  51. {data_manipulation_utilities-0.2.5 → data_manipulation_utilities-0.2.6}/src/dmu_data/text/transform.toml +0 -0
  52. {data_manipulation_utilities-0.2.5 → data_manipulation_utilities-0.2.6}/src/dmu_data/text/transform.txt +0 -0
  53. {data_manipulation_utilities-0.2.5 → data_manipulation_utilities-0.2.6}/src/dmu_data/text/transform_set.toml +0 -0
  54. {data_manipulation_utilities-0.2.5 → data_manipulation_utilities-0.2.6}/src/dmu_data/text/transform_set.txt +0 -0
  55. {data_manipulation_utilities-0.2.5 → data_manipulation_utilities-0.2.6}/src/dmu_data/text/transform_trf.txt +0 -0
  56. {data_manipulation_utilities-0.2.5 → data_manipulation_utilities-0.2.6}/src/dmu_scripts/git/publish +0 -0
  57. {data_manipulation_utilities-0.2.5 → data_manipulation_utilities-0.2.6}/src/dmu_scripts/physics/check_truth.py +0 -0
  58. {data_manipulation_utilities-0.2.5 → data_manipulation_utilities-0.2.6}/src/dmu_scripts/rfile/compare_root_files.py +0 -0
  59. {data_manipulation_utilities-0.2.5 → data_manipulation_utilities-0.2.6}/src/dmu_scripts/rfile/print_trees.py +0 -0
  60. {data_manipulation_utilities-0.2.5 → data_manipulation_utilities-0.2.6}/src/dmu_scripts/ssh/coned.py +0 -0
  61. {data_manipulation_utilities-0.2.5 → data_manipulation_utilities-0.2.6}/src/dmu_scripts/text/transform_text.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: data_manipulation_utilities
3
- Version: 0.2.5
3
+ Version: 0.2.6
4
4
  Description-Content-Type: text/markdown
5
5
  Requires-Dist: logzero
6
6
  Requires-Dist: PyYAML
@@ -427,7 +427,7 @@ rdf_bkg = _get_rdf(kind='bkg')
427
427
  cfg = _get_config()
428
428
 
429
429
  obj= TrainMva(sig=rdf_sig, bkg=rdf_bkg, cfg=cfg)
430
- obj.run()
430
+ obj.run(skip_fit=False) # by default it will be false, if true, it will only make plots of features
431
431
  ```
432
432
 
433
433
  where the settings for the training go in a config dictionary, which when written to YAML looks like:
@@ -549,7 +549,7 @@ When evaluating the model with real data, problems might occur, we deal with the
549
549
  ```python
550
550
  model.cfg
551
551
  ```
552
- - For whatever entries that are still NaN, they will be _patched_ with zeros and evaluated. However, before returning, the probabilities will be
552
+ - For whatever features that are still NaN, they will be _patched_ with zeros when evaluated. However, the returned probabilities will be
553
553
  saved as -1. I.e. entries with NaNs will have probabilities of -1.
554
554
 
555
555
  # Pandas dataframes
@@ -407,7 +407,7 @@ rdf_bkg = _get_rdf(kind='bkg')
407
407
  cfg = _get_config()
408
408
 
409
409
  obj= TrainMva(sig=rdf_sig, bkg=rdf_bkg, cfg=cfg)
410
- obj.run()
410
+ obj.run(skip_fit=False) # by default it will be false, if true, it will only make plots of features
411
411
  ```
412
412
 
413
413
  where the settings for the training go in a config dictionary, which when written to YAML looks like:
@@ -529,7 +529,7 @@ When evaluating the model with real data, problems might occur, we deal with the
529
529
  ```python
530
530
  model.cfg
531
531
  ```
532
- - For whatever entries that are still NaN, they will be _patched_ with zeros and evaluated. However, before returning, the probabilities will be
532
+ - For whatever features that are still NaN, they will be _patched_ with zeros when evaluated. However, the returned probabilities will be
533
533
  saved as -1. I.e. entries with NaNs will have probabilities of -1.
534
534
 
535
535
  # Pandas dataframes
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = 'data_manipulation_utilities'
3
- version = '0.2.5'
3
+ version = '0.2.6'
4
4
  readme = 'README.md'
5
5
  dependencies= [
6
6
  'logzero',
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: data_manipulation_utilities
3
- Version: 0.2.5
3
+ Version: 0.2.6
4
4
  Description-Content-Type: text/markdown
5
5
  Requires-Dist: logzero
6
6
  Requires-Dist: PyYAML
@@ -427,7 +427,7 @@ rdf_bkg = _get_rdf(kind='bkg')
427
427
  cfg = _get_config()
428
428
 
429
429
  obj= TrainMva(sig=rdf_sig, bkg=rdf_bkg, cfg=cfg)
430
- obj.run()
430
+ obj.run(skip_fit=False) # by default it will be false, if true, it will only make plots of features
431
431
  ```
432
432
 
433
433
  where the settings for the training go in a config dictionary, which when written to YAML looks like:
@@ -549,7 +549,7 @@ When evaluating the model with real data, problems might occur, we deal with the
549
549
  ```python
550
550
  model.cfg
551
551
  ```
552
- - For whatever entries that are still NaN, they will be _patched_ with zeros and evaluated. However, before returning, the probabilities will be
552
+ - For whatever features that are still NaN, they will be _patched_ with zeros when evaluated. However, the returned probabilities will be
553
553
  saved as -1. I.e. entries with NaNs will have probabilities of -1.
554
554
 
555
555
  # Pandas dataframes
@@ -1,7 +1,7 @@
1
1
  '''
2
2
  Module with TrainMva class
3
3
  '''
4
- # pylint: disable = too-many-locals
4
+ # pylint: disable = too-many-locals, no-name-in-module
5
5
  # pylint: disable = too-many-arguments, too-many-positional-arguments
6
6
 
7
7
  import os
@@ -14,7 +14,7 @@ import matplotlib.pyplot as plt
14
14
  from sklearn.metrics import roc_curve, auc
15
15
  from sklearn.model_selection import StratifiedKFold
16
16
 
17
- from ROOT import RDataFrame
17
+ from ROOT import RDataFrame, RDF
18
18
 
19
19
  import dmu.ml.utilities as ut
20
20
  import dmu.pdataframe.utilities as put
@@ -33,40 +33,41 @@ class TrainMva:
33
33
  Interface to scikit learn used to train classifier
34
34
  '''
35
35
  # ---------------------------------------------
36
- def __init__(self, bkg=None, sig=None, cfg=None):
36
+ def __init__(self, bkg : RDataFrame, sig : RDataFrame, cfg : dict):
37
37
  '''
38
38
  bkg (ROOT dataframe): Holds real data
39
39
  sig (ROOT dataframe): Holds simulation
40
40
  cfg (dict) : Dictionary storing configuration for training
41
41
  '''
42
- if bkg is None:
43
- raise ValueError('Background dataframe is not a ROOT dataframe')
44
-
45
- if sig is None:
46
- raise ValueError('Signal dataframe is not a ROOT dataframe')
47
-
48
- if not isinstance(cfg, dict):
49
- raise ValueError('Config dictionary is not a dictionary')
42
+ self._cfg = cfg
43
+ self._l_ft_name = self._cfg['training']['features']
50
44
 
51
- self._rdf_bkg = bkg
52
- self._rdf_sig = sig
53
- self._cfg = cfg
45
+ df_ft_sig, l_lab_sig = self._get_sample_inputs(rdf = sig, label = 1)
46
+ df_ft_bkg, l_lab_bkg = self._get_sample_inputs(rdf = bkg, label = 0)
54
47
 
55
- self._l_ft_name = self._cfg['training']['features']
48
+ self._df_ft = pnd.concat([df_ft_sig, df_ft_bkg], axis=0)
49
+ self._l_lab = numpy.array(l_lab_sig + l_lab_bkg)
56
50
 
57
- self._df_ft, self._l_lab = self._get_inputs()
51
+ self._rdf_bkg = self._get_rdf(rdf = bkg, df=df_ft_bkg)
52
+ self._rdf_sig = self._get_rdf(rdf = sig, df=df_ft_sig)
58
53
  # ---------------------------------------------
59
- def _get_inputs(self) -> tuple[pnd.DataFrame, npa]:
60
- log.info('Getting signal')
61
- df_sig, arr_lab_sig = self._get_sample_inputs(self._rdf_sig, label = 1)
54
+ def _get_rdf(self, rdf : RDataFrame, df : pnd.DataFrame) -> RDataFrame:
55
+ '''
56
+ Takes original ROOT dataframe and pre-processed features dataframe
57
+ Adds missing branches to latter and returns expanded ROOT dataframe
58
+ '''
62
59
 
63
- log.info('Getting background')
64
- df_bkg, arr_lab_bkg = self._get_sample_inputs(self._rdf_bkg, label = 0)
60
+ l_pnd_col = df.columns.tolist()
61
+ l_rdf_col = [ name.c_str() for name in rdf.GetColumnNames() ]
62
+ l_mis_col = [ col for col in l_rdf_col if col not in l_pnd_col ]
65
63
 
66
- df = pnd.concat([df_sig, df_bkg], axis=0)
67
- arr_lab = numpy.concatenate([arr_lab_sig, arr_lab_bkg])
64
+ log.debug(f'Adding extra-nonfeature columns: {l_mis_col}')
68
65
 
69
- return df, arr_lab
66
+ d_data = rdf.AsNumpy(l_mis_col)
67
+ df_ext = pnd.DataFrame(d_data)
68
+ df_all = pnd.concat([df, df_ext], axis=1)
69
+
70
+ return RDF.FromPandas(df_all)
70
71
  # ---------------------------------------------
71
72
  def _pre_process_nans(self, df : pnd.DataFrame) -> pnd.DataFrame:
72
73
  if 'dataset' not in self._cfg:
@@ -77,23 +78,26 @@ class TrainMva:
77
78
  return df
78
79
 
79
80
  d_name_val = self._cfg['dataset']['nan']
80
- log.info(60 * '-')
81
+ log.info(70 * '-')
81
82
  log.info('Doing NaN replacements')
82
- log.info(60 * '-')
83
+ log.info(70 * '-')
83
84
  for var, val in d_name_val.items():
84
- log.info(f'{var:<20}{"--->":20}{val:<20.3f}')
85
+ nna = df[var].isna().sum()
86
+
87
+ log.info(f'{var:<20}{"--->":20}{val:<20.3f}{nna}')
85
88
  df[var] = df[var].fillna(val)
89
+ log.info(70 * '-')
86
90
 
87
91
  return df
88
92
  # ---------------------------------------------
89
- def _get_sample_inputs(self, rdf : RDataFrame, label : int) -> tuple[pnd.DataFrame, npa]:
93
+ def _get_sample_inputs(self, rdf : RDataFrame, label : int) -> tuple[pnd.DataFrame, list[int]]:
90
94
  d_ft = rdf.AsNumpy(self._l_ft_name)
91
95
  df = pnd.DataFrame(d_ft)
92
96
  df = self._pre_process_nans(df)
93
97
  df = ut.cleanup(df)
94
98
  l_lab= len(df) * [label]
95
99
 
96
- return df, numpy.array(l_lab)
100
+ return df, l_lab
97
101
  # ---------------------------------------------
98
102
  def _get_model(self, arr_index : npa) -> cls:
99
103
  model = cls(cfg = self._cfg)
@@ -1,12 +1,16 @@
1
1
  '''
2
2
  Module containing derived classes from ZFit minimizer
3
3
  '''
4
+ from typing import Union
4
5
  import numpy
5
6
 
6
7
  import zfit
8
+ import matplotlib.pyplot as plt
9
+
7
10
  from zfit.result import FitResult
8
11
  from zfit.core.basepdf import BasePDF as zpdf
9
12
  from zfit.minimizers.baseminimizer import FailMinimizeNaN
13
+ from dmu.stats.utilities import print_pdf
10
14
  from dmu.stats.gof_calculator import GofCalculator
11
15
  from dmu.logging.log_store import LogStore
12
16
 
@@ -29,6 +33,7 @@ class AnealingMinimizer(zfit.minimize.Minuit):
29
33
  self._chi2ndof = chi2ndof
30
34
 
31
35
  self._check_thresholds()
36
+ self._l_bad_fit_res : list[FitResult] = []
32
37
 
33
38
  super().__init__()
34
39
  # ------------------------
@@ -66,19 +71,24 @@ class AnealingMinimizer(zfit.minimize.Minuit):
66
71
  return is_good
67
72
  # ------------------------
68
73
  def _is_good_fit(self, res : FitResult) -> bool:
74
+ good_fit = True
75
+
69
76
  if not res.valid:
70
- log.warning('Skipping invalid fit')
71
- return False
77
+ log.debug('Skipping invalid fit')
78
+ good_fit = False
72
79
 
73
80
  if res.status != 0:
74
- log.warning('Skipping fit with bad status')
75
- return False
81
+ log.debug('Skipping fit with bad status')
82
+ good_fit = False
76
83
 
77
84
  if not res.converged:
78
- log.warning('Skipping non-converging fit')
79
- return False
85
+ log.debug('Skipping non-converging fit')
86
+ good_fit = False
80
87
 
81
- return True
88
+ if not good_fit:
89
+ self._l_bad_fit_res.append(res)
90
+
91
+ return good_fit
82
92
  # ------------------------
83
93
  def _get_gof(self, nll) -> tuple[float, float]:
84
94
  log.debug('Checking GOF')
@@ -108,10 +118,11 @@ class AnealingMinimizer(zfit.minimize.Minuit):
108
118
  par.set_value(fval)
109
119
  log.debug(f'{par.name:<20}{ival:<15.3f}{"->":<10}{fval:<15.3f}{"in":<5}{par.lower:<15.3e}{par.upper:<15.3e}')
110
120
  # ------------------------
111
- def _pick_best_fit(self, d_chi2_res : dict) -> FitResult:
121
+ def _pick_best_fit(self, d_chi2_res : dict) -> Union[FitResult,None]:
112
122
  nres = len(d_chi2_res)
113
123
  if nres == 0:
114
- raise ValueError('No fits found')
124
+ log.error('No fits found')
125
+ return None
115
126
 
116
127
  l_chi2_res= list(d_chi2_res.items())
117
128
  l_chi2_res.sort()
@@ -149,6 +160,15 @@ class AnealingMinimizer(zfit.minimize.Minuit):
149
160
 
150
161
  return l_model[0]
151
162
  # ------------------------
163
+ def _print_failed_fit_diagnostics(self, nll) -> None:
164
+ for res in self._l_bad_fit_res:
165
+ print(res)
166
+
167
+ arr_mass = nll.data[0].numpy()
168
+
169
+ plt.hist(arr_mass, bins=60)
170
+ plt.show()
171
+ # ------------------------
152
172
  def minimize(self, nll, **kwargs) -> FitResult:
153
173
  '''
154
174
  Will run minimization and return FitResult object
@@ -156,18 +176,20 @@ class AnealingMinimizer(zfit.minimize.Minuit):
156
176
 
157
177
  d_chi2_res : dict[float,FitResult] = {}
158
178
  for i_try in range(self._ntries):
159
- log.info(f'try {i_try:02}/{self._ntries:02}')
160
179
  try:
161
180
  res = super().minimize(nll, **kwargs)
162
181
  except (FailMinimizeNaN, ValueError, RuntimeError) as exc:
163
- log.warning(exc)
182
+ log.error(f'{i_try:02}/{self._ntries:02}{"Failed":>20}')
183
+ log.debug(exc)
164
184
  self._randomize_parameters(nll)
165
185
  continue
166
186
 
167
187
  if not self._is_good_fit(res):
188
+ log.warning(f'{i_try:02}/{self._ntries:02}{"Bad fit":>20}')
168
189
  continue
169
190
 
170
191
  chi2, pvl = self._get_gof(nll)
192
+ log.info(f'{i_try:02}/{self._ntries:02}{chi2:>20.3f}')
171
193
  d_chi2_res[chi2] = res
172
194
 
173
195
  if self._is_good_gof(chi2, pvl):
@@ -176,6 +198,13 @@ class AnealingMinimizer(zfit.minimize.Minuit):
176
198
  self._randomize_parameters(nll)
177
199
 
178
200
  res = self._pick_best_fit(d_chi2_res)
201
+ if res is None:
202
+ self._print_failed_fit_diagnostics(nll)
203
+ pdf = nll.model[0]
204
+ print_pdf(pdf)
205
+
206
+ raise ValueError('Fit failed')
207
+
179
208
  pdf = self._pdf_from_nll(nll)
180
209
  self._set_pdf_pars(res, pdf)
181
210
 
@@ -37,7 +37,16 @@ class MethodRegistry:
37
37
  '''
38
38
  Will return method in charge of building PDF, for an input nickname
39
39
  '''
40
- return cls._d_method.get(nickname, None)
40
+ method = cls._d_method.get(nickname, None)
41
+
42
+ if method is not None:
43
+ return method
44
+
45
+ log.warning('Available PDFs:')
46
+ for value in cls._d_method:
47
+ log.info(f' {value}')
48
+
49
+ return method
41
50
  #-----------------------------------------
42
51
  class ModelFactory:
43
52
  '''
@@ -48,39 +57,56 @@ class ModelFactory:
48
57
 
49
58
  l_pdf = ['dscb', 'gauss']
50
59
  l_shr = ['mu']
51
- mod = ModelFactory(obs = obs, l_pdf = l_pdf, l_shared=l_shr)
60
+ mod = ModelFactory(preffix = 'signal', obs = obs, l_pdf = l_pdf, l_shared=l_shr)
52
61
  pdf = mod.get_pdf()
53
62
  ```
54
63
 
55
64
  where one can specify which parameters can be shared among the PDFs
56
65
  '''
57
66
  #-----------------------------------------
58
- def __init__(self, obs : zobs, l_pdf : list[str], l_shared : list[str]):
67
+ def __init__(self,
68
+ preffix : str,
69
+ obs : zobs,
70
+ l_pdf : list[str],
71
+ l_shared : list[str],
72
+ l_float : list[str]):
59
73
  '''
74
+ preffix: used to identify PDF, will be used to name every parameter
60
75
  obs: zfit obserbable
61
76
  l_pdf: List of PDF nicknames which are registered below
62
77
  l_shared: List of parameter names that are shared
78
+ l_float: List of parameter names to allow to float
63
79
  '''
64
80
 
81
+ self._preffix = preffix
65
82
  self._l_pdf = l_pdf
66
83
  self._l_shr = l_shared
67
- self._l_can_be_shared = ['mu', 'sg']
84
+ self._l_flt = l_float
68
85
  self._obs = obs
69
86
 
70
87
  self._d_par : dict[str,zpar] = {}
71
88
  #-----------------------------------------
72
- def _fltname_from_name(self, name : str) -> str:
73
- if name in ['mu', 'sg']:
74
- return f'{name}_flt'
89
+ def _split_name(self, name : str) -> tuple[str,str]:
90
+ l_part = name.split('_')
91
+ pname = l_part[0]
92
+ xname = '_'.join(l_part[1:])
75
93
 
76
- return name
94
+ return pname, xname
77
95
  #-----------------------------------------
78
- def _get_name(self, name : str, suffix : str) -> str:
79
- for can_be_shared in self._l_can_be_shared:
80
- if name.startswith(f'{can_be_shared}_') and can_be_shared in self._l_shr:
81
- return self._fltname_from_name(can_be_shared)
96
+ def _get_parameter_name(self, name : str, suffix : str) -> str:
97
+ pname, xname = self._split_name(name)
98
+
99
+ log.debug(f'Using physical name: {pname}')
82
100
 
83
- return self._fltname_from_name(f'{name}{suffix}')
101
+ if pname in self._l_shr:
102
+ name = f'{pname}_{self._preffix}'
103
+ else:
104
+ name = f'{pname}_{xname}_{self._preffix}{suffix}'
105
+
106
+ if pname in self._l_flt:
107
+ return f'{name}_flt'
108
+
109
+ return name
84
110
  #-----------------------------------------
85
111
  def _get_parameter(self,
86
112
  name : str,
@@ -88,7 +114,10 @@ class ModelFactory:
88
114
  val : float,
89
115
  low : float,
90
116
  high : float) -> zpar:
91
- name = self._get_name(name, suffix)
117
+
118
+ name = self._get_parameter_name(name, suffix)
119
+ log.debug(f'Assigning name: {name}')
120
+
92
121
  if name in self._d_par:
93
122
  return self._d_par[name]
94
123
 
@@ -100,15 +129,15 @@ class ModelFactory:
100
129
  #-----------------------------------------
101
130
  @MethodRegistry.register('exp')
102
131
  def _get_exponential(self, suffix : str = '') -> zpdf:
103
- c = self._get_parameter('c_exp', suffix, -0.005, -0.05, 0.00)
104
- pdf = zfit.pdf.Exponential(c, self._obs)
132
+ c = self._get_parameter('c_exp', suffix, -0.005, -0.20, 0.00)
133
+ pdf = zfit.pdf.Exponential(c, self._obs, name=f'exp{suffix}')
105
134
 
106
135
  return pdf
107
136
  #-----------------------------------------
108
137
  @MethodRegistry.register('pol1')
109
138
  def _get_pol1(self, suffix : str = '') -> zpdf:
110
139
  a = self._get_parameter('a_pol1', suffix, -0.005, -0.95, 0.00)
111
- pdf = zfit.pdf.Chebyshev(obs=self._obs, coeffs=[a])
140
+ pdf = zfit.pdf.Chebyshev(obs=self._obs, coeffs=[a], name=f'pol1{suffix}')
112
141
 
113
142
  return pdf
114
143
  #-----------------------------------------
@@ -116,51 +145,62 @@ class ModelFactory:
116
145
  def _get_pol2(self, suffix : str = '') -> zpdf:
117
146
  a = self._get_parameter('a_pol2', suffix, -0.005, -0.95, 0.00)
118
147
  b = self._get_parameter('b_pol2', suffix, 0.000, -0.95, 0.95)
119
- pdf = zfit.pdf.Chebyshev(obs=self._obs, coeffs=[a, b])
148
+ pdf = zfit.pdf.Chebyshev(obs=self._obs, coeffs=[a, b], name=f'pol2{suffix}')
120
149
 
121
150
  return pdf
122
151
  #-----------------------------------------
123
152
  @MethodRegistry.register('cbr')
124
153
  def _get_cbr(self, suffix : str = '') -> zpdf:
125
- mu = self._get_parameter('mu_cbr', suffix, 5300, 5250, 5350)
154
+ mu = self._get_parameter('mu_cbr', suffix, 5300, 5100, 5350)
126
155
  sg = self._get_parameter('sg_cbr', suffix, 10, 2, 300)
127
- ar = self._get_parameter('ac_cbr', suffix, -2, -4., -1.)
128
- nr = self._get_parameter('nc_cbr', suffix, 1, 0.5, 5.0)
156
+ ar = self._get_parameter('ac_cbr', suffix, -2, -14., -0.1)
157
+ nr = self._get_parameter('nc_cbr', suffix, 1, 0.5, 150)
158
+
159
+ pdf = zfit.pdf.CrystalBall(mu, sg, ar, nr, self._obs, name=f'cbr{suffix}')
160
+
161
+ return pdf
162
+ #-----------------------------------------
163
+ @MethodRegistry.register('suj')
164
+ def _get_suj(self, suffix : str = '') -> zpdf:
165
+ mu = self._get_parameter('mu_suj', suffix, 5300, 4000, 6000)
166
+ sg = self._get_parameter('sg_suj', suffix, 10, 2, 5000)
167
+ gm = self._get_parameter('gm_suj', suffix, 1, -10, 10)
168
+ dl = self._get_parameter('dl_suj', suffix, 1, 0.1, 10)
129
169
 
130
- pdf = zfit.pdf.CrystalBall(mu, sg, ar, nr, self._obs)
170
+ pdf = zfit.pdf.JohnsonSU(mu, sg, gm, dl, self._obs, name=f'suj{suffix}')
131
171
 
132
172
  return pdf
133
173
  #-----------------------------------------
134
174
  @MethodRegistry.register('cbl')
135
175
  def _get_cbl(self, suffix : str = '') -> zpdf:
136
- mu = self._get_parameter('mu_cbl', suffix, 5300, 5250, 5350)
176
+ mu = self._get_parameter('mu_cbl', suffix, 5300, 5100, 5350)
137
177
  sg = self._get_parameter('sg_cbl', suffix, 10, 2, 300)
138
- al = self._get_parameter('ac_cbl', suffix, 2, 1., 14.)
139
- nl = self._get_parameter('nc_cbl', suffix, 1, 0.5, 15.)
178
+ al = self._get_parameter('ac_cbl', suffix, 2, 0.1, 14.)
179
+ nl = self._get_parameter('nc_cbl', suffix, 1, 0.5, 150)
140
180
 
141
- pdf = zfit.pdf.CrystalBall(mu, sg, al, nl, self._obs)
181
+ pdf = zfit.pdf.CrystalBall(mu, sg, al, nl, self._obs, name=f'cbl{suffix}')
142
182
 
143
183
  return pdf
144
184
  #-----------------------------------------
145
185
  @MethodRegistry.register('gauss')
146
186
  def _get_gauss(self, suffix : str = '') -> zpdf:
147
- mu = self._get_parameter('mu_gauss', suffix, 5300, 5250, 5350)
187
+ mu = self._get_parameter('mu_gauss', suffix, 5300, 5100, 5350)
148
188
  sg = self._get_parameter('sg_gauss', suffix, 10, 2, 300)
149
189
 
150
- pdf = zfit.pdf.Gauss(mu, sg, self._obs)
190
+ pdf = zfit.pdf.Gauss(mu, sg, self._obs, name=f'gauss{suffix}')
151
191
 
152
192
  return pdf
153
193
  #-----------------------------------------
154
194
  @MethodRegistry.register('dscb')
155
195
  def _get_dscb(self, suffix : str = '') -> zpdf:
156
- mu = self._get_parameter('mu_dscb', suffix, 5300, 5250, 5400)
157
- sg = self._get_parameter('sg_dscb', suffix, 10, 2, 30)
196
+ mu = self._get_parameter('mu_dscb', suffix, 4000, 4000, 5400)
197
+ sg = self._get_parameter('sg_dscb', suffix, 10, 2, 500)
158
198
  ar = self._get_parameter('ar_dscb', suffix, 1, 0, 5)
159
199
  al = self._get_parameter('al_dscb', suffix, 1, 0, 5)
160
- nr = self._get_parameter('nr_dscb', suffix, 2, 1, 15)
161
- nl = self._get_parameter('nl_dscb', suffix, 2, 0, 15)
200
+ nr = self._get_parameter('nr_dscb', suffix, 2, 1, 150)
201
+ nl = self._get_parameter('nl_dscb', suffix, 2, 0, 150)
162
202
 
163
- pdf = zfit.pdf.DoubleCB(mu, sg, al, nl, ar, nr, self._obs)
203
+ pdf = zfit.pdf.DoubleCB(mu, sg, al, nl, ar, nr, self._obs, name=f'dscb{suffix}')
164
204
 
165
205
  return pdf
166
206
  #-----------------------------------------
@@ -196,7 +236,7 @@ class ModelFactory:
196
236
 
197
237
  l_frc= [ zfit.param.Parameter(f'frc_{ifrc + 1}', 0.5, 0, 1) for ifrc in range(nfrc - 1) ]
198
238
 
199
- pdf = zfit.pdf.SumPDF(l_pdf, fracs=l_frc)
239
+ pdf = zfit.pdf.SumPDF(l_pdf, name=self._preffix, fracs=l_frc)
200
240
 
201
241
  return pdf
202
242
  #-----------------------------------------
@@ -1,7 +1,7 @@
1
1
  dataset:
2
2
  nan :
3
- x : 1
4
- y : 2
3
+ x : -3
4
+ y : -3
5
5
  training :
6
6
  nfold : 3
7
7
  features : [x, y, z]
@@ -34,6 +34,10 @@ plotting:
34
34
  saving:
35
35
  plt_dir : '/tmp/dmu/ml/tests/train_mva/features'
36
36
  plots:
37
+ w :
38
+ binning : [-4, 4, 100]
39
+ yscale : 'linear'
40
+ labels : ['w', '']
37
41
  x :
38
42
  binning : [-4, 4, 100]
39
43
  yscale : 'linear'
@@ -46,4 +50,3 @@ plotting:
46
50
  binning : [-4, 4, 100]
47
51
  yscale : 'linear'
48
52
  labels : ['z', '']
49
-