sleap-nn 0.1.0a2__py3-none-any.whl → 0.1.0a3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. sleap_nn/__init__.py +1 -1
  2. sleap_nn/cli.py +36 -0
  3. sleap_nn/evaluation.py +8 -0
  4. sleap_nn/export/__init__.py +21 -0
  5. sleap_nn/export/cli.py +1778 -0
  6. sleap_nn/export/exporters/__init__.py +51 -0
  7. sleap_nn/export/exporters/onnx_exporter.py +80 -0
  8. sleap_nn/export/exporters/tensorrt_exporter.py +291 -0
  9. sleap_nn/export/metadata.py +225 -0
  10. sleap_nn/export/predictors/__init__.py +63 -0
  11. sleap_nn/export/predictors/base.py +22 -0
  12. sleap_nn/export/predictors/onnx.py +154 -0
  13. sleap_nn/export/predictors/tensorrt.py +312 -0
  14. sleap_nn/export/utils.py +307 -0
  15. sleap_nn/export/wrappers/__init__.py +25 -0
  16. sleap_nn/export/wrappers/base.py +96 -0
  17. sleap_nn/export/wrappers/bottomup.py +243 -0
  18. sleap_nn/export/wrappers/bottomup_multiclass.py +195 -0
  19. sleap_nn/export/wrappers/centered_instance.py +56 -0
  20. sleap_nn/export/wrappers/centroid.py +58 -0
  21. sleap_nn/export/wrappers/single_instance.py +83 -0
  22. sleap_nn/export/wrappers/topdown.py +180 -0
  23. sleap_nn/export/wrappers/topdown_multiclass.py +304 -0
  24. sleap_nn/inference/postprocessing.py +284 -0
  25. sleap_nn/predict.py +29 -0
  26. sleap_nn/train.py +64 -0
  27. sleap_nn/training/callbacks.py +62 -20
  28. sleap_nn/training/lightning_modules.py +332 -30
  29. sleap_nn/training/model_trainer.py +35 -67
  30. {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a3.dist-info}/METADATA +12 -1
  31. {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a3.dist-info}/RECORD +35 -14
  32. {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a3.dist-info}/WHEEL +0 -0
  33. {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a3.dist-info}/entry_points.txt +0 -0
  34. {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a3.dist-info}/licenses/LICENSE +0 -0
  35. {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a3.dist-info}/top_level.txt +0 -0
@@ -184,6 +184,10 @@ class LightningModel(L.LightningModule):
184
184
  self.val_loss = {}
185
185
  self.learning_rate = {}
186
186
 
187
+ # For epoch-averaged loss tracking
188
+ self._epoch_loss_sum = 0.0
189
+ self._epoch_loss_count = 0
190
+
187
191
  # For epoch-end evaluation
188
192
  self.val_predictions: List[Dict] = []
189
193
  self.val_ground_truth: List[Dict] = []
@@ -310,12 +314,20 @@ class LightningModel(L.LightningModule):
310
314
  def on_train_epoch_start(self):
311
315
  """Configure the train timer at the beginning of each epoch."""
312
316
  self.train_start_time = time.time()
317
+ # Reset epoch loss tracking
318
+ self._epoch_loss_sum = 0.0
319
+ self._epoch_loss_count = 0
320
+
321
+ def _accumulate_loss(self, loss: torch.Tensor):
322
+ """Accumulate loss for epoch-averaged logging. Call this in training_step."""
323
+ self._epoch_loss_sum += loss.detach().item()
324
+ self._epoch_loss_count += 1
313
325
 
314
326
  def on_train_epoch_end(self):
315
327
  """Configure the train timer at the end of every epoch."""
316
328
  train_time = time.time() - self.train_start_time
317
329
  self.log(
318
- "train_time",
330
+ "train/time",
319
331
  train_time,
320
332
  prog_bar=False,
321
333
  on_step=False,
@@ -332,6 +344,30 @@ class LightningModel(L.LightningModule):
332
344
  logger=True,
333
345
  sync_dist=True,
334
346
  )
