sleap-nn 0.1.0a1__py3-none-any.whl → 0.1.0a3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. sleap_nn/__init__.py +1 -1
  2. sleap_nn/cli.py +36 -0
  3. sleap_nn/config/trainer_config.py +18 -0
  4. sleap_nn/evaluation.py +81 -22
  5. sleap_nn/export/__init__.py +21 -0
  6. sleap_nn/export/cli.py +1778 -0
  7. sleap_nn/export/exporters/__init__.py +51 -0
  8. sleap_nn/export/exporters/onnx_exporter.py +80 -0
  9. sleap_nn/export/exporters/tensorrt_exporter.py +291 -0
  10. sleap_nn/export/metadata.py +225 -0
  11. sleap_nn/export/predictors/__init__.py +63 -0
  12. sleap_nn/export/predictors/base.py +22 -0
  13. sleap_nn/export/predictors/onnx.py +154 -0
  14. sleap_nn/export/predictors/tensorrt.py +312 -0
  15. sleap_nn/export/utils.py +307 -0
  16. sleap_nn/export/wrappers/__init__.py +25 -0
  17. sleap_nn/export/wrappers/base.py +96 -0
  18. sleap_nn/export/wrappers/bottomup.py +243 -0
  19. sleap_nn/export/wrappers/bottomup_multiclass.py +195 -0
  20. sleap_nn/export/wrappers/centered_instance.py +56 -0
  21. sleap_nn/export/wrappers/centroid.py +58 -0
  22. sleap_nn/export/wrappers/single_instance.py +83 -0
  23. sleap_nn/export/wrappers/topdown.py +180 -0
  24. sleap_nn/export/wrappers/topdown_multiclass.py +304 -0
  25. sleap_nn/inference/bottomup.py +86 -20
  26. sleap_nn/inference/postprocessing.py +284 -0
  27. sleap_nn/predict.py +29 -0
  28. sleap_nn/train.py +64 -0
  29. sleap_nn/training/callbacks.py +324 -8
  30. sleap_nn/training/lightning_modules.py +542 -32
  31. sleap_nn/training/model_trainer.py +48 -57
  32. {sleap_nn-0.1.0a1.dist-info → sleap_nn-0.1.0a3.dist-info}/METADATA +13 -2
  33. {sleap_nn-0.1.0a1.dist-info → sleap_nn-0.1.0a3.dist-info}/RECORD +37 -16
  34. {sleap_nn-0.1.0a1.dist-info → sleap_nn-0.1.0a3.dist-info}/WHEEL +0 -0
  35. {sleap_nn-0.1.0a1.dist-info → sleap_nn-0.1.0a3.dist-info}/entry_points.txt +0 -0
  36. {sleap_nn-0.1.0a1.dist-info → sleap_nn-0.1.0a3.dist-info}/licenses/LICENSE +0 -0
  37. {sleap_nn-0.1.0a1.dist-info → sleap_nn-0.1.0a3.dist-info}/top_level.txt +0 -0
@@ -295,8 +295,8 @@ class WandBVizCallback(Callback):
295
295
  suffix = "" if mode_name == "direct" else f"_{mode_name}"
296
296
  train_img = renderer.render(train_data, caption=f"Train Epoch {epoch}")
297
297
  val_img = renderer.render(val_data, caption=f"Val Epoch {epoch}")
298
- log_dict[f"train_predictions{suffix}"] = train_img
299
- log_dict[f"val_predictions{suffix}"] = val_img
298
+ log_dict[f"viz/train/predictions{suffix}"] = train_img
299
+ log_dict[f"viz/val/predictions{suffix}"] = val_img
300
300
 
301
301
  if log_dict:
302
302
  # Include epoch so wandb can use it as x-axis (via define_metric)
@@ -394,8 +394,8 @@ class WandBVizCallbackWithPAFs(WandBVizCallback):
394
394
  suffix = "" if mode_name == "direct" else f"_{mode_name}"
