data-manipulation-utilities 0.2.6__py3-none-any.whl → 0.2.8.dev714__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. {data_manipulation_utilities-0.2.6.dist-info → data_manipulation_utilities-0.2.8.dev714.dist-info}/METADATA +800 -34
  2. data_manipulation_utilities-0.2.8.dev714.dist-info/RECORD +93 -0
  3. {data_manipulation_utilities-0.2.6.dist-info → data_manipulation_utilities-0.2.8.dev714.dist-info}/WHEEL +1 -1
  4. {data_manipulation_utilities-0.2.6.dist-info → data_manipulation_utilities-0.2.8.dev714.dist-info}/entry_points.txt +1 -0
  5. dmu/__init__.py +0 -0
  6. dmu/generic/hashing.py +70 -0
  7. dmu/generic/utilities.py +175 -9
  8. dmu/generic/version_management.py +3 -5
  9. dmu/logging/log_store.py +34 -2
  10. dmu/logging/messages.py +96 -0
  11. dmu/ml/cv_classifier.py +3 -3
  12. dmu/ml/cv_diagnostics.py +224 -0
  13. dmu/ml/cv_performance.py +58 -0
  14. dmu/ml/cv_predict.py +149 -46
  15. dmu/ml/train_mva.py +587 -112
  16. dmu/ml/utilities.py +29 -10
  17. dmu/pdataframe/utilities.py +61 -3
  18. dmu/plotting/fwhm.py +64 -0
  19. dmu/plotting/matrix.py +1 -1
  20. dmu/plotting/plotter.py +25 -3
  21. dmu/plotting/plotter_1d.py +159 -14
  22. dmu/plotting/plotter_2d.py +5 -0
  23. dmu/rdataframe/utilities.py +54 -3
  24. dmu/rfile/ddfgetter.py +102 -0
  25. dmu/stats/fit_stats.py +129 -0
  26. dmu/stats/fitter.py +56 -23
  27. dmu/stats/gof_calculator.py +7 -0
  28. dmu/stats/model_factory.py +305 -50
  29. dmu/stats/parameters.py +100 -0
  30. dmu/stats/utilities.py +443 -12
  31. dmu/stats/wdata.py +187 -0
  32. dmu/stats/zfit.py +17 -0
  33. dmu/stats/zfit_models.py +68 -0
  34. dmu/stats/zfit_plotter.py +175 -56
  35. dmu/testing/utilities.py +120 -15
  36. dmu/workflow/__init__.py +0 -0
  37. dmu/workflow/cache.py +266 -0
  38. dmu_data/ml/tests/diagnostics_from_file.yaml +13 -0
  39. dmu_data/ml/tests/diagnostics_from_model.yaml +10 -0
  40. dmu_data/ml/tests/diagnostics_multiple_methods.yaml +10 -0
  41. dmu_data/ml/tests/diagnostics_overlay.yaml +33 -0
  42. dmu_data/ml/tests/train_mva.yaml +20 -12
  43. dmu_data/ml/tests/train_mva_def.yaml +75 -0
  44. dmu_data/ml/tests/train_mva_with_diagnostics.yaml +87 -0
  45. dmu_data/ml/tests/train_mva_with_preffix.yaml +58 -0
  46. dmu_data/plotting/tests/2d.yaml +5 -5
  47. dmu_data/plotting/tests/line.yaml +15 -0
  48. dmu_data/plotting/tests/plug_fwhm.yaml +24 -0
  49. dmu_data/plotting/tests/plug_stats.yaml +19 -0
  50. dmu_data/plotting/tests/simple.yaml +4 -3
  51. dmu_data/plotting/tests/styling.yaml +18 -0
  52. dmu_data/rfile/friends.yaml +13 -0
  53. dmu_data/stats/fitter/test_simple.yaml +28 -0
  54. dmu_data/stats/kde_optimizer/control.json +1 -0
  55. dmu_data/stats/kde_optimizer/signal.json +1 -0
  56. dmu_data/stats/parameters/data.yaml +178 -0
  57. dmu_data/tests/config.json +6 -0
  58. dmu_data/tests/config.yaml +4 -0
  59. dmu_data/tests/pdf_to_tex.txt +34 -0
  60. dmu_scripts/kerberos/check_expiration +21 -0
  61. dmu_scripts/kerberos/convert_certificate +22 -0
  62. dmu_scripts/ml/compare_classifiers.py +85 -0
  63. data_manipulation_utilities-0.2.6.dist-info/RECORD +0 -57
  64. {data_manipulation_utilities-0.2.6.data → data_manipulation_utilities-0.2.8.dev714.data}/scripts/publish +0 -0
  65. {data_manipulation_utilities-0.2.6.dist-info → data_manipulation_utilities-0.2.8.dev714.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,224 @@
1
+ '''
2
+ Module containing CVDiagnostics class
3
+ '''
4
+ import os
5
+
6
+ import numpy
7
+ import matplotlib
8
+ import matplotlib.pyplot as plt
9
+ import pandas as pnd
10
+
11
+ from scipy.stats import kendalltau
12
+ from ROOT import RDataFrame, RDF
13
+ from dmu.ml.cv_classifier import CVClassifier
14
+ from dmu.ml.cv_predict import CVPredict
15
+ from dmu.logging.log_store import LogStore
16
+ from dmu.plotting.plotter_1d import Plotter1D as Plotter
17
+
18
+ NPA = numpy.ndarray
19
+ Axis= matplotlib.axes._axes.Axes
20
+ log = LogStore.add_logger('dmu:ml:cv_diagnostics')
21
+ # -------------------------
22
+ class CVDiagnostics:
23
+ '''
24
+ Class meant to rundiagnostics on classifier
25
+
26
+ Correlations
27
+ ------------------
28
+ Will calculate correlations between features + signal probability and some external target variable specified in the config
29
+ '''
30
+ # -------------------------
31
+ def __init__(self, models : list[CVClassifier], rdf : RDataFrame, cfg : dict):
32
+ self._l_model = models
33
+ self._cfg = cfg
34
+ self._rdf = rdf
35
+ self._target = cfg['correlations']['target']['name']
36
+ self._l_feat = self._get_features()
37
+ self._d_xlab = self._get_xlabels()
38
+ # -------------------------
39
+ def _get_features(self) -> list[str]:
40
+ cfg = self._l_model[0].cfg
41
+ l_var = cfg['training']['features']
42
+
43
+ return l_var
44
+ # -------------------------
45
+ def _get_xlabels(self) -> dict[str,str]:
46
+ cfg = self._l_model[0].cfg
47
+ d_var = cfg['plotting']['features']['plots']
48
+
49
+ d_lab = { varname : d_field['labels'][0] for varname, d_field in d_var.items() }
50
+
51
+ target= self._cfg['correlations']['target']['name']
52
+ if 'overlay' not in self._cfg['correlations']['target']:
53
+ xlabel = target
54
+ else:
55
+ xlabel= self._cfg['correlations']['target']['overlay']['plots'][target]['labels'][0]
56
+
57
+ d_lab[target] = xlabel
58
+ d_lab['score'] = 'score'
59
+
60
+ d_lab = { var_id : var_name.replace('MeV', '') for var_id, var_name in d_lab.items() }
61
+
62
+ return d_lab
63
+ # -------------------------
64
+ def _add_columns(self, rdf : RDataFrame) -> RDataFrame:
65
+ cfg = self._l_model[0].cfg
66
+ d_def = cfg['dataset']['define']
67
+ for var, expr in d_def.items():
68
+ rdf = rdf.Define(var, expr)
69
+
70
+ return rdf
71
+ # -------------------------
72
+ def _get_scores(self) -> NPA:
73
+ if 'score_from_rdf' not in self._cfg:
74
+ log.debug('Using score from model')
75
+ prd = CVPredict(models=self._l_model, rdf = self._rdf)
76
+
77
+ return prd.predict()
78
+
79
+ name = self._cfg['score_from_rdf']
80
+ log.debug(f'Picking up score from dataframe, column: {name}')
81
+ arr_score = self._rdf.AsNumpy([name])[name]
82
+
83
+ return arr_score
84
+ # -------------------------
85
+ def _get_arrays(self) -> dict[str, NPA]:
86
+ rdf = self._add_columns(self._rdf)
87
+ l_col = [ name.c_str() for name in rdf.GetColumnNames() ]
88
+
89
+ missing= False
90
+ l_var = self._l_feat + [self._target]
91
+ for var in l_var:
92
+ if var not in l_col:
93
+ log.error(f'{"Missing":<20}{var}')
94
+ missing=True
95
+
96
+ if missing:
97
+ raise ValueError('Columns missing')
98
+
99
+ d_var = rdf.AsNumpy(l_var)
100
+ d_var['score'] = self._get_scores()
101
+
102
+ return d_var
103
+ # -------------------------
104
+ def _run_correlations(self, method : str, ax : Axis) -> Axis:
105
+ d_arr = self._get_arrays()
106
+ arr_target = d_arr[self._target]
107
+
108
+ d_corr= {}
109
+ for name, arr_val in d_arr.items():
110
+ if name == self._target:
111
+ continue
112
+
113
+ d_corr[name] = self._calculate_correlations(var=arr_val, target=arr_target, method=method)
114
+
115
+ ax = self._plot_correlations(d_corr=d_corr, method=method, ax=ax)
116
+
117
+ return ax
118
+ # -------------------------
119
+ def _plot_correlations(self, d_corr : dict[str,float], method : str, ax : Axis) -> Axis:
120
+ df = pnd.DataFrame.from_dict(d_corr, orient="index", columns=[method])
121
+ df['variable'] = df.index.map(self._d_xlab)
122
+
123
+ figsize = self._cfg['correlations']['figure']['size']
124
+ ax = df.plot(x='variable', y=method,label=method, figsize=figsize, ax=ax)
125
+
126
+ # Needed to show all labels on x axis
127
+ plt.xticks(ticks=range(len(df)), labels=df.variable)
128
+ if 'xlabelsize' in self._cfg['correlations']['figure']:
129
+ xlabsize= self._cfg['correlations']['figure']['xlabelsize']
130
+ else:
131
+ xlabsize= 30
132
+
133
+ ax.tick_params(axis='x', labelsize=xlabsize)
134
+
135
+ return ax
136
+ # -------------------------
137
+ def _save_plot(self):
138
+ plot_dir = self._cfg['output']
139
+ os.makedirs(plot_dir, exist_ok=True)
140
+
141
+ plot_path = f'{plot_dir}/correlations.png'
142
+ log.info(f'Saving to: {plot_path}')
143
+
144
+ title = None
145
+ if 'title' in self._cfg['correlations']['figure']:
146
+ title = self._cfg['correlations']['figure']['title']
147
+
148
+ rotation=30
149
+ if 'rotate' in self._cfg['correlations']['figure']:
150
+ rotation = self._cfg['correlations']['figure']['rotate']
151
+
152
+ plt.ylim(-1, +1)
153
+ plt.title(title)
154
+ plt.xlabel('')
155
+ plt.ylabel('Correlation')
156
+ plt.grid()
157
+ plt.xticks(rotation=rotation)
158
+ plt.tight_layout()
159
+ plt.savefig(plot_path)
160
+ plt.close()
161
+ # -------------------------
162
+ def _remove_nans(self, var : NPA, tgt : NPA) -> tuple[NPA,NPA]:
163
+ arr_nan_var = numpy.isnan(var)
164
+ arr_nan_tgt = numpy.isnan(tgt)
165
+ arr_is_nan = numpy.logical_or(arr_nan_var, arr_nan_tgt)
166
+ arr_not_nan = numpy.logical_not(arr_is_nan)
167
+
168
+ var = var[arr_not_nan]
169
+ tgt = tgt[arr_not_nan]
170
+
171
+ return var, tgt
172
+ # -------------------------
173
+ def _calculate_correlations(self, var : NPA, target : NPA, method : str) -> float:
174
+ var, target = self._remove_nans(var, target)
175
+
176
+ if method == 'Pearson':
177
+ mat = numpy.corrcoef(var, target)
178
+
179
+ return mat[0,1]
180
+
181
+ if method == r'Kendall-$\tau$':
182
+ tau, _ = kendalltau(var, target)
183
+
184
+ return tau
185
+
186
+ raise NotImplementedError(f'Correlation coefficient {method} not implemented')
187
+ # -------------------------
188
+ def _plot_cutflow(self) -> None:
189
+ '''
190
+ Plot the 'mass' column for different values of working point
191
+ '''
192
+ if 'overlay' not in self._cfg['correlations']['target']:
193
+ log.debug('Not plotting cutflow of target distribution')
194
+ return
195
+
196
+ arr_score = self._get_scores()
197
+ arr_target= self._rdf.AsNumpy([self._target])[self._target]
198
+ arr_wp = self._cfg['correlations']['target']['overlay']['wp']
199
+ rdf = RDF.FromNumpy({'Score' : arr_score, self._target : arr_target})
200
+
201
+ d_rdf = {}
202
+ for wp in arr_wp:
203
+ name = f'WP > {wp:.2}'
204
+ expr = f'Score > {wp:.3}'
205
+ d_rdf[name] = rdf.Filter(expr)
206
+
207
+ cfg_target = self._cfg['correlations']['target']['overlay']
208
+
209
+ ptr=Plotter(d_rdf=d_rdf, cfg=cfg_target)
210
+ ptr.run()
211
+ # -------------------------
212
+ def run(self) -> None:
213
+ '''
214
+ Runs diagnostics
215
+ '''
216
+ if 'correlations' in self._cfg:
217
+ ax = None
218
+ for method in self._cfg['correlations']['methods']:
219
+ ax = self._run_correlations(method=method, ax=ax)
220
+
221
+ self._save_plot()
222
+
223
+ self._plot_cutflow()
224
+ # -------------------------
@@ -0,0 +1,58 @@
1
+ '''
2
+ This module contains the class CVPerformance
3
+ '''
4
+ # pylint: disable=too-many-positional-arguments, too-many-arguments
5
+
6
+ from ROOT import RDataFrame
7
+ from dmu.ml.cv_classifier import CVClassifier
8
+ from dmu.ml.cv_predict import CVPredict
9
+ from dmu.ml.train_mva import TrainMva
10
+ from dmu.logging.log_store import LogStore
11
+
12
+ log=LogStore.add_logger('dmu:ml:cv_performance')
13
+ # -----------------------------------------------------
14
+ class CVPerformance:
15
+ '''
16
+ This class is meant to:
17
+
18
+ - Compare the classifier performance, through the ROC curve, of a model, for a given background and signal sample
19
+ '''
20
+ # ---------------------------
21
+ def plot_roc(
22
+ self,
23
+ name : str,
24
+ color : str,
25
+ sig : RDataFrame,
26
+ bkg : RDataFrame,
27
+ model : list[CVClassifier] ) -> float:
28
+ '''
29
+ Method in charge of picking up model and data and plotting ROC curve
30
+
31
+ Parameters
32
+ --------------------------
33
+ name : Label of combination, used for plots
34
+ sig : ROOT dataframe storing signal samples
35
+ bkg : ROOT dataframe storing background samples
36
+ model: List of instances of the CVClassifier
37
+
38
+ Returns
39
+ --------------------------
40
+ Area under the ROC curve
41
+ '''
42
+ log.info(f'Loading {name}')
43
+
44
+ cvp_sig = CVPredict(models=model, rdf=sig)
45
+ arr_sig = cvp_sig.predict()
46
+
47
+ cvp_bkg = CVPredict(models=model, rdf=bkg)
48
+ arr_bkg = cvp_bkg.predict()
49
+
50
+ _, _, auc = TrainMva.plot_roc_from_prob(
51
+ arr_sig_prb=arr_sig,
52
+ arr_bkg_prb=arr_bkg,
53
+ kind = name,
54
+ color = color, # This should allow the function to pick kind
55
+ ifold = 999) # for the label
56
+
57
+ return auc
58
+ # -----------------------------------------------------
dmu/ml/cv_predict.py CHANGED
@@ -1,8 +1,6 @@
1
1
  '''
2
2
  Module holding CVPredict class
3
3
  '''
4
- from typing import Optional
5
-
6
4
  import pandas as pnd
7
5
  import numpy
8
6
  import tqdm
@@ -21,41 +19,107 @@ class CVPredict:
21
19
  Class used to get classification probabilities from ROOT
22
20
  dataframe and a set of models. The models were trained with CVClassifier
23
21
  '''
24
- def __init__(self, models : Optional[list] = None, rdf : Optional[RDataFrame] = None):
22
+ def __init__(
23
+ self,
24
+ rdf : RDataFrame,
25
+ models : list[CVClassifier]):
25
26
  '''
26
27
  Will take a list of CVClassifier models and a ROOT dataframe
27
- '''
28
-
29
- if models is None:
30
- raise ValueError('No list of models passed')
31
-
32
- if rdf is None:
33
- raise ValueError('No ROOT dataframe passed')
34
28
 
29
+ rdf : ROOT dataframe where features will be extracted
30
+ models: List of models, one per fold
31
+ '''
35
32
  self._l_model = models
36
33
  self._rdf = rdf
34
+ self._nrows : int
35
+ self._l_column : list[str]
37
36
  self._d_nan_rep : dict[str,str]
38
37
 
39
- self._arr_patch : numpy.ndarray
38
+ # Value of score used when no score has been assigned
39
+ self._dummy_score = -1.0
40
+
41
+ # name of column in ROOT dataframe where 1s will prevent prediction
42
+ self._skip_index_column = 'skip_mva_prediction'
43
+
44
+ # name of attribute of features dataframe where array of indices to skip are stored
45
+ self._index_skip = 'skip_mva_prediction'
40
46
  # --------------------------------------------
41
47
  def _initialize(self):
48
+ self._rdf = self._remove_periods(self._rdf)
42
49
  self._rdf = self._define_columns(self._rdf)
43
50
  self._d_nan_rep = self._get_nan_replacements()
51
+ self._l_column = [ name.c_str() for name in self._rdf.GetColumnNames() ]
52
+ self._nrows = self._rdf.Count().GetValue()
53
+ # ----------------------------------
54
+ def _remove_periods(self, rdf : RDataFrame) -> RDataFrame:
55
+ '''
56
+ This will redefine all columns associated to friend trees as:
57
+
58
+ friend_preffix.branch_name -> friend_preffix.branch_name
59
+ '''
60
+ l_col = [ col.c_str() for col in rdf.GetColumnNames() ]
61
+ l_col = [ col for col in l_col if '.' in col ]
62
+
63
+ if len(l_col) == 0:
64
+ return rdf
65
+
66
+ log.debug(60 * '-')
67
+ log.debug('Renaming dotted columns')
68
+ log.debug(60 * '-')
69
+ for col in l_col:
70
+ new = col.replace('.', '_')
71
+ log.debug(f'{col:<50}{"->":10}{new:<20}')
72
+ rdf = rdf.Define(new, col)
73
+
74
+ return rdf
75
+ # --------------------------------------------
76
+ def _get_definitions(self) -> dict[str,str]:
77
+ '''
78
+ This method will search in the configuration the definitions used
79
+ on the dataframe before the dataframe was used to train the model.
80
+ '''
81
+ cfg = self._l_model[0].cfg
82
+ d_def = {}
83
+ if 'define' in cfg['dataset']:
84
+ d_def_gen = cfg['dataset']['define'] # get generic definitions
85
+ d_def.update(d_def_gen)
86
+
87
+ sig_name = 'sig'
88
+ try:
89
+ # Get sample specific definitions. This will be taken from the signal section
90
+ # because predicted scores should come from features defined as for the signal.
91
+ d_def_sam = cfg['dataset']['samples'][sig_name]['definitions']
92
+ except KeyError:
93
+ log.debug(f'No sample specific definitions were found in: {sig_name}')
94
+ return d_def
95
+
96
+ log.info('Adding sample dependent definitions')
97
+ d_def.update(d_def_sam)
98
+
99
+ return d_def
44
100
  # --------------------------------------------
45
101
  def _define_columns(self, rdf : RDataFrame) -> RDataFrame:
46
- cfg = self._l_model[0].cfg
47
-
48
- if 'define' not in cfg['dataset']:
49
- log.debug('No define section found in config, will not define extra columns')
102
+ d_def = self._get_definitions()
103
+ if len(d_def) == 0:
104
+ log.info('No definitions found')
50
105
  return self._rdf
51
106
 
52
- d_def = cfg['dataset']['define']
107
+ dexc = None
53
108
  log.debug(60 * '-')
54
109
  log.info('Defining columns in RDF before evaluating classifier')
55
110
  log.debug(60 * '-')
56
111
  for name, expr in d_def.items():
112
+ expr = expr.replace('.', '_')
113
+
57
114
  log.debug(f'{name:<20}{"<---":20}{expr:<100}')
58
- rdf = rdf.Define(name, expr)
115
+ try:
116
+ rdf = rdf.Define(name, expr)
117
+ except TypeError as exc:
118
+ log.error(f'Cannot define {name}={expr}')
119
+ dexc = exc
120
+
121
+ if dexc is not None:
122
+ raise TypeError('Could not define at least one column') from dexc
59
123
 
60
124
  return rdf
61
125
  # --------------------------------------------
@@ -68,21 +132,25 @@ class CVPredict:
68
132
 
69
133
  return cfg['dataset']['nan']
70
134
  # --------------------------------------------
71
- def _replace_nans(self, df : pnd.DataFrame) -> pnd.DataFrame:
135
+ def _replace_nans(self, df_ft : pnd.DataFrame) -> pnd.DataFrame:
136
+ '''
137
+ Funtion replaces nans in user specified columns with user specified values
138
+ These NaNs are expected
139
+ '''
72
140
  if len(self._d_nan_rep) == 0:
73
141
  log.debug('Not doing any NaN replacement')
74
- return df
142
+ return df_ft
75
143
 
76
144
  log.info(60 * '-')
77
145
  log.info('Doing NaN replacements')
78
146
  log.info(60 * '-')
79
147
  for var, val in self._d_nan_rep.items():
80
148
  log.info(f'{var:<20}{"--->":20}{val:<20.3f}')
81
- df[var] = df[var].fillna(val)
149
+ df_ft[var] = df_ft[var].fillna(val)
82
150
 
83
- return df
151
+ return df_ft
84
152
  # --------------------------------------------
85
- def _get_df(self):
153
+ def _get_df(self) -> pnd.DataFrame:
86
154
  '''
87
155
  Will make ROOT rdf into dataframe and return it
88
156
  '''
@@ -90,11 +158,11 @@ class CVPredict:
90
158
  l_ft = model.features
91
159
  d_data= self._rdf.AsNumpy(l_ft)
92
160
  df_ft = pnd.DataFrame(d_data)
93
- df_ft = self._replace_nans(df_ft)
94
- df_ft = ut.patch_and_tag(df_ft)
95
-
96
- if 'patched_indices' in df_ft.attrs:
97
- self._arr_patch = df_ft.attrs['patched_indices']
161
+ df_ft = self._replace_nans(df_ft=df_ft)
162
+ df_ft = self._tag_skipped(df_ft=df_ft)
163
+ df_ft = ut.tag_nans(
164
+ df = df_ft,
165
+ indexes = self._index_skip)
98
166
 
99
167
  nfeat = len(l_ft)
100
168
  log.info(f'Found {nfeat} features')
@@ -103,6 +171,24 @@ class CVPredict:
103
171
 
104
172
  return df_ft
105
173
  # --------------------------------------------
174
+ def _tag_skipped(self, df_ft : pnd.DataFrame) -> pnd.DataFrame:
175
+ '''
176
+ Will drop rows with features where column with name _skip_name (currently "_skip_mva_prediction") has values of 1
177
+ '''
178
+ if self._skip_index_column not in self._l_column:
179
+ log.debug(f'Not dropping any rows through: {self._skip_index_column}')
180
+ return df_ft
181
+
182
+ log.info(f'Dropping rows through: {self._skip_index_column}')
183
+ arr_drop = self._rdf.AsNumpy([self._skip_index_column])[self._skip_index_column]
184
+
185
+ if self._index_skip in df_ft.attrs:
186
+ raise ValueError(f'Feature dataframe already contains attribute: {self._index_skip}')
187
+
188
+ df_ft.attrs[self._index_skip] = numpy.where(arr_drop == 1)[0]
189
+
190
+ return df_ft
191
+ # --------------------------------------------
106
192
  def _non_overlapping_hashes(self, model, df_ft):
107
193
  '''
108
194
  Will return True if hashes of model and data do not overlap
@@ -147,8 +233,8 @@ class CVPredict:
147
233
  '''
148
234
  Evaluate the dataset for one of the folds, by taking the model and the full dataset
149
235
  '''
150
- s_dat_hash = set(df_ft.index)
151
- s_mod_hash = model.hashes
236
+ s_dat_hash : set[str] = set(df_ft.index)
237
+ s_mod_hash : set[str] = model.hashes
152
238
 
153
239
  s_dif_hash = s_dat_hash - s_mod_hash
154
240
 
@@ -164,19 +250,29 @@ class CVPredict:
164
250
  d_prob = dict(zip(l_hash, l_prob))
165
251
  nfeat = len(df_ft_group)
166
252
  nprob = len(l_prob)
167
- log.debug(f'{nfeat:<10}{"->":10}{nprob:<10}')
253
+
254
+ if nfeat != nprob:
255
+ raise ValueError(f'Number of features and probabilities do not agree: {nfeat} != {nprob}')
168
256
 
169
257
  return d_prob
170
258
  # --------------------------------------------
171
- def _patch_probabilities(self, arr_prb : numpy.ndarray) -> numpy.ndarray:
172
- if not hasattr(self, '_arr_patch'):
173
- return arr_prb
259
+ def _predict_signal_probabilities(
260
+ self,
261
+ model : CVClassifier,
262
+ df_ft : pnd.DataFrame) -> numpy.ndarray:
263
+ '''
264
+ Takes model and features dataframe, returns array of signal probabilities
265
+ '''
266
+ if self._non_overlapping_hashes(model, df_ft):
267
+ log.debug('No intersecting hashes found between model and data')
268
+ arr_prb = model.predict_proba(df_ft)
269
+ else:
270
+ log.info('Intersecting hashes found between model and data')
271
+ arr_prb = self._predict_with_overlap(df_ft)
174
272
 
175
- nentries = len(self._arr_patch)
176
- log.warning(f'Patching {nentries} probabilities with -1')
177
- arr_prb[self._arr_patch] = -1
273
+ arr_sig_prb = arr_prb.T[1]
178
274
 
179
- return arr_prb
275
+ return arr_sig_prb
180
276
  # --------------------------------------------
181
277
  def predict(self) -> numpy.ndarray:
182
278
  '''
@@ -187,15 +283,22 @@ class CVPredict:
187
283
  df_ft = self._get_df()
188
284
  model = self._l_model[0]
189
285
 
190
- if self._non_overlapping_hashes(model, df_ft):
191
- log.debug('No intersecting hashes found between model and data')
192
- arr_prb = model.predict_proba(df_ft)
193
- else:
194
- log.info('Intersecting hashes found between model and data')
195
- arr_prb = self._predict_with_overlap(df_ft)
286
+ arr_keep = None
287
+ arr_skip = None
288
+ if self._index_skip in df_ft.attrs:
289
+ arr_skip = df_ft.attrs[self._index_skip]
290
+ df_ft = df_ft.drop(arr_skip)
291
+ arr_keep = df_ft.index.to_numpy()
292
+
293
+ arr_sig_prb = self._predict_signal_probabilities(
294
+ model = model,
295
+ df_ft = df_ft)
296
+
297
+ if arr_skip is None:
298
+ return arr_sig_prb
196
299
 
197
- arr_prb = self._patch_probabilities(arr_prb)
198
- arr_prb = arr_prb.T[1]
300
+ arr_all_sig_prb = numpy.full(self._nrows, self._dummy_score)
301
+ arr_all_sig_prb[arr_keep] = arr_sig_prb
199
302
 
200
- return arr_prb
303
+ return arr_all_sig_prb
201
304
  # ---------------------------------------