data-manipulation-utilities 0.1.6__py3-none-any.whl → 0.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,527 @@
1
+ '''
2
+ Module containing plot class, used to plot fits
3
+ '''
4
+ # pylint: disable=too-many-instance-attributes
5
+
6
+ import warnings
7
+ import pprint
8
+
9
+ import zfit
10
+ import hist
11
+ import mplhep
12
+ import pandas as pd
13
+ import numpy as np
14
+ import matplotlib.pyplot as plt
15
+ import dmu.generic.utilities as gut
16
+
17
+ from dmu.logging.log_store import LogStore
18
+
19
+ log = LogStore.add_logger('dmu:fit_plotter')
20
+ #----------------------------------------
21
+ class ZFitPlotter:
22
+ '''
23
+ Class used to plot fits done with zfit
24
+ '''
25
+ def __init__(self, data=None, model=None, weights=None, result=None, suffix=''):
26
+ '''
27
+ obs: zfit space you are using to define the data and model
28
+ data: the data you are fit on
29
+ weights: 1D numpy array of weights
30
+ total_model: the final total fit model
31
+ '''
32
+ # pylint: disable=too-many-positional-arguments
33
+
34
+ self.obs = model.space
35
+ self.data = self._data_to_zdata(model.space, data, weights)
36
+ self.lower, self.upper = self.data.data_range.limit1d
37
+ self.total_model = model
38
+ self.x = np.linspace(self.lower, self.upper, 2000)
39
+ self.data_np = zfit.run(self.data.unstack_x())
40
+ self.data_weight_np = np.ones_like(self.data_np) if self.data.weights is None else zfit.run(self.data.weights)
41
+
42
+ self.errors = []
43
+ self._l_def_col = []
44
+ self._result = result
45
+ self._suffix = suffix
46
+ self._leg = {}
47
+ self._col = {}
48
+ self._l_blind = None
49
+ self._l_plot_components= None
50
+ self.axs = None
51
+ self._figsize = None
52
+ self._leg_loc = None
53
+
54
+ # zfit.settings.advanced_warnings['extend_wrapped_extended'] = False
55
+ warnings.filterwarnings("ignore")
56
+ #----------------------------------------
57
+ def _initialize(self):
58
+ import matplotlib.colors as mcolors
59
+
60
+ self._l_def_col = list(mcolors.TABLEAU_COLORS.keys())
61
+ #----------------------------------------
62
+ def _data_to_zdata(self, obs, data, weights):
63
+ if isinstance(data, np.ndarray):
64
+ data = zfit.Data.from_numpy (obs=obs, array=data , weights=weights)
65
+ elif isinstance(data, pd.Series):
66
+ data = zfit.Data.from_pandas(obs=obs, df=pd.DataFrame(data), weights=weights)
67
+ elif isinstance(data, pd.DataFrame):
68
+ data = zfit.Data.from_pandas(obs=obs, df=data , weights=weights)
69
+ elif isinstance(data, zfit.data.Data):
70
+ data = data
71
+ else:
72
+ log.error(f'Passed data is of usupported type {type(data)}')
73
+ raise
74
+
75
+ return data
76
+ #----------------------------------------
77
+ def _get_errors(self, nbins=100, l_range=None):
78
+ dat, wgt = self._get_range_data(l_range, blind=False)
79
+ data_hist = hist.Hist.new.Regular(nbins, self.lower, self.upper, name=self.obs.obs[0], underflow=False, overflow=False)
80
+ data_hist = data_hist.Weight()
81
+ data_hist.fill(dat, weight=wgt)
82
+
83
+ tmp_fig, tmp_ax = plt.subplots()
84
+ errorbars = mplhep.histplot(
85
+ data_hist,
86
+ yerr=True,
87
+ color='white',
88
+ histtype="errorbar",
89
+ label=None,
90
+ ax=tmp_ax,
91
+ )
92
+ plt.close(tmp_fig)
93
+
94
+ lines = errorbars[0].errorbar[2]
95
+ segs = lines[0].get_segments()
96
+ values = data_hist.values()
97
+
98
+ l_error=[]
99
+ for i in range(nbins):
100
+ low = values[i] - segs[i][0][1]
101
+ up = -values[i] + segs[i][1][1]
102
+ l_error.append((low, up))
103
+
104
+ return l_error
105
+ #----------------------------------------
106
+ def _get_range_data(self, l_range, blind=True):
107
+ sdat = self.data_np
108
+ swgt = self.data_weight_np
109
+ dmat = np.array([sdat, swgt]).T
110
+
111
+ if blind and self._l_blind is not None:
112
+ log.debug(f'Blinding data with: {self._l_blind}')
113
+ _, min_val, max_val = self._l_blind
114
+ dmat = dmat[(dmat.T[0] < min_val) | (dmat.T[0] > max_val)]
115
+
116
+ if l_range is None:
117
+ [dat, wgt] = dmat.T
118
+ return dat, wgt
119
+
120
+ l_dat = []
121
+ l_wgt = []
122
+ for lo, hi in l_range:
123
+ dmat_f = dmat[(dmat.T[0] > lo) & (dmat.T[0] < hi)]
124
+
125
+ [dat, wgt] = dmat_f.T
126
+
127
+ l_dat.append(dat)
128
+ l_wgt.append(wgt)
129
+
130
+ dat_f = np.concatenate(l_dat)
131
+ wgt_f = np.concatenate(l_wgt)
132
+
133
+ return dat_f, wgt_f
134
+ #----------------------------------------
135
+ def _plot_data(self, ax, nbins=100, l_range=None):
136
+ dat, wgt = self._get_range_data(l_range, blind=True)
137
+ data_hist = hist.Hist.new.Regular(nbins, self.lower, self.upper, name=self.obs.obs[0], underflow=False, overflow=False)
138
+ data_hist = data_hist.Weight()
139
+ data_hist.fill(dat, weight=wgt)
140
+
141
+ _ = mplhep.histplot(
142
+ data_hist,
143
+ yerr=True,
144
+ color="black",
145
+ histtype="errorbar",
146
+ label=self._leg.get("Data", "Data"),
147
+ ax=ax,
148
+ xerr=self.dat_xerr
149
+ )
150
+ #----------------------------------------
151
+ def _pull_hist(self, pdf_hist, nbins, data_yield, l_range=None):
152
+ pdf_values= pdf_hist.values()
153
+ dat, wgt = self._get_range_data(l_range, blind=False)
154
+ data_hist = hist.Hist.new.Regular(nbins, self.lower, self.upper, name=self.obs.obs[0], underflow=False, overflow=False)
155
+ data_hist = data_hist.Weight()
156
+ data_hist.fill(dat, weight=wgt)
157
+
158
+ data_values = data_hist.values()
159
+ pdf_tot = sum(pdf_values)
160
+ pdf_scl = data_yield / pdf_tot
161
+
162
+ pdf_values = [ value * pdf_scl for value in pdf_values ]
163
+ pull_errors = [[], []]
164
+ pulls = []
165
+
166
+ for [low, up], pdf_val, dat_val in zip(self.errors, pdf_values, data_values):
167
+ res = float(dat_val - pdf_val)
168
+ err = low if res > 0 else up
169
+ pul = res / err
170
+
171
+ if abs(pul) > 5:
172
+ log.warning(f'Large pull: {pul:.1f}=({dat_val:.0f}-{pdf_val:.0f})/{err:.0f}')
173
+
174
+ pulls.append(pul)
175
+ pull_errors[0].append(low / err)
176
+ pull_errors[1].append(up / err)
177
+
178
+ hst = hist.axis.Regular(nbins, self.lower, self.upper, name="pulls")
179
+ pull_hist = hist.Hist(hst)
180
+ pull_hist[...] = pulls
181
+
182
+ return pull_hist, pull_errors
183
+ #----------------------------------------
184
+ def _plot_pulls(self, ax, nbins, data_yield, l_range):
185
+ obs_name = self.obs.obs[0]
186
+ binning = zfit.binned.RegularBinning(bins=nbins, start=self.lower, stop=self.upper, name=obs_name)
187
+ binned_obs = zfit.Space(obs_name, binning=binning)
188
+ binned_pdf = zfit.pdf.BinnedFromUnbinnedPDF(self.total_model, binned_obs)
189
+ pdf_hist = binned_pdf.to_hist()
190
+
191
+ pull_hist, pull_errors = self._pull_hist(pdf_hist, nbins, data_yield, l_range=l_range)
192
+
193
+ mplhep.histplot(
194
+ pull_hist,
195
+ color = "black",
196
+ histtype= "errorbar",
197
+ yerr = np.array(pull_errors),
198
+ ax = ax,
199
+ )
200
+ #----------------------------------------
201
+ def _get_zfit_gof(self):
202
+ if not hasattr(self._result, 'gof'):
203
+ return
204
+
205
+ chi2, ndof, pval = self._result.gof
206
+
207
+ rchi2 = chi2/ndof
208
+
209
+ return f'$\chi^2$/NdoF={chi2:.2f}/{ndof}={rchi2:.2f}\np={pval:.3f}'
210
+ #----------------------------------------
211
+ def _get_text(self, ext_text):
212
+ gof_text = self._get_zfit_gof()
213
+
214
+ if ext_text is None and gof_text is None:
215
+ return
216
+ elif ext_text is not None and gof_text is None:
217
+ return ext_text
218
+ elif ext_text is None and gof_text is not None:
219
+ return gof_text
220
+ else:
221
+ return f'{ext_text}\n{gof_text}'
222
+ #----------------------------------------
223
+ def _get_pars(self):
224
+ '''
225
+ Will return a dictionary with:
226
+ ```
227
+ par_name -> [value, error]
228
+ ```
229
+
230
+ if error is not available, will assign zeros
231
+ '''
232
+ pdf = self.total_model
233
+
234
+ if self._result is not None:
235
+ d_par = {}
236
+ for par, d_val in self._result.params.items():
237
+ val = d_val['value']
238
+ name= par if isinstance(par, str) else par.name
239
+ try:
240
+ err = d_val['hesse']['error']
241
+ except:
242
+ log.warning(f'Cannot extract {name} Hesse errors, using zeros')
243
+ pprint.pprint(d_val)
244
+ err = 0
245
+
246
+ d_par[name] = [val, err]
247
+ else:
248
+ s_par = pdf.get_params()
249
+ d_par = {par.name : [par.value(), 0] for par in s_par}
250
+
251
+ return d_par
252
+ #----------------------------------------
253
+ def _add_pars_box(self, add_pars):
254
+ '''
255
+ Will add parameter values to box to the right of fit plot
256
+
257
+ Parameters:
258
+ ------------------
259
+ add_pars (list|str): List of names of parameters to be added or string with value 'all' to add all fit parameters.
260
+ '''
261
+ d_par = self._get_pars()
262
+
263
+ line = f''
264
+ for name, [val, err] in d_par.items():
265
+ if add_pars != 'all' and name not in add_pars:
266
+ continue
267
+
268
+ line += f'{name:<20}{val:>10.3e}{"+/-":>5}{err:>10.3e}\n'
269
+
270
+ plt.text(0.65, 0.75, line, fontsize=12, transform=plt.gcf().transFigure)
271
+ #----------------------------------------
272
+ def _get_axis(self, add_pars, skip_pulls):
273
+ plt.style.use(mplhep.style.LHCb2)
274
+ if skip_pulls:
275
+ _, (ax) = plt.subplots(1)
276
+ return [ax]
277
+
278
+ if add_pars is None:
279
+ fig = plt.figure()
280
+ gs = fig.add_gridspec(nrows=2, ncols=1, hspace=0.1, height_ratios=[4, 1])
281
+ axs = gs.subplots(sharex=True)
282
+
283
+ return axs.flat
284
+
285
+ fig = plt.figure(figsize=self._figsize)
286
+ ax1 = plt.subplot2grid((4,40),(0, 0), rowspan=3, colspan=25)
287
+ ax2 = plt.subplot2grid((4,40),(3, 0), rowspan=1, colspan=25)
288
+ plt.subplots_adjust(hspace=0.2)
289
+
290
+ self._add_pars_box(add_pars)
291
+
292
+ return [ax1, ax2]
293
+ #----------------------------------------
294
+ def _get_component_yield(self, model, par):
295
+ if model.is_extended:
296
+ par = model.get_yield()
297
+ nevt = float(par.value())
298
+ return nevt
299
+
300
+ yild = self.total_model.get_yield()
301
+ if yild is None:
302
+ nevs = self.data_weight_np.sum()
303
+ else:
304
+ nevs = yild.value().numpy()
305
+
306
+ frac = par.value().numpy()
307
+
308
+ return frac * nevs
309
+ #----------------------------------------
310
+ def _plot_model_components(self, nbins, stacked):
311
+ if not hasattr(self.total_model, 'pdfs'):
312
+ return
313
+
314
+ if self._l_blind is not None:
315
+ [blind_name, _, _] = self._l_blind
316
+ else:
317
+ blind_name = None
318
+
319
+ y = None
320
+ l_y = []
321
+ was_blinded = False
322
+ for model, par in zip(self.total_model.pdfs, self.total_model.params.values()):
323
+ if model.name == blind_name:
324
+ was_blinded = True
325
+ log.debug(f'Skipping blinded PDF: {blind_name}')
326
+ continue
327
+
328
+ nevt = self._get_component_yield(model, par)
329
+
330
+ if model.name in self._l_plot_components and hasattr(model, 'pdfs'):
331
+ l_model = [ (frc, pdf) for pdf, frc in zip(model.pdfs, model.params.values()) ]
332
+ elif model.name in self._l_plot_components and not hasattr(model, 'pdfs'):
333
+ log.warning(f'Cannot plot {model.name} as separate components, despite it was requested')
334
+ l_model = [ (1, model)]
335
+ else:
336
+ l_model = [ (1, model)]
337
+
338
+ l_y += self._plot_sub_components(y, nbins, stacked, nevt, l_model)
339
+ y,_ = l_y[-1]
340
+
341
+ l_y.reverse()
342
+ ax = self.axs[0]
343
+ for y, name in l_y:
344
+ if stacked:
345
+ ax.fill_between(self.x, y, alpha=1.0, label=self._leg.get(name, name), color=self._get_col(name))
346
+ else:
347
+ ax.plot(self.x, y, '-', label=self._leg.get(name, name), color=self._col.get(name))
348
+
349
+ if (blind_name is not None) and (was_blinded is False):
350
+ log.error(f'Blinding was requested, but PDF {blind_name} was not found among:')
351
+ for model in self.total_model.pdfs:
352
+ log.info(model.name)
353
+ raise
354
+ #----------------------------------------
355
+ def _get_col(self, name):
356
+ if name in self._col:
357
+ return self._col[name]
358
+
359
+ col = self._l_def_col[0]
360
+ del(self._l_def_col[0])
361
+
362
+ return col
363
+ #----------------------------------------
364
+ def _plot_sub_components(self, y, nbins, stacked, nevt, l_model):
365
+ l_y = []
366
+ for frc, model in l_model:
367
+ this_y = model.pdf(self.x) * nevt * frc / nbins * (self.upper - self.lower)
368
+
369
+ if stacked:
370
+ y = this_y if y is None else y + this_y
371
+ else:
372
+ y = this_y
373
+
374
+ l_y.append((y, model.name))
375
+
376
+ return l_y
377
+ #----------------------------------------
378
+ def _plot_model(self, ax, model, nbins=100, linestyle='-'):
379
+ if self._l_blind is not None:
380
+ log.debug(f'Blinding: {model.name}')
381
+ return
382
+
383
+ data_yield = self.data_weight_np.sum()
384
+ y = model.pdf(self.x) * data_yield / nbins * (self.upper - self.lower)
385
+
386
+ name = model.name
387
+ ax.plot(self.x, y, linestyle, label=self._leg.get(name, name), color=self._col.get(name))
388
+ #----------------------------------------
389
+ def _get_labels(self, xlabel, ylabel, unit, nbins):
390
+ if xlabel == "":
391
+ xlabel = f"{self.obs.obs[0]} [{unit}]"
392
+
393
+ if ylabel == "":
394
+ width = (self.upper-self.lower)/nbins
395
+ ylabel = f'Candidates / ({width:.3f} {unit})'
396
+
397
+ return xlabel, ylabel
398
+ #----------------------------------------
399
+ def _get_xcoor(self, plot_range):
400
+ if plot_range is not None:
401
+ try:
402
+ self.lower, self.upper = plot_range
403
+ except TypeError:
404
+ log.error(f'plot_range argument is expected to be a tuple with two numeric values')
405
+ raise TypeError
406
+
407
+ return np.linspace(self.lower, self.upper, 2000)
408
+ #----------------------------------------
409
+ def _get_data_yield(self, mas_tup):
410
+ if mas_tup is None:
411
+ return self.data_weight_np.sum()
412
+
413
+ minx, maxx = mas_tup
414
+ arr_data = np.array([self.data_np, self.data_weight_np]).T
415
+
416
+ arr_data = arr_data[arr_data[:, 0] > minx]
417
+ arr_data = arr_data[arr_data[:, 0] < maxx]
418
+
419
+ [_, arr_wgt] = arr_data.T
420
+
421
+ return arr_wgt.sum()
422
+ #----------------------------------------
423
+ @gut.timeit
424
+ def plot(self,
425
+ title = None,
426
+ stacked = False,
427
+ blind = None,
428
+ no_data = False,
429
+ ranges = None,
430
+ nbins: int = 100,
431
+ unit: str = r'$\rm{MeV}/\it{c}^{2}$',
432
+ xlabel: str = "",
433
+ ylabel: str = "",
434
+ d_leg: dict = None,
435
+ d_col: dict = None,
436
+ plot_range: tuple = None,
437
+ plot_components = None,
438
+ ext_text : str = None,
439
+ add_pars = None,
440
+ ymax = None,
441
+ skip_pulls = False,
442
+ axs = None,
443
+ figsize:tuple = (13, 7),
444
+ leg_loc:str = 'best',
445
+ xerr: bool = False):
446
+ '''
447
+ title (str) : Title
448
+ stacked (bool) : If true will stack the PDFs
449
+ ranges : List of tuples with ranges if any was used for the fit, e.g. [(0, 3), (7, 10)]
450
+ nbins : Bin numbers
451
+ unit : Unit for x axis, default is MeV/c^2
452
+ no_data (bool) : If true data won't be plotted as well as pull
453
+ xlabel : xlabel
454
+ ylabel : ylabel
455
+ d_leg : Customize legend
456
+ d_col : Customize color
457
+ plot_range : Set plot_range
458
+ plot_components (list): List of strings, with names of PDFs, which are expected to be sums of PDFs and whose components should be plotted separately
459
+ ext_text : Text that can be added to plot
460
+ add_pars (list|str) : List of names of parameters to be added or string with value 'all' to add all fit parameters. If this is used, plot won't use LHCb style.
461
+ skip_pulls(bool) : Will not draw pulls if True, default False
462
+ ymax (float) : Optional, if specified will be used to set the maximum in plot
463
+ blind (list) : PDF name for the signal if blinding is needed, followed by blinding range, min and max.
464
+ figsize (tuple) : Tuple with figure size, default (13, 7)
465
+ leg_loc (str) : Location of legend, default 'best'
466
+ xerr (bool or float) : Used to pass xerr to mplhep histplot. True will use error with bin size, False, no error, otherwise it's the size of the xerror bar
467
+ '''
468
+ # pylint: disable=too-many-locals, too-many-positional-arguments, too-many-arguments
469
+ d_leg = {} if d_leg is None else d_leg
470
+ d_col = {} if d_col is None else d_col
471
+ plot_components = [] if plot_components is None else plot_components
472
+
473
+ if not hasattr(self.total_model, 'pdfs'):
474
+ #if it's not a sum of PDFs, do not stack
475
+ stacked=False
476
+
477
+ self._figsize = figsize
478
+ self._leg_loc = leg_loc
479
+
480
+ self._initialize()
481
+
482
+ self._l_plot_components = plot_components
483
+
484
+ self._leg = d_leg
485
+ self._col = d_col
486
+ self.x = self._get_xcoor(plot_range)
487
+ self.axs = self._get_axis(add_pars, skip_pulls) if axs is None else axs
488
+ self._l_blind = blind
489
+ total_entries = self._get_data_yield(plot_range)
490
+ self.errors = self._get_errors(nbins, ranges)
491
+ self.dat_xerr = xerr
492
+
493
+ if not stacked:
494
+ log.debug('Plotting full model, for non-stacked case')
495
+ self._plot_model(self.axs[0], self.total_model, nbins)
496
+
497
+ log.debug('Plotting model components')
498
+ self._plot_model_components(nbins, stacked)
499
+
500
+ if not no_data:
501
+ log.debug('Plotting data')
502
+ self._plot_data(self.axs[0], nbins, ranges)
503
+
504
+ if not skip_pulls and not no_data:
505
+ log.debug('Plotting pulls')
506
+ self._plot_pulls(self.axs[1], nbins, total_entries, ranges)
507
+
508
+ text = self._get_text(ext_text)
509
+ xlabel, ylabel = self._get_labels(xlabel, ylabel, unit, nbins)
510
+
511
+ self.axs[0].legend(title=text, fontsize=20, title_fontsize=20, loc=self._leg_loc)
512
+ self.axs[0].set(xlabel=xlabel, ylabel=ylabel)
513
+ self.axs[0].set_xlim([self.lower, self.upper])
514
+
515
+ if title is not None:
516
+ self.axs[0].set_title(title)
517
+
518
+ if ymax is not None:
519
+ self.axs[0].set_ylim([0, ymax])
520
+
521
+ if not skip_pulls:
522
+ self.axs[1].set(xlabel=xlabel, ylabel="pulls")
523
+ self.axs[1].set_xlim([self.lower, self.upper])
524
+
525
+ for ax in self.axs:
526
+ ax.label_outer()
527
+ #----------------------------------------
@@ -1,14 +1,20 @@
1
1
  saving:
2
2
  plt_dir : tests/plotting/2d_weighted
3
+ definitions:
4
+ z : x + y
3
5
  general:
4
6
  size : [20, 10]
5
7
  plots_2d:
6
8
  - [x, y, weights, 'xy_w']
7
9
  - [x, y, null, 'xy_r']
10
+ - [x, z, null, 'xz_r']
8
11
  axes:
9
12
  x :
10
- binning : [-5.0, 8.0, 40]
13
+ binning : [-3.0, 3.0, 40]
11
14
  label : 'x'
12
15
  y :
13
16
  binning : [-5.0, 8.0, 40]
14
17
  label : 'y'
18
+ z :
19
+ binning : [-5.0, 16.0, 40]
20
+ label : 'z'