395
395
  train_img = renderer.render(train_data, caption=f"Train Epoch {epoch}")
396
396
  val_img = renderer.render(val_data, caption=f"Val Epoch {epoch}")
397
- log_dict[f"train_predictions{suffix}"] = train_img
398
- log_dict[f"val_predictions{suffix}"] = val_img
397
+ log_dict[f"viz/train/predictions{suffix}"] = train_img
398
+ log_dict[f"viz/val/predictions{suffix}"] = val_img
399
399
 
400
400
  # Render PAFs (always use matplotlib/direct for PAFs)
401
401
  from io import BytesIO
@@ -408,7 +408,7 @@ class WandBVizCallbackWithPAFs(WandBVizCallback):
408
408
  buf.seek(0)
409
409
  plt.close(train_pafs_fig)
410
410
  train_pafs_pil = Image.open(buf)
411
- log_dict["train_pafs"] = wandb.Image(
411
+ log_dict["viz/train/pafs"] = wandb.Image(
412
412
  train_pafs_pil, caption=f"Train PAFs Epoch {epoch}"
413
413
  )
414
414
 
@@ -418,7 +418,7 @@ class WandBVizCallbackWithPAFs(WandBVizCallback):
418
418
  buf.seek(0)
419
419
  plt.close(val_pafs_fig)
420
420
  val_pafs_pil = Image.open(buf)
421
- log_dict["val_pafs"] = wandb.Image(
421
+ log_dict["viz/val/pafs"] = wandb.Image(
422
422
  val_pafs_pil, caption=f"Val PAFs Epoch {epoch}"
423
423
  )
424
424
 
@@ -444,8 +444,8 @@ class WandBVizCallbackWithPAFs(WandBVizCallback):
444
444
  epoch,
445
445
  train_img,
446
446
  val_img,
447
- log_dict["train_pafs"],
448
- log_dict["val_pafs"],
447
+ log_dict["viz/train/pafs"],
448
+ log_dict["viz/val/pafs"],
449
449
  ]
450
450
  ],
451
451
  )
@@ -662,3 +662,319 @@ class ProgressReporterZMQ(Callback):
662
662
  return {
663
663
  k: float(v.item()) if hasattr(v, "item") else v for k, v in logs.items()
664
664
  }
