birdnet-analyzer 2.0.0__py3-none-any.whl → 2.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (122) hide show
  1. birdnet_analyzer/__init__.py +9 -8
  2. birdnet_analyzer/analyze/__init__.py +5 -5
  3. birdnet_analyzer/analyze/__main__.py +3 -4
  4. birdnet_analyzer/analyze/cli.py +25 -25
  5. birdnet_analyzer/analyze/core.py +241 -245
  6. birdnet_analyzer/analyze/utils.py +692 -701
  7. birdnet_analyzer/audio.py +368 -372
  8. birdnet_analyzer/cli.py +709 -707
  9. birdnet_analyzer/config.py +242 -242
  10. birdnet_analyzer/eBird_taxonomy_codes_2021E.json +25279 -25279
  11. birdnet_analyzer/embeddings/__init__.py +3 -4
  12. birdnet_analyzer/embeddings/__main__.py +3 -3
  13. birdnet_analyzer/embeddings/cli.py +12 -13
  14. birdnet_analyzer/embeddings/core.py +69 -70
  15. birdnet_analyzer/embeddings/utils.py +179 -193
  16. birdnet_analyzer/evaluation/__init__.py +196 -195
  17. birdnet_analyzer/evaluation/__main__.py +3 -3
  18. birdnet_analyzer/evaluation/assessment/__init__.py +0 -0
  19. birdnet_analyzer/evaluation/assessment/metrics.py +388 -0
  20. birdnet_analyzer/evaluation/assessment/performance_assessor.py +409 -0
  21. birdnet_analyzer/evaluation/assessment/plotting.py +379 -0
  22. birdnet_analyzer/evaluation/preprocessing/__init__.py +0 -0
  23. birdnet_analyzer/evaluation/preprocessing/data_processor.py +631 -0
  24. birdnet_analyzer/evaluation/preprocessing/utils.py +98 -0
  25. birdnet_analyzer/gui/__init__.py +19 -23
  26. birdnet_analyzer/gui/__main__.py +3 -3
  27. birdnet_analyzer/gui/analysis.py +175 -174
  28. birdnet_analyzer/gui/assets/arrow_down.svg +4 -4
  29. birdnet_analyzer/gui/assets/arrow_left.svg +4 -4
  30. birdnet_analyzer/gui/assets/arrow_right.svg +4 -4
  31. birdnet_analyzer/gui/assets/arrow_up.svg +4 -4
  32. birdnet_analyzer/gui/assets/gui.css +28 -28
  33. birdnet_analyzer/gui/assets/gui.js +93 -93
  34. birdnet_analyzer/gui/embeddings.py +619 -620
  35. birdnet_analyzer/gui/evaluation.py +795 -813
  36. birdnet_analyzer/gui/localization.py +75 -68
  37. birdnet_analyzer/gui/multi_file.py +245 -246
  38. birdnet_analyzer/gui/review.py +519 -527
  39. birdnet_analyzer/gui/segments.py +191 -191
  40. birdnet_analyzer/gui/settings.py +128 -129
  41. birdnet_analyzer/gui/single_file.py +267 -269
  42. birdnet_analyzer/gui/species.py +95 -95
  43. birdnet_analyzer/gui/train.py +696 -698
  44. birdnet_analyzer/gui/utils.py +810 -808
  45. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_af.txt +6522 -6522
  46. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ar.txt +6522 -6522
  47. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_bg.txt +6522 -6522
  48. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ca.txt +6522 -6522
  49. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_cs.txt +6522 -6522
  50. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_da.txt +6522 -6522
  51. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_de.txt +6522 -6522
  52. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_el.txt +6522 -6522
  53. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_en_uk.txt +6522 -6522
  54. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_es.txt +6522 -6522
  55. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_fi.txt +6522 -6522
  56. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_fr.txt +6522 -6522
  57. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_he.txt +6522 -6522
  58. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_hr.txt +6522 -6522
  59. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_hu.txt +6522 -6522
  60. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_in.txt +6522 -6522
  61. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_is.txt +6522 -6522
  62. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_it.txt +6522 -6522
  63. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ja.txt +6522 -6522
  64. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ko.txt +6522 -6522
  65. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_lt.txt +6522 -6522
  66. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ml.txt +6522 -6522
  67. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_nl.txt +6522 -6522
  68. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_no.txt +6522 -6522
  69. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_pl.txt +6522 -6522
  70. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_pt_BR.txt +6522 -6522
  71. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_pt_PT.txt +6522 -6522
  72. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ro.txt +6522 -6522
  73. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ru.txt +6522 -6522
  74. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sk.txt +6522 -6522
  75. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sl.txt +6522 -6522
  76. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sr.txt +6522 -6522
  77. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sv.txt +6522 -6522
  78. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_th.txt +6522 -6522
  79. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_tr.txt +6522 -6522
  80. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_uk.txt +6522 -6522
  81. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_zh.txt +6522 -6522
  82. birdnet_analyzer/lang/de.json +334 -334
  83. birdnet_analyzer/lang/en.json +334 -334
  84. birdnet_analyzer/lang/fi.json +334 -334
  85. birdnet_analyzer/lang/fr.json +334 -334
  86. birdnet_analyzer/lang/id.json +334 -334
  87. birdnet_analyzer/lang/pt-br.json +334 -334
  88. birdnet_analyzer/lang/ru.json +334 -334
  89. birdnet_analyzer/lang/se.json +334 -334
  90. birdnet_analyzer/lang/tlh.json +334 -334
  91. birdnet_analyzer/lang/zh_TW.json +334 -334
  92. birdnet_analyzer/model.py +1212 -1243
  93. birdnet_analyzer/playground.py +5 -0
  94. birdnet_analyzer/search/__init__.py +3 -3
  95. birdnet_analyzer/search/__main__.py +3 -3
  96. birdnet_analyzer/search/cli.py +11 -12
  97. birdnet_analyzer/search/core.py +78 -78
  98. birdnet_analyzer/search/utils.py +107 -111
  99. birdnet_analyzer/segments/__init__.py +3 -3
  100. birdnet_analyzer/segments/__main__.py +3 -3
  101. birdnet_analyzer/segments/cli.py +13 -14
  102. birdnet_analyzer/segments/core.py +81 -78
  103. birdnet_analyzer/segments/utils.py +383 -394
  104. birdnet_analyzer/species/__init__.py +3 -3
  105. birdnet_analyzer/species/__main__.py +3 -3
  106. birdnet_analyzer/species/cli.py +13 -14
  107. birdnet_analyzer/species/core.py +35 -35
  108. birdnet_analyzer/species/utils.py +74 -75
  109. birdnet_analyzer/train/__init__.py +3 -3
  110. birdnet_analyzer/train/__main__.py +3 -3
  111. birdnet_analyzer/train/cli.py +13 -14
  112. birdnet_analyzer/train/core.py +113 -113
  113. birdnet_analyzer/train/utils.py +877 -847
  114. birdnet_analyzer/translate.py +133 -104
  115. birdnet_analyzer/utils.py +426 -419
  116. {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.0.1.dist-info}/METADATA +137 -129
  117. birdnet_analyzer-2.0.1.dist-info/RECORD +125 -0
  118. {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.0.1.dist-info}/WHEEL +1 -1
  119. {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.0.1.dist-info}/licenses/LICENSE +18 -18
  120. birdnet_analyzer-2.0.0.dist-info/RECORD +0 -117
  121. {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.0.1.dist-info}/entry_points.txt +0 -0
  122. {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.0.1.dist-info}/top_level.txt +0 -0
