sleap-nn 0.0.5__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. sleap_nn/__init__.py +9 -2
  2. sleap_nn/architectures/convnext.py +5 -0
  3. sleap_nn/architectures/encoder_decoder.py +25 -6
  4. sleap_nn/architectures/swint.py +8 -0
  5. sleap_nn/cli.py +489 -46
  6. sleap_nn/config/data_config.py +51 -8
  7. sleap_nn/config/get_config.py +32 -24
  8. sleap_nn/config/trainer_config.py +88 -0
  9. sleap_nn/data/augmentation.py +61 -200
  10. sleap_nn/data/custom_datasets.py +433 -61
  11. sleap_nn/data/instance_cropping.py +71 -6
  12. sleap_nn/data/normalization.py +45 -2
  13. sleap_nn/data/providers.py +26 -0
  14. sleap_nn/data/resizing.py +2 -2
  15. sleap_nn/data/skia_augmentation.py +414 -0
  16. sleap_nn/data/utils.py +135 -17
  17. sleap_nn/evaluation.py +177 -42
  18. sleap_nn/export/__init__.py +21 -0
  19. sleap_nn/export/cli.py +1778 -0
  20. sleap_nn/export/exporters/__init__.py +51 -0
  21. sleap_nn/export/exporters/onnx_exporter.py +80 -0
  22. sleap_nn/export/exporters/tensorrt_exporter.py +291 -0
  23. sleap_nn/export/metadata.py +225 -0
  24. sleap_nn/export/predictors/__init__.py +63 -0
  25. sleap_nn/export/predictors/base.py +22 -0
  26. sleap_nn/export/predictors/onnx.py +154 -0
  27. sleap_nn/export/predictors/tensorrt.py +312 -0
  28. sleap_nn/export/utils.py +307 -0
  29. sleap_nn/export/wrappers/__init__.py +25 -0
  30. sleap_nn/export/wrappers/base.py +96 -0
  31. sleap_nn/export/wrappers/bottomup.py +243 -0
  32. sleap_nn/export/wrappers/bottomup_multiclass.py +195 -0
  33. sleap_nn/export/wrappers/centered_instance.py +56 -0
  34. sleap_nn/export/wrappers/centroid.py +58 -0
  35. sleap_nn/export/wrappers/single_instance.py +83 -0
  36. sleap_nn/export/wrappers/topdown.py +180 -0
  37. sleap_nn/export/wrappers/topdown_multiclass.py +304 -0
  38. sleap_nn/inference/__init__.py +6 -0
  39. sleap_nn/inference/bottomup.py +86 -20
  40. sleap_nn/inference/peak_finding.py +93 -16
  41. sleap_nn/inference/postprocessing.py +284 -0
  42. sleap_nn/inference/predictors.py +339 -137
  43. sleap_nn/inference/provenance.py +292 -0
  44. sleap_nn/inference/topdown.py +55 -47
  45. sleap_nn/legacy_models.py +65 -11
  46. sleap_nn/predict.py +224 -19
  47. sleap_nn/system_info.py +443 -0
  48. sleap_nn/tracking/tracker.py +8 -1
  49. sleap_nn/train.py +138 -44
  50. sleap_nn/training/callbacks.py +1258 -5
  51. sleap_nn/training/lightning_modules.py +902 -220
  52. sleap_nn/training/model_trainer.py +424 -111
  53. sleap_nn/training/schedulers.py +191 -0
  54. sleap_nn/training/utils.py +367 -2
  55. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/METADATA +35 -33
  56. sleap_nn-0.1.0.dist-info/RECORD +88 -0
  57. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/WHEEL +1 -1
  58. sleap_nn-0.0.5.dist-info/RECORD +0 -63
  59. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/entry_points.txt +0 -0
  60. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/licenses/LICENSE +0 -0
  61. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/top_level.txt +0 -0
@@ -2,7 +2,6 @@
2
2
 
3
3
  import os
4
4
  import shutil
5
- import copy
6
5
  import attrs
7
6
  import torch
8
7
  import random
@@ -16,11 +15,14 @@ import yaml
16
15
  from pathlib import Path
17
16
  from typing import List, Optional
18
17
  from datetime import datetime
19
- from itertools import cycle, count
18
+ from itertools import count
20
19
  from omegaconf import DictConfig, OmegaConf
21
20
  from lightning.pytorch.loggers import WandbLogger
22
21
  from sleap_nn.data.utils import check_cache_memory
