sleap-nn 0.1.0a1__py3-none-any.whl → 0.1.0a3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. sleap_nn/__init__.py +1 -1
  2. sleap_nn/cli.py +36 -0
  3. sleap_nn/config/trainer_config.py +18 -0
  4. sleap_nn/evaluation.py +81 -22
  5. sleap_nn/export/__init__.py +21 -0
  6. sleap_nn/export/cli.py +1778 -0
  7. sleap_nn/export/exporters/__init__.py +51 -0
  8. sleap_nn/export/exporters/onnx_exporter.py +80 -0
  9. sleap_nn/export/exporters/tensorrt_exporter.py +291 -0
  10. sleap_nn/export/metadata.py +225 -0
  11. sleap_nn/export/predictors/__init__.py +63 -0
  12. sleap_nn/export/predictors/base.py +22 -0
  13. sleap_nn/export/predictors/onnx.py +154 -0
  14. sleap_nn/export/predictors/tensorrt.py +312 -0
  15. sleap_nn/export/utils.py +307 -0
  16. sleap_nn/export/wrappers/__init__.py +25 -0
  17. sleap_nn/export/wrappers/base.py +96 -0
  18. sleap_nn/export/wrappers/bottomup.py +243 -0
  19. sleap_nn/export/wrappers/bottomup_multiclass.py +195 -0
  20. sleap_nn/export/wrappers/centered_instance.py +56 -0
  21. sleap_nn/export/wrappers/centroid.py +58 -0
  22. sleap_nn/export/wrappers/single_instance.py +83 -0
  23. sleap_nn/export/wrappers/topdown.py +180 -0
  24. sleap_nn/export/wrappers/topdown_multiclass.py +304 -0
  25. sleap_nn/inference/bottomup.py +86 -20
  26. sleap_nn/inference/postprocessing.py +284 -0
  27. sleap_nn/predict.py +29 -0
  28. sleap_nn/train.py +64 -0
  29. sleap_nn/training/callbacks.py +324 -8
  30. sleap_nn/training/lightning_modules.py +542 -32
  31. sleap_nn/training/model_trainer.py +48 -57
  32. {sleap_nn-0.1.0a1.dist-info → sleap_nn-0.1.0a3.dist-info}/METADATA +13 -2
  33. {sleap_nn-0.1.0a1.dist-info → sleap_nn-0.1.0a3.dist-info}/RECORD +37 -16
  34. {sleap_nn-0.1.0a1.dist-info → sleap_nn-0.1.0a3.dist-info}/WHEEL +0 -0
  35. {sleap_nn-0.1.0a1.dist-info → sleap_nn-0.1.0a3.dist-info}/entry_points.txt +0 -0
  36. {sleap_nn-0.1.0a1.dist-info → sleap_nn-0.1.0a3.dist-info}/licenses/LICENSE +0 -0
  37. {sleap_nn-0.1.0a1.dist-info → sleap_nn-0.1.0a3.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  """This module has the LightningModule classes for all model types."""
2
2
 
3
- from typing import Optional, Union, Dict, Any
3
+ from typing import Optional, Union, Dict, Any, List
4
4
  import time
5
5
  from torch import nn
6
6
  import numpy as np
@@ -184,6 +184,15 @@ class LightningModel(L.LightningModule):
184
184
  self.val_loss = {}
185
185
  self.learning_rate = {}
186
186
 
187
+ # For epoch-averaged loss tracking
188
+ self._epoch_loss_sum = 0.0
189
+ self._epoch_loss_count = 0
190
+
191
+ # For epoch-end evaluation
192
+ self.val_predictions: List[Dict] = []
193
+ self.val_ground_truth: List[Dict] = []
194
+ self._collect_val_predictions: bool = False
195
+
187
196
  # Initialization for encoder and decoder stacks.
188
197
  if self.init_weights == "xavier":
189
198
  self.model.apply(xavier_init_weights)
@@ -305,12 +314,20 @@ class LightningModel(L.LightningModule):
305
314
  def on_train_epoch_start(self):
306
315
  """Configure the train timer at the beginning of each epoch."""
307
316
  self.train_start_time = time.time()
317
+ # Reset epoch loss tracking
318
+ self._epoch_loss_sum = 0.0
319
+ self._epoch_loss_count = 0
320
+
321
+ def _accumulate_loss(self, loss: torch.Tensor):
322
+ """Accumulate loss for epoch-averaged logging. Call this in training_step."""
323
+ self._epoch_loss_sum += loss.detach().item()
324
+ self._epoch_loss_count += 1
308
325
 
309
326
  def on_train_epoch_end(self):
310
327
  """Configure the train timer at the end of every epoch."""
311
328
  train_time = time.time() - self.train_start_time
312
329
  self.log(
313
- "train_time",
330
+ "train/time",
314
331
  train_time,
315
332
  prog_bar=False,
316
333
  on_step=False,
@@ -327,16 +344,43 @@ class LightningModel(L.LightningModule):
327
344
  logger=True,
328
345
  sync_dist=True,
329
346
  )
347
+ # Log epoch-averaged training loss
348
+ if self._epoch_loss_count > 0:
349
+ avg_loss = self._epoch_loss_sum / self._epoch_loss_count
350
+ self.log(
351
+ "train/loss",
352
+ avg_loss,
353
+ prog_bar=False,
354
+ on_step=False,
355
+ on_epoch=True,
356
+ logger=True,
357
+ sync_dist=True,
358
+ )
359
+ # Log current learning rate (useful for monitoring LR schedulers)
360
+ if self.trainer.optimizers:
361
+ lr = self.trainer.optimizers[0].param_groups[0]["lr"]
362
+ self.log(
363
+ "train/lr",
364
+ lr,
365
+ prog_bar=False,
366
+ on_step=False,
367
+ on_epoch=True,
368
+ logger=True,
369
+ sync_dist=True,
370
+ )
330
371
 
331
372
  def on_validation_epoch_start(self):
332
373
  """Configure the val timer at the beginning of each epoch."""
333
374
  self.val_start_time = time.time()
375
+ # Clear accumulated predictions for new epoch
376
+ self.val_predictions = []
377
+ self.val_ground_truth = []
334
378
 
335
379
  def on_validation_epoch_end(self):
336
380
  """Configure the val timer at the end of every epoch."""
337
381
  val_time = time.time() - self.val_start_time
338
382
  self.log(
339
- "val_time",
383
+ "val/time",
340
384
  val_time,
341
385
  prog_bar=False,
342
386
  on_step=False,
@@ -344,6 +388,16 @@ class LightningModel(L.LightningModule):
344
388
  logger=True,
345
389
  sync_dist=True,
346
390
  )
391
+ # Log epoch explicitly so val/* metrics can use it as x-axis in wandb
392
+ # (mirrors what on_train_epoch_end does for train/* metrics)
393
+ self.log(
394
+ "epoch",
395
+ float(self.current_epoch),
396
+ on_step=False,
397
+ on_epoch=True,
398
+ logger=True,
399
+ sync_dist=True,
400
+ )
347
401
 
348
402
  def training_step(self, batch, batch_idx):
349
403
  """Training step."""
@@ -412,7 +466,7 @@ class LightningModel(L.LightningModule):
412
466
  "optimizer": optimizer,
413
467
  "lr_scheduler": {
414
468
  "scheduler": scheduler,
415
- "monitor": "val_loss",
469
+ "monitor": "val/loss",
416
470
  },
417
471
  }
