birdnet-analyzer 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (117) hide show
  1. birdnet_analyzer/__init__.py +8 -0
  2. birdnet_analyzer/analyze/__init__.py +5 -0
  3. birdnet_analyzer/analyze/__main__.py +4 -0
  4. birdnet_analyzer/analyze/cli.py +25 -0
  5. birdnet_analyzer/analyze/core.py +245 -0
  6. birdnet_analyzer/analyze/utils.py +701 -0
  7. birdnet_analyzer/audio.py +372 -0
  8. birdnet_analyzer/cli.py +707 -0
  9. birdnet_analyzer/config.py +242 -0
  10. birdnet_analyzer/eBird_taxonomy_codes_2021E.json +25280 -0
  11. birdnet_analyzer/embeddings/__init__.py +4 -0
  12. birdnet_analyzer/embeddings/__main__.py +3 -0
  13. birdnet_analyzer/embeddings/cli.py +13 -0
  14. birdnet_analyzer/embeddings/core.py +70 -0
  15. birdnet_analyzer/embeddings/utils.py +193 -0
  16. birdnet_analyzer/evaluation/__init__.py +195 -0
  17. birdnet_analyzer/evaluation/__main__.py +3 -0
  18. birdnet_analyzer/gui/__init__.py +23 -0
  19. birdnet_analyzer/gui/__main__.py +3 -0
  20. birdnet_analyzer/gui/analysis.py +174 -0
  21. birdnet_analyzer/gui/assets/arrow_down.svg +4 -0
  22. birdnet_analyzer/gui/assets/arrow_left.svg +4 -0
  23. birdnet_analyzer/gui/assets/arrow_right.svg +4 -0
  24. birdnet_analyzer/gui/assets/arrow_up.svg +4 -0
  25. birdnet_analyzer/gui/assets/gui.css +29 -0
  26. birdnet_analyzer/gui/assets/gui.js +94 -0
  27. birdnet_analyzer/gui/assets/img/birdnet-icon.ico +0 -0
  28. birdnet_analyzer/gui/assets/img/birdnet_logo.png +0 -0
  29. birdnet_analyzer/gui/assets/img/birdnet_logo_no_transparent.png +0 -0
  30. birdnet_analyzer/gui/assets/img/clo-logo-bird.svg +1 -0
  31. birdnet_analyzer/gui/embeddings.py +620 -0
  32. birdnet_analyzer/gui/evaluation.py +813 -0
  33. birdnet_analyzer/gui/localization.py +68 -0
  34. birdnet_analyzer/gui/multi_file.py +246 -0
  35. birdnet_analyzer/gui/review.py +527 -0
  36. birdnet_analyzer/gui/segments.py +191 -0
  37. birdnet_analyzer/gui/settings.py +129 -0
  38. birdnet_analyzer/gui/single_file.py +269 -0
  39. birdnet_analyzer/gui/species.py +95 -0
  40. birdnet_analyzer/gui/train.py +698 -0
  41. birdnet_analyzer/gui/utils.py +808 -0
  42. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_af.txt +6522 -0
  43. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ar.txt +6522 -0
  44. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_bg.txt +6522 -0
  45. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ca.txt +6522 -0
  46. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_cs.txt +6522 -0
  47. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_da.txt +6522 -0
  48. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_de.txt +6522 -0
  49. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_el.txt +6522 -0
  50. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_en_uk.txt +6522 -0
  51. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_es.txt +6522 -0
  52. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_fi.txt +6522 -0
  53. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_fr.txt +6522 -0
  54. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_he.txt +6522 -0
  55. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_hr.txt +6522 -0
  56. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_hu.txt +6522 -0
  57. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_in.txt +6522 -0
  58. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_is.txt +6522 -0
  59. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_it.txt +6522 -0
  60. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ja.txt +6522 -0
  61. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ko.txt +6522 -0
  62. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_lt.txt +6522 -0
  63. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ml.txt +6522 -0
  64. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_nl.txt +6522 -0
  65. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_no.txt +6522 -0
  66. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_pl.txt +6522 -0
  67. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_pt_BR.txt +6522 -0
  68. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_pt_PT.txt +6522 -0
  69. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ro.txt +6522 -0
  70. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ru.txt +6522 -0
  71. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sk.txt +6522 -0
  72. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sl.txt +6522 -0
  73. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sr.txt +6522 -0
  74. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sv.txt +6522 -0
  75. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_th.txt +6522 -0
  76. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_tr.txt +6522 -0
  77. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_uk.txt +6522 -0
  78. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_zh.txt +6522 -0
  79. birdnet_analyzer/lang/de.json +335 -0
  80. birdnet_analyzer/lang/en.json +335 -0
  81. birdnet_analyzer/lang/fi.json +335 -0
  82. birdnet_analyzer/lang/fr.json +335 -0
  83. birdnet_analyzer/lang/id.json +335 -0
  84. birdnet_analyzer/lang/pt-br.json +335 -0
  85. birdnet_analyzer/lang/ru.json +335 -0
  86. birdnet_analyzer/lang/se.json +335 -0
  87. birdnet_analyzer/lang/tlh.json +335 -0
  88. birdnet_analyzer/lang/zh_TW.json +335 -0
  89. birdnet_analyzer/model.py +1243 -0
  90. birdnet_analyzer/search/__init__.py +3 -0
  91. birdnet_analyzer/search/__main__.py +3 -0
  92. birdnet_analyzer/search/cli.py +12 -0
  93. birdnet_analyzer/search/core.py +78 -0
  94. birdnet_analyzer/search/utils.py +111 -0
  95. birdnet_analyzer/segments/__init__.py +3 -0
  96. birdnet_analyzer/segments/__main__.py +3 -0
  97. birdnet_analyzer/segments/cli.py +14 -0
  98. birdnet_analyzer/segments/core.py +78 -0
  99. birdnet_analyzer/segments/utils.py +394 -0
  100. birdnet_analyzer/species/__init__.py +3 -0
  101. birdnet_analyzer/species/__main__.py +3 -0
  102. birdnet_analyzer/species/cli.py +14 -0
  103. birdnet_analyzer/species/core.py +35 -0
  104. birdnet_analyzer/species/utils.py +75 -0
  105. birdnet_analyzer/train/__init__.py +3 -0
  106. birdnet_analyzer/train/__main__.py +3 -0
  107. birdnet_analyzer/train/cli.py +14 -0
  108. birdnet_analyzer/train/core.py +113 -0
  109. birdnet_analyzer/train/utils.py +847 -0
  110. birdnet_analyzer/translate.py +104 -0
  111. birdnet_analyzer/utils.py +419 -0
  112. birdnet_analyzer-2.0.0.dist-info/METADATA +129 -0
  113. birdnet_analyzer-2.0.0.dist-info/RECORD +117 -0
  114. birdnet_analyzer-2.0.0.dist-info/WHEEL +5 -0
  115. birdnet_analyzer-2.0.0.dist-info/entry_points.txt +11 -0
  116. birdnet_analyzer-2.0.0.dist-info/licenses/LICENSE +19 -0
  117. birdnet_analyzer-2.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,847 @@
