sleap-nn 0.0.5__py3-none-any.whl → 0.1.0a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -33,6 +33,7 @@ from sleap_nn.inference.bottomup import (
33
33
  )
34
34
  from sleap_nn.inference.paf_grouping import PAFScorer
35
35
  from sleap_nn.architectures.model import Model
36
+ from sleap_nn.data.normalization import normalize_on_gpu
36
37
  from sleap_nn.training.losses import compute_ohkm_loss
37
38
  from loguru import logger
38
39
  from sleap_nn.training.utils import (
@@ -40,7 +41,13 @@ from sleap_nn.training.utils import (
40
41
  plot_confmaps,
41
42
  plot_img,
42
43
  plot_peaks,
44
+ VisualizationData,
43
45
  )
46
+ import matplotlib
47
+
48
+ matplotlib.use(
49
+ "Agg"
50
+ ) # Use non-interactive backend to avoid tkinter issues on Windows CI
44
51
  import matplotlib.pyplot as plt
45
52
  from sleap_nn.config.utils import get_backbone_type_from_cfg, get_model_type_from_cfg
46
53
  from sleap_nn.config.trainer_config import (
@@ -311,6 +318,15 @@ class LightningModel(L.LightningModule):
311
318
  logger=True,
312
319
  sync_dist=True,
313
320
  )
321
+ # Log epoch explicitly for custom x-axis support in wandb
322
+ self.log(
323
+ "epoch",
324
+ float(self.current_epoch),
325
+ on_step=False,
326
+ on_epoch=True,
327
+ logger=True,
328
+ sync_dist=True,
329
+ )
314
330
 
315
331
  def on_validation_epoch_start(self):
316
332
  """Configure the val timer at the beginning of each epoch."""
@@ -493,8 +509,15 @@ class SingleInstanceLightningModule(LightningModel):
493
509
  )
494
510
  self.node_names = self.head_configs.single_instance.confmaps.part_names
495
511
 
496
- def visualize_example(self, sample):
497
- """Visualize predictions during training (used with callbacks)."""
512
+ def get_visualization_data(self, sample) -> VisualizationData:
513
+ """Extract visualization data from a sample.
514
+
515
+ Args:
516
+ sample: A sample dictionary from the data pipeline.
517
+
518
+ Returns:
519
+ VisualizationData containing image, confmaps, peaks, etc.
520
+ """
498
521
  ex = sample.copy()
499
522
  ex["eff_scale"] = torch.tensor([1.0])
500
523
  for k, v in ex.items():
@@ -502,27 +525,41 @@ class SingleInstanceLightningModule(LightningModel):
502
525
  ex[k] = v.to(device=self.device)
503
526
  ex["image"] = ex["image"].unsqueeze(dim=0)
504
527
  output = self.single_instance_inf_layer(ex)[0]
528
+
505
529
  peaks = output["pred_instance_peaks"].cpu().numpy()
506
- img = (
507
- output["image"][0, 0].cpu().numpy().transpose(1, 2, 0)
508
- ) # convert from (C, H, W) to (H, W, C)
530
+ peak_values = output["pred_peak_values"].cpu().numpy()
531
+ img = output["image"][0, 0].cpu().numpy().transpose(1, 2, 0)
509
532
  gt_instances = ex["instances"][0].cpu().numpy()
510
- confmaps = (
511
- output["pred_confmaps"][0].cpu().numpy().transpose(1, 2, 0)
512
- ) # convert from (C, H, W) to (H, W, C)
533
+ confmaps = output["pred_confmaps"][0].cpu().numpy().transpose(1, 2, 0)
534
+
535
+ return VisualizationData(
536
+ image=img,
537
+ pred_confmaps=confmaps,
538
+ pred_peaks=peaks,
539
+ pred_peak_values=peak_values,
540
+ gt_instances=gt_instances,
541
+ node_names=list(self.node_names) if self.node_names else [],
542
+ output_scale=confmaps.shape[0] / img.shape[0],
543
+ is_paired=True,
544
+ )
545
+
546
+ def visualize_example(self, sample):
547
+ """Visualize predictions during training (used with callbacks)."""
548
+ data = self.get_visualization_data(sample)
513
549
  scale = 1.0
514
- if img.shape[0] < 512:
550
+ if data.image.shape[0] < 512:
515
551
  scale = 2.0
516
- if img.shape[0] < 256:
552
+ if data.image.shape[0] < 256:
517
553
  scale = 4.0
518
- fig = plot_img(img, dpi=72 * scale, scale=scale)
519
- plot_confmaps(confmaps, output_scale=confmaps.shape[0] / img.shape[0])
520
- plot_peaks(gt_instances, peaks, paired=True)
554
+ fig = plot_img(data.image, dpi=72 * scale, scale=scale)
555
+ plot_confmaps(data.pred_confmaps, output_scale=data.output_scale)
556
+ plot_peaks(data.gt_instances, data.pred_peaks, paired=data.is_paired)
521
557
  return fig
