birdnet-analyzer 2.0.0__py3-none-any.whl → 2.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (123) hide show
  1. birdnet_analyzer/__init__.py +9 -8
  2. birdnet_analyzer/analyze/__init__.py +19 -5
  3. birdnet_analyzer/analyze/__main__.py +3 -4
  4. birdnet_analyzer/analyze/cli.py +30 -25
  5. birdnet_analyzer/analyze/core.py +246 -245
  6. birdnet_analyzer/analyze/utils.py +694 -701
  7. birdnet_analyzer/audio.py +368 -372
  8. birdnet_analyzer/cli.py +732 -707
  9. birdnet_analyzer/config.py +243 -242
  10. birdnet_analyzer/eBird_taxonomy_codes_2024E.json +13046 -0
  11. birdnet_analyzer/embeddings/__init__.py +3 -4
  12. birdnet_analyzer/embeddings/__main__.py +3 -3
  13. birdnet_analyzer/embeddings/cli.py +12 -13
  14. birdnet_analyzer/embeddings/core.py +70 -70
  15. birdnet_analyzer/embeddings/utils.py +220 -193
  16. birdnet_analyzer/evaluation/__init__.py +189 -195
  17. birdnet_analyzer/evaluation/__main__.py +3 -3
  18. birdnet_analyzer/evaluation/assessment/__init__.py +0 -0
  19. birdnet_analyzer/evaluation/assessment/metrics.py +388 -0
  20. birdnet_analyzer/evaluation/assessment/performance_assessor.py +364 -0
  21. birdnet_analyzer/evaluation/assessment/plotting.py +378 -0
  22. birdnet_analyzer/evaluation/preprocessing/__init__.py +0 -0
  23. birdnet_analyzer/evaluation/preprocessing/data_processor.py +631 -0
  24. birdnet_analyzer/evaluation/preprocessing/utils.py +98 -0
  25. birdnet_analyzer/gui/__init__.py +19 -23
  26. birdnet_analyzer/gui/__main__.py +3 -3
  27. birdnet_analyzer/gui/analysis.py +179 -174
  28. birdnet_analyzer/gui/assets/arrow_down.svg +4 -4
  29. birdnet_analyzer/gui/assets/arrow_left.svg +4 -4
  30. birdnet_analyzer/gui/assets/arrow_right.svg +4 -4
  31. birdnet_analyzer/gui/assets/arrow_up.svg +4 -4
  32. birdnet_analyzer/gui/assets/gui.css +36 -28
  33. birdnet_analyzer/gui/assets/gui.js +93 -93
  34. birdnet_analyzer/gui/embeddings.py +638 -620
  35. birdnet_analyzer/gui/evaluation.py +801 -813
  36. birdnet_analyzer/gui/localization.py +75 -68
  37. birdnet_analyzer/gui/multi_file.py +265 -246
  38. birdnet_analyzer/gui/review.py +472 -527
  39. birdnet_analyzer/gui/segments.py +191 -191
  40. birdnet_analyzer/gui/settings.py +149 -129
  41. birdnet_analyzer/gui/single_file.py +264 -269
  42. birdnet_analyzer/gui/species.py +95 -95
  43. birdnet_analyzer/gui/train.py +687 -698
  44. birdnet_analyzer/gui/utils.py +797 -808
  45. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_af.txt +6522 -6522
  46. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ar.txt +6522 -6522
  47. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_bg.txt +6522 -6522
  48. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ca.txt +6522 -6522
  49. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_cs.txt +6522 -6522
  50. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_da.txt +6522 -6522
  51. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_de.txt +6522 -6522
  52. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_el.txt +6522 -6522
  53. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_en_uk.txt +6522 -6522
  54. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_es.txt +6522 -6522
  55. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_fi.txt +6522 -6522
  56. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_fr.txt +6522 -6522
  57. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_he.txt +6522 -6522
  58. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_hr.txt +6522 -6522
  59. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_hu.txt +6522 -6522
  60. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_in.txt +6522 -6522
  61. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_is.txt +6522 -6522
  62. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_it.txt +6522 -6522
  63. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ja.txt +6522 -6522
  64. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ko.txt +6522 -6522
  65. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_lt.txt +6522 -6522
  66. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ml.txt +6522 -6522
  67. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_nl.txt +6522 -6522
  68. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_no.txt +6522 -6522
  69. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_pl.txt +6522 -6522
  70. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_pt_BR.txt +6522 -6522
  71. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_pt_PT.txt +6522 -6522
  72. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ro.txt +6522 -6522
  73. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ru.txt +6522 -6522
  74. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sk.txt +6522 -6522
  75. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sl.txt +6522 -6522
  76. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sr.txt +6522 -6522
  77. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sv.txt +6522 -6522
  78. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_th.txt +6522 -6522
  79. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_tr.txt +6522 -6522
  80. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_uk.txt +6522 -6522
  81. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_zh.txt +6522 -6522
  82. birdnet_analyzer/lang/de.json +341 -334
  83. birdnet_analyzer/lang/en.json +341 -334
  84. birdnet_analyzer/lang/fi.json +341 -334
  85. birdnet_analyzer/lang/fr.json +341 -334
  86. birdnet_analyzer/lang/id.json +341 -334
  87. birdnet_analyzer/lang/pt-br.json +341 -334
  88. birdnet_analyzer/lang/ru.json +341 -334
  89. birdnet_analyzer/lang/se.json +341 -334
  90. birdnet_analyzer/lang/tlh.json +341 -334
  91. birdnet_analyzer/lang/zh_TW.json +341 -334
  92. birdnet_analyzer/model.py +1212 -1243
  93. birdnet_analyzer/playground.py +5 -0
  94. birdnet_analyzer/search/__init__.py +3 -3
  95. birdnet_analyzer/search/__main__.py +3 -3
  96. birdnet_analyzer/search/cli.py +11 -12
  97. birdnet_analyzer/search/core.py +78 -78
  98. birdnet_analyzer/search/utils.py +107 -111
  99. birdnet_analyzer/segments/__init__.py +3 -3
  100. birdnet_analyzer/segments/__main__.py +3 -3
  101. birdnet_analyzer/segments/cli.py +13 -14
  102. birdnet_analyzer/segments/core.py +81 -78
  103. birdnet_analyzer/segments/utils.py +383 -394
  104. birdnet_analyzer/species/__init__.py +3 -3
  105. birdnet_analyzer/species/__main__.py +3 -3
  106. birdnet_analyzer/species/cli.py +13 -14
  107. birdnet_analyzer/species/core.py +35 -35
  108. birdnet_analyzer/species/utils.py +74 -75
  109. birdnet_analyzer/train/__init__.py +3 -3
  110. birdnet_analyzer/train/__main__.py +3 -3
  111. birdnet_analyzer/train/cli.py +13 -14
  112. birdnet_analyzer/train/core.py +113 -113
  113. birdnet_analyzer/train/utils.py +877 -847
  114. birdnet_analyzer/translate.py +133 -104
  115. birdnet_analyzer/utils.py +425 -419
  116. {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.1.0.dist-info}/METADATA +146 -129
  117. birdnet_analyzer-2.1.0.dist-info/RECORD +125 -0
  118. {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.1.0.dist-info}/WHEEL +1 -1
  119. {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.1.0.dist-info}/licenses/LICENSE +18 -18
  120. birdnet_analyzer/eBird_taxonomy_codes_2021E.json +0 -25280
  121. birdnet_analyzer-2.0.0.dist-info/RECORD +0 -117
  122. {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.1.0.dist-info}/entry_points.txt +0 -0
  123. {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.1.0.dist-info}/top_level.txt +0 -0
