sleap-nn 0.1.0a1__py3-none-any.whl → 0.1.0a3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sleap_nn/__init__.py +1 -1
- sleap_nn/cli.py +36 -0
- sleap_nn/config/trainer_config.py +18 -0
- sleap_nn/evaluation.py +81 -22
- sleap_nn/export/__init__.py +21 -0
- sleap_nn/export/cli.py +1778 -0
- sleap_nn/export/exporters/__init__.py +51 -0
- sleap_nn/export/exporters/onnx_exporter.py +80 -0
- sleap_nn/export/exporters/tensorrt_exporter.py +291 -0
- sleap_nn/export/metadata.py +225 -0
- sleap_nn/export/predictors/__init__.py +63 -0
- sleap_nn/export/predictors/base.py +22 -0
- sleap_nn/export/predictors/onnx.py +154 -0
- sleap_nn/export/predictors/tensorrt.py +312 -0
- sleap_nn/export/utils.py +307 -0
- sleap_nn/export/wrappers/__init__.py +25 -0
- sleap_nn/export/wrappers/base.py +96 -0
- sleap_nn/export/wrappers/bottomup.py +243 -0
- sleap_nn/export/wrappers/bottomup_multiclass.py +195 -0
- sleap_nn/export/wrappers/centered_instance.py +56 -0
- sleap_nn/export/wrappers/centroid.py +58 -0
- sleap_nn/export/wrappers/single_instance.py +83 -0
- sleap_nn/export/wrappers/topdown.py +180 -0
- sleap_nn/export/wrappers/topdown_multiclass.py +304 -0
- sleap_nn/inference/bottomup.py +86 -20
- sleap_nn/inference/postprocessing.py +284 -0
- sleap_nn/predict.py +29 -0
- sleap_nn/train.py +64 -0
- sleap_nn/training/callbacks.py +324 -8
- sleap_nn/training/lightning_modules.py +542 -32
- sleap_nn/training/model_trainer.py +48 -57
- {sleap_nn-0.1.0a1.dist-info → sleap_nn-0.1.0a3.dist-info}/METADATA +13 -2
- {sleap_nn-0.1.0a1.dist-info → sleap_nn-0.1.0a3.dist-info}/RECORD +37 -16
- {sleap_nn-0.1.0a1.dist-info → sleap_nn-0.1.0a3.dist-info}/WHEEL +0 -0
- {sleap_nn-0.1.0a1.dist-info → sleap_nn-0.1.0a3.dist-info}/entry_points.txt +0 -0
- {sleap_nn-0.1.0a1.dist-info → sleap_nn-0.1.0a3.dist-info}/licenses/LICENSE +0 -0
- {sleap_nn-0.1.0a1.dist-info → sleap_nn-0.1.0a3.dist-info}/top_level.txt +0 -0
sleap_nn/training/callbacks.py
CHANGED
|
@@ -295,8 +295,8 @@ class WandBVizCallback(Callback):
|
|
|
295
295
|
suffix = "" if mode_name == "direct" else f"_{mode_name}"
|
|
296
296
|
train_img = renderer.render(train_data, caption=f"Train Epoch {epoch}")
|
|
297
297
|
val_img = renderer.render(val_data, caption=f"Val Epoch {epoch}")
|
|
298
|
-
log_dict[f"
|
|
299
|
-
log_dict[f"
|
|
298
|
+
log_dict[f"viz/train/predictions{suffix}"] = train_img
|
|
299
|
+
log_dict[f"viz/val/predictions{suffix}"] = val_img
|
|
300
300
|
|
|
301
301
|
if log_dict:
|
|
302
302
|
# Include epoch so wandb can use it as x-axis (via define_metric)
|
|
@@ -394,8 +394,8 @@ class WandBVizCallbackWithPAFs(WandBVizCallback):
|
|
|
394
394
|
suffix = "" if mode_name == "direct" else f"_{mode_name}"
|
|
395
395
|
train_img = renderer.render(train_data, caption=f"Train Epoch {epoch}")
|
|
396
396
|
val_img = renderer.render(val_data, caption=f"Val Epoch {epoch}")
|
|
397
|
-
log_dict[f"
|
|
398
|
-
log_dict[f"
|
|
397
|
+
log_dict[f"viz/train/predictions{suffix}"] = train_img
|
|
398
|
+
log_dict[f"viz/val/predictions{suffix}"] = val_img
|
|
399
399
|
|
|
400
400
|
# Render PAFs (always use matplotlib/direct for PAFs)
|
|
401
401
|
from io import BytesIO
|
|
@@ -408,7 +408,7 @@ class WandBVizCallbackWithPAFs(WandBVizCallback):
|
|
|
408
408
|
buf.seek(0)
|
|
409
409
|
plt.close(train_pafs_fig)
|
|
410
410
|
train_pafs_pil = Image.open(buf)
|
|
411
|
-
log_dict["
|
|
411
|
+
log_dict["viz/train/pafs"] = wandb.Image(
|
|
412
412
|
train_pafs_pil, caption=f"Train PAFs Epoch {epoch}"
|
|
413
413
|
)
|
|
414
414
|
|
|
@@ -418,7 +418,7 @@ class WandBVizCallbackWithPAFs(WandBVizCallback):
|
|
|
418
418
|
buf.seek(0)
|
|
419
419
|
plt.close(val_pafs_fig)
|
|
420
420
|
val_pafs_pil = Image.open(buf)
|
|
421
|
-
log_dict["
|
|
421
|
+
log_dict["viz/val/pafs"] = wandb.Image(
|
|
422
422
|
val_pafs_pil, caption=f"Val PAFs Epoch {epoch}"
|
|
423
423
|
)
|
|
424
424
|
|
|
@@ -444,8 +444,8 @@ class WandBVizCallbackWithPAFs(WandBVizCallback):
|
|
|
444
444
|
epoch,
|
|
445
445
|
train_img,
|
|
446
446
|
val_img,
|
|
447
|
-
log_dict["
|
|
448
|
-
log_dict["
|
|
447
|
+
log_dict["viz/train/pafs"],
|
|
448
|
+
log_dict["viz/val/pafs"],
|
|
449
449
|
]
|
|
450
450
|
],
|
|
451
451
|
)
|
|
@@ -662,3 +662,319 @@ class ProgressReporterZMQ(Callback):
|
|
|
662
662
|
return {
|
|
663
663
|
k: float(v.item()) if hasattr(v, "item") else v for k, v in logs.items()
|
|
664
664
|
}
|
|
665
|
+
|
|
666
|
+
|
|
667
|
+
class EpochEndEvaluationCallback(Callback):
|
|
668
|
+
"""Callback to run full evaluation metrics at end of validation epochs.
|
|
669
|
+
|
|
670
|
+
This callback collects predictions and ground truth during validation,
|
|
671
|
+
then runs the full evaluation pipeline (OKS, mAP, PCK, etc.) and logs
|
|
672
|
+
metrics to WandB.
|
|
673
|
+
|
|
674
|
+
Attributes:
|
|
675
|
+
skeleton: sio.Skeleton for creating instances.
|
|
676
|
+
videos: List of sio.Video objects.
|
|
677
|
+
eval_frequency: Run evaluation every N epochs (default: 1).
|
|
678
|
+
oks_stddev: OKS standard deviation (default: 0.025).
|
|
679
|
+
oks_scale: Optional OKS scale override.
|
|
680
|
+
metrics_to_log: List of metric keys to log.
|
|
681
|
+
"""
|
|
682
|
+
|
|
683
|
+
def __init__(
|
|
684
|
+
self,
|
|
685
|
+
skeleton: "sio.Skeleton",
|
|
686
|
+
videos: list,
|
|
687
|
+
eval_frequency: int = 1,
|
|
688
|
+
oks_stddev: float = 0.025,
|
|
689
|
+
oks_scale: Optional[float] = None,
|
|
690
|
+
metrics_to_log: Optional[list] = None,
|
|
691
|
+
):
|
|
692
|
+
"""Initialize the callback.
|
|
693
|
+
|
|
694
|
+
Args:
|
|
695
|
+
skeleton: sio.Skeleton for creating instances.
|
|
696
|
+
videos: List of sio.Video objects.
|
|
697
|
+
eval_frequency: Run evaluation every N epochs (default: 1).
|
|
698
|
+
oks_stddev: OKS standard deviation (default: 0.025).
|
|
699
|
+
oks_scale: Optional OKS scale override.
|
|
700
|
+
metrics_to_log: List of metric keys to log. If None, logs all available.
|
|
701
|
+
"""
|
|
702
|
+
super().__init__()
|
|
703
|
+
self.skeleton = skeleton
|
|
704
|
+
self.videos = videos
|
|
705
|
+
self.eval_frequency = eval_frequency
|
|
706
|
+
self.oks_stddev = oks_stddev
|
|
707
|
+
self.oks_scale = oks_scale
|
|
708
|
+
self.metrics_to_log = metrics_to_log or [
|
|
709
|
+
"mOKS",
|
|
710
|
+
"oks_voc.mAP",
|
|
711
|
+
"oks_voc.mAR",
|
|
712
|
+
"distance/avg",
|
|
713
|
+
"distance/p50",
|
|
714
|
+
"distance/p95",
|
|
715
|
+
"distance/p99",
|
|
716
|
+
"mPCK",
|
|
717
|
+
"PCK@5",
|
|
718
|
+
"PCK@10",
|
|
719
|
+
"visibility_precision",
|
|
720
|
+
"visibility_recall",
|
|
721
|
+
]
|
|
722
|
+
|
|
723
|
+
def on_validation_epoch_start(self, trainer, pl_module):
|
|
724
|
+
"""Enable prediction collection at the start of validation.
|
|
725
|
+
|
|
726
|
+
Skip during sanity check to avoid inference issues.
|
|
727
|
+
"""
|
|
728
|
+
if trainer.sanity_checking:
|
|
729
|
+
return
|
|
730
|
+
pl_module._collect_val_predictions = True
|
|
731
|
+
|
|
732
|
+
def on_validation_epoch_end(self, trainer, pl_module):
|
|
733
|
+
"""Run evaluation and log metrics at end of validation epoch."""
|
|
734
|
+
import sleap_io as sio
|
|
735
|
+
import numpy as np
|
|
736
|
+
from lightning.pytorch.loggers import WandbLogger
|
|
737
|
+
from sleap_nn.evaluation import Evaluator
|
|
738
|
+
|
|
739
|
+
# Check frequency (epoch is 0-indexed, so add 1)
|
|
740
|
+
if (trainer.current_epoch + 1) % self.eval_frequency != 0:
|
|
741
|
+
pl_module._collect_val_predictions = False
|
|
742
|
+
return
|
|
743
|
+
|
|
744
|
+
# Only run on rank 0 for distributed training
|
|
745
|
+
if not trainer.is_global_zero:
|
|
746
|
+
pl_module._collect_val_predictions = False
|
|
747
|
+
return
|
|
748
|
+
|
|
749
|
+
# Check if we have predictions
|
|
750
|
+
if not pl_module.val_predictions or not pl_module.val_ground_truth:
|
|
751
|
+
logger.warning("No predictions collected for epoch-end evaluation")
|
|
752
|
+
pl_module._collect_val_predictions = False
|
|
753
|
+
return
|
|
754
|
+
|
|
755
|
+
try:
|
|
756
|
+
# Build sio.Labels from accumulated predictions and ground truth
|
|
757
|
+
pred_labels = self._build_pred_labels(pl_module.val_predictions, sio, np)
|
|
758
|
+
gt_labels = self._build_gt_labels(pl_module.val_ground_truth, sio, np)
|
|
759
|
+
|
|
760
|
+
# Check if we have valid frames to evaluate
|
|
761
|
+
if len(pred_labels) == 0:
|
|
762
|
+
logger.warning(
|
|
763
|
+
"No valid predictions for epoch-end evaluation "
|
|
764
|
+
"(all predictions may be empty or NaN)"
|
|
765
|
+
)
|
|
766
|
+
pl_module._collect_val_predictions = False
|
|
767
|
+
pl_module.val_predictions = []
|
|
768
|
+
pl_module.val_ground_truth = []
|
|
769
|
+
return
|
|
770
|
+
|
|
771
|
+
# Run evaluation
|
|
772
|
+
evaluator = Evaluator(
|
|
773
|
+
ground_truth_instances=gt_labels,
|
|
774
|
+
predicted_instances=pred_labels,
|
|
775
|
+
oks_stddev=self.oks_stddev,
|
|
776
|
+
oks_scale=self.oks_scale,
|
|
777
|
+
user_labels_only=False, # All validation frames are "user" frames
|
|
778
|
+
)
|
|
779
|
+
metrics = evaluator.evaluate()
|
|
780
|
+
|
|
781
|
+
# Log to WandB
|
|
782
|
+
self._log_metrics(trainer, metrics, trainer.current_epoch)
|
|
783
|
+
|
|
784
|
+
logger.info(
|
|
785
|
+
f"Epoch {trainer.current_epoch} evaluation: "
|
|
786
|
+
f"PCK@5={metrics['pck_metrics']['PCK@5']:.4f}, "
|
|
787
|
+
f"mOKS={metrics['mOKS']['mOKS']:.4f}, "
|
|
788
|
+
f"mAP={metrics['voc_metrics']['oks_voc.mAP']:.4f}"
|
|
789
|
+
)
|
|
790
|
+
|
|
791
|
+
except Exception as e:
|
|
792
|
+
logger.warning(f"Epoch-end evaluation failed: {e}")
|
|
793
|
+
|
|
794
|
+
# Cleanup
|
|
795
|
+
pl_module._collect_val_predictions = False
|
|
796
|
+
pl_module.val_predictions = []
|
|
797
|
+
pl_module.val_ground_truth = []
|
|
798
|
+
|
|
799
|
+
def _build_pred_labels(self, predictions: list, sio, np) -> "sio.Labels":
|
|
800
|
+
"""Convert prediction dicts to sio.Labels."""
|
|
801
|
+
labeled_frames = []
|
|
802
|
+
for pred in predictions:
|
|
803
|
+
pred_peaks = pred["pred_peaks"]
|
|
804
|
+
pred_scores = pred["pred_scores"]
|
|
805
|
+
|
|
806
|
+
# Handle NaN/missing predictions
|
|
807
|
+
if pred_peaks is None or (
|
|
808
|
+
isinstance(pred_peaks, np.ndarray) and np.isnan(pred_peaks).all()
|
|
809
|
+
):
|
|
810
|
+
continue
|
|
811
|
+
|
|
812
|
+
# Handle multi-instance predictions (bottomup)
|
|
813
|
+
if len(pred_peaks.shape) == 2:
|
|
814
|
+
# Single instance: (n_nodes, 2) -> (1, n_nodes, 2)
|
|
815
|
+
pred_peaks = pred_peaks.reshape(1, -1, 2)
|
|
816
|
+
pred_scores = pred_scores.reshape(1, -1)
|
|
817
|
+
|
|
818
|
+
instances = []
|
|
819
|
+
for inst_idx in range(len(pred_peaks)):
|
|
820
|
+
inst_points = pred_peaks[inst_idx]
|
|
821
|
+
inst_scores = pred_scores[inst_idx] if pred_scores is not None else None
|
|
822
|
+
|
|
823
|
+
# Skip if all NaN
|
|
824
|
+
if np.isnan(inst_points).all():
|
|
825
|
+
continue
|
|
826
|
+
|
|
827
|
+
inst = sio.PredictedInstance.from_numpy(
|
|
828
|
+
points_data=inst_points,
|
|
829
|
+
skeleton=self.skeleton,
|
|
830
|
+
point_scores=(
|
|
831
|
+
inst_scores
|
|
832
|
+
if inst_scores is not None
|
|
833
|
+
else np.ones(len(inst_points))
|
|
834
|
+
),
|
|
835
|
+
score=(
|
|
836
|
+
float(np.nanmean(inst_scores))
|
|
837
|
+
if inst_scores is not None
|
|
838
|
+
else 1.0
|
|
839
|
+
),
|
|
840
|
+
)
|
|
841
|
+
instances.append(inst)
|
|
842
|
+
|
|
843
|
+
if instances:
|
|
844
|
+
lf = sio.LabeledFrame(
|
|
845
|
+
video=self.videos[pred["video_idx"]],
|
|
846
|
+
frame_idx=pred["frame_idx"],
|
|
847
|
+
instances=instances,
|
|
848
|
+
)
|
|
849
|
+
labeled_frames.append(lf)
|
|
850
|
+
|
|
851
|
+
return sio.Labels(
|
|
852
|
+
videos=self.videos,
|
|
853
|
+
skeletons=[self.skeleton],
|
|
854
|
+
labeled_frames=labeled_frames,
|
|
855
|
+
)
|
|
856
|
+
|
|
857
|
+
def _build_gt_labels(self, ground_truth: list, sio, np) -> "sio.Labels":
|
|
858
|
+
"""Convert ground truth dicts to sio.Labels."""
|
|
859
|
+
labeled_frames = []
|
|
860
|
+
for gt in ground_truth:
|
|
861
|
+
instances = []
|
|
862
|
+
gt_instances = gt["gt_instances"]
|
|
863
|
+
|
|
864
|
+
# Handle shape variations
|
|
865
|
+
if len(gt_instances.shape) == 2:
|
|
866
|
+
# (n_nodes, 2) -> (1, n_nodes, 2)
|
|
867
|
+
gt_instances = gt_instances.reshape(1, -1, 2)
|
|
868
|
+
|
|
869
|
+
for i in range(min(gt["num_instances"], len(gt_instances))):
|
|
870
|
+
inst_data = gt_instances[i]
|
|
871
|
+
if np.isnan(inst_data).all():
|
|
872
|
+
continue
|
|
873
|
+
inst = sio.Instance.from_numpy(
|
|
874
|
+
points_data=inst_data,
|
|
875
|
+
skeleton=self.skeleton,
|
|
876
|
+
)
|
|
877
|
+
instances.append(inst)
|
|
878
|
+
|
|
879
|
+
if instances:
|
|
880
|
+
lf = sio.LabeledFrame(
|
|
881
|
+
video=self.videos[gt["video_idx"]],
|
|
882
|
+
frame_idx=gt["frame_idx"],
|
|
883
|
+
instances=instances,
|
|
884
|
+
)
|
|
885
|
+
labeled_frames.append(lf)
|
|
886
|
+
|
|
887
|
+
return sio.Labels(
|
|
888
|
+
videos=self.videos,
|
|
889
|
+
skeletons=[self.skeleton],
|
|
890
|
+
labeled_frames=labeled_frames,
|
|
891
|
+
)
|
|
892
|
+
|
|
893
|
+
def _log_metrics(self, trainer, metrics: dict, epoch: int):
|
|
894
|
+
"""Log evaluation metrics to WandB."""
|
|
895
|
+
import numpy as np
|
|
896
|
+
from lightning.pytorch.loggers import WandbLogger
|
|
897
|
+
|
|
898
|
+
# Get WandB logger
|
|
899
|
+
wandb_logger = None
|
|
900
|
+
for log in trainer.loggers:
|
|
901
|
+
if isinstance(log, WandbLogger):
|
|
902
|
+
wandb_logger = log
|
|
903
|
+
break
|
|
904
|
+
|
|
905
|
+
if wandb_logger is None:
|
|
906
|
+
return
|
|
907
|
+
|
|
908
|
+
log_dict = {"epoch": epoch}
|
|
909
|
+
|
|
910
|
+
# Extract key metrics with consistent naming
|
|
911
|
+
# All eval metrics use eval/val/ prefix since they're computed on validation data
|
|
912
|
+
if "mOKS" in self.metrics_to_log:
|
|
913
|
+
log_dict["eval/val/mOKS"] = metrics["mOKS"]["mOKS"]
|
|
914
|
+
|
|
915
|
+
if "oks_voc.mAP" in self.metrics_to_log:
|
|
916
|
+
log_dict["eval/val/oks_voc_mAP"] = metrics["voc_metrics"]["oks_voc.mAP"]
|
|
917
|
+
|
|
918
|
+
if "oks_voc.mAR" in self.metrics_to_log:
|
|
919
|
+
log_dict["eval/val/oks_voc_mAR"] = metrics["voc_metrics"]["oks_voc.mAR"]
|
|
920
|
+
|
|
921
|
+
# Distance metrics grouped under eval/val/distance/
|
|
922
|
+
if "distance/avg" in self.metrics_to_log:
|
|
923
|
+
val = metrics["distance_metrics"]["avg"]
|
|
924
|
+
if not np.isnan(val):
|
|
925
|
+
log_dict["eval/val/distance/avg"] = val
|
|
926
|
+
|
|
927
|
+
if "distance/p50" in self.metrics_to_log:
|
|
928
|
+
val = metrics["distance_metrics"]["p50"]
|
|
929
|
+
if not np.isnan(val):
|
|
930
|
+
log_dict["eval/val/distance/p50"] = val
|
|
931
|
+
|
|
932
|
+
if "distance/p95" in self.metrics_to_log:
|
|
933
|
+
val = metrics["distance_metrics"]["p95"]
|
|
934
|
+
if not np.isnan(val):
|
|
935
|
+
log_dict["eval/val/distance/p95"] = val
|
|
936
|
+
|
|
937
|
+
if "distance/p99" in self.metrics_to_log:
|
|
938
|
+
val = metrics["distance_metrics"]["p99"]
|
|
939
|
+
if not np.isnan(val):
|
|
940
|
+
log_dict["eval/val/distance/p99"] = val
|
|
941
|
+
|
|
942
|
+
# PCK metrics
|
|
943
|
+
if "mPCK" in self.metrics_to_log:
|
|
944
|
+
log_dict["eval/val/mPCK"] = metrics["pck_metrics"]["mPCK"]
|
|
945
|
+
|
|
946
|
+
# PCK at specific thresholds (precomputed in evaluation.py)
|
|
947
|
+
if "PCK@5" in self.metrics_to_log:
|
|
948
|
+
log_dict["eval/val/PCK_5"] = metrics["pck_metrics"]["PCK@5"]
|
|
949
|
+
|
|
950
|
+
if "PCK@10" in self.metrics_to_log:
|
|
951
|
+
log_dict["eval/val/PCK_10"] = metrics["pck_metrics"]["PCK@10"]
|
|
952
|
+
|
|
953
|
+
# Visibility metrics
|
|
954
|
+
if "visibility_precision" in self.metrics_to_log:
|
|
955
|
+
val = metrics["visibility_metrics"]["precision"]
|
|
956
|
+
if not np.isnan(val):
|
|
957
|
+
log_dict["eval/val/visibility_precision"] = val
|
|
958
|
+
|
|
959
|
+
if "visibility_recall" in self.metrics_to_log:
|
|
960
|
+
val = metrics["visibility_metrics"]["recall"]
|
|
961
|
+
if not np.isnan(val):
|
|
962
|
+
log_dict["eval/val/visibility_recall"] = val
|
|
963
|
+
|
|
964
|
+
wandb_logger.experiment.log(log_dict, commit=False)
|
|
965
|
+
|
|
966
|
+
# Update best metrics in summary (excluding epoch)
|
|
967
|
+
for key, value in log_dict.items():
|
|
968
|
+
if key == "epoch":
|
|
969
|
+
continue
|
|
970
|
+
# Create summary key like "best/eval/val/mOKS"
|
|
971
|
+
summary_key = f"best/{key}"
|
|
972
|
+
current_best = wandb_logger.experiment.summary.get(summary_key)
|
|
973
|
+
# For distance metrics, lower is better; for others, higher is better
|
|
974
|
+
is_distance = "distance" in key
|
|
975
|
+
if current_best is None:
|
|
976
|
+
wandb_logger.experiment.summary[summary_key] = value
|
|
977
|
+
elif is_distance and value < current_best:
|
|
978
|
+
wandb_logger.experiment.summary[summary_key] = value
|
|
979
|
+
elif not is_distance and value > current_best:
|
|
980
|
+
wandb_logger.experiment.summary[summary_key] = value
|