birdnet-analyzer 2.0.1__py3-none-any.whl → 2.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (121) hide show
  1. birdnet_analyzer/__init__.py +9 -9
  2. birdnet_analyzer/analyze/__init__.py +19 -5
  3. birdnet_analyzer/analyze/__main__.py +3 -3
  4. birdnet_analyzer/analyze/cli.py +30 -25
  5. birdnet_analyzer/analyze/core.py +268 -241
  6. birdnet_analyzer/analyze/utils.py +700 -692
  7. birdnet_analyzer/audio.py +368 -368
  8. birdnet_analyzer/cli.py +732 -709
  9. birdnet_analyzer/config.py +243 -242
  10. birdnet_analyzer/eBird_taxonomy_codes_2024E.json +13046 -0
  11. birdnet_analyzer/embeddings/__init__.py +3 -3
  12. birdnet_analyzer/embeddings/__main__.py +3 -3
  13. birdnet_analyzer/embeddings/cli.py +12 -12
  14. birdnet_analyzer/embeddings/core.py +70 -69
  15. birdnet_analyzer/embeddings/utils.py +173 -179
  16. birdnet_analyzer/evaluation/__init__.py +189 -196
  17. birdnet_analyzer/evaluation/__main__.py +3 -3
  18. birdnet_analyzer/evaluation/assessment/metrics.py +388 -388
  19. birdnet_analyzer/evaluation/assessment/performance_assessor.py +364 -409
  20. birdnet_analyzer/evaluation/assessment/plotting.py +378 -379
  21. birdnet_analyzer/evaluation/preprocessing/data_processor.py +631 -631
  22. birdnet_analyzer/evaluation/preprocessing/utils.py +98 -98
  23. birdnet_analyzer/gui/__init__.py +19 -19
  24. birdnet_analyzer/gui/__main__.py +3 -3
  25. birdnet_analyzer/gui/analysis.py +179 -175
  26. birdnet_analyzer/gui/assets/arrow_down.svg +4 -4
  27. birdnet_analyzer/gui/assets/arrow_left.svg +4 -4
  28. birdnet_analyzer/gui/assets/arrow_right.svg +4 -4
  29. birdnet_analyzer/gui/assets/arrow_up.svg +4 -4
  30. birdnet_analyzer/gui/assets/gui.css +36 -28
  31. birdnet_analyzer/gui/assets/gui.js +89 -93
  32. birdnet_analyzer/gui/embeddings.py +638 -619
  33. birdnet_analyzer/gui/evaluation.py +801 -795
  34. birdnet_analyzer/gui/localization.py +75 -75
  35. birdnet_analyzer/gui/multi_file.py +265 -245
  36. birdnet_analyzer/gui/review.py +472 -519
  37. birdnet_analyzer/gui/segments.py +191 -191
  38. birdnet_analyzer/gui/settings.py +149 -128
  39. birdnet_analyzer/gui/single_file.py +264 -267
  40. birdnet_analyzer/gui/species.py +95 -95
  41. birdnet_analyzer/gui/train.py +687 -696
  42. birdnet_analyzer/gui/utils.py +803 -810
  43. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_af.txt +6522 -6522
  44. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ar.txt +6522 -6522
  45. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_bg.txt +6522 -6522
  46. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ca.txt +6522 -6522
  47. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_cs.txt +6522 -6522
  48. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_da.txt +6522 -6522
  49. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_de.txt +6522 -6522
  50. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_el.txt +6522 -6522
  51. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_en_uk.txt +6522 -6522
  52. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_es.txt +6522 -6522
  53. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_fi.txt +6522 -6522
  54. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_fr.txt +6522 -6522
  55. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_he.txt +6522 -6522
  56. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_hr.txt +6522 -6522
  57. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_hu.txt +6522 -6522
  58. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_in.txt +6522 -6522
  59. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_is.txt +6522 -6522
  60. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_it.txt +6522 -6522
  61. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ja.txt +6522 -6522
  62. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ko.txt +6522 -6522
  63. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_lt.txt +6522 -6522
  64. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ml.txt +6522 -6522
  65. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_nl.txt +6522 -6522
  66. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_no.txt +6522 -6522
  67. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_pl.txt +6522 -6522
  68. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_pt_BR.txt +6522 -6522
  69. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_pt_PT.txt +6522 -6522
  70. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ro.txt +6522 -6522
  71. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ru.txt +6522 -6522
  72. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sk.txt +6522 -6522
  73. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sl.txt +6522 -6522
  74. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sr.txt +6522 -6522
  75. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sv.txt +6522 -6522
  76. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_th.txt +6522 -6522
  77. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_tr.txt +6522 -6522
  78. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_uk.txt +6522 -6522
  79. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_zh.txt +6522 -6522
  80. birdnet_analyzer/lang/de.json +342 -334
  81. birdnet_analyzer/lang/en.json +342 -334
  82. birdnet_analyzer/lang/fi.json +342 -334
  83. birdnet_analyzer/lang/fr.json +342 -334
  84. birdnet_analyzer/lang/id.json +342 -334
  85. birdnet_analyzer/lang/pt-br.json +342 -334
  86. birdnet_analyzer/lang/ru.json +342 -334
  87. birdnet_analyzer/lang/se.json +342 -334
  88. birdnet_analyzer/lang/tlh.json +342 -334
  89. birdnet_analyzer/lang/zh_TW.json +342 -334
  90. birdnet_analyzer/model.py +1213 -1212
  91. birdnet_analyzer/search/__init__.py +3 -3
  92. birdnet_analyzer/search/__main__.py +3 -3
  93. birdnet_analyzer/search/cli.py +11 -11
  94. birdnet_analyzer/search/core.py +78 -78
  95. birdnet_analyzer/search/utils.py +104 -107
  96. birdnet_analyzer/segments/__init__.py +3 -3
  97. birdnet_analyzer/segments/__main__.py +3 -3
  98. birdnet_analyzer/segments/cli.py +13 -13
  99. birdnet_analyzer/segments/core.py +81 -81
  100. birdnet_analyzer/segments/utils.py +383 -383
  101. birdnet_analyzer/species/__init__.py +3 -3
  102. birdnet_analyzer/species/__main__.py +3 -3
  103. birdnet_analyzer/species/cli.py +13 -13
  104. birdnet_analyzer/species/core.py +35 -35
  105. birdnet_analyzer/species/utils.py +73 -74
  106. birdnet_analyzer/train/__init__.py +3 -3
  107. birdnet_analyzer/train/__main__.py +3 -3
  108. birdnet_analyzer/train/cli.py +13 -13
  109. birdnet_analyzer/train/core.py +113 -113
  110. birdnet_analyzer/train/utils.py +878 -877
  111. birdnet_analyzer/translate.py +132 -133
  112. birdnet_analyzer/utils.py +425 -426
  113. {birdnet_analyzer-2.0.1.dist-info → birdnet_analyzer-2.1.1.dist-info}/METADATA +147 -137
  114. birdnet_analyzer-2.1.1.dist-info/RECORD +124 -0
  115. {birdnet_analyzer-2.0.1.dist-info → birdnet_analyzer-2.1.1.dist-info}/WHEEL +1 -1
  116. {birdnet_analyzer-2.0.1.dist-info → birdnet_analyzer-2.1.1.dist-info}/licenses/LICENSE +18 -18
  117. birdnet_analyzer/eBird_taxonomy_codes_2021E.json +0 -25280
  118. birdnet_analyzer/playground.py +0 -5
  119. birdnet_analyzer-2.0.1.dist-info/RECORD +0 -125
  120. {birdnet_analyzer-2.0.1.dist-info → birdnet_analyzer-2.1.1.dist-info}/entry_points.txt +0 -0
  121. {birdnet_analyzer-2.0.1.dist-info → birdnet_analyzer-2.1.1.dist-info}/top_level.txt +0 -0
