sleap-nn 0.0.5__py3-none-any.whl → 0.1.0a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -21,9 +21,6 @@ from sleap_nn.data.resizing import (
21
21
  apply_sizematcher,
22
22
  apply_resizer,
23
23
  )
24
- from sleap_nn.data.normalization import (
25
- apply_normalization,
26
- )
27
24
  from sleap_nn.config.utils import get_model_type_from_cfg
28
25
  from sleap_nn.inference.paf_grouping import PAFScorer
29
26
  from sleap_nn.training.lightning_modules import (
@@ -61,6 +58,47 @@ from rich.progress import (
61
58
  from time import time
62
59
 
63
60
 
61
+ def _filter_user_labeled_frames(
62
+ labels: sio.Labels,
63
+ video: sio.Video,
64
+ frames: Optional[list],
65
+ exclude_user_labeled: bool,
66
+ ) -> Optional[list]:
67
+ """Filter out user-labeled frames from a frame list.
68
+
69
+ This function is used when running inference with VideoReader (video_index specified)
70
+ to implement the exclude_user_labeled functionality.
71
+
72
+ Args:
73
+ labels: The Labels object containing labeled frames.
74
+ video: The video to filter frames for.
75
+ frames: List of frame indices to filter. If None, builds list of all frames.
76
+ exclude_user_labeled: If True, filter out user-labeled frames.
77
+
78
+ Returns:
79
+ Filtered list of frame indices excluding user-labeled frames if
80
+ exclude_user_labeled is True. Returns original frames if exclude_user_labeled
81
+ is False or if there are no user-labeled frames.
82
+ """
83
+ if not exclude_user_labeled:
84
+ return frames
85
+
86
+ # Get user-labeled frame indices for this video
87
+ user_frame_indices = {
88
+ lf.frame_idx for lf in labels.find(video=video) if lf.has_user_instances
89
+ }
90
+
91
+ if not user_frame_indices:
92
+ return frames
93
+
94
+ # Build full frame list if frames is None
95
+ if frames is None:
96
+ frames = list(range(len(video)))
97
+
98
+ # Filter out user-labeled frames
99
+ return [f for f in frames if f not in user_frame_indices]
100
+
101
+
64
102
  class RateColumn(rich.progress.ProgressColumn):
65
103
  """Renders the progress rate."""
66
104
 
@@ -321,6 +359,8 @@ class Predictor(ABC):
321
359
  frames: Optional[list] = None,
322
360
  only_labeled_frames: bool = False,
323
361
  only_suggested_frames: bool = False,
362
+ exclude_user_labeled: bool = False,
363
+ only_predicted_frames: bool = False,
324
364
  video_index: Optional[int] = None,
325
365
  video_dataset: Optional[str] = None,
326
366
  video_input_format: str = "channels_last",
@@ -393,7 +433,6 @@ class Predictor(ABC):
393
433
  if frame["image"] is None:
394
434
  done = True
395
435
  break
396
- frame["image"] = apply_normalization(frame["image"])
397
436
  frame["image"], eff_scale = apply_sizematcher(
398
437
  frame["image"],
399
438
  self.preprocess_config["max_height"],
@@ -671,9 +710,6 @@ class TopDownPredictor(Predictor):
671
710
  max_stride=max_stride,
672
711
  input_scale=self.confmap_config.data_config.preprocessing.scale,
673
712
  )
674
- centroid_crop_layer.precrop_resize = (
675
- self.confmap_config.data_config.preprocessing.scale
676
- )
677
713
 
678
714
  if self.centroid_config is None and self.confmap_config is not None:
679
715
  self.instances_key = (
@@ -788,6 +824,7 @@ class TopDownPredictor(Predictor):
788
824
  learning_rate=centroid_config.trainer_config.optimizer.lr,
789
825
  amsgrad=centroid_config.trainer_config.optimizer.amsgrad,
790
826
  map_location=device,
827
+ weights_only=False,
791
828
  )
792
829
  else:
793
830
  # Load the converted model
@@ -904,6 +941,7 @@ class TopDownPredictor(Predictor):
904
941
  amsgrad=confmap_config.trainer_config.optimizer.amsgrad,
905
942
  backbone_type=centered_instance_backbone_type,
906
943
  map_location=device,
944
+ weights_only=False,
907
945
  )
908
946
  else:
909
947
  # Load the converted model
@@ -1073,6 +1111,8 @@ class TopDownPredictor(Predictor):
1073
1111
  frames: Optional[list] = None,
1074
1112
  only_labeled_frames: bool = False,
1075
1113
  only_suggested_frames: bool = False,
1114
+ exclude_user_labeled: bool = False,
1115
+ only_predicted_frames: bool = False,
1076
1116
  video_index: Optional[int] = None,
1077
1117
  video_dataset: Optional[str] = None,
1078
1118
  video_input_format: str = "channels_last",
@@ -1085,6 +1125,8 @@ class TopDownPredictor(Predictor):
1085
1125
  frames: (list) List of frames indices. If `None`, all frames in the video are used. Default: None.
1086
1126
  only_labeled_frames: (bool) `True` if inference should be run only on user-labeled frames. Default: `False`.
1087
1127
  only_suggested_frames: (bool) `True` if inference should be run only on unlabeled suggested frames. Default: `False`.
1128
+ exclude_user_labeled: (bool) `True` to skip frames that have user-labeled instances. Default: `False`.
1129
+ only_predicted_frames: (bool) `True` to run inference only on frames that already have predictions. Default: `False`.
1088
1130
  video_index: (int) Integer index of video in .slp file to predict on. To be used
1089
1131
  with an .slp path as an alternative to specifying the video path.
1090
1132
  video_dataset: (str) The dataset for HDF5 videos.
@@ -1119,6 +1161,8 @@ class TopDownPredictor(Predictor):
1119
1161
  instances_key=self.instances_key,
1120
1162
  only_labeled_frames=only_labeled_frames,
1121
1163
  only_suggested_frames=only_suggested_frames,
1164
+ exclude_user_labeled=exclude_user_labeled,
1165
+ only_predicted_frames=only_predicted_frames,
1122
1166
  )
