birdnet-analyzer 2.0.0__py3-none-any.whl → 2.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (122) hide show
  1. birdnet_analyzer/__init__.py +9 -8
  2. birdnet_analyzer/analyze/__init__.py +5 -5
  3. birdnet_analyzer/analyze/__main__.py +3 -4
  4. birdnet_analyzer/analyze/cli.py +25 -25
  5. birdnet_analyzer/analyze/core.py +241 -245
  6. birdnet_analyzer/analyze/utils.py +692 -701
  7. birdnet_analyzer/audio.py +368 -372
  8. birdnet_analyzer/cli.py +709 -707
  9. birdnet_analyzer/config.py +242 -242
  10. birdnet_analyzer/eBird_taxonomy_codes_2021E.json +25279 -25279
  11. birdnet_analyzer/embeddings/__init__.py +3 -4
  12. birdnet_analyzer/embeddings/__main__.py +3 -3
  13. birdnet_analyzer/embeddings/cli.py +12 -13
  14. birdnet_analyzer/embeddings/core.py +69 -70
  15. birdnet_analyzer/embeddings/utils.py +179 -193
  16. birdnet_analyzer/evaluation/__init__.py +196 -195
  17. birdnet_analyzer/evaluation/__main__.py +3 -3
  18. birdnet_analyzer/evaluation/assessment/__init__.py +0 -0
  19. birdnet_analyzer/evaluation/assessment/metrics.py +388 -0
  20. birdnet_analyzer/evaluation/assessment/performance_assessor.py +409 -0
  21. birdnet_analyzer/evaluation/assessment/plotting.py +379 -0
  22. birdnet_analyzer/evaluation/preprocessing/__init__.py +0 -0
  23. birdnet_analyzer/evaluation/preprocessing/data_processor.py +631 -0
  24. birdnet_analyzer/evaluation/preprocessing/utils.py +98 -0
  25. birdnet_analyzer/gui/__init__.py +19 -23
  26. birdnet_analyzer/gui/__main__.py +3 -3
  27. birdnet_analyzer/gui/analysis.py +175 -174
  28. birdnet_analyzer/gui/assets/arrow_down.svg +4 -4
  29. birdnet_analyzer/gui/assets/arrow_left.svg +4 -4
  30. birdnet_analyzer/gui/assets/arrow_right.svg +4 -4
  31. birdnet_analyzer/gui/assets/arrow_up.svg +4 -4
  32. birdnet_analyzer/gui/assets/gui.css +28 -28
  33. birdnet_analyzer/gui/assets/gui.js +93 -93
  34. birdnet_analyzer/gui/embeddings.py +619 -620
  35. birdnet_analyzer/gui/evaluation.py +795 -813
  36. birdnet_analyzer/gui/localization.py +75 -68
  37. birdnet_analyzer/gui/multi_file.py +245 -246
  38. birdnet_analyzer/gui/review.py +519 -527
  39. birdnet_analyzer/gui/segments.py +191 -191
  40. birdnet_analyzer/gui/settings.py +128 -129
  41. birdnet_analyzer/gui/single_file.py +267 -269
  42. birdnet_analyzer/gui/species.py +95 -95
  43. birdnet_analyzer/gui/train.py +696 -698
  44. birdnet_analyzer/gui/utils.py +810 -808
  45. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_af.txt +6522 -6522
  46. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ar.txt +6522 -6522
  47. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_bg.txt +6522 -6522
  48. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ca.txt +6522 -6522
  49. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_cs.txt +6522 -6522
  50. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_da.txt +6522 -6522
  51. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_de.txt +6522 -6522
  52. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_el.txt +6522 -6522
  53. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_en_uk.txt +6522 -6522
  54. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_es.txt +6522 -6522
  55. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_fi.txt +6522 -6522
  56. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_fr.txt +6522 -6522
  57. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_he.txt +6522 -6522
  58. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_hr.txt +6522 -6522
  59. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_hu.txt +6522 -6522
  60. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_in.txt +6522 -6522
  61. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_is.txt +6522 -6522
  62. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_it.txt +6522 -6522
  63. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ja.txt +6522 -6522
  64. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ko.txt +6522 -6522
  65. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_lt.txt +6522 -6522
  66. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ml.txt +6522 -6522
  67. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_nl.txt +6522 -6522
  68. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_no.txt +6522 -6522
  69. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_pl.txt +6522 -6522
  70. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_pt_BR.txt +6522 -6522
  71. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_pt_PT.txt +6522 -6522
  72. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ro.txt +6522 -6522
  73. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ru.txt +6522 -6522
  74. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sk.txt +6522 -6522
  75. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sl.txt +6522 -6522
  76. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sr.txt +6522 -6522
  77. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sv.txt +6522 -6522
  78. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_th.txt +6522 -6522
  79. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_tr.txt +6522 -6522
  80. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_uk.txt +6522 -6522
  81. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_zh.txt +6522 -6522
  82. birdnet_analyzer/lang/de.json +334 -334
  83. birdnet_analyzer/lang/en.json +334 -334
  84. birdnet_analyzer/lang/fi.json +334 -334
  85. birdnet_analyzer/lang/fr.json +334 -334
  86. birdnet_analyzer/lang/id.json +334 -334
  87. birdnet_analyzer/lang/pt-br.json +334 -334
  88. birdnet_analyzer/lang/ru.json +334 -334
  89. birdnet_analyzer/lang/se.json +334 -334
  90. birdnet_analyzer/lang/tlh.json +334 -334
  91. birdnet_analyzer/lang/zh_TW.json +334 -334
  92. birdnet_analyzer/model.py +1212 -1243
  93. birdnet_analyzer/playground.py +5 -0
  94. birdnet_analyzer/search/__init__.py +3 -3
  95. birdnet_analyzer/search/__main__.py +3 -3
  96. birdnet_analyzer/search/cli.py +11 -12
  97. birdnet_analyzer/search/core.py +78 -78
  98. birdnet_analyzer/search/utils.py +107 -111
  99. birdnet_analyzer/segments/__init__.py +3 -3
  100. birdnet_analyzer/segments/__main__.py +3 -3
  101. birdnet_analyzer/segments/cli.py +13 -14
  102. birdnet_analyzer/segments/core.py +81 -78
  103. birdnet_analyzer/segments/utils.py +383 -394
  104. birdnet_analyzer/species/__init__.py +3 -3
  105. birdnet_analyzer/species/__main__.py +3 -3
  106. birdnet_analyzer/species/cli.py +13 -14
  107. birdnet_analyzer/species/core.py +35 -35
  108. birdnet_analyzer/species/utils.py +74 -75
  109. birdnet_analyzer/train/__init__.py +3 -3
  110. birdnet_analyzer/train/__main__.py +3 -3
  111. birdnet_analyzer/train/cli.py +13 -14
  112. birdnet_analyzer/train/core.py +113 -113
  113. birdnet_analyzer/train/utils.py +877 -847
  114. birdnet_analyzer/translate.py +133 -104
  115. birdnet_analyzer/utils.py +426 -419
  116. {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.0.1.dist-info}/METADATA +137 -129
  117. birdnet_analyzer-2.0.1.dist-info/RECORD +125 -0
  118. {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.0.1.dist-info}/WHEEL +1 -1
  119. {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.0.1.dist-info}/licenses/LICENSE +18 -18
  120. birdnet_analyzer-2.0.0.dist-info/RECORD +0 -117
  121. {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.0.1.dist-info}/entry_points.txt +0 -0
  122. {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.0.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,631 @@
1
+ """
2
+ DataProcessor class for handling and transforming sample data with annotations and predictions.
3
+
4
+ This module defines the DataProcessor class, which processes prediction and annotation data,
5
+ aligns them with sampled time intervals, and generates tensors for further model training or evaluation.
6
+ """
7
+
8
+ import os
9
+ import warnings
10
+ from typing import ClassVar
11
+
12
+ import numpy as np
13
+ import pandas as pd
14
+
15
+ from birdnet_analyzer.evaluation.preprocessing.utils import (
16
+ extract_recording_filename,
17
+ extract_recording_filename_from_filename,
18
+ read_and_concatenate_files_in_directory,
19
+ )
20
+
21
+
22
+ class DataProcessor:
23
+ """
24
+ Processor for handling and transforming sample data with annotations and predictions.
25
+
26
+ This class processes prediction and annotation data, aligning them with sampled time intervals,
27
+ and generates tensors for further model training or evaluation.
28
+ """
29
+
30
+ # Default column mappings for predictions and annotations
31
+ DEFAULT_COLUMNS_PREDICTIONS: ClassVar[dict[str, str]] = {
32
+ "Start Time": "Start Time",
33
+ "End Time": "End Time",
34
+ "Class": "Class",
35
+ "Recording": "Recording",
36
+ "Duration": "Duration",
37
+ "Confidence": "Confidence",
38
+ }
39
+
40
+ DEFAULT_COLUMNS_ANNOTATIONS: ClassVar[dict[str, str]] = {
41
+ "Start Time": "Start Time",
42
+ "End Time": "End Time",
43
+ "Class": "Class",
44
+ "Recording": "Recording",
45
+ "Duration": "Duration",
46
+ }
47
+
48
+ def __init__(
49
+ self,
50
+ prediction_directory_path: str,
51
+ annotation_directory_path: str,
52
+ prediction_file_name: str | None = None,
53
+ annotation_file_name: str | None = None,
54
+ class_mapping: dict[str, str] | None = None,
55
+ sample_duration: int = 3,
56
+ min_overlap: float = 0.5,
57
+ columns_predictions: dict[str, str] | None = None,
58
+ columns_annotations: dict[str, str] | None = None,
59
+ recording_duration: float | None = None,
60
+ ) -> None:
61
+ """
62
+ Initializes the DataProcessor by loading prediction and annotation data.
63
+
64
+ Args:
65
+ prediction_directory_path (str): Path to the folder containing prediction files.
66
+ annotation_directory_path (str): Path to the folder containing annotation files.
67
+ prediction_file_name (Optional[str]): Name of the prediction file to process.
68
+ annotation_file_name (Optional[str]): Name of the annotation file to process.
69
+ class_mapping (Optional[Dict[str, str]]): Optional dictionary mapping raw class
70
+ names to standardized class names.
71
+ sample_duration (int, optional): Length of each data sample in seconds. Defaults to 3.
72
+ min_overlap (float, optional): Minimum overlap required between prediction and
73
+ annotation to consider a match.
74
+ columns_predictions (Optional[Dict[str, str]], optional): Column name mappings for prediction files.
75
+ columns_annotations (Optional[Dict[str, str]], optional): Column name mappings for annotation files.
76
+ recording_duration (Optional[float], optional): User-specified recording duration in seconds.
77
+ Defaults to None.
78
+
79
+ Raises:
80
+ ValueError: If any parameter is invalid (e.g., negative sample duration).
81
+ """
82
+ # Initialize instance variables
83
+ self.sample_duration: int = sample_duration
84
+ self.min_overlap: float = min_overlap
85
+ self.class_mapping: dict[str, str] | None = class_mapping
86
+
87
+ # Use provided column mappings or defaults
88
+ self.columns_predictions: dict[str, str] = columns_predictions if columns_predictions is not None else self.DEFAULT_COLUMNS_PREDICTIONS.copy()
89
+ self.columns_annotations: dict[str, str] = columns_annotations if columns_annotations is not None else self.DEFAULT_COLUMNS_ANNOTATIONS.copy()
90
+
91
+ self.recording_duration: float | None = recording_duration
92
+
93
+ # Paths and filenames
94
+ self.prediction_directory_path: str = prediction_directory_path
95
+ self.prediction_file_name: str | None = prediction_file_name
96
+ self.annotation_directory_path: str = annotation_directory_path
97
+ self.annotation_file_name: str | None = annotation_file_name
98
+
99
+ # DataFrames for predictions and annotations
100
+ self.predictions_df: pd.DataFrame = pd.DataFrame()
101
+ self.annotations_df: pd.DataFrame = pd.DataFrame()
102
+
103
+ # Placeholder for unique classes across predictions and annotations
104
+ self.classes: tuple[str, ...] = ()
105
+
106
+ # Placeholder for samples DataFrame and tensors
107
+ self.samples_df: pd.DataFrame = pd.DataFrame()
108
+ self.prediction_tensors: np.ndarray = np.array([])
109
+ self.label_tensors: np.ndarray = np.array([])
110
+
111
+ # Validate column mappings and parameters
112
+ self._validate_columns()
113
+ self._validate_parameters()
114
+
115
+ # Load and process data
116
+ self.load_data()
117
+ self.process_data()
118
+ self.create_tensors()
119
+
120
+ def _validate_parameters(self) -> None:
121
+ """
122
+ Validates the input parameters for correctness.
123
+
124
+ Raises:
125
+ ValueError: If sample duration, minimum overlap, or recording duration is invalid.
126
+ """
127
+ # Validate sample duration
128
+ if self.sample_duration <= 0:
129
+ raise ValueError("Sample duration must be positive.")
130
+
131
+ # Validate recording duration
132
+ if self.recording_duration is not None:
133
+ if self.recording_duration <= 0:
134
+ raise ValueError("Recording duration must be greater than 0.")
135
+ if self.sample_duration > self.recording_duration:
136
+ raise ValueError("Sample duration cannot exceed the recording duration.")
137
+
138
+ # Validate minimum overlap
139
+ if self.min_overlap <= 0:
140
+ raise ValueError("Min overlap must be greater than 0.")
141
+ if self.min_overlap > self.sample_duration:
142
+ raise ValueError("Min overlap cannot exceed the sample duration.")
143
+
144
+ def _validate_columns(self) -> None:
145
+ """
146
+ Validates that essential columns are provided in the column mappings.
147
+
148
+ Raises:
149
+ ValueError: If required columns are missing or have None values.
150
+ """
151
+ # Required columns for predictions and annotations
152
+ required_columns = ["Start Time", "End Time", "Class"]
153
+
154
+ # Check for missing or None columns in predictions
155
+ missing_pred_columns = [col for col in required_columns if col not in self.columns_predictions or self.columns_predictions[col] is None]
156
+
157
+ # Check for missing or None columns in annotations
158
+ missing_annot_columns = [col for col in required_columns if col not in self.columns_annotations or self.columns_annotations[col] is None]
159
+
160
+ if missing_pred_columns:
161
+ raise ValueError(f"Missing or None prediction columns: {', '.join(missing_pred_columns)}")
162
+ if missing_annot_columns:
163
+ raise ValueError(f"Missing or None annotation columns: {', '.join(missing_annot_columns)}")
164
+
165
+ def load_data(self) -> None:
166
+ """
167
+ Loads the prediction and annotation data into DataFrames.
168
+
169
+ Depending on whether specific files are provided, this method either reads all files
170
+ in the given directories or reads the specified files. The method also applies any
171
+ specified class mapping and prepares the data for further processing.
172
+
173
+ Raises:
174
+ ValueError: If file reading fails or data preparation encounters issues.
175
+ """
176
+ if self.prediction_file_name is None or self.annotation_file_name is None:
177
+ # Case: No specific files provided; load all files in directories.
178
+ self.predictions_df = read_and_concatenate_files_in_directory(self.prediction_directory_path)
179
+ self.annotations_df = read_and_concatenate_files_in_directory(self.annotation_directory_path)
180
+
181
+ # Ensure 'source_file' column exists for traceability
182
+ if "source_file" not in self.predictions_df.columns:
183
+ self.predictions_df["source_file"] = ""
184
+
185
+ if "source_file" not in self.annotations_df.columns:
186
+ self.annotations_df["source_file"] = ""
187
+
188
+ # Prepare DataFrames
189
+ self.predictions_df = self._prepare_dataframe(self.predictions_df, prediction=True)
190
+ self.annotations_df = self._prepare_dataframe(self.annotations_df, prediction=False)
191
+
192
+ # Apply class mapping to predictions if provided
193
+ if self.class_mapping:
194
+ class_col_pred = self.get_column_name("Class", prediction=True)
195
+ self.predictions_df[class_col_pred] = self.predictions_df[class_col_pred].apply(lambda x: self.class_mapping.get(x, x))
196
+ else:
197
+ # Case: Specific files are provided for predictions and annotations.
198
+ # Ensure filenames correspond to the same recording (heuristic check).
199
+ if not self.prediction_file_name.startswith(os.path.splitext(self.annotation_file_name)[0]):
200
+ warnings.warn(
201
+ "Prediction file name and annotation file name do not fully match, but proceeding anyway.",
202
+ stacklevel=2,
203
+ )
204
+
205
+ # Construct full file paths
206
+ prediction_file = os.path.join(self.prediction_directory_path, self.prediction_file_name)
207
+ annotation_file = os.path.join(self.annotation_directory_path, self.annotation_file_name)
208
+
209
+ # Load files into DataFrames
210
+ self.predictions_df = pd.read_csv(prediction_file, sep="\t")
211
+ self.annotations_df = pd.read_csv(annotation_file, sep="\t")
212
+
213
+ # Add 'source_file' column to identify origins
214
+ self.predictions_df["source_file"] = self.prediction_file_name
215
+ self.annotations_df["source_file"] = self.annotation_file_name
216
+
217
+ # Prepare DataFrames
218
+ self.predictions_df = self._prepare_dataframe(self.predictions_df, prediction=True)
219
+ self.annotations_df = self._prepare_dataframe(self.annotations_df, prediction=False)
220
+
221
+ # Apply class mapping to predictions if provided
222
+ if self.class_mapping:
223
+ class_col_pred = self.get_column_name("Class", prediction=True)
224
+ self.predictions_df[class_col_pred] = self.predictions_df[class_col_pred].apply(lambda x: self.class_mapping.get(x, x))
225
+
226
+ # Consolidate all unique classes from predictions and annotations
227
+ class_col_pred = self.get_column_name("Class", prediction=True)
228
+ class_col_annot = self.get_column_name("Class", prediction=False)
229
+
230
+ pred_classes = set(self.predictions_df[class_col_pred].unique())
231
+ annot_classes = set(self.annotations_df[class_col_annot].unique())
232
+
233
+ # Remove any NaN values from the union
234
+ all_classes = {cls for cls in pred_classes.union(annot_classes) if pd.notna(cls)}
235
+ self.classes = tuple(sorted(all_classes))
236
+
237
+ def _prepare_dataframe(self, df: pd.DataFrame, prediction: bool) -> pd.DataFrame:
238
+ """
239
+ Prepares a DataFrame by adding a 'recording_filename' column.
240
+
241
+ This method extracts the recording filename from either a specified 'Recording' column
242
+ or from the 'source_file' column to ensure traceability.
243
+
244
+ Args:
245
+ df (pd.DataFrame): The DataFrame to prepare.
246
+ prediction (bool): Whether the DataFrame is for predictions or annotations.
247
+
248
+ Returns:
249
+ pd.DataFrame: The prepared DataFrame with the added 'recording_filename' column.
250
+ """
251
+ # Determine the relevant column for extracting recording filenames
252
+ recording_col = self.get_column_name("Recording", prediction=prediction)
253
+
254
+ if recording_col in df.columns:
255
+ # Extract recording filename using the 'Recording' column
256
+ df["recording_filename"] = extract_recording_filename(df[recording_col])
257
+ elif "source_file" in df.columns:
258
+ # Fall back to extracting from the 'source_file' column
259
+ df["recording_filename"] = extract_recording_filename_from_filename(df["source_file"])
260
+ else:
261
+ # Assign a default empty string if no relevant columns exist
262
+ df["recording_filename"] = ""
263
+
264
+ return df
265
+
266
+ def process_data(self) -> None:
267
+ """
268
+ Processes the loaded data, aligns predictions and annotations with sample intervals,
269
+ and updates the samples DataFrame.
270
+
271
+ This method iterates through all recording filenames, processes each recording,
272
+ and aggregates the results into the `samples_df` attribute.
273
+ """
274
+ self.samples_df = pd.DataFrame() # Initialize the samples DataFrame
275
+
276
+ # Get the unique set of recording filenames from both predictions and annotations
277
+ recording_filenames = set(self.predictions_df["recording_filename"].unique()).union(set(self.annotations_df["recording_filename"].unique()))
278
+
279
+ # Process each recording
280
+ for recording_filename in recording_filenames:
281
+ # Filter predictions and annotations for the current recording
282
+ pred_df = self.predictions_df[self.predictions_df["recording_filename"] == recording_filename]
283
+ annot_df = self.annotations_df[self.annotations_df["recording_filename"] == recording_filename]
284
+
285
+ # Generate sample intervals and annotations for the recording
286
+ samples_df = self.process_recording(recording_filename, pred_df, annot_df)
287
+
288
+ # Append the processed DataFrame to the overall samples DataFrame
289
+ self.samples_df = pd.concat([self.samples_df, samples_df], ignore_index=True)
290
+
291
+ def process_recording(self, recording_filename: str, pred_df: pd.DataFrame, annot_df: pd.DataFrame) -> pd.DataFrame:
292
+ """
293
+ Processes a single recording by determining its duration, initializing sample intervals,
294
+ and updating the intervals with predictions and annotations.
295
+
296
+ Args:
297
+ recording_filename (str): The name of the recording.
298
+ pred_df (pd.DataFrame): Predictions DataFrame specific to the recording.
299
+ annot_df (pd.DataFrame): Annotations DataFrame specific to the recording.
300
+
301
+ Returns:
302
+ pd.DataFrame: A DataFrame containing sample intervals with prediction and annotation data.
303
+ """
304
+ # Determine the duration of the recording
305
+ file_duration = self.determine_file_duration(pred_df, annot_df)
306
+
307
+ if file_duration <= 0:
308
+ # Return an empty DataFrame if the duration is invalid
309
+ return pd.DataFrame()
310
+
311
+ # Initialize sample intervals for the recording
312
+ samples_df = self.initialize_samples(recording_filename=recording_filename, file_duration=file_duration)
313
+
314
+ # Update the samples DataFrame with prediction data
315
+ self.update_samples_with_predictions(pred_df, samples_df)
316
+
317
+ # Update the samples DataFrame with annotation data
318
+ self.update_samples_with_annotations(annot_df, samples_df)
319
+
320
+ return samples_df
321
+
322
+ def determine_file_duration(self, pred_df: pd.DataFrame, annot_df: pd.DataFrame) -> float:
323
+ """
324
+ Determines the duration of the recording based on available dataframes or the specified recording duration.
325
+
326
+ This method prioritizes the explicitly set `recording_duration` if available.
327
+ Otherwise, it computes the duration from the `Duration` or `End Time` columns in the
328
+ predictions and annotations DataFrames. Handles edge cases where data may be incomplete
329
+ or missing.
330
+
331
+ Args:
332
+ pred_df (pd.DataFrame): Predictions DataFrame containing duration or end time information.
333
+ annot_df (pd.DataFrame): Annotations DataFrame containing duration or end time information.
334
+
335
+ Returns:
336
+ float: The determined duration of the recording. Defaults to 0 if no valid duration is found.
337
+ """
338
+ if self.recording_duration is not None:
339
+ # Use the explicitly provided recording duration
340
+ return self.recording_duration
341
+
342
+ duration = 0.0
343
+
344
+ # Extract the 'Duration' column from predictions if available
345
+ file_duration_col_pred = self.get_column_name("Duration", prediction=True)
346
+
347
+ file_duration_col_annot = self.get_column_name("Duration", prediction=False)
348
+
349
+ # Try to get duration from 'Duration' column in pred_df
350
+ if file_duration_col_pred in pred_df.columns and pred_df[file_duration_col_pred].notna().any():
351
+ duration = max(duration, pred_df[file_duration_col_pred].dropna().max())
352
+
353
+ # Try to get duration from 'Duration' column in annot_df
354
+ if file_duration_col_annot in annot_df.columns and annot_df[file_duration_col_annot].notna().any():
355
+ duration = max(duration, annot_df[file_duration_col_annot].dropna().max())
356
+
357
+ # If no duration is found, use the maximum 'End Time' value
358
+ if duration == 0.0:
359
+ end_time_col_pred = self.get_column_name("End Time", prediction=True)
360
+ end_time_col_annot = self.get_column_name("End Time", prediction=False)
361
+
362
+ max_end_pred = pred_df[end_time_col_pred].max() if end_time_col_pred in pred_df.columns else 0.0
363
+ max_end_annot = annot_df[end_time_col_annot].max() if end_time_col_annot in annot_df.columns else 0.0
364
+ duration = max(max_end_pred, max_end_annot)
365
+
366
+ # Handle invalid values (NaN or negative duration)
367
+ if pd.isna(duration) or duration < 0:
368
+ duration = 0.0
369
+
370
+ return duration
371
+
372
+ def initialize_samples(self, recording_filename: str, file_duration: float) -> pd.DataFrame:
373
+ """
374
+ Initializes a DataFrame of time-based sample intervals for the specified recording.
375
+
376
+ Samples are evenly spaced time intervals of length `sample_duration` that cover the
377
+ entire recording duration. Each sample is initialized with confidence scores and
378
+ annotation values for all classes.
379
+
380
+ Args:
381
+ recording_filename (str): The name of the recording.
382
+ file_duration (float): The total duration of the recording in seconds.
383
+
384
+ Returns:
385
+ pd.DataFrame: A DataFrame containing initialized sample intervals, confidence scores, and annotations.
386
+ Returns an empty DataFrame if the file duration is less than or equal to 0.
387
+ """
388
+ if file_duration <= 0:
389
+ # Return an empty DataFrame if file duration is invalid
390
+ return pd.DataFrame()
391
+
392
+ # Generate start times for each sample interval
393
+ intervals = np.arange(0, file_duration, self.sample_duration)
394
+ if len(intervals) == 0:
395
+ intervals = np.array([0])
396
+
397
+ # Prepare sample structure
398
+ samples = {
399
+ "filename": recording_filename,
400
+ "sample_index": [],
401
+ "start_time": [],
402
+ "end_time": [],
403
+ }
404
+
405
+ for idx, start in enumerate(intervals):
406
+ samples["sample_index"].append(idx)
407
+ samples["start_time"].append(start)
408
+ samples["end_time"].append(min(start + self.sample_duration, file_duration))
409
+
410
+ # Initialize confidence scores and annotations for each class
411
+ for label in self.classes:
412
+ samples[f"{label}_confidence"] = [0.0] * len(samples["sample_index"]) # Float values
413
+ samples[f"{label}_annotation"] = [0] * len(samples["sample_index"]) # Integer values
414
+
415
+ return pd.DataFrame(samples)
416
+
417
+ def update_samples_with_predictions(self, pred_df: pd.DataFrame, samples_df: pd.DataFrame) -> None:
418
+ """
419
+ Updates the samples DataFrame with prediction confidence scores.
420
+
421
+ For each prediction in the predictions DataFrame, this method identifies overlapping
422
+ samples based on the specified `min_overlap`. It then updates the confidence scores
423
+ for those samples, retaining the maximum confidence value if multiple predictions overlap.
424
+
425
+ Args:
426
+ pred_df (pd.DataFrame): DataFrame containing prediction information.
427
+ samples_df (pd.DataFrame): DataFrame of samples to be updated with confidence scores.
428
+ """
429
+ # Retrieve the column names for predictions
430
+ class_col = self.get_column_name("Class", prediction=True)
431
+ start_time_col = self.get_column_name("Start Time", prediction=True)
432
+ end_time_col = self.get_column_name("End Time", prediction=True)
433
+ confidence_col = self.get_column_name("Confidence", prediction=True)
434
+
435
+ # Iterate through each prediction row
436
+ for _, row in pred_df.iterrows():
437
+ class_name = row[class_col]
438
+ if class_name not in self.classes:
439
+ continue # Skip predictions for classes not included in the predefined list
440
+
441
+ # Extract start and end times, and confidence score
442
+ begin_time = row[start_time_col]
443
+ end_time = row[end_time_col]
444
+ confidence = row.get(confidence_col, 0.0)
445
+
446
+ # Identify samples that overlap with the prediction based on min_overlap
447
+ sample_indices = samples_df[(samples_df["start_time"] <= end_time - self.min_overlap) & (samples_df["end_time"] >= begin_time + self.min_overlap)].index
448
+
449
+ # Update the confidence scores for the overlapping samples
450
+ for i in sample_indices:
451
+ current_confidence = samples_df.loc[i, f"{class_name}_confidence"]
452
+ samples_df.loc[i, f"{class_name}_confidence"] = max(current_confidence, confidence)
453
+
454
+ def update_samples_with_annotations(self, annot_df: pd.DataFrame, samples_df: pd.DataFrame) -> None:
455
+ """
456
+ Updates the samples DataFrame with annotations.
457
+
458
+ For each annotation in the annotations DataFrame, this method identifies overlapping
459
+ samples based on the specified `min_overlap`. It sets the annotation value to 1
460
+ for the overlapping samples.
461
+
462
+ Args:
463
+ annot_df (pd.DataFrame): DataFrame containing annotation information.
464
+ samples_df (pd.DataFrame): DataFrame of samples to be updated with annotations.
465
+ """
466
+ # Retrieve the column names for annotations
467
+ class_col = self.get_column_name("Class", prediction=False)
468
+ start_time_col = self.get_column_name("Start Time", prediction=False)
469
+ end_time_col = self.get_column_name("End Time", prediction=False)
470
+
471
+ # Iterate through each annotation row
472
+ for _, row in annot_df.iterrows():
473
+ class_name = row[class_col]
474
+ if class_name not in self.classes:
475
+ continue # Skip annotations for classes not included in the predefined list
476
+
477
+ # Extract start and end times
478
+ begin_time = row[start_time_col]
479
+ end_time = row[end_time_col]
480
+
481
+ # Identify samples that overlap with the annotation based on min_overlap
482
+ sample_indices = samples_df[(samples_df["start_time"] <= end_time - self.min_overlap) & (samples_df["end_time"] >= begin_time + self.min_overlap)].index
483
+
484
+ # Set annotation value to 1 for the overlapping samples
485
+ for i in sample_indices:
486
+ samples_df.loc[i, f"{class_name}_annotation"] = 1
487
+
488
+ def create_tensors(self) -> None:
489
+ """
490
+ Creates prediction and label tensors from the samples DataFrame.
491
+
492
+ This method converts confidence scores and annotations for each class into
493
+ numpy arrays (tensors). It ensures that there are no NaN values in the DataFrame
494
+ before creating the tensors.
495
+
496
+ Raises:
497
+ ValueError: If NaN values are found in confidence or annotation columns.
498
+ """
499
+ if self.samples_df.empty:
500
+ # Initialize empty tensors if samples DataFrame is empty
501
+ self.prediction_tensors = np.empty((0, len(self.classes)), dtype=np.float32)
502
+ self.label_tensors = np.empty((0, len(self.classes)), dtype=np.int64)
503
+ return
504
+
505
+ # Check for NaN values in annotation columns
506
+ annotation_columns = [f"{cls}_annotation" for cls in self.classes]
507
+ if self.samples_df[annotation_columns].isna().to_numpy().any():
508
+ raise ValueError("NaN values found in annotation columns.")
509
+
510
+ # Check for NaN values in confidence columns
511
+ confidence_columns = [f"{cls}_confidence" for cls in self.classes]
512
+ if self.samples_df[confidence_columns].isna().to_numpy().any():
513
+ raise ValueError("NaN values found in confidence columns.")
514
+
515
+ # Convert confidence scores and annotations into numpy arrays (tensors)
516
+ self.prediction_tensors = self.samples_df[confidence_columns].to_numpy(dtype=np.float32)
517
+ self.label_tensors = self.samples_df[annotation_columns].to_numpy(dtype=np.int64)
518
+
519
+ def get_column_name(self, field_name: str, prediction: bool = True) -> str:
520
+ """
521
+ Retrieves the appropriate column name for the specified field.
522
+
523
+ This method checks the column mapping (for predictions or annotations) and
524
+ returns the corresponding column name. If the field is not mapped, it returns
525
+ the field name directly.
526
+
527
+ Args:
528
+ field_name (str): The name of the field (e.g., "Class", "Start Time").
529
+ prediction (bool): Whether to fetch the name from the predictions mapping (True)
530
+ or annotations mapping (False).
531
+
532
+ Returns:
533
+ str: The column name corresponding to the field.
534
+
535
+ Raises:
536
+ TypeError: If `field_name` or `prediction` is None.
537
+ """
538
+ if field_name is None:
539
+ raise TypeError("field_name cannot be None.")
540
+ if prediction is None:
541
+ raise TypeError("prediction parameter cannot be None.")
542
+
543
+ # Select the appropriate mapping based on the `prediction` flag
544
+ mapping = self.columns_predictions if prediction else self.columns_annotations
545
+
546
+ if field_name in mapping and mapping[field_name] is not None:
547
+ return mapping[field_name]
548
+
549
+ return field_name
550
+
551
+ def get_sample_data(self) -> pd.DataFrame:
552
+ """
553
+ Retrieves the DataFrame containing all sample intervals, prediction scores, and annotations.
554
+
555
+ This method provides a copy of the `samples_df` DataFrame, ensuring that the original
556
+ data is not modified when accessed externally.
557
+
558
+ Returns:
559
+ pd.DataFrame: A copy of the `samples_df` DataFrame, which contains the sampled data.
560
+ """
561
+ # Return a copy of the samples DataFrame to preserve data integrity
562
+ return self.samples_df.copy()
563
+
564
+ def get_filtered_tensors(
565
+ self,
566
+ selected_classes: list[str] | None = None,
567
+ selected_recordings: list[str] | None = None,
568
+ ) -> tuple[np.ndarray, np.ndarray, tuple[str]]:
569
+ """
570
+ Filters the prediction and label tensors based on selected classes and recordings.
571
+
572
+ This method extracts subsets of the prediction and label tensors for specific classes
573
+ and/or recordings. It ensures that the filtered tensors correspond to valid classes
574
+ and recordings present in the sampled data.
575
+
576
+ Args:
577
+ selected_classes (List[str], optional): A list of class names to filter by. If None,
578
+ all classes are included.
579
+ selected_recordings (List[str], optional): A list of recording filenames to filter by. If None,
580
+ all recordings are included.
581
+
582
+ Returns:
583
+ Tuple[np.ndarray, np.ndarray, Tuple[str]]: A tuple containing:
584
+ - Filtered prediction tensors (numpy.ndarray)
585
+ - Filtered label tensors (numpy.ndarray)
586
+ - Tuple of selected class names (Tuple[str])
587
+
588
+ Raises:
589
+ ValueError: If the `samples_df` is empty or missing required columns.
590
+ KeyError: If required confidence or annotation columns are missing in the DataFrame.
591
+ """
592
+ if self.samples_df.empty:
593
+ raise ValueError("samples_df is empty.")
594
+
595
+ if "filename" not in self.samples_df.columns:
596
+ raise ValueError("samples_df must contain a 'filename' column.")
597
+
598
+ # Determine the classes to filter by
599
+ classes = self.classes if selected_classes is None else tuple(cls for cls in selected_classes if cls in self.classes)
600
+
601
+ if not classes:
602
+ raise ValueError("No valid classes selected.")
603
+
604
+ # Create a mask for filtering samples
605
+ mask = pd.Series(True, index=self.samples_df.index)
606
+
607
+ # Apply recording-based filtering if specified
608
+ if selected_recordings is not None:
609
+ if selected_recordings:
610
+ mask &= self.samples_df["filename"].isin(selected_recordings)
611
+ else:
612
+ # If `selected_recordings` is an empty list, select no samples
613
+ mask = pd.Series(False, index=self.samples_df.index)
614
+
615
+ # Filter the samples DataFrame using the mask
616
+ filtered_samples = self.samples_df.loc[mask]
617
+
618
+ # Prepare column names for confidence and annotation data
619
+ confidence_columns = [f"{cls}_confidence" for cls in classes]
620
+ annotation_columns = [f"{cls}_annotation" for cls in classes]
621
+
622
+ # Ensure all required columns are present in the filtered DataFrame
623
+ if not all(col in filtered_samples.columns for col in confidence_columns + annotation_columns):
624
+ raise KeyError("Required confidence or annotation columns are missing.")
625
+
626
+ # Convert filtered data into numpy arrays
627
+ predictions = filtered_samples[confidence_columns].to_numpy(dtype=np.float32)
628
+ labels = filtered_samples[annotation_columns].to_numpy(dtype=np.int64)
629
+
630
+ # Return the tensors and the list of filtered classes
631
+ return predictions, labels, classes