birdnet-analyzer 2.0.1__py3-none-any.whl → 2.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. birdnet_analyzer/analyze/__init__.py +14 -0
  2. birdnet_analyzer/analyze/cli.py +5 -0
  3. birdnet_analyzer/analyze/core.py +6 -1
  4. birdnet_analyzer/analyze/utils.py +42 -40
  5. birdnet_analyzer/audio.py +2 -2
  6. birdnet_analyzer/cli.py +41 -18
  7. birdnet_analyzer/config.py +4 -3
  8. birdnet_analyzer/eBird_taxonomy_codes_2024E.json +13046 -0
  9. birdnet_analyzer/embeddings/core.py +2 -1
  10. birdnet_analyzer/embeddings/utils.py +42 -1
  11. birdnet_analyzer/evaluation/__init__.py +6 -13
  12. birdnet_analyzer/evaluation/assessment/performance_assessor.py +12 -57
  13. birdnet_analyzer/evaluation/assessment/plotting.py +61 -62
  14. birdnet_analyzer/evaluation/preprocessing/data_processor.py +1 -1
  15. birdnet_analyzer/gui/analysis.py +5 -1
  16. birdnet_analyzer/gui/assets/gui.css +8 -0
  17. birdnet_analyzer/gui/embeddings.py +37 -18
  18. birdnet_analyzer/gui/evaluation.py +14 -8
  19. birdnet_analyzer/gui/multi_file.py +25 -5
  20. birdnet_analyzer/gui/review.py +16 -63
  21. birdnet_analyzer/gui/settings.py +25 -4
  22. birdnet_analyzer/gui/single_file.py +14 -17
  23. birdnet_analyzer/gui/train.py +7 -16
  24. birdnet_analyzer/gui/utils.py +42 -55
  25. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ca.txt +1 -1
  26. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_pl.txt +1 -1
  27. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sr.txt +108 -108
  28. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_zh.txt +1 -1
  29. birdnet_analyzer/lang/de.json +7 -0
  30. birdnet_analyzer/lang/en.json +7 -0
  31. birdnet_analyzer/lang/fi.json +7 -0
  32. birdnet_analyzer/lang/fr.json +7 -0
  33. birdnet_analyzer/lang/id.json +7 -0
  34. birdnet_analyzer/lang/pt-br.json +7 -0
  35. birdnet_analyzer/lang/ru.json +36 -29
  36. birdnet_analyzer/lang/se.json +7 -0
  37. birdnet_analyzer/lang/tlh.json +7 -0
  38. birdnet_analyzer/lang/zh_TW.json +7 -0
  39. birdnet_analyzer/model.py +21 -21
  40. birdnet_analyzer/search/core.py +1 -1
  41. birdnet_analyzer/utils.py +3 -4
  42. {birdnet_analyzer-2.0.1.dist-info → birdnet_analyzer-2.1.0.dist-info}/METADATA +18 -9
  43. {birdnet_analyzer-2.0.1.dist-info → birdnet_analyzer-2.1.0.dist-info}/RECORD +47 -47
  44. {birdnet_analyzer-2.0.1.dist-info → birdnet_analyzer-2.1.0.dist-info}/WHEEL +1 -1
  45. birdnet_analyzer/eBird_taxonomy_codes_2021E.json +0 -25280
  46. {birdnet_analyzer-2.0.1.dist-info → birdnet_analyzer-2.1.0.dist-info}/entry_points.txt +0 -0
  47. {birdnet_analyzer-2.0.1.dist-info → birdnet_analyzer-2.1.0.dist-info}/licenses/LICENSE +0 -0
  48. {birdnet_analyzer-2.0.1.dist-info → birdnet_analyzer-2.1.0.dist-info}/top_level.txt +0 -0
@@ -8,6 +8,7 @@ def embeddings(
8
8
  fmax: int = 15000,
9
9
  threads: int = 8,
10
10
  batch_size: int = 1,
11
+ file_output: str | None = None,
11
12
  ):
