sleap-nn 0.0.5__py3-none-any.whl → 0.1.0a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sleap_nn/__init__.py +6 -1
- sleap_nn/cli.py +142 -3
- sleap_nn/config/data_config.py +44 -7
- sleap_nn/config/get_config.py +22 -20
- sleap_nn/config/trainer_config.py +12 -0
- sleap_nn/data/augmentation.py +54 -2
- sleap_nn/data/custom_datasets.py +22 -22
- sleap_nn/data/instance_cropping.py +70 -5
- sleap_nn/data/normalization.py +45 -2
- sleap_nn/data/providers.py +26 -0
- sleap_nn/evaluation.py +99 -23
- sleap_nn/inference/__init__.py +6 -0
- sleap_nn/inference/peak_finding.py +10 -2
- sleap_nn/inference/predictors.py +115 -20
- sleap_nn/inference/provenance.py +292 -0
- sleap_nn/inference/topdown.py +55 -47
- sleap_nn/predict.py +187 -10
- sleap_nn/system_info.py +443 -0
- sleap_nn/tracking/tracker.py +8 -1
- sleap_nn/train.py +64 -40
- sleap_nn/training/callbacks.py +317 -5
- sleap_nn/training/lightning_modules.py +325 -180
- sleap_nn/training/model_trainer.py +308 -22
- sleap_nn/training/utils.py +367 -2
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0a0.dist-info}/METADATA +22 -32
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0a0.dist-info}/RECORD +30 -28
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0a0.dist-info}/WHEEL +0 -0
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0a0.dist-info}/entry_points.txt +0 -0
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0a0.dist-info}/licenses/LICENSE +0 -0
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0a0.dist-info}/top_level.txt +0 -0
sleap_nn/inference/predictors.py
CHANGED
|
@@ -21,9 +21,6 @@ from sleap_nn.data.resizing import (
|
|
|
21
21
|
apply_sizematcher,
|
|
22
22
|
apply_resizer,
|
|
23
23
|
)
|
|
24
|
-
from sleap_nn.data.normalization import (
|
|
25
|
-
apply_normalization,
|
|
26
|
-
)
|
|
27
24
|
from sleap_nn.config.utils import get_model_type_from_cfg
|
|
28
25
|
from sleap_nn.inference.paf_grouping import PAFScorer
|
|
29
26
|
from sleap_nn.training.lightning_modules import (
|
|
@@ -61,6 +58,47 @@ from rich.progress import (
|
|
|
61
58
|
from time import time
|
|
62
59
|
|
|
63
60
|
|
|
61
|
+
def _filter_user_labeled_frames(
|
|
62
|
+
labels: sio.Labels,
|
|
63
|
+
video: sio.Video,
|
|
64
|
+
frames: Optional[list],
|
|
65
|
+
exclude_user_labeled: bool,
|
|
66
|
+
) -> Optional[list]:
|
|
67
|
+
"""Filter out user-labeled frames from a frame list.
|
|
68
|
+
|
|
69
|
+
This function is used when running inference with VideoReader (video_index specified)
|
|
70
|
+
to implement the exclude_user_labeled functionality.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
labels: The Labels object containing labeled frames.
|
|
74
|
+
video: The video to filter frames for.
|
|
75
|
+
frames: List of frame indices to filter. If None, builds list of all frames.
|
|
76
|
+
exclude_user_labeled: If True, filter out user-labeled frames.
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
Filtered list of frame indices excluding user-labeled frames if
|
|
80
|
+
exclude_user_labeled is True. Returns original frames if exclude_user_labeled
|
|
81
|
+
is False or if there are no user-labeled frames.
|
|
82
|
+
"""
|
|
83
|
+
if not exclude_user_labeled:
|
|
84
|
+
return frames
|
|
85
|
+
|
|
86
|
+
# Get user-labeled frame indices for this video
|
|
87
|
+
user_frame_indices = {
|
|
88
|
+
lf.frame_idx for lf in labels.find(video=video) if lf.has_user_instances
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
if not user_frame_indices:
|
|
92
|
+
return frames
|
|
93
|
+
|
|
94
|
+
# Build full frame list if frames is None
|
|
95
|
+
if frames is None:
|
|
96
|
+
frames = list(range(len(video)))
|
|
97
|
+
|
|
98
|
+
# Filter out user-labeled frames
|
|
99
|
+
return [f for f in frames if f not in user_frame_indices]
|
|
100
|
+
|
|
101
|
+
|
|
64
102
|
class RateColumn(rich.progress.ProgressColumn):
|
|
65
103
|
"""Renders the progress rate."""
|
|
66
104
|
|
|
@@ -321,6 +359,8 @@ class Predictor(ABC):
|
|
|
321
359
|
frames: Optional[list] = None,
|
|
322
360
|
only_labeled_frames: bool = False,
|
|
323
361
|
only_suggested_frames: bool = False,
|
|
362
|
+
exclude_user_labeled: bool = False,
|
|
363
|
+
only_predicted_frames: bool = False,
|
|
324
364
|
video_index: Optional[int] = None,
|
|
325
365
|
video_dataset: Optional[str] = None,
|
|
326
366
|
video_input_format: str = "channels_last",
|
|
@@ -393,7 +433,6 @@ class Predictor(ABC):
|
|
|
393
433
|
if frame["image"] is None:
|
|
394
434
|
done = True
|
|
395
435
|
break
|
|
396
|
-
frame["image"] = apply_normalization(frame["image"])
|
|
397
436
|
frame["image"], eff_scale = apply_sizematcher(
|
|
398
437
|
frame["image"],
|
|
399
438
|
self.preprocess_config["max_height"],
|
|
@@ -671,9 +710,6 @@ class TopDownPredictor(Predictor):
|
|
|
671
710
|
max_stride=max_stride,
|
|
672
711
|
input_scale=self.confmap_config.data_config.preprocessing.scale,
|
|
673
712
|
)
|
|
674
|
-
centroid_crop_layer.precrop_resize = (
|
|
675
|
-
self.confmap_config.data_config.preprocessing.scale
|
|
676
|
-
)
|
|
677
713
|
|
|
678
714
|
if self.centroid_config is None and self.confmap_config is not None:
|
|
679
715
|
self.instances_key = (
|
|
@@ -788,6 +824,7 @@ class TopDownPredictor(Predictor):
|
|
|
788
824
|
learning_rate=centroid_config.trainer_config.optimizer.lr,
|
|
789
825
|
amsgrad=centroid_config.trainer_config.optimizer.amsgrad,
|
|
790
826
|
map_location=device,
|
|
827
|
+
weights_only=False,
|
|
791
828
|
)
|
|
792
829
|
else:
|
|
793
830
|
# Load the converted model
|
|
@@ -904,6 +941,7 @@ class TopDownPredictor(Predictor):
|
|
|
904
941
|
amsgrad=confmap_config.trainer_config.optimizer.amsgrad,
|
|
905
942
|
backbone_type=centered_instance_backbone_type,
|
|
906
943
|
map_location=device,
|
|
944
|
+
weights_only=False,
|
|
907
945
|
)
|
|
908
946
|
else:
|
|
909
947
|
# Load the converted model
|
|
@@ -1073,6 +1111,8 @@ class TopDownPredictor(Predictor):
|
|
|
1073
1111
|
frames: Optional[list] = None,
|
|
1074
1112
|
only_labeled_frames: bool = False,
|
|
1075
1113
|
only_suggested_frames: bool = False,
|
|
1114
|
+
exclude_user_labeled: bool = False,
|
|
1115
|
+
only_predicted_frames: bool = False,
|
|
1076
1116
|
video_index: Optional[int] = None,
|
|
1077
1117
|
video_dataset: Optional[str] = None,
|
|
1078
1118
|
video_input_format: str = "channels_last",
|
|
@@ -1085,6 +1125,8 @@ class TopDownPredictor(Predictor):
|
|
|
1085
1125
|
frames: (list) List of frames indices. If `None`, all frames in the video are used. Default: None.
|
|
1086
1126
|
only_labeled_frames: (bool) `True` if inference should be run only on user-labeled frames. Default: `False`.
|
|
1087
1127
|
only_suggested_frames: (bool) `True` if inference should be run only on unlabeled suggested frames. Default: `False`.
|
|
1128
|
+
exclude_user_labeled: (bool) `True` to skip frames that have user-labeled instances. Default: `False`.
|
|
1129
|
+
only_predicted_frames: (bool) `True` to run inference only on frames that already have predictions. Default: `False`.
|
|
1088
1130
|
video_index: (int) Integer index of video in .slp file to predict on. To be used
|
|
1089
1131
|
with an .slp path as an alternative to specifying the video path.
|
|
1090
1132
|
video_dataset: (str) The dataset for HDF5 videos.
|
|
@@ -1119,6 +1161,8 @@ class TopDownPredictor(Predictor):
|
|
|
1119
1161
|
instances_key=self.instances_key,
|
|
1120
1162
|
only_labeled_frames=only_labeled_frames,
|
|
1121
1163
|
only_suggested_frames=only_suggested_frames,
|
|
1164
|
+
exclude_user_labeled=exclude_user_labeled,
|
|
1165
|
+
only_predicted_frames=only_predicted_frames,
|
|
1122
1166
|
)
|
|
1123
1167
|
self.videos = self.pipeline.labels.videos
|
|
1124
1168
|
|
|
@@ -1136,10 +1180,15 @@ class TopDownPredictor(Predictor):
|
|
|
1136
1180
|
|
|
1137
1181
|
if isinstance(inference_object, sio.Labels) and video_index is not None:
|
|
1138
1182
|
labels = inference_object
|
|
1183
|
+
video = labels.videos[video_index]
|
|
1184
|
+
# Filter out user-labeled frames if requested
|
|
1185
|
+
filtered_frames = _filter_user_labeled_frames(
|
|
1186
|
+
labels, video, frames, exclude_user_labeled
|
|
1187
|
+
)
|
|
1139
1188
|
self.pipeline = provider.from_video(
|
|
1140
|
-
video=
|
|
1189
|
+
video=video,
|
|
1141
1190
|
queue_maxsize=queue_maxsize,
|
|
1142
|
-
frames=
|
|
1191
|
+
frames=filtered_frames,
|
|
1143
1192
|
)
|
|
1144
1193
|
|
|
1145
1194
|
else: # for mp4 or hdf5 videos
|
|
@@ -1394,6 +1443,7 @@ class SingleInstancePredictor(Predictor):
|
|
|
1394
1443
|
learning_rate=confmap_config.trainer_config.optimizer.lr,
|
|
1395
1444
|
amsgrad=confmap_config.trainer_config.optimizer.amsgrad,
|
|
1396
1445
|
map_location=device,
|
|
1446
|
+
weights_only=False,
|
|
1397
1447
|
)
|
|
1398
1448
|
else:
|
|
1399
1449
|
confmap_converted_model = load_legacy_model(
|
|
@@ -1491,6 +1541,8 @@ class SingleInstancePredictor(Predictor):
|
|
|
1491
1541
|
frames: Optional[list] = None,
|
|
1492
1542
|
only_labeled_frames: bool = False,
|
|
1493
1543
|
only_suggested_frames: bool = False,
|
|
1544
|
+
exclude_user_labeled: bool = False,
|
|
1545
|
+
only_predicted_frames: bool = False,
|
|
1494
1546
|
video_index: Optional[int] = None,
|
|
1495
1547
|
video_dataset: Optional[str] = None,
|
|
1496
1548
|
video_input_format: str = "channels_last",
|
|
@@ -1503,6 +1555,8 @@ class SingleInstancePredictor(Predictor):
|
|
|
1503
1555
|
frames: List of frames indices. If `None`, all frames in the video are used. Default: None.
|
|
1504
1556
|
only_labeled_frames: (bool) `True` if inference should be run only on user-labeled frames. Default: `False`.
|
|
1505
1557
|
only_suggested_frames: (bool) `True` if inference should be run only on unlabeled suggested frames. Default: `False`.
|
|
1558
|
+
exclude_user_labeled: (bool) `True` to skip frames that have user-labeled instances. Default: `False`.
|
|
1559
|
+
only_predicted_frames: (bool) `True` to run inference only on frames that already have predictions. Default: `False`.
|
|
1506
1560
|
video_index: (int) Integer index of video in .slp file to predict on. To be used
|
|
1507
1561
|
with an .slp path as an alternative to specifying the video path.
|
|
1508
1562
|
video_dataset: (str) The dataset for HDF5 videos.
|
|
@@ -1536,6 +1590,8 @@ class SingleInstancePredictor(Predictor):
|
|
|
1536
1590
|
frame_buffer=frame_buffer,
|
|
1537
1591
|
only_labeled_frames=only_labeled_frames,
|
|
1538
1592
|
only_suggested_frames=only_suggested_frames,
|
|
1593
|
+
exclude_user_labeled=exclude_user_labeled,
|
|
1594
|
+
only_predicted_frames=only_predicted_frames,
|
|
1539
1595
|
)
|
|
1540
1596
|
self.videos = self.pipeline.labels.videos
|
|
1541
1597
|
|
|
@@ -1544,10 +1600,15 @@ class SingleInstancePredictor(Predictor):
|
|
|
1544
1600
|
|
|
1545
1601
|
if isinstance(inference_object, sio.Labels) and video_index is not None:
|
|
1546
1602
|
labels = inference_object
|
|
1603
|
+
video = labels.videos[video_index]
|
|
1604
|
+
# Filter out user-labeled frames if requested
|
|
1605
|
+
filtered_frames = _filter_user_labeled_frames(
|
|
1606
|
+
labels, video, frames, exclude_user_labeled
|
|
1607
|
+
)
|
|
1547
1608
|
self.pipeline = provider.from_video(
|
|
1548
|
-
video=
|
|
1609
|
+
video=video,
|
|
1549
1610
|
queue_maxsize=queue_maxsize,
|
|
1550
|
-
frames=
|
|
1611
|
+
frames=filtered_frames,
|
|
1551
1612
|
)
|
|
1552
1613
|
|
|
1553
1614
|
else: # for mp4 or hdf5 videos
|
|
@@ -1833,6 +1894,7 @@ class BottomUpPredictor(Predictor):
|
|
|
1833
1894
|
backbone_type=backbone_type,
|
|
1834
1895
|
model_type="bottomup",
|
|
1835
1896
|
map_location=device,
|
|
1897
|
+
weights_only=False,
|
|
1836
1898
|
)
|
|
1837
1899
|
else:
|
|
1838
1900
|
bottomup_converted_model = load_legacy_model(
|
|
@@ -1929,6 +1991,8 @@ class BottomUpPredictor(Predictor):
|
|
|
1929
1991
|
frames: Optional[list] = None,
|
|
1930
1992
|
only_labeled_frames: bool = False,
|
|
1931
1993
|
only_suggested_frames: bool = False,
|
|
1994
|
+
exclude_user_labeled: bool = False,
|
|
1995
|
+
only_predicted_frames: bool = False,
|
|
1932
1996
|
video_index: Optional[int] = None,
|
|
1933
1997
|
video_dataset: Optional[str] = None,
|
|
1934
1998
|
video_input_format: str = "channels_last",
|
|
@@ -1941,6 +2005,8 @@ class BottomUpPredictor(Predictor):
|
|
|
1941
2005
|
frames: List of frames indices. If `None`, all frames in the video are used. Default: None.
|
|
1942
2006
|
only_labeled_frames: (bool) `True` if inference should be run only on user-labeled frames. Default: `False`.
|
|
1943
2007
|
only_suggested_frames: (bool) `True` if inference should be run only on unlabeled suggested frames. Default: `False`.
|
|
2008
|
+
exclude_user_labeled: (bool) `True` to skip frames that have user-labeled instances. Default: `False`.
|
|
2009
|
+
only_predicted_frames: (bool) `True` to run inference only on frames that already have predictions. Default: `False`.
|
|
1944
2010
|
video_index: (int) Integer index of video in .slp file to predict on. To be used
|
|
1945
2011
|
with an .slp path as an alternative to specifying the video path.
|
|
1946
2012
|
video_dataset: (str) The dataset for HDF5 videos.
|
|
@@ -1974,6 +2040,8 @@ class BottomUpPredictor(Predictor):
|
|
|
1974
2040
|
frame_buffer=frame_buffer,
|
|
1975
2041
|
only_labeled_frames=only_labeled_frames,
|
|
1976
2042
|
only_suggested_frames=only_suggested_frames,
|
|
2043
|
+
exclude_user_labeled=exclude_user_labeled,
|
|
2044
|
+
only_predicted_frames=only_predicted_frames,
|
|
1977
2045
|
)
|
|
1978
2046
|
|
|
1979
2047
|
self.videos = self.pipeline.labels.videos
|
|
@@ -1983,10 +2051,15 @@ class BottomUpPredictor(Predictor):
|
|
|
1983
2051
|
|
|
1984
2052
|
if isinstance(inference_object, sio.Labels) and video_index is not None:
|
|
1985
2053
|
labels = inference_object
|
|
2054
|
+
video = labels.videos[video_index]
|
|
2055
|
+
# Filter out user-labeled frames if requested
|
|
2056
|
+
filtered_frames = _filter_user_labeled_frames(
|
|
2057
|
+
labels, video, frames, exclude_user_labeled
|
|
2058
|
+
)
|
|
1986
2059
|
self.pipeline = provider.from_video(
|
|
1987
|
-
video=
|
|
2060
|
+
video=video,
|
|
1988
2061
|
queue_maxsize=queue_maxsize,
|
|
1989
|
-
frames=
|
|
2062
|
+
frames=filtered_frames,
|
|
1990
2063
|
)
|
|
1991
2064
|
|
|
1992
2065
|
else: # for mp4 or hdf5 videos
|
|
@@ -2260,6 +2333,7 @@ class BottomUpMultiClassPredictor(Predictor):
|
|
|
2260
2333
|
optimizer=bottomup_config.trainer_config.optimizer_name,
|
|
2261
2334
|
learning_rate=bottomup_config.trainer_config.optimizer.lr,
|
|
2262
2335
|
amsgrad=bottomup_config.trainer_config.optimizer.amsgrad,
|
|
2336
|
+
weights_only=False,
|
|
2263
2337
|
)
|
|
2264
2338
|
else:
|
|
2265
2339
|
bottomup_converted_model = load_legacy_model(
|
|
@@ -2364,6 +2438,8 @@ class BottomUpMultiClassPredictor(Predictor):
|
|
|
2364
2438
|
frames: Optional[list] = None,
|
|
2365
2439
|
only_labeled_frames: bool = False,
|
|
2366
2440
|
only_suggested_frames: bool = False,
|
|
2441
|
+
exclude_user_labeled: bool = False,
|
|
2442
|
+
only_predicted_frames: bool = False,
|
|
2367
2443
|
video_index: Optional[int] = None,
|
|
2368
2444
|
video_dataset: Optional[str] = None,
|
|
2369
2445
|
video_input_format: str = "channels_last",
|
|
@@ -2376,6 +2452,8 @@ class BottomUpMultiClassPredictor(Predictor):
|
|
|
2376
2452
|
frames: List of frames indices. If `None`, all frames in the video are used. Default: None.
|
|
2377
2453
|
only_labeled_frames: (bool) `True` if inference should be run only on user-labeled frames. Default: `False`.
|
|
2378
2454
|
only_suggested_frames: (bool) `True` if inference should be run only on unlabeled suggested frames. Default: `False`.
|
|
2455
|
+
exclude_user_labeled: (bool) `True` to skip frames that have user-labeled instances. Default: `False`.
|
|
2456
|
+
only_predicted_frames: (bool) `True` to run inference only on frames that already have predictions. Default: `False`.
|
|
2379
2457
|
video_index: (int) Integer index of video in .slp file to predict on. To be used
|
|
2380
2458
|
with an .slp path as an alternative to specifying the video path.
|
|
2381
2459
|
video_dataset: (str) The dataset for HDF5 videos.
|
|
@@ -2411,6 +2489,8 @@ class BottomUpMultiClassPredictor(Predictor):
|
|
|
2411
2489
|
frame_buffer=frame_buffer,
|
|
2412
2490
|
only_labeled_frames=only_labeled_frames,
|
|
2413
2491
|
only_suggested_frames=only_suggested_frames,
|
|
2492
|
+
exclude_user_labeled=exclude_user_labeled,
|
|
2493
|
+
only_predicted_frames=only_predicted_frames,
|
|
2414
2494
|
)
|
|
2415
2495
|
|
|
2416
2496
|
self.videos = self.pipeline.labels.videos
|
|
@@ -2420,10 +2500,15 @@ class BottomUpMultiClassPredictor(Predictor):
|
|
|
2420
2500
|
|
|
2421
2501
|
if isinstance(inference_object, sio.Labels) and video_index is not None:
|
|
2422
2502
|
labels = inference_object
|
|
2503
|
+
video = labels.videos[video_index]
|
|
2504
|
+
# Filter out user-labeled frames if requested
|
|
2505
|
+
filtered_frames = _filter_user_labeled_frames(
|
|
2506
|
+
labels, video, frames, exclude_user_labeled
|
|
2507
|
+
)
|
|
2423
2508
|
self.pipeline = provider.from_video(
|
|
2424
|
-
video=
|
|
2509
|
+
video=video,
|
|
2425
2510
|
queue_maxsize=queue_maxsize,
|
|
2426
|
-
frames=
|
|
2511
|
+
frames=filtered_frames,
|
|
2427
2512
|
)
|
|
2428
2513
|
|
|
2429
2514
|
else: # for mp4 or hdf5 videos
|
|
@@ -2682,9 +2767,6 @@ class TopDownMultiClassPredictor(Predictor):
|
|
|
2682
2767
|
max_stride=max_stride,
|
|
2683
2768
|
input_scale=self.confmap_config.data_config.preprocessing.scale,
|
|
2684
2769
|
)
|
|
2685
|
-
centroid_crop_layer.precrop_resize = (
|
|
2686
|
-
self.confmap_config.data_config.preprocessing.scale
|
|
2687
|
-
)
|
|
2688
2770
|
|
|
2689
2771
|
if self.centroid_config is None:
|
|
2690
2772
|
self.instances_key = (
|
|
@@ -2805,6 +2887,7 @@ class TopDownMultiClassPredictor(Predictor):
|
|
|
2805
2887
|
learning_rate=centroid_config.trainer_config.optimizer.lr,
|
|
2806
2888
|
amsgrad=centroid_config.trainer_config.optimizer.amsgrad,
|
|
2807
2889
|
map_location=device,
|
|
2890
|
+
weights_only=False,
|
|
2808
2891
|
)
|
|
2809
2892
|
|
|
2810
2893
|
else:
|
|
@@ -2929,6 +3012,7 @@ class TopDownMultiClassPredictor(Predictor):
|
|
|
2929
3012
|
learning_rate=confmap_config.trainer_config.optimizer.lr,
|
|
2930
3013
|
amsgrad=confmap_config.trainer_config.optimizer.amsgrad,
|
|
2931
3014
|
map_location=device,
|
|
3015
|
+
weights_only=False,
|
|
2932
3016
|
)
|
|
2933
3017
|
else:
|
|
2934
3018
|
confmap_converted_model = load_legacy_model(
|
|
@@ -3109,6 +3193,8 @@ class TopDownMultiClassPredictor(Predictor):
|
|
|
3109
3193
|
frames: Optional[list] = None,
|
|
3110
3194
|
only_labeled_frames: bool = False,
|
|
3111
3195
|
only_suggested_frames: bool = False,
|
|
3196
|
+
exclude_user_labeled: bool = False,
|
|
3197
|
+
only_predicted_frames: bool = False,
|
|
3112
3198
|
video_index: Optional[int] = None,
|
|
3113
3199
|
video_dataset: Optional[str] = None,
|
|
3114
3200
|
video_input_format: str = "channels_last",
|
|
@@ -3121,6 +3207,8 @@ class TopDownMultiClassPredictor(Predictor):
|
|
|
3121
3207
|
frames: (list) List of frames indices. If `None`, all frames in the video are used. Default: None.
|
|
3122
3208
|
only_labeled_frames: (bool) `True` if inference should be run only on user-labeled frames. Default: `False`.
|
|
3123
3209
|
only_suggested_frames: (bool) `True` if inference should be run only on unlabeled suggested frames. Default: `False`.
|
|
3210
|
+
exclude_user_labeled: (bool) `True` to skip frames that have user-labeled instances. Default: `False`.
|
|
3211
|
+
only_predicted_frames: (bool) `True` to run inference only on frames that already have predictions. Default: `False`.
|
|
3124
3212
|
video_index: (int) Integer index of video in .slp file to predict on. To be used
|
|
3125
3213
|
with an .slp path as an alternative to specifying the video path.
|
|
3126
3214
|
video_dataset: (str) The dataset for HDF5 videos.
|
|
@@ -3155,6 +3243,8 @@ class TopDownMultiClassPredictor(Predictor):
|
|
|
3155
3243
|
instances_key=self.instances_key,
|
|
3156
3244
|
only_labeled_frames=only_labeled_frames,
|
|
3157
3245
|
only_suggested_frames=only_suggested_frames,
|
|
3246
|
+
exclude_user_labeled=exclude_user_labeled,
|
|
3247
|
+
only_predicted_frames=only_predicted_frames,
|
|
3158
3248
|
)
|
|
3159
3249
|
self.videos = self.pipeline.labels.videos
|
|
3160
3250
|
|
|
@@ -3172,10 +3262,15 @@ class TopDownMultiClassPredictor(Predictor):
|
|
|
3172
3262
|
|
|
3173
3263
|
if isinstance(inference_object, sio.Labels) and video_index is not None:
|
|
3174
3264
|
labels = inference_object
|
|
3265
|
+
video = labels.videos[video_index]
|
|
3266
|
+
# Filter out user-labeled frames if requested
|
|
3267
|
+
filtered_frames = _filter_user_labeled_frames(
|
|
3268
|
+
labels, video, frames, exclude_user_labeled
|
|
3269
|
+
)
|
|
3175
3270
|
self.pipeline = provider.from_video(
|
|
3176
|
-
video=
|
|
3271
|
+
video=video,
|
|
3177
3272
|
queue_maxsize=queue_maxsize,
|
|
3178
|
-
frames=
|
|
3273
|
+
frames=filtered_frames,
|
|
3179
3274
|
)
|
|
3180
3275
|
|
|
3181
3276
|
else: # for mp4 or hdf5 videos
|
|
@@ -0,0 +1,292 @@
|
|
|
1
|
+
"""Provenance metadata utilities for inference outputs.
|
|
2
|
+
|
|
3
|
+
This module provides utilities for building and managing provenance metadata
|
|
4
|
+
that is stored in SLP files produced during inference. Provenance metadata
|
|
5
|
+
helps track where predictions came from and how they were generated.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from datetime import datetime
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Any, Optional, Union
|
|
11
|
+
|
|
12
|
+
import sleap_io as sio
|
|
13
|
+
|
|
14
|
+
import sleap_nn
|
|
15
|
+
from sleap_nn.system_info import get_system_info_dict
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def build_inference_provenance(
|
|
19
|
+
model_paths: Optional[list[str]] = None,
|
|
20
|
+
model_type: Optional[str] = None,
|
|
21
|
+
start_time: Optional[datetime] = None,
|
|
22
|
+
end_time: Optional[datetime] = None,
|
|
23
|
+
input_labels: Optional[sio.Labels] = None,
|
|
24
|
+
input_path: Optional[Union[str, Path]] = None,
|
|
25
|
+
frames_processed: Optional[int] = None,
|
|
26
|
+
frames_total: Optional[int] = None,
|
|
27
|
+
frame_selection_method: Optional[str] = None,
|
|
28
|
+
inference_params: Optional[dict[str, Any]] = None,
|
|
29
|
+
tracking_params: Optional[dict[str, Any]] = None,
|
|
30
|
+
device: Optional[str] = None,
|
|
31
|
+
cli_args: Optional[dict[str, Any]] = None,
|
|
32
|
+
include_system_info: bool = True,
|
|
33
|
+
) -> dict[str, Any]:
|
|
34
|
+
"""Build provenance metadata dictionary for inference output.
|
|
35
|
+
|
|
36
|
+
This function creates a comprehensive provenance dictionary that captures
|
|
37
|
+
all relevant metadata about an inference run, enabling reproducibility
|
|
38
|
+
and tracking of prediction origins.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
model_paths: List of paths to model checkpoints used for inference.
|
|
42
|
+
model_type: Type of model used (e.g., "top_down", "bottom_up",
|
|
43
|
+
"single_instance").
|
|
44
|
+
start_time: Datetime when inference started.
|
|
45
|
+
end_time: Datetime when inference finished.
|
|
46
|
+
input_labels: Input Labels object if inference was run on an SLP file.
|
|
47
|
+
The provenance from this object will be preserved.
|
|
48
|
+
input_path: Path to input file (SLP or video).
|
|
49
|
+
frames_processed: Number of frames that were processed.
|
|
50
|
+
frames_total: Total number of frames in the input.
|
|
51
|
+
frame_selection_method: Method used to select frames (e.g., "all",
|
|
52
|
+
"labeled", "suggested", "range").
|
|
53
|
+
inference_params: Dictionary of inference parameters (peak_threshold,
|
|
54
|
+
integral_refinement, batch_size, etc.).
|
|
55
|
+
tracking_params: Dictionary of tracking parameters if tracking was run.
|
|
56
|
+
device: Device used for inference (e.g., "cuda:0", "cpu", "mps").
|
|
57
|
+
cli_args: Command-line arguments if available.
|
|
58
|
+
include_system_info: If True, include detailed system information.
|
|
59
|
+
Set to False for lighter-weight provenance.
|
|
60
|
+
|
|
61
|
+
Returns:
|
|
62
|
+
Dictionary containing provenance metadata suitable for storing in
|
|
63
|
+
Labels.provenance.
|
|
64
|
+
|
|
65
|
+
Example:
|
|
66
|
+
>>> from datetime import datetime
|
|
67
|
+
>>> provenance = build_inference_provenance(
|
|
68
|
+
... model_paths=["/path/to/model.ckpt"],
|
|
69
|
+
... model_type="top_down",
|
|
70
|
+
... start_time=datetime.now(),
|
|
71
|
+
... end_time=datetime.now(),
|
|
72
|
+
... device="cuda:0",
|
|
73
|
+
... )
|
|
74
|
+
>>> labels.provenance = provenance
|
|
75
|
+
>>> labels.save("predictions.slp")
|
|
76
|
+
"""
|
|
77
|
+
provenance: dict[str, Any] = {}
|
|
78
|
+
|
|
79
|
+
# Timestamps
|
|
80
|
+
if start_time is not None:
|
|
81
|
+
provenance["inference_start_timestamp"] = start_time.isoformat()
|
|
82
|
+
if end_time is not None:
|
|
83
|
+
provenance["inference_end_timestamp"] = end_time.isoformat()
|
|
84
|
+
if start_time is not None and end_time is not None:
|
|
85
|
+
runtime_seconds = (end_time - start_time).total_seconds()
|
|
86
|
+
provenance["inference_runtime_seconds"] = runtime_seconds
|
|
87
|
+
|
|
88
|
+
# Version information
|
|
89
|
+
provenance["sleap_nn_version"] = sleap_nn.__version__
|
|
90
|
+
provenance["sleap_io_version"] = sio.__version__
|
|
91
|
+
|
|
92
|
+
# Model information
|
|
93
|
+
if model_paths is not None:
|
|
94
|
+
# Store as absolute POSIX paths for cross-platform compatibility
|
|
95
|
+
provenance["model_paths"] = [
|
|
96
|
+
Path(p).resolve().as_posix() if isinstance(p, (str, Path)) else str(p)
|
|
97
|
+
for p in model_paths
|
|
98
|
+
]
|
|
99
|
+
if model_type is not None:
|
|
100
|
+
provenance["model_type"] = model_type
|
|
101
|
+
|
|
102
|
+
# Input data lineage
|
|
103
|
+
if input_path is not None:
|
|
104
|
+
provenance["source_file"] = (
|
|
105
|
+
Path(input_path).resolve().as_posix()
|
|
106
|
+
if isinstance(input_path, (str, Path))
|
|
107
|
+
else str(input_path)
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
# Preserve input provenance if available
|
|
111
|
+
if input_labels is not None and hasattr(input_labels, "provenance"):
|
|
112
|
+
input_prov = dict(input_labels.provenance)
|
|
113
|
+
if input_prov:
|
|
114
|
+
provenance["input_provenance"] = input_prov
|
|
115
|
+
# Also set source_labels for compatibility with sleap-io conventions
|
|
116
|
+
if "filename" in input_prov:
|
|
117
|
+
provenance["source_labels"] = input_prov["filename"]
|
|
118
|
+
|
|
119
|
+
# Frame selection information
|
|
120
|
+
if frames_processed is not None or frames_total is not None:
|
|
121
|
+
frame_info: dict[str, Any] = {}
|
|
122
|
+
if frame_selection_method is not None:
|
|
123
|
+
frame_info["method"] = frame_selection_method
|
|
124
|
+
if frames_processed is not None:
|
|
125
|
+
frame_info["frames_processed"] = frames_processed
|
|
126
|
+
if frames_total is not None:
|
|
127
|
+
frame_info["frames_total"] = frames_total
|
|
128
|
+
if frame_info:
|
|
129
|
+
provenance["frame_selection"] = frame_info
|
|
130
|
+
|
|
131
|
+
# Inference parameters
|
|
132
|
+
if inference_params is not None:
|
|
133
|
+
# Filter out None values and convert paths
|
|
134
|
+
clean_params = {}
|
|
135
|
+
for key, value in inference_params.items():
|
|
136
|
+
if value is not None:
|
|
137
|
+
if isinstance(value, Path):
|
|
138
|
+
clean_params[key] = value.as_posix()
|
|
139
|
+
else:
|
|
140
|
+
clean_params[key] = value
|
|
141
|
+
if clean_params:
|
|
142
|
+
provenance["inference_config"] = clean_params
|
|
143
|
+
|
|
144
|
+
# Tracking parameters
|
|
145
|
+
if tracking_params is not None:
|
|
146
|
+
clean_tracking = {k: v for k, v in tracking_params.items() if v is not None}
|
|
147
|
+
if clean_tracking:
|
|
148
|
+
provenance["tracking_config"] = clean_tracking
|
|
149
|
+
|
|
150
|
+
# Device information
|
|
151
|
+
if device is not None:
|
|
152
|
+
provenance["device"] = device
|
|
153
|
+
|
|
154
|
+
# CLI arguments
|
|
155
|
+
if cli_args is not None:
|
|
156
|
+
# Filter out None values
|
|
157
|
+
clean_cli = {k: v for k, v in cli_args.items() if v is not None}
|
|
158
|
+
if clean_cli:
|
|
159
|
+
provenance["cli_args"] = clean_cli
|
|
160
|
+
|
|
161
|
+
# System information (can be disabled for lighter provenance)
|
|
162
|
+
if include_system_info:
|
|
163
|
+
try:
|
|
164
|
+
system_info = get_system_info_dict()
|
|
165
|
+
# Extract key fields for provenance (avoid excessive nesting)
|
|
166
|
+
provenance["system_info"] = {
|
|
167
|
+
"python_version": system_info.get("python_version"),
|
|
168
|
+
"platform": system_info.get("platform"),
|
|
169
|
+
"pytorch_version": system_info.get("pytorch_version"),
|
|
170
|
+
"cuda_version": system_info.get("cuda_version"),
|
|
171
|
+
"accelerator": system_info.get("accelerator"),
|
|
172
|
+
"gpu_count": system_info.get("gpu_count"),
|
|
173
|
+
}
|
|
174
|
+
# Include GPU names if available
|
|
175
|
+
if system_info.get("gpus"):
|
|
176
|
+
provenance["system_info"]["gpus"] = [
|
|
177
|
+
gpu.get("name") for gpu in system_info["gpus"]
|
|
178
|
+
]
|
|
179
|
+
except Exception:
|
|
180
|
+
# Don't fail inference if system info collection fails
|
|
181
|
+
pass
|
|
182
|
+
|
|
183
|
+
return provenance
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def build_tracking_only_provenance(
|
|
187
|
+
input_labels: Optional[sio.Labels] = None,
|
|
188
|
+
input_path: Optional[Union[str, Path]] = None,
|
|
189
|
+
start_time: Optional[datetime] = None,
|
|
190
|
+
end_time: Optional[datetime] = None,
|
|
191
|
+
tracking_params: Optional[dict[str, Any]] = None,
|
|
192
|
+
frames_processed: Optional[int] = None,
|
|
193
|
+
include_system_info: bool = True,
|
|
194
|
+
) -> dict[str, Any]:
|
|
195
|
+
"""Build provenance metadata for tracking-only pipeline.
|
|
196
|
+
|
|
197
|
+
This is a simplified version of build_inference_provenance for when
|
|
198
|
+
only tracking is run without model inference.
|
|
199
|
+
|
|
200
|
+
Args:
|
|
201
|
+
input_labels: Input Labels object with existing predictions.
|
|
202
|
+
input_path: Path to input SLP file.
|
|
203
|
+
start_time: Datetime when tracking started.
|
|
204
|
+
end_time: Datetime when tracking finished.
|
|
205
|
+
tracking_params: Dictionary of tracking parameters.
|
|
206
|
+
frames_processed: Number of frames that were tracked.
|
|
207
|
+
include_system_info: If True, include system information.
|
|
208
|
+
|
|
209
|
+
Returns:
|
|
210
|
+
Dictionary containing provenance metadata.
|
|
211
|
+
"""
|
|
212
|
+
provenance: dict[str, Any] = {}
|
|
213
|
+
|
|
214
|
+
# Timestamps
|
|
215
|
+
if start_time is not None:
|
|
216
|
+
provenance["tracking_start_timestamp"] = start_time.isoformat()
|
|
217
|
+
if end_time is not None:
|
|
218
|
+
provenance["tracking_end_timestamp"] = end_time.isoformat()
|
|
219
|
+
if start_time is not None and end_time is not None:
|
|
220
|
+
runtime_seconds = (end_time - start_time).total_seconds()
|
|
221
|
+
provenance["tracking_runtime_seconds"] = runtime_seconds
|
|
222
|
+
|
|
223
|
+
# Version information
|
|
224
|
+
provenance["sleap_nn_version"] = sleap_nn.__version__
|
|
225
|
+
provenance["sleap_io_version"] = sio.__version__
|
|
226
|
+
|
|
227
|
+
# Note that this is tracking-only
|
|
228
|
+
provenance["pipeline_type"] = "tracking_only"
|
|
229
|
+
|
|
230
|
+
# Input data lineage
|
|
231
|
+
if input_path is not None:
|
|
232
|
+
provenance["source_file"] = (
|
|
233
|
+
Path(input_path).resolve().as_posix()
|
|
234
|
+
if isinstance(input_path, (str, Path))
|
|
235
|
+
else str(input_path)
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
# Preserve input provenance if available
|
|
239
|
+
if input_labels is not None and hasattr(input_labels, "provenance"):
|
|
240
|
+
input_prov = dict(input_labels.provenance)
|
|
241
|
+
if input_prov:
|
|
242
|
+
provenance["input_provenance"] = input_prov
|
|
243
|
+
if "filename" in input_prov:
|
|
244
|
+
provenance["source_labels"] = input_prov["filename"]
|
|
245
|
+
|
|
246
|
+
# Frame information
|
|
247
|
+
if frames_processed is not None:
|
|
248
|
+
provenance["frames_processed"] = frames_processed
|
|
249
|
+
|
|
250
|
+
# Tracking parameters
|
|
251
|
+
if tracking_params is not None:
|
|
252
|
+
clean_tracking = {k: v for k, v in tracking_params.items() if v is not None}
|
|
253
|
+
if clean_tracking:
|
|
254
|
+
provenance["tracking_config"] = clean_tracking
|
|
255
|
+
|
|
256
|
+
# System information
|
|
257
|
+
if include_system_info:
|
|
258
|
+
try:
|
|
259
|
+
system_info = get_system_info_dict()
|
|
260
|
+
provenance["system_info"] = {
|
|
261
|
+
"python_version": system_info.get("python_version"),
|
|
262
|
+
"platform": system_info.get("platform"),
|
|
263
|
+
"pytorch_version": system_info.get("pytorch_version"),
|
|
264
|
+
"accelerator": system_info.get("accelerator"),
|
|
265
|
+
}
|
|
266
|
+
except Exception:
|
|
267
|
+
pass
|
|
268
|
+
|
|
269
|
+
return provenance
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
def merge_provenance(
|
|
273
|
+
base_provenance: dict[str, Any],
|
|
274
|
+
additional: dict[str, Any],
|
|
275
|
+
overwrite: bool = True,
|
|
276
|
+
) -> dict[str, Any]:
|
|
277
|
+
"""Merge additional provenance fields into base provenance.
|
|
278
|
+
|
|
279
|
+
Args:
|
|
280
|
+
base_provenance: Base provenance dictionary.
|
|
281
|
+
additional: Additional fields to merge.
|
|
282
|
+
overwrite: If True, additional fields overwrite base fields.
|
|
283
|
+
If False, base fields take precedence.
|
|
284
|
+
|
|
285
|
+
Returns:
|
|
286
|
+
Merged provenance dictionary.
|
|
287
|
+
"""
|
|
288
|
+
result = dict(base_provenance)
|
|
289
|
+
for key, value in additional.items():
|
|
290
|
+
if key not in result or overwrite:
|
|
291
|
+
result[key] = value
|
|
292
|
+
return result
|