522
558
 
523
559
  def forward(self, img):
524
560
  """Forward pass of the model."""
525
561
  img = torch.squeeze(img, dim=1).to(self.device)
562
+ img = normalize_on_gpu(img)
526
563
  return self.model(img)["SingleInstanceConfmapsHead"]
527
564
 
528
565
  def training_step(self, batch, batch_idx):
@@ -556,8 +593,8 @@ class SingleInstanceLightningModule(LightningModel):
556
593
  self.log(
557
594
  f"{name}",
558
595
  channel_wise_loss[node_idx],
559
- prog_bar=True,
560
- on_step=True,
596
+ prog_bar=False,
597
+ on_step=False,
561
598
  on_epoch=True,
562
599
  logger=True,
563
600
  sync_dist=True,
@@ -567,7 +604,7 @@ class SingleInstanceLightningModule(LightningModel):
567
604
  loss,
568
605
  prog_bar=True,
569
606
  on_step=True,
570
- on_epoch=True,
607
+ on_epoch=False,
571
608
  logger=True,
572
609
  sync_dist=True,
573
610
  )
@@ -592,21 +629,11 @@ class SingleInstanceLightningModule(LightningModel):
592
629
  loss_scale=self.loss_scale,
593
630
  )
594
631
  val_loss = val_loss + ohkm_loss
595
- lr = self.optimizers().optimizer.param_groups[0]["lr"]
596
- self.log(
597
- "learning_rate",
598
- lr,
599
- prog_bar=True,
600
- on_step=True,
601
- on_epoch=True,
602
- logger=True,
603
- sync_dist=True,
604
- )
605
632
  self.log(
606
633
  "val_loss",
607
634
  val_loss,
608
635
  prog_bar=True,
609
- on_step=True,
636
+ on_step=False,
610
637
  on_epoch=True,
611
638
  logger=True,
612
639
  sync_dist=True,
@@ -705,8 +732,8 @@ class TopDownCenteredInstanceLightningModule(LightningModel):
705
732
 
706
733
  self.node_names = self.head_configs.centered_instance.confmaps.part_names
707
734
 
708
- def visualize_example(self, sample):
709
- """Visualize predictions during training (used with callbacks)."""
735
+ def get_visualization_data(self, sample) -> VisualizationData:
736
+ """Extract visualization data from a sample."""
710
737
  ex = sample.copy()
711
738
  ex["eff_scale"] = torch.tensor([1.0])
712
739
  for k, v in ex.items():
@@ -714,27 +741,41 @@ class TopDownCenteredInstanceLightningModule(LightningModel):
714
741
  ex[k] = v.to(device=self.device)
715
742
  ex["instance_image"] = ex["instance_image"].unsqueeze(dim=0)
716
743
  output = self.instance_peaks_inf_layer(ex)
744
+
717
745
  peaks = output["pred_instance_peaks"].cpu().numpy()
718
- img = (
719
- output["instance_image"][0, 0].cpu().numpy().transpose(1, 2, 0)
720
- ) # convert from (C, H, W) to (H, W, C)
746
+ peak_values = output["pred_peak_values"].cpu().numpy()
747
+ img = output["instance_image"][0, 0].cpu().numpy().transpose(1, 2, 0)
721
748
  gt_instances = ex["instance"].cpu().numpy()
722
- confmaps = (
723
- output["pred_confmaps"][0].cpu().numpy().transpose(1, 2, 0)
724
- ) # convert from (C, H, W) to (H, W, C)
749
+ confmaps = output["pred_confmaps"][0].cpu().numpy().transpose(1, 2, 0)
750
+
751
+ return VisualizationData(
752
+ image=img,
753
+ pred_confmaps=confmaps,
754
+ pred_peaks=peaks,
755
+ pred_peak_values=peak_values,
756
+ gt_instances=gt_instances,
757
+ node_names=list(self.node_names) if self.node_names else [],
758
+ output_scale=confmaps.shape[0] / img.shape[0],
759
+ is_paired=True,
760
+ )
761
+
762
+ def visualize_example(self, sample):
763
+ """Visualize predictions during training (used with callbacks)."""
764
+ data = self.get_visualization_data(sample)
725
765
  scale = 1.0
726
- if img.shape[0] < 512:
766
+ if data.image.shape[0] < 512:
727
767
  scale = 2.0
728
- if img.shape[0] < 256:
768
+ if data.image.shape[0] < 256:
729
769
  scale = 4.0
730
- fig = plot_img(img, dpi=72 * scale, scale=scale)
731
- plot_confmaps(confmaps, output_scale=confmaps.shape[0] / img.shape[0])
732
- plot_peaks(gt_instances, peaks, paired=True)
770
+ fig = plot_img(data.image, dpi=72 * scale, scale=scale)
771
+ plot_confmaps(data.pred_confmaps, output_scale=data.output_scale)
772
+ plot_peaks(data.gt_instances, data.pred_peaks, paired=data.is_paired)
733
773
  return fig
734
774
 
735
775
  def forward(self, img):
736
776
  """Forward pass of the model."""
737
777
  img = torch.squeeze(img, dim=1).to(self.device)
778
+ img = normalize_on_gpu(img)
738
779
  return self.model(img)["CenteredInstanceConfmapsHead"]
739
780
 
740
781
  def training_step(self, batch, batch_idx):
@@ -768,8 +809,8 @@ class TopDownCenteredInstanceLightningModule(LightningModel):
768
809
  self.log(
769
810
  f"{name}",
770
811
  channel_wise_loss[node_idx],
771
- prog_bar=True,
772
- on_step=True,
812
+ prog_bar=False,
813
+ on_step=False,
773
814
  on_epoch=True,
774
815
  logger=True,
775
816
  sync_dist=True,
@@ -780,7 +821,7 @@ class TopDownCenteredInstanceLightningModule(LightningModel):
780
821
  loss,
781
822
  prog_bar=True,
782
823
  on_step=True,
783
- on_epoch=True,
824
+ on_epoch=False,
784
825
  logger=True,
785
826
  sync_dist=True,
786
827
  )
@@ -805,21 +846,11 @@ class TopDownCenteredInstanceLightningModule(LightningModel):
805
846
  loss_scale=self.loss_scale,
806
847
  )