@@ -1,409 +1,364 @@
1
- """
2
- PerformanceAssessor Module
3
-
4
- This module defines the `PerformanceAssessor` class to evaluate classification model performance.
5
- It includes methods to compute metrics like precision, recall, F1 score, AUROC, and accuracy,
6
- as well as utilities for generating related plots.
7
- """
8
-
9
- from typing import Literal
10
-
11
- import matplotlib.pyplot as plt
12
- import numpy as np
13
- import pandas as pd
14
- from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix
15
-
16
- from birdnet_analyzer.evaluation.assessment import metrics, plotting
17
-
18
-
19
- class PerformanceAssessor:
20
- """
21
- A class to assess the performance of classification models by computing metrics
22
- and generating visualizations for binary and multilabel classification tasks.
23
- """
24
-
25
- def __init__(
26
- self,
27
- num_classes: int,
28
- threshold: float = 0.5,
29
- classes: tuple[str, ...] | None = None,
30
- task: Literal["binary", "multilabel"] = "multilabel",
31
- metrics_list: tuple[str, ...] = (
32
- "recall",
33
- "precision",
34
- "f1",
35
- "ap",
36
- "auroc",
37
- "accuracy",
38
- ),
39
- ) -> None:
40
- """
41
- Initialize the PerformanceAssessor.
42
-
43
- Args:
44
- num_classes (int): The number of classes in the classification problem.
45
- threshold (float): The threshold for binarizing probabilities into class labels.
46
- classes (Optional[Tuple[str, ...]]): Optional tuple of class names.
47
- task (Literal["binary", "multilabel"]): The classification task type.
48
- metrics_list (Tuple[str, ...]): A tuple of metrics to compute.
49
-
50
- Raises:
51
- ValueError: If any of the inputs are invalid.
52
- """
53
- # Validate the number of classes
54
- if not isinstance(num_classes, int) or num_classes <= 0:
55
- raise ValueError("num_classes must be a positive integer.")
56
-
57
- # Validate the threshold value
58
- if not isinstance(threshold, float) or not 0 < threshold < 1:
59
- raise ValueError("threshold must be a float between 0 and 1 (exclusive).")
60
-
61
- # Validate class names
62
- if classes is not None:
63
- if not isinstance(classes, tuple):
64
- raise ValueError("classes must be a tuple of strings.")
65
- if len(classes) != num_classes:
66
- raise ValueError(f"Length of classes ({len(classes)}) must match num_classes ({num_classes}).")
67
- if not all(isinstance(class_name, str) for class_name in classes):
68
- raise ValueError("All elements in classes must be strings.")
69
-
70
- # Validate the task type
71
- if task not in {"binary", "multilabel"}:
72
- raise ValueError("task must be 'binary' or 'multilabel'.")
73
-
74
- # Validate the metrics list
75
- valid_metrics = ["accuracy", "recall", "precision", "f1", "ap", "auroc"]
76
- if not metrics_list:
77
- raise ValueError("metrics_list cannot be empty.")
78
- if not all(metric in valid_metrics for metric in metrics_list):
79
- raise ValueError(f"Invalid metrics in {metrics_list}. Valid options are {valid_metrics}.")
80
-
81
- # Assign instance variables
82
- self.num_classes = num_classes
83
- self.threshold = threshold
84
- self.classes = classes
85
- self.task = task
86
- self.metrics_list = metrics_list
87
-
88
- # Set default colors for plotting
89
- self.colors = ["#3A50B1", "#61A83E", "#D74C4C", "#A13FA1", "#D9A544", "#F3A6E0"]
90
-
91
- def calculate_metrics(
92
- self,
93
- predictions: np.ndarray,
94
- labels: np.ndarray,
95
- per_class_metrics: bool = False,
96
- ) -> pd.DataFrame:
97
- """
98
- Calculate multiple performance metrics for the given predictions and labels.
99
-
100
- Args:
101
- predictions (np.ndarray): Model predictions as a 2D NumPy array (probabilities or logits).
102
- labels (np.ndarray): Ground truth labels as a 2D NumPy array.
103
- per_class_metrics (bool): If True, compute metrics for each class individually.
104
-
105
- Returns:
106
- pd.DataFrame: A DataFrame containing the computed metrics.
107
-
108
- Raises:
109
- TypeError: If predictions or labels are not NumPy arrays.
110
- ValueError: If predictions and labels have mismatched dimensions or invalid shapes.
111
- """
112
- # Validate that predictions and labels are NumPy arrays
113
- if not isinstance(predictions, np.ndarray):
114
- raise TypeError("predictions must be a NumPy array.")
115
- if not isinstance(labels, np.ndarray):
116
- raise TypeError("labels must be a NumPy array.")
117
-
118
- # Ensure predictions and labels have the same shape
119
- if predictions.shape != labels.shape:
120
- raise ValueError("predictions and labels must have the same shape.")
121
- if predictions.ndim != 2:
122
- raise ValueError("predictions and labels must be 2-dimensional arrays.")
123
- if predictions.shape[1] != self.num_classes:
124
- raise ValueError(
125
- f"The number of columns in predictions ({predictions.shape[1]}) "
126
- + f"must match num_classes ({self.num_classes})."
127
- )
128
-
129
- # Determine the averaging method for metrics
130
- if per_class_metrics and self.num_classes == 1:
131
- averaging_method = "macro"
132
- else:
133
- averaging_method = None if per_class_metrics else "macro"
134
-
135
- # Dictionary to store the results of each metric
136
- metrics_results = {}
137
-
138
- # Compute each metric in the metrics list
139
- for metric_name in self.metrics_list:
140
- if metric_name == "recall":
141
- result = metrics.calculate_recall(
142
- predictions=predictions,
143
- labels=labels,
144
- task=self.task,
145
- threshold=self.threshold,
146
- averaging_method=averaging_method,
147
- )
148
- metrics_results["Recall"] = np.atleast_1d(result)
149
- elif metric_name == "precision":
150
- result = metrics.calculate_precision(
151
- predictions=predictions,
152
- labels=labels,
153
- task=self.task,
154
- threshold=self.threshold,
155
- averaging_method=averaging_method,
156
- )
157
- metrics_results["Precision"] = np.atleast_1d(result)
158
- elif metric_name == "f1":
159
- result = metrics.calculate_f1_score(
160
- predictions=predictions,
161
- labels=labels,
162
- task=self.task,
163
- threshold=self.threshold,
164
- averaging_method=averaging_method,
165
- )
166
- metrics_results["F1"] = np.atleast_1d(result)
167
- elif metric_name == "ap":
168
- result = metrics.calculate_average_precision(
169
- predictions=predictions,
170
- labels=labels,
171
- task=self.task,
172
- averaging_method=averaging_method,
173
- )
174
- metrics_results["AP"] = np.atleast_1d(result)
175
- elif metric_name == "auroc":
176
- result = metrics.calculate_auroc(
177
- predictions=predictions,
178
- labels=labels,
179
- task=self.task,
180
- averaging_method=averaging_method,
181
- )
182
- metrics_results["AUROC"] = np.atleast_1d(result)
183
- elif metric_name == "accuracy":
184
- result = metrics.calculate_accuracy(
185
- predictions=predictions,
186
- labels=labels,
187
- task=self.task,
188
- num_classes=self.num_classes,
189
- threshold=self.threshold,
190
- averaging_method=averaging_method,
191
- )
192
- metrics_results["Accuracy"] = np.atleast_1d(result)
193
-
194
- # Define column names for the DataFrame
195
- columns = (
196
- (self.classes if self.classes else [f"Class {i}" for i in range(self.num_classes)])
197
- if per_class_metrics
198
- else ["Overall"]
199
- )
200
-
201
- # Create a DataFrame to organize metric results
202
- metrics_data = {key: np.atleast_1d(value) for key, value in metrics_results.items()}
203
- return pd.DataFrame.from_dict(metrics_data, orient="index", columns=columns)
204
-
205
- def plot_metrics(
206
- self,
207
- predictions: np.ndarray,
208
- labels: np.ndarray,
209
- per_class_metrics: bool = False,
210
- ) -> None:
211
- """
212
- Plot performance metrics for the given predictions and labels.
213
-
214
- Args:
215
- predictions (np.ndarray): Model output predictions as a 2D NumPy array (probabilities or logits).
216
- labels (np.ndarray): Ground truth labels as a 2D NumPy array.
217
- per_class_metrics (bool): If True, plots metrics for each class individually.
218
-
219
- Raises:
220
- ValueError: If the metrics cannot be calculated or plotting fails.
221
-
222
- Returns:
223
- None
224
- """
225
- # Calculate metrics using the provided predictions and labels
226
- metrics_df = self.calculate_metrics(predictions, labels, per_class_metrics)
227
-
228
- # Choose the plotting method based on whether per-class metrics are required
229
- return (
230
- plotting.plot_metrics_per_class(metrics_df, self.colors)
231
- if per_class_metrics
232
- else plotting.plot_overall_metrics(metrics_df, self.colors)
233
- )
234
-
235
- def plot_metrics_all_thresholds(
236
- self,
237
- predictions: np.ndarray,
238
- labels: np.ndarray,
239
- per_class_metrics: bool = False,
240
- ) -> None:
241
- """
242
- Plot performance metrics across thresholds for the given predictions and labels.
243
-
244
- Args:
245
- predictions (np.ndarray): Model output predictions as a 2D NumPy array (probabilities or logits).
246
- labels (np.ndarray): Ground truth labels as a 2D NumPy array.
247
- per_class_metrics (bool): If True, plots metrics for each class individually.
248
-
249
- Raises:
250
- ValueError: If metrics calculation or plotting fails.
251
-
252
- Returns:
253
- None
254
- """
255
- # Save the original threshold value to restore it later
256
- original_threshold = self.threshold
257
-
258
- # Define a range of thresholds for analysis
259
- thresholds = np.arange(0.05, 1.0, 0.05)
260
-
261
- # Exclude metrics that are not threshold-dependent
262
- metrics_to_plot = [m for m in self.metrics_list if m not in ["auroc", "ap"]]
263
-
264
- if per_class_metrics:
265
- # Define class names for plotting
266
- class_names = list(self.classes) if self.classes else [f"Class {i}" for i in range(self.num_classes)]
267
-
268
- # Initialize a dictionary to store metric values per class
269
- metric_values_dict_per_class = {
270
- class_name: {metric: [] for metric in metrics_to_plot} for class_name in class_names
271
- }
272
-
273
- # Compute metrics for each threshold
274
- for thresh in thresholds:
275
- self.threshold = thresh
276
- metrics_df = self.calculate_metrics(predictions, labels, per_class_metrics=True)
277
- for metric_name in metrics_to_plot:
278
- metric_label = metric_name.capitalize() if metric_name != "f1" else "F1"
279
- for class_name in class_names:
280
- value = metrics_df.loc[metric_label, class_name]
281
- metric_values_dict_per_class[class_name][metric_name].append(value)
282
-
283
- # Restore the original threshold
284
- self.threshold = original_threshold
285
-
286
- # Plot metrics across thresholds per class
287
- fig = plotting.plot_metrics_across_thresholds_per_class(
288
- thresholds,
289
- metric_values_dict_per_class,
290
- metrics_to_plot,
291
- class_names,
292
- self.colors,
293
- )
294
- else:
295
- # Initialize a dictionary to store overall metric values
296
- metric_values_dict = {metric_name: [] for metric_name in metrics_to_plot}
297
-
298
- # Compute metrics for each threshold
299
- for thresh in thresholds:
300
- self.threshold = thresh
301
- metrics_df = self.calculate_metrics(predictions, labels, per_class_metrics=False)
302
- for metric_name in metrics_to_plot:
303
- metric_label = metric_name.capitalize() if metric_name != "f1" else "F1"
304
- value = metrics_df.loc[metric_label, "Overall"]
305
- metric_values_dict[metric_name].append(value)
306
-
307
- # Restore the original threshold
308
- self.threshold = original_threshold
309
-
310
- # Plot metrics across thresholds
311
- fig = plotting.plot_metrics_across_thresholds(
312
- thresholds,
313
- metric_values_dict,
314
- metrics_to_plot,
315
- self.colors,
316
- )
317
-
318
- return fig
319
-
320
- def plot_confusion_matrix(
321
- self,
322
- predictions: np.ndarray,
323
- labels: np.ndarray,
324
- ) -> None:
325
- """
326
- Plot confusion matrices for each class using scikit-learn's ConfusionMatrixDisplay.
327
-
328
- Args:
329
- predictions (np.ndarray): Model output predictions as a 2D NumPy array (probabilities or logits).
330
- labels (np.ndarray): Ground truth labels as a 2D NumPy array.
331
-
332
- Raises:
333
- TypeError: If predictions or labels are not NumPy arrays.
334
- ValueError: If predictions and labels have mismatched shapes or invalid dimensions.
335
-
336
- Returns:
337
- None
338
- """
339
- # Validate that predictions and labels are NumPy arrays and match in shape
340
- if not isinstance(predictions, np.ndarray):
341
- raise TypeError("predictions must be a NumPy array.")
342
- if not isinstance(labels, np.ndarray):
343
- raise TypeError("labels must be a NumPy array.")
344
- if predictions.shape != labels.shape:
345
- raise ValueError("predictions and labels must have the same shape.")
346
- if predictions.ndim != 2:
347
- raise ValueError("predictions and labels must be 2-dimensional arrays.")
348
- if predictions.shape[1] != self.num_classes:
349
- raise ValueError(
350
- f"The number of columns in predictions ({predictions.shape[1]}) "
351
- + f"must match num_classes ({self.num_classes})."
352
- )
353
-
354
- if self.task == "binary":
355
- # Binarize predictions using the threshold
356
- y_pred = (predictions >= self.threshold).astype(int).flatten()
357
- y_true = labels.astype(int).flatten()
358
-
359
- # Compute and normalize the confusion matrix
360
- conf_mat = confusion_matrix(y_true, y_pred, normalize="true")
361
- conf_mat = np.round(conf_mat, 2)
362
-
363
- # Plot the confusion matrix
364
- disp = ConfusionMatrixDisplay(confusion_matrix=conf_mat, display_labels=["Negative", "Positive"])
365
- fig, ax = plt.subplots(figsize=(6, 6))
366
- disp.plot(cmap="Reds", ax=ax, colorbar=False, values_format=".2f")
367
- ax.set_title("Confusion Matrix")
368
-
369
- return fig
370
-
371
- if self.task == "multilabel":
372
- # Binarize predictions for multilabel classification
373
- y_pred = (predictions >= self.threshold).astype(int)
374
- y_true = labels.astype(int)
375
-
376
- # Compute confusion matrices for each class
377
- conf_mats = []
378
- class_names = self.classes if self.classes else [f"Class {i}" for i in range(self.num_classes)]
379
- for i in range(self.num_classes):
380
- conf_mat = confusion_matrix(y_true[:, i], y_pred[:, i], normalize="true")
381
- conf_mat = np.round(conf_mat, 2)
382
- conf_mats.append(conf_mat)
383
-
384
- # Determine grid size for subplots
385
- num_matrices = self.num_classes
386
- n_cols = int(np.ceil(np.sqrt(num_matrices)))
387
- n_rows = int(np.ceil(num_matrices / n_cols))
388
-
389
- # Create subplots for each confusion matrix
390
- fig, axes = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, 4 * n_rows))
391
- axes = axes.flatten()
392
-
393
- # Plot each confusion matrix
394
- for idx, (conf_mat, class_name) in enumerate(zip(conf_mats, class_names, strict=True)):
395
- disp = ConfusionMatrixDisplay(confusion_matrix=conf_mat, display_labels=["Negative", "Positive"])
396
- disp.plot(cmap="Reds", ax=axes[idx], colorbar=False, values_format=".2f")
397
- axes[idx].set_title(f"{class_name}")
398
- axes[idx].set_xlabel("Predicted class")
399
- axes[idx].set_ylabel("True class")
400
-
401
- # Remove unused subplot axes
402
- for ax in axes[num_matrices:]:
403
- fig.delaxes(ax)
404
-
405
- plt.tight_layout()
406
-
407
- return fig
408
-
409
- raise ValueError(f"Unsupported task type: {self.task}")
1
+ """
2
+ PerformanceAssessor Module
3
+
4
+ This module defines the `PerformanceAssessor` class to evaluate classification model performance.
5
+ It includes methods to compute metrics like precision, recall, F1 score, AUROC, and accuracy,
6
+ as well as utilities for generating related plots.
7
+ """
8
+
9
+ from typing import Literal
10
+
11
+ import numpy as np
12
+ import pandas as pd
13
+ from sklearn.metrics import confusion_matrix
14
+
15
+ from birdnet_analyzer.evaluation.assessment import metrics, plotting
16
+
17
+
18
+ class PerformanceAssessor:
19
+ """
20
+ A class to assess the performance of classification models by computing metrics
21
+ and generating visualizations for binary and multilabel classification tasks.
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ num_classes: int,
27
+ threshold: float = 0.5,
28
+ classes: tuple[str, ...] | None = None,
29
+ task: Literal["binary", "multilabel"] = "multilabel",
30
+ metrics_list: tuple[str, ...] = (
31
+ "recall",
32
+ "precision",
33
+ "f1",
34
+ "ap",
35
+ "auroc",
36
+ "accuracy",
37
+ ),
38
+ ) -> None:
39
+ """
40
+ Initialize the PerformanceAssessor.
41
+
42
+ Args:
43
+ num_classes (int): The number of classes in the classification problem.
44
+ threshold (float): The threshold for binarizing probabilities into class labels.
45
+ classes (Optional[Tuple[str, ...]]): Optional tuple of class names.
46
+ task (Literal["binary", "multilabel"]): The classification task type.
47
+ metrics_list (Tuple[str, ...]): A tuple of metrics to compute.
48
+
49
+ Raises:
50
+ ValueError: If any of the inputs are invalid.
51
+ """
52
+ # Validate the number of classes
53
+ if not isinstance(num_classes, int) or num_classes <= 0:
54
+ raise ValueError("num_classes must be a positive integer.")
55
+
56
+ # Validate the threshold value
57
+ if not isinstance(threshold, float) or not 0 < threshold < 1:
58
+ raise ValueError("threshold must be a float between 0 and 1 (exclusive).")
59
+
60
+ # Validate class names
61
+ if classes is not None:
62
+ if not isinstance(classes, tuple):
63
+ raise ValueError("classes must be a tuple of strings.")
64
+ if len(classes) != num_classes:
65
+ raise ValueError(f"Length of classes ({len(classes)}) must match num_classes ({num_classes}).")
66
+ if not all(isinstance(class_name, str) for class_name in classes):
67
+ raise ValueError("All elements in classes must be strings.")
68
+
69
+ # Validate the task type
70
+ if task not in {"binary", "multilabel"}:
71
+ raise ValueError("task must be 'binary' or 'multilabel'.")
72
+
73
+ # Validate the metrics list
74
+ valid_metrics = ["accuracy", "recall", "precision", "f1", "ap", "auroc"]
75
+ if not metrics_list:
76
+ raise ValueError("metrics_list cannot be empty.")
77
+ if not all(metric in valid_metrics for metric in metrics_list):
78
+ raise ValueError(f"Invalid metrics in {metrics_list}. Valid options are {valid_metrics}.")
79
+
80
+ # Assign instance variables
81
+ self.num_classes = num_classes
82
+ self.threshold = threshold
83
+ self.classes = classes
84
+ self.task = task
85
+ self.metrics_list = metrics_list
86
+
87
+ # Set default colors for plotting
88
+ self.colors = ["#3A50B1", "#61A83E", "#D74C4C", "#A13FA1", "#D9A544", "#F3A6E0"]
89
+
90
+ def calculate_metrics(
91
+ self,
92
+ predictions: np.ndarray,
93
+ labels: np.ndarray,
94
+ per_class_metrics: bool = False,
95
+ ) -> pd.DataFrame:
96
+ """
97
+ Calculate multiple performance metrics for the given predictions and labels.
98
+
99
+ Args:
100
+ predictions (np.ndarray): Model predictions as a 2D NumPy array (probabilities or logits).
101
+ labels (np.ndarray): Ground truth labels as a 2D NumPy array.
102
+ per_class_metrics (bool): If True, compute metrics for each class individually.
103
+
104
+ Returns:
105
+ pd.DataFrame: A DataFrame containing the computed metrics.
106
+
107
+ Raises:
108
+ TypeError: If predictions or labels are not NumPy arrays.
109
+ ValueError: If predictions and labels have mismatched dimensions or invalid shapes.
110
+ """
111
+ # Validate that predictions and labels are NumPy arrays
112
+ if not isinstance(predictions, np.ndarray):
113
+ raise TypeError("predictions must be a NumPy array.")
114
+ if not isinstance(labels, np.ndarray):
115
+ raise TypeError("labels must be a NumPy array.")
116
+
117
+ # Ensure predictions and labels have the same shape
118
+ if predictions.shape != labels.shape:
119
+ raise ValueError("predictions and labels must have the same shape.")
120
+ if predictions.ndim != 2:
121
+ raise ValueError("predictions and labels must be 2-dimensional arrays.")
122
+ if predictions.shape[1] != self.num_classes:
123
+ raise ValueError(f"The number of columns in predictions ({predictions.shape[1]}) " + f"must match num_classes ({self.num_classes}).")
124
+
125
+ # Determine the averaging method for metrics
126
+ if per_class_metrics and self.num_classes == 1:
127
+ averaging_method = "macro"
128
+ else:
129
+ averaging_method = None if per_class_metrics else "macro"
130
+
131
+ # Dictionary to store the results of each metric
132
+ metrics_results = {}
133
+
134
+ # Compute each metric in the metrics list
135
+ for metric_name in self.metrics_list:
136
+ if metric_name == "recall":
137
+ result = metrics.calculate_recall(
138
+ predictions=predictions,
139
+ labels=labels,
140
+ task=self.task,
141
+ threshold=self.threshold,
142
+ averaging_method=averaging_method,
143
+ )
144
+ metrics_results["Recall"] = np.atleast_1d(result)
145
+ elif metric_name == "precision":
146
+ result = metrics.calculate_precision(
147
+ predictions=predictions,
148
+ labels=labels,
149
+ task=self.task,
150
+ threshold=self.threshold,
151
+ averaging_method=averaging_method,
152
+ )
153
+ metrics_results["Precision"] = np.atleast_1d(result)
154
+ elif metric_name == "f1":
155
+ result = metrics.calculate_f1_score(
156
+ predictions=predictions,
157
+ labels=labels,
158
+ task=self.task,
159
+ threshold=self.threshold,
160
+ averaging_method=averaging_method,
161
+ )
162
+ metrics_results["F1"] = np.atleast_1d(result)
163
+ elif metric_name == "ap":
164
+ result = metrics.calculate_average_precision(
165
+ predictions=predictions,
166
+ labels=labels,
167
+ task=self.task,
168
+ averaging_method=averaging_method,
169
+ )
170
+ metrics_results["AP"] = np.atleast_1d(result)
171
+ elif metric_name == "auroc":
172
+ result = metrics.calculate_auroc(
173
+ predictions=predictions,
174
+ labels=labels,
175
+ task=self.task,
176
+ averaging_method=averaging_method,
177
+ )
178
+ metrics_results["AUROC"] = np.atleast_1d(result)
179
+ elif metric_name == "accuracy":
180
+ result = metrics.calculate_accuracy(
181
+ predictions=predictions,
182
+ labels=labels,
183
+ task=self.task,
184
+ num_classes=self.num_classes,
185
+ threshold=self.threshold,
186
+ averaging_method=averaging_method,
187
+ )
188
+ metrics_results["Accuracy"] = np.atleast_1d(result)
189
+
190
+ # Define column names for the DataFrame
191
+ columns = (self.classes if self.classes else [f"Class {i}" for i in range(self.num_classes)]) if per_class_metrics else ["Overall"]
192
+
193
+ # Create a DataFrame to organize metric results
194
+ metrics_data = {key: np.atleast_1d(value) for key, value in metrics_results.items()}
195
+ return pd.DataFrame.from_dict(metrics_data, orient="index", columns=columns)
196
+
197
+ def plot_metrics(
198
+ self,
199
+ predictions: np.ndarray,
200
+ labels: np.ndarray,
201
+ per_class_metrics: bool = False,
202
+ ):
203
+ """
204
+ Plot performance metrics for the given predictions and labels.
205
+
206
+ Args:
207
+ predictions (np.ndarray): Model output predictions as a 2D NumPy array (probabilities or logits).
208
+ labels (np.ndarray): Ground truth labels as a 2D NumPy array.
209
+ per_class_metrics (bool): If True, plots metrics for each class individually.
210
+
211
+ Raises:
212
+ ValueError: If the metrics cannot be calculated or plotting fails.
213
+
214
+ Returns:
215
+ None
216
+ """
217
+ # Calculate metrics using the provided predictions and labels
218
+ metrics_df = self.calculate_metrics(predictions, labels, per_class_metrics)
219
+
220
+ # Choose the plotting method based on whether per-class metrics are required
221
+ return plotting.plot_metrics_per_class(metrics_df, self.colors) if per_class_metrics else plotting.plot_overall_metrics(metrics_df, self.colors)
222
+
223
+ def plot_metrics_all_thresholds(
224
+ self,
225
+ predictions: np.ndarray,
226
+ labels: np.ndarray,
227
+ per_class_metrics: bool = False,
228
+ ):
229
+ """
230
+ Plot performance metrics across thresholds for the given predictions and labels.
231
+
232
+ Args:
233
+ predictions (np.ndarray): Model output predictions as a 2D NumPy array (probabilities or logits).
234
+ labels (np.ndarray): Ground truth labels as a 2D NumPy array.
235
+ per_class_metrics (bool): If True, plots metrics for each class individually.
236
+
237
+ Raises:
238
+ ValueError: If metrics calculation or plotting fails.
239
+
240
+ Returns:
241
+ None
242
+ """
243
+ # Save the original threshold value to restore it later
244
+ original_threshold = self.threshold
245
+
246
+ # Define a range of thresholds for analysis
247
+ thresholds = np.arange(0.05, 1.0, 0.05)
248
+
249
+ # Exclude metrics that are not threshold-dependent
250
+ metrics_to_plot = [m for m in self.metrics_list if m not in ["auroc", "ap"]]
251
+
252
+ if per_class_metrics:
253
+ # Define class names for plotting
254
+ class_names = list(self.classes) if self.classes else [f"Class {i}" for i in range(self.num_classes)]
255
+
256
+ # Initialize a dictionary to store metric values per class
257
+ metric_values_dict_per_class = {class_name: {metric: [] for metric in metrics_to_plot} for class_name in class_names}
258
+
259
+ # Compute metrics for each threshold
260
+ for thresh in thresholds:
261
+ self.threshold = thresh
262
+ metrics_df = self.calculate_metrics(predictions, labels, per_class_metrics=True)
263
+ for metric_name in metrics_to_plot:
264
+ metric_label = metric_name.capitalize() if metric_name != "f1" else "F1"
265
+ for class_name in class_names:
266
+ value = metrics_df.loc[metric_label, class_name]
267
+ metric_values_dict_per_class[class_name][metric_name].append(value)
268
+
269
+ # Restore the original threshold
270
+ self.threshold = original_threshold
271
+
272
+ # Plot metrics across thresholds per class
273
+ fig = plotting.plot_metrics_across_thresholds_per_class(
274
+ thresholds,
275
+ metric_values_dict_per_class,
276
+ metrics_to_plot,
277
+ class_names,
278
+ self.colors,
279
+ )
280
+ else:
281
+ # Initialize a dictionary to store overall metric values
282
+ metric_values_dict = {metric_name: [] for metric_name in metrics_to_plot}
283
+
284
+ # Compute metrics for each threshold
285
+ for thresh in thresholds:
286
+ self.threshold = thresh
287
+ metrics_df = self.calculate_metrics(predictions, labels, per_class_metrics=False)
288
+ for metric_name in metrics_to_plot:
289
+ metric_label = metric_name.capitalize() if metric_name != "f1" else "F1"
290
+ value = metrics_df.loc[metric_label, "Overall"]
291
+ metric_values_dict[metric_name].append(value)
292
+
293
+ # Restore the original threshold
294
+ self.threshold = original_threshold
295
+
296
+ # Plot metrics across thresholds
297
+ fig = plotting.plot_metrics_across_thresholds(
298
+ thresholds,
299
+ metric_values_dict,
300
+ metrics_to_plot,
301
+ self.colors,
302
+ )
303
+
304
+ return fig
305
+
306
+ def plot_confusion_matrix(
307
+ self,
308
+ predictions: np.ndarray,
309
+ labels: np.ndarray,
310
+ ):
311
+ """
312
+ Plot confusion matrices for each class using scikit-learn's ConfusionMatrixDisplay.
313
+
314
+ Args:
315
+ predictions (np.ndarray): Model output predictions as a 2D NumPy array (probabilities or logits).
316
+ labels (np.ndarray): Ground truth labels as a 2D NumPy array.
317
+
318
+ Raises:
319
+ TypeError: If predictions or labels are not NumPy arrays.
320
+ ValueError: If predictions and labels have mismatched shapes or invalid dimensions.
321
+
322
+ Returns:
323
+ None
324
+ """
325
+ # Validate that predictions and labels are NumPy arrays and match in shape
326
+ if not isinstance(predictions, np.ndarray):
327
+ raise TypeError("predictions must be a NumPy array.")
328
+ if not isinstance(labels, np.ndarray):
329
+ raise TypeError("labels must be a NumPy array.")
330
+ if predictions.shape != labels.shape:
331
+ raise ValueError("predictions and labels must have the same shape.")
332
+ if predictions.ndim != 2:
333
+ raise ValueError("predictions and labels must be 2-dimensional arrays.")
334
+ if predictions.shape[1] != self.num_classes:
335
+ raise ValueError(f"The number of columns in predictions ({predictions.shape[1]}) " + f"must match num_classes ({self.num_classes}).")
336
+
337
+ if self.task == "binary":
338
+ # Binarize predictions using the threshold
339
+ y_pred = (predictions >= self.threshold).astype(int).flatten()
340
+ y_true = labels.astype(int).flatten()
341
+
342
+ # Compute and normalize the confusion matrix
343
+ conf_mat = confusion_matrix(y_true, y_pred, normalize="true")
344
+ conf_mat = np.round(conf_mat, 2)
345
+
346
+ return plotting.plot_confusion_matrices(conf_mat, self.task, self.classes)
347
+
348
+ if self.task == "multilabel":
349
+ # Binarize predictions for multilabel classification
350
+ y_pred = (predictions >= self.threshold).astype(int)
351
+ y_true = labels.astype(int)
352
+
353
+ # Compute confusion matrices for each class
354
+ conf_mats = []
355
+ class_names = self.classes if self.classes else [f"Class {i}" for i in range(self.num_classes)]
356
+
357
+ for i in range(self.num_classes):
358
+ conf_mat = confusion_matrix(y_true[:, i], y_pred[:, i], normalize="true")
359
+ conf_mat = np.round(conf_mat, 2)
360
+ conf_mats.append(conf_mat)
361
+
362
+ return plotting.plot_confusion_matrices(np.array(conf_mats), self.task, class_names)
363
+
364
+ raise ValueError(f"Unsupported task type: {self.task}")