@@ -1,698 +1,687 @@
1
- import multiprocessing
2
- import os
3
- from functools import partial
4
- from pathlib import Path
5
-
6
- import gradio as gr
7
-
8
- import birdnet_analyzer.config as cfg
9
- import birdnet_analyzer.gui.localization as loc
10
- import birdnet_analyzer.gui.utils as gu
11
- import birdnet_analyzer.utils as utils
12
-
13
- _GRID_MAX_HEIGHT = 240
14
-
15
-
16
- def select_subdirectories(state_key=None):
17
- """Creates a directory selection dialog.
18
-
19
- Returns:
20
- A tuples of (directory, list of subdirectories) or (None, None) if the dialog was canceled.
21
- """
22
- dir_name = gu.select_folder(state_key=state_key)
23
-
24
- if dir_name:
25
- subdirs = utils.list_subdirectories(dir_name)
26
- labels = []
27
-
28
- for folder in subdirs:
29
- labels_in_folder = folder.split(",")
30
-
31
- for label in labels_in_folder:
32
- if label not in labels:
33
- labels.append(label)
34
-
35
- return dir_name, [[label] for label in sorted(labels)]
36
-
37
- return None, None
38
-
39
-
40
- @gu.gui_runtime_error_handler
41
- def start_training(
42
- data_dir,
43
- test_data_dir,
44
- crop_mode,
45
- crop_overlap,
46
- fmin,
47
- fmax,
48
- output_dir,
49
- classifier_name,
50
- model_save_mode,
51
- cache_mode,
52
- cache_file,
53
- cache_file_name,
54
- autotune,
55
- autotune_trials,
56
- autotune_executions_per_trials,
57
- epochs,
58
- batch_size,
59
- learning_rate,
60
- focal_loss,
61
- focal_loss_gamma,
62
- focal_loss_alpha,
63
- hidden_units,
64
- dropout,
65
- label_smoothing,
66
- use_mixup,
67
- upsampling_ratio,
68
- upsampling_mode,
69
- model_format,
70
- audio_speed,
71
- progress=gr.Progress(),
72
- ):
73
- """Starts the training of a custom classifier.
74
-
75
- Args:
76
- data_dir: Directory containing the training data.
77
- test_data_dir: Directory containing the test data.
78
- crop_mode: Mode for cropping audio samples.
79
- crop_overlap: Overlap ratio for audio segments.
80
- fmin: Minimum frequency for bandpass filtering.
81
- fmax: Maximum frequency for bandpass filtering.
82
- output_dir: Directory to save the trained model.
83
- classifier_name: Name of the custom classifier.
84
- model_save_mode: Save mode for the model (replace or append).
85
- cache_mode: Cache mode for training data (load, save, or None).
86
- cache_file: Path to the cache file.
87
- cache_file_name: Name of the cache file.
88
- autotune: Whether to use hyperparameter autotuning.
89
- autotune_trials: Number of trials for autotuning.
90
- autotune_executions_per_trials: Number of executions per autotuning trial.
91
- epochs: Number of training epochs.
92
- batch_size: Batch size for training.
93
- learning_rate: Learning rate for the optimizer.
94
- focal_loss: Whether to use focal loss for training.
95
- focal_loss_gamma: Gamma parameter for focal loss.
96
- focal_loss_alpha: Alpha parameter for focal loss.
97
- hidden_units: Number of hidden units in the droput: Dropout rate for regularization.
98
- dropout: Dropout rate for regularization.
99
- label_smoothing: Whether to apply label smoothing for training.
100
- use_mixup: Whether to use mixup data augmentation.
101
- upsampling_ratio: Ratio for upsampling underrepresented classes.
102
- upsampling_mode: Mode for upsampling (repeat, mean, smote).
103
- model_format: Format to save the trained model (tflite, raven, both).
104
- audio_speed: Speed factor for audio playback.
105
-
106
- Returns:
107
- Returns a matplotlib.pyplot figure.
108
- """
109
- import matplotlib
110
- import matplotlib.pyplot as plt
111
-
112
- from birdnet_analyzer.train.utils import train_model
113
-
114
- # Skip training data validation when cache mode is "load"
115
- if cache_mode != "load":
116
- gu.validate(data_dir, loc.localize("validation-no-training-data-selected"))
117
-
118
- gu.validate(output_dir, loc.localize("validation-no-directory-for-classifier-selected"))
119
- gu.validate(classifier_name, loc.localize("validation-no-valid-classifier-name"))
120
-
121
- if not epochs or epochs < 0:
122
- raise gr.Error(loc.localize("validation-no-valid-epoch-number"))
123
-
124
- if not batch_size or batch_size < 0:
125
- raise gr.Error(loc.localize("validation-no-valid-batch-size"))
126
-
127
- if not learning_rate or learning_rate < 0:
128
- raise gr.Error(loc.localize("validation-no-valid-learning-rate"))
129
-
130
- if fmin < cfg.SIG_FMIN or fmax > cfg.SIG_FMAX or fmin > fmax:
131
- raise gr.Error(f"{loc.localize('validation-no-valid-frequency')} [{cfg.SIG_FMIN}, {cfg.SIG_FMAX}]")
132
-
133
- cfg.TRAIN_WITH_FOCAL_LOSS = focal_loss
134
- cfg.FOCAL_LOSS_GAMMA = max(0.0, float(focal_loss_gamma))
135
- cfg.FOCAL_LOSS_ALPHA = max(0.0, min(1.0, float(focal_loss_alpha)))
136
-
137
- if not hidden_units or hidden_units < 0:
138
- hidden_units = 0
139
-
140
- cfg.TRAIN_DROPOUT = max(0.0, min(1.0, float(dropout)))
141
-
142
- if progress is not None:
143
- progress((0, epochs), desc=loc.localize("progress-build-classifier"), unit="epochs")
144
-
145
- cfg.TRAIN_DATA_PATH = data_dir
146
- cfg.TEST_DATA_PATH = test_data_dir
147
- cfg.SAMPLE_CROP_MODE = crop_mode
148
- cfg.SIG_OVERLAP = max(0.0, min(2.9, float(crop_overlap)))
149
- cfg.CUSTOM_CLASSIFIER = str(Path(output_dir) / classifier_name)
150
- cfg.TRAIN_EPOCHS = int(epochs)
151
- cfg.TRAIN_BATCH_SIZE = int(batch_size)
152
- cfg.TRAIN_LEARNING_RATE = learning_rate
153
- cfg.TRAIN_HIDDEN_UNITS = int(hidden_units)
154
- cfg.TRAIN_WITH_LABEL_SMOOTHING = label_smoothing
155
- cfg.TRAIN_WITH_MIXUP = use_mixup
156
- cfg.UPSAMPLING_RATIO = min(max(0, upsampling_ratio), 1)
157
- cfg.UPSAMPLING_MODE = upsampling_mode
158
- cfg.TRAINED_MODEL_OUTPUT_FORMAT = model_format
159
-
160
- cfg.BANDPASS_FMIN = max(0, min(cfg.SIG_FMAX, int(fmin)))
161
- cfg.BANDPASS_FMAX = max(cfg.SIG_FMIN, min(cfg.SIG_FMAX, int(fmax)))
162
-
163
- cfg.TRAINED_MODEL_SAVE_MODE = model_save_mode
164
- cfg.TRAIN_CACHE_MODE = cache_mode
165
- cfg.TRAIN_CACHE_FILE = os.path.join(cache_file, cache_file_name) if cache_mode == "save" else cache_file
166
- cfg.TFLITE_THREADS = 1
167
- cfg.CPU_THREADS = max(1, multiprocessing.cpu_count() - 1) # let's use everything we have (well, almost)
168
-
169
- if cache_mode == "load" and not os.path.isfile(cfg.TRAIN_CACHE_FILE):
170
- raise gr.Error(loc.localize("validation-no-cache-file-selected"))
171
-
172
- cfg.AUTOTUNE = autotune
173
- cfg.AUTOTUNE_TRIALS = autotune_trials
174
- cfg.AUTOTUNE_EXECUTIONS_PER_TRIAL = int(autotune_executions_per_trials)
175
-
176
- cfg.AUDIO_SPEED = max(0.1, 1.0 / (audio_speed * -1)) if audio_speed < 0 else max(1.0, float(audio_speed))
177
-
178
- def data_load_progression(num_files, num_total_files, label):
179
- if progress is not None:
180
- progress(
181
- (num_files, num_total_files),
182
- total=num_total_files,
183
- unit="files",
184
- desc=f"{loc.localize('progress-loading-data')} '{label}'",
185
- )
186
-
187
- def epoch_progression(epoch, logs=None):
188
- if progress is not None:
189
- if epoch + 1 == epochs:
190
- progress(
191
- (epoch + 1, epochs),
192
- total=epochs,
193
- unit="epochs",
194
- desc=f"{loc.localize('progress-saving')} {cfg.CUSTOM_CLASSIFIER}",
195
- )
196
- else:
197
- progress((epoch + 1, epochs), total=epochs, unit="epochs", desc=loc.localize("progress-training"))
198
-
199
- def trial_progression(trial):
200
- if progress is not None:
201
- progress(
202
- (trial, autotune_trials), total=autotune_trials, unit="trials", desc=loc.localize("progress-autotune")
203
- )
204
-
205
- try:
206
- history_result = train_model(
207
- on_epoch_end=epoch_progression,
208
- on_trial_result=trial_progression,
209
- on_data_load_end=data_load_progression,
210
- autotune_directory=gu.APPDIR if utils.FROZEN else "autotune",
211
- )
212
-
213
- # Unpack history and metrics
214
- history, metrics = history_result
215
- except Exception as e:
216
- if e.args and len(e.args) > 1:
217
- raise gr.Error(loc.localize(e.args[1]))
218
- else:
219
- raise gr.Error(f"{e}")
220
-
221
- if len(history.epoch) < epochs:
222
- gr.Info(loc.localize("training-tab-early-stoppage-msg"))
223
-
224
- auprc = history.history["val_AUPRC"]
225
- auroc = history.history["val_AUROC"]
226
-
227
- matplotlib.use("agg")
228
-
229
- fig = plt.figure()
230
- plt.plot(auprc, label="AUPRC")
231
- plt.plot(auroc, label="AUROC")
232
- plt.legend()
233
- plt.xlabel("Epoch")
234
-
235
- return fig, metrics
236
-
237
-
238
- def build_train_tab():
239
- with gr.Tab(loc.localize("training-tab-title")):
240
- input_directory_state = gr.State()
241
- output_directory_state = gr.State()
242
- test_data_dir_state = gr.State()
243
-
244
- with gr.Row():
245
- with gr.Column():
246
- select_directory_btn = gr.Button(loc.localize("training-tab-input-selection-button-label"))
247
- directory_input = gr.List(
248
- headers=[loc.localize("training-tab-classes-dataframe-column-classes-header")],
249
- interactive=False,
250
- max_height=_GRID_MAX_HEIGHT,
251
- )
252
- select_directory_btn.click(
253
- partial(select_subdirectories, state_key="train-data-dir"),
254
- outputs=[input_directory_state, directory_input],
255
- show_progress=False,
256
- )
257
-
258
- select_test_directory_btn = gr.Button(loc.localize("training-tab-test-data-selection-button-label"))
259
- test_directory_input = gr.List(
260
- headers=[loc.localize("training-tab-classes-dataframe-column-classes-header")],
261
- interactive=False,
262
- max_height=_GRID_MAX_HEIGHT,
263
- )
264
- select_test_directory_btn.click(
265
- partial(select_subdirectories, state_key="test-data-dir"),
266
- outputs=[test_data_dir_state, test_directory_input],
267
- show_progress=False,
268
- )
269
-
270
- with gr.Column():
271
- select_classifier_directory_btn = gr.Button(loc.localize("training-tab-select-output-button-label"))
272
-
273
- with gr.Column():
274
- classifier_name = gr.Textbox(
275
- "CustomClassifier",
276
- visible=False,
277
- info=loc.localize("training-tab-classifier-textbox-info"),
278
- )
279
- output_format = gr.Radio(
280
- ["tflite", "raven", (loc.localize("training-tab-output-format-both"), "both")],
281
- value=cfg.TRAINED_MODEL_OUTPUT_FORMAT,
282
- label=loc.localize("training-tab-output-format-radio-label"),
283
- info=loc.localize("training-tab-output-format-radio-info"),
284
- visible=False,
285
- )
286
-
287
- def select_directory_and_update_tb():
288
- dir_name = gu.select_folder(state_key="train-output-dir")
289
-
290
- if dir_name:
291
- return (
292
- dir_name,
293
- gr.Textbox(label=dir_name, visible=True),
294
- gr.Radio(visible=True, interactive=True),
295
- )
296
-
297
- return None, None
298
-
299
- select_classifier_directory_btn.click(
300
- select_directory_and_update_tb,
301
- outputs=[output_directory_state, classifier_name, output_format],
302
- show_progress=False,
303
- )
304
-
305
- with gr.Row():
306
- cache_file_state = gr.State()
307
- cache_mode = gr.Radio(
308
- [
309
- (loc.localize("training-tab-cache-mode-radio-option-none"), None),
310
- (loc.localize("training-tab-cache-mode-radio-option-load"), "load"),
311
- (loc.localize("training-tab-cache-mode-radio-option-save"), "save"),
312
- ],
313
- value=cfg.TRAIN_CACHE_MODE,
314
- label=loc.localize("training-tab-cache-mode-radio-label"),
315
- info=loc.localize("training-tab-cache-mode-radio-info"),
316
- )
317
- with gr.Column(visible=False) as new_cache_file_row:
318
- select_cache_file_directory_btn = gr.Button(
319
- loc.localize("training-tab-cache-select-directory-button-label")
320
- )
321
-
322
- with gr.Column():
323
- cache_file_name = gr.Textbox(
324
- "train_cache.npz",
325
- visible=False,
326
- info=loc.localize("training-tab-cache-file-name-textbox-info"),
327
- )
328
-
329
- def select_directory_and_update():
330
- dir_name = gu.select_folder(state_key="train-data-cache-file-output")
331
-
332
- if dir_name:
333
- return (
334
- dir_name,
335
- gr.Textbox(label=dir_name, visible=True),
336
- )
337
-
338
- return None, None
339
-
340
- select_cache_file_directory_btn.click(
341
- select_directory_and_update,
342
- outputs=[cache_file_state, cache_file_name],
343
- show_progress=False,
344
- )
345
-
346
- with gr.Column(visible=False) as load_cache_file_row:
347
- selected_cache_file_btn = gr.Button(loc.localize("training-tab-cache-select-file-button-label"))
348
- cache_file_input = gr.File(file_types=[".npz"], visible=False, interactive=False)
349
-
350
- def on_cache_file_selection_click():
351
- file = gu.select_file(("NPZ file (*.npz)",), state_key="train_data_cache_file")
352
-
353
- if file:
354
- return file, gr.File(value=file, visible=True)
355
-
356
- return None, None
357
-
358
- selected_cache_file_btn.click(
359
- on_cache_file_selection_click,
360
- outputs=[cache_file_state, cache_file_input],
361
- show_progress=False,
362
- )
363
-
364
- def on_cache_mode_change(value):
365
- return (
366
- gr.update(visible=value == "save"),
367
- gr.update(visible=value == "load"),
368
- gr.update(interactive=value != "load"),
369
- [],
370
- gr.update(interactive=value != "load"),
371
- [],
372
- gr.update(interactive=value != "load"),
373
- gr.update(interactive=value != "load"),
374
- gr.update(interactive=value != "load"),
375
- gr.update(interactive=value != "load"),
376
- gr.update(interactive=value != "load"),
377
- )
378
-
379
- with gr.Row():
380
- fmin_number = gr.Number(
381
- cfg.SIG_FMIN,
382
- minimum=0,
383
- label=loc.localize("inference-settings-fmin-number-label"),
384
- info=loc.localize("inference-settings-fmin-number-info"),
385
- )
386
-
387
- fmax_number = gr.Number(
388
- cfg.SIG_FMAX,
389
- minimum=0,
390
- label=loc.localize("inference-settings-fmax-number-label"),
391
- info=loc.localize("inference-settings-fmax-number-info"),
392
- )
393
-
394
- with gr.Row():
395
- audio_speed_slider = gr.Slider(
396
- minimum=-10,
397
- maximum=10,
398
- value=cfg.AUDIO_SPEED,
399
- step=1,
400
- label=loc.localize("training-tab-audio-speed-slider-label"),
401
- info=loc.localize("training-tab-audio-speed-slider-info"),
402
- )
403
-
404
- with gr.Row():
405
- crop_mode = gr.Radio(
406
- [
407
- (loc.localize("training-tab-crop-mode-radio-option-center"), "center"),
408
- (loc.localize("training-tab-crop-mode-radio-option-first"), "first"),
409
- (loc.localize("training-tab-crop-mode-radio-option-segments"), "segments"),
410
- (loc.localize("training-tab-crop-mode-radio-option-smart"), "smart"),
411
- ],
412
- value="center",
413
- label=loc.localize("training-tab-crop-mode-radio-label"),
414
- info=loc.localize("training-tab-crop-mode-radio-info"),
415
- )
416
-
417
- crop_overlap = gr.Slider(
418
- minimum=0,
419
- maximum=2.99,
420
- value=cfg.SIG_OVERLAP,
421
- step=0.01,
422
- label=loc.localize("training-tab-crop-overlap-number-label"),
423
- info=loc.localize("training-tab-crop-overlap-number-info"),
424
- visible=False,
425
- )
426
-
427
- def on_crop_select(new_crop_mode):
428
- # Make overlap slider visible for both "segments" and "smart" crop modes
429
- return gr.Number(
430
- visible=new_crop_mode in ["segments", "smart"], interactive=new_crop_mode in ["segments", "smart"]
431
- )
432
-
433
- crop_mode.change(on_crop_select, inputs=crop_mode, outputs=crop_overlap)
434
-
435
- cache_mode.change(
436
- on_cache_mode_change,
437
- inputs=cache_mode,
438
- outputs=[
439
- new_cache_file_row,
440
- load_cache_file_row,
441
- select_directory_btn,
442
- directory_input,
443
- select_test_directory_btn,
444
- test_directory_input,
445
- fmin_number,
446
- fmax_number,
447
- audio_speed_slider,
448
- crop_mode,
449
- crop_overlap,
450
- ],
451
- show_progress=False,
452
- )
453
-
454
- autotune_cb = gr.Checkbox(
455
- cfg.AUTOTUNE,
456
- label=loc.localize("training-tab-autotune-checkbox-label"),
457
- info=loc.localize("training-tab-autotune-checkbox-info"),
458
- )
459
-
460
- with gr.Column(visible=False) as autotune_params:
461
- with gr.Row():
462
- autotune_trials = gr.Number(
463
- cfg.AUTOTUNE_TRIALS,
464
- label=loc.localize("training-tab-autotune-trials-number-label"),
465
- info=loc.localize("training-tab-autotune-trials-number-info"),
466
- minimum=1,
467
- )
468
- autotune_executions_per_trials = gr.Number(
469
- cfg.AUTOTUNE_EXECUTIONS_PER_TRIAL,
470
- minimum=1,
471
- label=loc.localize("training-tab-autotune-executions-number-label"),
472
- info=loc.localize("training-tab-autotune-executions-number-info"),
473
- )
474
-
475
- with gr.Column() as custom_params:
476
- with gr.Row():
477
- epoch_number = gr.Number(
478
- cfg.TRAIN_EPOCHS,
479
- minimum=1,
480
- step=1,
481
- label=loc.localize("training-tab-epochs-number-label"),
482
- info=loc.localize("training-tab-epochs-number-info"),
483
- )
484
- batch_size_number = gr.Number(
485
- 32,
486
- minimum=1,
487
- step=8,
488
- label=loc.localize("training-tab-batchsize-number-label"),
489
- info=loc.localize("training-tab-batchsize-number-info"),
490
- )
491
- learning_rate_number = gr.Number(
492
- cfg.TRAIN_LEARNING_RATE,
493
- minimum=0.0001,
494
- step=0.0001,
495
- label=loc.localize("training-tab-learningrate-number-label"),
496
- info=loc.localize("training-tab-learningrate-number-info"),
497
- )
498
-
499
- with gr.Row():
500
- hidden_units_number = gr.Number(
501
- cfg.TRAIN_HIDDEN_UNITS,
502
- minimum=0,
503
- step=64,
504
- label=loc.localize("training-tab-hiddenunits-number-label"),
505
- info=loc.localize("training-tab-hiddenunits-number-info"),
506
- )
507
- dropout_number = gr.Number(
508
- cfg.TRAIN_DROPOUT,
509
- minimum=0.0,
510
- maximum=0.9,
511
- step=0.1,
512
- label=loc.localize("training-tab-dropout-number-label"),
513
- info=loc.localize("training-tab-dropout-number-info"),
514
- )
515
- use_label_smoothing = gr.Checkbox(
516
- cfg.TRAIN_WITH_LABEL_SMOOTHING,
517
- label=loc.localize("training-tab-use-labelsmoothing-checkbox-label"),
518
- info=loc.localize("training-tab-use-labelsmoothing-checkbox-info"),
519
- show_label=True,
520
- )
521
-
522
- with gr.Row():
523
- upsampling_mode = gr.Radio(
524
- [
525
- (loc.localize("training-tab-upsampling-radio-option-repeat"), "repeat"),
526
- (loc.localize("training-tab-upsampling-radio-option-mean"), "mean"),
527
- (loc.localize("training-tab-upsampling-radio-option-linear"), "linear"),
528
- ("SMOTE", "smote"),
529
- ],
530
- value=cfg.UPSAMPLING_MODE,
531
- label=loc.localize("training-tab-upsampling-radio-label"),
532
- info=loc.localize("training-tab-upsampling-radio-info"),
533
- )
534
- upsampling_ratio = gr.Slider(
535
- 0.0,
536
- 1.0,
537
- cfg.UPSAMPLING_RATIO,
538
- step=0.05,
539
- label=loc.localize("training-tab-upsampling-ratio-slider-label"),
540
- info=loc.localize("training-tab-upsampling-ratio-slider-info"),
541
- )
542
-
543
- with gr.Row():
544
- use_mixup = gr.Checkbox(
545
- cfg.TRAIN_WITH_MIXUP,
546
- label=loc.localize("training-tab-use-mixup-checkbox-label"),
547
- info=loc.localize("training-tab-use-mixup-checkbox-info"),
548
- show_label=True,
549
- )
550
- use_focal_loss = gr.Checkbox(
551
- cfg.TRAIN_WITH_FOCAL_LOSS,
552
- label=loc.localize("training-tab-use-focal-loss-checkbox-label"),
553
- info=loc.localize("training-tab-use-focal-loss-checkbox-info"),
554
- show_label=True,
555
- )
556
-
557
- with gr.Row(visible=False) as focal_loss_params:
558
- with gr.Column():
559
- focal_loss_gamma = gr.Slider(
560
- minimum=0.5,
561
- maximum=5.0,
562
- value=cfg.FOCAL_LOSS_GAMMA,
563
- step=0.1,
564
- label=loc.localize("training-tab-focal-loss-gamma-slider-label"),
565
- info=loc.localize("training-tab-focal-loss-gamma-slider-info"),
566
- interactive=True,
567
- )
568
- focal_loss_alpha = gr.Slider(
569
- minimum=0.1,
570
- maximum=0.9,
571
- value=cfg.FOCAL_LOSS_ALPHA,
572
- step=0.05,
573
- label=loc.localize("training-tab-focal-loss-alpha-slider-label"),
574
- info=loc.localize("training-tab-focal-loss-alpha-slider-info"),
575
- interactive=True,
576
- )
577
-
578
- def on_focal_loss_change(value):
579
- return gr.Row(visible=value)
580
-
581
- use_focal_loss.change(
582
- on_focal_loss_change, inputs=use_focal_loss, outputs=focal_loss_params, show_progress=False
583
- )
584
-
585
- def on_autotune_change(value):
586
- return (
587
- gr.Column(visible=not value),
588
- gr.Column(visible=value),
589
- gr.Row(visible=not value and use_focal_loss.value),
590
- )
591
-
592
- autotune_cb.change(
593
- on_autotune_change,
594
- inputs=autotune_cb,
595
- outputs=[custom_params, autotune_params, focal_loss_params],
596
- show_progress=False,
597
- )
598
-
599
- model_save_mode = gr.Radio(
600
- [
601
- (loc.localize("training-tab-model-save-mode-radio-option-replace"), "replace"),
602
- (loc.localize("training-tab-model-save-mode-radio-option-append"), "append"),
603
- ],
604
- value=cfg.TRAINED_MODEL_SAVE_MODE,
605
- label=loc.localize("training-tab-model-save-mode-radio-label"),
606
- info=loc.localize("training-tab-model-save-mode-radio-info"),
607
- )
608
-
609
- train_history_plot = gr.Plot()
610
- metrics_table = gr.Dataframe(
611
- headers=["Class", "Precision", "Recall", "F1 Score", "AUPRC", "AUROC", "Samples"],
612
- visible=False,
613
- label="Model Performance Metrics (Default Threshold 0.5)",
614
- )
615
- start_training_button = gr.Button(
616
- loc.localize("training-tab-start-training-button-label"), variant="huggingface"
617
- )
618
-
619
- def train_and_show_metrics(*args):
620
- history, metrics = start_training(*args)
621
-
622
- # If metrics are available (test data was provided), create table
623
- if metrics:
624
- # Create dataframe data with metrics
625
- table_data = []
626
-
627
- # Add overall metrics row first
628
- table_data.append(
629
- [
630
- "OVERALL (Macro-avg)",
631
- f"{metrics['macro_precision_default']:.4f}",
632
- f"{metrics['macro_recall_default']:.4f}",
633
- f"{metrics['macro_f1_default']:.4f}",
634
- f"{metrics['macro_auprc']:.4f}",
635
- f"{metrics['macro_auroc']:.4f}",
636
- "",
637
- ]
638
- )
639
-
640
- # Add class-specific metrics
641
- for class_name, class_metrics in metrics["class_metrics"].items():
642
- distribution = metrics["class_distribution"].get(class_name, {"count": 0, "percentage": 0.0})
643
- table_data.append(
644
- [
645
- class_name,
646
- f"{class_metrics['precision_default']:.4f}",
647
- f"{class_metrics['recall_default']:.4f}",
648
- f"{class_metrics['f1_default']:.4f}",
649
- f"{class_metrics['auprc']:.4f}",
650
- f"{class_metrics['auroc']:.4f}",
651
- f"{distribution['count']} ({distribution['percentage']:.2f}%)",
652
- ]
653
- )
654
-
655
- return history, gr.Dataframe(visible=True, value=table_data)
656
- else:
657
- # No metrics available, just return history and hide table
658
- return history, gr.Dataframe(visible=False)
659
-
660
- start_training_button.click(
661
- train_and_show_metrics,
662
- inputs=[
663
- input_directory_state,
664
- test_data_dir_state,
665
- crop_mode,
666
- crop_overlap,
667
- fmin_number,
668
- fmax_number,
669
- output_directory_state,
670
- classifier_name,
671
- model_save_mode,
672
- cache_mode,
673
- cache_file_state,
674
- cache_file_name,
675
- autotune_cb,
676
- autotune_trials,
677
- autotune_executions_per_trials,
678
- epoch_number,
679
- batch_size_number,
680
- learning_rate_number,
681
- use_focal_loss,
682
- focal_loss_gamma,
683
- focal_loss_alpha,
684
- hidden_units_number,
685
- dropout_number,
686
- use_label_smoothing,
687
- use_mixup,
688
- upsampling_ratio,
689
- upsampling_mode,
690
- output_format,
691
- audio_speed_slider,
692
- ],
693
- outputs=[train_history_plot, metrics_table],
694
- )
695
-
696
-
697
- if __name__ == "__main__":
698
- gu.open_window(build_train_tab)
1
+ import multiprocessing
2
+ import os
3
+ from functools import partial
4
+ from pathlib import Path
5
+
6
+ import gradio as gr
7
+
8
+ import birdnet_analyzer.config as cfg
9
+ import birdnet_analyzer.gui.localization as loc
10
+ import birdnet_analyzer.gui.utils as gu
11
+ from birdnet_analyzer import utils
12
+ from birdnet_analyzer.gui.settings import APPDIR
13
+
14
+ _GRID_MAX_HEIGHT = 240
15
+
16
+
17
+ def select_subdirectories(state_key=None):
18
+ """Creates a directory selection dialog.
19
+
20
+ Returns:
21
+ A tuples of (directory, list of subdirectories) or (None, None) if the dialog was canceled.
22
+ """
23
+ dir_name = gu.select_folder(state_key=state_key)
24
+
25
+ if dir_name:
26
+ subdirs = utils.list_subdirectories(dir_name)
27
+ labels = []
28
+
29
+ for folder in subdirs:
30
+ labels_in_folder = folder.split(",")
31
+
32
+ for label in labels_in_folder:
33
+ if label not in labels:
34
+ labels.append(label)
35
+
36
+ return dir_name, [[label] for label in sorted(labels)]
37
+
38
+ return None, None
39
+
40
+
41
+ @gu.gui_runtime_error_handler
42
+ def start_training(
43
+ data_dir,
44
+ test_data_dir,
45
+ crop_mode,
46
+ crop_overlap,
47
+ fmin,
48
+ fmax,
49
+ output_dir,
50
+ classifier_name,
51
+ model_save_mode,
52
+ cache_mode,
53
+ cache_file,
54
+ cache_file_name,
55
+ autotune,
56
+ autotune_trials,
57
+ autotune_executions_per_trials,
58
+ epochs,
59
+ batch_size,
60
+ learning_rate,
61
+ focal_loss,
62
+ focal_loss_gamma,
63
+ focal_loss_alpha,
64
+ hidden_units,
65
+ dropout,
66
+ label_smoothing,
67
+ use_mixup,
68
+ upsampling_ratio,
69
+ upsampling_mode,
70
+ model_format,
71
+ audio_speed,
72
+ progress=gr.Progress(),
73
+ ):
74
+ """Starts the training of a custom classifier.
75
+
76
+ Args:
77
+ data_dir: Directory containing the training data.
78
+ test_data_dir: Directory containing the test data.
79
+ crop_mode: Mode for cropping audio samples.
80
+ crop_overlap: Overlap ratio for audio segments.
81
+ fmin: Minimum frequency for bandpass filtering.
82
+ fmax: Maximum frequency for bandpass filtering.
83
+ output_dir: Directory to save the trained model.
84
+ classifier_name: Name of the custom classifier.
85
+ model_save_mode: Save mode for the model (replace or append).
86
+ cache_mode: Cache mode for training data (load, save, or None).
87
+ cache_file: Path to the cache file.
88
+ cache_file_name: Name of the cache file.
89
+ autotune: Whether to use hyperparameter autotuning.
90
+ autotune_trials: Number of trials for autotuning.
91
+ autotune_executions_per_trials: Number of executions per autotuning trial.
92
+ epochs: Number of training epochs.
93
+ batch_size: Batch size for training.
94
+ learning_rate: Learning rate for the optimizer.
95
+ focal_loss: Whether to use focal loss for training.
96
+ focal_loss_gamma: Gamma parameter for focal loss.
97
+ focal_loss_alpha: Alpha parameter for focal loss.
98
+ hidden_units: Number of hidden units in the droput: Dropout rate for regularization.
99
+ dropout: Dropout rate for regularization.
100
+ label_smoothing: Whether to apply label smoothing for training.
101
+ use_mixup: Whether to use mixup data augmentation.
102
+ upsampling_ratio: Ratio for upsampling underrepresented classes.
103
+ upsampling_mode: Mode for upsampling (repeat, mean, smote).
104
+ model_format: Format to save the trained model (tflite, raven, both).
105
+ audio_speed: Speed factor for audio playback.
106
+
107
+ Returns:
108
+ Returns a matplotlib.pyplot figure.
109
+ """
110
+ import matplotlib
111
+ import matplotlib.pyplot as plt
112
+
113
+ from birdnet_analyzer.train.utils import train_model
114
+
115
+ # Skip training data validation when cache mode is "load"
116
+ if cache_mode != "load":
117
+ gu.validate(data_dir, loc.localize("validation-no-training-data-selected"))
118
+
119
+ gu.validate(output_dir, loc.localize("validation-no-directory-for-classifier-selected"))
120
+ gu.validate(classifier_name, loc.localize("validation-no-valid-classifier-name"))
121
+
122
+ if not epochs or epochs < 0:
123
+ raise gr.Error(loc.localize("validation-no-valid-epoch-number"))
124
+
125
+ if not batch_size or batch_size < 0:
126
+ raise gr.Error(loc.localize("validation-no-valid-batch-size"))
127
+
128
+ if not learning_rate or learning_rate < 0:
129
+ raise gr.Error(loc.localize("validation-no-valid-learning-rate"))
130
+
131
+ if fmin < cfg.SIG_FMIN or fmax > cfg.SIG_FMAX or fmin > fmax:
132
+ raise gr.Error(f"{loc.localize('validation-no-valid-frequency')} [{cfg.SIG_FMIN}, {cfg.SIG_FMAX}]")
133
+
134
+ cfg.TRAIN_WITH_FOCAL_LOSS = focal_loss
135
+ cfg.FOCAL_LOSS_GAMMA = max(0.0, float(focal_loss_gamma))
136
+ cfg.FOCAL_LOSS_ALPHA = max(0.0, min(1.0, float(focal_loss_alpha)))
137
+
138
+ if not hidden_units or hidden_units < 0:
139
+ hidden_units = 0
140
+
141
+ cfg.TRAIN_DROPOUT = max(0.0, min(1.0, float(dropout)))
142
+
143
+ if progress is not None:
144
+ progress((0, epochs), desc=loc.localize("progress-build-classifier"), unit="epochs")
145
+
146
+ cfg.TRAIN_DATA_PATH = data_dir
147
+ cfg.TEST_DATA_PATH = test_data_dir
148
+ cfg.SAMPLE_CROP_MODE = crop_mode
149
+ cfg.SIG_OVERLAP = max(0.0, min(2.9, float(crop_overlap)))
150
+ cfg.CUSTOM_CLASSIFIER = str(Path(output_dir) / classifier_name)
151
+ cfg.TRAIN_EPOCHS = int(epochs)
152
+ cfg.TRAIN_BATCH_SIZE = int(batch_size)
153
+ cfg.TRAIN_LEARNING_RATE = learning_rate
154
+ cfg.TRAIN_HIDDEN_UNITS = int(hidden_units)
155
+ cfg.TRAIN_WITH_LABEL_SMOOTHING = label_smoothing
156
+ cfg.TRAIN_WITH_MIXUP = use_mixup
157
+ cfg.UPSAMPLING_RATIO = min(max(0, upsampling_ratio), 1)
158
+ cfg.UPSAMPLING_MODE = upsampling_mode
159
+ cfg.TRAINED_MODEL_OUTPUT_FORMAT = model_format
160
+
161
+ cfg.BANDPASS_FMIN = max(0, min(cfg.SIG_FMAX, int(fmin)))
162
+ cfg.BANDPASS_FMAX = max(cfg.SIG_FMIN, min(cfg.SIG_FMAX, int(fmax)))
163
+
164
+ cfg.TRAINED_MODEL_SAVE_MODE = model_save_mode
165
+ cfg.TRAIN_CACHE_MODE = cache_mode
166
+ cfg.TRAIN_CACHE_FILE = os.path.join(cache_file, cache_file_name) if cache_mode == "save" else cache_file
167
+ cfg.TFLITE_THREADS = 1
168
+ cfg.CPU_THREADS = max(1, multiprocessing.cpu_count() - 1) # let's use everything we have (well, almost)
169
+
170
+ if cache_mode == "load" and not os.path.isfile(cfg.TRAIN_CACHE_FILE):
171
+ raise gr.Error(loc.localize("validation-no-cache-file-selected"))
172
+
173
+ cfg.AUTOTUNE = autotune
174
+ cfg.AUTOTUNE_TRIALS = autotune_trials
175
+ cfg.AUTOTUNE_EXECUTIONS_PER_TRIAL = int(autotune_executions_per_trials)
176
+
177
+ cfg.AUDIO_SPEED = max(0.1, 1.0 / (audio_speed * -1)) if audio_speed < 0 else max(1.0, float(audio_speed))
178
+
179
+ def data_load_progression(num_files, num_total_files, label):
180
+ if progress is not None:
181
+ progress(
182
+ (num_files, num_total_files),
183
+ total=num_total_files,
184
+ unit="files",
185
+ desc=f"{loc.localize('progress-loading-data')} '{label}'",
186
+ )
187
+
188
+ def epoch_progression(epoch, logs=None):
189
+ if progress is not None:
190
+ if epoch + 1 == epochs:
191
+ progress(
192
+ (epoch + 1, epochs),
193
+ total=epochs,
194
+ unit="epochs",
195
+ desc=f"{loc.localize('progress-saving')} {cfg.CUSTOM_CLASSIFIER}",
196
+ )
197
+ else:
198
+ progress((epoch + 1, epochs), total=epochs, unit="epochs", desc=loc.localize("progress-training"))
199
+
200
+ def trial_progression(trial):
201
+ if progress is not None:
202
+ progress((trial, autotune_trials), total=autotune_trials, unit="trials", desc=loc.localize("progress-autotune"))
203
+
204
+ try:
205
+ history_result = train_model(
206
+ on_epoch_end=epoch_progression,
207
+ on_trial_result=trial_progression,
208
+ on_data_load_end=data_load_progression,
209
+ autotune_directory=APPDIR if utils.FROZEN else "autotune",
210
+ )
211
+
212
+ # Unpack history and metrics
213
+ history, metrics = history_result
214
+ except Exception as e:
215
+ if e.args and len(e.args) > 1:
216
+ raise gr.Error(loc.localize(e.args[1])) from e
217
+
218
+ raise gr.Error(f"{e}") from e
219
+
220
+ if len(history.epoch) < epochs:
221
+ gr.Info(loc.localize("training-tab-early-stoppage-msg"))
222
+
223
+ auprc = history.history["val_AUPRC"]
224
+ auroc = history.history["val_AUROC"]
225
+
226
+ matplotlib.use("agg")
227
+
228
+ fig = plt.figure()
229
+ plt.plot(auprc, label="AUPRC")
230
+ plt.plot(auroc, label="AUROC")
231
+ plt.legend()
232
+ plt.xlabel("Epoch")
233
+
234
+ return fig, metrics
235
+
236
+
237
+ def build_train_tab():
238
+ with gr.Tab(loc.localize("training-tab-title")):
239
+ input_directory_state = gr.State()
240
+ output_directory_state = gr.State()
241
+ test_data_dir_state = gr.State()
242
+
243
+ with gr.Row():
244
+ with gr.Column():
245
+ select_directory_btn = gr.Button(loc.localize("training-tab-input-selection-button-label"))
246
+ directory_input = gr.List(
247
+ headers=[loc.localize("training-tab-classes-dataframe-column-classes-header")],
248
+ interactive=False,
249
+ max_height=_GRID_MAX_HEIGHT,
250
+ )
251
+ select_directory_btn.click(
252
+ partial(select_subdirectories, state_key="train-data-dir"),
253
+ outputs=[input_directory_state, directory_input],
254
+ show_progress=False,
255
+ )
256
+
257
+ select_test_directory_btn = gr.Button(loc.localize("training-tab-test-data-selection-button-label"))
258
+ test_directory_input = gr.List(
259
+ headers=[loc.localize("training-tab-classes-dataframe-column-classes-header")],
260
+ interactive=False,
261
+ max_height=_GRID_MAX_HEIGHT,
262
+ )
263
+ select_test_directory_btn.click(
264
+ partial(select_subdirectories, state_key="test-data-dir"),
265
+ outputs=[test_data_dir_state, test_directory_input],
266
+ show_progress=False,
267
+ )
268
+
269
+ with gr.Column():
270
+ select_classifier_directory_btn = gr.Button(loc.localize("training-tab-select-output-button-label"))
271
+
272
+ with gr.Column():
273
+ classifier_name = gr.Textbox(
274
+ "CustomClassifier",
275
+ visible=False,
276
+ info=loc.localize("training-tab-classifier-textbox-info"),
277
+ )
278
+ output_format = gr.Radio(
279
+ ["tflite", "raven", (loc.localize("training-tab-output-format-both"), "both")],
280
+ value=cfg.TRAINED_MODEL_OUTPUT_FORMAT,
281
+ label=loc.localize("training-tab-output-format-radio-label"),
282
+ info=loc.localize("training-tab-output-format-radio-info"),
283
+ visible=False,
284
+ )
285
+
286
+ def select_directory_and_update_tb():
287
+ dir_name = gu.select_folder(state_key="train-output-dir")
288
+
289
+ if dir_name:
290
+ return (
291
+ dir_name,
292
+ gr.Textbox(label=dir_name, visible=True),
293
+ gr.Radio(visible=True, interactive=True),
294
+ )
295
+
296
+ return None, None
297
+
298
+ select_classifier_directory_btn.click(
299
+ select_directory_and_update_tb,
300
+ outputs=[output_directory_state, classifier_name, output_format],
301
+ show_progress=False,
302
+ )
303
+
304
+ with gr.Row():
305
+ cache_file_state = gr.State()
306
+ cache_mode = gr.Radio(
307
+ [
308
+ (loc.localize("training-tab-cache-mode-radio-option-none"), None),
309
+ (loc.localize("training-tab-cache-mode-radio-option-load"), "load"),
310
+ (loc.localize("training-tab-cache-mode-radio-option-save"), "save"),
311
+ ],
312
+ value=cfg.TRAIN_CACHE_MODE,
313
+ label=loc.localize("training-tab-cache-mode-radio-label"),
314
+ info=loc.localize("training-tab-cache-mode-radio-info"),
315
+ )
316
+ with gr.Column(visible=False) as new_cache_file_row:
317
+ select_cache_file_directory_btn = gr.Button(loc.localize("training-tab-cache-select-directory-button-label"))
318
+
319
+ with gr.Column():
320
+ cache_file_name = gr.Textbox(
321
+ "train_cache.npz",
322
+ visible=False,
323
+ info=loc.localize("training-tab-cache-file-name-textbox-info"),
324
+ )
325
+
326
+ def select_directory_and_update():
327
+ dir_name = gu.select_folder(state_key="train-data-cache-file-output")
328
+
329
+ if dir_name:
330
+ return (
331
+ dir_name,
332
+ gr.Textbox(label=dir_name, visible=True),
333
+ )
334
+
335
+ return None, None
336
+
337
+ select_cache_file_directory_btn.click(
338
+ select_directory_and_update,
339
+ outputs=[cache_file_state, cache_file_name],
340
+ show_progress=False,
341
+ )
342
+
343
+ with gr.Column(visible=False) as load_cache_file_row:
344
+ selected_cache_file_btn = gr.Button(loc.localize("training-tab-cache-select-file-button-label"))
345
+ cache_file_input = gr.File(file_types=[".npz"], visible=False, interactive=False)
346
+
347
+ def on_cache_file_selection_click():
348
+ file = gu.select_file(("NPZ file (*.npz)",), state_key="train_data_cache_file")
349
+
350
+ if file:
351
+ return file, gr.File(value=file, visible=True)
352
+
353
+ return None, None
354
+
355
+ selected_cache_file_btn.click(
356
+ on_cache_file_selection_click,
357
+ outputs=[cache_file_state, cache_file_input],
358
+ show_progress=False,
359
+ )
360
+
361
+ def on_cache_mode_change(value):
362
+ return (
363
+ gr.update(visible=value == "save"),
364
+ gr.update(visible=value == "load"),
365
+ gr.update(interactive=value != "load"),
366
+ [],
367
+ gr.update(interactive=value != "load"),
368
+ [],
369
+ gr.update(interactive=value != "load"),
370
+ gr.update(interactive=value != "load"),
371
+ gr.update(interactive=value != "load"),
372
+ gr.update(interactive=value != "load"),
373
+ gr.update(interactive=value != "load"),
374
+ )
375
+
376
+ with gr.Row():
377
+ fmin_number = gr.Number(
378
+ cfg.SIG_FMIN,
379
+ minimum=0,
380
+ label=loc.localize("inference-settings-fmin-number-label"),
381
+ info=loc.localize("inference-settings-fmin-number-info"),
382
+ )
383
+
384
+ fmax_number = gr.Number(
385
+ cfg.SIG_FMAX,
386
+ minimum=0,
387
+ label=loc.localize("inference-settings-fmax-number-label"),
388
+ info=loc.localize("inference-settings-fmax-number-info"),
389
+ )
390
+
391
+ with gr.Row():
392
+ audio_speed_slider = gr.Slider(
393
+ minimum=-10,
394
+ maximum=10,
395
+ value=cfg.AUDIO_SPEED,
396
+ step=1,
397
+ label=loc.localize("training-tab-audio-speed-slider-label"),
398
+ info=loc.localize("training-tab-audio-speed-slider-info"),
399
+ )
400
+
401
+ with gr.Row():
402
+ crop_mode = gr.Radio(
403
+ [
404
+ (loc.localize("training-tab-crop-mode-radio-option-center"), "center"),
405
+ (loc.localize("training-tab-crop-mode-radio-option-first"), "first"),
406
+ (loc.localize("training-tab-crop-mode-radio-option-segments"), "segments"),
407
+ (loc.localize("training-tab-crop-mode-radio-option-smart"), "smart"),
408
+ ],
409
+ value="center",
410
+ label=loc.localize("training-tab-crop-mode-radio-label"),
411
+ info=loc.localize("training-tab-crop-mode-radio-info"),
412
+ )
413
+
414
+ crop_overlap = gr.Slider(
415
+ minimum=0,
416
+ maximum=2.99,
417
+ value=cfg.SIG_OVERLAP,
418
+ step=0.01,
419
+ label=loc.localize("training-tab-crop-overlap-number-label"),
420
+ info=loc.localize("training-tab-crop-overlap-number-info"),
421
+ visible=False,
422
+ )
423
+
424
+ def on_crop_select(new_crop_mode):
425
+ # Make overlap slider visible for both "segments" and "smart" crop modes
426
+ return gr.Number(visible=new_crop_mode in ["segments", "smart"], interactive=new_crop_mode in ["segments", "smart"])
427
+
428
+ crop_mode.change(on_crop_select, inputs=crop_mode, outputs=crop_overlap)
429
+
430
+ cache_mode.change(
431
+ on_cache_mode_change,
432
+ inputs=cache_mode,
433
+ outputs=[
434
+ new_cache_file_row,
435
+ load_cache_file_row,
436
+ select_directory_btn,
437
+ directory_input,
438
+ select_test_directory_btn,
439
+ test_directory_input,
440
+ fmin_number,
441
+ fmax_number,
442
+ audio_speed_slider,
443
+ crop_mode,
444
+ crop_overlap,
445
+ ],
446
+ show_progress=False,
447
+ )
448
+
449
+ autotune_cb = gr.Checkbox(
450
+ cfg.AUTOTUNE,
451
+ label=loc.localize("training-tab-autotune-checkbox-label"),
452
+ info=loc.localize("training-tab-autotune-checkbox-info"),
453
+ )
454
+
455
+ with gr.Column(visible=False) as autotune_params, gr.Row():
456
+ autotune_trials = gr.Number(
457
+ cfg.AUTOTUNE_TRIALS,
458
+ label=loc.localize("training-tab-autotune-trials-number-label"),
459
+ info=loc.localize("training-tab-autotune-trials-number-info"),
460
+ minimum=1,
461
+ )
462
+ autotune_executions_per_trials = gr.Number(
463
+ cfg.AUTOTUNE_EXECUTIONS_PER_TRIAL,
464
+ minimum=1,
465
+ label=loc.localize("training-tab-autotune-executions-number-label"),
466
+ info=loc.localize("training-tab-autotune-executions-number-info"),
467
+ )
468
+
469
+ with gr.Column() as custom_params:
470
+ with gr.Row():
471
+ epoch_number = gr.Number(
472
+ cfg.TRAIN_EPOCHS,
473
+ minimum=1,
474
+ step=1,
475
+ label=loc.localize("training-tab-epochs-number-label"),
476
+ info=loc.localize("training-tab-epochs-number-info"),
477
+ )
478
+ batch_size_number = gr.Number(
479
+ 32,
480
+ minimum=1,
481
+ step=8,
482
+ label=loc.localize("training-tab-batchsize-number-label"),
483
+ info=loc.localize("training-tab-batchsize-number-info"),
484
+ )
485
+ learning_rate_number = gr.Number(
486
+ cfg.TRAIN_LEARNING_RATE,
487
+ minimum=0.0001,
488
+ step=0.0001,
489
+ label=loc.localize("training-tab-learningrate-number-label"),
490
+ info=loc.localize("training-tab-learningrate-number-info"),
491
+ )
492
+
493
+ with gr.Row():
494
+ hidden_units_number = gr.Number(
495
+ cfg.TRAIN_HIDDEN_UNITS,
496
+ minimum=0,
497
+ step=64,
498
+ label=loc.localize("training-tab-hiddenunits-number-label"),
499
+ info=loc.localize("training-tab-hiddenunits-number-info"),
500
+ )
501
+ dropout_number = gr.Number(
502
+ cfg.TRAIN_DROPOUT,
503
+ minimum=0.0,
504
+ maximum=0.9,
505
+ step=0.1,
506
+ label=loc.localize("training-tab-dropout-number-label"),
507
+ info=loc.localize("training-tab-dropout-number-info"),
508
+ )
509
+ use_label_smoothing = gr.Checkbox(
510
+ cfg.TRAIN_WITH_LABEL_SMOOTHING,
511
+ label=loc.localize("training-tab-use-labelsmoothing-checkbox-label"),
512
+ info=loc.localize("training-tab-use-labelsmoothing-checkbox-info"),
513
+ show_label=True,
514
+ )
515
+
516
+ with gr.Row():
517
+ upsampling_mode = gr.Radio(
518
+ [
519
+ (loc.localize("training-tab-upsampling-radio-option-repeat"), "repeat"),
520
+ (loc.localize("training-tab-upsampling-radio-option-mean"), "mean"),
521
+ (loc.localize("training-tab-upsampling-radio-option-linear"), "linear"),
522
+ ("SMOTE", "smote"),
523
+ ],
524
+ value=cfg.UPSAMPLING_MODE,
525
+ label=loc.localize("training-tab-upsampling-radio-label"),
526
+ info=loc.localize("training-tab-upsampling-radio-info"),
527
+ )
528
+ upsampling_ratio = gr.Slider(
529
+ 0.0,
530
+ 1.0,
531
+ cfg.UPSAMPLING_RATIO,
532
+ step=0.05,
533
+ label=loc.localize("training-tab-upsampling-ratio-slider-label"),
534
+ info=loc.localize("training-tab-upsampling-ratio-slider-info"),
535
+ )
536
+
537
+ with gr.Row():
538
+ use_mixup = gr.Checkbox(
539
+ cfg.TRAIN_WITH_MIXUP,
540
+ label=loc.localize("training-tab-use-mixup-checkbox-label"),
541
+ info=loc.localize("training-tab-use-mixup-checkbox-info"),
542
+ show_label=True,
543
+ )
544
+ use_focal_loss = gr.Checkbox(
545
+ cfg.TRAIN_WITH_FOCAL_LOSS,
546
+ label=loc.localize("training-tab-use-focal-loss-checkbox-label"),
547
+ info=loc.localize("training-tab-use-focal-loss-checkbox-info"),
548
+ show_label=True,
549
+ )
550
+
551
+ with gr.Row(visible=False) as focal_loss_params, gr.Row():
552
+ focal_loss_gamma = gr.Slider(
553
+ minimum=0.5,
554
+ maximum=5.0,
555
+ value=cfg.FOCAL_LOSS_GAMMA,
556
+ step=0.1,
557
+ label=loc.localize("training-tab-focal-loss-gamma-slider-label"),
558
+ info=loc.localize("training-tab-focal-loss-gamma-slider-info"),
559
+ interactive=True,
560
+ )
561
+ focal_loss_alpha = gr.Slider(
562
+ minimum=0.1,
563
+ maximum=0.9,
564
+ value=cfg.FOCAL_LOSS_ALPHA,
565
+ step=0.05,
566
+ label=loc.localize("training-tab-focal-loss-alpha-slider-label"),
567
+ info=loc.localize("training-tab-focal-loss-alpha-slider-info"),
568
+ interactive=True,
569
+ )
570
+
571
+ def on_focal_loss_change(value):
572
+ return gr.Row(visible=value)
573
+
574
+ use_focal_loss.change(on_focal_loss_change, inputs=use_focal_loss, outputs=focal_loss_params, show_progress=False)
575
+
576
+ def on_autotune_change(value):
577
+ return (
578
+ gr.Column(visible=not value),
579
+ gr.Column(visible=value),
580
+ gr.Row(visible=not value and use_focal_loss.value),
581
+ )
582
+
583
+ autotune_cb.change(
584
+ on_autotune_change,
585
+ inputs=autotune_cb,
586
+ outputs=[custom_params, autotune_params, focal_loss_params],
587
+ show_progress=False,
588
+ )
589
+
590
+ model_save_mode = gr.Radio(
591
+ [
592
+ (loc.localize("training-tab-model-save-mode-radio-option-replace"), "replace"),
593
+ (loc.localize("training-tab-model-save-mode-radio-option-append"), "append"),
594
+ ],
595
+ value=cfg.TRAINED_MODEL_SAVE_MODE,
596
+ label=loc.localize("training-tab-model-save-mode-radio-label"),
597
+ info=loc.localize("training-tab-model-save-mode-radio-info"),
598
+ )
599
+
600
+ train_history_plot = gr.Plot()
601
+ metrics_table = gr.Dataframe(
602
+ headers=["Class", "Precision", "Recall", "F1 Score", "AUPRC", "AUROC", "Samples"],
603
+ visible=False,
604
+ label="Model Performance Metrics (Default Threshold 0.5)",
605
+ )
606
+ start_training_button = gr.Button(loc.localize("training-tab-start-training-button-label"), variant="huggingface")
607
+
608
+ def train_and_show_metrics(*args):
609
+ history, metrics = start_training(*args)
610
+
611
+ # If metrics are available (test data was provided), create table
612
+ if metrics:
613
+ # Create dataframe data with metrics
614
+ table_data = []
615
+
616
+ # Add overall metrics row first
617
+ table_data.append(
618
+ [
619
+ "OVERALL (Macro-avg)",
620
+ f"{metrics['macro_precision_default']:.4f}",
621
+ f"{metrics['macro_recall_default']:.4f}",
622
+ f"{metrics['macro_f1_default']:.4f}",
623
+ f"{metrics['macro_auprc']:.4f}",
624
+ f"{metrics['macro_auroc']:.4f}",
625
+ "",
626
+ ]
627
+ )
628
+
629
+ # Add class-specific metrics
630
+ for class_name, class_metrics in metrics["class_metrics"].items():
631
+ distribution = metrics["class_distribution"].get(class_name, {"count": 0, "percentage": 0.0})
632
+ table_data.append(
633
+ [
634
+ class_name,
635
+ f"{class_metrics['precision_default']:.4f}",
636
+ f"{class_metrics['recall_default']:.4f}",
637
+ f"{class_metrics['f1_default']:.4f}",
638
+ f"{class_metrics['auprc']:.4f}",
639
+ f"{class_metrics['auroc']:.4f}",
640
+ f"{distribution['count']} ({distribution['percentage']:.2f}%)",
641
+ ]
642
+ )
643
+
644
+ return history, gr.Dataframe(visible=True, value=table_data)
645
+
646
+ # No metrics available, just return history and hide table
647
+ return history, gr.Dataframe(visible=False)
648
+
649
+ start_training_button.click(
650
+ train_and_show_metrics,
651
+ inputs=[
652
+ input_directory_state,
653
+ test_data_dir_state,
654
+ crop_mode,
655
+ crop_overlap,
656
+ fmin_number,
657
+ fmax_number,
658
+ output_directory_state,
659
+ classifier_name,
660
+ model_save_mode,
661
+ cache_mode,
662
+ cache_file_state,
663
+ cache_file_name,
664
+ autotune_cb,
665
+ autotune_trials,
666
+ autotune_executions_per_trials,
667
+ epoch_number,
668
+ batch_size_number,
669
+ learning_rate_number,
670
+ use_focal_loss,
671
+ focal_loss_gamma,
672
+ focal_loss_alpha,
673
+ hidden_units_number,
674
+ dropout_number,
675
+ use_label_smoothing,
676
+ use_mixup,
677
+ upsampling_ratio,
678
+ upsampling_mode,
679
+ output_format,
680
+ audio_speed_slider,
681
+ ],
682
+ outputs=[train_history_plot, metrics_table],
683
+ )
684
+
685
+
686
+ if __name__ == "__main__":
687
+ gu.open_window(build_train_tab)