418
472
 
@@ -591,7 +645,7 @@ class SingleInstanceLightningModule(LightningModel):
591
645
  channel_wise_loss = torch.sum(mse, dim=(0, 2, 3)) / (batch_size * h * w)
592
646
  for node_idx, name in enumerate(self.node_names):
593
647
  self.log(
594
- f"{name}",
648
+ f"train/confmaps/{name}",
595
649
  channel_wise_loss[node_idx],
596
650
  prog_bar=False,
597
651
  on_step=False,
@@ -599,8 +653,9 @@ class SingleInstanceLightningModule(LightningModel):
599
653
  logger=True,
600
654
  sync_dist=True,
601
655
  )
656
+ # Log step-level loss (every batch, uses global_step x-axis)
602
657
  self.log(
603
- "train_loss",
658
+ "loss",
604
659
  loss,
605
660
  prog_bar=True,
606
661
  on_step=True,
@@ -608,6 +663,8 @@ class SingleInstanceLightningModule(LightningModel):
608
663
  logger=True,
609
664
  sync_dist=True,
610
665
  )
666
+ # Accumulate for epoch-averaged loss (logged in on_train_epoch_end)
667
+ self._accumulate_loss(loss)
611
668
  return loss
612
669
 
613
670
  def validation_step(self, batch, batch_idx):
@@ -630,7 +687,7 @@ class SingleInstanceLightningModule(LightningModel):
630
687
  )
631
688
  val_loss = val_loss + ohkm_loss
632
689
  self.log(
633
- "val_loss",
690
+ "val/loss",
634
691
  val_loss,
635
692
  prog_bar=True,
636
693
  on_step=False,
@@ -639,6 +696,51 @@ class SingleInstanceLightningModule(LightningModel):
639
696
  sync_dist=True,
640
697
  )
641
698
 
699
+ # Collect predictions for epoch-end evaluation if enabled
700
+ if self._collect_val_predictions:
701
+ with torch.no_grad():
702
+ # Squeeze n_samples dim from image for inference (batch, 1, C, H, W) -> (batch, C, H, W)
703
+ inference_batch = {k: v for k, v in batch.items()}
704
+ if inference_batch["image"].ndim == 5:
705
+ inference_batch["image"] = inference_batch["image"].squeeze(1)
706
+ inference_output = self.single_instance_inf_layer(inference_batch)
707
+ if isinstance(inference_output, list):
708
+ inference_output = inference_output[0]
709
+
710
+ batch_size = len(batch["frame_idx"])
711
+ for i in range(batch_size):
712
+ eff = batch["eff_scale"][i].cpu().numpy()
713
+
714
+ # Predictions are already in original image space (inference divides by eff_scale)
715
+ pred_peaks = inference_output["pred_instance_peaks"][i].cpu().numpy()
716
+ pred_scores = inference_output["pred_peak_values"][i].cpu().numpy()
717
+
718
+ # Transform GT from preprocessed to original image space
719
+ # Note: instances have shape (1, max_inst, n_nodes, 2) - squeeze n_samples dim
720
+ gt_prep = batch["instances"][i].cpu().numpy()
721
+ if gt_prep.ndim == 4:
722
+ gt_prep = gt_prep.squeeze(0) # (max_inst, n_nodes, 2)
723
+ gt_orig = gt_prep / eff
724
+ num_inst = batch["num_instances"][i].item()
725
+ gt_orig = gt_orig[:num_inst] # Only valid instances
726
+
727
+ self.val_predictions.append(
728
+ {
729
+ "video_idx": batch["video_idx"][i].item(),
730
+ "frame_idx": batch["frame_idx"][i].item(),
731
+ "pred_peaks": pred_peaks,
732
+ "pred_scores": pred_scores,
733
+ }
734
+ )
735
+ self.val_ground_truth.append(
736
+ {
737
+ "video_idx": batch["video_idx"][i].item(),
738
+ "frame_idx": batch["frame_idx"][i].item(),
739
+ "gt_instances": gt_orig,
740
+ "num_instances": num_inst,
741
+ }
742
+ )
743
+
642
744
 
