sleap-nn 0.0.5__py3-none-any.whl → 0.1.0a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sleap_nn/__init__.py +6 -1
- sleap_nn/cli.py +142 -3
- sleap_nn/config/data_config.py +44 -7
- sleap_nn/config/get_config.py +22 -20
- sleap_nn/config/trainer_config.py +12 -0
- sleap_nn/data/augmentation.py +54 -2
- sleap_nn/data/custom_datasets.py +22 -22
- sleap_nn/data/instance_cropping.py +70 -5
- sleap_nn/data/normalization.py +45 -2
- sleap_nn/data/providers.py +26 -0
- sleap_nn/evaluation.py +99 -23
- sleap_nn/inference/__init__.py +6 -0
- sleap_nn/inference/peak_finding.py +10 -2
- sleap_nn/inference/predictors.py +115 -20
- sleap_nn/inference/provenance.py +292 -0
- sleap_nn/inference/topdown.py +55 -47
- sleap_nn/predict.py +187 -10
- sleap_nn/system_info.py +443 -0
- sleap_nn/tracking/tracker.py +8 -1
- sleap_nn/train.py +64 -40
- sleap_nn/training/callbacks.py +317 -5
- sleap_nn/training/lightning_modules.py +325 -180
- sleap_nn/training/model_trainer.py +308 -22
- sleap_nn/training/utils.py +367 -2
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0a0.dist-info}/METADATA +22 -32
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0a0.dist-info}/RECORD +30 -28
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0a0.dist-info}/WHEEL +0 -0
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0a0.dist-info}/entry_points.txt +0 -0
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0a0.dist-info}/licenses/LICENSE +0 -0
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0a0.dist-info}/top_level.txt +0 -0
|
@@ -33,6 +33,7 @@ from sleap_nn.inference.bottomup import (
|
|
|
33
33
|
)
|
|
34
34
|
from sleap_nn.inference.paf_grouping import PAFScorer
|
|
35
35
|
from sleap_nn.architectures.model import Model
|
|
36
|
+
from sleap_nn.data.normalization import normalize_on_gpu
|
|
36
37
|
from sleap_nn.training.losses import compute_ohkm_loss
|
|
37
38
|
from loguru import logger
|
|
38
39
|
from sleap_nn.training.utils import (
|
|
@@ -40,7 +41,13 @@ from sleap_nn.training.utils import (
|
|
|
40
41
|
plot_confmaps,
|
|
41
42
|
plot_img,
|
|
42
43
|
plot_peaks,
|
|
44
|
+
VisualizationData,
|
|
43
45
|
)
|
|
46
|
+
import matplotlib
|
|
47
|
+
|
|
48
|
+
matplotlib.use(
|
|
49
|
+
"Agg"
|
|
50
|
+
) # Use non-interactive backend to avoid tkinter issues on Windows CI
|
|
44
51
|
import matplotlib.pyplot as plt
|
|
45
52
|
from sleap_nn.config.utils import get_backbone_type_from_cfg, get_model_type_from_cfg
|
|
46
53
|
from sleap_nn.config.trainer_config import (
|
|
@@ -311,6 +318,15 @@ class LightningModel(L.LightningModule):
|
|
|
311
318
|
logger=True,
|
|
312
319
|
sync_dist=True,
|
|
313
320
|
)
|
|
321
|
+
# Log epoch explicitly for custom x-axis support in wandb
|
|
322
|
+
self.log(
|
|
323
|
+
"epoch",
|
|
324
|
+
float(self.current_epoch),
|
|
325
|
+
on_step=False,
|
|
326
|
+
on_epoch=True,
|
|
327
|
+
logger=True,
|
|
328
|
+
sync_dist=True,
|
|
329
|
+
)
|
|
314
330
|
|
|
315
331
|
def on_validation_epoch_start(self):
|
|
316
332
|
"""Configure the val timer at the beginning of each epoch."""
|
|
@@ -493,8 +509,15 @@ class SingleInstanceLightningModule(LightningModel):
|
|
|
493
509
|
)
|
|
494
510
|
self.node_names = self.head_configs.single_instance.confmaps.part_names
|
|
495
511
|
|
|
496
|
-
def
|
|
497
|
-
"""
|
|
512
|
+
def get_visualization_data(self, sample) -> VisualizationData:
|
|
513
|
+
"""Extract visualization data from a sample.
|
|
514
|
+
|
|
515
|
+
Args:
|
|
516
|
+
sample: A sample dictionary from the data pipeline.
|
|
517
|
+
|
|
518
|
+
Returns:
|
|
519
|
+
VisualizationData containing image, confmaps, peaks, etc.
|
|
520
|
+
"""
|
|
498
521
|
ex = sample.copy()
|
|
499
522
|
ex["eff_scale"] = torch.tensor([1.0])
|
|
500
523
|
for k, v in ex.items():
|
|
@@ -502,27 +525,41 @@ class SingleInstanceLightningModule(LightningModel):
|
|
|
502
525
|
ex[k] = v.to(device=self.device)
|
|
503
526
|
ex["image"] = ex["image"].unsqueeze(dim=0)
|
|
504
527
|
output = self.single_instance_inf_layer(ex)[0]
|
|
528
|
+
|
|
505
529
|
peaks = output["pred_instance_peaks"].cpu().numpy()
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
) # convert from (C, H, W) to (H, W, C)
|
|
530
|
+
peak_values = output["pred_peak_values"].cpu().numpy()
|
|
531
|
+
img = output["image"][0, 0].cpu().numpy().transpose(1, 2, 0)
|
|
509
532
|
gt_instances = ex["instances"][0].cpu().numpy()
|
|
510
|
-
confmaps = (
|
|
511
|
-
|
|
512
|
-
|
|
533
|
+
confmaps = output["pred_confmaps"][0].cpu().numpy().transpose(1, 2, 0)
|
|
534
|
+
|
|
535
|
+
return VisualizationData(
|
|
536
|
+
image=img,
|
|
537
|
+
pred_confmaps=confmaps,
|
|
538
|
+
pred_peaks=peaks,
|
|
539
|
+
pred_peak_values=peak_values,
|
|
540
|
+
gt_instances=gt_instances,
|
|
541
|
+
node_names=list(self.node_names) if self.node_names else [],
|
|
542
|
+
output_scale=confmaps.shape[0] / img.shape[0],
|
|
543
|
+
is_paired=True,
|
|
544
|
+
)
|
|
545
|
+
|
|
546
|
+
def visualize_example(self, sample):
|
|
547
|
+
"""Visualize predictions during training (used with callbacks)."""
|
|
548
|
+
data = self.get_visualization_data(sample)
|
|
513
549
|
scale = 1.0
|
|
514
|
-
if
|
|
550
|
+
if data.image.shape[0] < 512:
|
|
515
551
|
scale = 2.0
|
|
516
|
-
if
|
|
552
|
+
if data.image.shape[0] < 256:
|
|
517
553
|
scale = 4.0
|
|
518
|
-
fig = plot_img(
|
|
519
|
-
plot_confmaps(
|
|
520
|
-
plot_peaks(gt_instances,
|
|
554
|
+
fig = plot_img(data.image, dpi=72 * scale, scale=scale)
|
|
555
|
+
plot_confmaps(data.pred_confmaps, output_scale=data.output_scale)
|
|
556
|
+
plot_peaks(data.gt_instances, data.pred_peaks, paired=data.is_paired)
|
|
521
557
|
return fig
|
|
522
558
|
|
|
523
559
|
def forward(self, img):
|
|
524
560
|
"""Forward pass of the model."""
|
|
525
561
|
img = torch.squeeze(img, dim=1).to(self.device)
|
|
562
|
+
img = normalize_on_gpu(img)
|
|
526
563
|
return self.model(img)["SingleInstanceConfmapsHead"]
|
|
527
564
|
|
|
528
565
|
def training_step(self, batch, batch_idx):
|
|
@@ -556,8 +593,8 @@ class SingleInstanceLightningModule(LightningModel):
|
|
|
556
593
|
self.log(
|
|
557
594
|
f"{name}",
|
|
558
595
|
channel_wise_loss[node_idx],
|
|
559
|
-
prog_bar=
|
|
560
|
-
on_step=
|
|
596
|
+
prog_bar=False,
|
|
597
|
+
on_step=False,
|
|
561
598
|
on_epoch=True,
|
|
562
599
|
logger=True,
|
|
563
600
|
sync_dist=True,
|
|
@@ -567,7 +604,7 @@ class SingleInstanceLightningModule(LightningModel):
|
|
|
567
604
|
loss,
|
|
568
605
|
prog_bar=True,
|
|
569
606
|
on_step=True,
|
|
570
|
-
on_epoch=
|
|
607
|
+
on_epoch=False,
|
|
571
608
|
logger=True,
|
|
572
609
|
sync_dist=True,
|
|
573
610
|
)
|
|
@@ -592,21 +629,11 @@ class SingleInstanceLightningModule(LightningModel):
|
|
|
592
629
|
loss_scale=self.loss_scale,
|
|
593
630
|
)
|
|
594
631
|
val_loss = val_loss + ohkm_loss
|
|
595
|
-
lr = self.optimizers().optimizer.param_groups[0]["lr"]
|
|
596
|
-
self.log(
|
|
597
|
-
"learning_rate",
|
|
598
|
-
lr,
|
|
599
|
-
prog_bar=True,
|
|
600
|
-
on_step=True,
|
|
601
|
-
on_epoch=True,
|
|
602
|
-
logger=True,
|
|
603
|
-
sync_dist=True,
|
|
604
|
-
)
|
|
605
632
|
self.log(
|
|
606
633
|
"val_loss",
|
|
607
634
|
val_loss,
|
|
608
635
|
prog_bar=True,
|
|
609
|
-
on_step=
|
|
636
|
+
on_step=False,
|
|
610
637
|
on_epoch=True,
|
|
611
638
|
logger=True,
|
|
612
639
|
sync_dist=True,
|
|
@@ -705,8 +732,8 @@ class TopDownCenteredInstanceLightningModule(LightningModel):
|
|
|
705
732
|
|
|
706
733
|
self.node_names = self.head_configs.centered_instance.confmaps.part_names
|
|
707
734
|
|
|
708
|
-
def
|
|
709
|
-
"""
|
|
735
|
+
def get_visualization_data(self, sample) -> VisualizationData:
|
|
736
|
+
"""Extract visualization data from a sample."""
|
|
710
737
|
ex = sample.copy()
|
|
711
738
|
ex["eff_scale"] = torch.tensor([1.0])
|
|
712
739
|
for k, v in ex.items():
|
|
@@ -714,27 +741,41 @@ class TopDownCenteredInstanceLightningModule(LightningModel):
|
|
|
714
741
|
ex[k] = v.to(device=self.device)
|
|
715
742
|
ex["instance_image"] = ex["instance_image"].unsqueeze(dim=0)
|
|
716
743
|
output = self.instance_peaks_inf_layer(ex)
|
|
744
|
+
|
|
717
745
|
peaks = output["pred_instance_peaks"].cpu().numpy()
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
) # convert from (C, H, W) to (H, W, C)
|
|
746
|
+
peak_values = output["pred_peak_values"].cpu().numpy()
|
|
747
|
+
img = output["instance_image"][0, 0].cpu().numpy().transpose(1, 2, 0)
|
|
721
748
|
gt_instances = ex["instance"].cpu().numpy()
|
|
722
|
-
confmaps = (
|
|
723
|
-
|
|
724
|
-
|
|
749
|
+
confmaps = output["pred_confmaps"][0].cpu().numpy().transpose(1, 2, 0)
|
|
750
|
+
|
|
751
|
+
return VisualizationData(
|
|
752
|
+
image=img,
|
|
753
|
+
pred_confmaps=confmaps,
|
|
754
|
+
pred_peaks=peaks,
|
|
755
|
+
pred_peak_values=peak_values,
|
|
756
|
+
gt_instances=gt_instances,
|
|
757
|
+
node_names=list(self.node_names) if self.node_names else [],
|
|
758
|
+
output_scale=confmaps.shape[0] / img.shape[0],
|
|
759
|
+
is_paired=True,
|
|
760
|
+
)
|
|
761
|
+
|
|
762
|
+
def visualize_example(self, sample):
|
|
763
|
+
"""Visualize predictions during training (used with callbacks)."""
|
|
764
|
+
data = self.get_visualization_data(sample)
|
|
725
765
|
scale = 1.0
|
|
726
|
-
if
|
|
766
|
+
if data.image.shape[0] < 512:
|
|
727
767
|
scale = 2.0
|
|
728
|
-
if
|
|
768
|
+
if data.image.shape[0] < 256:
|
|
729
769
|
scale = 4.0
|
|
730
|
-
fig = plot_img(
|
|
731
|
-
plot_confmaps(
|
|
732
|
-
plot_peaks(gt_instances,
|
|
770
|
+
fig = plot_img(data.image, dpi=72 * scale, scale=scale)
|
|
771
|
+
plot_confmaps(data.pred_confmaps, output_scale=data.output_scale)
|
|
772
|
+
plot_peaks(data.gt_instances, data.pred_peaks, paired=data.is_paired)
|
|
733
773
|
return fig
|
|
734
774
|
|
|
735
775
|
def forward(self, img):
|
|
736
776
|
"""Forward pass of the model."""
|
|
737
777
|
img = torch.squeeze(img, dim=1).to(self.device)
|
|
778
|
+
img = normalize_on_gpu(img)
|
|
738
779
|
return self.model(img)["CenteredInstanceConfmapsHead"]
|
|
739
780
|
|
|
740
781
|
def training_step(self, batch, batch_idx):
|
|
@@ -768,8 +809,8 @@ class TopDownCenteredInstanceLightningModule(LightningModel):
|
|
|
768
809
|
self.log(
|
|
769
810
|
f"{name}",
|
|
770
811
|
channel_wise_loss[node_idx],
|
|
771
|
-
prog_bar=
|
|
772
|
-
on_step=
|
|
812
|
+
prog_bar=False,
|
|
813
|
+
on_step=False,
|
|
773
814
|
on_epoch=True,
|
|
774
815
|
logger=True,
|
|
775
816
|
sync_dist=True,
|
|
@@ -780,7 +821,7 @@ class TopDownCenteredInstanceLightningModule(LightningModel):
|
|
|
780
821
|
loss,
|
|
781
822
|
prog_bar=True,
|
|
782
823
|
on_step=True,
|
|
783
|
-
on_epoch=
|
|
824
|
+
on_epoch=False,
|
|
784
825
|
logger=True,
|
|
785
826
|
sync_dist=True,
|
|
786
827
|
)
|
|
@@ -805,21 +846,11 @@ class TopDownCenteredInstanceLightningModule(LightningModel):
|
|
|
805
846
|
loss_scale=self.loss_scale,
|
|
806
847
|
)
|
|
807
848
|
val_loss = val_loss + ohkm_loss
|
|
808
|
-
lr = self.optimizers().optimizer.param_groups[0]["lr"]
|
|
809
|
-
self.log(
|
|
810
|
-
"learning_rate",
|
|
811
|
-
lr,
|
|
812
|
-
prog_bar=True,
|
|
813
|
-
on_step=True,
|
|
814
|
-
on_epoch=True,
|
|
815
|
-
logger=True,
|
|
816
|
-
sync_dist=True,
|
|
817
|
-
)
|
|
818
849
|
self.log(
|
|
819
850
|
"val_loss",
|
|
820
851
|
val_loss,
|
|
821
852
|
prog_bar=True,
|
|
822
|
-
on_step=
|
|
853
|
+
on_step=False,
|
|
823
854
|
on_epoch=True,
|
|
824
855
|
logger=True,
|
|
825
856
|
sync_dist=True,
|
|
@@ -916,9 +947,10 @@ class CentroidLightningModule(LightningModel):
|
|
|
916
947
|
output_stride=self.head_configs.centroid.confmaps.output_stride,
|
|
917
948
|
input_scale=1.0,
|
|
918
949
|
)
|
|
950
|
+
self.node_names = ["centroid"]
|
|
919
951
|
|
|
920
|
-
def
|
|
921
|
-
"""
|
|
952
|
+
def get_visualization_data(self, sample) -> VisualizationData:
|
|
953
|
+
"""Extract visualization data from a sample."""
|
|
922
954
|
ex = sample.copy()
|
|
923
955
|
ex["eff_scale"] = torch.tensor([1.0])
|
|
924
956
|
for k, v in ex.items():
|
|
@@ -927,26 +959,40 @@ class CentroidLightningModule(LightningModel):
|
|
|
927
959
|
ex["image"] = ex["image"].unsqueeze(dim=0)
|
|
928
960
|
gt_centroids = ex["centroids"].cpu().numpy()
|
|
929
961
|
output = self.centroid_inf_layer(ex)
|
|
962
|
+
|
|
930
963
|
peaks = output["centroids"][0].cpu().numpy()
|
|
931
|
-
|
|
932
|
-
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
|
|
936
|
-
|
|
964
|
+
centroid_vals = output["centroid_vals"][0].cpu().numpy()
|
|
965
|
+
img = output["image"][0, 0].cpu().numpy().transpose(1, 2, 0)
|
|
966
|
+
confmaps = output["pred_centroid_confmaps"][0].cpu().numpy().transpose(1, 2, 0)
|
|
967
|
+
|
|
968
|
+
return VisualizationData(
|
|
969
|
+
image=img,
|
|
970
|
+
pred_confmaps=confmaps,
|
|
971
|
+
pred_peaks=peaks,
|
|
972
|
+
pred_peak_values=centroid_vals,
|
|
973
|
+
gt_instances=gt_centroids,
|
|
974
|
+
node_names=self.node_names,
|
|
975
|
+
output_scale=confmaps.shape[0] / img.shape[0],
|
|
976
|
+
is_paired=False,
|
|
977
|
+
)
|
|
978
|
+
|
|
979
|
+
def visualize_example(self, sample):
|
|
980
|
+
"""Visualize predictions during training (used with callbacks)."""
|
|
981
|
+
data = self.get_visualization_data(sample)
|
|
937
982
|
scale = 1.0
|
|
938
|
-
if
|
|
983
|
+
if data.image.shape[0] < 512:
|
|
939
984
|
scale = 2.0
|
|
940
|
-
if
|
|
985
|
+
if data.image.shape[0] < 256:
|
|
941
986
|
scale = 4.0
|
|
942
|
-
fig = plot_img(
|
|
943
|
-
plot_confmaps(
|
|
944
|
-
plot_peaks(
|
|
987
|
+
fig = plot_img(data.image, dpi=72 * scale, scale=scale)
|
|
988
|
+
plot_confmaps(data.pred_confmaps, output_scale=data.output_scale)
|
|
989
|
+
plot_peaks(data.gt_instances, data.pred_peaks, paired=data.is_paired)
|
|
945
990
|
return fig
|
|
946
991
|
|
|
947
992
|
def forward(self, img):
|
|
948
993
|
"""Forward pass of the model."""
|
|
949
994
|
img = torch.squeeze(img, dim=1).to(self.device)
|
|
995
|
+
img = normalize_on_gpu(img)
|
|
950
996
|
return self.model(img)["CentroidConfmapsHead"]
|
|
951
997
|
|
|
952
998
|
def training_step(self, batch, batch_idx):
|
|
@@ -963,7 +1009,7 @@ class CentroidLightningModule(LightningModel):
|
|
|
963
1009
|
loss,
|
|
964
1010
|
prog_bar=True,
|
|
965
1011
|
on_step=True,
|
|
966
|
-
on_epoch=
|
|
1012
|
+
on_epoch=False,
|
|
967
1013
|
logger=True,
|
|
968
1014
|
sync_dist=True,
|
|
969
1015
|
)
|
|
@@ -978,21 +1024,11 @@ class CentroidLightningModule(LightningModel):
|
|
|
978
1024
|
|
|
979
1025
|
y_preds = self.model(X)["CentroidConfmapsHead"]
|
|
980
1026
|
val_loss = nn.MSELoss()(y_preds, y)
|
|
981
|
-
lr = self.optimizers().optimizer.param_groups[0]["lr"]
|
|
982
|
-
self.log(
|
|
983
|
-
"learning_rate",
|
|
984
|
-
lr,
|
|
985
|
-
prog_bar=True,
|
|
986
|
-
on_step=True,
|
|
987
|
-
on_epoch=True,
|
|
988
|
-
logger=True,
|
|
989
|
-
sync_dist=True,
|
|
990
|
-
)
|
|
991
1027
|
self.log(
|
|
992
1028
|
"val_loss",
|
|
993
1029
|
val_loss,
|
|
994
1030
|
prog_bar=True,
|
|
995
|
-
on_step=
|
|
1031
|
+
on_step=False,
|
|
996
1032
|
on_epoch=True,
|
|
997
1033
|
logger=True,
|
|
998
1034
|
sync_dist=True,
|
|
@@ -1097,9 +1133,12 @@ class BottomUpLightningModule(LightningModel):
|
|
|
1097
1133
|
cms_output_stride=self.head_configs.bottomup.confmaps.output_stride,
|
|
1098
1134
|
pafs_output_stride=self.head_configs.bottomup.pafs.output_stride,
|
|
1099
1135
|
)
|
|
1136
|
+
self.node_names = list(self.head_configs.bottomup.confmaps.part_names)
|
|
1100
1137
|
|
|
1101
|
-
def
|
|
1102
|
-
|
|
1138
|
+
def get_visualization_data(
|
|
1139
|
+
self, sample, include_pafs: bool = False
|
|
1140
|
+
) -> VisualizationData:
|
|
1141
|
+
"""Extract visualization data from a sample."""
|
|
1103
1142
|
ex = sample.copy()
|
|
1104
1143
|
ex["eff_scale"] = torch.tensor([1.0])
|
|
1105
1144
|
for k, v in ex.items():
|
|
@@ -1107,54 +1146,65 @@ class BottomUpLightningModule(LightningModel):
|
|
|
1107
1146
|
ex[k] = v.to(device=self.device)
|
|
1108
1147
|
ex["image"] = ex["image"].unsqueeze(dim=0)
|
|
1109
1148
|
output = self.bottomup_inf_layer(ex)[0]
|
|
1149
|
+
|
|
1110
1150
|
peaks = output["pred_instance_peaks"][0].cpu().numpy()
|
|
1111
|
-
|
|
1112
|
-
|
|
1113
|
-
) # convert from (C, H, W) to (H, W, C)
|
|
1151
|
+
peak_values = output["pred_peak_values"][0].cpu().numpy()
|
|
1152
|
+
img = output["image"][0, 0].cpu().numpy().transpose(1, 2, 0)
|
|
1114
1153
|
gt_instances = ex["instances"][0].cpu().numpy()
|
|
1115
|
-
confmaps = (
|
|
1116
|
-
|
|
1117
|
-
|
|
1154
|
+
confmaps = output["pred_confmaps"][0].cpu().numpy().transpose(1, 2, 0)
|
|
1155
|
+
|
|
1156
|
+
pred_pafs = None
|
|
1157
|
+
if include_pafs:
|
|
1158
|
+
pafs = output["pred_part_affinity_fields"].cpu().numpy()[0]
|
|
1159
|
+
pred_pafs = pafs # (h, w, 2*edges)
|
|
1160
|
+
|
|
1161
|
+
return VisualizationData(
|
|
1162
|
+
image=img,
|
|
1163
|
+
pred_confmaps=confmaps,
|
|
1164
|
+
pred_peaks=peaks,
|
|
1165
|
+
pred_peak_values=peak_values,
|
|
1166
|
+
gt_instances=gt_instances,
|
|
1167
|
+
node_names=self.node_names,
|
|
1168
|
+
output_scale=confmaps.shape[0] / img.shape[0],
|
|
1169
|
+
is_paired=False,
|
|
1170
|
+
pred_pafs=pred_pafs,
|
|
1171
|
+
)
|
|
1172
|
+
|
|
1173
|
+
def visualize_example(self, sample):
|
|
1174
|
+
"""Visualize predictions during training (used with callbacks)."""
|
|
1175
|
+
data = self.get_visualization_data(sample)
|
|
1118
1176
|
scale = 1.0
|
|
1119
|
-
if
|
|
1177
|
+
if data.image.shape[0] < 512:
|
|
1120
1178
|
scale = 2.0
|
|
1121
|
-
if
|
|
1179
|
+
if data.image.shape[0] < 256:
|
|
1122
1180
|
scale = 4.0
|
|
1123
|
-
fig = plot_img(
|
|
1124
|
-
plot_confmaps(
|
|
1181
|
+
fig = plot_img(data.image, dpi=72 * scale, scale=scale)
|
|
1182
|
+
plot_confmaps(data.pred_confmaps, output_scale=data.output_scale)
|
|
1125
1183
|
plt.xlim(plt.xlim())
|
|
1126
1184
|
plt.ylim(plt.ylim())
|
|
1127
|
-
plot_peaks(gt_instances,
|
|
1185
|
+
plot_peaks(data.gt_instances, data.pred_peaks, paired=data.is_paired)
|
|
1128
1186
|
return fig
|
|
1129
1187
|
|
|
1130
1188
|
def visualize_pafs_example(self, sample):
|
|
1131
|
-
"""Visualize predictions during training (used with callbacks)."""
|
|
1132
|
-
|
|
1133
|
-
ex["eff_scale"] = torch.tensor([1.0])
|
|
1134
|
-
for k, v in ex.items():
|
|
1135
|
-
if isinstance(v, torch.Tensor):
|
|
1136
|
-
ex[k] = v.to(device=self.device)
|
|
1137
|
-
ex["image"] = ex["image"].unsqueeze(dim=0)
|
|
1138
|
-
output = self.bottomup_inf_layer(ex)[0]
|
|
1139
|
-
img = (
|
|
1140
|
-
output["image"][0, 0].cpu().numpy().transpose(1, 2, 0)
|
|
1141
|
-
) # convert from (C, H, W) to (H, W, C)
|
|
1142
|
-
pafs = output["pred_part_affinity_fields"].cpu().numpy()[0] # (h, w, 2*edges)
|
|
1189
|
+
"""Visualize PAF predictions during training (used with callbacks)."""
|
|
1190
|
+
data = self.get_visualization_data(sample, include_pafs=True)
|
|
1143
1191
|
scale = 1.0
|
|
1144
|
-
if
|
|
1192
|
+
if data.image.shape[0] < 512:
|
|
1145
1193
|
scale = 2.0
|
|
1146
|
-
if
|
|
1194
|
+
if data.image.shape[0] < 256:
|
|
1147
1195
|
scale = 4.0
|
|
1148
|
-
fig = plot_img(
|
|
1196
|
+
fig = plot_img(data.image, dpi=72 * scale, scale=scale)
|
|
1149
1197
|
|
|
1198
|
+
pafs = data.pred_pafs
|
|
1150
1199
|
pafs = pafs.reshape((pafs.shape[0], pafs.shape[1], -1, 2))
|
|
1151
1200
|
pafs_mag = np.sqrt(pafs[..., 0] ** 2 + pafs[..., 1] ** 2)
|
|
1152
|
-
plot_confmaps(pafs_mag, output_scale=pafs_mag.shape[0] /
|
|
1201
|
+
plot_confmaps(pafs_mag, output_scale=pafs_mag.shape[0] / data.image.shape[0])
|
|
1153
1202
|
return fig
|
|
1154
1203
|
|
|
1155
1204
|
def forward(self, img):
|
|
1156
1205
|
"""Forward pass of the model."""
|
|
1157
1206
|
img = torch.squeeze(img, dim=1).to(self.device)
|
|
1207
|
+
img = normalize_on_gpu(img)
|
|
1158
1208
|
output = self.model(img)
|
|
1159
1209
|
return {
|
|
1160
1210
|
"MultiInstanceConfmapsHead": output["MultiInstanceConfmapsHead"],
|
|
@@ -1203,6 +1253,22 @@ class BottomUpLightningModule(LightningModel):
|
|
|
1203
1253
|
loss,
|
|
1204
1254
|
prog_bar=True,
|
|
1205
1255
|
on_step=True,
|
|
1256
|
+
on_epoch=False,
|
|
1257
|
+
logger=True,
|
|
1258
|
+
sync_dist=True,
|
|
1259
|
+
)
|
|
1260
|
+
self.log(
|
|
1261
|
+
"train_confmap_loss",
|
|
1262
|
+
confmap_loss,
|
|
1263
|
+
on_step=False,
|
|
1264
|
+
on_epoch=True,
|
|
1265
|
+
logger=True,
|
|
1266
|
+
sync_dist=True,
|
|
1267
|
+
)
|
|
1268
|
+
self.log(
|
|
1269
|
+
"train_paf_loss",
|
|
1270
|
+
pafs_loss,
|
|
1271
|
+
on_step=False,
|
|
1206
1272
|
on_epoch=True,
|
|
1207
1273
|
logger=True,
|
|
1208
1274
|
sync_dist=True,
|
|
@@ -1248,21 +1314,27 @@ class BottomUpLightningModule(LightningModel):
|
|
|
1248
1314
|
}
|
|
1249
1315
|
|
|
1250
1316
|
val_loss = sum([s * losses[t] for s, t in zip(self.loss_weights, losses)])
|
|
1251
|
-
lr = self.optimizers().optimizer.param_groups[0]["lr"]
|
|
1252
1317
|
self.log(
|
|
1253
|
-
"
|
|
1254
|
-
|
|
1318
|
+
"val_loss",
|
|
1319
|
+
val_loss,
|
|
1255
1320
|
prog_bar=True,
|
|
1256
|
-
on_step=
|
|
1321
|
+
on_step=False,
|
|
1257
1322
|
on_epoch=True,
|
|
1258
1323
|
logger=True,
|
|
1259
1324
|
sync_dist=True,
|
|
1260
1325
|
)
|
|
1261
1326
|
self.log(
|
|
1262
|
-
"
|
|
1263
|
-
|
|
1264
|
-
|
|
1265
|
-
|
|
1327
|
+
"val_confmap_loss",
|
|
1328
|
+
confmap_loss,
|
|
1329
|
+
on_step=False,
|
|
1330
|
+
on_epoch=True,
|
|
1331
|
+
logger=True,
|
|
1332
|
+
sync_dist=True,
|
|
1333
|
+
)
|
|
1334
|
+
self.log(
|
|
1335
|
+
"val_paf_loss",
|
|
1336
|
+
pafs_loss,
|
|
1337
|
+
on_step=False,
|
|
1266
1338
|
on_epoch=True,
|
|
1267
1339
|
logger=True,
|
|
1268
1340
|
sync_dist=True,
|
|
@@ -1361,9 +1433,14 @@ class BottomUpMultiClassLightningModule(LightningModel):
|
|
|
1361
1433
|
cms_output_stride=self.head_configs.multi_class_bottomup.confmaps.output_stride,
|
|
1362
1434
|
class_maps_output_stride=self.head_configs.multi_class_bottomup.class_maps.output_stride,
|
|
1363
1435
|
)
|
|
1436
|
+
self.node_names = list(
|
|
1437
|
+
self.head_configs.multi_class_bottomup.confmaps.part_names
|
|
1438
|
+
)
|
|
1364
1439
|
|
|
1365
|
-
def
|
|
1366
|
-
|
|
1440
|
+
def get_visualization_data(
|
|
1441
|
+
self, sample, include_class_maps: bool = False
|
|
1442
|
+
) -> VisualizationData:
|
|
1443
|
+
"""Extract visualization data from a sample."""
|
|
1367
1444
|
ex = sample.copy()
|
|
1368
1445
|
ex["eff_scale"] = torch.tensor([1.0])
|
|
1369
1446
|
for k, v in ex.items():
|
|
@@ -1371,54 +1448,65 @@ class BottomUpMultiClassLightningModule(LightningModel):
|
|
|
1371
1448
|
ex[k] = v.to(device=self.device)
|
|
1372
1449
|
ex["image"] = ex["image"].unsqueeze(dim=0)
|
|
1373
1450
|
output = self.bottomup_inf_layer(ex)[0]
|
|
1451
|
+
|
|
1374
1452
|
peaks = output["pred_instance_peaks"][0].cpu().numpy()
|
|
1375
|
-
|
|
1376
|
-
|
|
1377
|
-
) # convert from (C, H, W) to (H, W, C)
|
|
1453
|
+
peak_values = output["pred_peak_values"][0].cpu().numpy()
|
|
1454
|
+
img = output["image"][0, 0].cpu().numpy().transpose(1, 2, 0)
|
|
1378
1455
|
gt_instances = ex["instances"][0].cpu().numpy()
|
|
1379
|
-
confmaps = (
|
|
1380
|
-
|
|
1381
|
-
|
|
1456
|
+
confmaps = output["pred_confmaps"][0].cpu().numpy().transpose(1, 2, 0)
|
|
1457
|
+
|
|
1458
|
+
pred_class_maps = None
|
|
1459
|
+
if include_class_maps:
|
|
1460
|
+
pred_class_maps = (
|
|
1461
|
+
output["pred_class_maps"].cpu().numpy()[0].transpose(1, 2, 0)
|
|
1462
|
+
)
|
|
1463
|
+
|
|
1464
|
+
return VisualizationData(
|
|
1465
|
+
image=img,
|
|
1466
|
+
pred_confmaps=confmaps,
|
|
1467
|
+
pred_peaks=peaks,
|
|
1468
|
+
pred_peak_values=peak_values,
|
|
1469
|
+
gt_instances=gt_instances,
|
|
1470
|
+
node_names=self.node_names,
|
|
1471
|
+
output_scale=confmaps.shape[0] / img.shape[0],
|
|
1472
|
+
is_paired=False,
|
|
1473
|
+
pred_class_maps=pred_class_maps,
|
|
1474
|
+
)
|
|
1475
|
+
|
|
1476
|
+
def visualize_example(self, sample):
|
|
1477
|
+
"""Visualize predictions during training (used with callbacks)."""
|
|
1478
|
+
data = self.get_visualization_data(sample)
|
|
1382
1479
|
scale = 1.0
|
|
1383
|
-
if
|
|
1480
|
+
if data.image.shape[0] < 512:
|
|
1384
1481
|
scale = 2.0
|
|
1385
|
-
if
|
|
1482
|
+
if data.image.shape[0] < 256:
|
|
1386
1483
|
scale = 4.0
|
|
1387
|
-
fig = plot_img(
|
|
1388
|
-
plot_confmaps(
|
|
1484
|
+
fig = plot_img(data.image, dpi=72 * scale, scale=scale)
|
|
1485
|
+
plot_confmaps(data.pred_confmaps, output_scale=data.output_scale)
|
|
1389
1486
|
plt.xlim(plt.xlim())
|
|
1390
1487
|
plt.ylim(plt.ylim())
|
|
1391
|
-
plot_peaks(gt_instances,
|
|
1488
|
+
plot_peaks(data.gt_instances, data.pred_peaks, paired=data.is_paired)
|
|
1392
1489
|
return fig
|
|
1393
1490
|
|
|
1394
1491
|
def visualize_class_maps_example(self, sample):
|
|
1395
|
-
"""Visualize predictions during training (used with callbacks)."""
|
|
1396
|
-
|
|
1397
|
-
ex["eff_scale"] = torch.tensor([1.0])
|
|
1398
|
-
for k, v in ex.items():
|
|
1399
|
-
if isinstance(v, torch.Tensor):
|
|
1400
|
-
ex[k] = v.to(device=self.device)
|
|
1401
|
-
ex["image"] = ex["image"].unsqueeze(dim=0)
|
|
1402
|
-
output = self.bottomup_inf_layer(ex)[0]
|
|
1403
|
-
img = (
|
|
1404
|
-
output["image"][0, 0].cpu().numpy().transpose(1, 2, 0)
|
|
1405
|
-
) # convert from (C, H, W) to (H, W, C)
|
|
1406
|
-
classmaps = (
|
|
1407
|
-
output["pred_class_maps"].cpu().numpy()[0].transpose(1, 2, 0)
|
|
1408
|
-
) # (n_classes, h, w)
|
|
1492
|
+
"""Visualize class map predictions during training (used with callbacks)."""
|
|
1493
|
+
data = self.get_visualization_data(sample, include_class_maps=True)
|
|
1409
1494
|
scale = 1.0
|
|
1410
|
-
if
|
|
1495
|
+
if data.image.shape[0] < 512:
|
|
1411
1496
|
scale = 2.0
|
|
1412
|
-
if
|
|
1497
|
+
if data.image.shape[0] < 256:
|
|
1413
1498
|
scale = 4.0
|
|
1414
|
-
fig = plot_img(
|
|
1415
|
-
|
|
1416
|
-
|
|
1499
|
+
fig = plot_img(data.image, dpi=72 * scale, scale=scale)
|
|
1500
|
+
plot_confmaps(
|
|
1501
|
+
data.pred_class_maps,
|
|
1502
|
+
output_scale=data.pred_class_maps.shape[0] / data.image.shape[0],
|
|
1503
|
+
)
|
|
1417
1504
|
return fig
|
|
1418
1505
|
|
|
1419
1506
|
def forward(self, img):
|
|
1420
1507
|
"""Forward pass of the model."""
|
|
1421
1508
|
img = torch.squeeze(img, dim=1).to(self.device)
|
|
1509
|
+
img = normalize_on_gpu(img)
|
|
1422
1510
|
output = self.model(img)
|
|
1423
1511
|
return {
|
|
1424
1512
|
"MultiInstanceConfmapsHead": output["MultiInstanceConfmapsHead"],
|
|
@@ -1458,6 +1546,22 @@ class BottomUpMultiClassLightningModule(LightningModel):
|
|
|
1458
1546
|
loss,
|
|
1459
1547
|
prog_bar=True,
|
|
1460
1548
|
on_step=True,
|
|
1549
|
+
on_epoch=False,
|
|
1550
|
+
logger=True,
|
|
1551
|
+
sync_dist=True,
|
|
1552
|
+
)
|
|
1553
|
+
self.log(
|
|
1554
|
+
"train_confmap_loss",
|
|
1555
|
+
confmap_loss,
|
|
1556
|
+
on_step=False,
|
|
1557
|
+
on_epoch=True,
|
|
1558
|
+
logger=True,
|
|
1559
|
+
sync_dist=True,
|
|
1560
|
+
)
|
|
1561
|
+
self.log(
|
|
1562
|
+
"train_classmap_loss",
|
|
1563
|
+
classmaps_loss,
|
|
1564
|
+
on_step=False,
|
|
1461
1565
|
on_epoch=True,
|
|
1462
1566
|
logger=True,
|
|
1463
1567
|
sync_dist=True,
|
|
@@ -1494,21 +1598,27 @@ class BottomUpMultiClassLightningModule(LightningModel):
|
|
|
1494
1598
|
}
|
|
1495
1599
|
|
|
1496
1600
|
val_loss = sum([s * losses[t] for s, t in zip(self.loss_weights, losses)])
|
|
1497
|
-
lr = self.optimizers().optimizer.param_groups[0]["lr"]
|
|
1498
1601
|
self.log(
|
|
1499
|
-
"
|
|
1500
|
-
|
|
1602
|
+
"val_loss",
|
|
1603
|
+
val_loss,
|
|
1501
1604
|
prog_bar=True,
|
|
1502
|
-
on_step=
|
|
1605
|
+
on_step=False,
|
|
1503
1606
|
on_epoch=True,
|
|
1504
1607
|
logger=True,
|
|
1505
1608
|
sync_dist=True,
|
|
1506
1609
|
)
|
|
1507
1610
|
self.log(
|
|
1508
|
-
"
|
|
1509
|
-
|
|
1510
|
-
|
|
1511
|
-
|
|
1611
|
+
"val_confmap_loss",
|
|
1612
|
+
confmap_loss,
|
|
1613
|
+
on_step=False,
|
|
1614
|
+
on_epoch=True,
|
|
1615
|
+
logger=True,
|
|
1616
|
+
sync_dist=True,
|
|
1617
|
+
)
|
|
1618
|
+
self.log(
|
|
1619
|
+
"val_classmap_loss",
|
|
1620
|
+
classmaps_loss,
|
|
1621
|
+
on_step=False,
|
|
1512
1622
|
on_epoch=True,
|
|
1513
1623
|
logger=True,
|
|
1514
1624
|
sync_dist=True,
|
|
@@ -1607,8 +1717,8 @@ class TopDownCenteredInstanceMultiClassLightningModule(LightningModel):
|
|
|
1607
1717
|
|
|
1608
1718
|
self.node_names = self.head_configs.multi_class_topdown.confmaps.part_names
|
|
1609
1719
|
|
|
1610
|
-
def
|
|
1611
|
-
"""
|
|
1720
|
+
def get_visualization_data(self, sample) -> VisualizationData:
|
|
1721
|
+
"""Extract visualization data from a sample."""
|
|
1612
1722
|
ex = sample.copy()
|
|
1613
1723
|
ex["eff_scale"] = torch.tensor([1.0])
|
|
1614
1724
|
for k, v in ex.items():
|
|
@@ -1616,27 +1726,41 @@ class TopDownCenteredInstanceMultiClassLightningModule(LightningModel):
|
|
|
1616
1726
|
ex[k] = v.to(device=self.device)
|
|
1617
1727
|
ex["instance_image"] = ex["instance_image"].unsqueeze(dim=0)
|
|
1618
1728
|
output = self.instance_peaks_inf_layer(ex)
|
|
1729
|
+
|
|
1619
1730
|
peaks = output["pred_instance_peaks"].cpu().numpy()
|
|
1620
|
-
|
|
1621
|
-
|
|
1622
|
-
) # convert from (C, H, W) to (H, W, C)
|
|
1731
|
+
peak_values = output["pred_peak_values"].cpu().numpy()
|
|
1732
|
+
img = output["instance_image"][0, 0].cpu().numpy().transpose(1, 2, 0)
|
|
1623
1733
|
gt_instances = ex["instance"].cpu().numpy()
|
|
1624
|
-
confmaps = (
|
|
1625
|
-
|
|
1626
|
-
|
|
1734
|
+
confmaps = output["pred_confmaps"][0].cpu().numpy().transpose(1, 2, 0)
|
|
1735
|
+
|
|
1736
|
+
return VisualizationData(
|
|
1737
|
+
image=img,
|
|
1738
|
+
pred_confmaps=confmaps,
|
|
1739
|
+
pred_peaks=peaks,
|
|
1740
|
+
pred_peak_values=peak_values,
|
|
1741
|
+
gt_instances=gt_instances,
|
|
1742
|
+
node_names=list(self.node_names) if self.node_names else [],
|
|
1743
|
+
output_scale=confmaps.shape[0] / img.shape[0],
|
|
1744
|
+
is_paired=True,
|
|
1745
|
+
)
|
|
1746
|
+
|
|
1747
|
+
def visualize_example(self, sample):
|
|
1748
|
+
"""Visualize predictions during training (used with callbacks)."""
|
|
1749
|
+
data = self.get_visualization_data(sample)
|
|
1627
1750
|
scale = 1.0
|
|
1628
|
-
if
|
|
1751
|
+
if data.image.shape[0] < 512:
|
|
1629
1752
|
scale = 2.0
|
|
1630
|
-
if
|
|
1753
|
+
if data.image.shape[0] < 256:
|
|
1631
1754
|
scale = 4.0
|
|
1632
|
-
fig = plot_img(
|
|
1633
|
-
plot_confmaps(
|
|
1634
|
-
plot_peaks(gt_instances,
|
|
1755
|
+
fig = plot_img(data.image, dpi=72 * scale, scale=scale)
|
|
1756
|
+
plot_confmaps(data.pred_confmaps, output_scale=data.output_scale)
|
|
1757
|
+
plot_peaks(data.gt_instances, data.pred_peaks, paired=data.is_paired)
|
|
1635
1758
|
return fig
|
|
1636
1759
|
|
|
1637
1760
|
def forward(self, img):
|
|
1638
1761
|
"""Forward pass of the model."""
|
|
1639
1762
|
img = torch.squeeze(img, dim=1).to(self.device)
|
|
1763
|
+
img = normalize_on_gpu(img)
|
|
1640
1764
|
output = self.model(img)
|
|
1641
1765
|
return {
|
|
1642
1766
|
"CenteredInstanceConfmapsHead": output["CenteredInstanceConfmapsHead"],
|
|
@@ -1681,8 +1805,8 @@ class TopDownCenteredInstanceMultiClassLightningModule(LightningModel):
|
|
|
1681
1805
|
self.log(
|
|
1682
1806
|
f"{name}",
|
|
1683
1807
|
channel_wise_loss[node_idx],
|
|
1684
|
-
prog_bar=
|
|
1685
|
-
on_step=
|
|
1808
|
+
prog_bar=False,
|
|
1809
|
+
on_step=False,
|
|
1686
1810
|
on_epoch=True,
|
|
1687
1811
|
logger=True,
|
|
1688
1812
|
sync_dist=True,
|
|
@@ -1693,6 +1817,22 @@ class TopDownCenteredInstanceMultiClassLightningModule(LightningModel):
|
|
|
1693
1817
|
loss,
|
|
1694
1818
|
prog_bar=True,
|
|
1695
1819
|
on_step=True,
|
|
1820
|
+
on_epoch=False,
|
|
1821
|
+
logger=True,
|
|
1822
|
+
sync_dist=True,
|
|
1823
|
+
)
|
|
1824
|
+
self.log(
|
|
1825
|
+
"train_confmap_loss",
|
|
1826
|
+
confmap_loss,
|
|
1827
|
+
on_step=False,
|
|
1828
|
+
on_epoch=True,
|
|
1829
|
+
logger=True,
|
|
1830
|
+
sync_dist=True,
|
|
1831
|
+
)
|
|
1832
|
+
self.log(
|
|
1833
|
+
"train_classvector_loss",
|
|
1834
|
+
classvector_loss,
|
|
1835
|
+
on_step=False,
|
|
1696
1836
|
on_epoch=True,
|
|
1697
1837
|
logger=True,
|
|
1698
1838
|
sync_dist=True,
|
|
@@ -1727,22 +1867,27 @@ class TopDownCenteredInstanceMultiClassLightningModule(LightningModel):
|
|
|
1727
1867
|
"ClassVectorsHead": classvector_loss,
|
|
1728
1868
|
}
|
|
1729
1869
|
val_loss = sum([s * losses[t] for s, t in zip(self.loss_weights, losses)])
|
|
1730
|
-
|
|
1731
|
-
lr = self.optimizers().optimizer.param_groups[0]["lr"]
|
|
1732
1870
|
self.log(
|
|
1733
|
-
"
|
|
1734
|
-
|
|
1871
|
+
"val_loss",
|
|
1872
|
+
val_loss,
|
|
1735
1873
|
prog_bar=True,
|
|
1736
|
-
on_step=
|
|
1874
|
+
on_step=False,
|
|
1737
1875
|
on_epoch=True,
|
|
1738
1876
|
logger=True,
|
|
1739
1877
|
sync_dist=True,
|
|
1740
1878
|
)
|
|
1741
1879
|
self.log(
|
|
1742
|
-
"
|
|
1743
|
-
|
|
1744
|
-
|
|
1745
|
-
|
|
1880
|
+
"val_confmap_loss",
|
|
1881
|
+
confmap_loss,
|
|
1882
|
+
on_step=False,
|
|
1883
|
+
on_epoch=True,
|
|
1884
|
+
logger=True,
|
|
1885
|
+
sync_dist=True,
|
|
1886
|
+
)
|
|
1887
|
+
self.log(
|
|
1888
|
+
"val_classvector_loss",
|
|
1889
|
+
classvector_loss,
|
|
1890
|
+
on_step=False,
|
|
1746
1891
|
on_epoch=True,
|
|
1747
1892
|
logger=True,
|
|
1748
1893
|
sync_dist=True,
|