birdnet-analyzer 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (117) hide show
  1. birdnet_analyzer/__init__.py +8 -0
  2. birdnet_analyzer/analyze/__init__.py +5 -0
  3. birdnet_analyzer/analyze/__main__.py +4 -0
  4. birdnet_analyzer/analyze/cli.py +25 -0
  5. birdnet_analyzer/analyze/core.py +245 -0
  6. birdnet_analyzer/analyze/utils.py +701 -0
  7. birdnet_analyzer/audio.py +372 -0
  8. birdnet_analyzer/cli.py +707 -0
  9. birdnet_analyzer/config.py +242 -0
  10. birdnet_analyzer/eBird_taxonomy_codes_2021E.json +25280 -0
  11. birdnet_analyzer/embeddings/__init__.py +4 -0
  12. birdnet_analyzer/embeddings/__main__.py +3 -0
  13. birdnet_analyzer/embeddings/cli.py +13 -0
  14. birdnet_analyzer/embeddings/core.py +70 -0
  15. birdnet_analyzer/embeddings/utils.py +193 -0
  16. birdnet_analyzer/evaluation/__init__.py +195 -0
  17. birdnet_analyzer/evaluation/__main__.py +3 -0
  18. birdnet_analyzer/gui/__init__.py +23 -0
  19. birdnet_analyzer/gui/__main__.py +3 -0
  20. birdnet_analyzer/gui/analysis.py +174 -0
  21. birdnet_analyzer/gui/assets/arrow_down.svg +4 -0
  22. birdnet_analyzer/gui/assets/arrow_left.svg +4 -0
  23. birdnet_analyzer/gui/assets/arrow_right.svg +4 -0
  24. birdnet_analyzer/gui/assets/arrow_up.svg +4 -0
  25. birdnet_analyzer/gui/assets/gui.css +29 -0
  26. birdnet_analyzer/gui/assets/gui.js +94 -0
  27. birdnet_analyzer/gui/assets/img/birdnet-icon.ico +0 -0
  28. birdnet_analyzer/gui/assets/img/birdnet_logo.png +0 -0
  29. birdnet_analyzer/gui/assets/img/birdnet_logo_no_transparent.png +0 -0
  30. birdnet_analyzer/gui/assets/img/clo-logo-bird.svg +1 -0
  31. birdnet_analyzer/gui/embeddings.py +620 -0
  32. birdnet_analyzer/gui/evaluation.py +813 -0
  33. birdnet_analyzer/gui/localization.py +68 -0
  34. birdnet_analyzer/gui/multi_file.py +246 -0
  35. birdnet_analyzer/gui/review.py +527 -0
  36. birdnet_analyzer/gui/segments.py +191 -0
  37. birdnet_analyzer/gui/settings.py +129 -0
  38. birdnet_analyzer/gui/single_file.py +269 -0
  39. birdnet_analyzer/gui/species.py +95 -0
  40. birdnet_analyzer/gui/train.py +698 -0
  41. birdnet_analyzer/gui/utils.py +808 -0
  42. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_af.txt +6522 -0
  43. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ar.txt +6522 -0
  44. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_bg.txt +6522 -0
  45. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ca.txt +6522 -0
  46. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_cs.txt +6522 -0
  47. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_da.txt +6522 -0
  48. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_de.txt +6522 -0
  49. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_el.txt +6522 -0
  50. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_en_uk.txt +6522 -0
  51. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_es.txt +6522 -0
  52. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_fi.txt +6522 -0
  53. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_fr.txt +6522 -0
  54. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_he.txt +6522 -0
  55. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_hr.txt +6522 -0
  56. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_hu.txt +6522 -0
  57. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_in.txt +6522 -0
  58. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_is.txt +6522 -0
  59. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_it.txt +6522 -0
  60. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ja.txt +6522 -0
  61. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ko.txt +6522 -0
  62. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_lt.txt +6522 -0
  63. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ml.txt +6522 -0
  64. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_nl.txt +6522 -0
  65. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_no.txt +6522 -0
  66. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_pl.txt +6522 -0
  67. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_pt_BR.txt +6522 -0
  68. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_pt_PT.txt +6522 -0
  69. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ro.txt +6522 -0
  70. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ru.txt +6522 -0
  71. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sk.txt +6522 -0
  72. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sl.txt +6522 -0
  73. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sr.txt +6522 -0
  74. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sv.txt +6522 -0
  75. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_th.txt +6522 -0
  76. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_tr.txt +6522 -0
  77. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_uk.txt +6522 -0
  78. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_zh.txt +6522 -0
  79. birdnet_analyzer/lang/de.json +335 -0
  80. birdnet_analyzer/lang/en.json +335 -0
  81. birdnet_analyzer/lang/fi.json +335 -0
  82. birdnet_analyzer/lang/fr.json +335 -0
  83. birdnet_analyzer/lang/id.json +335 -0
  84. birdnet_analyzer/lang/pt-br.json +335 -0
  85. birdnet_analyzer/lang/ru.json +335 -0
  86. birdnet_analyzer/lang/se.json +335 -0
  87. birdnet_analyzer/lang/tlh.json +335 -0
  88. birdnet_analyzer/lang/zh_TW.json +335 -0
  89. birdnet_analyzer/model.py +1243 -0
  90. birdnet_analyzer/search/__init__.py +3 -0
  91. birdnet_analyzer/search/__main__.py +3 -0
  92. birdnet_analyzer/search/cli.py +12 -0
  93. birdnet_analyzer/search/core.py +78 -0
  94. birdnet_analyzer/search/utils.py +111 -0
  95. birdnet_analyzer/segments/__init__.py +3 -0
  96. birdnet_analyzer/segments/__main__.py +3 -0
  97. birdnet_analyzer/segments/cli.py +14 -0
  98. birdnet_analyzer/segments/core.py +78 -0
  99. birdnet_analyzer/segments/utils.py +394 -0
  100. birdnet_analyzer/species/__init__.py +3 -0
  101. birdnet_analyzer/species/__main__.py +3 -0
  102. birdnet_analyzer/species/cli.py +14 -0
  103. birdnet_analyzer/species/core.py +35 -0
  104. birdnet_analyzer/species/utils.py +75 -0
  105. birdnet_analyzer/train/__init__.py +3 -0
  106. birdnet_analyzer/train/__main__.py +3 -0
  107. birdnet_analyzer/train/cli.py +14 -0
  108. birdnet_analyzer/train/core.py +113 -0
  109. birdnet_analyzer/train/utils.py +847 -0
  110. birdnet_analyzer/translate.py +104 -0
  111. birdnet_analyzer/utils.py +419 -0
  112. birdnet_analyzer-2.0.0.dist-info/METADATA +129 -0
  113. birdnet_analyzer-2.0.0.dist-info/RECORD +117 -0
  114. birdnet_analyzer-2.0.0.dist-info/WHEEL +5 -0
  115. birdnet_analyzer-2.0.0.dist-info/entry_points.txt +11 -0
  116. birdnet_analyzer-2.0.0.dist-info/licenses/LICENSE +19 -0
  117. birdnet_analyzer-2.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,4 @@
