birdnet-analyzer 2.0.0__py3-none-any.whl → 2.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- birdnet_analyzer/__init__.py +9 -8
- birdnet_analyzer/analyze/__init__.py +5 -5
- birdnet_analyzer/analyze/__main__.py +3 -4
- birdnet_analyzer/analyze/cli.py +25 -25
- birdnet_analyzer/analyze/core.py +241 -245
- birdnet_analyzer/analyze/utils.py +692 -701
- birdnet_analyzer/audio.py +368 -372
- birdnet_analyzer/cli.py +709 -707
- birdnet_analyzer/config.py +242 -242
- birdnet_analyzer/eBird_taxonomy_codes_2021E.json +25279 -25279
- birdnet_analyzer/embeddings/__init__.py +3 -4
- birdnet_analyzer/embeddings/__main__.py +3 -3
- birdnet_analyzer/embeddings/cli.py +12 -13
- birdnet_analyzer/embeddings/core.py +69 -70
- birdnet_analyzer/embeddings/utils.py +179 -193
- birdnet_analyzer/evaluation/__init__.py +196 -195
- birdnet_analyzer/evaluation/__main__.py +3 -3
- birdnet_analyzer/evaluation/assessment/__init__.py +0 -0
- birdnet_analyzer/evaluation/assessment/metrics.py +388 -0
- birdnet_analyzer/evaluation/assessment/performance_assessor.py +409 -0
- birdnet_analyzer/evaluation/assessment/plotting.py +379 -0
- birdnet_analyzer/evaluation/preprocessing/__init__.py +0 -0
- birdnet_analyzer/evaluation/preprocessing/data_processor.py +631 -0
- birdnet_analyzer/evaluation/preprocessing/utils.py +98 -0
- birdnet_analyzer/gui/__init__.py +19 -23
- birdnet_analyzer/gui/__main__.py +3 -3
- birdnet_analyzer/gui/analysis.py +175 -174
- birdnet_analyzer/gui/assets/arrow_down.svg +4 -4
- birdnet_analyzer/gui/assets/arrow_left.svg +4 -4
- birdnet_analyzer/gui/assets/arrow_right.svg +4 -4
- birdnet_analyzer/gui/assets/arrow_up.svg +4 -4
- birdnet_analyzer/gui/assets/gui.css +28 -28
- birdnet_analyzer/gui/assets/gui.js +93 -93
- birdnet_analyzer/gui/embeddings.py +619 -620
- birdnet_analyzer/gui/evaluation.py +795 -813
- birdnet_analyzer/gui/localization.py +75 -68
- birdnet_analyzer/gui/multi_file.py +245 -246
- birdnet_analyzer/gui/review.py +519 -527
- birdnet_analyzer/gui/segments.py +191 -191
- birdnet_analyzer/gui/settings.py +128 -129
- birdnet_analyzer/gui/single_file.py +267 -269
- birdnet_analyzer/gui/species.py +95 -95
- birdnet_analyzer/gui/train.py +696 -698
- birdnet_analyzer/gui/utils.py +810 -808
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_af.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ar.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_bg.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ca.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_cs.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_da.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_de.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_el.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_en_uk.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_es.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_fi.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_fr.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_he.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_hr.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_hu.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_in.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_is.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_it.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ja.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ko.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_lt.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ml.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_nl.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_no.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_pl.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_pt_BR.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_pt_PT.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ro.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ru.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sk.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sl.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sr.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sv.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_th.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_tr.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_uk.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_zh.txt +6522 -6522
- birdnet_analyzer/lang/de.json +334 -334
- birdnet_analyzer/lang/en.json +334 -334
- birdnet_analyzer/lang/fi.json +334 -334
- birdnet_analyzer/lang/fr.json +334 -334
- birdnet_analyzer/lang/id.json +334 -334
- birdnet_analyzer/lang/pt-br.json +334 -334
- birdnet_analyzer/lang/ru.json +334 -334
- birdnet_analyzer/lang/se.json +334 -334
- birdnet_analyzer/lang/tlh.json +334 -334
- birdnet_analyzer/lang/zh_TW.json +334 -334
- birdnet_analyzer/model.py +1212 -1243
- birdnet_analyzer/playground.py +5 -0
- birdnet_analyzer/search/__init__.py +3 -3
- birdnet_analyzer/search/__main__.py +3 -3
- birdnet_analyzer/search/cli.py +11 -12
- birdnet_analyzer/search/core.py +78 -78
- birdnet_analyzer/search/utils.py +107 -111
- birdnet_analyzer/segments/__init__.py +3 -3
- birdnet_analyzer/segments/__main__.py +3 -3
- birdnet_analyzer/segments/cli.py +13 -14
- birdnet_analyzer/segments/core.py +81 -78
- birdnet_analyzer/segments/utils.py +383 -394
- birdnet_analyzer/species/__init__.py +3 -3
- birdnet_analyzer/species/__main__.py +3 -3
- birdnet_analyzer/species/cli.py +13 -14
- birdnet_analyzer/species/core.py +35 -35
- birdnet_analyzer/species/utils.py +74 -75
- birdnet_analyzer/train/__init__.py +3 -3
- birdnet_analyzer/train/__main__.py +3 -3
- birdnet_analyzer/train/cli.py +13 -14
- birdnet_analyzer/train/core.py +113 -113
- birdnet_analyzer/train/utils.py +877 -847
- birdnet_analyzer/translate.py +133 -104
- birdnet_analyzer/utils.py +426 -419
- {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.0.1.dist-info}/METADATA +137 -129
- birdnet_analyzer-2.0.1.dist-info/RECORD +125 -0
- {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.0.1.dist-info}/WHEEL +1 -1
- {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.0.1.dist-info}/licenses/LICENSE +18 -18
- birdnet_analyzer-2.0.0.dist-info/RECORD +0 -117
- {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.0.1.dist-info}/entry_points.txt +0 -0
- {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.0.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,631 @@
|
|
1
|
+
"""
|
2
|
+
DataProcessor class for handling and transforming sample data with annotations and predictions.
|
3
|
+
|
4
|
+
This module defines the DataProcessor class, which processes prediction and annotation data,
|
5
|
+
aligns them with sampled time intervals, and generates tensors for further model training or evaluation.
|
6
|
+
"""
|
7
|
+
|
8
|
+
import os
|
9
|
+
import warnings
|
10
|
+
from typing import ClassVar
|
11
|
+
|
12
|
+
import numpy as np
|
13
|
+
import pandas as pd
|
14
|
+
|
15
|
+
from birdnet_analyzer.evaluation.preprocessing.utils import (
|
16
|
+
extract_recording_filename,
|
17
|
+
extract_recording_filename_from_filename,
|
18
|
+
read_and_concatenate_files_in_directory,
|
19
|
+
)
|
20
|
+
|
21
|
+
|
22
|
+
class DataProcessor:
|
23
|
+
"""
|
24
|
+
Processor for handling and transforming sample data with annotations and predictions.
|
25
|
+
|
26
|
+
This class processes prediction and annotation data, aligning them with sampled time intervals,
|
27
|
+
and generates tensors for further model training or evaluation.
|
28
|
+
"""
|
29
|
+
|
30
|
+
# Default column mappings for predictions and annotations
|
31
|
+
DEFAULT_COLUMNS_PREDICTIONS: ClassVar[dict[str, str]] = {
|
32
|
+
"Start Time": "Start Time",
|
33
|
+
"End Time": "End Time",
|
34
|
+
"Class": "Class",
|
35
|
+
"Recording": "Recording",
|
36
|
+
"Duration": "Duration",
|
37
|
+
"Confidence": "Confidence",
|
38
|
+
}
|
39
|
+
|
40
|
+
DEFAULT_COLUMNS_ANNOTATIONS: ClassVar[dict[str, str]] = {
|
41
|
+
"Start Time": "Start Time",
|
42
|
+
"End Time": "End Time",
|
43
|
+
"Class": "Class",
|
44
|
+
"Recording": "Recording",
|
45
|
+
"Duration": "Duration",
|
46
|
+
}
|
47
|
+
|
48
|
+
def __init__(
|
49
|
+
self,
|
50
|
+
prediction_directory_path: str,
|
51
|
+
annotation_directory_path: str,
|
52
|
+
prediction_file_name: str | None = None,
|
53
|
+
annotation_file_name: str | None = None,
|
54
|
+
class_mapping: dict[str, str] | None = None,
|
55
|
+
sample_duration: int = 3,
|
56
|
+
min_overlap: float = 0.5,
|
57
|
+
columns_predictions: dict[str, str] | None = None,
|
58
|
+
columns_annotations: dict[str, str] | None = None,
|
59
|
+
recording_duration: float | None = None,
|
60
|
+
) -> None:
|
61
|
+
"""
|
62
|
+
Initializes the DataProcessor by loading prediction and annotation data.
|
63
|
+
|
64
|
+
Args:
|
65
|
+
prediction_directory_path (str): Path to the folder containing prediction files.
|
66
|
+
annotation_directory_path (str): Path to the folder containing annotation files.
|
67
|
+
prediction_file_name (Optional[str]): Name of the prediction file to process.
|
68
|
+
annotation_file_name (Optional[str]): Name of the annotation file to process.
|
69
|
+
class_mapping (Optional[Dict[str, str]]): Optional dictionary mapping raw class
|
70
|
+
names to standardized class names.
|
71
|
+
sample_duration (int, optional): Length of each data sample in seconds. Defaults to 3.
|
72
|
+
min_overlap (float, optional): Minimum overlap required between prediction and
|
73
|
+
annotation to consider a match.
|
74
|
+
columns_predictions (Optional[Dict[str, str]], optional): Column name mappings for prediction files.
|
75
|
+
columns_annotations (Optional[Dict[str, str]], optional): Column name mappings for annotation files.
|
76
|
+
recording_duration (Optional[float], optional): User-specified recording duration in seconds.
|
77
|
+
Defaults to None.
|
78
|
+
|
79
|
+
Raises:
|
80
|
+
ValueError: If any parameter is invalid (e.g., negative sample duration).
|
81
|
+
"""
|
82
|
+
# Initialize instance variables
|
83
|
+
self.sample_duration: int = sample_duration
|
84
|
+
self.min_overlap: float = min_overlap
|
85
|
+
self.class_mapping: dict[str, str] | None = class_mapping
|
86
|
+
|
87
|
+
# Use provided column mappings or defaults
|
88
|
+
self.columns_predictions: dict[str, str] = columns_predictions if columns_predictions is not None else self.DEFAULT_COLUMNS_PREDICTIONS.copy()
|
89
|
+
self.columns_annotations: dict[str, str] = columns_annotations if columns_annotations is not None else self.DEFAULT_COLUMNS_ANNOTATIONS.copy()
|
90
|
+
|
91
|
+
self.recording_duration: float | None = recording_duration
|
92
|
+
|
93
|
+
# Paths and filenames
|
94
|
+
self.prediction_directory_path: str = prediction_directory_path
|
95
|
+
self.prediction_file_name: str | None = prediction_file_name
|
96
|
+
self.annotation_directory_path: str = annotation_directory_path
|
97
|
+
self.annotation_file_name: str | None = annotation_file_name
|
98
|
+
|
99
|
+
# DataFrames for predictions and annotations
|
100
|
+
self.predictions_df: pd.DataFrame = pd.DataFrame()
|
101
|
+
self.annotations_df: pd.DataFrame = pd.DataFrame()
|
102
|
+
|
103
|
+
# Placeholder for unique classes across predictions and annotations
|
104
|
+
self.classes: tuple[str, ...] = ()
|
105
|
+
|
106
|
+
# Placeholder for samples DataFrame and tensors
|
107
|
+
self.samples_df: pd.DataFrame = pd.DataFrame()
|
108
|
+
self.prediction_tensors: np.ndarray = np.array([])
|
109
|
+
self.label_tensors: np.ndarray = np.array([])
|
110
|
+
|
111
|
+
# Validate column mappings and parameters
|
112
|
+
self._validate_columns()
|
113
|
+
self._validate_parameters()
|
114
|
+
|
115
|
+
# Load and process data
|
116
|
+
self.load_data()
|
117
|
+
self.process_data()
|
118
|
+
self.create_tensors()
|
119
|
+
|
120
|
+
def _validate_parameters(self) -> None:
|
121
|
+
"""
|
122
|
+
Validates the input parameters for correctness.
|
123
|
+
|
124
|
+
Raises:
|
125
|
+
ValueError: If sample duration, minimum overlap, or recording duration is invalid.
|
126
|
+
"""
|
127
|
+
# Validate sample duration
|
128
|
+
if self.sample_duration <= 0:
|
129
|
+
raise ValueError("Sample duration must be positive.")
|
130
|
+
|
131
|
+
# Validate recording duration
|
132
|
+
if self.recording_duration is not None:
|
133
|
+
if self.recording_duration <= 0:
|
134
|
+
raise ValueError("Recording duration must be greater than 0.")
|
135
|
+
if self.sample_duration > self.recording_duration:
|
136
|
+
raise ValueError("Sample duration cannot exceed the recording duration.")
|
137
|
+
|
138
|
+
# Validate minimum overlap
|
139
|
+
if self.min_overlap <= 0:
|
140
|
+
raise ValueError("Min overlap must be greater than 0.")
|
141
|
+
if self.min_overlap > self.sample_duration:
|
142
|
+
raise ValueError("Min overlap cannot exceed the sample duration.")
|
143
|
+
|
144
|
+
def _validate_columns(self) -> None:
|
145
|
+
"""
|
146
|
+
Validates that essential columns are provided in the column mappings.
|
147
|
+
|
148
|
+
Raises:
|
149
|
+
ValueError: If required columns are missing or have None values.
|
150
|
+
"""
|
151
|
+
# Required columns for predictions and annotations
|
152
|
+
required_columns = ["Start Time", "End Time", "Class"]
|
153
|
+
|
154
|
+
# Check for missing or None columns in predictions
|
155
|
+
missing_pred_columns = [col for col in required_columns if col not in self.columns_predictions or self.columns_predictions[col] is None]
|
156
|
+
|
157
|
+
# Check for missing or None columns in annotations
|
158
|
+
missing_annot_columns = [col for col in required_columns if col not in self.columns_annotations or self.columns_annotations[col] is None]
|
159
|
+
|
160
|
+
if missing_pred_columns:
|
161
|
+
raise ValueError(f"Missing or None prediction columns: {', '.join(missing_pred_columns)}")
|
162
|
+
if missing_annot_columns:
|
163
|
+
raise ValueError(f"Missing or None annotation columns: {', '.join(missing_annot_columns)}")
|
164
|
+
|
165
|
+
def load_data(self) -> None:
|
166
|
+
"""
|
167
|
+
Loads the prediction and annotation data into DataFrames.
|
168
|
+
|
169
|
+
Depending on whether specific files are provided, this method either reads all files
|
170
|
+
in the given directories or reads the specified files. The method also applies any
|
171
|
+
specified class mapping and prepares the data for further processing.
|
172
|
+
|
173
|
+
Raises:
|
174
|
+
ValueError: If file reading fails or data preparation encounters issues.
|
175
|
+
"""
|
176
|
+
if self.prediction_file_name is None or self.annotation_file_name is None:
|
177
|
+
# Case: No specific files provided; load all files in directories.
|
178
|
+
self.predictions_df = read_and_concatenate_files_in_directory(self.prediction_directory_path)
|
179
|
+
self.annotations_df = read_and_concatenate_files_in_directory(self.annotation_directory_path)
|
180
|
+
|
181
|
+
# Ensure 'source_file' column exists for traceability
|
182
|
+
if "source_file" not in self.predictions_df.columns:
|
183
|
+
self.predictions_df["source_file"] = ""
|
184
|
+
|
185
|
+
if "source_file" not in self.annotations_df.columns:
|
186
|
+
self.annotations_df["source_file"] = ""
|
187
|
+
|
188
|
+
# Prepare DataFrames
|
189
|
+
self.predictions_df = self._prepare_dataframe(self.predictions_df, prediction=True)
|
190
|
+
self.annotations_df = self._prepare_dataframe(self.annotations_df, prediction=False)
|
191
|
+
|
192
|
+
# Apply class mapping to predictions if provided
|
193
|
+
if self.class_mapping:
|
194
|
+
class_col_pred = self.get_column_name("Class", prediction=True)
|
195
|
+
self.predictions_df[class_col_pred] = self.predictions_df[class_col_pred].apply(lambda x: self.class_mapping.get(x, x))
|
196
|
+
else:
|
197
|
+
# Case: Specific files are provided for predictions and annotations.
|
198
|
+
# Ensure filenames correspond to the same recording (heuristic check).
|
199
|
+
if not self.prediction_file_name.startswith(os.path.splitext(self.annotation_file_name)[0]):
|
200
|
+
warnings.warn(
|
201
|
+
"Prediction file name and annotation file name do not fully match, but proceeding anyway.",
|
202
|
+
stacklevel=2,
|
203
|
+
)
|
204
|
+
|
205
|
+
# Construct full file paths
|
206
|
+
prediction_file = os.path.join(self.prediction_directory_path, self.prediction_file_name)
|
207
|
+
annotation_file = os.path.join(self.annotation_directory_path, self.annotation_file_name)
|
208
|
+
|
209
|
+
# Load files into DataFrames
|
210
|
+
self.predictions_df = pd.read_csv(prediction_file, sep="\t")
|
211
|
+
self.annotations_df = pd.read_csv(annotation_file, sep="\t")
|
212
|
+
|
213
|
+
# Add 'source_file' column to identify origins
|
214
|
+
self.predictions_df["source_file"] = self.prediction_file_name
|
215
|
+
self.annotations_df["source_file"] = self.annotation_file_name
|
216
|
+
|
217
|
+
# Prepare DataFrames
|
218
|
+
self.predictions_df = self._prepare_dataframe(self.predictions_df, prediction=True)
|
219
|
+
self.annotations_df = self._prepare_dataframe(self.annotations_df, prediction=False)
|
220
|
+
|
221
|
+
# Apply class mapping to predictions if provided
|
222
|
+
if self.class_mapping:
|
223
|
+
class_col_pred = self.get_column_name("Class", prediction=True)
|
224
|
+
self.predictions_df[class_col_pred] = self.predictions_df[class_col_pred].apply(lambda x: self.class_mapping.get(x, x))
|
225
|
+
|
226
|
+
# Consolidate all unique classes from predictions and annotations
|
227
|
+
class_col_pred = self.get_column_name("Class", prediction=True)
|
228
|
+
class_col_annot = self.get_column_name("Class", prediction=False)
|
229
|
+
|
230
|
+
pred_classes = set(self.predictions_df[class_col_pred].unique())
|
231
|
+
annot_classes = set(self.annotations_df[class_col_annot].unique())
|
232
|
+
|
233
|
+
# Remove any NaN values from the union
|
234
|
+
all_classes = {cls for cls in pred_classes.union(annot_classes) if pd.notna(cls)}
|
235
|
+
self.classes = tuple(sorted(all_classes))
|
236
|
+
|
237
|
+
def _prepare_dataframe(self, df: pd.DataFrame, prediction: bool) -> pd.DataFrame:
|
238
|
+
"""
|
239
|
+
Prepares a DataFrame by adding a 'recording_filename' column.
|
240
|
+
|
241
|
+
This method extracts the recording filename from either a specified 'Recording' column
|
242
|
+
or from the 'source_file' column to ensure traceability.
|
243
|
+
|
244
|
+
Args:
|
245
|
+
df (pd.DataFrame): The DataFrame to prepare.
|
246
|
+
prediction (bool): Whether the DataFrame is for predictions or annotations.
|
247
|
+
|
248
|
+
Returns:
|
249
|
+
pd.DataFrame: The prepared DataFrame with the added 'recording_filename' column.
|
250
|
+
"""
|
251
|
+
# Determine the relevant column for extracting recording filenames
|
252
|
+
recording_col = self.get_column_name("Recording", prediction=prediction)
|
253
|
+
|
254
|
+
if recording_col in df.columns:
|
255
|
+
# Extract recording filename using the 'Recording' column
|
256
|
+
df["recording_filename"] = extract_recording_filename(df[recording_col])
|
257
|
+
elif "source_file" in df.columns:
|
258
|
+
# Fall back to extracting from the 'source_file' column
|
259
|
+
df["recording_filename"] = extract_recording_filename_from_filename(df["source_file"])
|
260
|
+
else:
|
261
|
+
# Assign a default empty string if no relevant columns exist
|
262
|
+
df["recording_filename"] = ""
|
263
|
+
|
264
|
+
return df
|
265
|
+
|
266
|
+
def process_data(self) -> None:
|
267
|
+
"""
|
268
|
+
Processes the loaded data, aligns predictions and annotations with sample intervals,
|
269
|
+
and updates the samples DataFrame.
|
270
|
+
|
271
|
+
This method iterates through all recording filenames, processes each recording,
|
272
|
+
and aggregates the results into the `samples_df` attribute.
|
273
|
+
"""
|
274
|
+
self.samples_df = pd.DataFrame() # Initialize the samples DataFrame
|
275
|
+
|
276
|
+
# Get the unique set of recording filenames from both predictions and annotations
|
277
|
+
recording_filenames = set(self.predictions_df["recording_filename"].unique()).union(set(self.annotations_df["recording_filename"].unique()))
|
278
|
+
|
279
|
+
# Process each recording
|
280
|
+
for recording_filename in recording_filenames:
|
281
|
+
# Filter predictions and annotations for the current recording
|
282
|
+
pred_df = self.predictions_df[self.predictions_df["recording_filename"] == recording_filename]
|
283
|
+
annot_df = self.annotations_df[self.annotations_df["recording_filename"] == recording_filename]
|
284
|
+
|
285
|
+
# Generate sample intervals and annotations for the recording
|
286
|
+
samples_df = self.process_recording(recording_filename, pred_df, annot_df)
|
287
|
+
|
288
|
+
# Append the processed DataFrame to the overall samples DataFrame
|
289
|
+
self.samples_df = pd.concat([self.samples_df, samples_df], ignore_index=True)
|
290
|
+
|
291
|
+
def process_recording(self, recording_filename: str, pred_df: pd.DataFrame, annot_df: pd.DataFrame) -> pd.DataFrame:
|
292
|
+
"""
|
293
|
+
Processes a single recording by determining its duration, initializing sample intervals,
|
294
|
+
and updating the intervals with predictions and annotations.
|
295
|
+
|
296
|
+
Args:
|
297
|
+
recording_filename (str): The name of the recording.
|
298
|
+
pred_df (pd.DataFrame): Predictions DataFrame specific to the recording.
|
299
|
+
annot_df (pd.DataFrame): Annotations DataFrame specific to the recording.
|
300
|
+
|
301
|
+
Returns:
|
302
|
+
pd.DataFrame: A DataFrame containing sample intervals with prediction and annotation data.
|
303
|
+
"""
|
304
|
+
# Determine the duration of the recording
|
305
|
+
file_duration = self.determine_file_duration(pred_df, annot_df)
|
306
|
+
|
307
|
+
if file_duration <= 0:
|
308
|
+
# Return an empty DataFrame if the duration is invalid
|
309
|
+
return pd.DataFrame()
|
310
|
+
|
311
|
+
# Initialize sample intervals for the recording
|
312
|
+
samples_df = self.initialize_samples(recording_filename=recording_filename, file_duration=file_duration)
|
313
|
+
|
314
|
+
# Update the samples DataFrame with prediction data
|
315
|
+
self.update_samples_with_predictions(pred_df, samples_df)
|
316
|
+
|
317
|
+
# Update the samples DataFrame with annotation data
|
318
|
+
self.update_samples_with_annotations(annot_df, samples_df)
|
319
|
+
|
320
|
+
return samples_df
|
321
|
+
|
322
|
+
def determine_file_duration(self, pred_df: pd.DataFrame, annot_df: pd.DataFrame) -> float:
|
323
|
+
"""
|
324
|
+
Determines the duration of the recording based on available dataframes or the specified recording duration.
|
325
|
+
|
326
|
+
This method prioritizes the explicitly set `recording_duration` if available.
|
327
|
+
Otherwise, it computes the duration from the `Duration` or `End Time` columns in the
|
328
|
+
predictions and annotations DataFrames. Handles edge cases where data may be incomplete
|
329
|
+
or missing.
|
330
|
+
|
331
|
+
Args:
|
332
|
+
pred_df (pd.DataFrame): Predictions DataFrame containing duration or end time information.
|
333
|
+
annot_df (pd.DataFrame): Annotations DataFrame containing duration or end time information.
|
334
|
+
|
335
|
+
Returns:
|
336
|
+
float: The determined duration of the recording. Defaults to 0 if no valid duration is found.
|
337
|
+
"""
|
338
|
+
if self.recording_duration is not None:
|
339
|
+
# Use the explicitly provided recording duration
|
340
|
+
return self.recording_duration
|
341
|
+
|
342
|
+
duration = 0.0
|
343
|
+
|
344
|
+
# Extract the 'Duration' column from predictions if available
|
345
|
+
file_duration_col_pred = self.get_column_name("Duration", prediction=True)
|
346
|
+
|
347
|
+
file_duration_col_annot = self.get_column_name("Duration", prediction=False)
|
348
|
+
|
349
|
+
# Try to get duration from 'Duration' column in pred_df
|
350
|
+
if file_duration_col_pred in pred_df.columns and pred_df[file_duration_col_pred].notna().any():
|
351
|
+
duration = max(duration, pred_df[file_duration_col_pred].dropna().max())
|
352
|
+
|
353
|
+
# Try to get duration from 'Duration' column in annot_df
|
354
|
+
if file_duration_col_annot in annot_df.columns and annot_df[file_duration_col_annot].notna().any():
|
355
|
+
duration = max(duration, annot_df[file_duration_col_annot].dropna().max())
|
356
|
+
|
357
|
+
# If no duration is found, use the maximum 'End Time' value
|
358
|
+
if duration == 0.0:
|
359
|
+
end_time_col_pred = self.get_column_name("End Time", prediction=True)
|
360
|
+
end_time_col_annot = self.get_column_name("End Time", prediction=False)
|
361
|
+
|
362
|
+
max_end_pred = pred_df[end_time_col_pred].max() if end_time_col_pred in pred_df.columns else 0.0
|
363
|
+
max_end_annot = annot_df[end_time_col_annot].max() if end_time_col_annot in annot_df.columns else 0.0
|
364
|
+
duration = max(max_end_pred, max_end_annot)
|
365
|
+
|
366
|
+
# Handle invalid values (NaN or negative duration)
|
367
|
+
if pd.isna(duration) or duration < 0:
|
368
|
+
duration = 0.0
|
369
|
+
|
370
|
+
return duration
|
371
|
+
|
372
|
+
def initialize_samples(self, recording_filename: str, file_duration: float) -> pd.DataFrame:
|
373
|
+
"""
|
374
|
+
Initializes a DataFrame of time-based sample intervals for the specified recording.
|
375
|
+
|
376
|
+
Samples are evenly spaced time intervals of length `sample_duration` that cover the
|
377
|
+
entire recording duration. Each sample is initialized with confidence scores and
|
378
|
+
annotation values for all classes.
|
379
|
+
|
380
|
+
Args:
|
381
|
+
recording_filename (str): The name of the recording.
|
382
|
+
file_duration (float): The total duration of the recording in seconds.
|
383
|
+
|
384
|
+
Returns:
|
385
|
+
pd.DataFrame: A DataFrame containing initialized sample intervals, confidence scores, and annotations.
|
386
|
+
Returns an empty DataFrame if the file duration is less than or equal to 0.
|
387
|
+
"""
|
388
|
+
if file_duration <= 0:
|
389
|
+
# Return an empty DataFrame if file duration is invalid
|
390
|
+
return pd.DataFrame()
|
391
|
+
|
392
|
+
# Generate start times for each sample interval
|
393
|
+
intervals = np.arange(0, file_duration, self.sample_duration)
|
394
|
+
if len(intervals) == 0:
|
395
|
+
intervals = np.array([0])
|
396
|
+
|
397
|
+
# Prepare sample structure
|
398
|
+
samples = {
|
399
|
+
"filename": recording_filename,
|
400
|
+
"sample_index": [],
|
401
|
+
"start_time": [],
|
402
|
+
"end_time": [],
|
403
|
+
}
|
404
|
+
|
405
|
+
for idx, start in enumerate(intervals):
|
406
|
+
samples["sample_index"].append(idx)
|
407
|
+
samples["start_time"].append(start)
|
408
|
+
samples["end_time"].append(min(start + self.sample_duration, file_duration))
|
409
|
+
|
410
|
+
# Initialize confidence scores and annotations for each class
|
411
|
+
for label in self.classes:
|
412
|
+
samples[f"{label}_confidence"] = [0.0] * len(samples["sample_index"]) # Float values
|
413
|
+
samples[f"{label}_annotation"] = [0] * len(samples["sample_index"]) # Integer values
|
414
|
+
|
415
|
+
return pd.DataFrame(samples)
|
416
|
+
|
417
|
+
def update_samples_with_predictions(self, pred_df: pd.DataFrame, samples_df: pd.DataFrame) -> None:
|
418
|
+
"""
|
419
|
+
Updates the samples DataFrame with prediction confidence scores.
|
420
|
+
|
421
|
+
For each prediction in the predictions DataFrame, this method identifies overlapping
|
422
|
+
samples based on the specified `min_overlap`. It then updates the confidence scores
|
423
|
+
for those samples, retaining the maximum confidence value if multiple predictions overlap.
|
424
|
+
|
425
|
+
Args:
|
426
|
+
pred_df (pd.DataFrame): DataFrame containing prediction information.
|
427
|
+
samples_df (pd.DataFrame): DataFrame of samples to be updated with confidence scores.
|
428
|
+
"""
|
429
|
+
# Retrieve the column names for predictions
|
430
|
+
class_col = self.get_column_name("Class", prediction=True)
|
431
|
+
start_time_col = self.get_column_name("Start Time", prediction=True)
|
432
|
+
end_time_col = self.get_column_name("End Time", prediction=True)
|
433
|
+
confidence_col = self.get_column_name("Confidence", prediction=True)
|
434
|
+
|
435
|
+
# Iterate through each prediction row
|
436
|
+
for _, row in pred_df.iterrows():
|
437
|
+
class_name = row[class_col]
|
438
|
+
if class_name not in self.classes:
|
439
|
+
continue # Skip predictions for classes not included in the predefined list
|
440
|
+
|
441
|
+
# Extract start and end times, and confidence score
|
442
|
+
begin_time = row[start_time_col]
|
443
|
+
end_time = row[end_time_col]
|
444
|
+
confidence = row.get(confidence_col, 0.0)
|
445
|
+
|
446
|
+
# Identify samples that overlap with the prediction based on min_overlap
|
447
|
+
sample_indices = samples_df[(samples_df["start_time"] <= end_time - self.min_overlap) & (samples_df["end_time"] >= begin_time + self.min_overlap)].index
|
448
|
+
|
449
|
+
# Update the confidence scores for the overlapping samples
|
450
|
+
for i in sample_indices:
|
451
|
+
current_confidence = samples_df.loc[i, f"{class_name}_confidence"]
|
452
|
+
samples_df.loc[i, f"{class_name}_confidence"] = max(current_confidence, confidence)
|
453
|
+
|
454
|
+
def update_samples_with_annotations(self, annot_df: pd.DataFrame, samples_df: pd.DataFrame) -> None:
|
455
|
+
"""
|
456
|
+
Updates the samples DataFrame with annotations.
|
457
|
+
|
458
|
+
For each annotation in the annotations DataFrame, this method identifies overlapping
|
459
|
+
samples based on the specified `min_overlap`. It sets the annotation value to 1
|
460
|
+
for the overlapping samples.
|
461
|
+
|
462
|
+
Args:
|
463
|
+
annot_df (pd.DataFrame): DataFrame containing annotation information.
|
464
|
+
samples_df (pd.DataFrame): DataFrame of samples to be updated with annotations.
|
465
|
+
"""
|
466
|
+
# Retrieve the column names for annotations
|
467
|
+
class_col = self.get_column_name("Class", prediction=False)
|
468
|
+
start_time_col = self.get_column_name("Start Time", prediction=False)
|
469
|
+
end_time_col = self.get_column_name("End Time", prediction=False)
|
470
|
+
|
471
|
+
# Iterate through each annotation row
|
472
|
+
for _, row in annot_df.iterrows():
|
473
|
+
class_name = row[class_col]
|
474
|
+
if class_name not in self.classes:
|
475
|
+
continue # Skip annotations for classes not included in the predefined list
|
476
|
+
|
477
|
+
# Extract start and end times
|
478
|
+
begin_time = row[start_time_col]
|
479
|
+
end_time = row[end_time_col]
|
480
|
+
|
481
|
+
# Identify samples that overlap with the annotation based on min_overlap
|
482
|
+
sample_indices = samples_df[(samples_df["start_time"] <= end_time - self.min_overlap) & (samples_df["end_time"] >= begin_time + self.min_overlap)].index
|
483
|
+
|
484
|
+
# Set annotation value to 1 for the overlapping samples
|
485
|
+
for i in sample_indices:
|
486
|
+
samples_df.loc[i, f"{class_name}_annotation"] = 1
|
487
|
+
|
488
|
+
def create_tensors(self) -> None:
|
489
|
+
"""
|
490
|
+
Creates prediction and label tensors from the samples DataFrame.
|
491
|
+
|
492
|
+
This method converts confidence scores and annotations for each class into
|
493
|
+
numpy arrays (tensors). It ensures that there are no NaN values in the DataFrame
|
494
|
+
before creating the tensors.
|
495
|
+
|
496
|
+
Raises:
|
497
|
+
ValueError: If NaN values are found in confidence or annotation columns.
|
498
|
+
"""
|
499
|
+
if self.samples_df.empty:
|
500
|
+
# Initialize empty tensors if samples DataFrame is empty
|
501
|
+
self.prediction_tensors = np.empty((0, len(self.classes)), dtype=np.float32)
|
502
|
+
self.label_tensors = np.empty((0, len(self.classes)), dtype=np.int64)
|
503
|
+
return
|
504
|
+
|
505
|
+
# Check for NaN values in annotation columns
|
506
|
+
annotation_columns = [f"{cls}_annotation" for cls in self.classes]
|
507
|
+
if self.samples_df[annotation_columns].isna().to_numpy().any():
|
508
|
+
raise ValueError("NaN values found in annotation columns.")
|
509
|
+
|
510
|
+
# Check for NaN values in confidence columns
|
511
|
+
confidence_columns = [f"{cls}_confidence" for cls in self.classes]
|
512
|
+
if self.samples_df[confidence_columns].isna().to_numpy().any():
|
513
|
+
raise ValueError("NaN values found in confidence columns.")
|
514
|
+
|
515
|
+
# Convert confidence scores and annotations into numpy arrays (tensors)
|
516
|
+
self.prediction_tensors = self.samples_df[confidence_columns].to_numpy(dtype=np.float32)
|
517
|
+
self.label_tensors = self.samples_df[annotation_columns].to_numpy(dtype=np.int64)
|
518
|
+
|
519
|
+
def get_column_name(self, field_name: str, prediction: bool = True) -> str:
|
520
|
+
"""
|
521
|
+
Retrieves the appropriate column name for the specified field.
|
522
|
+
|
523
|
+
This method checks the column mapping (for predictions or annotations) and
|
524
|
+
returns the corresponding column name. If the field is not mapped, it returns
|
525
|
+
the field name directly.
|
526
|
+
|
527
|
+
Args:
|
528
|
+
field_name (str): The name of the field (e.g., "Class", "Start Time").
|
529
|
+
prediction (bool): Whether to fetch the name from the predictions mapping (True)
|
530
|
+
or annotations mapping (False).
|
531
|
+
|
532
|
+
Returns:
|
533
|
+
str: The column name corresponding to the field.
|
534
|
+
|
535
|
+
Raises:
|
536
|
+
TypeError: If `field_name` or `prediction` is None.
|
537
|
+
"""
|
538
|
+
if field_name is None:
|
539
|
+
raise TypeError("field_name cannot be None.")
|
540
|
+
if prediction is None:
|
541
|
+
raise TypeError("prediction parameter cannot be None.")
|
542
|
+
|
543
|
+
# Select the appropriate mapping based on the `prediction` flag
|
544
|
+
mapping = self.columns_predictions if prediction else self.columns_annotations
|
545
|
+
|
546
|
+
if field_name in mapping and mapping[field_name] is not None:
|
547
|
+
return mapping[field_name]
|
548
|
+
|
549
|
+
return field_name
|
550
|
+
|
551
|
+
def get_sample_data(self) -> pd.DataFrame:
|
552
|
+
"""
|
553
|
+
Retrieves the DataFrame containing all sample intervals, prediction scores, and annotations.
|
554
|
+
|
555
|
+
This method provides a copy of the `samples_df` DataFrame, ensuring that the original
|
556
|
+
data is not modified when accessed externally.
|
557
|
+
|
558
|
+
Returns:
|
559
|
+
pd.DataFrame: A copy of the `samples_df` DataFrame, which contains the sampled data.
|
560
|
+
"""
|
561
|
+
# Return a copy of the samples DataFrame to preserve data integrity
|
562
|
+
return self.samples_df.copy()
|
563
|
+
|
564
|
+
def get_filtered_tensors(
|
565
|
+
self,
|
566
|
+
selected_classes: list[str] | None = None,
|
567
|
+
selected_recordings: list[str] | None = None,
|
568
|
+
) -> tuple[np.ndarray, np.ndarray, tuple[str]]:
|
569
|
+
"""
|
570
|
+
Filters the prediction and label tensors based on selected classes and recordings.
|
571
|
+
|
572
|
+
This method extracts subsets of the prediction and label tensors for specific classes
|
573
|
+
and/or recordings. It ensures that the filtered tensors correspond to valid classes
|
574
|
+
and recordings present in the sampled data.
|
575
|
+
|
576
|
+
Args:
|
577
|
+
selected_classes (List[str], optional): A list of class names to filter by. If None,
|
578
|
+
all classes are included.
|
579
|
+
selected_recordings (List[str], optional): A list of recording filenames to filter by. If None,
|
580
|
+
all recordings are included.
|
581
|
+
|
582
|
+
Returns:
|
583
|
+
Tuple[np.ndarray, np.ndarray, Tuple[str]]: A tuple containing:
|
584
|
+
- Filtered prediction tensors (numpy.ndarray)
|
585
|
+
- Filtered label tensors (numpy.ndarray)
|
586
|
+
- Tuple of selected class names (Tuple[str])
|
587
|
+
|
588
|
+
Raises:
|
589
|
+
ValueError: If the `samples_df` is empty or missing required columns.
|
590
|
+
KeyError: If required confidence or annotation columns are missing in the DataFrame.
|
591
|
+
"""
|
592
|
+
if self.samples_df.empty:
|
593
|
+
raise ValueError("samples_df is empty.")
|
594
|
+
|
595
|
+
if "filename" not in self.samples_df.columns:
|
596
|
+
raise ValueError("samples_df must contain a 'filename' column.")
|
597
|
+
|
598
|
+
# Determine the classes to filter by
|
599
|
+
classes = self.classes if selected_classes is None else tuple(cls for cls in selected_classes if cls in self.classes)
|
600
|
+
|
601
|
+
if not classes:
|
602
|
+
raise ValueError("No valid classes selected.")
|
603
|
+
|
604
|
+
# Create a mask for filtering samples
|
605
|
+
mask = pd.Series(True, index=self.samples_df.index)
|
606
|
+
|
607
|
+
# Apply recording-based filtering if specified
|
608
|
+
if selected_recordings is not None:
|
609
|
+
if selected_recordings:
|
610
|
+
mask &= self.samples_df["filename"].isin(selected_recordings)
|
611
|
+
else:
|
612
|
+
# If `selected_recordings` is an empty list, select no samples
|
613
|
+
mask = pd.Series(False, index=self.samples_df.index)
|
614
|
+
|
615
|
+
# Filter the samples DataFrame using the mask
|
616
|
+
filtered_samples = self.samples_df.loc[mask]
|
617
|
+
|
618
|
+
# Prepare column names for confidence and annotation data
|
619
|
+
confidence_columns = [f"{cls}_confidence" for cls in classes]
|
620
|
+
annotation_columns = [f"{cls}_annotation" for cls in classes]
|
621
|
+
|
622
|
+
# Ensure all required columns are present in the filtered DataFrame
|
623
|
+
if not all(col in filtered_samples.columns for col in confidence_columns + annotation_columns):
|
624
|
+
raise KeyError("Required confidence or annotation columns are missing.")
|
625
|
+
|
626
|
+
# Convert filtered data into numpy arrays
|
627
|
+
predictions = filtered_samples[confidence_columns].to_numpy(dtype=np.float32)
|
628
|
+
labels = filtered_samples[annotation_columns].to_numpy(dtype=np.int64)
|
629
|
+
|
630
|
+
# Return the tensors and the list of filtered classes
|
631
|
+
return predictions, labels, classes
|