347
+ # Log epoch-averaged training loss
348
+ if self._epoch_loss_count > 0:
349
+ avg_loss = self._epoch_loss_sum / self._epoch_loss_count
350
+ self.log(
351
+ "train/loss",
352
+ avg_loss,
353
+ prog_bar=False,
354
+ on_step=False,
355
+ on_epoch=True,
356
+ logger=True,
357
+ sync_dist=True,
358
+ )
359
+ # Log current learning rate (useful for monitoring LR schedulers)
360
+ if self.trainer.optimizers:
361
+ lr = self.trainer.optimizers[0].param_groups[0]["lr"]
362
+ self.log(
363
+ "train/lr",
364
+ lr,
365
+ prog_bar=False,
366
+ on_step=False,
367
+ on_epoch=True,
368
+ logger=True,
369
+ sync_dist=True,
370
+ )
335
371
 
336
372
  def on_validation_epoch_start(self):
337
373
  """Configure the val timer at the beginning of each epoch."""
@@ -344,7 +380,7 @@ class LightningModel(L.LightningModule):
344
380
  """Configure the val timer at the end of every epoch."""
345
381
  val_time = time.time() - self.val_start_time
346
382
  self.log(
347
- "val_time",
383
+ "val/time",
348
384
  val_time,
349
385
  prog_bar=False,
350
386
  on_step=False,
@@ -352,6 +388,16 @@ class LightningModel(L.LightningModule):
352
388
  logger=True,
353
389
  sync_dist=True,
354
390
  )
391
+ # Log epoch explicitly so val/* metrics can use it as x-axis in wandb
392
+ # (mirrors what on_train_epoch_end does for train/* metrics)
393
+ self.log(
394
+ "epoch",
395
+ float(self.current_epoch),
396
+ on_step=False,
397
+ on_epoch=True,
398
+ logger=True,
399
+ sync_dist=True,
400
+ )
355
401
 
356
402
  def training_step(self, batch, batch_idx):
357
403
  """Training step."""
@@ -420,7 +466,7 @@ class LightningModel(L.LightningModule):
420
466
  "optimizer": optimizer,
421
467
  "lr_scheduler": {
422
468
  "scheduler": scheduler,
423
- "monitor": "val_loss",
469
+ "monitor": "val/loss",
424
470
  },
425
471
  }
426
472
 
@@ -599,7 +645,7 @@ class SingleInstanceLightningModule(LightningModel):
599
645
  channel_wise_loss = torch.sum(mse, dim=(0, 2, 3)) / (batch_size * h * w)
600
646
  for node_idx, name in enumerate(self.node_names):
601
647
  self.log(
602
- f"{name}",
648
+ f"train/confmaps/{name}",
603
649
  channel_wise_loss[node_idx],
604
650
  prog_bar=False,
605
651
  on_step=False,
@@ -607,8 +653,9 @@ class SingleInstanceLightningModule(LightningModel):
607
653
  logger=True,
608
654
  sync_dist=True,
609
655
  )
656
+ # Log step-level loss (every batch, uses global_step x-axis)
610
657
  self.log(
611
- "train_loss",
658
+ "loss",
612
659
  loss,
613
660
  prog_bar=True,
614
661
  on_step=True,
@@ -616,6 +663,8 @@ class SingleInstanceLightningModule(LightningModel):
616
663
  logger=True,
617
664
  sync_dist=True,
618
665
  )
666
+ # Accumulate for epoch-averaged loss (logged in on_train_epoch_end)
667
+ self._accumulate_loss(loss)
619
668
  return loss
620
669
 
621
670
  def validation_step(self, batch, batch_idx):
@@ -638,7 +687,7 @@ class SingleInstanceLightningModule(LightningModel):
638
687
  )
639
688
  val_loss = val_loss + ohkm_loss
