sleap-nn 0.1.0__py3-none-any.whl → 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. sleap_nn/__init__.py +1 -1
  2. sleap_nn/architectures/convnext.py +0 -5
  3. sleap_nn/architectures/encoder_decoder.py +6 -25
  4. sleap_nn/architectures/swint.py +0 -8
  5. sleap_nn/cli.py +60 -364
  6. sleap_nn/config/data_config.py +5 -11
  7. sleap_nn/config/get_config.py +4 -5
  8. sleap_nn/config/trainer_config.py +0 -71
  9. sleap_nn/data/augmentation.py +241 -50
  10. sleap_nn/data/custom_datasets.py +34 -364
  11. sleap_nn/data/instance_cropping.py +1 -1
  12. sleap_nn/data/resizing.py +2 -2
  13. sleap_nn/data/utils.py +17 -135
  14. sleap_nn/evaluation.py +22 -81
  15. sleap_nn/inference/bottomup.py +20 -86
  16. sleap_nn/inference/peak_finding.py +19 -88
  17. sleap_nn/inference/predictors.py +117 -224
  18. sleap_nn/legacy_models.py +11 -65
  19. sleap_nn/predict.py +9 -37
  20. sleap_nn/train.py +4 -69
  21. sleap_nn/training/callbacks.py +105 -1046
  22. sleap_nn/training/lightning_modules.py +65 -602
  23. sleap_nn/training/model_trainer.py +204 -201
  24. {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a1.dist-info}/METADATA +3 -15
  25. sleap_nn-0.1.0a1.dist-info/RECORD +65 -0
  26. {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a1.dist-info}/WHEEL +1 -1
  27. sleap_nn/data/skia_augmentation.py +0 -414
  28. sleap_nn/export/__init__.py +0 -21
  29. sleap_nn/export/cli.py +0 -1778
  30. sleap_nn/export/exporters/__init__.py +0 -51
  31. sleap_nn/export/exporters/onnx_exporter.py +0 -80
  32. sleap_nn/export/exporters/tensorrt_exporter.py +0 -291
  33. sleap_nn/export/metadata.py +0 -225
  34. sleap_nn/export/predictors/__init__.py +0 -63
  35. sleap_nn/export/predictors/base.py +0 -22
  36. sleap_nn/export/predictors/onnx.py +0 -154
  37. sleap_nn/export/predictors/tensorrt.py +0 -312
  38. sleap_nn/export/utils.py +0 -307
  39. sleap_nn/export/wrappers/__init__.py +0 -25
  40. sleap_nn/export/wrappers/base.py +0 -96
  41. sleap_nn/export/wrappers/bottomup.py +0 -243
  42. sleap_nn/export/wrappers/bottomup_multiclass.py +0 -195
  43. sleap_nn/export/wrappers/centered_instance.py +0 -56
  44. sleap_nn/export/wrappers/centroid.py +0 -58
  45. sleap_nn/export/wrappers/single_instance.py +0 -83
  46. sleap_nn/export/wrappers/topdown.py +0 -180
  47. sleap_nn/export/wrappers/topdown_multiclass.py +0 -304
  48. sleap_nn/inference/postprocessing.py +0 -284
  49. sleap_nn/training/schedulers.py +0 -191
  50. sleap_nn-0.1.0.dist-info/RECORD +0 -88
  51. {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a1.dist-info}/entry_points.txt +0 -0
  52. {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a1.dist-info}/licenses/LICENSE +0 -0
  53. {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a1.dist-info}/top_level.txt +0 -0
@@ -2,6 +2,7 @@
2
2
 
3
3
  import os
4
4
  import shutil
5
+ import copy
5
6
  import attrs
6
7
  import torch
7
8
  import random
@@ -15,13 +16,14 @@ import yaml
15
16
  from pathlib import Path
16
17
  from typing import List, Optional
17
18
  from datetime import datetime
18
- from itertools import count
19
+ from itertools import cycle, count
19
20
  from omegaconf import DictConfig, OmegaConf
20
21
  from lightning.pytorch.loggers import WandbLogger
21
22
  from sleap_nn.data.utils import check_cache_memory
22
23
  from lightning.pytorch.callbacks import (
23
24
  ModelCheckpoint,
24
25
  EarlyStopping,
26
+ LearningRateMonitor,
25
27
  )