1123
1167
  self.videos = self.pipeline.labels.videos
1124
1168
 
@@ -1136,10 +1180,15 @@ class TopDownPredictor(Predictor):
1136
1180
 
1137
1181
  if isinstance(inference_object, sio.Labels) and video_index is not None:
1138
1182
  labels = inference_object
1183
+ video = labels.videos[video_index]
1184
+ # Filter out user-labeled frames if requested
1185
+ filtered_frames = _filter_user_labeled_frames(
1186
+ labels, video, frames, exclude_user_labeled
1187
+ )
1139
1188
  self.pipeline = provider.from_video(
1140
- video=labels.videos[video_index],
1189
+ video=video,
1141
1190
  queue_maxsize=queue_maxsize,
1142
- frames=frames,
1191
+ frames=filtered_frames,
1143
1192
  )
1144
1193
 
1145
1194
  else: # for mp4 or hdf5 videos
@@ -1394,6 +1443,7 @@ class SingleInstancePredictor(Predictor):
1394
1443
  learning_rate=confmap_config.trainer_config.optimizer.lr,
1395
1444
  amsgrad=confmap_config.trainer_config.optimizer.amsgrad,
1396
1445
  map_location=device,
1446
+ weights_only=False,
1397
1447
  )
1398
1448
  else:
1399
1449
  confmap_converted_model = load_legacy_model(
@@ -1491,6 +1541,8 @@ class SingleInstancePredictor(Predictor):
1491
1541
  frames: Optional[list] = None,
1492
1542
  only_labeled_frames: bool = False,
1493
1543
  only_suggested_frames: bool = False,
1544
+ exclude_user_labeled: bool = False,
1545
+ only_predicted_frames: bool = False,
1494
1546
  video_index: Optional[int] = None,
1495
1547
  video_dataset: Optional[str] = None,
1496
1548
  video_input_format: str = "channels_last",
@@ -1503,6 +1555,8 @@ class SingleInstancePredictor(Predictor):
1503
1555
  frames: List of frames indices. If `None`, all frames in the video are used. Default: None.
1504
1556
  only_labeled_frames: (bool) `True` if inference should be run only on user-labeled frames. Default: `False`.
1505
1557
  only_suggested_frames: (bool) `True` if inference should be run only on unlabeled suggested frames. Default: `False`.
1558
+ exclude_user_labeled: (bool) `True` to skip frames that have user-labeled instances. Default: `False`.
1559
+ only_predicted_frames: (bool) `True` to run inference only on frames that already have predictions. Default: `False`.
1506
1560
  video_index: (int) Integer index of video in .slp file to predict on. To be used
1507
1561
  with an .slp path as an alternative to specifying the video path.
1508
1562
  video_dataset: (str) The dataset for HDF5 videos.
@@ -1536,6 +1590,8 @@ class SingleInstancePredictor(Predictor):
1536
1590
  frame_buffer=frame_buffer,
1537
1591
  only_labeled_frames=only_labeled_frames,
1538
1592
  only_suggested_frames=only_suggested_frames,
1593
+ exclude_user_labeled=exclude_user_labeled,
1594
+ only_predicted_frames=only_predicted_frames,
1539
1595
  )
1540
1596
  self.videos = self.pipeline.labels.videos
1541
1597
 
@@ -1544,10 +1600,15 @@ class SingleInstancePredictor(Predictor):
1544
1600
 
1545
1601
  if isinstance(inference_object, sio.Labels) and video_index is not None:
1546
1602
  labels = inference_object
1603
+ video = labels.videos[video_index]
1604
+ # Filter out user-labeled frames if requested
1605
+ filtered_frames = _filter_user_labeled_frames(
1606
+ labels, video, frames, exclude_user_labeled
1607
+ )
1547
1608
  self.pipeline = provider.from_video(
1548
- video=labels.videos[video_index],
1609
+ video=video,
1549
1610
  queue_maxsize=queue_maxsize,
1550
- frames=frames,
1611
+ frames=filtered_frames,
1551
1612
  )
1552
1613
 
1553
1614
  else: # for mp4 or hdf5 videos
@@ -1833,6 +1894,7 @@ class BottomUpPredictor(Predictor):
1833
1894
  backbone_type=backbone_type,
1834
1895
  model_type="bottomup",
1835
1896
  map_location=device,
1897
+ weights_only=False,
1836
1898
  )
1837
1899
  else:
1838
1900
  bottomup_converted_model = load_legacy_model(
@@ -1929,6 +1991,8 @@ class BottomUpPredictor(Predictor):
1929
1991
  frames: Optional[list] = None,
1930
1992
  only_labeled_frames: bool = False,
1931
1993
  only_suggested_frames: bool = False,
1994
+ exclude_user_labeled: bool = False,
1995
+ only_predicted_frames: bool = False,
1932
1996
  video_index: Optional[int] = None,
1933
1997
  video_dataset: Optional[str] = None,
1934
1998
  video_input_format: str = "channels_last",
@@ -1941,6 +2005,8 @@ class BottomUpPredictor(Predictor):
1941
2005
  frames: List of frames indices. If `None`, all frames in the video are used. Default: None.
1942
2006
  only_labeled_frames: (bool) `True` if inference should be run only on user-labeled frames. Default: `False`.
1943
2007
  only_suggested_frames: (bool) `True` if inference should be run only on unlabeled suggested frames. Default: `False`.
2008
+ exclude_user_labeled: (bool) `True` to skip frames that have user-labeled instances. Default: `False`.
2009
+ only_predicted_frames: (bool) `True` to run inference only on frames that already have predictions. Default: `False`.
1944
2010
  video_index: (int) Integer index of video in .slp file to predict on. To be used
1945
2011
  with an .slp path as an alternative to specifying the video path.
1946
2012
  video_dataset: (str) The dataset for HDF5 videos.
@@ -1974,6 +2040,8 @@ class BottomUpPredictor(Predictor):
1974
2040
  frame_buffer=frame_buffer,
1975
2041
  only_labeled_frames=only_labeled_frames,
1976
2042
  only_suggested_frames=only_suggested_frames,
2043
+ exclude_user_labeled=exclude_user_labeled,
2044
+ only_predicted_frames=only_predicted_frames,
1977
2045
  )
