mlquantify 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,5 +1,6 @@
1
1
  from abc import abstractmethod
2
2
  import numpy as np
3
+ import warnings
3
4
  from sklearn.base import BaseEstimator
4
5
 
5
6
  from ..base import AggregativeQuantifier
@@ -446,9 +447,8 @@ class MS(ThresholdOptimization):
446
447
  {0: 0.3991228070175439, 1: 0.6008771929824561}
447
448
  """
448
449
 
449
- def __init__(self, learner: BaseEstimator=None, threshold: float = 0.5):
450
+ def __init__(self, learner: BaseEstimator=None):
450
451
  super().__init__(learner)
451
- self.threshold = threshold
452
452
 
453
453
  def best_tprfpr(self, thresholds: np.ndarray, tprs: np.ndarray, fprs: np.ndarray) -> tuple:
454
454
  """
@@ -481,11 +481,42 @@ class MS(ThresholdOptimization):
481
481
  ValueError
482
482
  If `thresholds`, `tprs`, or `fprs` are empty or have mismatched lengths.
483
483
  """
484
- # Compute median TPR and FPR
485
- tpr = np.median(tprs)
486
- fpr = np.median(fprs)
487
484
 
488
- return (self.threshold, tpr, fpr)
485
+ return (thresholds, tprs, fprs)
486
+
487
+ def _predict_method(self, X) -> dict:
488
+ """
489
+ Predicts class prevalences using the adjusted threshold.
490
+
491
+ Parameters
492
+ ----------
493
+ X : pd.DataFrame or np.ndarray
494
+ The input features for prediction.
495
+
496
+ Returns
497
+ -------
498
+ np.ndarray
499
+ An array of predicted prevalences for the classes.
500
+ """
501
+ # Get predicted probabilities for the positive class
502
+ probabilities = self.predict_learner(X)[:, 1]
503
+
504
+ prevs = []
505
+
506
+ for thr, tpr, fpr in zip(self.threshold, self.tpr, self.fpr):
507
+ cc_output = len(probabilities[probabilities >= thr]) / len(probabilities)
508
+
509
+ if tpr - fpr == 0:
510
+ prevalence = cc_output
511
+ else:
512
+ prev = np.clip((cc_output - fpr) / (tpr - fpr), 0, 1)
513
+ prevs.append(prev)
514
+
515
+ prevalence = np.median(prevs)
516
+
517
+ prevalences = [1 - prevalence, prevalence]
518
+
519
+ return np.asarray(prevalences)
489
520
 
490
521
 
491
522
 
@@ -570,33 +601,63 @@ class MS2(ThresholdOptimization):
570
601
  - The median threshold value for cases meeting the condition (float).
571
602
  - The median true positive rate for cases meeting the condition (float).
572
603
  - The median false positive rate for cases meeting the condition (float).
573
-
604
+
574
605
  Raises
575
606
  ------
576
607
  ValueError
577
- If no cases satisfy the condition `|TPR - FPR| > 0.25` or if the
578
- input arrays are empty or have mismatched lengths.
608
+ If no cases satisfy the condition `|TPR - FPR| > 0.25`.
609
+ Warning
610
+ If all TPR or FPR values are zero.
579
611
  """
612
+ # Check if all TPR or FPR values are zero
613
+ if np.all(tprs == 0) or np.all(fprs == 0):
614
+ warnings.warn("All TPR or FPR values are zero.")
615
+
580
616
  # Identify indices where the condition is satisfied
581
617
  indices = np.where(np.abs(tprs - fprs) > 0.25)[0]
582
618
  if len(indices) == 0:
583
- raise ValueError("No cases meet the condition |TPR - FPR| > 0.25.")
619
+ warnings.warn("No cases satisfy the condition |TPR - FPR| > 0.25.")
620
+ indices = np.where(np.abs(tprs - fprs) >= 0)[0]
621
+
622
+ thresholds_ = thresholds[indices]
623
+ tprs_ = tprs[indices]
624
+ fprs_ = fprs[indices]
584
625
 
585
- # Compute medians for the selected cases
586
- threshold = np.median(thresholds[indices])
587
- tpr = np.median(tprs[indices])
588
- fpr = np.median(fprs[indices])
589
-
590
- return (threshold, tpr, fpr)
591
-
592
-
593
-
594
-
595
-
626
+ return (thresholds_, tprs_, fprs_)
596
627
 
628
+ def _predict_method(self, X) -> dict:
629
+ """
630
+ Predicts class prevalences using the adjusted threshold.
597
631
 
632
+ Parameters
633
+ ----------
634
+ X : pd.DataFrame or np.ndarray
635
+ The input features for prediction.
598
636
 
637
+ Returns
638
+ -------
639
+ np.ndarray
640
+ An array of predicted prevalences for the classes.
641
+ """
642
+ # Get predicted probabilities for the positive class
643
+ probabilities = self.predict_learner(X)[:, 1]
644
+
645
+ prevs = []
646
+
647
+ for thr, tpr, fpr in zip(self.threshold, self.tpr, self.fpr):
648
+ cc_output = len(probabilities[probabilities >= thr]) / len(probabilities)
649
+
650
+ if tpr - fpr == 0:
651
+ prevalence = cc_output
652
+ else:
653
+ prev = np.clip((cc_output - fpr) / (tpr - fpr), 0, 1)
654
+ prevs.append(prev)
655
+
656
+ prevalence = np.median(prevs)
657
+
658
+ prevalences = [1 - prevalence, prevalence]
599
659
 
