sleap-nn 0.1.0__py3-none-any.whl → 0.1.0a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. sleap_nn/__init__.py +2 -4
  2. sleap_nn/architectures/convnext.py +0 -5
  3. sleap_nn/architectures/encoder_decoder.py +6 -25
  4. sleap_nn/architectures/swint.py +0 -8
  5. sleap_nn/cli.py +60 -364
  6. sleap_nn/config/data_config.py +5 -11
  7. sleap_nn/config/get_config.py +4 -10
  8. sleap_nn/config/trainer_config.py +0 -76
  9. sleap_nn/data/augmentation.py +241 -50
  10. sleap_nn/data/custom_datasets.py +39 -411
  11. sleap_nn/data/instance_cropping.py +1 -1
  12. sleap_nn/data/resizing.py +2 -2
  13. sleap_nn/data/utils.py +17 -135
  14. sleap_nn/evaluation.py +22 -81
  15. sleap_nn/inference/bottomup.py +20 -86
  16. sleap_nn/inference/peak_finding.py +19 -88
  17. sleap_nn/inference/predictors.py +117 -224
  18. sleap_nn/legacy_models.py +11 -65
  19. sleap_nn/predict.py +9 -37
  20. sleap_nn/train.py +4 -74
  21. sleap_nn/training/callbacks.py +105 -1046
  22. sleap_nn/training/lightning_modules.py +65 -602
  23. sleap_nn/training/model_trainer.py +184 -211
  24. {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a0.dist-info}/METADATA +3 -15
  25. sleap_nn-0.1.0a0.dist-info/RECORD +65 -0
  26. {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a0.dist-info}/WHEEL +1 -1
  27. sleap_nn/data/skia_augmentation.py +0 -414
  28. sleap_nn/export/__init__.py +0 -21
  29. sleap_nn/export/cli.py +0 -1778
  30. sleap_nn/export/exporters/__init__.py +0 -51
  31. sleap_nn/export/exporters/onnx_exporter.py +0 -80
  32. sleap_nn/export/exporters/tensorrt_exporter.py +0 -291
  33. sleap_nn/export/metadata.py +0 -225
  34. sleap_nn/export/predictors/__init__.py +0 -63
  35. sleap_nn/export/predictors/base.py +0 -22
  36. sleap_nn/export/predictors/onnx.py +0 -154
  37. sleap_nn/export/predictors/tensorrt.py +0 -312
  38. sleap_nn/export/utils.py +0 -307
  39. sleap_nn/export/wrappers/__init__.py +0 -25
  40. sleap_nn/export/wrappers/base.py +0 -96
  41. sleap_nn/export/wrappers/bottomup.py +0 -243
  42. sleap_nn/export/wrappers/bottomup_multiclass.py +0 -195
  43. sleap_nn/export/wrappers/centered_instance.py +0 -56
  44. sleap_nn/export/wrappers/centroid.py +0 -58
  45. sleap_nn/export/wrappers/single_instance.py +0 -83
  46. sleap_nn/export/wrappers/topdown.py +0 -180
  47. sleap_nn/export/wrappers/topdown_multiclass.py +0 -304
  48. sleap_nn/inference/postprocessing.py +0 -284
  49. sleap_nn/training/schedulers.py +0 -191
  50. sleap_nn-0.1.0.dist-info/RECORD +0 -88
  51. {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a0.dist-info}/entry_points.txt +0 -0
  52. {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a0.dist-info}/licenses/LICENSE +0 -0
  53. {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a0.dist-info}/top_level.txt +0 -0
@@ -2,6 +2,7 @@
2
2
 
3
3
  import os
4
4
  import shutil
5
+ import copy
5
6
  import attrs
6
7
  import torch
7
8
  import random
@@ -15,13 +16,14 @@ import yaml
15
16
  from pathlib import Path
16
17
  from typing import List, Optional
17
18
  from datetime import datetime
18
- from itertools import count
19
+ from itertools import cycle, count
19
20
  from omegaconf import DictConfig, OmegaConf
20
21
  from lightning.pytorch.loggers import WandbLogger
21
22
  from sleap_nn.data.utils import check_cache_memory
22
23
  from lightning.pytorch.callbacks import (
23
24
  ModelCheckpoint,
24
25
  EarlyStopping,
26
+ LearningRateMonitor,
25
27
  )
