britekit 0.0.7__tar.gz → 0.0.9__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of britekit might be problematic. Click here for more details.
- {britekit-0.0.7 → britekit-0.0.9}/PKG-INFO +1 -1
- {britekit-0.0.7 → britekit-0.0.9}/britekit/__about__.py +1 -1
- {britekit-0.0.7 → britekit-0.0.9}/britekit/commands/__init__.py +2 -0
- britekit-0.0.9/britekit/commands/_ensemble.py +237 -0
- {britekit-0.0.7 → britekit-0.0.9}/britekit/commands/_reports.py +4 -4
- {britekit-0.0.7 → britekit-0.0.9}/britekit/commands/_tune.py +2 -2
- {britekit-0.0.7 → britekit-0.0.9}/pyproject.toml +2 -0
- {britekit-0.0.7 → britekit-0.0.9}/src/britekit/cli.py +2 -0
- {britekit-0.0.7 → britekit-0.0.9}/src/britekit/core/data_module.py +1 -0
- {britekit-0.0.7 → britekit-0.0.9}/src/britekit/core/trainer.py +23 -0
- {britekit-0.0.7 → britekit-0.0.9}/src/britekit/core/tuner.py +1 -1
- {britekit-0.0.7 → britekit-0.0.9}/.gitignore +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/LICENSE.txt +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/README.md +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/britekit/__init__.py +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/britekit/commands/_analyze.py +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/britekit/commands/_audioset.py +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/britekit/commands/_calibrate.py +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/britekit/commands/_ckpt_ops.py +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/britekit/commands/_db_add.py +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/britekit/commands/_db_delete.py +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/britekit/commands/_embed.py +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/britekit/commands/_extract.py +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/britekit/commands/_find_dup.py +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/britekit/commands/_inat.py +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/britekit/commands/_init.py +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/britekit/commands/_pickle.py +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/britekit/commands/_plot.py +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/britekit/commands/_reextract.py +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/britekit/commands/_search.py +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/britekit/commands/_train.py +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/britekit/commands/_wav2mp3.py +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/britekit/commands/_xeno.py +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/britekit/commands/_youtube.py +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/britekit/core/__init__.py +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/britekit/install/data/audioset/class_inclusion.csv +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/britekit/install/data/audioset/class_list.csv +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/britekit/install/data/audioset/curated/aircraft.csv +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/britekit/install/data/audioset/curated/car.csv +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/britekit/install/data/audioset/curated/chainsaw.csv +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/britekit/install/data/audioset/curated/cow.csv +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/britekit/install/data/audioset/curated/cricket.csv +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/britekit/install/data/audioset/curated/dog.csv +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/britekit/install/data/audioset/curated/rain.csv +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/britekit/install/data/audioset/curated/rooster.csv +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/britekit/install/data/audioset/curated/sheep.csv +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/britekit/install/data/audioset/curated/siren.csv +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/britekit/install/data/audioset/curated/speech.csv +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/britekit/install/data/audioset/curated/truck.csv +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/britekit/install/data/audioset/curated/wind.csv +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/britekit/install/data/audioset/unbalanced_train_segments.csv +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/britekit/install/data/classes.csv +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/britekit/install/data/ignore.txt +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/britekit/install/yaml/base_config.yaml +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/britekit/install/yaml/samples/cfg_infer.yaml +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/britekit/install/yaml/samples/train_dla.yaml +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/britekit/install/yaml/samples/train_effnet.yaml +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/britekit/install/yaml/samples/train_gernet.yaml +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/britekit/install/yaml/samples/train_hgnet.yaml +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/britekit/install/yaml/samples/train_timm.yaml +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/britekit/install/yaml/samples/train_vovnet.yaml +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/britekit/install/yaml/samples/tune_dropout.yaml +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/britekit/install/yaml/samples/tune_learning_rate.yaml +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/britekit/install/yaml/samples/tune_optimizer.yaml +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/britekit/install/yaml/samples/tune_smooth.yaml +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/install/data/audioset/class_inclusion.csv +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/install/data/audioset/class_list.csv +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/install/data/audioset/curated/aircraft.csv +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/install/data/audioset/curated/car.csv +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/install/data/audioset/curated/chainsaw.csv +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/install/data/audioset/curated/cow.csv +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/install/data/audioset/curated/cricket.csv +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/install/data/audioset/curated/dog.csv +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/install/data/audioset/curated/rain.csv +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/install/data/audioset/curated/rooster.csv +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/install/data/audioset/curated/sheep.csv +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/install/data/audioset/curated/siren.csv +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/install/data/audioset/curated/speech.csv +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/install/data/audioset/curated/truck.csv +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/install/data/audioset/curated/wind.csv +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/install/data/audioset/unbalanced_train_segments.csv +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/install/data/classes.csv +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/install/data/ignore.txt +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/install/yaml/base_config.yaml +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/install/yaml/samples/cfg_infer.yaml +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/install/yaml/samples/train_dla.yaml +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/install/yaml/samples/train_effnet.yaml +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/install/yaml/samples/train_gernet.yaml +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/install/yaml/samples/train_hgnet.yaml +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/install/yaml/samples/train_timm.yaml +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/install/yaml/samples/train_vovnet.yaml +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/install/yaml/samples/tune_dropout.yaml +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/install/yaml/samples/tune_learning_rate.yaml +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/install/yaml/samples/tune_optimizer.yaml +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/install/yaml/samples/tune_smooth.yaml +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/src/britekit/core/analyzer.py +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/src/britekit/core/audio.py +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/src/britekit/core/augmentation.py +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/src/britekit/core/base_config.py +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/src/britekit/core/config_loader.py +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/src/britekit/core/dataset.py +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/src/britekit/core/exceptions.py +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/src/britekit/core/pickler.py +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/src/britekit/core/plot.py +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/src/britekit/core/predictor.py +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/src/britekit/core/reextractor.py +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/src/britekit/core/util.py +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/src/britekit/models/base_model.py +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/src/britekit/models/dla.py +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/src/britekit/models/effnet.py +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/src/britekit/models/gernet.py +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/src/britekit/models/head_factory.py +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/src/britekit/models/hgnet.py +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/src/britekit/models/model_loader.py +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/src/britekit/models/timm_model.py +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/src/britekit/models/vovnet.py +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/src/britekit/occurrence_db/occurrence_data_provider.py +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/src/britekit/occurrence_db/occurrence_db.py +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/src/britekit/testing/base_tester.py +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/src/britekit/testing/per_minute_tester.py +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/src/britekit/testing/per_recording_tester.py +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/src/britekit/testing/per_segment_tester.py +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/src/britekit/training_db/extractor.py +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/src/britekit/training_db/training_data_provider.py +0 -0
- {britekit-0.0.7 → britekit-0.0.9}/src/britekit/training_db/training_db.py +0 -0
|
@@ -13,6 +13,7 @@ from ._db_delete import (
|
|
|
13
13
|
del_stype,
|
|
14
14
|
)
|
|
15
15
|
from ._embed import embed
|
|
16
|
+
from ._ensemble import ensemble
|
|
16
17
|
from ._extract import extract_all, extract_by_image
|
|
17
18
|
from ._find_dup import find_dup
|
|
18
19
|
from ._inat import inat
|
|
@@ -54,6 +55,7 @@ __all__ = [
|
|
|
54
55
|
"del_src",
|
|
55
56
|
"del_stype",
|
|
56
57
|
"embed",
|
|
58
|
+
"ensemble",
|
|
57
59
|
"extract_all",
|
|
58
60
|
"extract_by_image",
|
|
59
61
|
"find_dup",
|
|
@@ -0,0 +1,237 @@
|
|
|
1
|
+
# File name starts with _ to keep it out of typeahead for API users.
|
|
2
|
+
# Defer some imports to improve --help performance.
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
import tempfile
|
|
7
|
+
from typing import Optional
|
|
8
|
+
|
|
9
|
+
import click
|
|
10
|
+
|
|
11
|
+
from britekit.core.config_loader import get_config
|
|
12
|
+
from britekit.core import util
|
|
13
|
+
|
|
14
|
+
def _eval_ensemble(ensemble, temp_dir, annotations_path, recording_dir):
|
|
15
|
+
import shutil
|
|
16
|
+
|
|
17
|
+
from britekit.core.analyzer import Analyzer
|
|
18
|
+
from britekit.testing.per_segment_tester import PerSegmentTester
|
|
19
|
+
|
|
20
|
+
# delete any checkpoints in the temp dir
|
|
21
|
+
for filename in os.listdir(temp_dir):
|
|
22
|
+
file_path = os.path.join(temp_dir, filename)
|
|
23
|
+
os.remove(file_path)
|
|
24
|
+
|
|
25
|
+
# copy checkpoints to the temp dir
|
|
26
|
+
for file_path in ensemble:
|
|
27
|
+
file_name = Path(file_path).name
|
|
28
|
+
dest_path = os.path.join(temp_dir, file_name)
|
|
29
|
+
shutil.copyfile(file_path, dest_path)
|
|
30
|
+
|
|
31
|
+
# run inference on the given test
|
|
32
|
+
util.set_logging(level=logging.ERROR) # suppress logging during inference and analysis
|
|
33
|
+
label_dir = "ensemble_evaluation_labels"
|
|
34
|
+
inference_output_dir = str(Path(recording_dir) / label_dir)
|
|
35
|
+
Analyzer().run(recording_dir, inference_output_dir)
|
|
36
|
+
|
|
37
|
+
min_score = 0.8 # irrelevant really
|
|
38
|
+
with tempfile.TemporaryDirectory() as output_dir:
|
|
39
|
+
tester = PerSegmentTester(
|
|
40
|
+
annotations_path,
|
|
41
|
+
recording_dir,
|
|
42
|
+
inference_output_dir,
|
|
43
|
+
output_dir,
|
|
44
|
+
min_score,
|
|
45
|
+
)
|
|
46
|
+
tester.initialize()
|
|
47
|
+
|
|
48
|
+
pr_stats = tester.get_pr_auc_stats()
|
|
49
|
+
roc_stats = tester.get_roc_auc_stats()
|
|
50
|
+
|
|
51
|
+
scores = {
|
|
52
|
+
"macro_pr": pr_stats["macro_pr_auc"],
|
|
53
|
+
"micro_pr": pr_stats["micro_pr_auc_trained"],
|
|
54
|
+
"macro_roc": roc_stats["macro_roc_auc"],
|
|
55
|
+
"micro_roc": roc_stats["micro_roc_auc_trained"]
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
shutil.rmtree(inference_output_dir)
|
|
59
|
+
util.set_logging() # restore logging
|
|
60
|
+
|
|
61
|
+
return scores
|
|
62
|
+
|
|
63
|
+
def ensemble(
|
|
64
|
+
cfg_path: Optional[str]=None,
|
|
65
|
+
ckpt_path: str="",
|
|
66
|
+
ensemble_size: int=3,
|
|
67
|
+
num_tries: int=100,
|
|
68
|
+
metric: str = "micro_roc",
|
|
69
|
+
annotations_path: str = "",
|
|
70
|
+
recordings_path: Optional[str] = None,
|
|
71
|
+
output_path: str = "",
|
|
72
|
+
) -> None:
|
|
73
|
+
"""
|
|
74
|
+
Find the best ensemble of a given size from a group of checkpoints.
|
|
75
|
+
|
|
76
|
+
Given a directory containing checkpoints, and an ensemble size (default=3), select random
|
|
77
|
+
ensembles of the given size and test each one to identify the best ensemble.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
cfg_path (str, optional): Path to YAML file defining configuration overrides.
|
|
81
|
+
ckpt_path (str): Path to directory containing checkpoints.
|
|
82
|
+
ensemble_size (int): Number of checkpoints in ensemble (default=3).
|
|
83
|
+
num_tries (int): Maximum number of ensembles to try (default=100).
|
|
84
|
+
metric (str): Metric to use to compare ensembles (default=micro_roc).
|
|
85
|
+
annotations_path (str): Path to CSV file containing ground truth annotations.
|
|
86
|
+
recordings_path (str, optional): Directory containing audio recordings. Defaults to annotations directory.
|
|
87
|
+
output_path (str): Directory where reports will be saved.
|
|
88
|
+
"""
|
|
89
|
+
import glob
|
|
90
|
+
import itertools
|
|
91
|
+
import math
|
|
92
|
+
import random
|
|
93
|
+
|
|
94
|
+
if metric not in ["macro_pr", "micro_pr", "macro_roc", "micro_roc"]:
|
|
95
|
+
logging.error(f"Error: invalid metric ({metric})")
|
|
96
|
+
return
|
|
97
|
+
|
|
98
|
+
cfg, _ = get_config(cfg_path)
|
|
99
|
+
ckpt_paths = sorted(glob.glob(os.path.join(ckpt_path, "*.ckpt")))
|
|
100
|
+
num_ckpts = len(ckpt_paths)
|
|
101
|
+
if num_ckpts == 0:
|
|
102
|
+
logging.error(f"Error: no checkpoints found in {ckpt_path}")
|
|
103
|
+
return
|
|
104
|
+
elif num_ckpts < ensemble_size:
|
|
105
|
+
logging.error(f"Error: number of checkpoints ({num_ckpts}) is less than requested ensemble size ({ensemble_size})")
|
|
106
|
+
return
|
|
107
|
+
|
|
108
|
+
if not recordings_path:
|
|
109
|
+
recordings_path = str(Path(annotations_path).parent)
|
|
110
|
+
|
|
111
|
+
with tempfile.TemporaryDirectory() as temp_dir:
|
|
112
|
+
cfg.misc.ckpt_folder = temp_dir
|
|
113
|
+
cfg.infer.min_score = 0
|
|
114
|
+
|
|
115
|
+
best_score = 0
|
|
116
|
+
best_ensemble = None
|
|
117
|
+
count = 1
|
|
118
|
+
total_combinations = math.comb(len(ckpt_paths), ensemble_size)
|
|
119
|
+
if total_combinations <= num_tries:
|
|
120
|
+
# Exhaustive search
|
|
121
|
+
logging.info("Doing exhaustive search")
|
|
122
|
+
for ensemble in itertools.combinations(ckpt_paths, ensemble_size):
|
|
123
|
+
scores = _eval_ensemble(ensemble, temp_dir, annotations_path, recordings_path)
|
|
124
|
+
logging.info(f"For ensemble {count} of {total_combinations}, score = {scores[metric]:.4f}")
|
|
125
|
+
if scores[metric] > best_score:
|
|
126
|
+
best_score = scores[metric]
|
|
127
|
+
best_ensemble = ensemble
|
|
128
|
+
|
|
129
|
+
count += 1
|
|
130
|
+
else:
|
|
131
|
+
# Random sampling without replacement
|
|
132
|
+
logging.info("Doing random sampling")
|
|
133
|
+
seen: set = set()
|
|
134
|
+
while len(seen) < num_tries:
|
|
135
|
+
ensemble = tuple(sorted(random.sample(ckpt_paths, ensemble_size)))
|
|
136
|
+
if ensemble not in seen:
|
|
137
|
+
seen.add(ensemble)
|
|
138
|
+
scores = _eval_ensemble(ensemble, temp_dir, annotations_path, recordings_path)
|
|
139
|
+
logging.info(f"For ensemble {count} of {num_tries}, score = {scores[metric]:.4f}")
|
|
140
|
+
if scores[metric] > best_score:
|
|
141
|
+
best_score = scores[metric]
|
|
142
|
+
best_ensemble = ensemble
|
|
143
|
+
|
|
144
|
+
count += 1
|
|
145
|
+
|
|
146
|
+
logging.info(f"Best score = {best_score:.4f}")
|
|
147
|
+
|
|
148
|
+
best_names = [Path(ckpt_path).name for ckpt_path in best_ensemble]
|
|
149
|
+
logging.info(f"Best ensemble = {best_names}")
|
|
150
|
+
|
|
151
|
+
@click.command(
|
|
152
|
+
name="ensemble",
|
|
153
|
+
short_help="Find the best ensemble of a given size from a group of checkpoints.",
|
|
154
|
+
help=util.cli_help_from_doc(ensemble.__doc__),
|
|
155
|
+
)
|
|
156
|
+
@click.option(
|
|
157
|
+
"-c",
|
|
158
|
+
"--cfg",
|
|
159
|
+
"cfg_path",
|
|
160
|
+
type=click.Path(exists=True),
|
|
161
|
+
required=False,
|
|
162
|
+
help="Path to YAML file defining config overrides.",
|
|
163
|
+
)
|
|
164
|
+
@click.option(
|
|
165
|
+
"--ckpt_path",
|
|
166
|
+
"ckpt_path",
|
|
167
|
+
type=click.Path(exists=True, file_okay=False, dir_okay=True),
|
|
168
|
+
required=True,
|
|
169
|
+
help="Directory containing checkpoints."
|
|
170
|
+
)
|
|
171
|
+
@click.option(
|
|
172
|
+
"-e",
|
|
173
|
+
"--ensemble_size",
|
|
174
|
+
"ensemble_size",
|
|
175
|
+
type=int,
|
|
176
|
+
default=3,
|
|
177
|
+
help="Number of checkpoints in ensemble (default=3)."
|
|
178
|
+
)
|
|
179
|
+
@click.option(
|
|
180
|
+
"-n",
|
|
181
|
+
"--num_tries",
|
|
182
|
+
"num_tries",
|
|
183
|
+
type=int,
|
|
184
|
+
default=100,
|
|
185
|
+
help="Maximum number of ensembles to try (default=100)."
|
|
186
|
+
)
|
|
187
|
+
@click.option(
|
|
188
|
+
"-m",
|
|
189
|
+
"--metric",
|
|
190
|
+
"metric",
|
|
191
|
+
type=click.Choice(
|
|
192
|
+
[
|
|
193
|
+
"macro_pr",
|
|
194
|
+
"micro_pr",
|
|
195
|
+
"macro_roc",
|
|
196
|
+
"micro_roc",
|
|
197
|
+
]
|
|
198
|
+
),
|
|
199
|
+
default="micro_roc",
|
|
200
|
+
help="Metric used to compare ensembles (default=micro_roc). Macro-averaging uses annotated classes only, but micro-averaging uses all classes.",
|
|
201
|
+
)
|
|
202
|
+
@click.option(
|
|
203
|
+
"-a",
|
|
204
|
+
"--annotations",
|
|
205
|
+
"annotations_path",
|
|
206
|
+
type=click.Path(exists=True, file_okay=True, dir_okay=False),
|
|
207
|
+
required=True,
|
|
208
|
+
help="Path to CSV file containing annotations or ground truth).",
|
|
209
|
+
)
|
|
210
|
+
@click.option(
|
|
211
|
+
"-r",
|
|
212
|
+
"--recordings",
|
|
213
|
+
"recordings_path",
|
|
214
|
+
type=click.Path(exists=True, file_okay=False, dir_okay=True),
|
|
215
|
+
required=False,
|
|
216
|
+
help="Recordings directory. Default is directory containing annotations file.",
|
|
217
|
+
)
|
|
218
|
+
@click.option(
|
|
219
|
+
"-o",
|
|
220
|
+
"--output",
|
|
221
|
+
"output_path",
|
|
222
|
+
type=click.Path(file_okay=False, dir_okay=True),
|
|
223
|
+
required=True,
|
|
224
|
+
help="Path to output directory.",
|
|
225
|
+
)
|
|
226
|
+
def _ensemble_cmd(
|
|
227
|
+
cfg_path: Optional[str],
|
|
228
|
+
ckpt_path: str,
|
|
229
|
+
ensemble_size: int,
|
|
230
|
+
num_tries: int,
|
|
231
|
+
metric: str,
|
|
232
|
+
annotations_path: str,
|
|
233
|
+
recordings_path: Optional[str],
|
|
234
|
+
output_path: str,
|
|
235
|
+
) -> None:
|
|
236
|
+
util.set_logging()
|
|
237
|
+
ensemble(cfg_path, ckpt_path, ensemble_size, num_tries, metric, annotations_path, recordings_path, output_path)
|
|
@@ -276,14 +276,14 @@ def rpt_epochs(
|
|
|
276
276
|
tester.initialize()
|
|
277
277
|
|
|
278
278
|
pr_stats = tester.get_pr_auc_stats()
|
|
279
|
-
pr_score = pr_stats["
|
|
279
|
+
pr_score = pr_stats["micro_pr_auc_trained"]
|
|
280
280
|
pr_scores.append(pr_score)
|
|
281
281
|
if pr_score > max_pr_score:
|
|
282
282
|
max_pr_score = pr_score
|
|
283
283
|
max_pr_epoch = epoch_num
|
|
284
284
|
|
|
285
285
|
roc_stats = tester.get_roc_auc_stats()
|
|
286
|
-
roc_score = roc_stats["
|
|
286
|
+
roc_score = roc_stats["micro_roc_auc_trained"]
|
|
287
287
|
roc_scores.append(roc_score)
|
|
288
288
|
if roc_score > max_roc_score:
|
|
289
289
|
max_roc_score = roc_score
|
|
@@ -323,8 +323,8 @@ def rpt_epochs(
|
|
|
323
323
|
plot_path = str(Path(output_path) / "training_scores.jpeg")
|
|
324
324
|
plt.savefig(plot_path, dpi=300, bbox_inches="tight")
|
|
325
325
|
|
|
326
|
-
logging.info(f"Maximum PR-AUC score = {max_pr_score:.3f} at epoch {max_pr_epoch}")
|
|
327
|
-
logging.info(f"Maximum ROC-AUC score = {max_roc_score:.3f} at epoch {max_roc_epoch}")
|
|
326
|
+
logging.info(f"Maximum micro-averaged PR-AUC score = {max_pr_score:.3f} at epoch {max_pr_epoch}")
|
|
327
|
+
logging.info(f"Maximum micro-averaged ROC-AUC score = {max_roc_score:.3f} at epoch {max_roc_epoch}")
|
|
328
328
|
logging.info(f"See plot at {plot_path}")
|
|
329
329
|
|
|
330
330
|
|
|
@@ -18,7 +18,7 @@ def tune(
|
|
|
18
18
|
param_path: Optional[str] = None,
|
|
19
19
|
output_path: str = "",
|
|
20
20
|
annotations_path: str = "",
|
|
21
|
-
metric: str = "
|
|
21
|
+
metric: str = "micro_roc",
|
|
22
22
|
recordings_path: str = "",
|
|
23
23
|
train_log_path: str = "",
|
|
24
24
|
num_trials: int = 0,
|
|
@@ -159,7 +159,7 @@ def tune(
|
|
|
159
159
|
"micro_roc",
|
|
160
160
|
]
|
|
161
161
|
),
|
|
162
|
-
default="
|
|
162
|
+
default="micro_roc",
|
|
163
163
|
help="Metric used to compare runs. Macro-averaging uses annotated classes only, but micro-averaging uses all classes.",
|
|
164
164
|
)
|
|
165
165
|
@click.option(
|
|
@@ -86,6 +86,7 @@ packages = ["src/britekit"]
|
|
|
86
86
|
"src/britekit/commands/_db_add.py" = "britekit/commands/_db_add.py"
|
|
87
87
|
"src/britekit/commands/_db_delete.py" = "britekit/commands/_db_delete.py"
|
|
88
88
|
"src/britekit/commands/_embed.py" = "britekit/commands/_embed.py"
|
|
89
|
+
"src/britekit/commands/_ensemble.py" = "britekit/commands/_ensemble.py"
|
|
89
90
|
"src/britekit/commands/_extract.py" = "britekit/commands/_extract.py"
|
|
90
91
|
"src/britekit/commands/_find_dup.py" = "britekit/commands/_find_dup.py"
|
|
91
92
|
"src/britekit/commands/_inat.py" = "britekit/commands/_inat.py"
|
|
@@ -120,6 +121,7 @@ only-include = [
|
|
|
120
121
|
"src/britekit/commands/_db_add.py" = "britekit/commands/_db_add.py"
|
|
121
122
|
"src/britekit/commands/_db_delete.py" = "britekit/commands/_db_delete.py"
|
|
122
123
|
"src/britekit/commands/_embed.py" = "britekit/commands/_embed.py"
|
|
124
|
+
"src/britekit/commands/_ensemble.py" = "britekit/commands/_ensemble.py"
|
|
123
125
|
"src/britekit/commands/_extract.py" = "britekit/commands/_extract.py"
|
|
124
126
|
"src/britekit/commands/_find_dup.py" = "britekit/commands/_find_dup.py"
|
|
125
127
|
"src/britekit/commands/_inat.py" = "britekit/commands/_inat.py"
|
|
@@ -30,6 +30,7 @@ from .commands._db_delete import (
|
|
|
30
30
|
_del_stype_cmd,
|
|
31
31
|
)
|
|
32
32
|
from .commands._embed import _embed_cmd
|
|
33
|
+
from .commands._ensemble import _ensemble_cmd
|
|
33
34
|
from .commands._extract import _extract_all_cmd, _extract_by_image_cmd
|
|
34
35
|
from .commands._find_dup import _find_dup_cmd
|
|
35
36
|
from .commands._inat import _inat_cmd
|
|
@@ -80,6 +81,7 @@ cli.add_command(_del_src_cmd)
|
|
|
80
81
|
cli.add_command(_del_stype_cmd)
|
|
81
82
|
|
|
82
83
|
cli.add_command(_embed_cmd)
|
|
84
|
+
cli.add_command(_ensemble_cmd)
|
|
83
85
|
cli.add_command(_extract_all_cmd)
|
|
84
86
|
cli.add_command(_extract_by_image_cmd)
|
|
85
87
|
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
# Defer some imports to improve initialization performance.
|
|
2
|
+
import logging
|
|
2
3
|
from pathlib import Path
|
|
3
4
|
|
|
4
5
|
from britekit.core.config_loader import get_config
|
|
@@ -43,6 +44,7 @@ class Trainer:
|
|
|
43
44
|
# load all the data once for performance, then split as needed in each fold
|
|
44
45
|
dm = DataModule()
|
|
45
46
|
|
|
47
|
+
val_rocs = []
|
|
46
48
|
for k in range(self.cfg.train.num_folds):
|
|
47
49
|
logger = TensorBoardLogger(
|
|
48
50
|
save_dir="logs", name=f"fold-{k}", default_hp_metric=False
|
|
@@ -116,6 +118,27 @@ class Trainer:
|
|
|
116
118
|
if self.cfg.train.test_pickle is not None:
|
|
117
119
|
trainer.test(model, dm)
|
|
118
120
|
|
|
121
|
+
# save stats from k-fold cross-validation
|
|
122
|
+
if self.cfg.train.num_folds > 1 and "val_roc" in trainer.callback_metrics:
|
|
123
|
+
val_rocs.append(float(trainer.callback_metrics["val_roc"]))
|
|
124
|
+
|
|
125
|
+
if val_rocs:
|
|
126
|
+
import math
|
|
127
|
+
import numpy as np
|
|
128
|
+
|
|
129
|
+
mean = float(np.mean(val_rocs))
|
|
130
|
+
std = float(np.std(val_rocs, ddof=1)) if len(val_rocs) > 1 else 0.0
|
|
131
|
+
n = len(val_rocs)
|
|
132
|
+
se = std / math.sqrt(n) if n > 1 else 0.0
|
|
133
|
+
ci95 = 1.96 * se # 95% CI using normal approximation
|
|
134
|
+
|
|
135
|
+
logging.info("Using micro-averaged ROC AUC")
|
|
136
|
+
scores_str = ", ".join(f"{v:.4f}" for v in val_rocs)
|
|
137
|
+
logging.info(f"folds: {scores_str}")
|
|
138
|
+
logging.info(f"mean: {mean:.4f}")
|
|
139
|
+
logging.info(f"standard deviation: {std:.4f}")
|
|
140
|
+
logging.info(f"95% confidence interval: {mean-ci95:.4f} to {mean+ci95:.4f}")
|
|
141
|
+
|
|
119
142
|
def find_lr(self, num_batches: int = 100):
|
|
120
143
|
"""
|
|
121
144
|
Suggest a learning rate and produce a plot.
|
|
@@ -288,7 +288,7 @@ class Tuner:
|
|
|
288
288
|
self.cfg.misc.ckpt_folder = str(
|
|
289
289
|
Path(self.train_log_dir) / train_dir / "checkpoints"
|
|
290
290
|
)
|
|
291
|
-
|
|
291
|
+
logging.info(f"Using checkpoints in {self.cfg.misc.ckpt_folder}")
|
|
292
292
|
self.cfg.infer.min_score = 0
|
|
293
293
|
|
|
294
294
|
# suppress console output during inference and test analysis
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{britekit-0.0.7 → britekit-0.0.9}/britekit/install/data/audioset/unbalanced_train_segments.csv
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|