birdnet-analyzer 2.0.0__py3-none-any.whl → 2.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (122) hide show
  1. birdnet_analyzer/__init__.py +9 -8
  2. birdnet_analyzer/analyze/__init__.py +5 -5
  3. birdnet_analyzer/analyze/__main__.py +3 -4
  4. birdnet_analyzer/analyze/cli.py +25 -25
  5. birdnet_analyzer/analyze/core.py +241 -245
  6. birdnet_analyzer/analyze/utils.py +692 -701
  7. birdnet_analyzer/audio.py +368 -372
  8. birdnet_analyzer/cli.py +709 -707
  9. birdnet_analyzer/config.py +242 -242
  10. birdnet_analyzer/eBird_taxonomy_codes_2021E.json +25279 -25279
  11. birdnet_analyzer/embeddings/__init__.py +3 -4
  12. birdnet_analyzer/embeddings/__main__.py +3 -3
  13. birdnet_analyzer/embeddings/cli.py +12 -13
  14. birdnet_analyzer/embeddings/core.py +69 -70
  15. birdnet_analyzer/embeddings/utils.py +179 -193
  16. birdnet_analyzer/evaluation/__init__.py +196 -195
  17. birdnet_analyzer/evaluation/__main__.py +3 -3
  18. birdnet_analyzer/evaluation/assessment/__init__.py +0 -0
  19. birdnet_analyzer/evaluation/assessment/metrics.py +388 -0
  20. birdnet_analyzer/evaluation/assessment/performance_assessor.py +409 -0
  21. birdnet_analyzer/evaluation/assessment/plotting.py +379 -0
  22. birdnet_analyzer/evaluation/preprocessing/__init__.py +0 -0
  23. birdnet_analyzer/evaluation/preprocessing/data_processor.py +631 -0
  24. birdnet_analyzer/evaluation/preprocessing/utils.py +98 -0
  25. birdnet_analyzer/gui/__init__.py +19 -23
  26. birdnet_analyzer/gui/__main__.py +3 -3
  27. birdnet_analyzer/gui/analysis.py +175 -174
  28. birdnet_analyzer/gui/assets/arrow_down.svg +4 -4
  29. birdnet_analyzer/gui/assets/arrow_left.svg +4 -4
  30. birdnet_analyzer/gui/assets/arrow_right.svg +4 -4
  31. birdnet_analyzer/gui/assets/arrow_up.svg +4 -4
  32. birdnet_analyzer/gui/assets/gui.css +28 -28
  33. birdnet_analyzer/gui/assets/gui.js +93 -93
  34. birdnet_analyzer/gui/embeddings.py +619 -620
  35. birdnet_analyzer/gui/evaluation.py +795 -813
  36. birdnet_analyzer/gui/localization.py +75 -68
  37. birdnet_analyzer/gui/multi_file.py +245 -246
  38. birdnet_analyzer/gui/review.py +519 -527
  39. birdnet_analyzer/gui/segments.py +191 -191
  40. birdnet_analyzer/gui/settings.py +128 -129
  41. birdnet_analyzer/gui/single_file.py +267 -269
  42. birdnet_analyzer/gui/species.py +95 -95
  43. birdnet_analyzer/gui/train.py +696 -698
  44. birdnet_analyzer/gui/utils.py +810 -808
  45. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_af.txt +6522 -6522
  46. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ar.txt +6522 -6522
  47. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_bg.txt +6522 -6522
  48. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ca.txt +6522 -6522
  49. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_cs.txt +6522 -6522
  50. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_da.txt +6522 -6522
  51. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_de.txt +6522 -6522
  52. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_el.txt +6522 -6522
  53. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_en_uk.txt +6522 -6522
  54. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_es.txt +6522 -6522
  55. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_fi.txt +6522 -6522
  56. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_fr.txt +6522 -6522
  57. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_he.txt +6522 -6522
  58. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_hr.txt +6522 -6522
  59. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_hu.txt +6522 -6522
  60. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_in.txt +6522 -6522
  61. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_is.txt +6522 -6522
  62. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_it.txt +6522 -6522
  63. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ja.txt +6522 -6522
  64. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ko.txt +6522 -6522
  65. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_lt.txt +6522 -6522
  66. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ml.txt +6522 -6522
  67. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_nl.txt +6522 -6522
  68. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_no.txt +6522 -6522
  69. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_pl.txt +6522 -6522
  70. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_pt_BR.txt +6522 -6522
  71. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_pt_PT.txt +6522 -6522
  72. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ro.txt +6522 -6522
  73. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_ru.txt +6522 -6522
  74. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sk.txt +6522 -6522
  75. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sl.txt +6522 -6522
  76. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sr.txt +6522 -6522
  77. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_sv.txt +6522 -6522
  78. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_th.txt +6522 -6522
  79. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_tr.txt +6522 -6522
  80. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_uk.txt +6522 -6522
  81. birdnet_analyzer/labels/V2.4/BirdNET_GLOBAL_6K_V2.4_Labels_zh.txt +6522 -6522
  82. birdnet_analyzer/lang/de.json +334 -334
  83. birdnet_analyzer/lang/en.json +334 -334
  84. birdnet_analyzer/lang/fi.json +334 -334
  85. birdnet_analyzer/lang/fr.json +334 -334
  86. birdnet_analyzer/lang/id.json +334 -334
  87. birdnet_analyzer/lang/pt-br.json +334 -334
  88. birdnet_analyzer/lang/ru.json +334 -334
  89. birdnet_analyzer/lang/se.json +334 -334
  90. birdnet_analyzer/lang/tlh.json +334 -334
  91. birdnet_analyzer/lang/zh_TW.json +334 -334
  92. birdnet_analyzer/model.py +1212 -1243
  93. birdnet_analyzer/playground.py +5 -0
  94. birdnet_analyzer/search/__init__.py +3 -3
  95. birdnet_analyzer/search/__main__.py +3 -3
  96. birdnet_analyzer/search/cli.py +11 -12
  97. birdnet_analyzer/search/core.py +78 -78
  98. birdnet_analyzer/search/utils.py +107 -111
  99. birdnet_analyzer/segments/__init__.py +3 -3
  100. birdnet_analyzer/segments/__main__.py +3 -3
  101. birdnet_analyzer/segments/cli.py +13 -14
  102. birdnet_analyzer/segments/core.py +81 -78
  103. birdnet_analyzer/segments/utils.py +383 -394
  104. birdnet_analyzer/species/__init__.py +3 -3
  105. birdnet_analyzer/species/__main__.py +3 -3
  106. birdnet_analyzer/species/cli.py +13 -14
  107. birdnet_analyzer/species/core.py +35 -35
  108. birdnet_analyzer/species/utils.py +74 -75
  109. birdnet_analyzer/train/__init__.py +3 -3
  110. birdnet_analyzer/train/__main__.py +3 -3
  111. birdnet_analyzer/train/cli.py +13 -14
  112. birdnet_analyzer/train/core.py +113 -113
  113. birdnet_analyzer/train/utils.py +877 -847
  114. birdnet_analyzer/translate.py +133 -104
  115. birdnet_analyzer/utils.py +426 -419
  116. {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.0.1.dist-info}/METADATA +137 -129
  117. birdnet_analyzer-2.0.1.dist-info/RECORD +125 -0
  118. {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.0.1.dist-info}/WHEEL +1 -1
  119. {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.0.1.dist-info}/licenses/LICENSE +18 -18
  120. birdnet_analyzer-2.0.0.dist-info/RECORD +0 -117
  121. {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.0.1.dist-info}/entry_points.txt +0 -0
  122. {birdnet_analyzer-2.0.0.dist-info → birdnet_analyzer-2.0.1.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,3 @@
1
- from birdnet_analyzer.species.core import species
2
-
3
- __all__ = ["species"]
1
+ from birdnet_analyzer.species.core import species
2
+
3
+ __all__ = ["species"]
@@ -1,3 +1,3 @@
1
- from birdnet_analyzer.species.cli import main
2
-
3
- main()
1
+ from birdnet_analyzer.species.cli import main
2
+
3
+ main()
@@ -1,14 +1,13 @@
1
- from birdnet_analyzer.utils import runtime_error_handler
2
-
3
-
4
- @runtime_error_handler
5
- def main():
6
- import birdnet_analyzer.cli as cli
7
- from birdnet_analyzer import species
8
-
9
- # Parse arguments
10
- parser = cli.species_parser()
11
-
12
- args = parser.parse_args()
13
-
14
- species(**vars(args))
1
+ from birdnet_analyzer.utils import runtime_error_handler
2
+
3
+
4
+ @runtime_error_handler
5
+ def main():
6
+ from birdnet_analyzer import cli, species
7
+
8
+ # Parse arguments
9
+ parser = cli.species_parser()
10
+
11
+ args = parser.parse_args()
12
+
13
+ species(**vars(args))
@@ -1,35 +1,35 @@
1
- from typing import Literal
2
-
3
-
4
- def species(
5
- output: str,
6
- *,
7
- lat: float = -1,
8
- lon: float = -1,
9
- week: int = -1,
10
- sf_thresh: float = 0.03,
11
- sortby: Literal["freq", "alpha"] = "freq",
12
- ):
13
- """
14
- Retrieves and processes species data based on the provided parameters.
15
- Args:
16
- output (str): The output directory or file path where the results will be stored.
17
- lat (float, optional): Latitude of the location for species filtering. Defaults to -1 (no filtering by location).
18
- lon (float, optional): Longitude of the location for species filtering. Defaults to -1 (no filtering by location).
19
- week (int, optional): Week of the year for species filtering. Defaults to -1 (no filtering by time).
20
- sf_thresh (float, optional): Species frequency threshold for filtering. Defaults to 0.03.
21
- sortby (Literal["freq", "alpha"], optional): Sorting method for the species list.
22
- "freq" sorts by frequency, and "alpha" sorts alphabetically. Defaults to "freq".
23
- Raises:
24
- FileNotFoundError: If the required model files are not found.
25
- ValueError: If invalid parameters are provided.
26
- Notes:
27
- This function ensures that the required model files exist before processing.
28
- It delegates the main processing to the `run` function from `birdnet_analyzer.species.utils`.
29
- """
30
- from birdnet_analyzer.species.utils import run
31
- from birdnet_analyzer.utils import ensure_model_exists
32
-
33
- ensure_model_exists()
34
-
35
- run(output, lat, lon, week, sf_thresh, sortby)
1
+ from typing import Literal
2
+
3
+
4
+ def species(
5
+ output: str,
6
+ *,
7
+ lat: float = -1,
8
+ lon: float = -1,
9
+ week: int = -1,
10
+ sf_thresh: float = 0.03,
11
+ sortby: Literal["freq", "alpha"] = "freq",
12
+ ):
13
+ """
14
+ Retrieves and processes species data based on the provided parameters.
15
+ Args:
16
+ output (str): The output directory or file path where the results will be stored.
17
+ lat (float, optional): Latitude of the location for species filtering. Defaults to -1 (no filtering by location).
18
+ lon (float, optional): Longitude of the location for species filtering. Defaults to -1 (no filtering by location).
19
+ week (int, optional): Week of the year for species filtering. Defaults to -1 (no filtering by time).
20
+ sf_thresh (float, optional): Species frequency threshold for filtering. Defaults to 0.03.
21
+ sortby (Literal["freq", "alpha"], optional): Sorting method for the species list.
22
+ "freq" sorts by frequency, and "alpha" sorts alphabetically. Defaults to "freq".
23
+ Raises:
24
+ FileNotFoundError: If the required model files are not found.
25
+ ValueError: If invalid parameters are provided.
26
+ Notes:
27
+ This function ensures that the required model files exist before processing.
28
+ It delegates the main processing to the `run` function from `birdnet_analyzer.species.utils`.
29
+ """
30
+ from birdnet_analyzer.species.utils import run
31
+ from birdnet_analyzer.utils import ensure_model_exists
32
+
33
+ ensure_model_exists()
34
+
35
+ run(output, lat, lon, week, sf_thresh, sortby)
@@ -1,75 +1,74 @@
1
- """Module for predicting a species list.
2
-
3
- Can be used to predict a species list using coordinates and weeks.
4
- """
5
-
6
- import os
7
-
8
- import birdnet_analyzer.config as cfg
9
- import birdnet_analyzer.model as model
10
- import birdnet_analyzer.utils as utils
11
-
12
-
13
- def get_species_list(lat: float, lon: float, week: int, threshold=0.05, sort=False) -> list[str]:
14
- """Predict a species list.
15
-
16
- Uses the model to predict the species list for the given coordinates and filters by threshold.
17
-
18
- Args:
19
- lat: The latitude.
20
- lon: The longitude.
21
- week: The week of the year [1-48]. Use -1 for year-round.
22
- threshold: Only values above or equal to threshold will be shown.
23
- sort: If the species list should be sorted.
24
-
25
- Returns:
26
- A list of all eligible species.
27
- """
28
- # Extract species from model
29
- pred = model.explore(lat, lon, week)
30
-
31
- # Make species list
32
- slist = [p[1] for p in pred if p[0] >= threshold]
33
-
34
- return sorted(slist) if sort else slist
35
-
36
-
37
- def run(output_path, lat, lon, week, threshold, sortby):
38
- """
39
- Generates a species list for a given location and time, and saves it to the specified output path.
40
- Args:
41
- output_path (str): The path where the species list will be saved. If it's a directory, the list will be saved as "species_list.txt" inside it.
42
- lat (float): Latitude of the location.
43
- lon (float): Longitude of the location.
44
- week (int): Week of the year (1-52) for which the species list is generated.
45
- threshold (float): Threshold for location filtering.
46
- sortby (str): Sorting criteria for the species list. Can be "freq" for frequency or any other value for alphabetical sorting.
47
- Returns:
48
- None
49
- """
50
- # Load eBird codes, labels
51
- cfg.LABELS = utils.read_lines(cfg.LABELS_FILE)
52
-
53
- # Set output path
54
- cfg.OUTPUT_PATH = output_path
55
-
56
- if os.path.isdir(cfg.OUTPUT_PATH):
57
- cfg.OUTPUT_PATH = os.path.join(cfg.OUTPUT_PATH, "species_list.txt")
58
-
59
- # Set config
60
- cfg.LATITUDE, cfg.LONGITUDE, cfg.WEEK = lat, lon, week
61
- cfg.LOCATION_FILTER_THRESHOLD = threshold
62
-
63
- print(f"Getting species list for {cfg.LATITUDE}/{cfg.LONGITUDE}, Week {cfg.WEEK}...", end="", flush=True)
64
-
65
- # Get species list
66
- species_list = get_species_list(
67
- cfg.LATITUDE, cfg.LONGITUDE, cfg.WEEK, cfg.LOCATION_FILTER_THRESHOLD, False if sortby == "freq" else True
68
- )
69
-
70
- print(f"Done. {len(species_list)} species on list.", flush=True)
71
-
72
- # Save species list
73
- with open(cfg.OUTPUT_PATH, "w") as f:
74
- for s in species_list:
75
- f.write(s + "\n")
1
+ """Module for predicting a species list.
2
+
3
+ Can be used to predict a species list using coordinates and weeks.
4
+ """
5
+
6
+ import os
7
+
8
+ import birdnet_analyzer.config as cfg
9
+ from birdnet_analyzer import model, utils
10
+
11
+
12
+ def get_species_list(lat: float, lon: float, week: int, threshold=0.05, sort=False) -> list[str]:
13
+ """Predict a species list.
14
+
15
+ Uses the model to predict the species list for the given coordinates and filters by threshold.
16
+
17
+ Args:
18
+ lat: The latitude.
19
+ lon: The longitude.
20
+ week: The week of the year [1-48]. Use -1 for year-round.
21
+ threshold: Only values above or equal to threshold will be shown.
22
+ sort: If the species list should be sorted.
23
+
24
+ Returns:
25
+ A list of all eligible species.
26
+ """
27
+ # Extract species from model
28
+ pred = model.explore(lat, lon, week)
29
+
30
+ # Make species list
31
+ slist = [p[1] for p in pred if p[0] >= threshold]
32
+
33
+ return sorted(slist) if sort else slist
34
+
35
+
36
+ def run(output_path, lat, lon, week, threshold, sortby):
37
+ """
38
+ Generates a species list for a given location and time, and saves it to the specified output path.
39
+ Args:
40
+ output_path (str): The path where the species list will be saved. If it's a directory, the list will be saved as "species_list.txt" inside it.
41
+ lat (float): Latitude of the location.
42
+ lon (float): Longitude of the location.
43
+ week (int): Week of the year (1-52) for which the species list is generated.
44
+ threshold (float): Threshold for location filtering.
45
+ sortby (str): Sorting criteria for the species list. Can be "freq" for frequency or any other value for alphabetical sorting.
46
+ Returns:
47
+ None
48
+ """
49
+ # Load eBird codes, labels
50
+ cfg.LABELS = utils.read_lines(cfg.LABELS_FILE)
51
+
52
+ # Set output path
53
+ cfg.OUTPUT_PATH = output_path
54
+
55
+ if os.path.isdir(cfg.OUTPUT_PATH):
56
+ cfg.OUTPUT_PATH = os.path.join(cfg.OUTPUT_PATH, "species_list.txt")
57
+
58
+ # Set config
59
+ cfg.LATITUDE, cfg.LONGITUDE, cfg.WEEK = lat, lon, week
60
+ cfg.LOCATION_FILTER_THRESHOLD = threshold
61
+
62
+ print(f"Getting species list for {cfg.LATITUDE}/{cfg.LONGITUDE}, Week {cfg.WEEK}...", end="", flush=True)
63
+
64
+ # Get species list
65
+ species_list = get_species_list(
66
+ cfg.LATITUDE, cfg.LONGITUDE, cfg.WEEK, cfg.LOCATION_FILTER_THRESHOLD, sortby != "freq"
67
+ )
68
+
69
+ print(f"Done. {len(species_list)} species on list.", flush=True)
70
+
71
+ # Save species list
72
+ with open(cfg.OUTPUT_PATH, "w") as f:
73
+ for s in species_list:
74
+ f.write(s + "\n")
@@ -1,3 +1,3 @@
1
- from birdnet_analyzer.train.core import train
2
-
3
- __all__ = ["train"]
1
+ from birdnet_analyzer.train.core import train
2
+
3
+ __all__ = ["train"]
@@ -1,3 +1,3 @@
1
- from birdnet_analyzer.train.cli import main
2
-
3
- main()
1
+ from birdnet_analyzer.train.cli import main
2
+
3
+ main()
@@ -1,14 +1,13 @@
1
- from birdnet_analyzer.utils import runtime_error_handler
2
-
3
-
4
- @runtime_error_handler
5
- def main():
6
- import birdnet_analyzer.cli as cli
7
- from birdnet_analyzer import train
8
-
9
- # Parse arguments
10
- parser = cli.train_parser()
11
-
12
- args = parser.parse_args()
13
-
14
- train(**vars(args))
1
+ from birdnet_analyzer.utils import runtime_error_handler
2
+
3
+
4
+ @runtime_error_handler
5
+ def main():
6
+ from birdnet_analyzer import cli, train
7
+
8
+ # Parse arguments
9
+ parser = cli.train_parser()
10
+
11
+ args = parser.parse_args()
12
+
13
+ train(**vars(args))
@@ -1,113 +1,113 @@
1
- from typing import Literal
2
-
3
-
4
- def train(
5
- input: str,
6
- output: str = "checkpoints/custom/Custom_Classifier",
7
- test_data: str = None,
8
- *,
9
- crop_mode: Literal["center", "first", "segments"] = "center",
10
- overlap: float = 0.0,
11
- epochs: int = 50,
12
- batch_size: int = 32,
13
- val_split: float = 0.2,
14
- learning_rate: float = 0.0001,
15
- use_focal_loss: bool = False,
16
- focal_loss_gamma: float = 2.0,
17
- focal_loss_alpha: float = 0.25,
18
- hidden_units: int = 0,
19
- dropout: float = 0.0,
20
- label_smoothing: bool = False,
21
- mixup: bool = False,
22
- upsampling_ratio: float = 0.0,
23
- upsampling_mode: Literal["repeat", "mean", "smote"] = "repeat",
24
- model_format: Literal["tflite", "raven", "both"] = "tflite",
25
- model_save_mode: Literal["replace", "append"] = "replace",
26
- cache_mode: Literal["load", "save"] | None = None,
27
- cache_file: str = "train_cache.npz",
28
- threads: int = 1,
29
- fmin: float = 0.0,
30
- fmax: float = 15000.0,
31
- audio_speed: float = 1.0,
32
- autotune: bool = False,
33
- autotune_trials: int = 50,
34
- autotune_executions_per_trial: int = 1,
35
- ):
36
- """
37
- Trains a custom classifier model using the BirdNET-Analyzer framework.
38
- Args:
39
- input (str): Path to the training data directory.
40
- test_data (str, optional): Path to the test data directory. Defaults to None. If not specified, a validation split will be used.
41
- output (str, optional): Path to save the trained model. Defaults to "checkpoints/custom/Custom_Classifier".
42
- crop_mode (Literal["center", "first", "segments", "smart"], optional): Mode for cropping audio samples. Defaults to "center".
43
- overlap (float, optional): Overlap ratio for audio segments. Defaults to 0.0.
44
- epochs (int, optional): Number of training epochs. Defaults to 50.
45
- batch_size (int, optional): Batch size for training. Defaults to 32.
46
- val_split (float, optional): Fraction of data to use for validation. Defaults to 0.2.
47
- learning_rate (float, optional): Learning rate for the optimizer. Defaults to 0.0001.
48
- use_focal_loss (bool, optional): Whether to use focal loss for training. Defaults to False.
49
- focal_loss_gamma (float, optional): Gamma parameter for focal loss. Defaults to 2.0.
50
- focal_loss_alpha (float, optional): Alpha parameter for focal loss. Defaults to 0.25.
51
- hidden_units (int, optional): Number of hidden units in the model. Defaults to 0.
52
- dropout (float, optional): Dropout rate for regularization. Defaults to 0.0.
53
- label_smoothing (bool, optional): Whether to use label smoothing. Defaults to False.
54
- mixup (bool, optional): Whether to use mixup data augmentation. Defaults to False.
55
- upsampling_ratio (float, optional): Ratio for upsampling underrepresented classes. Defaults to 0.0.
56
- upsampling_mode (Literal["repeat", "mean", "smote"], optional): Mode for upsampling. Defaults to "repeat".
57
- model_format (Literal["tflite", "raven", "both"], optional): Format to save the trained model. Defaults to "tflite".
58
- model_save_mode (Literal["replace", "append"], optional): Save mode for the model. Defaults to "replace".
59
- cache_mode (Literal["load", "save"] | None, optional): Cache mode for training data. Defaults to None.
60
- cache_file (str, optional): Path to the cache file. Defaults to "train_cache.npz".
61
- threads (int, optional): Number of CPU threads to use. Defaults to 1.
62
- fmin (float, optional): Minimum frequency for bandpass filtering. Defaults to 0.0.
63
- fmax (float, optional): Maximum frequency for bandpass filtering. Defaults to 15000.0.
64
- audio_speed (float, optional): Speed factor for audio playback. Defaults to 1.0.
65
- autotune (bool, optional): Whether to use hyperparameter autotuning. Defaults to False.
66
- autotune_trials (int, optional): Number of trials for autotuning. Defaults to 50.
67
- autotune_executions_per_trial (int, optional): Number of executions per autotuning trial. Defaults to 1.
68
- Returns:
69
- None
70
- """
71
- from birdnet_analyzer.train.utils import train_model
72
- import birdnet_analyzer.config as cfg
73
- from birdnet_analyzer.utils import ensure_model_exists
74
-
75
- ensure_model_exists()
76
-
77
- # Config
78
- cfg.TRAIN_DATA_PATH = input
79
- cfg.TEST_DATA_PATH = test_data
80
- cfg.SAMPLE_CROP_MODE = crop_mode
81
- cfg.SIG_OVERLAP = overlap
82
- cfg.CUSTOM_CLASSIFIER = output
83
- cfg.TRAIN_EPOCHS = epochs
84
- cfg.TRAIN_BATCH_SIZE = batch_size
85
- cfg.TRAIN_VAL_SPLIT = val_split
86
- cfg.TRAIN_LEARNING_RATE = learning_rate
87
- cfg.TRAIN_WITH_FOCAL_LOSS = use_focal_loss if use_focal_loss is not None else cfg.TRAIN_WITH_FOCAL_LOSS
88
- cfg.FOCAL_LOSS_GAMMA = focal_loss_gamma
89
- cfg.FOCAL_LOSS_ALPHA = focal_loss_alpha
90
- cfg.TRAIN_HIDDEN_UNITS = hidden_units
91
- cfg.TRAIN_DROPOUT = dropout
92
- cfg.TRAIN_WITH_LABEL_SMOOTHING = label_smoothing if label_smoothing is not None else cfg.TRAIN_WITH_LABEL_SMOOTHING
93
- cfg.TRAIN_WITH_MIXUP = mixup if mixup is not None else cfg.TRAIN_WITH_MIXUP
94
- cfg.UPSAMPLING_RATIO = upsampling_ratio
95
- cfg.UPSAMPLING_MODE = upsampling_mode
96
- cfg.TRAINED_MODEL_OUTPUT_FORMAT = model_format
97
- cfg.TRAINED_MODEL_SAVE_MODE = model_save_mode
98
- cfg.TRAIN_CACHE_MODE = cache_mode
99
- cfg.TRAIN_CACHE_FILE = cache_file
100
- cfg.TFLITE_THREADS = 1
101
- cfg.CPU_THREADS = threads
102
-
103
- cfg.BANDPASS_FMIN = fmin
104
- cfg.BANDPASS_FMAX = fmax
105
-
106
- cfg.AUDIO_SPEED = audio_speed
107
-
108
- cfg.AUTOTUNE = autotune
109
- cfg.AUTOTUNE_TRIALS = autotune_trials
110
- cfg.AUTOTUNE_EXECUTIONS_PER_TRIAL = autotune_executions_per_trial
111
-
112
- # Train model
113
- train_model()
1
+ from typing import Literal
2
+
3
+
4
+ def train(
5
+ audio_input: str,
6
+ output: str = "checkpoints/custom/Custom_Classifier",
7
+ test_data: str | None = None,
8
+ *,
9
+ crop_mode: Literal["center", "first", "segments"] = "center",
10
+ overlap: float = 0.0,
11
+ epochs: int = 50,
12
+ batch_size: int = 32,
13
+ val_split: float = 0.2,
14
+ learning_rate: float = 0.0001,
15
+ use_focal_loss: bool = False,
16
+ focal_loss_gamma: float = 2.0,
17
+ focal_loss_alpha: float = 0.25,
18
+ hidden_units: int = 0,
19
+ dropout: float = 0.0,
20
+ label_smoothing: bool = False,
21
+ mixup: bool = False,
22
+ upsampling_ratio: float = 0.0,
23
+ upsampling_mode: Literal["repeat", "mean", "smote"] = "repeat",
24
+ model_format: Literal["tflite", "raven", "both"] = "tflite",
25
+ model_save_mode: Literal["replace", "append"] = "replace",
26
+ cache_mode: Literal["load", "save"] | None = None,
27
+ cache_file: str = "train_cache.npz",
28
+ threads: int = 1,
29
+ fmin: float = 0.0,
30
+ fmax: float = 15000.0,
31
+ audio_speed: float = 1.0,
32
+ autotune: bool = False,
33
+ autotune_trials: int = 50,
34
+ autotune_executions_per_trial: int = 1,
35
+ ):
36
+ """
37
+ Trains a custom classifier model using the BirdNET-Analyzer framework.
38
+ Args:
39
+ audio_input (str): Path to the training data directory.
40
+ test_data (str, optional): Path to the test data directory. Defaults to None. If not specified, a validation split will be used.
41
+ output (str, optional): Path to save the trained model. Defaults to "checkpoints/custom/Custom_Classifier".
42
+ crop_mode (Literal["center", "first", "segments", "smart"], optional): Mode for cropping audio samples. Defaults to "center".
43
+ overlap (float, optional): Overlap ratio for audio segments. Defaults to 0.0.
44
+ epochs (int, optional): Number of training epochs. Defaults to 50.
45
+ batch_size (int, optional): Batch size for training. Defaults to 32.
46
+ val_split (float, optional): Fraction of data to use for validation. Defaults to 0.2.
47
+ learning_rate (float, optional): Learning rate for the optimizer. Defaults to 0.0001.
48
+ use_focal_loss (bool, optional): Whether to use focal loss for training. Defaults to False.
49
+ focal_loss_gamma (float, optional): Gamma parameter for focal loss. Defaults to 2.0.
50
+ focal_loss_alpha (float, optional): Alpha parameter for focal loss. Defaults to 0.25.
51
+ hidden_units (int, optional): Number of hidden units in the model. Defaults to 0.
52
+ dropout (float, optional): Dropout rate for regularization. Defaults to 0.0.
53
+ label_smoothing (bool, optional): Whether to use label smoothing. Defaults to False.
54
+ mixup (bool, optional): Whether to use mixup data augmentation. Defaults to False.
55
+ upsampling_ratio (float, optional): Ratio for upsampling underrepresented classes. Defaults to 0.0.
56
+ upsampling_mode (Literal["repeat", "mean", "smote"], optional): Mode for upsampling. Defaults to "repeat".
57
+ model_format (Literal["tflite", "raven", "both"], optional): Format to save the trained model. Defaults to "tflite".
58
+ model_save_mode (Literal["replace", "append"], optional): Save mode for the model. Defaults to "replace".
59
+ cache_mode (Literal["load", "save"] | None, optional): Cache mode for training data. Defaults to None.
60
+ cache_file (str, optional): Path to the cache file. Defaults to "train_cache.npz".
61
+ threads (int, optional): Number of CPU threads to use. Defaults to 1.
62
+ fmin (float, optional): Minimum frequency for bandpass filtering. Defaults to 0.0.
63
+ fmax (float, optional): Maximum frequency for bandpass filtering. Defaults to 15000.0.
64
+ audio_speed (float, optional): Speed factor for audio playback. Defaults to 1.0.
65
+ autotune (bool, optional): Whether to use hyperparameter autotuning. Defaults to False.
66
+ autotune_trials (int, optional): Number of trials for autotuning. Defaults to 50.
67
+ autotune_executions_per_trial (int, optional): Number of executions per autotuning trial. Defaults to 1.
68
+ Returns:
69
+ None
70
+ """
71
+ import birdnet_analyzer.config as cfg
72
+ from birdnet_analyzer.train.utils import train_model
73
+ from birdnet_analyzer.utils import ensure_model_exists
74
+
75
+ ensure_model_exists()
76
+
77
+ # Config
78
+ cfg.TRAIN_DATA_PATH = audio_input
79
+ cfg.TEST_DATA_PATH = test_data
80
+ cfg.SAMPLE_CROP_MODE = crop_mode
81
+ cfg.SIG_OVERLAP = overlap
82
+ cfg.CUSTOM_CLASSIFIER = output
83
+ cfg.TRAIN_EPOCHS = epochs
84
+ cfg.TRAIN_BATCH_SIZE = batch_size
85
+ cfg.TRAIN_VAL_SPLIT = val_split
86
+ cfg.TRAIN_LEARNING_RATE = learning_rate
87
+ cfg.TRAIN_WITH_FOCAL_LOSS = use_focal_loss if use_focal_loss is not None else cfg.TRAIN_WITH_FOCAL_LOSS
88
+ cfg.FOCAL_LOSS_GAMMA = focal_loss_gamma
89
+ cfg.FOCAL_LOSS_ALPHA = focal_loss_alpha
90
+ cfg.TRAIN_HIDDEN_UNITS = hidden_units
91
+ cfg.TRAIN_DROPOUT = dropout
92
+ cfg.TRAIN_WITH_LABEL_SMOOTHING = label_smoothing if label_smoothing is not None else cfg.TRAIN_WITH_LABEL_SMOOTHING
93
+ cfg.TRAIN_WITH_MIXUP = mixup if mixup is not None else cfg.TRAIN_WITH_MIXUP
94
+ cfg.UPSAMPLING_RATIO = upsampling_ratio
95
+ cfg.UPSAMPLING_MODE = upsampling_mode
96
+ cfg.TRAINED_MODEL_OUTPUT_FORMAT = model_format
97
+ cfg.TRAINED_MODEL_SAVE_MODE = model_save_mode
98
+ cfg.TRAIN_CACHE_MODE = cache_mode
99
+ cfg.TRAIN_CACHE_FILE = cache_file
100
+ cfg.TFLITE_THREADS = 1
101
+ cfg.CPU_THREADS = threads
102
+
103
+ cfg.BANDPASS_FMIN = fmin
104
+ cfg.BANDPASS_FMAX = fmax
105
+
106
+ cfg.AUDIO_SPEED = audio_speed
107
+
108
+ cfg.AUTOTUNE = autotune
109
+ cfg.AUTOTUNE_TRIALS = autotune_trials
110
+ cfg.AUTOTUNE_EXECUTIONS_PER_TRIAL = autotune_executions_per_trial
111
+
112
+ # Train model
113
+ train_model()