26
28
  from lightning.pytorch.profilers import (
27
29
  SimpleProfiler,
@@ -53,11 +55,12 @@ from sleap_nn.config.training_job_config import verify_training_cfg
53
55
  from sleap_nn.training.callbacks import (
54
56
  ProgressReporterZMQ,
55
57
  TrainingControllerZMQ,
58
+ MatplotlibSaver,
59
+ WandBPredImageLogger,
60
+ WandBVizCallback,
61
+ WandBVizCallbackWithPAFs,
56
62
  CSVLoggerCallback,
57
63
  SleapProgressBar,
58
- EpochEndEvaluationCallback,
59
- CentroidEvaluationCallback,
60
- UnifiedVizCallback,
61
64
  )
62
65
  from sleap_nn import RANK
63
66
  from sleap_nn.legacy_models import get_keras_first_layer_channels
@@ -487,36 +490,16 @@ class ModelTrainer:
487
490
  ckpt_dir = "."
488
491
  self.config.trainer_config.ckpt_dir = ckpt_dir
489
492
  run_name = self.config.trainer_config.run_name
490
- run_name_is_empty = run_name is None or run_name == "" or run_name == "None"
491
-
492
- # Validate: multi-GPU + disk cache requires explicit run_name
493
- if run_name_is_empty:
494
- is_disk_caching = (
495
- self.config.data_config.data_pipeline_fw
496
- == "torch_dataset_cache_img_disk"
497
- )
498
- num_devices = self._get_trainer_devices()
499
-
500
- if is_disk_caching and num_devices > 1:
501
- raise ValueError(
502
- f"Multi-GPU training with disk caching requires an explicit `run_name`.\n\n"
503
- f"Detected {num_devices} device(s) with "
504
- f"`data_pipeline_fw='torch_dataset_cache_img_disk'`.\n"
505
- f"Without an explicit run_name, each GPU worker generates a different "
506
- f"timestamp-based directory, causing cache synchronization failures.\n\n"
507
- f"Please provide a run_name using one of these methods:\n"
508
- f" - CLI: sleap-nn train config.yaml trainer_config.run_name=my_experiment\n"
509
- f" - Config file: Set `trainer_config.run_name: my_experiment`\n"
510
- f" - Python API: train(..., run_name='my_experiment')"
511
- )
512
-
513
- # Auto-generate timestamp-based run_name (safe for single GPU or non-disk-cache)
493
+ if run_name is None or run_name == "" or run_name == "None":
514
494
  sum_train_lfs = sum([len(train_label) for train_label in self.train_labels])
515
495
  sum_val_lfs = sum([len(val_label) for val_label in self.val_labels])
516
- run_name = (
517
- datetime.now().strftime("%y%m%d_%H%M%S")
518
- + f".{self.model_type}.n={sum_train_lfs + sum_val_lfs}"
519
- )
496
+ if self._get_trainer_devices() > 1:
497
+ run_name = f"{self.model_type}.n={sum_train_lfs + sum_val_lfs}"
498
+ else:
499
+ run_name = (
500
+ datetime.now().strftime("%y%m%d_%H%M%S")
501
+ + f".{self.model_type}.n={sum_train_lfs + sum_val_lfs}"
502
+ )
520
503
 
521
504
  # If checkpoint path already exists, add suffix to prevent overwriting
522
505
  if (Path(ckpt_dir) / run_name).exists() and (
@@ -659,6 +642,10 @@ class ModelTrainer:
659
642
  if self.config.trainer_config.wandb.prv_runid == "":
660
643
  self.config.trainer_config.wandb.prv_runid = None
661
644
 
645
+ # Default wandb run name to trainer run_name if not specified
646
+ if self.config.trainer_config.wandb.name is None:
647
+ self.config.trainer_config.wandb.name = self.config.trainer_config.run_name
648
+
662
649
  # compute preprocessing parameters from the labels objects and fill in the config
663
650
  self._setup_preprocessing_config()
664
651
 
@@ -708,14 +695,9 @@ class ModelTrainer:
708
695
  )
709
696
  )
710
697
 
711
- # setup checkpoint path (generates run_name if not specified)
698
+ # setup checkpoint path
712
699
  self._setup_ckpt_path()
713
700
 
714
- # Default wandb run name to trainer run_name if not specified
715
- # Note: This must come after _setup_ckpt_path() which generates run_name
716
- if self.config.trainer_config.wandb.name is None:
717
- self.config.trainer_config.wandb.name = self.config.trainer_config.run_name
718
-
719
701
  # verify input_channels in model_config based on input image and pretrained model weights
720
702
  self._verify_model_input_channels()
721
703
 
@@ -727,15 +709,15 @@ class ModelTrainer:
727
709
  ).as_posix()
728
710
  logger.info(f"Setting up model ckpt dir: `{ckpt_path}`...")
729
711
 
730
- # Only rank 0 (or non-distributed) should create directories and save files
712
+ if not Path(ckpt_path).exists():
713
+ try:
714
+ Path(ckpt_path).mkdir(parents=True, exist_ok=True)
715
+ except OSError as e:
716
+ message = f"Cannot create a new folder in {ckpt_path}.\n {e}"
717
+ logger.error(message)
718
+ raise OSError(message)
719
+
731
720
  if RANK in [0, -1]:
732
- if not Path(ckpt_path).exists():
733
- try:
734
- Path(ckpt_path).mkdir(parents=True, exist_ok=True)
735
- except OSError as e:
736
- message = f"Cannot create a new folder in {ckpt_path}.\n {e}"
737
- logger.error(message)
738
- raise OSError(message)
739
721
  # Check if we should filter to user-labeled frames only
