mlquantify 0.1.8__py3-none-any.whl → 0.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. mlquantify/__init__.py +0 -29
  2. mlquantify/adjust_counting/__init__.py +14 -0
  3. mlquantify/adjust_counting/_adjustment.py +365 -0
  4. mlquantify/adjust_counting/_base.py +247 -0
  5. mlquantify/adjust_counting/_counting.py +145 -0
  6. mlquantify/adjust_counting/_utils.py +114 -0
  7. mlquantify/base.py +117 -519
  8. mlquantify/base_aggregative.py +209 -0
  9. mlquantify/calibration.py +1 -0
  10. mlquantify/confidence.py +335 -0
  11. mlquantify/likelihood/__init__.py +5 -0
  12. mlquantify/likelihood/_base.py +161 -0
  13. mlquantify/likelihood/_classes.py +414 -0
  14. mlquantify/meta/__init__.py +1 -0
  15. mlquantify/meta/_classes.py +761 -0
  16. mlquantify/metrics/__init__.py +21 -0
  17. mlquantify/metrics/_oq.py +109 -0
  18. mlquantify/metrics/_rq.py +98 -0
  19. mlquantify/{evaluation/measures.py → metrics/_slq.py} +43 -28
  20. mlquantify/mixture/__init__.py +7 -0
  21. mlquantify/mixture/_base.py +153 -0
  22. mlquantify/mixture/_classes.py +400 -0
  23. mlquantify/mixture/_utils.py +112 -0
  24. mlquantify/model_selection/__init__.py +9 -0
  25. mlquantify/model_selection/_protocol.py +358 -0
  26. mlquantify/model_selection/_search.py +315 -0
  27. mlquantify/model_selection/_split.py +1 -0
  28. mlquantify/multiclass.py +350 -0
  29. mlquantify/neighbors/__init__.py +9 -0
  30. mlquantify/neighbors/_base.py +198 -0
  31. mlquantify/neighbors/_classes.py +159 -0
  32. mlquantify/{classification/methods.py → neighbors/_classification.py} +48 -66
  33. mlquantify/neighbors/_kde.py +270 -0
  34. mlquantify/neighbors/_utils.py +135 -0
  35. mlquantify/neural/__init__.py +1 -0
  36. mlquantify/utils/__init__.py +47 -2
  37. mlquantify/utils/_artificial.py +27 -0
  38. mlquantify/utils/_constraints.py +219 -0
  39. mlquantify/utils/_context.py +21 -0
  40. mlquantify/utils/_decorators.py +36 -0
  41. mlquantify/utils/_exceptions.py +12 -0
  42. mlquantify/utils/_get_scores.py +159 -0
  43. mlquantify/utils/_load.py +18 -0
  44. mlquantify/utils/_parallel.py +6 -0
  45. mlquantify/utils/_random.py +36 -0
  46. mlquantify/utils/_sampling.py +273 -0
  47. mlquantify/utils/_tags.py +44 -0
  48. mlquantify/utils/_validation.py +447 -0
  49. mlquantify/utils/prevalence.py +61 -0
  50. {mlquantify-0.1.8.dist-info → mlquantify-0.1.9.dist-info}/METADATA +2 -1
  51. mlquantify-0.1.9.dist-info/RECORD +53 -0
  52. mlquantify/classification/__init__.py +0 -1
  53. mlquantify/evaluation/__init__.py +0 -14
  54. mlquantify/evaluation/protocol.py +0 -289
  55. mlquantify/methods/__init__.py +0 -37
  56. mlquantify/methods/aggregative.py +0 -1159
  57. mlquantify/methods/meta.py +0 -472
  58. mlquantify/methods/mixture_models.py +0 -1003
  59. mlquantify/methods/non_aggregative.py +0 -136
  60. mlquantify/methods/threshold_optimization.py +0 -869
  61. mlquantify/model_selection.py +0 -377
  62. mlquantify/plots.py +0 -367
  63. mlquantify/utils/general.py +0 -371
  64. mlquantify/utils/method.py +0 -449
  65. mlquantify-0.1.8.dist-info/RECORD +0 -22
  66. {mlquantify-0.1.8.dist-info → mlquantify-0.1.9.dist-info}/WHEEL +0 -0
  67. {mlquantify-0.1.8.dist-info → mlquantify-0.1.9.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,350 @@
1
+ from copy import deepcopy
2
+ import numpy as np
3
+ from abc import abstractmethod
4
+ from mlquantify.base import BaseQuantifier
5
+ from mlquantify.base_aggregative import get_aggregation_requirements
6
+ from mlquantify.utils._decorators import _fit_context
7
+ from mlquantify.base import BaseQuantifier, MetaquantifierMixin
8
+ from mlquantify.utils._validation import validate_prevalences, check_has_method
9
+
10
+
11
+ from copy import deepcopy
12
+ from itertools import combinations
13
+ import numpy as np
14
+ from abc import abstractmethod
15
+ from mlquantify.base import BaseQuantifier, MetaquantifierMixin
16
+ from mlquantify.base_aggregative import get_aggregation_requirements
17
+ from mlquantify.utils._decorators import _fit_context
18
+ from mlquantify.utils._validation import validate_prevalences, check_has_method
19
+
20
+
21
+ # ============================================================
22
+ # Decorator for enabling binary quantification behavior
23
+ # ============================================================
24
+ def define_binary(cls):
25
+ """Decorator to enable binary quantification extensions (One-vs-Rest or One-vs-One).
26
+
27
+ This decorator dynamically extends a quantifier class to handle multiclass
28
+ quantification tasks by decomposing them into multiple binary subproblems,
29
+ following either the One-vs-Rest (OvR) or One-vs-One (OvO) strategy.
30
+
31
+ It automatically replaces the class methods `fit`, `predict`, and `aggregate`
32
+ with binary-aware versions from `BinaryQuantifier`, while preserving access
33
+ to the original implementations via `_original_fit`, `_original_predict`,
34
+ and `_original_aggregate`.
35
+
36
+ Parameters
37
+ ----------
38
+ cls : class
39
+ A subclass of `BaseQuantifier` implementing standard binary quantification
40
+ methods (`fit`, `predict`, and `aggregate`).
41
+
42
+ Returns
43
+ -------
44
+ class
45
+ The same class with binary quantification capabilities added.
46
+
47
+ Examples
48
+ --------
49
+ >>> from mlquantify.base import BaseQuantifier
50
+ >>> from mlquantify.binary import define_binary
51
+
52
+ >>> @define_binary
53
+ ... class MyQuantifier(BaseQuantifier):
54
+ ... def fit(self, X, y):
55
+ ... # Custom binary training logic
56
+ ... self.classes_ = np.unique(y)
57
+ ... return self
58
+ ...
59
+ ... def predict(self, X):
60
+ ... # Return dummy prevalences
61
+ ... return np.array([0.4, 0.6])
62
+ ...
63
+ ... def aggregate(self, preds, y_train):
64
+ ... # Example aggregation method
65
+ ... return np.mean(preds, axis=0)
66
+
67
+ >>> qtf = MyQuantifier()
68
+ >>> qtf.strategy = 'ovr' # or 'ovo'
69
+ >>> X = np.random.randn(10, 5)
70
+ >>> y = np.random.randint(0, 3, 10)
71
+ >>> qtf.fit(X, y)
72
+ MyQuantifier(...)
73
+ >>> qtf.predict(X)
74
+ array([...])
75
+ """
76
+ if check_has_method(cls, "fit"):
77
+ cls._original_fit = cls.fit
78
+ if check_has_method(cls, "predict"):
79
+ cls._original_predict = cls.predict
80
+ if check_has_method(cls, "aggregate"):
81
+ cls._original_aggregate = cls.aggregate
82
+
83
+ cls.fit = BinaryQuantifier.fit
84
+ cls.predict = BinaryQuantifier.predict
85
+ cls.aggregate = BinaryQuantifier.aggregate
86
+
87
+ return cls
88
+
89
+
90
+ # ============================================================
91
+ # Fitting strategies
92
+ # ============================================================
93
+ def _fit_ovr(quantifier, X, y):
94
+ """Fit using One-vs-Rest (OvR) strategy.
95
+
96
+ Creates a binary quantifier for each class, trained to distinguish that class
97
+ versus all others.
98
+
99
+ Parameters
100
+ ----------
101
+ quantifier : BaseQuantifier
102
+ The quantifier instance being trained.
103
+ X : array-like of shape (n_samples, n_features)
104
+ Training feature matrix.
105
+ y : array-like of shape (n_samples,)
106
+ Class labels.
107
+
108
+ Returns
109
+ -------
110
+ dict
111
+ A mapping from class label to fitted binary quantifier.
112
+ """
113
+ quantifiers = {}
114
+ for cls in np.unique(y):
115
+ qtf = deepcopy(quantifier)
116
+ y_bin = (y == cls).astype(int)
117
+ qtf._original_fit(X, y_bin)
118
+ quantifiers[cls] = qtf
119
+ return quantifiers
120
+
121
+
122
+ def _fit_ovo(quantifier, X, y):
123
+ """Fit using One-vs-One (OvO) strategy.
124
+
125
+ Creates a binary quantifier for every pair of classes, trained to distinguish
126
+ one class from another.
127
+
128
+ Parameters
129
+ ----------
130
+ quantifier : BaseQuantifier
131
+ The quantifier instance being trained.
132
+ X : array-like of shape (n_samples, n_features)
133
+ Training feature matrix.
134
+ y : array-like of shape (n_samples,)
135
+ Class labels.
136
+
137
+ Returns
138
+ -------
139
+ dict
140
+ A mapping from (class1, class2) tuples to fitted binary quantifiers.
141
+ """
142
+ quantifiers = {}
143
+ for cls1, cls2 in combinations(np.unique(y), 2):
144
+ qtf = deepcopy(quantifier)
145
+ mask = (y == cls1) | (y == cls2)
146
+ y_bin = (y[mask] == cls1).astype(int)
147
+ qtf._original_fit(X[mask], y_bin)
148
+ quantifiers[(cls1, cls2)] = qtf
149
+ return quantifiers
150
+
151
+
152
+ # ============================================================
153
+ # Prediction strategies
154
+ # ============================================================
155
+ def _predict_ovr(quantifier, X):
156
+ """Predict using One-vs-Rest (OvR) strategy.
157
+
158
+ Each binary quantifier produces a prevalence estimate for its corresponding class.
159
+
160
+ Parameters
161
+ ----------
162
+ quantifier : BinaryQuantifier
163
+ Fitted quantifier containing binary models.
164
+ X : array-like of shape (n_samples, n_features)
165
+ Test feature matrix.
166
+
167
+ Returns
168
+ -------
169
+ np.ndarray
170
+ Predicted prevalences for each class.
171
+ """
172
+ preds = np.zeros(len(quantifier.qtfs_))
173
+ for i, qtf in enumerate(quantifier.qtfs_.values()):
174
+ preds[i] = qtf._original_predict(X)[1]
175
+ return preds
176
+
177
+
178
+ def _predict_ovo(quantifier, X):
179
+ """Predict using One-vs-One (OvO) strategy.
180
+
181
+ Each binary quantifier outputs a prevalence estimate for the pair of classes it was trained on.
182
+
183
+ Parameters
184
+ ----------
185
+ quantifier : BinaryQuantifier
186
+ Fitted quantifier containing binary models.
187
+ X : array-like of shape (n_samples, n_features)
188
+ Test feature matrix.
189
+
190
+ Returns
191
+ -------
192
+ np.ndarray
193
+ Pairwise prevalence predictions.
194
+ """
195
+ preds = np.zeros(len(quantifier.qtfs_))
196
+ for i, (cls1, cls2) in enumerate(combinations(quantifier.qtfs_.keys(), 2)):
197
+ qtf = quantifier.qtfs_[(cls1, cls2)]
198
+ preds[i] = qtf._original_predict(X)[1]
199
+ return preds
200
+
201
+
202
+ # ============================================================
203
+ # Aggregation strategies
204
+ # ============================================================
205
+ def _aggregate_ovr(quantifier, preds, y_train, train_preds=None):
206
+ """Aggregate binary predictions using One-vs-Rest (OvR).
207
+
208
+ Parameters
209
+ ----------
210
+ quantifier : BinaryQuantifier
211
+ Quantifier performing the aggregation.
212
+ preds : ndarray of shape (n_samples, n_classes)
213
+ Model predictions.
214
+ y_train : ndarray of shape (n_samples,)
215
+ Training labels.
216
+ train_preds : ndarray of shape (n_samples, n_classes), optional
217
+ Predictions on the training set.
218
+
219
+ Returns
220
+ -------
221
+ dict
222
+ Class-wise prevalence estimates.
223
+ """
224
+ prevalences = {}
225
+ for i, cls in enumerate(np.unique(y_train)):
226
+ bin_preds = np.column_stack([1 - preds[:, i], preds[:, i]])
227
+ y_bin = (y_train == cls).astype(int)
228
+ args = [bin_preds]
229
+
230
+ if train_preds is not None:
231
+ bin_train_preds = np.column_stack([1 - train_preds[:, i], train_preds[:, i]])
232
+ args.append(bin_train_preds)
233
+
234
+ args.append(y_bin)
235
+ prevalences[cls] = quantifier._original_aggregate(*args)[1]
236
+ return prevalences
237
+
238
+
239
+ def _aggregate_ovo(quantifier, preds, y_train, train_preds=None):
240
+ """Aggregate binary predictions using One-vs-One (OvO).
241
+
242
+ Parameters
243
+ ----------
244
+ quantifier : BinaryQuantifier
245
+ Quantifier performing the aggregation.
246
+ preds : ndarray
247
+ Model predictions.
248
+ y_train : ndarray
249
+ Training labels.
250
+ train_preds : ndarray, optional
251
+ Predictions on the training set.
252
+
253
+ Returns
254
+ -------
255
+ dict
256
+ Pairwise prevalence estimates.
257
+ """
258
+ prevalences = {}
259
+ for cls1, cls2 in combinations(np.unique(y_train), 2):
260
+ bin_preds = np.column_stack([1 - preds[:, (cls1, cls2)], preds[:, (cls1, cls2)]])
261
+ mask = (y_train == cls1) | (y_train == cls2)
262
+ y_bin = (y_train[mask] == cls1).astype(int)
263
+
264
+ args = [bin_preds]
265
+ if train_preds is not None:
266
+ bin_train_preds = np.column_stack([1 - train_preds[:, (cls1, cls2)], train_preds[:, (cls1, cls2)]])
267
+ args.append(bin_train_preds)
268
+
269
+ args.append(y_bin)
270
+ prevalences[(cls1, cls2)] = quantifier._original_aggregate(*args)[1]
271
+ return prevalences
272
+
273
+
274
+ # ============================================================
275
+ # Main Class
276
+ # ============================================================
277
+ class BinaryQuantifier(MetaquantifierMixin, BaseQuantifier):
278
+ """Meta-quantifier enabling One-vs-Rest and One-vs-One strategies.
279
+
280
+ This class extends a base quantifier to handle multiclass problems by
281
+ decomposing them into binary subproblems. It automatically delegates fitting,
282
+ prediction, and aggregation operations to the appropriate binary quantifiers.
283
+
284
+ Attributes
285
+ ----------
286
+ qtfs_ : dict
287
+ Dictionary mapping class labels or label pairs to fitted binary quantifiers.
288
+ strategy : {'ovr', 'ovo'}
289
+ Defines how multiclass quantification is decomposed.
290
+ """
291
+
292
+ @_fit_context(prefer_skip_nested_validation=False)
293
+ def fit(qtf, X, y):
294
+ """Fit the quantifier under a binary decomposition strategy."""
295
+ if len(np.unique(y)) <= 2:
296
+ qtf.binary = True
297
+ return qtf._original_fit(X, y)
298
+
299
+ qtf.strategy = getattr(qtf, "strategy", "ovr")
300
+
301
+ if qtf.strategy == "ovr":
302
+ qtf.qtfs_ = _fit_ovr(qtf, X, y)
303
+ elif qtf.strategy == "ovo":
304
+ qtf.qtfs_ = _fit_ovo(qtf, X, y)
305
+ else:
306
+ raise ValueError("Strategy must be 'ovr' or 'ovo'")
307
+
308
+ return qtf
309
+
310
+ def predict(qtf, X):
311
+ """Predict class prevalences using the trained binary quantifiers."""
312
+ if hasattr(qtf, "binary") and qtf.binary:
313
+ return qtf._original_predict(X)
314
+ else:
315
+ if qtf.strategy == "ovr":
316
+ preds = _predict_ovr(qtf, X)
317
+ elif qtf.strategy == "ovo":
318
+ preds = _predict_ovo(qtf, X)
319
+ else:
320
+ raise ValueError("Strategy must be 'ovr' or 'ovo'")
321
+
322
+ return validate_prevalences(qtf, preds, qtf.qtfs_.keys())
323
+
324
+ def aggregate(qtf, *args):
325
+ """Aggregate binary predictions to obtain multiclass prevalence estimates."""
326
+ requirements = get_aggregation_requirements(qtf)
327
+
328
+ if requirements.requires_train_proba and requirements.requires_train_labels:
329
+ preds, train_preds, y_train = args
330
+ args_dict = dict(preds=preds, train_preds=train_preds, y_train=y_train)
331
+ elif requirements.requires_train_labels:
332
+ preds, y_train = args
333
+ args_dict = dict(preds=preds, y_train=y_train)
334
+ else:
335
+ raise ValueError("Binary aggregation requires at least train labels")
336
+
337
+ classes = np.unique(args_dict["y_train"])
338
+ qtf.strategy = getattr(qtf, "strategy", "ovr")
339
+
340
+ if hasattr(qtf, "binary") and qtf.binary:
341
+ return qtf._original_aggregate(*args_dict.values())
342
+
343
+ if qtf.strategy == "ovr":
344
+ prevalences = _aggregate_ovr(qtf, **args_dict)
345
+ elif qtf.strategy == "ovo":
346
+ prevalences = _aggregate_ovo(qtf, **args_dict)
347
+ else:
348
+ raise ValueError("Strategy must be 'ovr' or 'ovo'")
349
+
350
+ return validate_prevalences(qtf, prevalences, classes)
@@ -0,0 +1,9 @@
1
+ from ._kde import (
2
+ KDEyCS,
3
+ KDEyHD,
4
+ KDEyML,
5
+ )
6
+
7
+ from ._classes import (
8
+ PWK,
9
+ )
@@ -0,0 +1,198 @@
1
+ import numpy as np
2
+ from abc import abstractmethod
3
+ from sklearn.neighbors import KernelDensity
4
+
5
+ from mlquantify.utils._decorators import _fit_context
6
+ from mlquantify.base import BaseQuantifier
7
+ from mlquantify.utils import validate_y, validate_predictions, validate_data, check_classes_attribute
8
+ from mlquantify.base_aggregative import AggregationMixin, SoftLearnerQMixin, _get_learner_function
9
+ from mlquantify.utils._constraints import Interval, Options
10
+ from mlquantify.utils._get_scores import apply_cross_validation
11
+ from mlquantify.utils._validation import validate_prevalences
12
+
13
+ EPS = 1e-12
14
+
15
+ class BaseKDE(SoftLearnerQMixin, AggregationMixin, BaseQuantifier):
16
+ r"""
17
+ Base class for KDEy quantification methods.
18
+
19
+ KDEy methods model the class-conditional densities of posterior probabilities
20
+ using Kernel Density Estimation (KDE) in the probability simplex space.
21
+ Given a probabilistic classifier's posterior outputs, each class distribution
22
+ is approximated as a smooth density function via KDE. Class prevalences in
23
+ the test set are estimated as the mixture weights of these densities that best
24
+ explain the test posterior distribution.
25
+
26
+ Formally, KDEy approximates the test posterior distribution as:
27
+
28
+ \[
29
+ p_{test}(x) \approx \sum_{k=1}^K \alpha_k p_k(x),
30
+ \]
31
+
32
+ where \( p_k(x) \) is the KDE of the posterior scores of class \( k \) on training data,
33
+ and \( \alpha_k \) are the unknown class prevalences to be estimated under:
34
+
35
+ \[
36
+ \alpha_k \geq 0, \quad \sum_{k=1}^K \alpha_k = 1.
37
+ \]
38
+
39
+ The quantification task is then to find the vector \( \boldsymbol{\alpha} = (\alpha_1,\dots,\alpha_K) \)
40
+ minimizing an objective function defined on the mixture density and the test posteriors,
41
+ subject to the simplex constraints on \( \boldsymbol{\alpha} \).
42
+
43
+ Attributes
44
+ ----------
45
+ learner : estimator
46
+ The underlying probabilistic classifier yielding posterior predictions.
47
+ bandwidth : float
48
+ Bandwidth (smoothing parameter) for the KDE models.
49
+ kernel : str
50
+ Kernel type used in KDE (e.g., 'gaussian').
51
+ _precomputed : bool
52
+ Indicates whether KDE models have been fitted on training data.
53
+ best_distance : float or None
54
+ Stores the best value of the objective (distance or loss) achieved.
55
+
56
+ Methods
57
+ -------
58
+ fit(X, y, learner_fitted=False)
59
+ Fits KDE models for each class using posterior predictions of the learner.
60
+ predict(X)
61
+ Aggregates learner’s posterior predictions on X to estimate class prevalences.
62
+ aggregate(predictions, train_predictions, train_y_values)
63
+ Core estimation method that validates inputs, ensures KDE precomputation,
64
+ and calls `_solve_prevalences` implemented by subclasses.
65
+ _fit_kde_models(train_predictions, train_y_values)
66
+ Fits KDE model per class on training data posteriors.
67
+ _solve_prevalences(predictions)
68
+ Abstract method to estimate prevalence vector \( \boldsymbol{\alpha} \) for given posteriors.
69
+ Must be implemented by subclasses.
70
+
71
+ Examples
72
+ --------
73
+ To implement a new KDEy quantifier, subclass BaseKDE and implement the method
74
+ `_solve_prevalences`, which receives posterior predictions and returns a tuple
75
+
76
+ (estimated prevalences \(\boldsymbol{\alpha}\), objective value).
77
+
78
+ >>> class KDEyExample(BaseKDE):
79
+ ... def _solve_prevalences(self, predictions):
80
+ ... # Example: simple uniform prevalences, replace with actual optimization
81
+ ... n_classes = len(self._class_kdes)
82
+ ... alpha = np.ones(n_classes) / n_classes
83
+ ... obj_val = 0.0 # Replace with actual objective computation
84
+ ... return alpha, obj_val
85
+
86
+ Mathematical formulation for prevalence estimation typically involves optimizing:
87
+
88
+ \[
89
+ \min_{\boldsymbol{\alpha} \in \Delta^{K-1}} \mathcal{L} \bigg( \sum_{k=1}^K \alpha_k p_k(x), \hat{p}(x) \bigg),
90
+ \]
91
+
92
+ where \(\hat{p}(x)\) is the test posterior distribution (empirical KDE or direct predictions),
93
+ \(\Delta^{K-1}\) is the probability simplex defined by the constraints on \(\boldsymbol{\alpha}\),
94
+ and \(\mathcal{L}\) is an appropriate divergence or loss function, e.g., negative log-likelihood,
95
+ Hellinger distance, or Cauchy–Schwarz divergence.
96
+
97
+ This optimization is typically solved numerically with constrained methods such as
98
+ sequential quadratic programming or projected gradient descent.
99
+
100
+ References
101
+ ----------
102
+ [1] Moreo, A., et al. (2023). Kernel Density Quantification methods and applications.
103
+ In *Learning to Quantify*, Springer.
104
+ """
105
+
106
+ _parameter_constraints = {
107
+ "bandwidth": [Interval(0, None, inclusive_right=False)],
108
+ "kernel": [Options(["gaussian", "tophat", "epanechnikov", "exponential", "linear", "cosine"])],
109
+ }
110
+
111
+ def __init__(self, learner=None, bandwidth: float = 0.1, kernel: str = "gaussian"):
112
+ self.learner = learner
113
+ self.bandwidth = bandwidth
114
+ self.kernel = kernel
115
+ self._precomputed = False
116
+ self.best_distance = None
117
+
118
+ @_fit_context(prefer_skip_nested_validation=True)
119
+ def fit(self, X, y, learner_fitted=False):
120
+ X, y = validate_data(self, X, y, ensure_2d=True, ensure_min_samples=2)
121
+ validate_y(self, y)
122
+
123
+ self.classes_ = np.unique(y)
124
+
125
+ learner_function = _get_learner_function(self)
126
+
127
+ if learner_fitted:
128
+ train_predictions = getattr(self.learner, learner_function)(X)
129
+ train_y_values = y
130
+ else:
131
+ train_predictions, train_y_values = apply_cross_validation(
132
+ self.learner, X, y,
133
+ function=learner_function, cv=5,
134
+ stratified=True, shuffle=True
135
+ )
136
+
137
+ self.train_predictions = train_predictions
138
+ self.train_y_values = train_y_values
139
+ self._precompute_training(train_predictions, train_y_values)
140
+ return self
141
+
142
+ def _fit_kde_models(self, train_predictions, train_y_values):
143
+ P = np.atleast_2d(train_predictions)
144
+ y = np.asarray(train_y_values)
145
+ self._class_kdes = []
146
+
147
+ for c in self.classes_:
148
+ Xi = P[y == c]
149
+ if Xi.shape[0] == 0:
150
+ Xi = np.ones((1, P.shape[1])) / P.shape[1] # fallback
151
+ kde = KernelDensity(bandwidth=self.bandwidth, kernel=self.kernel)
152
+ kde.fit(Xi)
153
+ self._class_kdes.append(kde)
154
+
155
+ self._precomputed = True
156
+
157
+ def predict(self, X):
158
+ predictions = getattr(self.learner, _get_learner_function(self))(X)
159
+ return self.aggregate(predictions, self.train_predictions, self.train_y_values)
160
+
161
+ def aggregate(self, predictions, train_predictions, train_y_values):
162
+ predictions = validate_predictions(self, predictions)
163
+
164
+ if hasattr(self, "classes_") and len(np.unique(train_y_values)) != len(self.classes_):
165
+ self._precomputed = False
166
+
167
+ self.classes_ = check_classes_attribute(self, np.unique(train_y_values))
168
+
169
+ if not self._precomputed:
170
+ self._precompute_training(train_predictions, train_y_values)
171
+ self._precomputed = True
172
+
173
+ prevalence, _ = self._solve_prevalences(predictions)
174
+ prevalence = np.clip(prevalence, EPS, None)
175
+ prevalence = validate_prevalences(self, prevalence, self.classes_)
176
+ return prevalence
177
+
178
+ def best_distance(self, predictions, train_predictions, train_y_values):
179
+ """Retorna a melhor distância encontrada durante o ajuste."""
180
+ if self.best_distance is not None:
181
+ return self.best_distance
182
+
183
+ self.classes_ = check_classes_attribute(self, np.unique(train_y_values))
184
+
185
+ if not self._precomputed:
186
+ self._precompute_training(train_predictions, train_y_values)
187
+ self._precomputed = True
188
+
189
+ _, best_distance = self._solve_prevalences(predictions)
190
+ return best_distance
191
+
192
+ @abstractmethod
193
+ def _precompute_training(self, train_predictions, train_y_values):
194
+ raise NotImplementedError
195
+
196
+ @abstractmethod
197
+ def _solve_prevalences(self, predictions):
198
+ raise NotImplementedError