sleap-nn 0.1.0a1__py3-none-any.whl → 0.1.0a3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. sleap_nn/__init__.py +1 -1
  2. sleap_nn/cli.py +36 -0
  3. sleap_nn/config/trainer_config.py +18 -0
  4. sleap_nn/evaluation.py +81 -22
  5. sleap_nn/export/__init__.py +21 -0
  6. sleap_nn/export/cli.py +1778 -0
  7. sleap_nn/export/exporters/__init__.py +51 -0
  8. sleap_nn/export/exporters/onnx_exporter.py +80 -0
  9. sleap_nn/export/exporters/tensorrt_exporter.py +291 -0
  10. sleap_nn/export/metadata.py +225 -0
  11. sleap_nn/export/predictors/__init__.py +63 -0
  12. sleap_nn/export/predictors/base.py +22 -0
  13. sleap_nn/export/predictors/onnx.py +154 -0
  14. sleap_nn/export/predictors/tensorrt.py +312 -0
  15. sleap_nn/export/utils.py +307 -0
  16. sleap_nn/export/wrappers/__init__.py +25 -0
  17. sleap_nn/export/wrappers/base.py +96 -0
  18. sleap_nn/export/wrappers/bottomup.py +243 -0
  19. sleap_nn/export/wrappers/bottomup_multiclass.py +195 -0
  20. sleap_nn/export/wrappers/centered_instance.py +56 -0
  21. sleap_nn/export/wrappers/centroid.py +58 -0
  22. sleap_nn/export/wrappers/single_instance.py +83 -0
  23. sleap_nn/export/wrappers/topdown.py +180 -0
  24. sleap_nn/export/wrappers/topdown_multiclass.py +304 -0
  25. sleap_nn/inference/bottomup.py +86 -20
  26. sleap_nn/inference/postprocessing.py +284 -0
  27. sleap_nn/predict.py +29 -0
  28. sleap_nn/train.py +64 -0
  29. sleap_nn/training/callbacks.py +324 -8
  30. sleap_nn/training/lightning_modules.py +542 -32
  31. sleap_nn/training/model_trainer.py +48 -57
  32. {sleap_nn-0.1.0a1.dist-info → sleap_nn-0.1.0a3.dist-info}/METADATA +13 -2
  33. {sleap_nn-0.1.0a1.dist-info → sleap_nn-0.1.0a3.dist-info}/RECORD +37 -16
  34. {sleap_nn-0.1.0a1.dist-info → sleap_nn-0.1.0a3.dist-info}/WHEEL +0 -0
  35. {sleap_nn-0.1.0a1.dist-info → sleap_nn-0.1.0a3.dist-info}/entry_points.txt +0 -0
  36. {sleap_nn-0.1.0a1.dist-info → sleap_nn-0.1.0a3.dist-info}/licenses/LICENSE +0 -0
  37. {sleap_nn-0.1.0a1.dist-info → sleap_nn-0.1.0a3.dist-info}/top_level.txt +0 -0
sleap_nn/__init__.py CHANGED
@@ -50,7 +50,7 @@ logger.add(
50
50
  colorize=False,
51
51
  )
52
52
 
53
- __version__ = "0.1.0a1"
53
+ __version__ = "0.1.0a3"
54
54
 
55
55
  # Public API
56
56
  from sleap_nn.evaluation import load_metrics
sleap_nn/cli.py CHANGED
@@ -7,6 +7,8 @@ from omegaconf import OmegaConf, DictConfig
7
7
  import sleap_io as sio
8
8
  from sleap_nn.predict import run_inference, frame_list
9
9
  from sleap_nn.evaluation import run_evaluation
10
+ from sleap_nn.export.cli import export as export_command
11
+ from sleap_nn.export.cli import predict as predict_command
10
12
  from sleap_nn.train import run_training
11
13
  from sleap_nn import __version__
12
14
  import hydra
@@ -417,6 +419,36 @@ def train(config_name, config_dir, video_paths, video_path_map, prefix_map, over
417
419
  default=0.2,
418
420
  help="Minimum confidence map value to consider a peak as valid.",
419
421
  )
