mlquantify 0.1.20__py3-none-any.whl → 0.1.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlquantify/__init__.py +2 -1
- mlquantify/adjust_counting/__init__.py +6 -5
- mlquantify/adjust_counting/_adjustment.py +208 -37
- mlquantify/adjust_counting/_base.py +5 -6
- mlquantify/adjust_counting/_counting.py +10 -7
- mlquantify/likelihood/__init__.py +0 -2
- mlquantify/likelihood/_classes.py +45 -199
- mlquantify/meta/_classes.py +12 -12
- mlquantify/mixture/__init__.py +2 -1
- mlquantify/mixture/_classes.py +310 -15
- mlquantify/model_selection/_search.py +1 -1
- mlquantify/neighbors/_base.py +15 -15
- mlquantify/neighbors/_classes.py +2 -2
- mlquantify/neighbors/_kde.py +6 -6
- mlquantify/neural/__init__.py +1 -1
- mlquantify/neural/_base.py +0 -0
- mlquantify/neural/_classes.py +611 -0
- mlquantify/neural/_perm_invariant.py +0 -0
- mlquantify/neural/_utils.py +0 -0
- mlquantify/utils/__init__.py +2 -1
- mlquantify/utils/_constraints.py +2 -0
- mlquantify/utils/_validation.py +9 -0
- {mlquantify-0.1.20.dist-info → mlquantify-0.1.22.dist-info}/METADATA +13 -18
- {mlquantify-0.1.20.dist-info → mlquantify-0.1.22.dist-info}/RECORD +27 -23
- {mlquantify-0.1.20.dist-info → mlquantify-0.1.22.dist-info}/WHEEL +1 -1
- mlquantify-0.1.22.dist-info/licenses/LICENSE +28 -0
- mlquantify/likelihood/_base.py +0 -147
- {mlquantify-0.1.20.dist-info → mlquantify-0.1.22.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: mlquantify
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.22
|
|
4
4
|
Summary: Quantification Library
|
|
5
5
|
Home-page: https://github.com/luizfernandolj/QuantifyML/tree/master
|
|
6
6
|
Maintainer: Luiz Fernando Luth Junior
|
|
@@ -12,6 +12,7 @@ Classifier: Operating System :: Unix
|
|
|
12
12
|
Classifier: Operating System :: MacOS :: MacOS X
|
|
13
13
|
Classifier: Operating System :: Microsoft :: Windows
|
|
14
14
|
Description-Content-Type: text/markdown
|
|
15
|
+
License-File: LICENSE
|
|
15
16
|
Requires-Dist: scikit-learn
|
|
16
17
|
Requires-Dist: numpy
|
|
17
18
|
Requires-Dist: scipy
|
|
@@ -26,25 +27,23 @@ Dynamic: description
|
|
|
26
27
|
Dynamic: description-content-type
|
|
27
28
|
Dynamic: home-page
|
|
28
29
|
Dynamic: keywords
|
|
30
|
+
Dynamic: license-file
|
|
29
31
|
Dynamic: maintainer
|
|
30
32
|
Dynamic: requires-dist
|
|
31
33
|
Dynamic: summary
|
|
32
34
|
|
|
33
|
-
|
|
35
|
+

|
|
36
|
+
[](https://github.com/luizfernandolj/mlquantify/)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
<a href="https://luizfernandolj.github.io/mlquantify/"><img src="assets/logo_mlquantify-white.svg" alt="mlquantify logo"></a>
|
|
34
40
|
<h4 align="center">A Python Package for Quantification</h4>
|
|
35
41
|
|
|
36
42
|
___
|
|
37
43
|
|
|
38
44
|
**mlquantify** is a Python library for quantification, also known as supervised prevalence estimation, designed to estimate the distribution of classes within datasets. It offers a range of tools for various quantification methods, model selection tailored for quantification tasks, evaluation metrics, and protocols to assess quantification performance. Additionally, mlquantify includes popular datasets and visualization tools to help analyze and interpret results.
|
|
39
45
|
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
## Latest Release
|
|
43
|
-
|
|
44
|
-
- **Version 0.1.11**: Inicial beta version. For a detailed list of changes, check the [changelog](#).
|
|
45
|
-
- In case you need any help, refer to the [User Guide](https://luizfernandolj.github.io/mlquantify/user_guide.html).
|
|
46
|
-
- Explore the [API documentation](https://luizfernandolj.github.io/mlquantify/api/index.html) for detailed developer information.
|
|
47
|
-
- See also the library in the pypi site in [pypi mlquantify](https://pypi.org/project/mlquantify/)
|
|
46
|
+
Website: https://luizfernandolj.github.io/mlquantify/
|
|
48
47
|
|
|
49
48
|
___
|
|
50
49
|
|
|
@@ -112,6 +111,10 @@ print(f"Mean Absolute Error -> {mae}")
|
|
|
112
111
|
print(f"Normalized Relative Absolute Error -> {nrae}")
|
|
113
112
|
```
|
|
114
113
|
|
|
114
|
+
- In case you need any help, refer to the [User Guide](https://luizfernandolj.github.io/mlquantify/user_guide.html).
|
|
115
|
+
- Explore the [API documentation](https://luizfernandolj.github.io/mlquantify/api/index.html) for detailed developer information.
|
|
116
|
+
- See also the library in the pypi site in [pypi mlquantify](https://pypi.org/project/mlquantify/)
|
|
117
|
+
|
|
115
118
|
___
|
|
116
119
|
|
|
117
120
|
## Requirements
|
|
@@ -123,11 +126,3 @@ ___
|
|
|
123
126
|
- tqdm
|
|
124
127
|
- matplotlib
|
|
125
128
|
- xlrd
|
|
126
|
-
|
|
127
|
-
___
|
|
128
|
-
|
|
129
|
-
## Documentation
|
|
130
|
-
|
|
131
|
-
##### API is avaliable [here](https://luizfernandolj.github.io/mlquantify/api/)
|
|
132
|
-
|
|
133
|
-
___
|
|
@@ -1,41 +1,44 @@
|
|
|
1
|
-
mlquantify/__init__.py,sha256=
|
|
1
|
+
mlquantify/__init__.py,sha256=gFtdd0MKMje_yVztHzEsKA3p-JF90bO-VQxzSplyovU,348
|
|
2
2
|
mlquantify/base.py,sha256=o7IaKODocyi4tEmCvGmHKQ8F4ZJsaEh4kymsNcLyHAg,5077
|
|
3
3
|
mlquantify/base_aggregative.py,sha256=uqfhpUmgv5pNLLvqgROCWHfjs3sj_2jfwOTyzUySuGo,7545
|
|
4
4
|
mlquantify/calibration.py,sha256=chG3GNX2BBDTWIuSVfZUJ_YF_ZVBSoel2d_AN0OChS0,6
|
|
5
5
|
mlquantify/confidence.py,sha256=QkEWr6s-Su3Nbinia_TRQbBeTM6ymDPe7Bv204XBKKA,10799
|
|
6
6
|
mlquantify/multiclass.py,sha256=wFbbXKqGsFVSsI9zC0EHGYyyx1JRxFpzMi_q8l80TUM,11770
|
|
7
|
-
mlquantify/adjust_counting/__init__.py,sha256=
|
|
8
|
-
mlquantify/adjust_counting/_adjustment.py,sha256=
|
|
9
|
-
mlquantify/adjust_counting/_base.py,sha256=
|
|
10
|
-
mlquantify/adjust_counting/_counting.py,sha256=
|
|
7
|
+
mlquantify/adjust_counting/__init__.py,sha256=3Qsgd-UN-qTv2jO47CzgGZ56dubQ3E7u_L4fBfz7i0o,309
|
|
8
|
+
mlquantify/adjust_counting/_adjustment.py,sha256=GKvJafMHmONQIhY9Dl2Stoy2ROOZWYWaFZqtRUk6ERY,29430
|
|
9
|
+
mlquantify/adjust_counting/_base.py,sha256=gZtxyMut5CFgQChlWQyNtMR5xkq9NbMmtf3mgS_h060,9425
|
|
10
|
+
mlquantify/adjust_counting/_counting.py,sha256=90Rah5E8lfQrathH7Bw4E1-v7Wt-gsI6Rc6t0CD7Btk,5779
|
|
11
11
|
mlquantify/adjust_counting/_utils.py,sha256=DEPNzvcr0KszCnfUJaRzBilwWzuNVMSdy5eV7aQ_JPE,2907
|
|
12
|
-
mlquantify/likelihood/__init__.py,sha256=
|
|
13
|
-
mlquantify/likelihood/
|
|
14
|
-
mlquantify/likelihood/_classes.py,sha256=PZ31cAwO8q5X3O2_oSmQ1FM6bY4EsB8hWEcAgcEmWXQ,14731
|
|
12
|
+
mlquantify/likelihood/__init__.py,sha256=vkeh_5Mb6MFcz3BZTR4MLwO5AU5Qgr_2y9uWmb9uyEQ,35
|
|
13
|
+
mlquantify/likelihood/_classes.py,sha256=RUFd2cfGV9DTeRJ4g2OI3BFwRupyC3XixwR_zpDkqfg,9713
|
|
15
14
|
mlquantify/meta/__init__.py,sha256=GzdGw4ky_kmd5VNWiLBULy06IdN_MLCDAuJKbnMOx4s,62
|
|
16
|
-
mlquantify/meta/_classes.py,sha256=
|
|
15
|
+
mlquantify/meta/_classes.py,sha256=RwM7UmfcWpdLDMCF4CFOjjPLP0U8h0R-qwwjnR3y34Q,30807
|
|
17
16
|
mlquantify/metrics/__init__.py,sha256=3bzzjSYTgrZIJsfAgJidQlB-bnjInwVYUvJ34bPhZxY,186
|
|
18
17
|
mlquantify/metrics/_oq.py,sha256=koXDKeHWksl_vHpZuhc2pAps8wvu_MOgEztlSr04MmE,3544
|
|
19
18
|
mlquantify/metrics/_rq.py,sha256=3yiEmGaRAGpzL29Et3tNqkJ3RMsLXwUX3uL9RoIgi40,3034
|
|
20
19
|
mlquantify/metrics/_slq.py,sha256=JZceO2LR3mjbT_0zVcl9xI6jf8pn3tIcpP3vP3Luf9I,6817
|
|
21
|
-
mlquantify/mixture/__init__.py,sha256=
|
|
20
|
+
mlquantify/mixture/__init__.py,sha256=3JPPxFqPPU9aTghLc41G3La36dc89x5vpekdg6ug9lQ,84
|
|
22
21
|
mlquantify/mixture/_base.py,sha256=1-yW64FPQXB_d9hH9KjSlDnmFtW9FY7S2hppXAd1DBg,5645
|
|
23
|
-
mlquantify/mixture/_classes.py,sha256=
|
|
22
|
+
mlquantify/mixture/_classes.py,sha256=n6P47xo740TljCAuMfFdrAHoYUWdiZOTuCMP3Lw_uGE,25656
|
|
24
23
|
mlquantify/mixture/_utils.py,sha256=CKlC081nrkJ8Pil7lrPZvNZC_xfpXV8SsuQq3M_LHgA,4037
|
|
25
24
|
mlquantify/model_selection/__init__.py,sha256=98I0uf8k6lbWAjazGyGjbOdPOvzU8aMRLqC3I7D3jzk,113
|
|
26
25
|
mlquantify/model_selection/_protocol.py,sha256=XhkNUN-XAuGkihm0jwQL665ps2G9bevxme_yrETNQHo,12902
|
|
27
|
-
mlquantify/model_selection/_search.py,sha256=
|
|
26
|
+
mlquantify/model_selection/_search.py,sha256=ie2KI3WOKa5ejCkIiemEDczAlURjaArcBqHxIGHCuic,10696
|
|
28
27
|
mlquantify/model_selection/_split.py,sha256=chG3GNX2BBDTWIuSVfZUJ_YF_ZVBSoel2d_AN0OChS0,6
|
|
29
28
|
mlquantify/neighbors/__init__.py,sha256=rIOuSaUhjqEXsUN9HNZ62P53QG0N7lJ3j1pvf8kJzms,93
|
|
30
|
-
mlquantify/neighbors/_base.py,sha256=
|
|
31
|
-
mlquantify/neighbors/_classes.py,sha256=
|
|
29
|
+
mlquantify/neighbors/_base.py,sha256=Dje7wXn3iTFurbSksfbyF9mFSc-vxD2cLnBLShywkck,6470
|
|
30
|
+
mlquantify/neighbors/_classes.py,sha256=AcZxperEN31Mufh6aAsiZZ8rsG9vneoNNoEF_yMRy78,5258
|
|
32
31
|
mlquantify/neighbors/_classification.py,sha256=8xNqaTQXUGg_dbQd6SqwKWb07BM2QM0uwZeXZ5C_DMs,4136
|
|
33
|
-
mlquantify/neighbors/_kde.py,sha256=
|
|
32
|
+
mlquantify/neighbors/_kde.py,sha256=x6DkONFhCec44HPFY5H2DTXjMVKPQkz2kEUmfxx5DrM,9889
|
|
34
33
|
mlquantify/neighbors/_utils.py,sha256=CozcKtmd6ZDluMT4bvOj4QI7xwORF_vCIJRucPEzJJo,4123
|
|
35
|
-
mlquantify/neural/__init__.py,sha256=
|
|
36
|
-
mlquantify/
|
|
34
|
+
mlquantify/neural/__init__.py,sha256=UFHkMnUCyqi6zmH6YZr6aTq8v0ndbzTWS9se14rSEx8,28
|
|
35
|
+
mlquantify/neural/_base.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
36
|
+
mlquantify/neural/_classes.py,sha256=PCX4-LH3QhgEav44u12bhP_T1BHGNoNcB2aVJ4VWIXk,23373
|
|
37
|
+
mlquantify/neural/_perm_invariant.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
38
|
+
mlquantify/neural/_utils.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
39
|
+
mlquantify/utils/__init__.py,sha256=4_Mnerh2GV6mOiKYW9SX_mjMTxqKeai0umBCnVj4vow,1370
|
|
37
40
|
mlquantify/utils/_artificial.py,sha256=6tqMoAuxUULFGHXtMez56re4DZ7d2Q6tK55LPGeEiO8,713
|
|
38
|
-
mlquantify/utils/_constraints.py,sha256=
|
|
41
|
+
mlquantify/utils/_constraints.py,sha256=wULvW1Kkj6FAlA6Zv541jf1ydJYpJsurtITetzQabEE,5679
|
|
39
42
|
mlquantify/utils/_context.py,sha256=25QmzmfSiuF_hwCjY_7db_XfCnj1dVe4mIbDycVTHf8,661
|
|
40
43
|
mlquantify/utils/_decorators.py,sha256=yYtnPBh1sLSN6wTY-7ZVAV0j--qbpJxBsgncm794JPc,1205
|
|
41
44
|
mlquantify/utils/_exceptions.py,sha256=C3BQSv3-7QDLaorKcV-ANxnBcSaxHQSlCc6YSZrPK6c,392
|
|
@@ -45,9 +48,10 @@ mlquantify/utils/_parallel.py,sha256=XotpX9nsj6nW-tNCmZ-ahTcRztgnn9oQKP2cl1rLdYM
|
|
|
45
48
|
mlquantify/utils/_random.py,sha256=7F3nyy7Pa_kN8xP8P1L6MOM4WFu4BirE7bOfGTZ1Spk,1275
|
|
46
49
|
mlquantify/utils/_sampling.py,sha256=3W0vUuvLvoYrt-BZpSM0HM1XJEZr0XYIdkOcUP5hp-8,8350
|
|
47
50
|
mlquantify/utils/_tags.py,sha256=Rz78TLpxgVxBKS0mKTlC9Qo_kn6HaEwVKNXh8pxFT7M,1095
|
|
48
|
-
mlquantify/utils/_validation.py,sha256=
|
|
51
|
+
mlquantify/utils/_validation.py,sha256=V3y-wbH69yFdogcEtQ15ShvtIrlIZ4ObzduaqLYvXp0,18319
|
|
49
52
|
mlquantify/utils/prevalence.py,sha256=LG-KXJ5Eb4w26WMpu4PoBpxMSHaqrmTQqdRlyqNRJ1o,2020
|
|
50
|
-
mlquantify-0.1.
|
|
51
|
-
mlquantify-0.1.
|
|
52
|
-
mlquantify-0.1.
|
|
53
|
-
mlquantify-0.1.
|
|
53
|
+
mlquantify-0.1.22.dist-info/licenses/LICENSE,sha256=DyKiou3Tffi-9NKcNBomiuHedeiF9sDC2Y9tZK_3Sko,1539
|
|
54
|
+
mlquantify-0.1.22.dist-info/METADATA,sha256=Smq1DW-R4IAu7X1ZXxxYy68_oMeoguVhFvDTcyaYdjk,4791
|
|
55
|
+
mlquantify-0.1.22.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
56
|
+
mlquantify-0.1.22.dist-info/top_level.txt,sha256=tGEkYkbbFElwULvqENjam3u1uXtyC1J9dRmibsq8_n0,11
|
|
57
|
+
mlquantify-0.1.22.dist-info/RECORD,,
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
BSD 3-Clause License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025, Luiz Fernando Luth Junior and Andre Gustavo Maletzke
|
|
4
|
+
|
|
5
|
+
Redistribution and use in source and binary forms, with or without
|
|
6
|
+
modification, are permitted provided that the following conditions are met:
|
|
7
|
+
|
|
8
|
+
1. Redistributions of source code must retain the above copyright notice, this
|
|
9
|
+
list of conditions and the following disclaimer.
|
|
10
|
+
|
|
11
|
+
2. Redistributions in binary form must reproduce the above copyright notice,
|
|
12
|
+
this list of conditions and the following disclaimer in the documentation
|
|
13
|
+
and/or other materials provided with the distribution.
|
|
14
|
+
|
|
15
|
+
3. Neither the name of the copyright holder nor the names of its
|
|
16
|
+
contributors may be used to endorse or promote products derived from
|
|
17
|
+
this software without specific prior written permission.
|
|
18
|
+
|
|
19
|
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
20
|
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
21
|
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
22
|
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
23
|
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
24
|
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
25
|
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
26
|
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
27
|
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
28
|
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
mlquantify/likelihood/_base.py
DELETED
|
@@ -1,147 +0,0 @@
|
|
|
1
|
-
import numpy as np
|
|
2
|
-
from abc import abstractmethod
|
|
3
|
-
|
|
4
|
-
from mlquantify.base import BaseQuantifier
|
|
5
|
-
|
|
6
|
-
from mlquantify.base_aggregative import (
|
|
7
|
-
AggregationMixin,
|
|
8
|
-
_get_learner_function
|
|
9
|
-
)
|
|
10
|
-
from mlquantify.adjust_counting import CC
|
|
11
|
-
from mlquantify.utils._decorators import _fit_context
|
|
12
|
-
from mlquantify.utils._validation import check_classes_attribute, validate_predictions, validate_y, validate_data, validate_prevalences
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
class BaseIterativeLikelihood(AggregationMixin, BaseQuantifier):
|
|
17
|
-
r"""Iterative likelihood-based quantification adjustment methods.
|
|
18
|
-
|
|
19
|
-
This base class encompasses quantification approaches that estimate class prevalences
|
|
20
|
-
by maximizing the likelihood of observed data, adjusting prevalence estimates on test
|
|
21
|
-
sets under the assumption of prior probability shift.
|
|
22
|
-
|
|
23
|
-
These methods iteratively refine estimates of class prevalences by maximizing the
|
|
24
|
-
likelihood of classifier outputs, usually the posterior probabilities provided by
|
|
25
|
-
a trained model, assuming that the class-conditional distributions remain fixed
|
|
26
|
-
between training and test domains.
|
|
27
|
-
|
|
28
|
-
Mathematical formulation
|
|
29
|
-
------------------------
|
|
30
|
-
Let:
|
|
31
|
-
|
|
32
|
-
- :math:`p_k^t` be the prior probabilities for class \(k\) in the training set, satisfying \( \sum_k p_k^t = 1 \),
|
|
33
|
-
- :math:`s_k(x)` be the posterior probability estimate from the classifier for class \(k\) given instance \(x\),
|
|
34
|
-
- :math:`p_k` be the unknown prior probabilities for class \(k\) in the test set,
|
|
35
|
-
- \( x_1, \dots, x_N \) be unlabeled test set instances.
|
|
36
|
-
|
|
37
|
-
The likelihood of the observed data is:
|
|
38
|
-
|
|
39
|
-
.. math::
|
|
40
|
-
|
|
41
|
-
L = \prod_{i=1}^N \sum_{k=1}^K s_k(x_i) \frac{p_k}{p_k^t}
|
|
42
|
-
|
|
43
|
-
Methods in this class seek a solution that maximizes this likelihood via iterative methods.
|
|
44
|
-
|
|
45
|
-
Notes
|
|
46
|
-
-----
|
|
47
|
-
- Applicable to binary and multiclass problems as long as the classifier provides calibrated posterior probabilities.
|
|
48
|
-
- Assumes changes only in prior probabilities (prior probability shift).
|
|
49
|
-
- Algorithms converge to local maxima of the likelihood function.
|
|
50
|
-
- Includes methods such as Class Distribution Estimation (CDE), Maximum Likelihood Prevalence Estimation (MLPE), and Expectation-Maximization (EM) based quantification.
|
|
51
|
-
|
|
52
|
-
Parameters
|
|
53
|
-
----------
|
|
54
|
-
learner : estimator, optional
|
|
55
|
-
Probabilistic classifier implementing the methods `fit(X, y)` and `predict_proba(X)`.
|
|
56
|
-
tol : float, default=1e-4
|
|
57
|
-
Convergence tolerance for prevalence update criteria.
|
|
58
|
-
max_iter : int, default=100
|
|
59
|
-
Maximum allowed number of iterations.
|
|
60
|
-
|
|
61
|
-
Attributes
|
|
62
|
-
----------
|
|
63
|
-
learner : estimator
|
|
64
|
-
Underlying classification model.
|
|
65
|
-
tol : float
|
|
66
|
-
Tolerance for stopping criterion.
|
|
67
|
-
max_iter : int
|
|
68
|
-
Maximum number of iterations.
|
|
69
|
-
classes : ndarray of shape (n_classes,)
|
|
70
|
-
Unique classes observed during training.
|
|
71
|
-
priors : ndarray of shape (n_classes,)
|
|
72
|
-
Class distribution in the training set.
|
|
73
|
-
y_train : array-like
|
|
74
|
-
Training labels used to estimate priors.
|
|
75
|
-
|
|
76
|
-
Examples
|
|
77
|
-
--------
|
|
78
|
-
>>> import numpy as np
|
|
79
|
-
>>> from sklearn.linear_model import LogisticRegression
|
|
80
|
-
>>> class MyQuantifier(BaseIterativeLikelihood):
|
|
81
|
-
... def _iterate(self, predictions, priors):
|
|
82
|
-
... # Implementation of iterative update logic
|
|
83
|
-
... pass
|
|
84
|
-
>>> X = np.random.randn(200, 8)
|
|
85
|
-
>>> y = np.random.randint(0, 3, size=(200,))
|
|
86
|
-
>>> q = MyQuantifier(learner=LogisticRegression(max_iter=200))
|
|
87
|
-
>>> q.fit(X, y)
|
|
88
|
-
>>> q.predict(X)
|
|
89
|
-
{0: 0.32, 1: 0.40, 2: 0.28}
|
|
90
|
-
|
|
91
|
-
References
|
|
92
|
-
----------
|
|
93
|
-
.. [1] Saerens, M., Latinne, P., & Decaestecker, C. (2002). "Adjusting the Outputs of a Classifier to New a Priori Probabilities: A Simple Procedure." Neural Computation, 14(1), 2141-2156.
|
|
94
|
-
|
|
95
|
-
.. [2] Esuli, A., Moreo, A., & Sebastiani, F. (2023). "Learning to Quantify." The Information Retrieval Series 47, Springer. https://doi.org/10.1007/978-3-031-20467-8
|
|
96
|
-
"""
|
|
97
|
-
|
|
98
|
-
@abstractmethod
|
|
99
|
-
def __init__(self,
|
|
100
|
-
learner=None,
|
|
101
|
-
tol=1e-4,
|
|
102
|
-
max_iter=100):
|
|
103
|
-
self.learner = learner
|
|
104
|
-
self.tol = tol
|
|
105
|
-
self.max_iter = max_iter
|
|
106
|
-
|
|
107
|
-
def __mlquantify_tags__(self):
|
|
108
|
-
tags = super().__mlquantify_tags__()
|
|
109
|
-
tags.prediction_requirements.requires_train_proba = False
|
|
110
|
-
return tags
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
@_fit_context(prefer_skip_nested_validation=True)
|
|
114
|
-
def fit(self, X, y):
|
|
115
|
-
"""Fit the quantifier using the provided data and learner."""
|
|
116
|
-
X, y = validate_data(self, X, y)
|
|
117
|
-
validate_y(self, y)
|
|
118
|
-
self.classes_ = np.unique(y)
|
|
119
|
-
self.learner.fit(X, y)
|
|
120
|
-
counts = np.array([np.count_nonzero(y == _class) for _class in self.classes_])
|
|
121
|
-
self.priors = counts / len(y)
|
|
122
|
-
self.y_train = y
|
|
123
|
-
|
|
124
|
-
return self
|
|
125
|
-
|
|
126
|
-
def predict(self, X):
|
|
127
|
-
"""Predict class prevalences for the given data."""
|
|
128
|
-
estimator_function = _get_learner_function(self)
|
|
129
|
-
predictions = getattr(self.learner, estimator_function)(X)
|
|
130
|
-
prevalences = self.aggregate(predictions, self.y_train)
|
|
131
|
-
return prevalences
|
|
132
|
-
|
|
133
|
-
def aggregate(self, predictions, y_train):
|
|
134
|
-
predictions = validate_predictions(self, predictions)
|
|
135
|
-
self.classes_ = check_classes_attribute(self, np.unique(y_train))
|
|
136
|
-
|
|
137
|
-
if not hasattr(self, 'priors') or len(self.priors) != len(self.classes_):
|
|
138
|
-
counts = np.array([np.count_nonzero(y_train == _class) for _class in self.classes_])
|
|
139
|
-
self.priors = counts / len(y_train)
|
|
140
|
-
|
|
141
|
-
prevalences = self._iterate(predictions, self.priors)
|
|
142
|
-
prevalences = validate_prevalences(self, prevalences, self.classes_)
|
|
143
|
-
return prevalences
|
|
144
|
-
|
|
145
|
-
@abstractmethod
|
|
146
|
-
def _iterate(self, predictions, priors):
|
|
147
|
-
...
|
|
File without changes
|