1
+ from birdnet_analyzer.embeddings.core import embeddings
2
+
3
+
4
+ __all__ = ["embeddings"]
@@ -0,0 +1,3 @@
1
+ from birdnet_analyzer.embeddings.cli import main
2
+
3
+ main()
@@ -0,0 +1,13 @@
1
+ from birdnet_analyzer.utils import runtime_error_handler
2
+
3
+ from birdnet_analyzer import embeddings
4
+
5
+
6
+ @runtime_error_handler
7
+ def main():
8
+ import birdnet_analyzer.cli as cli
9
+
10
+ parser = cli.embeddings_parser()
11
+ args = parser.parse_args()
12
+
13
+ embeddings(**vars(args))
@@ -0,0 +1,70 @@
1
+ def embeddings(
2
+ input: str,
3
+ database: str,
4
+ *,
5
+ overlap: float = 0.0,
6
+ audio_speed: float = 1.0,
7
+ fmin: int = 0,
8
+ fmax: int = 15000,
9
+ threads: int = 8,
10
+ batch_size: int = 1,
11
+ ):
12
+ """
13
+ Generates embeddings for audio files using the BirdNET-Analyzer.
14
+ This function processes audio files to extract embeddings, which are
15
+ representations of audio features. The embeddings can be used for
16
+ further analysis or comparison.
17
+ Args:
18
+ input (str): Path to the input audio file or directory containing audio files.
19
+ database (str): Path to the database where embeddings will be stored.
20
+ overlap (float, optional): Overlap between consecutive audio segments in seconds. Defaults to 0.0.
21
+ audio_speed (float, optional): Speed factor for audio processing. Defaults to 1.0.
22
+ fmin (int, optional): Minimum frequency (in Hz) for audio analysis. Defaults to 0.
23
+ fmax (int, optional): Maximum frequency (in Hz) for audio analysis. Defaults to 15000.
24
+ threads (int, optional): Number of threads to use for processing. Defaults to 8.
25
+ batch_size (int, optional): Number of audio segments to process in a single batch. Defaults to 1.
26
+ Raises:
27
+ FileNotFoundError: If the input path or database path does not exist.
28
+ ValueError: If any of the parameters are invalid.
29
+ Note:
30
+ Ensure that the required model files are downloaded and available before
31
+ calling this function. The `ensure_model_exists` function is used to
32
+ verify this.
33
+ Example:
34
+ embeddings(
35
+ input="path/to/audio",
36
+ database="path/to/database",
37
+ overlap=0.5,
38
+ audio_speed=1.0,
39
+ fmin=500,
40
+ fmax=10000,
41
+ threads=4,
42
+ batch_size=2
43
+ )
44
+ """
45
+ from birdnet_analyzer.embeddings.utils import run
46
+ from birdnet_analyzer.utils import ensure_model_exists
47
+
48
+ ensure_model_exists()
49
+ run(input, database, overlap, audio_speed, fmin, fmax, threads, batch_size)
50
+
51
+
52
+ def get_database(db_path: str):
53
+ """Get the database object. Creates or opens the databse.
54
+ Args:
55
+ db: The path to the database.
56
+ Returns:
57
+ The database object.
58
+ """
59
+ import os
60
+
61
+ from perch_hoplite.db import sqlite_usearch_impl
62
+
63
+ if not os.path.exists(db_path):
64
+ os.makedirs(os.path.dirname(db_path), exist_ok=True)
65
+ db = sqlite_usearch_impl.SQLiteUsearchDB.create(
66
+ db_path=db_path,
67
+ usearch_cfg=sqlite_usearch_impl.get_default_usearch_config(embedding_dim=1024), # TODO dont hardcode this
68
+ )
69
+ return db
70
+ return sqlite_usearch_impl.SQLiteUsearchDB.create(db_path=db_path)
@@ -0,0 +1,193 @@
1
+ """Module used to extract embeddings for samples."""
2
+
3
+ import datetime
4
+ import os
5
+
6
+ import numpy as np
7
+
8
+ import birdnet_analyzer.audio as audio
9
+ import birdnet_analyzer.config as cfg
10
+ import birdnet_analyzer.model as model
11
+ import birdnet_analyzer.utils as utils
12
+ from birdnet_analyzer.analyze.utils import get_raw_audio_from_file
13
+ from birdnet_analyzer.embeddings.core import get_database
14
+
15
+
16
+ from perch_hoplite.db import sqlite_usearch_impl
17
+ from perch_hoplite.db import interface as hoplite
18
+ from ml_collections import ConfigDict
19
+ from functools import partial
20
+ from tqdm import tqdm
21
+ from multiprocessing import Pool
22
+
23
+
24
+ DATASET_NAME: str = "birdnet_analyzer_dataset"
25
+
26
+
27
+ def analyze_file(item, db: sqlite_usearch_impl.SQLiteUsearchDB):
28
+ """Extracts the embeddings for a file.
29
+
30
+ Args:
31
+ item: (filepath, config)
32
+ """
33
+ # Get file path and restore cfg
34
+ fpath: str = item[0]
35
+ cfg.set_config(item[1])
36
+
37
+ offset = 0
38
+ duration = cfg.FILE_SPLITTING_DURATION
39
+
40
+ try:
41
+ fileLengthSeconds = int(audio.get_audio_file_length(fpath))
42
+ except Exception as ex:
43
+ # Write error log
44
+ print(f"Error: Cannot analyze audio file {fpath}. File corrupt?\n", flush=True)
45
+ utils.write_error_log(ex)
46
+
47
+ return None
48
+
49
+ # Start time
50
+ start_time = datetime.datetime.now()
51
+
52
+ # Status
53
+ print(f"Analyzing {fpath}", flush=True)
54
+
55
+ source_id = fpath
56
+
57
+ # Process each chunk
58
+ try:
59
+ while offset < fileLengthSeconds:
60
+ chunks = get_raw_audio_from_file(fpath, offset, duration)
61
+ start, end = offset, cfg.SIG_LENGTH + offset
62
+ samples = []
63
+ timestamps = []
64
+
65
+ for c in range(len(chunks)):
66
+ # Add to batch
67
+ samples.append(chunks[c])
68
+ timestamps.append([start, end])
69
+
70
+ # Advance start and end
71
+ start += cfg.SIG_LENGTH - cfg.SIG_OVERLAP
72
+ end = start + cfg.SIG_LENGTH
73
+
74
+ # Check if batch is full or last chunk
75
+ if len(samples) < cfg.BATCH_SIZE and c < len(chunks) - 1:
76
+ continue
77
+
78
+ # Prepare sample and pass through model
79
+ data = np.array(samples, dtype="float32")
80
+ e = model.embeddings(data)
81
+
82
+ # Add to results
83
+ for i in range(len(samples)):
84
+ # Get timestamp
85
+ s_start, s_end = timestamps[i]
86
+
87
+ # Check if embedding already exists
88
+ existing_embedding = db.get_embeddings_by_source(
89
+ DATASET_NAME, source_id, np.array([s_start, s_end])
90
+ )
91
+
92
+ if existing_embedding.size == 0:
93
+ # Get prediction
94
+ embeddings = e[i]
95
+
96
+ # Store embeddings
97
+ embeddings_source = hoplite.EmbeddingSource(DATASET_NAME, source_id, np.array([s_start, s_end]))
98
+
99
+ # Insert into database
100
+ db.insert_embedding(embeddings, embeddings_source)
101
+ db.commit()
102
+
103
+ # Reset batch
104
+ samples = []
105
+ timestamps = []
106
+
107
+ offset = offset + duration
108
+
109
+ except Exception as ex:
110
+ # Write error log
111
+ print(f"Error: Cannot analyze audio file {fpath}.", flush=True)
112
+ utils.write_error_log(ex)
113
+
114
+ return
115
+
116
+ delta_time = (datetime.datetime.now() - start_time).total_seconds()
117
+ print("Finished {} in {:.2f} seconds".format(fpath, delta_time), flush=True)
118
+
119
+
120
+ def check_database_settings(db: sqlite_usearch_impl.SQLiteUsearchDB):
121
+ try:
122
+ settings = db.get_metadata("birdnet_analyzer_settings")
123
+ if (
124
+ settings["BANDPASS_FMIN"] != cfg.BANDPASS_FMIN
125
+ or settings["BANDPASS_FMAX"] != cfg.BANDPASS_FMAX
126
+ or settings["AUDIO_SPEED"] != cfg.AUDIO_SPEED
127
+ ):
128
+ raise ValueError(
129
+ "Database settings do not match current configuration. DB Settings are: fmin: {}, fmax: {}, audio_speed: {}".format(
130
+ settings["BANDPASS_FMIN"], settings["BANDPASS_FMAX"], settings["AUDIO_SPEED"]
131
+ )
132
+ )
133
+ except KeyError:
134
+ settings = ConfigDict(
135
+ {"BANDPASS_FMIN": cfg.BANDPASS_FMIN, "BANDPASS_FMAX": cfg.BANDPASS_FMAX, "AUDIO_SPEED": cfg.AUDIO_SPEED}
136
+ )
137
+ db.insert_metadata("birdnet_analyzer_settings", settings)
138
+ db.commit()
139
+
140
+
141
+ def run(input, database, overlap, audio_speed, fmin, fmax, threads, batchsize):
142
+ ### Make sure to comment out appropriately if you are not using args. ###
143
+
144
+ # Set input and output path
145
+ cfg.INPUT_PATH = input
146
+
147
+ # Parse input files
148
+ if os.path.isdir(cfg.INPUT_PATH):
149
+ cfg.FILE_LIST = utils.collect_audio_files(cfg.INPUT_PATH)
150
+ else:
151
+ cfg.FILE_LIST = [cfg.INPUT_PATH]
152
+
153
+ # Set overlap
154
+ cfg.SIG_OVERLAP = max(0.0, min(2.9, float(overlap)))
155
+
156
+ # Set audio speed
157
+ cfg.AUDIO_SPEED = max(0.01, audio_speed)
158
+
159
+ # Set bandpass frequency range
160
+ cfg.BANDPASS_FMIN = max(0, min(cfg.SIG_FMAX, int(fmin)))
161
+ cfg.BANDPASS_FMAX = max(cfg.SIG_FMIN, min(cfg.SIG_FMAX, int(fmax)))
162
+
163
+ # Set number of threads
164
+ if os.path.isdir(cfg.INPUT_PATH):
165
+ cfg.CPU_THREADS = max(1, int(threads))
166
+ cfg.TFLITE_THREADS = 1
167
+ else:
168
+ cfg.CPU_THREADS = 1
169
+ cfg.TFLITE_THREADS = max(1, int(threads))
170
+
171
+ cfg.CPU_THREADS = 1 # TODO: with the current implementation, we can't use more than 1 thread
172
+
173
+ # Set batch size
174
+ cfg.BATCH_SIZE = max(1, int(batchsize))
175
+
176
+ # Add config items to each file list entry.
177
+ # We have to do this for Windows which does not
178
+ # support fork() and thus each process has to
179
+ # have its own config. USE LINUX!
180
+ flist = [(f, cfg.get_config()) for f in cfg.FILE_LIST]
181
+
182
+ db = get_database(database)
183
+ check_database_settings(db)
184
+
185
+ # Analyze files
186
+ if cfg.CPU_THREADS < 2:
187
+ for entry in tqdm(flist):
188
+ analyze_file(entry, db)
189
+ else:
190
+ with Pool(cfg.CPU_THREADS) as p:
191
+ tqdm(p.imap(partial(analyze_file, db=db), flist))
192
+
193
+ db.db.close()
@@ -0,0 +1,195 @@
1
+ """
2
+ Core script for assessing performance of prediction models against annotated data.
3
+
4
+ This script uses the `DataProcessor` and `PerformanceAssessor` classes to process prediction and
5
+ annotation data, compute metrics, and optionally generate plots. It supports flexible configurations
6
+ for columns, class mappings, and filtering based on selected classes or recordings.
7
+ """
8
+
9
+ import argparse
10
+ import json
11
+ import os
12
+ from typing import Optional, Dict, List, Tuple
13
+
14
+ from birdnet_analyzer.evaluation.preprocessing.data_processor import DataProcessor
15
+ from birdnet_analyzer.evaluation.assessment.performance_assessor import PerformanceAssessor
16
+
17
+
18
+ def process_data(
19
+ annotation_path: str,
20
+ prediction_path: str,
21
+ mapping_path: Optional[str] = None,
22
+ sample_duration: float = 3.0,
23
+ min_overlap: float = 0.5,
24
+ recording_duration: Optional[float] = None,
25
+ columns_annotations: Optional[Dict[str, str]] = None,
26
+ columns_predictions: Optional[Dict[str, str]] = None,
27
+ selected_classes: Optional[List[str]] = None,
28
+ selected_recordings: Optional[List[str]] = None,
29
+ metrics_list: Tuple[str, ...] = ("accuracy", "precision", "recall"),
30
+ threshold: float = 0.1,
31
+ class_wise: bool = False,
32
+ ):
33
+ """
34
+ Processes data, computes metrics, and prepares the performance assessment pipeline.
35
+
36
+ Args:
37
+ annotation_path (str): Path to the annotation file or folder.
38
+ prediction_path (str): Path to the prediction file or folder.
39
+ mapping_path (Optional[str]): Path to the class mapping JSON file, if applicable.
40
+ sample_duration (float): Duration of each sample interval in seconds.
41
+ min_overlap (float): Minimum overlap required between predictions and annotations.
42
+ recording_duration (Optional[float]): Total duration of the recordings, if known.
43
+ columns_annotations (Optional[Dict[str, str]]): Custom column mappings for annotations.
44
+ columns_predictions (Optional[Dict[str, str]]): Custom column mappings for predictions.
45
+ selected_classes (Optional[List[str]]): List of classes to include in the analysis.
46
+ selected_recordings (Optional[List[str]]): List of recordings to include in the analysis.
47
+ metrics_list (Tuple[str, ...]): Metrics to compute for performance assessment.
48
+ threshold (float): Confidence threshold for predictions.
49
+ class_wise (bool): Whether to calculate metrics on a per-class basis.
50
+
51
+ Returns:
52
+ Tuple: Metrics DataFrame, `PerformanceAssessor` object, predictions tensor, labels tensor.
53
+ """
54
+ # Load class mapping if provided
55
+ if mapping_path:
56
+ with open(mapping_path, "r") as f:
57
+ class_mapping = json.load(f)
58
+ else:
59
+ class_mapping = None
60
+
61
+ # Determine directory and file paths for annotations and predictions
62
+ annotation_dir, annotation_file = (
63
+ (os.path.dirname(annotation_path), os.path.basename(annotation_path))
64
+ if os.path.isfile(annotation_path)
65
+ else (annotation_path, None)
66
+ )
67
+ prediction_dir, prediction_file = (
68
+ (os.path.dirname(prediction_path), os.path.basename(prediction_path))
69
+ if os.path.isfile(prediction_path)
70
+ else (prediction_path, None)
71
+ )
72
+
73
+ # Initialize the DataProcessor to handle and prepare data
74
+ processor = DataProcessor(
75
+ prediction_directory_path=prediction_dir,
76
+ prediction_file_name=prediction_file,
77
+ annotation_directory_path=annotation_dir,
78
+ annotation_file_name=annotation_file,
79
+ class_mapping=class_mapping,
80
+ sample_duration=sample_duration,
81
+ min_overlap=min_overlap,
82
+ columns_predictions=columns_predictions,
83
+ columns_annotations=columns_annotations,
84
+ recording_duration=recording_duration,
85
+ )
86
+
87
+ # Get the available classes and recordings
88
+ available_classes = processor.classes
89
+ available_recordings = processor.samples_df["filename"].unique().tolist()
90
+
91
+ # Default to all classes or recordings if none are specified
92
+ if selected_classes is None:
93
+ selected_classes = available_classes
94
+ if selected_recordings is None:
95
+ selected_recordings = available_recordings
96
+
97
+ # Retrieve predictions and labels tensors for the selected classes and recordings
98
+ predictions, labels, classes = processor.get_filtered_tensors(selected_classes, selected_recordings)
99
+
100
+ num_classes = len(classes)
101
+ task = "binary" if num_classes == 1 else "multilabel"
102
+
103
+ # Initialize the PerformanceAssessor for computing metrics
104
+ pa = PerformanceAssessor(
105
+ num_classes=num_classes,
106
+ threshold=threshold,
107
+ classes=classes,
108
+ task=task,
109
+ metrics_list=metrics_list,
110
+ )
111
+
112
+ # Compute performance metrics
113
+ metrics_df = pa.calculate_metrics(predictions, labels, per_class_metrics=class_wise)
114
+
115
+ return metrics_df, pa, predictions, labels
116
+
117
+
118
+ def main():
119
+ """
120
+ Entry point for the script. Parses command-line arguments and orchestrates the performance assessment pipeline.
121
+ """
122
+ # Set up argument parsing
123
+ parser = argparse.ArgumentParser(description="Performance Assessor Core Script")
124
+ parser.add_argument("--annotation_path", required=True, help="Path to annotation file or folder")
125
+ parser.add_argument("--prediction_path", required=True, help="Path to prediction file or folder")
126
+ parser.add_argument("--mapping_path", help="Path to class mapping JSON file (optional)")
127
+ parser.add_argument("--sample_duration", type=float, default=3.0, help="Sample duration in seconds")
128
+ parser.add_argument("--min_overlap", type=float, default=0.5, help="Minimum overlap in seconds")
129
+ parser.add_argument("--recording_duration", type=float, help="Recording duration in seconds")
130
+ parser.add_argument("--columns_annotations", type=json.loads, help="JSON string for columns_annotations")
131
+ parser.add_argument("--columns_predictions", type=json.loads, help="JSON string for columns_predictions")
132
+ parser.add_argument("--selected_classes", nargs="+", help="List of selected classes")
133
+ parser.add_argument("--selected_recordings", nargs="+", help="List of selected recordings")
134
+ parser.add_argument("--metrics", nargs="+", default=["accuracy", "precision", "recall"], help="List of metrics")
135
+ parser.add_argument("--threshold", type=float, default=0.1, help="Threshold value (0-1)")
136
+ parser.add_argument("--class_wise", action="store_true", help="Calculate class-wise metrics")
137
+ parser.add_argument("--plot_metrics", action="store_true", help="Plot metrics")
138
+ parser.add_argument("--plot_confusion_matrix", action="store_true", help="Plot confusion matrix")
139
+ parser.add_argument("--plot_metrics_all_thresholds", action="store_true", help="Plot metrics for all thresholds")
140
+ parser.add_argument("--output_dir", help="Directory to save plots")
141
+
142
+ # Parse arguments
143
+ args = parser.parse_args()
144
+
145
+ # Process data and compute metrics
146
+ metrics_df, pa, predictions, labels = process_data(
147
+ annotation_path=args.annotation_path,
148
+ prediction_path=args.prediction_path,
149
+ mapping_path=args.mapping_path,
150
+ sample_duration=args.sample_duration,
151
+ min_overlap=args.min_overlap,
152
+ recording_duration=args.recording_duration,
153
+ columns_annotations=args.columns_annotations,
154
+ columns_predictions=args.columns_predictions,
155
+ selected_classes=args.selected_classes,
156
+ selected_recordings=args.selected_recordings,
157
+ metrics_list=args.metrics,
158
+ threshold=args.threshold,
159
+ class_wise=args.class_wise,
160
+ )
161
+
162
+ # Display the computed metrics
163
+ print(metrics_df)
164
+
165
+ # Create output directory if needed
166
+ if args.output_dir and not os.path.exists(args.output_dir):
167
+ os.makedirs(args.output_dir)
168
+
169
+ # Generate plots if specified
170
+ if args.plot_metrics:
171
+ pa.plot_metrics(predictions, labels, per_class_metrics=args.class_wise)
172
+ if args.output_dir:
173
+ import matplotlib.pyplot as plt
174
+
175
+ plt.savefig(os.path.join(args.output_dir, "metrics_plot.png"))
176
+ else:
177
+ plt.show()
178
+
179
+ if args.plot_confusion_matrix:
180
+ pa.plot_confusion_matrix(predictions, labels)
181
+ if args.output_dir:
182
+ import matplotlib.pyplot as plt
183
+
184
+ plt.savefig(os.path.join(args.output_dir, "confusion_matrix.png"))
185
+ else:
186
+ plt.show()
187
+
188
+ if args.plot_metrics_all_thresholds:
189
+ pa.plot_metrics_all_thresholds(predictions, labels, per_class_metrics=args.class_wise)
190
+ if args.output_dir:
191
+ import matplotlib.pyplot as plt
192
+
193
+ plt.savefig(os.path.join(args.output_dir, "metrics_all_thresholds.png"))
194
+ else:
195
+ plt.show()
@@ -0,0 +1,3 @@
1
+ from birdnet_analyzer.evaluation import main
2
+
3
+ main()
@@ -0,0 +1,23 @@
1
+ def main():
2
+ import birdnet_analyzer.gui.multi_file as mfa
3
+ import birdnet_analyzer.gui.review as review
4
+ import birdnet_analyzer.gui.segments as gs
5
+ import birdnet_analyzer.gui.single_file as sfa
6
+ import birdnet_analyzer.gui.species as species
7
+ import birdnet_analyzer.gui.train as train
8
+ import birdnet_analyzer.gui.utils as gu
9
+ import birdnet_analyzer.gui.embeddings as embeddings
10
+ import birdnet_analyzer.gui.evaluation as evaluation
11
+
12
+ gu.open_window(
13
+ [
14
+ sfa.build_single_analysis_tab,
15
+ mfa.build_multi_analysis_tab,
16
+ train.build_train_tab,
17
+ gs.build_segments_tab,
18
+ review.build_review_tab,
19
+ species.build_species_tab,
20
+ embeddings.build_embeddings_tab,
21
+ evaluation.build_evaluation_tab,
22
+ ]
23
+ )
@@ -0,0 +1,3 @@
1
+ from birdnet_analyzer.gui import main
2
+
3
+ main()
@@ -0,0 +1,174 @@
1
+ import concurrent.futures
2
+ import os
3
+ from pathlib import Path
4
+
5
+ import gradio as gr
6
+
7
+ import birdnet_analyzer.config as cfg
8
+ import birdnet_analyzer.gui.utils as gu
9
+ import birdnet_analyzer.gui.localization as loc
10
+ import birdnet_analyzer.model as model
11
+
12
+
13
+ from birdnet_analyzer.analyze.utils import analyze_file, combine_results, save_analysis_params
14
+
15
+ SCRIPT_DIR = os.path.abspath(os.path.dirname(__file__))
16
+ ORIGINAL_LABELS_FILE = str(Path(SCRIPT_DIR).parent / cfg.LABELS_FILE)
17
+
18
+
19
+ def analyze_file_wrapper(entry):
20
+ """
21
+ Wrapper function for analyzing a file.
22
+
23
+ Args:
24
+ entry (tuple): A tuple where the first element is the file path and the
25
+ remaining elements are arguments to be passed to the
26
+ analyze.analyzeFile function.
27
+
28
+ Returns:
29
+ tuple: A tuple where the first element is the file path and the second
30
+ element is the result of the analyze.analyzeFile function.
31
+ """
32
+ return (entry[0], analyze_file(entry))
33
+
34
+
35
+ def run_analysis(
36
+ input_path: str,
37
+ output_path: str | None,
38
+ use_top_n: bool,
39
+ top_n: int,
40
+ confidence: float,
41
+ sensitivity: float,
42
+ overlap: float,
43
+ merge_consecutive: int,
44
+ audio_speed: float,
45
+ fmin: int,
46
+ fmax: int,
47
+ species_list_choice: str,
48
+ species_list_file,
49
+ lat: float,
50
+ lon: float,
51
+ week: int,
52
+ use_yearlong: bool,
53
+ sf_thresh: float,
54
+ custom_classifier_file,
55
+ output_types: str,
56
+ combine_tables: bool,
57
+ locale: str,
58
+ batch_size: int,
59
+ threads: int,
60
+ input_dir: str,
61
+ skip_existing: bool,
62
+ save_params: bool,
63
+ progress: gr.Progress | None,
64
+ ):
65
+ """Starts the analysis.
66
+
67
+ Args:
68
+ input_path: Either a file or directory.
69
+ output_path: The output path for the result, if None the input_path is used
70
+ confidence: The selected minimum confidence.
71
+ sensitivity: The selected sensitivity.
72
+ overlap: The selected segment overlap.
73
+ merge_consecutive: The number of consecutive segments to merge into one.
74
+ audio_speed: The selected audio speed.
75
+ fmin: The selected minimum bandpass frequency.
76
+ fmax: The selected maximum bandpass frequency.
77
+ species_list_choice: The choice for the species list.
78
+ species_list_file: The selected custom species list file.
79
+ lat: The selected latitude.
80
+ lon: The selected longitude.
81
+ week: The selected week of the year.
82
+ use_yearlong: Use yearlong instead of week.
83
+ sf_thresh: The threshold for the predicted species list.
84
+ custom_classifier_file: Custom classifier to be used.
85
+ output_type: The type of result to be generated.
86
+ output_filename: The filename for the combined output.
87
+ locale: The translation to be used.
88
+ batch_size: The number of samples in a batch.
89
+ threads: The number of threads to be used.
90
+ input_dir: The input directory.
91
+ progress: The gradio progress bar.
92
+ """
93
+ if progress is not None:
94
+ progress(0, desc=f"{loc.localize('progress-preparing')} ...")
95
+
96
+ from birdnet_analyzer.analyze.core import _set_params
97
+
98
+ locale = locale.lower()
99
+ custom_classifier = custom_classifier_file if species_list_choice == gu._CUSTOM_CLASSIFIER else None
100
+ slist = species_list_file if species_list_choice == gu._CUSTOM_SPECIES else None
101
+ lat = lat if species_list_choice == gu._PREDICT_SPECIES else -1
102
+ lon = lon if species_list_choice == gu._PREDICT_SPECIES else -1
103
+ week = -1 if use_yearlong else week
104
+
105
+ flist = _set_params(
106
+ input=input_dir if input_dir else input_path,
107
+ min_conf=confidence,
108
+ custom_classifier=custom_classifier,
109
+ sensitivity=min(1.25, max(0.75, float(sensitivity))),
110
+ locale=locale,
111
+ overlap=max(0.0, min(2.9, float(overlap))),
112
+ merge_consecutive=max(1, int(merge_consecutive)),
113
+ audio_speed=max(0.1, 1.0 / (audio_speed * -1)) if audio_speed < 0 else max(1.0, float(audio_speed)),
114
+ fmin=max(0, min(cfg.SIG_FMAX, int(fmin))),
115
+ fmax=max(cfg.SIG_FMIN, min(cfg.SIG_FMAX, int(fmax))),
116
+ bs=max(1, int(batch_size)),
117
+ combine_results=combine_tables,
118
+ rtype=output_types,
119
+ skip_existing_results=skip_existing,
120
+ threads=max(1, int(threads)),
121
+ labels_file=ORIGINAL_LABELS_FILE,
122
+ sf_thresh=sf_thresh,
123
+ lat=lat,
124
+ lon=lon,
125
+ week=week,
126
+ slist=slist,
127
+ top_n=top_n if use_top_n else None,
128
+ output=output_path,
129
+ )
130
+
131
+ if species_list_choice == gu._CUSTOM_CLASSIFIER:
132
+ if custom_classifier_file is None:
133
+ raise gr.Error(loc.localize("validation-no-custom-classifier-selected"))
134
+
135
+ model.reset_custom_classifier()
136
+
137
+ gu.validate(cfg.FILE_LIST, loc.localize("validation-no-audio-files-found"))
138
+
139
+ result_list = []
140
+
141
+ if progress is not None:
142
+ progress(0, desc=f"{loc.localize('progress-starting')} ...")
143
+
144
+ # Analyze files
145
+ if cfg.CPU_THREADS < 2:
146
+ for entry in flist:
147
+ result_list.append(analyze_file_wrapper(entry))
148
+ else:
149
+ with concurrent.futures.ProcessPoolExecutor(max_workers=cfg.CPU_THREADS) as executor:
150
+ futures = (executor.submit(analyze_file_wrapper, arg) for arg in flist)
151
+ for i, f in enumerate(concurrent.futures.as_completed(futures), start=1):
152
+ if progress is not None:
153
+ progress((i, len(flist)), total=len(flist), unit="files")
154
+ result = f.result()
155
+
156
+ result_list.append(result)
157
+
158
+ # Combine results?
159
+ if cfg.COMBINE_RESULTS:
160
+ combine_list = [[r[1] for r in result_list if r[0] == i[0]][0] for i in flist]
161
+ print(f"Combining results, writing to {cfg.OUTPUT_PATH}...", end="", flush=True)
162
+ combine_results(combine_list)
163
+ print("done!", flush=True)
164
+
165
+ if save_params:
166
+ save_analysis_params(os.path.join(cfg.OUTPUT_PATH, cfg.ANALYSIS_PARAMS_FILENAME))
167
+
168
+ return (
169
+ [[os.path.relpath(r[0], input_dir), bool(r[1])] for r in result_list]
170
+ if input_dir
171
+ else result_list[0][1]["csv"]
172
+ if result_list[0][1]
173
+ else None
174
+ )