birdnet-analyzer 2.0.0__py3-none-any.whl → 2.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- birdnet_analyzer/__init__.py +9 -8
- birdnet_analyzer/analyze/__init__.py +5 -5
- birdnet_analyzer/analyze/__main__.py +3 -4
- birdnet_analyzer/analyze/cli.py +25 -25
- birdnet_analyzer/analyze/core.py +241 -245
- birdnet_analyzer/analyze/utils.py +692 -701
- birdnet_analyzer/audio.py +368 -372
- birdnet_analyzer/cli.py +709 -707
- birdnet_analyzer/config.py +242 -242
- birdnet_analyzer/eBird_taxonomy_codes_2021E.json +25279 -25279
- birdnet_analyzer/embeddings/__init__.py +3 -4
- birdnet_analyzer/embeddings/__main__.py +3 -3
- birdnet_analyzer/embeddings/cli.py +12 -13
- birdnet_analyzer/embeddings/core.py +69 -70
- birdnet_analyzer/embeddings/utils.py +179 -193
- birdnet_analyzer/evaluation/__init__.py +196 -195
- birdnet_analyzer/evaluation/__main__.py +3 -3
- birdnet_analyzer/evaluation/assessment/__init__.py +0 -0
- birdnet_analyzer/evaluation/assessment/metrics.py +388 -0
- birdnet_analyzer/evaluation/assessment/performance_assessor.py +409 -0
- birdnet_analyzer/evaluation/assessment/plotting.py +379 -0
- birdnet_analyzer/evaluation/preprocessing/__init__.py +0 -0
- birdnet_analyzer/evaluation/preprocessing/data_processor.py +631 -0
- birdnet_analyzer/evaluation/preprocessing/utils.py +98 -0
- birdnet_analyzer/gui/__init__.py +19 -23
- birdnet_analyzer/gui/__main__.py +3 -3
- birdnet_analyzer/gui/analysis.py +175 -174
- birdnet_analyzer/gui/assets/arrow_down.svg +4 -4
- birdnet_analyzer/gui/assets/arrow_left.svg +4 -4
- birdnet_analyzer/gui/assets/arrow_right.svg +4 -4
- birdnet_analyzer/gui/assets/arrow_up.svg +4 -4
- birdnet_analyzer/gui/assets/gui.css +28 -28
- birdnet_analyzer/gui/assets/gui.js +93 -93
- birdnet_analyzer/gui/embeddings.py +619 -620
- birdnet_analyzer/gui/evaluation.py +795 -813
- birdnet_analyzer/gui/localization.py +75 -68
- birdnet_analyzer/gui/multi_file.py +245 -246
- birdnet_analyzer/gui/review.py +519 -527
- birdnet_analyzer/gui/segments.py +191 -191
- birdnet_analyzer/gui/settings.py +128 -129
- birdnet_analyzer/gui/single_file.py +267 -269
- birdnet_analyzer/gui/species.py +95 -95
- birdnet_analyzer/gui/train.py +696 -698
- birdnet_analyzer/gui/utils.py +810 -808
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_af.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ar.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_bg.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ca.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_cs.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_da.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_de.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_el.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_en_uk.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_es.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_fi.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_fr.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_he.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_hr.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_hu.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_in.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_is.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_it.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ja.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ko.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_lt.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ml.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_nl.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_no.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_pl.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_pt_BR.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_pt_PT.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ro.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ru.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sk.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sl.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sr.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sv.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_th.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_tr.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_uk.txt +6522 -6522
- birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_zh.txt +6522 -6522
- birdnet_analyzer/lang/de.json +334 -334
- birdnet_analyzer/lang/en.json +334 -334
- birdnet_analyzer/lang/fi.json +334 -334
- birdnet_analyzer/lang/fr.json +334 -334
- birdnet_analyzer/lang/id.json +334 -334
- birdnet_analyzer/lang/pt-br.json +334 -334
- birdnet_analyzer/lang/ru.json +334 -334
- birdnet_analyzer/lang/se.json +334 -334
- birdnet_analyzer/lang/tlh.json +334 -334
- birdnet_analyzer/lang/zh_TW.json +334 -334
- birdnet_analyzer/model.py +1212 -1243
- birdnet_analyzer/playground.py +5 -0
- birdnet_analyzer/search/__init__.py +3 -3
- birdnet_analyzer/search/__main__.py +3 -3
- birdnet_analyzer/search/cli.py +11 -12
- birdnet_analyzer/search/core.py +78 -78
- birdnet_analyzer/search/utils.py +107 -111
- birdnet_analyzer/segments/__init__.py +3 -3
- birdnet_analyzer/segments/__main__.py +3 -3
- birdnet_analyzer/segments/cli.py +13 -14
- birdnet_analyzer/segments/core.py +81 -78
- birdnet_analyzer/segments/utils.py +383 -394
- birdnet_analyzer/species/__init__.py +3 -3
- birdnet_analyzer/species/__main__.py +3 -3
- birdnet_analyzer/species/cli.py +13 -14
- birdnet_analyzer/species/core.py +35 -35
- birdnet_analyzer/species/utils.py +74 -75
- birdnet_analyzer/train/__init__.py +3 -3
- birdnet_analyzer/train/__main__.py +3 -3
- birdnet_analyzer/train/cli.py +13 -14
- birdnet_analyzer/train/core.py +113 -113
- birdnet_analyzer/train/utils.py +877 -847
- birdnet_analyzer/translate.py +133 -104
- birdnet_analyzer/utils.py +426 -419
- {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.0.1.dist-info}/METADATA +137 -129
- birdnet_analyzer-2.0.1.dist-info/RECORD +125 -0
- {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.0.1.dist-info}/WHEEL +1 -1
- {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.0.1.dist-info}/licenses/LICENSE +18 -18
- birdnet_analyzer-2.0.0.dist-info/RECORD +0 -117
- {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.0.1.dist-info}/entry_points.txt +0 -0
- {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.0.1.dist-info}/top_level.txt +0 -0
birdnet_analyzer/train/utils.py
CHANGED
@@ -1,847 +1,877 @@
|
|
1
|
-
"""Module for training a custom classifier.
|
2
|
-
|
3
|
-
Can be used to train a custom classifier with new training data.
|
4
|
-
"""
|
5
|
-
|
6
|
-
import csv
|
7
|
-
import os
|
8
|
-
from functools import partial
|
9
|
-
from multiprocessing.pool import Pool
|
10
|
-
|
11
|
-
import numpy as np
|
12
|
-
import tqdm
|
13
|
-
|
14
|
-
import birdnet_analyzer.
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
csv_file_path =
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
sig_splits = [audio.
|
84
|
-
elif cfg.SAMPLE_CROP_MODE == "
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
sig_splits = audio.
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
print(f"\t...
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
"
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
"
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
#
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
os.path.join(data_path, folder
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
)
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
result
|
240
|
-
#
|
241
|
-
#
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
print(f"...Done. Loaded {x_train.shape[0]} training samples and {y_train.shape[1]} labels.", flush=True)
|
319
|
-
if len(x_test) > 0:
|
320
|
-
print(f"...Loaded {x_test.shape[0]} test samples.", flush=True)
|
321
|
-
|
322
|
-
# Normalize embeddings
|
323
|
-
print("Normalizing embeddings...", flush=True)
|
324
|
-
x_train = normalize_embeddings(x_train)
|
325
|
-
if len(x_test) > 0:
|
326
|
-
x_test = normalize_embeddings(x_test)
|
327
|
-
|
328
|
-
if cfg.AUTOTUNE:
|
329
|
-
import gc
|
330
|
-
|
331
|
-
import keras
|
332
|
-
import keras_tuner
|
333
|
-
|
334
|
-
# Call callback to initialize progress bar
|
335
|
-
if on_trial_result:
|
336
|
-
on_trial_result(0)
|
337
|
-
|
338
|
-
class BirdNetTuner(keras_tuner.BayesianOptimization):
|
339
|
-
def __init__(self, x_train, y_train, x_test, y_test, max_trials, executions_per_trial, on_trial_result):
|
340
|
-
super().__init__(
|
341
|
-
max_trials=max_trials,
|
342
|
-
executions_per_trial=executions_per_trial,
|
343
|
-
overwrite=True,
|
344
|
-
directory=autotune_directory,
|
345
|
-
project_name="birdnet_analyzer",
|
346
|
-
)
|
347
|
-
self.x_train = x_train
|
348
|
-
self.y_train = y_train
|
349
|
-
self.x_test = x_test
|
350
|
-
self.y_test = y_test
|
351
|
-
self.on_trial_result = on_trial_result
|
352
|
-
|
353
|
-
def run_trial(self, trial, *args, **kwargs):
|
354
|
-
histories = []
|
355
|
-
hp: keras_tuner.HyperParameters = trial.hyperparameters
|
356
|
-
trial_number = len(self.oracle.trials)
|
357
|
-
|
358
|
-
for execution in range(int(self.executions_per_trial)):
|
359
|
-
print(f"Running Trial #{trial_number} execution #{execution + 1}", flush=True)
|
360
|
-
|
361
|
-
# Build model
|
362
|
-
print("Building model...", flush=True)
|
363
|
-
classifier = model.build_linear_classifier(
|
364
|
-
self.y_train.shape[1],
|
365
|
-
self.x_train.shape[1],
|
366
|
-
hidden_units=hp.Choice(
|
367
|
-
"hidden_units", [0, 128, 256, 512, 1024, 2048], default=cfg.TRAIN_HIDDEN_UNITS
|
368
|
-
),
|
369
|
-
dropout=hp.Choice("dropout", [0.0, 0.25, 0.33, 0.5, 0.75, 0.9], default=cfg.TRAIN_DROPOUT),
|
370
|
-
)
|
371
|
-
print("...Done.", flush=True)
|
372
|
-
|
373
|
-
# Only allow repeat upsampling in multi-label setting
|
374
|
-
upsampling_choices = ["repeat", "mean", "linear"] # SMOTE is too slow
|
375
|
-
|
376
|
-
if cfg.MULTI_LABEL:
|
377
|
-
upsampling_choices = ["repeat"]
|
378
|
-
|
379
|
-
batch_size = hp.Choice("batch_size", [8, 16, 32, 64, 128], default=cfg.TRAIN_BATCH_SIZE)
|
380
|
-
|
381
|
-
if batch_size == 8:
|
382
|
-
learning_rate = hp.Choice(
|
383
|
-
"learning_rate_8",
|
384
|
-
[0.0005, 0.0002, 0.0001],
|
385
|
-
default=0.0001,
|
386
|
-
parent_name="batch_size",
|
387
|
-
parent_values=[8],
|
388
|
-
)
|
389
|
-
elif batch_size == 16:
|
390
|
-
learning_rate = hp.Choice(
|
391
|
-
"learning_rate_16",
|
392
|
-
[0.005, 0.002, 0.001, 0.0005, 0.0002],
|
393
|
-
default=0.0005,
|
394
|
-
parent_name="batch_size",
|
395
|
-
parent_values=[16],
|
396
|
-
)
|
397
|
-
elif batch_size == 32:
|
398
|
-
learning_rate = hp.Choice(
|
399
|
-
"learning_rate_32",
|
400
|
-
[0.01, 0.005, 0.001, 0.0005, 0.0001],
|
401
|
-
default=0.0001,
|
402
|
-
parent_name="batch_size",
|
403
|
-
parent_values=[32],
|
404
|
-
)
|
405
|
-
elif batch_size == 64:
|
406
|
-
learning_rate = hp.Choice(
|
407
|
-
"learning_rate_64",
|
408
|
-
[0.01, 0.005, 0.002, 0.001],
|
409
|
-
default=0.001,
|
410
|
-
parent_name="batch_size",
|
411
|
-
parent_values=[64],
|
412
|
-
)
|
413
|
-
elif batch_size == 128:
|
414
|
-
learning_rate = hp.Choice(
|
415
|
-
"learning_rate_128",
|
416
|
-
[0.1, 0.01, 0.005],
|
417
|
-
default=0.005,
|
418
|
-
parent_name="batch_size",
|
419
|
-
parent_values=[128],
|
420
|
-
)
|
421
|
-
|
422
|
-
# Train model
|
423
|
-
print("Training model...", flush=True)
|
424
|
-
classifier, history = model.train_linear_classifier(
|
425
|
-
classifier,
|
426
|
-
self.x_train,
|
427
|
-
self.y_train,
|
428
|
-
self.x_test,
|
429
|
-
self.y_test,
|
430
|
-
epochs=cfg.TRAIN_EPOCHS,
|
431
|
-
batch_size=batch_size,
|
432
|
-
learning_rate=learning_rate,
|
433
|
-
val_split=0.0 if len(self.x_test) > 0 else cfg.TRAIN_VAL_SPLIT,
|
434
|
-
upsampling_ratio=hp.Choice(
|
435
|
-
"upsampling_ratio", [0.0, 0.25, 0.33, 0.5, 0.75, 1.0], default=cfg.UPSAMPLING_RATIO
|
436
|
-
),
|
437
|
-
upsampling_mode=hp.Choice(
|
438
|
-
"upsampling_mode",
|
439
|
-
upsampling_choices,
|
440
|
-
default=cfg.UPSAMPLING_MODE,
|
441
|
-
parent_name="upsampling_ratio",
|
442
|
-
parent_values=[0.25, 0.33, 0.5, 0.75, 1.0],
|
443
|
-
),
|
444
|
-
train_with_mixup=hp.Boolean("mixup", default=cfg.TRAIN_WITH_MIXUP),
|
445
|
-
train_with_label_smoothing=hp.Boolean(
|
446
|
-
"label_smoothing", default=cfg.TRAIN_WITH_LABEL_SMOOTHING
|
447
|
-
),
|
448
|
-
train_with_focal_loss=hp.Boolean("focal_loss", default=cfg.TRAIN_WITH_FOCAL_LOSS),
|
449
|
-
focal_loss_gamma=hp.Choice(
|
450
|
-
"focal_loss_gamma",
|
451
|
-
[0.5, 1.0, 2.0, 3.0, 4.0],
|
452
|
-
default=cfg.FOCAL_LOSS_GAMMA,
|
453
|
-
parent_name="focal_loss",
|
454
|
-
parent_values=[True]
|
455
|
-
),
|
456
|
-
focal_loss_alpha=hp.Choice(
|
457
|
-
"focal_loss_alpha",
|
458
|
-
[0.1, 0.25, 0.5, 0.75, 0.9],
|
459
|
-
default=cfg.FOCAL_LOSS_ALPHA,
|
460
|
-
parent_name="focal_loss",
|
461
|
-
parent_values=[True]
|
462
|
-
),
|
463
|
-
)
|
464
|
-
|
465
|
-
# Get the best validation AUPRC instead of loss
|
466
|
-
best_val_auprc = history.history["val_AUPRC"][np.argmax(history.history["val_AUPRC"])]
|
467
|
-
histories.append(best_val_auprc)
|
468
|
-
|
469
|
-
print(
|
470
|
-
f"Finished Trial #{trial_number} execution #{execution + 1}. Best validation AUPRC: {best_val_auprc}",
|
471
|
-
flush=True,
|
472
|
-
)
|
473
|
-
|
474
|
-
keras.backend.clear_session()
|
475
|
-
del classifier
|
476
|
-
del history
|
477
|
-
gc.collect()
|
478
|
-
|
479
|
-
# Call the on_trial_result callback
|
480
|
-
if self.on_trial_result:
|
481
|
-
self.on_trial_result(trial_number)
|
482
|
-
|
483
|
-
# Return the negative AUPRC for minimization (keras-tuner minimizes by default)
|
484
|
-
return [-h for h in histories]
|
485
|
-
|
486
|
-
# Create the tuner instance
|
487
|
-
tuner = BirdNetTuner(
|
488
|
-
x_train=x_train,
|
489
|
-
y_train=y_train,
|
490
|
-
x_test=x_test,
|
491
|
-
y_test=y_test,
|
492
|
-
max_trials=cfg.AUTOTUNE_TRIALS,
|
493
|
-
executions_per_trial=cfg.AUTOTUNE_EXECUTIONS_PER_TRIAL,
|
494
|
-
on_trial_result=on_trial_result,
|
495
|
-
)
|
496
|
-
try:
|
497
|
-
tuner.search()
|
498
|
-
except model.get_empty_class_exception() as e:
|
499
|
-
e.message = f"Class with label {labels[e.index]} is empty. Please remove it from the training data."
|
500
|
-
e.args = (e.message,)
|
501
|
-
raise e
|
502
|
-
|
503
|
-
best_params = tuner.get_best_hyperparameters()[0]
|
504
|
-
|
505
|
-
cfg.TRAIN_HIDDEN_UNITS = best_params["hidden_units"]
|
506
|
-
cfg.TRAIN_DROPOUT = best_params["dropout"]
|
507
|
-
cfg.TRAIN_BATCH_SIZE = best_params["batch_size"]
|
508
|
-
cfg.TRAIN_LEARNING_RATE = best_params[f"learning_rate_{cfg.TRAIN_BATCH_SIZE}"]
|
509
|
-
if cfg.UPSAMPLING_RATIO > 0:
|
510
|
-
cfg.UPSAMPLING_MODE = best_params["upsampling_mode"]
|
511
|
-
cfg.UPSAMPLING_RATIO = best_params["upsampling_ratio"]
|
512
|
-
cfg.TRAIN_WITH_MIXUP = best_params["mixup"]
|
513
|
-
cfg.TRAIN_WITH_LABEL_SMOOTHING = best_params["label_smoothing"]
|
514
|
-
|
515
|
-
print("Best params: ")
|
516
|
-
print("hidden_units: ", cfg.TRAIN_HIDDEN_UNITS)
|
517
|
-
print("dropout: ", cfg.TRAIN_DROPOUT)
|
518
|
-
print("batch_size: ", cfg.TRAIN_BATCH_SIZE)
|
519
|
-
print("learning_rate: ", cfg.TRAIN_LEARNING_RATE)
|
520
|
-
print("upsampling_ratio: ", cfg.UPSAMPLING_RATIO)
|
521
|
-
if cfg.UPSAMPLING_RATIO > 0:
|
522
|
-
print("upsampling_mode: ", cfg.UPSAMPLING_MODE)
|
523
|
-
print("mixup: ", cfg.TRAIN_WITH_MIXUP)
|
524
|
-
print("label_smoothing: ", cfg.TRAIN_WITH_LABEL_SMOOTHING)
|
525
|
-
|
526
|
-
# Build model
|
527
|
-
print("Building model...", flush=True)
|
528
|
-
classifier = model.build_linear_classifier(
|
529
|
-
y_train.shape[1], x_train.shape[1], cfg.TRAIN_HIDDEN_UNITS, cfg.TRAIN_DROPOUT
|
530
|
-
)
|
531
|
-
print("...Done.", flush=True)
|
532
|
-
|
533
|
-
# Train model
|
534
|
-
print("Training model...", flush=True)
|
535
|
-
try:
|
536
|
-
classifier, history = model.train_linear_classifier(
|
537
|
-
classifier,
|
538
|
-
x_train,
|
539
|
-
y_train,
|
540
|
-
x_test,
|
541
|
-
y_test,
|
542
|
-
epochs=cfg.TRAIN_EPOCHS,
|
543
|
-
batch_size=cfg.TRAIN_BATCH_SIZE,
|
544
|
-
learning_rate=cfg.TRAIN_LEARNING_RATE,
|
545
|
-
val_split=cfg.TRAIN_VAL_SPLIT if len(x_test) == 0 else 0.0,
|
546
|
-
upsampling_ratio=cfg.UPSAMPLING_RATIO,
|
547
|
-
upsampling_mode=cfg.UPSAMPLING_MODE,
|
548
|
-
train_with_mixup=cfg.TRAIN_WITH_MIXUP,
|
549
|
-
train_with_label_smoothing=cfg.TRAIN_WITH_LABEL_SMOOTHING,
|
550
|
-
train_with_focal_loss=cfg.TRAIN_WITH_FOCAL_LOSS,
|
551
|
-
focal_loss_gamma=cfg.FOCAL_LOSS_GAMMA,
|
552
|
-
focal_loss_alpha=cfg.FOCAL_LOSS_ALPHA,
|
553
|
-
on_epoch_end=on_epoch_end,
|
554
|
-
)
|
555
|
-
except model.get_empty_class_exception() as e:
|
556
|
-
e.message = f"Class with label {labels[e.index]} is empty. Please remove it from the training data."
|
557
|
-
e.args = (e.message,)
|
558
|
-
raise e
|
559
|
-
except Exception as e:
|
560
|
-
raise Exception("Error training model") from e
|
561
|
-
|
562
|
-
print("...Done.", flush=True)
|
563
|
-
|
564
|
-
# Get best validation metrics based on AUPRC instead of loss for more reliable results with imbalanced data
|
565
|
-
best_epoch = np.argmax(history.history["val_AUPRC"])
|
566
|
-
best_val_auprc = history.history["val_AUPRC"][best_epoch]
|
567
|
-
best_val_auroc = history.history["val_AUROC"][best_epoch]
|
568
|
-
best_val_loss = history.history["val_loss"][best_epoch]
|
569
|
-
|
570
|
-
print("Saving model...", flush=True)
|
571
|
-
|
572
|
-
try:
|
573
|
-
if cfg.TRAINED_MODEL_OUTPUT_FORMAT == "both":
|
574
|
-
model.save_raven_model(classifier, cfg.CUSTOM_CLASSIFIER, labels, mode=cfg.TRAINED_MODEL_SAVE_MODE)
|
575
|
-
model.save_linear_classifier(classifier, cfg.CUSTOM_CLASSIFIER, labels, mode=cfg.TRAINED_MODEL_SAVE_MODE)
|
576
|
-
elif cfg.TRAINED_MODEL_OUTPUT_FORMAT == "tflite":
|
577
|
-
model.save_linear_classifier(classifier, cfg.CUSTOM_CLASSIFIER, labels, mode=cfg.TRAINED_MODEL_SAVE_MODE)
|
578
|
-
elif cfg.TRAINED_MODEL_OUTPUT_FORMAT == "raven":
|
579
|
-
model.save_raven_model(classifier, cfg.CUSTOM_CLASSIFIER, labels, mode=cfg.TRAINED_MODEL_SAVE_MODE)
|
580
|
-
else:
|
581
|
-
raise ValueError(f"Unknown model output format: {cfg.TRAINED_MODEL_OUTPUT_FORMAT}")
|
582
|
-
except Exception as e:
|
583
|
-
raise Exception("Error saving model") from e
|
584
|
-
|
585
|
-
save_sample_counts(labels, y_train)
|
586
|
-
|
587
|
-
# Evaluate model on test data if available
|
588
|
-
metrics = None
|
589
|
-
if len(x_test) > 0:
|
590
|
-
print("\nEvaluating model on test data...", flush=True)
|
591
|
-
metrics = evaluate_model(classifier, x_test, y_test, labels)
|
592
|
-
|
593
|
-
# Save evaluation results to file
|
594
|
-
if metrics:
|
595
|
-
import csv
|
596
|
-
|
597
|
-
|
598
|
-
|
599
|
-
|
600
|
-
|
601
|
-
|
602
|
-
|
603
|
-
|
604
|
-
|
605
|
-
|
606
|
-
|
607
|
-
|
608
|
-
|
609
|
-
|
610
|
-
|
611
|
-
|
612
|
-
|
613
|
-
|
614
|
-
|
615
|
-
|
616
|
-
|
617
|
-
|
618
|
-
|
619
|
-
|
620
|
-
|
621
|
-
|
622
|
-
|
623
|
-
|
624
|
-
|
625
|
-
|
626
|
-
|
627
|
-
|
628
|
-
f"{
|
629
|
-
f"{
|
630
|
-
f"{
|
631
|
-
f"{
|
632
|
-
f"{
|
633
|
-
f"{
|
634
|
-
|
635
|
-
|
636
|
-
|
637
|
-
|
638
|
-
|
639
|
-
|
640
|
-
|
641
|
-
|
642
|
-
|
643
|
-
|
644
|
-
|
645
|
-
|
646
|
-
|
647
|
-
|
648
|
-
|
649
|
-
|
650
|
-
|
651
|
-
|
652
|
-
|
653
|
-
|
654
|
-
|
655
|
-
|
656
|
-
|
657
|
-
|
658
|
-
|
659
|
-
|
660
|
-
|
661
|
-
|
662
|
-
|
663
|
-
|
664
|
-
|
665
|
-
|
666
|
-
|
667
|
-
|
668
|
-
|
669
|
-
|
670
|
-
|
671
|
-
|
672
|
-
|
673
|
-
|
674
|
-
|
675
|
-
|
676
|
-
|
677
|
-
|
678
|
-
|
679
|
-
|
680
|
-
|
681
|
-
|
682
|
-
|
683
|
-
|
684
|
-
|
685
|
-
|
686
|
-
|
687
|
-
|
688
|
-
|
689
|
-
|
690
|
-
|
691
|
-
|
692
|
-
|
693
|
-
|
694
|
-
|
695
|
-
|
696
|
-
|
697
|
-
|
698
|
-
|
699
|
-
|
700
|
-
|
701
|
-
|
702
|
-
|
703
|
-
|
704
|
-
|
705
|
-
|
706
|
-
|
707
|
-
|
708
|
-
|
709
|
-
|
710
|
-
|
711
|
-
|
712
|
-
|
713
|
-
|
714
|
-
|
715
|
-
|
716
|
-
|
717
|
-
|
718
|
-
|
719
|
-
|
720
|
-
|
721
|
-
|
722
|
-
|
723
|
-
|
724
|
-
|
725
|
-
|
726
|
-
|
727
|
-
|
728
|
-
|
729
|
-
|
730
|
-
|
731
|
-
|
732
|
-
|
733
|
-
|
734
|
-
|
735
|
-
|
736
|
-
|
737
|
-
|
738
|
-
|
739
|
-
|
740
|
-
|
741
|
-
|
742
|
-
|
743
|
-
|
744
|
-
|
745
|
-
|
746
|
-
|
747
|
-
|
748
|
-
|
749
|
-
|
750
|
-
|
751
|
-
|
752
|
-
|
753
|
-
|
754
|
-
|
755
|
-
|
756
|
-
|
757
|
-
|
758
|
-
|
759
|
-
|
760
|
-
|
761
|
-
|
762
|
-
|
763
|
-
|
764
|
-
|
765
|
-
|
766
|
-
|
767
|
-
|
768
|
-
|
769
|
-
|
770
|
-
|
771
|
-
|
772
|
-
|
773
|
-
|
774
|
-
|
775
|
-
|
776
|
-
|
777
|
-
|
778
|
-
|
779
|
-
|
780
|
-
|
781
|
-
|
782
|
-
|
783
|
-
|
784
|
-
|
785
|
-
|
786
|
-
|
787
|
-
|
788
|
-
|
789
|
-
|
790
|
-
|
791
|
-
|
792
|
-
|
793
|
-
|
794
|
-
|
795
|
-
|
796
|
-
|
797
|
-
|
798
|
-
|
799
|
-
|
800
|
-
|
801
|
-
|
802
|
-
|
803
|
-
|
804
|
-
|
805
|
-
|
806
|
-
|
807
|
-
|
808
|
-
|
809
|
-
|
810
|
-
|
811
|
-
|
812
|
-
|
813
|
-
|
814
|
-
|
815
|
-
|
816
|
-
|
817
|
-
|
818
|
-
|
819
|
-
|
820
|
-
|
821
|
-
|
822
|
-
|
823
|
-
|
824
|
-
|
825
|
-
|
826
|
-
|
827
|
-
|
828
|
-
|
829
|
-
|
830
|
-
|
831
|
-
|
832
|
-
|
833
|
-
|
834
|
-
|
835
|
-
|
836
|
-
|
837
|
-
|
838
|
-
|
839
|
-
|
840
|
-
for
|
841
|
-
|
842
|
-
|
843
|
-
|
844
|
-
|
845
|
-
metrics[
|
846
|
-
|
847
|
-
|
1
|
+
"""Module for training a custom classifier.
|
2
|
+
|
3
|
+
Can be used to train a custom classifier with new training data.
|
4
|
+
"""
|
5
|
+
|
6
|
+
import csv
|
7
|
+
import os
|
8
|
+
from functools import partial
|
9
|
+
from multiprocessing.pool import Pool
|
10
|
+
|
11
|
+
import numpy as np
|
12
|
+
import tqdm
|
13
|
+
|
14
|
+
import birdnet_analyzer.config as cfg
|
15
|
+
from birdnet_analyzer import audio, model, utils
|
16
|
+
|
17
|
+
|
18
|
+
def save_sample_counts(labels, y_train):
|
19
|
+
"""
|
20
|
+
Saves the count of samples per label combination to a CSV file.
|
21
|
+
|
22
|
+
The function creates a dictionary where the keys are label combinations (joined by '+') and the values are the counts of samples for each combination.
|
23
|
+
It then writes this information to a CSV file named "<cfg.CUSTOM_CLASSIFIER>_sample_counts.csv" with two columns: "Label" and "Count".
|
24
|
+
|
25
|
+
Args:
|
26
|
+
labels (list of str): List of label names corresponding to the columns in y_train.
|
27
|
+
y_train (numpy.ndarray): 2D array where each row is a binary vector indicating the presence (1) or absence (0) of each label.
|
28
|
+
"""
|
29
|
+
samples_per_label = {}
|
30
|
+
label_combinations = np.unique(y_train, axis=0)
|
31
|
+
|
32
|
+
for label_combination in label_combinations:
|
33
|
+
label = "+".join([labels[i] for i in range(len(label_combination)) if label_combination[i] == 1])
|
34
|
+
samples_per_label[label] = np.sum(np.all(y_train == label_combination, axis=1))
|
35
|
+
|
36
|
+
csv_file_path = cfg.CUSTOM_CLASSIFIER + "_sample_counts.csv"
|
37
|
+
|
38
|
+
with open(csv_file_path, mode="w", newline="") as csv_file:
|
39
|
+
writer = csv.writer(csv_file)
|
40
|
+
writer.writerow(["Label", "Count"])
|
41
|
+
|
42
|
+
for label, count in samples_per_label.items():
|
43
|
+
writer.writerow([label, count])
|
44
|
+
|
45
|
+
|
46
|
+
def _load_audio_file(f, label_vector, config):
|
47
|
+
"""Load an audio file and extract features.
|
48
|
+
Args:
|
49
|
+
f: Path to the audio file.
|
50
|
+
label_vector: The label vector for the file.
|
51
|
+
Returns:
|
52
|
+
A tuple of (x_train, y_train).
|
53
|
+
"""
|
54
|
+
|
55
|
+
x_train = []
|
56
|
+
y_train = []
|
57
|
+
|
58
|
+
# restore config in case we're on Windows to be thread save
|
59
|
+
cfg.set_config(config)
|
60
|
+
|
61
|
+
# Try to load the audio file
|
62
|
+
try:
|
63
|
+
# Load audio
|
64
|
+
sig, rate = audio.open_audio_file(
|
65
|
+
f,
|
66
|
+
duration=cfg.SIG_LENGTH if cfg.SAMPLE_CROP_MODE == "first" else None,
|
67
|
+
fmin=cfg.BANDPASS_FMIN,
|
68
|
+
fmax=cfg.BANDPASS_FMAX,
|
69
|
+
speed=cfg.AUDIO_SPEED,
|
70
|
+
)
|
71
|
+
|
72
|
+
# if anything happens print the error and ignore the file
|
73
|
+
except Exception as e:
|
74
|
+
# Print Error
|
75
|
+
print(f"\t Error when loading file {f}", flush=True)
|
76
|
+
print(f"\t {e}", flush=True)
|
77
|
+
return np.array([]), np.array([])
|
78
|
+
|
79
|
+
# Crop training samples
|
80
|
+
if cfg.SAMPLE_CROP_MODE == "center":
|
81
|
+
sig_splits = [audio.crop_center(sig, rate, cfg.SIG_LENGTH)]
|
82
|
+
elif cfg.SAMPLE_CROP_MODE == "first":
|
83
|
+
sig_splits = [audio.split_signal(sig, rate, cfg.SIG_LENGTH, cfg.SIG_OVERLAP, cfg.SIG_MINLEN)[0]]
|
84
|
+
elif cfg.SAMPLE_CROP_MODE == "smart":
|
85
|
+
# Smart cropping - detect peaks in audio energy to identify potential signals
|
86
|
+
sig_splits = audio.smart_crop_signal(sig, rate, cfg.SIG_LENGTH, cfg.SIG_OVERLAP, cfg.SIG_MINLEN)
|
87
|
+
else:
|
88
|
+
sig_splits = audio.split_signal(sig, rate, cfg.SIG_LENGTH, cfg.SIG_OVERLAP, cfg.SIG_MINLEN)
|
89
|
+
|
90
|
+
# Get feature embeddings
|
91
|
+
batch_size = 1 # turns out that batch size 1 is the fastest, probably because of having to resize the model input when the number of samples in a batch changes
|
92
|
+
for i in range(0, len(sig_splits), batch_size):
|
93
|
+
batch_sig = sig_splits[i : i + batch_size]
|
94
|
+
batch_label = [label_vector] * len(batch_sig)
|
95
|
+
embeddings = model.embeddings(batch_sig)
|
96
|
+
|
97
|
+
# Add to training data
|
98
|
+
x_train.extend(embeddings)
|
99
|
+
y_train.extend(batch_label)
|
100
|
+
|
101
|
+
return x_train, y_train
|
102
|
+
|
103
|
+
|
104
|
+
def _load_training_data(cache_mode=None, cache_file="", progress_callback=None):
|
105
|
+
"""Loads the data for training.
|
106
|
+
|
107
|
+
Reads all subdirectories of "config.TRAIN_DATA_PATH" and uses their names as new labels.
|
108
|
+
|
109
|
+
These directories should contain all the training data for each label.
|
110
|
+
|
111
|
+
If a cache file is provided, the training data is loaded from there.
|
112
|
+
|
113
|
+
Args:
|
114
|
+
cache_mode: Cache mode. Can be 'load' or 'save'. Defaults to None.
|
115
|
+
cache_file: Path to cache file.
|
116
|
+
|
117
|
+
Returns:
|
118
|
+
A tuple of (x_train, y_train, x_test, y_test, labels).
|
119
|
+
"""
|
120
|
+
# Load from cache
|
121
|
+
if cache_mode == "load":
|
122
|
+
if os.path.isfile(cache_file):
|
123
|
+
print(f"\t...loading from cache: {cache_file}", flush=True)
|
124
|
+
x_train, y_train, x_test, y_test, labels, cfg.BINARY_CLASSIFICATION, cfg.MULTI_LABEL = (
|
125
|
+
utils.load_from_cache(cache_file)
|
126
|
+
)
|
127
|
+
return x_train, y_train, x_test, y_test, labels
|
128
|
+
|
129
|
+
print(f"\t...cache file not found: {cache_file}", flush=True)
|
130
|
+
|
131
|
+
# Print train and test data path as confirmation
|
132
|
+
print(f"\t...train data path: {cfg.TRAIN_DATA_PATH}", flush=True)
|
133
|
+
print(f"\t...test data path: {cfg.TEST_DATA_PATH}", flush=True)
|
134
|
+
|
135
|
+
# Get list of subfolders as labels
|
136
|
+
train_folders = sorted(utils.list_subdirectories(cfg.TRAIN_DATA_PATH))
|
137
|
+
|
138
|
+
# Read all individual labels from the folder names
|
139
|
+
labels = []
|
140
|
+
|
141
|
+
for folder in train_folders:
|
142
|
+
labels_in_folder = folder.split(",")
|
143
|
+
for label in labels_in_folder:
|
144
|
+
if label not in labels:
|
145
|
+
labels.append(label)
|
146
|
+
|
147
|
+
# Sort labels
|
148
|
+
labels = sorted(labels)
|
149
|
+
|
150
|
+
# Get valid labels
|
151
|
+
valid_labels = [
|
152
|
+
label for label in labels if label.lower() not in cfg.NON_EVENT_CLASSES and not label.startswith("-")
|
153
|
+
]
|
154
|
+
|
155
|
+
# Check if binary classification
|
156
|
+
cfg.BINARY_CLASSIFICATION = len(valid_labels) == 1
|
157
|
+
|
158
|
+
# Validate the classes for binary classification
|
159
|
+
if cfg.BINARY_CLASSIFICATION:
|
160
|
+
if len([f for f in train_folders if f.startswith("-")]) > 0:
|
161
|
+
raise Exception(
|
162
|
+
"Negative labels can't be used with binary classification",
|
163
|
+
"validation-no-negative-samples-in-binary-classification",
|
164
|
+
)
|
165
|
+
if len([f for f in train_folders if f.lower() in cfg.NON_EVENT_CLASSES]) == 0:
|
166
|
+
raise Exception(
|
167
|
+
"Non-event samples are required for binary classification",
|
168
|
+
"validation-non-event-samples-required-in-binary-classification",
|
169
|
+
)
|
170
|
+
|
171
|
+
# Check if multi label
|
172
|
+
cfg.MULTI_LABEL = len(valid_labels) > 1 and any("," in f for f in train_folders)
|
173
|
+
|
174
|
+
# Check if multi-label and binary classficication
|
175
|
+
if cfg.BINARY_CLASSIFICATION and cfg.MULTI_LABEL:
|
176
|
+
raise Exception("Error: Binary classfication and multi-label not possible at the same time")
|
177
|
+
|
178
|
+
# Only allow repeat upsampling for multi-label setting
|
179
|
+
if cfg.MULTI_LABEL and cfg.UPSAMPLING_RATIO > 0 and cfg.UPSAMPLING_MODE != "repeat":
|
180
|
+
raise Exception(
|
181
|
+
"Only repeat-upsampling ist available for multi-label", "validation-only-repeat-upsampling-for-multi-label"
|
182
|
+
)
|
183
|
+
|
184
|
+
# Load training data
|
185
|
+
x_train = []
|
186
|
+
y_train = []
|
187
|
+
x_test = []
|
188
|
+
y_test = []
|
189
|
+
|
190
|
+
def load_data(data_path, allowed_folders):
|
191
|
+
x = []
|
192
|
+
y = []
|
193
|
+
folders = sorted(utils.list_subdirectories(data_path))
|
194
|
+
|
195
|
+
for folder in folders:
|
196
|
+
if folder not in allowed_folders:
|
197
|
+
print(f"Skipping folder {folder} because it is not in the training data.", flush=True)
|
198
|
+
continue
|
199
|
+
|
200
|
+
# Get label vector
|
201
|
+
label_vector = np.zeros((len(valid_labels),), dtype="float32")
|
202
|
+
folder_labels = folder.split(",")
|
203
|
+
|
204
|
+
for label in folder_labels:
|
205
|
+
if label.lower() not in cfg.NON_EVENT_CLASSES and not label.startswith("-"):
|
206
|
+
label_vector[valid_labels.index(label)] = 1
|
207
|
+
elif (
|
208
|
+
label.startswith("-") and label[1:] in valid_labels
|
209
|
+
): # Negative labels need to be contained in the valid labels
|
210
|
+
label_vector[valid_labels.index(label[1:])] = -1
|
211
|
+
|
212
|
+
# Get list of files
|
213
|
+
# Filter files that start with '.' because macOS seems to them for temp files.
|
214
|
+
files = filter(
|
215
|
+
os.path.isfile,
|
216
|
+
(
|
217
|
+
os.path.join(data_path, folder, f)
|
218
|
+
for f in sorted(os.listdir(os.path.join(data_path, folder)))
|
219
|
+
if not f.startswith(".") and f.rsplit(".", 1)[-1].lower() in cfg.ALLOWED_FILETYPES
|
220
|
+
),
|
221
|
+
)
|
222
|
+
|
223
|
+
# Load files using thread pool
|
224
|
+
with Pool(cfg.CPU_THREADS) as p:
|
225
|
+
tasks = []
|
226
|
+
|
227
|
+
for f in files:
|
228
|
+
task = p.apply_async(
|
229
|
+
partial(_load_audio_file, f=f, label_vector=label_vector, config=cfg.get_config())
|
230
|
+
)
|
231
|
+
tasks.append(task)
|
232
|
+
|
233
|
+
# Wait for tasks to complete and monitor progress with tqdm
|
234
|
+
num_files_processed = 0
|
235
|
+
|
236
|
+
with tqdm.tqdm(total=len(tasks), desc=f" - loading '{folder}'", unit="f") as progress_bar:
|
237
|
+
for task in tasks:
|
238
|
+
result = task.get()
|
239
|
+
# Make sure result is not empty
|
240
|
+
# Empty results might be caused by errors when loading the audio file
|
241
|
+
# TODO: We should check for embeddings size in result, otherwise we can't add them to the training data
|
242
|
+
if len(result[0]) > 0:
|
243
|
+
x += result[0]
|
244
|
+
y += result[1]
|
245
|
+
|
246
|
+
num_files_processed += 1
|
247
|
+
progress_bar.update(1)
|
248
|
+
|
249
|
+
if progress_callback:
|
250
|
+
progress_callback(num_files_processed, len(tasks), folder)
|
251
|
+
return np.array(x, dtype="float32"), np.array(y, dtype="float32")
|
252
|
+
|
253
|
+
x_train, y_train = load_data(cfg.TRAIN_DATA_PATH, train_folders)
|
254
|
+
|
255
|
+
if cfg.TEST_DATA_PATH and cfg.TEST_DATA_PATH != cfg.TRAIN_DATA_PATH:
|
256
|
+
test_folders = sorted(utils.list_subdirectories(cfg.TEST_DATA_PATH))
|
257
|
+
allowed_test_folders = [
|
258
|
+
folder for folder in test_folders if folder in train_folders and not folder.startswith("-")
|
259
|
+
]
|
260
|
+
x_test, y_test = load_data(cfg.TEST_DATA_PATH, allowed_test_folders)
|
261
|
+
else:
|
262
|
+
x_test = np.array([])
|
263
|
+
y_test = np.array([])
|
264
|
+
|
265
|
+
# Save to cache?
|
266
|
+
if cache_mode == "save":
|
267
|
+
print(f"\t...saving training data to cache: {cache_file}", flush=True)
|
268
|
+
try:
|
269
|
+
# Only save the valid labels
|
270
|
+
utils.save_to_cache(cache_file, x_train, y_train, x_test, y_test, valid_labels)
|
271
|
+
except Exception as e:
|
272
|
+
print(f"\t...error saving cache: {e}", flush=True)
|
273
|
+
|
274
|
+
# Return only the valid labels for further use
|
275
|
+
return x_train, y_train, x_test, y_test, valid_labels
|
276
|
+
|
277
|
+
|
278
|
+
def normalize_embeddings(embeddings):
|
279
|
+
"""
|
280
|
+
Normalize embeddings to improve training stability and performance.
|
281
|
+
|
282
|
+
This applies L2 normalization to each embedding vector, which can help
|
283
|
+
with convergence and model performance, especially when training on
|
284
|
+
embeddings from different sources or domains.
|
285
|
+
|
286
|
+
Args:
|
287
|
+
embeddings: numpy array of embedding vectors
|
288
|
+
|
289
|
+
Returns:
|
290
|
+
Normalized embeddings array
|
291
|
+
"""
|
292
|
+
# Calculate L2 norm of each embedding vector
|
293
|
+
norms = np.sqrt(np.sum(embeddings**2, axis=1, keepdims=True))
|
294
|
+
# Avoid division by zero
|
295
|
+
norms[norms == 0] = 1.0
|
296
|
+
# Normalize each embedding vector
|
297
|
+
return embeddings / norms
|
298
|
+
|
299
|
+
|
300
|
+
def train_model(on_epoch_end=None, on_trial_result=None, on_data_load_end=None, autotune_directory="autotune"):
|
301
|
+
"""Trains a custom classifier.
|
302
|
+
|
303
|
+
Args:
|
304
|
+
on_epoch_end: A callback function that takes two arguments `epoch`, `logs`.
|
305
|
+
on_trial_result: A callback function for hyperparameter tuning.
|
306
|
+
on_data_load_end: A callback function for data loading progress.
|
307
|
+
autotune_directory: Directory for autotune results.
|
308
|
+
|
309
|
+
Returns:
|
310
|
+
A keras `History` object, whose `history` property contains all the metrics.
|
311
|
+
"""
|
312
|
+
|
313
|
+
# Load training data
|
314
|
+
print("Loading training data...", flush=True)
|
315
|
+
x_train, y_train, x_test, y_test, labels = _load_training_data(
|
316
|
+
cfg.TRAIN_CACHE_MODE, cfg.TRAIN_CACHE_FILE, on_data_load_end
|
317
|
+
)
|
318
|
+
print(f"...Done. Loaded {x_train.shape[0]} training samples and {y_train.shape[1]} labels.", flush=True)
|
319
|
+
if len(x_test) > 0:
|
320
|
+
print(f"...Loaded {x_test.shape[0]} test samples.", flush=True)
|
321
|
+
|
322
|
+
# Normalize embeddings
|
323
|
+
print("Normalizing embeddings...", flush=True)
|
324
|
+
x_train = normalize_embeddings(x_train)
|
325
|
+
if len(x_test) > 0:
|
326
|
+
x_test = normalize_embeddings(x_test)
|
327
|
+
|
328
|
+
if cfg.AUTOTUNE:
|
329
|
+
import gc
|
330
|
+
|
331
|
+
import keras
|
332
|
+
import keras_tuner
|
333
|
+
|
334
|
+
# Call callback to initialize progress bar
|
335
|
+
if on_trial_result:
|
336
|
+
on_trial_result(0)
|
337
|
+
|
338
|
+
class BirdNetTuner(keras_tuner.BayesianOptimization):
|
339
|
+
def __init__(self, x_train, y_train, x_test, y_test, max_trials, executions_per_trial, on_trial_result):
|
340
|
+
super().__init__(
|
341
|
+
max_trials=max_trials,
|
342
|
+
executions_per_trial=executions_per_trial,
|
343
|
+
overwrite=True,
|
344
|
+
directory=autotune_directory,
|
345
|
+
project_name="birdnet_analyzer",
|
346
|
+
)
|
347
|
+
self.x_train = x_train
|
348
|
+
self.y_train = y_train
|
349
|
+
self.x_test = x_test
|
350
|
+
self.y_test = y_test
|
351
|
+
self.on_trial_result = on_trial_result
|
352
|
+
|
353
|
+
def run_trial(self, trial, *args, **kwargs):
|
354
|
+
histories = []
|
355
|
+
hp: keras_tuner.HyperParameters = trial.hyperparameters
|
356
|
+
trial_number = len(self.oracle.trials)
|
357
|
+
|
358
|
+
for execution in range(int(self.executions_per_trial)):
|
359
|
+
print(f"Running Trial #{trial_number} execution #{execution + 1}", flush=True)
|
360
|
+
|
361
|
+
# Build model
|
362
|
+
print("Building model...", flush=True)
|
363
|
+
classifier = model.build_linear_classifier(
|
364
|
+
self.y_train.shape[1],
|
365
|
+
self.x_train.shape[1],
|
366
|
+
hidden_units=hp.Choice(
|
367
|
+
"hidden_units", [0, 128, 256, 512, 1024, 2048], default=cfg.TRAIN_HIDDEN_UNITS
|
368
|
+
),
|
369
|
+
dropout=hp.Choice("dropout", [0.0, 0.25, 0.33, 0.5, 0.75, 0.9], default=cfg.TRAIN_DROPOUT),
|
370
|
+
)
|
371
|
+
print("...Done.", flush=True)
|
372
|
+
|
373
|
+
# Only allow repeat upsampling in multi-label setting
|
374
|
+
upsampling_choices = ["repeat", "mean", "linear"] # SMOTE is too slow
|
375
|
+
|
376
|
+
if cfg.MULTI_LABEL:
|
377
|
+
upsampling_choices = ["repeat"]
|
378
|
+
|
379
|
+
batch_size = hp.Choice("batch_size", [8, 16, 32, 64, 128], default=cfg.TRAIN_BATCH_SIZE)
|
380
|
+
|
381
|
+
if batch_size == 8:
|
382
|
+
learning_rate = hp.Choice(
|
383
|
+
"learning_rate_8",
|
384
|
+
[0.0005, 0.0002, 0.0001],
|
385
|
+
default=0.0001,
|
386
|
+
parent_name="batch_size",
|
387
|
+
parent_values=[8],
|
388
|
+
)
|
389
|
+
elif batch_size == 16:
|
390
|
+
learning_rate = hp.Choice(
|
391
|
+
"learning_rate_16",
|
392
|
+
[0.005, 0.002, 0.001, 0.0005, 0.0002],
|
393
|
+
default=0.0005,
|
394
|
+
parent_name="batch_size",
|
395
|
+
parent_values=[16],
|
396
|
+
)
|
397
|
+
elif batch_size == 32:
|
398
|
+
learning_rate = hp.Choice(
|
399
|
+
"learning_rate_32",
|
400
|
+
[0.01, 0.005, 0.001, 0.0005, 0.0001],
|
401
|
+
default=0.0001,
|
402
|
+
parent_name="batch_size",
|
403
|
+
parent_values=[32],
|
404
|
+
)
|
405
|
+
elif batch_size == 64:
|
406
|
+
learning_rate = hp.Choice(
|
407
|
+
"learning_rate_64",
|
408
|
+
[0.01, 0.005, 0.002, 0.001],
|
409
|
+
default=0.001,
|
410
|
+
parent_name="batch_size",
|
411
|
+
parent_values=[64],
|
412
|
+
)
|
413
|
+
elif batch_size == 128:
|
414
|
+
learning_rate = hp.Choice(
|
415
|
+
"learning_rate_128",
|
416
|
+
[0.1, 0.01, 0.005],
|
417
|
+
default=0.005,
|
418
|
+
parent_name="batch_size",
|
419
|
+
parent_values=[128],
|
420
|
+
)
|
421
|
+
|
422
|
+
# Train model
|
423
|
+
print("Training model...", flush=True)
|
424
|
+
classifier, history = model.train_linear_classifier(
|
425
|
+
classifier,
|
426
|
+
self.x_train,
|
427
|
+
self.y_train,
|
428
|
+
self.x_test,
|
429
|
+
self.y_test,
|
430
|
+
epochs=cfg.TRAIN_EPOCHS,
|
431
|
+
batch_size=batch_size,
|
432
|
+
learning_rate=learning_rate,
|
433
|
+
val_split=0.0 if len(self.x_test) > 0 else cfg.TRAIN_VAL_SPLIT,
|
434
|
+
upsampling_ratio=hp.Choice(
|
435
|
+
"upsampling_ratio", [0.0, 0.25, 0.33, 0.5, 0.75, 1.0], default=cfg.UPSAMPLING_RATIO
|
436
|
+
),
|
437
|
+
upsampling_mode=hp.Choice(
|
438
|
+
"upsampling_mode",
|
439
|
+
upsampling_choices,
|
440
|
+
default=cfg.UPSAMPLING_MODE,
|
441
|
+
parent_name="upsampling_ratio",
|
442
|
+
parent_values=[0.25, 0.33, 0.5, 0.75, 1.0],
|
443
|
+
),
|
444
|
+
train_with_mixup=hp.Boolean("mixup", default=cfg.TRAIN_WITH_MIXUP),
|
445
|
+
train_with_label_smoothing=hp.Boolean(
|
446
|
+
"label_smoothing", default=cfg.TRAIN_WITH_LABEL_SMOOTHING
|
447
|
+
),
|
448
|
+
train_with_focal_loss=hp.Boolean("focal_loss", default=cfg.TRAIN_WITH_FOCAL_LOSS),
|
449
|
+
focal_loss_gamma=hp.Choice(
|
450
|
+
"focal_loss_gamma",
|
451
|
+
[0.5, 1.0, 2.0, 3.0, 4.0],
|
452
|
+
default=cfg.FOCAL_LOSS_GAMMA,
|
453
|
+
parent_name="focal_loss",
|
454
|
+
parent_values=[True],
|
455
|
+
),
|
456
|
+
focal_loss_alpha=hp.Choice(
|
457
|
+
"focal_loss_alpha",
|
458
|
+
[0.1, 0.25, 0.5, 0.75, 0.9],
|
459
|
+
default=cfg.FOCAL_LOSS_ALPHA,
|
460
|
+
parent_name="focal_loss",
|
461
|
+
parent_values=[True],
|
462
|
+
),
|
463
|
+
)
|
464
|
+
|
465
|
+
# Get the best validation AUPRC instead of loss
|
466
|
+
best_val_auprc = history.history["val_AUPRC"][np.argmax(history.history["val_AUPRC"])]
|
467
|
+
histories.append(best_val_auprc)
|
468
|
+
|
469
|
+
print(
|
470
|
+
f"Finished Trial #{trial_number} execution #{execution + 1}. Best validation AUPRC: {best_val_auprc}",
|
471
|
+
flush=True,
|
472
|
+
)
|
473
|
+
|
474
|
+
keras.backend.clear_session()
|
475
|
+
del classifier
|
476
|
+
del history
|
477
|
+
gc.collect()
|
478
|
+
|
479
|
+
# Call the on_trial_result callback
|
480
|
+
if self.on_trial_result:
|
481
|
+
self.on_trial_result(trial_number)
|
482
|
+
|
483
|
+
# Return the negative AUPRC for minimization (keras-tuner minimizes by default)
|
484
|
+
return [-h for h in histories]
|
485
|
+
|
486
|
+
# Create the tuner instance
|
487
|
+
tuner = BirdNetTuner(
|
488
|
+
x_train=x_train,
|
489
|
+
y_train=y_train,
|
490
|
+
x_test=x_test,
|
491
|
+
y_test=y_test,
|
492
|
+
max_trials=cfg.AUTOTUNE_TRIALS,
|
493
|
+
executions_per_trial=cfg.AUTOTUNE_EXECUTIONS_PER_TRIAL,
|
494
|
+
on_trial_result=on_trial_result,
|
495
|
+
)
|
496
|
+
try:
|
497
|
+
tuner.search()
|
498
|
+
except model.get_empty_class_exception() as e:
|
499
|
+
e.message = f"Class with label {labels[e.index]} is empty. Please remove it from the training data."
|
500
|
+
e.args = (e.message,)
|
501
|
+
raise e
|
502
|
+
|
503
|
+
best_params = tuner.get_best_hyperparameters()[0]
|
504
|
+
|
505
|
+
cfg.TRAIN_HIDDEN_UNITS = best_params["hidden_units"]
|
506
|
+
cfg.TRAIN_DROPOUT = best_params["dropout"]
|
507
|
+
cfg.TRAIN_BATCH_SIZE = best_params["batch_size"]
|
508
|
+
cfg.TRAIN_LEARNING_RATE = best_params[f"learning_rate_{cfg.TRAIN_BATCH_SIZE}"]
|
509
|
+
if cfg.UPSAMPLING_RATIO > 0:
|
510
|
+
cfg.UPSAMPLING_MODE = best_params["upsampling_mode"]
|
511
|
+
cfg.UPSAMPLING_RATIO = best_params["upsampling_ratio"]
|
512
|
+
cfg.TRAIN_WITH_MIXUP = best_params["mixup"]
|
513
|
+
cfg.TRAIN_WITH_LABEL_SMOOTHING = best_params["label_smoothing"]
|
514
|
+
|
515
|
+
print("Best params: ")
|
516
|
+
print("hidden_units: ", cfg.TRAIN_HIDDEN_UNITS)
|
517
|
+
print("dropout: ", cfg.TRAIN_DROPOUT)
|
518
|
+
print("batch_size: ", cfg.TRAIN_BATCH_SIZE)
|
519
|
+
print("learning_rate: ", cfg.TRAIN_LEARNING_RATE)
|
520
|
+
print("upsampling_ratio: ", cfg.UPSAMPLING_RATIO)
|
521
|
+
if cfg.UPSAMPLING_RATIO > 0:
|
522
|
+
print("upsampling_mode: ", cfg.UPSAMPLING_MODE)
|
523
|
+
print("mixup: ", cfg.TRAIN_WITH_MIXUP)
|
524
|
+
print("label_smoothing: ", cfg.TRAIN_WITH_LABEL_SMOOTHING)
|
525
|
+
|
526
|
+
# Build model
|
527
|
+
print("Building model...", flush=True)
|
528
|
+
classifier = model.build_linear_classifier(
|
529
|
+
y_train.shape[1], x_train.shape[1], cfg.TRAIN_HIDDEN_UNITS, cfg.TRAIN_DROPOUT
|
530
|
+
)
|
531
|
+
print("...Done.", flush=True)
|
532
|
+
|
533
|
+
# Train model
|
534
|
+
print("Training model...", flush=True)
|
535
|
+
try:
|
536
|
+
classifier, history = model.train_linear_classifier(
|
537
|
+
classifier,
|
538
|
+
x_train,
|
539
|
+
y_train,
|
540
|
+
x_test,
|
541
|
+
y_test,
|
542
|
+
epochs=cfg.TRAIN_EPOCHS,
|
543
|
+
batch_size=cfg.TRAIN_BATCH_SIZE,
|
544
|
+
learning_rate=cfg.TRAIN_LEARNING_RATE,
|
545
|
+
val_split=cfg.TRAIN_VAL_SPLIT if len(x_test) == 0 else 0.0,
|
546
|
+
upsampling_ratio=cfg.UPSAMPLING_RATIO,
|
547
|
+
upsampling_mode=cfg.UPSAMPLING_MODE,
|
548
|
+
train_with_mixup=cfg.TRAIN_WITH_MIXUP,
|
549
|
+
train_with_label_smoothing=cfg.TRAIN_WITH_LABEL_SMOOTHING,
|
550
|
+
train_with_focal_loss=cfg.TRAIN_WITH_FOCAL_LOSS,
|
551
|
+
focal_loss_gamma=cfg.FOCAL_LOSS_GAMMA,
|
552
|
+
focal_loss_alpha=cfg.FOCAL_LOSS_ALPHA,
|
553
|
+
on_epoch_end=on_epoch_end,
|
554
|
+
)
|
555
|
+
except model.get_empty_class_exception() as e:
|
556
|
+
e.message = f"Class with label {labels[e.index]} is empty. Please remove it from the training data."
|
557
|
+
e.args = (e.message,)
|
558
|
+
raise e
|
559
|
+
except Exception as e:
|
560
|
+
raise Exception("Error training model") from e
|
561
|
+
|
562
|
+
print("...Done.", flush=True)
|
563
|
+
|
564
|
+
# Get best validation metrics based on AUPRC instead of loss for more reliable results with imbalanced data
|
565
|
+
best_epoch = np.argmax(history.history["val_AUPRC"])
|
566
|
+
best_val_auprc = history.history["val_AUPRC"][best_epoch]
|
567
|
+
best_val_auroc = history.history["val_AUROC"][best_epoch]
|
568
|
+
best_val_loss = history.history["val_loss"][best_epoch]
|
569
|
+
|
570
|
+
print("Saving model...", flush=True)
|
571
|
+
|
572
|
+
try:
|
573
|
+
if cfg.TRAINED_MODEL_OUTPUT_FORMAT == "both":
|
574
|
+
model.save_raven_model(classifier, cfg.CUSTOM_CLASSIFIER, labels, mode=cfg.TRAINED_MODEL_SAVE_MODE)
|
575
|
+
model.save_linear_classifier(classifier, cfg.CUSTOM_CLASSIFIER, labels, mode=cfg.TRAINED_MODEL_SAVE_MODE)
|
576
|
+
elif cfg.TRAINED_MODEL_OUTPUT_FORMAT == "tflite":
|
577
|
+
model.save_linear_classifier(classifier, cfg.CUSTOM_CLASSIFIER, labels, mode=cfg.TRAINED_MODEL_SAVE_MODE)
|
578
|
+
elif cfg.TRAINED_MODEL_OUTPUT_FORMAT == "raven":
|
579
|
+
model.save_raven_model(classifier, cfg.CUSTOM_CLASSIFIER, labels, mode=cfg.TRAINED_MODEL_SAVE_MODE)
|
580
|
+
else:
|
581
|
+
raise ValueError(f"Unknown model output format: {cfg.TRAINED_MODEL_OUTPUT_FORMAT}")
|
582
|
+
except Exception as e:
|
583
|
+
raise Exception("Error saving model") from e
|
584
|
+
|
585
|
+
save_sample_counts(labels, y_train)
|
586
|
+
|
587
|
+
# Evaluate model on test data if available
|
588
|
+
metrics = None
|
589
|
+
if len(x_test) > 0:
|
590
|
+
print("\nEvaluating model on test data...", flush=True)
|
591
|
+
metrics = evaluate_model(classifier, x_test, y_test, labels)
|
592
|
+
|
593
|
+
# Save evaluation results to file
|
594
|
+
if metrics:
|
595
|
+
import csv
|
596
|
+
|
597
|
+
eval_file_path = cfg.CUSTOM_CLASSIFIER + "_evaluation.csv"
|
598
|
+
with open(eval_file_path, "w", newline="") as f:
|
599
|
+
writer = csv.writer(f)
|
600
|
+
|
601
|
+
# Define all the metrics as columns, including both default and optimized threshold metrics
|
602
|
+
header = [
|
603
|
+
"Class",
|
604
|
+
"Precision (0.5)",
|
605
|
+
"Recall (0.5)",
|
606
|
+
"F1 Score (0.5)",
|
607
|
+
"Precision (opt)",
|
608
|
+
"Recall (opt)",
|
609
|
+
"F1 Score (opt)",
|
610
|
+
"AUPRC",
|
611
|
+
"AUROC",
|
612
|
+
"Optimal Threshold",
|
613
|
+
"True Positives",
|
614
|
+
"False Positives",
|
615
|
+
"True Negatives",
|
616
|
+
"False Negatives",
|
617
|
+
"Samples",
|
618
|
+
"Percentage (%)",
|
619
|
+
]
|
620
|
+
writer.writerow(header)
|
621
|
+
|
622
|
+
# Write macro-averaged metrics (overall scores) first
|
623
|
+
writer.writerow(
|
624
|
+
[
|
625
|
+
"OVERALL (Macro-avg)",
|
626
|
+
f"{metrics['macro_precision_default']:.4f}",
|
627
|
+
f"{metrics['macro_recall_default']:.4f}",
|
628
|
+
f"{metrics['macro_f1_default']:.4f}",
|
629
|
+
f"{metrics['macro_precision_opt']:.4f}",
|
630
|
+
f"{metrics['macro_recall_opt']:.4f}",
|
631
|
+
f"{metrics['macro_f1_opt']:.4f}",
|
632
|
+
f"{metrics['macro_auprc']:.4f}",
|
633
|
+
f"{metrics['macro_auroc']:.4f}",
|
634
|
+
"",
|
635
|
+
"",
|
636
|
+
"",
|
637
|
+
"",
|
638
|
+
"",
|
639
|
+
"",
|
640
|
+
"", # Empty cells for Threshold, TP, FP, TN, FN, Samples, Percentage
|
641
|
+
]
|
642
|
+
)
|
643
|
+
|
644
|
+
# Write per-class metrics (one row per species)
|
645
|
+
for class_name, class_metrics in metrics["class_metrics"].items():
|
646
|
+
distribution = metrics["class_distribution"].get(class_name, {"count": 0, "percentage": 0.0})
|
647
|
+
writer.writerow(
|
648
|
+
[
|
649
|
+
class_name,
|
650
|
+
f"{class_metrics['precision_default']:.4f}",
|
651
|
+
f"{class_metrics['recall_default']:.4f}",
|
652
|
+
f"{class_metrics['f1_default']:.4f}",
|
653
|
+
f"{class_metrics['precision_opt']:.4f}",
|
654
|
+
f"{class_metrics['recall_opt']:.4f}",
|
655
|
+
f"{class_metrics['f1_opt']:.4f}",
|
656
|
+
f"{class_metrics['auprc']:.4f}",
|
657
|
+
f"{class_metrics['auroc']:.4f}",
|
658
|
+
f"{class_metrics['threshold']:.2f}",
|
659
|
+
class_metrics["tp"],
|
660
|
+
class_metrics["fp"],
|
661
|
+
class_metrics["tn"],
|
662
|
+
class_metrics["fn"],
|
663
|
+
distribution["count"],
|
664
|
+
f"{distribution['percentage']:.2f}",
|
665
|
+
]
|
666
|
+
)
|
667
|
+
|
668
|
+
print(f"Evaluation results saved to {eval_file_path}", flush=True)
|
669
|
+
else:
|
670
|
+
print("\nNo separate test data provided for evaluation. Using validation metrics.", flush=True)
|
671
|
+
|
672
|
+
print(
|
673
|
+
f"...Done. Best AUPRC: {best_val_auprc}, Best AUROC: {best_val_auroc}, Best Loss: {best_val_loss} (epoch {best_epoch + 1}/{len(history.epoch)})",
|
674
|
+
flush=True,
|
675
|
+
)
|
676
|
+
|
677
|
+
return history, metrics
|
678
|
+
|
679
|
+
|
680
|
+
def find_optimal_threshold(y_true, y_pred_prob):
|
681
|
+
"""
|
682
|
+
Find the optimal classification threshold using the F1 score.
|
683
|
+
|
684
|
+
For imbalanced datasets, the default threshold of 0.5 may not be optimal.
|
685
|
+
This function finds the threshold that maximizes the F1 score for each class.
|
686
|
+
|
687
|
+
Args:
|
688
|
+
y_true: Ground truth labels
|
689
|
+
y_pred_prob: Predicted probabilities
|
690
|
+
|
691
|
+
Returns:
|
692
|
+
The optimal threshold value
|
693
|
+
"""
|
694
|
+
from sklearn.metrics import f1_score
|
695
|
+
|
696
|
+
# Try different thresholds and find the one that gives the best F1 score
|
697
|
+
best_threshold = 0.5
|
698
|
+
best_f1 = 0.0
|
699
|
+
|
700
|
+
for threshold in np.arange(0.1, 0.9, 0.05):
|
701
|
+
y_pred = (y_pred_prob >= threshold).astype(int)
|
702
|
+
f1 = f1_score(y_true, y_pred)
|
703
|
+
|
704
|
+
if f1 > best_f1:
|
705
|
+
best_f1 = f1
|
706
|
+
best_threshold = threshold
|
707
|
+
|
708
|
+
return best_threshold
|
709
|
+
|
710
|
+
|
711
|
+
def evaluate_model(classifier, x_test, y_test, labels, threshold=None):
|
712
|
+
"""
|
713
|
+
Evaluates the trained model on test data and prints detailed metrics.
|
714
|
+
|
715
|
+
Args:
|
716
|
+
classifier: The trained model
|
717
|
+
x_test: Test features (embeddings)
|
718
|
+
y_test: Test labels
|
719
|
+
labels: List of label names
|
720
|
+
threshold: Classification threshold (if None, will find optimal threshold for each class)
|
721
|
+
|
722
|
+
Returns:
|
723
|
+
Dictionary with evaluation metrics
|
724
|
+
"""
|
725
|
+
from sklearn.metrics import (
|
726
|
+
average_precision_score,
|
727
|
+
confusion_matrix,
|
728
|
+
f1_score,
|
729
|
+
precision_score,
|
730
|
+
recall_score,
|
731
|
+
roc_auc_score,
|
732
|
+
)
|
733
|
+
|
734
|
+
# Skip evaluation if test set is empty
|
735
|
+
if len(x_test) == 0:
|
736
|
+
print("No test data available for evaluation.")
|
737
|
+
return {}
|
738
|
+
|
739
|
+
# Make predictions
|
740
|
+
y_pred_prob = classifier.predict(x_test)
|
741
|
+
|
742
|
+
# Calculate metrics for each class
|
743
|
+
metrics = {}
|
744
|
+
|
745
|
+
print("\nModel Evaluation:")
|
746
|
+
print("=================")
|
747
|
+
|
748
|
+
# Calculate metrics for each class
|
749
|
+
precisions_default = []
|
750
|
+
recalls_default = []
|
751
|
+
f1s_default = []
|
752
|
+
precisions_opt = []
|
753
|
+
recalls_opt = []
|
754
|
+
f1s_opt = []
|
755
|
+
auprcs = []
|
756
|
+
aurocs = []
|
757
|
+
class_metrics = {}
|
758
|
+
optimal_thresholds = {}
|
759
|
+
|
760
|
+
# Print the metric calculation method that's being used
|
761
|
+
print("\nNote: The AUPRC and AUROC metrics calculated during post-training evaluation may differ")
|
762
|
+
print("from training history values due to different calculation methods:")
|
763
|
+
print(" - Training history uses Keras metrics calculated over batches")
|
764
|
+
print(" - Evaluation uses scikit-learn metrics calculated over the entire dataset")
|
765
|
+
|
766
|
+
for i in range(y_test.shape[1]):
|
767
|
+
try:
|
768
|
+
# Calculate metrics with default threshold (0.5)
|
769
|
+
y_pred_default = (y_pred_prob[:, i] >= 0.5).astype(int)
|
770
|
+
|
771
|
+
class_precision_default = precision_score(y_test[:, i], y_pred_default)
|
772
|
+
class_recall_default = recall_score(y_test[:, i], y_pred_default)
|
773
|
+
class_f1_default = f1_score(y_test[:, i], y_pred_default)
|
774
|
+
|
775
|
+
precisions_default.append(class_precision_default)
|
776
|
+
recalls_default.append(class_recall_default)
|
777
|
+
f1s_default.append(class_f1_default)
|
778
|
+
|
779
|
+
# Find optimal threshold for this class if needed
|
780
|
+
if threshold is None:
|
781
|
+
class_threshold = find_optimal_threshold(y_test[:, i], y_pred_prob[:, i])
|
782
|
+
optimal_thresholds[labels[i]] = class_threshold
|
783
|
+
else:
|
784
|
+
class_threshold = threshold
|
785
|
+
|
786
|
+
# Calculate metrics with optimized threshold
|
787
|
+
y_pred_opt = (y_pred_prob[:, i] >= class_threshold).astype(int)
|
788
|
+
|
789
|
+
class_precision_opt = precision_score(y_test[:, i], y_pred_opt)
|
790
|
+
class_recall_opt = recall_score(y_test[:, i], y_pred_opt)
|
791
|
+
class_f1_opt = f1_score(y_test[:, i], y_pred_opt)
|
792
|
+
class_auprc = average_precision_score(y_test[:, i], y_pred_prob[:, i])
|
793
|
+
class_auroc = roc_auc_score(y_test[:, i], y_pred_prob[:, i])
|
794
|
+
|
795
|
+
precisions_opt.append(class_precision_opt)
|
796
|
+
recalls_opt.append(class_recall_opt)
|
797
|
+
f1s_opt.append(class_f1_opt)
|
798
|
+
auprcs.append(class_auprc)
|
799
|
+
aurocs.append(class_auroc)
|
800
|
+
|
801
|
+
# Confusion matrix with optimized threshold
|
802
|
+
tn, fp, fn, tp = confusion_matrix(y_test[:, i], y_pred_opt).ravel()
|
803
|
+
|
804
|
+
class_metrics[labels[i]] = {
|
805
|
+
"precision_default": class_precision_default,
|
806
|
+
"recall_default": class_recall_default,
|
807
|
+
"f1_default": class_f1_default,
|
808
|
+
"precision_opt": class_precision_opt,
|
809
|
+
"recall_opt": class_recall_opt,
|
810
|
+
"f1_opt": class_f1_opt,
|
811
|
+
"auprc": class_auprc,
|
812
|
+
"auroc": class_auroc,
|
813
|
+
"tp": tp,
|
814
|
+
"fp": fp,
|
815
|
+
"tn": tn,
|
816
|
+
"fn": fn,
|
817
|
+
"threshold": class_threshold,
|
818
|
+
}
|
819
|
+
|
820
|
+
print(f"\nClass: {labels[i]}")
|
821
|
+
print(" Default threshold (0.5):")
|
822
|
+
print(f" Precision: {class_precision_default:.4f}")
|
823
|
+
print(f" Recall: {class_recall_default:.4f}")
|
824
|
+
print(f" F1 Score: {class_f1_default:.4f}")
|
825
|
+
print(f" Optimized threshold ({class_threshold:.2f}):")
|
826
|
+
print(f" Precision: {class_precision_opt:.4f}")
|
827
|
+
print(f" Recall: {class_recall_opt:.4f}")
|
828
|
+
print(f" F1 Score: {class_f1_opt:.4f}")
|
829
|
+
print(f" AUPRC: {class_auprc:.4f}")
|
830
|
+
print(f" AUROC: {class_auroc:.4f}")
|
831
|
+
print(" Confusion matrix (optimized threshold):")
|
832
|
+
print(f" True Positives: {tp}")
|
833
|
+
print(f" False Positives: {fp}")
|
834
|
+
print(f" True Negatives: {tn}")
|
835
|
+
print(f" False Negatives: {fn}")
|
836
|
+
|
837
|
+
except Exception as e:
|
838
|
+
print(f"Error calculating metrics for class {labels[i]}: {e}")
|
839
|
+
|
840
|
+
# Calculate macro-averaged metrics for both default and optimized thresholds
|
841
|
+
metrics["macro_precision_default"] = np.mean(precisions_default)
|
842
|
+
metrics["macro_recall_default"] = np.mean(recalls_default)
|
843
|
+
metrics["macro_f1_default"] = np.mean(f1s_default)
|
844
|
+
metrics["macro_precision_opt"] = np.mean(precisions_opt)
|
845
|
+
metrics["macro_recall_opt"] = np.mean(recalls_opt)
|
846
|
+
metrics["macro_f1_opt"] = np.mean(f1s_opt)
|
847
|
+
metrics["macro_auprc"] = np.mean(auprcs)
|
848
|
+
metrics["macro_auroc"] = np.mean(aurocs)
|
849
|
+
metrics["class_metrics"] = class_metrics
|
850
|
+
metrics["optimal_thresholds"] = optimal_thresholds
|
851
|
+
|
852
|
+
print("\nMacro-averaged metrics:")
|
853
|
+
print(" Default threshold (0.5):")
|
854
|
+
print(f" Precision: {metrics['macro_precision_default']:.4f}")
|
855
|
+
print(f" Recall: {metrics['macro_recall_default']:.4f}")
|
856
|
+
print(f" F1 Score: {metrics['macro_f1_default']:.4f}")
|
857
|
+
print(" Optimized thresholds:")
|
858
|
+
print(f" Precision: {metrics['macro_precision_opt']:.4f}")
|
859
|
+
print(f" Recall: {metrics['macro_recall_opt']:.4f}")
|
860
|
+
print(f" F1 Score: {metrics['macro_f1_opt']:.4f}")
|
861
|
+
print(f" AUPRC: {metrics['macro_auprc']:.4f}")
|
862
|
+
print(f" AUROC: {metrics['macro_auroc']:.4f}")
|
863
|
+
|
864
|
+
# Calculate class distribution in test set
|
865
|
+
class_counts = y_test.sum(axis=0)
|
866
|
+
total_samples = len(y_test)
|
867
|
+
class_distribution = {}
|
868
|
+
|
869
|
+
print("\nClass distribution in test set:")
|
870
|
+
for i, count in enumerate(class_counts):
|
871
|
+
percentage = count / total_samples * 100
|
872
|
+
class_distribution[labels[i]] = {"count": int(count), "percentage": percentage}
|
873
|
+
print(f" {labels[i]}: {int(count)} samples ({percentage:.2f}%)")
|
874
|
+
|
875
|
+
metrics["class_distribution"] = class_distribution
|
876
|
+
|
877
|
+
return metrics
|