sleap-nn 0.1.0a2__py3-none-any.whl → 0.1.0a3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. sleap_nn/__init__.py +1 -1
  2. sleap_nn/cli.py +36 -0
  3. sleap_nn/evaluation.py +8 -0
  4. sleap_nn/export/__init__.py +21 -0
  5. sleap_nn/export/cli.py +1778 -0
  6. sleap_nn/export/exporters/__init__.py +51 -0
  7. sleap_nn/export/exporters/onnx_exporter.py +80 -0
  8. sleap_nn/export/exporters/tensorrt_exporter.py +291 -0
  9. sleap_nn/export/metadata.py +225 -0
  10. sleap_nn/export/predictors/__init__.py +63 -0
  11. sleap_nn/export/predictors/base.py +22 -0
  12. sleap_nn/export/predictors/onnx.py +154 -0
  13. sleap_nn/export/predictors/tensorrt.py +312 -0
  14. sleap_nn/export/utils.py +307 -0
  15. sleap_nn/export/wrappers/__init__.py +25 -0
  16. sleap_nn/export/wrappers/base.py +96 -0
  17. sleap_nn/export/wrappers/bottomup.py +243 -0
  18. sleap_nn/export/wrappers/bottomup_multiclass.py +195 -0
  19. sleap_nn/export/wrappers/centered_instance.py +56 -0
  20. sleap_nn/export/wrappers/centroid.py +58 -0
  21. sleap_nn/export/wrappers/single_instance.py +83 -0
  22. sleap_nn/export/wrappers/topdown.py +180 -0
  23. sleap_nn/export/wrappers/topdown_multiclass.py +304 -0
  24. sleap_nn/inference/postprocessing.py +284 -0
  25. sleap_nn/predict.py +29 -0
  26. sleap_nn/train.py +64 -0
  27. sleap_nn/training/callbacks.py +62 -20
  28. sleap_nn/training/lightning_modules.py +332 -30
  29. sleap_nn/training/model_trainer.py +35 -67
  30. {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a3.dist-info}/METADATA +12 -1
  31. {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a3.dist-info}/RECORD +35 -14
  32. {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a3.dist-info}/WHEEL +0 -0
  33. {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a3.dist-info}/entry_points.txt +0 -0
  34. {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a3.dist-info}/licenses/LICENSE +0 -0
  35. {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a3.dist-info}/top_level.txt +0 -0
sleap_nn/__init__.py CHANGED
@@ -50,7 +50,7 @@ logger.add(
50
50
  colorize=False,
51
51
  )
52
52
 
53
- __version__ = "0.1.0a2"
53
+ __version__ = "0.1.0a3"
54
54
 
55
55
  # Public API
56
56
  from sleap_nn.evaluation import load_metrics
sleap_nn/cli.py CHANGED
@@ -7,6 +7,8 @@ from omegaconf import OmegaConf, DictConfig
7
7
  import sleap_io as sio
8
8
  from sleap_nn.predict import run_inference, frame_list
9
9
  from sleap_nn.evaluation import run_evaluation
10
+ from sleap_nn.export.cli import export as export_command
11
+ from sleap_nn.export.cli import predict as predict_command
10
12
  from sleap_nn.train import run_training
11
13
  from sleap_nn import __version__
12
14
  import hydra
@@ -417,6 +419,36 @@ def train(config_name, config_dir, video_paths, video_path_map, prefix_map, over
417
419
  default=0.2,
418
420
  help="Minimum confidence map value to consider a peak as valid.",
419
421
  )
422
+ @click.option(
423
+ "--filter_overlapping",
424
+ is_flag=True,
425
+ default=False,
426
+ help=(
427
+ "Enable filtering of overlapping instances after inference using greedy NMS. "
428
+ "Applied independently of tracking. (default: False)"
429
+ ),
430
+ )
431
+ @click.option(
432
+ "--filter_overlapping_method",
433
+ type=click.Choice(["iou", "oks"]),
434
+ default="iou",
435
+ help=(
436
+ "Similarity metric for filtering overlapping instances. "
437
+ "'iou': bounding box intersection-over-union. "
438
+ "'oks': Object Keypoint Similarity (pose-based). (default: iou)"
439
+ ),
440
+ )
441
+ @click.option(
442
+ "--filter_overlapping_threshold",
443
+ type=float,
444
+ default=0.8,
445
+ help=(
446
+ "Similarity threshold for filtering overlapping instances. "
447
+ "Instances with similarity above this threshold are removed, "
448
+ "keeping the higher-scoring instance. "
449
+ "Typical values: 0.3 (aggressive) to 0.8 (permissive). (default: 0.8)"
450
+ ),
451
+ )
420
452
  @click.option(
421
453
  "--integral_refinement",
422
454
  type=str,
@@ -613,5 +645,9 @@ def system():
613
645
  print_system_info()
614
646
 
615
647
 
648
+ cli.add_command(export_command)
649
+ cli.add_command(predict_command)
650
+
651
+
616
652
  if __name__ == "__main__":
617
653
  cli()
sleap_nn/evaluation.py CHANGED
@@ -639,11 +639,19 @@ class Evaluator:
639
639
  mPCK_parts = pcks.mean(axis=0).mean(axis=-1)
640
640
  mPCK = mPCK_parts.mean()
641
641
 
642
+ # Precompute PCK at common thresholds
643
+ idx_5 = np.argmin(np.abs(thresholds - 5))
644
+ idx_10 = np.argmin(np.abs(thresholds - 10))
645
+ pck5 = pcks[:, :, idx_5].mean()
646
+ pck10 = pcks[:, :, idx_10].mean()
647
+
642
648
  return {
643
649
  "thresholds": thresholds,
644
650
  "pcks": pcks,
645
651
  "mPCK_parts": mPCK_parts,
646
652
  "mPCK": mPCK,
653
+ "PCK@5": pck5,
654
+ "PCK@10": pck10,
647
655
  }
648
656
 
649
657
  def visibility_metrics(self):
@@ -0,0 +1,21 @@
1
+ """Export utilities for sleap-nn."""
2
+
3
+ from sleap_nn.export.exporters import export_model, export_to_onnx, export_to_tensorrt
4
+ from sleap_nn.export.metadata import ExportMetadata
5
+ from sleap_nn.export.predictors import (
6
+ load_exported_model,
7
+ ONNXPredictor,
8
+ TensorRTPredictor,
9
+ )
10
+ from sleap_nn.export.utils import build_bottomup_candidate_template
11
+
12
+ __all__ = [
13
+ "export_model",
14
+ "export_to_onnx",
15
+ "export_to_tensorrt",
16
+ "load_exported_model",
17
+ "ONNXPredictor",
18
+ "TensorRTPredictor",
19
+ "ExportMetadata",
20
+ "build_bottomup_candidate_template",
21
+ ]