mlquantify 0.1.12__py3-none-any.whl → 0.1.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -104,7 +104,7 @@ class ThresholdAdjustment(SoftLearnerQMixin, BaseAdjustCount):
104
104
  thresholds, tprs, fprs = evaluate_thresholds(train_y_values, positive_scores)
105
105
  threshold, tpr, fpr = self.get_best_threshold(thresholds, tprs, fprs)
106
106
 
107
- cc_predictions = CC(threshold).aggregate(predictions)[1]
107
+ cc_predictions = CC(threshold=threshold).aggregate(predictions, train_y_values)[1]
108
108
 
109
109
  if tpr - fpr == 0:
110
110
  prevalence = cc_predictions
@@ -514,7 +514,7 @@ class X_method(ThresholdAdjustment):
514
514
  *ECML*, pp. 564-575.
515
515
  """
516
516
  def get_best_threshold(self, thresholds, tprs, fprs):
517
- idx = np.argmin(np.abs(1 - (tprs + fprs)))
517
+ idx = np.argmin(np.abs((1-tprs) - fprs))
518
518
  return thresholds[idx], tprs[idx], fprs[idx]
519
519
 
520
520
 
@@ -601,7 +601,7 @@ class MS(ThresholdAdjustment):
601
601
 
602
602
  prevs = []
603
603
  for thr, tpr, fpr in zip(thresholds, tprs, fprs):
604
- cc_predictions = CC(thr).aggregate(predictions)
604
+ cc_predictions = CC(threshold=thr).aggregate(predictions, train_y_values)
605
605
  cc_predictions = cc_predictions[1]
606
606
 
607
607
  if tpr - fpr == 0:
@@ -204,7 +204,7 @@ class BaseAdjustCount(AggregationMixin, BaseQuantifier):
204
204
  self.learner = learner
205
205
 
206
206
  @_fit_context(prefer_skip_nested_validation=True)
207
- def fit(self, X, y, learner_fitted=False):
207
+ def fit(self, X, y, learner_fitted=False, cv=10, stratified=True, random_state=None, shuffle=True):
208
208
  """Fit the quantifier using the provided data and learner."""
209
209
  X, y = validate_data(self, X, y)
210
210
  validate_y(self, y)
@@ -220,10 +220,10 @@ class BaseAdjustCount(AggregationMixin, BaseQuantifier):
220
220
  X,
221
221
  y,
222
222
  function=learner_function,
223
- cv=5,
224
- stratified=True,
225
- random_state=None,
226
- shuffle=True
223
+ cv=cv,
224
+ stratified=stratified,
225
+ random_state=random_state,
226
+ shuffle=shuffle
227
227
  )
228
228
 
229
229
  self.train_predictions = train_predictions
@@ -241,7 +241,6 @@ class BaseAdjustCount(AggregationMixin, BaseQuantifier):
241
241
  self.classes_ = check_classes_attribute(self, np.unique(y_train_values))
242
242
 
243
243
  predictions = validate_predictions(self, predictions)
244
- train_predictions = validate_predictions(self, train_predictions)
245
244
 
246
245
  prevalences = self._adjust(predictions, train_predictions, y_train_values)
247
246
  prevalences = validate_prevalences(self, prevalences, self.classes_)
@@ -75,10 +75,12 @@ class CC(CrispLearnerQMixin, BaseCount):
75
75
  super().__init__(learner=learner)
76
76
  self.threshold = threshold
77
77
 
78
- def aggregate(self, predictions):
79
- predictions = validate_predictions(self, predictions)
78
+ def aggregate(self, predictions, train_y_values=None):
79
+ predictions = validate_predictions(self, predictions, self.threshold)
80
80
 
81
- self.classes_ = check_classes_attribute(self, np.unique(predictions))
81
+ if train_y_values is None:
82
+ train_y_values = np.unique(predictions)
83
+ self.classes_ = check_classes_attribute(self, np.unique(train_y_values))
82
84
  class_counts = np.array([np.count_nonzero(predictions == _class) for _class in self.classes_])
83
85
  prevalences = class_counts / len(predictions)
84
86
 
@@ -776,9 +776,9 @@ class QuaDapt(MetaquantifierMixin, BaseQuantifier):
776
776
  n_neg = n - n_pos
777
777
 
778
778
  # Scores positivos
779
- p_score = np.random.uniform(size=n_pos) ** merging_factor
779
+ p_score = np.random.uniform(size=n_pos) ** m
780
780
  # Scores negativos
781
- n_score = 1 - (np.random.uniform(size=n_neg) ** merging_factor)
781
+ n_score = 1 - (np.random.uniform(size=n_neg) ** m)
782
782
 
783
783
  # Construção dos arrays de features (duas colunas iguais)
784
784
  moss = np.column_stack(
mlquantify/multiclass.py CHANGED
@@ -337,7 +337,7 @@ class BinaryQuantifier(MetaquantifierMixin, BaseQuantifier):
337
337
  classes = np.unique(args_dict["y_train"])
338
338
  qtf.strategy = getattr(qtf, "strategy", "ovr")
339
339
 
340
- if hasattr(qtf, "binary") and qtf.binary:
340
+ if (hasattr(qtf, "binary") and qtf.binary) or len(classes) <= 2:
341
341
  return qtf._original_aggregate(*args_dict.values())
342
342
 
343
343
  if qtf.strategy == "ovr":
@@ -97,22 +97,22 @@ def validate_y(quantifier: Any, y: np.ndarray) -> None:
97
97
  def _get_valid_crisp_predictions(predictions, threshold=0.5):
98
98
  predictions = np.asarray(predictions)
99
99
 
100
- dimensions = predictions.shape[1] if len(predictions.shape) > 1 else 1
100
+ dimensions = predictions.ndim
101
101
 
102
102
  if dimensions > 2:
103
103
  predictions = np.argmax(predictions, axis=1)
104
104
  elif dimensions == 2:
105
- predictions = (predictions[:, 1] > threshold).astype(int)
105
+ predictions = (predictions[:, 1] >= threshold).astype(int)
106
106
  elif dimensions == 1:
107
107
  if np.issubdtype(predictions.dtype, np.floating):
108
- predictions = (predictions > threshold).astype(int)
108
+ predictions = (predictions >= threshold).astype(int)
109
109
  else:
110
110
  raise ValueError(f"Predictions array has an invalid number of dimensions. Expected 1 or more dimensions, got {predictions.ndim}.")
111
111
 
112
112
  return predictions
113
113
 
114
114
 
115
- def validate_predictions(quantifier: Any, predictions: np.ndarray) -> None:
115
+ def validate_predictions(quantifier: Any, predictions: np.ndarray, threshold: float = 0.5) -> np.ndarray:
116
116
  """