23
- from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping
22
+ from lightning.pytorch.callbacks import (
23
+ ModelCheckpoint,
24
+ EarlyStopping,
25
+ )
24
26
  from lightning.pytorch.profilers import (
25
27
  SimpleProfiler,
26
28
  AdvancedProfiler,
@@ -28,7 +30,11 @@ from lightning.pytorch.profilers import (
28
30
  PassThroughProfiler,
29
31
  )
30
32
  from sleap_io.io.skeleton import SkeletonYAMLEncoder
31
- from sleap_nn.data.instance_cropping import find_instance_crop_size
33
+ from sleap_nn.data.instance_cropping import (
34
+ find_instance_crop_size,
35
+ find_max_instance_bbox_size,
36
+ compute_augmentation_padding,
37
+ )
32
38
  from sleap_nn.data.providers import get_max_height_width
33
39
  from sleap_nn.data.custom_datasets import (
34
40
  get_train_val_dataloaders,
@@ -47,9 +53,11 @@ from sleap_nn.config.training_job_config import verify_training_cfg
47
53
  from sleap_nn.training.callbacks import (
48
54
  ProgressReporterZMQ,
49
55
  TrainingControllerZMQ,
50
- MatplotlibSaver,
51
- WandBPredImageLogger,
52
56
  CSVLoggerCallback,
57
+ SleapProgressBar,
58
+ EpochEndEvaluationCallback,
59
+ CentroidEvaluationCallback,
60
+ UnifiedVizCallback,
53
61
  )
54
62
  from sleap_nn import RANK
55
63
  from sleap_nn.legacy_models import get_keras_first_layer_channels
@@ -207,6 +215,52 @@ class ModelTrainer:
207
215
  trainer_devices = 1
208
216
  return trainer_devices
209
217
 
218
+ def _count_labeled_frames(
219
+ self, labels_list: List[sio.Labels], user_only: bool = True
220
+ ) -> int:
221
+ """Count labeled frames, optionally filtering to user-labeled only.
222
+
223
+ Args:
224
+ labels_list: List of Labels objects to count frames from.
225
+ user_only: If True, count only frames with user instances.
226
+
227
+ Returns:
228
+ Total count of labeled frames.
229
+ """
230
+ total = 0
231
+ for label in labels_list:
232
+ if user_only:
233
+ total += sum(1 for lf in label if lf.has_user_instances)
234
+ else:
235
+ total += len(label)
236
+ return total
237
+
238
+ def _filter_to_user_labeled(self, labels: sio.Labels) -> sio.Labels:
239
+ """Filter a Labels object to only include user-labeled frames.
240
+
241
+ Args:
242
+ labels: Labels object to filter.
243
+
244
+ Returns:
245
+ New Labels object containing only frames with user instances.
246
+ """
247
+ # Filter labeled frames to only those with user instances
248
+ user_lfs = [lf for lf in labels if lf.has_user_instances]
249
+
250
+ # Set instances to user instances only
251
+ for lf in user_lfs:
252
+ lf.instances = lf.user_instances
253
+
254
+ # Create new Labels with filtered frames
255
+ return sio.Labels(
256
+ labeled_frames=user_lfs,
257
+ videos=labels.videos,
258
+ skeletons=labels.skeletons,
259
+ tracks=labels.tracks,
260
+ suggestions=labels.suggestions,
261
+ provenance=labels.provenance,
262
+ )
263
+
210
264
  def _setup_train_val_labels(
211
265
  self,
212
266
  labels: Optional[List[sio.Labels]] = None,
@@ -218,21 +272,35 @@ class ModelTrainer:
218
272
  total_val_lfs = 0
219
273
  self.skeletons = labels[0].skeletons
220
274
 
275
+ # Check if we should count only user-labeled frames
276
+ user_instances_only = OmegaConf.select(
277
+ self.config, "data_config.user_instances_only", default=True
278
+ )
279
+
221
280
  # check if all `.slp` file shave same skeleton structure (if multiple slp file paths are provided)
222
281
  skeleton = self.skeletons[0]
223
282
  for index, train_label in enumerate(labels):
224
283
  skel_temp = train_label.skeletons[0]
225
284
  skeletons_equal = skeleton.matches(skel_temp)
226
- if skeletons_equal:
227
- total_train_lfs += len(train_label)
228
- else:
285
+ if not skeletons_equal:
229
286
  message = f"The skeletons in the training labels: {index + 1} do not match the skeleton in the first training label file."
230
287
  logger.error(message)
231
288
  raise ValueError(message)
232
289
 
233
- if val_labels is None or not len(val_labels):
290
+ # Check for same-data mode (train = val, for intentional overfitting)
291
+ use_same = OmegaConf.select(
292
+ self.config, "data_config.use_same_data_for_val", default=False
293
+ )
294
+
295
+ if use_same:
296
+ # Same mode: use identical data for train and val (for overfitting)
297
+ logger.info("Using same data for train and val (overfit mode)")
298
+ self.train_labels = labels
299
+ self.val_labels = labels
300
+ total_train_lfs = self._count_labeled_frames(labels, user_instances_only)
301
+ total_val_lfs = total_train_lfs
302
+ elif val_labels is None or not len(val_labels):
234
303
  # if val labels are not provided, split from train
235
- total_train_lfs = 0
236
304
  val_fraction = OmegaConf.select(
237
305
  self.config, "data_config.validation_fraction", default=0.1
238
306
  )
@@ -250,13 +318,14 @@ class ModelTrainer:
250
318
  )
251
319
  self.train_labels.append(train_split)
252
320
  self.val_labels.append(val_split)
321
+ # make_training_splits returns only user-labeled frames
253
322
  total_train_lfs += len(train_split)
254
323
  total_val_lfs += len(val_split)
255
324
  else:
256
325
  self.train_labels = labels
257
326
  self.val_labels = val_labels
258
- for val_l in self.val_labels:
259
- total_val_lfs += len(val_l)
327
+ total_train_lfs = self._count_labeled_frames(labels, user_instances_only)
328
+ total_val_lfs = self._count_labeled_frames(val_labels, user_instances_only)
260
329
 
261
330
  logger.info(f"# Train Labeled frames: {total_train_lfs}")
262
331
  logger.info(f"# Val Labeled frames: {total_val_lfs}")
@@ -291,13 +360,70 @@ class ModelTrainer:
291
360
  ):
292
361
  # compute crop size if not provided in config
293
362
  if crop_size is None:
363
+ # Get padding from config or auto-compute from augmentation settings
364
+ padding = self.config.data_config.preprocessing.crop_padding
365
+ if padding is None:
366
+ # Auto-compute padding based on augmentation settings
367
+ aug_config = self.config.data_config.augmentation_config
368
+ if (
369
+ self.config.data_config.use_augmentations_train
370
+ and aug_config is not None
371
+ and aug_config.geometric is not None
372
+ ):
373
+ geo = aug_config.geometric
374
+ # Check if rotation is enabled (via rotation_p or affine_p)
375
+ rotation_enabled = (
376
+ geo.rotation_p is not None and geo.rotation_p > 0
377
+ ) or (
378
+ geo.rotation_p is None
379
+ and geo.scale_p is None
380
+ and geo.translate_p is None
381
+ and geo.affine_p > 0
382
+ )
383
+ # Check if scale is enabled (via scale_p or affine_p)
384
+ scale_enabled = (
385
+ geo.scale_p is not None and geo.scale_p > 0
386
+ ) or (
387
+ geo.rotation_p is None
388
+ and geo.scale_p is None
389
+ and geo.translate_p is None
390
+ and geo.affine_p > 0
391
+ )
392
+
393
+ if rotation_enabled or scale_enabled:
394
+ # First find the actual max bbox size from labels
395
+ bbox_size = find_max_instance_bbox_size(train_label)
396
+ bbox_size = max(
397
+ bbox_size,
398
+ self.config.data_config.preprocessing.min_crop_size
399
+ or 100,
400
+ )
401
+ rotation_max = (
402
+ max(
403
+ abs(geo.rotation_min),
404
+ abs(geo.rotation_max),
405
+ )
406
+ if rotation_enabled
407
+ else 0.0
408
+ )
409
+ scale_max = geo.scale_max if scale_enabled else 1.0
410
+ padding = compute_augmentation_padding(
411
+ bbox_size=bbox_size,
412
+ rotation_max=rotation_max,
413
+ scale_max=scale_max,
414
+ )
415
+ else:
416
+ padding = 0
417
+ else:
418
+ padding = 0
419
+
294
420
  crop_sz = find_instance_crop_size(
295
421
  labels=train_label,
422
+ padding=padding,
296
423
  maximum_stride=self.config.model_config.backbone_config[
297
424
  f"{self.backbone_type}"
298
425
  ]["max_stride"],
299
426
  min_crop_size=self.config.data_config.preprocessing.min_crop_size,
300
- input_scaling=self.config.data_config.preprocessing.scale,
301
427
  )
302
428
 
303
429
  if crop_sz > max_crop_size:
@@ -361,16 +487,36 @@ class ModelTrainer:
361
487
  ckpt_dir = "."
362
488
  self.config.trainer_config.ckpt_dir = ckpt_dir
363
489
  run_name = self.config.trainer_config.run_name
364
- if run_name is None or run_name == "" or run_name == "None":
490
+ run_name_is_empty = run_name is None or run_name == "" or run_name == "None"
491
+
492
+ # Validate: multi-GPU + disk cache requires explicit run_name
493
+ if run_name_is_empty:
494
+ is_disk_caching = (
495
+ self.config.data_config.data_pipeline_fw
496
+ == "torch_dataset_cache_img_disk"
497
+ )
498
+ num_devices = self._get_trainer_devices()
499
+
500
+ if is_disk_caching and num_devices > 1:
501
+ raise ValueError(
502
+ f"Multi-GPU training with disk caching requires an explicit `run_name`.\n\n"
503
+ f"Detected {num_devices} device(s) with "
504
+ f"`data_pipeline_fw='torch_dataset_cache_img_disk'`.\n"
505
+ f"Without an explicit run_name, each GPU worker generates a different "
506
+ f"timestamp-based directory, causing cache synchronization failures.\n\n"
507
+ f"Please provide a run_name using one of these methods:\n"
508
+ f" - CLI: sleap-nn train config.yaml trainer_config.run_name=my_experiment\n"
509
+ f" - Config file: Set `trainer_config.run_name: my_experiment`\n"
510
+ f" - Python API: train(..., run_name='my_experiment')"
511
+ )
512
+
513
+ # Auto-generate timestamp-based run_name (safe for single GPU or non-disk-cache)
365
514
  sum_train_lfs = sum([len(train_label) for train_label in self.train_labels])
366
515
  sum_val_lfs = sum([len(val_label) for val_label in self.val_labels])
367
- if self._get_trainer_devices() > 1:
368
- run_name = f"{self.model_type}.n={sum_train_lfs + sum_val_lfs}"
369
- else:
370
- run_name = (
371
- datetime.now().strftime("%y%m%d_%H%M%S")
372
- + f".{self.model_type}.n={sum_train_lfs + sum_val_lfs}"
373
- )
516
+ run_name = (
517
+ datetime.now().strftime("%y%m%d_%H%M%S")
518
+ + f".{self.model_type}.n={sum_train_lfs + sum_val_lfs}"
519
+ )
374
520
 
375
521
  # If checkpoint path already exists, add suffix to prevent overwriting
376
522
  if (Path(ckpt_dir) / run_name).exists() and (
@@ -509,6 +655,10 @@ class ModelTrainer:
509
655
  """Compute config parameters."""
510
656
  logger.info("Setting up config...")
511
657
 
658
+ # Normalize empty strings to None for optional wandb fields
659
+ if self.config.trainer_config.wandb.prv_runid == "":
660
+ self.config.trainer_config.wandb.prv_runid = None
661
+
512
662
  # compute preprocessing parameters from the labels objects and fill in the config
513
663
  self._setup_preprocessing_config()
514
664
 
@@ -558,39 +708,89 @@ class ModelTrainer:
558
708
  )
559
709
  )
