sleap-nn 0.0.5__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sleap_nn/__init__.py +9 -2
- sleap_nn/architectures/convnext.py +5 -0
- sleap_nn/architectures/encoder_decoder.py +25 -6
- sleap_nn/architectures/swint.py +8 -0
- sleap_nn/cli.py +489 -46
- sleap_nn/config/data_config.py +51 -8
- sleap_nn/config/get_config.py +32 -24
- sleap_nn/config/trainer_config.py +88 -0
- sleap_nn/data/augmentation.py +61 -200
- sleap_nn/data/custom_datasets.py +433 -61
- sleap_nn/data/instance_cropping.py +71 -6
- sleap_nn/data/normalization.py +45 -2
- sleap_nn/data/providers.py +26 -0
- sleap_nn/data/resizing.py +2 -2
- sleap_nn/data/skia_augmentation.py +414 -0
- sleap_nn/data/utils.py +135 -17
- sleap_nn/evaluation.py +177 -42
- sleap_nn/export/__init__.py +21 -0
- sleap_nn/export/cli.py +1778 -0
- sleap_nn/export/exporters/__init__.py +51 -0
- sleap_nn/export/exporters/onnx_exporter.py +80 -0
- sleap_nn/export/exporters/tensorrt_exporter.py +291 -0
- sleap_nn/export/metadata.py +225 -0
- sleap_nn/export/predictors/__init__.py +63 -0
- sleap_nn/export/predictors/base.py +22 -0
- sleap_nn/export/predictors/onnx.py +154 -0
- sleap_nn/export/predictors/tensorrt.py +312 -0
- sleap_nn/export/utils.py +307 -0
- sleap_nn/export/wrappers/__init__.py +25 -0
- sleap_nn/export/wrappers/base.py +96 -0
- sleap_nn/export/wrappers/bottomup.py +243 -0
- sleap_nn/export/wrappers/bottomup_multiclass.py +195 -0
- sleap_nn/export/wrappers/centered_instance.py +56 -0
- sleap_nn/export/wrappers/centroid.py +58 -0
- sleap_nn/export/wrappers/single_instance.py +83 -0
- sleap_nn/export/wrappers/topdown.py +180 -0
- sleap_nn/export/wrappers/topdown_multiclass.py +304 -0
- sleap_nn/inference/__init__.py +6 -0
- sleap_nn/inference/bottomup.py +86 -20
- sleap_nn/inference/peak_finding.py +93 -16
- sleap_nn/inference/postprocessing.py +284 -0
- sleap_nn/inference/predictors.py +339 -137
- sleap_nn/inference/provenance.py +292 -0
- sleap_nn/inference/topdown.py +55 -47
- sleap_nn/legacy_models.py +65 -11
- sleap_nn/predict.py +224 -19
- sleap_nn/system_info.py +443 -0
- sleap_nn/tracking/tracker.py +8 -1
- sleap_nn/train.py +138 -44
- sleap_nn/training/callbacks.py +1258 -5
- sleap_nn/training/lightning_modules.py +902 -220
- sleap_nn/training/model_trainer.py +424 -111
- sleap_nn/training/schedulers.py +191 -0
- sleap_nn/training/utils.py +367 -2
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/METADATA +35 -33
- sleap_nn-0.1.0.dist-info/RECORD +88 -0
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/WHEEL +1 -1
- sleap_nn-0.0.5.dist-info/RECORD +0 -63
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/entry_points.txt +0 -0
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/licenses/LICENSE +0 -0
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/top_level.txt +0 -0
sleap_nn/predict.py
CHANGED
|
@@ -15,6 +15,11 @@ from sleap_nn.tracking.tracker import (
|
|
|
15
15
|
connect_single_breaks,
|
|
16
16
|
cull_instances,
|
|
17
17
|
)
|
|
18
|
+
from sleap_nn.system_info import get_startup_info_string
|
|
19
|
+
from sleap_nn.inference.provenance import (
|
|
20
|
+
build_inference_provenance,
|
|
21
|
+
build_tracking_only_provenance,
|
|
22
|
+
)
|
|
18
23
|
from omegaconf import OmegaConf
|
|
19
24
|
import sleap_io as sio
|
|
20
25
|
from pathlib import Path
|
|
@@ -58,15 +63,20 @@ def run_inference(
|
|
|
58
63
|
anchor_part: Optional[str] = None,
|
|
59
64
|
only_labeled_frames: bool = False,
|
|
60
65
|
only_suggested_frames: bool = False,
|
|
66
|
+
exclude_user_labeled: bool = False,
|
|
67
|
+
only_predicted_frames: bool = False,
|
|
61
68
|
no_empty_frames: bool = False,
|
|
62
69
|
batch_size: int = 4,
|
|
63
|
-
queue_maxsize: int =
|
|
70
|
+
queue_maxsize: int = 32,
|
|
64
71
|
video_index: Optional[int] = None,
|
|
65
72
|
video_dataset: Optional[str] = None,
|
|
66
73
|
video_input_format: str = "channels_last",
|
|
67
74
|
frames: Optional[list] = None,
|
|
68
75
|
crop_size: Optional[int] = None,
|
|
69
76
|
peak_threshold: Union[float, List[float]] = 0.2,
|
|
77
|
+
filter_overlapping: bool = False,
|
|
78
|
+
filter_overlapping_method: str = "iou",
|
|
79
|
+
filter_overlapping_threshold: float = 0.8,
|
|
70
80
|
integral_refinement: Optional[str] = "integral",
|
|
71
81
|
integral_patch_size: int = 5,
|
|
72
82
|
return_confmaps: bool = False,
|
|
@@ -103,6 +113,7 @@ def run_inference(
|
|
|
103
113
|
tracking_pre_cull_iou_threshold: float = 0,
|
|
104
114
|
tracking_clean_instance_count: int = 0,
|
|
105
115
|
tracking_clean_iou_threshold: float = 0,
|
|
116
|
+
gui: bool = False,
|
|
106
117
|
):
|
|
107
118
|
"""Entry point to run inference on trained SLEAP-NN models.
|
|
108
119
|
|
|
@@ -136,21 +147,32 @@ def run_inference(
|
|
|
136
147
|
provided, the anchor part in the `training_config.yaml` is used. Default: `None`.
|
|
137
148
|
only_labeled_frames: (bool) `True` if inference should be run only on user-labeled frames. Default: `False`.
|
|
138
149
|
only_suggested_frames: (bool) `True` if inference should be run only on unlabeled suggested frames. Default: `False`.
|
|
150
|
+
exclude_user_labeled: (bool) `True` to skip frames that have user-labeled instances. Default: `False`.
|
|
151
|
+
only_predicted_frames: (bool) `True` to run inference only on frames that already have predictions. Default: `False`.
|
|
139
152
|
no_empty_frames: (bool) `True` if empty frames that did not have predictions should be cleared before saving to output. Default: `False`.
|
|
140
153
|
batch_size: (int) Number of samples per batch. Default: 4.
|
|
141
|
-
queue_maxsize: (int) Maximum size of the frame buffer queue. Default:
|
|
154
|
+
queue_maxsize: (int) Maximum size of the frame buffer queue. Default: 32.
|
|
142
155
|
video_index: (int) Integer index of video in .slp file to predict on. To be used with
|
|
143
156
|
an .slp path as an alternative to specifying the video path.
|
|
144
157
|
video_dataset: (str) The dataset for HDF5 videos.
|
|
145
158
|
video_input_format: (str) The input_format for HDF5 videos.
|
|
146
159
|
frames: (list) List of frames indices. If `None`, all frames in the video are used. Default: None.
|
|
147
160
|
crop_size: (int) Crop size. If not provided, the crop size from training_config.yaml is used.
|
|
148
|
-
Default: None.
|
|
161
|
+
If `input_scale` is provided, then the cropped image will be resized according to `input_scale`. Default: None.
|
|
149
162
|
peak_threshold: (float) Minimum confidence threshold. Peaks with values below
|
|
150
163
|
this will be ignored. Default: 0.2. This can also be `List[float]` for topdown
|
|
151
164
|
centroid and centered-instance model, where the first element corresponds
|
|
152
165
|
to centroid model peak finding threshold and the second element is for
|
|
153
166
|
centered-instance model peak finding.
|
|
167
|
+
filter_overlapping: (bool) If True, removes overlapping instances after
|
|
168
|
+
inference using greedy NMS. Applied independently of tracking.
|
|
169
|
+
Default: False.
|
|
170
|
+
filter_overlapping_method: (str) Similarity metric for filtering overlapping
|
|
171
|
+
instances. One of "iou" (bounding box) or "oks" (keypoint similarity).
|
|
172
|
+
Default: "iou".
|
|
173
|
+
filter_overlapping_threshold: (float) Similarity threshold for filtering.
|
|
174
|
+
Instances with similarity > threshold are removed (keeping higher-scoring).
|
|
175
|
+
Typical values: 0.3 (aggressive) to 0.8 (permissive). Default: 0.8.
|
|
154
176
|
integral_refinement: (str) If `None`, returns the grid-aligned peaks with no refinement.
|
|
155
177
|
If `"integral"`, peaks will be refined with integral regression.
|
|
156
178
|
Default: `"integral"`.
|
|
@@ -241,6 +263,8 @@ def run_inference(
|
|
|
241
263
|
tracking_pre_cull_iou_threshold: If non-zero and pre_cull_to_target also set, then use IOU threshold to remove overlapping instances over count *before* tracking. (default: 0)
|
|
242
264
|
tracking_clean_instance_count: Target number of instances to clean *after* tracking. (default: 0)
|
|
243
265
|
tracking_clean_iou_threshold: IOU to use when culling instances *after* tracking. (default: 0)
|
|
266
|
+
gui: (bool) If True, outputs JSON progress lines for GUI integration instead
|
|
267
|
+
of Rich progress bars. Default: False.
|
|
244
268
|
|
|
245
269
|
Returns:
|
|
246
270
|
Returns `sio.Labels` object if `make_labels` is True. Else this function returns
|
|
@@ -256,6 +280,27 @@ def run_inference(
|
|
|
256
280
|
"scale": input_scale,
|
|
257
281
|
}
|
|
258
282
|
|
|
283
|
+
# Validate mutually exclusive frame filter flags
|
|
284
|
+
if only_labeled_frames and exclude_user_labeled:
|
|
285
|
+
message = (
|
|
286
|
+
"--only_labeled_frames and --exclude_user_labeled are mutually exclusive "
|
|
287
|
+
"(would result in zero frames)"
|
|
288
|
+
)
|
|
289
|
+
logger.error(message)
|
|
290
|
+
raise ValueError(message)
|
|
291
|
+
|
|
292
|
+
if (
|
|
293
|
+
only_predicted_frames
|
|
294
|
+
and data_path is not None
|
|
295
|
+
and not data_path.endswith(".slp")
|
|
296
|
+
):
|
|
297
|
+
message = (
|
|
298
|
+
"--only_predicted_frames requires a .slp file input "
|
|
299
|
+
"(need Labels to know which frames have predictions)"
|
|
300
|
+
)
|
|
301
|
+
logger.error(message)
|
|
302
|
+
raise ValueError(message)
|
|
303
|
+
|
|
259
304
|
if model_paths is None or not len(
|
|
260
305
|
model_paths
|
|
261
306
|
): # if model paths is not provided, run tracking-only pipeline.
|
|
@@ -273,7 +318,8 @@ def run_inference(
|
|
|
273
318
|
raise ValueError(message)
|
|
274
319
|
|
|
275
320
|
start_inf_time = time()
|
|
276
|
-
|
|
321
|
+
start_datetime = datetime.now()
|
|
322
|
+
start_timestamp = str(start_datetime)
|
|
277
323
|
logger.info(f"Started tracking at: {start_timestamp}")
|
|
278
324
|
|
|
279
325
|
labels = sio.load_slp(data_path) if input_labels is None else input_labels
|
|
@@ -302,7 +348,22 @@ def run_inference(
|
|
|
302
348
|
|
|
303
349
|
if post_connect_single_breaks or tracking_pre_cull_to_target:
|
|
304
350
|
if tracking_target_instance_count is None and max_instances is None:
|
|
305
|
-
|
|
351
|
+
features_requested = []
|
|
352
|
+
if post_connect_single_breaks:
|
|
353
|
+
features_requested.append("--post_connect_single_breaks")
|
|
354
|
+
if tracking_pre_cull_to_target:
|
|
355
|
+
features_requested.append("--tracking_pre_cull_to_target")
|
|
356
|
+
features_str = " and ".join(features_requested)
|
|
357
|
+
|
|
358
|
+
if max_tracks is not None:
|
|
359
|
+
suggestion = f"Add --tracking_target_instance_count {max_tracks} to your command (using your --max_tracks value)."
|
|
360
|
+
else:
|
|
361
|
+
suggestion = "Add --tracking_target_instance_count N where N is the expected number of instances per frame."
|
|
362
|
+
|
|
363
|
+
message = (
|
|
364
|
+
f"{features_str} requires --tracking_target_instance_count to be set. "
|
|
365
|
+
f"{suggestion}"
|
|
366
|
+
)
|
|
306
367
|
logger.error(message)
|
|
307
368
|
raise ValueError(message)
|
|
308
369
|
elif tracking_target_instance_count is None:
|
|
@@ -332,21 +393,53 @@ def run_inference(
|
|
|
332
393
|
tracking_clean_iou_threshold=tracking_clean_iou_threshold,
|
|
333
394
|
)
|
|
334
395
|
|
|
335
|
-
|
|
396
|
+
end_datetime = datetime.now()
|
|
397
|
+
finish_timestamp = str(end_datetime)
|
|
336
398
|
total_elapsed = time() - start_inf_time
|
|
337
399
|
logger.info(f"Finished tracking at: {finish_timestamp}")
|
|
338
400
|
logger.info(f"Total runtime: {total_elapsed} secs")
|
|
339
401
|
|
|
402
|
+
# Build tracking-only provenance
|
|
403
|
+
tracking_params = {
|
|
404
|
+
"window_size": tracking_window_size,
|
|
405
|
+
"min_new_track_points": min_new_track_points,
|
|
406
|
+
"candidates_method": candidates_method,
|
|
407
|
+
"min_match_points": min_match_points,
|
|
408
|
+
"features": features,
|
|
409
|
+
"scoring_method": scoring_method,
|
|
410
|
+
"scoring_reduction": scoring_reduction,
|
|
411
|
+
"robust_best_instance": robust_best_instance,
|
|
412
|
+
"track_matching_method": track_matching_method,
|
|
413
|
+
"max_tracks": max_tracks,
|
|
414
|
+
"use_flow": use_flow,
|
|
415
|
+
"post_connect_single_breaks": post_connect_single_breaks,
|
|
416
|
+
}
|
|
417
|
+
provenance = build_tracking_only_provenance(
|
|
418
|
+
input_labels=labels,
|
|
419
|
+
input_path=data_path,
|
|
420
|
+
start_time=start_datetime,
|
|
421
|
+
end_time=end_datetime,
|
|
422
|
+
tracking_params=tracking_params,
|
|
423
|
+
frames_processed=len(tracked_frames),
|
|
424
|
+
)
|
|
425
|
+
|
|
340
426
|
output = sio.Labels(
|
|
341
427
|
labeled_frames=tracked_frames,
|
|
342
428
|
videos=labels.videos,
|
|
343
429
|
skeletons=labels.skeletons,
|
|
430
|
+
provenance=provenance,
|
|
344
431
|
)
|
|
345
432
|
|
|
346
433
|
else:
|
|
347
434
|
start_inf_time = time()
|
|
348
|
-
|
|
435
|
+
start_datetime = datetime.now()
|
|
436
|
+
start_timestamp = str(start_datetime)
|
|
349
437
|
logger.info(f"Started inference at: {start_timestamp}")
|
|
438
|
+
logger.info(get_startup_info_string())
|
|
439
|
+
|
|
440
|
+
# Convert device to string if it's a torch.device object
|
|
441
|
+
if hasattr(device, "type"):
|
|
442
|
+
device = str(device)
|
|
350
443
|
|
|
351
444
|
if device == "auto":
|
|
352
445
|
device = (
|
|
@@ -355,13 +448,6 @@ def run_inference(
|
|
|
355
448
|
else "mps" if torch.backends.mps.is_available() else "cpu"
|
|
356
449
|
)
|
|
357
450
|
|
|
358
|
-
if integral_refinement is not None and device == "mps": # TODO
|
|
359
|
-
# kornia/geometry/transform/imgwarp.py:382: in get_perspective_transform. NotImplementedError: The operator 'aten::_linalg_solve_ex.result' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.
|
|
360
|
-
logger.info(
|
|
361
|
-
"Integral refinement is not supported with MPS accelerator. Setting integral refinement to None."
|
|
362
|
-
)
|
|
363
|
-
integral_refinement = None
|
|
364
|
-
|
|
365
451
|
logger.info(f"Using device: {device}")
|
|
366
452
|
|
|
367
453
|
# initializes the inference model
|
|
@@ -380,6 +466,9 @@ def run_inference(
|
|
|
380
466
|
anchor_part=anchor_part,
|
|
381
467
|
)
|
|
382
468
|
|
|
469
|
+
# Set GUI mode for progress output
|
|
470
|
+
predictor.gui = gui
|
|
471
|
+
|
|
383
472
|
if (
|
|
384
473
|
tracking
|
|
385
474
|
and not isinstance(predictor, BottomUpMultiClassPredictor)
|
|
@@ -387,7 +476,22 @@ def run_inference(
|
|
|
387
476
|
):
|
|
388
477
|
if post_connect_single_breaks or tracking_pre_cull_to_target:
|
|
389
478
|
if tracking_target_instance_count is None and max_instances is None:
|
|
390
|
-
|
|
479
|
+
features_requested = []
|
|
480
|
+
if post_connect_single_breaks:
|
|
481
|
+
features_requested.append("--post_connect_single_breaks")
|
|
482
|
+
if tracking_pre_cull_to_target:
|
|
483
|
+
features_requested.append("--tracking_pre_cull_to_target")
|
|
484
|
+
features_str = " and ".join(features_requested)
|
|
485
|
+
|
|
486
|
+
if max_tracks is not None:
|
|
487
|
+
suggestion = f"Add --tracking_target_instance_count {max_tracks} to your command (using your --max_tracks value)."
|
|
488
|
+
else:
|
|
489
|
+
suggestion = "Add --tracking_target_instance_count N or --max_instances N where N is the expected number of instances per frame."
|
|
490
|
+
|
|
491
|
+
message = (
|
|
492
|
+
f"{features_str} requires --tracking_target_instance_count or --max_instances to be set. "
|
|
493
|
+
f"{suggestion}"
|
|
494
|
+
)
|
|
391
495
|
logger.error(message)
|
|
392
496
|
raise ValueError(message)
|
|
393
497
|
elif tracking_target_instance_count is None:
|
|
@@ -448,6 +552,8 @@ def run_inference(
|
|
|
448
552
|
frames=frames,
|
|
449
553
|
only_labeled_frames=only_labeled_frames,
|
|
450
554
|
only_suggested_frames=only_suggested_frames,
|
|
555
|
+
exclude_user_labeled=exclude_user_labeled,
|
|
556
|
+
only_predicted_frames=only_predicted_frames,
|
|
451
557
|
video_index=video_index,
|
|
452
558
|
video_dataset=video_dataset,
|
|
453
559
|
video_input_format=video_input_format,
|
|
@@ -458,6 +564,20 @@ def run_inference(
|
|
|
458
564
|
make_labels=make_labels,
|
|
459
565
|
)
|
|
460
566
|
|
|
567
|
+
# Filter overlapping instances (independent of tracking)
|
|
568
|
+
if filter_overlapping and make_labels:
|
|
569
|
+
from sleap_nn.inference.postprocessing import filter_overlapping_instances
|
|
570
|
+
|
|
571
|
+
output = filter_overlapping_instances(
|
|
572
|
+
output,
|
|
573
|
+
threshold=filter_overlapping_threshold,
|
|
574
|
+
method=filter_overlapping_method,
|
|
575
|
+
)
|
|
576
|
+
logger.info(
|
|
577
|
+
f"Filtered overlapping instances with {filter_overlapping_method.upper()} "
|
|
578
|
+
f"threshold: {filter_overlapping_threshold}"
|
|
579
|
+
)
|
|
580
|
+
|
|
461
581
|
if tracking:
|
|
462
582
|
lfs = [x for x in output]
|
|
463
583
|
if tracking_clean_instance_count > 0:
|
|
@@ -492,12 +612,97 @@ def run_inference(
|
|
|
492
612
|
skeletons=output.skeletons,
|
|
493
613
|
)
|
|
494
614
|
|
|
495
|
-
|
|
615
|
+
end_datetime = datetime.now()
|
|
616
|
+
finish_timestamp = str(end_datetime)
|
|
496
617
|
total_elapsed = time() - start_inf_time
|
|
497
618
|
logger.info(f"Finished inference at: {finish_timestamp}")
|
|
498
|
-
logger.info(
|
|
499
|
-
|
|
500
|
-
|
|
619
|
+
logger.info(f"Total runtime: {total_elapsed} secs")
|
|
620
|
+
|
|
621
|
+
# Determine input labels for provenance preservation
|
|
622
|
+
input_labels_for_prov = None
|
|
623
|
+
if input_labels is not None:
|
|
624
|
+
input_labels_for_prov = input_labels
|
|
625
|
+
elif data_path is not None and data_path.endswith(".slp"):
|
|
626
|
+
# Load input labels to preserve provenance (if not already loaded)
|
|
627
|
+
try:
|
|
628
|
+
input_labels_for_prov = sio.load_slp(data_path)
|
|
629
|
+
except Exception:
|
|
630
|
+
pass
|
|
631
|
+
|
|
632
|
+
# Build inference parameters for provenance
|
|
633
|
+
inference_params = {
|
|
634
|
+
"peak_threshold": peak_threshold,
|
|
635
|
+
"filter_overlapping": filter_overlapping,
|
|
636
|
+
"filter_overlapping_method": filter_overlapping_method,
|
|
637
|
+
"filter_overlapping_threshold": filter_overlapping_threshold,
|
|
638
|
+
"integral_refinement": integral_refinement,
|
|
639
|
+
"integral_patch_size": integral_patch_size,
|
|
640
|
+
"batch_size": batch_size,
|
|
641
|
+
"max_instances": max_instances,
|
|
642
|
+
"crop_size": crop_size,
|
|
643
|
+
"input_scale": input_scale,
|
|
644
|
+
"anchor_part": anchor_part,
|
|
645
|
+
}
|
|
646
|
+
|
|
647
|
+
# Build tracking parameters if tracking was enabled
|
|
648
|
+
tracking_params_prov = None
|
|
649
|
+
if tracking:
|
|
650
|
+
tracking_params_prov = {
|
|
651
|
+
"window_size": tracking_window_size,
|
|
652
|
+
"min_new_track_points": min_new_track_points,
|
|
653
|
+
"candidates_method": candidates_method,
|
|
654
|
+
"min_match_points": min_match_points,
|
|
655
|
+
"features": features,
|
|
656
|
+
"scoring_method": scoring_method,
|
|
657
|
+
"scoring_reduction": scoring_reduction,
|
|
658
|
+
"robust_best_instance": robust_best_instance,
|
|
659
|
+
"track_matching_method": track_matching_method,
|
|
660
|
+
"max_tracks": max_tracks,
|
|
661
|
+
"use_flow": use_flow,
|
|
662
|
+
"post_connect_single_breaks": post_connect_single_breaks,
|
|
663
|
+
}
|
|
664
|
+
|
|
665
|
+
# Determine frame selection method
|
|
666
|
+
frame_selection_method = "all"
|
|
667
|
+
if only_labeled_frames:
|
|
668
|
+
frame_selection_method = "labeled"
|
|
669
|
+
elif only_suggested_frames:
|
|
670
|
+
frame_selection_method = "suggested"
|
|
671
|
+
elif only_predicted_frames:
|
|
672
|
+
frame_selection_method = "predicted"
|
|
673
|
+
elif frames is not None:
|
|
674
|
+
frame_selection_method = "specified"
|
|
675
|
+
|
|
676
|
+
# Determine model type from predictor class
|
|
677
|
+
predictor_type_map = {
|
|
678
|
+
"TopDownPredictor": "top_down",
|
|
679
|
+
"SingleInstancePredictor": "single_instance",
|
|
680
|
+
"BottomUpPredictor": "bottom_up",
|
|
681
|
+
"BottomUpMultiClassPredictor": "bottom_up_multi_class",
|
|
682
|
+
"TopDownMultiClassPredictor": "top_down_multi_class",
|
|
683
|
+
}
|
|
684
|
+
model_type = predictor_type_map.get(type(predictor).__name__)
|
|
685
|
+
|
|
686
|
+
# Build and set provenance (only for Labels objects)
|
|
687
|
+
if make_labels and isinstance(output, sio.Labels):
|
|
688
|
+
provenance = build_inference_provenance(
|
|
689
|
+
model_paths=model_paths,
|
|
690
|
+
model_type=model_type,
|
|
691
|
+
start_time=start_datetime,
|
|
692
|
+
end_time=end_datetime,
|
|
693
|
+
input_labels=input_labels_for_prov,
|
|
694
|
+
input_path=data_path,
|
|
695
|
+
frames_processed=(
|
|
696
|
+
len(output.labeled_frames)
|
|
697
|
+
if hasattr(output, "labeled_frames")
|
|
698
|
+
else None
|
|
699
|
+
),
|
|
700
|
+
frame_selection_method=frame_selection_method,
|
|
701
|
+
inference_params=inference_params,
|
|
702
|
+
tracking_params=tracking_params_prov,
|
|
703
|
+
device=device,
|
|
704
|
+
)
|
|
705
|
+
output.provenance = provenance
|
|
501
706
|
|
|
502
707
|
if no_empty_frames:
|
|
503
708
|
output.clean(frames=True, skeletons=False)
|