26
28
  from lightning.pytorch.profilers import (
27
29
  SimpleProfiler,
@@ -53,11 +55,12 @@ from sleap_nn.config.training_job_config import verify_training_cfg
53
55
  from sleap_nn.training.callbacks import (
54
56
  ProgressReporterZMQ,
55
57
  TrainingControllerZMQ,
58
+ MatplotlibSaver,
59
+ WandBPredImageLogger,
60
+ WandBVizCallback,
61
+ WandBVizCallbackWithPAFs,
56
62
  CSVLoggerCallback,
57
63
  SleapProgressBar,
58
- EpochEndEvaluationCallback,
59
- CentroidEvaluationCallback,
60
- UnifiedVizCallback,
61
64
  )
62
65
  from sleap_nn import RANK
63
66
  from sleap_nn.legacy_models import get_keras_first_layer_channels
@@ -487,36 +490,16 @@ class ModelTrainer:
487
490
  ckpt_dir = "."
488
491
  self.config.trainer_config.ckpt_dir = ckpt_dir
489
492
  run_name = self.config.trainer_config.run_name
490
- run_name_is_empty = run_name is None or run_name == "" or run_name == "None"
491
-
492
- # Validate: multi-GPU + disk cache requires explicit run_name
493
- if run_name_is_empty:
494
- is_disk_caching = (
495
- self.config.data_config.data_pipeline_fw
496
- == "torch_dataset_cache_img_disk"
497
- )
498
- num_devices = self._get_trainer_devices()
499
-
500
- if is_disk_caching and num_devices > 1:
501
- raise ValueError(
502
- f"Multi-GPU training with disk caching requires an explicit `run_name`.\n\n"
503
- f"Detected {num_devices} device(s) with "
504
- f"`data_pipeline_fw='torch_dataset_cache_img_disk'`.\n"
505
- f"Without an explicit run_name, each GPU worker generates a different "
506
- f"timestamp-based directory, causing cache synchronization failures.\n\n"
507
- f"Please provide a run_name using one of these methods:\n"
508
- f" - CLI: sleap-nn train config.yaml trainer_config.run_name=my_experiment\n"
509
- f" - Config file: Set `trainer_config.run_name: my_experiment`\n"
510
- f" - Python API: train(..., run_name='my_experiment')"
511
- )
512
-
513
- # Auto-generate timestamp-based run_name (safe for single GPU or non-disk-cache)
493
+ if run_name is None or run_name == "" or run_name == "None":
514
494
  sum_train_lfs = sum([len(train_label) for train_label in self.train_labels])
515
495
  sum_val_lfs = sum([len(val_label) for val_label in self.val_labels])
516
- run_name = (
517
- datetime.now().strftime("%y%m%d_%H%M%S")
518
- + f".{self.model_type}.n={sum_train_lfs + sum_val_lfs}"
519
- )
496
+ if self._get_trainer_devices() > 1:
497
+ run_name = f"{self.model_type}.n={sum_train_lfs + sum_val_lfs}"
498
+ else:
499
+ run_name = (
500
+ datetime.now().strftime("%y%m%d_%H%M%S")
501
+ + f".{self.model_type}.n={sum_train_lfs + sum_val_lfs}"
502
+ )
520
503
 
521
504
  # If checkpoint path already exists, add suffix to prevent overwriting
522
505
  if (Path(ckpt_dir) / run_name).exists() and (
@@ -659,6 +642,10 @@ class ModelTrainer:
659
642
  if self.config.trainer_config.wandb.prv_runid == "":
660
643
  self.config.trainer_config.wandb.prv_runid = None
661
644
 
645
+ # Default wandb run name to trainer run_name if not specified
646
+ if self.config.trainer_config.wandb.name is None:
647
+ self.config.trainer_config.wandb.name = self.config.trainer_config.run_name
648
+
662
649
  # compute preprocessing parameters from the labels objects and fill in the config
663
650
  self._setup_preprocessing_config()
664
651
 
@@ -708,14 +695,9 @@ class ModelTrainer:
708
695
  )
709
696
  )
710
697
 
711
- # setup checkpoint path (generates run_name if not specified)
698
+ # setup checkpoint path
712
699
  self._setup_ckpt_path()
