mlquantify 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlquantify/__init__.py +6 -0
- mlquantify/base.py +256 -0
- mlquantify/classification/__init__.py +1 -0
- mlquantify/classification/pwkclf.py +73 -0
- mlquantify/evaluation/__init__.py +2 -0
- mlquantify/evaluation/measures/__init__.py +26 -0
- mlquantify/evaluation/measures/ae.py +11 -0
- mlquantify/evaluation/measures/bias.py +16 -0
- mlquantify/evaluation/measures/kld.py +8 -0
- mlquantify/evaluation/measures/mse.py +12 -0
- mlquantify/evaluation/measures/nae.py +16 -0
- mlquantify/evaluation/measures/nkld.py +13 -0
- mlquantify/evaluation/measures/nrae.py +16 -0
- mlquantify/evaluation/measures/rae.py +12 -0
- mlquantify/evaluation/measures/se.py +12 -0
- mlquantify/evaluation/protocol/_Protocol.py +202 -0
- mlquantify/evaluation/protocol/__init__.py +2 -0
- mlquantify/evaluation/protocol/app.py +146 -0
- mlquantify/evaluation/protocol/npp.py +34 -0
- mlquantify/methods/__init__.py +40 -0
- mlquantify/methods/aggregative/ThreholdOptm/_ThreholdOptimization.py +62 -0
- mlquantify/methods/aggregative/ThreholdOptm/__init__.py +7 -0
- mlquantify/methods/aggregative/ThreholdOptm/acc.py +27 -0
- mlquantify/methods/aggregative/ThreholdOptm/max.py +23 -0
- mlquantify/methods/aggregative/ThreholdOptm/ms.py +21 -0
- mlquantify/methods/aggregative/ThreholdOptm/ms2.py +25 -0
- mlquantify/methods/aggregative/ThreholdOptm/pacc.py +41 -0
- mlquantify/methods/aggregative/ThreholdOptm/t50.py +21 -0
- mlquantify/methods/aggregative/ThreholdOptm/x.py +23 -0
- mlquantify/methods/aggregative/__init__.py +9 -0
- mlquantify/methods/aggregative/cc.py +32 -0
- mlquantify/methods/aggregative/emq.py +86 -0
- mlquantify/methods/aggregative/fm.py +72 -0
- mlquantify/methods/aggregative/gac.py +96 -0
- mlquantify/methods/aggregative/gpac.py +87 -0
- mlquantify/methods/aggregative/mixtureModels/_MixtureModel.py +81 -0
- mlquantify/methods/aggregative/mixtureModels/__init__.py +5 -0
- mlquantify/methods/aggregative/mixtureModels/dys.py +55 -0
- mlquantify/methods/aggregative/mixtureModels/dys_syn.py +89 -0
- mlquantify/methods/aggregative/mixtureModels/hdy.py +46 -0
- mlquantify/methods/aggregative/mixtureModels/smm.py +27 -0
- mlquantify/methods/aggregative/mixtureModels/sord.py +77 -0
- mlquantify/methods/aggregative/pcc.py +33 -0
- mlquantify/methods/aggregative/pwk.py +38 -0
- mlquantify/methods/meta/__init__.py +1 -0
- mlquantify/methods/meta/ensemble.py +236 -0
- mlquantify/methods/non_aggregative/__init__.py +1 -0
- mlquantify/methods/non_aggregative/hdx.py +71 -0
- mlquantify/model_selection.py +232 -0
- mlquantify/plots/__init__.py +2 -0
- mlquantify/plots/distribution_plot.py +109 -0
- mlquantify/plots/protocol_plot.py +157 -0
- mlquantify/utils/__init__.py +2 -0
- mlquantify/utils/general_purposes/__init__.py +8 -0
- mlquantify/utils/general_purposes/convert_col_to_array.py +13 -0
- mlquantify/utils/general_purposes/generate_artificial_indexes.py +29 -0
- mlquantify/utils/general_purposes/get_real_prev.py +9 -0
- mlquantify/utils/general_purposes/load_quantifier.py +4 -0
- mlquantify/utils/general_purposes/make_prevs.py +23 -0
- mlquantify/utils/general_purposes/normalize.py +20 -0
- mlquantify/utils/general_purposes/parallel.py +10 -0
- mlquantify/utils/general_purposes/round_protocol_df.py +14 -0
- mlquantify/utils/method_purposes/__init__.py +6 -0
- mlquantify/utils/method_purposes/distances.py +21 -0
- mlquantify/utils/method_purposes/getHist.py +13 -0
- mlquantify/utils/method_purposes/get_scores.py +33 -0
- mlquantify/utils/method_purposes/moss.py +16 -0
- mlquantify/utils/method_purposes/ternary_search.py +14 -0
- mlquantify/utils/method_purposes/tprfpr.py +42 -0
- mlquantify-0.0.1.dist-info/METADATA +23 -0
- mlquantify-0.0.1.dist-info/RECORD +73 -0
- mlquantify-0.0.1.dist-info/WHEEL +5 -0
- mlquantify-0.0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,232 @@
|
|
|
1
|
+
from .base import Quantifier
|
|
2
|
+
from typing import Union, List
|
|
3
|
+
import itertools
|
|
4
|
+
from tqdm import tqdm
|
|
5
|
+
import signal
|
|
6
|
+
from copy import deepcopy
|
|
7
|
+
import numpy as np
|
|
8
|
+
from sklearn.model_selection import train_test_split
|
|
9
|
+
from .utils import parallel
|
|
10
|
+
from .evaluation import get_measure, APP, NPP
|
|
11
|
+
|
|
12
|
+
class GridSearchQ(Quantifier):
|
|
13
|
+
"""
|
|
14
|
+
Hyperparameter optimization for quantification models using grid search.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
model (Quantifier): The base quantification model.
|
|
18
|
+
param_grid (dict): Hyperparameters to search over.
|
|
19
|
+
protocol (str, optional): Quantification protocol ('app' or 'npp'). Defaults to 'app'.
|
|
20
|
+
n_prevs (int, optional): Number of prevalence points for APP. Defaults to None.
|
|
21
|
+
n_repetitions (int, optional): Number of repetitions for NPP. Defaults to 1.
|
|
22
|
+
scoring (Union[List[str], str], optional): Metric(s) for evaluation. Defaults to "mae".
|
|
23
|
+
refit (bool, optional): Refit model on best parameters. Defaults to True.
|
|
24
|
+
val_split (float, optional): Proportion of data for validation. Defaults to 0.4.
|
|
25
|
+
n_jobs (int, optional): Number of parallel jobs. Defaults to 1.
|
|
26
|
+
random_seed (int, optional): Seed for reproducibility. Defaults to 42.
|
|
27
|
+
timeout (int, optional): Max time per parameter combination (seconds). Defaults to -1.
|
|
28
|
+
verbose (bool, optional): Verbosity of output. Defaults to False.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
def __init__(self,
|
|
32
|
+
model: Quantifier,
|
|
33
|
+
param_grid: dict,
|
|
34
|
+
protocol: str = 'app',
|
|
35
|
+
n_prevs: int = None,
|
|
36
|
+
n_repetitions: int = 1,
|
|
37
|
+
scoring: Union[List[str], str] = "ae",
|
|
38
|
+
refit: bool = True,
|
|
39
|
+
val_split: float = 0.4,
|
|
40
|
+
n_jobs: int = 1,
|
|
41
|
+
random_seed: int = 42,
|
|
42
|
+
timeout: int = -1,
|
|
43
|
+
verbose: bool = False):
|
|
44
|
+
|
|
45
|
+
self.model = model
|
|
46
|
+
self.param_grid = param_grid
|
|
47
|
+
self.protocol = protocol.lower()
|
|
48
|
+
self.n_prevs = n_prevs
|
|
49
|
+
self.n_repetitions = n_repetitions
|
|
50
|
+
self.refit = refit
|
|
51
|
+
self.val_split = val_split
|
|
52
|
+
self.n_jobs = n_jobs
|
|
53
|
+
self.random_seed = random_seed
|
|
54
|
+
self.timeout = timeout
|
|
55
|
+
self.verbose = verbose
|
|
56
|
+
self.scoring = [get_measure(measure) for measure in (scoring if isinstance(scoring, list) else [scoring])]
|
|
57
|
+
|
|
58
|
+
assert self.protocol in {'app', 'npp'}, 'Unknown protocol; valid ones are "app" or "npp".'
|
|
59
|
+
|
|
60
|
+
if self.protocol == 'npp' and self.n_repetitions <= 1:
|
|
61
|
+
raise ValueError('For "npp" protocol, n_repetitions must be greater than 1.')
|
|
62
|
+
|
|
63
|
+
def sout(self, msg):
|
|
64
|
+
"""Prints messages if verbose is True."""
|
|
65
|
+
if self.verbose:
|
|
66
|
+
print(f'[{self.__class__.__name__}]: {msg}')
|
|
67
|
+
|
|
68
|
+
def __get_protocol(self, model, sample_size):
|
|
69
|
+
"""Get the appropriate protocol instance.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
model (Quantifier): The quantification model.
|
|
73
|
+
sample_size (int): The sample size for batch processing.
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
object: Instance of APP or NPP protocol.
|
|
77
|
+
"""
|
|
78
|
+
protocol_params = {
|
|
79
|
+
'models': model,
|
|
80
|
+
'batch_size': sample_size,
|
|
81
|
+
'n_iterations': self.n_repetitions,
|
|
82
|
+
'n_jobs': self.n_jobs,
|
|
83
|
+
'verbose': False,
|
|
84
|
+
'random_state': 35,
|
|
85
|
+
'return_type': "predictions"
|
|
86
|
+
}
|
|
87
|
+
return APP(n_prevs=self.n_prevs, **protocol_params) if self.protocol == 'app' else NPP(**protocol_params)
|
|
88
|
+
|
|
89
|
+
def fit(self, X, y):
|
|
90
|
+
"""Fit the quantifier model and perform grid search.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
X (array-like): Training features.
|
|
94
|
+
y (array-like): Training labels.
|
|
95
|
+
|
|
96
|
+
Returns:
|
|
97
|
+
self: Fitted GridSearchQ instance.
|
|
98
|
+
"""
|
|
99
|
+
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=self.val_split, random_state=self.random_seed)
|
|
100
|
+
param_combinations = list(itertools.product(*self.param_grid.values()))
|
|
101
|
+
best_score, best_params = None, None
|
|
102
|
+
|
|
103
|
+
if self.timeout > 0:
|
|
104
|
+
signal.signal(signal.SIGALRM, self._timeout_handler)
|
|
105
|
+
|
|
106
|
+
def evaluate_combination(params):
|
|
107
|
+
"""Evaluate a single combination of hyperparameters.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
params (tuple): A tuple of hyperparameter values.
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
float or None: The evaluation score, or None if timeout occurred.
|
|
114
|
+
"""
|
|
115
|
+
|
|
116
|
+
if self.verbose:
|
|
117
|
+
print(f"\tEvaluate Combination for {str(params)}")
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
model = deepcopy(self.model)
|
|
121
|
+
model.set_params(**dict(zip(self.param_grid.keys(), params)))
|
|
122
|
+
protocol_instance = self.__get_protocol(model, len(y_train))
|
|
123
|
+
|
|
124
|
+
try:
|
|
125
|
+
if self.timeout > 0:
|
|
126
|
+
signal.alarm(self.timeout)
|
|
127
|
+
|
|
128
|
+
protocol_instance.fit(X_train, y_train)
|
|
129
|
+
_, real_prevs, pred_prevs = protocol_instance.predict(X_val, y_val)
|
|
130
|
+
scores = [np.mean([measure(rp, pp) for rp, pp in zip(real_prevs, pred_prevs)]) for measure in self.scoring]
|
|
131
|
+
|
|
132
|
+
if self.timeout > 0:
|
|
133
|
+
signal.alarm(0)
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
if self.verbose:
|
|
138
|
+
print(f"\t\\--ended evaluation of {str(params)}")
|
|
139
|
+
|
|
140
|
+
return np.mean(scores) if scores else None
|
|
141
|
+
except TimeoutError:
|
|
142
|
+
self.sout(f'Timeout reached for combination {params}.')
|
|
143
|
+
return None
|
|
144
|
+
|
|
145
|
+
results = parallel(
|
|
146
|
+
evaluate_combination,
|
|
147
|
+
tqdm(param_combinations, desc="Evaluating combination", total=len(param_combinations)) if self.verbose else param_combinations,
|
|
148
|
+
n_jobs=self.n_jobs
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
for score, params in zip(results, param_combinations):
|
|
152
|
+
if score is not None and (best_score is None or score < best_score):
|
|
153
|
+
best_score, best_params = score, params
|
|
154
|
+
|
|
155
|
+
self.best_score_ = best_score
|
|
156
|
+
self.best_params_ = dict(zip(self.param_grid.keys(), best_params))
|
|
157
|
+
self.sout(f'Optimization complete. Best score: {self.best_score_}, with parameters: {self.best_params_}.')
|
|
158
|
+
|
|
159
|
+
if self.refit and self.best_params_:
|
|
160
|
+
self.model.set_params(**self.best_params_)
|
|
161
|
+
self.model.fit(X, y)
|
|
162
|
+
self.best_model_ = self.model
|
|
163
|
+
|
|
164
|
+
return self
|
|
165
|
+
|
|
166
|
+
def predict(self, X):
|
|
167
|
+
"""Make predictions using the best found model.
|
|
168
|
+
|
|
169
|
+
Args:
|
|
170
|
+
X (array-like): Data to predict on.
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
array-like: Predictions.
|
|
174
|
+
"""
|
|
175
|
+
if not hasattr(self, 'best_model_'):
|
|
176
|
+
raise RuntimeError("The model has not been fitted yet.")
|
|
177
|
+
return self.best_model_.predict(X)
|
|
178
|
+
|
|
179
|
+
@property
|
|
180
|
+
def classes_(self):
|
|
181
|
+
"""Get the classes of the best model.
|
|
182
|
+
|
|
183
|
+
Returns:
|
|
184
|
+
array-like: The classes.
|
|
185
|
+
"""
|
|
186
|
+
return self.best_model_.classes_
|
|
187
|
+
|
|
188
|
+
def set_params(self, **parameters):
|
|
189
|
+
"""Set the hyperparameters for grid search.
|
|
190
|
+
|
|
191
|
+
Args:
|
|
192
|
+
parameters (dict): Hyperparameters to set.
|
|
193
|
+
"""
|
|
194
|
+
self.param_grid = parameters
|
|
195
|
+
|
|
196
|
+
def get_params(self, deep=True):
|
|
197
|
+
"""Get the parameters of the best model.
|
|
198
|
+
|
|
199
|
+
Args:
|
|
200
|
+
deep (bool, optional): If True, will return the parameters for this estimator and contained subobjects. Defaults to True.
|
|
201
|
+
|
|
202
|
+
Returns:
|
|
203
|
+
dict: Parameters of the best model.
|
|
204
|
+
"""
|
|
205
|
+
if hasattr(self, 'best_model_'):
|
|
206
|
+
return self.best_model_.get_params()
|
|
207
|
+
raise ValueError('get_params called before fit')
|
|
208
|
+
|
|
209
|
+
def best_model(self):
|
|
210
|
+
"""Return the best model after fitting.
|
|
211
|
+
|
|
212
|
+
Returns:
|
|
213
|
+
Quantifier: The best model.
|
|
214
|
+
|
|
215
|
+
Raises:
|
|
216
|
+
ValueError: If called before fitting.
|
|
217
|
+
"""
|
|
218
|
+
if hasattr(self, 'best_model_'):
|
|
219
|
+
return self.best_model_
|
|
220
|
+
raise ValueError('best_model called before fit')
|
|
221
|
+
|
|
222
|
+
def _timeout_handler(self, signum, frame):
|
|
223
|
+
"""Handle timeouts during evaluation.
|
|
224
|
+
|
|
225
|
+
Args:
|
|
226
|
+
signum (int): Signal number.
|
|
227
|
+
frame (object): Current stack frame.
|
|
228
|
+
|
|
229
|
+
Raises:
|
|
230
|
+
TimeoutError: When the timeout is reached.
|
|
231
|
+
"""
|
|
232
|
+
raise TimeoutError()
|
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import matplotlib.pyplot as plt
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import List, Optional, Dict, Any, Union
|
|
5
|
+
|
|
6
|
+
plt.rcParams.update({
|
|
7
|
+
'axes.facecolor': "#F8F8F8",
|
|
8
|
+
'figure.facecolor': "#F8F8F8",
|
|
9
|
+
'font.family': 'sans-serif',
|
|
10
|
+
'font.sans-serif': 'Arial',
|
|
11
|
+
'font.size': 12,
|
|
12
|
+
'font.weight': 'light',
|
|
13
|
+
'axes.labelsize': 14,
|
|
14
|
+
'axes.labelweight': 'light',
|
|
15
|
+
'axes.titlesize': 16,
|
|
16
|
+
'axes.titleweight': 'normal',
|
|
17
|
+
'boxplot.boxprops.linewidth': 0.3,
|
|
18
|
+
'boxplot.whiskerprops.linewidth': 0.3,
|
|
19
|
+
'boxplot.capprops.linewidth': 0.3,
|
|
20
|
+
'boxplot.medianprops.linewidth': 0.6,
|
|
21
|
+
'boxplot.flierprops.linewidth': 0.3,
|
|
22
|
+
'boxplot.flierprops.markersize': 0.9,
|
|
23
|
+
'boxplot.medianprops.color': 'black',
|
|
24
|
+
'figure.subplot.bottom': 0.2,
|
|
25
|
+
'axes.grid': True,
|
|
26
|
+
'grid.color': 'black',
|
|
27
|
+
'grid.alpha': 0.1,
|
|
28
|
+
'grid.linewidth': 0.5,
|
|
29
|
+
'grid.linestyle': '--'
|
|
30
|
+
})
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
COLORS = [
|
|
36
|
+
'#FFAB91', '#FFE082', '#A5D6A7', '#4DD0E1', '#FF6F61', '#FF8C94', '#D4A5A5',
|
|
37
|
+
'#FF677D', '#B9FBC0', '#C2C2F0', '#E3F9A6', '#E2A8F7', '#F7B7A3', '#F7C6C7',
|
|
38
|
+
'#8D9BFC', '#B4E6FF', '#FF8A65', '#FFC3A0', '#FFCCBC', '#F8BBD0', '#FF9AA2',
|
|
39
|
+
'#FFB3B3', '#FFDDC1', '#FFE0B2', '#E2A8F7', '#F7C6C7', '#E57373', '#BA68C8',
|
|
40
|
+
'#4FC3F7', '#FFB3B3', '#FF6F61'
|
|
41
|
+
]
|
|
42
|
+
|
|
43
|
+
def class_distribution_plot(values: Union[List, np.ndarray],
|
|
44
|
+
labels: Union[List, np.ndarray],
|
|
45
|
+
bins: int = 30,
|
|
46
|
+
title: Optional[str] = None,
|
|
47
|
+
legend: bool = True,
|
|
48
|
+
save_path: Optional[str] = None,
|
|
49
|
+
plot_params: Optional[Dict[str, Any]] = None):
|
|
50
|
+
|
|
51
|
+
"""Plot overlaid histograms of class distributions.
|
|
52
|
+
|
|
53
|
+
This function creates a plot with overlaid histograms, each representing the distribution
|
|
54
|
+
of a different class or category. Custom colors, titles, legends, and other plot parameters
|
|
55
|
+
can be applied to enhance visualization.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
values (Union[List, np.ndarray]):
|
|
59
|
+
A list of arrays or a single array containing values for specific classes or categories.
|
|
60
|
+
labels (Union[List, np.ndarray]):
|
|
61
|
+
A list or an array of labels corresponding to each value set in `values`.
|
|
62
|
+
Must be the same length as `values`.
|
|
63
|
+
bins (int, optional):
|
|
64
|
+
Number of bins to use in the histograms. Default is 30.
|
|
65
|
+
title (Optional[str], optional):
|
|
66
|
+
Title of the plot. If not provided, no title will be displayed.
|
|
67
|
+
legend (bool, optional):
|
|
68
|
+
Whether to display a legend. Default is True.
|
|
69
|
+
save_path (Optional[str], optional):
|
|
70
|
+
File path to save the plot image. If not provided, the plot will not be saved.
|
|
71
|
+
plot_params (Optional[Dict[str, Any]], optional):
|
|
72
|
+
Dictionary of custom plotting parameters to apply. Default is None.
|
|
73
|
+
|
|
74
|
+
Raises:
|
|
75
|
+
AssertionError:
|
|
76
|
+
If the number of labels does not match the number of value sets.
|
|
77
|
+
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
# Apply custom plotting parameters if provided
|
|
82
|
+
if plot_params:
|
|
83
|
+
plt.rcParams.update(plot_params)
|
|
84
|
+
|
|
85
|
+
# Ensure the number of labels matches the number of value sets
|
|
86
|
+
assert len(values) == len(labels), "The number of value sets must match the number of labels."
|
|
87
|
+
|
|
88
|
+
# Create the overlaid histogram
|
|
89
|
+
for i, (value_set, label) in enumerate(zip(values, labels)):
|
|
90
|
+
plt.hist(value_set, bins=bins, color=COLORS[i % len(COLORS)], edgecolor='black', alpha=0.5, label=label)
|
|
91
|
+
|
|
92
|
+
# Add title to the plot if provided
|
|
93
|
+
if title:
|
|
94
|
+
plt.title(title)
|
|
95
|
+
|
|
96
|
+
# Add legend to the plot if enabled
|
|
97
|
+
if legend:
|
|
98
|
+
plt.legend(loc='upper right')
|
|
99
|
+
|
|
100
|
+
# Set axis labels
|
|
101
|
+
plt.xlabel('Values')
|
|
102
|
+
plt.ylabel('Frequency')
|
|
103
|
+
|
|
104
|
+
# Save the figure if a path is specified
|
|
105
|
+
if save_path:
|
|
106
|
+
plt.savefig(save_path, bbox_inches='tight')
|
|
107
|
+
|
|
108
|
+
# Show the plot
|
|
109
|
+
plt.show()
|
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
import matplotlib.pyplot as plt
|
|
2
|
+
import matplotlib.patches as mpatches
|
|
3
|
+
import pandas as pd
|
|
4
|
+
from typing import List, Optional, Dict, Any, Union
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
plt.rcParams.update({
|
|
8
|
+
'axes.facecolor': "#F8F8F8",
|
|
9
|
+
'figure.facecolor': "#F8F8F8",
|
|
10
|
+
'font.family': 'sans-serif',
|
|
11
|
+
'font.sans-serif': 'Arial',
|
|
12
|
+
'font.size': 12,
|
|
13
|
+
'font.weight': 'light',
|
|
14
|
+
'axes.labelsize': 14,
|
|
15
|
+
'axes.labelweight': 'light',
|
|
16
|
+
'axes.titlesize': 16,
|
|
17
|
+
'axes.titleweight': 'normal',
|
|
18
|
+
'boxplot.boxprops.linewidth': 0.3,
|
|
19
|
+
'boxplot.whiskerprops.linewidth': 0.3,
|
|
20
|
+
'boxplot.capprops.linewidth': 0.3,
|
|
21
|
+
'boxplot.medianprops.linewidth': 0.6,
|
|
22
|
+
'boxplot.flierprops.linewidth': 0.3,
|
|
23
|
+
'boxplot.flierprops.markersize': 0.9,
|
|
24
|
+
'boxplot.medianprops.color': 'black',
|
|
25
|
+
'figure.subplot.bottom': 0.2,
|
|
26
|
+
'axes.grid': True,
|
|
27
|
+
'grid.color': 'black',
|
|
28
|
+
'grid.alpha': 0.1,
|
|
29
|
+
'grid.linewidth': 0.5,
|
|
30
|
+
'grid.linestyle': '--'
|
|
31
|
+
})
|
|
32
|
+
|
|
33
|
+
# Colors and markers
|
|
34
|
+
COLORS = [
|
|
35
|
+
'#FFAB91', '#FFE082', '#A5D6A7', '#4DD0E1', '#FF6F61', '#FF8C94', '#D4A5A5',
|
|
36
|
+
'#FF677D', '#B9FBC0', '#C2C2F0', '#E3F9A6', '#E2A8F7', '#F7B7A3', '#F7C6C7',
|
|
37
|
+
'#8D9BFC', '#B4E6FF', '#FF8A65', '#FFC3A0', '#FFCCBC', '#F8BBD0', '#FF9AA2',
|
|
38
|
+
'#FFB3B3', '#FFDDC1', '#FFE0B2', '#E2A8F7', '#F7C6C7', '#E57373', '#BA68C8',
|
|
39
|
+
'#4FC3F7', '#FFB3B3', '#FF6F61'
|
|
40
|
+
]
|
|
41
|
+
|
|
42
|
+
MARKERS = ["o", "s", "^", "D", "p", "*", "+", "x", "H", "1", "2", "3", "4", "|", "_"]
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def protocol_boxplot(
|
|
47
|
+
table_protocol: pd.DataFrame,
|
|
48
|
+
x: str,
|
|
49
|
+
y: str,
|
|
50
|
+
methods: Optional[List[str]] = None,
|
|
51
|
+
title: Optional[str] = None,
|
|
52
|
+
legend: bool = True,
|
|
53
|
+
save_path: Optional[str] = None,
|
|
54
|
+
order: Optional[str] = None,
|
|
55
|
+
plot_params: Optional[Dict[str, Any]] = None):
|
|
56
|
+
"""
|
|
57
|
+
Plots a boxplot based on the provided DataFrame and selected methods.
|
|
58
|
+
"""
|
|
59
|
+
# Handle plot_params
|
|
60
|
+
plot_params = plot_params or {}
|
|
61
|
+
figsize = plot_params.pop('figsize', (10, 6)) # Default figsize if not provided
|
|
62
|
+
|
|
63
|
+
# Prepare data
|
|
64
|
+
table = table_protocol.drop(["PRED_PREVS", "REAL_PREVS"], axis=1).copy()
|
|
65
|
+
methods = methods or table['QUANTIFIER'].unique()
|
|
66
|
+
table = table[table['QUANTIFIER'].isin(methods)]
|
|
67
|
+
|
|
68
|
+
# Order methods by ranking if specified
|
|
69
|
+
if order == 'rank':
|
|
70
|
+
methods = table.groupby('QUANTIFIER')[y].median().sort_values().index.tolist()
|
|
71
|
+
|
|
72
|
+
# Create plot with custom figsize
|
|
73
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
74
|
+
ax.grid(False)
|
|
75
|
+
|
|
76
|
+
box = ax.boxplot([table[table['QUANTIFIER'] == method][y] for method in methods],
|
|
77
|
+
patch_artist=True, widths=0.8, labels=methods, **plot_params)
|
|
78
|
+
|
|
79
|
+
# Apply colors
|
|
80
|
+
for patch, color in zip(box['boxes'], COLORS[:len(methods)]):
|
|
81
|
+
patch.set_facecolor(color)
|
|
82
|
+
|
|
83
|
+
# Add legend
|
|
84
|
+
if legend:
|
|
85
|
+
handles = [mpatches.Patch(color=COLORS[i], label=method) for i, method in enumerate(methods)]
|
|
86
|
+
ax.legend(handles=handles, title="Quantifiers", loc='upper left', bbox_to_anchor=(1, 1), fontsize=10, title_fontsize='11')
|
|
87
|
+
|
|
88
|
+
# Customize plot
|
|
89
|
+
ax.set_xticklabels(methods, rotation=45, fontstyle='italic')
|
|
90
|
+
ax.set_xlabel(x.capitalize())
|
|
91
|
+
ax.set_ylabel(f"{y.capitalize()}")
|
|
92
|
+
if title:
|
|
93
|
+
ax.set_title(title)
|
|
94
|
+
|
|
95
|
+
# Adjust layout and save plot
|
|
96
|
+
plt.tight_layout(rect=[0, 0, 0.9, 1])
|
|
97
|
+
if save_path:
|
|
98
|
+
plt.savefig(save_path, bbox_inches='tight')
|
|
99
|
+
plt.show()
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def protocol_lineplot(
|
|
104
|
+
table_protocol: pd.DataFrame,
|
|
105
|
+
methods: Union[List[str], str, None],
|
|
106
|
+
x: str,
|
|
107
|
+
y: str,
|
|
108
|
+
title: Optional[str] = None,
|
|
109
|
+
legend: bool = True,
|
|
110
|
+
save_path: Optional[str] = None,
|
|
111
|
+
group_by: str = "mean",
|
|
112
|
+
pos_alpha: int = 1,
|
|
113
|
+
plot_params: Optional[Dict[str, Any]] = None):
|
|
114
|
+
"""
|
|
115
|
+
Plots a line plot based on the provided DataFrame of the protocol and selected methods.
|
|
116
|
+
"""
|
|
117
|
+
# Handle plot_params
|
|
118
|
+
plot_params = plot_params or {}
|
|
119
|
+
figsize = plot_params.pop('figsize', (10, 6)) # Default figsize if not provided
|
|
120
|
+
|
|
121
|
+
# Filter data
|
|
122
|
+
methods = methods or table_protocol['QUANTIFIER'].unique()
|
|
123
|
+
table_protocol = table_protocol[table_protocol['QUANTIFIER'].isin(methods)]
|
|
124
|
+
|
|
125
|
+
if x == "ALPHA":
|
|
126
|
+
real = table_protocol["REAL_PREVS"].apply(lambda x: x[pos_alpha])
|
|
127
|
+
table = table_protocol.drop(["PRED_PREVS", "REAL_PREVS"], axis=1).copy()
|
|
128
|
+
table["ALPHA"] = real
|
|
129
|
+
else:
|
|
130
|
+
table = table_protocol.drop(["PRED_PREVS", "REAL_PREVS"], axis=1).copy()
|
|
131
|
+
|
|
132
|
+
# Aggregate data
|
|
133
|
+
if group_by:
|
|
134
|
+
table = table.groupby(['QUANTIFIER', x])[y].agg(group_by).reset_index()
|
|
135
|
+
|
|
136
|
+
# Create plot with custom figsize
|
|
137
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
138
|
+
for i, (method, marker) in enumerate(zip(methods, MARKERS[:len(methods)])):
|
|
139
|
+
method_data = table[table['QUANTIFIER'] == method]
|
|
140
|
+
y_data = real if y == "ALPHA" else method_data[y]
|
|
141
|
+
ax.plot(method_data[x], y_data, color=COLORS[i % len(COLORS)], marker=marker, label=method, **plot_params)
|
|
142
|
+
|
|
143
|
+
# Add legend
|
|
144
|
+
if legend:
|
|
145
|
+
ax.legend(title="Quantifiers", loc='upper left', bbox_to_anchor=(1, 1), fontsize=10, title_fontsize='11')
|
|
146
|
+
|
|
147
|
+
# Customize plot
|
|
148
|
+
ax.set_xlabel(x.capitalize())
|
|
149
|
+
ax.set_ylabel(y.capitalize())
|
|
150
|
+
if title:
|
|
151
|
+
ax.set_title(title)
|
|
152
|
+
|
|
153
|
+
# Adjust layout and save plot
|
|
154
|
+
plt.tight_layout(rect=[0, 0, 0.9, 1])
|
|
155
|
+
if save_path:
|
|
156
|
+
plt.savefig(save_path, bbox_inches='tight')
|
|
157
|
+
plt.show()
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
from .normalize import normalize_prevalence
|
|
2
|
+
from .parallel import parallel
|
|
3
|
+
from .get_real_prev import get_real_prev
|
|
4
|
+
from .make_prevs import make_prevs
|
|
5
|
+
from .generate_artificial_indexes import generate_artificial_indexes
|
|
6
|
+
from .round_protocol_df import round_protocol_df
|
|
7
|
+
from .convert_col_to_array import convert_columns_to_arrays
|
|
8
|
+
from .load_quantifier import load_quantifier
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
def convert_columns_to_arrays(df, columns:list = ['PRED_PREVS', 'REAL_PREVS']):
|
|
4
|
+
"""Converts the specified columns from string of arrays to numpy arrays
|
|
5
|
+
|
|
6
|
+
Args:
|
|
7
|
+
df (array-like): the dataframe from which to change convert the coluns
|
|
8
|
+
columns (list, optional): the coluns with string of arrays, default is the options for
|
|
9
|
+
the protocol dataframes
|
|
10
|
+
"""
|
|
11
|
+
for col in columns:
|
|
12
|
+
df[col] = df[col].apply(lambda x: np.fromstring(x.strip('[]'), sep=' ') if isinstance(x, str) else x)
|
|
13
|
+
return df
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
def generate_artificial_indexes(y, prevalence: list, sample_size:int, classes:list):
|
|
4
|
+
# Ensure the sum of prevalences is 1
|
|
5
|
+
assert np.isclose(sum(prevalence), 1), "The sum of prevalences must be 1"
|
|
6
|
+
# Ensure the number of prevalences matches the number of classes
|
|
7
|
+
|
|
8
|
+
sampled_indexes = []
|
|
9
|
+
total_sampled = 0
|
|
10
|
+
|
|
11
|
+
for i, class_ in enumerate(classes):
|
|
12
|
+
|
|
13
|
+
if i == len(classes) - 1:
|
|
14
|
+
num_samples = sample_size - total_sampled
|
|
15
|
+
else:
|
|
16
|
+
num_samples = int(sample_size * prevalence[i])
|
|
17
|
+
|
|
18
|
+
# Get the indexes of the current class
|
|
19
|
+
class_indexes = np.where(y == class_)[0]
|
|
20
|
+
|
|
21
|
+
# Sample the indexes for the current class
|
|
22
|
+
sampled_class_indexes = np.random.choice(class_indexes, size=num_samples, replace=True)
|
|
23
|
+
|
|
24
|
+
sampled_indexes.extend(sampled_class_indexes)
|
|
25
|
+
total_sampled += num_samples
|
|
26
|
+
|
|
27
|
+
np.random.shuffle(sampled_indexes) # Shuffle after collecting all indexes
|
|
28
|
+
|
|
29
|
+
return sampled_indexes
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
def make_prevs(ndim:int) -> list:
|
|
4
|
+
"""
|
|
5
|
+
Generate a list of n_dim values uniformly distributed between 0 and 1 that sum exactly to 1.
|
|
6
|
+
|
|
7
|
+
Args:
|
|
8
|
+
n_dim (int): Number of values in the list.
|
|
9
|
+
|
|
10
|
+
Returns:
|
|
11
|
+
list: List of n_dim values that sum to 1.
|
|
12
|
+
"""
|
|
13
|
+
# Generate n_dim-1 random u_dist uniformly distributed between 0 and 1
|
|
14
|
+
u_dist = np.random.uniform(0, 1, ndim - 1)
|
|
15
|
+
# Add 0 and 1 to the u_dist
|
|
16
|
+
u_dist = np.append(u_dist, [0, 1])
|
|
17
|
+
# Sort the u_dist
|
|
18
|
+
u_dist.sort()
|
|
19
|
+
# Calculate the differences between consecutive u_dist
|
|
20
|
+
prevs = np.diff(u_dist)
|
|
21
|
+
|
|
22
|
+
return prevs
|
|
23
|
+
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from collections import defaultdict
|
|
3
|
+
|
|
4
|
+
def normalize_prevalence(prevalences: np.ndarray, classes:list):
|
|
5
|
+
|
|
6
|
+
if isinstance(prevalences, dict):
|
|
7
|
+
summ = sum(prevalences.values())
|
|
8
|
+
prevalences = {int(_class):float(value/summ) for _class, value in prevalences.items()}
|
|
9
|
+
return prevalences
|
|
10
|
+
|
|
11
|
+
summ = np.sum(prevalences, axis=-1, keepdims=True)
|
|
12
|
+
prevalences = np.true_divide(prevalences, sum(prevalences), where=summ>0)
|
|
13
|
+
prevalences = {int(_class):float(prev) for _class, prev in zip(classes, prevalences)}
|
|
14
|
+
prevalences = defaultdict(lambda: 0, prevalences)
|
|
15
|
+
|
|
16
|
+
# Ensure all classes are present in the result
|
|
17
|
+
for cls in classes:
|
|
18
|
+
prevalences[cls] = prevalences[cls]
|
|
19
|
+
|
|
20
|
+
return dict(prevalences)
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import pandas as pd
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def round_protocol_df(dataframe: pd.DataFrame, frac: int = 3):
|
|
6
|
+
def round_column(col):
|
|
7
|
+
if col.name in ['PRED_PREVS', 'REAL_PREVS']:
|
|
8
|
+
return col.apply(lambda x: np.round(x, frac) if isinstance(x, (np.ndarray, float, int)) else x)
|
|
9
|
+
elif np.issubdtype(col.dtype, np.number):
|
|
10
|
+
return col.round(frac)
|
|
11
|
+
else:
|
|
12
|
+
return col
|
|
13
|
+
|
|
14
|
+
return dataframe.apply(round_column)
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
from .getHist import getHist
|
|
2
|
+
from .distances import sqEuclidean, probsymm, hellinger, topsoe
|
|
3
|
+
from .ternary_search import ternary_search
|
|
4
|
+
from .tprfpr import compute_table, compute_tpr, compute_fpr, adjust_threshold
|
|
5
|
+
from .get_scores import get_scores
|
|
6
|
+
from .moss import MoSS
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
def sqEuclidean(dist1, dist2):
|
|
4
|
+
P=dist1
|
|
5
|
+
Q=dist2
|
|
6
|
+
return sum((P-Q)**2)
|
|
7
|
+
|
|
8
|
+
def probsymm(dist1, dist2):
|
|
9
|
+
P=dist1
|
|
10
|
+
Q=dist2
|
|
11
|
+
return 2*sum((P-Q)**2/(P+Q))
|
|
12
|
+
|
|
13
|
+
def topsoe(dist1, dist2):
|
|
14
|
+
P=dist1
|
|
15
|
+
Q=dist2
|
|
16
|
+
return sum(P*np.log(2*P/(P+Q))+Q*np.log(2*Q/(P+Q)))
|
|
17
|
+
|
|
18
|
+
def hellinger(dist1, dist2):
|
|
19
|
+
P=dist1
|
|
20
|
+
Q=dist2
|
|
21
|
+
return 2 * np.sqrt(np.abs(1 - sum(np.sqrt(P * Q))))
|