birdnet-analyzer 2.1.0__py3-none-any.whl → 2.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (120) hide show
  1. birdnet_analyzer/__init__.py +9 -9
  2. birdnet_analyzer/analyze/__init__.py +19 -19
  3. birdnet_analyzer/analyze/__main__.py +3 -3
  4. birdnet_analyzer/analyze/cli.py +30 -30
  5. birdnet_analyzer/analyze/core.py +268 -246
  6. birdnet_analyzer/analyze/utils.py +700 -694
  7. birdnet_analyzer/audio.py +368 -368
  8. birdnet_analyzer/cli.py +732 -732
  9. birdnet_analyzer/config.py +243 -243
  10. birdnet_analyzer/eBird_taxonomy_codes_2024E.json +13045 -13045
  11. birdnet_analyzer/embeddings/__init__.py +3 -3
  12. birdnet_analyzer/embeddings/__main__.py +3 -3
  13. birdnet_analyzer/embeddings/cli.py +12 -12
  14. birdnet_analyzer/embeddings/core.py +70 -70
  15. birdnet_analyzer/embeddings/utils.py +173 -220
  16. birdnet_analyzer/evaluation/__init__.py +189 -189
  17. birdnet_analyzer/evaluation/__main__.py +3 -3
  18. birdnet_analyzer/evaluation/assessment/metrics.py +388 -388
  19. birdnet_analyzer/evaluation/assessment/performance_assessor.py +364 -364
  20. birdnet_analyzer/evaluation/assessment/plotting.py +378 -378
  21. birdnet_analyzer/evaluation/preprocessing/data_processor.py +631 -631
  22. birdnet_analyzer/evaluation/preprocessing/utils.py +98 -98
  23. birdnet_analyzer/gui/__init__.py +19 -19
  24. birdnet_analyzer/gui/__main__.py +3 -3
  25. birdnet_analyzer/gui/analysis.py +179 -179
  26. birdnet_analyzer/gui/assets/arrow_down.svg +4 -4
  27. birdnet_analyzer/gui/assets/arrow_left.svg +4 -4
  28. birdnet_analyzer/gui/assets/arrow_right.svg +4 -4
  29. birdnet_analyzer/gui/assets/arrow_up.svg +4 -4
  30. birdnet_analyzer/gui/assets/gui.css +36 -36
  31. birdnet_analyzer/gui/assets/gui.js +89 -93
  32. birdnet_analyzer/gui/embeddings.py +638 -638
  33. birdnet_analyzer/gui/evaluation.py +801 -801
  34. birdnet_analyzer/gui/localization.py +75 -75
  35. birdnet_analyzer/gui/multi_file.py +265 -265
  36. birdnet_analyzer/gui/review.py +472 -472
  37. birdnet_analyzer/gui/segments.py +191 -191
  38. birdnet_analyzer/gui/settings.py +149 -149
  39. birdnet_analyzer/gui/single_file.py +264 -264
  40. birdnet_analyzer/gui/species.py +95 -95
  41. birdnet_analyzer/gui/train.py +687 -687
  42. birdnet_analyzer/gui/utils.py +803 -797
  43. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_af.txt +6522 -6522
  44. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ar.txt +6522 -6522
  45. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_bg.txt +6522 -6522
  46. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ca.txt +6522 -6522
  47. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_cs.txt +6522 -6522
  48. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_da.txt +6522 -6522
  49. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_de.txt +6522 -6522
  50. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_el.txt +6522 -6522
  51. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_en_uk.txt +6522 -6522
  52. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_es.txt +6522 -6522
  53. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_fi.txt +6522 -6522
  54. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_fr.txt +6522 -6522
  55. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_he.txt +6522 -6522
  56. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_hr.txt +6522 -6522
  57. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_hu.txt +6522 -6522
  58. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_in.txt +6522 -6522
  59. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_is.txt +6522 -6522
  60. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_it.txt +6522 -6522
  61. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ja.txt +6522 -6522
  62. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ko.txt +6522 -6522
  63. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_lt.txt +6522 -6522
  64. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ml.txt +6522 -6522
  65. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_nl.txt +6522 -6522
  66. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_no.txt +6522 -6522
  67. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_pl.txt +6522 -6522
  68. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_pt_BR.txt +6522 -6522
  69. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_pt_PT.txt +6522 -6522
  70. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ro.txt +6522 -6522
  71. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ru.txt +6522 -6522
  72. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sk.txt +6522 -6522
  73. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sl.txt +6522 -6522
  74. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sr.txt +6522 -6522
  75. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sv.txt +6522 -6522
  76. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_th.txt +6522 -6522
  77. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_tr.txt +6522 -6522
  78. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_uk.txt +6522 -6522
  79. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_zh.txt +6522 -6522
  80. birdnet_analyzer/lang/de.json +342 -341
  81. birdnet_analyzer/lang/en.json +342 -341
  82. birdnet_analyzer/lang/fi.json +342 -341
  83. birdnet_analyzer/lang/fr.json +342 -341
  84. birdnet_analyzer/lang/id.json +342 -341
  85. birdnet_analyzer/lang/pt-br.json +342 -341
  86. birdnet_analyzer/lang/ru.json +342 -341
  87. birdnet_analyzer/lang/se.json +342 -341
  88. birdnet_analyzer/lang/tlh.json +342 -341
  89. birdnet_analyzer/lang/zh_TW.json +342 -341
  90. birdnet_analyzer/model.py +1213 -1212
  91. birdnet_analyzer/search/__init__.py +3 -3
  92. birdnet_analyzer/search/__main__.py +3 -3
  93. birdnet_analyzer/search/cli.py +11 -11
  94. birdnet_analyzer/search/core.py +78 -78
  95. birdnet_analyzer/search/utils.py +104 -107
  96. birdnet_analyzer/segments/__init__.py +3 -3
  97. birdnet_analyzer/segments/__main__.py +3 -3
  98. birdnet_analyzer/segments/cli.py +13 -13
  99. birdnet_analyzer/segments/core.py +81 -81
  100. birdnet_analyzer/segments/utils.py +383 -383
  101. birdnet_analyzer/species/__init__.py +3 -3
  102. birdnet_analyzer/species/__main__.py +3 -3
  103. birdnet_analyzer/species/cli.py +13 -13
  104. birdnet_analyzer/species/core.py +35 -35
  105. birdnet_analyzer/species/utils.py +73 -74
  106. birdnet_analyzer/train/__init__.py +3 -3
  107. birdnet_analyzer/train/__main__.py +3 -3
  108. birdnet_analyzer/train/cli.py +13 -13
  109. birdnet_analyzer/train/core.py +113 -113
  110. birdnet_analyzer/train/utils.py +878 -877
  111. birdnet_analyzer/translate.py +132 -133
  112. birdnet_analyzer/utils.py +425 -425
  113. {birdnet_analyzer-2.1.0.dist-info → birdnet_analyzer-2.1.1.dist-info}/METADATA +147 -146
  114. birdnet_analyzer-2.1.1.dist-info/RECORD +124 -0
  115. {birdnet_analyzer-2.1.0.dist-info → birdnet_analyzer-2.1.1.dist-info}/licenses/LICENSE +18 -18
  116. birdnet_analyzer/playground.py +0 -5
  117. birdnet_analyzer-2.1.0.dist-info/RECORD +0 -125
  118. {birdnet_analyzer-2.1.0.dist-info → birdnet_analyzer-2.1.1.dist-info}/WHEEL +0 -0
  119. {birdnet_analyzer-2.1.0.dist-info → birdnet_analyzer-2.1.1.dist-info}/entry_points.txt +0 -0
  120. {birdnet_analyzer-2.1.0.dist-info → birdnet_analyzer-2.1.1.dist-info}/top_level.txt +0 -0