560
710
 
561
- # setup checkpoint path
711
+ # setup checkpoint path (generates run_name if not specified)
562
712
  self._setup_ckpt_path()
563
713
 
714
+ # Default wandb run name to trainer run_name if not specified
715
+ # Note: This must come after _setup_ckpt_path() which generates run_name
716
+ if self.config.trainer_config.wandb.name is None:
717
+ self.config.trainer_config.wandb.name = self.config.trainer_config.run_name
718
+
564
719
  # verify input_channels in model_config based on input image and pretrained model weights
565
720
  self._verify_model_input_channels()
566
721
 
567
722
  def _setup_model_ckpt_dir(self):
568
- """Create the model ckpt folder."""
723
+ """Create the model ckpt folder and save ground truth labels."""
569
724
  ckpt_path = (
570
725
  Path(self.config.trainer_config.ckpt_dir)
571
726
  / self.config.trainer_config.run_name
572
727
  ).as_posix()
573
728
  logger.info(f"Setting up model ckpt dir: `{ckpt_path}`...")
574
729
 
575
- if not Path(ckpt_path).exists():
576
- try:
577
- Path(ckpt_path).mkdir(parents=True, exist_ok=True)
578
- except OSError as e:
579
- message = f"Cannot create a new folder in {ckpt_path}.\n {e}"
580
- logger.error(message)
581
- raise OSError(message)
582
-
730
+ # Only rank 0 (or non-distributed) should create directories and save files
583
731
  if RANK in [0, -1]:
732
+ if not Path(ckpt_path).exists():
733
+ try:
734
+ Path(ckpt_path).mkdir(parents=True, exist_ok=True)
735
+ except OSError as e:
736
+ message = f"Cannot create a new folder in {ckpt_path}.\n {e}"
737
+ logger.error(message)
738
+ raise OSError(message)
739
+ # Check if we should filter to user-labeled frames only
740
+ user_instances_only = OmegaConf.select(
741
+ self.config, "data_config.user_instances_only", default=True
742
+ )
743
+
744
+ # Save train and val ground truth labels
584
745
  for idx, (train, val) in enumerate(zip(self.train_labels, self.val_labels)):
585
- train.save(
586
- Path(ckpt_path) / f"labels_train_gt_{idx}.slp",
746
+ # Filter to user-labeled frames if needed (for evaluation)
747
+ if user_instances_only:
748
+ train_filtered = self._filter_to_user_labeled(train)
749
+ val_filtered = self._filter_to_user_labeled(val)
750
+ else:
751
+ train_filtered = train
752
+ val_filtered = val
753
+
754
+ train_filtered.save(
755
+ Path(ckpt_path) / f"labels_gt.train.{idx}.slp",
587
756
  restore_original_videos=False,
588
757
  )
589
- val.save(
590
- Path(ckpt_path) / f"labels_val_gt_{idx}.slp",
758
+ val_filtered.save(
759
+ Path(ckpt_path) / f"labels_gt.val.{idx}.slp",
591
760
  restore_original_videos=False,
592
761
  )
593
762
 
763
+ # Save test ground truth labels if test paths are provided
764
+ test_file_path = OmegaConf.select(
765
+ self.config, "data_config.test_file_path", default=None
766
+ )
767
+ if test_file_path is not None:
768
+ # Normalize to list of strings
769
+ if isinstance(test_file_path, str):
770
+ test_paths = [test_file_path]
771
+ else:
772
+ test_paths = list(test_file_path)
773
+
774
+ for idx, test_path in enumerate(test_paths):
775
+ # Only save if it's a .slp file (not a video file)
776
+ if test_path.endswith(".slp") or test_path.endswith(".pkg.slp"):
777
+ try:
778
+ test_labels = sio.load_slp(test_path)
779
+ if user_instances_only:
780
+ test_filtered = self._filter_to_user_labeled(
781
+ test_labels
782
+ )
783
+ else:
784
+ test_filtered = test_labels
785
+ test_filtered.save(
786
+ Path(ckpt_path) / f"labels_gt.test.{idx}.slp",
787
+ restore_original_videos=False,
788
+ )
789
+ except Exception as e:
790
+ logger.warning(
791
+ f"Could not save test ground truth for {test_path}: {e}"
792
+ )
793
+
594
794
  def _setup_viz_datasets(self):
595
795
  """Setup dataloaders."""
596
796
  data_viz_config = self.config.copy()
@@ -608,10 +808,40 @@ class ModelTrainer:
608
808
  base_cache_img_path = None
609
809
  if self.config.data_config.data_pipeline_fw == "torch_dataset_cache_img_memory":
610
810
  # check available memory. If insufficient memory, default to disk caching.
811
+ # Account for DataLoader worker memory overhead
812
+ train_num_workers = self.config.trainer_config.train_data_loader.num_workers
813
+ val_num_workers = self.config.trainer_config.val_data_loader.num_workers
814
+ max_num_workers = max(train_num_workers, val_num_workers)
815
+
611
816
  mem_available = check_cache_memory(
612
- self.train_labels, self.val_labels, memory_buffer=MEMORY_BUFFER
817
+ self.train_labels,
818
+ self.val_labels,
819
+ memory_buffer=MEMORY_BUFFER,
820
+ num_workers=max_num_workers,
613
821
  )
614
822
  if not mem_available:
823
+ # Validate: multi-GPU + auto-generated run_name + fallback to disk cache
824
+ original_run_name = self._initial_config.trainer_config.run_name
825
+ run_name_was_auto = (
826
+ original_run_name is None
827
+ or original_run_name == ""
828
+ or original_run_name == "None"
829
+ )
830
+ if run_name_was_auto and self.trainer.num_devices > 1:
831
+ raise ValueError(
832
+ f"Memory caching failed and disk caching fallback requires an "
833
+ f"explicit `run_name` for multi-GPU training.\n\n"
834
+ f"Detected {self.trainer.num_devices} device(s) with insufficient "
835
+ f"memory for in-memory caching.\n"
836
+ f"Without an explicit run_name, each GPU worker generates a different "
837
+ f"timestamp-based directory, causing cache synchronization failures.\n\n"
838
+ f"Please provide a run_name using one of these methods:\n"
839
+ f" - CLI: sleap-nn train config.yaml trainer_config.run_name=my_experiment\n"
840
+ f" - Config file: Set `trainer_config.run_name: my_experiment`\n"
841
+ f" - Python API: train(..., run_name='my_experiment')\n\n"
842
+ f"Alternatively, use `data_pipeline_fw='torch_dataset'` to disable caching."
843
+ )
844
+
615
845
  self.config.data_config.data_pipeline_fw = (
616
846
  "torch_dataset_cache_img_disk"
617
847
  )
@@ -655,7 +885,7 @@ class ModelTrainer:
655
885
  / self.config.trainer_config.run_name
656
886
  ).as_posix(),
657
887
  filename="best",
658
- monitor="val_loss",
888
+ monitor="val/loss",
659
889
  mode="min",
660
890
  )
