sleap-nn 0.1.0__py3-none-any.whl → 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sleap_nn/__init__.py +1 -1
- sleap_nn/architectures/convnext.py +0 -5
- sleap_nn/architectures/encoder_decoder.py +6 -25
- sleap_nn/architectures/swint.py +0 -8
- sleap_nn/cli.py +60 -364
- sleap_nn/config/data_config.py +5 -11
- sleap_nn/config/get_config.py +4 -5
- sleap_nn/config/trainer_config.py +0 -71
- sleap_nn/data/augmentation.py +241 -50
- sleap_nn/data/custom_datasets.py +34 -364
- sleap_nn/data/instance_cropping.py +1 -1
- sleap_nn/data/resizing.py +2 -2
- sleap_nn/data/utils.py +17 -135
- sleap_nn/evaluation.py +22 -81
- sleap_nn/inference/bottomup.py +20 -86
- sleap_nn/inference/peak_finding.py +19 -88
- sleap_nn/inference/predictors.py +117 -224
- sleap_nn/legacy_models.py +11 -65
- sleap_nn/predict.py +9 -37
- sleap_nn/train.py +4 -69
- sleap_nn/training/callbacks.py +105 -1046
- sleap_nn/training/lightning_modules.py +65 -602
- sleap_nn/training/model_trainer.py +204 -201
- {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a1.dist-info}/METADATA +3 -15
- sleap_nn-0.1.0a1.dist-info/RECORD +65 -0
- {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a1.dist-info}/WHEEL +1 -1
- sleap_nn/data/skia_augmentation.py +0 -414
- sleap_nn/export/__init__.py +0 -21
- sleap_nn/export/cli.py +0 -1778
- sleap_nn/export/exporters/__init__.py +0 -51
- sleap_nn/export/exporters/onnx_exporter.py +0 -80
- sleap_nn/export/exporters/tensorrt_exporter.py +0 -291
- sleap_nn/export/metadata.py +0 -225
- sleap_nn/export/predictors/__init__.py +0 -63
- sleap_nn/export/predictors/base.py +0 -22
- sleap_nn/export/predictors/onnx.py +0 -154
- sleap_nn/export/predictors/tensorrt.py +0 -312
- sleap_nn/export/utils.py +0 -307
- sleap_nn/export/wrappers/__init__.py +0 -25
- sleap_nn/export/wrappers/base.py +0 -96
- sleap_nn/export/wrappers/bottomup.py +0 -243
- sleap_nn/export/wrappers/bottomup_multiclass.py +0 -195
- sleap_nn/export/wrappers/centered_instance.py +0 -56
- sleap_nn/export/wrappers/centroid.py +0 -58
- sleap_nn/export/wrappers/single_instance.py +0 -83
- sleap_nn/export/wrappers/topdown.py +0 -180
- sleap_nn/export/wrappers/topdown_multiclass.py +0 -304
- sleap_nn/inference/postprocessing.py +0 -284
- sleap_nn/training/schedulers.py +0 -191
- sleap_nn-0.1.0.dist-info/RECORD +0 -88
- {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a1.dist-info}/entry_points.txt +0 -0
- {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a1.dist-info}/licenses/LICENSE +0 -0
- {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a1.dist-info}/top_level.txt +0 -0
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
4
|
import shutil
|
|
5
|
+
import copy
|
|
5
6
|
import attrs
|
|
6
7
|
import torch
|
|
7
8
|
import random
|
|
@@ -15,13 +16,14 @@ import yaml
|
|
|
15
16
|
from pathlib import Path
|
|
16
17
|
from typing import List, Optional
|
|
17
18
|
from datetime import datetime
|
|
18
|
-
from itertools import count
|
|
19
|
+
from itertools import cycle, count
|
|
19
20
|
from omegaconf import DictConfig, OmegaConf
|
|
20
21
|
from lightning.pytorch.loggers import WandbLogger
|
|
21
22
|
from sleap_nn.data.utils import check_cache_memory
|
|
22
23
|
from lightning.pytorch.callbacks import (
|
|
23
24
|
ModelCheckpoint,
|
|
24
25
|
EarlyStopping,
|
|
26
|
+
LearningRateMonitor,
|
|
25
27
|
)
|
|
26
28
|
from lightning.pytorch.profilers import (
|
|
27
29
|
SimpleProfiler,
|
|
@@ -53,11 +55,12 @@ from sleap_nn.config.training_job_config import verify_training_cfg
|
|
|
53
55
|
from sleap_nn.training.callbacks import (
|
|
54
56
|
ProgressReporterZMQ,
|
|
55
57
|
TrainingControllerZMQ,
|
|
58
|
+
MatplotlibSaver,
|
|
59
|
+
WandBPredImageLogger,
|
|
60
|
+
WandBVizCallback,
|
|
61
|
+
WandBVizCallbackWithPAFs,
|
|
56
62
|
CSVLoggerCallback,
|
|
57
63
|
SleapProgressBar,
|
|
58
|
-
EpochEndEvaluationCallback,
|
|
59
|
-
CentroidEvaluationCallback,
|
|
60
|
-
UnifiedVizCallback,
|
|
61
64
|
)
|
|
62
65
|
from sleap_nn import RANK
|
|
63
66
|
from sleap_nn.legacy_models import get_keras_first_layer_channels
|
|
@@ -487,36 +490,16 @@ class ModelTrainer:
|
|
|
487
490
|
ckpt_dir = "."
|
|
488
491
|
self.config.trainer_config.ckpt_dir = ckpt_dir
|
|
489
492
|
run_name = self.config.trainer_config.run_name
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
# Validate: multi-GPU + disk cache requires explicit run_name
|
|
493
|
-
if run_name_is_empty:
|
|
494
|
-
is_disk_caching = (
|
|
495
|
-
self.config.data_config.data_pipeline_fw
|
|
496
|
-
== "torch_dataset_cache_img_disk"
|
|
497
|
-
)
|
|
498
|
-
num_devices = self._get_trainer_devices()
|
|
499
|
-
|
|
500
|
-
if is_disk_caching and num_devices > 1:
|
|
501
|
-
raise ValueError(
|
|
502
|
-
f"Multi-GPU training with disk caching requires an explicit `run_name`.\n\n"
|
|
503
|
-
f"Detected {num_devices} device(s) with "
|
|
504
|
-
f"`data_pipeline_fw='torch_dataset_cache_img_disk'`.\n"
|
|
505
|
-
f"Without an explicit run_name, each GPU worker generates a different "
|
|
506
|
-
f"timestamp-based directory, causing cache synchronization failures.\n\n"
|
|
507
|
-
f"Please provide a run_name using one of these methods:\n"
|
|
508
|
-
f" - CLI: sleap-nn train config.yaml trainer_config.run_name=my_experiment\n"
|
|
509
|
-
f" - Config file: Set `trainer_config.run_name: my_experiment`\n"
|
|
510
|
-
f" - Python API: train(..., run_name='my_experiment')"
|
|
511
|
-
)
|
|
512
|
-
|
|
513
|
-
# Auto-generate timestamp-based run_name (safe for single GPU or non-disk-cache)
|
|
493
|
+
if run_name is None or run_name == "" or run_name == "None":
|
|
514
494
|
sum_train_lfs = sum([len(train_label) for train_label in self.train_labels])
|
|
515
495
|
sum_val_lfs = sum([len(val_label) for val_label in self.val_labels])
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
496
|
+
if self._get_trainer_devices() > 1:
|
|
497
|
+
run_name = f"{self.model_type}.n={sum_train_lfs + sum_val_lfs}"
|
|
498
|
+
else:
|
|
499
|
+
run_name = (
|
|
500
|
+
datetime.now().strftime("%y%m%d_%H%M%S")
|
|
501
|
+
+ f".{self.model_type}.n={sum_train_lfs + sum_val_lfs}"
|
|
502
|
+
)
|
|
520
503
|
|
|
521
504
|
# If checkpoint path already exists, add suffix to prevent overwriting
|
|
522
505
|
if (Path(ckpt_dir) / run_name).exists() and (
|
|
@@ -659,6 +642,10 @@ class ModelTrainer:
|
|
|
659
642
|
if self.config.trainer_config.wandb.prv_runid == "":
|
|
660
643
|
self.config.trainer_config.wandb.prv_runid = None
|
|
661
644
|
|
|
645
|
+
# Default wandb run name to trainer run_name if not specified
|
|
646
|
+
if self.config.trainer_config.wandb.name is None:
|
|
647
|
+
self.config.trainer_config.wandb.name = self.config.trainer_config.run_name
|
|
648
|
+
|
|
662
649
|
# compute preprocessing parameters from the labels objects and fill in the config
|
|
663
650
|
self._setup_preprocessing_config()
|
|
664
651
|
|
|
@@ -708,14 +695,9 @@ class ModelTrainer:
|
|
|
708
695
|
)
|
|
709
696
|
)
|
|
710
697
|
|
|
711
|
-
# setup checkpoint path
|
|
698
|
+
# setup checkpoint path
|
|
712
699
|
self._setup_ckpt_path()
|
|
713
700
|
|
|
714
|
-
# Default wandb run name to trainer run_name if not specified
|
|
715
|
-
# Note: This must come after _setup_ckpt_path() which generates run_name
|
|
716
|
-
if self.config.trainer_config.wandb.name is None:
|
|
717
|
-
self.config.trainer_config.wandb.name = self.config.trainer_config.run_name
|
|
718
|
-
|
|
719
701
|
# verify input_channels in model_config based on input image and pretrained model weights
|
|
720
702
|
self._verify_model_input_channels()
|
|
721
703
|
|
|
@@ -727,15 +709,15 @@ class ModelTrainer:
|
|
|
727
709
|
).as_posix()
|
|
728
710
|
logger.info(f"Setting up model ckpt dir: `{ckpt_path}`...")
|
|
729
711
|
|
|
730
|
-
|
|
712
|
+
if not Path(ckpt_path).exists():
|
|
713
|
+
try:
|
|
714
|
+
Path(ckpt_path).mkdir(parents=True, exist_ok=True)
|
|
715
|
+
except OSError as e:
|
|
716
|
+
message = f"Cannot create a new folder in {ckpt_path}.\n {e}"
|
|
717
|
+
logger.error(message)
|
|
718
|
+
raise OSError(message)
|
|
719
|
+
|
|
731
720
|
if RANK in [0, -1]:
|
|
732
|
-
if not Path(ckpt_path).exists():
|
|
733
|
-
try:
|
|
734
|
-
Path(ckpt_path).mkdir(parents=True, exist_ok=True)
|
|
735
|
-
except OSError as e:
|
|
736
|
-
message = f"Cannot create a new folder in {ckpt_path}.\n {e}"
|
|
737
|
-
logger.error(message)
|
|
738
|
-
raise OSError(message)
|
|
739
721
|
# Check if we should filter to user-labeled frames only
|
|
740
722
|
user_instances_only = OmegaConf.select(
|
|
741
723
|
self.config, "data_config.user_instances_only", default=True
|
|
@@ -808,40 +790,10 @@ class ModelTrainer:
|
|
|
808
790
|
base_cache_img_path = None
|
|
809
791
|
if self.config.data_config.data_pipeline_fw == "torch_dataset_cache_img_memory":
|
|
810
792
|
# check available memory. If insufficient memory, default to disk caching.
|
|
811
|
-
# Account for DataLoader worker memory overhead
|
|
812
|
-
train_num_workers = self.config.trainer_config.train_data_loader.num_workers
|
|
813
|
-
val_num_workers = self.config.trainer_config.val_data_loader.num_workers
|
|
814
|
-
max_num_workers = max(train_num_workers, val_num_workers)
|
|
815
|
-
|
|
816
793
|
mem_available = check_cache_memory(
|
|
817
|
-
self.train_labels,
|
|
818
|
-
self.val_labels,
|
|
819
|
-
memory_buffer=MEMORY_BUFFER,
|
|
820
|
-
num_workers=max_num_workers,
|
|
794
|
+
self.train_labels, self.val_labels, memory_buffer=MEMORY_BUFFER
|
|
821
795
|
)
|
|
822
796
|
if not mem_available:
|
|
823
|
-
# Validate: multi-GPU + auto-generated run_name + fallback to disk cache
|
|
824
|
-
original_run_name = self._initial_config.trainer_config.run_name
|
|
825
|
-
run_name_was_auto = (
|
|
826
|
-
original_run_name is None
|
|
827
|
-
or original_run_name == ""
|
|
828
|
-
or original_run_name == "None"
|
|
829
|
-
)
|
|
830
|
-
if run_name_was_auto and self.trainer.num_devices > 1:
|
|
831
|
-
raise ValueError(
|
|
832
|
-
f"Memory caching failed and disk caching fallback requires an "
|
|
833
|
-
f"explicit `run_name` for multi-GPU training.\n\n"
|
|
834
|
-
f"Detected {self.trainer.num_devices} device(s) with insufficient "
|
|
835
|
-
f"memory for in-memory caching.\n"
|
|
836
|
-
f"Without an explicit run_name, each GPU worker generates a different "
|
|
837
|
-
f"timestamp-based directory, causing cache synchronization failures.\n\n"
|
|
838
|
-
f"Please provide a run_name using one of these methods:\n"
|
|
839
|
-
f" - CLI: sleap-nn train config.yaml trainer_config.run_name=my_experiment\n"
|
|
840
|
-
f" - Config file: Set `trainer_config.run_name: my_experiment`\n"
|
|
841
|
-
f" - Python API: train(..., run_name='my_experiment')\n\n"
|
|
842
|
-
f"Alternatively, use `data_pipeline_fw='torch_dataset'` to disable caching."
|
|
843
|
-
)
|
|
844
|
-
|
|
845
797
|
self.config.data_config.data_pipeline_fw = (
|
|
846
798
|
"torch_dataset_cache_img_disk"
|
|
847
799
|
)
|
|
@@ -885,7 +837,7 @@ class ModelTrainer:
|
|
|
885
837
|
/ self.config.trainer_config.run_name
|
|
886
838
|
).as_posix(),
|
|
887
839
|
filename="best",
|
|
888
|
-
monitor="
|
|
840
|
+
monitor="val_loss",
|
|
889
841
|
mode="min",
|
|
890
842
|
)
|
|
891
843
|
callbacks.append(checkpoint_callback)
|
|
@@ -893,52 +845,18 @@ class ModelTrainer:
|
|
|
893
845
|
# csv log callback
|
|
894
846
|
csv_log_keys = [
|
|
895
847
|
"epoch",
|
|
896
|
-
"
|
|
897
|
-
"
|
|
848
|
+
"train_loss",
|
|
849
|
+
"val_loss",
|
|
898
850
|
"learning_rate",
|
|
899
|
-
"
|
|
900
|
-
"
|
|
851
|
+
"train_time",
|
|
852
|
+
"val_time",
|
|
901
853
|
]
|
|
902
|
-
# Add model-specific keys for wandb parity
|
|
903
854
|
if self.model_type in [
|
|
904
855
|
"single_instance",
|
|
905
856
|
"centered_instance",
|
|
906
857
|
"multi_class_topdown",
|
|
907
858
|
]:
|
|
908
|
-
csv_log_keys.extend(
|
|
909
|
-
[f"train/confmaps/{name}" for name in self.skeletons[0].node_names]
|
|
910
|
-
)
|
|
911
|
-
if self.model_type == "bottomup":
|
|
912
|
-
csv_log_keys.extend(
|
|
913
|
-
[
|
|
914
|
-
"train/confmaps_loss",
|
|
915
|
-
"train/paf_loss",
|
|
916
|
-
"val/confmaps_loss",
|
|
917
|
-
"val/paf_loss",
|
|
918
|
-
]
|
|
919
|
-
)
|
|
920
|
-
if self.model_type == "multi_class_bottomup":
|
|
921
|
-
csv_log_keys.extend(
|
|
922
|
-
[
|
|
923
|
-
"train/confmaps_loss",
|
|
924
|
-
"train/classmap_loss",
|
|
925
|
-
"train/class_accuracy",
|
|
926
|
-
"val/confmaps_loss",
|
|
927
|
-
"val/classmap_loss",
|
|
928
|
-
"val/class_accuracy",
|
|
929
|
-
]
|
|
930
|
-
)
|
|
931
|
-
if self.model_type == "multi_class_topdown":
|
|
932
|
-
csv_log_keys.extend(
|
|
933
|
-
[
|
|
934
|
-
"train/confmaps_loss",
|
|
935
|
-
"train/classvector_loss",
|
|
936
|
-
"train/class_accuracy",
|
|
937
|
-
"val/confmaps_loss",
|
|
938
|
-
"val/classvector_loss",
|
|
939
|
-
"val/class_accuracy",
|
|
940
|
-
]
|
|
941
|
-
)
|
|
859
|
+
csv_log_keys.extend(self.skeletons[0].node_names)
|
|
942
860
|
csv_logger = CSVLoggerCallback(
|
|
943
861
|
filepath=Path(self.config.trainer_config.ckpt_dir)
|
|
944
862
|
/ self.config.trainer_config.run_name
|
|
@@ -951,7 +869,7 @@ class ModelTrainer:
|
|
|
951
869
|
# early stopping callback
|
|
952
870
|
callbacks.append(
|
|
953
871
|
EarlyStopping(
|
|
954
|
-
monitor="
|
|
872
|
+
monitor="val_loss",
|
|
955
873
|
mode="min",
|
|
956
874
|
verbose=False,
|
|
957
875
|
min_delta=self.config.trainer_config.early_stopping.min_delta,
|
|
@@ -991,6 +909,10 @@ class ModelTrainer:
|
|
|
991
909
|
"To keep logs, set trainer_config.wandb.delete_local_logs=false"
|
|
992
910
|
)
|
|
993
911
|
|
|
912
|
+
# Learning rate monitor callback - logs LR at each step for dynamic schedulers
|
|
913
|
+
# Only added when wandb is enabled since it requires a logger
|
|
914
|
+
callbacks.append(LearningRateMonitor(logging_interval="step"))
|
|
915
|
+
|
|
994
916
|
# save the configs as yaml in the checkpoint dir
|
|
995
917
|
# Mask API key in both configs to prevent saving to disk
|
|
996
918
|
self.config.trainer_config.wandb.api_key = ""
|
|
@@ -1009,8 +931,11 @@ class ModelTrainer:
|
|
|
1009
931
|
)
|
|
1010
932
|
callbacks.append(ProgressReporterZMQ(address=publish_address))
|
|
1011
933
|
|
|
1012
|
-
# viz callbacks
|
|
934
|
+
# viz callbacks
|
|
1013
935
|
if self.config.trainer_config.visualize_preds_during_training:
|
|
936
|
+
train_viz_pipeline = cycle(viz_train_dataset)
|
|
937
|
+
val_viz_pipeline = cycle(viz_val_dataset)
|
|
938
|
+
|
|
1014
939
|
viz_dir = (
|
|
1015
940
|
Path(self.config.trainer_config.ckpt_dir)
|
|
1016
941
|
/ self.config.trainer_config.run_name
|
|
@@ -1020,77 +945,147 @@ class ModelTrainer:
|
|
|
1020
945
|
if RANK in [0, -1]:
|
|
1021
946
|
Path(viz_dir).mkdir(parents=True, exist_ok=True)
|
|
1022
947
|
|
|
1023
|
-
# Get wandb viz config options
|
|
1024
|
-
log_wandb = self.config.trainer_config.use_wandb and OmegaConf.select(
|
|
1025
|
-
self.config, "trainer_config.wandb.save_viz_imgs_wandb", default=False
|
|
1026
|
-
)
|
|
1027
|
-
wandb_modes = []
|
|
1028
|
-
if log_wandb:
|
|
1029
|
-
if OmegaConf.select(
|
|
1030
|
-
self.config, "trainer_config.wandb.viz_enabled", default=True
|
|
1031
|
-
):
|
|
1032
|
-
wandb_modes.append("direct")
|
|
1033
|
-
if OmegaConf.select(
|
|
1034
|
-
self.config, "trainer_config.wandb.viz_boxes", default=False
|
|
1035
|
-
):
|
|
1036
|
-
wandb_modes.append("boxes")
|
|
1037
|
-
if OmegaConf.select(
|
|
1038
|
-
self.config, "trainer_config.wandb.viz_masks", default=False
|
|
1039
|
-
):
|
|
1040
|
-
wandb_modes.append("masks")
|
|
1041
|
-
|
|
1042
|
-
# Single unified callback handles all visualization outputs
|
|
1043
948
|
callbacks.append(
|
|
1044
|
-
|
|
1045
|
-
|
|
1046
|
-
|
|
1047
|
-
|
|
1048
|
-
model_type=self.model_type,
|
|
1049
|
-
save_local=self.config.trainer_config.save_ckpt,
|
|
1050
|
-
local_save_dir=viz_dir,
|
|
1051
|
-
log_wandb=log_wandb,
|
|
1052
|
-
wandb_modes=wandb_modes if wandb_modes else ["direct"],
|
|
1053
|
-
wandb_box_size=OmegaConf.select(
|
|
1054
|
-
self.config, "trainer_config.wandb.viz_box_size", default=5.0
|
|
949
|
+
MatplotlibSaver(
|
|
950
|
+
save_folder=viz_dir,
|
|
951
|
+
plot_fn=lambda: self.lightning_model.visualize_example(
|
|
952
|
+
next(train_viz_pipeline)
|
|
1055
953
|
),
|
|
1056
|
-
|
|
1057
|
-
|
|
1058
|
-
|
|
1059
|
-
|
|
1060
|
-
|
|
1061
|
-
|
|
1062
|
-
|
|
954
|
+
prefix="train",
|
|
955
|
+
)
|
|
956
|
+
)
|
|
957
|
+
callbacks.append(
|
|
958
|
+
MatplotlibSaver(
|
|
959
|
+
save_folder=viz_dir,
|
|
960
|
+
plot_fn=lambda: self.lightning_model.visualize_example(
|
|
961
|
+
next(val_viz_pipeline)
|
|
1063
962
|
),
|
|
963
|
+
prefix="validation",
|
|
1064
964
|
)
|
|
1065
965
|
)
|
|
1066
966
|
|
|
1067
|
-
|
|
1068
|
-
|
|
1069
|
-
|
|
967
|
+
if self.model_type == "bottomup":
|
|
968
|
+
train_viz_pipeline1 = cycle(copy.deepcopy(viz_train_dataset))
|
|
969
|
+
val_viz_pipeline1 = cycle(copy.deepcopy(viz_val_dataset))
|
|
970
|
+
callbacks.append(
|
|
971
|
+
MatplotlibSaver(
|
|
972
|
+
save_folder=viz_dir,
|
|
973
|
+
plot_fn=lambda: self.lightning_model.visualize_pafs_example(
|
|
974
|
+
next(train_viz_pipeline1)
|
|
975
|
+
),
|
|
976
|
+
prefix="train.pafs_magnitude",
|
|
977
|
+
)
|
|
978
|
+
)
|
|
979
|
+
callbacks.append(
|
|
980
|
+
MatplotlibSaver(
|
|
981
|
+
save_folder=viz_dir,
|
|
982
|
+
plot_fn=lambda: self.lightning_model.visualize_pafs_example(
|
|
983
|
+
next(val_viz_pipeline1)
|
|
984
|
+
),
|
|
985
|
+
prefix="validation.pafs_magnitude",
|
|
986
|
+
)
|
|
987
|
+
)
|
|
1070
988
|
|
|
1071
|
-
|
|
1072
|
-
|
|
1073
|
-
|
|
1074
|
-
# Use centroid-specific evaluation with distance-based metrics
|
|
989
|
+
if self.model_type == "multi_class_bottomup":
|
|
990
|
+
train_viz_pipeline1 = cycle(copy.deepcopy(viz_train_dataset))
|
|
991
|
+
val_viz_pipeline1 = cycle(copy.deepcopy(viz_val_dataset))
|
|
1075
992
|
callbacks.append(
|
|
1076
|
-
|
|
1077
|
-
|
|
1078
|
-
|
|
1079
|
-
|
|
993
|
+
MatplotlibSaver(
|
|
994
|
+
save_folder=viz_dir,
|
|
995
|
+
plot_fn=lambda: self.lightning_model.visualize_class_maps_example(
|
|
996
|
+
next(train_viz_pipeline1)
|
|
997
|
+
),
|
|
998
|
+
prefix="train.class_maps",
|
|
1080
999
|
)
|
|
1081
1000
|
)
|
|
1082
|
-
else:
|
|
1083
|
-
# Use standard OKS/PCK evaluation for pose models
|
|
1084
1001
|
callbacks.append(
|
|
1085
|
-
|
|
1086
|
-
|
|
1087
|
-
|
|
1088
|
-
|
|
1089
|
-
|
|
1090
|
-
|
|
1002
|
+
MatplotlibSaver(
|
|
1003
|
+
save_folder=viz_dir,
|
|
1004
|
+
plot_fn=lambda: self.lightning_model.visualize_class_maps_example(
|
|
1005
|
+
next(val_viz_pipeline1)
|
|
1006
|
+
),
|
|
1007
|
+
prefix="validation.class_maps",
|
|
1091
1008
|
)
|
|
1092
1009
|
)
|
|
1093
1010
|
|
|
1011
|
+
if self.config.trainer_config.use_wandb and OmegaConf.select(
|
|
1012
|
+
self.config, "trainer_config.wandb.save_viz_imgs_wandb", default=False
|
|
1013
|
+
):
|
|
1014
|
+
# Get wandb viz config options
|
|
1015
|
+
viz_enabled = OmegaConf.select(
|
|
1016
|
+
self.config, "trainer_config.wandb.viz_enabled", default=True
|
|
1017
|
+
)
|
|
1018
|
+
viz_boxes = OmegaConf.select(
|
|
1019
|
+
self.config, "trainer_config.wandb.viz_boxes", default=False
|
|
1020
|
+
)
|
|
1021
|
+
viz_masks = OmegaConf.select(
|
|
1022
|
+
self.config, "trainer_config.wandb.viz_masks", default=False
|
|
1023
|
+
)
|
|
1024
|
+
viz_box_size = OmegaConf.select(
|
|
1025
|
+
self.config, "trainer_config.wandb.viz_box_size", default=5.0
|
|
1026
|
+
)
|
|
1027
|
+
viz_confmap_threshold = OmegaConf.select(
|
|
1028
|
+
self.config,
|
|
1029
|
+
"trainer_config.wandb.viz_confmap_threshold",
|
|
1030
|
+
default=0.1,
|
|
1031
|
+
)
|
|
1032
|
+
log_viz_table = OmegaConf.select(
|
|
1033
|
+
self.config, "trainer_config.wandb.log_viz_table", default=False
|
|
1034
|
+
)
|
|
1035
|
+
|
|
1036
|
+
# Create viz data pipelines for wandb callback
|
|
1037
|
+
wandb_train_viz_pipeline = cycle(copy.deepcopy(viz_train_dataset))
|
|
1038
|
+
wandb_val_viz_pipeline = cycle(copy.deepcopy(viz_val_dataset))
|
|
1039
|
+
|
|
1040
|
+
if self.model_type == "bottomup":
|
|
1041
|
+
# Bottom-up model needs PAF visualizations
|
|
1042
|
+
wandb_train_pafs_pipeline = cycle(copy.deepcopy(viz_train_dataset))
|
|
1043
|
+
wandb_val_pafs_pipeline = cycle(copy.deepcopy(viz_val_dataset))
|
|
1044
|
+
callbacks.append(
|
|
1045
|
+
WandBVizCallbackWithPAFs(
|
|
1046
|
+
train_viz_fn=lambda: self.lightning_model.get_visualization_data(
|
|
1047
|
+
next(wandb_train_viz_pipeline)
|
|
1048
|
+
),
|
|
1049
|
+
val_viz_fn=lambda: self.lightning_model.get_visualization_data(
|
|
1050
|
+
next(wandb_val_viz_pipeline)
|
|
1051
|
+
),
|
|
1052
|
+
train_pafs_viz_fn=lambda: self.lightning_model.get_visualization_data(
|
|
1053
|
+
next(wandb_train_pafs_pipeline), include_pafs=True
|
|
1054
|
+
),
|
|
1055
|
+
val_pafs_viz_fn=lambda: self.lightning_model.get_visualization_data(
|
|
1056
|
+
next(wandb_val_pafs_pipeline), include_pafs=True
|
|
1057
|
+
),
|
|
1058
|
+
viz_enabled=viz_enabled,
|
|
1059
|
+
viz_boxes=viz_boxes,
|
|
1060
|
+
viz_masks=viz_masks,
|
|
1061
|
+
box_size=viz_box_size,
|
|
1062
|
+
confmap_threshold=viz_confmap_threshold,
|
|
1063
|
+
log_table=log_viz_table,
|
|
1064
|
+
)
|
|
1065
|
+
)
|
|
1066
|
+
else:
|
|
1067
|
+
# Standard models
|
|
1068
|
+
callbacks.append(
|
|
1069
|
+
WandBVizCallback(
|
|
1070
|
+
train_viz_fn=lambda: self.lightning_model.get_visualization_data(
|
|
1071
|
+
next(wandb_train_viz_pipeline)
|
|
1072
|
+
),
|
|
1073
|
+
val_viz_fn=lambda: self.lightning_model.get_visualization_data(
|
|
1074
|
+
next(wandb_val_viz_pipeline)
|
|
1075
|
+
),
|
|
1076
|
+
viz_enabled=viz_enabled,
|
|
1077
|
+
viz_boxes=viz_boxes,
|
|
1078
|
+
viz_masks=viz_masks,
|
|
1079
|
+
box_size=viz_box_size,
|
|
1080
|
+
confmap_threshold=viz_confmap_threshold,
|
|
1081
|
+
log_table=log_viz_table,
|
|
1082
|
+
)
|
|
1083
|
+
)
|
|
1084
|
+
|
|
1085
|
+
# Add custom progress bar with better metric formatting
|
|
1086
|
+
if self.config.trainer_config.enable_progress_bar:
|
|
1087
|
+
callbacks.append(SleapProgressBar())
|
|
1088
|
+
|
|
1094
1089
|
return loggers, callbacks
|
|
1095
1090
|
|
|
1096
1091
|
def _delete_cache_imgs(self):
|
|
@@ -1179,11 +1174,6 @@ class ModelTrainer:
|
|
|
1179
1174
|
: self.config.trainer_config.trainer_devices
|
|
1180
1175
|
]
|
|
1181
1176
|
]
|
|
1182
|
-
# Sort device indices in ascending order for NCCL compatibility.
|
|
1183
|
-
# NCCL expects devices in consistent ascending order across ranks
|
|
1184
|
-
# to properly set up communication rings. Without sorting, DDP may
|
|
1185
|
-
# assign multiple ranks to the same GPU, causing "Duplicate GPU detected" errors.
|
|
1186
|
-
devices.sort()
|
|
1187
1177
|
logger.info(f"Using GPUs with most available memory: {devices}")
|
|
1188
1178
|
|
|
1189
1179
|
# create lightning.Trainer instance.
|
|
@@ -1205,10 +1195,6 @@ class ModelTrainer:
|
|
|
1205
1195
|
# setup datasets
|
|
1206
1196
|
train_dataset, val_dataset = self._setup_datasets()
|
|
1207
1197
|
|
|
1208
|
-
# Barrier after dataset creation to ensure all workers wait for disk caching
|
|
1209
|
-
# (rank 0 caches to disk, others must wait before reading cached files)
|
|
1210
|
-
self.trainer.strategy.barrier()
|
|
1211
|
-
|
|
1212
1198
|
# set-up steps per epoch
|
|
1213
1199
|
train_steps_per_epoch = self.config.trainer_config.train_steps_per_epoch
|
|
1214
1200
|
if train_steps_per_epoch is None:
|
|
@@ -1282,21 +1268,18 @@ class ModelTrainer:
|
|
|
1282
1268
|
# Define custom x-axes for wandb metrics
|
|
1283
1269
|
# Epoch-level metrics use epoch as x-axis, step-level use default global_step
|
|
1284
1270
|
wandb.define_metric("epoch")
|
|
1285
|
-
|
|
1286
|
-
|
|
1287
|
-
wandb.define_metric("
|
|
1288
|
-
|
|
1289
|
-
|
|
1290
|
-
|
|
1291
|
-
|
|
1292
|
-
|
|
1293
|
-
|
|
1294
|
-
wandb.define_metric("
|
|
1295
|
-
|
|
1296
|
-
|
|
1297
|
-
wandb.define_metric("viz/*", step_metric="epoch")
|
|
1298
|
-
wandb.define_metric("viz/train/*", step_metric="epoch")
|
|
1299
|
-
wandb.define_metric("viz/val/*", step_metric="epoch")
|
|
1271
|
+
wandb.define_metric("val_loss", step_metric="epoch")
|
|
1272
|
+
wandb.define_metric("val_time", step_metric="epoch")
|
|
1273
|
+
wandb.define_metric("train_time", step_metric="epoch")
|
|
1274
|
+
# Per-node losses use epoch as x-axis
|
|
1275
|
+
for node_name in self.skeletons[0].node_names:
|
|
1276
|
+
wandb.define_metric(node_name, step_metric="epoch")
|
|
1277
|
+
|
|
1278
|
+
# Visualization images use epoch as x-axis
|
|
1279
|
+
wandb.define_metric("train_predictions*", step_metric="epoch")
|
|
1280
|
+
wandb.define_metric("val_predictions*", step_metric="epoch")
|
|
1281
|
+
wandb.define_metric("train_pafs*", step_metric="epoch")
|
|
1282
|
+
wandb.define_metric("val_pafs*", step_metric="epoch")
|
|
1300
1283
|
|
|
1301
1284
|
self.config.trainer_config.wandb.current_run_id = wandb.run.id
|
|
1302
1285
|
wandb.config["run_name"] = self.config.trainer_config.wandb.name
|
|
@@ -1339,7 +1322,27 @@ class ModelTrainer:
|
|
|
1339
1322
|
logger.info(
|
|
1340
1323
|
f"Finished training loop. [{(time.time() - start_train_time) / 60:.1f} min]"
|
|
1341
1324
|
)
|
|
1342
|
-
|
|
1325
|
+
if self.trainer.global_rank == 0 and self.config.trainer_config.use_wandb:
|
|
1326
|
+
wandb.finish()
|
|
1327
|
+
|
|
1328
|
+
# Delete local wandb logs if configured
|
|
1329
|
+
wandb_config = self.config.trainer_config.wandb
|
|
1330
|
+
should_delete_wandb_logs = wandb_config.delete_local_logs is True or (
|
|
1331
|
+
wandb_config.delete_local_logs is None
|
|
1332
|
+
and wandb_config.wandb_mode != "offline"
|
|
1333
|
+
)
|
|
1334
|
+
if should_delete_wandb_logs:
|
|
1335
|
+
wandb_dir = (
|
|
1336
|
+
Path(self.config.trainer_config.ckpt_dir)
|
|
1337
|
+
/ self.config.trainer_config.run_name
|
|
1338
|
+
/ "wandb"
|
|
1339
|
+
)
|
|
1340
|
+
if wandb_dir.exists():
|
|
1341
|
+
logger.info(
|
|
1342
|
+
f"Deleting local wandb logs at {wandb_dir}... "
|
|
1343
|
+
"(set trainer_config.wandb.delete_local_logs=false to disable)"
|
|
1344
|
+
)
|
|
1345
|
+
shutil.rmtree(wandb_dir, ignore_errors=True)
|
|
1343
1346
|
|
|
1344
1347
|
# delete image disk caching
|
|
1345
1348
|
if (
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: sleap-nn
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.0a1
|
|
4
4
|
Summary: Neural network backend for training and inference for animal pose estimation.
|
|
5
5
|
Author-email: Divya Seshadri Murali <dimurali@salk.edu>, Elizabeth Berrigan <eberrigan@salk.edu>, Vincent Tu <vitu@ucsd.edu>, Liezl Maree <lmaree@salk.edu>, David Samy <davidasamy@gmail.com>, Talmo Pereira <talmo@salk.edu>
|
|
6
6
|
License: BSD-3-Clause
|
|
@@ -13,10 +13,10 @@ Classifier: Programming Language :: Python :: 3.13
|
|
|
13
13
|
Requires-Python: <3.14,>=3.11
|
|
14
14
|
Description-Content-Type: text/markdown
|
|
15
15
|
License-File: LICENSE
|
|
16
|
-
Requires-Dist: sleap-io<0.7.0,>=0.6.
|
|
16
|
+
Requires-Dist: sleap-io<0.7.0,>=0.6.0
|
|
17
17
|
Requires-Dist: numpy
|
|
18
18
|
Requires-Dist: lightning
|
|
19
|
-
Requires-Dist:
|
|
19
|
+
Requires-Dist: kornia
|
|
20
20
|
Requires-Dist: jsonpickle
|
|
21
21
|
Requires-Dist: scipy
|
|
22
22
|
Requires-Dist: attrs
|
|
@@ -32,7 +32,6 @@ Requires-Dist: hydra-core
|
|
|
32
32
|
Requires-Dist: jupyter
|
|
33
33
|
Requires-Dist: jupyterlab
|
|
34
34
|
Requires-Dist: pyzmq
|
|
35
|
-
Requires-Dist: rich-click>=1.9.5
|
|
36
35
|
Provides-Extra: torch
|
|
37
36
|
Requires-Dist: torch; extra == "torch"
|
|
38
37
|
Requires-Dist: torchvision>=0.20.0; extra == "torch"
|
|
@@ -48,17 +47,6 @@ Requires-Dist: torchvision>=0.20.0; extra == "torch-cuda128"
|
|
|
48
47
|
Provides-Extra: torch-cuda130
|
|
49
48
|
Requires-Dist: torch; extra == "torch-cuda130"
|
|
50
49
|
Requires-Dist: torchvision>=0.20.0; extra == "torch-cuda130"
|
|
51
|
-
Provides-Extra: export
|
|
52
|
-
Requires-Dist: onnx>=1.15.0; extra == "export"
|
|
53
|
-
Requires-Dist: onnxruntime>=1.16.0; extra == "export"
|
|
54
|
-
Requires-Dist: onnxscript>=0.1.0; extra == "export"
|
|
55
|
-
Provides-Extra: export-gpu
|
|
56
|
-
Requires-Dist: onnx>=1.15.0; extra == "export-gpu"
|
|
57
|
-
Requires-Dist: onnxruntime-gpu>=1.16.0; extra == "export-gpu"
|
|
58
|
-
Requires-Dist: onnxscript>=0.1.0; extra == "export-gpu"
|
|
59
|
-
Provides-Extra: tensorrt
|
|
60
|
-
Requires-Dist: tensorrt>=10.13.0; (sys_platform == "linux" or sys_platform == "win32") and extra == "tensorrt"
|
|
61
|
-
Requires-Dist: torch-tensorrt>=2.5.0; (sys_platform == "linux" or sys_platform == "win32") and extra == "tensorrt"
|
|
62
50
|
Dynamic: license-file
|
|
63
51
|
|
|
64
52
|
# sleap-nn
|