mlquantify 0.1.16__py3-none-any.whl → 0.1.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -105,7 +105,8 @@ class ThresholdAdjustment(SoftLearnerQMixin, BaseAdjustCount):
105
105
  thresholds, tprs, fprs = evaluate_thresholds(train_y_values, positive_scores)
106
106
  threshold, tpr, fpr = self.get_best_threshold(thresholds, tprs, fprs)
107
107
 
108
- cc_predictions = CC(threshold=threshold).aggregate(predictions, train_y_values)[1]
108
+ cc_predictions = CC(threshold=threshold).aggregate(predictions, train_y_values)
109
+ cc_predictions = list(cc_predictions.values())[1]
109
110
 
110
111
  if tpr - fpr == 0:
111
112
  prevalence = cc_predictions
@@ -609,7 +610,7 @@ class MS(ThresholdAdjustment):
609
610
  prevs = []
610
611
  for thr, tpr, fpr in zip(thresholds, tprs, fprs):
611
612
  cc_predictions = CC(threshold=thr).aggregate(predictions, train_y_values)
612
- cc_predictions = cc_predictions[1]
613
+ cc_predictions = list(cc_predictions.values())[1]
613
614
 
614
615
  if tpr - fpr == 0:
615
616
  prevalence = cc_predictions
@@ -76,14 +76,15 @@ class CC(CrispLearnerQMixin, BaseCount):
76
76
  self.threshold = threshold
77
77
 
78
78
  def aggregate(self, predictions, train_y_values=None):
79
- predictions = validate_predictions(self, predictions, self.threshold)
79
+ predictions = validate_predictions(self, predictions, self.threshold, train_y_values)
80
80
 
81
81
  if train_y_values is None:
82
82
  train_y_values = np.unique(predictions)
83
+
83
84
  self.classes_ = check_classes_attribute(self, np.unique(train_y_values))
84
85
  class_counts = np.array([np.count_nonzero(predictions == _class) for _class in self.classes_])
85
86
  prevalences = class_counts / len(predictions)
86
-
87
+
87
88
  prevalences = validate_prevalences(self, prevalences, self.classes_)
88
89
  return prevalences
89
90
 
@@ -94,25 +94,51 @@ def validate_y(quantifier: Any, y: np.ndarray) -> None:
94
94
  f"Predictions must be 1D or 2D array, got array with ndim={y.ndim} and shape={y.shape}."
95
95
  )
96
96
 
97
- def _get_valid_crisp_predictions(predictions, threshold=0.5):
97
+ def _get_valid_crisp_predictions(predictions, train_y_values=None, threshold=0.5):
98
98
  predictions = np.asarray(predictions)
99
-
100
99
  dimensions = predictions.ndim
101
100
 
101
+ if train_y_values is not None:
102
+ classes = np.unique(train_y_values)
103
+ else:
104
+ classes = None
105
+
102
106
  if dimensions > 2:
103
- predictions = np.argmax(predictions, axis=1)
107
+ # Assuming the last dimension contains class probabilities
108
+ crisp_indices = np.argmax(predictions, axis=-1)
109
+ if classes is not None:
110
+ predictions = classes[crisp_indices]
111
+ else:
112
+ predictions = crisp_indices
104
113
  elif dimensions == 2:
105
- predictions = (predictions[:, 1] >= threshold).astype(int)
114
+ # Binary or multi-class probabilities (N, C)
115
+ if classes is not None and len(classes) == 2:
116
+ # Binary case with explicit classes
117
+ predictions = np.where(predictions[:, 1] >= threshold, classes[1], classes[0])
118
+ elif classes is not None and len(classes) > 2:
119
+ # Multi-class case with explicit classes
120
+ crisp_indices = np.argmax(predictions, axis=1)
121
+ predictions = classes[crisp_indices]
122
+ else:
123
+ # Default binary (0 or 1) or multi-class (0 to C-1)
124
+ if predictions.shape[1] == 2:
125
+ predictions = (predictions[:, 1] >= threshold).astype(int)
126
+ else:
127
+ predictions = np.argmax(predictions, axis=1)
106
128
  elif dimensions == 1:
129
+ # 1D probabilities (e.g., probability of positive class)
107
130
  if np.issubdtype(predictions.dtype, np.floating):
108
- predictions = (predictions >= threshold).astype(int)
131
+ if classes is not None and len(classes) == 2:
132
+ predictions = np.where(predictions >= threshold, classes[1], classes[0])
133
+ else:
134
+ predictions = (predictions >= threshold).astype(int)
109
135
  else:
110
136
  raise ValueError(f"Predictions array has an invalid number of dimensions. Expected 1 or more dimensions, got {predictions.ndim}.")
111
137
 
112
138
  return predictions
113
139
 
114
140
 
