birdnet-analyzer 2.0.0__py3-none-any.whl → 2.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (122) hide show
  1. birdnet_analyzer/__init__.py +9 -8
  2. birdnet_analyzer/analyze/__init__.py +5 -5
  3. birdnet_analyzer/analyze/__main__.py +3 -4
  4. birdnet_analyzer/analyze/cli.py +25 -25
  5. birdnet_analyzer/analyze/core.py +241 -245
  6. birdnet_analyzer/analyze/utils.py +692 -701
  7. birdnet_analyzer/audio.py +368 -372
  8. birdnet_analyzer/cli.py +709 -707
  9. birdnet_analyzer/config.py +242 -242
  10. birdnet_analyzer/eBird_taxonomy_codes_2021E.json +25279 -25279
  11. birdnet_analyzer/embeddings/__init__.py +3 -4
  12. birdnet_analyzer/embeddings/__main__.py +3 -3
  13. birdnet_analyzer/embeddings/cli.py +12 -13
  14. birdnet_analyzer/embeddings/core.py +69 -70
  15. birdnet_analyzer/embeddings/utils.py +179 -193
  16. birdnet_analyzer/evaluation/__init__.py +196 -195
  17. birdnet_analyzer/evaluation/__main__.py +3 -3
  18. birdnet_analyzer/evaluation/assessment/__init__.py +0 -0
  19. birdnet_analyzer/evaluation/assessment/metrics.py +388 -0
  20. birdnet_analyzer/evaluation/assessment/performance_assessor.py +409 -0
  21. birdnet_analyzer/evaluation/assessment/plotting.py +379 -0
  22. birdnet_analyzer/evaluation/preprocessing/__init__.py +0 -0
  23. birdnet_analyzer/evaluation/preprocessing/data_processor.py +631 -0
  24. birdnet_analyzer/evaluation/preprocessing/utils.py +98 -0
  25. birdnet_analyzer/gui/__init__.py +19 -23
  26. birdnet_analyzer/gui/__main__.py +3 -3
  27. birdnet_analyzer/gui/analysis.py +175 -174
  28. birdnet_analyzer/gui/assets/arrow_down.svg +4 -4
  29. birdnet_analyzer/gui/assets/arrow_left.svg +4 -4
  30. birdnet_analyzer/gui/assets/arrow_right.svg +4 -4
  31. birdnet_analyzer/gui/assets/arrow_up.svg +4 -4
  32. birdnet_analyzer/gui/assets/gui.css +28 -28
  33. birdnet_analyzer/gui/assets/gui.js +93 -93
  34. birdnet_analyzer/gui/embeddings.py +619 -620
  35. birdnet_analyzer/gui/evaluation.py +795 -813
  36. birdnet_analyzer/gui/localization.py +75 -68
  37. birdnet_analyzer/gui/multi_file.py +245 -246
  38. birdnet_analyzer/gui/review.py +519 -527
  39. birdnet_analyzer/gui/segments.py +191 -191
  40. birdnet_analyzer/gui/settings.py +128 -129
  41. birdnet_analyzer/gui/single_file.py +267 -269
  42. birdnet_analyzer/gui/species.py +95 -95
  43. birdnet_analyzer/gui/train.py +696 -698
  44. birdnet_analyzer/gui/utils.py +810 -808
  45. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_af.txt +6522 -6522
  46. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ar.txt +6522 -6522
  47. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_bg.txt +6522 -6522
  48. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ca.txt +6522 -6522
  49. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_cs.txt +6522 -6522
  50. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_da.txt +6522 -6522
  51. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_de.txt +6522 -6522
  52. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_el.txt +6522 -6522
  53. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_en_uk.txt +6522 -6522
  54. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_es.txt +6522 -6522
  55. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_fi.txt +6522 -6522
  56. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_fr.txt +6522 -6522
  57. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_he.txt +6522 -6522
  58. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_hr.txt +6522 -6522
  59. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_hu.txt +6522 -6522
  60. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_in.txt +6522 -6522
  61. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_is.txt +6522 -6522
  62. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_it.txt +6522 -6522
  63. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ja.txt +6522 -6522
  64. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ko.txt +6522 -6522
  65. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_lt.txt +6522 -6522
  66. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ml.txt +6522 -6522
  67. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_nl.txt +6522 -6522
  68. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_no.txt +6522 -6522
  69. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_pl.txt +6522 -6522
  70. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_pt_BR.txt +6522 -6522
  71. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_pt_PT.txt +6522 -6522
  72. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ro.txt +6522 -6522
  73. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ru.txt +6522 -6522
  74. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sk.txt +6522 -6522
  75. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sl.txt +6522 -6522
  76. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sr.txt +6522 -6522
  77. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sv.txt +6522 -6522
  78. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_th.txt +6522 -6522
  79. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_tr.txt +6522 -6522
  80. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_uk.txt +6522 -6522
  81. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_zh.txt +6522 -6522
  82. birdnet_analyzer/lang/de.json +334 -334
  83. birdnet_analyzer/lang/en.json +334 -334
  84. birdnet_analyzer/lang/fi.json +334 -334
  85. birdnet_analyzer/lang/fr.json +334 -334
  86. birdnet_analyzer/lang/id.json +334 -334
  87. birdnet_analyzer/lang/pt-br.json +334 -334
  88. birdnet_analyzer/lang/ru.json +334 -334
  89. birdnet_analyzer/lang/se.json +334 -334
  90. birdnet_analyzer/lang/tlh.json +334 -334
  91. birdnet_analyzer/lang/zh_TW.json +334 -334
  92. birdnet_analyzer/model.py +1212 -1243
  93. birdnet_analyzer/playground.py +5 -0
  94. birdnet_analyzer/search/__init__.py +3 -3
  95. birdnet_analyzer/search/__main__.py +3 -3
  96. birdnet_analyzer/search/cli.py +11 -12
  97. birdnet_analyzer/search/core.py +78 -78
  98. birdnet_analyzer/search/utils.py +107 -111
  99. birdnet_analyzer/segments/__init__.py +3 -3
  100. birdnet_analyzer/segments/__main__.py +3 -3
  101. birdnet_analyzer/segments/cli.py +13 -14
  102. birdnet_analyzer/segments/core.py +81 -78
  103. birdnet_analyzer/segments/utils.py +383 -394
  104. birdnet_analyzer/species/__init__.py +3 -3
  105. birdnet_analyzer/species/__main__.py +3 -3
  106. birdnet_analyzer/species/cli.py +13 -14
  107. birdnet_analyzer/species/core.py +35 -35
  108. birdnet_analyzer/species/utils.py +74 -75
  109. birdnet_analyzer/train/__init__.py +3 -3
  110. birdnet_analyzer/train/__main__.py +3 -3
  111. birdnet_analyzer/train/cli.py +13 -14
  112. birdnet_analyzer/train/core.py +113 -113
  113. birdnet_analyzer/train/utils.py +877 -847
  114. birdnet_analyzer/translate.py +133 -104
  115. birdnet_analyzer/utils.py +426 -419
  116. {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.0.1.dist-info}/METADATA +137 -129
  117. birdnet_analyzer-2.0.1.dist-info/RECORD +125 -0
  118. {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.0.1.dist-info}/WHEEL +1 -1
  119. {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.0.1.dist-info}/licenses/LICENSE +18 -18
  120. birdnet_analyzer-2.0.0.dist-info/RECORD +0 -117
  121. {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.0.1.dist-info}/entry_points.txt +0 -0
  122. {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.0.1.dist-info}/top_level.txt +0 -0