@@ -1,801 +1,801 @@
1
- import json
2
- import os
3
- import shutil
4
- import tempfile
5
- import typing
6
-
7
- import gradio as gr
8
- import matplotlib.pyplot as plt
9
- import pandas as pd
10
-
11
- import birdnet_analyzer.gui.localization as loc
12
- import birdnet_analyzer.gui.utils as gu
13
- from birdnet_analyzer.evaluation import process_data
14
- from birdnet_analyzer.evaluation.assessment.performance_assessor import (
15
- PerformanceAssessor,
16
- )
17
- from birdnet_analyzer.evaluation.preprocessing.data_processor import DataProcessor
18
-
19
-
20
- class ProcessorState(typing.NamedTuple):
21
- """State of the DataProcessor."""
22
-
23
- processor: DataProcessor
24
- annotation_dir: str
25
- prediction_dir: str
26
-
27
-
28
- def build_evaluation_tab():
29
- # Default columns for annotations
30
- annotation_default_columns = {
31
- "Start Time": "Begin Time (s)",
32
- "End Time": "End Time (s)",
33
- "Class": "Class",
34
- "Recording": "Begin File",
35
- "Duration": "File Duration (s)",
36
- }
37
-
38
- # Default columns for predictions
39
- prediction_default_columns = {
40
- "Start Time": "Begin Time (s)",
41
- "End Time": "End Time (s)",
42
- "Class": "Common Name",
43
- "Recording": "Begin File",
44
- "Duration": "File Duration (s)",
45
- "Confidence": "Confidence",
46
- }
47
-
48
- localized_column_labels = {
49
- "Start Time": loc.localize("eval-tab-column-start-time-label"),
50
- "End Time": loc.localize("eval-tab-column-end-time-label"),
51
- "Class": loc.localize("eval-tab-column-class-label"),
52
- "Recording": loc.localize("eval-tab-column-recording-label"),
53
- "Duration": loc.localize("eval-tab-column-duration-label"),
54
- "Confidence": loc.localize("eval-tab-column-confidence-label"),
55
- }
56
-
57
- def download_class_mapping_template():
58
- try:
59
- template_mapping = {
60
- "Predicted Class Name 1": "Annotation Class Name 1",
61
- "Predicted Class Name 2": "Annotation Class Name 2",
62
- "Predicted Class Name 3": "Annotation Class Name 3",
63
- "Predicted Class Name 4": "Annotation Class Name 4",
64
- "Predicted Class Name 5": "Annotation Class Name 5",
65
- }
66
-
67
- file_location = gu.save_file_dialog(
68
- state_key="eval-mapping-template",
69
- filetypes=("JSON (*.json)",),
70
- default_filename="class_mapping_template.json",
71
- )
72
-
73
- if file_location:
74
- with open(file_location, "w") as f:
75
- json.dump(template_mapping, f, indent=4)
76
-
77
- gr.Info(loc.localize("eval-tab-info-mapping-template-saved"))
78
- except Exception as e:
79
- print(f"Error saving mapping template: {e}")
80
- raise gr.Error(f"{loc.localize('eval-tab-error-saving-mapping-template')} {e}") from e
81
-
82
- def download_results_table(pa: PerformanceAssessor, predictions, labels, class_wise_value):
83
- if pa is None or predictions is None or labels is None:
84
- raise gr.Error(loc.localize("eval-tab-error-calc-metrics-first"), print_exception=False)
85
-
86
- try:
87
- file_location = gu.save_file_dialog(
88
- state_key="eval-results-table",
89
- filetypes=("CSV (*.csv;*.CSV)", "TSV (*.tsv;*.TSV)"),
90
- default_filename="results_table.csv",
91
- )
92
-
93
- if file_location:
94
- metrics_df = pa.calculate_metrics(predictions, labels, per_class_metrics=class_wise_value)
95
-
96
- if file_location.split(".")[-1].lower() == "tsv":
97
- metrics_df.to_csv(file_location, sep="\t", index=True)
98
- else:
99
- metrics_df.to_csv(file_location, index=True)
100
-
101
- gr.Info(loc.localize("eval-tab-info-results-table-saved"))
102
- except Exception as e:
103
- print(f"Error saving results table: {e}")
104
- raise gr.Error(f"{loc.localize('eval-tab-error-saving-results-table')} {e}") from e
105
-
106
- def download_data_table(processor_state: ProcessorState):
107
- if processor_state is None:
108
- raise gr.Error(loc.localize("eval-tab-error-calc-metrics-first"), print_exception=False)
109
- try:
110
- file_location = gu.save_file_dialog(
111
- state_key="eval-data-table",
112
- filetypes=("CSV (*.csv)", "TSV (*.tsv;*.TSV)"),
113
- default_filename="data_table.csv",
114
- )
115
- if file_location:
116
- data_df = processor_state.processor.get_sample_data()
117
-
118
- if file_location.split(".")[-1].lower() == "tsv":
119
- data_df.to_csv(file_location, sep="\t", index=False)
120
- else:
121
- data_df.to_csv(file_location, index=False)
122
-
123
- gr.Info(loc.localize("eval-tab-info-data-table-saved"))
124
- except Exception as e:
125
- raise gr.Error(f"{loc.localize('eval-tab-error-saving-data-table')} {e}") from e
126
-
127
- def get_columns_from_uploaded_files(files):
128
- columns = set()
129
-
130
- if files:
131
- for file_obj in files:
132
- try:
133
- df = pd.read_csv(file_obj, sep=None, engine="python", nrows=0)
134
- columns.update(df.columns)
135
- except Exception as e:
136
- print(f"Error reading file {file_obj}: {e}")
137
- gr.Warning(f"{loc.localize('eval-tab-warning-error-reading-file')} {file_obj}")
138
-
139
- return sorted(columns)
140
-
141
- def save_uploaded_files(files):
142
- if not files:
143
- return None
144
-
145
- temp_dir = tempfile.mkdtemp()
146
-
147
- for file_obj in files:
148
- dest_path = os.path.join(temp_dir, os.path.basename(file_obj))
149
- shutil.copy(file_obj, dest_path)
150
-
151
- return temp_dir
152
-
153
- # Single initialize_processor that can reuse given directories.
154
- def initialize_processor(
155
- annotation_files,
156
- prediction_files,
157
- mapping_file_obj,
158
- sample_duration_value,
159
- min_overlap_value,
160
- recording_duration,
161
- ann_start_time,
162
- ann_end_time,
163
- ann_class,
164
- ann_recording,
165
- ann_duration,
166
- pred_start_time,
167
- pred_end_time,
168
- pred_class,
169
- pred_confidence,
170
- pred_recording,
171
- pred_duration,
172
- annotation_dir=None,
173
- prediction_dir=None,
174
- ):
175
- if not annotation_files or not prediction_files:
176
- return [], [], None, None, None
177
-
178
- if annotation_dir is None:
179
- annotation_dir = save_uploaded_files(annotation_files)
180
-
181
- if prediction_dir is None:
182
- prediction_dir = save_uploaded_files(prediction_files)
183
-
184
- # Fallback for annotation columns.
185
- ann_start_time = ann_start_time if ann_start_time else annotation_default_columns["Start Time"]
186
- ann_end_time = ann_end_time if ann_end_time else annotation_default_columns["End Time"]
187
- ann_class = ann_class if ann_class else annotation_default_columns["Class"]
188
- ann_recording = ann_recording if ann_recording else annotation_default_columns["Recording"]
189
- ann_duration = ann_duration if ann_duration else annotation_default_columns["Duration"]
190
-
191
- # Fallback for prediction columns.
192
- pred_start_time = pred_start_time if pred_start_time else prediction_default_columns["Start Time"]
193
- pred_end_time = pred_end_time if pred_end_time else prediction_default_columns["End Time"]
194
- pred_class = pred_class if pred_class else prediction_default_columns["Class"]
195
- pred_confidence = pred_confidence if pred_confidence else prediction_default_columns["Confidence"]
196
- pred_recording = pred_recording if pred_recording else prediction_default_columns["Recording"]
197
- pred_duration = pred_duration if pred_duration else prediction_default_columns["Duration"]
198
-
199
- cols_ann = {
200
- "Start Time": ann_start_time,
201
- "End Time": ann_end_time,
202
- "Class": ann_class,
203
- "Recording": ann_recording,
204
- "Duration": ann_duration,
205
- }
206
- cols_pred = {
207
- "Start Time": pred_start_time,
208
- "End Time": pred_end_time,
209
- "Class": pred_class,
210
- "Confidence": pred_confidence,
211
- "Recording": pred_recording,
212
- "Duration": pred_duration,
213
- }
214
-
215
- # Handle mapping file: if it has a temp_files attribute use that, otherwise assume it's a filepath.
216
- if mapping_file_obj and hasattr(mapping_file_obj, "temp_files"):
217
- mapping_path = list(mapping_file_obj.temp_files)[0]
218
- else:
219
- mapping_path = mapping_file_obj if mapping_file_obj else None
220
-
221
- if mapping_path:
222
- with open(mapping_path) as f:
223
- class_mapping = json.load(f)
224
- else:
225
- class_mapping = None
226
-
227
- try:
228
- proc = DataProcessor(
229
- prediction_directory_path=prediction_dir,
230
- prediction_file_name=None,
231
- annotation_directory_path=annotation_dir,
232
- annotation_file_name=None,
233
- class_mapping=class_mapping,
234
- sample_duration=sample_duration_value,
235
- min_overlap=min_overlap_value,
236
- columns_predictions=cols_pred,
237
- columns_annotations=cols_ann,
238
- recording_duration=recording_duration,
239
- )
240
- avail_classes = list(proc.classes) # Ensure it's a list
241
- avail_recordings = proc.samples_df["filename"].unique().tolist()
242
-
243
- return avail_classes, avail_recordings, proc, annotation_dir, prediction_dir
244
- except KeyError as e:
245
- print(f"Column missing in files: {e}")
246
- raise gr.Error(f"{loc.localize('eval-tab-error-missing-col')}: " + str(e) + f". {loc.localize('eval-tab-error-missing-col-info')}") from e
247
- except Exception as e:
248
- print(f"Error initializing processor: {e}")
249
-
250
- raise gr.Error(f"{loc.localize('eval-tab-error-init-processor')}:" + str(e)) from e
251
-
252
- # update_selections is triggered when files or mapping file change.
253
- # It creates the temporary directories once and stores them along with the processor.
254
- # It now also receives the current selection values so that user selections are preserved.
255
- def update_selections(
256
- annotation_files,
257
- prediction_files,
258
- mapping_file_obj,
259
- sample_duration_value,
260
- min_overlap_value,
261
- recording_duration_value: str,
262
- ann_start_time,
263
- ann_end_time,
264
- ann_class,
265
- ann_recording,
266
- ann_duration,
267
- pred_start_time,
268
- pred_end_time,
269
- pred_class,
270
- pred_confidence,
271
- pred_recording,
272
- pred_duration,
273
- current_classes,
274
- current_recordings,
275
- ):
276
- if recording_duration_value.strip() == "":
277
- rec_dur = None
278
- else:
279
- try:
280
- rec_dur = float(recording_duration_value)
281
- except ValueError:
282
- rec_dur = None
283
-
284
- # Create temporary directories once.
285
- annotation_dir = save_uploaded_files(annotation_files)
286
- prediction_dir = save_uploaded_files(prediction_files)
287
- avail_classes, avail_recordings, proc, annotation_dir, prediction_dir = initialize_processor(
288
- annotation_files,
289
- prediction_files,
290
- mapping_file_obj,
291
- sample_duration_value,
292
- min_overlap_value,
293
- rec_dur,
294
- ann_start_time,
295
- ann_end_time,
296
- ann_class,
297
- ann_recording,
298
- ann_duration,
299
- pred_start_time,
300
- pred_end_time,
301
- pred_class,
302
- pred_confidence,
303
- pred_recording,
304
- pred_duration,
305
- annotation_dir,
306
- prediction_dir,
307
- )
308
- # Build a state dictionary to store the processor and the directories.
309
- state = ProcessorState(proc, annotation_dir, prediction_dir)
310
- # If no current selection exists, default to all available classes/recordings;
311
- # otherwise, preserve any selections that are still valid.
312
- new_classes = avail_classes if not current_classes else [c for c in current_classes if c in avail_classes] or avail_classes
313
- new_recordings = avail_recordings if not current_recordings else [r for r in current_recordings if r in avail_recordings] or avail_recordings
314
-
315
- return (
316
- gr.update(choices=avail_classes, value=new_classes),
317
- gr.update(choices=avail_recordings, value=new_recordings),
318
- state,
319
- )
320
-
321
- with gr.Tab(loc.localize("eval-tab-title")):
322
- # Custom CSS to match the layout style of other files and remove gray backgrounds.
323
- gr.Markdown(
324
- """
325
- <style>
326
- /* Grid layout for checkbox groups */
327
- .custom-checkbox-group {
328
- display: grid;
329
- grid-template-columns: repeat(auto-fit, minmax(150px, 1fr));
330
- grid-gap: 8px;
331
- }
332
- </style>
333
- """
334
- )
335
-
336
- processor_state = gr.State()
337
- pa_state = gr.State()
338
- predictions_state = gr.State()
339
- labels_state = gr.State()
340
- annotation_files_state = gr.State()
341
- prediction_files_state = gr.State()
342
- plot_name_state = gr.State()
343
-
344
- def get_selection_tables(directory):
345
- from pathlib import Path
346
-
347
- directory = Path(directory)
348
-
349
- return list(directory.glob("*.txt"))
350
-
351
- # Update column dropdowns when files are uploaded.
352
- def update_annotation_columns(uploaded_files):
353
- cols = get_columns_from_uploaded_files(uploaded_files)
354
- cols = ["", *cols]
355
- updates = []
356
-
357
- for label in ["Start Time", "End Time", "Class", "Recording", "Duration"]:
358
- default_val = annotation_default_columns.get(label)
359
- val = default_val if default_val in cols else None
360
- updates.append(gr.update(choices=cols, value=val))
361
-
362
- return updates
363
-
364
- def update_prediction_columns(uploaded_files):
365
- cols = get_columns_from_uploaded_files(uploaded_files)
366
- cols = ["", *cols]
367
- updates = []
368
-
369
- for label in ["Start Time", "End Time", "Class", "Confidence", "Recording", "Duration"]:
370
- default_val = prediction_default_columns.get(label)
371
- val = default_val if default_val in cols else None
372
- updates.append(gr.update(choices=cols, value=val))
373
-
374
- return updates
375
-
376
- def get_selection_func(state_key, on_select):
377
- def select_directory_on_empty(): # Nishant - Function modified for For Folder selection
378
- folder = gu.select_folder(state_key=state_key)
379
-
380
- if folder:
381
- files = get_selection_tables(folder)
382
- files_to_display = files[:100] + [["..."]] if len(files) > 100 else files
383
- return [files, files_to_display, gr.update(visible=True), *on_select(files)]
384
-
385
- return ["", [[loc.localize("eval-tab-no-files-found")]]]
386
-
387
- return select_directory_on_empty
388
-
389
- with gr.Row():
390
- with gr.Column():
391
- annotation_select_directory_btn = gr.Button(loc.localize("eval-tab-annotation-selection-button-label"))
392
- annotation_directory_input = gr.Matrix(
393
- interactive=False,
394
- headers=[
395
- loc.localize("eval-tab-selections-column-file-header"),
396
- ],
397
- )
398
-
399
- with gr.Column():
400
- prediction_select_directory_btn = gr.Button(loc.localize("eval-tab-prediction-selection-button-label"))
401
- prediction_directory_input = gr.Matrix(
402
- interactive=False,
403
- headers=[
404
- loc.localize("eval-tab-selections-column-file-header"),
405
- ],
406
- )
407
-
408
- # ----------------------- Annotations Columns Box -----------------------
409
- with (
410
- gr.Group(visible=False) as annotation_group,
411
- gr.Accordion(loc.localize("eval-tab-annotation-col-accordion-label"), open=True),
412
- gr.Row(),
413
- ):
414
- annotation_columns: dict[str, gr.Dropdown] = {}
415
-
416
- for col in ["Start Time", "End Time", "Class", "Recording", "Duration"]:
417
- annotation_columns[col] = gr.Dropdown(choices=[], label=localized_column_labels[col])
418
-
419
- # ----------------------- Predictions Columns Box -----------------------
420
- with (
421
- gr.Group(visible=False) as prediction_group,
422
- gr.Accordion(loc.localize("eval-tab-prediction-col-accordion-label"), open=True),
423
- gr.Row(),
424
- ):
425
- prediction_columns: dict[str, gr.Dropdown] = {}
426
-
427
- for col in ["Start Time", "End Time", "Class", "Confidence", "Recording", "Duration"]:
428
- prediction_columns[col] = gr.Dropdown(choices=[], label=localized_column_labels[col])
429
-
430
- # ----------------------- Class Mapping Box -----------------------
431
- with gr.Group(visible=False) as mapping_group:
432
- with gr.Accordion(loc.localize("eval-tab-class-mapping-accordion-label"), open=False), gr.Row():
433
- mapping_file = gr.File(
434
- label=loc.localize("eval-tab-upload-mapping-file-label"),
435
- file_count="single",
436
- file_types=[".json"],
437
- )
438
- download_mapping_button = gr.DownloadButton(label=loc.localize("eval-tab-mapping-file-template-download-button-label"))
439
-
440
- download_mapping_button.click(fn=download_class_mapping_template)
441
-
442
- # ----------------------- Classes and Recordings Selection Box -----------------------
443
- with (
444
- gr.Group(visible=False) as class_recording_group,
445
- gr.Accordion(loc.localize("eval-tab-select-classes-recordings-accordion-label"), open=False),
446
- gr.Row(),
447
- ):
448
- with gr.Column():
449
- select_classes_checkboxgroup = gr.CheckboxGroup(
450
- choices=[],
451
- value=[],
452
- label=loc.localize("eval-tab-select-classes-checkboxgroup-label"),
453
- info=loc.localize("eval-tab-select-classes-checkboxgroup-info"),
454
- interactive=True,
455
- elem_classes="custom-checkbox-group",
456
- )
457
-
458
- with gr.Column():
459
- select_recordings_checkboxgroup = gr.CheckboxGroup(
460
- choices=[],
461
- value=[],
462
- label=loc.localize("eval-tab-select-recordings-checkboxgroup-label"),
463
- info=loc.localize("eval-tab-select-recordings-checkboxgroup-info"),
464
- interactive=True,
465
- elem_classes="custom-checkbox-group",
466
- )
467
-
468
- # ----------------------- Parameters Box -----------------------
469
- with gr.Group(), gr.Accordion(loc.localize("eval-tab-parameters-accordion-label"), open=False), gr.Row():
470
- sample_duration = gr.Number(
471
- value=3,
472
- label=loc.localize("eval-tab-sample-duration-number-label"),
473
- precision=0,
474
- info=loc.localize("eval-tab-sample-duration-number-info"),
475
- )
476
- recording_duration = gr.Textbox(
477
- label=loc.localize("eval-tab-recording-duration-textbox-label"),
478
- placeholder=loc.localize("eval-tab-recording-duration-textbox-placeholder"),
479
- info=loc.localize("eval-tab-recording-duration-textbox-info"),
480
- )
481
- min_overlap = gr.Number(
482
- value=0.5,
483
- label=loc.localize("eval-tab-min-overlap-number-label"),
484
- info=loc.localize("eval-tab-min-overlap-number-info"),
485
- )
486
- threshold = gr.Slider(
487
- minimum=0.01,
488
- maximum=0.99,
489
- value=0.1,
490
- label=loc.localize("eval-tab-threshold-number-label"),
491
- info=loc.localize("eval-tab-threshold-number-info"),
492
- )
493
- class_wise = gr.Checkbox(
494
- label=loc.localize("eval-tab-classwise-checkbox-label"),
495
- value=False,
496
- info=loc.localize("eval-tab-classwise-checkbox-info"),
497
- )
498
-
499
- # ----------------------- Metrics Box -----------------------
500
- with gr.Group(), gr.Accordion(loc.localize("eval-tab-metrics-accordian-label"), open=False), gr.Row():
501
- metric_info = {
502
- "AUROC": loc.localize("eval-tab-auroc-checkbox-info"),
503
- "Precision": loc.localize("eval-tab-precision-checkbox-info"),
504
- "Recall": loc.localize("eval-tab-recall-checkbox-info"),
505
- "F1 Score": loc.localize("eval-tab-f1-score-checkbox-info"),
506
- "Average Precision (AP)": loc.localize("eval-tab-ap-checkbox-info"),
507
- "Accuracy": loc.localize("eval-tab-accuracy-checkbox-info"),
508
- }
509
- metrics_checkboxes = {}
510
-
511
- for metric_name, description in metric_info.items():
512
- metrics_checkboxes[metric_name.lower()] = gr.Checkbox(label=metric_name, value=True, info=description)
513
-
514
- # ----------------------- Actions Box -----------------------
515
-
516
- calculate_button = gr.Button(loc.localize("eval-tab-calculate-metrics-button-label"), variant="huggingface")
517
-
518
- with gr.Column(visible=False) as action_col:
519
- with gr.Row():
520
- plot_metrics_button = gr.Button(loc.localize("eval-tab-plot-metrics-button-label"))
521
- plot_confusion_button = gr.Button(loc.localize("eval-tab-plot-confusion-matrix-button-label"))
522
- plot_metrics_all_thresholds_button = gr.Button(loc.localize("eval-tab-plot-metrics-all-thresholds-button-label"))
523
-
524
- with gr.Row():
525
- download_results_button = gr.DownloadButton(loc.localize("eval-tab-result-table-download-button-label"))
526
- download_data_button = gr.DownloadButton(loc.localize("eval-tab-data-table-download-button-label"))
527
-
528
- download_results_button.click(
529
- fn=download_results_table,
530
- inputs=[pa_state, predictions_state, labels_state, class_wise],
531
- )
532
- download_data_button.click(fn=download_data_table, inputs=[processor_state])
533
- metric_table = gr.Dataframe(show_label=False, type="pandas", visible=False, interactive=False)
534
-
535
- with gr.Group(visible=False) as plot_group:
536
- plot_output = gr.Plot(show_label=False)
537
- plot_output_dl_btn = gr.Button("Download plot", size="sm")
538
-
539
- # Update available selections (classes and recordings) and the processor state when files or mapping file change.
540
- # Also pass the current selection values so that user selections are preserved.
541
- for comp in list(annotation_columns.values()) + list(prediction_columns.values()) + [mapping_file]:
542
- comp.change(
543
- fn=update_selections,
544
- inputs=[
545
- annotation_files_state,
546
- prediction_files_state,
547
- mapping_file,
548
- sample_duration,
549
- min_overlap,
550
- recording_duration,
551
- annotation_columns["Start Time"],
552
- annotation_columns["End Time"],
553
- annotation_columns["Class"],
554
- annotation_columns["Recording"],
555
- annotation_columns["Duration"],
556
- prediction_columns["Start Time"],
557
- prediction_columns["End Time"],
558
- prediction_columns["Class"],
559
- prediction_columns["Confidence"],
560
- prediction_columns["Recording"],
561
- prediction_columns["Duration"],
562
- select_classes_checkboxgroup,
563
- select_recordings_checkboxgroup,
564
- ],
565
- outputs=[
566
- select_classes_checkboxgroup,
567
- select_recordings_checkboxgroup,
568
- processor_state,
569
- ],
570
- )
571
-
572
- # calculate_metrics now uses the stored temporary directories from processor_state.
573
- # The function now accepts selected_classes and selected_recordings as inputs.
574
- def calculate_metrics(
575
- mapping_file_obj,
576
- sample_duration_value,
577
- min_overlap_value,
578
- recording_duration_value: str,
579
- ann_start_time,
580
- ann_end_time,
581
- ann_class,
582
- ann_recording,
583
- ann_duration,
584
- pred_start_time,
585
- pred_end_time,
586
- pred_class,
587
- pred_confidence,
588
- pred_recording,
589
- pred_duration,
590
- threshold_value,
591
- class_wise_value,
592
- selected_classes_list,
593
- selected_recordings_list,
594
- proc_state: ProcessorState,
595
- *metrics_checkbox_values,
596
- ):
597
- selected_metrics = []
598
-
599
- for value, (m_lower, _) in zip(metrics_checkbox_values, metrics_checkboxes.items(), strict=True):
600
- if value:
601
- selected_metrics.append(m_lower)
602
-
603
- valid_metrics = {
604
- "accuracy": "accuracy",
605
- "recall": "recall",
606
- "precision": "precision",
607
- "f1 score": "f1",
608
- "average precision (ap)": "ap",
609
- "auroc": "auroc",
610
- }
611
- metrics = tuple(valid_metrics[m] for m in selected_metrics if m in valid_metrics)
612
-
613
- # Fall back to available classes from processor state if none selected.
614
- if not selected_classes_list and proc_state and proc_state.processor:
615
- selected_classes_list = list(proc_state.processor.classes)
616
-
617
- if not selected_classes_list:
618
- raise gr.Error(loc.localize("eval-tab-error-no-class-selected"))
619
-
620
- if recording_duration_value.strip() == "":
621
- rec_dur = None
622
- else:
623
- try:
624
- rec_dur = float(recording_duration_value)
625
- except ValueError as e:
626
- raise gr.Error(loc.localize("eval-tab-error-no-valid-recording-duration")) from e
627
-
628
- if mapping_file_obj and hasattr(mapping_file_obj, "temp_files"):
629
- mapping_path = list(mapping_file_obj.temp_files)[0]
630
- else:
631
- mapping_path = mapping_file_obj if mapping_file_obj else None
632
-
633
- try:
634
- metrics_df, pa, preds, labs = process_data(
635
- annotation_path=proc_state.annotation_dir,
636
- prediction_path=proc_state.prediction_dir,
637
- mapping_path=mapping_path,
638
- sample_duration=sample_duration_value,
639
- min_overlap=min_overlap_value,
640
- recording_duration=rec_dur,
641
- columns_annotations={
642
- "Start Time": ann_start_time,
643
- "End Time": ann_end_time,
644
- "Class": ann_class,
645
- "Recording": ann_recording,
646
- "Duration": ann_duration,
647
- },
648
- columns_predictions={
649
- "Start Time": pred_start_time,
650
- "End Time": pred_end_time,
651
- "Class": pred_class,
652
- "Confidence": pred_confidence,
653
- "Recording": pred_recording,
654
- "Duration": pred_duration,
655
- },
656
- selected_classes=selected_classes_list,
657
- selected_recordings=selected_recordings_list,
658
- metrics_list=metrics,
659
- threshold=threshold_value,
660
- class_wise=class_wise_value,
661
- )
662
-
663
- table = metrics_df.T.reset_index(names=[""])
664
-
665
- return (
666
- gr.update(value=table, visible=True),
667
- gr.update(visible=True),
668
- pa,
669
- preds,
670
- labs,
671
- gr.update(),
672
- gr.update(),
673
- proc_state,
674
- )
675
- except Exception as e:
676
- print("Error processing data:", e)
677
- raise gr.Error(f"{loc.localize('eval-tab-error-during-processing')}: {e}") from e
678
-
679
- # Updated calculate_button click now passes the selected classes and recordings.
680
- calculate_button.click(
681
- calculate_metrics,
682
- inputs=[
683
- mapping_file,
684
- sample_duration,
685
- min_overlap,
686
- recording_duration,
687
- annotation_columns["Start Time"],
688
- annotation_columns["End Time"],
689
- annotation_columns["Class"],
690
- annotation_columns["Recording"],
691
- annotation_columns["Duration"],
692
- prediction_columns["Start Time"],
693
- prediction_columns["End Time"],
694
- prediction_columns["Class"],
695
- prediction_columns["Confidence"],
696
- prediction_columns["Recording"],
697
- prediction_columns["Duration"],
698
- threshold,
699
- class_wise,
700
- select_classes_checkboxgroup,
701
- select_recordings_checkboxgroup,
702
- processor_state,
703
- *list(metrics_checkboxes.values()),
704
- ],
705
- outputs=[
706
- metric_table,
707
- action_col,
708
- pa_state,
709
- predictions_state,
710
- labels_state,
711
- select_classes_checkboxgroup,
712
- select_recordings_checkboxgroup,
713
- processor_state,
714
- ],
715
- )
716
-
717
- def plot_metrics(pa: PerformanceAssessor, predictions, labels, class_wise_value):
718
- if pa is None or predictions is None or labels is None:
719
- raise gr.Error(loc.localize("eval-tab-error-calc-metrics-first"), print_exception=False)
720
- try:
721
- fig = pa.plot_metrics(predictions, labels, per_class_metrics=class_wise_value)
722
- plt.close(fig)
723
-
724
- return gr.update(visible=True), gr.update(value=fig), "metrics"
725
- except Exception as e:
726
- raise gr.Error(f"{loc.localize('eval-tab-error-plotting-metrics')}: {e}") from e
727
-
728
- plot_metrics_button.click(
729
- plot_metrics,
730
- inputs=[pa_state, predictions_state, labels_state, class_wise],
731
- outputs=[plot_group, plot_output, plot_name_state],
732
- )
733
-
734
- def plot_confusion_matrix(pa: PerformanceAssessor, predictions, labels):
735
- if pa is None or predictions is None or labels is None:
736
- raise gr.Error(loc.localize("eval-tab-error-calc-metrics-first"), print_exception=False)
737
- try:
738
- fig = pa.plot_confusion_matrix(predictions, labels)
739
- plt.close(fig)
740
-
741
- return gr.update(visible=True), fig, "confusion_matrix"
742
- except Exception as e:
743
- raise gr.Error(f"{loc.localize('eval-tab-error-plotting-confusion-matrix')}: {e}") from e
744
-
745
- plot_confusion_button.click(
746
- plot_confusion_matrix,
747
- inputs=[pa_state, predictions_state, labels_state],
748
- outputs=[plot_group, plot_output, plot_name_state],
749
- )
750
-
751
- annotation_select_directory_btn.click(
752
- get_selection_func("eval-annotations-dir", update_annotation_columns),
753
- outputs=[annotation_files_state, annotation_directory_input, annotation_group]
754
- + [annotation_columns[label] for label in ["Start Time", "End Time", "Class", "Recording", "Duration"]],
755
- show_progress=True,
756
- )
757
-
758
- prediction_select_directory_btn.click(
759
- get_selection_func("eval-predictions-dir", update_prediction_columns),
760
- outputs=[prediction_files_state, prediction_directory_input, prediction_group]
761
- + [prediction_columns[label] for label in ["Start Time", "End Time", "Class", "Confidence", "Recording", "Duration"]],
762
- show_progress=True,
763
- )
764
-
765
- def toggle_after_selection(annotation_files, prediction_files):
766
- return [gr.update(visible=annotation_files and prediction_files)] * 2
767
-
768
- annotation_directory_input.change(
769
- toggle_after_selection,
770
- inputs=[annotation_files_state, prediction_files_state],
771
- outputs=[mapping_group, class_recording_group],
772
- )
773
-
774
- prediction_directory_input.change(
775
- toggle_after_selection,
776
- inputs=[annotation_files_state, prediction_files_state],
777
- outputs=[mapping_group, class_recording_group],
778
- )
779
-
780
- def plot_metrics_all_thresholds(pa: PerformanceAssessor, predictions, labels, class_wise_value):
781
- if pa is None or predictions is None or labels is None:
782
- raise gr.Error(loc.localize("eval-tab-error-calc-metrics-first"), print_exception=False)
783
- try:
784
- fig = pa.plot_metrics_all_thresholds(predictions, labels, per_class_metrics=class_wise_value)
785
- plt.close(fig)
786
-
787
- return gr.update(visible=True), gr.update(value=fig), "metrics_all_thresholds"
788
- except Exception as e:
789
- raise gr.Error(f"{loc.localize('eval-tab-error-plotting-metrics-all-thresholds')}: {e}") from e
790
-
791
- plot_metrics_all_thresholds_button.click(
792
- plot_metrics_all_thresholds,
793
- inputs=[pa_state, predictions_state, labels_state, class_wise],
794
- outputs=[plot_group, plot_output, plot_name_state],
795
- )
796
-
797
- plot_output_dl_btn.click(gu.download_plot, inputs=[plot_output, plot_name_state])
798
-
799
-
800
- if __name__ == "__main__":
801
- gu.open_window(build_evaluation_tab)
1
+ import json
2
+ import os
3
+ import shutil
4
+ import tempfile
5
+ import typing
6
+
7
+ import gradio as gr
8
+ import matplotlib.pyplot as plt
9
+ import pandas as pd
10
+
11
+ import birdnet_analyzer.gui.localization as loc
12
+ import birdnet_analyzer.gui.utils as gu
13
+ from birdnet_analyzer.evaluation import process_data
14
+ from birdnet_analyzer.evaluation.assessment.performance_assessor import (
15
+ PerformanceAssessor,
16
+ )
17
+ from birdnet_analyzer.evaluation.preprocessing.data_processor import DataProcessor
18
+
19
+
20
+ class ProcessorState(typing.NamedTuple):
21
+ """State of the DataProcessor."""
22
+
23
+ processor: DataProcessor
24
+ annotation_dir: str
25
+ prediction_dir: str
26
+
27
+
28
+ def build_evaluation_tab():
29
+ # Default columns for annotations
30
+ annotation_default_columns = {
31
+ "Start Time": "Begin Time (s)",
32
+ "End Time": "End Time (s)",
33
+ "Class": "Class",
34
+ "Recording": "Begin File",
35
+ "Duration": "File Duration (s)",
36
+ }
37
+
38
+ # Default columns for predictions
39
+ prediction_default_columns = {
40
+ "Start Time": "Begin Time (s)",
41
+ "End Time": "End Time (s)",
42
+ "Class": "Common Name",
43
+ "Recording": "Begin File",
44
+ "Duration": "File Duration (s)",
45
+ "Confidence": "Confidence",
46
+ }
47
+
48
+ localized_column_labels = {
49
+ "Start Time": loc.localize("eval-tab-column-start-time-label"),
50
+ "End Time": loc.localize("eval-tab-column-end-time-label"),
51
+ "Class": loc.localize("eval-tab-column-class-label"),
52
+ "Recording": loc.localize("eval-tab-column-recording-label"),
53
+ "Duration": loc.localize("eval-tab-column-duration-label"),
54
+ "Confidence": loc.localize("eval-tab-column-confidence-label"),
55
+ }
56
+
57
+ def download_class_mapping_template():
58
+ try:
59
+ template_mapping = {
60
+ "Predicted Class Name 1": "Annotation Class Name 1",
61
+ "Predicted Class Name 2": "Annotation Class Name 2",
62
+ "Predicted Class Name 3": "Annotation Class Name 3",
63
+ "Predicted Class Name 4": "Annotation Class Name 4",
64
+ "Predicted Class Name 5": "Annotation Class Name 5",
65
+ }
66
+
67
+ file_location = gu.save_file_dialog(
68
+ state_key="eval-mapping-template",
69
+ filetypes=("JSON (*.json)",),
70
+ default_filename="class_mapping_template.json",
71
+ )
72
+
73
+ if file_location:
74
+ with open(file_location, "w") as f:
75
+ json.dump(template_mapping, f, indent=4)
76
+
77
+ gr.Info(loc.localize("eval-tab-info-mapping-template-saved"))
78
+ except Exception as e:
79
+ print(f"Error saving mapping template: {e}")
80
+ raise gr.Error(f"{loc.localize('eval-tab-error-saving-mapping-template')} {e}") from e
81
+
82
+ def download_results_table(pa: PerformanceAssessor, predictions, labels, class_wise_value):
83
+ if pa is None or predictions is None or labels is None:
84
+ raise gr.Error(loc.localize("eval-tab-error-calc-metrics-first"), print_exception=False)
85
+
86
+ try:
87
+ file_location = gu.save_file_dialog(
88
+ state_key="eval-results-table",
89
+ filetypes=("CSV (*.csv;*.CSV)", "TSV (*.tsv;*.TSV)"),
90
+ default_filename="results_table.csv",
91
+ )
92
+
93
+ if file_location:
94
+ metrics_df = pa.calculate_metrics(predictions, labels, per_class_metrics=class_wise_value)
95
+
96
+ if file_location.split(".")[-1].lower() == "tsv":
97
+ metrics_df.to_csv(file_location, sep="\t", index=True)
98
+ else:
99
+ metrics_df.to_csv(file_location, index=True)
100
+
101
+ gr.Info(loc.localize("eval-tab-info-results-table-saved"))
102
+ except Exception as e:
103
+ print(f"Error saving results table: {e}")
104
+ raise gr.Error(f"{loc.localize('eval-tab-error-saving-results-table')} {e}") from e
105
+
106
+ def download_data_table(processor_state: ProcessorState):
107
+ if processor_state is None:
108
+ raise gr.Error(loc.localize("eval-tab-error-calc-metrics-first"), print_exception=False)
109
+ try:
110
+ file_location = gu.save_file_dialog(
111
+ state_key="eval-data-table",
112
+ filetypes=("CSV (*.csv)", "TSV (*.tsv;*.TSV)"),
113
+ default_filename="data_table.csv",
114
+ )
115
+ if file_location:
116
+ data_df = processor_state.processor.get_sample_data()
117
+
118
+ if file_location.split(".")[-1].lower() == "tsv":
119
+ data_df.to_csv(file_location, sep="\t", index=False)
120
+ else:
121
+ data_df.to_csv(file_location, index=False)
122
+
123
+ gr.Info(loc.localize("eval-tab-info-data-table-saved"))
124
+ except Exception as e:
125
+ raise gr.Error(f"{loc.localize('eval-tab-error-saving-data-table')} {e}") from e
126
+
127
+ def get_columns_from_uploaded_files(files):
128
+ columns = set()
129
+
130
+ if files:
131
+ for file_obj in files:
132
+ try:
133
+ df = pd.read_csv(file_obj, sep=None, engine="python", nrows=0)
134
+ columns.update(df.columns)
135
+ except Exception as e:
136
+ print(f"Error reading file {file_obj}: {e}")
137
+ gr.Warning(f"{loc.localize('eval-tab-warning-error-reading-file')} {file_obj}")
138
+
139
+ return sorted(columns)
140
+
141
+ def save_uploaded_files(files):
142
+ if not files:
143
+ return None
144
+
145
+ temp_dir = tempfile.mkdtemp()
146
+
147
+ for file_obj in files:
148
+ dest_path = os.path.join(temp_dir, os.path.basename(file_obj))
149
+ shutil.copy(file_obj, dest_path)
150
+
151
+ return temp_dir
152
+
153
+ # Single initialize_processor that can reuse given directories.
154
+ def initialize_processor(
155
+ annotation_files,
156
+ prediction_files,
157
+ mapping_file_obj,
158
+ sample_duration_value,
159
+ min_overlap_value,
160
+ recording_duration,
161
+ ann_start_time,
162
+ ann_end_time,
163
+ ann_class,
164
+ ann_recording,
165
+ ann_duration,
166
+ pred_start_time,
167
+ pred_end_time,
168
+ pred_class,
169
+ pred_confidence,
170
+ pred_recording,
171
+ pred_duration,
172
+ annotation_dir=None,
173
+ prediction_dir=None,
174
+ ):
175
+ if not annotation_files or not prediction_files:
176
+ return [], [], None, None, None
177
+
178
+ if annotation_dir is None:
179
+ annotation_dir = save_uploaded_files(annotation_files)
180
+
181
+ if prediction_dir is None:
182
+ prediction_dir = save_uploaded_files(prediction_files)
183
+
184
+ # Fallback for annotation columns.
185
+ ann_start_time = ann_start_time if ann_start_time else annotation_default_columns["Start Time"]
186
+ ann_end_time = ann_end_time if ann_end_time else annotation_default_columns["End Time"]
187
+ ann_class = ann_class if ann_class else annotation_default_columns["Class"]
188
+ ann_recording = ann_recording if ann_recording else annotation_default_columns["Recording"]
189
+ ann_duration = ann_duration if ann_duration else annotation_default_columns["Duration"]
190
+
191
+ # Fallback for prediction columns.
192
+ pred_start_time = pred_start_time if pred_start_time else prediction_default_columns["Start Time"]
193
+ pred_end_time = pred_end_time if pred_end_time else prediction_default_columns["End Time"]
194
+ pred_class = pred_class if pred_class else prediction_default_columns["Class"]
195
+ pred_confidence = pred_confidence if pred_confidence else prediction_default_columns["Confidence"]
196
+ pred_recording = pred_recording if pred_recording else prediction_default_columns["Recording"]
197
+ pred_duration = pred_duration if pred_duration else prediction_default_columns["Duration"]
198
+
199
+ cols_ann = {
200
+ "Start Time": ann_start_time,
201
+ "End Time": ann_end_time,
202
+ "Class": ann_class,
203
+ "Recording": ann_recording,
204
+ "Duration": ann_duration,
205
+ }
206
+ cols_pred = {
207
+ "Start Time": pred_start_time,
208
+ "End Time": pred_end_time,
209
+ "Class": pred_class,
210
+ "Confidence": pred_confidence,
211
+ "Recording": pred_recording,
212
+ "Duration": pred_duration,
213
+ }
214
+
215
+ # Handle mapping file: if it has a temp_files attribute use that, otherwise assume it's a filepath.
216
+ if mapping_file_obj and hasattr(mapping_file_obj, "temp_files"):
217
+ mapping_path = list(mapping_file_obj.temp_files)[0]
218
+ else:
219
+ mapping_path = mapping_file_obj if mapping_file_obj else None
220
+
221
+ if mapping_path:
222
+ with open(mapping_path) as f:
223
+ class_mapping = json.load(f)
224
+ else:
225
+ class_mapping = None
226
+
227
+ try:
228
+ proc = DataProcessor(
229
+ prediction_directory_path=prediction_dir,
230
+ prediction_file_name=None,
231
+ annotation_directory_path=annotation_dir,
232
+ annotation_file_name=None,
233
+ class_mapping=class_mapping,
234
+ sample_duration=sample_duration_value,
235
+ min_overlap=min_overlap_value,
236
+ columns_predictions=cols_pred,
237
+ columns_annotations=cols_ann,
238
+ recording_duration=recording_duration,
239
+ )
240
+ avail_classes = list(proc.classes) # Ensure it's a list
241
+ avail_recordings = proc.samples_df["filename"].unique().tolist()
242
+
243
+ return avail_classes, avail_recordings, proc, annotation_dir, prediction_dir
244
+ except KeyError as e:
245
+ print(f"Column missing in files: {e}")
246
+ raise gr.Error(f"{loc.localize('eval-tab-error-missing-col')}: " + str(e) + f". {loc.localize('eval-tab-error-missing-col-info')}") from e
247
+ except Exception as e:
248
+ print(f"Error initializing processor: {e}")
249
+
250
+ raise gr.Error(f"{loc.localize('eval-tab-error-init-processor')}:" + str(e)) from e
251
+
252
+ # update_selections is triggered when files or mapping file change.
253
+ # It creates the temporary directories once and stores them along with the processor.
254
+ # It now also receives the current selection values so that user selections are preserved.
255
+ def update_selections(
256
+ annotation_files,
257
+ prediction_files,
258
+ mapping_file_obj,
259
+ sample_duration_value,
260
+ min_overlap_value,
261
+ recording_duration_value: str,
262
+ ann_start_time,
263
+ ann_end_time,
264
+ ann_class,
265
+ ann_recording,
266
+ ann_duration,
267
+ pred_start_time,
268
+ pred_end_time,
269
+ pred_class,
270
+ pred_confidence,
271
+ pred_recording,
272
+ pred_duration,
273
+ current_classes,
274
+ current_recordings,
275
+ ):
276
+ if recording_duration_value.strip() == "":
277
+ rec_dur = None
278
+ else:
279
+ try:
280
+ rec_dur = float(recording_duration_value)
281
+ except ValueError:
282
+ rec_dur = None
283
+
284
+ # Create temporary directories once.
285
+ annotation_dir = save_uploaded_files(annotation_files)
286
+ prediction_dir = save_uploaded_files(prediction_files)
287
+ avail_classes, avail_recordings, proc, annotation_dir, prediction_dir = initialize_processor(
288
+ annotation_files,
289
+ prediction_files,
290
+ mapping_file_obj,
291
+ sample_duration_value,
292
+ min_overlap_value,
293
+ rec_dur,
294
+ ann_start_time,
295
+ ann_end_time,
296
+ ann_class,
297
+ ann_recording,
298
+ ann_duration,
299
+ pred_start_time,
300
+ pred_end_time,
301
+ pred_class,
302
+ pred_confidence,
303
+ pred_recording,
304
+ pred_duration,
305
+ annotation_dir,
306
+ prediction_dir,
307
+ )
308
+ # Build a state dictionary to store the processor and the directories.
309
+ state = ProcessorState(proc, annotation_dir, prediction_dir)
310
+ # If no current selection exists, default to all available classes/recordings;
311
+ # otherwise, preserve any selections that are still valid.
312
+ new_classes = avail_classes if not current_classes else [c for c in current_classes if c in avail_classes] or avail_classes
313
+ new_recordings = avail_recordings if not current_recordings else [r for r in current_recordings if r in avail_recordings] or avail_recordings
314
+
315
+ return (
316
+ gr.update(choices=avail_classes, value=new_classes),
317
+ gr.update(choices=avail_recordings, value=new_recordings),
318
+ state,
319
+ )
320
+
321
+ with gr.Tab(loc.localize("eval-tab-title")):
322
+ # Custom CSS to match the layout style of other files and remove gray backgrounds.
323
+ gr.Markdown(
324
+ """
325
+ <style>
326
+ /* Grid layout for checkbox groups */
327
+ .custom-checkbox-group {
328
+ display: grid;
329
+ grid-template-columns: repeat(auto-fit, minmax(150px, 1fr));
330
+ grid-gap: 8px;
331
+ }
332
+ </style>
333
+ """
334
+ )
335
+
336
+ processor_state = gr.State()
337
+ pa_state = gr.State()
338
+ predictions_state = gr.State()
339
+ labels_state = gr.State()
340
+ annotation_files_state = gr.State()
341
+ prediction_files_state = gr.State()
342
+ plot_name_state = gr.State()
343
+
344
+ def get_selection_tables(directory):
345
+ from pathlib import Path
346
+
347
+ directory = Path(directory)
348
+
349
+ return list(directory.glob("*.txt"))
350
+
351
+ # Update column dropdowns when files are uploaded.
352
+ def update_annotation_columns(uploaded_files):
353
+ cols = get_columns_from_uploaded_files(uploaded_files)
354
+ cols = ["", *cols]
355
+ updates = []
356
+
357
+ for label in ["Start Time", "End Time", "Class", "Recording", "Duration"]:
358
+ default_val = annotation_default_columns.get(label)
359
+ val = default_val if default_val in cols else None
360
+ updates.append(gr.update(choices=cols, value=val))
361
+
362
+ return updates
363
+
364
+ def update_prediction_columns(uploaded_files):
365
+ cols = get_columns_from_uploaded_files(uploaded_files)
366
+ cols = ["", *cols]
367
+ updates = []
368
+
369
+ for label in ["Start Time", "End Time", "Class", "Confidence", "Recording", "Duration"]:
370
+ default_val = prediction_default_columns.get(label)
371
+ val = default_val if default_val in cols else None
372
+ updates.append(gr.update(choices=cols, value=val))
373
+
374
+ return updates
375
+
376
+ def get_selection_func(state_key, on_select):
377
+ def select_directory_on_empty(): # Nishant - Function modified for For Folder selection
378
+ folder = gu.select_folder(state_key=state_key)
379
+
380
+ if folder:
381
+ files = get_selection_tables(folder)
382
+ files_to_display = [*files[:100], ["..."]] if len(files) > 100 else files
383
+ return [files, files_to_display, gr.update(visible=True), *on_select(files)]
384
+
385
+ return ["", [[loc.localize("eval-tab-no-files-found")]]]
386
+
387
+ return select_directory_on_empty
388
+
389
+ with gr.Row():
390
+ with gr.Column():
391
+ annotation_select_directory_btn = gr.Button(loc.localize("eval-tab-annotation-selection-button-label"))
392
+ annotation_directory_input = gr.Matrix(
393
+ interactive=False,
394
+ headers=[
395
+ loc.localize("eval-tab-selections-column-file-header"),
396
+ ],
397
+ )
398
+
399
+ with gr.Column():
400
+ prediction_select_directory_btn = gr.Button(loc.localize("eval-tab-prediction-selection-button-label"))
401
+ prediction_directory_input = gr.Matrix(
402
+ interactive=False,
403
+ headers=[
404
+ loc.localize("eval-tab-selections-column-file-header"),
405
+ ],
406
+ )
407
+
408
+ # ----------------------- Annotations Columns Box -----------------------
409
+ with (
410
+ gr.Group(visible=False) as annotation_group,
411
+ gr.Accordion(loc.localize("eval-tab-annotation-col-accordion-label"), open=True),
412
+ gr.Row(),
413
+ ):
414
+ annotation_columns: dict[str, gr.Dropdown] = {}
415
+
416
+ for col in ["Start Time", "End Time", "Class", "Recording", "Duration"]:
417
+ annotation_columns[col] = gr.Dropdown(choices=[], label=localized_column_labels[col])
418
+
419
+ # ----------------------- Predictions Columns Box -----------------------
420
+ with (
421
+ gr.Group(visible=False) as prediction_group,
422
+ gr.Accordion(loc.localize("eval-tab-prediction-col-accordion-label"), open=True),
423
+ gr.Row(),
424
+ ):
425
+ prediction_columns: dict[str, gr.Dropdown] = {}
426
+
427
+ for col in ["Start Time", "End Time", "Class", "Confidence", "Recording", "Duration"]:
428
+ prediction_columns[col] = gr.Dropdown(choices=[], label=localized_column_labels[col])
429
+
430
+ # ----------------------- Class Mapping Box -----------------------
431
+ with gr.Group(visible=False) as mapping_group:
432
+ with gr.Accordion(loc.localize("eval-tab-class-mapping-accordion-label"), open=False), gr.Row():
433
+ mapping_file = gr.File(
434
+ label=loc.localize("eval-tab-upload-mapping-file-label"),
435
+ file_count="single",
436
+ file_types=[".json"],
437
+ )
438
+ download_mapping_button = gr.DownloadButton(label=loc.localize("eval-tab-mapping-file-template-download-button-label"))
439
+
440
+ download_mapping_button.click(fn=download_class_mapping_template)
441
+
442
+ # ----------------------- Classes and Recordings Selection Box -----------------------
443
+ with (
444
+ gr.Group(visible=False) as class_recording_group,
445
+ gr.Accordion(loc.localize("eval-tab-select-classes-recordings-accordion-label"), open=False),
446
+ gr.Row(),
447
+ ):
448
+ with gr.Column():
449
+ select_classes_checkboxgroup = gr.CheckboxGroup(
450
+ choices=[],
451
+ value=[],
452
+ label=loc.localize("eval-tab-select-classes-checkboxgroup-label"),
453
+ info=loc.localize("eval-tab-select-classes-checkboxgroup-info"),
454
+ interactive=True,
455
+ elem_classes="custom-checkbox-group",
456
+ )
457
+
458
+ with gr.Column():
459
+ select_recordings_checkboxgroup = gr.CheckboxGroup(
460
+ choices=[],
461
+ value=[],
462
+ label=loc.localize("eval-tab-select-recordings-checkboxgroup-label"),
463
+ info=loc.localize("eval-tab-select-recordings-checkboxgroup-info"),
464
+ interactive=True,
465
+ elem_classes="custom-checkbox-group",
466
+ )
467
+
468
+ # ----------------------- Parameters Box -----------------------
469
+ with gr.Group(), gr.Accordion(loc.localize("eval-tab-parameters-accordion-label"), open=False), gr.Row():
470
+ sample_duration = gr.Number(
471
+ value=3,
472
+ label=loc.localize("eval-tab-sample-duration-number-label"),
473
+ precision=0,
474
+ info=loc.localize("eval-tab-sample-duration-number-info"),
475
+ )
476
+ recording_duration = gr.Textbox(
477
+ label=loc.localize("eval-tab-recording-duration-textbox-label"),
478
+ placeholder=loc.localize("eval-tab-recording-duration-textbox-placeholder"),
479
+ info=loc.localize("eval-tab-recording-duration-textbox-info"),
480
+ )
481
+ min_overlap = gr.Number(
482
+ value=0.5,
483
+ label=loc.localize("eval-tab-min-overlap-number-label"),
484
+ info=loc.localize("eval-tab-min-overlap-number-info"),
485
+ )
486
+ threshold = gr.Slider(
487
+ minimum=0.01,
488
+ maximum=0.99,
489
+ value=0.1,
490
+ label=loc.localize("eval-tab-threshold-number-label"),
491
+ info=loc.localize("eval-tab-threshold-number-info"),
492
+ )
493
+ class_wise = gr.Checkbox(
494
+ label=loc.localize("eval-tab-classwise-checkbox-label"),
495
+ value=False,
496
+ info=loc.localize("eval-tab-classwise-checkbox-info"),
497
+ )
498
+
499
+ # ----------------------- Metrics Box -----------------------
500
+ with gr.Group(), gr.Accordion(loc.localize("eval-tab-metrics-accordian-label"), open=False), gr.Row():
501
+ metric_info = {
502
+ "AUROC": loc.localize("eval-tab-auroc-checkbox-info"),
503
+ "Precision": loc.localize("eval-tab-precision-checkbox-info"),
504
+ "Recall": loc.localize("eval-tab-recall-checkbox-info"),
505
+ "F1 Score": loc.localize("eval-tab-f1-score-checkbox-info"),
506
+ "Average Precision (AP)": loc.localize("eval-tab-ap-checkbox-info"),
507
+ "Accuracy": loc.localize("eval-tab-accuracy-checkbox-info"),
508
+ }
509
+ metrics_checkboxes = {}
510
+
511
+ for metric_name, description in metric_info.items():
512
+ metrics_checkboxes[metric_name.lower()] = gr.Checkbox(label=metric_name, value=True, info=description)
513
+
514
+ # ----------------------- Actions Box -----------------------
515
+
516
+ calculate_button = gr.Button(loc.localize("eval-tab-calculate-metrics-button-label"), variant="huggingface")
517
+
518
+ with gr.Column(visible=False) as action_col:
519
+ with gr.Row():
520
+ plot_metrics_button = gr.Button(loc.localize("eval-tab-plot-metrics-button-label"))
521
+ plot_confusion_button = gr.Button(loc.localize("eval-tab-plot-confusion-matrix-button-label"))
522
+ plot_metrics_all_thresholds_button = gr.Button(loc.localize("eval-tab-plot-metrics-all-thresholds-button-label"))
523
+
524
+ with gr.Row():
525
+ download_results_button = gr.DownloadButton(loc.localize("eval-tab-result-table-download-button-label"))
526
+ download_data_button = gr.DownloadButton(loc.localize("eval-tab-data-table-download-button-label"))
527
+
528
+ download_results_button.click(
529
+ fn=download_results_table,
530
+ inputs=[pa_state, predictions_state, labels_state, class_wise],
531
+ )
532
+ download_data_button.click(fn=download_data_table, inputs=[processor_state])
533
+ metric_table = gr.Dataframe(show_label=False, type="pandas", visible=False, interactive=False)
534
+
535
+ with gr.Group(visible=False) as plot_group:
536
+ plot_output = gr.Plot(show_label=False)
537
+ plot_output_dl_btn = gr.Button("Download plot", size="sm")
538
+
539
+ # Update available selections (classes and recordings) and the processor state when files or mapping file change.
540
+ # Also pass the current selection values so that user selections are preserved.
541
+ for comp in list(annotation_columns.values()) + list(prediction_columns.values()) + [mapping_file]:
542
+ comp.change(
543
+ fn=update_selections,
544
+ inputs=[
545
+ annotation_files_state,
546
+ prediction_files_state,
547
+ mapping_file,
548
+ sample_duration,
549
+ min_overlap,
550
+ recording_duration,
551
+ annotation_columns["Start Time"],
552
+ annotation_columns["End Time"],
553
+ annotation_columns["Class"],
554
+ annotation_columns["Recording"],
555
+ annotation_columns["Duration"],
556
+ prediction_columns["Start Time"],
557
+ prediction_columns["End Time"],
558
+ prediction_columns["Class"],
559
+ prediction_columns["Confidence"],
560
+ prediction_columns["Recording"],
561
+ prediction_columns["Duration"],
562
+ select_classes_checkboxgroup,
563
+ select_recordings_checkboxgroup,
564
+ ],
565
+ outputs=[
566
+ select_classes_checkboxgroup,
567
+ select_recordings_checkboxgroup,
568
+ processor_state,
569
+ ],
570
+ )
571
+
572
+ # calculate_metrics now uses the stored temporary directories from processor_state.
573
+ # The function now accepts selected_classes and selected_recordings as inputs.
574
+ def calculate_metrics(
575
+ mapping_file_obj,
576
+ sample_duration_value,
577
+ min_overlap_value,
578
+ recording_duration_value: str,
579
+ ann_start_time,
580
+ ann_end_time,
581
+ ann_class,
582
+ ann_recording,
583
+ ann_duration,
584
+ pred_start_time,
585
+ pred_end_time,
586
+ pred_class,
587
+ pred_confidence,
588
+ pred_recording,
589
+ pred_duration,
590
+ threshold_value,
591
+ class_wise_value,
592
+ selected_classes_list,
593
+ selected_recordings_list,
594
+ proc_state: ProcessorState,
595
+ *metrics_checkbox_values,
596
+ ):
597
+ selected_metrics = []
598
+
599
+ for value, (m_lower, _) in zip(metrics_checkbox_values, metrics_checkboxes.items(), strict=True):
600
+ if value:
601
+ selected_metrics.append(m_lower)
602
+
603
+ valid_metrics = {
604
+ "accuracy": "accuracy",
605
+ "recall": "recall",
606
+ "precision": "precision",
607
+ "f1 score": "f1",
608
+ "average precision (ap)": "ap",
609
+ "auroc": "auroc",
610
+ }
611
+ metrics = tuple(valid_metrics[m] for m in selected_metrics if m in valid_metrics)
612
+
613
+ # Fall back to available classes from processor state if none selected.
614
+ if not selected_classes_list and proc_state and proc_state.processor:
615
+ selected_classes_list = list(proc_state.processor.classes)
616
+
617
+ if not selected_classes_list:
618
+ raise gr.Error(loc.localize("eval-tab-error-no-class-selected"))
619
+
620
+ if recording_duration_value.strip() == "":
621
+ rec_dur = None
622
+ else:
623
+ try:
624
+ rec_dur = float(recording_duration_value)
625
+ except ValueError as e:
626
+ raise gr.Error(loc.localize("eval-tab-error-no-valid-recording-duration")) from e
627
+
628
+ if mapping_file_obj and hasattr(mapping_file_obj, "temp_files"):
629
+ mapping_path = list(mapping_file_obj.temp_files)[0]
630
+ else:
631
+ mapping_path = mapping_file_obj if mapping_file_obj else None
632
+
633
+ try:
634
+ metrics_df, pa, preds, labs = process_data(
635
+ annotation_path=proc_state.annotation_dir,
636
+ prediction_path=proc_state.prediction_dir,
637
+ mapping_path=mapping_path,
638
+ sample_duration=sample_duration_value,
639
+ min_overlap=min_overlap_value,
640
+ recording_duration=rec_dur,
641
+ columns_annotations={
642
+ "Start Time": ann_start_time,
643
+ "End Time": ann_end_time,
644
+ "Class": ann_class,
645
+ "Recording": ann_recording,
646
+ "Duration": ann_duration,
647
+ },
648
+ columns_predictions={
649
+ "Start Time": pred_start_time,
650
+ "End Time": pred_end_time,
651
+ "Class": pred_class,
652
+ "Confidence": pred_confidence,
653
+ "Recording": pred_recording,
654
+ "Duration": pred_duration,
655
+ },
656
+ selected_classes=selected_classes_list,
657
+ selected_recordings=selected_recordings_list,
658
+ metrics_list=metrics,
659
+ threshold=threshold_value,
660
+ class_wise=class_wise_value,
661
+ )
662
+
663
+ table = metrics_df.T.reset_index(names=[""])
664
+
665
+ return (
666
+ gr.update(value=table, visible=True),
667
+ gr.update(visible=True),
668
+ pa,
669
+ preds,
670
+ labs,
671
+ gr.update(),
672
+ gr.update(),
673
+ proc_state,
674
+ )
675
+ except Exception as e:
676
+ print("Error processing data:", e)
677
+ raise gr.Error(f"{loc.localize('eval-tab-error-during-processing')}: {e}") from e
678
+
679
+ # Updated calculate_button click now passes the selected classes and recordings.
680
+ calculate_button.click(
681
+ calculate_metrics,
682
+ inputs=[
683
+ mapping_file,
684
+ sample_duration,
685
+ min_overlap,
686
+ recording_duration,
687
+ annotation_columns["Start Time"],
688
+ annotation_columns["End Time"],
689
+ annotation_columns["Class"],
690
+ annotation_columns["Recording"],
691
+ annotation_columns["Duration"],
692
+ prediction_columns["Start Time"],
693
+ prediction_columns["End Time"],
694
+ prediction_columns["Class"],
695
+ prediction_columns["Confidence"],
696
+ prediction_columns["Recording"],
697
+ prediction_columns["Duration"],
698
+ threshold,
699
+ class_wise,
700
+ select_classes_checkboxgroup,
701
+ select_recordings_checkboxgroup,
702
+ processor_state,
703
+ *list(metrics_checkboxes.values()),
704
+ ],
705
+ outputs=[
706
+ metric_table,
707
+ action_col,
708
+ pa_state,
709
+ predictions_state,
710
+ labels_state,
711
+ select_classes_checkboxgroup,
712
+ select_recordings_checkboxgroup,
713
+ processor_state,
714
+ ],
715
+ )
716
+
717
+ def plot_metrics(pa: PerformanceAssessor, predictions, labels, class_wise_value):
718
+ if pa is None or predictions is None or labels is None:
719
+ raise gr.Error(loc.localize("eval-tab-error-calc-metrics-first"), print_exception=False)
720
+ try:
721
+ fig = pa.plot_metrics(predictions, labels, per_class_metrics=class_wise_value)
722
+ plt.close(fig)
723
+
724
+ return gr.update(visible=True), gr.update(value=fig), "metrics"
725
+ except Exception as e:
726
+ raise gr.Error(f"{loc.localize('eval-tab-error-plotting-metrics')}: {e}") from e
727
+
728
+ plot_metrics_button.click(
729
+ plot_metrics,
730
+ inputs=[pa_state, predictions_state, labels_state, class_wise],
731
+ outputs=[plot_group, plot_output, plot_name_state],
732
+ )
733
+
734
+ def plot_confusion_matrix(pa: PerformanceAssessor, predictions, labels):
735
+ if pa is None or predictions is None or labels is None:
736
+ raise gr.Error(loc.localize("eval-tab-error-calc-metrics-first"), print_exception=False)
737
+ try:
738
+ fig = pa.plot_confusion_matrix(predictions, labels)
739
+ plt.close(fig)
740
+
741
+ return gr.update(visible=True), fig, "confusion_matrix"
742
+ except Exception as e:
743
+ raise gr.Error(f"{loc.localize('eval-tab-error-plotting-confusion-matrix')}: {e}") from e
744
+
745
+ plot_confusion_button.click(
746
+ plot_confusion_matrix,
747
+ inputs=[pa_state, predictions_state, labels_state],
748
+ outputs=[plot_group, plot_output, plot_name_state],
749
+ )
750
+
751
+ annotation_select_directory_btn.click(
752
+ get_selection_func("eval-annotations-dir", update_annotation_columns),
753
+ outputs=[annotation_files_state, annotation_directory_input, annotation_group]
754
+ + [annotation_columns[label] for label in ["Start Time", "End Time", "Class", "Recording", "Duration"]],
755
+ show_progress="full",
756
+ )
757
+
758
+ prediction_select_directory_btn.click(
759
+ get_selection_func("eval-predictions-dir", update_prediction_columns),
760
+ outputs=[prediction_files_state, prediction_directory_input, prediction_group]
761
+ + [prediction_columns[label] for label in ["Start Time", "End Time", "Class", "Confidence", "Recording", "Duration"]],
762
+ show_progress="full",
763
+ )
764
+
765
+ def toggle_after_selection(annotation_files, prediction_files):
766
+ return [gr.update(visible=annotation_files and prediction_files)] * 2
767
+
768
+ annotation_directory_input.change(
769
+ toggle_after_selection,
770
+ inputs=[annotation_files_state, prediction_files_state],
771
+ outputs=[mapping_group, class_recording_group],
772
+ )
773
+
774
+ prediction_directory_input.change(
775
+ toggle_after_selection,
776
+ inputs=[annotation_files_state, prediction_files_state],
777
+ outputs=[mapping_group, class_recording_group],
778
+ )
779
+
780
+ def plot_metrics_all_thresholds(pa: PerformanceAssessor, predictions, labels, class_wise_value):
781
+ if pa is None or predictions is None or labels is None:
782
+ raise gr.Error(loc.localize("eval-tab-error-calc-metrics-first"), print_exception=False)
783
+ try:
784
+ fig = pa.plot_metrics_all_thresholds(predictions, labels, per_class_metrics=class_wise_value)
785
+ plt.close(fig)
786
+
787
+ return gr.update(visible=True), gr.update(value=fig), "metrics_all_thresholds"
788
+ except Exception as e:
789
+ raise gr.Error(f"{loc.localize('eval-tab-error-plotting-metrics-all-thresholds')}: {e}") from e
790
+
791
+ plot_metrics_all_thresholds_button.click(
792
+ plot_metrics_all_thresholds,
793
+ inputs=[pa_state, predictions_state, labels_state, class_wise],
794
+ outputs=[plot_group, plot_output, plot_name_state],
795
+ )
796
+
797
+ plot_output_dl_btn.click(gu.download_plot, inputs=[plot_output, plot_name_state])
798
+
799
+
800
+ if __name__ == "__main__":
801
+ gu.open_window(build_evaluation_tab)