britekit 0.0.8__tar.gz → 0.0.9__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of britekit might be problematic. Click here for more details.

Files changed (125) hide show
  1. {britekit-0.0.8 → britekit-0.0.9}/PKG-INFO +1 -1
  2. {britekit-0.0.8 → britekit-0.0.9}/britekit/__about__.py +1 -1
  3. {britekit-0.0.8 → britekit-0.0.9}/britekit/commands/__init__.py +2 -0
  4. britekit-0.0.9/britekit/commands/_ensemble.py +237 -0
  5. {britekit-0.0.8 → britekit-0.0.9}/britekit/commands/_reports.py +2 -2
  6. {britekit-0.0.8 → britekit-0.0.9}/britekit/commands/_tune.py +2 -2
  7. {britekit-0.0.8 → britekit-0.0.9}/pyproject.toml +2 -0
  8. {britekit-0.0.8 → britekit-0.0.9}/src/britekit/cli.py +2 -0
  9. {britekit-0.0.8 → britekit-0.0.9}/src/britekit/core/trainer.py +3 -2
  10. {britekit-0.0.8 → britekit-0.0.9}/.gitignore +0 -0
  11. {britekit-0.0.8 → britekit-0.0.9}/LICENSE.txt +0 -0
  12. {britekit-0.0.8 → britekit-0.0.9}/README.md +0 -0
  13. {britekit-0.0.8 → britekit-0.0.9}/britekit/__init__.py +0 -0
  14. {britekit-0.0.8 → britekit-0.0.9}/britekit/commands/_analyze.py +0 -0
  15. {britekit-0.0.8 → britekit-0.0.9}/britekit/commands/_audioset.py +0 -0
  16. {britekit-0.0.8 → britekit-0.0.9}/britekit/commands/_calibrate.py +0 -0
  17. {britekit-0.0.8 → britekit-0.0.9}/britekit/commands/_ckpt_ops.py +0 -0
  18. {britekit-0.0.8 → britekit-0.0.9}/britekit/commands/_db_add.py +0 -0
  19. {britekit-0.0.8 → britekit-0.0.9}/britekit/commands/_db_delete.py +0 -0
  20. {britekit-0.0.8 → britekit-0.0.9}/britekit/commands/_embed.py +0 -0
  21. {britekit-0.0.8 → britekit-0.0.9}/britekit/commands/_extract.py +0 -0
  22. {britekit-0.0.8 → britekit-0.0.9}/britekit/commands/_find_dup.py +0 -0
  23. {britekit-0.0.8 → britekit-0.0.9}/britekit/commands/_inat.py +0 -0
  24. {britekit-0.0.8 → britekit-0.0.9}/britekit/commands/_init.py +0 -0
  25. {britekit-0.0.8 → britekit-0.0.9}/britekit/commands/_pickle.py +0 -0
  26. {britekit-0.0.8 → britekit-0.0.9}/britekit/commands/_plot.py +0 -0
  27. {britekit-0.0.8 → britekit-0.0.9}/britekit/commands/_reextract.py +0 -0
  28. {britekit-0.0.8 → britekit-0.0.9}/britekit/commands/_search.py +0 -0
  29. {britekit-0.0.8 → britekit-0.0.9}/britekit/commands/_train.py +0 -0
  30. {britekit-0.0.8 → britekit-0.0.9}/britekit/commands/_wav2mp3.py +0 -0
  31. {britekit-0.0.8 → britekit-0.0.9}/britekit/commands/_xeno.py +0 -0
  32. {britekit-0.0.8 → britekit-0.0.9}/britekit/commands/_youtube.py +0 -0
  33. {britekit-0.0.8 → britekit-0.0.9}/britekit/core/__init__.py +0 -0
  34. {britekit-0.0.8 → britekit-0.0.9}/britekit/install/data/audioset/class_inclusion.csv +0 -0
  35. {britekit-0.0.8 → britekit-0.0.9}/britekit/install/data/audioset/class_list.csv +0 -0
  36. {britekit-0.0.8 → britekit-0.0.9}/britekit/install/data/audioset/curated/aircraft.csv +0 -0
  37. {britekit-0.0.8 → britekit-0.0.9}/britekit/install/data/audioset/curated/car.csv +0 -0
  38. {britekit-0.0.8 → britekit-0.0.9}/britekit/install/data/audioset/curated/chainsaw.csv +0 -0
  39. {britekit-0.0.8 → britekit-0.0.9}/britekit/install/data/audioset/curated/cow.csv +0 -0
  40. {britekit-0.0.8 → britekit-0.0.9}/britekit/install/data/audioset/curated/cricket.csv +0 -0
  41. {britekit-0.0.8 → britekit-0.0.9}/britekit/install/data/audioset/curated/dog.csv +0 -0
  42. {britekit-0.0.8 → britekit-0.0.9}/britekit/install/data/audioset/curated/rain.csv +0 -0
  43. {britekit-0.0.8 → britekit-0.0.9}/britekit/install/data/audioset/curated/rooster.csv +0 -0
  44. {britekit-0.0.8 → britekit-0.0.9}/britekit/install/data/audioset/curated/sheep.csv +0 -0
  45. {britekit-0.0.8 → britekit-0.0.9}/britekit/install/data/audioset/curated/siren.csv +0 -0
  46. {britekit-0.0.8 → britekit-0.0.9}/britekit/install/data/audioset/curated/speech.csv +0 -0
  47. {britekit-0.0.8 → britekit-0.0.9}/britekit/install/data/audioset/curated/truck.csv +0 -0
  48. {britekit-0.0.8 → britekit-0.0.9}/britekit/install/data/audioset/curated/wind.csv +0 -0
  49. {britekit-0.0.8 → britekit-0.0.9}/britekit/install/data/audioset/unbalanced_train_segments.csv +0 -0
  50. {britekit-0.0.8 → britekit-0.0.9}/britekit/install/data/classes.csv +0 -0
  51. {britekit-0.0.8 → britekit-0.0.9}/britekit/install/data/ignore.txt +0 -0
  52. {britekit-0.0.8 → britekit-0.0.9}/britekit/install/yaml/base_config.yaml +0 -0
  53. {britekit-0.0.8 → britekit-0.0.9}/britekit/install/yaml/samples/cfg_infer.yaml +0 -0
  54. {britekit-0.0.8 → britekit-0.0.9}/britekit/install/yaml/samples/train_dla.yaml +0 -0
  55. {britekit-0.0.8 → britekit-0.0.9}/britekit/install/yaml/samples/train_effnet.yaml +0 -0
  56. {britekit-0.0.8 → britekit-0.0.9}/britekit/install/yaml/samples/train_gernet.yaml +0 -0
  57. {britekit-0.0.8 → britekit-0.0.9}/britekit/install/yaml/samples/train_hgnet.yaml +0 -0
  58. {britekit-0.0.8 → britekit-0.0.9}/britekit/install/yaml/samples/train_timm.yaml +0 -0
  59. {britekit-0.0.8 → britekit-0.0.9}/britekit/install/yaml/samples/train_vovnet.yaml +0 -0
  60. {britekit-0.0.8 → britekit-0.0.9}/britekit/install/yaml/samples/tune_dropout.yaml +0 -0
  61. {britekit-0.0.8 → britekit-0.0.9}/britekit/install/yaml/samples/tune_learning_rate.yaml +0 -0
  62. {britekit-0.0.8 → britekit-0.0.9}/britekit/install/yaml/samples/tune_optimizer.yaml +0 -0
  63. {britekit-0.0.8 → britekit-0.0.9}/britekit/install/yaml/samples/tune_smooth.yaml +0 -0
  64. {britekit-0.0.8 → britekit-0.0.9}/install/data/audioset/class_inclusion.csv +0 -0
  65. {britekit-0.0.8 → britekit-0.0.9}/install/data/audioset/class_list.csv +0 -0
  66. {britekit-0.0.8 → britekit-0.0.9}/install/data/audioset/curated/aircraft.csv +0 -0
  67. {britekit-0.0.8 → britekit-0.0.9}/install/data/audioset/curated/car.csv +0 -0
  68. {britekit-0.0.8 → britekit-0.0.9}/install/data/audioset/curated/chainsaw.csv +0 -0
  69. {britekit-0.0.8 → britekit-0.0.9}/install/data/audioset/curated/cow.csv +0 -0
  70. {britekit-0.0.8 → britekit-0.0.9}/install/data/audioset/curated/cricket.csv +0 -0
  71. {britekit-0.0.8 → britekit-0.0.9}/install/data/audioset/curated/dog.csv +0 -0
  72. {britekit-0.0.8 → britekit-0.0.9}/install/data/audioset/curated/rain.csv +0 -0
  73. {britekit-0.0.8 → britekit-0.0.9}/install/data/audioset/curated/rooster.csv +0 -0
  74. {britekit-0.0.8 → britekit-0.0.9}/install/data/audioset/curated/sheep.csv +0 -0
  75. {britekit-0.0.8 → britekit-0.0.9}/install/data/audioset/curated/siren.csv +0 -0
  76. {britekit-0.0.8 → britekit-0.0.9}/install/data/audioset/curated/speech.csv +0 -0
  77. {britekit-0.0.8 → britekit-0.0.9}/install/data/audioset/curated/truck.csv +0 -0
  78. {britekit-0.0.8 → britekit-0.0.9}/install/data/audioset/curated/wind.csv +0 -0
  79. {britekit-0.0.8 → britekit-0.0.9}/install/data/audioset/unbalanced_train_segments.csv +0 -0
  80. {britekit-0.0.8 → britekit-0.0.9}/install/data/classes.csv +0 -0
  81. {britekit-0.0.8 → britekit-0.0.9}/install/data/ignore.txt +0 -0
  82. {britekit-0.0.8 → britekit-0.0.9}/install/yaml/base_config.yaml +0 -0
  83. {britekit-0.0.8 → britekit-0.0.9}/install/yaml/samples/cfg_infer.yaml +0 -0
  84. {britekit-0.0.8 → britekit-0.0.9}/install/yaml/samples/train_dla.yaml +0 -0
  85. {britekit-0.0.8 → britekit-0.0.9}/install/yaml/samples/train_effnet.yaml +0 -0
  86. {britekit-0.0.8 → britekit-0.0.9}/install/yaml/samples/train_gernet.yaml +0 -0
  87. {britekit-0.0.8 → britekit-0.0.9}/install/yaml/samples/train_hgnet.yaml +0 -0
  88. {britekit-0.0.8 → britekit-0.0.9}/install/yaml/samples/train_timm.yaml +0 -0
  89. {britekit-0.0.8 → britekit-0.0.9}/install/yaml/samples/train_vovnet.yaml +0 -0
  90. {britekit-0.0.8 → britekit-0.0.9}/install/yaml/samples/tune_dropout.yaml +0 -0
  91. {britekit-0.0.8 → britekit-0.0.9}/install/yaml/samples/tune_learning_rate.yaml +0 -0
  92. {britekit-0.0.8 → britekit-0.0.9}/install/yaml/samples/tune_optimizer.yaml +0 -0
  93. {britekit-0.0.8 → britekit-0.0.9}/install/yaml/samples/tune_smooth.yaml +0 -0
  94. {britekit-0.0.8 → britekit-0.0.9}/src/britekit/core/analyzer.py +0 -0
  95. {britekit-0.0.8 → britekit-0.0.9}/src/britekit/core/audio.py +0 -0
  96. {britekit-0.0.8 → britekit-0.0.9}/src/britekit/core/augmentation.py +0 -0
  97. {britekit-0.0.8 → britekit-0.0.9}/src/britekit/core/base_config.py +0 -0
  98. {britekit-0.0.8 → britekit-0.0.9}/src/britekit/core/config_loader.py +0 -0
  99. {britekit-0.0.8 → britekit-0.0.9}/src/britekit/core/data_module.py +0 -0
  100. {britekit-0.0.8 → britekit-0.0.9}/src/britekit/core/dataset.py +0 -0
  101. {britekit-0.0.8 → britekit-0.0.9}/src/britekit/core/exceptions.py +0 -0
  102. {britekit-0.0.8 → britekit-0.0.9}/src/britekit/core/pickler.py +0 -0
  103. {britekit-0.0.8 → britekit-0.0.9}/src/britekit/core/plot.py +0 -0
  104. {britekit-0.0.8 → britekit-0.0.9}/src/britekit/core/predictor.py +0 -0
  105. {britekit-0.0.8 → britekit-0.0.9}/src/britekit/core/reextractor.py +0 -0
  106. {britekit-0.0.8 → britekit-0.0.9}/src/britekit/core/tuner.py +0 -0
  107. {britekit-0.0.8 → britekit-0.0.9}/src/britekit/core/util.py +0 -0
  108. {britekit-0.0.8 → britekit-0.0.9}/src/britekit/models/base_model.py +0 -0
  109. {britekit-0.0.8 → britekit-0.0.9}/src/britekit/models/dla.py +0 -0
  110. {britekit-0.0.8 → britekit-0.0.9}/src/britekit/models/effnet.py +0 -0
  111. {britekit-0.0.8 → britekit-0.0.9}/src/britekit/models/gernet.py +0 -0
  112. {britekit-0.0.8 → britekit-0.0.9}/src/britekit/models/head_factory.py +0 -0
  113. {britekit-0.0.8 → britekit-0.0.9}/src/britekit/models/hgnet.py +0 -0
  114. {britekit-0.0.8 → britekit-0.0.9}/src/britekit/models/model_loader.py +0 -0
  115. {britekit-0.0.8 → britekit-0.0.9}/src/britekit/models/timm_model.py +0 -0
  116. {britekit-0.0.8 → britekit-0.0.9}/src/britekit/models/vovnet.py +0 -0
  117. {britekit-0.0.8 → britekit-0.0.9}/src/britekit/occurrence_db/occurrence_data_provider.py +0 -0
  118. {britekit-0.0.8 → britekit-0.0.9}/src/britekit/occurrence_db/occurrence_db.py +0 -0
  119. {britekit-0.0.8 → britekit-0.0.9}/src/britekit/testing/base_tester.py +0 -0
  120. {britekit-0.0.8 → britekit-0.0.9}/src/britekit/testing/per_minute_tester.py +0 -0
  121. {britekit-0.0.8 → britekit-0.0.9}/src/britekit/testing/per_recording_tester.py +0 -0
  122. {britekit-0.0.8 → britekit-0.0.9}/src/britekit/testing/per_segment_tester.py +0 -0
  123. {britekit-0.0.8 → britekit-0.0.9}/src/britekit/training_db/extractor.py +0 -0
  124. {britekit-0.0.8 → britekit-0.0.9}/src/britekit/training_db/training_data_provider.py +0 -0
  125. {britekit-0.0.8 → britekit-0.0.9}/src/britekit/training_db/training_db.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: britekit
