sleap-nn 0.1.0__py3-none-any.whl → 0.1.0a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sleap_nn/__init__.py +2 -4
- sleap_nn/architectures/convnext.py +0 -5
- sleap_nn/architectures/encoder_decoder.py +6 -25
- sleap_nn/architectures/swint.py +0 -8
- sleap_nn/cli.py +60 -364
- sleap_nn/config/data_config.py +5 -11
- sleap_nn/config/get_config.py +4 -10
- sleap_nn/config/trainer_config.py +0 -76
- sleap_nn/data/augmentation.py +241 -50
- sleap_nn/data/custom_datasets.py +39 -411
- sleap_nn/data/instance_cropping.py +1 -1
- sleap_nn/data/resizing.py +2 -2
- sleap_nn/data/utils.py +17 -135
- sleap_nn/evaluation.py +22 -81
- sleap_nn/inference/bottomup.py +20 -86
- sleap_nn/inference/peak_finding.py +19 -88
- sleap_nn/inference/predictors.py +117 -224
- sleap_nn/legacy_models.py +11 -65
- sleap_nn/predict.py +9 -37
- sleap_nn/train.py +4 -74
- sleap_nn/training/callbacks.py +105 -1046
- sleap_nn/training/lightning_modules.py +65 -602
- sleap_nn/training/model_trainer.py +184 -211
- {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a0.dist-info}/METADATA +3 -15
- sleap_nn-0.1.0a0.dist-info/RECORD +65 -0
- {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a0.dist-info}/WHEEL +1 -1
- sleap_nn/data/skia_augmentation.py +0 -414
- sleap_nn/export/__init__.py +0 -21
- sleap_nn/export/cli.py +0 -1778
- sleap_nn/export/exporters/__init__.py +0 -51
- sleap_nn/export/exporters/onnx_exporter.py +0 -80
- sleap_nn/export/exporters/tensorrt_exporter.py +0 -291
- sleap_nn/export/metadata.py +0 -225
- sleap_nn/export/predictors/__init__.py +0 -63
- sleap_nn/export/predictors/base.py +0 -22
- sleap_nn/export/predictors/onnx.py +0 -154
- sleap_nn/export/predictors/tensorrt.py +0 -312
- sleap_nn/export/utils.py +0 -307
- sleap_nn/export/wrappers/__init__.py +0 -25
- sleap_nn/export/wrappers/base.py +0 -96
- sleap_nn/export/wrappers/bottomup.py +0 -243
- sleap_nn/export/wrappers/bottomup_multiclass.py +0 -195
- sleap_nn/export/wrappers/centered_instance.py +0 -56
- sleap_nn/export/wrappers/centroid.py +0 -58
- sleap_nn/export/wrappers/single_instance.py +0 -83
- sleap_nn/export/wrappers/topdown.py +0 -180
- sleap_nn/export/wrappers/topdown_multiclass.py +0 -304
- sleap_nn/inference/postprocessing.py +0 -284
- sleap_nn/training/schedulers.py +0 -191
- sleap_nn-0.1.0.dist-info/RECORD +0 -88
- {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a0.dist-info}/entry_points.txt +0 -0
- {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a0.dist-info}/licenses/LICENSE +0 -0
- {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a0.dist-info}/top_level.txt +0 -0
sleap_nn/inference/predictors.py
CHANGED
|
@@ -56,8 +56,6 @@ from rich.progress import (
|
|
|
56
56
|
MofNCompleteColumn,
|
|
57
57
|
)
|
|
58
58
|
from time import time
|
|
59
|
-
import json
|
|
60
|
-
import sys
|
|
61
59
|
|
|
62
60
|
|
|
63
61
|
def _filter_user_labeled_frames(
|
|
@@ -135,8 +133,6 @@ class Predictor(ABC):
|
|
|
135
133
|
`backbone_config`. This determines the downsampling factor applied by the backbone,
|
|
136
134
|
and is used to ensure that input images are padded or resized to be compatible
|
|
137
135
|
with the model's architecture. Default: 16.
|
|
138
|
-
gui: If True, outputs JSON progress lines for GUI integration instead of
|
|
139
|
-
Rich progress bars. Default: False.
|
|
140
136
|
"""
|
|
141
137
|
|
|
142
138
|
preprocess: bool = True
|
|
@@ -156,7 +152,6 @@ class Predictor(ABC):
|
|
|
156
152
|
] = None
|
|
157
153
|
instances_key: bool = False
|
|
158
154
|
max_stride: int = 16
|
|
159
|
-
gui: bool = False
|
|
160
155
|
|
|
161
156
|
@classmethod
|
|
162
157
|
def from_model_paths(
|
|
@@ -360,7 +355,7 @@ class Predictor(ABC):
|
|
|
360
355
|
def make_pipeline(
|
|
361
356
|
self,
|
|
362
357
|
data_path: str,
|
|
363
|
-
queue_maxsize: int =
|
|
358
|
+
queue_maxsize: int = 8,
|
|
364
359
|
frames: Optional[list] = None,
|
|
365
360
|
only_labeled_frames: bool = False,
|
|
366
361
|
only_suggested_frames: bool = False,
|
|
@@ -386,102 +381,6 @@ class Predictor(ABC):
|
|
|
386
381
|
v[n] = v[n].cpu().numpy()
|
|
387
382
|
return output
|
|
388
383
|
|
|
389
|
-
def _process_batch(self) -> tuple:
|
|
390
|
-
"""Process a single batch of frames from the pipeline.
|
|
391
|
-
|
|
392
|
-
Returns:
|
|
393
|
-
Tuple of (imgs, fidxs, vidxs, org_szs, instances, eff_scales, done)
|
|
394
|
-
where done is True if the pipeline has finished.
|
|
395
|
-
"""
|
|
396
|
-
imgs = []
|
|
397
|
-
fidxs = []
|
|
398
|
-
vidxs = []
|
|
399
|
-
org_szs = []
|
|
400
|
-
instances = []
|
|
401
|
-
eff_scales = []
|
|
402
|
-
done = False
|
|
403
|
-
|
|
404
|
-
for _ in range(self.batch_size):
|
|
405
|
-
frame = self.pipeline.frame_buffer.get()
|
|
406
|
-
if frame["image"] is None:
|
|
407
|
-
done = True
|
|
408
|
-
break
|
|
409
|
-
frame["image"], eff_scale = apply_sizematcher(
|
|
410
|
-
frame["image"],
|
|
411
|
-
self.preprocess_config["max_height"],
|
|
412
|
-
self.preprocess_config["max_width"],
|
|
413
|
-
)
|
|
414
|
-
if self.instances_key:
|
|
415
|
-
frame["instances"] = frame["instances"] * eff_scale
|
|
416
|
-
if self.preprocess_config["ensure_rgb"] and frame["image"].shape[-3] != 3:
|
|
417
|
-
frame["image"] = frame["image"].repeat(1, 3, 1, 1)
|
|
418
|
-
elif (
|
|
419
|
-
self.preprocess_config["ensure_grayscale"]
|
|
420
|
-
and frame["image"].shape[-3] != 1
|
|
421
|
-
):
|
|
422
|
-
frame["image"] = F.rgb_to_grayscale(
|
|
423
|
-
frame["image"], num_output_channels=1
|
|
424
|
-
)
|
|
425
|
-
|
|
426
|
-
eff_scales.append(torch.tensor(eff_scale))
|
|
427
|
-
imgs.append(frame["image"].unsqueeze(dim=0))
|
|
428
|
-
fidxs.append(frame["frame_idx"])
|
|
429
|
-
vidxs.append(frame["video_idx"])
|
|
430
|
-
org_szs.append(frame["orig_size"].unsqueeze(dim=0))
|
|
431
|
-
if self.instances_key:
|
|
432
|
-
instances.append(frame["instances"].unsqueeze(dim=0))
|
|
433
|
-
|
|
434
|
-
return imgs, fidxs, vidxs, org_szs, instances, eff_scales, done
|
|
435
|
-
|
|
436
|
-
def _run_inference_on_batch(
|
|
437
|
-
self, imgs, fidxs, vidxs, org_szs, instances, eff_scales
|
|
438
|
-
) -> Iterator[Dict[str, np.ndarray]]:
|
|
439
|
-
"""Run inference on a prepared batch of frames.
|
|
440
|
-
|
|
441
|
-
Args:
|
|
442
|
-
imgs: List of image tensors.
|
|
443
|
-
fidxs: List of frame indices.
|
|
444
|
-
vidxs: List of video indices.
|
|
445
|
-
org_szs: List of original sizes.
|
|
446
|
-
instances: List of instance tensors.
|
|
447
|
-
eff_scales: List of effective scales.
|
|
448
|
-
|
|
449
|
-
Yields:
|
|
450
|
-
Dictionaries containing inference results for each frame.
|
|
451
|
-
"""
|
|
452
|
-
# TODO: all preprocessing should be moved into InferenceModels to be exportable.
|
|
453
|
-
imgs = torch.concatenate(imgs, dim=0)
|
|
454
|
-
fidxs = torch.tensor(fidxs, dtype=torch.int32)
|
|
455
|
-
vidxs = torch.tensor(vidxs, dtype=torch.int32)
|
|
456
|
-
org_szs = torch.concatenate(org_szs, dim=0)
|
|
457
|
-
eff_scales = torch.tensor(eff_scales, dtype=torch.float32)
|
|
458
|
-
if self.instances_key:
|
|
459
|
-
instances = torch.concatenate(instances, dim=0)
|
|
460
|
-
ex = {
|
|
461
|
-
"image": imgs,
|
|
462
|
-
"frame_idx": fidxs,
|
|
463
|
-
"video_idx": vidxs,
|
|
464
|
-
"orig_size": org_szs,
|
|
465
|
-
"eff_scale": eff_scales,
|
|
466
|
-
}
|
|
467
|
-
if self.instances_key:
|
|
468
|
-
ex["instances"] = instances
|
|
469
|
-
if self.preprocess:
|
|
470
|
-
scale = self.preprocess_config["scale"]
|
|
471
|
-
if scale != 1.0:
|
|
472
|
-
if self.instances_key:
|
|
473
|
-
ex["image"], ex["instances"] = apply_resizer(
|
|
474
|
-
ex["image"], ex["instances"]
|
|
475
|
-
)
|
|
476
|
-
else:
|
|
477
|
-
ex["image"] = resize_image(ex["image"], scale)
|
|
478
|
-
ex["image"] = apply_pad_to_stride(ex["image"], self.max_stride)
|
|
479
|
-
outputs_list = self.inference_model(ex)
|
|
480
|
-
if outputs_list is not None:
|
|
481
|
-
for output in outputs_list:
|
|
482
|
-
output = self._convert_tensors_to_numpy(output)
|
|
483
|
-
yield output
|
|
484
|
-
|
|
485
384
|
def _predict_generator(self) -> Iterator[Dict[str, np.ndarray]]:
|
|
486
385
|
"""Create a generator that yields batches of inference results.
|
|
487
386
|
|
|
@@ -501,14 +400,114 @@ class Predictor(ABC):
|
|
|
501
400
|
# Loop over data batches.
|
|
502
401
|
self.pipeline.start()
|
|
503
402
|
total_frames = self.pipeline.total_len()
|
|
403
|
+
done = False
|
|
504
404
|
|
|
505
405
|
try:
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
406
|
+
with Progress(
|
|
407
|
+
"{task.description}",
|
|
408
|
+
BarColumn(),
|
|
409
|
+
"[progress.percentage]{task.percentage:>3.0f}%",
|
|
410
|
+
MofNCompleteColumn(),
|
|
411
|
+
"ETA:",
|
|
412
|
+
TimeRemainingColumn(),
|
|
413
|
+
"Elapsed:",
|
|
414
|
+
TimeElapsedColumn(),
|
|
415
|
+
RateColumn(),
|
|
416
|
+
auto_refresh=False,
|
|
417
|
+
refresh_per_second=4, # Change to self.report_rate if needed
|
|
418
|
+
speed_estimate_period=5,
|
|
419
|
+
) as progress:
|
|
420
|
+
task = progress.add_task("Predicting...", total=total_frames)
|
|
421
|
+
last_report = time()
|
|
422
|
+
|
|
423
|
+
done = False
|
|
424
|
+
while not done:
|
|
425
|
+
imgs = []
|
|
426
|
+
fidxs = []
|
|
427
|
+
vidxs = []
|
|
428
|
+
org_szs = []
|
|
429
|
+
instances = []
|
|
430
|
+
eff_scales = []
|
|
431
|
+
for _ in range(self.batch_size):
|
|
432
|
+
frame = self.pipeline.frame_buffer.get()
|
|
433
|
+
if frame["image"] is None:
|
|
434
|
+
done = True
|
|
435
|
+
break
|
|
436
|
+
frame["image"], eff_scale = apply_sizematcher(
|
|
437
|
+
frame["image"],
|
|
438
|
+
self.preprocess_config["max_height"],
|
|
439
|
+
self.preprocess_config["max_width"],
|
|
440
|
+
)
|
|
441
|
+
if self.instances_key:
|
|
442
|
+
frame["instances"] = frame["instances"] * eff_scale
|
|
443
|
+
if (
|
|
444
|
+
self.preprocess_config["ensure_rgb"]
|
|
445
|
+
and frame["image"].shape[-3] != 3
|
|
446
|
+
):
|
|
447
|
+
frame["image"] = frame["image"].repeat(1, 3, 1, 1)
|
|
448
|
+
elif (
|
|
449
|
+
self.preprocess_config["ensure_grayscale"]
|
|
450
|
+
and frame["image"].shape[-3] != 1
|
|
451
|
+
):
|
|
452
|
+
frame["image"] = F.rgb_to_grayscale(
|
|
453
|
+
frame["image"], num_output_channels=1
|
|
454
|
+
)
|
|
455
|
+
|
|
456
|
+
eff_scales.append(torch.tensor(eff_scale))
|
|
457
|
+
imgs.append(frame["image"].unsqueeze(dim=0))
|
|
458
|
+
fidxs.append(frame["frame_idx"])
|
|
459
|
+
vidxs.append(frame["video_idx"])
|
|
460
|
+
org_szs.append(frame["orig_size"].unsqueeze(dim=0))
|
|
461
|
+
if self.instances_key:
|
|
462
|
+
instances.append(frame["instances"].unsqueeze(dim=0))
|
|
463
|
+
if imgs:
|
|
464
|
+
# TODO: all preprocessing should be moved into InferenceModels to be exportable.
|
|
465
|
+
imgs = torch.concatenate(imgs, dim=0)
|
|
466
|
+
fidxs = torch.tensor(fidxs, dtype=torch.int32)
|
|
467
|
+
vidxs = torch.tensor(vidxs, dtype=torch.int32)
|
|
468
|
+
org_szs = torch.concatenate(org_szs, dim=0)
|
|
469
|
+
eff_scales = torch.tensor(eff_scales, dtype=torch.float32)
|
|
470
|
+
if self.instances_key:
|
|
471
|
+
instances = torch.concatenate(instances, dim=0)
|
|
472
|
+
ex = {
|
|
473
|
+
"image": imgs,
|
|
474
|
+
"frame_idx": fidxs,
|
|
475
|
+
"video_idx": vidxs,
|
|
476
|
+
"orig_size": org_szs,
|
|
477
|
+
"eff_scale": eff_scales,
|
|
478
|
+
}
|
|
479
|
+
if self.instances_key:
|
|
480
|
+
ex["instances"] = instances
|
|
481
|
+
if self.preprocess:
|
|
482
|
+
scale = self.preprocess_config["scale"]
|
|
483
|
+
if scale != 1.0:
|
|
484
|
+
if self.instances_key:
|
|
485
|
+
ex["image"], ex["instances"] = apply_resizer(
|
|
486
|
+
ex["image"], ex["instances"]
|
|
487
|
+
)
|
|
488
|
+
else:
|
|
489
|
+
ex["image"] = resize_image(ex["image"], scale)
|
|
490
|
+
ex["image"] = apply_pad_to_stride(
|
|
491
|
+
ex["image"], self.max_stride
|
|
492
|
+
)
|
|
493
|
+
outputs_list = self.inference_model(ex)
|
|
494
|
+
if outputs_list is not None:
|
|
495
|
+
for output in outputs_list:
|
|
496
|
+
output = self._convert_tensors_to_numpy(output)
|
|
497
|
+
yield output
|
|
498
|
+
|
|
499
|
+
# Advance progress
|
|
500
|
+
num_frames = (
|
|
501
|
+
len(ex["frame_idx"])
|
|
502
|
+
if "frame_idx" in ex
|
|
503
|
+
else self.batch_size
|
|
504
|
+
)
|
|
505
|
+
progress.update(task, advance=num_frames)
|
|
506
|
+
|
|
507
|
+
# Manually refresh progress bar
|
|
508
|
+
if time() - last_report > 0.25:
|
|
509
|
+
progress.refresh()
|
|
510
|
+
last_report = time()
|
|
512
511
|
|
|
513
512
|
except KeyboardInterrupt:
|
|
514
513
|
logger.info("Inference interrupted by user")
|
|
@@ -521,112 +520,6 @@ class Predictor(ABC):
|
|
|
521
520
|
|
|
522
521
|
self.pipeline.join()
|
|
523
522
|
|
|
524
|
-
def _predict_generator_gui(
|
|
525
|
-
self, total_frames: int
|
|
526
|
-
) -> Iterator[Dict[str, np.ndarray]]:
|
|
527
|
-
"""Generator for GUI mode with JSON progress output.
|
|
528
|
-
|
|
529
|
-
Args:
|
|
530
|
-
total_frames: Total number of frames to process.
|
|
531
|
-
|
|
532
|
-
Yields:
|
|
533
|
-
Dictionaries containing inference results for each frame.
|
|
534
|
-
"""
|
|
535
|
-
start_time = time()
|
|
536
|
-
frames_processed = 0
|
|
537
|
-
last_report = time()
|
|
538
|
-
done = False
|
|
539
|
-
|
|
540
|
-
while not done:
|
|
541
|
-
imgs, fidxs, vidxs, org_szs, instances, eff_scales, done = (
|
|
542
|
-
self._process_batch()
|
|
543
|
-
)
|
|
544
|
-
|
|
545
|
-
if imgs:
|
|
546
|
-
yield from self._run_inference_on_batch(
|
|
547
|
-
imgs, fidxs, vidxs, org_szs, instances, eff_scales
|
|
548
|
-
)
|
|
549
|
-
|
|
550
|
-
# Update progress
|
|
551
|
-
num_frames = len(fidxs)
|
|
552
|
-
frames_processed += num_frames
|
|
553
|
-
|
|
554
|
-
# Emit JSON progress (throttled to ~4Hz)
|
|
555
|
-
if time() - last_report > 0.25:
|
|
556
|
-
elapsed = time() - start_time
|
|
557
|
-
rate = frames_processed / elapsed if elapsed > 0 else 0
|
|
558
|
-
remaining = total_frames - frames_processed
|
|
559
|
-
eta = remaining / rate if rate > 0 else 0
|
|
560
|
-
|
|
561
|
-
progress_data = {
|
|
562
|
-
"n_processed": frames_processed,
|
|
563
|
-
"n_total": total_frames,
|
|
564
|
-
"rate": round(rate, 1),
|
|
565
|
-
"eta": round(eta, 1),
|
|
566
|
-
}
|
|
567
|
-
print(json.dumps(progress_data), flush=True)
|
|
568
|
-
last_report = time()
|
|
569
|
-
|
|
570
|
-
# Final progress emit to ensure 100% is shown
|
|
571
|
-
elapsed = time() - start_time
|
|
572
|
-
progress_data = {
|
|
573
|
-
"n_processed": total_frames,
|
|
574
|
-
"n_total": total_frames,
|
|
575
|
-
"rate": round(frames_processed / elapsed, 1) if elapsed > 0 else 0,
|
|
576
|
-
"eta": 0,
|
|
577
|
-
}
|
|
578
|
-
print(json.dumps(progress_data), flush=True)
|
|
579
|
-
|
|
580
|
-
def _predict_generator_rich(
|
|
581
|
-
self, total_frames: int
|
|
582
|
-
) -> Iterator[Dict[str, np.ndarray]]:
|
|
583
|
-
"""Generator for normal mode with Rich progress bar.
|
|
584
|
-
|
|
585
|
-
Args:
|
|
586
|
-
total_frames: Total number of frames to process.
|
|
587
|
-
|
|
588
|
-
Yields:
|
|
589
|
-
Dictionaries containing inference results for each frame.
|
|
590
|
-
"""
|
|
591
|
-
with Progress(
|
|
592
|
-
"{task.description}",
|
|
593
|
-
BarColumn(),
|
|
594
|
-
"[progress.percentage]{task.percentage:>3.0f}%",
|
|
595
|
-
MofNCompleteColumn(),
|
|
596
|
-
"ETA:",
|
|
597
|
-
TimeRemainingColumn(),
|
|
598
|
-
"Elapsed:",
|
|
599
|
-
TimeElapsedColumn(),
|
|
600
|
-
RateColumn(),
|
|
601
|
-
auto_refresh=False,
|
|
602
|
-
refresh_per_second=4,
|
|
603
|
-
speed_estimate_period=5,
|
|
604
|
-
) as progress:
|
|
605
|
-
task = progress.add_task("Predicting...", total=total_frames)
|
|
606
|
-
last_report = time()
|
|
607
|
-
done = False
|
|
608
|
-
|
|
609
|
-
while not done:
|
|
610
|
-
imgs, fidxs, vidxs, org_szs, instances, eff_scales, done = (
|
|
611
|
-
self._process_batch()
|
|
612
|
-
)
|
|
613
|
-
|
|
614
|
-
if imgs:
|
|
615
|
-
yield from self._run_inference_on_batch(
|
|
616
|
-
imgs, fidxs, vidxs, org_szs, instances, eff_scales
|
|
617
|
-
)
|
|
618
|
-
|
|
619
|
-
# Advance progress
|
|
620
|
-
num_frames = len(fidxs)
|
|
621
|
-
progress.update(task, advance=num_frames)
|
|
622
|
-
|
|
623
|
-
# Manually refresh progress bar
|
|
624
|
-
if time() - last_report > 0.25:
|
|
625
|
-
progress.refresh()
|
|
626
|
-
last_report = time()
|
|
627
|
-
|
|
628
|
-
self.pipeline.join()
|
|
629
|
-
|
|
630
523
|
def predict(
|
|
631
524
|
self,
|
|
632
525
|
make_labels: bool = True,
|
|
@@ -1214,7 +1107,7 @@ class TopDownPredictor(Predictor):
|
|
|
1214
1107
|
def make_pipeline(
|
|
1215
1108
|
self,
|
|
1216
1109
|
inference_object: Union[str, Path, sio.Labels, sio.Video],
|
|
1217
|
-
queue_maxsize: int =
|
|
1110
|
+
queue_maxsize: int = 8,
|
|
1218
1111
|
frames: Optional[list] = None,
|
|
1219
1112
|
only_labeled_frames: bool = False,
|
|
1220
1113
|
only_suggested_frames: bool = False,
|
|
@@ -1228,7 +1121,7 @@ class TopDownPredictor(Predictor):
|
|
|
1228
1121
|
|
|
1229
1122
|
Args:
|
|
1230
1123
|
inference_object: (str) Path to `.slp` file or `.mp4` or sio.Labels or sio.Video to run inference on.
|
|
1231
|
-
queue_maxsize: (int) Maximum size of the frame buffer queue. Default:
|
|
1124
|
+
queue_maxsize: (int) Maximum size of the frame buffer queue. Default: 8.
|
|
1232
1125
|
frames: (list) List of frames indices. If `None`, all frames in the video are used. Default: None.
|
|
1233
1126
|
only_labeled_frames: (bool) `True` if inference should be run only on user-labeled frames. Default: `False`.
|
|
1234
1127
|
only_suggested_frames: (bool) `True` if inference should be run only on unlabeled suggested frames. Default: `False`.
|
|
@@ -1644,7 +1537,7 @@ class SingleInstancePredictor(Predictor):
|
|
|
1644
1537
|
def make_pipeline(
|
|
1645
1538
|
self,
|
|
1646
1539
|
inference_object: Union[str, Path, sio.Labels, sio.Video],
|
|
1647
|
-
queue_maxsize: int =
|
|
1540
|
+
queue_maxsize: int = 8,
|
|
1648
1541
|
frames: Optional[list] = None,
|
|
1649
1542
|
only_labeled_frames: bool = False,
|
|
1650
1543
|
only_suggested_frames: bool = False,
|
|
@@ -1658,7 +1551,7 @@ class SingleInstancePredictor(Predictor):
|
|
|
1658
1551
|
|
|
1659
1552
|
Args:
|
|
1660
1553
|
inference_object: (str) Path to `.slp` file or `.mp4` or sio.Labels or sio.Video to run inference on.
|
|
1661
|
-
queue_maxsize: (int) Maximum size of the frame buffer queue. Default:
|
|
1554
|
+
queue_maxsize: (int) Maximum size of the frame buffer queue. Default: 8.
|
|
1662
1555
|
frames: List of frames indices. If `None`, all frames in the video are used. Default: None.
|
|
1663
1556
|
only_labeled_frames: (bool) `True` if inference should be run only on user-labeled frames. Default: `False`.
|
|
1664
1557
|
only_suggested_frames: (bool) `True` if inference should be run only on unlabeled suggested frames. Default: `False`.
|
|
@@ -2094,7 +1987,7 @@ class BottomUpPredictor(Predictor):
|
|
|
2094
1987
|
def make_pipeline(
|
|
2095
1988
|
self,
|
|
2096
1989
|
inference_object: Union[str, Path, sio.Labels, sio.Video],
|
|
2097
|
-
queue_maxsize: int =
|
|
1990
|
+
queue_maxsize: int = 8,
|
|
2098
1991
|
frames: Optional[list] = None,
|
|
2099
1992
|
only_labeled_frames: bool = False,
|
|
2100
1993
|
only_suggested_frames: bool = False,
|
|
@@ -2108,7 +2001,7 @@ class BottomUpPredictor(Predictor):
|
|
|
2108
2001
|
|
|
2109
2002
|
Args:
|
|
2110
2003
|
inference_object: (str) Path to `.slp` file or `.mp4` or sio.Labels or sio.Video to run inference on.
|
|
2111
|
-
queue_maxsize: (int) Maximum size of the frame buffer queue. Default:
|
|
2004
|
+
queue_maxsize: (int) Maximum size of the frame buffer queue. Default: 8.
|
|
2112
2005
|
frames: List of frames indices. If `None`, all frames in the video are used. Default: None.
|
|
2113
2006
|
only_labeled_frames: (bool) `True` if inference should be run only on user-labeled frames. Default: `False`.
|
|
2114
2007
|
only_suggested_frames: (bool) `True` if inference should be run only on unlabeled suggested frames. Default: `False`.
|
|
@@ -2541,7 +2434,7 @@ class BottomUpMultiClassPredictor(Predictor):
|
|
|
2541
2434
|
def make_pipeline(
|
|
2542
2435
|
self,
|
|
2543
2436
|
inference_object: Union[str, Path, sio.Labels, sio.Video],
|
|
2544
|
-
queue_maxsize: int =
|
|
2437
|
+
queue_maxsize: int = 8,
|
|
2545
2438
|
frames: Optional[list] = None,
|
|
2546
2439
|
only_labeled_frames: bool = False,
|
|
2547
2440
|
only_suggested_frames: bool = False,
|
|
@@ -2555,7 +2448,7 @@ class BottomUpMultiClassPredictor(Predictor):
|
|
|
2555
2448
|
|
|
2556
2449
|
Args:
|
|
2557
2450
|
inference_object: (str) Path to `.slp` file or `.mp4` or sio.Labels or sio.Video to run inference on.
|
|
2558
|
-
queue_maxsize: (int) Maximum size of the frame buffer queue. Default:
|
|
2451
|
+
queue_maxsize: (int) Maximum size of the frame buffer queue. Default: 8.
|
|
2559
2452
|
frames: List of frames indices. If `None`, all frames in the video are used. Default: None.
|
|
2560
2453
|
only_labeled_frames: (bool) `True` if inference should be run only on user-labeled frames. Default: `False`.
|
|
2561
2454
|
only_suggested_frames: (bool) `True` if inference should be run only on unlabeled suggested frames. Default: `False`.
|
|
@@ -3296,7 +3189,7 @@ class TopDownMultiClassPredictor(Predictor):
|
|
|
3296
3189
|
def make_pipeline(
|
|
3297
3190
|
self,
|
|
3298
3191
|
inference_object: Union[str, Path, sio.Labels, sio.Video],
|
|
3299
|
-
queue_maxsize: int =
|
|
3192
|
+
queue_maxsize: int = 8,
|
|
3300
3193
|
frames: Optional[list] = None,
|
|
3301
3194
|
only_labeled_frames: bool = False,
|
|
3302
3195
|
only_suggested_frames: bool = False,
|
|
@@ -3310,7 +3203,7 @@ class TopDownMultiClassPredictor(Predictor):
|
|
|
3310
3203
|
|
|
3311
3204
|
Args:
|
|
3312
3205
|
inference_object: (str) Path to `.slp` file or `.mp4` or sio.Labels or sio.Video to run inference on.
|
|
3313
|
-
queue_maxsize: (int) Maximum size of the frame buffer queue. Default:
|
|
3206
|
+
queue_maxsize: (int) Maximum size of the frame buffer queue. Default: 8.
|
|
3314
3207
|
frames: (list) List of frames indices. If `None`, all frames in the video are used. Default: None.
|
|
3315
3208
|
only_labeled_frames: (bool) `True` if inference should be run only on user-labeled frames. Default: `False`.
|
|
3316
3209
|
only_suggested_frames: (bool) `True` if inference should be run only on unlabeled suggested frames. Default: `False`.
|
sleap_nn/legacy_models.py
CHANGED
|
@@ -7,8 +7,9 @@ TensorFlow/Keras backend to PyTorch format compatible with sleap-nn.
|
|
|
7
7
|
import h5py
|
|
8
8
|
import numpy as np
|
|
9
9
|
import torch
|
|
10
|
-
from typing import Dict, Any, Optional
|
|
10
|
+
from typing import Dict, Tuple, Any, Optional, List
|
|
11
11
|
from pathlib import Path
|
|
12
|
+
from omegaconf import OmegaConf
|
|
12
13
|
import re
|
|
13
14
|
from loguru import logger
|
|
14
15
|
|
|
@@ -180,61 +181,18 @@ def parse_keras_layer_name(layer_path: str) -> Dict[str, Any]:
|
|
|
180
181
|
return info
|
|
181
182
|
|
|
182
183
|
|
|
183
|
-
def filter_legacy_weights_by_component(
|
|
184
|
-
legacy_weights: Dict[str, np.ndarray], component: Optional[str]
|
|
185
|
-
) -> Dict[str, np.ndarray]:
|
|
186
|
-
"""Filter legacy weights based on component type.
|
|
187
|
-
|
|
188
|
-
Args:
|
|
189
|
-
legacy_weights: Dictionary of legacy weights from load_keras_weights()
|
|
190
|
-
component: Component type to filter for. One of:
|
|
191
|
-
- "backbone": Keep only encoder/decoder weights (exclude heads)
|
|
192
|
-
- "head": Keep only head layer weights
|
|
193
|
-
- None: No filtering (keep all weights)
|
|
194
|
-
|
|
195
|
-
Returns:
|
|
196
|
-
Filtered dictionary of legacy weights
|
|
197
|
-
"""
|
|
198
|
-
if component is None:
|
|
199
|
-
return legacy_weights
|
|
200
|
-
|
|
201
|
-
filtered = {}
|
|
202
|
-
for path, weight in legacy_weights.items():
|
|
203
|
-
# Check if this is a head layer (contains "Head" in the path)
|
|
204
|
-
is_head_layer = "Head" in path
|
|
205
|
-
|
|
206
|
-
if component == "backbone" and not is_head_layer:
|
|
207
|
-
filtered[path] = weight
|
|
208
|
-
elif component == "head" and is_head_layer:
|
|
209
|
-
filtered[path] = weight
|
|
210
|
-
|
|
211
|
-
return filtered
|
|
212
|
-
|
|
213
|
-
|
|
214
184
|
def map_legacy_to_pytorch_layers(
|
|
215
|
-
legacy_weights: Dict[str, np.ndarray],
|
|
216
|
-
pytorch_model: torch.nn.Module,
|
|
217
|
-
component: Optional[str] = None,
|
|
185
|
+
legacy_weights: Dict[str, np.ndarray], pytorch_model: torch.nn.Module
|
|
218
186
|
) -> Dict[str, str]:
|
|
219
187
|
"""Create mapping between legacy Keras layers and PyTorch model layers.
|
|
220
188
|
|
|
221
189
|
Args:
|
|
222
190
|
legacy_weights: Dictionary of legacy weights from load_keras_weights()
|
|
223
191
|
pytorch_model: PyTorch model instance to map to
|
|
224
|
-
component: Optional component type for filtering weights before mapping.
|
|
225
|
-
One of "backbone", "head", or None (no filtering).
|
|
226
192
|
|
|
227
193
|
Returns:
|
|
228
194
|
Dictionary mapping legacy layer paths to PyTorch parameter names
|
|
229
195
|
"""
|
|
230
|
-
# Filter weights based on component type
|
|
231
|
-
filtered_weights = filter_legacy_weights_by_component(legacy_weights, component)
|
|
232
|
-
|
|
233
|
-
if component is not None:
|
|
234
|
-
logger.info(
|
|
235
|
-
f"Filtered legacy weights for {component}: "
|
|
236
|
-
f"{len(filtered_weights)}/{len(legacy_weights)} weights"
|
|
237
|
-
)
|
|
238
196
|
mapping = {}
|
|
239
197
|
|
|
240
198
|
# Get all PyTorch parameters with their shapes
|
|
@@ -243,7 +201,7 @@ def map_legacy_to_pytorch_layers(
|
|
|
243
201
|
pytorch_params[name] = param.shape
|
|
244
202
|
|
|
245
203
|
# For each legacy weight, find the corresponding PyTorch parameter
|
|
246
|
-
for legacy_path, weight in
|
|
204
|
+
for legacy_path, weight in legacy_weights.items():
|
|
247
205
|
# Extract the layer name from the legacy path
|
|
248
206
|
# Legacy path format: "model_weights/stack0_enc0_conv0/stack0_enc0_conv0/kernel:0"
|
|
249
207
|
clean_path = legacy_path.replace("model_weights/", "")
|
|
@@ -262,6 +220,8 @@ def map_legacy_to_pytorch_layers(
|
|
|
262
220
|
# This handles cases where Keras uses suffixes like _0, _1, etc.
|
|
263
221
|
if "Head" in layer_name:
|
|
264
222
|
# Remove trailing _N where N is a number
|
|
223
|
+
import re
|
|
224
|
+
|
|
265
225
|
layer_name_clean = re.sub(r"_\d+$", "", layer_name)
|
|
266
226
|
else:
|
|
267
227
|
layer_name_clean = layer_name
|
|
@@ -306,17 +266,12 @@ def map_legacy_to_pytorch_layers(
|
|
|
306
266
|
if not mapping:
|
|
307
267
|
logger.info(
|
|
308
268
|
f"No mappings could be created between legacy weights and PyTorch model. "
|
|
309
|
-
f"Legacy weights: {len(
|
|
269
|
+
f"Legacy weights: {len(legacy_weights)}, PyTorch parameters: {len(pytorch_params)}"
|
|
310
270
|
)
|
|
311
271
|
else:
|
|
312
272
|
logger.info(
|
|
313
|
-
f"Successfully mapped {len(mapping)}/{len(
|
|
273
|
+
f"Successfully mapped {len(mapping)}/{len(legacy_weights)} legacy weights to PyTorch parameters"
|
|
314
274
|
)
|
|
315
|
-
unmatched_count = len(filtered_weights) - len(mapping)
|
|
316
|
-
if unmatched_count > 0:
|
|
317
|
-
logger.warning(
|
|
318
|
-
f"({unmatched_count} legacy weights did not match any parameters in this model component)"
|
|
319
|
-
)
|
|
320
275
|
|
|
321
276
|
return mapping
|
|
322
277
|
|
|
@@ -325,7 +280,6 @@ def load_legacy_model_weights(
|
|
|
325
280
|
pytorch_model: torch.nn.Module,
|
|
326
281
|
h5_path: str,
|
|
327
282
|
mapping: Optional[Dict[str, str]] = None,
|
|
328
|
-
component: Optional[str] = None,
|
|
329
283
|
) -> None:
|
|
330
284
|
"""Load legacy Keras weights into a PyTorch model.
|
|
331
285
|
|
|
@@ -334,10 +288,6 @@ def load_legacy_model_weights(
|
|
|
334
288
|
h5_path: Path to the legacy .h5 model file
|
|
335
289
|
mapping: Optional manual mapping of layer names. If None,
|
|
336
290
|
will attempt automatic mapping.
|
|
337
|
-
component: Optional component type for filtering weights. One of:
|
|
338
|
-
- "backbone": Only load encoder/decoder weights (exclude heads)
|
|
339
|
-
- "head": Only load head layer weights
|
|
340
|
-
- None: Load all weights (default, for full model loading)
|
|
341
291
|
"""
|
|
342
292
|
# Load legacy weights
|
|
343
293
|
legacy_weights = load_keras_weights(h5_path)
|
|
@@ -345,9 +295,7 @@ def load_legacy_model_weights(
|
|
|
345
295
|
if mapping is None:
|
|
346
296
|
# Attempt automatic mapping
|
|
347
297
|
try:
|
|
348
|
-
mapping = map_legacy_to_pytorch_layers(
|
|
349
|
-
legacy_weights, pytorch_model, component=component
|
|
350
|
-
)
|
|
298
|
+
mapping = map_legacy_to_pytorch_layers(legacy_weights, pytorch_model)
|
|
351
299
|
except Exception as e:
|
|
352
300
|
logger.error(f"Failed to create weight mappings: {e}")
|
|
353
301
|
return
|
|
@@ -469,9 +417,7 @@ def load_legacy_model_weights(
|
|
|
469
417
|
).item()
|
|
470
418
|
diff = abs(keras_mean - torch_mean)
|
|
471
419
|
if diff > 1e-6:
|
|
472
|
-
message = f"Weight verification failed for {pytorch_name}
|
|
473
|
-
logger.error(message)
|
|
474
|
-
verification_errors.append(message)
|
|
420
|
+
message = f"Weight verification failed for {pytorch_name} linear): keras={keras_mean:.6f}, torch={torch_mean:.6f}, diff={diff:.6e}"
|
|
475
421
|
else:
|
|
476
422
|
# Bias : just compare all values
|
|
477
423
|
keras_mean = np.mean(original_weight)
|
|
@@ -480,7 +426,7 @@ def load_legacy_model_weights(
|
|
|
480
426
|
).item()
|
|
481
427
|
diff = abs(keras_mean - torch_mean)
|
|
482
428
|
if diff > 1e-6:
|
|
483
|
-
message = f"Weight verification failed for {pytorch_name}
|
|
429
|
+
message = f"Weight verification failed for {pytorch_name} bias): keras={keras_mean:.6f}, torch={torch_mean:.6f}, diff={diff:.6e}"
|
|
484
430
|
logger.error(message)
|
|
485
431
|
verification_errors.append(message)
|
|
486
432
|
|