birdnet-analyzer 2.0.0__py3-none-any.whl → 2.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- birdnet_analyzer/__init__.py +9 -8
- birdnet_analyzer/analyze/__init__.py +19 -5
- birdnet_analyzer/analyze/__main__.py +3 -4
- birdnet_analyzer/analyze/cli.py +30 -25
- birdnet_analyzer/analyze/core.py +246 -245
- birdnet_analyzer/analyze/utils.py +694 -701
- birdnet_analyzer/audio.py +368 -372
- birdnet_analyzer/cli.py +732 -707
- birdnet_analyzer/config.py +243 -242
- birdnet_analyzer/eBird_taxonomy_codes_2024E.json +13046 -0
- birdnet_analyzer/embeddings/__init__.py +3 -4
- birdnet_analyzer/embeddings/__main__.py +3 -3
- birdnet_analyzer/embeddings/cli.py +12 -13
- birdnet_analyzer/embeddings/core.py +70 -70
- birdnet_analyzer/embeddings/utils.py +220 -193
- birdnet_analyzer/evaluation/__init__.py +189 -195
- birdnet_analyzer/evaluation/__main__.py +3 -3
- birdnet_analyzer/evaluation/assessment/__init__.py +0 -0
- birdnet_analyzer/evaluation/assessment/metrics.py +388 -0
- birdnet_analyzer/evaluation/assessment/performance_assessor.py +364 -0
- birdnet_analyzer/evaluation/assessment/plotting.py +378 -0
- birdnet_analyzer/evaluation/preprocessing/__init__.py +0 -0
- birdnet_analyzer/evaluation/preprocessing/data_processor.py +631 -0
- birdnet_analyzer/evaluation/preprocessing/utils.py +98 -0
- birdnet_analyzer/gui/__init__.py +19 -23
- birdnet_analyzer/gui/__main__.py +3 -3
- birdnet_analyzer/gui/analysis.py +179 -174
- birdnet_analyzer/gui/assets/arrow_down.svg +4 -4
- birdnet_analyzer/gui/assets/arrow_left.svg +4 -4
- birdnet_analyzer/gui/assets/arrow_right.svg +4 -4
- birdnet_analyzer/gui/assets/arrow_up.svg +4 -4
- birdnet_analyzer/gui/assets/gui.css +36 -28
- birdnet_analyzer/gui/assets/gui.js +93 -93
- birdnet_analyzer/gui/embeddings.py +638 -620
- birdnet_analyzer/gui/evaluation.py +801 -813
- birdnet_analyzer/gui/localization.py +75 -68
- birdnet_analyzer/gui/multi_file.py +265 -246
- birdnet_analyzer/gui/review.py +472 -527
- birdnet_analyzer/gui/segments.py +191 -191
- birdnet_analyzer/gui/settings.py +149 -129
- birdnet_analyzer/gui/single_file.py +264 -269
- birdnet_analyzer/gui/species.py +95 -95
- birdnet_analyzer/gui/train.py +687 -698
- birdnet_analyzer/gui/utils.py +797 -808
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_af.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ar.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_bg.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ca.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_cs.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_da.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_de.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_el.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_en_uk.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_es.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_fi.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_fr.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_he.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_hr.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_hu.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_in.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_is.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_it.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ja.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ko.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_lt.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ml.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_nl.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_no.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_pl.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_pt_BR.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_pt_PT.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ro.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ru.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sk.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sl.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sr.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sv.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_th.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_tr.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_uk.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_zh.txt +6522 -6522
- birdnet_analyzer/lang/de.json +341 -334
- birdnet_analyzer/lang/en.json +341 -334
- birdnet_analyzer/lang/fi.json +341 -334
- birdnet_analyzer/lang/fr.json +341 -334
- birdnet_analyzer/lang/id.json +341 -334
- birdnet_analyzer/lang/pt-br.json +341 -334
- birdnet_analyzer/lang/ru.json +341 -334
- birdnet_analyzer/lang/se.json +341 -334
- birdnet_analyzer/lang/tlh.json +341 -334
- birdnet_analyzer/lang/zh_TW.json +341 -334
- birdnet_analyzer/model.py +1212 -1243
- birdnet_analyzer/playground.py +5 -0
- birdnet_analyzer/search/__init__.py +3 -3
- birdnet_analyzer/search/__main__.py +3 -3
- birdnet_analyzer/search/cli.py +11 -12
- birdnet_analyzer/search/core.py +78 -78
- birdnet_analyzer/search/utils.py +107 -111
- birdnet_analyzer/segments/__init__.py +3 -3
- birdnet_analyzer/segments/__main__.py +3 -3
- birdnet_analyzer/segments/cli.py +13 -14
- birdnet_analyzer/segments/core.py +81 -78
- birdnet_analyzer/segments/utils.py +383 -394
- birdnet_analyzer/species/__init__.py +3 -3
- birdnet_analyzer/species/__main__.py +3 -3
- birdnet_analyzer/species/cli.py +13 -14
- birdnet_analyzer/species/core.py +35 -35
- birdnet_analyzer/species/utils.py +74 -75
- birdnet_analyzer/train/__init__.py +3 -3
- birdnet_analyzer/train/__main__.py +3 -3
- birdnet_analyzer/train/cli.py +13 -14
- birdnet_analyzer/train/core.py +113 -113
- birdnet_analyzer/train/utils.py +877 -847
- birdnet_analyzer/translate.py +133 -104
- birdnet_analyzer/utils.py +425 -419
- {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.1.0.dist-info}/METADATA +146 -129
- birdnet_analyzer-2.1.0.dist-info/RECORD +125 -0
- {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.1.0.dist-info}/WHEEL +1 -1
- {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.1.0.dist-info}/licenses/LICENSE +18 -18
- birdnet_analyzer/eBird_taxonomy_codes_2021E.json +0 -25280
- birdnet_analyzer-2.0.0.dist-info/RECORD +0 -117
- {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.1.0.dist-info}/entry_points.txt +0 -0
- {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,364 @@
|
|
1
|
+
"""
|
2
|
+
PerformanceAssessor Module
|
3
|
+
|
4
|
+
This module defines the `PerformanceAssessor` class to evaluate classification model performance.
|
5
|
+
It includes methods to compute metrics like precision, recall, F1 score, AUROC, and accuracy,
|
6
|
+
as well as utilities for generating related plots.
|
7
|
+
"""
|
8
|
+
|
9
|
+
from typing import Literal
|
10
|
+
|
11
|
+
import numpy as np
|
12
|
+
import pandas as pd
|
13
|
+
from sklearn.metrics import confusion_matrix
|
14
|
+
|
15
|
+
from birdnet_analyzer.evaluation.assessment import metrics, plotting
|
16
|
+
|
17
|
+
|
18
|
+
class PerformanceAssessor:
|
19
|
+
"""
|
20
|
+
A class to assess the performance of classification models by computing metrics
|
21
|
+
and generating visualizations for binary and multilabel classification tasks.
|
22
|
+
"""
|
23
|
+
|
24
|
+
def __init__(
|
25
|
+
self,
|
26
|
+
num_classes: int,
|
27
|
+
threshold: float = 0.5,
|
28
|
+
classes: tuple[str, ...] | None = None,
|
29
|
+
task: Literal["binary", "multilabel"] = "multilabel",
|
30
|
+
metrics_list: tuple[str, ...] = (
|
31
|
+
"recall",
|
32
|
+
"precision",
|
33
|
+
"f1",
|
34
|
+
"ap",
|
35
|
+
"auroc",
|
36
|
+
"accuracy",
|
37
|
+
),
|
38
|
+
) -> None:
|
39
|
+
"""
|
40
|
+
Initialize the PerformanceAssessor.
|
41
|
+
|
42
|
+
Args:
|
43
|
+
num_classes (int): The number of classes in the classification problem.
|
44
|
+
threshold (float): The threshold for binarizing probabilities into class labels.
|
45
|
+
classes (Optional[Tuple[str, ...]]): Optional tuple of class names.
|
46
|
+
task (Literal["binary", "multilabel"]): The classification task type.
|
47
|
+
metrics_list (Tuple[str, ...]): A tuple of metrics to compute.
|
48
|
+
|
49
|
+
Raises:
|
50
|
+
ValueError: If any of the inputs are invalid.
|
51
|
+
"""
|
52
|
+
# Validate the number of classes
|
53
|
+
if not isinstance(num_classes, int) or num_classes <= 0:
|
54
|
+
raise ValueError("num_classes must be a positive integer.")
|
55
|
+
|
56
|
+
# Validate the threshold value
|
57
|
+
if not isinstance(threshold, float) or not 0 < threshold < 1:
|
58
|
+
raise ValueError("threshold must be a float between 0 and 1 (exclusive).")
|
59
|
+
|
60
|
+
# Validate class names
|
61
|
+
if classes is not None:
|
62
|
+
if not isinstance(classes, tuple):
|
63
|
+
raise ValueError("classes must be a tuple of strings.")
|
64
|
+
if len(classes) != num_classes:
|
65
|
+
raise ValueError(f"Length of classes ({len(classes)}) must match num_classes ({num_classes}).")
|
66
|
+
if not all(isinstance(class_name, str) for class_name in classes):
|
67
|
+
raise ValueError("All elements in classes must be strings.")
|
68
|
+
|
69
|
+
# Validate the task type
|
70
|
+
if task not in {"binary", "multilabel"}:
|
71
|
+
raise ValueError("task must be 'binary' or 'multilabel'.")
|
72
|
+
|
73
|
+
# Validate the metrics list
|
74
|
+
valid_metrics = ["accuracy", "recall", "precision", "f1", "ap", "auroc"]
|
75
|
+
if not metrics_list:
|
76
|
+
raise ValueError("metrics_list cannot be empty.")
|
77
|
+
if not all(metric in valid_metrics for metric in metrics_list):
|
78
|
+
raise ValueError(f"Invalid metrics in {metrics_list}. Valid options are {valid_metrics}.")
|
79
|
+
|
80
|
+
# Assign instance variables
|
81
|
+
self.num_classes = num_classes
|
82
|
+
self.threshold = threshold
|
83
|
+
self.classes = classes
|
84
|
+
self.task = task
|
85
|
+
self.metrics_list = metrics_list
|
86
|
+
|
87
|
+
# Set default colors for plotting
|
88
|
+
self.colors = ["#3A50B1", "#61A83E", "#D74C4C", "#A13FA1", "#D9A544", "#F3A6E0"]
|
89
|
+
|
90
|
+
def calculate_metrics(
|
91
|
+
self,
|
92
|
+
predictions: np.ndarray,
|
93
|
+
labels: np.ndarray,
|
94
|
+
per_class_metrics: bool = False,
|
95
|
+
) -> pd.DataFrame:
|
96
|
+
"""
|
97
|
+
Calculate multiple performance metrics for the given predictions and labels.
|
98
|
+
|
99
|
+
Args:
|
100
|
+
predictions (np.ndarray): Model predictions as a 2D NumPy array (probabilities or logits).
|
101
|
+
labels (np.ndarray): Ground truth labels as a 2D NumPy array.
|
102
|
+
per_class_metrics (bool): If True, compute metrics for each class individually.
|
103
|
+
|
104
|
+
Returns:
|
105
|
+
pd.DataFrame: A DataFrame containing the computed metrics.
|
106
|
+
|
107
|
+
Raises:
|
108
|
+
TypeError: If predictions or labels are not NumPy arrays.
|
109
|
+
ValueError: If predictions and labels have mismatched dimensions or invalid shapes.
|
110
|
+
"""
|
111
|
+
# Validate that predictions and labels are NumPy arrays
|
112
|
+
if not isinstance(predictions, np.ndarray):
|
113
|
+
raise TypeError("predictions must be a NumPy array.")
|
114
|
+
if not isinstance(labels, np.ndarray):
|
115
|
+
raise TypeError("labels must be a NumPy array.")
|
116
|
+
|
117
|
+
# Ensure predictions and labels have the same shape
|
118
|
+
if predictions.shape != labels.shape:
|
119
|
+
raise ValueError("predictions and labels must have the same shape.")
|
120
|
+
if predictions.ndim != 2:
|
121
|
+
raise ValueError("predictions and labels must be 2-dimensional arrays.")
|
122
|
+
if predictions.shape[1] != self.num_classes:
|
123
|
+
raise ValueError(f"The number of columns in predictions ({predictions.shape[1]}) " + f"must match num_classes ({self.num_classes}).")
|
124
|
+
|
125
|
+
# Determine the averaging method for metrics
|
126
|
+
if per_class_metrics and self.num_classes == 1:
|
127
|
+
averaging_method = "macro"
|
128
|
+
else:
|
129
|
+
averaging_method = None if per_class_metrics else "macro"
|
130
|
+
|
131
|
+
# Dictionary to store the results of each metric
|
132
|
+
metrics_results = {}
|
133
|
+
|
134
|
+
# Compute each metric in the metrics list
|
135
|
+
for metric_name in self.metrics_list:
|
136
|
+
if metric_name == "recall":
|
137
|
+
result = metrics.calculate_recall(
|
138
|
+
predictions=predictions,
|
139
|
+
labels=labels,
|
140
|
+
task=self.task,
|
141
|
+
threshold=self.threshold,
|
142
|
+
averaging_method=averaging_method,
|
143
|
+
)
|
144
|
+
metrics_results["Recall"] = np.atleast_1d(result)
|
145
|
+
elif metric_name == "precision":
|
146
|
+
result = metrics.calculate_precision(
|
147
|
+
predictions=predictions,
|
148
|
+
labels=labels,
|
149
|
+
task=self.task,
|
150
|
+
threshold=self.threshold,
|
151
|
+
averaging_method=averaging_method,
|
152
|
+
)
|
153
|
+
metrics_results["Precision"] = np.atleast_1d(result)
|
154
|
+
elif metric_name == "f1":
|
155
|
+
result = metrics.calculate_f1_score(
|
156
|
+
predictions=predictions,
|
157
|
+
labels=labels,
|
158
|
+
task=self.task,
|
159
|
+
threshold=self.threshold,
|
160
|
+
averaging_method=averaging_method,
|
161
|
+
)
|
162
|
+
metrics_results["F1"] = np.atleast_1d(result)
|
163
|
+
elif metric_name == "ap":
|
164
|
+
result = metrics.calculate_average_precision(
|
165
|
+
predictions=predictions,
|
166
|
+
labels=labels,
|
167
|
+
task=self.task,
|
168
|
+
averaging_method=averaging_method,
|
169
|
+
)
|
170
|
+
metrics_results["AP"] = np.atleast_1d(result)
|
171
|
+
elif metric_name == "auroc":
|
172
|
+
result = metrics.calculate_auroc(
|
173
|
+
predictions=predictions,
|
174
|
+
labels=labels,
|
175
|
+
task=self.task,
|
176
|
+
averaging_method=averaging_method,
|
177
|
+
)
|
178
|
+
metrics_results["AUROC"] = np.atleast_1d(result)
|
179
|
+
elif metric_name == "accuracy":
|
180
|
+
result = metrics.calculate_accuracy(
|
181
|
+
predictions=predictions,
|
182
|
+
labels=labels,
|
183
|
+
task=self.task,
|
184
|
+
num_classes=self.num_classes,
|
185
|
+
threshold=self.threshold,
|
186
|
+
averaging_method=averaging_method,
|
187
|
+
)
|
188
|
+
metrics_results["Accuracy"] = np.atleast_1d(result)
|
189
|
+
|
190
|
+
# Define column names for the DataFrame
|
191
|
+
columns = (self.classes if self.classes else [f"Class {i}" for i in range(self.num_classes)]) if per_class_metrics else ["Overall"]
|
192
|
+
|
193
|
+
# Create a DataFrame to organize metric results
|
194
|
+
metrics_data = {key: np.atleast_1d(value) for key, value in metrics_results.items()}
|
195
|
+
return pd.DataFrame.from_dict(metrics_data, orient="index", columns=columns)
|
196
|
+
|
197
|
+
def plot_metrics(
|
198
|
+
self,
|
199
|
+
predictions: np.ndarray,
|
200
|
+
labels: np.ndarray,
|
201
|
+
per_class_metrics: bool = False,
|
202
|
+
):
|
203
|
+
"""
|
204
|
+
Plot performance metrics for the given predictions and labels.
|
205
|
+
|
206
|
+
Args:
|
207
|
+
predictions (np.ndarray): Model output predictions as a 2D NumPy array (probabilities or logits).
|
208
|
+
labels (np.ndarray): Ground truth labels as a 2D NumPy array.
|
209
|
+
per_class_metrics (bool): If True, plots metrics for each class individually.
|
210
|
+
|
211
|
+
Raises:
|
212
|
+
ValueError: If the metrics cannot be calculated or plotting fails.
|
213
|
+
|
214
|
+
Returns:
|
215
|
+
None
|
216
|
+
"""
|
217
|
+
# Calculate metrics using the provided predictions and labels
|
218
|
+
metrics_df = self.calculate_metrics(predictions, labels, per_class_metrics)
|
219
|
+
|
220
|
+
# Choose the plotting method based on whether per-class metrics are required
|
221
|
+
return plotting.plot_metrics_per_class(metrics_df, self.colors) if per_class_metrics else plotting.plot_overall_metrics(metrics_df, self.colors)
|
222
|
+
|
223
|
+
def plot_metrics_all_thresholds(
|
224
|
+
self,
|
225
|
+
predictions: np.ndarray,
|
226
|
+
labels: np.ndarray,
|
227
|
+
per_class_metrics: bool = False,
|
228
|
+
):
|
229
|
+
"""
|
230
|
+
Plot performance metrics across thresholds for the given predictions and labels.
|
231
|
+
|
232
|
+
Args:
|
233
|
+
predictions (np.ndarray): Model output predictions as a 2D NumPy array (probabilities or logits).
|
234
|
+
labels (np.ndarray): Ground truth labels as a 2D NumPy array.
|
235
|
+
per_class_metrics (bool): If True, plots metrics for each class individually.
|
236
|
+
|
237
|
+
Raises:
|
238
|
+
ValueError: If metrics calculation or plotting fails.
|
239
|
+
|
240
|
+
Returns:
|
241
|
+
None
|
242
|
+
"""
|
243
|
+
# Save the original threshold value to restore it later
|
244
|
+
original_threshold = self.threshold
|
245
|
+
|
246
|
+
# Define a range of thresholds for analysis
|
247
|
+
thresholds = np.arange(0.05, 1.0, 0.05)
|
248
|
+
|
249
|
+
# Exclude metrics that are not threshold-dependent
|
250
|
+
metrics_to_plot = [m for m in self.metrics_list if m not in ["auroc", "ap"]]
|
251
|
+
|
252
|
+
if per_class_metrics:
|
253
|
+
# Define class names for plotting
|
254
|
+
class_names = list(self.classes) if self.classes else [f"Class {i}" for i in range(self.num_classes)]
|
255
|
+
|
256
|
+
# Initialize a dictionary to store metric values per class
|
257
|
+
metric_values_dict_per_class = {class_name: {metric: [] for metric in metrics_to_plot} for class_name in class_names}
|
258
|
+
|
259
|
+
# Compute metrics for each threshold
|
260
|
+
for thresh in thresholds:
|
261
|
+
self.threshold = thresh
|
262
|
+
metrics_df = self.calculate_metrics(predictions, labels, per_class_metrics=True)
|
263
|
+
for metric_name in metrics_to_plot:
|
264
|
+
metric_label = metric_name.capitalize() if metric_name != "f1" else "F1"
|
265
|
+
for class_name in class_names:
|
266
|
+
value = metrics_df.loc[metric_label, class_name]
|
267
|
+
metric_values_dict_per_class[class_name][metric_name].append(value)
|
268
|
+
|
269
|
+
# Restore the original threshold
|
270
|
+
self.threshold = original_threshold
|
271
|
+
|
272
|
+
# Plot metrics across thresholds per class
|
273
|
+
fig = plotting.plot_metrics_across_thresholds_per_class(
|
274
|
+
thresholds,
|
275
|
+
metric_values_dict_per_class,
|
276
|
+
metrics_to_plot,
|
277
|
+
class_names,
|
278
|
+
self.colors,
|
279
|
+
)
|
280
|
+
else:
|
281
|
+
# Initialize a dictionary to store overall metric values
|
282
|
+
metric_values_dict = {metric_name: [] for metric_name in metrics_to_plot}
|
283
|
+
|
284
|
+
# Compute metrics for each threshold
|
285
|
+
for thresh in thresholds:
|
286
|
+
self.threshold = thresh
|
287
|
+
metrics_df = self.calculate_metrics(predictions, labels, per_class_metrics=False)
|
288
|
+
for metric_name in metrics_to_plot:
|
289
|
+
metric_label = metric_name.capitalize() if metric_name != "f1" else "F1"
|
290
|
+
value = metrics_df.loc[metric_label, "Overall"]
|
291
|
+
metric_values_dict[metric_name].append(value)
|
292
|
+
|
293
|
+
# Restore the original threshold
|
294
|
+
self.threshold = original_threshold
|
295
|
+
|
296
|
+
# Plot metrics across thresholds
|
297
|
+
fig = plotting.plot_metrics_across_thresholds(
|
298
|
+
thresholds,
|
299
|
+
metric_values_dict,
|
300
|
+
metrics_to_plot,
|
301
|
+
self.colors,
|
302
|
+
)
|
303
|
+
|
304
|
+
return fig
|
305
|
+
|
306
|
+
def plot_confusion_matrix(
|
307
|
+
self,
|
308
|
+
predictions: np.ndarray,
|
309
|
+
labels: np.ndarray,
|
310
|
+
):
|
311
|
+
"""
|
312
|
+
Plot confusion matrices for each class using scikit-learn's ConfusionMatrixDisplay.
|
313
|
+
|
314
|
+
Args:
|
315
|
+
predictions (np.ndarray): Model output predictions as a 2D NumPy array (probabilities or logits).
|
316
|
+
labels (np.ndarray): Ground truth labels as a 2D NumPy array.
|
317
|
+
|
318
|
+
Raises:
|
319
|
+
TypeError: If predictions or labels are not NumPy arrays.
|
320
|
+
ValueError: If predictions and labels have mismatched shapes or invalid dimensions.
|
321
|
+
|
322
|
+
Returns:
|
323
|
+
None
|
324
|
+
"""
|
325
|
+
# Validate that predictions and labels are NumPy arrays and match in shape
|
326
|
+
if not isinstance(predictions, np.ndarray):
|
327
|
+
raise TypeError("predictions must be a NumPy array.")
|
328
|
+
if not isinstance(labels, np.ndarray):
|
329
|
+
raise TypeError("labels must be a NumPy array.")
|
330
|
+
if predictions.shape != labels.shape:
|
331
|
+
raise ValueError("predictions and labels must have the same shape.")
|
332
|
+
if predictions.ndim != 2:
|
333
|
+
raise ValueError("predictions and labels must be 2-dimensional arrays.")
|
334
|
+
if predictions.shape[1] != self.num_classes:
|
335
|
+
raise ValueError(f"The number of columns in predictions ({predictions.shape[1]}) " + f"must match num_classes ({self.num_classes}).")
|
336
|
+
|
337
|
+
if self.task == "binary":
|
338
|
+
# Binarize predictions using the threshold
|
339
|
+
y_pred = (predictions >= self.threshold).astype(int).flatten()
|
340
|
+
y_true = labels.astype(int).flatten()
|
341
|
+
|
342
|
+
# Compute and normalize the confusion matrix
|
343
|
+
conf_mat = confusion_matrix(y_true, y_pred, normalize="true")
|
344
|
+
conf_mat = np.round(conf_mat, 2)
|
345
|
+
|
346
|
+
return plotting.plot_confusion_matrices(conf_mat, self.task, self.classes)
|
347
|
+
|
348
|
+
if self.task == "multilabel":
|
349
|
+
# Binarize predictions for multilabel classification
|
350
|
+
y_pred = (predictions >= self.threshold).astype(int)
|
351
|
+
y_true = labels.astype(int)
|
352
|
+
|
353
|
+
# Compute confusion matrices for each class
|
354
|
+
conf_mats = []
|
355
|
+
class_names = self.classes if self.classes else [f"Class {i}" for i in range(self.num_classes)]
|
356
|
+
|
357
|
+
for i in range(self.num_classes):
|
358
|
+
conf_mat = confusion_matrix(y_true[:, i], y_pred[:, i], normalize="true")
|
359
|
+
conf_mat = np.round(conf_mat, 2)
|
360
|
+
conf_mats.append(conf_mat)
|
361
|
+
|
362
|
+
return plotting.plot_confusion_matrices(np.array(conf_mats), self.task, class_names)
|
363
|
+
|
364
|
+
raise ValueError(f"Unsupported task type: {self.task}")
|