data-manipulation-utilities 0.1.6__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.2
2
2
  Name: data_manipulation_utilities
3
- Version: 0.1.6
3
+ Version: 0.1.7
4
4
  Description-Content-Type: text/markdown
5
5
  Requires-Dist: logzero
6
6
  Requires-Dist: PyYAML
@@ -41,7 +41,7 @@ such that:
41
41
 
42
42
  Then, for each remote it pushes the tags and the commits.
43
43
 
44
- *Why?*
44
+ *Why?*
45
45
 
46
46
  1. Tags should be named as the project's version
47
47
  1. As soon as a new version is created, that version needs to be tagged.
@@ -231,6 +231,49 @@ likelihood :
231
231
  nbins : 100 #If specified, will do binned likelihood fit instead of unbinned
232
232
  ```
233
233
 
234
+ ## Fit plotting
235
+
236
+ The class `ZFitPlotter` can be used to plot fits done with zfit. For a complete set of examples of how to use
237
+ this class check the [tests](tests/stats/test_fit_plotter.py). A simple example of its usage is below:
238
+
239
+ ```python
240
+ from dmu.stats.zfit_plotter import ZFitPlotter
241
+
242
+ obs = zfit.Space('m', limits=(0, 10))
243
+
244
+ # Create signal PDF
245
+ mu = zfit.Parameter("mu", 5.0, 0, 10)
246
+ sg = zfit.Parameter("sg", 0.5, 0, 5)
247
+ sig = zfit.pdf.Gauss(obs=obs, mu=mu, sigma=sg)
248
+ nsg = zfit.Parameter('nsg', 1000, 0, 10000)
249
+ esig= sig.create_extended(nsg, name='gauss')
250
+
251
+ # Create background PDF
252
+ lm = zfit.Parameter('lm', -0.1, -1, 0)
253
+ bkg = zfit.pdf.Exponential(obs=obs, lam=lm)
254
+ nbk = zfit.Parameter('nbk', 1000, 0, 10000)
255
+ ebkg= bkg.create_extended(nbk, name='expo')
256
+
257
+ # Add them
258
+ pdf = zfit.pdf.SumPDF([ebkg, esig])
259
+ sam = pdf.create_sampler()
260
+
261
+ # Plot them
262
+ obj = ZFitPlotter(data=sam, model=pdf)
263
+ d_leg = {'gauss': 'New Gauss'}
264
+ obj.plot(nbins=50, d_leg=d_leg, stacked=True, plot_range=(0, 10), ext_text='Extra text here')
265
+
266
+ # add a line to pull hist
267
+ obj.axs[1].plot([0, 10], [0, 0], linestyle='--', color='black')
268
+ ```
269
+
270
+ this class supports:
271
+
272
+ - Handling title, legend, plots size.
273
+ - Adding pulls.
274
+ - Stacking and overlaying of PDFs.
275
+ - Blinding.
276
+
234
277
  ## Arrays
235
278
 
236
279
  ### Scaling by non-integer
@@ -1,4 +1,4 @@
1
- data_manipulation_utilities-0.1.6.data/scripts/publish,sha256=-3K_Y2_4CfWCV50rPB8CRuhjxDu7xMGswinRwPovgLs,1976
1
+ data_manipulation_utilities-0.1.7.data/scripts/publish,sha256=-3K_Y2_4CfWCV50rPB8CRuhjxDu7xMGswinRwPovgLs,1976
2
2
  dmu/arrays/utilities.py,sha256=PKoYyybPptA2aU-V3KLnJXBudWxTXu4x1uGdIMQ49HY,1722
3
3
  dmu/generic/utilities.py,sha256=0Xnq9t35wuebAqKxbyAiMk1ISB7IcXK4cFH25MT1fgw,1741
4
4
  dmu/logging/log_store.py,sha256=umdvjNDuV3LdezbG26b0AiyTglbvkxST19CQu9QATbA,4184
@@ -6,9 +6,9 @@ dmu/ml/cv_classifier.py,sha256=n81m7i2M6Zq96AEd9EZGwXSrbG5m9jkS5RdeXvbsAXU,3712
6
6
  dmu/ml/cv_predict.py,sha256=Bqxu-f6qquKJokFljhCzL_kiGcjLJLQFhVBD130fsyw,4893
7
7
  dmu/ml/train_mva.py,sha256=d_n-A07DFweikz5nXap4OE_Mqx8VprFT7zbxmnQAbac,9638
8
8
  dmu/ml/utilities.py,sha256=Nue7O9zi1QXgjGRPH6wnSAW9jusMQ2ZOSDJzBqJKIi0,3687
9
- dmu/plotting/plotter.py,sha256=laa6Kl7P-ZOIhaOFBVjOH4XQ4kPCV7wBNvLIMBnyCwM,7181
10
- dmu/plotting/plotter_1d.py,sha256=G-i94uzm2TjNaog1A4agAKar_G0qNdkAqIPCmzhe85Y,3660
11
- dmu/plotting/plotter_2d.py,sha256=SWPKns-CfpUZHgBXvwm3gceH3k2eL_mKGXQ8sWpZJB0,2919
9
+ dmu/plotting/plotter.py,sha256=ytMxtzHEY8ZFU0ZKEBE-ROjMszXl5kHTMnQnWe173nU,7208
10
+ dmu/plotting/plotter_1d.py,sha256=O7rTgCBlpCko1RSpj2TzcUIfx9sKoz2jAgw73Pz7Ynk,4472
11
+ dmu/plotting/plotter_2d.py,sha256=J-gKnagoHGfJFU7HBrhDFpGYH5Rxy0_zF5l8eE_7ZHE,2944
12
12
  dmu/rdataframe/atr_mgr.py,sha256=FdhaQWVpsm4OOe1IRbm7rfrq8VenTNdORyI-lZ2Bs1M,2386
13
13
  dmu/rdataframe/utilities.py,sha256=x8r379F2-vZPYzAdMFCn_V4Kx2Tx9t9pn_QHcZ1euew,2756
14
14
  dmu/rfile/rfprinter.py,sha256=mp5jd-oCJAnuokbdmGyL9i6tK2lY72jEfROuBIZ_ums,3941
@@ -16,11 +16,12 @@ dmu/rfile/utilities.py,sha256=XuYY7HuSBj46iSu3c60UYBHtI6KIPoJU_oofuhb-be0,945
16
16
  dmu/stats/fitter.py,sha256=LDvFNyhgO0OzXN7aH3kfHe6LzuPqdQfPcKR_IegDcaU,18204
17
17
  dmu/stats/function.py,sha256=yzi_Fvp_ASsFzbWFivIf-comquy21WoeY7is6dgY0Go,9491
18
18
  dmu/stats/utilities.py,sha256=LQy4kd3xSXqpApcWuYfZxkGQyjowaXv2Wr1c4Bj-4ys,4523
19
+ dmu/stats/zfit_plotter.py,sha256=Xs6kisNEmNQXhYRCcjowxO6xHuyAyrfyQIFhGAR61U4,19719
19
20
  dmu/testing/utilities.py,sha256=WbMM4e9Cn3-B-12Vr64mB5qTKkV32joStlRkD-48lG0,3460
20
21
  dmu/text/transformer.py,sha256=4lrGknbAWRm0-rxbvgzOO-eR1-9bkYk61boJUEV3cQ0,6100
21
22
  dmu_data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
22
23
  dmu_data/ml/tests/train_mva.yaml,sha256=TCniCVpXMEFxZcHa8IIqollKA7ci4OkBnRznLEkXM9o,925
23
- dmu_data/plotting/tests/2d.yaml,sha256=lTMNheK3DB8klp4O5QjMDwBI1A1Oh2_Wp2F2Ro9VQKM,282
24
+ dmu_data/plotting/tests/2d.yaml,sha256=VApcAfJFbjNcjMCTBSRm2P37MQlGavMZv6msbZwLSgw,402
24
25
  dmu_data/plotting/tests/fig_size.yaml,sha256=7ROq49nwZ1A2EbPiySmu6n3G-Jq6YAOkc3d2X3YNZv0,294
25
26
  dmu_data/plotting/tests/high_stat.yaml,sha256=bLglBLCZK6ft0xMhQ5OltxE76cWsBMPMjO6GG0OkDr8,522
26
27
  dmu_data/plotting/tests/name.yaml,sha256=mkcPAVg8wBAmlSbSRQ1bcaMl4vOS6LXMtpqQeDrrtO4,312
@@ -39,8 +40,8 @@ dmu_scripts/rfile/compare_root_files.py,sha256=T8lDnQxsRNMr37x1Y7YvWD8ySHrJOWZki
39
40
  dmu_scripts/rfile/print_trees.py,sha256=Ze4Ccl_iUldl4eVEDVnYBoe4amqBT1fSBR1zN5WSztk,941
40
41
  dmu_scripts/ssh/coned.py,sha256=lhilYNHWRCGxC-jtyJ3LQ4oUgWW33B2l1tYCcyHHsR0,4858
41
42
  dmu_scripts/text/transform_text.py,sha256=9akj1LB0HAyopOvkLjNOJiptZw5XoOQLe17SlcrGMD0,1456
42
- data_manipulation_utilities-0.1.6.dist-info/METADATA,sha256=1ttATABwWcdqqPJM72_4s_ZQjtbFp9MzkfsprkDJTv8,19946
43
- data_manipulation_utilities-0.1.6.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
44
- data_manipulation_utilities-0.1.6.dist-info/entry_points.txt,sha256=1TIZDed651KuOH-DgaN5AoBdirKmrKE_oM1b6b7zTUU,270
45
- data_manipulation_utilities-0.1.6.dist-info/top_level.txt,sha256=n_x5J6uWtSqy9mRImKtdA2V2NJNyU8Kn3u8DTOKJix0,25
46
- data_manipulation_utilities-0.1.6.dist-info/RECORD,,
43
+ data_manipulation_utilities-0.1.7.dist-info/METADATA,sha256=6cSG5TvicYwa0Ru5352DXpVC1k0B6Zcz2HB4vkVWEkg,21183
44
+ data_manipulation_utilities-0.1.7.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
45
+ data_manipulation_utilities-0.1.7.dist-info/entry_points.txt,sha256=1TIZDed651KuOH-DgaN5AoBdirKmrKE_oM1b6b7zTUU,270
46
+ data_manipulation_utilities-0.1.7.dist-info/top_level.txt,sha256=n_x5J6uWtSqy9mRImKtdA2V2NJNyU8Kn3u8DTOKJix0,25
47
+ data_manipulation_utilities-0.1.7.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.6.0)
2
+ Generator: setuptools (75.8.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
dmu/plotting/plotter.py CHANGED
@@ -65,7 +65,7 @@ class Plotter:
65
65
 
66
66
  return minx, maxx
67
67
  #-------------------------------------
68
- def _preprocess_rdf(self, rdf):
68
+ def _preprocess_rdf(self, rdf : RDataFrame) -> RDataFrame:
69
69
  '''
70
70
  rdf (RDataFrame): ROOT dataframe
71
71
 
@@ -2,6 +2,9 @@
2
2
  Module containing plotter class
3
3
  '''
4
4
 
5
+ import hist
6
+ from hist import Hist
7
+
5
8
  import numpy
6
9
  import matplotlib.pyplot as plt
7
10
 
@@ -33,58 +36,75 @@ class Plotter1D(Plotter):
33
36
 
34
37
  return xname, yname
35
38
  #-------------------------------------
36
- def _plot_var(self, var):
39
+ def _is_normalized(self, var : str) -> bool:
40
+ d_cfg = self._d_cfg['plots'][var]
41
+ normalized=False
42
+ if 'normalized' in d_cfg:
43
+ normalized = d_cfg['normalized']
44
+
45
+ return normalized
46
+ #-------------------------------------
47
+ def _get_binning(self, var : str, d_data : dict[str, numpy.ndarray]) -> tuple[float, float, int]:
48
+ d_cfg = self._d_cfg['plots'][var]
49
+ minx, maxx, bins = d_cfg['binning']
50
+ if maxx <= minx + 1e-5:
51
+ log.info(f'Bounds not set for {var}, will calculated them')
52
+ minx, maxx = self._find_bounds(d_data = d_data, qnt=minx)
53
+ log.info(f'Using bounds [{minx:.3e}, {maxx:.3e}]')
54
+ else:
55
+ log.debug(f'Using bounds [{minx:.3e}, {maxx:.3e}]')
56
+
57
+ return minx, maxx, bins
58
+ #-------------------------------------
59
+ def _plot_var(self, var : str) -> float:
37
60
  '''
38
61
  Will plot a variable from a dictionary of dataframes
39
62
  Parameters
40
63
  --------------------
41
64
  var (str) : name of column
65
+
66
+ Return
67
+ --------------------
68
+ Largest bin content among all bins and among all histograms plotted
42
69
  '''
43
70
  # pylint: disable=too-many-locals
44
71
 
45
- d_cfg = self._d_cfg['plots'][var]
46
-
47
- minx, maxx, bins = d_cfg['binning']
48
- yscale = d_cfg['yscale' ] if 'yscale' in d_cfg else 'linear'
49
- xname, yname = self._get_labels(var)
50
-
51
- normalized=False
52
- if 'normalized' in d_cfg:
53
- normalized = d_cfg['normalized']
54
-
55
- title = ''
56
- if 'title' in d_cfg:
57
- title = d_cfg['title']
58
-
59
72
  d_data = {}
60
73
  for name, rdf in self._d_rdf.items():
61
74
  d_data[name] = rdf.AsNumpy([var])[var]
62
75
 
63
- if maxx <= minx + 1e-5:
64
- log.info(f'Bounds not set for {var}, will calculated them')
65
- minx, maxx = self._find_bounds(d_data = d_data, qnt=minx)
66
- log.info(f'Using bounds [{minx:.3e}, {maxx:.3e}]')
67
- else:
68
- log.debug(f'Using bounds [{minx:.3e}, {maxx:.3e}]')
76
+ minx, maxx, bins = self._get_binning(var, d_data)
77
+ d_wgt = self._get_weights(var)
69
78
 
70
79
  l_bc_all = []
71
- d_wgt = self._get_weights(var)
72
80
  for name, arr_val in d_data.items():
73
- arr_wgt = d_wgt[name] if d_wgt is not None else None
74
-
75
- self._print_weights(arr_wgt, var, name)
76
- l_bc, _, _ = plt.hist(arr_val, weights=arr_wgt, bins=bins, range=(minx, maxx), density=normalized, histtype='step', label=name)
77
- l_bc_all += numpy.array(l_bc).tolist()
81
+ arr_wgt = d_wgt[name] if d_wgt is not None else numpy.ones_like(arr_val)
82
+ hst = Hist.new.Reg(bins=bins, start=minx, stop=maxx, name='x', label=name).Weight()
83
+ hst.fill(x=arr_val, weight=arr_wgt)
84
+ hst.plot(label=name)
85
+ l_bc_all += hst.values().tolist()
78
86
 
79
- plt.yscale(yscale)
80
- plt.xlabel(xname)
81
- plt.ylabel(yname)
87
+ max_y = max(l_bc_all)
82
88
 
89
+ return max_y
90
+ # --------------------------------------------
91
+ def _style_plot(self, var : str, max_y : float) -> None:
92
+ d_cfg = self._d_cfg['plots'][var]
93
+ yscale = d_cfg['yscale' ] if 'yscale' in d_cfg else 'linear'
94
+
95
+ xname, yname = self._get_labels(var)
96
+ plt.xlabel(xname)
97
+ plt.ylabel(yname)
98
+ plt.yscale(yscale)
83
99
  if yscale == 'linear':
84
100
  plt.ylim(bottom=0)
85
101
 
86
- max_y = max(l_bc_all)
102
+ title = ''
103
+ if 'title' in d_cfg:
104
+ title = d_cfg['title']
105
+
87
106
  plt.ylim(top=1.2 * max_y)
107
+ plt.legend()
88
108
  plt.title(title)
89
109
  # --------------------------------------------
90
110
  def _plot_lines(self, var : str):
@@ -106,8 +126,10 @@ class Plotter1D(Plotter):
106
126
  fig_size = self._get_fig_size()
107
127
  for var in self._d_cfg['plots']:
108
128
  log.debug(f'Plotting: {var}')
129
+
109
130
  plt.figure(var, figsize=fig_size)
110
- self._plot_var(var)
131
+ max_y = self._plot_var(var)
132
+ self._style_plot(var, max_y)
111
133
  self._plot_lines(var)
112
134
  self._save_plot(var)
113
135
  # --------------------------------------------
@@ -31,8 +31,8 @@ class Plotter2D(Plotter):
31
31
  if not isinstance(cfg, dict):
32
32
  raise ValueError('Config dictionary not passed')
33
33
 
34
- self._rdf : RDataFrame = rdf
35
34
  self._d_cfg : dict = cfg
35
+ self._rdf : RDataFrame = super()._preprocess_rdf(rdf)
36
36
 
37
37
  self._wgt : numpy.ndarray
38
38
  # --------------------------------------------
@@ -0,0 +1,527 @@
1
+ '''
2
+ Module containing plot class, used to plot fits
3
+ '''
4
+ # pylint: disable=too-many-instance-attributes
5
+
6
+ import warnings
7
+ import pprint
8
+
9
+ import zfit
10
+ import hist
11
+ import mplhep
12
+ import pandas as pd
13
+ import numpy as np
14
+ import matplotlib.pyplot as plt
15
+ import dmu.generic.utilities as gut
16
+
17
+ from dmu.logging.log_store import LogStore
18
+
19
+ log = LogStore.add_logger('dmu:fit_plotter')
20
+ #----------------------------------------
21
+ class ZFitPlotter:
22
+ '''
23
+ Class used to plot fits done with zfit
24
+ '''
25
+ def __init__(self, data=None, model=None, weights=None, result=None, suffix=''):
26
+ '''
27
+ obs: zfit space you are using to define the data and model
28
+ data: the data you are fit on
29
+ weights: 1D numpy array of weights
30
+ total_model: the final total fit model
31
+ '''
32
+ # pylint: disable=too-many-positional-arguments
33
+
34
+ self.obs = model.space
35
+ self.data = self._data_to_zdata(model.space, data, weights)
36
+ self.lower, self.upper = self.data.data_range.limit1d
37
+ self.total_model = model
38
+ self.x = np.linspace(self.lower, self.upper, 2000)
39
+ self.data_np = zfit.run(self.data.unstack_x())
40
+ self.data_weight_np = np.ones_like(self.data_np) if self.data.weights is None else zfit.run(self.data.weights)
41
+
42
+ self.errors = []
43
+ self._l_def_col = []
44
+ self._result = result
45
+ self._suffix = suffix
46
+ self._leg = {}
47
+ self._col = {}
48
+ self._l_blind = None
49
+ self._l_plot_components= None
50
+ self.axs = None
51
+ self._figsize = None
52
+ self._leg_loc = None
53
+
54
+ # zfit.settings.advanced_warnings['extend_wrapped_extended'] = False
55
+ warnings.filterwarnings("ignore")
56
+ #----------------------------------------
57
+ def _initialize(self):
58
+ import matplotlib.colors as mcolors
59
+
60
+ self._l_def_col = list(mcolors.TABLEAU_COLORS.keys())
61
+ #----------------------------------------
62
+ def _data_to_zdata(self, obs, data, weights):
63
+ if isinstance(data, np.ndarray):
64
+ data = zfit.Data.from_numpy (obs=obs, array=data , weights=weights)
65
+ elif isinstance(data, pd.Series):
66
+ data = zfit.Data.from_pandas(obs=obs, df=pd.DataFrame(data), weights=weights)
67
+ elif isinstance(data, pd.DataFrame):
68
+ data = zfit.Data.from_pandas(obs=obs, df=data , weights=weights)
69
+ elif isinstance(data, zfit.data.Data):
70
+ data = data
71
+ else:
72
+ log.error(f'Passed data is of usupported type {type(data)}')
73
+ raise
74
+
75
+ return data
76
+ #----------------------------------------
77
+ def _get_errors(self, nbins=100, l_range=None):
78
+ dat, wgt = self._get_range_data(l_range, blind=False)
79
+ data_hist = hist.Hist.new.Regular(nbins, self.lower, self.upper, name=self.obs.obs[0], underflow=False, overflow=False)
80
+ data_hist = data_hist.Weight()
81
+ data_hist.fill(dat, weight=wgt)
82
+
83
+ tmp_fig, tmp_ax = plt.subplots()
84
+ errorbars = mplhep.histplot(
85
+ data_hist,
86
+ yerr=True,
87
+ color='white',
88
+ histtype="errorbar",
89
+ label=None,
90
+ ax=tmp_ax,
91
+ )
92
+ plt.close(tmp_fig)
93
+
94
+ lines = errorbars[0].errorbar[2]
95
+ segs = lines[0].get_segments()
96
+ values = data_hist.values()
97
+
98
+ l_error=[]
99
+ for i in range(nbins):
100
+ low = values[i] - segs[i][0][1]
101
+ up = -values[i] + segs[i][1][1]
102
+ l_error.append((low, up))
103
+
104
+ return l_error
105
+ #----------------------------------------
106
+ def _get_range_data(self, l_range, blind=True):
107
+ sdat = self.data_np
108
+ swgt = self.data_weight_np
109
+ dmat = np.array([sdat, swgt]).T
110
+
111
+ if blind and self._l_blind is not None:
112
+ log.debug(f'Blinding data with: {self._l_blind}')
113
+ _, min_val, max_val = self._l_blind
114
+ dmat = dmat[(dmat.T[0] < min_val) | (dmat.T[0] > max_val)]
115
+
116
+ if l_range is None:
117
+ [dat, wgt] = dmat.T
118
+ return dat, wgt
119
+
120
+ l_dat = []
121
+ l_wgt = []
122
+ for lo, hi in l_range:
123
+ dmat_f = dmat[(dmat.T[0] > lo) & (dmat.T[0] < hi)]
124
+
125
+ [dat, wgt] = dmat_f.T
126
+
127
+ l_dat.append(dat)
128
+ l_wgt.append(wgt)
129
+
130
+ dat_f = np.concatenate(l_dat)
131
+ wgt_f = np.concatenate(l_wgt)
132
+
133
+ return dat_f, wgt_f
134
+ #----------------------------------------
135
+ def _plot_data(self, ax, nbins=100, l_range=None):
136
+ dat, wgt = self._get_range_data(l_range, blind=True)
137
+ data_hist = hist.Hist.new.Regular(nbins, self.lower, self.upper, name=self.obs.obs[0], underflow=False, overflow=False)
138
+ data_hist = data_hist.Weight()
139
+ data_hist.fill(dat, weight=wgt)
140
+
141
+ _ = mplhep.histplot(
142
+ data_hist,
143
+ yerr=True,
144
+ color="black",
145
+ histtype="errorbar",
146
+ label=self._leg.get("Data", "Data"),
147
+ ax=ax,
148
+ xerr=self.dat_xerr
149
+ )
150
+ #----------------------------------------
151
+ def _pull_hist(self, pdf_hist, nbins, data_yield, l_range=None):
152
+ pdf_values= pdf_hist.values()
153
+ dat, wgt = self._get_range_data(l_range, blind=False)
154
+ data_hist = hist.Hist.new.Regular(nbins, self.lower, self.upper, name=self.obs.obs[0], underflow=False, overflow=False)
155
+ data_hist = data_hist.Weight()
156
+ data_hist.fill(dat, weight=wgt)
157
+
158
+ data_values = data_hist.values()
159
+ pdf_tot = sum(pdf_values)
160
+ pdf_scl = data_yield / pdf_tot
161
+
162
+ pdf_values = [ value * pdf_scl for value in pdf_values ]
163
+ pull_errors = [[], []]
164
+ pulls = []
165
+
166
+ for [low, up], pdf_val, dat_val in zip(self.errors, pdf_values, data_values):
167
+ res = float(dat_val - pdf_val)
168
+ err = low if res > 0 else up
169
+ pul = res / err
170
+
171
+ if abs(pul) > 5:
172
+ log.warning(f'Large pull: {pul:.1f}=({dat_val:.0f}-{pdf_val:.0f})/{err:.0f}')
173
+
174
+ pulls.append(pul)
175
+ pull_errors[0].append(low / err)
176
+ pull_errors[1].append(up / err)
177
+
178
+ hst = hist.axis.Regular(nbins, self.lower, self.upper, name="pulls")
179
+ pull_hist = hist.Hist(hst)
180
+ pull_hist[...] = pulls
181
+
182
+ return pull_hist, pull_errors
183
+ #----------------------------------------
184
+ def _plot_pulls(self, ax, nbins, data_yield, l_range):
185
+ obs_name = self.obs.obs[0]
186
+ binning = zfit.binned.RegularBinning(bins=nbins, start=self.lower, stop=self.upper, name=obs_name)
187
+ binned_obs = zfit.Space(obs_name, binning=binning)
188
+ binned_pdf = zfit.pdf.BinnedFromUnbinnedPDF(self.total_model, binned_obs)
189
+ pdf_hist = binned_pdf.to_hist()
190
+
191
+ pull_hist, pull_errors = self._pull_hist(pdf_hist, nbins, data_yield, l_range=l_range)
192
+
193
+ mplhep.histplot(
194
+ pull_hist,
195
+ color = "black",
196
+ histtype= "errorbar",
197
+ yerr = np.array(pull_errors),
198
+ ax = ax,
199
+ )
200
+ #----------------------------------------
201
+ def _get_zfit_gof(self):
202
+ if not hasattr(self._result, 'gof'):
203
+ return
204
+
205
+ chi2, ndof, pval = self._result.gof
206
+
207
+ rchi2 = chi2/ndof
208
+
209
+ return f'$\chi^2$/NdoF={chi2:.2f}/{ndof}={rchi2:.2f}\np={pval:.3f}'
210
+ #----------------------------------------
211
+ def _get_text(self, ext_text):
212
+ gof_text = self._get_zfit_gof()
213
+
214
+ if ext_text is None and gof_text is None:
215
+ return
216
+ elif ext_text is not None and gof_text is None:
217
+ return ext_text
218
+ elif ext_text is None and gof_text is not None:
219
+ return gof_text
220
+ else:
221
+ return f'{ext_text}\n{gof_text}'
222
+ #----------------------------------------
223
+ def _get_pars(self):
224
+ '''
225
+ Will return a dictionary with:
226
+ ```
227
+ par_name -> [value, error]
228
+ ```
229
+
230
+ if error is not available, will assign zeros
231
+ '''
232
+ pdf = self.total_model
233
+
234
+ if self._result is not None:
235
+ d_par = {}
236
+ for par, d_val in self._result.params.items():
237
+ val = d_val['value']
238
+ name= par if isinstance(par, str) else par.name
239
+ try:
240
+ err = d_val['hesse']['error']
241
+ except:
242
+ log.warning(f'Cannot extract {name} Hesse errors, using zeros')
243
+ pprint.pprint(d_val)
244
+ err = 0
245
+
246
+ d_par[name] = [val, err]
247
+ else:
248
+ s_par = pdf.get_params()
249
+ d_par = {par.name : [par.value(), 0] for par in s_par}
250
+
251
+ return d_par
252
+ #----------------------------------------
253
+ def _add_pars_box(self, add_pars):
254
+ '''
255
+ Will add parameter values to box to the right of fit plot
256
+
257
+ Parameters:
258
+ ------------------
259
+ add_pars (list|str): List of names of parameters to be added or string with value 'all' to add all fit parameters.
260
+ '''
261
+ d_par = self._get_pars()
262
+
263
+ line = f''
264
+ for name, [val, err] in d_par.items():
265
+ if add_pars != 'all' and name not in add_pars:
266
+ continue
267
+
268
+ line += f'{name:<20}{val:>10.3e}{"+/-":>5}{err:>10.3e}\n'
269
+
270
+ plt.text(0.65, 0.75, line, fontsize=12, transform=plt.gcf().transFigure)
271
+ #----------------------------------------
272
+ def _get_axis(self, add_pars, skip_pulls):
273
+ plt.style.use(mplhep.style.LHCb2)
274
+ if skip_pulls:
275
+ _, (ax) = plt.subplots(1)
276
+ return [ax]
277
+
278
+ if add_pars is None:
279
+ fig = plt.figure()
280
+ gs = fig.add_gridspec(nrows=2, ncols=1, hspace=0.1, height_ratios=[4, 1])
281
+ axs = gs.subplots(sharex=True)
282
+
283
+ return axs.flat
284
+
285
+ fig = plt.figure(figsize=self._figsize)
286
+ ax1 = plt.subplot2grid((4,40),(0, 0), rowspan=3, colspan=25)
287
+ ax2 = plt.subplot2grid((4,40),(3, 0), rowspan=1, colspan=25)
288
+ plt.subplots_adjust(hspace=0.2)
289
+
290
+ self._add_pars_box(add_pars)
291
+
292
+ return [ax1, ax2]
293
+ #----------------------------------------
294
+ def _get_component_yield(self, model, par):
295
+ if model.is_extended:
296
+ par = model.get_yield()
297
+ nevt = float(par.value())
298
+ return nevt
299
+
300
+ yild = self.total_model.get_yield()
301
+ if yild is None:
302
+ nevs = self.data_weight_np.sum()
303
+ else:
304
+ nevs = yild.value().numpy()
305
+
306
+ frac = par.value().numpy()
307
+
308
+ return frac * nevs
309
+ #----------------------------------------
310
+ def _plot_model_components(self, nbins, stacked):
311
+ if not hasattr(self.total_model, 'pdfs'):
312
+ return
313
+
314
+ if self._l_blind is not None:
315
+ [blind_name, _, _] = self._l_blind
316
+ else:
317
+ blind_name = None
318
+
319
+ y = None
320
+ l_y = []
321
+ was_blinded = False
322
+ for model, par in zip(self.total_model.pdfs, self.total_model.params.values()):
323
+ if model.name == blind_name:
324
+ was_blinded = True
325
+ log.debug(f'Skipping blinded PDF: {blind_name}')
326
+ continue
327
+
328
+ nevt = self._get_component_yield(model, par)
329
+
330
+ if model.name in self._l_plot_components and hasattr(model, 'pdfs'):
331
+ l_model = [ (frc, pdf) for pdf, frc in zip(model.pdfs, model.params.values()) ]
332
+ elif model.name in self._l_plot_components and not hasattr(model, 'pdfs'):
333
+ log.warning(f'Cannot plot {model.name} as separate components, despite it was requested')
334
+ l_model = [ (1, model)]
335
+ else:
336
+ l_model = [ (1, model)]
337
+
338
+ l_y += self._plot_sub_components(y, nbins, stacked, nevt, l_model)
339
+ y,_ = l_y[-1]
340
+
341
+ l_y.reverse()
342
+ ax = self.axs[0]
343
+ for y, name in l_y:
344
+ if stacked:
345
+ ax.fill_between(self.x, y, alpha=1.0, label=self._leg.get(name, name), color=self._get_col(name))
346
+ else:
347
+ ax.plot(self.x, y, '-', label=self._leg.get(name, name), color=self._col.get(name))
348
+
349
+ if (blind_name is not None) and (was_blinded is False):
350
+ log.error(f'Blinding was requested, but PDF {blind_name} was not found among:')
351
+ for model in self.total_model.pdfs:
352
+ log.info(model.name)
353
+ raise
354
+ #----------------------------------------
355
+ def _get_col(self, name):
356
+ if name in self._col:
357
+ return self._col[name]
358
+
359
+ col = self._l_def_col[0]
360
+ del(self._l_def_col[0])
361
+
362
+ return col
363
+ #----------------------------------------
364
+ def _plot_sub_components(self, y, nbins, stacked, nevt, l_model):
365
+ l_y = []
366
+ for frc, model in l_model:
367
+ this_y = model.pdf(self.x) * nevt * frc / nbins * (self.upper - self.lower)
368
+
369
+ if stacked:
370
+ y = this_y if y is None else y + this_y
371
+ else:
372
+ y = this_y
373
+
374
+ l_y.append((y, model.name))
375
+
376
+ return l_y
377
+ #----------------------------------------
378
+ def _plot_model(self, ax, model, nbins=100, linestyle='-'):
379
+ if self._l_blind is not None:
380
+ log.debug(f'Blinding: {model.name}')
381
+ return
382
+
383
+ data_yield = self.data_weight_np.sum()
384
+ y = model.pdf(self.x) * data_yield / nbins * (self.upper - self.lower)
385
+
386
+ name = model.name
387
+ ax.plot(self.x, y, linestyle, label=self._leg.get(name, name), color=self._col.get(name))
388
+ #----------------------------------------
389
+ def _get_labels(self, xlabel, ylabel, unit, nbins):
390
+ if xlabel == "":
391
+ xlabel = f"{self.obs.obs[0]} [{unit}]"
392
+
393
+ if ylabel == "":
394
+ width = (self.upper-self.lower)/nbins
395
+ ylabel = f'Candidates / ({width:.3f} {unit})'
396
+
397
+ return xlabel, ylabel
398
+ #----------------------------------------
399
+ def _get_xcoor(self, plot_range):
400
+ if plot_range is not None:
401
+ try:
402
+ self.lower, self.upper = plot_range
403
+ except TypeError:
404
+ log.error(f'plot_range argument is expected to be a tuple with two numeric values')
405
+ raise TypeError
406
+
407
+ return np.linspace(self.lower, self.upper, 2000)
408
+ #----------------------------------------
409
+ def _get_data_yield(self, mas_tup):
410
+ if mas_tup is None:
411
+ return self.data_weight_np.sum()
412
+
413
+ minx, maxx = mas_tup
414
+ arr_data = np.array([self.data_np, self.data_weight_np]).T
415
+
416
+ arr_data = arr_data[arr_data[:, 0] > minx]
417
+ arr_data = arr_data[arr_data[:, 0] < maxx]
418
+
419
+ [_, arr_wgt] = arr_data.T
420
+
421
+ return arr_wgt.sum()
422
+ #----------------------------------------
423
+ @gut.timeit
424
+ def plot(self,
425
+ title = None,
426
+ stacked = False,
427
+ blind = None,
428
+ no_data = False,
429
+ ranges = None,
430
+ nbins: int = 100,
431
+ unit: str = r'$\rm{MeV}/\it{c}^{2}$',
432
+ xlabel: str = "",
433
+ ylabel: str = "",
434
+ d_leg: dict = None,
435
+ d_col: dict = None,
436
+ plot_range: tuple = None,
437
+ plot_components = None,
438
+ ext_text : str = None,
439
+ add_pars = None,
440
+ ymax = None,
441
+ skip_pulls = False,
442
+ axs = None,
443
+ figsize:tuple = (13, 7),
444
+ leg_loc:str = 'best',
445
+ xerr: bool = False):
446
+ '''
447
+ title (str) : Title
448
+ stacked (bool) : If true will stack the PDFs
449
+ ranges : List of tuples with ranges if any was used for the fit, e.g. [(0, 3), (7, 10)]
450
+ nbins : Bin numbers
451
+ unit : Unit for x axis, default is MeV/c^2
452
+ no_data (bool) : If true data won't be plotted as well as pull
453
+ xlabel : xlabel
454
+ ylabel : ylabel
455
+ d_leg : Customize legend
456
+ d_col : Customize color
457
+ plot_range : Set plot_range
458
+ plot_components (list): List of strings, with names of PDFs, which are expected to be sums of PDFs and whose components should be plotted separately
459
+ ext_text : Text that can be added to plot
460
+ add_pars (list|str) : List of names of parameters to be added or string with value 'all' to add all fit parameters. If this is used, plot won't use LHCb style.
461
+ skip_pulls(bool) : Will not draw pulls if True, default False
462
+ ymax (float) : Optional, if specified will be used to set the maximum in plot
463
+ blind (list) : PDF name for the signal if blinding is needed, followed by blinding range, min and max.
464
+ figsize (tuple) : Tuple with figure size, default (13, 7)
465
+ leg_loc (str) : Location of legend, default 'best'
466
+ xerr (bool or float) : Used to pass xerr to mplhep histplot. True will use error with bin size, False, no error, otherwise it's the size of the xerror bar
467
+ '''
468
+ # pylint: disable=too-many-locals, too-many-positional-arguments, too-many-arguments
469
+ d_leg = {} if d_leg is None else d_leg
470
+ d_col = {} if d_col is None else d_col
471
+ plot_components = [] if plot_components is None else plot_components
472
+
473
+ if not hasattr(self.total_model, 'pdfs'):
474
+ #if it's not a sum of PDFs, do not stack
475
+ stacked=False
476
+
477
+ self._figsize = figsize
478
+ self._leg_loc = leg_loc
479
+
480
+ self._initialize()
481
+
482
+ self._l_plot_components = plot_components
483
+
484
+ self._leg = d_leg
485
+ self._col = d_col
486
+ self.x = self._get_xcoor(plot_range)
487
+ self.axs = self._get_axis(add_pars, skip_pulls) if axs is None else axs
488
+ self._l_blind = blind
489
+ total_entries = self._get_data_yield(plot_range)
490
+ self.errors = self._get_errors(nbins, ranges)
491
+ self.dat_xerr = xerr
492
+
493
+ if not stacked:
494
+ log.debug('Plotting full model, for non-stacked case')
495
+ self._plot_model(self.axs[0], self.total_model, nbins)
496
+
497
+ log.debug('Plotting model components')
498
+ self._plot_model_components(nbins, stacked)
499
+
500
+ if not no_data:
501
+ log.debug('Plotting data')
502
+ self._plot_data(self.axs[0], nbins, ranges)
503
+
504
+ if not skip_pulls and not no_data:
505
+ log.debug('Plotting pulls')
506
+ self._plot_pulls(self.axs[1], nbins, total_entries, ranges)
507
+
508
+ text = self._get_text(ext_text)
509
+ xlabel, ylabel = self._get_labels(xlabel, ylabel, unit, nbins)
510
+
511
+ self.axs[0].legend(title=text, fontsize=20, title_fontsize=20, loc=self._leg_loc)
512
+ self.axs[0].set(xlabel=xlabel, ylabel=ylabel)
513
+ self.axs[0].set_xlim([self.lower, self.upper])
514
+
515
+ if title is not None:
516
+ self.axs[0].set_title(title)
517
+
518
+ if ymax is not None:
519
+ self.axs[0].set_ylim([0, ymax])
520
+
521
+ if not skip_pulls:
522
+ self.axs[1].set(xlabel=xlabel, ylabel="pulls")
523
+ self.axs[1].set_xlim([self.lower, self.upper])
524
+
525
+ for ax in self.axs:
526
+ ax.label_outer()
527
+ #----------------------------------------
@@ -1,14 +1,20 @@
1
1
  saving:
2
2
  plt_dir : tests/plotting/2d_weighted
3
+ definitions:
4
+ z : x + y
3
5
  general:
4
6
  size : [20, 10]
5
7
  plots_2d:
6
8
  - [x, y, weights, 'xy_w']
7
9
  - [x, y, null, 'xy_r']
10
+ - [x, z, null, 'xz_r']
8
11
  axes:
9
12
  x :
10
- binning : [-5.0, 8.0, 40]
13
+ binning : [-3.0, 3.0, 40]
11
14
  label : 'x'
12
15
  y :
13
16
  binning : [-5.0, 8.0, 40]
14
17
  label : 'y'
18
+ z :
19
+ binning : [-5.0, 16.0, 40]
20
+ label : 'z'