807
848
  val_loss = val_loss + ohkm_loss
808
- lr = self.optimizers().optimizer.param_groups[0]["lr"]
809
- self.log(
810
- "learning_rate",
811
- lr,
812
- prog_bar=True,
813
- on_step=True,
814
- on_epoch=True,
815
- logger=True,
816
- sync_dist=True,
817
- )
818
849
  self.log(
819
850
  "val_loss",
820
851
  val_loss,
821
852
  prog_bar=True,
822
- on_step=True,
853
+ on_step=False,
823
854
  on_epoch=True,
824
855
  logger=True,
825
856
  sync_dist=True,
@@ -916,9 +947,10 @@ class CentroidLightningModule(LightningModel):
916
947
  output_stride=self.head_configs.centroid.confmaps.output_stride,
917
948
  input_scale=1.0,
918
949
  )
950
+ self.node_names = ["centroid"]
919
951
 
920
- def visualize_example(self, sample):
921
- """Visualize predictions during training (used with callbacks)."""
952
+ def get_visualization_data(self, sample) -> VisualizationData:
953
+ """Extract visualization data from a sample."""
922
954
  ex = sample.copy()
923
955
  ex["eff_scale"] = torch.tensor([1.0])
924
956
  for k, v in ex.items():
@@ -927,26 +959,40 @@ class CentroidLightningModule(LightningModel):
927
959
  ex["image"] = ex["image"].unsqueeze(dim=0)
928
960
  gt_centroids = ex["centroids"].cpu().numpy()
929
961
  output = self.centroid_inf_layer(ex)
962
+
930
963
  peaks = output["centroids"][0].cpu().numpy()
931
- img = (
932
- output["image"][0, 0].cpu().numpy().transpose(1, 2, 0)
933
- ) # convert from (C, H, W) to (H, W, C)
934
- confmaps = (
935
- output["pred_centroid_confmaps"][0].cpu().numpy().transpose(1, 2, 0)
936
- ) # convert from (C, H, W) to (H, W, C)
964
+ centroid_vals = output["centroid_vals"][0].cpu().numpy()
965
+ img = output["image"][0, 0].cpu().numpy().transpose(1, 2, 0)
966
+ confmaps = output["pred_centroid_confmaps"][0].cpu().numpy().transpose(1, 2, 0)
967
+
968
+ return VisualizationData(
969
+ image=img,
970
+ pred_confmaps=confmaps,
971
+ pred_peaks=peaks,
972
+ pred_peak_values=centroid_vals,
973
+ gt_instances=gt_centroids,
974
+ node_names=self.node_names,
975
+ output_scale=confmaps.shape[0] / img.shape[0],
976
+ is_paired=False,
977
+ )
978
+
979
+ def visualize_example(self, sample):
980
+ """Visualize predictions during training (used with callbacks)."""
981
+ data = self.get_visualization_data(sample)
937
982
  scale = 1.0
938
- if img.shape[0] < 512:
983
+ if data.image.shape[0] < 512:
939
984
  scale = 2.0
940
- if img.shape[0] < 256:
985
+ if data.image.shape[0] < 256:
941
986
  scale = 4.0
942
- fig = plot_img(img, dpi=72 * scale, scale=scale)
943
- plot_confmaps(confmaps, output_scale=confmaps.shape[0] / img.shape[0])
944
- plot_peaks(gt_centroids, peaks, paired=False)
987
+ fig = plot_img(data.image, dpi=72 * scale, scale=scale)
988
+ plot_confmaps(data.pred_confmaps, output_scale=data.output_scale)
989
+ plot_peaks(data.gt_instances, data.pred_peaks, paired=data.is_paired)
945
990
  return fig
