sleap-nn 0.0.5__py3-none-any.whl → 0.1.0a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -47,9 +47,6 @@ class CentroidCrop(L.LightningModule):
47
47
  crop_hw: Tuple (height, width) representing the crop size.
48
48
  input_scale: Float indicating if the images should be resized before being
49
49
  passed to the model.
50
- precrop_resize: Float indicating the factor by which the original images
51
- (not images resized for centroid model) should be resized before cropping.
52
- Note: This resize happens only after getting the predictions for centroid model.
53
50
  max_stride: Maximum stride in a model that the images must be divisible by.
54
51
  If > 1, this will pad the bottom and right of the images to ensure they meet
55
52
  this divisibility criteria. Padding is applied after the scaling specified
@@ -74,7 +71,6 @@ class CentroidCrop(L.LightningModule):
74
71
  return_crops: bool = False,
75
72
  crop_hw: Optional[List[int]] = None,
76
73
  input_scale: float = 1.0,
77
- precrop_resize: float = 1.0,
78
74
  max_stride: int = 1,
79
75
  use_gt_centroids: bool = False,
80
76
  anchor_ind: Optional[int] = None,
@@ -92,22 +88,25 @@ class CentroidCrop(L.LightningModule):
92
88
  self.return_crops = return_crops
93
89
  self.crop_hw = crop_hw
94
90
  self.input_scale = input_scale
95
- self.precrop_resize = precrop_resize
96
91
  self.max_stride = max_stride
97
92
  self.use_gt_centroids = use_gt_centroids
98
93
  self.anchor_ind = anchor_ind
99
94
 
100
- def _generate_crops(self, inputs):
95
+ def _generate_crops(self, inputs, cms: Optional[torch.Tensor] = None):
101
96
  """Generate Crops from the predicted centroids."""
102
97
  crops_dict = []
