birdnet-analyzer 2.0.0__py3-none-any.whl → 2.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- birdnet_analyzer/__init__.py +9 -8
- birdnet_analyzer/analyze/__init__.py +5 -5
- birdnet_analyzer/analyze/__main__.py +3 -4
- birdnet_analyzer/analyze/cli.py +25 -25
- birdnet_analyzer/analyze/core.py +241 -245
- birdnet_analyzer/analyze/utils.py +692 -701
- birdnet_analyzer/audio.py +368 -372
- birdnet_analyzer/cli.py +709 -707
- birdnet_analyzer/config.py +242 -242
- birdnet_analyzer/eBird_taxonomy_codes_2021E.json +25279 -25279
- birdnet_analyzer/embeddings/__init__.py +3 -4
- birdnet_analyzer/embeddings/__main__.py +3 -3
- birdnet_analyzer/embeddings/cli.py +12 -13
- birdnet_analyzer/embeddings/core.py +69 -70
- birdnet_analyzer/embeddings/utils.py +179 -193
- birdnet_analyzer/evaluation/__init__.py +196 -195
- birdnet_analyzer/evaluation/__main__.py +3 -3
- birdnet_analyzer/evaluation/assessment/__init__.py +0 -0
- birdnet_analyzer/evaluation/assessment/metrics.py +388 -0
- birdnet_analyzer/evaluation/assessment/performance_assessor.py +409 -0
- birdnet_analyzer/evaluation/assessment/plotting.py +379 -0
- birdnet_analyzer/evaluation/preprocessing/__init__.py +0 -0
- birdnet_analyzer/evaluation/preprocessing/data_processor.py +631 -0
- birdnet_analyzer/evaluation/preprocessing/utils.py +98 -0
- birdnet_analyzer/gui/__init__.py +19 -23
- birdnet_analyzer/gui/__main__.py +3 -3
- birdnet_analyzer/gui/analysis.py +175 -174
- birdnet_analyzer/gui/assets/arrow_down.svg +4 -4
- birdnet_analyzer/gui/assets/arrow_left.svg +4 -4
- birdnet_analyzer/gui/assets/arrow_right.svg +4 -4
- birdnet_analyzer/gui/assets/arrow_up.svg +4 -4
- birdnet_analyzer/gui/assets/gui.css +28 -28
- birdnet_analyzer/gui/assets/gui.js +93 -93
- birdnet_analyzer/gui/embeddings.py +619 -620
- birdnet_analyzer/gui/evaluation.py +795 -813
- birdnet_analyzer/gui/localization.py +75 -68
- birdnet_analyzer/gui/multi_file.py +245 -246
- birdnet_analyzer/gui/review.py +519 -527
- birdnet_analyzer/gui/segments.py +191 -191
- birdnet_analyzer/gui/settings.py +128 -129
- birdnet_analyzer/gui/single_file.py +267 -269
- birdnet_analyzer/gui/species.py +95 -95
- birdnet_analyzer/gui/train.py +696 -698
- birdnet_analyzer/gui/utils.py +810 -808
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_af.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ar.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_bg.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ca.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_cs.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_da.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_de.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_el.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_en_uk.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_es.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_fi.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_fr.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_he.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_hr.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_hu.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_in.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_is.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_it.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ja.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ko.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_lt.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ml.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_nl.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_no.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_pl.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_pt_BR.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_pt_PT.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ro.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ru.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sk.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sl.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sr.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sv.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_th.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_tr.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_uk.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_zh.txt +6522 -6522
- birdnet_analyzer/lang/de.json +334 -334
- birdnet_analyzer/lang/en.json +334 -334
- birdnet_analyzer/lang/fi.json +334 -334
- birdnet_analyzer/lang/fr.json +334 -334
- birdnet_analyzer/lang/id.json +334 -334
- birdnet_analyzer/lang/pt-br.json +334 -334
- birdnet_analyzer/lang/ru.json +334 -334
- birdnet_analyzer/lang/se.json +334 -334
- birdnet_analyzer/lang/tlh.json +334 -334
- birdnet_analyzer/lang/zh_TW.json +334 -334
- birdnet_analyzer/model.py +1212 -1243
- birdnet_analyzer/playground.py +5 -0
- birdnet_analyzer/search/__init__.py +3 -3
- birdnet_analyzer/search/__main__.py +3 -3
- birdnet_analyzer/search/cli.py +11 -12
- birdnet_analyzer/search/core.py +78 -78
- birdnet_analyzer/search/utils.py +107 -111
- birdnet_analyzer/segments/__init__.py +3 -3
- birdnet_analyzer/segments/__main__.py +3 -3
- birdnet_analyzer/segments/cli.py +13 -14
- birdnet_analyzer/segments/core.py +81 -78
- birdnet_analyzer/segments/utils.py +383 -394
- birdnet_analyzer/species/__init__.py +3 -3
- birdnet_analyzer/species/__main__.py +3 -3
- birdnet_analyzer/species/cli.py +13 -14
- birdnet_analyzer/species/core.py +35 -35
- birdnet_analyzer/species/utils.py +74 -75
- birdnet_analyzer/train/__init__.py +3 -3
- birdnet_analyzer/train/__main__.py +3 -3
- birdnet_analyzer/train/cli.py +13 -14
- birdnet_analyzer/train/core.py +113 -113
- birdnet_analyzer/train/utils.py +877 -847
- birdnet_analyzer/translate.py +133 -104
- birdnet_analyzer/utils.py +426 -419
- {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.0.1.dist-info}/METADATA +137 -129
- birdnet_analyzer-2.0.1.dist-info/RECORD +125 -0
- {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.0.1.dist-info}/WHEEL +1 -1
- {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.0.1.dist-info}/licenses/LICENSE +18 -18
- birdnet_analyzer-2.0.0.dist-info/RECORD +0 -117
- {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.0.1.dist-info}/entry_points.txt +0 -0
- {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.0.1.dist-info}/top_level.txt +0 -0
birdnet_analyzer/gui/train.py
CHANGED
@@ -1,698 +1,696 @@
|
|
1
|
-
import multiprocessing
|
2
|
-
import os
|
3
|
-
from functools import partial
|
4
|
-
from pathlib import Path
|
5
|
-
|
6
|
-
import gradio as gr
|
7
|
-
|
8
|
-
import birdnet_analyzer.config as cfg
|
9
|
-
import birdnet_analyzer.gui.localization as loc
|
10
|
-
import birdnet_analyzer.gui.utils as gu
|
11
|
-
|
12
|
-
|
13
|
-
_GRID_MAX_HEIGHT = 240
|
14
|
-
|
15
|
-
|
16
|
-
def select_subdirectories(state_key=None):
|
17
|
-
"""Creates a directory selection dialog.
|
18
|
-
|
19
|
-
Returns:
|
20
|
-
A tuples of (directory, list of subdirectories) or (None, None) if the dialog was canceled.
|
21
|
-
"""
|
22
|
-
dir_name = gu.select_folder(state_key=state_key)
|
23
|
-
|
24
|
-
if dir_name:
|
25
|
-
subdirs = utils.list_subdirectories(dir_name)
|
26
|
-
labels = []
|
27
|
-
|
28
|
-
for folder in subdirs:
|
29
|
-
labels_in_folder = folder.split(",")
|
30
|
-
|
31
|
-
for label in labels_in_folder:
|
32
|
-
if label not in labels:
|
33
|
-
labels.append(label)
|
34
|
-
|
35
|
-
return dir_name, [[label] for label in sorted(labels)]
|
36
|
-
|
37
|
-
return None, None
|
38
|
-
|
39
|
-
|
40
|
-
@gu.gui_runtime_error_handler
|
41
|
-
def start_training(
|
42
|
-
data_dir,
|
43
|
-
test_data_dir,
|
44
|
-
crop_mode,
|
45
|
-
crop_overlap,
|
46
|
-
fmin,
|
47
|
-
fmax,
|
48
|
-
output_dir,
|
49
|
-
classifier_name,
|
50
|
-
model_save_mode,
|
51
|
-
cache_mode,
|
52
|
-
cache_file,
|
53
|
-
cache_file_name,
|
54
|
-
autotune,
|
55
|
-
autotune_trials,
|
56
|
-
autotune_executions_per_trials,
|
57
|
-
epochs,
|
58
|
-
batch_size,
|
59
|
-
learning_rate,
|
60
|
-
focal_loss,
|
61
|
-
focal_loss_gamma,
|
62
|
-
focal_loss_alpha,
|
63
|
-
hidden_units,
|
64
|
-
dropout,
|
65
|
-
label_smoothing,
|
66
|
-
use_mixup,
|
67
|
-
upsampling_ratio,
|
68
|
-
upsampling_mode,
|
69
|
-
model_format,
|
70
|
-
audio_speed,
|
71
|
-
progress=gr.Progress(),
|
72
|
-
):
|
73
|
-
"""Starts the training of a custom classifier.
|
74
|
-
|
75
|
-
Args:
|
76
|
-
data_dir: Directory containing the training data.
|
77
|
-
test_data_dir: Directory containing the test data.
|
78
|
-
crop_mode: Mode for cropping audio samples.
|
79
|
-
crop_overlap: Overlap ratio for audio segments.
|
80
|
-
fmin: Minimum frequency for bandpass filtering.
|
81
|
-
fmax: Maximum frequency for bandpass filtering.
|
82
|
-
output_dir: Directory to save the trained model.
|
83
|
-
classifier_name: Name of the custom classifier.
|
84
|
-
model_save_mode: Save mode for the model (replace or append).
|
85
|
-
cache_mode: Cache mode for training data (load, save, or None).
|
86
|
-
cache_file: Path to the cache file.
|
87
|
-
cache_file_name: Name of the cache file.
|
88
|
-
autotune: Whether to use hyperparameter autotuning.
|
89
|
-
autotune_trials: Number of trials for autotuning.
|
90
|
-
autotune_executions_per_trials: Number of executions per autotuning trial.
|
91
|
-
epochs: Number of training epochs.
|
92
|
-
batch_size: Batch size for training.
|
93
|
-
learning_rate: Learning rate for the optimizer.
|
94
|
-
focal_loss: Whether to use focal loss for training.
|
95
|
-
focal_loss_gamma: Gamma parameter for focal loss.
|
96
|
-
focal_loss_alpha: Alpha parameter for focal loss.
|
97
|
-
hidden_units: Number of hidden units in the droput: Dropout rate for regularization.
|
98
|
-
dropout: Dropout rate for regularization.
|
99
|
-
label_smoothing: Whether to apply label smoothing for training.
|
100
|
-
use_mixup: Whether to use mixup data augmentation.
|
101
|
-
upsampling_ratio: Ratio for upsampling underrepresented classes.
|
102
|
-
upsampling_mode: Mode for upsampling (repeat, mean, smote).
|
103
|
-
model_format: Format to save the trained model (tflite, raven, both).
|
104
|
-
audio_speed: Speed factor for audio playback.
|
105
|
-
|
106
|
-
Returns:
|
107
|
-
Returns a matplotlib.pyplot figure.
|
108
|
-
"""
|
109
|
-
import matplotlib
|
110
|
-
import matplotlib.pyplot as plt
|
111
|
-
|
112
|
-
from birdnet_analyzer.train.utils import train_model
|
113
|
-
|
114
|
-
# Skip training data validation when cache mode is "load"
|
115
|
-
if cache_mode != "load":
|
116
|
-
gu.validate(data_dir, loc.localize("validation-no-training-data-selected"))
|
117
|
-
|
118
|
-
gu.validate(output_dir, loc.localize("validation-no-directory-for-classifier-selected"))
|
119
|
-
gu.validate(classifier_name, loc.localize("validation-no-valid-classifier-name"))
|
120
|
-
|
121
|
-
if not epochs or epochs < 0:
|
122
|
-
raise gr.Error(loc.localize("validation-no-valid-epoch-number"))
|
123
|
-
|
124
|
-
if not batch_size or batch_size < 0:
|
125
|
-
raise gr.Error(loc.localize("validation-no-valid-batch-size"))
|
126
|
-
|
127
|
-
if not learning_rate or learning_rate < 0:
|
128
|
-
raise gr.Error(loc.localize("validation-no-valid-learning-rate"))
|
129
|
-
|
130
|
-
if fmin < cfg.SIG_FMIN or fmax > cfg.SIG_FMAX or fmin > fmax:
|
131
|
-
raise gr.Error(f"{loc.localize('validation-no-valid-frequency')} [{cfg.SIG_FMIN}, {cfg.SIG_FMAX}]")
|
132
|
-
|
133
|
-
cfg.TRAIN_WITH_FOCAL_LOSS = focal_loss
|
134
|
-
cfg.FOCAL_LOSS_GAMMA = max(0.0, float(focal_loss_gamma))
|
135
|
-
cfg.FOCAL_LOSS_ALPHA = max(0.0, min(1.0, float(focal_loss_alpha)))
|
136
|
-
|
137
|
-
if not hidden_units or hidden_units < 0:
|
138
|
-
hidden_units = 0
|
139
|
-
|
140
|
-
cfg.TRAIN_DROPOUT = max(0.0, min(1.0, float(dropout)))
|
141
|
-
|
142
|
-
if progress is not None:
|
143
|
-
progress((0, epochs), desc=loc.localize("progress-build-classifier"), unit="epochs")
|
144
|
-
|
145
|
-
cfg.TRAIN_DATA_PATH = data_dir
|
146
|
-
cfg.TEST_DATA_PATH = test_data_dir
|
147
|
-
cfg.SAMPLE_CROP_MODE = crop_mode
|
148
|
-
cfg.SIG_OVERLAP = max(0.0, min(2.9, float(crop_overlap)))
|
149
|
-
cfg.CUSTOM_CLASSIFIER = str(Path(output_dir) / classifier_name)
|
150
|
-
cfg.TRAIN_EPOCHS = int(epochs)
|
151
|
-
cfg.TRAIN_BATCH_SIZE = int(batch_size)
|
152
|
-
cfg.TRAIN_LEARNING_RATE = learning_rate
|
153
|
-
cfg.TRAIN_HIDDEN_UNITS = int(hidden_units)
|
154
|
-
cfg.TRAIN_WITH_LABEL_SMOOTHING = label_smoothing
|
155
|
-
cfg.TRAIN_WITH_MIXUP = use_mixup
|
156
|
-
cfg.UPSAMPLING_RATIO = min(max(0, upsampling_ratio), 1)
|
157
|
-
cfg.UPSAMPLING_MODE = upsampling_mode
|
158
|
-
cfg.TRAINED_MODEL_OUTPUT_FORMAT = model_format
|
159
|
-
|
160
|
-
cfg.BANDPASS_FMIN = max(0, min(cfg.SIG_FMAX, int(fmin)))
|
161
|
-
cfg.BANDPASS_FMAX = max(cfg.SIG_FMIN, min(cfg.SIG_FMAX, int(fmax)))
|
162
|
-
|
163
|
-
cfg.TRAINED_MODEL_SAVE_MODE = model_save_mode
|
164
|
-
cfg.TRAIN_CACHE_MODE = cache_mode
|
165
|
-
cfg.TRAIN_CACHE_FILE = os.path.join(cache_file, cache_file_name) if cache_mode == "save" else cache_file
|
166
|
-
cfg.TFLITE_THREADS = 1
|
167
|
-
cfg.CPU_THREADS = max(1, multiprocessing.cpu_count() - 1) # let's use everything we have (well, almost)
|
168
|
-
|
169
|
-
if cache_mode == "load" and not os.path.isfile(cfg.TRAIN_CACHE_FILE):
|
170
|
-
raise gr.Error(loc.localize("validation-no-cache-file-selected"))
|
171
|
-
|
172
|
-
cfg.AUTOTUNE = autotune
|
173
|
-
cfg.AUTOTUNE_TRIALS = autotune_trials
|
174
|
-
cfg.AUTOTUNE_EXECUTIONS_PER_TRIAL = int(autotune_executions_per_trials)
|
175
|
-
|
176
|
-
cfg.AUDIO_SPEED = max(0.1, 1.0 / (audio_speed * -1)) if audio_speed < 0 else max(1.0, float(audio_speed))
|
177
|
-
|
178
|
-
def data_load_progression(num_files, num_total_files, label):
|
179
|
-
if progress is not None:
|
180
|
-
progress(
|
181
|
-
(num_files, num_total_files),
|
182
|
-
total=num_total_files,
|
183
|
-
unit="files",
|
184
|
-
desc=f"{loc.localize('progress-loading-data')} '{label}'",
|
185
|
-
)
|
186
|
-
|
187
|
-
def epoch_progression(epoch, logs=None):
|
188
|
-
if progress is not None:
|
189
|
-
if epoch + 1 == epochs:
|
190
|
-
progress(
|
191
|
-
(epoch + 1, epochs),
|
192
|
-
total=epochs,
|
193
|
-
unit="epochs",
|
194
|
-
desc=f"{loc.localize('progress-saving')} {cfg.CUSTOM_CLASSIFIER}",
|
195
|
-
)
|
196
|
-
else:
|
197
|
-
progress((epoch + 1, epochs), total=epochs, unit="epochs", desc=loc.localize("progress-training"))
|
198
|
-
|
199
|
-
def trial_progression(trial):
|
200
|
-
if progress is not None:
|
201
|
-
progress(
|
202
|
-
(trial, autotune_trials), total=autotune_trials, unit="trials", desc=loc.localize("progress-autotune")
|
203
|
-
)
|
204
|
-
|
205
|
-
try:
|
206
|
-
history_result = train_model(
|
207
|
-
on_epoch_end=epoch_progression,
|
208
|
-
on_trial_result=trial_progression,
|
209
|
-
on_data_load_end=data_load_progression,
|
210
|
-
autotune_directory=gu.APPDIR if utils.FROZEN else "autotune",
|
211
|
-
)
|
212
|
-
|
213
|
-
# Unpack history and metrics
|
214
|
-
history, metrics = history_result
|
215
|
-
except Exception as e:
|
216
|
-
if e.args and len(e.args) > 1:
|
217
|
-
raise gr.Error(loc.localize(e.args[1]))
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
if len(history.epoch) < epochs:
|
222
|
-
gr.Info(loc.localize("training-tab-early-stoppage-msg"))
|
223
|
-
|
224
|
-
auprc = history.history["val_AUPRC"]
|
225
|
-
auroc = history.history["val_AUROC"]
|
226
|
-
|
227
|
-
matplotlib.use("agg")
|
228
|
-
|
229
|
-
fig = plt.figure()
|
230
|
-
plt.plot(auprc, label="AUPRC")
|
231
|
-
plt.plot(auroc, label="AUROC")
|
232
|
-
plt.legend()
|
233
|
-
plt.xlabel("Epoch")
|
234
|
-
|
235
|
-
return fig, metrics
|
236
|
-
|
237
|
-
|
238
|
-
def build_train_tab():
|
239
|
-
with gr.Tab(loc.localize("training-tab-title")):
|
240
|
-
input_directory_state = gr.State()
|
241
|
-
output_directory_state = gr.State()
|
242
|
-
test_data_dir_state = gr.State()
|
243
|
-
|
244
|
-
with gr.Row():
|
245
|
-
with gr.Column():
|
246
|
-
select_directory_btn = gr.Button(loc.localize("training-tab-input-selection-button-label"))
|
247
|
-
directory_input = gr.List(
|
248
|
-
headers=[loc.localize("training-tab-classes-dataframe-column-classes-header")],
|
249
|
-
interactive=False,
|
250
|
-
max_height=_GRID_MAX_HEIGHT,
|
251
|
-
)
|
252
|
-
select_directory_btn.click(
|
253
|
-
partial(select_subdirectories, state_key="train-data-dir"),
|
254
|
-
outputs=[input_directory_state, directory_input],
|
255
|
-
show_progress=False,
|
256
|
-
)
|
257
|
-
|
258
|
-
select_test_directory_btn = gr.Button(loc.localize("training-tab-test-data-selection-button-label"))
|
259
|
-
test_directory_input = gr.List(
|
260
|
-
headers=[loc.localize("training-tab-classes-dataframe-column-classes-header")],
|
261
|
-
interactive=False,
|
262
|
-
max_height=_GRID_MAX_HEIGHT,
|
263
|
-
)
|
264
|
-
select_test_directory_btn.click(
|
265
|
-
partial(select_subdirectories, state_key="test-data-dir"),
|
266
|
-
outputs=[test_data_dir_state, test_directory_input],
|
267
|
-
show_progress=False,
|
268
|
-
)
|
269
|
-
|
270
|
-
with gr.Column():
|
271
|
-
select_classifier_directory_btn = gr.Button(loc.localize("training-tab-select-output-button-label"))
|
272
|
-
|
273
|
-
with gr.Column():
|
274
|
-
classifier_name = gr.Textbox(
|
275
|
-
"CustomClassifier",
|
276
|
-
visible=False,
|
277
|
-
info=loc.localize("training-tab-classifier-textbox-info"),
|
278
|
-
)
|
279
|
-
output_format = gr.Radio(
|
280
|
-
["tflite", "raven", (loc.localize("training-tab-output-format-both"), "both")],
|
281
|
-
value=cfg.TRAINED_MODEL_OUTPUT_FORMAT,
|
282
|
-
label=loc.localize("training-tab-output-format-radio-label"),
|
283
|
-
info=loc.localize("training-tab-output-format-radio-info"),
|
284
|
-
visible=False,
|
285
|
-
)
|
286
|
-
|
287
|
-
def select_directory_and_update_tb():
|
288
|
-
dir_name = gu.select_folder(state_key="train-output-dir")
|
289
|
-
|
290
|
-
if dir_name:
|
291
|
-
return (
|
292
|
-
dir_name,
|
293
|
-
gr.Textbox(label=dir_name, visible=True),
|
294
|
-
gr.Radio(visible=True, interactive=True),
|
295
|
-
)
|
296
|
-
|
297
|
-
return None, None
|
298
|
-
|
299
|
-
select_classifier_directory_btn.click(
|
300
|
-
select_directory_and_update_tb,
|
301
|
-
outputs=[output_directory_state, classifier_name, output_format],
|
302
|
-
show_progress=False,
|
303
|
-
)
|
304
|
-
|
305
|
-
with gr.Row():
|
306
|
-
cache_file_state = gr.State()
|
307
|
-
cache_mode = gr.Radio(
|
308
|
-
[
|
309
|
-
(loc.localize("training-tab-cache-mode-radio-option-none"), None),
|
310
|
-
(loc.localize("training-tab-cache-mode-radio-option-load"), "load"),
|
311
|
-
(loc.localize("training-tab-cache-mode-radio-option-save"), "save"),
|
312
|
-
],
|
313
|
-
value=cfg.TRAIN_CACHE_MODE,
|
314
|
-
label=loc.localize("training-tab-cache-mode-radio-label"),
|
315
|
-
info=loc.localize("training-tab-cache-mode-radio-info"),
|
316
|
-
)
|
317
|
-
with gr.Column(visible=False) as new_cache_file_row:
|
318
|
-
select_cache_file_directory_btn = gr.Button(
|
319
|
-
loc.localize("training-tab-cache-select-directory-button-label")
|
320
|
-
)
|
321
|
-
|
322
|
-
with gr.Column():
|
323
|
-
cache_file_name = gr.Textbox(
|
324
|
-
"train_cache.npz",
|
325
|
-
visible=False,
|
326
|
-
info=loc.localize("training-tab-cache-file-name-textbox-info"),
|
327
|
-
)
|
328
|
-
|
329
|
-
def select_directory_and_update():
|
330
|
-
dir_name = gu.select_folder(state_key="train-data-cache-file-output")
|
331
|
-
|
332
|
-
if dir_name:
|
333
|
-
return (
|
334
|
-
dir_name,
|
335
|
-
gr.Textbox(label=dir_name, visible=True),
|
336
|
-
)
|
337
|
-
|
338
|
-
return None, None
|
339
|
-
|
340
|
-
select_cache_file_directory_btn.click(
|
341
|
-
select_directory_and_update,
|
342
|
-
outputs=[cache_file_state, cache_file_name],
|
343
|
-
show_progress=False,
|
344
|
-
)
|
345
|
-
|
346
|
-
with gr.Column(visible=False) as load_cache_file_row:
|
347
|
-
selected_cache_file_btn = gr.Button(loc.localize("training-tab-cache-select-file-button-label"))
|
348
|
-
cache_file_input = gr.File(file_types=[".npz"], visible=False, interactive=False)
|
349
|
-
|
350
|
-
def on_cache_file_selection_click():
|
351
|
-
file = gu.select_file(("NPZ file (*.npz)",), state_key="train_data_cache_file")
|
352
|
-
|
353
|
-
if file:
|
354
|
-
return file, gr.File(value=file, visible=True)
|
355
|
-
|
356
|
-
return None, None
|
357
|
-
|
358
|
-
selected_cache_file_btn.click(
|
359
|
-
on_cache_file_selection_click,
|
360
|
-
outputs=[cache_file_state, cache_file_input],
|
361
|
-
show_progress=False,
|
362
|
-
)
|
363
|
-
|
364
|
-
def on_cache_mode_change(value):
|
365
|
-
return (
|
366
|
-
gr.update(visible=value == "save"),
|
367
|
-
gr.update(visible=value == "load"),
|
368
|
-
gr.update(interactive=value != "load"),
|
369
|
-
[],
|
370
|
-
gr.update(interactive=value != "load"),
|
371
|
-
[],
|
372
|
-
gr.update(interactive=value != "load"),
|
373
|
-
gr.update(interactive=value != "load"),
|
374
|
-
gr.update(interactive=value != "load"),
|
375
|
-
gr.update(interactive=value != "load"),
|
376
|
-
gr.update(interactive=value != "load"),
|
377
|
-
)
|
378
|
-
|
379
|
-
with gr.Row():
|
380
|
-
fmin_number = gr.Number(
|
381
|
-
cfg.SIG_FMIN,
|
382
|
-
minimum=0,
|
383
|
-
label=loc.localize("inference-settings-fmin-number-label"),
|
384
|
-
info=loc.localize("inference-settings-fmin-number-info"),
|
385
|
-
)
|
386
|
-
|
387
|
-
fmax_number = gr.Number(
|
388
|
-
cfg.SIG_FMAX,
|
389
|
-
minimum=0,
|
390
|
-
label=loc.localize("inference-settings-fmax-number-label"),
|
391
|
-
info=loc.localize("inference-settings-fmax-number-info"),
|
392
|
-
)
|
393
|
-
|
394
|
-
with gr.Row():
|
395
|
-
audio_speed_slider = gr.Slider(
|
396
|
-
minimum=-10,
|
397
|
-
maximum=10,
|
398
|
-
value=cfg.AUDIO_SPEED,
|
399
|
-
step=1,
|
400
|
-
label=loc.localize("training-tab-audio-speed-slider-label"),
|
401
|
-
info=loc.localize("training-tab-audio-speed-slider-info"),
|
402
|
-
)
|
403
|
-
|
404
|
-
with gr.Row():
|
405
|
-
crop_mode = gr.Radio(
|
406
|
-
[
|
407
|
-
(loc.localize("training-tab-crop-mode-radio-option-center"), "center"),
|
408
|
-
(loc.localize("training-tab-crop-mode-radio-option-first"), "first"),
|
409
|
-
(loc.localize("training-tab-crop-mode-radio-option-segments"), "segments"),
|
410
|
-
(loc.localize("training-tab-crop-mode-radio-option-smart"), "smart"),
|
411
|
-
],
|
412
|
-
value="center",
|
413
|
-
label=loc.localize("training-tab-crop-mode-radio-label"),
|
414
|
-
info=loc.localize("training-tab-crop-mode-radio-info"),
|
415
|
-
)
|
416
|
-
|
417
|
-
crop_overlap = gr.Slider(
|
418
|
-
minimum=0,
|
419
|
-
maximum=2.99,
|
420
|
-
value=cfg.SIG_OVERLAP,
|
421
|
-
step=0.01,
|
422
|
-
label=loc.localize("training-tab-crop-overlap-number-label"),
|
423
|
-
info=loc.localize("training-tab-crop-overlap-number-info"),
|
424
|
-
visible=False,
|
425
|
-
)
|
426
|
-
|
427
|
-
def on_crop_select(new_crop_mode):
|
428
|
-
# Make overlap slider visible for both "segments" and "smart" crop modes
|
429
|
-
return gr.Number(
|
430
|
-
visible=new_crop_mode in ["segments", "smart"], interactive=new_crop_mode in ["segments", "smart"]
|
431
|
-
)
|
432
|
-
|
433
|
-
crop_mode.change(on_crop_select, inputs=crop_mode, outputs=crop_overlap)
|
434
|
-
|
435
|
-
cache_mode.change(
|
436
|
-
on_cache_mode_change,
|
437
|
-
inputs=cache_mode,
|
438
|
-
outputs=[
|
439
|
-
new_cache_file_row,
|
440
|
-
load_cache_file_row,
|
441
|
-
select_directory_btn,
|
442
|
-
directory_input,
|
443
|
-
select_test_directory_btn,
|
444
|
-
test_directory_input,
|
445
|
-
fmin_number,
|
446
|
-
fmax_number,
|
447
|
-
audio_speed_slider,
|
448
|
-
crop_mode,
|
449
|
-
crop_overlap,
|
450
|
-
],
|
451
|
-
show_progress=False,
|
452
|
-
)
|
453
|
-
|
454
|
-
autotune_cb = gr.Checkbox(
|
455
|
-
cfg.AUTOTUNE,
|
456
|
-
label=loc.localize("training-tab-autotune-checkbox-label"),
|
457
|
-
info=loc.localize("training-tab-autotune-checkbox-info"),
|
458
|
-
)
|
459
|
-
|
460
|
-
with gr.Column(visible=False) as autotune_params:
|
461
|
-
|
462
|
-
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
|
468
|
-
|
469
|
-
|
470
|
-
|
471
|
-
|
472
|
-
|
473
|
-
|
474
|
-
|
475
|
-
|
476
|
-
|
477
|
-
|
478
|
-
|
479
|
-
|
480
|
-
|
481
|
-
|
482
|
-
|
483
|
-
|
484
|
-
|
485
|
-
|
486
|
-
|
487
|
-
|
488
|
-
|
489
|
-
|
490
|
-
|
491
|
-
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
|
500
|
-
|
501
|
-
|
502
|
-
|
503
|
-
|
504
|
-
|
505
|
-
|
506
|
-
|
507
|
-
|
508
|
-
|
509
|
-
|
510
|
-
|
511
|
-
|
512
|
-
|
513
|
-
|
514
|
-
|
515
|
-
|
516
|
-
|
517
|
-
|
518
|
-
|
519
|
-
|
520
|
-
|
521
|
-
|
522
|
-
|
523
|
-
|
524
|
-
|
525
|
-
(loc.localize("training-tab-upsampling-radio-option-
|
526
|
-
(loc.localize("training-tab-upsampling-radio-option-
|
527
|
-
(
|
528
|
-
|
529
|
-
|
530
|
-
|
531
|
-
|
532
|
-
|
533
|
-
|
534
|
-
|
535
|
-
|
536
|
-
|
537
|
-
|
538
|
-
|
539
|
-
|
540
|
-
|
541
|
-
|
542
|
-
|
543
|
-
|
544
|
-
|
545
|
-
|
546
|
-
|
547
|
-
|
548
|
-
|
549
|
-
|
550
|
-
|
551
|
-
|
552
|
-
|
553
|
-
|
554
|
-
|
555
|
-
|
556
|
-
|
557
|
-
|
558
|
-
|
559
|
-
|
560
|
-
|
561
|
-
|
562
|
-
|
563
|
-
|
564
|
-
|
565
|
-
|
566
|
-
|
567
|
-
|
568
|
-
|
569
|
-
|
570
|
-
|
571
|
-
|
572
|
-
|
573
|
-
|
574
|
-
|
575
|
-
|
576
|
-
|
577
|
-
|
578
|
-
|
579
|
-
|
580
|
-
|
581
|
-
|
582
|
-
|
583
|
-
)
|
584
|
-
|
585
|
-
|
586
|
-
|
587
|
-
gr.
|
588
|
-
|
589
|
-
|
590
|
-
|
591
|
-
|
592
|
-
|
593
|
-
|
594
|
-
|
595
|
-
|
596
|
-
|
597
|
-
|
598
|
-
|
599
|
-
|
600
|
-
|
601
|
-
|
602
|
-
|
603
|
-
|
604
|
-
|
605
|
-
|
606
|
-
|
607
|
-
)
|
608
|
-
|
609
|
-
|
610
|
-
|
611
|
-
|
612
|
-
|
613
|
-
|
614
|
-
|
615
|
-
|
616
|
-
|
617
|
-
)
|
618
|
-
|
619
|
-
|
620
|
-
|
621
|
-
|
622
|
-
|
623
|
-
|
624
|
-
|
625
|
-
|
626
|
-
|
627
|
-
|
628
|
-
|
629
|
-
|
630
|
-
"
|
631
|
-
f"{metrics['
|
632
|
-
f"{metrics['
|
633
|
-
f"{metrics['
|
634
|
-
|
635
|
-
|
636
|
-
|
637
|
-
|
638
|
-
|
639
|
-
|
640
|
-
|
641
|
-
|
642
|
-
|
643
|
-
|
644
|
-
|
645
|
-
|
646
|
-
f"{class_metrics['
|
647
|
-
f"{class_metrics['
|
648
|
-
f"{class_metrics['
|
649
|
-
f"{
|
650
|
-
|
651
|
-
|
652
|
-
|
653
|
-
|
654
|
-
|
655
|
-
|
656
|
-
|
657
|
-
|
658
|
-
|
659
|
-
|
660
|
-
|
661
|
-
|
662
|
-
|
663
|
-
|
664
|
-
|
665
|
-
|
666
|
-
|
667
|
-
|
668
|
-
|
669
|
-
|
670
|
-
|
671
|
-
|
672
|
-
|
673
|
-
|
674
|
-
|
675
|
-
|
676
|
-
|
677
|
-
|
678
|
-
|
679
|
-
|
680
|
-
|
681
|
-
|
682
|
-
|
683
|
-
|
684
|
-
|
685
|
-
|
686
|
-
|
687
|
-
|
688
|
-
|
689
|
-
|
690
|
-
|
691
|
-
|
692
|
-
|
693
|
-
|
694
|
-
|
695
|
-
|
696
|
-
|
697
|
-
if __name__ == "__main__":
|
698
|
-
gu.open_window(build_train_tab)
|
1
|
+
import multiprocessing
|
2
|
+
import os
|
3
|
+
from functools import partial
|
4
|
+
from pathlib import Path
|
5
|
+
|
6
|
+
import gradio as gr
|
7
|
+
|
8
|
+
import birdnet_analyzer.config as cfg
|
9
|
+
import birdnet_analyzer.gui.localization as loc
|
10
|
+
import birdnet_analyzer.gui.utils as gu
|
11
|
+
from birdnet_analyzer import utils
|
12
|
+
|
13
|
+
_GRID_MAX_HEIGHT = 240
|
14
|
+
|
15
|
+
|
16
|
+
def select_subdirectories(state_key=None):
|
17
|
+
"""Creates a directory selection dialog.
|
18
|
+
|
19
|
+
Returns:
|
20
|
+
A tuples of (directory, list of subdirectories) or (None, None) if the dialog was canceled.
|
21
|
+
"""
|
22
|
+
dir_name = gu.select_folder(state_key=state_key)
|
23
|
+
|
24
|
+
if dir_name:
|
25
|
+
subdirs = utils.list_subdirectories(dir_name)
|
26
|
+
labels = []
|
27
|
+
|
28
|
+
for folder in subdirs:
|
29
|
+
labels_in_folder = folder.split(",")
|
30
|
+
|
31
|
+
for label in labels_in_folder:
|
32
|
+
if label not in labels:
|
33
|
+
labels.append(label)
|
34
|
+
|
35
|
+
return dir_name, [[label] for label in sorted(labels)]
|
36
|
+
|
37
|
+
return None, None
|
38
|
+
|
39
|
+
|
40
|
+
@gu.gui_runtime_error_handler
|
41
|
+
def start_training(
|
42
|
+
data_dir,
|
43
|
+
test_data_dir,
|
44
|
+
crop_mode,
|
45
|
+
crop_overlap,
|
46
|
+
fmin,
|
47
|
+
fmax,
|
48
|
+
output_dir,
|
49
|
+
classifier_name,
|
50
|
+
model_save_mode,
|
51
|
+
cache_mode,
|
52
|
+
cache_file,
|
53
|
+
cache_file_name,
|
54
|
+
autotune,
|
55
|
+
autotune_trials,
|
56
|
+
autotune_executions_per_trials,
|
57
|
+
epochs,
|
58
|
+
batch_size,
|
59
|
+
learning_rate,
|
60
|
+
focal_loss,
|
61
|
+
focal_loss_gamma,
|
62
|
+
focal_loss_alpha,
|
63
|
+
hidden_units,
|
64
|
+
dropout,
|
65
|
+
label_smoothing,
|
66
|
+
use_mixup,
|
67
|
+
upsampling_ratio,
|
68
|
+
upsampling_mode,
|
69
|
+
model_format,
|
70
|
+
audio_speed,
|
71
|
+
progress=gr.Progress(),
|
72
|
+
):
|
73
|
+
"""Starts the training of a custom classifier.
|
74
|
+
|
75
|
+
Args:
|
76
|
+
data_dir: Directory containing the training data.
|
77
|
+
test_data_dir: Directory containing the test data.
|
78
|
+
crop_mode: Mode for cropping audio samples.
|
79
|
+
crop_overlap: Overlap ratio for audio segments.
|
80
|
+
fmin: Minimum frequency for bandpass filtering.
|
81
|
+
fmax: Maximum frequency for bandpass filtering.
|
82
|
+
output_dir: Directory to save the trained model.
|
83
|
+
classifier_name: Name of the custom classifier.
|
84
|
+
model_save_mode: Save mode for the model (replace or append).
|
85
|
+
cache_mode: Cache mode for training data (load, save, or None).
|
86
|
+
cache_file: Path to the cache file.
|
87
|
+
cache_file_name: Name of the cache file.
|
88
|
+
autotune: Whether to use hyperparameter autotuning.
|
89
|
+
autotune_trials: Number of trials for autotuning.
|
90
|
+
autotune_executions_per_trials: Number of executions per autotuning trial.
|
91
|
+
epochs: Number of training epochs.
|
92
|
+
batch_size: Batch size for training.
|
93
|
+
learning_rate: Learning rate for the optimizer.
|
94
|
+
focal_loss: Whether to use focal loss for training.
|
95
|
+
focal_loss_gamma: Gamma parameter for focal loss.
|
96
|
+
focal_loss_alpha: Alpha parameter for focal loss.
|
97
|
+
hidden_units: Number of hidden units in the droput: Dropout rate for regularization.
|
98
|
+
dropout: Dropout rate for regularization.
|
99
|
+
label_smoothing: Whether to apply label smoothing for training.
|
100
|
+
use_mixup: Whether to use mixup data augmentation.
|
101
|
+
upsampling_ratio: Ratio for upsampling underrepresented classes.
|
102
|
+
upsampling_mode: Mode for upsampling (repeat, mean, smote).
|
103
|
+
model_format: Format to save the trained model (tflite, raven, both).
|
104
|
+
audio_speed: Speed factor for audio playback.
|
105
|
+
|
106
|
+
Returns:
|
107
|
+
Returns a matplotlib.pyplot figure.
|
108
|
+
"""
|
109
|
+
import matplotlib
|
110
|
+
import matplotlib.pyplot as plt
|
111
|
+
|
112
|
+
from birdnet_analyzer.train.utils import train_model
|
113
|
+
|
114
|
+
# Skip training data validation when cache mode is "load"
|
115
|
+
if cache_mode != "load":
|
116
|
+
gu.validate(data_dir, loc.localize("validation-no-training-data-selected"))
|
117
|
+
|
118
|
+
gu.validate(output_dir, loc.localize("validation-no-directory-for-classifier-selected"))
|
119
|
+
gu.validate(classifier_name, loc.localize("validation-no-valid-classifier-name"))
|
120
|
+
|
121
|
+
if not epochs or epochs < 0:
|
122
|
+
raise gr.Error(loc.localize("validation-no-valid-epoch-number"))
|
123
|
+
|
124
|
+
if not batch_size or batch_size < 0:
|
125
|
+
raise gr.Error(loc.localize("validation-no-valid-batch-size"))
|
126
|
+
|
127
|
+
if not learning_rate or learning_rate < 0:
|
128
|
+
raise gr.Error(loc.localize("validation-no-valid-learning-rate"))
|
129
|
+
|
130
|
+
if fmin < cfg.SIG_FMIN or fmax > cfg.SIG_FMAX or fmin > fmax:
|
131
|
+
raise gr.Error(f"{loc.localize('validation-no-valid-frequency')} [{cfg.SIG_FMIN}, {cfg.SIG_FMAX}]")
|
132
|
+
|
133
|
+
cfg.TRAIN_WITH_FOCAL_LOSS = focal_loss
|
134
|
+
cfg.FOCAL_LOSS_GAMMA = max(0.0, float(focal_loss_gamma))
|
135
|
+
cfg.FOCAL_LOSS_ALPHA = max(0.0, min(1.0, float(focal_loss_alpha)))
|
136
|
+
|
137
|
+
if not hidden_units or hidden_units < 0:
|
138
|
+
hidden_units = 0
|
139
|
+
|
140
|
+
cfg.TRAIN_DROPOUT = max(0.0, min(1.0, float(dropout)))
|
141
|
+
|
142
|
+
if progress is not None:
|
143
|
+
progress((0, epochs), desc=loc.localize("progress-build-classifier"), unit="epochs")
|
144
|
+
|
145
|
+
cfg.TRAIN_DATA_PATH = data_dir
|
146
|
+
cfg.TEST_DATA_PATH = test_data_dir
|
147
|
+
cfg.SAMPLE_CROP_MODE = crop_mode
|
148
|
+
cfg.SIG_OVERLAP = max(0.0, min(2.9, float(crop_overlap)))
|
149
|
+
cfg.CUSTOM_CLASSIFIER = str(Path(output_dir) / classifier_name)
|
150
|
+
cfg.TRAIN_EPOCHS = int(epochs)
|
151
|
+
cfg.TRAIN_BATCH_SIZE = int(batch_size)
|
152
|
+
cfg.TRAIN_LEARNING_RATE = learning_rate
|
153
|
+
cfg.TRAIN_HIDDEN_UNITS = int(hidden_units)
|
154
|
+
cfg.TRAIN_WITH_LABEL_SMOOTHING = label_smoothing
|
155
|
+
cfg.TRAIN_WITH_MIXUP = use_mixup
|
156
|
+
cfg.UPSAMPLING_RATIO = min(max(0, upsampling_ratio), 1)
|
157
|
+
cfg.UPSAMPLING_MODE = upsampling_mode
|
158
|
+
cfg.TRAINED_MODEL_OUTPUT_FORMAT = model_format
|
159
|
+
|
160
|
+
cfg.BANDPASS_FMIN = max(0, min(cfg.SIG_FMAX, int(fmin)))
|
161
|
+
cfg.BANDPASS_FMAX = max(cfg.SIG_FMIN, min(cfg.SIG_FMAX, int(fmax)))
|
162
|
+
|
163
|
+
cfg.TRAINED_MODEL_SAVE_MODE = model_save_mode
|
164
|
+
cfg.TRAIN_CACHE_MODE = cache_mode
|
165
|
+
cfg.TRAIN_CACHE_FILE = os.path.join(cache_file, cache_file_name) if cache_mode == "save" else cache_file
|
166
|
+
cfg.TFLITE_THREADS = 1
|
167
|
+
cfg.CPU_THREADS = max(1, multiprocessing.cpu_count() - 1) # let's use everything we have (well, almost)
|
168
|
+
|
169
|
+
if cache_mode == "load" and not os.path.isfile(cfg.TRAIN_CACHE_FILE):
|
170
|
+
raise gr.Error(loc.localize("validation-no-cache-file-selected"))
|
171
|
+
|
172
|
+
cfg.AUTOTUNE = autotune
|
173
|
+
cfg.AUTOTUNE_TRIALS = autotune_trials
|
174
|
+
cfg.AUTOTUNE_EXECUTIONS_PER_TRIAL = int(autotune_executions_per_trials)
|
175
|
+
|
176
|
+
cfg.AUDIO_SPEED = max(0.1, 1.0 / (audio_speed * -1)) if audio_speed < 0 else max(1.0, float(audio_speed))
|
177
|
+
|
178
|
+
def data_load_progression(num_files, num_total_files, label):
|
179
|
+
if progress is not None:
|
180
|
+
progress(
|
181
|
+
(num_files, num_total_files),
|
182
|
+
total=num_total_files,
|
183
|
+
unit="files",
|
184
|
+
desc=f"{loc.localize('progress-loading-data')} '{label}'",
|
185
|
+
)
|
186
|
+
|
187
|
+
def epoch_progression(epoch, logs=None):
|
188
|
+
if progress is not None:
|
189
|
+
if epoch + 1 == epochs:
|
190
|
+
progress(
|
191
|
+
(epoch + 1, epochs),
|
192
|
+
total=epochs,
|
193
|
+
unit="epochs",
|
194
|
+
desc=f"{loc.localize('progress-saving')} {cfg.CUSTOM_CLASSIFIER}",
|
195
|
+
)
|
196
|
+
else:
|
197
|
+
progress((epoch + 1, epochs), total=epochs, unit="epochs", desc=loc.localize("progress-training"))
|
198
|
+
|
199
|
+
def trial_progression(trial):
|
200
|
+
if progress is not None:
|
201
|
+
progress(
|
202
|
+
(trial, autotune_trials), total=autotune_trials, unit="trials", desc=loc.localize("progress-autotune")
|
203
|
+
)
|
204
|
+
|
205
|
+
try:
|
206
|
+
history_result = train_model(
|
207
|
+
on_epoch_end=epoch_progression,
|
208
|
+
on_trial_result=trial_progression,
|
209
|
+
on_data_load_end=data_load_progression,
|
210
|
+
autotune_directory=gu.APPDIR if utils.FROZEN else "autotune",
|
211
|
+
)
|
212
|
+
|
213
|
+
# Unpack history and metrics
|
214
|
+
history, metrics = history_result
|
215
|
+
except Exception as e:
|
216
|
+
if e.args and len(e.args) > 1:
|
217
|
+
raise gr.Error(loc.localize(e.args[1])) from e
|
218
|
+
|
219
|
+
raise gr.Error(f"{e}") from e
|
220
|
+
|
221
|
+
if len(history.epoch) < epochs:
|
222
|
+
gr.Info(loc.localize("training-tab-early-stoppage-msg"))
|
223
|
+
|
224
|
+
auprc = history.history["val_AUPRC"]
|
225
|
+
auroc = history.history["val_AUROC"]
|
226
|
+
|
227
|
+
matplotlib.use("agg")
|
228
|
+
|
229
|
+
fig = plt.figure()
|
230
|
+
plt.plot(auprc, label="AUPRC")
|
231
|
+
plt.plot(auroc, label="AUROC")
|
232
|
+
plt.legend()
|
233
|
+
plt.xlabel("Epoch")
|
234
|
+
|
235
|
+
return fig, metrics
|
236
|
+
|
237
|
+
|
238
|
+
def build_train_tab():
|
239
|
+
with gr.Tab(loc.localize("training-tab-title")):
|
240
|
+
input_directory_state = gr.State()
|
241
|
+
output_directory_state = gr.State()
|
242
|
+
test_data_dir_state = gr.State()
|
243
|
+
|
244
|
+
with gr.Row():
|
245
|
+
with gr.Column():
|
246
|
+
select_directory_btn = gr.Button(loc.localize("training-tab-input-selection-button-label"))
|
247
|
+
directory_input = gr.List(
|
248
|
+
headers=[loc.localize("training-tab-classes-dataframe-column-classes-header")],
|
249
|
+
interactive=False,
|
250
|
+
max_height=_GRID_MAX_HEIGHT,
|
251
|
+
)
|
252
|
+
select_directory_btn.click(
|
253
|
+
partial(select_subdirectories, state_key="train-data-dir"),
|
254
|
+
outputs=[input_directory_state, directory_input],
|
255
|
+
show_progress=False,
|
256
|
+
)
|
257
|
+
|
258
|
+
select_test_directory_btn = gr.Button(loc.localize("training-tab-test-data-selection-button-label"))
|
259
|
+
test_directory_input = gr.List(
|
260
|
+
headers=[loc.localize("training-tab-classes-dataframe-column-classes-header")],
|
261
|
+
interactive=False,
|
262
|
+
max_height=_GRID_MAX_HEIGHT,
|
263
|
+
)
|
264
|
+
select_test_directory_btn.click(
|
265
|
+
partial(select_subdirectories, state_key="test-data-dir"),
|
266
|
+
outputs=[test_data_dir_state, test_directory_input],
|
267
|
+
show_progress=False,
|
268
|
+
)
|
269
|
+
|
270
|
+
with gr.Column():
|
271
|
+
select_classifier_directory_btn = gr.Button(loc.localize("training-tab-select-output-button-label"))
|
272
|
+
|
273
|
+
with gr.Column():
|
274
|
+
classifier_name = gr.Textbox(
|
275
|
+
"CustomClassifier",
|
276
|
+
visible=False,
|
277
|
+
info=loc.localize("training-tab-classifier-textbox-info"),
|
278
|
+
)
|
279
|
+
output_format = gr.Radio(
|
280
|
+
["tflite", "raven", (loc.localize("training-tab-output-format-both"), "both")],
|
281
|
+
value=cfg.TRAINED_MODEL_OUTPUT_FORMAT,
|
282
|
+
label=loc.localize("training-tab-output-format-radio-label"),
|
283
|
+
info=loc.localize("training-tab-output-format-radio-info"),
|
284
|
+
visible=False,
|
285
|
+
)
|
286
|
+
|
287
|
+
def select_directory_and_update_tb():
|
288
|
+
dir_name = gu.select_folder(state_key="train-output-dir")
|
289
|
+
|
290
|
+
if dir_name:
|
291
|
+
return (
|
292
|
+
dir_name,
|
293
|
+
gr.Textbox(label=dir_name, visible=True),
|
294
|
+
gr.Radio(visible=True, interactive=True),
|
295
|
+
)
|
296
|
+
|
297
|
+
return None, None
|
298
|
+
|
299
|
+
select_classifier_directory_btn.click(
|
300
|
+
select_directory_and_update_tb,
|
301
|
+
outputs=[output_directory_state, classifier_name, output_format],
|
302
|
+
show_progress=False,
|
303
|
+
)
|
304
|
+
|
305
|
+
with gr.Row():
|
306
|
+
cache_file_state = gr.State()
|
307
|
+
cache_mode = gr.Radio(
|
308
|
+
[
|
309
|
+
(loc.localize("training-tab-cache-mode-radio-option-none"), None),
|
310
|
+
(loc.localize("training-tab-cache-mode-radio-option-load"), "load"),
|
311
|
+
(loc.localize("training-tab-cache-mode-radio-option-save"), "save"),
|
312
|
+
],
|
313
|
+
value=cfg.TRAIN_CACHE_MODE,
|
314
|
+
label=loc.localize("training-tab-cache-mode-radio-label"),
|
315
|
+
info=loc.localize("training-tab-cache-mode-radio-info"),
|
316
|
+
)
|
317
|
+
with gr.Column(visible=False) as new_cache_file_row:
|
318
|
+
select_cache_file_directory_btn = gr.Button(
|
319
|
+
loc.localize("training-tab-cache-select-directory-button-label")
|
320
|
+
)
|
321
|
+
|
322
|
+
with gr.Column():
|
323
|
+
cache_file_name = gr.Textbox(
|
324
|
+
"train_cache.npz",
|
325
|
+
visible=False,
|
326
|
+
info=loc.localize("training-tab-cache-file-name-textbox-info"),
|
327
|
+
)
|
328
|
+
|
329
|
+
def select_directory_and_update():
|
330
|
+
dir_name = gu.select_folder(state_key="train-data-cache-file-output")
|
331
|
+
|
332
|
+
if dir_name:
|
333
|
+
return (
|
334
|
+
dir_name,
|
335
|
+
gr.Textbox(label=dir_name, visible=True),
|
336
|
+
)
|
337
|
+
|
338
|
+
return None, None
|
339
|
+
|
340
|
+
select_cache_file_directory_btn.click(
|
341
|
+
select_directory_and_update,
|
342
|
+
outputs=[cache_file_state, cache_file_name],
|
343
|
+
show_progress=False,
|
344
|
+
)
|
345
|
+
|
346
|
+
with gr.Column(visible=False) as load_cache_file_row:
|
347
|
+
selected_cache_file_btn = gr.Button(loc.localize("training-tab-cache-select-file-button-label"))
|
348
|
+
cache_file_input = gr.File(file_types=[".npz"], visible=False, interactive=False)
|
349
|
+
|
350
|
+
def on_cache_file_selection_click():
|
351
|
+
file = gu.select_file(("NPZ file (*.npz)",), state_key="train_data_cache_file")
|
352
|
+
|
353
|
+
if file:
|
354
|
+
return file, gr.File(value=file, visible=True)
|
355
|
+
|
356
|
+
return None, None
|
357
|
+
|
358
|
+
selected_cache_file_btn.click(
|
359
|
+
on_cache_file_selection_click,
|
360
|
+
outputs=[cache_file_state, cache_file_input],
|
361
|
+
show_progress=False,
|
362
|
+
)
|
363
|
+
|
364
|
+
def on_cache_mode_change(value):
|
365
|
+
return (
|
366
|
+
gr.update(visible=value == "save"),
|
367
|
+
gr.update(visible=value == "load"),
|
368
|
+
gr.update(interactive=value != "load"),
|
369
|
+
[],
|
370
|
+
gr.update(interactive=value != "load"),
|
371
|
+
[],
|
372
|
+
gr.update(interactive=value != "load"),
|
373
|
+
gr.update(interactive=value != "load"),
|
374
|
+
gr.update(interactive=value != "load"),
|
375
|
+
gr.update(interactive=value != "load"),
|
376
|
+
gr.update(interactive=value != "load"),
|
377
|
+
)
|
378
|
+
|
379
|
+
with gr.Row():
|
380
|
+
fmin_number = gr.Number(
|
381
|
+
cfg.SIG_FMIN,
|
382
|
+
minimum=0,
|
383
|
+
label=loc.localize("inference-settings-fmin-number-label"),
|
384
|
+
info=loc.localize("inference-settings-fmin-number-info"),
|
385
|
+
)
|
386
|
+
|
387
|
+
fmax_number = gr.Number(
|
388
|
+
cfg.SIG_FMAX,
|
389
|
+
minimum=0,
|
390
|
+
label=loc.localize("inference-settings-fmax-number-label"),
|
391
|
+
info=loc.localize("inference-settings-fmax-number-info"),
|
392
|
+
)
|
393
|
+
|
394
|
+
with gr.Row():
|
395
|
+
audio_speed_slider = gr.Slider(
|
396
|
+
minimum=-10,
|
397
|
+
maximum=10,
|
398
|
+
value=cfg.AUDIO_SPEED,
|
399
|
+
step=1,
|
400
|
+
label=loc.localize("training-tab-audio-speed-slider-label"),
|
401
|
+
info=loc.localize("training-tab-audio-speed-slider-info"),
|
402
|
+
)
|
403
|
+
|
404
|
+
with gr.Row():
|
405
|
+
crop_mode = gr.Radio(
|
406
|
+
[
|
407
|
+
(loc.localize("training-tab-crop-mode-radio-option-center"), "center"),
|
408
|
+
(loc.localize("training-tab-crop-mode-radio-option-first"), "first"),
|
409
|
+
(loc.localize("training-tab-crop-mode-radio-option-segments"), "segments"),
|
410
|
+
(loc.localize("training-tab-crop-mode-radio-option-smart"), "smart"),
|
411
|
+
],
|
412
|
+
value="center",
|
413
|
+
label=loc.localize("training-tab-crop-mode-radio-label"),
|
414
|
+
info=loc.localize("training-tab-crop-mode-radio-info"),
|
415
|
+
)
|
416
|
+
|
417
|
+
crop_overlap = gr.Slider(
|
418
|
+
minimum=0,
|
419
|
+
maximum=2.99,
|
420
|
+
value=cfg.SIG_OVERLAP,
|
421
|
+
step=0.01,
|
422
|
+
label=loc.localize("training-tab-crop-overlap-number-label"),
|
423
|
+
info=loc.localize("training-tab-crop-overlap-number-info"),
|
424
|
+
visible=False,
|
425
|
+
)
|
426
|
+
|
427
|
+
def on_crop_select(new_crop_mode):
|
428
|
+
# Make overlap slider visible for both "segments" and "smart" crop modes
|
429
|
+
return gr.Number(
|
430
|
+
visible=new_crop_mode in ["segments", "smart"], interactive=new_crop_mode in ["segments", "smart"]
|
431
|
+
)
|
432
|
+
|
433
|
+
crop_mode.change(on_crop_select, inputs=crop_mode, outputs=crop_overlap)
|
434
|
+
|
435
|
+
cache_mode.change(
|
436
|
+
on_cache_mode_change,
|
437
|
+
inputs=cache_mode,
|
438
|
+
outputs=[
|
439
|
+
new_cache_file_row,
|
440
|
+
load_cache_file_row,
|
441
|
+
select_directory_btn,
|
442
|
+
directory_input,
|
443
|
+
select_test_directory_btn,
|
444
|
+
test_directory_input,
|
445
|
+
fmin_number,
|
446
|
+
fmax_number,
|
447
|
+
audio_speed_slider,
|
448
|
+
crop_mode,
|
449
|
+
crop_overlap,
|
450
|
+
],
|
451
|
+
show_progress=False,
|
452
|
+
)
|
453
|
+
|
454
|
+
autotune_cb = gr.Checkbox(
|
455
|
+
cfg.AUTOTUNE,
|
456
|
+
label=loc.localize("training-tab-autotune-checkbox-label"),
|
457
|
+
info=loc.localize("training-tab-autotune-checkbox-info"),
|
458
|
+
)
|
459
|
+
|
460
|
+
with gr.Column(visible=False) as autotune_params, gr.Row():
|
461
|
+
autotune_trials = gr.Number(
|
462
|
+
cfg.AUTOTUNE_TRIALS,
|
463
|
+
label=loc.localize("training-tab-autotune-trials-number-label"),
|
464
|
+
info=loc.localize("training-tab-autotune-trials-number-info"),
|
465
|
+
minimum=1,
|
466
|
+
)
|
467
|
+
autotune_executions_per_trials = gr.Number(
|
468
|
+
cfg.AUTOTUNE_EXECUTIONS_PER_TRIAL,
|
469
|
+
minimum=1,
|
470
|
+
label=loc.localize("training-tab-autotune-executions-number-label"),
|
471
|
+
info=loc.localize("training-tab-autotune-executions-number-info"),
|
472
|
+
)
|
473
|
+
|
474
|
+
with gr.Column() as custom_params:
|
475
|
+
with gr.Row():
|
476
|
+
epoch_number = gr.Number(
|
477
|
+
cfg.TRAIN_EPOCHS,
|
478
|
+
minimum=1,
|
479
|
+
step=1,
|
480
|
+
label=loc.localize("training-tab-epochs-number-label"),
|
481
|
+
info=loc.localize("training-tab-epochs-number-info"),
|
482
|
+
)
|
483
|
+
batch_size_number = gr.Number(
|
484
|
+
32,
|
485
|
+
minimum=1,
|
486
|
+
step=8,
|
487
|
+
label=loc.localize("training-tab-batchsize-number-label"),
|
488
|
+
info=loc.localize("training-tab-batchsize-number-info"),
|
489
|
+
)
|
490
|
+
learning_rate_number = gr.Number(
|
491
|
+
cfg.TRAIN_LEARNING_RATE,
|
492
|
+
minimum=0.0001,
|
493
|
+
step=0.0001,
|
494
|
+
label=loc.localize("training-tab-learningrate-number-label"),
|
495
|
+
info=loc.localize("training-tab-learningrate-number-info"),
|
496
|
+
)
|
497
|
+
|
498
|
+
with gr.Row():
|
499
|
+
hidden_units_number = gr.Number(
|
500
|
+
cfg.TRAIN_HIDDEN_UNITS,
|
501
|
+
minimum=0,
|
502
|
+
step=64,
|
503
|
+
label=loc.localize("training-tab-hiddenunits-number-label"),
|
504
|
+
info=loc.localize("training-tab-hiddenunits-number-info"),
|
505
|
+
)
|
506
|
+
dropout_number = gr.Number(
|
507
|
+
cfg.TRAIN_DROPOUT,
|
508
|
+
minimum=0.0,
|
509
|
+
maximum=0.9,
|
510
|
+
step=0.1,
|
511
|
+
label=loc.localize("training-tab-dropout-number-label"),
|
512
|
+
info=loc.localize("training-tab-dropout-number-info"),
|
513
|
+
)
|
514
|
+
use_label_smoothing = gr.Checkbox(
|
515
|
+
cfg.TRAIN_WITH_LABEL_SMOOTHING,
|
516
|
+
label=loc.localize("training-tab-use-labelsmoothing-checkbox-label"),
|
517
|
+
info=loc.localize("training-tab-use-labelsmoothing-checkbox-info"),
|
518
|
+
show_label=True,
|
519
|
+
)
|
520
|
+
|
521
|
+
with gr.Row():
|
522
|
+
upsampling_mode = gr.Radio(
|
523
|
+
[
|
524
|
+
(loc.localize("training-tab-upsampling-radio-option-repeat"), "repeat"),
|
525
|
+
(loc.localize("training-tab-upsampling-radio-option-mean"), "mean"),
|
526
|
+
(loc.localize("training-tab-upsampling-radio-option-linear"), "linear"),
|
527
|
+
("SMOTE", "smote"),
|
528
|
+
],
|
529
|
+
value=cfg.UPSAMPLING_MODE,
|
530
|
+
label=loc.localize("training-tab-upsampling-radio-label"),
|
531
|
+
info=loc.localize("training-tab-upsampling-radio-info"),
|
532
|
+
)
|
533
|
+
upsampling_ratio = gr.Slider(
|
534
|
+
0.0,
|
535
|
+
1.0,
|
536
|
+
cfg.UPSAMPLING_RATIO,
|
537
|
+
step=0.05,
|
538
|
+
label=loc.localize("training-tab-upsampling-ratio-slider-label"),
|
539
|
+
info=loc.localize("training-tab-upsampling-ratio-slider-info"),
|
540
|
+
)
|
541
|
+
|
542
|
+
with gr.Row():
|
543
|
+
use_mixup = gr.Checkbox(
|
544
|
+
cfg.TRAIN_WITH_MIXUP,
|
545
|
+
label=loc.localize("training-tab-use-mixup-checkbox-label"),
|
546
|
+
info=loc.localize("training-tab-use-mixup-checkbox-info"),
|
547
|
+
show_label=True,
|
548
|
+
)
|
549
|
+
use_focal_loss = gr.Checkbox(
|
550
|
+
cfg.TRAIN_WITH_FOCAL_LOSS,
|
551
|
+
label=loc.localize("training-tab-use-focal-loss-checkbox-label"),
|
552
|
+
info=loc.localize("training-tab-use-focal-loss-checkbox-info"),
|
553
|
+
show_label=True,
|
554
|
+
)
|
555
|
+
|
556
|
+
with gr.Row(visible=False) as focal_loss_params, gr.Row():
|
557
|
+
focal_loss_gamma = gr.Slider(
|
558
|
+
minimum=0.5,
|
559
|
+
maximum=5.0,
|
560
|
+
value=cfg.FOCAL_LOSS_GAMMA,
|
561
|
+
step=0.1,
|
562
|
+
label=loc.localize("training-tab-focal-loss-gamma-slider-label"),
|
563
|
+
info=loc.localize("training-tab-focal-loss-gamma-slider-info"),
|
564
|
+
interactive=True,
|
565
|
+
)
|
566
|
+
focal_loss_alpha = gr.Slider(
|
567
|
+
minimum=0.1,
|
568
|
+
maximum=0.9,
|
569
|
+
value=cfg.FOCAL_LOSS_ALPHA,
|
570
|
+
step=0.05,
|
571
|
+
label=loc.localize("training-tab-focal-loss-alpha-slider-label"),
|
572
|
+
info=loc.localize("training-tab-focal-loss-alpha-slider-info"),
|
573
|
+
interactive=True,
|
574
|
+
)
|
575
|
+
|
576
|
+
def on_focal_loss_change(value):
|
577
|
+
return gr.Row(visible=value)
|
578
|
+
|
579
|
+
use_focal_loss.change(
|
580
|
+
on_focal_loss_change, inputs=use_focal_loss, outputs=focal_loss_params, show_progress=False
|
581
|
+
)
|
582
|
+
|
583
|
+
def on_autotune_change(value):
|
584
|
+
return (
|
585
|
+
gr.Column(visible=not value),
|
586
|
+
gr.Column(visible=value),
|
587
|
+
gr.Row(visible=not value and use_focal_loss.value),
|
588
|
+
)
|
589
|
+
|
590
|
+
autotune_cb.change(
|
591
|
+
on_autotune_change,
|
592
|
+
inputs=autotune_cb,
|
593
|
+
outputs=[custom_params, autotune_params, focal_loss_params],
|
594
|
+
show_progress=False,
|
595
|
+
)
|
596
|
+
|
597
|
+
model_save_mode = gr.Radio(
|
598
|
+
[
|
599
|
+
(loc.localize("training-tab-model-save-mode-radio-option-replace"), "replace"),
|
600
|
+
(loc.localize("training-tab-model-save-mode-radio-option-append"), "append"),
|
601
|
+
],
|
602
|
+
value=cfg.TRAINED_MODEL_SAVE_MODE,
|
603
|
+
label=loc.localize("training-tab-model-save-mode-radio-label"),
|
604
|
+
info=loc.localize("training-tab-model-save-mode-radio-info"),
|
605
|
+
)
|
606
|
+
|
607
|
+
train_history_plot = gr.Plot()
|
608
|
+
metrics_table = gr.Dataframe(
|
609
|
+
headers=["Class", "Precision", "Recall", "F1 Score", "AUPRC", "AUROC", "Samples"],
|
610
|
+
visible=False,
|
611
|
+
label="Model Performance Metrics (Default Threshold 0.5)",
|
612
|
+
)
|
613
|
+
start_training_button = gr.Button(
|
614
|
+
loc.localize("training-tab-start-training-button-label"), variant="huggingface"
|
615
|
+
)
|
616
|
+
|
617
|
+
def train_and_show_metrics(*args):
|
618
|
+
history, metrics = start_training(*args)
|
619
|
+
|
620
|
+
# If metrics are available (test data was provided), create table
|
621
|
+
if metrics:
|
622
|
+
# Create dataframe data with metrics
|
623
|
+
table_data = []
|
624
|
+
|
625
|
+
# Add overall metrics row first
|
626
|
+
table_data.append(
|
627
|
+
[
|
628
|
+
"OVERALL (Macro-avg)",
|
629
|
+
f"{metrics['macro_precision_default']:.4f}",
|
630
|
+
f"{metrics['macro_recall_default']:.4f}",
|
631
|
+
f"{metrics['macro_f1_default']:.4f}",
|
632
|
+
f"{metrics['macro_auprc']:.4f}",
|
633
|
+
f"{metrics['macro_auroc']:.4f}",
|
634
|
+
"",
|
635
|
+
]
|
636
|
+
)
|
637
|
+
|
638
|
+
# Add class-specific metrics
|
639
|
+
for class_name, class_metrics in metrics["class_metrics"].items():
|
640
|
+
distribution = metrics["class_distribution"].get(class_name, {"count": 0, "percentage": 0.0})
|
641
|
+
table_data.append(
|
642
|
+
[
|
643
|
+
class_name,
|
644
|
+
f"{class_metrics['precision_default']:.4f}",
|
645
|
+
f"{class_metrics['recall_default']:.4f}",
|
646
|
+
f"{class_metrics['f1_default']:.4f}",
|
647
|
+
f"{class_metrics['auprc']:.4f}",
|
648
|
+
f"{class_metrics['auroc']:.4f}",
|
649
|
+
f"{distribution['count']} ({distribution['percentage']:.2f}%)",
|
650
|
+
]
|
651
|
+
)
|
652
|
+
|
653
|
+
return history, gr.Dataframe(visible=True, value=table_data)
|
654
|
+
|
655
|
+
# No metrics available, just return history and hide table
|
656
|
+
return history, gr.Dataframe(visible=False)
|
657
|
+
|
658
|
+
start_training_button.click(
|
659
|
+
train_and_show_metrics,
|
660
|
+
inputs=[
|
661
|
+
input_directory_state,
|
662
|
+
test_data_dir_state,
|
663
|
+
crop_mode,
|
664
|
+
crop_overlap,
|
665
|
+
fmin_number,
|
666
|
+
fmax_number,
|
667
|
+
output_directory_state,
|
668
|
+
classifier_name,
|
669
|
+
model_save_mode,
|
670
|
+
cache_mode,
|
671
|
+
cache_file_state,
|
672
|
+
cache_file_name,
|
673
|
+
autotune_cb,
|
674
|
+
autotune_trials,
|
675
|
+
autotune_executions_per_trials,
|
676
|
+
epoch_number,
|
677
|
+
batch_size_number,
|
678
|
+
learning_rate_number,
|
679
|
+
use_focal_loss,
|
680
|
+
focal_loss_gamma,
|
681
|
+
focal_loss_alpha,
|
682
|
+
hidden_units_number,
|
683
|
+
dropout_number,
|
684
|
+
use_label_smoothing,
|
685
|
+
use_mixup,
|
686
|
+
upsampling_ratio,
|
687
|
+
upsampling_mode,
|
688
|
+
output_format,
|
689
|
+
audio_speed_slider,
|
690
|
+
],
|
691
|
+
outputs=[train_history_plot, metrics_table],
|
692
|
+
)
|
693
|
+
|
694
|
+
|
695
|
+
if __name__ == "__main__":
|
696
|
+
gu.open_window(build_train_tab)
|