birdnet-analyzer 2.0.1__py3-none-any.whl → 2.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (121) hide show
  1. birdnet_analyzer/__init__.py +9 -9
  2. birdnet_analyzer/analyze/__init__.py +19 -5
  3. birdnet_analyzer/analyze/__main__.py +3 -3
  4. birdnet_analyzer/analyze/cli.py +30 -25
  5. birdnet_analyzer/analyze/core.py +268 -241
  6. birdnet_analyzer/analyze/utils.py +700 -692
  7. birdnet_analyzer/audio.py +368 -368
  8. birdnet_analyzer/cli.py +732 -709
  9. birdnet_analyzer/config.py +243 -242
  10. birdnet_analyzer/eBird_taxonomy_codes_2024E.json +13046 -0
  11. birdnet_analyzer/embeddings/__init__.py +3 -3
  12. birdnet_analyzer/embeddings/__main__.py +3 -3
  13. birdnet_analyzer/embeddings/cli.py +12 -12
  14. birdnet_analyzer/embeddings/core.py +70 -69
  15. birdnet_analyzer/embeddings/utils.py +173 -179
  16. birdnet_analyzer/evaluation/__init__.py +189 -196
  17. birdnet_analyzer/evaluation/__main__.py +3 -3
  18. birdnet_analyzer/evaluation/assessment/metrics.py +388 -388
  19. birdnet_analyzer/evaluation/assessment/performance_assessor.py +364 -409
  20. birdnet_analyzer/evaluation/assessment/plotting.py +378 -379
  21. birdnet_analyzer/evaluation/preprocessing/data_processor.py +631 -631
  22. birdnet_analyzer/evaluation/preprocessing/utils.py +98 -98
  23. birdnet_analyzer/gui/__init__.py +19 -19
  24. birdnet_analyzer/gui/__main__.py +3 -3
  25. birdnet_analyzer/gui/analysis.py +179 -175
  26. birdnet_analyzer/gui/assets/arrow_down.svg +4 -4
  27. birdnet_analyzer/gui/assets/arrow_left.svg +4 -4
  28. birdnet_analyzer/gui/assets/arrow_right.svg +4 -4
  29. birdnet_analyzer/gui/assets/arrow_up.svg +4 -4
  30. birdnet_analyzer/gui/assets/gui.css +36 -28
  31. birdnet_analyzer/gui/assets/gui.js +89 -93
  32. birdnet_analyzer/gui/embeddings.py +638 -619
  33. birdnet_analyzer/gui/evaluation.py +801 -795
  34. birdnet_analyzer/gui/localization.py +75 -75
  35. birdnet_analyzer/gui/multi_file.py +265 -245
  36. birdnet_analyzer/gui/review.py +472 -519
  37. birdnet_analyzer/gui/segments.py +191 -191
  38. birdnet_analyzer/gui/settings.py +149 -128
  39. birdnet_analyzer/gui/single_file.py +264 -267
  40. birdnet_analyzer/gui/species.py +95 -95
  41. birdnet_analyzer/gui/train.py +687 -696
  42. birdnet_analyzer/gui/utils.py +803 -810
  43. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_af.txt +6522 -6522
  44. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ar.txt +6522 -6522
  45. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_bg.txt +6522 -6522
  46. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ca.txt +6522 -6522
  47. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_cs.txt +6522 -6522
  48. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_da.txt +6522 -6522
  49. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_de.txt +6522 -6522
  50. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_el.txt +6522 -6522
  51. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_en_uk.txt +6522 -6522
  52. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_es.txt +6522 -6522
  53. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_fi.txt +6522 -6522
  54. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_fr.txt +6522 -6522
  55. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_he.txt +6522 -6522
  56. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_hr.txt +6522 -6522
  57. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_hu.txt +6522 -6522
  58. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_in.txt +6522 -6522
  59. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_is.txt +6522 -6522
  60. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_it.txt +6522 -6522
  61. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ja.txt +6522 -6522
  62. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ko.txt +6522 -6522
  63. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_lt.txt +6522 -6522
  64. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ml.txt +6522 -6522
  65. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_nl.txt +6522 -6522
  66. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_no.txt +6522 -6522
  67. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_pl.txt +6522 -6522
  68. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_pt_BR.txt +6522 -6522
  69. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_pt_PT.txt +6522 -6522
  70. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ro.txt +6522 -6522
  71. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ru.txt +6522 -6522
  72. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sk.txt +6522 -6522
  73. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sl.txt +6522 -6522
  74. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sr.txt +6522 -6522
  75. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sv.txt +6522 -6522
  76. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_th.txt +6522 -6522
  77. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_tr.txt +6522 -6522
  78. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_uk.txt +6522 -6522
  79. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_zh.txt +6522 -6522
  80. birdnet_analyzer/lang/de.json +342 -334
  81. birdnet_analyzer/lang/en.json +342 -334
  82. birdnet_analyzer/lang/fi.json +342 -334
  83. birdnet_analyzer/lang/fr.json +342 -334
  84. birdnet_analyzer/lang/id.json +342 -334
  85. birdnet_analyzer/lang/pt-br.json +342 -334
  86. birdnet_analyzer/lang/ru.json +342 -334
  87. birdnet_analyzer/lang/se.json +342 -334
  88. birdnet_analyzer/lang/tlh.json +342 -334
  89. birdnet_analyzer/lang/zh_TW.json +342 -334
  90. birdnet_analyzer/model.py +1213 -1212
  91. birdnet_analyzer/search/__init__.py +3 -3
  92. birdnet_analyzer/search/__main__.py +3 -3
  93. birdnet_analyzer/search/cli.py +11 -11
  94. birdnet_analyzer/search/core.py +78 -78
  95. birdnet_analyzer/search/utils.py +104 -107
  96. birdnet_analyzer/segments/__init__.py +3 -3
  97. birdnet_analyzer/segments/__main__.py +3 -3
  98. birdnet_analyzer/segments/cli.py +13 -13
  99. birdnet_analyzer/segments/core.py +81 -81
  100. birdnet_analyzer/segments/utils.py +383 -383
  101. birdnet_analyzer/species/__init__.py +3 -3
  102. birdnet_analyzer/species/__main__.py +3 -3
  103. birdnet_analyzer/species/cli.py +13 -13
  104. birdnet_analyzer/species/core.py +35 -35
  105. birdnet_analyzer/species/utils.py +73 -74
  106. birdnet_analyzer/train/__init__.py +3 -3
  107. birdnet_analyzer/train/__main__.py +3 -3
  108. birdnet_analyzer/train/cli.py +13 -13
  109. birdnet_analyzer/train/core.py +113 -113
  110. birdnet_analyzer/train/utils.py +878 -877
  111. birdnet_analyzer/translate.py +132 -133
  112. birdnet_analyzer/utils.py +425 -426
  113. {birdnet_analyzer-2.0.1.dist-info → birdnet_analyzer-2.1.1.dist-info}/METADATA +147 -137
  114. birdnet_analyzer-2.1.1.dist-info/RECORD +124 -0
  115. {birdnet_analyzer-2.0.1.dist-info → birdnet_analyzer-2.1.1.dist-info}/WHEEL +1 -1
  116. {birdnet_analyzer-2.0.1.dist-info → birdnet_analyzer-2.1.1.dist-info}/licenses/LICENSE +18 -18
  117. birdnet_analyzer/eBird_taxonomy_codes_2021E.json +0 -25280
  118. birdnet_analyzer/playground.py +0 -5
  119. birdnet_analyzer-2.0.1.dist-info/RECORD +0 -125
  120. {birdnet_analyzer-2.0.1.dist-info → birdnet_analyzer-2.1.1.dist-info}/entry_points.txt +0 -0
  121. {birdnet_analyzer-2.0.1.dist-info → birdnet_analyzer-2.1.1.dist-info}/top_level.txt +0 -0
