mlquantify 0.1.8__py3-none-any.whl → 0.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. mlquantify/__init__.py +0 -29
  2. mlquantify/adjust_counting/__init__.py +14 -0
  3. mlquantify/adjust_counting/_adjustment.py +365 -0
  4. mlquantify/adjust_counting/_base.py +247 -0
  5. mlquantify/adjust_counting/_counting.py +145 -0
  6. mlquantify/adjust_counting/_utils.py +114 -0
  7. mlquantify/base.py +117 -519
  8. mlquantify/base_aggregative.py +209 -0
  9. mlquantify/calibration.py +1 -0
  10. mlquantify/confidence.py +335 -0
  11. mlquantify/likelihood/__init__.py +5 -0
  12. mlquantify/likelihood/_base.py +161 -0
  13. mlquantify/likelihood/_classes.py +414 -0
  14. mlquantify/meta/__init__.py +1 -0
  15. mlquantify/meta/_classes.py +761 -0
  16. mlquantify/metrics/__init__.py +21 -0
  17. mlquantify/metrics/_oq.py +109 -0
  18. mlquantify/metrics/_rq.py +98 -0
  19. mlquantify/{evaluation/measures.py → metrics/_slq.py} +43 -28
  20. mlquantify/mixture/__init__.py +7 -0
  21. mlquantify/mixture/_base.py +153 -0
  22. mlquantify/mixture/_classes.py +400 -0
  23. mlquantify/mixture/_utils.py +112 -0
  24. mlquantify/model_selection/__init__.py +9 -0
  25. mlquantify/model_selection/_protocol.py +358 -0
  26. mlquantify/model_selection/_search.py +315 -0
  27. mlquantify/model_selection/_split.py +1 -0
  28. mlquantify/multiclass.py +350 -0
  29. mlquantify/neighbors/__init__.py +9 -0
  30. mlquantify/neighbors/_base.py +198 -0
  31. mlquantify/neighbors/_classes.py +159 -0
  32. mlquantify/{classification/methods.py → neighbors/_classification.py} +48 -66
  33. mlquantify/neighbors/_kde.py +270 -0
  34. mlquantify/neighbors/_utils.py +135 -0
  35. mlquantify/neural/__init__.py +1 -0
  36. mlquantify/utils/__init__.py +47 -2
  37. mlquantify/utils/_artificial.py +27 -0
  38. mlquantify/utils/_constraints.py +219 -0
  39. mlquantify/utils/_context.py +21 -0
  40. mlquantify/utils/_decorators.py +36 -0
  41. mlquantify/utils/_exceptions.py +12 -0
  42. mlquantify/utils/_get_scores.py +159 -0
  43. mlquantify/utils/_load.py +18 -0
  44. mlquantify/utils/_parallel.py +6 -0
  45. mlquantify/utils/_random.py +36 -0
  46. mlquantify/utils/_sampling.py +273 -0
  47. mlquantify/utils/_tags.py +44 -0
  48. mlquantify/utils/_validation.py +447 -0
  49. mlquantify/utils/prevalence.py +61 -0
  50. {mlquantify-0.1.8.dist-info → mlquantify-0.1.9.dist-info}/METADATA +2 -1
  51. mlquantify-0.1.9.dist-info/RECORD +53 -0
  52. mlquantify/classification/__init__.py +0 -1
  53. mlquantify/evaluation/__init__.py +0 -14
  54. mlquantify/evaluation/protocol.py +0 -289
  55. mlquantify/methods/__init__.py +0 -37
  56. mlquantify/methods/aggregative.py +0 -1159
  57. mlquantify/methods/meta.py +0 -472
  58. mlquantify/methods/mixture_models.py +0 -1003
  59. mlquantify/methods/non_aggregative.py +0 -136
  60. mlquantify/methods/threshold_optimization.py +0 -869
  61. mlquantify/model_selection.py +0 -377
  62. mlquantify/plots.py +0 -367
  63. mlquantify/utils/general.py +0 -371
  64. mlquantify/utils/method.py +0 -449
  65. mlquantify-0.1.8.dist-info/RECORD +0 -22
  66. {mlquantify-0.1.8.dist-info → mlquantify-0.1.9.dist-info}/WHEEL +0 -0
  67. {mlquantify-0.1.8.dist-info → mlquantify-0.1.9.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,358 @@
1
+ import numpy as np
2
+
3
+ from mlquantify.base import BaseQuantifier, ProtocolMixin
4
+ from mlquantify.utils._constraints import Interval, Options
5
+ from mlquantify.utils._sampling import (
6
+ get_indexes_with_prevalence,
7
+ simplex_grid_sampling,
8
+ simplex_uniform_kraemer,
9
+ simplex_uniform_sampling,
10
+ )
11
+ from mlquantify.utils._validation import validate_data
12
+ from abc import ABC, abstractmethod
13
+ from logging import warning
14
+ import numpy as np
15
+
16
+
17
+ class BaseProtocol(ProtocolMixin, BaseQuantifier):
18
+ """Base class for evaluation protocols.
19
+
20
+ Parameters
21
+ ----------
22
+ batch_size : int or list of int
23
+ The size of the batches to be used in the evaluation.
24
+ random_state : int, optional
25
+ The random seed for reproducibility.
26
+
27
+ Attributes
28
+ ----------
29
+ n_combinations : int
30
+
31
+ Raises
32
+ ------
33
+ ValueError
34
+ If the batch size is not a positive integer or list of positive integers.
35
+
36
+ Notes
37
+ -----
38
+ This class serves as a base class for different evaluation protocols, each with its own strategy for splitting the data into batches.
39
+
40
+ Examples
41
+ --------
42
+ >>> class MyCustomProtocol(Protocol):
43
+ ... def _iter_indices(self, X: np.ndarray, y: np.ndarray):
44
+ ... for batch_size in self.batch_size:
45
+ ... yield np.random.choice(X.shape[0], batch_size, replace=True)
46
+ ...
47
+ >>> protocol = MyCustomProtocol(batch_size=100, random_state=42)
48
+ >>> for idx in protocol.split(X, y):
49
+ ... # Train and evaluate model
50
+ ... pass
51
+
52
+ """
53
+
54
+ _parameter_constraints = {
55
+ "batch_size": [Interval(left=1, right=None, discrete=True)],
56
+ "random_state": [Interval(left=0, right=None, discrete=True)]
57
+ }
58
+
59
+ def __init__(self, batch_size, random_state=None, **kwargs):
60
+ if isinstance(batch_size, int):
61
+ self.n_combinations = 1
62
+ else:
63
+ self.n_combinations = len(batch_size)
64
+
65
+ self.batch_size = [batch_size] if isinstance(batch_size, int) else batch_size
66
+ self.random_state = random_state
67
+
68
+ for name, value in kwargs.items():
69
+ setattr(self, name, value)
70
+ if isinstance(value, list):
71
+ self.n_combinations *= len(value)
72
+ elif isinstance(value, (int, float)):
73
+ self.n_combinations *= value
74
+ else:
75
+ raise ValueError(f"Invalid argument {name}={value}: must be int/float or list of int/float.")
76
+
77
+
78
+ def split(self, X: np.ndarray, y: np.ndarray):
79
+ """
80
+ Split the data into samples for evaluation.
81
+
82
+ Parameters
83
+ ----------
84
+ X : np.ndarray
85
+ The input features.
86
+ y : np.ndarray
87
+ The target labels.
88
+
89
+ Yields
90
+ ------
91
+ Generator[np.ndarray, np.ndarray]
92
+ A generator that yields the indices for each split.
93
+ """
94
+ X, y = validate_data(self, X, y)
95
+ for idx in self._iter_indices(X, y):
96
+ if len(idx) > len(X):
97
+ warning(f"Batch size {len(idx)} exceeds dataset size {len(X)}. Replacement sampling will be used.")
98
+ yield idx
99
+
100
+
101
+ @abstractmethod
102
+ def _iter_indices(self, X, y):
103
+ """Abstract method to be implemented by subclasses to yield indices for each batch."""
104
+ pass
105
+
106
+ def get_n_combinations(self):
107
+ """
108
+ Get the number of combinations for the current protocol.
109
+ """
110
+ return self.n_combinations
111
+
112
+
113
+
114
+ # ===========================================
115
+ # Protocol Implementations
116
+ # ===========================================
117
+
118
+
119
+ class APP(BaseProtocol):
120
+ """
121
+ Artificial Prevalence Protocol (APP) for exhaustive prevalent batch evaluation.
122
+
123
+ Generates batches with artificially imposed prevalences across all possible
124
+ combinations within specified bounds. This allows comprehensive evaluation
125
+ over a range of prevalence scenarios.
126
+
127
+ Parameters
128
+ ----------
129
+ batch_size : int or list of int
130
+ Size(s) of the evaluation batches.
131
+ n_prevalences : int
132
+ Number of artificial prevalence levels to sample per class dimension.
133
+ repeats : int, optional (default=1)
134
+ Number of repetitions for each prevalence sampling.
135
+ random_state : int, optional
136
+ Random seed for reproducibility.
137
+ min_prev : float, optional (default=0.0)
138
+ Minimum possible prevalence for any class.
139
+ max_prev : float, optional (default=1.0)
140
+ Maximum possible prevalence for any class.
141
+
142
+ Notes
143
+ -----
144
+ For multiclass problems, this protocol may have high computational complexity
145
+ due to combinatorial explosion in prevalence combinations.
146
+
147
+ Examples
148
+ --------
149
+ >>> protocol = APP(batch_size=[100, 200], n_prevalences=5, repeats=3, random_state=42)
150
+ >>> for idx in protocol.split(X, y):
151
+ ... # Train and evaluate
152
+ ... pass
153
+ """
154
+
155
+ _parameter_constraints = {
156
+ "n_prevalences": [Interval(left=1, right=None, discrete=True)],
157
+ "repeats": [Interval(left=1, right=None, discrete=True)],
158
+ "min_prev": [Interval(left=0.0, right=1.0)],
159
+ "max_prev": [Interval(left=0.0, right=1.0)]
160
+ }
161
+
162
+ def __init__(self, batch_size, n_prevalences, repeats=1, random_state=None, min_prev=0.0, max_prev=1.0):
163
+ super().__init__(batch_size=batch_size,
164
+ random_state=random_state,
165
+ n_prevalences=n_prevalences,
166
+ repeats=repeats)
167
+ self.min_prev = min_prev
168
+ self.max_prev = max_prev
169
+
170
+ def _iter_indices(self, X: np.ndarray, y: np.ndarray):
171
+
172
+ n_dim = len(np.unique(y))
173
+
174
+ for batch_size in self.batch_size:
175
+ prevalences = simplex_grid_sampling(n_dim=n_dim,
176
+ n_prev=self.n_prevalences,
177
+ n_iter=self.repeats,
178
+ min_val=self.min_prev,
179
+ max_val=self.max_prev)
180
+ for prev in prevalences:
181
+ indexes = get_indexes_with_prevalence(y, prev, batch_size)
182
+ yield indexes
183
+
184
+
185
+
186
+
187
+ class NPP(BaseProtocol):
188
+ """
189
+ Natural Prevalence Protocol (NPP) that samples data without imposing prevalence constraints.
190
+
191
+ This protocol simply samples batches randomly with replacement,
192
+ ignoring prevalence distributions.
193
+
194
+ Parameters
195
+ ----------
196
+ batch_size : int or list of int
197
+ Size(s) of the evaluation batches.
198
+ n_samples : int, optional (default=1)
199
+ Number of distinct batch samples per batch size.
200
+ repeats : int, optional (default=1)
201
+ Number of repetitions for each batch sample.
202
+ random_state : int, optional
203
+ Random seed for reproducibility.
204
+
205
+ Examples
206
+ --------
207
+ >>> protocol = NPP(batch_size=100, random_state=42)
208
+ >>> for idx in protocol.split(X, y):
209
+ ... # Train and evaluate
210
+ ... pass
211
+ """
212
+
213
+ _parameter_constraints = {
214
+ "repeats": [Interval(left=1, right=None, discrete=True)]
215
+ }
216
+
217
+ def __init__(self, batch_size, n_samples=1, repeats=1, random_state=None):
218
+ super().__init__(batch_size=batch_size,
219
+ random_state=random_state)
220
+ self.n_samples = n_samples
221
+ self.repeats = repeats
222
+
223
+ def _iter_indices(self, X: np.ndarray, y: np.ndarray):
224
+
225
+ for _ in range(self.n_samples):
226
+ for batch_size in self.batch_size:
227
+ idx = np.random.choice(X.shape[0], batch_size, replace=True)
228
+ for _ in range(self.repeats):
229
+ yield idx
230
+
231
+
232
+ class UPP(BaseProtocol):
233
+ """
234
+ Uniform Prevalence Protocol (UPP) for uniform sampling of artificial prevalences.
235
+
236
+ Similar to APP, but uses uniform prevalence distribution generation
237
+ methods such as Kraemer or uniform simplex sampling to generate batches
238
+ with uniformly sampled class prevalences.
239
+
240
+ Parameters
241
+ ----------
242
+ batch_size : int or list of int
243
+ Batch size(s) for evaluation.
244
+ n_prevalences : int
245
+ Number of prevalence samples per class.
246
+ repeats : int
247
+ Number of evaluation repeats with different samples.
248
+ random_state : int, optional
249
+ Random seed for reproducibility.
250
+ min_prev : float, optional (default=0.0)
251
+ Minimum prevalence limit.
252
+ max_prev : float, optional (default=1.0)
253
+ Maximum prevalence limit.
254
+ algorithm : {'kraemer', 'uniform'}, optional (default='kraemer')
255
+ Sampling algorithm used to generate artificial prevalences.
256
+
257
+ Examples
258
+ --------
259
+ >>> protocol = UPP(batch_size=100, n_prevalences=5, repeats=3, random_state=42)
260
+ >>> for idx in protocol.split(X, y):
261
+ ... # Train and evaluate
262
+ ... pass
263
+ """
264
+
265
+ _parameter_constraints = {
266
+ "n_prevalences": [Interval(left=1, right=None, discrete=True)],
267
+ "repeats": [Interval(left=1, right=None, discrete=True)],
268
+ "min_prev": [Interval(left=0.0, right=1.0)],
269
+ "max_prev": [Interval(left=0.0, right=1.0)],
270
+ "algorithm": [Options(['kraemer', 'uniform'])]
271
+ }
272
+
273
+ def __init__(self,
274
+ batch_size,
275
+ n_prevalences,
276
+ repeats=1,
277
+ random_state=None,
278
+ min_prev=0.0,
279
+ max_prev=1.0,
280
+ algorithm='kraemer'):
281
+ super().__init__(batch_size=batch_size,
282
+ random_state=random_state,
283
+ n_prevalences=n_prevalences,
284
+ repeats=repeats)
285
+ self.min_prev = min_prev
286
+ self.max_prev = max_prev
287
+ self.algorithm = algorithm
288
+
289
+ def _iter_indices(self, X: np.ndarray, y: np.ndarray):
290
+
291
+ n_dim = len(np.unique(y))
292
+
293
+ for batch_size in self.batch_size:
294
+ if self.algorithm == 'kraemer':
295
+ prevalences = simplex_uniform_kraemer(n_dim=n_dim,
296
+ n_prev=self.n_prevalences,
297
+ n_iter=self.repeats,
298
+ min_val=self.min_prev,
299
+ max_val=self.max_prev)
300
+ elif self.algorithm == 'uniform':
301
+ prevalences = simplex_uniform_sampling(n_dim=n_dim,
302
+ n_prev=self.n_prevalences,
303
+ n_iter=self.repeats,
304
+ min_val=self.min_prev,
305
+ max_val=self.max_prev)
306
+
307
+ for prev in prevalences:
308
+ indexes = get_indexes_with_prevalence(y, prev, batch_size)
309
+ yield indexes
310
+
311
+
312
+ class PPP(BaseProtocol):
313
+ """
314
+ Personalized Prevalence Protocol (PPP) for targeted prevalence batch generation.
315
+
316
+ Generates batches with user-specified prevalence distributions, allowing for
317
+ controlled evaluation on specific scenarios.
318
+
319
+ Parameters
320
+ ----------
321
+ batch_size : int or list of int
322
+ Batch sizes to generate.
323
+ prevalences : list of floats or array-like
324
+ Custom target prevalences per class to generate evaluation batches.
325
+ repeats : int, optional (default=1)
326
+ Number of evaluation repetitions with different batches.
327
+ random_state : int, optional
328
+ Random seed for reproducibility.
329
+
330
+ Examples
331
+ --------
332
+ >>> protocol = PPP(batch_size=100, prevalences=[0.1, 0.9], repeats=3, random_state=42)
333
+ >>> for idx in protocol.split(X, y):
334
+ ... # Train and evaluate
335
+ ... pass
336
+ """
337
+
338
+ _parameter_constraints = {
339
+ "repeats": [Interval(left=1, right=None, discrete=True)],
340
+ "prevalences": ["array-like"]
341
+ }
342
+
343
+ def __init__(self, batch_size, prevalences, repeats=1, random_state=None):
344
+ super().__init__(batch_size=batch_size,
345
+ random_state=random_state,
346
+ prevalences=prevalences,
347
+ repeats=repeats)
348
+
349
+ def _iter_indices(self, X: np.ndarray, y: np.ndarray):
350
+
351
+ for batch_size in self.batch_size:
352
+ for prev in self.prevalences:
353
+ if isinstance(prev, float):
354
+ prev = [1-prev, prev]
355
+
356
+ indexes = get_indexes_with_prevalence(y, prev, batch_size)
357
+ yield indexes
358
+
@@ -0,0 +1,315 @@
1
+ from mlquantify.base import BaseQuantifier, MetaquantifierMixin
2
+ import itertools
3
+ from joblib import Parallel, delayed
4
+ from copy import deepcopy
5
+ import numpy as np
6
+ from sklearn.model_selection import train_test_split
7
+ from mlquantify.metrics._slq import MAE
8
+ from mlquantify.utils._constraints import (
9
+ Interval,
10
+ Options,
11
+ CallableConstraint
12
+ )
13
+ from mlquantify.utils._validation import validate_data
14
+ from mlquantify.utils._decorators import _fit_context
15
+ from mlquantify.utils.prevalence import get_prev_from_labels
16
+ from mlquantify.model_selection import (
17
+ APP, NPP, UPP
18
+ )
19
+
20
+ class GridSearchQ(MetaquantifierMixin, BaseQuantifier):
21
+ """
22
+ Grid Search for Quantifiers with evaluation protocols.
23
+
24
+ This class automates the hyperparameter search over a grid of parameter
25
+ combinations for a given quantifier. It evaluates each combination using
26
+ a specified evaluation protocol (e.g., APP, NPP, UPP), over multiple splits
27
+ of the validation data, and selects the best-performing parameters based on
28
+ a chosen scoring metric such as Mean Absolute Error (MAE).
29
+
30
+ Parameters
31
+ ----------
32
+ quantifier : BaseQuantifier
33
+ Quantifier class (not instance). It must implement fit and predict.
34
+ param_grid : dict
35
+ Dictionary where keys are parameter names and values are lists of parameter
36
+ values to try.
37
+ protocol : {'app', 'npp', 'upp'}, default='app'
38
+ Evaluation protocol to use for splitting the validation data.
39
+ samples_sizes : int or list of int, default=100
40
+ Batch size(s) for evaluation splits.
41
+ n_repetitions : int, default=10
42
+ Number of random repetitions per evaluation.
43
+ scoring : callable, default=MAE
44
+ Scoring function to evaluate prevalence prediction quality.
45
+ Must accept (true_prevalences, predicted_prevalences) arrays.
46
+ refit : bool, default=True
47
+ If True, refits the quantifier on the whole data using best parameters.
48
+ val_split : float, default=0.4
49
+ Fraction of data reserved for validation during parameter search.
50
+ n_jobs : int or None, default=1
51
+ Number of parallel jobs for evaluation.
52
+ random_seed : int, default=42
53
+ Random seed for reproducibility.
54
+ verbose : bool, default=False
55
+ Enable verbose output during evaluation.
56
+
57
+ Attributes
58
+ ----------
59
+ best_score : float
60
+ Best score (lowest loss) found during grid search.
61
+ best_params : dict
62
+ Parameter combination corresponding to best_score.
63
+ best_model_ : BaseQuantifier
64
+ Refitted quantifier instance with best parameters after search.
65
+
66
+ Methods
67
+ -------
68
+ fit(X, y)
69
+ Runs grid search over param_grid, evaluates with the selected protocol,
70
+ and stores best found parameters and model.
71
+ predict(X)
72
+ Predicts prevalences using the best fitted model after search.
73
+ best_params()
74
+ Returns the best parameter dictionary after fitting.
75
+ best_model()
76
+ Returns the best refitted quantifier after fitting.
77
+ sout(msg)
78
+ Utility method to print messages if verbose is enabled.
79
+
80
+ Examples
81
+ --------
82
+ >>> from mlquantify.quantifiers import SomeQuantifier
83
+ >>> param_grid = {'alpha': [0.1, 1.0], 'beta': [10, 20]}
84
+ >>> grid_search = GridSearchQ(quantifier=SomeQuantifier,
85
+ ... param_grid=param_grid,
86
+ ... protocol='app',
87
+ ... samples_sizes=100,
88
+ ... n_repetitions=5,
89
+ ... scoring=MAE,
90
+ ... refit=True,
91
+ ... val_split=0.3,
92
+ ... n_jobs=2,
93
+ ... random_seed=123,
94
+ ... verbose=True)
95
+ >>> grid_search.fit(X_train, y_train)
96
+ >>> y_pred = grid_search.predict(X_test)
97
+ >>> best_params = grid_search.best_params()
98
+ >>> best_model = grid_search.best_model()
99
+ """
100
+
101
+ _parameter_constraints = {
102
+ "quantifier": [BaseQuantifier],
103
+ "param_grid": [dict],
104
+ "protocol": [Options({'app', 'npp', 'upp'})],
105
+ "n_samples": [Interval(1, None)],
106
+ "n_repetitions": [Interval(1, None)],
107
+ "scoring": [CallableConstraint()],
108
+ "refit": [bool],
109
+ "val_split": [Interval(0.0, 1.0, inclusive_left=False, inclusive_right=False)],
110
+ "n_jobs": [Interval(1, None), None],
111
+ "random_seed": [Interval(0, None), None],
112
+ "timeout": [Interval(-1, None)],
113
+ "verbose": [bool]
114
+ }
115
+
116
+
117
+ def __init__(self,
118
+ quantifier,
119
+ param_grid,
120
+ protocol="app",
121
+ samples_sizes=100,
122
+ n_repetitions=10,
123
+ scoring=MAE,
124
+ refit=True,
125
+ val_split=0.4,
126
+ n_jobs=1,
127
+ random_seed=42,
128
+ verbose=False):
129
+
130
+ self.quantifier = quantifier()
131
+ self.param_grid = param_grid
132
+ self.protocol = protocol.lower()
133
+ self.samples_sizes = samples_sizes
134
+ self.n_repetitions = n_repetitions
135
+ self.refit = refit
136
+ self.val_split = val_split
137
+ self.n_jobs = n_jobs
138
+ self.random_seed = random_seed
139
+ self.verbose = verbose
140
+ self.scoring = scoring
141
+
142
+
143
+ def sout(self, msg):
144
+ """Prints messages if verbose is True."""
145
+ if self.verbose:
146
+ print(f'[{self.__class__.__name__}]: {msg}')
147
+
148
+ def __get_protocol(self):
149
+
150
+ if self.protocol == "app":
151
+ return APP(batch_size=self.samples_sizes,
152
+ n_prevalences=self.n_repetitions,
153
+ repeats=self.n_repetitions,
154
+ random_state=self.random_seed,
155
+ min_prev=0.0,
156
+ max_prev=1.0)
157
+ elif self.protocol == "npp":
158
+ return NPP(batch_size=self.samples_sizes,
159
+ n_samples=self.n_repetitions,
160
+ repeats=self.n_repetitions,
161
+ random_state=self.random_seed)
162
+ elif self.protocol == "upp":
163
+ return UPP(batch_size=self.samples_sizes,
164
+ n_prevalences=self.n_repetitions,
165
+ repeats=self.n_repetitions,
166
+ random_state=self.random_seed,
167
+ min_prev=0.0,
168
+ max_prev=1.0)
169
+ else:
170
+ raise ValueError(f'Unknown protocol: {self.protocol}')
171
+
172
+ @_fit_context(prefer_skip_nested_validation=True)
173
+ def fit(self, X, y):
174
+ """
175
+ Fit quantifiers over grid parameter combinations with evaluation protocol.
176
+
177
+ Splits data into training and validation by val_split, and evaluates
178
+ each parameter combination multiple times with protocol-generated batches.
179
+
180
+ Parameters
181
+ ----------
182
+ X : array-like
183
+ Feature matrix for training.
184
+ y : array-like
185
+ Target labels for training.
186
+
187
+ Returns
188
+ -------
189
+ self : object
190
+ Returns self for chaining.
191
+ """
192
+ X, y = validate_data(self, X, y)
193
+
194
+
195
+ X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=self.val_split, random_state=self.random_seed)
196
+ param_combinations = list(itertools.product(*self.param_grid.values()))
197
+ params = list(self.param_grid.keys())
198
+
199
+ best_score, best_params = None, None
200
+
201
+ def evaluate_combination(params):
202
+
203
+ self.sout(f'Evaluating combination: {str(params)}')
204
+
205
+ errors = []
206
+
207
+ params = dict(zip(self.param_grid.keys(), params))
208
+
209
+ model = deepcopy(self.quantifier)
210
+ model.set_params(**params)
211
+
212
+ protocol = self.__get_protocol()
213
+
214
+ model.fit(X_train, y_train)
215
+
216
+ for idx in protocol.split(X_val, y_val):
217
+ X_batch, y_batch = X_val[idx], y_val[idx]
218
+
219
+ y_real = get_prev_from_labels(y_batch)
220
+ y_pred = model.predict(X_batch)
221
+
222
+
223
+ errors.append(self.scoring(y_real, y_pred))
224
+
225
+ avg_score = np.mean(errors)
226
+
227
+ self.sout(f'\\--Finished evaluation: {str(params)} with score: {avg_score}')
228
+
229
+ return avg_score
230
+
231
+
232
+ results = Parallel(n_jobs=self.n_jobs)(
233
+ delayed(evaluate_combination)(params) for params in param_combinations
234
+ )
235
+
236
+
237
+ for score, params in zip(results, param_combinations):
238
+ if score is not None and (best_score is None or score < best_score):
239
+ best_score, best_params = score, params
240
+
241
+
242
+ self.best_score = best_score
243
+ self.best_params = dict(zip(self.param_grid.keys(), best_params))
244
+ self.sout(f'Optimization complete. Best score: {self.best_score}, with parameters: {self.best_params}.')
245
+
246
+ if self.refit and self.best_params:
247
+ model = deepcopy(self.quantifier)
248
+ model.set_params(**self.best_params)
249
+ model.fit(X, y)
250
+ self.best_model_ = model
251
+
252
+ return self
253
+
254
+
255
+
256
+ def predict(self, X):
257
+ """
258
+ Predict using the best found model.
259
+
260
+ Parameters
261
+ ----------
262
+ X : array-like
263
+ Data for prediction.
264
+
265
+ Returns
266
+ -------
267
+ predictions : array-like
268
+ Prevalence predictions.
269
+
270
+ Raises
271
+ ------
272
+ RuntimeError
273
+ If called before fitting.
274
+ """
275
+ if not hasattr(self, 'best_model_'):
276
+ raise RuntimeError("The model has not been fitted yet.")
277
+ return self.best_model_.predict(X)
278
+
279
+
280
+
281
+ def best_params(self):
282
+ """Return the best parameters found during fitting.
283
+
284
+ Returns
285
+ -------
286
+ dict
287
+ The best parameters.
288
+
289
+ Raises
290
+ ------
291
+ ValueError
292
+ If called before fitting.
293
+ """
294
+ if hasattr(self, 'best_params'):
295
+ return self.best_params
296
+ raise ValueError('best_params called before fit.')
297
+
298
+
299
+
300
+ def best_model(self):
301
+ """Return the best model after fitting.
302
+
303
+ Returns
304
+ -------
305
+ Quantifier
306
+ The best fitted model.
307
+
308
+ Raises
309
+ ------
310
+ ValueError
311
+ If called before fitting.
312
+ """
313
+ if hasattr(self, 'best_model_'):
314
+ return self.best_model_
315
+ raise ValueError('best_model called before fit.')
@@ -0,0 +1 @@
1
+ # TODO