740
722
  user_instances_only = OmegaConf.select(
741
723
  self.config, "data_config.user_instances_only", default=True
@@ -808,40 +790,10 @@ class ModelTrainer:
808
790
  base_cache_img_path = None
809
791
  if self.config.data_config.data_pipeline_fw == "torch_dataset_cache_img_memory":
810
792
  # check available memory. If insufficient memory, default to disk caching.
811
- # Account for DataLoader worker memory overhead
812
- train_num_workers = self.config.trainer_config.train_data_loader.num_workers
813
- val_num_workers = self.config.trainer_config.val_data_loader.num_workers
814
- max_num_workers = max(train_num_workers, val_num_workers)
815
-
816
793
  mem_available = check_cache_memory(
817
- self.train_labels,
818
- self.val_labels,
819
- memory_buffer=MEMORY_BUFFER,
820
- num_workers=max_num_workers,
794
+ self.train_labels, self.val_labels, memory_buffer=MEMORY_BUFFER
821
795
  )
822
796
  if not mem_available:
823
- # Validate: multi-GPU + auto-generated run_name + fallback to disk cache
824
- original_run_name = self._initial_config.trainer_config.run_name
825
- run_name_was_auto = (
826
- original_run_name is None
827
- or original_run_name == ""
828
- or original_run_name == "None"
829
- )
830
- if run_name_was_auto and self.trainer.num_devices > 1:
831
- raise ValueError(
832
- f"Memory caching failed and disk caching fallback requires an "
833
- f"explicit `run_name` for multi-GPU training.\n\n"
834
- f"Detected {self.trainer.num_devices} device(s) with insufficient "
835
- f"memory for in-memory caching.\n"
836
- f"Without an explicit run_name, each GPU worker generates a different "
837
- f"timestamp-based directory, causing cache synchronization failures.\n\n"
838
- f"Please provide a run_name using one of these methods:\n"
839
- f" - CLI: sleap-nn train config.yaml trainer_config.run_name=my_experiment\n"
840
- f" - Config file: Set `trainer_config.run_name: my_experiment`\n"
841
- f" - Python API: train(..., run_name='my_experiment')\n\n"
842
- f"Alternatively, use `data_pipeline_fw='torch_dataset'` to disable caching."
843
- )
844
-
845
797
  self.config.data_config.data_pipeline_fw = (
846
798
  "torch_dataset_cache_img_disk"
847
799
  )
@@ -885,7 +837,7 @@ class ModelTrainer:
885
837
  / self.config.trainer_config.run_name
886
838
  ).as_posix(),
887
839
  filename="best",
888
- monitor="val/loss",
840
+ monitor="val_loss",
889
841
  mode="min",
890
842
  )
891
843
  callbacks.append(checkpoint_callback)
@@ -893,52 +845,18 @@ class ModelTrainer:
893
845
  # csv log callback
894
846
  csv_log_keys = [
895
847
  "epoch",
896
- "train/loss",
897
- "val/loss",
848
+ "train_loss",
849
+ "val_loss",
898
850
  "learning_rate",
899
- "train/time",
900
- "val/time",
851
+ "train_time",
852
+ "val_time",
901
853
  ]
902
- # Add model-specific keys for wandb parity
903
854
  if self.model_type in [
904
855
  "single_instance",
905
856
  "centered_instance",
906
857
  "multi_class_topdown",
907
858
  ]:
908
- csv_log_keys.extend(
909
- [f"train/confmaps/{name}" for name in self.skeletons[0].node_names]
910
- )
911
- if self.model_type == "bottomup":
912
- csv_log_keys.extend(
913
- [
914
- "train/confmaps_loss",
915
- "train/paf_loss",
916
- "val/confmaps_loss",
917
- "val/paf_loss",
918
- ]
919
- )
920
- if self.model_type == "multi_class_bottomup":
921
- csv_log_keys.extend(
922
- [
923
- "train/confmaps_loss",
924
- "train/classmap_loss",
925
- "train/class_accuracy",
926
- "val/confmaps_loss",
927
- "val/classmap_loss",
928
- "val/class_accuracy",
929
- ]
930
- )
931
- if self.model_type == "multi_class_topdown":
932
- csv_log_keys.extend(
933
- [
934
- "train/confmaps_loss",
935
- "train/classvector_loss",
936
- "train/class_accuracy",
937
- "val/confmaps_loss",
938
- "val/classvector_loss",
939
- "val/class_accuracy",
940
- ]
941
- )
859
+ csv_log_keys.extend(self.skeletons[0].node_names)
942
860
  csv_logger = CSVLoggerCallback(
943
861
  filepath=Path(self.config.trainer_config.ckpt_dir)
944
862
  / self.config.trainer_config.run_name
@@ -951,7 +869,7 @@ class ModelTrainer:
951
869
  # early stopping callback
952
870
  callbacks.append(
953
871
  EarlyStopping(
954
- monitor="val/loss",
872
+ monitor="val_loss",
955
873
  mode="min",
956
874
  verbose=False,
957
875
  min_delta=self.config.trainer_config.early_stopping.min_delta,
@@ -980,16 +898,9 @@ class ModelTrainer:
980
898
  )
981
899
  loggers.append(wandb_logger)
982
900
 
983
- # Log message about wandb local logs cleanup
984
- should_delete_wandb_logs = wandb_config.delete_local_logs is True or (
985
- wandb_config.delete_local_logs is None
986
- and wandb_config.wandb_mode != "offline"
987
- )
988
- if should_delete_wandb_logs:
989
- logger.info(
990
- "WandB local logs will be deleted after training completes. "
991
- "To keep logs, set trainer_config.wandb.delete_local_logs=false"
992
- )
901
+ # Learning rate monitor callback - logs LR at each step for dynamic schedulers
902
+ # Only added when wandb is enabled since it requires a logger
903
+ callbacks.append(LearningRateMonitor(logging_interval="step"))
993
904
 
994
905
  # save the configs as yaml in the checkpoint dir
995
906
  # Mask API key in both configs to prevent saving to disk
@@ -1009,8 +920,11 @@ class ModelTrainer:
1009
920
  )
1010
921
  callbacks.append(ProgressReporterZMQ(address=publish_address))
1011
922
 
1012
- # viz callbacks - use unified callback for all visualization outputs
923
+ # viz callbacks
1013
924
  if self.config.trainer_config.visualize_preds_during_training:
925
+ train_viz_pipeline = cycle(viz_train_dataset)
926
+ val_viz_pipeline = cycle(viz_val_dataset)
927
+
1014
928
  viz_dir = (
1015
929
  Path(self.config.trainer_config.ckpt_dir)
1016
930
  / self.config.trainer_config.run_name
@@ -1020,77 +934,147 @@ class ModelTrainer:
1020
934
  if RANK in [0, -1]:
1021
935
  Path(viz_dir).mkdir(parents=True, exist_ok=True)
1022
936
 
1023
- # Get wandb viz config options
1024
- log_wandb = self.config.trainer_config.use_wandb and OmegaConf.select(
1025
- self.config, "trainer_config.wandb.save_viz_imgs_wandb", default=False
1026
- )
1027
- wandb_modes = []
1028
- if log_wandb:
1029
- if OmegaConf.select(
1030
- self.config, "trainer_config.wandb.viz_enabled", default=True
1031
- ):
1032
- wandb_modes.append("direct")
1033
- if OmegaConf.select(
1034
- self.config, "trainer_config.wandb.viz_boxes", default=False
1035
- ):
1036
- wandb_modes.append("boxes")
1037
- if OmegaConf.select(
1038
- self.config, "trainer_config.wandb.viz_masks", default=False
1039
- ):
1040
- wandb_modes.append("masks")
1041
-
1042
- # Single unified callback handles all visualization outputs
1043
937
  callbacks.append(
1044
- UnifiedVizCallback(
1045
- model_trainer=self,
1046
- train_dataset=viz_train_dataset,
1047
- val_dataset=viz_val_dataset,
1048
- model_type=self.model_type,
1049
- save_local=self.config.trainer_config.save_ckpt,
1050
- local_save_dir=viz_dir,
1051
- log_wandb=log_wandb,
1052
- wandb_modes=wandb_modes if wandb_modes else ["direct"],
1053
- wandb_box_size=OmegaConf.select(
1054
- self.config, "trainer_config.wandb.viz_box_size", default=5.0
1055
- ),
1056
- wandb_confmap_threshold=OmegaConf.select(
1057
- self.config,
1058
- "trainer_config.wandb.viz_confmap_threshold",
1059
- default=0.1,
938
+ MatplotlibSaver(
939
+ save_folder=viz_dir,
940
+ plot_fn=lambda: self.lightning_model.visualize_example(
941
+ next(train_viz_pipeline)
1060
942
  ),
1061
- log_wandb_table=OmegaConf.select(
1062
- self.config, "trainer_config.wandb.log_viz_table", default=False
943
+ prefix="train",
944
+ )
945
+ )
946
+ callbacks.append(
947
+ MatplotlibSaver(
948
+ save_folder=viz_dir,
949
+ plot_fn=lambda: self.lightning_model.visualize_example(
950
+ next(val_viz_pipeline)
1063
951
  ),
952
+ prefix="validation",
1064
953
  )
1065
954
  )
1066
955
 
1067
- # Add custom progress bar with better metric formatting
1068
- if self.config.trainer_config.enable_progress_bar:
1069
- callbacks.append(SleapProgressBar())
956
+ if self.model_type == "bottomup":
957
+ train_viz_pipeline1 = cycle(copy.deepcopy(viz_train_dataset))
958
+ val_viz_pipeline1 = cycle(copy.deepcopy(viz_val_dataset))
959
+ callbacks.append(
960
+ MatplotlibSaver(
961
+ save_folder=viz_dir,
962
+ plot_fn=lambda: self.lightning_model.visualize_pafs_example(
963
+ next(train_viz_pipeline1)
964
+ ),
965
+ prefix="train.pafs_magnitude",
966
+ )
967
+ )
968
+ callbacks.append(
969
+ MatplotlibSaver(
970
+ save_folder=viz_dir,
971
+ plot_fn=lambda: self.lightning_model.visualize_pafs_example(
972
+ next(val_viz_pipeline1)
973
+ ),
974
+ prefix="validation.pafs_magnitude",
975
+ )
976
+ )
1070
977
 
1071
- # Add epoch-end evaluation callback if enabled
1072
- if self.config.trainer_config.eval.enabled:
1073
- if self.model_type == "centroid":
1074
- # Use centroid-specific evaluation with distance-based metrics
978
+ if self.model_type == "multi_class_bottomup":
979
+ train_viz_pipeline1 = cycle(copy.deepcopy(viz_train_dataset))
980
+ val_viz_pipeline1 = cycle(copy.deepcopy(viz_val_dataset))
1075
981
  callbacks.append(
1076
- CentroidEvaluationCallback(
1077
- videos=self.val_labels[0].videos,
1078
- eval_frequency=self.config.trainer_config.eval.frequency,
1079
- match_threshold=self.config.trainer_config.eval.match_threshold,
982
+ MatplotlibSaver(
983
+ save_folder=viz_dir,
984
+ plot_fn=lambda: self.lightning_model.visualize_class_maps_example(
985
+ next(train_viz_pipeline1)
986
+ ),
987
+ prefix="train.class_maps",
1080
988
  )
1081
989
  )
1082
- else:
1083
- # Use standard OKS/PCK evaluation for pose models
1084
990
  callbacks.append(
1085
- EpochEndEvaluationCallback(
1086
- skeleton=self.skeletons[0],
1087
- videos=self.val_labels[0].videos,
1088
- eval_frequency=self.config.trainer_config.eval.frequency,
1089
- oks_stddev=self.config.trainer_config.eval.oks_stddev,
1090
- oks_scale=self.config.trainer_config.eval.oks_scale,
991
+ MatplotlibSaver(
992
+ save_folder=viz_dir,
993
+ plot_fn=lambda: self.lightning_model.visualize_class_maps_example(
994
+ next(val_viz_pipeline1)
995
+ ),
996
+ prefix="validation.class_maps",
1091
997
  )
1092
998
  )
1093
999
 
1000
+ if self.config.trainer_config.use_wandb and OmegaConf.select(
1001
+ self.config, "trainer_config.wandb.save_viz_imgs_wandb", default=False
1002
+ ):
1003
+ # Get wandb viz config options
1004
+ viz_enabled = OmegaConf.select(
1005
+ self.config, "trainer_config.wandb.viz_enabled", default=True
1006
+ )
1007
+ viz_boxes = OmegaConf.select(
1008
+ self.config, "trainer_config.wandb.viz_boxes", default=False
1009
+ )
1010
+ viz_masks = OmegaConf.select(
1011
+ self.config, "trainer_config.wandb.viz_masks", default=False
1012
+ )
1013
+ viz_box_size = OmegaConf.select(
1014
+ self.config, "trainer_config.wandb.viz_box_size", default=5.0
1015
+ )
1016
+ viz_confmap_threshold = OmegaConf.select(
1017
+ self.config,
1018
+ "trainer_config.wandb.viz_confmap_threshold",
1019
+ default=0.1,
1020
+ )
1021
+ log_viz_table = OmegaConf.select(
1022
+ self.config, "trainer_config.wandb.log_viz_table", default=False
1023
+ )
1024
+
1025
+ # Create viz data pipelines for wandb callback
1026
+ wandb_train_viz_pipeline = cycle(copy.deepcopy(viz_train_dataset))
1027
+ wandb_val_viz_pipeline = cycle(copy.deepcopy(viz_val_dataset))
1028
+
1029
+ if self.model_type == "bottomup":
1030
+ # Bottom-up model needs PAF visualizations
1031
+ wandb_train_pafs_pipeline = cycle(copy.deepcopy(viz_train_dataset))
1032
+ wandb_val_pafs_pipeline = cycle(copy.deepcopy(viz_val_dataset))
1033
+ callbacks.append(
1034
+ WandBVizCallbackWithPAFs(
1035
+ train_viz_fn=lambda: self.lightning_model.get_visualization_data(
1036
+ next(wandb_train_viz_pipeline)
1037
+ ),
1038
+ val_viz_fn=lambda: self.lightning_model.get_visualization_data(
1039
+ next(wandb_val_viz_pipeline)
1040
+ ),
1041
+ train_pafs_viz_fn=lambda: self.lightning_model.get_visualization_data(
1042
+ next(wandb_train_pafs_pipeline), include_pafs=True
1043
+ ),
1044
+ val_pafs_viz_fn=lambda: self.lightning_model.get_visualization_data(
1045
+ next(wandb_val_pafs_pipeline), include_pafs=True
1046
+ ),
1047
+ viz_enabled=viz_enabled,
1048
+ viz_boxes=viz_boxes,
1049
+ viz_masks=viz_masks,
1050
+ box_size=viz_box_size,
1051
+ confmap_threshold=viz_confmap_threshold,
1052
+ log_table=log_viz_table,
1053
+ )
1054
+ )
1055
+ else:
1056
+ # Standard models
1057
+ callbacks.append(
1058
+ WandBVizCallback(
1059
+ train_viz_fn=lambda: self.lightning_model.get_visualization_data(
1060
+ next(wandb_train_viz_pipeline)
1061
+ ),
1062
+ val_viz_fn=lambda: self.lightning_model.get_visualization_data(
1063
+ next(wandb_val_viz_pipeline)
1064
+ ),
1065
+ viz_enabled=viz_enabled,
1066
+ viz_boxes=viz_boxes,
1067
+ viz_masks=viz_masks,
1068
+ box_size=viz_box_size,
1069
+ confmap_threshold=viz_confmap_threshold,
1070
+ log_table=log_viz_table,
1071
+ )
1072
+ )
1073
+
1074
+ # Add custom progress bar with better metric formatting
1075
+ if self.config.trainer_config.enable_progress_bar:
1076
+ callbacks.append(SleapProgressBar())
1077
+
1094
1078
  return loggers, callbacks
1095
1079
 
1096
1080
  def _delete_cache_imgs(self):
@@ -1179,11 +1163,6 @@ class ModelTrainer:
1179
1163
  : self.config.trainer_config.trainer_devices
1180
1164
  ]
1181
1165
  ]
1182
- # Sort device indices in ascending order for NCCL compatibility.
1183
- # NCCL expects devices in consistent ascending order across ranks
1184
- # to properly set up communication rings. Without sorting, DDP may
1185
- # assign multiple ranks to the same GPU, causing "Duplicate GPU detected" errors.
1186
- devices.sort()
1187
1166
  logger.info(f"Using GPUs with most available memory: {devices}")
1188
1167
 
1189
1168
  # create lightning.Trainer instance.
@@ -1205,10 +1184,6 @@ class ModelTrainer:
1205
1184
  # setup datasets
1206
1185
  train_dataset, val_dataset = self._setup_datasets()
1207
1186
 
1208
- # Barrier after dataset creation to ensure all workers wait for disk caching
1209
- # (rank 0 caches to disk, others must wait before reading cached files)
1210
- self.trainer.strategy.barrier()
1211
-
1212
1187
  # set-up steps per epoch
1213
1188
  train_steps_per_epoch = self.config.trainer_config.train_steps_per_epoch
1214
1189
  if train_steps_per_epoch is None:
@@ -1282,21 +1257,18 @@ class ModelTrainer:
1282
1257
  # Define custom x-axes for wandb metrics
1283
1258
  # Epoch-level metrics use epoch as x-axis, step-level use default global_step
1284
1259
  wandb.define_metric("epoch")
1285
-
1286
- # Training metrics (train/ prefix for grouping) - all use epoch x-axis
1287
- wandb.define_metric("train/*", step_metric="epoch")
1288
- wandb.define_metric("train/confmaps/*", step_metric="epoch")
1289
-
1290
- # Validation metrics (val/ prefix for grouping)
1291
- wandb.define_metric("val/*", step_metric="epoch")
1292
-
1293
- # Evaluation metrics (eval/ prefix for grouping)
1294
- wandb.define_metric("eval/*", step_metric="epoch")
1295
-
1296
- # Visualization images (need explicit nested paths)
1297
- wandb.define_metric("viz/*", step_metric="epoch")
1298
- wandb.define_metric("viz/train/*", step_metric="epoch")
1299
- wandb.define_metric("viz/val/*", step_metric="epoch")
1260
+ wandb.define_metric("val_loss", step_metric="epoch")
1261
+ wandb.define_metric("val_time", step_metric="epoch")
1262
+ wandb.define_metric("train_time", step_metric="epoch")
1263
+ # Per-node losses use epoch as x-axis
1264
+ for node_name in self.skeletons[0].node_names:
1265
+ wandb.define_metric(node_name, step_metric="epoch")
1266
+
1267
+ # Visualization images use epoch as x-axis
1268
+ wandb.define_metric("train_predictions*", step_metric="epoch")
1269
+ wandb.define_metric("val_predictions*", step_metric="epoch")
1270
+ wandb.define_metric("train_pafs*", step_metric="epoch")
1271
+ wandb.define_metric("val_pafs*", step_metric="epoch")
1300
1272
 
1301
1273
  self.config.trainer_config.wandb.current_run_id = wandb.run.id
1302
1274
  wandb.config["run_name"] = self.config.trainer_config.wandb.name
@@ -1339,7 +1311,8 @@ class ModelTrainer:
1339
1311
  logger.info(
1340
1312
  f"Finished training loop. [{(time.time() - start_train_time) / 60:.1f} min]"
1341
1313
  )
1342
- # Note: wandb.finish() is called in train.py after post-training evaluation
1314
+ if self.trainer.global_rank == 0 and self.config.trainer_config.use_wandb:
1315
+ wandb.finish()
1343
1316
 
1344
1317
  # delete image disk caching
1345
1318
  if (
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sleap-nn
3
- Version: 0.1.0
3
+ Version: 0.1.0a0
4
4
  Summary: Neural network backend for training and inference for animal pose estimation.
5
5
  Author-email: Divya Seshadri Murali <dimurali@salk.edu>, Elizabeth Berrigan <eberrigan@salk.edu>, Vincent Tu <vitu@ucsd.edu>, Liezl Maree <lmaree@salk.edu>, David Samy <davidasamy@gmail.com>, Talmo Pereira <talmo@salk.edu>
6
6
  License: BSD-3-Clause
@@ -13,10 +13,10 @@ Classifier: Programming Language :: Python :: 3.13
13
13
  Requires-Python: <3.14,>=3.11
14
14
  Description-Content-Type: text/markdown
15
15
  License-File: LICENSE
16
- Requires-Dist: sleap-io<0.7.0,>=0.6.2
16
+ Requires-Dist: sleap-io<0.7.0,>=0.6.0
17
17
  Requires-Dist: numpy
18
18
  Requires-Dist: lightning
19
- Requires-Dist: skia-python>=87.0
19
+ Requires-Dist: kornia
20
20
  Requires-Dist: jsonpickle
21
21
  Requires-Dist: scipy
22
22
  Requires-Dist: attrs
@@ -32,7 +32,6 @@ Requires-Dist: hydra-core
32
32
  Requires-Dist: jupyter
33
33
  Requires-Dist: jupyterlab
34
34
  Requires-Dist: pyzmq
35
- Requires-Dist: rich-click>=1.9.5
36
35
  Provides-Extra: torch
37
36
  Requires-Dist: torch; extra == "torch"
38
37
  Requires-Dist: torchvision>=0.20.0; extra == "torch"
@@ -48,17 +47,6 @@ Requires-Dist: torchvision>=0.20.0; extra == "torch-cuda128"
48
47
  Provides-Extra: torch-cuda130
49
48
  Requires-Dist: torch; extra == "torch-cuda130"
50
49
  Requires-Dist: torchvision>=0.20.0; extra == "torch-cuda130"
51
- Provides-Extra: export
52
- Requires-Dist: onnx>=1.15.0; extra == "export"
53
- Requires-Dist: onnxruntime>=1.16.0; extra == "export"
54
- Requires-Dist: onnxscript>=0.1.0; extra == "export"
55
- Provides-Extra: export-gpu
56
- Requires-Dist: onnx>=1.15.0; extra == "export-gpu"
57
- Requires-Dist: onnxruntime-gpu>=1.16.0; extra == "export-gpu"
58
- Requires-Dist: onnxscript>=0.1.0; extra == "export-gpu"
59
- Provides-Extra: tensorrt
60
- Requires-Dist: tensorrt>=10.13.0; (sys_platform == "linux" or sys_platform == "win32") and extra == "tensorrt"
61
- Requires-Dist: torch-tensorrt>=2.5.0; (sys_platform == "linux" or sys_platform == "win32") and extra == "tensorrt"
62
50
  Dynamic: license-file
63
51
 
64
52
  # sleap-nn
@@ -0,0 +1,65 @@
1
+ sleap_nn/.DS_Store,sha256=HY8amA79eHkt7o5VUiNsMxkc9YwW6WIPyZbYRj_JdSU,6148
2
+ sleap_nn/__init__.py,sha256=DzQeiZIFUmfhpf6mk4j1AKAY2bofVMyIa31xbiSu-ls,1317
3
+ sleap_nn/cli.py,sha256=U4hpEcOxK7a92GeItY95E2DRm5P1ME1GqU__mxaDcW0,21167
4
+ sleap_nn/evaluation.py,sha256=3u7y85wFoBgCwOB2xOGTJIDrd2dUPWOo4m0s0oW3da4,31095
5
+ sleap_nn/legacy_models.py,sha256=8aGK30DZv3pW2IKDBEWH1G2mrytjaxPQD4miPUehj0M,20258
6
+ sleap_nn/predict.py,sha256=8QKjRbS-L-6HF1NFJWioBPv3HSzUpFr2oGEB5hRJzQA,35523
7
+ sleap_nn/system_info.py,sha256=7tWe3y6s872nDbrZoHIdSs-w4w46Z4dEV2qCV-Fe7No,14711
8
+ sleap_nn/train.py,sha256=fWx_b1HqkadQ-GM_VEM1frCd8WkzJLqRARBNn8UoUbo,27181
9
+ sleap_nn/architectures/__init__.py,sha256=w0XxQcx-CYyooszzvxRkKWiJkUg-26IlwQoGna8gn40,46
10
+ sleap_nn/architectures/common.py,sha256=MLv-zdHsWL5Q2ct_Wv6SQbRS-5hrFtjK_pvBEfwx-vU,3660
11
+ sleap_nn/architectures/convnext.py,sha256=l9lMJDxIMb-9MI3ShOtVwbOUMuwOLtSQlxiVyYHqjvE,13953
12
+ sleap_nn/architectures/encoder_decoder.py,sha256=f3DUFJo6RrIUposdC3Ytyblr5J0tAeZ_si9dm_m_PhM,28339
13
+ sleap_nn/architectures/heads.py,sha256=5E-7kQ-b2gsL0EviQ8z3KS1DAAMT4F2ZnEzx7eSG5gg,21001
14
+ sleap_nn/architectures/model.py,sha256=1_dsP_4T9fsEVJjDt3er0haMKtbeM6w6JC6tc2jD0Gw,7139
15
+ sleap_nn/architectures/swint.py,sha256=S66Wd0j8Hp-rGlv1C60WSw3AwGyAyGetgfwpL0nIK_M,14687
16
+ sleap_nn/architectures/unet.py,sha256=rAy2Omi6tv1MNW2nBn0Tw-94Nw_-1wFfCT3-IUyPcgo,11723
17
+ sleap_nn/architectures/utils.py,sha256=L0KVs0gbtG8U75Sl40oH_r_w2ySawh3oQPqIGi54HGo,2171
18
+ sleap_nn/config/__init__.py,sha256=l0xV1uJsGJfMPfWAqlUR7Ivu4cSCWsP-3Y9ueyPESuk,42
19
+ sleap_nn/config/data_config.py,sha256=5a5YlXm4V9qGvkqgFNy6o0XJ_Q06UFjpYJXmNHfvXEI,24021
20
+ sleap_nn/config/get_config.py,sha256=vN_aOPTj9F-QBqGGfVSv8_aFSAYl-RfXY0pdbdcqjcM,42021
21
+ sleap_nn/config/model_config.py,sha256=XFIbqFno7IkX0Se5WF_2_7aUalAlC2SvpDe-uP2TttM,57582
22
+ sleap_nn/config/trainer_config.py,sha256=PaoNtRSNc2xgzwN955aR9kTZL8IxCWdevGljLxS6jOk,28073
23
+ sleap_nn/config/training_job_config.py,sha256=v12_ME_tBUg8JFwOxJNW4sDQn-SedDhiJOGz-TlRwT0,5861
24
+ sleap_nn/config/utils.py,sha256=GgWgVs7_N7ifsJ5OQG3_EyOagNyN3Dx7wS2BAlkaRkg,5553
25
+ sleap_nn/data/__init__.py,sha256=eMNvFJFa3gv5Rq8oK5wzo6zt1pOlwUGYf8EQii6bq7c,54
26
+ sleap_nn/data/augmentation.py,sha256=Kqw_DayPth_DBsmaO1G8Voou_-cYZuSPOjSQWSajgRI,13618
27
+ sleap_nn/data/confidence_maps.py,sha256=PTRqZWSAz1S7viJhxu7QgIC1aHiek97c_dCUsKUwG1o,6217
28
+ sleap_nn/data/custom_datasets.py,sha256=2qAyLeiCPI9uudFFP7zlj6d_tbxc5OVzpnzT23mRkVw,98472
29
+ sleap_nn/data/edge_maps.py,sha256=75qG_7zHRw7fC8JUCVI2tzYakIoxxneWWmcrTwjcHPo,12519
30
+ sleap_nn/data/identity.py,sha256=7vNup6PudST4yDLyDT9wDO-cunRirTEvx4sP77xrlfk,5193
31
+ sleap_nn/data/instance_centroids.py,sha256=SF-3zJt_VMTbZI5ssbrvmZQZDd3684bn55EAtvcbQ6o,2172
32
+ sleap_nn/data/instance_cropping.py,sha256=2dYq5OTwkFN1PdMjoxyuMuHq1OEe03m3Vzqvcs_dkPE,8304
33
+ sleap_nn/data/normalization.py,sha256=5xEvcguG-fvAGObl4nWPZ9TEM5gvv0uYPGDuni34XII,2930
34
+ sleap_nn/data/providers.py,sha256=0x6GFP1s1c08ji4p0M5V6p-dhT4Z9c-SI_Aw1DWX-uM,14272
35
+ sleap_nn/data/resizing.py,sha256=YFpSQduIBkRK39FYmrqDL-v8zMySlEs6TJxh6zb_0ZU,5076
36
+ sleap_nn/data/utils.py,sha256=rT0w7KMOTlzaeKWq1TqjbgC4Lvjz_G96McllvEOqXx8,5641
37
+ sleap_nn/inference/__init__.py,sha256=eVkCmKrxHlDFJIlZTf8B5XEOcSyw-gPQymXMY5uShOM,170
38
+ sleap_nn/inference/bottomup.py,sha256=NqN-G8TzAOsvCoL3bttEjA1iGsuveLOnOCXIUeFCdSA,13684
39
+ sleap_nn/inference/identity.py,sha256=GjNDL9MfGqNyQaK4AE8JQCAE8gpMuE_Y-3r3Gpa53CE,6540
40
+ sleap_nn/inference/paf_grouping.py,sha256=7Fo9lCAj-zcHgv5rI5LIMYGcixCGNt_ZbSNs8Dik7l8,69973
41
+ sleap_nn/inference/peak_finding.py,sha256=L9LdYKt_Bfw7cxo6xEpgF8wXcZAwq5plCfmKJ839N40,13014
42
+ sleap_nn/inference/predictors.py,sha256=U114RlgOXKGm5iz1lnTfE3aN9S0WCh6gWhVP3KVewfc,158046
43
+ sleap_nn/inference/provenance.py,sha256=0BekXyvpLMb0Vv6DjpctlLduG9RN-Q8jt5zDm783eZE,11204
44
+ sleap_nn/inference/single_instance.py,sha256=rOns_5TsJ1rb-lwmHG3ZY-pOhXGN2D-SfW9RmBxxzcI,4089
45
+ sleap_nn/inference/topdown.py,sha256=Ha0Nwx-XCH_rebIuIGhP0qW68QpjLB3XRr9rxt05JLs,35108
46
+ sleap_nn/inference/utils.py,sha256=JnaJK4S_qLtHkWOSkHf4oRZjOmgnU9BGADQnntgGxxs,4689
47
+ sleap_nn/tracking/__init__.py,sha256=rGR35wpSW-n5d3cMiQUzQQ_Dy5II5DPjlXAoPw2QhmM,31
48
+ sleap_nn/tracking/track_instance.py,sha256=9k0uVy9VmpleaLcJh7sVWSeFUPXiw7yj95EYNdXJcks,1373
49
+ sleap_nn/tracking/tracker.py,sha256=_WT-HFruzyOsvcq3AtLm3vnI9MYSwyBmq-HlQvj1vmU,41955
50
+ sleap_nn/tracking/utils.py,sha256=uHVd_mzzZjviVDdLSKXJJ1T96n5ObKvkqIuGsl9Yy8U,11276
51
+ sleap_nn/tracking/candidates/__init__.py,sha256=1O7NObIwshM7j1rLHmImbFphvkM9wY1j4j1TvO5scSE,49
52
+ sleap_nn/tracking/candidates/fixed_window.py,sha256=D80KMlTnenuQveQVVhk9j0G8yx6K324C7nMLHgG76e0,6296
53
+ sleap_nn/tracking/candidates/local_queues.py,sha256=Nx3R5wwEwq0gbfH-fi3oOumfkQo8_sYe5GN47pD9Be8,7305
54
+ sleap_nn/training/__init__.py,sha256=vNTKsIJPZHJwFSKn5PmjiiRJunR_9e7y4_v0S6rdF8U,32
55
+ sleap_nn/training/callbacks.py,sha256=TVnQ6plNC2MnlTiY2rSCRuw2WRk5cQSziek_VPUcOEg,25994
56
+ sleap_nn/training/lightning_modules.py,sha256=G3c4xJkYWW-iSRawzkgTqkGd4lTsbPiMTcB5Nvq7jes,85512
57
+ sleap_nn/training/losses.py,sha256=gbdinUURh4QUzjmNd2UJpt4FXwecqKy9gHr65JZ1bZk,1632
58
+ sleap_nn/training/model_trainer.py,sha256=InDKHrQxBwbltZKutW4yrBR9NThLdRpWNUGhmB0xAi4,57863
59
+ sleap_nn/training/utils.py,sha256=ivdkZEI0DkTCm6NPszsaDOh9jSfozkONZdl6TvvQUWI,20398
60
+ sleap_nn-0.1.0a0.dist-info/licenses/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
61
+ sleap_nn-0.1.0a0.dist-info/METADATA,sha256=lxSmGNTUg9eetqHCvhw8Tv5zJtia-dIM5RzOeoDccj8,5637
62
+ sleap_nn-0.1.0a0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
63
+ sleap_nn-0.1.0a0.dist-info/entry_points.txt,sha256=zfl5Y3hidZxWBvo8qXvu5piJAXJ_l6v7xVFm0gNiUoI,46
64
+ sleap_nn-0.1.0a0.dist-info/top_level.txt,sha256=Kz68iQ55K75LWgSeqz4V4SCMGeFFYH-KGBOyhQh3xZE,9
65
+ sleap_nn-0.1.0a0.dist-info/RECORD,,