640
689
  self.log(
641
- "val_loss",
690
+ "val/loss",
642
691
  val_loss,
643
692
  prog_bar=True,
644
693
  on_step=False,
@@ -860,7 +909,7 @@ class TopDownCenteredInstanceLightningModule(LightningModel):
860
909
  channel_wise_loss = torch.sum(mse, dim=(0, 2, 3)) / (batch_size * h * w)
861
910
  for node_idx, name in enumerate(self.node_names):
862
911
  self.log(
863
- f"{name}",
912
+ f"train/confmaps/{name}",
864
913
  channel_wise_loss[node_idx],
865
914
  prog_bar=False,
866
915
  on_step=False,
@@ -869,8 +918,9 @@ class TopDownCenteredInstanceLightningModule(LightningModel):
869
918
  sync_dist=True,
870
919
  )
871
920
 
921
+ # Log step-level loss (every batch, uses global_step x-axis)
872
922
  self.log(
873
- "train_loss",
923
+ "loss",
874
924
  loss,
875
925
  prog_bar=True,
876
926
  on_step=True,
@@ -878,6 +928,8 @@ class TopDownCenteredInstanceLightningModule(LightningModel):
878
928
  logger=True,
879
929
  sync_dist=True,
880
930
  )
931
+ # Accumulate for epoch-averaged loss (logged in on_train_epoch_end)
932
+ self._accumulate_loss(loss)
881
933
  return loss
882
934
 
883
935
  def validation_step(self, batch, batch_idx):
@@ -900,7 +952,7 @@ class TopDownCenteredInstanceLightningModule(LightningModel):
900
952
  )
901
953
  val_loss = val_loss + ohkm_loss
