sleap-nn 0.1.0a0__py3-none-any.whl → 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sleap_nn/__init__.py CHANGED
@@ -41,14 +41,16 @@ def _safe_print(msg):
41
41
 
42
42
 
43
43
  # Add logger with the custom filter
44
+ # Disable colorization to avoid ANSI codes in captured output
44
45
  logger.add(
45
46
  _safe_print,
46
47
  level="DEBUG",
47
48
  filter=_should_log,
48
- format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {name}:{function}:{line} | {message}",
49
+ format="{time:YYYY-MM-DD HH:mm:ss} | {message}",
50
+ colorize=False,
49
51
  )
50
52
 
51
- __version__ = "0.1.0a0"
53
+ __version__ = "0.1.0a1"
52
54
 
53
55
  # Public API
54
56
  from sleap_nn.evaluation import load_metrics
@@ -677,6 +677,7 @@ def get_trainer_config(
677
677
  wandb_save_viz_imgs_wandb: bool = False,
678
678
  wandb_resume_prv_runid: Optional[str] = None,
679
679
  wandb_group_name: Optional[str] = None,
680
+ wandb_delete_local_logs: Optional[bool] = None,
680
681
  optimizer: str = "Adam",
681
682
  learning_rate: float = 1e-3,
682
683
  amsgrad: bool = False,
@@ -746,6 +747,9 @@ def get_trainer_config(
746
747
  wandb_resume_prv_runid: Previous run ID if training should be resumed from a previous
747
748
  ckpt. Default: None
748
749
  wandb_group_name: Group name for the wandb run. Default: None.
750
+ wandb_delete_local_logs: If True, delete local wandb logs folder after training.
751
+ If False, keep the folder. If None (default), automatically delete if logging
752
+ online (wandb_mode != "offline") and keep if logging offline. Default: None.
749
753
  optimizer: Optimizer to be used. One of ["Adam", "AdamW"]. Default: "Adam".
750
754
  learning_rate: Learning rate of type float. Default: 1e-3.
751
755
  amsgrad: Enable AMSGrad with the optimizer. Default: False.
@@ -846,6 +850,7 @@ def get_trainer_config(
846
850
  save_viz_imgs_wandb=wandb_save_viz_imgs_wandb,
847
851
  prv_runid=wandb_resume_prv_runid,
848
852
  group=wandb_group_name,
853
+ delete_local_logs=wandb_delete_local_logs,
849
854
  ),
850
855
  save_ckpt=save_ckpt,
851
856
  ckpt_dir=ckpt_dir,
@@ -90,6 +90,10 @@ class WandBConfig:
90
90
  viz_box_size: (float) Size of keypoint boxes in pixels (for viz_boxes). *Default*: `5.0`.
91
91
  viz_confmap_threshold: (float) Threshold for confidence map masks (for viz_masks). *Default*: `0.1`.
92
92
  log_viz_table: (bool) If True, also log images to a wandb.Table for backwards compatibility. *Default*: `False`.
93
+ delete_local_logs: (bool, optional) If True, delete local wandb logs folder after
94
+ training. If False, keep the folder. If None (default), automatically delete
95
+ if logging online (wandb_mode != "offline") and keep if logging offline.
96
+ *Default*: `None`.
93
97
  """
94
98
 
95
99
  entity: Optional[str] = None
@@ -107,6 +111,7 @@ class WandBConfig:
107
111
  viz_box_size: float = 5.0
108
112
  viz_confmap_threshold: float = 0.1
109
113
  log_viz_table: bool = False
114
+ delete_local_logs: Optional[bool] = None
110
115
 
111
116
 
112
117
  @define
@@ -13,6 +13,14 @@ from omegaconf import DictConfig, OmegaConf
13
13
  import numpy as np
14
14
  from PIL import Image
15
15
  from loguru import logger
16
+ from rich.progress import (
17
+ Progress,
18
+ SpinnerColumn,
19
+ TextColumn,
20
+ BarColumn,
21
+ TimeElapsedColumn,
22
+ )
23
+ from rich.console import Console
16
24
  import torch
17
25
  import torchvision.transforms as T
18
26
  from torch.utils.data import Dataset, DataLoader, DistributedSampler
@@ -215,17 +223,51 @@ class BaseDataset(Dataset):
215
223
  def _fill_cache(self, labels: List[sio.Labels]):
216
224
  """Load all samples to cache."""
217
225
  # TODO: Implement parallel processing (using threads might cause error with MediaVideo backend)
218
- for sample in self.lf_idx_list:
219
- labels_idx = sample["labels_idx"]
220
- lf_idx = sample["lf_idx"]
221
- img = labels[labels_idx][lf_idx].image
222
- if img.shape[-1] == 1:
223
- img = np.squeeze(img)
224
- if self.cache_img == "disk":
225
- f_name = f"{self.cache_img_path}/sample_{labels_idx}_{lf_idx}.jpg"
226
- Image.fromarray(img).save(f_name, format="JPEG")
227
- if self.cache_img == "memory":
228
- self.cache[(labels_idx, lf_idx)] = img
226
+ import os
227
+ import sys
228
+
229
+ total_samples = len(self.lf_idx_list)
230
+ cache_type = "disk" if self.cache_img == "disk" else "memory"
231
+
232
+ # Check for NO_COLOR env var or non-interactive terminal
233
+ no_color = (
234
+ os.environ.get("NO_COLOR") is not None
235
+ or os.environ.get("FORCE_COLOR") == "0"
236
+ )
237
+ use_progress = sys.stdout.isatty() and not no_color
238
+
239
+ def process_samples(progress=None, task=None):
240
+ for sample in self.lf_idx_list:
241
+ labels_idx = sample["labels_idx"]
242
+ lf_idx = sample["lf_idx"]
243
+ img = labels[labels_idx][lf_idx].image
244
+ if img.shape[-1] == 1:
245
+ img = np.squeeze(img)
246
+ if self.cache_img == "disk":
247
+ f_name = f"{self.cache_img_path}/sample_{labels_idx}_{lf_idx}.jpg"
248
+ Image.fromarray(img).save(f_name, format="JPEG")
249
+ if self.cache_img == "memory":
250
+ self.cache[(labels_idx, lf_idx)] = img
251
+ if progress is not None:
252
+ progress.update(task, advance=1)
253
+
254
+ if use_progress:
255
+ with Progress(
256
+ SpinnerColumn(),
257
+ TextColumn("[progress.description]{task.description}"),
258
+ BarColumn(),
259
+ TextColumn("{task.completed}/{task.total}"),
260
+ TimeElapsedColumn(),
261
+ console=Console(force_terminal=True),
262
+ transient=True,
263
+ ) as progress:
264
+ task = progress.add_task(
265
+ f"Caching images to {cache_type}", total=total_samples
266
+ )
267
+ process_samples(progress, task)
268
+ else:
269
+ logger.info(f"Caching {total_samples} images to {cache_type}...")
270
+ process_samples()
229
271
 
230
272
  def __len__(self) -> int:
231
273
  """Return the number of samples in the dataset."""
sleap_nn/train.py CHANGED
@@ -175,6 +175,7 @@ def train(
175
175
  wandb_save_viz_imgs_wandb: bool = False,
176
176
  wandb_resume_prv_runid: Optional[str] = None,
177
177
  wandb_group_name: Optional[str] = None,
178
+ wandb_delete_local_logs: Optional[bool] = None,
178
179
  optimizer: str = "Adam",
179
180
  learning_rate: float = 1e-3,
180
181
  amsgrad: bool = False,
@@ -353,6 +354,9 @@ def train(
353
354
  wandb_resume_prv_runid: Previous run ID if training should be resumed from a previous
354
355
  ckpt. Default: None
355
356
  wandb_group_name: Group name for the wandb run. Default: None.
357
+ wandb_delete_local_logs: If True, delete local wandb logs folder after training.
358
+ If False, keep the folder. If None (default), automatically delete if logging
359
+ online (wandb_mode != "offline") and keep if logging offline. Default: None.
356
360
  optimizer: Optimizer to be used. One of ["Adam", "AdamW"]. Default: "Adam".
357
361
  learning_rate: Learning rate of type float. Default: 1e-3.
358
362
  amsgrad: Enable AMSGrad with the optimizer. Default: False.
@@ -456,6 +460,7 @@ def train(
456
460
  wandb_save_viz_imgs_wandb=wandb_save_viz_imgs_wandb,
457
461
  wandb_resume_prv_runid=wandb_resume_prv_runid,
458
462
  wandb_group_name=wandb_group_name,
463
+ wandb_delete_local_logs=wandb_delete_local_logs,
459
464
  optimizer=optimizer,
460
465
  learning_rate=learning_rate,
461
466
  amsgrad=amsgrad,
@@ -898,6 +898,17 @@ class ModelTrainer:
898
898
  )
899
899
  loggers.append(wandb_logger)
900
900
 
901
+ # Log message about wandb local logs cleanup
902
+ should_delete_wandb_logs = wandb_config.delete_local_logs is True or (
903
+ wandb_config.delete_local_logs is None
904
+ and wandb_config.wandb_mode != "offline"
905
+ )
906
+ if should_delete_wandb_logs:
907
+ logger.info(
908
+ "WandB local logs will be deleted after training completes. "
909
+ "To keep logs, set trainer_config.wandb.delete_local_logs=false"
910
+ )
911
+
901
912
  # Learning rate monitor callback - logs LR at each step for dynamic schedulers
902
913
  # Only added when wandb is enabled since it requires a logger
903
914
  callbacks.append(LearningRateMonitor(logging_interval="step"))
@@ -1314,6 +1325,25 @@ class ModelTrainer:
1314
1325
  if self.trainer.global_rank == 0 and self.config.trainer_config.use_wandb:
1315
1326
  wandb.finish()
1316
1327
 
1328
+ # Delete local wandb logs if configured
1329
+ wandb_config = self.config.trainer_config.wandb
1330
+ should_delete_wandb_logs = wandb_config.delete_local_logs is True or (
1331
+ wandb_config.delete_local_logs is None
1332
+ and wandb_config.wandb_mode != "offline"
1333
+ )
1334
+ if should_delete_wandb_logs:
1335
+ wandb_dir = (
1336
+ Path(self.config.trainer_config.ckpt_dir)
1337
+ / self.config.trainer_config.run_name
1338
+ / "wandb"
1339
+ )
1340
+ if wandb_dir.exists():
1341
+ logger.info(
1342
+ f"Deleting local wandb logs at {wandb_dir}... "
1343
+ "(set trainer_config.wandb.delete_local_logs=false to disable)"
1344
+ )
1345
+ shutil.rmtree(wandb_dir, ignore_errors=True)
1346
+
1317
1347
  # delete image disk caching
1318
1348
  if (
1319
1349
  self.config.data_config.data_pipeline_fw
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sleap-nn
3
- Version: 0.1.0a0
3
+ Version: 0.1.0a1
4
4
  Summary: Neural network backend for training and inference for animal pose estimation.
5
5
  Author-email: Divya Seshadri Murali <dimurali@salk.edu>, Elizabeth Berrigan <eberrigan@salk.edu>, Vincent Tu <vitu@ucsd.edu>, Liezl Maree <lmaree@salk.edu>, David Samy <davidasamy@gmail.com>, Talmo Pereira <talmo@salk.edu>
6
6
  License: BSD-3-Clause
@@ -1,11 +1,11 @@
1
1
  sleap_nn/.DS_Store,sha256=HY8amA79eHkt7o5VUiNsMxkc9YwW6WIPyZbYRj_JdSU,6148
2
- sleap_nn/__init__.py,sha256=DzQeiZIFUmfhpf6mk4j1AKAY2bofVMyIa31xbiSu-ls,1317
2
+ sleap_nn/__init__.py,sha256=l5Lwiad8GOurqkAhMwWw8-UcpH6af2TnMURf-oKj_U8,1362
3
3
  sleap_nn/cli.py,sha256=U4hpEcOxK7a92GeItY95E2DRm5P1ME1GqU__mxaDcW0,21167
4
4
  sleap_nn/evaluation.py,sha256=3u7y85wFoBgCwOB2xOGTJIDrd2dUPWOo4m0s0oW3da4,31095
5
5
  sleap_nn/legacy_models.py,sha256=8aGK30DZv3pW2IKDBEWH1G2mrytjaxPQD4miPUehj0M,20258
6
6
  sleap_nn/predict.py,sha256=8QKjRbS-L-6HF1NFJWioBPv3HSzUpFr2oGEB5hRJzQA,35523
7
7
  sleap_nn/system_info.py,sha256=7tWe3y6s872nDbrZoHIdSs-w4w46Z4dEV2qCV-Fe7No,14711
8
- sleap_nn/train.py,sha256=fWx_b1HqkadQ-GM_VEM1frCd8WkzJLqRARBNn8UoUbo,27181
8
+ sleap_nn/train.py,sha256=XvVhzMXL9rNQLx1-6jIcp5BAO1pR7AZjdphMn5ZX-_I,27558
9
9
  sleap_nn/architectures/__init__.py,sha256=w0XxQcx-CYyooszzvxRkKWiJkUg-26IlwQoGna8gn40,46
10
10
  sleap_nn/architectures/common.py,sha256=MLv-zdHsWL5Q2ct_Wv6SQbRS-5hrFtjK_pvBEfwx-vU,3660
11
11
  sleap_nn/architectures/convnext.py,sha256=l9lMJDxIMb-9MI3ShOtVwbOUMuwOLtSQlxiVyYHqjvE,13953
@@ -17,15 +17,15 @@ sleap_nn/architectures/unet.py,sha256=rAy2Omi6tv1MNW2nBn0Tw-94Nw_-1wFfCT3-IUyPcg
17
17
  sleap_nn/architectures/utils.py,sha256=L0KVs0gbtG8U75Sl40oH_r_w2ySawh3oQPqIGi54HGo,2171
18
18
  sleap_nn/config/__init__.py,sha256=l0xV1uJsGJfMPfWAqlUR7Ivu4cSCWsP-3Y9ueyPESuk,42
19
19
  sleap_nn/config/data_config.py,sha256=5a5YlXm4V9qGvkqgFNy6o0XJ_Q06UFjpYJXmNHfvXEI,24021
20
- sleap_nn/config/get_config.py,sha256=vN_aOPTj9F-QBqGGfVSv8_aFSAYl-RfXY0pdbdcqjcM,42021
20
+ sleap_nn/config/get_config.py,sha256=rjNUffKU9z-ohLwrOVmJNGCqwUM93eh68h4KJfrSy8Y,42396
21
21
  sleap_nn/config/model_config.py,sha256=XFIbqFno7IkX0Se5WF_2_7aUalAlC2SvpDe-uP2TttM,57582
22
- sleap_nn/config/trainer_config.py,sha256=PaoNtRSNc2xgzwN955aR9kTZL8IxCWdevGljLxS6jOk,28073
22
+ sleap_nn/config/trainer_config.py,sha256=ZMXxns6VYakgYHRhkM541Eje76DdaTdDi4FFPNjJtP4,28413
23
23
  sleap_nn/config/training_job_config.py,sha256=v12_ME_tBUg8JFwOxJNW4sDQn-SedDhiJOGz-TlRwT0,5861
24
24
  sleap_nn/config/utils.py,sha256=GgWgVs7_N7ifsJ5OQG3_EyOagNyN3Dx7wS2BAlkaRkg,5553
25
25
  sleap_nn/data/__init__.py,sha256=eMNvFJFa3gv5Rq8oK5wzo6zt1pOlwUGYf8EQii6bq7c,54
26
26
  sleap_nn/data/augmentation.py,sha256=Kqw_DayPth_DBsmaO1G8Voou_-cYZuSPOjSQWSajgRI,13618
27
27
  sleap_nn/data/confidence_maps.py,sha256=PTRqZWSAz1S7viJhxu7QgIC1aHiek97c_dCUsKUwG1o,6217
28
- sleap_nn/data/custom_datasets.py,sha256=2qAyLeiCPI9uudFFP7zlj6d_tbxc5OVzpnzT23mRkVw,98472
28
+ sleap_nn/data/custom_datasets.py,sha256=SO-aNB1-bB9DL5Zw-oGYDsliBxwI4iKX_FmwgZjKOgQ,99975
29
29
  sleap_nn/data/edge_maps.py,sha256=75qG_7zHRw7fC8JUCVI2tzYakIoxxneWWmcrTwjcHPo,12519
30
30
  sleap_nn/data/identity.py,sha256=7vNup6PudST4yDLyDT9wDO-cunRirTEvx4sP77xrlfk,5193
31
31
  sleap_nn/data/instance_centroids.py,sha256=SF-3zJt_VMTbZI5ssbrvmZQZDd3684bn55EAtvcbQ6o,2172
@@ -55,11 +55,11 @@ sleap_nn/training/__init__.py,sha256=vNTKsIJPZHJwFSKn5PmjiiRJunR_9e7y4_v0S6rdF8U
55
55
  sleap_nn/training/callbacks.py,sha256=TVnQ6plNC2MnlTiY2rSCRuw2WRk5cQSziek_VPUcOEg,25994
56
56
  sleap_nn/training/lightning_modules.py,sha256=G3c4xJkYWW-iSRawzkgTqkGd4lTsbPiMTcB5Nvq7jes,85512
57
57
  sleap_nn/training/losses.py,sha256=gbdinUURh4QUzjmNd2UJpt4FXwecqKy9gHr65JZ1bZk,1632
58
- sleap_nn/training/model_trainer.py,sha256=InDKHrQxBwbltZKutW4yrBR9NThLdRpWNUGhmB0xAi4,57863
58
+ sleap_nn/training/model_trainer.py,sha256=loCmEX0DfBtdV_pN-W8s31fn2_L-lbpWaq3OQXeSp-0,59337
59
59
  sleap_nn/training/utils.py,sha256=ivdkZEI0DkTCm6NPszsaDOh9jSfozkONZdl6TvvQUWI,20398
60
- sleap_nn-0.1.0a0.dist-info/licenses/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
61
- sleap_nn-0.1.0a0.dist-info/METADATA,sha256=lxSmGNTUg9eetqHCvhw8Tv5zJtia-dIM5RzOeoDccj8,5637
62
- sleap_nn-0.1.0a0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
63
- sleap_nn-0.1.0a0.dist-info/entry_points.txt,sha256=zfl5Y3hidZxWBvo8qXvu5piJAXJ_l6v7xVFm0gNiUoI,46
64
- sleap_nn-0.1.0a0.dist-info/top_level.txt,sha256=Kz68iQ55K75LWgSeqz4V4SCMGeFFYH-KGBOyhQh3xZE,9
65
- sleap_nn-0.1.0a0.dist-info/RECORD,,
60
+ sleap_nn-0.1.0a1.dist-info/licenses/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
61
+ sleap_nn-0.1.0a1.dist-info/METADATA,sha256=h3d4WPIu_JunY32jaRqJ4-fXp4KruTWT57FWb3L6dps,5637
62
+ sleap_nn-0.1.0a1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
63
+ sleap_nn-0.1.0a1.dist-info/entry_points.txt,sha256=zfl5Y3hidZxWBvo8qXvu5piJAXJ_l6v7xVFm0gNiUoI,46
64
+ sleap_nn-0.1.0a1.dist-info/top_level.txt,sha256=Kz68iQ55K75LWgSeqz4V4SCMGeFFYH-KGBOyhQh3xZE,9
65
+ sleap_nn-0.1.0a1.dist-info/RECORD,,