birdnet-analyzer 2.0.0__py3-none-any.whl → 2.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (122) hide show
  1. birdnet_analyzer/__init__.py +9 -8
  2. birdnet_analyzer/analyze/__init__.py +5 -5
  3. birdnet_analyzer/analyze/__main__.py +3 -4
  4. birdnet_analyzer/analyze/cli.py +25 -25
  5. birdnet_analyzer/analyze/core.py +241 -245
  6. birdnet_analyzer/analyze/utils.py +692 -701
  7. birdnet_analyzer/audio.py +368 -372
  8. birdnet_analyzer/cli.py +709 -707
  9. birdnet_analyzer/config.py +242 -242
  10. birdnet_analyzer/eBird_taxonomy_codes_2021E.json +25279 -25279
  11. birdnet_analyzer/embeddings/__init__.py +3 -4
  12. birdnet_analyzer/embeddings/__main__.py +3 -3
  13. birdnet_analyzer/embeddings/cli.py +12 -13
  14. birdnet_analyzer/embeddings/core.py +69 -70
  15. birdnet_analyzer/embeddings/utils.py +179 -193
  16. birdnet_analyzer/evaluation/__init__.py +196 -195
  17. birdnet_analyzer/evaluation/__main__.py +3 -3
  18. birdnet_analyzer/evaluation/assessment/__init__.py +0 -0
  19. birdnet_analyzer/evaluation/assessment/metrics.py +388 -0
  20. birdnet_analyzer/evaluation/assessment/performance_assessor.py +409 -0
  21. birdnet_analyzer/evaluation/assessment/plotting.py +379 -0
  22. birdnet_analyzer/evaluation/preprocessing/__init__.py +0 -0
  23. birdnet_analyzer/evaluation/preprocessing/data_processor.py +631 -0
  24. birdnet_analyzer/evaluation/preprocessing/utils.py +98 -0
  25. birdnet_analyzer/gui/__init__.py +19 -23
  26. birdnet_analyzer/gui/__main__.py +3 -3
  27. birdnet_analyzer/gui/analysis.py +175 -174
  28. birdnet_analyzer/gui/assets/arrow_down.svg +4 -4
  29. birdnet_analyzer/gui/assets/arrow_left.svg +4 -4
  30. birdnet_analyzer/gui/assets/arrow_right.svg +4 -4
  31. birdnet_analyzer/gui/assets/arrow_up.svg +4 -4
  32. birdnet_analyzer/gui/assets/gui.css +28 -28
  33. birdnet_analyzer/gui/assets/gui.js +93 -93
  34. birdnet_analyzer/gui/embeddings.py +619 -620
  35. birdnet_analyzer/gui/evaluation.py +795 -813
  36. birdnet_analyzer/gui/localization.py +75 -68
  37. birdnet_analyzer/gui/multi_file.py +245 -246
  38. birdnet_analyzer/gui/review.py +519 -527
  39. birdnet_analyzer/gui/segments.py +191 -191
  40. birdnet_analyzer/gui/settings.py +128 -129
  41. birdnet_analyzer/gui/single_file.py +267 -269
  42. birdnet_analyzer/gui/species.py +95 -95
  43. birdnet_analyzer/gui/train.py +696 -698
  44. birdnet_analyzer/gui/utils.py +810 -808
  45. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_af.txt +6522 -6522
  46. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ar.txt +6522 -6522
  47. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_bg.txt +6522 -6522
  48. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ca.txt +6522 -6522
  49. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_cs.txt +6522 -6522
  50. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_da.txt +6522 -6522
  51. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_de.txt +6522 -6522
  52. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_el.txt +6522 -6522
  53. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_en_uk.txt +6522 -6522
  54. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_es.txt +6522 -6522
  55. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_fi.txt +6522 -6522
  56. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_fr.txt +6522 -6522
  57. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_he.txt +6522 -6522
  58. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_hr.txt +6522 -6522
  59. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_hu.txt +6522 -6522
  60. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_in.txt +6522 -6522
  61. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_is.txt +6522 -6522
  62. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_it.txt +6522 -6522
  63. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ja.txt +6522 -6522
  64. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ko.txt +6522 -6522
  65. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_lt.txt +6522 -6522
  66. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ml.txt +6522 -6522
  67. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_nl.txt +6522 -6522
  68. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_no.txt +6522 -6522
  69. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_pl.txt +6522 -6522
  70. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_pt_BR.txt +6522 -6522
  71. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_pt_PT.txt +6522 -6522
  72. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ro.txt +6522 -6522
  73. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ru.txt +6522 -6522
  74. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sk.txt +6522 -6522
  75. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sl.txt +6522 -6522
  76. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sr.txt +6522 -6522
  77. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sv.txt +6522 -6522
  78. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_th.txt +6522 -6522
  79. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_tr.txt +6522 -6522
  80. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_uk.txt +6522 -6522
  81. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_zh.txt +6522 -6522
  82. birdnet_analyzer/lang/de.json +334 -334
  83. birdnet_analyzer/lang/en.json +334 -334
  84. birdnet_analyzer/lang/fi.json +334 -334
  85. birdnet_analyzer/lang/fr.json +334 -334
  86. birdnet_analyzer/lang/id.json +334 -334
  87. birdnet_analyzer/lang/pt-br.json +334 -334
  88. birdnet_analyzer/lang/ru.json +334 -334
  89. birdnet_analyzer/lang/se.json +334 -334
  90. birdnet_analyzer/lang/tlh.json +334 -334
  91. birdnet_analyzer/lang/zh_TW.json +334 -334
  92. birdnet_analyzer/model.py +1212 -1243
  93. birdnet_analyzer/playground.py +5 -0
  94. birdnet_analyzer/search/__init__.py +3 -3
  95. birdnet_analyzer/search/__main__.py +3 -3
  96. birdnet_analyzer/search/cli.py +11 -12
  97. birdnet_analyzer/search/core.py +78 -78
  98. birdnet_analyzer/search/utils.py +107 -111
  99. birdnet_analyzer/segments/__init__.py +3 -3
  100. birdnet_analyzer/segments/__main__.py +3 -3
  101. birdnet_analyzer/segments/cli.py +13 -14
  102. birdnet_analyzer/segments/core.py +81 -78
  103. birdnet_analyzer/segments/utils.py +383 -394
  104. birdnet_analyzer/species/__init__.py +3 -3
  105. birdnet_analyzer/species/__main__.py +3 -3
  106. birdnet_analyzer/species/cli.py +13 -14
  107. birdnet_analyzer/species/core.py +35 -35
  108. birdnet_analyzer/species/utils.py +74 -75
  109. birdnet_analyzer/train/__init__.py +3 -3
  110. birdnet_analyzer/train/__main__.py +3 -3
  111. birdnet_analyzer/train/cli.py +13 -14
  112. birdnet_analyzer/train/core.py +113 -113
  113. birdnet_analyzer/train/utils.py +877 -847
  114. birdnet_analyzer/translate.py +133 -104
  115. birdnet_analyzer/utils.py +426 -419
  116. {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.0.1.dist-info}/METADATA +137 -129
  117. birdnet_analyzer-2.0.1.dist-info/RECORD +125 -0
  118. {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.0.1.dist-info}/WHEEL +1 -1
  119. {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.0.1.dist-info}/licenses/LICENSE +18 -18
  120. birdnet_analyzer-2.0.0.dist-info/RECORD +0 -117
  121. {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.0.1.dist-info}/entry_points.txt +0 -0
  122. {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.0.1.dist-info}/top_level.txt +0 -0