1978
2046
 
1979
2047
  self.videos = self.pipeline.labels.videos
@@ -1983,10 +2051,15 @@ class BottomUpPredictor(Predictor):
1983
2051
 
1984
2052
  if isinstance(inference_object, sio.Labels) and video_index is not None:
1985
2053
  labels = inference_object
2054
+ video = labels.videos[video_index]
2055
+ # Filter out user-labeled frames if requested
2056
+ filtered_frames = _filter_user_labeled_frames(
2057
+ labels, video, frames, exclude_user_labeled
2058
+ )
1986
2059
  self.pipeline = provider.from_video(
1987
- video=labels.videos[video_index],
2060
+ video=video,
1988
2061
  queue_maxsize=queue_maxsize,
1989
- frames=frames,
2062
+ frames=filtered_frames,
1990
2063
  )
1991
2064
 
1992
2065
  else: # for mp4 or hdf5 videos
@@ -2260,6 +2333,7 @@ class BottomUpMultiClassPredictor(Predictor):
2260
2333
  optimizer=bottomup_config.trainer_config.optimizer_name,
2261
2334
  learning_rate=bottomup_config.trainer_config.optimizer.lr,
2262
2335
  amsgrad=bottomup_config.trainer_config.optimizer.amsgrad,
2336
+ weights_only=False,
2263
2337
  )
2264
2338
  else:
2265
2339
  bottomup_converted_model = load_legacy_model(
@@ -2364,6 +2438,8 @@ class BottomUpMultiClassPredictor(Predictor):
2364
2438
  frames: Optional[list] = None,
2365
2439
  only_labeled_frames: bool = False,
2366
2440
  only_suggested_frames: bool = False,
2441
+ exclude_user_labeled: bool = False,
2442
+ only_predicted_frames: bool = False,
2367
2443
  video_index: Optional[int] = None,
2368
2444
  video_dataset: Optional[str] = None,
2369
2445
  video_input_format: str = "channels_last",
@@ -2376,6 +2452,8 @@ class BottomUpMultiClassPredictor(Predictor):
2376
2452
  frames: List of frames indices. If `None`, all frames in the video are used. Default: None.
2377
2453
  only_labeled_frames: (bool) `True` if inference should be run only on user-labeled frames. Default: `False`.
2378
2454
  only_suggested_frames: (bool) `True` if inference should be run only on unlabeled suggested frames. Default: `False`.
2455
+ exclude_user_labeled: (bool) `True` to skip frames that have user-labeled instances. Default: `False`.
2456
+ only_predicted_frames: (bool) `True` to run inference only on frames that already have predictions. Default: `False`.
2379
2457
  video_index: (int) Integer index of video in .slp file to predict on. To be used
2380
2458
  with an .slp path as an alternative to specifying the video path.
2381
2459
  video_dataset: (str) The dataset for HDF5 videos.
@@ -2411,6 +2489,8 @@ class BottomUpMultiClassPredictor(Predictor):
2411
2489
  frame_buffer=frame_buffer,
2412
2490
  only_labeled_frames=only_labeled_frames,
2413
2491
  only_suggested_frames=only_suggested_frames,
2492
+ exclude_user_labeled=exclude_user_labeled,
2493
+ only_predicted_frames=only_predicted_frames,
2414
2494
  )
2415
2495
 
2416
2496
  self.videos = self.pipeline.labels.videos
@@ -2420,10 +2500,15 @@ class BottomUpMultiClassPredictor(Predictor):
2420
2500
 
2421
2501
  if isinstance(inference_object, sio.Labels) and video_index is not None:
2422
2502
  labels = inference_object
2503
+ video = labels.videos[video_index]
2504
+ # Filter out user-labeled frames if requested
2505
+ filtered_frames = _filter_user_labeled_frames(
2506
+ labels, video, frames, exclude_user_labeled
2507
+ )
2423
2508
  self.pipeline = provider.from_video(
2424
- video=labels.videos[video_index],
2509
+ video=video,
2425
2510
  queue_maxsize=queue_maxsize,
2426
- frames=frames,
2511
+ frames=filtered_frames,
2427
2512
  )
