birdnet-analyzer 2.0.1__py3-none-any.whl → 2.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- birdnet_analyzer/analyze/__init__.py +14 -0
- birdnet_analyzer/analyze/cli.py +5 -0
- birdnet_analyzer/analyze/core.py +6 -1
- birdnet_analyzer/analyze/utils.py +42 -40
- birdnet_analyzer/audio.py +2 -2
- birdnet_analyzer/cli.py +41 -18
- birdnet_analyzer/config.py +4 -3
- birdnet_analyzer/eBird_taxonomy_codes_2024E.json +13046 -0
- birdnet_analyzer/embeddings/core.py +2 -1
- birdnet_analyzer/embeddings/utils.py +42 -1
- birdnet_analyzer/evaluation/__init__.py +6 -13
- birdnet_analyzer/evaluation/assessment/performance_assessor.py +12 -57
- birdnet_analyzer/evaluation/assessment/plotting.py +61 -62
- birdnet_analyzer/evaluation/preprocessing/data_processor.py +1 -1
- birdnet_analyzer/gui/analysis.py +5 -1
- birdnet_analyzer/gui/assets/gui.css +8 -0
- birdnet_analyzer/gui/embeddings.py +37 -18
- birdnet_analyzer/gui/evaluation.py +14 -8
- birdnet_analyzer/gui/multi_file.py +25 -5
- birdnet_analyzer/gui/review.py +16 -63
- birdnet_analyzer/gui/settings.py +25 -4
- birdnet_analyzer/gui/single_file.py +14 -17
- birdnet_analyzer/gui/train.py +7 -16
- birdnet_analyzer/gui/utils.py +42 -55
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ca.txt +1 -1
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_pl.txt +1 -1
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sr.txt +108 -108
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_zh.txt +1 -1
- birdnet_analyzer/lang/de.json +7 -0
- birdnet_analyzer/lang/en.json +7 -0
- birdnet_analyzer/lang/fi.json +7 -0
- birdnet_analyzer/lang/fr.json +7 -0
- birdnet_analyzer/lang/id.json +7 -0
- birdnet_analyzer/lang/pt-br.json +7 -0
- birdnet_analyzer/lang/ru.json +36 -29
- birdnet_analyzer/lang/se.json +7 -0
- birdnet_analyzer/lang/tlh.json +7 -0
- birdnet_analyzer/lang/zh_TW.json +7 -0
- birdnet_analyzer/model.py +21 -21
- birdnet_analyzer/search/core.py +1 -1
- birdnet_analyzer/utils.py +3 -4
- {birdnet_analyzer-2.0.1.dist-info → birdnet_analyzer-2.1.0.dist-info}/METADATA +18 -9
- {birdnet_analyzer-2.0.1.dist-info → birdnet_analyzer-2.1.0.dist-info}/RECORD +47 -47
- {birdnet_analyzer-2.0.1.dist-info → birdnet_analyzer-2.1.0.dist-info}/WHEEL +1 -1
- birdnet_analyzer/eBird_taxonomy_codes_2021E.json +0 -25280
- {birdnet_analyzer-2.0.1.dist-info → birdnet_analyzer-2.1.0.dist-info}/entry_points.txt +0 -0
- {birdnet_analyzer-2.0.1.dist-info → birdnet_analyzer-2.1.0.dist-info}/licenses/LICENSE +0 -0
- {birdnet_analyzer-2.0.1.dist-info → birdnet_analyzer-2.1.0.dist-info}/top_level.txt +0 -0
@@ -8,6 +8,7 @@ def embeddings(
|
|
8
8
|
fmax: int = 15000,
|
9
9
|
threads: int = 8,
|
10
10
|
batch_size: int = 1,
|
11
|
+
file_output: str | None = None,
|
11
12
|
):
|
12
13
|
"""
|
13
14
|
Generates embeddings for audio files using the BirdNET-Analyzer.
|
@@ -46,7 +47,7 @@ def embeddings(
|
|
46
47
|
from birdnet_analyzer.utils import ensure_model_exists
|
47
48
|
|
48
49
|
ensure_model_exists()
|
49
|
-
run(audio_input, database, overlap, audio_speed, fmin, fmax, threads, batch_size)
|
50
|
+
run(audio_input, database, overlap, audio_speed, fmin, fmax, threads, batch_size, file_output)
|
50
51
|
|
51
52
|
|
52
53
|
def get_database(db_path: str):
|
@@ -25,6 +25,7 @@ def analyze_file(item, db: sqlite_usearch_impl.SQLiteUsearchDB):
|
|
25
25
|
Args:
|
26
26
|
item: (filepath, config)
|
27
27
|
"""
|
28
|
+
|
28
29
|
# Get file path and restore cfg
|
29
30
|
fpath: str = item[0]
|
30
31
|
cfg.set_config(item[1])
|
@@ -124,7 +125,44 @@ def check_database_settings(db: sqlite_usearch_impl.SQLiteUsearchDB):
|
|
124
125
|
db.commit()
|
125
126
|
|
126
127
|
|
127
|
-
def
|
128
|
+
def create_file_output(output_path: str, db: sqlite_usearch_impl.SQLiteUsearchDB):
|
129
|
+
"""Creates a file output for the database.
|
130
|
+
|
131
|
+
Args:
|
132
|
+
output_path: Path to the output file.
|
133
|
+
db: Database object.
|
134
|
+
"""
|
135
|
+
# Check if output path exists
|
136
|
+
if not os.path.exists(output_path):
|
137
|
+
os.makedirs(output_path)
|
138
|
+
# Get all embeddings
|
139
|
+
embedding_ids = db.get_embedding_ids()
|
140
|
+
|
141
|
+
# Write embeddings to file
|
142
|
+
for embedding_id in embedding_ids:
|
143
|
+
embedding = db.get_embedding(embedding_id)
|
144
|
+
source = db.get_embedding_source(embedding_id)
|
145
|
+
|
146
|
+
# Get start and end time
|
147
|
+
start, end = source.offsets
|
148
|
+
|
149
|
+
source_id = source.source_id.rsplit(".", 1)[0]
|
150
|
+
|
151
|
+
filename = f"{source_id}_{start}_{end}.birdnet.embeddings.txt"
|
152
|
+
|
153
|
+
# Get the common prefix between the output path and the filename
|
154
|
+
common_prefix = os.path.commonpath([output_path, os.path.dirname(filename)])
|
155
|
+
relative_filename = os.path.relpath(filename, common_prefix)
|
156
|
+
target_path = os.path.join(output_path, relative_filename)
|
157
|
+
|
158
|
+
# Ensure the target directory exists
|
159
|
+
os.makedirs(os.path.dirname(target_path), exist_ok=True)
|
160
|
+
|
161
|
+
# Write embedding values to a text file
|
162
|
+
with open(target_path, "w") as f:
|
163
|
+
f.write(",".join(map(str, embedding.tolist())))
|
164
|
+
|
165
|
+
def run(audio_input, database, overlap, audio_speed, fmin, fmax, threads, batchsize, file_output):
|
128
166
|
### Make sure to comment out appropriately if you are not using args. ###
|
129
167
|
|
130
168
|
# Set input and output path
|
@@ -176,4 +214,7 @@ def run(audio_input, database, overlap, audio_speed, fmin, fmax, threads, batchs
|
|
176
214
|
with Pool(cfg.CPU_THREADS) as p:
|
177
215
|
tqdm(p.imap(partial(analyze_file, db=db), flist))
|
178
216
|
|
217
|
+
if file_output:
|
218
|
+
create_file_output(file_output, db)
|
219
|
+
|
179
220
|
db.db.close()
|
@@ -9,6 +9,7 @@ for columns, class mappings, and filtering based on selected classes or recordin
|
|
9
9
|
import argparse
|
10
10
|
import json
|
11
11
|
import os
|
12
|
+
from collections.abc import Sequence
|
12
13
|
|
13
14
|
from birdnet_analyzer.evaluation.assessment.performance_assessor import (
|
14
15
|
PerformanceAssessor,
|
@@ -25,7 +26,7 @@ def process_data(
|
|
25
26
|
recording_duration: float | None = None,
|
26
27
|
columns_annotations: dict[str, str] | None = None,
|
27
28
|
columns_predictions: dict[str, str] | None = None,
|
28
|
-
selected_classes:
|
29
|
+
selected_classes: Sequence[str] | None = None,
|
29
30
|
selected_recordings: list[str] | None = None,
|
30
31
|
metrics_list: tuple[str, ...] = ("accuracy", "precision", "recall"),
|
31
32
|
threshold: float = 0.1,
|
@@ -61,14 +62,10 @@ def process_data(
|
|
61
62
|
|
62
63
|
# Determine directory and file paths for annotations and predictions
|
63
64
|
annotation_dir, annotation_file = (
|
64
|
-
(os.path.dirname(annotation_path), os.path.basename(annotation_path))
|
65
|
-
if os.path.isfile(annotation_path)
|
66
|
-
else (annotation_path, None)
|
65
|
+
(os.path.dirname(annotation_path), os.path.basename(annotation_path)) if os.path.isfile(annotation_path) else (annotation_path, None)
|
67
66
|
)
|
68
67
|
prediction_dir, prediction_file = (
|
69
|
-
(os.path.dirname(prediction_path), os.path.basename(prediction_path))
|
70
|
-
if os.path.isfile(prediction_path)
|
71
|
-
else (prediction_path, None)
|
68
|
+
(os.path.dirname(prediction_path), os.path.basename(prediction_path)) if os.path.isfile(prediction_path) else (prediction_path, None)
|
72
69
|
)
|
73
70
|
|
74
71
|
# Initialize the DataProcessor to handle and prepare data
|
@@ -120,6 +117,8 @@ def main():
|
|
120
117
|
"""
|
121
118
|
Entry point for the script. Parses command-line arguments and orchestrates the performance assessment pipeline.
|
122
119
|
"""
|
120
|
+
import matplotlib.pyplot as plt
|
121
|
+
|
123
122
|
# Set up argument parsing
|
124
123
|
parser = argparse.ArgumentParser(description="Performance Assessor Core Script")
|
125
124
|
parser.add_argument("--annotation_path", required=True, help="Path to annotation file or folder")
|
@@ -171,8 +170,6 @@ def main():
|
|
171
170
|
if args.plot_metrics:
|
172
171
|
pa.plot_metrics(predictions, labels, per_class_metrics=args.class_wise)
|
173
172
|
if args.output_dir:
|
174
|
-
import matplotlib.pyplot as plt
|
175
|
-
|
176
173
|
plt.savefig(os.path.join(args.output_dir, "metrics_plot.png"))
|
177
174
|
else:
|
178
175
|
plt.show()
|
@@ -180,8 +177,6 @@ def main():
|
|
180
177
|
if args.plot_confusion_matrix:
|
181
178
|
pa.plot_confusion_matrix(predictions, labels)
|
182
179
|
if args.output_dir:
|
183
|
-
import matplotlib.pyplot as plt
|
184
|
-
|
185
180
|
plt.savefig(os.path.join(args.output_dir, "confusion_matrix.png"))
|
186
181
|
else:
|
187
182
|
plt.show()
|
@@ -189,8 +184,6 @@ def main():
|
|
189
184
|
if args.plot_metrics_all_thresholds:
|
190
185
|
pa.plot_metrics_all_thresholds(predictions, labels, per_class_metrics=args.class_wise)
|
191
186
|
if args.output_dir:
|
192
|
-
import matplotlib.pyplot as plt
|
193
|
-
|
194
187
|
plt.savefig(os.path.join(args.output_dir, "metrics_all_thresholds.png"))
|
195
188
|
else:
|
196
189
|
plt.show()
|
@@ -8,10 +8,9 @@ as well as utilities for generating related plots.
|
|
8
8
|
|
9
9
|
from typing import Literal
|
10
10
|
|
11
|
-
import matplotlib.pyplot as plt
|
12
11
|
import numpy as np
|
13
12
|
import pandas as pd
|
14
|
-
from sklearn.metrics import
|
13
|
+
from sklearn.metrics import confusion_matrix
|
15
14
|
|
16
15
|
from birdnet_analyzer.evaluation.assessment import metrics, plotting
|
17
16
|
|
@@ -121,10 +120,7 @@ class PerformanceAssessor:
|
|
121
120
|
if predictions.ndim != 2:
|
122
121
|
raise ValueError("predictions and labels must be 2-dimensional arrays.")
|
123
122
|
if predictions.shape[1] != self.num_classes:
|
124
|
-
raise ValueError(
|
125
|
-
f"The number of columns in predictions ({predictions.shape[1]}) "
|
126
|
-
+ f"must match num_classes ({self.num_classes})."
|
127
|
-
)
|
123
|
+
raise ValueError(f"The number of columns in predictions ({predictions.shape[1]}) " + f"must match num_classes ({self.num_classes}).")
|
128
124
|
|
129
125
|
# Determine the averaging method for metrics
|
130
126
|
if per_class_metrics and self.num_classes == 1:
|
@@ -192,11 +188,7 @@ class PerformanceAssessor:
|
|
192
188
|
metrics_results["Accuracy"] = np.atleast_1d(result)
|
193
189
|
|
194
190
|
# Define column names for the DataFrame
|
195
|
-
columns = (
|
196
|
-
(self.classes if self.classes else [f"Class {i}" for i in range(self.num_classes)])
|
197
|
-
if per_class_metrics
|
198
|
-
else ["Overall"]
|
199
|
-
)
|
191
|
+
columns = (self.classes if self.classes else [f"Class {i}" for i in range(self.num_classes)]) if per_class_metrics else ["Overall"]
|
200
192
|
|
201
193
|
# Create a DataFrame to organize metric results
|
202
194
|
metrics_data = {key: np.atleast_1d(value) for key, value in metrics_results.items()}
|
@@ -207,7 +199,7 @@ class PerformanceAssessor:
|
|
207
199
|
predictions: np.ndarray,
|
208
200
|
labels: np.ndarray,
|
209
201
|
per_class_metrics: bool = False,
|
210
|
-
)
|
202
|
+
):
|
211
203
|
"""
|
212
204
|
Plot performance metrics for the given predictions and labels.
|
213
205
|
|
@@ -226,18 +218,14 @@ class PerformanceAssessor:
|
|
226
218
|
metrics_df = self.calculate_metrics(predictions, labels, per_class_metrics)
|
227
219
|
|
228
220
|
# Choose the plotting method based on whether per-class metrics are required
|
229
|
-
return (
|
230
|
-
plotting.plot_metrics_per_class(metrics_df, self.colors)
|
231
|
-
if per_class_metrics
|
232
|
-
else plotting.plot_overall_metrics(metrics_df, self.colors)
|
233
|
-
)
|
221
|
+
return plotting.plot_metrics_per_class(metrics_df, self.colors) if per_class_metrics else plotting.plot_overall_metrics(metrics_df, self.colors)
|
234
222
|
|
235
223
|
def plot_metrics_all_thresholds(
|
236
224
|
self,
|
237
225
|
predictions: np.ndarray,
|
238
226
|
labels: np.ndarray,
|
239
227
|
per_class_metrics: bool = False,
|
240
|
-
)
|
228
|
+
):
|
241
229
|
"""
|
242
230
|
Plot performance metrics across thresholds for the given predictions and labels.
|
243
231
|
|
@@ -266,9 +254,7 @@ class PerformanceAssessor:
|
|
266
254
|
class_names = list(self.classes) if self.classes else [f"Class {i}" for i in range(self.num_classes)]
|
267
255
|
|
268
256
|
# Initialize a dictionary to store metric values per class
|
269
|
-
metric_values_dict_per_class = {
|
270
|
-
class_name: {metric: [] for metric in metrics_to_plot} for class_name in class_names
|
271
|
-
}
|
257
|
+
metric_values_dict_per_class = {class_name: {metric: [] for metric in metrics_to_plot} for class_name in class_names}
|
272
258
|
|
273
259
|
# Compute metrics for each threshold
|
274
260
|
for thresh in thresholds:
|
@@ -321,7 +307,7 @@ class PerformanceAssessor:
|
|
321
307
|
self,
|
322
308
|
predictions: np.ndarray,
|
323
309
|
labels: np.ndarray,
|
324
|
-
)
|
310
|
+
):
|
325
311
|
"""
|
326
312
|
Plot confusion matrices for each class using scikit-learn's ConfusionMatrixDisplay.
|
327
313
|
|
@@ -346,10 +332,7 @@ class PerformanceAssessor:
|
|
346
332
|
if predictions.ndim != 2:
|
347
333
|
raise ValueError("predictions and labels must be 2-dimensional arrays.")
|
348
334
|
if predictions.shape[1] != self.num_classes:
|
349
|
-
raise ValueError(
|
350
|
-
f"The number of columns in predictions ({predictions.shape[1]}) "
|
351
|
-
+ f"must match num_classes ({self.num_classes})."
|
352
|
-
)
|
335
|
+
raise ValueError(f"The number of columns in predictions ({predictions.shape[1]}) " + f"must match num_classes ({self.num_classes}).")
|
353
336
|
|
354
337
|
if self.task == "binary":
|
355
338
|
# Binarize predictions using the threshold
|
@@ -360,13 +343,7 @@ class PerformanceAssessor:
|
|
360
343
|
conf_mat = confusion_matrix(y_true, y_pred, normalize="true")
|
361
344
|
conf_mat = np.round(conf_mat, 2)
|
362
345
|
|
363
|
-
|
364
|
-
disp = ConfusionMatrixDisplay(confusion_matrix=conf_mat, display_labels=["Negative", "Positive"])
|
365
|
-
fig, ax = plt.subplots(figsize=(6, 6))
|
366
|
-
disp.plot(cmap="Reds", ax=ax, colorbar=False, values_format=".2f")
|
367
|
-
ax.set_title("Confusion Matrix")
|
368
|
-
|
369
|
-
return fig
|
346
|
+
return plotting.plot_confusion_matrices(conf_mat, self.task, self.classes)
|
370
347
|
|
371
348
|
if self.task == "multilabel":
|
372
349
|
# Binarize predictions for multilabel classification
|
@@ -376,34 +353,12 @@ class PerformanceAssessor:
|
|
376
353
|
# Compute confusion matrices for each class
|
377
354
|
conf_mats = []
|
378
355
|
class_names = self.classes if self.classes else [f"Class {i}" for i in range(self.num_classes)]
|
356
|
+
|
379
357
|
for i in range(self.num_classes):
|
380
358
|
conf_mat = confusion_matrix(y_true[:, i], y_pred[:, i], normalize="true")
|
381
359
|
conf_mat = np.round(conf_mat, 2)
|
382
360
|
conf_mats.append(conf_mat)
|
383
361
|
|
384
|
-
|
385
|
-
num_matrices = self.num_classes
|
386
|
-
n_cols = int(np.ceil(np.sqrt(num_matrices)))
|
387
|
-
n_rows = int(np.ceil(num_matrices / n_cols))
|
388
|
-
|
389
|
-
# Create subplots for each confusion matrix
|
390
|
-
fig, axes = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, 4 * n_rows))
|
391
|
-
axes = axes.flatten()
|
392
|
-
|
393
|
-
# Plot each confusion matrix
|
394
|
-
for idx, (conf_mat, class_name) in enumerate(zip(conf_mats, class_names, strict=True)):
|
395
|
-
disp = ConfusionMatrixDisplay(confusion_matrix=conf_mat, display_labels=["Negative", "Positive"])
|
396
|
-
disp.plot(cmap="Reds", ax=axes[idx], colorbar=False, values_format=".2f")
|
397
|
-
axes[idx].set_title(f"{class_name}")
|
398
|
-
axes[idx].set_xlabel("Predicted class")
|
399
|
-
axes[idx].set_ylabel("True class")
|
400
|
-
|
401
|
-
# Remove unused subplot axes
|
402
|
-
for ax in axes[num_matrices:]:
|
403
|
-
fig.delaxes(ax)
|
404
|
-
|
405
|
-
plt.tight_layout()
|
406
|
-
|
407
|
-
return fig
|
362
|
+
return plotting.plot_confusion_matrices(np.array(conf_mats), self.task, class_names)
|
408
363
|
|
409
364
|
raise ValueError(f"Unsupported task type: {self.task}")
|
@@ -18,10 +18,17 @@ from typing import Literal
|
|
18
18
|
import matplotlib.pyplot as plt
|
19
19
|
import numpy as np
|
20
20
|
import pandas as pd
|
21
|
-
|
21
|
+
from sklearn.metrics import ConfusionMatrixDisplay
|
22
22
|
|
23
|
+
MATPLOTLIB_BINARY_CONFUSION_MATRIX_FIGURE_NUM = "performance-tab-binary-confusion-matrix-plot"
|
24
|
+
MATPLOTLIB_MULTICLASS_CONFUSION_MATRIX_FIGURE_NUM = "performance-tab-multiclass-confusion-matrix-plot"
|
25
|
+
MATPLOTLIB_OVERALL_METRICS_FIGURE_NUM = "performance-tab-overall-metrics-plot"
|
26
|
+
MATPLOTLIB_PER_CLASS_METRICS_FIGURE_NUM = "performance-tab-per-class-metrics-plot"
|
27
|
+
MATPLOTLIB_ACROSS_METRICS_THRESHOLDS_FIGURE_NUM = "performance-tab-metrics-across-thresholds-plot"
|
28
|
+
MATPLOTLIB_ACROSS_METRICS_THRESHOLDS_PER_CLASS_FIGURE_NUM = "performance-tab-metrics-across-thresholds-per-class-plot"
|
23
29
|
|
24
|
-
|
30
|
+
|
31
|
+
def plot_overall_metrics(metrics_df: pd.DataFrame, colors: list[str]):
|
25
32
|
"""
|
26
33
|
Plots a bar chart for overall performance metrics.
|
27
34
|
|
@@ -55,7 +62,11 @@ def plot_overall_metrics(metrics_df: pd.DataFrame, colors: list[str]) -> plt.Fig
|
|
55
62
|
values = metrics_df["Overall"].to_numpy() # Metric values
|
56
63
|
|
57
64
|
# Plot bar chart
|
58
|
-
fig = plt.figure(figsize=(10, 6))
|
65
|
+
fig = plt.figure(MATPLOTLIB_OVERALL_METRICS_FIGURE_NUM, figsize=(10, 6))
|
66
|
+
fig.clear()
|
67
|
+
fig.tight_layout(pad=0)
|
68
|
+
fig.set_dpi(300)
|
69
|
+
|
59
70
|
plt.bar(metrics, values, color=colors[: len(metrics)])
|
60
71
|
|
61
72
|
# Add titles, labels, and format
|
@@ -64,12 +75,11 @@ def plot_overall_metrics(metrics_df: pd.DataFrame, colors: list[str]) -> plt.Fig
|
|
64
75
|
plt.ylabel("Score", fontsize=12)
|
65
76
|
plt.xticks(rotation=45, ha="right", fontsize=10)
|
66
77
|
plt.grid(axis="y", linestyle="--", alpha=0.7)
|
67
|
-
plt.tight_layout()
|
68
78
|
|
69
79
|
return fig
|
70
80
|
|
71
81
|
|
72
|
-
def plot_metrics_per_class(metrics_df: pd.DataFrame, colors: list[str])
|
82
|
+
def plot_metrics_per_class(metrics_df: pd.DataFrame, colors: list[str]):
|
73
83
|
"""
|
74
84
|
Plots metric values per class, with each metric represented by a distinct color and line.
|
75
85
|
|
@@ -97,7 +107,10 @@ def plot_metrics_per_class(metrics_df: pd.DataFrame, colors: list[str]) -> plt.F
|
|
97
107
|
|
98
108
|
# Line styles for distinction
|
99
109
|
line_styles = ["-", "--", "-.", ":", (0, (5, 10)), (0, (5, 5)), (0, (3, 5, 1, 5))]
|
100
|
-
fig = plt.figure(figsize=(10, 6))
|
110
|
+
fig = plt.figure(MATPLOTLIB_OVERALL_METRICS_FIGURE_NUM, figsize=(10, 6))
|
111
|
+
fig.clear()
|
112
|
+
fig.tight_layout(pad=0)
|
113
|
+
fig.set_dpi(300)
|
101
114
|
|
102
115
|
# Loop over each metric and plot it
|
103
116
|
for i, metric_name in enumerate(metrics_df.index):
|
@@ -120,7 +133,6 @@ def plot_metrics_per_class(metrics_df: pd.DataFrame, colors: list[str]) -> plt.F
|
|
120
133
|
plt.ylabel("Score", fontsize=12)
|
121
134
|
plt.legend(loc="lower right")
|
122
135
|
plt.grid(True)
|
123
|
-
plt.tight_layout()
|
124
136
|
|
125
137
|
return fig
|
126
138
|
|
@@ -130,7 +142,7 @@ def plot_metrics_across_thresholds(
|
|
130
142
|
metric_values_dict: dict[str, np.ndarray],
|
131
143
|
metrics_to_plot: list[str],
|
132
144
|
colors: list[str],
|
133
|
-
)
|
145
|
+
):
|
134
146
|
"""
|
135
147
|
Plots metrics across different thresholds.
|
136
148
|
|
@@ -164,7 +176,10 @@ def plot_metrics_across_thresholds(
|
|
164
176
|
|
165
177
|
# Line styles for distinction
|
166
178
|
line_styles = ["-", "--", "-.", ":", (0, (5, 10)), (0, (5, 5)), (0, (3, 5, 1, 5))]
|
167
|
-
fig = plt.figure(figsize=(10, 6))
|
179
|
+
fig = plt.figure(MATPLOTLIB_ACROSS_METRICS_THRESHOLDS_FIGURE_NUM, figsize=(10, 6))
|
180
|
+
fig.clear()
|
181
|
+
fig.tight_layout(pad=0)
|
182
|
+
fig.set_dpi(300)
|
168
183
|
|
169
184
|
# Plot each metric against thresholds
|
170
185
|
for i, metric_name in enumerate(metrics_to_plot):
|
@@ -188,7 +203,6 @@ def plot_metrics_across_thresholds(
|
|
188
203
|
plt.ylabel("Metric Score", fontsize=12)
|
189
204
|
plt.legend(loc="best")
|
190
205
|
plt.grid(True)
|
191
|
-
plt.tight_layout()
|
192
206
|
|
193
207
|
return fig
|
194
208
|
|
@@ -199,7 +213,7 @@ def plot_metrics_across_thresholds_per_class(
|
|
199
213
|
metrics_to_plot: list[str],
|
200
214
|
class_names: list[str],
|
201
215
|
colors: list[str],
|
202
|
-
)
|
216
|
+
):
|
203
217
|
"""
|
204
218
|
Plots metrics across different thresholds per class.
|
205
219
|
|
@@ -244,7 +258,10 @@ def plot_metrics_across_thresholds_per_class(
|
|
244
258
|
n_rows = int(np.ceil(num_classes / n_cols))
|
245
259
|
|
246
260
|
# Create subplots
|
247
|
-
fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 5, n_rows * 4))
|
261
|
+
fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 5, n_rows * 4), num=MATPLOTLIB_ACROSS_METRICS_THRESHOLDS_PER_CLASS_FIGURE_NUM)
|
262
|
+
fig.clear()
|
263
|
+
fig.tight_layout(pad=0)
|
264
|
+
fig.set_dpi(300)
|
248
265
|
|
249
266
|
# Flatten axes for easy indexing
|
250
267
|
axes = [axes] if num_classes == 1 else axes.flatten()
|
@@ -265,10 +282,7 @@ def plot_metrics_across_thresholds_per_class(
|
|
265
282
|
raise KeyError(f"Metric '{metric_name}' not found for class '{class_name}'.")
|
266
283
|
metric_values = metric_values_dict[metric_name]
|
267
284
|
if len(metric_values) != len(thresholds):
|
268
|
-
raise ValueError(
|
269
|
-
f"Length of metric '{metric_name}' values for class '{class_name}' "
|
270
|
-
+ "does not match length of thresholds."
|
271
|
-
)
|
285
|
+
raise ValueError(f"Length of metric '{metric_name}' values for class '{class_name}' " + "does not match length of thresholds.")
|
272
286
|
ax.plot(
|
273
287
|
thresholds,
|
274
288
|
metric_values,
|
@@ -285,13 +299,6 @@ def plot_metrics_across_thresholds_per_class(
|
|
285
299
|
ax.legend(loc="best", fontsize=8)
|
286
300
|
ax.grid(True)
|
287
301
|
|
288
|
-
# Hide any unused subplots
|
289
|
-
for j in range(num_classes, len(axes)):
|
290
|
-
fig.delaxes(axes[j])
|
291
|
-
|
292
|
-
# Adjust layout and show
|
293
|
-
plt.tight_layout()
|
294
|
-
|
295
302
|
return fig
|
296
303
|
|
297
304
|
|
@@ -299,7 +306,7 @@ def plot_confusion_matrices(
|
|
299
306
|
conf_mat: np.ndarray,
|
300
307
|
task: Literal["binary", "multiclass", "multilabel"],
|
301
308
|
class_names: list[str],
|
302
|
-
)
|
309
|
+
):
|
303
310
|
"""
|
304
311
|
Plots confusion matrices for each class in a single figure with multiple subplots.
|
305
312
|
|
@@ -323,57 +330,49 @@ def plot_confusion_matrices(
|
|
323
330
|
raise ValueError("conf_mat is empty.")
|
324
331
|
if not isinstance(task, str) or task not in ["binary", "multiclass", "multilabel"]:
|
325
332
|
raise ValueError("Invalid task. Expected 'binary', 'multiclass', or 'multilabel'.")
|
326
|
-
if not isinstance(class_names, list):
|
327
|
-
raise TypeError("class_names must be a list.")
|
328
|
-
if len(class_names) == 0:
|
329
|
-
raise ValueError("class_names list is empty.")
|
330
333
|
|
331
334
|
if task == "binary":
|
332
335
|
# Binary classification expects a single 2x2 matrix
|
333
336
|
if conf_mat.shape != (2, 2):
|
334
337
|
raise ValueError("For binary task, conf_mat must be of shape (2, 2).")
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
fig
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
plt.ylabel("True Class")
|
344
|
-
plt.tight_layout()
|
338
|
+
|
339
|
+
disp = ConfusionMatrixDisplay(confusion_matrix=conf_mat, display_labels=["Negative", "Positive"])
|
340
|
+
fig, ax = plt.subplots(num=MATPLOTLIB_BINARY_CONFUSION_MATRIX_FIGURE_NUM, figsize=(6, 6))
|
341
|
+
|
342
|
+
fig.tight_layout()
|
343
|
+
fig.set_dpi(300)
|
344
|
+
disp.plot(cmap="Reds", ax=ax, colorbar=False, values_format=".2f")
|
345
|
+
ax.set_title("Confusion Matrix")
|
345
346
|
else:
|
346
347
|
# Multilabel or multiclass expects a set of 2x2 matrices
|
347
|
-
|
348
|
+
num_matrices = conf_mat.shape[0]
|
349
|
+
|
348
350
|
if conf_mat.shape[1:] != (2, 2):
|
349
351
|
raise ValueError("For multilabel or multiclass task, conf_mat must have shape (num_labels, 2, 2).")
|
350
|
-
if len(class_names) !=
|
352
|
+
if len(class_names) != num_matrices:
|
351
353
|
raise ValueError("Length of class_names must match number of labels in conf_mat.")
|
352
354
|
|
353
355
|
# Determine grid size for subplots
|
354
|
-
n_cols = int(np.ceil(np.sqrt(
|
355
|
-
n_rows = int(np.ceil(
|
356
|
-
|
357
|
-
# Create subplots
|
358
|
-
fig, axes = plt.subplots(n_rows, n_cols, figsize=(
|
359
|
-
|
360
|
-
|
361
|
-
|
362
|
-
|
363
|
-
|
364
|
-
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
for j in range(num_labels, len(axes)):
|
374
|
-
fig.delaxes(axes[j])
|
356
|
+
n_cols = int(np.ceil(np.sqrt(num_matrices)))
|
357
|
+
n_rows = int(np.ceil(num_matrices / n_cols))
|
358
|
+
|
359
|
+
# Create subplots for each confusion matrix
|
360
|
+
fig, axes = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, 4 * n_rows), num=MATPLOTLIB_MULTICLASS_CONFUSION_MATRIX_FIGURE_NUM)
|
361
|
+
fig.set_dpi(300)
|
362
|
+
axes = axes.flatten() if hasattr(axes, "flatten") else [axes]
|
363
|
+
|
364
|
+
# Plot each confusion matrix
|
365
|
+
for idx, (cf, class_name) in enumerate(zip(conf_mat, class_names, strict=True)):
|
366
|
+
disp = ConfusionMatrixDisplay(confusion_matrix=cf, display_labels=["Negative", "Positive"])
|
367
|
+
disp.plot(cmap="Reds", ax=axes[idx], colorbar=False, values_format=".2f")
|
368
|
+
axes[idx].set_title(f"{class_name}")
|
369
|
+
axes[idx].set_xlabel("Predicted class")
|
370
|
+
axes[idx].set_ylabel("True class")
|
371
|
+
|
372
|
+
# Remove unused subplot axes
|
373
|
+
for ax in axes[num_matrices:]:
|
374
|
+
fig.delaxes(ax)
|
375
375
|
|
376
|
-
# Adjust layout and show
|
377
376
|
plt.tight_layout()
|
378
377
|
|
379
378
|
return fig
|
@@ -565,7 +565,7 @@ class DataProcessor:
|
|
565
565
|
self,
|
566
566
|
selected_classes: list[str] | None = None,
|
567
567
|
selected_recordings: list[str] | None = None,
|
568
|
-
) -> tuple[np.ndarray, np.ndarray, tuple[str]]:
|
568
|
+
) -> tuple[np.ndarray, np.ndarray, tuple[str, ...]]:
|
569
569
|
"""
|
570
570
|
Filters the prediction and label tensors based on selected classes and recordings.
|
571
571
|
|
birdnet_analyzer/gui/analysis.py
CHANGED
@@ -5,7 +5,6 @@ from pathlib import Path
|
|
5
5
|
import gradio as gr
|
6
6
|
|
7
7
|
import birdnet_analyzer.config as cfg
|
8
|
-
import birdnet_analyzer.gui.localization as loc
|
9
8
|
import birdnet_analyzer.gui.utils as gu
|
10
9
|
from birdnet_analyzer import model
|
11
10
|
from birdnet_analyzer.analyze.utils import (
|
@@ -55,6 +54,7 @@ def run_analysis(
|
|
55
54
|
sf_thresh: float,
|
56
55
|
custom_classifier_file,
|
57
56
|
output_types: str,
|
57
|
+
additional_columns: list[str] | None,
|
58
58
|
combine_tables: bool,
|
59
59
|
locale: str,
|
60
60
|
batch_size: int,
|
@@ -85,6 +85,7 @@ def run_analysis(
|
|
85
85
|
sf_thresh: The threshold for the predicted species list.
|
86
86
|
custom_classifier_file: Custom classifier to be used.
|
87
87
|
output_type: The type of result to be generated.
|
88
|
+
additional_columns: Additional columns to be added to the result.
|
88
89
|
output_filename: The filename for the combined output.
|
89
90
|
locale: The translation to be used.
|
90
91
|
batch_size: The number of samples in a batch.
|
@@ -92,6 +93,8 @@ def run_analysis(
|
|
92
93
|
input_dir: The input directory.
|
93
94
|
progress: The gradio progress bar.
|
94
95
|
"""
|
96
|
+
import birdnet_analyzer.gui.localization as loc
|
97
|
+
|
95
98
|
if progress is not None:
|
96
99
|
progress(0, desc=f"{loc.localize('progress-preparing')} ...")
|
97
100
|
|
@@ -128,6 +131,7 @@ def run_analysis(
|
|
128
131
|
slist=slist,
|
129
132
|
top_n=top_n if use_top_n else None,
|
130
133
|
output=output_path,
|
134
|
+
additional_columns=additional_columns,
|
131
135
|
)
|
132
136
|
|
133
137
|
if species_list_choice == gu._CUSTOM_CLASSIFIER:
|