sleap-nn 0.0.5__py3-none-any.whl → 0.1.0a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sleap_nn/__init__.py +6 -1
- sleap_nn/cli.py +142 -3
- sleap_nn/config/data_config.py +44 -7
- sleap_nn/config/get_config.py +22 -20
- sleap_nn/config/trainer_config.py +12 -0
- sleap_nn/data/augmentation.py +54 -2
- sleap_nn/data/custom_datasets.py +22 -22
- sleap_nn/data/instance_cropping.py +70 -5
- sleap_nn/data/normalization.py +45 -2
- sleap_nn/data/providers.py +26 -0
- sleap_nn/evaluation.py +99 -23
- sleap_nn/inference/__init__.py +6 -0
- sleap_nn/inference/peak_finding.py +10 -2
- sleap_nn/inference/predictors.py +115 -20
- sleap_nn/inference/provenance.py +292 -0
- sleap_nn/inference/topdown.py +55 -47
- sleap_nn/predict.py +187 -10
- sleap_nn/system_info.py +443 -0
- sleap_nn/tracking/tracker.py +8 -1
- sleap_nn/train.py +64 -40
- sleap_nn/training/callbacks.py +317 -5
- sleap_nn/training/lightning_modules.py +325 -180
- sleap_nn/training/model_trainer.py +308 -22
- sleap_nn/training/utils.py +367 -2
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0a0.dist-info}/METADATA +22 -32
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0a0.dist-info}/RECORD +30 -28
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0a0.dist-info}/WHEEL +0 -0
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0a0.dist-info}/entry_points.txt +0 -0
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0a0.dist-info}/licenses/LICENSE +0 -0
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0a0.dist-info}/top_level.txt +0 -0
sleap_nn/inference/topdown.py
CHANGED
|
@@ -47,9 +47,6 @@ class CentroidCrop(L.LightningModule):
|
|
|
47
47
|
crop_hw: Tuple (height, width) representing the crop size.
|
|
48
48
|
input_scale: Float indicating if the images should be resized before being
|
|
49
49
|
passed to the model.
|
|
50
|
-
precrop_resize: Float indicating the factor by which the original images
|
|
51
|
-
(not images resized for centroid model) should be resized before cropping.
|
|
52
|
-
Note: This resize happens only after getting the predictions for centroid model.
|
|
53
50
|
max_stride: Maximum stride in a model that the images must be divisible by.
|
|
54
51
|
If > 1, this will pad the bottom and right of the images to ensure they meet
|
|
55
52
|
this divisibility criteria. Padding is applied after the scaling specified
|
|
@@ -74,7 +71,6 @@ class CentroidCrop(L.LightningModule):
|
|
|
74
71
|
return_crops: bool = False,
|
|
75
72
|
crop_hw: Optional[List[int]] = None,
|
|
76
73
|
input_scale: float = 1.0,
|
|
77
|
-
precrop_resize: float = 1.0,
|
|
78
74
|
max_stride: int = 1,
|
|
79
75
|
use_gt_centroids: bool = False,
|
|
80
76
|
anchor_ind: Optional[int] = None,
|
|
@@ -92,22 +88,25 @@ class CentroidCrop(L.LightningModule):
|
|
|
92
88
|
self.return_crops = return_crops
|
|
93
89
|
self.crop_hw = crop_hw
|
|
94
90
|
self.input_scale = input_scale
|
|
95
|
-
self.precrop_resize = precrop_resize
|
|
96
91
|
self.max_stride = max_stride
|
|
97
92
|
self.use_gt_centroids = use_gt_centroids
|
|
98
93
|
self.anchor_ind = anchor_ind
|
|
99
94
|
|
|
100
|
-
def _generate_crops(self, inputs):
|
|
95
|
+
def _generate_crops(self, inputs, cms: Optional[torch.Tensor] = None):
|
|
101
96
|
"""Generate Crops from the predicted centroids."""
|
|
102
97
|
crops_dict = []
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
98
|
+
if cms is not None:
|
|
99
|
+
cms = cms.detach()
|
|
100
|
+
for idx, (centroid, centroid_val, image, fidx, vidx, sz, eff_sc) in enumerate(
|
|
101
|
+
zip(
|
|
102
|
+
self.refined_peaks_batched,
|
|
103
|
+
self.peak_vals_batched,
|
|
104
|
+
inputs["image"],
|
|
105
|
+
inputs["frame_idx"],
|
|
106
|
+
inputs["video_idx"],
|
|
107
|
+
inputs["orig_size"],
|
|
108
|
+
inputs["eff_scale"],
|
|
109
|
+
)
|
|
111
110
|
):
|
|
112
111
|
if torch.any(torch.isnan(centroid)):
|
|
113
112
|
if torch.all(torch.isnan(centroid)):
|
|
@@ -149,6 +148,11 @@ class CentroidCrop(L.LightningModule):
|
|
|
149
148
|
ex["instance_image"] = instance_image.unsqueeze(dim=1)
|
|
150
149
|
ex["orig_size"] = torch.cat([torch.Tensor(sz)] * n)
|
|
151
150
|
ex["eff_scale"] = torch.Tensor([eff_sc] * n)
|
|
151
|
+
ex["pred_centroids"] = centroid
|
|
152
|
+
if self.return_confmaps:
|
|
153
|
+
ex["pred_centroid_confmaps"] = torch.cat(
|
|
154
|
+
[cms[idx].unsqueeze(dim=0)] * n
|
|
155
|
+
)
|
|
152
156
|
crops_dict.append(ex)
|
|
153
157
|
|
|
154
158
|
return crops_dict
|
|
@@ -204,12 +208,6 @@ class CentroidCrop(L.LightningModule):
|
|
|
204
208
|
|
|
205
209
|
if self.return_crops:
|
|
206
210
|
crops_dict = self._generate_crops(inputs)
|
|
207
|
-
inputs["image"] = resize_image(inputs["image"], self.precrop_resize)
|
|
208
|
-
inputs["centroids"] *= self.precrop_resize
|
|
209
|
-
scaled_refined_peaks = []
|
|
210
|
-
for ref_peak in self.refined_peaks_batched:
|
|
211
|
-
scaled_refined_peaks.append(ref_peak * self.precrop_resize)
|
|
212
|
-
self.refined_peaks_batched = scaled_refined_peaks
|
|
213
211
|
return crops_dict
|
|
214
212
|
else:
|
|
215
213
|
return inputs
|
|
@@ -274,19 +272,13 @@ class CentroidCrop(L.LightningModule):
|
|
|
274
272
|
|
|
275
273
|
# Generate crops if return_crops=True to pass the crops to CenteredInstance model.
|
|
276
274
|
if self.return_crops:
|
|
277
|
-
inputs["image"] = resize_image(inputs["image"], self.precrop_resize)
|
|
278
|
-
scaled_refined_peaks = []
|
|
279
|
-
for ref_peak in self.refined_peaks_batched:
|
|
280
|
-
scaled_refined_peaks.append(ref_peak * self.precrop_resize)
|
|
281
|
-
self.refined_peaks_batched = scaled_refined_peaks
|
|
282
|
-
|
|
283
275
|
inputs.update(
|
|
284
276
|
{
|
|
285
277
|
"centroids": self.refined_peaks_batched,
|
|
286
278
|
"centroid_vals": self.peak_vals_batched,
|
|
287
279
|
}
|
|
288
280
|
)
|
|
289
|
-
crops_dict = self._generate_crops(inputs)
|
|
281
|
+
crops_dict = self._generate_crops(inputs, cms)
|
|
290
282
|
return crops_dict
|
|
291
283
|
else:
|
|
292
284
|
# batch the peaks to pass it to FindInstancePeaksGroundTruth class.
|
|
@@ -359,7 +351,11 @@ class FindInstancePeaksGroundTruth(L.LightningModule):
|
|
|
359
351
|
|
|
360
352
|
def forward(self, batch: Dict[str, torch.Tensor]) -> Dict[str, np.array]:
|
|
361
353
|
"""Return the ground truth instance peaks given a set of crops."""
|
|
362
|
-
b, _,
|
|
354
|
+
b, _, _, nodes, _ = batch["instances"].shape
|
|
355
|
+
# Use number of centroids as max_inst to ensure consistent output shape
|
|
356
|
+
# This handles the case where max_instances limits centroids but instances
|
|
357
|
+
# tensor has a different (global) max_instances from the labels file
|
|
358
|
+
num_centroids = batch["centroids"].shape[2]
|
|
363
359
|
inst = (
|
|
364
360
|
batch["instances"].unsqueeze(dim=-4).float()
|
|
365
361
|
) # (batch, 1, 1, n_inst, nodes, 2)
|
|
@@ -389,26 +385,26 @@ class FindInstancePeaksGroundTruth(L.LightningModule):
|
|
|
389
385
|
parsed = 0
|
|
390
386
|
for i in range(b):
|
|
391
387
|
if i not in matched_batch_inds:
|
|
392
|
-
batch_peaks = torch.full((
|
|
393
|
-
vals = torch.full((
|
|
388
|
+
batch_peaks = torch.full((num_centroids, nodes, 2), torch.nan)
|
|
389
|
+
vals = torch.full((num_centroids, nodes), torch.nan)
|
|
394
390
|
else:
|
|
395
391
|
c = counts[i]
|
|
396
392
|
batch_peaks = peaks_list[parsed : parsed + c]
|
|
397
393
|
num_inst = len(batch_peaks)
|
|
398
394
|
vals = torch.ones((num_inst, nodes))
|
|
399
|
-
if c <
|
|
395
|
+
if c < num_centroids:
|
|
400
396
|
batch_peaks = torch.cat(
|
|
401
397
|
[
|
|
402
398
|
batch_peaks,
|
|
403
|
-
torch.full((
|
|
399
|
+
torch.full((num_centroids - num_inst, nodes, 2), torch.nan),
|
|
404
400
|
]
|
|
405
401
|
)
|
|
406
402
|
vals = torch.cat(
|
|
407
|
-
[vals, torch.full((
|
|
403
|
+
[vals, torch.full((num_centroids - num_inst, nodes), torch.nan)]
|
|
408
404
|
)
|
|
409
405
|
else:
|
|
410
|
-
batch_peaks = batch_peaks[:
|
|
411
|
-
vals = vals[:
|
|
406
|
+
batch_peaks = batch_peaks[:num_centroids]
|
|
407
|
+
vals = vals[:num_centroids]
|
|
412
408
|
parsed += c
|
|
413
409
|
|
|
414
410
|
batch_peaks = batch_peaks.unsqueeze(dim=0)
|
|
@@ -432,33 +428,45 @@ class FindInstancePeaksGroundTruth(L.LightningModule):
|
|
|
432
428
|
peaks_output["pred_instance_peaks"] = peaks
|
|
433
429
|
peaks_output["pred_peak_values"] = peaks_vals
|
|
434
430
|
|
|
435
|
-
batch_size
|
|
436
|
-
batch["centroids"].shape[0],
|
|
437
|
-
batch["centroids"].shape[2],
|
|
438
|
-
)
|
|
431
|
+
batch_size = batch["centroids"].shape[0]
|
|
439
432
|
output_dict = {}
|
|
440
433
|
output_dict["centroid"] = batch["centroids"].squeeze(dim=1).reshape(-1, 1, 2)
|
|
441
434
|
output_dict["centroid_val"] = batch["centroid_vals"].reshape(-1)
|
|
442
|
-
output_dict["pred_instance_peaks"] =
|
|
443
|
-
|
|
435
|
+
output_dict["pred_instance_peaks"] = peaks_output[
|
|
436
|
+
"pred_instance_peaks"
|
|
437
|
+
].reshape(-1, nodes, 2)
|
|
438
|
+
output_dict["pred_peak_values"] = peaks_output["pred_peak_values"].reshape(
|
|
439
|
+
-1, nodes
|
|
444
440
|
)
|
|
445
|
-
output_dict["pred_peak_values"] = batch["pred_peak_values"].reshape(-1, nodes)
|
|
446
441
|
output_dict["instance_bbox"] = torch.zeros(
|
|
447
442
|
(batch_size * num_centroids, 1, 4, 2)
|
|
448
443
|
)
|
|
449
444
|
frame_inds = []
|
|
450
445
|
video_inds = []
|
|
451
446
|
orig_szs = []
|
|
447
|
+
images = []
|
|
448
|
+
centroid_confmaps = []
|
|
452
449
|
for b_idx in range(b):
|
|
453
450
|
curr_batch_size = len(batch["centroids"][b_idx][0])
|
|
454
451
|
frame_inds.extend([batch["frame_idx"][b_idx]] * curr_batch_size)
|
|
455
452
|
video_inds.extend([batch["video_idx"][b_idx]] * curr_batch_size)
|
|
456
453
|
orig_szs.append(torch.cat([batch["orig_size"][b_idx]] * curr_batch_size))
|
|
454
|
+
images.append(
|
|
455
|
+
batch["image"][b_idx].unsqueeze(0).repeat(curr_batch_size, 1, 1, 1, 1)
|
|
456
|
+
)
|
|
457
|
+
if "pred_centroid_confmaps" in batch:
|
|
458
|
+
centroid_confmaps.append(
|
|
459
|
+
batch["pred_centroid_confmaps"][b_idx]
|
|
460
|
+
.unsqueeze(0)
|
|
461
|
+
.repeat(curr_batch_size, 1, 1, 1)
|
|
462
|
+
)
|
|
457
463
|
|
|
458
464
|
output_dict["frame_idx"] = torch.tensor(frame_inds)
|
|
459
465
|
output_dict["video_idx"] = torch.tensor(video_inds)
|
|
460
466
|
output_dict["orig_size"] = torch.concatenate(orig_szs, dim=0)
|
|
461
|
-
|
|
467
|
+
output_dict["image"] = torch.cat(images, dim=0)
|
|
468
|
+
if centroid_confmaps:
|
|
469
|
+
output_dict["pred_centroid_confmaps"] = torch.cat(centroid_confmaps, dim=0)
|
|
462
470
|
return output_dict
|
|
463
471
|
|
|
464
472
|
|
|
@@ -548,6 +556,8 @@ class FindInstancePeaks(L.LightningModule):
|
|
|
548
556
|
# Network forward pass.
|
|
549
557
|
# resize and pad the input image
|
|
550
558
|
input_image = inputs["instance_image"]
|
|
559
|
+
# resize the crop image
|
|
560
|
+
input_image = resize_image(input_image, self.input_scale)
|
|
551
561
|
if self.max_stride != 1:
|
|
552
562
|
input_image = apply_pad_to_stride(input_image, self.max_stride)
|
|
553
563
|
|
|
@@ -569,8 +579,6 @@ class FindInstancePeaks(L.LightningModule):
|
|
|
569
579
|
inputs["eff_scale"].unsqueeze(dim=1).unsqueeze(dim=2).to(peak_points.device)
|
|
570
580
|
)
|
|
571
581
|
|
|
572
|
-
inputs["instance_bbox"] = inputs["instance_bbox"] / self.input_scale
|
|
573
|
-
|
|
574
582
|
inputs["instance_bbox"] = inputs["instance_bbox"] / (
|
|
575
583
|
inputs["eff_scale"]
|
|
576
584
|
.unsqueeze(dim=1)
|
|
@@ -679,6 +687,8 @@ class TopDownMultiClassFindInstancePeaks(L.LightningModule):
|
|
|
679
687
|
# Network forward pass.
|
|
680
688
|
# resize and pad the input image
|
|
681
689
|
input_image = inputs["instance_image"]
|
|
690
|
+
# resize the crop image
|
|
691
|
+
input_image = resize_image(input_image, self.input_scale)
|
|
682
692
|
if self.max_stride != 1:
|
|
683
693
|
input_image = apply_pad_to_stride(input_image, self.max_stride)
|
|
684
694
|
|
|
@@ -702,8 +712,6 @@ class TopDownMultiClassFindInstancePeaks(L.LightningModule):
|
|
|
702
712
|
inputs["eff_scale"].unsqueeze(dim=1).unsqueeze(dim=2).to(peak_points.device)
|
|
703
713
|
)
|
|
704
714
|
|
|
705
|
-
inputs["instance_bbox"] = inputs["instance_bbox"] / self.input_scale
|
|
706
|
-
|
|
707
715
|
inputs["instance_bbox"] = inputs["instance_bbox"] / (
|
|
708
716
|
inputs["eff_scale"]
|
|
709
717
|
.unsqueeze(dim=1)
|
sleap_nn/predict.py
CHANGED
|
@@ -15,6 +15,11 @@ from sleap_nn.tracking.tracker import (
|
|
|
15
15
|
connect_single_breaks,
|
|
16
16
|
cull_instances,
|
|
17
17
|
)
|
|
18
|
+
from sleap_nn.system_info import get_startup_info_string
|
|
19
|
+
from sleap_nn.inference.provenance import (
|
|
20
|
+
build_inference_provenance,
|
|
21
|
+
build_tracking_only_provenance,
|
|
22
|
+
)
|
|
18
23
|
from omegaconf import OmegaConf
|
|
19
24
|
import sleap_io as sio
|
|
20
25
|
from pathlib import Path
|
|
@@ -58,6 +63,8 @@ def run_inference(
|
|
|
58
63
|
anchor_part: Optional[str] = None,
|
|
59
64
|
only_labeled_frames: bool = False,
|
|
60
65
|
only_suggested_frames: bool = False,
|
|
66
|
+
exclude_user_labeled: bool = False,
|
|
67
|
+
only_predicted_frames: bool = False,
|
|
61
68
|
no_empty_frames: bool = False,
|
|
62
69
|
batch_size: int = 4,
|
|
63
70
|
queue_maxsize: int = 8,
|
|
@@ -136,6 +143,8 @@ def run_inference(
|
|
|
136
143
|
provided, the anchor part in the `training_config.yaml` is used. Default: `None`.
|
|
137
144
|
only_labeled_frames: (bool) `True` if inference should be run only on user-labeled frames. Default: `False`.
|
|
138
145
|
only_suggested_frames: (bool) `True` if inference should be run only on unlabeled suggested frames. Default: `False`.
|
|
146
|
+
exclude_user_labeled: (bool) `True` to skip frames that have user-labeled instances. Default: `False`.
|
|
147
|
+
only_predicted_frames: (bool) `True` to run inference only on frames that already have predictions. Default: `False`.
|
|
139
148
|
no_empty_frames: (bool) `True` if empty frames that did not have predictions should be cleared before saving to output. Default: `False`.
|
|
140
149
|
batch_size: (int) Number of samples per batch. Default: 4.
|
|
141
150
|
queue_maxsize: (int) Maximum size of the frame buffer queue. Default: 8.
|
|
@@ -145,7 +154,7 @@ def run_inference(
|
|
|
145
154
|
video_input_format: (str) The input_format for HDF5 videos.
|
|
146
155
|
frames: (list) List of frames indices. If `None`, all frames in the video are used. Default: None.
|
|
147
156
|
crop_size: (int) Crop size. If not provided, the crop size from training_config.yaml is used.
|
|
148
|
-
Default: None.
|
|
157
|
+
If `input_scale` is provided, then the cropped image will be resized according to `input_scale`. Default: None.
|
|
149
158
|
peak_threshold: (float) Minimum confidence threshold. Peaks with values below
|
|
150
159
|
this will be ignored. Default: 0.2. This can also be `List[float]` for topdown
|
|
151
160
|
centroid and centered-instance model, where the first element corresponds
|
|
@@ -256,6 +265,27 @@ def run_inference(
|
|
|
256
265
|
"scale": input_scale,
|
|
257
266
|
}
|
|
258
267
|
|
|
268
|
+
# Validate mutually exclusive frame filter flags
|
|
269
|
+
if only_labeled_frames and exclude_user_labeled:
|
|
270
|
+
message = (
|
|
271
|
+
"--only_labeled_frames and --exclude_user_labeled are mutually exclusive "
|
|
272
|
+
"(would result in zero frames)"
|
|
273
|
+
)
|
|
274
|
+
logger.error(message)
|
|
275
|
+
raise ValueError(message)
|
|
276
|
+
|
|
277
|
+
if (
|
|
278
|
+
only_predicted_frames
|
|
279
|
+
and data_path is not None
|
|
280
|
+
and not data_path.endswith(".slp")
|
|
281
|
+
):
|
|
282
|
+
message = (
|
|
283
|
+
"--only_predicted_frames requires a .slp file input "
|
|
284
|
+
"(need Labels to know which frames have predictions)"
|
|
285
|
+
)
|
|
286
|
+
logger.error(message)
|
|
287
|
+
raise ValueError(message)
|
|
288
|
+
|
|
259
289
|
if model_paths is None or not len(
|
|
260
290
|
model_paths
|
|
261
291
|
): # if model paths is not provided, run tracking-only pipeline.
|
|
@@ -273,7 +303,8 @@ def run_inference(
|
|
|
273
303
|
raise ValueError(message)
|
|
274
304
|
|
|
275
305
|
start_inf_time = time()
|
|
276
|
-
|
|
306
|
+
start_datetime = datetime.now()
|
|
307
|
+
start_timestamp = str(start_datetime)
|
|
277
308
|
logger.info(f"Started tracking at: {start_timestamp}")
|
|
278
309
|
|
|
279
310
|
labels = sio.load_slp(data_path) if input_labels is None else input_labels
|
|
@@ -302,7 +333,22 @@ def run_inference(
|
|
|
302
333
|
|
|
303
334
|
if post_connect_single_breaks or tracking_pre_cull_to_target:
|
|
304
335
|
if tracking_target_instance_count is None and max_instances is None:
|
|
305
|
-
|
|
336
|
+
features_requested = []
|
|
337
|
+
if post_connect_single_breaks:
|
|
338
|
+
features_requested.append("--post_connect_single_breaks")
|
|
339
|
+
if tracking_pre_cull_to_target:
|
|
340
|
+
features_requested.append("--tracking_pre_cull_to_target")
|
|
341
|
+
features_str = " and ".join(features_requested)
|
|
342
|
+
|
|
343
|
+
if max_tracks is not None:
|
|
344
|
+
suggestion = f"Add --tracking_target_instance_count {max_tracks} to your command (using your --max_tracks value)."
|
|
345
|
+
else:
|
|
346
|
+
suggestion = "Add --tracking_target_instance_count N where N is the expected number of instances per frame."
|
|
347
|
+
|
|
348
|
+
message = (
|
|
349
|
+
f"{features_str} requires --tracking_target_instance_count to be set. "
|
|
350
|
+
f"{suggestion}"
|
|
351
|
+
)
|
|
306
352
|
logger.error(message)
|
|
307
353
|
raise ValueError(message)
|
|
308
354
|
elif tracking_target_instance_count is None:
|
|
@@ -332,21 +378,53 @@ def run_inference(
|
|
|
332
378
|
tracking_clean_iou_threshold=tracking_clean_iou_threshold,
|
|
333
379
|
)
|
|
334
380
|
|
|
335
|
-
|
|
381
|
+
end_datetime = datetime.now()
|
|
382
|
+
finish_timestamp = str(end_datetime)
|
|
336
383
|
total_elapsed = time() - start_inf_time
|
|
337
384
|
logger.info(f"Finished tracking at: {finish_timestamp}")
|
|
338
385
|
logger.info(f"Total runtime: {total_elapsed} secs")
|
|
339
386
|
|
|
387
|
+
# Build tracking-only provenance
|
|
388
|
+
tracking_params = {
|
|
389
|
+
"window_size": tracking_window_size,
|
|
390
|
+
"min_new_track_points": min_new_track_points,
|
|
391
|
+
"candidates_method": candidates_method,
|
|
392
|
+
"min_match_points": min_match_points,
|
|
393
|
+
"features": features,
|
|
394
|
+
"scoring_method": scoring_method,
|
|
395
|
+
"scoring_reduction": scoring_reduction,
|
|
396
|
+
"robust_best_instance": robust_best_instance,
|
|
397
|
+
"track_matching_method": track_matching_method,
|
|
398
|
+
"max_tracks": max_tracks,
|
|
399
|
+
"use_flow": use_flow,
|
|
400
|
+
"post_connect_single_breaks": post_connect_single_breaks,
|
|
401
|
+
}
|
|
402
|
+
provenance = build_tracking_only_provenance(
|
|
403
|
+
input_labels=labels,
|
|
404
|
+
input_path=data_path,
|
|
405
|
+
start_time=start_datetime,
|
|
406
|
+
end_time=end_datetime,
|
|
407
|
+
tracking_params=tracking_params,
|
|
408
|
+
frames_processed=len(tracked_frames),
|
|
409
|
+
)
|
|
410
|
+
|
|
340
411
|
output = sio.Labels(
|
|
341
412
|
labeled_frames=tracked_frames,
|
|
342
413
|
videos=labels.videos,
|
|
343
414
|
skeletons=labels.skeletons,
|
|
415
|
+
provenance=provenance,
|
|
344
416
|
)
|
|
345
417
|
|
|
346
418
|
else:
|
|
347
419
|
start_inf_time = time()
|
|
348
|
-
|
|
420
|
+
start_datetime = datetime.now()
|
|
421
|
+
start_timestamp = str(start_datetime)
|
|
349
422
|
logger.info(f"Started inference at: {start_timestamp}")
|
|
423
|
+
logger.info(get_startup_info_string())
|
|
424
|
+
|
|
425
|
+
# Convert device to string if it's a torch.device object
|
|
426
|
+
if hasattr(device, "type"):
|
|
427
|
+
device = str(device)
|
|
350
428
|
|
|
351
429
|
if device == "auto":
|
|
352
430
|
device = (
|
|
@@ -387,7 +465,22 @@ def run_inference(
|
|
|
387
465
|
):
|
|
388
466
|
if post_connect_single_breaks or tracking_pre_cull_to_target:
|
|
389
467
|
if tracking_target_instance_count is None and max_instances is None:
|
|
390
|
-
|
|
468
|
+
features_requested = []
|
|
469
|
+
if post_connect_single_breaks:
|
|
470
|
+
features_requested.append("--post_connect_single_breaks")
|
|
471
|
+
if tracking_pre_cull_to_target:
|
|
472
|
+
features_requested.append("--tracking_pre_cull_to_target")
|
|
473
|
+
features_str = " and ".join(features_requested)
|
|
474
|
+
|
|
475
|
+
if max_tracks is not None:
|
|
476
|
+
suggestion = f"Add --tracking_target_instance_count {max_tracks} to your command (using your --max_tracks value)."
|
|
477
|
+
else:
|
|
478
|
+
suggestion = "Add --tracking_target_instance_count N or --max_instances N where N is the expected number of instances per frame."
|
|
479
|
+
|
|
480
|
+
message = (
|
|
481
|
+
f"{features_str} requires --tracking_target_instance_count or --max_instances to be set. "
|
|
482
|
+
f"{suggestion}"
|
|
483
|
+
)
|
|
391
484
|
logger.error(message)
|
|
392
485
|
raise ValueError(message)
|
|
393
486
|
elif tracking_target_instance_count is None:
|
|
@@ -448,6 +541,8 @@ def run_inference(
|
|
|
448
541
|
frames=frames,
|
|
449
542
|
only_labeled_frames=only_labeled_frames,
|
|
450
543
|
only_suggested_frames=only_suggested_frames,
|
|
544
|
+
exclude_user_labeled=exclude_user_labeled,
|
|
545
|
+
only_predicted_frames=only_predicted_frames,
|
|
451
546
|
video_index=video_index,
|
|
452
547
|
video_dataset=video_dataset,
|
|
453
548
|
video_input_format=video_input_format,
|
|
@@ -492,12 +587,94 @@ def run_inference(
|
|
|
492
587
|
skeletons=output.skeletons,
|
|
493
588
|
)
|
|
494
589
|
|
|
495
|
-
|
|
590
|
+
end_datetime = datetime.now()
|
|
591
|
+
finish_timestamp = str(end_datetime)
|
|
496
592
|
total_elapsed = time() - start_inf_time
|
|
497
593
|
logger.info(f"Finished inference at: {finish_timestamp}")
|
|
498
|
-
logger.info(
|
|
499
|
-
|
|
500
|
-
|
|
594
|
+
logger.info(f"Total runtime: {total_elapsed} secs")
|
|
595
|
+
|
|
596
|
+
# Determine input labels for provenance preservation
|
|
597
|
+
input_labels_for_prov = None
|
|
598
|
+
if input_labels is not None:
|
|
599
|
+
input_labels_for_prov = input_labels
|
|
600
|
+
elif data_path is not None and data_path.endswith(".slp"):
|
|
601
|
+
# Load input labels to preserve provenance (if not already loaded)
|
|
602
|
+
try:
|
|
603
|
+
input_labels_for_prov = sio.load_slp(data_path)
|
|
604
|
+
except Exception:
|
|
605
|
+
pass
|
|
606
|
+
|
|
607
|
+
# Build inference parameters for provenance
|
|
608
|
+
inference_params = {
|
|
609
|
+
"peak_threshold": peak_threshold,
|
|
610
|
+
"integral_refinement": integral_refinement,
|
|
611
|
+
"integral_patch_size": integral_patch_size,
|
|
612
|
+
"batch_size": batch_size,
|
|
613
|
+
"max_instances": max_instances,
|
|
614
|
+
"crop_size": crop_size,
|
|
615
|
+
"input_scale": input_scale,
|
|
616
|
+
"anchor_part": anchor_part,
|
|
617
|
+
}
|
|
618
|
+
|
|
619
|
+
# Build tracking parameters if tracking was enabled
|
|
620
|
+
tracking_params_prov = None
|
|
621
|
+
if tracking:
|
|
622
|
+
tracking_params_prov = {
|
|
623
|
+
"window_size": tracking_window_size,
|
|
624
|
+
"min_new_track_points": min_new_track_points,
|
|
625
|
+
"candidates_method": candidates_method,
|
|
626
|
+
"min_match_points": min_match_points,
|
|
627
|
+
"features": features,
|
|
628
|
+
"scoring_method": scoring_method,
|
|
629
|
+
"scoring_reduction": scoring_reduction,
|
|
630
|
+
"robust_best_instance": robust_best_instance,
|
|
631
|
+
"track_matching_method": track_matching_method,
|
|
632
|
+
"max_tracks": max_tracks,
|
|
633
|
+
"use_flow": use_flow,
|
|
634
|
+
"post_connect_single_breaks": post_connect_single_breaks,
|
|
635
|
+
}
|
|
636
|
+
|
|
637
|
+
# Determine frame selection method
|
|
638
|
+
frame_selection_method = "all"
|
|
639
|
+
if only_labeled_frames:
|
|
640
|
+
frame_selection_method = "labeled"
|
|
641
|
+
elif only_suggested_frames:
|
|
642
|
+
frame_selection_method = "suggested"
|
|
643
|
+
elif only_predicted_frames:
|
|
644
|
+
frame_selection_method = "predicted"
|
|
645
|
+
elif frames is not None:
|
|
646
|
+
frame_selection_method = "specified"
|
|
647
|
+
|
|
648
|
+
# Determine model type from predictor class
|
|
649
|
+
predictor_type_map = {
|
|
650
|
+
"TopDownPredictor": "top_down",
|
|
651
|
+
"SingleInstancePredictor": "single_instance",
|
|
652
|
+
"BottomUpPredictor": "bottom_up",
|
|
653
|
+
"BottomUpMultiClassPredictor": "bottom_up_multi_class",
|
|
654
|
+
"TopDownMultiClassPredictor": "top_down_multi_class",
|
|
655
|
+
}
|
|
656
|
+
model_type = predictor_type_map.get(type(predictor).__name__)
|
|
657
|
+
|
|
658
|
+
# Build and set provenance (only for Labels objects)
|
|
659
|
+
if make_labels and isinstance(output, sio.Labels):
|
|
660
|
+
provenance = build_inference_provenance(
|
|
661
|
+
model_paths=model_paths,
|
|
662
|
+
model_type=model_type,
|
|
663
|
+
start_time=start_datetime,
|
|
664
|
+
end_time=end_datetime,
|
|
665
|
+
input_labels=input_labels_for_prov,
|
|
666
|
+
input_path=data_path,
|
|
667
|
+
frames_processed=(
|
|
668
|
+
len(output.labeled_frames)
|
|
669
|
+
if hasattr(output, "labeled_frames")
|
|
670
|
+
else None
|
|
671
|
+
),
|
|
672
|
+
frame_selection_method=frame_selection_method,
|
|
673
|
+
inference_params=inference_params,
|
|
674
|
+
tracking_params=tracking_params_prov,
|
|
675
|
+
device=device,
|
|
676
|
+
)
|
|
677
|
+
output.provenance = provenance
|
|
501
678
|
|
|
502
679
|
if no_empty_frames:
|
|
503
680
|
output.clean(frames=True, skeletons=False)
|