1
+ """Module for training a custom classifier.
2
+
3
+ Can be used to train a custom classifier with new training data.
4
+ """
5
+
6
+ import csv
7
+ import os
8
+ from functools import partial
9
+ from multiprocessing.pool import Pool
10
+
11
+ import numpy as np
12
+ import tqdm
13
+
14
+ import birdnet_analyzer.audio as audio
15
+ import birdnet_analyzer.config as cfg
16
+ import birdnet_analyzer.model as model
17
+ import birdnet_analyzer.utils as utils
18
+
19
+
20
+ def save_sample_counts(labels, y_train):
21
+ """
22
+ Saves the count of samples per label combination to a CSV file.
23
+
24
+ The function creates a dictionary where the keys are label combinations (joined by '+') and the values are the counts of samples for each combination.
25
+ It then writes this information to a CSV file named "<cfg.CUSTOM_CLASSIFIER>_sample_counts.csv" with two columns: "Label" and "Count".
26
+
27
+ Args:
28
+ labels (list of str): List of label names corresponding to the columns in y_train.
29
+ y_train (numpy.ndarray): 2D array where each row is a binary vector indicating the presence (1) or absence (0) of each label.
30
+ """
31
+ samples_per_label = {}
32
+ label_combinations = np.unique(y_train, axis=0)
33
+
34
+ for label_combination in label_combinations:
35
+ label = "+".join([labels[i] for i in range(len(label_combination)) if label_combination[i] == 1])
36
+ samples_per_label[label] = np.sum(np.all(y_train == label_combination, axis=1))
37
+
38
+ csv_file_path = cfg.CUSTOM_CLASSIFIER + "_sample_counts.csv"
39
+
40
+ with open(csv_file_path, mode="w", newline="") as csv_file:
41
+ writer = csv.writer(csv_file)
42
+ writer.writerow(["Label", "Count"])
43
+
44
+ for label, count in samples_per_label.items():
45
+ writer.writerow([label, count])
46
+
47
+
48
+ def _load_audio_file(f, label_vector, config):
49
+ """Load an audio file and extract features.
50
+ Args:
51
+ f: Path to the audio file.
52
+ label_vector: The label vector for the file.
53
+ Returns:
54
+ A tuple of (x_train, y_train).
55
+ """
56
+
57
+ x_train = []
58
+ y_train = []
59
+
60
+ # restore config in case we're on Windows to be thread save
61
+ cfg.set_config(config)
62
+
63
+ # Try to load the audio file
64
+ try:
65
+ # Load audio
66
+ sig, rate = audio.open_audio_file(
67
+ f,
68
+ duration=cfg.SIG_LENGTH if cfg.SAMPLE_CROP_MODE == "first" else None,
69
+ fmin=cfg.BANDPASS_FMIN,
70
+ fmax=cfg.BANDPASS_FMAX,
71
+ speed=cfg.AUDIO_SPEED,
72
+ )
73
+
74
+ # if anything happens print the error and ignore the file
75
+ except Exception as e:
76
+ # Print Error
77
+ print(f"\t Error when loading file {f}", flush=True)
78
+ print(f"\t {e}", flush=True)
79
+ return np.array([]), np.array([])
80
+
81
+ # Crop training samples
82
+ if cfg.SAMPLE_CROP_MODE == "center":
83
+ sig_splits = [audio.crop_center(sig, rate, cfg.SIG_LENGTH)]
84
+ elif cfg.SAMPLE_CROP_MODE == "first":
85
+ sig_splits = [audio.split_signal(sig, rate, cfg.SIG_LENGTH, cfg.SIG_OVERLAP, cfg.SIG_MINLEN)[0]]
86
+ elif cfg.SAMPLE_CROP_MODE == "smart":
87
+ # Smart cropping - detect peaks in audio energy to identify potential signals
88
+ sig_splits = audio.smart_crop_signal(sig, rate, cfg.SIG_LENGTH, cfg.SIG_OVERLAP, cfg.SIG_MINLEN)
89
+ else:
90
+ sig_splits = audio.split_signal(sig, rate, cfg.SIG_LENGTH, cfg.SIG_OVERLAP, cfg.SIG_MINLEN)
91
+
92
+ # Get feature embeddings
93
+ batch_size = 1 # turns out that batch size 1 is the fastest, probably because of having to resize the model input when the number of samples in a batch changes
94
+ for i in range(0, len(sig_splits), batch_size):
95
+ batch_sig = sig_splits[i : i + batch_size]
96
+ batch_label = [label_vector] * len(batch_sig)
97
+ embeddings = model.embeddings(batch_sig)
98
+
99
+ # Add to training data
100
+ x_train.extend(embeddings)
101
+ y_train.extend(batch_label)
102
+
103
+ return x_train, y_train
104
+
105
+ def _load_training_data(cache_mode=None, cache_file="", progress_callback=None):
106
+ """Loads the data for training.
107
+
108
+ Reads all subdirectories of "config.TRAIN_DATA_PATH" and uses their names as new labels.
109
+
110
+ These directories should contain all the training data for each label.
111
+
112
+ If a cache file is provided, the training data is loaded from there.
113
+
114
+ Args:
115
+ cache_mode: Cache mode. Can be 'load' or 'save'. Defaults to None.
116
+ cache_file: Path to cache file.
117
+
118
+ Returns:
119
+ A tuple of (x_train, y_train, x_test, y_test, labels).
120
+ """
121
+ # Load from cache
122
+ if cache_mode == "load":
123
+ if os.path.isfile(cache_file):
124
+ print(f"\t...loading from cache: {cache_file}", flush=True)
125
+ x_train, y_train, x_test, y_test, labels, cfg.BINARY_CLASSIFICATION, cfg.MULTI_LABEL = (
126
+ utils.load_from_cache(cache_file)
127
+ )
128
+ return x_train, y_train, x_test, y_test, labels
129
+ else:
130
+ print(f"\t...cache file not found: {cache_file}", flush=True)
131
+
132
+ # Print train and test data path as confirmation
133
+ print(f"\t...train data path: {cfg.TRAIN_DATA_PATH}", flush=True)
134
+ print(f"\t...test data path: {cfg.TEST_DATA_PATH}", flush=True)
135
+
136
+ # Get list of subfolders as labels
137
+ train_folders = list(sorted(utils.list_subdirectories(cfg.TRAIN_DATA_PATH)))
138
+
139
+ # Read all individual labels from the folder names
140
+ labels = []
141
+
142
+ for folder in train_folders:
143
+ labels_in_folder = folder.split(",")
144
+ for label in labels_in_folder:
145
+ if label not in labels:
146
+ labels.append(label)
147
+
148
+ # Sort labels
149
+ labels = list(sorted(labels))
150
+
151
+ # Get valid labels
152
+ valid_labels = [
153
+ label for label in labels if label.lower() not in cfg.NON_EVENT_CLASSES and not label.startswith("-")
154
+ ]
155
+
156
+ # Check if binary classification
157
+ cfg.BINARY_CLASSIFICATION = len(valid_labels) == 1
158
+
159
+ # Validate the classes for binary classification
160
+ if cfg.BINARY_CLASSIFICATION:
161
+ if len([f for f in train_folders if f.startswith("-")]) > 0:
162
+ raise Exception(
163
+ "Negative labels can't be used with binary classification",
164
+ "validation-no-negative-samples-in-binary-classification",
165
+ )
166
+ if len([f for f in train_folders if f.lower() in cfg.NON_EVENT_CLASSES]) == 0:
167
+ raise Exception(
168
+ "Non-event samples are required for binary classification",
169
+ "validation-non-event-samples-required-in-binary-classification",
170
+ )
171
+
172
+ # Check if multi label
173
+ cfg.MULTI_LABEL = len(valid_labels) > 1 and any("," in f for f in train_folders)
174
+
175
+ # Check if multi-label and binary classficication
176
+ if cfg.BINARY_CLASSIFICATION and cfg.MULTI_LABEL:
177
+ raise Exception("Error: Binary classfication and multi-label not possible at the same time")
178
+
179
+ # Only allow repeat upsampling for multi-label setting
180
+ if cfg.MULTI_LABEL and cfg.UPSAMPLING_RATIO > 0 and cfg.UPSAMPLING_MODE != "repeat":
181
+ raise Exception(
182
+ "Only repeat-upsampling ist available for multi-label", "validation-only-repeat-upsampling-for-multi-label"
183
+ )
184
+
185
+ # Load training data
186
+ x_train = []
187
+ y_train = []
188
+ x_test = []
189
+ y_test = []
190
+
191
+ def load_data(data_path, allowed_folders):
192
+ x = []
193
+ y = []
194
+ folders = list(sorted(utils.list_subdirectories(data_path)))
195
+
196
+ for folder in folders:
197
+ if folder not in allowed_folders:
198
+ print(f"Skipping folder {folder} because it is not in the training data.", flush=True)
199
+ continue
200
+
201
+ # Get label vector
202
+ label_vector = np.zeros((len(valid_labels),), dtype="float32")
203
+ folder_labels = folder.split(",")
204
+
205
+ for label in folder_labels:
206
+ if label.lower() not in cfg.NON_EVENT_CLASSES and not label.startswith("-"):
207
+ label_vector[valid_labels.index(label)] = 1
208
+ elif (
209
+ label.startswith("-") and label[1:] in valid_labels
210
+ ): # Negative labels need to be contained in the valid labels
211
+ label_vector[valid_labels.index(label[1:])] = -1
212
+
213
+ # Get list of files
214
+ # Filter files that start with '.' because macOS seems to them for temp files.
215
+ files = filter(
216
+ os.path.isfile,
217
+ (
218
+ os.path.join(data_path, folder, f)
219
+ for f in sorted(os.listdir(os.path.join(data_path, folder)))
220
+ if not f.startswith(".") and f.rsplit(".", 1)[-1].lower() in cfg.ALLOWED_FILETYPES
221
+ ),
222
+ )
223
+
224
+ # Load files using thread pool
225
+ with Pool(cfg.CPU_THREADS) as p:
226
+ tasks = []
227
+
228
+ for f in files:
229
+ task = p.apply_async(
230
+ partial(_load_audio_file, f=f, label_vector=label_vector, config=cfg.get_config())
231
+ )
232
+ tasks.append(task)
233
+
234
+ # Wait for tasks to complete and monitor progress with tqdm
235
+ num_files_processed = 0
236
+
237
+ with tqdm.tqdm(total=len(tasks), desc=f" - loading '{folder}'", unit="f") as progress_bar:
238
+ for task in tasks:
239
+ result = task.get()
240
+ # Make sure result is not empty
241
+ # Empty results might be caused by errors when loading the audio file
242
+ # TODO: We should check for embeddings size in result, otherwise we can't add them to the training data
243
+ if len(result[0]) > 0:
244
+ x += result[0]
245
+ y += result[1]
246
+
247
+ num_files_processed += 1
248
+ progress_bar.update(1)
249
+
250
+ if progress_callback:
251
+ progress_callback(num_files_processed, len(tasks), folder)
252
+ return np.array(x, dtype="float32"), np.array(y, dtype="float32")
253
+
254
+ x_train, y_train = load_data(cfg.TRAIN_DATA_PATH, train_folders)
255
+
256
+ if cfg.TEST_DATA_PATH and cfg.TEST_DATA_PATH != cfg.TRAIN_DATA_PATH:
257
+ test_folders = list(sorted(utils.list_subdirectories(cfg.TEST_DATA_PATH)))
258
+ allowed_test_folders = [
259
+ folder for folder in test_folders if folder in train_folders and not folder.startswith("-")
260
+ ]
261
+ x_test, y_test = load_data(cfg.TEST_DATA_PATH, allowed_test_folders)
262
+ else:
263
+ x_test = np.array([])
264
+ y_test = np.array([])
265
+
266
+ # Save to cache?
267
+ if cache_mode == "save":
268
+ print(f"\t...saving training data to cache: {cache_file}", flush=True)
269
+ try:
270
+ # Only save the valid labels
271
+ utils.save_to_cache(cache_file, x_train, y_train, x_test, y_test, valid_labels)
272
+ except Exception as e:
273
+ print(f"\t...error saving cache: {e}", flush=True)
274
+
275
+ # Return only the valid labels for further use
276
+ return x_train, y_train, x_test, y_test, valid_labels
277
+
278
+
279
+ def normalize_embeddings(embeddings):
280
+ """
281
+ Normalize embeddings to improve training stability and performance.
282
+
283
+ This applies L2 normalization to each embedding vector, which can help
284
+ with convergence and model performance, especially when training on
285
+ embeddings from different sources or domains.
286
+
287
+ Args:
288
+ embeddings: numpy array of embedding vectors
289
+
290
+ Returns:
291
+ Normalized embeddings array
292
+ """
293
+ # Calculate L2 norm of each embedding vector
294
+ norms = np.sqrt(np.sum(embeddings**2, axis=1, keepdims=True))
295
+ # Avoid division by zero
296
+ norms[norms == 0] = 1.0
297
+ # Normalize each embedding vector
298
+ normalized = embeddings / norms
299
+ return normalized
300
+
301
+
302
+ def train_model(on_epoch_end=None, on_trial_result=None, on_data_load_end=None, autotune_directory="autotune"):
303
+ """Trains a custom classifier.
304
+
305
+ Args:
306
+ on_epoch_end: A callback function that takes two arguments `epoch`, `logs`.
307
+ on_trial_result: A callback function for hyperparameter tuning.
308
+ on_data_load_end: A callback function for data loading progress.
309
+ autotune_directory: Directory for autotune results.
310
+
311
+ Returns:
312
+ A keras `History` object, whose `history` property contains all the metrics.
313
+ """
314
+
315
+ # Load training data
316
+ print("Loading training data...", flush=True)
317
+ x_train, y_train, x_test, y_test, labels = _load_training_data(cfg.TRAIN_CACHE_MODE, cfg.TRAIN_CACHE_FILE, on_data_load_end)
318
+ print(f"...Done. Loaded {x_train.shape[0]} training samples and {y_train.shape[1]} labels.", flush=True)
319
+ if len(x_test) > 0:
320
+ print(f"...Loaded {x_test.shape[0]} test samples.", flush=True)
321
+
322
+ # Normalize embeddings
323
+ print("Normalizing embeddings...", flush=True)
324
+ x_train = normalize_embeddings(x_train)
325
+ if len(x_test) > 0:
326
+ x_test = normalize_embeddings(x_test)
327
+
328
+ if cfg.AUTOTUNE:
329
+ import gc
330
+
331
+ import keras
332
+ import keras_tuner
333
+
334
+ # Call callback to initialize progress bar
335
+ if on_trial_result:
336
+ on_trial_result(0)
337
+
338
+ class BirdNetTuner(keras_tuner.BayesianOptimization):
339
+ def __init__(self, x_train, y_train, x_test, y_test, max_trials, executions_per_trial, on_trial_result):
340
+ super().__init__(
341
+ max_trials=max_trials,
342
+ executions_per_trial=executions_per_trial,
343
+ overwrite=True,
344
+ directory=autotune_directory,
345
+ project_name="birdnet_analyzer",
346
+ )
347
+ self.x_train = x_train
348
+ self.y_train = y_train
349
+ self.x_test = x_test
350
+ self.y_test = y_test
351
+ self.on_trial_result = on_trial_result
352
+
353
+ def run_trial(self, trial, *args, **kwargs):
354
+ histories = []
355
+ hp: keras_tuner.HyperParameters = trial.hyperparameters
356
+ trial_number = len(self.oracle.trials)
357
+
358
+ for execution in range(int(self.executions_per_trial)):
359
+ print(f"Running Trial #{trial_number} execution #{execution + 1}", flush=True)
360
+
361
+ # Build model
362
+ print("Building model...", flush=True)
363
+ classifier = model.build_linear_classifier(
364
+ self.y_train.shape[1],
365
+ self.x_train.shape[1],
366
+ hidden_units=hp.Choice(
367
+ "hidden_units", [0, 128, 256, 512, 1024, 2048], default=cfg.TRAIN_HIDDEN_UNITS
368
+ ),
369
+ dropout=hp.Choice("dropout", [0.0, 0.25, 0.33, 0.5, 0.75, 0.9], default=cfg.TRAIN_DROPOUT),
370
+ )
371
+ print("...Done.", flush=True)
372
+
373
+ # Only allow repeat upsampling in multi-label setting
374
+ upsampling_choices = ["repeat", "mean", "linear"] # SMOTE is too slow
375
+
376
+ if cfg.MULTI_LABEL:
377
+ upsampling_choices = ["repeat"]
378
+
379
+ batch_size = hp.Choice("batch_size", [8, 16, 32, 64, 128], default=cfg.TRAIN_BATCH_SIZE)
380
+
381
+ if batch_size == 8:
382
+ learning_rate = hp.Choice(
383
+ "learning_rate_8",
384
+ [0.0005, 0.0002, 0.0001],
385
+ default=0.0001,
386
+ parent_name="batch_size",
387
+ parent_values=[8],
388
+ )
389
+ elif batch_size == 16:
390
+ learning_rate = hp.Choice(
391
+ "learning_rate_16",
392
+ [0.005, 0.002, 0.001, 0.0005, 0.0002],
393
+ default=0.0005,
394
+ parent_name="batch_size",
395
+ parent_values=[16],
396
+ )
397
+ elif batch_size == 32:
398
+ learning_rate = hp.Choice(
399
+ "learning_rate_32",
400
+ [0.01, 0.005, 0.001, 0.0005, 0.0001],
401
+ default=0.0001,
402
+ parent_name="batch_size",
403
+ parent_values=[32],
404
+ )
405
+ elif batch_size == 64:
406
+ learning_rate = hp.Choice(
407
+ "learning_rate_64",
408
+ [0.01, 0.005, 0.002, 0.001],
409
+ default=0.001,
410
+ parent_name="batch_size",
411
+ parent_values=[64],
412
+ )
413
+ elif batch_size == 128:
414
+ learning_rate = hp.Choice(
415
+ "learning_rate_128",
416
+ [0.1, 0.01, 0.005],
417
+ default=0.005,
418
+ parent_name="batch_size",
419
+ parent_values=[128],
420
+ )
421
+
422
+ # Train model
423
+ print("Training model...", flush=True)
424
+ classifier, history = model.train_linear_classifier(
425
+ classifier,
426
+ self.x_train,
427
+ self.y_train,
428
+ self.x_test,
429
+ self.y_test,
430
+ epochs=cfg.TRAIN_EPOCHS,
431
+ batch_size=batch_size,
432
+ learning_rate=learning_rate,
433
+ val_split=0.0 if len(self.x_test) > 0 else cfg.TRAIN_VAL_SPLIT,
434
+ upsampling_ratio=hp.Choice(
435
+ "upsampling_ratio", [0.0, 0.25, 0.33, 0.5, 0.75, 1.0], default=cfg.UPSAMPLING_RATIO
436
+ ),
437
+ upsampling_mode=hp.Choice(
438
+ "upsampling_mode",
439
+ upsampling_choices,
440
+ default=cfg.UPSAMPLING_MODE,
441
+ parent_name="upsampling_ratio",
442
+ parent_values=[0.25, 0.33, 0.5, 0.75, 1.0],
443
+ ),
444
+ train_with_mixup=hp.Boolean("mixup", default=cfg.TRAIN_WITH_MIXUP),
445
+ train_with_label_smoothing=hp.Boolean(
446
+ "label_smoothing", default=cfg.TRAIN_WITH_LABEL_SMOOTHING
447
+ ),
448
+ train_with_focal_loss=hp.Boolean("focal_loss", default=cfg.TRAIN_WITH_FOCAL_LOSS),
449
+ focal_loss_gamma=hp.Choice(
450
+ "focal_loss_gamma",
451
+ [0.5, 1.0, 2.0, 3.0, 4.0],
452
+ default=cfg.FOCAL_LOSS_GAMMA,
453
+ parent_name="focal_loss",
454
+ parent_values=[True]
455
+ ),
456
+ focal_loss_alpha=hp.Choice(
457
+ "focal_loss_alpha",
458
+ [0.1, 0.25, 0.5, 0.75, 0.9],
459
+ default=cfg.FOCAL_LOSS_ALPHA,
460
+ parent_name="focal_loss",
461
+ parent_values=[True]
462
+ ),
463
+ )
464
+
465
+ # Get the best validation AUPRC instead of loss
466
+ best_val_auprc = history.history["val_AUPRC"][np.argmax(history.history["val_AUPRC"])]
467
+ histories.append(best_val_auprc)
468
+
469
+ print(
470
+ f"Finished Trial #{trial_number} execution #{execution + 1}. Best validation AUPRC: {best_val_auprc}",
471
+ flush=True,
472
+ )
473
+
474
+ keras.backend.clear_session()
475
+ del classifier
476
+ del history
477
+ gc.collect()
478
+
479
+ # Call the on_trial_result callback
480
+ if self.on_trial_result:
481
+ self.on_trial_result(trial_number)
482
+
483
+ # Return the negative AUPRC for minimization (keras-tuner minimizes by default)
484
+ return [-h for h in histories]
485
+
486
+ # Create the tuner instance
487
+ tuner = BirdNetTuner(
488
+ x_train=x_train,
489
+ y_train=y_train,
490
+ x_test=x_test,
491
+ y_test=y_test,
492
+ max_trials=cfg.AUTOTUNE_TRIALS,
493
+ executions_per_trial=cfg.AUTOTUNE_EXECUTIONS_PER_TRIAL,
494
+ on_trial_result=on_trial_result,
495
+ )
496
+ try:
497
+ tuner.search()
498
+ except model.get_empty_class_exception() as e:
499
+ e.message = f"Class with label {labels[e.index]} is empty. Please remove it from the training data."
500
+ e.args = (e.message,)
501
+ raise e
502
+
503
+ best_params = tuner.get_best_hyperparameters()[0]
504
+
505
+ cfg.TRAIN_HIDDEN_UNITS = best_params["hidden_units"]
506
+ cfg.TRAIN_DROPOUT = best_params["dropout"]
507
+ cfg.TRAIN_BATCH_SIZE = best_params["batch_size"]
508
+ cfg.TRAIN_LEARNING_RATE = best_params[f"learning_rate_{cfg.TRAIN_BATCH_SIZE}"]
509
+ if cfg.UPSAMPLING_RATIO > 0:
510
+ cfg.UPSAMPLING_MODE = best_params["upsampling_mode"]
511
+ cfg.UPSAMPLING_RATIO = best_params["upsampling_ratio"]
512
+ cfg.TRAIN_WITH_MIXUP = best_params["mixup"]
513
+ cfg.TRAIN_WITH_LABEL_SMOOTHING = best_params["label_smoothing"]
514
+
515
+ print("Best params: ")
516
+ print("hidden_units: ", cfg.TRAIN_HIDDEN_UNITS)
517
+ print("dropout: ", cfg.TRAIN_DROPOUT)
518
+ print("batch_size: ", cfg.TRAIN_BATCH_SIZE)
519
+ print("learning_rate: ", cfg.TRAIN_LEARNING_RATE)
520
+ print("upsampling_ratio: ", cfg.UPSAMPLING_RATIO)
521
+ if cfg.UPSAMPLING_RATIO > 0:
522
+ print("upsampling_mode: ", cfg.UPSAMPLING_MODE)
523
+ print("mixup: ", cfg.TRAIN_WITH_MIXUP)
524
+ print("label_smoothing: ", cfg.TRAIN_WITH_LABEL_SMOOTHING)
525
+
526
+ # Build model
527
+ print("Building model...", flush=True)
528
+ classifier = model.build_linear_classifier(
529
+ y_train.shape[1], x_train.shape[1], cfg.TRAIN_HIDDEN_UNITS, cfg.TRAIN_DROPOUT
530
+ )
531
+ print("...Done.", flush=True)
532
+
533
+ # Train model
534
+ print("Training model...", flush=True)
535
+ try:
536
+ classifier, history = model.train_linear_classifier(
537
+ classifier,
538
+ x_train,
539
+ y_train,
540
+ x_test,
541
+ y_test,
542
+ epochs=cfg.TRAIN_EPOCHS,
543
+ batch_size=cfg.TRAIN_BATCH_SIZE,
544
+ learning_rate=cfg.TRAIN_LEARNING_RATE,
545
+ val_split=cfg.TRAIN_VAL_SPLIT if len(x_test) == 0 else 0.0,
546
+ upsampling_ratio=cfg.UPSAMPLING_RATIO,
547
+ upsampling_mode=cfg.UPSAMPLING_MODE,
548
+ train_with_mixup=cfg.TRAIN_WITH_MIXUP,
549
+ train_with_label_smoothing=cfg.TRAIN_WITH_LABEL_SMOOTHING,
550
+ train_with_focal_loss=cfg.TRAIN_WITH_FOCAL_LOSS,
551
+ focal_loss_gamma=cfg.FOCAL_LOSS_GAMMA,
552
+ focal_loss_alpha=cfg.FOCAL_LOSS_ALPHA,
553
+ on_epoch_end=on_epoch_end,
554
+ )
555
+ except model.get_empty_class_exception() as e:
556
+ e.message = f"Class with label {labels[e.index]} is empty. Please remove it from the training data."
557
+ e.args = (e.message,)
558
+ raise e
559
+ except Exception as e:
560
+ raise Exception("Error training model") from e
561
+
562
+ print("...Done.", flush=True)
563
+
564
+ # Get best validation metrics based on AUPRC instead of loss for more reliable results with imbalanced data
565
+ best_epoch = np.argmax(history.history["val_AUPRC"])
566
+ best_val_auprc = history.history["val_AUPRC"][best_epoch]
567
+ best_val_auroc = history.history["val_AUROC"][best_epoch]
568
+ best_val_loss = history.history["val_loss"][best_epoch]
569
+
570
+ print("Saving model...", flush=True)
571
+
572
+ try:
573
+ if cfg.TRAINED_MODEL_OUTPUT_FORMAT == "both":
574
+ model.save_raven_model(classifier, cfg.CUSTOM_CLASSIFIER, labels, mode=cfg.TRAINED_MODEL_SAVE_MODE)
575
+ model.save_linear_classifier(classifier, cfg.CUSTOM_CLASSIFIER, labels, mode=cfg.TRAINED_MODEL_SAVE_MODE)
576
+ elif cfg.TRAINED_MODEL_OUTPUT_FORMAT == "tflite":
577
+ model.save_linear_classifier(classifier, cfg.CUSTOM_CLASSIFIER, labels, mode=cfg.TRAINED_MODEL_SAVE_MODE)
578
+ elif cfg.TRAINED_MODEL_OUTPUT_FORMAT == "raven":
579
+ model.save_raven_model(classifier, cfg.CUSTOM_CLASSIFIER, labels, mode=cfg.TRAINED_MODEL_SAVE_MODE)
580
+ else:
581
+ raise ValueError(f"Unknown model output format: {cfg.TRAINED_MODEL_OUTPUT_FORMAT}")
582
+ except Exception as e:
583
+ raise Exception("Error saving model") from e
584
+
585
+ save_sample_counts(labels, y_train)
586
+
587
+ # Evaluate model on test data if available
588
+ metrics = None
589
+ if len(x_test) > 0:
590
+ print("\nEvaluating model on test data...", flush=True)
591
+ metrics = evaluate_model(classifier, x_test, y_test, labels)
592
+
593
+ # Save evaluation results to file
594
+ if metrics:
595
+ import csv
596
+ eval_file_path = cfg.CUSTOM_CLASSIFIER + "_evaluation.csv"
597
+ with open(eval_file_path, 'w', newline='') as f:
598
+ writer = csv.writer(f)
599
+
600
+ # Define all the metrics as columns, including both default and optimized threshold metrics
601
+ header = ['Class',
602
+ 'Precision (0.5)', 'Recall (0.5)', 'F1 Score (0.5)',
603
+ 'Precision (opt)', 'Recall (opt)', 'F1 Score (opt)',
604
+ 'AUPRC', 'AUROC', 'Optimal Threshold',
605
+ 'True Positives', 'False Positives', 'True Negatives', 'False Negatives',
606
+ 'Samples', 'Percentage (%)']
607
+ writer.writerow(header)
608
+
609
+ # Write macro-averaged metrics (overall scores) first
610
+ writer.writerow([
611
+ 'OVERALL (Macro-avg)',
612
+ f"{metrics['macro_precision_default']:.4f}",
613
+ f"{metrics['macro_recall_default']:.4f}",
614
+ f"{metrics['macro_f1_default']:.4f}",
615
+ f"{metrics['macro_precision_opt']:.4f}",
616
+ f"{metrics['macro_recall_opt']:.4f}",
617
+ f"{metrics['macro_f1_opt']:.4f}",
618
+ f"{metrics['macro_auprc']:.4f}",
619
+ f"{metrics['macro_auroc']:.4f}",
620
+ '', '', '', '', '', '', '' # Empty cells for Threshold, TP, FP, TN, FN, Samples, Percentage
621
+ ])
622
+
623
+ # Write per-class metrics (one row per species)
624
+ for class_name, class_metrics in metrics['class_metrics'].items():
625
+ distribution = metrics['class_distribution'].get(class_name, {'count': 0, 'percentage': 0.0})
626
+ writer.writerow([
627
+ class_name,
628
+ f"{class_metrics['precision_default']:.4f}",
629
+ f"{class_metrics['recall_default']:.4f}",
630
+ f"{class_metrics['f1_default']:.4f}",
631
+ f"{class_metrics['precision_opt']:.4f}",
632
+ f"{class_metrics['recall_opt']:.4f}",
633
+ f"{class_metrics['f1_opt']:.4f}",
634
+ f"{class_metrics['auprc']:.4f}",
635
+ f"{class_metrics['auroc']:.4f}",
636
+ f"{class_metrics['threshold']:.2f}",
637
+ class_metrics['tp'],
638
+ class_metrics['fp'],
639
+ class_metrics['tn'],
640
+ class_metrics['fn'],
641
+ distribution['count'],
642
+ f"{distribution['percentage']:.2f}"
643
+ ])
644
+
645
+ print(f"Evaluation results saved to {eval_file_path}", flush=True)
646
+ else:
647
+ print("\nNo separate test data provided for evaluation. Using validation metrics.", flush=True)
648
+
649
+ print(f"...Done. Best AUPRC: {best_val_auprc}, Best AUROC: {best_val_auroc}, Best Loss: {best_val_loss} (epoch {best_epoch+1}/{len(history.epoch)})", flush=True)
650
+
651
+ return history, metrics
652
+
653
+ def find_optimal_threshold(y_true, y_pred_prob):
654
+ """
655
+ Find the optimal classification threshold using the F1 score.
656
+
657
+ For imbalanced datasets, the default threshold of 0.5 may not be optimal.
658
+ This function finds the threshold that maximizes the F1 score for each class.
659
+
660
+ Args:
661
+ y_true: Ground truth labels
662
+ y_pred_prob: Predicted probabilities
663
+
664
+ Returns:
665
+ The optimal threshold value
666
+ """
667
+ from sklearn.metrics import f1_score
668
+
669
+ # Try different thresholds and find the one that gives the best F1 score
670
+ best_threshold = 0.5
671
+ best_f1 = 0.0
672
+
673
+ for threshold in np.arange(0.1, 0.9, 0.05):
674
+ y_pred = (y_pred_prob >= threshold).astype(int)
675
+ f1 = f1_score(y_true, y_pred)
676
+
677
+ if f1 > best_f1:
678
+ best_f1 = f1
679
+ best_threshold = threshold
680
+
681
+ return best_threshold
682
+
683
+
684
+ def evaluate_model(classifier, x_test, y_test, labels, threshold=None):
685
+ """
686
+ Evaluates the trained model on test data and prints detailed metrics.
687
+
688
+ Args:
689
+ classifier: The trained model
690
+ x_test: Test features (embeddings)
691
+ y_test: Test labels
692
+ labels: List of label names
693
+ threshold: Classification threshold (if None, will find optimal threshold for each class)
694
+
695
+ Returns:
696
+ Dictionary with evaluation metrics
697
+ """
698
+ from sklearn.metrics import (
699
+ precision_score, recall_score, f1_score,
700
+ confusion_matrix, classification_report,
701
+ average_precision_score, roc_auc_score
702
+ )
703
+
704
+ # Skip evaluation if test set is empty
705
+ if len(x_test) == 0:
706
+ print("No test data available for evaluation.")
707
+ return {}
708
+
709
+ # Make predictions
710
+ y_pred_prob = classifier.predict(x_test)
711
+
712
+ # Calculate metrics for each class
713
+ metrics = {}
714
+
715
+ print("\nModel Evaluation:")
716
+ print("=================")
717
+
718
+ # Calculate metrics for each class
719
+ precisions_default = []
720
+ recalls_default = []
721
+ f1s_default = []
722
+ precisions_opt = []
723
+ recalls_opt = []
724
+ f1s_opt = []
725
+ auprcs = []
726
+ aurocs = []
727
+ class_metrics = {}
728
+ optimal_thresholds = {}
729
+
730
+ # Print the metric calculation method that's being used
731
+ print("\nNote: The AUPRC and AUROC metrics calculated during post-training evaluation may differ")
732
+ print("from training history values due to different calculation methods:")
733
+ print(" - Training history uses Keras metrics calculated over batches")
734
+ print(" - Evaluation uses scikit-learn metrics calculated over the entire dataset")
735
+
736
+ for i in range(y_test.shape[1]):
737
+ try:
738
+ # Calculate metrics with default threshold (0.5)
739
+ y_pred_default = (y_pred_prob[:, i] >= 0.5).astype(int)
740
+
741
+ class_precision_default = precision_score(y_test[:, i], y_pred_default)
742
+ class_recall_default = recall_score(y_test[:, i], y_pred_default)
743
+ class_f1_default = f1_score(y_test[:, i], y_pred_default)
744
+
745
+ precisions_default.append(class_precision_default)
746
+ recalls_default.append(class_recall_default)
747
+ f1s_default.append(class_f1_default)
748
+
749
+ # Find optimal threshold for this class if needed
750
+ if threshold is None:
751
+ class_threshold = find_optimal_threshold(y_test[:, i], y_pred_prob[:, i])
752
+ optimal_thresholds[labels[i]] = class_threshold
753
+ else:
754
+ class_threshold = threshold
755
+
756
+ # Calculate metrics with optimized threshold
757
+ y_pred_opt = (y_pred_prob[:, i] >= class_threshold).astype(int)
758
+
759
+ class_precision_opt = precision_score(y_test[:, i], y_pred_opt)
760
+ class_recall_opt = recall_score(y_test[:, i], y_pred_opt)
761
+ class_f1_opt = f1_score(y_test[:, i], y_pred_opt)
762
+ class_auprc = average_precision_score(y_test[:, i], y_pred_prob[:, i])
763
+ class_auroc = roc_auc_score(y_test[:, i], y_pred_prob[:, i])
764
+
765
+ precisions_opt.append(class_precision_opt)
766
+ recalls_opt.append(class_recall_opt)
767
+ f1s_opt.append(class_f1_opt)
768
+ auprcs.append(class_auprc)
769
+ aurocs.append(class_auroc)
770
+
771
+ # Confusion matrix with optimized threshold
772
+ tn, fp, fn, tp = confusion_matrix(y_test[:, i], y_pred_opt).ravel()
773
+
774
+ class_metrics[labels[i]] = {
775
+ 'precision_default': class_precision_default,
776
+ 'recall_default': class_recall_default,
777
+ 'f1_default': class_f1_default,
778
+ 'precision_opt': class_precision_opt,
779
+ 'recall_opt': class_recall_opt,
780
+ 'f1_opt': class_f1_opt,
781
+ 'auprc': class_auprc,
782
+ 'auroc': class_auroc,
783
+ 'tp': tp,
784
+ 'fp': fp,
785
+ 'tn': tn,
786
+ 'fn': fn,
787
+ 'threshold': class_threshold
788
+ }
789
+
790
+ print(f"\nClass: {labels[i]}")
791
+ print(f" Default threshold (0.5):")
792
+ print(f" Precision: {class_precision_default:.4f}")
793
+ print(f" Recall: {class_recall_default:.4f}")
794
+ print(f" F1 Score: {class_f1_default:.4f}")
795
+ print(f" Optimized threshold ({class_threshold:.2f}):")
796
+ print(f" Precision: {class_precision_opt:.4f}")
797
+ print(f" Recall: {class_recall_opt:.4f}")
798
+ print(f" F1 Score: {class_f1_opt:.4f}")
799
+ print(f" AUPRC: {class_auprc:.4f}")
800
+ print(f" AUROC: {class_auroc:.4f}")
801
+ print(f" Confusion matrix (optimized threshold):")
802
+ print(f" True Positives: {tp}")
803
+ print(f" False Positives: {fp}")
804
+ print(f" True Negatives: {tn}")
805
+ print(f" False Negatives: {fn}")
806
+
807
+ except Exception as e:
808
+ print(f"Error calculating metrics for class {labels[i]}: {e}")
809
+
810
+ # Calculate macro-averaged metrics for both default and optimized thresholds
811
+ metrics['macro_precision_default'] = np.mean(precisions_default)
812
+ metrics['macro_recall_default'] = np.mean(recalls_default)
813
+ metrics['macro_f1_default'] = np.mean(f1s_default)
814
+ metrics['macro_precision_opt'] = np.mean(precisions_opt)
815
+ metrics['macro_recall_opt'] = np.mean(recalls_opt)
816
+ metrics['macro_f1_opt'] = np.mean(f1s_opt)
817
+ metrics['macro_auprc'] = np.mean(auprcs)
818
+ metrics['macro_auroc'] = np.mean(aurocs)
819
+ metrics['class_metrics'] = class_metrics
820
+ metrics['optimal_thresholds'] = optimal_thresholds
821
+
822
+ print("\nMacro-averaged metrics:")
823
+ print(f" Default threshold (0.5):")
824
+ print(f" Precision: {metrics['macro_precision_default']:.4f}")
825
+ print(f" Recall: {metrics['macro_recall_default']:.4f}")
826
+ print(f" F1 Score: {metrics['macro_f1_default']:.4f}")
827
+ print(f" Optimized thresholds:")
828
+ print(f" Precision: {metrics['macro_precision_opt']:.4f}")
829
+ print(f" Recall: {metrics['macro_recall_opt']:.4f}")
830
+ print(f" F1 Score: {metrics['macro_f1_opt']:.4f}")
831
+ print(f" AUPRC: {metrics['macro_auprc']:.4f}")
832
+ print(f" AUROC: {metrics['macro_auroc']:.4f}")
833
+
834
+ # Calculate class distribution in test set
835
+ class_counts = y_test.sum(axis=0)
836
+ total_samples = len(y_test)
837
+ class_distribution = {}
838
+
839
+ print("\nClass distribution in test set:")
840
+ for i, count in enumerate(class_counts):
841
+ percentage = count / total_samples * 100
842
+ class_distribution[labels[i]] = {'count': int(count), 'percentage': percentage}
843
+ print(f" {labels[i]}: {int(count)} samples ({percentage:.2f}%)")
844
+
845
+ metrics['class_distribution'] = class_distribution
846
+
847
+ return metrics