data-manipulation-utilities 0.2.7__py3-none-any.whl → 0.2.8.dev720__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {data_manipulation_utilities-0.2.7.dist-info → data_manipulation_utilities-0.2.8.dev720.dist-info}/METADATA +669 -42
- data_manipulation_utilities-0.2.8.dev720.dist-info/RECORD +45 -0
- {data_manipulation_utilities-0.2.7.dist-info → data_manipulation_utilities-0.2.8.dev720.dist-info}/WHEEL +1 -2
- data_manipulation_utilities-0.2.8.dev720.dist-info/entry_points.txt +8 -0
- dmu/generic/hashing.py +34 -8
- dmu/generic/utilities.py +164 -11
- dmu/logging/log_store.py +34 -2
- dmu/logging/messages.py +96 -0
- dmu/ml/cv_classifier.py +3 -3
- dmu/ml/cv_diagnostics.py +3 -0
- dmu/ml/cv_performance.py +58 -0
- dmu/ml/cv_predict.py +149 -46
- dmu/ml/train_mva.py +482 -100
- dmu/ml/utilities.py +29 -10
- dmu/pdataframe/utilities.py +28 -3
- dmu/plotting/fwhm.py +2 -2
- dmu/plotting/matrix.py +1 -1
- dmu/plotting/plotter.py +23 -3
- dmu/plotting/plotter_1d.py +96 -32
- dmu/plotting/plotter_2d.py +5 -0
- dmu/rdataframe/utilities.py +54 -3
- dmu/rfile/ddfgetter.py +102 -0
- dmu/stats/fit_stats.py +129 -0
- dmu/stats/fitter.py +55 -22
- dmu/stats/gof_calculator.py +7 -0
- dmu/stats/model_factory.py +153 -62
- dmu/stats/parameters.py +100 -0
- dmu/stats/utilities.py +443 -12
- dmu/stats/wdata.py +187 -0
- dmu/stats/zfit.py +17 -0
- dmu/stats/zfit_plotter.py +147 -36
- dmu/testing/utilities.py +102 -24
- dmu/workflow/__init__.py +0 -0
- dmu/workflow/cache.py +266 -0
- data_manipulation_utilities-0.2.7.data/scripts/publish +0 -89
- data_manipulation_utilities-0.2.7.dist-info/RECORD +0 -69
- data_manipulation_utilities-0.2.7.dist-info/entry_points.txt +0 -6
- data_manipulation_utilities-0.2.7.dist-info/top_level.txt +0 -3
- dmu_data/ml/tests/diagnostics_from_file.yaml +0 -13
- dmu_data/ml/tests/diagnostics_from_model.yaml +0 -10
- dmu_data/ml/tests/diagnostics_multiple_methods.yaml +0 -10
- dmu_data/ml/tests/diagnostics_overlay.yaml +0 -33
- dmu_data/ml/tests/train_mva.yaml +0 -58
- dmu_data/ml/tests/train_mva_with_diagnostics.yaml +0 -82
- dmu_data/plotting/tests/2d.yaml +0 -24
- dmu_data/plotting/tests/fig_size.yaml +0 -13
- dmu_data/plotting/tests/high_stat.yaml +0 -22
- dmu_data/plotting/tests/legend.yaml +0 -12
- dmu_data/plotting/tests/name.yaml +0 -14
- dmu_data/plotting/tests/no_bounds.yaml +0 -12
- dmu_data/plotting/tests/normalized.yaml +0 -9
- dmu_data/plotting/tests/plug_fwhm.yaml +0 -24
- dmu_data/plotting/tests/plug_stats.yaml +0 -19
- dmu_data/plotting/tests/simple.yaml +0 -9
- dmu_data/plotting/tests/stats.yaml +0 -9
- dmu_data/plotting/tests/styling.yaml +0 -11
- dmu_data/plotting/tests/title.yaml +0 -14
- dmu_data/plotting/tests/weights.yaml +0 -13
- dmu_data/text/transform.toml +0 -4
- dmu_data/text/transform.txt +0 -6
- dmu_data/text/transform_set.toml +0 -8
- dmu_data/text/transform_set.txt +0 -6
- dmu_data/text/transform_trf.txt +0 -12
- dmu_scripts/git/publish +0 -89
- dmu_scripts/physics/check_truth.py +0 -121
- dmu_scripts/rfile/compare_root_files.py +0 -299
- dmu_scripts/rfile/print_trees.py +0 -35
- dmu_scripts/ssh/coned.py +0 -168
- dmu_scripts/text/transform_text.py +0 -46
- {dmu_data → dmu}/__init__.py +0 -0
dmu/stats/zfit_plotter.py
CHANGED
@@ -3,25 +3,29 @@ Module containing plot class, used to plot fits
|
|
3
3
|
'''
|
4
4
|
# pylint: disable=too-many-instance-attributes, too-many-arguments
|
5
5
|
|
6
|
+
import math
|
6
7
|
import warnings
|
7
8
|
import pprint
|
8
9
|
|
9
10
|
import zfit
|
10
11
|
import hist
|
11
12
|
import mplhep
|
13
|
+
import tensorflow as tf
|
12
14
|
import pandas as pd
|
13
15
|
import numpy as np
|
14
16
|
import matplotlib.pyplot as plt
|
15
|
-
|
17
|
+
from zfit.core.basepdf import BasePDF as zpdf
|
16
18
|
|
19
|
+
import dmu.generic.utilities as gut
|
17
20
|
from dmu.logging.log_store import LogStore
|
18
21
|
|
19
|
-
log = LogStore.add_logger('dmu:
|
22
|
+
log = LogStore.add_logger('dmu:zfit_plotter')
|
20
23
|
#----------------------------------------
|
21
24
|
class ZFitPlotter:
|
22
25
|
'''
|
23
26
|
Class used to plot fits done with zfit
|
24
27
|
'''
|
28
|
+
#----------------------------------------
|
25
29
|
def __init__(self, data=None, model=None, weights=None, result=None, suffix=''):
|
26
30
|
'''
|
27
31
|
obs: zfit space you are using to define the data and model
|
@@ -62,7 +66,7 @@ class ZFitPlotter:
|
|
62
66
|
self._l_def_col = list(mcolors.TABLEAU_COLORS.keys())
|
63
67
|
#----------------------------------------
|
64
68
|
def _data_to_zdata(self, obs, data, weights):
|
65
|
-
if isinstance(data, zfit.
|
69
|
+
if isinstance(data, zfit.Data):
|
66
70
|
return data
|
67
71
|
|
68
72
|
if isinstance(data, np.ndarray):
|
@@ -76,36 +80,76 @@ class ZFitPlotter:
|
|
76
80
|
|
77
81
|
return data
|
78
82
|
#----------------------------------------
|
79
|
-
def _get_errors(
|
80
|
-
|
81
|
-
|
83
|
+
def _get_errors(
|
84
|
+
self,
|
85
|
+
nbins : int = 100,
|
86
|
+
l_range: list[tuple[float,float]]|None = None) -> list[float]:
|
87
|
+
'''
|
88
|
+
Parameters
|
89
|
+
---------------------
|
90
|
+
nbins : Number of bins
|
91
|
+
l_range: List of ranges where data should be picked, if None, will pick full range
|
92
|
+
|
93
|
+
Returns
|
94
|
+
---------------------
|
95
|
+
list of errors associated to histogram filled with data
|
96
|
+
'''
|
97
|
+
dat, wgt = self._get_range_data(l_range=l_range, blind=False)
|
98
|
+
data_hist = hist.Hist.new.Regular(
|
99
|
+
nbins,
|
100
|
+
self.lower,
|
101
|
+
self.upper,
|
102
|
+
name =self.obs.obs[0],
|
103
|
+
underflow =False,
|
104
|
+
overflow =False)
|
105
|
+
|
82
106
|
data_hist = data_hist.Weight()
|
83
107
|
data_hist.fill(dat, weight=wgt)
|
84
108
|
|
85
109
|
tmp_fig, tmp_ax = plt.subplots()
|
86
|
-
errorbars
|
110
|
+
errorbars = mplhep.histplot(
|
87
111
|
data_hist,
|
88
|
-
yerr=True,
|
89
|
-
color='white',
|
90
|
-
histtype=
|
91
|
-
label=None,
|
92
|
-
ax=tmp_ax
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
lines = errorbars[0].errorbar[2]
|
97
|
-
segs = lines[0].get_segments()
|
112
|
+
yerr =True,
|
113
|
+
color ='white',
|
114
|
+
histtype ='errorbar',
|
115
|
+
label =None,
|
116
|
+
ax =tmp_ax)
|
117
|
+
|
118
|
+
lines = errorbars[0].errorbar[2]
|
119
|
+
segs = lines[0].get_segments()
|
98
120
|
values = data_hist.values()
|
99
121
|
|
100
122
|
l_error=[]
|
101
123
|
for i in range(nbins):
|
102
|
-
|
103
|
-
|
124
|
+
seg = segs[i]
|
125
|
+
val = values[i]
|
126
|
+
|
127
|
+
try:
|
128
|
+
low = val - seg[0][1]
|
129
|
+
up = -val + seg[1][1]
|
130
|
+
except IndexError as exc:
|
131
|
+
raise IndexError(f'Cannot read the upper/lower errors, found {seg}') from exc
|
132
|
+
|
104
133
|
l_error.append((low, up))
|
105
134
|
|
135
|
+
plt.close(tmp_fig)
|
136
|
+
|
106
137
|
return l_error
|
107
138
|
#----------------------------------------
|
108
|
-
def _get_range_data(
|
139
|
+
def _get_range_data(
|
140
|
+
self,
|
141
|
+
l_range : list[tuple[float,float]]|None,
|
142
|
+
blind : bool =True) -> tuple[np.ndarray, np.ndarray]:
|
143
|
+
'''
|
144
|
+
Parameters
|
145
|
+
-----------------
|
146
|
+
l_range: List of ranges, i.e. tuples of bounds
|
147
|
+
blind : If true (default) will blind the range specified, i.e. will exclude it
|
148
|
+
|
149
|
+
Returns
|
150
|
+
-----------------
|
151
|
+
Tuple with two numpy arrays defined in those ranges, with the observable and the weights.
|
152
|
+
'''
|
109
153
|
sdat = self.data_np
|
110
154
|
swgt = self.data_weight_np
|
111
155
|
dmat = np.array([sdat, swgt]).T
|
@@ -117,6 +161,8 @@ class ZFitPlotter:
|
|
117
161
|
|
118
162
|
if l_range is None:
|
119
163
|
[dat, wgt] = dmat.T
|
164
|
+
self._check_data(dat=dat, wgt=wgt)
|
165
|
+
|
120
166
|
return dat, wgt
|
121
167
|
|
122
168
|
l_dat = []
|
@@ -132,23 +178,42 @@ class ZFitPlotter:
|
|
132
178
|
dat_f = np.concatenate(l_dat)
|
133
179
|
wgt_f = np.concatenate(l_wgt)
|
134
180
|
|
181
|
+
self._check_data(dat=dat_f, wgt=wgt_f)
|
182
|
+
|
135
183
|
return dat_f, wgt_f
|
136
184
|
#----------------------------------------
|
185
|
+
def _check_data(
|
186
|
+
self,
|
187
|
+
dat : np.ndarray,
|
188
|
+
wgt : np.ndarray) -> None:
|
189
|
+
'''
|
190
|
+
Checks for empty data, etc
|
191
|
+
|
192
|
+
Parameters
|
193
|
+
------------
|
194
|
+
Numpy arrays with data and weights
|
195
|
+
'''
|
196
|
+
|
197
|
+
if dat.shape != wgt.shape:
|
198
|
+
raise ValueError(f'Shapes or data and weights differ: {dat.shape}/{wgt.shape}')
|
199
|
+
|
200
|
+
if len(dat) == 0:
|
201
|
+
raise ValueError('Dataset is empty')
|
202
|
+
#----------------------------------------
|
137
203
|
def _plot_data(self, ax, nbins=100, l_range=None):
|
138
204
|
dat, wgt = self._get_range_data(l_range, blind=True)
|
139
205
|
data_hist = hist.Hist.new.Regular(nbins, self.lower, self.upper, name=self.obs.obs[0], underflow=False, overflow=False)
|
140
206
|
data_hist = data_hist.Weight()
|
141
207
|
data_hist.fill(dat, weight=wgt)
|
142
208
|
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
)
|
209
|
+
mplhep.histplot(
|
210
|
+
data_hist,
|
211
|
+
yerr = True,
|
212
|
+
color = 'black',
|
213
|
+
histtype = 'errorbar',
|
214
|
+
label = self._leg.get('Data', 'Data'),
|
215
|
+
ax = ax,
|
216
|
+
xerr = self.dat_xerr)
|
152
217
|
#----------------------------------------
|
153
218
|
def _pull_hist(self, pdf_hist, nbins, data_yield, l_range=None):
|
154
219
|
pdf_values= pdf_hist.values()
|
@@ -170,8 +235,16 @@ class ZFitPlotter:
|
|
170
235
|
err = low if res > 0 else up
|
171
236
|
pul = res / err
|
172
237
|
|
173
|
-
|
174
|
-
|
238
|
+
# If the data is weighted
|
239
|
+
# and the data does not exist
|
240
|
+
# The pulls will have an error of zero => pull is inf
|
241
|
+
# Ignore these cases
|
242
|
+
if math.isinf(pul):
|
243
|
+
pass
|
244
|
+
elif abs(pul) > 5:
|
245
|
+
log.info(f'Pull: {pul:.2f}=({dat_val:.2f}-{pdf_val:.2f})/{err:.2f}')
|
246
|
+
else:
|
247
|
+
log.debug(f'Pull: {pul:.2f}=({dat_val:.2f}-{pdf_val:.2f})/{err:.2f}')
|
175
248
|
|
176
249
|
pulls.append(pul)
|
177
250
|
pull_errors[0].append(low / err)
|
@@ -241,7 +314,7 @@ class ZFitPlotter:
|
|
241
314
|
val = d_val['value']
|
242
315
|
name= par if isinstance(par, str) else par.name
|
243
316
|
try:
|
244
|
-
err = d_val['
|
317
|
+
err = d_val['minuit_hesse']['error']
|
245
318
|
except KeyError:
|
246
319
|
log.warning(f'Cannot extract {name} Hesse errors, using zeros')
|
247
320
|
pprint.pprint(d_val)
|
@@ -348,7 +421,7 @@ class ZFitPlotter:
|
|
348
421
|
if stacked:
|
349
422
|
ax.fill_between(self.x, y, alpha=1.0, label=self._leg.get(name, name), color=self._get_col(name))
|
350
423
|
else:
|
351
|
-
ax.plot(self.x, y, '
|
424
|
+
ax.plot(self.x, y, ':', label=self._leg.get(name, name), color=self._col.get(name))
|
352
425
|
|
353
426
|
if (blind_name is not None) and (was_blinded is False):
|
354
427
|
for model in self.total_model.pdfs:
|
@@ -365,10 +438,34 @@ class ZFitPlotter:
|
|
365
438
|
|
366
439
|
return col
|
367
440
|
#----------------------------------------
|
441
|
+
def _print_data(self) -> None:
|
442
|
+
log.info(f'Data shape : {self.data_np.shape}')
|
443
|
+
log.info(f'Weights shape: {self.data_weight_np.shape}')
|
444
|
+
|
445
|
+
nnans = np.sum(np.isnan(self.data_np))
|
446
|
+
log.info(f'NaNs: {nnans}')
|
447
|
+
|
448
|
+
# This function will run before program raises
|
449
|
+
# One should be able to drop any plot
|
450
|
+
plt.close('all')
|
451
|
+
plt.hist(self.data_np, weights=self.data_weight_np)
|
452
|
+
plt.show()
|
453
|
+
#----------------------------------------
|
454
|
+
def _evaluate_pdf(self, pdf : zpdf) -> np.ndarray:
|
455
|
+
try:
|
456
|
+
arr_y = pdf.pdf(self.x)
|
457
|
+
except tf.errors.InvalidArgumentError as exc:
|
458
|
+
log.info(f'X values: {self.x}')
|
459
|
+
self._print_data()
|
460
|
+
raise ValueError('Cannot evaluate PDF') from exc
|
461
|
+
|
462
|
+
return arr_y
|
463
|
+
#----------------------------------------
|
368
464
|
def _plot_sub_components(self, y, nbins, stacked, nevt, l_model):
|
369
465
|
l_y = []
|
370
466
|
for frc, model in l_model:
|
371
|
-
|
467
|
+
arr_y = self._evaluate_pdf(pdf = model)
|
468
|
+
this_y = arr_y * nevt * frc / nbins * (self.upper - self.lower)
|
372
469
|
|
373
470
|
if stacked:
|
374
471
|
y = this_y if y is None else y + this_y
|
@@ -385,7 +482,13 @@ class ZFitPlotter:
|
|
385
482
|
return
|
386
483
|
|
387
484
|
data_yield = self.data_weight_np.sum()
|
388
|
-
|
485
|
+
try:
|
486
|
+
arr_y = self._evaluate_pdf(model)
|
487
|
+
y = arr_y * data_yield / nbins * (self.upper - self.lower)
|
488
|
+
except tf.errors.InvalidArgumentError as exc:
|
489
|
+
log.warning(f'Data yield: {data_yield:.0f}')
|
490
|
+
log.info(self.data_np)
|
491
|
+
raise RuntimeError('Cannot parse PDF') from exc
|
389
492
|
|
390
493
|
name = model.name
|
391
494
|
ax.plot(self.x, y, linestyle, label=self._leg.get(name, name), color=self._col.get(name))
|
@@ -396,7 +499,7 @@ class ZFitPlotter:
|
|
396
499
|
|
397
500
|
if ylabel == "":
|
398
501
|
width = (self.upper-self.lower)/nbins
|
399
|
-
ylabel = f'Candidates / ({width:.
|
502
|
+
ylabel = f'Candidates / ({width:.0f} {unit})'
|
400
503
|
|
401
504
|
return xlabel, ylabel
|
402
505
|
#----------------------------------------
|
@@ -442,6 +545,7 @@ class ZFitPlotter:
|
|
442
545
|
add_pars = None,
|
443
546
|
ymax = None,
|
444
547
|
skip_pulls = False,
|
548
|
+
pull_styling :bool= True,
|
445
549
|
yscale : str = None,
|
446
550
|
axs = None,
|
447
551
|
figsize:tuple = (13, 7),
|
@@ -459,6 +563,7 @@ class ZFitPlotter:
|
|
459
563
|
d_leg : Customize legend
|
460
564
|
d_col : Customize color
|
461
565
|
plot_range : Set plot_range
|
566
|
+
pull_styling(bool) : Will add lines at +/-3 and set range to +/-5 for pull plots, by default True
|
462
567
|
plot_components (list): List of strings, with names of PDFs, which are expected to be sums of PDFs and whose components should be plotted separately
|
463
568
|
ext_text : Text that can be added to plot
|
464
569
|
add_pars (list|str) : List of names of parameters to be added or string with value 'all' to add all fit parameters. If this is used, plot won't use LHCb style.
|
@@ -532,4 +637,10 @@ class ZFitPlotter:
|
|
532
637
|
|
533
638
|
for ax in self.axs:
|
534
639
|
ax.label_outer()
|
640
|
+
|
641
|
+
if pull_styling and not skip_pulls:
|
642
|
+
self.axs[1].axhline(y=-3, color='red' , linestyle='-', lw=2)
|
643
|
+
self.axs[1].axhline(y= 0, color='gray', linestyle='-', lw=1)
|
644
|
+
self.axs[1].axhline(y=+3, color='red' , linestyle='-', lw=2)
|
645
|
+
self.axs[1].set_ylim(-5, 5)
|
535
646
|
#----------------------------------------
|
dmu/testing/utilities.py
CHANGED
@@ -10,6 +10,7 @@ from importlib.resources import files
|
|
10
10
|
|
11
11
|
from ROOT import RDF, TFile, RDataFrame
|
12
12
|
|
13
|
+
import uproot
|
13
14
|
import joblib
|
14
15
|
import pandas as pnd
|
15
16
|
import numpy
|
@@ -27,6 +28,14 @@ class Data:
|
|
27
28
|
Class storing shared data
|
28
29
|
'''
|
29
30
|
out_dir = '/tmp/tests/dmu/ml/cv_predict'
|
31
|
+
|
32
|
+
d_col = {
|
33
|
+
'main' : ['index', 'a0', 'b0'],
|
34
|
+
'frn1' : ['index', 'a1', 'b1'],
|
35
|
+
'frn2' : ['index', 'a2', 'b2'],
|
36
|
+
'frn3' : ['index', 'a3', 'b3'],
|
37
|
+
'frn4' : ['index', 'a4', 'b4'],
|
38
|
+
}
|
30
39
|
# -------------------------------
|
31
40
|
def _double_data(df_1 : pnd.DataFrame) -> pnd.DataFrame:
|
32
41
|
df_2 = df_1.copy()
|
@@ -53,25 +62,39 @@ def _add_nans(df : pnd.DataFrame, columns : list[str]) -> pnd.DataFrame:
|
|
53
62
|
|
54
63
|
return df
|
55
64
|
# -------------------------------
|
56
|
-
def get_rdf(
|
57
|
-
|
58
|
-
|
59
|
-
|
65
|
+
def get_rdf(
|
66
|
+
kind : Union[str,None] = None,
|
67
|
+
repeated : bool = False,
|
68
|
+
nentries : int = 3_000,
|
69
|
+
use_preffix : bool = False,
|
70
|
+
columns_with_nans : list[str] = None):
|
60
71
|
'''
|
61
72
|
Return ROOT dataframe with toy data
|
73
|
+
|
74
|
+
kind : sig, bkg or bkg_alt
|
75
|
+
repeated : Will add repeated rows
|
76
|
+
nentries : Number of rows
|
77
|
+
columns_with_nans : List of column names in [w, y, z]
|
62
78
|
'''
|
79
|
+
# Needed for a specific test
|
80
|
+
xnm = 'preffix.x.suffix' if use_preffix else 'x'
|
63
81
|
|
64
82
|
d_data = {}
|
65
83
|
if kind == 'sig':
|
66
|
-
d_data[
|
67
|
-
d_data['
|
68
|
-
d_data['y'] = numpy.random.normal(0, 1, size=nentries)
|
69
|
-
d_data['z'] = numpy.random.normal(0, 1, size=nentries)
|
84
|
+
d_data[xnm] = numpy.random.normal(0.0, 1.0, size=nentries)
|
85
|
+
d_data['w'] = numpy.random.normal(0.0, 1.0, size=nentries)
|
86
|
+
d_data['y'] = numpy.random.normal(0.0, 1.0, size=nentries)
|
87
|
+
d_data['z'] = numpy.random.normal(0.0, 1.0, size=nentries)
|
70
88
|
elif kind == 'bkg':
|
71
|
-
d_data[
|
72
|
-
d_data['
|
73
|
-
d_data['y'] = numpy.random.normal(1, 1, size=nentries)
|
74
|
-
d_data['z'] = numpy.random.normal(1, 1, size=nentries)
|
89
|
+
d_data[xnm] = numpy.random.normal(1.0, 1.0, size=nentries)
|
90
|
+
d_data['w'] = numpy.random.normal(1.0, 1.0, size=nentries)
|
91
|
+
d_data['y'] = numpy.random.normal(1.0, 1.0, size=nentries)
|
92
|
+
d_data['z'] = numpy.random.normal(1.0, 1.0, size=nentries)
|
93
|
+
elif kind == 'bkg_alt':
|
94
|
+
d_data[xnm] = numpy.random.normal(1.3, 1.3, size=nentries)
|
95
|
+
d_data['w'] = numpy.random.normal(1.3, 1.3, size=nentries)
|
96
|
+
d_data['y'] = numpy.random.normal(1.3, 1.3, size=nentries)
|
97
|
+
d_data['z'] = numpy.random.normal(1.3, 1.3, size=nentries)
|
75
98
|
else:
|
76
99
|
log.error(f'Invalid kind: {kind}')
|
77
100
|
raise ValueError
|
@@ -132,24 +155,79 @@ def get_file_with_trees(path : str) -> TFile:
|
|
132
155
|
|
133
156
|
return TFile(path)
|
134
157
|
# -------------------------------
|
135
|
-
def get_models(
|
158
|
+
def get_models(
|
159
|
+
rdf_sig : RDataFrame,
|
160
|
+
rdf_bkg : RDataFrame,
|
161
|
+
name : str = 'train_mva',
|
162
|
+
out_dir : str | None = None) -> tuple[list[CVClassifier], float]:
|
136
163
|
'''
|
137
|
-
Will train and return models
|
164
|
+
Will train and return models together with the AUC in a tuple
|
165
|
+
|
166
|
+
rdf_xxx : Signal or background dataframe used for training
|
167
|
+
name : Name of config file, e.g. train_mva
|
168
|
+
out_dir : Directory where the training output will go, optional.
|
138
169
|
'''
|
170
|
+
out_dir = Data.out_dir if out_dir is None else out_dir
|
139
171
|
|
140
|
-
cfg
|
141
|
-
|
142
|
-
plt_dir = f'{Data.out_dir}/cv_predict'
|
143
|
-
cfg['saving']['path'] = pkl_path
|
144
|
-
cfg['plotting']['val_dir'] = plt_dir
|
145
|
-
cfg['plotting']['features']['saving']['plt_dir'] = plt_dir
|
172
|
+
cfg = get_config(f'ml/tests/{name}.yaml')
|
173
|
+
cfg['saving']['output'] = out_dir
|
146
174
|
|
147
|
-
obj= TrainMva(sig=rdf_sig, bkg=rdf_bkg, cfg=cfg)
|
148
|
-
obj.run()
|
175
|
+
obj = TrainMva(sig=rdf_sig, bkg=rdf_bkg, cfg=cfg)
|
176
|
+
auc = obj.run()
|
149
177
|
|
150
|
-
pkl_wc =
|
178
|
+
pkl_wc = f'{out_dir}/model*.pkl'
|
151
179
|
l_pkl_path = glob.glob(pkl_wc)
|
152
180
|
l_model = [ joblib.load(pkl_path) for pkl_path in l_pkl_path ]
|
153
181
|
|
154
|
-
return l_model
|
182
|
+
return l_model, auc
|
155
183
|
# -------------------------------
|
184
|
+
def _make_file(
|
185
|
+
fpath : str,
|
186
|
+
tree : str,
|
187
|
+
nentries : int) -> None:
|
188
|
+
|
189
|
+
fdir = os.path.dirname(fpath)
|
190
|
+
sample = os.path.basename(fdir)
|
191
|
+
l_col_name = Data.d_col[sample]
|
192
|
+
data = {}
|
193
|
+
for col_name in l_col_name:
|
194
|
+
if col_name == 'index':
|
195
|
+
data[col_name] = numpy.arange(nentries)
|
196
|
+
continue
|
197
|
+
|
198
|
+
data[col_name] = numpy.random.normal(0, 1, nentries)
|
199
|
+
|
200
|
+
with uproot.recreate(fpath) as ofile:
|
201
|
+
log.debug(f'Saving to: {fpath}:{tree}')
|
202
|
+
ofile[tree] = data
|
203
|
+
# -------------------------------
|
204
|
+
def build_friend_structure(file_name : str, nentries : int) -> None:
|
205
|
+
'''
|
206
|
+
Will load YAML file with file structure needed to
|
207
|
+
test code that relies on friend trees, e.g. DDFGetter
|
208
|
+
|
209
|
+
Parameters:
|
210
|
+
-------------------
|
211
|
+
file_name (str): Name of YAML file with wanted structure, e.g. friends.yaml
|
212
|
+
nentries (int) : Number of entries in file
|
213
|
+
'''
|
214
|
+
cfg_path = files('dmu_data').joinpath(f'rfile/{file_name}')
|
215
|
+
with open(cfg_path, encoding='utf=8') as ifile:
|
216
|
+
data = yaml.safe_load(ifile)
|
217
|
+
|
218
|
+
if 'tree' not in data:
|
219
|
+
raise ValueError('tree entry missing in: {cfg_path}')
|
220
|
+
|
221
|
+
tree_name = data['tree']
|
222
|
+
|
223
|
+
if 'samples' not in data:
|
224
|
+
raise ValueError('Samples section missing in: {cfg_path}')
|
225
|
+
|
226
|
+
if 'files' not in data:
|
227
|
+
raise ValueError('Files section missing in: {cfg_path}')
|
228
|
+
|
229
|
+
for fdir in data['samples']:
|
230
|
+
for fname in data['files']:
|
231
|
+
path = f'{fdir}/{fname}'
|
232
|
+
_make_file(fpath=path, tree=tree_name, nentries=nentries)
|
233
|
+
# ----------------------------------------------
|
dmu/workflow/__init__.py
ADDED
File without changes
|