902
954
  self.log(
903
- "val_loss",
955
+ "val/loss",
904
956
  val_loss,
905
957
  prog_bar=True,
906
958
  on_step=False,
@@ -1113,8 +1165,9 @@ class CentroidLightningModule(LightningModel):
1113
1165
 
1114
1166
  y_preds = self.model(X)["CentroidConfmapsHead"]
1115
1167
  loss = nn.MSELoss()(y_preds, y)
1168
+ # Log step-level loss (every batch, uses global_step x-axis)
1116
1169
  self.log(
1117
- "train_loss",
1170
+ "loss",
1118
1171
  loss,
1119
1172
  prog_bar=True,
1120
1173
  on_step=True,
@@ -1122,6 +1175,8 @@ class CentroidLightningModule(LightningModel):
1122
1175
  logger=True,
1123
1176
  sync_dist=True,
1124
1177
  )
1178
+ # Accumulate for epoch-averaged loss (logged in on_train_epoch_end)
1179
+ self._accumulate_loss(loss)
1125
1180
  return loss
1126
1181
 
1127
1182
  def validation_step(self, batch, batch_idx):
@@ -1134,7 +1189,7 @@ class CentroidLightningModule(LightningModel):
1134
1189
  y_preds = self.model(X)["CentroidConfmapsHead"]
1135
1190
  val_loss = nn.MSELoss()(y_preds, y)
1136
1191
  self.log(
1137
- "val_loss",
1192
+ "val/loss",
1138
1193
  val_loss,
1139
1194
  prog_bar=True,
1140
1195
  on_step=False,
@@ -1409,8 +1464,9 @@ class BottomUpLightningModule(LightningModel):
1409
1464
  "PartAffinityFieldsHead": pafs_loss,
1410
1465
  }
1411
1466
  loss = sum([s * losses[t] for s, t in zip(self.loss_weights, losses)])
1467
+ # Log step-level loss (every batch, uses global_step x-axis)
1412
1468
  self.log(
1413
- "train_loss",
1469
+ "loss",
1414
1470
  loss,
1415
1471
  prog_bar=True,
1416
1472
  on_step=True,
@@ -1418,8 +1474,10 @@ class BottomUpLightningModule(LightningModel):
1418
1474
  logger=True,
1419
1475
  sync_dist=True,
1420
1476
  )
1477
+ # Accumulate for epoch-averaged loss (logged in on_train_epoch_end)
1478
+ self._accumulate_loss(loss)
1421
1479
  self.log(
1422
- "train_confmap_loss",
1480
+ "train/confmaps_loss",
1423
1481
  confmap_loss,
1424
1482
  on_step=False,
1425
1483
  on_epoch=True,
@@ -1427,7 +1485,7 @@ class BottomUpLightningModule(LightningModel):
1427
1485
  sync_dist=True,
1428
1486
  )
1429
1487
  self.log(
1430
- "train_paf_loss",
1488
+ "train/paf_loss",
1431
1489
  pafs_loss,
1432
1490
  on_step=False,
1433
1491
  on_epoch=True,
@@ -1476,7 +1534,7 @@ class BottomUpLightningModule(LightningModel):
1476
1534
 
1477
1535
  val_loss = sum([s * losses[t] for s, t in zip(self.loss_weights, losses)])
1478
1536
  self.log(
1479
- "val_loss",
1537
+ "val/loss",
1480
1538
  val_loss,
1481
1539
  prog_bar=True,
1482
1540
  on_step=False,
@@ -1485,7 +1543,7 @@ class BottomUpLightningModule(LightningModel):
1485
1543
  sync_dist=True,
1486
1544
  )
1487
1545
  self.log(
1488
- "val_confmap_loss",
1546
+ "val/confmaps_loss",
1489
1547
  confmap_loss,
1490
1548
  on_step=False,
1491
1549
  on_epoch=True,
@@ -1493,7 +1551,7 @@ class BottomUpLightningModule(LightningModel):
1493
1551
  sync_dist=True,
1494
1552
  )
1495
1553
  self.log(
1496
- "val_paf_loss",
1554
+ "val/paf_loss",
1497
1555
  pafs_loss,
1498
1556
  on_step=False,
1499
1557
  on_epoch=True,
@@ -1749,8 +1807,9 @@ class BottomUpMultiClassLightningModule(LightningModel):
1749
1807
  "ClassMapsHead": classmaps_loss,
1750
1808
  }
1751
1809
  loss = sum([s * losses[t] for s, t in zip(self.loss_weights, losses)])
1810
+ # Log step-level loss (every batch, uses global_step x-axis)
1752
1811
  self.log(
1753
- "train_loss",
1812
+ "loss",
1754
1813
  loss,
1755
1814
  prog_bar=True,
1756
1815
  on_step=True,
@@ -1758,8 +1817,10 @@ class BottomUpMultiClassLightningModule(LightningModel):
1758
1817
  logger=True,
1759
1818
  sync_dist=True,
1760
1819
  )
1820
+ # Accumulate for epoch-averaged loss (logged in on_train_epoch_end)
1821
+ self._accumulate_loss(loss)
1761
1822
  self.log(
1762
- "train_confmap_loss",
1823
+ "train/confmaps_loss",
1763
1824
  confmap_loss,
1764
1825
  on_step=False,
1765
1826
  on_epoch=True,
@@ -1767,13 +1828,67 @@ class BottomUpMultiClassLightningModule(LightningModel):
1767
1828
  sync_dist=True,
1768
1829
  )
1769
1830
  self.log(
1770
- "train_classmap_loss",
1831
+ "train/classmap_loss",
1771
1832
  classmaps_loss,
1772
1833
  on_step=False,
1773
1834
  on_epoch=True,
1774
1835
  logger=True,
1775
1836
  sync_dist=True,
1776
1837
  )
1838
+
1839
+ # Compute classification accuracy at GT keypoint locations
1840
+ with torch.no_grad():
1841
+ # Get output stride for class maps
1842
+ cms_stride = self.head_configs.multi_class_bottomup.class_maps.output_stride
1843
+
1844
+ # Get GT instances and sample class maps at those locations
1845
+ instances = batch["instances"] # (batch, n_samples, max_inst, n_nodes, 2)
1846
+ if instances.dim() == 5:
1847
+ instances = instances.squeeze(1) # (batch, max_inst, n_nodes, 2)
1848
+ num_instances = batch["num_instances"] # (batch,)
1849
+
1850
+ correct = 0
1851
+ total = 0
1852
+ for b in range(instances.shape[0]):
1853
+ n_inst = num_instances[b].item()
1854
+ for inst_idx in range(n_inst):
1855
+ for node_idx in range(instances.shape[2]):
1856
+ # Get keypoint location (in input image space)
1857
+ kp = instances[b, inst_idx, node_idx] # (2,) = (x, y)
1858
+ if torch.isnan(kp).any():
1859
+ continue
1860
+
1861
+ # Convert to class map space
1862
+ x_cm = (
1863
+ (kp[0] / cms_stride)
1864
+ .long()
1865
+ .clamp(0, classmaps.shape[-1] - 1)
1866
+ )
1867
+ y_cm = (
1868
+ (kp[1] / cms_stride)
1869
+ .long()
1870
+ .clamp(0, classmaps.shape[-2] - 1)
1871
+ )
1872
+
1873
+ # Sample predicted and GT class at this location
1874
+ pred_class = classmaps[b, :, y_cm, x_cm].argmax()
1875
+ gt_class = y_classmap[b, :, y_cm, x_cm].argmax()
1876
+
1877
+ if pred_class == gt_class:
1878
+ correct += 1
1879
+ total += 1
1880
+
1881
+ if total > 0:
1882
+ class_accuracy = torch.tensor(correct / total, device=X.device)
1883
+ self.log(
1884
+ "train/class_accuracy",
1885
+ class_accuracy,
1886
+ on_step=False,
1887
+ on_epoch=True,
1888
+ logger=True,
1889
+ sync_dist=True,
1890
+ )
1891
+
1777
1892
  return loss
1778
1893
 
1779
1894
  def validation_step(self, batch, batch_idx):
@@ -1807,7 +1922,7 @@ class BottomUpMultiClassLightningModule(LightningModel):
1807
1922
 
1808
1923
  val_loss = sum([s * losses[t] for s, t in zip(self.loss_weights, losses)])
1809
1924
  self.log(
1810
- "val_loss",
1925
+ "val/loss",
1811
1926
  val_loss,
1812
1927
  prog_bar=True,
1813
1928
  on_step=False,
@@ -1816,7 +1931,7 @@ class BottomUpMultiClassLightningModule(LightningModel):
1816
1931
  sync_dist=True,
1817
1932
  )
1818
1933
  self.log(
1819
- "val_confmap_loss",
1934
+ "val/confmaps_loss",
1820
1935
  confmap_loss,
1821
1936
  on_step=False,
1822
1937
  on_epoch=True,
@@ -1824,7 +1939,7 @@ class BottomUpMultiClassLightningModule(LightningModel):
1824
1939
  sync_dist=True,
1825
1940
  )
1826
1941
  self.log(
1827
- "val_classmap_loss",
1942
+ "val/classmap_loss",
1828
1943
  classmaps_loss,
1829
1944
  on_step=False,
1830
1945
  on_epoch=True,
@@ -1832,6 +1947,106 @@ class BottomUpMultiClassLightningModule(LightningModel):
1832
1947
  sync_dist=True,
1833
1948
  )
1834
1949
 
1950
+ # Compute classification accuracy at GT keypoint locations
1951
+ with torch.no_grad():
1952
+ # Get output stride for class maps
1953
+ cms_stride = self.head_configs.multi_class_bottomup.class_maps.output_stride
1954
+
1955
+ # Get GT instances and sample class maps at those locations
1956
+ instances = batch["instances"] # (batch, n_samples, max_inst, n_nodes, 2)
1957
+ if instances.dim() == 5:
1958
+ instances = instances.squeeze(1) # (batch, max_inst, n_nodes, 2)
1959
+ num_instances = batch["num_instances"] # (batch,)
1960
+
1961
+ correct = 0
1962
+ total = 0
1963
+ for b in range(instances.shape[0]):
1964
+ n_inst = num_instances[b].item()
1965
+ for inst_idx in range(n_inst):
1966
+ for node_idx in range(instances.shape[2]):
1967
+ # Get keypoint location (in input image space)
1968
+ kp = instances[b, inst_idx, node_idx] # (2,) = (x, y)
1969
+ if torch.isnan(kp).any():
1970
+ continue
1971
+
1972
+ # Convert to class map space
1973
+ x_cm = (
1974
+ (kp[0] / cms_stride)
1975
+ .long()
1976
+ .clamp(0, classmaps.shape[-1] - 1)
1977
+ )
1978
+ y_cm = (
1979
+ (kp[1] / cms_stride)
1980
+ .long()
1981
+ .clamp(0, classmaps.shape[-2] - 1)
1982
+ )
1983
+
1984
+ # Sample predicted and GT class at this location
1985
+ pred_class = classmaps[b, :, y_cm, x_cm].argmax()
1986
+ gt_class = y_classmap[b, :, y_cm, x_cm].argmax()
1987
+
1988
+ if pred_class == gt_class:
1989
+ correct += 1
1990
+ total += 1
1991
+
1992
+ if total > 0:
1993
+ class_accuracy = torch.tensor(correct / total, device=X.device)
1994
+ self.log(
1995
+ "val/class_accuracy",
1996
+ class_accuracy,
1997
+ on_step=False,
1998
+ on_epoch=True,
1999
+ logger=True,
2000
+ sync_dist=True,
2001
+ )
2002
+
2003
+ # Collect predictions for epoch-end evaluation if enabled
2004
+ if self._collect_val_predictions:
2005
+ with torch.no_grad():
2006
+ # Note: Do NOT squeeze the image here - the forward() method expects
2007
+ # (batch, n_samples, C, H, W) and handles the n_samples squeeze internally
2008
+ inference_output = self.bottomup_inf_layer(batch)
2009
+ if isinstance(inference_output, list):
2010
+ inference_output = inference_output[0]
2011
+
2012
+ batch_size = len(batch["frame_idx"])
2013
+ for i in range(batch_size):
2014
+ eff = batch["eff_scale"][i].cpu().numpy()
2015
+
2016
+ # Predictions are already in original space (variable number of instances)
2017
+ pred_peaks = inference_output["pred_instance_peaks"][i]
2018
+ pred_scores = inference_output["pred_peak_values"][i]
2019
+ if torch.is_tensor(pred_peaks):
2020
+ pred_peaks = pred_peaks.cpu().numpy()
2021
+ if torch.is_tensor(pred_scores):
2022
+ pred_scores = pred_scores.cpu().numpy()
2023
+
2024
+ # Transform GT to original space
2025
+ # Note: instances have shape (1, max_inst, n_nodes, 2) - squeeze n_samples dim
2026
+ gt_prep = batch["instances"][i].cpu().numpy()
2027
+ if gt_prep.ndim == 4:
2028
+ gt_prep = gt_prep.squeeze(0) # (max_inst, n_nodes, 2)
2029
+ gt_orig = gt_prep / eff
2030
+ num_inst = batch["num_instances"][i].item()
2031
+ gt_orig = gt_orig[:num_inst] # Only valid instances
2032
+
2033
+ self.val_predictions.append(
2034
+ {
2035
+ "video_idx": batch["video_idx"][i].item(),
2036
+ "frame_idx": batch["frame_idx"][i].item(),
2037
+ "pred_peaks": pred_peaks, # Original space, variable instances
2038
+ "pred_scores": pred_scores,
2039
+ }
2040
+ )
2041
+ self.val_ground_truth.append(
2042
+ {
2043
+ "video_idx": batch["video_idx"][i].item(),
2044
+ "frame_idx": batch["frame_idx"][i].item(),
2045
+ "gt_instances": gt_orig, # Original space
2046
+ "num_instances": num_inst,
2047
+ }
2048
+ )
2049
+
1835
2050
 
1836
2051
  class TopDownCenteredInstanceMultiClassLightningModule(LightningModel):
1837
2052
  """Lightning Module for TopDownCenteredInstance ID Model.
@@ -2011,7 +2226,7 @@ class TopDownCenteredInstanceMultiClassLightningModule(LightningModel):
2011
2226
  channel_wise_loss = torch.sum(mse, dim=(0, 2, 3)) / (batch_size * h * w)
2012
2227
  for node_idx, name in enumerate(self.node_names):
2013
2228
  self.log(
2014
- f"{name}",
2229
+ f"train/confmaps/{name}",
2015
2230
  channel_wise_loss[node_idx],
2016
2231
  prog_bar=False,
2017
2232
  on_step=False,
@@ -2020,8 +2235,9 @@ class TopDownCenteredInstanceMultiClassLightningModule(LightningModel):
2020
2235
  sync_dist=True,
2021
2236
  )
2022
2237
 
2238
+ # Log step-level loss (every batch, uses global_step x-axis)
2023
2239
  self.log(
2024
- "train_loss",
2240
+ "loss",
2025
2241
  loss,
2026
2242
  prog_bar=True,
2027
2243
  on_step=True,
@@ -2029,8 +2245,10 @@ class TopDownCenteredInstanceMultiClassLightningModule(LightningModel):
2029
2245
  logger=True,
2030
2246
  sync_dist=True,
2031
2247
  )
2248
+ # Accumulate for epoch-averaged loss (logged in on_train_epoch_end)
2249
+ self._accumulate_loss(loss)
2032
2250
  self.log(
2033
- "train_confmap_loss",
2251
+ "train/confmaps_loss",
2034
2252
  confmap_loss,
2035
2253
  on_step=False,
2036
2254
  on_epoch=True,
@@ -2038,13 +2256,27 @@ class TopDownCenteredInstanceMultiClassLightningModule(LightningModel):
2038
2256
  sync_dist=True,
2039
2257
  )
2040
2258
  self.log(
2041
- "train_classvector_loss",
2259
+ "train/classvector_loss",
2042
2260
  classvector_loss,
2043
2261
  on_step=False,
2044
2262
  on_epoch=True,
2045
2263
  logger=True,
2046
2264
  sync_dist=True,
2047
2265
  )
2266
+
2267
+ # Compute classification accuracy
2268
+ with torch.no_grad():
2269
+ pred_classes = torch.argmax(classvector, dim=1)
2270
+ gt_classes = torch.argmax(y_classvector, dim=1)
2271
+ class_accuracy = (pred_classes == gt_classes).float().mean()
2272
+ self.log(
2273
+ "train/class_accuracy",
2274
+ class_accuracy,
2275
+ on_step=False,
2276
+ on_epoch=True,
2277
+ logger=True,
2278
+ sync_dist=True,
2279
+ )
2048
2280
  return loss
2049
2281
 
2050
2282
  def validation_step(self, batch, batch_idx):
@@ -2076,7 +2308,7 @@ class TopDownCenteredInstanceMultiClassLightningModule(LightningModel):
2076
2308
  }
2077
2309
  val_loss = sum([s * losses[t] for s, t in zip(self.loss_weights, losses)])
2078
2310
  self.log(
2079
- "val_loss",
2311
+ "val/loss",
2080
2312
  val_loss,
2081
2313
  prog_bar=True,
2082
2314
  on_step=False,
@@ -2085,7 +2317,7 @@ class TopDownCenteredInstanceMultiClassLightningModule(LightningModel):
2085
2317
  sync_dist=True,
2086
2318
  )
2087
2319
  self.log(
2088
- "val_confmap_loss",
2320
+ "val/confmaps_loss",
2089
2321
  confmap_loss,
2090
2322
  on_step=False,
2091
2323
  on_epoch=True,
@@ -2093,10 +2325,80 @@ class TopDownCenteredInstanceMultiClassLightningModule(LightningModel):
2093
2325
  sync_dist=True,
2094
2326
  )
2095
2327
  self.log(
2096
- "val_classvector_loss",
2328
+ "val/classvector_loss",
2097
2329
  classvector_loss,
2098
2330
  on_step=False,
2099
2331
  on_epoch=True,
2100
2332
  logger=True,
2101
2333
  sync_dist=True,
2102
2334
  )
2335
+
2336
+ # Compute classification accuracy
2337
+ with torch.no_grad():
2338
+ pred_classes = torch.argmax(classvector, dim=1)
2339
+ gt_classes = torch.argmax(y_classvector, dim=1)
2340
+ class_accuracy = (pred_classes == gt_classes).float().mean()
2341
+ self.log(
2342
+ "val/class_accuracy",
2343
+ class_accuracy,
2344
+ on_step=False,
2345
+ on_epoch=True,
2346
+ logger=True,
2347
+ sync_dist=True,
2348
+ )
2349
+
2350
+ # Collect predictions for epoch-end evaluation if enabled
2351
+ if self._collect_val_predictions:
2352
+ # SAVE bbox BEFORE inference (it modifies in-place!)
2353
+ bbox_prep_saved = batch["instance_bbox"].clone()
2354
+
2355
+ with torch.no_grad():
2356
+ inference_output = self.instance_peaks_inf_layer(batch)
2357
+
2358
+ batch_size = len(batch["frame_idx"])
2359
+ for i in range(batch_size):
2360
+ eff = batch["eff_scale"][i].cpu().numpy()
2361
+
2362
+ # Predictions from inference (crop-relative, original scale)
2363
+ pred_peaks_crop = (
2364
+ inference_output["pred_instance_peaks"][i].cpu().numpy()
2365
+ )
2366
+ pred_scores = inference_output["pred_peak_values"][i].cpu().numpy()
2367
+
2368
+ # Compute bbox offset in original space from SAVED prep bbox
2369
+ # bbox has shape (n_samples=1, 4, 2) where 4 corners
2370
+ bbox_prep = bbox_prep_saved[i].squeeze(0).cpu().numpy() # (4, 2)
2371
+ bbox_top_left_orig = (
2372
+ bbox_prep[0] / eff
2373
+ ) # Top-left corner in original space
2374
+
2375
+ # Full image coordinates (original space)
2376
+ pred_peaks_full = pred_peaks_crop + bbox_top_left_orig
2377
+
2378
+ # GT transform: crop-relative preprocessed -> full image original
2379
+ gt_crop_prep = (
2380
+ batch["instance"][i].squeeze(0).cpu().numpy()
2381
+ ) # (n_nodes, 2)
2382
+ gt_crop_orig = gt_crop_prep / eff
2383
+ gt_full_orig = gt_crop_orig + bbox_top_left_orig
2384
+
2385
+ self.val_predictions.append(
2386
+ {
2387
+ "video_idx": batch["video_idx"][i].item(),
2388
+ "frame_idx": batch["frame_idx"][i].item(),
2389
+ "pred_peaks": pred_peaks_full.reshape(
2390
+ 1, -1, 2
2391
+ ), # (1, n_nodes, 2)
2392
+ "pred_scores": pred_scores.reshape(1, -1), # (1, n_nodes)
2393
+ }
2394
+ )
2395
+ self.val_ground_truth.append(
2396
+ {
2397
+ "video_idx": batch["video_idx"][i].item(),
2398
+ "frame_idx": batch["frame_idx"][i].item(),
2399
+ "gt_instances": gt_full_orig.reshape(
2400
+ 1, -1, 2
2401
+ ), # (1, n_nodes, 2)
2402
+ "num_instances": 1,
2403
+ }
2404
+ )