birdnet-analyzer 2.0.0__py3-none-any.whl → 2.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (123) hide show
  1. birdnet_analyzer/__init__.py +9 -8
  2. birdnet_analyzer/analyze/__init__.py +19 -5
  3. birdnet_analyzer/analyze/__main__.py +3 -4
  4. birdnet_analyzer/analyze/cli.py +30 -25
  5. birdnet_analyzer/analyze/core.py +246 -245
  6. birdnet_analyzer/analyze/utils.py +694 -701
  7. birdnet_analyzer/audio.py +368 -372
  8. birdnet_analyzer/cli.py +732 -707
  9. birdnet_analyzer/config.py +243 -242
  10. birdnet_analyzer/eBird_taxonomy_codes_2024E.json +13046 -0
  11. birdnet_analyzer/embeddings/__init__.py +3 -4
  12. birdnet_analyzer/embeddings/__main__.py +3 -3
  13. birdnet_analyzer/embeddings/cli.py +12 -13
  14. birdnet_analyzer/embeddings/core.py +70 -70
  15. birdnet_analyzer/embeddings/utils.py +220 -193
  16. birdnet_analyzer/evaluation/__init__.py +189 -195
  17. birdnet_analyzer/evaluation/__main__.py +3 -3
  18. birdnet_analyzer/evaluation/assessment/__init__.py +0 -0
  19. birdnet_analyzer/evaluation/assessment/metrics.py +388 -0
  20. birdnet_analyzer/evaluation/assessment/performance_assessor.py +364 -0
  21. birdnet_analyzer/evaluation/assessment/plotting.py +378 -0
  22. birdnet_analyzer/evaluation/preprocessing/__init__.py +0 -0
  23. birdnet_analyzer/evaluation/preprocessing/data_processor.py +631 -0
  24. birdnet_analyzer/evaluation/preprocessing/utils.py +98 -0
  25. birdnet_analyzer/gui/__init__.py +19 -23
  26. birdnet_analyzer/gui/__main__.py +3 -3
  27. birdnet_analyzer/gui/analysis.py +179 -174
  28. birdnet_analyzer/gui/assets/arrow_down.svg +4 -4
  29. birdnet_analyzer/gui/assets/arrow_left.svg +4 -4
  30. birdnet_analyzer/gui/assets/arrow_right.svg +4 -4
  31. birdnet_analyzer/gui/assets/arrow_up.svg +4 -4
  32. birdnet_analyzer/gui/assets/gui.css +36 -28
  33. birdnet_analyzer/gui/assets/gui.js +93 -93
  34. birdnet_analyzer/gui/embeddings.py +638 -620
  35. birdnet_analyzer/gui/evaluation.py +801 -813
  36. birdnet_analyzer/gui/localization.py +75 -68
  37. birdnet_analyzer/gui/multi_file.py +265 -246
  38. birdnet_analyzer/gui/review.py +472 -527
  39. birdnet_analyzer/gui/segments.py +191 -191
  40. birdnet_analyzer/gui/settings.py +149 -129
  41. birdnet_analyzer/gui/single_file.py +264 -269
  42. birdnet_analyzer/gui/species.py +95 -95
  43. birdnet_analyzer/gui/train.py +687 -698
  44. birdnet_analyzer/gui/utils.py +797 -808
  45. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_af.txt +6522 -6522
  46. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ar.txt +6522 -6522
  47. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_bg.txt +6522 -6522
  48. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ca.txt +6522 -6522
  49. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_cs.txt +6522 -6522
  50. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_da.txt +6522 -6522
  51. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_de.txt +6522 -6522
  52. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_el.txt +6522 -6522
  53. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_en_uk.txt +6522 -6522
  54. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_es.txt +6522 -6522
  55. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_fi.txt +6522 -6522
  56. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_fr.txt +6522 -6522
  57. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_he.txt +6522 -6522
  58. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_hr.txt +6522 -6522
  59. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_hu.txt +6522 -6522
  60. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_in.txt +6522 -6522
  61. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_is.txt +6522 -6522
  62. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_it.txt +6522 -6522
  63. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ja.txt +6522 -6522
  64. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ko.txt +6522 -6522
  65. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_lt.txt +6522 -6522
  66. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ml.txt +6522 -6522
  67. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_nl.txt +6522 -6522
  68. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_no.txt +6522 -6522
  69. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_pl.txt +6522 -6522
  70. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_pt_BR.txt +6522 -6522
  71. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_pt_PT.txt +6522 -6522
  72. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ro.txt +6522 -6522
  73. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ru.txt +6522 -6522
  74. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sk.txt +6522 -6522
  75. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sl.txt +6522 -6522
  76. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sr.txt +6522 -6522
  77. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sv.txt +6522 -6522
  78. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_th.txt +6522 -6522
  79. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_tr.txt +6522 -6522
  80. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_uk.txt +6522 -6522
  81. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_zh.txt +6522 -6522
  82. birdnet_analyzer/lang/de.json +341 -334
  83. birdnet_analyzer/lang/en.json +341 -334
  84. birdnet_analyzer/lang/fi.json +341 -334
  85. birdnet_analyzer/lang/fr.json +341 -334
  86. birdnet_analyzer/lang/id.json +341 -334
  87. birdnet_analyzer/lang/pt-br.json +341 -334
  88. birdnet_analyzer/lang/ru.json +341 -334
  89. birdnet_analyzer/lang/se.json +341 -334
  90. birdnet_analyzer/lang/tlh.json +341 -334
  91. birdnet_analyzer/lang/zh_TW.json +341 -334
  92. birdnet_analyzer/model.py +1212 -1243
  93. birdnet_analyzer/playground.py +5 -0
  94. birdnet_analyzer/search/__init__.py +3 -3
  95. birdnet_analyzer/search/__main__.py +3 -3
  96. birdnet_analyzer/search/cli.py +11 -12
  97. birdnet_analyzer/search/core.py +78 -78
  98. birdnet_analyzer/search/utils.py +107 -111
  99. birdnet_analyzer/segments/__init__.py +3 -3
  100. birdnet_analyzer/segments/__main__.py +3 -3
  101. birdnet_analyzer/segments/cli.py +13 -14
  102. birdnet_analyzer/segments/core.py +81 -78
  103. birdnet_analyzer/segments/utils.py +383 -394
  104. birdnet_analyzer/species/__init__.py +3 -3
  105. birdnet_analyzer/species/__main__.py +3 -3
  106. birdnet_analyzer/species/cli.py +13 -14
  107. birdnet_analyzer/species/core.py +35 -35
  108. birdnet_analyzer/species/utils.py +74 -75
  109. birdnet_analyzer/train/__init__.py +3 -3
  110. birdnet_analyzer/train/__main__.py +3 -3
  111. birdnet_analyzer/train/cli.py +13 -14
  112. birdnet_analyzer/train/core.py +113 -113
  113. birdnet_analyzer/train/utils.py +877 -847
  114. birdnet_analyzer/translate.py +133 -104
  115. birdnet_analyzer/utils.py +425 -419
  116. {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.1.0.dist-info}/METADATA +146 -129
  117. birdnet_analyzer-2.1.0.dist-info/RECORD +125 -0
  118. {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.1.0.dist-info}/WHEEL +1 -1
  119. {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.1.0.dist-info}/licenses/LICENSE +18 -18
  120. birdnet_analyzer/eBird_taxonomy_codes_2021E.json +0 -25280
  121. birdnet_analyzer-2.0.0.dist-info/RECORD +0 -117
  122. {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.1.0.dist-info}/entry_points.txt +0 -0
  123. {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,364 @@
1
+ """
2
+ PerformanceAssessor Module
3
+
4
+ This module defines the `PerformanceAssessor` class to evaluate classification model performance.
5
+ It includes methods to compute metrics like precision, recall, F1 score, AUROC, and accuracy,
6
+ as well as utilities for generating related plots.
7
+ """
8
+
9
+ from typing import Literal
10
+
11
+ import numpy as np
12
+ import pandas as pd
13
+ from sklearn.metrics import confusion_matrix
14
+
15
+ from birdnet_analyzer.evaluation.assessment import metrics, plotting
16
+
17
+
18
+ class PerformanceAssessor:
19
+ """
20
+ A class to assess the performance of classification models by computing metrics
21
+ and generating visualizations for binary and multilabel classification tasks.
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ num_classes: int,
27
+ threshold: float = 0.5,
28
+ classes: tuple[str, ...] | None = None,
29
+ task: Literal["binary", "multilabel"] = "multilabel",
30
+ metrics_list: tuple[str, ...] = (
31
+ "recall",
32
+ "precision",
33
+ "f1",
34
+ "ap",
35
+ "auroc",
36
+ "accuracy",
37
+ ),
38
+ ) -> None:
39
+ """
40
+ Initialize the PerformanceAssessor.
41
+
42
+ Args:
43
+ num_classes (int): The number of classes in the classification problem.
44
+ threshold (float): The threshold for binarizing probabilities into class labels.
45
+ classes (Optional[Tuple[str, ...]]): Optional tuple of class names.
46
+ task (Literal["binary", "multilabel"]): The classification task type.
47
+ metrics_list (Tuple[str, ...]): A tuple of metrics to compute.
48
+
49
+ Raises:
50
+ ValueError: If any of the inputs are invalid.
51
+ """
52
+ # Validate the number of classes
53
+ if not isinstance(num_classes, int) or num_classes <= 0:
54
+ raise ValueError("num_classes must be a positive integer.")
55
+
56
+ # Validate the threshold value
57
+ if not isinstance(threshold, float) or not 0 < threshold < 1:
58
+ raise ValueError("threshold must be a float between 0 and 1 (exclusive).")
59
+
60
+ # Validate class names
61
+ if classes is not None:
62
+ if not isinstance(classes, tuple):
63
+ raise ValueError("classes must be a tuple of strings.")
64
+ if len(classes) != num_classes:
65
+ raise ValueError(f"Length of classes ({len(classes)}) must match num_classes ({num_classes}).")
66
+ if not all(isinstance(class_name, str) for class_name in classes):
67
+ raise ValueError("All elements in classes must be strings.")
68
+
69
+ # Validate the task type
70
+ if task not in {"binary", "multilabel"}:
71
+ raise ValueError("task must be 'binary' or 'multilabel'.")
72
+
73
+ # Validate the metrics list
74
+ valid_metrics = ["accuracy", "recall", "precision", "f1", "ap", "auroc"]
75
+ if not metrics_list:
76
+ raise ValueError("metrics_list cannot be empty.")
77
+ if not all(metric in valid_metrics for metric in metrics_list):
78
+ raise ValueError(f"Invalid metrics in {metrics_list}. Valid options are {valid_metrics}.")
79
+
80
+ # Assign instance variables
81
+ self.num_classes = num_classes
82
+ self.threshold = threshold
83
+ self.classes = classes
84
+ self.task = task
85
+ self.metrics_list = metrics_list
86
+
87
+ # Set default colors for plotting
88
+ self.colors = ["#3A50B1", "#61A83E", "#D74C4C", "#A13FA1", "#D9A544", "#F3A6E0"]
89
+
90
+ def calculate_metrics(
91
+ self,
92
+ predictions: np.ndarray,
93
+ labels: np.ndarray,
94
+ per_class_metrics: bool = False,
95
+ ) -> pd.DataFrame:
96
+ """
97
+ Calculate multiple performance metrics for the given predictions and labels.
98
+
99
+ Args:
100
+ predictions (np.ndarray): Model predictions as a 2D NumPy array (probabilities or logits).
101
+ labels (np.ndarray): Ground truth labels as a 2D NumPy array.
102
+ per_class_metrics (bool): If True, compute metrics for each class individually.
103
+
104
+ Returns:
105
+ pd.DataFrame: A DataFrame containing the computed metrics.
106
+
107
+ Raises:
108
+ TypeError: If predictions or labels are not NumPy arrays.
109
+ ValueError: If predictions and labels have mismatched dimensions or invalid shapes.
110
+ """
111
+ # Validate that predictions and labels are NumPy arrays
112
+ if not isinstance(predictions, np.ndarray):
113
+ raise TypeError("predictions must be a NumPy array.")
114
+ if not isinstance(labels, np.ndarray):
115
+ raise TypeError("labels must be a NumPy array.")
116
+
117
+ # Ensure predictions and labels have the same shape
118
+ if predictions.shape != labels.shape:
119
+ raise ValueError("predictions and labels must have the same shape.")
120
+ if predictions.ndim != 2:
121
+ raise ValueError("predictions and labels must be 2-dimensional arrays.")
122
+ if predictions.shape[1] != self.num_classes:
123
+ raise ValueError(f"The number of columns in predictions ({predictions.shape[1]}) " + f"must match num_classes ({self.num_classes}).")
124
+
125
+ # Determine the averaging method for metrics
126
+ if per_class_metrics and self.num_classes == 1:
127
+ averaging_method = "macro"
128
+ else:
129
+ averaging_method = None if per_class_metrics else "macro"
130
+
131
+ # Dictionary to store the results of each metric
132
+ metrics_results = {}
133
+
134
+ # Compute each metric in the metrics list
135
+ for metric_name in self.metrics_list:
136
+ if metric_name == "recall":
137
+ result = metrics.calculate_recall(
138
+ predictions=predictions,
139
+ labels=labels,
140
+ task=self.task,
141
+ threshold=self.threshold,
142
+ averaging_method=averaging_method,
143
+ )
144
+ metrics_results["Recall"] = np.atleast_1d(result)
145
+ elif metric_name == "precision":
146
+ result = metrics.calculate_precision(
147
+ predictions=predictions,
148
+ labels=labels,
149
+ task=self.task,
150
+ threshold=self.threshold,
151
+ averaging_method=averaging_method,
152
+ )
153
+ metrics_results["Precision"] = np.atleast_1d(result)
154
+ elif metric_name == "f1":
155
+ result = metrics.calculate_f1_score(
156
+ predictions=predictions,
157
+ labels=labels,
158
+ task=self.task,
159
+ threshold=self.threshold,
160
+ averaging_method=averaging_method,
161
+ )
162
+ metrics_results["F1"] = np.atleast_1d(result)
163
+ elif metric_name == "ap":
164
+ result = metrics.calculate_average_precision(
165
+ predictions=predictions,
166
+ labels=labels,
167
+ task=self.task,
168
+ averaging_method=averaging_method,
169
+ )
170
+ metrics_results["AP"] = np.atleast_1d(result)
171
+ elif metric_name == "auroc":
172
+ result = metrics.calculate_auroc(
173
+ predictions=predictions,
174
+ labels=labels,
175
+ task=self.task,
176
+ averaging_method=averaging_method,
177
+ )
178
+ metrics_results["AUROC"] = np.atleast_1d(result)
179
+ elif metric_name == "accuracy":
180
+ result = metrics.calculate_accuracy(
181
+ predictions=predictions,
182
+ labels=labels,
183
+ task=self.task,
184
+ num_classes=self.num_classes,
185
+ threshold=self.threshold,
186
+ averaging_method=averaging_method,
187
+ )
188
+ metrics_results["Accuracy"] = np.atleast_1d(result)
189
+
190
+ # Define column names for the DataFrame
191
+ columns = (self.classes if self.classes else [f"Class {i}" for i in range(self.num_classes)]) if per_class_metrics else ["Overall"]
192
+
193
+ # Create a DataFrame to organize metric results
194
+ metrics_data = {key: np.atleast_1d(value) for key, value in metrics_results.items()}
195
+ return pd.DataFrame.from_dict(metrics_data, orient="index", columns=columns)
196
+
197
+ def plot_metrics(
198
+ self,
199
+ predictions: np.ndarray,
200
+ labels: np.ndarray,
201
+ per_class_metrics: bool = False,
202
+ ):
203
+ """
204
+ Plot performance metrics for the given predictions and labels.
205
+
206
+ Args:
207
+ predictions (np.ndarray): Model output predictions as a 2D NumPy array (probabilities or logits).
208
+ labels (np.ndarray): Ground truth labels as a 2D NumPy array.
209
+ per_class_metrics (bool): If True, plots metrics for each class individually.
210
+
211
+ Raises:
212
+ ValueError: If the metrics cannot be calculated or plotting fails.
213
+
214
+ Returns:
215
+ None
216
+ """
217
+ # Calculate metrics using the provided predictions and labels
218
+ metrics_df = self.calculate_metrics(predictions, labels, per_class_metrics)
219
+
220
+ # Choose the plotting method based on whether per-class metrics are required
221
+ return plotting.plot_metrics_per_class(metrics_df, self.colors) if per_class_metrics else plotting.plot_overall_metrics(metrics_df, self.colors)
222
+
223
+ def plot_metrics_all_thresholds(
224
+ self,
225
+ predictions: np.ndarray,
226
+ labels: np.ndarray,
227
+ per_class_metrics: bool = False,
228
+ ):
229
+ """
230
+ Plot performance metrics across thresholds for the given predictions and labels.
231
+
232
+ Args:
233
+ predictions (np.ndarray): Model output predictions as a 2D NumPy array (probabilities or logits).
234
+ labels (np.ndarray): Ground truth labels as a 2D NumPy array.
235
+ per_class_metrics (bool): If True, plots metrics for each class individually.
236
+
237
+ Raises:
238
+ ValueError: If metrics calculation or plotting fails.
239
+
240
+ Returns:
241
+ None
242
+ """
243
+ # Save the original threshold value to restore it later
244
+ original_threshold = self.threshold
245
+
246
+ # Define a range of thresholds for analysis
247
+ thresholds = np.arange(0.05, 1.0, 0.05)
248
+
249
+ # Exclude metrics that are not threshold-dependent
250
+ metrics_to_plot = [m for m in self.metrics_list if m not in ["auroc", "ap"]]
251
+
252
+ if per_class_metrics:
253
+ # Define class names for plotting
254
+ class_names = list(self.classes) if self.classes else [f"Class {i}" for i in range(self.num_classes)]
255
+
256
+ # Initialize a dictionary to store metric values per class
257
+ metric_values_dict_per_class = {class_name: {metric: [] for metric in metrics_to_plot} for class_name in class_names}
258
+
259
+ # Compute metrics for each threshold
260
+ for thresh in thresholds:
261
+ self.threshold = thresh
262
+ metrics_df = self.calculate_metrics(predictions, labels, per_class_metrics=True)
263
+ for metric_name in metrics_to_plot:
264
+ metric_label = metric_name.capitalize() if metric_name != "f1" else "F1"
265
+ for class_name in class_names:
266
+ value = metrics_df.loc[metric_label, class_name]
267
+ metric_values_dict_per_class[class_name][metric_name].append(value)
268
+
269
+ # Restore the original threshold
270
+ self.threshold = original_threshold
271
+
272
+ # Plot metrics across thresholds per class
273
+ fig = plotting.plot_metrics_across_thresholds_per_class(
274
+ thresholds,
275
+ metric_values_dict_per_class,
276
+ metrics_to_plot,
277
+ class_names,
278
+ self.colors,
279
+ )
280
+ else:
281
+ # Initialize a dictionary to store overall metric values
282
+ metric_values_dict = {metric_name: [] for metric_name in metrics_to_plot}
283
+
284
+ # Compute metrics for each threshold
285
+ for thresh in thresholds:
286
+ self.threshold = thresh
287
+ metrics_df = self.calculate_metrics(predictions, labels, per_class_metrics=False)
288
+ for metric_name in metrics_to_plot:
289
+ metric_label = metric_name.capitalize() if metric_name != "f1" else "F1"
290
+ value = metrics_df.loc[metric_label, "Overall"]
291
+ metric_values_dict[metric_name].append(value)
292
+
293
+ # Restore the original threshold
294
+ self.threshold = original_threshold
295
+
296
+ # Plot metrics across thresholds
297
+ fig = plotting.plot_metrics_across_thresholds(
298
+ thresholds,
299
+ metric_values_dict,
300
+ metrics_to_plot,
301
+ self.colors,
302
+ )
303
+
304
+ return fig
305
+
306
+ def plot_confusion_matrix(
307
+ self,
308
+ predictions: np.ndarray,
309
+ labels: np.ndarray,
310
+ ):
311
+ """
312
+ Plot confusion matrices for each class using scikit-learn's ConfusionMatrixDisplay.
313
+
314
+ Args:
315
+ predictions (np.ndarray): Model output predictions as a 2D NumPy array (probabilities or logits).
316
+ labels (np.ndarray): Ground truth labels as a 2D NumPy array.
317
+
318
+ Raises:
319
+ TypeError: If predictions or labels are not NumPy arrays.
320
+ ValueError: If predictions and labels have mismatched shapes or invalid dimensions.
321
+
322
+ Returns:
323
+ None
324
+ """
325
+ # Validate that predictions and labels are NumPy arrays and match in shape
326
+ if not isinstance(predictions, np.ndarray):
327
+ raise TypeError("predictions must be a NumPy array.")
328
+ if not isinstance(labels, np.ndarray):
329
+ raise TypeError("labels must be a NumPy array.")
330
+ if predictions.shape != labels.shape:
331
+ raise ValueError("predictions and labels must have the same shape.")
332
+ if predictions.ndim != 2:
333
+ raise ValueError("predictions and labels must be 2-dimensional arrays.")
334
+ if predictions.shape[1] != self.num_classes:
335
+ raise ValueError(f"The number of columns in predictions ({predictions.shape[1]}) " + f"must match num_classes ({self.num_classes}).")
336
+
337
+ if self.task == "binary":
338
+ # Binarize predictions using the threshold
339
+ y_pred = (predictions >= self.threshold).astype(int).flatten()
340
+ y_true = labels.astype(int).flatten()
341
+
342
+ # Compute and normalize the confusion matrix
343
+ conf_mat = confusion_matrix(y_true, y_pred, normalize="true")
344
+ conf_mat = np.round(conf_mat, 2)
345
+
346
+ return plotting.plot_confusion_matrices(conf_mat, self.task, self.classes)
347
+
348
+ if self.task == "multilabel":
349
+ # Binarize predictions for multilabel classification
350
+ y_pred = (predictions >= self.threshold).astype(int)
351
+ y_true = labels.astype(int)
352
+
353
+ # Compute confusion matrices for each class
354
+ conf_mats = []
355
+ class_names = self.classes if self.classes else [f"Class {i}" for i in range(self.num_classes)]
356
+
357
+ for i in range(self.num_classes):
358
+ conf_mat = confusion_matrix(y_true[:, i], y_pred[:, i], normalize="true")
359
+ conf_mat = np.round(conf_mat, 2)
360
+ conf_mats.append(conf_mat)
361
+
362
+ return plotting.plot_confusion_matrices(np.array(conf_mats), self.task, class_names)
363
+
364
+ raise ValueError(f"Unsupported task type: {self.task}")