mlquantify 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. mlquantify/__init__.py +6 -0
  2. mlquantify/base.py +256 -0
  3. mlquantify/classification/__init__.py +1 -0
  4. mlquantify/classification/pwkclf.py +73 -0
  5. mlquantify/evaluation/__init__.py +2 -0
  6. mlquantify/evaluation/measures/__init__.py +26 -0
  7. mlquantify/evaluation/measures/ae.py +11 -0
  8. mlquantify/evaluation/measures/bias.py +16 -0
  9. mlquantify/evaluation/measures/kld.py +8 -0
  10. mlquantify/evaluation/measures/mse.py +12 -0
  11. mlquantify/evaluation/measures/nae.py +16 -0
  12. mlquantify/evaluation/measures/nkld.py +13 -0
  13. mlquantify/evaluation/measures/nrae.py +16 -0
  14. mlquantify/evaluation/measures/rae.py +12 -0
  15. mlquantify/evaluation/measures/se.py +12 -0
  16. mlquantify/evaluation/protocol/_Protocol.py +202 -0
  17. mlquantify/evaluation/protocol/__init__.py +2 -0
  18. mlquantify/evaluation/protocol/app.py +146 -0
  19. mlquantify/evaluation/protocol/npp.py +34 -0
  20. mlquantify/methods/__init__.py +40 -0
  21. mlquantify/methods/aggregative/ThreholdOptm/_ThreholdOptimization.py +62 -0
  22. mlquantify/methods/aggregative/ThreholdOptm/__init__.py +7 -0
  23. mlquantify/methods/aggregative/ThreholdOptm/acc.py +27 -0
  24. mlquantify/methods/aggregative/ThreholdOptm/max.py +23 -0
  25. mlquantify/methods/aggregative/ThreholdOptm/ms.py +21 -0
  26. mlquantify/methods/aggregative/ThreholdOptm/ms2.py +25 -0
  27. mlquantify/methods/aggregative/ThreholdOptm/pacc.py +41 -0
  28. mlquantify/methods/aggregative/ThreholdOptm/t50.py +21 -0
  29. mlquantify/methods/aggregative/ThreholdOptm/x.py +23 -0
  30. mlquantify/methods/aggregative/__init__.py +9 -0
  31. mlquantify/methods/aggregative/cc.py +32 -0
  32. mlquantify/methods/aggregative/emq.py +86 -0
  33. mlquantify/methods/aggregative/fm.py +72 -0
  34. mlquantify/methods/aggregative/gac.py +96 -0
  35. mlquantify/methods/aggregative/gpac.py +87 -0
  36. mlquantify/methods/aggregative/mixtureModels/_MixtureModel.py +81 -0
  37. mlquantify/methods/aggregative/mixtureModels/__init__.py +5 -0
  38. mlquantify/methods/aggregative/mixtureModels/dys.py +55 -0
  39. mlquantify/methods/aggregative/mixtureModels/dys_syn.py +89 -0
  40. mlquantify/methods/aggregative/mixtureModels/hdy.py +46 -0
  41. mlquantify/methods/aggregative/mixtureModels/smm.py +27 -0
  42. mlquantify/methods/aggregative/mixtureModels/sord.py +77 -0
  43. mlquantify/methods/aggregative/pcc.py +33 -0
  44. mlquantify/methods/aggregative/pwk.py +38 -0
  45. mlquantify/methods/meta/__init__.py +1 -0
  46. mlquantify/methods/meta/ensemble.py +236 -0
  47. mlquantify/methods/non_aggregative/__init__.py +1 -0
  48. mlquantify/methods/non_aggregative/hdx.py +71 -0
  49. mlquantify/model_selection.py +232 -0
  50. mlquantify/plots/__init__.py +2 -0
  51. mlquantify/plots/distribution_plot.py +109 -0
  52. mlquantify/plots/protocol_plot.py +157 -0
  53. mlquantify/utils/__init__.py +2 -0
  54. mlquantify/utils/general_purposes/__init__.py +8 -0
  55. mlquantify/utils/general_purposes/convert_col_to_array.py +13 -0
  56. mlquantify/utils/general_purposes/generate_artificial_indexes.py +29 -0
  57. mlquantify/utils/general_purposes/get_real_prev.py +9 -0
  58. mlquantify/utils/general_purposes/load_quantifier.py +4 -0
  59. mlquantify/utils/general_purposes/make_prevs.py +23 -0
  60. mlquantify/utils/general_purposes/normalize.py +20 -0
  61. mlquantify/utils/general_purposes/parallel.py +10 -0
  62. mlquantify/utils/general_purposes/round_protocol_df.py +14 -0
  63. mlquantify/utils/method_purposes/__init__.py +6 -0
  64. mlquantify/utils/method_purposes/distances.py +21 -0
  65. mlquantify/utils/method_purposes/getHist.py +13 -0
  66. mlquantify/utils/method_purposes/get_scores.py +33 -0
  67. mlquantify/utils/method_purposes/moss.py +16 -0
  68. mlquantify/utils/method_purposes/ternary_search.py +14 -0
  69. mlquantify/utils/method_purposes/tprfpr.py +42 -0
  70. mlquantify-0.0.1.dist-info/METADATA +23 -0
  71. mlquantify-0.0.1.dist-info/RECORD +73 -0
  72. mlquantify-0.0.1.dist-info/WHEEL +5 -0
  73. mlquantify-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,232 @@
