data-manipulation-utilities 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data_manipulation_utilities-0.0.1.dist-info/METADATA +713 -0
- data_manipulation_utilities-0.0.1.dist-info/RECORD +45 -0
- data_manipulation_utilities-0.0.1.dist-info/WHEEL +5 -0
- data_manipulation_utilities-0.0.1.dist-info/entry_points.txt +6 -0
- data_manipulation_utilities-0.0.1.dist-info/top_level.txt +3 -0
- dmu/arrays/utilities.py +55 -0
- dmu/dataframe/dataframe.py +36 -0
- dmu/generic/utilities.py +69 -0
- dmu/logging/log_store.py +129 -0
- dmu/ml/cv_classifier.py +122 -0
- dmu/ml/cv_predict.py +152 -0
- dmu/ml/train_mva.py +257 -0
- dmu/ml/utilities.py +132 -0
- dmu/plotting/plotter.py +227 -0
- dmu/plotting/plotter_1d.py +113 -0
- dmu/plotting/plotter_2d.py +87 -0
- dmu/rdataframe/atr_mgr.py +79 -0
- dmu/rdataframe/utilities.py +72 -0
- dmu/rfile/rfprinter.py +91 -0
- dmu/rfile/utilities.py +34 -0
- dmu/stats/fitter.py +515 -0
- dmu/stats/function.py +314 -0
- dmu/stats/utilities.py +134 -0
- dmu/testing/utilities.py +119 -0
- dmu/text/transformer.py +182 -0
- dmu_data/__init__.py +0 -0
- dmu_data/ml/tests/train_mva.yaml +37 -0
- dmu_data/plotting/tests/2d.yaml +14 -0
- dmu_data/plotting/tests/fig_size.yaml +13 -0
- dmu_data/plotting/tests/high_stat.yaml +22 -0
- dmu_data/plotting/tests/name.yaml +14 -0
- dmu_data/plotting/tests/no_bounds.yaml +12 -0
- dmu_data/plotting/tests/simple.yaml +8 -0
- dmu_data/plotting/tests/title.yaml +14 -0
- dmu_data/plotting/tests/weights.yaml +13 -0
- dmu_data/text/transform.toml +4 -0
- dmu_data/text/transform.txt +6 -0
- dmu_data/text/transform_set.toml +8 -0
- dmu_data/text/transform_set.txt +6 -0
- dmu_data/text/transform_trf.txt +12 -0
- dmu_scripts/physics/check_truth.py +121 -0
- dmu_scripts/rfile/compare_root_files.py +299 -0
- dmu_scripts/rfile/print_trees.py +35 -0
- dmu_scripts/ssh/coned.py +168 -0
- dmu_scripts/text/transform_text.py +46 -0
dmu/stats/fitter.py
ADDED
@@ -0,0 +1,515 @@
|
|
1
|
+
'''
|
2
|
+
Module holding zfitter class
|
3
|
+
'''
|
4
|
+
|
5
|
+
import pprint
|
6
|
+
from typing import Union
|
7
|
+
|
8
|
+
import numpy
|
9
|
+
import zfit
|
10
|
+
import pandas as pd
|
11
|
+
|
12
|
+
from scipy import stats
|
13
|
+
from zfit.minimizers.strategy import FailMinimizeNaN
|
14
|
+
from zfit.result import FitResult
|
15
|
+
from zfit.core.data import Data
|
16
|
+
from dmu.logging.log_store import LogStore
|
17
|
+
|
18
|
+
log = LogStore.add_logger('dmu:statistics:fitter')
|
19
|
+
#------------------------------
|
20
|
+
class FitterGofError(Exception):
|
21
|
+
'''
|
22
|
+
Exception used when GoF cannot be calculated
|
23
|
+
'''
|
24
|
+
#------------------------------
|
25
|
+
class FitterFailedFit(Exception):
|
26
|
+
'''
|
27
|
+
Exception used when fitter fails
|
28
|
+
'''
|
29
|
+
#------------------------------
|
30
|
+
class Fitter:
|
31
|
+
'''
|
32
|
+
Class meant to be an interface to underlying fitters
|
33
|
+
'''
|
34
|
+
# pylint: disable=too-many-instance-attributes
|
35
|
+
#------------------------------
|
36
|
+
def __init__(self, pdf, data):
|
37
|
+
self._data_in = data
|
38
|
+
self._pdf = pdf
|
39
|
+
|
40
|
+
self._data_zf : zfit.data.Data
|
41
|
+
self._data_np : numpy.ndarray
|
42
|
+
self._obs : zfit.Space
|
43
|
+
self._d_par : dict
|
44
|
+
|
45
|
+
self._ndof = 10
|
46
|
+
self._pval_threshold = 0.01
|
47
|
+
self._initialized = False
|
48
|
+
#------------------------------
|
49
|
+
def _initialize(self):
|
50
|
+
if self._initialized:
|
51
|
+
return
|
52
|
+
|
53
|
+
self._check_data()
|
54
|
+
|
55
|
+
self._intialized = True
|
56
|
+
#------------------------------
|
57
|
+
def _check_data(self):
|
58
|
+
if isinstance(self._data_in, numpy.ndarray):
|
59
|
+
data_np = self._data_in
|
60
|
+
elif isinstance(self._data_in, zfit.Data):
|
61
|
+
data_np = zfit.run(zfit.z.unstack_x(self._data_in)) # convert original data to numpy array, needed by _calc_gof
|
62
|
+
elif isinstance(self._data_in, pd.DataFrame):
|
63
|
+
data_np = self._data_in.to_numpy()
|
64
|
+
elif isinstance(self._data_in, pd.Series):
|
65
|
+
self._data_in = pd.DataFrame(self._data_in)
|
66
|
+
data_np = self._data_in.to_numpy()
|
67
|
+
else:
|
68
|
+
data_type = str(type(self._data_in))
|
69
|
+
raise ValueError(f'Data is not a numpy array, zfit.Data or pandas.DataFrame, but {data_type}')
|
70
|
+
|
71
|
+
data_np = self._check_numpy_data(data_np)
|
72
|
+
self._data_np = data_np
|
73
|
+
if not isinstance(self._data_in, zfit.Data):
|
74
|
+
self._data_zf = zfit.Data.from_numpy(obs=self._pdf.space, array=data_np)
|
75
|
+
else:
|
76
|
+
self._data_zf = self._data_in
|
77
|
+
#------------------------------
|
78
|
+
def _check_numpy_data(self, data):
|
79
|
+
shp = data.shape
|
80
|
+
if len(shp) == 1:
|
81
|
+
pass
|
82
|
+
elif len(shp) == 2:
|
83
|
+
_, jval = shp
|
84
|
+
if jval != 1:
|
85
|
+
log.error(f'Invalid data shape: {shp}')
|
86
|
+
raise
|
87
|
+
else:
|
88
|
+
log.error(f'Invalid data shape: {shp}')
|
89
|
+
raise
|
90
|
+
|
91
|
+
ival = data.shape[0]
|
92
|
+
|
93
|
+
data = data[~numpy.isnan(data)]
|
94
|
+
data = data[~numpy.isinf(data)]
|
95
|
+
|
96
|
+
fval = data.shape[0]
|
97
|
+
|
98
|
+
if ival != fval:
|
99
|
+
log.warning(f'Data was trimmed for inf and nan: {ival} -> {fval}')
|
100
|
+
|
101
|
+
return data
|
102
|
+
#------------------------------
|
103
|
+
def _bin_pdf(self, nbins):
|
104
|
+
[[min_x]], [[max_x]] = self._pdf.space.limits
|
105
|
+
_, arr_edg = numpy.histogram(self._data_np, bins = nbins, range=(min_x, max_x))
|
106
|
+
|
107
|
+
size = arr_edg.size
|
108
|
+
|
109
|
+
l_bc = []
|
110
|
+
for i_edg in range(size - 1):
|
111
|
+
low = arr_edg[i_edg + 0]
|
112
|
+
hig = arr_edg[i_edg + 1]
|
113
|
+
|
114
|
+
var = self._pdf.integrate(limits = [low, hig])
|
115
|
+
val = var.numpy()[0]
|
116
|
+
l_bc.append(val * self._data_np.size)
|
117
|
+
|
118
|
+
return numpy.array(l_bc)
|
119
|
+
#------------------------------
|
120
|
+
def _get_binning(self):
|
121
|
+
min_x = numpy.min(self._data_np)
|
122
|
+
max_x = numpy.max(self._data_np)
|
123
|
+
nbins = self._ndof + self._get_float_pars()
|
124
|
+
|
125
|
+
return nbins, min_x, max_x
|
126
|
+
#------------------------------
|
127
|
+
def _calc_gof(self):
|
128
|
+
log.debug('Calculating GOF')
|
129
|
+
nbins, min_x, max_x = self._get_binning()
|
130
|
+
|
131
|
+
log.debug(f'Nbins: {nbins}')
|
132
|
+
log.debug(f'Range: [{min_x:.3f}, {max_x:.3f}]')
|
133
|
+
|
134
|
+
arr_data, _ = numpy.histogram(self._data_np, bins = nbins, range=(min_x, max_x))
|
135
|
+
arr_data = arr_data.astype(float)
|
136
|
+
arr_modl = self._bin_pdf(nbins)
|
137
|
+
norm = numpy.sum(arr_data) / numpy.sum(arr_modl)
|
138
|
+
arr_modl = norm * arr_modl
|
139
|
+
arr_res = arr_modl - arr_data
|
140
|
+
|
141
|
+
arr_chi2 = numpy.divide(arr_res ** 2, arr_data, out=numpy.zeros_like(arr_data), where=arr_data!=0)
|
142
|
+
sum_chi2 = numpy.sum(arr_chi2)
|
143
|
+
pvalue = 1 - stats.chi2.cdf(sum_chi2, self._ndof)
|
144
|
+
|
145
|
+
log.debug(f'{"Data":<20}{"Model":<20}{"chi2":<20}')
|
146
|
+
if pvalue < self._pval_threshold:
|
147
|
+
for data, modl, chi2 in zip(arr_data, arr_modl, arr_chi2):
|
148
|
+
log.debug(f'{data:<20.0f}{modl:<20.3f}{chi2:<20.3f}')
|
149
|
+
|
150
|
+
log.debug(f'Chi2: {sum_chi2:.3f}')
|
151
|
+
log.debug(f'Ndof: {self._ndof}')
|
152
|
+
log.debug(f'pval: {pvalue:<.3e}')
|
153
|
+
|
154
|
+
return (sum_chi2, self._ndof, pvalue)
|
155
|
+
#------------------------------
|
156
|
+
def _get_float_pars(self):
|
157
|
+
npar = 0
|
158
|
+
s_par = self._pdf.get_params()
|
159
|
+
for par in s_par:
|
160
|
+
if par.floating:
|
161
|
+
npar+=1
|
162
|
+
|
163
|
+
self._d_par = {par.name : par for par in s_par}
|
164
|
+
|
165
|
+
return npar
|
166
|
+
#------------------------------
|
167
|
+
def _reshuffle_pdf_pars(self):
|
168
|
+
'''
|
169
|
+
Will move floating parameters of PDF according
|
170
|
+
to uniform PDF
|
171
|
+
'''
|
172
|
+
|
173
|
+
s_par = self._pdf.get_params(floating=True)
|
174
|
+
for par in s_par:
|
175
|
+
ival = par.value()
|
176
|
+
fval = numpy.random.uniform(par.lower, par.upper)
|
177
|
+
par.set_value(fval)
|
178
|
+
log.debug(f'{par.name:<20}{ival:<15.3f}{"->":<10}{fval:<15.3f}{"in":<5}{par.lower:<15.3e}{par.upper:<15.3e}')
|
179
|
+
#------------------------------
|
180
|
+
def _set_pdf_pars(self, res):
|
181
|
+
'''
|
182
|
+
Will set the PDF floating parameter values as the result instance
|
183
|
+
'''
|
184
|
+
l_par_flt = list(self._pdf.get_params(floating= True))
|
185
|
+
l_par_fix = list(self._pdf.get_params(floating=False))
|
186
|
+
l_par = l_par_flt + l_par_fix
|
187
|
+
|
188
|
+
d_val = { par.name : dc['value'] for par, dc in res.params.items()}
|
189
|
+
|
190
|
+
log.debug('Setting PDF parameters to best result')
|
191
|
+
for par in l_par:
|
192
|
+
if par.name not in d_val:
|
193
|
+
log.debug(f'Skipping {par.name} = {par.value().numpy():.3e}')
|
194
|
+
continue
|
195
|
+
|
196
|
+
val = d_val[par.name]
|
197
|
+
log.debug(f'{"":<4}{par.name:<20}{"->":<10}{val:<20.3e}')
|
198
|
+
par.set_value(val)
|
199
|
+
#------------------------------
|
200
|
+
def _get_constraints(self, cfg : dict) -> list[zfit.constraint.GaussianConstraint]:
|
201
|
+
'''
|
202
|
+
Takes dictionary of constraints as floats, returns list of GaussianConstraint objects
|
203
|
+
'''
|
204
|
+
if 'constraints' not in cfg:
|
205
|
+
log.debug('Not using any constraint')
|
206
|
+
return []
|
207
|
+
|
208
|
+
d_const = cfg['constraints']
|
209
|
+
s_par = self._pdf.get_params(floating=True)
|
210
|
+
d_par = { par.name : par for par in s_par}
|
211
|
+
|
212
|
+
log.info('Adding constraints:')
|
213
|
+
l_const = []
|
214
|
+
for par_name, (par_mu, par_sg) in d_const.items():
|
215
|
+
if par_name not in d_par:
|
216
|
+
log.error(s_par)
|
217
|
+
raise ValueError(f'Parameter {par_name} not found among floating parameters of model, above')
|
218
|
+
|
219
|
+
par = d_par[par_name]
|
220
|
+
|
221
|
+
if par_sg == 0:
|
222
|
+
par.floating = False
|
223
|
+
log.info(f'{"":<4}{par_name:<15}{par_mu:<15.3e}{par_sg:<15.3e}')
|
224
|
+
continue
|
225
|
+
|
226
|
+
const = zfit.constraint.GaussianConstraint(params=par, observation=float(par_mu), uncertainty=float(par_sg))
|
227
|
+
log.info(f'{"":<4}{par_name:<25}{par_mu:<15.3e}{par_sg:<15.3e}')
|
228
|
+
l_const.append(const)
|
229
|
+
|
230
|
+
return l_const
|
231
|
+
#------------------------------
|
232
|
+
def _get_ranges(self, cfg : dict) -> list:
|
233
|
+
if 'ranges' not in cfg:
|
234
|
+
return [None]
|
235
|
+
|
236
|
+
ranges = cfg['ranges']
|
237
|
+
log.info('-' * 30)
|
238
|
+
log.info(f'{"Low edge":>15}{"High edge":>15}')
|
239
|
+
log.info('-' * 30)
|
240
|
+
for rng in ranges:
|
241
|
+
log.info(f'{rng[0]:>15.3e}{rng[1]:>15.3e}')
|
242
|
+
|
243
|
+
return ranges
|
244
|
+
#------------------------------
|
245
|
+
def _get_subdataset(self, cfg : dict) -> Data:
|
246
|
+
if 'nentries' not in cfg:
|
247
|
+
return self._data_zf
|
248
|
+
|
249
|
+
nentries_out = cfg['nentries']
|
250
|
+
arr_inp = self._data_zf.to_numpy().flatten()
|
251
|
+
nentries_inp = len(arr_inp)
|
252
|
+
if nentries_inp <= nentries_out:
|
253
|
+
log.warning(f'Input dataset in smaller than output dataset, {nentries_inp} < {nentries_out}')
|
254
|
+
return self._data_zf
|
255
|
+
|
256
|
+
has_weights = self._data_zf.weights is not None
|
257
|
+
|
258
|
+
if has_weights:
|
259
|
+
arr_wgt = self._data_zf.weights.numpy()
|
260
|
+
arr_inp = numpy.array([arr_inp, arr_wgt]).T
|
261
|
+
|
262
|
+
arr_out = numpy.random.choice(arr_inp, size=nentries_out, replace=False)
|
263
|
+
if has_weights:
|
264
|
+
arr_out = arr_out.T[0]
|
265
|
+
arr_wgt = arr_out.T[1]
|
266
|
+
else:
|
267
|
+
arr_wgt = None
|
268
|
+
|
269
|
+
data = zfit.data.from_numpy(array=arr_out, weights=arr_wgt, obs=self._data_zf.obs)
|
270
|
+
|
271
|
+
return data
|
272
|
+
#------------------------------
|
273
|
+
def _get_binned_observable(self, nbins : int):
|
274
|
+
obs = self._pdf.space
|
275
|
+
[[minx]], [[maxx]] = obs.limits
|
276
|
+
|
277
|
+
binning = zfit.binned.RegularBinning(nbins, minx, maxx, name=obs.label)
|
278
|
+
obs_bin = zfit.Space(obs.label, binning=binning)
|
279
|
+
|
280
|
+
return obs_bin
|
281
|
+
#------------------------------
|
282
|
+
def _get_nbins(self, cfg : dict) -> Union[None, int]:
|
283
|
+
if 'likelihood' not in cfg:
|
284
|
+
return None
|
285
|
+
|
286
|
+
if 'nbins' not in cfg['likelihood']:
|
287
|
+
return None
|
288
|
+
|
289
|
+
return cfg['likelihood']['nbins']
|
290
|
+
#------------------------------
|
291
|
+
def _get_nll(self, data_zf, constraints, frange, cfg):
|
292
|
+
nbins = self._get_nbins(cfg)
|
293
|
+
if nbins is None:
|
294
|
+
log.info('No binning was specified, will do unbinned fit')
|
295
|
+
pdf = self._pdf
|
296
|
+
else:
|
297
|
+
log.info(f'Using {nbins} bins for fit')
|
298
|
+
obs = self._get_binned_observable(nbins)
|
299
|
+
pdf = zfit.pdf.BinnedFromUnbinnedPDF(self._pdf, obs)
|
300
|
+
data_zf = data_zf.to_binned(obs)
|
301
|
+
|
302
|
+
if not self._pdf.is_extended and nbins is None:
|
303
|
+
nll = zfit.loss.UnbinnedNLL( model=pdf, data=data_zf, constraints=constraints, fit_range=frange)
|
304
|
+
return nll
|
305
|
+
|
306
|
+
if self._pdf.is_extended and nbins is None:
|
307
|
+
nll = zfit.loss.ExtendedUnbinnedNLL(model=pdf, data=data_zf, constraints=constraints, fit_range=frange)
|
308
|
+
return nll
|
309
|
+
|
310
|
+
if frange is not None:
|
311
|
+
raise ValueError('Fit range cannot be defined for binned likelihoods')
|
312
|
+
|
313
|
+
if not self._pdf.is_extended:
|
314
|
+
nll = zfit.loss.BinnedNLL( model=pdf, data=data_zf, constraints=constraints)
|
315
|
+
return nll
|
316
|
+
|
317
|
+
if self._pdf.is_extended:
|
318
|
+
nll = zfit.loss.ExtendedBinnedNLL( model=pdf, data=data_zf, constraints=constraints)
|
319
|
+
return nll
|
320
|
+
|
321
|
+
raise ValueError('Likelihood was neither Binned nor Unbinned nor Extended nor non-extended')
|
322
|
+
#------------------------------
|
323
|
+
def _get_full_nll(self, cfg : dict):
|
324
|
+
constraints = self._get_constraints(cfg)
|
325
|
+
ranges = self._get_ranges(cfg)
|
326
|
+
data_zf = self._get_subdataset(cfg)
|
327
|
+
l_nll = [ self._get_nll(data_zf, constraints, frange, cfg) for frange in ranges ]
|
328
|
+
nll = sum(l_nll[1:], l_nll[0])
|
329
|
+
|
330
|
+
return nll
|
331
|
+
#------------------------------
|
332
|
+
def _print_pars(self, cfg : dict):
|
333
|
+
'''
|
334
|
+
Will print current values parameters in cfg['print_pars'] list, if present
|
335
|
+
'''
|
336
|
+
|
337
|
+
if 'print_pars' not in cfg:
|
338
|
+
return
|
339
|
+
|
340
|
+
l_par_name = cfg['print_pars']
|
341
|
+
d_par_val = { name : par.value().numpy() for name, par in self._d_par.items() if name in l_par_name}
|
342
|
+
|
343
|
+
l_name = list(d_par_val.keys())
|
344
|
+
l_value= list(d_par_val.values())
|
345
|
+
|
346
|
+
l_form = [ f'{var:<10}' for var in l_name]
|
347
|
+
header = ''.join(l_form)
|
348
|
+
|
349
|
+
l_form = [f'{val:<10.3f}' for val in l_value]
|
350
|
+
parval = ''.join(l_form)
|
351
|
+
|
352
|
+
log.info(header)
|
353
|
+
log.info(parval)
|
354
|
+
#------------------------------
|
355
|
+
def _minimize(self, nll, cfg : dict) -> tuple[FitResult, tuple]:
|
356
|
+
mnm = zfit.minimize.Minuit()
|
357
|
+
res = mnm.minimize(nll)
|
358
|
+
|
359
|
+
try:
|
360
|
+
gof = self._calc_gof()
|
361
|
+
except FitterGofError as exc:
|
362
|
+
raise FitterGofError('Cannot calculate GOF') from exc
|
363
|
+
|
364
|
+
chi2, _, pval = gof
|
365
|
+
stat = res.status
|
366
|
+
|
367
|
+
log.info(f'{chi2:<10.3f}{pval:<10.3e}{stat:<10}')
|
368
|
+
self._print_pars(cfg)
|
369
|
+
|
370
|
+
return res, gof
|
371
|
+
#------------------------------
|
372
|
+
def _fit_retries(self, cfg : dict) -> tuple[dict, FitResult]:
|
373
|
+
ntries = cfg['strategy']['retry']['ntries']
|
374
|
+
pvalue_thresh= cfg['strategy']['retry']['pvalue_thresh']
|
375
|
+
ignore_status= cfg['strategy']['retry']['ignore_status']
|
376
|
+
|
377
|
+
nll = self._get_full_nll(cfg = cfg)
|
378
|
+
d_pval_res = {}
|
379
|
+
last_res = None
|
380
|
+
for i_try in range(ntries):
|
381
|
+
try:
|
382
|
+
res, gof = self._minimize(nll, cfg)
|
383
|
+
except (FailMinimizeNaN, FitterGofError, RuntimeError):
|
384
|
+
self._reshuffle_pdf_pars()
|
385
|
+
log.warning(f'Fit {i_try:03}/{ntries:03} failed due to exception')
|
386
|
+
continue
|
387
|
+
|
388
|
+
last_res = res
|
389
|
+
bad_fit = res.status != 0 or not res.valid
|
390
|
+
|
391
|
+
if not ignore_status and bad_fit:
|
392
|
+
self._reshuffle_pdf_pars()
|
393
|
+
log.info(f'Fit {i_try:03}/{ntries:03} failed, status/validity: {res.status}/{res.valid}')
|
394
|
+
continue
|
395
|
+
|
396
|
+
chi2, _, pval = gof
|
397
|
+
d_pval_res[chi2]=res
|
398
|
+
|
399
|
+
if pval > pvalue_thresh:
|
400
|
+
log.info(f'Reached {pval:.3f} (> {pvalue_thresh:.3f}) threshold after {i_try + 1} attempts')
|
401
|
+
return {chi2 : res}, res
|
402
|
+
|
403
|
+
self._reshuffle_pdf_pars()
|
404
|
+
|
405
|
+
if last_res is None:
|
406
|
+
raise FitterFailedFit('Cannot find any valid fit')
|
407
|
+
|
408
|
+
return d_pval_res, last_res
|
409
|
+
#------------------------------
|
410
|
+
def _pick_best_fit(self, d_pval_res : dict, last_res : FitResult) -> FitResult:
|
411
|
+
nsucc = len(d_pval_res)
|
412
|
+
if nsucc == 0:
|
413
|
+
log.warning('None of the fits succeeded, returning last result')
|
414
|
+
self._set_pdf_pars(last_res)
|
415
|
+
|
416
|
+
return last_res
|
417
|
+
|
418
|
+
l_pval_res= list(d_pval_res.items())
|
419
|
+
l_pval_res.sort()
|
420
|
+
_, res = l_pval_res[0]
|
421
|
+
|
422
|
+
log.debug('Picking out best fit from {nsucc} fits')
|
423
|
+
for chi2, _ in l_pval_res:
|
424
|
+
log.debug(f'{chi2:.3f}')
|
425
|
+
|
426
|
+
self._set_pdf_pars(res)
|
427
|
+
|
428
|
+
return res
|
429
|
+
#------------------------------
|
430
|
+
def _fit_in_steps(self, cfg : dict) -> FitResult:
|
431
|
+
l_nsample = cfg['strategy']['steps']['nsteps']
|
432
|
+
l_nsigma = cfg['strategy']['steps']['nsigma']
|
433
|
+
l_yield = cfg['strategy']['steps']['yields']
|
434
|
+
|
435
|
+
res = None
|
436
|
+
for nsample, nsigma in zip(l_nsample, l_nsigma):
|
437
|
+
log.info(f'Fitting with {nsample} samples')
|
438
|
+
cfg_step = dict(cfg)
|
439
|
+
cfg_step['nentries'] = nsample
|
440
|
+
|
441
|
+
nll = self._get_full_nll(cfg = cfg_step)
|
442
|
+
res, _ = self._minimize(nll, cfg_step)
|
443
|
+
res.hesse(method='minuit_hesse')
|
444
|
+
self._update_par_bounds(res, nsigma=nsigma, yields=l_yield)
|
445
|
+
|
446
|
+
log.info('Fitting full sample')
|
447
|
+
nll = self._get_full_nll(cfg = cfg)
|
448
|
+
res, _ = self._minimize(nll, cfg)
|
449
|
+
res.hesse(method='minuit_hesse')
|
450
|
+
|
451
|
+
if res is None:
|
452
|
+
nsteps = len(l_nsample)
|
453
|
+
raise ValueError(f'No fit out of {nsteps} was done')
|
454
|
+
|
455
|
+
return res
|
456
|
+
#------------------------------
|
457
|
+
def _result_to_value_error(self, res : FitResult) -> dict[str, list[float]]:
|
458
|
+
d_par = {}
|
459
|
+
for par, d_val in res.params.items():
|
460
|
+
try:
|
461
|
+
val = d_val['value']
|
462
|
+
err = d_val['hesse']['error']
|
463
|
+
except KeyError as exc:
|
464
|
+
pprint.pprint(d_val)
|
465
|
+
raise KeyError(f'Cannot extract value, hesse or error from dictionary above') from exc
|
466
|
+
|
467
|
+
d_par[par.name] = [val, err]
|
468
|
+
|
469
|
+
return d_par
|
470
|
+
#------------------------------
|
471
|
+
def _update_par_bounds(self, res : FitResult, nsigma : float, yields : list[str]) -> None:
|
472
|
+
s_shape_par = self._pdf.get_params(is_yield=False, floating=True)
|
473
|
+
d_shp_par = { par.name : par for par in s_shape_par if par.name not in yields}
|
474
|
+
d_fit_par = self._result_to_value_error(res)
|
475
|
+
|
476
|
+
log.info(60 * '-')
|
477
|
+
log.info(f'{"Parameter":<20}{"Low bound":<20}{"High bound":<20}')
|
478
|
+
log.info(60 * '-')
|
479
|
+
for name, [val, err] in d_fit_par.items():
|
480
|
+
if name not in d_shp_par:
|
481
|
+
log.debug(f'Skipping {name} parameter')
|
482
|
+
continue
|
483
|
+
|
484
|
+
shape = d_shp_par[name]
|
485
|
+
shape.lower = val - nsigma * err
|
486
|
+
shape.upper = val + nsigma * err
|
487
|
+
|
488
|
+
log.info(f'{name:<20}{val - err:<20.3e}{val + err:<20.3e}')
|
489
|
+
#------------------------------
|
490
|
+
def fit(self, cfg : Union[dict, None] = None):
|
491
|
+
'''
|
492
|
+
Runs the fit using the configuration specified by the cfg dictionary
|
493
|
+
|
494
|
+
Returns fit result
|
495
|
+
'''
|
496
|
+
self._initialize()
|
497
|
+
|
498
|
+
cfg = {} if cfg is None else cfg
|
499
|
+
|
500
|
+
log.info(f'{"chi2":<10}{"pval":<10}{"stat":<10}')
|
501
|
+
if 'strategy' not in cfg:
|
502
|
+
nll = self._get_full_nll(cfg = cfg)
|
503
|
+
res, _ = self._minimize(nll, cfg)
|
504
|
+
res.hesse(method='minuit_hesse')
|
505
|
+
elif 'retry' in cfg['strategy']:
|
506
|
+
d_pval_res, last_res = self._fit_retries(cfg)
|
507
|
+
res = self._pick_best_fit(d_pval_res, last_res)
|
508
|
+
elif 'steps' in cfg['strategy']:
|
509
|
+
res = self._fit_in_steps(cfg)
|
510
|
+
else:
|
511
|
+
raise ValueError('Unsupported fitting strategy')
|
512
|
+
|
513
|
+
|
514
|
+
return res
|
515
|
+
#------------------------------
|