117
117
  Validate predictions using the quantifier's declared output tags.
118
118
  Raises InputValidationError if inconsistent with tags.
@@ -132,7 +132,7 @@ def validate_predictions(quantifier: Any, predictions: np.ndarray) -> None:
132
132
  f"Soft predictions for {quantifier.__class__.__name__} must be float, got dtype {predictions.dtype}."
133
133
  )
134
134
  elif estimator_type == "crisp" and np.issubdtype(predictions.dtype, np.floating):
135
- predictions = _get_valid_crisp_predictions(predictions)
135
+ predictions = _get_valid_crisp_predictions(predictions, threshold)
136
136
  return predictions
137
137
 
138
138
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mlquantify
3
- Version: 0.1.12
3
+ Version: 0.1.14
4
4
  Summary: Quantification Library
5
5
  Home-page: https://github.com/luizfernandolj/QuantifyML/tree/master
6
6
  Maintainer: Luiz Fernando Luth Junior
@@ -3,17 +3,17 @@ mlquantify/base.py,sha256=o7IaKODocyi4tEmCvGmHKQ8F4ZJsaEh4kymsNcLyHAg,5077
3
3
  mlquantify/base_aggregative.py,sha256=uqfhpUmgv5pNLLvqgROCWHfjs3sj_2jfwOTyzUySuGo,7545
4
4
  mlquantify/calibration.py,sha256=chG3GNX2BBDTWIuSVfZUJ_YF_ZVBSoel2d_AN0OChS0,6
5
5
  mlquantify/confidence.py,sha256=QkEWr6s-Su3Nbinia_TRQbBeTM6ymDPe7Bv204XBKKA,10799
6
- mlquantify/multiclass.py,sha256=Jux0fvL5IBZA3DXLCuqUEE77JYYBGAcW6GaEH9srmu4,11747
6
+ mlquantify/multiclass.py,sha256=wFbbXKqGsFVSsI9zC0EHGYyyx1JRxFpzMi_q8l80TUM,11770
7
7
  mlquantify/adjust_counting/__init__.py,sha256=AWio99zeaUULQq9vKggkFhnq-tqgXxasQt167NdcNVY,307