713
700
 
714
- # Default wandb run name to trainer run_name if not specified
715
- # Note: This must come after _setup_ckpt_path() which generates run_name
716
- if self.config.trainer_config.wandb.name is None:
717
- self.config.trainer_config.wandb.name = self.config.trainer_config.run_name
718
-
719
701
  # verify input_channels in model_config based on input image and pretrained model weights
720
702
  self._verify_model_input_channels()
721
703
 
@@ -727,15 +709,15 @@ class ModelTrainer:
727
709
  ).as_posix()
728
710
  logger.info(f"Setting up model ckpt dir: `{ckpt_path}`...")
729
711
 
730
- # Only rank 0 (or non-distributed) should create directories and save files
712
+ if not Path(ckpt_path).exists():
713
+ try:
714
+ Path(ckpt_path).mkdir(parents=True, exist_ok=True)
715
+ except OSError as e:
716
+ message = f"Cannot create a new folder in {ckpt_path}.\n {e}"
717
+ logger.error(message)
718
+ raise OSError(message)
719
+
731
720
  if RANK in [0, -1]:
732
- if not Path(ckpt_path).exists():
733
- try:
734
- Path(ckpt_path).mkdir(parents=True, exist_ok=True)
735
- except OSError as e:
736
- message = f"Cannot create a new folder in {ckpt_path}.\n {e}"
737
- logger.error(message)
738
- raise OSError(message)
739
721
  # Check if we should filter to user-labeled frames only
740
722
  user_instances_only = OmegaConf.select(
741
723
  self.config, "data_config.user_instances_only", default=True
@@ -808,40 +790,10 @@ class ModelTrainer:
808
790
  base_cache_img_path = None
809
791
  if self.config.data_config.data_pipeline_fw == "torch_dataset_cache_img_memory":
810
792
  # check available memory. If insufficient memory, default to disk caching.
811
- # Account for DataLoader worker memory overhead
812
- train_num_workers = self.config.trainer_config.train_data_loader.num_workers
813
- val_num_workers = self.config.trainer_config.val_data_loader.num_workers
814
- max_num_workers = max(train_num_workers, val_num_workers)
815
-
816
793
  mem_available = check_cache_memory(
817
- self.train_labels,
818
- self.val_labels,
819
- memory_buffer=MEMORY_BUFFER,
820
- num_workers=max_num_workers,
794
+ self.train_labels, self.val_labels, memory_buffer=MEMORY_BUFFER
821
795
  )
822
796
  if not mem_available:
823
- # Validate: multi-GPU + auto-generated run_name + fallback to disk cache
824
- original_run_name = self._initial_config.trainer_config.run_name
825
- run_name_was_auto = (
826
- original_run_name is None
827
- or original_run_name == ""
828
- or original_run_name == "None"
829
- )
830
- if run_name_was_auto and self.trainer.num_devices > 1:
831
- raise ValueError(
832
- f"Memory caching failed and disk caching fallback requires an "
833
- f"explicit `run_name` for multi-GPU training.\n\n"
834
- f"Detected {self.trainer.num_devices} device(s) with insufficient "
835
- f"memory for in-memory caching.\n"
836
- f"Without an explicit run_name, each GPU worker generates a different "
837
- f"timestamp-based directory, causing cache synchronization failures.\n\n"
838
- f"Please provide a run_name using one of these methods:\n"
839
- f" - CLI: sleap-nn train config.yaml trainer_config.run_name=my_experiment\n"
840
- f" - Config file: Set `trainer_config.run_name: my_experiment`\n"
841
- f" - Python API: train(..., run_name='my_experiment')\n\n"
842
- f"Alternatively, use `data_pipeline_fw='torch_dataset'` to disable caching."
843
- )
844
-
845
797
  self.config.data_config.data_pipeline_fw = (
846
798
  "torch_dataset_cache_img_disk"
847
799
  )
@@ -885,7 +837,7 @@ class ModelTrainer:
885
837
  / self.config.trainer_config.run_name
886
838
  ).as_posix(),
887
839
  filename="best",
888
- monitor="val/loss",
840
+ monitor="val_loss",
889
841
  mode="min",
890
842
  )
891
843
  callbacks.append(checkpoint_callback)