@@ -1,847 +1,877 @@
1
- """Module for training a custom classifier.
2
-
3
- Can be used to train a custom classifier with new training data.
4
- """
5
-
6
- import csv
7
- import os
8
- from functools import partial
9
- from multiprocessing.pool import Pool
10
-
11
- import numpy as np
12
- import tqdm
13
-
14
- import birdnet_analyzer.audio as audio
15
- import birdnet_analyzer.config as cfg
16
- import birdnet_analyzer.model as model
17
- import birdnet_analyzer.utils as utils
18
-
19
-
20
- def save_sample_counts(labels, y_train):
21
- """
22
- Saves the count of samples per label combination to a CSV file.
23
-
24
- The function creates a dictionary where the keys are label combinations (joined by '+') and the values are the counts of samples for each combination.
25
- It then writes this information to a CSV file named "<cfg.CUSTOM_CLASSIFIER>_sample_counts.csv" with two columns: "Label" and "Count".
26
-
27
- Args:
28
- labels (list of str): List of label names corresponding to the columns in y_train.
29
- y_train (numpy.ndarray): 2D array where each row is a binary vector indicating the presence (1) or absence (0) of each label.
30
- """
31
- samples_per_label = {}
32
- label_combinations = np.unique(y_train, axis=0)
33
-
34
- for label_combination in label_combinations:
35
- label = "+".join([labels[i] for i in range(len(label_combination)) if label_combination[i] == 1])
36
- samples_per_label[label] = np.sum(np.all(y_train == label_combination, axis=1))
37
-
38
- csv_file_path = cfg.CUSTOM_CLASSIFIER + "_sample_counts.csv"
39
-
40
- with open(csv_file_path, mode="w", newline="") as csv_file:
41
- writer = csv.writer(csv_file)
42
- writer.writerow(["Label", "Count"])
43
-
44
- for label, count in samples_per_label.items():
45
- writer.writerow([label, count])
46
-
47
-
48
- def _load_audio_file(f, label_vector, config):
49
- """Load an audio file and extract features.
50
- Args:
51
- f: Path to the audio file.
52
- label_vector: The label vector for the file.
53
- Returns:
54
- A tuple of (x_train, y_train).
55
- """
56
-
57
- x_train = []
58
- y_train = []
59
-
60
- # restore config in case we're on Windows to be thread save
61
- cfg.set_config(config)
62
-
63
- # Try to load the audio file
64
- try:
65
- # Load audio
66
- sig, rate = audio.open_audio_file(
67
- f,
68
- duration=cfg.SIG_LENGTH if cfg.SAMPLE_CROP_MODE == "first" else None,
69
- fmin=cfg.BANDPASS_FMIN,
70
- fmax=cfg.BANDPASS_FMAX,
71
- speed=cfg.AUDIO_SPEED,
72
- )
73
-
74
- # if anything happens print the error and ignore the file
75
- except Exception as e:
76
- # Print Error
77
- print(f"\t Error when loading file {f}", flush=True)
78
- print(f"\t {e}", flush=True)
79
- return np.array([]), np.array([])
80
-
81
- # Crop training samples
82
- if cfg.SAMPLE_CROP_MODE == "center":
83
- sig_splits = [audio.crop_center(sig, rate, cfg.SIG_LENGTH)]
84
- elif cfg.SAMPLE_CROP_MODE == "first":
85
- sig_splits = [audio.split_signal(sig, rate, cfg.SIG_LENGTH, cfg.SIG_OVERLAP, cfg.SIG_MINLEN)[0]]
86
- elif cfg.SAMPLE_CROP_MODE == "smart":
87
- # Smart cropping - detect peaks in audio energy to identify potential signals
88
- sig_splits = audio.smart_crop_signal(sig, rate, cfg.SIG_LENGTH, cfg.SIG_OVERLAP, cfg.SIG_MINLEN)
89
- else:
90
- sig_splits = audio.split_signal(sig, rate, cfg.SIG_LENGTH, cfg.SIG_OVERLAP, cfg.SIG_MINLEN)
91
-
92
- # Get feature embeddings
93
- batch_size = 1 # turns out that batch size 1 is the fastest, probably because of having to resize the model input when the number of samples in a batch changes
94
- for i in range(0, len(sig_splits), batch_size):
95
- batch_sig = sig_splits[i : i + batch_size]
96
- batch_label = [label_vector] * len(batch_sig)
97
- embeddings = model.embeddings(batch_sig)
98
-
99
- # Add to training data
100
- x_train.extend(embeddings)
101
- y_train.extend(batch_label)
102
-
103
- return x_train, y_train
104
-
105
- def _load_training_data(cache_mode=None, cache_file="", progress_callback=None):
106
- """Loads the data for training.
107
-
108
- Reads all subdirectories of "config.TRAIN_DATA_PATH" and uses their names as new labels.
109
-
110
- These directories should contain all the training data for each label.
111
-
112
- If a cache file is provided, the training data is loaded from there.
113
-
114
- Args:
115
- cache_mode: Cache mode. Can be 'load' or 'save'. Defaults to None.
116
- cache_file: Path to cache file.
117
-
118
- Returns:
119
- A tuple of (x_train, y_train, x_test, y_test, labels).
120
- """
121
- # Load from cache
122
- if cache_mode == "load":
123
- if os.path.isfile(cache_file):
124
- print(f"\t...loading from cache: {cache_file}", flush=True)
125
- x_train, y_train, x_test, y_test, labels, cfg.BINARY_CLASSIFICATION, cfg.MULTI_LABEL = (
126
- utils.load_from_cache(cache_file)
127
- )
128
- return x_train, y_train, x_test, y_test, labels
129
- else:
130
- print(f"\t...cache file not found: {cache_file}", flush=True)
131
-
132
- # Print train and test data path as confirmation
133
- print(f"\t...train data path: {cfg.TRAIN_DATA_PATH}", flush=True)
134
- print(f"\t...test data path: {cfg.TEST_DATA_PATH}", flush=True)
135
-
136
- # Get list of subfolders as labels
137
- train_folders = list(sorted(utils.list_subdirectories(cfg.TRAIN_DATA_PATH)))
138
-
139
- # Read all individual labels from the folder names
140
- labels = []
141
-
142
- for folder in train_folders:
143
- labels_in_folder = folder.split(",")
144
- for label in labels_in_folder:
145
- if label not in labels:
146
- labels.append(label)
147
-
148
- # Sort labels
149
- labels = list(sorted(labels))
150
-
151
- # Get valid labels
152
- valid_labels = [
153
- label for label in labels if label.lower() not in cfg.NON_EVENT_CLASSES and not label.startswith("-")
154
- ]
155
-
156
- # Check if binary classification
157
- cfg.BINARY_CLASSIFICATION = len(valid_labels) == 1
158
-
159
- # Validate the classes for binary classification
160
- if cfg.BINARY_CLASSIFICATION:
161
- if len([f for f in train_folders if f.startswith("-")]) > 0:
162
- raise Exception(
163
- "Negative labels can't be used with binary classification",
164
- "validation-no-negative-samples-in-binary-classification",
165
- )
166
- if len([f for f in train_folders if f.lower() in cfg.NON_EVENT_CLASSES]) == 0:
167
- raise Exception(
168
- "Non-event samples are required for binary classification",
169
- "validation-non-event-samples-required-in-binary-classification",
170
- )
171
-
172
- # Check if multi label
173
- cfg.MULTI_LABEL = len(valid_labels) > 1 and any("," in f for f in train_folders)
174
-
175
- # Check if multi-label and binary classficication
176
- if cfg.BINARY_CLASSIFICATION and cfg.MULTI_LABEL:
177
- raise Exception("Error: Binary classfication and multi-label not possible at the same time")
178
-
179
- # Only allow repeat upsampling for multi-label setting
180
- if cfg.MULTI_LABEL and cfg.UPSAMPLING_RATIO > 0 and cfg.UPSAMPLING_MODE != "repeat":
181
- raise Exception(
182
- "Only repeat-upsampling ist available for multi-label", "validation-only-repeat-upsampling-for-multi-label"
183
- )
184
-
185
- # Load training data
186
- x_train = []
187
- y_train = []
188
- x_test = []
189
- y_test = []
190
-
191
- def load_data(data_path, allowed_folders):
192
- x = []
193
- y = []
194
- folders = list(sorted(utils.list_subdirectories(data_path)))
195
-
196
- for folder in folders:
197
- if folder not in allowed_folders:
198
- print(f"Skipping folder {folder} because it is not in the training data.", flush=True)
199
- continue
200
-
201
- # Get label vector
202
- label_vector = np.zeros((len(valid_labels),), dtype="float32")
203
- folder_labels = folder.split(",")
204
-
205
- for label in folder_labels:
206
- if label.lower() not in cfg.NON_EVENT_CLASSES and not label.startswith("-"):
207
- label_vector[valid_labels.index(label)] = 1
208
- elif (
209
- label.startswith("-") and label[1:] in valid_labels
210
- ): # Negative labels need to be contained in the valid labels
211
- label_vector[valid_labels.index(label[1:])] = -1
212
-
213
- # Get list of files
214
- # Filter files that start with '.' because macOS seems to them for temp files.
215
- files = filter(
216
- os.path.isfile,
217
- (
218
- os.path.join(data_path, folder, f)
219
- for f in sorted(os.listdir(os.path.join(data_path, folder)))
220
- if not f.startswith(".") and f.rsplit(".", 1)[-1].lower() in cfg.ALLOWED_FILETYPES
221
- ),
222
- )
223
-
224
- # Load files using thread pool
225
- with Pool(cfg.CPU_THREADS) as p:
226
- tasks = []
227
-
228
- for f in files:
229
- task = p.apply_async(
230
- partial(_load_audio_file, f=f, label_vector=label_vector, config=cfg.get_config())
231
- )
232
- tasks.append(task)
233
-
234
- # Wait for tasks to complete and monitor progress with tqdm
235
- num_files_processed = 0
236
-
237
- with tqdm.tqdm(total=len(tasks), desc=f" - loading '{folder}'", unit="f") as progress_bar:
238
- for task in tasks:
239
- result = task.get()
240
- # Make sure result is not empty
241
- # Empty results might be caused by errors when loading the audio file
242
- # TODO: We should check for embeddings size in result, otherwise we can't add them to the training data
243
- if len(result[0]) > 0:
244
- x += result[0]
245
- y += result[1]
246
-
247
- num_files_processed += 1
248
- progress_bar.update(1)
249
-
250
- if progress_callback:
251
- progress_callback(num_files_processed, len(tasks), folder)
252
- return np.array(x, dtype="float32"), np.array(y, dtype="float32")
253
-
254
- x_train, y_train = load_data(cfg.TRAIN_DATA_PATH, train_folders)
255
-
256
- if cfg.TEST_DATA_PATH and cfg.TEST_DATA_PATH != cfg.TRAIN_DATA_PATH:
257
- test_folders = list(sorted(utils.list_subdirectories(cfg.TEST_DATA_PATH)))
258
- allowed_test_folders = [
259
- folder for folder in test_folders if folder in train_folders and not folder.startswith("-")
260
- ]
261
- x_test, y_test = load_data(cfg.TEST_DATA_PATH, allowed_test_folders)
262
- else:
263
- x_test = np.array([])
264
- y_test = np.array([])
265
-
266
- # Save to cache?
267
- if cache_mode == "save":
268
- print(f"\t...saving training data to cache: {cache_file}", flush=True)
269
- try:
270
- # Only save the valid labels
271
- utils.save_to_cache(cache_file, x_train, y_train, x_test, y_test, valid_labels)
272
- except Exception as e:
273
- print(f"\t...error saving cache: {e}", flush=True)
274
-
275
- # Return only the valid labels for further use
276
- return x_train, y_train, x_test, y_test, valid_labels
277
-
278
-
279
- def normalize_embeddings(embeddings):
280
- """
281
- Normalize embeddings to improve training stability and performance.
282
-
283
- This applies L2 normalization to each embedding vector, which can help
284
- with convergence and model performance, especially when training on
285
- embeddings from different sources or domains.
286
-
287
- Args:
288
- embeddings: numpy array of embedding vectors
289
-
290
- Returns:
291
- Normalized embeddings array
292
- """
293
- # Calculate L2 norm of each embedding vector
294
- norms = np.sqrt(np.sum(embeddings**2, axis=1, keepdims=True))
295
- # Avoid division by zero
296
- norms[norms == 0] = 1.0
297
- # Normalize each embedding vector
298
- normalized = embeddings / norms
299
- return normalized
300
-
301
-
302
- def train_model(on_epoch_end=None, on_trial_result=None, on_data_load_end=None, autotune_directory="autotune"):
303
- """Trains a custom classifier.
304
-
305
- Args:
306
- on_epoch_end: A callback function that takes two arguments `epoch`, `logs`.
307
- on_trial_result: A callback function for hyperparameter tuning.
308
- on_data_load_end: A callback function for data loading progress.
309
- autotune_directory: Directory for autotune results.
310
-
311
- Returns:
312
- A keras `History` object, whose `history` property contains all the metrics.
313
- """
314
-
315
- # Load training data
316
- print("Loading training data...", flush=True)
317
- x_train, y_train, x_test, y_test, labels = _load_training_data(cfg.TRAIN_CACHE_MODE, cfg.TRAIN_CACHE_FILE, on_data_load_end)
318
- print(f"...Done. Loaded {x_train.shape[0]} training samples and {y_train.shape[1]} labels.", flush=True)
319
- if len(x_test) > 0:
320
- print(f"...Loaded {x_test.shape[0]} test samples.", flush=True)
321
-
322
- # Normalize embeddings
323
- print("Normalizing embeddings...", flush=True)
324
- x_train = normalize_embeddings(x_train)
325
- if len(x_test) > 0:
326
- x_test = normalize_embeddings(x_test)
327
-
328
- if cfg.AUTOTUNE:
329
- import gc
330
-
331
- import keras
332
- import keras_tuner
333
-
334
- # Call callback to initialize progress bar
335
- if on_trial_result:
336
- on_trial_result(0)
337
-
338
- class BirdNetTuner(keras_tuner.BayesianOptimization):
339
- def __init__(self, x_train, y_train, x_test, y_test, max_trials, executions_per_trial, on_trial_result):
340
- super().__init__(
341
- max_trials=max_trials,
342
- executions_per_trial=executions_per_trial,
343
- overwrite=True,
344
- directory=autotune_directory,
345
- project_name="birdnet_analyzer",
346
- )
347
- self.x_train = x_train
348
- self.y_train = y_train
349
- self.x_test = x_test
350
- self.y_test = y_test
351
- self.on_trial_result = on_trial_result
352
-
353
- def run_trial(self, trial, *args, **kwargs):
354
- histories = []
355
- hp: keras_tuner.HyperParameters = trial.hyperparameters
356
- trial_number = len(self.oracle.trials)
357
-
358
- for execution in range(int(self.executions_per_trial)):
359
- print(f"Running Trial #{trial_number} execution #{execution + 1}", flush=True)
360
-
361
- # Build model
362
- print("Building model...", flush=True)
363
- classifier = model.build_linear_classifier(
364
- self.y_train.shape[1],
365
- self.x_train.shape[1],
366
- hidden_units=hp.Choice(
367
- "hidden_units", [0, 128, 256, 512, 1024, 2048], default=cfg.TRAIN_HIDDEN_UNITS
368
- ),
369
- dropout=hp.Choice("dropout", [0.0, 0.25, 0.33, 0.5, 0.75, 0.9], default=cfg.TRAIN_DROPOUT),
370
- )
371
- print("...Done.", flush=True)
372
-
373
- # Only allow repeat upsampling in multi-label setting
374
- upsampling_choices = ["repeat", "mean", "linear"] # SMOTE is too slow
375
-
376
- if cfg.MULTI_LABEL:
377
- upsampling_choices = ["repeat"]
378
-
379
- batch_size = hp.Choice("batch_size", [8, 16, 32, 64, 128], default=cfg.TRAIN_BATCH_SIZE)
380
-
381
- if batch_size == 8:
382
- learning_rate = hp.Choice(
383
- "learning_rate_8",
384
- [0.0005, 0.0002, 0.0001],
385
- default=0.0001,
386
- parent_name="batch_size",
387
- parent_values=[8],
388
- )
389
- elif batch_size == 16:
390
- learning_rate = hp.Choice(
391
- "learning_rate_16",
392
- [0.005, 0.002, 0.001, 0.0005, 0.0002],
393
- default=0.0005,
394
- parent_name="batch_size",
395
- parent_values=[16],
396
- )
397
- elif batch_size == 32:
398
- learning_rate = hp.Choice(
399
- "learning_rate_32",
400
- [0.01, 0.005, 0.001, 0.0005, 0.0001],
401
- default=0.0001,
402
- parent_name="batch_size",
403
- parent_values=[32],
404
- )
405
- elif batch_size == 64:
406
- learning_rate = hp.Choice(
407
- "learning_rate_64",
408
- [0.01, 0.005, 0.002, 0.001],
409
- default=0.001,
410
- parent_name="batch_size",
411
- parent_values=[64],
412
- )
413
- elif batch_size == 128:
414
- learning_rate = hp.Choice(
415
- "learning_rate_128",
416
- [0.1, 0.01, 0.005],
417
- default=0.005,
418
- parent_name="batch_size",
419
- parent_values=[128],
420
- )
421
-
422
- # Train model
423
- print("Training model...", flush=True)
424
- classifier, history = model.train_linear_classifier(
425
- classifier,
426
- self.x_train,
427
- self.y_train,
428
- self.x_test,
429
- self.y_test,
430
- epochs=cfg.TRAIN_EPOCHS,
431
- batch_size=batch_size,
432
- learning_rate=learning_rate,
433
- val_split=0.0 if len(self.x_test) > 0 else cfg.TRAIN_VAL_SPLIT,
434
- upsampling_ratio=hp.Choice(
435
- "upsampling_ratio", [0.0, 0.25, 0.33, 0.5, 0.75, 1.0], default=cfg.UPSAMPLING_RATIO
436
- ),
437
- upsampling_mode=hp.Choice(
438
- "upsampling_mode",
439
- upsampling_choices,
440
- default=cfg.UPSAMPLING_MODE,
441
- parent_name="upsampling_ratio",
442
- parent_values=[0.25, 0.33, 0.5, 0.75, 1.0],
443
- ),
444
- train_with_mixup=hp.Boolean("mixup", default=cfg.TRAIN_WITH_MIXUP),
445
- train_with_label_smoothing=hp.Boolean(
446
- "label_smoothing", default=cfg.TRAIN_WITH_LABEL_SMOOTHING
447
- ),
448
- train_with_focal_loss=hp.Boolean("focal_loss", default=cfg.TRAIN_WITH_FOCAL_LOSS),
449
- focal_loss_gamma=hp.Choice(
450
- "focal_loss_gamma",
451
- [0.5, 1.0, 2.0, 3.0, 4.0],
452
- default=cfg.FOCAL_LOSS_GAMMA,
453
- parent_name="focal_loss",
454
- parent_values=[True]
455
- ),
456
- focal_loss_alpha=hp.Choice(
457
- "focal_loss_alpha",
458
- [0.1, 0.25, 0.5, 0.75, 0.9],
459
- default=cfg.FOCAL_LOSS_ALPHA,
460
- parent_name="focal_loss",
461
- parent_values=[True]
462
- ),
463
- )
464
-
465
- # Get the best validation AUPRC instead of loss
466
- best_val_auprc = history.history["val_AUPRC"][np.argmax(history.history["val_AUPRC"])]
467
- histories.append(best_val_auprc)
468
-
469
- print(
470
- f"Finished Trial #{trial_number} execution #{execution + 1}. Best validation AUPRC: {best_val_auprc}",
471
- flush=True,
472
- )
473
-
474
- keras.backend.clear_session()
475
- del classifier
476
- del history
477
- gc.collect()
478
-
479
- # Call the on_trial_result callback
480
- if self.on_trial_result:
481
- self.on_trial_result(trial_number)
482
-
483
- # Return the negative AUPRC for minimization (keras-tuner minimizes by default)
484
- return [-h for h in histories]
485
-
486
- # Create the tuner instance
487
- tuner = BirdNetTuner(
488
- x_train=x_train,
489
- y_train=y_train,
490
- x_test=x_test,
491
- y_test=y_test,
492
- max_trials=cfg.AUTOTUNE_TRIALS,
493
- executions_per_trial=cfg.AUTOTUNE_EXECUTIONS_PER_TRIAL,
494
- on_trial_result=on_trial_result,
495
- )
496
- try:
497
- tuner.search()
498
- except model.get_empty_class_exception() as e:
499
- e.message = f"Class with label {labels[e.index]} is empty. Please remove it from the training data."
500
- e.args = (e.message,)
501
- raise e
502
-
503
- best_params = tuner.get_best_hyperparameters()[0]
504
-
505
- cfg.TRAIN_HIDDEN_UNITS = best_params["hidden_units"]
506
- cfg.TRAIN_DROPOUT = best_params["dropout"]
507
- cfg.TRAIN_BATCH_SIZE = best_params["batch_size"]
508
- cfg.TRAIN_LEARNING_RATE = best_params[f"learning_rate_{cfg.TRAIN_BATCH_SIZE}"]
509
- if cfg.UPSAMPLING_RATIO > 0:
510
- cfg.UPSAMPLING_MODE = best_params["upsampling_mode"]
511
- cfg.UPSAMPLING_RATIO = best_params["upsampling_ratio"]
512
- cfg.TRAIN_WITH_MIXUP = best_params["mixup"]
513
- cfg.TRAIN_WITH_LABEL_SMOOTHING = best_params["label_smoothing"]
514
-
515
- print("Best params: ")
516
- print("hidden_units: ", cfg.TRAIN_HIDDEN_UNITS)
517
- print("dropout: ", cfg.TRAIN_DROPOUT)
518
- print("batch_size: ", cfg.TRAIN_BATCH_SIZE)
519
- print("learning_rate: ", cfg.TRAIN_LEARNING_RATE)
520
- print("upsampling_ratio: ", cfg.UPSAMPLING_RATIO)
521
- if cfg.UPSAMPLING_RATIO > 0:
522
- print("upsampling_mode: ", cfg.UPSAMPLING_MODE)
523
- print("mixup: ", cfg.TRAIN_WITH_MIXUP)
524
- print("label_smoothing: ", cfg.TRAIN_WITH_LABEL_SMOOTHING)
525
-
526
- # Build model
527
- print("Building model...", flush=True)
528
- classifier = model.build_linear_classifier(
529
- y_train.shape[1], x_train.shape[1], cfg.TRAIN_HIDDEN_UNITS, cfg.TRAIN_DROPOUT
530
- )
531
- print("...Done.", flush=True)
532
-
533
- # Train model
534
- print("Training model...", flush=True)
535
- try:
536
- classifier, history = model.train_linear_classifier(
537
- classifier,
538
- x_train,
539
- y_train,
540
- x_test,
541
- y_test,
542
- epochs=cfg.TRAIN_EPOCHS,
543
- batch_size=cfg.TRAIN_BATCH_SIZE,
544
- learning_rate=cfg.TRAIN_LEARNING_RATE,
545
- val_split=cfg.TRAIN_VAL_SPLIT if len(x_test) == 0 else 0.0,
546
- upsampling_ratio=cfg.UPSAMPLING_RATIO,
547
- upsampling_mode=cfg.UPSAMPLING_MODE,
548
- train_with_mixup=cfg.TRAIN_WITH_MIXUP,
549
- train_with_label_smoothing=cfg.TRAIN_WITH_LABEL_SMOOTHING,
550
- train_with_focal_loss=cfg.TRAIN_WITH_FOCAL_LOSS,
551
- focal_loss_gamma=cfg.FOCAL_LOSS_GAMMA,
552
- focal_loss_alpha=cfg.FOCAL_LOSS_ALPHA,
553
- on_epoch_end=on_epoch_end,
554
- )
555
- except model.get_empty_class_exception() as e:
556
- e.message = f"Class with label {labels[e.index]} is empty. Please remove it from the training data."
557
- e.args = (e.message,)
558
- raise e
559
- except Exception as e:
560
- raise Exception("Error training model") from e
561
-
562
- print("...Done.", flush=True)
563
-
564
- # Get best validation metrics based on AUPRC instead of loss for more reliable results with imbalanced data
565
- best_epoch = np.argmax(history.history["val_AUPRC"])
566
- best_val_auprc = history.history["val_AUPRC"][best_epoch]
567
- best_val_auroc = history.history["val_AUROC"][best_epoch]
568
- best_val_loss = history.history["val_loss"][best_epoch]
569
-
570
- print("Saving model...", flush=True)
571
-
572
- try:
573
- if cfg.TRAINED_MODEL_OUTPUT_FORMAT == "both":
574
- model.save_raven_model(classifier, cfg.CUSTOM_CLASSIFIER, labels, mode=cfg.TRAINED_MODEL_SAVE_MODE)
575
- model.save_linear_classifier(classifier, cfg.CUSTOM_CLASSIFIER, labels, mode=cfg.TRAINED_MODEL_SAVE_MODE)
576
- elif cfg.TRAINED_MODEL_OUTPUT_FORMAT == "tflite":
577
- model.save_linear_classifier(classifier, cfg.CUSTOM_CLASSIFIER, labels, mode=cfg.TRAINED_MODEL_SAVE_MODE)
578
- elif cfg.TRAINED_MODEL_OUTPUT_FORMAT == "raven":
579
- model.save_raven_model(classifier, cfg.CUSTOM_CLASSIFIER, labels, mode=cfg.TRAINED_MODEL_SAVE_MODE)
580
- else:
581
- raise ValueError(f"Unknown model output format: {cfg.TRAINED_MODEL_OUTPUT_FORMAT}")
582
- except Exception as e:
583
- raise Exception("Error saving model") from e
584
-
585
- save_sample_counts(labels, y_train)
586
-
587
- # Evaluate model on test data if available
588
- metrics = None
589
- if len(x_test) > 0:
590
- print("\nEvaluating model on test data...", flush=True)
591
- metrics = evaluate_model(classifier, x_test, y_test, labels)
592
-
593
- # Save evaluation results to file
594
- if metrics:
595
- import csv
596
- eval_file_path = cfg.CUSTOM_CLASSIFIER + "_evaluation.csv"
597
- with open(eval_file_path, 'w', newline='') as f:
598
- writer = csv.writer(f)
599
-
600
- # Define all the metrics as columns, including both default and optimized threshold metrics
601
- header = ['Class',
602
- 'Precision (0.5)', 'Recall (0.5)', 'F1 Score (0.5)',
603
- 'Precision (opt)', 'Recall (opt)', 'F1 Score (opt)',
604
- 'AUPRC', 'AUROC', 'Optimal Threshold',
605
- 'True Positives', 'False Positives', 'True Negatives', 'False Negatives',
606
- 'Samples', 'Percentage (%)']
607
- writer.writerow(header)
608
-
609
- # Write macro-averaged metrics (overall scores) first
610
- writer.writerow([
611
- 'OVERALL (Macro-avg)',
612
- f"{metrics['macro_precision_default']:.4f}",
613
- f"{metrics['macro_recall_default']:.4f}",
614
- f"{metrics['macro_f1_default']:.4f}",
615
- f"{metrics['macro_precision_opt']:.4f}",
616
- f"{metrics['macro_recall_opt']:.4f}",
617
- f"{metrics['macro_f1_opt']:.4f}",
618
- f"{metrics['macro_auprc']:.4f}",
619
- f"{metrics['macro_auroc']:.4f}",
620
- '', '', '', '', '', '', '' # Empty cells for Threshold, TP, FP, TN, FN, Samples, Percentage
621
- ])
622
-
623
- # Write per-class metrics (one row per species)
624
- for class_name, class_metrics in metrics['class_metrics'].items():
625
- distribution = metrics['class_distribution'].get(class_name, {'count': 0, 'percentage': 0.0})
626
- writer.writerow([
627
- class_name,
628
- f"{class_metrics['precision_default']:.4f}",
629
- f"{class_metrics['recall_default']:.4f}",
630
- f"{class_metrics['f1_default']:.4f}",
631
- f"{class_metrics['precision_opt']:.4f}",
632
- f"{class_metrics['recall_opt']:.4f}",
633
- f"{class_metrics['f1_opt']:.4f}",
634
- f"{class_metrics['auprc']:.4f}",
635
- f"{class_metrics['auroc']:.4f}",
636
- f"{class_metrics['threshold']:.2f}",
637
- class_metrics['tp'],
638
- class_metrics['fp'],
639
- class_metrics['tn'],
640
- class_metrics['fn'],
641
- distribution['count'],
642
- f"{distribution['percentage']:.2f}"
643
- ])
644
-
645
- print(f"Evaluation results saved to {eval_file_path}", flush=True)
646
- else:
647
- print("\nNo separate test data provided for evaluation. Using validation metrics.", flush=True)
648
-
649
- print(f"...Done. Best AUPRC: {best_val_auprc}, Best AUROC: {best_val_auroc}, Best Loss: {best_val_loss} (epoch {best_epoch+1}/{len(history.epoch)})", flush=True)
650
-
651
- return history, metrics
652
-
653
- def find_optimal_threshold(y_true, y_pred_prob):
654
- """
655
- Find the optimal classification threshold using the F1 score.
656
-
657
- For imbalanced datasets, the default threshold of 0.5 may not be optimal.
658
- This function finds the threshold that maximizes the F1 score for each class.
659
-
660
- Args:
661
- y_true: Ground truth labels
662
- y_pred_prob: Predicted probabilities
663
-
664
- Returns:
665
- The optimal threshold value
666
- """
667
- from sklearn.metrics import f1_score
668
-
669
- # Try different thresholds and find the one that gives the best F1 score
670
- best_threshold = 0.5
671
- best_f1 = 0.0
672
-
673
- for threshold in np.arange(0.1, 0.9, 0.05):
674
- y_pred = (y_pred_prob >= threshold).astype(int)
675
- f1 = f1_score(y_true, y_pred)
676
-
677
- if f1 > best_f1:
678
- best_f1 = f1
679
- best_threshold = threshold
680
-
681
- return best_threshold
682
-
683
-
684
- def evaluate_model(classifier, x_test, y_test, labels, threshold=None):
685
- """
686
- Evaluates the trained model on test data and prints detailed metrics.
687
-
688
- Args:
689
- classifier: The trained model
690
- x_test: Test features (embeddings)
691
- y_test: Test labels
692
- labels: List of label names
693
- threshold: Classification threshold (if None, will find optimal threshold for each class)
694
-
695
- Returns:
696
- Dictionary with evaluation metrics
697
- """
698
- from sklearn.metrics import (
699
- precision_score, recall_score, f1_score,
700
- confusion_matrix, classification_report,
701
- average_precision_score, roc_auc_score
702
- )
703
-
704
- # Skip evaluation if test set is empty
705
- if len(x_test) == 0:
706
- print("No test data available for evaluation.")
707
- return {}
708
-
709
- # Make predictions
710
- y_pred_prob = classifier.predict(x_test)
711
-
712
- # Calculate metrics for each class
713
- metrics = {}
714
-
715
- print("\nModel Evaluation:")
716
- print("=================")
717
-
718
- # Calculate metrics for each class
719
- precisions_default = []
720
- recalls_default = []
721
- f1s_default = []
722
- precisions_opt = []
723
- recalls_opt = []
724
- f1s_opt = []
725
- auprcs = []
726
- aurocs = []
727
- class_metrics = {}
728
- optimal_thresholds = {}
729
-
730
- # Print the metric calculation method that's being used
731
- print("\nNote: The AUPRC and AUROC metrics calculated during post-training evaluation may differ")
732
- print("from training history values due to different calculation methods:")
733
- print(" - Training history uses Keras metrics calculated over batches")
734
- print(" - Evaluation uses scikit-learn metrics calculated over the entire dataset")
735
-
736
- for i in range(y_test.shape[1]):
737
- try:
738
- # Calculate metrics with default threshold (0.5)
739
- y_pred_default = (y_pred_prob[:, i] >= 0.5).astype(int)
740
-
741
- class_precision_default = precision_score(y_test[:, i], y_pred_default)
742
- class_recall_default = recall_score(y_test[:, i], y_pred_default)
743
- class_f1_default = f1_score(y_test[:, i], y_pred_default)
744
-
745
- precisions_default.append(class_precision_default)
746
- recalls_default.append(class_recall_default)
747
- f1s_default.append(class_f1_default)
748
-
749
- # Find optimal threshold for this class if needed
750
- if threshold is None:
751
- class_threshold = find_optimal_threshold(y_test[:, i], y_pred_prob[:, i])
752
- optimal_thresholds[labels[i]] = class_threshold
753
- else:
754
- class_threshold = threshold
755
-
756
- # Calculate metrics with optimized threshold
757
- y_pred_opt = (y_pred_prob[:, i] >= class_threshold).astype(int)
758
-
759
- class_precision_opt = precision_score(y_test[:, i], y_pred_opt)
760
- class_recall_opt = recall_score(y_test[:, i], y_pred_opt)
761
- class_f1_opt = f1_score(y_test[:, i], y_pred_opt)
762
- class_auprc = average_precision_score(y_test[:, i], y_pred_prob[:, i])
763
- class_auroc = roc_auc_score(y_test[:, i], y_pred_prob[:, i])
764
-
765
- precisions_opt.append(class_precision_opt)
766
- recalls_opt.append(class_recall_opt)
767
- f1s_opt.append(class_f1_opt)
768
- auprcs.append(class_auprc)
769
- aurocs.append(class_auroc)
770
-
771
- # Confusion matrix with optimized threshold
772
- tn, fp, fn, tp = confusion_matrix(y_test[:, i], y_pred_opt).ravel()
773
-
774
- class_metrics[labels[i]] = {
775
- 'precision_default': class_precision_default,
776
- 'recall_default': class_recall_default,
777
- 'f1_default': class_f1_default,
778
- 'precision_opt': class_precision_opt,
779
- 'recall_opt': class_recall_opt,
780
- 'f1_opt': class_f1_opt,
781
- 'auprc': class_auprc,
782
- 'auroc': class_auroc,
783
- 'tp': tp,
784
- 'fp': fp,
785
- 'tn': tn,
786
- 'fn': fn,
787
- 'threshold': class_threshold
788
- }
789
-
790
- print(f"\nClass: {labels[i]}")
791
- print(f" Default threshold (0.5):")
792
- print(f" Precision: {class_precision_default:.4f}")
793
- print(f" Recall: {class_recall_default:.4f}")
794
- print(f" F1 Score: {class_f1_default:.4f}")
795
- print(f" Optimized threshold ({class_threshold:.2f}):")
796
- print(f" Precision: {class_precision_opt:.4f}")
797
- print(f" Recall: {class_recall_opt:.4f}")
798
- print(f" F1 Score: {class_f1_opt:.4f}")
799
- print(f" AUPRC: {class_auprc:.4f}")
800
- print(f" AUROC: {class_auroc:.4f}")
801
- print(f" Confusion matrix (optimized threshold):")
802
- print(f" True Positives: {tp}")
803
- print(f" False Positives: {fp}")
804
- print(f" True Negatives: {tn}")
805
- print(f" False Negatives: {fn}")
806
-
807
- except Exception as e:
808
- print(f"Error calculating metrics for class {labels[i]}: {e}")
809
-
810
- # Calculate macro-averaged metrics for both default and optimized thresholds
811
- metrics['macro_precision_default'] = np.mean(precisions_default)
812
- metrics['macro_recall_default'] = np.mean(recalls_default)
813
- metrics['macro_f1_default'] = np.mean(f1s_default)
814
- metrics['macro_precision_opt'] = np.mean(precisions_opt)
815
- metrics['macro_recall_opt'] = np.mean(recalls_opt)
816
- metrics['macro_f1_opt'] = np.mean(f1s_opt)
817
- metrics['macro_auprc'] = np.mean(auprcs)
818
- metrics['macro_auroc'] = np.mean(aurocs)
819
- metrics['class_metrics'] = class_metrics
820
- metrics['optimal_thresholds'] = optimal_thresholds
821
-
822
- print("\nMacro-averaged metrics:")
823
- print(f" Default threshold (0.5):")
824
- print(f" Precision: {metrics['macro_precision_default']:.4f}")
825
- print(f" Recall: {metrics['macro_recall_default']:.4f}")
826
- print(f" F1 Score: {metrics['macro_f1_default']:.4f}")
827
- print(f" Optimized thresholds:")
828
- print(f" Precision: {metrics['macro_precision_opt']:.4f}")
829
- print(f" Recall: {metrics['macro_recall_opt']:.4f}")
830
- print(f" F1 Score: {metrics['macro_f1_opt']:.4f}")
831
- print(f" AUPRC: {metrics['macro_auprc']:.4f}")
832
- print(f" AUROC: {metrics['macro_auroc']:.4f}")
833
-
834
- # Calculate class distribution in test set
835
- class_counts = y_test.sum(axis=0)
836
- total_samples = len(y_test)
837
- class_distribution = {}
838
-
839
- print("\nClass distribution in test set:")
840
- for i, count in enumerate(class_counts):
841
- percentage = count / total_samples * 100
842
- class_distribution[labels[i]] = {'count': int(count), 'percentage': percentage}
843
- print(f" {labels[i]}: {int(count)} samples ({percentage:.2f}%)")
844
-
845
- metrics['class_distribution'] = class_distribution
846
-
847
- return metrics
1
+ """Module for training a custom classifier.
2
+
3
+ Can be used to train a custom classifier with new training data.
4
+ """
5
+
6
+ import csv
7
+ import os
8
+ from functools import partial
9
+ from multiprocessing.pool import Pool
10
+
11
+ import numpy as np
12
+ import tqdm
13
+
14
+ import birdnet_analyzer.config as cfg
15
+ from birdnet_analyzer import audio, model, utils
16
+
17
+
18
+ def save_sample_counts(labels, y_train):
19
+ """
20
+ Saves the count of samples per label combination to a CSV file.
21
+
22
+ The function creates a dictionary where the keys are label combinations (joined by '+') and the values are the counts of samples for each combination.
23
+ It then writes this information to a CSV file named "<cfg.CUSTOM_CLASSIFIER>_sample_counts.csv" with two columns: "Label" and "Count".
24
+
25
+ Args:
26
+ labels (list of str): List of label names corresponding to the columns in y_train.
27
+ y_train (numpy.ndarray): 2D array where each row is a binary vector indicating the presence (1) or absence (0) of each label.
28
+ """
29
+ samples_per_label = {}
30
+ label_combinations = np.unique(y_train, axis=0)
31
+
32
+ for label_combination in label_combinations:
33
+ label = "+".join([labels[i] for i in range(len(label_combination)) if label_combination[i] == 1])
34
+ samples_per_label[label] = np.sum(np.all(y_train == label_combination, axis=1))
35
+
36
+ csv_file_path = cfg.CUSTOM_CLASSIFIER + "_sample_counts.csv"
37
+
38
+ with open(csv_file_path, mode="w", newline="") as csv_file:
39
+ writer = csv.writer(csv_file)
40
+ writer.writerow(["Label", "Count"])
41
+
42
+ for label, count in samples_per_label.items():
43
+ writer.writerow([label, count])
44
+
45
+
46
+ def _load_audio_file(f, label_vector, config):
47
+ """Load an audio file and extract features.
48
+ Args:
49
+ f: Path to the audio file.
50
+ label_vector: The label vector for the file.
51
+ Returns:
52
+ A tuple of (x_train, y_train).
53
+ """
54
+
55
+ x_train = []
56
+ y_train = []
57
+
58
+ # restore config in case we're on Windows to be thread save
59
+ cfg.set_config(config)
60
+
61
+ # Try to load the audio file
62
+ try:
63
+ # Load audio
64
+ sig, rate = audio.open_audio_file(
65
+ f,
66
+ duration=cfg.SIG_LENGTH if cfg.SAMPLE_CROP_MODE == "first" else None,
67
+ fmin=cfg.BANDPASS_FMIN,
68
+ fmax=cfg.BANDPASS_FMAX,
69
+ speed=cfg.AUDIO_SPEED,
70
+ )
71
+
72
+ # if anything happens print the error and ignore the file
73
+ except Exception as e:
74
+ # Print Error
75
+ print(f"\t Error when loading file {f}", flush=True)
76
+ print(f"\t {e}", flush=True)
77
+ return np.array([]), np.array([])
78
+
79
+ # Crop training samples
80
+ if cfg.SAMPLE_CROP_MODE == "center":
81
+ sig_splits = [audio.crop_center(sig, rate, cfg.SIG_LENGTH)]
82
+ elif cfg.SAMPLE_CROP_MODE == "first":
83
+ sig_splits = [audio.split_signal(sig, rate, cfg.SIG_LENGTH, cfg.SIG_OVERLAP, cfg.SIG_MINLEN)[0]]
84
+ elif cfg.SAMPLE_CROP_MODE == "smart":
85
+ # Smart cropping - detect peaks in audio energy to identify potential signals
86
+ sig_splits = audio.smart_crop_signal(sig, rate, cfg.SIG_LENGTH, cfg.SIG_OVERLAP, cfg.SIG_MINLEN)
87
+ else:
88
+ sig_splits = audio.split_signal(sig, rate, cfg.SIG_LENGTH, cfg.SIG_OVERLAP, cfg.SIG_MINLEN)
89
+
90
+ # Get feature embeddings
91
+ batch_size = 1 # turns out that batch size 1 is the fastest, probably because of having to resize the model input when the number of samples in a batch changes
92
+ for i in range(0, len(sig_splits), batch_size):
93
+ batch_sig = sig_splits[i : i + batch_size]
94
+ batch_label = [label_vector] * len(batch_sig)
95
+ embeddings = model.embeddings(batch_sig)
96
+
97
+ # Add to training data
98
+ x_train.extend(embeddings)
99
+ y_train.extend(batch_label)
100
+
101
+ return x_train, y_train
102
+
103
+
104
+ def _load_training_data(cache_mode=None, cache_file="", progress_callback=None):
105
+ """Loads the data for training.
106
+
107
+ Reads all subdirectories of "config.TRAIN_DATA_PATH" and uses their names as new labels.
108
+
109
+ These directories should contain all the training data for each label.
110
+
111
+ If a cache file is provided, the training data is loaded from there.
112
+
113
+ Args:
114
+ cache_mode: Cache mode. Can be 'load' or 'save'. Defaults to None.
115
+ cache_file: Path to cache file.
116
+
117
+ Returns:
118
+ A tuple of (x_train, y_train, x_test, y_test, labels).
119
+ """
120
+ # Load from cache
121
+ if cache_mode == "load":
122
+ if os.path.isfile(cache_file):
123
+ print(f"\t...loading from cache: {cache_file}", flush=True)
124
+ x_train, y_train, x_test, y_test, labels, cfg.BINARY_CLASSIFICATION, cfg.MULTI_LABEL = (
125
+ utils.load_from_cache(cache_file)
126
+ )
127
+ return x_train, y_train, x_test, y_test, labels
128
+
129
+ print(f"\t...cache file not found: {cache_file}", flush=True)
130
+
131
+ # Print train and test data path as confirmation
132
+ print(f"\t...train data path: {cfg.TRAIN_DATA_PATH}", flush=True)
133
+ print(f"\t...test data path: {cfg.TEST_DATA_PATH}", flush=True)
134
+
135
+ # Get list of subfolders as labels
136
+ train_folders = sorted(utils.list_subdirectories(cfg.TRAIN_DATA_PATH))
137
+
138
+ # Read all individual labels from the folder names
139
+ labels = []
140
+
141
+ for folder in train_folders:
142
+ labels_in_folder = folder.split(",")
143
+ for label in labels_in_folder:
144
+ if label not in labels:
145
+ labels.append(label)
146
+
147
+ # Sort labels
148
+ labels = sorted(labels)
149
+
150
+ # Get valid labels
151
+ valid_labels = [
152
+ label for label in labels if label.lower() not in cfg.NON_EVENT_CLASSES and not label.startswith("-")
153
+ ]
154
+
155
+ # Check if binary classification
156
+ cfg.BINARY_CLASSIFICATION = len(valid_labels) == 1
157
+
158
+ # Validate the classes for binary classification
159
+ if cfg.BINARY_CLASSIFICATION:
160
+ if len([f for f in train_folders if f.startswith("-")]) > 0:
161
+ raise Exception(
162
+ "Negative labels can't be used with binary classification",
163
+ "validation-no-negative-samples-in-binary-classification",
164
+ )
165
+ if len([f for f in train_folders if f.lower() in cfg.NON_EVENT_CLASSES]) == 0:
166
+ raise Exception(
167
+ "Non-event samples are required for binary classification",
168
+ "validation-non-event-samples-required-in-binary-classification",
169
+ )
170
+
171
+ # Check if multi label
172
+ cfg.MULTI_LABEL = len(valid_labels) > 1 and any("," in f for f in train_folders)
173
+
174
+ # Check if multi-label and binary classficication
175
+ if cfg.BINARY_CLASSIFICATION and cfg.MULTI_LABEL:
176
+ raise Exception("Error: Binary classfication and multi-label not possible at the same time")
177
+
178
+ # Only allow repeat upsampling for multi-label setting
179
+ if cfg.MULTI_LABEL and cfg.UPSAMPLING_RATIO > 0 and cfg.UPSAMPLING_MODE != "repeat":
180
+ raise Exception(
181
+ "Only repeat-upsampling ist available for multi-label", "validation-only-repeat-upsampling-for-multi-label"
182
+ )
183
+
184
+ # Load training data
185
+ x_train = []
186
+ y_train = []
187
+ x_test = []
188
+ y_test = []
189
+
190
+ def load_data(data_path, allowed_folders):
191
+ x = []
192
+ y = []
193
+ folders = sorted(utils.list_subdirectories(data_path))
194
+
195
+ for folder in folders:
196
+ if folder not in allowed_folders:
197
+ print(f"Skipping folder {folder} because it is not in the training data.", flush=True)
198
+ continue
199
+
200
+ # Get label vector
201
+ label_vector = np.zeros((len(valid_labels),), dtype="float32")
202
+ folder_labels = folder.split(",")
203
+
204
+ for label in folder_labels:
205
+ if label.lower() not in cfg.NON_EVENT_CLASSES and not label.startswith("-"):
206
+ label_vector[valid_labels.index(label)] = 1
207
+ elif (
208
+ label.startswith("-") and label[1:] in valid_labels
209
+ ): # Negative labels need to be contained in the valid labels
210
+ label_vector[valid_labels.index(label[1:])] = -1
211
+
212
+ # Get list of files
213
+ # Filter files that start with '.' because macOS seems to them for temp files.
214
+ files = filter(
215
+ os.path.isfile,
216
+ (
217
+ os.path.join(data_path, folder, f)
218
+ for f in sorted(os.listdir(os.path.join(data_path, folder)))
219
+ if not f.startswith(".") and f.rsplit(".", 1)[-1].lower() in cfg.ALLOWED_FILETYPES
220
+ ),
221
+ )
222
+
223
+ # Load files using thread pool
224
+ with Pool(cfg.CPU_THREADS) as p:
225
+ tasks = []
226
+
227
+ for f in files:
228
+ task = p.apply_async(
229
+ partial(_load_audio_file, f=f, label_vector=label_vector, config=cfg.get_config())
230
+ )
231
+ tasks.append(task)
232
+
233
+ # Wait for tasks to complete and monitor progress with tqdm
234
+ num_files_processed = 0
235
+
236
+ with tqdm.tqdm(total=len(tasks), desc=f" - loading '{folder}'", unit="f") as progress_bar:
237
+ for task in tasks:
238
+ result = task.get()
239
+ # Make sure result is not empty
240
+ # Empty results might be caused by errors when loading the audio file
241
+ # TODO: We should check for embeddings size in result, otherwise we can't add them to the training data
242
+ if len(result[0]) > 0:
243
+ x += result[0]
244
+ y += result[1]
245
+
246
+ num_files_processed += 1
247
+ progress_bar.update(1)
248
+
249
+ if progress_callback:
250
+ progress_callback(num_files_processed, len(tasks), folder)
251
+ return np.array(x, dtype="float32"), np.array(y, dtype="float32")
252
+
253
+ x_train, y_train = load_data(cfg.TRAIN_DATA_PATH, train_folders)
254
+
255
+ if cfg.TEST_DATA_PATH and cfg.TEST_DATA_PATH != cfg.TRAIN_DATA_PATH:
256
+ test_folders = sorted(utils.list_subdirectories(cfg.TEST_DATA_PATH))
257
+ allowed_test_folders = [
258
+ folder for folder in test_folders if folder in train_folders and not folder.startswith("-")
259
+ ]
260
+ x_test, y_test = load_data(cfg.TEST_DATA_PATH, allowed_test_folders)
261
+ else:
262
+ x_test = np.array([])
263
+ y_test = np.array([])
264
+
265
+ # Save to cache?
266
+ if cache_mode == "save":
267
+ print(f"\t...saving training data to cache: {cache_file}", flush=True)
268
+ try:
269
+ # Only save the valid labels
270
+ utils.save_to_cache(cache_file, x_train, y_train, x_test, y_test, valid_labels)
271
+ except Exception as e:
272
+ print(f"\t...error saving cache: {e}", flush=True)
273
+
274
+ # Return only the valid labels for further use
275
+ return x_train, y_train, x_test, y_test, valid_labels
276
+
277
+
278
+ def normalize_embeddings(embeddings):
279
+ """
280
+ Normalize embeddings to improve training stability and performance.
281
+
282
+ This applies L2 normalization to each embedding vector, which can help
283
+ with convergence and model performance, especially when training on
284
+ embeddings from different sources or domains.
285
+
286
+ Args:
287
+ embeddings: numpy array of embedding vectors
288
+
289
+ Returns:
290
+ Normalized embeddings array
291
+ """
292
+ # Calculate L2 norm of each embedding vector
293
+ norms = np.sqrt(np.sum(embeddings**2, axis=1, keepdims=True))
294
+ # Avoid division by zero
295
+ norms[norms == 0] = 1.0
296
+ # Normalize each embedding vector
297
+ return embeddings / norms
298
+
299
+
300
+ def train_model(on_epoch_end=None, on_trial_result=None, on_data_load_end=None, autotune_directory="autotune"):
301
+ """Trains a custom classifier.
302
+
303
+ Args:
304
+ on_epoch_end: A callback function that takes two arguments `epoch`, `logs`.
305
+ on_trial_result: A callback function for hyperparameter tuning.
306
+ on_data_load_end: A callback function for data loading progress.
307
+ autotune_directory: Directory for autotune results.
308
+
309
+ Returns:
310
+ A keras `History` object, whose `history` property contains all the metrics.
311
+ """
312
+
313
+ # Load training data
314
+ print("Loading training data...", flush=True)
315
+ x_train, y_train, x_test, y_test, labels = _load_training_data(
316
+ cfg.TRAIN_CACHE_MODE, cfg.TRAIN_CACHE_FILE, on_data_load_end
317
+ )
318
+ print(f"...Done. Loaded {x_train.shape[0]} training samples and {y_train.shape[1]} labels.", flush=True)
319
+ if len(x_test) > 0:
320
+ print(f"...Loaded {x_test.shape[0]} test samples.", flush=True)
321
+
322
+ # Normalize embeddings
323
+ print("Normalizing embeddings...", flush=True)
324
+ x_train = normalize_embeddings(x_train)
325
+ if len(x_test) > 0:
326
+ x_test = normalize_embeddings(x_test)
327
+
328
+ if cfg.AUTOTUNE:
329
+ import gc
330
+
331
+ import keras
332
+ import keras_tuner
333
+
334
+ # Call callback to initialize progress bar
335
+ if on_trial_result:
336
+ on_trial_result(0)
337
+
338
+ class BirdNetTuner(keras_tuner.BayesianOptimization):
339
+ def __init__(self, x_train, y_train, x_test, y_test, max_trials, executions_per_trial, on_trial_result):
340
+ super().__init__(
341
+ max_trials=max_trials,
342
+ executions_per_trial=executions_per_trial,
343
+ overwrite=True,
344
+ directory=autotune_directory,
345
+ project_name="birdnet_analyzer",
346
+ )
347
+ self.x_train = x_train
348
+ self.y_train = y_train
349
+ self.x_test = x_test
350
+ self.y_test = y_test
351
+ self.on_trial_result = on_trial_result
352
+
353
+ def run_trial(self, trial, *args, **kwargs):
354
+ histories = []
355
+ hp: keras_tuner.HyperParameters = trial.hyperparameters
356
+ trial_number = len(self.oracle.trials)
357
+
358
+ for execution in range(int(self.executions_per_trial)):
359
+ print(f"Running Trial #{trial_number} execution #{execution + 1}", flush=True)
360
+
361
+ # Build model
362
+ print("Building model...", flush=True)
363
+ classifier = model.build_linear_classifier(
364
+ self.y_train.shape[1],
365
+ self.x_train.shape[1],
366
+ hidden_units=hp.Choice(
367
+ "hidden_units", [0, 128, 256, 512, 1024, 2048], default=cfg.TRAIN_HIDDEN_UNITS
368
+ ),
369
+ dropout=hp.Choice("dropout", [0.0, 0.25, 0.33, 0.5, 0.75, 0.9], default=cfg.TRAIN_DROPOUT),
370
+ )
371
+ print("...Done.", flush=True)
372
+
373
+ # Only allow repeat upsampling in multi-label setting
374
+ upsampling_choices = ["repeat", "mean", "linear"] # SMOTE is too slow
375
+
376
+ if cfg.MULTI_LABEL:
377
+ upsampling_choices = ["repeat"]
378
+
379
+ batch_size = hp.Choice("batch_size", [8, 16, 32, 64, 128], default=cfg.TRAIN_BATCH_SIZE)
380
+
381
+ if batch_size == 8:
382
+ learning_rate = hp.Choice(
383
+ "learning_rate_8",
384
+ [0.0005, 0.0002, 0.0001],
385
+ default=0.0001,
386
+ parent_name="batch_size",
387
+ parent_values=[8],
388
+ )
389
+ elif batch_size == 16:
390
+ learning_rate = hp.Choice(
391
+ "learning_rate_16",
392
+ [0.005, 0.002, 0.001, 0.0005, 0.0002],
393
+ default=0.0005,
394
+ parent_name="batch_size",
395
+ parent_values=[16],
396
+ )
397
+ elif batch_size == 32:
398
+ learning_rate = hp.Choice(
399
+ "learning_rate_32",
400
+ [0.01, 0.005, 0.001, 0.0005, 0.0001],
401
+ default=0.0001,
402
+ parent_name="batch_size",
403
+ parent_values=[32],
404
+ )
405
+ elif batch_size == 64:
406
+ learning_rate = hp.Choice(
407
+ "learning_rate_64",
408
+ [0.01, 0.005, 0.002, 0.001],
409
+ default=0.001,
410
+ parent_name="batch_size",
411
+ parent_values=[64],
412
+ )
413
+ elif batch_size == 128:
414
+ learning_rate = hp.Choice(
415
+ "learning_rate_128",
416
+ [0.1, 0.01, 0.005],
417
+ default=0.005,
418
+ parent_name="batch_size",
419
+ parent_values=[128],
420
+ )
421
+
422
+ # Train model
423
+ print("Training model...", flush=True)
424
+ classifier, history = model.train_linear_classifier(
425
+ classifier,
426
+ self.x_train,
427
+ self.y_train,
428
+ self.x_test,
429
+ self.y_test,
430
+ epochs=cfg.TRAIN_EPOCHS,
431
+ batch_size=batch_size,
432
+ learning_rate=learning_rate,
433
+ val_split=0.0 if len(self.x_test) > 0 else cfg.TRAIN_VAL_SPLIT,
434
+ upsampling_ratio=hp.Choice(
435
+ "upsampling_ratio", [0.0, 0.25, 0.33, 0.5, 0.75, 1.0], default=cfg.UPSAMPLING_RATIO
436
+ ),
437
+ upsampling_mode=hp.Choice(
438
+ "upsampling_mode",
439
+ upsampling_choices,
440
+ default=cfg.UPSAMPLING_MODE,
441
+ parent_name="upsampling_ratio",
442
+ parent_values=[0.25, 0.33, 0.5, 0.75, 1.0],
443
+ ),
444
+ train_with_mixup=hp.Boolean("mixup", default=cfg.TRAIN_WITH_MIXUP),
445
+ train_with_label_smoothing=hp.Boolean(
446
+ "label_smoothing", default=cfg.TRAIN_WITH_LABEL_SMOOTHING
447
+ ),
448
+ train_with_focal_loss=hp.Boolean("focal_loss", default=cfg.TRAIN_WITH_FOCAL_LOSS),
449
+ focal_loss_gamma=hp.Choice(
450
+ "focal_loss_gamma",
451
+ [0.5, 1.0, 2.0, 3.0, 4.0],
452
+ default=cfg.FOCAL_LOSS_GAMMA,
453
+ parent_name="focal_loss",
454
+ parent_values=[True],
455
+ ),
456
+ focal_loss_alpha=hp.Choice(
457
+ "focal_loss_alpha",
458
+ [0.1, 0.25, 0.5, 0.75, 0.9],
459
+ default=cfg.FOCAL_LOSS_ALPHA,
460
+ parent_name="focal_loss",
461
+ parent_values=[True],
462
+ ),
463
+ )
464
+
465
+ # Get the best validation AUPRC instead of loss
466
+ best_val_auprc = history.history["val_AUPRC"][np.argmax(history.history["val_AUPRC"])]
467
+ histories.append(best_val_auprc)
468
+
469
+ print(
470
+ f"Finished Trial #{trial_number} execution #{execution + 1}. Best validation AUPRC: {best_val_auprc}",
471
+ flush=True,
472
+ )
473
+
474
+ keras.backend.clear_session()
475
+ del classifier
476
+ del history
477
+ gc.collect()
478
+
479
+ # Call the on_trial_result callback
480
+ if self.on_trial_result:
481
+ self.on_trial_result(trial_number)
482
+
483
+ # Return the negative AUPRC for minimization (keras-tuner minimizes by default)
484
+ return [-h for h in histories]
485
+
486
+ # Create the tuner instance
487
+ tuner = BirdNetTuner(
488
+ x_train=x_train,
489
+ y_train=y_train,
490
+ x_test=x_test,
491
+ y_test=y_test,
492
+ max_trials=cfg.AUTOTUNE_TRIALS,
493
+ executions_per_trial=cfg.AUTOTUNE_EXECUTIONS_PER_TRIAL,
494
+ on_trial_result=on_trial_result,
495
+ )
496
+ try:
497
+ tuner.search()
498
+ except model.get_empty_class_exception() as e:
499
+ e.message = f"Class with label {labels[e.index]} is empty. Please remove it from the training data."
500
+ e.args = (e.message,)
501
+ raise e
502
+
503
+ best_params = tuner.get_best_hyperparameters()[0]
504
+
505
+ cfg.TRAIN_HIDDEN_UNITS = best_params["hidden_units"]
506
+ cfg.TRAIN_DROPOUT = best_params["dropout"]
507
+ cfg.TRAIN_BATCH_SIZE = best_params["batch_size"]
508
+ cfg.TRAIN_LEARNING_RATE = best_params[f"learning_rate_{cfg.TRAIN_BATCH_SIZE}"]
509
+ if cfg.UPSAMPLING_RATIO > 0:
510
+ cfg.UPSAMPLING_MODE = best_params["upsampling_mode"]
511
+ cfg.UPSAMPLING_RATIO = best_params["upsampling_ratio"]
512
+ cfg.TRAIN_WITH_MIXUP = best_params["mixup"]
513
+ cfg.TRAIN_WITH_LABEL_SMOOTHING = best_params["label_smoothing"]
514
+
515
+ print("Best params: ")
516
+ print("hidden_units: ", cfg.TRAIN_HIDDEN_UNITS)
517
+ print("dropout: ", cfg.TRAIN_DROPOUT)
518
+ print("batch_size: ", cfg.TRAIN_BATCH_SIZE)
519
+ print("learning_rate: ", cfg.TRAIN_LEARNING_RATE)
520
+ print("upsampling_ratio: ", cfg.UPSAMPLING_RATIO)
521
+ if cfg.UPSAMPLING_RATIO > 0:
522
+ print("upsampling_mode: ", cfg.UPSAMPLING_MODE)
523
+ print("mixup: ", cfg.TRAIN_WITH_MIXUP)
524
+ print("label_smoothing: ", cfg.TRAIN_WITH_LABEL_SMOOTHING)
525
+
526
+ # Build model
527
+ print("Building model...", flush=True)
528
+ classifier = model.build_linear_classifier(
529
+ y_train.shape[1], x_train.shape[1], cfg.TRAIN_HIDDEN_UNITS, cfg.TRAIN_DROPOUT
530
+ )
531
+ print("...Done.", flush=True)
532
+
533
+ # Train model
534
+ print("Training model...", flush=True)
535
+ try:
536
+ classifier, history = model.train_linear_classifier(
537
+ classifier,
538
+ x_train,
539
+ y_train,
540
+ x_test,
541
+ y_test,
542
+ epochs=cfg.TRAIN_EPOCHS,
543
+ batch_size=cfg.TRAIN_BATCH_SIZE,
544
+ learning_rate=cfg.TRAIN_LEARNING_RATE,
545
+ val_split=cfg.TRAIN_VAL_SPLIT if len(x_test) == 0 else 0.0,
546
+ upsampling_ratio=cfg.UPSAMPLING_RATIO,
547
+ upsampling_mode=cfg.UPSAMPLING_MODE,
548
+ train_with_mixup=cfg.TRAIN_WITH_MIXUP,
549
+ train_with_label_smoothing=cfg.TRAIN_WITH_LABEL_SMOOTHING,
550
+ train_with_focal_loss=cfg.TRAIN_WITH_FOCAL_LOSS,
551
+ focal_loss_gamma=cfg.FOCAL_LOSS_GAMMA,
552
+ focal_loss_alpha=cfg.FOCAL_LOSS_ALPHA,
553
+ on_epoch_end=on_epoch_end,
554
+ )
555
+ except model.get_empty_class_exception() as e:
556
+ e.message = f"Class with label {labels[e.index]} is empty. Please remove it from the training data."
557
+ e.args = (e.message,)
558
+ raise e
559
+ except Exception as e:
560
+ raise Exception("Error training model") from e
561
+
562
+ print("...Done.", flush=True)
563
+
564
+ # Get best validation metrics based on AUPRC instead of loss for more reliable results with imbalanced data
565
+ best_epoch = np.argmax(history.history["val_AUPRC"])
566
+ best_val_auprc = history.history["val_AUPRC"][best_epoch]
567
+ best_val_auroc = history.history["val_AUROC"][best_epoch]
568
+ best_val_loss = history.history["val_loss"][best_epoch]
569
+
570
+ print("Saving model...", flush=True)
571
+
572
+ try:
573
+ if cfg.TRAINED_MODEL_OUTPUT_FORMAT == "both":
574
+ model.save_raven_model(classifier, cfg.CUSTOM_CLASSIFIER, labels, mode=cfg.TRAINED_MODEL_SAVE_MODE)
575
+ model.save_linear_classifier(classifier, cfg.CUSTOM_CLASSIFIER, labels, mode=cfg.TRAINED_MODEL_SAVE_MODE)
576
+ elif cfg.TRAINED_MODEL_OUTPUT_FORMAT == "tflite":
577
+ model.save_linear_classifier(classifier, cfg.CUSTOM_CLASSIFIER, labels, mode=cfg.TRAINED_MODEL_SAVE_MODE)
578
+ elif cfg.TRAINED_MODEL_OUTPUT_FORMAT == "raven":
579
+ model.save_raven_model(classifier, cfg.CUSTOM_CLASSIFIER, labels, mode=cfg.TRAINED_MODEL_SAVE_MODE)
580
+ else:
581
+ raise ValueError(f"Unknown model output format: {cfg.TRAINED_MODEL_OUTPUT_FORMAT}")
582
+ except Exception as e:
583
+ raise Exception("Error saving model") from e
584
+
585
+ save_sample_counts(labels, y_train)
586
+
587
+ # Evaluate model on test data if available
588
+ metrics = None
589
+ if len(x_test) > 0:
590
+ print("\nEvaluating model on test data...", flush=True)
591
+ metrics = evaluate_model(classifier, x_test, y_test, labels)
592
+
593
+ # Save evaluation results to file
594
+ if metrics:
595
+ import csv
596
+
597
+ eval_file_path = cfg.CUSTOM_CLASSIFIER + "_evaluation.csv"
598
+ with open(eval_file_path, "w", newline="") as f:
599
+ writer = csv.writer(f)
600
+
601
+ # Define all the metrics as columns, including both default and optimized threshold metrics
602
+ header = [
603
+ "Class",
604
+ "Precision (0.5)",
605
+ "Recall (0.5)",
606
+ "F1 Score (0.5)",
607
+ "Precision (opt)",
608
+ "Recall (opt)",
609
+ "F1 Score (opt)",
610
+ "AUPRC",
611
+ "AUROC",
612
+ "Optimal Threshold",
613
+ "True Positives",
614
+ "False Positives",
615
+ "True Negatives",
616
+ "False Negatives",
617
+ "Samples",
618
+ "Percentage (%)",
619
+ ]
620
+ writer.writerow(header)
621
+
622
+ # Write macro-averaged metrics (overall scores) first
623
+ writer.writerow(
624
+ [
625
+ "OVERALL (Macro-avg)",
626
+ f"{metrics['macro_precision_default']:.4f}",
627
+ f"{metrics['macro_recall_default']:.4f}",
628
+ f"{metrics['macro_f1_default']:.4f}",
629
+ f"{metrics['macro_precision_opt']:.4f}",
630
+ f"{metrics['macro_recall_opt']:.4f}",
631
+ f"{metrics['macro_f1_opt']:.4f}",
632
+ f"{metrics['macro_auprc']:.4f}",
633
+ f"{metrics['macro_auroc']:.4f}",
634
+ "",
635
+ "",
636
+ "",
637
+ "",
638
+ "",
639
+ "",
640
+ "", # Empty cells for Threshold, TP, FP, TN, FN, Samples, Percentage
641
+ ]
642
+ )
643
+
644
+ # Write per-class metrics (one row per species)
645
+ for class_name, class_metrics in metrics["class_metrics"].items():
646
+ distribution = metrics["class_distribution"].get(class_name, {"count": 0, "percentage": 0.0})
647
+ writer.writerow(
648
+ [
649
+ class_name,
650
+ f"{class_metrics['precision_default']:.4f}",
651
+ f"{class_metrics['recall_default']:.4f}",
652
+ f"{class_metrics['f1_default']:.4f}",
653
+ f"{class_metrics['precision_opt']:.4f}",
654
+ f"{class_metrics['recall_opt']:.4f}",
655
+ f"{class_metrics['f1_opt']:.4f}",
656
+ f"{class_metrics['auprc']:.4f}",
657
+ f"{class_metrics['auroc']:.4f}",
658
+ f"{class_metrics['threshold']:.2f}",
659
+ class_metrics["tp"],
660
+ class_metrics["fp"],
661
+ class_metrics["tn"],
662
+ class_metrics["fn"],
663
+ distribution["count"],
664
+ f"{distribution['percentage']:.2f}",
665
+ ]
666
+ )
667
+
668
+ print(f"Evaluation results saved to {eval_file_path}", flush=True)
669
+ else:
670
+ print("\nNo separate test data provided for evaluation. Using validation metrics.", flush=True)
671
+
672
+ print(
673
+ f"...Done. Best AUPRC: {best_val_auprc}, Best AUROC: {best_val_auroc}, Best Loss: {best_val_loss} (epoch {best_epoch + 1}/{len(history.epoch)})",
674
+ flush=True,
675
+ )
676
+
677
+ return history, metrics
678
+
679
+
680
+ def find_optimal_threshold(y_true, y_pred_prob):
681
+ """
682
+ Find the optimal classification threshold using the F1 score.
683
+
684
+ For imbalanced datasets, the default threshold of 0.5 may not be optimal.
685
+ This function finds the threshold that maximizes the F1 score for each class.
686
+
687
+ Args:
688
+ y_true: Ground truth labels
689
+ y_pred_prob: Predicted probabilities
690
+
691
+ Returns:
692
+ The optimal threshold value
693
+ """
694
+ from sklearn.metrics import f1_score
695
+
696
+ # Try different thresholds and find the one that gives the best F1 score
697
+ best_threshold = 0.5
698
+ best_f1 = 0.0
699
+
700
+ for threshold in np.arange(0.1, 0.9, 0.05):
701
+ y_pred = (y_pred_prob >= threshold).astype(int)
702
+ f1 = f1_score(y_true, y_pred)
703
+
704
+ if f1 > best_f1:
705
+ best_f1 = f1
706
+ best_threshold = threshold
707
+
708
+ return best_threshold
709
+
710
+
711
+ def evaluate_model(classifier, x_test, y_test, labels, threshold=None):
712
+ """
713
+ Evaluates the trained model on test data and prints detailed metrics.
714
+
715
+ Args:
716
+ classifier: The trained model
717
+ x_test: Test features (embeddings)
718
+ y_test: Test labels
719
+ labels: List of label names
720
+ threshold: Classification threshold (if None, will find optimal threshold for each class)
721
+
722
+ Returns:
723
+ Dictionary with evaluation metrics
724
+ """
725
+ from sklearn.metrics import (
726
+ average_precision_score,
727
+ confusion_matrix,
728
+ f1_score,
729
+ precision_score,
730
+ recall_score,
731
+ roc_auc_score,
732
+ )
733
+
734
+ # Skip evaluation if test set is empty
735
+ if len(x_test) == 0:
736
+ print("No test data available for evaluation.")
737
+ return {}
738
+
739
+ # Make predictions
740
+ y_pred_prob = classifier.predict(x_test)
741
+
742
+ # Calculate metrics for each class
743
+ metrics = {}
744
+
745
+ print("\nModel Evaluation:")
746
+ print("=================")
747
+
748
+ # Calculate metrics for each class
749
+ precisions_default = []
750
+ recalls_default = []
751
+ f1s_default = []
752
+ precisions_opt = []
753
+ recalls_opt = []
754
+ f1s_opt = []
755
+ auprcs = []
756
+ aurocs = []
757
+ class_metrics = {}
758
+ optimal_thresholds = {}
759
+
760
+ # Print the metric calculation method that's being used
761
+ print("\nNote: The AUPRC and AUROC metrics calculated during post-training evaluation may differ")
762
+ print("from training history values due to different calculation methods:")
763
+ print(" - Training history uses Keras metrics calculated over batches")
764
+ print(" - Evaluation uses scikit-learn metrics calculated over the entire dataset")
765
+
766
+ for i in range(y_test.shape[1]):
767
+ try:
768
+ # Calculate metrics with default threshold (0.5)
769
+ y_pred_default = (y_pred_prob[:, i] >= 0.5).astype(int)
770
+
771
+ class_precision_default = precision_score(y_test[:, i], y_pred_default)
772
+ class_recall_default = recall_score(y_test[:, i], y_pred_default)
773
+ class_f1_default = f1_score(y_test[:, i], y_pred_default)
774
+
775
+ precisions_default.append(class_precision_default)
776
+ recalls_default.append(class_recall_default)
777
+ f1s_default.append(class_f1_default)
778
+
779
+ # Find optimal threshold for this class if needed
780
+ if threshold is None:
781
+ class_threshold = find_optimal_threshold(y_test[:, i], y_pred_prob[:, i])
782
+ optimal_thresholds[labels[i]] = class_threshold
783
+ else:
784
+ class_threshold = threshold
785
+
786
+ # Calculate metrics with optimized threshold
787
+ y_pred_opt = (y_pred_prob[:, i] >= class_threshold).astype(int)
788
+
789
+ class_precision_opt = precision_score(y_test[:, i], y_pred_opt)
790
+ class_recall_opt = recall_score(y_test[:, i], y_pred_opt)
791
+ class_f1_opt = f1_score(y_test[:, i], y_pred_opt)
792
+ class_auprc = average_precision_score(y_test[:, i], y_pred_prob[:, i])
793
+ class_auroc = roc_auc_score(y_test[:, i], y_pred_prob[:, i])
794
+
795
+ precisions_opt.append(class_precision_opt)
796
+ recalls_opt.append(class_recall_opt)
797
+ f1s_opt.append(class_f1_opt)
798
+ auprcs.append(class_auprc)
799
+ aurocs.append(class_auroc)
800
+
801
+ # Confusion matrix with optimized threshold
802
+ tn, fp, fn, tp = confusion_matrix(y_test[:, i], y_pred_opt).ravel()
803
+
804
+ class_metrics[labels[i]] = {
805
+ "precision_default": class_precision_default,
806
+ "recall_default": class_recall_default,
807
+ "f1_default": class_f1_default,
808
+ "precision_opt": class_precision_opt,
809
+ "recall_opt": class_recall_opt,
810
+ "f1_opt": class_f1_opt,
811
+ "auprc": class_auprc,
812
+ "auroc": class_auroc,
813
+ "tp": tp,
814
+ "fp": fp,
815
+ "tn": tn,
816
+ "fn": fn,
817
+ "threshold": class_threshold,
818
+ }
819
+
820
+ print(f"\nClass: {labels[i]}")
821
+ print(" Default threshold (0.5):")
822
+ print(f" Precision: {class_precision_default:.4f}")
823
+ print(f" Recall: {class_recall_default:.4f}")
824
+ print(f" F1 Score: {class_f1_default:.4f}")
825
+ print(f" Optimized threshold ({class_threshold:.2f}):")
826
+ print(f" Precision: {class_precision_opt:.4f}")
827
+ print(f" Recall: {class_recall_opt:.4f}")
828
+ print(f" F1 Score: {class_f1_opt:.4f}")
829
+ print(f" AUPRC: {class_auprc:.4f}")
830
+ print(f" AUROC: {class_auroc:.4f}")
831
+ print(" Confusion matrix (optimized threshold):")
832
+ print(f" True Positives: {tp}")
833
+ print(f" False Positives: {fp}")
834
+ print(f" True Negatives: {tn}")
835
+ print(f" False Negatives: {fn}")
836
+
837
+ except Exception as e:
838
+ print(f"Error calculating metrics for class {labels[i]}: {e}")
839
+
840
+ # Calculate macro-averaged metrics for both default and optimized thresholds
841
+ metrics["macro_precision_default"] = np.mean(precisions_default)
842
+ metrics["macro_recall_default"] = np.mean(recalls_default)
843
+ metrics["macro_f1_default"] = np.mean(f1s_default)
844
+ metrics["macro_precision_opt"] = np.mean(precisions_opt)
845
+ metrics["macro_recall_opt"] = np.mean(recalls_opt)
846
+ metrics["macro_f1_opt"] = np.mean(f1s_opt)
847
+ metrics["macro_auprc"] = np.mean(auprcs)
848
+ metrics["macro_auroc"] = np.mean(aurocs)
849
+ metrics["class_metrics"] = class_metrics
850
+ metrics["optimal_thresholds"] = optimal_thresholds
851
+
852
+ print("\nMacro-averaged metrics:")
853
+ print(" Default threshold (0.5):")
854
+ print(f" Precision: {metrics['macro_precision_default']:.4f}")
855
+ print(f" Recall: {metrics['macro_recall_default']:.4f}")
856
+ print(f" F1 Score: {metrics['macro_f1_default']:.4f}")
857
+ print(" Optimized thresholds:")
858
+ print(f" Precision: {metrics['macro_precision_opt']:.4f}")
859
+ print(f" Recall: {metrics['macro_recall_opt']:.4f}")
860
+ print(f" F1 Score: {metrics['macro_f1_opt']:.4f}")
861
+ print(f" AUPRC: {metrics['macro_auprc']:.4f}")
862
+ print(f" AUROC: {metrics['macro_auroc']:.4f}")
863
+
864
+ # Calculate class distribution in test set
865
+ class_counts = y_test.sum(axis=0)
866
+ total_samples = len(y_test)
867
+ class_distribution = {}
868
+
869
+ print("\nClass distribution in test set:")
870
+ for i, count in enumerate(class_counts):
871
+ percentage = count / total_samples * 100
872
+ class_distribution[labels[i]] = {"count": int(count), "percentage": percentage}
873
+ print(f" {labels[i]}: {int(count)} samples ({percentage:.2f}%)")
874
+
875
+ metrics["class_distribution"] = class_distribution
876
+
877
+ return metrics