12
13
  """
13
14
  Generates embeddings for audio files using the BirdNET-Analyzer.
@@ -46,7 +47,7 @@ def embeddings(
46
47
  from birdnet_analyzer.utils import ensure_model_exists
47
48
 
48
49
  ensure_model_exists()
49
- run(audio_input, database, overlap, audio_speed, fmin, fmax, threads, batch_size)
50
+ run(audio_input, database, overlap, audio_speed, fmin, fmax, threads, batch_size, file_output)
50
51
 
51
52
 
52
53
  def get_database(db_path: str):
@@ -25,6 +25,7 @@ def analyze_file(item, db: sqlite_usearch_impl.SQLiteUsearchDB):
25
25
  Args:
26
26
  item: (filepath, config)
27
27
  """
28
+
28
29
  # Get file path and restore cfg
29
30
  fpath: str = item[0]
30
31
  cfg.set_config(item[1])
@@ -124,7 +125,44 @@ def check_database_settings(db: sqlite_usearch_impl.SQLiteUsearchDB):
124
125
  db.commit()
125
126
 
126
127
 
127
- def run(audio_input, database, overlap, audio_speed, fmin, fmax, threads, batchsize):
128
+ def create_file_output(output_path: str, db: sqlite_usearch_impl.SQLiteUsearchDB):
129
+ """Creates a file output for the database.
130
+
131
+ Args:
132
+ output_path: Path to the output file.
133
+ db: Database object.
134
+ """
135
+ # Check if output path exists
136
+ if not os.path.exists(output_path):
137
+ os.makedirs(output_path)
138
+ # Get all embeddings
139
+ embedding_ids = db.get_embedding_ids()
140
+
141
+ # Write embeddings to file
142
+ for embedding_id in embedding_ids:
143
+ embedding = db.get_embedding(embedding_id)
144
+ source = db.get_embedding_source(embedding_id)
145
+
146
+ # Get start and end time
147
+ start, end = source.offsets
148
+
149
+ source_id = source.source_id.rsplit(".", 1)[0]
150
+
151
+ filename = f"{source_id}_{start}_{end}.birdnet.embeddings.txt"
152
+
153
+ # Get the common prefix between the output path and the filename
154
+ common_prefix = os.path.commonpath([output_path, os.path.dirname(filename)])
155
+ relative_filename = os.path.relpath(filename, common_prefix)
156
+ target_path = os.path.join(output_path, relative_filename)
157
+
158
+ # Ensure the target directory exists
159
+ os.makedirs(os.path.dirname(target_path), exist_ok=True)
160
+
161
+ # Write embedding values to a text file
162
+ with open(target_path, "w") as f:
163
+ f.write(",".join(map(str, embedding.tolist())))
164
+
165
+ def run(audio_input, database, overlap, audio_speed, fmin, fmax, threads, batchsize, file_output):
128
166
  ### Make sure to comment out appropriately if you are not using args. ###
129
167
 
130
168
  # Set input and output path