3
- Version: 0.0.8
3
+ Version: 0.0.9
4
4
  Summary: Core functions for bioacoustic recognizers.
5
5
  Project-URL: Documentation, https://github.com/jhuus/BriteKit#readme
6
6
  Project-URL: Issues, https://github.com/jhuus/BriteKit/issues
@@ -1,4 +1,4 @@
1
1
  # SPDX-FileCopyrightText: 2025-present Jan Huus <jhuus1@gmail.com>
2
2
  #
3
3
  # SPDX-License-Identifier: MIT
4
- __version__ = "0.0.8"
4
+ __version__ = "0.0.9"
@@ -13,6 +13,7 @@ from ._db_delete import (
13
13
  del_stype,
14
14
  )
15
15
  from ._embed import embed
16
+ from ._ensemble import ensemble
16
17
  from ._extract import extract_all, extract_by_image
17
18
  from ._find_dup import find_dup
18
19
  from ._inat import inat
@@ -54,6 +55,7 @@ __all__ = [
54
55
  "del_src",
55
56
  "del_stype",
56
57
  "embed",
58
+ "ensemble",
57
59
  "extract_all",
58
60
  "extract_by_image",
59
61
  "find_dup",
@@ -0,0 +1,237 @@
1
+ # File name starts with _ to keep it out of typeahead for API users.
2
+ # Defer some imports to improve --help performance.
3
+ import logging
4
+ import os
5
+ from pathlib import Path
6
+ import tempfile
7
+ from typing import Optional
8
+
9
+ import click
10
+
11
+ from britekit.core.config_loader import get_config
12
+ from britekit.core import util
13
+
14
+ def _eval_ensemble(ensemble, temp_dir, annotations_path, recording_dir):
15
+ import shutil
16
+
17
+ from britekit.core.analyzer import Analyzer
18
+ from britekit.testing.per_segment_tester import PerSegmentTester
19
+
20
+ # delete any checkpoints in the temp dir
21
+ for filename in os.listdir(temp_dir):
22
+ file_path = os.path.join(temp_dir, filename)
23
+ os.remove(file_path)
24
+
25
+ # copy checkpoints to the temp dir
26
+ for file_path in ensemble:
27
+ file_name = Path(file_path).name
28
+ dest_path = os.path.join(temp_dir, file_name)
29
+ shutil.copyfile(file_path, dest_path)
30
+
31
+ # run inference on the given test
32
+ util.set_logging(level=logging.ERROR) # suppress logging during inference and analysis
33
+ label_dir = "ensemble_evaluation_labels"
34
+ inference_output_dir = str(Path(recording_dir) / label_dir)
35
+ Analyzer().run(recording_dir, inference_output_dir)
36
+
37
+ min_score = 0.8 # irrelevant really
38
+ with tempfile.TemporaryDirectory() as output_dir:
39
+ tester = PerSegmentTester(
40
+ annotations_path,
41
+ recording_dir,
42
+ inference_output_dir,
43
+ output_dir,
44
+ min_score,
45
+ )
46
+ tester.initialize()
47
+
48
+ pr_stats = tester.get_pr_auc_stats()
49
+ roc_stats = tester.get_roc_auc_stats()
50
+
51
+ scores = {
52
+ "macro_pr": pr_stats["macro_pr_auc"],
53
+ "micro_pr": pr_stats["micro_pr_auc_trained"],
54
+ "macro_roc": roc_stats["macro_roc_auc"],
55
+ "micro_roc": roc_stats["micro_roc_auc_trained"]
56
+ }
57
+
58
+ shutil.rmtree(inference_output_dir)
59
+ util.set_logging() # restore logging
60
+
61
+ return scores
62
+
63
+ def ensemble(
64
+ cfg_path: Optional[str]=None,
65
+ ckpt_path: str="",
66
+ ensemble_size: int=3,
67
+ num_tries: int=100,
68
+ metric: str = "micro_roc",
69
+ annotations_path: str = "",
70
+ recordings_path: Optional[str] = None,
71
+ output_path: str = "",
72
+ ) -> None:
73
+ """
74
+ Find the best ensemble of a given size from a group of checkpoints.
75
+
76
+ Given a directory containing checkpoints, and an ensemble size (default=3), select random
77
+ ensembles of the given size and test each one to identify the best ensemble.
78
+
79
+ Args:
80
+ cfg_path (str, optional): Path to YAML file defining configuration overrides.
81
+ ckpt_path (str): Path to directory containing checkpoints.
82
+ ensemble_size (int): Number of checkpoints in ensemble (default=3).
83
+ num_tries (int): Maximum number of ensembles to try (default=100).
84
+ metric (str): Metric to use to compare ensembles (default=micro_roc).
85
+ annotations_path (str): Path to CSV file containing ground truth annotations.
86
+ recordings_path (str, optional): Directory containing audio recordings. Defaults to annotations directory.
87
+ output_path (str): Directory where reports will be saved.
88
+ """
89
+ import glob
90
+ import itertools
91
+ import math
92
+ import random
93
+
94
+ if metric not in ["macro_pr", "micro_pr", "macro_roc", "micro_roc"]:
95
+ logging.error(f"Error: invalid metric ({metric})")
96
+ return
97
+
98
+ cfg, _ = get_config(cfg_path)
99
+ ckpt_paths = sorted(glob.glob(os.path.join(ckpt_path, "*.ckpt")))
100
+ num_ckpts = len(ckpt_paths)
101
+ if num_ckpts == 0:
102
+ logging.error(f"Error: no checkpoints found in {ckpt_path}")
103
+ return
104
+ elif num_ckpts < ensemble_size:
105
+ logging.error(f"Error: number of checkpoints ({num_ckpts}) is less than requested ensemble size ({ensemble_size})")
106
+ return
107
+
108
+ if not recordings_path:
109
+ recordings_path = str(Path(annotations_path).parent)
110
+
111
+ with tempfile.TemporaryDirectory() as temp_dir:
112
+ cfg.misc.ckpt_folder = temp_dir
113
+ cfg.infer.min_score = 0
114
+
115
+ best_score = 0
116
+ best_ensemble = None
117
+ count = 1
118
+ total_combinations = math.comb(len(ckpt_paths), ensemble_size)
119
+ if total_combinations <= num_tries:
120
+ # Exhaustive search
121
+ logging.info("Doing exhaustive search")
122
+ for ensemble in itertools.combinations(ckpt_paths, ensemble_size):
123
+ scores = _eval_ensemble(ensemble, temp_dir, annotations_path, recordings_path)
124
+ logging.info(f"For ensemble {count} of {total_combinations}, score = {scores[metric]:.4f}")
125
+ if scores[metric] > best_score:
126
+ best_score = scores[metric]
127
+ best_ensemble = ensemble
128
+
129
+ count += 1
130
+ else:
131
+ # Random sampling without replacement
132
+ logging.info("Doing random sampling")
133
+ seen: set = set()
134
+ while len(seen) < num_tries:
135
+ ensemble = tuple(sorted(random.sample(ckpt_paths, ensemble_size)))
136
+ if ensemble not in seen:
137
+ seen.add(ensemble)
138
+ scores = _eval_ensemble(ensemble, temp_dir, annotations_path, recordings_path)
139
+ logging.info(f"For ensemble {count} of {num_tries}, score = {scores[metric]:.4f}")
140
+ if scores[metric] > best_score:
141
+ best_score = scores[metric]
142
+ best_ensemble = ensemble
143
+
144
+ count += 1
145
+
146
+ logging.info(f"Best score = {best_score:.4f}")
147
+
148
+ best_names = [Path(ckpt_path).name for ckpt_path in best_ensemble]
149
+ logging.info(f"Best ensemble = {best_names}")
150
+
151
+ @click.command(
152
+ name="ensemble",
153
+ short_help="Find the best ensemble of a given size from a group of checkpoints.",
154
+ help=util.cli_help_from_doc(ensemble.__doc__),
155
+ )
156
+ @click.option(
157
+ "-c",
158
+ "--cfg",
159
+ "cfg_path",
160
+ type=click.Path(exists=True),
161
+ required=False,
162
+ help="Path to YAML file defining config overrides.",
163
+ )
164
+ @click.option(
165
+ "--ckpt_path",
166
+ "ckpt_path",
167
+ type=click.Path(exists=True, file_okay=False, dir_okay=True),
168
+ required=True,
169
+ help="Directory containing checkpoints."
170
+ )
171
+ @click.option(
172
+ "-e",
173
+ "--ensemble_size",
174
+ "ensemble_size",
175
+ type=int,
176
+ default=3,
177
+ help="Number of checkpoints in ensemble (default=3)."
178
+ )
179
+ @click.option(
180
+ "-n",
181
+ "--num_tries",
182
+ "num_tries",
183
+ type=int,
184
+ default=100,
185
+ help="Maximum number of ensembles to try (default=100)."
186
+ )
187
+ @click.option(
188
+ "-m",
189
+ "--metric",
190
+ "metric",
191
+ type=click.Choice(
192
+ [
193
+ "macro_pr",
194
+ "micro_pr",
195
+ "macro_roc",
196
+ "micro_roc",
197
+ ]
198
+ ),
199
+ default="micro_roc",
200
+ help="Metric used to compare ensembles (default=micro_roc). Macro-averaging uses annotated classes only, but micro-averaging uses all classes.",
201
+ )
202
+ @click.option(
203
+ "-a",
204
+ "--annotations",
205
+ "annotations_path",
206
+ type=click.Path(exists=True, file_okay=True, dir_okay=False),
207
+ required=True,
208
+ help="Path to CSV file containing annotations or ground truth).",
209
+ )
210
+ @click.option(
211
+ "-r",
212
+ "--recordings",
213
+ "recordings_path",
214
+ type=click.Path(exists=True, file_okay=False, dir_okay=True),
215
+ required=False,
216
+ help="Recordings directory. Default is directory containing annotations file.",
217
+ )
218
+ @click.option(
219
+ "-o",
220
+ "--output",
221
+ "output_path",
222
+ type=click.Path(file_okay=False, dir_okay=True),
223
+ required=True,
224
+ help="Path to output directory.",
225
+ )
226
+ def _ensemble_cmd(
227
+ cfg_path: Optional[str],
228
+ ckpt_path: str,
229
+ ensemble_size: int,
230
+ num_tries: int,
231
+ metric: str,
232
+ annotations_path: str,
233
+ recordings_path: Optional[str],
234
+ output_path: str,
235
+ ) -> None:
236
+ util.set_logging()
237
+ ensemble(cfg_path, ckpt_path, ensemble_size, num_tries, metric, annotations_path, recordings_path, output_path)
@@ -276,14 +276,14 @@ def rpt_epochs(
276
276
  tester.initialize()
277
277
 
278
278
  pr_stats = tester.get_pr_auc_stats()
279
- pr_score = pr_stats["micro_pr_auc"]
279
+ pr_score = pr_stats["micro_pr_auc_trained"]
280
280
  pr_scores.append(pr_score)
281
281
  if pr_score > max_pr_score:
282
282
  max_pr_score = pr_score
283
283
  max_pr_epoch = epoch_num
284
284
 
285
285
  roc_stats = tester.get_roc_auc_stats()
286
- roc_score = roc_stats["micro_roc_auc"]
286
+ roc_score = roc_stats["micro_roc_auc_trained"]
287
287
  roc_scores.append(roc_score)
288
288
  if roc_score > max_roc_score:
289
289
  max_roc_score = roc_score
@@ -18,7 +18,7 @@ def tune(
18
18
  param_path: Optional[str] = None,
19
19
  output_path: str = "",
20
20
  annotations_path: str = "",
21
- metric: str = "macro_roc",
21
+ metric: str = "micro_roc",
22
22
  recordings_path: str = "",
23
23
  train_log_path: str = "",
24
24
  num_trials: int = 0,
@@ -159,7 +159,7 @@ def tune(
159
159
  "micro_roc",
160
160
  ]
161
161
  ),
162
- default="macro_roc",
162
+ default="micro_roc",
163
163
  help="Metric used to compare runs. Macro-averaging uses annotated classes only, but micro-averaging uses all classes.",
164
164
  )
165
165
  @click.option(
@@ -86,6 +86,7 @@ packages = ["src/britekit"]
86
86
  "src/britekit/commands/_db_add.py" = "britekit/commands/_db_add.py"
87
87
  "src/britekit/commands/_db_delete.py" = "britekit/commands/_db_delete.py"
88
88
  "src/britekit/commands/_embed.py" = "britekit/commands/_embed.py"
89
+ "src/britekit/commands/_ensemble.py" = "britekit/commands/_ensemble.py"
89
90
  "src/britekit/commands/_extract.py" = "britekit/commands/_extract.py"
90
91
  "src/britekit/commands/_find_dup.py" = "britekit/commands/_find_dup.py"
91
92
  "src/britekit/commands/_inat.py" = "britekit/commands/_inat.py"
@@ -120,6 +121,7 @@ only-include = [
120
121
  "src/britekit/commands/_db_add.py" = "britekit/commands/_db_add.py"
121
122
  "src/britekit/commands/_db_delete.py" = "britekit/commands/_db_delete.py"
122
123
  "src/britekit/commands/_embed.py" = "britekit/commands/_embed.py"
124
+ "src/britekit/commands/_ensemble.py" = "britekit/commands/_ensemble.py"
123
125
  "src/britekit/commands/_extract.py" = "britekit/commands/_extract.py"
124
126
  "src/britekit/commands/_find_dup.py" = "britekit/commands/_find_dup.py"
125
127
  "src/britekit/commands/_inat.py" = "britekit/commands/_inat.py"
@@ -30,6 +30,7 @@ from .commands._db_delete import (
30
30
  _del_stype_cmd,
31
31
  )
32
32
  from .commands._embed import _embed_cmd
33
+ from .commands._ensemble import _ensemble_cmd
33
34
  from .commands._extract import _extract_all_cmd, _extract_by_image_cmd
34
35
  from .commands._find_dup import _find_dup_cmd
35
36
  from .commands._inat import _inat_cmd
@@ -80,6 +81,7 @@ cli.add_command(_del_src_cmd)
80
81
  cli.add_command(_del_stype_cmd)
81
82
 
82
83
  cli.add_command(_embed_cmd)
84
+ cli.add_command(_ensemble_cmd)
83
85
  cli.add_command(_extract_all_cmd)
84
86
  cli.add_command(_extract_by_image_cmd)
85
87
 
@@ -125,11 +125,12 @@ class Trainer:
125
125
  if val_rocs:
126
126
  import math
127
127
  import numpy as np
128
+
128
129
  mean = float(np.mean(val_rocs))
129
- std = float(np.std(val_rocs, ddof=1)) if len(val_rocs) > 1 else 0.0
130
+ std = float(np.std(val_rocs, ddof=1)) if len(val_rocs) > 1 else 0.0
130
131
  n = len(val_rocs)
131
132
  se = std / math.sqrt(n) if n > 1 else 0.0
132
- ci95 = 1.96 * se # 95% CI using normal approximation
133
+ ci95 = 1.96 * se # 95% CI using normal approximation
133
134
 
134
135
  logging.info("Using micro-averaged ROC AUC")
135
136
  scores_str = ", ".join(f"{v:.4f}" for v in val_rocs)
File without changes
File without changes
File without changes
File without changes