2428
2513
 
2429
2514
  else: # for mp4 or hdf5 videos
@@ -2682,9 +2767,6 @@ class TopDownMultiClassPredictor(Predictor):
2682
2767
  max_stride=max_stride,
2683
2768
  input_scale=self.confmap_config.data_config.preprocessing.scale,
2684
2769
  )
2685
- centroid_crop_layer.precrop_resize = (
2686
- self.confmap_config.data_config.preprocessing.scale
2687
- )
2688
2770
 
2689
2771
  if self.centroid_config is None:
2690
2772
  self.instances_key = (
@@ -2805,6 +2887,7 @@ class TopDownMultiClassPredictor(Predictor):
2805
2887
  learning_rate=centroid_config.trainer_config.optimizer.lr,
2806
2888
  amsgrad=centroid_config.trainer_config.optimizer.amsgrad,
2807
2889
  map_location=device,
2890
+ weights_only=False,
2808
2891
  )
2809
2892
 
2810
2893
  else:
@@ -2929,6 +3012,7 @@ class TopDownMultiClassPredictor(Predictor):
2929
3012
  learning_rate=confmap_config.trainer_config.optimizer.lr,
2930
3013
  amsgrad=confmap_config.trainer_config.optimizer.amsgrad,
2931
3014
  map_location=device,
3015
+ weights_only=False,
2932
3016
  )
2933
3017
  else:
2934
3018
  confmap_converted_model = load_legacy_model(
@@ -3109,6 +3193,8 @@ class TopDownMultiClassPredictor(Predictor):
3109
3193
  frames: Optional[list] = None,
3110
3194
  only_labeled_frames: bool = False,
3111
3195
  only_suggested_frames: bool = False,
3196
+ exclude_user_labeled: bool = False,
3197
+ only_predicted_frames: bool = False,
3112
3198
  video_index: Optional[int] = None,
3113
3199
  video_dataset: Optional[str] = None,
3114
3200
  video_input_format: str = "channels_last",
@@ -3121,6 +3207,8 @@ class TopDownMultiClassPredictor(Predictor):
3121
3207
  frames: (list) List of frames indices. If `None`, all frames in the video are used. Default: None.
3122
3208
  only_labeled_frames: (bool) `True` if inference should be run only on user-labeled frames. Default: `False`.
3123
3209
  only_suggested_frames: (bool) `True` if inference should be run only on unlabeled suggested frames. Default: `False`.
3210
+ exclude_user_labeled: (bool) `True` to skip frames that have user-labeled instances. Default: `False`.
3211
+ only_predicted_frames: (bool) `True` to run inference only on frames that already have predictions. Default: `False`.
3124
3212
  video_index: (int) Integer index of video in .slp file to predict on. To be used
3125
3213
  with an .slp path as an alternative to specifying the video path.
3126
3214
  video_dataset: (str) The dataset for HDF5 videos.
@@ -3155,6 +3243,8 @@ class TopDownMultiClassPredictor(Predictor):
3155
3243
  instances_key=self.instances_key,
3156
3244
  only_labeled_frames=only_labeled_frames,
3157
3245
  only_suggested_frames=only_suggested_frames,
3246
+ exclude_user_labeled=exclude_user_labeled,
3247
+ only_predicted_frames=only_predicted_frames,
3158
3248
  )
3159
3249
  self.videos = self.pipeline.labels.videos
3160
3250
 
@@ -3172,10 +3262,15 @@ class TopDownMultiClassPredictor(Predictor):
3172
3262
 
3173
3263
  if isinstance(inference_object, sio.Labels) and video_index is not None:
3174
3264
  labels = inference_object
3265
+ video = labels.videos[video_index]
3266
+ # Filter out user-labeled frames if requested
3267
+ filtered_frames = _filter_user_labeled_frames(
3268
+ labels, video, frames, exclude_user_labeled
3269
+ )
3175
3270
  self.pipeline = provider.from_video(
3176
- video=labels.videos[video_index],
3271
+ video=video,
3177
3272
  queue_maxsize=queue_maxsize,
3178
- frames=frames,
3273
+ frames=filtered_frames,
3179
3274
  )
