data-manipulation-utilities 0.1.7__py3-none-any.whl → 0.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: data_manipulation_utilities
3
- Version: 0.1.7
3
+ Version: 0.1.9
4
4
  Description-Content-Type: text/markdown
5
5
  Requires-Dist: logzero
6
6
  Requires-Dist: PyYAML
@@ -121,6 +121,24 @@ samples:
121
121
 
122
122
  ## PDFs
123
123
 
124
+ ### Model building
125
+
126
+ In order to do complex fits, one often needs PDFs with many parameters, which need to be added.
127
+ In these PDFs certain parameters (e.g. $\mu$ or $\sigma$) need to be shared. This project provides
128
+ `ModelFactory`, which can do this as shown below:
129
+
130
+ ```python
131
+ from dmu.stats.model_factory import ModelFactory
132
+
133
+ l_pdf = ['cbr'] + 2 * ['cbl']
134
+ l_shr = ['mu', 'sg']
135
+ mod = ModelFactory(obs = Data.obs, l_pdf = l_pdf, l_shared=l_shr)
136
+ pdf = mod.get_pdf()
137
+ ```
138
+
139
+ where the model is a sum of three `CrystallBall` PDFs, one with a right tail and two with a left tail.
140
+ The `mu` and `sg` parameters are shared.
141
+
124
142
  ### Printing PDFs
125
143
 
126
144
  One can print a zfit PDF by doing:
@@ -231,6 +249,44 @@ likelihood :
231
249
  nbins : 100 #If specified, will do binned likelihood fit instead of unbinned