@@ -893,52 +845,18 @@ class ModelTrainer:
893
845
  # csv log callback
894
846
  csv_log_keys = [
895
847
  "epoch",
896
- "train/loss",
897
- "val/loss",
848
+ "train_loss",
849
+ "val_loss",
898
850
  "learning_rate",
899
- "train/time",
900
- "val/time",
851
+ "train_time",
852
+ "val_time",
901
853
  ]
902
- # Add model-specific keys for wandb parity
903
854
  if self.model_type in [
904
855
  "single_instance",
905
856
  "centered_instance",
906
857
  "multi_class_topdown",
907
858
  ]:
908
- csv_log_keys.extend(
909
- [f"train/confmaps/{name}" for name in self.skeletons[0].node_names]
910
- )
911
- if self.model_type == "bottomup":
912
- csv_log_keys.extend(
913
- [
914
- "train/confmaps_loss",
915
- "train/paf_loss",
916
- "val/confmaps_loss",
917
- "val/paf_loss",
918
- ]
919
- )
920
- if self.model_type == "multi_class_bottomup":
921
- csv_log_keys.extend(
922
- [
923
- "train/confmaps_loss",
924
- "train/classmap_loss",
925
- "train/class_accuracy",
926
- "val/confmaps_loss",
927
- "val/classmap_loss",
928
- "val/class_accuracy",
929
- ]
930
- )
931
- if self.model_type == "multi_class_topdown":
932
- csv_log_keys.extend(
933
- [
934
- "train/confmaps_loss",
935
- "train/classvector_loss",
936
- "train/class_accuracy",
937
- "val/confmaps_loss",
938
- "val/classvector_loss",
939
- "val/class_accuracy",
940
- ]
941
- )
859
+ csv_log_keys.extend(self.skeletons[0].node_names)
942
860
  csv_logger = CSVLoggerCallback(
943
861
  filepath=Path(self.config.trainer_config.ckpt_dir)
944
862
  / self.config.trainer_config.run_name
@@ -951,7 +869,7 @@ class ModelTrainer:
951
869
  # early stopping callback
952
870
  callbacks.append(
953
871
  EarlyStopping(
954
- monitor="val/loss",
872
+ monitor="val_loss",
955
873
  mode="min",
956
874
  verbose=False,
957
875
  min_delta=self.config.trainer_config.early_stopping.min_delta,
@@ -991,6 +909,10 @@ class ModelTrainer:
991
909
  "To keep logs, set trainer_config.wandb.delete_local_logs=false"
992
910
  )
993
911
 
912
+ # Learning rate monitor callback - logs LR at each step for dynamic schedulers
913
+ # Only added when wandb is enabled since it requires a logger
914
+ callbacks.append(LearningRateMonitor(logging_interval="step"))
915
+
994
916
  # save the configs as yaml in the checkpoint dir
995
917
  # Mask API key in both configs to prevent saving to disk
996
918
  self.config.trainer_config.wandb.api_key = ""
@@ -1009,8 +931,11 @@ class ModelTrainer:
1009
931
  )
1010
932
  callbacks.append(ProgressReporterZMQ(address=publish_address))
1011
933
 
1012
- # viz callbacks - use unified callback for all visualization outputs
934
+ # viz callbacks
1013
935
  if self.config.trainer_config.visualize_preds_during_training:
936
+ train_viz_pipeline = cycle(viz_train_dataset)
937
+ val_viz_pipeline = cycle(viz_val_dataset)
938
+
1014
939
  viz_dir = (
1015
940
  Path(self.config.trainer_config.ckpt_dir)
1016
941
  / self.config.trainer_config.run_name
@@ -1020,77 +945,147 @@ class ModelTrainer:
1020
945
  if RANK in [0, -1]:
1021
946
  Path(viz_dir).mkdir(parents=True, exist_ok=True)
1022
947
 
1023
- # Get wandb viz config options
1024
- log_wandb = self.config.trainer_config.use_wandb and OmegaConf.select(
1025
- self.config, "trainer_config.wandb.save_viz_imgs_wandb", default=False
1026
- )
1027
- wandb_modes = []
1028
- if log_wandb:
1029
- if OmegaConf.select(
1030
- self.config, "trainer_config.wandb.viz_enabled", default=True
1031
- ):
1032
- wandb_modes.append("direct")
1033
- if OmegaConf.select(
1034
- self.config, "trainer_config.wandb.viz_boxes", default=False
1035
- ):
1036
- wandb_modes.append("boxes")
1037
- if OmegaConf.select(
1038
- self.config, "trainer_config.wandb.viz_masks", default=False
1039
- ):
1040
- wandb_modes.append("masks")
1041
-
1042
- # Single unified callback handles all visualization outputs
1043
948
  callbacks.append(
1044
- UnifiedVizCallback(
1045
- model_trainer=self,
1046
- train_dataset=viz_train_dataset,
1047
- val_dataset=viz_val_dataset,
1048
- model_type=self.model_type,
1049
- save_local=self.config.trainer_config.save_ckpt,
1050
- local_save_dir=viz_dir,
1051
- log_wandb=log_wandb,
1052
- wandb_modes=wandb_modes if wandb_modes else ["direct"],
1053
- wandb_box_size=OmegaConf.select(
1054
- self.config, "trainer_config.wandb.viz_box_size", default=5.0
949
+ MatplotlibSaver(
950
+ save_folder=viz_dir,
951
+ plot_fn=lambda: self.lightning_model.visualize_example(
952
+ next(train_viz_pipeline)
1055
953
  ),
1056
- wandb_confmap_threshold=OmegaConf.select(
1057
- self.config,
1058
- "trainer_config.wandb.viz_confmap_threshold",
1059
- default=0.1,
1060
- ),
1061
- log_wandb_table=OmegaConf.select(
1062
- self.config, "trainer_config.wandb.log_viz_table", default=False
954
+ prefix="train",
955
+ )
956
+ )
957
+ callbacks.append(
958
+ MatplotlibSaver(
959
+ save_folder=viz_dir,
960
+ plot_fn=lambda: self.lightning_model.visualize_example(
961
+ next(val_viz_pipeline)
1063
962
  ),
963
+ prefix="validation",
1064
964
  )
1065
965
  )
1066
966
 
1067
- # Add custom progress bar with better metric formatting
1068
- if self.config.trainer_config.enable_progress_bar:
1069
- callbacks.append(SleapProgressBar())
967
+ if self.model_type == "bottomup":
968
+ train_viz_pipeline1 = cycle(copy.deepcopy(viz_train_dataset))
969
+ val_viz_pipeline1 = cycle(copy.deepcopy(viz_val_dataset))
970
+ callbacks.append(
971
+ MatplotlibSaver(
972
+ save_folder=viz_dir,
973
+ plot_fn=lambda: self.lightning_model.visualize_pafs_example(
974
+ next(train_viz_pipeline1)
975
+ ),
976
+ prefix="train.pafs_magnitude",
977
+ )
978
+ )
979
+ callbacks.append(
980
+ MatplotlibSaver(
981
+ save_folder=viz_dir,
982
+ plot_fn=lambda: self.lightning_model.visualize_pafs_example(
983
+ next(val_viz_pipeline1)
984
+ ),
985
+ prefix="validation.pafs_magnitude",
986
+ )
987
+ )
1070
988
 
1071
- # Add epoch-end evaluation callback if enabled
1072
- if self.config.trainer_config.eval.enabled:
1073
- if self.model_type == "centroid":
1074
- # Use centroid-specific evaluation with distance-based metrics
989
+ if self.model_type == "multi_class_bottomup":
990
+ train_viz_pipeline1 = cycle(copy.deepcopy(viz_train_dataset))
991
+ val_viz_pipeline1 = cycle(copy.deepcopy(viz_val_dataset))
1075
992
  callbacks.append(
1076
- CentroidEvaluationCallback(
1077
- videos=self.val_labels[0].videos,
1078
- eval_frequency=self.config.trainer_config.eval.frequency,
1079
- match_threshold=self.config.trainer_config.eval.match_threshold,
993
+ MatplotlibSaver(
994
+ save_folder=viz_dir,
995
+ plot_fn=lambda: self.lightning_model.visualize_class_maps_example(
996
+ next(train_viz_pipeline1)
997
+ ),
998
+ prefix="train.class_maps",
1080
999
  )
1081
1000
  )
1082
- else:
1083
- # Use standard OKS/PCK evaluation for pose models
1084
1001
  callbacks.append(
1085
- EpochEndEvaluationCallback(
1086
- skeleton=self.skeletons[0],
1087
- videos=self.val_labels[0].videos,
1088
- eval_frequency=self.config.trainer_config.eval.frequency,
1089
- oks_stddev=self.config.trainer_config.eval.oks_stddev,
1090
- oks_scale=self.config.trainer_config.eval.oks_scale,
1002
+ MatplotlibSaver(
1003
+ save_folder=viz_dir,
1004
+ plot_fn=lambda: self.lightning_model.visualize_class_maps_example(
1005
+ next(val_viz_pipeline1)
1006
+ ),
1007
+ prefix="validation.class_maps",
1091
1008
  )
1092
1009
  )
1093
1010
 
1011
+ if self.config.trainer_config.use_wandb and OmegaConf.select(
1012
+ self.config, "trainer_config.wandb.save_viz_imgs_wandb", default=False
1013
+ ):
1014
+ # Get wandb viz config options
1015
+ viz_enabled = OmegaConf.select(
1016
+ self.config, "trainer_config.wandb.viz_enabled", default=True
1017
+ )
1018
+ viz_boxes = OmegaConf.select(
1019
+ self.config, "trainer_config.wandb.viz_boxes", default=False
1020
+ )
1021
+ viz_masks = OmegaConf.select(
1022
+ self.config, "trainer_config.wandb.viz_masks", default=False
1023
+ )
1024
+ viz_box_size = OmegaConf.select(
1025
+ self.config, "trainer_config.wandb.viz_box_size", default=5.0
1026
+ )
1027
+ viz_confmap_threshold = OmegaConf.select(
1028
+ self.config,
1029
+ "trainer_config.wandb.viz_confmap_threshold",
1030
+ default=0.1,
1031
+ )
1032
+ log_viz_table = OmegaConf.select(
1033
+ self.config, "trainer_config.wandb.log_viz_table", default=False
1034
+ )
1035
+
1036
+ # Create viz data pipelines for wandb callback
1037
+ wandb_train_viz_pipeline = cycle(copy.deepcopy(viz_train_dataset))
1038
+ wandb_val_viz_pipeline = cycle(copy.deepcopy(viz_val_dataset))
1039
+
1040
+ if self.model_type == "bottomup":
1041
+ # Bottom-up model needs PAF visualizations
1042
+ wandb_train_pafs_pipeline = cycle(copy.deepcopy(viz_train_dataset))
1043
+ wandb_val_pafs_pipeline = cycle(copy.deepcopy(viz_val_dataset))
1044
+ callbacks.append(
1045
+ WandBVizCallbackWithPAFs(
1046
+ train_viz_fn=lambda: self.lightning_model.get_visualization_data(
1047
+ next(wandb_train_viz_pipeline)
1048
+ ),
1049
+ val_viz_fn=lambda: self.lightning_model.get_visualization_data(
1050
+ next(wandb_val_viz_pipeline)
1051
+ ),
1052
+ train_pafs_viz_fn=lambda: self.lightning_model.get_visualization_data(
1053
+ next(wandb_train_pafs_pipeline), include_pafs=True
1054
+ ),
1055
+ val_pafs_viz_fn=lambda: self.lightning_model.get_visualization_data(
1056
+ next(wandb_val_pafs_pipeline), include_pafs=True
1057
+ ),
1058
+ viz_enabled=viz_enabled,
1059
+ viz_boxes=viz_boxes,
1060
+ viz_masks=viz_masks,
1061
+ box_size=viz_box_size,
1062
+ confmap_threshold=viz_confmap_threshold,
1063
+ log_table=log_viz_table,
1064
+ )
1065
+ )
1066
+ else:
1067
+ # Standard models
1068
+ callbacks.append(
1069
+ WandBVizCallback(
1070
+ train_viz_fn=lambda: self.lightning_model.get_visualization_data(
1071
+ next(wandb_train_viz_pipeline)
1072
+ ),
1073
+ val_viz_fn=lambda: self.lightning_model.get_visualization_data(
1074
+ next(wandb_val_viz_pipeline)
1075
+ ),
1076
+ viz_enabled=viz_enabled,
1077
+ viz_boxes=viz_boxes,
1078
+ viz_masks=viz_masks,
1079
+ box_size=viz_box_size,
1080
+ confmap_threshold=viz_confmap_threshold,
1081
+ log_table=log_viz_table,
1082
+ )
1083
+ )
1084
+
1085
+ # Add custom progress bar with better metric formatting
1086
+ if self.config.trainer_config.enable_progress_bar:
1087
+ callbacks.append(SleapProgressBar())
1088
+
1094
1089
  return loggers, callbacks
1095
1090
 
1096
1091
  def _delete_cache_imgs(self):
@@ -1179,11 +1174,6 @@ class ModelTrainer:
1179
1174
  : self.config.trainer_config.trainer_devices
1180
1175
  ]
1181
1176
  ]
1182
- # Sort device indices in ascending order for NCCL compatibility.
1183
- # NCCL expects devices in consistent ascending order across ranks
1184
- # to properly set up communication rings. Without sorting, DDP may
1185
- # assign multiple ranks to the same GPU, causing "Duplicate GPU detected" errors.
1186
- devices.sort()
1187
1177
  logger.info(f"Using GPUs with most available memory: {devices}")
1188
1178
 
1189
1179
  # create lightning.Trainer instance.
@@ -1205,10 +1195,6 @@ class ModelTrainer:
1205
1195
  # setup datasets
1206
1196
  train_dataset, val_dataset = self._setup_datasets()
1207
1197
 
1208
- # Barrier after dataset creation to ensure all workers wait for disk caching
1209
- # (rank 0 caches to disk, others must wait before reading cached files)
1210
- self.trainer.strategy.barrier()
1211
-
1212
1198
  # set-up steps per epoch
1213
1199
  train_steps_per_epoch = self.config.trainer_config.train_steps_per_epoch
1214
1200
  if train_steps_per_epoch is None:
@@ -1282,21 +1268,18 @@ class ModelTrainer:
1282
1268
  # Define custom x-axes for wandb metrics
1283
1269
  # Epoch-level metrics use epoch as x-axis, step-level use default global_step
1284
1270
  wandb.define_metric("epoch")
1285
-
1286
- # Training metrics (train/ prefix for grouping) - all use epoch x-axis
1287
- wandb.define_metric("train/*", step_metric="epoch")
1288
- wandb.define_metric("train/confmaps/*", step_metric="epoch")
1289
-
1290
- # Validation metrics (val/ prefix for grouping)
1291
- wandb.define_metric("val/*", step_metric="epoch")
1292
-
1293
- # Evaluation metrics (eval/ prefix for grouping)
1294
- wandb.define_metric("eval/*", step_metric="epoch")
1295
-
1296
- # Visualization images (need explicit nested paths)
1297
- wandb.define_metric("viz/*", step_metric="epoch")
1298
- wandb.define_metric("viz/train/*", step_metric="epoch")
1299
- wandb.define_metric("viz/val/*", step_metric="epoch")
1271
+ wandb.define_metric("val_loss", step_metric="epoch")
1272
+ wandb.define_metric("val_time", step_metric="epoch")
1273
+ wandb.define_metric("train_time", step_metric="epoch")
1274
+ # Per-node losses use epoch as x-axis
1275
+ for node_name in self.skeletons[0].node_names:
1276
+ wandb.define_metric(node_name, step_metric="epoch")
1277
+
1278
+ # Visualization images use epoch as x-axis
1279
+ wandb.define_metric("train_predictions*", step_metric="epoch")
1280
+ wandb.define_metric("val_predictions*", step_metric="epoch")
1281
+ wandb.define_metric("train_pafs*", step_metric="epoch")
1282
+ wandb.define_metric("val_pafs*", step_metric="epoch")
1300
1283
 
1301
1284
  self.config.trainer_config.wandb.current_run_id = wandb.run.id
1302
1285
  wandb.config["run_name"] = self.config.trainer_config.wandb.name
@@ -1339,7 +1322,27 @@ class ModelTrainer:
1339
1322
  logger.info(
1340
1323
  f"Finished training loop. [{(time.time() - start_train_time) / 60:.1f} min]"
1341
1324
  )
1342
- # Note: wandb.finish() is called in train.py after post-training evaluation
1325
+ if self.trainer.global_rank == 0 and self.config.trainer_config.use_wandb:
1326
+ wandb.finish()
1327
+
1328
+ # Delete local wandb logs if configured
1329
+ wandb_config = self.config.trainer_config.wandb
1330
+ should_delete_wandb_logs = wandb_config.delete_local_logs is True or (
1331
+ wandb_config.delete_local_logs is None
1332
+ and wandb_config.wandb_mode != "offline"
1333
+ )
1334
+ if should_delete_wandb_logs:
1335
+ wandb_dir = (
1336
+ Path(self.config.trainer_config.ckpt_dir)
1337
+ / self.config.trainer_config.run_name
1338
+ / "wandb"
1339
+ )
1340
+ if wandb_dir.exists():
1341
+ logger.info(
1342
+ f"Deleting local wandb logs at {wandb_dir}... "
1343
+ "(set trainer_config.wandb.delete_local_logs=false to disable)"
1344
+ )
1345
+ shutil.rmtree(wandb_dir, ignore_errors=True)
1343
1346
 
1344
1347
  # delete image disk caching
1345
1348
  if (
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sleap-nn
3
- Version: 0.1.0
3
+ Version: 0.1.0a1
4
4
  Summary: Neural network backend for training and inference for animal pose estimation.
5
5
  Author-email: Divya Seshadri Murali <dimurali@salk.edu>, Elizabeth Berrigan <eberrigan@salk.edu>, Vincent Tu <vitu@ucsd.edu>, Liezl Maree <lmaree@salk.edu>, David Samy <davidasamy@gmail.com>, Talmo Pereira <talmo@salk.edu>
6
6
  License: BSD-3-Clause
@@ -13,10 +13,10 @@ Classifier: Programming Language :: Python :: 3.13
13
13
  Requires-Python: <3.14,>=3.11
14
14
  Description-Content-Type: text/markdown
15
15
  License-File: LICENSE
16
- Requires-Dist: sleap-io<0.7.0,>=0.6.2
16
+ Requires-Dist: sleap-io<0.7.0,>=0.6.0
17
17
  Requires-Dist: numpy
18
18
  Requires-Dist: lightning
19
- Requires-Dist: skia-python>=87.0
19
+ Requires-Dist: kornia
20
20
  Requires-Dist: jsonpickle
21
21
  Requires-Dist: scipy
22
22
  Requires-Dist: attrs
@@ -32,7 +32,6 @@ Requires-Dist: hydra-core
32
32
  Requires-Dist: jupyter
33
33
  Requires-Dist: jupyterlab
34
34
  Requires-Dist: pyzmq
35
- Requires-Dist: rich-click>=1.9.5
36
35
  Provides-Extra: torch
37
36
  Requires-Dist: torch; extra == "torch"
38
37
  Requires-Dist: torchvision>=0.20.0; extra == "torch"
@@ -48,17 +47,6 @@ Requires-Dist: torchvision>=0.20.0; extra == "torch-cuda128"
48
47
  Provides-Extra: torch-cuda130
49
48
  Requires-Dist: torch; extra == "torch-cuda130"
50
49
  Requires-Dist: torchvision>=0.20.0; extra == "torch-cuda130"
51
- Provides-Extra: export
52
- Requires-Dist: onnx>=1.15.0; extra == "export"
53
- Requires-Dist: onnxruntime>=1.16.0; extra == "export"
54
- Requires-Dist: onnxscript>=0.1.0; extra == "export"
55
- Provides-Extra: export-gpu
56
- Requires-Dist: onnx>=1.15.0; extra == "export-gpu"
57
- Requires-Dist: onnxruntime-gpu>=1.16.0; extra == "export-gpu"
58
- Requires-Dist: onnxscript>=0.1.0; extra == "export-gpu"
59
- Provides-Extra: tensorrt
60
- Requires-Dist: tensorrt>=10.13.0; (sys_platform == "linux" or sys_platform == "win32") and extra == "tensorrt"
61
- Requires-Dist: torch-tensorrt>=2.5.0; (sys_platform == "linux" or sys_platform == "win32") and extra == "tensorrt"
62
50
  Dynamic: license-file
63
51
 
64
52
  # sleap-nn