643
745
  class TopDownCenteredInstanceLightningModule(LightningModel):
644
746
  """Lightning Module for TopDownCenteredInstance Model.
@@ -807,7 +909,7 @@ class TopDownCenteredInstanceLightningModule(LightningModel):
807
909
  channel_wise_loss = torch.sum(mse, dim=(0, 2, 3)) / (batch_size * h * w)
808
910
  for node_idx, name in enumerate(self.node_names):
809
911
  self.log(
810
- f"{name}",
912
+ f"train/confmaps/{name}",
811
913
  channel_wise_loss[node_idx],
812
914
  prog_bar=False,
813
915
  on_step=False,
@@ -816,8 +918,9 @@ class TopDownCenteredInstanceLightningModule(LightningModel):
816
918
  sync_dist=True,
817
919
  )
818
920
 
921
+ # Log step-level loss (every batch, uses global_step x-axis)
819
922
  self.log(
820
- "train_loss",
923
+ "loss",
821
924
  loss,
822
925
  prog_bar=True,
823
926
  on_step=True,
@@ -825,6 +928,8 @@ class TopDownCenteredInstanceLightningModule(LightningModel):
825
928
  logger=True,
826
929
  sync_dist=True,
827
930
  )
931
+ # Accumulate for epoch-averaged loss (logged in on_train_epoch_end)
932
+ self._accumulate_loss(loss)
828
933
  return loss
829
934
 
830
935
  def validation_step(self, batch, batch_idx):
@@ -847,7 +952,7 @@ class TopDownCenteredInstanceLightningModule(LightningModel):
847
952
  )
848
953
  val_loss = val_loss + ohkm_loss
849
954
  self.log(
850
- "val_loss",
955
+ "val/loss",
851
956
  val_loss,
852
957
  prog_bar=True,
853
958
  on_step=False,
@@ -856,6 +961,62 @@ class TopDownCenteredInstanceLightningModule(LightningModel):
856
961
  sync_dist=True,
857
962
  )
858
963
 
964
+ # Collect predictions for epoch-end evaluation if enabled
965
+ if self._collect_val_predictions:
966
+ # SAVE bbox BEFORE inference (it modifies in-place!)
967
+ bbox_prep_saved = batch["instance_bbox"].clone()
968
+
969
+ with torch.no_grad():
970
+ inference_output = self.instance_peaks_inf_layer(batch)
971
+
972
+ batch_size = len(batch["frame_idx"])
973
+ for i in range(batch_size):
974
+ eff = batch["eff_scale"][i].cpu().numpy()
975
+
976
+ # Predictions from inference (crop-relative, original scale)
977
+ pred_peaks_crop = (
978
+ inference_output["pred_instance_peaks"][i].cpu().numpy()
979
+ )
980
+ pred_scores = inference_output["pred_peak_values"][i].cpu().numpy()
981
+
982
+ # Compute bbox offset in original space from SAVED prep bbox
983
+ # bbox has shape (n_samples=1, 4, 2) where 4 corners
984
+ bbox_prep = bbox_prep_saved[i].squeeze(0).cpu().numpy() # (4, 2)
985
+ bbox_top_left_orig = (
986
+ bbox_prep[0] / eff
987
+ ) # Top-left corner in original space
988
+
989
+ # Full image coordinates (original space)
990
+ pred_peaks_full = pred_peaks_crop + bbox_top_left_orig
991
+
992
+ # GT transform: crop-relative preprocessed -> full image original
993
+ gt_crop_prep = (
994
+ batch["instance"][i].squeeze(0).cpu().numpy()
995
+ ) # (n_nodes, 2)
996
+ gt_crop_orig = gt_crop_prep / eff
997
+ gt_full_orig = gt_crop_orig + bbox_top_left_orig
998
+
999
+ self.val_predictions.append(
1000
+ {
1001
+ "video_idx": batch["video_idx"][i].item(),
1002
+ "frame_idx": batch["frame_idx"][i].item(),
1003
+ "pred_peaks": pred_peaks_full.reshape(
1004
+ 1, -1, 2
1005
+ ), # (1, n_nodes, 2)
1006
+ "pred_scores": pred_scores.reshape(1, -1), # (1, n_nodes)
1007
+ }
1008
+ )
1009
+ self.val_ground_truth.append(
1010
+ {
1011
+ "video_idx": batch["video_idx"][i].item(),
1012
+ "frame_idx": batch["frame_idx"][i].item(),
1013
+ "gt_instances": gt_full_orig.reshape(
1014
+ 1, -1, 2
1015
+ ), # (1, n_nodes, 2)
1016
+ "num_instances": 1,
1017
+ }
1018
+ )
1019
+
859
1020
 
860
1021
  class CentroidLightningModule(LightningModel):
861
1022
  """Lightning Module for Centroid Model.
@@ -1004,8 +1165,9 @@ class CentroidLightningModule(LightningModel):
1004
1165
 
1005
1166
  y_preds = self.model(X)["CentroidConfmapsHead"]
1006
1167
  loss = nn.MSELoss()(y_preds, y)
