sleap-nn 0.0.5__py3-none-any.whl → 0.1.0a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -20,7 +20,11 @@ from itertools import cycle, count
20
20
  from omegaconf import DictConfig, OmegaConf
21
21
  from lightning.pytorch.loggers import WandbLogger
22
22
  from sleap_nn.data.utils import check_cache_memory
23
- from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping
23
+ from lightning.pytorch.callbacks import (
24
+ ModelCheckpoint,
25
+ EarlyStopping,
26
+ LearningRateMonitor,
27
+ )
24
28
  from lightning.pytorch.profilers import (
25
29
  SimpleProfiler,
26
30
  AdvancedProfiler,
@@ -28,7 +32,11 @@ from lightning.pytorch.profilers import (
28
32
  PassThroughProfiler,
29
33
  )
30
34
  from sleap_io.io.skeleton import SkeletonYAMLEncoder
31
- from sleap_nn.data.instance_cropping import find_instance_crop_size
35
+ from sleap_nn.data.instance_cropping import (
36
+ find_instance_crop_size,
37
+ find_max_instance_bbox_size,
38
+ compute_augmentation_padding,
39
+ )
32
40
  from sleap_nn.data.providers import get_max_height_width
33
41
  from sleap_nn.data.custom_datasets import (
34
42
  get_train_val_dataloaders,
@@ -49,7 +57,10 @@ from sleap_nn.training.callbacks import (
49
57
  TrainingControllerZMQ,
50
58
  MatplotlibSaver,
51
59
  WandBPredImageLogger,
60
+ WandBVizCallback,
61
+ WandBVizCallbackWithPAFs,
52
62
  CSVLoggerCallback,
63
+ SleapProgressBar,
53
64
  )
54
65
  from sleap_nn import RANK
55
66
  from sleap_nn.legacy_models import get_keras_first_layer_channels
@@ -207,6 +218,52 @@ class ModelTrainer:
207
218
  trainer_devices = 1
208
219
  return trainer_devices
209
220
 
221
+ def _count_labeled_frames(
222
+ self, labels_list: List[sio.Labels], user_only: bool = True
223
+ ) -> int:
224
+ """Count labeled frames, optionally filtering to user-labeled only.
225
+
226
+ Args:
227
+ labels_list: List of Labels objects to count frames from.
228
+ user_only: If True, count only frames with user instances.
229
+
230
+ Returns:
231
+ Total count of labeled frames.
232
+ """
233
+ total = 0
234
+ for label in labels_list:
235
+ if user_only:
236
+ total += sum(1 for lf in label if lf.has_user_instances)
237
+ else:
238
+ total += len(label)
239
+ return total
240
+
241
+ def _filter_to_user_labeled(self, labels: sio.Labels) -> sio.Labels:
242
+ """Filter a Labels object to only include user-labeled frames.
243
+
244
+ Args:
245
+ labels: Labels object to filter.
246
+
247
+ Returns:
248
+ New Labels object containing only frames with user instances.
249
+ """
250
+ # Filter labeled frames to only those with user instances
251
+ user_lfs = [lf for lf in labels if lf.has_user_instances]
252
+
253
+ # Set instances to user instances only
254
+ for lf in user_lfs:
255
+ lf.instances = lf.user_instances
256
+
257
+ # Create new Labels with filtered frames
258
+ return sio.Labels(
259
+ labeled_frames=user_lfs,
260
+ videos=labels.videos,
261
+ skeletons=labels.skeletons,
262
+ tracks=labels.tracks,
263
+ suggestions=labels.suggestions,
264
+ provenance=labels.provenance,
265
+ )
266
+
210
267
  def _setup_train_val_labels(
211
268
  self,
212
269
  labels: Optional[List[sio.Labels]] = None,
@@ -218,21 +275,35 @@ class ModelTrainer:
218
275
  total_val_lfs = 0
219
276
  self.skeletons = labels[0].skeletons
220
277
 
278
+ # Check if we should count only user-labeled frames
279
+ user_instances_only = OmegaConf.select(
280
+ self.config, "data_config.user_instances_only", default=True
281
+ )
282
+
221
283
  # check if all `.slp` file shave same skeleton structure (if multiple slp file paths are provided)
222
284
  skeleton = self.skeletons[0]
223
285
  for index, train_label in enumerate(labels):
224
286
  skel_temp = train_label.skeletons[0]
225
287
  skeletons_equal = skeleton.matches(skel_temp)
226
- if skeletons_equal:
227
- total_train_lfs += len(train_label)
228
- else:
288
+ if not skeletons_equal:
229
289
  message = f"The skeletons in the training labels: {index + 1} do not match the skeleton in the first training label file."
230
290
  logger.error(message)
231
291
  raise ValueError(message)
232
292
 
233
- if val_labels is None or not len(val_labels):
293
+ # Check for same-data mode (train = val, for intentional overfitting)
294
+ use_same = OmegaConf.select(
295
+ self.config, "data_config.use_same_data_for_val", default=False
296
+ )
297
+
298
+ if use_same:
299
+ # Same mode: use identical data for train and val (for overfitting)
300
+ logger.info("Using same data for train and val (overfit mode)")
301
+ self.train_labels = labels
302
+ self.val_labels = labels
303
+ total_train_lfs = self._count_labeled_frames(labels, user_instances_only)
304
+ total_val_lfs = total_train_lfs
305
+ elif val_labels is None or not len(val_labels):
234
306
  # if val labels are not provided, split from train
235
- total_train_lfs = 0
236
307
  val_fraction = OmegaConf.select(
237
308
  self.config, "data_config.validation_fraction", default=0.1
238
309
  )
@@ -250,13 +321,14 @@ class ModelTrainer:
250
321
  )
251
322
  self.train_labels.append(train_split)
252
323
  self.val_labels.append(val_split)
324
+ # make_training_splits returns only user-labeled frames
253
325
  total_train_lfs += len(train_split)
254
326
  total_val_lfs += len(val_split)
255
327
  else:
256
328
  self.train_labels = labels
257
329
  self.val_labels = val_labels
258
- for val_l in self.val_labels:
259
- total_val_lfs += len(val_l)
330
+ total_train_lfs = self._count_labeled_frames(labels, user_instances_only)
331
+ total_val_lfs = self._count_labeled_frames(val_labels, user_instances_only)
260
332
 
261
333
  logger.info(f"# Train Labeled frames: {total_train_lfs}")
262
334
  logger.info(f"# Val Labeled frames: {total_val_lfs}")
@@ -291,13 +363,70 @@ class ModelTrainer:
291
363
  ):
292
364
  # compute crop size if not provided in config
293
365
  if crop_size is None:
366
+ # Get padding from config or auto-compute from augmentation settings
367
+ padding = self.config.data_config.preprocessing.crop_padding
368
+ if padding is None:
369
+ # Auto-compute padding based on augmentation settings
370
+ aug_config = self.config.data_config.augmentation_config
371
+ if (
372
+ self.config.data_config.use_augmentations_train
373
+ and aug_config is not None
374
+ and aug_config.geometric is not None
375
+ ):
376
+ geo = aug_config.geometric
377
+ # Check if rotation is enabled (via rotation_p or affine_p)
378
+ rotation_enabled = (
379
+ geo.rotation_p is not None and geo.rotation_p > 0
380
+ ) or (
381
+ geo.rotation_p is None
382
+ and geo.scale_p is None
383
+ and geo.translate_p is None
384
+ and geo.affine_p > 0
385
+ )
386
+ # Check if scale is enabled (via scale_p or affine_p)
387
+ scale_enabled = (
388
+ geo.scale_p is not None and geo.scale_p > 0
389
+ ) or (
390
+ geo.rotation_p is None
391
+ and geo.scale_p is None
392
+ and geo.translate_p is None
393
+ and geo.affine_p > 0
394
+ )
395
+
396
+ if rotation_enabled or scale_enabled:
397
+ # First find the actual max bbox size from labels
398
+ bbox_size = find_max_instance_bbox_size(train_label)
399
+ bbox_size = max(
400
+ bbox_size,
401
+ self.config.data_config.preprocessing.min_crop_size
402
+ or 100,
403
+ )
404
+ rotation_max = (
405
+ max(
406
+ abs(geo.rotation_min),
407
+ abs(geo.rotation_max),
408
+ )
409
+ if rotation_enabled
410
+ else 0.0
411
+ )
412
+ scale_max = geo.scale_max if scale_enabled else 1.0
413
+ padding = compute_augmentation_padding(
414
+ bbox_size=bbox_size,
415
+ rotation_max=rotation_max,
416
+ scale_max=scale_max,
417
+ )
418
+ else:
419
+ padding = 0
420
+ else:
421
+ padding = 0
422
+
294
423
  crop_sz = find_instance_crop_size(
295
424
  labels=train_label,
425
+ padding=padding,
296
426
  maximum_stride=self.config.model_config.backbone_config[
297
427
  f"{self.backbone_type}"
298
428
  ]["max_stride"],
299
429
  min_crop_size=self.config.data_config.preprocessing.min_crop_size,
300
- input_scaling=self.config.data_config.preprocessing.scale,
301
430
  )
302
431
 
303
432
  if crop_sz > max_crop_size:
@@ -509,6 +638,14 @@ class ModelTrainer:
509
638
  """Compute config parameters."""
510
639
  logger.info("Setting up config...")
511
640
 
641
+ # Normalize empty strings to None for optional wandb fields
642
+ if self.config.trainer_config.wandb.prv_runid == "":
643
+ self.config.trainer_config.wandb.prv_runid = None
644
+
645
+ # Default wandb run name to trainer run_name if not specified
646
+ if self.config.trainer_config.wandb.name is None:
647
+ self.config.trainer_config.wandb.name = self.config.trainer_config.run_name
648
+
512
649
  # compute preprocessing parameters from the labels objects and fill in the config
513
650
  self._setup_preprocessing_config()
514
651
 
@@ -565,7 +702,7 @@ class ModelTrainer:
565
702
  self._verify_model_input_channels()
566
703
 
567
704
  def _setup_model_ckpt_dir(self):
568
- """Create the model ckpt folder."""
705
+ """Create the model ckpt folder and save ground truth labels."""
569
706
  ckpt_path = (
570
707
  Path(self.config.trainer_config.ckpt_dir)
571
708
  / self.config.trainer_config.run_name
@@ -581,16 +718,61 @@ class ModelTrainer:
581
718
  raise OSError(message)
582
719
 
583
720
  if RANK in [0, -1]:
721
+ # Check if we should filter to user-labeled frames only
722
+ user_instances_only = OmegaConf.select(
723
+ self.config, "data_config.user_instances_only", default=True
724
+ )
725
+
726
+ # Save train and val ground truth labels
584
727
  for idx, (train, val) in enumerate(zip(self.train_labels, self.val_labels)):
585
- train.save(
586
- Path(ckpt_path) / f"labels_train_gt_{idx}.slp",
728
+ # Filter to user-labeled frames if needed (for evaluation)
729
+ if user_instances_only:
730
+ train_filtered = self._filter_to_user_labeled(train)
731
+ val_filtered = self._filter_to_user_labeled(val)
732
+ else:
733
+ train_filtered = train
734
+ val_filtered = val
735
+
736
+ train_filtered.save(
737
+ Path(ckpt_path) / f"labels_gt.train.{idx}.slp",
587
738
  restore_original_videos=False,
588
739
  )
589
- val.save(
590
- Path(ckpt_path) / f"labels_val_gt_{idx}.slp",
740
+ val_filtered.save(
741
+ Path(ckpt_path) / f"labels_gt.val.{idx}.slp",
591
742
  restore_original_videos=False,
592
743
  )
593
744
 
745
+ # Save test ground truth labels if test paths are provided
746
+ test_file_path = OmegaConf.select(
747
+ self.config, "data_config.test_file_path", default=None
748
+ )
749
+ if test_file_path is not None:
750
+ # Normalize to list of strings
751
+ if isinstance(test_file_path, str):
752
+ test_paths = [test_file_path]
753
+ else:
754
+ test_paths = list(test_file_path)
755
+
756
+ for idx, test_path in enumerate(test_paths):
757
+ # Only save if it's a .slp file (not a video file)
758
+ if test_path.endswith(".slp") or test_path.endswith(".pkg.slp"):
759
+ try:
760
+ test_labels = sio.load_slp(test_path)
761
+ if user_instances_only:
762
+ test_filtered = self._filter_to_user_labeled(
763
+ test_labels
764
+ )
765
+ else:
766
+ test_filtered = test_labels
767
+ test_filtered.save(
768
+ Path(ckpt_path) / f"labels_gt.test.{idx}.slp",
769
+ restore_original_videos=False,
770
+ )
771
+ except Exception as e:
772
+ logger.warning(
773
+ f"Could not save test ground truth for {test_path}: {e}"
774
+ )
775
+
594
776
  def _setup_viz_datasets(self):
595
777
  """Setup dataloaders."""
596
778
  data_viz_config = self.config.copy()
@@ -716,6 +898,10 @@ class ModelTrainer:
716
898
  )
717
899
  loggers.append(wandb_logger)
718
900
 
901
+ # Learning rate monitor callback - logs LR at each step for dynamic schedulers
902
+ # Only added when wandb is enabled since it requires a logger
903
+ callbacks.append(LearningRateMonitor(logging_interval="step"))
904
+
719
905
  # save the configs as yaml in the checkpoint dir
720
906
  # Mask API key in both configs to prevent saving to disk
721
907
  self.config.trainer_config.wandb.api_key = ""
@@ -814,13 +1000,80 @@ class ModelTrainer:
814
1000
  if self.config.trainer_config.use_wandb and OmegaConf.select(
815
1001
  self.config, "trainer_config.wandb.save_viz_imgs_wandb", default=False
816
1002
  ):
817
- callbacks.append(
818
- WandBPredImageLogger(
819
- viz_folder=viz_dir,
820
- wandb_run_name=self.config.trainer_config.wandb.name,
821
- is_bottomup=(self.model_type == "bottomup"),
822
- )
1003
+ # Get wandb viz config options
1004
+ viz_enabled = OmegaConf.select(
1005
+ self.config, "trainer_config.wandb.viz_enabled", default=True
823
1006
  )
1007
+ viz_boxes = OmegaConf.select(
1008
+ self.config, "trainer_config.wandb.viz_boxes", default=False
1009
+ )
1010
+ viz_masks = OmegaConf.select(
1011
+ self.config, "trainer_config.wandb.viz_masks", default=False
1012
+ )
1013
+ viz_box_size = OmegaConf.select(
1014
+ self.config, "trainer_config.wandb.viz_box_size", default=5.0
1015
+ )
1016
+ viz_confmap_threshold = OmegaConf.select(
1017
+ self.config,
1018
+ "trainer_config.wandb.viz_confmap_threshold",
1019
+ default=0.1,
1020
+ )
1021
+ log_viz_table = OmegaConf.select(
1022
+ self.config, "trainer_config.wandb.log_viz_table", default=False
1023
+ )
1024
+
1025
+ # Create viz data pipelines for wandb callback
1026
+ wandb_train_viz_pipeline = cycle(copy.deepcopy(viz_train_dataset))
1027
+ wandb_val_viz_pipeline = cycle(copy.deepcopy(viz_val_dataset))
1028
+
1029
+ if self.model_type == "bottomup":
1030
+ # Bottom-up model needs PAF visualizations
1031
+ wandb_train_pafs_pipeline = cycle(copy.deepcopy(viz_train_dataset))
1032
+ wandb_val_pafs_pipeline = cycle(copy.deepcopy(viz_val_dataset))
1033
+ callbacks.append(
1034
+ WandBVizCallbackWithPAFs(
1035
+ train_viz_fn=lambda: self.lightning_model.get_visualization_data(
1036
+ next(wandb_train_viz_pipeline)
1037
+ ),
1038
+ val_viz_fn=lambda: self.lightning_model.get_visualization_data(
1039
+ next(wandb_val_viz_pipeline)
1040
+ ),
1041
+ train_pafs_viz_fn=lambda: self.lightning_model.get_visualization_data(
1042
+ next(wandb_train_pafs_pipeline), include_pafs=True
1043
+ ),
1044
+ val_pafs_viz_fn=lambda: self.lightning_model.get_visualization_data(
1045
+ next(wandb_val_pafs_pipeline), include_pafs=True
1046
+ ),
1047
+ viz_enabled=viz_enabled,
1048
+ viz_boxes=viz_boxes,
1049
+ viz_masks=viz_masks,
1050
+ box_size=viz_box_size,
1051
+ confmap_threshold=viz_confmap_threshold,
1052
+ log_table=log_viz_table,
1053
+ )
1054
+ )
1055
+ else:
1056
+ # Standard models
1057
+ callbacks.append(
1058
+ WandBVizCallback(
1059
+ train_viz_fn=lambda: self.lightning_model.get_visualization_data(
1060
+ next(wandb_train_viz_pipeline)
1061
+ ),
1062
+ val_viz_fn=lambda: self.lightning_model.get_visualization_data(
1063
+ next(wandb_val_viz_pipeline)
1064
+ ),
1065
+ viz_enabled=viz_enabled,
1066
+ viz_boxes=viz_boxes,
1067
+ viz_masks=viz_masks,
1068
+ box_size=viz_box_size,
1069
+ confmap_threshold=viz_confmap_threshold,
1070
+ log_table=log_viz_table,
1071
+ )
1072
+ )
1073
+
1074
+ # Add custom progress bar with better metric formatting
1075
+ if self.config.trainer_config.enable_progress_bar:
1076
+ callbacks.append(SleapProgressBar())
824
1077
 
825
1078
  return loggers, callbacks
826
1079
 
@@ -959,7 +1212,7 @@ class ModelTrainer:
959
1212
  logger.info(f"Backbone model: {self.lightning_model.model.backbone}")
960
1213
  logger.info(f"Head model: {self.lightning_model.model.head_layers}")
961
1214
  total_params = sum(p.numel() for p in self.lightning_model.parameters())
962
- logger.info(f"Total model parameters: {total_params}")
1215
+ logger.info(f"Total model parameters: {total_params:,}")
963
1216
  self.config.model_config.total_params = total_params
964
1217
 
965
1218
  # setup dataloaders
@@ -1000,6 +1253,23 @@ class ModelTrainer:
1000
1253
  id=self.config.trainer_config.wandb.prv_runid,
1001
1254
  group=self.config.trainer_config.wandb.group,
1002
1255
  )
1256
+
1257
+ # Define custom x-axes for wandb metrics
1258
+ # Epoch-level metrics use epoch as x-axis, step-level use default global_step
1259
+ wandb.define_metric("epoch")
1260
+ wandb.define_metric("val_loss", step_metric="epoch")
1261
+ wandb.define_metric("val_time", step_metric="epoch")
1262
+ wandb.define_metric("train_time", step_metric="epoch")
1263
+ # Per-node losses use epoch as x-axis
1264
+ for node_name in self.skeletons[0].node_names:
1265
+ wandb.define_metric(node_name, step_metric="epoch")
1266
+
1267
+ # Visualization images use epoch as x-axis
1268
+ wandb.define_metric("train_predictions*", step_metric="epoch")
1269
+ wandb.define_metric("val_predictions*", step_metric="epoch")
1270
+ wandb.define_metric("train_pafs*", step_metric="epoch")
1271
+ wandb.define_metric("val_pafs*", step_metric="epoch")
1272
+
1003
1273
  self.config.trainer_config.wandb.current_run_id = wandb.run.id
1004
1274
  wandb.config["run_name"] = self.config.trainer_config.wandb.name
1005
1275
  wandb.config["run_config"] = OmegaConf.to_container(
@@ -1017,6 +1287,9 @@ class ModelTrainer:
1017
1287
 
1018
1288
  self.trainer.strategy.barrier()
1019
1289
 
1290
+ # Flag to track if training was interrupted (not completed normally)
1291
+ training_interrupted = False
1292
+
1020
1293
  try:
1021
1294
  logger.info(
1022
1295
  f"Finished trainer set up. [{time.time() - start_setup_time:.1f}s]"
@@ -1032,6 +1305,7 @@ class ModelTrainer:
1032
1305
 
1033
1306
  except KeyboardInterrupt:
1034
1307
  logger.info("Stopping training...")
1308
+ training_interrupted = True
1035
1309
 
1036
1310
  finally:
1037
1311
  logger.info(
@@ -1063,3 +1337,15 @@ class ModelTrainer:
1063
1337
  if viz_dir.exists():
1064
1338
  logger.info(f"Deleting viz folder at {viz_dir}...")
1065
1339
  shutil.rmtree(viz_dir, ignore_errors=True)
1340
+
1341
+ # Clean up entire run folder if training was interrupted (KeyboardInterrupt)
1342
+ if training_interrupted and self.trainer.global_rank == 0:
1343
+ run_dir = (
1344
+ Path(self.config.trainer_config.ckpt_dir)
1345
+ / self.config.trainer_config.run_name
1346
+ )
1347
+ if run_dir.exists():
1348
+ logger.info(
1349
+ f"Training canceled - cleaning up run folder at {run_dir}..."
1350
+ )
1351
+ shutil.rmtree(run_dir, ignore_errors=True)