data-manipulation-utilities 0.2.6__py3-none-any.whl → 0.2.8.dev714__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {data_manipulation_utilities-0.2.6.dist-info → data_manipulation_utilities-0.2.8.dev714.dist-info}/METADATA +800 -34
- data_manipulation_utilities-0.2.8.dev714.dist-info/RECORD +93 -0
- {data_manipulation_utilities-0.2.6.dist-info → data_manipulation_utilities-0.2.8.dev714.dist-info}/WHEEL +1 -1
- {data_manipulation_utilities-0.2.6.dist-info → data_manipulation_utilities-0.2.8.dev714.dist-info}/entry_points.txt +1 -0
- dmu/__init__.py +0 -0
- dmu/generic/hashing.py +70 -0
- dmu/generic/utilities.py +175 -9
- dmu/generic/version_management.py +3 -5
- dmu/logging/log_store.py +34 -2
- dmu/logging/messages.py +96 -0
- dmu/ml/cv_classifier.py +3 -3
- dmu/ml/cv_diagnostics.py +224 -0
- dmu/ml/cv_performance.py +58 -0
- dmu/ml/cv_predict.py +149 -46
- dmu/ml/train_mva.py +587 -112
- dmu/ml/utilities.py +29 -10
- dmu/pdataframe/utilities.py +61 -3
- dmu/plotting/fwhm.py +64 -0
- dmu/plotting/matrix.py +1 -1
- dmu/plotting/plotter.py +25 -3
- dmu/plotting/plotter_1d.py +159 -14
- dmu/plotting/plotter_2d.py +5 -0
- dmu/rdataframe/utilities.py +54 -3
- dmu/rfile/ddfgetter.py +102 -0
- dmu/stats/fit_stats.py +129 -0
- dmu/stats/fitter.py +56 -23
- dmu/stats/gof_calculator.py +7 -0
- dmu/stats/model_factory.py +305 -50
- dmu/stats/parameters.py +100 -0
- dmu/stats/utilities.py +443 -12
- dmu/stats/wdata.py +187 -0
- dmu/stats/zfit.py +17 -0
- dmu/stats/zfit_models.py +68 -0
- dmu/stats/zfit_plotter.py +175 -56
- dmu/testing/utilities.py +120 -15
- dmu/workflow/__init__.py +0 -0
- dmu/workflow/cache.py +266 -0
- dmu_data/ml/tests/diagnostics_from_file.yaml +13 -0
- dmu_data/ml/tests/diagnostics_from_model.yaml +10 -0
- dmu_data/ml/tests/diagnostics_multiple_methods.yaml +10 -0
- dmu_data/ml/tests/diagnostics_overlay.yaml +33 -0
- dmu_data/ml/tests/train_mva.yaml +20 -12
- dmu_data/ml/tests/train_mva_def.yaml +75 -0
- dmu_data/ml/tests/train_mva_with_diagnostics.yaml +87 -0
- dmu_data/ml/tests/train_mva_with_preffix.yaml +58 -0
- dmu_data/plotting/tests/2d.yaml +5 -5
- dmu_data/plotting/tests/line.yaml +15 -0
- dmu_data/plotting/tests/plug_fwhm.yaml +24 -0
- dmu_data/plotting/tests/plug_stats.yaml +19 -0
- dmu_data/plotting/tests/simple.yaml +4 -3
- dmu_data/plotting/tests/styling.yaml +18 -0
- dmu_data/rfile/friends.yaml +13 -0
- dmu_data/stats/fitter/test_simple.yaml +28 -0
- dmu_data/stats/kde_optimizer/control.json +1 -0
- dmu_data/stats/kde_optimizer/signal.json +1 -0
- dmu_data/stats/parameters/data.yaml +178 -0
- dmu_data/tests/config.json +6 -0
- dmu_data/tests/config.yaml +4 -0
- dmu_data/tests/pdf_to_tex.txt +34 -0
- dmu_scripts/kerberos/check_expiration +21 -0
- dmu_scripts/kerberos/convert_certificate +22 -0
- dmu_scripts/ml/compare_classifiers.py +85 -0
- data_manipulation_utilities-0.2.6.dist-info/RECORD +0 -57
- {data_manipulation_utilities-0.2.6.data → data_manipulation_utilities-0.2.8.dev714.data}/scripts/publish +0 -0
- {data_manipulation_utilities-0.2.6.dist-info → data_manipulation_utilities-0.2.8.dev714.dist-info}/top_level.txt +0 -0
dmu/ml/train_mva.py
CHANGED
@@ -3,16 +3,28 @@ Module with TrainMva class
|
|
3
3
|
'''
|
4
4
|
# pylint: disable = too-many-locals, no-name-in-module
|
5
5
|
# pylint: disable = too-many-arguments, too-many-positional-arguments
|
6
|
+
# pylint: disable = too-many-instance-attributes
|
7
|
+
# pylint: disable = too-many-arguments, too-many-positional-arguments
|
6
8
|
|
7
9
|
import os
|
10
|
+
import copy
|
11
|
+
import json
|
12
|
+
import math
|
13
|
+
|
14
|
+
from contextlib import contextmanager
|
15
|
+
from typing import Optional, Union
|
16
|
+
from functools import partial
|
8
17
|
|
18
|
+
import tqdm
|
9
19
|
import joblib
|
20
|
+
import optuna
|
10
21
|
import pandas as pnd
|
11
22
|
import numpy
|
12
23
|
import matplotlib.pyplot as plt
|
13
24
|
|
14
25
|
from sklearn.metrics import roc_curve, auc
|
15
|
-
from sklearn.model_selection import StratifiedKFold
|
26
|
+
from sklearn.model_selection import StratifiedKFold, cross_val_score
|
27
|
+
from sklearn.ensemble import GradientBoostingClassifier
|
16
28
|
|
17
29
|
from ROOT import RDataFrame, RDF
|
18
30
|
|
@@ -20,18 +32,29 @@ import dmu.ml.utilities as ut
|
|
20
32
|
import dmu.pdataframe.utilities as put
|
21
33
|
import dmu.plotting.utilities as plu
|
22
34
|
|
35
|
+
from dmu.ml.cv_diagnostics import CVDiagnostics
|
23
36
|
from dmu.ml.cv_classifier import CVClassifier as cls
|
24
37
|
from dmu.plotting.plotter_1d import Plotter1D as Plotter
|
25
38
|
from dmu.plotting.matrix import MatrixPlotter
|
26
39
|
from dmu.logging.log_store import LogStore
|
27
40
|
|
28
|
-
|
41
|
+
NPA = numpy.ndarray
|
29
42
|
log = LogStore.add_logger('dmu:ml:train_mva')
|
30
43
|
# ---------------------------------------------
|
44
|
+
class NoFeatureInfo(Exception):
|
45
|
+
'''
|
46
|
+
Used when information about a feature is missing in the config file
|
47
|
+
'''
|
48
|
+
def __init__(self, message : str):
|
49
|
+
super().__init__(message)
|
50
|
+
# ---------------------------------------------
|
31
51
|
class TrainMva:
|
32
52
|
'''
|
33
53
|
Interface to scikit learn used to train classifier
|
34
54
|
'''
|
55
|
+
# TODO:
|
56
|
+
# - Hyperparameter optimization methods should go into their own class
|
57
|
+
# - Data preprocessing methods might need their own class
|
35
58
|
# ---------------------------------------------
|
36
59
|
def __init__(self, bkg : RDataFrame, sig : RDataFrame, cfg : dict):
|
37
60
|
'''
|
@@ -40,32 +63,71 @@ class TrainMva:
|
|
40
63
|
cfg (dict) : Dictionary storing configuration for training
|
41
64
|
'''
|
42
65
|
self._cfg = cfg
|
66
|
+
self._auc = math.nan # This is where the Area Under the ROC curve for the full sample will be saved
|
43
67
|
self._l_ft_name = self._cfg['training']['features']
|
68
|
+
self._pbar : Optional[tqdm.tqdm]
|
69
|
+
|
70
|
+
self._rdf_sig_org = sig
|
71
|
+
self._rdf_bkg_org = bkg
|
44
72
|
|
45
|
-
|
46
|
-
|
73
|
+
rdf_bkg = self._preprocess_rdf(rdf=bkg, kind='bkg')
|
74
|
+
rdf_sig = self._preprocess_rdf(rdf=sig, kind='sig')
|
75
|
+
|
76
|
+
df_ft_sig, l_lab_sig = self._get_sample_inputs(rdf = rdf_sig, label = 1)
|
77
|
+
df_ft_bkg, l_lab_bkg = self._get_sample_inputs(rdf = rdf_bkg, label = 0)
|
47
78
|
|
48
79
|
self._df_ft = pnd.concat([df_ft_sig, df_ft_bkg], axis=0)
|
49
80
|
self._l_lab = numpy.array(l_lab_sig + l_lab_bkg)
|
50
81
|
|
51
|
-
self._rdf_bkg = self._get_rdf(rdf =
|
52
|
-
self._rdf_sig = self._get_rdf(rdf =
|
82
|
+
self._rdf_bkg = self._get_rdf(rdf = rdf_bkg, df_feat=df_ft_bkg)
|
83
|
+
self._rdf_sig = self._get_rdf(rdf = rdf_sig, df_feat=df_ft_sig)
|
84
|
+
|
85
|
+
self._rdm_state = 42 # Random state for training classifier
|
86
|
+
self._nworkers = 1 # Used to set number of workers for ANY process. Can be overriden with `use` context manager
|
87
|
+
|
88
|
+
optuna.logging.set_verbosity(optuna.logging.WARNING)
|
89
|
+
# ---------------------------------------------
|
90
|
+
def _get_extra_columns(self, rdf : RDataFrame, df : pnd.DataFrame) -> list[str]:
|
91
|
+
d_plot = self._cfg['plotting']['features']['plots']
|
92
|
+
l_expr = list(d_plot)
|
93
|
+
l_rdf = [ name.c_str() for name in rdf.GetColumnNames() ]
|
94
|
+
|
95
|
+
l_extr = []
|
96
|
+
for expr in l_expr:
|
97
|
+
if expr not in l_rdf:
|
98
|
+
continue
|
99
|
+
|
100
|
+
if expr in df.columns:
|
101
|
+
continue
|
102
|
+
|
103
|
+
l_extr.append(expr)
|
104
|
+
|
105
|
+
return l_extr
|
53
106
|
# ---------------------------------------------
|
54
|
-
def _get_rdf(self, rdf : RDataFrame,
|
107
|
+
def _get_rdf(self, rdf : RDataFrame, df_feat : pnd.DataFrame) -> RDataFrame:
|
55
108
|
'''
|
56
109
|
Takes original ROOT dataframe and pre-processed features dataframe
|
57
110
|
Adds missing branches to latter and returns expanded ROOT dataframe
|
111
|
+
Need to make plots
|
58
112
|
'''
|
59
113
|
|
60
|
-
|
61
|
-
|
62
|
-
|
114
|
+
l_extr_col = self._get_extra_columns(rdf, df_feat)
|
115
|
+
if len(l_extr_col) > 20:
|
116
|
+
for name in l_extr_col:
|
117
|
+
log.debug(name)
|
118
|
+
raise ValueError('Found more than 20 extra columns')
|
119
|
+
|
120
|
+
d_data = rdf.AsNumpy(l_extr_col)
|
121
|
+
log.debug(f'Adding extra-nonfeature columns: {l_extr_col}')
|
122
|
+
df_extr = pnd.DataFrame(d_data)
|
123
|
+
|
124
|
+
nmain = len(df_feat.columns)
|
125
|
+
nextr = len(df_extr.columns)
|
63
126
|
|
64
|
-
log.debug(f'
|
127
|
+
log.debug(f'Main DF size: {nmain}')
|
128
|
+
log.debug(f'Extra DF size: {nextr}')
|
65
129
|
|
66
|
-
|
67
|
-
df_ext = pnd.DataFrame(d_data)
|
68
|
-
df_all = pnd.concat([df, df_ext], axis=1)
|
130
|
+
df_all = pnd.concat([df_feat, df_extr], axis=1)
|
69
131
|
|
70
132
|
return RDF.FromPandas(df_all)
|
71
133
|
# ---------------------------------------------
|
@@ -89,6 +151,53 @@ class TrainMva:
|
|
89
151
|
log.info(70 * '-')
|
90
152
|
|
91
153
|
return df
|
154
|
+
#---------------------------------
|
155
|
+
def _add_sample_columns(
|
156
|
+
self,
|
157
|
+
rdf : RDataFrame,
|
158
|
+
kind : str) -> RDataFrame:
|
159
|
+
'''
|
160
|
+
This will apply sample specific column definitions
|
161
|
+
to the dataframe
|
162
|
+
'''
|
163
|
+
try:
|
164
|
+
d_def = self._cfg['dataset']['samples'][kind]['definitions']
|
165
|
+
except KeyError:
|
166
|
+
log.debug(f'Not found sample definitions for {kind}')
|
167
|
+
return rdf
|
168
|
+
|
169
|
+
log.info(60 * '-')
|
170
|
+
log.info(f'Found sample definitions for {kind}')
|
171
|
+
log.info(60 * '-')
|
172
|
+
for name, expr in d_def.items():
|
173
|
+
log.info(f'{name:<30}{"-->":<10}{expr:<20}')
|
174
|
+
rdf = rdf.Define(name, expr)
|
175
|
+
log.info(60 * '-')
|
176
|
+
|
177
|
+
return rdf
|
178
|
+
# ---------------------------------------------
|
179
|
+
def _preprocess_rdf(self, rdf : RDataFrame, kind : str) -> RDataFrame:
|
180
|
+
rdf = self._add_sample_columns(rdf, kind)
|
181
|
+
|
182
|
+
if 'define' not in self._cfg['dataset']:
|
183
|
+
log.debug('No definitions found')
|
184
|
+
return rdf
|
185
|
+
|
186
|
+
log.debug(f'Definitions found for {kind}')
|
187
|
+
d_def = self._cfg['dataset']['define']
|
188
|
+
for name, expr in d_def.items():
|
189
|
+
log.debug(f'{name:<20}{expr}')
|
190
|
+
try:
|
191
|
+
rdf = rdf.Define(name, expr)
|
192
|
+
except TypeError as exc:
|
193
|
+
l_col = [ name.c_str() for name in rdf.GetColumnNames() ]
|
194
|
+
branch_list = 'found_branches.txt'
|
195
|
+
with open(branch_list, 'w', encoding='utf-8') as ifile:
|
196
|
+
json.dump(l_col, ifile, indent=2)
|
197
|
+
|
198
|
+
raise TypeError(f'Branches found were dumped to {branch_list}') from exc
|
199
|
+
|
200
|
+
return rdf
|
92
201
|
# ---------------------------------------------
|
93
202
|
def _get_sample_inputs(self, rdf : RDataFrame, label : int) -> tuple[pnd.DataFrame, list[int]]:
|
94
203
|
d_ft = rdf.AsNumpy(self._l_ft_name)
|
@@ -99,7 +208,7 @@ class TrainMva:
|
|
99
208
|
|
100
209
|
return df, l_lab
|
101
210
|
# ---------------------------------------------
|
102
|
-
def _get_model(self, arr_index :
|
211
|
+
def _get_model(self, arr_index : NPA) -> cls:
|
103
212
|
model = cls(cfg = self._cfg)
|
104
213
|
df_ft = self._df_ft.iloc[arr_index]
|
105
214
|
l_lab = self._l_lab[arr_index]
|
@@ -111,10 +220,14 @@ class TrainMva:
|
|
111
220
|
|
112
221
|
return model
|
113
222
|
# ---------------------------------------------
|
114
|
-
def _get_models(self):
|
223
|
+
def _get_models(self, load_trained : bool) -> list[cls]:
|
115
224
|
'''
|
116
225
|
Will create models, train them and return them
|
117
226
|
'''
|
227
|
+
if load_trained:
|
228
|
+
log.warning('Not retraining, but loading trained models')
|
229
|
+
return self._load_trained_models()
|
230
|
+
|
118
231
|
nfold = self._cfg['training']['nfold']
|
119
232
|
rdmst = self._cfg['training']['rdm_stat']
|
120
233
|
|
@@ -122,6 +235,11 @@ class TrainMva:
|
|
122
235
|
|
123
236
|
l_model=[]
|
124
237
|
ifold=0
|
238
|
+
|
239
|
+
l_arr_lab_ts = []
|
240
|
+
l_arr_all_ts = []
|
241
|
+
l_arr_sig_ts = []
|
242
|
+
l_arr_bkg_ts = []
|
125
243
|
for arr_itr, arr_its in kfold.split(self._df_ft, self._l_lab):
|
126
244
|
log.debug(20 * '-')
|
127
245
|
log.info(f'Training fold: {ifold}')
|
@@ -129,33 +247,132 @@ class TrainMva:
|
|
129
247
|
model = self._get_model(arr_itr)
|
130
248
|
l_model.append(model)
|
131
249
|
|
132
|
-
|
133
|
-
|
250
|
+
arr_sig_tr, arr_bkg_tr, arr_all_tr, arr_lab_tr = self._get_scores(model, arr_itr, on_training_ok= True)
|
251
|
+
arr_sig_ts, arr_bkg_ts, arr_all_ts, arr_lab_ts = self._get_scores(model, arr_its, on_training_ok=False)
|
134
252
|
|
135
253
|
self._save_feature_importance(model, ifold)
|
136
|
-
self.
|
137
|
-
self._plot_scores(
|
138
|
-
|
254
|
+
self._plot_correlations(arr_itr, ifold)
|
255
|
+
self._plot_scores(
|
256
|
+
ifold = ifold,
|
257
|
+
sig_trn=arr_sig_tr,
|
258
|
+
sig_tst=arr_sig_ts,
|
259
|
+
bkg_trn=arr_bkg_tr,
|
260
|
+
bkg_tst=arr_bkg_ts)
|
261
|
+
|
262
|
+
xval_ts, yval_ts, _ = TrainMva.plot_roc(arr_lab_ts, arr_all_ts, kind='Test' , ifold=ifold)
|
263
|
+
xval_tr, yval_tr, _ = TrainMva.plot_roc(arr_lab_tr, arr_all_tr, kind='Train', ifold=ifold)
|
264
|
+
self._plot_probabilities(xval_tr, yval_tr, arr_all_tr, arr_lab_tr)
|
265
|
+
self._save_roc_plot(ifold=ifold)
|
266
|
+
|
267
|
+
self._save_roc_json(xval=xval_ts, yval=yval_ts, kind='Test' , ifold=ifold)
|
268
|
+
self._save_roc_json(xval=xval_tr, yval=yval_tr, kind='Train', ifold=ifold)
|
139
269
|
|
140
270
|
ifold+=1
|
141
271
|
|
272
|
+
l_arr_lab_ts.append(arr_lab_ts)
|
273
|
+
l_arr_all_ts.append(arr_all_ts)
|
274
|
+
l_arr_sig_ts.append(arr_sig_ts)
|
275
|
+
l_arr_bkg_ts.append(arr_bkg_ts)
|
276
|
+
|
277
|
+
arr_lab_ts = numpy.concatenate(l_arr_lab_ts)
|
278
|
+
arr_all_ts = numpy.concatenate(l_arr_all_ts)
|
279
|
+
arr_sig_ts = numpy.concatenate(l_arr_sig_ts)
|
280
|
+
arr_bkg_ts = numpy.concatenate(l_arr_bkg_ts)
|
281
|
+
|
282
|
+
xval, yval, self._auc = TrainMva.plot_roc(
|
283
|
+
arr_lab_ts,
|
284
|
+
arr_all_ts,
|
285
|
+
kind ='Test',
|
286
|
+
ifold=-1)
|
287
|
+
self._plot_probabilities(xval, yval, arr_all_ts, arr_lab_ts)
|
288
|
+
self._save_roc_plot(ifold=-1)
|
289
|
+
|
290
|
+
self._plot_scores(ifold=-1, sig_tst=arr_sig_ts, bkg_tst=arr_bkg_ts)
|
291
|
+
self._save_roc_json(xval=xval, yval=yval, kind='Full', ifold=-1)
|
292
|
+
|
293
|
+
return l_model
|
294
|
+
# ---------------------------------------------
|
295
|
+
def _save_roc_json(
|
296
|
+
self,
|
297
|
+
ifold : int,
|
298
|
+
kind : str,
|
299
|
+
xval : NPA,
|
300
|
+
yval : NPA) -> None:
|
301
|
+
ifold = 'all' if ifold == -1 else ifold # -1 represents all the testing datasets combined
|
302
|
+
val_dir = self._cfg['saving']['output']
|
303
|
+
|
304
|
+
name = kind.lower()
|
305
|
+
val_dir = f'{val_dir}/fold_{ifold:03}'
|
306
|
+
os.makedirs(val_dir, exist_ok=True)
|
307
|
+
jsn_path = f'{val_dir}/roc_{name}.json'
|
308
|
+
|
309
|
+
df = pnd.DataFrame({'x' : xval, 'y' : yval})
|
310
|
+
df.to_json(jsn_path, indent=2)
|
311
|
+
# ---------------------------------------------
|
312
|
+
def _save_roc_plot(self, ifold : int) -> None:
|
313
|
+
min_x = 0
|
314
|
+
min_y = 0
|
315
|
+
ifold = 'all' if ifold == -1 else ifold
|
316
|
+
|
317
|
+
if 'min' in self._cfg['plotting']['roc']:
|
318
|
+
[min_x, min_y] = self._cfg['plotting']['roc']['min']
|
319
|
+
|
320
|
+
max_x = 1
|
321
|
+
max_y = 1
|
322
|
+
if 'max' in self._cfg['plotting']['roc']:
|
323
|
+
[max_x, max_y] = self._cfg['plotting']['roc']['max']
|
324
|
+
|
325
|
+
val_dir = self._cfg['saving']['output']
|
326
|
+
|
327
|
+
if ifold == 'all':
|
328
|
+
plt_dir = f'{val_dir}/fold_all'
|
329
|
+
else:
|
330
|
+
plt_dir = f'{val_dir}/fold_{ifold:03}'
|
331
|
+
|
332
|
+
os.makedirs(plt_dir, exist_ok=True)
|
333
|
+
|
334
|
+
plt.xlabel('Signal efficiency')
|
335
|
+
plt.ylabel('Background rejection')
|
336
|
+
plt.title(f'Fold: {ifold}')
|
337
|
+
plt.xlim(min_x, max_x)
|
338
|
+
plt.ylim(min_y, max_y)
|
339
|
+
plt.grid()
|
340
|
+
plt.legend()
|
341
|
+
plt.savefig(f'{plt_dir}/roc.png')
|
342
|
+
plt.close()
|
343
|
+
# ---------------------------------------------
|
344
|
+
def _load_trained_models(self) -> list[cls]:
|
345
|
+
out_dir = self._cfg['saving']['output']
|
346
|
+
model_path = f'{out_dir}/model.pkl'
|
347
|
+
nfold = self._cfg['training']['nfold']
|
348
|
+
l_model = []
|
349
|
+
for ifold in range(nfold):
|
350
|
+
fold_path = model_path.replace('.pkl', f'_{ifold:03}.pkl')
|
351
|
+
|
352
|
+
if not os.path.isfile(fold_path):
|
353
|
+
raise FileNotFoundError(f'Missing trained model: {fold_path}')
|
354
|
+
|
355
|
+
log.debug(f'Loading model from: {fold_path}')
|
356
|
+
model = joblib.load(fold_path)
|
357
|
+
l_model.append(model)
|
358
|
+
|
142
359
|
return l_model
|
143
360
|
# ---------------------------------------------
|
144
361
|
def _labels_from_varnames(self, l_var_name : list[str]) -> list[str]:
|
145
362
|
try:
|
146
363
|
d_plot = self._cfg['plotting']['features']['plots']
|
147
|
-
except
|
148
|
-
|
149
|
-
return l_var_name
|
364
|
+
except KeyError as exc:
|
365
|
+
raise KeyError('Cannot find plotting/features/plots section in config, using dataframe names') from exc
|
150
366
|
|
151
367
|
l_label = []
|
152
368
|
for var_name in l_var_name:
|
153
369
|
if var_name not in d_plot:
|
154
|
-
|
155
|
-
l_label.append(var_name)
|
156
|
-
continue
|
370
|
+
raise NoFeatureInfo(f'No plot found for feature {var_name}, cannot extract label')
|
157
371
|
|
158
372
|
d_setting = d_plot[var_name]
|
373
|
+
if 'labels' not in d_setting:
|
374
|
+
raise NoFeatureInfo(f'No no labels present for plot of feature {var_name}, cannot extract label')
|
375
|
+
|
159
376
|
[xlab, _ ]= d_setting['labels']
|
160
377
|
|
161
378
|
l_label.append(xlab)
|
@@ -169,7 +386,7 @@ class TrainMva:
|
|
169
386
|
d_data['Variable' ] = self._labels_from_varnames(l_var_name)
|
170
387
|
d_data['Importance'] = 100 * model.feature_importances_
|
171
388
|
|
172
|
-
val_dir = self._cfg['
|
389
|
+
val_dir = self._cfg['saving']['output']
|
173
390
|
val_dir = f'{val_dir}/fold_{ifold:03}'
|
174
391
|
os.makedirs(val_dir, exist_ok=True)
|
175
392
|
|
@@ -180,7 +397,7 @@ class TrainMva:
|
|
180
397
|
d_form = {'Variable' : '{}', 'Importance' : '{:.1f}'}
|
181
398
|
put.df_to_tex(df, table_path, d_format = d_form)
|
182
399
|
# ---------------------------------------------
|
183
|
-
def _get_scores(self, model : cls, arr_index :
|
400
|
+
def _get_scores(self, model : cls, arr_index : NPA, on_training_ok : bool) -> tuple[NPA, NPA, NPA, NPA]:
|
184
401
|
'''
|
185
402
|
Returns a tuple of four arrays
|
186
403
|
|
@@ -203,7 +420,7 @@ class TrainMva:
|
|
203
420
|
|
204
421
|
return arr_sig, arr_bkg, arr_all, arr_lab
|
205
422
|
# ---------------------------------------------
|
206
|
-
def _split_scores(self, arr_prob :
|
423
|
+
def _split_scores(self, arr_prob : NPA, arr_label : NPA) -> tuple[NPA, NPA]:
|
207
424
|
'''
|
208
425
|
Will split the testing scores (predictions) based on the training scores
|
209
426
|
|
@@ -222,7 +439,9 @@ class TrainMva:
|
|
222
439
|
'''
|
223
440
|
Saves a model, associated to a specific fold
|
224
441
|
'''
|
225
|
-
|
442
|
+
out_dir = self._cfg['saving']['output']
|
443
|
+
model_path = f'{out_dir}/model.pkl'
|
444
|
+
|
226
445
|
if os.path.isfile(model_path):
|
227
446
|
log.info(f'Model found in {model_path}, not saving')
|
228
447
|
return
|
@@ -259,49 +478,71 @@ class TrainMva:
|
|
259
478
|
|
260
479
|
return cfg
|
261
480
|
# ---------------------------------------------
|
262
|
-
def
|
481
|
+
def _plot_correlations(self, arr_index : NPA, ifold : int) -> None:
|
482
|
+
log.debug('Plotting correlations')
|
483
|
+
|
263
484
|
df_ft = self._df_ft.iloc[arr_index]
|
485
|
+
l_lab = self._l_lab[arr_index]
|
486
|
+
|
487
|
+
arr_sig_idx, = numpy.where(l_lab == 1)
|
488
|
+
arr_bkg_idx, = numpy.where(l_lab == 0)
|
489
|
+
|
490
|
+
df_ft_sig = df_ft.iloc[arr_sig_idx]
|
491
|
+
df_ft_bkg = df_ft.iloc[arr_bkg_idx]
|
492
|
+
|
493
|
+
self._plot_correlation(df_ft=df_ft_sig, ifold=ifold, name='signal' )
|
494
|
+
self._plot_correlation(df_ft=df_ft_bkg, ifold=ifold, name='background')
|
495
|
+
# ---------------------------------------------
|
496
|
+
def _plot_correlation(
|
497
|
+
self,
|
498
|
+
df_ft : pnd.DataFrame,
|
499
|
+
ifold : int,
|
500
|
+
name : str) -> None:
|
501
|
+
|
502
|
+
log.debug(f'Plotting correlation for {name}/{ifold} fold')
|
503
|
+
|
264
504
|
cfg = self._get_correlation_cfg(df_ft, ifold)
|
265
505
|
cov = df_ft.corr()
|
266
506
|
mat = cov.to_numpy()
|
267
507
|
|
268
|
-
|
269
|
-
|
270
|
-
val_dir = self._cfg['plotting']['val_dir']
|
508
|
+
val_dir = self._cfg['saving']['output']
|
271
509
|
val_dir = f'{val_dir}/fold_{ifold:03}'
|
272
510
|
os.makedirs(val_dir, exist_ok=True)
|
273
511
|
|
274
512
|
obj = MatrixPlotter(mat=mat, cfg=cfg)
|
275
513
|
obj.plot()
|
276
|
-
plt.savefig(f'{val_dir}/
|
514
|
+
plt.savefig(f'{val_dir}/correlation_{name}.png')
|
277
515
|
plt.close()
|
278
516
|
# ---------------------------------------------
|
279
|
-
def _get_nentries(self, arr_val :
|
517
|
+
def _get_nentries(self, arr_val : NPA) -> str:
|
280
518
|
size = len(arr_val)
|
281
519
|
size = size / 1000.
|
282
520
|
|
283
521
|
return f'{size:.2f}K'
|
284
522
|
# ---------------------------------------------
|
285
|
-
def _plot_scores(
|
286
|
-
|
523
|
+
def _plot_scores(
|
524
|
+
self,
|
525
|
+
ifold : int,
|
526
|
+
sig_tst : NPA,
|
527
|
+
bkg_tst : NPA,
|
528
|
+
sig_trn : NPA = None,
|
529
|
+
bkg_trn : NPA = None) -> None:
|
287
530
|
'''
|
288
531
|
Will plot an array of scores, associated to a given fold
|
289
532
|
'''
|
533
|
+
ifold = 'all' if ifold == -1 else ifold
|
290
534
|
log.debug(f'Plotting scores for {ifold} fold')
|
291
535
|
|
292
|
-
|
293
|
-
log.warning('Scores path not passed, not plotting scores')
|
294
|
-
return
|
295
|
-
|
296
|
-
val_dir = self._cfg['plotting']['val_dir']
|
536
|
+
val_dir = self._cfg['saving']['output']
|
297
537
|
val_dir = f'{val_dir}/fold_{ifold:03}'
|
298
538
|
os.makedirs(val_dir, exist_ok=True)
|
299
539
|
|
300
|
-
plt.hist(
|
301
|
-
plt.hist(
|
540
|
+
plt.hist(sig_tst, histtype='step', bins=50, range=(0,1), color='b', density=True, label='Signal Test: ' + self._get_nentries(sig_tst))
|
541
|
+
plt.hist(bkg_tst, histtype='step', bins=50, range=(0,1), color='r', density=True, label='Background Test: ' + self._get_nentries(bkg_tst))
|
302
542
|
|
303
|
-
|
304
|
-
|
543
|
+
if sig_trn is not None and bkg_trn is not None:
|
544
|
+
plt.hist(sig_trn, alpha = 0.3, bins=50, range=(0,1), color='b', density=True, label='Signal Train: ' + self._get_nentries(sig_trn))
|
545
|
+
plt.hist(bkg_trn, alpha = 0.3, bins=50, range=(0,1), color='r', density=True, label='Background Train: '+ self._get_nentries(bkg_trn))
|
305
546
|
|
306
547
|
plt.legend()
|
307
548
|
plt.title(f'Fold: {ifold}')
|
@@ -310,59 +551,12 @@ class TrainMva:
|
|
310
551
|
plt.savefig(f'{val_dir}/scores.png')
|
311
552
|
plt.close()
|
312
553
|
# ---------------------------------------------
|
313
|
-
def
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
'''
|
320
|
-
Takes the labels and the probabilities and plots ROC
|
321
|
-
curve for given fold
|
322
|
-
'''
|
323
|
-
log.debug(f'Plotting ROC curve for {ifold} fold')
|
324
|
-
|
325
|
-
val_dir = self._cfg['plotting']['val_dir']
|
326
|
-
val_dir = f'{val_dir}/fold_{ifold:03}'
|
327
|
-
os.makedirs(val_dir, exist_ok=True)
|
328
|
-
|
329
|
-
xval_ts, yval_ts, _ = roc_curve(l_lab_ts, l_prb_ts)
|
330
|
-
xval_ts = 1 - xval_ts
|
331
|
-
area_ts = auc(xval_ts, yval_ts)
|
332
|
-
|
333
|
-
xval_tr, yval_tr, _ = roc_curve(l_lab_tr, l_prb_tr)
|
334
|
-
xval_tr = 1 - xval_tr
|
335
|
-
area_tr = auc(xval_tr, yval_tr)
|
336
|
-
|
337
|
-
min_x = 0
|
338
|
-
min_y = 0
|
339
|
-
if 'min' in self._cfg['plotting']['roc']:
|
340
|
-
[min_x, min_y] = self._cfg['plotting']['roc']['min']
|
341
|
-
|
342
|
-
max_x = 1
|
343
|
-
max_y = 1
|
344
|
-
if 'max' in self._cfg['plotting']['roc']:
|
345
|
-
[max_x, max_y] = self._cfg['plotting']['roc']['max']
|
346
|
-
|
347
|
-
plt.plot(xval_ts, yval_ts, color='b', label=f'Test: {area_ts:.3f}')
|
348
|
-
plt.plot(xval_tr, yval_tr, color='r', label=f'Train: {area_tr:.3f}')
|
349
|
-
self._plot_probabilities(xval_ts, yval_ts, l_prb_ts, l_lab_ts)
|
350
|
-
|
351
|
-
plt.xlabel('Signal efficiency')
|
352
|
-
plt.ylabel('Background rejection')
|
353
|
-
plt.title(f'Fold: {ifold}')
|
354
|
-
plt.xlim(min_x, max_x)
|
355
|
-
plt.ylim(min_y, max_y)
|
356
|
-
plt.grid()
|
357
|
-
plt.legend()
|
358
|
-
plt.savefig(f'{val_dir}/roc.png')
|
359
|
-
plt.close()
|
360
|
-
# ---------------------------------------------
|
361
|
-
def _plot_probabilities(self,
|
362
|
-
arr_seff: npa,
|
363
|
-
arr_brej: npa,
|
364
|
-
arr_sprb: npa,
|
365
|
-
arr_labl: npa) -> None:
|
554
|
+
def _plot_probabilities(
|
555
|
+
self,
|
556
|
+
arr_seff: NPA,
|
557
|
+
arr_brej: NPA,
|
558
|
+
arr_sprb: NPA,
|
559
|
+
arr_labl: NPA) -> None:
|
366
560
|
|
367
561
|
roc_cfg = self._cfg['plotting']['roc']
|
368
562
|
if 'annotate' not in roc_cfg:
|
@@ -407,7 +601,10 @@ class TrainMva:
|
|
407
601
|
'''
|
408
602
|
Will plot the features, based on the settings in the config
|
409
603
|
'''
|
410
|
-
|
604
|
+
out_dir = self._cfg['saving']['output']
|
605
|
+
d_cfg = self._cfg['plotting']['features']
|
606
|
+
d_cfg['saving'] = {'plt_dir' : f'{out_dir}/features'}
|
607
|
+
|
411
608
|
ptr = Plotter(d_rdf = {'Signal' : self._rdf_sig, 'Background' : self._rdf_bkg}, cfg=d_cfg)
|
412
609
|
ptr.run()
|
413
610
|
# ---------------------------------------------
|
@@ -430,7 +627,7 @@ class TrainMva:
|
|
430
627
|
|
431
628
|
d_tex = {'Variable' : l_lab, 'Replacement' : l_val}
|
432
629
|
df = pnd.DataFrame(d_tex)
|
433
|
-
val_dir = self._cfg['
|
630
|
+
val_dir = self._cfg['saving']['output']
|
434
631
|
os.makedirs(val_dir, exist_ok=True)
|
435
632
|
put.df_to_tex(df, f'{val_dir}/nan_replacement.tex')
|
436
633
|
# ---------------------------------------------
|
@@ -438,28 +635,306 @@ class TrainMva:
|
|
438
635
|
if 'hyper' not in self._cfg['training']:
|
439
636
|
raise ValueError('Cannot find hyper parameters in configuration')
|
440
637
|
|
638
|
+
def format_value(val : Union[int,float]) -> str:
|
639
|
+
if isinstance(val, float):
|
640
|
+
return f'\\verb|{val:.3f}|'
|
641
|
+
|
642
|
+
return f'\\verb|{val}|'
|
643
|
+
|
441
644
|
d_hyper = self._cfg['training']['hyper']
|
442
|
-
d_form = { f'\\verb|{key}|' :
|
645
|
+
d_form = { f'\\verb|{key}|' : format_value(val) for key, val in d_hyper.items() }
|
443
646
|
d_latex = { 'Hyperparameter' : list(d_form.keys()), 'Value' : list(d_form.values())}
|
444
647
|
|
445
648
|
df = pnd.DataFrame(d_latex)
|
446
|
-
val_dir = self._cfg['
|
649
|
+
val_dir = self._cfg['saving']['output']
|
447
650
|
os.makedirs(val_dir, exist_ok=True)
|
448
651
|
put.df_to_tex(df, f'{val_dir}/hyperparameters.tex')
|
449
652
|
# ---------------------------------------------
|
450
|
-
def
|
653
|
+
def _run_diagnostics(self, models : list[cls], rdf : RDataFrame, name : str) -> None:
|
654
|
+
log.info(f'Running diagnostics for sample {name}')
|
655
|
+
if 'diagnostics' not in self._cfg:
|
656
|
+
log.warning('Diagnostics section not found, not running diagnostics')
|
657
|
+
return
|
658
|
+
|
659
|
+
cfg_diag = self._cfg['diagnostics']
|
660
|
+
out_dir = cfg_diag['output']
|
661
|
+
plt_dir = None
|
662
|
+
|
663
|
+
if 'overlay' in cfg_diag['correlations']['target']:
|
664
|
+
plt_dir = cfg_diag['correlations']['target']['overlay']['saving']['plt_dir']
|
665
|
+
|
666
|
+
cfg_diag = copy.deepcopy(cfg_diag)
|
667
|
+
cfg_diag['output'] = f'{out_dir}/{name}'
|
668
|
+
if plt_dir is not None:
|
669
|
+
cfg_diag['correlations']['target']['overlay']['saving']['plt_dir'] = f'{plt_dir}/{name}'
|
670
|
+
|
671
|
+
cvd = CVDiagnostics(models=models, rdf=rdf, cfg=cfg_diag)
|
672
|
+
cvd.run()
|
673
|
+
# ---------------------------------------------
|
674
|
+
#
|
675
|
+
# Hyperparameter optimization
|
676
|
+
# ---------------------------------------------
|
677
|
+
def _objective(self, trial, kfold : StratifiedKFold) -> float:
|
678
|
+
ft = self._df_ft
|
679
|
+
lab= self._l_lab
|
680
|
+
|
681
|
+
if not issubclass(cls, GradientBoostingClassifier):
|
682
|
+
raise NotImplementedError('Hyperparameter optimization only implemented for GradientBoostingClassifier')
|
683
|
+
|
684
|
+
nft = len(ft.columns)
|
685
|
+
|
686
|
+
var_learn_rate = trial.suggest_float('learning_rate' , 1e-3, 1e-1, log=True)
|
687
|
+
var_max_depth = trial.suggest_int('max_depth' , 2, 15)
|
688
|
+
var_max_features= trial.suggest_int('max_features' , 2, nft)
|
689
|
+
var_min_split = trial.suggest_int('min_samples_split', 2, 10)
|
690
|
+
var_min_samples = trial.suggest_int('min_samples_leaf' , 2, 30)
|
691
|
+
var_nestimators = trial.suggest_int('n_estimators' , 50, 400)
|
692
|
+
|
693
|
+
classifier = GradientBoostingClassifier(
|
694
|
+
learning_rate = var_learn_rate,
|
695
|
+
max_depth = var_max_depth,
|
696
|
+
max_features = var_max_features,
|
697
|
+
min_samples_split = var_min_split,
|
698
|
+
min_samples_leaf = var_min_samples,
|
699
|
+
n_estimators = var_nestimators,
|
700
|
+
random_state = self._rdm_state)
|
701
|
+
|
702
|
+
score = cross_val_score(
|
703
|
+
classifier,
|
704
|
+
ft,
|
705
|
+
lab,
|
706
|
+
n_jobs=1, # More than this will reach RLIMIT_NPROC in cluster
|
707
|
+
cv=kfold)
|
708
|
+
|
709
|
+
accuracy = score.mean()
|
710
|
+
|
711
|
+
return accuracy
|
712
|
+
# ---------------------------------------------
|
713
|
+
def _optimize_hyperparameters(self, ntrial : int):
|
714
|
+
log.info('Running hyperparameter optimization')
|
715
|
+
|
716
|
+
self._pbar = tqdm.tqdm(total=ntrial, desc='Optimizing')
|
717
|
+
kfold = StratifiedKFold(n_splits=5, shuffle=True, random_state=self._rdm_state)
|
718
|
+
objective = partial(self._objective, kfold=kfold)
|
719
|
+
|
720
|
+
study = optuna.create_study(
|
721
|
+
direction='maximize',
|
722
|
+
pruner = optuna.pruners.MedianPruner(n_startup_trials=10, n_warmup_steps=5),)
|
723
|
+
|
724
|
+
study.optimize(
|
725
|
+
objective,
|
726
|
+
callbacks = [self._update_progress],
|
727
|
+
n_jobs = self._nworkers,
|
728
|
+
n_trials = ntrial)
|
729
|
+
|
730
|
+
self._print_hyper_opt(study=study)
|
731
|
+
self._plot_hyper_opt(study=study)
|
732
|
+
|
733
|
+
log.info('Overriding hyperparameters with optimized values')
|
734
|
+
|
735
|
+
self._cfg['training']['hyper'] = study.best_params
|
736
|
+
# ---------------------------------------------
|
737
|
+
def _plot_hyper_opt(self, study) -> None:
|
738
|
+
out_dir = self._cfg['saving']['output']
|
739
|
+
opt_dir = f'{out_dir}/optimization'
|
740
|
+
os.makedirs(opt_dir, exist_ok=True)
|
741
|
+
|
742
|
+
trials_df = study.trials_dataframe()
|
743
|
+
|
744
|
+
plt.plot(trials_df['number'], trials_df['value'])
|
745
|
+
plt.xlabel('Trial')
|
746
|
+
plt.ylabel('Accuracy')
|
747
|
+
plt.title('Optimization History')
|
748
|
+
plt.grid(True)
|
749
|
+
plt.savefig(f'{opt_dir}/history.png')
|
750
|
+
plt.close()
|
751
|
+
|
752
|
+
plt.hist(trials_df['value'], bins=20, alpha=0.7)
|
753
|
+
plt.xlabel('Accuracy')
|
754
|
+
plt.ylabel('Frequency')
|
755
|
+
plt.title('Distribution of Trial Results')
|
756
|
+
plt.savefig(f'{opt_dir}/accuracy.png')
|
757
|
+
plt.close()
|
758
|
+
# ---------------------------------------------
|
759
|
+
def _update_progress(self, study, _trial):
|
760
|
+
self._pbar.set_postfix({'Best': f'{study.best_value:.4f}' if study.best_value else 'N/A'})
|
761
|
+
self._pbar.update(1)
|
762
|
+
# ---------------------------------------------
|
763
|
+
def _print_hyper_opt(self, study) -> None:
|
764
|
+
log.info(40 * '-')
|
765
|
+
log.info('Optimized hyperparameters:')
|
766
|
+
log.info(40 * '-')
|
767
|
+
for name, value in study.best_params.items():
|
768
|
+
if isinstance(value, float):
|
769
|
+
log.info(f'{name:<20}{value:.3f}')
|
770
|
+
else:
|
771
|
+
log.info(f'{name:<20}{value}')
|
772
|
+
# ---------------------------------------------
|
773
|
+
# ---------------------------------------------
|
774
|
+
def _auc_from_json(self, ifold : int, kind : str) -> float:
|
775
|
+
val_dir = self._cfg['saving']['output']
|
776
|
+
path = f'{val_dir}/fold_{ifold:03}/roc_{kind}.json'
|
777
|
+
df = pnd.read_json(path)
|
778
|
+
|
779
|
+
return auc(df['x'], df['y'])
|
780
|
+
# ---------------------------------------------
|
781
|
+
def _check_overtraining(self) -> None:
|
782
|
+
nfold = self._cfg['training']['nfold']
|
783
|
+
|
784
|
+
df = pnd.DataFrame(columns=['fold'])
|
785
|
+
df['fold' ]= numpy.linspace(0, nfold - 1, nfold, dtype=int)
|
786
|
+
df['test' ]= df['fold'].apply(self._auc_from_json, args=('test' ,))
|
787
|
+
df['train']= df['fold'].apply(self._auc_from_json, args=('train',))
|
788
|
+
|
789
|
+
ax=None
|
790
|
+
ax=df.plot('fold', 'test' , color='blue', label='Testing sample' , ax=ax)
|
791
|
+
ax=df.plot('fold', 'train', color='red' , label='Training sample', ax=ax)
|
792
|
+
ax.set_ylim(bottom=0.75, top=1.00)
|
793
|
+
ax.set_ylabel('AUC')
|
794
|
+
ax.set_xlabel('Fold')
|
795
|
+
|
796
|
+
plt.grid()
|
797
|
+
|
798
|
+
val_dir = self._cfg['saving']['output']
|
799
|
+
path = f'{val_dir}/fold_all/auc_vs_fold.png'
|
800
|
+
plt.savefig(path)
|
801
|
+
plt.close()
|
802
|
+
# ---------------------------------------------
|
803
|
+
def run(
|
804
|
+
self,
|
805
|
+
skip_fit : bool = False,
|
806
|
+
opt_ntrial : int = 0,
|
807
|
+
load_trained : bool = False) -> float:
|
451
808
|
'''
|
452
809
|
Will do the training
|
453
810
|
|
454
|
-
skip_fit: By default false, if True, it will only do the plots of features and save tables
|
811
|
+
skip_fit : By default false, if True, it will only do the plots of features and save tables
|
812
|
+
opt_ntrial : Number of optimization tries for hyperparameter optimization, by default zero, i.e. no optimization will run
|
813
|
+
load_trained: If true, it will load the models instead of training, by default false.
|
814
|
+
|
815
|
+
Returns
|
816
|
+
----------------
|
817
|
+
Area under the ROC curve from evaluating the classifiers
|
818
|
+
on samples that were not used in their training. Uses the full sample
|
455
819
|
'''
|
456
|
-
self._save_settings_to_tex()
|
457
820
|
self._plot_features()
|
458
821
|
|
459
822
|
if skip_fit:
|
460
|
-
return
|
823
|
+
return self._auc
|
824
|
+
|
825
|
+
if opt_ntrial > 0:
|
826
|
+
self._optimize_hyperparameters(ntrial=opt_ntrial)
|
827
|
+
|
828
|
+
self._save_settings_to_tex()
|
829
|
+
l_mod = self._get_models(load_trained = load_trained)
|
830
|
+
if not load_trained:
|
831
|
+
for ifold, mod in enumerate(l_mod):
|
832
|
+
self._save_model(mod, ifold)
|
833
|
+
|
834
|
+
self._check_overtraining()
|
835
|
+
self._run_diagnostics(models = l_mod, rdf = self._rdf_sig_org, name='Signal' )
|
836
|
+
self._run_diagnostics(models = l_mod, rdf = self._rdf_bkg_org, name='Background')
|
837
|
+
|
838
|
+
return self._auc
|
839
|
+
# ---------------------------------------------
|
840
|
+
@contextmanager
|
841
|
+
def use(self, nworkers : int) -> None:
|
842
|
+
'''
|
843
|
+
Context manager used to run with a specific configuration
|
844
|
+
|
845
|
+
nworkers: Use this number of workers for ANY process that can be parallelized.
|
846
|
+
'''
|
847
|
+
old = self._nworkers
|
848
|
+
|
849
|
+
log.info(f'Using {nworkers} workers to run training')
|
850
|
+
|
851
|
+
self._nworkers = nworkers
|
852
|
+
try:
|
853
|
+
yield
|
854
|
+
finally:
|
855
|
+
self._nworkers = old
|
856
|
+
# ---------------------------------------------
|
857
|
+
@staticmethod
|
858
|
+
def plot_roc_from_prob(
|
859
|
+
arr_sig_prb : NPA,
|
860
|
+
arr_bkg_prb : NPA,
|
861
|
+
kind : str,
|
862
|
+
ifold : int,
|
863
|
+
color : str = None) -> tuple[NPA,NPA, float]:
|
864
|
+
'''
|
865
|
+
Takes arrays of signal and background probabilities
|
866
|
+
and plots ROC curve
|
867
|
+
|
868
|
+
Parameters
|
869
|
+
--------------------
|
870
|
+
arr_bkg/sig_prb : Array with background/signal probabilities
|
871
|
+
kind : String used to label the plot
|
872
|
+
ifold : If no fold makes sense (i.e. this is the full sample), use ifold=-1
|
873
|
+
kind : Used to label the plot
|
874
|
+
color : String with color of curve
|
875
|
+
|
876
|
+
Returns
|
877
|
+
--------------------
|
878
|
+
Tuple with 3 elements:
|
879
|
+
|
880
|
+
- Array of x coordinates of ROC curve
|
881
|
+
- Array of y coordinates of ROC curve
|
882
|
+
- Area under the curve
|
883
|
+
'''
|
884
|
+
arr_sig_lab = numpy.ones_like( arr_sig_prb)
|
885
|
+
arr_bkg_lab = numpy.zeros_like(arr_bkg_prb)
|
886
|
+
|
887
|
+
arr_prb = numpy.concatenate([arr_sig_prb, arr_bkg_prb])
|
888
|
+
arr_lab = numpy.concatenate([arr_sig_lab, arr_bkg_lab])
|
889
|
+
|
890
|
+
res = TrainMva.plot_roc(
|
891
|
+
l_lab=arr_lab,
|
892
|
+
l_prb=arr_prb,
|
893
|
+
color=color,
|
894
|
+
kind =kind,
|
895
|
+
ifold=ifold)
|
896
|
+
|
897
|
+
return res
|
898
|
+
# ---------------------------------------------
|
899
|
+
@staticmethod
|
900
|
+
def plot_roc(
|
901
|
+
l_lab : NPA,
|
902
|
+
l_prb : NPA,
|
903
|
+
kind : str,
|
904
|
+
ifold : int,
|
905
|
+
color : str = None) -> tuple[NPA, NPA, float]:
|
906
|
+
'''
|
907
|
+
Takes the labels and the probabilities and plots ROC
|
908
|
+
curve for given fold
|
909
|
+
|
910
|
+
Parameters
|
911
|
+
--------------------
|
912
|
+
ifold : If no fold makes sense (i.e. this is the full sample), use ifold=-1
|
913
|
+
kind : Used to label the plot
|
914
|
+
|
915
|
+
Returns
|
916
|
+
--------------------
|
917
|
+
Tuple with 3 elements:
|
918
|
+
|
919
|
+
- Array of x coordinates of ROC curve
|
920
|
+
- Array of y coordinates of ROC curve
|
921
|
+
- Area under the curve
|
922
|
+
'''
|
923
|
+
log.debug(f'Plotting ROC curve for {ifold} fold')
|
924
|
+
|
925
|
+
xval, yval, _ = roc_curve(l_lab, l_prb)
|
926
|
+
xval = 1 - xval
|
927
|
+
area = auc(xval, yval)
|
928
|
+
|
929
|
+
if color is None:
|
930
|
+
color='red' if kind == 'Train' else 'blue'
|
931
|
+
|
932
|
+
if ifold == -1:
|
933
|
+
label=f'Test sample: {area:.3f}'
|
934
|
+
else:
|
935
|
+
label=f'{kind}: {area:.3f}'
|
936
|
+
|
937
|
+
plt.plot(xval, yval, color=color, label=label)
|
461
938
|
|
462
|
-
|
463
|
-
for ifold, mod in enumerate(l_mod):
|
464
|
-
self._save_model(mod, ifold)
|
939
|
+
return xval, yval, area
|
465
940
|
# ---------------------------------------------
|