birdnet-analyzer 2.0.0__py3-none-any.whl → 2.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (122) hide show
  1. birdnet_analyzer/__init__.py +9 -8
  2. birdnet_analyzer/analyze/__init__.py +5 -5
  3. birdnet_analyzer/analyze/__main__.py +3 -4
  4. birdnet_analyzer/analyze/cli.py +25 -25
  5. birdnet_analyzer/analyze/core.py +241 -245
  6. birdnet_analyzer/analyze/utils.py +692 -701
  7. birdnet_analyzer/audio.py +368 -372
  8. birdnet_analyzer/cli.py +709 -707
  9. birdnet_analyzer/config.py +242 -242
  10. birdnet_analyzer/eBird_taxonomy_codes_2021E.json +25279 -25279
  11. birdnet_analyzer/embeddings/__init__.py +3 -4
  12. birdnet_analyzer/embeddings/__main__.py +3 -3
  13. birdnet_analyzer/embeddings/cli.py +12 -13
  14. birdnet_analyzer/embeddings/core.py +69 -70
  15. birdnet_analyzer/embeddings/utils.py +179 -193
  16. birdnet_analyzer/evaluation/__init__.py +196 -195
  17. birdnet_analyzer/evaluation/__main__.py +3 -3
  18. birdnet_analyzer/evaluation/assessment/__init__.py +0 -0
  19. birdnet_analyzer/evaluation/assessment/metrics.py +388 -0
  20. birdnet_analyzer/evaluation/assessment/performance_assessor.py +409 -0
  21. birdnet_analyzer/evaluation/assessment/plotting.py +379 -0
  22. birdnet_analyzer/evaluation/preprocessing/__init__.py +0 -0
  23. birdnet_analyzer/evaluation/preprocessing/data_processor.py +631 -0
  24. birdnet_analyzer/evaluation/preprocessing/utils.py +98 -0
  25. birdnet_analyzer/gui/__init__.py +19 -23
  26. birdnet_analyzer/gui/__main__.py +3 -3
  27. birdnet_analyzer/gui/analysis.py +175 -174
  28. birdnet_analyzer/gui/assets/arrow_down.svg +4 -4
  29. birdnet_analyzer/gui/assets/arrow_left.svg +4 -4
  30. birdnet_analyzer/gui/assets/arrow_right.svg +4 -4
  31. birdnet_analyzer/gui/assets/arrow_up.svg +4 -4
  32. birdnet_analyzer/gui/assets/gui.css +28 -28
  33. birdnet_analyzer/gui/assets/gui.js +93 -93
  34. birdnet_analyzer/gui/embeddings.py +619 -620
  35. birdnet_analyzer/gui/evaluation.py +795 -813
  36. birdnet_analyzer/gui/localization.py +75 -68
  37. birdnet_analyzer/gui/multi_file.py +245 -246
  38. birdnet_analyzer/gui/review.py +519 -527
  39. birdnet_analyzer/gui/segments.py +191 -191
  40. birdnet_analyzer/gui/settings.py +128 -129
  41. birdnet_analyzer/gui/single_file.py +267 -269
  42. birdnet_analyzer/gui/species.py +95 -95
  43. birdnet_analyzer/gui/train.py +696 -698
  44. birdnet_analyzer/gui/utils.py +810 -808
  45. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_af.txt +6522 -6522
  46. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ar.txt +6522 -6522
  47. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_bg.txt +6522 -6522
  48. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ca.txt +6522 -6522
  49. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_cs.txt +6522 -6522
  50. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_da.txt +6522 -6522
  51. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_de.txt +6522 -6522
  52. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_el.txt +6522 -6522
  53. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_en_uk.txt +6522 -6522
  54. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_es.txt +6522 -6522
  55. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_fi.txt +6522 -6522
  56. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_fr.txt +6522 -6522
  57. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_he.txt +6522 -6522
  58. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_hr.txt +6522 -6522
  59. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_hu.txt +6522 -6522
  60. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_in.txt +6522 -6522
  61. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_is.txt +6522 -6522
  62. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_it.txt +6522 -6522
  63. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ja.txt +6522 -6522
  64. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ko.txt +6522 -6522
  65. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_lt.txt +6522 -6522
  66. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ml.txt +6522 -6522
  67. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_nl.txt +6522 -6522
  68. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_no.txt +6522 -6522
  69. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_pl.txt +6522 -6522
  70. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_pt_BR.txt +6522 -6522
  71. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_pt_PT.txt +6522 -6522
  72. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ro.txt +6522 -6522
  73. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ru.txt +6522 -6522
  74. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sk.txt +6522 -6522
  75. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sl.txt +6522 -6522
  76. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sr.txt +6522 -6522
  77. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sv.txt +6522 -6522
  78. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_th.txt +6522 -6522
  79. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_tr.txt +6522 -6522
  80. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_uk.txt +6522 -6522
  81. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_zh.txt +6522 -6522
  82. birdnet_analyzer/lang/de.json +334 -334
  83. birdnet_analyzer/lang/en.json +334 -334
  84. birdnet_analyzer/lang/fi.json +334 -334
  85. birdnet_analyzer/lang/fr.json +334 -334
  86. birdnet_analyzer/lang/id.json +334 -334
  87. birdnet_analyzer/lang/pt-br.json +334 -334
  88. birdnet_analyzer/lang/ru.json +334 -334
  89. birdnet_analyzer/lang/se.json +334 -334
  90. birdnet_analyzer/lang/tlh.json +334 -334
  91. birdnet_analyzer/lang/zh_TW.json +334 -334
  92. birdnet_analyzer/model.py +1212 -1243
  93. birdnet_analyzer/playground.py +5 -0
  94. birdnet_analyzer/search/__init__.py +3 -3
  95. birdnet_analyzer/search/__main__.py +3 -3
  96. birdnet_analyzer/search/cli.py +11 -12
  97. birdnet_analyzer/search/core.py +78 -78
  98. birdnet_analyzer/search/utils.py +107 -111
  99. birdnet_analyzer/segments/__init__.py +3 -3
  100. birdnet_analyzer/segments/__main__.py +3 -3
  101. birdnet_analyzer/segments/cli.py +13 -14
  102. birdnet_analyzer/segments/core.py +81 -78
  103. birdnet_analyzer/segments/utils.py +383 -394
  104. birdnet_analyzer/species/__init__.py +3 -3
  105. birdnet_analyzer/species/__main__.py +3 -3
  106. birdnet_analyzer/species/cli.py +13 -14
  107. birdnet_analyzer/species/core.py +35 -35
  108. birdnet_analyzer/species/utils.py +74 -75
  109. birdnet_analyzer/train/__init__.py +3 -3
  110. birdnet_analyzer/train/__main__.py +3 -3
  111. birdnet_analyzer/train/cli.py +13 -14
  112. birdnet_analyzer/train/core.py +113 -113
  113. birdnet_analyzer/train/utils.py +877 -847
  114. birdnet_analyzer/translate.py +133 -104
  115. birdnet_analyzer/utils.py +426 -419
  116. {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.0.1.dist-info}/METADATA +137 -129
  117. birdnet_analyzer-2.0.1.dist-info/RECORD +125 -0
  118. {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.0.1.dist-info}/WHEEL +1 -1
  119. {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.0.1.dist-info}/licenses/LICENSE +18 -18
  120. birdnet_analyzer-2.0.0.dist-info/RECORD +0 -117
  121. {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.0.1.dist-info}/entry_points.txt +0 -0
  122. {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.0.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,409 @@
1
+ """
2
+ PerformanceAssessor Module
3
+
4
+ This module defines the `PerformanceAssessor` class to evaluate classification model performance.
5
+ It includes methods to compute metrics like precision, recall, F1 score, AUROC, and accuracy,
6
+ as well as utilities for generating related plots.
7
+ """
8
+
9
+ from typing import Literal
10
+
11
+ import matplotlib.pyplot as plt
12
+ import numpy as np
13
+ import pandas as pd
14
+ from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix
15
+
16
+ from birdnet_analyzer.evaluation.assessment import metrics, plotting
17
+
18
+
19
+ class PerformanceAssessor:
20
+ """
21
+ A class to assess the performance of classification models by computing metrics
22
+ and generating visualizations for binary and multilabel classification tasks.
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ num_classes: int,
28
+ threshold: float = 0.5,
29
+ classes: tuple[str, ...] | None = None,
30
+ task: Literal["binary", "multilabel"] = "multilabel",
31
+ metrics_list: tuple[str, ...] = (
32
+ "recall",
33
+ "precision",
34
+ "f1",
35
+ "ap",
36
+ "auroc",
37
+ "accuracy",
38
+ ),
39
+ ) -> None:
40
+ """
41
+ Initialize the PerformanceAssessor.
42
+
43
+ Args:
44
+ num_classes (int): The number of classes in the classification problem.
45
+ threshold (float): The threshold for binarizing probabilities into class labels.
46
+ classes (Optional[Tuple[str, ...]]): Optional tuple of class names.
47
+ task (Literal["binary", "multilabel"]): The classification task type.
48
+ metrics_list (Tuple[str, ...]): A tuple of metrics to compute.
49
+
50
+ Raises:
51
+ ValueError: If any of the inputs are invalid.
52
+ """
53
+ # Validate the number of classes
54
+ if not isinstance(num_classes, int) or num_classes <= 0:
55
+ raise ValueError("num_classes must be a positive integer.")
56
+
57
+ # Validate the threshold value
58
+ if not isinstance(threshold, float) or not 0 < threshold < 1:
59
+ raise ValueError("threshold must be a float between 0 and 1 (exclusive).")
60
+
61
+ # Validate class names
62
+ if classes is not None:
63
+ if not isinstance(classes, tuple):
64
+ raise ValueError("classes must be a tuple of strings.")
65
+ if len(classes) != num_classes:
66
+ raise ValueError(f"Length of classes ({len(classes)}) must match num_classes ({num_classes}).")
67
+ if not all(isinstance(class_name, str) for class_name in classes):
68
+ raise ValueError("All elements in classes must be strings.")
69
+
70
+ # Validate the task type
71
+ if task not in {"binary", "multilabel"}:
72
+ raise ValueError("task must be 'binary' or 'multilabel'.")
73
+
74
+ # Validate the metrics list
75
+ valid_metrics = ["accuracy", "recall", "precision", "f1", "ap", "auroc"]
76
+ if not metrics_list:
77
+ raise ValueError("metrics_list cannot be empty.")
78
+ if not all(metric in valid_metrics for metric in metrics_list):
79
+ raise ValueError(f"Invalid metrics in {metrics_list}. Valid options are {valid_metrics}.")
80
+
81
+ # Assign instance variables
82
+ self.num_classes = num_classes
83
+ self.threshold = threshold
84
+ self.classes = classes
85
+ self.task = task
86
+ self.metrics_list = metrics_list
87
+
88
+ # Set default colors for plotting
89
+ self.colors = ["#3A50B1", "#61A83E", "#D74C4C", "#A13FA1", "#D9A544", "#F3A6E0"]
90
+
91
+ def calculate_metrics(
92
+ self,
93
+ predictions: np.ndarray,
94
+ labels: np.ndarray,
95
+ per_class_metrics: bool = False,
96
+ ) -> pd.DataFrame:
97
+ """
98
+ Calculate multiple performance metrics for the given predictions and labels.
99
+
100
+ Args:
101
+ predictions (np.ndarray): Model predictions as a 2D NumPy array (probabilities or logits).
102
+ labels (np.ndarray): Ground truth labels as a 2D NumPy array.
103
+ per_class_metrics (bool): If True, compute metrics for each class individually.
104
+
105
+ Returns:
106
+ pd.DataFrame: A DataFrame containing the computed metrics.
107
+
108
+ Raises:
109
+ TypeError: If predictions or labels are not NumPy arrays.
110
+ ValueError: If predictions and labels have mismatched dimensions or invalid shapes.
111
+ """
112
+ # Validate that predictions and labels are NumPy arrays
113
+ if not isinstance(predictions, np.ndarray):
114
+ raise TypeError("predictions must be a NumPy array.")
115
+ if not isinstance(labels, np.ndarray):
116
+ raise TypeError("labels must be a NumPy array.")
117
+
118
+ # Ensure predictions and labels have the same shape
119
+ if predictions.shape != labels.shape:
120
+ raise ValueError("predictions and labels must have the same shape.")
121
+ if predictions.ndim != 2:
122
+ raise ValueError("predictions and labels must be 2-dimensional arrays.")
123
+ if predictions.shape[1] != self.num_classes:
124
+ raise ValueError(
125
+ f"The number of columns in predictions ({predictions.shape[1]}) "
126
+ + f"must match num_classes ({self.num_classes})."
127
+ )
128
+
129
+ # Determine the averaging method for metrics
130
+ if per_class_metrics and self.num_classes == 1:
131
+ averaging_method = "macro"
132
+ else:
133
+ averaging_method = None if per_class_metrics else "macro"
134
+
135
+ # Dictionary to store the results of each metric
136
+ metrics_results = {}
137
+
138
+ # Compute each metric in the metrics list
139
+ for metric_name in self.metrics_list:
140
+ if metric_name == "recall":
141
+ result = metrics.calculate_recall(
142
+ predictions=predictions,
143
+ labels=labels,
144
+ task=self.task,
145
+ threshold=self.threshold,
146
+ averaging_method=averaging_method,
147
+ )
148
+ metrics_results["Recall"] = np.atleast_1d(result)
149
+ elif metric_name == "precision":
150
+ result = metrics.calculate_precision(
151
+ predictions=predictions,
152
+ labels=labels,
153
+ task=self.task,
154
+ threshold=self.threshold,
155
+ averaging_method=averaging_method,
156
+ )
157
+ metrics_results["Precision"] = np.atleast_1d(result)
158
+ elif metric_name == "f1":
159
+ result = metrics.calculate_f1_score(
160
+ predictions=predictions,
161
+ labels=labels,
162
+ task=self.task,
163
+ threshold=self.threshold,
164
+ averaging_method=averaging_method,
165
+ )
166
+ metrics_results["F1"] = np.atleast_1d(result)
167
+ elif metric_name == "ap":
168
+ result = metrics.calculate_average_precision(
169
+ predictions=predictions,
170
+ labels=labels,
171
+ task=self.task,
172
+ averaging_method=averaging_method,
173
+ )
174
+ metrics_results["AP"] = np.atleast_1d(result)
175
+ elif metric_name == "auroc":
176
+ result = metrics.calculate_auroc(
177
+ predictions=predictions,
178
+ labels=labels,
179
+ task=self.task,
180
+ averaging_method=averaging_method,
181
+ )
182
+ metrics_results["AUROC"] = np.atleast_1d(result)
183
+ elif metric_name == "accuracy":
184
+ result = metrics.calculate_accuracy(
185
+ predictions=predictions,
186
+ labels=labels,
187
+ task=self.task,
188
+ num_classes=self.num_classes,
189
+ threshold=self.threshold,
190
+ averaging_method=averaging_method,
191
+ )
192
+ metrics_results["Accuracy"] = np.atleast_1d(result)
193
+
194
+ # Define column names for the DataFrame
195
+ columns = (
196
+ (self.classes if self.classes else [f"Class {i}" for i in range(self.num_classes)])
197
+ if per_class_metrics
198
+ else ["Overall"]
199
+ )
200
+
201
+ # Create a DataFrame to organize metric results
202
+ metrics_data = {key: np.atleast_1d(value) for key, value in metrics_results.items()}
203
+ return pd.DataFrame.from_dict(metrics_data, orient="index", columns=columns)
204
+
205
+ def plot_metrics(
206
+ self,
207
+ predictions: np.ndarray,
208
+ labels: np.ndarray,
209
+ per_class_metrics: bool = False,
210
+ ) -> None:
211
+ """
212
+ Plot performance metrics for the given predictions and labels.
213
+
214
+ Args:
215
+ predictions (np.ndarray): Model output predictions as a 2D NumPy array (probabilities or logits).
216
+ labels (np.ndarray): Ground truth labels as a 2D NumPy array.
217
+ per_class_metrics (bool): If True, plots metrics for each class individually.
218
+
219
+ Raises:
220
+ ValueError: If the metrics cannot be calculated or plotting fails.
221
+
222
+ Returns:
223
+ None
224
+ """
225
+ # Calculate metrics using the provided predictions and labels
226
+ metrics_df = self.calculate_metrics(predictions, labels, per_class_metrics)
227
+
228
+ # Choose the plotting method based on whether per-class metrics are required
229
+ return (
230
+ plotting.plot_metrics_per_class(metrics_df, self.colors)
231
+ if per_class_metrics
232
+ else plotting.plot_overall_metrics(metrics_df, self.colors)
233
+ )
234
+
235
+ def plot_metrics_all_thresholds(
236
+ self,
237
+ predictions: np.ndarray,
238
+ labels: np.ndarray,
239
+ per_class_metrics: bool = False,
240
+ ) -> None:
241
+ """
242
+ Plot performance metrics across thresholds for the given predictions and labels.
243
+
244
+ Args:
245
+ predictions (np.ndarray): Model output predictions as a 2D NumPy array (probabilities or logits).
246
+ labels (np.ndarray): Ground truth labels as a 2D NumPy array.
247
+ per_class_metrics (bool): If True, plots metrics for each class individually.
248
+
249
+ Raises:
250
+ ValueError: If metrics calculation or plotting fails.
251
+
252
+ Returns:
253
+ None
254
+ """
255
+ # Save the original threshold value to restore it later
256
+ original_threshold = self.threshold
257
+
258
+ # Define a range of thresholds for analysis
259
+ thresholds = np.arange(0.05, 1.0, 0.05)
260
+
261
+ # Exclude metrics that are not threshold-dependent
262
+ metrics_to_plot = [m for m in self.metrics_list if m not in ["auroc", "ap"]]
263
+
264
+ if per_class_metrics:
265
+ # Define class names for plotting
266
+ class_names = list(self.classes) if self.classes else [f"Class {i}" for i in range(self.num_classes)]
267
+
268
+ # Initialize a dictionary to store metric values per class
269
+ metric_values_dict_per_class = {
270
+ class_name: {metric: [] for metric in metrics_to_plot} for class_name in class_names
271
+ }
272
+
273
+ # Compute metrics for each threshold
274
+ for thresh in thresholds:
275
+ self.threshold = thresh
276
+ metrics_df = self.calculate_metrics(predictions, labels, per_class_metrics=True)
277
+ for metric_name in metrics_to_plot:
278
+ metric_label = metric_name.capitalize() if metric_name != "f1" else "F1"
279
+ for class_name in class_names:
280
+ value = metrics_df.loc[metric_label, class_name]
281
+ metric_values_dict_per_class[class_name][metric_name].append(value)
282
+
283
+ # Restore the original threshold
284
+ self.threshold = original_threshold
285
+
286
+ # Plot metrics across thresholds per class
287
+ fig = plotting.plot_metrics_across_thresholds_per_class(
288
+ thresholds,
289
+ metric_values_dict_per_class,
290
+ metrics_to_plot,
291
+ class_names,
292
+ self.colors,
293
+ )
294
+ else:
295
+ # Initialize a dictionary to store overall metric values
296
+ metric_values_dict = {metric_name: [] for metric_name in metrics_to_plot}
297
+
298
+ # Compute metrics for each threshold
299
+ for thresh in thresholds:
300
+ self.threshold = thresh
301
+ metrics_df = self.calculate_metrics(predictions, labels, per_class_metrics=False)
302
+ for metric_name in metrics_to_plot:
303
+ metric_label = metric_name.capitalize() if metric_name != "f1" else "F1"
304
+ value = metrics_df.loc[metric_label, "Overall"]
305
+ metric_values_dict[metric_name].append(value)
306
+
307
+ # Restore the original threshold
308
+ self.threshold = original_threshold
309
+
310
+ # Plot metrics across thresholds
311
+ fig = plotting.plot_metrics_across_thresholds(
312
+ thresholds,
313
+ metric_values_dict,
314
+ metrics_to_plot,
315
+ self.colors,
316
+ )
317
+
318
+ return fig
319
+
320
+ def plot_confusion_matrix(
321
+ self,
322
+ predictions: np.ndarray,
323
+ labels: np.ndarray,
324
+ ) -> None:
325
+ """
326
+ Plot confusion matrices for each class using scikit-learn's ConfusionMatrixDisplay.
327
+
328
+ Args:
329
+ predictions (np.ndarray): Model output predictions as a 2D NumPy array (probabilities or logits).
330
+ labels (np.ndarray): Ground truth labels as a 2D NumPy array.
331
+
332
+ Raises:
333
+ TypeError: If predictions or labels are not NumPy arrays.
334
+ ValueError: If predictions and labels have mismatched shapes or invalid dimensions.
335
+
336
+ Returns:
337
+ None
338
+ """
339
+ # Validate that predictions and labels are NumPy arrays and match in shape
340
+ if not isinstance(predictions, np.ndarray):
341
+ raise TypeError("predictions must be a NumPy array.")
342
+ if not isinstance(labels, np.ndarray):
343
+ raise TypeError("labels must be a NumPy array.")
344
+ if predictions.shape != labels.shape:
345
+ raise ValueError("predictions and labels must have the same shape.")
346
+ if predictions.ndim != 2:
347
+ raise ValueError("predictions and labels must be 2-dimensional arrays.")
348
+ if predictions.shape[1] != self.num_classes:
349
+ raise ValueError(
350
+ f"The number of columns in predictions ({predictions.shape[1]}) "
351
+ + f"must match num_classes ({self.num_classes})."
352
+ )
353
+
354
+ if self.task == "binary":
355
+ # Binarize predictions using the threshold
356
+ y_pred = (predictions >= self.threshold).astype(int).flatten()
357
+ y_true = labels.astype(int).flatten()
358
+
359
+ # Compute and normalize the confusion matrix
360
+ conf_mat = confusion_matrix(y_true, y_pred, normalize="true")
361
+ conf_mat = np.round(conf_mat, 2)
362
+
363
+ # Plot the confusion matrix
364
+ disp = ConfusionMatrixDisplay(confusion_matrix=conf_mat, display_labels=["Negative", "Positive"])
365
+ fig, ax = plt.subplots(figsize=(6, 6))
366
+ disp.plot(cmap="Reds", ax=ax, colorbar=False, values_format=".2f")
367
+ ax.set_title("Confusion Matrix")
368
+
369
+ return fig
370
+
371
+ if self.task == "multilabel":
372
+ # Binarize predictions for multilabel classification
373
+ y_pred = (predictions >= self.threshold).astype(int)
374
+ y_true = labels.astype(int)
375
+
376
+ # Compute confusion matrices for each class
377
+ conf_mats = []
378
+ class_names = self.classes if self.classes else [f"Class {i}" for i in range(self.num_classes)]
379
+ for i in range(self.num_classes):
380
+ conf_mat = confusion_matrix(y_true[:, i], y_pred[:, i], normalize="true")
381
+ conf_mat = np.round(conf_mat, 2)
382
+ conf_mats.append(conf_mat)
383
+
384
+ # Determine grid size for subplots
385
+ num_matrices = self.num_classes
386
+ n_cols = int(np.ceil(np.sqrt(num_matrices)))
387
+ n_rows = int(np.ceil(num_matrices / n_cols))
388
+
389
+ # Create subplots for each confusion matrix
390
+ fig, axes = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, 4 * n_rows))
391
+ axes = axes.flatten()
392
+
393
+ # Plot each confusion matrix
394
+ for idx, (conf_mat, class_name) in enumerate(zip(conf_mats, class_names, strict=True)):
395
+ disp = ConfusionMatrixDisplay(confusion_matrix=conf_mat, display_labels=["Negative", "Positive"])
396
+ disp.plot(cmap="Reds", ax=axes[idx], colorbar=False, values_format=".2f")
397
+ axes[idx].set_title(f"{class_name}")
398
+ axes[idx].set_xlabel("Predicted class")
399
+ axes[idx].set_ylabel("True class")
400
+
401
+ # Remove unused subplot axes
402
+ for ax in axes[num_matrices:]:
403
+ fig.delaxes(ax)
404
+
405
+ plt.tight_layout()
406
+
407
+ return fig
408
+
409
+ raise ValueError(f"Unsupported task type: {self.task}")