946
991
 
947
992
  def forward(self, img):
948
993
  """Forward pass of the model."""
949
994
  img = torch.squeeze(img, dim=1).to(self.device)
995
+ img = normalize_on_gpu(img)
950
996
  return self.model(img)["CentroidConfmapsHead"]
951
997
 
952
998
  def training_step(self, batch, batch_idx):
@@ -963,7 +1009,7 @@ class CentroidLightningModule(LightningModel):
963
1009
  loss,
964
1010
  prog_bar=True,
965
1011
  on_step=True,
966
- on_epoch=True,
1012
+ on_epoch=False,
967
1013
  logger=True,
968
1014
  sync_dist=True,
969
1015
  )
@@ -978,21 +1024,11 @@ class CentroidLightningModule(LightningModel):
978
1024
 
979
1025
  y_preds = self.model(X)["CentroidConfmapsHead"]
980
1026
  val_loss = nn.MSELoss()(y_preds, y)
981
- lr = self.optimizers().optimizer.param_groups[0]["lr"]
982
- self.log(
983
- "learning_rate",
984
- lr,
985
- prog_bar=True,
986
- on_step=True,
987
- on_epoch=True,
988
- logger=True,
989
- sync_dist=True,
990
- )
991
1027
  self.log(
992
1028
  "val_loss",
993
1029
  val_loss,
994
1030
  prog_bar=True,
995
- on_step=True,
1031
+ on_step=False,
996
1032
  on_epoch=True,
997
1033
  logger=True,
998
1034
  sync_dist=True,
@@ -1097,9 +1133,12 @@ class BottomUpLightningModule(LightningModel):
1097
1133
  cms_output_stride=self.head_configs.bottomup.confmaps.output_stride,
1098
1134
  pafs_output_stride=self.head_configs.bottomup.pafs.output_stride,
1099
1135
  )
1136
+ self.node_names = list(self.head_configs.bottomup.confmaps.part_names)
1100
1137
 
1101
- def visualize_example(self, sample):
1102
- """Visualize predictions during training (used with callbacks)."""
1138
+ def get_visualization_data(
1139
+ self, sample, include_pafs: bool = False
1140
+ ) -> VisualizationData:
1141
+ """Extract visualization data from a sample."""
1103
1142
  ex = sample.copy()
1104
1143
  ex["eff_scale"] = torch.tensor([1.0])
1105
1144
  for k, v in ex.items():
@@ -1107,54 +1146,65 @@ class BottomUpLightningModule(LightningModel):
1107
1146
  ex[k] = v.to(device=self.device)
1108
1147
  ex["image"] = ex["image"].unsqueeze(dim=0)
1109
1148
  output = self.bottomup_inf_layer(ex)[0]
1149
+
1110
1150
  peaks = output["pred_instance_peaks"][0].cpu().numpy()
1111
- img = (
1112
- output["image"][0, 0].cpu().numpy().transpose(1, 2, 0)
1113
- ) # convert from (C, H, W) to (H, W, C)
1151
+ peak_values = output["pred_peak_values"][0].cpu().numpy()
1152
+ img = output["image"][0, 0].cpu().numpy().transpose(1, 2, 0)
1114
1153
  gt_instances = ex["instances"][0].cpu().numpy()
1115
- confmaps = (
1116
- output["pred_confmaps"][0].cpu().numpy().transpose(1, 2, 0)
1117
- ) # convert from (C, H, W) to (H, W, C)
1154
+ confmaps = output["pred_confmaps"][0].cpu().numpy().transpose(1, 2, 0)
1155
+
1156
+ pred_pafs = None
1157
+ if include_pafs:
1158
+ pafs = output["pred_part_affinity_fields"].cpu().numpy()[0]
1159
+ pred_pafs = pafs # (h, w, 2*edges)
1160
+
1161
+ return VisualizationData(
1162
+ image=img,
1163
+ pred_confmaps=confmaps,
1164
+ pred_peaks=peaks,
1165
+ pred_peak_values=peak_values,
1166
+ gt_instances=gt_instances,
1167
+ node_names=self.node_names,
1168
+ output_scale=confmaps.shape[0] / img.shape[0],
1169
+ is_paired=False,
1170
+ pred_pafs=pred_pafs,
1171
+ )
1172
+
1173
+ def visualize_example(self, sample):
1174
+ """Visualize predictions during training (used with callbacks)."""
1175
+ data = self.get_visualization_data(sample)
1118
1176
  scale = 1.0
1119
- if img.shape[0] < 512:
1177
+ if data.image.shape[0] < 512:
1120
1178
  scale = 2.0
1121
- if img.shape[0] < 256:
1179
+ if data.image.shape[0] < 256:
1122
1180
  scale = 4.0
