birdnet-analyzer 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- birdnet_analyzer/__init__.py +8 -0
- birdnet_analyzer/analyze/__init__.py +5 -0
- birdnet_analyzer/analyze/__main__.py +4 -0
- birdnet_analyzer/analyze/cli.py +25 -0
- birdnet_analyzer/analyze/core.py +245 -0
- birdnet_analyzer/analyze/utils.py +701 -0
- birdnet_analyzer/audio.py +372 -0
- birdnet_analyzer/cli.py +707 -0
- birdnet_analyzer/config.py +242 -0
- birdnet_analyzer/eBird_taxonomy_codes_2021E.json +25280 -0
- birdnet_analyzer/embeddings/__init__.py +4 -0
- birdnet_analyzer/embeddings/__main__.py +3 -0
- birdnet_analyzer/embeddings/cli.py +13 -0
- birdnet_analyzer/embeddings/core.py +70 -0
- birdnet_analyzer/embeddings/utils.py +193 -0
- birdnet_analyzer/evaluation/__init__.py +195 -0
- birdnet_analyzer/evaluation/__main__.py +3 -0
- birdnet_analyzer/gui/__init__.py +23 -0
- birdnet_analyzer/gui/__main__.py +3 -0
- birdnet_analyzer/gui/analysis.py +174 -0
- birdnet_analyzer/gui/assets/arrow_down.svg +4 -0
- birdnet_analyzer/gui/assets/arrow_left.svg +4 -0
- birdnet_analyzer/gui/assets/arrow_right.svg +4 -0
- birdnet_analyzer/gui/assets/arrow_up.svg +4 -0
- birdnet_analyzer/gui/assets/gui.css +29 -0
- birdnet_analyzer/gui/assets/gui.js +94 -0
- birdnet_analyzer/gui/assets/img/birdnet-icon.ico +0 -0
- birdnet_analyzer/gui/assets/img/birdnet_logo.png +0 -0
- birdnet_analyzer/gui/assets/img/birdnet_logo_no_transparent.png +0 -0
- birdnet_analyzer/gui/assets/img/clo-logo-bird.svg +1 -0
- birdnet_analyzer/gui/embeddings.py +620 -0
- birdnet_analyzer/gui/evaluation.py +813 -0
- birdnet_analyzer/gui/localization.py +68 -0
- birdnet_analyzer/gui/multi_file.py +246 -0
- birdnet_analyzer/gui/review.py +527 -0
- birdnet_analyzer/gui/segments.py +191 -0
- birdnet_analyzer/gui/settings.py +129 -0
- birdnet_analyzer/gui/single_file.py +269 -0
- birdnet_analyzer/gui/species.py +95 -0
- birdnet_analyzer/gui/train.py +698 -0
- birdnet_analyzer/gui/utils.py +808 -0
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_af.txt +6522 -0
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ar.txt +6522 -0
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_bg.txt +6522 -0
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ca.txt +6522 -0
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_cs.txt +6522 -0
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_da.txt +6522 -0
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_de.txt +6522 -0
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_el.txt +6522 -0
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_en_uk.txt +6522 -0
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_es.txt +6522 -0
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_fi.txt +6522 -0
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_fr.txt +6522 -0
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_he.txt +6522 -0
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_hr.txt +6522 -0
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_hu.txt +6522 -0
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_in.txt +6522 -0
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_is.txt +6522 -0
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_it.txt +6522 -0
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ja.txt +6522 -0
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ko.txt +6522 -0
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_lt.txt +6522 -0
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ml.txt +6522 -0
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_nl.txt +6522 -0
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_no.txt +6522 -0
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_pl.txt +6522 -0
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_pt_BR.txt +6522 -0
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_pt_PT.txt +6522 -0
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ro.txt +6522 -0
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ru.txt +6522 -0
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sk.txt +6522 -0
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sl.txt +6522 -0
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sr.txt +6522 -0
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sv.txt +6522 -0
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_th.txt +6522 -0
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_tr.txt +6522 -0
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_uk.txt +6522 -0
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_zh.txt +6522 -0
- birdnet_analyzer/lang/de.json +335 -0
- birdnet_analyzer/lang/en.json +335 -0
- birdnet_analyzer/lang/fi.json +335 -0
- birdnet_analyzer/lang/fr.json +335 -0
- birdnet_analyzer/lang/id.json +335 -0
- birdnet_analyzer/lang/pt-br.json +335 -0
- birdnet_analyzer/lang/ru.json +335 -0
- birdnet_analyzer/lang/se.json +335 -0
- birdnet_analyzer/lang/tlh.json +335 -0
- birdnet_analyzer/lang/zh_TW.json +335 -0
- birdnet_analyzer/model.py +1243 -0
- birdnet_analyzer/search/__init__.py +3 -0
- birdnet_analyzer/search/__main__.py +3 -0
- birdnet_analyzer/search/cli.py +12 -0
- birdnet_analyzer/search/core.py +78 -0
- birdnet_analyzer/search/utils.py +111 -0
- birdnet_analyzer/segments/__init__.py +3 -0
- birdnet_analyzer/segments/__main__.py +3 -0
- birdnet_analyzer/segments/cli.py +14 -0
- birdnet_analyzer/segments/core.py +78 -0
- birdnet_analyzer/segments/utils.py +394 -0
- birdnet_analyzer/species/__init__.py +3 -0
- birdnet_analyzer/species/__main__.py +3 -0
- birdnet_analyzer/species/cli.py +14 -0
- birdnet_analyzer/species/core.py +35 -0
- birdnet_analyzer/species/utils.py +75 -0
- birdnet_analyzer/train/__init__.py +3 -0
- birdnet_analyzer/train/__main__.py +3 -0
- birdnet_analyzer/train/cli.py +14 -0
- birdnet_analyzer/train/core.py +113 -0
- birdnet_analyzer/train/utils.py +847 -0
- birdnet_analyzer/translate.py +104 -0
- birdnet_analyzer/utils.py +419 -0
- birdnet_analyzer-2.0.0.dist-info/METADATA +129 -0
- birdnet_analyzer-2.0.0.dist-info/RECORD +117 -0
- birdnet_analyzer-2.0.0.dist-info/WHEEL +5 -0
- birdnet_analyzer-2.0.0.dist-info/entry_points.txt +11 -0
- birdnet_analyzer-2.0.0.dist-info/licenses/LICENSE +19 -0
- birdnet_analyzer-2.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,13 @@
|
|
1
|
+
from birdnet_analyzer.utils import runtime_error_handler
|
2
|
+
|
3
|
+
from birdnet_analyzer import embeddings
|
4
|
+
|
5
|
+
|
6
|
+
@runtime_error_handler
|
7
|
+
def main():
|
8
|
+
import birdnet_analyzer.cli as cli
|
9
|
+
|
10
|
+
parser = cli.embeddings_parser()
|
11
|
+
args = parser.parse_args()
|
12
|
+
|
13
|
+
embeddings(**vars(args))
|
@@ -0,0 +1,70 @@
|
|
1
|
+
def embeddings(
|
2
|
+
input: str,
|
3
|
+
database: str,
|
4
|
+
*,
|
5
|
+
overlap: float = 0.0,
|
6
|
+
audio_speed: float = 1.0,
|
7
|
+
fmin: int = 0,
|
8
|
+
fmax: int = 15000,
|
9
|
+
threads: int = 8,
|
10
|
+
batch_size: int = 1,
|
11
|
+
):
|
12
|
+
"""
|
13
|
+
Generates embeddings for audio files using the BirdNET-Analyzer.
|
14
|
+
This function processes audio files to extract embeddings, which are
|
15
|
+
representations of audio features. The embeddings can be used for
|
16
|
+
further analysis or comparison.
|
17
|
+
Args:
|
18
|
+
input (str): Path to the input audio file or directory containing audio files.
|
19
|
+
database (str): Path to the database where embeddings will be stored.
|
20
|
+
overlap (float, optional): Overlap between consecutive audio segments in seconds. Defaults to 0.0.
|
21
|
+
audio_speed (float, optional): Speed factor for audio processing. Defaults to 1.0.
|
22
|
+
fmin (int, optional): Minimum frequency (in Hz) for audio analysis. Defaults to 0.
|
23
|
+
fmax (int, optional): Maximum frequency (in Hz) for audio analysis. Defaults to 15000.
|
24
|
+
threads (int, optional): Number of threads to use for processing. Defaults to 8.
|
25
|
+
batch_size (int, optional): Number of audio segments to process in a single batch. Defaults to 1.
|
26
|
+
Raises:
|
27
|
+
FileNotFoundError: If the input path or database path does not exist.
|
28
|
+
ValueError: If any of the parameters are invalid.
|
29
|
+
Note:
|
30
|
+
Ensure that the required model files are downloaded and available before
|
31
|
+
calling this function. The `ensure_model_exists` function is used to
|
32
|
+
verify this.
|
33
|
+
Example:
|
34
|
+
embeddings(
|
35
|
+
input="path/to/audio",
|
36
|
+
database="path/to/database",
|
37
|
+
overlap=0.5,
|
38
|
+
audio_speed=1.0,
|
39
|
+
fmin=500,
|
40
|
+
fmax=10000,
|
41
|
+
threads=4,
|
42
|
+
batch_size=2
|
43
|
+
)
|
44
|
+
"""
|
45
|
+
from birdnet_analyzer.embeddings.utils import run
|
46
|
+
from birdnet_analyzer.utils import ensure_model_exists
|
47
|
+
|
48
|
+
ensure_model_exists()
|
49
|
+
run(input, database, overlap, audio_speed, fmin, fmax, threads, batch_size)
|
50
|
+
|
51
|
+
|
52
|
+
def get_database(db_path: str):
|
53
|
+
"""Get the database object. Creates or opens the databse.
|
54
|
+
Args:
|
55
|
+
db: The path to the database.
|
56
|
+
Returns:
|
57
|
+
The database object.
|
58
|
+
"""
|
59
|
+
import os
|
60
|
+
|
61
|
+
from perch_hoplite.db import sqlite_usearch_impl
|
62
|
+
|
63
|
+
if not os.path.exists(db_path):
|
64
|
+
os.makedirs(os.path.dirname(db_path), exist_ok=True)
|
65
|
+
db = sqlite_usearch_impl.SQLiteUsearchDB.create(
|
66
|
+
db_path=db_path,
|
67
|
+
usearch_cfg=sqlite_usearch_impl.get_default_usearch_config(embedding_dim=1024), # TODO dont hardcode this
|
68
|
+
)
|
69
|
+
return db
|
70
|
+
return sqlite_usearch_impl.SQLiteUsearchDB.create(db_path=db_path)
|
@@ -0,0 +1,193 @@
|
|
1
|
+
"""Module used to extract embeddings for samples."""
|
2
|
+
|
3
|
+
import datetime
|
4
|
+
import os
|
5
|
+
|
6
|
+
import numpy as np
|
7
|
+
|
8
|
+
import birdnet_analyzer.audio as audio
|
9
|
+
import birdnet_analyzer.config as cfg
|
10
|
+
import birdnet_analyzer.model as model
|
11
|
+
import birdnet_analyzer.utils as utils
|
12
|
+
from birdnet_analyzer.analyze.utils import get_raw_audio_from_file
|
13
|
+
from birdnet_analyzer.embeddings.core import get_database
|
14
|
+
|
15
|
+
|
16
|
+
from perch_hoplite.db import sqlite_usearch_impl
|
17
|
+
from perch_hoplite.db import interface as hoplite
|
18
|
+
from ml_collections import ConfigDict
|
19
|
+
from functools import partial
|
20
|
+
from tqdm import tqdm
|
21
|
+
from multiprocessing import Pool
|
22
|
+
|
23
|
+
|
24
|
+
DATASET_NAME: str = "birdnet_analyzer_dataset"
|
25
|
+
|
26
|
+
|
27
|
+
def analyze_file(item, db: sqlite_usearch_impl.SQLiteUsearchDB):
|
28
|
+
"""Extracts the embeddings for a file.
|
29
|
+
|
30
|
+
Args:
|
31
|
+
item: (filepath, config)
|
32
|
+
"""
|
33
|
+
# Get file path and restore cfg
|
34
|
+
fpath: str = item[0]
|
35
|
+
cfg.set_config(item[1])
|
36
|
+
|
37
|
+
offset = 0
|
38
|
+
duration = cfg.FILE_SPLITTING_DURATION
|
39
|
+
|
40
|
+
try:
|
41
|
+
fileLengthSeconds = int(audio.get_audio_file_length(fpath))
|
42
|
+
except Exception as ex:
|
43
|
+
# Write error log
|
44
|
+
print(f"Error: Cannot analyze audio file {fpath}. File corrupt?\n", flush=True)
|
45
|
+
utils.write_error_log(ex)
|
46
|
+
|
47
|
+
return None
|
48
|
+
|
49
|
+
# Start time
|
50
|
+
start_time = datetime.datetime.now()
|
51
|
+
|
52
|
+
# Status
|
53
|
+
print(f"Analyzing {fpath}", flush=True)
|
54
|
+
|
55
|
+
source_id = fpath
|
56
|
+
|
57
|
+
# Process each chunk
|
58
|
+
try:
|
59
|
+
while offset < fileLengthSeconds:
|
60
|
+
chunks = get_raw_audio_from_file(fpath, offset, duration)
|
61
|
+
start, end = offset, cfg.SIG_LENGTH + offset
|
62
|
+
samples = []
|
63
|
+
timestamps = []
|
64
|
+
|
65
|
+
for c in range(len(chunks)):
|
66
|
+
# Add to batch
|
67
|
+
samples.append(chunks[c])
|
68
|
+
timestamps.append([start, end])
|
69
|
+
|
70
|
+
# Advance start and end
|
71
|
+
start += cfg.SIG_LENGTH - cfg.SIG_OVERLAP
|
72
|
+
end = start + cfg.SIG_LENGTH
|
73
|
+
|
74
|
+
# Check if batch is full or last chunk
|
75
|
+
if len(samples) < cfg.BATCH_SIZE and c < len(chunks) - 1:
|
76
|
+
continue
|
77
|
+
|
78
|
+
# Prepare sample and pass through model
|
79
|
+
data = np.array(samples, dtype="float32")
|
80
|
+
e = model.embeddings(data)
|
81
|
+
|
82
|
+
# Add to results
|
83
|
+
for i in range(len(samples)):
|
84
|
+
# Get timestamp
|
85
|
+
s_start, s_end = timestamps[i]
|
86
|
+
|
87
|
+
# Check if embedding already exists
|
88
|
+
existing_embedding = db.get_embeddings_by_source(
|
89
|
+
DATASET_NAME, source_id, np.array([s_start, s_end])
|
90
|
+
)
|
91
|
+
|
92
|
+
if existing_embedding.size == 0:
|
93
|
+
# Get prediction
|
94
|
+
embeddings = e[i]
|
95
|
+
|
96
|
+
# Store embeddings
|
97
|
+
embeddings_source = hoplite.EmbeddingSource(DATASET_NAME, source_id, np.array([s_start, s_end]))
|
98
|
+
|
99
|
+
# Insert into database
|
100
|
+
db.insert_embedding(embeddings, embeddings_source)
|
101
|
+
db.commit()
|
102
|
+
|
103
|
+
# Reset batch
|
104
|
+
samples = []
|
105
|
+
timestamps = []
|
106
|
+
|
107
|
+
offset = offset + duration
|
108
|
+
|
109
|
+
except Exception as ex:
|
110
|
+
# Write error log
|
111
|
+
print(f"Error: Cannot analyze audio file {fpath}.", flush=True)
|
112
|
+
utils.write_error_log(ex)
|
113
|
+
|
114
|
+
return
|
115
|
+
|
116
|
+
delta_time = (datetime.datetime.now() - start_time).total_seconds()
|
117
|
+
print("Finished {} in {:.2f} seconds".format(fpath, delta_time), flush=True)
|
118
|
+
|
119
|
+
|
120
|
+
def check_database_settings(db: sqlite_usearch_impl.SQLiteUsearchDB):
|
121
|
+
try:
|
122
|
+
settings = db.get_metadata("birdnet_analyzer_settings")
|
123
|
+
if (
|
124
|
+
settings["BANDPASS_FMIN"] != cfg.BANDPASS_FMIN
|
125
|
+
or settings["BANDPASS_FMAX"] != cfg.BANDPASS_FMAX
|
126
|
+
or settings["AUDIO_SPEED"] != cfg.AUDIO_SPEED
|
127
|
+
):
|
128
|
+
raise ValueError(
|
129
|
+
"Database settings do not match current configuration. DB Settings are: fmin: {}, fmax: {}, audio_speed: {}".format(
|
130
|
+
settings["BANDPASS_FMIN"], settings["BANDPASS_FMAX"], settings["AUDIO_SPEED"]
|
131
|
+
)
|
132
|
+
)
|
133
|
+
except KeyError:
|
134
|
+
settings = ConfigDict(
|
135
|
+
{"BANDPASS_FMIN": cfg.BANDPASS_FMIN, "BANDPASS_FMAX": cfg.BANDPASS_FMAX, "AUDIO_SPEED": cfg.AUDIO_SPEED}
|
136
|
+
)
|
137
|
+
db.insert_metadata("birdnet_analyzer_settings", settings)
|
138
|
+
db.commit()
|
139
|
+
|
140
|
+
|
141
|
+
def run(input, database, overlap, audio_speed, fmin, fmax, threads, batchsize):
|
142
|
+
### Make sure to comment out appropriately if you are not using args. ###
|
143
|
+
|
144
|
+
# Set input and output path
|
145
|
+
cfg.INPUT_PATH = input
|
146
|
+
|
147
|
+
# Parse input files
|
148
|
+
if os.path.isdir(cfg.INPUT_PATH):
|
149
|
+
cfg.FILE_LIST = utils.collect_audio_files(cfg.INPUT_PATH)
|
150
|
+
else:
|
151
|
+
cfg.FILE_LIST = [cfg.INPUT_PATH]
|
152
|
+
|
153
|
+
# Set overlap
|
154
|
+
cfg.SIG_OVERLAP = max(0.0, min(2.9, float(overlap)))
|
155
|
+
|
156
|
+
# Set audio speed
|
157
|
+
cfg.AUDIO_SPEED = max(0.01, audio_speed)
|
158
|
+
|
159
|
+
# Set bandpass frequency range
|
160
|
+
cfg.BANDPASS_FMIN = max(0, min(cfg.SIG_FMAX, int(fmin)))
|
161
|
+
cfg.BANDPASS_FMAX = max(cfg.SIG_FMIN, min(cfg.SIG_FMAX, int(fmax)))
|
162
|
+
|
163
|
+
# Set number of threads
|
164
|
+
if os.path.isdir(cfg.INPUT_PATH):
|
165
|
+
cfg.CPU_THREADS = max(1, int(threads))
|
166
|
+
cfg.TFLITE_THREADS = 1
|
167
|
+
else:
|
168
|
+
cfg.CPU_THREADS = 1
|
169
|
+
cfg.TFLITE_THREADS = max(1, int(threads))
|
170
|
+
|
171
|
+
cfg.CPU_THREADS = 1 # TODO: with the current implementation, we can't use more than 1 thread
|
172
|
+
|
173
|
+
# Set batch size
|
174
|
+
cfg.BATCH_SIZE = max(1, int(batchsize))
|
175
|
+
|
176
|
+
# Add config items to each file list entry.
|
177
|
+
# We have to do this for Windows which does not
|
178
|
+
# support fork() and thus each process has to
|
179
|
+
# have its own config. USE LINUX!
|
180
|
+
flist = [(f, cfg.get_config()) for f in cfg.FILE_LIST]
|
181
|
+
|
182
|
+
db = get_database(database)
|
183
|
+
check_database_settings(db)
|
184
|
+
|
185
|
+
# Analyze files
|
186
|
+
if cfg.CPU_THREADS < 2:
|
187
|
+
for entry in tqdm(flist):
|
188
|
+
analyze_file(entry, db)
|
189
|
+
else:
|
190
|
+
with Pool(cfg.CPU_THREADS) as p:
|
191
|
+
tqdm(p.imap(partial(analyze_file, db=db), flist))
|
192
|
+
|
193
|
+
db.db.close()
|
@@ -0,0 +1,195 @@
|
|
1
|
+
"""
|
2
|
+
Core script for assessing performance of prediction models against annotated data.
|
3
|
+
|
4
|
+
This script uses the `DataProcessor` and `PerformanceAssessor` classes to process prediction and
|
5
|
+
annotation data, compute metrics, and optionally generate plots. It supports flexible configurations
|
6
|
+
for columns, class mappings, and filtering based on selected classes or recordings.
|
7
|
+
"""
|
8
|
+
|
9
|
+
import argparse
|
10
|
+
import json
|
11
|
+
import os
|
12
|
+
from typing import Optional, Dict, List, Tuple
|
13
|
+
|
14
|
+
from birdnet_analyzer.evaluation.preprocessing.data_processor import DataProcessor
|
15
|
+
from birdnet_analyzer.evaluation.assessment.performance_assessor import PerformanceAssessor
|
16
|
+
|
17
|
+
|
18
|
+
def process_data(
|
19
|
+
annotation_path: str,
|
20
|
+
prediction_path: str,
|
21
|
+
mapping_path: Optional[str] = None,
|
22
|
+
sample_duration: float = 3.0,
|
23
|
+
min_overlap: float = 0.5,
|
24
|
+
recording_duration: Optional[float] = None,
|
25
|
+
columns_annotations: Optional[Dict[str, str]] = None,
|
26
|
+
columns_predictions: Optional[Dict[str, str]] = None,
|
27
|
+
selected_classes: Optional[List[str]] = None,
|
28
|
+
selected_recordings: Optional[List[str]] = None,
|
29
|
+
metrics_list: Tuple[str, ...] = ("accuracy", "precision", "recall"),
|
30
|
+
threshold: float = 0.1,
|
31
|
+
class_wise: bool = False,
|
32
|
+
):
|
33
|
+
"""
|
34
|
+
Processes data, computes metrics, and prepares the performance assessment pipeline.
|
35
|
+
|
36
|
+
Args:
|
37
|
+
annotation_path (str): Path to the annotation file or folder.
|
38
|
+
prediction_path (str): Path to the prediction file or folder.
|
39
|
+
mapping_path (Optional[str]): Path to the class mapping JSON file, if applicable.
|
40
|
+
sample_duration (float): Duration of each sample interval in seconds.
|
41
|
+
min_overlap (float): Minimum overlap required between predictions and annotations.
|
42
|
+
recording_duration (Optional[float]): Total duration of the recordings, if known.
|
43
|
+
columns_annotations (Optional[Dict[str, str]]): Custom column mappings for annotations.
|
44
|
+
columns_predictions (Optional[Dict[str, str]]): Custom column mappings for predictions.
|
45
|
+
selected_classes (Optional[List[str]]): List of classes to include in the analysis.
|
46
|
+
selected_recordings (Optional[List[str]]): List of recordings to include in the analysis.
|
47
|
+
metrics_list (Tuple[str, ...]): Metrics to compute for performance assessment.
|
48
|
+
threshold (float): Confidence threshold for predictions.
|
49
|
+
class_wise (bool): Whether to calculate metrics on a per-class basis.
|
50
|
+
|
51
|
+
Returns:
|
52
|
+
Tuple: Metrics DataFrame, `PerformanceAssessor` object, predictions tensor, labels tensor.
|
53
|
+
"""
|
54
|
+
# Load class mapping if provided
|
55
|
+
if mapping_path:
|
56
|
+
with open(mapping_path, "r") as f:
|
57
|
+
class_mapping = json.load(f)
|
58
|
+
else:
|
59
|
+
class_mapping = None
|
60
|
+
|
61
|
+
# Determine directory and file paths for annotations and predictions
|
62
|
+
annotation_dir, annotation_file = (
|
63
|
+
(os.path.dirname(annotation_path), os.path.basename(annotation_path))
|
64
|
+
if os.path.isfile(annotation_path)
|
65
|
+
else (annotation_path, None)
|
66
|
+
)
|
67
|
+
prediction_dir, prediction_file = (
|
68
|
+
(os.path.dirname(prediction_path), os.path.basename(prediction_path))
|
69
|
+
if os.path.isfile(prediction_path)
|
70
|
+
else (prediction_path, None)
|
71
|
+
)
|
72
|
+
|
73
|
+
# Initialize the DataProcessor to handle and prepare data
|
74
|
+
processor = DataProcessor(
|
75
|
+
prediction_directory_path=prediction_dir,
|
76
|
+
prediction_file_name=prediction_file,
|
77
|
+
annotation_directory_path=annotation_dir,
|
78
|
+
annotation_file_name=annotation_file,
|
79
|
+
class_mapping=class_mapping,
|
80
|
+
sample_duration=sample_duration,
|
81
|
+
min_overlap=min_overlap,
|
82
|
+
columns_predictions=columns_predictions,
|
83
|
+
columns_annotations=columns_annotations,
|
84
|
+
recording_duration=recording_duration,
|
85
|
+
)
|
86
|
+
|
87
|
+
# Get the available classes and recordings
|
88
|
+
available_classes = processor.classes
|
89
|
+
available_recordings = processor.samples_df["filename"].unique().tolist()
|
90
|
+
|
91
|
+
# Default to all classes or recordings if none are specified
|
92
|
+
if selected_classes is None:
|
93
|
+
selected_classes = available_classes
|
94
|
+
if selected_recordings is None:
|
95
|
+
selected_recordings = available_recordings
|
96
|
+
|
97
|
+
# Retrieve predictions and labels tensors for the selected classes and recordings
|
98
|
+
predictions, labels, classes = processor.get_filtered_tensors(selected_classes, selected_recordings)
|
99
|
+
|
100
|
+
num_classes = len(classes)
|
101
|
+
task = "binary" if num_classes == 1 else "multilabel"
|
102
|
+
|
103
|
+
# Initialize the PerformanceAssessor for computing metrics
|
104
|
+
pa = PerformanceAssessor(
|
105
|
+
num_classes=num_classes,
|
106
|
+
threshold=threshold,
|
107
|
+
classes=classes,
|
108
|
+
task=task,
|
109
|
+
metrics_list=metrics_list,
|
110
|
+
)
|
111
|
+
|
112
|
+
# Compute performance metrics
|
113
|
+
metrics_df = pa.calculate_metrics(predictions, labels, per_class_metrics=class_wise)
|
114
|
+
|
115
|
+
return metrics_df, pa, predictions, labels
|
116
|
+
|
117
|
+
|
118
|
+
def main():
|
119
|
+
"""
|
120
|
+
Entry point for the script. Parses command-line arguments and orchestrates the performance assessment pipeline.
|
121
|
+
"""
|
122
|
+
# Set up argument parsing
|
123
|
+
parser = argparse.ArgumentParser(description="Performance Assessor Core Script")
|
124
|
+
parser.add_argument("--annotation_path", required=True, help="Path to annotation file or folder")
|
125
|
+
parser.add_argument("--prediction_path", required=True, help="Path to prediction file or folder")
|
126
|
+
parser.add_argument("--mapping_path", help="Path to class mapping JSON file (optional)")
|
127
|
+
parser.add_argument("--sample_duration", type=float, default=3.0, help="Sample duration in seconds")
|
128
|
+
parser.add_argument("--min_overlap", type=float, default=0.5, help="Minimum overlap in seconds")
|
129
|
+
parser.add_argument("--recording_duration", type=float, help="Recording duration in seconds")
|
130
|
+
parser.add_argument("--columns_annotations", type=json.loads, help="JSON string for columns_annotations")
|
131
|
+
parser.add_argument("--columns_predictions", type=json.loads, help="JSON string for columns_predictions")
|
132
|
+
parser.add_argument("--selected_classes", nargs="+", help="List of selected classes")
|
133
|
+
parser.add_argument("--selected_recordings", nargs="+", help="List of selected recordings")
|
134
|
+
parser.add_argument("--metrics", nargs="+", default=["accuracy", "precision", "recall"], help="List of metrics")
|
135
|
+
parser.add_argument("--threshold", type=float, default=0.1, help="Threshold value (0-1)")
|
136
|
+
parser.add_argument("--class_wise", action="store_true", help="Calculate class-wise metrics")
|
137
|
+
parser.add_argument("--plot_metrics", action="store_true", help="Plot metrics")
|
138
|
+
parser.add_argument("--plot_confusion_matrix", action="store_true", help="Plot confusion matrix")
|
139
|
+
parser.add_argument("--plot_metrics_all_thresholds", action="store_true", help="Plot metrics for all thresholds")
|
140
|
+
parser.add_argument("--output_dir", help="Directory to save plots")
|
141
|
+
|
142
|
+
# Parse arguments
|
143
|
+
args = parser.parse_args()
|
144
|
+
|
145
|
+
# Process data and compute metrics
|
146
|
+
metrics_df, pa, predictions, labels = process_data(
|
147
|
+
annotation_path=args.annotation_path,
|
148
|
+
prediction_path=args.prediction_path,
|
149
|
+
mapping_path=args.mapping_path,
|
150
|
+
sample_duration=args.sample_duration,
|
151
|
+
min_overlap=args.min_overlap,
|
152
|
+
recording_duration=args.recording_duration,
|
153
|
+
columns_annotations=args.columns_annotations,
|
154
|
+
columns_predictions=args.columns_predictions,
|
155
|
+
selected_classes=args.selected_classes,
|
156
|
+
selected_recordings=args.selected_recordings,
|
157
|
+
metrics_list=args.metrics,
|
158
|
+
threshold=args.threshold,
|
159
|
+
class_wise=args.class_wise,
|
160
|
+
)
|
161
|
+
|
162
|
+
# Display the computed metrics
|
163
|
+
print(metrics_df)
|
164
|
+
|
165
|
+
# Create output directory if needed
|
166
|
+
if args.output_dir and not os.path.exists(args.output_dir):
|
167
|
+
os.makedirs(args.output_dir)
|
168
|
+
|
169
|
+
# Generate plots if specified
|
170
|
+
if args.plot_metrics:
|
171
|
+
pa.plot_metrics(predictions, labels, per_class_metrics=args.class_wise)
|
172
|
+
if args.output_dir:
|
173
|
+
import matplotlib.pyplot as plt
|
174
|
+
|
175
|
+
plt.savefig(os.path.join(args.output_dir, "metrics_plot.png"))
|
176
|
+
else:
|
177
|
+
plt.show()
|
178
|
+
|
179
|
+
if args.plot_confusion_matrix:
|
180
|
+
pa.plot_confusion_matrix(predictions, labels)
|
181
|
+
if args.output_dir:
|
182
|
+
import matplotlib.pyplot as plt
|
183
|
+
|
184
|
+
plt.savefig(os.path.join(args.output_dir, "confusion_matrix.png"))
|
185
|
+
else:
|
186
|
+
plt.show()
|
187
|
+
|
188
|
+
if args.plot_metrics_all_thresholds:
|
189
|
+
pa.plot_metrics_all_thresholds(predictions, labels, per_class_metrics=args.class_wise)
|
190
|
+
if args.output_dir:
|
191
|
+
import matplotlib.pyplot as plt
|
192
|
+
|
193
|
+
plt.savefig(os.path.join(args.output_dir, "metrics_all_thresholds.png"))
|
194
|
+
else:
|
195
|
+
plt.show()
|
@@ -0,0 +1,23 @@
|
|
1
|
+
def main():
|
2
|
+
import birdnet_analyzer.gui.multi_file as mfa
|
3
|
+
import birdnet_analyzer.gui.review as review
|
4
|
+
import birdnet_analyzer.gui.segments as gs
|
5
|
+
import birdnet_analyzer.gui.single_file as sfa
|
6
|
+
import birdnet_analyzer.gui.species as species
|
7
|
+
import birdnet_analyzer.gui.train as train
|
8
|
+
import birdnet_analyzer.gui.utils as gu
|
9
|
+
import birdnet_analyzer.gui.embeddings as embeddings
|
10
|
+
import birdnet_analyzer.gui.evaluation as evaluation
|
11
|
+
|
12
|
+
gu.open_window(
|
13
|
+
[
|
14
|
+
sfa.build_single_analysis_tab,
|
15
|
+
mfa.build_multi_analysis_tab,
|
16
|
+
train.build_train_tab,
|
17
|
+
gs.build_segments_tab,
|
18
|
+
review.build_review_tab,
|
19
|
+
species.build_species_tab,
|
20
|
+
embeddings.build_embeddings_tab,
|
21
|
+
evaluation.build_evaluation_tab,
|
22
|
+
]
|
23
|
+
)
|
@@ -0,0 +1,174 @@
|
|
1
|
+
import concurrent.futures
|
2
|
+
import os
|
3
|
+
from pathlib import Path
|
4
|
+
|
5
|
+
import gradio as gr
|
6
|
+
|
7
|
+
import birdnet_analyzer.config as cfg
|
8
|
+
import birdnet_analyzer.gui.utils as gu
|
9
|
+
import birdnet_analyzer.gui.localization as loc
|
10
|
+
import birdnet_analyzer.model as model
|
11
|
+
|
12
|
+
|
13
|
+
from birdnet_analyzer.analyze.utils import analyze_file, combine_results, save_analysis_params
|
14
|
+
|
15
|
+
SCRIPT_DIR = os.path.abspath(os.path.dirname(__file__))
|
16
|
+
ORIGINAL_LABELS_FILE = str(Path(SCRIPT_DIR).parent / cfg.LABELS_FILE)
|
17
|
+
|
18
|
+
|
19
|
+
def analyze_file_wrapper(entry):
|
20
|
+
"""
|
21
|
+
Wrapper function for analyzing a file.
|
22
|
+
|
23
|
+
Args:
|
24
|
+
entry (tuple): A tuple where the first element is the file path and the
|
25
|
+
remaining elements are arguments to be passed to the
|
26
|
+
analyze.analyzeFile function.
|
27
|
+
|
28
|
+
Returns:
|
29
|
+
tuple: A tuple where the first element is the file path and the second
|
30
|
+
element is the result of the analyze.analyzeFile function.
|
31
|
+
"""
|
32
|
+
return (entry[0], analyze_file(entry))
|
33
|
+
|
34
|
+
|
35
|
+
def run_analysis(
|
36
|
+
input_path: str,
|
37
|
+
output_path: str | None,
|
38
|
+
use_top_n: bool,
|
39
|
+
top_n: int,
|
40
|
+
confidence: float,
|
41
|
+
sensitivity: float,
|
42
|
+
overlap: float,
|
43
|
+
merge_consecutive: int,
|
44
|
+
audio_speed: float,
|
45
|
+
fmin: int,
|
46
|
+
fmax: int,
|
47
|
+
species_list_choice: str,
|
48
|
+
species_list_file,
|
49
|
+
lat: float,
|
50
|
+
lon: float,
|
51
|
+
week: int,
|
52
|
+
use_yearlong: bool,
|
53
|
+
sf_thresh: float,
|
54
|
+
custom_classifier_file,
|
55
|
+
output_types: str,
|
56
|
+
combine_tables: bool,
|
57
|
+
locale: str,
|
58
|
+
batch_size: int,
|
59
|
+
threads: int,
|
60
|
+
input_dir: str,
|
61
|
+
skip_existing: bool,
|
62
|
+
save_params: bool,
|
63
|
+
progress: gr.Progress | None,
|
64
|
+
):
|
65
|
+
"""Starts the analysis.
|
66
|
+
|
67
|
+
Args:
|
68
|
+
input_path: Either a file or directory.
|
69
|
+
output_path: The output path for the result, if None the input_path is used
|
70
|
+
confidence: The selected minimum confidence.
|
71
|
+
sensitivity: The selected sensitivity.
|
72
|
+
overlap: The selected segment overlap.
|
73
|
+
merge_consecutive: The number of consecutive segments to merge into one.
|
74
|
+
audio_speed: The selected audio speed.
|
75
|
+
fmin: The selected minimum bandpass frequency.
|
76
|
+
fmax: The selected maximum bandpass frequency.
|
77
|
+
species_list_choice: The choice for the species list.
|
78
|
+
species_list_file: The selected custom species list file.
|
79
|
+
lat: The selected latitude.
|
80
|
+
lon: The selected longitude.
|
81
|
+
week: The selected week of the year.
|
82
|
+
use_yearlong: Use yearlong instead of week.
|
83
|
+
sf_thresh: The threshold for the predicted species list.
|
84
|
+
custom_classifier_file: Custom classifier to be used.
|
85
|
+
output_type: The type of result to be generated.
|
86
|
+
output_filename: The filename for the combined output.
|
87
|
+
locale: The translation to be used.
|
88
|
+
batch_size: The number of samples in a batch.
|
89
|
+
threads: The number of threads to be used.
|
90
|
+
input_dir: The input directory.
|
91
|
+
progress: The gradio progress bar.
|
92
|
+
"""
|
93
|
+
if progress is not None:
|
94
|
+
progress(0, desc=f"{loc.localize('progress-preparing')} ...")
|
95
|
+
|
96
|
+
from birdnet_analyzer.analyze.core import _set_params
|
97
|
+
|
98
|
+
locale = locale.lower()
|
99
|
+
custom_classifier = custom_classifier_file if species_list_choice == gu._CUSTOM_CLASSIFIER else None
|
100
|
+
slist = species_list_file if species_list_choice == gu._CUSTOM_SPECIES else None
|
101
|
+
lat = lat if species_list_choice == gu._PREDICT_SPECIES else -1
|
102
|
+
lon = lon if species_list_choice == gu._PREDICT_SPECIES else -1
|
103
|
+
week = -1 if use_yearlong else week
|
104
|
+
|
105
|
+
flist = _set_params(
|
106
|
+
input=input_dir if input_dir else input_path,
|
107
|
+
min_conf=confidence,
|
108
|
+
custom_classifier=custom_classifier,
|
109
|
+
sensitivity=min(1.25, max(0.75, float(sensitivity))),
|
110
|
+
locale=locale,
|
111
|
+
overlap=max(0.0, min(2.9, float(overlap))),
|
112
|
+
merge_consecutive=max(1, int(merge_consecutive)),
|
113
|
+
audio_speed=max(0.1, 1.0 / (audio_speed * -1)) if audio_speed < 0 else max(1.0, float(audio_speed)),
|
114
|
+
fmin=max(0, min(cfg.SIG_FMAX, int(fmin))),
|
115
|
+
fmax=max(cfg.SIG_FMIN, min(cfg.SIG_FMAX, int(fmax))),
|
116
|
+
bs=max(1, int(batch_size)),
|
117
|
+
combine_results=combine_tables,
|
118
|
+
rtype=output_types,
|
119
|
+
skip_existing_results=skip_existing,
|
120
|
+
threads=max(1, int(threads)),
|
121
|
+
labels_file=ORIGINAL_LABELS_FILE,
|
122
|
+
sf_thresh=sf_thresh,
|
123
|
+
lat=lat,
|
124
|
+
lon=lon,
|
125
|
+
week=week,
|
126
|
+
slist=slist,
|
127
|
+
top_n=top_n if use_top_n else None,
|
128
|
+
output=output_path,
|
129
|
+
)
|
130
|
+
|
131
|
+
if species_list_choice == gu._CUSTOM_CLASSIFIER:
|
132
|
+
if custom_classifier_file is None:
|
133
|
+
raise gr.Error(loc.localize("validation-no-custom-classifier-selected"))
|
134
|
+
|
135
|
+
model.reset_custom_classifier()
|
136
|
+
|
137
|
+
gu.validate(cfg.FILE_LIST, loc.localize("validation-no-audio-files-found"))
|
138
|
+
|
139
|
+
result_list = []
|
140
|
+
|
141
|
+
if progress is not None:
|
142
|
+
progress(0, desc=f"{loc.localize('progress-starting')} ...")
|
143
|
+
|
144
|
+
# Analyze files
|
145
|
+
if cfg.CPU_THREADS < 2:
|
146
|
+
for entry in flist:
|
147
|
+
result_list.append(analyze_file_wrapper(entry))
|
148
|
+
else:
|
149
|
+
with concurrent.futures.ProcessPoolExecutor(max_workers=cfg.CPU_THREADS) as executor:
|
150
|
+
futures = (executor.submit(analyze_file_wrapper, arg) for arg in flist)
|
151
|
+
for i, f in enumerate(concurrent.futures.as_completed(futures), start=1):
|
152
|
+
if progress is not None:
|
153
|
+
progress((i, len(flist)), total=len(flist), unit="files")
|
154
|
+
result = f.result()
|
155
|
+
|
156
|
+
result_list.append(result)
|
157
|
+
|
158
|
+
# Combine results?
|
159
|
+
if cfg.COMBINE_RESULTS:
|
160
|
+
combine_list = [[r[1] for r in result_list if r[0] == i[0]][0] for i in flist]
|
161
|
+
print(f"Combining results, writing to {cfg.OUTPUT_PATH}...", end="", flush=True)
|
162
|
+
combine_results(combine_list)
|
163
|
+
print("done!", flush=True)
|
164
|
+
|
165
|
+
if save_params:
|
166
|
+
save_analysis_params(os.path.join(cfg.OUTPUT_PATH, cfg.ANALYSIS_PARAMS_FILENAME))
|
167
|
+
|
168
|
+
return (
|
169
|
+
[[os.path.relpath(r[0], input_dir), bool(r[1])] for r in result_list]
|
170
|
+
if input_dir
|
171
|
+
else result_list[0][1]["csv"]
|
172
|
+
if result_list[0][1]
|
173
|
+
else None
|
174
|
+
)
|