britekit 0.0.8__tar.gz → 0.0.10__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of britekit might be problematic. Click here for more details.

Files changed (125) hide show
  1. {britekit-0.0.8 → britekit-0.0.10}/PKG-INFO +2 -2
  2. {britekit-0.0.8 → britekit-0.0.10}/README.md +1 -1
  3. {britekit-0.0.8 → britekit-0.0.10}/britekit/__about__.py +1 -1
  4. {britekit-0.0.8 → britekit-0.0.10}/britekit/commands/__init__.py +2 -0
  5. {britekit-0.0.8 → britekit-0.0.10}/britekit/commands/_analyze.py +1 -1
  6. {britekit-0.0.8 → britekit-0.0.10}/britekit/commands/_ckpt_ops.py +1 -1
  7. {britekit-0.0.8 → britekit-0.0.10}/britekit/commands/_db_add.py +4 -4
  8. {britekit-0.0.8 → britekit-0.0.10}/britekit/commands/_db_delete.py +7 -7
  9. {britekit-0.0.8 → britekit-0.0.10}/britekit/commands/_embed.py +1 -1
  10. britekit-0.0.10/britekit/commands/_ensemble.py +238 -0
  11. {britekit-0.0.8 → britekit-0.0.10}/britekit/commands/_extract.py +2 -2
  12. {britekit-0.0.8 → britekit-0.0.10}/britekit/commands/_find_dup.py +1 -1
  13. {britekit-0.0.8 → britekit-0.0.10}/britekit/commands/_pickle.py +1 -1
  14. {britekit-0.0.8 → britekit-0.0.10}/britekit/commands/_plot.py +3 -3
  15. {britekit-0.0.8 → britekit-0.0.10}/britekit/commands/_reextract.py +1 -1
  16. {britekit-0.0.8 → britekit-0.0.10}/britekit/commands/_reports.py +6 -6
  17. {britekit-0.0.8 → britekit-0.0.10}/britekit/commands/_search.py +1 -1
  18. {britekit-0.0.8 → britekit-0.0.10}/britekit/commands/_train.py +2 -2
  19. {britekit-0.0.8 → britekit-0.0.10}/britekit/commands/_tune.py +3 -3
  20. {britekit-0.0.8 → britekit-0.0.10}/pyproject.toml +2 -0
  21. {britekit-0.0.8 → britekit-0.0.10}/src/britekit/cli.py +2 -0
  22. {britekit-0.0.8 → britekit-0.0.10}/src/britekit/core/analyzer.py +1 -1
  23. {britekit-0.0.8 → britekit-0.0.10}/src/britekit/core/audio.py +1 -1
  24. {britekit-0.0.8 → britekit-0.0.10}/src/britekit/core/config_loader.py +13 -7
  25. {britekit-0.0.8 → britekit-0.0.10}/src/britekit/core/data_module.py +1 -1
  26. {britekit-0.0.8 → britekit-0.0.10}/src/britekit/core/dataset.py +1 -1
  27. {britekit-0.0.8 → britekit-0.0.10}/src/britekit/core/pickler.py +1 -1
  28. {britekit-0.0.8 → britekit-0.0.10}/src/britekit/core/plot.py +1 -1
  29. {britekit-0.0.8 → britekit-0.0.10}/src/britekit/core/predictor.py +1 -1
  30. {britekit-0.0.8 → britekit-0.0.10}/src/britekit/core/reextractor.py +1 -1
  31. {britekit-0.0.8 → britekit-0.0.10}/src/britekit/core/trainer.py +4 -3
  32. {britekit-0.0.8 → britekit-0.0.10}/src/britekit/core/tuner.py +1 -1
  33. {britekit-0.0.8 → britekit-0.0.10}/src/britekit/core/util.py +3 -3
  34. {britekit-0.0.8 → britekit-0.0.10}/src/britekit/models/base_model.py +1 -1
  35. {britekit-0.0.8 → britekit-0.0.10}/src/britekit/models/model_loader.py +1 -1
  36. {britekit-0.0.8 → britekit-0.0.10}/src/britekit/models/timm_model.py +1 -1
  37. {britekit-0.0.8 → britekit-0.0.10}/src/britekit/testing/per_minute_tester.py +1 -1
  38. {britekit-0.0.8 → britekit-0.0.10}/src/britekit/testing/per_recording_tester.py +1 -1
  39. {britekit-0.0.8 → britekit-0.0.10}/src/britekit/testing/per_segment_tester.py +1 -1
  40. {britekit-0.0.8 → britekit-0.0.10}/src/britekit/training_db/extractor.py +1 -1
  41. {britekit-0.0.8 → britekit-0.0.10}/src/britekit/training_db/training_db.py +0 -2
  42. {britekit-0.0.8 → britekit-0.0.10}/.gitignore +0 -0
  43. {britekit-0.0.8 → britekit-0.0.10}/LICENSE.txt +0 -0
  44. {britekit-0.0.8 → britekit-0.0.10}/britekit/__init__.py +0 -0
  45. {britekit-0.0.8 → britekit-0.0.10}/britekit/commands/_audioset.py +0 -0
  46. {britekit-0.0.8 → britekit-0.0.10}/britekit/commands/_calibrate.py +0 -0
  47. {britekit-0.0.8 → britekit-0.0.10}/britekit/commands/_inat.py +0 -0
  48. {britekit-0.0.8 → britekit-0.0.10}/britekit/commands/_init.py +0 -0
  49. {britekit-0.0.8 → britekit-0.0.10}/britekit/commands/_wav2mp3.py +0 -0
  50. {britekit-0.0.8 → britekit-0.0.10}/britekit/commands/_xeno.py +0 -0
  51. {britekit-0.0.8 → britekit-0.0.10}/britekit/commands/_youtube.py +0 -0
  52. {britekit-0.0.8 → britekit-0.0.10}/britekit/core/__init__.py +0 -0
  53. {britekit-0.0.8 → britekit-0.0.10}/britekit/install/data/audioset/class_inclusion.csv +0 -0
  54. {britekit-0.0.8 → britekit-0.0.10}/britekit/install/data/audioset/class_list.csv +0 -0
  55. {britekit-0.0.8 → britekit-0.0.10}/britekit/install/data/audioset/curated/aircraft.csv +0 -0
  56. {britekit-0.0.8 → britekit-0.0.10}/britekit/install/data/audioset/curated/car.csv +0 -0
  57. {britekit-0.0.8 → britekit-0.0.10}/britekit/install/data/audioset/curated/chainsaw.csv +0 -0
  58. {britekit-0.0.8 → britekit-0.0.10}/britekit/install/data/audioset/curated/cow.csv +0 -0
  59. {britekit-0.0.8 → britekit-0.0.10}/britekit/install/data/audioset/curated/cricket.csv +0 -0
  60. {britekit-0.0.8 → britekit-0.0.10}/britekit/install/data/audioset/curated/dog.csv +0 -0
  61. {britekit-0.0.8 → britekit-0.0.10}/britekit/install/data/audioset/curated/rain.csv +0 -0
  62. {britekit-0.0.8 → britekit-0.0.10}/britekit/install/data/audioset/curated/rooster.csv +0 -0
  63. {britekit-0.0.8 → britekit-0.0.10}/britekit/install/data/audioset/curated/sheep.csv +0 -0
  64. {britekit-0.0.8 → britekit-0.0.10}/britekit/install/data/audioset/curated/siren.csv +0 -0
  65. {britekit-0.0.8 → britekit-0.0.10}/britekit/install/data/audioset/curated/speech.csv +0 -0
  66. {britekit-0.0.8 → britekit-0.0.10}/britekit/install/data/audioset/curated/truck.csv +0 -0
  67. {britekit-0.0.8 → britekit-0.0.10}/britekit/install/data/audioset/curated/wind.csv +0 -0
  68. {britekit-0.0.8 → britekit-0.0.10}/britekit/install/data/audioset/unbalanced_train_segments.csv +0 -0
  69. {britekit-0.0.8 → britekit-0.0.10}/britekit/install/data/classes.csv +0 -0
  70. {britekit-0.0.8 → britekit-0.0.10}/britekit/install/data/ignore.txt +0 -0
  71. {britekit-0.0.8 → britekit-0.0.10}/britekit/install/yaml/base_config.yaml +0 -0
  72. {britekit-0.0.8 → britekit-0.0.10}/britekit/install/yaml/samples/cfg_infer.yaml +0 -0
  73. {britekit-0.0.8 → britekit-0.0.10}/britekit/install/yaml/samples/train_dla.yaml +0 -0
  74. {britekit-0.0.8 → britekit-0.0.10}/britekit/install/yaml/samples/train_effnet.yaml +0 -0
  75. {britekit-0.0.8 → britekit-0.0.10}/britekit/install/yaml/samples/train_gernet.yaml +0 -0
  76. {britekit-0.0.8 → britekit-0.0.10}/britekit/install/yaml/samples/train_hgnet.yaml +0 -0
  77. {britekit-0.0.8 → britekit-0.0.10}/britekit/install/yaml/samples/train_timm.yaml +0 -0
  78. {britekit-0.0.8 → britekit-0.0.10}/britekit/install/yaml/samples/train_vovnet.yaml +0 -0
  79. {britekit-0.0.8 → britekit-0.0.10}/britekit/install/yaml/samples/tune_dropout.yaml +0 -0
  80. {britekit-0.0.8 → britekit-0.0.10}/britekit/install/yaml/samples/tune_learning_rate.yaml +0 -0
  81. {britekit-0.0.8 → britekit-0.0.10}/britekit/install/yaml/samples/tune_optimizer.yaml +0 -0
  82. {britekit-0.0.8 → britekit-0.0.10}/britekit/install/yaml/samples/tune_smooth.yaml +0 -0
  83. {britekit-0.0.8 → britekit-0.0.10}/install/data/audioset/class_inclusion.csv +0 -0
  84. {britekit-0.0.8 → britekit-0.0.10}/install/data/audioset/class_list.csv +0 -0
  85. {britekit-0.0.8 → britekit-0.0.10}/install/data/audioset/curated/aircraft.csv +0 -0
  86. {britekit-0.0.8 → britekit-0.0.10}/install/data/audioset/curated/car.csv +0 -0
  87. {britekit-0.0.8 → britekit-0.0.10}/install/data/audioset/curated/chainsaw.csv +0 -0
  88. {britekit-0.0.8 → britekit-0.0.10}/install/data/audioset/curated/cow.csv +0 -0
  89. {britekit-0.0.8 → britekit-0.0.10}/install/data/audioset/curated/cricket.csv +0 -0
  90. {britekit-0.0.8 → britekit-0.0.10}/install/data/audioset/curated/dog.csv +0 -0
  91. {britekit-0.0.8 → britekit-0.0.10}/install/data/audioset/curated/rain.csv +0 -0
  92. {britekit-0.0.8 → britekit-0.0.10}/install/data/audioset/curated/rooster.csv +0 -0
  93. {britekit-0.0.8 → britekit-0.0.10}/install/data/audioset/curated/sheep.csv +0 -0
  94. {britekit-0.0.8 → britekit-0.0.10}/install/data/audioset/curated/siren.csv +0 -0
  95. {britekit-0.0.8 → britekit-0.0.10}/install/data/audioset/curated/speech.csv +0 -0
  96. {britekit-0.0.8 → britekit-0.0.10}/install/data/audioset/curated/truck.csv +0 -0
  97. {britekit-0.0.8 → britekit-0.0.10}/install/data/audioset/curated/wind.csv +0 -0
  98. {britekit-0.0.8 → britekit-0.0.10}/install/data/audioset/unbalanced_train_segments.csv +0 -0
  99. {britekit-0.0.8 → britekit-0.0.10}/install/data/classes.csv +0 -0
  100. {britekit-0.0.8 → britekit-0.0.10}/install/data/ignore.txt +0 -0
  101. {britekit-0.0.8 → britekit-0.0.10}/install/yaml/base_config.yaml +0 -0
  102. {britekit-0.0.8 → britekit-0.0.10}/install/yaml/samples/cfg_infer.yaml +0 -0
  103. {britekit-0.0.8 → britekit-0.0.10}/install/yaml/samples/train_dla.yaml +0 -0
  104. {britekit-0.0.8 → britekit-0.0.10}/install/yaml/samples/train_effnet.yaml +0 -0
  105. {britekit-0.0.8 → britekit-0.0.10}/install/yaml/samples/train_gernet.yaml +0 -0
  106. {britekit-0.0.8 → britekit-0.0.10}/install/yaml/samples/train_hgnet.yaml +0 -0
  107. {britekit-0.0.8 → britekit-0.0.10}/install/yaml/samples/train_timm.yaml +0 -0
  108. {britekit-0.0.8 → britekit-0.0.10}/install/yaml/samples/train_vovnet.yaml +0 -0
  109. {britekit-0.0.8 → britekit-0.0.10}/install/yaml/samples/tune_dropout.yaml +0 -0
  110. {britekit-0.0.8 → britekit-0.0.10}/install/yaml/samples/tune_learning_rate.yaml +0 -0
  111. {britekit-0.0.8 → britekit-0.0.10}/install/yaml/samples/tune_optimizer.yaml +0 -0
  112. {britekit-0.0.8 → britekit-0.0.10}/install/yaml/samples/tune_smooth.yaml +0 -0
  113. {britekit-0.0.8 → britekit-0.0.10}/src/britekit/core/augmentation.py +0 -0
  114. {britekit-0.0.8 → britekit-0.0.10}/src/britekit/core/base_config.py +0 -0
  115. {britekit-0.0.8 → britekit-0.0.10}/src/britekit/core/exceptions.py +0 -0
  116. {britekit-0.0.8 → britekit-0.0.10}/src/britekit/models/dla.py +0 -0
  117. {britekit-0.0.8 → britekit-0.0.10}/src/britekit/models/effnet.py +0 -0
  118. {britekit-0.0.8 → britekit-0.0.10}/src/britekit/models/gernet.py +0 -0
  119. {britekit-0.0.8 → britekit-0.0.10}/src/britekit/models/head_factory.py +0 -0
  120. {britekit-0.0.8 → britekit-0.0.10}/src/britekit/models/hgnet.py +0 -0
  121. {britekit-0.0.8 → britekit-0.0.10}/src/britekit/models/vovnet.py +0 -0
  122. {britekit-0.0.8 → britekit-0.0.10}/src/britekit/occurrence_db/occurrence_data_provider.py +0 -0
  123. {britekit-0.0.8 → britekit-0.0.10}/src/britekit/occurrence_db/occurrence_db.py +0 -0
  124. {britekit-0.0.8 → britekit-0.0.10}/src/britekit/testing/base_tester.py +0 -0
  125. {britekit-0.0.8 → britekit-0.0.10}/src/britekit/training_db/training_data_provider.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: britekit
