birdnet-analyzer 2.0.0__py3-none-any.whl → 2.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- birdnet_analyzer/__init__.py +9 -8
- birdnet_analyzer/analyze/__init__.py +5 -5
- birdnet_analyzer/analyze/__main__.py +3 -4
- birdnet_analyzer/analyze/cli.py +25 -25
- birdnet_analyzer/analyze/core.py +241 -245
- birdnet_analyzer/analyze/utils.py +692 -701
- birdnet_analyzer/audio.py +368 -372
- birdnet_analyzer/cli.py +709 -707
- birdnet_analyzer/config.py +242 -242
- birdnet_analyzer/eBird_taxonomy_codes_2021E.json +25279 -25279
- birdnet_analyzer/embeddings/__init__.py +3 -4
- birdnet_analyzer/embeddings/__main__.py +3 -3
- birdnet_analyzer/embeddings/cli.py +12 -13
- birdnet_analyzer/embeddings/core.py +69 -70
- birdnet_analyzer/embeddings/utils.py +179 -193
- birdnet_analyzer/evaluation/__init__.py +196 -195
- birdnet_analyzer/evaluation/__main__.py +3 -3
- birdnet_analyzer/evaluation/assessment/__init__.py +0 -0
- birdnet_analyzer/evaluation/assessment/metrics.py +388 -0
- birdnet_analyzer/evaluation/assessment/performance_assessor.py +409 -0
- birdnet_analyzer/evaluation/assessment/plotting.py +379 -0
- birdnet_analyzer/evaluation/preprocessing/__init__.py +0 -0
- birdnet_analyzer/evaluation/preprocessing/data_processor.py +631 -0
- birdnet_analyzer/evaluation/preprocessing/utils.py +98 -0
- birdnet_analyzer/gui/__init__.py +19 -23
- birdnet_analyzer/gui/__main__.py +3 -3
- birdnet_analyzer/gui/analysis.py +175 -174
- birdnet_analyzer/gui/assets/arrow_down.svg +4 -4
- birdnet_analyzer/gui/assets/arrow_left.svg +4 -4
- birdnet_analyzer/gui/assets/arrow_right.svg +4 -4
- birdnet_analyzer/gui/assets/arrow_up.svg +4 -4
- birdnet_analyzer/gui/assets/gui.css +28 -28
- birdnet_analyzer/gui/assets/gui.js +93 -93
- birdnet_analyzer/gui/embeddings.py +619 -620
- birdnet_analyzer/gui/evaluation.py +795 -813
- birdnet_analyzer/gui/localization.py +75 -68
- birdnet_analyzer/gui/multi_file.py +245 -246
- birdnet_analyzer/gui/review.py +519 -527
- birdnet_analyzer/gui/segments.py +191 -191
- birdnet_analyzer/gui/settings.py +128 -129
- birdnet_analyzer/gui/single_file.py +267 -269
- birdnet_analyzer/gui/species.py +95 -95
- birdnet_analyzer/gui/train.py +696 -698
- birdnet_analyzer/gui/utils.py +810 -808
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_af.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ar.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_bg.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ca.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_cs.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_da.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_de.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_el.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_en_uk.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_es.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_fi.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_fr.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_he.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_hr.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_hu.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_in.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_is.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_it.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ja.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ko.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_lt.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ml.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_nl.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_no.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_pl.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_pt_BR.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_pt_PT.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ro.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ru.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sk.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sl.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sr.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sv.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_th.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_tr.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_uk.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_zh.txt +6522 -6522
- birdnet_analyzer/lang/de.json +334 -334
- birdnet_analyzer/lang/en.json +334 -334
- birdnet_analyzer/lang/fi.json +334 -334
- birdnet_analyzer/lang/fr.json +334 -334
- birdnet_analyzer/lang/id.json +334 -334
- birdnet_analyzer/lang/pt-br.json +334 -334
- birdnet_analyzer/lang/ru.json +334 -334
- birdnet_analyzer/lang/se.json +334 -334
- birdnet_analyzer/lang/tlh.json +334 -334
- birdnet_analyzer/lang/zh_TW.json +334 -334
- birdnet_analyzer/model.py +1212 -1243
- birdnet_analyzer/playground.py +5 -0
- birdnet_analyzer/search/__init__.py +3 -3
- birdnet_analyzer/search/__main__.py +3 -3
- birdnet_analyzer/search/cli.py +11 -12
- birdnet_analyzer/search/core.py +78 -78
- birdnet_analyzer/search/utils.py +107 -111
- birdnet_analyzer/segments/__init__.py +3 -3
- birdnet_analyzer/segments/__main__.py +3 -3
- birdnet_analyzer/segments/cli.py +13 -14
- birdnet_analyzer/segments/core.py +81 -78
- birdnet_analyzer/segments/utils.py +383 -394
- birdnet_analyzer/species/__init__.py +3 -3
- birdnet_analyzer/species/__main__.py +3 -3
- birdnet_analyzer/species/cli.py +13 -14
- birdnet_analyzer/species/core.py +35 -35
- birdnet_analyzer/species/utils.py +74 -75
- birdnet_analyzer/train/__init__.py +3 -3
- birdnet_analyzer/train/__main__.py +3 -3
- birdnet_analyzer/train/cli.py +13 -14
- birdnet_analyzer/train/core.py +113 -113
- birdnet_analyzer/train/utils.py +877 -847
- birdnet_analyzer/translate.py +133 -104
- birdnet_analyzer/utils.py +426 -419
- {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.0.1.dist-info}/METADATA +137 -129
- birdnet_analyzer-2.0.1.dist-info/RECORD +125 -0
- {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.0.1.dist-info}/WHEEL +1 -1
- {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.0.1.dist-info}/licenses/LICENSE +18 -18
- birdnet_analyzer-2.0.0.dist-info/RECORD +0 -117
- {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.0.1.dist-info}/entry_points.txt +0 -0
- {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.0.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,379 @@
|
|
1
|
+
"""
|
2
|
+
Module containing functions to plot performance metrics.
|
3
|
+
|
4
|
+
This script provides a variety of functions to visualize performance metrics in different formats,
|
5
|
+
including bar charts, line plots, and heatmaps. These visualizations help analyze metrics such as
|
6
|
+
overall performance, per-class performance, and performance across thresholds.
|
7
|
+
|
8
|
+
Functions:
|
9
|
+
- plot_overall_metrics: Plots a bar chart for overall performance metrics.
|
10
|
+
- plot_metrics_per_class: Plots metric values per class with unique lines and colors.
|
11
|
+
- plot_metrics_across_thresholds: Plots metrics across different thresholds.
|
12
|
+
- plot_metrics_across_thresholds_per_class: Plots metrics across thresholds for each class.
|
13
|
+
- plot_confusion_matrices: Visualizes confusion matrices for binary, multiclass, or multilabel tasks.
|
14
|
+
"""
|
15
|
+
|
16
|
+
from typing import Literal
|
17
|
+
|
18
|
+
import matplotlib.pyplot as plt
|
19
|
+
import numpy as np
|
20
|
+
import pandas as pd
|
21
|
+
import seaborn as sns
|
22
|
+
|
23
|
+
|
24
|
+
def plot_overall_metrics(metrics_df: pd.DataFrame, colors: list[str]) -> plt.Figure:
|
25
|
+
"""
|
26
|
+
Plots a bar chart for overall performance metrics.
|
27
|
+
|
28
|
+
Args:
|
29
|
+
metrics_df (pd.DataFrame): DataFrame containing metric names as index and an 'Overall' column.
|
30
|
+
colors (List[str]): List of colors for the bars.
|
31
|
+
|
32
|
+
Raises:
|
33
|
+
TypeError: If `metrics_df` is not a DataFrame or `colors` is not a list.
|
34
|
+
KeyError: If 'Overall' column is missing in `metrics_df`.
|
35
|
+
ValueError: If `metrics_df` is empty.
|
36
|
+
|
37
|
+
Returns:
|
38
|
+
plt.Figure
|
39
|
+
"""
|
40
|
+
# Validate input types and content
|
41
|
+
if not isinstance(metrics_df, pd.DataFrame):
|
42
|
+
raise TypeError("metrics_df must be a pandas DataFrame.")
|
43
|
+
if "Overall" not in metrics_df.columns:
|
44
|
+
raise KeyError("metrics_df must contain an 'Overall' column.")
|
45
|
+
if metrics_df.empty:
|
46
|
+
raise ValueError("metrics_df is empty.")
|
47
|
+
if not isinstance(colors, list):
|
48
|
+
raise TypeError("colors must be a list.")
|
49
|
+
if len(colors) == 0:
|
50
|
+
# Default to matplotlib's color cycle if colors are not provided
|
51
|
+
colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
|
52
|
+
|
53
|
+
# Extract metric names and values
|
54
|
+
metrics = metrics_df.index # Metric names
|
55
|
+
values = metrics_df["Overall"].to_numpy() # Metric values
|
56
|
+
|
57
|
+
# Plot bar chart
|
58
|
+
fig = plt.figure(figsize=(10, 6))
|
59
|
+
plt.bar(metrics, values, color=colors[: len(metrics)])
|
60
|
+
|
61
|
+
# Add titles, labels, and format
|
62
|
+
plt.title("Overall Metric Scores", fontsize=16)
|
63
|
+
plt.xlabel("Metrics", fontsize=12)
|
64
|
+
plt.ylabel("Score", fontsize=12)
|
65
|
+
plt.xticks(rotation=45, ha="right", fontsize=10)
|
66
|
+
plt.grid(axis="y", linestyle="--", alpha=0.7)
|
67
|
+
plt.tight_layout()
|
68
|
+
|
69
|
+
return fig
|
70
|
+
|
71
|
+
|
72
|
+
def plot_metrics_per_class(metrics_df: pd.DataFrame, colors: list[str]) -> plt.Figure:
|
73
|
+
"""
|
74
|
+
Plots metric values per class, with each metric represented by a distinct color and line.
|
75
|
+
|
76
|
+
Args:
|
77
|
+
metrics_df (pd.DataFrame): DataFrame containing metrics as index and class names as columns.
|
78
|
+
colors (List[str]): List of colors for the lines.
|
79
|
+
|
80
|
+
Raises:
|
81
|
+
TypeError: If inputs are not of expected types.
|
82
|
+
ValueError: If `metrics_df` is empty.
|
83
|
+
|
84
|
+
Returns:
|
85
|
+
plt.Figure
|
86
|
+
"""
|
87
|
+
# Validate inputs
|
88
|
+
if not isinstance(metrics_df, pd.DataFrame):
|
89
|
+
raise TypeError("metrics_df must be a pandas DataFrame.")
|
90
|
+
if metrics_df.empty:
|
91
|
+
raise ValueError("metrics_df is empty.")
|
92
|
+
if not isinstance(colors, list):
|
93
|
+
raise TypeError("colors must be a list.")
|
94
|
+
if len(colors) == 0:
|
95
|
+
# Default to matplotlib's color cycle if colors are not provided
|
96
|
+
colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
|
97
|
+
|
98
|
+
# Line styles for distinction
|
99
|
+
line_styles = ["-", "--", "-.", ":", (0, (5, 10)), (0, (5, 5)), (0, (3, 5, 1, 5))]
|
100
|
+
fig = plt.figure(figsize=(10, 6))
|
101
|
+
|
102
|
+
# Loop over each metric and plot it
|
103
|
+
for i, metric_name in enumerate(metrics_df.index):
|
104
|
+
values = metrics_df.loc[metric_name] # Metric values for each class
|
105
|
+
classes = metrics_df.columns # Class labels
|
106
|
+
plt.plot(
|
107
|
+
classes,
|
108
|
+
values,
|
109
|
+
label=metric_name,
|
110
|
+
marker="o",
|
111
|
+
markersize=8,
|
112
|
+
linewidth=2,
|
113
|
+
linestyle=line_styles[i % len(line_styles)],
|
114
|
+
color=colors[i % len(colors)],
|
115
|
+
)
|
116
|
+
|
117
|
+
# Add titles, labels, legend, and format
|
118
|
+
plt.title("Metric Scores per Class", fontsize=16)
|
119
|
+
plt.xlabel("Class", fontsize=12)
|
120
|
+
plt.ylabel("Score", fontsize=12)
|
121
|
+
plt.legend(loc="lower right")
|
122
|
+
plt.grid(True)
|
123
|
+
plt.tight_layout()
|
124
|
+
|
125
|
+
return fig
|
126
|
+
|
127
|
+
|
128
|
+
def plot_metrics_across_thresholds(
|
129
|
+
thresholds: np.ndarray,
|
130
|
+
metric_values_dict: dict[str, np.ndarray],
|
131
|
+
metrics_to_plot: list[str],
|
132
|
+
colors: list[str],
|
133
|
+
) -> plt.Figure:
|
134
|
+
"""
|
135
|
+
Plots metrics across different thresholds.
|
136
|
+
|
137
|
+
Args:
|
138
|
+
thresholds (np.ndarray): Array of threshold values.
|
139
|
+
metric_values_dict (Dict[str, np.ndarray]): Dictionary mapping metric names to their values.
|
140
|
+
metrics_to_plot (List[str]): List of metric names to plot.
|
141
|
+
colors (List[str]): List of colors for the lines.
|
142
|
+
|
143
|
+
Raises:
|
144
|
+
TypeError: If inputs are not of expected types.
|
145
|
+
ValueError: If thresholds or metric values have mismatched lengths.
|
146
|
+
|
147
|
+
Returns:
|
148
|
+
plt.Figure
|
149
|
+
"""
|
150
|
+
# Validate inputs
|
151
|
+
if not isinstance(thresholds, np.ndarray):
|
152
|
+
raise TypeError("thresholds must be a numpy ndarray.")
|
153
|
+
if thresholds.size == 0:
|
154
|
+
raise ValueError("thresholds array is empty.")
|
155
|
+
if not isinstance(metric_values_dict, dict):
|
156
|
+
raise TypeError("metric_values_dict must be a dictionary.")
|
157
|
+
if not isinstance(metrics_to_plot, list):
|
158
|
+
raise TypeError("metrics_to_plot must be a list.")
|
159
|
+
if not isinstance(colors, list):
|
160
|
+
raise TypeError("colors must be a list.")
|
161
|
+
if len(colors) == 0:
|
162
|
+
# Default to matplotlib's color cycle if colors are not provided
|
163
|
+
colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
|
164
|
+
|
165
|
+
# Line styles for distinction
|
166
|
+
line_styles = ["-", "--", "-.", ":", (0, (5, 10)), (0, (5, 5)), (0, (3, 5, 1, 5))]
|
167
|
+
fig = plt.figure(figsize=(10, 6))
|
168
|
+
|
169
|
+
# Plot each metric against thresholds
|
170
|
+
for i, metric_name in enumerate(metrics_to_plot):
|
171
|
+
if metric_name not in metric_values_dict:
|
172
|
+
raise KeyError(f"Metric '{metric_name}' not found in metric_values_dict.")
|
173
|
+
metric_values = metric_values_dict[metric_name]
|
174
|
+
if len(metric_values) != len(thresholds):
|
175
|
+
raise ValueError(f"Length of metric '{metric_name}' values does not match length of thresholds.")
|
176
|
+
plt.plot(
|
177
|
+
thresholds,
|
178
|
+
metric_values,
|
179
|
+
label=metric_name.capitalize(),
|
180
|
+
linestyle=line_styles[i % len(line_styles)],
|
181
|
+
linewidth=2,
|
182
|
+
color=colors[i % len(colors)],
|
183
|
+
)
|
184
|
+
|
185
|
+
# Add titles, labels, legend, and format
|
186
|
+
plt.title("Metrics across Different Thresholds", fontsize=16)
|
187
|
+
plt.xlabel("Threshold", fontsize=12)
|
188
|
+
plt.ylabel("Metric Score", fontsize=12)
|
189
|
+
plt.legend(loc="best")
|
190
|
+
plt.grid(True)
|
191
|
+
plt.tight_layout()
|
192
|
+
|
193
|
+
return fig
|
194
|
+
|
195
|
+
|
196
|
+
def plot_metrics_across_thresholds_per_class(
|
197
|
+
thresholds: np.ndarray,
|
198
|
+
metric_values_dict_per_class: dict[str, dict[str, np.ndarray]],
|
199
|
+
metrics_to_plot: list[str],
|
200
|
+
class_names: list[str],
|
201
|
+
colors: list[str],
|
202
|
+
) -> plt.Figure:
|
203
|
+
"""
|
204
|
+
Plots metrics across different thresholds per class.
|
205
|
+
|
206
|
+
Args:
|
207
|
+
thresholds (np.ndarray): Array of threshold values.
|
208
|
+
metric_values_dict_per_class (Dict[str, Dict[str, np.ndarray]]): Dictionary mapping class names
|
209
|
+
to metric dictionaries, each containing metric names and their values across thresholds.
|
210
|
+
metrics_to_plot (List[str]): List of metric names to plot.
|
211
|
+
class_names (List[str]): List of class names.
|
212
|
+
colors (List[str]): List of colors for the lines.
|
213
|
+
|
214
|
+
Raises:
|
215
|
+
TypeError: If inputs are not of expected types.
|
216
|
+
ValueError: If inputs have mismatched lengths or are empty.
|
217
|
+
|
218
|
+
Returns:
|
219
|
+
plt.Figure
|
220
|
+
"""
|
221
|
+
# Validate inputs
|
222
|
+
if not isinstance(thresholds, np.ndarray):
|
223
|
+
raise TypeError("thresholds must be a numpy ndarray.")
|
224
|
+
if thresholds.size == 0:
|
225
|
+
raise ValueError("thresholds array is empty.")
|
226
|
+
if not isinstance(metric_values_dict_per_class, dict):
|
227
|
+
raise TypeError("metric_values_dict_per_class must be a dictionary.")
|
228
|
+
if not isinstance(metrics_to_plot, list):
|
229
|
+
raise TypeError("metrics_to_plot must be a list.")
|
230
|
+
if not isinstance(class_names, list):
|
231
|
+
raise TypeError("class_names must be a list.")
|
232
|
+
if not isinstance(colors, list):
|
233
|
+
raise TypeError("colors must be a list.")
|
234
|
+
if len(colors) == 0:
|
235
|
+
# Default to matplotlib's color cycle if colors are not provided
|
236
|
+
colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
|
237
|
+
|
238
|
+
num_classes = len(class_names)
|
239
|
+
if num_classes == 0:
|
240
|
+
raise ValueError("class_names list is empty.")
|
241
|
+
|
242
|
+
# Determine grid size for subplots
|
243
|
+
n_cols = int(np.ceil(np.sqrt(num_classes)))
|
244
|
+
n_rows = int(np.ceil(num_classes / n_cols))
|
245
|
+
|
246
|
+
# Create subplots
|
247
|
+
fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 5, n_rows * 4))
|
248
|
+
|
249
|
+
# Flatten axes for easy indexing
|
250
|
+
axes = [axes] if num_classes == 1 else axes.flatten()
|
251
|
+
|
252
|
+
# Line styles for distinction
|
253
|
+
line_styles = ["-", "--", "-.", ":", (0, (5, 10)), (0, (5, 5)), (0, (3, 5, 1, 5))]
|
254
|
+
|
255
|
+
# Plot each class
|
256
|
+
for class_idx, class_name in enumerate(class_names):
|
257
|
+
if class_name not in metric_values_dict_per_class:
|
258
|
+
raise KeyError(f"Class '{class_name}' not found in metric_values_dict_per_class.")
|
259
|
+
ax = axes[class_idx]
|
260
|
+
metric_values_dict = metric_values_dict_per_class[class_name]
|
261
|
+
|
262
|
+
# Plot each metric for the current class
|
263
|
+
for i, metric_name in enumerate(metrics_to_plot):
|
264
|
+
if metric_name not in metric_values_dict:
|
265
|
+
raise KeyError(f"Metric '{metric_name}' not found for class '{class_name}'.")
|
266
|
+
metric_values = metric_values_dict[metric_name]
|
267
|
+
if len(metric_values) != len(thresholds):
|
268
|
+
raise ValueError(
|
269
|
+
f"Length of metric '{metric_name}' values for class '{class_name}' "
|
270
|
+
+ "does not match length of thresholds."
|
271
|
+
)
|
272
|
+
ax.plot(
|
273
|
+
thresholds,
|
274
|
+
metric_values,
|
275
|
+
label=metric_name.capitalize(),
|
276
|
+
linestyle=line_styles[i % len(line_styles)],
|
277
|
+
linewidth=2,
|
278
|
+
color=colors[i % len(colors)],
|
279
|
+
)
|
280
|
+
|
281
|
+
# Add titles and labels for each subplot
|
282
|
+
ax.set_title(f"{class_name}", fontsize=12)
|
283
|
+
ax.set_xlabel("Threshold", fontsize=10)
|
284
|
+
ax.set_ylabel("Metric Score", fontsize=10)
|
285
|
+
ax.legend(loc="best", fontsize=8)
|
286
|
+
ax.grid(True)
|
287
|
+
|
288
|
+
# Hide any unused subplots
|
289
|
+
for j in range(num_classes, len(axes)):
|
290
|
+
fig.delaxes(axes[j])
|
291
|
+
|
292
|
+
# Adjust layout and show
|
293
|
+
plt.tight_layout()
|
294
|
+
|
295
|
+
return fig
|
296
|
+
|
297
|
+
|
298
|
+
def plot_confusion_matrices(
|
299
|
+
conf_mat: np.ndarray,
|
300
|
+
task: Literal["binary", "multiclass", "multilabel"],
|
301
|
+
class_names: list[str],
|
302
|
+
) -> plt.Figure:
|
303
|
+
"""
|
304
|
+
Plots confusion matrices for each class in a single figure with multiple subplots.
|
305
|
+
|
306
|
+
Args:
|
307
|
+
conf_mat (np.ndarray): Confusion matrix or matrices. For binary classification, a single 2x2 matrix.
|
308
|
+
For multilabel or multiclass, an array of shape (num_classes, 2, 2).
|
309
|
+
task (Literal["binary", "multiclass", "multilabel"]): Task type.
|
310
|
+
class_names (List[str]): List of class names.
|
311
|
+
|
312
|
+
Raises:
|
313
|
+
TypeError: If inputs are not of expected types.
|
314
|
+
ValueError: If confusion matrix dimensions or task specifications are invalid.
|
315
|
+
|
316
|
+
Returns:
|
317
|
+
plt.Figure
|
318
|
+
"""
|
319
|
+
# Validate inputs
|
320
|
+
if not isinstance(conf_mat, np.ndarray):
|
321
|
+
raise TypeError("conf_mat must be a numpy ndarray.")
|
322
|
+
if conf_mat.size == 0:
|
323
|
+
raise ValueError("conf_mat is empty.")
|
324
|
+
if not isinstance(task, str) or task not in ["binary", "multiclass", "multilabel"]:
|
325
|
+
raise ValueError("Invalid task. Expected 'binary', 'multiclass', or 'multilabel'.")
|
326
|
+
if not isinstance(class_names, list):
|
327
|
+
raise TypeError("class_names must be a list.")
|
328
|
+
if len(class_names) == 0:
|
329
|
+
raise ValueError("class_names list is empty.")
|
330
|
+
|
331
|
+
if task == "binary":
|
332
|
+
# Binary classification expects a single 2x2 matrix
|
333
|
+
if conf_mat.shape != (2, 2):
|
334
|
+
raise ValueError("For binary task, conf_mat must be of shape (2, 2).")
|
335
|
+
if len(class_names) != 2:
|
336
|
+
raise ValueError("For binary task, class_names must have exactly two elements.")
|
337
|
+
|
338
|
+
# Plot single confusion matrix
|
339
|
+
fig = plt.figure(figsize=(4, 4))
|
340
|
+
sns.heatmap(conf_mat, annot=True, fmt=".2f", cmap="Reds", cbar=False)
|
341
|
+
plt.title("Confusion Matrix")
|
342
|
+
plt.xlabel("Predicted Class")
|
343
|
+
plt.ylabel("True Class")
|
344
|
+
plt.tight_layout()
|
345
|
+
else:
|
346
|
+
# Multilabel or multiclass expects a set of 2x2 matrices
|
347
|
+
num_labels = conf_mat.shape[0]
|
348
|
+
if conf_mat.shape[1:] != (2, 2):
|
349
|
+
raise ValueError("For multilabel or multiclass task, conf_mat must have shape (num_labels, 2, 2).")
|
350
|
+
if len(class_names) != num_labels:
|
351
|
+
raise ValueError("Length of class_names must match number of labels in conf_mat.")
|
352
|
+
|
353
|
+
# Determine grid size for subplots
|
354
|
+
n_cols = int(np.ceil(np.sqrt(num_labels)))
|
355
|
+
n_rows = int(np.ceil(num_labels / n_cols))
|
356
|
+
|
357
|
+
# Create subplots
|
358
|
+
fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 2, n_rows * 2))
|
359
|
+
|
360
|
+
# Flatten axes for easy indexing
|
361
|
+
axes = [axes] if num_labels == 1 else axes.flatten()
|
362
|
+
|
363
|
+
# Plot each class's confusion matrix
|
364
|
+
for i in range(num_labels):
|
365
|
+
cm = conf_mat[i]
|
366
|
+
ax = axes[i]
|
367
|
+
sns.heatmap(cm, annot=True, fmt=".2f", cmap="Reds", cbar=False, ax=ax)
|
368
|
+
ax.set_title(f"{class_names[i]}")
|
369
|
+
ax.set_xlabel("Predicted Class")
|
370
|
+
ax.set_ylabel("True Class")
|
371
|
+
|
372
|
+
# Hide any unused subplots
|
373
|
+
for j in range(num_labels, len(axes)):
|
374
|
+
fig.delaxes(axes[j])
|
375
|
+
|
376
|
+
# Adjust layout and show
|
377
|
+
plt.tight_layout()
|
378
|
+
|
379
|
+
return fig
|
File without changes
|