1123
- fig = plot_img(img, dpi=72 * scale, scale=scale)
1124
- plot_confmaps(confmaps, output_scale=confmaps.shape[0] / img.shape[0])
1181
+ fig = plot_img(data.image, dpi=72 * scale, scale=scale)
1182
+ plot_confmaps(data.pred_confmaps, output_scale=data.output_scale)
1125
1183
  plt.xlim(plt.xlim())
1126
1184
  plt.ylim(plt.ylim())
1127
- plot_peaks(gt_instances, peaks, paired=False)
1185
+ plot_peaks(data.gt_instances, data.pred_peaks, paired=data.is_paired)
1128
1186
  return fig
1129
1187
 
1130
1188
  def visualize_pafs_example(self, sample):
1131
- """Visualize predictions during training (used with callbacks)."""
1132
- ex = sample.copy()
1133
- ex["eff_scale"] = torch.tensor([1.0])
1134
- for k, v in ex.items():
1135
- if isinstance(v, torch.Tensor):
1136
- ex[k] = v.to(device=self.device)
1137
- ex["image"] = ex["image"].unsqueeze(dim=0)
1138
- output = self.bottomup_inf_layer(ex)[0]
1139
- img = (
1140
- output["image"][0, 0].cpu().numpy().transpose(1, 2, 0)
1141
- ) # convert from (C, H, W) to (H, W, C)
1142
- pafs = output["pred_part_affinity_fields"].cpu().numpy()[0] # (h, w, 2*edges)
1189
+ """Visualize PAF predictions during training (used with callbacks)."""
1190
+ data = self.get_visualization_data(sample, include_pafs=True)
1143
1191
  scale = 1.0
1144
- if img.shape[0] < 512:
1192
+ if data.image.shape[0] < 512:
1145
1193
  scale = 2.0
1146
- if img.shape[0] < 256:
1194
+ if data.image.shape[0] < 256:
1147
1195
  scale = 4.0
1148
- fig = plot_img(img, dpi=72 * scale, scale=scale)
1196
+ fig = plot_img(data.image, dpi=72 * scale, scale=scale)
1149
1197
 
1198
+ pafs = data.pred_pafs
1150
1199
  pafs = pafs.reshape((pafs.shape[0], pafs.shape[1], -1, 2))
1151
1200
  pafs_mag = np.sqrt(pafs[..., 0] ** 2 + pafs[..., 1] ** 2)
1152
- plot_confmaps(pafs_mag, output_scale=pafs_mag.shape[0] / img.shape[0])
1201
+ plot_confmaps(pafs_mag, output_scale=pafs_mag.shape[0] / data.image.shape[0])
1153
1202
  return fig
1154
1203
 
1155
1204
  def forward(self, img):
1156
1205
  """Forward pass of the model."""
1157
1206
  img = torch.squeeze(img, dim=1).to(self.device)
1207
+ img = normalize_on_gpu(img)
1158
1208
  output = self.model(img)
1159
1209
  return {
1160
1210
  "MultiInstanceConfmapsHead": output["MultiInstanceConfmapsHead"],
@@ -1203,6 +1253,22 @@ class BottomUpLightningModule(LightningModel):
1203
1253
  loss,
1204
1254
  prog_bar=True,
1205
1255
  on_step=True,
1256
+ on_epoch=False,
1257
+ logger=True,
1258
+ sync_dist=True,
1259
+ )
1260
+ self.log(
1261
+ "train_confmap_loss",
1262
+ confmap_loss,
1263
+ on_step=False,
1264
+ on_epoch=True,
1265
+ logger=True,
1266
+ sync_dist=True,
1267
+ )
1268
+ self.log(
1269
+ "train_paf_loss",
1270
+ pafs_loss,
1271
+ on_step=False,
1206
1272
  on_epoch=True,
1207
1273
  logger=True,
1208
1274
  sync_dist=True,
@@ -1248,21 +1314,27 @@ class BottomUpLightningModule(LightningModel):
1248
1314
  }
1249
1315
 
1250
1316
  val_loss = sum([s * losses[t] for s, t in zip(self.loss_weights, losses)])
1251
- lr = self.optimizers().optimizer.param_groups[0]["lr"]
1252
1317
  self.log(
1253
- "learning_rate",
1254
- lr,
1318
+ "val_loss",
1319
+ val_loss,
1255
1320
  prog_bar=True,
1256
- on_step=True,
1321
+ on_step=False,
1257
1322
  on_epoch=True,
1258
1323
  logger=True,
1259
1324
  sync_dist=True,
1260
1325
  )
1261
1326
  self.log(
1262
- "val_loss",
1263
- val_loss,
1264
- prog_bar=True,
1265
- on_step=True,
1327
+ "val_confmap_loss",
1328
+ confmap_loss,
1329
+ on_step=False,
1330
+ on_epoch=True,
1331
+ logger=True,
1332
+ sync_dist=True,
1333
+ )
1334
+ self.log(
1335
+ "val_paf_loss",
1336
+ pafs_loss,
1337
+ on_step=False,
1266
1338
  on_epoch=True,
1267
1339
  logger=True,
1268
1340
  sync_dist=True,
@@ -1361,9 +1433,14 @@ class BottomUpMultiClassLightningModule(LightningModel):
1361
1433
  cms_output_stride=self.head_configs.multi_class_bottomup.confmaps.output_stride,
1362
1434
  class_maps_output_stride=self.head_configs.multi_class_bottomup.class_maps.output_stride,
1363
1435
  )