@@ -176,4 +214,7 @@ def run(audio_input, database, overlap, audio_speed, fmin, fmax, threads, batchs
176
214
  with Pool(cfg.CPU_THREADS) as p:
177
215
  tqdm(p.imap(partial(analyze_file, db=db), flist))
178
216
 
217
+ if file_output:
218
+ create_file_output(file_output, db)
219
+
179
220
  db.db.close()
@@ -9,6 +9,7 @@ for columns, class mappings, and filtering based on selected classes or recordin
9
9
  import argparse
10
10
  import json
11
11
  import os
12
+ from collections.abc import Sequence
12
13
 
13
14
  from birdnet_analyzer.evaluation.assessment.performance_assessor import (
14
15
  PerformanceAssessor,
@@ -25,7 +26,7 @@ def process_data(
25
26
  recording_duration: float | None = None,
26
27
  columns_annotations: dict[str, str] | None = None,
27
28
  columns_predictions: dict[str, str] | None = None,
28
- selected_classes: list[str] | None = None,
29
+ selected_classes: Sequence[str] | None = None,
29
30
  selected_recordings: list[str] | None = None,
30
31
  metrics_list: tuple[str, ...] = ("accuracy", "precision", "recall"),
31
32
  threshold: float = 0.1,
@@ -61,14 +62,10 @@ def process_data(
61
62
 
62
63
  # Determine directory and file paths for annotations and predictions
63
64
  annotation_dir, annotation_file = (
64
- (os.path.dirname(annotation_path), os.path.basename(annotation_path))
65
- if os.path.isfile(annotation_path)
66
- else (annotation_path, None)
65
+ (os.path.dirname(annotation_path), os.path.basename(annotation_path)) if os.path.isfile(annotation_path) else (annotation_path, None)
67
66
  )
68
67
  prediction_dir, prediction_file = (
69
- (os.path.dirname(prediction_path), os.path.basename(prediction_path))
70
- if os.path.isfile(prediction_path)
71
- else (prediction_path, None)
68
+ (os.path.dirname(prediction_path), os.path.basename(prediction_path)) if os.path.isfile(prediction_path) else (prediction_path, None)
72
69
  )
73
70
 
74
71
  # Initialize the DataProcessor to handle and prepare data
@@ -120,6 +117,8 @@ def main():
120
117
  """
121
118
  Entry point for the script. Parses command-line arguments and orchestrates the performance assessment pipeline.
122
119
  """
120
+ import matplotlib.pyplot as plt
121
+
123
122
  # Set up argument parsing
124
123
  parser = argparse.ArgumentParser(description="Performance Assessor Core Script")
125
124
  parser.add_argument("--annotation_path", required=True, help="Path to annotation file or folder")
@@ -171,8 +170,6 @@ def main():
171
170
  if args.plot_metrics:
172
171
  pa.plot_metrics(predictions, labels, per_class_metrics=args.class_wise)
173
172
  if args.output_dir:
174
- import matplotlib.pyplot as plt
175
-
176
173
  plt.savefig(os.path.join(args.output_dir, "metrics_plot.png"))
177
174
  else:
178
175
  plt.show()
@@ -180,8 +177,6 @@ def main():
180
177
  if args.plot_confusion_matrix:
181
178
  pa.plot_confusion_matrix(predictions, labels)
182
179
  if args.output_dir:
183
- import matplotlib.pyplot as plt
184
-
185
180
  plt.savefig(os.path.join(args.output_dir, "confusion_matrix.png"))
186
181
  else:
187
182
  plt.show()
@@ -189,8 +184,6 @@ def main():
189
184
  if args.plot_metrics_all_thresholds:
190
185
  pa.plot_metrics_all_thresholds(predictions, labels, per_class_metrics=args.class_wise)
191
186
  if args.output_dir:
192
- import matplotlib.pyplot as plt
193
-
194
187
  plt.savefig(os.path.join(args.output_dir, "metrics_all_thresholds.png"))
195
188
  else:
196
189
  plt.show()
@@ -8,10 +8,9 @@ as well as utilities for generating related plots.
8
8
 
9
9
  from typing import Literal
10
10
 
11
- import matplotlib.pyplot as plt
12
11
  import numpy as np
13
12
  import pandas as pd
14
- from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix
13
+ from sklearn.metrics import confusion_matrix
15
14
 
16
15
  from birdnet_analyzer.evaluation.assessment import metrics, plotting
17
16
 
@@ -121,10 +120,7 @@ class PerformanceAssessor:
121
120
  if predictions.ndim != 2:
122
121
  raise ValueError("predictions and labels must be 2-dimensional arrays.")
123
122
  if predictions.shape[1] != self.num_classes:
124
- raise ValueError(
125
- f"The number of columns in predictions ({predictions.shape[1]}) "
126
- + f"must match num_classes ({self.num_classes})."
127
- )
123
+ raise ValueError(f"The number of columns in predictions ({predictions.shape[1]}) " + f"must match num_classes ({self.num_classes}).")
128
124
 
129
125
  # Determine the averaging method for metrics
130
126
  if per_class_metrics and self.num_classes == 1:
@@ -192,11 +188,7 @@ class PerformanceAssessor:
192
188
  metrics_results["Accuracy"] = np.atleast_1d(result)
193
189
 
194
190
  # Define column names for the DataFrame
195
- columns = (
196
- (self.classes if self.classes else [f"Class {i}" for i in range(self.num_classes)])
197
- if per_class_metrics
198
- else ["Overall"]
199
- )
191
+ columns = (self.classes if self.classes else [f"Class {i}" for i in range(self.num_classes)]) if per_class_metrics else ["Overall"]
200
192
 
201
193
  # Create a DataFrame to organize metric results
202
194
  metrics_data = {key: np.atleast_1d(value) for key, value in metrics_results.items()}
@@ -207,7 +199,7 @@ class PerformanceAssessor:
207
199
  predictions: np.ndarray,
208
200
  labels: np.ndarray,
209
201
  per_class_metrics: bool = False,
210
- ) -> None:
202
+ ):
211
203
  """
212
204
  Plot performance metrics for the given predictions and labels.
213
205
 
@@ -226,18 +218,14 @@ class PerformanceAssessor:
226
218
  metrics_df = self.calculate_metrics(predictions, labels, per_class_metrics)
227
219
 
228
220
  # Choose the plotting method based on whether per-class metrics are required
229
- return (
230
- plotting.plot_metrics_per_class(metrics_df, self.colors)
231
- if per_class_metrics
232
- else plotting.plot_overall_metrics(metrics_df, self.colors)
233
- )
221
+ return plotting.plot_metrics_per_class(metrics_df, self.colors) if per_class_metrics else plotting.plot_overall_metrics(metrics_df, self.colors)
234
222
 
235
223
  def plot_metrics_all_thresholds(
236
224
  self,
237
225
  predictions: np.ndarray,
238
226
  labels: np.ndarray,
239
227
  per_class_metrics: bool = False,
240
- ) -> None:
228
+ ):
241
229
  """
242
230
  Plot performance metrics across thresholds for the given predictions and labels.
243
231
 
@@ -266,9 +254,7 @@ class PerformanceAssessor:
266
254
  class_names = list(self.classes) if self.classes else [f"Class {i}" for i in range(self.num_classes)]
267
255
 
268
256
  # Initialize a dictionary to store metric values per class
269
- metric_values_dict_per_class = {
270
- class_name: {metric: [] for metric in metrics_to_plot} for class_name in class_names
271
- }
257
+ metric_values_dict_per_class = {class_name: {metric: [] for metric in metrics_to_plot} for class_name in class_names}
272
258
 
273
259
  # Compute metrics for each threshold
274
260
  for thresh in thresholds:
@@ -321,7 +307,7 @@ class PerformanceAssessor:
321
307
  self,
322
308
  predictions: np.ndarray,
323
309
  labels: np.ndarray,
324
- ) -> None:
310
+ ):
325
311
  """
326
312
  Plot confusion matrices for each class using scikit-learn's ConfusionMatrixDisplay.
327
313
 
@@ -346,10 +332,7 @@ class PerformanceAssessor:
346
332
  if predictions.ndim != 2:
347
333
  raise ValueError("predictions and labels must be 2-dimensional arrays.")
348
334
  if predictions.shape[1] != self.num_classes:
349
- raise ValueError(
350
- f"The number of columns in predictions ({predictions.shape[1]}) "
351
- + f"must match num_classes ({self.num_classes})."
352
- )
335
+ raise ValueError(f"The number of columns in predictions ({predictions.shape[1]}) " + f"must match num_classes ({self.num_classes}).")
353
336
 
354
337
  if self.task == "binary":
355
338
  # Binarize predictions using the threshold
@@ -360,13 +343,7 @@ class PerformanceAssessor:
360
343
  conf_mat = confusion_matrix(y_true, y_pred, normalize="true")
361
344
  conf_mat = np.round(conf_mat, 2)
362
345
 
363
- # Plot the confusion matrix
364
- disp = ConfusionMatrixDisplay(confusion_matrix=conf_mat, display_labels=["Negative", "Positive"])
365
- fig, ax = plt.subplots(figsize=(6, 6))
366
- disp.plot(cmap="Reds", ax=ax, colorbar=False, values_format=".2f")
367
- ax.set_title("Confusion Matrix")
368
-
369
- return fig
346
+ return plotting.plot_confusion_matrices(conf_mat, self.task, self.classes)
370
347
 
371
348
  if self.task == "multilabel":
372
349
  # Binarize predictions for multilabel classification
@@ -376,34 +353,12 @@ class PerformanceAssessor:
376
353
  # Compute confusion matrices for each class
377
354
  conf_mats = []
378
355
  class_names = self.classes if self.classes else [f"Class {i}" for i in range(self.num_classes)]
356
+
379
357
  for i in range(self.num_classes):
380
358
  conf_mat = confusion_matrix(y_true[:, i], y_pred[:, i], normalize="true")
381
359
  conf_mat = np.round(conf_mat, 2)
382
360
  conf_mats.append(conf_mat)
383
361
 
384
- # Determine grid size for subplots
385
- num_matrices = self.num_classes
386
- n_cols = int(np.ceil(np.sqrt(num_matrices)))
387
- n_rows = int(np.ceil(num_matrices / n_cols))
388
-
389
- # Create subplots for each confusion matrix
390
- fig, axes = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, 4 * n_rows))
391
- axes = axes.flatten()
392
-
393
- # Plot each confusion matrix
394
- for idx, (conf_mat, class_name) in enumerate(zip(conf_mats, class_names, strict=True)):
395
- disp = ConfusionMatrixDisplay(confusion_matrix=conf_mat, display_labels=["Negative", "Positive"])
396
- disp.plot(cmap="Reds", ax=axes[idx], colorbar=False, values_format=".2f")
397
- axes[idx].set_title(f"{class_name}")
398
- axes[idx].set_xlabel("Predicted class")
399
- axes[idx].set_ylabel("True class")
400
-
401
- # Remove unused subplot axes
402
- for ax in axes[num_matrices:]:
403
- fig.delaxes(ax)
404
-
405
- plt.tight_layout()
406
-
407
- return fig
362
+ return plotting.plot_confusion_matrices(np.array(conf_mats), self.task, class_names)
408
363
 
409
364
  raise ValueError(f"Unsupported task type: {self.task}")
@@ -18,10 +18,17 @@ from typing import Literal
18
18
  import matplotlib.pyplot as plt
19
19
  import numpy as np
20
20
  import pandas as pd
21
- import seaborn as sns
21
+ from sklearn.metrics import ConfusionMatrixDisplay
22
22
 
23
+ MATPLOTLIB_BINARY_CONFUSION_MATRIX_FIGURE_NUM = "performance-tab-binary-confusion-matrix-plot"
24
+ MATPLOTLIB_MULTICLASS_CONFUSION_MATRIX_FIGURE_NUM = "performance-tab-multiclass-confusion-matrix-plot"
25
+ MATPLOTLIB_OVERALL_METRICS_FIGURE_NUM = "performance-tab-overall-metrics-plot"
26
+ MATPLOTLIB_PER_CLASS_METRICS_FIGURE_NUM = "performance-tab-per-class-metrics-plot"
27
+ MATPLOTLIB_ACROSS_METRICS_THRESHOLDS_FIGURE_NUM = "performance-tab-metrics-across-thresholds-plot"
28
+ MATPLOTLIB_ACROSS_METRICS_THRESHOLDS_PER_CLASS_FIGURE_NUM = "performance-tab-metrics-across-thresholds-per-class-plot"
23
29
 
24
- def plot_overall_metrics(metrics_df: pd.DataFrame, colors: list[str]) -> plt.Figure:
30
+
31
+ def plot_overall_metrics(metrics_df: pd.DataFrame, colors: list[str]):
25
32
  """
