birdnet-analyzer 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (117) hide show
  1. birdnet_analyzer/__init__.py +8 -0
  2. birdnet_analyzer/analyze/__init__.py +5 -0
  3. birdnet_analyzer/analyze/__main__.py +4 -0
  4. birdnet_analyzer/analyze/cli.py +25 -0
  5. birdnet_analyzer/analyze/core.py +245 -0
  6. birdnet_analyzer/analyze/utils.py +701 -0
  7. birdnet_analyzer/audio.py +372 -0
  8. birdnet_analyzer/cli.py +707 -0
  9. birdnet_analyzer/config.py +242 -0
  10. birdnet_analyzer/eBird_taxonomy_codes_2021E.json +25280 -0
  11. birdnet_analyzer/embeddings/__init__.py +4 -0
  12. birdnet_analyzer/embeddings/__main__.py +3 -0
  13. birdnet_analyzer/embeddings/cli.py +13 -0
  14. birdnet_analyzer/embeddings/core.py +70 -0
  15. birdnet_analyzer/embeddings/utils.py +193 -0
  16. birdnet_analyzer/evaluation/__init__.py +195 -0
  17. birdnet_analyzer/evaluation/__main__.py +3 -0
  18. birdnet_analyzer/gui/__init__.py +23 -0
  19. birdnet_analyzer/gui/__main__.py +3 -0
  20. birdnet_analyzer/gui/analysis.py +174 -0
  21. birdnet_analyzer/gui/assets/arrow_down.svg +4 -0
  22. birdnet_analyzer/gui/assets/arrow_left.svg +4 -0
  23. birdnet_analyzer/gui/assets/arrow_right.svg +4 -0
  24. birdnet_analyzer/gui/assets/arrow_up.svg +4 -0
  25. birdnet_analyzer/gui/assets/gui.css +29 -0
  26. birdnet_analyzer/gui/assets/gui.js +94 -0
  27. birdnet_analyzer/gui/assets/img/birdnet-icon.ico +0 -0
  28. birdnet_analyzer/gui/assets/img/birdnet_logo.png +0 -0
  29. birdnet_analyzer/gui/assets/img/birdnet_logo_no_transparent.png +0 -0
  30. birdnet_analyzer/gui/assets/img/clo-logo-bird.svg +1 -0
  31. birdnet_analyzer/gui/embeddings.py +620 -0
  32. birdnet_analyzer/gui/evaluation.py +813 -0
  33. birdnet_analyzer/gui/localization.py +68 -0
  34. birdnet_analyzer/gui/multi_file.py +246 -0
  35. birdnet_analyzer/gui/review.py +527 -0
  36. birdnet_analyzer/gui/segments.py +191 -0
  37. birdnet_analyzer/gui/settings.py +129 -0
  38. birdnet_analyzer/gui/single_file.py +269 -0
  39. birdnet_analyzer/gui/species.py +95 -0
  40. birdnet_analyzer/gui/train.py +698 -0
  41. birdnet_analyzer/gui/utils.py +808 -0
  42. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_af.txt +6522 -0
  43. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ar.txt +6522 -0
  44. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_bg.txt +6522 -0
  45. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ca.txt +6522 -0
  46. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_cs.txt +6522 -0
  47. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_da.txt +6522 -0
  48. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_de.txt +6522 -0
  49. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_el.txt +6522 -0
  50. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_en_uk.txt +6522 -0
  51. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_es.txt +6522 -0
  52. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_fi.txt +6522 -0
  53. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_fr.txt +6522 -0
  54. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_he.txt +6522 -0
  55. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_hr.txt +6522 -0
  56. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_hu.txt +6522 -0
  57. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_in.txt +6522 -0
  58. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_is.txt +6522 -0
  59. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_it.txt +6522 -0
  60. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ja.txt +6522 -0
  61. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ko.txt +6522 -0
  62. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_lt.txt +6522 -0
  63. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ml.txt +6522 -0
  64. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_nl.txt +6522 -0
  65. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_no.txt +6522 -0
  66. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_pl.txt +6522 -0
  67. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_pt_BR.txt +6522 -0
  68. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_pt_PT.txt +6522 -0
  69. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ro.txt +6522 -0
  70. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ru.txt +6522 -0
  71. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sk.txt +6522 -0
  72. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sl.txt +6522 -0
  73. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sr.txt +6522 -0
  74. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sv.txt +6522 -0
  75. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_th.txt +6522 -0
  76. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_tr.txt +6522 -0
  77. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_uk.txt +6522 -0
  78. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_zh.txt +6522 -0
  79. birdnet_analyzer/lang/de.json +335 -0
  80. birdnet_analyzer/lang/en.json +335 -0
  81. birdnet_analyzer/lang/fi.json +335 -0
  82. birdnet_analyzer/lang/fr.json +335 -0
  83. birdnet_analyzer/lang/id.json +335 -0
  84. birdnet_analyzer/lang/pt-br.json +335 -0
  85. birdnet_analyzer/lang/ru.json +335 -0
  86. birdnet_analyzer/lang/se.json +335 -0
  87. birdnet_analyzer/lang/tlh.json +335 -0
  88. birdnet_analyzer/lang/zh_TW.json +335 -0
  89. birdnet_analyzer/model.py +1243 -0
  90. birdnet_analyzer/search/__init__.py +3 -0
  91. birdnet_analyzer/search/__main__.py +3 -0
  92. birdnet_analyzer/search/cli.py +12 -0
  93. birdnet_analyzer/search/core.py +78 -0
  94. birdnet_analyzer/search/utils.py +111 -0
  95. birdnet_analyzer/segments/__init__.py +3 -0
  96. birdnet_analyzer/segments/__main__.py +3 -0
  97. birdnet_analyzer/segments/cli.py +14 -0
  98. birdnet_analyzer/segments/core.py +78 -0
  99. birdnet_analyzer/segments/utils.py +394 -0
  100. birdnet_analyzer/species/__init__.py +3 -0
  101. birdnet_analyzer/species/__main__.py +3 -0
  102. birdnet_analyzer/species/cli.py +14 -0
  103. birdnet_analyzer/species/core.py +35 -0
  104. birdnet_analyzer/species/utils.py +75 -0
  105. birdnet_analyzer/train/__init__.py +3 -0
  106. birdnet_analyzer/train/__main__.py +3 -0
  107. birdnet_analyzer/train/cli.py +14 -0
  108. birdnet_analyzer/train/core.py +113 -0
  109. birdnet_analyzer/train/utils.py +847 -0
  110. birdnet_analyzer/translate.py +104 -0
  111. birdnet_analyzer/utils.py +419 -0
  112. birdnet_analyzer-2.0.0.dist-info/METADATA +129 -0
  113. birdnet_analyzer-2.0.0.dist-info/RECORD +117 -0
  114. birdnet_analyzer-2.0.0.dist-info/WHEEL +5 -0
  115. birdnet_analyzer-2.0.0.dist-info/entry_points.txt +11 -0
  116. birdnet_analyzer-2.0.0.dist-info/licenses/LICENSE +19 -0
  117. birdnet_analyzer-2.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,698 @@