103
- for centroid, centroid_val, image, fidx, vidx, sz, eff_sc in zip(
104
- self.refined_peaks_batched,
105
- self.peak_vals_batched,
106
- inputs["image"],
107
- inputs["frame_idx"],
108
- inputs["video_idx"],
109
- inputs["orig_size"],
110
- inputs["eff_scale"],
98
+ if cms is not None:
99
+ cms = cms.detach()
100
+ for idx, (centroid, centroid_val, image, fidx, vidx, sz, eff_sc) in enumerate(
101
+ zip(
102
+ self.refined_peaks_batched,
103
+ self.peak_vals_batched,
104
+ inputs["image"],
105
+ inputs["frame_idx"],
106
+ inputs["video_idx"],
107
+ inputs["orig_size"],
108
+ inputs["eff_scale"],
109
+ )
111
110
  ):
112
111
  if torch.any(torch.isnan(centroid)):
113
112
  if torch.all(torch.isnan(centroid)):
@@ -149,6 +148,11 @@ class CentroidCrop(L.LightningModule):
149
148
  ex["instance_image"] = instance_image.unsqueeze(dim=1)
150
149
  ex["orig_size"] = torch.cat([torch.Tensor(sz)] * n)
151
150
  ex["eff_scale"] = torch.Tensor([eff_sc] * n)
151
+ ex["pred_centroids"] = centroid
152
+ if self.return_confmaps:
153
+ ex["pred_centroid_confmaps"] = torch.cat(
154
+ [cms[idx].unsqueeze(dim=0)] * n
155
+ )
152
156
  crops_dict.append(ex)
153
157
 
154
158
  return crops_dict
@@ -204,12 +208,6 @@ class CentroidCrop(L.LightningModule):
204
208
 
205
209
  if self.return_crops:
206
210
  crops_dict = self._generate_crops(inputs)
207
- inputs["image"] = resize_image(inputs["image"], self.precrop_resize)
208
- inputs["centroids"] *= self.precrop_resize
209
- scaled_refined_peaks = []
210
- for ref_peak in self.refined_peaks_batched:
211
- scaled_refined_peaks.append(ref_peak * self.precrop_resize)
212
- self.refined_peaks_batched = scaled_refined_peaks
213
211
  return crops_dict
214
212
  else:
215
213
  return inputs
@@ -274,19 +272,13 @@ class CentroidCrop(L.LightningModule):
274
272
 
275
273
  # Generate crops if return_crops=True to pass the crops to CenteredInstance model.
276
274
  if self.return_crops:
277
- inputs["image"] = resize_image(inputs["image"], self.precrop_resize)
278
- scaled_refined_peaks = []
279
- for ref_peak in self.refined_peaks_batched:
280
- scaled_refined_peaks.append(ref_peak * self.precrop_resize)
281
- self.refined_peaks_batched = scaled_refined_peaks
282
-
283
275
  inputs.update(
284
276
  {
285
277
  "centroids": self.refined_peaks_batched,
286
278
  "centroid_vals": self.peak_vals_batched,
287
279
  }
288
280
  )
289
- crops_dict = self._generate_crops(inputs)
281
+ crops_dict = self._generate_crops(inputs, cms)
290
282
  return crops_dict
291
283
  else:
292
284
  # batch the peaks to pass it to FindInstancePeaksGroundTruth class.
@@ -359,7 +351,11 @@ class FindInstancePeaksGroundTruth(L.LightningModule):
359
351
 
360
352
  def forward(self, batch: Dict[str, torch.Tensor]) -> Dict[str, np.array]:
361
353
  """Return the ground truth instance peaks given a set of crops."""
362
- b, _, max_inst, nodes, _ = batch["instances"].shape
354
+ b, _, _, nodes, _ = batch["instances"].shape
355
+ # Use number of centroids as max_inst to ensure consistent output shape
356
+ # This handles the case where max_instances limits centroids but instances
357
+ # tensor has a different (global) max_instances from the labels file
358
+ num_centroids = batch["centroids"].shape[2]
363
359
  inst = (
364
360
  batch["instances"].unsqueeze(dim=-4).float()
365
361
  ) # (batch, 1, 1, n_inst, nodes, 2)
@@ -389,26 +385,26 @@ class FindInstancePeaksGroundTruth(L.LightningModule):
389
385
  parsed = 0
390
386
  for i in range(b):
391
387
  if i not in matched_batch_inds:
392
- batch_peaks = torch.full((max_inst, nodes, 2), torch.nan)
393
- vals = torch.full((max_inst, nodes), torch.nan)
388
+ batch_peaks = torch.full((num_centroids, nodes, 2), torch.nan)
389
+ vals = torch.full((num_centroids, nodes), torch.nan)
394
390
  else:
395
391
  c = counts[i]
396
392
  batch_peaks = peaks_list[parsed : parsed + c]
397
393
  num_inst = len(batch_peaks)
398
394
  vals = torch.ones((num_inst, nodes))
399
- if c < max_inst:
395
+ if c < num_centroids:
400
396
  batch_peaks = torch.cat(
401
397
  [
402
398
  batch_peaks,
403
- torch.full((max_inst - num_inst, nodes, 2), torch.nan),
399
+ torch.full((num_centroids - num_inst, nodes, 2), torch.nan),
404
400
  ]
405
401
  )
406
402
  vals = torch.cat(
407
- [vals, torch.full((max_inst - num_inst, nodes), torch.nan)]
403
+ [vals, torch.full((num_centroids - num_inst, nodes), torch.nan)]
408
404
  )
409
405
  else:
410
- batch_peaks = batch_peaks[:max_inst]
411
- vals = vals[:max_inst]
406
+ batch_peaks = batch_peaks[:num_centroids]
407
+ vals = vals[:num_centroids]
412
408
  parsed += c
413
409
 
414
410
  batch_peaks = batch_peaks.unsqueeze(dim=0)
@@ -432,33 +428,45 @@ class FindInstancePeaksGroundTruth(L.LightningModule):
432
428
  peaks_output["pred_instance_peaks"] = peaks
433
429
  peaks_output["pred_peak_values"] = peaks_vals
434
430
 
435
- batch_size, num_centroids = (
436
- batch["centroids"].shape[0],
437
- batch["centroids"].shape[2],
438
- )
431
+ batch_size = batch["centroids"].shape[0]
439
432
  output_dict = {}
440
433
  output_dict["centroid"] = batch["centroids"].squeeze(dim=1).reshape(-1, 1, 2)
441
434
  output_dict["centroid_val"] = batch["centroid_vals"].reshape(-1)
442
- output_dict["pred_instance_peaks"] = batch["pred_instance_peaks"].reshape(
443
- -1, nodes, 2
435
+ output_dict["pred_instance_peaks"] = peaks_output[
436
+ "pred_instance_peaks"
437
+ ].reshape(-1, nodes, 2)
438
+ output_dict["pred_peak_values"] = peaks_output["pred_peak_values"].reshape(
439
+ -1, nodes
444
440
  )
445
- output_dict["pred_peak_values"] = batch["pred_peak_values"].reshape(-1, nodes)
446
441
  output_dict["instance_bbox"] = torch.zeros(
447
442
  (batch_size * num_centroids, 1, 4, 2)
448
443
  )
449
444
  frame_inds = []
450
445
  video_inds = []
451
446
  orig_szs = []
447
+ images = []
448
+ centroid_confmaps = []
452
449
  for b_idx in range(b):
453
450
  curr_batch_size = len(batch["centroids"][b_idx][0])
454
451
  frame_inds.extend([batch["frame_idx"][b_idx]] * curr_batch_size)
455
452
  video_inds.extend([batch["video_idx"][b_idx]] * curr_batch_size)
456
453
  orig_szs.append(torch.cat([batch["orig_size"][b_idx]] * curr_batch_size))
454
+ images.append(
455
+ batch["image"][b_idx].unsqueeze(0).repeat(curr_batch_size, 1, 1, 1, 1)
456
+ )
457
+ if "pred_centroid_confmaps" in batch:
458
+ centroid_confmaps.append(
459
+ batch["pred_centroid_confmaps"][b_idx]
460
+ .unsqueeze(0)
461
+ .repeat(curr_batch_size, 1, 1, 1)
462
+ )
457
463
 
458
464
  output_dict["frame_idx"] = torch.tensor(frame_inds)
459
465
  output_dict["video_idx"] = torch.tensor(video_inds)
460
466
  output_dict["orig_size"] = torch.concatenate(orig_szs, dim=0)
461
-
467
+ output_dict["image"] = torch.cat(images, dim=0)
468
+ if centroid_confmaps:
469
+ output_dict["pred_centroid_confmaps"] = torch.cat(centroid_confmaps, dim=0)
462
470
  return output_dict
463
471
 
464
472
 
@@ -548,6 +556,8 @@ class FindInstancePeaks(L.LightningModule):
548
556
  # Network forward pass.
549
557
  # resize and pad the input image
550
558
  input_image = inputs["instance_image"]
559
+ # resize the crop image
560
+ input_image = resize_image(input_image, self.input_scale)
551
561
  if self.max_stride != 1:
552
562
  input_image = apply_pad_to_stride(input_image, self.max_stride)
553
563
 
@@ -569,8 +579,6 @@ class FindInstancePeaks(L.LightningModule):
569
579
  inputs["eff_scale"].unsqueeze(dim=1).unsqueeze(dim=2).to(peak_points.device)
570
580
  )
571
581
 
572
- inputs["instance_bbox"] = inputs["instance_bbox"] / self.input_scale
573
-
574
582
  inputs["instance_bbox"] = inputs["instance_bbox"] / (
575
583
  inputs["eff_scale"]
576
584
  .unsqueeze(dim=1)
@@ -679,6 +687,8 @@ class TopDownMultiClassFindInstancePeaks(L.LightningModule):
679
687
  # Network forward pass.
680
688
  # resize and pad the input image
681
689
  input_image = inputs["instance_image"]
690
+ # resize the crop image
691
+ input_image = resize_image(input_image, self.input_scale)
682
692
  if self.max_stride != 1:
683
693
  input_image = apply_pad_to_stride(input_image, self.max_stride)
684
694
 
@@ -702,8 +712,6 @@ class TopDownMultiClassFindInstancePeaks(L.LightningModule):
702
712
  inputs["eff_scale"].unsqueeze(dim=1).unsqueeze(dim=2).to(peak_points.device)
703
713
  )
704
714
 
705
- inputs["instance_bbox"] = inputs["instance_bbox"] / self.input_scale
706
-
707
715
  inputs["instance_bbox"] = inputs["instance_bbox"] / (
708
716
  inputs["eff_scale"]
709
717
  .unsqueeze(dim=1)
sleap_nn/predict.py CHANGED
@@ -15,6 +15,11 @@ from sleap_nn.tracking.tracker import (
15
15
  connect_single_breaks,
16
16
  cull_instances,
17
17
  )
18
+ from sleap_nn.system_info import get_startup_info_string
19
+ from sleap_nn.inference.provenance import (
20
+ build_inference_provenance,
21
+ build_tracking_only_provenance,
22
+ )
18
23
  from omegaconf import OmegaConf
19
24
  import sleap_io as sio
20
25
  from pathlib import Path
@@ -58,6 +63,8 @@ def run_inference(
58
63
  anchor_part: Optional[str] = None,
59
64
  only_labeled_frames: bool = False,
60
65
  only_suggested_frames: bool = False,
66
+ exclude_user_labeled: bool = False,
67
+ only_predicted_frames: bool = False,
61
68
  no_empty_frames: bool = False,
62
69
  batch_size: int = 4,
63
70
  queue_maxsize: int = 8,
@@ -136,6 +143,8 @@ def run_inference(
136
143
  provided, the anchor part in the `training_config.yaml` is used. Default: `None`.
137
144
  only_labeled_frames: (bool) `True` if inference should be run only on user-labeled frames. Default: `False`.
138
145
  only_suggested_frames: (bool) `True` if inference should be run only on unlabeled suggested frames. Default: `False`.
146
+ exclude_user_labeled: (bool) `True` to skip frames that have user-labeled instances. Default: `False`.
147
+ only_predicted_frames: (bool) `True` to run inference only on frames that already have predictions. Default: `False`.
139
148
  no_empty_frames: (bool) `True` if empty frames that did not have predictions should be cleared before saving to output. Default: `False`.
140
149
  batch_size: (int) Number of samples per batch. Default: 4.
141
150
  queue_maxsize: (int) Maximum size of the frame buffer queue. Default: 8.
@@ -145,7 +154,7 @@ def run_inference(
145
154
  video_input_format: (str) The input_format for HDF5 videos.
146
155
  frames: (list) List of frames indices. If `None`, all frames in the video are used. Default: None.
147
156
  crop_size: (int) Crop size. If not provided, the crop size from training_config.yaml is used.
148
- Default: None.
157
+ If `input_scale` is provided, then the cropped image will be resized according to `input_scale`. Default: None.
149
158
  peak_threshold: (float) Minimum confidence threshold. Peaks with values below
150
159
  this will be ignored. Default: 0.2. This can also be `List[float]` for topdown
151
160
  centroid and centered-instance model, where the first element corresponds
@@ -256,6 +265,27 @@ def run_inference(
256
265
  "scale": input_scale,
257
266
  }
258
267
 
268
+ # Validate mutually exclusive frame filter flags
269
+ if only_labeled_frames and exclude_user_labeled:
270
+ message = (
271
+ "--only_labeled_frames and --exclude_user_labeled are mutually exclusive "
272
+ "(would result in zero frames)"
273
+ )
274
+ logger.error(message)
275
+ raise ValueError(message)
276
+
277
+ if (
278
+ only_predicted_frames
279
+ and data_path is not None
280
+ and not data_path.endswith(".slp")
281
+ ):
282
+ message = (
283
+ "--only_predicted_frames requires a .slp file input "
284
+ "(need Labels to know which frames have predictions)"
285
+ )
286
+ logger.error(message)
287
+ raise ValueError(message)
288
+
259
289
  if model_paths is None or not len(
260
290
  model_paths
261
291
  ): # if model paths is not provided, run tracking-only pipeline.
@@ -273,7 +303,8 @@ def run_inference(
273
303
  raise ValueError(message)
274
304
 
275
305
  start_inf_time = time()
276
- start_timestamp = str(datetime.now())
306
+ start_datetime = datetime.now()
307
+ start_timestamp = str(start_datetime)
277
308
  logger.info(f"Started tracking at: {start_timestamp}")
278
309
 
279
310
  labels = sio.load_slp(data_path) if input_labels is None else input_labels
@@ -302,7 +333,22 @@ def run_inference(
302
333
 
303
334
  if post_connect_single_breaks or tracking_pre_cull_to_target:
304
335
  if tracking_target_instance_count is None and max_instances is None:
305
- message = "Both tracking_target_instance_count and max_instances is set to 0. To connect single breaks or pre-cull to target, at least one of them should be set to an integer."
336
+ features_requested = []
337
+ if post_connect_single_breaks:
338
+ features_requested.append("--post_connect_single_breaks")
339
+ if tracking_pre_cull_to_target:
340
+ features_requested.append("--tracking_pre_cull_to_target")
341
+ features_str = " and ".join(features_requested)
342
+
343
+ if max_tracks is not None:
344
+ suggestion = f"Add --tracking_target_instance_count {max_tracks} to your command (using your --max_tracks value)."
345
+ else:
346
+ suggestion = "Add --tracking_target_instance_count N where N is the expected number of instances per frame."
347
+
348
+ message = (
349
+ f"{features_str} requires --tracking_target_instance_count to be set. "
350
+ f"{suggestion}"
351
+ )
306
352
  logger.error(message)
307
353
  raise ValueError(message)
308
354
  elif tracking_target_instance_count is None:
@@ -332,21 +378,53 @@ def run_inference(
332
378
  tracking_clean_iou_threshold=tracking_clean_iou_threshold,
333
379
  )
334
380
 
335
- finish_timestamp = str(datetime.now())
381
+ end_datetime = datetime.now()
382
+ finish_timestamp = str(end_datetime)
336
383
  total_elapsed = time() - start_inf_time
337
384
  logger.info(f"Finished tracking at: {finish_timestamp}")
338
385
  logger.info(f"Total runtime: {total_elapsed} secs")
339
386
 
387
+ # Build tracking-only provenance
388
+ tracking_params = {
389
+ "window_size": tracking_window_size,
390
+ "min_new_track_points": min_new_track_points,
391
+ "candidates_method": candidates_method,
392
+ "min_match_points": min_match_points,
393
+ "features": features,
394
+ "scoring_method": scoring_method,
395
+ "scoring_reduction": scoring_reduction,
396
+ "robust_best_instance": robust_best_instance,
397
+ "track_matching_method": track_matching_method,
398
+ "max_tracks": max_tracks,
399
+ "use_flow": use_flow,
400
+ "post_connect_single_breaks": post_connect_single_breaks,
401
+ }
402
+ provenance = build_tracking_only_provenance(
403
+ input_labels=labels,
404
+ input_path=data_path,
405
+ start_time=start_datetime,
406
+ end_time=end_datetime,
407
+ tracking_params=tracking_params,
408
+ frames_processed=len(tracked_frames),
409
+ )
410
+
340
411
  output = sio.Labels(
341
412
  labeled_frames=tracked_frames,
342
413
  videos=labels.videos,
343
414
  skeletons=labels.skeletons,
415
+ provenance=provenance,
344
416
  )
345
417
 
346
418
  else:
347
419
  start_inf_time = time()
348
- start_timestamp = str(datetime.now())
420
+ start_datetime = datetime.now()
421
+ start_timestamp = str(start_datetime)
349
422
  logger.info(f"Started inference at: {start_timestamp}")
423
+ logger.info(get_startup_info_string())
424
+
425
+ # Convert device to string if it's a torch.device object
426
+ if hasattr(device, "type"):
427
+ device = str(device)
350
428
 
351
429
  if device == "auto":
352
430
  device = (
@@ -387,7 +465,22 @@ def run_inference(
387
465
  ):
388
466
  if post_connect_single_breaks or tracking_pre_cull_to_target:
389
467
  if tracking_target_instance_count is None and max_instances is None:
390
- message = "Both tracking_target_instance_count and max_instances is set to 0. To connect single breaks or pre-cull to target, at least one of them should be set to an integer."
468
+ features_requested = []
469
+ if post_connect_single_breaks:
470
+ features_requested.append("--post_connect_single_breaks")
471
+ if tracking_pre_cull_to_target:
472
+ features_requested.append("--tracking_pre_cull_to_target")
473
+ features_str = " and ".join(features_requested)
474
+
475
+ if max_tracks is not None:
476
+ suggestion = f"Add --tracking_target_instance_count {max_tracks} to your command (using your --max_tracks value)."
477
+ else:
478
+ suggestion = "Add --tracking_target_instance_count N or --max_instances N where N is the expected number of instances per frame."
479
+
480
+ message = (
481
+ f"{features_str} requires --tracking_target_instance_count or --max_instances to be set. "
482
+ f"{suggestion}"
483
+ )
391
484
  logger.error(message)
392
485
  raise ValueError(message)
393
486
  elif tracking_target_instance_count is None:
@@ -448,6 +541,8 @@ def run_inference(
448
541
  frames=frames,
449
542
  only_labeled_frames=only_labeled_frames,
450
543
  only_suggested_frames=only_suggested_frames,
544
+ exclude_user_labeled=exclude_user_labeled,
545
+ only_predicted_frames=only_predicted_frames,
451
546
  video_index=video_index,
452
547
  video_dataset=video_dataset,
453
548
  video_input_format=video_input_format,
@@ -492,12 +587,94 @@ def run_inference(
492
587
  skeletons=output.skeletons,
493
588
  )
494
589
 
495
- finish_timestamp = str(datetime.now())
590
+ end_datetime = datetime.now()
591
+ finish_timestamp = str(end_datetime)
496
592
  total_elapsed = time() - start_inf_time
497
593
  logger.info(f"Finished inference at: {finish_timestamp}")
498
- logger.info(
499
- f"Total runtime: {total_elapsed} secs"
500
- ) # TODO: add number of predicted frames
594
+ logger.info(f"Total runtime: {total_elapsed} secs")
595
+
596
+ # Determine input labels for provenance preservation
597
+ input_labels_for_prov = None
598
+ if input_labels is not None:
599
+ input_labels_for_prov = input_labels
600
+ elif data_path is not None and data_path.endswith(".slp"):
601
+ # Load input labels to preserve provenance (if not already loaded)
602
+ try:
603
+ input_labels_for_prov = sio.load_slp(data_path)
604
+ except Exception:
605
+ pass
606
+
607
+ # Build inference parameters for provenance
608
+ inference_params = {
609
+ "peak_threshold": peak_threshold,
610
+ "integral_refinement": integral_refinement,
611
+ "integral_patch_size": integral_patch_size,
612
+ "batch_size": batch_size,
613
+ "max_instances": max_instances,
614
+ "crop_size": crop_size,
615
+ "input_scale": input_scale,
616
+ "anchor_part": anchor_part,
617
+ }
618
+
619
+ # Build tracking parameters if tracking was enabled
620
+ tracking_params_prov = None
621
+ if tracking:
622
+ tracking_params_prov = {
623
+ "window_size": tracking_window_size,
624
+ "min_new_track_points": min_new_track_points,
625
+ "candidates_method": candidates_method,
626
+ "min_match_points": min_match_points,
627
+ "features": features,
628
+ "scoring_method": scoring_method,
629
+ "scoring_reduction": scoring_reduction,
630
+ "robust_best_instance": robust_best_instance,
631
+ "track_matching_method": track_matching_method,
632
+ "max_tracks": max_tracks,
633
+ "use_flow": use_flow,
634
+ "post_connect_single_breaks": post_connect_single_breaks,
635
+ }
636
+
637
+ # Determine frame selection method
638
+ frame_selection_method = "all"
639
+ if only_labeled_frames:
640
+ frame_selection_method = "labeled"
641
+ elif only_suggested_frames:
642
+ frame_selection_method = "suggested"
643
+ elif only_predicted_frames:
644
+ frame_selection_method = "predicted"
645
+ elif frames is not None:
646
+ frame_selection_method = "specified"
647
+
648
+ # Determine model type from predictor class
649
+ predictor_type_map = {
650
+ "TopDownPredictor": "top_down",
651
+ "SingleInstancePredictor": "single_instance",
652
+ "BottomUpPredictor": "bottom_up",
653
+ "BottomUpMultiClassPredictor": "bottom_up_multi_class",
654
+ "TopDownMultiClassPredictor": "top_down_multi_class",
655
+ }
656
+ model_type = predictor_type_map.get(type(predictor).__name__)
657
+
658
+ # Build and set provenance (only for Labels objects)
659
+ if make_labels and isinstance(output, sio.Labels):
660
+ provenance = build_inference_provenance(
661
+ model_paths=model_paths,
662
+ model_type=model_type,
663
+ start_time=start_datetime,
664
+ end_time=end_datetime,
665
+ input_labels=input_labels_for_prov,
666
+ input_path=data_path,
667
+ frames_processed=(
668
+ len(output.labeled_frames)
669
+ if hasattr(output, "labeled_frames")
670
+ else None
671
+ ),
672
+ frame_selection_method=frame_selection_method,
673
+ inference_params=inference_params,
674
+ tracking_params=tracking_params_prov,
675
+ device=device,
676
+ )
677
+ output.provenance = provenance
501
678
 
502
679
  if no_empty_frames:
503
680
  output.clean(frames=True, skeletons=False)