mlquantify 0.0.11.2__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlquantify/__init__.py +32 -6
- mlquantify/base.py +559 -257
- mlquantify/classification/__init__.py +1 -1
- mlquantify/classification/methods.py +160 -0
- mlquantify/evaluation/__init__.py +14 -2
- mlquantify/evaluation/measures.py +215 -0
- mlquantify/evaluation/protocol.py +647 -0
- mlquantify/methods/__init__.py +37 -40
- mlquantify/methods/aggregative.py +1030 -0
- mlquantify/methods/meta.py +472 -0
- mlquantify/methods/mixture_models.py +1003 -0
- mlquantify/methods/non_aggregative.py +136 -0
- mlquantify/methods/threshold_optimization.py +957 -0
- mlquantify/model_selection.py +377 -232
- mlquantify/plots.py +367 -0
- mlquantify/utils/__init__.py +2 -2
- mlquantify/utils/general.py +334 -0
- mlquantify/utils/method.py +449 -0
- {mlquantify-0.0.11.2.dist-info → mlquantify-0.1.1.dist-info}/METADATA +137 -122
- mlquantify-0.1.1.dist-info/RECORD +22 -0
- {mlquantify-0.0.11.2.dist-info → mlquantify-0.1.1.dist-info}/WHEEL +1 -1
- mlquantify/classification/pwkclf.py +0 -73
- mlquantify/evaluation/measures/__init__.py +0 -26
- mlquantify/evaluation/measures/ae.py +0 -11
- mlquantify/evaluation/measures/bias.py +0 -16
- mlquantify/evaluation/measures/kld.py +0 -8
- mlquantify/evaluation/measures/mse.py +0 -12
- mlquantify/evaluation/measures/nae.py +0 -16
- mlquantify/evaluation/measures/nkld.py +0 -13
- mlquantify/evaluation/measures/nrae.py +0 -16
- mlquantify/evaluation/measures/rae.py +0 -12
- mlquantify/evaluation/measures/se.py +0 -12
- mlquantify/evaluation/protocol/_Protocol.py +0 -202
- mlquantify/evaluation/protocol/__init__.py +0 -2
- mlquantify/evaluation/protocol/app.py +0 -146
- mlquantify/evaluation/protocol/npp.py +0 -34
- mlquantify/methods/aggregative/ThreholdOptm/_ThreholdOptimization.py +0 -62
- mlquantify/methods/aggregative/ThreholdOptm/__init__.py +0 -7
- mlquantify/methods/aggregative/ThreholdOptm/acc.py +0 -27
- mlquantify/methods/aggregative/ThreholdOptm/max.py +0 -23
- mlquantify/methods/aggregative/ThreholdOptm/ms.py +0 -21
- mlquantify/methods/aggregative/ThreholdOptm/ms2.py +0 -25
- mlquantify/methods/aggregative/ThreholdOptm/pacc.py +0 -41
- mlquantify/methods/aggregative/ThreholdOptm/t50.py +0 -21
- mlquantify/methods/aggregative/ThreholdOptm/x.py +0 -23
- mlquantify/methods/aggregative/__init__.py +0 -9
- mlquantify/methods/aggregative/cc.py +0 -32
- mlquantify/methods/aggregative/emq.py +0 -86
- mlquantify/methods/aggregative/fm.py +0 -72
- mlquantify/methods/aggregative/gac.py +0 -96
- mlquantify/methods/aggregative/gpac.py +0 -87
- mlquantify/methods/aggregative/mixtureModels/_MixtureModel.py +0 -81
- mlquantify/methods/aggregative/mixtureModels/__init__.py +0 -5
- mlquantify/methods/aggregative/mixtureModels/dys.py +0 -55
- mlquantify/methods/aggregative/mixtureModels/dys_syn.py +0 -89
- mlquantify/methods/aggregative/mixtureModels/hdy.py +0 -46
- mlquantify/methods/aggregative/mixtureModels/smm.py +0 -27
- mlquantify/methods/aggregative/mixtureModels/sord.py +0 -77
- mlquantify/methods/aggregative/pcc.py +0 -33
- mlquantify/methods/aggregative/pwk.py +0 -38
- mlquantify/methods/meta/__init__.py +0 -1
- mlquantify/methods/meta/ensemble.py +0 -236
- mlquantify/methods/non_aggregative/__init__.py +0 -1
- mlquantify/methods/non_aggregative/hdx.py +0 -71
- mlquantify/plots/__init__.py +0 -2
- mlquantify/plots/distribution_plot.py +0 -109
- mlquantify/plots/protocol_plot.py +0 -193
- mlquantify/utils/general_purposes/__init__.py +0 -8
- mlquantify/utils/general_purposes/convert_col_to_array.py +0 -13
- mlquantify/utils/general_purposes/generate_artificial_indexes.py +0 -29
- mlquantify/utils/general_purposes/get_real_prev.py +0 -9
- mlquantify/utils/general_purposes/load_quantifier.py +0 -4
- mlquantify/utils/general_purposes/make_prevs.py +0 -23
- mlquantify/utils/general_purposes/normalize.py +0 -20
- mlquantify/utils/general_purposes/parallel.py +0 -10
- mlquantify/utils/general_purposes/round_protocol_df.py +0 -14
- mlquantify/utils/method_purposes/__init__.py +0 -6
- mlquantify/utils/method_purposes/distances.py +0 -21
- mlquantify/utils/method_purposes/getHist.py +0 -13
- mlquantify/utils/method_purposes/get_scores.py +0 -33
- mlquantify/utils/method_purposes/moss.py +0 -16
- mlquantify/utils/method_purposes/ternary_search.py +0 -14
- mlquantify/utils/method_purposes/tprfpr.py +0 -42
- mlquantify-0.0.11.2.dist-info/RECORD +0 -73
- {mlquantify-0.0.11.2.dist-info → mlquantify-0.1.1.dist-info}/top_level.txt +0 -0
mlquantify/model_selection.py
CHANGED
|
@@ -1,232 +1,377 @@
|
|
|
1
|
-
from .base import Quantifier
|
|
2
|
-
from typing import Union, List
|
|
3
|
-
import itertools
|
|
4
|
-
from tqdm import tqdm
|
|
5
|
-
import signal
|
|
6
|
-
from copy import deepcopy
|
|
7
|
-
import numpy as np
|
|
8
|
-
from sklearn.model_selection import train_test_split
|
|
9
|
-
from .utils import parallel
|
|
10
|
-
from .evaluation import
|
|
11
|
-
|
|
12
|
-
class GridSearchQ(Quantifier):
|
|
13
|
-
"""
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
self.
|
|
156
|
-
|
|
157
|
-
self.
|
|
158
|
-
|
|
159
|
-
if self.
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
"""
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
1
|
+
from .base import Quantifier
|
|
2
|
+
from typing import Union, List
|
|
3
|
+
import itertools
|
|
4
|
+
from tqdm import tqdm
|
|
5
|
+
import signal
|
|
6
|
+
from copy import deepcopy
|
|
7
|
+
import numpy as np
|
|
8
|
+
from sklearn.model_selection import train_test_split
|
|
9
|
+
from .utils.general import parallel, get_measure
|
|
10
|
+
from .evaluation.protocol import APP, NPP
|
|
11
|
+
|
|
12
|
+
class GridSearchQ(Quantifier):
|
|
13
|
+
"""Hyperparameter optimization for quantification models using grid search.
|
|
14
|
+
|
|
15
|
+
GridSearchQ allows hyperparameter tuning for quantification models
|
|
16
|
+
by minimizing a quantification-oriented loss over a parameter grid.
|
|
17
|
+
This method evaluates hyperparameter configurations using quantification
|
|
18
|
+
metrics rather than standard classification metrics, ensuring better
|
|
19
|
+
approximation of class distributions.
|
|
20
|
+
|
|
21
|
+
Parameters
|
|
22
|
+
----------
|
|
23
|
+
model : Quantifier
|
|
24
|
+
The base quantification model to optimize.
|
|
25
|
+
|
|
26
|
+
param_grid : dict
|
|
27
|
+
Dictionary where keys are parameter names (str) and values are
|
|
28
|
+
lists of parameter settings to try.
|
|
29
|
+
|
|
30
|
+
protocol : str, default='app'
|
|
31
|
+
The quantification protocol to use. Supported options are:
|
|
32
|
+
- 'app': Artificial Prevalence Protocol.
|
|
33
|
+
- 'npp': Natural Prevalence Protocol.
|
|
34
|
+
|
|
35
|
+
n_prevs : int, default=None
|
|
36
|
+
Number of prevalence points to generate for APP.
|
|
37
|
+
|
|
38
|
+
n_repetitions : int, default=1
|
|
39
|
+
Number of repetitions to perform for NPP.
|
|
40
|
+
|
|
41
|
+
scoring : Union[List[str], str], default='mae'
|
|
42
|
+
Metric or metrics to evaluate the model's performance. Can be
|
|
43
|
+
a string (e.g., 'mae') or a list of metrics.
|
|
44
|
+
|
|
45
|
+
refit : bool, default=True
|
|
46
|
+
If True, refit the model using the best found hyperparameters
|
|
47
|
+
on the entire dataset.
|
|
48
|
+
|
|
49
|
+
val_split : float, default=0.4
|
|
50
|
+
Proportion of the training data to use for validation. Only
|
|
51
|
+
applicable if cross-validation is not used.
|
|
52
|
+
|
|
53
|
+
n_jobs : int, default=1
|
|
54
|
+
The number of jobs to run in parallel. -1 means using all processors.
|
|
55
|
+
|
|
56
|
+
random_seed : int, default=42
|
|
57
|
+
Random seed for reproducibility.
|
|
58
|
+
|
|
59
|
+
timeout : int, default=-1
|
|
60
|
+
Maximum time (in seconds) allowed for a single parameter combination.
|
|
61
|
+
A value of -1 disables the timeout.
|
|
62
|
+
|
|
63
|
+
verbose : bool, default=False
|
|
64
|
+
If True, print progress messages during grid search.
|
|
65
|
+
|
|
66
|
+
Attributes
|
|
67
|
+
----------
|
|
68
|
+
best_params : dict
|
|
69
|
+
The parameter setting that gave the best results on the validation set.
|
|
70
|
+
|
|
71
|
+
best_score : float
|
|
72
|
+
The best score achieved during the grid search.
|
|
73
|
+
|
|
74
|
+
results : pandas.DataFrame
|
|
75
|
+
A DataFrame containing details of all evaluations, including parameters,
|
|
76
|
+
scores, and execution times.
|
|
77
|
+
|
|
78
|
+
References
|
|
79
|
+
----------
|
|
80
|
+
The idea of using grid search for hyperparameter optimization in
|
|
81
|
+
quantification models was discussed in:
|
|
82
|
+
Moreo, Alejandro; Sebastiani, Fabrizio. "Re-assessing the 'Classify and Count'
|
|
83
|
+
Quantification Method". In: Advances in Information Retrieval:
|
|
84
|
+
43rd European Conference on IR Research, ECIR 2021, Virtual Event,
|
|
85
|
+
March 28–April 1, 2021, Proceedings, Part II. Springer International Publishing,
|
|
86
|
+
2021, pp. 75–91. [Link](https://link.springer.com/chapter/10.1007/978-3-030-72240-1_6).
|
|
87
|
+
|
|
88
|
+
Examples
|
|
89
|
+
--------
|
|
90
|
+
>>> from mlquantify.methods.aggregative import DyS
|
|
91
|
+
>>> from mlquantify.model_selection import GridSearchQ
|
|
92
|
+
>>> from sklearn.ensemble import RandomForestClassifier
|
|
93
|
+
>>> from sklearn.datasets import load_breast_cancer
|
|
94
|
+
>>> from sklearn.model_selection import train_test_split
|
|
95
|
+
>>>
|
|
96
|
+
>>> # Loading dataset from sklearn
|
|
97
|
+
>>> features, target = load_breast_cancer(return_X_y=True)
|
|
98
|
+
>>>
|
|
99
|
+
>>> # Splitting into train and test
|
|
100
|
+
>>> X_train, X_test, y_train, y_test = train_test_split(features, target, test_size=0.3)
|
|
101
|
+
>>>
|
|
102
|
+
>>> model = DyS(RandomForestClassifier())
|
|
103
|
+
>>>
|
|
104
|
+
>>> # Creating the hyperparameter grid
|
|
105
|
+
>>> param_grid = {
|
|
106
|
+
>>> 'learner__n_estimators': [100, 500, 1000],
|
|
107
|
+
>>> 'learner__criterion': ["gini", "entropy"],
|
|
108
|
+
>>> 'measure': ["topsoe", "hellinger"]
|
|
109
|
+
>>> }
|
|
110
|
+
>>>
|
|
111
|
+
>>> gs = GridSearchQ(
|
|
112
|
+
... model=model,
|
|
113
|
+
... param_grid=param_grid,
|
|
114
|
+
... protocol='app', # Default
|
|
115
|
+
... n_prevs=100, # Default
|
|
116
|
+
... scoring='nae',
|
|
117
|
+
... refit=True, # Default
|
|
118
|
+
... val_split=0.3,
|
|
119
|
+
... n_jobs=-1,
|
|
120
|
+
... verbose=True)
|
|
121
|
+
>>>
|
|
122
|
+
>>> gs.fit(X_train, y_train)
|
|
123
|
+
[GridSearchQ]: Optimization complete. Best score: 0.0060630241297973545, with parameters: {'learner__n_estimators': 500, 'learner__criterion': 'entropy', 'measure': 'topsoe'}.
|
|
124
|
+
>>> predictions = gs.predict(X_test)
|
|
125
|
+
>>> predictions
|
|
126
|
+
{0: 0.4182508973311534, 1: 0.5817491026688466}
|
|
127
|
+
"""
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def __init__(self,
|
|
131
|
+
model: Quantifier,
|
|
132
|
+
param_grid: dict,
|
|
133
|
+
protocol: str = 'app',
|
|
134
|
+
n_prevs: int = 100,
|
|
135
|
+
n_repetitions: int = 1,
|
|
136
|
+
scoring: Union[List[str], str] = "ae",
|
|
137
|
+
refit: bool = True,
|
|
138
|
+
val_split: float = 0.4,
|
|
139
|
+
n_jobs: int = 1,
|
|
140
|
+
random_seed: int = 42,
|
|
141
|
+
timeout: int = -1,
|
|
142
|
+
verbose: bool = False):
|
|
143
|
+
|
|
144
|
+
self.model = model
|
|
145
|
+
self.param_grid = param_grid
|
|
146
|
+
self.protocol = protocol.lower()
|
|
147
|
+
self.n_prevs = n_prevs
|
|
148
|
+
self.n_repetitions = n_repetitions
|
|
149
|
+
self.refit = refit
|
|
150
|
+
self.val_split = val_split
|
|
151
|
+
self.n_jobs = n_jobs
|
|
152
|
+
self.random_seed = random_seed
|
|
153
|
+
self.timeout = timeout
|
|
154
|
+
self.verbose = verbose
|
|
155
|
+
self.scoring = [get_measure(measure) for measure in (scoring if isinstance(scoring, list) else [scoring])]
|
|
156
|
+
|
|
157
|
+
assert self.protocol in {'app', 'npp'}, 'Unknown protocol; valid ones are "app" or "npp".'
|
|
158
|
+
|
|
159
|
+
if self.protocol == 'npp' and self.n_repetitions <= 1:
|
|
160
|
+
raise ValueError('For "npp" protocol, n_repetitions must be greater than 1.')
|
|
161
|
+
|
|
162
|
+
def sout(self, msg):
|
|
163
|
+
"""Prints messages if verbose is True."""
|
|
164
|
+
if self.verbose:
|
|
165
|
+
print(f'[{self.__class__.__name__}]: {msg}')
|
|
166
|
+
|
|
167
|
+
def __get_protocol(self, model, sample_size):
|
|
168
|
+
"""Get the appropriate protocol instance.
|
|
169
|
+
|
|
170
|
+
Parameters
|
|
171
|
+
----------
|
|
172
|
+
model : Quantifier
|
|
173
|
+
The quantification model.
|
|
174
|
+
|
|
175
|
+
sample_size : int
|
|
176
|
+
The sample size for batch processing.
|
|
177
|
+
|
|
178
|
+
Returns
|
|
179
|
+
-------
|
|
180
|
+
object
|
|
181
|
+
Instance of APP or NPP protocol, depending on the configured protocol.
|
|
182
|
+
"""
|
|
183
|
+
protocol_params = {
|
|
184
|
+
'models': model,
|
|
185
|
+
'batch_size': sample_size,
|
|
186
|
+
'n_iterations': self.n_repetitions,
|
|
187
|
+
'n_jobs': self.n_jobs,
|
|
188
|
+
'verbose': False,
|
|
189
|
+
'random_state': 35,
|
|
190
|
+
'return_type': "predictions"
|
|
191
|
+
}
|
|
192
|
+
return APP(n_prevs=self.n_prevs, **protocol_params) if self.protocol == 'app' else NPP(**protocol_params)
|
|
193
|
+
|
|
194
|
+
def fit(self, X, y):
|
|
195
|
+
"""Fit the quantifier model and perform grid search.
|
|
196
|
+
|
|
197
|
+
Parameters
|
|
198
|
+
----------
|
|
199
|
+
X : array-like of shape (n_samples, n_features)
|
|
200
|
+
Training features, where `n_samples` is the number of samples
|
|
201
|
+
and `n_features` is the number of features.
|
|
202
|
+
|
|
203
|
+
y : array-like of shape (n_samples,)
|
|
204
|
+
Training labels.
|
|
205
|
+
|
|
206
|
+
Returns
|
|
207
|
+
-------
|
|
208
|
+
self : GridSearchQ
|
|
209
|
+
Returns the fitted instance of GridSearchQ.
|
|
210
|
+
"""
|
|
211
|
+
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=self.val_split, random_state=self.random_seed)
|
|
212
|
+
param_combinations = list(itertools.product(*self.param_grid.values()))
|
|
213
|
+
best_score, best_params = None, None
|
|
214
|
+
|
|
215
|
+
if self.timeout > 0:
|
|
216
|
+
signal.signal(signal.SIGALRM, self._timeout_handler)
|
|
217
|
+
|
|
218
|
+
def evaluate_combination(params):
|
|
219
|
+
"""Evaluate a single combination of hyperparameters.
|
|
220
|
+
|
|
221
|
+
Parameters
|
|
222
|
+
----------
|
|
223
|
+
params : tuple
|
|
224
|
+
A tuple of hyperparameter values.
|
|
225
|
+
|
|
226
|
+
Returns
|
|
227
|
+
-------
|
|
228
|
+
float or None
|
|
229
|
+
The evaluation score, or None if a timeout occurred.
|
|
230
|
+
"""
|
|
231
|
+
if self.verbose:
|
|
232
|
+
print(f"\tEvaluating combination: {str(params)}")
|
|
233
|
+
|
|
234
|
+
model = deepcopy(self.model)
|
|
235
|
+
model.set_params(**dict(zip(self.param_grid.keys(), params)))
|
|
236
|
+
protocol_instance = self.__get_protocol(model, len(y_train))
|
|
237
|
+
|
|
238
|
+
try:
|
|
239
|
+
if self.timeout > 0:
|
|
240
|
+
signal.alarm(self.timeout)
|
|
241
|
+
|
|
242
|
+
protocol_instance.fit(X_train, y_train)
|
|
243
|
+
_, real_prevs, pred_prevs = protocol_instance.predict(X_val, y_val)
|
|
244
|
+
scores = [np.mean([measure(rp, pp) for rp, pp in zip(real_prevs, pred_prevs)]) for measure in self.scoring]
|
|
245
|
+
|
|
246
|
+
if self.timeout > 0:
|
|
247
|
+
signal.alarm(0)
|
|
248
|
+
|
|
249
|
+
if self.verbose:
|
|
250
|
+
print(f"\t\\--Finished evaluation: {str(params)}")
|
|
251
|
+
|
|
252
|
+
return np.mean(scores) if scores else None
|
|
253
|
+
except TimeoutError:
|
|
254
|
+
self.sout(f'Timeout reached for combination: {params}.')
|
|
255
|
+
return None
|
|
256
|
+
|
|
257
|
+
results = parallel(
|
|
258
|
+
evaluate_combination,
|
|
259
|
+
tqdm(param_combinations, desc="Evaluating combinations", total=len(param_combinations)) if self.verbose else param_combinations,
|
|
260
|
+
n_jobs=self.n_jobs
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
for score, params in zip(results, param_combinations):
|
|
264
|
+
if score is not None and (best_score is None or score < best_score):
|
|
265
|
+
best_score, best_params = score, params
|
|
266
|
+
|
|
267
|
+
self.best_score = best_score
|
|
268
|
+
self.best_params = dict(zip(self.param_grid.keys(), best_params))
|
|
269
|
+
self.sout(f'Optimization complete. Best score: {self.best_score}, with parameters: {self.best_params}.')
|
|
270
|
+
|
|
271
|
+
if self.refit and self.best_params:
|
|
272
|
+
self.model.set_params(**self.best_params)
|
|
273
|
+
self.model.fit(X, y)
|
|
274
|
+
self.best_model_ = self.model
|
|
275
|
+
|
|
276
|
+
return self
|
|
277
|
+
|
|
278
|
+
def predict(self, X):
|
|
279
|
+
"""Make predictions using the best found model.
|
|
280
|
+
|
|
281
|
+
Parameters
|
|
282
|
+
----------
|
|
283
|
+
X : array-like of shape (n_samples, n_features)
|
|
284
|
+
Data to predict on.
|
|
285
|
+
|
|
286
|
+
Returns
|
|
287
|
+
-------
|
|
288
|
+
array-like
|
|
289
|
+
Predictions for the input data.
|
|
290
|
+
|
|
291
|
+
Raises
|
|
292
|
+
------
|
|
293
|
+
RuntimeError
|
|
294
|
+
If the model has not been fitted yet.
|
|
295
|
+
"""
|
|
296
|
+
if not hasattr(self, 'best_model_'):
|
|
297
|
+
raise RuntimeError("The model has not been fitted yet.")
|
|
298
|
+
return self.best_model_.predict(X)
|
|
299
|
+
|
|
300
|
+
@property
|
|
301
|
+
def classes_(self):
|
|
302
|
+
"""Get the classes of the best model.
|
|
303
|
+
|
|
304
|
+
Returns
|
|
305
|
+
-------
|
|
306
|
+
array-like
|
|
307
|
+
The classes learned by the best model.
|
|
308
|
+
"""
|
|
309
|
+
return self.best_model_.classes_
|
|
310
|
+
|
|
311
|
+
def set_params(self, **parameters):
|
|
312
|
+
"""Set the hyperparameters for grid search.
|
|
313
|
+
|
|
314
|
+
Parameters
|
|
315
|
+
----------
|
|
316
|
+
parameters : dict
|
|
317
|
+
Dictionary of hyperparameters to set.
|
|
318
|
+
"""
|
|
319
|
+
self.param_grid = parameters
|
|
320
|
+
|
|
321
|
+
def get_params(self, deep=True):
|
|
322
|
+
"""Get the parameters of the best model.
|
|
323
|
+
|
|
324
|
+
Parameters
|
|
325
|
+
----------
|
|
326
|
+
deep : bool, optional, default=True
|
|
327
|
+
If True, will return the parameters for this estimator and
|
|
328
|
+
contained subobjects.
|
|
329
|
+
|
|
330
|
+
Returns
|
|
331
|
+
-------
|
|
332
|
+
dict
|
|
333
|
+
Parameters of the best model.
|
|
334
|
+
|
|
335
|
+
Raises
|
|
336
|
+
------
|
|
337
|
+
ValueError
|
|
338
|
+
If called before the model has been fitted.
|
|
339
|
+
"""
|
|
340
|
+
if hasattr(self, 'best_model_'):
|
|
341
|
+
return self.best_model_.get_params()
|
|
342
|
+
raise ValueError('get_params called before fit.')
|
|
343
|
+
|
|
344
|
+
def best_model(self):
|
|
345
|
+
"""Return the best model after fitting.
|
|
346
|
+
|
|
347
|
+
Returns
|
|
348
|
+
-------
|
|
349
|
+
Quantifier
|
|
350
|
+
The best fitted model.
|
|
351
|
+
|
|
352
|
+
Raises
|
|
353
|
+
------
|
|
354
|
+
ValueError
|
|
355
|
+
If called before fitting.
|
|
356
|
+
"""
|
|
357
|
+
if hasattr(self, 'best_model_'):
|
|
358
|
+
return self.best_model_
|
|
359
|
+
raise ValueError('best_model called before fit.')
|
|
360
|
+
|
|
361
|
+
def _timeout_handler(self, signum, frame):
|
|
362
|
+
"""Handle timeouts during evaluation.
|
|
363
|
+
|
|
364
|
+
Parameters
|
|
365
|
+
----------
|
|
366
|
+
signum : int
|
|
367
|
+
Signal number.
|
|
368
|
+
|
|
369
|
+
frame : object
|
|
370
|
+
Current stack frame.
|
|
371
|
+
|
|
372
|
+
Raises
|
|
373
|
+
------
|
|
374
|
+
TimeoutError
|
|
375
|
+
Raised when the timeout is reached.
|
|
376
|
+
"""
|
|
377
|
+
raise TimeoutError
|