26
33
  Plots a bar chart for overall performance metrics.
27
34
 
@@ -55,7 +62,11 @@ def plot_overall_metrics(metrics_df: pd.DataFrame, colors: list[str]) -> plt.Fig
55
62
  values = metrics_df["Overall"].to_numpy() # Metric values
56
63
 
57
64
  # Plot bar chart
58
- fig = plt.figure(figsize=(10, 6))
65
+ fig = plt.figure(MATPLOTLIB_OVERALL_METRICS_FIGURE_NUM, figsize=(10, 6))
66
+ fig.clear()
67
+ fig.tight_layout(pad=0)
68
+ fig.set_dpi(300)
69
+
59
70
  plt.bar(metrics, values, color=colors[: len(metrics)])
60
71
 
61
72
  # Add titles, labels, and format
@@ -64,12 +75,11 @@ def plot_overall_metrics(metrics_df: pd.DataFrame, colors: list[str]) -> plt.Fig
64
75
  plt.ylabel("Score", fontsize=12)
65
76
  plt.xticks(rotation=45, ha="right", fontsize=10)
66
77
  plt.grid(axis="y", linestyle="--", alpha=0.7)
67
- plt.tight_layout()
68
78
 
69
79
  return fig
70
80
 
71
81
 
72
- def plot_metrics_per_class(metrics_df: pd.DataFrame, colors: list[str]) -> plt.Figure:
82
+ def plot_metrics_per_class(metrics_df: pd.DataFrame, colors: list[str]):
73
83
  """
74
84
  Plots metric values per class, with each metric represented by a distinct color and line.
75
85
 
@@ -97,7 +107,10 @@ def plot_metrics_per_class(metrics_df: pd.DataFrame, colors: list[str]) -> plt.F
97
107
 
98
108
  # Line styles for distinction
99
109
  line_styles = ["-", "--", "-.", ":", (0, (5, 10)), (0, (5, 5)), (0, (3, 5, 1, 5))]
100
- fig = plt.figure(figsize=(10, 6))
110
+ fig = plt.figure(MATPLOTLIB_OVERALL_METRICS_FIGURE_NUM, figsize=(10, 6))
111
+ fig.clear()
112
+ fig.tight_layout(pad=0)
113
+ fig.set_dpi(300)
101
114
 
102
115
  # Loop over each metric and plot it
103
116
  for i, metric_name in enumerate(metrics_df.index):
@@ -120,7 +133,6 @@ def plot_metrics_per_class(metrics_df: pd.DataFrame, colors: list[str]) -> plt.F
120
133
  plt.ylabel("Score", fontsize=12)
121
134
  plt.legend(loc="lower right")
122
135
  plt.grid(True)
123
- plt.tight_layout()
124
136
 
125
137
  return fig
126
138
 
@@ -130,7 +142,7 @@ def plot_metrics_across_thresholds(
130
142
  metric_values_dict: dict[str, np.ndarray],
131
143
  metrics_to_plot: list[str],
132
144
  colors: list[str],
133
- ) -> plt.Figure:
145
+ ):
134
146
  """
