birdnet-analyzer 2.0.0__py3-none-any.whl → 2.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- birdnet_analyzer/__init__.py +9 -8
- birdnet_analyzer/analyze/__init__.py +5 -5
- birdnet_analyzer/analyze/__main__.py +3 -4
- birdnet_analyzer/analyze/cli.py +25 -25
- birdnet_analyzer/analyze/core.py +241 -245
- birdnet_analyzer/analyze/utils.py +692 -701
- birdnet_analyzer/audio.py +368 -372
- birdnet_analyzer/cli.py +709 -707
- birdnet_analyzer/config.py +242 -242
- birdnet_analyzer/eBird_taxonomy_codes_2021E.json +25279 -25279
- birdnet_analyzer/embeddings/__init__.py +3 -4
- birdnet_analyzer/embeddings/__main__.py +3 -3
- birdnet_analyzer/embeddings/cli.py +12 -13
- birdnet_analyzer/embeddings/core.py +69 -70
- birdnet_analyzer/embeddings/utils.py +179 -193
- birdnet_analyzer/evaluation/__init__.py +196 -195
- birdnet_analyzer/evaluation/__main__.py +3 -3
- birdnet_analyzer/evaluation/assessment/__init__.py +0 -0
- birdnet_analyzer/evaluation/assessment/metrics.py +388 -0
- birdnet_analyzer/evaluation/assessment/performance_assessor.py +409 -0
- birdnet_analyzer/evaluation/assessment/plotting.py +379 -0
- birdnet_analyzer/evaluation/preprocessing/__init__.py +0 -0
- birdnet_analyzer/evaluation/preprocessing/data_processor.py +631 -0
- birdnet_analyzer/evaluation/preprocessing/utils.py +98 -0
- birdnet_analyzer/gui/__init__.py +19 -23
- birdnet_analyzer/gui/__main__.py +3 -3
- birdnet_analyzer/gui/analysis.py +175 -174
- birdnet_analyzer/gui/assets/arrow_down.svg +4 -4
- birdnet_analyzer/gui/assets/arrow_left.svg +4 -4
- birdnet_analyzer/gui/assets/arrow_right.svg +4 -4
- birdnet_analyzer/gui/assets/arrow_up.svg +4 -4
- birdnet_analyzer/gui/assets/gui.css +28 -28
- birdnet_analyzer/gui/assets/gui.js +93 -93
- birdnet_analyzer/gui/embeddings.py +619 -620
- birdnet_analyzer/gui/evaluation.py +795 -813
- birdnet_analyzer/gui/localization.py +75 -68
- birdnet_analyzer/gui/multi_file.py +245 -246
- birdnet_analyzer/gui/review.py +519 -527
- birdnet_analyzer/gui/segments.py +191 -191
- birdnet_analyzer/gui/settings.py +128 -129
- birdnet_analyzer/gui/single_file.py +267 -269
- birdnet_analyzer/gui/species.py +95 -95
- birdnet_analyzer/gui/train.py +696 -698
- birdnet_analyzer/gui/utils.py +810 -808
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_af.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ar.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_bg.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ca.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_cs.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_da.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_de.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_el.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_en_uk.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_es.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_fi.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_fr.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_he.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_hr.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_hu.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_in.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_is.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_it.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ja.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ko.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_lt.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ml.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_nl.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_no.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_pl.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_pt_BR.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_pt_PT.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ro.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ru.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sk.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sl.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sr.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sv.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_th.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_tr.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_uk.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_zh.txt +6522 -6522
- birdnet_analyzer/lang/de.json +334 -334
- birdnet_analyzer/lang/en.json +334 -334
- birdnet_analyzer/lang/fi.json +334 -334
- birdnet_analyzer/lang/fr.json +334 -334
- birdnet_analyzer/lang/id.json +334 -334
- birdnet_analyzer/lang/pt-br.json +334 -334
- birdnet_analyzer/lang/ru.json +334 -334
- birdnet_analyzer/lang/se.json +334 -334
- birdnet_analyzer/lang/tlh.json +334 -334
- birdnet_analyzer/lang/zh_TW.json +334 -334
- birdnet_analyzer/model.py +1212 -1243
- birdnet_analyzer/playground.py +5 -0
- birdnet_analyzer/search/__init__.py +3 -3
- birdnet_analyzer/search/__main__.py +3 -3
- birdnet_analyzer/search/cli.py +11 -12
- birdnet_analyzer/search/core.py +78 -78
- birdnet_analyzer/search/utils.py +107 -111
- birdnet_analyzer/segments/__init__.py +3 -3
- birdnet_analyzer/segments/__main__.py +3 -3
- birdnet_analyzer/segments/cli.py +13 -14
- birdnet_analyzer/segments/core.py +81 -78
- birdnet_analyzer/segments/utils.py +383 -394
- birdnet_analyzer/species/__init__.py +3 -3
- birdnet_analyzer/species/__main__.py +3 -3
- birdnet_analyzer/species/cli.py +13 -14
- birdnet_analyzer/species/core.py +35 -35
- birdnet_analyzer/species/utils.py +74 -75
- birdnet_analyzer/train/__init__.py +3 -3
- birdnet_analyzer/train/__main__.py +3 -3
- birdnet_analyzer/train/cli.py +13 -14
- birdnet_analyzer/train/core.py +113 -113
- birdnet_analyzer/train/utils.py +877 -847
- birdnet_analyzer/translate.py +133 -104
- birdnet_analyzer/utils.py +426 -419
- {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.0.1.dist-info}/METADATA +137 -129
- birdnet_analyzer-2.0.1.dist-info/RECORD +125 -0
- {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.0.1.dist-info}/WHEEL +1 -1
- {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.0.1.dist-info}/licenses/LICENSE +18 -18
- birdnet_analyzer-2.0.0.dist-info/RECORD +0 -117
- {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.0.1.dist-info}/entry_points.txt +0 -0
- {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.0.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,388 @@
|
|
1
|
+
"""
|
2
|
+
Module containing functions to calculate various performance metrics using scikit-learn.
|
3
|
+
|
4
|
+
This script includes implementations for calculating accuracy, precision, recall, F1 score,
|
5
|
+
average precision, and AUROC for binary and multilabel classification tasks. It supports
|
6
|
+
various averaging methods and thresholds for predictions.
|
7
|
+
|
8
|
+
Functions:
|
9
|
+
- calculate_accuracy: Computes accuracy for binary or multilabel classification.
|
10
|
+
- calculate_recall: Computes recall for binary or multilabel classification.
|
11
|
+
- calculate_precision: Computes precision for binary or multilabel classification.
|
12
|
+
- calculate_f1_score: Computes the F1 score for binary or multilabel classification.
|
13
|
+
- calculate_average_precision: Computes the average precision score (AP).
|
14
|
+
- calculate_auroc: Computes the Area Under the Receiver Operating Characteristic curve (AUROC).
|
15
|
+
"""
|
16
|
+
|
17
|
+
from typing import Literal
|
18
|
+
|
19
|
+
import numpy as np
|
20
|
+
from sklearn.metrics import (
|
21
|
+
accuracy_score,
|
22
|
+
average_precision_score,
|
23
|
+
f1_score,
|
24
|
+
precision_score,
|
25
|
+
recall_score,
|
26
|
+
roc_auc_score,
|
27
|
+
)
|
28
|
+
|
29
|
+
|
30
|
+
def calculate_accuracy(
|
31
|
+
predictions: np.ndarray,
|
32
|
+
labels: np.ndarray,
|
33
|
+
task: Literal["binary", "multilabel"],
|
34
|
+
num_classes: int,
|
35
|
+
threshold: float,
|
36
|
+
averaging_method: Literal["micro", "macro", "weighted", "none"] | None = "macro",
|
37
|
+
) -> np.ndarray:
|
38
|
+
"""
|
39
|
+
Calculate accuracy for the given predictions and labels.
|
40
|
+
|
41
|
+
Args:
|
42
|
+
predictions (np.ndarray): Model predictions as probabilities.
|
43
|
+
labels (np.ndarray): True labels.
|
44
|
+
task (Literal["binary", "multilabel"]): Type of classification task.
|
45
|
+
num_classes (int): Number of classes (only for multilabel tasks).
|
46
|
+
threshold (float): Threshold to binarize probabilities.
|
47
|
+
averaging_method (Optional[Literal["micro", "macro", "weighted", "none"]], optional):
|
48
|
+
Averaging method to compute accuracy for multilabel tasks. Defaults to "macro".
|
49
|
+
|
50
|
+
Returns:
|
51
|
+
np.ndarray: Accuracy metric(s) based on the task and averaging method.
|
52
|
+
|
53
|
+
Raises:
|
54
|
+
ValueError: If inputs are invalid or unsupported task/averaging method is specified.
|
55
|
+
"""
|
56
|
+
# Input validation for predictions, labels, and threshold
|
57
|
+
if predictions.size == 0 or labels.size == 0:
|
58
|
+
raise ValueError("Predictions and labels must not be empty.")
|
59
|
+
if not 0 <= threshold <= 1:
|
60
|
+
raise ValueError(f"Invalid threshold: {threshold}. Must be between 0 and 1.")
|
61
|
+
if predictions.shape != labels.shape:
|
62
|
+
raise ValueError("Predictions and labels must have the same shape.")
|
63
|
+
|
64
|
+
# Handle binary and multilabel tasks separately
|
65
|
+
if task == "binary":
|
66
|
+
# Binary classification: Binarize predictions and compute accuracy
|
67
|
+
y_pred = (predictions >= threshold).astype(int)
|
68
|
+
y_true = labels.astype(int)
|
69
|
+
acc = accuracy_score(y_true, y_pred)
|
70
|
+
acc = np.array([acc])
|
71
|
+
|
72
|
+
elif task == "multilabel":
|
73
|
+
# Multilabel classification: Handle based on the specified averaging method
|
74
|
+
y_pred = (predictions >= threshold).astype(int)
|
75
|
+
y_true = labels.astype(int)
|
76
|
+
|
77
|
+
if averaging_method == "micro":
|
78
|
+
# Micro-averaging: Overall accuracy across all labels
|
79
|
+
correct = (y_pred == y_true).sum()
|
80
|
+
total = y_true.size
|
81
|
+
acc = correct / total if total > 0 else np.nan
|
82
|
+
acc = np.array([acc])
|
83
|
+
|
84
|
+
elif averaging_method == "macro":
|
85
|
+
# Macro-averaging: Compute accuracy per class and take the mean
|
86
|
+
accuracies = [accuracy_score(y_true[:, i], y_pred[:, i]) for i in range(num_classes)]
|
87
|
+
acc = np.mean(accuracies)
|
88
|
+
acc = np.array([acc])
|
89
|
+
|
90
|
+
elif averaging_method == "weighted":
|
91
|
+
# Weighted averaging: Weight class accuracies by class prevalence
|
92
|
+
accuracies, weights = [], []
|
93
|
+
for i in range(num_classes):
|
94
|
+
accuracies.append(accuracy_score(y_true[:, i], y_pred[:, i]))
|
95
|
+
weights.append(np.sum(y_true[:, i]))
|
96
|
+
acc = np.average(accuracies, weights=weights) if sum(weights) > 0 else np.array([0.0])
|
97
|
+
acc = np.array([acc])
|
98
|
+
|
99
|
+
elif averaging_method in [None, "none"]:
|
100
|
+
# No averaging: Return accuracy per class
|
101
|
+
acc = np.array([accuracy_score(y_true[:, i], y_pred[:, i]) for i in range(num_classes)])
|
102
|
+
|
103
|
+
else:
|
104
|
+
# Unsupported averaging method
|
105
|
+
raise ValueError(f"Invalid averaging method: {averaging_method}")
|
106
|
+
else:
|
107
|
+
# Unsupported task type
|
108
|
+
raise ValueError(f"Unsupported task type: {task}")
|
109
|
+
|
110
|
+
return acc
|
111
|
+
|
112
|
+
|
113
|
+
def calculate_recall(
|
114
|
+
predictions: np.ndarray,
|
115
|
+
labels: np.ndarray,
|
116
|
+
task: Literal["binary", "multilabel"],
|
117
|
+
threshold: float,
|
118
|
+
averaging_method: Literal["binary", "micro", "macro", "weighted", "samples", "none"] | None = None,
|
119
|
+
) -> np.ndarray:
|
120
|
+
"""
|
121
|
+
Calculate recall for the given predictions and labels.
|
122
|
+
|
123
|
+
Args:
|
124
|
+
predictions (np.ndarray): Model predictions as probabilities.
|
125
|
+
labels (np.ndarray): True labels.
|
126
|
+
task (Literal["binary", "multilabel"]): Type of classification task.
|
127
|
+
threshold (float): Threshold to binarize probabilities.
|
128
|
+
averaging_method (Optional[Literal["binary", "micro", "macro", "weighted", "samples", "none"]], optional):
|
129
|
+
Averaging method for multilabel recall. Defaults to None.
|
130
|
+
|
131
|
+
Returns:
|
132
|
+
np.ndarray: Recall metric(s).
|
133
|
+
|
134
|
+
Raises:
|
135
|
+
ValueError: If inputs are invalid or unsupported task type is specified.
|
136
|
+
"""
|
137
|
+
# Validate inputs for size, threshold, and shape
|
138
|
+
if predictions.size == 0 or labels.size == 0:
|
139
|
+
raise ValueError("Predictions and labels must not be empty.")
|
140
|
+
if not 0 <= threshold <= 1:
|
141
|
+
raise ValueError(f"Invalid threshold: {threshold}. Must be between 0 and 1.")
|
142
|
+
if predictions.shape != labels.shape:
|
143
|
+
raise ValueError("Predictions and labels must have the same shape.")
|
144
|
+
|
145
|
+
# Adjust averaging method for scikit-learn if none is specified
|
146
|
+
averaging = None if averaging_method == "none" else averaging_method
|
147
|
+
|
148
|
+
# Compute recall based on task type
|
149
|
+
if task == "binary":
|
150
|
+
averaging = averaging or "binary"
|
151
|
+
y_pred = (predictions >= threshold).astype(int)
|
152
|
+
y_true = labels.astype(int)
|
153
|
+
recall = recall_score(y_true, y_pred, average=averaging, zero_division=0)
|
154
|
+
|
155
|
+
elif task == "multilabel":
|
156
|
+
y_pred = (predictions >= threshold).astype(int)
|
157
|
+
y_true = labels.astype(int)
|
158
|
+
recall = recall_score(y_true, y_pred, average=averaging, zero_division=0)
|
159
|
+
|
160
|
+
else:
|
161
|
+
# Unsupported task type
|
162
|
+
raise ValueError(f"Unsupported task type: {task}")
|
163
|
+
|
164
|
+
# Ensure return type is consistent
|
165
|
+
if isinstance(recall, np.ndarray):
|
166
|
+
return recall
|
167
|
+
return np.array([recall])
|
168
|
+
|
169
|
+
|
170
|
+
def calculate_precision(
|
171
|
+
predictions: np.ndarray,
|
172
|
+
labels: np.ndarray,
|
173
|
+
task: Literal["binary", "multilabel"],
|
174
|
+
threshold: float,
|
175
|
+
averaging_method: Literal["binary", "micro", "macro", "weighted", "samples", "none"] | None = None,
|
176
|
+
) -> np.ndarray:
|
177
|
+
"""
|
178
|
+
Calculate precision for the given predictions and labels.
|
179
|
+
|
180
|
+
Args:
|
181
|
+
predictions (np.ndarray): Model predictions as probabilities.
|
182
|
+
labels (np.ndarray): True labels.
|
183
|
+
task (Literal["binary", "multilabel"]): Type of classification task.
|
184
|
+
threshold (float): Threshold to binarize probabilities.
|
185
|
+
averaging_method (Optional[Literal["binary", "micro", "macro", "weighted", "samples", "none"]], optional):
|
186
|
+
Averaging method for multilabel precision. Defaults to None.
|
187
|
+
|
188
|
+
Returns:
|
189
|
+
np.ndarray: Precision metric(s).
|
190
|
+
|
191
|
+
Raises:
|
192
|
+
ValueError: If inputs are invalid or unsupported task type is specified.
|
193
|
+
"""
|
194
|
+
# Validate inputs for size, threshold, and shape
|
195
|
+
if predictions.size == 0 or labels.size == 0:
|
196
|
+
raise ValueError("Predictions and labels must not be empty.")
|
197
|
+
if not 0 <= threshold <= 1:
|
198
|
+
raise ValueError(f"Invalid threshold: {threshold}. Must be between 0 and 1.")
|
199
|
+
if predictions.shape != labels.shape:
|
200
|
+
raise ValueError("Predictions and labels must have the same shape.")
|
201
|
+
|
202
|
+
# Adjust averaging method for scikit-learn if none is specified
|
203
|
+
averaging = None if averaging_method == "none" else averaging_method
|
204
|
+
|
205
|
+
# Compute precision based on task type
|
206
|
+
if task == "binary":
|
207
|
+
averaging = averaging or "binary"
|
208
|
+
y_pred = (predictions >= threshold).astype(int)
|
209
|
+
y_true = labels.astype(int)
|
210
|
+
precision = precision_score(y_true, y_pred, average=averaging, zero_division=0)
|
211
|
+
|
212
|
+
elif task == "multilabel":
|
213
|
+
y_pred = (predictions >= threshold).astype(int)
|
214
|
+
y_true = labels.astype(int)
|
215
|
+
precision = precision_score(y_true, y_pred, average=averaging, zero_division=0)
|
216
|
+
|
217
|
+
else:
|
218
|
+
# Unsupported task type
|
219
|
+
raise ValueError(f"Unsupported task type: {task}")
|
220
|
+
|
221
|
+
# Ensure return type is consistent
|
222
|
+
if isinstance(precision, np.ndarray):
|
223
|
+
return precision
|
224
|
+
return np.array([precision])
|
225
|
+
|
226
|
+
|
227
|
+
def calculate_f1_score(
|
228
|
+
predictions: np.ndarray,
|
229
|
+
labels: np.ndarray,
|
230
|
+
task: Literal["binary", "multilabel"],
|
231
|
+
threshold: float,
|
232
|
+
averaging_method: Literal["binary", "micro", "macro", "weighted", "samples", "none"] | None = None,
|
233
|
+
) -> np.ndarray:
|
234
|
+
"""
|
235
|
+
Calculate the F1 score for the given predictions and labels.
|
236
|
+
|
237
|
+
Args:
|
238
|
+
predictions (np.ndarray): Model predictions as probabilities.
|
239
|
+
labels (np.ndarray): True labels.
|
240
|
+
task (Literal["binary", "multilabel"]): Type of classification task.
|
241
|
+
threshold (float): Threshold to binarize probabilities.
|
242
|
+
averaging_method (Optional[Literal["binary", "micro", "macro", "weighted", "samples", "none"]], optional):
|
243
|
+
Averaging method for multilabel F1 score. Defaults to None.
|
244
|
+
|
245
|
+
Returns:
|
246
|
+
np.ndarray: F1 score metric(s).
|
247
|
+
|
248
|
+
Raises:
|
249
|
+
ValueError: If inputs are invalid or unsupported task type is specified.
|
250
|
+
"""
|
251
|
+
# Validate inputs for size, threshold, and shape
|
252
|
+
if predictions.size == 0 or labels.size == 0:
|
253
|
+
raise ValueError("Predictions and labels must not be empty.")
|
254
|
+
if not 0 <= threshold <= 1:
|
255
|
+
raise ValueError(f"Invalid threshold: {threshold}. Must be between 0 and 1.")
|
256
|
+
if predictions.shape != labels.shape:
|
257
|
+
raise ValueError("Predictions and labels must have the same shape.")
|
258
|
+
|
259
|
+
# Adjust averaging method for scikit-learn if none is specified
|
260
|
+
averaging = None if averaging_method == "none" else averaging_method
|
261
|
+
|
262
|
+
# Compute F1 score based on task type
|
263
|
+
if task == "binary":
|
264
|
+
averaging = averaging or "binary"
|
265
|
+
y_pred = (predictions >= threshold).astype(int)
|
266
|
+
y_true = labels.astype(int)
|
267
|
+
f1 = f1_score(y_true, y_pred, average=averaging, zero_division=0)
|
268
|
+
|
269
|
+
elif task == "multilabel":
|
270
|
+
y_pred = (predictions >= threshold).astype(int)
|
271
|
+
y_true = labels.astype(int)
|
272
|
+
f1 = f1_score(y_true, y_pred, average=averaging, zero_division=0)
|
273
|
+
|
274
|
+
else:
|
275
|
+
# Unsupported task type
|
276
|
+
raise ValueError(f"Unsupported task type: {task}")
|
277
|
+
|
278
|
+
# Ensure return type is consistent
|
279
|
+
if isinstance(f1, np.ndarray):
|
280
|
+
return f1
|
281
|
+
return np.array([f1])
|
282
|
+
|
283
|
+
|
284
|
+
def calculate_average_precision(
|
285
|
+
predictions: np.ndarray,
|
286
|
+
labels: np.ndarray,
|
287
|
+
task: Literal["binary", "multilabel"],
|
288
|
+
averaging_method: Literal["micro", "macro", "weighted", "samples", "none"] | None = None,
|
289
|
+
) -> np.ndarray:
|
290
|
+
"""
|
291
|
+
Calculate the average precision (AP) for the given predictions and labels.
|
292
|
+
|
293
|
+
Args:
|
294
|
+
predictions (np.ndarray): Model predictions as probabilities.
|
295
|
+
labels (np.ndarray): True labels.
|
296
|
+
task (Literal["binary", "multilabel"]): Type of classification task.
|
297
|
+
averaging_method (Optional[Literal["micro", "macro", "weighted", "samples", "none"]], optional):
|
298
|
+
Averaging method for AP. Defaults to None.
|
299
|
+
|
300
|
+
Returns:
|
301
|
+
np.ndarray: Average precision metric(s).
|
302
|
+
|
303
|
+
Raises:
|
304
|
+
ValueError: If inputs are invalid or unsupported task type is specified.
|
305
|
+
"""
|
306
|
+
# Validate inputs for size and shape
|
307
|
+
if predictions.size == 0 or labels.size == 0:
|
308
|
+
raise ValueError("Predictions and labels must not be empty.")
|
309
|
+
if predictions.shape != labels.shape:
|
310
|
+
raise ValueError("Predictions and labels must have the same shape.")
|
311
|
+
|
312
|
+
# Adjust averaging method for scikit-learn if none is specified
|
313
|
+
averaging = None if averaging_method == "none" else averaging_method
|
314
|
+
|
315
|
+
# Compute average precision based on task type
|
316
|
+
if task in ("binary", "multilabel"):
|
317
|
+
y_true = labels.astype(int)
|
318
|
+
y_scores = predictions
|
319
|
+
ap = average_precision_score(y_true, y_scores, average=averaging)
|
320
|
+
|
321
|
+
else:
|
322
|
+
# Unsupported task type
|
323
|
+
raise ValueError(f"Unsupported task type for average precision: {task}")
|
324
|
+
|
325
|
+
# Ensure return type is consistent
|
326
|
+
if isinstance(ap, np.ndarray):
|
327
|
+
return ap
|
328
|
+
return np.array([ap])
|
329
|
+
|
330
|
+
|
331
|
+
def calculate_auroc(
|
332
|
+
predictions: np.ndarray,
|
333
|
+
labels: np.ndarray,
|
334
|
+
task: Literal["binary", "multilabel"],
|
335
|
+
averaging_method: Literal["macro", "weighted", "samples", "none"] | None = "macro",
|
336
|
+
) -> np.ndarray:
|
337
|
+
"""
|
338
|
+
Calculate the Area Under the Receiver Operating Characteristic curve (AUROC).
|
339
|
+
|
340
|
+
Args:
|
341
|
+
predictions (np.ndarray): Model predictions as probabilities.
|
342
|
+
labels (np.ndarray): True labels.
|
343
|
+
task (Literal["binary", "multilabel"]): Type of classification task.
|
344
|
+
averaging_method (Optional[Literal["macro", "weighted", "samples", "none"]], optional):
|
345
|
+
Averaging method for multilabel AUROC. Defaults to "macro".
|
346
|
+
|
347
|
+
Returns:
|
348
|
+
np.ndarray: AUROC metric(s).
|
349
|
+
|
350
|
+
Raises:
|
351
|
+
ValueError: If inputs are invalid or unsupported task type is specified.
|
352
|
+
"""
|
353
|
+
# Validate inputs for size and shape
|
354
|
+
if predictions.size == 0 or labels.size == 0:
|
355
|
+
raise ValueError("Predictions and labels must not be empty.")
|
356
|
+
if predictions.shape != labels.shape:
|
357
|
+
raise ValueError("Predictions and labels must have the same shape.")
|
358
|
+
|
359
|
+
# Adjust averaging method for scikit-learn if none is specified
|
360
|
+
averaging = None if averaging_method == "none" else averaging_method
|
361
|
+
|
362
|
+
try:
|
363
|
+
# Compute AUROC based on task type
|
364
|
+
if task == "binary":
|
365
|
+
y_true = labels.astype(int)
|
366
|
+
y_scores = predictions
|
367
|
+
auroc = roc_auc_score(y_true, y_scores)
|
368
|
+
|
369
|
+
elif task == "multilabel":
|
370
|
+
y_true = labels.astype(int)
|
371
|
+
y_scores = predictions
|
372
|
+
auroc = roc_auc_score(y_true, y_scores, average=averaging)
|
373
|
+
|
374
|
+
else:
|
375
|
+
# Unsupported task type
|
376
|
+
raise ValueError(f"Unsupported task type: {task}")
|
377
|
+
|
378
|
+
except ValueError as e:
|
379
|
+
# Handle edge cases where AUROC cannot be computed
|
380
|
+
if "Only one class present in y_true" in str(e) or "Number of classes in y_true" in str(e):
|
381
|
+
auroc = np.nan
|
382
|
+
else:
|
383
|
+
raise
|
384
|
+
|
385
|
+
# Ensure return type is consistent
|
386
|
+
if isinstance(auroc, np.ndarray):
|
387
|
+
return auroc
|
388
|
+
return np.array([auroc])
|