1168
+ # Log step-level loss (every batch, uses global_step x-axis)
1007
1169
  self.log(
1008
- "train_loss",
1170
+ "loss",
1009
1171
  loss,
1010
1172
  prog_bar=True,
1011
1173
  on_step=True,
@@ -1013,6 +1175,8 @@ class CentroidLightningModule(LightningModel):
1013
1175
  logger=True,
1014
1176
  sync_dist=True,
1015
1177
  )
1178
+ # Accumulate for epoch-averaged loss (logged in on_train_epoch_end)
1179
+ self._accumulate_loss(loss)
1016
1180
  return loss
1017
1181
 
1018
1182
  def validation_step(self, batch, batch_idx):
@@ -1025,7 +1189,7 @@ class CentroidLightningModule(LightningModel):
1025
1189
  y_preds = self.model(X)["CentroidConfmapsHead"]
1026
1190
  val_loss = nn.MSELoss()(y_preds, y)
1027
1191
  self.log(
1028
- "val_loss",
1192
+ "val/loss",
1029
1193
  val_loss,
1030
1194
  prog_bar=True,
1031
1195
  on_step=False,
@@ -1034,6 +1198,57 @@ class CentroidLightningModule(LightningModel):
1034
1198
  sync_dist=True,
1035
1199
  )
1036
1200
 
1201
+ # Collect predictions for epoch-end evaluation if enabled
1202
+ if self._collect_val_predictions:
1203
+ with torch.no_grad():
1204
+ inference_output = self.centroid_inf_layer(batch)
1205
+
1206
+ batch_size = len(batch["frame_idx"])
1207
+ for i in range(batch_size):
1208
+ eff = batch["eff_scale"][i].cpu().numpy()
1209
+
1210
+ # Predictions are in original image space (inference divides by eff_scale)
1211
+ # centroids shape: (batch, 1, max_instances, 2) - squeeze to (max_instances, 2)
1212
+ pred_centroids = (
1213
+ inference_output["centroids"][i].squeeze(0).cpu().numpy()
1214
+ )
1215
+ pred_vals = inference_output["centroid_vals"][i].cpu().numpy()
1216
+
1217
+ # Transform GT centroids from preprocessed to original image space
1218
+ gt_centroids_prep = (
1219
+ batch["centroids"][i].cpu().numpy()
1220
+ ) # (n_samples=1, max_inst, 2)
1221
+ gt_centroids_orig = gt_centroids_prep.squeeze(0) / eff # (max_inst, 2)
1222
+ num_inst = batch["num_instances"][i].item()
1223
+
1224
+ # Filter to valid instances (non-NaN)
1225
+ valid_pred_mask = ~np.isnan(pred_centroids).any(axis=1)
1226
+ pred_centroids = pred_centroids[valid_pred_mask]
1227
+ pred_vals = pred_vals[valid_pred_mask]
1228
+
1229
+ gt_centroids_valid = gt_centroids_orig[:num_inst]
1230
+
1231
+ self.val_predictions.append(
1232
+ {
1233
+ "video_idx": batch["video_idx"][i].item(),
1234
+ "frame_idx": batch["frame_idx"][i].item(),
1235
+ "pred_peaks": pred_centroids.reshape(
1236
+ -1, 1, 2
1237
+ ), # (n_inst, 1, 2)
1238
+ "pred_scores": pred_vals.reshape(-1, 1), # (n_inst, 1)
1239
+ }
1240
+ )
1241
+ self.val_ground_truth.append(
1242
+ {
1243
+ "video_idx": batch["video_idx"][i].item(),
1244
+ "frame_idx": batch["frame_idx"][i].item(),
1245
+ "gt_instances": gt_centroids_valid.reshape(
1246
+ -1, 1, 2
1247
+ ), # (n_inst, 1, 2)
1248
+ "num_instances": num_inst,
1249
+ }
1250
+ )
1251
+
1037
1252
 
1038
1253
  class BottomUpLightningModule(LightningModel):