422
+ @click.option(
423
+ "--filter_overlapping",
424
+ is_flag=True,
425
+ default=False,
426
+ help=(
427
+ "Enable filtering of overlapping instances after inference using greedy NMS. "
428
+ "Applied independently of tracking. (default: False)"
429
+ ),
430
+ )
431
+ @click.option(
432
+ "--filter_overlapping_method",
433
+ type=click.Choice(["iou", "oks"]),
434
+ default="iou",
435
+ help=(
436
+ "Similarity metric for filtering overlapping instances. "
437
+ "'iou': bounding box intersection-over-union. "
438
+ "'oks': Object Keypoint Similarity (pose-based). (default: iou)"
439
+ ),
440
+ )
441
+ @click.option(
442
+ "--filter_overlapping_threshold",
443
+ type=float,
444
+ default=0.8,
445
+ help=(
446
+ "Similarity threshold for filtering overlapping instances. "
447
+ "Instances with similarity above this threshold are removed, "
448
+ "keeping the higher-scoring instance. "
449
+ "Typical values: 0.3 (aggressive) to 0.8 (permissive). (default: 0.8)"
450
+ ),
451
+ )
420
452
  @click.option(
421
453
  "--integral_refinement",
422
454
  type=str,
@@ -613,5 +645,9 @@ def system():
613
645
  print_system_info()
614
646
 
615
647
 
648
+ cli.add_command(export_command)
649
+ cli.add_command(predict_command)
650
+
651
+
616
652
  if __name__ == "__main__":
617
653
  cli()
@@ -208,6 +208,23 @@ class EarlyStoppingConfig:
208
208
  stop_training_on_plateau: bool = True
209
209
 
210
210
 
211
+ @define
212
+ class EvalConfig:
213
+ """Configuration for epoch-end evaluation.
214
+
215
+ Attributes:
216
+ enabled: (bool) Enable epoch-end evaluation metrics. *Default*: `False`.
217
+ frequency: (int) Evaluate every N epochs. *Default*: `1`.
218
+ oks_stddev: (float) OKS standard deviation for evaluation. *Default*: `0.025`.
219
+ oks_scale: (float) OKS scale override. If None, uses default. *Default*: `None`.
220
+ """
221
+
222
+ enabled: bool = False
223
+ frequency: int = field(default=1, validator=validators.ge(1))
224
+ oks_stddev: float = field(default=0.025, validator=validators.gt(0))
225
+ oks_scale: Optional[float] = None
226
+
227
+
211
228
  @define
212
229
  class HardKeypointMiningConfig:
213
230
  """Configuration for online hard keypoint mining.
@@ -310,6 +327,7 @@ class TrainerConfig:
310
327
  factory=HardKeypointMiningConfig
311
328
  )
312
329
  zmq: Optional[ZMQConfig] = field(factory=ZMQConfig) # Required for SLEAP GUI
330
+ eval: EvalConfig = field(factory=EvalConfig) # Epoch-end evaluation config
313
331
 
314
332
  @staticmethod
315
333
  def validate_optimizer_name(value):
sleap_nn/evaluation.py CHANGED
@@ -29,11 +29,27 @@ def get_instances(labeled_frame: sio.LabeledFrame) -> List[MatchInstance]:
29
29
  """
30
30
  instance_list = []
31
31
  frame_idx = labeled_frame.frame_idx
32
- video_path = (
33
- labeled_frame.video.backend.source_filename
34
- if hasattr(labeled_frame.video.backend, "source_filename")
35
- else labeled_frame.video.backend.filename
36
- )
32
+
33
+ # Extract video path with fallbacks for embedded videos
34
+ video = labeled_frame.video
35
+ video_path = None
36
+ if video is not None:
37
+ backend = getattr(video, "backend", None)
38
+ if backend is not None:
39
+ # Try source_filename first (for embedded videos with provenance)
40
+ video_path = getattr(backend, "source_filename", None)
41
+ if video_path is None:
42
+ video_path = getattr(backend, "filename", None)
43
+ # Fallback to video.filename if backend doesn't have it
44
+ if video_path is None:
45
+ video_path = getattr(video, "filename", None)
46
+ # Handle list filenames (image sequences)
47
+ if isinstance(video_path, list) and video_path:
48
+ video_path = video_path[0]
49
+ # Final fallback: use a unique identifier
50
+ if video_path is None:
51
+ video_path = f"video_{id(video)}" if video is not None else "unknown"
52
+
37
53
  for instance in labeled_frame.instances:
38
54
  match_instance = MatchInstance(
39
55
  instance=instance, frame_idx=frame_idx, video_path=video_path
@@ -47,6 +63,10 @@ def find_frame_pairs(
47
63
  ) -> List[Tuple[sio.LabeledFrame, sio.LabeledFrame]]:
48
64
  """Find corresponding frames across two sets of labels.
49
65
 
66
+ This function uses sleap-io's robust video matching API to handle various
67
+ scenarios including embedded videos, cross-platform paths, and videos with
68
+ different metadata.
69
+
50
70
  Args:
51
71
  labels_gt: A `sio.Labels` instance with ground truth instances.
52
72
  labels_pr: A `sio.Labels` instance with predicted instances.
@@ -56,16 +76,15 @@ def find_frame_pairs(
56
76
  Returns:
57
77
  A list of pairs of `sio.LabeledFrame`s in the form `(frame_gt, frame_pr)`.
58
78
  """
79
+ # Use sleap-io's robust video matching API (added in 0.6.2)
80
+ # The match() method returns a MatchResult with video_map: {pred_video: gt_video}
81
+ match_result = labels_gt.match(labels_pr)
82
+
59
83
  frame_pairs = []
60
- for video_gt in labels_gt.videos:
61
- # Find matching video instance in predictions.
62
- video_pr = None
63
- for video in labels_pr.videos:
64
- if video_gt.matches_content(video) and video_gt.matches_path(video):
65
- video_pr = video
66
- break
67
-
68
- if video_pr is None:
84
+ # Iterate over matched video pairs (pred_video -> gt_video mapping)
85
+ for video_pr, video_gt in match_result.video_map.items():
86
+ if video_gt is None:
87
+ # No match found for this prediction video
69
88
  continue
70
89
 
71
90
  # Find labeled frames in this video.
@@ -620,11 +639,19 @@ class Evaluator:
620
639
  mPCK_parts = pcks.mean(axis=0).mean(axis=-1)
621
640
  mPCK = mPCK_parts.mean()
622
641
 
642
+ # Precompute PCK at common thresholds
643
+ idx_5 = np.argmin(np.abs(thresholds - 5))
644
+ idx_10 = np.argmin(np.abs(thresholds - 10))
645
+ pck5 = pcks[:, :, idx_5].mean()
646
+ pck10 = pcks[:, :, idx_10].mean()
647
+
623
648
  return {
624
649
  "thresholds": thresholds,
625
650
  "pcks": pcks,
626
651
  "mPCK_parts": mPCK_parts,
627
652
  "mPCK": mPCK,
653
+ "PCK@5": pck5,
654
+ "PCK@10": pck10,
628
655
  }
629
656
 
630
657
  def visibility_metrics(self):
@@ -786,11 +813,26 @@ def run_evaluation(
786
813
  """Evaluate SLEAP-NN model predictions against ground truth labels."""
787
814
  logger.info("Loading ground truth labels...")
788
815
  ground_truth_instances = sio.load_slp(ground_truth_path)
816
+ logger.info(
817
+ f" Ground truth: {len(ground_truth_instances.videos)} videos, "
818
+ f"{len(ground_truth_instances.labeled_frames)} frames"
819
+ )
789
820
 
790
821
  logger.info("Loading predicted labels...")
791
822
  predicted_instances = sio.load_slp(predicted_path)
823
+ logger.info(
824
+ f" Predictions: {len(predicted_instances.videos)} videos, "
825
+ f"{len(predicted_instances.labeled_frames)} frames"
826
+ )
827
+
828
+ logger.info("Matching videos and frames...")
829
+ # Get match stats before creating evaluator
830
+ match_result = ground_truth_instances.match(predicted_instances)
831
+ logger.info(
832
+ f" Videos matched: {match_result.n_videos_matched}/{len(match_result.video_map)}"
833
+ )
792
834
 
793
- logger.info("Creating evaluator...")
835
+ logger.info("Matching instances...")
794
836
  evaluator = Evaluator(
795
837
  ground_truth_instances=ground_truth_instances,
796
838
  predicted_instances=predicted_instances,
@@ -799,21 +841,38 @@ def run_evaluation(
799
841
  match_threshold=match_threshold,
800
842
  user_labels_only=user_labels_only,
801
843
  )
844
+ logger.info(
845
+ f" Frame pairs: {len(evaluator.frame_pairs)}, "
846
+ f"Matched instances: {len(evaluator.positive_pairs)}, "
847
+ f"Unmatched GT: {len(evaluator.false_negatives)}"
848
+ )
802
849
 
803
850
  logger.info("Computing evaluation metrics...")
804
851
  metrics = evaluator.evaluate()
805
852
 
853
+ # Compute PCK at specific thresholds (5 and 10 pixels)
854
+ dists = metrics["distance_metrics"]["dists"]
855
+ dists_clean = np.copy(dists)
856
+ dists_clean[np.isnan(dists_clean)] = np.inf
857
+ pck_5 = (dists_clean < 5).mean()
858
+ pck_10 = (dists_clean < 10).mean()
859
+
806
860
  # Print key metrics
807
861
  logger.info("Evaluation Results:")
808
- logger.info(f"mOKS: {metrics['mOKS']['mOKS']:.4f}")
809
- logger.info(f"mAP (OKS VOC): {metrics['voc_metrics']['oks_voc.mAP']:.4f}")
810
- logger.info(f"mAR (OKS VOC): {metrics['voc_metrics']['oks_voc.mAR']:.4f}")
811
- logger.info(f"Average Distance: {metrics['distance_metrics']['avg']:.4f}")
812
- logger.info(f"mPCK: {metrics['pck_metrics']['mPCK']:.4f}")
862
+ logger.info(f" mOKS: {metrics['mOKS']['mOKS']:.4f}")
863
+ logger.info(f" mAP (OKS VOC): {metrics['voc_metrics']['oks_voc.mAP']:.4f}")
864
+ logger.info(f" mAR (OKS VOC): {metrics['voc_metrics']['oks_voc.mAR']:.4f}")
865
+ logger.info(f" Average Distance: {metrics['distance_metrics']['avg']:.2f} px")
866
+ logger.info(f" dist.p50: {metrics['distance_metrics']['p50']:.2f} px")
867
+ logger.info(f" dist.p95: {metrics['distance_metrics']['p95']:.2f} px")
868
+ logger.info(f" dist.p99: {metrics['distance_metrics']['p99']:.2f} px")
869
+ logger.info(f" mPCK: {metrics['pck_metrics']['mPCK']:.4f}")
870
+ logger.info(f" PCK@5px: {pck_5:.4f}")
871
+ logger.info(f" PCK@10px: {pck_10:.4f}")
813
872
  logger.info(
814
- f"Visibility Precision: {metrics['visibility_metrics']['precision']:.4f}"
873
+ f" Visibility Precision: {metrics['visibility_metrics']['precision']:.4f}"
815
874
  )
816
- logger.info(f"Visibility Recall: {metrics['visibility_metrics']['recall']:.4f}")
875
+ logger.info(f" Visibility Recall: {metrics['visibility_metrics']['recall']:.4f}")
817
876
 
818
877
  # Save metrics if path provided
819
878
  if save_metrics:
@@ -0,0 +1,21 @@
1
+ """Export utilities for sleap-nn."""
2
+
3
+ from sleap_nn.export.exporters import export_model, export_to_onnx, export_to_tensorrt
4
+ from sleap_nn.export.metadata import ExportMetadata
5
+ from sleap_nn.export.predictors import (
6
+ load_exported_model,
7
+ ONNXPredictor,
8
+ TensorRTPredictor,
9
+ )
10
+ from sleap_nn.export.utils import build_bottomup_candidate_template
11
+
12
+ __all__ = [
13
+ "export_model",
14
+ "export_to_onnx",
15
+ "export_to_tensorrt",
16
+ "load_exported_model",
17
+ "ONNXPredictor",
18
+ "TensorRTPredictor",
19
+ "ExportMetadata",
20
+ "build_bottomup_candidate_template",
21
+ ]