sleap-nn 0.1.0a0__py3-none-any.whl → 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sleap_nn/__init__.py +4 -2
- sleap_nn/config/get_config.py +5 -0
- sleap_nn/config/trainer_config.py +5 -0
- sleap_nn/data/custom_datasets.py +53 -11
- sleap_nn/train.py +5 -0
- sleap_nn/training/model_trainer.py +30 -0
- {sleap_nn-0.1.0a0.dist-info → sleap_nn-0.1.0a1.dist-info}/METADATA +1 -1
- {sleap_nn-0.1.0a0.dist-info → sleap_nn-0.1.0a1.dist-info}/RECORD +12 -12
- {sleap_nn-0.1.0a0.dist-info → sleap_nn-0.1.0a1.dist-info}/WHEEL +0 -0
- {sleap_nn-0.1.0a0.dist-info → sleap_nn-0.1.0a1.dist-info}/entry_points.txt +0 -0
- {sleap_nn-0.1.0a0.dist-info → sleap_nn-0.1.0a1.dist-info}/licenses/LICENSE +0 -0
- {sleap_nn-0.1.0a0.dist-info → sleap_nn-0.1.0a1.dist-info}/top_level.txt +0 -0
sleap_nn/__init__.py
CHANGED
|
@@ -41,14 +41,16 @@ def _safe_print(msg):
|
|
|
41
41
|
|
|
42
42
|
|
|
43
43
|
# Add logger with the custom filter
|
|
44
|
+
# Disable colorization to avoid ANSI codes in captured output
|
|
44
45
|
logger.add(
|
|
45
46
|
_safe_print,
|
|
46
47
|
level="DEBUG",
|
|
47
48
|
filter=_should_log,
|
|
48
|
-
format="{time:YYYY-MM-DD HH:mm:ss} | {
|
|
49
|
+
format="{time:YYYY-MM-DD HH:mm:ss} | {message}",
|
|
50
|
+
colorize=False,
|
|
49
51
|
)
|
|
50
52
|
|
|
51
|
-
__version__ = "0.1.
|
|
53
|
+
__version__ = "0.1.0a1"
|
|
52
54
|
|
|
53
55
|
# Public API
|
|
54
56
|
from sleap_nn.evaluation import load_metrics
|
sleap_nn/config/get_config.py
CHANGED
|
@@ -677,6 +677,7 @@ def get_trainer_config(
|
|
|
677
677
|
wandb_save_viz_imgs_wandb: bool = False,
|
|
678
678
|
wandb_resume_prv_runid: Optional[str] = None,
|
|
679
679
|
wandb_group_name: Optional[str] = None,
|
|
680
|
+
wandb_delete_local_logs: Optional[bool] = None,
|
|
680
681
|
optimizer: str = "Adam",
|
|
681
682
|
learning_rate: float = 1e-3,
|
|
682
683
|
amsgrad: bool = False,
|
|
@@ -746,6 +747,9 @@ def get_trainer_config(
|
|
|
746
747
|
wandb_resume_prv_runid: Previous run ID if training should be resumed from a previous
|
|
747
748
|
ckpt. Default: None
|
|
748
749
|
wandb_group_name: Group name for the wandb run. Default: None.
|
|
750
|
+
wandb_delete_local_logs: If True, delete local wandb logs folder after training.
|
|
751
|
+
If False, keep the folder. If None (default), automatically delete if logging
|
|
752
|
+
online (wandb_mode != "offline") and keep if logging offline. Default: None.
|
|
749
753
|
optimizer: Optimizer to be used. One of ["Adam", "AdamW"]. Default: "Adam".
|
|
750
754
|
learning_rate: Learning rate of type float. Default: 1e-3.
|
|
751
755
|
amsgrad: Enable AMSGrad with the optimizer. Default: False.
|
|
@@ -846,6 +850,7 @@ def get_trainer_config(
|
|
|
846
850
|
save_viz_imgs_wandb=wandb_save_viz_imgs_wandb,
|
|
847
851
|
prv_runid=wandb_resume_prv_runid,
|
|
848
852
|
group=wandb_group_name,
|
|
853
|
+
delete_local_logs=wandb_delete_local_logs,
|
|
849
854
|
),
|
|
850
855
|
save_ckpt=save_ckpt,
|
|
851
856
|
ckpt_dir=ckpt_dir,
|
|
@@ -90,6 +90,10 @@ class WandBConfig:
|
|
|
90
90
|
viz_box_size: (float) Size of keypoint boxes in pixels (for viz_boxes). *Default*: `5.0`.
|
|
91
91
|
viz_confmap_threshold: (float) Threshold for confidence map masks (for viz_masks). *Default*: `0.1`.
|
|
92
92
|
log_viz_table: (bool) If True, also log images to a wandb.Table for backwards compatibility. *Default*: `False`.
|
|
93
|
+
delete_local_logs: (bool, optional) If True, delete local wandb logs folder after
|
|
94
|
+
training. If False, keep the folder. If None (default), automatically delete
|
|
95
|
+
if logging online (wandb_mode != "offline") and keep if logging offline.
|
|
96
|
+
*Default*: `None`.
|
|
93
97
|
"""
|
|
94
98
|
|
|
95
99
|
entity: Optional[str] = None
|
|
@@ -107,6 +111,7 @@ class WandBConfig:
|
|
|
107
111
|
viz_box_size: float = 5.0
|
|
108
112
|
viz_confmap_threshold: float = 0.1
|
|
109
113
|
log_viz_table: bool = False
|
|
114
|
+
delete_local_logs: Optional[bool] = None
|
|
110
115
|
|
|
111
116
|
|
|
112
117
|
@define
|
sleap_nn/data/custom_datasets.py
CHANGED
|
@@ -13,6 +13,14 @@ from omegaconf import DictConfig, OmegaConf
|
|
|
13
13
|
import numpy as np
|
|
14
14
|
from PIL import Image
|
|
15
15
|
from loguru import logger
|
|
16
|
+
from rich.progress import (
|
|
17
|
+
Progress,
|
|
18
|
+
SpinnerColumn,
|
|
19
|
+
TextColumn,
|
|
20
|
+
BarColumn,
|
|
21
|
+
TimeElapsedColumn,
|
|
22
|
+
)
|
|
23
|
+
from rich.console import Console
|
|
16
24
|
import torch
|
|
17
25
|
import torchvision.transforms as T
|
|
18
26
|
from torch.utils.data import Dataset, DataLoader, DistributedSampler
|
|
@@ -215,17 +223,51 @@ class BaseDataset(Dataset):
|
|
|
215
223
|
def _fill_cache(self, labels: List[sio.Labels]):
|
|
216
224
|
"""Load all samples to cache."""
|
|
217
225
|
# TODO: Implement parallel processing (using threads might cause error with MediaVideo backend)
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
226
|
+
import os
|
|
227
|
+
import sys
|
|
228
|
+
|
|
229
|
+
total_samples = len(self.lf_idx_list)
|
|
230
|
+
cache_type = "disk" if self.cache_img == "disk" else "memory"
|
|
231
|
+
|
|
232
|
+
# Check for NO_COLOR env var or non-interactive terminal
|
|
233
|
+
no_color = (
|
|
234
|
+
os.environ.get("NO_COLOR") is not None
|
|
235
|
+
or os.environ.get("FORCE_COLOR") == "0"
|
|
236
|
+
)
|
|
237
|
+
use_progress = sys.stdout.isatty() and not no_color
|
|
238
|
+
|
|
239
|
+
def process_samples(progress=None, task=None):
|
|
240
|
+
for sample in self.lf_idx_list:
|
|
241
|
+
labels_idx = sample["labels_idx"]
|
|
242
|
+
lf_idx = sample["lf_idx"]
|
|
243
|
+
img = labels[labels_idx][lf_idx].image
|
|
244
|
+
if img.shape[-1] == 1:
|
|
245
|
+
img = np.squeeze(img)
|
|
246
|
+
if self.cache_img == "disk":
|
|
247
|
+
f_name = f"{self.cache_img_path}/sample_{labels_idx}_{lf_idx}.jpg"
|
|
248
|
+
Image.fromarray(img).save(f_name, format="JPEG")
|
|
249
|
+
if self.cache_img == "memory":
|
|
250
|
+
self.cache[(labels_idx, lf_idx)] = img
|
|
251
|
+
if progress is not None:
|
|
252
|
+
progress.update(task, advance=1)
|
|
253
|
+
|
|
254
|
+
if use_progress:
|
|
255
|
+
with Progress(
|
|
256
|
+
SpinnerColumn(),
|
|
257
|
+
TextColumn("[progress.description]{task.description}"),
|
|
258
|
+
BarColumn(),
|
|
259
|
+
TextColumn("{task.completed}/{task.total}"),
|
|
260
|
+
TimeElapsedColumn(),
|
|
261
|
+
console=Console(force_terminal=True),
|
|
262
|
+
transient=True,
|
|
263
|
+
) as progress:
|
|
264
|
+
task = progress.add_task(
|
|
265
|
+
f"Caching images to {cache_type}", total=total_samples
|
|
266
|
+
)
|
|
267
|
+
process_samples(progress, task)
|
|
268
|
+
else:
|
|
269
|
+
logger.info(f"Caching {total_samples} images to {cache_type}...")
|
|
270
|
+
process_samples()
|
|
229
271
|
|
|
230
272
|
def __len__(self) -> int:
|
|
231
273
|
"""Return the number of samples in the dataset."""
|
sleap_nn/train.py
CHANGED
|
@@ -175,6 +175,7 @@ def train(
|
|
|
175
175
|
wandb_save_viz_imgs_wandb: bool = False,
|
|
176
176
|
wandb_resume_prv_runid: Optional[str] = None,
|
|
177
177
|
wandb_group_name: Optional[str] = None,
|
|
178
|
+
wandb_delete_local_logs: Optional[bool] = None,
|
|
178
179
|
optimizer: str = "Adam",
|
|
179
180
|
learning_rate: float = 1e-3,
|
|
180
181
|
amsgrad: bool = False,
|
|
@@ -353,6 +354,9 @@ def train(
|
|
|
353
354
|
wandb_resume_prv_runid: Previous run ID if training should be resumed from a previous
|
|
354
355
|
ckpt. Default: None
|
|
355
356
|
wandb_group_name: Group name for the wandb run. Default: None.
|
|
357
|
+
wandb_delete_local_logs: If True, delete local wandb logs folder after training.
|
|
358
|
+
If False, keep the folder. If None (default), automatically delete if logging
|
|
359
|
+
online (wandb_mode != "offline") and keep if logging offline. Default: None.
|
|
356
360
|
optimizer: Optimizer to be used. One of ["Adam", "AdamW"]. Default: "Adam".
|
|
357
361
|
learning_rate: Learning rate of type float. Default: 1e-3.
|
|
358
362
|
amsgrad: Enable AMSGrad with the optimizer. Default: False.
|
|
@@ -456,6 +460,7 @@ def train(
|
|
|
456
460
|
wandb_save_viz_imgs_wandb=wandb_save_viz_imgs_wandb,
|
|
457
461
|
wandb_resume_prv_runid=wandb_resume_prv_runid,
|
|
458
462
|
wandb_group_name=wandb_group_name,
|
|
463
|
+
wandb_delete_local_logs=wandb_delete_local_logs,
|
|
459
464
|
optimizer=optimizer,
|
|
460
465
|
learning_rate=learning_rate,
|
|
461
466
|
amsgrad=amsgrad,
|
|
@@ -898,6 +898,17 @@ class ModelTrainer:
|
|
|
898
898
|
)
|
|
899
899
|
loggers.append(wandb_logger)
|
|
900
900
|
|
|
901
|
+
# Log message about wandb local logs cleanup
|
|
902
|
+
should_delete_wandb_logs = wandb_config.delete_local_logs is True or (
|
|
903
|
+
wandb_config.delete_local_logs is None
|
|
904
|
+
and wandb_config.wandb_mode != "offline"
|
|
905
|
+
)
|
|
906
|
+
if should_delete_wandb_logs:
|
|
907
|
+
logger.info(
|
|
908
|
+
"WandB local logs will be deleted after training completes. "
|
|
909
|
+
"To keep logs, set trainer_config.wandb.delete_local_logs=false"
|
|
910
|
+
)
|
|
911
|
+
|
|
901
912
|
# Learning rate monitor callback - logs LR at each step for dynamic schedulers
|
|
902
913
|
# Only added when wandb is enabled since it requires a logger
|
|
903
914
|
callbacks.append(LearningRateMonitor(logging_interval="step"))
|
|
@@ -1314,6 +1325,25 @@ class ModelTrainer:
|
|
|
1314
1325
|
if self.trainer.global_rank == 0 and self.config.trainer_config.use_wandb:
|
|
1315
1326
|
wandb.finish()
|
|
1316
1327
|
|
|
1328
|
+
# Delete local wandb logs if configured
|
|
1329
|
+
wandb_config = self.config.trainer_config.wandb
|
|
1330
|
+
should_delete_wandb_logs = wandb_config.delete_local_logs is True or (
|
|
1331
|
+
wandb_config.delete_local_logs is None
|
|
1332
|
+
and wandb_config.wandb_mode != "offline"
|
|
1333
|
+
)
|
|
1334
|
+
if should_delete_wandb_logs:
|
|
1335
|
+
wandb_dir = (
|
|
1336
|
+
Path(self.config.trainer_config.ckpt_dir)
|
|
1337
|
+
/ self.config.trainer_config.run_name
|
|
1338
|
+
/ "wandb"
|
|
1339
|
+
)
|
|
1340
|
+
if wandb_dir.exists():
|
|
1341
|
+
logger.info(
|
|
1342
|
+
f"Deleting local wandb logs at {wandb_dir}... "
|
|
1343
|
+
"(set trainer_config.wandb.delete_local_logs=false to disable)"
|
|
1344
|
+
)
|
|
1345
|
+
shutil.rmtree(wandb_dir, ignore_errors=True)
|
|
1346
|
+
|
|
1317
1347
|
# delete image disk caching
|
|
1318
1348
|
if (
|
|
1319
1349
|
self.config.data_config.data_pipeline_fw
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: sleap-nn
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.0a1
|
|
4
4
|
Summary: Neural network backend for training and inference for animal pose estimation.
|
|
5
5
|
Author-email: Divya Seshadri Murali <dimurali@salk.edu>, Elizabeth Berrigan <eberrigan@salk.edu>, Vincent Tu <vitu@ucsd.edu>, Liezl Maree <lmaree@salk.edu>, David Samy <davidasamy@gmail.com>, Talmo Pereira <talmo@salk.edu>
|
|
6
6
|
License: BSD-3-Clause
|
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
sleap_nn/.DS_Store,sha256=HY8amA79eHkt7o5VUiNsMxkc9YwW6WIPyZbYRj_JdSU,6148
|
|
2
|
-
sleap_nn/__init__.py,sha256=
|
|
2
|
+
sleap_nn/__init__.py,sha256=l5Lwiad8GOurqkAhMwWw8-UcpH6af2TnMURf-oKj_U8,1362
|
|
3
3
|
sleap_nn/cli.py,sha256=U4hpEcOxK7a92GeItY95E2DRm5P1ME1GqU__mxaDcW0,21167
|
|
4
4
|
sleap_nn/evaluation.py,sha256=3u7y85wFoBgCwOB2xOGTJIDrd2dUPWOo4m0s0oW3da4,31095
|
|
5
5
|
sleap_nn/legacy_models.py,sha256=8aGK30DZv3pW2IKDBEWH1G2mrytjaxPQD4miPUehj0M,20258
|
|
6
6
|
sleap_nn/predict.py,sha256=8QKjRbS-L-6HF1NFJWioBPv3HSzUpFr2oGEB5hRJzQA,35523
|
|
7
7
|
sleap_nn/system_info.py,sha256=7tWe3y6s872nDbrZoHIdSs-w4w46Z4dEV2qCV-Fe7No,14711
|
|
8
|
-
sleap_nn/train.py,sha256=
|
|
8
|
+
sleap_nn/train.py,sha256=XvVhzMXL9rNQLx1-6jIcp5BAO1pR7AZjdphMn5ZX-_I,27558
|
|
9
9
|
sleap_nn/architectures/__init__.py,sha256=w0XxQcx-CYyooszzvxRkKWiJkUg-26IlwQoGna8gn40,46
|
|
10
10
|
sleap_nn/architectures/common.py,sha256=MLv-zdHsWL5Q2ct_Wv6SQbRS-5hrFtjK_pvBEfwx-vU,3660
|
|
11
11
|
sleap_nn/architectures/convnext.py,sha256=l9lMJDxIMb-9MI3ShOtVwbOUMuwOLtSQlxiVyYHqjvE,13953
|
|
@@ -17,15 +17,15 @@ sleap_nn/architectures/unet.py,sha256=rAy2Omi6tv1MNW2nBn0Tw-94Nw_-1wFfCT3-IUyPcg
|
|
|
17
17
|
sleap_nn/architectures/utils.py,sha256=L0KVs0gbtG8U75Sl40oH_r_w2ySawh3oQPqIGi54HGo,2171
|
|
18
18
|
sleap_nn/config/__init__.py,sha256=l0xV1uJsGJfMPfWAqlUR7Ivu4cSCWsP-3Y9ueyPESuk,42
|
|
19
19
|
sleap_nn/config/data_config.py,sha256=5a5YlXm4V9qGvkqgFNy6o0XJ_Q06UFjpYJXmNHfvXEI,24021
|
|
20
|
-
sleap_nn/config/get_config.py,sha256=
|
|
20
|
+
sleap_nn/config/get_config.py,sha256=rjNUffKU9z-ohLwrOVmJNGCqwUM93eh68h4KJfrSy8Y,42396
|
|
21
21
|
sleap_nn/config/model_config.py,sha256=XFIbqFno7IkX0Se5WF_2_7aUalAlC2SvpDe-uP2TttM,57582
|
|
22
|
-
sleap_nn/config/trainer_config.py,sha256=
|
|
22
|
+
sleap_nn/config/trainer_config.py,sha256=ZMXxns6VYakgYHRhkM541Eje76DdaTdDi4FFPNjJtP4,28413
|
|
23
23
|
sleap_nn/config/training_job_config.py,sha256=v12_ME_tBUg8JFwOxJNW4sDQn-SedDhiJOGz-TlRwT0,5861
|
|
24
24
|
sleap_nn/config/utils.py,sha256=GgWgVs7_N7ifsJ5OQG3_EyOagNyN3Dx7wS2BAlkaRkg,5553
|
|
25
25
|
sleap_nn/data/__init__.py,sha256=eMNvFJFa3gv5Rq8oK5wzo6zt1pOlwUGYf8EQii6bq7c,54
|
|
26
26
|
sleap_nn/data/augmentation.py,sha256=Kqw_DayPth_DBsmaO1G8Voou_-cYZuSPOjSQWSajgRI,13618
|
|
27
27
|
sleap_nn/data/confidence_maps.py,sha256=PTRqZWSAz1S7viJhxu7QgIC1aHiek97c_dCUsKUwG1o,6217
|
|
28
|
-
sleap_nn/data/custom_datasets.py,sha256=
|
|
28
|
+
sleap_nn/data/custom_datasets.py,sha256=SO-aNB1-bB9DL5Zw-oGYDsliBxwI4iKX_FmwgZjKOgQ,99975
|
|
29
29
|
sleap_nn/data/edge_maps.py,sha256=75qG_7zHRw7fC8JUCVI2tzYakIoxxneWWmcrTwjcHPo,12519
|
|
30
30
|
sleap_nn/data/identity.py,sha256=7vNup6PudST4yDLyDT9wDO-cunRirTEvx4sP77xrlfk,5193
|
|
31
31
|
sleap_nn/data/instance_centroids.py,sha256=SF-3zJt_VMTbZI5ssbrvmZQZDd3684bn55EAtvcbQ6o,2172
|
|
@@ -55,11 +55,11 @@ sleap_nn/training/__init__.py,sha256=vNTKsIJPZHJwFSKn5PmjiiRJunR_9e7y4_v0S6rdF8U
|
|
|
55
55
|
sleap_nn/training/callbacks.py,sha256=TVnQ6plNC2MnlTiY2rSCRuw2WRk5cQSziek_VPUcOEg,25994
|
|
56
56
|
sleap_nn/training/lightning_modules.py,sha256=G3c4xJkYWW-iSRawzkgTqkGd4lTsbPiMTcB5Nvq7jes,85512
|
|
57
57
|
sleap_nn/training/losses.py,sha256=gbdinUURh4QUzjmNd2UJpt4FXwecqKy9gHr65JZ1bZk,1632
|
|
58
|
-
sleap_nn/training/model_trainer.py,sha256=
|
|
58
|
+
sleap_nn/training/model_trainer.py,sha256=loCmEX0DfBtdV_pN-W8s31fn2_L-lbpWaq3OQXeSp-0,59337
|
|
59
59
|
sleap_nn/training/utils.py,sha256=ivdkZEI0DkTCm6NPszsaDOh9jSfozkONZdl6TvvQUWI,20398
|
|
60
|
-
sleap_nn-0.1.
|
|
61
|
-
sleap_nn-0.1.
|
|
62
|
-
sleap_nn-0.1.
|
|
63
|
-
sleap_nn-0.1.
|
|
64
|
-
sleap_nn-0.1.
|
|
65
|
-
sleap_nn-0.1.
|
|
60
|
+
sleap_nn-0.1.0a1.dist-info/licenses/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
|
|
61
|
+
sleap_nn-0.1.0a1.dist-info/METADATA,sha256=h3d4WPIu_JunY32jaRqJ4-fXp4KruTWT57FWb3L6dps,5637
|
|
62
|
+
sleap_nn-0.1.0a1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
63
|
+
sleap_nn-0.1.0a1.dist-info/entry_points.txt,sha256=zfl5Y3hidZxWBvo8qXvu5piJAXJ_l6v7xVFm0gNiUoI,46
|
|
64
|
+
sleap_nn-0.1.0a1.dist-info/top_level.txt,sha256=Kz68iQ55K75LWgSeqz4V4SCMGeFFYH-KGBOyhQh3xZE,9
|
|
65
|
+
sleap_nn-0.1.0a1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|