665
+
666
+
667
+ class EpochEndEvaluationCallback(Callback):
668
+ """Callback to run full evaluation metrics at end of validation epochs.
669
+
670
+ This callback collects predictions and ground truth during validation,
671
+ then runs the full evaluation pipeline (OKS, mAP, PCK, etc.) and logs
672
+ metrics to WandB.
673
+
674
+ Attributes:
675
+ skeleton: sio.Skeleton for creating instances.
676
+ videos: List of sio.Video objects.
677
+ eval_frequency: Run evaluation every N epochs (default: 1).
678
+ oks_stddev: OKS standard deviation (default: 0.025).
679
+ oks_scale: Optional OKS scale override.
680
+ metrics_to_log: List of metric keys to log.
681
+ """
682
+
683
+ def __init__(
684
+ self,
685
+ skeleton: "sio.Skeleton",
686
+ videos: list,
687
+ eval_frequency: int = 1,
688
+ oks_stddev: float = 0.025,
689
+ oks_scale: Optional[float] = None,
690
+ metrics_to_log: Optional[list] = None,
691
+ ):
692
+ """Initialize the callback.
693
+
694
+ Args:
695
+ skeleton: sio.Skeleton for creating instances.
696
+ videos: List of sio.Video objects.
697
+ eval_frequency: Run evaluation every N epochs (default: 1).
698
+ oks_stddev: OKS standard deviation (default: 0.025).
699
+ oks_scale: Optional OKS scale override.
700
+ metrics_to_log: List of metric keys to log. If None, logs all available.
701
+ """
702
+ super().__init__()
703
+ self.skeleton = skeleton
704
+ self.videos = videos
705
+ self.eval_frequency = eval_frequency
706
+ self.oks_stddev = oks_stddev
707
+ self.oks_scale = oks_scale
708
+ self.metrics_to_log = metrics_to_log or [
709
+ "mOKS",
710
+ "oks_voc.mAP",
711
+ "oks_voc.mAR",
712
+ "distance/avg",
713
+ "distance/p50",
714
+ "distance/p95",
715
+ "distance/p99",
716
+ "mPCK",
717
+ "PCK@5",
718
+ "PCK@10",
719
+ "visibility_precision",
720
+ "visibility_recall",
721
+ ]
722
+
723
+ def on_validation_epoch_start(self, trainer, pl_module):
724
+ """Enable prediction collection at the start of validation.
725
+
726
+ Skip during sanity check to avoid inference issues.
727
+ """
728
+ if trainer.sanity_checking:
729
+ return
730
+ pl_module._collect_val_predictions = True
731
+
732
+ def on_validation_epoch_end(self, trainer, pl_module):
733
+ """Run evaluation and log metrics at end of validation epoch."""
734
+ import sleap_io as sio
735
+ import numpy as np
736
+ from lightning.pytorch.loggers import WandbLogger
737
+ from sleap_nn.evaluation import Evaluator
738
+
739
+ # Check frequency (epoch is 0-indexed, so add 1)
740
+ if (trainer.current_epoch + 1) % self.eval_frequency != 0:
741
+ pl_module._collect_val_predictions = False
742
+ return
743
+
744
+ # Only run on rank 0 for distributed training
745
+ if not trainer.is_global_zero:
746
+ pl_module._collect_val_predictions = False
747
+ return
748
+
749
+ # Check if we have predictions
750
+ if not pl_module.val_predictions or not pl_module.val_ground_truth:
751
+ logger.warning("No predictions collected for epoch-end evaluation")
752
+ pl_module._collect_val_predictions = False
753
+ return
754
+
755
+ try:
756
+ # Build sio.Labels from accumulated predictions and ground truth
757
+ pred_labels = self._build_pred_labels(pl_module.val_predictions, sio, np)
758
+ gt_labels = self._build_gt_labels(pl_module.val_ground_truth, sio, np)
759
+
760
+ # Check if we have valid frames to evaluate
761
+ if len(pred_labels) == 0:
762
+ logger.warning(
763
+ "No valid predictions for epoch-end evaluation "
764
+ "(all predictions may be empty or NaN)"
765
+ )
766
+ pl_module._collect_val_predictions = False
767
+ pl_module.val_predictions = []
768
+ pl_module.val_ground_truth = []
769
+ return
770
+
771
+ # Run evaluation
772
+ evaluator = Evaluator(
773
+ ground_truth_instances=gt_labels,
774
+ predicted_instances=pred_labels,
775
+ oks_stddev=self.oks_stddev,
776
+ oks_scale=self.oks_scale,
777
+ user_labels_only=False, # All validation frames are "user" frames
778
+ )
779
+ metrics = evaluator.evaluate()
780
+
781
+ # Log to WandB
782
+ self._log_metrics(trainer, metrics, trainer.current_epoch)
783
+
784
+ logger.info(
785
+ f"Epoch {trainer.current_epoch} evaluation: "
786
+ f"PCK@5={metrics['pck_metrics']['PCK@5']:.4f}, "
787
+ f"mOKS={metrics['mOKS']['mOKS']:.4f}, "
788
+ f"mAP={metrics['voc_metrics']['oks_voc.mAP']:.4f}"
789
+ )
790
+
791
+ except Exception as e:
792
+ logger.warning(f"Epoch-end evaluation failed: {e}")
793
+
794
+ # Cleanup
795
+ pl_module._collect_val_predictions = False
796
+ pl_module.val_predictions = []
797
+ pl_module.val_ground_truth = []
798
+
799
+ def _build_pred_labels(self, predictions: list, sio, np) -> "sio.Labels":
800
+ """Convert prediction dicts to sio.Labels."""
801
+ labeled_frames = []
802
+ for pred in predictions:
803
+ pred_peaks = pred["pred_peaks"]
804
+ pred_scores = pred["pred_scores"]
805
+
806
+ # Handle NaN/missing predictions
807
+ if pred_peaks is None or (
808
+ isinstance(pred_peaks, np.ndarray) and np.isnan(pred_peaks).all()
809
+ ):
810
+ continue
811
+
812
+ # Handle multi-instance predictions (bottomup)
813
+ if len(pred_peaks.shape) == 2:
814
+ # Single instance: (n_nodes, 2) -> (1, n_nodes, 2)
815
+ pred_peaks = pred_peaks.reshape(1, -1, 2)
816
+ pred_scores = pred_scores.reshape(1, -1)
817
+
818
+ instances = []
819
+ for inst_idx in range(len(pred_peaks)):
820
+ inst_points = pred_peaks[inst_idx]
821
+ inst_scores = pred_scores[inst_idx] if pred_scores is not None else None
822
+
823
+ # Skip if all NaN
824
+ if np.isnan(inst_points).all():
825
+ continue
826
+
827
+ inst = sio.PredictedInstance.from_numpy(
828
+ points_data=inst_points,
829
+ skeleton=self.skeleton,
830
+ point_scores=(
831
+ inst_scores
832
+ if inst_scores is not None
833
+ else np.ones(len(inst_points))
834
+ ),
835
+ score=(
836
+ float(np.nanmean(inst_scores))
837
+ if inst_scores is not None
838
+ else 1.0
839
+ ),
840
+ )
841
+ instances.append(inst)
842
+
843
+ if instances:
844
+ lf = sio.LabeledFrame(
845
+ video=self.videos[pred["video_idx"]],
846
+ frame_idx=pred["frame_idx"],
847
+ instances=instances,
848
+ )
849
+ labeled_frames.append(lf)
850
+
851
+ return sio.Labels(
852
+ videos=self.videos,
853
+ skeletons=[self.skeleton],
854
+ labeled_frames=labeled_frames,
855
+ )
856
+
857
+ def _build_gt_labels(self, ground_truth: list, sio, np) -> "sio.Labels":
858
+ """Convert ground truth dicts to sio.Labels."""
859
+ labeled_frames = []
860
+ for gt in ground_truth:
861
+ instances = []
862
+ gt_instances = gt["gt_instances"]
863
+
864
+ # Handle shape variations
865
+ if len(gt_instances.shape) == 2:
866
+ # (n_nodes, 2) -> (1, n_nodes, 2)
867
+ gt_instances = gt_instances.reshape(1, -1, 2)
868
+
869
+ for i in range(min(gt["num_instances"], len(gt_instances))):
870
+ inst_data = gt_instances[i]
871
+ if np.isnan(inst_data).all():
872
+ continue
873
+ inst = sio.Instance.from_numpy(
874
+ points_data=inst_data,
875
+ skeleton=self.skeleton,
876
+ )
877
+ instances.append(inst)
878
+
879
+ if instances:
880
+ lf = sio.LabeledFrame(
881
+ video=self.videos[gt["video_idx"]],
882
+ frame_idx=gt["frame_idx"],
883
+ instances=instances,
884
+ )
885
+ labeled_frames.append(lf)
886
+
887
+ return sio.Labels(
888
+ videos=self.videos,
889
+ skeletons=[self.skeleton],
890
+ labeled_frames=labeled_frames,
891
+ )
892
+
893
+ def _log_metrics(self, trainer, metrics: dict, epoch: int):
894
+ """Log evaluation metrics to WandB."""
895
+ import numpy as np
896
+ from lightning.pytorch.loggers import WandbLogger
897
+
898
+ # Get WandB logger
899
+ wandb_logger = None
900
+ for log in trainer.loggers:
901
+ if isinstance(log, WandbLogger):
902
+ wandb_logger = log
903
+ break
904
+
905
+ if wandb_logger is None:
906
+ return
907
+
908
+ log_dict = {"epoch": epoch}
909
+
910
+ # Extract key metrics with consistent naming
911
+ # All eval metrics use eval/val/ prefix since they're computed on validation data
912
+ if "mOKS" in self.metrics_to_log:
913
+ log_dict["eval/val/mOKS"] = metrics["mOKS"]["mOKS"]
914
+
915
+ if "oks_voc.mAP" in self.metrics_to_log:
916
+ log_dict["eval/val/oks_voc_mAP"] = metrics["voc_metrics"]["oks_voc.mAP"]
917
+
918
+ if "oks_voc.mAR" in self.metrics_to_log:
919
+ log_dict["eval/val/oks_voc_mAR"] = metrics["voc_metrics"]["oks_voc.mAR"]
920
+
921
+ # Distance metrics grouped under eval/val/distance/
922
+ if "distance/avg" in self.metrics_to_log:
923
+ val = metrics["distance_metrics"]["avg"]
924
+ if not np.isnan(val):
925
+ log_dict["eval/val/distance/avg"] = val
926
+
927
+ if "distance/p50" in self.metrics_to_log:
928
+ val = metrics["distance_metrics"]["p50"]
929
+ if not np.isnan(val):
930
+ log_dict["eval/val/distance/p50"] = val
931
+
932
+ if "distance/p95" in self.metrics_to_log:
933
+ val = metrics["distance_metrics"]["p95"]
934
+ if not np.isnan(val):
935
+ log_dict["eval/val/distance/p95"] = val
936
+
937
+ if "distance/p99" in self.metrics_to_log:
938
+ val = metrics["distance_metrics"]["p99"]
939
+ if not np.isnan(val):
940
+ log_dict["eval/val/distance/p99"] = val
941
+
942
+ # PCK metrics
943
+ if "mPCK" in self.metrics_to_log:
944
+ log_dict["eval/val/mPCK"] = metrics["pck_metrics"]["mPCK"]
945
+
946
+ # PCK at specific thresholds (precomputed in evaluation.py)
947
+ if "PCK@5" in self.metrics_to_log:
948
+ log_dict["eval/val/PCK_5"] = metrics["pck_metrics"]["PCK@5"]
949
+
950
+ if "PCK@10" in self.metrics_to_log:
951
+ log_dict["eval/val/PCK_10"] = metrics["pck_metrics"]["PCK@10"]
952
+
953
+ # Visibility metrics
954
+ if "visibility_precision" in self.metrics_to_log:
955
+ val = metrics["visibility_metrics"]["precision"]
956
+ if not np.isnan(val):
957
+ log_dict["eval/val/visibility_precision"] = val
958
+
959
+ if "visibility_recall" in self.metrics_to_log:
960
+ val = metrics["visibility_metrics"]["recall"]
961
+ if not np.isnan(val):
962
+ log_dict["eval/val/visibility_recall"] = val
963
+
964
+ wandb_logger.experiment.log(log_dict, commit=False)
965
+
966
+ # Update best metrics in summary (excluding epoch)
967
+ for key, value in log_dict.items():
968
+ if key == "epoch":
969
+ continue
970
+ # Create summary key like "best/eval/val/mOKS"
971
+ summary_key = f"best/{key}"
972
+ current_best = wandb_logger.experiment.summary.get(summary_key)
973
+ # For distance metrics, lower is better; for others, higher is better
974
+ is_distance = "distance" in key
975
+ if current_best is None:
976
+ wandb_logger.experiment.summary[summary_key] = value
977
+ elif is_distance and value < current_best:
978
+ wandb_logger.experiment.summary[summary_key] = value
979
+ elif not is_distance and value > current_best:
980
+ wandb_logger.experiment.summary[summary_key] = value