@@ -1,195 +1,196 @@
1
- """
2
- Core script for assessing performance of prediction models against annotated data.
3
-
4
- This script uses the `DataProcessor` and `PerformanceAssessor` classes to process prediction and
5
- annotation data, compute metrics, and optionally generate plots. It supports flexible configurations
6
- for columns, class mappings, and filtering based on selected classes or recordings.
7
- """
8
-
9
- import argparse
10
- import json
11
- import os
12
- from typing import Optional, Dict, List, Tuple
13
-
14
- from birdnet_analyzer.evaluation.preprocessing.data_processor import DataProcessor
15
- from birdnet_analyzer.evaluation.assessment.performance_assessor import PerformanceAssessor
16
-
17
-
18
- def process_data(
19
- annotation_path: str,
20
- prediction_path: str,
21
- mapping_path: Optional[str] = None,
22
- sample_duration: float = 3.0,
23
- min_overlap: float = 0.5,
24
- recording_duration: Optional[float] = None,
25
- columns_annotations: Optional[Dict[str, str]] = None,
26
- columns_predictions: Optional[Dict[str, str]] = None,
27
- selected_classes: Optional[List[str]] = None,
28
- selected_recordings: Optional[List[str]] = None,
29
- metrics_list: Tuple[str, ...] = ("accuracy", "precision", "recall"),
30
- threshold: float = 0.1,
31
- class_wise: bool = False,
32
- ):
33
- """
34
- Processes data, computes metrics, and prepares the performance assessment pipeline.
35
-
36
- Args:
37
- annotation_path (str): Path to the annotation file or folder.
38
- prediction_path (str): Path to the prediction file or folder.
39
- mapping_path (Optional[str]): Path to the class mapping JSON file, if applicable.
40
- sample_duration (float): Duration of each sample interval in seconds.
41
- min_overlap (float): Minimum overlap required between predictions and annotations.
42
- recording_duration (Optional[float]): Total duration of the recordings, if known.
43
- columns_annotations (Optional[Dict[str, str]]): Custom column mappings for annotations.
44
- columns_predictions (Optional[Dict[str, str]]): Custom column mappings for predictions.
45
- selected_classes (Optional[List[str]]): List of classes to include in the analysis.
46
- selected_recordings (Optional[List[str]]): List of recordings to include in the analysis.
47
- metrics_list (Tuple[str, ...]): Metrics to compute for performance assessment.
48
- threshold (float): Confidence threshold for predictions.
49
- class_wise (bool): Whether to calculate metrics on a per-class basis.
50
-
51
- Returns:
52
- Tuple: Metrics DataFrame, `PerformanceAssessor` object, predictions tensor, labels tensor.
53
- """
54
- # Load class mapping if provided
55
- if mapping_path:
56
- with open(mapping_path, "r") as f:
57
- class_mapping = json.load(f)
58
- else:
59
- class_mapping = None
60
-
61
- # Determine directory and file paths for annotations and predictions
62
- annotation_dir, annotation_file = (
63
- (os.path.dirname(annotation_path), os.path.basename(annotation_path))
64
- if os.path.isfile(annotation_path)
65
- else (annotation_path, None)
66
- )
67
- prediction_dir, prediction_file = (
68
- (os.path.dirname(prediction_path), os.path.basename(prediction_path))
69
- if os.path.isfile(prediction_path)
70
- else (prediction_path, None)
71
- )
72
-
73
- # Initialize the DataProcessor to handle and prepare data
74
- processor = DataProcessor(
75
- prediction_directory_path=prediction_dir,
76
- prediction_file_name=prediction_file,
77
- annotation_directory_path=annotation_dir,
78
- annotation_file_name=annotation_file,
79
- class_mapping=class_mapping,
80
- sample_duration=sample_duration,
81
- min_overlap=min_overlap,
82
- columns_predictions=columns_predictions,
83
- columns_annotations=columns_annotations,
84
- recording_duration=recording_duration,
85
- )
86
-
87
- # Get the available classes and recordings
88
- available_classes = processor.classes
89
- available_recordings = processor.samples_df["filename"].unique().tolist()
90
-
91
- # Default to all classes or recordings if none are specified
92
- if selected_classes is None:
93
- selected_classes = available_classes
94
- if selected_recordings is None:
95
- selected_recordings = available_recordings
96
-
97
- # Retrieve predictions and labels tensors for the selected classes and recordings
98
- predictions, labels, classes = processor.get_filtered_tensors(selected_classes, selected_recordings)
99
-
100
- num_classes = len(classes)
101
- task = "binary" if num_classes == 1 else "multilabel"
102
-
103
- # Initialize the PerformanceAssessor for computing metrics
104
- pa = PerformanceAssessor(
105
- num_classes=num_classes,
106
- threshold=threshold,
107
- classes=classes,
108
- task=task,
109
- metrics_list=metrics_list,
110
- )
111
-
112
- # Compute performance metrics
113
- metrics_df = pa.calculate_metrics(predictions, labels, per_class_metrics=class_wise)
114
-
115
- return metrics_df, pa, predictions, labels
116
-
117
-
118
- def main():
119
- """
120
- Entry point for the script. Parses command-line arguments and orchestrates the performance assessment pipeline.
121
- """
122
- # Set up argument parsing
123
- parser = argparse.ArgumentParser(description="Performance Assessor Core Script")
124
- parser.add_argument("--annotation_path", required=True, help="Path to annotation file or folder")
125
- parser.add_argument("--prediction_path", required=True, help="Path to prediction file or folder")
126
- parser.add_argument("--mapping_path", help="Path to class mapping JSON file (optional)")
127
- parser.add_argument("--sample_duration", type=float, default=3.0, help="Sample duration in seconds")
128
- parser.add_argument("--min_overlap", type=float, default=0.5, help="Minimum overlap in seconds")
129
- parser.add_argument("--recording_duration", type=float, help="Recording duration in seconds")
130
- parser.add_argument("--columns_annotations", type=json.loads, help="JSON string for columns_annotations")
131
- parser.add_argument("--columns_predictions", type=json.loads, help="JSON string for columns_predictions")
132
- parser.add_argument("--selected_classes", nargs="+", help="List of selected classes")
133
- parser.add_argument("--selected_recordings", nargs="+", help="List of selected recordings")
134
- parser.add_argument("--metrics", nargs="+", default=["accuracy", "precision", "recall"], help="List of metrics")
135
- parser.add_argument("--threshold", type=float, default=0.1, help="Threshold value (0-1)")
136
- parser.add_argument("--class_wise", action="store_true", help="Calculate class-wise metrics")
137
- parser.add_argument("--plot_metrics", action="store_true", help="Plot metrics")
138
- parser.add_argument("--plot_confusion_matrix", action="store_true", help="Plot confusion matrix")
139
- parser.add_argument("--plot_metrics_all_thresholds", action="store_true", help="Plot metrics for all thresholds")
140
- parser.add_argument("--output_dir", help="Directory to save plots")
141
-
142
- # Parse arguments
143
- args = parser.parse_args()
144
-
145
- # Process data and compute metrics
146
- metrics_df, pa, predictions, labels = process_data(
147
- annotation_path=args.annotation_path,
148
- prediction_path=args.prediction_path,
149
- mapping_path=args.mapping_path,
150
- sample_duration=args.sample_duration,
151
- min_overlap=args.min_overlap,
152
- recording_duration=args.recording_duration,
153
- columns_annotations=args.columns_annotations,
154
- columns_predictions=args.columns_predictions,
155
- selected_classes=args.selected_classes,
156
- selected_recordings=args.selected_recordings,
157
- metrics_list=args.metrics,
158
- threshold=args.threshold,
159
- class_wise=args.class_wise,
160
- )
161
-
162
- # Display the computed metrics
163
- print(metrics_df)
164
-
165
- # Create output directory if needed
166
- if args.output_dir and not os.path.exists(args.output_dir):
167
- os.makedirs(args.output_dir)
168
-
169
- # Generate plots if specified
170
- if args.plot_metrics:
171
- pa.plot_metrics(predictions, labels, per_class_metrics=args.class_wise)
172
- if args.output_dir:
173
- import matplotlib.pyplot as plt
174
-
175
- plt.savefig(os.path.join(args.output_dir, "metrics_plot.png"))
176
- else:
177
- plt.show()
178
-
179
- if args.plot_confusion_matrix:
180
- pa.plot_confusion_matrix(predictions, labels)
181
- if args.output_dir:
182
- import matplotlib.pyplot as plt
183
-
184
- plt.savefig(os.path.join(args.output_dir, "confusion_matrix.png"))
185
- else:
186
- plt.show()
187
-
188
- if args.plot_metrics_all_thresholds:
189
- pa.plot_metrics_all_thresholds(predictions, labels, per_class_metrics=args.class_wise)
190
- if args.output_dir:
191
- import matplotlib.pyplot as plt
192
-
193
- plt.savefig(os.path.join(args.output_dir, "metrics_all_thresholds.png"))
194
- else:
195
- plt.show()
1
+ """
2
+ Core script for assessing performance of prediction models against annotated data.
3
+
4
+ This script uses the `DataProcessor` and `PerformanceAssessor` classes to process prediction and
5
+ annotation data, compute metrics, and optionally generate plots. It supports flexible configurations
6
+ for columns, class mappings, and filtering based on selected classes or recordings.
7
+ """
8
+
9
+ import argparse
10
+ import json
11
+ import os
12
+
13
+ from birdnet_analyzer.evaluation.assessment.performance_assessor import (
14
+ PerformanceAssessor,
15
+ )
16
+ from birdnet_analyzer.evaluation.preprocessing.data_processor import DataProcessor
17
+
18
+
19
+ def process_data(
20
+ annotation_path: str,
21
+ prediction_path: str,
22
+ mapping_path: str | None = None,
23
+ sample_duration: float = 3.0,
24
+ min_overlap: float = 0.5,
25
+ recording_duration: float | None = None,
26
+ columns_annotations: dict[str, str] | None = None,
27
+ columns_predictions: dict[str, str] | None = None,
28
+ selected_classes: list[str] | None = None,
29
+ selected_recordings: list[str] | None = None,
30
+ metrics_list: tuple[str, ...] = ("accuracy", "precision", "recall"),
31
+ threshold: float = 0.1,
32
+ class_wise: bool = False,
33
+ ):
34
+ """
35
+ Processes data, computes metrics, and prepares the performance assessment pipeline.
36
+
37
+ Args:
38
+ annotation_path (str): Path to the annotation file or folder.
39
+ prediction_path (str): Path to the prediction file or folder.
40
+ mapping_path (Optional[str]): Path to the class mapping JSON file, if applicable.
41
+ sample_duration (float): Duration of each sample interval in seconds.
42
+ min_overlap (float): Minimum overlap required between predictions and annotations.
43
+ recording_duration (Optional[float]): Total duration of the recordings, if known.
44
+ columns_annotations (Optional[Dict[str, str]]): Custom column mappings for annotations.
45
+ columns_predictions (Optional[Dict[str, str]]): Custom column mappings for predictions.
46
+ selected_classes (Optional[List[str]]): List of classes to include in the analysis.
47
+ selected_recordings (Optional[List[str]]): List of recordings to include in the analysis.
48
+ metrics_list (Tuple[str, ...]): Metrics to compute for performance assessment.
49
+ threshold (float): Confidence threshold for predictions.
50
+ class_wise (bool): Whether to calculate metrics on a per-class basis.
51
+
52
+ Returns:
53
+ Tuple: Metrics DataFrame, `PerformanceAssessor` object, predictions tensor, labels tensor.
54
+ """
55
+ # Load class mapping if provided
56
+ if mapping_path:
57
+ with open(mapping_path) as f:
58
+ class_mapping = json.load(f)
59
+ else:
60
+ class_mapping = None
61
+
62
+ # Determine directory and file paths for annotations and predictions
63
+ annotation_dir, annotation_file = (
64
+ (os.path.dirname(annotation_path), os.path.basename(annotation_path))
65
+ if os.path.isfile(annotation_path)
66
+ else (annotation_path, None)
67
+ )
68
+ prediction_dir, prediction_file = (
69
+ (os.path.dirname(prediction_path), os.path.basename(prediction_path))
70
+ if os.path.isfile(prediction_path)
71
+ else (prediction_path, None)
72
+ )
73
+
74
+ # Initialize the DataProcessor to handle and prepare data
75
+ processor = DataProcessor(
76
+ prediction_directory_path=prediction_dir,
77
+ prediction_file_name=prediction_file,
78
+ annotation_directory_path=annotation_dir,
79
+ annotation_file_name=annotation_file,
80
+ class_mapping=class_mapping,
81
+ sample_duration=sample_duration,
82
+ min_overlap=min_overlap,
83
+ columns_predictions=columns_predictions,
84
+ columns_annotations=columns_annotations,
85
+ recording_duration=recording_duration,
86
+ )
87
+
88
+ # Get the available classes and recordings
89
+ available_classes = processor.classes
90
+ available_recordings = processor.samples_df["filename"].unique().tolist()
91
+
92
+ # Default to all classes or recordings if none are specified
93
+ if selected_classes is None:
94
+ selected_classes = available_classes
95
+ if selected_recordings is None:
96
+ selected_recordings = available_recordings
97
+
98
+ # Retrieve predictions and labels tensors for the selected classes and recordings
99
+ predictions, labels, classes = processor.get_filtered_tensors(selected_classes, selected_recordings)
100
+
101
+ num_classes = len(classes)
102
+ task = "binary" if num_classes == 1 else "multilabel"
103
+
104
+ # Initialize the PerformanceAssessor for computing metrics
105
+ pa = PerformanceAssessor(
106
+ num_classes=num_classes,
107
+ threshold=threshold,
108
+ classes=classes,
109
+ task=task,
110
+ metrics_list=metrics_list,
111
+ )
112
+
113
+ # Compute performance metrics
114
+ metrics_df = pa.calculate_metrics(predictions, labels, per_class_metrics=class_wise)
115
+
116
+ return metrics_df, pa, predictions, labels
117
+
118
+
119
+ def main():
120
+ """
121
+ Entry point for the script. Parses command-line arguments and orchestrates the performance assessment pipeline.
122
+ """
123
+ # Set up argument parsing
124
+ parser = argparse.ArgumentParser(description="Performance Assessor Core Script")
125
+ parser.add_argument("--annotation_path", required=True, help="Path to annotation file or folder")
126
+ parser.add_argument("--prediction_path", required=True, help="Path to prediction file or folder")
127
+ parser.add_argument("--mapping_path", help="Path to class mapping JSON file (optional)")
128
+ parser.add_argument("--sample_duration", type=float, default=3.0, help="Sample duration in seconds")
129
+ parser.add_argument("--min_overlap", type=float, default=0.5, help="Minimum overlap in seconds")
130
+ parser.add_argument("--recording_duration", type=float, help="Recording duration in seconds")
131
+ parser.add_argument("--columns_annotations", type=json.loads, help="JSON string for columns_annotations")
132
+ parser.add_argument("--columns_predictions", type=json.loads, help="JSON string for columns_predictions")
133
+ parser.add_argument("--selected_classes", nargs="+", help="List of selected classes")
134
+ parser.add_argument("--selected_recordings", nargs="+", help="List of selected recordings")
135
+ parser.add_argument("--metrics", nargs="+", default=["accuracy", "precision", "recall"], help="List of metrics")
136
+ parser.add_argument("--threshold", type=float, default=0.1, help="Threshold value (0-1)")
137
+ parser.add_argument("--class_wise", action="store_true", help="Calculate class-wise metrics")
138
+ parser.add_argument("--plot_metrics", action="store_true", help="Plot metrics")
139
+ parser.add_argument("--plot_confusion_matrix", action="store_true", help="Plot confusion matrix")
140
+ parser.add_argument("--plot_metrics_all_thresholds", action="store_true", help="Plot metrics for all thresholds")
141
+ parser.add_argument("--output_dir", help="Directory to save plots")
142
+
143
+ # Parse arguments
144
+ args = parser.parse_args()
145
+
146
+ # Process data and compute metrics
147
+ metrics_df, pa, predictions, labels = process_data(
148
+ annotation_path=args.annotation_path,
149
+ prediction_path=args.prediction_path,
150
+ mapping_path=args.mapping_path,
151
+ sample_duration=args.sample_duration,
152
+ min_overlap=args.min_overlap,
153
+ recording_duration=args.recording_duration,
154
+ columns_annotations=args.columns_annotations,
155
+ columns_predictions=args.columns_predictions,
156
+ selected_classes=args.selected_classes,
157
+ selected_recordings=args.selected_recordings,
158
+ metrics_list=args.metrics,
159
+ threshold=args.threshold,
160
+ class_wise=args.class_wise,
161
+ )
162
+
163
+ # Display the computed metrics
164
+ print(metrics_df)
165
+
166
+ # Create output directory if needed
167
+ if args.output_dir and not os.path.exists(args.output_dir):
168
+ os.makedirs(args.output_dir)
169
+
170
+ # Generate plots if specified
171
+ if args.plot_metrics:
172
+ pa.plot_metrics(predictions, labels, per_class_metrics=args.class_wise)
173
+ if args.output_dir:
174
+ import matplotlib.pyplot as plt
175
+
176
+ plt.savefig(os.path.join(args.output_dir, "metrics_plot.png"))
177
+ else:
178
+ plt.show()
179
+
180
+ if args.plot_confusion_matrix:
181
+ pa.plot_confusion_matrix(predictions, labels)
182
+ if args.output_dir:
183
+ import matplotlib.pyplot as plt
184
+
185
+ plt.savefig(os.path.join(args.output_dir, "confusion_matrix.png"))
186
+ else:
187
+ plt.show()
188
+
189
+ if args.plot_metrics_all_thresholds:
190
+ pa.plot_metrics_all_thresholds(predictions, labels, per_class_metrics=args.class_wise)
191
+ if args.output_dir:
192
+ import matplotlib.pyplot as plt
193
+
194
+ plt.savefig(os.path.join(args.output_dir, "metrics_all_thresholds.png"))
195
+ else:
196
+ plt.show()
@@ -1,3 +1,3 @@
1
- from birdnet_analyzer.evaluation import main
2
-
3
- main()
1
+ from birdnet_analyzer.evaluation import main
2
+
3
+ main()
File without changes