1436
+ self.node_names = list(
1437
+ self.head_configs.multi_class_bottomup.confmaps.part_names
1438
+ )
1364
1439
 
1365
- def visualize_example(self, sample):
1366
- """Visualize predictions during training (used with callbacks)."""
1440
+ def get_visualization_data(
1441
+ self, sample, include_class_maps: bool = False
1442
+ ) -> VisualizationData:
1443
+ """Extract visualization data from a sample."""
1367
1444
  ex = sample.copy()
1368
1445
  ex["eff_scale"] = torch.tensor([1.0])
1369
1446
  for k, v in ex.items():
@@ -1371,54 +1448,65 @@ class BottomUpMultiClassLightningModule(LightningModel):
1371
1448
  ex[k] = v.to(device=self.device)
1372
1449
  ex["image"] = ex["image"].unsqueeze(dim=0)
1373
1450
  output = self.bottomup_inf_layer(ex)[0]
1451
+
1374
1452
  peaks = output["pred_instance_peaks"][0].cpu().numpy()
1375
- img = (
1376
- output["image"][0, 0].cpu().numpy().transpose(1, 2, 0)
1377
- ) # convert from (C, H, W) to (H, W, C)
1453
+ peak_values = output["pred_peak_values"][0].cpu().numpy()
1454
+ img = output["image"][0, 0].cpu().numpy().transpose(1, 2, 0)
1378
1455
  gt_instances = ex["instances"][0].cpu().numpy()
1379
- confmaps = (
1380
- output["pred_confmaps"][0].cpu().numpy().transpose(1, 2, 0)
1381
- ) # convert from (C, H, W) to (H, W, C)
1456
+ confmaps = output["pred_confmaps"][0].cpu().numpy().transpose(1, 2, 0)
1457
+
1458
+ pred_class_maps = None
1459
+ if include_class_maps:
1460
+ pred_class_maps = (
1461
+ output["pred_class_maps"].cpu().numpy()[0].transpose(1, 2, 0)
1462
+ )
1463
+
1464
+ return VisualizationData(
1465
+ image=img,
1466
+ pred_confmaps=confmaps,
1467
+ pred_peaks=peaks,
1468
+ pred_peak_values=peak_values,
1469
+ gt_instances=gt_instances,
1470
+ node_names=self.node_names,
1471
+ output_scale=confmaps.shape[0] / img.shape[0],
1472
+ is_paired=False,
1473
+ pred_class_maps=pred_class_maps,
1474
+ )
1475
+
1476
+ def visualize_example(self, sample):
1477
+ """Visualize predictions during training (used with callbacks)."""
1478
+ data = self.get_visualization_data(sample)
1382
1479
  scale = 1.0
1383
- if img.shape[0] < 512:
1480
+ if data.image.shape[0] < 512:
1384
1481
  scale = 2.0
1385
- if img.shape[0] < 256:
1482
+ if data.image.shape[0] < 256:
1386
1483
  scale = 4.0
1387
- fig = plot_img(img, dpi=72 * scale, scale=scale)
1388
- plot_confmaps(confmaps, output_scale=confmaps.shape[0] / img.shape[0])
1484
+ fig = plot_img(data.image, dpi=72 * scale, scale=scale)
1485
+ plot_confmaps(data.pred_confmaps, output_scale=data.output_scale)
1389
1486
  plt.xlim(plt.xlim())
1390
1487
  plt.ylim(plt.ylim())
1391
- plot_peaks(gt_instances, peaks, paired=False)
1488
+ plot_peaks(data.gt_instances, data.pred_peaks, paired=data.is_paired)
1392
1489
  return fig
1393
1490
 
1394
1491
  def visualize_class_maps_example(self, sample):
1395
- """Visualize predictions during training (used with callbacks)."""
1396
- ex = sample.copy()
1397
- ex["eff_scale"] = torch.tensor([1.0])
1398
- for k, v in ex.items():
1399
- if isinstance(v, torch.Tensor):
1400
- ex[k] = v.to(device=self.device)
1401
- ex["image"] = ex["image"].unsqueeze(dim=0)
1402
- output = self.bottomup_inf_layer(ex)[0]
1403
- img = (
1404
- output["image"][0, 0].cpu().numpy().transpose(1, 2, 0)
1405
- ) # convert from (C, H, W) to (H, W, C)
1406
- classmaps = (
1407
- output["pred_class_maps"].cpu().numpy()[0].transpose(1, 2, 0)
1408
- ) # (n_classes, h, w)
1492
+ """Visualize class map predictions during training (used with callbacks)."""
1493
+ data = self.get_visualization_data(sample, include_class_maps=True)
1409
1494
  scale = 1.0
