mlquantify 0.1.7__py3-none-any.whl → 0.1.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlquantify/__init__.py +0 -29
- mlquantify/adjust_counting/__init__.py +14 -0
- mlquantify/adjust_counting/_adjustment.py +365 -0
- mlquantify/adjust_counting/_base.py +247 -0
- mlquantify/adjust_counting/_counting.py +145 -0
- mlquantify/adjust_counting/_utils.py +114 -0
- mlquantify/base.py +117 -519
- mlquantify/base_aggregative.py +209 -0
- mlquantify/calibration.py +1 -0
- mlquantify/confidence.py +335 -0
- mlquantify/likelihood/__init__.py +5 -0
- mlquantify/likelihood/_base.py +161 -0
- mlquantify/likelihood/_classes.py +414 -0
- mlquantify/meta/__init__.py +1 -0
- mlquantify/meta/_classes.py +761 -0
- mlquantify/metrics/__init__.py +21 -0
- mlquantify/metrics/_oq.py +109 -0
- mlquantify/metrics/_rq.py +98 -0
- mlquantify/{evaluation/measures.py → metrics/_slq.py} +43 -28
- mlquantify/mixture/__init__.py +7 -0
- mlquantify/mixture/_base.py +153 -0
- mlquantify/mixture/_classes.py +400 -0
- mlquantify/mixture/_utils.py +112 -0
- mlquantify/model_selection/__init__.py +9 -0
- mlquantify/model_selection/_protocol.py +358 -0
- mlquantify/model_selection/_search.py +315 -0
- mlquantify/model_selection/_split.py +1 -0
- mlquantify/multiclass.py +350 -0
- mlquantify/neighbors/__init__.py +9 -0
- mlquantify/neighbors/_base.py +198 -0
- mlquantify/neighbors/_classes.py +159 -0
- mlquantify/{classification/methods.py → neighbors/_classification.py} +48 -66
- mlquantify/neighbors/_kde.py +270 -0
- mlquantify/neighbors/_utils.py +135 -0
- mlquantify/neural/__init__.py +1 -0
- mlquantify/utils/__init__.py +47 -2
- mlquantify/utils/_artificial.py +27 -0
- mlquantify/utils/_constraints.py +219 -0
- mlquantify/utils/_context.py +21 -0
- mlquantify/utils/_decorators.py +36 -0
- mlquantify/utils/_exceptions.py +12 -0
- mlquantify/utils/_get_scores.py +159 -0
- mlquantify/utils/_load.py +18 -0
- mlquantify/utils/_parallel.py +6 -0
- mlquantify/utils/_random.py +36 -0
- mlquantify/utils/_sampling.py +273 -0
- mlquantify/utils/_tags.py +44 -0
- mlquantify/utils/_validation.py +447 -0
- mlquantify/utils/prevalence.py +61 -0
- {mlquantify-0.1.7.dist-info → mlquantify-0.1.9.dist-info}/METADATA +2 -1
- mlquantify-0.1.9.dist-info/RECORD +53 -0
- mlquantify/classification/__init__.py +0 -1
- mlquantify/evaluation/__init__.py +0 -14
- mlquantify/evaluation/protocol.py +0 -291
- mlquantify/methods/__init__.py +0 -37
- mlquantify/methods/aggregative.py +0 -1159
- mlquantify/methods/meta.py +0 -472
- mlquantify/methods/mixture_models.py +0 -1003
- mlquantify/methods/non_aggregative.py +0 -136
- mlquantify/methods/threshold_optimization.py +0 -869
- mlquantify/model_selection.py +0 -377
- mlquantify/plots.py +0 -367
- mlquantify/utils/general.py +0 -371
- mlquantify/utils/method.py +0 -449
- mlquantify-0.1.7.dist-info/RECORD +0 -22
- {mlquantify-0.1.7.dist-info → mlquantify-0.1.9.dist-info}/WHEEL +0 -0
- {mlquantify-0.1.7.dist-info → mlquantify-0.1.9.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,447 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from typing import Any
|
|
3
|
+
import numpy as np
|
|
4
|
+
import scipy.sparse as sp
|
|
5
|
+
from sklearn.utils.validation import check_array, check_X_y, _check_y
|
|
6
|
+
|
|
7
|
+
from mlquantify.utils._tags import TargetInputTags, get_tags
|
|
8
|
+
from mlquantify.utils._exceptions import InputValidationError, InvalidParameterError, NotFittedError
|
|
9
|
+
from mlquantify.utils._constraints import make_constraint
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
# ---------------------------
|
|
13
|
+
# y Validation
|
|
14
|
+
# ---------------------------
|
|
15
|
+
|
|
16
|
+
def _validate_is_numpy_array(array: Any) -> None:
|
|
17
|
+
"""Ensure y are a numpy array."""
|
|
18
|
+
if not isinstance(array, np.ndarray):
|
|
19
|
+
raise InputValidationError(
|
|
20
|
+
f"y must be a numpy array, got {type(y).__name__}."
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _validate_1d_predictions(quantifier: Any, y: np.ndarray, target_tags: TargetInputTags) -> None:
|
|
25
|
+
"""Validate 1D predictions according to quantifier tags."""
|
|
26
|
+
if target_tags.continuous:
|
|
27
|
+
return # continuous allows any numeric vector
|
|
28
|
+
|
|
29
|
+
n_class = len(np.unique(y))
|
|
30
|
+
|
|
31
|
+
if target_tags.one_d:
|
|
32
|
+
|
|
33
|
+
if n_class > 2 and not target_tags.multi_class:
|
|
34
|
+
raise InputValidationError(
|
|
35
|
+
f"1D predictions for {quantifier.__class__.__name__} must be binary "
|
|
36
|
+
f"with 2 unique values, got {n_class} unique values."
|
|
37
|
+
)
|
|
38
|
+
if not np.issubdtype(y.dtype, np.number) and not target_tags.categorical:
|
|
39
|
+
raise InputValidationError(
|
|
40
|
+
f"1D predictions for {quantifier.__class__.__name__} must be numeric (int or float), "
|
|
41
|
+
f"got dtype {y.dtype}."
|
|
42
|
+
)
|
|
43
|
+
return
|
|
44
|
+
|
|
45
|
+
if not target_tags.one_d:
|
|
46
|
+
raise InputValidationError(
|
|
47
|
+
f"{quantifier.__class__.__name__} does not accept 1D input according to its tags."
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def _validate_2d_predictions(quantifier: Any, y: np.ndarray, target_tags: TargetInputTags) -> None:
|
|
52
|
+
"""Validate 2D predictions according to quantifier tags."""
|
|
53
|
+
if not (target_tags.two_d or target_tags.multi_class):
|
|
54
|
+
raise InputValidationError(
|
|
55
|
+
f"{quantifier.__class__.__name__} does not accept multi-class or 2D input."
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
if not np.issubdtype(y.dtype, np.floating):
|
|
59
|
+
raise InputValidationError(
|
|
60
|
+
f"{quantifier.__class__.__name__} expects float probabilities for 2D predictions, "
|
|
61
|
+
f"got dtype {y.dtype}."
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
# Efficient normalization check for soft probabilities
|
|
65
|
+
if target_tags.two_d:
|
|
66
|
+
row_sums = y.sum(axis=1)
|
|
67
|
+
if np.abs(row_sums - 1).max() > 1e-3:
|
|
68
|
+
raise InputValidationError(
|
|
69
|
+
f"Soft predictions for multiclass quantifiers must sum to 1 across columns "
|
|
70
|
+
f"(max deviation={np.abs(row_sums - 1).max():.3g})."
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def validate_y(quantifier: Any, y: np.ndarray) -> None:
|
|
75
|
+
"""
|
|
76
|
+
Validate predictions using the quantifier's declared input tags.
|
|
77
|
+
Raises InputValidationError if inconsistent with tags.
|
|
78
|
+
"""
|
|
79
|
+
_validate_is_numpy_array(y)
|
|
80
|
+
|
|
81
|
+
try:
|
|
82
|
+
tags = get_tags(quantifier)
|
|
83
|
+
target_tags = tags.target_input_tags
|
|
84
|
+
except AttributeError as e:
|
|
85
|
+
raise InputValidationError(
|
|
86
|
+
f"Quantifier {quantifier.__class__.__name__} does not implement __mlquantify_tags__()."
|
|
87
|
+
) from e
|
|
88
|
+
if y.ndim == 1:
|
|
89
|
+
_validate_1d_predictions(quantifier, y, target_tags)
|
|
90
|
+
elif y.ndim == 2:
|
|
91
|
+
_validate_2d_predictions(quantifier, y, target_tags)
|
|
92
|
+
else:
|
|
93
|
+
raise InputValidationError(
|
|
94
|
+
f"Predictions must be 1D or 2D array, got array with ndim={y.ndim} and shape={y.shape}."
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
def _get_valid_crisp_predictions(predictions, threshold=0.5):
|
|
98
|
+
predictions = np.asarray(predictions)
|
|
99
|
+
|
|
100
|
+
dimensions = predictions.shape[1] if len(predictions.shape) > 1 else 1
|
|
101
|
+
|
|
102
|
+
if dimensions > 2:
|
|
103
|
+
predictions = np.argmax(predictions, axis=1)
|
|
104
|
+
elif dimensions == 2:
|
|
105
|
+
predictions = (predictions[:, 1] > threshold).astype(int)
|
|
106
|
+
elif dimensions == 1:
|
|
107
|
+
if np.issubdtype(predictions.dtype, np.floating):
|
|
108
|
+
predictions = (predictions > threshold).astype(int)
|
|
109
|
+
else:
|
|
110
|
+
raise ValueError(f"Predictions array has an invalid number of dimensions. Expected 1 or more dimensions, got {predictions.ndim}.")
|
|
111
|
+
|
|
112
|
+
return predictions
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def validate_predictions(quantifier: Any, predictions: np.ndarray) -> None:
|
|
116
|
+
"""
|
|
117
|
+
Validate predictions using the quantifier's declared output tags.
|
|
118
|
+
Raises InputValidationError if inconsistent with tags.
|
|
119
|
+
"""
|
|
120
|
+
_validate_is_numpy_array(predictions)
|
|
121
|
+
|
|
122
|
+
try:
|
|
123
|
+
tags = get_tags(quantifier)
|
|
124
|
+
estimator_type = tags.estimator_type
|
|
125
|
+
except AttributeError as e:
|
|
126
|
+
raise InputValidationError(
|
|
127
|
+
f"Quantifier {quantifier.__class__.__name__} does not implement __mlquantify_tags__()."
|
|
128
|
+
) from e
|
|
129
|
+
|
|
130
|
+
if estimator_type == "soft" and np.issubdtype(predictions.dtype, np.integer):
|
|
131
|
+
raise InputValidationError(
|
|
132
|
+
f"Soft predictions for {quantifier.__class__.__name__} must be float, got dtype {predictions.dtype}."
|
|
133
|
+
)
|
|
134
|
+
elif estimator_type == "crisp" and np.issubdtype(predictions.dtype, np.floating):
|
|
135
|
+
predictions = _get_valid_crisp_predictions(predictions)
|
|
136
|
+
return predictions
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
# ---------------------------
|
|
142
|
+
# Parameter Validation
|
|
143
|
+
# ---------------------------
|
|
144
|
+
|
|
145
|
+
def validate_parameter_constraints(parameter_constraints: dict[str, Any], params: dict[str, Any], caller_name: str) -> None:
|
|
146
|
+
"""Validate parameters against their declared constraints."""
|
|
147
|
+
for param_name, param_val in params.items():
|
|
148
|
+
if param_name not in parameter_constraints:
|
|
149
|
+
continue
|
|
150
|
+
|
|
151
|
+
constraints = parameter_constraints[param_name]
|
|
152
|
+
if constraints == "no_validation":
|
|
153
|
+
continue
|
|
154
|
+
|
|
155
|
+
constraint_objs = [make_constraint(c) for c in constraints]
|
|
156
|
+
|
|
157
|
+
if any(c.is_satisfied_by(param_val) for c in constraint_objs):
|
|
158
|
+
continue # valid parameter
|
|
159
|
+
|
|
160
|
+
# Only visible constraints in error message
|
|
161
|
+
visible = [c for c in constraint_objs if not getattr(c, "hidden", False)] or constraint_objs
|
|
162
|
+
constraint_str = (
|
|
163
|
+
str(visible[0])
|
|
164
|
+
if len(visible) == 1
|
|
165
|
+
else ", ".join(map(str, visible[:-1])) + f" or {visible[-1]}"
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
raise InvalidParameterError(
|
|
169
|
+
f"The parameter '{param_name}' of {caller_name} must be {constraint_str}. "
|
|
170
|
+
f"Got {param_val!r} (type={type(param_val).__name__})."
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def validate_learner_contraints(quantifier, learner) -> None:
|
|
175
|
+
"""Validate the learner parameter of a quantifier."""
|
|
176
|
+
try:
|
|
177
|
+
tags = get_tags(quantifier)
|
|
178
|
+
except AttributeError as e:
|
|
179
|
+
raise InvalidParameterError(
|
|
180
|
+
f"Quantifier {quantifier.__class__.__name__} does not implement __mlquantify_tags__()."
|
|
181
|
+
) from e
|
|
182
|
+
|
|
183
|
+
if not tags.has_estimator:
|
|
184
|
+
if learner is not None:
|
|
185
|
+
raise InvalidParameterError(
|
|
186
|
+
f"The quantifier {quantifier.__class__.__name__} does not support using a learner."
|
|
187
|
+
)
|
|
188
|
+
return # No learner needed
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
estimator_function = tags.estimator_function
|
|
192
|
+
|
|
193
|
+
if estimator_function is None:
|
|
194
|
+
raise InvalidParameterError(f"The quantifier {quantifier.__class__.__name__} does not specify a valid estimator_function in its tags.")
|
|
195
|
+
elif estimator_function == "predict":
|
|
196
|
+
if not hasattr(quantifier.learner, "predict"):
|
|
197
|
+
raise InvalidParameterError(f"The provided learner does not have a 'predict' method, which is required by the quantifier {quantifier.__class__.__name__}.")
|
|
198
|
+
elif estimator_function == "predict_proba":
|
|
199
|
+
if not hasattr(quantifier.learner, "predict_proba"):
|
|
200
|
+
raise InvalidParameterError(f"The provided learner does not have a 'predict_proba' method, which is required by the quantifier {quantifier.__class__.__name__}.")
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
def _is_fitted(quantifier, attributes=None, all_or_any=all):
|
|
204
|
+
"""Check if the quantifier is fitted by verifying the presence of specified attributes."""
|
|
205
|
+
if attributes is None:
|
|
206
|
+
attributes = ["is_fitted_"]
|
|
207
|
+
|
|
208
|
+
checks = [hasattr(quantifier, attr) for attr in attributes]
|
|
209
|
+
return all(checks) if all_or_any == all else any(checks)
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
def check_is_fitted(quantifier, attributes=None, *, msg=None, all_or_any=all):
|
|
213
|
+
"""Raise NotFittedError if the quantifier is not fitted."""
|
|
214
|
+
|
|
215
|
+
if msg is None:
|
|
216
|
+
msg = f"This {quantifier.__class__.__name__} instance is not fitted yet. Call 'fit' first."
|
|
217
|
+
|
|
218
|
+
if not hasattr(quantifier, "fit"):
|
|
219
|
+
raise TypeError(f"Cannot check if {quantifier.__class__.__name__} is fitted: no 'fit' method found.")
|
|
220
|
+
|
|
221
|
+
tags = get_tags(quantifier)
|
|
222
|
+
|
|
223
|
+
if not tags.requires_fit and attributes is None:
|
|
224
|
+
return # No fitting required for this quantifier
|
|
225
|
+
|
|
226
|
+
if not _is_fitted(quantifier, attributes, all_or_any):
|
|
227
|
+
raise NotFittedError(msg % {"name": type(quantifier).__name__})
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
def _is_arraylike_not_scalar(array):
|
|
231
|
+
"""Return True if array is array-like and not a scalar"""
|
|
232
|
+
return _is_arraylike(array) and not np.isscalar(array)
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
def _is_arraylike(x):
|
|
236
|
+
"""Returns whether the input is array-like."""
|
|
237
|
+
if sp.issparse(x):
|
|
238
|
+
return False
|
|
239
|
+
|
|
240
|
+
return hasattr(x, "__len__") or hasattr(x, "shape") or hasattr(x, "__array__")
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
def validate_data(quantifier,
|
|
244
|
+
X="no_validation",
|
|
245
|
+
y="no_validation",
|
|
246
|
+
reset=True,
|
|
247
|
+
validate_separately=False,
|
|
248
|
+
skip_check_array=False,
|
|
249
|
+
**check_params):
|
|
250
|
+
"""
|
|
251
|
+
Validate input data X and optionally y using sklearn's validate_data.
|
|
252
|
+
Raises InputValidationError if validation fails.
|
|
253
|
+
"""
|
|
254
|
+
tags = get_tags(quantifier)
|
|
255
|
+
if y is None and tags.target_input_tags.required:
|
|
256
|
+
raise InputValidationError(
|
|
257
|
+
f"The target variable y is required for {quantifier.__class__.__name__}."
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
no_val_X = isinstance(X, str) and X == "no_validation"
|
|
261
|
+
no_val_y = y is None or (isinstance(y, str) and y == "no_validation")
|
|
262
|
+
|
|
263
|
+
if no_val_X and no_val_y:
|
|
264
|
+
raise ValueError("Validation should be done on X, y or both.")
|
|
265
|
+
|
|
266
|
+
default_check_params = {"estimator": quantifier}
|
|
267
|
+
check_params = {**default_check_params, **check_params}
|
|
268
|
+
|
|
269
|
+
if skip_check_array:
|
|
270
|
+
if not no_val_X and no_val_y:
|
|
271
|
+
out = X
|
|
272
|
+
elif no_val_X and not no_val_y:
|
|
273
|
+
out = y
|
|
274
|
+
else:
|
|
275
|
+
out = X, y
|
|
276
|
+
elif not no_val_X and no_val_y:
|
|
277
|
+
out = check_array(X, input_name="X", **check_params)
|
|
278
|
+
elif no_val_X and not no_val_y:
|
|
279
|
+
out = _check_y(y, **check_params)
|
|
280
|
+
else:
|
|
281
|
+
if validate_separately:
|
|
282
|
+
# We need this because some estimators validate X and y
|
|
283
|
+
# separately, and in general, separately calling check_array()
|
|
284
|
+
# on X and y isn't equivalent to just calling check_X_y()
|
|
285
|
+
# :(
|
|
286
|
+
check_X_params, check_y_params = validate_separately
|
|
287
|
+
if "estimator" not in check_X_params:
|
|
288
|
+
check_X_params = {**default_check_params, **check_X_params}
|
|
289
|
+
X = check_array(X, input_name="X", **check_X_params)
|
|
290
|
+
if "estimator" not in check_y_params:
|
|
291
|
+
check_y_params = {**default_check_params, **check_y_params}
|
|
292
|
+
y = check_array(y, input_name="y", **check_y_params)
|
|
293
|
+
else:
|
|
294
|
+
X, y = check_X_y(X, y, **check_params)
|
|
295
|
+
out = X, y
|
|
296
|
+
|
|
297
|
+
return out
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
def validate_prevalences(quantifier, prevalences: np.ndarray | list | dict, classes: np.ndarray, return_type: str = "dict", normalize: bool = True) -> dict | np.ndarray:
|
|
301
|
+
"""
|
|
302
|
+
Validate class prevalences according to quantifier tags.
|
|
303
|
+
|
|
304
|
+
Parameters
|
|
305
|
+
----------
|
|
306
|
+
quantifier : estimator
|
|
307
|
+
The quantifier instance
|
|
308
|
+
prevalences : np.ndarray, list, or dict
|
|
309
|
+
Predicted prevalences for each class
|
|
310
|
+
classes : np.ndarray
|
|
311
|
+
Array of class labels
|
|
312
|
+
return_type : str, default="dict"
|
|
313
|
+
Return format: "dict" or "array"
|
|
314
|
+
normalize : bool, default=True
|
|
315
|
+
Whether to normalize prevalences to sum to 1
|
|
316
|
+
|
|
317
|
+
Returns
|
|
318
|
+
-------
|
|
319
|
+
dict or np.ndarray
|
|
320
|
+
Validated prevalences in the requested format
|
|
321
|
+
"""
|
|
322
|
+
if return_type not in ["dict", "array"]:
|
|
323
|
+
raise InvalidParameterError(
|
|
324
|
+
f"return_type must be 'dict' or 'array', got {return_type!r}."
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
# Convert to dict if needed
|
|
328
|
+
if isinstance(prevalences, dict):
|
|
329
|
+
prev_dict = prevalences
|
|
330
|
+
elif isinstance(prevalences, (list, np.ndarray)):
|
|
331
|
+
prevalences = np.asarray(prevalences)
|
|
332
|
+
|
|
333
|
+
if len(prevalences) > len(classes):
|
|
334
|
+
raise InputValidationError(
|
|
335
|
+
f"Number of prevalences ({len(prevalences)}) cannot exceed number of classes ({len(classes)})."
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
# Create dict, padding with zeros if classes is larger
|
|
339
|
+
prev_dict = {}
|
|
340
|
+
for i, cls in enumerate(classes):
|
|
341
|
+
prev_dict[cls] = prevalences[i] if i < len(prevalences) else 0.0
|
|
342
|
+
else:
|
|
343
|
+
raise InputValidationError(
|
|
344
|
+
f"prevalences must be a numpy array, list, or dict, got {type(prevalences).__name__}."
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
# Validate all classes are present
|
|
348
|
+
if set(prev_dict.keys()) != set(classes):
|
|
349
|
+
raise InputValidationError(
|
|
350
|
+
f"prevalences keys must match classes. Got keys {set(prev_dict.keys())}, expected {set(classes)}."
|
|
351
|
+
)
|
|
352
|
+
|
|
353
|
+
# Normalize if requested
|
|
354
|
+
if normalize:
|
|
355
|
+
total = sum(prev_dict.values())
|
|
356
|
+
if total == 0:
|
|
357
|
+
raise InputValidationError("Cannot normalize prevalences: sum is zero.")
|
|
358
|
+
prev_dict = {cls: val / total for cls, val in prev_dict.items()}
|
|
359
|
+
|
|
360
|
+
# Convert numpy types to native Python types for cleaner output
|
|
361
|
+
|
|
362
|
+
prev_dict_converted = {}
|
|
363
|
+
# Convert numpy types to native Python types
|
|
364
|
+
for cls, val in prev_dict.items():
|
|
365
|
+
if isinstance(cls, np.integer):
|
|
366
|
+
cls = int(cls)
|
|
367
|
+
elif isinstance(cls, np.floating):
|
|
368
|
+
cls = float(cls)
|
|
369
|
+
elif isinstance(cls, np.str_):
|
|
370
|
+
cls = str(cls)
|
|
371
|
+
prev_dict_converted[cls] = float(val)
|
|
372
|
+
|
|
373
|
+
# Return in requested format
|
|
374
|
+
if return_type == "dict":
|
|
375
|
+
return prev_dict_converted
|
|
376
|
+
else:
|
|
377
|
+
return np.array([prev_dict_converted[cls] for cls in classes])
|
|
378
|
+
|
|
379
|
+
|
|
380
|
+
def normalize_prevalences(prevalences: np.ndarray | list | dict, classes: np.ndarray = None) -> np.ndarray | dict:
|
|
381
|
+
"""
|
|
382
|
+
Normalize prevalences to sum to 1.
|
|
383
|
+
|
|
384
|
+
Parameters
|
|
385
|
+
----------
|
|
386
|
+
prevalences : np.ndarray, list, or dict
|
|
387
|
+
Class prevalences to normalize
|
|
388
|
+
classes : np.ndarray, optional
|
|
389
|
+
Array of class labels (required if prevalences is array/list)
|
|
390
|
+
|
|
391
|
+
Returns
|
|
392
|
+
-------
|
|
393
|
+
np.ndarray or dict
|
|
394
|
+
Normalized prevalences in the same format as input
|
|
395
|
+
"""
|
|
396
|
+
if isinstance(prevalences, dict):
|
|
397
|
+
total = sum(prevalences.values())
|
|
398
|
+
if total == 0:
|
|
399
|
+
raise InputValidationError("Cannot normalize prevalences: sum is zero.")
|
|
400
|
+
normalized = {cls: val / total for cls, val in prevalences.items()}
|
|
401
|
+
|
|
402
|
+
normalized_dict = {}
|
|
403
|
+
# Convert numpy types to native Python types
|
|
404
|
+
for cls, val in normalized.items():
|
|
405
|
+
if isinstance(cls, np.integer):
|
|
406
|
+
cls = int(cls)
|
|
407
|
+
elif isinstance(cls, np.floating):
|
|
408
|
+
cls = float(cls)
|
|
409
|
+
elif isinstance(cls, np.str_):
|
|
410
|
+
cls = str(cls)
|
|
411
|
+
normalized_dict[cls] = float(val)
|
|
412
|
+
return normalized_dict
|
|
413
|
+
|
|
414
|
+
elif isinstance(prevalences, (list, np.ndarray)):
|
|
415
|
+
prevalences = np.asarray(prevalences)
|
|
416
|
+
total = prevalences.sum()
|
|
417
|
+
if total == 0:
|
|
418
|
+
raise InputValidationError("Cannot normalize prevalences: sum is zero.")
|
|
419
|
+
return prevalences / total
|
|
420
|
+
|
|
421
|
+
else:
|
|
422
|
+
raise InputValidationError(
|
|
423
|
+
f"prevalences must be a numpy array, list, or dict, got {type(prevalences).__name__}."
|
|
424
|
+
)
|
|
425
|
+
|
|
426
|
+
|
|
427
|
+
def check_has_method(obj: Any, method_name: str) -> bool:
|
|
428
|
+
"""Check if the object has a callable method with the given name."""
|
|
429
|
+
return callable(getattr(obj, method_name, None))
|
|
430
|
+
|
|
431
|
+
def check_classes_attribute(quantifier: Any, classes) -> bool:
|
|
432
|
+
"""Check if the quantifier has a 'classes_' attribute and if it matches the type of classes."""
|
|
433
|
+
|
|
434
|
+
if not hasattr(quantifier, "classes_"):
|
|
435
|
+
return classes
|
|
436
|
+
|
|
437
|
+
quantifier_classes = quantifier.classes_
|
|
438
|
+
|
|
439
|
+
# Check if types match
|
|
440
|
+
if type(quantifier_classes) != type(classes):
|
|
441
|
+
return classes
|
|
442
|
+
|
|
443
|
+
# Check if shapes match before comparing elements
|
|
444
|
+
if len(quantifier_classes) != len(classes) or not np.all(quantifier_classes == classes):
|
|
445
|
+
return classes
|
|
446
|
+
return quantifier_classes
|
|
447
|
+
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import pandas as pd
|
|
3
|
+
from collections import defaultdict
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def get_prev_from_labels(y) -> dict:
|
|
7
|
+
"""
|
|
8
|
+
Get the real prevalence of each class in the target array.
|
|
9
|
+
|
|
10
|
+
Parameters
|
|
11
|
+
----------
|
|
12
|
+
y : np.ndarray or pd.Series
|
|
13
|
+
Array of class labels.
|
|
14
|
+
|
|
15
|
+
Returns
|
|
16
|
+
-------
|
|
17
|
+
dict
|
|
18
|
+
Dictionary of class labels and their corresponding prevalence.
|
|
19
|
+
"""
|
|
20
|
+
if isinstance(y, np.ndarray):
|
|
21
|
+
y = pd.Series(y)
|
|
22
|
+
real_prevs = y.value_counts(normalize=True).to_dict()
|
|
23
|
+
real_prevs = dict(sorted(real_prevs.items()))
|
|
24
|
+
return real_prevs
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def normalize_prevalence(prevalences: np.ndarray, classes:list):
|
|
31
|
+
"""
|
|
32
|
+
Normalize the prevalence of each class to sum to 1.
|
|
33
|
+
|
|
34
|
+
Parameters
|
|
35
|
+
----------
|
|
36
|
+
prevalences : np.ndarray
|
|
37
|
+
Array of prevalences.
|
|
38
|
+
classes : list
|
|
39
|
+
List of unique classes.
|
|
40
|
+
|
|
41
|
+
Returns
|
|
42
|
+
-------
|
|
43
|
+
dict
|
|
44
|
+
Dictionary of class labels and their corresponding prevalence.
|
|
45
|
+
"""
|
|
46
|
+
if isinstance(prevalences, dict):
|
|
47
|
+
summ = sum(prevalences.values())
|
|
48
|
+
prevalences = {int(_class):float(value/summ) for _class, value in prevalences.items()}
|
|
49
|
+
return prevalences
|
|
50
|
+
|
|
51
|
+
summ = np.sum(prevalences, axis=-1, keepdims=True)
|
|
52
|
+
prevalences = np.true_divide(prevalences, sum(prevalences), where=summ>0)
|
|
53
|
+
prevalences = {int(_class):float(prev) for _class, prev in zip(classes, prevalences)}
|
|
54
|
+
prevalences = defaultdict(lambda: 0, prevalences)
|
|
55
|
+
|
|
56
|
+
# Ensure all classes are present in the result
|
|
57
|
+
for cls in classes:
|
|
58
|
+
prevalences[cls] = prevalences[cls]
|
|
59
|
+
|
|
60
|
+
return dict(prevalences)
|
|
61
|
+
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: mlquantify
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.9
|
|
4
4
|
Summary: Quantification Library
|
|
5
5
|
Home-page: https://github.com/luizfernandolj/QuantifyML/tree/master
|
|
6
6
|
Maintainer: Luiz Fernando Luth Junior
|
|
@@ -20,6 +20,7 @@ Requires-Dist: tqdm
|
|
|
20
20
|
Requires-Dist: pandas
|
|
21
21
|
Requires-Dist: xlrd
|
|
22
22
|
Requires-Dist: matplotlib
|
|
23
|
+
Requires-Dist: abstention
|
|
23
24
|
Dynamic: classifier
|
|
24
25
|
Dynamic: description
|
|
25
26
|
Dynamic: description-content-type
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
mlquantify/__init__.py,sha256=P48iiVlcAeKeE6wr6yZGMTKwmtCOvQYO4ZUVCKAQMwM,52
|
|
2
|
+
mlquantify/base.py,sha256=o7IaKODocyi4tEmCvGmHKQ8F4ZJsaEh4kymsNcLyHAg,5077
|
|
3
|
+
mlquantify/base_aggregative.py,sha256=uqfhpUmgv5pNLLvqgROCWHfjs3sj_2jfwOTyzUySuGo,7545
|
|
4
|
+
mlquantify/calibration.py,sha256=chG3GNX2BBDTWIuSVfZUJ_YF_ZVBSoel2d_AN0OChS0,6
|
|
5
|
+
mlquantify/confidence.py,sha256=IUF6sLVmDi4XxX5NvbSFl4-cBbl1mdedtDMFqV-GA48,10918
|
|
6
|
+
mlquantify/multiclass.py,sha256=Jux0fvL5IBZA3DXLCuqUEE77JYYBGAcW6GaEH9srmu4,11747
|
|
7
|
+
mlquantify/adjust_counting/__init__.py,sha256=8qQtTzRAoRiIux_R8wCXopdi6dOg1ESd8oPWv-LvUC0,191
|
|
8
|
+
mlquantify/adjust_counting/_adjustment.py,sha256=BdFwYTwWhwdSxBgu98yTsVyyxgPz_Xm53YEjMxXI8f8,12824
|
|
9
|
+
mlquantify/adjust_counting/_base.py,sha256=abF0lo3fetR77JP87MS7Hy204jF7NYdxwyJWPE5hNyE,9344
|
|
10
|
+
mlquantify/adjust_counting/_counting.py,sha256=n4pBdyntrrxZTu7dWMnCsgN5kz6npU7CNIgRPQLY-nA,5266
|
|
11
|
+
mlquantify/adjust_counting/_utils.py,sha256=wlBrihWKPzzxXmIqowreZ_lN6buD6hFCH98qA3H6s5s,2636
|
|
12
|
+
mlquantify/likelihood/__init__.py,sha256=3dC5uregNmquUKz0r0-3aPspfjZjKGn3TRBoZPO1uFs,53
|
|
13
|
+
mlquantify/likelihood/_base.py,sha256=J6ze15i-TlMMEVl4KvE2_wdam-fq0ZqWl7pSkas35qs,6075
|
|
14
|
+
mlquantify/likelihood/_classes.py,sha256=Xp0hU83mYmfs1AOlmGEYLLsBPZBjPoi2xTx-2H4ztuI,15111
|
|
15
|
+
mlquantify/meta/__init__.py,sha256=GzdGw4ky_kmd5VNWiLBULy06IdN_MLCDAuJKbnMOx4s,62
|
|
16
|
+
mlquantify/meta/_classes.py,sha256=msivgTXvPw6Duq2Uv_odoayX-spZPtuWtD0FQ_8UFdw,29824
|
|
17
|
+
mlquantify/metrics/__init__.py,sha256=3bzzjSYTgrZIJsfAgJidQlB-bnjInwVYUvJ34bPhZxY,186
|
|
18
|
+
mlquantify/metrics/_oq.py,sha256=qTLyKpQkdnyzNOmWjplnLnr7nMDNqlBtfnddo5XHJ48,3542
|
|
19
|
+
mlquantify/metrics/_rq.py,sha256=v0FUepNF-Wj0f1MdB1-9TXSNDze-J0BXUqaTCo5gnUA,3032
|
|
20
|
+
mlquantify/metrics/_slq.py,sha256=nigIpZtPhPYVe5GU3qf1TOxGIkmKOrrhLXAm_tDPaCQ,6808
|
|
21
|
+
mlquantify/mixture/__init__.py,sha256=_KKhpFuvi3vYwxydm5nOy9MKwmIU4eyZDN9Pe00hqtk,70
|
|
22
|
+
mlquantify/mixture/_base.py,sha256=VDAOY6vFM2OayQxN4APysZ-ZycfrUwUS5Zzjr5v2t04,6076
|
|
23
|
+
mlquantify/mixture/_classes.py,sha256=BbBrMIFKWoaP5CjW35agecwl3TE6ZmmBh8kwUyp72Ig,14012
|
|
24
|
+
mlquantify/mixture/_utils.py,sha256=3507D13aw6Xl5Ki5bcC1j8yZH4EqawmtRke9m4AouT4,3049
|
|
25
|
+
mlquantify/model_selection/__init__.py,sha256=98I0uf8k6lbWAjazGyGjbOdPOvzU8aMRLqC3I7D3jzk,113
|
|
26
|
+
mlquantify/model_selection/_protocol.py,sha256=J-96OPJCkwtwk96P962qeztENRwypO__SbDLxM-Myvo,12493
|
|
27
|
+
mlquantify/model_selection/_search.py,sha256=YXeSSJXQVrKjwxfKOKJ9amkXZ1mOPJWKh2x2SQNO5rM,10694
|
|
28
|
+
mlquantify/model_selection/_split.py,sha256=chG3GNX2BBDTWIuSVfZUJ_YF_ZVBSoel2d_AN0OChS0,6
|
|
29
|
+
mlquantify/neighbors/__init__.py,sha256=rIOuSaUhjqEXsUN9HNZ62P53QG0N7lJ3j1pvf8kJzms,93
|
|
30
|
+
mlquantify/neighbors/_base.py,sha256=tYgq_yjEuqv0dipo2hZW99GxlF09mwpTxlMLV2gYUHo,8258
|
|
31
|
+
mlquantify/neighbors/_classes.py,sha256=zWPn9zhY_Xw3AgC1A112DUO3LrbuKOMvkZU2Cx0elYU,5577
|
|
32
|
+
mlquantify/neighbors/_classification.py,sha256=wzqz6eZpE4AAyBmxv8cjheluoUB_RDYse_mfLCVXfzI,4310
|
|
33
|
+
mlquantify/neighbors/_kde.py,sha256=URCq_o-bSb94_-1jX1Ag-1ZLkP0sBa-D3obAgqN6YYg,9930
|
|
34
|
+
mlquantify/neighbors/_utils.py,sha256=rAu2VuW13rBj935z_m-u0MpfQbLQC0Iq_1WPSnAXZCk,4114
|
|
35
|
+
mlquantify/neural/__init__.py,sha256=chG3GNX2BBDTWIuSVfZUJ_YF_ZVBSoel2d_AN0OChS0,6
|
|
36
|
+
mlquantify/utils/__init__.py,sha256=fCozxFABSv5L7lbD16-J370dbc_xHien3w0crYKPLTc,1344
|
|
37
|
+
mlquantify/utils/_artificial.py,sha256=6tqMoAuxUULFGHXtMez56re4DZ7d2Q6tK55LPGeEiO8,713
|
|
38
|
+
mlquantify/utils/_constraints.py,sha256=r1WDJuqsO3OS2Q45IBKJGtB6iUjcAXMW8USaEakyvCI,5600
|
|
39
|
+
mlquantify/utils/_context.py,sha256=25QmzmfSiuF_hwCjY_7db_XfCnj1dVe4mIbDycVTHf8,661
|
|
40
|
+
mlquantify/utils/_decorators.py,sha256=yYtnPBh1sLSN6wTY-7ZVAV0j--qbpJxBsgncm794JPc,1205
|
|
41
|
+
mlquantify/utils/_exceptions.py,sha256=C3BQSv3-7QDLaorKcV-ANxnBcSaxHQSlCc6YSZrPK6c,392
|
|
42
|
+
mlquantify/utils/_get_scores.py,sha256=VlTvgg_t4D9MzcgsH7YvP_wIL5AZ8XmEtGpbFivdVJk,5280
|
|
43
|
+
mlquantify/utils/_load.py,sha256=cMGXIs-8mUB4blAmagyDNNvAaV2hysRgeInQMl5fDHg,303
|
|
44
|
+
mlquantify/utils/_parallel.py,sha256=XotpX9nsj6nW-tNCmZ-ahTcRztgnn9oQKP2cl1rLdYM,196
|
|
45
|
+
mlquantify/utils/_random.py,sha256=7F3nyy7Pa_kN8xP8P1L6MOM4WFu4BirE7bOfGTZ1Spk,1275
|
|
46
|
+
mlquantify/utils/_sampling.py,sha256=QQxE2WKLdiCFUfPF6fKgzyrsOUIWYf74w_w8fbYVc2c,8409
|
|
47
|
+
mlquantify/utils/_tags.py,sha256=Rz78TLpxgVxBKS0mKTlC9Qo_kn6HaEwVKNXh8pxFT7M,1095
|
|
48
|
+
mlquantify/utils/_validation.py,sha256=dE7NYLy6C5UWf8tXIhQeWLTz2-rej_gr8-aAIwgJTPk,16762
|
|
49
|
+
mlquantify/utils/prevalence.py,sha256=9chdjfUyac7Omxv50Rb_HmfkQFrHfTjGiQPdbVH7FXc,1631
|
|
50
|
+
mlquantify-0.1.9.dist-info/METADATA,sha256=QZUSlEfWxeFGjI8R1QzJx5Y4DeyCiWjFqxFHQzYEIz0,5192
|
|
51
|
+
mlquantify-0.1.9.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
52
|
+
mlquantify-0.1.9.dist-info/top_level.txt,sha256=tGEkYkbbFElwULvqENjam3u1uXtyC1J9dRmibsq8_n0,11
|
|
53
|
+
mlquantify-0.1.9.dist-info/RECORD,,
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
from .methods import *
|
|
@@ -1,14 +0,0 @@
|
|
|
1
|
-
from . import measures
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
MEASURES = {
|
|
5
|
-
"ae": measures.absolute_error,
|
|
6
|
-
"mae": measures.mean_absolute_error,
|
|
7
|
-
"nae": measures.normalized_absolute_error,
|
|
8
|
-
"kld": measures.kullback_leibler_divergence,
|
|
9
|
-
"nkld": measures.normalized_kullback_leibler_divergence,
|
|
10
|
-
"nrae": measures.normalized_relative_absolute_error,
|
|
11
|
-
"rae": measures.relative_absolute_error,
|
|
12
|
-
"se": measures.squared_error,
|
|
13
|
-
"mse": measures.mean_squared_error
|
|
14
|
-
}
|