115
- def validate_predictions(quantifier: Any, predictions: np.ndarray, threshold: float = 0.5) -> np.ndarray:
141
+ def validate_predictions(quantifier: Any, predictions: np.ndarray, threshold: float = 0.5, train_y_values=None) -> np.ndarray:
116
142
  """
117
143
  Validate predictions using the quantifier's declared output tags.
118
144
  Raises InputValidationError if inconsistent with tags.
@@ -132,7 +158,7 @@ def validate_predictions(quantifier: Any, predictions: np.ndarray, threshold: fl
132
158
  f"Soft predictions for {quantifier.__class__.__name__} must be float, got dtype {predictions.dtype}."
133
159
  )
134
160
  elif estimator_type == "crisp" and np.issubdtype(predictions.dtype, np.floating):
135
- predictions = _get_valid_crisp_predictions(predictions, threshold)
161
+ predictions = _get_valid_crisp_predictions(predictions, train_y_values, threshold)
136
162
  return predictions
137
163
 
138
164
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mlquantify
3
- Version: 0.1.16
3
+ Version: 0.1.18
4
4
  Summary: Quantification Library
5
5
  Home-page: https://github.com/luizfernandolj/QuantifyML/tree/master
6
6
  Maintainer: Luiz Fernando Luth Junior
@@ -5,9 +5,9 @@ mlquantify/calibration.py,sha256=chG3GNX2BBDTWIuSVfZUJ_YF_ZVBSoel2d_AN0OChS0,6
5
5
  mlquantify/confidence.py,sha256=QkEWr6s-Su3Nbinia_TRQbBeTM6ymDPe7Bv204XBKKA,10799
6
6
  mlquantify/multiclass.py,sha256=wFbbXKqGsFVSsI9zC0EHGYyyx1JRxFpzMi_q8l80TUM,11770
7
7
  mlquantify/adjust_counting/__init__.py,sha256=AWio99zeaUULQq9vKggkFhnq-tqgXxasQt167NdcNVY,307
8
- mlquantify/adjust_counting/_adjustment.py,sha256=412kFnx3-noaA9u9AatuGIvJbKze-PLPFfBFMBVmQVA,23635
8
+ mlquantify/adjust_counting/_adjustment.py,sha256=YGZiaGBdlWaw7vxGmSfgtWRSqU7Ppc9Lh1eoxmAbAX8,23705
9
9
  mlquantify/adjust_counting/_base.py,sha256=MjBsNG7wE0Z_KToXX8WbthhVvz-yc0-d2zIqPo1CB9g,9429
10
- mlquantify/adjust_counting/_counting.py,sha256=6PKea54xvsga8spNEbsngKNQPyGUXzOkCRyXQR8rTdo,5699
10
+ mlquantify/adjust_counting/_counting.py,sha256=pAWfKK2mHiWAtyuhlld51CnMJFFfMo_X_3IEt5LfCdw,5720
11
11
  mlquantify/adjust_counting/_utils.py,sha256=DEPNzvcr0KszCnfUJaRzBilwWzuNVMSdy5eV7aQ_JPE,2907
12
12
  mlquantify/likelihood/__init__.py,sha256=3dC5uregNmquUKz0r0-3aPspfjZjKGn3TRBoZPO1uFs,53
13
13
  mlquantify/likelihood/_base.py,sha256=seu_Vb58QttcGbFjHKAplMYGZcVbIHqkyTXEK2cax9A,5830
@@ -45,9 +45,9 @@ mlquantify/utils/_parallel.py,sha256=XotpX9nsj6nW-tNCmZ-ahTcRztgnn9oQKP2cl1rLdYM
45
45
  mlquantify/utils/_random.py,sha256=7F3nyy7Pa_kN8xP8P1L6MOM4WFu4BirE7bOfGTZ1Spk,1275
46
46
  mlquantify/utils/_sampling.py,sha256=QQxE2WKLdiCFUfPF6fKgzyrsOUIWYf74w_w8fbYVc2c,8409
47
47
  mlquantify/utils/_tags.py,sha256=Rz78TLpxgVxBKS0mKTlC9Qo_kn6HaEwVKNXh8pxFT7M,1095
48
- mlquantify/utils/_validation.py,sha256=TGGnfv7F5rnQmVeSqGMuS9AP76O974b1TPishKCCWls,16800
48
+ mlquantify/utils/_validation.py,sha256=zn4OHfa704YBaPKskhiThUG7wS5fvDoHBpcEgb1i8qM,18078
49
49
  mlquantify/utils/prevalence.py,sha256=FXLCJViQb2yDbyTXeGZt8WsPPnSZINhorQYZTKXOn14,1772
50
- mlquantify-0.1.16.dist-info/METADATA,sha256=RZWNq8k48KnN4PYd-6Iw5mv-qS6bAVL2tc6zg9cMqN0,4701
51
- mlquantify-0.1.16.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
52
- mlquantify-0.1.16.dist-info/top_level.txt,sha256=tGEkYkbbFElwULvqENjam3u1uXtyC1J9dRmibsq8_n0,11
53
- mlquantify-0.1.16.dist-info/RECORD,,
50
+ mlquantify-0.1.18.dist-info/METADATA,sha256=XrQ188Icw5RZEAN8tvHRHTsRm1IKB1iwR_tm6G7uB0w,4701
51
+ mlquantify-0.1.18.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
52
+ mlquantify-0.1.18.dist-info/top_level.txt,sha256=tGEkYkbbFElwULvqENjam3u1uXtyC1J9dRmibsq8_n0,11
53
+ mlquantify-0.1.18.dist-info/RECORD,,