1410
- if img.shape[0] < 512:
1495
+ if data.image.shape[0] < 512:
1411
1496
  scale = 2.0
1412
- if img.shape[0] < 256:
1497
+ if data.image.shape[0] < 256:
1413
1498
  scale = 4.0
1414
- fig = plot_img(img, dpi=72 * scale, scale=scale)
1415
-
1416
- plot_confmaps(classmaps, output_scale=classmaps.shape[0] / img.shape[0])
1499
+ fig = plot_img(data.image, dpi=72 * scale, scale=scale)
1500
+ plot_confmaps(
1501
+ data.pred_class_maps,
1502
+ output_scale=data.pred_class_maps.shape[0] / data.image.shape[0],
1503
+ )
1417
1504
  return fig
1418
1505
 
1419
1506
  def forward(self, img):
1420
1507
  """Forward pass of the model."""
1421
1508
  img = torch.squeeze(img, dim=1).to(self.device)
1509
+ img = normalize_on_gpu(img)
1422
1510
  output = self.model(img)
1423
1511
  return {
1424
1512
  "MultiInstanceConfmapsHead": output["MultiInstanceConfmapsHead"],
@@ -1458,6 +1546,22 @@ class BottomUpMultiClassLightningModule(LightningModel):
1458
1546
  loss,
1459
1547
  prog_bar=True,
1460
1548
  on_step=True,
1549
+ on_epoch=False,
1550
+ logger=True,
1551
+ sync_dist=True,
1552
+ )
1553
+ self.log(
1554
+ "train_confmap_loss",
1555
+ confmap_loss,
1556
+ on_step=False,
1557
+ on_epoch=True,
1558
+ logger=True,
1559
+ sync_dist=True,
1560
+ )
1561
+ self.log(
1562
+ "train_classmap_loss",
1563
+ classmaps_loss,
1564
+ on_step=False,
1461
1565
  on_epoch=True,
1462
1566
  logger=True,
1463
1567
  sync_dist=True,
@@ -1494,21 +1598,27 @@ class BottomUpMultiClassLightningModule(LightningModel):
1494
1598
  }
1495
1599
 
1496
1600
  val_loss = sum([s * losses[t] for s, t in zip(self.loss_weights, losses)])
1497
- lr = self.optimizers().optimizer.param_groups[0]["lr"]
1498
1601
  self.log(
1499
- "learning_rate",
1500
- lr,
1602
+ "val_loss",
1603
+ val_loss,
1501
1604
  prog_bar=True,
1502
- on_step=True,
1605
+ on_step=False,
1503
1606
  on_epoch=True,
1504
1607
  logger=True,
1505
1608
  sync_dist=True,
1506
1609
  )