135
147
  Plots metrics across different thresholds.
136
148
 
@@ -164,7 +176,10 @@ def plot_metrics_across_thresholds(
164
176
 
165
177
  # Line styles for distinction
166
178
  line_styles = ["-", "--", "-.", ":", (0, (5, 10)), (0, (5, 5)), (0, (3, 5, 1, 5))]
167
- fig = plt.figure(figsize=(10, 6))
179
+ fig = plt.figure(MATPLOTLIB_ACROSS_METRICS_THRESHOLDS_FIGURE_NUM, figsize=(10, 6))
180
+ fig.clear()
181
+ fig.tight_layout(pad=0)
182
+ fig.set_dpi(300)
168
183
 
169
184
  # Plot each metric against thresholds
170
185
  for i, metric_name in enumerate(metrics_to_plot):
@@ -188,7 +203,6 @@ def plot_metrics_across_thresholds(
188
203
  plt.ylabel("Metric Score", fontsize=12)
189
204
  plt.legend(loc="best")
190
205
  plt.grid(True)
191
- plt.tight_layout()
192
206
 
193
207
  return fig
194
208
 
@@ -199,7 +213,7 @@ def plot_metrics_across_thresholds_per_class(
199
213
  metrics_to_plot: list[str],
200
214
  class_names: list[str],
201
215
  colors: list[str],
202
- ) -> plt.Figure:
216
+ ):
203
217
  """
204
218
  Plots metrics across different thresholds per class.
205
219
 
@@ -244,7 +258,10 @@ def plot_metrics_across_thresholds_per_class(
244
258
  n_rows = int(np.ceil(num_classes / n_cols))
245
259
 
246
260
  # Create subplots
247
- fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 5, n_rows * 4))
261
+ fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 5, n_rows * 4), num=MATPLOTLIB_ACROSS_METRICS_THRESHOLDS_PER_CLASS_FIGURE_NUM)
262
+ fig.clear()
263
+ fig.tight_layout(pad=0)
264
+ fig.set_dpi(300)
248
265
 
249
266
  # Flatten axes for easy indexing
250
267
  axes = [axes] if num_classes == 1 else axes.flatten()
@@ -265,10 +282,7 @@ def plot_metrics_across_thresholds_per_class(
265
282
  raise KeyError(f"Metric '{metric_name}' not found for class '{class_name}'.")
266
283
  metric_values = metric_values_dict[metric_name]
267
284
  if len(metric_values) != len(thresholds):
268
- raise ValueError(
269
- f"Length of metric '{metric_name}' values for class '{class_name}' "
270
- + "does not match length of thresholds."
271
- )
285
+ raise ValueError(f"Length of metric '{metric_name}' values for class '{class_name}' " + "does not match length of thresholds.")
272
286
  ax.plot(
273
287
  thresholds,
274
288
  metric_values,
@@ -285,13 +299,6 @@ def plot_metrics_across_thresholds_per_class(
285
299
  ax.legend(loc="best", fontsize=8)
286
300
  ax.grid(True)
287
301
 
288
- # Hide any unused subplots
289
- for j in range(num_classes, len(axes)):
290
- fig.delaxes(axes[j])
291
-
292
- # Adjust layout and show
293
- plt.tight_layout()
294
-
295
302
  return fig
296
303
 
297
304
 
@@ -299,7 +306,7 @@ def plot_confusion_matrices(
299
306
  conf_mat: np.ndarray,
300
307
  task: Literal["binary", "multiclass", "multilabel"],
301
308
  class_names: list[str],
302
- ) -> plt.Figure:
309
+ ):
303
310
  """
304
311
  Plots confusion matrices for each class in a single figure with multiple subplots.
305
312
 
@@ -323,57 +330,49 @@ def plot_confusion_matrices(
323
330
  raise ValueError("conf_mat is empty.")
324
331
  if not isinstance(task, str) or task not in ["binary", "multiclass", "multilabel"]:
325
332
  raise ValueError("Invalid task. Expected 'binary', 'multiclass', or 'multilabel'.")
326
- if not isinstance(class_names, list):
327
- raise TypeError("class_names must be a list.")
328
- if len(class_names) == 0:
329
- raise ValueError("class_names list is empty.")
330
333
 
331
334
  if task == "binary":
332
335
  # Binary classification expects a single 2x2 matrix
333
336
  if conf_mat.shape != (2, 2):
334
337
  raise ValueError("For binary task, conf_mat must be of shape (2, 2).")
335
- if len(class_names) != 2:
336
- raise ValueError("For binary task, class_names must have exactly two elements.")
337
-
338
- # Plot single confusion matrix
339
- fig = plt.figure(figsize=(4, 4))
340
- sns.heatmap(conf_mat, annot=True, fmt=".2f", cmap="Reds", cbar=False)
341
- plt.title("Confusion Matrix")
342
- plt.xlabel("Predicted Class")
343
- plt.ylabel("True Class")
344
- plt.tight_layout()
338
+
339
+ disp = ConfusionMatrixDisplay(confusion_matrix=conf_mat, display_labels=["Negative", "Positive"])
340
+ fig, ax = plt.subplots(num=MATPLOTLIB_BINARY_CONFUSION_MATRIX_FIGURE_NUM, figsize=(6, 6))
341
+
342
+ fig.tight_layout()
343
+ fig.set_dpi(300)
344
+ disp.plot(cmap="Reds", ax=ax, colorbar=False, values_format=".2f")
345
+ ax.set_title("Confusion Matrix")
345
346
  else:
346
347
  # Multilabel or multiclass expects a set of 2x2 matrices
347
- num_labels = conf_mat.shape[0]
348
+ num_matrices = conf_mat.shape[0]
349
+
348
350
  if conf_mat.shape[1:] != (2, 2):
349
351
  raise ValueError("For multilabel or multiclass task, conf_mat must have shape (num_labels, 2, 2).")
350
- if len(class_names) != num_labels:
352
+ if len(class_names) != num_matrices:
351
353
  raise ValueError("Length of class_names must match number of labels in conf_mat.")
352
354
 
353
355
  # Determine grid size for subplots
354
- n_cols = int(np.ceil(np.sqrt(num_labels)))
355
- n_rows = int(np.ceil(num_labels / n_cols))
356
-
357
- # Create subplots
358
- fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 2, n_rows * 2))
359
-
360
- # Flatten axes for easy indexing
361
- axes = [axes] if num_labels == 1 else axes.flatten()
362
-
363
- # Plot each class's confusion matrix
364
- for i in range(num_labels):
365
- cm = conf_mat[i]
366
- ax = axes[i]
367
- sns.heatmap(cm, annot=True, fmt=".2f", cmap="Reds", cbar=False, ax=ax)
368
- ax.set_title(f"{class_names[i]}")
369
- ax.set_xlabel("Predicted Class")
370
- ax.set_ylabel("True Class")
371
-
372
- # Hide any unused subplots
373
- for j in range(num_labels, len(axes)):
374
- fig.delaxes(axes[j])
356
+ n_cols = int(np.ceil(np.sqrt(num_matrices)))
357
+ n_rows = int(np.ceil(num_matrices / n_cols))
358
+
359
+ # Create subplots for each confusion matrix
360
+ fig, axes = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, 4 * n_rows), num=MATPLOTLIB_MULTICLASS_CONFUSION_MATRIX_FIGURE_NUM)
361
+ fig.set_dpi(300)
362
+ axes = axes.flatten() if hasattr(axes, "flatten") else [axes]
363
+
364
+ # Plot each confusion matrix
365
+ for idx, (cf, class_name) in enumerate(zip(conf_mat, class_names, strict=True)):
366
+ disp = ConfusionMatrixDisplay(confusion_matrix=cf, display_labels=["Negative", "Positive"])
367
+ disp.plot(cmap="Reds", ax=axes[idx], colorbar=False, values_format=".2f")
368
+ axes[idx].set_title(f"{class_name}")
369
+ axes[idx].set_xlabel("Predicted class")
370
+ axes[idx].set_ylabel("True class")
371
+
372
+ # Remove unused subplot axes
373
+ for ax in axes[num_matrices:]:
374
+ fig.delaxes(ax)
375
375
 