@@ -1,813 +1,795 @@
1
- import json
2
- import os
3
- import shutil
4
- import tempfile
5
- import typing
6
-
7
- import gradio as gr
8
- import matplotlib.pyplot as plt
9
- import pandas as pd
10
-
11
- import birdnet_analyzer.gui.localization as loc
12
- import birdnet_analyzer.gui.utils as gu
13
- from birdnet_analyzer.evaluation.assessment.performance_assessor import PerformanceAssessor
14
- from birdnet_analyzer.evaluation import process_data
15
- from birdnet_analyzer.evaluation.preprocessing.data_processor import DataProcessor
16
-
17
-
18
- class ProcessorState(typing.NamedTuple):
19
- """State of the DataProcessor."""
20
-
21
- processor: DataProcessor
22
- annotation_dir: str
23
- prediction_dir: str
24
-
25
-
26
- def build_evaluation_tab():
27
- # Default columns for annotations
28
- annotation_default_columns = {
29
- "Start Time": "Begin Time (s)",
30
- "End Time": "End Time (s)",
31
- "Class": "Class",
32
- "Recording": "Begin File",
33
- "Duration": "File Duration (s)",
34
- }
35
-
36
- # Default columns for predictions
37
- prediction_default_columns = {
38
- "Start Time": "Begin Time (s)",
39
- "End Time": "End Time (s)",
40
- "Class": "Common Name",
41
- "Recording": "Begin File",
42
- "Duration": "File Duration (s)",
43
- "Confidence": "Confidence",
44
- }
45
-
46
- localized_column_labels = {
47
- "Start Time": loc.localize("eval-tab-column-start-time-label"),
48
- "End Time": loc.localize("eval-tab-column-end-time-label"),
49
- "Class": loc.localize("eval-tab-column-class-label"),
50
- "Recording": loc.localize("eval-tab-column-recording-label"),
51
- "Duration": loc.localize("eval-tab-column-duration-label"),
52
- "Confidence": loc.localize("eval-tab-column-confidence-label"),
53
- }
54
-
55
- def download_class_mapping_template():
56
- try:
57
- template_mapping = {
58
- "Predicted Class Name 1": "Annotation Class Name 1",
59
- "Predicted Class Name 2": "Annotation Class Name 2",
60
- "Predicted Class Name 3": "Annotation Class Name 3",
61
- "Predicted Class Name 4": "Annotation Class Name 4",
62
- "Predicted Class Name 5": "Annotation Class Name 5",
63
- }
64
-
65
- file_location = gu.save_file_dialog(
66
- state_key="eval-mapping-template",
67
- filetypes=("JSON (*.json)",),
68
- default_filename="class_mapping_template.json",
69
- )
70
-
71
- if file_location:
72
- with open(file_location, "w") as f:
73
- json.dump(template_mapping, f, indent=4)
74
-
75
- gr.Info(loc.localize("eval-tab-info-mapping-template-saved"))
76
- except Exception as e:
77
- print(f"Error saving mapping template: {e}")
78
- raise gr.Error(f"{loc.localize('eval-tab-error-saving-mapping-template')} {e}") from e
79
-
80
- def download_results_table(pa: PerformanceAssessor, predictions, labels, class_wise_value):
81
- if pa is None or predictions is None or labels is None:
82
- raise gr.Error(loc.localize("eval-tab-error-calc-metrics-first"), print_exception=False)
83
-
84
- try:
85
- file_location = gu.save_file_dialog(
86
- state_key="eval-results-table",
87
- filetypes=("CSV (*.csv;*.CSV)", "TSV (*.tsv;*.TSV)"),
88
- default_filename="results_table.csv",
89
- )
90
-
91
- if file_location:
92
- metrics_df = pa.calculate_metrics(predictions, labels, per_class_metrics=class_wise_value)
93
-
94
- if file_location.split(".")[-1].lower() == "tsv":
95
- metrics_df.to_csv(file_location, sep="\t", index=True)
96
- else:
97
- metrics_df.to_csv(file_location, index=True)
98
-
99
- gr.Info(loc.localize("eval-tab-info-results-table-saved"))
100
- except Exception as e:
101
- print(f"Error saving results table: {e}")
102
- raise gr.Error(f"{loc.localize('eval-tab-error-saving-results-table')} {e}") from e
103
-
104
- def download_data_table(processor_state: ProcessorState):
105
- if processor_state is None:
106
- raise gr.Error(loc.localize("eval-tab-error-calc-metrics-first"), print_exception=False)
107
- try:
108
- file_location = gu.save_file_dialog(
109
- state_key="eval-data-table",
110
- filetypes=("CSV (*.csv)", "TSV (*.tsv;*.TSV)"),
111
- default_filename="data_table.csv",
112
- )
113
- if file_location:
114
- data_df = processor_state.processor.get_sample_data()
115
-
116
- if file_location.split(".")[-1].lower() == "tsv":
117
- data_df.to_csv(file_location, sep="\t", index=False)
118
- else:
119
- data_df.to_csv(file_location, index=False)
120
-
121
- gr.Info(loc.localize("eval-tab-info-data-table-saved"))
122
- except Exception as e:
123
- raise gr.Error(f"{loc.localize('eval-tab-error-saving-data-table')} {e}") from e
124
-
125
- def get_columns_from_uploaded_files(files):
126
- columns = set()
127
-
128
- if files:
129
- for file_obj in files:
130
- try:
131
- df = pd.read_csv(file_obj, sep=None, engine="python", nrows=0)
132
- columns.update(df.columns)
133
- except Exception as e:
134
- print(f"Error reading file {file_obj}: {e}")
135
- gr.Warning(f"{loc.localize('eval-tab-warning-error-reading-file')} {file_obj}")
136
-
137
- return sorted(list(columns))
138
-
139
- def save_uploaded_files(files):
140
- if not files:
141
- return None
142
-
143
- temp_dir = tempfile.mkdtemp()
144
-
145
- for file_obj in files:
146
- dest_path = os.path.join(temp_dir, os.path.basename(file_obj))
147
- shutil.copy(file_obj, dest_path)
148
-
149
- return temp_dir
150
-
151
- # Single initialize_processor that can reuse given directories.
152
- def initialize_processor(
153
- annotation_files,
154
- prediction_files,
155
- mapping_file_obj,
156
- sample_duration_value,
157
- min_overlap_value,
158
- recording_duration,
159
- ann_start_time,
160
- ann_end_time,
161
- ann_class,
162
- ann_recording,
163
- ann_duration,
164
- pred_start_time,
165
- pred_end_time,
166
- pred_class,
167
- pred_confidence,
168
- pred_recording,
169
- pred_duration,
170
- annotation_dir=None,
171
- prediction_dir=None,
172
- ):
173
- if not annotation_files or not prediction_files:
174
- return [], [], None, None, None
175
-
176
- if annotation_dir is None:
177
- annotation_dir = save_uploaded_files(annotation_files)
178
-
179
- if prediction_dir is None:
180
- prediction_dir = save_uploaded_files(prediction_files)
181
-
182
- # Fallback for annotation columns.
183
- ann_start_time = ann_start_time if ann_start_time else annotation_default_columns["Start Time"]
184
- ann_end_time = ann_end_time if ann_end_time else annotation_default_columns["End Time"]
185
- ann_class = ann_class if ann_class else annotation_default_columns["Class"]
186
- ann_recording = ann_recording if ann_recording else annotation_default_columns["Recording"]
187
- ann_duration = ann_duration if ann_duration else annotation_default_columns["Duration"]
188
-
189
- # Fallback for prediction columns.
190
- pred_start_time = pred_start_time if pred_start_time else prediction_default_columns["Start Time"]
191
- pred_end_time = pred_end_time if pred_end_time else prediction_default_columns["End Time"]
192
- pred_class = pred_class if pred_class else prediction_default_columns["Class"]
193
- pred_confidence = pred_confidence if pred_confidence else prediction_default_columns["Confidence"]
194
- pred_recording = pred_recording if pred_recording else prediction_default_columns["Recording"]
195
- pred_duration = pred_duration if pred_duration else prediction_default_columns["Duration"]
196
-
197
- cols_ann = {
198
- "Start Time": ann_start_time,
199
- "End Time": ann_end_time,
200
- "Class": ann_class,
201
- "Recording": ann_recording,
202
- "Duration": ann_duration,
203
- }
204
- cols_pred = {
205
- "Start Time": pred_start_time,
206
- "End Time": pred_end_time,
207
- "Class": pred_class,
208
- "Confidence": pred_confidence,
209
- "Recording": pred_recording,
210
- "Duration": pred_duration,
211
- }
212
-
213
- # Handle mapping file: if it has a temp_files attribute use that, otherwise assume it's a filepath.
214
- if mapping_file_obj and hasattr(mapping_file_obj, "temp_files"):
215
- mapping_path = list(mapping_file_obj.temp_files)[0]
216
- else:
217
- mapping_path = mapping_file_obj if mapping_file_obj else None
218
-
219
- if mapping_path:
220
- with open(mapping_path, "r") as f:
221
- class_mapping = json.load(f)
222
- else:
223
- class_mapping = None
224
-
225
- try:
226
- proc = DataProcessor(
227
- prediction_directory_path=prediction_dir,
228
- prediction_file_name=None,
229
- annotation_directory_path=annotation_dir,
230
- annotation_file_name=None,
231
- class_mapping=class_mapping,
232
- sample_duration=sample_duration_value,
233
- min_overlap=min_overlap_value,
234
- columns_predictions=cols_pred,
235
- columns_annotations=cols_ann,
236
- recording_duration=recording_duration,
237
- )
238
- avail_classes = list(proc.classes) # Ensure it's a list
239
- avail_recordings = proc.samples_df["filename"].unique().tolist()
240
-
241
- return avail_classes, avail_recordings, proc, annotation_dir, prediction_dir
242
- except KeyError as e:
243
- print(f"Column missing in files: {e}")
244
- raise gr.Error(
245
- f"{loc.localize('eval-tab-error-missing-col')}: "
246
- + str(e)
247
- + f". {loc.localize('eval-tab-error-missing-col-info')}"
248
- ) from e
249
- except Exception as e:
250
- print(f"Error initializing processor: {e}")
251
-
252
- raise gr.Error(f"{loc.localize('eval-tab-error-init-processor')}:" + str(e)) from e
253
-
254
- # update_selections is triggered when files or mapping file change.
255
- # It creates the temporary directories once and stores them along with the processor.
256
- # It now also receives the current selection values so that user selections are preserved.
257
- def update_selections(
258
- annotation_files,
259
- prediction_files,
260
- mapping_file_obj,
261
- sample_duration_value,
262
- min_overlap_value,
263
- recording_duration_value: str,
264
- ann_start_time,
265
- ann_end_time,
266
- ann_class,
267
- ann_recording,
268
- ann_duration,
269
- pred_start_time,
270
- pred_end_time,
271
- pred_class,
272
- pred_confidence,
273
- pred_recording,
274
- pred_duration,
275
- current_classes,
276
- current_recordings,
277
- ):
278
- if recording_duration_value.strip() == "":
279
- rec_dur = None
280
- else:
281
- try:
282
- rec_dur = float(recording_duration_value)
283
- except ValueError:
284
- rec_dur = None
285
-
286
- # Create temporary directories once.
287
- annotation_dir = save_uploaded_files(annotation_files)
288
- prediction_dir = save_uploaded_files(prediction_files)
289
- avail_classes, avail_recordings, proc, annotation_dir, prediction_dir = initialize_processor(
290
- annotation_files,
291
- prediction_files,
292
- mapping_file_obj,
293
- sample_duration_value,
294
- min_overlap_value,
295
- rec_dur,
296
- ann_start_time,
297
- ann_end_time,
298
- ann_class,
299
- ann_recording,
300
- ann_duration,
301
- pred_start_time,
302
- pred_end_time,
303
- pred_class,
304
- pred_confidence,
305
- pred_recording,
306
- pred_duration,
307
- annotation_dir,
308
- prediction_dir,
309
- )
310
- # Build a state dictionary to store the processor and the directories.
311
- state = ProcessorState(proc, annotation_dir, prediction_dir)
312
- # If no current selection exists, default to all available classes/recordings;
313
- # otherwise, preserve any selections that are still valid.
314
- new_classes = (
315
- avail_classes
316
- if not current_classes
317
- else [c for c in current_classes if c in avail_classes] or avail_classes
318
- )
319
- new_recordings = (
320
- avail_recordings
321
- if not current_recordings
322
- else [r for r in current_recordings if r in avail_recordings] or avail_recordings
323
- )
324
-
325
- return (
326
- gr.update(choices=avail_classes, value=new_classes),
327
- gr.update(choices=avail_recordings, value=new_recordings),
328
- state,
329
- )
330
-
331
- with gr.Tab(loc.localize("eval-tab-title")):
332
- # Custom CSS to match the layout style of other files and remove gray backgrounds.
333
- gr.Markdown(
334
- """
335
- <style>
336
- /* Grid layout for checkbox groups */
337
- .custom-checkbox-group {
338
- display: grid;
339
- grid-template-columns: repeat(auto-fit, minmax(150px, 1fr));
340
- grid-gap: 8px;
341
- }
342
- </style>
343
- """
344
- )
345
-
346
- processor_state = gr.State()
347
- pa_state = gr.State()
348
- predictions_state = gr.State()
349
- labels_state = gr.State()
350
- annotation_files_state = gr.State()
351
- prediction_files_state = gr.State()
352
-
353
- def get_selection_tables(directory):
354
- from pathlib import Path
355
-
356
- directory = Path(directory)
357
-
358
- return list(directory.glob("*.txt"))
359
-
360
- # Update column dropdowns when files are uploaded.
361
- def update_annotation_columns(uploaded_files):
362
- cols = get_columns_from_uploaded_files(uploaded_files)
363
- cols = [""] + cols
364
- updates = []
365
-
366
- for label in ["Start Time", "End Time", "Class", "Recording", "Duration"]:
367
- default_val = annotation_default_columns.get(label)
368
- val = default_val if default_val in cols else None
369
- updates.append(gr.update(choices=cols, value=val))
370
-
371
- return updates
372
-
373
- def update_prediction_columns(uploaded_files):
374
- cols = get_columns_from_uploaded_files(uploaded_files)
375
- cols = [""] + cols
376
- updates = []
377
-
378
- for label in ["Start Time", "End Time", "Class", "Confidence", "Recording", "Duration"]:
379
- default_val = prediction_default_columns.get(label)
380
- val = default_val if default_val in cols else None
381
- updates.append(gr.update(choices=cols, value=val))
382
-
383
- return updates
384
-
385
- def get_selection_func(state_key, on_select):
386
- def select_directory_on_empty(): # Nishant - Function modified for For Folder selection
387
- folder = gu.select_folder(state_key=state_key)
388
-
389
- if folder:
390
- files = get_selection_tables(folder)
391
- files_to_display = files[:100] + [["..."]] if len(files) > 100 else files
392
- return [files, files_to_display, gr.update(visible=True)] + on_select(files)
393
-
394
- return ["", [[loc.localize("eval-tab-no-files-found")]]]
395
-
396
- return select_directory_on_empty
397
-
398
- with gr.Row():
399
- with gr.Column():
400
- annotation_select_directory_btn = gr.Button(loc.localize("eval-tab-annotation-selection-button-label"))
401
- annotation_directory_input = gr.Matrix(
402
- interactive=False,
403
- headers=[
404
- loc.localize("eval-tab-selections-column-file-header"),
405
- ],
406
- )
407
-
408
- with gr.Column():
409
- prediction_select_directory_btn = gr.Button(loc.localize("eval-tab-prediction-selection-button-label"))
410
- prediction_directory_input = gr.Matrix(
411
- interactive=False,
412
- headers=[
413
- loc.localize("eval-tab-selections-column-file-header"),
414
- ],
415
- )
416
-
417
- # ----------------------- Annotations Columns Box -----------------------
418
- with gr.Group(visible=False) as annotation_group:
419
- with gr.Accordion(loc.localize("eval-tab-annotation-col-accordion-label"), open=True):
420
- with gr.Row():
421
- annotation_columns: dict[str, gr.Dropdown] = {}
422
-
423
- for col in ["Start Time", "End Time", "Class", "Recording", "Duration"]:
424
- annotation_columns[col] = gr.Dropdown(choices=[], label=localized_column_labels[col])
425
-
426
- # ----------------------- Predictions Columns Box -----------------------
427
- with gr.Group(visible=False) as prediction_group:
428
- with gr.Accordion(loc.localize("eval-tab-prediction-col-accordion-label"), open=True):
429
- with gr.Row():
430
- prediction_columns: dict[str, gr.Dropdown] = {}
431
-
432
- for col in ["Start Time", "End Time", "Class", "Confidence", "Recording", "Duration"]:
433
- prediction_columns[col] = gr.Dropdown(choices=[], label=localized_column_labels[col])
434
-
435
- # ----------------------- Class Mapping Box -----------------------
436
- with gr.Group(visible=False) as mapping_group:
437
- with gr.Accordion(loc.localize("eval-tab-class-mapping-accordion-label"), open=False):
438
- with gr.Row():
439
- mapping_file = gr.File(
440
- label=loc.localize("eval-tab-upload-mapping-file-label"),
441
- file_count="single",
442
- file_types=[".json"],
443
- )
444
- download_mapping_button = gr.DownloadButton(
445
- label=loc.localize("eval-tab-mapping-file-template-download-button-label")
446
- )
447
-
448
- download_mapping_button.click(fn=download_class_mapping_template)
449
-
450
- # ----------------------- Classes and Recordings Selection Box -----------------------
451
- with gr.Group(visible=False) as class_recording_group:
452
- with gr.Accordion(loc.localize("eval-tab-select-classes-recordings-accordion-label"), open=False):
453
- with gr.Row():
454
- with gr.Column():
455
- select_classes_checkboxgroup = gr.CheckboxGroup(
456
- choices=[],
457
- value=[],
458
- label=loc.localize("eval-tab-select-classes-checkboxgroup-label"),
459
- info=loc.localize("eval-tab-select-classes-checkboxgroup-info"),
460
- interactive=True,
461
- elem_classes="custom-checkbox-group",
462
- )
463
-
464
- with gr.Column():
465
- select_recordings_checkboxgroup = gr.CheckboxGroup(
466
- choices=[],
467
- value=[],
468
- label=loc.localize("eval-tab-select-recordings-checkboxgroup-label"),
469
- info=loc.localize("eval-tab-select-recordings-checkboxgroup-info"),
470
- interactive=True,
471
- elem_classes="custom-checkbox-group",
472
- )
473
-
474
- # ----------------------- Parameters Box -----------------------
475
- with gr.Group():
476
- with gr.Accordion(loc.localize("eval-tab-parameters-accordion-label"), open=False):
477
- with gr.Row():
478
- sample_duration = gr.Number(
479
- value=3,
480
- label=loc.localize("eval-tab-sample-duration-number-label"),
481
- precision=0,
482
- info=loc.localize("eval-tab-sample-duration-number-info"),
483
- )
484
- recording_duration = gr.Textbox(
485
- label=loc.localize("eval-tab-recording-duration-textbox-label"),
486
- placeholder=loc.localize("eval-tab-recording-duration-textbox-placeholder"),
487
- info=loc.localize("eval-tab-recording-duration-textbox-info"),
488
- )
489
- min_overlap = gr.Number(
490
- value=0.5,
491
- label=loc.localize("eval-tab-min-overlap-number-label"),
492
- info=loc.localize("eval-tab-min-overlap-number-info"),
493
- )
494
- threshold = gr.Slider(
495
- minimum=0.01,
496
- maximum=0.99,
497
- value=0.1,
498
- label=loc.localize("eval-tab-threshold-number-label"),
499
- info=loc.localize("eval-tab-threshold-number-info"),
500
- )
501
- class_wise = gr.Checkbox(
502
- label=loc.localize("eval-tab-classwise-checkbox-label"),
503
- value=False,
504
- info=loc.localize("eval-tab-classwise-checkbox-info"),
505
- )
506
-
507
- # ----------------------- Metrics Box -----------------------
508
- with gr.Group():
509
- with gr.Accordion(loc.localize("eval-tab-metrics-accordian-label"), open=False):
510
- with gr.Row():
511
- metric_info = {
512
- "AUROC": loc.localize("eval-tab-auroc-checkbox-info"),
513
- "Precision": loc.localize("eval-tab-precision-checkbox-info"),
514
- "Recall": loc.localize("eval-tab-recall-checkbox-info"),
515
- "F1 Score": loc.localize("eval-tab-f1-score-checkbox-info"),
516
- "Average Precision (AP)": loc.localize("eval-tab-ap-checkbox-info"),
517
- "Accuracy": loc.localize("eval-tab-accuracy-checkbox-info"),
518
- }
519
- metrics_checkboxes = {}
520
-
521
- for metric_name, description in metric_info.items():
522
- metrics_checkboxes[metric_name.lower()] = gr.Checkbox(
523
- label=metric_name, value=True, info=description
524
- )
525
-
526
- # ----------------------- Actions Box -----------------------
527
-
528
- calculate_button = gr.Button(loc.localize("eval-tab-calculate-metrics-button-label"), variant="huggingface")
529
-
530
- with gr.Column(visible=False) as action_col:
531
- with gr.Row():
532
- plot_metrics_button = gr.Button(loc.localize("eval-tab-plot-metrics-button-label"))
533
- plot_confusion_button = gr.Button(loc.localize("eval-tab-plot-confusion-matrix-button-label"))
534
- plot_metrics_all_thresholds_button = gr.Button(
535
- loc.localize("eval-tab-plot-metrics-all-thresholds-button-label")
536
- )
537
-
538
- with gr.Row():
539
- download_results_button = gr.DownloadButton(loc.localize("eval-tab-result-table-download-button-label"))
540
- download_data_button = gr.DownloadButton(loc.localize("eval-tab-data-table-download-button-label"))
541
-
542
- download_results_button.click(
543
- fn=download_results_table,
544
- inputs=[pa_state, predictions_state, labels_state, class_wise],
545
- )
546
- download_data_button.click(fn=download_data_table, inputs=[processor_state])
547
- metric_table = gr.Dataframe(show_label=False, type="pandas", visible=False, interactive=False)
548
- plot_output = gr.Plot(visible=False, show_label=False)
549
-
550
- # Update available selections (classes and recordings) and the processor state when files or mapping file change.
551
- # Also pass the current selection values so that user selections are preserved.
552
- for comp in list(annotation_columns.values()) + list(prediction_columns.values()) + [mapping_file]:
553
- comp.change(
554
- fn=update_selections,
555
- inputs=[
556
- annotation_files_state,
557
- prediction_files_state,
558
- mapping_file,
559
- sample_duration,
560
- min_overlap,
561
- recording_duration,
562
- annotation_columns["Start Time"],
563
- annotation_columns["End Time"],
564
- annotation_columns["Class"],
565
- annotation_columns["Recording"],
566
- annotation_columns["Duration"],
567
- prediction_columns["Start Time"],
568
- prediction_columns["End Time"],
569
- prediction_columns["Class"],
570
- prediction_columns["Confidence"],
571
- prediction_columns["Recording"],
572
- prediction_columns["Duration"],
573
- select_classes_checkboxgroup,
574
- select_recordings_checkboxgroup,
575
- ],
576
- outputs=[
577
- select_classes_checkboxgroup,
578
- select_recordings_checkboxgroup,
579
- processor_state,
580
- ],
581
- )
582
-
583
- # calculate_metrics now uses the stored temporary directories from processor_state.
584
- # The function now accepts selected_classes and selected_recordings as inputs.
585
- def calculate_metrics(
586
- mapping_file_obj,
587
- sample_duration_value,
588
- min_overlap_value,
589
- recording_duration_value: str,
590
- ann_start_time,
591
- ann_end_time,
592
- ann_class,
593
- ann_recording,
594
- ann_duration,
595
- pred_start_time,
596
- pred_end_time,
597
- pred_class,
598
- pred_confidence,
599
- pred_recording,
600
- pred_duration,
601
- threshold_value,
602
- class_wise_value,
603
- selected_classes_list,
604
- selected_recordings_list,
605
- proc_state: ProcessorState,
606
- *metrics_checkbox_values,
607
- ):
608
- selected_metrics = []
609
-
610
- for value, (m_lower, _) in zip(metrics_checkbox_values, metrics_checkboxes.items()):
611
- if value:
612
- selected_metrics.append(m_lower)
613
-
614
- valid_metrics = {
615
- "accuracy": "accuracy",
616
- "recall": "recall",
617
- "precision": "precision",
618
- "f1 score": "f1",
619
- "average precision (ap)": "ap",
620
- "auroc": "auroc",
621
- }
622
- metrics = tuple([valid_metrics[m] for m in selected_metrics if m in valid_metrics])
623
-
624
- # Fall back to available classes from processor state if none selected.
625
- if not selected_classes_list and proc_state and proc_state.processor:
626
- selected_classes_list = list(proc_state.processor.classes)
627
-
628
- if not selected_classes_list:
629
- raise gr.Error(loc.localize("eval-tab-error-no-class-selected"))
630
-
631
- if recording_duration_value.strip() == "":
632
- rec_dur = None
633
- else:
634
- try:
635
- rec_dur = float(recording_duration_value)
636
- except ValueError as e:
637
- raise gr.Error(loc.localize("eval-tab-error-no-valid-recording-duration")) from e
638
-
639
- if mapping_file_obj and hasattr(mapping_file_obj, "temp_files"):
640
- mapping_path = list(mapping_file_obj.temp_files)[0]
641
- else:
642
- mapping_path = mapping_file_obj if mapping_file_obj else None
643
-
644
- try:
645
- metrics_df, pa, preds, labs = process_data(
646
- annotation_path=proc_state.annotation_dir,
647
- prediction_path=proc_state.prediction_dir,
648
- mapping_path=mapping_path,
649
- sample_duration=sample_duration_value,
650
- min_overlap=min_overlap_value,
651
- recording_duration=rec_dur,
652
- columns_annotations={
653
- "Start Time": ann_start_time,
654
- "End Time": ann_end_time,
655
- "Class": ann_class,
656
- "Recording": ann_recording,
657
- "Duration": ann_duration,
658
- },
659
- columns_predictions={
660
- "Start Time": pred_start_time,
661
- "End Time": pred_end_time,
662
- "Class": pred_class,
663
- "Confidence": pred_confidence,
664
- "Recording": pred_recording,
665
- "Duration": pred_duration,
666
- },
667
- selected_classes=selected_classes_list,
668
- selected_recordings=selected_recordings_list,
669
- metrics_list=metrics,
670
- threshold=threshold_value,
671
- class_wise=class_wise_value,
672
- )
673
-
674
- table = metrics_df.T.reset_index(names=[""])
675
-
676
- return (
677
- gr.update(value=table, visible=True),
678
- gr.update(visible=True),
679
- pa,
680
- preds,
681
- labs,
682
- gr.update(),
683
- gr.update(),
684
- proc_state,
685
- )
686
- except Exception as e:
687
- print("Error processing data:", e)
688
- raise gr.Error(f"{loc.localize('eval-tab-error-during-processing')}: {e}") from e
689
-
690
- # Updated calculate_button click now passes the selected classes and recordings.
691
- calculate_button.click(
692
- calculate_metrics,
693
- inputs=[
694
- mapping_file,
695
- sample_duration,
696
- min_overlap,
697
- recording_duration,
698
- annotation_columns["Start Time"],
699
- annotation_columns["End Time"],
700
- annotation_columns["Class"],
701
- annotation_columns["Recording"],
702
- annotation_columns["Duration"],
703
- prediction_columns["Start Time"],
704
- prediction_columns["End Time"],
705
- prediction_columns["Class"],
706
- prediction_columns["Confidence"],
707
- prediction_columns["Recording"],
708
- prediction_columns["Duration"],
709
- threshold,
710
- class_wise,
711
- select_classes_checkboxgroup,
712
- select_recordings_checkboxgroup,
713
- processor_state,
714
- ]
715
- + [checkbox for checkbox in metrics_checkboxes.values()],
716
- outputs=[
717
- metric_table,
718
- action_col,
719
- pa_state,
720
- predictions_state,
721
- labels_state,
722
- select_classes_checkboxgroup,
723
- select_recordings_checkboxgroup,
724
- processor_state,
725
- ],
726
- )
727
-
728
- def plot_metrics(pa: PerformanceAssessor, predictions, labels, class_wise_value):
729
- if pa is None or predictions is None or labels is None:
730
- raise gr.Error(loc.localize("eval-tab-error-calc-metrics-first"), print_exception=False)
731
- try:
732
- fig = pa.plot_metrics(predictions, labels, per_class_metrics=class_wise_value)
733
- plt.close(fig)
734
-
735
- return gr.update(visible=True, value=fig)
736
- except Exception as e:
737
- raise gr.Error(f"{loc.localize('eval-tab-error-plotting-metrics')}: {e}") from e
738
-
739
- plot_metrics_button.click(
740
- plot_metrics,
741
- inputs=[pa_state, predictions_state, labels_state, class_wise],
742
- outputs=[plot_output],
743
- )
744
-
745
- def plot_confusion_matrix(pa: PerformanceAssessor, predictions, labels):
746
- if pa is None or predictions is None or labels is None:
747
- raise gr.Error(loc.localize("eval-tab-error-calc-metrics-first"), print_exception=False)
748
- try:
749
- fig = pa.plot_confusion_matrix(predictions, labels)
750
- plt.close(fig)
751
-
752
- return gr.update(visible=True, value=fig)
753
- except Exception as e:
754
- raise gr.Error(f"{loc.localize('eval-tab-error-plotting-confusion-matrix')}: {e}") from e
755
-
756
- plot_confusion_button.click(
757
- plot_confusion_matrix,
758
- inputs=[pa_state, predictions_state, labels_state],
759
- outputs=[plot_output],
760
- )
761
-
762
- annotation_select_directory_btn.click(
763
- get_selection_func("eval-annotations-dir", update_annotation_columns),
764
- outputs=[annotation_files_state, annotation_directory_input, annotation_group]
765
- + [annotation_columns[label] for label in ["Start Time", "End Time", "Class", "Recording", "Duration"]],
766
- show_progress=True,
767
- )
768
-
769
- prediction_select_directory_btn.click(
770
- get_selection_func("eval-predictions-dir", update_prediction_columns),
771
- outputs=[prediction_files_state, prediction_directory_input, prediction_group]
772
- + [
773
- prediction_columns[label]
774
- for label in ["Start Time", "End Time", "Class", "Confidence", "Recording", "Duration"]
775
- ],
776
- show_progress=True,
777
- )
778
-
779
- def toggle_after_selection(annotation_files, prediction_files):
780
- return [gr.update(visible=annotation_files and prediction_files)] * 2
781
-
782
- annotation_directory_input.change(
783
- toggle_after_selection,
784
- inputs=[annotation_files_state, prediction_files_state],
785
- outputs=[mapping_group, class_recording_group],
786
- )
787
-
788
- prediction_directory_input.change(
789
- toggle_after_selection,
790
- inputs=[annotation_files_state, prediction_files_state],
791
- outputs=[mapping_group, class_recording_group],
792
- )
793
-
794
- def plot_metrics_all_thresholds(pa: PerformanceAssessor, predictions, labels, class_wise_value):
795
- if pa is None or predictions is None or labels is None:
796
- raise gr.Error(loc.localize("eval-tab-error-calc-metrics-first"), print_exception=False)
797
- try:
798
- fig = pa.plot_metrics_all_thresholds(predictions, labels, per_class_metrics=class_wise_value)
799
- plt.close(fig)
800
-
801
- return gr.update(visible=True, value=fig)
802
- except Exception as e:
803
- raise gr.Error(f"{loc.localize('eval-tab-error-plotting-metrics-all-thresholds')}: {e}") from e
804
-
805
- plot_metrics_all_thresholds_button.click(
806
- plot_metrics_all_thresholds,
807
- inputs=[pa_state, predictions_state, labels_state, class_wise],
808
- outputs=[plot_output],
809
- )
810
-
811
-
812
- if __name__ == "__main__":
813
- gu.open_window(build_evaluation_tab)
1
+ import json
2
+ import os
3
+ import shutil
4
+ import tempfile
5
+ import typing
6
+
7
+ import gradio as gr
8
+ import matplotlib.pyplot as plt
9
+ import pandas as pd
10
+
11
+ import birdnet_analyzer.gui.localization as loc
12
+ import birdnet_analyzer.gui.utils as gu
13
+ from birdnet_analyzer.evaluation import process_data
14
+ from birdnet_analyzer.evaluation.assessment.performance_assessor import (
15
+ PerformanceAssessor,
16
+ )
17
+ from birdnet_analyzer.evaluation.preprocessing.data_processor import DataProcessor
18
+
19
+
20
+ class ProcessorState(typing.NamedTuple):
21
+ """State of the DataProcessor."""
22
+
23
+ processor: DataProcessor
24
+ annotation_dir: str
25
+ prediction_dir: str
26
+
27
+
28
+ def build_evaluation_tab():
29
+ # Default columns for annotations
30
+ annotation_default_columns = {
31
+ "Start Time": "Begin Time (s)",
32
+ "End Time": "End Time (s)",
33
+ "Class": "Class",
34
+ "Recording": "Begin File",
35
+ "Duration": "File Duration (s)",
36
+ }
37
+
38
+ # Default columns for predictions
39
+ prediction_default_columns = {
40
+ "Start Time": "Begin Time (s)",
41
+ "End Time": "End Time (s)",
42
+ "Class": "Common Name",
43
+ "Recording": "Begin File",
44
+ "Duration": "File Duration (s)",
45
+ "Confidence": "Confidence",
46
+ }
47
+
48
+ localized_column_labels = {
49
+ "Start Time": loc.localize("eval-tab-column-start-time-label"),
50
+ "End Time": loc.localize("eval-tab-column-end-time-label"),
51
+ "Class": loc.localize("eval-tab-column-class-label"),
52
+ "Recording": loc.localize("eval-tab-column-recording-label"),
53
+ "Duration": loc.localize("eval-tab-column-duration-label"),
54
+ "Confidence": loc.localize("eval-tab-column-confidence-label"),
55
+ }
56
+
57
+ def download_class_mapping_template():
58
+ try:
59
+ template_mapping = {
60
+ "Predicted Class Name 1": "Annotation Class Name 1",
61
+ "Predicted Class Name 2": "Annotation Class Name 2",
62
+ "Predicted Class Name 3": "Annotation Class Name 3",
63
+ "Predicted Class Name 4": "Annotation Class Name 4",
64
+ "Predicted Class Name 5": "Annotation Class Name 5",
65
+ }
66
+
67
+ file_location = gu.save_file_dialog(
68
+ state_key="eval-mapping-template",
69
+ filetypes=("JSON (*.json)",),
70
+ default_filename="class_mapping_template.json",
71
+ )
72
+
73
+ if file_location:
74
+ with open(file_location, "w") as f:
75
+ json.dump(template_mapping, f, indent=4)
76
+
77
+ gr.Info(loc.localize("eval-tab-info-mapping-template-saved"))
78
+ except Exception as e:
79
+ print(f"Error saving mapping template: {e}")
80
+ raise gr.Error(f"{loc.localize('eval-tab-error-saving-mapping-template')} {e}") from e
81
+
82
+ def download_results_table(pa: PerformanceAssessor, predictions, labels, class_wise_value):
83
+ if pa is None or predictions is None or labels is None:
84
+ raise gr.Error(loc.localize("eval-tab-error-calc-metrics-first"), print_exception=False)
85
+
86
+ try:
87
+ file_location = gu.save_file_dialog(
88
+ state_key="eval-results-table",
89
+ filetypes=("CSV (*.csv;*.CSV)", "TSV (*.tsv;*.TSV)"),
90
+ default_filename="results_table.csv",
91
+ )
92
+
93
+ if file_location:
94
+ metrics_df = pa.calculate_metrics(predictions, labels, per_class_metrics=class_wise_value)
95
+
96
+ if file_location.split(".")[-1].lower() == "tsv":
97
+ metrics_df.to_csv(file_location, sep="\t", index=True)
98
+ else:
99
+ metrics_df.to_csv(file_location, index=True)
100
+
101
+ gr.Info(loc.localize("eval-tab-info-results-table-saved"))
102
+ except Exception as e:
103
+ print(f"Error saving results table: {e}")
104
+ raise gr.Error(f"{loc.localize('eval-tab-error-saving-results-table')} {e}") from e
105
+
106
+ def download_data_table(processor_state: ProcessorState):
107
+ if processor_state is None:
108
+ raise gr.Error(loc.localize("eval-tab-error-calc-metrics-first"), print_exception=False)
109
+ try:
110
+ file_location = gu.save_file_dialog(
111
+ state_key="eval-data-table",
112
+ filetypes=("CSV (*.csv)", "TSV (*.tsv;*.TSV)"),
113
+ default_filename="data_table.csv",
114
+ )
115
+ if file_location:
116
+ data_df = processor_state.processor.get_sample_data()
117
+
118
+ if file_location.split(".")[-1].lower() == "tsv":
119
+ data_df.to_csv(file_location, sep="\t", index=False)
120
+ else:
121
+ data_df.to_csv(file_location, index=False)
122
+
123
+ gr.Info(loc.localize("eval-tab-info-data-table-saved"))
124
+ except Exception as e:
125
+ raise gr.Error(f"{loc.localize('eval-tab-error-saving-data-table')} {e}") from e
126
+
127
+ def get_columns_from_uploaded_files(files):
128
+ columns = set()
129
+
130
+ if files:
131
+ for file_obj in files:
132
+ try:
133
+ df = pd.read_csv(file_obj, sep=None, engine="python", nrows=0)
134
+ columns.update(df.columns)
135
+ except Exception as e:
136
+ print(f"Error reading file {file_obj}: {e}")
137
+ gr.Warning(f"{loc.localize('eval-tab-warning-error-reading-file')} {file_obj}")
138
+
139
+ return sorted(columns)
140
+
141
+ def save_uploaded_files(files):
142
+ if not files:
143
+ return None
144
+
145
+ temp_dir = tempfile.mkdtemp()
146
+
147
+ for file_obj in files:
148
+ dest_path = os.path.join(temp_dir, os.path.basename(file_obj))
149
+ shutil.copy(file_obj, dest_path)
150
+
151
+ return temp_dir
152
+
153
+ # Single initialize_processor that can reuse given directories.
154
+ def initialize_processor(
155
+ annotation_files,
156
+ prediction_files,
157
+ mapping_file_obj,
158
+ sample_duration_value,
159
+ min_overlap_value,
160
+ recording_duration,
161
+ ann_start_time,
162
+ ann_end_time,
163
+ ann_class,
164
+ ann_recording,
165
+ ann_duration,
166
+ pred_start_time,
167
+ pred_end_time,
168
+ pred_class,
169
+ pred_confidence,
170
+ pred_recording,
171
+ pred_duration,
172
+ annotation_dir=None,
173
+ prediction_dir=None,
174
+ ):
175
+ if not annotation_files or not prediction_files:
176
+ return [], [], None, None, None
177
+
178
+ if annotation_dir is None:
179
+ annotation_dir = save_uploaded_files(annotation_files)
180
+
181
+ if prediction_dir is None:
182
+ prediction_dir = save_uploaded_files(prediction_files)
183
+
184
+ # Fallback for annotation columns.
185
+ ann_start_time = ann_start_time if ann_start_time else annotation_default_columns["Start Time"]
186
+ ann_end_time = ann_end_time if ann_end_time else annotation_default_columns["End Time"]
187
+ ann_class = ann_class if ann_class else annotation_default_columns["Class"]
188
+ ann_recording = ann_recording if ann_recording else annotation_default_columns["Recording"]
189
+ ann_duration = ann_duration if ann_duration else annotation_default_columns["Duration"]
190
+
191
+ # Fallback for prediction columns.
192
+ pred_start_time = pred_start_time if pred_start_time else prediction_default_columns["Start Time"]
193
+ pred_end_time = pred_end_time if pred_end_time else prediction_default_columns["End Time"]
194
+ pred_class = pred_class if pred_class else prediction_default_columns["Class"]
195
+ pred_confidence = pred_confidence if pred_confidence else prediction_default_columns["Confidence"]
196
+ pred_recording = pred_recording if pred_recording else prediction_default_columns["Recording"]
197
+ pred_duration = pred_duration if pred_duration else prediction_default_columns["Duration"]
198
+
199
+ cols_ann = {
200
+ "Start Time": ann_start_time,
201
+ "End Time": ann_end_time,
202
+ "Class": ann_class,
203
+ "Recording": ann_recording,
204
+ "Duration": ann_duration,
205
+ }
206
+ cols_pred = {
207
+ "Start Time": pred_start_time,
208
+ "End Time": pred_end_time,
209
+ "Class": pred_class,
210
+ "Confidence": pred_confidence,
211
+ "Recording": pred_recording,
212
+ "Duration": pred_duration,
213
+ }
214
+
215
+ # Handle mapping file: if it has a temp_files attribute use that, otherwise assume it's a filepath.
216
+ if mapping_file_obj and hasattr(mapping_file_obj, "temp_files"):
217
+ mapping_path = list(mapping_file_obj.temp_files)[0]
218
+ else:
219
+ mapping_path = mapping_file_obj if mapping_file_obj else None
220
+
221
+ if mapping_path:
222
+ with open(mapping_path) as f:
223
+ class_mapping = json.load(f)
224
+ else:
225
+ class_mapping = None
226
+
227
+ try:
228
+ proc = DataProcessor(
229
+ prediction_directory_path=prediction_dir,
230
+ prediction_file_name=None,
231
+ annotation_directory_path=annotation_dir,
232
+ annotation_file_name=None,
233
+ class_mapping=class_mapping,
234
+ sample_duration=sample_duration_value,
235
+ min_overlap=min_overlap_value,
236
+ columns_predictions=cols_pred,
237
+ columns_annotations=cols_ann,
238
+ recording_duration=recording_duration,
239
+ )
240
+ avail_classes = list(proc.classes) # Ensure it's a list
241
+ avail_recordings = proc.samples_df["filename"].unique().tolist()
242
+
243
+ return avail_classes, avail_recordings, proc, annotation_dir, prediction_dir
244
+ except KeyError as e:
245
+ print(f"Column missing in files: {e}")
246
+ raise gr.Error(f"{loc.localize('eval-tab-error-missing-col')}: " + str(e) + f". {loc.localize('eval-tab-error-missing-col-info')}") from e
247
+ except Exception as e:
248
+ print(f"Error initializing processor: {e}")
249
+
250
+ raise gr.Error(f"{loc.localize('eval-tab-error-init-processor')}:" + str(e)) from e
251
+
252
+ # update_selections is triggered when files or mapping file change.
253
+ # It creates the temporary directories once and stores them along with the processor.
254
+ # It now also receives the current selection values so that user selections are preserved.
255
+ def update_selections(
256
+ annotation_files,
257
+ prediction_files,
258
+ mapping_file_obj,
259
+ sample_duration_value,
260
+ min_overlap_value,
261
+ recording_duration_value: str,
262
+ ann_start_time,
263
+ ann_end_time,
264
+ ann_class,
265
+ ann_recording,
266
+ ann_duration,
267
+ pred_start_time,
268
+ pred_end_time,
269
+ pred_class,
270
+ pred_confidence,
271
+ pred_recording,
272
+ pred_duration,
273
+ current_classes,
274
+ current_recordings,
275
+ ):
276
+ if recording_duration_value.strip() == "":
277
+ rec_dur = None
278
+ else:
279
+ try:
280
+ rec_dur = float(recording_duration_value)
281
+ except ValueError:
282
+ rec_dur = None
283
+
284
+ # Create temporary directories once.
285
+ annotation_dir = save_uploaded_files(annotation_files)
286
+ prediction_dir = save_uploaded_files(prediction_files)
287
+ avail_classes, avail_recordings, proc, annotation_dir, prediction_dir = initialize_processor(
288
+ annotation_files,
289
+ prediction_files,
290
+ mapping_file_obj,
291
+ sample_duration_value,
292
+ min_overlap_value,
293
+ rec_dur,
294
+ ann_start_time,
295
+ ann_end_time,
296
+ ann_class,
297
+ ann_recording,
298
+ ann_duration,
299
+ pred_start_time,
300
+ pred_end_time,
301
+ pred_class,
302
+ pred_confidence,
303
+ pred_recording,
304
+ pred_duration,
305
+ annotation_dir,
306
+ prediction_dir,
307
+ )
308
+ # Build a state dictionary to store the processor and the directories.
309
+ state = ProcessorState(proc, annotation_dir, prediction_dir)
310
+ # If no current selection exists, default to all available classes/recordings;
311
+ # otherwise, preserve any selections that are still valid.
312
+ new_classes = avail_classes if not current_classes else [c for c in current_classes if c in avail_classes] or avail_classes
313
+ new_recordings = avail_recordings if not current_recordings else [r for r in current_recordings if r in avail_recordings] or avail_recordings
314
+
315
+ return (
316
+ gr.update(choices=avail_classes, value=new_classes),
317
+ gr.update(choices=avail_recordings, value=new_recordings),
318
+ state,
319
+ )
320
+
321
+ with gr.Tab(loc.localize("eval-tab-title")):
322
+ # Custom CSS to match the layout style of other files and remove gray backgrounds.
323
+ gr.Markdown(
324
+ """
325
+ <style>
326
+ /* Grid layout for checkbox groups */
327
+ .custom-checkbox-group {
328
+ display: grid;
329
+ grid-template-columns: repeat(auto-fit, minmax(150px, 1fr));
330
+ grid-gap: 8px;
331
+ }
332
+ </style>
333
+ """
334
+ )
335
+
336
+ processor_state = gr.State()
337
+ pa_state = gr.State()
338
+ predictions_state = gr.State()
339
+ labels_state = gr.State()
340
+ annotation_files_state = gr.State()
341
+ prediction_files_state = gr.State()
342
+
343
+ def get_selection_tables(directory):
344
+ from pathlib import Path
345
+
346
+ directory = Path(directory)
347
+
348
+ return list(directory.glob("*.txt"))
349
+
350
+ # Update column dropdowns when files are uploaded.
351
+ def update_annotation_columns(uploaded_files):
352
+ cols = get_columns_from_uploaded_files(uploaded_files)
353
+ cols = ["", *cols]
354
+ updates = []
355
+
356
+ for label in ["Start Time", "End Time", "Class", "Recording", "Duration"]:
357
+ default_val = annotation_default_columns.get(label)
358
+ val = default_val if default_val in cols else None
359
+ updates.append(gr.update(choices=cols, value=val))
360
+
361
+ return updates
362
+
363
+ def update_prediction_columns(uploaded_files):
364
+ cols = get_columns_from_uploaded_files(uploaded_files)
365
+ cols = ["", *cols]
366
+ updates = []
367
+
368
+ for label in ["Start Time", "End Time", "Class", "Confidence", "Recording", "Duration"]:
369
+ default_val = prediction_default_columns.get(label)
370
+ val = default_val if default_val in cols else None
371
+ updates.append(gr.update(choices=cols, value=val))
372
+
373
+ return updates
374
+
375
+ def get_selection_func(state_key, on_select):
376
+ def select_directory_on_empty(): # Nishant - Function modified for For Folder selection
377
+ folder = gu.select_folder(state_key=state_key)
378
+
379
+ if folder:
380
+ files = get_selection_tables(folder)
381
+ files_to_display = files[:100] + [["..."]] if len(files) > 100 else files
382
+ return [files, files_to_display, gr.update(visible=True), *on_select(files)]
383
+
384
+ return ["", [[loc.localize("eval-tab-no-files-found")]]]
385
+
386
+ return select_directory_on_empty
387
+
388
+ with gr.Row():
389
+ with gr.Column():
390
+ annotation_select_directory_btn = gr.Button(loc.localize("eval-tab-annotation-selection-button-label"))
391
+ annotation_directory_input = gr.Matrix(
392
+ interactive=False,
393
+ headers=[
394
+ loc.localize("eval-tab-selections-column-file-header"),
395
+ ],
396
+ )
397
+
398
+ with gr.Column():
399
+ prediction_select_directory_btn = gr.Button(loc.localize("eval-tab-prediction-selection-button-label"))
400
+ prediction_directory_input = gr.Matrix(
401
+ interactive=False,
402
+ headers=[
403
+ loc.localize("eval-tab-selections-column-file-header"),
404
+ ],
405
+ )
406
+
407
+ # ----------------------- Annotations Columns Box -----------------------
408
+ with (
409
+ gr.Group(visible=False) as annotation_group,
410
+ gr.Accordion(loc.localize("eval-tab-annotation-col-accordion-label"), open=True),
411
+ gr.Row(),
412
+ ):
413
+ annotation_columns: dict[str, gr.Dropdown] = {}
414
+
415
+ for col in ["Start Time", "End Time", "Class", "Recording", "Duration"]:
416
+ annotation_columns[col] = gr.Dropdown(choices=[], label=localized_column_labels[col])
417
+
418
+ # ----------------------- Predictions Columns Box -----------------------
419
+ with (
420
+ gr.Group(visible=False) as prediction_group,
421
+ gr.Accordion(loc.localize("eval-tab-prediction-col-accordion-label"), open=True),
422
+ gr.Row(),
423
+ ):
424
+ prediction_columns: dict[str, gr.Dropdown] = {}
425
+
426
+ for col in ["Start Time", "End Time", "Class", "Confidence", "Recording", "Duration"]:
427
+ prediction_columns[col] = gr.Dropdown(choices=[], label=localized_column_labels[col])
428
+
429
+ # ----------------------- Class Mapping Box -----------------------
430
+ with gr.Group(visible=False) as mapping_group:
431
+ with gr.Accordion(loc.localize("eval-tab-class-mapping-accordion-label"), open=False), gr.Row():
432
+ mapping_file = gr.File(
433
+ label=loc.localize("eval-tab-upload-mapping-file-label"),
434
+ file_count="single",
435
+ file_types=[".json"],
436
+ )
437
+ download_mapping_button = gr.DownloadButton(label=loc.localize("eval-tab-mapping-file-template-download-button-label"))
438
+
439
+ download_mapping_button.click(fn=download_class_mapping_template)
440
+
441
+ # ----------------------- Classes and Recordings Selection Box -----------------------
442
+ with (
443
+ gr.Group(visible=False) as class_recording_group,
444
+ gr.Accordion(loc.localize("eval-tab-select-classes-recordings-accordion-label"), open=False),
445
+ gr.Row(),
446
+ ):
447
+ with gr.Column():
448
+ select_classes_checkboxgroup = gr.CheckboxGroup(
449
+ choices=[],
450
+ value=[],
451
+ label=loc.localize("eval-tab-select-classes-checkboxgroup-label"),
452
+ info=loc.localize("eval-tab-select-classes-checkboxgroup-info"),
453
+ interactive=True,
454
+ elem_classes="custom-checkbox-group",
455
+ )
456
+
457
+ with gr.Column():
458
+ select_recordings_checkboxgroup = gr.CheckboxGroup(
459
+ choices=[],
460
+ value=[],
461
+ label=loc.localize("eval-tab-select-recordings-checkboxgroup-label"),
462
+ info=loc.localize("eval-tab-select-recordings-checkboxgroup-info"),
463
+ interactive=True,
464
+ elem_classes="custom-checkbox-group",
465
+ )
466
+
467
+ # ----------------------- Parameters Box -----------------------
468
+ with gr.Group(), gr.Accordion(loc.localize("eval-tab-parameters-accordion-label"), open=False), gr.Row():
469
+ sample_duration = gr.Number(
470
+ value=3,
471
+ label=loc.localize("eval-tab-sample-duration-number-label"),
472
+ precision=0,
473
+ info=loc.localize("eval-tab-sample-duration-number-info"),
474
+ )
475
+ recording_duration = gr.Textbox(
476
+ label=loc.localize("eval-tab-recording-duration-textbox-label"),
477
+ placeholder=loc.localize("eval-tab-recording-duration-textbox-placeholder"),
478
+ info=loc.localize("eval-tab-recording-duration-textbox-info"),
479
+ )
480
+ min_overlap = gr.Number(
481
+ value=0.5,
482
+ label=loc.localize("eval-tab-min-overlap-number-label"),
483
+ info=loc.localize("eval-tab-min-overlap-number-info"),
484
+ )
485
+ threshold = gr.Slider(
486
+ minimum=0.01,
487
+ maximum=0.99,
488
+ value=0.1,
489
+ label=loc.localize("eval-tab-threshold-number-label"),
490
+ info=loc.localize("eval-tab-threshold-number-info"),
491
+ )
492
+ class_wise = gr.Checkbox(
493
+ label=loc.localize("eval-tab-classwise-checkbox-label"),
494
+ value=False,
495
+ info=loc.localize("eval-tab-classwise-checkbox-info"),
496
+ )
497
+
498
+ # ----------------------- Metrics Box -----------------------
499
+ with gr.Group(), gr.Accordion(loc.localize("eval-tab-metrics-accordian-label"), open=False), gr.Row():
500
+ metric_info = {
501
+ "AUROC": loc.localize("eval-tab-auroc-checkbox-info"),
502
+ "Precision": loc.localize("eval-tab-precision-checkbox-info"),
503
+ "Recall": loc.localize("eval-tab-recall-checkbox-info"),
504
+ "F1 Score": loc.localize("eval-tab-f1-score-checkbox-info"),
505
+ "Average Precision (AP)": loc.localize("eval-tab-ap-checkbox-info"),
506
+ "Accuracy": loc.localize("eval-tab-accuracy-checkbox-info"),
507
+ }
508
+ metrics_checkboxes = {}
509
+
510
+ for metric_name, description in metric_info.items():
511
+ metrics_checkboxes[metric_name.lower()] = gr.Checkbox(label=metric_name, value=True, info=description)
512
+
513
+ # ----------------------- Actions Box -----------------------
514
+
515
+ calculate_button = gr.Button(loc.localize("eval-tab-calculate-metrics-button-label"), variant="huggingface")
516
+
517
+ with gr.Column(visible=False) as action_col:
518
+ with gr.Row():
519
+ plot_metrics_button = gr.Button(loc.localize("eval-tab-plot-metrics-button-label"))
520
+ plot_confusion_button = gr.Button(loc.localize("eval-tab-plot-confusion-matrix-button-label"))
521
+ plot_metrics_all_thresholds_button = gr.Button(loc.localize("eval-tab-plot-metrics-all-thresholds-button-label"))
522
+
523
+ with gr.Row():
524
+ download_results_button = gr.DownloadButton(loc.localize("eval-tab-result-table-download-button-label"))
525
+ download_data_button = gr.DownloadButton(loc.localize("eval-tab-data-table-download-button-label"))
526
+
527
+ download_results_button.click(
528
+ fn=download_results_table,
529
+ inputs=[pa_state, predictions_state, labels_state, class_wise],
530
+ )
531
+ download_data_button.click(fn=download_data_table, inputs=[processor_state])
532
+ metric_table = gr.Dataframe(show_label=False, type="pandas", visible=False, interactive=False)
533
+ plot_output = gr.Plot(visible=False, show_label=False)
534
+
535
+ # Update available selections (classes and recordings) and the processor state when files or mapping file change.
536
+ # Also pass the current selection values so that user selections are preserved.
537
+ for comp in list(annotation_columns.values()) + list(prediction_columns.values()) + [mapping_file]:
538
+ comp.change(
539
+ fn=update_selections,
540
+ inputs=[
541
+ annotation_files_state,
542
+ prediction_files_state,
543
+ mapping_file,
544
+ sample_duration,
545
+ min_overlap,
546
+ recording_duration,
547
+ annotation_columns["Start Time"],
548
+ annotation_columns["End Time"],
549
+ annotation_columns["Class"],
550
+ annotation_columns["Recording"],
551
+ annotation_columns["Duration"],
552
+ prediction_columns["Start Time"],
553
+ prediction_columns["End Time"],
554
+ prediction_columns["Class"],
555
+ prediction_columns["Confidence"],
556
+ prediction_columns["Recording"],
557
+ prediction_columns["Duration"],
558
+ select_classes_checkboxgroup,
559
+ select_recordings_checkboxgroup,
560
+ ],
561
+ outputs=[
562
+ select_classes_checkboxgroup,
563
+ select_recordings_checkboxgroup,
564
+ processor_state,
565
+ ],
566
+ )
567
+
568
+ # calculate_metrics now uses the stored temporary directories from processor_state.
569
+ # The function now accepts selected_classes and selected_recordings as inputs.
570
+ def calculate_metrics(
571
+ mapping_file_obj,
572
+ sample_duration_value,
573
+ min_overlap_value,
574
+ recording_duration_value: str,
575
+ ann_start_time,
576
+ ann_end_time,
577
+ ann_class,
578
+ ann_recording,
579
+ ann_duration,
580
+ pred_start_time,
581
+ pred_end_time,
582
+ pred_class,
583
+ pred_confidence,
584
+ pred_recording,
585
+ pred_duration,
586
+ threshold_value,
587
+ class_wise_value,
588
+ selected_classes_list,
589
+ selected_recordings_list,
590
+ proc_state: ProcessorState,
591
+ *metrics_checkbox_values,
592
+ ):
593
+ selected_metrics = []
594
+
595
+ for value, (m_lower, _) in zip(metrics_checkbox_values, metrics_checkboxes.items(), strict=True):
596
+ if value:
597
+ selected_metrics.append(m_lower)
598
+
599
+ valid_metrics = {
600
+ "accuracy": "accuracy",
601
+ "recall": "recall",
602
+ "precision": "precision",
603
+ "f1 score": "f1",
604
+ "average precision (ap)": "ap",
605
+ "auroc": "auroc",
606
+ }
607
+ metrics = tuple([valid_metrics[m] for m in selected_metrics if m in valid_metrics])
608
+
609
+ # Fall back to available classes from processor state if none selected.
610
+ if not selected_classes_list and proc_state and proc_state.processor:
611
+ selected_classes_list = list(proc_state.processor.classes)
612
+
613
+ if not selected_classes_list:
614
+ raise gr.Error(loc.localize("eval-tab-error-no-class-selected"))
615
+
616
+ if recording_duration_value.strip() == "":
617
+ rec_dur = None
618
+ else:
619
+ try:
620
+ rec_dur = float(recording_duration_value)
621
+ except ValueError as e:
622
+ raise gr.Error(loc.localize("eval-tab-error-no-valid-recording-duration")) from e
623
+
624
+ if mapping_file_obj and hasattr(mapping_file_obj, "temp_files"):
625
+ mapping_path = list(mapping_file_obj.temp_files)[0]
626
+ else:
627
+ mapping_path = mapping_file_obj if mapping_file_obj else None
628
+
629
+ try:
630
+ metrics_df, pa, preds, labs = process_data(
631
+ annotation_path=proc_state.annotation_dir,
632
+ prediction_path=proc_state.prediction_dir,
633
+ mapping_path=mapping_path,
634
+ sample_duration=sample_duration_value,
635
+ min_overlap=min_overlap_value,
636
+ recording_duration=rec_dur,
637
+ columns_annotations={
638
+ "Start Time": ann_start_time,
639
+ "End Time": ann_end_time,
640
+ "Class": ann_class,
641
+ "Recording": ann_recording,
642
+ "Duration": ann_duration,
643
+ },
644
+ columns_predictions={
645
+ "Start Time": pred_start_time,
646
+ "End Time": pred_end_time,
647
+ "Class": pred_class,
648
+ "Confidence": pred_confidence,
649
+ "Recording": pred_recording,
650
+ "Duration": pred_duration,
651
+ },
652
+ selected_classes=selected_classes_list,
653
+ selected_recordings=selected_recordings_list,
654
+ metrics_list=metrics,
655
+ threshold=threshold_value,
656
+ class_wise=class_wise_value,
657
+ )
658
+
659
+ table = metrics_df.T.reset_index(names=[""])
660
+
661
+ return (
662
+ gr.update(value=table, visible=True),
663
+ gr.update(visible=True),
664
+ pa,
665
+ preds,
666
+ labs,
667
+ gr.update(),
668
+ gr.update(),
669
+ proc_state,
670
+ )
671
+ except Exception as e:
672
+ print("Error processing data:", e)
673
+ raise gr.Error(f"{loc.localize('eval-tab-error-during-processing')}: {e}") from e
674
+
675
+ # Updated calculate_button click now passes the selected classes and recordings.
676
+ calculate_button.click(
677
+ calculate_metrics,
678
+ inputs=[
679
+ mapping_file,
680
+ sample_duration,
681
+ min_overlap,
682
+ recording_duration,
683
+ annotation_columns["Start Time"],
684
+ annotation_columns["End Time"],
685
+ annotation_columns["Class"],
686
+ annotation_columns["Recording"],
687
+ annotation_columns["Duration"],
688
+ prediction_columns["Start Time"],
689
+ prediction_columns["End Time"],
690
+ prediction_columns["Class"],
691
+ prediction_columns["Confidence"],
692
+ prediction_columns["Recording"],
693
+ prediction_columns["Duration"],
694
+ threshold,
695
+ class_wise,
696
+ select_classes_checkboxgroup,
697
+ select_recordings_checkboxgroup,
698
+ processor_state,
699
+ *list(metrics_checkboxes.values()),
700
+ ],
701
+ outputs=[
702
+ metric_table,
703
+ action_col,
704
+ pa_state,
705
+ predictions_state,
706
+ labels_state,
707
+ select_classes_checkboxgroup,
708
+ select_recordings_checkboxgroup,
709
+ processor_state,
710
+ ],
711
+ )
712
+
713
+ def plot_metrics(pa: PerformanceAssessor, predictions, labels, class_wise_value):
714
+ if pa is None or predictions is None or labels is None:
715
+ raise gr.Error(loc.localize("eval-tab-error-calc-metrics-first"), print_exception=False)
716
+ try:
717
+ fig = pa.plot_metrics(predictions, labels, per_class_metrics=class_wise_value)
718
+ plt.close(fig)
719
+
720
+ return gr.update(visible=True, value=fig)
721
+ except Exception as e:
722
+ raise gr.Error(f"{loc.localize('eval-tab-error-plotting-metrics')}: {e}") from e
723
+
724
+ plot_metrics_button.click(
725
+ plot_metrics,
726
+ inputs=[pa_state, predictions_state, labels_state, class_wise],
727
+ outputs=[plot_output],
728
+ )
729
+
730
+ def plot_confusion_matrix(pa: PerformanceAssessor, predictions, labels):
731
+ if pa is None or predictions is None or labels is None:
732
+ raise gr.Error(loc.localize("eval-tab-error-calc-metrics-first"), print_exception=False)
733
+ try:
734
+ fig = pa.plot_confusion_matrix(predictions, labels)
735
+ plt.close(fig)
736
+
737
+ return gr.update(visible=True, value=fig)
738
+ except Exception as e:
739
+ raise gr.Error(f"{loc.localize('eval-tab-error-plotting-confusion-matrix')}: {e}") from e
740
+
741
+ plot_confusion_button.click(
742
+ plot_confusion_matrix,
743
+ inputs=[pa_state, predictions_state, labels_state],
744
+ outputs=[plot_output],
745
+ )
746
+
747
+ annotation_select_directory_btn.click(
748
+ get_selection_func("eval-annotations-dir", update_annotation_columns),
749
+ outputs=[annotation_files_state, annotation_directory_input, annotation_group]
750
+ + [annotation_columns[label] for label in ["Start Time", "End Time", "Class", "Recording", "Duration"]],
751
+ show_progress=True,
752
+ )
753
+
754
+ prediction_select_directory_btn.click(
755
+ get_selection_func("eval-predictions-dir", update_prediction_columns),
756
+ outputs=[prediction_files_state, prediction_directory_input, prediction_group]
757
+ + [prediction_columns[label] for label in ["Start Time", "End Time", "Class", "Confidence", "Recording", "Duration"]],
758
+ show_progress=True,
759
+ )
760
+
761
+ def toggle_after_selection(annotation_files, prediction_files):
762
+ return [gr.update(visible=annotation_files and prediction_files)] * 2
763
+
764
+ annotation_directory_input.change(
765
+ toggle_after_selection,
766
+ inputs=[annotation_files_state, prediction_files_state],
767
+ outputs=[mapping_group, class_recording_group],
768
+ )
769
+
770
+ prediction_directory_input.change(
771
+ toggle_after_selection,
772
+ inputs=[annotation_files_state, prediction_files_state],
773
+ outputs=[mapping_group, class_recording_group],
774
+ )
775
+
776
+ def plot_metrics_all_thresholds(pa: PerformanceAssessor, predictions, labels, class_wise_value):
777
+ if pa is None or predictions is None or labels is None:
778
+ raise gr.Error(loc.localize("eval-tab-error-calc-metrics-first"), print_exception=False)
779
+ try:
780
+ fig = pa.plot_metrics_all_thresholds(predictions, labels, per_class_metrics=class_wise_value)
781
+ plt.close(fig)
782
+
783
+ return gr.update(visible=True, value=fig)
784
+ except Exception as e:
785
+ raise gr.Error(f"{loc.localize('eval-tab-error-plotting-metrics-all-thresholds')}: {e}") from e
786
+
787
+ plot_metrics_all_thresholds_button.click(
788
+ plot_metrics_all_thresholds,
789
+ inputs=[pa_state, predictions_state, labels_state, class_wise],
790
+ outputs=[plot_output],
791
+ )
792
+
793
+
794
+ if __name__ == "__main__":
795
+ gu.open_window(build_evaluation_tab)