1
+ from .base import Quantifier
2
+ from typing import Union, List
3
+ import itertools
4
+ from tqdm import tqdm
5
+ import signal
6
+ from copy import deepcopy
7
+ import numpy as np
8
+ from sklearn.model_selection import train_test_split
9
+ from .utils import parallel
10
+ from .evaluation import get_measure, APP, NPP
11
+
12
+ class GridSearchQ(Quantifier):
13
+ """
14
+ Hyperparameter optimization for quantification models using grid search.
15
+
16
+ Args:
17
+ model (Quantifier): The base quantification model.
18
+ param_grid (dict): Hyperparameters to search over.
19
+ protocol (str, optional): Quantification protocol ('app' or 'npp'). Defaults to 'app'.
20
+ n_prevs (int, optional): Number of prevalence points for APP. Defaults to None.
21
+ n_repetitions (int, optional): Number of repetitions for NPP. Defaults to 1.
22
+ scoring (Union[List[str], str], optional): Metric(s) for evaluation. Defaults to "mae".
23
+ refit (bool, optional): Refit model on best parameters. Defaults to True.
24
+ val_split (float, optional): Proportion of data for validation. Defaults to 0.4.
25
+ n_jobs (int, optional): Number of parallel jobs. Defaults to 1.
26
+ random_seed (int, optional): Seed for reproducibility. Defaults to 42.
27
+ timeout (int, optional): Max time per parameter combination (seconds). Defaults to -1.
28
+ verbose (bool, optional): Verbosity of output. Defaults to False.
29
+ """
30
+
31
+ def __init__(self,
32
+ model: Quantifier,
33
+ param_grid: dict,
34
+ protocol: str = 'app',
35
+ n_prevs: int = None,
36
+ n_repetitions: int = 1,
37
+ scoring: Union[List[str], str] = "ae",
38
+ refit: bool = True,
39
+ val_split: float = 0.4,
40
+ n_jobs: int = 1,
41
+ random_seed: int = 42,
42
+ timeout: int = -1,
43
+ verbose: bool = False):
44
+
45
+ self.model = model
46
+ self.param_grid = param_grid
47
+ self.protocol = protocol.lower()
48
+ self.n_prevs = n_prevs
49
+ self.n_repetitions = n_repetitions
50
+ self.refit = refit
51
+ self.val_split = val_split
52
+ self.n_jobs = n_jobs
53
+ self.random_seed = random_seed
54
+ self.timeout = timeout
55
+ self.verbose = verbose
56
+ self.scoring = [get_measure(measure) for measure in (scoring if isinstance(scoring, list) else [scoring])]
57
+
58
+ assert self.protocol in {'app', 'npp'}, 'Unknown protocol; valid ones are "app" or "npp".'
59
+
60
+ if self.protocol == 'npp' and self.n_repetitions <= 1:
61
+ raise ValueError('For "npp" protocol, n_repetitions must be greater than 1.')
62
+
63
+ def sout(self, msg):
64
+ """Prints messages if verbose is True."""
65
+ if self.verbose:
66
+ print(f'[{self.__class__.__name__}]: {msg}')
67
+
68
+ def __get_protocol(self, model, sample_size):
69
+ """Get the appropriate protocol instance.
70
+
71
+ Args:
72
+ model (Quantifier): The quantification model.
73
+ sample_size (int): The sample size for batch processing.
74
+
75
+ Returns:
76
+ object: Instance of APP or NPP protocol.
77
+ """
78
+ protocol_params = {
79
+ 'models': model,
80
+ 'batch_size': sample_size,
81
+ 'n_iterations': self.n_repetitions,
82
+ 'n_jobs': self.n_jobs,
83
+ 'verbose': False,
84
+ 'random_state': 35,
85
+ 'return_type': "predictions"
86
+ }
87
+ return APP(n_prevs=self.n_prevs, **protocol_params) if self.protocol == 'app' else NPP(**protocol_params)
88
+
89
+ def fit(self, X, y):
90
+ """Fit the quantifier model and perform grid search.
91
+
92
+ Args:
93
+ X (array-like): Training features.
94
+ y (array-like): Training labels.
95
+
96
+ Returns:
97
+ self: Fitted GridSearchQ instance.
98
+ """
99
+ X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=self.val_split, random_state=self.random_seed)
100
+ param_combinations = list(itertools.product(*self.param_grid.values()))
101
+ best_score, best_params = None, None
102
+
103
+ if self.timeout > 0:
104
+ signal.signal(signal.SIGALRM, self._timeout_handler)
105
+
106
+ def evaluate_combination(params):
107
+ """Evaluate a single combination of hyperparameters.
108
+
109
+ Args:
110
+ params (tuple): A tuple of hyperparameter values.
111
+
112
+ Returns:
113
+ float or None: The evaluation score, or None if timeout occurred.
114
+ """
115
+
116
+ if self.verbose:
117
+ print(f"\tEvaluate Combination for {str(params)}")
118
+
119
+
120
+ model = deepcopy(self.model)
121
+ model.set_params(**dict(zip(self.param_grid.keys(), params)))
122
+ protocol_instance = self.__get_protocol(model, len(y_train))
123
+
124
+ try:
125
+ if self.timeout > 0:
126
+ signal.alarm(self.timeout)
127
+
128
+ protocol_instance.fit(X_train, y_train)
129
+ _, real_prevs, pred_prevs = protocol_instance.predict(X_val, y_val)
130
+ scores = [np.mean([measure(rp, pp) for rp, pp in zip(real_prevs, pred_prevs)]) for measure in self.scoring]
131
+
132
+ if self.timeout > 0:
133
+ signal.alarm(0)
134
+
135
+
136
+
137
+ if self.verbose:
138
+ print(f"\t\\--ended evaluation of {str(params)}")
139
+
140
+ return np.mean(scores) if scores else None
141
+ except TimeoutError:
142
+ self.sout(f'Timeout reached for combination {params}.')
143
+ return None
144
+
145
+ results = parallel(
146
+ evaluate_combination,
147
+ tqdm(param_combinations, desc="Evaluating combination", total=len(param_combinations)) if self.verbose else param_combinations,
148
+ n_jobs=self.n_jobs
149
+ )
150
+
151
+ for score, params in zip(results, param_combinations):
152
+ if score is not None and (best_score is None or score < best_score):
153
+ best_score, best_params = score, params
154
+
155
+ self.best_score_ = best_score
156
+ self.best_params_ = dict(zip(self.param_grid.keys(), best_params))
157
+ self.sout(f'Optimization complete. Best score: {self.best_score_}, with parameters: {self.best_params_}.')
158
+
159
+ if self.refit and self.best_params_:
160
+ self.model.set_params(**self.best_params_)
161
+ self.model.fit(X, y)
162
+ self.best_model_ = self.model
163
+
164
+ return self
165
+
166
+ def predict(self, X):
167
+ """Make predictions using the best found model.
168
+
169
+ Args:
170
+ X (array-like): Data to predict on.
171
+
172
+ Returns:
173
+ array-like: Predictions.
174
+ """
175
+ if not hasattr(self, 'best_model_'):
176
+ raise RuntimeError("The model has not been fitted yet.")
177
+ return self.best_model_.predict(X)
178
+
179
+ @property
180
+ def classes_(self):
181
+ """Get the classes of the best model.
182
+
183
+ Returns:
184
+ array-like: The classes.
185
+ """
186
+ return self.best_model_.classes_
187
+
188
+ def set_params(self, **parameters):
189
+ """Set the hyperparameters for grid search.
190
+
191
+ Args:
192
+ parameters (dict): Hyperparameters to set.
193
+ """
194
+ self.param_grid = parameters
195
+
196
+ def get_params(self, deep=True):
197
+ """Get the parameters of the best model.
198
+
199
+ Args:
200
+ deep (bool, optional): If True, will return the parameters for this estimator and contained subobjects. Defaults to True.
201
+
202
+ Returns:
203
+ dict: Parameters of the best model.
204
+ """
205
+ if hasattr(self, 'best_model_'):
206
+ return self.best_model_.get_params()
207
+ raise ValueError('get_params called before fit')
208
+
209
+ def best_model(self):
210
+ """Return the best model after fitting.
211
+
212
+ Returns:
213
+ Quantifier: The best model.
214
+
215
+ Raises:
216
+ ValueError: If called before fitting.
217
+ """
218
+ if hasattr(self, 'best_model_'):
219
+ return self.best_model_
220
+ raise ValueError('best_model called before fit')
221
+
222
+ def _timeout_handler(self, signum, frame):
223
+ """Handle timeouts during evaluation.
224
+
225
+ Args:
226
+ signum (int): Signal number.
227
+ frame (object): Current stack frame.
228
+
229
+ Raises:
230
+ TimeoutError: When the timeout is reached.
231
+ """
232
+ raise TimeoutError()
@@ -0,0 +1,2 @@
1
+ from .protocol_plot import protocol_boxplot, protocol_lineplot
2
+ from .distribution_plot import class_distribution_plot
@@ -0,0 +1,109 @@
1
+ import numpy as np
2
+ import matplotlib.pyplot as plt
3
+ from pathlib import Path
4
+ from typing import List, Optional, Dict, Any, Union
5
+
6
+ plt.rcParams.update({
7
+ 'axes.facecolor': "#F8F8F8",
8
+ 'figure.facecolor': "#F8F8F8",
9
+ 'font.family': 'sans-serif',
10
+ 'font.sans-serif': 'Arial',
11
+ 'font.size': 12,
12
+ 'font.weight': 'light',
13
+ 'axes.labelsize': 14,
14
+ 'axes.labelweight': 'light',
15
+ 'axes.titlesize': 16,
16
+ 'axes.titleweight': 'normal',
17
+ 'boxplot.boxprops.linewidth': 0.3,
18
+ 'boxplot.whiskerprops.linewidth': 0.3,
19
+ 'boxplot.capprops.linewidth': 0.3,
20
+ 'boxplot.medianprops.linewidth': 0.6,
21
+ 'boxplot.flierprops.linewidth': 0.3,
22
+ 'boxplot.flierprops.markersize': 0.9,
23
+ 'boxplot.medianprops.color': 'black',
24
+ 'figure.subplot.bottom': 0.2,
25
+ 'axes.grid': True,
26
+ 'grid.color': 'black',
27
+ 'grid.alpha': 0.1,
28
+ 'grid.linewidth': 0.5,
29
+ 'grid.linestyle': '--'
30
+ })
31
+
32
+
33
+
34
+
35
+ COLORS = [
36
+ '#FFAB91', '#FFE082', '#A5D6A7', '#4DD0E1', '#FF6F61', '#FF8C94', '#D4A5A5',
37
+ '#FF677D', '#B9FBC0', '#C2C2F0', '#E3F9A6', '#E2A8F7', '#F7B7A3', '#F7C6C7',
38
+ '#8D9BFC', '#B4E6FF', '#FF8A65', '#FFC3A0', '#FFCCBC', '#F8BBD0', '#FF9AA2',
39
+ '#FFB3B3', '#FFDDC1', '#FFE0B2', '#E2A8F7', '#F7C6C7', '#E57373', '#BA68C8',
40
+ '#4FC3F7', '#FFB3B3', '#FF6F61'
41
+ ]
42
+
43
+ def class_distribution_plot(values: Union[List, np.ndarray],
44
+ labels: Union[List, np.ndarray],
45
+ bins: int = 30,
46
+ title: Optional[str] = None,
47
+ legend: bool = True,
48
+ save_path: Optional[str] = None,
49
+ plot_params: Optional[Dict[str, Any]] = None):
50
+
51
+ """Plot overlaid histograms of class distributions.
52
+
53
+ This function creates a plot with overlaid histograms, each representing the distribution
54
+ of a different class or category. Custom colors, titles, legends, and other plot parameters
55
+ can be applied to enhance visualization.
56
+
57
+ Args:
58
+ values (Union[List, np.ndarray]):
59
+ A list of arrays or a single array containing values for specific classes or categories.
60
+ labels (Union[List, np.ndarray]):
61
+ A list or an array of labels corresponding to each value set in `values`.
62
+ Must be the same length as `values`.
63
+ bins (int, optional):
64
+ Number of bins to use in the histograms. Default is 30.
65
+ title (Optional[str], optional):
66
+ Title of the plot. If not provided, no title will be displayed.
67
+ legend (bool, optional):
68
+ Whether to display a legend. Default is True.
69
+ save_path (Optional[str], optional):
70
+ File path to save the plot image. If not provided, the plot will not be saved.
71
+ plot_params (Optional[Dict[str, Any]], optional):
72
+ Dictionary of custom plotting parameters to apply. Default is None.
73
+
74
+ Raises:
75
+ AssertionError:
76
+ If the number of labels does not match the number of value sets.
77
+
78
+ """
79
+
80
+
81
+ # Apply custom plotting parameters if provided
82
+ if plot_params:
83
+ plt.rcParams.update(plot_params)
84
+
85
+ # Ensure the number of labels matches the number of value sets
86
+ assert len(values) == len(labels), "The number of value sets must match the number of labels."
87
+
88
+ # Create the overlaid histogram
89
+ for i, (value_set, label) in enumerate(zip(values, labels)):
90
+ plt.hist(value_set, bins=bins, color=COLORS[i % len(COLORS)], edgecolor='black', alpha=0.5, label=label)
91
+
92
+ # Add title to the plot if provided
93
+ if title:
94
+ plt.title(title)
95
+
96
+ # Add legend to the plot if enabled
97
+ if legend:
98
+ plt.legend(loc='upper right')
99
+
100
+ # Set axis labels
101
+ plt.xlabel('Values')
102
+ plt.ylabel('Frequency')
103
+
104
+ # Save the figure if a path is specified
105
+ if save_path:
106
+ plt.savefig(save_path, bbox_inches='tight')
107
+
108
+ # Show the plot
109
+ plt.show()
@@ -0,0 +1,157 @@
1
+ import matplotlib.pyplot as plt
2
+ import matplotlib.patches as mpatches
3
+ import pandas as pd
4
+ from typing import List, Optional, Dict, Any, Union
5
+ from pathlib import Path
6
+
7
+ plt.rcParams.update({
8
+ 'axes.facecolor': "#F8F8F8",
9
+ 'figure.facecolor': "#F8F8F8",
10
+ 'font.family': 'sans-serif',
11
+ 'font.sans-serif': 'Arial',
12
+ 'font.size': 12,
13
+ 'font.weight': 'light',
14
+ 'axes.labelsize': 14,
15
+ 'axes.labelweight': 'light',
16
+ 'axes.titlesize': 16,
17
+ 'axes.titleweight': 'normal',
18
+ 'boxplot.boxprops.linewidth': 0.3,
19
+ 'boxplot.whiskerprops.linewidth': 0.3,
20
+ 'boxplot.capprops.linewidth': 0.3,
21
+ 'boxplot.medianprops.linewidth': 0.6,
22
+ 'boxplot.flierprops.linewidth': 0.3,
23
+ 'boxplot.flierprops.markersize': 0.9,
24
+ 'boxplot.medianprops.color': 'black',
25
+ 'figure.subplot.bottom': 0.2,
26
+ 'axes.grid': True,
27
+ 'grid.color': 'black',
28
+ 'grid.alpha': 0.1,
29
+ 'grid.linewidth': 0.5,
30
+ 'grid.linestyle': '--'
31
+ })
32
+
33
+ # Colors and markers
34
+ COLORS = [
35
+ '#FFAB91', '#FFE082', '#A5D6A7', '#4DD0E1', '#FF6F61', '#FF8C94', '#D4A5A5',
36
+ '#FF677D', '#B9FBC0', '#C2C2F0', '#E3F9A6', '#E2A8F7', '#F7B7A3', '#F7C6C7',
37
+ '#8D9BFC', '#B4E6FF', '#FF8A65', '#FFC3A0', '#FFCCBC', '#F8BBD0', '#FF9AA2',
38
+ '#FFB3B3', '#FFDDC1', '#FFE0B2', '#E2A8F7', '#F7C6C7', '#E57373', '#BA68C8',
39
+ '#4FC3F7', '#FFB3B3', '#FF6F61'
40
+ ]
41
+
42
+ MARKERS = ["o", "s", "^", "D", "p", "*", "+", "x", "H", "1", "2", "3", "4", "|", "_"]
43
+
44
+
45
+
46
+ def protocol_boxplot(
47
+ table_protocol: pd.DataFrame,
48
+ x: str,
49
+ y: str,
50
+ methods: Optional[List[str]] = None,
51
+ title: Optional[str] = None,
52
+ legend: bool = True,
53
+ save_path: Optional[str] = None,
54
+ order: Optional[str] = None,
55
+ plot_params: Optional[Dict[str, Any]] = None):
56
+ """
57
+ Plots a boxplot based on the provided DataFrame and selected methods.
58
+ """
59
+ # Handle plot_params
60
+ plot_params = plot_params or {}
61
+ figsize = plot_params.pop('figsize', (10, 6)) # Default figsize if not provided
62
+
63
+ # Prepare data
64
+ table = table_protocol.drop(["PRED_PREVS", "REAL_PREVS"], axis=1).copy()
65
+ methods = methods or table['QUANTIFIER'].unique()
66
+ table = table[table['QUANTIFIER'].isin(methods)]
67
+
68
+ # Order methods by ranking if specified
69
+ if order == 'rank':
70
+ methods = table.groupby('QUANTIFIER')[y].median().sort_values().index.tolist()
71
+
72
+ # Create plot with custom figsize
73
+ fig, ax = plt.subplots(figsize=figsize)
74
+ ax.grid(False)
75
+
76
+ box = ax.boxplot([table[table['QUANTIFIER'] == method][y] for method in methods],
77
+ patch_artist=True, widths=0.8, labels=methods, **plot_params)
78
+
79
+ # Apply colors
80
+ for patch, color in zip(box['boxes'], COLORS[:len(methods)]):
81
+ patch.set_facecolor(color)
82
+
83
+ # Add legend
84
+ if legend:
85
+ handles = [mpatches.Patch(color=COLORS[i], label=method) for i, method in enumerate(methods)]
86
+ ax.legend(handles=handles, title="Quantifiers", loc='upper left', bbox_to_anchor=(1, 1), fontsize=10, title_fontsize='11')
87
+
88
+ # Customize plot
89
+ ax.set_xticklabels(methods, rotation=45, fontstyle='italic')
90
+ ax.set_xlabel(x.capitalize())
91
+ ax.set_ylabel(f"{y.capitalize()}")
92
+ if title:
93
+ ax.set_title(title)
94
+
95
+ # Adjust layout and save plot
96
+ plt.tight_layout(rect=[0, 0, 0.9, 1])
97
+ if save_path:
98
+ plt.savefig(save_path, bbox_inches='tight')
99
+ plt.show()
100
+
101
+
102
+
103
+ def protocol_lineplot(
104
+ table_protocol: pd.DataFrame,
105
+ methods: Union[List[str], str, None],
106
+ x: str,
107
+ y: str,
108
+ title: Optional[str] = None,
109
+ legend: bool = True,
110
+ save_path: Optional[str] = None,
111
+ group_by: str = "mean",
112
+ pos_alpha: int = 1,
113
+ plot_params: Optional[Dict[str, Any]] = None):
114
+ """
115
+ Plots a line plot based on the provided DataFrame of the protocol and selected methods.
116
+ """
117
+ # Handle plot_params
118
+ plot_params = plot_params or {}
119
+ figsize = plot_params.pop('figsize', (10, 6)) # Default figsize if not provided
120
+
121
+ # Filter data
122
+ methods = methods or table_protocol['QUANTIFIER'].unique()
123
+ table_protocol = table_protocol[table_protocol['QUANTIFIER'].isin(methods)]
124
+
125
+ if x == "ALPHA":
126
+ real = table_protocol["REAL_PREVS"].apply(lambda x: x[pos_alpha])
127
+ table = table_protocol.drop(["PRED_PREVS", "REAL_PREVS"], axis=1).copy()
128
+ table["ALPHA"] = real
129
+ else:
130
+ table = table_protocol.drop(["PRED_PREVS", "REAL_PREVS"], axis=1).copy()
131
+
132
+ # Aggregate data
133
+ if group_by:
134
+ table = table.groupby(['QUANTIFIER', x])[y].agg(group_by).reset_index()
135
+
136
+ # Create plot with custom figsize
137
+ fig, ax = plt.subplots(figsize=figsize)
138
+ for i, (method, marker) in enumerate(zip(methods, MARKERS[:len(methods)])):
139
+ method_data = table[table['QUANTIFIER'] == method]
140
+ y_data = real if y == "ALPHA" else method_data[y]
141
+ ax.plot(method_data[x], y_data, color=COLORS[i % len(COLORS)], marker=marker, label=method, **plot_params)
142
+
143
+ # Add legend
144
+ if legend:
145
+ ax.legend(title="Quantifiers", loc='upper left', bbox_to_anchor=(1, 1), fontsize=10, title_fontsize='11')
146
+
147
+ # Customize plot
148
+ ax.set_xlabel(x.capitalize())
149
+ ax.set_ylabel(y.capitalize())
150
+ if title:
151
+ ax.set_title(title)
152
+
153
+ # Adjust layout and save plot
154
+ plt.tight_layout(rect=[0, 0, 0.9, 1])
155
+ if save_path:
156
+ plt.savefig(save_path, bbox_inches='tight')
157
+ plt.show()
@@ -0,0 +1,2 @@
1
+ from .general_purposes import *
2
+ from .method_purposes import *
@@ -0,0 +1,8 @@
1
+ from .normalize import normalize_prevalence
2
+ from .parallel import parallel
3
+ from .get_real_prev import get_real_prev
4
+ from .make_prevs import make_prevs
5
+ from .generate_artificial_indexes import generate_artificial_indexes
6
+ from .round_protocol_df import round_protocol_df
7
+ from .convert_col_to_array import convert_columns_to_arrays
8
+ from .load_quantifier import load_quantifier
@@ -0,0 +1,13 @@
1
+ import numpy as np
2
+
3
+ def convert_columns_to_arrays(df, columns:list = ['PRED_PREVS', 'REAL_PREVS']):
4
+ """Converts the specified columns from string of arrays to numpy arrays
5
+
6
+ Args:
7
+ df (array-like): the dataframe from which to change convert the coluns
8
+ columns (list, optional): the coluns with string of arrays, default is the options for
9
+ the protocol dataframes
10
+ """
11
+ for col in columns:
12
+ df[col] = df[col].apply(lambda x: np.fromstring(x.strip('[]'), sep=' ') if isinstance(x, str) else x)
13
+ return df
@@ -0,0 +1,29 @@
1
+ import numpy as np
2
+
3
+ def generate_artificial_indexes(y, prevalence: list, sample_size:int, classes:list):
4
+ # Ensure the sum of prevalences is 1
5
+ assert np.isclose(sum(prevalence), 1), "The sum of prevalences must be 1"
6
+ # Ensure the number of prevalences matches the number of classes
7
+
8
+ sampled_indexes = []
9
+ total_sampled = 0
10
+
11
+ for i, class_ in enumerate(classes):
12
+
13
+ if i == len(classes) - 1:
14
+ num_samples = sample_size - total_sampled
15
+ else:
16
+ num_samples = int(sample_size * prevalence[i])
17
+
18
+ # Get the indexes of the current class
19
+ class_indexes = np.where(y == class_)[0]
20
+
21
+ # Sample the indexes for the current class
22
+ sampled_class_indexes = np.random.choice(class_indexes, size=num_samples, replace=True)
23
+
24
+ sampled_indexes.extend(sampled_class_indexes)
25
+ total_sampled += num_samples
26
+
27
+ np.random.shuffle(sampled_indexes) # Shuffle after collecting all indexes
28
+
29
+ return sampled_indexes
@@ -0,0 +1,9 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+
4
+ def get_real_prev(y) -> dict:
5
+ if isinstance(y, np.ndarray):
6
+ y = pd.Series(y)
7
+ real_prevs = y.value_counts(normalize=True).to_dict()
8
+ real_prevs = dict(sorted(real_prevs.items()))
9
+ return real_prevs
@@ -0,0 +1,4 @@
1
+ import joblib
2
+
3
+ def load_quantifier(path:str):
4
+ return joblib.load(path)
@@ -0,0 +1,23 @@
1
+ import numpy as np
2
+
3
+ def make_prevs(ndim:int) -> list:
4
+ """
5
+ Generate a list of n_dim values uniformly distributed between 0 and 1 that sum exactly to 1.
6
+
7
+ Args:
8
+ n_dim (int): Number of values in the list.
9
+
10
+ Returns:
11
+ list: List of n_dim values that sum to 1.
12
+ """
13
+ # Generate n_dim-1 random u_dist uniformly distributed between 0 and 1
14
+ u_dist = np.random.uniform(0, 1, ndim - 1)
15
+ # Add 0 and 1 to the u_dist
16
+ u_dist = np.append(u_dist, [0, 1])
17
+ # Sort the u_dist
18
+ u_dist.sort()
19
+ # Calculate the differences between consecutive u_dist
20
+ prevs = np.diff(u_dist)
21
+
22
+ return prevs
23
+
@@ -0,0 +1,20 @@
1
+ import numpy as np
2
+ from collections import defaultdict
3
+
4
+ def normalize_prevalence(prevalences: np.ndarray, classes:list):
5
+
6
+ if isinstance(prevalences, dict):
7
+ summ = sum(prevalences.values())
8
+ prevalences = {int(_class):float(value/summ) for _class, value in prevalences.items()}
9
+ return prevalences
10
+
11
+ summ = np.sum(prevalences, axis=-1, keepdims=True)
12
+ prevalences = np.true_divide(prevalences, sum(prevalences), where=summ>0)
13
+ prevalences = {int(_class):float(prev) for _class, prev in zip(classes, prevalences)}
14
+ prevalences = defaultdict(lambda: 0, prevalences)
15
+
16
+ # Ensure all classes are present in the result
17
+ for cls in classes:
18
+ prevalences[cls] = prevalences[cls]
19
+
20
+ return dict(prevalences)
@@ -0,0 +1,10 @@
1
+ from joblib import Parallel, delayed
2
+ import numpy as np
3
+
4
+
5
+ def parallel(func, elements, n_jobs: int = 1, *args):
6
+ return Parallel(n_jobs=n_jobs)(
7
+ delayed(func)(e, *args) for e in elements
8
+ )
9
+
10
+
@@ -0,0 +1,14 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+
4
+
5
+ def round_protocol_df(dataframe: pd.DataFrame, frac: int = 3):
6
+ def round_column(col):
7
+ if col.name in ['PRED_PREVS', 'REAL_PREVS']:
8
+ return col.apply(lambda x: np.round(x, frac) if isinstance(x, (np.ndarray, float, int)) else x)
9
+ elif np.issubdtype(col.dtype, np.number):
10
+ return col.round(frac)
11
+ else:
12
+ return col
13
+
14
+ return dataframe.apply(round_column)
@@ -0,0 +1,6 @@
1
+ from .getHist import getHist
2
+ from .distances import sqEuclidean, probsymm, hellinger, topsoe
3
+ from .ternary_search import ternary_search
4
+ from .tprfpr import compute_table, compute_tpr, compute_fpr, adjust_threshold
5
+ from .get_scores import get_scores
6
+ from .moss import MoSS
@@ -0,0 +1,21 @@
1
+ import numpy as np
2
+
3
+ def sqEuclidean(dist1, dist2):
4
+ P=dist1
5
+ Q=dist2
6
+ return sum((P-Q)**2)
7
+
8
+ def probsymm(dist1, dist2):
9
+ P=dist1
10
+ Q=dist2
11
+ return 2*sum((P-Q)**2/(P+Q))
12
+
13
+ def topsoe(dist1, dist2):
14
+ P=dist1
15
+ Q=dist2
16
+ return sum(P*np.log(2*P/(P+Q))+Q*np.log(2*Q/(P+Q)))
17
+
18
+ def hellinger(dist1, dist2):
19
+ P=dist1
20
+ Q=dist2
21
+ return 2 * np.sqrt(np.abs(1 - sum(np.sqrt(P * Q))))