1039
1254
  """Lightning Module for BottomUp Model.
@@ -1126,12 +1341,13 @@ class BottomUpLightningModule(LightningModel):
1126
1341
  self.bottomup_inf_layer = BottomUpInferenceModel(
1127
1342
  torch_model=self.forward,
1128
1343
  paf_scorer=paf_scorer,
1129
- peak_threshold=0.2,
1344
+ peak_threshold=0.1, # Lower threshold for epoch-end eval during training
1130
1345
  input_scale=1.0,
1131
1346
  return_confmaps=True,
1132
1347
  return_pafs=True,
1133
1348
  cms_output_stride=self.head_configs.bottomup.confmaps.output_stride,
1134
1349
  pafs_output_stride=self.head_configs.bottomup.pafs.output_stride,
1350
+ max_peaks_per_node=100, # Prevents combinatorial explosion in early training
1135
1351
  )
1136
1352
  self.node_names = list(self.head_configs.bottomup.confmaps.part_names)
1137
1353
 
@@ -1248,8 +1464,9 @@ class BottomUpLightningModule(LightningModel):
1248
1464
  "PartAffinityFieldsHead": pafs_loss,
1249
1465
  }
1250
1466
  loss = sum([s * losses[t] for s, t in zip(self.loss_weights, losses)])
1467
+ # Log step-level loss (every batch, uses global_step x-axis)
1251
1468
  self.log(
1252
- "train_loss",
1469
+ "loss",
1253
1470
  loss,
1254
1471
  prog_bar=True,
1255
1472
  on_step=True,
@@ -1257,8 +1474,10 @@ class BottomUpLightningModule(LightningModel):
1257
1474
  logger=True,
1258
1475
  sync_dist=True,
1259
1476
  )
1477
+ # Accumulate for epoch-averaged loss (logged in on_train_epoch_end)
1478
+ self._accumulate_loss(loss)
1260
1479
  self.log(
1261
- "train_confmap_loss",
1480
+ "train/confmaps_loss",
1262
1481
  confmap_loss,
1263
1482
  on_step=False,
1264
1483
  on_epoch=True,
@@ -1266,7 +1485,7 @@ class BottomUpLightningModule(LightningModel):
1266
1485
  sync_dist=True,
1267
1486
  )
1268
1487
  self.log(
1269
- "train_paf_loss",
1488
+ "train/paf_loss",
1270
1489
  pafs_loss,
1271
1490
  on_step=False,
1272
1491
  on_epoch=True,
@@ -1315,7 +1534,7 @@ class BottomUpLightningModule(LightningModel):
1315
1534
 
1316
1535
  val_loss = sum([s * losses[t] for s, t in zip(self.loss_weights, losses)])
1317
1536
  self.log(
1318
- "val_loss",
1537
+ "val/loss",
1319
1538
  val_loss,
1320
1539
  prog_bar=True,
1321
1540
  on_step=False,
@@ -1324,7 +1543,7 @@ class BottomUpLightningModule(LightningModel):
1324
1543
  sync_dist=True,
1325
1544
  )
1326
1545
  self.log(
1327
- "val_confmap_loss",
1546
+ "val/confmaps_loss",
1328
1547
  confmap_loss,
1329
1548
  on_step=False,
1330
1549
  on_epoch=True,
@@ -1332,7 +1551,7 @@ class BottomUpLightningModule(LightningModel):
1332
1551
  sync_dist=True,
1333
1552
  )
1334
1553
  self.log(
1335
- "val_paf_loss",
1554
+ "val/paf_loss",
1336
1555
  pafs_loss,
1337
1556
  on_step=False,
1338
1557
  on_epoch=True,
@@ -1340,6 +1559,53 @@ class BottomUpLightningModule(LightningModel):
1340
1559
  sync_dist=True,
1341
1560
  )
1342
1561
 
1562
+ # Collect predictions for epoch-end evaluation if enabled
1563
+ if self._collect_val_predictions:
1564
+ with torch.no_grad():
1565
+ # Note: Do NOT squeeze the image here - the forward() method expects
1566
+ # (batch, n_samples, C, H, W) and handles the n_samples squeeze internally
1567
+ inference_output = self.bottomup_inf_layer(batch)
1568
+ if isinstance(inference_output, list):
1569
+ inference_output = inference_output[0]
1570
+
1571
+ batch_size = len(batch["frame_idx"])
1572
+ for i in range(batch_size):
1573
+ eff = batch["eff_scale"][i].cpu().numpy()
1574
+
1575
+ # Predictions are already in original space (variable number of instances)
1576
+ pred_peaks = inference_output["pred_instance_peaks"][i]
1577
+ pred_scores = inference_output["pred_peak_values"][i]
1578
+ if torch.is_tensor(pred_peaks):
1579
+ pred_peaks = pred_peaks.cpu().numpy()
1580
+ if torch.is_tensor(pred_scores):
1581
+ pred_scores = pred_scores.cpu().numpy()
1582
+
1583
+ # Transform GT to original space
1584
+ # Note: instances have shape (1, max_inst, n_nodes, 2) - squeeze n_samples dim
1585
+ gt_prep = batch["instances"][i].cpu().numpy()
1586
+ if gt_prep.ndim == 4:
1587
+ gt_prep = gt_prep.squeeze(0) # (max_inst, n_nodes, 2)
1588
+ gt_orig = gt_prep / eff
1589
+ num_inst = batch["num_instances"][i].item()
1590
+ gt_orig = gt_orig[:num_inst] # Only valid instances
1591
+
1592
+ self.val_predictions.append(
1593
+ {
1594
+ "video_idx": batch["video_idx"][i].item(),
1595
+ "frame_idx": batch["frame_idx"][i].item(),
1596
+ "pred_peaks": pred_peaks, # Original space, variable instances
1597
+ "pred_scores": pred_scores,
1598
+ }
1599
+ )
1600
+ self.val_ground_truth.append(
1601
+ {
1602
+ "video_idx": batch["video_idx"][i].item(),
1603
+ "frame_idx": batch["frame_idx"][i].item(),
1604
+ "gt_instances": gt_orig, # Original space
1605
+ "num_instances": num_inst,
1606
+ }
1607
+ )
1608
+
1343
1609
 
1344
1610
  class BottomUpMultiClassLightningModule(LightningModel):
1345
1611
  """Lightning Module for BottomUp ID Model.
@@ -1541,8 +1807,9 @@ class BottomUpMultiClassLightningModule(LightningModel):
1541
1807
  "ClassMapsHead": classmaps_loss,
1542
1808
  }
1543
1809
  loss = sum([s * losses[t] for s, t in zip(self.loss_weights, losses)])
1810
+ # Log step-level loss (every batch, uses global_step x-axis)
1544
1811
  self.log(
1545
- "train_loss",
1812
+ "loss",
1546
1813
  loss,
1547
1814
  prog_bar=True,
1548
1815
  on_step=True,
@@ -1550,8 +1817,10 @@ class BottomUpMultiClassLightningModule(LightningModel):
1550
1817
  logger=True,
1551
1818
  sync_dist=True,
1552
1819
  )