1
+ import multiprocessing
2
+ import os
3
+ from functools import partial
4
+ from pathlib import Path
5
+
6
+ import gradio as gr
7
+
8
+ import birdnet_analyzer.config as cfg
9
+ import birdnet_analyzer.gui.localization as loc
10
+ import birdnet_analyzer.gui.utils as gu
11
+ import birdnet_analyzer.utils as utils
12
+
13
+ _GRID_MAX_HEIGHT = 240
14
+
15
+
16
+ def select_subdirectories(state_key=None):
17
+ """Creates a directory selection dialog.
18
+
19
+ Returns:
20
+ A tuples of (directory, list of subdirectories) or (None, None) if the dialog was canceled.
21
+ """
22
+ dir_name = gu.select_folder(state_key=state_key)
23
+
24
+ if dir_name:
25
+ subdirs = utils.list_subdirectories(dir_name)
26
+ labels = []
27
+
28
+ for folder in subdirs:
29
+ labels_in_folder = folder.split(",")
30
+
31
+ for label in labels_in_folder:
32
+ if label not in labels:
33
+ labels.append(label)
34
+
35
+ return dir_name, [[label] for label in sorted(labels)]
36
+
37
+ return None, None
38
+
39
+
40
+ @gu.gui_runtime_error_handler
41
+ def start_training(
42
+ data_dir,
43
+ test_data_dir,
44
+ crop_mode,
45
+ crop_overlap,
46
+ fmin,
47
+ fmax,
48
+ output_dir,
49
+ classifier_name,
50
+ model_save_mode,
51
+ cache_mode,
52
+ cache_file,
53
+ cache_file_name,
54
+ autotune,
55
+ autotune_trials,
56
+ autotune_executions_per_trials,
57
+ epochs,
58
+ batch_size,
59
+ learning_rate,
60
+ focal_loss,
61
+ focal_loss_gamma,
62
+ focal_loss_alpha,
63
+ hidden_units,
64
+ dropout,
65
+ label_smoothing,
66
+ use_mixup,
67
+ upsampling_ratio,
68
+ upsampling_mode,
69
+ model_format,
70
+ audio_speed,
71
+ progress=gr.Progress(),
72
+ ):
73
+ """Starts the training of a custom classifier.
74
+
75
+ Args:
76
+ data_dir: Directory containing the training data.
77
+ test_data_dir: Directory containing the test data.
78
+ crop_mode: Mode for cropping audio samples.
79
+ crop_overlap: Overlap ratio for audio segments.
80
+ fmin: Minimum frequency for bandpass filtering.
81
+ fmax: Maximum frequency for bandpass filtering.
82
+ output_dir: Directory to save the trained model.
83
+ classifier_name: Name of the custom classifier.
84
+ model_save_mode: Save mode for the model (replace or append).
85
+ cache_mode: Cache mode for training data (load, save, or None).
86
+ cache_file: Path to the cache file.
87
+ cache_file_name: Name of the cache file.
88
+ autotune: Whether to use hyperparameter autotuning.
89
+ autotune_trials: Number of trials for autotuning.
90
+ autotune_executions_per_trials: Number of executions per autotuning trial.
91
+ epochs: Number of training epochs.
92
+ batch_size: Batch size for training.
93
+ learning_rate: Learning rate for the optimizer.
94
+ focal_loss: Whether to use focal loss for training.
95
+ focal_loss_gamma: Gamma parameter for focal loss.
96
+ focal_loss_alpha: Alpha parameter for focal loss.
97
+ hidden_units: Number of hidden units in the droput: Dropout rate for regularization.
98
+ dropout: Dropout rate for regularization.
99
+ label_smoothing: Whether to apply label smoothing for training.
100
+ use_mixup: Whether to use mixup data augmentation.
101
+ upsampling_ratio: Ratio for upsampling underrepresented classes.
102
+ upsampling_mode: Mode for upsampling (repeat, mean, smote).
103
+ model_format: Format to save the trained model (tflite, raven, both).
104
+ audio_speed: Speed factor for audio playback.
105
+
106
+ Returns:
107
+ Returns a matplotlib.pyplot figure.
108
+ """
109
+ import matplotlib
110
+ import matplotlib.pyplot as plt
111
+
112
+ from birdnet_analyzer.train.utils import train_model
113
+
114
+ # Skip training data validation when cache mode is "load"
115
+ if cache_mode != "load":
116
+ gu.validate(data_dir, loc.localize("validation-no-training-data-selected"))
117
+
118
+ gu.validate(output_dir, loc.localize("validation-no-directory-for-classifier-selected"))
119
+ gu.validate(classifier_name, loc.localize("validation-no-valid-classifier-name"))
120
+
121
+ if not epochs or epochs < 0:
122
+ raise gr.Error(loc.localize("validation-no-valid-epoch-number"))
123
+
124
+ if not batch_size or batch_size < 0:
125
+ raise gr.Error(loc.localize("validation-no-valid-batch-size"))
126
+
127
+ if not learning_rate or learning_rate < 0:
128
+ raise gr.Error(loc.localize("validation-no-valid-learning-rate"))
129
+
130
+ if fmin < cfg.SIG_FMIN or fmax > cfg.SIG_FMAX or fmin > fmax:
131
+ raise gr.Error(f"{loc.localize('validation-no-valid-frequency')} [{cfg.SIG_FMIN}, {cfg.SIG_FMAX}]")
132
+
133
+ cfg.TRAIN_WITH_FOCAL_LOSS = focal_loss
134
+ cfg.FOCAL_LOSS_GAMMA = max(0.0, float(focal_loss_gamma))
135
+ cfg.FOCAL_LOSS_ALPHA = max(0.0, min(1.0, float(focal_loss_alpha)))
136
+
137
+ if not hidden_units or hidden_units < 0:
138
+ hidden_units = 0
139
+
140
+ cfg.TRAIN_DROPOUT = max(0.0, min(1.0, float(dropout)))
141
+
142
+ if progress is not None:
143
+ progress((0, epochs), desc=loc.localize("progress-build-classifier"), unit="epochs")
144
+
145
+ cfg.TRAIN_DATA_PATH = data_dir
146
+ cfg.TEST_DATA_PATH = test_data_dir
147
+ cfg.SAMPLE_CROP_MODE = crop_mode
148
+ cfg.SIG_OVERLAP = max(0.0, min(2.9, float(crop_overlap)))
149
+ cfg.CUSTOM_CLASSIFIER = str(Path(output_dir) / classifier_name)
150
+ cfg.TRAIN_EPOCHS = int(epochs)
151
+ cfg.TRAIN_BATCH_SIZE = int(batch_size)
152
+ cfg.TRAIN_LEARNING_RATE = learning_rate
153
+ cfg.TRAIN_HIDDEN_UNITS = int(hidden_units)
154
+ cfg.TRAIN_WITH_LABEL_SMOOTHING = label_smoothing
155
+ cfg.TRAIN_WITH_MIXUP = use_mixup
156
+ cfg.UPSAMPLING_RATIO = min(max(0, upsampling_ratio), 1)
157
+ cfg.UPSAMPLING_MODE = upsampling_mode
158
+ cfg.TRAINED_MODEL_OUTPUT_FORMAT = model_format
159
+
160
+ cfg.BANDPASS_FMIN = max(0, min(cfg.SIG_FMAX, int(fmin)))
161
+ cfg.BANDPASS_FMAX = max(cfg.SIG_FMIN, min(cfg.SIG_FMAX, int(fmax)))
162
+
163
+ cfg.TRAINED_MODEL_SAVE_MODE = model_save_mode
164
+ cfg.TRAIN_CACHE_MODE = cache_mode
165
+ cfg.TRAIN_CACHE_FILE = os.path.join(cache_file, cache_file_name) if cache_mode == "save" else cache_file
166
+ cfg.TFLITE_THREADS = 1
167
+ cfg.CPU_THREADS = max(1, multiprocessing.cpu_count() - 1) # let's use everything we have (well, almost)
168
+
169
+ if cache_mode == "load" and not os.path.isfile(cfg.TRAIN_CACHE_FILE):
170
+ raise gr.Error(loc.localize("validation-no-cache-file-selected"))
171
+
172
+ cfg.AUTOTUNE = autotune
173
+ cfg.AUTOTUNE_TRIALS = autotune_trials
174
+ cfg.AUTOTUNE_EXECUTIONS_PER_TRIAL = int(autotune_executions_per_trials)
175
+
176
+ cfg.AUDIO_SPEED = max(0.1, 1.0 / (audio_speed * -1)) if audio_speed < 0 else max(1.0, float(audio_speed))
177
+
178
+ def data_load_progression(num_files, num_total_files, label):
179
+ if progress is not None:
180
+ progress(
181
+ (num_files, num_total_files),
182
+ total=num_total_files,
183
+ unit="files",
184
+ desc=f"{loc.localize('progress-loading-data')} '{label}'",
185
+ )
186
+
187
+ def epoch_progression(epoch, logs=None):
188
+ if progress is not None:
189
+ if epoch + 1 == epochs:
190
+ progress(
191
+ (epoch + 1, epochs),
192
+ total=epochs,
193
+ unit="epochs",
194
+ desc=f"{loc.localize('progress-saving')} {cfg.CUSTOM_CLASSIFIER}",
195
+ )
196
+ else:
197
+ progress((epoch + 1, epochs), total=epochs, unit="epochs", desc=loc.localize("progress-training"))
198
+
199
+ def trial_progression(trial):
200
+ if progress is not None:
201
+ progress(
202
+ (trial, autotune_trials), total=autotune_trials, unit="trials", desc=loc.localize("progress-autotune")
203
+ )
204
+
205
+ try:
206
+ history_result = train_model(
207
+ on_epoch_end=epoch_progression,
208
+ on_trial_result=trial_progression,
209
+ on_data_load_end=data_load_progression,
210
+ autotune_directory=gu.APPDIR if utils.FROZEN else "autotune",
211
+ )
212
+
213
+ # Unpack history and metrics
214
+ history, metrics = history_result
215
+ except Exception as e:
216
+ if e.args and len(e.args) > 1:
217
+ raise gr.Error(loc.localize(e.args[1]))
218
+ else:
219
+ raise gr.Error(f"{e}")
220
+
221
+ if len(history.epoch) < epochs:
222
+ gr.Info(loc.localize("training-tab-early-stoppage-msg"))
223
+
224
+ auprc = history.history["val_AUPRC"]
225
+ auroc = history.history["val_AUROC"]
226
+
227
+ matplotlib.use("agg")
228
+
229
+ fig = plt.figure()
230
+ plt.plot(auprc, label="AUPRC")
231
+ plt.plot(auroc, label="AUROC")
232
+ plt.legend()
233
+ plt.xlabel("Epoch")
234
+
235
+ return fig, metrics
236
+
237
+
238
+ def build_train_tab():
239
+ with gr.Tab(loc.localize("training-tab-title")):
240
+ input_directory_state = gr.State()
241
+ output_directory_state = gr.State()
242
+ test_data_dir_state = gr.State()
243
+
244
+ with gr.Row():
245
+ with gr.Column():
246
+ select_directory_btn = gr.Button(loc.localize("training-tab-input-selection-button-label"))
247
+ directory_input = gr.List(
248
+ headers=[loc.localize("training-tab-classes-dataframe-column-classes-header")],
249
+ interactive=False,
250
+ max_height=_GRID_MAX_HEIGHT,
251
+ )
252
+ select_directory_btn.click(
253
+ partial(select_subdirectories, state_key="train-data-dir"),
254
+ outputs=[input_directory_state, directory_input],
255
+ show_progress=False,
256
+ )
257
+
258
+ select_test_directory_btn = gr.Button(loc.localize("training-tab-test-data-selection-button-label"))
259
+ test_directory_input = gr.List(
260
+ headers=[loc.localize("training-tab-classes-dataframe-column-classes-header")],
261
+ interactive=False,
262
+ max_height=_GRID_MAX_HEIGHT,
263
+ )
264
+ select_test_directory_btn.click(
265
+ partial(select_subdirectories, state_key="test-data-dir"),
266
+ outputs=[test_data_dir_state, test_directory_input],
267
+ show_progress=False,
268
+ )
269
+
270
+ with gr.Column():
271
+ select_classifier_directory_btn = gr.Button(loc.localize("training-tab-select-output-button-label"))
272
+
273
+ with gr.Column():
274
+ classifier_name = gr.Textbox(
275
+ "CustomClassifier",
276
+ visible=False,
277
+ info=loc.localize("training-tab-classifier-textbox-info"),
278
+ )
279
+ output_format = gr.Radio(
280
+ ["tflite", "raven", (loc.localize("training-tab-output-format-both"), "both")],
281
+ value=cfg.TRAINED_MODEL_OUTPUT_FORMAT,
282
+ label=loc.localize("training-tab-output-format-radio-label"),
283
+ info=loc.localize("training-tab-output-format-radio-info"),
284
+ visible=False,
285
+ )
286
+
287
+ def select_directory_and_update_tb():
288
+ dir_name = gu.select_folder(state_key="train-output-dir")
289
+
290
+ if dir_name:
291
+ return (
292
+ dir_name,
293
+ gr.Textbox(label=dir_name, visible=True),
294
+ gr.Radio(visible=True, interactive=True),
295
+ )
296
+
297
+ return None, None
298
+
299
+ select_classifier_directory_btn.click(
300
+ select_directory_and_update_tb,
301
+ outputs=[output_directory_state, classifier_name, output_format],
302
+ show_progress=False,
303
+ )
304
+
305
+ with gr.Row():
306
+ cache_file_state = gr.State()
307
+ cache_mode = gr.Radio(
308
+ [
309
+ (loc.localize("training-tab-cache-mode-radio-option-none"), None),
310
+ (loc.localize("training-tab-cache-mode-radio-option-load"), "load"),
311
+ (loc.localize("training-tab-cache-mode-radio-option-save"), "save"),
312
+ ],
313
+ value=cfg.TRAIN_CACHE_MODE,
314
+ label=loc.localize("training-tab-cache-mode-radio-label"),
315
+ info=loc.localize("training-tab-cache-mode-radio-info"),
316
+ )
317
+ with gr.Column(visible=False) as new_cache_file_row:
318
+ select_cache_file_directory_btn = gr.Button(
319
+ loc.localize("training-tab-cache-select-directory-button-label")
320
+ )
321
+
322
+ with gr.Column():
323
+ cache_file_name = gr.Textbox(
324
+ "train_cache.npz",
325
+ visible=False,
326
+ info=loc.localize("training-tab-cache-file-name-textbox-info"),
327
+ )
328
+
329
+ def select_directory_and_update():
330
+ dir_name = gu.select_folder(state_key="train-data-cache-file-output")
331
+
332
+ if dir_name:
333
+ return (
334
+ dir_name,
335
+ gr.Textbox(label=dir_name, visible=True),
336
+ )
337
+
338
+ return None, None
339
+
340
+ select_cache_file_directory_btn.click(
341
+ select_directory_and_update,
342
+ outputs=[cache_file_state, cache_file_name],
343
+ show_progress=False,
344
+ )
345
+
346
+ with gr.Column(visible=False) as load_cache_file_row:
347
+ selected_cache_file_btn = gr.Button(loc.localize("training-tab-cache-select-file-button-label"))
348
+ cache_file_input = gr.File(file_types=[".npz"], visible=False, interactive=False)
349
+
350
+ def on_cache_file_selection_click():
351
+ file = gu.select_file(("NPZ file (*.npz)",), state_key="train_data_cache_file")
352
+
353
+ if file:
354
+ return file, gr.File(value=file, visible=True)
355
+
356
+ return None, None
357
+
358
+ selected_cache_file_btn.click(
359
+ on_cache_file_selection_click,
360
+ outputs=[cache_file_state, cache_file_input],
361
+ show_progress=False,
362
+ )
363
+
364
+ def on_cache_mode_change(value):
365
+ return (
366
+ gr.update(visible=value == "save"),
367
+ gr.update(visible=value == "load"),
368
+ gr.update(interactive=value != "load"),
369
+ [],
370
+ gr.update(interactive=value != "load"),
371
+ [],
372
+ gr.update(interactive=value != "load"),
373
+ gr.update(interactive=value != "load"),
374
+ gr.update(interactive=value != "load"),
375
+ gr.update(interactive=value != "load"),
376
+ gr.update(interactive=value != "load"),
377
+ )
378
+
379
+ with gr.Row():
380
+ fmin_number = gr.Number(
381
+ cfg.SIG_FMIN,
382
+ minimum=0,
383
+ label=loc.localize("inference-settings-fmin-number-label"),
384
+ info=loc.localize("inference-settings-fmin-number-info"),
385
+ )
386
+
387
+ fmax_number = gr.Number(
388
+ cfg.SIG_FMAX,
389
+ minimum=0,
390
+ label=loc.localize("inference-settings-fmax-number-label"),
391
+ info=loc.localize("inference-settings-fmax-number-info"),
392
+ )
393
+
394
+ with gr.Row():
395
+ audio_speed_slider = gr.Slider(
396
+ minimum=-10,
397
+ maximum=10,
398
+ value=cfg.AUDIO_SPEED,
399
+ step=1,
400
+ label=loc.localize("training-tab-audio-speed-slider-label"),
401
+ info=loc.localize("training-tab-audio-speed-slider-info"),
402
+ )
403
+
404
+ with gr.Row():
405
+ crop_mode = gr.Radio(
406
+ [
407
+ (loc.localize("training-tab-crop-mode-radio-option-center"), "center"),
408
+ (loc.localize("training-tab-crop-mode-radio-option-first"), "first"),
409
+ (loc.localize("training-tab-crop-mode-radio-option-segments"), "segments"),
410
+ (loc.localize("training-tab-crop-mode-radio-option-smart"), "smart"),
411
+ ],
412
+ value="center",
413
+ label=loc.localize("training-tab-crop-mode-radio-label"),
414
+ info=loc.localize("training-tab-crop-mode-radio-info"),
415
+ )
416
+
417
+ crop_overlap = gr.Slider(
418
+ minimum=0,
419
+ maximum=2.99,
420
+ value=cfg.SIG_OVERLAP,
421
+ step=0.01,
422
+ label=loc.localize("training-tab-crop-overlap-number-label"),
423
+ info=loc.localize("training-tab-crop-overlap-number-info"),
424
+ visible=False,
425
+ )
426
+
427
+ def on_crop_select(new_crop_mode):
428
+ # Make overlap slider visible for both "segments" and "smart" crop modes
429
+ return gr.Number(
430
+ visible=new_crop_mode in ["segments", "smart"], interactive=new_crop_mode in ["segments", "smart"]
431
+ )
432
+
433
+ crop_mode.change(on_crop_select, inputs=crop_mode, outputs=crop_overlap)
434
+
435
+ cache_mode.change(
436
+ on_cache_mode_change,
437
+ inputs=cache_mode,
438
+ outputs=[
439
+ new_cache_file_row,
440
+ load_cache_file_row,
441
+ select_directory_btn,
442
+ directory_input,
443
+ select_test_directory_btn,
444
+ test_directory_input,
445
+ fmin_number,
446
+ fmax_number,
447
+ audio_speed_slider,
448
+ crop_mode,
449
+ crop_overlap,
450
+ ],
451
+ show_progress=False,
452
+ )
453
+
454
+ autotune_cb = gr.Checkbox(
455
+ cfg.AUTOTUNE,
456
+ label=loc.localize("training-tab-autotune-checkbox-label"),
457
+ info=loc.localize("training-tab-autotune-checkbox-info"),
458
+ )
459
+
460
+ with gr.Column(visible=False) as autotune_params:
461
+ with gr.Row():
462
+ autotune_trials = gr.Number(
463
+ cfg.AUTOTUNE_TRIALS,
464
+ label=loc.localize("training-tab-autotune-trials-number-label"),
465
+ info=loc.localize("training-tab-autotune-trials-number-info"),
466
+ minimum=1,
467
+ )
468
+ autotune_executions_per_trials = gr.Number(
469
+ cfg.AUTOTUNE_EXECUTIONS_PER_TRIAL,
470
+ minimum=1,
471
+ label=loc.localize("training-tab-autotune-executions-number-label"),
472
+ info=loc.localize("training-tab-autotune-executions-number-info"),
473
+ )
474
+
475
+ with gr.Column() as custom_params:
476
+ with gr.Row():
477
+ epoch_number = gr.Number(
478
+ cfg.TRAIN_EPOCHS,
479
+ minimum=1,
480
+ step=1,
481
+ label=loc.localize("training-tab-epochs-number-label"),
482
+ info=loc.localize("training-tab-epochs-number-info"),
483
+ )
484
+ batch_size_number = gr.Number(
485
+ 32,
486
+ minimum=1,
487
+ step=8,
488
+ label=loc.localize("training-tab-batchsize-number-label"),
489
+ info=loc.localize("training-tab-batchsize-number-info"),
490
+ )
491
+ learning_rate_number = gr.Number(
492
+ cfg.TRAIN_LEARNING_RATE,
493
+ minimum=0.0001,
494
+ step=0.0001,
495
+ label=loc.localize("training-tab-learningrate-number-label"),
496
+ info=loc.localize("training-tab-learningrate-number-info"),
497
+ )
498
+
499
+ with gr.Row():
500
+ hidden_units_number = gr.Number(
501
+ cfg.TRAIN_HIDDEN_UNITS,
502
+ minimum=0,
503
+ step=64,
504
+ label=loc.localize("training-tab-hiddenunits-number-label"),
505
+ info=loc.localize("training-tab-hiddenunits-number-info"),
506
+ )
507
+ dropout_number = gr.Number(
508
+ cfg.TRAIN_DROPOUT,
509
+ minimum=0.0,
510
+ maximum=0.9,
511
+ step=0.1,
512
+ label=loc.localize("training-tab-dropout-number-label"),
513
+ info=loc.localize("training-tab-dropout-number-info"),
514
+ )
515
+ use_label_smoothing = gr.Checkbox(
516
+ cfg.TRAIN_WITH_LABEL_SMOOTHING,
517
+ label=loc.localize("training-tab-use-labelsmoothing-checkbox-label"),
518
+ info=loc.localize("training-tab-use-labelsmoothing-checkbox-info"),
519
+ show_label=True,
520
+ )
521
+
522
+ with gr.Row():
523
+ upsampling_mode = gr.Radio(
524
+ [
525
+ (loc.localize("training-tab-upsampling-radio-option-repeat"), "repeat"),
526
+ (loc.localize("training-tab-upsampling-radio-option-mean"), "mean"),
527
+ (loc.localize("training-tab-upsampling-radio-option-linear"), "linear"),
528
+ ("SMOTE", "smote"),
529
+ ],
530
+ value=cfg.UPSAMPLING_MODE,
531
+ label=loc.localize("training-tab-upsampling-radio-label"),
532
+ info=loc.localize("training-tab-upsampling-radio-info"),
533
+ )
534
+ upsampling_ratio = gr.Slider(
535
+ 0.0,
536
+ 1.0,
537
+ cfg.UPSAMPLING_RATIO,
538
+ step=0.05,
539
+ label=loc.localize("training-tab-upsampling-ratio-slider-label"),
540
+ info=loc.localize("training-tab-upsampling-ratio-slider-info"),
541
+ )
542
+
543
+ with gr.Row():
544
+ use_mixup = gr.Checkbox(
545
+ cfg.TRAIN_WITH_MIXUP,
546
+ label=loc.localize("training-tab-use-mixup-checkbox-label"),
547
+ info=loc.localize("training-tab-use-mixup-checkbox-info"),
548
+ show_label=True,
549
+ )
550
+ use_focal_loss = gr.Checkbox(
551
+ cfg.TRAIN_WITH_FOCAL_LOSS,
552
+ label=loc.localize("training-tab-use-focal-loss-checkbox-label"),
553
+ info=loc.localize("training-tab-use-focal-loss-checkbox-info"),
554
+ show_label=True,
555
+ )
556
+
557
+ with gr.Row(visible=False) as focal_loss_params:
558
+ with gr.Column():
559
+ focal_loss_gamma = gr.Slider(
560
+ minimum=0.5,
561
+ maximum=5.0,
562
+ value=cfg.FOCAL_LOSS_GAMMA,
563
+ step=0.1,
564
+ label=loc.localize("training-tab-focal-loss-gamma-slider-label"),
565
+ info=loc.localize("training-tab-focal-loss-gamma-slider-info"),
566
+ interactive=True,
567
+ )
568
+ focal_loss_alpha = gr.Slider(
569
+ minimum=0.1,
570
+ maximum=0.9,
571
+ value=cfg.FOCAL_LOSS_ALPHA,
572
+ step=0.05,
573
+ label=loc.localize("training-tab-focal-loss-alpha-slider-label"),
574
+ info=loc.localize("training-tab-focal-loss-alpha-slider-info"),
575
+ interactive=True,
576
+ )
577
+
578
+ def on_focal_loss_change(value):
579
+ return gr.Row(visible=value)
580
+
581
+ use_focal_loss.change(
582
+ on_focal_loss_change, inputs=use_focal_loss, outputs=focal_loss_params, show_progress=False
583
+ )
584
+
585
+ def on_autotune_change(value):
586
+ return (
587
+ gr.Column(visible=not value),
588
+ gr.Column(visible=value),
589
+ gr.Row(visible=not value and use_focal_loss.value),
590
+ )
591
+
592
+ autotune_cb.change(
593
+ on_autotune_change,
594
+ inputs=autotune_cb,
595
+ outputs=[custom_params, autotune_params, focal_loss_params],
596
+ show_progress=False,
597
+ )
598
+
599
+ model_save_mode = gr.Radio(
600
+ [
601
+ (loc.localize("training-tab-model-save-mode-radio-option-replace"), "replace"),
602
+ (loc.localize("training-tab-model-save-mode-radio-option-append"), "append"),
603
+ ],
604
+ value=cfg.TRAINED_MODEL_SAVE_MODE,
605
+ label=loc.localize("training-tab-model-save-mode-radio-label"),
606
+ info=loc.localize("training-tab-model-save-mode-radio-info"),
607
+ )
608
+
609
+ train_history_plot = gr.Plot()
610
+ metrics_table = gr.Dataframe(
611
+ headers=["Class", "Precision", "Recall", "F1 Score", "AUPRC", "AUROC", "Samples"],
612
+ visible=False,
613
+ label="Model Performance Metrics (Default Threshold 0.5)",
614
+ )
615
+ start_training_button = gr.Button(
616
+ loc.localize("training-tab-start-training-button-label"), variant="huggingface"
617
+ )
618
+
619
+ def train_and_show_metrics(*args):
620
+ history, metrics = start_training(*args)
621
+
622
+ # If metrics are available (test data was provided), create table
623
+ if metrics:
624
+ # Create dataframe data with metrics
625
+ table_data = []
626
+
627
+ # Add overall metrics row first
628
+ table_data.append(
629
+ [
630
+ "OVERALL (Macro-avg)",
631
+ f"{metrics['macro_precision_default']:.4f}",
632
+ f"{metrics['macro_recall_default']:.4f}",
633
+ f"{metrics['macro_f1_default']:.4f}",
634
+ f"{metrics['macro_auprc']:.4f}",
635
+ f"{metrics['macro_auroc']:.4f}",
636
+ "",
637
+ ]
638
+ )
639
+
640
+ # Add class-specific metrics
641
+ for class_name, class_metrics in metrics["class_metrics"].items():
642
+ distribution = metrics["class_distribution"].get(class_name, {"count": 0, "percentage": 0.0})
643
+ table_data.append(
644
+ [
645
+ class_name,
646
+ f"{class_metrics['precision_default']:.4f}",
647
+ f"{class_metrics['recall_default']:.4f}",
648
+ f"{class_metrics['f1_default']:.4f}",
649
+ f"{class_metrics['auprc']:.4f}",
650
+ f"{class_metrics['auroc']:.4f}",
651
+ f"{distribution['count']} ({distribution['percentage']:.2f}%)",
652
+ ]
653
+ )
654
+
655
+ return history, gr.Dataframe(visible=True, value=table_data)
656
+ else:
657
+ # No metrics available, just return history and hide table
658
+ return history, gr.Dataframe(visible=False)
659
+
660
+ start_training_button.click(
661
+ train_and_show_metrics,
662
+ inputs=[
663
+ input_directory_state,
664
+ test_data_dir_state,
665
+ crop_mode,
666
+ crop_overlap,
667
+ fmin_number,
668
+ fmax_number,
669
+ output_directory_state,
670
+ classifier_name,
671
+ model_save_mode,
672
+ cache_mode,
673
+ cache_file_state,
674
+ cache_file_name,
675
+ autotune_cb,
676
+ autotune_trials,
677
+ autotune_executions_per_trials,
678
+ epoch_number,
679
+ batch_size_number,
680
+ learning_rate_number,
681
+ use_focal_loss,
682
+ focal_loss_gamma,
683
+ focal_loss_alpha,
684
+ hidden_units_number,
685
+ dropout_number,
686
+ use_label_smoothing,
687
+ use_mixup,
688
+ upsampling_ratio,
689
+ upsampling_mode,
690
+ output_format,
691
+ audio_speed_slider,
692
+ ],
693
+ outputs=[train_history_plot, metrics_table],
694
+ )
695
+
696
+
697
+ if __name__ == "__main__":
698
+ gu.open_window(build_train_tab)