mlquantify 0.0.11.2__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. mlquantify/__init__.py +32 -6
  2. mlquantify/base.py +559 -257
  3. mlquantify/classification/__init__.py +1 -1
  4. mlquantify/classification/methods.py +160 -0
  5. mlquantify/evaluation/__init__.py +14 -2
  6. mlquantify/evaluation/measures.py +215 -0
  7. mlquantify/evaluation/protocol.py +647 -0
  8. mlquantify/methods/__init__.py +37 -40
  9. mlquantify/methods/aggregative.py +1030 -0
  10. mlquantify/methods/meta.py +472 -0
  11. mlquantify/methods/mixture_models.py +1003 -0
  12. mlquantify/methods/non_aggregative.py +136 -0
  13. mlquantify/methods/threshold_optimization.py +957 -0
  14. mlquantify/model_selection.py +377 -232
  15. mlquantify/plots.py +367 -0
  16. mlquantify/utils/__init__.py +2 -2
  17. mlquantify/utils/general.py +334 -0
  18. mlquantify/utils/method.py +449 -0
  19. {mlquantify-0.0.11.2.dist-info → mlquantify-0.1.1.dist-info}/METADATA +137 -122
  20. mlquantify-0.1.1.dist-info/RECORD +22 -0
  21. {mlquantify-0.0.11.2.dist-info → mlquantify-0.1.1.dist-info}/WHEEL +1 -1
  22. mlquantify/classification/pwkclf.py +0 -73
  23. mlquantify/evaluation/measures/__init__.py +0 -26
  24. mlquantify/evaluation/measures/ae.py +0 -11
  25. mlquantify/evaluation/measures/bias.py +0 -16
  26. mlquantify/evaluation/measures/kld.py +0 -8
  27. mlquantify/evaluation/measures/mse.py +0 -12
  28. mlquantify/evaluation/measures/nae.py +0 -16
  29. mlquantify/evaluation/measures/nkld.py +0 -13
  30. mlquantify/evaluation/measures/nrae.py +0 -16
  31. mlquantify/evaluation/measures/rae.py +0 -12
  32. mlquantify/evaluation/measures/se.py +0 -12
  33. mlquantify/evaluation/protocol/_Protocol.py +0 -202
  34. mlquantify/evaluation/protocol/__init__.py +0 -2
  35. mlquantify/evaluation/protocol/app.py +0 -146
  36. mlquantify/evaluation/protocol/npp.py +0 -34
  37. mlquantify/methods/aggregative/ThreholdOptm/_ThreholdOptimization.py +0 -62
  38. mlquantify/methods/aggregative/ThreholdOptm/__init__.py +0 -7
  39. mlquantify/methods/aggregative/ThreholdOptm/acc.py +0 -27
  40. mlquantify/methods/aggregative/ThreholdOptm/max.py +0 -23
  41. mlquantify/methods/aggregative/ThreholdOptm/ms.py +0 -21
  42. mlquantify/methods/aggregative/ThreholdOptm/ms2.py +0 -25
  43. mlquantify/methods/aggregative/ThreholdOptm/pacc.py +0 -41
  44. mlquantify/methods/aggregative/ThreholdOptm/t50.py +0 -21
  45. mlquantify/methods/aggregative/ThreholdOptm/x.py +0 -23
  46. mlquantify/methods/aggregative/__init__.py +0 -9
  47. mlquantify/methods/aggregative/cc.py +0 -32
  48. mlquantify/methods/aggregative/emq.py +0 -86
  49. mlquantify/methods/aggregative/fm.py +0 -72
  50. mlquantify/methods/aggregative/gac.py +0 -96
  51. mlquantify/methods/aggregative/gpac.py +0 -87
  52. mlquantify/methods/aggregative/mixtureModels/_MixtureModel.py +0 -81
  53. mlquantify/methods/aggregative/mixtureModels/__init__.py +0 -5
  54. mlquantify/methods/aggregative/mixtureModels/dys.py +0 -55
  55. mlquantify/methods/aggregative/mixtureModels/dys_syn.py +0 -89
  56. mlquantify/methods/aggregative/mixtureModels/hdy.py +0 -46
  57. mlquantify/methods/aggregative/mixtureModels/smm.py +0 -27
  58. mlquantify/methods/aggregative/mixtureModels/sord.py +0 -77
  59. mlquantify/methods/aggregative/pcc.py +0 -33
  60. mlquantify/methods/aggregative/pwk.py +0 -38
  61. mlquantify/methods/meta/__init__.py +0 -1
  62. mlquantify/methods/meta/ensemble.py +0 -236
  63. mlquantify/methods/non_aggregative/__init__.py +0 -1
  64. mlquantify/methods/non_aggregative/hdx.py +0 -71
  65. mlquantify/plots/__init__.py +0 -2
  66. mlquantify/plots/distribution_plot.py +0 -109
  67. mlquantify/plots/protocol_plot.py +0 -193
  68. mlquantify/utils/general_purposes/__init__.py +0 -8
  69. mlquantify/utils/general_purposes/convert_col_to_array.py +0 -13
  70. mlquantify/utils/general_purposes/generate_artificial_indexes.py +0 -29
  71. mlquantify/utils/general_purposes/get_real_prev.py +0 -9
  72. mlquantify/utils/general_purposes/load_quantifier.py +0 -4
  73. mlquantify/utils/general_purposes/make_prevs.py +0 -23
  74. mlquantify/utils/general_purposes/normalize.py +0 -20
  75. mlquantify/utils/general_purposes/parallel.py +0 -10
  76. mlquantify/utils/general_purposes/round_protocol_df.py +0 -14
  77. mlquantify/utils/method_purposes/__init__.py +0 -6
  78. mlquantify/utils/method_purposes/distances.py +0 -21
  79. mlquantify/utils/method_purposes/getHist.py +0 -13
  80. mlquantify/utils/method_purposes/get_scores.py +0 -33
  81. mlquantify/utils/method_purposes/moss.py +0 -16
  82. mlquantify/utils/method_purposes/ternary_search.py +0 -14
  83. mlquantify/utils/method_purposes/tprfpr.py +0 -42
  84. mlquantify-0.0.11.2.dist-info/RECORD +0 -73
  85. {mlquantify-0.0.11.2.dist-info → mlquantify-0.1.1.dist-info}/top_level.txt +0 -0
