mlquantify 0.1.16__py3-none-any.whl → 0.1.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlquantify/adjust_counting/_adjustment.py +3 -2
- mlquantify/adjust_counting/_counting.py +3 -2
- mlquantify/utils/_validation.py +33 -7
- {mlquantify-0.1.16.dist-info → mlquantify-0.1.18.dist-info}/METADATA +1 -1
- {mlquantify-0.1.16.dist-info → mlquantify-0.1.18.dist-info}/RECORD +7 -7
- {mlquantify-0.1.16.dist-info → mlquantify-0.1.18.dist-info}/WHEEL +0 -0
- {mlquantify-0.1.16.dist-info → mlquantify-0.1.18.dist-info}/top_level.txt +0 -0
|
@@ -105,7 +105,8 @@ class ThresholdAdjustment(SoftLearnerQMixin, BaseAdjustCount):
|
|
|
105
105
|
thresholds, tprs, fprs = evaluate_thresholds(train_y_values, positive_scores)
|
|
106
106
|
threshold, tpr, fpr = self.get_best_threshold(thresholds, tprs, fprs)
|
|
107
107
|
|
|
108
|
-
cc_predictions = CC(threshold=threshold).aggregate(predictions, train_y_values)
|
|
108
|
+
cc_predictions = CC(threshold=threshold).aggregate(predictions, train_y_values)
|
|
109
|
+
cc_predictions = list(cc_predictions.values())[1]
|
|
109
110
|
|
|
110
111
|
if tpr - fpr == 0:
|
|
111
112
|
prevalence = cc_predictions
|
|
@@ -609,7 +610,7 @@ class MS(ThresholdAdjustment):
|
|
|
609
610
|
prevs = []
|
|
610
611
|
for thr, tpr, fpr in zip(thresholds, tprs, fprs):
|
|
611
612
|
cc_predictions = CC(threshold=thr).aggregate(predictions, train_y_values)
|
|
612
|
-
cc_predictions = cc_predictions[1]
|
|
613
|
+
cc_predictions = list(cc_predictions.values())[1]
|
|
613
614
|
|
|
614
615
|
if tpr - fpr == 0:
|
|
615
616
|
prevalence = cc_predictions
|
|
@@ -76,14 +76,15 @@ class CC(CrispLearnerQMixin, BaseCount):
|
|
|
76
76
|
self.threshold = threshold
|
|
77
77
|
|
|
78
78
|
def aggregate(self, predictions, train_y_values=None):
|
|
79
|
-
predictions = validate_predictions(self, predictions, self.threshold)
|
|
79
|
+
predictions = validate_predictions(self, predictions, self.threshold, train_y_values)
|
|
80
80
|
|
|
81
81
|
if train_y_values is None:
|
|
82
82
|
train_y_values = np.unique(predictions)
|
|
83
|
+
|
|
83
84
|
self.classes_ = check_classes_attribute(self, np.unique(train_y_values))
|
|
84
85
|
class_counts = np.array([np.count_nonzero(predictions == _class) for _class in self.classes_])
|
|
85
86
|
prevalences = class_counts / len(predictions)
|
|
86
|
-
|
|
87
|
+
|
|
87
88
|
prevalences = validate_prevalences(self, prevalences, self.classes_)
|
|
88
89
|
return prevalences
|
|
89
90
|
|
mlquantify/utils/_validation.py
CHANGED
|
@@ -94,25 +94,51 @@ def validate_y(quantifier: Any, y: np.ndarray) -> None:
|
|
|
94
94
|
f"Predictions must be 1D or 2D array, got array with ndim={y.ndim} and shape={y.shape}."
|
|
95
95
|
)
|
|
96
96
|
|
|
97
|
-
def _get_valid_crisp_predictions(predictions, threshold=0.5):
|
|
97
|
+
def _get_valid_crisp_predictions(predictions, train_y_values=None, threshold=0.5):
|
|
98
98
|
predictions = np.asarray(predictions)
|
|
99
|
-
|
|
100
99
|
dimensions = predictions.ndim
|
|
101
100
|
|
|
101
|
+
if train_y_values is not None:
|
|
102
|
+
classes = np.unique(train_y_values)
|
|
103
|
+
else:
|
|
104
|
+
classes = None
|
|
105
|
+
|
|
102
106
|
if dimensions > 2:
|
|
103
|
-
|
|
107
|
+
# Assuming the last dimension contains class probabilities
|
|
108
|
+
crisp_indices = np.argmax(predictions, axis=-1)
|
|
109
|
+
if classes is not None:
|
|
110
|
+
predictions = classes[crisp_indices]
|
|
111
|
+
else:
|
|
112
|
+
predictions = crisp_indices
|
|
104
113
|
elif dimensions == 2:
|
|
105
|
-
|
|
114
|
+
# Binary or multi-class probabilities (N, C)
|
|
115
|
+
if classes is not None and len(classes) == 2:
|
|
116
|
+
# Binary case with explicit classes
|
|
117
|
+
predictions = np.where(predictions[:, 1] >= threshold, classes[1], classes[0])
|
|
118
|
+
elif classes is not None and len(classes) > 2:
|
|
119
|
+
# Multi-class case with explicit classes
|
|
120
|
+
crisp_indices = np.argmax(predictions, axis=1)
|
|
121
|
+
predictions = classes[crisp_indices]
|
|
122
|
+
else:
|
|
123
|
+
# Default binary (0 or 1) or multi-class (0 to C-1)
|
|
124
|
+
if predictions.shape[1] == 2:
|
|
125
|
+
predictions = (predictions[:, 1] >= threshold).astype(int)
|
|
126
|
+
else:
|
|
127
|
+
predictions = np.argmax(predictions, axis=1)
|
|
106
128
|
elif dimensions == 1:
|
|
129
|
+
# 1D probabilities (e.g., probability of positive class)
|
|
107
130
|
if np.issubdtype(predictions.dtype, np.floating):
|
|
108
|
-
|
|
131
|
+
if classes is not None and len(classes) == 2:
|
|
132
|
+
predictions = np.where(predictions >= threshold, classes[1], classes[0])
|
|
133
|
+
else:
|
|
134
|
+
predictions = (predictions >= threshold).astype(int)
|
|
109
135
|
else:
|
|
110
136
|
raise ValueError(f"Predictions array has an invalid number of dimensions. Expected 1 or more dimensions, got {predictions.ndim}.")
|
|
111
137
|
|
|
112
138
|
return predictions
|
|
113
139
|
|
|
114
140
|
|
|
115
|
-
def validate_predictions(quantifier: Any, predictions: np.ndarray, threshold: float = 0.5) -> np.ndarray:
|
|
141
|
+
def validate_predictions(quantifier: Any, predictions: np.ndarray, threshold: float = 0.5, train_y_values=None) -> np.ndarray:
|
|
116
142
|
"""
|
|
117
143
|
Validate predictions using the quantifier's declared output tags.
|
|
118
144
|
Raises InputValidationError if inconsistent with tags.
|
|
@@ -132,7 +158,7 @@ def validate_predictions(quantifier: Any, predictions: np.ndarray, threshold: fl
|
|
|
132
158
|
f"Soft predictions for {quantifier.__class__.__name__} must be float, got dtype {predictions.dtype}."
|
|
133
159
|
)
|
|
134
160
|
elif estimator_type == "crisp" and np.issubdtype(predictions.dtype, np.floating):
|
|
135
|
-
predictions = _get_valid_crisp_predictions(predictions, threshold)
|
|
161
|
+
predictions = _get_valid_crisp_predictions(predictions, train_y_values, threshold)
|
|
136
162
|
return predictions
|
|
137
163
|
|
|
138
164
|
|
|
@@ -5,9 +5,9 @@ mlquantify/calibration.py,sha256=chG3GNX2BBDTWIuSVfZUJ_YF_ZVBSoel2d_AN0OChS0,6
|
|
|
5
5
|
mlquantify/confidence.py,sha256=QkEWr6s-Su3Nbinia_TRQbBeTM6ymDPe7Bv204XBKKA,10799
|
|
6
6
|
mlquantify/multiclass.py,sha256=wFbbXKqGsFVSsI9zC0EHGYyyx1JRxFpzMi_q8l80TUM,11770
|
|
7
7
|
mlquantify/adjust_counting/__init__.py,sha256=AWio99zeaUULQq9vKggkFhnq-tqgXxasQt167NdcNVY,307
|
|
8
|
-
mlquantify/adjust_counting/_adjustment.py,sha256=
|
|
8
|
+
mlquantify/adjust_counting/_adjustment.py,sha256=YGZiaGBdlWaw7vxGmSfgtWRSqU7Ppc9Lh1eoxmAbAX8,23705
|
|
9
9
|
mlquantify/adjust_counting/_base.py,sha256=MjBsNG7wE0Z_KToXX8WbthhVvz-yc0-d2zIqPo1CB9g,9429
|
|
10
|
-
mlquantify/adjust_counting/_counting.py,sha256=
|
|
10
|
+
mlquantify/adjust_counting/_counting.py,sha256=pAWfKK2mHiWAtyuhlld51CnMJFFfMo_X_3IEt5LfCdw,5720
|
|
11
11
|
mlquantify/adjust_counting/_utils.py,sha256=DEPNzvcr0KszCnfUJaRzBilwWzuNVMSdy5eV7aQ_JPE,2907
|
|
12
12
|
mlquantify/likelihood/__init__.py,sha256=3dC5uregNmquUKz0r0-3aPspfjZjKGn3TRBoZPO1uFs,53
|
|
13
13
|
mlquantify/likelihood/_base.py,sha256=seu_Vb58QttcGbFjHKAplMYGZcVbIHqkyTXEK2cax9A,5830
|
|
@@ -45,9 +45,9 @@ mlquantify/utils/_parallel.py,sha256=XotpX9nsj6nW-tNCmZ-ahTcRztgnn9oQKP2cl1rLdYM
|
|
|
45
45
|
mlquantify/utils/_random.py,sha256=7F3nyy7Pa_kN8xP8P1L6MOM4WFu4BirE7bOfGTZ1Spk,1275
|
|
46
46
|
mlquantify/utils/_sampling.py,sha256=QQxE2WKLdiCFUfPF6fKgzyrsOUIWYf74w_w8fbYVc2c,8409
|
|
47
47
|
mlquantify/utils/_tags.py,sha256=Rz78TLpxgVxBKS0mKTlC9Qo_kn6HaEwVKNXh8pxFT7M,1095
|
|
48
|
-
mlquantify/utils/_validation.py,sha256=
|
|
48
|
+
mlquantify/utils/_validation.py,sha256=zn4OHfa704YBaPKskhiThUG7wS5fvDoHBpcEgb1i8qM,18078
|
|
49
49
|
mlquantify/utils/prevalence.py,sha256=FXLCJViQb2yDbyTXeGZt8WsPPnSZINhorQYZTKXOn14,1772
|
|
50
|
-
mlquantify-0.1.
|
|
51
|
-
mlquantify-0.1.
|
|
52
|
-
mlquantify-0.1.
|
|
53
|
-
mlquantify-0.1.
|
|
50
|
+
mlquantify-0.1.18.dist-info/METADATA,sha256=XrQ188Icw5RZEAN8tvHRHTsRm1IKB1iwR_tm6G7uB0w,4701
|
|
51
|
+
mlquantify-0.1.18.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
52
|
+
mlquantify-0.1.18.dist-info/top_level.txt,sha256=tGEkYkbbFElwULvqENjam3u1uXtyC1J9dRmibsq8_n0,11
|
|
53
|
+
mlquantify-0.1.18.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|