1507
1610
  self.log(
1508
- "val_loss",
1509
- val_loss,
1510
- prog_bar=True,
1511
- on_step=True,
1611
+ "val_confmap_loss",
1612
+ confmap_loss,
1613
+ on_step=False,
1614
+ on_epoch=True,
1615
+ logger=True,
1616
+ sync_dist=True,
1617
+ )
1618
+ self.log(
1619
+ "val_classmap_loss",
1620
+ classmaps_loss,
1621
+ on_step=False,
1512
1622
  on_epoch=True,
1513
1623
  logger=True,
1514
1624
  sync_dist=True,
@@ -1607,8 +1717,8 @@ class TopDownCenteredInstanceMultiClassLightningModule(LightningModel):
1607
1717
 
1608
1718
  self.node_names = self.head_configs.multi_class_topdown.confmaps.part_names
1609
1719
 
1610
- def visualize_example(self, sample):
1611
- """Visualize predictions during training (used with callbacks)."""
1720
+ def get_visualization_data(self, sample) -> VisualizationData:
1721
+ """Extract visualization data from a sample."""
1612
1722
  ex = sample.copy()
1613
1723
  ex["eff_scale"] = torch.tensor([1.0])
1614
1724
  for k, v in ex.items():
@@ -1616,27 +1726,41 @@ class TopDownCenteredInstanceMultiClassLightningModule(LightningModel):
1616
1726
  ex[k] = v.to(device=self.device)
1617
1727
  ex["instance_image"] = ex["instance_image"].unsqueeze(dim=0)
1618
1728
  output = self.instance_peaks_inf_layer(ex)
1729
+
1619
1730
  peaks = output["pred_instance_peaks"].cpu().numpy()
1620
- img = (
1621
- output["instance_image"][0, 0].cpu().numpy().transpose(1, 2, 0)
1622
- ) # convert from (C, H, W) to (H, W, C)
1731
+ peak_values = output["pred_peak_values"].cpu().numpy()
1732
+ img = output["instance_image"][0, 0].cpu().numpy().transpose(1, 2, 0)
1623
1733
  gt_instances = ex["instance"].cpu().numpy()
1624
- confmaps = (
1625
- output["pred_confmaps"][0].cpu().numpy().transpose(1, 2, 0)
1626
- ) # convert from (C, H, W) to (H, W, C)
1734
+ confmaps = output["pred_confmaps"][0].cpu().numpy().transpose(1, 2, 0)
1735
+
1736
+ return VisualizationData(
1737
+ image=img,
1738
+ pred_confmaps=confmaps,
1739
+ pred_peaks=peaks,
1740
+ pred_peak_values=peak_values,
1741
+ gt_instances=gt_instances,
1742
+ node_names=list(self.node_names) if self.node_names else [],
1743
+ output_scale=confmaps.shape[0] / img.shape[0],
1744
+ is_paired=True,
1745
+ )
1746
+
1747
+ def visualize_example(self, sample):
1748
+ """Visualize predictions during training (used with callbacks)."""
1749
+ data = self.get_visualization_data(sample)
1627
1750
  scale = 1.0
1628
- if img.shape[0] < 512:
1751
+ if data.image.shape[0] < 512:
1629
1752
  scale = 2.0
1630
- if img.shape[0] < 256:
1753
+ if data.image.shape[0] < 256:
1631
1754
  scale = 4.0
1632
- fig = plot_img(img, dpi=72 * scale, scale=scale)
1633
- plot_confmaps(confmaps, output_scale=confmaps.shape[0] / img.shape[0])
1634
- plot_peaks(gt_instances, peaks, paired=True)
1755
+ fig = plot_img(data.image, dpi=72 * scale, scale=scale)
1756
+ plot_confmaps(data.pred_confmaps, output_scale=data.output_scale)
1757
+ plot_peaks(data.gt_instances, data.pred_peaks, paired=data.is_paired)
1635
1758
  return fig
1636
1759
 
1637
1760
  def forward(self, img):
1638
1761
  """Forward pass of the model."""
1639
1762
  img = torch.squeeze(img, dim=1).to(self.device)
1763
+ img = normalize_on_gpu(img)
1640
1764
  output = self.model(img)
1641
1765
  return {
1642
1766
  "CenteredInstanceConfmapsHead": output["CenteredInstanceConfmapsHead"],
@@ -1681,8 +1805,8 @@ class TopDownCenteredInstanceMultiClassLightningModule(LightningModel):
1681
1805
  self.log(
1682
1806
  f"{name}",
1683
1807
  channel_wise_loss[node_idx],
1684
- prog_bar=True,
1685
- on_step=True,
1808
+ prog_bar=False,
1809
+ on_step=False,
1686
1810
  on_epoch=True,
1687
1811
  logger=True,
1688
1812
  sync_dist=True,
@@ -1693,6 +1817,22 @@ class TopDownCenteredInstanceMultiClassLightningModule(LightningModel):
1693
1817
  loss,
1694
1818
  prog_bar=True,
1695
1819
  on_step=True,
1820
+ on_epoch=False,
1821
+ logger=True,
1822
+ sync_dist=True,
1823
+ )
1824
+ self.log(
1825
+ "train_confmap_loss",
1826
+ confmap_loss,
1827
+ on_step=False,
1828
+ on_epoch=True,
1829
+ logger=True,
1830
+ sync_dist=True,
1831
+ )
1832
+ self.log(
1833
+ "train_classvector_loss",
1834
+ classvector_loss,
1835
+ on_step=False,
1696
1836
  on_epoch=True,
1697
1837
  logger=True,
1698
1838
  sync_dist=True,
@@ -1727,22 +1867,27 @@ class TopDownCenteredInstanceMultiClassLightningModule(LightningModel):
1727
1867
  "ClassVectorsHead": classvector_loss,
1728
1868
  }
1729
1869
  val_loss = sum([s * losses[t] for s, t in zip(self.loss_weights, losses)])
1730
-
1731
- lr = self.optimizers().optimizer.param_groups[0]["lr"]
1732
1870
  self.log(
1733
- "learning_rate",
1734
- lr,
1871
+ "val_loss",
1872
+ val_loss,
1735
1873
  prog_bar=True,
1736
- on_step=True,
1874
+ on_step=False,
1737
1875
  on_epoch=True,
1738
1876
  logger=True,
1739
1877
  sync_dist=True,
1740
1878
  )
1741
1879
  self.log(
1742
- "val_loss",
1743
- val_loss,
1744
- prog_bar=True,
1745
- on_step=True,
1880
+ "val_confmap_loss",
1881
+ confmap_loss,
1882
+ on_step=False,
1883
+ on_epoch=True,
1884
+ logger=True,
1885
+ sync_dist=True,
1886
+ )
1887
+ self.log(
1888
+ "val_classvector_loss",
1889
+ classvector_loss,
1890
+ on_step=False,
1746
1891
  on_epoch=True,
1747
1892
  logger=True,
1748
1893
  sync_dist=True,