660
+ return np.asarray(prevalences)
600
661
 
601
662
  class PACC(ThresholdOptimization):
602
663
  """
@@ -131,7 +131,7 @@ class GridSearchQ(Quantifier):
131
131
  model: Quantifier,
132
132
  param_grid: dict,
133
133
  protocol: str = 'app',
134
- n_prevs: int = None,
134
+ n_prevs: int = 100,
135
135
  n_repetitions: int = 1,
136
136
  scoring: Union[List[str], str] = "ae",
137
137
  refit: bool = True,
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.2
1
+ Metadata-Version: 2.4
2
2
  Name: mlquantify
3
- Version: 0.1.0
3
+ Version: 0.1.2
4
4
  Summary: Quantification Library
5
5
  Home-page: https://github.com/luizfernandolj/QuantifyML/tree/master
6
6
  Maintainer: Luiz Fernando Luth Junior
@@ -58,7 +58,7 @@ pip install mlquantify
58
58
  If you only want to update, run the code below:
59
59
 
60
60
  ```bash
61
- pip install --update mlquantify
61
+ pip install --upgrade mlquantify
62
62
  ```
63
63
 
64
64
  ___
@@ -1,6 +1,6 @@
1
1
  mlquantify/__init__.py,sha256=Q9jCEkG0EJoHXukrxh194mhO_Yfu-BZPRfjpQ4T1XlQ,978
2
2
  mlquantify/base.py,sha256=hJ9FYYNGeO5-WJlpJpsUiu_LQL1fimvZPPNsKptxN7w,19196
3
- mlquantify/model_selection.py,sha256=deMdTKarMu4DtfSy6ucTIxQKj1yrC-x3nhBT77kMOtI,12679
3
+ mlquantify/model_selection.py,sha256=rPR4fwwxuihzx5Axq4NhMOeuMBzpoC9pKp5taYNt_LY,12678
4
4
  mlquantify/plots.py,sha256=9XOhx4QXkN9RkkiErLuL90FWIBUV2YTEJNT4Jwfy0ac,12380
5
5
  mlquantify/classification/__init__.py,sha256=3FGf-F4SOM3gByUPsWdnBzjyC_31B3MtzuolEuocPls,22
6
6
  mlquantify/classification/methods.py,sha256=yDSbpoqM3hfF0a9ATzKqfG9S-44x-0Rq0lkAVJKTIEs,5006
@@ -12,11 +12,11 @@ mlquantify/methods/aggregative.py,sha256=rL_xlX2nYECrxFSjBJNlxj6h3b-iIs7l_XgxIRS
12
12
  mlquantify/methods/meta.py,sha256=sZWQHUGkm6iiqujmIpHDL_8tDdKQ161bzD5mcpXLWEY,19066
13
13
  mlquantify/methods/mixture_models.py,sha256=si2Pzaka5Kbva4QKBzLolvb_8V0ZEjp68UBAiOwl49s,35166
14
14
  mlquantify/methods/non_aggregative.py,sha256=xaBu21TUtiYkOEUKO16NaNMwdNa6-SNjfBsc5PpIMyI,4815
15
- mlquantify/methods/threshold_optimization.py,sha256=P88VXG-czZiaHSHTGnzFmZVzm3SoJHnrmi60Zvv7IJU,33726
15
+ mlquantify/methods/threshold_optimization.py,sha256=-iOcP5YcXZd0XZHGvbmcoE72hXR6D9YCoTnr1l80-9k,35796
16
16
  mlquantify/utils/__init__.py,sha256=logWrL6B6mukP8tvYm_UPEdO9eNA-J-ySILr7-syDoc,44
17
17
  mlquantify/utils/general.py,sha256=Li5ix_dy19dUhYNgiUsNHdqqnSVYvznUBUuyr-zYSPI,7554
18
18
  mlquantify/utils/method.py,sha256=RL4vBJGl5_6DZ59Bs62hdNXI_hnoDIWilMMyMPiOjBg,12631
19
- mlquantify-0.1.0.dist-info/METADATA,sha256=6ud6gvzxxaQr7oZLD3fu3piid1ZHjJuHAQZzZeUw7Rs,4939
20
- mlquantify-0.1.0.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
21
- mlquantify-0.1.0.dist-info/top_level.txt,sha256=tGEkYkbbFElwULvqENjam3u1uXtyC1J9dRmibsq8_n0,11
22
- mlquantify-0.1.0.dist-info/RECORD,,
19
+ mlquantify-0.1.2.dist-info/METADATA,sha256=2j3pqrm5djMAPm7bKTIjBjtg71OzAbFpwC-_ofOoSlc,4940
20
+ mlquantify-0.1.2.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
21
+ mlquantify-0.1.2.dist-info/top_level.txt,sha256=tGEkYkbbFElwULvqENjam3u1uXtyC1J9dRmibsq8_n0,11
22
+ mlquantify-0.1.2.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.8.0)
2
+ Generator: setuptools (78.1.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5