birdnet-analyzer 2.0.0__py3-none-any.whl → 2.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (122) hide show
  1. birdnet_analyzer/__init__.py +9 -8
  2. birdnet_analyzer/analyze/__init__.py +5 -5
  3. birdnet_analyzer/analyze/__main__.py +3 -4
  4. birdnet_analyzer/analyze/cli.py +25 -25
  5. birdnet_analyzer/analyze/core.py +241 -245
  6. birdnet_analyzer/analyze/utils.py +692 -701
  7. birdnet_analyzer/audio.py +368 -372
  8. birdnet_analyzer/cli.py +709 -707
  9. birdnet_analyzer/config.py +242 -242
  10. birdnet_analyzer/eBird_taxonomy_codes_2021E.json +25279 -25279
  11. birdnet_analyzer/embeddings/__init__.py +3 -4
  12. birdnet_analyzer/embeddings/__main__.py +3 -3
  13. birdnet_analyzer/embeddings/cli.py +12 -13
  14. birdnet_analyzer/embeddings/core.py +69 -70
  15. birdnet_analyzer/embeddings/utils.py +179 -193
  16. birdnet_analyzer/evaluation/__init__.py +196 -195
  17. birdnet_analyzer/evaluation/__main__.py +3 -3
  18. birdnet_analyzer/evaluation/assessment/__init__.py +0 -0
  19. birdnet_analyzer/evaluation/assessment/metrics.py +388 -0
  20. birdnet_analyzer/evaluation/assessment/performance_assessor.py +409 -0
  21. birdnet_analyzer/evaluation/assessment/plotting.py +379 -0
  22. birdnet_analyzer/evaluation/preprocessing/__init__.py +0 -0
  23. birdnet_analyzer/evaluation/preprocessing/data_processor.py +631 -0
  24. birdnet_analyzer/evaluation/preprocessing/utils.py +98 -0
  25. birdnet_analyzer/gui/__init__.py +19 -23
  26. birdnet_analyzer/gui/__main__.py +3 -3
  27. birdnet_analyzer/gui/analysis.py +175 -174
  28. birdnet_analyzer/gui/assets/arrow_down.svg +4 -4
  29. birdnet_analyzer/gui/assets/arrow_left.svg +4 -4
  30. birdnet_analyzer/gui/assets/arrow_right.svg +4 -4
  31. birdnet_analyzer/gui/assets/arrow_up.svg +4 -4
  32. birdnet_analyzer/gui/assets/gui.css +28 -28
  33. birdnet_analyzer/gui/assets/gui.js +93 -93
  34. birdnet_analyzer/gui/embeddings.py +619 -620
  35. birdnet_analyzer/gui/evaluation.py +795 -813
  36. birdnet_analyzer/gui/localization.py +75 -68
  37. birdnet_analyzer/gui/multi_file.py +245 -246
  38. birdnet_analyzer/gui/review.py +519 -527
  39. birdnet_analyzer/gui/segments.py +191 -191
  40. birdnet_analyzer/gui/settings.py +128 -129
  41. birdnet_analyzer/gui/single_file.py +267 -269
  42. birdnet_analyzer/gui/species.py +95 -95
  43. birdnet_analyzer/gui/train.py +696 -698
  44. birdnet_analyzer/gui/utils.py +810 -808
  45. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_af.txt +6522 -6522
  46. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ar.txt +6522 -6522
  47. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_bg.txt +6522 -6522
  48. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ca.txt +6522 -6522
  49. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_cs.txt +6522 -6522
  50. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_da.txt +6522 -6522
  51. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_de.txt +6522 -6522
  52. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_el.txt +6522 -6522
  53. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_en_uk.txt +6522 -6522
  54. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_es.txt +6522 -6522
  55. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_fi.txt +6522 -6522
  56. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_fr.txt +6522 -6522
  57. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_he.txt +6522 -6522
  58. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_hr.txt +6522 -6522
  59. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_hu.txt +6522 -6522
  60. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_in.txt +6522 -6522
  61. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_is.txt +6522 -6522
  62. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_it.txt +6522 -6522
  63. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ja.txt +6522 -6522
  64. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ko.txt +6522 -6522
  65. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_lt.txt +6522 -6522
  66. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ml.txt +6522 -6522
  67. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_nl.txt +6522 -6522
  68. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_no.txt +6522 -6522
  69. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_pl.txt +6522 -6522
  70. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_pt_BR.txt +6522 -6522
  71. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_pt_PT.txt +6522 -6522
  72. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ro.txt +6522 -6522
  73. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ru.txt +6522 -6522
  74. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sk.txt +6522 -6522
  75. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sl.txt +6522 -6522
  76. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sr.txt +6522 -6522
  77. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sv.txt +6522 -6522
  78. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_th.txt +6522 -6522
  79. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_tr.txt +6522 -6522
  80. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_uk.txt +6522 -6522
  81. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_zh.txt +6522 -6522
  82. birdnet_analyzer/lang/de.json +334 -334
  83. birdnet_analyzer/lang/en.json +334 -334
  84. birdnet_analyzer/lang/fi.json +334 -334
  85. birdnet_analyzer/lang/fr.json +334 -334
  86. birdnet_analyzer/lang/id.json +334 -334
  87. birdnet_analyzer/lang/pt-br.json +334 -334
  88. birdnet_analyzer/lang/ru.json +334 -334
  89. birdnet_analyzer/lang/se.json +334 -334
  90. birdnet_analyzer/lang/tlh.json +334 -334
  91. birdnet_analyzer/lang/zh_TW.json +334 -334
  92. birdnet_analyzer/model.py +1212 -1243
  93. birdnet_analyzer/playground.py +5 -0
  94. birdnet_analyzer/search/__init__.py +3 -3
  95. birdnet_analyzer/search/__main__.py +3 -3
  96. birdnet_analyzer/search/cli.py +11 -12
  97. birdnet_analyzer/search/core.py +78 -78
  98. birdnet_analyzer/search/utils.py +107 -111
  99. birdnet_analyzer/segments/__init__.py +3 -3
  100. birdnet_analyzer/segments/__main__.py +3 -3
  101. birdnet_analyzer/segments/cli.py +13 -14
  102. birdnet_analyzer/segments/core.py +81 -78
  103. birdnet_analyzer/segments/utils.py +383 -394
  104. birdnet_analyzer/species/__init__.py +3 -3
  105. birdnet_analyzer/species/__main__.py +3 -3
  106. birdnet_analyzer/species/cli.py +13 -14
  107. birdnet_analyzer/species/core.py +35 -35
  108. birdnet_analyzer/species/utils.py +74 -75
  109. birdnet_analyzer/train/__init__.py +3 -3
  110. birdnet_analyzer/train/__main__.py +3 -3
  111. birdnet_analyzer/train/cli.py +13 -14
  112. birdnet_analyzer/train/core.py +113 -113
  113. birdnet_analyzer/train/utils.py +877 -847
  114. birdnet_analyzer/translate.py +133 -104
  115. birdnet_analyzer/utils.py +426 -419
  116. {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.0.1.dist-info}/METADATA +137 -129
  117. birdnet_analyzer-2.0.1.dist-info/RECORD +125 -0
  118. {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.0.1.dist-info}/WHEEL +1 -1
  119. {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.0.1.dist-info}/licenses/LICENSE +18 -18
  120. birdnet_analyzer-2.0.0.dist-info/RECORD +0 -117
  121. {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.0.1.dist-info}/entry_points.txt +0 -0
  122. {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.0.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,388 @@
1
+ """
2
+ Module containing functions to calculate various performance metrics using scikit-learn.
3
+
4
+ This script includes implementations for calculating accuracy, precision, recall, F1 score,
5
+ average precision, and AUROC for binary and multilabel classification tasks. It supports
6
+ various averaging methods and thresholds for predictions.
7
+
8
+ Functions:
9
+ - calculate_accuracy: Computes accuracy for binary or multilabel classification.
10
+ - calculate_recall: Computes recall for binary or multilabel classification.
11
+ - calculate_precision: Computes precision for binary or multilabel classification.
12
+ - calculate_f1_score: Computes the F1 score for binary or multilabel classification.
13
+ - calculate_average_precision: Computes the average precision score (AP).
14
+ - calculate_auroc: Computes the Area Under the Receiver Operating Characteristic curve (AUROC).
15
+ """
16
+
17
+ from typing import Literal
18
+
19
+ import numpy as np
20
+ from sklearn.metrics import (
21
+ accuracy_score,
22
+ average_precision_score,
23
+ f1_score,
24
+ precision_score,
25
+ recall_score,
26
+ roc_auc_score,
27
+ )
28
+
29
+
30
+ def calculate_accuracy(
31
+ predictions: np.ndarray,
32
+ labels: np.ndarray,
33
+ task: Literal["binary", "multilabel"],
34
+ num_classes: int,
35
+ threshold: float,
36
+ averaging_method: Literal["micro", "macro", "weighted", "none"] | None = "macro",
37
+ ) -> np.ndarray:
38
+ """
39
+ Calculate accuracy for the given predictions and labels.
40
+
41
+ Args:
42
+ predictions (np.ndarray): Model predictions as probabilities.
43
+ labels (np.ndarray): True labels.
44
+ task (Literal["binary", "multilabel"]): Type of classification task.
45
+ num_classes (int): Number of classes (only for multilabel tasks).
46
+ threshold (float): Threshold to binarize probabilities.
47
+ averaging_method (Optional[Literal["micro", "macro", "weighted", "none"]], optional):
48
+ Averaging method to compute accuracy for multilabel tasks. Defaults to "macro".
49
+
50
+ Returns:
51
+ np.ndarray: Accuracy metric(s) based on the task and averaging method.
52
+
53
+ Raises:
54
+ ValueError: If inputs are invalid or unsupported task/averaging method is specified.
55
+ """
56
+ # Input validation for predictions, labels, and threshold
57
+ if predictions.size == 0 or labels.size == 0:
58
+ raise ValueError("Predictions and labels must not be empty.")
59
+ if not 0 <= threshold <= 1:
60
+ raise ValueError(f"Invalid threshold: {threshold}. Must be between 0 and 1.")
61
+ if predictions.shape != labels.shape:
62
+ raise ValueError("Predictions and labels must have the same shape.")
63
+
64
+ # Handle binary and multilabel tasks separately
65
+ if task == "binary":
66
+ # Binary classification: Binarize predictions and compute accuracy
67
+ y_pred = (predictions >= threshold).astype(int)
68
+ y_true = labels.astype(int)
69
+ acc = accuracy_score(y_true, y_pred)
70
+ acc = np.array([acc])
71
+
72
+ elif task == "multilabel":
73
+ # Multilabel classification: Handle based on the specified averaging method
74
+ y_pred = (predictions >= threshold).astype(int)
75
+ y_true = labels.astype(int)
76
+
77
+ if averaging_method == "micro":
78
+ # Micro-averaging: Overall accuracy across all labels
79
+ correct = (y_pred == y_true).sum()
80
+ total = y_true.size
81
+ acc = correct / total if total > 0 else np.nan
82
+ acc = np.array([acc])
83
+
84
+ elif averaging_method == "macro":
85
+ # Macro-averaging: Compute accuracy per class and take the mean
86
+ accuracies = [accuracy_score(y_true[:, i], y_pred[:, i]) for i in range(num_classes)]
87
+ acc = np.mean(accuracies)
88
+ acc = np.array([acc])
89
+
90
+ elif averaging_method == "weighted":
91
+ # Weighted averaging: Weight class accuracies by class prevalence
92
+ accuracies, weights = [], []
93
+ for i in range(num_classes):
94
+ accuracies.append(accuracy_score(y_true[:, i], y_pred[:, i]))
95
+ weights.append(np.sum(y_true[:, i]))
96
+ acc = np.average(accuracies, weights=weights) if sum(weights) > 0 else np.array([0.0])
97
+ acc = np.array([acc])
98
+
99
+ elif averaging_method in [None, "none"]:
100
+ # No averaging: Return accuracy per class
101
+ acc = np.array([accuracy_score(y_true[:, i], y_pred[:, i]) for i in range(num_classes)])
102
+
103
+ else:
104
+ # Unsupported averaging method
105
+ raise ValueError(f"Invalid averaging method: {averaging_method}")
106
+ else:
107
+ # Unsupported task type
108
+ raise ValueError(f"Unsupported task type: {task}")
109
+
110
+ return acc
111
+
112
+
113
+ def calculate_recall(
114
+ predictions: np.ndarray,
115
+ labels: np.ndarray,
116
+ task: Literal["binary", "multilabel"],
117
+ threshold: float,
118
+ averaging_method: Literal["binary", "micro", "macro", "weighted", "samples", "none"] | None = None,
119
+ ) -> np.ndarray:
120
+ """
121
+ Calculate recall for the given predictions and labels.
122
+
123
+ Args:
124
+ predictions (np.ndarray): Model predictions as probabilities.
125
+ labels (np.ndarray): True labels.
126
+ task (Literal["binary", "multilabel"]): Type of classification task.
127
+ threshold (float): Threshold to binarize probabilities.
128
+ averaging_method (Optional[Literal["binary", "micro", "macro", "weighted", "samples", "none"]], optional):
129
+ Averaging method for multilabel recall. Defaults to None.
130
+
131
+ Returns:
132
+ np.ndarray: Recall metric(s).
133
+
134
+ Raises:
135
+ ValueError: If inputs are invalid or unsupported task type is specified.
136
+ """
137
+ # Validate inputs for size, threshold, and shape
138
+ if predictions.size == 0 or labels.size == 0:
139
+ raise ValueError("Predictions and labels must not be empty.")
140
+ if not 0 <= threshold <= 1:
141
+ raise ValueError(f"Invalid threshold: {threshold}. Must be between 0 and 1.")
142
+ if predictions.shape != labels.shape:
143
+ raise ValueError("Predictions and labels must have the same shape.")
144
+
145
+ # Adjust averaging method for scikit-learn if none is specified
146
+ averaging = None if averaging_method == "none" else averaging_method
147
+
148
+ # Compute recall based on task type
149
+ if task == "binary":
150
+ averaging = averaging or "binary"
151
+ y_pred = (predictions >= threshold).astype(int)
152
+ y_true = labels.astype(int)
153
+ recall = recall_score(y_true, y_pred, average=averaging, zero_division=0)
154
+
155
+ elif task == "multilabel":
156
+ y_pred = (predictions >= threshold).astype(int)
157
+ y_true = labels.astype(int)
158
+ recall = recall_score(y_true, y_pred, average=averaging, zero_division=0)
159
+
160
+ else:
161
+ # Unsupported task type
162
+ raise ValueError(f"Unsupported task type: {task}")
163
+
164
+ # Ensure return type is consistent
165
+ if isinstance(recall, np.ndarray):
166
+ return recall
167
+ return np.array([recall])
168
+
169
+
170
+ def calculate_precision(
171
+ predictions: np.ndarray,
172
+ labels: np.ndarray,
173
+ task: Literal["binary", "multilabel"],
174
+ threshold: float,
175
+ averaging_method: Literal["binary", "micro", "macro", "weighted", "samples", "none"] | None = None,
176
+ ) -> np.ndarray:
177
+ """
178
+ Calculate precision for the given predictions and labels.
179
+
180
+ Args:
181
+ predictions (np.ndarray): Model predictions as probabilities.
182
+ labels (np.ndarray): True labels.
183
+ task (Literal["binary", "multilabel"]): Type of classification task.
184
+ threshold (float): Threshold to binarize probabilities.
185
+ averaging_method (Optional[Literal["binary", "micro", "macro", "weighted", "samples", "none"]], optional):
186
+ Averaging method for multilabel precision. Defaults to None.
187
+
188
+ Returns:
189
+ np.ndarray: Precision metric(s).
190
+
191
+ Raises:
192
+ ValueError: If inputs are invalid or unsupported task type is specified.
193
+ """
194
+ # Validate inputs for size, threshold, and shape
195
+ if predictions.size == 0 or labels.size == 0:
196
+ raise ValueError("Predictions and labels must not be empty.")
197
+ if not 0 <= threshold <= 1:
198
+ raise ValueError(f"Invalid threshold: {threshold}. Must be between 0 and 1.")
199
+ if predictions.shape != labels.shape:
200
+ raise ValueError("Predictions and labels must have the same shape.")
201
+
202
+ # Adjust averaging method for scikit-learn if none is specified
203
+ averaging = None if averaging_method == "none" else averaging_method
204
+
205
+ # Compute precision based on task type
206
+ if task == "binary":
207
+ averaging = averaging or "binary"
208
+ y_pred = (predictions >= threshold).astype(int)
209
+ y_true = labels.astype(int)
210
+ precision = precision_score(y_true, y_pred, average=averaging, zero_division=0)
211
+
212
+ elif task == "multilabel":
213
+ y_pred = (predictions >= threshold).astype(int)
214
+ y_true = labels.astype(int)
215
+ precision = precision_score(y_true, y_pred, average=averaging, zero_division=0)
216
+
217
+ else:
218
+ # Unsupported task type
219
+ raise ValueError(f"Unsupported task type: {task}")
220
+
221
+ # Ensure return type is consistent
222
+ if isinstance(precision, np.ndarray):
223
+ return precision
224
+ return np.array([precision])
225
+
226
+
227
+ def calculate_f1_score(
228
+ predictions: np.ndarray,
229
+ labels: np.ndarray,
230
+ task: Literal["binary", "multilabel"],
231
+ threshold: float,
232
+ averaging_method: Literal["binary", "micro", "macro", "weighted", "samples", "none"] | None = None,
233
+ ) -> np.ndarray:
234
+ """
235
+ Calculate the F1 score for the given predictions and labels.
236
+
237
+ Args:
238
+ predictions (np.ndarray): Model predictions as probabilities.
239
+ labels (np.ndarray): True labels.
240
+ task (Literal["binary", "multilabel"]): Type of classification task.
241
+ threshold (float): Threshold to binarize probabilities.
242
+ averaging_method (Optional[Literal["binary", "micro", "macro", "weighted", "samples", "none"]], optional):
243
+ Averaging method for multilabel F1 score. Defaults to None.
244
+
245
+ Returns:
246
+ np.ndarray: F1 score metric(s).
247
+
248
+ Raises:
249
+ ValueError: If inputs are invalid or unsupported task type is specified.
250
+ """
251
+ # Validate inputs for size, threshold, and shape
252
+ if predictions.size == 0 or labels.size == 0:
253
+ raise ValueError("Predictions and labels must not be empty.")
254
+ if not 0 <= threshold <= 1:
255
+ raise ValueError(f"Invalid threshold: {threshold}. Must be between 0 and 1.")
256
+ if predictions.shape != labels.shape:
257
+ raise ValueError("Predictions and labels must have the same shape.")
258
+
259
+ # Adjust averaging method for scikit-learn if none is specified
260
+ averaging = None if averaging_method == "none" else averaging_method
261
+
262
+ # Compute F1 score based on task type
263
+ if task == "binary":
264
+ averaging = averaging or "binary"
265
+ y_pred = (predictions >= threshold).astype(int)
266
+ y_true = labels.astype(int)
267
+ f1 = f1_score(y_true, y_pred, average=averaging, zero_division=0)
268
+
269
+ elif task == "multilabel":
270
+ y_pred = (predictions >= threshold).astype(int)
271
+ y_true = labels.astype(int)
272
+ f1 = f1_score(y_true, y_pred, average=averaging, zero_division=0)
273
+
274
+ else:
275
+ # Unsupported task type
276
+ raise ValueError(f"Unsupported task type: {task}")
277
+
278
+ # Ensure return type is consistent
279
+ if isinstance(f1, np.ndarray):
280
+ return f1
281
+ return np.array([f1])
282
+
283
+
284
+ def calculate_average_precision(
285
+ predictions: np.ndarray,
286
+ labels: np.ndarray,
287
+ task: Literal["binary", "multilabel"],
288
+ averaging_method: Literal["micro", "macro", "weighted", "samples", "none"] | None = None,
289
+ ) -> np.ndarray:
290
+ """
291
+ Calculate the average precision (AP) for the given predictions and labels.
292
+
293
+ Args:
294
+ predictions (np.ndarray): Model predictions as probabilities.
295
+ labels (np.ndarray): True labels.
296
+ task (Literal["binary", "multilabel"]): Type of classification task.
297
+ averaging_method (Optional[Literal["micro", "macro", "weighted", "samples", "none"]], optional):
298
+ Averaging method for AP. Defaults to None.
299
+
300
+ Returns:
301
+ np.ndarray: Average precision metric(s).
302
+
303
+ Raises:
304
+ ValueError: If inputs are invalid or unsupported task type is specified.
305
+ """
306
+ # Validate inputs for size and shape
307
+ if predictions.size == 0 or labels.size == 0:
308
+ raise ValueError("Predictions and labels must not be empty.")
309
+ if predictions.shape != labels.shape:
310
+ raise ValueError("Predictions and labels must have the same shape.")
311
+
312
+ # Adjust averaging method for scikit-learn if none is specified
313
+ averaging = None if averaging_method == "none" else averaging_method
314
+
315
+ # Compute average precision based on task type
316
+ if task in ("binary", "multilabel"):
317
+ y_true = labels.astype(int)
318
+ y_scores = predictions
319
+ ap = average_precision_score(y_true, y_scores, average=averaging)
320
+
321
+ else:
322
+ # Unsupported task type
323
+ raise ValueError(f"Unsupported task type for average precision: {task}")
324
+
325
+ # Ensure return type is consistent
326
+ if isinstance(ap, np.ndarray):
327
+ return ap
328
+ return np.array([ap])
329
+
330
+
331
+ def calculate_auroc(
332
+ predictions: np.ndarray,
333
+ labels: np.ndarray,
334
+ task: Literal["binary", "multilabel"],
335
+ averaging_method: Literal["macro", "weighted", "samples", "none"] | None = "macro",
336
+ ) -> np.ndarray:
337
+ """
338
+ Calculate the Area Under the Receiver Operating Characteristic curve (AUROC).
339
+
340
+ Args:
341
+ predictions (np.ndarray): Model predictions as probabilities.
342
+ labels (np.ndarray): True labels.
343
+ task (Literal["binary", "multilabel"]): Type of classification task.
344
+ averaging_method (Optional[Literal["macro", "weighted", "samples", "none"]], optional):
345
+ Averaging method for multilabel AUROC. Defaults to "macro".
346
+
347
+ Returns:
348
+ np.ndarray: AUROC metric(s).
349
+
350
+ Raises:
351
+ ValueError: If inputs are invalid or unsupported task type is specified.
352
+ """
353
+ # Validate inputs for size and shape
354
+ if predictions.size == 0 or labels.size == 0:
355
+ raise ValueError("Predictions and labels must not be empty.")
356
+ if predictions.shape != labels.shape:
357
+ raise ValueError("Predictions and labels must have the same shape.")
358
+
359
+ # Adjust averaging method for scikit-learn if none is specified
360
+ averaging = None if averaging_method == "none" else averaging_method
361
+
362
+ try:
363
+ # Compute AUROC based on task type
364
+ if task == "binary":
365
+ y_true = labels.astype(int)
366
+ y_scores = predictions
367
+ auroc = roc_auc_score(y_true, y_scores)
368
+
369
+ elif task == "multilabel":
370
+ y_true = labels.astype(int)
371
+ y_scores = predictions
372
+ auroc = roc_auc_score(y_true, y_scores, average=averaging)
373
+
374
+ else:
375
+ # Unsupported task type
376
+ raise ValueError(f"Unsupported task type: {task}")
377
+
378
+ except ValueError as e:
379
+ # Handle edge cases where AUROC cannot be computed
380
+ if "Only one class present in y_true" in str(e) or "Number of classes in y_true" in str(e):
381
+ auroc = np.nan
382
+ else:
383
+ raise
384
+
385
+ # Ensure return type is consistent
386
+ if isinstance(auroc, np.ndarray):
387
+ return auroc
388
+ return np.array([auroc])