661
891
  callbacks.append(checkpoint_callback)
@@ -663,18 +893,52 @@ class ModelTrainer:
663
893
  # csv log callback
664
894
  csv_log_keys = [
665
895
  "epoch",
666
- "train_loss",
667
- "val_loss",
896
+ "train/loss",
897
+ "val/loss",
668
898
  "learning_rate",
669
- "train_time",
670
- "val_time",
899
+ "train/time",
900
+ "val/time",
671
901
  ]
902
+ # Add model-specific keys for wandb parity
672
903
  if self.model_type in [
673
904
  "single_instance",
674
905
  "centered_instance",
675
906
  "multi_class_topdown",
676
907
  ]:
677
- csv_log_keys.extend(self.skeletons[0].node_names)
908
+ csv_log_keys.extend(
909
+ [f"train/confmaps/{name}" for name in self.skeletons[0].node_names]
910
+ )
911
+ if self.model_type == "bottomup":
912
+ csv_log_keys.extend(
913
+ [
914
+ "train/confmaps_loss",
915
+ "train/paf_loss",
916
+ "val/confmaps_loss",
917
+ "val/paf_loss",
918
+ ]
919
+ )
920
+ if self.model_type == "multi_class_bottomup":
921
+ csv_log_keys.extend(
922
+ [
923
+ "train/confmaps_loss",
924
+ "train/classmap_loss",
925
+ "train/class_accuracy",
926
+ "val/confmaps_loss",
927
+ "val/classmap_loss",
928
+ "val/class_accuracy",
929
+ ]
930
+ )
931
+ if self.model_type == "multi_class_topdown":
932
+ csv_log_keys.extend(
933
+ [
934
+ "train/confmaps_loss",
935
+ "train/classvector_loss",
936
+ "train/class_accuracy",
937
+ "val/confmaps_loss",
938
+ "val/classvector_loss",
939
+ "val/class_accuracy",
940
+ ]
941
+ )
678
942
  csv_logger = CSVLoggerCallback(
679
943
  filepath=Path(self.config.trainer_config.ckpt_dir)
680
944
  / self.config.trainer_config.run_name
@@ -687,7 +951,7 @@ class ModelTrainer:
687
951
  # early stopping callback
688
952
  callbacks.append(
689
953
  EarlyStopping(
690
- monitor="val_loss",
954
+ monitor="val/loss",
691
955
  mode="min",
692
956
  verbose=False,
693
957
  min_delta=self.config.trainer_config.early_stopping.min_delta,
@@ -716,6 +980,17 @@ class ModelTrainer:
716
980
  )
717
981
  loggers.append(wandb_logger)
718
982
 
983
+ # Log message about wandb local logs cleanup
984
+ should_delete_wandb_logs = wandb_config.delete_local_logs is True or (
985
+ wandb_config.delete_local_logs is None
986
+ and wandb_config.wandb_mode != "offline"
987
+ )
988
+ if should_delete_wandb_logs:
989
+ logger.info(
990
+ "WandB local logs will be deleted after training completes. "
991
+ "To keep logs, set trainer_config.wandb.delete_local_logs=false"
992
+ )
993
+
719
994
  # save the configs as yaml in the checkpoint dir
720
995
  # Mask API key in both configs to prevent saving to disk
721
996
  self.config.trainer_config.wandb.api_key = ""
@@ -734,11 +1009,8 @@ class ModelTrainer:
734
1009
  )
735
1010
  callbacks.append(ProgressReporterZMQ(address=publish_address))
736
1011
 
737
- # viz callbacks
1012
+ # viz callbacks - use unified callback for all visualization outputs
738
1013
  if self.config.trainer_config.visualize_preds_during_training:
739
- train_viz_pipeline = cycle(viz_train_dataset)
740
- val_viz_pipeline = cycle(viz_val_dataset)
741
-
742
1014
  viz_dir = (
743
1015
  Path(self.config.trainer_config.ckpt_dir)
744
1016
  / self.config.trainer_config.run_name
@@ -748,77 +1020,74 @@ class ModelTrainer:
748
1020
  if RANK in [0, -1]:
749
1021
  Path(viz_dir).mkdir(parents=True, exist_ok=True)
750
1022
 
751
- callbacks.append(
752
- MatplotlibSaver(
753
- save_folder=viz_dir,
754
- plot_fn=lambda: self.lightning_model.visualize_example(
755
- next(train_viz_pipeline)
756
- ),
757
- prefix="train",
758
- )
1023
+ # Get wandb viz config options
1024
+ log_wandb = self.config.trainer_config.use_wandb and OmegaConf.select(
1025
+ self.config, "trainer_config.wandb.save_viz_imgs_wandb", default=False
759
1026
  )
1027
+ wandb_modes = []
1028
+ if log_wandb:
1029
+ if OmegaConf.select(
1030
+ self.config, "trainer_config.wandb.viz_enabled", default=True
1031
+ ):
1032
+ wandb_modes.append("direct")
1033
+ if OmegaConf.select(
1034
+ self.config, "trainer_config.wandb.viz_boxes", default=False
1035
+ ):
1036
+ wandb_modes.append("boxes")
1037
+ if OmegaConf.select(
1038
+ self.config, "trainer_config.wandb.viz_masks", default=False
1039
+ ):
1040
+ wandb_modes.append("masks")
1041
+
1042
+ # Single unified callback handles all visualization outputs
760
1043
  callbacks.append(
761
- MatplotlibSaver(
762
- save_folder=viz_dir,
763
- plot_fn=lambda: self.lightning_model.visualize_example(
764
- next(val_viz_pipeline)
1044
+ UnifiedVizCallback(
1045
+ model_trainer=self,
1046
+ train_dataset=viz_train_dataset,
1047
+ val_dataset=viz_val_dataset,
1048
+ model_type=self.model_type,
1049
+ save_local=self.config.trainer_config.save_ckpt,
1050
+ local_save_dir=viz_dir,
1051
+ log_wandb=log_wandb,
1052
+ wandb_modes=wandb_modes if wandb_modes else ["direct"],
1053
+ wandb_box_size=OmegaConf.select(
1054
+ self.config, "trainer_config.wandb.viz_box_size", default=5.0
1055
+ ),
1056
+ wandb_confmap_threshold=OmegaConf.select(
1057
+ self.config,
1058
+ "trainer_config.wandb.viz_confmap_threshold",
1059
+ default=0.1,
1060
+ ),
1061
+ log_wandb_table=OmegaConf.select(
1062
+ self.config, "trainer_config.wandb.log_viz_table", default=False
765
1063
  ),
766
- prefix="validation",
767
1064
  )
768
1065
  )
769
1066
 
770
- if self.model_type == "bottomup":
771
- train_viz_pipeline1 = cycle(copy.deepcopy(viz_train_dataset))
772
- val_viz_pipeline1 = cycle(copy.deepcopy(viz_val_dataset))
773
- callbacks.append(
774
- MatplotlibSaver(
775
- save_folder=viz_dir,
776
- plot_fn=lambda: self.lightning_model.visualize_pafs_example(
777
- next(train_viz_pipeline1)
778
- ),
779
- prefix="train.pafs_magnitude",
780
- )
781
- )
782
- callbacks.append(
783
- MatplotlibSaver(
784
- save_folder=viz_dir,
785
- plot_fn=lambda: self.lightning_model.visualize_pafs_example(
786
- next(val_viz_pipeline1)
787
- ),
788
- prefix="validation.pafs_magnitude",
789
- )
790
- )
1067
+ # Add custom progress bar with better metric formatting
1068
+ if self.config.trainer_config.enable_progress_bar:
1069
+ callbacks.append(SleapProgressBar())
791
1070
 
792
- if self.model_type == "multi_class_bottomup":
793
- train_viz_pipeline1 = cycle(copy.deepcopy(viz_train_dataset))
794
- val_viz_pipeline1 = cycle(copy.deepcopy(viz_val_dataset))
1071
+ # Add epoch-end evaluation callback if enabled
1072
+ if self.config.trainer_config.eval.enabled:
1073
+ if self.model_type == "centroid":
1074
+ # Use centroid-specific evaluation with distance-based metrics
795
1075
  callbacks.append(
796
- MatplotlibSaver(
797
- save_folder=viz_dir,
798
- plot_fn=lambda: self.lightning_model.visualize_class_maps_example(
799
- next(train_viz_pipeline1)
800
- ),
801
- prefix="train.class_maps",
1076
+ CentroidEvaluationCallback(
1077
+ videos=self.val_labels[0].videos,
1078
+ eval_frequency=self.config.trainer_config.eval.frequency,
1079
+ match_threshold=self.config.trainer_config.eval.match_threshold,
802
1080
  )
803
1081
  )
1082
+ else:
1083
+ # Use standard OKS/PCK evaluation for pose models
804
1084
  callbacks.append(
805
- MatplotlibSaver(
806
- save_folder=viz_dir,
807
- plot_fn=lambda: self.lightning_model.visualize_class_maps_example(
808
- next(val_viz_pipeline1)
809
- ),
810
- prefix="validation.class_maps",
811
- )
812
- )
813
-
814
- if self.config.trainer_config.use_wandb and OmegaConf.select(
815
- self.config, "trainer_config.wandb.save_viz_imgs_wandb", default=False
816
- ):
817
- callbacks.append(
818
- WandBPredImageLogger(
819
- viz_folder=viz_dir,
820
- wandb_run_name=self.config.trainer_config.wandb.name,
821
- is_bottomup=(self.model_type == "bottomup"),
1085
+ EpochEndEvaluationCallback(
1086
+ skeleton=self.skeletons[0],
1087
+ videos=self.val_labels[0].videos,
1088
+ eval_frequency=self.config.trainer_config.eval.frequency,
1089
+ oks_stddev=self.config.trainer_config.eval.oks_stddev,
1090
+ oks_scale=self.config.trainer_config.eval.oks_scale,
822
1091
  )
823
1092
  )
824
1093
 
@@ -910,6 +1179,11 @@ class ModelTrainer:
910
1179
  : self.config.trainer_config.trainer_devices
911
1180
  ]
912
1181
  ]
1182
+ # Sort device indices in ascending order for NCCL compatibility.
1183
+ # NCCL expects devices in consistent ascending order across ranks
1184
+ # to properly set up communication rings. Without sorting, DDP may
1185
+ # assign multiple ranks to the same GPU, causing "Duplicate GPU detected" errors.
1186
+ devices.sort()
913
1187
  logger.info(f"Using GPUs with most available memory: {devices}")
914
1188
 
915
1189
  # create lightning.Trainer instance.
@@ -931,6 +1205,10 @@ class ModelTrainer:
931
1205
  # setup datasets
932
1206
  train_dataset, val_dataset = self._setup_datasets()
933
1207
 
1208
+ # Barrier after dataset creation to ensure all workers wait for disk caching
1209
+ # (rank 0 caches to disk, others must wait before reading cached files)
1210
+ self.trainer.strategy.barrier()
1211
+
934
1212
  # set-up steps per epoch
935
1213
  train_steps_per_epoch = self.config.trainer_config.train_steps_per_epoch
936
1214
  if train_steps_per_epoch is None:
@@ -959,7 +1237,7 @@ class ModelTrainer:
959
1237
  logger.info(f"Backbone model: {self.lightning_model.model.backbone}")
960
1238
  logger.info(f"Head model: {self.lightning_model.model.head_layers}")
961
1239
  total_params = sum(p.numel() for p in self.lightning_model.parameters())
962
- logger.info(f"Total model parameters: {total_params}")
1240
+ logger.info(f"Total model parameters: {total_params:,}")
963
1241
  self.config.model_config.total_params = total_params
964
1242
 
965
1243
  # setup dataloaders
@@ -1000,6 +1278,26 @@ class ModelTrainer:
1000
1278
  id=self.config.trainer_config.wandb.prv_runid,
1001
1279
  group=self.config.trainer_config.wandb.group,
1002
1280
  )
1281
+
1282
+ # Define custom x-axes for wandb metrics
1283
+ # Epoch-level metrics use epoch as x-axis, step-level use default global_step
1284
+ wandb.define_metric("epoch")
1285
+
1286
+ # Training metrics (train/ prefix for grouping) - all use epoch x-axis
1287
+ wandb.define_metric("train/*", step_metric="epoch")
1288
+ wandb.define_metric("train/confmaps/*", step_metric="epoch")
1289
+
1290
+ # Validation metrics (val/ prefix for grouping)
1291
+ wandb.define_metric("val/*", step_metric="epoch")
1292
+
1293
+ # Evaluation metrics (eval/ prefix for grouping)
1294
+ wandb.define_metric("eval/*", step_metric="epoch")
1295
+
1296
+ # Visualization images (need explicit nested paths)
1297
+ wandb.define_metric("viz/*", step_metric="epoch")
1298
+ wandb.define_metric("viz/train/*", step_metric="epoch")
1299
+ wandb.define_metric("viz/val/*", step_metric="epoch")
1300
+
1003
1301
  self.config.trainer_config.wandb.current_run_id = wandb.run.id
1004
1302
  wandb.config["run_name"] = self.config.trainer_config.wandb.name
1005
1303
  wandb.config["run_config"] = OmegaConf.to_container(
@@ -1017,6 +1315,9 @@ class ModelTrainer:
1017
1315
 
1018
1316
  self.trainer.strategy.barrier()
1019
1317
 
1318
+ # Flag to track if training was interrupted (not completed normally)
1319
+ training_interrupted = False
1320
+
1020
1321
  try:
1021
1322
  logger.info(
1022
1323
  f"Finished trainer set up. [{time.time() - start_setup_time:.1f}s]"
@@ -1032,13 +1333,13 @@ class ModelTrainer:
1032
1333
 
1033
1334
  except KeyboardInterrupt:
1034
1335
  logger.info("Stopping training...")
1336
+ training_interrupted = True
1035
1337
 
1036
1338
  finally:
1037
1339
  logger.info(
1038
1340
  f"Finished training loop. [{(time.time() - start_train_time) / 60:.1f} min]"
1039
1341
  )
1040
- if self.trainer.global_rank == 0 and self.config.trainer_config.use_wandb:
1041
- wandb.finish()
1342
+ # Note: wandb.finish() is called in train.py after post-training evaluation
1042
1343
 
1043
1344
  # delete image disk caching
1044
1345
  if (
@@ -1063,3 +1364,15 @@ class ModelTrainer:
1063
1364
  if viz_dir.exists():
1064
1365
  logger.info(f"Deleting viz folder at {viz_dir}...")
1065
1366
  shutil.rmtree(viz_dir, ignore_errors=True)
1367
+
1368
+ # Clean up entire run folder if training was interrupted (KeyboardInterrupt)
1369
+ if training_interrupted and self.trainer.global_rank == 0:
1370
+ run_dir = (
1371
+ Path(self.config.trainer_config.ckpt_dir)
1372
+ / self.config.trainer_config.run_name
1373
+ )
1374
+ if run_dir.exists():
1375
+ logger.info(
1376
+ f"Training canceled - cleaning up run folder at {run_dir}..."
1377
+ )
1378
+ shutil.rmtree(run_dir, ignore_errors=True)