8
- mlquantify/adjust_counting/_adjustment.py,sha256=x0i_jAWCw2UP9Gt20EteYxLmCr1Xh_AbISwFRbOVoI8,23234
9
- mlquantify/adjust_counting/_base.py,sha256=tbYq2Efaxsub_vzXoMOR-J6SZlK6K8oRr5UvSSsjVvs,9428
10
- mlquantify/adjust_counting/_counting.py,sha256=7Ip7-XHQJcTWcWVDaLzEIM6WYcp8k5axsCIyD3QPWZE,5572
8
+ mlquantify/adjust_counting/_adjustment.py,sha256=aQ92-wfRF1TT_d3kQecFleSWE8CdGBJiZx3bbePuXM0,23284
9
+ mlquantify/adjust_counting/_base.py,sha256=MjBsNG7wE0Z_KToXX8WbthhVvz-yc0-d2zIqPo1CB9g,9429
10
+ mlquantify/adjust_counting/_counting.py,sha256=6PKea54xvsga8spNEbsngKNQPyGUXzOkCRyXQR8rTdo,5699
11
11
  mlquantify/adjust_counting/_utils.py,sha256=DEPNzvcr0KszCnfUJaRzBilwWzuNVMSdy5eV7aQ_JPE,2907
12
12
  mlquantify/likelihood/__init__.py,sha256=3dC5uregNmquUKz0r0-3aPspfjZjKGn3TRBoZPO1uFs,53
13
13
  mlquantify/likelihood/_base.py,sha256=seu_Vb58QttcGbFjHKAplMYGZcVbIHqkyTXEK2cax9A,5830
14
14
  mlquantify/likelihood/_classes.py,sha256=PZ31cAwO8q5X3O2_oSmQ1FM6bY4EsB8hWEcAgcEmWXQ,14731
15
15
  mlquantify/meta/__init__.py,sha256=GzdGw4ky_kmd5VNWiLBULy06IdN_MLCDAuJKbnMOx4s,62
16
- mlquantify/meta/_classes.py,sha256=JAnMS4bu2XHXI_sSZUfcW_uIXRanoA0NIS3uN6dWSv4,30956
16
+ mlquantify/meta/_classes.py,sha256=RKEVghPMBlyv516xrUtTyUkHvC2-5IsTUO_oVwAt3Gw,30930
17
17
  mlquantify/metrics/__init__.py,sha256=3bzzjSYTgrZIJsfAgJidQlB-bnjInwVYUvJ34bPhZxY,186
18
18
  mlquantify/metrics/_oq.py,sha256=koXDKeHWksl_vHpZuhc2pAps8wvu_MOgEztlSr04MmE,3544
19
19
  mlquantify/metrics/_rq.py,sha256=3yiEmGaRAGpzL29Et3tNqkJ3RMsLXwUX3uL9RoIgi40,3034
@@ -45,9 +45,9 @@ mlquantify/utils/_parallel.py,sha256=XotpX9nsj6nW-tNCmZ-ahTcRztgnn9oQKP2cl1rLdYM
45
45
  mlquantify/utils/_random.py,sha256=7F3nyy7Pa_kN8xP8P1L6MOM4WFu4BirE7bOfGTZ1Spk,1275
46
46
  mlquantify/utils/_sampling.py,sha256=QQxE2WKLdiCFUfPF6fKgzyrsOUIWYf74w_w8fbYVc2c,8409
47
47
  mlquantify/utils/_tags.py,sha256=Rz78TLpxgVxBKS0mKTlC9Qo_kn6HaEwVKNXh8pxFT7M,1095
48
- mlquantify/utils/_validation.py,sha256=yR5zqh_c7OHPnuMFBgKbrdU1bG-oXL2thojFEzydzWs,16798
48
+ mlquantify/utils/_validation.py,sha256=TGGnfv7F5rnQmVeSqGMuS9AP76O974b1TPishKCCWls,16800
49
49
  mlquantify/utils/prevalence.py,sha256=FXLCJViQb2yDbyTXeGZt8WsPPnSZINhorQYZTKXOn14,1772
50
- mlquantify-0.1.12.dist-info/METADATA,sha256=qMZWMClRDNjUuFjuiAGhC7aDA3r9hlECzSbyoSLlQ-4,4701
51
- mlquantify-0.1.12.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
52
- mlquantify-0.1.12.dist-info/top_level.txt,sha256=tGEkYkbbFElwULvqENjam3u1uXtyC1J9dRmibsq8_n0,11
53
- mlquantify-0.1.12.dist-info/RECORD,,
50
+ mlquantify-0.1.14.dist-info/METADATA,sha256=-uPQqaXhgbXEq9M4EWYGKK9t3RRGNpbZY436DCL3bog,4701
51
+ mlquantify-0.1.14.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
52
+ mlquantify-0.1.14.dist-info/top_level.txt,sha256=tGEkYkbbFElwULvqENjam3u1uXtyC1J9dRmibsq8_n0,11
53
+ mlquantify-0.1.14.dist-info/RECORD,,