3180
3275
 
3181
3276
  else: # for mp4 or hdf5 videos
@@ -0,0 +1,292 @@
1
+ """Provenance metadata utilities for inference outputs.
2
+
3
+ This module provides utilities for building and managing provenance metadata
4
+ that is stored in SLP files produced during inference. Provenance metadata
5
+ helps track where predictions came from and how they were generated.
6
+ """
7
+
8
+ from datetime import datetime
9
+ from pathlib import Path
10
+ from typing import Any, Optional, Union
11
+
12
+ import sleap_io as sio
13
+
14
+ import sleap_nn
15
+ from sleap_nn.system_info import get_system_info_dict
16
+
17
+
18
+ def build_inference_provenance(
19
+ model_paths: Optional[list[str]] = None,
20
+ model_type: Optional[str] = None,
21
+ start_time: Optional[datetime] = None,
22
+ end_time: Optional[datetime] = None,
23
+ input_labels: Optional[sio.Labels] = None,
24
+ input_path: Optional[Union[str, Path]] = None,
25
+ frames_processed: Optional[int] = None,
26
+ frames_total: Optional[int] = None,
27
+ frame_selection_method: Optional[str] = None,
28
+ inference_params: Optional[dict[str, Any]] = None,
29
+ tracking_params: Optional[dict[str, Any]] = None,
30
+ device: Optional[str] = None,
31
+ cli_args: Optional[dict[str, Any]] = None,
32
+ include_system_info: bool = True,
33
+ ) -> dict[str, Any]:
34
+ """Build provenance metadata dictionary for inference output.
35
+
36
+ This function creates a comprehensive provenance dictionary that captures
37
+ all relevant metadata about an inference run, enabling reproducibility
38
+ and tracking of prediction origins.
39
+
40
+ Args:
41
+ model_paths: List of paths to model checkpoints used for inference.
42
+ model_type: Type of model used (e.g., "top_down", "bottom_up",
43
+ "single_instance").
44
+ start_time: Datetime when inference started.
45
+ end_time: Datetime when inference finished.
46
+ input_labels: Input Labels object if inference was run on an SLP file.
47
+ The provenance from this object will be preserved.
48
+ input_path: Path to input file (SLP or video).
49
+ frames_processed: Number of frames that were processed.
50
+ frames_total: Total number of frames in the input.
51
+ frame_selection_method: Method used to select frames (e.g., "all",
52
+ "labeled", "suggested", "range").
53
+ inference_params: Dictionary of inference parameters (peak_threshold,
54
+ integral_refinement, batch_size, etc.).
55
+ tracking_params: Dictionary of tracking parameters if tracking was run.
56
+ device: Device used for inference (e.g., "cuda:0", "cpu", "mps").
57
+ cli_args: Command-line arguments if available.
58
+ include_system_info: If True, include detailed system information.
59
+ Set to False for lighter-weight provenance.
60
+
61
+ Returns:
62
+ Dictionary containing provenance metadata suitable for storing in
63
+ Labels.provenance.
64
+
65
+ Example:
66
+ >>> from datetime import datetime
67
+ >>> provenance = build_inference_provenance(
68
+ ... model_paths=["/path/to/model.ckpt"],
69
+ ... model_type="top_down",
70
+ ... start_time=datetime.now(),
71
+ ... end_time=datetime.now(),
72
+ ... device="cuda:0",
73
+ ... )
74
+ >>> labels.provenance = provenance
75
+ >>> labels.save("predictions.slp")
76
+ """
77
+ provenance: dict[str, Any] = {}
78
+
79
+ # Timestamps
80
+ if start_time is not None:
81
+ provenance["inference_start_timestamp"] = start_time.isoformat()
82
+ if end_time is not None:
83
+ provenance["inference_end_timestamp"] = end_time.isoformat()
84
+ if start_time is not None and end_time is not None:
85
+ runtime_seconds = (end_time - start_time).total_seconds()
86
+ provenance["inference_runtime_seconds"] = runtime_seconds
87
+
88
+ # Version information
89
+ provenance["sleap_nn_version"] = sleap_nn.__version__
90
+ provenance["sleap_io_version"] = sio.__version__
91
+
92
+ # Model information
93
+ if model_paths is not None:
94
+ # Store as absolute POSIX paths for cross-platform compatibility
95
+ provenance["model_paths"] = [
96
+ Path(p).resolve().as_posix() if isinstance(p, (str, Path)) else str(p)
97
+ for p in model_paths
98
+ ]
99
+ if model_type is not None:
100
+ provenance["model_type"] = model_type
101
+
102
+ # Input data lineage
103
+ if input_path is not None:
104
+ provenance["source_file"] = (
105
+ Path(input_path).resolve().as_posix()
106
+ if isinstance(input_path, (str, Path))
107
+ else str(input_path)
108
+ )
109
+
110
+ # Preserve input provenance if available
111
+ if input_labels is not None and hasattr(input_labels, "provenance"):
112
+ input_prov = dict(input_labels.provenance)
113
+ if input_prov:
114
+ provenance["input_provenance"] = input_prov
115
+ # Also set source_labels for compatibility with sleap-io conventions
116
+ if "filename" in input_prov:
117
+ provenance["source_labels"] = input_prov["filename"]
118
+
119
+ # Frame selection information
120
+ if frames_processed is not None or frames_total is not None:
121
+ frame_info: dict[str, Any] = {}
122
+ if frame_selection_method is not None:
123
+ frame_info["method"] = frame_selection_method
124
+ if frames_processed is not None:
125
+ frame_info["frames_processed"] = frames_processed
126
+ if frames_total is not None:
127
+ frame_info["frames_total"] = frames_total
128
+ if frame_info:
129
+ provenance["frame_selection"] = frame_info
130
+
131
+ # Inference parameters
132
+ if inference_params is not None:
133
+ # Filter out None values and convert paths
134
+ clean_params = {}
135
+ for key, value in inference_params.items():
136
+ if value is not None:
137
+ if isinstance(value, Path):
138
+ clean_params[key] = value.as_posix()
139
+ else:
140
+ clean_params[key] = value
141
+ if clean_params:
142
+ provenance["inference_config"] = clean_params
143
+
144
+ # Tracking parameters
145
+ if tracking_params is not None:
146
+ clean_tracking = {k: v for k, v in tracking_params.items() if v is not None}
147
+ if clean_tracking:
148
+ provenance["tracking_config"] = clean_tracking
149
+
150
+ # Device information
151
+ if device is not None:
152
+ provenance["device"] = device
153
+
154
+ # CLI arguments
155
+ if cli_args is not None:
156
+ # Filter out None values
157
+ clean_cli = {k: v for k, v in cli_args.items() if v is not None}
158
+ if clean_cli:
159
+ provenance["cli_args"] = clean_cli
160
+
161
+ # System information (can be disabled for lighter provenance)
162
+ if include_system_info:
163
+ try:
164
+ system_info = get_system_info_dict()
165
+ # Extract key fields for provenance (avoid excessive nesting)
166
+ provenance["system_info"] = {
167
+ "python_version": system_info.get("python_version"),
168
+ "platform": system_info.get("platform"),
169
+ "pytorch_version": system_info.get("pytorch_version"),
170
+ "cuda_version": system_info.get("cuda_version"),
171
+ "accelerator": system_info.get("accelerator"),
172
+ "gpu_count": system_info.get("gpu_count"),
173
+ }
174
+ # Include GPU names if available
175
+ if system_info.get("gpus"):
176
+ provenance["system_info"]["gpus"] = [
177
+ gpu.get("name") for gpu in system_info["gpus"]
178
+ ]
179
+ except Exception:
180
+ # Don't fail inference if system info collection fails
181
+ pass
182
+
183
+ return provenance
184
+
185
+
186
+ def build_tracking_only_provenance(
187
+ input_labels: Optional[sio.Labels] = None,
188
+ input_path: Optional[Union[str, Path]] = None,
189
+ start_time: Optional[datetime] = None,
190
+ end_time: Optional[datetime] = None,
191
+ tracking_params: Optional[dict[str, Any]] = None,
192
+ frames_processed: Optional[int] = None,
193
+ include_system_info: bool = True,
194
+ ) -> dict[str, Any]:
195
+ """Build provenance metadata for tracking-only pipeline.
196
+
197
+ This is a simplified version of build_inference_provenance for when
198
+ only tracking is run without model inference.
199
+
200
+ Args:
201
+ input_labels: Input Labels object with existing predictions.
202
+ input_path: Path to input SLP file.
203
+ start_time: Datetime when tracking started.
204
+ end_time: Datetime when tracking finished.
205
+ tracking_params: Dictionary of tracking parameters.
206
+ frames_processed: Number of frames that were tracked.
207
+ include_system_info: If True, include system information.
208
+
209
+ Returns:
210
+ Dictionary containing provenance metadata.
211
+ """
212
+ provenance: dict[str, Any] = {}
213
+
214
+ # Timestamps
215
+ if start_time is not None:
216
+ provenance["tracking_start_timestamp"] = start_time.isoformat()
217
+ if end_time is not None:
218
+ provenance["tracking_end_timestamp"] = end_time.isoformat()
219
+ if start_time is not None and end_time is not None:
220
+ runtime_seconds = (end_time - start_time).total_seconds()
221
+ provenance["tracking_runtime_seconds"] = runtime_seconds
222
+
223
+ # Version information
224
+ provenance["sleap_nn_version"] = sleap_nn.__version__
225
+ provenance["sleap_io_version"] = sio.__version__
226
+
227
+ # Note that this is tracking-only
228
+ provenance["pipeline_type"] = "tracking_only"
229
+
230
+ # Input data lineage
231
+ if input_path is not None:
232
+ provenance["source_file"] = (
233
+ Path(input_path).resolve().as_posix()
234
+ if isinstance(input_path, (str, Path))
235
+ else str(input_path)
236
+ )
237
+
238
+ # Preserve input provenance if available
239
+ if input_labels is not None and hasattr(input_labels, "provenance"):
240
+ input_prov = dict(input_labels.provenance)
241
+ if input_prov:
242
+ provenance["input_provenance"] = input_prov
243
+ if "filename" in input_prov:
244
+ provenance["source_labels"] = input_prov["filename"]
245
+
246
+ # Frame information
247
+ if frames_processed is not None:
248
+ provenance["frames_processed"] = frames_processed
249
+
250
+ # Tracking parameters
251
+ if tracking_params is not None:
252
+ clean_tracking = {k: v for k, v in tracking_params.items() if v is not None}
253
+ if clean_tracking:
254
+ provenance["tracking_config"] = clean_tracking
255
+
256
+ # System information
257
+ if include_system_info:
258
+ try:
259
+ system_info = get_system_info_dict()
260
+ provenance["system_info"] = {
261
+ "python_version": system_info.get("python_version"),
262
+ "platform": system_info.get("platform"),
263
+ "pytorch_version": system_info.get("pytorch_version"),
264
+ "accelerator": system_info.get("accelerator"),
265
+ }
266
+ except Exception:
267
+ pass
268
+
269
+ return provenance
270
+
271
+
272
+ def merge_provenance(
273
+ base_provenance: dict[str, Any],
274
+ additional: dict[str, Any],
275
+ overwrite: bool = True,
276
+ ) -> dict[str, Any]:
277
+ """Merge additional provenance fields into base provenance.
278
+
279
+ Args:
280
+ base_provenance: Base provenance dictionary.
281
+ additional: Additional fields to merge.
282
+ overwrite: If True, additional fields overwrite base fields.
283
+ If False, base fields take precedence.
284
+
285
+ Returns:
286
+ Merged provenance dictionary.
287
+ """
288
+ result = dict(base_provenance)
289
+ for key, value in additional.items():
290
+ if key not in result or overwrite:
291
+ result[key] = value
292
+ return result