376
- # Adjust layout and show
377
376
  plt.tight_layout()
378
377
 
379
378
  return fig
@@ -565,7 +565,7 @@ class DataProcessor:
565
565
  self,
566
566
  selected_classes: list[str] | None = None,
567
567
  selected_recordings: list[str] | None = None,
568
- ) -> tuple[np.ndarray, np.ndarray, tuple[str]]:
568
+ ) -> tuple[np.ndarray, np.ndarray, tuple[str, ...]]:
569
569
  """
570
570
  Filters the prediction and label tensors based on selected classes and recordings.
571
571
 
@@ -5,7 +5,6 @@ from pathlib import Path
5
5
  import gradio as gr
6
6
 
7
7
  import birdnet_analyzer.config as cfg
8
- import birdnet_analyzer.gui.localization as loc
9
8
  import birdnet_analyzer.gui.utils as gu
10
9
  from birdnet_analyzer import model
11
10
  from birdnet_analyzer.analyze.utils import (
@@ -55,6 +54,7 @@ def run_analysis(
55
54
  sf_thresh: float,
56
55
  custom_classifier_file,
57
56
  output_types: str,
57
+ additional_columns: list[str] | None,
58
58
  combine_tables: bool,
59
59
  locale: str,
60
60
  batch_size: int,
@@ -85,6 +85,7 @@ def run_analysis(
85
85
  sf_thresh: The threshold for the predicted species list.
86
86
  custom_classifier_file: Custom classifier to be used.
87
87
  output_type: The type of result to be generated.
88
+ additional_columns: Additional columns to be added to the result.
88
89
  output_filename: The filename for the combined output.
89
90
  locale: The translation to be used.
90
91
  batch_size: The number of samples in a batch.
@@ -92,6 +93,8 @@ def run_analysis(
92
93
  input_dir: The input directory.
93
94
  progress: The gradio progress bar.
94
95
  """
96
+ import birdnet_analyzer.gui.localization as loc
97
+
95
98
  if progress is not None:
96
99
  progress(0, desc=f"{loc.localize('progress-preparing')} ...")
97
100
 
@@ -128,6 +131,7 @@ def run_analysis(
128
131
  slist=slist,
129
132
  top_n=top_n if use_top_n else None,
130
133
  output=output_path,
134
+ additional_columns=additional_columns,
131
135
  )
132
136
 
133
137
  if species_list_choice == gu._CUSTOM_CLASSIFIER:
@@ -26,4 +26,12 @@ footer {
26
26
  overflow: auto;
27
27
  flex-wrap: nowrap;
28
28
  padding-right: 5px;
29
+ }
30
+
31
+ #heart {
32
+ display: inline;
33
+ background-color: white;
34
+ border-radius: 3px;
35
+ margin-right: 3px;
36
+ padding: 1px;
29
37
  }