3
- Version: 0.0.8
3
+ Version: 0.0.10
4
4
  Summary: Core functions for bioacoustic recognizers.
5
5
  Project-URL: Documentation, https://github.com/jhuus/BriteKit#readme
6
6
  Project-URL: Issues, https://github.com/jhuus/BriteKit/issues
@@ -120,7 +120,7 @@ train:
120
120
  This overrides the default values for model_type, learning_rate, drop_rate and num_epochs. When using the API, you can update configuration parameters like this:
121
121
  ```
122
122
  import britekit as bk
123
- cfg, _ = bk.get_config()
123
+ cfg = bk.get_config()
124
124
  cfg.train.model_type = "effnet.4"
125
125
  ```
126
126
  ## Downloading Recordings
@@ -81,7 +81,7 @@ train:
81
81
  This overrides the default values for model_type, learning_rate, drop_rate and num_epochs. When using the API, you can update configuration parameters like this:
82
82
  ```
83
83
  import britekit as bk
84
- cfg, _ = bk.get_config()
84
+ cfg = bk.get_config()
85
85
  cfg.train.model_type = "effnet.4"
86
86
  ```
87
87
  ## Downloading Recordings
@@ -1,4 +1,4 @@
1
1
  # SPDX-FileCopyrightText: 2025-present Jan Huus <jhuus1@gmail.com>
