sleap-nn 0.1.0a2__py3-none-any.whl → 0.1.0a4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sleap_nn/__init__.py +1 -1
- sleap_nn/architectures/convnext.py +5 -0
- sleap_nn/architectures/encoder_decoder.py +25 -6
- sleap_nn/architectures/swint.py +8 -0
- sleap_nn/cli.py +168 -39
- sleap_nn/evaluation.py +8 -0
- sleap_nn/export/__init__.py +21 -0
- sleap_nn/export/cli.py +1778 -0
- sleap_nn/export/exporters/__init__.py +51 -0
- sleap_nn/export/exporters/onnx_exporter.py +80 -0
- sleap_nn/export/exporters/tensorrt_exporter.py +291 -0
- sleap_nn/export/metadata.py +225 -0
- sleap_nn/export/predictors/__init__.py +63 -0
- sleap_nn/export/predictors/base.py +22 -0
- sleap_nn/export/predictors/onnx.py +154 -0
- sleap_nn/export/predictors/tensorrt.py +312 -0
- sleap_nn/export/utils.py +307 -0
- sleap_nn/export/wrappers/__init__.py +25 -0
- sleap_nn/export/wrappers/base.py +96 -0
- sleap_nn/export/wrappers/bottomup.py +243 -0
- sleap_nn/export/wrappers/bottomup_multiclass.py +195 -0
- sleap_nn/export/wrappers/centered_instance.py +56 -0
- sleap_nn/export/wrappers/centroid.py +58 -0
- sleap_nn/export/wrappers/single_instance.py +83 -0
- sleap_nn/export/wrappers/topdown.py +180 -0
- sleap_nn/export/wrappers/topdown_multiclass.py +304 -0
- sleap_nn/inference/peak_finding.py +47 -17
- sleap_nn/inference/postprocessing.py +284 -0
- sleap_nn/inference/predictors.py +213 -106
- sleap_nn/predict.py +35 -7
- sleap_nn/train.py +64 -0
- sleap_nn/training/callbacks.py +69 -22
- sleap_nn/training/lightning_modules.py +332 -30
- sleap_nn/training/model_trainer.py +67 -67
- {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a4.dist-info}/METADATA +13 -1
- {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a4.dist-info}/RECORD +40 -19
- {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a4.dist-info}/WHEEL +0 -0
- {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a4.dist-info}/entry_points.txt +0 -0
- {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a4.dist-info}/licenses/LICENSE +0 -0
- {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a4.dist-info}/top_level.txt +0 -0
sleap_nn/inference/predictors.py
CHANGED
|
@@ -56,6 +56,8 @@ from rich.progress import (
|
|
|
56
56
|
MofNCompleteColumn,
|
|
57
57
|
)
|
|
58
58
|
from time import time
|
|
59
|
+
import json
|
|
60
|
+
import sys
|
|
59
61
|
|
|
60
62
|
|
|
61
63
|
def _filter_user_labeled_frames(
|
|
@@ -133,6 +135,8 @@ class Predictor(ABC):
|
|
|
133
135
|
`backbone_config`. This determines the downsampling factor applied by the backbone,
|
|
134
136
|
and is used to ensure that input images are padded or resized to be compatible
|
|
135
137
|
with the model's architecture. Default: 16.
|
|
138
|
+
gui: If True, outputs JSON progress lines for GUI integration instead of
|
|
139
|
+
Rich progress bars. Default: False.
|
|
136
140
|
"""
|
|
137
141
|
|
|
138
142
|
preprocess: bool = True
|
|
@@ -152,6 +156,7 @@ class Predictor(ABC):
|
|
|
152
156
|
] = None
|
|
153
157
|
instances_key: bool = False
|
|
154
158
|
max_stride: int = 16
|
|
159
|
+
gui: bool = False
|
|
155
160
|
|
|
156
161
|
@classmethod
|
|
157
162
|
def from_model_paths(
|
|
@@ -381,6 +386,102 @@ class Predictor(ABC):
|
|
|
381
386
|
v[n] = v[n].cpu().numpy()
|
|
382
387
|
return output
|
|
383
388
|
|
|
389
|
+
def _process_batch(self) -> tuple:
|
|
390
|
+
"""Process a single batch of frames from the pipeline.
|
|
391
|
+
|
|
392
|
+
Returns:
|
|
393
|
+
Tuple of (imgs, fidxs, vidxs, org_szs, instances, eff_scales, done)
|
|
394
|
+
where done is True if the pipeline has finished.
|
|
395
|
+
"""
|
|
396
|
+
imgs = []
|
|
397
|
+
fidxs = []
|
|
398
|
+
vidxs = []
|
|
399
|
+
org_szs = []
|
|
400
|
+
instances = []
|
|
401
|
+
eff_scales = []
|
|
402
|
+
done = False
|
|
403
|
+
|
|
404
|
+
for _ in range(self.batch_size):
|
|
405
|
+
frame = self.pipeline.frame_buffer.get()
|
|
406
|
+
if frame["image"] is None:
|
|
407
|
+
done = True
|
|
408
|
+
break
|
|
409
|
+
frame["image"], eff_scale = apply_sizematcher(
|
|
410
|
+
frame["image"],
|
|
411
|
+
self.preprocess_config["max_height"],
|
|
412
|
+
self.preprocess_config["max_width"],
|
|
413
|
+
)
|
|
414
|
+
if self.instances_key:
|
|
415
|
+
frame["instances"] = frame["instances"] * eff_scale
|
|
416
|
+
if self.preprocess_config["ensure_rgb"] and frame["image"].shape[-3] != 3:
|
|
417
|
+
frame["image"] = frame["image"].repeat(1, 3, 1, 1)
|
|
418
|
+
elif (
|
|
419
|
+
self.preprocess_config["ensure_grayscale"]
|
|
420
|
+
and frame["image"].shape[-3] != 1
|
|
421
|
+
):
|
|
422
|
+
frame["image"] = F.rgb_to_grayscale(
|
|
423
|
+
frame["image"], num_output_channels=1
|
|
424
|
+
)
|
|
425
|
+
|
|
426
|
+
eff_scales.append(torch.tensor(eff_scale))
|
|
427
|
+
imgs.append(frame["image"].unsqueeze(dim=0))
|
|
428
|
+
fidxs.append(frame["frame_idx"])
|
|
429
|
+
vidxs.append(frame["video_idx"])
|
|
430
|
+
org_szs.append(frame["orig_size"].unsqueeze(dim=0))
|
|
431
|
+
if self.instances_key:
|
|
432
|
+
instances.append(frame["instances"].unsqueeze(dim=0))
|
|
433
|
+
|
|
434
|
+
return imgs, fidxs, vidxs, org_szs, instances, eff_scales, done
|
|
435
|
+
|
|
436
|
+
def _run_inference_on_batch(
|
|
437
|
+
self, imgs, fidxs, vidxs, org_szs, instances, eff_scales
|
|
438
|
+
) -> Iterator[Dict[str, np.ndarray]]:
|
|
439
|
+
"""Run inference on a prepared batch of frames.
|
|
440
|
+
|
|
441
|
+
Args:
|
|
442
|
+
imgs: List of image tensors.
|
|
443
|
+
fidxs: List of frame indices.
|
|
444
|
+
vidxs: List of video indices.
|
|
445
|
+
org_szs: List of original sizes.
|
|
446
|
+
instances: List of instance tensors.
|
|
447
|
+
eff_scales: List of effective scales.
|
|
448
|
+
|
|
449
|
+
Yields:
|
|
450
|
+
Dictionaries containing inference results for each frame.
|
|
451
|
+
"""
|
|
452
|
+
# TODO: all preprocessing should be moved into InferenceModels to be exportable.
|
|
453
|
+
imgs = torch.concatenate(imgs, dim=0)
|
|
454
|
+
fidxs = torch.tensor(fidxs, dtype=torch.int32)
|
|
455
|
+
vidxs = torch.tensor(vidxs, dtype=torch.int32)
|
|
456
|
+
org_szs = torch.concatenate(org_szs, dim=0)
|
|
457
|
+
eff_scales = torch.tensor(eff_scales, dtype=torch.float32)
|
|
458
|
+
if self.instances_key:
|
|
459
|
+
instances = torch.concatenate(instances, dim=0)
|
|
460
|
+
ex = {
|
|
461
|
+
"image": imgs,
|
|
462
|
+
"frame_idx": fidxs,
|
|
463
|
+
"video_idx": vidxs,
|
|
464
|
+
"orig_size": org_szs,
|
|
465
|
+
"eff_scale": eff_scales,
|
|
466
|
+
}
|
|
467
|
+
if self.instances_key:
|
|
468
|
+
ex["instances"] = instances
|
|
469
|
+
if self.preprocess:
|
|
470
|
+
scale = self.preprocess_config["scale"]
|
|
471
|
+
if scale != 1.0:
|
|
472
|
+
if self.instances_key:
|
|
473
|
+
ex["image"], ex["instances"] = apply_resizer(
|
|
474
|
+
ex["image"], ex["instances"]
|
|
475
|
+
)
|
|
476
|
+
else:
|
|
477
|
+
ex["image"] = resize_image(ex["image"], scale)
|
|
478
|
+
ex["image"] = apply_pad_to_stride(ex["image"], self.max_stride)
|
|
479
|
+
outputs_list = self.inference_model(ex)
|
|
480
|
+
if outputs_list is not None:
|
|
481
|
+
for output in outputs_list:
|
|
482
|
+
output = self._convert_tensors_to_numpy(output)
|
|
483
|
+
yield output
|
|
484
|
+
|
|
384
485
|
def _predict_generator(self) -> Iterator[Dict[str, np.ndarray]]:
|
|
385
486
|
"""Create a generator that yields batches of inference results.
|
|
386
487
|
|
|
@@ -400,114 +501,14 @@ class Predictor(ABC):
|
|
|
400
501
|
# Loop over data batches.
|
|
401
502
|
self.pipeline.start()
|
|
402
503
|
total_frames = self.pipeline.total_len()
|
|
403
|
-
done = False
|
|
404
504
|
|
|
405
505
|
try:
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
TimeRemainingColumn(),
|
|
413
|
-
"Elapsed:",
|
|
414
|
-
TimeElapsedColumn(),
|
|
415
|
-
RateColumn(),
|
|
416
|
-
auto_refresh=False,
|
|
417
|
-
refresh_per_second=4, # Change to self.report_rate if needed
|
|
418
|
-
speed_estimate_period=5,
|
|
419
|
-
) as progress:
|
|
420
|
-
task = progress.add_task("Predicting...", total=total_frames)
|
|
421
|
-
last_report = time()
|
|
422
|
-
|
|
423
|
-
done = False
|
|
424
|
-
while not done:
|
|
425
|
-
imgs = []
|
|
426
|
-
fidxs = []
|
|
427
|
-
vidxs = []
|
|
428
|
-
org_szs = []
|
|
429
|
-
instances = []
|
|
430
|
-
eff_scales = []
|
|
431
|
-
for _ in range(self.batch_size):
|
|
432
|
-
frame = self.pipeline.frame_buffer.get()
|
|
433
|
-
if frame["image"] is None:
|
|
434
|
-
done = True
|
|
435
|
-
break
|
|
436
|
-
frame["image"], eff_scale = apply_sizematcher(
|
|
437
|
-
frame["image"],
|
|
438
|
-
self.preprocess_config["max_height"],
|
|
439
|
-
self.preprocess_config["max_width"],
|
|
440
|
-
)
|
|
441
|
-
if self.instances_key:
|
|
442
|
-
frame["instances"] = frame["instances"] * eff_scale
|
|
443
|
-
if (
|
|
444
|
-
self.preprocess_config["ensure_rgb"]
|
|
445
|
-
and frame["image"].shape[-3] != 3
|
|
446
|
-
):
|
|
447
|
-
frame["image"] = frame["image"].repeat(1, 3, 1, 1)
|
|
448
|
-
elif (
|
|
449
|
-
self.preprocess_config["ensure_grayscale"]
|
|
450
|
-
and frame["image"].shape[-3] != 1
|
|
451
|
-
):
|
|
452
|
-
frame["image"] = F.rgb_to_grayscale(
|
|
453
|
-
frame["image"], num_output_channels=1
|
|
454
|
-
)
|
|
455
|
-
|
|
456
|
-
eff_scales.append(torch.tensor(eff_scale))
|
|
457
|
-
imgs.append(frame["image"].unsqueeze(dim=0))
|
|
458
|
-
fidxs.append(frame["frame_idx"])
|
|
459
|
-
vidxs.append(frame["video_idx"])
|
|
460
|
-
org_szs.append(frame["orig_size"].unsqueeze(dim=0))
|
|
461
|
-
if self.instances_key:
|
|
462
|
-
instances.append(frame["instances"].unsqueeze(dim=0))
|
|
463
|
-
if imgs:
|
|
464
|
-
# TODO: all preprocessing should be moved into InferenceModels to be exportable.
|
|
465
|
-
imgs = torch.concatenate(imgs, dim=0)
|
|
466
|
-
fidxs = torch.tensor(fidxs, dtype=torch.int32)
|
|
467
|
-
vidxs = torch.tensor(vidxs, dtype=torch.int32)
|
|
468
|
-
org_szs = torch.concatenate(org_szs, dim=0)
|
|
469
|
-
eff_scales = torch.tensor(eff_scales, dtype=torch.float32)
|
|
470
|
-
if self.instances_key:
|
|
471
|
-
instances = torch.concatenate(instances, dim=0)
|
|
472
|
-
ex = {
|
|
473
|
-
"image": imgs,
|
|
474
|
-
"frame_idx": fidxs,
|
|
475
|
-
"video_idx": vidxs,
|
|
476
|
-
"orig_size": org_szs,
|
|
477
|
-
"eff_scale": eff_scales,
|
|
478
|
-
}
|
|
479
|
-
if self.instances_key:
|
|
480
|
-
ex["instances"] = instances
|
|
481
|
-
if self.preprocess:
|
|
482
|
-
scale = self.preprocess_config["scale"]
|
|
483
|
-
if scale != 1.0:
|
|
484
|
-
if self.instances_key:
|
|
485
|
-
ex["image"], ex["instances"] = apply_resizer(
|
|
486
|
-
ex["image"], ex["instances"]
|
|
487
|
-
)
|
|
488
|
-
else:
|
|
489
|
-
ex["image"] = resize_image(ex["image"], scale)
|
|
490
|
-
ex["image"] = apply_pad_to_stride(
|
|
491
|
-
ex["image"], self.max_stride
|
|
492
|
-
)
|
|
493
|
-
outputs_list = self.inference_model(ex)
|
|
494
|
-
if outputs_list is not None:
|
|
495
|
-
for output in outputs_list:
|
|
496
|
-
output = self._convert_tensors_to_numpy(output)
|
|
497
|
-
yield output
|
|
498
|
-
|
|
499
|
-
# Advance progress
|
|
500
|
-
num_frames = (
|
|
501
|
-
len(ex["frame_idx"])
|
|
502
|
-
if "frame_idx" in ex
|
|
503
|
-
else self.batch_size
|
|
504
|
-
)
|
|
505
|
-
progress.update(task, advance=num_frames)
|
|
506
|
-
|
|
507
|
-
# Manually refresh progress bar
|
|
508
|
-
if time() - last_report > 0.25:
|
|
509
|
-
progress.refresh()
|
|
510
|
-
last_report = time()
|
|
506
|
+
if self.gui:
|
|
507
|
+
# GUI mode: emit JSON progress lines
|
|
508
|
+
yield from self._predict_generator_gui(total_frames)
|
|
509
|
+
else:
|
|
510
|
+
# Normal mode: use Rich progress bar
|
|
511
|
+
yield from self._predict_generator_rich(total_frames)
|
|
511
512
|
|
|
512
513
|
except KeyboardInterrupt:
|
|
513
514
|
logger.info("Inference interrupted by user")
|
|
@@ -520,6 +521,112 @@ class Predictor(ABC):
|
|
|
520
521
|
|
|
521
522
|
self.pipeline.join()
|
|
522
523
|
|
|
524
|
+
def _predict_generator_gui(
|
|
525
|
+
self, total_frames: int
|
|
526
|
+
) -> Iterator[Dict[str, np.ndarray]]:
|
|
527
|
+
"""Generator for GUI mode with JSON progress output.
|
|
528
|
+
|
|
529
|
+
Args:
|
|
530
|
+
total_frames: Total number of frames to process.
|
|
531
|
+
|
|
532
|
+
Yields:
|
|
533
|
+
Dictionaries containing inference results for each frame.
|
|
534
|
+
"""
|
|
535
|
+
start_time = time()
|
|
536
|
+
frames_processed = 0
|
|
537
|
+
last_report = time()
|
|
538
|
+
done = False
|
|
539
|
+
|
|
540
|
+
while not done:
|
|
541
|
+
imgs, fidxs, vidxs, org_szs, instances, eff_scales, done = (
|
|
542
|
+
self._process_batch()
|
|
543
|
+
)
|
|
544
|
+
|
|
545
|
+
if imgs:
|
|
546
|
+
yield from self._run_inference_on_batch(
|
|
547
|
+
imgs, fidxs, vidxs, org_szs, instances, eff_scales
|
|
548
|
+
)
|
|
549
|
+
|
|
550
|
+
# Update progress
|
|
551
|
+
num_frames = len(fidxs)
|
|
552
|
+
frames_processed += num_frames
|
|
553
|
+
|
|
554
|
+
# Emit JSON progress (throttled to ~4Hz)
|
|
555
|
+
if time() - last_report > 0.25:
|
|
556
|
+
elapsed = time() - start_time
|
|
557
|
+
rate = frames_processed / elapsed if elapsed > 0 else 0
|
|
558
|
+
remaining = total_frames - frames_processed
|
|
559
|
+
eta = remaining / rate if rate > 0 else 0
|
|
560
|
+
|
|
561
|
+
progress_data = {
|
|
562
|
+
"n_processed": frames_processed,
|
|
563
|
+
"n_total": total_frames,
|
|
564
|
+
"rate": round(rate, 1),
|
|
565
|
+
"eta": round(eta, 1),
|
|
566
|
+
}
|
|
567
|
+
print(json.dumps(progress_data), flush=True)
|
|
568
|
+
last_report = time()
|
|
569
|
+
|
|
570
|
+
# Final progress emit to ensure 100% is shown
|
|
571
|
+
elapsed = time() - start_time
|
|
572
|
+
progress_data = {
|
|
573
|
+
"n_processed": total_frames,
|
|
574
|
+
"n_total": total_frames,
|
|
575
|
+
"rate": round(frames_processed / elapsed, 1) if elapsed > 0 else 0,
|
|
576
|
+
"eta": 0,
|
|
577
|
+
}
|
|
578
|
+
print(json.dumps(progress_data), flush=True)
|
|
579
|
+
|
|
580
|
+
def _predict_generator_rich(
|
|
581
|
+
self, total_frames: int
|
|
582
|
+
) -> Iterator[Dict[str, np.ndarray]]:
|
|
583
|
+
"""Generator for normal mode with Rich progress bar.
|
|
584
|
+
|
|
585
|
+
Args:
|
|
586
|
+
total_frames: Total number of frames to process.
|
|
587
|
+
|
|
588
|
+
Yields:
|
|
589
|
+
Dictionaries containing inference results for each frame.
|
|
590
|
+
"""
|
|
591
|
+
with Progress(
|
|
592
|
+
"{task.description}",
|
|
593
|
+
BarColumn(),
|
|
594
|
+
"[progress.percentage]{task.percentage:>3.0f}%",
|
|
595
|
+
MofNCompleteColumn(),
|
|
596
|
+
"ETA:",
|
|
597
|
+
TimeRemainingColumn(),
|
|
598
|
+
"Elapsed:",
|
|
599
|
+
TimeElapsedColumn(),
|
|
600
|
+
RateColumn(),
|
|
601
|
+
auto_refresh=False,
|
|
602
|
+
refresh_per_second=4,
|
|
603
|
+
speed_estimate_period=5,
|
|
604
|
+
) as progress:
|
|
605
|
+
task = progress.add_task("Predicting...", total=total_frames)
|
|
606
|
+
last_report = time()
|
|
607
|
+
done = False
|
|
608
|
+
|
|
609
|
+
while not done:
|
|
610
|
+
imgs, fidxs, vidxs, org_szs, instances, eff_scales, done = (
|
|
611
|
+
self._process_batch()
|
|
612
|
+
)
|
|
613
|
+
|
|
614
|
+
if imgs:
|
|
615
|
+
yield from self._run_inference_on_batch(
|
|
616
|
+
imgs, fidxs, vidxs, org_szs, instances, eff_scales
|
|
617
|
+
)
|
|
618
|
+
|
|
619
|
+
# Advance progress
|
|
620
|
+
num_frames = len(fidxs)
|
|
621
|
+
progress.update(task, advance=num_frames)
|
|
622
|
+
|
|
623
|
+
# Manually refresh progress bar
|
|
624
|
+
if time() - last_report > 0.25:
|
|
625
|
+
progress.refresh()
|
|
626
|
+
last_report = time()
|
|
627
|
+
|
|
628
|
+
self.pipeline.join()
|
|
629
|
+
|
|
523
630
|
def predict(
|
|
524
631
|
self,
|
|
525
632
|
make_labels: bool = True,
|
sleap_nn/predict.py
CHANGED
|
@@ -74,6 +74,9 @@ def run_inference(
|
|
|
74
74
|
frames: Optional[list] = None,
|
|
75
75
|
crop_size: Optional[int] = None,
|
|
76
76
|
peak_threshold: Union[float, List[float]] = 0.2,
|
|
77
|
+
filter_overlapping: bool = False,
|
|
78
|
+
filter_overlapping_method: str = "iou",
|
|
79
|
+
filter_overlapping_threshold: float = 0.8,
|
|
77
80
|
integral_refinement: Optional[str] = "integral",
|
|
78
81
|
integral_patch_size: int = 5,
|
|
79
82
|
return_confmaps: bool = False,
|
|
@@ -110,6 +113,7 @@ def run_inference(
|
|
|
110
113
|
tracking_pre_cull_iou_threshold: float = 0,
|
|
111
114
|
tracking_clean_instance_count: int = 0,
|
|
112
115
|
tracking_clean_iou_threshold: float = 0,
|
|
116
|
+
gui: bool = False,
|
|
113
117
|
):
|
|
114
118
|
"""Entry point to run inference on trained SLEAP-NN models.
|
|
115
119
|
|
|
@@ -160,6 +164,15 @@ def run_inference(
|
|
|
160
164
|
centroid and centered-instance model, where the first element corresponds
|
|
161
165
|
to centroid model peak finding threshold and the second element is for
|
|
162
166
|
centered-instance model peak finding.
|
|
167
|
+
filter_overlapping: (bool) If True, removes overlapping instances after
|
|
168
|
+
inference using greedy NMS. Applied independently of tracking.
|
|
169
|
+
Default: False.
|
|
170
|
+
filter_overlapping_method: (str) Similarity metric for filtering overlapping
|
|
171
|
+
instances. One of "iou" (bounding box) or "oks" (keypoint similarity).
|
|
172
|
+
Default: "iou".
|
|
173
|
+
filter_overlapping_threshold: (float) Similarity threshold for filtering.
|
|
174
|
+
Instances with similarity > threshold are removed (keeping higher-scoring).
|
|
175
|
+
Typical values: 0.3 (aggressive) to 0.8 (permissive). Default: 0.8.
|
|
163
176
|
integral_refinement: (str) If `None`, returns the grid-aligned peaks with no refinement.
|
|
164
177
|
If `"integral"`, peaks will be refined with integral regression.
|
|
165
178
|
Default: `"integral"`.
|
|
@@ -250,6 +263,8 @@ def run_inference(
|
|
|
250
263
|
tracking_pre_cull_iou_threshold: If non-zero and pre_cull_to_target also set, then use IOU threshold to remove overlapping instances over count *before* tracking. (default: 0)
|
|
251
264
|
tracking_clean_instance_count: Target number of instances to clean *after* tracking. (default: 0)
|
|
252
265
|
tracking_clean_iou_threshold: IOU to use when culling instances *after* tracking. (default: 0)
|
|
266
|
+
gui: (bool) If True, outputs JSON progress lines for GUI integration instead
|
|
267
|
+
of Rich progress bars. Default: False.
|
|
253
268
|
|
|
254
269
|
Returns:
|
|
255
270
|
Returns `sio.Labels` object if `make_labels` is True. Else this function returns
|
|
@@ -433,13 +448,6 @@ def run_inference(
|
|
|
433
448
|
else "mps" if torch.backends.mps.is_available() else "cpu"
|
|
434
449
|
)
|
|
435
450
|
|
|
436
|
-
if integral_refinement is not None and device == "mps": # TODO
|
|
437
|
-
# kornia/geometry/transform/imgwarp.py:382: in get_perspective_transform. NotImplementedError: The operator 'aten::_linalg_solve_ex.result' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.
|
|
438
|
-
logger.info(
|
|
439
|
-
"Integral refinement is not supported with MPS accelerator. Setting integral refinement to None."
|
|
440
|
-
)
|
|
441
|
-
integral_refinement = None
|
|
442
|
-
|
|
443
451
|
logger.info(f"Using device: {device}")
|
|
444
452
|
|
|
445
453
|
# initializes the inference model
|
|
@@ -458,6 +466,9 @@ def run_inference(
|
|
|
458
466
|
anchor_part=anchor_part,
|
|
459
467
|
)
|
|
460
468
|
|
|
469
|
+
# Set GUI mode for progress output
|
|
470
|
+
predictor.gui = gui
|
|
471
|
+
|
|
461
472
|
if (
|
|
462
473
|
tracking
|
|
463
474
|
and not isinstance(predictor, BottomUpMultiClassPredictor)
|
|
@@ -553,6 +564,20 @@ def run_inference(
|
|
|
553
564
|
make_labels=make_labels,
|
|
554
565
|
)
|
|
555
566
|
|
|
567
|
+
# Filter overlapping instances (independent of tracking)
|
|
568
|
+
if filter_overlapping and make_labels:
|
|
569
|
+
from sleap_nn.inference.postprocessing import filter_overlapping_instances
|
|
570
|
+
|
|
571
|
+
output = filter_overlapping_instances(
|
|
572
|
+
output,
|
|
573
|
+
threshold=filter_overlapping_threshold,
|
|
574
|
+
method=filter_overlapping_method,
|
|
575
|
+
)
|
|
576
|
+
logger.info(
|
|
577
|
+
f"Filtered overlapping instances with {filter_overlapping_method.upper()} "
|
|
578
|
+
f"threshold: {filter_overlapping_threshold}"
|
|
579
|
+
)
|
|
580
|
+
|
|
556
581
|
if tracking:
|
|
557
582
|
lfs = [x for x in output]
|
|
558
583
|
if tracking_clean_instance_count > 0:
|
|
@@ -607,6 +632,9 @@ def run_inference(
|
|
|
607
632
|
# Build inference parameters for provenance
|
|
608
633
|
inference_params = {
|
|
609
634
|
"peak_threshold": peak_threshold,
|
|
635
|
+
"filter_overlapping": filter_overlapping,
|
|
636
|
+
"filter_overlapping_method": filter_overlapping_method,
|
|
637
|
+
"filter_overlapping_threshold": filter_overlapping_threshold,
|
|
610
638
|
"integral_refinement": integral_refinement,
|
|
611
639
|
"integral_patch_size": integral_patch_size,
|
|
612
640
|
"batch_size": batch_size,
|
sleap_nn/train.py
CHANGED
|
@@ -118,6 +118,70 @@ def run_training(
|
|
|
118
118
|
logger.info(f"p90 dist: {metrics['distance_metrics']['p90']}")
|
|
119
119
|
logger.info(f"p50 dist: {metrics['distance_metrics']['p50']}")
|
|
120
120
|
|
|
121
|
+
# Log test metrics to wandb summary
|
|
122
|
+
if (
|
|
123
|
+
d_name.startswith("test")
|
|
124
|
+
and trainer.config.trainer_config.use_wandb
|
|
125
|
+
):
|
|
126
|
+
import wandb
|
|
127
|
+
|
|
128
|
+
if wandb.run is not None:
|
|
129
|
+
summary_metrics = {
|
|
130
|
+
f"eval/{d_name}/mOKS": metrics["mOKS"]["mOKS"],
|
|
131
|
+
f"eval/{d_name}/oks_voc_mAP": metrics["voc_metrics"][
|
|
132
|
+
"oks_voc.mAP"
|
|
133
|
+
],
|
|
134
|
+
f"eval/{d_name}/oks_voc_mAR": metrics["voc_metrics"][
|
|
135
|
+
"oks_voc.mAR"
|
|
136
|
+
],
|
|
137
|
+
f"eval/{d_name}/mPCK": metrics["pck_metrics"]["mPCK"],
|
|
138
|
+
f"eval/{d_name}/PCK_5": metrics["pck_metrics"]["PCK@5"],
|
|
139
|
+
f"eval/{d_name}/PCK_10": metrics["pck_metrics"]["PCK@10"],
|
|
140
|
+
f"eval/{d_name}/distance_avg": metrics["distance_metrics"][
|
|
141
|
+
"avg"
|
|
142
|
+
],
|
|
143
|
+
f"eval/{d_name}/distance_p50": metrics["distance_metrics"][
|
|
144
|
+
"p50"
|
|
145
|
+
],
|
|
146
|
+
f"eval/{d_name}/distance_p95": metrics["distance_metrics"][
|
|
147
|
+
"p95"
|
|
148
|
+
],
|
|
149
|
+
f"eval/{d_name}/distance_p99": metrics["distance_metrics"][
|
|
150
|
+
"p99"
|
|
151
|
+
],
|
|
152
|
+
f"eval/{d_name}/visibility_precision": metrics[
|
|
153
|
+
"visibility_metrics"
|
|
154
|
+
]["precision"],
|
|
155
|
+
f"eval/{d_name}/visibility_recall": metrics[
|
|
156
|
+
"visibility_metrics"
|
|
157
|
+
]["recall"],
|
|
158
|
+
}
|
|
159
|
+
for key, value in summary_metrics.items():
|
|
160
|
+
wandb.run.summary[key] = value
|
|
161
|
+
|
|
162
|
+
# Finish wandb run and cleanup after all evaluation is complete
|
|
163
|
+
if trainer.config.trainer_config.use_wandb:
|
|
164
|
+
import wandb
|
|
165
|
+
import shutil
|
|
166
|
+
|
|
167
|
+
if wandb.run is not None:
|
|
168
|
+
wandb.finish()
|
|
169
|
+
|
|
170
|
+
# Delete local wandb logs if configured
|
|
171
|
+
wandb_config = trainer.config.trainer_config.wandb
|
|
172
|
+
should_delete_wandb_logs = wandb_config.delete_local_logs is True or (
|
|
173
|
+
wandb_config.delete_local_logs is None
|
|
174
|
+
and wandb_config.wandb_mode != "offline"
|
|
175
|
+
)
|
|
176
|
+
if should_delete_wandb_logs:
|
|
177
|
+
wandb_dir = run_path / "wandb"
|
|
178
|
+
if wandb_dir.exists():
|
|
179
|
+
logger.info(
|
|
180
|
+
f"Deleting local wandb logs at {wandb_dir}... "
|
|
181
|
+
"(set trainer_config.wandb.delete_local_logs=false to disable)"
|
|
182
|
+
)
|
|
183
|
+
shutil.rmtree(wandb_dir, ignore_errors=True)
|
|
184
|
+
|
|
121
185
|
|
|
122
186
|
def train(
|
|
123
187
|
train_labels_path: Optional[List[str]] = None,
|