232
250
  ```
233
251
 
252
+ ## Minimizers
253
+
254
+ These are alternative implementations of the minimizers in zfit meant to be used for special types of fits.
255
+
256
+ ### Anealing minimizer
257
+
258
+ This minimizer is meant to be used for fits to models with many parameters, where multiple minima are expected in the
259
+ likelihood. The minimizer use is illustrated in:
260
+
261
+ ```python
262
+ from dmu.stats.minimizers import AnealingMinimizer
263
+
264
+ nll = _get_nll()
265
+ minimizer = AnealingMinimizer(ntries=10, pvalue=0.05)
266
+ res = minimizer.minimize(nll)
267
+ ```
268
+
269
+ this will:
270
+
271
+ - Take the `NLL` object.
272
+ - Try fitting at most 10 times
273
+ - After each fit, calculate the goodness of fit (in this case the p-value)
274
+ - Stop when the number of tries has been exhausted or the p-value reached is higher than `0.05`
275
+ - If the fit has not succeeded because of convergence, validity or goodness of fit issues,
276
+ randomize the parameters and try again.
277
+ - If the desired goodness of fit has not been achieved, pick the best result.
278
+ - Return the `FitResult` object and set the PDF to the final fit result.
279
+
280
+ The $\chi^2/Ndof$ can also be used as in:
281
+
282
+ ```python
283
+ from dmu.stats.minimizers import AnealingMinimizer
284
+
285
+ nll = _get_nll()
286
+ minimizer = AnealingMinimizer(ntries=10, chi2ndof=1.00)
287
+ res = minimizer.minimize(nll)
288
+ ```
289
+
234
290
  ## Fit plotting
235
291
 
236
292
  The class `ZFitPlotter` can be used to plot fits done with zfit. For a complete set of examples of how to use
@@ -1,4 +1,4 @@
1
- data_manipulation_utilities-0.1.7.data/scripts/publish,sha256=-3K_Y2_4CfWCV50rPB8CRuhjxDu7xMGswinRwPovgLs,1976
1
+ data_manipulation_utilities-0.1.9.data/scripts/publish,sha256=-3K_Y2_4CfWCV50rPB8CRuhjxDu7xMGswinRwPovgLs,1976
2
2
  dmu/arrays/utilities.py,sha256=PKoYyybPptA2aU-V3KLnJXBudWxTXu4x1uGdIMQ49HY,1722
3
3
  dmu/generic/utilities.py,sha256=0Xnq9t35wuebAqKxbyAiMk1ISB7IcXK4cFH25MT1fgw,1741
4
4
  dmu/logging/log_store.py,sha256=umdvjNDuV3LdezbG26b0AiyTglbvkxST19CQu9QATbA,4184
@@ -13,8 +13,11 @@ dmu/rdataframe/atr_mgr.py,sha256=FdhaQWVpsm4OOe1IRbm7rfrq8VenTNdORyI-lZ2Bs1M,238
13
13
  dmu/rdataframe/utilities.py,sha256=x8r379F2-vZPYzAdMFCn_V4Kx2Tx9t9pn_QHcZ1euew,2756
14
14
  dmu/rfile/rfprinter.py,sha256=mp5jd-oCJAnuokbdmGyL9i6tK2lY72jEfROuBIZ_ums,3941
15
15
  dmu/rfile/utilities.py,sha256=XuYY7HuSBj46iSu3c60UYBHtI6KIPoJU_oofuhb-be0,945
16
- dmu/stats/fitter.py,sha256=LDvFNyhgO0OzXN7aH3kfHe6LzuPqdQfPcKR_IegDcaU,18204
16
+ dmu/stats/fitter.py,sha256=vHNZ16U3apoQyeyM8evq-if49doF48sKB3q9wmA96Fw,18387
17
17
  dmu/stats/function.py,sha256=yzi_Fvp_ASsFzbWFivIf-comquy21WoeY7is6dgY0Go,9491
18
+ dmu/stats/gof_calculator.py,sha256=4EN6OhULcztFvsAZ00rxgohJemnjtDNB5o0IBcv6kbk,4657
19
+ dmu/stats/minimizers.py,sha256=f9cilFY9Kp9UvbSIUsKBGFzOOg7EEWZJLPod-4k-LAQ,6216
20
+ dmu/stats/model_factory.py,sha256=LyDOf0f9I5dNUTS0MXHtSivD8aAcTLIagvMPtoXtThk,7426
18
21
  dmu/stats/utilities.py,sha256=LQy4kd3xSXqpApcWuYfZxkGQyjowaXv2Wr1c4Bj-4ys,4523
19
22
  dmu/stats/zfit_plotter.py,sha256=Xs6kisNEmNQXhYRCcjowxO6xHuyAyrfyQIFhGAR61U4,19719
20
23
  dmu/testing/utilities.py,sha256=WbMM4e9Cn3-B-12Vr64mB5qTKkV32joStlRkD-48lG0,3460
@@ -40,8 +43,8 @@ dmu_scripts/rfile/compare_root_files.py,sha256=T8lDnQxsRNMr37x1Y7YvWD8ySHrJOWZki
40
43
  dmu_scripts/rfile/print_trees.py,sha256=Ze4Ccl_iUldl4eVEDVnYBoe4amqBT1fSBR1zN5WSztk,941
41
44
  dmu_scripts/ssh/coned.py,sha256=lhilYNHWRCGxC-jtyJ3LQ4oUgWW33B2l1tYCcyHHsR0,4858
42
45
  dmu_scripts/text/transform_text.py,sha256=9akj1LB0HAyopOvkLjNOJiptZw5XoOQLe17SlcrGMD0,1456
43
- data_manipulation_utilities-0.1.7.dist-info/METADATA,sha256=6cSG5TvicYwa0Ru5352DXpVC1k0B6Zcz2HB4vkVWEkg,21183
44
- data_manipulation_utilities-0.1.7.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
45
- data_manipulation_utilities-0.1.7.dist-info/entry_points.txt,sha256=1TIZDed651KuOH-DgaN5AoBdirKmrKE_oM1b6b7zTUU,270
46
- data_manipulation_utilities-0.1.7.dist-info/top_level.txt,sha256=n_x5J6uWtSqy9mRImKtdA2V2NJNyU8Kn3u8DTOKJix0,25
47
- data_manipulation_utilities-0.1.7.dist-info/RECORD,,
46
+ data_manipulation_utilities-0.1.9.dist-info/METADATA,sha256=sxu2cZc14f4VfDD2J3MLGmW0jRHXJBpmDspXUt1D_0k,23046
47
+ data_manipulation_utilities-0.1.9.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
48
+ data_manipulation_utilities-0.1.9.dist-info/entry_points.txt,sha256=1TIZDed651KuOH-DgaN5AoBdirKmrKE_oM1b6b7zTUU,270
49
+ data_manipulation_utilities-0.1.9.dist-info/top_level.txt,sha256=n_x5J6uWtSqy9mRImKtdA2V2NJNyU8Kn3u8DTOKJix0,25
50
+ data_manipulation_utilities-0.1.9.dist-info/RECORD,,
dmu/stats/fitter.py CHANGED
@@ -4,6 +4,7 @@ Module holding zfitter class
4
4
 
5
5
  import pprint
6
6
  from typing import Union
7
+ from functools import lru_cache
7
8
 
8
9
  import numpy
9
10
  import zfit
@@ -100,8 +101,8 @@ class Fitter:
100
101
 
101
102
  return data
102
103
  #------------------------------
103
- def _bin_pdf(self, nbins):
104
- [[min_x]], [[max_x]] = self._pdf.space.limits
104
+ def _bin_pdf(self):
105
+ nbins, min_x, max_x = self._get_binning()
105
106
  _, arr_edg = numpy.histogram(self._data_np, bins = nbins, range=(min_x, max_x))
106
107
 
107
108
  size = arr_edg.size
@@ -117,23 +118,29 @@ class Fitter:
117
118
 
118
119
  return numpy.array(l_bc)
119
120
  #------------------------------
121
+ def _bin_data(self):
122
+ nbins, min_x, max_x = self._get_binning()
123
+ arr_data, _ = numpy.histogram(self._data_np, bins = nbins, range=(min_x, max_x))
124
+ arr_data = arr_data.astype(float)
125
+
126
+ return arr_data
127
+ #------------------------------
128
+ @lru_cache(maxsize=10)
120
129
  def _get_binning(self):
121
130
  min_x = numpy.min(self._data_np)
122
131
  max_x = numpy.max(self._data_np)
123
132
  nbins = self._ndof + self._get_float_pars()
124
133
 
134
+ log.debug(f'Nbins: {nbins}')
135
+ log.debug(f'Range: [{min_x:.3f}, {max_x:.3f}]')
136
+
125
137
  return nbins, min_x, max_x
126
138
  #------------------------------
127
139
  def _calc_gof(self):
128
140
  log.debug('Calculating GOF')
129
- nbins, min_x, max_x = self._get_binning()
130
141
 
131
- log.debug(f'Nbins: {nbins}')
132
- log.debug(f'Range: [{min_x:.3f}, {max_x:.3f}]')
133
-
134
- arr_data, _ = numpy.histogram(self._data_np, bins = nbins, range=(min_x, max_x))
135
- arr_data = arr_data.astype(float)
136
- arr_modl = self._bin_pdf(nbins)
142
+ arr_data = self._bin_data()
143
+ arr_modl = self._bin_pdf()
137
144
  norm = numpy.sum(arr_data) / numpy.sum(arr_modl)
138
145
  arr_modl = norm * arr_modl
139
146
  arr_res = arr_modl - arr_data
@@ -0,0 +1,145 @@
1
+ '''
2
+ Module holding GofCalculator class
3
+ '''
4
+ from functools import lru_cache
5
+
6
+ import zfit
7
+ import numpy
8
+ import pandas as pnd
9
+
10
+ from scipy import stats
11
+ from zfit.core.basepdf import BasePDF as zpdf
12
+ from zfit.core.parameter import Parameter as zpar
13
+ from dmu.logging.log_store import LogStore
14
+
15
+ log = LogStore.add_logger('dmu:stats:gofcalculator')
16
+ # ------------------------
17
+ class GofCalculator:
18
+ '''
19
+ Class used to calculate goodness of fit from zfit NLL
20
+ '''
21
+ # ---------------------
22
+ def __init__(self, nll, ndof : int = 10):
23
+ self._nll = nll
24
+ self._ndof = ndof
25
+
26
+ self._pdf = self._pdf_from_nll()
27
+ self._data_in = self._data_from_nll()
28
+ self._data_np = self._data_np_from_data(self._data_in)
29
+ self._data_zf = zfit.Data.from_numpy(obs=self._pdf.space, array=self._data_np)
30
+ # ---------------------
31
+ def _data_np_from_data(self, dat) -> numpy.ndarray:
32
+ if isinstance(dat, numpy.ndarray):
33
+ return dat
34
+
35
+ if isinstance(dat, zfit.Data):
36
+ return zfit.run(zfit.z.unstack_x(dat))
37
+
38
+ if isinstance(dat, pnd.DataFrame):
39
+ return dat.to_numpy()
40
+
41
+ if isinstance(dat, pnd.Series):
42
+ dat = pnd.DataFrame(dat)
43
+ return dat.to_numpy()
44
+
45
+ data_type = str(type(dat))
46
+ raise ValueError(f'Data is not a numpy array, zfit.Data or pandas.DataFrame, but {data_type}')
47
+ # ---------------------
48
+ def _pdf_from_nll(self) -> zpdf:
49
+ l_model = self._nll.model
50
+ if len(l_model) != 1:
51
+ raise ValueError('Not found one and only one model')
52
+
53
+ return l_model[0]
54
+ # ---------------------
55
+ def _data_from_nll(self) -> zpdf:
56
+ l_data = self._nll.data
57
+ if len(l_data) != 1:
58
+ raise ValueError('Not found one and only one dataset')
59
+
60
+ return l_data[0]
61
+ # ---------------------
62
+ def _get_float_pars(self) -> int:
63
+ npar = 0
64
+ s_par = self._pdf.get_params()
65
+ for par in s_par:
66
+ if par.floating:
67
+ npar+=1
68
+
69
+ return npar
70
+ # ---------------------
71
+ @lru_cache(maxsize=10)
72
+ def _get_binning(self) -> tuple[int, float, float]:
73
+ min_x = numpy.min(self._data_np)
74
+ max_x = numpy.max(self._data_np)
75
+ nbins = self._ndof + self._get_float_pars()
76
+
77
+ log.debug(f'Nbins: {nbins}')
78
+ log.debug(f'Range: [{min_x:.3f}, {max_x:.3f}]')
79
+
80
+ return nbins, min_x, max_x
81
+ # ---------------------
82
+ def _get_pdf_bin_contents(self) -> numpy.ndarray:
83
+ nbins, min_x, max_x = self._get_binning()
84
+ _, arr_edg = numpy.histogram(self._data_np, bins = nbins, range=(min_x, max_x))
85
+
86
+ size = arr_edg.size
87
+
88
+ l_bc = []
89
+ for i_edg in range(size - 1):
90
+ low = arr_edg[i_edg + 0]
91
+ hig = arr_edg[i_edg + 1]
92
+
93
+ var : zpar = self._pdf.integrate(limits = [low, hig])
94
+ val = var.numpy()[0]
95
+ l_bc.append(val * self._data_np.size)
96
+
97
+ return numpy.array(l_bc)
98
+ #------------------------------
99
+ def _get_data_bin_contents(self) -> numpy.ndarray:
100
+ nbins, min_x, max_x = self._get_binning()
101
+ arr_data, _ = numpy.histogram(self._data_np, bins = nbins, range=(min_x, max_x))
102
+ arr_data = arr_data.astype(float)
103
+
104
+ return arr_data
105
+ #------------------------------
106
+ @lru_cache(maxsize=30)
107
+ def _calculate_gof(self) -> tuple[float, int, float]:
108
+ log.debug('Calculating GOF')
109
+
110
+ arr_data = self._get_data_bin_contents()
111
+ arr_modl = self._get_pdf_bin_contents()
112
+
113
+ norm = numpy.sum(arr_data) / numpy.sum(arr_modl)
114
+ arr_modl = norm * arr_modl
115
+ arr_res = arr_modl - arr_data
116
+
117
+ arr_chi2 = numpy.divide(arr_res ** 2, arr_data, out=numpy.zeros_like(arr_data), where=arr_data!=0)
118
+ sum_chi2 = numpy.sum(arr_chi2)
119
+
120
+ pvalue = 1 - stats.chi2.cdf(sum_chi2, self._ndof)
121
+ pvalue = float(pvalue)
122
+
123
+ log.debug(f'Chi2: {sum_chi2:.3f}')
124
+ log.debug(f'Ndof: {self._ndof}')
125
+ log.debug(f'pval: {pvalue:<.3e}')
126
+
127
+ return sum_chi2, self._ndof, pvalue
128
+ # ---------------------
129
+ def get_gof(self, kind : str) -> float:
130
+ '''
131
+ Returns good ness of fit of a given kind
132
+
133
+ kind: Type of goodness of fit, e.g. pvalue
134
+ '''
135
+
136
+ chi2, ndof, pval = self._calculate_gof()
137
+
138
+ if kind == 'pvalue':
139
+ return pval
140
+
141
+ if kind == 'chi2/ndof':
142
+ return chi2/ndof
143
+
144
+ raise NotImplementedError(f'Invalid goodness of fit: {kind}')
145
+ # ------------------------
@@ -0,0 +1,183 @@
1
+ '''
2
+ Module containing derived classes from ZFit minimizer
3
+ '''
4
+ import numpy
5
+
6
+ import zfit
7
+ from zfit.result import FitResult
8
+ from zfit.core.basepdf import BasePDF as zpdf
9
+ from zfit.minimizers.baseminimizer import FailMinimizeNaN
10
+ from dmu.stats.gof_calculator import GofCalculator
11
+ from dmu.logging.log_store import LogStore
12
+
13
+ log = LogStore.add_logger('dmu:ml:minimizers')
14
+ # ------------------------
15
+ class AnealingMinimizer(zfit.minimize.Minuit):
16
+ '''
17
+ Class meant to minimizer zfit likelihoods by using multiple retries,
18
+ each retry is preceeded by the randomization of the fitting parameters
19
+ '''
20
+ # ------------------------
21
+ def __init__(self, ntries : int, pvalue : float = -1, chi2ndof : float = -1):
22
+ '''
23
+ ntries : Try this number of times
24
+ pvalue : Stop tries when this threshold is reached
25
+ chi2ndof: Use this value as a threshold to stop fits
26
+ '''
27
+ self._ntries = ntries
28
+ self._pvalue = pvalue
29
+ self._chi2ndof = chi2ndof
30
+
31
+ self._check_thresholds()
32
+
33
+ super().__init__()
34
+ # ------------------------
35
+ def _check_thresholds(self) -> None:
36
+ good_pvalue = 0 <= self._pvalue < 1
37
+ good_chi2dof = self._chi2ndof > 0
38
+
39
+ if good_pvalue and good_chi2dof:
40
+ raise ValueError('Threshold for both chi2 and pvalue were specified')
41
+
42
+ if good_pvalue:
43
+ log.debug(f'Will use threshold on pvalue with value: {self._pvalue}')
44
+ return
45
+
46
+ if good_chi2dof:
47
+ log.debug(f'Will use threshold on chi2ndof with value: {self._chi2ndof}')
48
+ return
49
+
50
+ raise ValueError('Neither pvalue nor chi2 thresholds are valid')
51
+ # ------------------------
52
+ def _is_good_gof(self, ch2 : float, pvl : float) -> bool:
53
+ is_good_pval = pvl > self._pvalue and self._pvalue > 0
54
+ is_good_chi2 = ch2 < self._chi2ndof and self._chi2ndof > 0
55
+ is_good = is_good_pval or is_good_chi2
56
+
57
+ if is_good_pval:
58
+ log.info(f'Stopping fit, found p-value: {pvl:.3f} > {self._pvalue:.3f}')
59
+
60
+ if is_good_chi2:
61
+ log.info(f'Stopping fit, found chi2/ndof: {ch2:.3f} > {self._chi2ndof:.3f}')
62
+
63
+ if not is_good:
64
+ log.debug(f'Could not read threshold, pvalue/chi2: {pvl:.3f}/{ch2:.3f}')
65
+
66
+ return is_good
67
+ # ------------------------
68
+ def _is_good_fit(self, res : FitResult) -> bool:
69
+ if not res.valid:
70
+ log.warning('Skipping invalid fit')
71
+ return False
72
+
73
+ if res.status != 0:
74
+ log.warning('Skipping fit with bad status')
75
+ return False
76
+
77
+ if not res.converged:
78
+ log.warning('Skipping non-converging fit')
79
+ return False
80
+
81
+ return True
82
+ # ------------------------
83
+ def _get_gof(self, nll) -> tuple[float, float]:
84
+ log.debug('Checking GOF')
85
+
86
+ gcl = GofCalculator(nll)
87
+ pvl = gcl.get_gof(kind='pvalue')
88
+ ch2 = gcl.get_gof(kind='chi2/ndof')
89
+
90
+ return ch2, pvl
91
+ # ------------------------
92
+ def _randomize_parameters(self, nll):
93
+ '''
94
+ Will move floating parameters of PDF according
95
+ to uniform PDF
96
+ '''
97
+
98
+ log.debug('Randomizing parameters')
99
+ l_model = nll.model
100
+ if len(l_model) != 1:
101
+ raise ValueError('Not found and and only one model')
102
+
103
+ model = l_model[0]
104
+ s_par = model.get_params(floating=True)
105
+ for par in s_par:
106
+ ival = par.value()
107
+ fval = numpy.random.uniform(par.lower, par.upper)
108
+ par.set_value(fval)
109
+ log.debug(f'{par.name:<20}{ival:<15.3f}{"->":<10}{fval:<15.3f}{"in":<5}{par.lower:<15.3e}{par.upper:<15.3e}')
110
+ # ------------------------
111
+ def _pick_best_fit(self, d_chi2_res : dict) -> FitResult:
112
+ nres = len(d_chi2_res)
113
+ if nres == 0:
114
+ raise ValueError('No fits found')
115
+
116
+ l_chi2_res= list(d_chi2_res.items())
117
+ l_chi2_res.sort()
118
+ chi2, res = l_chi2_res[0]
119
+
120
+ log.warning(f'Picking out best fit from {nres} fits with chi2: {chi2:.3f}')
121
+
122
+ return res
123
+ #------------------------------
124
+ def _set_pdf_pars(self, res : FitResult, pdf : zpdf) -> None:
125
+ '''
126
+ Will set the PDF floating parameter values as the result instance
127
+ '''
128
+ l_par_flt = list(pdf.get_params(floating= True))
129
+ l_par_fix = list(pdf.get_params(floating=False))
130
+ l_par = l_par_flt + l_par_fix
131
+
132
+ d_val = { par.name : dc['value'] for par, dc in res.params.items()}
133
+
134
+ log.debug('Setting PDF parameters to best result')
135
+ for par in l_par:
136
+ if par.name not in d_val:
137
+ par_val = par.value().numpy()
138
+ log.debug(f'Skipping {par.name} = {par_val:.3e}')
139
+ continue
140
+
141
+ val = d_val[par.name]
142
+ log.debug(f'{"":<4}{par.name:<20}{"->":<10}{val:<20.3e}')
143
+ par.set_value(val)
144
+ # ------------------------
145
+ def _pdf_from_nll(self, nll) -> zpdf:
146
+ l_model = nll.model
147
+ if len(l_model) != 1:
148
+ raise ValueError('Cannot extract one and only one PDF from NLL')
149
+
150
+ return l_model[0]
151
+ # ------------------------
152
+ def minimize(self, nll, **kwargs) -> FitResult:
153
+ '''
154
+ Will run minimization and return FitResult object
155
+ '''
156
+
157
+ d_chi2_res : dict[float,FitResult] = {}
158
+ for i_try in range(self._ntries):
159
+ log.info(f'try {i_try:02}/{self._ntries:02}')
160
+ try:
161
+ res = super().minimize(nll, **kwargs)
162
+ except (FailMinimizeNaN, ValueError, RuntimeError) as exc:
163
+ log.warning(exc)
164
+ self._randomize_parameters(nll)
165
+ continue
166
+
167
+ if not self._is_good_fit(res):
168
+ continue
169
+
170
+ chi2, pvl = self._get_gof(nll)
171
+ d_chi2_res[chi2] = res
172
+
173
+ if self._is_good_gof(chi2, pvl):
174
+ return res
175
+
176
+ self._randomize_parameters(nll)
177
+
178
+ res = self._pick_best_fit(d_chi2_res)
179
+ pdf = self._pdf_from_nll(nll)
180
+ self._set_pdf_pars(res, pdf)
181
+
182
+ return res
183
+ # ------------------------
@@ -0,0 +1,207 @@
1
+ '''
2
+ Module storing ZModel class
3
+ '''
4
+ # pylint: disable=too-many-lines, import-error
5
+
6
+ from typing import Callable, Union
7
+
8
+ import zfit
9
+ from zfit.core.interfaces import ZfitSpace as zobs
10
+ from zfit.core.basepdf import BasePDF as zpdf
11
+ from zfit.core.parameter import Parameter as zpar
12
+ from dmu.logging.log_store import LogStore
13
+
14
+ log=LogStore.add_logger('dmu:stats:model_factory')
15
+ #-----------------------------------------
16
+ class MethodRegistry:
17
+ '''
18
+ Class intended to store protected methods belonging to ModelFactory class
19
+ which is defined in this same module
20
+ '''
21
+ # Registry dictionary to hold methods
22
+ _d_method = {}
23
+
24
+ @classmethod
25
+ def register(cls, nickname : str):
26
+ '''
27
+ Decorator in charge of registering method for given nickname
28
+ '''
29
+ def decorator(method):
30
+ cls._d_method[nickname] = method
31
+ return method
32
+
33
+ return decorator
34
+
35
+ @classmethod
36
+ def get_method(cls, nickname : str) -> Union[Callable,None]:
37
+ '''
38
+ Will return method in charge of building PDF, for an input nickname
39
+ '''
40
+ return cls._d_method.get(nickname, None)
41
+ #-----------------------------------------
42
+ class ModelFactory:
43
+ '''
44
+ Class used to create Zfit PDFs by passing only the nicknames, e.g.:
45
+
46
+ ```python
47
+ from dmu.stats.model_factory import ModelFactory
48
+
49
+ l_pdf = ['dscb', 'gauss']
50
+ l_shr = ['mu']
51
+ mod = ModelFactory(obs = obs, l_pdf = l_pdf, l_shared=l_shr)
52
+ pdf = mod.get_pdf()
53
+ ```
54
+
55
+ where one can specify which parameters can be shared among the PDFs
56
+ '''
57
+ #-----------------------------------------
58
+ def __init__(self, obs : zobs, l_pdf : list[str], l_shared : list[str]):
59
+ '''
60
+ obs: zfit obserbable
61
+ l_pdf: List of PDF nicknames which are registered below
62
+ l_shared: List of parameter names that are shared
63
+ '''
64
+
65
+ self._l_pdf = l_pdf
66
+ self._l_shr = l_shared
67
+ self._l_can_be_shared = ['mu', 'sg']
68
+ self._obs = obs
69
+
70
+ self._d_par : dict[str,zpar] = {}
71
+ #-----------------------------------------
72
+ def _get_name(self, name : str, suffix : str) -> str:
73
+ for can_be_shared in self._l_can_be_shared:
74
+ if name.startswith(f'{can_be_shared}_') and can_be_shared in self._l_shr:
75
+ return can_be_shared
76
+
77
+ return f'{name}{suffix}'
78
+ #-----------------------------------------
79
+ def _get_parameter(self,
80
+ name : str,
81
+ suffix : str,
82
+ val : float,
83
+ low : float,
84
+ high : float) -> zpar:
85
+ name = self._get_name(name, suffix)
86
+ if name in self._d_par:
87
+ return self._d_par[name]
88
+
89
+ par = zfit.param.Parameter(name, val, low, high)
90
+
91
+ self._d_par[name] = par
92
+
93
+ return par
94
+ #-----------------------------------------
95
+ @MethodRegistry.register('exp')
96
+ def _get_exponential(self, suffix : str = '') -> zpdf:
97
+ c = self._get_parameter('c_exp', suffix, -0.005, -0.05, 0.00)
98
+ pdf = zfit.pdf.Exponential(c, self._obs)
99
+
100
+ return pdf
101
+ #-----------------------------------------
102
+ @MethodRegistry.register('pol1')
103
+ def _get_pol1(self, suffix : str = '') -> zpdf:
104
+ a = self._get_parameter('a_pol1', suffix, -0.005, -0.95, 0.00)
105
+ pdf = zfit.pdf.Chebyshev(obs=self._obs, coeffs=[a])
106
+
107
+ return pdf
108
+ #-----------------------------------------
109
+ @MethodRegistry.register('pol2')
110
+ def _get_pol2(self, suffix : str = '') -> zpdf:
111
+ a = self._get_parameter('a_pol2', suffix, -0.005, -0.95, 0.00)
112
+ b = self._get_parameter('b_pol2', suffix, 0.000, -0.95, 0.95)
113
+ pdf = zfit.pdf.Chebyshev(obs=self._obs, coeffs=[a, b])
114
+
115
+ return pdf
116
+ #-----------------------------------------
117
+ @MethodRegistry.register('cbr')
118
+ def _get_cbr(self, suffix : str = '') -> zpdf:
119
+ mu = self._get_parameter('mu_cbr', suffix, 5300, 5250, 5350)
120
+ sg = self._get_parameter('sg_cbr', suffix, 10, 2, 300)
121
+ ar = self._get_parameter('ac_cbr', suffix, -2, -4., -1.)
122
+ nr = self._get_parameter('nc_cbr', suffix, 1, 0.5, 5.0)
123
+
124
+ pdf = zfit.pdf.CrystalBall(mu, sg, ar, nr, self._obs)
125
+
126
+ return pdf
127
+ #-----------------------------------------
128
+ @MethodRegistry.register('cbl')
129
+ def _get_cbl(self, suffix : str = '') -> zpdf:
130
+ mu = self._get_parameter('mu_cbl', suffix, 5300, 5250, 5350)
131
+ sg = self._get_parameter('sg_cbl', suffix, 10, 2, 300)
132
+ al = self._get_parameter('ac_cbl', suffix, 2, 1., 4.)
133
+ nl = self._get_parameter('nc_cbl', suffix, 1, 0.5, 5.0)
134
+
135
+ pdf = zfit.pdf.CrystalBall(mu, sg, al, nl, self._obs)
136
+
137
+ return pdf
138
+ #-----------------------------------------
139
+ @MethodRegistry.register('gauss')
140
+ def _get_gauss(self, suffix : str = '') -> zpdf:
141
+ mu = self._get_parameter('mu_gauss', suffix, 5300, 5250, 5350)
142
+ sg = self._get_parameter('sg_gauss', suffix, 10, 2, 300)
143
+
144
+ pdf = zfit.pdf.Gauss(mu, sg, self._obs)
145
+
146
+ return pdf
147
+ #-----------------------------------------
148
+ @MethodRegistry.register('dscb')
149
+ def _get_dscb(self, suffix : str = '') -> zpdf:
150
+ mu = self._get_parameter('mu_dscb', suffix, 5300, 5250, 5400)
151
+ sg = self._get_parameter('sg_dscb', suffix, 10, 2, 30)
152
+ ar = self._get_parameter('ar_dscb', suffix, 1, 0, 5)
153
+ al = self._get_parameter('al_dscb', suffix, 1, 0, 5)
154
+ nr = self._get_parameter('nr_dscb', suffix, 2, 1, 5)
155
+ nl = self._get_parameter('nl_dscb', suffix, 2, 0, 5)
156
+
157
+ pdf = zfit.pdf.DoubleCB(mu, sg, al, nl, ar, nr, self._obs)
158
+
159
+ return pdf
160
+ #-----------------------------------------
161
+ def _get_pdf_types(self) -> list[tuple[str,str]]:
162
+ d_name_freq = {}
163
+
164
+ l_type = []
165
+ for name in self._l_pdf:
166
+ if name not in d_name_freq:
167
+ d_name_freq[name] = 1
168
+ else:
169
+ d_name_freq[name]+= 1
170
+
171
+ frq = d_name_freq[name]
172
+ frq = f'_{frq}'
173
+
174
+ l_type.append((name, frq))
175
+
176
+ return l_type
177
+ #-----------------------------------------
178
+ def _get_pdf(self, kind : str, preffix : str) -> zpdf:
179
+ fun = MethodRegistry.get_method(kind)
180
+ if fun is None:
181
+ raise NotImplementedError(f'PDF of type {kind} is not implemented')
182
+
183
+ return fun(self, preffix)
184
+ #-----------------------------------------
185
+ def _add_pdf(self, l_pdf : list[zpdf]) -> zpdf:
186
+ nfrc = len(l_pdf)
187
+ if nfrc == 1:
188
+ log.debug('Requested only one PDF, skipping sum')
189
+ return l_pdf[0]
190
+
191
+ l_frc= [ zfit.param.Parameter(f'frc_{ifrc + 1}', 0.5, 0, 1) for ifrc in range(nfrc - 1) ]
192
+
193
+ pdf = zfit.pdf.SumPDF(l_pdf, fracs=l_frc)
194
+
195
+ return pdf
196
+ #-----------------------------------------
197
+ def get_pdf(self) -> zpdf:
198
+ '''
199
+ Given a list of strings representing PDFs returns the a zfit PDF which is
200
+ the sum of them
201
+ '''
202
+ l_type= self._get_pdf_types()
203
+ l_pdf = [ self._get_pdf(kind, preffix) for kind, preffix in l_type ]
204
+ pdf = self._add_pdf(l_pdf)
205
+
206
+ return pdf
207
+ #-----------------------------------------