2
2
  #
3
3
  # SPDX-License-Identifier: MIT
4
- __version__ = "0.0.8"
4
+ __version__ = "0.0.10"
@@ -13,6 +13,7 @@ from ._db_delete import (
13
13
  del_stype,
14
14
  )
15
15
  from ._embed import embed
16
+ from ._ensemble import ensemble
16
17
  from ._extract import extract_all, extract_by_image
17
18
  from ._find_dup import find_dup
18
19
  from ._inat import inat
@@ -54,6 +55,7 @@ __all__ = [
54
55
  "del_src",
55
56
  "del_stype",
56
57
  "embed",
58
+ "ensemble",
57
59
  "extract_all",
58
60
  "extract_by_image",
59
61
  "find_dup",
@@ -47,7 +47,7 @@ def analyze(
47
47
  from britekit.core.analyzer import Analyzer
48
48
 
49
49
  util.set_logging()
50
- cfg, _ = get_config(cfg_path)
50
+ cfg = get_config(cfg_path)
51
51
  try:
52
52
  if rtype not in {"audacity", "csv", "both"}:
53
53
  logging.error(f"Error. invalid rtype value: {rtype}")
@@ -142,7 +142,7 @@ def ckpt_onnx(
142
142
  import torch
143
143
  from britekit.models.model_loader import load_from_checkpoint
144
144
 
145
- cfg, _ = get_config(cfg_path)
145
+ cfg = get_config(cfg_path)
146
146
  base, _ = os.path.splitext(input_path)
147
147
  output_path = base + ".onnx"
148
148
  model = load_from_checkpoint(input_path)
@@ -23,7 +23,7 @@ def add_cat(db_path: Optional[str]=None, name: str="") -> None:
23
23
  """
24
24
  from britekit.training_db.training_db import TrainingDatabase
25
25
 
26
- cfg, _ = get_config()
26
+ cfg = get_config()
27
27
  if db_path is None:
28
28
  db_path = cfg.train.train_db
29
29
 
@@ -63,7 +63,7 @@ def add_stype(db_path: Optional[str]=None, name: str="") -> None:
63
63
  """
64
64
  from britekit.training_db.training_db import TrainingDatabase
65
65
 
66
- cfg, _ = get_config()
66
+ cfg = get_config()
67
67
  if db_path is None:
68
68
  db_path = cfg.train.train_db
69
69
 
@@ -103,7 +103,7 @@ def add_src(db_path: Optional[str]=None, name: str="") -> None:
103
103
  """
104
104
  from britekit.training_db.training_db import TrainingDatabase
105
105
 
106
- cfg, _ = get_config()
106
+ cfg = get_config()
107
107
  if db_path is None:
108
108
  db_path = cfg.train.train_db
109
109
 
@@ -154,7 +154,7 @@ def add_class(
154
154
  """
155
155
  from britekit.training_db.training_db import TrainingDatabase
156
156
 
157
- cfg, _ = get_config()
157
+ cfg = get_config()
158
158
  if db_path is None:
159
159
  db_path = cfg.train.train_db
160
160
 
@@ -25,7 +25,7 @@ def del_cat(db_path: Optional[str]=None, name: Optional[str]=None) -> None:
25
25
  """
26
26
  from britekit.training_db.training_db import TrainingDatabase
27
27
 
28
- cfg, _ = get_config()
28
+ cfg = get_config()
29
29
  if db_path is None:
30
30
  db_path = cfg.train.train_db
31
31
 
@@ -78,7 +78,7 @@ def del_class(db_path: Optional[str]=None, name: Optional[str]=None) -> None:
78
78
  """
79
79
  from britekit.training_db.training_db import TrainingDatabase
80
80
 
81
- cfg, _ = get_config()
81
+ cfg = get_config()
82
82
  if db_path is None:
83
83
  db_path = cfg.train.train_db
84
84
 
@@ -128,7 +128,7 @@ def del_rec(db_path: Optional[str]=None, file_name: Optional[str]=None) -> None:
128
128
  """
129
129
  from britekit.training_db.training_db import TrainingDatabase
130
130
 
131
- cfg, _ = get_config()
131
+ cfg = get_config()
132
132
  if db_path is None:
133
133
  db_path = cfg.train.train_db
134
134
 
@@ -172,7 +172,7 @@ def del_sgroup(db_path: Optional[str]=None, name: Optional[str]=None) -> None:
172
172
  """
173
173
  from britekit.training_db.training_db import TrainingDatabase
174
174
 
175
- cfg, _ = get_config()
175
+ cfg = get_config()
176
176
  if db_path is None:
177
177
  db_path = cfg.train.train_db
178
178
 
@@ -217,7 +217,7 @@ def del_stype(db_path: Optional[str]=None, name: Optional[str]=None) -> None:
217
217
  """
218
218
  from britekit.training_db.training_db import TrainingDatabase
219
219
 
220
- cfg, _ = get_config()
220
+ cfg = get_config()
221
221
  if db_path is None:
222
222
  db_path = cfg.train.train_db
223
223
 
@@ -262,7 +262,7 @@ def del_src(db_path: Optional[str]=None, name: Optional[str]=None) -> None:
262
262
  """
263
263
  from britekit.training_db.training_db import TrainingDatabase
264
264
 
265
- cfg, _ = get_config()
265
+ cfg = get_config()
266
266
  if db_path is None:
267
267
  db_path = cfg.train.train_db
268
268
 
@@ -311,7 +311,7 @@ def del_seg(db_path: Optional[str]=None, class_name: Optional[str]=None, dir_pat
311
311
  """
312
312
  from britekit.training_db.training_db import TrainingDatabase
313
313
 
314
- cfg, _ = get_config()
314
+ cfg = get_config()
315
315
  if db_path is None:
316
316
  db_path = cfg.train.train_db
317
317
 
@@ -57,7 +57,7 @@ def embed(
57
57
  from britekit.models.model_loader import load_from_checkpoint
58
58
  from britekit.training_db.training_db import TrainingDatabase
59
59
 
60
- cfg, _ = get_config(cfg_path)
60
+ cfg = get_config(cfg_path)
61
61
  if db_path is None:
62
62
  db_path = cfg.train.train_db
63
63
 
@@ -0,0 +1,238 @@
1
+ # File name starts with _ to keep it out of typeahead for API users.
2
+ # Defer some imports to improve --help performance.
3
+ import logging
4
+ import os
5
+ from pathlib import Path
6
+ import tempfile
7
+ from typing import Optional
8
+
9
+ import click
10
+
11
+ from britekit.core.config_loader import get_config
12
+ from britekit.core import util
13
+
14
+ def _eval_ensemble(ensemble, temp_dir, annotations_path, recording_dir):
15
+ import shutil
16
+
17
+ from britekit.core.analyzer import Analyzer
18
+ from britekit.testing.per_segment_tester import PerSegmentTester
19
+
20
+ # delete any checkpoints in the temp dir
21
+ for filename in os.listdir(temp_dir):
22
+ file_path = os.path.join(temp_dir, filename)
23
+ os.remove(file_path)
24
+
25
+ # copy checkpoints to the temp dir
26
+ for file_path in ensemble:
27
+ file_name = Path(file_path).name
28
+ dest_path = os.path.join(temp_dir, file_name)
29
+ shutil.copyfile(file_path, dest_path)
30
+
31
+ # run inference on the given test
32
+ util.set_logging(level=logging.ERROR) # suppress logging during inference and analysis
33
+ label_dir = "ensemble_evaluation_labels"
34
+ inference_output_dir = str(Path(recording_dir) / label_dir)
35
+ Analyzer().run(recording_dir, inference_output_dir)
36
+
37
+ min_score = 0.8 # irrelevant really
38
+ with tempfile.TemporaryDirectory() as output_dir:
39
+ tester = PerSegmentTester(
40
+ annotations_path,
41
+ recording_dir,
42
+ inference_output_dir,
43
+ output_dir,
44
+ min_score,
45
+ )
46
+ tester.initialize()
47
+
48
+ pr_stats = tester.get_pr_auc_stats()
49
+ roc_stats = tester.get_roc_auc_stats()
50
+
51
+ scores = {
52
+ "macro_pr": pr_stats["macro_pr_auc"],
53
+ "micro_pr": pr_stats["micro_pr_auc_trained"],
54
+ "macro_roc": roc_stats["macro_roc_auc"],
55
+ "micro_roc": roc_stats["micro_roc_auc_trained"]
56
+ }
57
+
58
+ shutil.rmtree(inference_output_dir)
59
+ util.set_logging() # restore logging
60
+
61
+ return scores
62
+
63
+ def ensemble(
64
+ cfg_path: Optional[str]=None,
65
+ ckpt_path: str="",
66
+ ensemble_size: int=3,
67
+ num_tries: int=100,
68
+ metric: str = "micro_roc",
69
+ annotations_path: str = "",
70
+ recordings_path: Optional[str] = None,
71
+ output_path: str = "",
72
+ ) -> None:
73
+ """
74
+ Find the best ensemble of a given size from a group of checkpoints.
75
+
76
+ Given a directory containing checkpoints, and an ensemble size (default=3), select random
77
+ ensembles of the given size and test each one to identify the best ensemble.
78
+
79
+ Args:
80
+ cfg_path (str, optional): Path to YAML file defining configuration overrides.
81
+ ckpt_path (str): Path to directory containing checkpoints.
82
+ ensemble_size (int): Number of checkpoints in ensemble (default=3).
83
+ num_tries (int): Maximum number of ensembles to try (default=100).
84
+ metric (str): Metric to use to compare ensembles (default=micro_roc).
85
+ annotations_path (str): Path to CSV file containing ground truth annotations.
86
+ recordings_path (str, optional): Directory containing audio recordings. Defaults to annotations directory.
87
+ output_path (str): Directory where reports will be saved.
88
+ """
89
+ import glob
90
+ import itertools
91
+ import math
92
+ import random
93
+
94
+ if metric not in ["macro_pr", "micro_pr", "macro_roc", "micro_roc"]:
95
+ logging.error(f"Error: invalid metric ({metric})")
96
+ return
97
+
98
+ cfg = get_config(cfg_path)
99
+ ckpt_paths = sorted(glob.glob(os.path.join(ckpt_path, "*.ckpt")))
100
+ num_ckpts = len(ckpt_paths)
101
+ if num_ckpts == 0:
102
+ logging.error(f"Error: no checkpoints found in {ckpt_path}")
103
+ return
104
+ elif num_ckpts < ensemble_size:
105
+ logging.error(f"Error: number of checkpoints ({num_ckpts}) is less than requested ensemble size ({ensemble_size})")
106
+ return
107
+
108
+ if not recordings_path:
109
+ recordings_path = str(Path(annotations_path).parent)
110
+
111
+ with tempfile.TemporaryDirectory() as temp_dir:
112
+ cfg.misc.ckpt_folder = temp_dir
113
+ cfg.infer.min_score = 0
114
+
115
+ best_score = 0
116
+ best_ensemble = None
117
+ count = 1
118
+ total_combinations = math.comb(len(ckpt_paths), ensemble_size)
119
+ if total_combinations <= num_tries:
120
+ # Exhaustive search
121
+ logging.info("Doing exhaustive search")
122
+ for ensemble in itertools.combinations(ckpt_paths, ensemble_size):
123
+ scores = _eval_ensemble(ensemble, temp_dir, annotations_path, recordings_path)
124
+ logging.info(f"For ensemble {count} of {total_combinations}, score = {scores[metric]:.4f}")
125
+ if scores[metric] > best_score:
126
+ best_score = scores[metric]
127
+ best_ensemble = ensemble
128
+
129
+ count += 1
130
+ else:
131
+ # Random sampling without replacement
132
+ logging.info("Doing random sampling")
133
+ seen: set = set()
134
+ while len(seen) < num_tries:
135
+ ensemble = tuple(sorted(random.sample(ckpt_paths, ensemble_size)))
136
+ if ensemble not in seen:
137
+ seen.add(ensemble)
138
+ scores = _eval_ensemble(ensemble, temp_dir, annotations_path, recordings_path)
139
+ logging.info(f"For ensemble {count} of {num_tries}, score = {scores[metric]:.4f}")
140
+ if scores[metric] > best_score:
141
+ best_score = scores[metric]
142
+ best_ensemble = ensemble
143
+
144
+ count += 1
145
+
146
+ logging.info(f"Best score = {best_score:.4f}")
147
+
148
+ assert best_ensemble is not None
149
+ best_names = [Path(ckpt_path).name for ckpt_path in best_ensemble]
150
+ logging.info(f"Best ensemble = {best_names}")
151
+
152
+ @click.command(
153
+ name="ensemble",
154
+ short_help="Find the best ensemble of a given size from a group of checkpoints.",
155
+ help=util.cli_help_from_doc(ensemble.__doc__),
156
+ )
157
+ @click.option(
158
+ "-c",
159
+ "--cfg",
160
+ "cfg_path",
161
+ type=click.Path(exists=True),
162
+ required=False,
163
+ help="Path to YAML file defining config overrides.",
164
+ )
165
+ @click.option(
166
+ "--ckpt_path",
167
+ "ckpt_path",
168
+ type=click.Path(exists=True, file_okay=False, dir_okay=True),
169
+ required=True,
170
+ help="Directory containing checkpoints."
171
+ )
172
+ @click.option(
173
+ "-e",
174
+ "--ensemble_size",
175
+ "ensemble_size",
176
+ type=int,
177
+ default=3,
178
+ help="Number of checkpoints in ensemble (default=3)."
179
+ )
180
+ @click.option(
181
+ "-n",
182
+ "--num_tries",
183
+ "num_tries",
184
+ type=int,
185
+ default=100,
186
+ help="Maximum number of ensembles to try (default=100)."
187
+ )
188
+ @click.option(
189
+ "-m",
190
+ "--metric",
191
+ "metric",
192
+ type=click.Choice(
193
+ [
194
+ "macro_pr",
195
+ "micro_pr",
196
+ "macro_roc",
197
+ "micro_roc",
198
+ ]
199
+ ),
200
+ default="micro_roc",
201
+ help="Metric used to compare ensembles (default=micro_roc). Macro-averaging uses annotated classes only, but micro-averaging uses all classes.",
202
+ )
203
+ @click.option(
204
+ "-a",
205
+ "--annotations",
206
+ "annotations_path",
207
+ type=click.Path(exists=True, file_okay=True, dir_okay=False),
208
+ required=True,
209
+ help="Path to CSV file containing annotations or ground truth).",
210
+ )
211
+ @click.option(
212
+ "-r",
213
+ "--recordings",
214
+ "recordings_path",
215
+ type=click.Path(exists=True, file_okay=False, dir_okay=True),
216
+ required=False,
217
+ help="Recordings directory. Default is directory containing annotations file.",
218
+ )
219
+ @click.option(
220
+ "-o",
221
+ "--output",
222
+ "output_path",
223
+ type=click.Path(file_okay=False, dir_okay=True),
224
+ required=True,
225
+ help="Path to output directory.",
226
+ )
227
+ def _ensemble_cmd(
228
+ cfg_path: Optional[str],
229
+ ckpt_path: str,
230
+ ensemble_size: int,
231
+ num_tries: int,
232
+ metric: str,
233
+ annotations_path: str,
234
+ recordings_path: Optional[str],
235
+ output_path: str,
236
+ ) -> None:
237
+ util.set_logging()
238
+ ensemble(cfg_path, ckpt_path, ensemble_size, num_tries, metric, annotations_path, recordings_path, output_path)
@@ -42,7 +42,7 @@ def extract_all(
42
42
  from britekit.training_db.extractor import Extractor
43
43
  from britekit.training_db.training_db import TrainingDatabase
44
44
 
45
- cfg, _ = get_config(cfg_path)
45
+ cfg = get_config(cfg_path)
46
46
  if db_path is not None:
47
47
  cfg.train.train_db = db_path
48
48
 
@@ -172,7 +172,7 @@ def extract_by_image(
172
172
  from britekit.training_db.extractor import Extractor
173
173
  from britekit.training_db.training_db import TrainingDatabase
174
174
 
175
- cfg, _ = get_config(cfg_path)
175
+ cfg = get_config(cfg_path)
176
176
  if db_path is not None:
177
177
  cfg.train.train_db = db_path
178
178
 
@@ -105,7 +105,7 @@ def find_dup(
105
105
  else:
106
106
  return False
107
107
 
108
- cfg, _ = get_config(cfg_path)
108
+ cfg = get_config(cfg_path)
109
109
  if db_path is None:
110
110
  db_path = cfg.train.train_db
111
111
 
@@ -37,7 +37,7 @@ def pickle(
37
37
  """
38
38
  from britekit.core.pickler import Pickler
39
39
 
40
- cfg, _ = get_config(cfg_path)
40
+ cfg = get_config(cfg_path)
41
41
  if db_path is None:
42
42
  db_path = cfg.train.train_db
43
43
 
@@ -89,7 +89,7 @@ def plot_db(
89
89
  from britekit.core.plot import plot_spec
90
90
  from britekit.training_db.training_db import TrainingDatabase
91
91
 
92
- cfg, _ = get_config(cfg_path)
92
+ cfg = get_config(cfg_path)
93
93
  if power is not None:
94
94
  cfg.audio.power = power
95
95
 
@@ -247,7 +247,7 @@ def plot_dir(
247
247
  """
248
248
  from britekit.core.audio import Audio
249
249
 
250
- cfg, _ = get_config(cfg_path)
250
+ cfg = get_config(cfg_path)
251
251
  if power is not None:
252
252
  cfg.audio.power = power
253
253
 
@@ -363,7 +363,7 @@ def plot_rec(
363
363
  """
364
364
  from britekit.core.audio import Audio
365
365
 
366
- cfg, _ = get_config(cfg_path)
366
+ cfg = get_config(cfg_path)
367
367
  if power is not None:
368
368
  cfg.audio.power = power
369
369
 
@@ -38,7 +38,7 @@ def reextract(
38
38
  spec_group (str): Spectrogram group name for storing the extracted spectrograms. Defaults to 'default'.
39
39
  """
40
40
  from britekit.core.reextractor import Reextractor
41
- cfg, _ = get_config(cfg_path)
41
+ cfg = get_config(cfg_path)
42
42
 
43
43
  if class_name and classes_path:
44
44
  logging.error("Only one of --name and --classes may be specified.")
@@ -143,7 +143,7 @@ def rpt_db(cfg_path: Optional[str] = None,
143
143
  from britekit.training_db.training_db import TrainingDatabase
144
144
  from britekit.training_db.training_data_provider import TrainingDataProvider
145
145
 
146
- cfg, _ = get_config(cfg_path)
146
+ cfg = get_config(cfg_path)
147
147
  if db_path is not None:
148
148
  cfg.train.train_db = db_path
149
149
 
@@ -214,7 +214,7 @@ def rpt_epochs(
214
214
  from britekit.core.analyzer import Analyzer
215
215
  from britekit.testing.per_segment_tester import PerSegmentTester
216
216
 
217
- cfg, _ = get_config(cfg_path)
217
+ cfg = get_config(cfg_path)
218
218
  ckpt_paths = glob.glob(str(Path(input_path) / "*.ckpt"))
219
219
  if len(ckpt_paths) == 0:
220
220
  logging.error(f"No checkpoint files found in {input_path}")
@@ -276,14 +276,14 @@ def rpt_epochs(
276
276
  tester.initialize()
277
277
 
278
278
  pr_stats = tester.get_pr_auc_stats()
279
- pr_score = pr_stats["micro_pr_auc"]
279
+ pr_score = pr_stats["micro_pr_auc_trained"]
280
280
  pr_scores.append(pr_score)
281
281
  if pr_score > max_pr_score:
282
282
  max_pr_score = pr_score
283
283
  max_pr_epoch = epoch_num
284
284
 
285
285
  roc_stats = tester.get_roc_auc_stats()
286
- roc_score = roc_stats["micro_roc_auc"]
286
+ roc_score = roc_stats["micro_roc_auc_trained"]
287
287
  roc_scores.append(roc_score)
288
288
  if roc_score > max_roc_score:
289
289
  max_roc_score = roc_score
@@ -403,7 +403,7 @@ def rpt_labels(
403
403
  """
404
404
  import pandas as pd
405
405
 
406
- cfg, _ = get_config()
406
+ cfg = get_config()
407
407
  if min_score is None:
408
408
  min_score = cfg.infer.min_score
409
409
 
@@ -556,7 +556,7 @@ def rpt_test(
556
556
  from britekit.testing.per_recording_tester import PerRecordingTester
557
557
  from britekit.testing.per_segment_tester import PerSegmentTester
558
558
 
559
- cfg, _ = get_config()
559
+ cfg = get_config()
560
560
  try:
561
561
  if not recordings_path:
562
562
  recordings_path = str(Path(annotations_path).parent)
@@ -62,7 +62,7 @@ def search(
62
62
  from britekit.models.model_loader import load_from_checkpoint
63
63
  from britekit.training_db.training_db import TrainingDatabase
64
64
 
65
- cfg, _ = get_config(cfg_path)
65
+ cfg = get_config(cfg_path)
66
66
 
67
67
  if not os.path.exists(output_path):
68
68
  os.makedirs(output_path)
@@ -31,7 +31,7 @@ def train(
31
31
  """
32
32
  from britekit.core.trainer import Trainer
33
33
 
34
- cfg, _ = get_config(cfg_path)
34
+ cfg = get_config(cfg_path)
35
35
  try:
36
36
  start_time = time.time()
37
37
  Trainer().run()
@@ -89,7 +89,7 @@ def find_lr(cfg_path: str, num_batches: int):
89
89
  """
90
90
  from britekit.core.trainer import Trainer
91
91
 
92
- cfg, _ = get_config(cfg_path)
92
+ cfg = get_config(cfg_path)
93
93
  try:
94
94
  suggested_lr, fig = Trainer().find_lr(num_batches)
95
95
  fig.savefig("learning_rates.jpeg")
@@ -18,7 +18,7 @@ def tune(
18
18
  param_path: Optional[str] = None,
19
19
  output_path: str = "",
20
20
  annotations_path: str = "",
21
- metric: str = "macro_roc",
21
+ metric: str = "micro_roc",
22
22
  recordings_path: str = "",
23
23
  train_log_path: str = "",
24
24
  num_trials: int = 0,
@@ -58,7 +58,7 @@ def tune(
58
58
  from britekit.core.tuner import Tuner
59
59
 
60
60
  try:
61
- cfg, _ = get_config(cfg_path)
61
+ cfg = get_config(cfg_path)
62
62
  if extract and skip_training:
63
63
  logging.error(
64
64
  "Performing spectrogram extract is incompatible with skipping training."
@@ -159,7 +159,7 @@ def tune(
159
159
  "micro_roc",
160
160
  ]
161
161
  ),
162
- default="macro_roc",
162
+ default="micro_roc",
163
163
  help="Metric used to compare runs. Macro-averaging uses annotated classes only, but micro-averaging uses all classes.",
164
164
  )
165
165
  @click.option(
@@ -86,6 +86,7 @@ packages = ["src/britekit"]
86
86
  "src/britekit/commands/_db_add.py" = "britekit/commands/_db_add.py"
87
87
  "src/britekit/commands/_db_delete.py" = "britekit/commands/_db_delete.py"
88
88
  "src/britekit/commands/_embed.py" = "britekit/commands/_embed.py"
89
+ "src/britekit/commands/_ensemble.py" = "britekit/commands/_ensemble.py"
89
90
  "src/britekit/commands/_extract.py" = "britekit/commands/_extract.py"
90
91
  "src/britekit/commands/_find_dup.py" = "britekit/commands/_find_dup.py"
91
92
  "src/britekit/commands/_inat.py" = "britekit/commands/_inat.py"
@@ -120,6 +121,7 @@ only-include = [
120
121
  "src/britekit/commands/_db_add.py" = "britekit/commands/_db_add.py"
121
122
  "src/britekit/commands/_db_delete.py" = "britekit/commands/_db_delete.py"
122
123
  "src/britekit/commands/_embed.py" = "britekit/commands/_embed.py"
124
+ "src/britekit/commands/_ensemble.py" = "britekit/commands/_ensemble.py"
123
125
  "src/britekit/commands/_extract.py" = "britekit/commands/_extract.py"
124
126
  "src/britekit/commands/_find_dup.py" = "britekit/commands/_find_dup.py"
125
127
  "src/britekit/commands/_inat.py" = "britekit/commands/_inat.py"
@@ -30,6 +30,7 @@ from .commands._db_delete import (
30
30
  _del_stype_cmd,
31
31
  )
32
32
  from .commands._embed import _embed_cmd
33
+ from .commands._ensemble import _ensemble_cmd
33
34
  from .commands._extract import _extract_all_cmd, _extract_by_image_cmd
34
35
  from .commands._find_dup import _find_dup_cmd
35
36
  from .commands._inat import _inat_cmd
@@ -80,6 +81,7 @@ cli.add_command(_del_src_cmd)
80
81
  cli.add_command(_del_stype_cmd)
81
82
 
82
83
  cli.add_command(_embed_cmd)
84
+ cli.add_command(_ensemble_cmd)
83
85
  cli.add_command(_extract_all_cmd)
84
86
  cli.add_command(_extract_by_image_cmd)
85
87
 
@@ -16,7 +16,7 @@ class Analyzer:
16
16
  """
17
17
 
18
18
  def __init__(self):
19
- self.cfg, self.fn_cfg = get_config()
19
+ self.cfg = get_config()
20
20
  self.dataframes = []
21
21
 
22
22
  def _save_manifest(self, output_path: str, predictor):
@@ -51,7 +51,7 @@ class Audio:
51
51
  import torchaudio as ta
52
52
 
53
53
  if cfg is None:
54
- self.cfg, _ = get_config()
54
+ self.cfg = get_config()
55
55
  else:
56
56
  self.cfg = cfg
57
57