sleap-nn 0.1.0a1__py3-none-any.whl → 0.1.0a3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sleap_nn/__init__.py +1 -1
- sleap_nn/cli.py +36 -0
- sleap_nn/config/trainer_config.py +18 -0
- sleap_nn/evaluation.py +81 -22
- sleap_nn/export/__init__.py +21 -0
- sleap_nn/export/cli.py +1778 -0
- sleap_nn/export/exporters/__init__.py +51 -0
- sleap_nn/export/exporters/onnx_exporter.py +80 -0
- sleap_nn/export/exporters/tensorrt_exporter.py +291 -0
- sleap_nn/export/metadata.py +225 -0
- sleap_nn/export/predictors/__init__.py +63 -0
- sleap_nn/export/predictors/base.py +22 -0
- sleap_nn/export/predictors/onnx.py +154 -0
- sleap_nn/export/predictors/tensorrt.py +312 -0
- sleap_nn/export/utils.py +307 -0
- sleap_nn/export/wrappers/__init__.py +25 -0
- sleap_nn/export/wrappers/base.py +96 -0
- sleap_nn/export/wrappers/bottomup.py +243 -0
- sleap_nn/export/wrappers/bottomup_multiclass.py +195 -0
- sleap_nn/export/wrappers/centered_instance.py +56 -0
- sleap_nn/export/wrappers/centroid.py +58 -0
- sleap_nn/export/wrappers/single_instance.py +83 -0
- sleap_nn/export/wrappers/topdown.py +180 -0
- sleap_nn/export/wrappers/topdown_multiclass.py +304 -0
- sleap_nn/inference/bottomup.py +86 -20
- sleap_nn/inference/postprocessing.py +284 -0
- sleap_nn/predict.py +29 -0
- sleap_nn/train.py +64 -0
- sleap_nn/training/callbacks.py +324 -8
- sleap_nn/training/lightning_modules.py +542 -32
- sleap_nn/training/model_trainer.py +48 -57
- {sleap_nn-0.1.0a1.dist-info → sleap_nn-0.1.0a3.dist-info}/METADATA +13 -2
- {sleap_nn-0.1.0a1.dist-info → sleap_nn-0.1.0a3.dist-info}/RECORD +37 -16
- {sleap_nn-0.1.0a1.dist-info → sleap_nn-0.1.0a3.dist-info}/WHEEL +0 -0
- {sleap_nn-0.1.0a1.dist-info → sleap_nn-0.1.0a3.dist-info}/entry_points.txt +0 -0
- {sleap_nn-0.1.0a1.dist-info → sleap_nn-0.1.0a3.dist-info}/licenses/LICENSE +0 -0
- {sleap_nn-0.1.0a1.dist-info → sleap_nn-0.1.0a3.dist-info}/top_level.txt +0 -0
sleap_nn/__init__.py
CHANGED
sleap_nn/cli.py
CHANGED
|
@@ -7,6 +7,8 @@ from omegaconf import OmegaConf, DictConfig
|
|
|
7
7
|
import sleap_io as sio
|
|
8
8
|
from sleap_nn.predict import run_inference, frame_list
|
|
9
9
|
from sleap_nn.evaluation import run_evaluation
|
|
10
|
+
from sleap_nn.export.cli import export as export_command
|
|
11
|
+
from sleap_nn.export.cli import predict as predict_command
|
|
10
12
|
from sleap_nn.train import run_training
|
|
11
13
|
from sleap_nn import __version__
|
|
12
14
|
import hydra
|
|
@@ -417,6 +419,36 @@ def train(config_name, config_dir, video_paths, video_path_map, prefix_map, over
|
|
|
417
419
|
default=0.2,
|
|
418
420
|
help="Minimum confidence map value to consider a peak as valid.",
|
|
419
421
|
)
|
|
422
|
+
@click.option(
|
|
423
|
+
"--filter_overlapping",
|
|
424
|
+
is_flag=True,
|
|
425
|
+
default=False,
|
|
426
|
+
help=(
|
|
427
|
+
"Enable filtering of overlapping instances after inference using greedy NMS. "
|
|
428
|
+
"Applied independently of tracking. (default: False)"
|
|
429
|
+
),
|
|
430
|
+
)
|
|
431
|
+
@click.option(
|
|
432
|
+
"--filter_overlapping_method",
|
|
433
|
+
type=click.Choice(["iou", "oks"]),
|
|
434
|
+
default="iou",
|
|
435
|
+
help=(
|
|
436
|
+
"Similarity metric for filtering overlapping instances. "
|
|
437
|
+
"'iou': bounding box intersection-over-union. "
|
|
438
|
+
"'oks': Object Keypoint Similarity (pose-based). (default: iou)"
|
|
439
|
+
),
|
|
440
|
+
)
|
|
441
|
+
@click.option(
|
|
442
|
+
"--filter_overlapping_threshold",
|
|
443
|
+
type=float,
|
|
444
|
+
default=0.8,
|
|
445
|
+
help=(
|
|
446
|
+
"Similarity threshold for filtering overlapping instances. "
|
|
447
|
+
"Instances with similarity above this threshold are removed, "
|
|
448
|
+
"keeping the higher-scoring instance. "
|
|
449
|
+
"Typical values: 0.3 (aggressive) to 0.8 (permissive). (default: 0.8)"
|
|
450
|
+
),
|
|
451
|
+
)
|
|
420
452
|
@click.option(
|
|
421
453
|
"--integral_refinement",
|
|
422
454
|
type=str,
|
|
@@ -613,5 +645,9 @@ def system():
|
|
|
613
645
|
print_system_info()
|
|
614
646
|
|
|
615
647
|
|
|
648
|
+
cli.add_command(export_command)
|
|
649
|
+
cli.add_command(predict_command)
|
|
650
|
+
|
|
651
|
+
|
|
616
652
|
if __name__ == "__main__":
|
|
617
653
|
cli()
|
|
@@ -208,6 +208,23 @@ class EarlyStoppingConfig:
|
|
|
208
208
|
stop_training_on_plateau: bool = True
|
|
209
209
|
|
|
210
210
|
|
|
211
|
+
@define
|
|
212
|
+
class EvalConfig:
|
|
213
|
+
"""Configuration for epoch-end evaluation.
|
|
214
|
+
|
|
215
|
+
Attributes:
|
|
216
|
+
enabled: (bool) Enable epoch-end evaluation metrics. *Default*: `False`.
|
|
217
|
+
frequency: (int) Evaluate every N epochs. *Default*: `1`.
|
|
218
|
+
oks_stddev: (float) OKS standard deviation for evaluation. *Default*: `0.025`.
|
|
219
|
+
oks_scale: (float) OKS scale override. If None, uses default. *Default*: `None`.
|
|
220
|
+
"""
|
|
221
|
+
|
|
222
|
+
enabled: bool = False
|
|
223
|
+
frequency: int = field(default=1, validator=validators.ge(1))
|
|
224
|
+
oks_stddev: float = field(default=0.025, validator=validators.gt(0))
|
|
225
|
+
oks_scale: Optional[float] = None
|
|
226
|
+
|
|
227
|
+
|
|
211
228
|
@define
|
|
212
229
|
class HardKeypointMiningConfig:
|
|
213
230
|
"""Configuration for online hard keypoint mining.
|
|
@@ -310,6 +327,7 @@ class TrainerConfig:
|
|
|
310
327
|
factory=HardKeypointMiningConfig
|
|
311
328
|
)
|
|
312
329
|
zmq: Optional[ZMQConfig] = field(factory=ZMQConfig) # Required for SLEAP GUI
|
|
330
|
+
eval: EvalConfig = field(factory=EvalConfig) # Epoch-end evaluation config
|
|
313
331
|
|
|
314
332
|
@staticmethod
|
|
315
333
|
def validate_optimizer_name(value):
|
sleap_nn/evaluation.py
CHANGED
|
@@ -29,11 +29,27 @@ def get_instances(labeled_frame: sio.LabeledFrame) -> List[MatchInstance]:
|
|
|
29
29
|
"""
|
|
30
30
|
instance_list = []
|
|
31
31
|
frame_idx = labeled_frame.frame_idx
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
32
|
+
|
|
33
|
+
# Extract video path with fallbacks for embedded videos
|
|
34
|
+
video = labeled_frame.video
|
|
35
|
+
video_path = None
|
|
36
|
+
if video is not None:
|
|
37
|
+
backend = getattr(video, "backend", None)
|
|
38
|
+
if backend is not None:
|
|
39
|
+
# Try source_filename first (for embedded videos with provenance)
|
|
40
|
+
video_path = getattr(backend, "source_filename", None)
|
|
41
|
+
if video_path is None:
|
|
42
|
+
video_path = getattr(backend, "filename", None)
|
|
43
|
+
# Fallback to video.filename if backend doesn't have it
|
|
44
|
+
if video_path is None:
|
|
45
|
+
video_path = getattr(video, "filename", None)
|
|
46
|
+
# Handle list filenames (image sequences)
|
|
47
|
+
if isinstance(video_path, list) and video_path:
|
|
48
|
+
video_path = video_path[0]
|
|
49
|
+
# Final fallback: use a unique identifier
|
|
50
|
+
if video_path is None:
|
|
51
|
+
video_path = f"video_{id(video)}" if video is not None else "unknown"
|
|
52
|
+
|
|
37
53
|
for instance in labeled_frame.instances:
|
|
38
54
|
match_instance = MatchInstance(
|
|
39
55
|
instance=instance, frame_idx=frame_idx, video_path=video_path
|
|
@@ -47,6 +63,10 @@ def find_frame_pairs(
|
|
|
47
63
|
) -> List[Tuple[sio.LabeledFrame, sio.LabeledFrame]]:
|
|
48
64
|
"""Find corresponding frames across two sets of labels.
|
|
49
65
|
|
|
66
|
+
This function uses sleap-io's robust video matching API to handle various
|
|
67
|
+
scenarios including embedded videos, cross-platform paths, and videos with
|
|
68
|
+
different metadata.
|
|
69
|
+
|
|
50
70
|
Args:
|
|
51
71
|
labels_gt: A `sio.Labels` instance with ground truth instances.
|
|
52
72
|
labels_pr: A `sio.Labels` instance with predicted instances.
|
|
@@ -56,16 +76,15 @@ def find_frame_pairs(
|
|
|
56
76
|
Returns:
|
|
57
77
|
A list of pairs of `sio.LabeledFrame`s in the form `(frame_gt, frame_pr)`.
|
|
58
78
|
"""
|
|
79
|
+
# Use sleap-io's robust video matching API (added in 0.6.2)
|
|
80
|
+
# The match() method returns a MatchResult with video_map: {pred_video: gt_video}
|
|
81
|
+
match_result = labels_gt.match(labels_pr)
|
|
82
|
+
|
|
59
83
|
frame_pairs = []
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
if video_gt.matches_content(video) and video_gt.matches_path(video):
|
|
65
|
-
video_pr = video
|
|
66
|
-
break
|
|
67
|
-
|
|
68
|
-
if video_pr is None:
|
|
84
|
+
# Iterate over matched video pairs (pred_video -> gt_video mapping)
|
|
85
|
+
for video_pr, video_gt in match_result.video_map.items():
|
|
86
|
+
if video_gt is None:
|
|
87
|
+
# No match found for this prediction video
|
|
69
88
|
continue
|
|
70
89
|
|
|
71
90
|
# Find labeled frames in this video.
|
|
@@ -620,11 +639,19 @@ class Evaluator:
|
|
|
620
639
|
mPCK_parts = pcks.mean(axis=0).mean(axis=-1)
|
|
621
640
|
mPCK = mPCK_parts.mean()
|
|
622
641
|
|
|
642
|
+
# Precompute PCK at common thresholds
|
|
643
|
+
idx_5 = np.argmin(np.abs(thresholds - 5))
|
|
644
|
+
idx_10 = np.argmin(np.abs(thresholds - 10))
|
|
645
|
+
pck5 = pcks[:, :, idx_5].mean()
|
|
646
|
+
pck10 = pcks[:, :, idx_10].mean()
|
|
647
|
+
|
|
623
648
|
return {
|
|
624
649
|
"thresholds": thresholds,
|
|
625
650
|
"pcks": pcks,
|
|
626
651
|
"mPCK_parts": mPCK_parts,
|
|
627
652
|
"mPCK": mPCK,
|
|
653
|
+
"PCK@5": pck5,
|
|
654
|
+
"PCK@10": pck10,
|
|
628
655
|
}
|
|
629
656
|
|
|
630
657
|
def visibility_metrics(self):
|
|
@@ -786,11 +813,26 @@ def run_evaluation(
|
|
|
786
813
|
"""Evaluate SLEAP-NN model predictions against ground truth labels."""
|
|
787
814
|
logger.info("Loading ground truth labels...")
|
|
788
815
|
ground_truth_instances = sio.load_slp(ground_truth_path)
|
|
816
|
+
logger.info(
|
|
817
|
+
f" Ground truth: {len(ground_truth_instances.videos)} videos, "
|
|
818
|
+
f"{len(ground_truth_instances.labeled_frames)} frames"
|
|
819
|
+
)
|
|
789
820
|
|
|
790
821
|
logger.info("Loading predicted labels...")
|
|
791
822
|
predicted_instances = sio.load_slp(predicted_path)
|
|
823
|
+
logger.info(
|
|
824
|
+
f" Predictions: {len(predicted_instances.videos)} videos, "
|
|
825
|
+
f"{len(predicted_instances.labeled_frames)} frames"
|
|
826
|
+
)
|
|
827
|
+
|
|
828
|
+
logger.info("Matching videos and frames...")
|
|
829
|
+
# Get match stats before creating evaluator
|
|
830
|
+
match_result = ground_truth_instances.match(predicted_instances)
|
|
831
|
+
logger.info(
|
|
832
|
+
f" Videos matched: {match_result.n_videos_matched}/{len(match_result.video_map)}"
|
|
833
|
+
)
|
|
792
834
|
|
|
793
|
-
logger.info("
|
|
835
|
+
logger.info("Matching instances...")
|
|
794
836
|
evaluator = Evaluator(
|
|
795
837
|
ground_truth_instances=ground_truth_instances,
|
|
796
838
|
predicted_instances=predicted_instances,
|
|
@@ -799,21 +841,38 @@ def run_evaluation(
|
|
|
799
841
|
match_threshold=match_threshold,
|
|
800
842
|
user_labels_only=user_labels_only,
|
|
801
843
|
)
|
|
844
|
+
logger.info(
|
|
845
|
+
f" Frame pairs: {len(evaluator.frame_pairs)}, "
|
|
846
|
+
f"Matched instances: {len(evaluator.positive_pairs)}, "
|
|
847
|
+
f"Unmatched GT: {len(evaluator.false_negatives)}"
|
|
848
|
+
)
|
|
802
849
|
|
|
803
850
|
logger.info("Computing evaluation metrics...")
|
|
804
851
|
metrics = evaluator.evaluate()
|
|
805
852
|
|
|
853
|
+
# Compute PCK at specific thresholds (5 and 10 pixels)
|
|
854
|
+
dists = metrics["distance_metrics"]["dists"]
|
|
855
|
+
dists_clean = np.copy(dists)
|
|
856
|
+
dists_clean[np.isnan(dists_clean)] = np.inf
|
|
857
|
+
pck_5 = (dists_clean < 5).mean()
|
|
858
|
+
pck_10 = (dists_clean < 10).mean()
|
|
859
|
+
|
|
806
860
|
# Print key metrics
|
|
807
861
|
logger.info("Evaluation Results:")
|
|
808
|
-
logger.info(f"mOKS: {metrics['mOKS']['mOKS']:.4f}")
|
|
809
|
-
logger.info(f"mAP (OKS VOC): {metrics['voc_metrics']['oks_voc.mAP']:.4f}")
|
|
810
|
-
logger.info(f"mAR (OKS VOC): {metrics['voc_metrics']['oks_voc.mAR']:.4f}")
|
|
811
|
-
logger.info(f"Average Distance: {metrics['distance_metrics']['avg']:.
|
|
812
|
-
logger.info(f"
|
|
862
|
+
logger.info(f" mOKS: {metrics['mOKS']['mOKS']:.4f}")
|
|
863
|
+
logger.info(f" mAP (OKS VOC): {metrics['voc_metrics']['oks_voc.mAP']:.4f}")
|
|
864
|
+
logger.info(f" mAR (OKS VOC): {metrics['voc_metrics']['oks_voc.mAR']:.4f}")
|
|
865
|
+
logger.info(f" Average Distance: {metrics['distance_metrics']['avg']:.2f} px")
|
|
866
|
+
logger.info(f" dist.p50: {metrics['distance_metrics']['p50']:.2f} px")
|
|
867
|
+
logger.info(f" dist.p95: {metrics['distance_metrics']['p95']:.2f} px")
|
|
868
|
+
logger.info(f" dist.p99: {metrics['distance_metrics']['p99']:.2f} px")
|
|
869
|
+
logger.info(f" mPCK: {metrics['pck_metrics']['mPCK']:.4f}")
|
|
870
|
+
logger.info(f" PCK@5px: {pck_5:.4f}")
|
|
871
|
+
logger.info(f" PCK@10px: {pck_10:.4f}")
|
|
813
872
|
logger.info(
|
|
814
|
-
f"Visibility Precision: {metrics['visibility_metrics']['precision']:.4f}"
|
|
873
|
+
f" Visibility Precision: {metrics['visibility_metrics']['precision']:.4f}"
|
|
815
874
|
)
|
|
816
|
-
logger.info(f"Visibility Recall: {metrics['visibility_metrics']['recall']:.4f}")
|
|
875
|
+
logger.info(f" Visibility Recall: {metrics['visibility_metrics']['recall']:.4f}")
|
|
817
876
|
|
|
818
877
|
# Save metrics if path provided
|
|
819
878
|
if save_metrics:
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
"""Export utilities for sleap-nn."""
|
|
2
|
+
|
|
3
|
+
from sleap_nn.export.exporters import export_model, export_to_onnx, export_to_tensorrt
|
|
4
|
+
from sleap_nn.export.metadata import ExportMetadata
|
|
5
|
+
from sleap_nn.export.predictors import (
|
|
6
|
+
load_exported_model,
|
|
7
|
+
ONNXPredictor,
|
|
8
|
+
TensorRTPredictor,
|
|
9
|
+
)
|
|
10
|
+
from sleap_nn.export.utils import build_bottomup_candidate_template
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
"export_model",
|
|
14
|
+
"export_to_onnx",
|
|
15
|
+
"export_to_tensorrt",
|
|
16
|
+
"load_exported_model",
|
|
17
|
+
"ONNXPredictor",
|
|
18
|
+
"TensorRTPredictor",
|
|
19
|
+
"ExportMetadata",
|
|
20
|
+
"build_bottomup_candidate_template",
|
|
21
|
+
]
|