mlquantify/plots.py ADDED
@@ -0,0 +1,367 @@
1
+ import numpy as np
2
+ import matplotlib.pyplot as plt
3
+ import matplotlib.colors as mcolors
4
+ import matplotlib.patches as mpatches
5
+ import pandas as pd
6
+ from typing import List, Optional, Dict, Any, Union
7
+
8
+
9
+
10
+
11
+ plt.rcParams.update({
12
+ 'lines.markersize': 6,
13
+ 'axes.facecolor': "#F8F8F8",
14
+ 'figure.facecolor': "#F8F8F8",
15
+ 'font.family': 'sans-serif',
16
+ 'font.sans-serif': 'Arial',
17
+ 'font.size': 12,
18
+ 'font.weight': 'light',
19
+ 'axes.labelsize': 14,
20
+ 'axes.labelweight': 'light',
21
+ 'axes.titlesize': 16,
22
+ 'axes.titleweight': 'normal',
23
+ 'boxplot.boxprops.linewidth': 0.3,
24
+ 'boxplot.whiskerprops.linewidth': 0.3,
25
+ 'boxplot.capprops.linewidth': 0.3,
26
+ 'boxplot.medianprops.linewidth': 0.6,
27
+ 'boxplot.flierprops.linewidth': 0.3,
28
+ 'boxplot.flierprops.markersize': 0.9,
29
+ 'boxplot.medianprops.color': 'black',
30
+ 'figure.subplot.bottom': 0.2,
31
+ 'axes.grid': True,
32
+ 'grid.color': 'black',
33
+ 'grid.alpha': 0.1,
34
+ 'grid.linewidth': 0.5,
35
+ 'grid.linestyle': '--'
36
+ })
37
+
38
+ # Colors and markers
39
+ COLORS = [
40
+ '#FFAB91', '#FFE082', '#A5D6A7', '#4DD0E1', '#FF6F61', '#FF8C94', '#D4A5A5',
41
+ '#FF677D', '#B9FBC0', '#C2C2F0', '#E3F9A6', '#E2A8F7', '#F7B7A3', '#F7C6C7',
42
+ '#8D9BFC', '#B4E6FF', '#FF8A65', '#FFC3A0', '#FFCCBC', '#F8BBD0', '#FF9AA2',
43
+ '#FFB3B3', '#FFDDC1', '#FFE0B2', '#E2A8F7', '#F7C6C7', '#E57373', '#BA68C8',
44
+ '#4FC3F7', '#FFB3B3', '#FF6F61'
45
+ ]
46
+
47
+ MARKERS = ["o", "s", "^", "D", "p", "*", "+", "x", "H", "1", "2", "3", "4", "|", "_"]
48
+
49
+ def adjust_color_saturation(color: str, saturation_factor: float = 5) -> str:
50
+ """
51
+ Adjusts the saturation of a given color.
52
+
53
+ Parameters
54
+ ----------
55
+ color : str
56
+ Color in hexadecimal format.
57
+ saturation_factor : float, optional
58
+ Factor to adjust the saturation. Default is 5.
59
+
60
+ Returns
61
+ -------
62
+ str
63
+ Color in hexadecimal format with adjusted saturation.
64
+ """
65
+ # Convert color to HSV (Hue, Saturation, Value)
66
+ h, s, v = mcolors.rgb_to_hsv(mcolors.to_rgb(color))
67
+
68
+ # Adjust saturation
69
+ s = min(1, s * saturation_factor)
70
+
71
+ # Convert back to RGB and then to hex
72
+ return mcolors.to_hex(mcolors.hsv_to_rgb((h, s, v)))
73
+
74
+
75
+
76
+ def protocol_boxplot(
77
+ table_protocol: pd.DataFrame,
78
+ x: str,
79
+ y: str,
80
+ methods: Optional[List[str]] = None,
81
+ title: Optional[str] = None,
82
+ legend: bool = True,
83
+ save_path: Optional[str] = None,
84
+ order: Optional[str] = None,
85
+ plot_params: Optional[Dict[str, Any]] = None):
86
+ """
87
+ Plots a boxplot based on the provided DataFrame and selected methods.
88
+
89
+ Parameters
90
+ ----------
91
+ table_protocol : pd.DataFrame
92
+ DataFrame containing the protocol results.
93
+ x : str
94
+ Column name to use as the x-axis.
95
+ y : str
96
+ Column name to use as the y-axis.
97
+ methods : List[str], optional
98
+ List of quantifiers to plot. If not provided, all quantifiers will be plotted.
99
+ title : str, optional
100
+ Title of the plot. Default is None.
101
+ legend : bool, optional
102
+ Whether to display a legend. Default is True.
103
+ save_path : str, optional
104
+ File path to save the plot image. If not provided, the plot will not be saved.
105
+ order : str, optional
106
+ Order to plot the methods. If 'rank', methods will be ordered by median value.
107
+ plot_params : Dict[str, Any], optional
108
+ Dictionary of custom plotting parameters to apply. Default is None
109
+ """
110
+ # Handle plot_params
111
+ plot_params = plot_params or {}
112
+ figsize = plot_params.pop('figsize', (10, 6)) # Default figsize if not provided
113
+
114
+ # Prepare data
115
+ table = table_protocol.drop(["PRED_PREVS", "REAL_PREVS"], axis=1).copy()
116
+ methods = methods or table['QUANTIFIER'].unique()
117
+ table = table[table['QUANTIFIER'].isin(methods)]
118
+
119
+ # Order methods by ranking if specified
120
+ if order == 'rank':
121
+ methods = table.groupby('QUANTIFIER')[y].median().sort_values().index.tolist()
122
+
123
+ # Create plot with custom figsize
124
+ fig, ax = plt.subplots(figsize=figsize)
125
+ ax.grid(False)
126
+
127
+ box = ax.boxplot([table[table['QUANTIFIER'] == method][y] for method in methods],
128
+ patch_artist=True, widths=0.8, labels=methods, **plot_params)
129
+
130
+ # Apply colors
131
+ for patch, color in zip(box['boxes'], COLORS[:len(methods)]):
132
+ patch.set_facecolor(color)
133
+
134
+ # Add legend
135
+ if legend:
136
+ handles = [mpatches.Patch(color=COLORS[i], label=method) for i, method in enumerate(methods)]
137
+ ax.legend(handles=handles, title="Quantifiers", loc='upper left', bbox_to_anchor=(1, 1), fontsize=10, title_fontsize='11')
138
+
139
+ # Customize plot
140
+ ax.set_xticklabels(methods, rotation=45, fontstyle='italic')
141
+ ax.set_xlabel(x.capitalize())
142
+ ax.set_ylabel(f"{y.capitalize()}")
143
+ if title:
144
+ ax.set_title(title)
145
+
146
+ # Adjust layout and save plot
147
+ plt.tight_layout(rect=[0, 0, 0.9, 1])
148
+ if save_path:
149
+ plt.savefig(save_path, bbox_inches='tight')
150
+ plt.show()
151
+
152
+
153
+
154
+
155
+
156
+
157
+
158
+ def protocol_lineplot(
159
+ table_protocol: pd.DataFrame,
160
+ methods: Union[List[str], str, None],
161
+ x: str,
162
+ y: str,
163
+ title: Optional[str] = None,
164
+ legend: bool = True,
165
+ save_path: Optional[str] = None,
166
+ group_by: str = "mean",
167
+ pos_alpha: int = 1,
168
+ plot_params: Optional[Dict[str, Any]] = None):
169
+ """
170
+ Plots a line plot based on the provided DataFrame of the protocol and selected methods.
171
+
172
+ Parameters
173
+ ----------
174
+ table_protocol : pd.DataFrame
175
+ DataFrame containing the protocol results.
176
+ methods : Union[List[str], str, None]
177
+ List of quantifiers to plot. If not provided, all quantifiers will be plotted.
178
+ x : str
179
+ Column name to use as the x-axis.
180
+ - If 'ALPHA', the real prevalence of the positive class will be used.
181
+ - You can also use any other column name, as long as the x has the same name.
182
+ y : str
183
+ Column name to use as the y-axis.
184
+ title : str, optional
185
+ Title of the plot. Default is None.
186
+ legend : bool, optional
187
+ Whether to display a legend. Default is True.
188
+ save_path : str, optional
189
+ File path to save the plot image. If not provided, the plot will not be saved.
190
+ group_by : str, optional
191
+ Column to group the data. Default is 'mean'.
192
+ pos_alpha : int, optional
193
+ Position of the positive class in the 'PREVS' column, this attribute only works for binary problems. Default is 1.
194
+ plot_params : Dict[str, Any], optional
195
+ Dictionary of custom plotting parameters to apply. Default is None.
196
+ """
197
+ # Handle plot_params
198
+ plot_params = plot_params or {}
199
+ figsize = plot_params.pop('figsize', (10, 6)) # Default figsize if not provided
200
+
201
+ # Filter data
202
+ methods = methods or table_protocol['QUANTIFIER'].unique()
203
+ table_protocol = table_protocol[table_protocol['QUANTIFIER'].isin(methods)]
204
+
205
+ if x == "ALPHA":
206
+ real = table_protocol["REAL_PREVS"].apply(lambda x: x[pos_alpha])
207
+ table = table_protocol.drop(["PRED_PREVS", "REAL_PREVS"], axis=1).copy()
208
+ table["ALPHA"] = real
209
+ else:
210
+ table = table_protocol.drop(["PRED_PREVS", "REAL_PREVS"], axis=1).copy()
211
+
212
+ # Aggregate data
213
+ if group_by:
214
+ table = table.groupby(['QUANTIFIER', x])[y].agg(group_by).reset_index()
215
+
216
+ # Create plot with custom figsize
217
+ fig, ax = plt.subplots(figsize=figsize)
218
+ for i, (method, marker) in enumerate(zip(methods, MARKERS[:len(methods)+1])):
219
+ method_data = table[table['QUANTIFIER'] == method]
220
+ y_data = real if y == "ALPHA" else method_data[y]
221
+ color = adjust_color_saturation(COLORS[i % len(COLORS)]) # Aumenta a saturação das cores
222
+ ax.plot(method_data[x],
223
+ y_data, color=color,
224
+ marker=marker,
225
+ label=method,
226
+ alpha=1.0,
227
+ **plot_params)
228
+
229
+ # Add legend
230
+ if legend:
231
+ ax.legend(title="Quantifiers", loc='upper left', bbox_to_anchor=(1, 1), fontsize=10, title_fontsize='11')
232
+
233
+ # Customize plot
234
+ ax.set_xlabel(x.capitalize())
235
+ ax.set_ylabel(y.capitalize())
236
+ if title:
237
+ ax.set_title(title)
238
+
239
+ # Adjust layout and save plot
240
+ plt.tight_layout(rect=[0, 0, 0.9, 1])
241
+ if save_path:
242
+ plt.savefig(save_path, bbox_inches='tight')
243
+ plt.show()
244
+
245
+
246
+
247
+
248
+
249
+
250
+
251
+
252
+
253
+
254
+
255
+
256
+
257
+
258
+ def class_distribution_plot(values: Union[List, np.ndarray],
259
+ labels: Union[List, np.ndarray],
260
+ bins: int = 30,
261
+ title: Optional[str] = None,
262
+ legend: bool = True,
263
+ save_path: Optional[str] = None,
264
+ plot_params: Optional[Dict[str, Any]] = None):
265
+ """Plot overlaid histograms of class distributions.
266
+
267
+ This function creates a plot with overlaid histograms, each representing the distribution
268
+ of a different class or category. Custom colors, titles, legends, and other plot parameters
269
+ can be applied to enhance visualization.
270
+
271
+ Parameters
272
+ ----------
273
+ values : Union[List, np.ndarray]
274
+ List or array of values to plot.
275
+ labels : Union[List, np.ndarray]
276
+ List or array of labels corresponding to the values.
277
+ bins : int, optional
278
+ Number of bins to use for the histogram. Default is 30.
279
+ title : str, optional
280
+ Title of the plot. Default is None.
281
+ legend : bool, optional
282
+ Whether to display a legend. Default is True.
283
+ save_path : str, optional
284
+ File path to save the plot image. If not provided, the plot will not be saved.
285
+ plot_params : Dict[str, Any], optional
286
+ Dictionary of custom plotting parameters to apply. Default is None.
287
+
288
+ Raises
289
+ ------
290
+ AssertionError
291
+ If the number of value sets does not match the number of labels.
292
+ """
293
+
294
+ # Ensure the number of labels matches the number of value sets
295
+ assert len(values) == len(labels), "The number of value sets must match the number of labels."
296
+
297
+ if isinstance(values, list):
298
+ values = np.asarray(values)
299
+ if isinstance(labels, list):
300
+ labels = np.asarray(labels)
301
+
302
+
303
+ # Apply custom plotting parameters if provided
304
+ if plot_params:
305
+ plt.rcParams.update(plot_params)
306
+
307
+ if values.shape[1] > 1:
308
+ num_plots = values.shape[1] # Number of columns in `values`
309
+ cols = int(np.ceil(np.sqrt(num_plots)))
310
+ rows = int(np.ceil(num_plots / cols))
311
+
312
+ fig, axs = plt.subplots(rows, cols, figsize=(cols * 5, rows * 4))
313
+ axs = axs.flatten()
314
+
315
+ # Create the overlaid histogram
316
+ for i, label in enumerate(np.unique(labels)):
317
+ if values.shape[1] > 1:
318
+ for j, lab in enumerate(np.unique(labels)):
319
+ value_set = values[:, j][label == labels]
320
+ axs[i].hist(value_set, bins=bins, color=COLORS[j % len(COLORS)], edgecolor='black', alpha=0.5, label=lab)
321
+ axs[i].set_xlim([0, 1]) # Fix x-axis range between 0 and 1
322
+ else:
323
+ value_set = values[label == labels]
324
+ plt.hist(value_set, bins=bins, color=COLORS[i % len(COLORS)], edgecolor='black', alpha=0.5, label=label)
325
+ plt.xlim([0, 1]) # Fix x-axis range between 0 and 1
326
+
327
+ if values.shape[1] > 1:
328
+ for i in range(i + 1, len(axs)):
329
+ fig.delaxes(axs[i])
330
+
331
+ # Add title to the plot if provided
332
+ if title:
333
+ if values.shape[1] > 1:
334
+ for i in range(values.shape[1]):
335
+ axs[i].set_title(f'{title} for class {i+1}')
336
+ else:
337
+ plt.title(title)
338
+
339
+ # Add legend to the plot if enabled
340
+ if legend:
341
+ if values.shape[1] > 1:
342
+ for i in range(values.shape[1]):
343
+ axs[i].legend(loc='upper right')
344
+ else:
345
+ plt.legend(loc='upper right')
346
+
347
+ # Set axis labels
348
+ if values.shape[1] > 1:
349
+ for i in range(values.shape[1]):
350
+ axs[i].set_xlabel('Values')
351
+ axs[i].set_ylabel('Frequency')
352
+ else:
353
+ plt.xlabel('Values')
354
+ plt.ylabel('Frequency')
355
+
356
+ # Adjust layout to prevent overlapping
357
+ plt.subplots_adjust(hspace=0.9, wspace=0.4)
358
+ plt.tight_layout()
359
+
360
+ # Save the figure if a path is specified
361
+ if save_path:
362
+ plt.savefig(save_path, bbox_inches='tight')
363
+
364
+ # Show the plot
365
+ plt.show()
366
+
367
+
@@ -1,2 +1,2 @@
1
- from .general_purposes import *
2
- from .method_purposes import *
1
+ from .general import *
2
+ from .method import *
@@ -0,0 +1,334 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+ from joblib import Parallel, delayed, load
4
+ from collections import defaultdict
5
+ import itertools
6
+
7
+
8
+ def convert_columns_to_arrays(df, columns:list = ['PRED_PREVS', 'REAL_PREVS']):
9
+ """
10
+ Converts specified columns in a DataFrame from strings of arrays to NumPy arrays.
11
+
12
+ Parameters
13
+ ----------
14
+ df : pd.DataFrame
15
+ DataFrame to convert.
16
+ columns : list
17
+ List of columns to convert.
18
+
19
+ Returns
20
+ -------
21
+ pd.DataFrame
22
+ DataFrame with the specified columns converted to NumPy arrays.
23
+ """
24
+ for col in columns:
25
+ df[col] = df[col].apply(lambda x: np.fromstring(x.strip('[]'), sep=' ') if isinstance(x, str) else x)
26
+ return df
27
+
28
+
29
+
30
+
31
+
32
+ def generate_artificial_indexes(y, prevalence: list, sample_size:int, classes:list):
33
+ """
34
+ Generate indexes for a stratified sample based on the prevalence of each class.
35
+
36
+ Parameters
37
+ ----------
38
+ y : np.ndarray
39
+ Array of class labels.
40
+ prevalence : list
41
+ List of prevalences for each class.
42
+ sample_size : int
43
+ Number of samples to generate.
44
+ classes : list
45
+ List of unique classes.
46
+
47
+ Returns
48
+ -------
49
+ list
50
+ List of indexes for the stratified sample.
51
+ """
52
+ # Ensure the sum of prevalences is 1
53
+ assert np.isclose(sum(prevalence), 1), "The sum of prevalences must be 1"
54
+ # Ensure the number of prevalences matches the number of classes
55
+
56
+ sampled_indexes = []
57
+ total_sampled = 0
58
+
59
+ for i, class_ in enumerate(classes):
60
+
61
+ if i == len(classes) - 1:
62
+ num_samples = sample_size - total_sampled
63
+ else:
64
+ num_samples = int(sample_size * prevalence[i])
65
+
66
+ # Get the indexes of the current class
67
+ class_indexes = np.where(y == class_)[0]
68
+
69
+ # Sample the indexes for the current class
70
+ sampled_class_indexes = np.random.choice(class_indexes, size=num_samples, replace=True)
71
+
72
+ sampled_indexes.extend(sampled_class_indexes)
73
+ total_sampled += num_samples
74
+
75
+ np.random.shuffle(sampled_indexes) # Shuffle after collecting all indexes
76
+
77
+ return sampled_indexes
78
+
79
+
80
+
81
+
82
+ def generate_artificial_prevalences(n_dim: int, n_prev: int, n_iter: int) -> np.ndarray:
83
+ """Generates n artificial prevalences with n dimensions.
84
+
85
+ Parameters
86
+ ----------
87
+ n_dim : int
88
+ Number of dimensions.
89
+ n_prev : int
90
+ Number of prevalences to generate.
91
+ n_iter : int
92
+ Number of iterations.
93
+
94
+ Returns
95
+ -------
96
+ np.ndarray
97
+ Array of artificial prevalences.
98
+
99
+ """
100
+ s = np.linspace(0., 1., n_prev, endpoint=True)
101
+ prevs = np.array([p + (1 - sum(p),) for p in itertools.product(*(s,) * (n_dim - 1)) if sum(p) <= 1])
102
+
103
+ return np.repeat(prevs, n_iter, axis=0) if n_iter > 1 else prevs
104
+
105
+
106
+
107
+
108
+
109
+
110
+
111
+
112
+ def get_real_prev(y) -> dict:
113
+ """
114
+ Get the real prevalence of each class in the target array.
115
+
116
+ Parameters
117
+ ----------
118
+ y : np.ndarray or pd.Series
119
+ Array of class labels.
120
+
121
+ Returns
122
+ -------
123
+ dict
124
+ Dictionary of class labels and their corresponding prevalence.
125
+ """
126
+ if isinstance(y, np.ndarray):
127
+ y = pd.Series(y)
128
+ real_prevs = y.value_counts(normalize=True).to_dict()
129
+ real_prevs = dict(sorted(real_prevs.items()))
130
+ return real_prevs
131
+
132
+
133
+
134
+
135
+
136
+
137
+
138
+
139
+
140
+ def load_quantifier(path:str):
141
+ """
142
+ Load a quantifier from a file.
143
+
144
+ Parameters
145
+ ----------
146
+ path : str
147
+ Path to the file containing the quantifier.
148
+
149
+ Returns
150
+ -------
151
+ Quantifier
152
+ Loaded quantifier.
153
+ """
154
+ return load(path)
155
+
156
+
157
+
158
+
159
+
160
+
161
+
162
+
163
+
164
+ def make_prevs(ndim:int) -> list:
165
+ """
166
+ Generate a list of n_dim values uniformly distributed between 0 and 1 that sum exactly to 1.
167
+
168
+ Parameters
169
+ ----------
170
+ ndim : int
171
+ Number of dimensions.
172
+
173
+ Returns
174
+ -------
175
+ list
176
+ List of n_dim values uniformly distributed between 0 and 1 that sum exactly to 1.
177
+ """
178
+ # Generate n_dim-1 random u_dist uniformly distributed between 0 and 1
179
+ u_dist = np.random.uniform(0, 1, ndim - 1)
180
+ # Add 0 and 1 to the u_dist
181
+ u_dist = np.append(u_dist, [0, 1])
182
+ # Sort the u_dist
183
+ u_dist.sort()
184
+ # Calculate the differences between consecutive u_dist
185
+ prevs = np.diff(u_dist)
186
+
187
+ return prevs
188
+
189
+
190
+
191
+
192
+
193
+
194
+
195
+
196
+
197
+
198
+
199
+
200
+
201
+ def normalize_prevalence(prevalences: np.ndarray, classes:list):
202
+ """
203
+ Normalize the prevalence of each class to sum to 1.
204
+
205
+ Parameters
206
+ ----------
207
+ prevalences : np.ndarray
208
+ Array of prevalences.
209
+ classes : list
210
+ List of unique classes.
211
+
212
+ Returns
213
+ -------
214
+ dict
215
+ Dictionary of class labels and their corresponding prevalence.
216
+ """
217
+ if isinstance(prevalences, dict):
218
+ summ = sum(prevalences.values())
219
+ prevalences = {int(_class):float(value/summ) for _class, value in prevalences.items()}
220
+ return prevalences
221
+
222
+ summ = np.sum(prevalences, axis=-1, keepdims=True)
223
+ prevalences = np.true_divide(prevalences, sum(prevalences), where=summ>0)
224
+ prevalences = {int(_class):float(prev) for _class, prev in zip(classes, prevalences)}
225
+ prevalences = defaultdict(lambda: 0, prevalences)
226
+
227
+ # Ensure all classes are present in the result
228
+ for cls in classes:
229
+ prevalences[cls] = prevalences[cls]
230
+
231
+ return dict(prevalences)
232
+
233
+
234
+
235
+
236
+
237
+
238
+
239
+ def parallel(func, elements, n_jobs: int = 1, *args):
240
+ """
241
+ Run a function in parallel on a list of elements.
242
+
243
+ Parameters
244
+ ----------
245
+ func : function
246
+ Function to run in parallel.
247
+ elements : list
248
+ List of elements to run the function on.
249
+ n_jobs : int
250
+ Number of jobs to run in parallel.
251
+ args : tuple
252
+ Additional arguments to pass to the function.
253
+
254
+ Returns
255
+ -------
256
+ list
257
+ List of results from running the function on each element.
258
+ """
259
+ return Parallel(n_jobs=n_jobs, backend="threading")(
260
+ delayed(func)(e, *args) for e in elements
261
+ )
262
+
263
+
264
+
265
+
266
+
267
+
268
+
269
+
270
+
271
+
272
+ def round_protocol_df(dataframe: pd.DataFrame, frac: int = 3):
273
+ """
274
+ Round the columns of a protocol dataframe to a specified number of decimal places.
275
+
276
+ Parameters
277
+ ----------
278
+ dataframe : pd.DataFrame
279
+ Protocol dataframe to round.
280
+ frac : int
281
+ Number of decimal places to round to.
282
+
283
+ Returns
284
+ -------
285
+ pd.DataFrame
286
+ Protocol dataFrame with the columns rounded to the specified number of decimal places.
287
+ """
288
+ def round_column(col):
289
+ if col.name in ['PRED_PREVS', 'REAL_PREVS']:
290
+ return col.apply(lambda x: np.round(x, frac) if isinstance(x, (np.ndarray, float, int)) else x)
291
+ elif np.issubdtype(col.dtype, np.number):
292
+ return col.round(frac)
293
+ else:
294
+ return col
295
+
296
+ return dataframe.apply(round_column)
297
+
298
+
299
+
300
+
301
+ def get_measure(measure:str):
302
+ """
303
+ Get the measure from the evaluation module.
304
+
305
+ Parameters
306
+ ----------
307
+ measure : str
308
+ Measure to get.
309
+
310
+ Returns
311
+ -------
312
+ Measure
313
+ Measure function from the evaluation module.
314
+ """
315
+ from ..evaluation import MEASURES
316
+ return MEASURES.get(measure)
317
+
318
+
319
+ def get_method(method: str):
320
+ """
321
+ Get the method from the methods module.
322
+
323
+ Parameters
324
+ ----------
325
+ method : str
326
+ Method to get.
327
+
328
+ Returns
329
+ -------
330
+ Method
331
+ Method class from the methods module.
332
+ """
333
+ from ..methods import METHODS
334
+ return METHODS.get(method)