britekit 0.0.6__py3-none-any.whl → 0.0.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of britekit might be problematic. Click here for more details.
- britekit/__about__.py +1 -1
- britekit/commands/_reports.py +4 -4
- britekit/commands/_tune.py +2 -0
- britekit/core/data_module.py +1 -0
- britekit/core/trainer.py +22 -0
- britekit/core/tuner.py +17 -2
- britekit/models/base_model.py +3 -2
- {britekit-0.0.6.dist-info → britekit-0.0.8.dist-info}/METADATA +1 -1
- {britekit-0.0.6.dist-info → britekit-0.0.8.dist-info}/RECORD +12 -12
- {britekit-0.0.6.dist-info → britekit-0.0.8.dist-info}/WHEEL +0 -0
- {britekit-0.0.6.dist-info → britekit-0.0.8.dist-info}/entry_points.txt +0 -0
- {britekit-0.0.6.dist-info → britekit-0.0.8.dist-info}/licenses/LICENSE.txt +0 -0
britekit/__about__.py
CHANGED
britekit/commands/_reports.py
CHANGED
|
@@ -276,14 +276,14 @@ def rpt_epochs(
|
|
|
276
276
|
tester.initialize()
|
|
277
277
|
|
|
278
278
|
pr_stats = tester.get_pr_auc_stats()
|
|
279
|
-
pr_score = pr_stats["
|
|
279
|
+
pr_score = pr_stats["micro_pr_auc"]
|
|
280
280
|
pr_scores.append(pr_score)
|
|
281
281
|
if pr_score > max_pr_score:
|
|
282
282
|
max_pr_score = pr_score
|
|
283
283
|
max_pr_epoch = epoch_num
|
|
284
284
|
|
|
285
285
|
roc_stats = tester.get_roc_auc_stats()
|
|
286
|
-
roc_score = roc_stats["
|
|
286
|
+
roc_score = roc_stats["micro_roc_auc"]
|
|
287
287
|
roc_scores.append(roc_score)
|
|
288
288
|
if roc_score > max_roc_score:
|
|
289
289
|
max_roc_score = roc_score
|
|
@@ -323,8 +323,8 @@ def rpt_epochs(
|
|
|
323
323
|
plot_path = str(Path(output_path) / "training_scores.jpeg")
|
|
324
324
|
plt.savefig(plot_path, dpi=300, bbox_inches="tight")
|
|
325
325
|
|
|
326
|
-
logging.info(f"Maximum PR-AUC score = {max_pr_score:.3f} at epoch {max_pr_epoch}")
|
|
327
|
-
logging.info(f"Maximum ROC-AUC score = {max_roc_score:.3f} at epoch {max_roc_epoch}")
|
|
326
|
+
logging.info(f"Maximum micro-averaged PR-AUC score = {max_pr_score:.3f} at epoch {max_pr_epoch}")
|
|
327
|
+
logging.info(f"Maximum micro-averaged ROC-AUC score = {max_roc_score:.3f} at epoch {max_roc_epoch}")
|
|
328
328
|
logging.info(f"See plot at {plot_path}")
|
|
329
329
|
|
|
330
330
|
|
britekit/commands/_tune.py
CHANGED
|
@@ -9,6 +9,7 @@ from typing import Optional
|
|
|
9
9
|
|
|
10
10
|
import click
|
|
11
11
|
|
|
12
|
+
from britekit.core.config_loader import get_config
|
|
12
13
|
from britekit.core import util
|
|
13
14
|
|
|
14
15
|
|
|
@@ -57,6 +58,7 @@ def tune(
|
|
|
57
58
|
from britekit.core.tuner import Tuner
|
|
58
59
|
|
|
59
60
|
try:
|
|
61
|
+
cfg, _ = get_config(cfg_path)
|
|
60
62
|
if extract and skip_training:
|
|
61
63
|
logging.error(
|
|
62
64
|
"Performing spectrogram extract is incompatible with skipping training."
|
britekit/core/data_module.py
CHANGED
britekit/core/trainer.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
# Defer some imports to improve initialization performance.
|
|
2
|
+
import logging
|
|
2
3
|
from pathlib import Path
|
|
3
4
|
|
|
4
5
|
from britekit.core.config_loader import get_config
|
|
@@ -43,6 +44,7 @@ class Trainer:
|
|
|
43
44
|
# load all the data once for performance, then split as needed in each fold
|
|
44
45
|
dm = DataModule()
|
|
45
46
|
|
|
47
|
+
val_rocs = []
|
|
46
48
|
for k in range(self.cfg.train.num_folds):
|
|
47
49
|
logger = TensorBoardLogger(
|
|
48
50
|
save_dir="logs", name=f"fold-{k}", default_hp_metric=False
|
|
@@ -116,6 +118,26 @@ class Trainer:
|
|
|
116
118
|
if self.cfg.train.test_pickle is not None:
|
|
117
119
|
trainer.test(model, dm)
|
|
118
120
|
|
|
121
|
+
# save stats from k-fold cross-validation
|
|
122
|
+
if self.cfg.train.num_folds > 1 and "val_roc" in trainer.callback_metrics:
|
|
123
|
+
val_rocs.append(float(trainer.callback_metrics["val_roc"]))
|
|
124
|
+
|
|
125
|
+
if val_rocs:
|
|
126
|
+
import math
|
|
127
|
+
import numpy as np
|
|
128
|
+
mean = float(np.mean(val_rocs))
|
|
129
|
+
std = float(np.std(val_rocs, ddof=1)) if len(val_rocs) > 1 else 0.0
|
|
130
|
+
n = len(val_rocs)
|
|
131
|
+
se = std / math.sqrt(n) if n > 1 else 0.0
|
|
132
|
+
ci95 = 1.96 * se # 95% CI using normal approximation
|
|
133
|
+
|
|
134
|
+
logging.info("Using micro-averaged ROC AUC")
|
|
135
|
+
scores_str = ", ".join(f"{v:.4f}" for v in val_rocs)
|
|
136
|
+
logging.info(f"folds: {scores_str}")
|
|
137
|
+
logging.info(f"mean: {mean:.4f}")
|
|
138
|
+
logging.info(f"standard deviation: {std:.4f}")
|
|
139
|
+
logging.info(f"95% confidence interval: {mean-ci95:.4f} to {mean+ci95:.4f}")
|
|
140
|
+
|
|
119
141
|
def find_lr(self, num_batches: int = 100):
|
|
120
142
|
"""
|
|
121
143
|
Suggest a learning rate and produce a plot.
|
britekit/core/tuner.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
# Defer some imports to improve initialization performance.
|
|
2
2
|
import copy
|
|
3
3
|
import logging
|
|
4
|
-
import os
|
|
5
4
|
from pathlib import Path
|
|
6
5
|
import random
|
|
7
6
|
import re
|
|
@@ -263,6 +262,21 @@ class Tuner:
|
|
|
263
262
|
self.trial_num += 1
|
|
264
263
|
self.trial_metrics[self.trial_num] = {}
|
|
265
264
|
|
|
265
|
+
@staticmethod
|
|
266
|
+
def _find_latest_version_dir(root):
|
|
267
|
+
root = Path(root)
|
|
268
|
+
version_dirs = []
|
|
269
|
+
for d in root.iterdir():
|
|
270
|
+
if d.is_dir() and d.name.startswith("version"):
|
|
271
|
+
m = re.search(r"\d+", d.name)
|
|
272
|
+
if m:
|
|
273
|
+
version_dirs.append((int(m.group()), d))
|
|
274
|
+
|
|
275
|
+
assert version_dirs, "Failed to find training log directory"
|
|
276
|
+
|
|
277
|
+
# Sort numerically by the extracted version number
|
|
278
|
+
return max(version_dirs, key=lambda x: x[0])[1].name
|
|
279
|
+
|
|
266
280
|
def _run_test(self):
|
|
267
281
|
"""
|
|
268
282
|
Run inference with the generated checkpoints and return the selected metric.
|
|
@@ -270,10 +284,11 @@ class Tuner:
|
|
|
270
284
|
from britekit.core.analyzer import Analyzer
|
|
271
285
|
from britekit.testing.per_segment_tester import PerSegmentTester
|
|
272
286
|
|
|
273
|
-
train_dir =
|
|
287
|
+
train_dir = self._find_latest_version_dir(self.train_log_dir)
|
|
274
288
|
self.cfg.misc.ckpt_folder = str(
|
|
275
289
|
Path(self.train_log_dir) / train_dir / "checkpoints"
|
|
276
290
|
)
|
|
291
|
+
logging.info(f"Using checkpoints in {self.cfg.misc.ckpt_folder}")
|
|
277
292
|
self.cfg.infer.min_score = 0
|
|
278
293
|
|
|
279
294
|
# suppress console output during inference and test analysis
|
britekit/models/base_model.py
CHANGED
|
@@ -11,7 +11,6 @@ import torch
|
|
|
11
11
|
from torch import nn
|
|
12
12
|
import torch.nn.functional as F
|
|
13
13
|
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
|
|
14
|
-
from torchmetrics.functional import accuracy
|
|
15
14
|
|
|
16
15
|
from britekit.core.config_loader import get_config
|
|
17
16
|
from britekit.core import util
|
|
@@ -195,7 +194,9 @@ class BaseModel(pl.LightningModule):
|
|
|
195
194
|
if self.multi_label:
|
|
196
195
|
preds = torch.sigmoid(seg_logits)
|
|
197
196
|
roc_auc = metrics.roc_auc_score(y.cpu(), preds.cpu(), average="micro")
|
|
198
|
-
self.log(
|
|
197
|
+
self.log(
|
|
198
|
+
"test_roc_auc", roc_auc, on_step=False, on_epoch=True, prog_bar=True
|
|
199
|
+
)
|
|
199
200
|
|
|
200
201
|
return loss
|
|
201
202
|
|
|
@@ -4,17 +4,17 @@ britekit/core/audio.py,sha256=8QLbNDAiQyViEhrVC8jU0n32we4C22W_jPfc_KcOlmQ,15853
|
|
|
4
4
|
britekit/core/augmentation.py,sha256=5_wyB-6gt7uM68Zl-rO_fPu1D6tlsd2m5oWhA6l0W9Q,5721
|
|
5
5
|
britekit/core/base_config.py,sha256=wbCJI9cEH9mktUTSfLSCnU5AhJT6xhxhwZS4QWRYTpM,8744
|
|
6
6
|
britekit/core/config_loader.py,sha256=epHNlH7yi_sCJX01FwgOsM6vPFk25rgUwkkcoGvVsYg,1341
|
|
7
|
-
britekit/core/data_module.py,sha256=
|
|
7
|
+
britekit/core/data_module.py,sha256=f6BL-ngqkklX06Q7xZ9PMKEkvfeCzQTLYdmXRg3RmCo,9108
|
|
8
8
|
britekit/core/dataset.py,sha256=Xu-lTz3TsHMuW10lHg4NN_r1baS9OQEhAD7EVz1a3A4,5804
|
|
9
9
|
britekit/core/exceptions.py,sha256=ti_ve7ZdhDmzgTuspXXqyw__SUt5NoAXGEwoe3agPU8,443
|
|
10
10
|
britekit/core/pickler.py,sha256=Vj-_DdFQUQj2bIVoyWe5puI8g8dTP9x7ZavbvM1iQZo,5788
|
|
11
11
|
britekit/core/plot.py,sha256=hLuLB1VdtdFyaSHVDGl5tjjFCRgOJJ1ucTVJHM_3D_0,5332
|
|
12
12
|
britekit/core/predictor.py,sha256=u4H8horTTvcg4Oqfpy5PG44eiiMeR5RU3aPZnMiXRCw,22914
|
|
13
13
|
britekit/core/reextractor.py,sha256=gazhIZN8V1K4T_Q_kc-ihxUYbkNnc_hoAS6bpYQc95I,8396
|
|
14
|
-
britekit/core/trainer.py,sha256=
|
|
15
|
-
britekit/core/tuner.py,sha256=
|
|
14
|
+
britekit/core/trainer.py,sha256=N5EsbCzxw3wXxs2PTJJ0OfYFkIi49HCRM0ylT5zSSZk,6439
|
|
15
|
+
britekit/core/tuner.py,sha256=FMmy4p3_j2Tojs4ONPzuUeRpCPWGlttr4rUJac7Hkyk,16435
|
|
16
16
|
britekit/core/util.py,sha256=0JsEEN09hFPQzuttCKaejWofXAjCGSvWEewjkiLAh3E,19172
|
|
17
|
-
britekit/models/base_model.py,sha256=
|
|
17
|
+
britekit/models/base_model.py,sha256=9T7TwHx3K8fl10Vb-qUuypK3NDDZM-ktB8ZLHzqQhdc,16883
|
|
18
18
|
britekit/models/dla.py,sha256=ALMY997AbERN7-sHqQuE5e43llRjpUDPZSFGL-Flv4M,3137
|
|
19
19
|
britekit/models/effnet.py,sha256=e7WdZMsLPXe8jcWChk6n97c8DMV0YyGV6lDP_Jv6Wz4,3129
|
|
20
20
|
britekit/models/gernet.py,sha256=7MEUZaDTfr-6oa8eE8dyDQb2LgahGBOEp1pTZSu1KOE,7022
|
|
@@ -32,7 +32,7 @@ britekit/testing/per_segment_tester.py,sha256=FnaozQ8VmH99aYc1ibmDFfOk_ADgsXQGU_
|
|
|
32
32
|
britekit/training_db/extractor.py,sha256=pT7lAUsNzYs3RXDzpMv7q0MKg6TktiFLKrRtKTWv6ho,8409
|
|
33
33
|
britekit/training_db/training_data_provider.py,sha256=V5aBjsCvrWViZ0Jv05hgcKRizcAXmqoj4q3hAHedoD8,5651
|
|
34
34
|
britekit/training_db/training_db.py,sha256=OOfD1pcbq5HVJbzhmuI-D-gkPHWSoz0cCO4zIUGFvoY,65011
|
|
35
|
-
britekit/__about__.py,sha256
|
|
35
|
+
britekit/__about__.py,sha256=-uGInVbPaVLti1Rr4PYUteRetwYfxeLtIuqiLmEcRjA,122
|
|
36
36
|
britekit/__init__.py,sha256=RpruzdjbvTcFNf21zJYY8HrAhJei91FtNNLjIBmw-kw,1857
|
|
37
37
|
britekit/install/data/classes.csv,sha256=OdTZ8oQdx7N-HKyhftxZStGZYsjhCy4UbanwtQJ2wBM,54
|
|
38
38
|
britekit/install/data/ignore.txt,sha256=RbKvEHtUCbgRYolwR1IucClwyD3q7l2s6QuRjph-Us4,68
|
|
@@ -79,16 +79,16 @@ britekit/commands/_init.py,sha256=FmaQRY-7SYSHCLXL__47LEPecWir7X6zEB05KpradFw,28
|
|
|
79
79
|
britekit/commands/_pickle.py,sha256=p990FsJGfSXcgjtBzH7nPGPh023b8cH0D7RZywQQ5Aw,3488
|
|
80
80
|
britekit/commands/_plot.py,sha256=7vZXsYP9dv4PbHb8K3YbJFZc65YoPIBjEMBolyh6Has,13084
|
|
81
81
|
britekit/commands/_reextract.py,sha256=kCmSjeghg6mhrJ46ibRTmBkGVytU7flFvTbqsnYhBvY,3770
|
|
82
|
-
britekit/commands/_reports.py,sha256=
|
|
82
|
+
britekit/commands/_reports.py,sha256=KVYtpeFQpUC4jAIm2k2xV7aiNq826DL6sUrYEJD38X0,22023
|
|
83
83
|
britekit/commands/_search.py,sha256=HIUXwfPvh3rxpgaFSR3bAAI38OtGVPyMo5GMfLtLX-8,9991
|
|
84
84
|
britekit/commands/_train.py,sha256=vGFKlfcv35cOelArQNbVbTRbDWogT_IMg0wZt5virHY,4158
|
|
85
|
-
britekit/commands/_tune.py,sha256=
|
|
85
|
+
britekit/commands/_tune.py,sha256=8dEZZURE769C0JZwhNpzB6pQxVklzl2w2cyXyWyhWXs,7331
|
|
86
86
|
britekit/commands/_wav2mp3.py,sha256=2Q4cjT6OhJmBPTNzGRMrDd6dSdBBufuQdjhH1V8ghLo,2167
|
|
87
87
|
britekit/commands/_xeno.py,sha256=_6YxQ7xFdaSy5DNUaigkbYp3E8EhtOhTC9b6OFS0MFA,6026
|
|
88
88
|
britekit/commands/_youtube.py,sha256=_u1LrwY_2GxllKd505N_2ArFMbACQ_PtVxuqUCYxFe0,2214
|
|
89
89
|
britekit/core/__init__.py,sha256=QcjcFyvO5KqJLF_HBeqiCk925uU5jTUjIV5lJix9XY4,556
|
|
90
|
-
britekit-0.0.
|
|
91
|
-
britekit-0.0.
|
|
92
|
-
britekit-0.0.
|
|
93
|
-
britekit-0.0.
|
|
94
|
-
britekit-0.0.
|
|
90
|
+
britekit-0.0.8.dist-info/METADATA,sha256=Qtzlff9X_WI1Cz8zpTyntAwFemS8hNbS0ClWJV9KVXk,18555
|
|
91
|
+
britekit-0.0.8.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
92
|
+
britekit-0.0.8.dist-info/entry_points.txt,sha256=ycnPy5DLX14RTf7lKfkQAVyIf1B1zTL1gMsHm455wmg,46
|
|
93
|
+
britekit-0.0.8.dist-info/licenses/LICENSE.txt,sha256=kPoHm6iop8-CUa_720Tt8gqyvLD6D_7218u1hCCpErk,1092
|
|
94
|
+
britekit-0.0.8.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|