1820
+ # Accumulate for epoch-averaged loss (logged in on_train_epoch_end)
1821
+ self._accumulate_loss(loss)
1553
1822
  self.log(
1554
- "train_confmap_loss",
1823
+ "train/confmaps_loss",
1555
1824
  confmap_loss,
1556
1825
  on_step=False,
1557
1826
  on_epoch=True,
@@ -1559,13 +1828,67 @@ class BottomUpMultiClassLightningModule(LightningModel):
1559
1828
  sync_dist=True,
1560
1829
  )
1561
1830
  self.log(
1562
- "train_classmap_loss",
1831
+ "train/classmap_loss",
1563
1832
  classmaps_loss,
1564
1833
  on_step=False,
1565
1834
  on_epoch=True,
1566
1835
  logger=True,
1567
1836
  sync_dist=True,
1568
1837
  )
1838
+
1839
+ # Compute classification accuracy at GT keypoint locations
1840
+ with torch.no_grad():
1841
+ # Get output stride for class maps
1842
+ cms_stride = self.head_configs.multi_class_bottomup.class_maps.output_stride
1843
+
1844
+ # Get GT instances and sample class maps at those locations
1845
+ instances = batch["instances"] # (batch, n_samples, max_inst, n_nodes, 2)
1846
+ if instances.dim() == 5:
1847
+ instances = instances.squeeze(1) # (batch, max_inst, n_nodes, 2)
1848
+ num_instances = batch["num_instances"] # (batch,)
1849
+
1850
+ correct = 0
1851
+ total = 0
1852
+ for b in range(instances.shape[0]):
1853
+ n_inst = num_instances[b].item()
1854
+ for inst_idx in range(n_inst):
1855
+ for node_idx in range(instances.shape[2]):
1856
+ # Get keypoint location (in input image space)
1857
+ kp = instances[b, inst_idx, node_idx] # (2,) = (x, y)
1858
+ if torch.isnan(kp).any():
1859
+ continue
1860
+
1861
+ # Convert to class map space
1862
+ x_cm = (
1863
+ (kp[0] / cms_stride)
1864
+ .long()
1865
+ .clamp(0, classmaps.shape[-1] - 1)
1866
+ )
1867
+ y_cm = (
1868
+ (kp[1] / cms_stride)
1869
+ .long()
1870
+ .clamp(0, classmaps.shape[-2] - 1)
1871
+ )
1872
+
1873
+ # Sample predicted and GT class at this location
1874
+ pred_class = classmaps[b, :, y_cm, x_cm].argmax()
1875
+ gt_class = y_classmap[b, :, y_cm, x_cm].argmax()
1876
+
1877
+ if pred_class == gt_class:
1878
+ correct += 1
1879
+ total += 1
1880
+
1881
+ if total > 0:
1882
+ class_accuracy = torch.tensor(correct / total, device=X.device)
1883
+ self.log(
1884
+ "train/class_accuracy",
1885
+ class_accuracy,
1886
+ on_step=False,
1887
+ on_epoch=True,
1888
+ logger=True,
1889
+ sync_dist=True,
1890
+ )
1891
+
1569
1892
  return loss
1570
1893
 
1571
1894
  def validation_step(self, batch, batch_idx):
@@ -1599,7 +1922,7 @@ class BottomUpMultiClassLightningModule(LightningModel):
1599
1922
 
1600
1923
  val_loss = sum([s * losses[t] for s, t in zip(self.loss_weights, losses)])
1601
1924
  self.log(
1602
- "val_loss",
1925
+ "val/loss",
1603
1926
  val_loss,
1604
1927
  prog_bar=True,
1605
1928
  on_step=False,
@@ -1608,7 +1931,7 @@ class BottomUpMultiClassLightningModule(LightningModel):
1608
1931
  sync_dist=True,
1609
1932
  )
1610
1933
  self.log(
1611
- "val_confmap_loss",
1934
+ "val/confmaps_loss",
1612
1935
  confmap_loss,
1613
1936
  on_step=False,
1614
1937
  on_epoch=True,
@@ -1616,7 +1939,7 @@ class BottomUpMultiClassLightningModule(LightningModel):
1616
1939
  sync_dist=True,
1617
1940
  )
1618
1941
  self.log(
1619
- "val_classmap_loss",
1942
+ "val/classmap_loss",
1620
1943
  classmaps_loss,
1621
1944
  on_step=False,
1622
1945
  on_epoch=True,
@@ -1624,6 +1947,106 @@ class BottomUpMultiClassLightningModule(LightningModel):
1624
1947
  sync_dist=True,
1625
1948
  )
1626
1949
 