@@ -1,379 +1,378 @@
1
- """
2
- Module containing functions to plot performance metrics.
3
-
4
- This script provides a variety of functions to visualize performance metrics in different formats,
5
- including bar charts, line plots, and heatmaps. These visualizations help analyze metrics such as
6
- overall performance, per-class performance, and performance across thresholds.
7
-
8
- Functions:
9
- - plot_overall_metrics: Plots a bar chart for overall performance metrics.
10
- - plot_metrics_per_class: Plots metric values per class with unique lines and colors.
11
- - plot_metrics_across_thresholds: Plots metrics across different thresholds.
12
- - plot_metrics_across_thresholds_per_class: Plots metrics across thresholds for each class.
13
- - plot_confusion_matrices: Visualizes confusion matrices for binary, multiclass, or multilabel tasks.
14
- """
15
-
16
- from typing import Literal
17
-
18
- import matplotlib.pyplot as plt
19
- import numpy as np
20
- import pandas as pd
21
- import seaborn as sns
22
-
23
-
24
- def plot_overall_metrics(metrics_df: pd.DataFrame, colors: list[str]) -> plt.Figure:
25
- """
26
- Plots a bar chart for overall performance metrics.
27
-
28
- Args:
29
- metrics_df (pd.DataFrame): DataFrame containing metric names as index and an 'Overall' column.
30
- colors (List[str]): List of colors for the bars.
31
-
32
- Raises:
33
- TypeError: If `metrics_df` is not a DataFrame or `colors` is not a list.
34
- KeyError: If 'Overall' column is missing in `metrics_df`.
35
- ValueError: If `metrics_df` is empty.
36
-
37
- Returns:
38
- plt.Figure
39
- """
40
- # Validate input types and content
41
- if not isinstance(metrics_df, pd.DataFrame):
42
- raise TypeError("metrics_df must be a pandas DataFrame.")
43
- if "Overall" not in metrics_df.columns:
44
- raise KeyError("metrics_df must contain an 'Overall' column.")
45
- if metrics_df.empty:
46
- raise ValueError("metrics_df is empty.")
47
- if not isinstance(colors, list):
48
- raise TypeError("colors must be a list.")
49
- if len(colors) == 0:
50
- # Default to matplotlib's color cycle if colors are not provided
51
- colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
52
-
53
- # Extract metric names and values
54
- metrics = metrics_df.index # Metric names
55
- values = metrics_df["Overall"].to_numpy() # Metric values
56
-
57
- # Plot bar chart
58
- fig = plt.figure(figsize=(10, 6))
59
- plt.bar(metrics, values, color=colors[: len(metrics)])
60
-
61
- # Add titles, labels, and format
62
- plt.title("Overall Metric Scores", fontsize=16)
63
- plt.xlabel("Metrics", fontsize=12)
64
- plt.ylabel("Score", fontsize=12)
65
- plt.xticks(rotation=45, ha="right", fontsize=10)
66
- plt.grid(axis="y", linestyle="--", alpha=0.7)
67
- plt.tight_layout()
68
-
69
- return fig
70
-
71
-
72
- def plot_metrics_per_class(metrics_df: pd.DataFrame, colors: list[str]) -> plt.Figure:
73
- """
74
- Plots metric values per class, with each metric represented by a distinct color and line.
75
-
76
- Args:
77
- metrics_df (pd.DataFrame): DataFrame containing metrics as index and class names as columns.
78
- colors (List[str]): List of colors for the lines.
79
-
80
- Raises:
81
- TypeError: If inputs are not of expected types.
82
- ValueError: If `metrics_df` is empty.
83
-
84
- Returns:
85
- plt.Figure
86
- """
87
- # Validate inputs
88
- if not isinstance(metrics_df, pd.DataFrame):
89
- raise TypeError("metrics_df must be a pandas DataFrame.")
90
- if metrics_df.empty:
91
- raise ValueError("metrics_df is empty.")
92
- if not isinstance(colors, list):
93
- raise TypeError("colors must be a list.")
94
- if len(colors) == 0:
95
- # Default to matplotlib's color cycle if colors are not provided
96
- colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
97
-
98
- # Line styles for distinction
99
- line_styles = ["-", "--", "-.", ":", (0, (5, 10)), (0, (5, 5)), (0, (3, 5, 1, 5))]
100
- fig = plt.figure(figsize=(10, 6))
101
-
102
- # Loop over each metric and plot it
103
- for i, metric_name in enumerate(metrics_df.index):
104
- values = metrics_df.loc[metric_name] # Metric values for each class
105
- classes = metrics_df.columns # Class labels
106
- plt.plot(
107
- classes,
108
- values,
109
- label=metric_name,
110
- marker="o",
111
- markersize=8,
112
- linewidth=2,
113
- linestyle=line_styles[i % len(line_styles)],
114
- color=colors[i % len(colors)],
115
- )
116
-
117
- # Add titles, labels, legend, and format
118
- plt.title("Metric Scores per Class", fontsize=16)
119
- plt.xlabel("Class", fontsize=12)
120
- plt.ylabel("Score", fontsize=12)
121
- plt.legend(loc="lower right")
122
- plt.grid(True)
123
- plt.tight_layout()
124
-
125
- return fig
126
-
127
-
128
- def plot_metrics_across_thresholds(
129
- thresholds: np.ndarray,
130
- metric_values_dict: dict[str, np.ndarray],
131
- metrics_to_plot: list[str],
132
- colors: list[str],
133
- ) -> plt.Figure:
134
- """
135
- Plots metrics across different thresholds.
136
-
137
- Args:
138
- thresholds (np.ndarray): Array of threshold values.
139
- metric_values_dict (Dict[str, np.ndarray]): Dictionary mapping metric names to their values.
140
- metrics_to_plot (List[str]): List of metric names to plot.
141
- colors (List[str]): List of colors for the lines.
142
-
143
- Raises:
144
- TypeError: If inputs are not of expected types.
145
- ValueError: If thresholds or metric values have mismatched lengths.
146
-
147
- Returns:
148
- plt.Figure
149
- """
150
- # Validate inputs
151
- if not isinstance(thresholds, np.ndarray):
152
- raise TypeError("thresholds must be a numpy ndarray.")
153
- if thresholds.size == 0:
154
- raise ValueError("thresholds array is empty.")
155
- if not isinstance(metric_values_dict, dict):
156
- raise TypeError("metric_values_dict must be a dictionary.")
157
- if not isinstance(metrics_to_plot, list):
158
- raise TypeError("metrics_to_plot must be a list.")
159
- if not isinstance(colors, list):
160
- raise TypeError("colors must be a list.")
161
- if len(colors) == 0:
162
- # Default to matplotlib's color cycle if colors are not provided
163
- colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
164
-
165
- # Line styles for distinction
166
- line_styles = ["-", "--", "-.", ":", (0, (5, 10)), (0, (5, 5)), (0, (3, 5, 1, 5))]
167
- fig = plt.figure(figsize=(10, 6))
168
-
169
- # Plot each metric against thresholds
170
- for i, metric_name in enumerate(metrics_to_plot):
171
- if metric_name not in metric_values_dict:
172
- raise KeyError(f"Metric '{metric_name}' not found in metric_values_dict.")
173
- metric_values = metric_values_dict[metric_name]
174
- if len(metric_values) != len(thresholds):
175
- raise ValueError(f"Length of metric '{metric_name}' values does not match length of thresholds.")
176
- plt.plot(
177
- thresholds,
178
- metric_values,
179
- label=metric_name.capitalize(),
180
- linestyle=line_styles[i % len(line_styles)],
181
- linewidth=2,
182
- color=colors[i % len(colors)],
183
- )
184
-
185
- # Add titles, labels, legend, and format
186
- plt.title("Metrics across Different Thresholds", fontsize=16)
187
- plt.xlabel("Threshold", fontsize=12)
188
- plt.ylabel("Metric Score", fontsize=12)
189
- plt.legend(loc="best")
190
- plt.grid(True)
191
- plt.tight_layout()
192
-
193
- return fig
194
-
195
-
196
- def plot_metrics_across_thresholds_per_class(
197
- thresholds: np.ndarray,
198
- metric_values_dict_per_class: dict[str, dict[str, np.ndarray]],
199
- metrics_to_plot: list[str],
200
- class_names: list[str],
201
- colors: list[str],
202
- ) -> plt.Figure:
203
- """
204
- Plots metrics across different thresholds per class.
205
-
206
- Args:
207
- thresholds (np.ndarray): Array of threshold values.
208
- metric_values_dict_per_class (Dict[str, Dict[str, np.ndarray]]): Dictionary mapping class names
209
- to metric dictionaries, each containing metric names and their values across thresholds.
210
- metrics_to_plot (List[str]): List of metric names to plot.
211
- class_names (List[str]): List of class names.
212
- colors (List[str]): List of colors for the lines.
213
-
214
- Raises:
215
- TypeError: If inputs are not of expected types.
216
- ValueError: If inputs have mismatched lengths or are empty.
217
-
218
- Returns:
219
- plt.Figure
220
- """
221
- # Validate inputs
222
- if not isinstance(thresholds, np.ndarray):
223
- raise TypeError("thresholds must be a numpy ndarray.")
224
- if thresholds.size == 0:
225
- raise ValueError("thresholds array is empty.")
226
- if not isinstance(metric_values_dict_per_class, dict):
227
- raise TypeError("metric_values_dict_per_class must be a dictionary.")
228
- if not isinstance(metrics_to_plot, list):
229
- raise TypeError("metrics_to_plot must be a list.")
230
- if not isinstance(class_names, list):
231
- raise TypeError("class_names must be a list.")
232
- if not isinstance(colors, list):
233
- raise TypeError("colors must be a list.")
234
- if len(colors) == 0:
235
- # Default to matplotlib's color cycle if colors are not provided
236
- colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
237
-
238
- num_classes = len(class_names)
239
- if num_classes == 0:
240
- raise ValueError("class_names list is empty.")
241
-
242
- # Determine grid size for subplots
243
- n_cols = int(np.ceil(np.sqrt(num_classes)))
244
- n_rows = int(np.ceil(num_classes / n_cols))
245
-
246
- # Create subplots
247
- fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 5, n_rows * 4))
248
-
249
- # Flatten axes for easy indexing
250
- axes = [axes] if num_classes == 1 else axes.flatten()
251
-
252
- # Line styles for distinction
253
- line_styles = ["-", "--", "-.", ":", (0, (5, 10)), (0, (5, 5)), (0, (3, 5, 1, 5))]
254
-
255
- # Plot each class
256
- for class_idx, class_name in enumerate(class_names):
257
- if class_name not in metric_values_dict_per_class:
258
- raise KeyError(f"Class '{class_name}' not found in metric_values_dict_per_class.")
259
- ax = axes[class_idx]
260
- metric_values_dict = metric_values_dict_per_class[class_name]
261
-
262
- # Plot each metric for the current class
263
- for i, metric_name in enumerate(metrics_to_plot):
264
- if metric_name not in metric_values_dict:
265
- raise KeyError(f"Metric '{metric_name}' not found for class '{class_name}'.")
266
- metric_values = metric_values_dict[metric_name]
267
- if len(metric_values) != len(thresholds):
268
- raise ValueError(
269
- f"Length of metric '{metric_name}' values for class '{class_name}' "
270
- + "does not match length of thresholds."
271
- )
272
- ax.plot(
273
- thresholds,
274
- metric_values,
275
- label=metric_name.capitalize(),
276
- linestyle=line_styles[i % len(line_styles)],
277
- linewidth=2,
278
- color=colors[i % len(colors)],
279
- )
280
-
281
- # Add titles and labels for each subplot
282
- ax.set_title(f"{class_name}", fontsize=12)
283
- ax.set_xlabel("Threshold", fontsize=10)
284
- ax.set_ylabel("Metric Score", fontsize=10)
285
- ax.legend(loc="best", fontsize=8)
286
- ax.grid(True)
287
-
288
- # Hide any unused subplots
289
- for j in range(num_classes, len(axes)):
290
- fig.delaxes(axes[j])
291
-
292
- # Adjust layout and show
293
- plt.tight_layout()
294
-
295
- return fig
296
-
297
-
298
- def plot_confusion_matrices(
299
- conf_mat: np.ndarray,
300
- task: Literal["binary", "multiclass", "multilabel"],
301
- class_names: list[str],
302
- ) -> plt.Figure:
303
- """
304
- Plots confusion matrices for each class in a single figure with multiple subplots.
305
-
306
- Args:
307
- conf_mat (np.ndarray): Confusion matrix or matrices. For binary classification, a single 2x2 matrix.
308
- For multilabel or multiclass, an array of shape (num_classes, 2, 2).
309
- task (Literal["binary", "multiclass", "multilabel"]): Task type.
310
- class_names (List[str]): List of class names.
311
-
312
- Raises:
313
- TypeError: If inputs are not of expected types.
314
- ValueError: If confusion matrix dimensions or task specifications are invalid.
315
-
316
- Returns:
317
- plt.Figure
318
- """
319
- # Validate inputs
320
- if not isinstance(conf_mat, np.ndarray):
321
- raise TypeError("conf_mat must be a numpy ndarray.")
322
- if conf_mat.size == 0:
323
- raise ValueError("conf_mat is empty.")
324
- if not isinstance(task, str) or task not in ["binary", "multiclass", "multilabel"]:
325
- raise ValueError("Invalid task. Expected 'binary', 'multiclass', or 'multilabel'.")
326
- if not isinstance(class_names, list):
327
- raise TypeError("class_names must be a list.")
328
- if len(class_names) == 0:
329
- raise ValueError("class_names list is empty.")
330
-
331
- if task == "binary":
332
- # Binary classification expects a single 2x2 matrix
333
- if conf_mat.shape != (2, 2):
334
- raise ValueError("For binary task, conf_mat must be of shape (2, 2).")
335
- if len(class_names) != 2:
336
- raise ValueError("For binary task, class_names must have exactly two elements.")
337
-
338
- # Plot single confusion matrix
339
- fig = plt.figure(figsize=(4, 4))
340
- sns.heatmap(conf_mat, annot=True, fmt=".2f", cmap="Reds", cbar=False)
341
- plt.title("Confusion Matrix")
342
- plt.xlabel("Predicted Class")
343
- plt.ylabel("True Class")
344
- plt.tight_layout()
345
- else:
346
- # Multilabel or multiclass expects a set of 2x2 matrices
347
- num_labels = conf_mat.shape[0]
348
- if conf_mat.shape[1:] != (2, 2):
349
- raise ValueError("For multilabel or multiclass task, conf_mat must have shape (num_labels, 2, 2).")
350
- if len(class_names) != num_labels:
351
- raise ValueError("Length of class_names must match number of labels in conf_mat.")
352
-
353
- # Determine grid size for subplots
354
- n_cols = int(np.ceil(np.sqrt(num_labels)))
355
- n_rows = int(np.ceil(num_labels / n_cols))
356
-
357
- # Create subplots
358
- fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 2, n_rows * 2))
359
-
360
- # Flatten axes for easy indexing
361
- axes = [axes] if num_labels == 1 else axes.flatten()
362
-
363
- # Plot each class's confusion matrix
364
- for i in range(num_labels):
365
- cm = conf_mat[i]
366
- ax = axes[i]
367
- sns.heatmap(cm, annot=True, fmt=".2f", cmap="Reds", cbar=False, ax=ax)
368
- ax.set_title(f"{class_names[i]}")
369
- ax.set_xlabel("Predicted Class")
370
- ax.set_ylabel("True Class")
371
-
372
- # Hide any unused subplots
373
- for j in range(num_labels, len(axes)):
374
- fig.delaxes(axes[j])
375
-
376
- # Adjust layout and show
377
- plt.tight_layout()
378
-
379
- return fig
1
+ """
2
+ Module containing functions to plot performance metrics.
3
+
4
+ This script provides a variety of functions to visualize performance metrics in different formats,
5
+ including bar charts, line plots, and heatmaps. These visualizations help analyze metrics such as
6
+ overall performance, per-class performance, and performance across thresholds.
7
+
8
+ Functions:
9
+ - plot_overall_metrics: Plots a bar chart for overall performance metrics.
10
+ - plot_metrics_per_class: Plots metric values per class with unique lines and colors.
11
+ - plot_metrics_across_thresholds: Plots metrics across different thresholds.
12
+ - plot_metrics_across_thresholds_per_class: Plots metrics across thresholds for each class.
13
+ - plot_confusion_matrices: Visualizes confusion matrices for binary, multiclass, or multilabel tasks.
14
+ """
15
+
16
+ from typing import Literal
17
+
18
+ import matplotlib.pyplot as plt
19
+ import numpy as np
20
+ import pandas as pd
21
+ from sklearn.metrics import ConfusionMatrixDisplay
22
+
23
+ MATPLOTLIB_BINARY_CONFUSION_MATRIX_FIGURE_NUM = "performance-tab-binary-confusion-matrix-plot"
24
+ MATPLOTLIB_MULTICLASS_CONFUSION_MATRIX_FIGURE_NUM = "performance-tab-multiclass-confusion-matrix-plot"
25
+ MATPLOTLIB_OVERALL_METRICS_FIGURE_NUM = "performance-tab-overall-metrics-plot"
26
+ MATPLOTLIB_PER_CLASS_METRICS_FIGURE_NUM = "performance-tab-per-class-metrics-plot"
27
+ MATPLOTLIB_ACROSS_METRICS_THRESHOLDS_FIGURE_NUM = "performance-tab-metrics-across-thresholds-plot"
28
+ MATPLOTLIB_ACROSS_METRICS_THRESHOLDS_PER_CLASS_FIGURE_NUM = "performance-tab-metrics-across-thresholds-per-class-plot"
29
+
30
+
31
+ def plot_overall_metrics(metrics_df: pd.DataFrame, colors: list[str]):
32
+ """
33
+ Plots a bar chart for overall performance metrics.
34
+
35
+ Args:
36
+ metrics_df (pd.DataFrame): DataFrame containing metric names as index and an 'Overall' column.
37
+ colors (List[str]): List of colors for the bars.
38
+
39
+ Raises:
40
+ TypeError: If `metrics_df` is not a DataFrame or `colors` is not a list.
41
+ KeyError: If 'Overall' column is missing in `metrics_df`.
42
+ ValueError: If `metrics_df` is empty.
43
+
44
+ Returns:
45
+ plt.Figure
46
+ """
47
+ # Validate input types and content
48
+ if not isinstance(metrics_df, pd.DataFrame):
49
+ raise TypeError("metrics_df must be a pandas DataFrame.")
50
+ if "Overall" not in metrics_df.columns:
51
+ raise KeyError("metrics_df must contain an 'Overall' column.")
52
+ if metrics_df.empty:
53
+ raise ValueError("metrics_df is empty.")
54
+ if not isinstance(colors, list):
55
+ raise TypeError("colors must be a list.")
56
+ if len(colors) == 0:
57
+ # Default to matplotlib's color cycle if colors are not provided
58
+ colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
59
+
60
+ # Extract metric names and values
61
+ metrics = metrics_df.index # Metric names
62
+ values = metrics_df["Overall"].to_numpy() # Metric values
63
+
64
+ # Plot bar chart
65
+ fig = plt.figure(MATPLOTLIB_OVERALL_METRICS_FIGURE_NUM, figsize=(10, 6))
66
+ fig.clear()
67
+ fig.tight_layout(pad=0)
68
+ fig.set_dpi(300)
69
+
70
+ plt.bar(metrics, values, color=colors[: len(metrics)])
71
+
72
+ # Add titles, labels, and format
73
+ plt.title("Overall Metric Scores", fontsize=16)
74
+ plt.xlabel("Metrics", fontsize=12)
75
+ plt.ylabel("Score", fontsize=12)
76
+ plt.xticks(rotation=45, ha="right", fontsize=10)
77
+ plt.grid(axis="y", linestyle="--", alpha=0.7)
78
+
79
+ return fig
80
+
81
+
82
+ def plot_metrics_per_class(metrics_df: pd.DataFrame, colors: list[str]):
83
+ """
84
+ Plots metric values per class, with each metric represented by a distinct color and line.
85
+
86
+ Args:
87
+ metrics_df (pd.DataFrame): DataFrame containing metrics as index and class names as columns.
88
+ colors (List[str]): List of colors for the lines.
89
+
90
+ Raises:
91
+ TypeError: If inputs are not of expected types.
92
+ ValueError: If `metrics_df` is empty.
93
+
94
+ Returns:
95
+ plt.Figure
96
+ """
97
+ # Validate inputs
98
+ if not isinstance(metrics_df, pd.DataFrame):
99
+ raise TypeError("metrics_df must be a pandas DataFrame.")
100
+ if metrics_df.empty:
101
+ raise ValueError("metrics_df is empty.")
102
+ if not isinstance(colors, list):
103
+ raise TypeError("colors must be a list.")
104
+ if len(colors) == 0:
105
+ # Default to matplotlib's color cycle if colors are not provided
106
+ colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
107
+
108
+ # Line styles for distinction
109
+ line_styles = ["-", "--", "-.", ":", (0, (5, 10)), (0, (5, 5)), (0, (3, 5, 1, 5))]
110
+ fig = plt.figure(MATPLOTLIB_OVERALL_METRICS_FIGURE_NUM, figsize=(10, 6))
111
+ fig.clear()
112
+ fig.tight_layout(pad=0)
113
+ fig.set_dpi(300)
114
+
115
+ # Loop over each metric and plot it
116
+ for i, metric_name in enumerate(metrics_df.index):
117
+ values = metrics_df.loc[metric_name] # Metric values for each class
118
+ classes = metrics_df.columns # Class labels
119
+ plt.plot(
120
+ classes,
121
+ values,
122
+ label=metric_name,
123
+ marker="o",
124
+ markersize=8,
125
+ linewidth=2,
126
+ linestyle=line_styles[i % len(line_styles)],
127
+ color=colors[i % len(colors)],
128
+ )
129
+
130
+ # Add titles, labels, legend, and format
131
+ plt.title("Metric Scores per Class", fontsize=16)
132
+ plt.xlabel("Class", fontsize=12)
133
+ plt.ylabel("Score", fontsize=12)
134
+ plt.legend(loc="lower right")
135
+ plt.grid(True)
136
+
137
+ return fig
138
+
139
+
140
+ def plot_metrics_across_thresholds(
141
+ thresholds: np.ndarray,
142
+ metric_values_dict: dict[str, np.ndarray],
143
+ metrics_to_plot: list[str],
144
+ colors: list[str],
145
+ ):
146
+ """
147
+ Plots metrics across different thresholds.
148
+
149
+ Args:
150
+ thresholds (np.ndarray): Array of threshold values.
151
+ metric_values_dict (Dict[str, np.ndarray]): Dictionary mapping metric names to their values.
152
+ metrics_to_plot (List[str]): List of metric names to plot.
153
+ colors (List[str]): List of colors for the lines.
154
+
155
+ Raises:
156
+ TypeError: If inputs are not of expected types.
157
+ ValueError: If thresholds or metric values have mismatched lengths.
158
+
159
+ Returns:
160
+ plt.Figure
161
+ """
162
+ # Validate inputs
163
+ if not isinstance(thresholds, np.ndarray):
164
+ raise TypeError("thresholds must be a numpy ndarray.")
165
+ if thresholds.size == 0:
166
+ raise ValueError("thresholds array is empty.")
167
+ if not isinstance(metric_values_dict, dict):
168
+ raise TypeError("metric_values_dict must be a dictionary.")
169
+ if not isinstance(metrics_to_plot, list):
170
+ raise TypeError("metrics_to_plot must be a list.")
171
+ if not isinstance(colors, list):
172
+ raise TypeError("colors must be a list.")
173
+ if len(colors) == 0:
174
+ # Default to matplotlib's color cycle if colors are not provided
175
+ colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
176
+
177
+ # Line styles for distinction
178
+ line_styles = ["-", "--", "-.", ":", (0, (5, 10)), (0, (5, 5)), (0, (3, 5, 1, 5))]
179
+ fig = plt.figure(MATPLOTLIB_ACROSS_METRICS_THRESHOLDS_FIGURE_NUM, figsize=(10, 6))
180
+ fig.clear()
181
+ fig.tight_layout(pad=0)
182
+ fig.set_dpi(300)
183
+
184
+ # Plot each metric against thresholds
185
+ for i, metric_name in enumerate(metrics_to_plot):
186
+ if metric_name not in metric_values_dict:
187
+ raise KeyError(f"Metric '{metric_name}' not found in metric_values_dict.")
188
+ metric_values = metric_values_dict[metric_name]
189
+ if len(metric_values) != len(thresholds):
190
+ raise ValueError(f"Length of metric '{metric_name}' values does not match length of thresholds.")
191
+ plt.plot(
192
+ thresholds,
193
+ metric_values,
194
+ label=metric_name.capitalize(),
195
+ linestyle=line_styles[i % len(line_styles)],
196
+ linewidth=2,
197
+ color=colors[i % len(colors)],
198
+ )
199
+
200
+ # Add titles, labels, legend, and format
201
+ plt.title("Metrics across Different Thresholds", fontsize=16)
202
+ plt.xlabel("Threshold", fontsize=12)
203
+ plt.ylabel("Metric Score", fontsize=12)
204
+ plt.legend(loc="best")
205
+ plt.grid(True)
206
+
207
+ return fig
208
+
209
+
210
+ def plot_metrics_across_thresholds_per_class(
211
+ thresholds: np.ndarray,
212
+ metric_values_dict_per_class: dict[str, dict[str, np.ndarray]],
213
+ metrics_to_plot: list[str],
214
+ class_names: list[str],
215
+ colors: list[str],
216
+ ):
217
+ """
218
+ Plots metrics across different thresholds per class.
219
+
220
+ Args:
221
+ thresholds (np.ndarray): Array of threshold values.
222
+ metric_values_dict_per_class (Dict[str, Dict[str, np.ndarray]]): Dictionary mapping class names
223
+ to metric dictionaries, each containing metric names and their values across thresholds.
224
+ metrics_to_plot (List[str]): List of metric names to plot.
225
+ class_names (List[str]): List of class names.
226
+ colors (List[str]): List of colors for the lines.
227
+
228
+ Raises:
229
+ TypeError: If inputs are not of expected types.
230
+ ValueError: If inputs have mismatched lengths or are empty.
231
+
232
+ Returns:
233
+ plt.Figure
234
+ """
235
+ # Validate inputs
236
+ if not isinstance(thresholds, np.ndarray):
237
+ raise TypeError("thresholds must be a numpy ndarray.")
238
+ if thresholds.size == 0:
239
+ raise ValueError("thresholds array is empty.")
240
+ if not isinstance(metric_values_dict_per_class, dict):
241
+ raise TypeError("metric_values_dict_per_class must be a dictionary.")
242
+ if not isinstance(metrics_to_plot, list):
243
+ raise TypeError("metrics_to_plot must be a list.")
244
+ if not isinstance(class_names, list):
245
+ raise TypeError("class_names must be a list.")
246
+ if not isinstance(colors, list):
247
+ raise TypeError("colors must be a list.")
248
+ if len(colors) == 0:
249
+ # Default to matplotlib's color cycle if colors are not provided
250
+ colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
251
+
252
+ num_classes = len(class_names)
253
+ if num_classes == 0:
254
+ raise ValueError("class_names list is empty.")
255
+
256
+ # Determine grid size for subplots
257
+ n_cols = int(np.ceil(np.sqrt(num_classes)))
258
+ n_rows = int(np.ceil(num_classes / n_cols))
259
+
260
+ # Create subplots
261
+ fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 5, n_rows * 4), num=MATPLOTLIB_ACROSS_METRICS_THRESHOLDS_PER_CLASS_FIGURE_NUM)
262
+ fig.clear()
263
+ fig.tight_layout(pad=0)
264
+ fig.set_dpi(300)
265
+
266
+ # Flatten axes for easy indexing
267
+ axes = [axes] if num_classes == 1 else axes.flatten()
268
+
269
+ # Line styles for distinction
270
+ line_styles = ["-", "--", "-.", ":", (0, (5, 10)), (0, (5, 5)), (0, (3, 5, 1, 5))]
271
+
272
+ # Plot each class
273
+ for class_idx, class_name in enumerate(class_names):
274
+ if class_name not in metric_values_dict_per_class:
275
+ raise KeyError(f"Class '{class_name}' not found in metric_values_dict_per_class.")
276
+ ax = axes[class_idx]
277
+ metric_values_dict = metric_values_dict_per_class[class_name]
278
+
279
+ # Plot each metric for the current class
280
+ for i, metric_name in enumerate(metrics_to_plot):
281
+ if metric_name not in metric_values_dict:
282
+ raise KeyError(f"Metric '{metric_name}' not found for class '{class_name}'.")
283
+ metric_values = metric_values_dict[metric_name]
284
+ if len(metric_values) != len(thresholds):
285
+ raise ValueError(f"Length of metric '{metric_name}' values for class '{class_name}' " + "does not match length of thresholds.")
286
+ ax.plot(
287
+ thresholds,
288
+ metric_values,
289
+ label=metric_name.capitalize(),
290
+ linestyle=line_styles[i % len(line_styles)],
291
+ linewidth=2,
292
+ color=colors[i % len(colors)],
293
+ )
294
+
295
+ # Add titles and labels for each subplot
296
+ ax.set_title(f"{class_name}", fontsize=12)
297
+ ax.set_xlabel("Threshold", fontsize=10)
298
+ ax.set_ylabel("Metric Score", fontsize=10)
299
+ ax.legend(loc="best", fontsize=8)
300
+ ax.grid(True)
301
+
302
+ return fig
303
+
304
+
305
+ def plot_confusion_matrices(
306
+ conf_mat: np.ndarray,
307
+ task: Literal["binary", "multiclass", "multilabel"],
308
+ class_names: list[str],
309
+ ):
310
+ """
311
+ Plots confusion matrices for each class in a single figure with multiple subplots.
312
+
313
+ Args:
314
+ conf_mat (np.ndarray): Confusion matrix or matrices. For binary classification, a single 2x2 matrix.
315
+ For multilabel or multiclass, an array of shape (num_classes, 2, 2).
316
+ task (Literal["binary", "multiclass", "multilabel"]): Task type.
317
+ class_names (List[str]): List of class names.
318
+
319
+ Raises:
320
+ TypeError: If inputs are not of expected types.
321
+ ValueError: If confusion matrix dimensions or task specifications are invalid.
322
+
323
+ Returns:
324
+ plt.Figure
325
+ """
326
+ # Validate inputs
327
+ if not isinstance(conf_mat, np.ndarray):
328
+ raise TypeError("conf_mat must be a numpy ndarray.")
329
+ if conf_mat.size == 0:
330
+ raise ValueError("conf_mat is empty.")
331
+ if not isinstance(task, str) or task not in ["binary", "multiclass", "multilabel"]:
332
+ raise ValueError("Invalid task. Expected 'binary', 'multiclass', or 'multilabel'.")
333
+
334
+ if task == "binary":
335
+ # Binary classification expects a single 2x2 matrix
336
+ if conf_mat.shape != (2, 2):
337
+ raise ValueError("For binary task, conf_mat must be of shape (2, 2).")
338
+
339
+ disp = ConfusionMatrixDisplay(confusion_matrix=conf_mat, display_labels=["Negative", "Positive"])
340
+ fig, ax = plt.subplots(num=MATPLOTLIB_BINARY_CONFUSION_MATRIX_FIGURE_NUM, figsize=(6, 6))
341
+
342
+ fig.tight_layout()
343
+ fig.set_dpi(300)
344
+ disp.plot(cmap="Reds", ax=ax, colorbar=False, values_format=".2f")
345
+ ax.set_title("Confusion Matrix")
346
+ else:
347
+ # Multilabel or multiclass expects a set of 2x2 matrices
348
+ num_matrices = conf_mat.shape[0]
349
+
350
+ if conf_mat.shape[1:] != (2, 2):
351
+ raise ValueError("For multilabel or multiclass task, conf_mat must have shape (num_labels, 2, 2).")
352
+ if len(class_names) != num_matrices:
353
+ raise ValueError("Length of class_names must match number of labels in conf_mat.")
354
+
355
+ # Determine grid size for subplots
356
+ n_cols = int(np.ceil(np.sqrt(num_matrices)))
357
+ n_rows = int(np.ceil(num_matrices / n_cols))
358
+
359
+ # Create subplots for each confusion matrix
360
+ fig, axes = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, 4 * n_rows), num=MATPLOTLIB_MULTICLASS_CONFUSION_MATRIX_FIGURE_NUM)
361
+ fig.set_dpi(300)
362
+ axes = axes.flatten() if hasattr(axes, "flatten") else [axes]
363
+
364
+ # Plot each confusion matrix
365
+ for idx, (cf, class_name) in enumerate(zip(conf_mat, class_names, strict=True)):
366
+ disp = ConfusionMatrixDisplay(confusion_matrix=cf, display_labels=["Negative", "Positive"])
367
+ disp.plot(cmap="Reds", ax=axes[idx], colorbar=False, values_format=".2f")
368
+ axes[idx].set_title(f"{class_name}")
369
+ axes[idx].set_xlabel("Predicted class")
370
+ axes[idx].set_ylabel("True class")
371
+
372
+ # Remove unused subplot axes
373
+ for ax in axes[num_matrices:]:
374
+ fig.delaxes(ax)
375
+
376
+ plt.tight_layout()
377
+
378
+ return fig