sleap-nn 0.1.0a2__py3-none-any.whl → 0.1.0a3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sleap_nn/__init__.py +1 -1
- sleap_nn/cli.py +36 -0
- sleap_nn/evaluation.py +8 -0
- sleap_nn/export/__init__.py +21 -0
- sleap_nn/export/cli.py +1778 -0
- sleap_nn/export/exporters/__init__.py +51 -0
- sleap_nn/export/exporters/onnx_exporter.py +80 -0
- sleap_nn/export/exporters/tensorrt_exporter.py +291 -0
- sleap_nn/export/metadata.py +225 -0
- sleap_nn/export/predictors/__init__.py +63 -0
- sleap_nn/export/predictors/base.py +22 -0
- sleap_nn/export/predictors/onnx.py +154 -0
- sleap_nn/export/predictors/tensorrt.py +312 -0
- sleap_nn/export/utils.py +307 -0
- sleap_nn/export/wrappers/__init__.py +25 -0
- sleap_nn/export/wrappers/base.py +96 -0
- sleap_nn/export/wrappers/bottomup.py +243 -0
- sleap_nn/export/wrappers/bottomup_multiclass.py +195 -0
- sleap_nn/export/wrappers/centered_instance.py +56 -0
- sleap_nn/export/wrappers/centroid.py +58 -0
- sleap_nn/export/wrappers/single_instance.py +83 -0
- sleap_nn/export/wrappers/topdown.py +180 -0
- sleap_nn/export/wrappers/topdown_multiclass.py +304 -0
- sleap_nn/inference/postprocessing.py +284 -0
- sleap_nn/predict.py +29 -0
- sleap_nn/train.py +64 -0
- sleap_nn/training/callbacks.py +62 -20
- sleap_nn/training/lightning_modules.py +332 -30
- sleap_nn/training/model_trainer.py +35 -67
- {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a3.dist-info}/METADATA +12 -1
- {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a3.dist-info}/RECORD +35 -14
- {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a3.dist-info}/WHEEL +0 -0
- {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a3.dist-info}/entry_points.txt +0 -0
- {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a3.dist-info}/licenses/LICENSE +0 -0
- {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a3.dist-info}/top_level.txt +0 -0
|
@@ -184,6 +184,10 @@ class LightningModel(L.LightningModule):
|
|
|
184
184
|
self.val_loss = {}
|
|
185
185
|
self.learning_rate = {}
|
|
186
186
|
|
|
187
|
+
# For epoch-averaged loss tracking
|
|
188
|
+
self._epoch_loss_sum = 0.0
|
|
189
|
+
self._epoch_loss_count = 0
|
|
190
|
+
|
|
187
191
|
# For epoch-end evaluation
|
|
188
192
|
self.val_predictions: List[Dict] = []
|
|
189
193
|
self.val_ground_truth: List[Dict] = []
|
|
@@ -310,12 +314,20 @@ class LightningModel(L.LightningModule):
|
|
|
310
314
|
def on_train_epoch_start(self):
|
|
311
315
|
"""Configure the train timer at the beginning of each epoch."""
|
|
312
316
|
self.train_start_time = time.time()
|
|
317
|
+
# Reset epoch loss tracking
|
|
318
|
+
self._epoch_loss_sum = 0.0
|
|
319
|
+
self._epoch_loss_count = 0
|
|
320
|
+
|
|
321
|
+
def _accumulate_loss(self, loss: torch.Tensor):
|
|
322
|
+
"""Accumulate loss for epoch-averaged logging. Call this in training_step."""
|
|
323
|
+
self._epoch_loss_sum += loss.detach().item()
|
|
324
|
+
self._epoch_loss_count += 1
|
|
313
325
|
|
|
314
326
|
def on_train_epoch_end(self):
|
|
315
327
|
"""Configure the train timer at the end of every epoch."""
|
|
316
328
|
train_time = time.time() - self.train_start_time
|
|
317
329
|
self.log(
|
|
318
|
-
"
|
|
330
|
+
"train/time",
|
|
319
331
|
train_time,
|
|
320
332
|
prog_bar=False,
|
|
321
333
|
on_step=False,
|
|
@@ -332,6 +344,30 @@ class LightningModel(L.LightningModule):
|
|
|
332
344
|
logger=True,
|
|
333
345
|
sync_dist=True,
|
|
334
346
|
)
|
|
347
|
+
# Log epoch-averaged training loss
|
|
348
|
+
if self._epoch_loss_count > 0:
|
|
349
|
+
avg_loss = self._epoch_loss_sum / self._epoch_loss_count
|
|
350
|
+
self.log(
|
|
351
|
+
"train/loss",
|
|
352
|
+
avg_loss,
|
|
353
|
+
prog_bar=False,
|
|
354
|
+
on_step=False,
|
|
355
|
+
on_epoch=True,
|
|
356
|
+
logger=True,
|
|
357
|
+
sync_dist=True,
|
|
358
|
+
)
|
|
359
|
+
# Log current learning rate (useful for monitoring LR schedulers)
|
|
360
|
+
if self.trainer.optimizers:
|
|
361
|
+
lr = self.trainer.optimizers[0].param_groups[0]["lr"]
|
|
362
|
+
self.log(
|
|
363
|
+
"train/lr",
|
|
364
|
+
lr,
|
|
365
|
+
prog_bar=False,
|
|
366
|
+
on_step=False,
|
|
367
|
+
on_epoch=True,
|
|
368
|
+
logger=True,
|
|
369
|
+
sync_dist=True,
|
|
370
|
+
)
|
|
335
371
|
|
|
336
372
|
def on_validation_epoch_start(self):
|
|
337
373
|
"""Configure the val timer at the beginning of each epoch."""
|
|
@@ -344,7 +380,7 @@ class LightningModel(L.LightningModule):
|
|
|
344
380
|
"""Configure the val timer at the end of every epoch."""
|
|
345
381
|
val_time = time.time() - self.val_start_time
|
|
346
382
|
self.log(
|
|
347
|
-
"
|
|
383
|
+
"val/time",
|
|
348
384
|
val_time,
|
|
349
385
|
prog_bar=False,
|
|
350
386
|
on_step=False,
|
|
@@ -352,6 +388,16 @@ class LightningModel(L.LightningModule):
|
|
|
352
388
|
logger=True,
|
|
353
389
|
sync_dist=True,
|
|
354
390
|
)
|
|
391
|
+
# Log epoch explicitly so val/* metrics can use it as x-axis in wandb
|
|
392
|
+
# (mirrors what on_train_epoch_end does for train/* metrics)
|
|
393
|
+
self.log(
|
|
394
|
+
"epoch",
|
|
395
|
+
float(self.current_epoch),
|
|
396
|
+
on_step=False,
|
|
397
|
+
on_epoch=True,
|
|
398
|
+
logger=True,
|
|
399
|
+
sync_dist=True,
|
|
400
|
+
)
|
|
355
401
|
|
|
356
402
|
def training_step(self, batch, batch_idx):
|
|
357
403
|
"""Training step."""
|
|
@@ -420,7 +466,7 @@ class LightningModel(L.LightningModule):
|
|
|
420
466
|
"optimizer": optimizer,
|
|
421
467
|
"lr_scheduler": {
|
|
422
468
|
"scheduler": scheduler,
|
|
423
|
-
"monitor": "
|
|
469
|
+
"monitor": "val/loss",
|
|
424
470
|
},
|
|
425
471
|
}
|
|
426
472
|
|
|
@@ -599,7 +645,7 @@ class SingleInstanceLightningModule(LightningModel):
|
|
|
599
645
|
channel_wise_loss = torch.sum(mse, dim=(0, 2, 3)) / (batch_size * h * w)
|
|
600
646
|
for node_idx, name in enumerate(self.node_names):
|
|
601
647
|
self.log(
|
|
602
|
-
f"{name}",
|
|
648
|
+
f"train/confmaps/{name}",
|
|
603
649
|
channel_wise_loss[node_idx],
|
|
604
650
|
prog_bar=False,
|
|
605
651
|
on_step=False,
|
|
@@ -607,8 +653,9 @@ class SingleInstanceLightningModule(LightningModel):
|
|
|
607
653
|
logger=True,
|
|
608
654
|
sync_dist=True,
|
|
609
655
|
)
|
|
656
|
+
# Log step-level loss (every batch, uses global_step x-axis)
|
|
610
657
|
self.log(
|
|
611
|
-
"
|
|
658
|
+
"loss",
|
|
612
659
|
loss,
|
|
613
660
|
prog_bar=True,
|
|
614
661
|
on_step=True,
|
|
@@ -616,6 +663,8 @@ class SingleInstanceLightningModule(LightningModel):
|
|
|
616
663
|
logger=True,
|
|
617
664
|
sync_dist=True,
|
|
618
665
|
)
|
|
666
|
+
# Accumulate for epoch-averaged loss (logged in on_train_epoch_end)
|
|
667
|
+
self._accumulate_loss(loss)
|
|
619
668
|
return loss
|
|
620
669
|
|
|
621
670
|
def validation_step(self, batch, batch_idx):
|
|
@@ -638,7 +687,7 @@ class SingleInstanceLightningModule(LightningModel):
|
|
|
638
687
|
)
|
|
639
688
|
val_loss = val_loss + ohkm_loss
|
|
640
689
|
self.log(
|
|
641
|
-
"
|
|
690
|
+
"val/loss",
|
|
642
691
|
val_loss,
|
|
643
692
|
prog_bar=True,
|
|
644
693
|
on_step=False,
|
|
@@ -860,7 +909,7 @@ class TopDownCenteredInstanceLightningModule(LightningModel):
|
|
|
860
909
|
channel_wise_loss = torch.sum(mse, dim=(0, 2, 3)) / (batch_size * h * w)
|
|
861
910
|
for node_idx, name in enumerate(self.node_names):
|
|
862
911
|
self.log(
|
|
863
|
-
f"{name}",
|
|
912
|
+
f"train/confmaps/{name}",
|
|
864
913
|
channel_wise_loss[node_idx],
|
|
865
914
|
prog_bar=False,
|
|
866
915
|
on_step=False,
|
|
@@ -869,8 +918,9 @@ class TopDownCenteredInstanceLightningModule(LightningModel):
|
|
|
869
918
|
sync_dist=True,
|
|
870
919
|
)
|
|
871
920
|
|
|
921
|
+
# Log step-level loss (every batch, uses global_step x-axis)
|
|
872
922
|
self.log(
|
|
873
|
-
"
|
|
923
|
+
"loss",
|
|
874
924
|
loss,
|
|
875
925
|
prog_bar=True,
|
|
876
926
|
on_step=True,
|
|
@@ -878,6 +928,8 @@ class TopDownCenteredInstanceLightningModule(LightningModel):
|
|
|
878
928
|
logger=True,
|
|
879
929
|
sync_dist=True,
|
|
880
930
|
)
|
|
931
|
+
# Accumulate for epoch-averaged loss (logged in on_train_epoch_end)
|
|
932
|
+
self._accumulate_loss(loss)
|
|
881
933
|
return loss
|
|
882
934
|
|
|
883
935
|
def validation_step(self, batch, batch_idx):
|
|
@@ -900,7 +952,7 @@ class TopDownCenteredInstanceLightningModule(LightningModel):
|
|
|
900
952
|
)
|
|
901
953
|
val_loss = val_loss + ohkm_loss
|
|
902
954
|
self.log(
|
|
903
|
-
"
|
|
955
|
+
"val/loss",
|
|
904
956
|
val_loss,
|
|
905
957
|
prog_bar=True,
|
|
906
958
|
on_step=False,
|
|
@@ -1113,8 +1165,9 @@ class CentroidLightningModule(LightningModel):
|
|
|
1113
1165
|
|
|
1114
1166
|
y_preds = self.model(X)["CentroidConfmapsHead"]
|
|
1115
1167
|
loss = nn.MSELoss()(y_preds, y)
|
|
1168
|
+
# Log step-level loss (every batch, uses global_step x-axis)
|
|
1116
1169
|
self.log(
|
|
1117
|
-
"
|
|
1170
|
+
"loss",
|
|
1118
1171
|
loss,
|
|
1119
1172
|
prog_bar=True,
|
|
1120
1173
|
on_step=True,
|
|
@@ -1122,6 +1175,8 @@ class CentroidLightningModule(LightningModel):
|
|
|
1122
1175
|
logger=True,
|
|
1123
1176
|
sync_dist=True,
|
|
1124
1177
|
)
|
|
1178
|
+
# Accumulate for epoch-averaged loss (logged in on_train_epoch_end)
|
|
1179
|
+
self._accumulate_loss(loss)
|
|
1125
1180
|
return loss
|
|
1126
1181
|
|
|
1127
1182
|
def validation_step(self, batch, batch_idx):
|
|
@@ -1134,7 +1189,7 @@ class CentroidLightningModule(LightningModel):
|
|
|
1134
1189
|
y_preds = self.model(X)["CentroidConfmapsHead"]
|
|
1135
1190
|
val_loss = nn.MSELoss()(y_preds, y)
|
|
1136
1191
|
self.log(
|
|
1137
|
-
"
|
|
1192
|
+
"val/loss",
|
|
1138
1193
|
val_loss,
|
|
1139
1194
|
prog_bar=True,
|
|
1140
1195
|
on_step=False,
|
|
@@ -1409,8 +1464,9 @@ class BottomUpLightningModule(LightningModel):
|
|
|
1409
1464
|
"PartAffinityFieldsHead": pafs_loss,
|
|
1410
1465
|
}
|
|
1411
1466
|
loss = sum([s * losses[t] for s, t in zip(self.loss_weights, losses)])
|
|
1467
|
+
# Log step-level loss (every batch, uses global_step x-axis)
|
|
1412
1468
|
self.log(
|
|
1413
|
-
"
|
|
1469
|
+
"loss",
|
|
1414
1470
|
loss,
|
|
1415
1471
|
prog_bar=True,
|
|
1416
1472
|
on_step=True,
|
|
@@ -1418,8 +1474,10 @@ class BottomUpLightningModule(LightningModel):
|
|
|
1418
1474
|
logger=True,
|
|
1419
1475
|
sync_dist=True,
|
|
1420
1476
|
)
|
|
1477
|
+
# Accumulate for epoch-averaged loss (logged in on_train_epoch_end)
|
|
1478
|
+
self._accumulate_loss(loss)
|
|
1421
1479
|
self.log(
|
|
1422
|
-
"
|
|
1480
|
+
"train/confmaps_loss",
|
|
1423
1481
|
confmap_loss,
|
|
1424
1482
|
on_step=False,
|
|
1425
1483
|
on_epoch=True,
|
|
@@ -1427,7 +1485,7 @@ class BottomUpLightningModule(LightningModel):
|
|
|
1427
1485
|
sync_dist=True,
|
|
1428
1486
|
)
|
|
1429
1487
|
self.log(
|
|
1430
|
-
"
|
|
1488
|
+
"train/paf_loss",
|
|
1431
1489
|
pafs_loss,
|
|
1432
1490
|
on_step=False,
|
|
1433
1491
|
on_epoch=True,
|
|
@@ -1476,7 +1534,7 @@ class BottomUpLightningModule(LightningModel):
|
|
|
1476
1534
|
|
|
1477
1535
|
val_loss = sum([s * losses[t] for s, t in zip(self.loss_weights, losses)])
|
|
1478
1536
|
self.log(
|
|
1479
|
-
"
|
|
1537
|
+
"val/loss",
|
|
1480
1538
|
val_loss,
|
|
1481
1539
|
prog_bar=True,
|
|
1482
1540
|
on_step=False,
|
|
@@ -1485,7 +1543,7 @@ class BottomUpLightningModule(LightningModel):
|
|
|
1485
1543
|
sync_dist=True,
|
|
1486
1544
|
)
|
|
1487
1545
|
self.log(
|
|
1488
|
-
"
|
|
1546
|
+
"val/confmaps_loss",
|
|
1489
1547
|
confmap_loss,
|
|
1490
1548
|
on_step=False,
|
|
1491
1549
|
on_epoch=True,
|
|
@@ -1493,7 +1551,7 @@ class BottomUpLightningModule(LightningModel):
|
|
|
1493
1551
|
sync_dist=True,
|
|
1494
1552
|
)
|
|
1495
1553
|
self.log(
|
|
1496
|
-
"
|
|
1554
|
+
"val/paf_loss",
|
|
1497
1555
|
pafs_loss,
|
|
1498
1556
|
on_step=False,
|
|
1499
1557
|
on_epoch=True,
|
|
@@ -1749,8 +1807,9 @@ class BottomUpMultiClassLightningModule(LightningModel):
|
|
|
1749
1807
|
"ClassMapsHead": classmaps_loss,
|
|
1750
1808
|
}
|
|
1751
1809
|
loss = sum([s * losses[t] for s, t in zip(self.loss_weights, losses)])
|
|
1810
|
+
# Log step-level loss (every batch, uses global_step x-axis)
|
|
1752
1811
|
self.log(
|
|
1753
|
-
"
|
|
1812
|
+
"loss",
|
|
1754
1813
|
loss,
|
|
1755
1814
|
prog_bar=True,
|
|
1756
1815
|
on_step=True,
|
|
@@ -1758,8 +1817,10 @@ class BottomUpMultiClassLightningModule(LightningModel):
|
|
|
1758
1817
|
logger=True,
|
|
1759
1818
|
sync_dist=True,
|
|
1760
1819
|
)
|
|
1820
|
+
# Accumulate for epoch-averaged loss (logged in on_train_epoch_end)
|
|
1821
|
+
self._accumulate_loss(loss)
|
|
1761
1822
|
self.log(
|
|
1762
|
-
"
|
|
1823
|
+
"train/confmaps_loss",
|
|
1763
1824
|
confmap_loss,
|
|
1764
1825
|
on_step=False,
|
|
1765
1826
|
on_epoch=True,
|
|
@@ -1767,13 +1828,67 @@ class BottomUpMultiClassLightningModule(LightningModel):
|
|
|
1767
1828
|
sync_dist=True,
|
|
1768
1829
|
)
|
|
1769
1830
|
self.log(
|
|
1770
|
-
"
|
|
1831
|
+
"train/classmap_loss",
|
|
1771
1832
|
classmaps_loss,
|
|
1772
1833
|
on_step=False,
|
|
1773
1834
|
on_epoch=True,
|
|
1774
1835
|
logger=True,
|
|
1775
1836
|
sync_dist=True,
|
|
1776
1837
|
)
|
|
1838
|
+
|
|
1839
|
+
# Compute classification accuracy at GT keypoint locations
|
|
1840
|
+
with torch.no_grad():
|
|
1841
|
+
# Get output stride for class maps
|
|
1842
|
+
cms_stride = self.head_configs.multi_class_bottomup.class_maps.output_stride
|
|
1843
|
+
|
|
1844
|
+
# Get GT instances and sample class maps at those locations
|
|
1845
|
+
instances = batch["instances"] # (batch, n_samples, max_inst, n_nodes, 2)
|
|
1846
|
+
if instances.dim() == 5:
|
|
1847
|
+
instances = instances.squeeze(1) # (batch, max_inst, n_nodes, 2)
|
|
1848
|
+
num_instances = batch["num_instances"] # (batch,)
|
|
1849
|
+
|
|
1850
|
+
correct = 0
|
|
1851
|
+
total = 0
|
|
1852
|
+
for b in range(instances.shape[0]):
|
|
1853
|
+
n_inst = num_instances[b].item()
|
|
1854
|
+
for inst_idx in range(n_inst):
|
|
1855
|
+
for node_idx in range(instances.shape[2]):
|
|
1856
|
+
# Get keypoint location (in input image space)
|
|
1857
|
+
kp = instances[b, inst_idx, node_idx] # (2,) = (x, y)
|
|
1858
|
+
if torch.isnan(kp).any():
|
|
1859
|
+
continue
|
|
1860
|
+
|
|
1861
|
+
# Convert to class map space
|
|
1862
|
+
x_cm = (
|
|
1863
|
+
(kp[0] / cms_stride)
|
|
1864
|
+
.long()
|
|
1865
|
+
.clamp(0, classmaps.shape[-1] - 1)
|
|
1866
|
+
)
|
|
1867
|
+
y_cm = (
|
|
1868
|
+
(kp[1] / cms_stride)
|
|
1869
|
+
.long()
|
|
1870
|
+
.clamp(0, classmaps.shape[-2] - 1)
|
|
1871
|
+
)
|
|
1872
|
+
|
|
1873
|
+
# Sample predicted and GT class at this location
|
|
1874
|
+
pred_class = classmaps[b, :, y_cm, x_cm].argmax()
|
|
1875
|
+
gt_class = y_classmap[b, :, y_cm, x_cm].argmax()
|
|
1876
|
+
|
|
1877
|
+
if pred_class == gt_class:
|
|
1878
|
+
correct += 1
|
|
1879
|
+
total += 1
|
|
1880
|
+
|
|
1881
|
+
if total > 0:
|
|
1882
|
+
class_accuracy = torch.tensor(correct / total, device=X.device)
|
|
1883
|
+
self.log(
|
|
1884
|
+
"train/class_accuracy",
|
|
1885
|
+
class_accuracy,
|
|
1886
|
+
on_step=False,
|
|
1887
|
+
on_epoch=True,
|
|
1888
|
+
logger=True,
|
|
1889
|
+
sync_dist=True,
|
|
1890
|
+
)
|
|
1891
|
+
|
|
1777
1892
|
return loss
|
|
1778
1893
|
|
|
1779
1894
|
def validation_step(self, batch, batch_idx):
|
|
@@ -1807,7 +1922,7 @@ class BottomUpMultiClassLightningModule(LightningModel):
|
|
|
1807
1922
|
|
|
1808
1923
|
val_loss = sum([s * losses[t] for s, t in zip(self.loss_weights, losses)])
|
|
1809
1924
|
self.log(
|
|
1810
|
-
"
|
|
1925
|
+
"val/loss",
|
|
1811
1926
|
val_loss,
|
|
1812
1927
|
prog_bar=True,
|
|
1813
1928
|
on_step=False,
|
|
@@ -1816,7 +1931,7 @@ class BottomUpMultiClassLightningModule(LightningModel):
|
|
|
1816
1931
|
sync_dist=True,
|
|
1817
1932
|
)
|
|
1818
1933
|
self.log(
|
|
1819
|
-
"
|
|
1934
|
+
"val/confmaps_loss",
|
|
1820
1935
|
confmap_loss,
|
|
1821
1936
|
on_step=False,
|
|
1822
1937
|
on_epoch=True,
|
|
@@ -1824,7 +1939,7 @@ class BottomUpMultiClassLightningModule(LightningModel):
|
|
|
1824
1939
|
sync_dist=True,
|
|
1825
1940
|
)
|
|
1826
1941
|
self.log(
|
|
1827
|
-
"
|
|
1942
|
+
"val/classmap_loss",
|
|
1828
1943
|
classmaps_loss,
|
|
1829
1944
|
on_step=False,
|
|
1830
1945
|
on_epoch=True,
|
|
@@ -1832,6 +1947,106 @@ class BottomUpMultiClassLightningModule(LightningModel):
|
|
|
1832
1947
|
sync_dist=True,
|
|
1833
1948
|
)
|
|
1834
1949
|
|
|
1950
|
+
# Compute classification accuracy at GT keypoint locations
|
|
1951
|
+
with torch.no_grad():
|
|
1952
|
+
# Get output stride for class maps
|
|
1953
|
+
cms_stride = self.head_configs.multi_class_bottomup.class_maps.output_stride
|
|
1954
|
+
|
|
1955
|
+
# Get GT instances and sample class maps at those locations
|
|
1956
|
+
instances = batch["instances"] # (batch, n_samples, max_inst, n_nodes, 2)
|
|
1957
|
+
if instances.dim() == 5:
|
|
1958
|
+
instances = instances.squeeze(1) # (batch, max_inst, n_nodes, 2)
|
|
1959
|
+
num_instances = batch["num_instances"] # (batch,)
|
|
1960
|
+
|
|
1961
|
+
correct = 0
|
|
1962
|
+
total = 0
|
|
1963
|
+
for b in range(instances.shape[0]):
|
|
1964
|
+
n_inst = num_instances[b].item()
|
|
1965
|
+
for inst_idx in range(n_inst):
|
|
1966
|
+
for node_idx in range(instances.shape[2]):
|
|
1967
|
+
# Get keypoint location (in input image space)
|
|
1968
|
+
kp = instances[b, inst_idx, node_idx] # (2,) = (x, y)
|
|
1969
|
+
if torch.isnan(kp).any():
|
|
1970
|
+
continue
|
|
1971
|
+
|
|
1972
|
+
# Convert to class map space
|
|
1973
|
+
x_cm = (
|
|
1974
|
+
(kp[0] / cms_stride)
|
|
1975
|
+
.long()
|
|
1976
|
+
.clamp(0, classmaps.shape[-1] - 1)
|
|
1977
|
+
)
|
|
1978
|
+
y_cm = (
|
|
1979
|
+
(kp[1] / cms_stride)
|
|
1980
|
+
.long()
|
|
1981
|
+
.clamp(0, classmaps.shape[-2] - 1)
|
|
1982
|
+
)
|
|
1983
|
+
|
|
1984
|
+
# Sample predicted and GT class at this location
|
|
1985
|
+
pred_class = classmaps[b, :, y_cm, x_cm].argmax()
|
|
1986
|
+
gt_class = y_classmap[b, :, y_cm, x_cm].argmax()
|
|
1987
|
+
|
|
1988
|
+
if pred_class == gt_class:
|
|
1989
|
+
correct += 1
|
|
1990
|
+
total += 1
|
|
1991
|
+
|
|
1992
|
+
if total > 0:
|
|
1993
|
+
class_accuracy = torch.tensor(correct / total, device=X.device)
|
|
1994
|
+
self.log(
|
|
1995
|
+
"val/class_accuracy",
|
|
1996
|
+
class_accuracy,
|
|
1997
|
+
on_step=False,
|
|
1998
|
+
on_epoch=True,
|
|
1999
|
+
logger=True,
|
|
2000
|
+
sync_dist=True,
|
|
2001
|
+
)
|
|
2002
|
+
|
|
2003
|
+
# Collect predictions for epoch-end evaluation if enabled
|
|
2004
|
+
if self._collect_val_predictions:
|
|
2005
|
+
with torch.no_grad():
|
|
2006
|
+
# Note: Do NOT squeeze the image here - the forward() method expects
|
|
2007
|
+
# (batch, n_samples, C, H, W) and handles the n_samples squeeze internally
|
|
2008
|
+
inference_output = self.bottomup_inf_layer(batch)
|
|
2009
|
+
if isinstance(inference_output, list):
|
|
2010
|
+
inference_output = inference_output[0]
|
|
2011
|
+
|
|
2012
|
+
batch_size = len(batch["frame_idx"])
|
|
2013
|
+
for i in range(batch_size):
|
|
2014
|
+
eff = batch["eff_scale"][i].cpu().numpy()
|
|
2015
|
+
|
|
2016
|
+
# Predictions are already in original space (variable number of instances)
|
|
2017
|
+
pred_peaks = inference_output["pred_instance_peaks"][i]
|
|
2018
|
+
pred_scores = inference_output["pred_peak_values"][i]
|
|
2019
|
+
if torch.is_tensor(pred_peaks):
|
|
2020
|
+
pred_peaks = pred_peaks.cpu().numpy()
|
|
2021
|
+
if torch.is_tensor(pred_scores):
|
|
2022
|
+
pred_scores = pred_scores.cpu().numpy()
|
|
2023
|
+
|
|
2024
|
+
# Transform GT to original space
|
|
2025
|
+
# Note: instances have shape (1, max_inst, n_nodes, 2) - squeeze n_samples dim
|
|
2026
|
+
gt_prep = batch["instances"][i].cpu().numpy()
|
|
2027
|
+
if gt_prep.ndim == 4:
|
|
2028
|
+
gt_prep = gt_prep.squeeze(0) # (max_inst, n_nodes, 2)
|
|
2029
|
+
gt_orig = gt_prep / eff
|
|
2030
|
+
num_inst = batch["num_instances"][i].item()
|
|
2031
|
+
gt_orig = gt_orig[:num_inst] # Only valid instances
|
|
2032
|
+
|
|
2033
|
+
self.val_predictions.append(
|
|
2034
|
+
{
|
|
2035
|
+
"video_idx": batch["video_idx"][i].item(),
|
|
2036
|
+
"frame_idx": batch["frame_idx"][i].item(),
|
|
2037
|
+
"pred_peaks": pred_peaks, # Original space, variable instances
|
|
2038
|
+
"pred_scores": pred_scores,
|
|
2039
|
+
}
|
|
2040
|
+
)
|
|
2041
|
+
self.val_ground_truth.append(
|
|
2042
|
+
{
|
|
2043
|
+
"video_idx": batch["video_idx"][i].item(),
|
|
2044
|
+
"frame_idx": batch["frame_idx"][i].item(),
|
|
2045
|
+
"gt_instances": gt_orig, # Original space
|
|
2046
|
+
"num_instances": num_inst,
|
|
2047
|
+
}
|
|
2048
|
+
)
|
|
2049
|
+
|
|
1835
2050
|
|
|
1836
2051
|
class TopDownCenteredInstanceMultiClassLightningModule(LightningModel):
|
|
1837
2052
|
"""Lightning Module for TopDownCenteredInstance ID Model.
|
|
@@ -2011,7 +2226,7 @@ class TopDownCenteredInstanceMultiClassLightningModule(LightningModel):
|
|
|
2011
2226
|
channel_wise_loss = torch.sum(mse, dim=(0, 2, 3)) / (batch_size * h * w)
|
|
2012
2227
|
for node_idx, name in enumerate(self.node_names):
|
|
2013
2228
|
self.log(
|
|
2014
|
-
f"{name}",
|
|
2229
|
+
f"train/confmaps/{name}",
|
|
2015
2230
|
channel_wise_loss[node_idx],
|
|
2016
2231
|
prog_bar=False,
|
|
2017
2232
|
on_step=False,
|
|
@@ -2020,8 +2235,9 @@ class TopDownCenteredInstanceMultiClassLightningModule(LightningModel):
|
|
|
2020
2235
|
sync_dist=True,
|
|
2021
2236
|
)
|
|
2022
2237
|
|
|
2238
|
+
# Log step-level loss (every batch, uses global_step x-axis)
|
|
2023
2239
|
self.log(
|
|
2024
|
-
"
|
|
2240
|
+
"loss",
|
|
2025
2241
|
loss,
|
|
2026
2242
|
prog_bar=True,
|
|
2027
2243
|
on_step=True,
|
|
@@ -2029,8 +2245,10 @@ class TopDownCenteredInstanceMultiClassLightningModule(LightningModel):
|
|
|
2029
2245
|
logger=True,
|
|
2030
2246
|
sync_dist=True,
|
|
2031
2247
|
)
|
|
2248
|
+
# Accumulate for epoch-averaged loss (logged in on_train_epoch_end)
|
|
2249
|
+
self._accumulate_loss(loss)
|
|
2032
2250
|
self.log(
|
|
2033
|
-
"
|
|
2251
|
+
"train/confmaps_loss",
|
|
2034
2252
|
confmap_loss,
|
|
2035
2253
|
on_step=False,
|
|
2036
2254
|
on_epoch=True,
|
|
@@ -2038,13 +2256,27 @@ class TopDownCenteredInstanceMultiClassLightningModule(LightningModel):
|
|
|
2038
2256
|
sync_dist=True,
|
|
2039
2257
|
)
|
|
2040
2258
|
self.log(
|
|
2041
|
-
"
|
|
2259
|
+
"train/classvector_loss",
|
|
2042
2260
|
classvector_loss,
|
|
2043
2261
|
on_step=False,
|
|
2044
2262
|
on_epoch=True,
|
|
2045
2263
|
logger=True,
|
|
2046
2264
|
sync_dist=True,
|
|
2047
2265
|
)
|
|
2266
|
+
|
|
2267
|
+
# Compute classification accuracy
|
|
2268
|
+
with torch.no_grad():
|
|
2269
|
+
pred_classes = torch.argmax(classvector, dim=1)
|
|
2270
|
+
gt_classes = torch.argmax(y_classvector, dim=1)
|
|
2271
|
+
class_accuracy = (pred_classes == gt_classes).float().mean()
|
|
2272
|
+
self.log(
|
|
2273
|
+
"train/class_accuracy",
|
|
2274
|
+
class_accuracy,
|
|
2275
|
+
on_step=False,
|
|
2276
|
+
on_epoch=True,
|
|
2277
|
+
logger=True,
|
|
2278
|
+
sync_dist=True,
|
|
2279
|
+
)
|
|
2048
2280
|
return loss
|
|
2049
2281
|
|
|
2050
2282
|
def validation_step(self, batch, batch_idx):
|
|
@@ -2076,7 +2308,7 @@ class TopDownCenteredInstanceMultiClassLightningModule(LightningModel):
|
|
|
2076
2308
|
}
|
|
2077
2309
|
val_loss = sum([s * losses[t] for s, t in zip(self.loss_weights, losses)])
|
|
2078
2310
|
self.log(
|
|
2079
|
-
"
|
|
2311
|
+
"val/loss",
|
|
2080
2312
|
val_loss,
|
|
2081
2313
|
prog_bar=True,
|
|
2082
2314
|
on_step=False,
|
|
@@ -2085,7 +2317,7 @@ class TopDownCenteredInstanceMultiClassLightningModule(LightningModel):
|
|
|
2085
2317
|
sync_dist=True,
|
|
2086
2318
|
)
|
|
2087
2319
|
self.log(
|
|
2088
|
-
"
|
|
2320
|
+
"val/confmaps_loss",
|
|
2089
2321
|
confmap_loss,
|
|
2090
2322
|
on_step=False,
|
|
2091
2323
|
on_epoch=True,
|
|
@@ -2093,10 +2325,80 @@ class TopDownCenteredInstanceMultiClassLightningModule(LightningModel):
|
|
|
2093
2325
|
sync_dist=True,
|
|
2094
2326
|
)
|
|
2095
2327
|
self.log(
|
|
2096
|
-
"
|
|
2328
|
+
"val/classvector_loss",
|
|
2097
2329
|
classvector_loss,
|
|
2098
2330
|
on_step=False,
|
|
2099
2331
|
on_epoch=True,
|
|
2100
2332
|
logger=True,
|
|
2101
2333
|
sync_dist=True,
|
|
2102
2334
|
)
|
|
2335
|
+
|
|
2336
|
+
# Compute classification accuracy
|
|
2337
|
+
with torch.no_grad():
|
|
2338
|
+
pred_classes = torch.argmax(classvector, dim=1)
|
|
2339
|
+
gt_classes = torch.argmax(y_classvector, dim=1)
|
|
2340
|
+
class_accuracy = (pred_classes == gt_classes).float().mean()
|
|
2341
|
+
self.log(
|
|
2342
|
+
"val/class_accuracy",
|
|
2343
|
+
class_accuracy,
|
|
2344
|
+
on_step=False,
|
|
2345
|
+
on_epoch=True,
|
|
2346
|
+
logger=True,
|
|
2347
|
+
sync_dist=True,
|
|
2348
|
+
)
|
|
2349
|
+
|
|
2350
|
+
# Collect predictions for epoch-end evaluation if enabled
|
|
2351
|
+
if self._collect_val_predictions:
|
|
2352
|
+
# SAVE bbox BEFORE inference (it modifies in-place!)
|
|
2353
|
+
bbox_prep_saved = batch["instance_bbox"].clone()
|
|
2354
|
+
|
|
2355
|
+
with torch.no_grad():
|
|
2356
|
+
inference_output = self.instance_peaks_inf_layer(batch)
|
|
2357
|
+
|
|
2358
|
+
batch_size = len(batch["frame_idx"])
|
|
2359
|
+
for i in range(batch_size):
|
|
2360
|
+
eff = batch["eff_scale"][i].cpu().numpy()
|
|
2361
|
+
|
|
2362
|
+
# Predictions from inference (crop-relative, original scale)
|
|
2363
|
+
pred_peaks_crop = (
|
|
2364
|
+
inference_output["pred_instance_peaks"][i].cpu().numpy()
|
|
2365
|
+
)
|
|
2366
|
+
pred_scores = inference_output["pred_peak_values"][i].cpu().numpy()
|
|
2367
|
+
|
|
2368
|
+
# Compute bbox offset in original space from SAVED prep bbox
|
|
2369
|
+
# bbox has shape (n_samples=1, 4, 2) where 4 corners
|
|
2370
|
+
bbox_prep = bbox_prep_saved[i].squeeze(0).cpu().numpy() # (4, 2)
|
|
2371
|
+
bbox_top_left_orig = (
|
|
2372
|
+
bbox_prep[0] / eff
|
|
2373
|
+
) # Top-left corner in original space
|
|
2374
|
+
|
|
2375
|
+
# Full image coordinates (original space)
|
|
2376
|
+
pred_peaks_full = pred_peaks_crop + bbox_top_left_orig
|
|
2377
|
+
|
|
2378
|
+
# GT transform: crop-relative preprocessed -> full image original
|
|
2379
|
+
gt_crop_prep = (
|
|
2380
|
+
batch["instance"][i].squeeze(0).cpu().numpy()
|
|
2381
|
+
) # (n_nodes, 2)
|
|
2382
|
+
gt_crop_orig = gt_crop_prep / eff
|
|
2383
|
+
gt_full_orig = gt_crop_orig + bbox_top_left_orig
|
|
2384
|
+
|
|
2385
|
+
self.val_predictions.append(
|
|
2386
|
+
{
|
|
2387
|
+
"video_idx": batch["video_idx"][i].item(),
|
|
2388
|
+
"frame_idx": batch["frame_idx"][i].item(),
|
|
2389
|
+
"pred_peaks": pred_peaks_full.reshape(
|
|
2390
|
+
1, -1, 2
|
|
2391
|
+
), # (1, n_nodes, 2)
|
|
2392
|
+
"pred_scores": pred_scores.reshape(1, -1), # (1, n_nodes)
|
|
2393
|
+
}
|
|
2394
|
+
)
|
|
2395
|
+
self.val_ground_truth.append(
|
|
2396
|
+
{
|
|
2397
|
+
"video_idx": batch["video_idx"][i].item(),
|
|
2398
|
+
"frame_idx": batch["frame_idx"][i].item(),
|
|
2399
|
+
"gt_instances": gt_full_orig.reshape(
|
|
2400
|
+
1, -1, 2
|
|
2401
|
+
), # (1, n_nodes, 2)
|
|
2402
|
+
"num_instances": 1,
|
|
2403
|
+
}
|
|
2404
|
+
)
|