1950
+ # Compute classification accuracy at GT keypoint locations
1951
+ with torch.no_grad():
1952
+ # Get output stride for class maps
1953
+ cms_stride = self.head_configs.multi_class_bottomup.class_maps.output_stride
1954
+
1955
+ # Get GT instances and sample class maps at those locations
1956
+ instances = batch["instances"] # (batch, n_samples, max_inst, n_nodes, 2)
1957
+ if instances.dim() == 5:
1958
+ instances = instances.squeeze(1) # (batch, max_inst, n_nodes, 2)
1959
+ num_instances = batch["num_instances"] # (batch,)
1960
+
1961
+ correct = 0
1962
+ total = 0
1963
+ for b in range(instances.shape[0]):
1964
+ n_inst = num_instances[b].item()
1965
+ for inst_idx in range(n_inst):
1966
+ for node_idx in range(instances.shape[2]):
1967
+ # Get keypoint location (in input image space)
1968
+ kp = instances[b, inst_idx, node_idx] # (2,) = (x, y)
1969
+ if torch.isnan(kp).any():
1970
+ continue
1971
+
1972
+ # Convert to class map space
1973
+ x_cm = (
1974
+ (kp[0] / cms_stride)
1975
+ .long()
1976
+ .clamp(0, classmaps.shape[-1] - 1)
1977
+ )
1978
+ y_cm = (
1979
+ (kp[1] / cms_stride)
1980
+ .long()
1981
+ .clamp(0, classmaps.shape[-2] - 1)
1982
+ )
1983
+
1984
+ # Sample predicted and GT class at this location
1985
+ pred_class = classmaps[b, :, y_cm, x_cm].argmax()
1986
+ gt_class = y_classmap[b, :, y_cm, x_cm].argmax()
1987
+
1988
+ if pred_class == gt_class:
1989
+ correct += 1
1990
+ total += 1
1991
+
1992
+ if total > 0:
1993
+ class_accuracy = torch.tensor(correct / total, device=X.device)
1994
+ self.log(
1995
+ "val/class_accuracy",
1996
+ class_accuracy,
1997
+ on_step=False,
1998
+ on_epoch=True,
1999
+ logger=True,
2000
+ sync_dist=True,
2001
+ )
2002
+
2003
+ # Collect predictions for epoch-end evaluation if enabled
2004
+ if self._collect_val_predictions:
2005
+ with torch.no_grad():
2006
+ # Note: Do NOT squeeze the image here - the forward() method expects
2007
+ # (batch, n_samples, C, H, W) and handles the n_samples squeeze internally
2008
+ inference_output = self.bottomup_inf_layer(batch)
2009
+ if isinstance(inference_output, list):
2010
+ inference_output = inference_output[0]
2011
+
2012
+ batch_size = len(batch["frame_idx"])
2013
+ for i in range(batch_size):
2014
+ eff = batch["eff_scale"][i].cpu().numpy()
2015
+
2016
+ # Predictions are already in original space (variable number of instances)
2017
+ pred_peaks = inference_output["pred_instance_peaks"][i]
2018
+ pred_scores = inference_output["pred_peak_values"][i]
2019
+ if torch.is_tensor(pred_peaks):
2020
+ pred_peaks = pred_peaks.cpu().numpy()
2021
+ if torch.is_tensor(pred_scores):
2022
+ pred_scores = pred_scores.cpu().numpy()
2023
+
2024
+ # Transform GT to original space
2025
+ # Note: instances have shape (1, max_inst, n_nodes, 2) - squeeze n_samples dim
2026
+ gt_prep = batch["instances"][i].cpu().numpy()
2027
+ if gt_prep.ndim == 4:
2028
+ gt_prep = gt_prep.squeeze(0) # (max_inst, n_nodes, 2)
2029
+ gt_orig = gt_prep / eff
2030
+ num_inst = batch["num_instances"][i].item()
2031
+ gt_orig = gt_orig[:num_inst] # Only valid instances
2032
+
2033
+ self.val_predictions.append(
2034
+ {
2035
+ "video_idx": batch["video_idx"][i].item(),
2036
+ "frame_idx": batch["frame_idx"][i].item(),
2037
+ "pred_peaks": pred_peaks, # Original space, variable instances
2038
+ "pred_scores": pred_scores,
2039
+ }
2040
+ )
2041
+ self.val_ground_truth.append(
2042
+ {
2043
+ "video_idx": batch["video_idx"][i].item(),
2044
+ "frame_idx": batch["frame_idx"][i].item(),
2045
+ "gt_instances": gt_orig, # Original space
2046
+ "num_instances": num_inst,
2047
+ }
2048
+ )
2049
+
1627
2050
 
1628
2051
  class TopDownCenteredInstanceMultiClassLightningModule(LightningModel):
