birdnet-analyzer 2.0.0__py3-none-any.whl → 2.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- birdnet_analyzer/__init__.py +9 -8
- birdnet_analyzer/analyze/__init__.py +19 -5
- birdnet_analyzer/analyze/__main__.py +3 -4
- birdnet_analyzer/analyze/cli.py +30 -25
- birdnet_analyzer/analyze/core.py +246 -245
- birdnet_analyzer/analyze/utils.py +694 -701
- birdnet_analyzer/audio.py +368 -372
- birdnet_analyzer/cli.py +732 -707
- birdnet_analyzer/config.py +243 -242
- birdnet_analyzer/eBird_taxonomy_codes_2024E.json +13046 -0
- birdnet_analyzer/embeddings/__init__.py +3 -4
- birdnet_analyzer/embeddings/__main__.py +3 -3
- birdnet_analyzer/embeddings/cli.py +12 -13
- birdnet_analyzer/embeddings/core.py +70 -70
- birdnet_analyzer/embeddings/utils.py +220 -193
- birdnet_analyzer/evaluation/__init__.py +189 -195
- birdnet_analyzer/evaluation/__main__.py +3 -3
- birdnet_analyzer/evaluation/assessment/__init__.py +0 -0
- birdnet_analyzer/evaluation/assessment/metrics.py +388 -0
- birdnet_analyzer/evaluation/assessment/performance_assessor.py +364 -0
- birdnet_analyzer/evaluation/assessment/plotting.py +378 -0
- birdnet_analyzer/evaluation/preprocessing/__init__.py +0 -0
- birdnet_analyzer/evaluation/preprocessing/data_processor.py +631 -0
- birdnet_analyzer/evaluation/preprocessing/utils.py +98 -0
- birdnet_analyzer/gui/__init__.py +19 -23
- birdnet_analyzer/gui/__main__.py +3 -3
- birdnet_analyzer/gui/analysis.py +179 -174
- birdnet_analyzer/gui/assets/arrow_down.svg +4 -4
- birdnet_analyzer/gui/assets/arrow_left.svg +4 -4
- birdnet_analyzer/gui/assets/arrow_right.svg +4 -4
- birdnet_analyzer/gui/assets/arrow_up.svg +4 -4
- birdnet_analyzer/gui/assets/gui.css +36 -28
- birdnet_analyzer/gui/assets/gui.js +93 -93
- birdnet_analyzer/gui/embeddings.py +638 -620
- birdnet_analyzer/gui/evaluation.py +801 -813
- birdnet_analyzer/gui/localization.py +75 -68
- birdnet_analyzer/gui/multi_file.py +265 -246
- birdnet_analyzer/gui/review.py +472 -527
- birdnet_analyzer/gui/segments.py +191 -191
- birdnet_analyzer/gui/settings.py +149 -129
- birdnet_analyzer/gui/single_file.py +264 -269
- birdnet_analyzer/gui/species.py +95 -95
- birdnet_analyzer/gui/train.py +687 -698
- birdnet_analyzer/gui/utils.py +797 -808
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_af.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ar.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_bg.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ca.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_cs.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_da.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_de.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_el.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_en_uk.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_es.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_fi.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_fr.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_he.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_hr.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_hu.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_in.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_is.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_it.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ja.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ko.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_lt.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ml.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_nl.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_no.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_pl.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_pt_BR.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_pt_PT.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ro.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ru.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sk.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sl.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sr.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sv.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_th.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_tr.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_uk.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_zh.txt +6522 -6522
- birdnet_analyzer/lang/de.json +341 -334
- birdnet_analyzer/lang/en.json +341 -334
- birdnet_analyzer/lang/fi.json +341 -334
- birdnet_analyzer/lang/fr.json +341 -334
- birdnet_analyzer/lang/id.json +341 -334
- birdnet_analyzer/lang/pt-br.json +341 -334
- birdnet_analyzer/lang/ru.json +341 -334
- birdnet_analyzer/lang/se.json +341 -334
- birdnet_analyzer/lang/tlh.json +341 -334
- birdnet_analyzer/lang/zh_TW.json +341 -334
- birdnet_analyzer/model.py +1212 -1243
- birdnet_analyzer/playground.py +5 -0
- birdnet_analyzer/search/__init__.py +3 -3
- birdnet_analyzer/search/__main__.py +3 -3
- birdnet_analyzer/search/cli.py +11 -12
- birdnet_analyzer/search/core.py +78 -78
- birdnet_analyzer/search/utils.py +107 -111
- birdnet_analyzer/segments/__init__.py +3 -3
- birdnet_analyzer/segments/__main__.py +3 -3
- birdnet_analyzer/segments/cli.py +13 -14
- birdnet_analyzer/segments/core.py +81 -78
- birdnet_analyzer/segments/utils.py +383 -394
- birdnet_analyzer/species/__init__.py +3 -3
- birdnet_analyzer/species/__main__.py +3 -3
- birdnet_analyzer/species/cli.py +13 -14
- birdnet_analyzer/species/core.py +35 -35
- birdnet_analyzer/species/utils.py +74 -75
- birdnet_analyzer/train/__init__.py +3 -3
- birdnet_analyzer/train/__main__.py +3 -3
- birdnet_analyzer/train/cli.py +13 -14
- birdnet_analyzer/train/core.py +113 -113
- birdnet_analyzer/train/utils.py +877 -847
- birdnet_analyzer/translate.py +133 -104
- birdnet_analyzer/utils.py +425 -419
- {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.1.0.dist-info}/METADATA +146 -129
- birdnet_analyzer-2.1.0.dist-info/RECORD +125 -0
- {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.1.0.dist-info}/WHEEL +1 -1
- {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.1.0.dist-info}/licenses/LICENSE +18 -18
- birdnet_analyzer/eBird_taxonomy_codes_2021E.json +0 -25280
- birdnet_analyzer-2.0.0.dist-info/RECORD +0 -117
- {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.1.0.dist-info}/entry_points.txt +0 -0
- {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.1.0.dist-info}/top_level.txt +0 -0
birdnet_analyzer/gui/train.py
CHANGED
@@ -1,698 +1,687 @@
|
|
1
|
-
import multiprocessing
|
2
|
-
import os
|
3
|
-
from functools import partial
|
4
|
-
from pathlib import Path
|
5
|
-
|
6
|
-
import gradio as gr
|
7
|
-
|
8
|
-
import birdnet_analyzer.config as cfg
|
9
|
-
import birdnet_analyzer.gui.localization as loc
|
10
|
-
import birdnet_analyzer.gui.utils as gu
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
)
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
import matplotlib
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
gu.validate(
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
cfg.
|
135
|
-
cfg.
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
cfg.
|
147
|
-
cfg.
|
148
|
-
cfg.
|
149
|
-
cfg.
|
150
|
-
cfg.
|
151
|
-
cfg.
|
152
|
-
cfg.
|
153
|
-
cfg.
|
154
|
-
cfg.
|
155
|
-
cfg.
|
156
|
-
cfg.
|
157
|
-
cfg.
|
158
|
-
cfg.
|
159
|
-
|
160
|
-
|
161
|
-
cfg.
|
162
|
-
|
163
|
-
|
164
|
-
cfg.
|
165
|
-
cfg.
|
166
|
-
cfg.
|
167
|
-
cfg.
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
cfg.
|
174
|
-
cfg.
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
plt.plot(
|
231
|
-
plt.
|
232
|
-
plt.
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
dir_name,
|
293
|
-
gr.
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
(loc.localize("training-tab-cache-mode-radio-option-
|
310
|
-
(loc.localize("training-tab-cache-mode-radio-option-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
)
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
)
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
)
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
|
362
|
-
|
363
|
-
|
364
|
-
|
365
|
-
|
366
|
-
|
367
|
-
gr.update(
|
368
|
-
|
369
|
-
|
370
|
-
gr.update(interactive=value != "load"),
|
371
|
-
|
372
|
-
gr.update(interactive=value != "load"),
|
373
|
-
gr.update(interactive=value != "load"),
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
(loc.localize("training-tab-crop-mode-radio-option-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
|
431
|
-
|
432
|
-
|
433
|
-
|
434
|
-
|
435
|
-
|
436
|
-
|
437
|
-
|
438
|
-
|
439
|
-
|
440
|
-
|
441
|
-
|
442
|
-
|
443
|
-
|
444
|
-
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
|
451
|
-
|
452
|
-
)
|
453
|
-
|
454
|
-
|
455
|
-
|
456
|
-
|
457
|
-
|
458
|
-
|
459
|
-
|
460
|
-
|
461
|
-
|
462
|
-
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
|
468
|
-
|
469
|
-
|
470
|
-
|
471
|
-
|
472
|
-
|
473
|
-
|
474
|
-
|
475
|
-
|
476
|
-
|
477
|
-
|
478
|
-
|
479
|
-
|
480
|
-
|
481
|
-
|
482
|
-
|
483
|
-
|
484
|
-
|
485
|
-
|
486
|
-
|
487
|
-
|
488
|
-
|
489
|
-
|
490
|
-
|
491
|
-
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
|
500
|
-
|
501
|
-
|
502
|
-
|
503
|
-
|
504
|
-
|
505
|
-
|
506
|
-
|
507
|
-
|
508
|
-
|
509
|
-
|
510
|
-
|
511
|
-
|
512
|
-
|
513
|
-
|
514
|
-
)
|
515
|
-
|
516
|
-
|
517
|
-
|
518
|
-
|
519
|
-
|
520
|
-
|
521
|
-
|
522
|
-
|
523
|
-
|
524
|
-
|
525
|
-
|
526
|
-
|
527
|
-
|
528
|
-
|
529
|
-
|
530
|
-
|
531
|
-
|
532
|
-
|
533
|
-
|
534
|
-
|
535
|
-
|
536
|
-
|
537
|
-
|
538
|
-
|
539
|
-
|
540
|
-
|
541
|
-
|
542
|
-
|
543
|
-
|
544
|
-
|
545
|
-
cfg.
|
546
|
-
label=loc.localize("training-tab-use-
|
547
|
-
info=loc.localize("training-tab-use-
|
548
|
-
show_label=True,
|
549
|
-
)
|
550
|
-
|
551
|
-
|
552
|
-
|
553
|
-
|
554
|
-
|
555
|
-
|
556
|
-
|
557
|
-
|
558
|
-
|
559
|
-
|
560
|
-
|
561
|
-
|
562
|
-
|
563
|
-
|
564
|
-
|
565
|
-
|
566
|
-
|
567
|
-
)
|
568
|
-
|
569
|
-
|
570
|
-
|
571
|
-
|
572
|
-
|
573
|
-
|
574
|
-
|
575
|
-
|
576
|
-
|
577
|
-
|
578
|
-
|
579
|
-
|
580
|
-
|
581
|
-
|
582
|
-
|
583
|
-
|
584
|
-
|
585
|
-
|
586
|
-
|
587
|
-
|
588
|
-
|
589
|
-
|
590
|
-
|
591
|
-
|
592
|
-
|
593
|
-
|
594
|
-
|
595
|
-
|
596
|
-
|
597
|
-
|
598
|
-
|
599
|
-
|
600
|
-
|
601
|
-
|
602
|
-
|
603
|
-
|
604
|
-
|
605
|
-
|
606
|
-
|
607
|
-
|
608
|
-
|
609
|
-
|
610
|
-
|
611
|
-
|
612
|
-
|
613
|
-
|
614
|
-
|
615
|
-
|
616
|
-
|
617
|
-
|
618
|
-
|
619
|
-
|
620
|
-
|
621
|
-
|
622
|
-
|
623
|
-
|
624
|
-
|
625
|
-
|
626
|
-
|
627
|
-
|
628
|
-
|
629
|
-
|
630
|
-
|
631
|
-
|
632
|
-
|
633
|
-
|
634
|
-
|
635
|
-
|
636
|
-
|
637
|
-
|
638
|
-
|
639
|
-
|
640
|
-
|
641
|
-
|
642
|
-
|
643
|
-
|
644
|
-
|
645
|
-
|
646
|
-
|
647
|
-
|
648
|
-
|
649
|
-
|
650
|
-
|
651
|
-
|
652
|
-
|
653
|
-
|
654
|
-
|
655
|
-
|
656
|
-
|
657
|
-
|
658
|
-
|
659
|
-
|
660
|
-
|
661
|
-
|
662
|
-
|
663
|
-
|
664
|
-
|
665
|
-
|
666
|
-
|
667
|
-
|
668
|
-
|
669
|
-
|
670
|
-
|
671
|
-
|
672
|
-
|
673
|
-
|
674
|
-
|
675
|
-
|
676
|
-
|
677
|
-
|
678
|
-
|
679
|
-
|
680
|
-
|
681
|
-
|
682
|
-
|
683
|
-
|
684
|
-
|
685
|
-
|
686
|
-
|
687
|
-
|
688
|
-
upsampling_ratio,
|
689
|
-
upsampling_mode,
|
690
|
-
output_format,
|
691
|
-
audio_speed_slider,
|
692
|
-
],
|
693
|
-
outputs=[train_history_plot, metrics_table],
|
694
|
-
)
|
695
|
-
|
696
|
-
|
697
|
-
if __name__ == "__main__":
|
698
|
-
gu.open_window(build_train_tab)
|
1
|
+
import multiprocessing
|
2
|
+
import os
|
3
|
+
from functools import partial
|
4
|
+
from pathlib import Path
|
5
|
+
|
6
|
+
import gradio as gr
|
7
|
+
|
8
|
+
import birdnet_analyzer.config as cfg
|
9
|
+
import birdnet_analyzer.gui.localization as loc
|
10
|
+
import birdnet_analyzer.gui.utils as gu
|
11
|
+
from birdnet_analyzer import utils
|
12
|
+
from birdnet_analyzer.gui.settings import APPDIR
|
13
|
+
|
14
|
+
_GRID_MAX_HEIGHT = 240
|
15
|
+
|
16
|
+
|
17
|
+
def select_subdirectories(state_key=None):
|
18
|
+
"""Creates a directory selection dialog.
|
19
|
+
|
20
|
+
Returns:
|
21
|
+
A tuples of (directory, list of subdirectories) or (None, None) if the dialog was canceled.
|
22
|
+
"""
|
23
|
+
dir_name = gu.select_folder(state_key=state_key)
|
24
|
+
|
25
|
+
if dir_name:
|
26
|
+
subdirs = utils.list_subdirectories(dir_name)
|
27
|
+
labels = []
|
28
|
+
|
29
|
+
for folder in subdirs:
|
30
|
+
labels_in_folder = folder.split(",")
|
31
|
+
|
32
|
+
for label in labels_in_folder:
|
33
|
+
if label not in labels:
|
34
|
+
labels.append(label)
|
35
|
+
|
36
|
+
return dir_name, [[label] for label in sorted(labels)]
|
37
|
+
|
38
|
+
return None, None
|
39
|
+
|
40
|
+
|
41
|
+
@gu.gui_runtime_error_handler
|
42
|
+
def start_training(
|
43
|
+
data_dir,
|
44
|
+
test_data_dir,
|
45
|
+
crop_mode,
|
46
|
+
crop_overlap,
|
47
|
+
fmin,
|
48
|
+
fmax,
|
49
|
+
output_dir,
|
50
|
+
classifier_name,
|
51
|
+
model_save_mode,
|
52
|
+
cache_mode,
|
53
|
+
cache_file,
|
54
|
+
cache_file_name,
|
55
|
+
autotune,
|
56
|
+
autotune_trials,
|
57
|
+
autotune_executions_per_trials,
|
58
|
+
epochs,
|
59
|
+
batch_size,
|
60
|
+
learning_rate,
|
61
|
+
focal_loss,
|
62
|
+
focal_loss_gamma,
|
63
|
+
focal_loss_alpha,
|
64
|
+
hidden_units,
|
65
|
+
dropout,
|
66
|
+
label_smoothing,
|
67
|
+
use_mixup,
|
68
|
+
upsampling_ratio,
|
69
|
+
upsampling_mode,
|
70
|
+
model_format,
|
71
|
+
audio_speed,
|
72
|
+
progress=gr.Progress(),
|
73
|
+
):
|
74
|
+
"""Starts the training of a custom classifier.
|
75
|
+
|
76
|
+
Args:
|
77
|
+
data_dir: Directory containing the training data.
|
78
|
+
test_data_dir: Directory containing the test data.
|
79
|
+
crop_mode: Mode for cropping audio samples.
|
80
|
+
crop_overlap: Overlap ratio for audio segments.
|
81
|
+
fmin: Minimum frequency for bandpass filtering.
|
82
|
+
fmax: Maximum frequency for bandpass filtering.
|
83
|
+
output_dir: Directory to save the trained model.
|
84
|
+
classifier_name: Name of the custom classifier.
|
85
|
+
model_save_mode: Save mode for the model (replace or append).
|
86
|
+
cache_mode: Cache mode for training data (load, save, or None).
|
87
|
+
cache_file: Path to the cache file.
|
88
|
+
cache_file_name: Name of the cache file.
|
89
|
+
autotune: Whether to use hyperparameter autotuning.
|
90
|
+
autotune_trials: Number of trials for autotuning.
|
91
|
+
autotune_executions_per_trials: Number of executions per autotuning trial.
|
92
|
+
epochs: Number of training epochs.
|
93
|
+
batch_size: Batch size for training.
|
94
|
+
learning_rate: Learning rate for the optimizer.
|
95
|
+
focal_loss: Whether to use focal loss for training.
|
96
|
+
focal_loss_gamma: Gamma parameter for focal loss.
|
97
|
+
focal_loss_alpha: Alpha parameter for focal loss.
|
98
|
+
hidden_units: Number of hidden units in the droput: Dropout rate for regularization.
|
99
|
+
dropout: Dropout rate for regularization.
|
100
|
+
label_smoothing: Whether to apply label smoothing for training.
|
101
|
+
use_mixup: Whether to use mixup data augmentation.
|
102
|
+
upsampling_ratio: Ratio for upsampling underrepresented classes.
|
103
|
+
upsampling_mode: Mode for upsampling (repeat, mean, smote).
|
104
|
+
model_format: Format to save the trained model (tflite, raven, both).
|
105
|
+
audio_speed: Speed factor for audio playback.
|
106
|
+
|
107
|
+
Returns:
|
108
|
+
Returns a matplotlib.pyplot figure.
|
109
|
+
"""
|
110
|
+
import matplotlib
|
111
|
+
import matplotlib.pyplot as plt
|
112
|
+
|
113
|
+
from birdnet_analyzer.train.utils import train_model
|
114
|
+
|
115
|
+
# Skip training data validation when cache mode is "load"
|
116
|
+
if cache_mode != "load":
|
117
|
+
gu.validate(data_dir, loc.localize("validation-no-training-data-selected"))
|
118
|
+
|
119
|
+
gu.validate(output_dir, loc.localize("validation-no-directory-for-classifier-selected"))
|
120
|
+
gu.validate(classifier_name, loc.localize("validation-no-valid-classifier-name"))
|
121
|
+
|
122
|
+
if not epochs or epochs < 0:
|
123
|
+
raise gr.Error(loc.localize("validation-no-valid-epoch-number"))
|
124
|
+
|
125
|
+
if not batch_size or batch_size < 0:
|
126
|
+
raise gr.Error(loc.localize("validation-no-valid-batch-size"))
|
127
|
+
|
128
|
+
if not learning_rate or learning_rate < 0:
|
129
|
+
raise gr.Error(loc.localize("validation-no-valid-learning-rate"))
|
130
|
+
|
131
|
+
if fmin < cfg.SIG_FMIN or fmax > cfg.SIG_FMAX or fmin > fmax:
|
132
|
+
raise gr.Error(f"{loc.localize('validation-no-valid-frequency')} [{cfg.SIG_FMIN}, {cfg.SIG_FMAX}]")
|
133
|
+
|
134
|
+
cfg.TRAIN_WITH_FOCAL_LOSS = focal_loss
|
135
|
+
cfg.FOCAL_LOSS_GAMMA = max(0.0, float(focal_loss_gamma))
|
136
|
+
cfg.FOCAL_LOSS_ALPHA = max(0.0, min(1.0, float(focal_loss_alpha)))
|
137
|
+
|
138
|
+
if not hidden_units or hidden_units < 0:
|
139
|
+
hidden_units = 0
|
140
|
+
|
141
|
+
cfg.TRAIN_DROPOUT = max(0.0, min(1.0, float(dropout)))
|
142
|
+
|
143
|
+
if progress is not None:
|
144
|
+
progress((0, epochs), desc=loc.localize("progress-build-classifier"), unit="epochs")
|
145
|
+
|
146
|
+
cfg.TRAIN_DATA_PATH = data_dir
|
147
|
+
cfg.TEST_DATA_PATH = test_data_dir
|
148
|
+
cfg.SAMPLE_CROP_MODE = crop_mode
|
149
|
+
cfg.SIG_OVERLAP = max(0.0, min(2.9, float(crop_overlap)))
|
150
|
+
cfg.CUSTOM_CLASSIFIER = str(Path(output_dir) / classifier_name)
|
151
|
+
cfg.TRAIN_EPOCHS = int(epochs)
|
152
|
+
cfg.TRAIN_BATCH_SIZE = int(batch_size)
|
153
|
+
cfg.TRAIN_LEARNING_RATE = learning_rate
|
154
|
+
cfg.TRAIN_HIDDEN_UNITS = int(hidden_units)
|
155
|
+
cfg.TRAIN_WITH_LABEL_SMOOTHING = label_smoothing
|
156
|
+
cfg.TRAIN_WITH_MIXUP = use_mixup
|
157
|
+
cfg.UPSAMPLING_RATIO = min(max(0, upsampling_ratio), 1)
|
158
|
+
cfg.UPSAMPLING_MODE = upsampling_mode
|
159
|
+
cfg.TRAINED_MODEL_OUTPUT_FORMAT = model_format
|
160
|
+
|
161
|
+
cfg.BANDPASS_FMIN = max(0, min(cfg.SIG_FMAX, int(fmin)))
|
162
|
+
cfg.BANDPASS_FMAX = max(cfg.SIG_FMIN, min(cfg.SIG_FMAX, int(fmax)))
|
163
|
+
|
164
|
+
cfg.TRAINED_MODEL_SAVE_MODE = model_save_mode
|
165
|
+
cfg.TRAIN_CACHE_MODE = cache_mode
|
166
|
+
cfg.TRAIN_CACHE_FILE = os.path.join(cache_file, cache_file_name) if cache_mode == "save" else cache_file
|
167
|
+
cfg.TFLITE_THREADS = 1
|
168
|
+
cfg.CPU_THREADS = max(1, multiprocessing.cpu_count() - 1) # let's use everything we have (well, almost)
|
169
|
+
|
170
|
+
if cache_mode == "load" and not os.path.isfile(cfg.TRAIN_CACHE_FILE):
|
171
|
+
raise gr.Error(loc.localize("validation-no-cache-file-selected"))
|
172
|
+
|
173
|
+
cfg.AUTOTUNE = autotune
|
174
|
+
cfg.AUTOTUNE_TRIALS = autotune_trials
|
175
|
+
cfg.AUTOTUNE_EXECUTIONS_PER_TRIAL = int(autotune_executions_per_trials)
|
176
|
+
|
177
|
+
cfg.AUDIO_SPEED = max(0.1, 1.0 / (audio_speed * -1)) if audio_speed < 0 else max(1.0, float(audio_speed))
|
178
|
+
|
179
|
+
def data_load_progression(num_files, num_total_files, label):
|
180
|
+
if progress is not None:
|
181
|
+
progress(
|
182
|
+
(num_files, num_total_files),
|
183
|
+
total=num_total_files,
|
184
|
+
unit="files",
|
185
|
+
desc=f"{loc.localize('progress-loading-data')} '{label}'",
|
186
|
+
)
|
187
|
+
|
188
|
+
def epoch_progression(epoch, logs=None):
|
189
|
+
if progress is not None:
|
190
|
+
if epoch + 1 == epochs:
|
191
|
+
progress(
|
192
|
+
(epoch + 1, epochs),
|
193
|
+
total=epochs,
|
194
|
+
unit="epochs",
|
195
|
+
desc=f"{loc.localize('progress-saving')} {cfg.CUSTOM_CLASSIFIER}",
|
196
|
+
)
|
197
|
+
else:
|
198
|
+
progress((epoch + 1, epochs), total=epochs, unit="epochs", desc=loc.localize("progress-training"))
|
199
|
+
|
200
|
+
def trial_progression(trial):
|
201
|
+
if progress is not None:
|
202
|
+
progress((trial, autotune_trials), total=autotune_trials, unit="trials", desc=loc.localize("progress-autotune"))
|
203
|
+
|
204
|
+
try:
|
205
|
+
history_result = train_model(
|
206
|
+
on_epoch_end=epoch_progression,
|
207
|
+
on_trial_result=trial_progression,
|
208
|
+
on_data_load_end=data_load_progression,
|
209
|
+
autotune_directory=APPDIR if utils.FROZEN else "autotune",
|
210
|
+
)
|
211
|
+
|
212
|
+
# Unpack history and metrics
|
213
|
+
history, metrics = history_result
|
214
|
+
except Exception as e:
|
215
|
+
if e.args and len(e.args) > 1:
|
216
|
+
raise gr.Error(loc.localize(e.args[1])) from e
|
217
|
+
|
218
|
+
raise gr.Error(f"{e}") from e
|
219
|
+
|
220
|
+
if len(history.epoch) < epochs:
|
221
|
+
gr.Info(loc.localize("training-tab-early-stoppage-msg"))
|
222
|
+
|
223
|
+
auprc = history.history["val_AUPRC"]
|
224
|
+
auroc = history.history["val_AUROC"]
|
225
|
+
|
226
|
+
matplotlib.use("agg")
|
227
|
+
|
228
|
+
fig = plt.figure()
|
229
|
+
plt.plot(auprc, label="AUPRC")
|
230
|
+
plt.plot(auroc, label="AUROC")
|
231
|
+
plt.legend()
|
232
|
+
plt.xlabel("Epoch")
|
233
|
+
|
234
|
+
return fig, metrics
|
235
|
+
|
236
|
+
|
237
|
+
def build_train_tab():
|
238
|
+
with gr.Tab(loc.localize("training-tab-title")):
|
239
|
+
input_directory_state = gr.State()
|
240
|
+
output_directory_state = gr.State()
|
241
|
+
test_data_dir_state = gr.State()
|
242
|
+
|
243
|
+
with gr.Row():
|
244
|
+
with gr.Column():
|
245
|
+
select_directory_btn = gr.Button(loc.localize("training-tab-input-selection-button-label"))
|
246
|
+
directory_input = gr.List(
|
247
|
+
headers=[loc.localize("training-tab-classes-dataframe-column-classes-header")],
|
248
|
+
interactive=False,
|
249
|
+
max_height=_GRID_MAX_HEIGHT,
|
250
|
+
)
|
251
|
+
select_directory_btn.click(
|
252
|
+
partial(select_subdirectories, state_key="train-data-dir"),
|
253
|
+
outputs=[input_directory_state, directory_input],
|
254
|
+
show_progress=False,
|
255
|
+
)
|
256
|
+
|
257
|
+
select_test_directory_btn = gr.Button(loc.localize("training-tab-test-data-selection-button-label"))
|
258
|
+
test_directory_input = gr.List(
|
259
|
+
headers=[loc.localize("training-tab-classes-dataframe-column-classes-header")],
|
260
|
+
interactive=False,
|
261
|
+
max_height=_GRID_MAX_HEIGHT,
|
262
|
+
)
|
263
|
+
select_test_directory_btn.click(
|
264
|
+
partial(select_subdirectories, state_key="test-data-dir"),
|
265
|
+
outputs=[test_data_dir_state, test_directory_input],
|
266
|
+
show_progress=False,
|
267
|
+
)
|
268
|
+
|
269
|
+
with gr.Column():
|
270
|
+
select_classifier_directory_btn = gr.Button(loc.localize("training-tab-select-output-button-label"))
|
271
|
+
|
272
|
+
with gr.Column():
|
273
|
+
classifier_name = gr.Textbox(
|
274
|
+
"CustomClassifier",
|
275
|
+
visible=False,
|
276
|
+
info=loc.localize("training-tab-classifier-textbox-info"),
|
277
|
+
)
|
278
|
+
output_format = gr.Radio(
|
279
|
+
["tflite", "raven", (loc.localize("training-tab-output-format-both"), "both")],
|
280
|
+
value=cfg.TRAINED_MODEL_OUTPUT_FORMAT,
|
281
|
+
label=loc.localize("training-tab-output-format-radio-label"),
|
282
|
+
info=loc.localize("training-tab-output-format-radio-info"),
|
283
|
+
visible=False,
|
284
|
+
)
|
285
|
+
|
286
|
+
def select_directory_and_update_tb():
|
287
|
+
dir_name = gu.select_folder(state_key="train-output-dir")
|
288
|
+
|
289
|
+
if dir_name:
|
290
|
+
return (
|
291
|
+
dir_name,
|
292
|
+
gr.Textbox(label=dir_name, visible=True),
|
293
|
+
gr.Radio(visible=True, interactive=True),
|
294
|
+
)
|
295
|
+
|
296
|
+
return None, None
|
297
|
+
|
298
|
+
select_classifier_directory_btn.click(
|
299
|
+
select_directory_and_update_tb,
|
300
|
+
outputs=[output_directory_state, classifier_name, output_format],
|
301
|
+
show_progress=False,
|
302
|
+
)
|
303
|
+
|
304
|
+
with gr.Row():
|
305
|
+
cache_file_state = gr.State()
|
306
|
+
cache_mode = gr.Radio(
|
307
|
+
[
|
308
|
+
(loc.localize("training-tab-cache-mode-radio-option-none"), None),
|
309
|
+
(loc.localize("training-tab-cache-mode-radio-option-load"), "load"),
|
310
|
+
(loc.localize("training-tab-cache-mode-radio-option-save"), "save"),
|
311
|
+
],
|
312
|
+
value=cfg.TRAIN_CACHE_MODE,
|
313
|
+
label=loc.localize("training-tab-cache-mode-radio-label"),
|
314
|
+
info=loc.localize("training-tab-cache-mode-radio-info"),
|
315
|
+
)
|
316
|
+
with gr.Column(visible=False) as new_cache_file_row:
|
317
|
+
select_cache_file_directory_btn = gr.Button(loc.localize("training-tab-cache-select-directory-button-label"))
|
318
|
+
|
319
|
+
with gr.Column():
|
320
|
+
cache_file_name = gr.Textbox(
|
321
|
+
"train_cache.npz",
|
322
|
+
visible=False,
|
323
|
+
info=loc.localize("training-tab-cache-file-name-textbox-info"),
|
324
|
+
)
|
325
|
+
|
326
|
+
def select_directory_and_update():
|
327
|
+
dir_name = gu.select_folder(state_key="train-data-cache-file-output")
|
328
|
+
|
329
|
+
if dir_name:
|
330
|
+
return (
|
331
|
+
dir_name,
|
332
|
+
gr.Textbox(label=dir_name, visible=True),
|
333
|
+
)
|
334
|
+
|
335
|
+
return None, None
|
336
|
+
|
337
|
+
select_cache_file_directory_btn.click(
|
338
|
+
select_directory_and_update,
|
339
|
+
outputs=[cache_file_state, cache_file_name],
|
340
|
+
show_progress=False,
|
341
|
+
)
|
342
|
+
|
343
|
+
with gr.Column(visible=False) as load_cache_file_row:
|
344
|
+
selected_cache_file_btn = gr.Button(loc.localize("training-tab-cache-select-file-button-label"))
|
345
|
+
cache_file_input = gr.File(file_types=[".npz"], visible=False, interactive=False)
|
346
|
+
|
347
|
+
def on_cache_file_selection_click():
|
348
|
+
file = gu.select_file(("NPZ file (*.npz)",), state_key="train_data_cache_file")
|
349
|
+
|
350
|
+
if file:
|
351
|
+
return file, gr.File(value=file, visible=True)
|
352
|
+
|
353
|
+
return None, None
|
354
|
+
|
355
|
+
selected_cache_file_btn.click(
|
356
|
+
on_cache_file_selection_click,
|
357
|
+
outputs=[cache_file_state, cache_file_input],
|
358
|
+
show_progress=False,
|
359
|
+
)
|
360
|
+
|
361
|
+
def on_cache_mode_change(value):
|
362
|
+
return (
|
363
|
+
gr.update(visible=value == "save"),
|
364
|
+
gr.update(visible=value == "load"),
|
365
|
+
gr.update(interactive=value != "load"),
|
366
|
+
[],
|
367
|
+
gr.update(interactive=value != "load"),
|
368
|
+
[],
|
369
|
+
gr.update(interactive=value != "load"),
|
370
|
+
gr.update(interactive=value != "load"),
|
371
|
+
gr.update(interactive=value != "load"),
|
372
|
+
gr.update(interactive=value != "load"),
|
373
|
+
gr.update(interactive=value != "load"),
|
374
|
+
)
|
375
|
+
|
376
|
+
with gr.Row():
|
377
|
+
fmin_number = gr.Number(
|
378
|
+
cfg.SIG_FMIN,
|
379
|
+
minimum=0,
|
380
|
+
label=loc.localize("inference-settings-fmin-number-label"),
|
381
|
+
info=loc.localize("inference-settings-fmin-number-info"),
|
382
|
+
)
|
383
|
+
|
384
|
+
fmax_number = gr.Number(
|
385
|
+
cfg.SIG_FMAX,
|
386
|
+
minimum=0,
|
387
|
+
label=loc.localize("inference-settings-fmax-number-label"),
|
388
|
+
info=loc.localize("inference-settings-fmax-number-info"),
|
389
|
+
)
|
390
|
+
|
391
|
+
with gr.Row():
|
392
|
+
audio_speed_slider = gr.Slider(
|
393
|
+
minimum=-10,
|
394
|
+
maximum=10,
|
395
|
+
value=cfg.AUDIO_SPEED,
|
396
|
+
step=1,
|
397
|
+
label=loc.localize("training-tab-audio-speed-slider-label"),
|
398
|
+
info=loc.localize("training-tab-audio-speed-slider-info"),
|
399
|
+
)
|
400
|
+
|
401
|
+
with gr.Row():
|
402
|
+
crop_mode = gr.Radio(
|
403
|
+
[
|
404
|
+
(loc.localize("training-tab-crop-mode-radio-option-center"), "center"),
|
405
|
+
(loc.localize("training-tab-crop-mode-radio-option-first"), "first"),
|
406
|
+
(loc.localize("training-tab-crop-mode-radio-option-segments"), "segments"),
|
407
|
+
(loc.localize("training-tab-crop-mode-radio-option-smart"), "smart"),
|
408
|
+
],
|
409
|
+
value="center",
|
410
|
+
label=loc.localize("training-tab-crop-mode-radio-label"),
|
411
|
+
info=loc.localize("training-tab-crop-mode-radio-info"),
|
412
|
+
)
|
413
|
+
|
414
|
+
crop_overlap = gr.Slider(
|
415
|
+
minimum=0,
|
416
|
+
maximum=2.99,
|
417
|
+
value=cfg.SIG_OVERLAP,
|
418
|
+
step=0.01,
|
419
|
+
label=loc.localize("training-tab-crop-overlap-number-label"),
|
420
|
+
info=loc.localize("training-tab-crop-overlap-number-info"),
|
421
|
+
visible=False,
|
422
|
+
)
|
423
|
+
|
424
|
+
def on_crop_select(new_crop_mode):
|
425
|
+
# Make overlap slider visible for both "segments" and "smart" crop modes
|
426
|
+
return gr.Number(visible=new_crop_mode in ["segments", "smart"], interactive=new_crop_mode in ["segments", "smart"])
|
427
|
+
|
428
|
+
crop_mode.change(on_crop_select, inputs=crop_mode, outputs=crop_overlap)
|
429
|
+
|
430
|
+
cache_mode.change(
|
431
|
+
on_cache_mode_change,
|
432
|
+
inputs=cache_mode,
|
433
|
+
outputs=[
|
434
|
+
new_cache_file_row,
|
435
|
+
load_cache_file_row,
|
436
|
+
select_directory_btn,
|
437
|
+
directory_input,
|
438
|
+
select_test_directory_btn,
|
439
|
+
test_directory_input,
|
440
|
+
fmin_number,
|
441
|
+
fmax_number,
|
442
|
+
audio_speed_slider,
|
443
|
+
crop_mode,
|
444
|
+
crop_overlap,
|
445
|
+
],
|
446
|
+
show_progress=False,
|
447
|
+
)
|
448
|
+
|
449
|
+
autotune_cb = gr.Checkbox(
|
450
|
+
cfg.AUTOTUNE,
|
451
|
+
label=loc.localize("training-tab-autotune-checkbox-label"),
|
452
|
+
info=loc.localize("training-tab-autotune-checkbox-info"),
|
453
|
+
)
|
454
|
+
|
455
|
+
with gr.Column(visible=False) as autotune_params, gr.Row():
|
456
|
+
autotune_trials = gr.Number(
|
457
|
+
cfg.AUTOTUNE_TRIALS,
|
458
|
+
label=loc.localize("training-tab-autotune-trials-number-label"),
|
459
|
+
info=loc.localize("training-tab-autotune-trials-number-info"),
|
460
|
+
minimum=1,
|
461
|
+
)
|
462
|
+
autotune_executions_per_trials = gr.Number(
|
463
|
+
cfg.AUTOTUNE_EXECUTIONS_PER_TRIAL,
|
464
|
+
minimum=1,
|
465
|
+
label=loc.localize("training-tab-autotune-executions-number-label"),
|
466
|
+
info=loc.localize("training-tab-autotune-executions-number-info"),
|
467
|
+
)
|
468
|
+
|
469
|
+
with gr.Column() as custom_params:
|
470
|
+
with gr.Row():
|
471
|
+
epoch_number = gr.Number(
|
472
|
+
cfg.TRAIN_EPOCHS,
|
473
|
+
minimum=1,
|
474
|
+
step=1,
|
475
|
+
label=loc.localize("training-tab-epochs-number-label"),
|
476
|
+
info=loc.localize("training-tab-epochs-number-info"),
|
477
|
+
)
|
478
|
+
batch_size_number = gr.Number(
|
479
|
+
32,
|
480
|
+
minimum=1,
|
481
|
+
step=8,
|
482
|
+
label=loc.localize("training-tab-batchsize-number-label"),
|
483
|
+
info=loc.localize("training-tab-batchsize-number-info"),
|
484
|
+
)
|
485
|
+
learning_rate_number = gr.Number(
|
486
|
+
cfg.TRAIN_LEARNING_RATE,
|
487
|
+
minimum=0.0001,
|
488
|
+
step=0.0001,
|
489
|
+
label=loc.localize("training-tab-learningrate-number-label"),
|
490
|
+
info=loc.localize("training-tab-learningrate-number-info"),
|
491
|
+
)
|
492
|
+
|
493
|
+
with gr.Row():
|
494
|
+
hidden_units_number = gr.Number(
|
495
|
+
cfg.TRAIN_HIDDEN_UNITS,
|
496
|
+
minimum=0,
|
497
|
+
step=64,
|
498
|
+
label=loc.localize("training-tab-hiddenunits-number-label"),
|
499
|
+
info=loc.localize("training-tab-hiddenunits-number-info"),
|
500
|
+
)
|
501
|
+
dropout_number = gr.Number(
|
502
|
+
cfg.TRAIN_DROPOUT,
|
503
|
+
minimum=0.0,
|
504
|
+
maximum=0.9,
|
505
|
+
step=0.1,
|
506
|
+
label=loc.localize("training-tab-dropout-number-label"),
|
507
|
+
info=loc.localize("training-tab-dropout-number-info"),
|
508
|
+
)
|
509
|
+
use_label_smoothing = gr.Checkbox(
|
510
|
+
cfg.TRAIN_WITH_LABEL_SMOOTHING,
|
511
|
+
label=loc.localize("training-tab-use-labelsmoothing-checkbox-label"),
|
512
|
+
info=loc.localize("training-tab-use-labelsmoothing-checkbox-info"),
|
513
|
+
show_label=True,
|
514
|
+
)
|
515
|
+
|
516
|
+
with gr.Row():
|
517
|
+
upsampling_mode = gr.Radio(
|
518
|
+
[
|
519
|
+
(loc.localize("training-tab-upsampling-radio-option-repeat"), "repeat"),
|
520
|
+
(loc.localize("training-tab-upsampling-radio-option-mean"), "mean"),
|
521
|
+
(loc.localize("training-tab-upsampling-radio-option-linear"), "linear"),
|
522
|
+
("SMOTE", "smote"),
|
523
|
+
],
|
524
|
+
value=cfg.UPSAMPLING_MODE,
|
525
|
+
label=loc.localize("training-tab-upsampling-radio-label"),
|
526
|
+
info=loc.localize("training-tab-upsampling-radio-info"),
|
527
|
+
)
|
528
|
+
upsampling_ratio = gr.Slider(
|
529
|
+
0.0,
|
530
|
+
1.0,
|
531
|
+
cfg.UPSAMPLING_RATIO,
|
532
|
+
step=0.05,
|
533
|
+
label=loc.localize("training-tab-upsampling-ratio-slider-label"),
|
534
|
+
info=loc.localize("training-tab-upsampling-ratio-slider-info"),
|
535
|
+
)
|
536
|
+
|
537
|
+
with gr.Row():
|
538
|
+
use_mixup = gr.Checkbox(
|
539
|
+
cfg.TRAIN_WITH_MIXUP,
|
540
|
+
label=loc.localize("training-tab-use-mixup-checkbox-label"),
|
541
|
+
info=loc.localize("training-tab-use-mixup-checkbox-info"),
|
542
|
+
show_label=True,
|
543
|
+
)
|
544
|
+
use_focal_loss = gr.Checkbox(
|
545
|
+
cfg.TRAIN_WITH_FOCAL_LOSS,
|
546
|
+
label=loc.localize("training-tab-use-focal-loss-checkbox-label"),
|
547
|
+
info=loc.localize("training-tab-use-focal-loss-checkbox-info"),
|
548
|
+
show_label=True,
|
549
|
+
)
|
550
|
+
|
551
|
+
with gr.Row(visible=False) as focal_loss_params, gr.Row():
|
552
|
+
focal_loss_gamma = gr.Slider(
|
553
|
+
minimum=0.5,
|
554
|
+
maximum=5.0,
|
555
|
+
value=cfg.FOCAL_LOSS_GAMMA,
|
556
|
+
step=0.1,
|
557
|
+
label=loc.localize("training-tab-focal-loss-gamma-slider-label"),
|
558
|
+
info=loc.localize("training-tab-focal-loss-gamma-slider-info"),
|
559
|
+
interactive=True,
|
560
|
+
)
|
561
|
+
focal_loss_alpha = gr.Slider(
|
562
|
+
minimum=0.1,
|
563
|
+
maximum=0.9,
|
564
|
+
value=cfg.FOCAL_LOSS_ALPHA,
|
565
|
+
step=0.05,
|
566
|
+
label=loc.localize("training-tab-focal-loss-alpha-slider-label"),
|
567
|
+
info=loc.localize("training-tab-focal-loss-alpha-slider-info"),
|
568
|
+
interactive=True,
|
569
|
+
)
|
570
|
+
|
571
|
+
def on_focal_loss_change(value):
|
572
|
+
return gr.Row(visible=value)
|
573
|
+
|
574
|
+
use_focal_loss.change(on_focal_loss_change, inputs=use_focal_loss, outputs=focal_loss_params, show_progress=False)
|
575
|
+
|
576
|
+
def on_autotune_change(value):
|
577
|
+
return (
|
578
|
+
gr.Column(visible=not value),
|
579
|
+
gr.Column(visible=value),
|
580
|
+
gr.Row(visible=not value and use_focal_loss.value),
|
581
|
+
)
|
582
|
+
|
583
|
+
autotune_cb.change(
|
584
|
+
on_autotune_change,
|
585
|
+
inputs=autotune_cb,
|
586
|
+
outputs=[custom_params, autotune_params, focal_loss_params],
|
587
|
+
show_progress=False,
|
588
|
+
)
|
589
|
+
|
590
|
+
model_save_mode = gr.Radio(
|
591
|
+
[
|
592
|
+
(loc.localize("training-tab-model-save-mode-radio-option-replace"), "replace"),
|
593
|
+
(loc.localize("training-tab-model-save-mode-radio-option-append"), "append"),
|
594
|
+
],
|
595
|
+
value=cfg.TRAINED_MODEL_SAVE_MODE,
|
596
|
+
label=loc.localize("training-tab-model-save-mode-radio-label"),
|
597
|
+
info=loc.localize("training-tab-model-save-mode-radio-info"),
|
598
|
+
)
|
599
|
+
|
600
|
+
train_history_plot = gr.Plot()
|
601
|
+
metrics_table = gr.Dataframe(
|
602
|
+
headers=["Class", "Precision", "Recall", "F1 Score", "AUPRC", "AUROC", "Samples"],
|
603
|
+
visible=False,
|
604
|
+
label="Model Performance Metrics (Default Threshold 0.5)",
|
605
|
+
)
|
606
|
+
start_training_button = gr.Button(loc.localize("training-tab-start-training-button-label"), variant="huggingface")
|
607
|
+
|
608
|
+
def train_and_show_metrics(*args):
|
609
|
+
history, metrics = start_training(*args)
|
610
|
+
|
611
|
+
# If metrics are available (test data was provided), create table
|
612
|
+
if metrics:
|
613
|
+
# Create dataframe data with metrics
|
614
|
+
table_data = []
|
615
|
+
|
616
|
+
# Add overall metrics row first
|
617
|
+
table_data.append(
|
618
|
+
[
|
619
|
+
"OVERALL (Macro-avg)",
|
620
|
+
f"{metrics['macro_precision_default']:.4f}",
|
621
|
+
f"{metrics['macro_recall_default']:.4f}",
|
622
|
+
f"{metrics['macro_f1_default']:.4f}",
|
623
|
+
f"{metrics['macro_auprc']:.4f}",
|
624
|
+
f"{metrics['macro_auroc']:.4f}",
|
625
|
+
"",
|
626
|
+
]
|
627
|
+
)
|
628
|
+
|
629
|
+
# Add class-specific metrics
|
630
|
+
for class_name, class_metrics in metrics["class_metrics"].items():
|
631
|
+
distribution = metrics["class_distribution"].get(class_name, {"count": 0, "percentage": 0.0})
|
632
|
+
table_data.append(
|
633
|
+
[
|
634
|
+
class_name,
|
635
|
+
f"{class_metrics['precision_default']:.4f}",
|
636
|
+
f"{class_metrics['recall_default']:.4f}",
|
637
|
+
f"{class_metrics['f1_default']:.4f}",
|
638
|
+
f"{class_metrics['auprc']:.4f}",
|
639
|
+
f"{class_metrics['auroc']:.4f}",
|
640
|
+
f"{distribution['count']} ({distribution['percentage']:.2f}%)",
|
641
|
+
]
|
642
|
+
)
|
643
|
+
|
644
|
+
return history, gr.Dataframe(visible=True, value=table_data)
|
645
|
+
|
646
|
+
# No metrics available, just return history and hide table
|
647
|
+
return history, gr.Dataframe(visible=False)
|
648
|
+
|
649
|
+
start_training_button.click(
|
650
|
+
train_and_show_metrics,
|
651
|
+
inputs=[
|
652
|
+
input_directory_state,
|
653
|
+
test_data_dir_state,
|
654
|
+
crop_mode,
|
655
|
+
crop_overlap,
|
656
|
+
fmin_number,
|
657
|
+
fmax_number,
|
658
|
+
output_directory_state,
|
659
|
+
classifier_name,
|
660
|
+
model_save_mode,
|
661
|
+
cache_mode,
|
662
|
+
cache_file_state,
|
663
|
+
cache_file_name,
|
664
|
+
autotune_cb,
|
665
|
+
autotune_trials,
|
666
|
+
autotune_executions_per_trials,
|
667
|
+
epoch_number,
|
668
|
+
batch_size_number,
|
669
|
+
learning_rate_number,
|
670
|
+
use_focal_loss,
|
671
|
+
focal_loss_gamma,
|
672
|
+
focal_loss_alpha,
|
673
|
+
hidden_units_number,
|
674
|
+
dropout_number,
|
675
|
+
use_label_smoothing,
|
676
|
+
use_mixup,
|
677
|
+
upsampling_ratio,
|
678
|
+
upsampling_mode,
|
679
|
+
output_format,
|
680
|
+
audio_speed_slider,
|
681
|
+
],
|
682
|
+
outputs=[train_history_plot, metrics_table],
|
683
|
+
)
|
684
|
+
|
685
|
+
|
686
|
+
if __name__ == "__main__":
|
687
|
+
gu.open_window(build_train_tab)
|