1629
2052
  """Lightning Module for TopDownCenteredInstance ID Model.
@@ -1803,7 +2226,7 @@ class TopDownCenteredInstanceMultiClassLightningModule(LightningModel):
1803
2226
  channel_wise_loss = torch.sum(mse, dim=(0, 2, 3)) / (batch_size * h * w)
1804
2227
  for node_idx, name in enumerate(self.node_names):
1805
2228
  self.log(
1806
- f"{name}",
2229
+ f"train/confmaps/{name}",
1807
2230
  channel_wise_loss[node_idx],
1808
2231
  prog_bar=False,
1809
2232
  on_step=False,
@@ -1812,8 +2235,9 @@ class TopDownCenteredInstanceMultiClassLightningModule(LightningModel):
1812
2235
  sync_dist=True,
1813
2236
  )
1814
2237
 
2238
+ # Log step-level loss (every batch, uses global_step x-axis)
1815
2239
  self.log(
1816
- "train_loss",
2240
+ "loss",
1817
2241
  loss,
1818
2242
  prog_bar=True,
1819
2243
  on_step=True,
@@ -1821,8 +2245,10 @@ class TopDownCenteredInstanceMultiClassLightningModule(LightningModel):
1821
2245
  logger=True,
1822
2246
  sync_dist=True,
1823
2247
  )
2248
+ # Accumulate for epoch-averaged loss (logged in on_train_epoch_end)
2249
+ self._accumulate_loss(loss)
1824
2250
  self.log(
1825
- "train_confmap_loss",
2251
+ "train/confmaps_loss",
1826
2252
  confmap_loss,
1827
2253
  on_step=False,
1828
2254
  on_epoch=True,
@@ -1830,13 +2256,27 @@ class TopDownCenteredInstanceMultiClassLightningModule(LightningModel):
1830
2256
  sync_dist=True,
1831
2257
  )
1832
2258
  self.log(
1833
- "train_classvector_loss",
2259
+ "train/classvector_loss",
1834
2260
  classvector_loss,
1835
2261
  on_step=False,
1836
2262
  on_epoch=True,
1837
2263
  logger=True,
1838
2264
  sync_dist=True,
1839
2265
  )
2266
+
2267
+ # Compute classification accuracy
2268
+ with torch.no_grad():
2269
+ pred_classes = torch.argmax(classvector, dim=1)
2270
+ gt_classes = torch.argmax(y_classvector, dim=1)
2271
+ class_accuracy = (pred_classes == gt_classes).float().mean()
2272
+ self.log(
2273
+ "train/class_accuracy",
2274
+ class_accuracy,
2275
+ on_step=False,
2276
+ on_epoch=True,
2277
+ logger=True,
2278
+ sync_dist=True,
2279
+ )
1840
2280
  return loss
1841
2281
 
1842
2282
  def validation_step(self, batch, batch_idx):
@@ -1868,7 +2308,7 @@ class TopDownCenteredInstanceMultiClassLightningModule(LightningModel):
1868
2308
  }
1869
2309
  val_loss = sum([s * losses[t] for s, t in zip(self.loss_weights, losses)])
1870
2310
  self.log(
1871
- "val_loss",
2311
+ "val/loss",
1872
2312
  val_loss,
1873
2313
  prog_bar=True,
1874
2314
  on_step=False,
@@ -1877,7 +2317,7 @@ class TopDownCenteredInstanceMultiClassLightningModule(LightningModel):
1877
2317
  sync_dist=True,
1878
2318
  )
1879
2319
  self.log(
1880
- "val_confmap_loss",
2320
+ "val/confmaps_loss",
1881
2321
  confmap_loss,
1882
2322
  on_step=False,
1883
2323
  on_epoch=True,
@@ -1885,10 +2325,80 @@ class TopDownCenteredInstanceMultiClassLightningModule(LightningModel):
1885
2325
  sync_dist=True,
1886
2326
  )
1887
2327
  self.log(
1888
- "val_classvector_loss",
2328
+ "val/classvector_loss",
1889
2329
  classvector_loss,
1890
2330
  on_step=False,
1891
2331
  on_epoch=True,
1892
2332
  logger=True,
1893
2333
  sync_dist=True,
1894
2334
  )
2335
+
2336
+ # Compute classification accuracy
2337
+ with torch.no_grad():
2338
+ pred_classes = torch.argmax(classvector, dim=1)
2339
+ gt_classes = torch.argmax(y_classvector, dim=1)
2340
+ class_accuracy = (pred_classes == gt_classes).float().mean()
2341
+ self.log(
2342
+ "val/class_accuracy",
2343
+ class_accuracy,
2344
+ on_step=False,
2345
+ on_epoch=True,
2346
+ logger=True,
2347
+ sync_dist=True,
2348
+ )
2349
+
2350
+ # Collect predictions for epoch-end evaluation if enabled
2351
+ if self._collect_val_predictions:
2352
+ # SAVE bbox BEFORE inference (it modifies in-place!)
2353
+ bbox_prep_saved = batch["instance_bbox"].clone()
2354
+
2355
+ with torch.no_grad():
2356
+ inference_output = self.instance_peaks_inf_layer(batch)
2357
+
2358
+ batch_size = len(batch["frame_idx"])
2359
+ for i in range(batch_size):
2360
+ eff = batch["eff_scale"][i].cpu().numpy()
2361
+
2362
+ # Predictions from inference (crop-relative, original scale)
2363
+ pred_peaks_crop = (
2364
+ inference_output["pred_instance_peaks"][i].cpu().numpy()
2365
+ )
2366
+ pred_scores = inference_output["pred_peak_values"][i].cpu().numpy()
2367
+
2368
+ # Compute bbox offset in original space from SAVED prep bbox
2369
+ # bbox has shape (n_samples=1, 4, 2) where 4 corners
2370
+ bbox_prep = bbox_prep_saved[i].squeeze(0).cpu().numpy() # (4, 2)
2371
+ bbox_top_left_orig = (
2372
+ bbox_prep[0] / eff
2373
+ ) # Top-left corner in original space
2374
+
2375
+ # Full image coordinates (original space)
2376
+ pred_peaks_full = pred_peaks_crop + bbox_top_left_orig
2377
+
2378
+ # GT transform: crop-relative preprocessed -> full image original
2379
+ gt_crop_prep = (
2380
+ batch["instance"][i].squeeze(0).cpu().numpy()
2381
+ ) # (n_nodes, 2)
2382
+ gt_crop_orig = gt_crop_prep / eff
2383
+ gt_full_orig = gt_crop_orig + bbox_top_left_orig
2384
+
2385
+ self.val_predictions.append(
2386
+ {
2387
+ "video_idx": batch["video_idx"][i].item(),
2388
+ "frame_idx": batch["frame_idx"][i].item(),
2389
+ "pred_peaks": pred_peaks_full.reshape(
2390
+ 1, -1, 2
2391
+ ), # (1, n_nodes, 2)
2392
+ "pred_scores": pred_scores.reshape(1, -1), # (1, n_nodes)
2393
+ }
2394
+ )
2395
+ self.val_ground_truth.append(
2396
+ {
2397
+ "video_idx": batch["video_idx"][i].item(),
2398
+ "frame_idx": batch["frame_idx"][i].item(),
2399
+ "gt_instances": gt_full_orig.reshape(
2400
+ 1, -1, 2
2401
+ ), # (1, n_nodes, 2)
2402
+ "num_instances": 1,
2403
+ }
2404
+ )